Skip to content

Commit eef8bfc

Browse files
refatcor: separate the weight loading in the npu layer class.
1 parent 559240b commit eef8bfc

File tree

58 files changed

+5235
-3699
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

58 files changed

+5235
-3699
lines changed

xllm/core/layers/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,11 @@ cc_library(
3737
HDRS
3838
attention_mask.h
3939
base_layer.h
40+
base_loader.h
4041
SRCS
4142
attention_mask.cpp
4243
base_layer.cpp
44+
base_loader.cpp
4345
DEPS
4446
:state_dict
4547
:block

xllm/core/layers/base_layer.h

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,7 @@ limitations under the License.
2727
#include <string>
2828
#include <vector>
2929

30-
#include "framework/kv_cache/kv_cache.h"
31-
#include "framework/model/model_input_params.h"
32-
#include "framework/model_context.h"
33-
#include "framework/state_dict/state_dict.h"
34-
30+
#include "base_loader.h"
3531
namespace xllm {
3632
namespace layer {
3733

@@ -92,14 +88,36 @@ class BaseLayer : public torch::nn::Module {
9288

9389
virtual ~BaseLayer() {};
9490

95-
virtual void load_state_dict(const StateDict& state_dict) {};
91+
virtual void load_state_dict(const StateDict& state_dict) {
92+
if (loader_) {
93+
loader_->load_state_dict(state_dict);
94+
}
95+
};
96+
97+
virtual void verify_loaded_weights() const {
98+
if (loader_) {
99+
loader_->verify_loaded_weights();
100+
}
101+
};
96102

97-
virtual void verify_loaded_weights() const {};
103+
virtual void verify_loaded_weights(const std::string& prefix) const {
104+
if (loader_) {
105+
loader_->verify_loaded_weights(prefix);
106+
}
107+
};
98108

99-
virtual void merge_loaded_weights() {};
109+
virtual void merge_loaded_weights() {
110+
if (loader_) {
111+
loader_->merge_loaded_weights();
112+
}
113+
init_layer();
114+
};
100115

101116
virtual int64_t init_layer() { return 0; };
102117

118+
virtual void run_task(std::string taskName, std::function<int()> task) const {
119+
};
120+
103121
void set_weight(const StateDict& state_dict,
104122
const std::string& tensor_name,
105123
int weight_position,
@@ -116,16 +134,14 @@ class BaseLayer : public torch::nn::Module {
116134
int rank,
117135
int world_size);
118136

119-
virtual void run_task(std::string taskName, std::function<int()> task) const {
120-
};
121-
122137
torch::Dtype string2dtype(const std::string& dtype_str);
123138

124139
void correct_tensor_dtype(torch::Tensor& tensor,
125140
const std::string& tensorName);
126141

127142
protected:
128143
std::vector<at::Tensor> at_weight_tensors_;
144+
std::unique_ptr<BaseLoader> loader_ = nullptr;
129145
at::Device device_;
130146
std::string name_;
131147
torch::ScalarType dtype_;

xllm/core/layers/base_loader.cpp

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "base_loader.h"
17+
18+
namespace xllm {
19+
namespace layer {
20+
21+
BaseLoader::BaseLoader(uint64_t weight_count, const ModelContext& context)
22+
: weight_count_(weight_count),
23+
parallel_args_(context.get_parallel_args()),
24+
device_(context.get_tensor_options().device()) {
25+
auto quant_args = context.get_quant_args();
26+
if (!quant_args.quantize_type().empty()) {
27+
quantize_type_ = quant_args.quantize_type();
28+
}
29+
30+
if (!quant_args.torch_dtype().empty()) {
31+
torch_dtype_ = quant_args.torch_dtype();
32+
}
33+
34+
dp_size_ = parallel_args_.dp_size();
35+
dp_local_tp_size_ = parallel_args_.world_size() / dp_size_;
36+
dp_rank_ = parallel_args_.rank() / dp_local_tp_size_;
37+
CHECK_EQ(parallel_args_.world_size(), dp_size_ * dp_local_tp_size_);
38+
dp_local_tp_rank_ = parallel_args_.rank() % dp_local_tp_size_;
39+
40+
at_weight_tensors_.resize(weight_count_);
41+
}
42+
43+
void BaseLoader::set_weight(const StateDict& state_dict,
44+
const std::string& tensor_name,
45+
int weight_position) {
46+
for (const auto& [name, tensor] : state_dict) {
47+
if (absl::EndsWith(name, tensor_name)) {
48+
at::Tensor mutable_tensor = tensor;
49+
correct_tensor_dtype(mutable_tensor, tensor_name);
50+
at_weight_tensors_[weight_position] = mutable_tensor.to(device_);
51+
}
52+
}
53+
}
54+
55+
void BaseLoader::set_weight(const StateDict& state_dict,
56+
const std::string& tensor_name,
57+
int weight_position,
58+
int dim) {
59+
for (const auto& [name, tensor] : state_dict) {
60+
if (absl::EndsWith(name, tensor_name)) {
61+
if (parallel_args_.world_size() <= 1) {
62+
at::Tensor mutable_tensor = tensor;
63+
correct_tensor_dtype(mutable_tensor, tensor_name);
64+
at_weight_tensors_[weight_position] = mutable_tensor.to(device_);
65+
} else {
66+
at_weight_tensors_[weight_position] =
67+
state_dict
68+
.get_sharded_tensor(tensor_name,
69+
/*dim=*/dim,
70+
/*rank=*/parallel_args_.rank(),
71+
/*world_size=*/parallel_args_.world_size())
72+
.to(device_);
73+
correct_tensor_dtype(at_weight_tensors_[weight_position], tensor_name);
74+
}
75+
}
76+
}
77+
}
78+
79+
void BaseLoader::set_weight(const StateDict& state_dict,
80+
const std::string& tensor_name,
81+
int weight_position,
82+
int dim,
83+
int rank,
84+
int world_size) {
85+
for (const auto& [name, tensor] : state_dict) {
86+
if (absl::EndsWith(name, tensor_name)) {
87+
if (world_size <= 1) {
88+
at::Tensor mutable_tensor = tensor;
89+
correct_tensor_dtype(mutable_tensor, tensor_name);
90+
at_weight_tensors_[weight_position] = mutable_tensor.to(device_);
91+
} else {
92+
at_weight_tensors_[weight_position] =
93+
state_dict
94+
.get_sharded_tensor(tensor_name,
95+
/*dim=*/dim,
96+
/*rank=*/rank,
97+
/*world_size=*/world_size)
98+
.to(device_);
99+
correct_tensor_dtype(at_weight_tensors_[weight_position], tensor_name);
100+
}
101+
}
102+
}
103+
}
104+
105+
void BaseLoader::correct_tensor_dtype(torch::Tensor& tensor,
106+
const std::string& tensorName) {
107+
if (absl::EndsWith(tensorName, "deq_scale") &&
108+
(torch_dtype_.compare("bfloat16") == 0)) {
109+
return;
110+
}
111+
112+
if (tensor.dtype() != torch::kInt8 && tensor.dtype() != torch::kInt32 &&
113+
tensor.dtype() != torch::kInt64) {
114+
torch::Dtype dtype = string2dtype(torch_dtype_);
115+
tensor = tensor.to(dtype);
116+
}
117+
}
118+
119+
torch::Dtype BaseLoader::string2dtype(const std::string& dtype_str) {
120+
if (dtype_str.compare("float16") == 0) {
121+
return torch::kFloat16;
122+
} else if (dtype_str.compare("bfloat16") == 0) {
123+
return torch::kBFloat16;
124+
} else if (dtype_str.compare("float32") == 0) {
125+
return torch::kFloat32;
126+
} else if (dtype_str.compare("float64") == 0) {
127+
return torch::kFloat64;
128+
} else if (dtype_str.compare("int8") == 0) {
129+
return torch::kInt8;
130+
} else if (dtype_str.compare("int16") == 0) {
131+
return torch::kInt16;
132+
} else if (dtype_str.compare("int32") == 0) {
133+
return torch::kInt32;
134+
} else if (dtype_str.compare("int64") == 0) {
135+
return torch::kInt64;
136+
} else if (dtype_str.compare("uint8") == 0) {
137+
return torch::kUInt8;
138+
} else if (dtype_str.compare("bool") == 0) {
139+
return torch::kBool;
140+
}
141+
142+
LOG(FATAL) << "Unsupported dtype string: " << dtype_str;
143+
}
144+
145+
} // namespace layer
146+
} // namespace xllm

xllm/core/layers/base_loader.h

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#pragma once
17+
18+
#include <absl/strings/match.h>
19+
#include <torch/torch.h>
20+
21+
#include "framework/eplb/expert_buffer_manager.h"
22+
#include "framework/kv_cache/kv_cache.h"
23+
#include "framework/model/model_input_params.h"
24+
#include "framework/model_context.h"
25+
#include "framework/state_dict/state_dict.h"
26+
#include "xllm_kernels/core/include/atb_speed/base/hosttensor_binder.h"
27+
#include "xllm_kernels/core/include/atb_speed/base/model.h"
28+
#include "xllm_kernels/core/include/atb_speed/log.h"
29+
#include "xllm_kernels/core/include/atb_speed/utils/model_factory.h"
30+
#include "xllm_kernels/models/qwen2/layer/decoder_layer.h"
31+
#include "xllm_kernels/pytorch/atb_torch/core/include/base_operation.h"
32+
#include "xllm_kernels/pytorch/atb_torch/core/include/graph_operation.h"
33+
34+
namespace xllm {
35+
namespace layer {
36+
37+
class BaseLoader {
38+
public:
39+
explicit BaseLoader(uint64_t weight_count, const ModelContext& context);
40+
virtual ~BaseLoader() = default;
41+
42+
virtual void load_state_dict(const StateDict& state_dict) {};
43+
virtual void verify_loaded_weights() const {};
44+
virtual void verify_loaded_weights(const std::string& prefix) const {};
45+
virtual void merge_loaded_weights() {};
46+
virtual void resize_experts_weights(int num_of_device_experts) {};
47+
torch::Dtype string2dtype(const std::string& dtype_str);
48+
49+
void correct_tensor_dtype(torch::Tensor& tensor,
50+
const std::string& tensorName);
51+
52+
std::vector<at::Tensor>& get_at_weight_tensors() {
53+
return at_weight_tensors_;
54+
}
55+
56+
std::unordered_map<std::string, std::vector<torch::Tensor>>&
57+
get_experts_weight_tensors() {
58+
return experts_weights_;
59+
}
60+
61+
std::unique_ptr<ExpertBufferManager>& get_expert_shared_buffer() {
62+
return shared_buffer_;
63+
}
64+
65+
std::vector<int32_t>& get_device_expert_list() { return device_expert_list_; }
66+
67+
atb_torch::TorchTensorMap& get_weights_map() { return weights_map_; }
68+
69+
protected:
70+
uint64_t weight_count_;
71+
xllm::ParallelArgs parallel_args_;
72+
std::string quantize_type_;
73+
std::string torch_dtype_;
74+
torch::ScalarType dtype_;
75+
torch::TensorOptions options_;
76+
std::vector<at::Tensor> at_weight_tensors_;
77+
std::unique_ptr<ExpertBufferManager> shared_buffer_ = nullptr;
78+
std::unordered_map<std::string, torch::Tensor> shared_experts_weights_;
79+
std::unordered_map<std::string, std::vector<torch::Tensor>> experts_weights_;
80+
std::vector<int32_t> device_expert_list_;
81+
atb_torch::TorchTensorMap weights_map_;
82+
83+
at::Device device_;
84+
int32_t dp_size_;
85+
int32_t dp_local_tp_size_;
86+
int32_t dp_rank_;
87+
int32_t dp_local_tp_rank_;
88+
89+
void set_weight(const StateDict& state_dict,
90+
const std::string& tensor_name,
91+
int weight_position);
92+
93+
void set_weight(const StateDict& state_dict,
94+
const std::string& tensor_name,
95+
int weight_position,
96+
int dim);
97+
98+
void set_weight(const StateDict& state_dict,
99+
const std::string& tensor_name,
100+
int weight_position,
101+
int dim,
102+
int rank,
103+
int world_size);
104+
};
105+
106+
} // namespace layer
107+
} // namespace xllm

xllm/core/layers/npu/CMakeLists.txt

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,19 @@ cc_library(
2525
npu_qwen3_decoder_layer_impl.h
2626
npu_rms_norm_impl.h
2727
npu_siglip_encoder_layer_impl.h
28+
loader/qwen3_decoder_loader.h
29+
loader/qwen2_decoder_loader.h
30+
loader/qwen3_moe_decoder_loader.h
31+
loader/word_embedding_loader.h
32+
loader/lm_head_loader.h
33+
loader/column_parallel_linear_loader.h
34+
loader/deepseek_v2_decoder_loader.h
35+
loader/glm4_moe_decoder_loader.h
36+
loader/llama_decoder_loader.h
37+
loader/qwen2dot5_vision_encoder_loader.h
38+
loader/qwen3_vision_encoder_loader.h
39+
loader/rms_norm_loader.h
40+
loader/siglip_encoder_loader.h
2841
SRCS
2942
npu_word_embedding_impl.cpp
3043
npu_pos_embedding_impl.cpp
@@ -45,6 +58,19 @@ cc_library(
4558
npu_qwen3_decoder_layer_impl.cpp
4659
npu_rms_norm_impl.cpp
4760
npu_siglip_encoder_layer_impl.cpp
61+
loader/qwen3_decoder_loader.cpp
62+
loader/qwen2_decoder_loader.cpp
63+
loader/qwen3_moe_decoder_loader.cpp
64+
loader/word_embedding_loader.cpp
65+
loader/lm_head_loader.cpp
66+
loader/column_parallel_linear_loader.cpp
67+
loader/deepseek_v2_decoder_loader.cpp
68+
loader/glm4_moe_decoder_loader.cpp
69+
loader/llama_decoder_loader.cpp
70+
loader/qwen2dot5_vision_encoder_loader.cpp
71+
loader/qwen3_vision_encoder_loader.cpp
72+
loader/rms_norm_loader.cpp
73+
loader/siglip_encoder_loader.cpp
4874
DEPS
4975
"-Wl,--whole-archive"
5076
"-Wl,--no-whole-archive"

0 commit comments

Comments
 (0)