websocket: use asyncio instead of threads

This commit is contained in:
Thomas Farstrike
2025-11-11 10:40:46 +01:00
parent 26d9fed77c
commit 4152604bf4
3 changed files with 86 additions and 58 deletions
+9 -4
View File
@@ -130,7 +130,7 @@ class WebSocketApp:
"""Send binary data."""
self.send(data, ABNF.OPCODE_BINARY)
def close(self, **kwargs):
async def close(self, **kwargs):
"""Close the WebSocket connection."""
_log_debug("Close requested")
self.running = False
@@ -184,7 +184,7 @@ class WebSocketApp:
_log_debug(f"Connection status: ready={status}")
return status
def run_forever(
async def run_forever(
self,
sockopt=None,
sslopt=None,
@@ -230,7 +230,7 @@ class WebSocketApp:
self.close()
return False
except Exception as e:
_log_error(f"run_forever's _loop.run_until_complete() got general exception: {e}")
_log_error(f"run_forever's _loop.run_until_complete() for {self.url} got general exception: {e}")
self.has_errored = True
self.running = False
#return True
@@ -262,7 +262,7 @@ class WebSocketApp:
try:
await self._connect_and_run() # keep waiting for it, until finished
except Exception as e:
_log_error(f"_async_main got exception: {e}")
_log_error(f"_async_main's await self._connect_and_run() got exception: {e}")
self.has_errored = True
_run_callback(self.on_error, self, e)
if not reconnect:
@@ -298,6 +298,11 @@ class WebSocketApp:
self.session = aiohttp.ClientSession(headers=self.header)
async with self.session.ws_connect(self.url, ssl=ssl_context) as ws:
if not ws:
print("ERROR: ws_connect got None instead of ws object!")
_run_callback(self.on_error, self, str(e))
return
self.ws = ws
_log_debug("WebSocket connected, running on_open callback")
_run_callback(self.on_open, self)
+76 -53
View File
@@ -1,3 +1,4 @@
import asyncio
import unittest
import _thread
import time
@@ -7,14 +8,19 @@ import mpos.apps
from websocket import WebSocketApp
class TestWebsocket(unittest.TestCase):
class TestMutlipleWebsocketsAsyncio(unittest.TestCase):
ws = None
max_allowed_connections = 3 # max that echo.websocket.org allows
on_open_called = None
on_message_called = None
on_ping_called = None
on_close_called = None
#relays = ["wss://echo.websocket.org", "wss://echo.websocket.org", "wss://echo.websocket.org" ] # more gives "too many requests" error
relays = ["wss://echo.websocket.org", "wss://echo.websocket.org", "wss://echo.websocket.org", "wss://echo.websocket.org", "wss://echo.websocket.org" ] # more might give "too many requests" error
wslist = []
on_open_called = 0
on_message_called = 0
on_ping_called = 0
on_close_called = 0
on_error_called = 0
def on_message(self, wsapp, message: str):
print(f"on_message received: {message}")
@@ -22,8 +28,8 @@ class TestWebsocket(unittest.TestCase):
def on_open(self, wsapp):
print(f"on_open called: {wsapp}")
self.on_open_called = True
self.ws.send('{"type": "subscribe","product_ids": ["BTC-USD"],"channels": ["ticker_batch"]}')
self.on_open_called += 1
#wsapp.send('{"type": "subscribe","product_ids": ["BTC-USD"],"channels": ["ticker_batch"]}')
def on_ping(wsapp, message):
print("Got a ping!")
@@ -31,19 +37,63 @@ class TestWebsocket(unittest.TestCase):
def on_close(self, wsapp, close_status_code, close_msg):
print(f"on_close called: {wsapp}")
self.on_close_called = True
self.on_close_called += 1
def websocket_thread(self):
wsurl = "wss://ws-feed.exchange.coinbase.com"
def on_error(self, wsapp, arg1):
print(f"on_error called: {wsapp}, {arg1}")
self.on_error_called += 1
self.ws = WebSocketApp(
wsurl,
on_open=self.on_open,
on_close=self.on_close,
on_message=self.on_message,
on_ping=self.on_ping
) # maybe add other callbacks to reconnect when disconnected etc.
self.ws.run_forever()
async def closeall(self):
await asyncio.sleep(1)
self.on_close_called = 0
print("disconnecting...")
for ws in self.wslist:
await ws.close()
async def main(self) -> None:
tasks = []
self.wslist = []
for idx, wsurl in enumerate(self.relays):
print(f"creating WebSocketApp for {wsurl}")
ws = WebSocketApp(
wsurl,
on_open=self.on_open,
on_close=self.on_close,
on_message=self.on_message,
on_ping=self.on_ping,
on_error=self.on_error
)
print(f"creating task for {wsurl}")
tasks.append(asyncio.create_task(ws.run_forever(),))
print(f"created task for {wsurl}")
self.wslist.append(ws)
print(f"Starting {len(tasks)} concurrent WebSocket connections…")
await asyncio.sleep(2)
await self.closeall()
for _ in range(10):
print("Waiting for on_open to be called...")
if self.on_open_called == min(len(self.relays),self.max_allowed_connections):
print("yes, it was called!")
break
await asyncio.sleep(1)
self.assertTrue(self.on_open_called == min(len(self.relays),self.max_allowed_connections))
for _ in range(10):
print("Waiting for on_close to be called...")
if self.on_close_called == min(len(self.relays),self.max_allowed_connections):
print("yes, it was called!")
break
await asyncio.sleep(1)
self.assertTrue(self.on_close_called == min(len(self.relays),self.max_allowed_connections))
self.assertTrue(self.on_error_called == min(len(self.relays),self.max_allowed_connections))
# Wait for *all* of them to finish (or be cancelled)
# If this hangs, it's also a failure:
await asyncio.gather(*tasks, return_exceptions=True)
def wait_for_ping(self):
self.on_ping_called = False
@@ -55,39 +105,12 @@ class TestWebsocket(unittest.TestCase):
time.sleep(1)
self.assertTrue(self.on_ping_called)
def test_it(self):
on_open_called = False
_thread.stack_size(mpos.apps.good_stack_size())
_thread.start_new_thread(self.websocket_thread, ())
self.on_open_called = False
self.on_message_called = False # message might be received very quickly, before we expect it
for _ in range(5):
print("Waiting for on_open to be called...")
if self.on_open_called:
print("yes, it was called!")
break
time.sleep(1)
self.assertTrue(self.on_open_called)
def test_it_loop(self):
for testnr in range(1):
print(f"starting iteration {testnr}")
asyncio.run(self.do_two())
print(f"finished iteration {testnr}")
self.on_message_called = False # message might be received very quickly, before we expect it
for _ in range(5):
print("Waiting for on_message to be called...")
if self.on_message_called:
print("yes, it was called!")
break
time.sleep(1)
self.assertTrue(self.on_message_called)
def do_two(self):
await self.main()
# Disabled because not all servers send pings:
# self.wait_for_ping()
self.on_close_called = False
self.ws.close()
for _ in range(5):
print("Waiting for on_close to be called...")
if self.on_close_called:
print("yes, it was called!")
break
time.sleep(1)
self.assertTrue(self.on_close_called)