You've already forked MicroPythonOS
mirror of
https://github.com/m5stack/MicroPythonOS.git
synced 2026-05-20 11:51:27 -07:00
Add unit tests
This commit is contained in:
@@ -8,4 +8,5 @@ internal_filesystem/SDLPointer_3
|
||||
|
||||
# config files etc:
|
||||
internal_filesystem/data
|
||||
internal_filesystem/sdcard
|
||||
|
||||
|
||||
@@ -0,0 +1,333 @@
|
||||
import asyncio
|
||||
import json
|
||||
import ssl
|
||||
import _thread
|
||||
import time
|
||||
import unittest
|
||||
|
||||
from mpos import App, PackageManager
|
||||
import mpos.apps
|
||||
|
||||
from nostr.relay_manager import RelayManager
|
||||
from nostr.message_type import ClientMessageType
|
||||
from nostr.filter import Filter, Filters
|
||||
from nostr.event import EncryptedDirectMessage
|
||||
from nostr.key import PrivateKey
|
||||
|
||||
|
||||
# keeps a list of items
|
||||
# The .add() method ensures the list remains unique (via __eq__)
|
||||
# and sorted (via __lt__) by inserting new items in the correct position.
|
||||
class UniqueSortedList:
|
||||
def __init__(self):
|
||||
self._items = []
|
||||
|
||||
def add(self, item):
|
||||
#print(f"before add: {str(self)}")
|
||||
# Check if item already exists (using __eq__)
|
||||
if item not in self._items:
|
||||
# Insert item in sorted position for descending order (using __gt__)
|
||||
for i, existing_item in enumerate(self._items):
|
||||
if item > existing_item:
|
||||
self._items.insert(i, item)
|
||||
return
|
||||
# If item is smaller than all existing items, append it
|
||||
self._items.append(item)
|
||||
#print(f"after add: {str(self)}")
|
||||
|
||||
def __iter__(self):
|
||||
# Return iterator for the internal list
|
||||
return iter(self._items)
|
||||
|
||||
def get(self, index_nr):
|
||||
# Retrieve item at given index, raise IndexError if invalid
|
||||
try:
|
||||
return self._items[index_nr]
|
||||
except IndexError:
|
||||
raise IndexError("Index out of range")
|
||||
|
||||
def __len__(self):
|
||||
# Return the number of items for len() calls
|
||||
return len(self._items)
|
||||
|
||||
def __str__(self):
|
||||
#print("UniqueSortedList tostring called")
|
||||
return "\n".join(str(item) for item in self._items)
|
||||
|
||||
def __eq__(self, other):
|
||||
if len(self._items) != len(other):
|
||||
return False
|
||||
return all(p1 == p2 for p1, p2 in zip(self._items, other))
|
||||
|
||||
# Payment class remains unchanged
|
||||
class Payment:
|
||||
def __init__(self, epoch_time, amount_sats, comment):
|
||||
self.epoch_time = epoch_time
|
||||
self.amount_sats = amount_sats
|
||||
self.comment = comment
|
||||
|
||||
def __str__(self):
|
||||
sattext = "sats"
|
||||
if self.amount_sats == 1:
|
||||
sattext = "sat"
|
||||
#return f"{self.amount_sats} {sattext} @ {self.epoch_time}: {self.comment}"
|
||||
return f"{self.amount_sats} {sattext}: {self.comment}"
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, Payment):
|
||||
return False
|
||||
return self.epoch_time == other.epoch_time and self.amount_sats == other.amount_sats and self.comment == other.comment
|
||||
|
||||
def __lt__(self, other):
|
||||
if not isinstance(other, Payment):
|
||||
return NotImplemented
|
||||
return (self.epoch_time, self.amount_sats, self.comment) < (other.epoch_time, other.amount_sats, other.comment)
|
||||
|
||||
def __le__(self, other):
|
||||
if not isinstance(other, Payment):
|
||||
return NotImplemented
|
||||
return (self.epoch_time, self.amount_sats, self.comment) <= (other.epoch_time, other.amount_sats, other.comment)
|
||||
|
||||
def __gt__(self, other):
|
||||
if not isinstance(other, Payment):
|
||||
return NotImplemented
|
||||
return (self.epoch_time, self.amount_sats, self.comment) > (other.epoch_time, other.amount_sats, other.comment)
|
||||
|
||||
def __ge__(self, other):
|
||||
if not isinstance(other, Payment):
|
||||
return NotImplemented
|
||||
return (self.epoch_time, self.amount_sats, self.comment) >= (other.epoch_time, other.amount_sats, other.comment)
|
||||
|
||||
|
||||
|
||||
class TestNostr(unittest.TestCase):
|
||||
|
||||
PAYMENTS_TO_SHOW = 5
|
||||
|
||||
keep_running = None
|
||||
connected = None
|
||||
balance = -1
|
||||
payment_list = []
|
||||
transactions_welcome = False
|
||||
|
||||
relays = [ "ws://192.168.1.16:5000/nostrrelay/test", "ws://192.168.1.16:5000/nostrclient/api/v1/relay" ]
|
||||
#relays = [ "ws://127.0.0.1:5000/nostrrelay/test", "ws://127.0.0.1:5000/nostrclient/api/v1/relay" ]
|
||||
#relays = [ "wss://relay.damus.io", "wss://nostr-pub.wellorder.net" ]
|
||||
#relays = [ "ws://127.0.0.1:5000/nostrrelay/test", "ws://127.0.0.1:5000/nostrclient/api/v1/relay", "wss://relay.damus.io", "wss://nostr-pub.wellorder.net" ]
|
||||
#relays = [ "ws://127.0.0.1:5000/nostrclient/api/v1/relay", "wss://relay.damus.io", "wss://nostr-pub.wellorder.net" ]
|
||||
secret = "fab0a9a11d4cf4b1d92e901a0b2c56634275e2fa1a7eb396ff1b942f95d59fd3"
|
||||
wallet_pubkey = "e46762afab282c324278351165122345f9983ea447b47943b052100321227571"
|
||||
|
||||
async def fetch_balance(self):
|
||||
if not self.keep_running:
|
||||
return
|
||||
# Create get_balance request
|
||||
balance_request = {
|
||||
"method": "get_balance",
|
||||
"params": {}
|
||||
}
|
||||
print(f"DEBUG: Created balance request: {balance_request}")
|
||||
print(f"DEBUG: Creating encrypted DM to wallet pubkey: {self.wallet_pubkey}")
|
||||
dm = EncryptedDirectMessage(
|
||||
recipient_pubkey=self.wallet_pubkey,
|
||||
cleartext_content=json.dumps(balance_request),
|
||||
kind=23194
|
||||
)
|
||||
print(f"DEBUG: Signing DM {json.dumps(dm)} with private key")
|
||||
self.private_key.sign_event(dm) # sign also does encryption if it's a encrypted dm
|
||||
print(f"DEBUG: Publishing encrypted DM")
|
||||
self.relay_manager.publish_event(dm)
|
||||
|
||||
def handle_new_balance(self, new_balance, fetchPaymentsIfChanged=True):
|
||||
if not self.keep_running or new_balance is None:
|
||||
return
|
||||
if fetchPaymentsIfChanged: # Fetching *all* payments isn't necessary if balance was changed by a payment notification
|
||||
print("Refreshing payments...")
|
||||
self.fetch_payments() # if the balance changed, then re-list transactions
|
||||
|
||||
def fetch_payments(self):
|
||||
if not self.keep_running:
|
||||
return
|
||||
# Create get_balance request
|
||||
list_transactions = {
|
||||
"method": "list_transactions",
|
||||
"params": {
|
||||
"limit": self.PAYMENTS_TO_SHOW
|
||||
}
|
||||
}
|
||||
dm = EncryptedDirectMessage(
|
||||
recipient_pubkey=self.wallet_pubkey,
|
||||
cleartext_content=json.dumps(list_transactions),
|
||||
kind=23194
|
||||
)
|
||||
self.private_key.sign_event(dm) # sign also does encryption if it's a encrypted dm
|
||||
print("\nPublishing DM to fetch payments...")
|
||||
self.relay_manager.publish_event(dm)
|
||||
self.transactions_welcome = True
|
||||
|
||||
def handle_new_payments(self, new_payments):
|
||||
if not self.keep_running or not self.transactions_welcome:
|
||||
return
|
||||
print("handle_new_payments")
|
||||
if self.payment_list != new_payments:
|
||||
print("new list of payments")
|
||||
self.payment_list = new_payments
|
||||
self.payments_updated_cb()
|
||||
|
||||
def payments_updated_cb(self):
|
||||
print("payments_updated_cb called, now closing everything!")
|
||||
self.keep_running = False
|
||||
|
||||
def getCommentFromTransaction(self, transaction):
|
||||
comment = ""
|
||||
try:
|
||||
comment = transaction["description"]
|
||||
json_comment = json.loads(comment)
|
||||
for field in json_comment:
|
||||
if field[0] == "text/plain":
|
||||
comment = field[1]
|
||||
break
|
||||
else:
|
||||
print("text/plain field is missing from JSON description")
|
||||
except Exception as e:
|
||||
print(f"Info: could not parse comment as JSON, this is fine, using as-is ({e})")
|
||||
return comment
|
||||
|
||||
|
||||
async def NOmainHERE(self):
|
||||
self.keep_running = True
|
||||
self.private_key = PrivateKey(bytes.fromhex(self.secret))
|
||||
self.relay_manager = RelayManager()
|
||||
for relay in self.relays:
|
||||
self.relay_manager.add_relay(relay)
|
||||
|
||||
print(f"DEBUG: Opening relay connections")
|
||||
await self.relay_manager.open_connections({"cert_reqs": ssl.CERT_NONE})
|
||||
self.connected = False
|
||||
for _ in range(20):
|
||||
print("Waiting for relay connection...")
|
||||
await asyncio.sleep(0.5)
|
||||
nrconnected = 0
|
||||
for index, relay in enumerate(self.relays):
|
||||
try:
|
||||
relay = self.relay_manager.relays[self.relays[index]]
|
||||
if relay.connected is True:
|
||||
print(f"connected: {self.relays[index]}")
|
||||
nrconnected += 1
|
||||
else:
|
||||
print(f"not connected: {self.relays[index]}")
|
||||
except Exception as e:
|
||||
print(f"could not find relay: {e}")
|
||||
break # not all of them have been initialized, skip...
|
||||
self.connected = ( nrconnected == len(self.relays) )
|
||||
if self.connected:
|
||||
print("All relays connected!")
|
||||
break
|
||||
if not self.connected or not self.keep_running:
|
||||
print(f"ERROR: could not connect to relay or not self.keep_running, aborting...")
|
||||
# TODO: call an error callback to notify the user
|
||||
return
|
||||
|
||||
# Set up subscription to receive response
|
||||
self.subscription_id = "micropython_nwc_" + str(round(time.time()))
|
||||
print(f"DEBUG: Setting up subscription with ID: {self.subscription_id}")
|
||||
self.filters = Filters([Filter(
|
||||
#event_ids=[self.subscription_id], # would be nice to filter, but not like this
|
||||
kinds=[23195, 23196], # NWC reponses and notifications
|
||||
authors=[self.wallet_pubkey],
|
||||
pubkey_refs=[self.private_key.public_key.hex()]
|
||||
)])
|
||||
print(f"DEBUG: Subscription filters: {self.filters.to_json_array()}")
|
||||
self.relay_manager.add_subscription(self.subscription_id, self.filters)
|
||||
print(f"DEBUG: Creating subscription request")
|
||||
request_message = [ClientMessageType.REQUEST, self.subscription_id]
|
||||
request_message.extend(self.filters.to_json_array())
|
||||
print(f"DEBUG: Publishing subscription request")
|
||||
self.relay_manager.publish_message(json.dumps(request_message))
|
||||
print(f"DEBUG: Published subscription request")
|
||||
for _ in range(4):
|
||||
if not self.keep_running:
|
||||
return
|
||||
print("Waiting a bit before self.fetch_balance()")
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
await self.fetch_balance()
|
||||
|
||||
while True:
|
||||
print(f"checking for incoming events...")
|
||||
await asyncio.sleep(1)
|
||||
if not self.keep_running:
|
||||
print("NWCWallet: not keep_running, closing connections...")
|
||||
await self.relay_manager.close_connections()
|
||||
break
|
||||
|
||||
start_time = time.ticks_ms()
|
||||
if self.relay_manager.message_pool.has_events():
|
||||
print(f"DEBUG: Event received from message pool after {time.ticks_ms()-start_time}ms")
|
||||
event_msg = self.relay_manager.message_pool.get_event()
|
||||
event_created_at = event_msg.event.created_at
|
||||
print(f"Received at {time.localtime()} a message with timestamp {event_created_at} after {time.ticks_ms()-start_time}ms")
|
||||
try:
|
||||
# This takes a very long time, even for short messages:
|
||||
decrypted_content = self.private_key.decrypt_message(
|
||||
event_msg.event.content,
|
||||
event_msg.event.public_key,
|
||||
)
|
||||
print(f"DEBUG: Decrypted content: {decrypted_content} after {time.ticks_ms()-start_time}ms")
|
||||
response = json.loads(decrypted_content)
|
||||
print(f"DEBUG: Parsed response: {response}")
|
||||
result = response.get("result")
|
||||
if result:
|
||||
if result.get("balance") is not None:
|
||||
new_balance = round(int(result["balance"]) / 1000)
|
||||
print(f"Got balance: {new_balance}")
|
||||
self.handle_new_balance(new_balance)
|
||||
elif result.get("transactions") is not None:
|
||||
print("Response contains transactions!")
|
||||
new_payment_list = UniqueSortedList()
|
||||
for transaction in result["transactions"]:
|
||||
amount = transaction["amount"]
|
||||
amount = round(amount / 1000)
|
||||
comment = self.getCommentFromTransaction(transaction)
|
||||
epoch_time = transaction["created_at"]
|
||||
paymentObj = Payment(epoch_time, amount, comment)
|
||||
new_payment_list.add(paymentObj)
|
||||
if len(new_payment_list) > 0:
|
||||
# do them all in one shot instead of one-by-one because the lv_async() isn't always chronological,
|
||||
# so when a long list of payments is added, it may be overwritten by a short list
|
||||
self.handle_new_payments(new_payment_list)
|
||||
else:
|
||||
notification = response.get("notification")
|
||||
if notification:
|
||||
amount = notification["amount"]
|
||||
amount = round(amount / 1000)
|
||||
type = notification["type"]
|
||||
if type == "outgoing":
|
||||
amount = -amount
|
||||
elif type == "incoming":
|
||||
new_balance = self.last_known_balance + amount
|
||||
self.handle_new_balance(new_balance, False) # don't trigger full fetch because payment info is in notification
|
||||
epoch_time = notification["created_at"]
|
||||
comment = self.getCommentFromTransaction(notification)
|
||||
paymentObj = Payment(epoch_time, amount, comment)
|
||||
self.handle_new_payment(paymentObj)
|
||||
else:
|
||||
print(f"WARNING: invalid notification type {type}, ignoring.")
|
||||
else:
|
||||
print("Unsupported response, ignoring.")
|
||||
except Exception as e:
|
||||
print(f"DEBUG: Error processing response: {e}")
|
||||
else:
|
||||
#print(f"pool has no events after {time.ticks_ms()-start_time}ms") # completes in 0-1ms
|
||||
pass
|
||||
|
||||
def test_it(self):
|
||||
print("before do_two")
|
||||
asyncio.run(self.do_two())
|
||||
print("after do_two")
|
||||
|
||||
def do_two(self):
|
||||
print("before await self.NOmainHERE()")
|
||||
await self.NOmainHERE()
|
||||
print("after await self.NOmainHERE()")
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
import asyncio
|
||||
import json
|
||||
import ssl
|
||||
import _thread
|
||||
import time
|
||||
import unittest
|
||||
|
||||
from mpos import App, PackageManager
|
||||
import mpos.apps
|
||||
|
||||
import sys
|
||||
sys.path.append("apps/com.lightningpiggy.displaywallet/assets/")
|
||||
from wallet import NWCWallet
|
||||
|
||||
class TestNWCWallet(unittest.TestCase):
|
||||
|
||||
redraw_balance_cb_called = 0
|
||||
redraw_payments_cb_called = 0
|
||||
redraw_static_receive_code_cb_called = 0
|
||||
error_callback_called = 0
|
||||
|
||||
def redraw_balance_cb(self, balance=0):
|
||||
print(f"redraw_callback called, balance: {balance}")
|
||||
self.redraw_balance_cb_called += 1
|
||||
|
||||
def redraw_payments_cb(self):
|
||||
print(f"redraw_payments_cb called")
|
||||
self.redraw_payments_cb_called += 1
|
||||
|
||||
def redraw_static_receive_code_cb(self):
|
||||
print(f"redraw_static_receive_code_cb called")
|
||||
self.redraw_static_receive_code_cb_called += 1
|
||||
|
||||
def error_callback(self, error):
|
||||
print(f"error_callback called, error: {error}")
|
||||
self.error_callback_called += 1
|
||||
|
||||
def test_it(self):
|
||||
print("starting test")
|
||||
self.wallet = NWCWallet("nostr+walletconnect://e46762afab282c324278351165122345f9983ea447b47943b052100321227571?relay=ws://192.168.1.16:5000/nostrclient/api/v1/relay&secret=fab0a9a11d4cf4b1d92e901a0b2c56634275e2fa1a7eb396ff1b942f95d59fd3&lud16=test@example.com")
|
||||
self.wallet.start(self.redraw_balance_cb, self.redraw_payments_cb, self.redraw_static_receive_code_cb, self.error_callback)
|
||||
time.sleep(15)
|
||||
self.assertTrue(self.redraw_balance_cb_called > 0)
|
||||
self.assertTrue(self.redraw_payments_cb_called > 0)
|
||||
self.assertTrue(self.redraw_static_receive_code_cb_called > 0)
|
||||
self.assertTrue(self.error_callback_called == 0)
|
||||
print("test finished")
|
||||
|
||||
|
||||
@@ -0,0 +1,255 @@
|
||||
import unittest
|
||||
import _thread
|
||||
import time
|
||||
|
||||
from mpos import App, PackageManager
|
||||
import mpos.apps
|
||||
|
||||
from websocket import WebSocketApp
|
||||
|
||||
|
||||
# demo_multiple_ws.py
|
||||
import asyncio
|
||||
import aiohttp
|
||||
from aiohttp import WSMsgType
|
||||
import logging
|
||||
import sys
|
||||
from typing import List
|
||||
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Logging
|
||||
# ----------------------------------------------------------------------
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(message)s",
|
||||
stream=sys.stdout,
|
||||
)
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TestTwoWebsockets(unittest.TestCase):
|
||||
#class TestTwoWebsockets():
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Configuration
|
||||
# ----------------------------------------------------------------------
|
||||
# Change these to point to a real echo / chat server you control.
|
||||
WS_URLS = [
|
||||
"wss://echo.websocket.org", # public echo service (may be down)
|
||||
"wss://echo.websocket.org", # duplicate on purpose – shows concurrency
|
||||
"wss://echo.websocket.org",
|
||||
# add more URLs here…
|
||||
]
|
||||
|
||||
nr_connected = 0
|
||||
|
||||
# How many messages each connection should send before closing gracefully
|
||||
MESSAGES_PER_CONNECTION = 2
|
||||
STOP_AFTER = 10
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# One connection worker
|
||||
# ----------------------------------------------------------------------
|
||||
async def ws_worker(self, session: aiohttp.ClientSession, url: str, idx: int) -> None:
|
||||
"""
|
||||
Handles a single WebSocket connection:
|
||||
* sends a few messages,
|
||||
* echoes back everything it receives,
|
||||
* closes when the remote end says "close" or after MESSAGES_PER_CONNECTION.
|
||||
"""
|
||||
try:
|
||||
async with session.ws_connect(url) as ws:
|
||||
log.info(f"[{idx}] Connected to {url}")
|
||||
self.nr_connected += 1
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 1. Send a few starter messages
|
||||
# ------------------------------------------------------------------
|
||||
for i in range(self.MESSAGES_PER_CONNECTION):
|
||||
payload = f"Hello from client #{idx} – msg {i+1}"
|
||||
await ws.send_str(payload)
|
||||
log.info(f"[{idx}] → {payload}")
|
||||
|
||||
# give the server a moment to reply
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 2. Echo-loop – react to incoming messages
|
||||
# ------------------------------------------------------------------
|
||||
msgcounter = 0
|
||||
async for msg in ws:
|
||||
msgcounter += 1
|
||||
if msgcounter > self.STOP_AFTER:
|
||||
print("Max reached, stopping...")
|
||||
await ws.close()
|
||||
break
|
||||
if msg.type == WSMsgType.TEXT:
|
||||
data: str = msg.data
|
||||
log.info(f"[{idx}] ← {data}")
|
||||
|
||||
# Echo back (with a suffix)
|
||||
reply = data + " / answer"
|
||||
await ws.send_str(reply)
|
||||
log.info(f"[{idx}] → {reply}")
|
||||
|
||||
# Close if server asks us to
|
||||
if data.strip().lower() == "close cmd":
|
||||
log.info(f"[{idx}] Server asked to close → closing")
|
||||
await ws.close()
|
||||
break
|
||||
|
||||
elif msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED):
|
||||
log.info(f"[{idx}] Connection closed by remote")
|
||||
break
|
||||
|
||||
elif msg.type == WSMsgType.ERROR:
|
||||
log.error(f"[{idx}] WebSocket error: {ws.exception()}")
|
||||
break
|
||||
|
||||
except asyncio.CancelledError:
|
||||
log.info(f"[{idx}] Task cancelled")
|
||||
raise
|
||||
except Exception as exc:
|
||||
log.exception(f"[{idx}] Unexpected error on {url}: {exc}")
|
||||
finally:
|
||||
log.info(f"[{idx}] Worker finished for {url}")
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Main entry point – creates a single ClientSession + many tasks
|
||||
# ----------------------------------------------------------------------
|
||||
async def main(self) -> None:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# Create one task per URL (they all run concurrently)
|
||||
tasks = [
|
||||
asyncio.create_task(self.ws_worker(session, url, idx))
|
||||
for idx, url in enumerate(self.WS_URLS)
|
||||
]
|
||||
|
||||
log.info(f"Starting {len(tasks)} concurrent WebSocket connections…")
|
||||
# Wait for *all* of them to finish (or be cancelled)
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
log.info(f"All tasks stopped successfully!")
|
||||
self.assertTrue(self.nr_connected, len(self.WS_URLS))
|
||||
|
||||
def newthread(self):
|
||||
asyncio.run(self.main())
|
||||
|
||||
def test_it(self):
|
||||
_thread.stack_size(mpos.apps.good_stack_size())
|
||||
_thread.start_new_thread(self.newthread, ())
|
||||
time.sleep(10)
|
||||
|
||||
# This demonstrates a crash when doing asyncio using different threads:
|
||||
#class TestCrashingSeparateThreads(unittest.TestCase):
|
||||
class TestCrashingSeparateThreads():
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Configuration
|
||||
# ----------------------------------------------------------------------
|
||||
# Change these to point to a real echo / chat server you control.
|
||||
WS_URLS = [
|
||||
"wss://echo.websocket.org", # public echo service (may be down)
|
||||
"wss://echo.websocket.org", # duplicate on purpose – shows concurrency
|
||||
"wss://echo.websocket.org",
|
||||
# add more URLs here…
|
||||
]
|
||||
|
||||
# How many messages each connection should send before closing gracefully
|
||||
MESSAGES_PER_CONNECTION = 2
|
||||
STOP_AFTER = 10
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# One connection worker
|
||||
# ----------------------------------------------------------------------
|
||||
async def ws_worker(self, session: aiohttp.ClientSession, url: str, idx: int) -> None:
|
||||
"""
|
||||
Handles a single WebSocket connection:
|
||||
* sends a few messages,
|
||||
* echoes back everything it receives,
|
||||
* closes when the remote end says "close" or after MESSAGES_PER_CONNECTION.
|
||||
"""
|
||||
try:
|
||||
async with session.ws_connect(url) as ws:
|
||||
log.info(f"[{idx}] Connected to {url}")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 1. Send a few starter messages
|
||||
# ------------------------------------------------------------------
|
||||
for i in range(self.MESSAGES_PER_CONNECTION):
|
||||
payload = f"Hello from client #{idx} – msg {i+1}"
|
||||
await ws.send_str(payload)
|
||||
log.info(f"[{idx}] → {payload}")
|
||||
|
||||
# give the server a moment to reply
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 2. Echo-loop – react to incoming messages
|
||||
# ------------------------------------------------------------------
|
||||
msgcounter = 0
|
||||
async for msg in ws:
|
||||
msgcounter += 1
|
||||
if msgcounter > self.STOP_AFTER:
|
||||
print("Max reached, stopping...")
|
||||
await ws.close()
|
||||
break
|
||||
if msg.type == WSMsgType.TEXT:
|
||||
data: str = msg.data
|
||||
log.info(f"[{idx}] ← {data}")
|
||||
|
||||
# Echo back (with a suffix)
|
||||
reply = data + " / answer"
|
||||
await ws.send_str(reply)
|
||||
log.info(f"[{idx}] → {reply}")
|
||||
|
||||
# Close if server asks us to
|
||||
if data.strip().lower() == "close cmd":
|
||||
log.info(f"[{idx}] Server asked to close → closing")
|
||||
await ws.close()
|
||||
break
|
||||
|
||||
elif msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED):
|
||||
log.info(f"[{idx}] Connection closed by remote")
|
||||
break
|
||||
|
||||
elif msg.type == WSMsgType.ERROR:
|
||||
log.error(f"[{idx}] WebSocket error: {ws.exception()}")
|
||||
break
|
||||
|
||||
except asyncio.CancelledError:
|
||||
log.info(f"[{idx}] Task cancelled")
|
||||
raise
|
||||
except Exception as exc:
|
||||
log.exception(f"[{idx}] Unexpected error on {url}: {exc}")
|
||||
finally:
|
||||
log.info(f"[{idx}] Worker finished for {url}")
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Main entry point – creates a single ClientSession + many tasks
|
||||
# ----------------------------------------------------------------------
|
||||
async def main(self) -> None:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# Create one task per URL (they all run concurrently)
|
||||
tasks = [
|
||||
asyncio.create_task(self.ws_worker(session, url, idx))
|
||||
for idx, url in enumerate(self.WS_URLS)
|
||||
]
|
||||
|
||||
log.info(f"Starting {len(tasks)} concurrent WebSocket connections…")
|
||||
# Wait for *all* of them to finish (or be cancelled)
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
async def almostmain(self, url):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
asyncio.create_task(self.ws_worker(session, url, idx))
|
||||
|
||||
def newthread(self, url):
|
||||
asyncio.run(self.main())
|
||||
|
||||
def test_it(self):
|
||||
for url in self.WS_URLS:
|
||||
_thread.stack_size(mpos.apps.good_stack_size())
|
||||
_thread.start_new_thread(self.newthread, (url,))
|
||||
time.sleep(15)
|
||||
Reference in New Issue
Block a user