/* -*- Mode: C; c-file-style: "bsd" -*- */
#include <string.h>
#include <time.h>
#include <openssl/bn.h>
#include <openssl/sha.h>
#include <openssl/dsa.h>
#include "pgpdsa.h"

#define PGP_DSA_TEST

#define BITS2BYTES( x ) ( ((x)+7) / 8 )

#define PGP_PKT_SIGNATURE 2
#define PGP_PKT_PRIVATE_KEY 5
#define PGP_PKT_PUBLIC_KEY 6
#define PGP_PKT_LITERAL 11

#define PGP_KEY_VERS_1 1
#define PGP_KEY_VERS_2 2
#define PGP_KEY_VERS_3 3
#define PGP_KEY_VERS_4 4

/* old format length of length */

#define PGP_LOL_1 0
#define PGP_LOL_2 1
#define PGP_LOL_4 2
#define PGP_LOL_VAR 3

#define PGP_IS_OLD_CTB( type ) ( ( (type) & 0x40 ) == 0 )
#define PGP_WRITE_OLD_CTB( type, lol ) ( 0x80 | ((type) << 2) | lol )
#define PGP_READ_OLD_CTB_TYPE( type ) ( ( (type) >> 2 ) & 0x1F )
#define PGP_READ_OLD_CTB_LOL( type ) ( (type) & 0x3 )

#define PGP_SIG_VERS_1 1
#define PGP_SIG_VERS_2 2
#define PGP_SIG_VERS_3 3
#define PGP_SIG_VERS_4 4

/* signature types */

#define PGP_SIG_TYPE_BINARY 0x00
#define PGP_SIG_TYPE_TEXT 0x01

/* algorithm types */

#define PGP_PK_ALG_DSA 17
#define PGP_HASH_ALG_SHA1 2

/* crappy macro to reduce code repetition */

#define TRY(x) \
    do{ ret=(x); if (ret<0) {return ret;} ptr+=ret; } while(0)

int pgp_read_ctb( const byte* ptr, int* type, int* length )
{
    int lol;
    if ( PGP_IS_OLD_CTB( *ptr ) )
    {
	*type = PGP_READ_OLD_CTB_TYPE( *ptr );
	lol = PGP_READ_OLD_CTB_LOL( *ptr );
	ptr++;
	switch ( lol ) 
	{
	case PGP_LOL_1: 
	    *length = *ptr; 
	    return 2;
	case PGP_LOL_2: 
	    *length = *ptr++; 
	    *length = *length << 8 | *ptr; 
	    return 3;
	case PGP_LOL_4: 
	    *length = *ptr++; 
	    *length = *length << 8 | *ptr; 
	    *length = *length << 8 | *ptr; 
	    *length = *length << 8 | *ptr; 
	    return 5;
	case PGP_LOL_VAR:
	    *length = -1;
	    return 1;
	}
    }
    else
    {
/* only copes with old CTB format for now */
/* fscking Colin Plumb, manual huffman encoder */
	return PGP_ERR_NOT_IMPLEMENTED;
    }
}

int pgp_write_old_ctb( byte* ptr, int type, int length )
{
    if ( length < 0 ) {
	*ptr++ = PGP_WRITE_OLD_CTB( type, PGP_LOL_VAR );
	return 1;
    } else if ( length < 0x100 ) {
	*ptr++ = PGP_WRITE_OLD_CTB( type, PGP_LOL_1 );
	*ptr++ = length & 0xFF;
	return 2;
    } else if ( length < 0x10000 ) {
	*ptr++ = PGP_WRITE_OLD_CTB( type, PGP_LOL_2 );
	*ptr++ = (length >> 8 ) & 0xFF;
	*ptr++ = length & 0xFF;
	return 3;
    } else {
	*ptr++ = PGP_WRITE_OLD_CTB( type, PGP_LOL_2 );
	*ptr++ = (length >> 24 ) & 0xFF;
	*ptr++ = (length >> 16 ) & 0xFF;
	*ptr++ = (length >> 8 ) & 0xFF;
	*ptr++ = length & 0xFF;
	return 5;
    }
}

int pgp_read_time( const byte* ptr, uint32_t* unixtime )
{
    *unixtime = *ptr;
    *unixtime = *unixtime << 8 | *ptr;
    *unixtime = *unixtime << 8 | *ptr;
    *unixtime = *unixtime << 8 | *ptr;
    return 4;
}

int pgp_write_time( byte* ptr, uint32_t unixtime )
{
    *ptr++ = (unixtime >> 24 ) & 0xFF;
    *ptr++ = (unixtime >> 16 ) & 0xFF;
    *ptr++ = (unixtime >> 8 ) & 0xFF;
    *ptr++ = unixtime & 0xFF;
    return 4;
}

int pgp_read_pascal_string( const byte* ptr, char* str, int* len )
{
    *len = *ptr++;
    memcpy( str, ptr, *len );
    str[*len] = '\0';
    return *len + 1;
}

int pgp_write_pascal_string( byte* ptr, char* str )
{
    int len = strlen( str );
    if ( len > 255 ) { return PGP_ERR_FAIL; }
    *ptr++ = len;
    memcpy( ptr, str, len );
    return len+1;
}

int pgp_read_mpi( const byte* ptr, BIGNUM** b )
{
    int bits;
    int len;

    bits = *ptr++;
    bits = bits << 8 | *ptr++;
    len = BITS2BYTES( bits );
    *b = BN_bin2bn( ptr, len, *b );
    return len+2;
}

int pgp_write_mpi( byte* ptr, BIGNUM* b )
{
    byte* start = ptr;
    int ret;
    int bits = BN_num_bits( b );

    if ( bits > 0xFFFF || bits < 0 ) { return PGP_ERR_FAIL; }

    *ptr++ = (bits >> 8) & 0xFF;
    *ptr++ = bits & 0xFF;
  
    TRY( BN_bn2bin( b, ptr ) );
    return ptr - start;
}

int pgp_read_dsa_key( const byte* ptr, DSA* dsa, byte keyid[ PGP_KEYID_SIZE ] )
{
    int ret;
    int type;
    int len;
    int salt_method;
    uint32_t create_unixtime;
    int pk_alg;
    int key_vers;
    uint16_t checksum_expected;
    uint16_t checksum;
    byte str[3];
    const byte* start = ptr;
    const byte* checksum_ptr;
    const byte* keyid_start;
    int keyid_len;
    SHA_CTX sha1;
    byte digest[ SHA_DIGEST_LENGTH ];
    
    TRY( pgp_read_ctb( ptr, &type, &len ) );
    if ( type != PGP_PKT_PRIVATE_KEY && 
	 type != PGP_PKT_PUBLIC_KEY )
    {
	return PGP_ERR_NOT_IMPLEMENTED;
    }

    keyid_start = ptr;
    key_vers = *ptr++;
    if ( key_vers != PGP_KEY_VERS_4 )
    {
	return PGP_ERR_NOT_IMPLEMENTED;
    }

    TRY( pgp_read_time( ptr, &create_unixtime ) );
    pk_alg = *ptr++;
    if ( pk_alg != PGP_PK_ALG_DSA )
    {
	return PGP_ERR_NOT_IMPLEMENTED;
    }
    TRY( pgp_read_mpi( ptr, &dsa->p ) );
    TRY( pgp_read_mpi( ptr, &dsa->q ) );
    TRY( pgp_read_mpi( ptr, &dsa->g ) );
    TRY( pgp_read_mpi( ptr, &dsa->pub_key ) );

    SHA1_Init( &sha1 );
    str[ 0 ] = 0x99;
    keyid_len = ptr - keyid_start;
    str[ 1 ] = (keyid_len >> 8) & 0xff;
    str[ 2 ] = keyid_len & 0xff;
    SHA1_Update( &sha1, str, 3 );
    SHA1_Update( &sha1, keyid_start, keyid_len );
    SHA1_Final( digest, &sha1 );
    memcpy( keyid, digest + SHA_DIGEST_LENGTH - PGP_KEYID_SIZE, PGP_KEYID_SIZE );

    if ( type == PGP_PKT_PRIVATE_KEY )
    {
	salt_method = *ptr++;
	/* can't cope with encrypted keys yet */
	if ( salt_method != 0 ) 
	{
	    return PGP_ERR_NOT_IMPLEMENTED;
	}

	checksum_ptr = ptr;
	TRY( pgp_read_mpi( ptr, &dsa->priv_key ) );

	for ( checksum = 0; checksum_ptr < ptr; checksum_ptr++ )
	{
	    checksum += *checksum_ptr;
	}

	checksum_expected = *ptr++;
	checksum_expected = checksum_expected << 8 | *ptr++;
	
	if ( checksum_expected != checksum ) 
	{
	    return PGP_ERR_BAD_MSG;
	}
    }

    return ptr - start;
}

int pgp_write_dsa_sig_binary( byte* ptr, DSA* dsa, 
			      const byte keyid[ PGP_KEYID_SIZE ],
			      const byte* in, int len )
{
    byte* extra;
    byte* start = ptr;
    SHA_CTX sha1;
    DSA_SIG* sig = NULL;
    byte digest[ SHA_DIGEST_LENGTH ];
    int ret;

    TRY( pgp_write_old_ctb( start, PGP_PKT_SIGNATURE, 63 ) );

    *ptr++ = PGP_SIG_VERS_3;	// SHOULD use 4, but ok for now
    *ptr++ = 5;			// fixed junk
    extra = ptr;
    *ptr++ = PGP_SIG_TYPE_BINARY; // binary
    TRY( pgp_write_time( ptr, time( NULL ) ) ); // time of sig creation
    
    memcpy( ptr, keyid, PGP_KEYID_SIZE );
    ptr += PGP_KEYID_SIZE;

    *ptr++ = PGP_PK_ALG_DSA;
    *ptr++ = PGP_HASH_ALG_SHA1;

    /* compute hash */

    SHA1_Init( &sha1 );
    SHA1_Update( &sha1, in, len );
    SHA1_Update( &sha1, extra, 5 ); // the 5 bytes above marked by extra
    SHA1_Final( digest, &sha1 );

    /* left 16 bits of signed hash value */

    memcpy( ptr, digest, 2 );
    ptr += 2;

    /* make the signature */

    sig = DSA_do_sign( digest, SHA_DIGEST_LENGTH, dsa );
    if ( sig == NULL ) { return PGP_ERR_FAIL; }

    /* write out the raw r & s values that are the sig */

    TRY( pgp_write_mpi( ptr, sig->r ) );
    TRY( pgp_write_mpi( ptr, sig->s ) );

    DSA_SIG_free( sig );

    /* append the binary text as a literal packet type */

    TRY( pgp_write_old_ctb( ptr, PGP_PKT_LITERAL, len+6 ) );
    *ptr++ = 'b';	// binary
    TRY( pgp_write_pascal_string( ptr, "" ) ); // filename
    TRY( pgp_write_time( ptr, 0 ) ); // time = 0 means current time
    memcpy( ptr, in, len );	// the input file
    ptr += len;

    return ptr - start;
}

int pgp_read_dsa_sig_keyid( const byte* ptr, byte keyid[ PGP_KEYID_SIZE ] )
{
    int ret;
    int type;
    int junk;
    int sig_vers;
    int sig_type;
    int pk_alg;
    int hash_alg;
    uint32_t sig_unixtime;
    SHA_CTX sha1;
    byte digest[ SHA_DIGEST_LENGTH ];

    TRY( pgp_read_ctb( ptr, &type, &junk ) );
    if ( type != PGP_PKT_SIGNATURE ) { return PGP_ERR_FAIL; }

    sig_vers = *ptr++;

    switch ( sig_vers ) 
    {
    case PGP_SIG_VERS_3:
	junk = *ptr++;
	if ( junk != 5 ) { return PGP_ERR_BAD_MSG; }
	
	sig_type = *ptr++;
	if ( sig_type != PGP_SIG_TYPE_BINARY )
	{
	    return PGP_ERR_NOT_IMPLEMENTED; 
	}

	TRY( pgp_read_time( ptr, &sig_unixtime ) );
	memcpy( keyid, ptr, PGP_KEYID_SIZE );

	break;
    case PGP_SIG_VERS_1:
    case PGP_SIG_VERS_2:
    case PGP_SIG_VERS_4:
    default:
	/* not implemented */
	return PGP_ERR_NOT_IMPLEMENTED; 
    }
    return 0;
}

int pgp_read_dsa_sig( const byte* ptr, DSA* dsa,
		      byte* out, int* outlen )
{
    int ret;
    const byte* start = ptr;
    const byte* extra;
    const byte* extra_stuff;
    SHA_CTX sha1;
    DSA_SIG* sig = DSA_SIG_new();
    byte check[ 2 ];
    byte digest[ SHA_DIGEST_LENGTH ];
    byte keyid[ PGP_KEYID_SIZE ];
    char filename[ 257 ];
    int filename_len;
    int len;
    int sig_vers;
    int type;
    int sig_type;
    int junk;
    uint32_t sig_unixtime;
    uint32_t file_unixtime;
    int pk_alg;
    int hash_alg;

    TRY( pgp_read_ctb( ptr, &type, &junk ) );
    if ( type != PGP_PKT_SIGNATURE ) { return PGP_ERR_FAIL; }

    sig_vers = *ptr++;

    switch ( sig_vers ) 
    {
    case PGP_SIG_VERS_3:
	junk = *ptr++;
	if ( junk != 5 ) { return PGP_ERR_BAD_MSG; }
	
	extra = ptr;

	sig_type = *ptr++;
	if ( sig_type != PGP_SIG_TYPE_BINARY )
	{
	    return PGP_ERR_NOT_IMPLEMENTED; 
	}

	TRY( pgp_read_time( ptr, &sig_unixtime ) );
	memcpy( keyid, ptr, PGP_KEYID_SIZE );
	ptr += PGP_KEYID_SIZE;

	pk_alg = *ptr++;
	if ( pk_alg != PGP_PK_ALG_DSA ) 
	{ 
	    return PGP_ERR_NOT_IMPLEMENTED; 
	}

	hash_alg = *ptr++;
	if ( hash_alg != PGP_HASH_ALG_SHA1 )
	{
	    return PGP_ERR_NOT_IMPLEMENTED; 
	}

	/* left 16 bits of signed hash value */

	memcpy( check, ptr, 2 );
	ptr += 2;

	/* read r, s */

	TRY( pgp_read_mpi( ptr, &sig->r ) );
	TRY( pgp_read_mpi( ptr, &sig->s ) );
	
	/* read the literal packet */

	TRY( pgp_read_ctb( ptr, &type, &len ) );
	
	if ( type != PGP_PKT_LITERAL )
	{
	    return PGP_ERR_NOT_IMPLEMENTED; 
	}

	extra_stuff = ptr;
	sig_type = *ptr++;
	if ( sig_type != 'b' && sig_type != 't' )
	{
	    return PGP_ERR_BAD_MSG;
	}
	if ( sig_type == 't' )
	{
	    return PGP_ERR_NOT_IMPLEMENTED; 
	}
	TRY( pgp_read_pascal_string( ptr, filename, &filename_len ) );
	TRY( pgp_read_time( ptr, &file_unixtime ) );

	len -= ptr - extra_stuff;

	SHA1_Init( &sha1 );
	SHA1_Update( &sha1, ptr, len );

	memcpy( out, ptr, len );
	ptr += len;
	*outlen = len;

	SHA1_Update( &sha1, extra, 5 );
	SHA1_Final( digest, &sha1 );

	if ( memcmp( digest, check, 2 ) != 0 )
	{
	    return PGP_ERR_BAD_SIG;
	}
	
	ret = DSA_do_verify( digest, SHA_DIGEST_LENGTH, sig, dsa );
	DSA_SIG_free( sig );
	switch (ret)
	{
	case 1: return ptr - start;
	case 0: return PGP_ERR_BAD_SIG;
	default: return PGP_ERR_FAIL;
	}

    case PGP_SIG_VERS_1:
    case PGP_SIG_VERS_2:
    case PGP_SIG_VERS_4:
    default:
	/* not implemented */
	return PGP_ERR_NOT_IMPLEMENTED; 
    }
}

#if defined( PGP_DSA_TEST )
#include <stdio.h>
#include <stdlib.h>

#define PGP_BUF_SIZE 2048

int file_slurp( byte* buf, size_t size, char* filename )
{
    FILE* file = fopen( filename, "r" );
    if ( file == NULL ) 
    { 
	fprintf( stderr, "can not open file %s\n", filename );
	exit( EXIT_FAILURE );
    }
    size = fread( buf, 1, size, file ); 
    fclose( file );
    return size;
}

int file_dump( const byte* buf, size_t size, char* filename )
{
    FILE* file = fopen( filename, "w" );

    if ( file == NULL )
    { 
	fprintf( stderr, "can not create or write to file %s\n", 
		 filename );
	exit( EXIT_FAILURE );
    }
    size = fwrite( buf, 1, size, file ); 
    fclose( file );
    return size;
}

void CHECK( char* str, int ret )
{
    if ( ret >= 0 )
    {
	fprintf( stderr, "%s ok\n", str );
	return;
    }
    switch ( ret )
    {
    case PGP_ERR_NOT_IMPLEMENTED:
	fprintf( stderr, "not implemented" );
	break;
    case PGP_ERR_BAD_MSG:
	fprintf( stderr, "failure" );
	break;
    case PGP_ERR_BAD_SIG:
	fprintf( stderr, "bad signature" );
	break;
    default:
    case PGP_ERR_FAIL:
	fprintf( stderr, "failure" );
	break;
    }
    fprintf( stderr, " while %s\n", str );
    exit( EXIT_FAILURE );
}

void fprint_hex( FILE* f, byte* data, int n )
{
    int i;

    fprintf( f, "0x" );
    for ( i = 0; i < n; i++ )
    {
	fprintf( f, "%02X", data[ i ] );
    }
}

int main( int argc, char* argv[] )
{
    byte key[ PGP_BUF_SIZE ];
    byte in[ PGP_BUF_SIZE ];
    int inlen;
    byte out[ PGP_BUF_SIZE ];
    int outlen;
    FILE* keyfile;
    FILE* sigfile;
    size_t size;
    int sign;
    int verify;
    DSA* dsa = DSA_new();
    byte keyid[ PGP_KEYID_SIZE ];
    byte sig_keyid[ PGP_KEYID_SIZE ];
    int ret;

    ERR_load_crypto_strings();
    RAND_seed( "12341324123413241234132412341324", 32 );

    if ( argc < 3 ||
	 argv[1][0] != '-' || 
	 ( argv[1][1] != 's' && argv[1][1] != 'v' ) )
    {
	fprintf( stderr, "usage: pgpdsa [-s|-v] dsakey.pgp in out\n" );
	exit( EXIT_FAILURE );
    }

    if ( argv[1][1] == 's' ) { sign = 1; }
    if ( argv[1][1] == 'v' ) { verify = 1; }

    inlen = file_slurp( in, PGP_BUF_SIZE, argv[3] );

    if ( sign )
    {
	file_slurp( key, PGP_BUF_SIZE, argv[2] );
	CHECK( "reading dsa key", pgp_read_dsa_key( key, dsa, keyid ) );

	fprintf( stderr, "signing with key with keyid = " );
	fprint_hex( stderr, keyid, PGP_KEYID_SIZE );
	fprintf( stderr, "\n" );

	CHECK( "creating dsa sig", 
	       outlen = 
	       pgp_write_dsa_sig_binary( out, dsa, keyid, in, inlen ) );
    }
    else if ( verify )
    {
	CHECK( "scanning for dsa sig keyid", 
	       pgp_read_dsa_sig_keyid( in, sig_keyid ) );

	fprintf( stderr, "sig made by keyid = " );
	fprint_hex( stderr, sig_keyid, PGP_KEYID_SIZE );
	fprintf( stderr, "\n" );

	/* if you have multiple keys, name the key file by keyid,
	   or build .db file with lookup by keyid, above call
	   gets you the keyid */

	/* but for test program there is only one fixed key file, 
	   our public key! */

	file_slurp( key, PGP_BUF_SIZE, argv[2] );
	CHECK( "reading dsa key", pgp_read_dsa_key( key, dsa, keyid ) );

	fprintf( stderr, "read in key with keyid = " );
	fprint_hex( stderr, keyid, PGP_KEYID_SIZE );
	fprintf( stderr, "\n" );

	if ( memcmp( keyid, sig_keyid, PGP_KEYID_SIZE ) != 0 )
	{
	    fprintf( stderr, "don't have the key this msg was signed with\n" );
	    exit( EXIT_FAILURE );
	}

	CHECK( "verifying dsa sig", 
	       pgp_read_dsa_sig( in, dsa, out, &outlen ) );
    }

    file_dump( out, outlen, argv[4] );

    DSA_free( dsa );
}
#endif
