Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions temporalio/worker/_workflow_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,8 @@ def __init__(self, det: WorkflowInstanceDetails) -> None:
self._object: Any = None
self._is_replaying: bool = False
self._random = random.Random(det.randomness_seed)
self._current_seed = det.randomness_seed
self._seed_callbacks: list[Callable[[int], None]] = []
self._read_only = False
self._in_query_or_validator = False

Expand Down Expand Up @@ -1075,6 +1077,14 @@ def _apply_update_random_seed(
self, job: temporalio.bridge.proto.workflow_activation.UpdateRandomSeed
) -> None:
self._random.seed(job.randomness_seed)
self._current_seed = job.randomness_seed
# Notify all registered callbacks
for callback in self._seed_callbacks:
try:
callback(job.randomness_seed)
except Exception:
# Ignore callback errors to avoid disrupting workflow execution
pass

def _make_workflow_input(
self, init_job: temporalio.bridge.proto.workflow_activation.InitializeWorkflow
Expand Down Expand Up @@ -1808,6 +1818,14 @@ def workflow_last_failure(self) -> BaseException | None:

return None

def workflow_random_seed(self) -> int:
return self._current_seed

def workflow_register_random_seed_callback(
self, callback: Callable[[int], None]
) -> None:
self._seed_callbacks.append(callback)

#### Calls from outbound impl ####
# These are in alphabetical order and all start with "_outbound_".

Expand Down
53 changes: 53 additions & 0 deletions temporalio/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -902,6 +902,14 @@ def workflow_last_completion_result(self, type_hint: type | None) -> Any | None:
@abstractmethod
def workflow_last_failure(self) -> BaseException | None: ...

@abstractmethod
def workflow_random_seed(self) -> int: ...

@abstractmethod
def workflow_register_random_seed_callback(
self, callback: Callable[[int], None]
) -> None: ...


_current_update_info: contextvars.ContextVar[UpdateInfo] = contextvars.ContextVar(
"__temporal_current_update_info"
Expand Down Expand Up @@ -1156,6 +1164,51 @@ def random() -> Random:
return _Runtime.current().workflow_random()


def random_seed() -> int:
"""Get the current random seed value from core.

This returns the seed value currently being used by the workflow's
deterministic random number generator.

Returns:
The current random seed as an integer.
"""
return _Runtime.current().workflow_random_seed()


def register_random_seed_callback(callback: Callable[[int], None]) -> None:
"""Register a callback to be notified when the random seed changes.

The callback will be invoked whenever the workflow receives a new random
seed from the core. This is useful for maintaining external random number
generators that need to stay in sync with the workflow's randomness.

Args:
callback: Function to be called with the new seed value when it changes.
"""
return _Runtime.current().workflow_register_random_seed_callback(callback)


def new_random() -> Random:
"""Create a Random instance that automatically reseeds when the workflow seed changes.

This creates a new Random instance that is initially seeded with the current
workflow seed, and automatically registers a callback to reseed itself
whenever the workflow receives a new seed from core.

Returns:
A Random instance that stays synchronized with the workflow's randomness.
"""
current_seed = random_seed()
auto_random = Random(current_seed)

def reseed_callback(new_seed: int) -> None:
auto_random.seed(new_seed)

register_random_seed_callback(reseed_callback)
return auto_random


def time() -> float:
"""Current seconds since the epoch from the workflow perspective.

Expand Down
117 changes: 117 additions & 0 deletions tests/worker/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -8485,3 +8485,120 @@ async def test_disable_logger_sandbox(
run_timeout=timedelta(seconds=1),
retry_policy=RetryPolicy(maximum_attempts=1),
)


@workflow.defn
class RandomSeedTestWorkflow:
def __init__(self) -> None:
self.seed_changes: list[int] = []
self.continue_signal_received = False
self._ready = False

@workflow.run
async def run(self) -> dict[str, Any]:
# Get the initial seed
initial_seed = workflow.random_seed()

# Register callback to track seed changes
workflow.register_random_seed_callback(self._on_seed_change)

# Create a new random instance that auto-reseeds
auto_random = workflow.new_random()

# Generate random values before waiting
auto_value1 = auto_random.randint(1, 1000000)

# Do an activity to give a reset point
await workflow.execute_activity(
say_hello,
"Hi",
schedule_to_close_timeout=timedelta(seconds=5),
)

self._ready = True

# Wait for signal to continue - this allows for workflow reset
await workflow.wait_condition(lambda: self.continue_signal_received)

# Generate more random values after reset might have occurred
auto_value2 = auto_random.randint(1, 1000000)

# Get final seed
final_seed = workflow.random_seed()

return {
"initial_seed": initial_seed,
"final_seed": final_seed,
"seed_changes": self.seed_changes.copy(),
"auto_values": [auto_value1, auto_value2],
}

def _on_seed_change(self, new_seed: int) -> None:
self.seed_changes.append(new_seed)

@workflow.signal
def continue_workflow(self) -> None:
self.continue_signal_received = True

@workflow.query
def ready(self) -> bool:
return self._ready


async def test_random_seed_functionality(
client: Client, worker: Worker, env: WorkflowEnvironment
):
if env.supports_time_skipping:
pytest.skip("Java test server doesn't support reset")
async with new_worker(
client, RandomSeedTestWorkflow, activities=[say_hello], max_cached_workflows=0
) as worker:
workflow_id = f"test-random-seed-{uuid.uuid4()}"
handle = await client.start_workflow(
RandomSeedTestWorkflow.run,
id=workflow_id,
task_queue=worker.task_queue,
)

# Let workflow generate some random values
# Wait for workflow to be ready
async def ready() -> bool:
return await handle.query(RandomSeedTestWorkflow.ready)

await assert_eq_eventually(True, ready)

# Reset workflow using raw gRPC call to trigger seed change
from temporalio.api.common.v1.message_pb2 import WorkflowExecution
from temporalio.api.enums.v1.reset_pb2 import ResetReapplyType
from temporalio.api.workflowservice.v1 import ResetWorkflowExecutionRequest

await client.workflow_service.reset_workflow_execution(
ResetWorkflowExecutionRequest(
namespace=client.namespace,
workflow_execution=WorkflowExecution(
workflow_id=handle.id,
run_id="",
),
reason="Test seed change",
reset_reapply_type=ResetReapplyType.RESET_REAPPLY_TYPE_UNSPECIFIED,
request_id=str(uuid.uuid4()),
workflow_task_finish_event_id=9, # Reset to after activity completion
)
)

# Get handle to the reset workflow using the new run ID
reset_handle = client.get_workflow_handle(
workflow_id,
)

# Continue the workflow
await reset_handle.signal(RandomSeedTestWorkflow.continue_workflow)

result = await reset_handle.result()

# Verify basic functionality
assert isinstance(result["initial_seed"], int)
assert isinstance(result["final_seed"], int)
assert isinstance(result["seed_changes"], list)
assert len(result["auto_values"]) == 2
assert len(result["seed_changes"]) == 1