22import drmaa
33from drmaa_executor_plugin .drmaa_patches import PatchedSession as drmaaSession
44
5- from typing import TYPE_CHECKING
6- from typing import Optional , List
5+ from typing import TYPE_CHECKING , Optional , TypedDict , Generator , Dict , cast
76
87from functools import wraps
98
2827 drmaa .JobState .FAILED : State .FAILED
2928}
3029
30+ JobID = int
31+ _TaskInstanceKeyDict = TypedDict ('_TaskInstanceKeyDict' , {
32+ 'dag_id' : str ,
33+ 'task_id' : str ,
34+ 'run_id' : str ,
35+ 'try_number' : int
36+ })
37+
38+ JobTrackingType = Dict [JobID , _TaskInstanceKeyDict ]
39+
3140
3241def check_started (method ):
3342 '''
@@ -58,49 +67,78 @@ def __init__(self, max_concurrent_jobs: Optional[int] = None):
5867 # Not yet implemented
5968 self .max_concurrent_jobs : Optional [int ] = max_concurrent_jobs
6069
70+ def iter_scheduled_jobs (
71+ self ) -> Generator [tuple [JobID , TaskInstanceKey ], None , None ]:
72+ '''
73+ Iterate over scheduled jobs
74+ '''
75+ for job_id , instance_info in self ._get_or_create_job_ids ().items ():
76+ yield job_id , TaskInstanceKey (** instance_info )
77+
6178 @property
6279 def active_jobs (self ) -> int :
6380 return len (self ._get_or_create_job_ids ())
6481
82+ @active_jobs .setter
83+ def active_jobs (self , val ):
84+ self .active_jobs = val
85+
6586 # TODO: Make `scheduler_job_ids` configurable under [executor]
6687 @provide_session
6788 def _get_or_create_job_ids (self ,
68- session : Optional [Session ] = None ) -> List [int ]:
89+ session : Optional [Session ] = None
90+ ) -> JobTrackingType :
91+
6992 current_jobs = Variable .get ("scheduler_job_ids" ,
7093 default_var = None ,
7194 deserialize_json = True )
7295 if not current_jobs :
96+ current_jobs = {}
7397 self .log .info ("Setting up job tracking Airflow variable..." )
74- Variable .set ("scheduler_job_ids" , {"jobs" : []},
98+ Variable .set ("scheduler_job_ids" ,
99+ current_jobs ,
75100 description = "Scheduler Job ID tracking" ,
76101 serialize_json = True ,
77102 session = session )
78103 self .log .info ("Created `scheduler_job_ids` variable" )
79104
80- return current_jobs [ "jobs" ]
105+ return current_jobs
81106
82107 @provide_session
83108 def _update_job_tracker (self ,
84- jobs : List [ int ] ,
109+ jobs : JobTrackingType ,
85110 session : Optional [Session ] = None ) -> None :
86- write_json = {"jobs" : jobs }
87111 Variable .update ("scheduler_job_ids" ,
88- write_json ,
112+ jobs ,
89113 serialize_json = True ,
90114 session = session )
91115
92- def _drop_from_tracking (self , job_id : int ) -> None :
116+ def _drop_from_tracking (self , job_id : JobID ) -> None :
93117 self .log .info (
94- "Removing Job {job_id} from tracking variable `scheduler_job_ids`" )
95- new_state = [j for j in self ._get_or_create_job_ids () if j != job_id ]
96- self ._update_job_tracker (new_state )
97- self .log .info ("Successfully removed {job_id} from `scheduler_job_ids`" )
118+ f"Removing Job { job_id } from tracking variable `scheduler_job_ids`"
119+ )
120+
121+ jobs = self ._get_or_create_job_ids ()
122+ try :
123+ jobs .pop (job_id )
124+ except KeyError :
125+ self .log .error (f"Failed to remove { job_id } , job was not"
126+ " being tracked by Airflow!" )
127+ else :
128+ self ._update_job_tracker (jobs )
129+ self .log .info (
130+ f"Successfully removed { job_id } from `scheduler_job_ids`" )
98131
99- def _push_to_tracking (self , job_id : int ) -> None :
132+ def _push_to_tracking (self , job_id : JobID , key : TaskInstanceKey ) -> None :
100133 self .log .info (
101134 "Adding Job {job_id} to tracking variable `scheduler_job_ids`" )
135+
136+ # Convert TaskInstanceKey to serializable form
137+ key_dict = _taskkey_to_dict (key )
138+ entry = {job_id : key_dict }
139+
102140 current_jobs = self ._get_or_create_job_ids ()
103- current_jobs .append ( job_id )
141+ current_jobs .update ( entry )
104142 self ._update_job_tracker (current_jobs )
105143 self .log .info ("Successfully added {job_id} to `scheduler_job_ids`" )
106144
@@ -114,7 +152,7 @@ def start(self) -> None:
114152 current_jobs = self ._get_or_create_job_ids ()
115153
116154 if current_jobs :
117- print_jobs = "\n " .join ([j_id for j_id in current_jobs [ "jobs" ] ])
155+ print_jobs = "\n " .join ([f" { j_id } " for j_id in current_jobs ])
118156 self .log .info (f"Jobs from previous session:\n { print_jobs } " )
119157 else :
120158 self .log .info ("No jobs are currently being tracked" )
@@ -127,6 +165,7 @@ def end(self) -> None:
127165 self .log .info ("Terminating DRMAA session" )
128166 self .session .exit ()
129167
168+ @check_started
130169 def sync (self ) -> None :
131170 """
132171 Called periodically by `airflow.executors.base_executor.BaseExecutor`'s
@@ -136,9 +175,8 @@ def sync(self) -> None:
136175 """
137176
138177 # Go through currently running jobs and update state
139- scheduled_jobs = self ._get_or_create_job_ids ()
140- for job_id in scheduled_jobs :
141- drmaa_status = self .jobStatus (job_id )
178+ for job_id , task_instance_key in self .iter_scheduled_jobs ():
179+ drmaa_status = self .session .jobStatus (job_id )
142180 try :
143181 status = JOB_STATE_MAP [drmaa_status ]
144182 except KeyError :
@@ -147,7 +185,8 @@ def sync(self) -> None:
147185 " Cannot be mapped into an Airflow TaskInstance State"
148186 " Will try again in next sync attempt..." )
149187 else :
150- self .change_state (status , None )
188+ # Need taskinstancekey
189+ self .change_state (task_instance_key , status )
151190 self ._drop_from_tracking (job_id )
152191
153192 @check_started
@@ -162,7 +201,7 @@ def execute_async(self,
162201
163202 self .log .info (f"Submitting job { key } with command { command } with"
164203 f" configuration options:\n { executor_config } )" )
165- jt = executor_config .drm2drmaa (self .session .createJobTemplate ())
204+ jt = executor_config .get_drmaa_config (self .session .createJobTemplate ())
166205
167206 # CommandType always begins with "airflow" binary command
168207 jt .remoteCommand = command [0 ]
@@ -174,9 +213,18 @@ def execute_async(self,
174213 job_id = self .session .runJob (jt )
175214
176215 self .log .info (f"Submitted Job { job_id } " )
177- self ._push_to_tracking (job_id )
216+ self ._push_to_tracking (job_id , key )
178217
179218 # Prevent memory leaks on C back-end, running jobs unaffected
180219 # https://drmaa-python.readthedocs.io/en/latest/drmaa.html
181220 self .session .deleteJobTemplate (jt )
182221 self .jobs_submitted += 1
222+
223+
224+ def _taskkey_to_dict (key : TaskInstanceKey ) -> _TaskInstanceKeyDict :
225+ return {
226+ "dag_id" : key .dag_id ,
227+ "task_id" : key .task_id ,
228+ "run_id" : key .run_id ,
229+ "try_number" : key .try_number
230+ }
0 commit comments