Socket_Programming/app_secure_client.py

243 lines
8.6 KiB
Python

# -*- coding: utf-8 -*-
"""
Created on Sun Mar 22 10:46:20 2020
@author: cpan
Simple client implementation with timeout feature to prevent hanging
"""
import sys
import socket
import json
import io
import struct
import hashlib
from Crypto import Random
from Crypto.PublicKey import RSA
from Crypto.Cipher import AES
HDRLEN = 2
PASSWORD = 'adminadmin'
class Message:
def __init__(self, sock, addr):
self._recv_buffer = bytearray(b"")
self.sock = sock
self.key = RSA.generate(2048, Random.new().read)
self.sessionkey = None
self.serverkey = None
self.addr = addr
self.handshake = False
def _json_encode(self,obj, encoding):
return json.dumps(obj, ensure_ascii=False).encode(encoding)
def _json_decode(self,json_bytes, encoding):
tiow = io.TextIOWrapper(
io.BytesIO(json_bytes), encoding=encoding, newline=""
)
obj = json.load(tiow)
tiow.close()
return obj
def _create_message(self, *, content_bytes, content_type, content_encoding):
if self.handshake:
# encrypt content_bytes
key = self.sessionkey + self.sessionkey[::-1]
iv = Random.new().read(AES.block_size)
cipher = AES.new(key, AES.MODE_CFB, iv)
data = iv + cipher.encrypt(content_bytes)
else:
data = content_bytes
jsonheader = {
"byteorder": sys.byteorder,
"content-type": content_type,
"content-encoding": content_encoding,
"content-length": len(data),
}
jsonheader_bytes = self._json_encode(jsonheader, "utf-8")
message_hdr = struct.pack(">H", len(jsonheader_bytes))
message = message_hdr + jsonheader_bytes + data
return message
def send_request(self, request):
content = request["content"]
content_type = request["type"]
content_encoding = request["encoding"]
if content_type == "text/json":
req = {
"content_bytes": self._json_encode(content, content_encoding),
"content_type": content_type,
"content_encoding": content_encoding,
}
else:
req = {
"content_bytes": content,
"content_type": content_type,
"content_encoding": content_encoding,
}
msg = self._create_message(**req)
print(f"sending request {repr(msg)} to {self.addr}")
self.sock.sendall(msg)
def read_to_buffer(self, length):
while len(self._recv_buffer) < length:
self._recv_buffer += self.sock.recv(4096)
def process_response(self):
# read and process protoheader (fixed length of 2 byte in network order)
self.read_to_buffer(HDRLEN)
_jsonheader_len = struct.unpack(">H", self._recv_buffer[:HDRLEN])[0]
self._recv_buffer = self._recv_buffer[HDRLEN:]
# read and process jsonheader
self.read_to_buffer(_jsonheader_len)
jsonheader = self._json_decode(self._recv_buffer[:_jsonheader_len], "utf-8")
self._recv_buffer = self._recv_buffer[_jsonheader_len:]
for reqhdr in ("byteorder", "content-length", "content-type",
"content-encoding",):
if reqhdr not in jsonheader:
raise ValueError(f'Missing required header "{reqhdr}".')
content_len = jsonheader["content-length"]
self.read_to_buffer(content_len)
data = self._recv_buffer[:content_len]
# decrypt data
iv = bytes(data[:16])
encrypted = bytes(data[16:])
key = self.sessionkey + self.sessionkey[::-1]
cipher = AES.new(key, AES.MODE_CFB, iv)
data = cipher.decrypt(encrypted)
self._recv_buffer = self._recv_buffer[content_len:]
if jsonheader["content-type"] == "text/json":
encoding = jsonheader["content-encoding"]
response = self._json_decode(data, encoding)
print("received response", repr(response), "from", self.addr)
content = response
result = content.get("result")
print(f"got result: {result}")
else:
# Binary or unknown content-type
response = data
print(
f'received {jsonheader["content-type"]} response from',
self.addr,
)
content = response
print(f"got response: {repr(content)}")
def _send_handshake(self, data):
response = {
"content_bytes": data,
"content_type": "binary/custom-server-binary-type",
"content_encoding": "binary",
}
message = self._create_message(**response)
self.sock.sendall(message)
def close(self):
self.sock.shutdown(socket.SHUT_RDWR)
self.sock.close()
self.sock = None
def _recv_handshake(self):
# read and process protoheader (fixed length of 2 byte in network order)
self.read_to_buffer(HDRLEN)
_jsonheader_len = struct.unpack(">H", self._recv_buffer[:HDRLEN])[0]
self._recv_buffer = self._recv_buffer[HDRLEN:]
# read and process jsonheader
self.read_to_buffer(_jsonheader_len)
jsonheader = self._json_decode(self._recv_buffer[:_jsonheader_len], "utf-8")
self._recv_buffer = self._recv_buffer[_jsonheader_len:]
for reqhdr in ("byteorder", "content-length", "content-type",
"content-encoding",):
if reqhdr not in jsonheader:
raise ValueError(f'Missing required header "{reqhdr}".')
content_len = jsonheader["content-length"]
self.read_to_buffer(content_len)
data = self._recv_buffer[:content_len]
self._recv_buffer = self._recv_buffer[content_len:]
sesskeyhash = data[-128:]
data = data[:-128]
encryptedsess = data[-256:]
halfsesskey = self.key.decrypt(bytes(encryptedsess))
data = data[:-256]
pubkeyhash = data[-128:]
pubkey = data[:-128]
if hashlib.sha3_512(pubkey).hexdigest().encode() != pubkeyhash:
self.close()
raise Exception('Server public key does not match hash')
if hashlib.sha3_512(halfsesskey).hexdigest().encode() != sesskeyhash:
self.close()
raise Exception('Server public key does not match hash')
self.sessionkey = halfsesskey
self.serverkey = RSA.importKey(pubkey)
print(f"Received server secrets correctly.")
def send_passphrase(self, passphrase):
passphrase = passphrase.encode()
passhash = hashlib.sha3_512(passphrase).hexdigest().encode()
sesshash = hashlib.sha3_512(self.sessionkey).hexdigest().encode()
data = self.serverkey.encrypt(passhash + sesshash, None)[0]
print(f'Sending passphrase + hash + sessionkey hash...')
self._send_handshake(data)
self.handshake = True
def create_request(action, value):
if action == "search":
return dict(
type="text/json",
encoding="utf-8",
content=dict(action=action, value=value),
)
else:
return dict(
type="binary/custom-client-binary-type",
encoding="binary",
content=bytes(action + value, encoding="utf-8"),
)
def main():
if len(sys.argv) != 5:
print("usage:", sys.argv[0], "<host> <port> <action> <value>")
sys.exit(1)
host, port = sys.argv[1], int(sys.argv[2])
addr = (host, port)
print("starting connection to", addr)
sock = socket.create_connection(addr)
sock.settimeout(None)
message = Message(sock, addr)
# security handshake
publickey = message.key.publickey().exportKey('PEM')
keyhash = hashlib.sha3_512(publickey).hexdigest().encode()
data = publickey + keyhash
print(f'Sending client public key and hash...')
message._send_handshake(data)
# receiving server secrets - public key and session key
message._recv_handshake()
message.send_passphrase(PASSWORD)
print('Handshake completed. Sever would close the sock if password wrong')
action, value = sys.argv[3], sys.argv[4]
request = create_request(action, value)
message.send_request(request)
# read and process response
message.process_response()
sock.shutdown(socket.SHUT_RDWR)
sock.close()
if __name__ == '__main__':
main()