Skip to content

Commit b8ee22c

Browse files
authored
common : add minimalist multi-thread progress bar (#17602)
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
1 parent 2eaa2c6 commit b8ee22c

File tree

1 file changed

+69
-25
lines changed

1 file changed

+69
-25
lines changed

common/download.cpp

Lines changed: 69 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
#include <filesystem>
1313
#include <fstream>
1414
#include <future>
15+
#include <map>
16+
#include <mutex>
1517
#include <regex>
1618
#include <string>
1719
#include <thread>
@@ -472,36 +474,79 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string &
472474

473475
#elif defined(LLAMA_USE_HTTPLIB)
474476

475-
static bool is_output_a_tty() {
477+
class ProgressBar {
478+
static inline std::mutex mutex;
479+
static inline std::map<const ProgressBar *, int> lines;
480+
static inline int max_line = 0;
481+
482+
static void cleanup(const ProgressBar * line) {
483+
lines.erase(line);
484+
if (lines.empty()) {
485+
max_line = 0;
486+
}
487+
}
488+
489+
static bool is_output_a_tty() {
476490
#if defined(_WIN32)
477-
return _isatty(_fileno(stdout));
491+
return _isatty(_fileno(stdout));
478492
#else
479-
return isatty(1);
493+
return isatty(1);
480494
#endif
481-
}
495+
}
482496

483-
static void print_progress(size_t current, size_t total) {
484-
if (!is_output_a_tty()) {
485-
return;
497+
public:
498+
ProgressBar() = default;
499+
500+
~ProgressBar() {
501+
std::lock_guard<std::mutex> lock(mutex);
502+
cleanup(this);
486503
}
487504

488-
if (!total) {
489-
return;
505+
void update(size_t current, size_t total) {
506+
if (!is_output_a_tty()) {
507+
return;
508+
}
509+
510+
if (!total) {
511+
return;
512+
}
513+
514+
std::lock_guard<std::mutex> lock(mutex);
515+
516+
if (lines.find(this) == lines.end()) {
517+
lines[this] = max_line++;
518+
std::cout << "\n";
519+
}
520+
int lines_up = max_line - lines[this];
521+
522+
size_t width = 50;
523+
size_t pct = (100 * current) / total;
524+
size_t pos = (width * current) / total;
525+
526+
std::cout << "\033[s";
527+
528+
if (lines_up > 0) {
529+
std::cout << "\033[" << lines_up << "A";
530+
}
531+
std::cout << "\033[2K\r["
532+
<< std::string(pos, '=')
533+
<< (pos < width ? ">" : "")
534+
<< std::string(width - pos, ' ')
535+
<< "] " << std::setw(3) << pct << "% ("
536+
<< current / (1024 * 1024) << " MB / "
537+
<< total / (1024 * 1024) << " MB) "
538+
<< "\033[u";
539+
540+
std::cout.flush();
541+
542+
if (current == total) {
543+
cleanup(this);
544+
}
490545
}
491546

492-
size_t width = 50;
493-
size_t pct = (100 * current) / total;
494-
size_t pos = (width * current) / total;
495-
496-
std::cout << "["
497-
<< std::string(pos, '=')
498-
<< (pos < width ? ">" : "")
499-
<< std::string(width - pos, ' ')
500-
<< "] " << std::setw(3) << pct << "% ("
501-
<< current / (1024 * 1024) << " MB / "
502-
<< total / (1024 * 1024) << " MB)\r";
503-
std::cout.flush();
504-
}
547+
ProgressBar(const ProgressBar &) = delete;
548+
ProgressBar & operator=(const ProgressBar &) = delete;
549+
};
505550

506551
static bool common_pull_file(httplib::Client & cli,
507552
const std::string & resolve_path,
@@ -523,6 +568,7 @@ static bool common_pull_file(httplib::Client & cli,
523568
const char * func = __func__; // avoid __func__ inside a lambda
524569
size_t downloaded = existing_size;
525570
size_t progress_step = 0;
571+
ProgressBar bar;
526572

527573
auto res = cli.Get(resolve_path, headers,
528574
[&](const httplib::Response &response) {
@@ -554,16 +600,14 @@ static bool common_pull_file(httplib::Client & cli,
554600
progress_step += len;
555601

556602
if (progress_step >= total_size / 1000 || downloaded == total_size) {
557-
print_progress(downloaded, total_size);
603+
bar.update(downloaded, total_size);
558604
progress_step = 0;
559605
}
560606
return true;
561607
},
562608
nullptr
563609
);
564610

565-
std::cout << "\n";
566-
567611
if (!res) {
568612
LOG_ERR("%s: error during download. Status: %d\n", __func__, res ? res->status : -1);
569613
return false;

0 commit comments

Comments
 (0)