Skip to content

Commit f5690a7

Browse files
authored
enable strict check for arguments (#183)
* enable strict check for arguments * remove arguments in input_data
1 parent 9438565 commit f5690a7

File tree

6 files changed

+118
-30
lines changed

6 files changed

+118
-30
lines changed

dpdispatcher/dp_cloud_server_context.py

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -251,25 +251,7 @@ def machine_subfields(cls) -> List[Argument]:
251251
Argument("email", str, optional=False, doc="Email"),
252252
Argument("password", str, optional=False, doc="Password"),
253253
Argument("program_id", int, optional=False, doc="Program ID"),
254-
Argument("input_data", dict, [
255-
Argument("job_name", str, optional=True, doc="Job name"),
256-
Argument("image_name", str, optional=True, doc="Name of the image which run the job, optional "
257-
"when platform is not ali/oss."),
258-
Argument("disk_size", str, optional=True, doc="disk size (GB), optional "
259-
"when platform is not ali/oss."),
260-
Argument("scass_type", str, optional=False, doc="machine configuration."),
261-
Argument("platform", str, optional=False, doc="Job run in which platform."),
262-
Argument("log_file", str, optional=True, doc="location of log file."),
263-
Argument('checkpoint_files', [str, list], optional=True, doc="location of checkpoint files when "
264-
"it is list type. record file "
265-
"changes when it is string value "
266-
"'sync_files'"),
267-
Argument('checkpoint_time', int, optional=True, default=15, doc='interval of checkpoint data been '
268-
'stored minimum 15.'),
269-
Argument('backward_files', list, optional=True, doc='which files to be uploaded to remote '
270-
'resources. Upload all the files when it is '
271-
'None or empty.')
272-
], optional=False, doc="Configuration of job"),
254+
Argument("input_data", dict, optional=False, doc="Configuration of job"),
273255
], doc=doc_remote_profile)]
274256

275257

dpdispatcher/machine.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,11 @@ def load_from_dict(cls, machine_dict):
120120
except KeyError as e:
121121
dlog.error(f"KeyError:batch_type; cls.subclasses_dict:{cls.subclasses_dict}; batch_type:{batch_type};")
122122
raise e
123+
# check dict
124+
base = cls.arginfo()
125+
machine_dict = base.normalize_value(machine_dict, trim_pattern="_*")
126+
base.check_value(machine_dict, strict=True)
127+
123128
context = BaseContext.load_from_dict(machine_dict)
124129
machine = machine_class(context=context)
125130
return machine

dpdispatcher/submission.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,15 @@ def get_hash(self):
439439
def load_from_json(cls, json_file):
440440
with open(json_file, 'r') as f:
441441
task_dict = json.load(f)
442+
return cls.load_from_dict(task_dict)
443+
444+
@classmethod
445+
def load_from_dict(cls, task_dict: dict) -> "Task":
446+
# check dict
447+
base = cls.arginfo()
448+
task_dict = base.normalize_value(task_dict, trim_pattern="_*")
449+
base.check_value(task_dict, strict=True)
450+
442451
task = cls.deserialize(task_dict=task_dict)
443452
return task
444453

@@ -793,10 +802,15 @@ def load_from_json(cls, json_file):
793802

794803
@classmethod
795804
def load_from_dict(cls, resources_dict):
805+
# check dict
806+
base = cls.arginfo(detail_kwargs=False)
807+
resources_dict = base.normalize_value(resources_dict, trim_pattern="_*")
808+
base.check_value(resources_dict, strict=True)
809+
796810
return cls.deserialize(resources_dict=resources_dict)
797811

798812
@staticmethod
799-
def arginfo():
813+
def arginfo(detail_kwargs=True):
800814
doc_number_node = 'The number of node need for each `job`'
801815
doc_cpu_per_node = 'cpu numbers of each node assigned to each job.'
802816
doc_gpu_per_node = 'gpu numbers of each node assigned to each job.'
@@ -837,14 +851,20 @@ def arginfo():
837851
Argument("wait_time", [int, float], optional=True, doc=doc_wait_time, default=0)
838852
]
839853

840-
batch_variant = Variant(
841-
"batch_type",
842-
[machine.resources_arginfo() for machine in set(Machine.subclasses_dict.values())],
843-
optional=False,
844-
doc='The batch job system type loaded from machine/batch_type.',
845-
)
854+
if detail_kwargs:
855+
batch_variant = Variant(
856+
"batch_type",
857+
[machine.resources_arginfo() for machine in set(Machine.subclasses_dict.values())],
858+
optional=False,
859+
doc='The batch job system type loaded from machine/batch_type.',
860+
)
846861

847-
resources_format = Argument("resources", dict, resources_args, [batch_variant])
862+
resources_format = Argument("resources", dict, resources_args, [batch_variant])
863+
else:
864+
resources_args.append(
865+
Argument("kwargs", dict, optional=True, doc="Vary by different machines.")
866+
)
867+
resources_format = Argument("resources", dict, resources_args)
848868
return resources_format
849869

850870

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
with open(path.join(NAME, '_date.py'), 'w') as fp :
2020
fp.write('date = \'%s\'' % today)
2121

22-
install_requires=['paramiko', 'dargs>=0.2.6', 'requests', 'tqdm']
22+
install_requires=['paramiko', 'dargs>=0.2.9', 'requests', 'tqdm']
2323

2424
setuptools.setup(
2525
name=NAME,

tests/test_argcheck.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import os
2+
import sys
3+
import unittest
4+
5+
sys.path.insert(0, os.path.abspath(
6+
os.path.join(os.path.dirname(__file__), '..')))
7+
__package__ = 'tests'
8+
9+
from .context import setUpModule
10+
from .context import Task, Resources, Machine
11+
12+
class TestJob(unittest.TestCase):
13+
14+
def test_machine_argcheck(self):
15+
norm_dict = Machine.load_from_dict({
16+
"batch_type": "slurm",
17+
"context_type": "local",
18+
"local_root": "./",
19+
"remote_root": "/some/path",
20+
}).serialize()
21+
expected_dict = {
22+
'batch_type': 'Slurm',
23+
'context_type': 'LocalContext',
24+
'local_root': './',
25+
'remote_root': '/some/path',
26+
'remote_profile': {},
27+
}
28+
self.assertDictEqual(norm_dict, expected_dict)
29+
30+
def test_resources_argcheck(self):
31+
norm_dict = Resources.load_from_dict({
32+
"number_node": 1,
33+
"cpu_per_node": 2,
34+
"gpu_per_node": 0,
35+
"queue_name": "haha",
36+
"group_size": 1,
37+
"envs": {
38+
"aa": "bb",
39+
},
40+
"kwargs": {
41+
"cc": True,
42+
}
43+
}).serialize()
44+
expected_dict = {'cpu_per_node': 2,
45+
'custom_flags': [],
46+
'envs': {'aa': 'bb'},
47+
'gpu_per_node': 0,
48+
'group_size': 1,
49+
'kwargs': {
50+
"cc": True,
51+
},
52+
'module_list': [],
53+
'module_purge': False,
54+
'module_unload_list': [],
55+
'number_node': 1,
56+
'para_deg': 1,
57+
'queue_name': 'haha',
58+
'source_list': [],
59+
'strategy': {'if_cuda_multi_devices': False, 'ratio_unfinished': 0.0},
60+
'wait_time': 0,
61+
}
62+
self.assertDictEqual(norm_dict, expected_dict)
63+
64+
def test_task_argcheck(self):
65+
norm_dict = Task.load_from_dict({
66+
"command": "ls",
67+
"task_work_path": "./",
68+
"forward_files": [],
69+
"backward_files": [],
70+
"outlog": "out",
71+
"errlog": "err",
72+
}).serialize()
73+
expected_dict = {
74+
"command": "ls",
75+
"task_work_path": "./",
76+
"forward_files": [],
77+
"backward_files": [],
78+
"outlog": "out",
79+
"errlog": "err",
80+
}
81+
self.assertDictEqual(norm_dict, expected_dict)

tests/test_class_machine_dispatch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from .context import dargs
1717
from .context import DistributedShell, HDFSContext
1818
from .sample_class import SampleClass
19-
from dargs.dargs import ArgumentKeyError
19+
from dargs.dargs import ArgumentValueError
2020

2121
class TestMachineDispatch(unittest.TestCase):
2222
def setUp(self):
@@ -136,7 +136,7 @@ def test_context_err(self):
136136
'context_type': 'foo'
137137
}
138138
# with self.assertRaises(KeyError):
139-
with self.assertRaises(KeyError):
139+
with self.assertRaises(ArgumentValueError):
140140
Machine.load_from_dict(machine_dict=machine_dict)
141141

142142
def test_pbs(self):

0 commit comments

Comments
 (0)