447 lines
11 KiB
Python
447 lines
11 KiB
Python
import socket
|
|
import zlib
|
|
import json
|
|
import struct
|
|
import queue
|
|
|
|
from imshow import parallel
|
|
|
|
class SocketMessage():
|
|
def __init__(self, msg={}):
|
|
self.data = msg
|
|
self.type = 'empty'
|
|
|
|
def encode(self):
|
|
pass
|
|
|
|
def decode(self):
|
|
pass
|
|
|
|
def tobytes(self):
|
|
self.encode()
|
|
js = {}
|
|
js['type'] = self.type
|
|
js['msg'] = self.data
|
|
js = json.dumps(js)
|
|
js = js.encode('utf-8')
|
|
byte = zlib.compress(js)
|
|
return byte
|
|
|
|
def frombytes(self, byte):
|
|
js = zlib.decompress(byte)
|
|
js = js.decode('utf-8')
|
|
js = json.loads(js)
|
|
msg = js['msg']
|
|
self.data = msg
|
|
self.decode()
|
|
|
|
class SocketServer():
|
|
def __init__(self, ip='0.0.0.0', port=12345, message_handler=None):
|
|
self.daemon = parallel.daemon
|
|
self.ip = ip
|
|
self.port = port
|
|
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
self.socket.bind((self.ip, self.port))
|
|
self.socket.listen(5)
|
|
self.terminate = False
|
|
self.handler = message_handler
|
|
|
|
def handle_message(self, clientsocket, addr):
|
|
# print('Connection Established from clinet', addr)
|
|
if self.handler is not None:
|
|
self.handler(clientsocket, addr)
|
|
pass
|
|
|
|
def loop(self):
|
|
while not self.terminate:
|
|
clientsocket,addr = self.socket.accept()
|
|
self.daemon.add_job(
|
|
self.handle_message,
|
|
args=[clientsocket, addr],
|
|
name='Client[{0}]'.format(addr)
|
|
)
|
|
|
|
def start(self, back=True):
|
|
if back:
|
|
self.daemon.add_job(self.loop, name='SocketMainLoop')
|
|
else:
|
|
self.loop()
|
|
|
|
# what should a packet header contains:
|
|
# 1. message id
|
|
# 2. packet id
|
|
# 3. total packets
|
|
# 4. total size
|
|
|
|
class PacketHeader():
|
|
def __init__(self, mid=0, pid=0, pn=0, sz=0):
|
|
self.msg_id = mid
|
|
self.pkt_id = pid
|
|
self.pkt_num = pn
|
|
self.msg_sz = sz
|
|
self.header_size = 16
|
|
|
|
def tobytes(self):
|
|
b = struct.pack('LLLL', self.msg_id, self.pkt_id, self.pkt_num, self.msg_sz)
|
|
return b
|
|
|
|
def frombytes(self, b):
|
|
self.msg_id, self.pkt_id, self.pkt_num, self.msg_sz = struct.unpack('LLLL', b)
|
|
return self
|
|
|
|
class Packet():
|
|
def __init__(self, header=PacketHeader(), msg=b''):
|
|
self.header = header
|
|
self.msg = msg
|
|
self.header_size = self.header.header_size
|
|
|
|
def frombytes(self, b):
|
|
header = b[:self.header_size]
|
|
self.msg = b[self.header_size:]
|
|
self.header.frombytes(header)
|
|
return self
|
|
|
|
def tobytes(self):
|
|
msg = b''
|
|
self.header.msg_sz = len(self.msg)
|
|
msg += self.header.tobytes()
|
|
msg += self.msg
|
|
return msg
|
|
|
|
class PacketFactory():
|
|
def __init__(self, max_size=8192, log=None):
|
|
self.max_size = max_size
|
|
self.id = 0
|
|
self.header_size = PacketHeader().header_size
|
|
self.log = print
|
|
if log != None:
|
|
self.log = log
|
|
|
|
def to_packets(self, msg):
|
|
packets = []
|
|
length = len(msg)
|
|
capacity = self.max_size - self.header_size
|
|
num_packets = int((length + capacity - 1) / capacity)
|
|
for i in range(num_packets):
|
|
header = PacketHeader(self.id, i, num_packets)
|
|
packet = Packet(header, msg[i*capacity:(i+1)*capacity])
|
|
packets.append(packet)
|
|
self.id += 1
|
|
return packets
|
|
|
|
def from_packets(self, packets):
|
|
num_packet = len(packets)
|
|
self.log('Packet Number:', num_packet)
|
|
if num_packet == 0:
|
|
return None
|
|
msg_id = None
|
|
packet_id = 0
|
|
message = b''
|
|
|
|
for packet in packets:
|
|
header = packet.header
|
|
if msg_id is None:
|
|
msg_id = header.msg_id
|
|
if num_packet != header.pkt_num:
|
|
self.log('Uncorrect Package Number')
|
|
self.log('get {0} while it should be {1}'.format(header.pkt_num, num_packet))
|
|
return None
|
|
|
|
if msg_id != header.msg_id:
|
|
self.log('Uncorrect Message id')
|
|
self.log('get {0} while it should be {1}'.format(header.pkt_num, msg_id))
|
|
return None
|
|
if packet_id != header.pkt_id:
|
|
self.log('Uncorrect pkt id')
|
|
self.log('get {0} while it should be {1}'.format(header.pkt_id, packet_id))
|
|
return None
|
|
message += packet.msg
|
|
packet_id += 1
|
|
return message
|
|
|
|
class AuthMessage(SocketMessage):
|
|
def __init__(self, token='None'):
|
|
super(AuthMessage, self).__init__()
|
|
self.token = token
|
|
self.type = 'auth'
|
|
self.stat = 0
|
|
# stat:
|
|
# 0: auth client
|
|
# 1: auth success
|
|
# 2: auth failed.
|
|
|
|
def encode(self):
|
|
self.data = {}
|
|
self.data['token'] = self.token
|
|
self.data['status'] = self.stat
|
|
|
|
def decode(self):
|
|
self.token = self.data['token']
|
|
self.stat = self.data['status']
|
|
|
|
class SocketConnection():
|
|
def __init__(self, daemon, log_prefix=''):
|
|
self.sock = None
|
|
self.daemon = daemon
|
|
self.messages = queue.Queue()
|
|
self.terminated = False
|
|
self.factory = PacketFactory(log=self.log)
|
|
self.header_size = PacketHeader().header_size
|
|
self.loglevel = 2
|
|
self.log_prefix = log_prefix
|
|
|
|
# log level:
|
|
# 0 : debug
|
|
# 1 : message or info
|
|
# 2 : warning
|
|
# 3 : error
|
|
def log(self, *args, level=0, end='\n'):
|
|
if level >= self.loglevel:
|
|
print(self.log_prefix, end=' ')
|
|
for msg in args:
|
|
print(str(msg), end=' ')
|
|
print(end, end='')
|
|
|
|
def start(self):
|
|
self.jid = self.daemon.add_job(self.recv_bare, name='connection')
|
|
|
|
def SetSock(self, sock):
|
|
self.sock = sock
|
|
self.start()
|
|
|
|
def connect(self, host, port):
|
|
if self.sock is None:
|
|
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
self.sock.connect((host, port))
|
|
self.log('Sock Connected', level=1)
|
|
else:
|
|
self.log('Socket has already established, ingoring connect.', level=2)
|
|
self.start()
|
|
|
|
def auth(self, token):
|
|
if self.sock is None:
|
|
self.log('Socket not established, unable to auth.', level=3)
|
|
return None
|
|
auth_msg = AuthMessage(token)
|
|
msg = auth_msg.tobytes()
|
|
self.log('Sending Authentication Message.')
|
|
self.send(msg)
|
|
self.log('Waiting Authentication Status.')
|
|
msg = self.recv()
|
|
auth_msg.frombytes(msg)
|
|
self.log('Status Recived! auth_status =', auth_msg.__dict__)
|
|
if auth_msg.stat == 1:
|
|
self.log('Auth Success!', level=1)
|
|
else:
|
|
self.log('Auth Failed!', level=3)
|
|
self.close()
|
|
|
|
def WaitAuth(self, token):
|
|
msg = self.recv()
|
|
auth_msg = AuthMessage()
|
|
auth_msg.frombytes(msg)
|
|
if auth_msg.token != token:
|
|
self.log('Authentication Failed!', level=1)
|
|
auth_msg = AuthMessage('InvalidAuth')
|
|
auth_msg.stat = 2
|
|
self.send(auth_msg.tobytes())
|
|
self.close()
|
|
else:
|
|
self.log('Authentication Success!', level=1)
|
|
auth_msg = AuthMessage('Welcome')
|
|
auth_msg.stat = 1
|
|
self.send(auth_msg.tobytes())
|
|
|
|
def send(self, msg):
|
|
if self.terminated:
|
|
return False
|
|
packets = self.factory.to_packets(msg)
|
|
self.log('spliting message to {0} packets'.format(len(packets)))
|
|
for packet in packets:
|
|
self.sock.send(packet.tobytes())
|
|
return True
|
|
|
|
def commit_message(self, packets):
|
|
self.log('Generating final packet...')
|
|
full_msg = self.factory.from_packets(packets)
|
|
if full_msg is not None:
|
|
self.log('Valid package!')
|
|
self.messages.put(full_msg)
|
|
else:
|
|
self.log('Invalid package')
|
|
|
|
def recv_bare(self):
|
|
msg_id = None
|
|
packets = []
|
|
while not self.terminated:
|
|
raw_msg = None
|
|
try:
|
|
raw_msg = self.sock.recv(8192)
|
|
except (ConnectionAbortedError, ConnectionResetError):
|
|
self.log('Connection Stopped')
|
|
self.close()
|
|
continue
|
|
|
|
if raw_msg is None or len(raw_msg) == 0:
|
|
self.close()
|
|
continue
|
|
if len(raw_msg) < self.header_size:
|
|
continue
|
|
|
|
pkt = Packet(PacketHeader())
|
|
pkt.frombytes(raw_msg)
|
|
packets.append(pkt)
|
|
|
|
if msg_id is None:
|
|
msg_id = pkt.header.msg_id
|
|
|
|
if msg_id != pkt.header.msg_id:
|
|
msg_id = pkt.header.msg_id
|
|
packets = [pkt]
|
|
continue
|
|
|
|
if pkt.header.pkt_id == pkt.header.pkt_num - 1:
|
|
self.log('Finished')
|
|
self.commit_message(packets)
|
|
msg_id = None
|
|
packets = []
|
|
|
|
|
|
def recv(self):
|
|
self.log('Getting messages')
|
|
msg = None
|
|
while msg is None and not self.terminated:
|
|
try:
|
|
msg = self.messages.get(timeout=0.1)
|
|
except queue.Empty:
|
|
msg = None
|
|
continue
|
|
return msg
|
|
self.log('Message Get Finished')
|
|
|
|
def close(self):
|
|
self.terminated = True
|
|
self.sock.close()
|
|
|
|
|
|
class HttpServer(SocketServer):
|
|
def __init__(self):
|
|
super(HttpServer, self).__init__(port=80)
|
|
self.header = 'HTTP/1.1 200 OK\nServer: NaiveHttpServer\nConnection: close\nContent-Length: {0}\nContent-Type: text/html\n\n'
|
|
|
|
def handle_message(self, clientsocket, addr):
|
|
# print('Connection Established from clinet', addr)
|
|
msg = clientsocket.recv(8192)
|
|
text = msg.decode('utf-8')
|
|
while '\n' in text:
|
|
text = text.replace('\n', '<br>')
|
|
info = '<html><body>'
|
|
info += '<h1> Hello, World </h1>'
|
|
info += '<h2> Your IP Address & Port</h2>\n'
|
|
info += '<p>{0}</p>'.format(str(addr))
|
|
info += '<h2> Your Request</h2>\n'
|
|
info += '<p>{0}</p>\n'.format(text)
|
|
info += '</body></html>\n'
|
|
length = len(info)
|
|
# print('length =', length)
|
|
header = self.header.format(length)
|
|
msg = header + info
|
|
# print('msg:', msg)
|
|
clientsocket.send(msg.encode('utf-8'))
|
|
# print('message sent!')
|
|
clientsocket.close()
|
|
|
|
class BridgeConnection():
|
|
def __init__(self, client, server, num_threads=8):
|
|
self.client = client
|
|
self.server = server
|
|
self.daemon = parallel.ParallelHost(num_threads)
|
|
self.exit = False
|
|
self.client_terminated = False
|
|
self.server_terminated = False
|
|
|
|
def is_terminated(self):
|
|
if self.client_terminated:
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
def client_recv_handler(self, msg):
|
|
# print('send message to server, size=', len(msg))
|
|
self.server.send(msg)
|
|
|
|
def stop(self):
|
|
self.daemon.stop('kill')
|
|
self.server.close()
|
|
print('connection terminated successfully.')
|
|
|
|
def client_recv(self):
|
|
while True:
|
|
msg = self.client.recv(8192)
|
|
if len(msg) == 0:
|
|
self.exit = True
|
|
break
|
|
# print('message received from clinet, size=', len(msg))
|
|
self.daemon.add_job(self.client_recv_handler, args=[msg])
|
|
self.client_terminated = True
|
|
print('Client Recv terminated.')
|
|
|
|
def server_recv_handler(self, msg):
|
|
# print('send message to client, size=', len(msg))
|
|
self.client.send(msg)
|
|
|
|
def server_recv(self):
|
|
while not self.exit:
|
|
try:
|
|
msg = self.server.recv(8192)
|
|
except ConnectionAbortedError:
|
|
break
|
|
if len(msg) == 0:
|
|
self.exit = True
|
|
continue
|
|
# print('message received from server, size=', len(msg))
|
|
self.daemon.add_job(self.server_recv_handler, args=[msg])
|
|
self.server_terminated = True
|
|
print('Server Recv terminated.')
|
|
|
|
def run(self):
|
|
self.daemon.add_job(self.server_recv)
|
|
self.daemon.add_job(self.client_recv)
|
|
while not self.is_terminated():
|
|
time.sleep(1)
|
|
self.stop()
|
|
|
|
def SendMessage(host, port, msg):
|
|
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
s.connect((host, port))
|
|
s.send(msg)
|
|
print('Message Sent!')
|
|
s.close()
|
|
|
|
|
|
|
|
def portforward(dst_ip, dst_port, listen_ip, listen_port):
|
|
def port_forward_handler(clientsocket, addr):
|
|
print('Handling message from clinet', addr)
|
|
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
s.connect((dst_ip, dst_port))
|
|
print('Connect to dst server success!')
|
|
connection = BridgeConnection(clientsocket, s)
|
|
connection.run()
|
|
print('handler exited')
|
|
server = SocketServer(message_handler=port_forward_handler, ip=listen_ip, port=listen_port)
|
|
server.start()
|
|
|
|
# if __name__ == '__main__':
|
|
# # let's try a port forwarding server using this socker server.
|
|
# portforward('192.168.233.101', 22, '0.0.0.0', 30001)
|
|
# portforward('192.168.233.102', 22, '0.0.0.0', 30002)
|
|
|
|
# con = console.console('PortForward')
|
|
# con.interactive()
|
|
|
|
if __name__ == '__main__':
|
|
server = HttpServer()
|
|
server.start()
|
|
import time
|
|
time.sleep(10000) |