Skip to content

Commit 8816361

Browse files
atif93gpengzhi
authored andcommitted
Fixing the Google Drive download bug (#275)
* Fixing the Google Drive download bug * adding docstring * Changing some imports
1 parent e29fb89 commit 8816361

File tree

2 files changed

+15
-2
lines changed

2 files changed

+15
-2
lines changed

texar/torch/data/data_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
"read_words",
3434
"make_vocab",
3535
"count_file_lines",
36+
"get_filename"
3637
]
3738

3839
Py3 = sys.version_info[0] == 3
@@ -139,6 +140,9 @@ def _progress_hook(count, block_size, total_size):
139140
def _extract_google_drive_file_id(url: str) -> str:
140141
# id is between `/d/` and '/'
141142
url_suffix = url[url.find('/d/') + 3:]
143+
if url_suffix.find('/') == -1:
144+
# if there's no trailing '/'
145+
return url_suffix
142146
file_id = url_suffix[:url_suffix.find('/')]
143147
return file_id
144148

@@ -305,3 +309,12 @@ def _count_lines(fn):
305309
filenames = [filenames]
306310
num_lines = np.sum([_count_lines(fn) for fn in filenames]).item()
307311
return num_lines
312+
313+
314+
def get_filename(url: str) -> str:
315+
r"""Extracts the filename of the downloaded checkpoint file from the URL.
316+
"""
317+
if 'drive.google.com' in url:
318+
return _extract_google_drive_file_id(url)
319+
url, filename = os.path.split(url)
320+
return filename or os.path.basename(url)

texar/torch/modules/pretrained/pretrained_base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from torch import nn
2424

25-
from texar.torch.data.data_utils import maybe_download
25+
from texar.torch.data.data_utils import maybe_download, get_filename
2626
from texar.torch.hyperparams import HParams
2727
from texar.torch.module_base import ModuleBase
2828
from texar.torch.utils.types import MaybeList
@@ -200,7 +200,7 @@ def download_checkpoint(cls, pretrained_model_name: str,
200200

201201
if not cache_path.exists():
202202
if isinstance(download_path, str):
203-
filename = download_path.split('/')[-1]
203+
filename = get_filename(download_path)
204204
maybe_download(download_path, cache_path, extract=True)
205205

206206
# removing the compressed file

0 commit comments

Comments
 (0)