Skip to content

Commit 825b2c8

Browse files
committed
TYPMAINT: Type fixes #1
1 parent 8e081b4 commit 825b2c8

1 file changed

Lines changed: 70 additions & 22 deletions

File tree

executors/drmaa_executor.py

Lines changed: 70 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22
import drmaa
33
from drmaa_executor_plugin.drmaa_patches import PatchedSession as drmaaSession
44

5-
from typing import TYPE_CHECKING
6-
from typing import Optional, List
5+
from typing import TYPE_CHECKING, Optional, TypedDict, Generator, Dict, cast
76

87
from functools import wraps
98

@@ -28,6 +27,16 @@
2827
drmaa.JobState.FAILED: State.FAILED
2928
}
3029

30+
JobID = int
31+
_TaskInstanceKeyDict = TypedDict('_TaskInstanceKeyDict', {
32+
'dag_id': str,
33+
'task_id': str,
34+
'run_id': str,
35+
'try_number': int
36+
})
37+
38+
JobTrackingType = Dict[JobID, _TaskInstanceKeyDict]
39+
3140

3241
def check_started(method):
3342
'''
@@ -58,49 +67,78 @@ def __init__(self, max_concurrent_jobs: Optional[int] = None):
5867
# Not yet implemented
5968
self.max_concurrent_jobs: Optional[int] = max_concurrent_jobs
6069

70+
def iter_scheduled_jobs(
71+
self) -> Generator[tuple[JobID, TaskInstanceKey], None, None]:
72+
'''
73+
Iterate over scheduled jobs
74+
'''
75+
for job_id, instance_info in self._get_or_create_job_ids().items():
76+
yield job_id, TaskInstanceKey(**instance_info)
77+
6178
@property
6279
def active_jobs(self) -> int:
6380
return len(self._get_or_create_job_ids())
6481

82+
@active_jobs.setter
83+
def active_jobs(self, val):
84+
self.active_jobs = val
85+
6586
# TODO: Make `scheduler_job_ids` configurable under [executor]
6687
@provide_session
6788
def _get_or_create_job_ids(self,
68-
session: Optional[Session] = None) -> List[int]:
89+
session: Optional[Session] = None
90+
) -> JobTrackingType:
91+
6992
current_jobs = Variable.get("scheduler_job_ids",
7093
default_var=None,
7194
deserialize_json=True)
7295
if not current_jobs:
96+
current_jobs = {}
7397
self.log.info("Setting up job tracking Airflow variable...")
74-
Variable.set("scheduler_job_ids", {"jobs": []},
98+
Variable.set("scheduler_job_ids",
99+
current_jobs,
75100
description="Scheduler Job ID tracking",
76101
serialize_json=True,
77102
session=session)
78103
self.log.info("Created `scheduler_job_ids` variable")
79104

80-
return current_jobs["jobs"]
105+
return current_jobs
81106

82107
@provide_session
83108
def _update_job_tracker(self,
84-
jobs: List[int],
109+
jobs: JobTrackingType,
85110
session: Optional[Session] = None) -> None:
86-
write_json = {"jobs": jobs}
87111
Variable.update("scheduler_job_ids",
88-
write_json,
112+
jobs,
89113
serialize_json=True,
90114
session=session)
91115

92-
def _drop_from_tracking(self, job_id: int) -> None:
116+
def _drop_from_tracking(self, job_id: JobID) -> None:
93117
self.log.info(
94-
"Removing Job {job_id} from tracking variable `scheduler_job_ids`")
95-
new_state = [j for j in self._get_or_create_job_ids() if j != job_id]
96-
self._update_job_tracker(new_state)
97-
self.log.info("Successfully removed {job_id} from `scheduler_job_ids`")
118+
f"Removing Job {job_id} from tracking variable `scheduler_job_ids`"
119+
)
120+
121+
jobs = self._get_or_create_job_ids()
122+
try:
123+
jobs.pop(job_id)
124+
except KeyError:
125+
self.log.error(f"Failed to remove {job_id}, job was not"
126+
" being tracked by Airflow!")
127+
else:
128+
self._update_job_tracker(jobs)
129+
self.log.info(
130+
f"Successfully removed {job_id} from `scheduler_job_ids`")
98131

99-
def _push_to_tracking(self, job_id: int) -> None:
132+
def _push_to_tracking(self, job_id: JobID, key: TaskInstanceKey) -> None:
100133
self.log.info(
101134
"Adding Job {job_id} to tracking variable `scheduler_job_ids`")
135+
136+
# Convert TaskInstanceKey to serializable form
137+
key_dict = _taskkey_to_dict(key)
138+
entry = {job_id: key_dict}
139+
102140
current_jobs = self._get_or_create_job_ids()
103-
current_jobs.append(job_id)
141+
current_jobs.update(entry)
104142
self._update_job_tracker(current_jobs)
105143
self.log.info("Successfully added {job_id} to `scheduler_job_ids`")
106144

@@ -114,7 +152,7 @@ def start(self) -> None:
114152
current_jobs = self._get_or_create_job_ids()
115153

116154
if current_jobs:
117-
print_jobs = "\n".join([j_id for j_id in current_jobs["jobs"]])
155+
print_jobs = "\n".join([f"{j_id}" for j_id in current_jobs])
118156
self.log.info(f"Jobs from previous session:\n{print_jobs}")
119157
else:
120158
self.log.info("No jobs are currently being tracked")
@@ -127,6 +165,7 @@ def end(self) -> None:
127165
self.log.info("Terminating DRMAA session")
128166
self.session.exit()
129167

168+
@check_started
130169
def sync(self) -> None:
131170
"""
132171
Called periodically by `airflow.executors.base_executor.BaseExecutor`'s
@@ -136,9 +175,8 @@ def sync(self) -> None:
136175
"""
137176

138177
# Go through currently running jobs and update state
139-
scheduled_jobs = self._get_or_create_job_ids()
140-
for job_id in scheduled_jobs:
141-
drmaa_status = self.jobStatus(job_id)
178+
for job_id, task_instance_key in self.iter_scheduled_jobs():
179+
drmaa_status = self.session.jobStatus(job_id)
142180
try:
143181
status = JOB_STATE_MAP[drmaa_status]
144182
except KeyError:
@@ -147,7 +185,8 @@ def sync(self) -> None:
147185
" Cannot be mapped into an Airflow TaskInstance State"
148186
" Will try again in next sync attempt...")
149187
else:
150-
self.change_state(status, None)
188+
# Need taskinstancekey
189+
self.change_state(task_instance_key, status)
151190
self._drop_from_tracking(job_id)
152191

153192
@check_started
@@ -162,7 +201,7 @@ def execute_async(self,
162201

163202
self.log.info(f"Submitting job {key} with command {command} with"
164203
f" configuration options:\n{executor_config})")
165-
jt = executor_config.drm2drmaa(self.session.createJobTemplate())
204+
jt = executor_config.get_drmaa_config(self.session.createJobTemplate())
166205

167206
# CommandType always begins with "airflow" binary command
168207
jt.remoteCommand = command[0]
@@ -174,9 +213,18 @@ def execute_async(self,
174213
job_id = self.session.runJob(jt)
175214

176215
self.log.info(f"Submitted Job {job_id}")
177-
self._push_to_tracking(job_id)
216+
self._push_to_tracking(job_id, key)
178217

179218
# Prevent memory leaks on C back-end, running jobs unaffected
180219
# https://drmaa-python.readthedocs.io/en/latest/drmaa.html
181220
self.session.deleteJobTemplate(jt)
182221
self.jobs_submitted += 1
222+
223+
224+
def _taskkey_to_dict(key: TaskInstanceKey) -> _TaskInstanceKeyDict:
225+
return {
226+
"dag_id": key.dag_id,
227+
"task_id": key.task_id,
228+
"run_id": key.run_id,
229+
"try_number": key.try_number
230+
}

0 commit comments

Comments
 (0)