import threading
class TaskJournal:
"""State machine for tracking tasks through their asynchronous lifecycle."""
def __init__(self):
self.lock = threading.Lock()
self.tasks = {} # task_id -> { "event": Event, "result": None, "node_id": str }
def register(self, task_id, node_id=None):
"""Initializes state for a new task and returns its notification event."""
event = threading.Event()
with self.lock:
self.tasks[task_id] = {"event": event, "result": None, "node_id": node_id}
return event
def fulfill(self, task_id, result):
"""Processes a result from a node and triggers the waiting thread."""
with self.lock:
if task_id in self.tasks:
self.tasks[task_id]["result"] = result
self.tasks[task_id]["event"].set()
return True
return False
def get_result(self, task_id):
"""Returns the result associated with the given task ID."""
with self.lock:
data = self.tasks.get(task_id)
return data["result"] if data else None
def pop(self, task_id):
"""Removes the task's state from the journal."""
with self.lock:
return self.tasks.pop(task_id, None)