#!/usr/bin/python

#spike proxy spkproxy.py


#listen on port 8080
#for each connection
# spawn off a new thread
# accept that connection
# proxy that connection 
#endfor

#for each thread proxying a connection
# wait for data
# when you get data, save into a buffer, check for end-of-header+end_of_body
# 
import socket
import sys
from threading import Thread
import string
import os
from OpenSSL import SSL

print "Running"
print "SPIKE Proxy is copyright Dave Aitel 2002"
print "License: GPL v 2.0"
print "Please visit www.immunitysec.com for updates and other useful tools!"

#s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
#HOST = "www.immunitysec.com"
#PORT = 443
#s.connect((HOST, PORT))
#sslobj = socket.ssl(s)
#sslobj.write("GET / HTTP/1.0\r\n\r\n")
#data=sslobj.read(50000)
#print data

#we use pyOpenSSL for SSL stuff! v.5pre or greater is required
class MyConnection:
    def __init__(self,conn):
        self.doSSL=0
        self.mysocket=conn

    def recv(self,size):
        #if self.doSSL:
        #    print "Reciving data as ssl!"
        #print "Recieving %d bytes" % size
        return self.mysocket.recv(size)

    def send(self,data):
        sizetosend=len(data)
        sentsize=0
        while sentsize<sizetosend:
            #print "sentsize="+str(sentsize)+"/"+str(sizetosend)
            sentsize+=self.mysocket.send(data[sentsize:])
        return sentsize

    def verify_cb(conn, cert, errnum, depth, ok):
        # This obviously has to be updated
        print 'Got certificate: %s' % cert.get_subject()
        return ok


    def startSSLserver(self):
        dir = os.path.dirname(sys.argv[0])
        if dir == '':
            dir = os.curdir

        self.mysocket.send("HTTP/1.0 200 Connection established\r\n\r\n")
        ctx = SSL.Context(SSL.SSLv3_METHOD)
        ctx.set_verify(SSL.VERIFY_NONE, self.verify_cb) # Don't demand a certificate
        ctx.use_privatekey_file (os.path.join(dir, 'server.pkey'))
        ctx.use_certificate_file(os.path.join(dir, 'server.cert'))
        ctx.load_verify_locations(os.path.join(dir, 'CA.cert'))
        self.mysocket = SSL.Connection(ctx, self.mysocket)
        #only works with pyOpenSSL 5.0pre or >
        self.mysocket.set_accept_state()
        #print "State="+self.mysocket.state_string()
        self.doSSL=1
        #print "Now using SSL to talk to client"
        

    def close(self):
        #print "calling connection.close"
        self.mysocket.close()

class header:
    def __init__(self):
        self.clear()

    #you need to sometimes set the connection's state into SSL
    def setConnection(self,conn):
        self.connection=conn
        return

    def clear(self):
        self.data=[]
        self.done=0
        self.goodHeader=0
        self.clientisSSL=0
        #for the first request, we see a CONNECT verb
        self.sawCONNECT=0
        self.firstline="" #sheesh
        #1 if we are reading a response instead of a GET/POST, etc
        self.responseHeader=0
        self.wasChunked=0

        #here is basically what we return from parsing the headers
        self.URLargsDict={}
        self.headerValuesDict={}
        self.useSSL=0
        self.connectHost=""
        self.URL=""
        self.sawsslinit=0
        self.connectPort=0
        self.mybodysize=0
        self.useRawArguments=0
        self.allURLargs=""

        #variables for server response headers
        self.returncode=""
        self.returnmessage=""
        return

    def setclientSSL(self):
        self.useSSL=1
        self.clientisSSL=1
        return
        
    def addData(self,moredata):
        #print "addData "+moredata
        self.data.append(moredata)
        #print self.data[-4:]
        if self.data[-4:]==['\r', '\n', '\r', '\n']:
            #print "Got end of header!"
            self.done=1;
            self.verifyHeader()

    #keys is a set of values for which we're going to look and
    #return an integer associated with them from the headers
    #we return the first value in the header list as an int
    def getIntValue(self,keys):
        #iterate over all the keys in the argument until we have a match
        for akey in keys:
            if self.headerValuesDict.has_key(akey):
                #we just return the first one we encounter, sorry
                #so multiple headers will just be on a first come
                #first serve basis
                return int(self.headerValuesDict[akey][0])
        return 0

    #we return the first value in the header list as a string
    def getStrValue(self,keys):
         for akey in keys:
            if self.headerValuesDict.has_key(akey):
                return str(self.headerValuesDict[akey][0])
         return 0

    def addHeader(self,newheader,newheadervalue):
        #print "Adding header "+newheader+": "+newheadervalue
        #now we store it, at last
        if not self.headerValuesDict.has_key(newheader):
            #intialize it as a list
            self.headerValuesDict[newheader]=[]
        else:
            #print "Duplicate KEY: "+newheader
            pass

        #just separating them by commas doesn't work for hotmail.com 
        self.headerValuesDict[newheader].append(newheadervalue)        

    def verifyHeader(self):
        #this little ditty returns a list of lines, without \r\n's
        #the -2 is because there were 2 null \r\n thingies on the end
        self.allheaders="".join(self.data).split("\r\n")[:-2]
        #print "Self.allheaders="+str(self.allheaders)
        firstline=self.allheaders[0]
        #this will fail if we can't parse the first line
        if not self.parseFirstLine(firstline):
            print "Couldn't parse first line!"
            return 0

        #did we see a CONNECT?
        if self.sawCONNECT:
            #print "Saw SSL CONNECT request!"
            self.sawsslinit=1
            return 1

        
        for headerLine in self.allheaders[1:]:
            #print "Doing header line: "+headerLine
            tempvalues=headerLine.split(": ")
            if len(tempvalues)<2:
                #MS hotmail login is lame - uses this header, notice no space:
                #P3P:CP="BUS CUR CONo FIN IVDo ONL OUR PHY SAMo TELo"
                #so we handle that condition now
                tempvalues=headerLine.split(":")
                if len(tempvalues)<2:
                    print "len(tempvalues)!=2 ="+str(len(tempvalues))+" in "+str(tempvalues)
                    return 0

            self.addHeader(tempvalues[0],":".join(tempvalues[1:]))
                
        self.massageHeaders()
        #print "Headers="+str(self.headerValuesDict)

        #print "Got a good header."
        self.goodHeader=1;
        return

    #this function takes in 
    def massageHeaders(self):
        #non-IE user Agent, for reference
        #User-Agent: Mozilla/5.0 Galeon/1.0.3 (X11; Linux i686; U;) Gecko/0
        #IE string
        IEstring="Mozilla/4.0 (compatible; MSIE 5.0; Windows NT; Bob)"
        nonIEstring="Mozilla/5.0 Galeon/1.0.3 (X11; Linux i686; U;) Gecko/0"
        #always massage chunked out of the way
        #this will cause problems if someone sends over a gig of data
        #I doubt that will happen though
        if self.getStrValue(["Transfer-Encoding"])=="chunked":
            del self.headerValuesDict["Transfer-Encoding"]
            self.wasChunked=1
        
        #don't massage a response
        if self.responseHeader:
            return 
        
        #by default, use IE 5.0
        replaceUserAgent=1
        userAgent=IEstring

        #change Proxy-Connection to Connection
        if self.headerValuesDict.has_key("Proxy-Connection"):
            self.headerValuesDict["Connection"]=self.headerValuesDict["Proxy-Connection"]
            del self.headerValuesDict["Proxy-Connection"]

        #replace the User-Agent
        if replaceUserAgent:
            #just overwrite the damn thing
            if self.headerValuesDict.has_key("User-Agent"):
                del self.headerValuesDict["User-Agent"]
            self.addHeader("User-Agent",userAgent)


        #save this off before we delete it
        self.mybodysize=self.getIntValue(["Content-length","Content-Length"])
        #get rid of Content-Length or Content-length - this is
        #a requirement since we recalcuate it later for fun!
        if self.headerValuesDict.has_key("Content-length"):
            del  self.headerValuesDict["Content-length"]
        if self.headerValuesDict.has_key("Content-Length"):
            del  self.headerValuesDict["Content-Length"]

        #no return value for massageHeaders
        return


    def parseFirstLine(self,firstline):
        #print "firstline="+firstline
        templist=firstline.split(" ")
        if len(templist)<3:
            print "First line has less than 3 members!"
            return 0
        self.verb=templist[0]

        if self.verb in [ "HTTP/1.1", "HTTP/1.0" ]:
            #print "Response header - not verifying the first line!"
            self.responseHeader=1
            if len(templist)>1:
                self.returncode=templist[1]
            if len(templist)>2:
                self.returnmessage=templist[2]
            self.firstline=firstline
            return 1
        

        #SSL proxy check
        if self.verb=="CONNECT":
            #WE ARE SSL!
            #signifies we connect to server with ssl
            self.useSSL=1
            #signifies we connect to client with ssl
            self.clientisSSL=1
            self.sawCONNECT=1
            self.connectHost=templist[1].split(":")[0]

            #no port would be weird, but maybe it'll happen...
            if templist[1].split(":") < 2:
                self.connectPort=443
            else:
                self.connectPort=templist[1].split(":")[1]
            self.connection.startSSLserver()
            return 1
        
        if not self.processProxyUrl(templist[1]):
            return 0

        self.version=templist[2]
        #print "VERB="+self.verb+" URL="+self.URL+" version="+self.version
        return 1
        

    def processProxyUrl(self, proxyurl):

        #here is basically what we return
        self.URLargsDict={}
        self.useSSL=0
        self.connectHost=""
        #this might already be set if we got an SSL proxy request
        if not self.connectPort:
            self.connectPort=80
        self.URL=""

        #print "processProxyUrl: "+proxyurl
        #just in case we ARE doing ssl...
        urlbit=proxyurl
        #if we're not doing an SSL proxy
        if not self.clientisSSL:
            #print "proxyURL is not SSL"
            #rip the http:// off
            urltype=proxyurl.split("://")[0]
            if len(proxyurl.split("://")) < 2:
                print "Need something after the http:// - exiting this thread"
                return 0
            #else we are good to go...we reassign urlbit here
            urlbit=proxyurl.split("://")[1]
            if urltype=="https":
                #this is probably broken: REVISIT
                self.useSSL=1
            elif urltype!="http":
                print "unknown url type "+urltype
                return 0

            #must have http://something
            if len(proxyurl.split("://"))<2:
                print "must have http://something"
                return 0
            
            self.connectHost=urlbit.split("/")[0]

            #get rid of the host from urlbit
            if len(urlbit.split("/"))<2:
                urlbit="/"
            else:
                urlbit="/".join(urlbit.split("/")[1:])

            if urlbit=="":
                urlbit="/"

            #lame, but should work
            if urlbit[0]!="/":
                urlbit="/"+urlbit
                
            #print "connectHost="+self.connectHost
            if self.connectHost.count(":")>0:
                self.connectHost=self.connectHost.split(":")[0]
                #print "Set self.connectHost to "+self.connectHost
                self.connectPort=int(self.connectHost.split(":")[1])
            if self.connectHost=="":
                print "Error: empty connect host!"
                return 0
            
        #end if self.clientisSSL==0:

        #if we have a url as well
        self.URL=urlbit.split("?")[0]
        #if we have arguments too
        if len(urlbit.split("?"))>1:
            self.allURLargs=urlbit.split("?")[1]
            templist=self.allURLargs.split("&");
            for pair in templist:
                if pair!="":
                    templist2=pair.split("=")
                    if len(templist2)!=2:
                        #print "Failed to parse the URL arguments because of invalid number of equal signs in one argument in: \""+pair+"\" len="+str(len(templist2))
                        self.useRawArguments=1
                        return 1
                    #add this argument to the URLargsDict
                    self.URLargsDict[templist2[0]]=templist2[1]
        #got here! success!
        #we now have URLargsDict
        return 1
        
    def isdone(self):
        #print "self.isdone called "+str(self.done)
        if self.done==0:
            return 0
        return 1

    def gotGoodHeader(self):
        return self.goodHeader

    def bodySize(self):
        return self.mybodysize

    def grabHeader(self,header):
        if self.headerValuesDict.has_key(header):
            returnstr=""
            #iterate over the list and add a line for each
            for value in self.headerValuesDict[header]:
                returnstr+=header+": "+value+"\r\n"
            return returnstr
        else:
            return ""
        

class body:
    def __init__(self):
        self.mysize=0
        self.data=[]

    def readBlock(self,connection,size):
        targetsize=size
        tempdata=""
        while targetsize > len(tempdata):
            #read some data
            tempdata+=connection.recv(targetsize-len(tempdata))
            #print "Targetsize=%d, len(tempdata)=%d" % (targetsize,len(tempdata))
        #print "read "+str(len(tempdata))+" bytes of data in readblock, wanted "+str(size)
        self.data+=tempdata
        self.mysize+=targetsize
        return size

    #This handles chunked data cleanly - well, handles it anyways
    #this is the cruftiest function ever made.
    def read(self,connection,size,waschunked,readtillclosed):
        if not waschunked:
            if readtillclosed and size==0:
                #print "reading till closed"
                temp=""
                while 1:
                    #this is a lame way to do it, but hopefully it will work
                    try:
                        length=len(temp)
                        #print "len="+str(length)
                        temp+=connection.recv(1000)
                        #print "len2="+str(len(temp))
                        #WAY crufty here...
                        if (length==len(temp)):
                            break
                        if temp.count("</html>")>0:
                            #this is necessary because stupid hotmail will
                            #not send a fin after sending lots of data
                            #with connection: close!
                            #print "Noticed a </html> - breaking out of this"
                            break
                    except (SSL.SysCallError,socket.error), diag:
                        #print "Caught exception in recv - "+str(diag)
                        break;
                    except:
                        print "Unknown exception occured"
                        break;
                    
                #print "Read till close occured - "+str(len(temp))+" bytes read"
                self.data+=temp
                self.mysize+=len(temp)
                
                return len(temp)
            else:
                return self.readBlock(connection,size)
        
        else:
            #print "Reading chunked data"
            while 1:
                #read in a chunked data stream and return the size
                linesize=[]
                while linesize[-2:]!=["\r","\n"]:
                    linesize+=connection.recv(1)
                #ok, now we have the size as a list, transform that to an int
                #base 16, of course
                #print "linesize in str = "+"".join(linesize)
                linesize=int("".join(linesize),16)
                #print "linesize="+str(linesize)
                if linesize==0:
                    #print "done with chunked transfer!"
                    #clear this out
                    linesize=[]
                    while linesize[-2:]!=["\r","\n"]:
                        linesize+=connection.recv(1)
                    return self.mysize
                #print "calling self.readBlock with size "+str(linesize)
                self.readBlock(connection,linesize)
                #clear this out
                linesize=[]
                while linesize[-2:]!=["\r","\n"]:
                    linesize+=connection.recv(1)

    def gotGoodBody(self):
        if self.mysize==len(self.data):
            return 1
        else:
            return 0
        
class spkProxyConnection( Thread ):
    def __init__(self,connection):
        Thread.__init__(self)
        #client connection
        self.connection=connection
        self.clientisSSL=0
        self.currentHost=""
        self.currentPort=0
        self.haveSocket=0
        self.sslHost=""
        self.sslPort=""
        #serversion connection
        self.currentSocket=-1
        
    def run( self ):
        while 1:
            #print "entering while loop"
            myheader = header()
            myheader.setConnection(self.connection)
            if self.clientisSSL:
                myheader.setclientSSL()

            while myheader.isdone()==0:
                try:
                    data=self.connection.recv(1)
                except:
                    print "Client closed connection"
                    self.cleanup()
                    return
                
                if not data:
                    #print "end of data"
                    break;
                myheader.addData(data)

            if myheader.sawsslinit==1:
                #print "Saw ssl init!"
                self.clientisSSL=1
                self.sslHost=myheader.connectHost
                self.sslPort=myheader.connectPort
                continue

            #print "Continuing on with while loop!"
            mybody=body()
            #print "Done with header"
            #read the body from the client now
            if myheader.gotGoodHeader():
                #print "reading body"
                if myheader.bodySize()>0 or myheader.wasChunked:
                    #print "Reading the body!"
                    #readtillclosed always 0 on client
                    mybody.read(self.connection,myheader.bodySize(),myheader.wasChunked,0)
                else:
                    #print "No body needed"
                    pass

                #reset this to the truth
                myheader.mybodysize=mybody.mysize
                    
                if not mybody.gotGoodBody():
                    self.cleanup()
                    return
                #print "done with body"
            else:
                #print "failed to get a good header, cleaning up."
                self.cleanup()
                return
            #done with the body. So now we have a header and a body
            #print "header data="+str(myheader.data)
            #print "body data="+str(mybody.data)

            response = self.sendRequest(myheader,mybody)
            #print "Response : "+response
            sizetosend=len(response)
            sentsize=0
            while sentsize<sizetosend:
                #print "sentsize="+str(sentsize)+"/"+str(sizetosend)
                sentsize+=self.connection.send(response[sentsize:])
            #print "Sent data to client."
            #now send our data to the remote server
            continue #while loop


    def constructResponse(self,myheader,mybody):
        if myheader.firstline=="":
            print "Serious error: response's first line is empty!"
            
        response=myheader.firstline+"\r\n"
        for akey in myheader.headerValuesDict.keys():
            #don't send 2 Content-lengths
            if akey not in [ "Content-Length", "Content-length"]:
                response+=myheader.grabHeader(akey)

        
        response+="Content-Length: "+str(mybody.mysize)+"\r\n"

            
        response+="\r\n"
        response+="".join(mybody.data)
        return response

    
    def constructRequest(self,myheader,mybody):
        #debug 
        if 0:
            return "GET / HTTP/1.1\r\nHost: www.immunitysec.com\r\nContent-Length: 0\r\n\r\n"
        request=myheader.verb+" "+myheader.URL
        #if we have arguments
        if myheader.useRawArguments:
            request+="?"+myheader.allURLargs
        else:
            if len(myheader.URLargsDict) > 0:
                request+="?"
                first=1
                for akey in myheader.URLargsDict.keys():
                    if not first:
                        request+="&"
                    first=0
                    request+=akey+"="+myheader.URLargsDict[akey]

                
        request+=" "+myheader.version+"\r\n"


        
        #ok, the first line is done!

        #do the rest of the headers that need order
        #I dunno if any except Host really need ordering, but I do it
        #to erase any chance of lame bugs later on
        #plus, python makes it quite easy
        needOrdered=["Host","User-Agent","Accept","Accept-Language","Accept-Encoding","Accept-Charset","Keep-Alive","Connection","Pragma","Cache-Control"]
        for avalue in needOrdered:
            request+=myheader.grabHeader(avalue)
        #now work on the header pairs we haven't already done
        for akey in myheader.headerValuesDict.keys():
            if akey not in needOrdered:
                request+=myheader.grabHeader(akey)

        

        #ok, headers are all done except for content-length
        #Content-Length: 0 should always be valid, but it's
        #not working for some reason on get requests!
        if mybody.mysize!=0 or myheader.verb!="GET":
            request+="Content-Length: "+str(len(mybody.data))+"\r\n"

        #ok, all headers are done, finish with blank line
        request+="\r\n"

        #ok, now add body
        request+="".join(mybody.data)

        #done!
        return request
        
    def connectToWebServer(self,myheader):
        if self.clientisSSL:
            #do we already have a socket connected to the web server?
            if not self.haveSocket:
                self.haveSocket=1
                self.currentHost=self.sslHost
                self.currentPort=self.sslPort
                self.currentSocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
                print "Connecting to "+str(self.sslHost)+" "+str(self.sslPort)
                try:
                    self.currentSocket.connect((self.sslHost,int(self.sslPort)))
                except:
                    print "Connection refused"
                    return 0
                #HANDLE SSL HERE DO DO DO
                ctx = SSL.Context(SSL.SSLv23_METHOD)
                self.currentSocket = SSL.Connection(ctx,self.currentSocket)
                self.currentSocket.set_connect_state()
                print "Set up SSL"
                
        else:
            #not SSL
            #do we already have a socket connected to the host
            print "Connecting to "+str(myheader.connectHost)+" "+str(myheader.connectPort)
            if self.currentHost==myheader.connectHost and self.currentPort==myheader.connectPort:
                #print "passing because currentHost and currentPort are the same"
                #nothing really
                pass
            else:
                #handle the condition where we have a socket, but it is the wrong host...
                if self.haveSocket:
                    self.currentSocket.close()

                #if we don't have a socket, or we had the wrong socket, we now need a socket
                #TODO: add error checking...
                self.currentSocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
                
                self.currentSocket.connect((myheader.connectHost,int(myheader.connectPort)))
                self.currentHost=myheader.connectHost
                self.currentPort=myheader.connectPort

        #return success!
        return 1

                    
    #given a valid header and body, sends it off, and returns the result
    def sendRequest(self, myheader, mybody):

        byestring="<html><head><title>Error</title></head><body><h1>No SSL server there, sorry.</h1></body></html>"
        if not self.connectToWebServer(myheader):
            print "returning fake 501 page!"
            return "HTTP/1.1 501 No Server There!\r\nContent-Length: "+str(len(byestring))+"\r\n\r\n"+byestring

        myRequest=self.constructRequest(myheader,mybody)
            
        #ok, now I have a socket connected to the host, send the data
        try:
            self.currentSocket.send(myRequest)
        except:
            return "HTTP/1.1 501 No Server There!\r\nContent-Length: "+str(len(byestring))+"\r\n\r\n"+byestring
            
        print "Sent request:\n"+myRequest

        returncode="100"
        #now read the response - we just ignore HTTP/1.1 100 Continue responses
        while returncode=="100":
            serverheader = header()
            serverheader.setConnection(self.currentSocket)

            #print "Reading response now"
            while serverheader.isdone()==0:
                try:
                    data=self.currentSocket.recv(1)
                    #print "Read a byte: "+data
                except ZeroReturnError:
                    print "Server closed connection - weird"
                
            
                if not data:
                    #print "end of data in response!"
                    break;
                serverheader.addData(data)
            returncode=serverheader.returncode
            #print "Rerturn code from server response="+returncode

        #print "end of header in response!"
        #+str(serverheader.data)
        #does a case insensitive match and returns us the content length
        #variable
        bodylength=serverheader.getIntValue(["Content-Length","Content-length"])
        readtillclosed=0
        if serverheader.getStrValue(["Connection"])=="close":
            #print "Connection: close detected"
            readtillclosed=1

        print "\nResponse Header:\n"+"".join(serverheader.data)
        
        #print "Reading a body of length "+str(bodylength)
        serverbody=body()
        serverbody.read(self.currentSocket,bodylength,serverheader.wasChunked,readtillclosed)
        #print "Body turned out to be "+str(serverbody.mysize)+" or "+str(len(serverbody.data))+" bytes."

        response=self.constructResponse(serverheader,serverbody)
        if (readtillclosed):
            self.currentSocket.close()
        return response

    def cleanup(self):
        #needs to close socket and stuff
        self.connection.close()
        if self.haveSocket:
            self.currentSocket.close()
        return
    
class spkProxy:
    def __init__(self):
        self.mylistenport=8080
        self.mylistenhost=''

    def setPort(self,port):
        self.mylistenport=round(port)

    def run(self):
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        s.bind((self.mylistenhost, self.mylistenport))
        s.listen(5)
        while 1:
            conn, addr = s.accept()
            print 'Connected to by', addr
            connection=MyConnection(conn)
            self.handleConnection(connection)


    def handleConnection(self,connection):
        #this needs to spawn a new thread!!
        spkProxyConnection(connection).start()
        #done. :>
        
#end of class spkProxy    


#this stuff happens.
if __name__ == '__main__':
    app = spkProxy()
    if len(sys.argv) > 1:
        app.setPort(int(sys.argv[1]))
            
    app.run()
