diff --git a/README.md b/README.md index f26c7b837..cfd302cf3 100644 --- a/README.md +++ b/README.md @@ -54,6 +54,7 @@ informal introduction to the features and their implementation. - [Data Conversion](#data-conversion) - [Pydantic Support](#pydantic-support) - [Custom Type Data Conversion](#custom-type-data-conversion) + - [External Payload Storage](#external-payload-storage) - [Workers](#workers) - [Workflows](#workflows) - [Definition](#definition) @@ -309,8 +310,9 @@ other_ns_client = Client(**config) Data converters are used to convert raw Temporal payloads to/from actual Python types. A custom data converter of type `temporalio.converter.DataConverter` can be set via the `data_converter` parameter of the `Client` constructor. Data -converters are a combination of payload converters, payload codecs, and failure converters. Payload converters convert -Python values to/from serialized bytes. Payload codecs convert bytes to bytes (e.g. for compression or encryption). +converters are a combination of payload converters, external payload storage, payload codecs, and failure converters. Payload +converters convert Python values to/from serialized bytes. External payload storage optionally stores and retrieves payloads +to/from external storage services using drivers. Payload codecs convert bytes to bytes (e.g. for compression or encryption). Failure converters convert exceptions to/from serialized failures. The default data converter supports converting multiple types including: @@ -455,6 +457,140 @@ my_data_converter = dataclasses.replace( Now `IPv4Address` can be used in type hints including collections, optionals, etc. +##### External Payload Storage + +⚠️ **External payload storage support is currently at an experimental release stage.** ⚠️ + +External payload storage allows large payloads to be offloaded to an external storage service (such as Amazon S3) rather than stored inline in workflow history. This is useful when workflows or activities work with data that would otherwise exceed Temporal's payload size limits. + +External payload storage is configured via the `external_storage` parameter on `DataConverter`, which accepts a `temporalio.extstore.StorageConfig` instance. Any driver used to store payloads must also be configured on the component that retrieves them — for example, if the client stores workflow inputs using a driver, the worker must include that driver in its `StorageConfig.drivers` list to retrieve them. + +The simplest setup uses a single storage driver: + +```python +import dataclasses +from temporalio.client import Client +from temporalio.converter import DataConverter +from temporalio.extstore import StorageConfig + +driver = MyDriver() + +client = await Client.connect( + "localhost:7233", + data_converter=dataclasses.replace( + DataConverter.default, + external_storage=StorageConfig(drivers=[driver]), + ), +) +``` + +Some things to note about external payload storage: + +* Only payloads that meet or exceed `StorageConfig.payload_size_threshold` (default 256 KiB) are offloaded. Smaller payloads are stored inline as normal. +* External payload storage applies transparently to workflow inputs/outputs, activity inputs/outputs, signals, updates, queries, and failure details. +* The `DataConverter`'s `payload_codec` (if configured) is applied to the *reference* payload stored in workflow history, not to the externally stored bytes. To encrypt or compress the bytes handed to a driver, use `StorageConfig.payload_codec`. +* Setting `StorageConfig.payload_size_threshold` to `None` causes every payload to be considered for external payload storage regardless of size. + +###### Multiple Drivers and Driver Selection + +When multiple storage backends are needed, list all drivers in `StorageConfig.drivers` and provide a `driver_selector` to control which driver stores new payloads. Any driver in the list not chosen for storing is still available for retrieval, which is useful when migrating between storage backends. + +```python +from temporalio.extstore import StorageConfig + +options = StorageConfig( + drivers=[hot_driver, cold_driver], + driver_selector=lambda context, payload: ( + hot_driver if payload.ByteSize() < 5 * 1024 * 1024 else cold_driver + ), +) +``` + +For stateful or class-based selection logic, implement a callable class. If it also implements `temporalio.converter.WithSerializationContext`, it will receive workflow or activity context (namespace, workflow ID, etc.) at serialization time, just like a driver or payload codec: + +```python +from typing_extensions import Self +import temporalio.converter +import temporalio.extstore +from temporalio.api.common.v1 import Payload + +class FeatureFlaggedDriverSelector(temporalio.converter.WithSerializationContext): + def __init__(self, driver: temporalio.extstore.StorageDriver, enabled: bool = False): + self._driver = driver + self._enabled = enabled + + def __call__( + self, _context: temporalio.extstore.StorageDriverContext, _payload: Payload + ) -> temporalio.extstore.StorageDriver | None: + return self._driver if self._enabled else None + + def with_context(self, context: temporalio.converter.SerializationContext) -> Self: + workflow_id = None + if isinstance(context, temporalio.converter.WorkflowSerializationContext) and context.workflow_id: + workflow_id = context.workflow_id + elif isinstance(context, temporalio.converter.ActivitySerializationContext) and context.workflow_id: + workflow_id = context.workflow_id + + return FeatureFlaggedDriverSelector( + self._driver, FeatureFlaggedDriverSelector.feature_flag_is_on(workflow_id) + ) + + @staticmethod + def feature_flag_is_on(workflow_id: str | None) -> bool: + """Mock implementation of a feature flag based on a workflow ID.""" + return workflow_id is not None and len(workflow_id) % 2 == 0 +``` + +Some things to note about driver selection: + +* When no `driver_selector` is set, the first driver in `StorageConfig.drivers` is always used for storing. +* Returning `None` from a selector leaves the payload stored inline in workflow history rather than offloading it. +* The driver returned by the selector must be registered in `StorageConfig.drivers`. If it is not, a `RuntimeError` is raised. + +###### Custom Drivers + +Implement `temporalio.extstore.StorageDriver` to integrate with an external storage system: + +```python +from collections.abc import Sequence +from temporalio.extstore import StorageDriver, StorageDriverClaim, StorageDriverContext +from temporalio.api.common.v1 import Payload + +class MyDriver(StorageDriver): + def __init__(self, driver_name: str | None = None): + self._driver_name = driver_name or "my-org:driver:my-driver" + + def name(self) -> str: + return self._driver_name + + async def store( + self, context: StorageDriverContext, payloads: Sequence[Payload] + ) -> list[StorageDriverClaim]: + claims = [] + for payload in payloads: + key = await my_storage.put(payload.SerializeToString()) + claims.append(StorageDriverClaim(data={"key": key})) + return claims + + async def retrieve( + self, context: StorageDriverContext, claims: Sequence[StorageDriverClaim] + ) -> list[Payload]: + payloads = [] + for claim in claims: + data = await my_storage.get(claim.data["key"]) + p = Payload() + p.ParseFromString(data) + payloads.append(p) + return payloads +``` + +Some things to note about implementing a custom driver: + +* `store` and `retrieve` must return lists of the same length as their respective input sequences. +* `StorageDriver.name()` must return a string that is unique among all drivers in `StorageConfig.drivers`. This name is embedded in the reference payload stored in workflow history and used to look up the correct driver during retrieval — changing it after payloads have been stored will break retrieval. +* `StorageDriver.type()` is automatically implemented to return the name of the class. This can be overridden in subclasses but must remain consistent across all instances of the subclass. +* Implement `temporalio.converter.WithSerializationContext` on your driver to receive workflow or activity context (namespace, workflow ID, activity ID, etc.) at serialization time. + ### Workers Workers host workflows and/or activities. Here's how to run a worker: diff --git a/temporalio/bridge/worker.py b/temporalio/bridge/worker.py index c98afefca..c2e426d28 100644 --- a/temporalio/bridge/worker.py +++ b/temporalio/bridge/worker.py @@ -303,10 +303,9 @@ async def decode_activation( decode_headers: bool, ) -> None: """Decode all payloads in the activation.""" - if data_converter._decode_payload_has_effect: - await CommandAwarePayloadVisitor( - skip_search_attributes=True, skip_headers=not decode_headers - ).visit(_Visitor(data_converter._decode_payload_sequence), activation) + await CommandAwarePayloadVisitor( + skip_search_attributes=True, skip_headers=not decode_headers + ).visit(_Visitor(data_converter._decode_payload_sequence), activation) async def encode_completion( diff --git a/temporalio/converter.py b/temporalio/converter.py index dc37f5039..18b2e67fd 100644 --- a/temporalio/converter.py +++ b/temporalio/converter.py @@ -21,6 +21,7 @@ from itertools import zip_longest from logging import getLogger from typing import ( + TYPE_CHECKING, Any, ClassVar, Literal, @@ -44,6 +45,9 @@ import temporalio.exceptions import temporalio.types +if TYPE_CHECKING: + import temporalio.extstore # avoid circular import at runtime + if sys.version_info < (3, 11): # Python's datetime.fromisoformat doesn't support certain formats pre-3.11 from dateutil import parser # type: ignore @@ -1359,15 +1363,25 @@ class DataConverter(WithSerializationContext): payload_limits: PayloadLimitsConfig = PayloadLimitsConfig() """Settings for payload size limits.""" + external_storage: temporalio.extstore.StorageConfig | None = None + """Options for external storage. If None, external storage is disabled. + + .. warning:: + This API is experimental. + """ + default: ClassVar[DataConverter] """Singleton default data converter.""" + _storage_impl: temporalio.extstore._StorageImpl = dataclasses.field(init=False) + _payload_error_limits: _ServerPayloadErrorLimits | None = None """Server-reported limits for payloads.""" def __post_init__(self) -> None: # noqa: D105 object.__setattr__(self, "payload_converter", self.payload_converter_class()) object.__setattr__(self, "failure_converter", self.failure_converter_class()) + self._reset_external_storage_middleware() async def encode( self, values: Sequence[Any] @@ -1445,18 +1459,22 @@ def with_context(self, context: SerializationContext) -> Self: payload_converter = self.payload_converter payload_codec = self.payload_codec failure_converter = self.failure_converter + external_storage = self.external_storage if isinstance(payload_converter, WithSerializationContext): payload_converter = payload_converter.with_context(context) if isinstance(payload_codec, WithSerializationContext): payload_codec = payload_codec.with_context(context) if isinstance(failure_converter, WithSerializationContext): failure_converter = failure_converter.with_context(context) + if isinstance(external_storage, WithSerializationContext): + external_storage = external_storage.with_context(context) if all( new is orig for new, orig in [ (payload_converter, self.payload_converter), (payload_codec, self.payload_codec), (failure_converter, self.failure_converter), + (external_storage, self.external_storage), ] ): return self @@ -1464,8 +1482,21 @@ def with_context(self, context: SerializationContext) -> Self: object.__setattr__(cloned, "payload_converter", payload_converter) object.__setattr__(cloned, "payload_codec", payload_codec) object.__setattr__(cloned, "failure_converter", failure_converter) + object.__setattr__(cloned, "external_storage", external_storage) + cloned._reset_external_storage_middleware(context) return cloned + def _reset_external_storage_middleware( + self, context: SerializationContext | None = None + ) -> None: + import temporalio.extstore # lazy import to avoid circular dependency + + object.__setattr__( + self, + "_external_storage_middleware", + temporalio.extstore._StorageImpl(self.external_storage, context), + ) + def _with_payload_error_limits( self, limits: _ServerPayloadErrorLimits | None ) -> DataConverter: @@ -1523,12 +1554,14 @@ async def _encode_memo_existing( async def _encode_payload( self, payload: temporalio.api.common.v1.Payload ) -> temporalio.api.common.v1.Payload: + payload = await self._storage_impl.store_payload(payload) if self.payload_codec: payload = (await self.payload_codec.encode([payload]))[0] self._validate_payload_limits([payload]) return payload async def _encode_payloads(self, payloads: temporalio.api.common.v1.Payloads): + await self._storage_impl.store_payloads(payloads) if self.payload_codec: await self.payload_codec.encode_wrapper(payloads) self._validate_payload_limits(payloads.payloads) @@ -1536,35 +1569,33 @@ async def _encode_payloads(self, payloads: temporalio.api.common.v1.Payloads): async def _encode_payload_sequence( self, payloads: Sequence[temporalio.api.common.v1.Payload] ) -> list[temporalio.api.common.v1.Payload]: - encoded_payloads = list(payloads) + result = await self._storage_impl.store_payload_sequence(payloads) if self.payload_codec: - encoded_payloads = await self.payload_codec.encode(encoded_payloads) - self._validate_payload_limits(encoded_payloads) - return encoded_payloads + result = await self.payload_codec.encode(result) + self._validate_payload_limits(result) + return result async def _decode_payload( self, payload: temporalio.api.common.v1.Payload ) -> temporalio.api.common.v1.Payload: if self.payload_codec: payload = (await self.payload_codec.decode([payload]))[0] + payload = await self._storage_impl.retrieve_payload(payload) return payload async def _decode_payloads(self, payloads: temporalio.api.common.v1.Payloads): if self.payload_codec: await self.payload_codec.decode_wrapper(payloads) + await self._storage_impl.retrieve_payloads(payloads) async def _decode_payload_sequence( self, payloads: Sequence[temporalio.api.common.v1.Payload] ) -> list[temporalio.api.common.v1.Payload]: - if not self.payload_codec: - return list(payloads) - return await self.payload_codec.decode(payloads) - - # Temporary shortcircuit detection while the _decode_* methods may no-op if - # a payload codec is not configured. Remove once those paths have more to them. - @property - def _decode_payload_has_effect(self) -> bool: - return self.payload_codec is not None + result = list(payloads) + if self.payload_codec: + result = await self.payload_codec.decode(result) + result = await self._storage_impl.retrieve_payload_sequence(result) + return result @staticmethod async def _apply_to_failure_payloads( diff --git a/temporalio/extstore.py b/temporalio/extstore.py new file mode 100644 index 000000000..54b446bec --- /dev/null +++ b/temporalio/extstore.py @@ -0,0 +1,501 @@ +"""External payload storage support for offloading payloads to external storage systems.""" + +from __future__ import annotations + +import asyncio +import dataclasses +import warnings +from abc import ABC, abstractmethod +from collections.abc import Callable, Mapping, Sequence +from dataclasses import dataclass + +from typing_extensions import Self + +from temporalio.api.common.v1 import Payload, Payloads +from temporalio.converter import ( + JSONPlainPayloadConverter, + PayloadCodec, + SerializationContext, + WithSerializationContext, +) + + +@dataclass(frozen=True) +class StorageDriverClaim: + """Claim for an externally stored payload. + + .. warning:: + This API is experimental. + """ + + data: Mapping[str, str] + """Driver-defined data for identifying and retrieving an externally stored payload.""" + + +@dataclass(frozen=True) +class StorageDriverContext: + """Context passed to :class:`StorageDriver` and :class:`DriverSelector` calls. + + .. warning:: + This API is experimental. + """ + + serialization_context: SerializationContext | None = None + """The serialization context active when this driver operation was initiated, + or ``None`` if no context has been set. + """ + + +class StorageDriver(ABC): + """Base driver for storing and retrieve payloads from external storage systems. + + .. warning:: + This API is experimental. + """ + + @abstractmethod + def name(self) -> str: + """Returns the name of this driver instance. A driver may allow its name + to be parameterized at construction time so that multiple instances of + the same driver class can coexist in :attr:`StorageConfig.drivers` with + distinct names. + """ + raise NotImplementedError + + def type(self) -> str: + """Returns the type of the storage driver. This string should be the same + across all instantiations of the same driver class. This allows the equivalent + driver implementation in different languages to be named the same. + + Defaults to the class name. Subclasses may override this to return a + stable, language-agnostic identifier. + """ + return type(self).__name__ + + @abstractmethod + async def store( + self, + context: StorageDriverContext, + payloads: Sequence[Payload], + ) -> list[StorageDriverClaim]: + """Stores payloads in external storage and returns a :class:`StorageDriverClaim` + for each one. The returned list must be the same length as ``payloads``. + """ + raise NotImplementedError + + @abstractmethod + async def retrieve( + self, + context: StorageDriverContext, + claims: Sequence[StorageDriverClaim], + ) -> list[Payload]: + """Retrieves payloads from external storage for the given :class:`StorageDriverClaim` + list. The returned list must be the same length as ``claims``. + + Raise :class:`PayloadNotFoundError` when a retrieval attempt confirms + that a payload is absent from storage. This signals an unrecoverable + condition that will fail the workflow rather than retrying the workflow + task. + """ + raise NotImplementedError + + +@dataclass(frozen=True) +class StorageConfig(WithSerializationContext): + """Configuration for external storage behavior. + + .. warning:: + This API is experimental. + """ + + drivers: Sequence[StorageDriver] + """Drivers available for storing and retrieving payloads. At least one + driver must be provided. + + When no :attr:`driver_selector` is set, the first driver in this list is + used for all store operations. Additional drivers may be included solely to + support retrieval — for example, to download payloads that remote callers + uploaded to an external storage system that is not your primary store + driver. Drivers in this list are looked up by :meth:`Driver.name` during + retrieval, so each driver must have a unique name. + """ + + driver_selector: ( + Callable[[StorageDriverContext, Payload], StorageDriver | None] | None + ) = None + """Controls which driver stores a given payload. Accepts either a + :class:`DriverSelector` instance or a callable of the form + ``(DriverContext, Payload) -> Driver | None``. + + When ``None``, the first driver in :attr:`drivers` is used for all store + operations. Returning ``None`` from the selector leaves the payload stored + inline rather than offloading it to external storage. + """ + + payload_size_threshold: int | None = 256 * 1024 + """Minimum payload size in bytes before external storage is considered. + Defaults to 256 KiB. Set to ``None`` to consider every payload for + external storage regardless of size. + """ + + payload_codec: PayloadCodec | None = None + """Optional codec applied to payloads before they are handed to a + :class:`StorageDriver` for storage, and after they are retrieved. When ``None``, + payloads are stored as-is by the driver. + """ + + def with_context(self, context: SerializationContext) -> Self: + """Return a copy of these options with the serialization context applied. + + Propagates *context* to any drivers, the driver selector, and the + payload codec that implement :class:`WithSerializationContext`. + If none of those fields changed, ``self`` is returned unchanged. + """ + drivers = list(self.drivers) + for index, driver in enumerate(drivers): + if isinstance(driver, WithSerializationContext): + drivers[index] = driver.with_context(context) + driver_selector = self.driver_selector + if isinstance(driver_selector, WithSerializationContext): + driver_selector = driver_selector.with_context(context) + payload_codec = self.payload_codec + if isinstance(payload_codec, WithSerializationContext): + payload_codec = payload_codec.with_context(context) + if all( + new is orig + for new, orig in [ + (drivers, self.drivers), + (driver_selector, self.driver_selector), + (payload_codec, self.payload_codec), + ] + ): + return self + cloned = dataclasses.replace(self) + object.__setattr__(cloned, "drivers", drivers) + object.__setattr__(cloned, "driver_selector", driver_selector) + object.__setattr__(cloned, "payload_codec", payload_codec) + return cloned + + +class StorageWarning(RuntimeWarning): + """Warning for external storage issues. + + .. warning:: + This API is experimental. + """ + + +@dataclass(frozen=True) +class _StorageReference: + driver_name: str + driver_claim: StorageDriverClaim + + +class _StorageImpl: # type:ignore[reportUnusedClass] + # Claim payload encoding is fixed and independent of any user configuration. + _claim_converter: JSONPlainPayloadConverter = JSONPlainPayloadConverter( + encoding="json/external-storage-reference" + ) + + def __init__( + self, + options: StorageConfig | None, + context: SerializationContext | None = None, + ): + self._options = options + self._context = context + self._driver_map: dict[str, StorageDriver] = {} + if options is not None: + for driver in options.drivers: + name = driver.name() + if name in self._driver_map: + warnings.warn( + f"StorageConfig.drivers contains multiple drivers with name '{name}'. " + "The first one will be used.", + category=StorageWarning, + ) + else: + self._driver_map[name] = driver + + def _select_driver( + self, context: StorageDriverContext, payload: Payload + ) -> StorageDriver | None: + """Returns the driver to use for this payload, or None to pass through.""" + assert self._options is not None + selector = self._options.driver_selector + if selector is None: + return self._options.drivers[0] if self._options.drivers else None + else: + driver = selector(context, payload) + if driver is None: + return None + registered = self._driver_map.get(driver.name()) + if registered is None: + raise RuntimeError(f"No driver found with name '{driver.name()}'") + return registered + + def _get_driver_by_name(self, name: str) -> StorageDriver: + """Looks up a driver by name, raising :class:`RuntimeError` if not found.""" + driver = self._driver_map.get(name) + if driver is None: + raise RuntimeError(f"No driver found with name '{name}'") + return driver + + async def store_payload(self, payload: Payload) -> Payload: + if self._options is None: + return payload + + size_bytes = payload.ByteSize() + if ( + self._options.payload_size_threshold is not None + and size_bytes < self._options.payload_size_threshold + ): + return payload + + context = StorageDriverContext(serialization_context=self._context) + + driver = self._select_driver(context, payload) + if driver is None: + return payload + + # Optionally encode the payload before externally storing it + encoded_payload = payload + if self._options.payload_codec: + encoded_payload = (await self._options.payload_codec.encode([payload]))[0] + + claims = await driver.store(context, [encoded_payload]) + + self._validate_claim_length(claims, expected=1, driver=driver) + + reference = _StorageReference( + driver_name=driver.name(), + driver_claim=claims[0], + ) + reference_payload = self._claim_converter.to_payload(reference) + assert reference_payload is not None + reference_payload.external_payloads.add().size_bytes = ( + encoded_payload.ByteSize() + ) + return reference_payload + + async def store_payloads(self, payloads: Payloads): + stored_payloads = await self.store_payload_sequence(payloads.payloads) + for i, payload in enumerate(stored_payloads): + payloads.payloads[i].CopyFrom(payload) + + async def store_payload_sequence( + self, + payloads: Sequence[Payload], + ) -> list[Payload]: + if self._options is None: + return list(payloads) + + if len(payloads) == 1: + return [await self.store_payload(payloads[0])] + + results = list(payloads) + context = StorageDriverContext(serialization_context=self._context) + + # First pass: determine which payloads to store and which driver to use for each. + # Provide unencoded payloads to give maximal context information to the selector. + to_store: list[tuple[int, Payload, StorageDriver]] = [] + for index, payload in enumerate(payloads): + size_bytes = payload.ByteSize() + if ( + self._options.payload_size_threshold is not None + and size_bytes < self._options.payload_size_threshold + ): + continue + driver = self._select_driver(context, payload) + if driver is None: + continue + to_store.append((index, payload, driver)) + + if not to_store: + return results + + # Optionally encode all payloads destined for external storage + payloads_to_encode = [payload for _, payload, _ in to_store] + encoded_payloads = payloads_to_encode + if self._options.payload_codec: + encoded_payloads = await self._options.payload_codec.encode( + payloads_to_encode + ) + + # Group encoded payloads by driver for batched store calls + # driver -> [(original_index, encoded_payload)] + driver_groups: dict[StorageDriver, list[tuple[int, Payload]]] = {} + for i, (orig_index, _, driver) in enumerate(to_store): + driver_groups.setdefault(driver, []).append( + (orig_index, encoded_payloads[i]) + ) + + # Store all driver groups concurrently then build reference payloads + driver_group_list = list(driver_groups.items()) + + all_claims = await asyncio.gather( + *( + driver.store(context, [p for _, p in indexed_payloads]) + for driver, indexed_payloads in driver_group_list + ) + ) + + for (driver, indexed_payloads), claims in zip(driver_group_list, all_claims): + indices = [idx for idx, _ in indexed_payloads] + sizes = [p.ByteSize() for _, p in indexed_payloads] + + self._validate_claim_length(claims, expected=len(indices), driver=driver) + + for i, claim in enumerate(claims): + reference = _StorageReference( + driver_name=driver.name(), + driver_claim=claim, + ) + reference_payload = self._claim_converter.to_payload(reference) + assert reference_payload is not None + reference_payload.external_payloads.add().size_bytes = sizes[i] + results[indices[i]] = reference_payload + + return results + + async def retrieve_payload( + self, + payload: Payload, + ) -> Payload: + if self._options is None or len(self._options.drivers) == 0: + # External storage was not configured (correctly). Warn if there are any external payloads + # since that is likely to cause downstream error when decoding or deserializing. + if len(payload.external_payloads) > 0: + if not self._options: + warnings.warn( + "External storage is not configured, but detected external storage references.", + category=StorageWarning, + ) + elif len(self._options.drivers) == 0: + warnings.warn( + "StorageConfig.drivers is empty, but detected external storage references.", + category=StorageWarning, + ) + return payload + + if len(payload.external_payloads) == 0: + return payload + + reference = self._claim_converter.from_payload(payload, _StorageReference) + if not isinstance(reference, _StorageReference): + return payload + + driver = self._get_driver_by_name(reference.driver_name) + context = StorageDriverContext(serialization_context=self._context) + + stored_payloads = await driver.retrieve(context, [reference.driver_claim]) + + self._validate_payload_length(stored_payloads, expected=1, driver=driver) + + if self._options.payload_codec: + stored_payloads = await self._options.payload_codec.decode(stored_payloads) + + return stored_payloads[0] + + async def retrieve_payloads(self, payloads: Payloads): + stored_payloads = await self.retrieve_payload_sequence(payloads.payloads) + for i, payload in enumerate(stored_payloads): + payloads.payloads[i].CopyFrom(payload) + + async def retrieve_payload_sequence( + self, + payloads: Sequence[Payload], + ) -> list[Payload]: + results = list(payloads) + + if self._options is None or len(self._options.drivers) == 0: + # External storage was not configured, but warn if there are any external payloads + # since that is likely to cause downstream error when decoding or deserializing. + if any(len(p.external_payloads) > 0 for p in payloads): + if not self._options: + warnings.warn( + "External storage is not configured, but detected external storage references.", + category=StorageWarning, + ) + elif len(self._options.drivers) == 0: + warnings.warn( + "StorageConfig.drivers is empty, but detected external storage references.", + category=StorageWarning, + ) + return results + + if len(payloads) == 1: + return [await self.retrieve_payload(payloads[0])] + + # Group claims by driver for batched retrieve calls + # driver -> [(original_index, claim)] + driver_claims: dict[StorageDriver, list[tuple[int, StorageDriverClaim]]] = {} + for index, payload in enumerate(payloads): + if len(payload.external_payloads) == 0: + continue + + reference = self._claim_converter.from_payload(payload, _StorageReference) + if not isinstance(reference, _StorageReference): + continue + + driver = self._get_driver_by_name(reference.driver_name) + driver_claims.setdefault(driver, []).append((index, reference.driver_claim)) + + if not driver_claims: + return results + + context = StorageDriverContext(serialization_context=self._context) + stored_by_index: dict[int, Payload] = {} + + # Retrieve from all drivers concurrently + driver_claim_list = list(driver_claims.items()) + + all_stored = await asyncio.gather( + *( + driver.retrieve(context, [claim for _, claim in indexed_claims]) + for driver, indexed_claims in driver_claim_list + ) + ) + + for (driver, indexed_claims), stored_payloads in zip( + driver_claim_list, all_stored + ): + indices = [idx for idx, _ in indexed_claims] + + self._validate_payload_length( + stored_payloads, + expected=len(indexed_claims), + driver=driver, + ) + + for idx, stored_payload in zip(indices, stored_payloads): + stored_by_index[idx] = stored_payload + + # Decode all retrieved payloads together if a codec is configured + retrieve_indices = sorted(stored_by_index.keys()) + stored_list = [stored_by_index[idx] for idx in retrieve_indices] + + decoded_payloads = stored_list + if self._options.payload_codec: + decoded_payloads = await self._options.payload_codec.decode(stored_list) + + for i, retrieved_payload in enumerate(decoded_payloads): + results[retrieve_indices[i]] = retrieved_payload + + return results + + def _validate_claim_length( + self, claims: Sequence[StorageDriverClaim], expected: int, driver: StorageDriver + ) -> None: + if len(claims) != expected: + raise RuntimeError( + f"Driver '{driver.name()}' returned {len(claims)} claims, expected {expected}", + ) + + def _validate_payload_length( + self, payloads: Sequence[Payload], expected: int, driver: StorageDriver + ) -> None: + if len(payloads) != expected: + raise RuntimeError( + f"Driver '{driver.name()}' returned {len(payloads)} payloads, expected {expected}", + ) diff --git a/temporalio/worker/_activity.py b/temporalio/worker/_activity.py index 4e6e06282..a895f54d2 100644 --- a/temporalio/worker/_activity.py +++ b/temporalio/worker/_activity.py @@ -625,7 +625,7 @@ async def _execute_activity( else None, ) - if self._encode_headers and data_converter._decode_payload_has_effect: + if self._encode_headers: for payload in start.header_fields.values(): payload.CopyFrom(await data_converter._decode_payload(payload)) diff --git a/temporalio/worker/_workflow.py b/temporalio/worker/_workflow.py index 18f5599ba..6fff2e8b5 100644 --- a/temporalio/worker/_workflow.py +++ b/temporalio/worker/_workflow.py @@ -340,21 +340,44 @@ async def _handle_activation( "Failed handling activation on workflow with run ID %s", act.run_id ) - completion.failed.failure.SetInParent() - try: - data_converter.failure_converter.to_failure( - err, - data_converter.payload_converter, - completion.failed.failure, - ) - except Exception as inner_err: - logger.exception( - "Failed converting activation exception on workflow with run ID %s", - act.run_id, - ) - completion.failed.failure.message = ( - f"Failed converting activation exception: {inner_err}" - ) + if ( + isinstance(err, temporalio.exceptions.ApplicationError) + and err.non_retryable + ): + # Fail the workflow execution terminally rather than failing the task + command = completion.successful.commands.add() + failure = command.fail_workflow_execution.failure + failure.SetInParent() + try: + data_converter.failure_converter.to_failure( + err, + data_converter.payload_converter, + failure, + ) + except Exception as inner_err: + logger.exception( + "Failed converting activation exception on workflow with run ID %s", + act.run_id, + ) + failure.message = ( + f"Failed converting activation exception: {inner_err}" + ) + else: + completion.failed.failure.SetInParent() + try: + data_converter.failure_converter.to_failure( + err, + data_converter.payload_converter, + completion.failed.failure, + ) + except Exception as inner_err: + logger.exception( + "Failed converting activation exception on workflow with run ID %s", + act.run_id, + ) + completion.failed.failure.message = ( + f"Failed converting activation exception: {inner_err}" + ) completion.run_id = act.run_id diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index 9c22c05ce..d7b6d27b9 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -1781,9 +1781,9 @@ def workflow_set_current_details(self, details: str): self._current_details = details def workflow_is_failure_exception(self, err: BaseException) -> bool: - # An exception is a failure instead of a task fail if it's already a - # failure error or if it is a timeout error or if it is an instance of - # any of the failure types in the worker or workflow-level setting + # An exception causes the workflow to fail (rather than the task) if it + # is already a failure error, a timeout error, or an instance of any of the + # failure exception types configured at the worker or workflow level. wf_failure_exception_types = self._defn.failure_exception_types if self._dynamic_failure_exception_types is not None: wf_failure_exception_types = self._dynamic_failure_exception_types diff --git a/tests/test_extstore.py b/tests/test_extstore.py new file mode 100644 index 000000000..c584d8b3b --- /dev/null +++ b/tests/test_extstore.py @@ -0,0 +1,718 @@ +"""Tests for external storage functionality.""" + +import warnings +from collections.abc import Sequence + +import pytest +from typing_extensions import Self + +from temporalio.api.common.v1 import Payload +from temporalio.converter import ( + ActivitySerializationContext, + DataConverter, + JSONPlainPayloadConverter, + PayloadCodec, + SerializationContext, + WithSerializationContext, + WorkflowSerializationContext, +) +from temporalio.exceptions import ApplicationError +from temporalio.extstore import ( + StorageConfig, + StorageDriver, + StorageDriverClaim, + StorageDriverContext, + StorageWarning, + _StorageReference, +) + + +class InMemoryTestDriver(StorageDriver): + """In-memory storage driver for testing.""" + + def __init__( + self, + driver_name: str = "test-driver", + ): + self._driver_name = driver_name + self._storage: dict[str, bytes] = {} + self._store_calls = 0 + self._retrieve_calls = 0 + + def name(self) -> str: + return self._driver_name + + async def store( + self, + context: StorageDriverContext, + payloads: Sequence[Payload], + ) -> list[StorageDriverClaim]: + self._store_calls += 1 + start_index = len(self._storage) + + entries = [ + (f"payload-{start_index + i}", payload.SerializeToString()) + for i, payload in enumerate(payloads) + ] + self._storage.update(entries) + + return [StorageDriverClaim(data={"key": key}) for key, _ in entries] + + async def retrieve( + self, + context: StorageDriverContext, + claims: Sequence[StorageDriverClaim], + ) -> list[Payload]: + self._retrieve_calls += 1 + + def parse_claim( + claim: StorageDriverClaim, + ) -> Payload: + key = claim.data["key"] + if key not in self._storage: + raise ApplicationError( + f"Payload not found for key '{key}'", non_retryable=True + ) + payload = Payload() + payload.ParseFromString(self._storage[key]) + return payload + + return [parse_claim(claim) for claim in claims] + + +class WorkflowIdFeatureFlagDriverSelector(WithSerializationContext): + """Example selector that conditionally stores based on workflow ID feature flag. + + This example shows how a callable can implement WithSerializationContext if it + needs to precompute data from the serialization context instead of doing it on + every payload selection call. + + The feature flag in this example is a simple check on the workflow ID length, but in + a real implementation this could be a call to a feature flag service or a lookup in a + configuration store. + """ + + def __init__(self, driver: StorageDriver, enabled: bool = False): + self._driver = driver + self._enabled = enabled + + def __call__( + self, _context: StorageDriverContext, _payload: Payload + ) -> StorageDriver | None: + return self._driver if self._enabled else None + + def with_context(self, context: SerializationContext) -> Self: + workflow_id = None + if isinstance(context, ActivitySerializationContext) and context.workflow_id: + workflow_id = context.workflow_id + if isinstance(context, WorkflowSerializationContext) and context.workflow_id: + workflow_id = context.workflow_id + + # Create new instance with updated enabled flag and propagate context to inner driver + driver = self._driver + if isinstance(driver, WithSerializationContext): + driver = driver.with_context(context) + + return type(self)( + driver, WorkflowIdFeatureFlagDriverSelector.feature_flag_is_on(workflow_id) + ) + + @staticmethod + def feature_flag_is_on(workflow_id: str | None) -> bool: + """Mock implementation of a feature flag based on a workflow ID.""" + return workflow_id is not None and len(workflow_id) % 2 == 0 + + +class TestDataConverterExternalStorage: + """Tests for DataConverter with external storage.""" + + async def test_extstore_encode_decode(self): + """Test that large payloads are stored externally.""" + driver = InMemoryTestDriver() + + # Configure with 100-byte threshold + converter = DataConverter( + external_storage=StorageConfig( + drivers=[driver], + payload_size_threshold=100, + ) + ) + + # Small value should not be externalized + small_value = "small" + encoded_small = await converter.encode([small_value]) + assert len(encoded_small) == 1 + assert not encoded_small[0].external_payloads # Not externalized + assert driver._store_calls == 0 + + # Large value should be externalized + large_value = "x" * 200 + encoded_large = await converter.encode([large_value]) + assert len(encoded_large) == 1 + assert len(encoded_large[0].external_payloads) > 0 # Externalized + assert driver._store_calls == 1 + + # Decode large value + decoded = await converter.decode(encoded_large, [str]) + assert len(decoded) == 1 + assert decoded[0] == large_value + assert driver._retrieve_calls == 1 + + async def test_extstore_reference_structure(self): + """Test that external storage creates proper reference structure.""" + converter = DataConverter( + external_storage=StorageConfig( + drivers=[InMemoryTestDriver("test-driver")], + payload_size_threshold=50, + ) + ) + + # Create large payload + large_value = "x" * 100 + encoded = await converter.encode([large_value]) + + # Verify reference structure + reference_payload = encoded[0] + assert len(reference_payload.external_payloads) > 0 + + # The payload should contain a serialized _ExternalStorageReference + # Deserialize it to verify structure using the same encoding + claim_converter = JSONPlainPayloadConverter( + encoding="json/external-storage-reference" + ) + reference = claim_converter.from_payload(reference_payload, _StorageReference) + + assert isinstance(reference, _StorageReference) + assert "test-driver" == reference.driver_name + assert isinstance(reference.driver_claim, StorageDriverClaim) + assert "key" in reference.driver_claim.data + + async def test_extstore_composite_conditional(self): + """Test using multiple drivers based on size.""" + hot_driver = InMemoryTestDriver("hot-storage") + cold_driver = InMemoryTestDriver("cold-storage") + + options = StorageConfig( + drivers=[hot_driver, cold_driver], + driver_selector=lambda context, payload: hot_driver + if payload.ByteSize() < 500 + else cold_driver, + payload_size_threshold=100, + ) + converter = DataConverter(external_storage=options) + + # Small payload (not externalized) + small = "x" * 50 + encoded_small = await converter.encode([small]) + assert not encoded_small[0].external_payloads + assert hot_driver._store_calls == 0 + assert cold_driver._store_calls == 0 + + # Medium payload (hot storage) + medium = "x" * 200 + encoded_medium = await converter.encode([medium]) + assert len(encoded_medium[0].external_payloads) > 0 + assert hot_driver._store_calls == 1 + assert cold_driver._store_calls == 0 + + # Large payload (cold storage) + large = "x" * 2000 + encoded_large = await converter.encode([large]) + assert len(encoded_large[0].external_payloads) > 0 + assert hot_driver._store_calls == 1 # Unchanged + assert cold_driver._store_calls == 1 + + # Verify retrieval from correct drivers + decoded_medium = await converter.decode(encoded_medium, [str]) + assert decoded_medium[0] == medium + assert hot_driver._retrieve_calls == 1 + + decoded_large = await converter.decode(encoded_large, [str]) + assert decoded_large[0] == large + assert cold_driver._retrieve_calls == 1 + + async def test_extstore_serialization_context(self): + driver = InMemoryTestDriver() + + # The payload should not be stored externally when it doesn't have a serialization context + # or if the workflow ID doesn't end with "-extstore". This is an example of feature flagging + # external storage using the workflow ID. This is an advanced secnario and requires the "can_store" + # filter to be a WithSerializationContext. + options = StorageConfig( + drivers=[driver], + driver_selector=WorkflowIdFeatureFlagDriverSelector(driver), + payload_size_threshold=1024, + ) + converter = DataConverter(external_storage=options) + + large_value = "c" * (1024 + 100) + + encoded_payloads = await converter.encode([large_value]) + assert len(encoded_payloads) == 1 + assert driver._store_calls == 0 + assert driver._retrieve_calls == 0 + + decoded_values = await converter.decode(encoded_payloads, [str]) + assert len(decoded_values) == 1 + assert driver._store_calls == 0 + assert driver._store_calls == 0 + + namespace = "my-ns" + + # Now has serialization context, but workflow ID does not end with "-extstore" + converter = converter.with_context( + WorkflowSerializationContext( + namespace=namespace, + workflow_id="odd-length-workflow-id1", + ) + ) + + encoded_payloads = await converter.encode([large_value]) + assert len(encoded_payloads) == 1 + assert driver._store_calls == 0 + assert driver._retrieve_calls == 0 + + decoded_values = await converter.decode(encoded_payloads, [str]) + assert len(decoded_values) == 1 + assert driver._store_calls == 0 + assert driver._store_calls == 0 + + # Now has serialization context with workflow ID ending with "-extstore" + converter = converter.with_context( + WorkflowSerializationContext( + namespace=namespace, + workflow_id="even-length-workflow-id1", + ) + ) + + encoded_payloads = await converter.encode([large_value]) + assert len(encoded_payloads) == 1 + assert driver._store_calls == 1 + assert driver._retrieve_calls == 0 + + decoded_values = await converter.decode(encoded_payloads, [str]) + assert len(decoded_values) == 1 + assert driver._store_calls == 1 + assert driver._store_calls == 1 + + +class NotFoundDriver(StorageDriver): + """Driver that stores normally but raises non-retryable ApplicationError on retrieve.""" + + def __init__(self, driver_name: str = "not-found-driver"): + self._driver_name = driver_name + self._storage: dict[str, bytes] = {} + + def name(self) -> str: + return self._driver_name + + async def store( + self, + context: StorageDriverContext, + payloads: Sequence[Payload], + ) -> list[StorageDriverClaim]: + entries = [ + (f"payload-{i}", payload.SerializeToString()) + for i, payload in enumerate(payloads) + ] + self._storage.update(entries) + return [StorageDriverClaim(data={"key": key}) for key, _ in entries] + + async def retrieve( + self, + context: StorageDriverContext, + claims: Sequence[StorageDriverClaim], + ) -> list[Payload]: + assert len(claims) > 0, "NotFoundDriver expected claims to be provided" + raise ApplicationError("Payload not found.", non_retryable=True) + + +class TestDriverError: + """Tests for RuntimeError raised when a driver violates its contract.""" + + async def test_encode_wrong_claim_count_raises_runtime_error(self): + """store() returning fewer claims than payloads must raise RuntimeError.""" + + class _NoClaimsDriver(InMemoryTestDriver): + async def store( + self, context: StorageDriverContext, payloads: Sequence[Payload] + ) -> list[StorageDriverClaim]: + return [] + + driver = _NoClaimsDriver() + converter = DataConverter( + external_storage=StorageConfig( + drivers=[driver], + payload_size_threshold=10, + ) + ) + with pytest.raises( + RuntimeError, + match=f"Driver '{driver.name()}' returned 0 claims, expected 1", + ): + await converter.encode(["x" * 200]) + + async def test_decode_wrong_payload_count_raises_runtime_error(self): + """retrieve() returning fewer payloads than claims must raise RuntimeError.""" + good_converter = DataConverter( + external_storage=StorageConfig( + drivers=[InMemoryTestDriver()], + payload_size_threshold=10, + ) + ) + encoded = await good_converter.encode(["x" * 200]) + + class _NoPayloadsDriver(InMemoryTestDriver): + async def retrieve( + self, + context: StorageDriverContext, + claims: Sequence[StorageDriverClaim], + ) -> list[Payload]: + return [] + + driver = _NoPayloadsDriver() + bad_converter = DataConverter( + external_storage=StorageConfig( + drivers=[driver], + payload_size_threshold=10, + ) + ) + with pytest.raises( + RuntimeError, + match=f"Driver '{driver.name()}' returned 0 payloads, expected 1", + ): + await bad_converter.decode(encoded, [str]) + + +class RecordingPayloadCodec(PayloadCodec): + """Codec that wraps each payload under a recognisable ``encoding`` label. + + Encode sets ``metadata["encoding"]`` to ``encoding_label`` and stores the + serialised inner payload as ``data``. Decode reverses that. The call + counters let tests assert exactly how many payloads each codec processed. + """ + + def __init__(self, encoding_label: str) -> None: + self._encoding_label = encoding_label.encode() + self.encoded_count = 0 + self.decoded_count = 0 + + async def encode(self, payloads: Sequence[Payload]) -> list[Payload]: + self.encoded_count += len(payloads) + results = [] + for p in payloads: + wrapped = Payload() + wrapped.metadata["encoding"] = self._encoding_label + wrapped.data = p.SerializeToString() + results.append(wrapped) + return results + + async def decode(self, payloads: Sequence[Payload]) -> list[Payload]: + self.decoded_count += len(payloads) + results = [] + for p in payloads: + inner = Payload() + inner.ParseFromString(p.data) + results.append(inner) + return results + + +class TestPayloadCodecWithExternalStorage: + """Tests for interaction between DataConverter.payload_codec and external storage.""" + + async def test_dc_payload_codec_encodes_reference_payload(self): + """DataConverter.payload_codec encodes the reference payload in workflow + history but does NOT encode the bytes handed to the driver for storage.""" + driver = InMemoryTestDriver() + dc_codec = RecordingPayloadCodec("binary/dc-encoded") + + converter = DataConverter( + payload_codec=dc_codec, + external_storage=StorageConfig( + drivers=[driver], + payload_size_threshold=50, + ), + ) + + large_value = "x" * 200 + encoded = await converter.encode([large_value]) + assert len(encoded) == 1 + assert driver._store_calls == 1 + + # The reference payload written to history must carry the dc_codec label. + assert dc_codec.encoded_count == 1 + assert encoded[0].metadata.get("encoding") == b"binary/dc-encoded" + + # The bytes given to the driver must NOT carry the dc_codec label. + stored_payload = Payload() + stored_payload.ParseFromString(next(iter(driver._storage.values()))) + assert stored_payload.metadata.get("encoding") != b"binary/dc-encoded" + assert stored_payload.metadata.get("encoding") == b"json/plain" + + # Round-trip must recover the original value. + decoded = await converter.decode(encoded, [str]) + assert decoded[0] == large_value + assert dc_codec.decoded_count == 1 + assert driver._retrieve_calls == 1 + + async def test_external_converter_without_codec_does_not_encode_stored_bytes(self): + """When DataConverter.payload_codec is set but StorageConfig.payload_codec + is None, stored bytes are NOT encoded – even though + DataConverter.payload_codec is active for the reference payload in history.""" + driver = InMemoryTestDriver() + dc_codec = RecordingPayloadCodec("binary/dc-encoded") + + converter = DataConverter( + payload_codec=dc_codec, + external_storage=StorageConfig( + drivers=[driver], + payload_size_threshold=50, + ), + ) + + large_value = "x" * 200 + encoded = await converter.encode([large_value]) + assert len(encoded) == 1 + assert driver._store_calls == 1 + + # Reference payload in history is still encoded by DataConverter.payload_codec. + assert dc_codec.encoded_count == 1 + assert encoded[0].metadata.get("encoding") == b"binary/dc-encoded" + + # Stored bytes are NOT encoded. + stored_payload = Payload() + stored_payload.ParseFromString(next(iter(driver._storage.values()))) + assert stored_payload.metadata.get("encoding") != b"binary/dc-encoded" + assert stored_payload.metadata.get("encoding") == b"json/plain" + + # Round-trip. + decoded = await converter.decode(encoded, [str]) + assert decoded[0] == large_value + assert dc_codec.decoded_count == 1 + assert driver._retrieve_calls == 1 + + async def test_external_converter_codec_independent_from_dc_codec(self): + """When both DataConverter.payload_codec and StorageConfig.payload_codec + are set, the reference payload in history uses DataConverter.payload_codec + and the bytes stored by the driver use StorageConfig.payload_codec – + independently.""" + driver = InMemoryTestDriver() + dc_codec = RecordingPayloadCodec("binary/dc-encoded") + ext_codec = RecordingPayloadCodec("binary/ext-encoded") + + converter = DataConverter( + payload_codec=dc_codec, + external_storage=StorageConfig( + drivers=[driver], + payload_size_threshold=50, + payload_codec=ext_codec, + ), + ) + + large_value = "x" * 200 + encoded = await converter.encode([large_value]) + assert len(encoded) == 1 + assert driver._store_calls == 1 + + # Each codec was applied exactly once during encode. + assert dc_codec.encoded_count == 1 + assert ext_codec.encoded_count == 1 + + # Reference payload carries dc_codec's label. + assert encoded[0].metadata.get("encoding") == b"binary/dc-encoded" + + # Stored bytes carry ext_codec's label – different from the reference. + stored_payload = Payload() + stored_payload.ParseFromString(next(iter(driver._storage.values()))) + assert stored_payload.metadata.get("encoding") == b"binary/ext-encoded" + assert stored_payload.metadata.get("encoding") != encoded[0].metadata.get( + "encoding" + ) + + # Round-trip must recover the original value using both codecs. + decoded = await converter.decode(encoded, [str]) + assert decoded[0] == large_value + assert dc_codec.decoded_count == 1 + assert ext_codec.decoded_count == 1 + assert driver._retrieve_calls == 1 + + +class TestMultiDriver: + """Tests for StorageConfig with multiple drivers.""" + + async def test_no_selector_uses_first_driver_for_store(self): + """Without a driver_selector the first driver in the list handles all + store operations. Additional drivers are never called for store.""" + first = InMemoryTestDriver("driver-first") + second = InMemoryTestDriver("driver-second") + + converter = DataConverter( + external_storage=StorageConfig( + drivers=[first, second], + payload_size_threshold=50, + ) + ) + + large = "x" * 200 + encoded = await converter.encode([large]) + + assert first._store_calls == 1 + assert second._store_calls == 0 + + # The reference in history names the first driver. + ref = JSONPlainPayloadConverter( + encoding="json/external-storage-reference" + ).from_payload(encoded[0], _StorageReference) + assert ref.driver_name == "driver-first" + + # Retrieval also goes to the first driver. + decoded = await converter.decode(encoded, [str]) + assert decoded[0] == large + assert first._retrieve_calls == 1 + assert second._retrieve_calls == 0 + + async def test_no_selector_second_driver_is_retrieve_only(self): + """A driver that is second in the list acts as a retrieve-only driver. + References are resolved by name, not by position, so a payload stored + by driver-b is retrieved correctly even when driver-a is listed first.""" + driver_a = InMemoryTestDriver("driver-a") + driver_b = InMemoryTestDriver("driver-b") + + # Store with driver-b as the sole driver. + store_converter = DataConverter( + external_storage=StorageConfig( + drivers=[driver_b], + payload_size_threshold=50, + ) + ) + large = "y" * 200 + encoded = await store_converter.encode([large]) + + # Retrieve with driver-a listed first, driver-b second. + # The "driver-b" name in the reference must route to driver-b. + retrieve_converter = DataConverter( + external_storage=StorageConfig( + drivers=[driver_a, driver_b], + payload_size_threshold=50, + ) + ) + decoded = await retrieve_converter.decode(encoded, [str]) + assert decoded[0] == large + assert driver_a._retrieve_calls == 0 # never consulted + assert driver_b._retrieve_calls == 1 + + async def test_selector_routes_payloads_to_different_drivers_in_single_batch(self): + """When a selector routes different payloads to different drivers, a + single encode([v1, v2, ...]) call batches payloads per driver so each + driver receives exactly one store() call regardless of how many + payloads are routed to it.""" + driver_a = InMemoryTestDriver("driver-a") + driver_b = InMemoryTestDriver("driver-b") + + # Route payloads that serialise to < 500 bytes to driver_a, larger ones + # to driver_b. + def selector(_ctx: object, payload: Payload) -> InMemoryTestDriver: + return driver_a if payload.ByteSize() < 500 else driver_b + + converter = DataConverter( + external_storage=StorageConfig( + drivers=[driver_a, driver_b], + driver_selector=selector, + payload_size_threshold=50, + ) + ) + + small_ext = "a" * 100 # above threshold, serialises well below 500 B + large_ext = "b" * 1000 # serialises above 500 B + + # Encode both values in a single call — they should be batched per driver. + encoded = await converter.encode([small_ext, large_ext]) + assert driver_a._store_calls == 1 # one batched call, not two individual ones + assert driver_b._store_calls == 1 + + # Full round-trip. + decoded = await converter.decode(encoded, [str, str]) + assert decoded == [small_ext, large_ext] + assert driver_a._retrieve_calls == 1 + assert driver_b._retrieve_calls == 1 + + async def test_selector_returning_none_keeps_payload_inline(self): + """A selector that returns None for a payload leaves it stored inline + in workflow history rather than offloading it to any driver, even when + the payload exceeds the size threshold.""" + driver = InMemoryTestDriver("driver-a") + + converter = DataConverter( + external_storage=StorageConfig( + drivers=[driver], + driver_selector=lambda _ctx, _payload: None, + payload_size_threshold=50, + ) + ) + + large = "x" * 200 + encoded = await converter.encode([large]) + + assert driver._store_calls == 0 + assert len(encoded[0].external_payloads) == 0 # payload is inline + + decoded = await converter.decode(encoded, [str]) + assert decoded[0] == large + assert driver._retrieve_calls == 0 + + async def test_selector_returns_unregistered_driver_raises(self): + """A selector that returns a Driver whose name is not present in + StorageConfig.drivers raises RuntimeError during encode.""" + registered = InMemoryTestDriver("registered") + unregistered = InMemoryTestDriver("not-in-list") + + converter = DataConverter( + external_storage=StorageConfig( + drivers=[registered], + driver_selector=lambda _ctx, _payload: unregistered, + payload_size_threshold=50, + ) + ) + + with pytest.raises(RuntimeError): + await converter.encode(["x" * 200]) + + async def test_duplicate_driver_names_warns_and_first_wins_for_retrieval(self): + """Registering two drivers with identical names emits StorageWarning. + The first-registered driver with that name is kept in the name→driver + map; subsequent duplicates are ignored. + + Both store (positional via drivers[0]) and retrieval (name-based map) + therefore resolve to the same driver — no data is lost.""" + first = InMemoryTestDriver("dup-name") + duplicate = InMemoryTestDriver("dup-name") + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + converter = DataConverter( + external_storage=StorageConfig( + drivers=[first, duplicate], + payload_size_threshold=50, + ) + ) + + storage_warnings = [w for w in caught if issubclass(w.category, StorageWarning)] + assert len(storage_warnings) == 1 + assert "dup-name" in str(storage_warnings[0].message) + + # Store goes to first (drivers[0], positional). + large = "x" * 200 + encoded = await converter.encode([large]) + assert first._store_calls == 1 + assert duplicate._store_calls == 0 + + # Retrieval resolves "dup-name" to first (first-registered wins). + # first has the data, so the round-trip succeeds. + decoded = await converter.decode(encoded, [str]) + assert decoded[0] == large + assert first._retrieve_calls == 1 + assert duplicate._retrieve_calls == 0 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/worker/test_extstore.py b/tests/worker/test_extstore.py new file mode 100644 index 000000000..84492af6d --- /dev/null +++ b/tests/worker/test_extstore.py @@ -0,0 +1,427 @@ +import dataclasses +import uuid +from collections.abc import Sequence +from dataclasses import dataclass +from datetime import timedelta + +import pytest + +import temporalio +import temporalio.converter +from temporalio import activity, workflow +from temporalio.api.common.v1 import Payload +from temporalio.client import Client, WorkflowFailureError, WorkflowHandle +from temporalio.common import RetryPolicy +from temporalio.exceptions import ActivityError, ApplicationError +from temporalio.extstore import ( + StorageConfig, + StorageDriverClaim, + StorageDriverContext, + StorageWarning, +) +from temporalio.testing._workflow import WorkflowEnvironment +from temporalio.worker import Replayer +from tests.helpers import assert_task_fail_eventually, new_worker +from tests.test_extstore import InMemoryTestDriver + + +@dataclass(frozen=True) +class ExtStoreActivityInput: + input_data: str + output_size: int + pass + + +@activity.defn +async def ext_store_activity( + input: ExtStoreActivityInput, +) -> str: + return "ao" * int(input.output_size / 2) + + +@dataclass(frozen=True) +class ExtStoreWorkflowInput: + input_data: str + activity_input_size: int + activity_output_size: int + output_size: int + max_activity_attempts: int | None = None + + +@workflow.defn +class ExtStoreWorkflow: + @workflow.run + async def run(self, input: ExtStoreWorkflowInput) -> str: + retry_policy = ( + RetryPolicy(maximum_attempts=input.max_activity_attempts) + if input.max_activity_attempts is not None + else None + ) + await workflow.execute_activity( + ext_store_activity, + ExtStoreActivityInput( + input_data="ai" * int(input.activity_input_size / 2), + output_size=input.activity_output_size, + ), + schedule_to_close_timeout=timedelta(seconds=3), + retry_policy=retry_policy, + ) + return "wo" * int(input.output_size / 2) + + +class BadTestDriver(InMemoryTestDriver): + def __init__( + self, + driver_name: str = "bad-driver", + no_store: bool = False, + no_retrieve: bool = False, + raise_payload_not_found: bool = False, + ): + super().__init__(driver_name) + self._no_store = no_store + self._no_retrieve = no_retrieve + self._raise_payload_not_found = raise_payload_not_found + + async def store( + self, + context: StorageDriverContext, + payloads: Sequence[Payload], + ) -> list[StorageDriverClaim]: + if self._no_store: + return [] + return await super().store(context, payloads) + + async def retrieve( + self, + context: StorageDriverContext, + claims: Sequence[StorageDriverClaim], + ) -> list[Payload]: + if self._no_retrieve: + return [] + if self._raise_payload_not_found: + raise ApplicationError( + "Payload not found", + type="PayloadNotFoundError", + non_retryable=True, + ) + return await super().retrieve(context, claims) + + +async def test_extstore_activity_input_no_retrieve( + env: WorkflowEnvironment, +): + """Using a small threshold, validate that activity result size over + the threshold causes driver to be invoked.""" + driver = BadTestDriver(no_retrieve=True) + + client = await Client.connect( + env.client.service_client.config.target_host, + namespace=env.client.namespace, + data_converter=dataclasses.replace( + temporalio.converter.default(), + external_storage=StorageConfig( + drivers=[driver], + payload_size_threshold=1024, + ), + ), + ) + + async with new_worker( + client, ExtStoreWorkflow, activities=[ext_store_activity] + ) as worker: + handle = await client.start_workflow( + ExtStoreWorkflow.run, + ExtStoreWorkflowInput( + input_data="workflow input", + activity_input_size=1000, + activity_output_size=10, + output_size=10, + max_activity_attempts=1, + ), + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + with pytest.raises(WorkflowFailureError) as err: + await handle.result() + + assert isinstance(err.value.cause, ActivityError) + + +async def test_extstore_activity_result_no_store( + env: WorkflowEnvironment, +): + """Using a small threshold, validate that activity result size over + the threshold causes driver to be invoked.""" + driver = BadTestDriver(no_store=True) + + client = await Client.connect( + env.client.service_client.config.target_host, + namespace=env.client.namespace, + data_converter=dataclasses.replace( + temporalio.converter.default(), + external_storage=StorageConfig( + drivers=[driver], + payload_size_threshold=1024, + ), + ), + ) + + async with new_worker( + client, ExtStoreWorkflow, activities=[ext_store_activity] + ) as worker: + handle = await client.start_workflow( + ExtStoreWorkflow.run, + ExtStoreWorkflowInput( + input_data="workflow input", + activity_input_size=10, + activity_output_size=1000, + output_size=10, + max_activity_attempts=1, + ), + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + with pytest.raises(WorkflowFailureError) as err: + await handle.result() + + assert isinstance(err.value.cause, ActivityError) + + +async def test_extstore_worker_missing_driver( + env: WorkflowEnvironment, +): + """Validate that when a worker is provided a workflow history with + external storage references and the worker is not configured for external + storage, it will cause a workflow task failure. + """ + driver = InMemoryTestDriver() + + far_client = await Client.connect( + env.client.service_client.config.target_host, + namespace=env.client.namespace, + data_converter=dataclasses.replace( + temporalio.converter.default(), + external_storage=StorageConfig( + drivers=[driver], + payload_size_threshold=1024, + ), + ), + ) + + worker_client = await Client.connect( + env.client.service_client.config.target_host, + namespace=env.client.namespace, + ) + + async with new_worker( + worker_client, ExtStoreWorkflow, activities=[ext_store_activity] + ) as worker: + handle = await far_client.start_workflow( + ExtStoreWorkflow.run, + ExtStoreWorkflowInput( + input_data="wi" * 1024, + activity_input_size=10, + activity_output_size=10, + output_size=10, + ), + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + await assert_task_fail_eventually(handle) + + +async def test_extstore_payload_not_found_fails_workflow( + env: WorkflowEnvironment, +): + """When a non-retryable ApplicationError is raised while retrieving workflow input, + the workflow must fail terminally (not retry as a task failure). + """ + client = await Client.connect( + env.client.service_client.config.target_host, + namespace=env.client.namespace, + data_converter=dataclasses.replace( + temporalio.converter.default(), + external_storage=StorageConfig( + drivers=[BadTestDriver(raise_payload_not_found=True)], + payload_size_threshold=1024, + ), + ), + ) + + async with new_worker( + client, ExtStoreWorkflow, activities=[ext_store_activity] + ) as worker: + handle = await client.start_workflow( + ExtStoreWorkflow.run, + ExtStoreWorkflowInput( + input_data="wi" * 512, # exceeds 1024-byte threshold + activity_input_size=10, + activity_output_size=10, + output_size=10, + ), + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=10), + ) + + with pytest.raises(WorkflowFailureError) as exc_info: + await handle.result() + + assert isinstance(exc_info.value.cause, ApplicationError) + assert exc_info.value.cause.type == "PayloadNotFoundError" + assert exc_info.value.cause.non_retryable is True + + +async def _run_extstore_workflow_and_fetch_history( + env: WorkflowEnvironment, + driver: InMemoryTestDriver, + *, + input_data: str, + activity_output_size: int = 10, +) -> WorkflowHandle: + """Helper: run ExtStoreWorkflow with the given driver and return its history handle.""" + extstore_client = await Client.connect( + env.client.service_client.config.target_host, + namespace=env.client.namespace, + data_converter=dataclasses.replace( + temporalio.converter.default(), + external_storage=StorageConfig( + drivers=[driver], + payload_size_threshold=512, + ), + ), + ) + async with new_worker( + extstore_client, ExtStoreWorkflow, activities=[ext_store_activity] + ) as worker: + handle = await extstore_client.start_workflow( + ExtStoreWorkflow.run, + ExtStoreWorkflowInput( + input_data=input_data, + activity_input_size=10, + activity_output_size=activity_output_size, + output_size=10, + ), + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + await handle.result() + return handle + + +async def test_replay_extstore_history_fails_without_extstore( + env: WorkflowEnvironment, +) -> None: + """A history with externalized workflow input fails to replay when the + Replayer has no external storage configured.""" + driver = InMemoryTestDriver() + handle = await _run_extstore_workflow_and_fetch_history( + env, + driver, + input_data="wi" * 512, # exceeds 512-byte threshold + ) + history = await handle.fetch_history() + + # Replay without external storage — the reference payload cannot be decoded. + # The middleware emits a StorageWarning when it encounters a reference payload + # with no driver configured. + with pytest.warns(StorageWarning, match="External storage is not configured"): + result = await Replayer(workflows=[ExtStoreWorkflow]).replay_workflow( + history, raise_on_replay_failure=False + ) + # Must be a task-failure RuntimeError, not a NondeterminismError — external + # storage decode failures are distinct from workflow code changes. + assert isinstance(result.replay_failure, RuntimeError) + assert not isinstance(result.replay_failure, workflow.NondeterminismError) + # The message is the full activation-completion failure string; the + # "Failed decoding arguments" text from _convert_payloads is embedded in it. + assert "Failed decoding arguments" in result.replay_failure.args[0] + + +async def test_replay_extstore_history_succeeds_with_correct_extstore( + env: WorkflowEnvironment, +) -> None: + """A history with externalized workflow input replays successfully when the + Replayer is configured with the same storage driver that holds the data.""" + driver = InMemoryTestDriver() + handle = await _run_extstore_workflow_and_fetch_history( + env, driver, input_data="wi" * 512 + ) + history = await handle.fetch_history() + + # Replay with the same populated driver — must succeed. + await Replayer( + workflows=[ExtStoreWorkflow], + data_converter=dataclasses.replace( + temporalio.converter.default(), + external_storage=StorageConfig( + drivers=[driver], + payload_size_threshold=512, + ), + ), + ).replay_workflow(history) + + +async def test_replay_extstore_history_fails_with_empty_driver( + env: WorkflowEnvironment, +) -> None: + """A history with external storage references fails to replay when the + Replayer has external storage configured but the driver holds no data + (simulates pointing at the wrong backend or a purged store).""" + driver = InMemoryTestDriver() + handle = await _run_extstore_workflow_and_fetch_history( + env, driver, input_data="wi" * 512 + ) + history = await handle.fetch_history() + + # Replay with a fresh empty driver — retrieval will fail. + result = await Replayer( + workflows=[ExtStoreWorkflow], + data_converter=dataclasses.replace( + temporalio.converter.default(), + external_storage=StorageConfig( + drivers=[InMemoryTestDriver()], + payload_size_threshold=512, + ), + ), + ).replay_workflow(history, raise_on_replay_failure=False) + # InMemoryTestDriver raises ApplicationError for absent keys. + # ApplicationError is re-raised without wrapping, so it propagates + # through decode_activation (before the workflow task runs). The core SDK + # receives an activation failure, issues a FailWorkflow command, but the + # next history event is ActivityTaskScheduled — causing a NondeterminismError. + assert isinstance(result.replay_failure, workflow.NondeterminismError) + + +async def test_replay_extstore_activity_result_fails_without_extstore( + env: WorkflowEnvironment, +) -> None: + """A history where only the activity result was stored externally (the + workflow input is small enough to be inline) also fails to replay without + external storage — verifying that mid-workflow decode failures are caught.""" + driver = InMemoryTestDriver() + handle = await _run_extstore_workflow_and_fetch_history( + env, + driver, + input_data="small", # well under 512 bytes — stays inline + activity_output_size=2048, # 2 KB result — stored externally + ) + history = await handle.fetch_history() + + # Replay without external storage. The workflow input decodes fine, but + # when the ActivityTaskCompleted result is delivered back to the workflow + # coroutine it cannot be decoded. + with pytest.warns(StorageWarning, match="External storage is not configured"): + result = await Replayer(workflows=[ExtStoreWorkflow]).replay_workflow( + history, raise_on_replay_failure=False + ) + # Mid-workflow decode failure is still a task failure (RuntimeError), not + # nondeterminism. + assert isinstance(result.replay_failure, RuntimeError) + assert not isinstance(result.replay_failure, workflow.NondeterminismError) + # The message is the full activation-completion failure string; the + # "Failed decoding arguments" text from _convert_payloads is embedded in it. + assert "Failed decoding arguments" in result.replay_failure.args[0]