From 6690b61756e78c5472e4e80c7604bd22fc2e1522 Mon Sep 17 00:00:00 2001 From: chen Date: Thu, 22 Jan 2026 07:49:10 +0000 Subject: [PATCH] fix: resolve multi-iteration tensor file overwrite and simplify precision checker Counter mechanism: - Add ResetCounters() to clear tensor counter at iteration boundaries - Move counter management to PrecisionCheckEnv with thread_local storage - Call ResetCounters() at start of each training step in gpt2/llama3 Precision checker refactoring: - Remove baseline comparison functionality (use separate script instead) - Remove table format output, keep only simple and md5 formats - Add TensorStats struct with min/max/mean/nan_count/inf_count - Add SaveNpy() function for NPY file saving with rank subdirectories - Simplify log output format with dtype, shape, stats, and first 6 values - Change stage names from "Module Forward/Backward Output" to "Forward/Backward Output" - Use std::filesystem instead of sys/stat.h for directory creation Documentation and scripts: - Update docs/precision_checker_guide.md with current implementation - Add precision_compare.py for offline NPY comparison - Add run_precision_check_gpt2.sh and run_precision_check_llama3.sh Co-Authored-By: Claude Opus 4.5 --- docs/precision_checker_guide.md | 335 ++++++++--------- example/gpt2/main.cc | 4 + example/llama3/main.cc | 4 + .../include/utils/precision_check_config.h | 14 +- .../include/utils/precision_checker.h | 3 + .../src/utils/precision_check_config.cc | 45 ++- infini_train/src/utils/precision_checker.cc | 343 ++++++------------ scripts/precision_check/precision_compare.py | 153 ++++++++ .../run_precision_check_gpt2.sh | 90 +++++ .../run_precision_check_llama3.sh | 90 +++++ 10 files changed, 637 insertions(+), 444 deletions(-) create mode 100755 scripts/precision_check/precision_compare.py create mode 100644 scripts/precision_check/run_precision_check_gpt2.sh create mode 100644 scripts/precision_check/run_precision_check_llama3.sh diff --git a/docs/precision_checker_guide.md b/docs/precision_checker_guide.md index 0ea4ec8f..164c99ac 100644 --- a/docs/precision_checker_guide.md +++ b/docs/precision_checker_guide.md @@ -1,17 +1,18 @@ # Precision Checker 使用指南 -精度检查工具,用于检测模型训练过程中的数值稳定性问题(NaN、Inf 等),支持 MD5 哈希对比和多种输出格式。 +精度检查工具,用于检测模型训练过程中的数值稳定性问题(NaN、Inf 等),支持 tensor 统计信息输出、MD5 哈希对比和 NPY 文件保存。 ## 功能特性 - **自动检测 NaN/Inf**:在前向和反向传播过程中自动检测异常值 - **多级别检查**:支持 Module 级别和 Function 级别的精度检查 - **灵活配置**:通过 key=value 字符串配置所有选项 -- **MD5 哈希**:支持输出 tensor 的 MD5 值用于对比 -- **表格格式**:支持表格化输出,便于查看和对比 -- **基准对比**:支持加载基准文件进行自动对比 +- **统计信息**:输出 tensor 的 min、max、mean 等统计值 +- **MD5 哈希**:支持输出 tensor 的 MD5 值用于快速对比 +- **NPY 保存**:支持保存 tensor 为 .npy 文件,便于离线分析 - **上下文追踪**:支持 GAS(梯度累积步)和层号追踪 -- **性能优化**:仅在需要时计算 MD5,避免冗余计算 +- **多卡支持**:每个 rank 独立输出到 rank_N 目录 +- **多 iter 覆盖**:同一次运行中,后续 iteration 的文件会覆盖前一个 ## 配置方式 @@ -19,148 +20,158 @@ ```cpp struct PrecisionCheckConfig { - int level = 0; // 0=关闭, 1=MODULE级别, 2=FUNCTION级别 - std::string output_path = ""; // 空=控制台(仅rank0), 非空=文件(所有rank) - bool output_md5 = false; // 输出 MD5 还是 tensor 值 - std::string format = "simple"; // "simple" 或 "table" - std::string baseline_path = ""; // 基准文件路径(用于对比),指定后默认开启 format=table + PrecisionCheckLevel level = PrecisionCheckLevel::OFF; // 0=关闭, 1=MODULE, 2=FUNCTION + std::string output_path = "./precision_check"; // 输出目录 + std::string format = "simple"; // "simple" 或 "md5" + bool save_tensors = false; // 是否保存 .npy 文件 }; ``` -### 配置字符串格式 - -使用 `key=value,key=value` 格式: - -```cpp -auto config = utils::PrecisionCheckConfig::Parse("level=2,format=table,output_md5=true"); -nn::parallel::global::InitAllEnv(nthread, tp_size, sp_enabled, pp_size, vpp_size, config); -``` - ### 配置选项说明 | 选项 | 类型 | 默认值 | 说明 | |------|------|--------|------| | `level` | int | 0 | 0=关闭, 1=MODULE级别, 2=FUNCTION级别 | -| `output_path` | string | "" | 空=控制台(仅rank0), 非空=文件路径(所有rank) | -| `output_md5` | bool | false | true=输出MD5哈希, false=输出tensor值 | -| `format` | string | "simple" | "simple"=简单格式, "table"=表格格式 | -| `baseline` | string | "" | 基准文件路径,用于对比 | - -## 使用方法 +| `path` | string | `./precision_check` | 输出目录(自动创建时间戳子目录) | +| `format` | string | `simple` | `simple`=统计信息+前6个值, `md5`=MD5哈希 | +| `save_tensors` | bool | false | 是否保存 tensor 为 .npy 文件 | -### 1. 基本用法(简单格式) +### 配置字符串格式 -```cpp -#include "infini_train/include/nn/parallel/global.h" -#include "infini_train/include/utils/precision_check_config.h" +使用 `key=value,key=value` 格式: -// 启用 Function 级别检查,输出 tensor 值 -auto config = utils::PrecisionCheckConfig::Parse("level=2"); -nn::parallel::global::InitAllEnv(1, 1, false, 1, 1, config); +```bash +--precision_check "level=1,path=./my_output,format=simple,save_tensors=true" +``` -// 创建并运行模型 -auto x = std::make_shared(std::vector{2, 3}, DataType::kFLOAT32); -x->Fill(2.0f); -x->RequiresGrad(); +## 输出格式 -auto y = x->Mul(x); -auto loss = y->Sum(0, false); -loss->Backward(); -``` +### 目录结构 -输出示例: ``` -I0113 06:44:10.575 [Rank 0][PrecisionCheck] Forward Input MulFunction tensor[0]: [2, 2, 2, 2, 2, 2] -I0113 06:44:10.575 [Rank 0][PrecisionCheck] Forward Output MulFunction tensor[0]: [4, 4, 4, 4, 4, 4] +precision_check/ +└── 20260122_143052/ # 时间戳子目录 (YYYYMMDD_HHMMSS) + ├── precision_check_rank_0.log # 文本日志 + ├── rank_0/ # NPY 文件目录 (save_tensors=true) + │ ├── Block_0_forward.npy + │ ├── Block_1_forward.npy + │ ├── Block_0_backward.npy + │ └── ... + └── rank_1/ # 多卡时每个 rank 独立目录 + └── ... ``` -### 2. MD5 哈希输出 - -```cpp -// 输出 MD5 而不是 tensor 值 -auto config = utils::PrecisionCheckConfig::Parse("level=2,output_md5=true"); -nn::parallel::global::InitAllEnv(1, 1, false, 1, 1, config); -``` +### Simple 格式 (format=simple) -输出示例: ``` -I0113 06:44:37.751 [Rank 0][PrecisionCheck] Forward Input MulFunction tensor[0]: md5=522b4223c3a2f0dd964caa87cb6eab65 -I0113 06:44:37.751 [Rank 0][PrecisionCheck] Forward Output MulFunction tensor[0]: md5=91d1e78bf226d8735a3bc0ca6968339c +[GAS-0] [L-0] Block_0_Forward Output tensor[0]: dtype=float32 shape=(2,1024,768) min=-2.34 max=3.56 mean=0.12 [1.23, 4.56, 7.89, ...] [NaN:0 Inf:0] +[GAS-0] [L-0] Block_0_Forward Output tensor[0]: dtype=float32 shape=(2,1024,768) min=-2.34 max=3.56 mean=0.12 [1.23, NaN, ...] [NaN:5 Inf:0] <- ERROR ``` -### 3. 表格格式输出 - -```cpp -// 使用表格格式,便于查看和对比 -auto config = utils::PrecisionCheckConfig::Parse("level=2,format=table,output_md5=true"); -nn::parallel::global::InitAllEnv(1, 1, false, 1, 1, config); -``` +### MD5 格式 (format=md5) -输出示例: ``` -+--------------------------------------------------+-------+------------------+---------------+----------+----------+ -| key | level | shape | dtype | same_hash| diff_order| -+--------------------------------------------------+-------+------------------+---------------+----------+----------+ -| [GAS-0] [L-0] Forward Input MulFunction | 2 | (2, 3) | float32 | True | 0 | -| [GAS-0] [L-0] Forward Output MulFunction | 2 | (2, 3) | float32 | True | 0 | +[GAS-0] [L-0] Block_0_Forward Output tensor[0]: dtype=float32 shape=(2,1024,768) md5=a1b2c3d4e5f6... ``` -### 4. 基准对比 +### NPY 文件命名规则 -```cpp -// 第一次运行:生成基准文件 -auto config1 = utils::PrecisionCheckConfig::Parse("level=2,output_md5=true,output_path=./baseline"); -nn::parallel::global::InitAllEnv(1, 1, false, 1, 1, config1); -// ... 运行模型 ... -// 生成文件: ./baseline/precision_check_rank_0.log - -// 第二次运行:与基准对比 -auto config2 = utils::PrecisionCheckConfig::Parse("level=2,format=table,baseline=./baseline/precision_check_rank_0.log"); -nn::parallel::global::InitAllEnv(1, 1, false, 1, 1, config2); -// ... 运行模型 ... -// 输出会显示 same_hash 列,标识是否与基准一致 -``` +文件名格式:`{ModuleName}_{idx}_{stage}.npy` -### 5. 文件输出(所有 rank) +- `ModuleName`: 模块名称(如 Block、LayerNorm) +- `idx`: 同名模块在当前 iteration 内的执行顺序索引 +- `stage`: `forward` 或 `backward` -```cpp -// 输出到文件,所有 rank 都会输出 -auto config = utils::PrecisionCheckConfig::Parse("level=2,output_path=./logs"); -nn::parallel::global::InitAllEnv(8, 2, false, 2, 1, config); -// 生成文件: ./logs/precision_check_rank_0.log, ./logs/precision_check_rank_1.log, ... -``` +**多 iteration 行为**:每个 iteration 开始时索引重置为 0,文件会被覆盖。最终只保留最后一个 iteration 的数据。 ## 命令行使用 ### GPT2 示例 ```bash -# 基本检查(简单格式,输出 tensor 值) -./gpt2 --precision_check "level=2" +# 基本检查(Simple 格式,输出到文件) +./build/gpt2 --device cuda \ + --input_bin /path/to/data.bin \ + --llmc_filepath /path/to/model.bin \ + --precision_check "level=1" \ + --num_iteration 1 + +# 保存 NPY 文件 +./build/gpt2 --device cuda \ + --input_bin /path/to/data.bin \ + --llmc_filepath /path/to/model.bin \ + --precision_check "level=1,save_tensors=true" \ + --num_iteration 1 + +# MD5 格式(用于快速对比) +./build/gpt2 --device cuda \ + --input_bin /path/to/data.bin \ + --llmc_filepath /path/to/model.bin \ + --precision_check "level=1,format=md5" \ + --num_iteration 1 + +# 自定义输出路径 +./build/gpt2 --device cuda \ + --input_bin /path/to/data.bin \ + --llmc_filepath /path/to/model.bin \ + --precision_check "level=1,path=./my_precision_check,save_tensors=true" \ + --num_iteration 1 +``` + +### LLaMA3 示例 -# 输出 MD5 哈希 -./gpt2 --precision_check "level=2,output_md5=true" +```bash +./build/llama3 --device cuda \ + --input_bin /path/to/data.bin \ + --llmc_filepath /path/to/model.bin \ + --precision_check "level=1,save_tensors=true" \ + --num_iteration 1 +``` -# 表格格式 -./gpt2 --precision_check "level=2,format=table,output_md5=true" +## 离线对比工具 -# 生成基准文件 -./gpt2 --precision_check "level=2,output_md5=true,output_path=./baseline" +### precision_compare.py -# 与基准对比 -./gpt2 --precision_check "level=2,format=table,baseline=./baseline/precision_check_rank_0.log" +用于对比两次运行的 NPY 文件: + +```bash +python scripts/precision_check/precision_compare.py \ + --dir1 ./precision_check/20260122_143052 \ + --dir2 ./precision_check/20260122_143105 \ + --atol 1e-5 \ + --rtol 1e-3 ``` -### LLaMA3 示例 +输出示例: +``` +Comparing Block_0_forward.npy: + Shape: (2, 1024, 768) vs (2, 1024, 768) ✓ + Dtype: float32 vs float32 ✓ + Max abs diff: 1.23e-06 ✓ + Max rel diff: 2.34e-07 ✓ + +Summary: 433/433 files passed +``` + +### 验证脚本 + +提供了完整的验证脚本: ```bash -# 基本检查 -./llama3 --precision_check "level=2" +# GPT2 验证 +bash scripts/precision_check/run_precision_check_gpt2.sh -# 表格格式 + MD5 -./llama3 --precision_check "level=2,format=table,output_md5=true" +# LLaMA3 验证 +bash scripts/precision_check/run_precision_check_llama3.sh ``` +验证内容: +1. 单卡测试 - Simple 格式 +2. 单卡测试 - MD5 格式 +3. 多 iter 覆盖测试 +4. 两次运行对比测试 +5. 多卡测试(如果环境支持) + ## 上下文追踪 使用 `PrecisionCheckContext` 设置 GAS(梯度累积步)和层号信息: @@ -168,132 +179,78 @@ nn::parallel::global::InitAllEnv(8, 2, false, 2, 1, config); ```cpp #include "infini_train/include/utils/precision_check_context.h" -// 在训练循环中设置上下文 for (int gas_step = 0; gas_step < grad_accum_steps; ++gas_step) { PrecisionCheckContext::Instance().SetGAS(gas_step); for (int layer = 0; layer < num_layers; ++layer) { PrecisionCheckContext::Instance().SetLayer(layer); - PrecisionCheckContext::Instance().SetLayerName("transformer_block"); - // 运行该层的前向传播 // 输出会包含 [GAS-X] [L-Y] 前缀 } } ``` -输出示例: -``` -[GAS-0] [L-0] Forward Input MulFunction -[GAS-0] [L-1] Forward Input MulFunction -[GAS-1] [L-0] Forward Input MulFunction -``` - -## 性能优化 - -### MD5 计算优化 - -MD5 仅在以下情况计算: -- `output_md5=true` 时 -- `baseline_path` 非空时(需要对比) - -默认情况下(`output_md5=false` 且无基准对比),不会计算 MD5,避免性能开销。 - -### 使用建议 - -| 场景 | 推荐配置 | -|------|----------| -| 快速调试 | `level=2` | -| 详细调试 | `level=2,format=table` | -| 生成基准 | `level=2,output_md5=true,output_path=./baseline` | -| 对比测试 | `level=2,format=table,baseline=./baseline/...` | -| 生产环境 | `level=0`(关闭) | - -## 输出格式对比 - -### Simple 格式 - -``` -I0113 06:44:10.575 [Rank 0][PrecisionCheck] Forward Input MulFunction tensor[0]: [2, 2, 2, 2, 2, 2] -``` - -优点:紧凑,易于阅读 -缺点:不便于对比多个 tensor - -### Table 格式 - -``` -+--------------------------------------------------+-------+------------------+---------------+----------+----------+ -| key | level | shape | dtype | same_hash| diff_order| -+--------------------------------------------------+-------+------------------+---------------+----------+----------+ -| [GAS-0] [L-0] Forward Input MulFunction | 2 | (2, 3) | float32 | True | 0 | -``` - -优点:结构化,便于对比和分析 -缺点:占用更多空间 - ## 手动注册(高级用法) -除了通过 `InitAllEnv` 自动注册,也可以手动为特定模块注册: +除了通过命令行自动注册,也可以手动为特定模块注册: ```cpp #include "infini_train/include/utils/precision_checker.h" -// 配置精度检查器 utils::PrecisionChecker::Config config; config.check_nan = true; config.check_inf = true; -config.print_stats = true; config.abort_on_error = false; // 为特定模块注册 utils::PrecisionChecker::RegisterForModule(model.get(), "MyModel", config); - -// 为特定 Function 注册 -utils::PrecisionChecker::RegisterForFunction(my_function.get(), "MyFunction", config); ``` ## 实现原理 -精度检查器通过 Hook 机制实现: +### Hook 机制 -1. **Forward Pre-Hook**:检查输入 tensor -2. **Forward Post-Hook**:检查输出 tensor -3. **Backward Hooks**:自动检查梯度 +精度检查器通过 Hook 机制实现: -检查流程: ``` Forward Pass: - ├─> Pre-Hook: 检查输入 - ├─> Forward: 执行计算 - └─> Post-Hook: 检查输出 + └─> Post-Hook: 检查输出 tensor Backward Pass: - ├─> Backward Pre-Hook: 检查梯度输入 - ├─> Backward: 执行梯度计算 - └─> Backward Post-Hook: 检查梯度输出 + └─> Post-Hook: 检查梯度 tensor ``` -## 示例代码 +### Counter 机制 -参见: -- `test/hook/test_precision_check.cc` - 完整使用示例 -- `infini_train/include/utils/precision_checker.h` - API 文档 -- `infini_train/include/utils/precision_check_config.h` - 配置结构 -- `infini_train/include/utils/precision_check_context.h` - 上下文追踪 +为了支持多 iteration 文件覆盖,使用 `thread_local` 计数器: + +```cpp +// 每个 iteration 开始时重置 +PrecisionChecker::ResetCounters(); -## 测试 +// 每次 CheckTensors 时递增 +int idx = PrecisionCheckEnv::GetAndIncrementCounter(counter_key); +// 文件名: Block_{idx}_forward.npy +``` -```bash -# 运行测试(默认:简单格式) -./test_precision_check +这确保了: +- 同一 iteration 内,同名模块有不同的索引(Block_0, Block_1, ...) +- 不同 iteration 之间,索引重置,文件被覆盖 -# Function 级别 + MD5 -./test_precision_check "level=2,output_md5=true" +## 使用建议 + +| 场景 | 推荐配置 | +|------|----------| +| 快速调试 | `level=1` | +| 详细分析 | `level=1,save_tensors=true` | +| 快速对比 | `level=1,format=md5` | +| 生产环境 | `level=0`(关闭) | -# 表格格式 -./test_precision_check "level=2,format=table,output_md5=true" +## 相关文件 -# Module 级别 -./test_precision_check "level=1" -``` +- `infini_train/include/utils/precision_checker.h` - API 定义 +- `infini_train/include/utils/precision_check_config.h` - 配置结构 +- `infini_train/include/utils/precision_check_context.h` - 上下文追踪 +- `scripts/precision_check/precision_compare.py` - 离线对比工具 +- `scripts/precision_check/run_precision_check_gpt2.sh` - GPT2 验证脚本 +- `scripts/precision_check/run_precision_check_llama3.sh` - LLaMA3 验证脚本 \ No newline at end of file diff --git a/example/gpt2/main.cc b/example/gpt2/main.cc index a1a58ed5..3e4ffa48 100644 --- a/example/gpt2/main.cc +++ b/example/gpt2/main.cc @@ -27,6 +27,7 @@ #endif #include "infini_train/include/nn/parallel/utils.h" #include "infini_train/include/utils/precision_check_config.h" +#include "infini_train/include/utils/precision_checker.h" #include "example/common/tiny_shakespeare_dataset.h" #include "example/common/tokenizer.h" @@ -257,6 +258,9 @@ void Train(const nn::parallel::Rank &rank) { LOG(INFO) << "start training"; for (int step = 0; step < FLAGS_num_iteration + 1; ++step) { + // Reset precision check counters at start of each iteration for file overwrite + utils::PrecisionChecker::ResetCounters(); + const bool last_step = step == FLAGS_num_iteration; const auto iter_start = std::chrono::high_resolution_clock::now(); diff --git a/example/llama3/main.cc b/example/llama3/main.cc index 3a4e5053..ad55c568 100644 --- a/example/llama3/main.cc +++ b/example/llama3/main.cc @@ -26,6 +26,7 @@ #include "infini_train/include/nn/parallel/process_group.h" #include "infini_train/include/nn/parallel/utils.h" #include "infini_train/include/utils/precision_check_config.h" +#include "infini_train/include/utils/precision_checker.h" #include "example/common/tiny_shakespeare_dataset.h" #include "example/common/tokenizer.h" @@ -232,6 +233,9 @@ void Train(const nn::parallel::Rank &rank) { LOG(INFO) << "Rank " << rank.GlobalRank() << ": start training"; for (int step = 0; step < FLAGS_num_iteration + 1; ++step) { + // Reset precision check counters at start of each iteration for file overwrite + utils::PrecisionChecker::ResetCounters(); + const bool last_step = step == FLAGS_num_iteration; const auto iter_start = std::chrono::high_resolution_clock::now(); diff --git a/infini_train/include/utils/precision_check_config.h b/infini_train/include/utils/precision_check_config.h index 25524fb7..5a83405a 100644 --- a/infini_train/include/utils/precision_check_config.h +++ b/infini_train/include/utils/precision_check_config.h @@ -1,6 +1,7 @@ #pragma once #include +#include namespace infini_train { namespace utils { @@ -9,10 +10,9 @@ enum class PrecisionCheckLevel { OFF = 0, MODULE = 1, FUNCTION = 2 }; struct PrecisionCheckConfig { PrecisionCheckLevel level = PrecisionCheckLevel::OFF; - std::string output_path = ""; // empty=console(rank0), non-empty=file(all ranks) - bool output_md5 = false; // output MD5 hash or tensor values - std::string format = "simple"; // "simple" or "table" - std::string baseline_path = ""; // baseline file path for comparison + std::string output_path = "./precision_check"; // Output path (default) + std::string format = "simple"; // "simple" or "md5" + bool save_tensors = false; // Whether to output .npy file // Parse from "key=value,key=value" string static PrecisionCheckConfig Parse(const std::string &config_str); @@ -23,10 +23,16 @@ class PrecisionCheckEnv { static PrecisionCheckEnv &Instance(); void Init(const PrecisionCheckConfig &config); const PrecisionCheckConfig &GetConfig() const; + const std::string &GetOutputPath() const; + + // Tensor counter management for file overwrite across iterations (thread-local) + static int GetAndIncrementCounter(const std::string &key); + static void ResetCounters(); private: PrecisionCheckEnv() = default; PrecisionCheckConfig config_; + std::string timestamped_path_; // Actual output path (with timestamp) }; } // namespace utils diff --git a/infini_train/include/utils/precision_checker.h b/infini_train/include/utils/precision_checker.h index 060ccb98..7c961f1f 100644 --- a/infini_train/include/utils/precision_checker.h +++ b/infini_train/include/utils/precision_checker.h @@ -39,6 +39,9 @@ class PrecisionChecker { static void RegisterForModule(nn::Module *module, const std::string &name = "", const Config &config = DefaultConfig()); + // Reset tensor counters (call at start of each iteration for file overwrite) + static void ResetCounters(); + private: static void CheckTensors(const std::string &stage, const std::string &name, const std::vector> &tensors, const Config &config); diff --git a/infini_train/src/utils/precision_check_config.cc b/infini_train/src/utils/precision_check_config.cc index d37cbdb0..a5460a95 100644 --- a/infini_train/src/utils/precision_check_config.cc +++ b/infini_train/src/utils/precision_check_config.cc @@ -1,10 +1,17 @@ #include "infini_train/include/utils/precision_check_config.h" +#include +#include #include #include namespace infini_train::utils { +namespace { +// Thread-local tensor counter for precision check file indexing +thread_local std::unordered_map g_tensor_counter; +} // namespace + PrecisionCheckConfig PrecisionCheckConfig::Parse(const std::string &config_str) { PrecisionCheckConfig config; if (config_str.empty()) { @@ -25,20 +32,14 @@ PrecisionCheckConfig PrecisionCheckConfig::Parse(const std::string &config_str) int level_int = std::stoi(kv_map["level"]); config.level = static_cast(level_int); } - if (kv_map.count("output_path")) { - config.output_path = kv_map["output_path"]; - } - if (kv_map.count("output_md5")) { - config.output_md5 = (kv_map["output_md5"] == "true" || kv_map["output_md5"] == "1"); - } - if (kv_map.count("baseline")) { - config.baseline_path = kv_map["baseline"]; + if (kv_map.count("path")) { + config.output_path = kv_map["path"]; } if (kv_map.count("format")) { config.format = kv_map["format"]; - } else if (!config.baseline_path.empty()) { - // Default to table format when baseline is specified - config.format = "table"; + } + if (kv_map.count("save_tensors")) { + config.save_tensors = (kv_map["save_tensors"] == "true" || kv_map["save_tensors"] == "1"); } return config; } @@ -48,8 +49,28 @@ PrecisionCheckEnv &PrecisionCheckEnv::Instance() { return instance; } -void PrecisionCheckEnv::Init(const PrecisionCheckConfig &config) { config_ = config; } +void PrecisionCheckEnv::Init(const PrecisionCheckConfig &config) { + config_ = config; + if (config_.level != PrecisionCheckLevel::OFF) { + // Create timestamped subdirectory: output_path/YYYYMMDD_HHMMSS/ + auto now = std::chrono::system_clock::now(); + auto time_t = std::chrono::system_clock::to_time_t(now); + std::tm tm; + localtime_r(&time_t, &tm); + char buf[32]; + std::strftime(buf, sizeof(buf), "%Y%m%d_%H%M%S", &tm); + + timestamped_path_ = config_.output_path + "/" + buf; + std::filesystem::create_directories(timestamped_path_); + } +} const PrecisionCheckConfig &PrecisionCheckEnv::GetConfig() const { return config_; } +const std::string &PrecisionCheckEnv::GetOutputPath() const { return timestamped_path_; } + +int PrecisionCheckEnv::GetAndIncrementCounter(const std::string &key) { return g_tensor_counter[key]++; } + +void PrecisionCheckEnv::ResetCounters() { g_tensor_counter.clear(); } + } // namespace infini_train::utils diff --git a/infini_train/src/utils/precision_checker.cc b/infini_train/src/utils/precision_checker.cc index 60301aa9..6bb93b66 100644 --- a/infini_train/src/utils/precision_checker.cc +++ b/infini_train/src/utils/precision_checker.cc @@ -5,14 +5,13 @@ #include #include #include +#include #include #include #include #include #include #include -#include -#include #include #include "infini_train/include/autograd/function.h" @@ -153,123 +152,29 @@ std::string ComputeMD5(const void *data, size_t size) { return md5.Finalize(); } -// Baseline storage -std::unordered_map &GetBaseline() { - static std::unordered_map baseline; - static bool loaded = false; - static std::mutex load_mutex; - - if (!loaded) { - std::lock_guard lock(load_mutex); - if (!loaded) { - const auto &config = PrecisionCheckEnv::Instance().GetConfig(); - if (!config.baseline_path.empty()) { - std::ifstream file(config.baseline_path); - if (!file.is_open()) { - std::cerr << "[PrecisionCheck] Failed to open baseline file: " << config.baseline_path << std::endl; - } else { - std::string line; - while (std::getline(file, line)) { - // Try format 1: key|md5 - auto pipe_pos = line.rfind('|'); - if (pipe_pos != std::string::npos) { - std::string key = line.substr(0, pipe_pos); - std::string md5 = line.substr(pipe_pos + 1); - baseline[key] = md5; - } else { - // Try format 2: simple log format with "md5=" - auto md5_pos = line.find("md5="); - if (md5_pos != std::string::npos) { - // Extract md5 value - std::string md5 = line.substr(md5_pos + 4); - - // Extract key: find text between "][PrecisionCheck] " and ": md5=" - auto check_pos = line.find("][PrecisionCheck] "); - if (check_pos != std::string::npos) { - size_t key_start = check_pos + 18; // length of "][PrecisionCheck] " - size_t key_end = line.find(": md5=", key_start); - if (key_end != std::string::npos) { - std::string key = line.substr(key_start, key_end - key_start); - baseline[key] = md5; - } - } - } - } - } - std::cout << "[PrecisionCheck] Loaded " << baseline.size() << " baseline entries from " - << config.baseline_path << std::endl; - } - } - loaded = true; - } - } - return baseline; -} - -// Table header printed flag -bool &TableHeaderPrinted() { - thread_local bool printed = false; - return printed; -} - std::ostream &GetLogStream() { thread_local std::ofstream log_file; thread_local std::mutex init_mutex; thread_local bool initialized = false; - thread_local bool use_console = false; if (!initialized) { std::lock_guard lock(init_mutex); if (!initialized) { - const auto &config = PrecisionCheckEnv::Instance().GetConfig(); - - if (config.output_path.empty()) { - use_console = true; + const auto &output_path = PrecisionCheckEnv::Instance().GetOutputPath(); + int global_rank = nn::parallel::global::thread_global_rank; + std::string filename = output_path + "/precision_check_rank_" + std::to_string(global_rank) + ".log"; + log_file.open(filename, std::ios::out | std::ios::trunc); + if (!log_file.is_open()) { + std::cerr << "[Rank " << global_rank << "] Failed to open precision check log file: " << filename + << std::endl; } else { - // Create output directory if it doesn't exist - mkdir(config.output_path.c_str(), 0755); - - int global_rank = nn::parallel::global::thread_global_rank; - std::string filename - = config.output_path + "/precision_check_rank_" + std::to_string(global_rank) + ".log"; - log_file.open(filename, std::ios::out | std::ios::trunc); - if (!log_file.is_open()) { - std::cerr << "[Rank " << global_rank << "] Failed to open precision check log file: " << filename - << std::endl; - use_console = true; - } else { - use_console = false; - std::cout << "[Rank " << global_rank << "] Precision check output: " << filename << std::endl; - } + std::cout << "[Rank " << global_rank << "] Precision check output: " << filename << std::endl; } initialized = true; } } - return use_console ? std::cout : log_file; -} - -bool ShouldPrint() { - const auto &config = PrecisionCheckEnv::Instance().GetConfig(); - if (!config.output_path.empty()) { - return true; - } - return nn::parallel::global::GlobalEnv::Instance().global_proc_rank() == 0; -} - -std::string GetTimestamp() { - auto now = std::chrono::system_clock::now(); - auto time_t = std::chrono::system_clock::to_time_t(now); - auto ms = std::chrono::duration_cast(now.time_since_epoch()) % 1000; - - std::tm tm; - localtime_r(&time_t, &tm); - - std::ostringstream oss; - oss << std::setfill('0') << std::setw(2) << (tm.tm_mon + 1) << std::setw(2) << tm.tm_mday << ' ' << std::setw(2) - << tm.tm_hour << ':' << std::setw(2) << tm.tm_min << ':' << std::setw(2) << tm.tm_sec << '.' << std::setw(3) - << ms.count(); - return oss.str(); + return log_file.is_open() ? log_file : std::cout; } std::string FormatShape(const std::vector &shape) { @@ -277,7 +182,7 @@ std::string FormatShape(const std::vector &shape) { oss << "("; for (size_t i = 0; i < shape.size(); ++i) { if (i > 0) { - oss << ", "; + oss << ","; } oss << shape[i]; } @@ -302,63 +207,70 @@ std::string DataTypeToString(DataType dtype) { } } -void PrintTableHeader(std::ostream &os) { - if (TableHeaderPrinted()) { - return; +struct TensorStats { + float min_val = 0; + float max_val = 0; + float mean_val = 0; + int nan_count = 0; + int inf_count = 0; +}; + +TensorStats ComputeStats(const float *data, size_t num_elements) { + TensorStats stats; + if (num_elements == 0) { + return stats; } - TableHeaderPrinted() = true; - - os << "+" << std::string(50, '-') << "+" << std::string(7, '-') << "+" << std::string(18, '-') << "+" - << std::string(15, '-') << "+" << std::string(10, '-') << "+\n"; - os << "| " << std::left << std::setw(49) << "key" - << "| " << std::setw(6) << "level" - << "| " << std::setw(17) << "shape" - << "| " << std::setw(14) << "dtype" - << "| " << std::setw(9) << "same_hash" - << "|\n"; - os << "+" << std::string(50, '-') << "+" << std::string(7, '-') << "+" << std::string(18, '-') << "+" - << std::string(15, '-') << "+" << std::string(10, '-') << "+\n"; -} -void PrintTableRow(std::ostream &os, const std::string &key, int level, const std::string &shape, - const std::string &dtype, const std::string &same_hash) { - os << "| " << std::left << std::setw(49) << key.substr(0, 49) << "| " << std::setw(6) << level << "| " - << std::setw(17) << shape.substr(0, 17) << "| " << std::setw(14) << dtype << "| " << std::setw(9) << same_hash - << "|\n"; -} + stats.min_val = std::numeric_limits::max(); + stats.max_val = std::numeric_limits::lowest(); + double sum = 0; -// Calculate diff order between two tensors (returns string like "1e-3" or "0") -std::string CalculateDiffOrder(const float *data1, const float *data2, size_t size) { - if (!data1 || !data2 || size == 0) { - return "N/A"; + for (size_t i = 0; i < num_elements; ++i) { + float val = data[i]; + if (std::isnan(val)) { + stats.nan_count++; + continue; + } + if (std::isinf(val)) { + stats.inf_count++; + continue; + } + stats.min_val = std::min(stats.min_val, val); + stats.max_val = std::max(stats.max_val, val); + sum += val; } - double max_diff = 0.0; - for (size_t i = 0; i < size; ++i) { - double diff = std::abs(static_cast(data1[i]) - static_cast(data2[i])); - max_diff = std::max(max_diff, diff); - } + size_t valid_count = num_elements - stats.nan_count - stats.inf_count; + stats.mean_val = valid_count > 0 ? static_cast(sum / valid_count) : 0; - if (max_diff == 0.0) { - return "0"; - } + return stats; +} - int order = static_cast(std::floor(std::log10(max_diff))); - return "1e" + std::to_string(order); +void SaveNpy(const std::shared_ptr &tensor, const std::string &name, int idx, const std::string &stage, + int rank) { + const auto &output_path = PrecisionCheckEnv::Instance().GetOutputPath(); + std::string dir = output_path + "/rank_" + std::to_string(rank); + std::filesystem::create_directories(dir); + std::string filename = dir + "/" + name + "_" + std::to_string(idx) + "_" + stage + ".npy"; + + if (tensor->Dtype() == DataType::kFLOAT32) { + tensor->SaveAsNpy(filename); + } else { + auto float_tensor = tensor->To(DataType::kFLOAT32); + float_tensor.SaveAsNpy(filename); + } } } // namespace void PrecisionChecker::CheckTensors(const std::string &stage, const std::string &name, const std::vector> &tensors, const Config &config) { - if (!ShouldPrint()) { + const auto &global_config = PrecisionCheckEnv::Instance().GetConfig(); + if (global_config.level == PrecisionCheckLevel::OFF) { return; } - const auto &global_config = PrecisionCheckEnv::Instance().GetConfig(); const int rank = nn::parallel::global::thread_global_rank; - const auto level = global_config.level; - auto &baseline = GetBaseline(); for (size_t i = 0; i < tensors.size(); ++i) { if (!tensors[i]) { @@ -376,110 +288,61 @@ void PrecisionChecker::CheckTensors(const std::string &stage, const std::string cpu_tensor = tensor; } - const void *data = cpu_tensor->DataPtr(); + const float *float_data = static_cast(cpu_tensor->DataPtr()); const size_t byte_size = cpu_tensor->SizeInBytes(); const size_t num_elements = cpu_tensor->NumElements(); - // Build key + // Build context key const std::string context_key = PrecisionCheckContext::Instance().GetKey(); - const std::string full_key = context_key.empty() ? (stage + " " + name + " tensor[" + std::to_string(i) + "]") - : (context_key + " " + stage + " " + name); - - // Only compute MD5 if needed (for output or baseline comparison) - const bool need_md5 = global_config.output_md5 || !baseline.empty(); - std::string md5; - if (need_md5) { - md5 = ComputeMD5(data, byte_size); - } + const std::string stage_short = (stage.find("Forward") != std::string::npos) ? "forward" : "backward"; - // Check baseline - const bool has_baseline = !baseline.empty(); - bool same_hash = true; - if (has_baseline) { - auto it = baseline.find(full_key); - if (it == baseline.end() && !context_key.empty()) { - // Try without context: "stage name tensor[i]" - std::string key_without_context = stage + " " + name + " tensor[" + std::to_string(i) + "]"; - it = baseline.find(key_without_context); - } - if (it != baseline.end()) { - same_hash = (it->second == md5); - } + // Get tensor index for this (name, stage) combination + std::string counter_key = name + "_" + stage_short; + int idx = PrecisionCheckEnv::GetAndIncrementCounter(counter_key); + + // Save NPY if enabled + if (global_config.save_tensors) { + SaveNpy(cpu_tensor, name, idx, stage_short, rank); } + // Output to log auto &log_stream = GetLogStream(); - if (global_config.format == "table") { - thread_local bool header_printed = false; - if (!header_printed) { - PrintTableHeader(log_stream); - header_printed = true; - } - std::string same_hash_str = has_baseline ? (same_hash ? "True" : "False") : "--"; - PrintTableRow(log_stream, full_key, static_cast(level), FormatShape(cpu_tensor->Dims()), - DataTypeToString(cpu_tensor->Dtype()), same_hash_str); - - // Save to baseline file if output_path is set and output_md5 is true - if (!global_config.output_path.empty() && global_config.output_md5) { - log_stream << full_key << "|" << md5 << std::endl; - } + if (global_config.format == "md5") { + // MD5 format + std::string md5 = ComputeMD5(float_data, byte_size); + log_stream << context_key << " " << name << "_" << idx << "_" << stage << " tensor[" << i << "]: " + << "dtype=" << DataTypeToString(cpu_tensor->Dtype()) << " " + << "shape=" << FormatShape(cpu_tensor->Dims()) << " " + << "md5=" << md5 << std::endl; } else { - // Simple format - const float *float_data = static_cast(data); - - bool has_nan = false; - bool has_inf = false; - for (size_t j = 0; j < num_elements; ++j) { - float val = float_data[j]; - if (std::isnan(val)) { - has_nan = true; - } - if (std::isinf(val)) { - has_inf = true; + // Simple format (default) + TensorStats stats = ComputeStats(float_data, num_elements); + + const bool has_error + = (config.check_nan && stats.nan_count > 0) || (config.check_inf && stats.inf_count > 0); + const std::string error_marker = has_error ? " <- ERROR" : ""; + + log_stream << context_key << " " << name << "_" << idx << "_" << stage << " tensor[" << i << "]: " + << "dtype=" << DataTypeToString(cpu_tensor->Dtype()) << " " + << "shape=" << FormatShape(cpu_tensor->Dims()) << " " + << "min=" << stats.min_val << " " + << "max=" << stats.max_val << " " + << "mean=" << stats.mean_val << " ["; + + // Print first 6 values + constexpr size_t max_print = 6; + for (size_t j = 0; j < std::min(num_elements, max_print); ++j) { + if (j > 0) { + log_stream << ", "; } + log_stream << float_data[j]; } - - const bool has_error = (config.check_nan && has_nan) || (config.check_inf && has_inf); - - // When output_path is set, always write to file; otherwise only write on error or if print_stats is enabled - const bool should_output = !global_config.output_path.empty() || has_error || config.print_stats; - - if (should_output) { - const std::string log_level = has_error ? "E" : "I"; - - log_stream << log_level << GetTimestamp() << " [Rank " << rank << "][PrecisionCheck] " << stage << " " - << name << " tensor[" << i << "]: "; - - if (global_config.output_md5) { - log_stream << "md5=" << md5; - if (!same_hash) { - log_stream << " (MISMATCH)"; - } - } else { - log_stream << "["; - if (has_nan) { - log_stream << " NaN detected!"; - } - if (has_inf) { - log_stream << " Inf detected!"; - } - - if (config.print_stats) { - constexpr size_t max_print = 6; - for (size_t j = 0; j < std::min(num_elements, max_print); ++j) { - if (j > 0) { - log_stream << ", "; - } - log_stream << float_data[j]; - } - if (num_elements > max_print) { - log_stream << ", ..."; - } - } - log_stream << "]"; - } - log_stream << std::endl; + if (num_elements > max_print) { + log_stream << ", ..."; } + log_stream << "] [NaN:" << stats.nan_count << " Inf:" << stats.inf_count << "]" << error_marker + << std::endl; if (has_error && config.abort_on_error) { std::cerr << "Precision check failed, aborting!" << std::endl; @@ -520,14 +383,16 @@ void PrecisionChecker::RegisterForModule(nn::Module *module, const std::string & module->RegisterForwardPostHook([module_name, config](nn::Module *, const std::vector> &, const std::vector> &outputs) { - CheckTensors("Module Forward Output", module_name, outputs, config); + CheckTensors("Forward Output", module_name, outputs, config); }); module->RegisterBackwardPostHook([module_name, config](nn::Module *, const std::vector> &grad_inputs, const std::vector> &) { - CheckTensors("Module Backward Output", module_name, grad_inputs, config); + CheckTensors("Backward Output", module_name, grad_inputs, config); }); } +void PrecisionChecker::ResetCounters() { PrecisionCheckEnv::ResetCounters(); } + } // namespace infini_train::utils diff --git a/scripts/precision_check/precision_compare.py b/scripts/precision_check/precision_compare.py new file mode 100755 index 00000000..40c91308 --- /dev/null +++ b/scripts/precision_check/precision_compare.py @@ -0,0 +1,153 @@ +#!/usr/bin/env python3 +""" +Precision comparison tool for InfiniTrain tensor outputs. + +Usage: + python precision_compare.py --dir1 ./run1 --dir2 ./run2 [--atol 1e-5] [--rtol 1e-3] + +Compares .npy files between two directories and reports differences. +""" + +import argparse +import os +import sys +from pathlib import Path + +import numpy as np + + +def find_npy_files(directory: str) -> dict[str, Path]: + """Find all .npy files in directory (recursively).""" + files = {} + for path in Path(directory).rglob("*.npy"): + rel_path = path.relative_to(directory) + files[str(rel_path)] = path + return files + + +def compare_tensors(file1: Path, file2: Path, atol: float, rtol: float) -> dict: + """Compare two tensor files and return comparison results.""" + arr1 = np.load(file1) + arr2 = np.load(file2) + + result = { + "file": str(file1.name), + "shape1": arr1.shape, + "shape2": arr2.shape, + "dtype1": str(arr1.dtype), + "dtype2": str(arr2.dtype), + "match": False, + "error": None, + } + + if arr1.shape != arr2.shape: + result["error"] = f"Shape mismatch: {arr1.shape} vs {arr2.shape}" + return result + + if arr1.dtype != arr2.dtype: + result["error"] = f"Dtype mismatch: {arr1.dtype} vs {arr2.dtype}" + return result + + arr1_flat = arr1.astype(np.float64).flatten() + arr2_flat = arr2.astype(np.float64).flatten() + + abs_diff = np.abs(arr1_flat - arr2_flat) + max_abs_diff = np.max(abs_diff) + mean_abs_diff = np.mean(abs_diff) + + with np.errstate(divide="ignore", invalid="ignore"): + rel_diff = abs_diff / (np.abs(arr2_flat) + 1e-12) + rel_diff = np.where(np.isfinite(rel_diff), rel_diff, 0) + max_rel_diff = np.max(rel_diff) + mean_rel_diff = np.mean(rel_diff) + + result["max_abs_diff"] = float(max_abs_diff) + result["mean_abs_diff"] = float(mean_abs_diff) + result["max_rel_diff"] = float(max_rel_diff) + result["mean_rel_diff"] = float(mean_rel_diff) + result["match"] = np.allclose(arr1, arr2, atol=atol, rtol=rtol) + + return result + + +def main(): + parser = argparse.ArgumentParser(description="Compare precision check outputs") + parser.add_argument("--dir1", required=True, help="First directory") + parser.add_argument("--dir2", required=True, help="Second directory") + parser.add_argument("--atol", type=float, default=1e-5, help="Absolute tolerance") + parser.add_argument("--rtol", type=float, default=1e-3, help="Relative tolerance") + parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output") + args = parser.parse_args() + + if not os.path.isdir(args.dir1): + print(f"Error: {args.dir1} is not a directory") + sys.exit(1) + if not os.path.isdir(args.dir2): + print(f"Error: {args.dir2} is not a directory") + sys.exit(1) + + files1 = find_npy_files(args.dir1) + files2 = find_npy_files(args.dir2) + + print(f"Directory 1: {args.dir1} ({len(files1)} files)") + print(f"Directory 2: {args.dir2} ({len(files2)} files)") + print(f"Tolerance: atol={args.atol}, rtol={args.rtol}") + print() + + only_in_1 = set(files1.keys()) - set(files2.keys()) + only_in_2 = set(files2.keys()) - set(files1.keys()) + common = set(files1.keys()) & set(files2.keys()) + + if only_in_1: + print(f"Files only in dir1 ({len(only_in_1)}):") + for f in sorted(only_in_1): + print(f" {f}") + print() + + if only_in_2: + print(f"Files only in dir2 ({len(only_in_2)}):") + for f in sorted(only_in_2): + print(f" {f}") + print() + + if not common: + print("No common files to compare") + sys.exit(1) + + print(f"Comparing {len(common)} common files...") + print() + + passed = 0 + failed = 0 + errors = 0 + + for rel_path in sorted(common): + result = compare_tensors(files1[rel_path], files2[rel_path], args.atol, args.rtol) + + if result["error"]: + errors += 1 + print(f"ERROR: {rel_path}") + print(f" {result['error']}") + elif result["match"]: + passed += 1 + if args.verbose: + print(f"PASS: {rel_path}") + print(f" max_abs={result['max_abs_diff']:.2e} max_rel={result['max_rel_diff']:.2e}") + else: + failed += 1 + print(f"FAIL: {rel_path}") + print(f" shape={result['shape1']} dtype={result['dtype1']}") + print(f" max_abs={result['max_abs_diff']:.2e} mean_abs={result['mean_abs_diff']:.2e}") + print(f" max_rel={result['max_rel_diff']:.2e} mean_rel={result['mean_rel_diff']:.2e}") + + print() + print("=" * 50) + print(f"Summary: {passed} passed, {failed} failed, {errors} errors") + print(f"Missing: {len(only_in_1)} in dir1 only, {len(only_in_2)} in dir2 only") + + if failed > 0 or errors > 0: + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/scripts/precision_check/run_precision_check_gpt2.sh b/scripts/precision_check/run_precision_check_gpt2.sh new file mode 100644 index 00000000..a0b3cab9 --- /dev/null +++ b/scripts/precision_check/run_precision_check_gpt2.sh @@ -0,0 +1,90 @@ +#!/bin/bash +# InfiniTrain Precision Checker - GPT2 +set -e + +# Configuration +BIN="./build/gpt2" +MODEL_ARGS="--device cuda --input_bin /data/shared/InfiniTrain-dev/data/llmc/gpt2/tinyshakespeare/tiny_shakespeare_train.bin --llmc_filepath /data/shared/InfiniTrain-dev/data/llmc/gpt2/gpt2_124M.bin" +OUTPUT_DIR="./log_precision_check_gpt2" +COMPARE_SCRIPT="scripts/precision_check/precision_compare.py" + +echo "=== InfiniTrain Precision Checker - GPT2 ===" + +if [ ! -f "$BIN" ]; then + echo "Error: $BIN not found. Please build the project first." + exit 1 +fi + +if [ ! -f "$COMPARE_SCRIPT" ]; then + echo "Error: $COMPARE_SCRIPT not found." + exit 1 +fi + +# Clean test directory +rm -rf "$OUTPUT_DIR" +mkdir -p "$OUTPUT_DIR" + +# 1. Single-rank test - Simple format +echo "" +echo "=== 1. Single-rank test (Simple format) ===" +CMD="$BIN $MODEL_ARGS --precision_check \"level=1,path=$OUTPUT_DIR/test1,format=simple,save_tensors=true\" --num_iteration 1" +echo "Running: $CMD" +eval $CMD + +TIMESTAMP_DIR=$(ls -t "$OUTPUT_DIR/test1" | head -1) +echo "Timestamp directory: $TIMESTAMP_DIR" +NPY_COUNT=$(ls "$OUTPUT_DIR/test1/$TIMESTAMP_DIR/rank_0"/*.npy 2>/dev/null | wc -l) +echo "Rank 0 NPY files: $NPY_COUNT" +LOG_FILE=$(ls "$OUTPUT_DIR/test1/$TIMESTAMP_DIR"/*.log 2>/dev/null | head -1) +echo "Log file: $LOG_FILE" + +# 2. Single-rank test - MD5 format +echo "" +echo "=== 2. Single-rank test (MD5 format) ===" +CMD="$BIN $MODEL_ARGS --precision_check \"level=1,path=$OUTPUT_DIR/test2,format=md5\" --num_iteration 1" +echo "Running: $CMD" +eval $CMD + +TIMESTAMP_DIR=$(ls -t "$OUTPUT_DIR/test2" | head -1) + +# 3. Multi-iter overwrite test +echo "" +echo "=== 3. Multi-iter overwrite test ===" +CMD="$BIN $MODEL_ARGS --precision_check \"level=1,path=$OUTPUT_DIR/test3,save_tensors=true\" --num_iteration 3" +echo "Running: $CMD" +eval $CMD + +TIMESTAMP_DIR=$(ls -t "$OUTPUT_DIR/test3" | head -1) +FILE_COUNT=$(ls "$OUTPUT_DIR/test3/$TIMESTAMP_DIR/rank_0"/*.npy 2>/dev/null | wc -l) +echo "Files after 3 iters: $FILE_COUNT (should be same as 1 iter - files overwritten)" + +# 4. Comparison test +echo "" +echo "=== 4. Comparison test ===" +CMD="$BIN $MODEL_ARGS --precision_check \"level=1,path=$OUTPUT_DIR/run1,save_tensors=true\" --num_iteration 1" +echo "Running: $CMD" +eval $CMD +sleep 2 +CMD="$BIN $MODEL_ARGS --precision_check \"level=1,path=$OUTPUT_DIR/run2,save_tensors=true\" --num_iteration 1" +echo "Running: $CMD" +eval $CMD + +RUN1_DIR="$OUTPUT_DIR/run1/$(ls -t "$OUTPUT_DIR/run1" | head -1)" +RUN2_DIR="$OUTPUT_DIR/run2/$(ls -t "$OUTPUT_DIR/run2" | head -1)" + +echo "Comparing directories:" +echo " Run1: $RUN1_DIR" +echo " Run2: $RUN2_DIR" + +python "$COMPARE_SCRIPT" --dir1 "$RUN1_DIR" --dir2 "$RUN2_DIR" --atol 1e-5 --rtol 1e-3 || true + +# 5. Multi-rank test (if available) +echo "" +echo "=== 5. Multi-rank test ===" +CMD="$BIN $MODEL_ARGS --nthread_per_process 8 --tensor_parallel 4 --pipeline_parallel 2 --precision_check \"level=1,path=$OUTPUT_DIR/test_multi,save_tensors=true\" --num_iteration 1" +echo "Running: $CMD" +eval $CMD + +echo "" +echo "=== Verification Complete ===" +echo "Test output directory: $OUTPUT_DIR" diff --git a/scripts/precision_check/run_precision_check_llama3.sh b/scripts/precision_check/run_precision_check_llama3.sh new file mode 100644 index 00000000..d535ea15 --- /dev/null +++ b/scripts/precision_check/run_precision_check_llama3.sh @@ -0,0 +1,90 @@ +#!/bin/bash +# InfiniTrain Precision Checker - Llama3 +set -e + +# Configuration +BIN="./build/llama3" +MODEL_ARGS="--device cuda --input_bin /data/shared/InfiniTrain-dev/data/llmc/llama3/tinyshakespeare/tiny_shakespeare_train.bin --llmc_filepath /data/shared/InfiniTrain-dev/data/llmc/llama3/llama3.2_1B_fp32.bin" +OUTPUT_DIR="./log_precision_check_llama3" +COMPARE_SCRIPT="scripts/precision_check/precision_compare.py" + +echo "=== InfiniTrain Precision Checker - Llama3 ===" + +if [ ! -f "$BIN" ]; then + echo "Error: $BIN not found. Please build the project first." + exit 1 +fi + +if [ ! -f "$COMPARE_SCRIPT" ]; then + echo "Error: $COMPARE_SCRIPT not found." + exit 1 +fi + +# Clean test directory +rm -rf "$OUTPUT_DIR" +mkdir -p "$OUTPUT_DIR" + +# 1. Single-rank test - Simple format +echo "" +echo "=== 1. Single-rank test (Simple format) ===" +CMD="$BIN $MODEL_ARGS --precision_check \"level=1,path=$OUTPUT_DIR/test1,format=simple,save_tensors=true\" --num_iteration 1" +echo "Running: $CMD" +eval $CMD + +TIMESTAMP_DIR=$(ls -t "$OUTPUT_DIR/test1" | head -1) +echo "Timestamp directory: $TIMESTAMP_DIR" +NPY_COUNT=$(ls "$OUTPUT_DIR/test1/$TIMESTAMP_DIR/rank_0"/*.npy 2>/dev/null | wc -l) +echo "Rank 0 NPY files: $NPY_COUNT" +LOG_FILE=$(ls "$OUTPUT_DIR/test1/$TIMESTAMP_DIR"/*.log 2>/dev/null | head -1) +echo "Log file: $LOG_FILE" + +# 2. Single-rank test - MD5 format +echo "" +echo "=== 2. Single-rank test (MD5 format) ===" +CMD="$BIN $MODEL_ARGS --precision_check \"level=1,path=$OUTPUT_DIR/test2,format=md5\" --num_iteration 1" +echo "Running: $CMD" +eval $CMD + +TIMESTAMP_DIR=$(ls -t "$OUTPUT_DIR/test2" | head -1) + +# 3. Multi-iter overwrite test +echo "" +echo "=== 3. Multi-iter overwrite test ===" +CMD="$BIN $MODEL_ARGS --precision_check \"level=1,path=$OUTPUT_DIR/test3,save_tensors=true\" --num_iteration 3" +echo "Running: $CMD" +eval $CMD + +TIMESTAMP_DIR=$(ls -t "$OUTPUT_DIR/test3" | head -1) +FILE_COUNT=$(ls "$OUTPUT_DIR/test3/$TIMESTAMP_DIR/rank_0"/*.npy 2>/dev/null | wc -l) +echo "Files after 3 iters: $FILE_COUNT (should be same as 1 iter - files overwritten)" + +# 4. Comparison test +echo "" +echo "=== 4. Comparison test ===" +CMD="$BIN $MODEL_ARGS --precision_check \"level=1,path=$OUTPUT_DIR/run1,save_tensors=true\" --num_iteration 1" +echo "Running: $CMD" +eval $CMD +sleep 2 +CMD="$BIN $MODEL_ARGS --precision_check \"level=1,path=$OUTPUT_DIR/run2,save_tensors=true\" --num_iteration 1" +echo "Running: $CMD" +eval $CMD + +RUN1_DIR="$OUTPUT_DIR/run1/$(ls -t "$OUTPUT_DIR/run1" | head -1)" +RUN2_DIR="$OUTPUT_DIR/run2/$(ls -t "$OUTPUT_DIR/run2" | head -1)" + +echo "Comparing directories:" +echo " Run1: $RUN1_DIR" +echo " Run2: $RUN2_DIR" + +python "$COMPARE_SCRIPT" --dir1 "$RUN1_DIR" --dir2 "$RUN2_DIR" --atol 1e-5 --rtol 1e-3 || true + +# 5. Multi-rank test (if available) +echo "" +echo "=== 5. Multi-rank test ===" +CMD="$BIN $MODEL_ARGS --nthread_per_process 8 --tensor_parallel 4 --pipeline_parallel 2 --precision_check \"level=1,path=$OUTPUT_DIR/test_multi,save_tensors=true\" --num_iteration 1" +echo "Running: $CMD" +eval $CMD + +echo "" +echo "=== Verification Complete ===" +echo "Test output directory: $OUTPUT_DIR" \ No newline at end of file