diff --git a/internal_filesystem/lib/websocket.py b/internal_filesystem/lib/websocket.py index bae33ea8..a2cde71b 100644 --- a/internal_filesystem/lib/websocket.py +++ b/internal_filesystem/lib/websocket.py @@ -130,7 +130,7 @@ class WebSocketApp: """Send binary data.""" self.send(data, ABNF.OPCODE_BINARY) - def close(self, **kwargs): + async def close(self, **kwargs): """Close the WebSocket connection.""" _log_debug("Close requested") self.running = False @@ -184,7 +184,7 @@ class WebSocketApp: _log_debug(f"Connection status: ready={status}") return status - def run_forever( + async def run_forever( self, sockopt=None, sslopt=None, @@ -230,7 +230,7 @@ class WebSocketApp: self.close() return False except Exception as e: - _log_error(f"run_forever's _loop.run_until_complete() got general exception: {e}") + _log_error(f"run_forever's _loop.run_until_complete() for {self.url} got general exception: {e}") self.has_errored = True self.running = False #return True @@ -262,7 +262,7 @@ class WebSocketApp: try: await self._connect_and_run() # keep waiting for it, until finished except Exception as e: - _log_error(f"_async_main got exception: {e}") + _log_error(f"_async_main's await self._connect_and_run() got exception: {e}") self.has_errored = True _run_callback(self.on_error, self, e) if not reconnect: @@ -298,6 +298,11 @@ class WebSocketApp: self.session = aiohttp.ClientSession(headers=self.header) async with self.session.ws_connect(self.url, ssl=ssl_context) as ws: + if not ws: + print("ERROR: ws_connect got None instead of ws object!") + _run_callback(self.on_error, self, str(e)) + return + self.ws = ws _log_debug("WebSocket connected, running on_open callback") _run_callback(self.on_open, self) diff --git a/micropython-nostr b/micropython-nostr index ac2c74ab..da7c2be1 160000 --- a/micropython-nostr +++ b/micropython-nostr @@ -1 +1 @@ -Subproject commit ac2c74ab8377d5ace53136f25366881f205a26fa +Subproject commit da7c2be1ca436a39e8b6ef32b0f279ebc088d77d diff --git a/tests/test_websocket.py b/tests/test_websocket.py index 051fe728..ddeb1123 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -1,3 +1,4 @@ +import asyncio import unittest import _thread import time @@ -7,14 +8,19 @@ import mpos.apps from websocket import WebSocketApp -class TestWebsocket(unittest.TestCase): +class TestMutlipleWebsocketsAsyncio(unittest.TestCase): - ws = None + max_allowed_connections = 3 # max that echo.websocket.org allows - on_open_called = None - on_message_called = None - on_ping_called = None - on_close_called = None + #relays = ["wss://echo.websocket.org", "wss://echo.websocket.org", "wss://echo.websocket.org" ] # more gives "too many requests" error + relays = ["wss://echo.websocket.org", "wss://echo.websocket.org", "wss://echo.websocket.org", "wss://echo.websocket.org", "wss://echo.websocket.org" ] # more might give "too many requests" error + wslist = [] + + on_open_called = 0 + on_message_called = 0 + on_ping_called = 0 + on_close_called = 0 + on_error_called = 0 def on_message(self, wsapp, message: str): print(f"on_message received: {message}") @@ -22,8 +28,8 @@ class TestWebsocket(unittest.TestCase): def on_open(self, wsapp): print(f"on_open called: {wsapp}") - self.on_open_called = True - self.ws.send('{"type": "subscribe","product_ids": ["BTC-USD"],"channels": ["ticker_batch"]}') + self.on_open_called += 1 + #wsapp.send('{"type": "subscribe","product_ids": ["BTC-USD"],"channels": ["ticker_batch"]}') def on_ping(wsapp, message): print("Got a ping!") @@ -31,19 +37,63 @@ class TestWebsocket(unittest.TestCase): def on_close(self, wsapp, close_status_code, close_msg): print(f"on_close called: {wsapp}") - self.on_close_called = True + self.on_close_called += 1 - def websocket_thread(self): - wsurl = "wss://ws-feed.exchange.coinbase.com" + def on_error(self, wsapp, arg1): + print(f"on_error called: {wsapp}, {arg1}") + self.on_error_called += 1 - self.ws = WebSocketApp( - wsurl, - on_open=self.on_open, - on_close=self.on_close, - on_message=self.on_message, - on_ping=self.on_ping - ) # maybe add other callbacks to reconnect when disconnected etc. - self.ws.run_forever() + async def closeall(self): + await asyncio.sleep(1) + + self.on_close_called = 0 + print("disconnecting...") + for ws in self.wslist: + await ws.close() + + async def main(self) -> None: + tasks = [] + self.wslist = [] + for idx, wsurl in enumerate(self.relays): + print(f"creating WebSocketApp for {wsurl}") + ws = WebSocketApp( + wsurl, + on_open=self.on_open, + on_close=self.on_close, + on_message=self.on_message, + on_ping=self.on_ping, + on_error=self.on_error + ) + print(f"creating task for {wsurl}") + tasks.append(asyncio.create_task(ws.run_forever(),)) + print(f"created task for {wsurl}") + self.wslist.append(ws) + + print(f"Starting {len(tasks)} concurrent WebSocket connections…") + await asyncio.sleep(2) + await self.closeall() + + for _ in range(10): + print("Waiting for on_open to be called...") + if self.on_open_called == min(len(self.relays),self.max_allowed_connections): + print("yes, it was called!") + break + await asyncio.sleep(1) + self.assertTrue(self.on_open_called == min(len(self.relays),self.max_allowed_connections)) + + for _ in range(10): + print("Waiting for on_close to be called...") + if self.on_close_called == min(len(self.relays),self.max_allowed_connections): + print("yes, it was called!") + break + await asyncio.sleep(1) + self.assertTrue(self.on_close_called == min(len(self.relays),self.max_allowed_connections)) + + self.assertTrue(self.on_error_called == min(len(self.relays),self.max_allowed_connections)) + + # Wait for *all* of them to finish (or be cancelled) + # If this hangs, it's also a failure: + await asyncio.gather(*tasks, return_exceptions=True) def wait_for_ping(self): self.on_ping_called = False @@ -55,39 +105,12 @@ class TestWebsocket(unittest.TestCase): time.sleep(1) self.assertTrue(self.on_ping_called) - def test_it(self): - on_open_called = False - _thread.stack_size(mpos.apps.good_stack_size()) - _thread.start_new_thread(self.websocket_thread, ()) - - self.on_open_called = False - self.on_message_called = False # message might be received very quickly, before we expect it - for _ in range(5): - print("Waiting for on_open to be called...") - if self.on_open_called: - print("yes, it was called!") - break - time.sleep(1) - self.assertTrue(self.on_open_called) + def test_it_loop(self): + for testnr in range(1): + print(f"starting iteration {testnr}") + asyncio.run(self.do_two()) + print(f"finished iteration {testnr}") - self.on_message_called = False # message might be received very quickly, before we expect it - for _ in range(5): - print("Waiting for on_message to be called...") - if self.on_message_called: - print("yes, it was called!") - break - time.sleep(1) - self.assertTrue(self.on_message_called) + def do_two(self): + await self.main() - # Disabled because not all servers send pings: - # self.wait_for_ping() - - self.on_close_called = False - self.ws.close() - for _ in range(5): - print("Waiting for on_close to be called...") - if self.on_close_called: - print("yes, it was called!") - break - time.sleep(1) - self.assertTrue(self.on_close_called)