|
4 | 4 | from dpdispatcher.base_context import BaseContext |
5 | 5 | import os, paramiko, tarfile, time |
6 | 6 | import uuid |
| 7 | +import shutil |
| 8 | +from functools import lru_cache |
7 | 9 | from glob import glob |
8 | 10 | from dpdispatcher import dlog |
9 | 11 | from dargs.dargs import Argument |
10 | 12 | from typing import List |
11 | 13 | import pathlib |
12 | 14 | # from dpdispatcher.submission import Machine |
13 | | -from dpdispatcher.utils import get_sha256, generate_totp |
| 15 | +from dpdispatcher.utils import get_sha256, generate_totp, rsync |
14 | 16 |
|
15 | 17 | class SSHSession (object): |
16 | 18 | def __init__(self, |
@@ -175,6 +177,27 @@ def arginfo(): |
175 | 177 | ssh_remote_profile_format = Argument("ssh_session", dict, ssh_remote_profile_args) |
176 | 178 | return ssh_remote_profile_format |
177 | 179 |
|
| 180 | + def put(self, from_f, to_f): |
| 181 | + if self.rsync_available: |
| 182 | + return rsync(from_f, self.remote + ":" + to_f) |
| 183 | + return self.sftp.put(from_f, to_f) |
| 184 | + |
| 185 | + def get(self, from_f, to_f): |
| 186 | + if self.rsync_available: |
| 187 | + return rsync(self.remote + ":" + from_f, to_f) |
| 188 | + return self.sftp.get(from_f, to_f) |
| 189 | + |
| 190 | + @property |
| 191 | + @lru_cache(maxsize=None) |
| 192 | + def rsync_available(self) -> bool: |
| 193 | + return (shutil.which("rsync") is not None and self.password is None |
| 194 | + and self.port == 22 and self.key_filename is None |
| 195 | + and self.passphrase is None) |
| 196 | + |
| 197 | + @property |
| 198 | + def remote(self) -> str: |
| 199 | + return "%s@%s" % (self.username, self.hostname) |
| 200 | + |
178 | 201 |
|
179 | 202 | class SSHContext(BaseContext): |
180 | 203 | def __init__ (self, |
@@ -519,7 +542,7 @@ def _put_files(self, |
519 | 542 | from_f = pathlib.PurePath(os.path.join(self.local_root, of)).as_posix() |
520 | 543 | to_f = pathlib.PurePath(os.path.join(self.remote_root, of)).as_posix() |
521 | 544 | try: |
522 | | - self.sftp.put(from_f, to_f) |
| 545 | + self.ssh_session.put(from_f, to_f) |
523 | 546 | except FileNotFoundError: |
524 | 547 | raise FileNotFoundError("from %s to %s @ %s : %s Error!"%(from_f, self.ssh_session.username, self.ssh_session.hostname, to_f)) |
525 | 548 | # remote extract |
@@ -547,7 +570,7 @@ def _get_files(self, |
547 | 570 | to_f = pathlib.PurePath(os.path.join(self.local_root, of)).as_posix() |
548 | 571 | if os.path.isfile(to_f) : |
549 | 572 | os.remove(to_f) |
550 | | - self.sftp.get(from_f, to_f) |
| 573 | + self.ssh_session.get(from_f, to_f) |
551 | 574 | # extract |
552 | 575 | cwd = os.getcwd() |
553 | 576 | os.chdir(self.local_root) |
|
0 commit comments