/*
 * LibMessage Server Copyright (c) 2018-2025, James Bailie.
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *
 *     * Redistributions of source code must retain the above copyright
 * notice, this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 * notice, this list of conditions and the following disclaimer in the
 * documentation and/or other materials provided with the distribution.
 *     * The name of James Bailie may not be used to endorse or promote
 * products derived from this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDER AND CONTRIBUTORS "AS IS"
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
 * POSSIBILITY OF SUCH DAMAGE.
 */

#include <errno.h>
#include <string.h>
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <signal.h>

#include <sys/types.h>
#include <sys/socket.h>
#include <sys/un.h>
#include <sys/uio.h>

#include <netdb.h>

#include <openssl/ssl.h>
#include <openssl/err.h>
#include <openssl/tls1.h>
#include <openssl/x509v3.h>

#define MAX_MESSAGE 262128
#define MAX_FRAME 32766
#define TERMINATOR 32768

unsigned int max_message = MAX_MESSAGE;

extern char mc_error[ 256 ];

SSL_CTX *ctx;

struct mc_ssl
{
   SSL *ssl;
   int fd;
};

int mc_connect_to_tcp_server( char *interface, char *port )
{
   int result, fd;
   struct addrinfo hints, *res, *ptr;

   bzero( &hints, sizeof( struct addrinfo ));
   hints.ai_family = PF_UNSPEC;
   hints.ai_socktype = SOCK_STREAM;

   if (( result = getaddrinfo( interface, port, &hints, &res )))
   {
      fprintf( stderr, "getaddrinfo(): %s\n", gai_strerror( result ));
      return -1;
   }

   for( ptr = res; ptr != NULL; ptr = ptr->ai_next )
   {
      result = 0;
      fd = socket( ptr->ai_family, ptr->ai_socktype, ptr->ai_protocol );

      if ( fd == -1 )
         continue;

      if (( result = connect( fd, ptr->ai_addr, ptr->ai_addrlen )) == -1 )
      {
         snprintf( mc_error, sizeof( mc_error ), "connect(): %s\n", strerror( errno ));
         return -1;
      }

      break;
   }

   if ( res != NULL )
      freeaddrinfo( res );

   if ( fd == -1 )
      snprintf( mc_error, sizeof( mc_error ),
                ( result ? "connect(): %s\n" : "socket(): %s\n" ), strerror( errno ));

   return fd;
}

void mc_log_error()
{
   unsigned long err;

   err = ERR_get_error();
   ERR_error_string_n( err, mc_error, sizeof( mc_error ));
}

int mc_init()
{
   SSL_load_error_strings();
   SSL_library_init();

   signal( SIGPIPE, SIG_IGN );

   if (( ctx = SSL_CTX_new( TLS_client_method() )) == NULL )
   {
      mc_log_error();
      return -1;
   }

   return 0;
}

void mc_deinit()
{
   SSL_CTX_free( ctx );
}

int mc_verify_hostname( struct mc_ssl *conn, char *hostname )
{
   int err;
   char *ptr;
   X509 *cert;

   if (( cert = SSL_get_peer_certificate( conn->ssl )) == NULL )
   {
      snprintf( mc_error, sizeof( mc_error ), "%s", "server did not present certificate" );
      return 1;
   }

   for( ptr = hostname, err = 0; *ptr; ++ptr )
      ++err;

   if ( X509_check_host( cert, hostname, err, 0, NULL ) != 1 )
   {
      char *certificate = X509_NAME_oneline( X509_get_subject_name( cert ), NULL, 0 );

      snprintf( mc_error, sizeof( mc_error ), "certificate is for: %s", ( certificate == NULL ? "" : certificate ));

      if ( certificate != NULL )
         OPENSSL_free( certificate );

      return 1;
   }

   return 0;
}

struct mc_ssl *mc_connect_to_server( char *host, char *port, int ignore )
{
   struct mc_ssl conn, *ret;

   if (( conn.fd = mc_connect_to_tcp_server( host, port )) < 0 )
      return NULL;

   if (( conn.ssl = SSL_new( ctx )) == NULL )
   {
      mc_log_error();
      close( conn.fd );
      return NULL;
   }

   if ( ! SSL_set_fd( conn.ssl, conn.fd ))
   {
      mc_log_error();
      SSL_free( conn.ssl );
      close( conn.fd );
      return NULL;
   }

   if ( SSL_connect( conn.ssl ) != 1 )
   {
      mc_log_error();
      SSL_free( conn.ssl );
      close( conn.fd );
      return NULL;
   }

   if ( ! ignore && mc_verify_hostname( &conn, host ))
   {
      SSL_free( conn.ssl );
      close( conn.fd );
      return NULL;
   }

   if (( ret = malloc( sizeof( struct mc_ssl ))) == NULL )
   {
      snprintf( mc_error, sizeof( mc_error ), "malloc(): %s", strerror( errno ));
      SSL_free( conn.ssl );
      close( conn.fd );
      return NULL;
   }

   *ret = conn;
   return ret;
}

void mc_close_connection( struct mc_ssl *conn )
{
   SSL_shutdown( conn->ssl );
   SSL_free( conn->ssl );
   close( conn->fd );
   free( conn );
}

int mc_send_frame( struct mc_ssl *conn, unsigned char final, unsigned int len, unsigned char *data )
{
   unsigned char buffer[ MAX_FRAME + 2 ];
   unsigned int header;
   int r;

   if ( len > MAX_FRAME )
   {
      snprintf( mc_error, sizeof( mc_error ), "illegal frame size: %ud", len );
      return -1;
   }

   header = ( final ? TERMINATOR : 0 ) + len;
   buffer[ 0 ] = header / 256;
   buffer[ 1 ] = header % 256;

   bcopy( data, &buffer[ 2 ], len );

   if (( r =  SSL_write( conn->ssl, buffer, len + 2 )) <= 0 )
   {
      if ( ! r )
         return -2;

      mc_log_error();
      return -1;
   }

   return 0;
}

int mc_send_message( struct mc_ssl *conn, unsigned int len, unsigned char *data )
{
   int i, frames, result;
   unsigned int flen, total;
   unsigned char *payload, final;

   if ( len > max_message )
   {
      snprintf( mc_error, sizeof( mc_error ), "illegal message size: %ud", len );
      return -1;
   }

   if ( ! ( frames = len / MAX_FRAME ))
      ++frames;
   else if ( len % MAX_FRAME )
      ++frames;

   payload = data;
   total   = len;

   for( i = 0; i < frames; ++i )
   {
      if ( i == frames - 1 )
      {
         flen  = total;
         final = 1;
      }
      else
      {
         flen  = MAX_FRAME;
         final = 0;
      }

      if (( result = mc_send_frame( conn, final, flen, payload )))
         return result;

      total   -= MAX_FRAME;
      payload += MAX_FRAME;
   }

   return 0;
}

int mc_read_frame( struct mc_ssl *conn, unsigned char *buffer, unsigned char *final )
{
   unsigned int len, count, total, r;

   len = 2;
   count = 0;

   while( len )
   {
      if (( r = SSL_read( conn->ssl, &buffer[ count ], len )) <= 0 )
      {
         if ( ! r )
            return -2;

         mc_log_error();
         return -1;
      }

      len -= r;
      count += r;
   }

   if (( total = len = (( buffer[ 0 ] & 0x7F ) << 8 ) + buffer[ 1 ] ) > MAX_FRAME )
   {
      snprintf( mc_error, sizeof( mc_error ), "%s", "fatal error: oversize frame" );
      return -1;
   }

   *final = buffer[ 0 ] & 0x80;
   count = 0;

   while( len )
   {
      if (( r = SSL_read( conn->ssl, &buffer[ count ], len )) <= 0 )
      {
         if ( ! r )
            return -2;

         mc_log_error();
         return -1;
      }

      len   -= r;
      count += r;
   }

   return total;
}

int mc_read_message( struct mc_ssl *conn, unsigned char **message )
{
   unsigned char buffer[ MAX_FRAME + 2 ], final;
   int r, len;

   *message = NULL;
   len = 0;

   for( ; ; )
   {
      r = mc_read_frame( conn, buffer, &final );

      if ( r < 0 )
      {
         if ( *message != NULL )
         {
            free( *message );
            *message = NULL;
         }

         return r;
      }

      if ( r )
      {
         if ( *message == NULL )
         {
            if (( *message = malloc( r )) == NULL )
            {
               snprintf( mc_error, sizeof( mc_error ), "malloc(): %s", strerror( errno ));
               return -1;
            }
         }
         else if (( len + r ) > max_message )
         {
            snprintf( mc_error, sizeof( mc_error ), "%s", "fatal error: oversize message" );
            free( *message );
            *message = NULL;
            return -1;
         }
         else if (( *message = realloc( *message, len + r )) == NULL )
         {
            snprintf( mc_error, sizeof( mc_error ), "realloc(): %s", strerror( errno ));
            return -1;
         }

         bcopy( buffer, &(( *message )[ len ]), r );
         len += r;
      }

      if ( final )
         break;
   }

   return len;
}

unsigned int mc_set_max_message( unsigned int max )
{
   if ( ! max || ( max % MAX_FRAME ))
      return 0;

   return max_message = max;
}
