Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -533,11 +533,11 @@ def feature_flag_is_on(workflow_id: str | None) -> bool:
def feature_flag_selector(
context: temporalio.converter.StorageDriverStoreContext, _payload: Payload
) -> temporalio.converter.StorageDriver | None:
workflow_id = None
if isinstance(context.serialization_context, temporalio.converter.WorkflowSerializationContext):
workflow_id = context.serialization_context.workflow_id
elif isinstance(context.serialization_context, temporalio.converter.ActivitySerializationContext):
workflow_id = context.serialization_context.workflow_id
workflow_id = (
context.target.id
if isinstance(context.target, temporalio.converter.StorageDriverWorkflowInfo)
else None
)
return my_driver if feature_flag_is_on(workflow_id) else None

options = ExternalStorage(
Expand Down
133 changes: 111 additions & 22 deletions temporalio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@
ActivitySerializationContext,
DataConverter,
SerializationContext,
StorageDriverActivityInfo,
StorageDriverStoreContext,
StorageDriverWorkflowInfo,
WithSerializationContext,
WorkflowSerializationContext,
)
Expand Down Expand Up @@ -6161,11 +6164,16 @@ async def _to_proto(
priority: temporalio.api.common.v1.Priority | None = None
if self.priority:
priority = self.priority._to_proto()
data_converter = client.data_converter.with_context(
data_converter = client.data_converter._with_contexts(
WorkflowSerializationContext(
namespace=client.namespace,
workflow_id=self.id,
)
),
StorageDriverStoreContext(
target=StorageDriverWorkflowInfo(
id=self.id, type=self.workflow, namespace=client.namespace
),
),
)
action = temporalio.api.schedule.v1.ScheduleAction(
start_workflow=temporalio.api.workflow.v1.NewWorkflowExecutionInfo(
Expand Down Expand Up @@ -6210,7 +6218,8 @@ async def _to_proto(
# TODO (dan): confirm whether this be `is not None`
if self.typed_search_attributes:
temporalio.converter.encode_search_attributes(
self.typed_search_attributes, action.start_workflow.search_attributes
self.typed_search_attributes,
action.start_workflow.search_attributes,
)
if self.headers:
await _apply_headers(
Expand Down Expand Up @@ -8077,11 +8086,16 @@ async def _build_signal_with_start_workflow_execution_request(
self, input: StartWorkflowInput
) -> temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionRequest:
assert input.start_signal
data_converter = self._client.data_converter.with_context(
data_converter = self._client.data_converter._with_contexts(
WorkflowSerializationContext(
namespace=self._client.namespace,
workflow_id=input.id,
)
),
StorageDriverStoreContext(
target=StorageDriverWorkflowInfo(
id=input.id, type=input.workflow, namespace=self._client.namespace
),
),
)
req = temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionRequest(
signal_name=input.start_signal
Expand All @@ -8108,11 +8122,16 @@ async def _populate_start_workflow_execution_request(
),
input: StartWorkflowInput | UpdateWithStartStartWorkflowInput,
) -> None:
data_converter = self._client.data_converter.with_context(
data_converter = self._client.data_converter._with_contexts(
WorkflowSerializationContext(
namespace=self._client.namespace,
workflow_id=input.id,
)
),
StorageDriverStoreContext(
target=StorageDriverWorkflowInfo(
id=input.id, type=input.workflow, namespace=self._client.namespace
),
),
)
req.namespace = self._client.namespace
req.workflow_id = input.id
Expand Down Expand Up @@ -8228,11 +8247,18 @@ async def count_workflows(
)

async def query_workflow(self, input: QueryWorkflowInput) -> Any:
data_converter = self._client.data_converter.with_context(
data_converter = self._client.data_converter._with_contexts(
WorkflowSerializationContext(
namespace=self._client.namespace,
workflow_id=input.id,
)
),
StorageDriverStoreContext(
target=StorageDriverWorkflowInfo(
id=input.id,
run_id=input.run_id or None,
namespace=self._client.namespace,
),
),
)
req = temporalio.api.workflowservice.v1.QueryWorkflowRequest(
namespace=self._client.namespace,
Expand All @@ -8255,7 +8281,10 @@ async def query_workflow(self, input: QueryWorkflowInput) -> Any:
await self._apply_headers(input.headers, req.query.header.fields)
try:
resp = await self._client.workflow_service.query_workflow(
req, retry=True, metadata=input.rpc_metadata, timeout=input.rpc_timeout
req,
retry=True,
metadata=input.rpc_metadata,
timeout=input.rpc_timeout,
)
except RPCError as err:
# If the status is INVALID_ARGUMENT, we can assume it's a query
Expand All @@ -8281,11 +8310,18 @@ async def query_workflow(self, input: QueryWorkflowInput) -> Any:
return results[0]

async def signal_workflow(self, input: SignalWorkflowInput) -> None:
data_converter = self._client.data_converter.with_context(
data_converter = self._client.data_converter._with_contexts(
WorkflowSerializationContext(
namespace=self._client.namespace,
workflow_id=input.id,
)
),
StorageDriverStoreContext(
target=StorageDriverWorkflowInfo(
id=input.id,
run_id=input.run_id or None,
namespace=self._client.namespace,
),
),
)
req = temporalio.api.workflowservice.v1.SignalWorkflowExecutionRequest(
namespace=self._client.namespace,
Expand All @@ -8306,11 +8342,18 @@ async def signal_workflow(self, input: SignalWorkflowInput) -> None:
)

async def terminate_workflow(self, input: TerminateWorkflowInput) -> None:
data_converter = self._client.data_converter.with_context(
data_converter = self._client.data_converter._with_contexts(
WorkflowSerializationContext(
namespace=self._client.namespace,
workflow_id=input.id,
)
),
StorageDriverStoreContext(
target=StorageDriverWorkflowInfo(
id=input.id,
run_id=input.run_id or None,
namespace=self._client.namespace,
),
),
)
req = temporalio.api.workflowservice.v1.TerminateWorkflowExecutionRequest(
namespace=self._client.namespace,
Expand Down Expand Up @@ -8365,7 +8408,7 @@ async def _build_start_activity_execution_request(
self, input: StartActivityInput
) -> temporalio.api.workflowservice.v1.StartActivityExecutionRequest:
"""Build StartActivityExecutionRequest from input."""
data_converter = self._client.data_converter.with_context(
data_converter = self._client.data_converter._with_contexts(
ActivitySerializationContext(
namespace=self._client.namespace,
activity_id=input.id,
Expand All @@ -8374,7 +8417,14 @@ async def _build_start_activity_execution_request(
is_local=False,
workflow_id=None,
workflow_type=None,
)
),
StorageDriverStoreContext(
target=StorageDriverActivityInfo(
id=input.id,
type=input.activity_type,
namespace=self._client.namespace,
),
),
)

req = temporalio.api.workflowservice.v1.StartActivityExecutionRequest(
Expand Down Expand Up @@ -8560,11 +8610,20 @@ async def _build_update_workflow_execution_request(
input: StartWorkflowUpdateInput | UpdateWithStartUpdateWorkflowInput,
workflow_id: str,
) -> temporalio.api.workflowservice.v1.UpdateWorkflowExecutionRequest:
data_converter = self._client.data_converter.with_context(
data_converter = self._client.data_converter._with_contexts(
WorkflowSerializationContext(
namespace=self._client.namespace,
workflow_id=workflow_id,
)
),
StorageDriverStoreContext(
target=StorageDriverWorkflowInfo(
id=workflow_id,
run_id=(input.run_id or None)
if isinstance(input, StartWorkflowUpdateInput)
else None,
namespace=self._client.namespace,
),
),
)
run_id, first_execution_run_id = (
(
Expand Down Expand Up @@ -8739,10 +8798,34 @@ async def _start_workflow_update_with_start(

### Async activity calls

def _get_async_activity_store_context(
self, id_or_token: AsyncActivityIDReference | bytes
) -> StorageDriverStoreContext:
if isinstance(id_or_token, AsyncActivityIDReference):
if id_or_token.workflow_id:
return StorageDriverStoreContext(
target=StorageDriverWorkflowInfo(
id=id_or_token.workflow_id or None,
run_id=id_or_token.run_id or None,
namespace=self._client.namespace,
),
)
return StorageDriverStoreContext(
target=StorageDriverActivityInfo(
id=id_or_token.activity_id,
run_id=id_or_token.run_id or None,
namespace=self._client.namespace,
),
)
else:
return StorageDriverStoreContext(target=None)

async def heartbeat_async_activity(
self, input: HeartbeatAsyncActivityInput
) -> None:
data_converter = input.data_converter_override or self._client.data_converter
data_converter = (
input.data_converter_override or self._client.data_converter
)._with_store_context(self._get_async_activity_store_context(input.id_or_token))
details = (
None
if not input.details
Expand Down Expand Up @@ -8797,7 +8880,9 @@ async def heartbeat_async_activity(
)

async def complete_async_activity(self, input: CompleteAsyncActivityInput) -> None:
data_converter = input.data_converter_override or self._client.data_converter
data_converter = (
input.data_converter_override or self._client.data_converter
)._with_store_context(self._get_async_activity_store_context(input.id_or_token))
result = (
None
if input.result is temporalio.common._arg_unset
Expand Down Expand Up @@ -8831,7 +8916,9 @@ async def complete_async_activity(self, input: CompleteAsyncActivityInput) -> No
)

async def fail_async_activity(self, input: FailAsyncActivityInput) -> None:
data_converter = input.data_converter_override or self._client.data_converter
data_converter = (
input.data_converter_override or self._client.data_converter
)._with_store_context(self._get_async_activity_store_context(input.id_or_token))

failure = temporalio.api.failure.v1.Failure()
await data_converter.encode_failure(input.error, failure)
Expand Down Expand Up @@ -8872,7 +8959,9 @@ async def fail_async_activity(self, input: FailAsyncActivityInput) -> None:
async def report_cancellation_async_activity(
self, input: ReportCancellationAsyncActivityInput
) -> None:
data_converter = input.data_converter_override or self._client.data_converter
data_converter = (
input.data_converter_override or self._client.data_converter
)._with_store_context(self._get_async_activity_store_context(input.id_or_token))
details = (
None
if not input.details
Expand Down
55 changes: 20 additions & 35 deletions temporalio/contrib/aws/s3driver/_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
from temporalio.api.common.v1 import Payload
from temporalio.contrib.aws.s3driver._client import S3StorageDriverClient
from temporalio.converter import (
ActivitySerializationContext,
StorageDriver,
StorageDriverActivityInfo,
StorageDriverClaim,
StorageDriverRetrieveContext,
StorageDriverStoreContext,
WorkflowSerializationContext,
StorageDriverWorkflowInfo,
)

_T = TypeVar("_T")
Expand Down Expand Up @@ -113,40 +113,25 @@ async def store(
(e.g. proto binary). The returned list is the same length as
``payloads``.
"""
workflow_id: str | None = None
activity_id: str | None = None
namespace: str | None = None
if isinstance(context.serialization_context, WorkflowSerializationContext):
workflow_id = context.serialization_context.workflow_id
namespace = context.serialization_context.namespace
if isinstance(context.serialization_context, ActivitySerializationContext):
# Prioritize workflow over activity so that the same payload that
# may be stored across workflow and activity boundaries are deduplicated.
if context.serialization_context.workflow_id:
workflow_id = context.serialization_context.workflow_id
elif context.serialization_context.activity_id:
activity_id = context.serialization_context.activity_id
namespace = context.serialization_context.namespace

# URL encode values to avoid characters that break the key format
# e.g. spaces, forward-slashes, etc.
if namespace:
namespace = urllib.parse.quote(namespace, safe="")
if workflow_id:
workflow_id = urllib.parse.quote(workflow_id, safe="")
if activity_id:
activity_id = urllib.parse.quote(activity_id, safe="")

namespace_segments = f"/ns/{namespace}" if namespace else ""

def _quote(val: str | None) -> str | None:
return urllib.parse.quote(val, safe="") if val else None

# Build context segments from the target identity.
context_segments = ""
# Prioritize workflow over activity so that the same payload that
# may be stored across workflow and activity boundaries are deduplicated.
# Workflow and Activity IDs are case sensitive.
if workflow_id:
context_segments += f"/wfi/{workflow_id}"
elif activity_id:
context_segments += f"/aci/{activity_id}"
target = context.target
namespace = _quote(target.namespace) if target is not None else None
namespace_segment = f"/ns/{namespace}" if namespace else ""
if isinstance(target, StorageDriverWorkflowInfo):
wf_type = _quote(target.type) or "null"
wf_id = _quote(target.id) or "null"
wf_run_id = _quote(target.run_id) or "null"
context_segments = f"/wt/{wf_type}/wi/{wf_id}/ri/{wf_run_id}"
elif isinstance(target, StorageDriverActivityInfo):
act_type = _quote(target.type) or "null"
act_id = _quote(target.id) or "null"
act_run_id = _quote(target.run_id) or "null"
context_segments = f"/at/{act_type}/ai/{act_id}/ri/{act_run_id}"

async def _upload(payload: Payload) -> StorageDriverClaim:
bucket = self._get_bucket(context, payload)
Expand All @@ -162,7 +147,7 @@ async def _upload(payload: Payload) -> StorageDriverClaim:

digest_segments = f"/d/sha256/{hash_digest}"

key = f"v0{namespace_segments}{context_segments}{digest_segments}"
key = f"v0{namespace_segment}{context_segments}{digest_segments}"

try:
if not await self._client.object_exists(bucket=bucket, key=key):
Expand Down
4 changes: 4 additions & 0 deletions temporalio/converter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
from temporalio.converter._extstore import (
ExternalStorage,
StorageDriver,
StorageDriverActivityInfo,
StorageDriverClaim,
StorageDriverRetrieveContext,
StorageDriverStoreContext,
StorageDriverWorkflowInfo,
StorageWarning,
)
from temporalio.converter._failure_converter import (
Expand Down Expand Up @@ -54,9 +56,11 @@
"ActivitySerializationContext",
"ExternalStorage",
"StorageDriver",
"StorageDriverActivityInfo",
"StorageDriverClaim",
"StorageDriverRetrieveContext",
"StorageDriverStoreContext",
"StorageDriverWorkflowInfo",
"StorageWarning",
"AdvancedJSONEncoder",
"BinaryNullPayloadConverter",
Expand Down
Loading