11from attrs import field , frozen
22from datetime import datetime
33from metafold .api import asdatetime , asdict , optional_datetime
4+ from metafold .assets import Asset
45from metafold .client import Client
56from metafold .exceptions import PollTimeout
7+ from metafold .jobs import Job
68from requests import Response
7- from typing import Optional , Union
9+ from typing import Optional , Union , cast
10+ import typing
11+
12+ if typing .TYPE_CHECKING :
13+ from metafold import MetafoldClient
814
915
1016@frozen (kw_only = True )
@@ -21,6 +27,9 @@ class Workflow:
2127 definition: Workflow definition string.
2228 project_id: Project ID.
2329 """
30+ _client : "MetafoldClient"
31+ _jobs : dict [str , str ] = field (factory = dict , init = False )
32+
2433 id : str
2534 jobs : list [str ] = field (factory = list )
2635 state : str
@@ -32,6 +41,56 @@ class Workflow:
3241 definition : str
3342 project_id : str
3443
44+ def get_asset (self , path : str ) -> Union [Asset , None ]:
45+ """Retrieve an asset from the workflow by dot notation.
46+
47+ Args:
48+ path: Path to asset in the form "job.name", e.g. "sample-mesh.volume"
49+ searches for the asset "volume" from the "sample-mesh" job.
50+ """
51+ job_name , asset_name = self ._parse_path (path )
52+ job = self ._find_job (job_name )
53+ if not job or not job .outputs .assets :
54+ return None
55+ for name , asset in job .outputs .assets .items ():
56+ if name == asset_name :
57+ return asset
58+ return None
59+
60+ def get_parameter (self , path : str ) -> Union [str , None ]:
61+ """Retrieve a parameter from the workflow by dot notation.
62+
63+ Args:
64+ path: Path to parameter in the form "job.name", e.g. "sample-mesh.patch_size"
65+ searches for the parameter "patch_size" from the "sample-mesh" job.
66+ """
67+ job_name , param_name = self ._parse_path (path )
68+ job = self ._find_job (job_name )
69+ if not job or not job .outputs .params :
70+ return None
71+ for name , param in job .outputs .params .items ():
72+ if name == param_name :
73+ return param
74+ return None
75+
76+ def _find_job (self , name : str ) -> Union [Job , None ]:
77+ # FIXME(ryan): Update API to return job names as well as IDs.
78+ # For now we cache a mapping b/w job name and job id.
79+ if job_id := self ._jobs .get (name ):
80+ return self ._client .jobs .get (job_id )
81+
82+ for job_id in self .jobs :
83+ job = self ._client .jobs .get (job_id )
84+ if job .name == name :
85+ self ._jobs [name ] = job_id
86+ return job
87+ return None
88+
89+ @staticmethod
90+ def _parse_path (path : str ) -> tuple [str , str ]:
91+ first , second = path .split ("." , maxsplit = 1 )
92+ return first , second
93+
3594
3695class WorkflowsEndpoint :
3796 """Metafold workflows endpoint."""
@@ -61,7 +120,7 @@ def list(
61120 url = f"/projects/{ project_id } /workflows"
62121 payload = asdict (sort = sort , q = q )
63122 r : Response = self ._client .get (url , params = payload )
64- return [Workflow (** w ) for w in r .json ()]
123+ return [Workflow (client = cast ( "MetafoldClient" , self . _client ), ** w ) for w in r .json ()]
65124
66125 def get (self , workflow_id : str , project_id : Optional [str ] = None ) -> Workflow :
67126 """Get a workflow.
@@ -76,7 +135,7 @@ def get(self, workflow_id: str, project_id: Optional[str] = None) -> Workflow:
76135 project_id = self ._client .project_id (project_id )
77136 url = f"/projects/{ project_id } /workflows/{ workflow_id } "
78137 r : Response = self ._client .get (url )
79- return Workflow (** r .json ())
138+ return Workflow (client = cast ( "MetafoldClient" , self . _client ), ** r .json ())
80139
81140 def run (
82141 self , definition : str ,
@@ -110,7 +169,7 @@ def run(
110169 raise RuntimeError (
111170 f"Workflow failed to complete within { timeout } seconds"
112171 ) from e
113- return Workflow (** r .json ())
172+ return Workflow (client = cast ( "MetafoldClient" , self . _client ), ** r .json ())
114173
115174 def cancel (self , workflow_id : str , project_id : Optional [str ] = None ) -> Workflow :
116175 """Cancel a running workflow.
@@ -125,7 +184,7 @@ def cancel(self, workflow_id: str, project_id: Optional[str] = None) -> Workflow
125184 project_id = self ._client .project_id (project_id )
126185 url = f"/projects/{ project_id } /workflows/{ workflow_id } /cancel"
127186 r : Response = self ._client .post (url )
128- return Workflow (** r .json ())
187+ return Workflow (client = cast ( "MetafoldClient" , self . _client ), ** r .json ())
129188
130189 def delete (self , workflow_id : str , project_id : Optional [str ] = None ):
131190 """Delete a workflow.
0 commit comments