import socket import zlib import json import struct import queue from imshow import parallel class SocketMessage(): def __init__(self, msg={}): self.data = msg self.type = 'empty' def encode(self): pass def decode(self): pass def tobytes(self): self.encode() js = {} js['type'] = self.type js['msg'] = self.data js = json.dumps(js) js = js.encode('utf-8') byte = zlib.compress(js) return byte def frombytes(self, byte): js = zlib.decompress(byte) js = js.decode('utf-8') js = json.loads(js) msg = js['msg'] self.data = msg self.decode() class SocketServer(): def __init__(self, ip='0.0.0.0', port=12345, message_handler=None): self.daemon = parallel.daemon self.ip = ip self.port = port self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.socket.bind((self.ip, self.port)) self.socket.listen(5) self.terminate = False self.handler = message_handler def handle_message(self, clientsocket, addr): # print('Connection Established from clinet', addr) if self.handler is not None: self.handler(clientsocket, addr) pass def loop(self): while not self.terminate: clientsocket,addr = self.socket.accept() self.daemon.add_job( self.handle_message, args=[clientsocket, addr], name='Client[{0}]'.format(addr) ) def start(self, back=True): if back: self.daemon.add_job(self.loop, name='SocketMainLoop') else: self.loop() # what should a packet header contains: # 1. message id # 2. packet id # 3. total packets # 4. total size class PacketHeader(): def __init__(self, mid=0, pid=0, pn=0, sz=0): self.msg_id = mid self.pkt_id = pid self.pkt_num = pn self.msg_sz = sz self.header_size = 16 def tobytes(self): b = struct.pack('LLLL', self.msg_id, self.pkt_id, self.pkt_num, self.msg_sz) return b def frombytes(self, b): self.msg_id, self.pkt_id, self.pkt_num, self.msg_sz = struct.unpack('LLLL', b) return self class Packet(): def __init__(self, header=PacketHeader(), msg=b''): self.header = header self.msg = msg self.header_size = self.header.header_size def frombytes(self, b): header = b[:self.header_size] self.msg = b[self.header_size:] self.header.frombytes(header) return self def tobytes(self): msg = b'' self.header.msg_sz = len(self.msg) msg += self.header.tobytes() msg += self.msg return msg class PacketFactory(): def __init__(self, max_size=8192, log=None): self.max_size = max_size self.id = 0 self.header_size = PacketHeader().header_size self.log = print if log != None: self.log = log def to_packets(self, msg): packets = [] length = len(msg) capacity = self.max_size - self.header_size num_packets = int((length + capacity - 1) / capacity) for i in range(num_packets): header = PacketHeader(self.id, i, num_packets) packet = Packet(header, msg[i*capacity:(i+1)*capacity]) packets.append(packet) self.id += 1 return packets def from_packets(self, packets): num_packet = len(packets) self.log('Packet Number:', num_packet) if num_packet == 0: return None msg_id = None packet_id = 0 message = b'' for packet in packets: header = packet.header if msg_id is None: msg_id = header.msg_id if num_packet != header.pkt_num: self.log('Uncorrect Package Number') self.log('get {0} while it should be {1}'.format(header.pkt_num, num_packet)) return None if msg_id != header.msg_id: self.log('Uncorrect Message id') self.log('get {0} while it should be {1}'.format(header.pkt_num, msg_id)) return None if packet_id != header.pkt_id: self.log('Uncorrect pkt id') self.log('get {0} while it should be {1}'.format(header.pkt_id, packet_id)) return None message += packet.msg packet_id += 1 return message class AuthMessage(SocketMessage): def __init__(self, token='None'): super(AuthMessage, self).__init__() self.token = token self.type = 'auth' self.stat = 0 # stat: # 0: auth client # 1: auth success # 2: auth failed. def encode(self): self.data = {} self.data['token'] = self.token self.data['status'] = self.stat def decode(self): self.token = self.data['token'] self.stat = self.data['status'] class SocketConnection(): def __init__(self, daemon, log_prefix=''): self.sock = None self.daemon = daemon self.messages = queue.Queue() self.terminated = False self.factory = PacketFactory(log=self.log) self.header_size = PacketHeader().header_size self.loglevel = 2 self.log_prefix = log_prefix # log level: # 0 : debug # 1 : message or info # 2 : warning # 3 : error def log(self, *args, level=0, end='\n'): if level >= self.loglevel: print(self.log_prefix, end=' ') for msg in args: print(str(msg), end=' ') print(end, end='') def start(self): self.jid = self.daemon.add_job(self.recv_bare, name='connection') def SetSock(self, sock): self.sock = sock self.start() def connect(self, host, port): if self.sock is None: self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.sock.connect((host, port)) self.log('Sock Connected', level=1) else: self.log('Socket has already established, ingoring connect.', level=2) self.start() def auth(self, token): if self.sock is None: self.log('Socket not established, unable to auth.', level=3) return None auth_msg = AuthMessage(token) msg = auth_msg.tobytes() self.log('Sending Authentication Message.') self.send(msg) self.log('Waiting Authentication Status.') msg = self.recv() auth_msg.frombytes(msg) self.log('Status Recived! auth_status =', auth_msg.__dict__) if auth_msg.stat == 1: self.log('Auth Success!', level=1) else: self.log('Auth Failed!', level=3) self.close() def WaitAuth(self, token): msg = self.recv() auth_msg = AuthMessage() auth_msg.frombytes(msg) if auth_msg.token != token: self.log('Authentication Failed!', level=1) auth_msg = AuthMessage('InvalidAuth') auth_msg.stat = 2 self.send(auth_msg.tobytes()) self.close() else: self.log('Authentication Success!', level=1) auth_msg = AuthMessage('Welcome') auth_msg.stat = 1 self.send(auth_msg.tobytes()) def send(self, msg): if self.terminated: return False packets = self.factory.to_packets(msg) self.log('spliting message to {0} packets'.format(len(packets))) for packet in packets: self.sock.send(packet.tobytes()) return True def commit_message(self, packets): self.log('Generating final packet...') full_msg = self.factory.from_packets(packets) if full_msg is not None: self.log('Valid package!') self.messages.put(full_msg) else: self.log('Invalid package') def recv_bare(self): msg_id = None packets = [] while not self.terminated: raw_msg = None try: raw_msg = self.sock.recv(8192) except (ConnectionAbortedError, ConnectionResetError): self.log('Connection Stopped') self.close() continue if raw_msg is None or len(raw_msg) == 0: self.close() continue if len(raw_msg) < self.header_size: continue pkt = Packet(PacketHeader()) pkt.frombytes(raw_msg) packets.append(pkt) if msg_id is None: msg_id = pkt.header.msg_id if msg_id != pkt.header.msg_id: msg_id = pkt.header.msg_id packets = [pkt] continue if pkt.header.pkt_id == pkt.header.pkt_num - 1: self.log('Finished') self.commit_message(packets) msg_id = None packets = [] def recv(self): self.log('Getting messages') msg = None while msg is None and not self.terminated: try: msg = self.messages.get(timeout=0.1) except queue.Empty: msg = None continue return msg self.log('Message Get Finished') def close(self): self.terminated = True self.sock.close() class HttpServer(SocketServer): def __init__(self): super(HttpServer, self).__init__(port=80) self.header = 'HTTP/1.1 200 OK\nServer: NaiveHttpServer\nConnection: close\nContent-Length: {0}\nContent-Type: text/html\n\n' def handle_message(self, clientsocket, addr): # print('Connection Established from clinet', addr) msg = clientsocket.recv(8192) text = msg.decode('utf-8') while '\n' in text: text = text.replace('\n', '
') info = '' info += '

Hello, World

' info += '

Your IP Address & Port

\n' info += '

{0}

'.format(str(addr)) info += '

Your Request

\n' info += '

{0}

\n'.format(text) info += '\n' length = len(info) # print('length =', length) header = self.header.format(length) msg = header + info # print('msg:', msg) clientsocket.send(msg.encode('utf-8')) # print('message sent!') clientsocket.close() class BridgeConnection(): def __init__(self, client, server, num_threads=8): self.client = client self.server = server self.daemon = parallel.ParallelHost(num_threads) self.exit = False self.client_terminated = False self.server_terminated = False def is_terminated(self): if self.client_terminated: return True else: return False def client_recv_handler(self, msg): # print('send message to server, size=', len(msg)) self.server.send(msg) def stop(self): self.daemon.stop('kill') self.server.close() print('connection terminated successfully.') def client_recv(self): while True: msg = self.client.recv(8192) if len(msg) == 0: self.exit = True break # print('message received from clinet, size=', len(msg)) self.daemon.add_job(self.client_recv_handler, args=[msg]) self.client_terminated = True print('Client Recv terminated.') def server_recv_handler(self, msg): # print('send message to client, size=', len(msg)) self.client.send(msg) def server_recv(self): while not self.exit: try: msg = self.server.recv(8192) except ConnectionAbortedError: break if len(msg) == 0: self.exit = True continue # print('message received from server, size=', len(msg)) self.daemon.add_job(self.server_recv_handler, args=[msg]) self.server_terminated = True print('Server Recv terminated.') def run(self): self.daemon.add_job(self.server_recv) self.daemon.add_job(self.client_recv) while not self.is_terminated(): time.sleep(1) self.stop() def SendMessage(host, port, msg): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.connect((host, port)) s.send(msg) print('Message Sent!') s.close() def portforward(dst_ip, dst_port, listen_ip, listen_port): def port_forward_handler(clientsocket, addr): print('Handling message from clinet', addr) s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.connect((dst_ip, dst_port)) print('Connect to dst server success!') connection = BridgeConnection(clientsocket, s) connection.run() print('handler exited') server = SocketServer(message_handler=port_forward_handler, ip=listen_ip, port=listen_port) server.start() # if __name__ == '__main__': # # let's try a port forwarding server using this socker server. # portforward('192.168.233.101', 22, '0.0.0.0', 30001) # portforward('192.168.233.102', 22, '0.0.0.0', 30002) # con = console.console('PortForward') # con.interactive() if __name__ == '__main__': server = HttpServer() server.start() import time time.sleep(10000)