You've already forked MicroPythonOS
mirror of
https://github.com/m5stack/MicroPythonOS.git
synced 2026-05-20 11:51:27 -07:00
285 lines
8.8 KiB
Python
285 lines
8.8 KiB
Python
import socket
|
|
import asyncio as a
|
|
import binascii as b
|
|
import random as r
|
|
from collections import namedtuple
|
|
import re
|
|
import struct
|
|
import ssl
|
|
|
|
# Opcodes
|
|
OP_CONT = const(0x0)
|
|
OP_TEXT = const(0x1)
|
|
OP_BYTES = const(0x2)
|
|
OP_CLOSE = const(0x8)
|
|
OP_PING = const(0x9)
|
|
OP_PONG = const(0xa)
|
|
|
|
# Close codes
|
|
CLOSE_OK = const(1000)
|
|
CLOSE_GOING_AWAY = const(1001)
|
|
CLOSE_PROTOCOL_ERROR = const(1002)
|
|
CLOSE_DATA_NOT_SUPPORTED = const(1003)
|
|
CLOSE_BAD_DATA = const(1007)
|
|
CLOSE_POLICY_VIOLATION = const(1008)
|
|
CLOSE_TOO_BIG = const(1009)
|
|
CLOSE_MISSING_EXTN = const(1010)
|
|
CLOSE_BAD_CONDITION = const(1011)
|
|
|
|
URL_RE = re.compile(r'(wss|ws)://([A-Za-z0-9-\.]+)(?:\:([0-9]+))?(/.+)?')
|
|
URI = namedtuple('URI', ('protocol', 'hostname', 'port', 'path'))
|
|
|
|
class AsyncWebsocketClient:
|
|
def __init__(self, ms_delay_for_read: int = 5):
|
|
self._open = False
|
|
self.delay_read = ms_delay_for_read
|
|
self._lock_for_open = a.Lock()
|
|
self.sock = None
|
|
|
|
async def open(self, new_val: bool = None):
|
|
await self._lock_for_open.acquire()
|
|
if new_val is not None:
|
|
if not new_val and self.sock:
|
|
self.sock.close()
|
|
self.sock = None
|
|
self._open = new_val
|
|
to_return = self._open
|
|
self._lock_for_open.release()
|
|
return to_return
|
|
|
|
async def close(self):
|
|
return await self.open(False)
|
|
|
|
def urlparse(self, uri):
|
|
"""Parse ws or wss:// URLs"""
|
|
match = URL_RE.match(uri)
|
|
if match:
|
|
protocol, host, port, path = match.group(1), match.group(2), match.group(3), match.group(4)
|
|
|
|
if protocol not in ['ws', 'wss']:
|
|
raise ValueError('Scheme {} is invalid'.format(protocol))
|
|
|
|
if port is None:
|
|
port = (80, 443)[protocol == 'wss']
|
|
|
|
return URI(protocol, host, int(port), path)
|
|
|
|
async def a_readline(self):
|
|
line = None
|
|
while line is None:
|
|
line = self.sock.readline()
|
|
await a.sleep_ms(self.delay_read)
|
|
|
|
return line
|
|
|
|
async def a_read(self, size: int = None):
|
|
if size == 0:
|
|
return b''
|
|
chunks = []
|
|
|
|
while True:
|
|
b = self.sock.read(size)
|
|
await a.sleep_ms(self.delay_read)
|
|
|
|
# Continue reading if the socket returns None
|
|
if b is None: continue
|
|
|
|
# In some cases, the socket will return an empty bytes
|
|
# after PING or PONG frames, we need to ignore them.
|
|
if len(b) == 0: break
|
|
|
|
chunks.append(b)
|
|
size -= len(b)
|
|
|
|
# After reading the first chunk, we can break if size is None or 0
|
|
if size is None or size == 0: break
|
|
|
|
# Join all the chunks and return them
|
|
return b''.join(chunks)
|
|
|
|
async def handshake(self, uri, headers=[], keyfile=None, certfile=None, cafile=None, cert_reqs=0):
|
|
if self.sock:
|
|
self.close()
|
|
|
|
self.sock = socket.socket()
|
|
self.uri = self.urlparse(uri)
|
|
ai = socket.getaddrinfo(self.uri.hostname, self.uri.port)
|
|
addr = ai[0][4]
|
|
|
|
self.sock.connect(addr)
|
|
self.sock.setblocking(False)
|
|
|
|
if self.uri.protocol == 'wss':
|
|
cadata = None
|
|
if not cafile is None:
|
|
with open(cafile, 'rb') as f:
|
|
cadata = f.read()
|
|
self.sock = ssl.wrap_socket(
|
|
self.sock, server_side=False,
|
|
key=keyfile, cert=certfile,
|
|
cert_reqs=cert_reqs, # 0 - NONE, 1 - OPTIONAL, 2 - REQUIED
|
|
cadata=cadata,
|
|
server_hostname=self.uri.hostname
|
|
)
|
|
|
|
def send_header(header, *args):
|
|
self.sock.write(header % args + '\r\n')
|
|
|
|
# Sec-WebSocket-Key is 16 bytes of random base64 encoded
|
|
key = b.b2a_base64(bytes(r.getrandbits(8)
|
|
for _ in range(16)))[:-1]
|
|
|
|
send_header(b'GET %s HTTP/1.1', self.uri.path or '/')
|
|
send_header(b'Host: %s:%s', self.uri.hostname, self.uri.port)
|
|
send_header(b'Connection: Upgrade')
|
|
send_header(b'Upgrade: websocket')
|
|
send_header(b'Sec-WebSocket-Key: %s', key)
|
|
send_header(b'Sec-WebSocket-Version: 13')
|
|
send_header(b'Origin: http://{hostname}:{port}'.format(
|
|
hostname=self.uri.hostname,
|
|
port=self.uri.port)
|
|
)
|
|
|
|
for key, value in headers:
|
|
send_header(b'%s: %s', key, value)
|
|
|
|
send_header(b'')
|
|
|
|
line = await self.a_readline()
|
|
header = (line)[:-2]
|
|
if not header.startswith(b'HTTP/1.1 101 '):
|
|
raise Exception(header)
|
|
|
|
# We don't (currently) need these headers
|
|
# FIXME: should we check the return key?
|
|
while header:
|
|
line = await self.a_readline()
|
|
header = (line)[:-2]
|
|
|
|
return await self.open(True)
|
|
|
|
async def read_frame(self, max_size=None):
|
|
# Frame header
|
|
byte1, byte2 = struct.unpack('!BB', await self.a_read(2))
|
|
|
|
# Byte 1: FIN(1) _(1) _(1) _(1) OPCODE(4)
|
|
fin = bool(byte1 & 0x80)
|
|
opcode = byte1 & 0x0f
|
|
|
|
# Byte 2: MASK(1) LENGTH(7)
|
|
mask = bool(byte2 & (1 << 7))
|
|
length = byte2 & 0x7f
|
|
|
|
if length == 126: # Magic number, length header is 2 bytes
|
|
length, = struct.unpack('!H', await self.a_read(2))
|
|
elif length == 127: # Magic number, length header is 8 bytes
|
|
length, = struct.unpack('!Q', await self.a_read(8))
|
|
|
|
if mask: # Mask is 4 bytes
|
|
mask_bits = await self.a_read(4)
|
|
|
|
try:
|
|
data = await self.a_read(length)
|
|
except MemoryError:
|
|
# We can't receive this many bytes, close the socket
|
|
self.close(code=CLOSE_TOO_BIG)
|
|
# await self._stream.drain()
|
|
return True, OP_CLOSE, None
|
|
|
|
if mask:
|
|
data = bytes(b ^ mask_bits[i % 4]
|
|
for i, b in enumerate(data))
|
|
|
|
return fin, opcode, data
|
|
|
|
def write_frame(self, opcode, data=b''):
|
|
fin = True
|
|
mask = True # messages sent by client are masked
|
|
|
|
length = len(data)
|
|
|
|
# Frame header
|
|
# Byte 1: FIN(1) _(1) _(1) _(1) OPCODE(4)
|
|
byte1 = 0x80 if fin else 0
|
|
byte1 |= opcode
|
|
|
|
# Byte 2: MASK(1) LENGTH(7)
|
|
byte2 = 0x80 if mask else 0
|
|
|
|
if length < 126: # 126 is magic value to use 2-byte length header
|
|
byte2 |= length
|
|
self.sock.write(struct.pack('!BB', byte1, byte2))
|
|
|
|
elif length < (1 << 16): # Length fits in 2-bytes
|
|
byte2 |= 126 # Magic code
|
|
self.sock.write(struct.pack('!BBH', byte1, byte2, length))
|
|
|
|
elif length < (1 << 64):
|
|
byte2 |= 127 # Magic code
|
|
self.sock.write(struct.pack('!BBQ', byte1, byte2, length))
|
|
|
|
else:
|
|
raise ValueError()
|
|
|
|
if mask: # Mask is 4 bytes
|
|
mask_bits = struct.pack('!I', r.getrandbits(32))
|
|
self.sock.write(mask_bits)
|
|
data = bytes(b ^ mask_bits[i % 4]
|
|
for i, b in enumerate(data))
|
|
|
|
self.sock.write(data)
|
|
|
|
async def recv(self):
|
|
while await self.open():
|
|
try:
|
|
fin, opcode, data = await self.read_frame()
|
|
# except (ValueError, EOFError) as ex:
|
|
except Exception as ex:
|
|
print('Exception in recv while reading frame:', ex)
|
|
await self.open(False)
|
|
return
|
|
|
|
if not fin:
|
|
raise NotImplementedError()
|
|
|
|
if opcode == OP_TEXT:
|
|
return data.decode('utf-8')
|
|
elif opcode == OP_BYTES:
|
|
return data
|
|
elif opcode == OP_CLOSE:
|
|
await self.open(False)
|
|
return
|
|
elif opcode == OP_PONG:
|
|
# Ignore this frame, keep waiting for a data frame
|
|
continue
|
|
elif opcode == OP_PING:
|
|
try:
|
|
# We need to send a pong frame
|
|
self.write_frame(OP_PONG, data)
|
|
|
|
# And then continue to wait for a data frame
|
|
continue
|
|
except Exception as ex:
|
|
print('Error sending pong frame:', ex)
|
|
# If sending the pong frame fails, close the connection
|
|
await self.open(False)
|
|
return
|
|
elif opcode == OP_CONT:
|
|
# This is a continuation of a previous frame
|
|
raise NotImplementedError(opcode)
|
|
else:
|
|
raise ValueError(opcode)
|
|
|
|
async def send(self, buf):
|
|
if not await self.open():
|
|
return
|
|
if isinstance(buf, str):
|
|
opcode = OP_TEXT
|
|
buf = buf.encode('utf-8')
|
|
elif isinstance(buf, bytes):
|
|
opcode = OP_BYTES
|
|
else:
|
|
raise TypeError()
|
|
self.write_frame(opcode, buf)
|
|
|