/*
 * Multiplexing File Server Copyright (c) 2018, 2019, James Bailie.
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *
 *     * Redistributions of source code must retain the above copyright
 * notice, this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 * notice, this list of conditions and the following disclaimer in the
 * documentation and/or other materials provided with the distribution.
 *     * The name of James Bailie may not be used to endorse or promote
 * products derived from this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDER AND CONTRIBUTORS "AS IS"
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
 * POSSIBILITY OF SUCH DAMAGE.
 */

#include <errno.h>
#include <string.h>
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <signal.h>

#include <sys/types.h>
#include <sys/socket.h>
#include <sys/un.h>
#include <sys/uio.h>

#include <netdb.h>

#include <openssl/ssl.h>
#include <openssl/err.h>
#include <openssl/tls1.h>
#include <openssl/x509v3.h>

#define fc_max_frame 1056

extern char fc_error[ 256 ];

SSL_CTX *ctx;

struct fc_ssl
{
   SSL *ssl;
   int fd;
};

int fc_connect_to_tcp_server( char *interface, char *port )
{
   int result, fd;
   struct addrinfo hints, *res, *ptr;

   bzero( &hints, sizeof( struct addrinfo ));
   hints.ai_family = PF_UNSPEC;
   hints.ai_socktype = SOCK_STREAM;

   if (( result = getaddrinfo( interface, port, &hints, &res )))
   {
      fprintf( stderr, "getaddrinfo(): %s\n", gai_strerror( result ));
      return -1;
   }

   for( ptr = res; ptr != NULL; ptr = ptr->ai_next )
   {
      result = 0;
      fd = socket( ptr->ai_family, ptr->ai_socktype, ptr->ai_protocol );

      if ( fd == -1 )
         continue;

      if (( result = connect( fd, ptr->ai_addr, ptr->ai_addrlen )) == -1 )
      {
         snprintf( fc_error, sizeof( fc_error ), "connect(): %s\n", strerror( errno ));
         return -1;
      }

      break;
   }

   if ( res != NULL )
      freeaddrinfo( res );

   if ( fd == -1 )
      snprintf( fc_error, sizeof( fc_error ),
                ( result ? "connect(): %s\n" : "socket(): %s\n" ), strerror( errno ));

   return fd;
}

void fc_log_error()
{
   unsigned long err;

   err = ERR_get_error();
   ERR_error_string_n( err, fc_error, sizeof( fc_error ));
}

int fc_verify_hostname( struct fc_ssl *conn, char *hostname )
{
   int err;
   char *ptr;
   X509 *cert;

   if (( cert = SSL_get_peer_certificate( conn->ssl )) == NULL )
   {
      snprintf( fc_error, sizeof( fc_error ), "%s", "server did not present certificate" );
      return 1;
   }

   for( ptr = hostname, err = 0; *ptr; ++ptr )
      ++err;

   if ( X509_check_host( cert, hostname, err, 0, NULL ) != 1 )
   {
      char *certificate = X509_NAME_oneline( X509_get_subject_name( cert ), NULL, 0 );

      snprintf( fc_error, sizeof( fc_error ), "certificate is for: %s", ( certificate == NULL ? "" : certificate ));

      if ( certificate != NULL )
         OPENSSL_free( certificate );

      return 1;
   }

   return 0;
}

struct fc_ssl *fc_connect_to_server( char *hostname, char *port, int ignore )
{
   struct fc_ssl conn, *ret;

   if (( conn.fd = fc_connect_to_tcp_server( hostname, port )) < 0 )
      return NULL;

   if (( conn.ssl = SSL_new( ctx )) == NULL )
   {
      fc_log_error();
      close( conn.fd );
      return NULL;
   }

   if ( ! SSL_set_fd( conn.ssl, conn.fd ))
   {
      fc_log_error();
      SSL_free( conn.ssl );
      close( conn.fd );
      return NULL;
   }

   if ( SSL_connect( conn.ssl ) != 1 )
   {
      fc_log_error();
      SSL_free( conn.ssl );
      close( conn.fd );
      return NULL;
   }

   if ( ! ignore && fc_verify_hostname( &conn, hostname ))
   {
      SSL_free( conn.ssl );
      close( conn.fd );
      return NULL;
   }

   if (( ret = malloc( sizeof( struct fc_ssl ))) == NULL )
   {
      snprintf( fc_error, sizeof( fc_error ), "malloc(): %s", strerror( errno ));
      SSL_free( conn.ssl );
      close( conn.fd );
      return NULL;
   }

   *ret = conn;
   return ret;
}

int fc_init()
{
   SSL_load_error_strings();
   SSL_library_init();

   signal( SIGPIPE, SIG_IGN );

   *fc_error = '\0';

   if (( ctx = SSL_CTX_new( TLS_client_method() )) == NULL )
   {
      fc_log_error();
      exit( 1 );
   }

   return 0;
}

void fc_deinit()
{
   SSL_CTX_free( ctx );
}

void fc_close_connection( struct fc_ssl *conn )
{
   SSL_shutdown( conn->ssl );
   SSL_free( conn->ssl );
   close( conn->fd );
   free( conn );
}

int fc_send_frame( struct fc_ssl *conn, unsigned char *data, unsigned int len )
{
   unsigned char buffer[ fc_max_frame + 2 ];
   int r;

   if ( len < 1 || len > fc_max_frame )
   {
      snprintf( fc_error, sizeof( fc_error ), "illegal frame size: %ud", len );
      return -1;
   }

   buffer[ 0 ] = len / 256;
   buffer[ 1 ] = len % 256;

   bcopy( data, &buffer[ 2 ], len );

   if (( r =  SSL_write( conn->ssl, buffer, len + 2 )) <= 0 )
   {
      if ( ! r )
         return -2;

      fc_log_error();
      return -1;
   }

   return 0;
}

int fc_read_frame( struct fc_ssl *conn, unsigned char *buffer )
{
   unsigned int len, count, total, r;

   len = 2;
   count = 0;

   while( len )
   {
      if (( r = SSL_read( conn->ssl, &buffer[ count ], len )) <= 0 )
      {
         if ( ! r )
            return 0;

         fc_log_error();
         return -1;
      }

      len   -= r;
      count += r;
   }

   if (( total = len = ( buffer[ 0 ] << 8 ) + buffer[ 1 ] ) > fc_max_frame )
   {
      snprintf( fc_error, sizeof( fc_error ), "received illegal frame size: %ud", len );
      return -1;
   }

   count = 0;

   while( len )
   {
      if (( r = SSL_read( conn->ssl, &buffer[ count ], len )) <= 0 )
      {
         if ( ! r )
            return 0;

         fc_log_error();
         return -1;
      }

      len -= r;
      count += r;
   }

   return total;
}
