From 4a22981c9ca2a37201338aa987cd9dd96e7ed618 Mon Sep 17 00:00:00 2001 From: Thomas Farstrike Date: Tue, 11 Nov 2025 10:51:02 +0100 Subject: [PATCH] Add unit tests --- .gitignore | 1 + tests/manual_test_nostr_asyncio.py | 333 +++++++++++++++++++++++++++++ tests/manual_test_nwcwallet.py | 49 +++++ tests/test_multi_connect.py | 255 ++++++++++++++++++++++ 4 files changed, 638 insertions(+) create mode 100644 tests/manual_test_nostr_asyncio.py create mode 100644 tests/manual_test_nwcwallet.py create mode 100644 tests/test_multi_connect.py diff --git a/.gitignore b/.gitignore index cab0b726..6f7b3193 100644 --- a/.gitignore +++ b/.gitignore @@ -8,4 +8,5 @@ internal_filesystem/SDLPointer_3 # config files etc: internal_filesystem/data +internal_filesystem/sdcard diff --git a/tests/manual_test_nostr_asyncio.py b/tests/manual_test_nostr_asyncio.py new file mode 100644 index 00000000..aa3ea80f --- /dev/null +++ b/tests/manual_test_nostr_asyncio.py @@ -0,0 +1,333 @@ +import asyncio +import json +import ssl +import _thread +import time +import unittest + +from mpos import App, PackageManager +import mpos.apps + +from nostr.relay_manager import RelayManager +from nostr.message_type import ClientMessageType +from nostr.filter import Filter, Filters +from nostr.event import EncryptedDirectMessage +from nostr.key import PrivateKey + + +# keeps a list of items +# The .add() method ensures the list remains unique (via __eq__) +# and sorted (via __lt__) by inserting new items in the correct position. +class UniqueSortedList: + def __init__(self): + self._items = [] + + def add(self, item): + #print(f"before add: {str(self)}") + # Check if item already exists (using __eq__) + if item not in self._items: + # Insert item in sorted position for descending order (using __gt__) + for i, existing_item in enumerate(self._items): + if item > existing_item: + self._items.insert(i, item) + return + # If item is smaller than all existing items, append it + self._items.append(item) + #print(f"after add: {str(self)}") + + def __iter__(self): + # Return iterator for the internal list + return iter(self._items) + + def get(self, index_nr): + # Retrieve item at given index, raise IndexError if invalid + try: + return self._items[index_nr] + except IndexError: + raise IndexError("Index out of range") + + def __len__(self): + # Return the number of items for len() calls + return len(self._items) + + def __str__(self): + #print("UniqueSortedList tostring called") + return "\n".join(str(item) for item in self._items) + + def __eq__(self, other): + if len(self._items) != len(other): + return False + return all(p1 == p2 for p1, p2 in zip(self._items, other)) + +# Payment class remains unchanged +class Payment: + def __init__(self, epoch_time, amount_sats, comment): + self.epoch_time = epoch_time + self.amount_sats = amount_sats + self.comment = comment + + def __str__(self): + sattext = "sats" + if self.amount_sats == 1: + sattext = "sat" + #return f"{self.amount_sats} {sattext} @ {self.epoch_time}: {self.comment}" + return f"{self.amount_sats} {sattext}: {self.comment}" + + def __eq__(self, other): + if not isinstance(other, Payment): + return False + return self.epoch_time == other.epoch_time and self.amount_sats == other.amount_sats and self.comment == other.comment + + def __lt__(self, other): + if not isinstance(other, Payment): + return NotImplemented + return (self.epoch_time, self.amount_sats, self.comment) < (other.epoch_time, other.amount_sats, other.comment) + + def __le__(self, other): + if not isinstance(other, Payment): + return NotImplemented + return (self.epoch_time, self.amount_sats, self.comment) <= (other.epoch_time, other.amount_sats, other.comment) + + def __gt__(self, other): + if not isinstance(other, Payment): + return NotImplemented + return (self.epoch_time, self.amount_sats, self.comment) > (other.epoch_time, other.amount_sats, other.comment) + + def __ge__(self, other): + if not isinstance(other, Payment): + return NotImplemented + return (self.epoch_time, self.amount_sats, self.comment) >= (other.epoch_time, other.amount_sats, other.comment) + + + +class TestNostr(unittest.TestCase): + + PAYMENTS_TO_SHOW = 5 + + keep_running = None + connected = None + balance = -1 + payment_list = [] + transactions_welcome = False + + relays = [ "ws://192.168.1.16:5000/nostrrelay/test", "ws://192.168.1.16:5000/nostrclient/api/v1/relay" ] + #relays = [ "ws://127.0.0.1:5000/nostrrelay/test", "ws://127.0.0.1:5000/nostrclient/api/v1/relay" ] + #relays = [ "wss://relay.damus.io", "wss://nostr-pub.wellorder.net" ] + #relays = [ "ws://127.0.0.1:5000/nostrrelay/test", "ws://127.0.0.1:5000/nostrclient/api/v1/relay", "wss://relay.damus.io", "wss://nostr-pub.wellorder.net" ] + #relays = [ "ws://127.0.0.1:5000/nostrclient/api/v1/relay", "wss://relay.damus.io", "wss://nostr-pub.wellorder.net" ] + secret = "fab0a9a11d4cf4b1d92e901a0b2c56634275e2fa1a7eb396ff1b942f95d59fd3" + wallet_pubkey = "e46762afab282c324278351165122345f9983ea447b47943b052100321227571" + + async def fetch_balance(self): + if not self.keep_running: + return + # Create get_balance request + balance_request = { + "method": "get_balance", + "params": {} + } + print(f"DEBUG: Created balance request: {balance_request}") + print(f"DEBUG: Creating encrypted DM to wallet pubkey: {self.wallet_pubkey}") + dm = EncryptedDirectMessage( + recipient_pubkey=self.wallet_pubkey, + cleartext_content=json.dumps(balance_request), + kind=23194 + ) + print(f"DEBUG: Signing DM {json.dumps(dm)} with private key") + self.private_key.sign_event(dm) # sign also does encryption if it's a encrypted dm + print(f"DEBUG: Publishing encrypted DM") + self.relay_manager.publish_event(dm) + + def handle_new_balance(self, new_balance, fetchPaymentsIfChanged=True): + if not self.keep_running or new_balance is None: + return + if fetchPaymentsIfChanged: # Fetching *all* payments isn't necessary if balance was changed by a payment notification + print("Refreshing payments...") + self.fetch_payments() # if the balance changed, then re-list transactions + + def fetch_payments(self): + if not self.keep_running: + return + # Create get_balance request + list_transactions = { + "method": "list_transactions", + "params": { + "limit": self.PAYMENTS_TO_SHOW + } + } + dm = EncryptedDirectMessage( + recipient_pubkey=self.wallet_pubkey, + cleartext_content=json.dumps(list_transactions), + kind=23194 + ) + self.private_key.sign_event(dm) # sign also does encryption if it's a encrypted dm + print("\nPublishing DM to fetch payments...") + self.relay_manager.publish_event(dm) + self.transactions_welcome = True + + def handle_new_payments(self, new_payments): + if not self.keep_running or not self.transactions_welcome: + return + print("handle_new_payments") + if self.payment_list != new_payments: + print("new list of payments") + self.payment_list = new_payments + self.payments_updated_cb() + + def payments_updated_cb(self): + print("payments_updated_cb called, now closing everything!") + self.keep_running = False + + def getCommentFromTransaction(self, transaction): + comment = "" + try: + comment = transaction["description"] + json_comment = json.loads(comment) + for field in json_comment: + if field[0] == "text/plain": + comment = field[1] + break + else: + print("text/plain field is missing from JSON description") + except Exception as e: + print(f"Info: could not parse comment as JSON, this is fine, using as-is ({e})") + return comment + + + async def NOmainHERE(self): + self.keep_running = True + self.private_key = PrivateKey(bytes.fromhex(self.secret)) + self.relay_manager = RelayManager() + for relay in self.relays: + self.relay_manager.add_relay(relay) + + print(f"DEBUG: Opening relay connections") + await self.relay_manager.open_connections({"cert_reqs": ssl.CERT_NONE}) + self.connected = False + for _ in range(20): + print("Waiting for relay connection...") + await asyncio.sleep(0.5) + nrconnected = 0 + for index, relay in enumerate(self.relays): + try: + relay = self.relay_manager.relays[self.relays[index]] + if relay.connected is True: + print(f"connected: {self.relays[index]}") + nrconnected += 1 + else: + print(f"not connected: {self.relays[index]}") + except Exception as e: + print(f"could not find relay: {e}") + break # not all of them have been initialized, skip... + self.connected = ( nrconnected == len(self.relays) ) + if self.connected: + print("All relays connected!") + break + if not self.connected or not self.keep_running: + print(f"ERROR: could not connect to relay or not self.keep_running, aborting...") + # TODO: call an error callback to notify the user + return + + # Set up subscription to receive response + self.subscription_id = "micropython_nwc_" + str(round(time.time())) + print(f"DEBUG: Setting up subscription with ID: {self.subscription_id}") + self.filters = Filters([Filter( + #event_ids=[self.subscription_id], # would be nice to filter, but not like this + kinds=[23195, 23196], # NWC reponses and notifications + authors=[self.wallet_pubkey], + pubkey_refs=[self.private_key.public_key.hex()] + )]) + print(f"DEBUG: Subscription filters: {self.filters.to_json_array()}") + self.relay_manager.add_subscription(self.subscription_id, self.filters) + print(f"DEBUG: Creating subscription request") + request_message = [ClientMessageType.REQUEST, self.subscription_id] + request_message.extend(self.filters.to_json_array()) + print(f"DEBUG: Publishing subscription request") + self.relay_manager.publish_message(json.dumps(request_message)) + print(f"DEBUG: Published subscription request") + for _ in range(4): + if not self.keep_running: + return + print("Waiting a bit before self.fetch_balance()") + await asyncio.sleep(0.5) + + await self.fetch_balance() + + while True: + print(f"checking for incoming events...") + await asyncio.sleep(1) + if not self.keep_running: + print("NWCWallet: not keep_running, closing connections...") + await self.relay_manager.close_connections() + break + + start_time = time.ticks_ms() + if self.relay_manager.message_pool.has_events(): + print(f"DEBUG: Event received from message pool after {time.ticks_ms()-start_time}ms") + event_msg = self.relay_manager.message_pool.get_event() + event_created_at = event_msg.event.created_at + print(f"Received at {time.localtime()} a message with timestamp {event_created_at} after {time.ticks_ms()-start_time}ms") + try: + # This takes a very long time, even for short messages: + decrypted_content = self.private_key.decrypt_message( + event_msg.event.content, + event_msg.event.public_key, + ) + print(f"DEBUG: Decrypted content: {decrypted_content} after {time.ticks_ms()-start_time}ms") + response = json.loads(decrypted_content) + print(f"DEBUG: Parsed response: {response}") + result = response.get("result") + if result: + if result.get("balance") is not None: + new_balance = round(int(result["balance"]) / 1000) + print(f"Got balance: {new_balance}") + self.handle_new_balance(new_balance) + elif result.get("transactions") is not None: + print("Response contains transactions!") + new_payment_list = UniqueSortedList() + for transaction in result["transactions"]: + amount = transaction["amount"] + amount = round(amount / 1000) + comment = self.getCommentFromTransaction(transaction) + epoch_time = transaction["created_at"] + paymentObj = Payment(epoch_time, amount, comment) + new_payment_list.add(paymentObj) + if len(new_payment_list) > 0: + # do them all in one shot instead of one-by-one because the lv_async() isn't always chronological, + # so when a long list of payments is added, it may be overwritten by a short list + self.handle_new_payments(new_payment_list) + else: + notification = response.get("notification") + if notification: + amount = notification["amount"] + amount = round(amount / 1000) + type = notification["type"] + if type == "outgoing": + amount = -amount + elif type == "incoming": + new_balance = self.last_known_balance + amount + self.handle_new_balance(new_balance, False) # don't trigger full fetch because payment info is in notification + epoch_time = notification["created_at"] + comment = self.getCommentFromTransaction(notification) + paymentObj = Payment(epoch_time, amount, comment) + self.handle_new_payment(paymentObj) + else: + print(f"WARNING: invalid notification type {type}, ignoring.") + else: + print("Unsupported response, ignoring.") + except Exception as e: + print(f"DEBUG: Error processing response: {e}") + else: + #print(f"pool has no events after {time.ticks_ms()-start_time}ms") # completes in 0-1ms + pass + + def test_it(self): + print("before do_two") + asyncio.run(self.do_two()) + print("after do_two") + + def do_two(self): + print("before await self.NOmainHERE()") + await self.NOmainHERE() + print("after await self.NOmainHERE()") + diff --git a/tests/manual_test_nwcwallet.py b/tests/manual_test_nwcwallet.py new file mode 100644 index 00000000..6ce2a3b1 --- /dev/null +++ b/tests/manual_test_nwcwallet.py @@ -0,0 +1,49 @@ +import asyncio +import json +import ssl +import _thread +import time +import unittest + +from mpos import App, PackageManager +import mpos.apps + +import sys +sys.path.append("apps/com.lightningpiggy.displaywallet/assets/") +from wallet import NWCWallet + +class TestNWCWallet(unittest.TestCase): + + redraw_balance_cb_called = 0 + redraw_payments_cb_called = 0 + redraw_static_receive_code_cb_called = 0 + error_callback_called = 0 + + def redraw_balance_cb(self, balance=0): + print(f"redraw_callback called, balance: {balance}") + self.redraw_balance_cb_called += 1 + + def redraw_payments_cb(self): + print(f"redraw_payments_cb called") + self.redraw_payments_cb_called += 1 + + def redraw_static_receive_code_cb(self): + print(f"redraw_static_receive_code_cb called") + self.redraw_static_receive_code_cb_called += 1 + + def error_callback(self, error): + print(f"error_callback called, error: {error}") + self.error_callback_called += 1 + + def test_it(self): + print("starting test") + self.wallet = NWCWallet("nostr+walletconnect://e46762afab282c324278351165122345f9983ea447b47943b052100321227571?relay=ws://192.168.1.16:5000/nostrclient/api/v1/relay&secret=fab0a9a11d4cf4b1d92e901a0b2c56634275e2fa1a7eb396ff1b942f95d59fd3&lud16=test@example.com") + self.wallet.start(self.redraw_balance_cb, self.redraw_payments_cb, self.redraw_static_receive_code_cb, self.error_callback) + time.sleep(15) + self.assertTrue(self.redraw_balance_cb_called > 0) + self.assertTrue(self.redraw_payments_cb_called > 0) + self.assertTrue(self.redraw_static_receive_code_cb_called > 0) + self.assertTrue(self.error_callback_called == 0) + print("test finished") + + diff --git a/tests/test_multi_connect.py b/tests/test_multi_connect.py new file mode 100644 index 00000000..1559f7c4 --- /dev/null +++ b/tests/test_multi_connect.py @@ -0,0 +1,255 @@ +import unittest +import _thread +import time + +from mpos import App, PackageManager +import mpos.apps + +from websocket import WebSocketApp + + +# demo_multiple_ws.py +import asyncio +import aiohttp +from aiohttp import WSMsgType +import logging +import sys +from typing import List + + + +# ---------------------------------------------------------------------- +# Logging +# ---------------------------------------------------------------------- +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(message)s", + stream=sys.stdout, +) +log = logging.getLogger(__name__) + + +class TestTwoWebsockets(unittest.TestCase): +#class TestTwoWebsockets(): + + # ---------------------------------------------------------------------- + # Configuration + # ---------------------------------------------------------------------- + # Change these to point to a real echo / chat server you control. + WS_URLS = [ + "wss://echo.websocket.org", # public echo service (may be down) + "wss://echo.websocket.org", # duplicate on purpose – shows concurrency + "wss://echo.websocket.org", + # add more URLs here… + ] + + nr_connected = 0 + + # How many messages each connection should send before closing gracefully + MESSAGES_PER_CONNECTION = 2 + STOP_AFTER = 10 + + # ---------------------------------------------------------------------- + # One connection worker + # ---------------------------------------------------------------------- + async def ws_worker(self, session: aiohttp.ClientSession, url: str, idx: int) -> None: + """ + Handles a single WebSocket connection: + * sends a few messages, + * echoes back everything it receives, + * closes when the remote end says "close" or after MESSAGES_PER_CONNECTION. + """ + try: + async with session.ws_connect(url) as ws: + log.info(f"[{idx}] Connected to {url}") + self.nr_connected += 1 + + # ------------------------------------------------------------------ + # 1. Send a few starter messages + # ------------------------------------------------------------------ + for i in range(self.MESSAGES_PER_CONNECTION): + payload = f"Hello from client #{idx} – msg {i+1}" + await ws.send_str(payload) + log.info(f"[{idx}] → {payload}") + + # give the server a moment to reply + await asyncio.sleep(0.5) + + # ------------------------------------------------------------------ + # 2. Echo-loop – react to incoming messages + # ------------------------------------------------------------------ + msgcounter = 0 + async for msg in ws: + msgcounter += 1 + if msgcounter > self.STOP_AFTER: + print("Max reached, stopping...") + await ws.close() + break + if msg.type == WSMsgType.TEXT: + data: str = msg.data + log.info(f"[{idx}] ← {data}") + + # Echo back (with a suffix) + reply = data + " / answer" + await ws.send_str(reply) + log.info(f"[{idx}] → {reply}") + + # Close if server asks us to + if data.strip().lower() == "close cmd": + log.info(f"[{idx}] Server asked to close → closing") + await ws.close() + break + + elif msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED): + log.info(f"[{idx}] Connection closed by remote") + break + + elif msg.type == WSMsgType.ERROR: + log.error(f"[{idx}] WebSocket error: {ws.exception()}") + break + + except asyncio.CancelledError: + log.info(f"[{idx}] Task cancelled") + raise + except Exception as exc: + log.exception(f"[{idx}] Unexpected error on {url}: {exc}") + finally: + log.info(f"[{idx}] Worker finished for {url}") + + # ---------------------------------------------------------------------- + # Main entry point – creates a single ClientSession + many tasks + # ---------------------------------------------------------------------- + async def main(self) -> None: + async with aiohttp.ClientSession() as session: + # Create one task per URL (they all run concurrently) + tasks = [ + asyncio.create_task(self.ws_worker(session, url, idx)) + for idx, url in enumerate(self.WS_URLS) + ] + + log.info(f"Starting {len(tasks)} concurrent WebSocket connections…") + # Wait for *all* of them to finish (or be cancelled) + await asyncio.gather(*tasks, return_exceptions=True) + log.info(f"All tasks stopped successfully!") + self.assertTrue(self.nr_connected, len(self.WS_URLS)) + + def newthread(self): + asyncio.run(self.main()) + + def test_it(self): + _thread.stack_size(mpos.apps.good_stack_size()) + _thread.start_new_thread(self.newthread, ()) + time.sleep(10) + +# This demonstrates a crash when doing asyncio using different threads: +#class TestCrashingSeparateThreads(unittest.TestCase): +class TestCrashingSeparateThreads(): + + # ---------------------------------------------------------------------- + # Configuration + # ---------------------------------------------------------------------- + # Change these to point to a real echo / chat server you control. + WS_URLS = [ + "wss://echo.websocket.org", # public echo service (may be down) + "wss://echo.websocket.org", # duplicate on purpose – shows concurrency + "wss://echo.websocket.org", + # add more URLs here… + ] + + # How many messages each connection should send before closing gracefully + MESSAGES_PER_CONNECTION = 2 + STOP_AFTER = 10 + + # ---------------------------------------------------------------------- + # One connection worker + # ---------------------------------------------------------------------- + async def ws_worker(self, session: aiohttp.ClientSession, url: str, idx: int) -> None: + """ + Handles a single WebSocket connection: + * sends a few messages, + * echoes back everything it receives, + * closes when the remote end says "close" or after MESSAGES_PER_CONNECTION. + """ + try: + async with session.ws_connect(url) as ws: + log.info(f"[{idx}] Connected to {url}") + + # ------------------------------------------------------------------ + # 1. Send a few starter messages + # ------------------------------------------------------------------ + for i in range(self.MESSAGES_PER_CONNECTION): + payload = f"Hello from client #{idx} – msg {i+1}" + await ws.send_str(payload) + log.info(f"[{idx}] → {payload}") + + # give the server a moment to reply + await asyncio.sleep(0.5) + + # ------------------------------------------------------------------ + # 2. Echo-loop – react to incoming messages + # ------------------------------------------------------------------ + msgcounter = 0 + async for msg in ws: + msgcounter += 1 + if msgcounter > self.STOP_AFTER: + print("Max reached, stopping...") + await ws.close() + break + if msg.type == WSMsgType.TEXT: + data: str = msg.data + log.info(f"[{idx}] ← {data}") + + # Echo back (with a suffix) + reply = data + " / answer" + await ws.send_str(reply) + log.info(f"[{idx}] → {reply}") + + # Close if server asks us to + if data.strip().lower() == "close cmd": + log.info(f"[{idx}] Server asked to close → closing") + await ws.close() + break + + elif msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED): + log.info(f"[{idx}] Connection closed by remote") + break + + elif msg.type == WSMsgType.ERROR: + log.error(f"[{idx}] WebSocket error: {ws.exception()}") + break + + except asyncio.CancelledError: + log.info(f"[{idx}] Task cancelled") + raise + except Exception as exc: + log.exception(f"[{idx}] Unexpected error on {url}: {exc}") + finally: + log.info(f"[{idx}] Worker finished for {url}") + + # ---------------------------------------------------------------------- + # Main entry point – creates a single ClientSession + many tasks + # ---------------------------------------------------------------------- + async def main(self) -> None: + async with aiohttp.ClientSession() as session: + # Create one task per URL (they all run concurrently) + tasks = [ + asyncio.create_task(self.ws_worker(session, url, idx)) + for idx, url in enumerate(self.WS_URLS) + ] + + log.info(f"Starting {len(tasks)} concurrent WebSocket connections…") + # Wait for *all* of them to finish (or be cancelled) + await asyncio.gather(*tasks, return_exceptions=True) + + async def almostmain(self, url): + async with aiohttp.ClientSession() as session: + asyncio.create_task(self.ws_worker(session, url, idx)) + + def newthread(self, url): + asyncio.run(self.main()) + + def test_it(self): + for url in self.WS_URLS: + _thread.stack_size(mpos.apps.good_stack_size()) + _thread.start_new_thread(self.newthread, (url,)) + time.sleep(15)