Skip to content

Commit 1e9eb06

Browse files
authored
improve _get_files in SSH context (#473)
1 parent 6251619 commit 1e9eb06

File tree

2 files changed

+62
-49
lines changed

2 files changed

+62
-49
lines changed

dpdispatcher/contexts/ssh_context.py

Lines changed: 52 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,11 @@ def exec_command(self, cmd):
295295
assert self.ssh is not None
296296
try:
297297
return self.ssh.exec_command(cmd)
298-
except (paramiko.ssh_exception.SSHException, socket.timeout, EOFError) as e:
298+
except (
299+
paramiko.ssh_exception.SSHException,
300+
socket.timeout,
301+
EOFError,
302+
) as e:
299303
# SSH session not active
300304
# retry for up to 3 times
301305
# ensure alive
@@ -355,10 +359,18 @@ def arginfo():
355359
),
356360
Argument("timeout", int, optional=True, default=10, doc=doc_timeout),
357361
Argument(
358-
"totp_secret", str, optional=True, default=None, doc=doc_totp_secret
362+
"totp_secret",
363+
str,
364+
optional=True,
365+
default=None,
366+
doc=doc_totp_secret,
359367
),
360368
Argument(
361-
"tar_compress", bool, optional=True, default=True, doc=doc_tar_compress
369+
"tar_compress",
370+
bool,
371+
optional=True,
372+
default=True,
373+
doc=doc_tar_compress,
362374
),
363375
Argument(
364376
"look_for_keys",
@@ -603,7 +615,10 @@ def upload(
603615
directory_list,
604616
)
605617
self._walk_directory(
606-
submission.forward_common_files, self.local_root, file_list, directory_list
618+
submission.forward_common_files,
619+
self.local_root,
620+
file_list,
621+
directory_list,
607622
)
608623

609624
# convert to relative path to local_root
@@ -736,7 +751,8 @@ def download(
736751
file_list.extend(submission.backward_common_files)
737752
if len(file_list) > 0:
738753
self._get_files(
739-
file_list, tar_compress=self.remote_profile.get("tar_compress", None)
754+
file_list,
755+
tar_compress=self.remote_profile.get("tar_compress", None),
740756
)
741757

742758
def block_checkcall(self, cmd, asynchronously=False, stderr_whitelist=None):
@@ -793,18 +809,23 @@ def write_file(self, fname, write_str):
793809
fname = pathlib.PurePath(os.path.join(self.remote_root, fname)).as_posix()
794810
# to prevent old file from being overwritten but cancelled, create a temporary file first
795811
# when it is fully written, rename it to the original file name
796-
with self.sftp.open(fname + "~", "w") as fp:
797-
fp.write(write_str)
812+
temp_fname = fname + "_tmp"
813+
try:
814+
with self.sftp.open(temp_fname, "w") as fp:
815+
fp.write(write_str)
816+
# Rename the temporary file
817+
self.block_checkcall(f"mv {shlex.quote(temp_fname)} {shlex.quote(fname)}")
798818
# sftp.rename may throw OSError
799-
self.block_checkcall(
800-
"mv {} {}".format(shlex.quote(fname + "~"), shlex.quote(fname))
801-
)
819+
except OSError as e:
820+
dlog.exception(f"Error writing to file {fname}")
821+
raise e
802822

803823
def read_file(self, fname):
804824
assert self.remote_root is not None
805825
self.ssh_session.ensure_alive()
806826
with self.sftp.open(
807-
pathlib.PurePath(os.path.join(self.remote_root, fname)).as_posix(), "r"
827+
pathlib.PurePath(os.path.join(self.remote_root, fname)).as_posix(),
828+
"r",
808829
) as fp:
809830
ret = fp.read().decode("utf-8")
810831
return ret
@@ -945,36 +966,28 @@ def _get_files(self, files, tar_compress=True):
945966
per_nfile = 100
946967
ntar = len(files) // per_nfile + 1
947968
if ntar <= 1:
948-
try:
949-
self.block_checkcall(
950-
"tar {} {} {}".format(
951-
tar_command,
952-
shlex.quote(of),
953-
" ".join([shlex.quote(file) for file in files]),
954-
)
955-
)
956-
except RuntimeError as e:
957-
if "No such file or directory" in str(e):
958-
raise FileNotFoundError(
959-
"Any of the backward files does not exist in the remote directory."
960-
) from e
961-
raise e
969+
file_list = " ".join([shlex.quote(file) for file in files])
970+
tar_cmd = f"tar {tar_command} {shlex.quote(of)} {file_list}"
962971
else:
963-
file_list_file = os.path.join(
964-
self.remote_root, ".tmp.tar." + str(uuid.uuid4())
965-
)
972+
file_list_file = pathlib.PurePath(
973+
os.path.join(self.remote_root, f".tmp_tar_{uuid.uuid4()}")
974+
).as_posix()
966975
self.write_file(file_list_file, "\n".join(files))
967-
try:
968-
self.block_checkcall(
969-
f"tar {tar_command} {shlex.quote(of)} -T {shlex.quote(file_list_file)}"
970-
)
971-
except RuntimeError as e:
972-
if "No such file or directory" in str(e):
973-
raise FileNotFoundError(
974-
"Any of the backward files does not exist in the remote directory."
975-
) from e
976-
raise e
977-
# trans
976+
tar_cmd = (
977+
f"tar {tar_command} {shlex.quote(of)} -T {shlex.quote(file_list_file)}"
978+
)
979+
980+
# Execute the tar command remotely
981+
try:
982+
self.block_checkcall(tar_cmd)
983+
except RuntimeError as e:
984+
if "No such file or directory" in str(e):
985+
raise FileNotFoundError(
986+
"Backward files do not exist in the remote directory."
987+
) from e
988+
raise e
989+
990+
# Transfer the archive from remote to local
978991
from_f = pathlib.PurePath(os.path.join(self.remote_root, of)).as_posix()
979992
to_f = pathlib.PurePath(os.path.join(self.local_root, of)).as_posix()
980993
if os.path.isfile(to_f):

dpdispatcher/machines/pbs.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -215,12 +215,13 @@ def gen_script_header(self, job):
215215
# resources.number_node is not used in SGE
216216
resources = job.resources
217217
job_name = resources.kwargs.get("job_name", "wDPjob")
218-
sge_pe_name = resources.kwargs.get("sge_pe_name", "mpi")
218+
pe_name = resources.kwargs.get("pe_name", "mpi")
219219
sge_script_header_dict = {}
220220
sge_script_header_dict["select_node_line"] = f"#$ -N {job_name}\n"
221221
sge_script_header_dict["select_node_line"] += (
222-
f"#$ -pe {sge_pe_name} {resources.cpu_per_node}\n"
222+
f"#$ -pe {pe_name} {resources.cpu_per_node}\n"
223223
)
224+
224225
if resources.queue_name != "":
225226
sge_script_header_dict["select_node_line"] += (
226227
f"#$ -q {resources.queue_name}"
@@ -266,8 +267,7 @@ def check_status(self, job):
266267
err_str = stderr.read().decode("utf-8")
267268
if ret != 0:
268269
raise RuntimeError(
269-
"status command qstat fails to execute. erro info: %s return code %d"
270-
% (err_str, ret)
270+
f"status command qstat fails to execute. erro info: {err_str} return code {ret}"
271271
)
272272
status_text_list = stdout.read().decode("utf-8").split("\n")
273273
for txt in status_text_list:
@@ -280,8 +280,7 @@ def check_status(self, job):
280280
if self.check_finish_tag(job=job):
281281
return JobStatus.finished
282282
dlog.info(
283-
"not tag_finished detected, execute sync command and wait. count "
284-
+ str(count)
283+
f"not tag_finished detected, execute sync command and wait. count {count}"
285284
)
286285
self.context.block_call("sync")
287286
import time
@@ -307,15 +306,15 @@ def check_finish_tag(self, job):
307306
def resources_subfields(cls) -> List[Argument]:
308307
"""Generate the resources subfields.
309308
310-
sge_pe_name : str
309+
pe_name : str
311310
The parallel environment name of SGE.
312311
313312
Returns
314313
-------
315314
list[Argument]
316315
resources subfields
317316
"""
318-
doc_sge_pe_name = "The parallel environment name of SGE."
317+
doc_pe_name = "The parallel environment name of SGE system."
319318
doc_job_name = "The name of SGE's job."
320319

321320
return [
@@ -324,11 +323,12 @@ def resources_subfields(cls) -> List[Argument]:
324323
dict,
325324
[
326325
Argument(
327-
"sge_pe_name",
326+
"pe_name",
328327
str,
329328
optional=True,
330329
default="mpi",
331-
doc=doc_sge_pe_name,
330+
doc=doc_pe_name,
331+
alias=["sge_pe_name"],
332332
),
333333
Argument(
334334
"job_name",

0 commit comments

Comments
 (0)