diff --git a/Cargo.toml b/Cargo.toml index 55b1f6a..f533ddb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rustyface" -version = "0.1.2" +version = "0.1.3" authors = ["Xinyu Bao "] edition = "2021" description = "A Huggingface downloading CLI tool written in Rust." diff --git a/README.md b/README.md index 7d9c89d..3d12ac8 100644 --- a/README.md +++ b/README.md @@ -85,6 +85,13 @@ rustyface --repository sentence-transformers/all-MiniLM-L6-v2 --tasks 4 - `--repository` is followed by the `repo_id` of the repository that you want to download from HuggingFace. - `--tasks` is followed by the number of concurrent downloads. For example, 4 means downloading 4 files at once. It is recommended to use a lower number if your network conditions do not support higher concurrency. +# Configurations + +If you would like to use an alternative mirror, or the HuggingFace official base url, you may configure your env like below: + +```bash +export HF_ENDPOINT=https://your-endpoint.com +``` # Feedback & Further Development Any participation is appreciated! Feel free to submit an issue, discussion or pull request. You can find me on WeChat: `baoxinyu2007` or Discord: `https://discord.gg/UYfZeuPy` diff --git a/src/constants.rs b/src/constants.rs new file mode 100644 index 0000000..f0b5ddb --- /dev/null +++ b/src/constants.rs @@ -0,0 +1,2 @@ +pub const DEFAULT_BASE_URL: &str = "https://hf-mirror.com"; +pub const BASE_URL_ENV_VAR: &str = "HF_ENDPOINT"; diff --git a/src/download.rs b/src/download.rs index c7bbbf1..ecfe964 100644 --- a/src/download.rs +++ b/src/download.rs @@ -8,6 +8,8 @@ use indicatif; use log::{debug, error, info, warn}; use sha2::Digest; +use crate::constants::{BASE_URL_ENV_VAR, DEFAULT_BASE_URL}; + #[derive(Parser, Debug)] #[command(version, about, long_about = None)] pub struct DownloadArguments { @@ -25,17 +27,11 @@ pub struct DownloadArguments { impl DownloadArguments { pub fn clone_repository(&mut self) -> Result> { info!("Attempting to clone the repository: {}", &self.repository); - fn ensure_trailing_slash(s: &str) -> String { - if !s.ends_with('/') { - format!("{}{}", s, '/') - } else { - s.to_string() - } - } // set the url with a base url - let mut url = - ensure_trailing_slash(option_env!("HF_ENDPOINT").unwrap_or("https://hf-mirror.com/")); + let mut url: String = ensure_trailing_slash( + &std::env::var(BASE_URL_ENV_VAR).unwrap_or(DEFAULT_BASE_URL.to_string()), + ); url.push_str(self.repository.as_str()); let path_to_join = std::path::Path::new(&self.repository); @@ -84,9 +80,10 @@ impl DownloadArguments { match entry { Ok(result) => { lfs_files.push( - result.strip_prefix( - self.repository_local_path.clone().unwrap() - )?.to_string_lossy().to_string() + result + .strip_prefix(self.repository_local_path.clone().unwrap())? + .to_string_lossy() + .to_string(), ); debug!("LFS filepath extracted: {:?}", result); } @@ -137,9 +134,7 @@ impl DownloadArguments { ); debug!("Constructed URL: {}", &url); - large_file_information.push( - LargeFileInformation::new(url, oid) - ); + large_file_information.push(LargeFileInformation::new(url, oid)); } else { debug!("OID not found in pointer file: {}", lfs_file); } @@ -360,9 +355,14 @@ pub struct LargeFileInformation { impl LargeFileInformation { fn new(url: String, sha256: String) -> Self { - return LargeFileInformation { - url: url, - sha256: sha256, - }; + return LargeFileInformation { url, sha256 }; + } +} + +pub fn ensure_trailing_slash(s: &str) -> String { + if !s.ends_with('/') { + format!("{}{}", s, '/') + } else { + s.to_string() } } diff --git a/src/main.rs b/src/main.rs index 9b1ab97..f19ee70 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,9 @@ use clap::Parser; +use constants::{BASE_URL_ENV_VAR, DEFAULT_BASE_URL}; +use download::ensure_trailing_slash; use log::{debug, error, info}; +mod constants; mod download; mod utilities; @@ -46,7 +49,7 @@ async fn main() -> Result<(), Box> { Ok(lfs_files) => match arguments.extract_lfs_urls( &result.path().parent().unwrap().to_path_buf(), lfs_files, - &"https://hf-mirror.com".to_string(), + &std::env::var(BASE_URL_ENV_VAR).unwrap_or(DEFAULT_BASE_URL.to_string()), ) { Ok(large_file_information) => { arguments.download_files(large_file_information).await?