Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "rustyface"
version = "0.1.2"
version = "0.1.3"
authors = ["Xinyu Bao <baoxinyuworks@163.com>"]
edition = "2021"
description = "A Huggingface downloading CLI tool written in Rust."
Expand Down
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,13 @@ rustyface --repository sentence-transformers/all-MiniLM-L6-v2 --tasks 4
- `--repository` is followed by the `repo_id` of the repository that you want to download from HuggingFace.
- `--tasks` is followed by the number of concurrent downloads. For example, 4 means downloading 4 files at once. It is recommended to use a lower number if your network conditions do not support higher concurrency.

# Configurations

If you would like to use an alternative mirror, or the HuggingFace official base url, you may configure your env like below:

```bash
export HF_ENDPOINT=https://your-endpoint.com
```

# Feedback & Further Development
Any participation is appreciated! Feel free to submit an issue, discussion or pull request. You can find me on WeChat: `baoxinyu2007` or Discord: `https://discord.gg/UYfZeuPy`
Expand Down
2 changes: 2 additions & 0 deletions src/constants.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
pub const DEFAULT_BASE_URL: &str = "https://hf-mirror.com";
pub const BASE_URL_ENV_VAR: &str = "HF_ENDPOINT";
38 changes: 19 additions & 19 deletions src/download.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ use indicatif;
use log::{debug, error, info, warn};
use sha2::Digest;

use crate::constants::{BASE_URL_ENV_VAR, DEFAULT_BASE_URL};

#[derive(Parser, Debug)]
#[command(version, about, long_about = None)]
pub struct DownloadArguments {
Expand All @@ -25,17 +27,11 @@ pub struct DownloadArguments {
impl DownloadArguments {
pub fn clone_repository(&mut self) -> Result<Repository, Box<dyn std::error::Error>> {
info!("Attempting to clone the repository: {}", &self.repository);
fn ensure_trailing_slash(s: &str) -> String {
if !s.ends_with('/') {
format!("{}{}", s, '/')
} else {
s.to_string()
}
}

// set the url with a base url
let mut url =
ensure_trailing_slash(option_env!("HF_ENDPOINT").unwrap_or("https://hf-mirror.com/"));
let mut url: String = ensure_trailing_slash(
&std::env::var(BASE_URL_ENV_VAR).unwrap_or(DEFAULT_BASE_URL.to_string()),
);
url.push_str(self.repository.as_str());

let path_to_join = std::path::Path::new(&self.repository);
Expand Down Expand Up @@ -84,9 +80,10 @@ impl DownloadArguments {
match entry {
Ok(result) => {
lfs_files.push(
result.strip_prefix(
self.repository_local_path.clone().unwrap()
)?.to_string_lossy().to_string()
result
.strip_prefix(self.repository_local_path.clone().unwrap())?
.to_string_lossy()
.to_string(),
);
debug!("LFS filepath extracted: {:?}", result);
}
Expand Down Expand Up @@ -137,9 +134,7 @@ impl DownloadArguments {
);
debug!("Constructed URL: {}", &url);

large_file_information.push(
LargeFileInformation::new(url, oid)
);
large_file_information.push(LargeFileInformation::new(url, oid));
} else {
debug!("OID not found in pointer file: {}", lfs_file);
}
Expand Down Expand Up @@ -360,9 +355,14 @@ pub struct LargeFileInformation {

impl LargeFileInformation {
fn new(url: String, sha256: String) -> Self {
return LargeFileInformation {
url: url,
sha256: sha256,
};
return LargeFileInformation { url, sha256 };
}
}

pub fn ensure_trailing_slash(s: &str) -> String {
if !s.ends_with('/') {
format!("{}{}", s, '/')
} else {
s.to_string()
}
}
5 changes: 4 additions & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use clap::Parser;
use constants::{BASE_URL_ENV_VAR, DEFAULT_BASE_URL};
use download::ensure_trailing_slash;
use log::{debug, error, info};

mod constants;
mod download;
mod utilities;

Expand Down Expand Up @@ -46,7 +49,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
Ok(lfs_files) => match arguments.extract_lfs_urls(
&result.path().parent().unwrap().to_path_buf(),
lfs_files,
&"https://hf-mirror.com".to_string(),
&std::env::var(BASE_URL_ENV_VAR).unwrap_or(DEFAULT_BASE_URL.to_string()),
) {
Ok(large_file_information) => {
arguments.download_files(large_file_information).await?
Expand Down
Loading