ImshowOverSocket/imshow/network.py
2019-12-28 16:26:19 +08:00

447 lines
11 KiB
Python

import socket
import zlib
import json
import struct
import queue
from imshow import parallel
class SocketMessage():
def __init__(self, msg={}):
self.data = msg
self.type = 'empty'
def encode(self):
pass
def decode(self):
pass
def tobytes(self):
self.encode()
js = {}
js['type'] = self.type
js['msg'] = self.data
js = json.dumps(js)
js = js.encode('utf-8')
byte = zlib.compress(js)
return byte
def frombytes(self, byte):
js = zlib.decompress(byte)
js = js.decode('utf-8')
js = json.loads(js)
msg = js['msg']
self.data = msg
self.decode()
class SocketServer():
def __init__(self, ip='0.0.0.0', port=12345, message_handler=None):
self.daemon = parallel.daemon
self.ip = ip
self.port = port
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.socket.bind((self.ip, self.port))
self.socket.listen(5)
self.terminate = False
self.handler = message_handler
def handle_message(self, clientsocket, addr):
# print('Connection Established from clinet', addr)
if self.handler is not None:
self.handler(clientsocket, addr)
pass
def loop(self):
while not self.terminate:
clientsocket,addr = self.socket.accept()
self.daemon.add_job(
self.handle_message,
args=[clientsocket, addr],
name='Client[{0}]'.format(addr)
)
def start(self, back=True):
if back:
self.daemon.add_job(self.loop, name='SocketMainLoop')
else:
self.loop()
# what should a packet header contains:
# 1. message id
# 2. packet id
# 3. total packets
# 4. total size
class PacketHeader():
def __init__(self, mid=0, pid=0, pn=0, sz=0):
self.msg_id = mid
self.pkt_id = pid
self.pkt_num = pn
self.msg_sz = sz
self.header_size = 16
def tobytes(self):
b = struct.pack('LLLL', self.msg_id, self.pkt_id, self.pkt_num, self.msg_sz)
return b
def frombytes(self, b):
self.msg_id, self.pkt_id, self.pkt_num, self.msg_sz = struct.unpack('LLLL', b)
return self
class Packet():
def __init__(self, header=PacketHeader(), msg=b''):
self.header = header
self.msg = msg
self.header_size = self.header.header_size
def frombytes(self, b):
header = b[:self.header_size]
self.msg = b[self.header_size:]
self.header.frombytes(header)
return self
def tobytes(self):
msg = b''
self.header.msg_sz = len(self.msg)
msg += self.header.tobytes()
msg += self.msg
return msg
class PacketFactory():
def __init__(self, max_size=8192, log=None):
self.max_size = max_size
self.id = 0
self.header_size = PacketHeader().header_size
self.log = print
if log != None:
self.log = log
def to_packets(self, msg):
packets = []
length = len(msg)
capacity = self.max_size - self.header_size
num_packets = int((length + capacity - 1) / capacity)
for i in range(num_packets):
header = PacketHeader(self.id, i, num_packets)
packet = Packet(header, msg[i*capacity:(i+1)*capacity])
packets.append(packet)
self.id += 1
return packets
def from_packets(self, packets):
num_packet = len(packets)
self.log('Packet Number:', num_packet)
if num_packet == 0:
return None
msg_id = None
packet_id = 0
message = b''
for packet in packets:
header = packet.header
if msg_id is None:
msg_id = header.msg_id
if num_packet != header.pkt_num:
self.log('Uncorrect Package Number')
self.log('get {0} while it should be {1}'.format(header.pkt_num, num_packet))
return None
if msg_id != header.msg_id:
self.log('Uncorrect Message id')
self.log('get {0} while it should be {1}'.format(header.pkt_num, msg_id))
return None
if packet_id != header.pkt_id:
self.log('Uncorrect pkt id')
self.log('get {0} while it should be {1}'.format(header.pkt_id, packet_id))
return None
message += packet.msg
packet_id += 1
return message
class AuthMessage(SocketMessage):
def __init__(self, token='None'):
super(AuthMessage, self).__init__()
self.token = token
self.type = 'auth'
self.stat = 0
# stat:
# 0: auth client
# 1: auth success
# 2: auth failed.
def encode(self):
self.data = {}
self.data['token'] = self.token
self.data['status'] = self.stat
def decode(self):
self.token = self.data['token']
self.stat = self.data['status']
class SocketConnection():
def __init__(self, daemon, log_prefix=''):
self.sock = None
self.daemon = daemon
self.messages = queue.Queue()
self.terminated = False
self.factory = PacketFactory(log=self.log)
self.header_size = PacketHeader().header_size
self.loglevel = 2
self.log_prefix = log_prefix
# log level:
# 0 : debug
# 1 : message or info
# 2 : warning
# 3 : error
def log(self, *args, level=0, end='\n'):
if level >= self.loglevel:
print(self.log_prefix, end=' ')
for msg in args:
print(str(msg), end=' ')
print(end, end='')
def start(self):
self.jid = self.daemon.add_job(self.recv_bare, name='connection')
def SetSock(self, sock):
self.sock = sock
self.start()
def connect(self, host, port):
if self.sock is None:
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.sock.connect((host, port))
self.log('Sock Connected', level=1)
else:
self.log('Socket has already established, ingoring connect.', level=2)
self.start()
def auth(self, token):
if self.sock is None:
self.log('Socket not established, unable to auth.', level=3)
return None
auth_msg = AuthMessage(token)
msg = auth_msg.tobytes()
self.log('Sending Authentication Message.')
self.send(msg)
self.log('Waiting Authentication Status.')
msg = self.recv()
auth_msg.frombytes(msg)
self.log('Status Recived! auth_status =', auth_msg.__dict__)
if auth_msg.stat == 1:
self.log('Auth Success!', level=1)
else:
self.log('Auth Failed!', level=3)
self.close()
def WaitAuth(self, token):
msg = self.recv()
auth_msg = AuthMessage()
auth_msg.frombytes(msg)
if auth_msg.token != token:
self.log('Authentication Failed!', level=1)
auth_msg = AuthMessage('InvalidAuth')
auth_msg.stat = 2
self.send(auth_msg.tobytes())
self.close()
else:
self.log('Authentication Success!', level=1)
auth_msg = AuthMessage('Welcome')
auth_msg.stat = 1
self.send(auth_msg.tobytes())
def send(self, msg):
if self.terminated:
return False
packets = self.factory.to_packets(msg)
self.log('spliting message to {0} packets'.format(len(packets)))
for packet in packets:
self.sock.send(packet.tobytes())
return True
def commit_message(self, packets):
self.log('Generating final packet...')
full_msg = self.factory.from_packets(packets)
if full_msg is not None:
self.log('Valid package!')
self.messages.put(full_msg)
else:
self.log('Invalid package')
def recv_bare(self):
msg_id = None
packets = []
while not self.terminated:
raw_msg = None
try:
raw_msg = self.sock.recv(8192)
except (ConnectionAbortedError, ConnectionResetError):
self.log('Connection Stopped')
self.close()
continue
if raw_msg is None or len(raw_msg) == 0:
self.close()
continue
if len(raw_msg) < self.header_size:
continue
pkt = Packet(PacketHeader())
pkt.frombytes(raw_msg)
packets.append(pkt)
if msg_id is None:
msg_id = pkt.header.msg_id
if msg_id != pkt.header.msg_id:
msg_id = pkt.header.msg_id
packets = [pkt]
continue
if pkt.header.pkt_id == pkt.header.pkt_num - 1:
self.log('Finished')
self.commit_message(packets)
msg_id = None
packets = []
def recv(self):
self.log('Getting messages')
msg = None
while msg is None and not self.terminated:
try:
msg = self.messages.get(timeout=0.1)
except queue.Empty:
msg = None
continue
return msg
self.log('Message Get Finished')
def close(self):
self.terminated = True
self.sock.close()
class HttpServer(SocketServer):
def __init__(self):
super(HttpServer, self).__init__(port=80)
self.header = 'HTTP/1.1 200 OK\nServer: NaiveHttpServer\nConnection: close\nContent-Length: {0}\nContent-Type: text/html\n\n'
def handle_message(self, clientsocket, addr):
# print('Connection Established from clinet', addr)
msg = clientsocket.recv(8192)
text = msg.decode('utf-8')
while '\n' in text:
text = text.replace('\n', '<br>')
info = '<html><body>'
info += '<h1> Hello, World </h1>'
info += '<h2> Your IP Address & Port</h2>\n'
info += '<p>{0}</p>'.format(str(addr))
info += '<h2> Your Request</h2>\n'
info += '<p>{0}</p>\n'.format(text)
info += '</body></html>\n'
length = len(info)
# print('length =', length)
header = self.header.format(length)
msg = header + info
# print('msg:', msg)
clientsocket.send(msg.encode('utf-8'))
# print('message sent!')
clientsocket.close()
class BridgeConnection():
def __init__(self, client, server, num_threads=8):
self.client = client
self.server = server
self.daemon = parallel.ParallelHost(num_threads)
self.exit = False
self.client_terminated = False
self.server_terminated = False
def is_terminated(self):
if self.client_terminated:
return True
else:
return False
def client_recv_handler(self, msg):
# print('send message to server, size=', len(msg))
self.server.send(msg)
def stop(self):
self.daemon.stop('kill')
self.server.close()
print('connection terminated successfully.')
def client_recv(self):
while True:
msg = self.client.recv(8192)
if len(msg) == 0:
self.exit = True
break
# print('message received from clinet, size=', len(msg))
self.daemon.add_job(self.client_recv_handler, args=[msg])
self.client_terminated = True
print('Client Recv terminated.')
def server_recv_handler(self, msg):
# print('send message to client, size=', len(msg))
self.client.send(msg)
def server_recv(self):
while not self.exit:
try:
msg = self.server.recv(8192)
except ConnectionAbortedError:
break
if len(msg) == 0:
self.exit = True
continue
# print('message received from server, size=', len(msg))
self.daemon.add_job(self.server_recv_handler, args=[msg])
self.server_terminated = True
print('Server Recv terminated.')
def run(self):
self.daemon.add_job(self.server_recv)
self.daemon.add_job(self.client_recv)
while not self.is_terminated():
time.sleep(1)
self.stop()
def SendMessage(host, port, msg):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.connect((host, port))
s.send(msg)
print('Message Sent!')
s.close()
def portforward(dst_ip, dst_port, listen_ip, listen_port):
def port_forward_handler(clientsocket, addr):
print('Handling message from clinet', addr)
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.connect((dst_ip, dst_port))
print('Connect to dst server success!')
connection = BridgeConnection(clientsocket, s)
connection.run()
print('handler exited')
server = SocketServer(message_handler=port_forward_handler, ip=listen_ip, port=listen_port)
server.start()
# if __name__ == '__main__':
# # let's try a port forwarding server using this socker server.
# portforward('192.168.233.101', 22, '0.0.0.0', 30001)
# portforward('192.168.233.102', 22, '0.0.0.0', 30002)
# con = console.console('PortForward')
# con.interactive()
if __name__ == '__main__':
server = HttpServer()
server.start()
import time
time.sleep(10000)