Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
335 changes: 146 additions & 189 deletions docs/precision_checker_guide.md

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions example/gpt2/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#endif
#include "infini_train/include/nn/parallel/utils.h"
#include "infini_train/include/utils/precision_check_config.h"
#include "infini_train/include/utils/precision_checker.h"

#include "example/common/tiny_shakespeare_dataset.h"
#include "example/common/tokenizer.h"
Expand Down Expand Up @@ -257,6 +258,9 @@ void Train(const nn::parallel::Rank &rank) {
LOG(INFO) << "start training";

for (int step = 0; step < FLAGS_num_iteration + 1; ++step) {
// Reset precision check counters at start of each iteration for file overwrite
utils::PrecisionChecker::ResetCounters();

const bool last_step = step == FLAGS_num_iteration;

const auto iter_start = std::chrono::high_resolution_clock::now();
Expand Down
4 changes: 4 additions & 0 deletions example/llama3/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "infini_train/include/nn/parallel/process_group.h"
#include "infini_train/include/nn/parallel/utils.h"
#include "infini_train/include/utils/precision_check_config.h"
#include "infini_train/include/utils/precision_checker.h"

#include "example/common/tiny_shakespeare_dataset.h"
#include "example/common/tokenizer.h"
Expand Down Expand Up @@ -232,6 +233,9 @@ void Train(const nn::parallel::Rank &rank) {
LOG(INFO) << "Rank " << rank.GlobalRank() << ": start training";

for (int step = 0; step < FLAGS_num_iteration + 1; ++step) {
// Reset precision check counters at start of each iteration for file overwrite
utils::PrecisionChecker::ResetCounters();

const bool last_step = step == FLAGS_num_iteration;

const auto iter_start = std::chrono::high_resolution_clock::now();
Expand Down
14 changes: 10 additions & 4 deletions infini_train/include/utils/precision_check_config.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include <string>
#include <unordered_map>

namespace infini_train {
namespace utils {
Expand All @@ -9,10 +10,9 @@ enum class PrecisionCheckLevel { OFF = 0, MODULE = 1, FUNCTION = 2 };

struct PrecisionCheckConfig {
PrecisionCheckLevel level = PrecisionCheckLevel::OFF;
std::string output_path = ""; // empty=console(rank0), non-empty=file(all ranks)
bool output_md5 = false; // output MD5 hash or tensor values
std::string format = "simple"; // "simple" or "table"
std::string baseline_path = ""; // baseline file path for comparison
std::string output_path = "./precision_check"; // Output path (default)
std::string format = "simple"; // "simple" or "md5"
bool save_tensors = false; // Whether to output .npy file

// Parse from "key=value,key=value" string
static PrecisionCheckConfig Parse(const std::string &config_str);
Expand All @@ -23,10 +23,16 @@ class PrecisionCheckEnv {
static PrecisionCheckEnv &Instance();
void Init(const PrecisionCheckConfig &config);
const PrecisionCheckConfig &GetConfig() const;
const std::string &GetOutputPath() const;

// Tensor counter management for file overwrite across iterations (thread-local)
static int GetAndIncrementCounter(const std::string &key);
static void ResetCounters();

private:
PrecisionCheckEnv() = default;
PrecisionCheckConfig config_;
std::string timestamped_path_; // Actual output path (with timestamp)
};

} // namespace utils
Expand Down
3 changes: 3 additions & 0 deletions infini_train/include/utils/precision_checker.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ class PrecisionChecker {
static void RegisterForModule(nn::Module *module, const std::string &name = "",
const Config &config = DefaultConfig());

// Reset tensor counters (call at start of each iteration for file overwrite)
static void ResetCounters();

private:
static void CheckTensors(const std::string &stage, const std::string &name,
const std::vector<std::shared_ptr<Tensor>> &tensors, const Config &config);
Expand Down
45 changes: 33 additions & 12 deletions infini_train/src/utils/precision_check_config.cc
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
#include "infini_train/include/utils/precision_check_config.h"

#include <chrono>
#include <filesystem>
#include <sstream>
#include <unordered_map>

namespace infini_train::utils {

namespace {
// Thread-local tensor counter for precision check file indexing
thread_local std::unordered_map<std::string, int> g_tensor_counter;
} // namespace

PrecisionCheckConfig PrecisionCheckConfig::Parse(const std::string &config_str) {
PrecisionCheckConfig config;
if (config_str.empty()) {
Expand All @@ -25,20 +32,14 @@ PrecisionCheckConfig PrecisionCheckConfig::Parse(const std::string &config_str)
int level_int = std::stoi(kv_map["level"]);
config.level = static_cast<PrecisionCheckLevel>(level_int);
}
if (kv_map.count("output_path")) {
config.output_path = kv_map["output_path"];
}
if (kv_map.count("output_md5")) {
config.output_md5 = (kv_map["output_md5"] == "true" || kv_map["output_md5"] == "1");
}
if (kv_map.count("baseline")) {
config.baseline_path = kv_map["baseline"];
if (kv_map.count("path")) {
config.output_path = kv_map["path"];
}
if (kv_map.count("format")) {
config.format = kv_map["format"];
} else if (!config.baseline_path.empty()) {
// Default to table format when baseline is specified
config.format = "table";
}
if (kv_map.count("save_tensors")) {
config.save_tensors = (kv_map["save_tensors"] == "true" || kv_map["save_tensors"] == "1");
}
return config;
}
Expand All @@ -48,8 +49,28 @@ PrecisionCheckEnv &PrecisionCheckEnv::Instance() {
return instance;
}

void PrecisionCheckEnv::Init(const PrecisionCheckConfig &config) { config_ = config; }
void PrecisionCheckEnv::Init(const PrecisionCheckConfig &config) {
config_ = config;
if (config_.level != PrecisionCheckLevel::OFF) {
// Create timestamped subdirectory: output_path/YYYYMMDD_HHMMSS/
auto now = std::chrono::system_clock::now();
auto time_t = std::chrono::system_clock::to_time_t(now);
std::tm tm;
localtime_r(&time_t, &tm);
char buf[32];
std::strftime(buf, sizeof(buf), "%Y%m%d_%H%M%S", &tm);

timestamped_path_ = config_.output_path + "/" + buf;
std::filesystem::create_directories(timestamped_path_);
}
}

const PrecisionCheckConfig &PrecisionCheckEnv::GetConfig() const { return config_; }

const std::string &PrecisionCheckEnv::GetOutputPath() const { return timestamped_path_; }

int PrecisionCheckEnv::GetAndIncrementCounter(const std::string &key) { return g_tensor_counter[key]++; }

void PrecisionCheckEnv::ResetCounters() { g_tensor_counter.clear(); }

} // namespace infini_train::utils
Loading