Skip to content

Commit 364eef9

Browse files
authored
Merge pull request #125 from felix5572/master
2 parents 6b264c7 + 12b276a commit 364eef9

File tree

7 files changed

+109
-17
lines changed

7 files changed

+109
-17
lines changed

dpdispatcher/submission.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -156,14 +156,19 @@ def run_submission(self, *, exit_on_submit=False, clean=True):
156156
if not self.belonging_jobs:
157157
self.generate_jobs()
158158
self.try_recover_from_json()
159+
self.update_submission_state()
159160
if self.check_all_finished():
160161
dlog.info('info:check_all_finished: True')
161162
else:
162163
dlog.info('info:check_all_finished: False')
163164
self.upload_jobs()
164165
self.handle_unexpected_submission_state()
165166
self.submission_to_json()
166-
time.sleep(1)
167+
time.sleep(1)
168+
self.update_submission_state()
169+
self.check_all_finished()
170+
self.handle_unexpected_submission_state()
171+
167172
while not self.check_all_finished():
168173
if exit_on_submit is True:
169174
dlog.info(f"submission succeeded: {self.submission_hash}")
@@ -179,6 +184,7 @@ def run_submission(self, *, exit_on_submit=False, clean=True):
179184
dlog.debug(self.serialize())
180185
raise e
181186
else:
187+
self.update_submission_state()
182188
self.handle_unexpected_submission_state()
183189
finally:
184190
pass
@@ -189,7 +195,7 @@ def run_submission(self, *, exit_on_submit=False, clean=True):
189195
self.clean_jobs()
190196
return self.serialize()
191197

192-
def get_submission_state(self):
198+
def update_submission_state(self):
193199
"""check whether all the jobs in the submission.
194200
195201
Notes
@@ -201,7 +207,7 @@ def get_submission_state(self):
201207
# finished job will be finished for ever, skip
202208
continue
203209
job.get_job_state()
204-
dlog.debug(f"debug:get_submission_state: job: {job.job_hash}, {job.job_id}, {repr(job.job_state)}")
210+
dlog.debug(f"debug:update_submission_state: job: {job.job_hash}, {job.job_id}, {job.job_state}")
205211
# self.submission_to_json()
206212

207213
def handle_unexpected_submission_state(self):
@@ -217,9 +223,10 @@ def handle_unexpected_submission_state(self):
217223
self.submission_to_json()
218224
raise RuntimeError(
219225
f"Meet errors will handle unexpected submission state.\n"
220-
f"Debug information: remote_root=={self.remote_root}.\n"
226+
f"Debug information: remote_root=={self.machine.context.remote_root}.\n"
221227
f"Debug information: submission_hash=={self.submission_hash}.\n"
222228
f"Please check the dirs and scripts in remote_root"
229+
f"The job information mentioned above may help"
223230
) from e
224231

225232
# not used here, submitting job is in handle_unexpected_submission_state.
@@ -231,14 +238,16 @@ def handle_unexpected_submission_state(self):
231238
# job.submit_job()
232239
# self.get_submission_state()
233240

241+
# def update_submi
242+
234243
def check_all_finished(self):
235244
"""check whether all the jobs in the submission.
236245
237246
Notes
238247
-----
239248
This method will not handle unexpected job state in the submission.
240249
"""
241-
self.get_submission_state()
250+
# self.update_submission_state()
242251
if any( (job.job_state in [JobStatus.terminated, JobStatus.unknown] ) for job in self.belonging_jobs):
243252
self.submission_to_json()
244253
if any( (job.job_state in [JobStatus.running,
@@ -294,7 +303,7 @@ def clean_jobs(self):
294303
self.machine.context.clean()
295304

296305
def submission_to_json(self):
297-
# self.get_submission_state()
306+
# self.update_submission_state()
298307
write_str = json.dumps(self.serialize(), indent=4, default=str)
299308
submission_file_name = "{submission_hash}.json".format(submission_hash=self.submission_hash)
300309
self.machine.context.write_file(submission_file_name, write_str=write_str)
@@ -532,11 +541,10 @@ def handle_unexpected_job_state(self):
532541
raise RuntimeError(f"job:{self.job_hash} {self.job_id} failed {self.fail_count} times.job_detail:{self}")
533542
self.submit_job()
534543
dlog.info("job:{job_hash} re-submit after terminated; new job_id is {job_id}".format(job_hash=self.job_hash, job_id=self.job_id))
544+
time.sleep(0.2)
535545
self.get_job_state()
536-
dlog.info("job:{job_hash} job_id:{job_id} after re-submitting; the state now is {job_state}".format(
537-
job_hash=self.job_hash,
538-
job_id=self.job_id,
539-
job_state=JobStatus(self.job_state)))
546+
dlog.info(f"job:{self.job_hash} job_id:{self.job_id} after re-submitting; the state now is {repr(self.job_state)}")
547+
self.handle_unexpected_job_state()
540548

541549
if job_state == JobStatus.unsubmitted:
542550
dlog.info(f"job: {self.job_hash} unsubmitted; submit it")
@@ -758,3 +766,5 @@ def arginfo():
758766
resources_format = Argument("resources", dict, resources_args)
759767
return resources_format
760768

769+
770+
# %%

tests/graph.pb

-1.55 MB
Binary file not shown.
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
{
2+
"machine":{
3+
"batch_type": "Shell",
4+
"context_type": "LazyLocalContext",
5+
"local_root": "./test_shell_trival_dir"
6+
},
7+
"resources":{
8+
"number_node": 1,
9+
"cpu_per_node": 4,
10+
"gpu_per_node": 0,
11+
"queue_name": "CPU",
12+
"group_size": 2
13+
}
14+
}

tests/test_class_submission.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,9 @@ def test_submission_to_json(self):
5151
pass
5252

5353
@patch('dpdispatcher.Submission.submission_to_json')
54-
@patch('dpdispatcher.Submission.get_submission_state')
55-
def test_check_all_finished(self, patch_get_submission_state, patch_submission_to_json):
56-
patch_get_submission_state = MagicMock(return_value=None)
54+
@patch('dpdispatcher.Submission.update_submission_state')
55+
def test_check_all_finished(self, patch_update_submission_state, patch_submission_to_json):
56+
patch_update_submission_state = MagicMock(return_value=None)
5757
patch_submission_to_json = MagicMock(return_value=None)
5858

5959
self.submission.belonging_jobs[0].job_state = JobStatus.running

tests/test_shell_trival.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,41 @@ def test_shell_trival(self):
4848
f2 = os.path.join('test_shell_trival_dir/', 'parent_dir/', dir, 'out.txt')
4949
self.assertEqual(get_file_md5(f1), get_file_md5(f2))
5050

51+
def test_shell_fail(self):
52+
with open('jsons/machine_local_shell.json', 'r') as f:
53+
machine_dict = json.load(f)
54+
55+
machine = Machine(**machine_dict['machine'])
56+
resources = Resources(**machine_dict['resources'])
57+
58+
task = Task(command='cat mock_fail_task.txt && exit 1',
59+
task_work_path='./',
60+
forward_files=[],
61+
backward_files=['out.txt'],
62+
outlog='out.txt')
63+
64+
task_list = [task,]
65+
66+
submission = Submission(work_base='fail_dir/',
67+
machine=machine,
68+
resources=resources,
69+
forward_common_files=[],
70+
backward_common_files=[],
71+
task_list=task_list
72+
)
73+
with self.assertRaises(RuntimeError):
74+
submission.run_submission()
75+
76+
def test_shell_recover(self):
77+
with open('jsons/machine_lazylocal_shell.json', 'r') as f:
78+
machine_dict = json.load(f)
79+
80+
machine = Machine(**machine_dict['machine'])
81+
resources = Resources(**machine_dict['resources'])
82+
83+
pass
84+
85+
5186
@classmethod
5287
def tearDownClass(cls):
5388
shutil.rmtree('tmp_shell_trival_dir/')
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# mock file for unittest; test when dpdispatcher meets fail task

tests/test_ssh_context.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,47 @@
11
import os,sys,json,glob,shutil,uuid,getpass
22
import unittest
3-
from pathlib import Path
3+
import pathlib
4+
from paramiko.ssh_exception import NoValidConnectionsError
5+
from paramiko.ssh_exception import SSHException
46

57
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
68
__package__ = 'tests'
79
from .context import SSHContext, SSHSession
8-
from .context import setUpModule
10+
from .context import Machine
11+
from .sample_class import SampleClass
912

1013
class TestSSHContext(unittest.TestCase):
14+
@classmethod
15+
def setUpClass(cls):
16+
with open('jsons/machine_ali_ehpc.json', 'r') as f:
17+
mdata = json.load(f)
18+
try:
19+
cls.machine = Machine.load_from_dict(mdata['machine'])
20+
except SSHException:
21+
raise unittest.SkipTest("SSHException ssh cannot connect")
22+
cls.submission = SampleClass.get_sample_submission()
23+
cls.submission.bind_machine(cls.machine)
24+
cls.submission_hash = cls.submission.submission_hash
25+
file_list = ['bct-1/log.lammps', 'bct-2/log.lammps', 'bct-3/log.lammps', 'bct-4/log.lammps']
26+
for file in file_list:
27+
cls.machine.context.write_file(file, '# mock log')
28+
1129
def setUp(self):
12-
self.tmp_local_root = 'test_context_dir/'
13-
self.tmp_remote_root = 'tmp_ssh_context_dir/'
30+
self.context = self.__class__.machine.context
31+
32+
def test_ssh_session(self):
33+
self.assertIsInstance(
34+
self.__class__.machine.context.ssh_session, SSHSession
35+
)
36+
37+
def test_upload(self):
38+
self.context.upload(self.__class__.submission)
39+
check_file_list = ['graph.pb', 'bct-1/conf.lmp', 'bct-4/input.lammps']
40+
for file in check_file_list:
41+
self.assertTrue(self.context.check_file_exists(os.path.join(self.context.remote_root, file)))
42+
43+
def test_download(self):
44+
self.context.download(self.__class__.submission)
45+
1446

1547

0 commit comments

Comments
 (0)