import socket
import zlib
import json
import struct
import queue
from imshow import parallel
class SocketMessage():
def __init__(self, msg={}):
self.data = msg
self.type = 'empty'
def encode(self):
pass
def decode(self):
pass
def tobytes(self):
self.encode()
js = {}
js['type'] = self.type
js['msg'] = self.data
js = json.dumps(js)
js = js.encode('utf-8')
byte = zlib.compress(js)
return byte
def frombytes(self, byte):
js = zlib.decompress(byte)
js = js.decode('utf-8')
js = json.loads(js)
msg = js['msg']
self.data = msg
self.decode()
class SocketServer():
def __init__(self, ip='0.0.0.0', port=12345, message_handler=None):
self.daemon = parallel.daemon
self.ip = ip
self.port = port
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.socket.bind((self.ip, self.port))
self.socket.listen(5)
self.terminate = False
self.handler = message_handler
def handle_message(self, clientsocket, addr):
# print('Connection Established from clinet', addr)
if self.handler is not None:
self.handler(clientsocket, addr)
pass
def loop(self):
while not self.terminate:
clientsocket,addr = self.socket.accept()
self.daemon.add_job(
self.handle_message,
args=[clientsocket, addr],
name='Client[{0}]'.format(addr)
)
def start(self, back=True):
if back:
self.daemon.add_job(self.loop, name='SocketMainLoop')
else:
self.loop()
# what should a packet header contains:
# 1. message id
# 2. packet id
# 3. total packets
# 4. total size
class PacketHeader():
def __init__(self, mid=0, pid=0, pn=0, sz=0):
self.msg_id = mid
self.pkt_id = pid
self.pkt_num = pn
self.msg_sz = sz
self.header_size = 16
def tobytes(self):
b = struct.pack('LLLL', self.msg_id, self.pkt_id, self.pkt_num, self.msg_sz)
return b
def frombytes(self, b):
self.msg_id, self.pkt_id, self.pkt_num, self.msg_sz = struct.unpack('LLLL', b)
return self
class Packet():
def __init__(self, header=PacketHeader(), msg=b''):
self.header = header
self.msg = msg
self.header_size = self.header.header_size
def frombytes(self, b):
header = b[:self.header_size]
self.msg = b[self.header_size:]
self.header.frombytes(header)
return self
def tobytes(self):
msg = b''
self.header.msg_sz = len(self.msg)
msg += self.header.tobytes()
msg += self.msg
return msg
class PacketFactory():
def __init__(self, max_size=8192, log=None):
self.max_size = max_size
self.id = 0
self.header_size = PacketHeader().header_size
self.log = print
if log != None:
self.log = log
def to_packets(self, msg):
packets = []
length = len(msg)
capacity = self.max_size - self.header_size
num_packets = int((length + capacity - 1) / capacity)
for i in range(num_packets):
header = PacketHeader(self.id, i, num_packets)
packet = Packet(header, msg[i*capacity:(i+1)*capacity])
packets.append(packet)
self.id += 1
return packets
def from_packets(self, packets):
num_packet = len(packets)
self.log('Packet Number:', num_packet)
if num_packet == 0:
return None
msg_id = None
packet_id = 0
message = b''
for packet in packets:
header = packet.header
if msg_id is None:
msg_id = header.msg_id
if num_packet != header.pkt_num:
self.log('Uncorrect Package Number')
self.log('get {0} while it should be {1}'.format(header.pkt_num, num_packet))
return None
if msg_id != header.msg_id:
self.log('Uncorrect Message id')
self.log('get {0} while it should be {1}'.format(header.pkt_num, msg_id))
return None
if packet_id != header.pkt_id:
self.log('Uncorrect pkt id')
self.log('get {0} while it should be {1}'.format(header.pkt_id, packet_id))
return None
message += packet.msg
packet_id += 1
return message
class AuthMessage(SocketMessage):
def __init__(self, token='None'):
super(AuthMessage, self).__init__()
self.token = token
self.type = 'auth'
self.stat = 0
# stat:
# 0: auth client
# 1: auth success
# 2: auth failed.
def encode(self):
self.data = {}
self.data['token'] = self.token
self.data['status'] = self.stat
def decode(self):
self.token = self.data['token']
self.stat = self.data['status']
class SocketConnection():
def __init__(self, daemon, log_prefix=''):
self.sock = None
self.daemon = daemon
self.messages = queue.Queue()
self.terminated = False
self.factory = PacketFactory(log=self.log)
self.header_size = PacketHeader().header_size
self.loglevel = 2
self.log_prefix = log_prefix
# log level:
# 0 : debug
# 1 : message or info
# 2 : warning
# 3 : error
def log(self, *args, level=0, end='\n'):
if level >= self.loglevel:
print(self.log_prefix, end=' ')
for msg in args:
print(str(msg), end=' ')
print(end, end='')
def start(self):
self.jid = self.daemon.add_job(self.recv_bare, name='connection')
def SetSock(self, sock):
self.sock = sock
self.start()
def connect(self, host, port):
if self.sock is None:
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.sock.connect((host, port))
self.log('Sock Connected', level=1)
else:
self.log('Socket has already established, ingoring connect.', level=2)
self.start()
def auth(self, token):
if self.sock is None:
self.log('Socket not established, unable to auth.', level=3)
return None
auth_msg = AuthMessage(token)
msg = auth_msg.tobytes()
self.log('Sending Authentication Message.')
self.send(msg)
self.log('Waiting Authentication Status.')
msg = self.recv()
auth_msg.frombytes(msg)
self.log('Status Recived! auth_status =', auth_msg.__dict__)
if auth_msg.stat == 1:
self.log('Auth Success!', level=1)
else:
self.log('Auth Failed!', level=3)
self.close()
def WaitAuth(self, token):
msg = self.recv()
auth_msg = AuthMessage()
auth_msg.frombytes(msg)
if auth_msg.token != token:
self.log('Authentication Failed!', level=1)
auth_msg = AuthMessage('InvalidAuth')
auth_msg.stat = 2
self.send(auth_msg.tobytes())
self.close()
else:
self.log('Authentication Success!', level=1)
auth_msg = AuthMessage('Welcome')
auth_msg.stat = 1
self.send(auth_msg.tobytes())
def send(self, msg):
if self.terminated:
return False
packets = self.factory.to_packets(msg)
self.log('spliting message to {0} packets'.format(len(packets)))
for packet in packets:
self.sock.send(packet.tobytes())
return True
def commit_message(self, packets):
self.log('Generating final packet...')
full_msg = self.factory.from_packets(packets)
if full_msg is not None:
self.log('Valid package!')
self.messages.put(full_msg)
else:
self.log('Invalid package')
def recv_bare(self):
msg_id = None
packets = []
while not self.terminated:
raw_msg = None
try:
raw_msg = self.sock.recv(8192)
except (ConnectionAbortedError, ConnectionResetError):
self.log('Connection Stopped')
self.close()
continue
if raw_msg is None or len(raw_msg) == 0:
self.close()
continue
if len(raw_msg) < self.header_size:
continue
pkt = Packet(PacketHeader())
pkt.frombytes(raw_msg)
packets.append(pkt)
if msg_id is None:
msg_id = pkt.header.msg_id
if msg_id != pkt.header.msg_id:
msg_id = pkt.header.msg_id
packets = [pkt]
continue
if pkt.header.pkt_id == pkt.header.pkt_num - 1:
self.log('Finished')
self.commit_message(packets)
msg_id = None
packets = []
def recv(self):
self.log('Getting messages')
msg = None
while msg is None and not self.terminated:
try:
msg = self.messages.get(timeout=0.1)
except queue.Empty:
msg = None
continue
return msg
self.log('Message Get Finished')
def close(self):
self.terminated = True
self.sock.close()
class HttpServer(SocketServer):
def __init__(self):
super(HttpServer, self).__init__(port=80)
self.header = 'HTTP/1.1 200 OK\nServer: NaiveHttpServer\nConnection: close\nContent-Length: {0}\nContent-Type: text/html\n\n'
def handle_message(self, clientsocket, addr):
# print('Connection Established from clinet', addr)
msg = clientsocket.recv(8192)
text = msg.decode('utf-8')
while '\n' in text:
text = text.replace('\n', '
')
info = '
{0}
'.format(str(addr)) info += '{0}
\n'.format(text) info += '\n' length = len(info) # print('length =', length) header = self.header.format(length) msg = header + info # print('msg:', msg) clientsocket.send(msg.encode('utf-8')) # print('message sent!') clientsocket.close() class BridgeConnection(): def __init__(self, client, server, num_threads=8): self.client = client self.server = server self.daemon = parallel.ParallelHost(num_threads) self.exit = False self.client_terminated = False self.server_terminated = False def is_terminated(self): if self.client_terminated: return True else: return False def client_recv_handler(self, msg): # print('send message to server, size=', len(msg)) self.server.send(msg) def stop(self): self.daemon.stop('kill') self.server.close() print('connection terminated successfully.') def client_recv(self): while True: msg = self.client.recv(8192) if len(msg) == 0: self.exit = True break # print('message received from clinet, size=', len(msg)) self.daemon.add_job(self.client_recv_handler, args=[msg]) self.client_terminated = True print('Client Recv terminated.') def server_recv_handler(self, msg): # print('send message to client, size=', len(msg)) self.client.send(msg) def server_recv(self): while not self.exit: try: msg = self.server.recv(8192) except ConnectionAbortedError: break if len(msg) == 0: self.exit = True continue # print('message received from server, size=', len(msg)) self.daemon.add_job(self.server_recv_handler, args=[msg]) self.server_terminated = True print('Server Recv terminated.') def run(self): self.daemon.add_job(self.server_recv) self.daemon.add_job(self.client_recv) while not self.is_terminated(): time.sleep(1) self.stop() def SendMessage(host, port, msg): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.connect((host, port)) s.send(msg) print('Message Sent!') s.close() def portforward(dst_ip, dst_port, listen_ip, listen_port): def port_forward_handler(clientsocket, addr): print('Handling message from clinet', addr) s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.connect((dst_ip, dst_port)) print('Connect to dst server success!') connection = BridgeConnection(clientsocket, s) connection.run() print('handler exited') server = SocketServer(message_handler=port_forward_handler, ip=listen_ip, port=listen_port) server.start() # if __name__ == '__main__': # # let's try a port forwarding server using this socker server. # portforward('192.168.233.101', 22, '0.0.0.0', 30001) # portforward('192.168.233.102', 22, '0.0.0.0', 30002) # con = console.console('PortForward') # con.interactive() if __name__ == '__main__': server = HttpServer() server.start() import time time.sleep(10000)