Source code for kuhl_haus.mdp.components.widget_data_service

"""WebSocket-to-Redis bridge for real-time market data delivery to clients.

Manages client WebSocket subscriptions to Redis pub/sub channels, handles fan-out
from Redis messages to connected clients, and provides snapshot API for cached data.
Designed for scenarios where multiple browser/dashboard clients subscribe to the
same market data feeds with minimal latency.
"""
import asyncio
import json
import logging
from typing import Dict, Set, Any

import redis.asyncio as redis
from fastapi import WebSocket

from kuhl_haus.mdp.helpers.observability import get_tracer, get_meter

tracer = get_tracer(__name__)


[docs] class UnauthorizedException(Exception): """Raised when client attempts unauthorized operation.""" pass
[docs] class WidgetDataService: """Fan out Redis pub/sub messages to WebSocket clients. Maintains a registry of WebSocket connections per Redis channel, subscribes to channels on-demand as clients join, and fans out incoming pub/sub messages to all subscribers. The background pub/sub task starts lazily on first subscription and stops when the last client disconnects. Pattern/wildcard subscriptions are supported. Clients can fetch snapshots from Redis cache via get_cache for initial state before streaming updates. Concurrency: Single background task (_handle_pubsub) polls Redis and sends to all WebSockets in the event loop. Lock protects subscription dict mutations. """
[docs] def __init__(self, redis_client: redis.Redis, pubsub_client: redis.client.PubSub): self.redis_client: redis.Redis = redis_client self.pubsub_client: redis.client.PubSub = pubsub_client self.logger = logging.getLogger(__name__) # Track active WebSocket connections per feed self.subscriptions: Dict[str, Set[WebSocket]] = {} self._pubsub_task: asyncio.Task = None self._pubsub_lock = asyncio.Lock() # Metrics meter = get_meter(__name__) self.subscription_counter = meter.create_up_down_counter( name="wds.subscriptions", description="Number of active subscriptions", unit="1" ) self.cache_hit_counter = meter.create_counter( name="wds.cache_hit", description="Number of times get_cache returns a result", unit="1" ) self.cache_miss_counter = meter.create_counter( name="wds.cache_miss", description="Number of times get_cache returns nothing", unit="1" ) self.message_received_counter = meter.create_counter( name="wds.messages_received", description="Number of messages received from Redis pub/sub", unit="1" ) self.message_sent_counter = meter.create_counter( name="wds.messages_sent", description="Number of messages sent to WebSocket clients", unit="1" ) self.mdc_connected = False
[docs] @tracer.start_as_current_span("wds.start") async def start(self): """Verify Redis connectivity. Pub/sub task starts lazily on first client subscription; this method simply PINGs Redis to confirm the connection is healthy. """ self.logger.info("wds.starting") await self.redis_client.ping() self.mdc_connected = True self.logger.info("wds.started")
[docs] @tracer.start_as_current_span("wds.stop") async def stop(self): """Cancel pub/sub task if running. Does not close Redis client (caller owns that lifecycle). """ self.logger.info("wds.stopping") if self._pubsub_task: self._pubsub_task.cancel() try: await self._pubsub_task except asyncio.CancelledError: pass self._pubsub_task = None self.logger.info("wds.stopped")
[docs] @tracer.start_as_current_span("wds.subscribe_feed") async def subscribe(self, feed: str, websocket: WebSocket): """Add WebSocket client to Redis channel subscription. Creates Redis subscription if this is the first client for the channel. Starts background pub/sub task if this is the first subscription overall. Supports wildcard patterns (e.g., ``leaderboard:*``). Args: feed: Redis channel name or pattern. websocket: FastAPI WebSocket connection. Side effects: Subscribes to Redis channel; spawns background task on first use. """ async with self._pubsub_lock: if feed not in self.subscriptions: self.subscriptions[feed] = set() self.subscription_counter.add(1) if "*" in feed: await self.pubsub_client.psubscribe(feed) else: await self.pubsub_client.subscribe(feed) self.logger.debug(f"wds.feed.subscribed feed:{feed}, total_feeds:{len(self.subscriptions)}") # First subscription: start pub/sub task if len(self.subscriptions.keys()) == 1 and self._pubsub_task is None: self._pubsub_task = asyncio.create_task(self._handle_pubsub()) self.logger.debug("wds.pubsub.task_started") self.subscriptions[feed].add(websocket) self.logger.debug(f"wds.client.subscribed feed:{feed}, clients:{len(self.subscriptions[feed])}")
[docs] @tracer.start_as_current_span("wds.unsubscribe_feed") async def unsubscribe(self, feed: str, websocket: WebSocket): """Remove WebSocket client from Redis channel subscription. Unsubscribes from Redis if this was the last client for the channel. Stops background pub/sub task if this was the last subscription overall. Args: feed: Redis channel name or pattern. websocket: FastAPI WebSocket connection. Side effects: Unsubscribes from Redis channel; cancels background task if idle. """ async with self._pubsub_lock: if feed in self.subscriptions: self.subscriptions[feed].discard(websocket) self.subscription_counter.add(-1) if not self.subscriptions[feed]: if "*" in feed: await self.pubsub_client.punsubscribe(feed) else: await self.pubsub_client.unsubscribe(feed) del self.subscriptions[feed] self.logger.debug(f"wds.feed.unsubscribed feed:{feed}, total_feeds:{len(self.subscriptions)}") else: self.logger.debug(f"wds.client.unsubscribed feed:{feed}, clients:{len(self.subscriptions[feed])}") # Last subscription removed: stop pub/sub task if not self.subscriptions and self._pubsub_task: try: self._pubsub_task.cancel() await self._pubsub_task except asyncio.CancelledError: pass except RuntimeError: pass self._pubsub_task = None self.logger.debug("wds.pubsub.task_stopped")
[docs] @tracer.start_as_current_span("wds.disconnect") async def disconnect(self, websocket: WebSocket): """Unsubscribe client from all channels. Iterates over all feeds the WebSocket is subscribed to and calls unsubscribe for each. Called when WebSocket connection closes or client disconnects. """ subs = [] async with self._pubsub_lock: feeds = self.subscriptions.keys() for feed in feeds: self.logger.debug(f"wds.client.disconnecting feed:{feed}") subs.append(f"{feed}") for sub in subs: await self.unsubscribe(sub, websocket)
[docs] @tracer.start_as_current_span("wds.get_cache") async def get_cache(self, cache_key: str, limit: int = 0) -> list[Any] | None | Any: """Retrieve cached market data for initial client snapshot. Clients typically call this before subscribing to a feed to get current state, then receive incremental updates via WebSocket. Returns None if key not found. Supports both Redis string keys (returns dict) and list keys (returns list of dicts). List keys are used for rolling news feed caches (news:feed:latest, news:ticker:*). Args: cache_key: Redis key to fetch. limit: Maximum number of items to return from list keys. 0 means fetch all. """ self.logger.debug(f"wds.cache.get cache_key:{cache_key}") key_type = await self.redis_client.type(cache_key) # redis-py returns str, not bytes (e.g. "list", "string", "none") key_type_str = key_type.decode() if isinstance(key_type, bytes) else key_type if key_type_str == "list": end = (limit - 1) if limit > 0 else -1 values = await self.redis_client.lrange(cache_key, 0, end) if values: self.logger.debug(f"wds.cache.hit cache_key:{cache_key} type:list len:{len(values)}") self.cache_hit_counter.add(1) return [json.loads(v) for v in values] self.logger.debug(f"wds.cache.miss cache_key:{cache_key} type:list") self.cache_miss_counter.add(1) return [] if key_type_str == "string": value = await self.redis_client.get(cache_key) if value: self.logger.debug(f"wds.cache.hit cache_key:{cache_key} type:string") self.cache_hit_counter.add(1) return json.loads(value) self.logger.debug(f"wds.cache.miss cache_key:{cache_key}") self.cache_miss_counter.add(1) return None
@tracer.start_as_current_span("wds._handle_pubsub") async def _handle_pubsub(self): """Background task that polls Redis pub/sub and fans out to WebSocket clients. Runs indefinitely until cancelled. Fetches messages from pub/sub client, handles subscription lifecycle events, and sends data messages to all WebSockets subscribed to the source channel. Auto-disconnects clients that fail to receive messages (closed connections). Side effects: Calls WebSocket.send_text (network I/O); calls unsubscribe for dead connections. """ try: self.logger.info("wds.pubsub.starting") message_count = 0 while True: # get_message() requires active subscriptions message = await self.pubsub_client.get_message( ignore_subscribe_messages=False, timeout=1.0 ) if message is None: # Timeout reached, no message available await asyncio.sleep(0.01) continue msg_type = message.get("type") # Log subscription lifecycle events if msg_type in ("subscribe", "psubscribe"): self.logger.debug(f"wds.pubsub.subscribed channel:{message['channel']}, num_subs:{message['data']}") elif msg_type in ("unsubscribe", "punsubscribe"): self.logger.debug(f"wds.pubsub.unsubscribed channel:{message['channel']}, num_subs:{message['data']}") # Process actual data messages elif msg_type in ("message", "pmessage"): message_count += 1 self.message_received_counter.add(1) feed = message["channel"] data = message["data"] self.logger.debug(f"wds.pubsub.message feed:{feed}, data_len:{len(data)}, msg_num:{message_count}") if feed in self.subscriptions: # Fan out to all WebSocket clients subscribed to this feed disconnected = [] sent_count = 0 for ws in list(self.subscriptions[feed]): # snapshot prevents RuntimeError on concurrent disconnect try: await ws.send_text(data) sent_count += 1 self.message_sent_counter.add(1) except Exception as e: self.logger.error(f"wds.send.failed feed:{feed}, error:{repr(e)}") disconnected.append(ws) self.logger.debug(f"wds.fanout.complete feed:{feed}, sent:{sent_count}, failed:{len(disconnected)}") # Clean up disconnected clients for ws in disconnected: await self.unsubscribe(feed, ws) else: self.logger.warning(f"wds.pubsub.orphan feed:{feed}, msg:Received message for untracked feed") except asyncio.CancelledError: self.logger.info("wds.pubsub.cancelled") raise except Exception as e: self.logger.error(f"wds.pubsub.error error:{repr(e)}", exc_info=True) raise