Add unit tests

This commit is contained in:
Thomas Farstrike
2025-11-11 10:51:02 +01:00
parent 4152604bf4
commit 4a22981c9c
4 changed files with 638 additions and 0 deletions
+1
View File
@@ -8,4 +8,5 @@ internal_filesystem/SDLPointer_3
# config files etc:
internal_filesystem/data
internal_filesystem/sdcard
+333
View File
@@ -0,0 +1,333 @@
import asyncio
import json
import ssl
import _thread
import time
import unittest
from mpos import App, PackageManager
import mpos.apps
from nostr.relay_manager import RelayManager
from nostr.message_type import ClientMessageType
from nostr.filter import Filter, Filters
from nostr.event import EncryptedDirectMessage
from nostr.key import PrivateKey
# keeps a list of items
# The .add() method ensures the list remains unique (via __eq__)
# and sorted (via __lt__) by inserting new items in the correct position.
class UniqueSortedList:
def __init__(self):
self._items = []
def add(self, item):
#print(f"before add: {str(self)}")
# Check if item already exists (using __eq__)
if item not in self._items:
# Insert item in sorted position for descending order (using __gt__)
for i, existing_item in enumerate(self._items):
if item > existing_item:
self._items.insert(i, item)
return
# If item is smaller than all existing items, append it
self._items.append(item)
#print(f"after add: {str(self)}")
def __iter__(self):
# Return iterator for the internal list
return iter(self._items)
def get(self, index_nr):
# Retrieve item at given index, raise IndexError if invalid
try:
return self._items[index_nr]
except IndexError:
raise IndexError("Index out of range")
def __len__(self):
# Return the number of items for len() calls
return len(self._items)
def __str__(self):
#print("UniqueSortedList tostring called")
return "\n".join(str(item) for item in self._items)
def __eq__(self, other):
if len(self._items) != len(other):
return False
return all(p1 == p2 for p1, p2 in zip(self._items, other))
# Payment class remains unchanged
class Payment:
def __init__(self, epoch_time, amount_sats, comment):
self.epoch_time = epoch_time
self.amount_sats = amount_sats
self.comment = comment
def __str__(self):
sattext = "sats"
if self.amount_sats == 1:
sattext = "sat"
#return f"{self.amount_sats} {sattext} @ {self.epoch_time}: {self.comment}"
return f"{self.amount_sats} {sattext}: {self.comment}"
def __eq__(self, other):
if not isinstance(other, Payment):
return False
return self.epoch_time == other.epoch_time and self.amount_sats == other.amount_sats and self.comment == other.comment
def __lt__(self, other):
if not isinstance(other, Payment):
return NotImplemented
return (self.epoch_time, self.amount_sats, self.comment) < (other.epoch_time, other.amount_sats, other.comment)
def __le__(self, other):
if not isinstance(other, Payment):
return NotImplemented
return (self.epoch_time, self.amount_sats, self.comment) <= (other.epoch_time, other.amount_sats, other.comment)
def __gt__(self, other):
if not isinstance(other, Payment):
return NotImplemented
return (self.epoch_time, self.amount_sats, self.comment) > (other.epoch_time, other.amount_sats, other.comment)
def __ge__(self, other):
if not isinstance(other, Payment):
return NotImplemented
return (self.epoch_time, self.amount_sats, self.comment) >= (other.epoch_time, other.amount_sats, other.comment)
class TestNostr(unittest.TestCase):
PAYMENTS_TO_SHOW = 5
keep_running = None
connected = None
balance = -1
payment_list = []
transactions_welcome = False
relays = [ "ws://192.168.1.16:5000/nostrrelay/test", "ws://192.168.1.16:5000/nostrclient/api/v1/relay" ]
#relays = [ "ws://127.0.0.1:5000/nostrrelay/test", "ws://127.0.0.1:5000/nostrclient/api/v1/relay" ]
#relays = [ "wss://relay.damus.io", "wss://nostr-pub.wellorder.net" ]
#relays = [ "ws://127.0.0.1:5000/nostrrelay/test", "ws://127.0.0.1:5000/nostrclient/api/v1/relay", "wss://relay.damus.io", "wss://nostr-pub.wellorder.net" ]
#relays = [ "ws://127.0.0.1:5000/nostrclient/api/v1/relay", "wss://relay.damus.io", "wss://nostr-pub.wellorder.net" ]
secret = "fab0a9a11d4cf4b1d92e901a0b2c56634275e2fa1a7eb396ff1b942f95d59fd3"
wallet_pubkey = "e46762afab282c324278351165122345f9983ea447b47943b052100321227571"
async def fetch_balance(self):
if not self.keep_running:
return
# Create get_balance request
balance_request = {
"method": "get_balance",
"params": {}
}
print(f"DEBUG: Created balance request: {balance_request}")
print(f"DEBUG: Creating encrypted DM to wallet pubkey: {self.wallet_pubkey}")
dm = EncryptedDirectMessage(
recipient_pubkey=self.wallet_pubkey,
cleartext_content=json.dumps(balance_request),
kind=23194
)
print(f"DEBUG: Signing DM {json.dumps(dm)} with private key")
self.private_key.sign_event(dm) # sign also does encryption if it's a encrypted dm
print(f"DEBUG: Publishing encrypted DM")
self.relay_manager.publish_event(dm)
def handle_new_balance(self, new_balance, fetchPaymentsIfChanged=True):
if not self.keep_running or new_balance is None:
return
if fetchPaymentsIfChanged: # Fetching *all* payments isn't necessary if balance was changed by a payment notification
print("Refreshing payments...")
self.fetch_payments() # if the balance changed, then re-list transactions
def fetch_payments(self):
if not self.keep_running:
return
# Create get_balance request
list_transactions = {
"method": "list_transactions",
"params": {
"limit": self.PAYMENTS_TO_SHOW
}
}
dm = EncryptedDirectMessage(
recipient_pubkey=self.wallet_pubkey,
cleartext_content=json.dumps(list_transactions),
kind=23194
)
self.private_key.sign_event(dm) # sign also does encryption if it's a encrypted dm
print("\nPublishing DM to fetch payments...")
self.relay_manager.publish_event(dm)
self.transactions_welcome = True
def handle_new_payments(self, new_payments):
if not self.keep_running or not self.transactions_welcome:
return
print("handle_new_payments")
if self.payment_list != new_payments:
print("new list of payments")
self.payment_list = new_payments
self.payments_updated_cb()
def payments_updated_cb(self):
print("payments_updated_cb called, now closing everything!")
self.keep_running = False
def getCommentFromTransaction(self, transaction):
comment = ""
try:
comment = transaction["description"]
json_comment = json.loads(comment)
for field in json_comment:
if field[0] == "text/plain":
comment = field[1]
break
else:
print("text/plain field is missing from JSON description")
except Exception as e:
print(f"Info: could not parse comment as JSON, this is fine, using as-is ({e})")
return comment
async def NOmainHERE(self):
self.keep_running = True
self.private_key = PrivateKey(bytes.fromhex(self.secret))
self.relay_manager = RelayManager()
for relay in self.relays:
self.relay_manager.add_relay(relay)
print(f"DEBUG: Opening relay connections")
await self.relay_manager.open_connections({"cert_reqs": ssl.CERT_NONE})
self.connected = False
for _ in range(20):
print("Waiting for relay connection...")
await asyncio.sleep(0.5)
nrconnected = 0
for index, relay in enumerate(self.relays):
try:
relay = self.relay_manager.relays[self.relays[index]]
if relay.connected is True:
print(f"connected: {self.relays[index]}")
nrconnected += 1
else:
print(f"not connected: {self.relays[index]}")
except Exception as e:
print(f"could not find relay: {e}")
break # not all of them have been initialized, skip...
self.connected = ( nrconnected == len(self.relays) )
if self.connected:
print("All relays connected!")
break
if not self.connected or not self.keep_running:
print(f"ERROR: could not connect to relay or not self.keep_running, aborting...")
# TODO: call an error callback to notify the user
return
# Set up subscription to receive response
self.subscription_id = "micropython_nwc_" + str(round(time.time()))
print(f"DEBUG: Setting up subscription with ID: {self.subscription_id}")
self.filters = Filters([Filter(
#event_ids=[self.subscription_id], # would be nice to filter, but not like this
kinds=[23195, 23196], # NWC reponses and notifications
authors=[self.wallet_pubkey],
pubkey_refs=[self.private_key.public_key.hex()]
)])
print(f"DEBUG: Subscription filters: {self.filters.to_json_array()}")
self.relay_manager.add_subscription(self.subscription_id, self.filters)
print(f"DEBUG: Creating subscription request")
request_message = [ClientMessageType.REQUEST, self.subscription_id]
request_message.extend(self.filters.to_json_array())
print(f"DEBUG: Publishing subscription request")
self.relay_manager.publish_message(json.dumps(request_message))
print(f"DEBUG: Published subscription request")
for _ in range(4):
if not self.keep_running:
return
print("Waiting a bit before self.fetch_balance()")
await asyncio.sleep(0.5)
await self.fetch_balance()
while True:
print(f"checking for incoming events...")
await asyncio.sleep(1)
if not self.keep_running:
print("NWCWallet: not keep_running, closing connections...")
await self.relay_manager.close_connections()
break
start_time = time.ticks_ms()
if self.relay_manager.message_pool.has_events():
print(f"DEBUG: Event received from message pool after {time.ticks_ms()-start_time}ms")
event_msg = self.relay_manager.message_pool.get_event()
event_created_at = event_msg.event.created_at
print(f"Received at {time.localtime()} a message with timestamp {event_created_at} after {time.ticks_ms()-start_time}ms")
try:
# This takes a very long time, even for short messages:
decrypted_content = self.private_key.decrypt_message(
event_msg.event.content,
event_msg.event.public_key,
)
print(f"DEBUG: Decrypted content: {decrypted_content} after {time.ticks_ms()-start_time}ms")
response = json.loads(decrypted_content)
print(f"DEBUG: Parsed response: {response}")
result = response.get("result")
if result:
if result.get("balance") is not None:
new_balance = round(int(result["balance"]) / 1000)
print(f"Got balance: {new_balance}")
self.handle_new_balance(new_balance)
elif result.get("transactions") is not None:
print("Response contains transactions!")
new_payment_list = UniqueSortedList()
for transaction in result["transactions"]:
amount = transaction["amount"]
amount = round(amount / 1000)
comment = self.getCommentFromTransaction(transaction)
epoch_time = transaction["created_at"]
paymentObj = Payment(epoch_time, amount, comment)
new_payment_list.add(paymentObj)
if len(new_payment_list) > 0:
# do them all in one shot instead of one-by-one because the lv_async() isn't always chronological,
# so when a long list of payments is added, it may be overwritten by a short list
self.handle_new_payments(new_payment_list)
else:
notification = response.get("notification")
if notification:
amount = notification["amount"]
amount = round(amount / 1000)
type = notification["type"]
if type == "outgoing":
amount = -amount
elif type == "incoming":
new_balance = self.last_known_balance + amount
self.handle_new_balance(new_balance, False) # don't trigger full fetch because payment info is in notification
epoch_time = notification["created_at"]
comment = self.getCommentFromTransaction(notification)
paymentObj = Payment(epoch_time, amount, comment)
self.handle_new_payment(paymentObj)
else:
print(f"WARNING: invalid notification type {type}, ignoring.")
else:
print("Unsupported response, ignoring.")
except Exception as e:
print(f"DEBUG: Error processing response: {e}")
else:
#print(f"pool has no events after {time.ticks_ms()-start_time}ms") # completes in 0-1ms
pass
def test_it(self):
print("before do_two")
asyncio.run(self.do_two())
print("after do_two")
def do_two(self):
print("before await self.NOmainHERE()")
await self.NOmainHERE()
print("after await self.NOmainHERE()")
+49
View File
@@ -0,0 +1,49 @@
import asyncio
import json
import ssl
import _thread
import time
import unittest
from mpos import App, PackageManager
import mpos.apps
import sys
sys.path.append("apps/com.lightningpiggy.displaywallet/assets/")
from wallet import NWCWallet
class TestNWCWallet(unittest.TestCase):
redraw_balance_cb_called = 0
redraw_payments_cb_called = 0
redraw_static_receive_code_cb_called = 0
error_callback_called = 0
def redraw_balance_cb(self, balance=0):
print(f"redraw_callback called, balance: {balance}")
self.redraw_balance_cb_called += 1
def redraw_payments_cb(self):
print(f"redraw_payments_cb called")
self.redraw_payments_cb_called += 1
def redraw_static_receive_code_cb(self):
print(f"redraw_static_receive_code_cb called")
self.redraw_static_receive_code_cb_called += 1
def error_callback(self, error):
print(f"error_callback called, error: {error}")
self.error_callback_called += 1
def test_it(self):
print("starting test")
self.wallet = NWCWallet("nostr+walletconnect://e46762afab282c324278351165122345f9983ea447b47943b052100321227571?relay=ws://192.168.1.16:5000/nostrclient/api/v1/relay&secret=fab0a9a11d4cf4b1d92e901a0b2c56634275e2fa1a7eb396ff1b942f95d59fd3&lud16=test@example.com")
self.wallet.start(self.redraw_balance_cb, self.redraw_payments_cb, self.redraw_static_receive_code_cb, self.error_callback)
time.sleep(15)
self.assertTrue(self.redraw_balance_cb_called > 0)
self.assertTrue(self.redraw_payments_cb_called > 0)
self.assertTrue(self.redraw_static_receive_code_cb_called > 0)
self.assertTrue(self.error_callback_called == 0)
print("test finished")
+255
View File
@@ -0,0 +1,255 @@
import unittest
import _thread
import time
from mpos import App, PackageManager
import mpos.apps
from websocket import WebSocketApp
# demo_multiple_ws.py
import asyncio
import aiohttp
from aiohttp import WSMsgType
import logging
import sys
from typing import List
# ----------------------------------------------------------------------
# Logging
# ----------------------------------------------------------------------
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
stream=sys.stdout,
)
log = logging.getLogger(__name__)
class TestTwoWebsockets(unittest.TestCase):
#class TestTwoWebsockets():
# ----------------------------------------------------------------------
# Configuration
# ----------------------------------------------------------------------
# Change these to point to a real echo / chat server you control.
WS_URLS = [
"wss://echo.websocket.org", # public echo service (may be down)
"wss://echo.websocket.org", # duplicate on purpose shows concurrency
"wss://echo.websocket.org",
# add more URLs here…
]
nr_connected = 0
# How many messages each connection should send before closing gracefully
MESSAGES_PER_CONNECTION = 2
STOP_AFTER = 10
# ----------------------------------------------------------------------
# One connection worker
# ----------------------------------------------------------------------
async def ws_worker(self, session: aiohttp.ClientSession, url: str, idx: int) -> None:
"""
Handles a single WebSocket connection:
* sends a few messages,
* echoes back everything it receives,
* closes when the remote end says "close" or after MESSAGES_PER_CONNECTION.
"""
try:
async with session.ws_connect(url) as ws:
log.info(f"[{idx}] Connected to {url}")
self.nr_connected += 1
# ------------------------------------------------------------------
# 1. Send a few starter messages
# ------------------------------------------------------------------
for i in range(self.MESSAGES_PER_CONNECTION):
payload = f"Hello from client #{idx} msg {i+1}"
await ws.send_str(payload)
log.info(f"[{idx}] → {payload}")
# give the server a moment to reply
await asyncio.sleep(0.5)
# ------------------------------------------------------------------
# 2. Echo-loop react to incoming messages
# ------------------------------------------------------------------
msgcounter = 0
async for msg in ws:
msgcounter += 1
if msgcounter > self.STOP_AFTER:
print("Max reached, stopping...")
await ws.close()
break
if msg.type == WSMsgType.TEXT:
data: str = msg.data
log.info(f"[{idx}] ← {data}")
# Echo back (with a suffix)
reply = data + " / answer"
await ws.send_str(reply)
log.info(f"[{idx}] → {reply}")
# Close if server asks us to
if data.strip().lower() == "close cmd":
log.info(f"[{idx}] Server asked to close → closing")
await ws.close()
break
elif msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED):
log.info(f"[{idx}] Connection closed by remote")
break
elif msg.type == WSMsgType.ERROR:
log.error(f"[{idx}] WebSocket error: {ws.exception()}")
break
except asyncio.CancelledError:
log.info(f"[{idx}] Task cancelled")
raise
except Exception as exc:
log.exception(f"[{idx}] Unexpected error on {url}: {exc}")
finally:
log.info(f"[{idx}] Worker finished for {url}")
# ----------------------------------------------------------------------
# Main entry point creates a single ClientSession + many tasks
# ----------------------------------------------------------------------
async def main(self) -> None:
async with aiohttp.ClientSession() as session:
# Create one task per URL (they all run concurrently)
tasks = [
asyncio.create_task(self.ws_worker(session, url, idx))
for idx, url in enumerate(self.WS_URLS)
]
log.info(f"Starting {len(tasks)} concurrent WebSocket connections…")
# Wait for *all* of them to finish (or be cancelled)
await asyncio.gather(*tasks, return_exceptions=True)
log.info(f"All tasks stopped successfully!")
self.assertTrue(self.nr_connected, len(self.WS_URLS))
def newthread(self):
asyncio.run(self.main())
def test_it(self):
_thread.stack_size(mpos.apps.good_stack_size())
_thread.start_new_thread(self.newthread, ())
time.sleep(10)
# This demonstrates a crash when doing asyncio using different threads:
#class TestCrashingSeparateThreads(unittest.TestCase):
class TestCrashingSeparateThreads():
# ----------------------------------------------------------------------
# Configuration
# ----------------------------------------------------------------------
# Change these to point to a real echo / chat server you control.
WS_URLS = [
"wss://echo.websocket.org", # public echo service (may be down)
"wss://echo.websocket.org", # duplicate on purpose shows concurrency
"wss://echo.websocket.org",
# add more URLs here…
]
# How many messages each connection should send before closing gracefully
MESSAGES_PER_CONNECTION = 2
STOP_AFTER = 10
# ----------------------------------------------------------------------
# One connection worker
# ----------------------------------------------------------------------
async def ws_worker(self, session: aiohttp.ClientSession, url: str, idx: int) -> None:
"""
Handles a single WebSocket connection:
* sends a few messages,
* echoes back everything it receives,
* closes when the remote end says "close" or after MESSAGES_PER_CONNECTION.
"""
try:
async with session.ws_connect(url) as ws:
log.info(f"[{idx}] Connected to {url}")
# ------------------------------------------------------------------
# 1. Send a few starter messages
# ------------------------------------------------------------------
for i in range(self.MESSAGES_PER_CONNECTION):
payload = f"Hello from client #{idx} msg {i+1}"
await ws.send_str(payload)
log.info(f"[{idx}] → {payload}")
# give the server a moment to reply
await asyncio.sleep(0.5)
# ------------------------------------------------------------------
# 2. Echo-loop react to incoming messages
# ------------------------------------------------------------------
msgcounter = 0
async for msg in ws:
msgcounter += 1
if msgcounter > self.STOP_AFTER:
print("Max reached, stopping...")
await ws.close()
break
if msg.type == WSMsgType.TEXT:
data: str = msg.data
log.info(f"[{idx}] ← {data}")
# Echo back (with a suffix)
reply = data + " / answer"
await ws.send_str(reply)
log.info(f"[{idx}] → {reply}")
# Close if server asks us to
if data.strip().lower() == "close cmd":
log.info(f"[{idx}] Server asked to close → closing")
await ws.close()
break
elif msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED):
log.info(f"[{idx}] Connection closed by remote")
break
elif msg.type == WSMsgType.ERROR:
log.error(f"[{idx}] WebSocket error: {ws.exception()}")
break
except asyncio.CancelledError:
log.info(f"[{idx}] Task cancelled")
raise
except Exception as exc:
log.exception(f"[{idx}] Unexpected error on {url}: {exc}")
finally:
log.info(f"[{idx}] Worker finished for {url}")
# ----------------------------------------------------------------------
# Main entry point creates a single ClientSession + many tasks
# ----------------------------------------------------------------------
async def main(self) -> None:
async with aiohttp.ClientSession() as session:
# Create one task per URL (they all run concurrently)
tasks = [
asyncio.create_task(self.ws_worker(session, url, idx))
for idx, url in enumerate(self.WS_URLS)
]
log.info(f"Starting {len(tasks)} concurrent WebSocket connections…")
# Wait for *all* of them to finish (or be cancelled)
await asyncio.gather(*tasks, return_exceptions=True)
async def almostmain(self, url):
async with aiohttp.ClientSession() as session:
asyncio.create_task(self.ws_worker(session, url, idx))
def newthread(self, url):
asyncio.run(self.main())
def test_it(self):
for url in self.WS_URLS:
_thread.stack_size(mpos.apps.good_stack_size())
_thread.start_new_thread(self.newthread, (url,))
time.sleep(15)