
from cl32 import *
from array import *
import sys, time

def openKeyStore():
    global keySet
    keySet = cryptKeysetOpen(CRYPT_UNUSED, CRYPT_KEYSET_FILE, "keyset.p15", CRYPT_KEYOPT_CREATE)

def createKeyPairAndCert(commonName, label, isCa, isCodeSigning):
    print "Generating key pair and cert for " + label
    global caKeyPair

    #Create key pair
    keyPair = cryptCreateContext(CRYPT_UNUSED, CRYPT_ALGO_RSA)
    cryptSetAttributeString(keyPair, CRYPT_CTXINFO_LABEL, label) # Verbose way of setting attributes
    keyPair.CTXINFO_KEYSIZE = 128                                # Shortcut way of setting attributes 
    cryptGenerateKey(keyPair)
    if isCa:
        caKeyPair = keyPair

    #Create cert
    cert = cryptCreateCert(CRYPT_UNUSED, CRYPT_CERTTYPE_CERTIFICATE)
    if isCa:
        cert.CERTINFO_SELFSIGNED = 1
        cert.CERTINFO_CA = 1
    if isCodeSigning:
        cert.CERTINFO_EXTKEY_CODESIGNING = CRYPT_UNUSED
    cert.CERTINFO_SUBJECTPUBLICKEYINFO = keyPair
    cert.CERTINFO_COUNTRYNAME = "us"
    cert.CERTINFO_COMMONNAME = commonName

    #Sign cert with keyPair
    cryptSignCert(cert, caKeyPair)

    #Add keyPair and cert chain to pkcs15 keyset (.p15)
    cryptAddPrivateKey(keySet, keyPair, "password")
    cryptAddPublicKey(keySet, cert)

    #Write out cert chain to pkcs7 file (.p7b)
    buffer = array('b', [0] * 1000)
    certLength = cryptExportCert(buffer, CRYPT_CERTFORMAT_CERTCHAIN, cert)
    f = open(label+".p7b", "wb")
    f.write(buffer[:certLength])
    f.close()

    #Write out private key and cert chain to pkcs12 file (.pfx)
    #
    # DISABLED CAUSE THE DEFAULT CRYPTLIB DOESN'T SUPPORT PFX
    #
    # pfxKeySet = cryptKeysetOpen(CRYPT_UNUSED, CRYPT_KEYSET_FILE, label+".pfx", CRYPT_KEYOPT_CREATE)
    # cryptAddPrivateKey(pfxKeySet, keyPair, "password")
    # cryptAddPublicKey(pfxKeySet, cert)
    # cryptKeysetClose(pfxKeySet)

    #Cleanup
    cryptDestroyContext(cert)
    cryptDestroyContext(keyPair)

    #Re-get the CA key pair so we get it attached to the certificate
    #This is a kludge, I wish I could just attach the cert directly
    #to the keypair, but we need to do this so we can later sign with
    #this keypair and have the cert automatically added to the chain
    if isCa:
        caKeyPair = cryptGetPrivateKey(keySet, CRYPT_KEYID_NAME, label, "password")

def closeKeyStore():
    cryptDestroyContext(caKeyPair)
    cryptKeysetClose(keySet)

def makeTempKeyPairsAndCerts(closeKeySet):
    print "Generating temporary key pairs and certs"

    global tempKeyPairs
    global tempCerts
    global tempKeySet

    tempKeyPairs = [0,0]
    tempCerts = [0,0]

    tempKeySet = cryptKeysetOpen(CRYPT_UNUSED, CRYPT_KEYSET_FILE, "keyset.p15", CRYPT_KEYOPT_CREATE)

    for count in range(2):
        #Generate temp key pair
        k = cryptCreateContext(CRYPT_UNUSED, CRYPT_ALGO_RSA)
        k.CTXINFO_LABEL = "Temp Key " + str(count)
        k.CTXINFO_KEYSIZE = 128
        cryptGenerateKey(k)

        #Generate good-for-anything self-signed cert for key pair
        c = cryptCreateCert(CRYPT_UNUSED, CRYPT_CERTTYPE_CERTIFICATE)
        c.CERTINFO_XYZZY = 1 #Special good-for-anything cert
        c.CERTINFO_SUBJECTPUBLICKEYINFO = k
        c.CERTINFO_COMMONNAME = "Temp Cert" + str(count)
        cryptSignCert(c, k)

        #Kludge
        #Must write out key pair and cert to keyset and then read them back
        #to get cert attached to keypair.  Once the key pair  has a certificate
        #attached to it, we can use it as the server certificate for TLS, for example
        cryptAddPrivateKey(tempKeySet, k, "password")
        cryptAddPublicKey(tempKeySet, c)
        cryptDestroyContext(k)
        k = cryptGetPrivateKey(tempKeySet, CRYPT_KEYID_NAME, "Temp Key " + str(count), "password")

        tempKeyPairs[count] = k
        tempCerts[count] = c
    if (closeKeySet):
        cryptKeysetClose(tempKeySet)

def destroyTempKeyPairsAndCerts(closeKeySet):
    for count in range(2):
        cryptDestroyContext(tempKeyPairs[count])
        cryptDestroyCert(tempCerts[count])
    if closeKeySet:
        cryptKeysetClose(tempKeySet)

def testTLSClient(serverName):
    #Open an SSL session to the server
    session = cryptCreateSession(CRYPT_UNUSED, CRYPT_SESSION_SSL)
    session.SESSINFO_SERVER_NAME = serverName
    session.SESSINFO_ACTIVE = 1

    #Set to blocking
    cryptSetAttribute(CRYPT_UNUSED, CRYPT_OPTION_NET_TIMEOUT, 30)

    #Send 10MB of data in 10K packets
    text = array('b', [0] * 10240)
    bytesSent = 0
    bytesRead = 0
    startTime = time.clock()
    for count in range(512):
        bytesSent += cryptPushData(session, text)
        cryptFlushData(session)
        bytesRead += cryptPopData(session, text, len(text))
    endTime = time.clock()
    print "Sent 10 MB over TLS 1.0 with 3DES in " + str(endTime-startTime) + " seconds"
    cryptDestroySession(session)

def testTLSServer():
    #Start up a single-threaded server, using the server keyPair/cert we
    #previously generated, and block waiting for a connection
    print "Waiting for client..."
    session = cryptCreateSession(CRYPT_UNUSED, CRYPT_SESSION_SSL_SERVER)
    session.CRYPT_SESSINFO_PRIVATEKEY = tempKeyPairs[0]
    session.CRYPT_SESSINFO_ACTIVE = 1
    print "Connection established..."

    #Set it to blocking
    cryptSetAttribute(CRYPT_UNUSED, CRYPT_OPTION_NET_TIMEOUT, 30)

    #Send and Receive a total of 10MB in 10KB chunks
    text = array('b', [0] * 10240)
    bytesSent = 0
    bytesRead = 0
    for count in range(512):
        bytesRead += cryptPopData(session, text, len(text))
        bytesSent += cryptPushData(session, text)
        cryptFlushData(session)
    cryptDestroySession(session)
    print "Session closed..."

def testCMS():
    plaintext = "Hi, how are you"
    buffer = array('b', [0] * (32 * 1024)) #equal to cryptlibs buffer size

    #First sign with temp key 0
    print "Signing message"
    e = cryptCreateEnvelope(CRYPT_UNUSED, CRYPT_FORMAT_CRYPTLIB)
    cryptSetAttribute(e, CRYPT_ENVINFO_SIGNATURE, tempKeyPairs[0])
    cryptPushData(e, plaintext)
    cryptFlushData(e)
    bytesCopied = cryptPopData(e, buffer, len(buffer))
    cryptDestroyEnvelope(e)

    #Then encrypt to temp certs 0 and 1, using AES256
    print "Encrypting message"
    e = cryptCreateEnvelope(CRYPT_UNUSED, CRYPT_FORMAT_CRYPTLIB)

    c = cryptCreateContext(CRYPT_UNUSED, CRYPT_ALGO_AES)
    c.CRYPT_CTXINFO_KEYSIZE = 32
    cryptGenerateKey(c)
    e.CRYPT_ENVINFO_SESSIONKEY = c
    cryptDestroyContext(c)

    e.ENVINFO_PUBLICKEY = tempCerts[0]
    e.ENVINFO_PUBLICKEY = tempCerts[1]

    cryptPushData(e, buffer[:bytesCopied])
    cryptFlushData(e)
    bytesCopied = cryptPopData(e, buffer, len(buffer))
    cryptDestroyEnvelope(e)

    #Prepare to decrypt encrypted data
    print "Decrypting message"
    e = cryptCreateEnvelope(CRYPT_UNUSED, CRYPT_FORMAT_AUTO)

    #It tells us we need to add an encryption attribute, via Exceptions
    #This is an unfortunate result of translating status values as
    #Exceptions, but we can live with it
    try:
        cryptPushData(e, buffer[:bytesCopied])
    except CryptException: pass
    try:
        cryptFlushData(e)
    except CryptException: pass

    #Iterate through encryption attributes until we find one that matches key 1
    cryptSetAttribute(e, CRYPT_ENVINFO_CURRENT_COMPONENT, CRYPT_CURSOR_FIRST)
    while 1:
        attributeNeeded = cryptGetAttribute(e, CRYPT_ENVINFO_CURRENT_COMPONENT)
        if attributeNeeded == CRYPT_ENVINFO_PRIVATEKEY:
            try:
                #Test if key 1 works.  If so, exit the loop
                cryptSetAttribute(e, CRYPT_ENVINFO_PRIVATEKEY, tempKeyPairs[1])
                break
            except CryptException, ex:
                #This is not the correct key
                if (ex[0] != CRYPT_ERROR_WRONGKEY):
                    raise CryptException, ex
        cryptSetAttribute(e, CRYPT_ENVINFO_CURRENT_COMPONENT, CRYPT_CURSOR_NEXT)

    bytesCopied = cryptPopData(e, buffer, len(buffer))
    cryptDestroyEnvelope(e)

    #Prepare to verify decrypted data
    print "Verifying message"
    e = cryptCreateEnvelope(CRYPT_UNUSED, CRYPT_FORMAT_AUTO)

    #Check that it was signed by one of the keys in tempKeySet
    cryptSetAttribute(e, CRYPT_ENVINFO_KEYSET_SIGCHECK, tempKeySet)

    cryptPushData(e, buffer[:bytesCopied])
    cryptFlushData(e)
    bytesCopied = cryptPopData(e, buffer, len(buffer))

    result = cryptGetAttribute(e, CRYPT_ENVINFO_SIGNATURE_RESULT)
    if (result != CRYPT_OK):
        raise CryptException, (result, "signature failed")
    print "Decrypted/Verified data is:\n" + str("".join([chr(c) for c in buffer[0:bytesCopied]]))

def testCipher(cryptAlgo, keySize, algoName):
    #Create context with random IV (if it needs one) and a key derived
    #from a password using a salt and iteration count to hinder guessing attacks
    encContext = cryptCreateContext(CRYPT_UNUSED, cryptAlgo)
    encContext.CTXINFO_KEYSIZE = keySize
    encContext.CTXINFO_KEYING_ITERATIONS = 10000
    encContext.CTXINFO_KEYING_SALT = "salt1234"         
    encContext.CTXINFO_KEYING_VALUE = "password123"		

    #Encrypt 10 MB of text
    #We do this in 10 KB chunks for the sake of demonstration
    #--------------------------------------------------------------

    #Create 10MB array
    #We use a direct ByteBuffer, otherwise the entire array needs to be
    #copied to/from JNI each time, which kills performance
    tenMegs = 10 * 1024 * 1024
    text = array('b', [0] * tenMegs)

    #Encrypt each chunk
    startTime = time.clock()
    cryptEncrypt(encContext, text)    #encrypt chunk
    endTime = time.clock()

    #Print results
    print algoName +" 10 MB in %2f seconds" % (endTime-startTime)

    #Destroy the context, zeroing the key
    cryptDestroyContext(encContext)




if __name__ == "__main__":
    if len(sys.argv) == 1:
        print """\
    Commands:
    testcerts <caName> <webServerName> <codeSignerName>
    testtls [server|client] <serverName (if client)>
    testcms
    testciphers"""
        sys.exit()
    
    cryptInit()
    if sys.argv[1].lower() == "testcerts":
        openKeyStore()
        createKeyPairAndCert(sys.argv[2], "ca", 1, 0)
        createKeyPairAndCert(sys.argv[3], "webServer", 0, 0)
        createKeyPairAndCert(sys.argv[4], "codeSigner", 0, 1)
        closeKeyStore()
        
    elif sys.argv[1].lower() == "testtls":
        makeTempKeyPairsAndCerts(1)
        if sys.argv[2].lower() == "client":
            testTLSClient(sys.argv[3])
        else:
            testTLSServer()
        destroyTempKeyPairsAndCerts(0)

    elif sys.argv[1].lower() == "testcms":
        makeTempKeyPairsAndCerts(0)
        testCMS()
        destroyTempKeyPairsAndCerts(1)

    elif sys.argv[1].lower() == "testciphers":
        testCipher(CRYPT_ALGO_HMAC_SHA, 32, "HMAC-SHA1   Processed")
        testCipher(CRYPT_ALGO_RC4, 16, 		"RC4         Encrypted")
        testCipher(CRYPT_ALGO_AES, 16, 		"AES-128 CBC Encrypted")
        testCipher(CRYPT_ALGO_AES, 32, 		"AES-256 CBC Encrypted")
        testCipher(CRYPT_ALGO_3DES, 24,		"3DES CBC    Encrypted")
    else:
        print "unrecognized parameter"
    try:
        cryptEnd()
    except CryptException, ex:
        pass