diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index 4ef6e8f10..3f55e19eb 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -272,6 +272,8 @@ def __init__(self, det: WorkflowInstanceDetails) -> None: self._object: Any = None self._is_replaying: bool = False self._random = random.Random(det.randomness_seed) + self._current_seed = det.randomness_seed + self._seed_callbacks: list[Callable[[int], None]] = [] self._read_only = False self._in_query_or_validator = False @@ -1075,6 +1077,14 @@ def _apply_update_random_seed( self, job: temporalio.bridge.proto.workflow_activation.UpdateRandomSeed ) -> None: self._random.seed(job.randomness_seed) + self._current_seed = job.randomness_seed + # Notify all registered callbacks + for callback in self._seed_callbacks: + try: + callback(job.randomness_seed) + except Exception: + # Ignore callback errors to avoid disrupting workflow execution + pass def _make_workflow_input( self, init_job: temporalio.bridge.proto.workflow_activation.InitializeWorkflow @@ -1808,6 +1818,14 @@ def workflow_last_failure(self) -> BaseException | None: return None + def workflow_random_seed(self) -> int: + return self._current_seed + + def workflow_register_random_seed_callback( + self, callback: Callable[[int], None] + ) -> None: + self._seed_callbacks.append(callback) + #### Calls from outbound impl #### # These are in alphabetical order and all start with "_outbound_". diff --git a/temporalio/workflow.py b/temporalio/workflow.py index a1b3cb918..fb1753ed8 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -902,6 +902,14 @@ def workflow_last_completion_result(self, type_hint: type | None) -> Any | None: @abstractmethod def workflow_last_failure(self) -> BaseException | None: ... + @abstractmethod + def workflow_random_seed(self) -> int: ... + + @abstractmethod + def workflow_register_random_seed_callback( + self, callback: Callable[[int], None] + ) -> None: ... + _current_update_info: contextvars.ContextVar[UpdateInfo] = contextvars.ContextVar( "__temporal_current_update_info" @@ -1156,6 +1164,51 @@ def random() -> Random: return _Runtime.current().workflow_random() +def random_seed() -> int: + """Get the current random seed value from core. + + This returns the seed value currently being used by the workflow's + deterministic random number generator. + + Returns: + The current random seed as an integer. + """ + return _Runtime.current().workflow_random_seed() + + +def register_random_seed_callback(callback: Callable[[int], None]) -> None: + """Register a callback to be notified when the random seed changes. + + The callback will be invoked whenever the workflow receives a new random + seed from the core. This is useful for maintaining external random number + generators that need to stay in sync with the workflow's randomness. + + Args: + callback: Function to be called with the new seed value when it changes. + """ + return _Runtime.current().workflow_register_random_seed_callback(callback) + + +def new_random() -> Random: + """Create a Random instance that automatically reseeds when the workflow seed changes. + + This creates a new Random instance that is initially seeded with the current + workflow seed, and automatically registers a callback to reseed itself + whenever the workflow receives a new seed from core. + + Returns: + A Random instance that stays synchronized with the workflow's randomness. + """ + current_seed = random_seed() + auto_random = Random(current_seed) + + def reseed_callback(new_seed: int) -> None: + auto_random.seed(new_seed) + + register_random_seed_callback(reseed_callback) + return auto_random + + def time() -> float: """Current seconds since the epoch from the workflow perspective. diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index b597a85ab..e9ba467e6 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -8485,3 +8485,120 @@ async def test_disable_logger_sandbox( run_timeout=timedelta(seconds=1), retry_policy=RetryPolicy(maximum_attempts=1), ) + + +@workflow.defn +class RandomSeedTestWorkflow: + def __init__(self) -> None: + self.seed_changes: list[int] = [] + self.continue_signal_received = False + self._ready = False + + @workflow.run + async def run(self) -> dict[str, Any]: + # Get the initial seed + initial_seed = workflow.random_seed() + + # Register callback to track seed changes + workflow.register_random_seed_callback(self._on_seed_change) + + # Create a new random instance that auto-reseeds + auto_random = workflow.new_random() + + # Generate random values before waiting + auto_value1 = auto_random.randint(1, 1000000) + + # Do an activity to give a reset point + await workflow.execute_activity( + say_hello, + "Hi", + schedule_to_close_timeout=timedelta(seconds=5), + ) + + self._ready = True + + # Wait for signal to continue - this allows for workflow reset + await workflow.wait_condition(lambda: self.continue_signal_received) + + # Generate more random values after reset might have occurred + auto_value2 = auto_random.randint(1, 1000000) + + # Get final seed + final_seed = workflow.random_seed() + + return { + "initial_seed": initial_seed, + "final_seed": final_seed, + "seed_changes": self.seed_changes.copy(), + "auto_values": [auto_value1, auto_value2], + } + + def _on_seed_change(self, new_seed: int) -> None: + self.seed_changes.append(new_seed) + + @workflow.signal + def continue_workflow(self) -> None: + self.continue_signal_received = True + + @workflow.query + def ready(self) -> bool: + return self._ready + + +async def test_random_seed_functionality( + client: Client, worker: Worker, env: WorkflowEnvironment +): + if env.supports_time_skipping: + pytest.skip("Java test server doesn't support reset") + async with new_worker( + client, RandomSeedTestWorkflow, activities=[say_hello], max_cached_workflows=0 + ) as worker: + workflow_id = f"test-random-seed-{uuid.uuid4()}" + handle = await client.start_workflow( + RandomSeedTestWorkflow.run, + id=workflow_id, + task_queue=worker.task_queue, + ) + + # Let workflow generate some random values + # Wait for workflow to be ready + async def ready() -> bool: + return await handle.query(RandomSeedTestWorkflow.ready) + + await assert_eq_eventually(True, ready) + + # Reset workflow using raw gRPC call to trigger seed change + from temporalio.api.common.v1.message_pb2 import WorkflowExecution + from temporalio.api.enums.v1.reset_pb2 import ResetReapplyType + from temporalio.api.workflowservice.v1 import ResetWorkflowExecutionRequest + + await client.workflow_service.reset_workflow_execution( + ResetWorkflowExecutionRequest( + namespace=client.namespace, + workflow_execution=WorkflowExecution( + workflow_id=handle.id, + run_id="", + ), + reason="Test seed change", + reset_reapply_type=ResetReapplyType.RESET_REAPPLY_TYPE_UNSPECIFIED, + request_id=str(uuid.uuid4()), + workflow_task_finish_event_id=9, # Reset to after activity completion + ) + ) + + # Get handle to the reset workflow using the new run ID + reset_handle = client.get_workflow_handle( + workflow_id, + ) + + # Continue the workflow + await reset_handle.signal(RandomSeedTestWorkflow.continue_workflow) + + result = await reset_handle.result() + + # Verify basic functionality + assert isinstance(result["initial_seed"], int) + assert isinstance(result["final_seed"], int) + assert isinstance(result["seed_changes"], list) + assert len(result["auto_values"]) == 2 + assert len(result["seed_changes"]) == 1