Skip to content

Commit aaa9676

Browse files
refatcor: separate the weight loading in the npu layer class.
1 parent 559240b commit aaa9676

File tree

57 files changed

+5240
-3699
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+5240
-3699
lines changed

xllm/core/layers/base_layer.h

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,9 @@ limitations under the License.
2727
#include <string>
2828
#include <vector>
2929

30-
#include "framework/kv_cache/kv_cache.h"
31-
#include "framework/model/model_input_params.h"
32-
#include "framework/model_context.h"
33-
#include "framework/state_dict/state_dict.h"
34-
30+
#if defined(USE_NPU)
31+
#include "npu/loader/base_loader.h"
32+
#endif
3533
namespace xllm {
3634
namespace layer {
3735

@@ -92,14 +90,44 @@ class BaseLayer : public torch::nn::Module {
9290

9391
virtual ~BaseLayer() {};
9492

95-
virtual void load_state_dict(const StateDict& state_dict) {};
93+
virtual void load_state_dict(const StateDict& state_dict) {
94+
#if defined(USE_NPU)
95+
if (loader_) {
96+
loader_->load_state_dict(state_dict);
97+
}
98+
#endif
99+
};
100+
101+
virtual void verify_loaded_weights() const {
102+
#if defined(USE_NPU)
103+
if (loader_) {
104+
loader_->verify_loaded_weights();
105+
}
106+
#endif
107+
};
96108

97-
virtual void verify_loaded_weights() const {};
109+
virtual void verify_loaded_weights(const std::string& prefix) const {
110+
#if defined(USE_NPU)
111+
if (loader_) {
112+
loader_->verify_loaded_weights(prefix);
113+
}
114+
#endif
115+
};
98116

99-
virtual void merge_loaded_weights() {};
117+
virtual void merge_loaded_weights() {
118+
#if defined(USE_NPU)
119+
if (loader_) {
120+
loader_->merge_loaded_weights();
121+
}
122+
init_layer();
123+
#endif
124+
};
100125

101126
virtual int64_t init_layer() { return 0; };
102127

128+
virtual void run_task(std::string taskName, std::function<int()> task) const {
129+
};
130+
103131
void set_weight(const StateDict& state_dict,
104132
const std::string& tensor_name,
105133
int weight_position,
@@ -116,16 +144,16 @@ class BaseLayer : public torch::nn::Module {
116144
int rank,
117145
int world_size);
118146

119-
virtual void run_task(std::string taskName, std::function<int()> task) const {
120-
};
121-
122147
torch::Dtype string2dtype(const std::string& dtype_str);
123148

124149
void correct_tensor_dtype(torch::Tensor& tensor,
125150
const std::string& tensorName);
126151

127152
protected:
128153
std::vector<at::Tensor> at_weight_tensors_;
154+
#if defined(USE_NPU)
155+
std::unique_ptr<BaseLoader> loader_ = nullptr;
156+
#endif
129157
at::Device device_;
130158
std::string name_;
131159
torch::ScalarType dtype_;

xllm/core/layers/npu/CMakeLists.txt

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,20 @@ cc_library(
2525
npu_qwen3_decoder_layer_impl.h
2626
npu_rms_norm_impl.h
2727
npu_siglip_encoder_layer_impl.h
28+
loader/qwen3_decoder_loader.h
29+
loader/qwen2_decoder_loader.h
30+
loader/qwen3_moe_decoder_loader.h
31+
loader/word_embedding_loader.h
32+
loader/lm_head_loader.h
33+
loader/column_parallel_linear_loader.h
34+
loader/deepseek_v2_decoder_loader.h
35+
loader/glm4_moe_decoder_loader.h
36+
loader/llama_decoder_loader.h
37+
loader/qwen2dot5_vision_encoder_loader.h
38+
loader/qwen3_vision_encoder_loader.h
39+
loader/rms_norm_loader.h
40+
loader/siglip_encoder_loader.h
41+
loader/base_loader.h
2842
SRCS
2943
npu_word_embedding_impl.cpp
3044
npu_pos_embedding_impl.cpp
@@ -45,6 +59,20 @@ cc_library(
4559
npu_qwen3_decoder_layer_impl.cpp
4660
npu_rms_norm_impl.cpp
4761
npu_siglip_encoder_layer_impl.cpp
62+
loader/qwen3_decoder_loader.cpp
63+
loader/qwen2_decoder_loader.cpp
64+
loader/qwen3_moe_decoder_loader.cpp
65+
loader/word_embedding_loader.cpp
66+
loader/lm_head_loader.cpp
67+
loader/column_parallel_linear_loader.cpp
68+
loader/deepseek_v2_decoder_loader.cpp
69+
loader/glm4_moe_decoder_loader.cpp
70+
loader/llama_decoder_loader.cpp
71+
loader/qwen2dot5_vision_encoder_loader.cpp
72+
loader/qwen3_vision_encoder_loader.cpp
73+
loader/rms_norm_loader.cpp
74+
loader/siglip_encoder_loader.cpp
75+
loader/base_loader.cpp
4876
DEPS
4977
"-Wl,--whole-archive"
5078
"-Wl,--no-whole-archive"
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "base_loader.h"
17+
18+
namespace xllm {
19+
namespace layer {
20+
21+
BaseLoader::BaseLoader(uint64_t weight_count, const ModelContext& context)
22+
: weight_count_(weight_count),
23+
parallel_args_(context.get_parallel_args()),
24+
device_(context.get_tensor_options().device()) {
25+
auto quant_args = context.get_quant_args();
26+
if (!quant_args.quantize_type().empty()) {
27+
quantize_type_ = quant_args.quantize_type();
28+
}
29+
30+
if (!quant_args.torch_dtype().empty()) {
31+
torch_dtype_ = quant_args.torch_dtype();
32+
}
33+
34+
dp_size_ = parallel_args_.dp_size();
35+
dp_local_tp_size_ = parallel_args_.world_size() / dp_size_;
36+
dp_rank_ = parallel_args_.rank() / dp_local_tp_size_;
37+
CHECK_EQ(parallel_args_.world_size(), dp_size_ * dp_local_tp_size_);
38+
dp_local_tp_rank_ = parallel_args_.rank() % dp_local_tp_size_;
39+
40+
at_weight_tensors_.resize(weight_count_);
41+
}
42+
43+
void BaseLoader::set_weight(const StateDict& state_dict,
44+
const std::string& tensor_name,
45+
int weight_position) {
46+
for (const auto& [name, tensor] : state_dict) {
47+
if (absl::EndsWith(name, tensor_name)) {
48+
at::Tensor mutable_tensor = tensor;
49+
correct_tensor_dtype(mutable_tensor, tensor_name);
50+
at_weight_tensors_[weight_position] = mutable_tensor.to(device_);
51+
}
52+
}
53+
}
54+
55+
void BaseLoader::set_weight(const StateDict& state_dict,
56+
const std::string& tensor_name,
57+
int weight_position,
58+
int dim) {
59+
for (const auto& [name, tensor] : state_dict) {
60+
if (absl::EndsWith(name, tensor_name)) {
61+
if (parallel_args_.world_size() <= 1) {
62+
at::Tensor mutable_tensor = tensor;
63+
correct_tensor_dtype(mutable_tensor, tensor_name);
64+
at_weight_tensors_[weight_position] = mutable_tensor.to(device_);
65+
} else {
66+
at_weight_tensors_[weight_position] =
67+
state_dict
68+
.get_sharded_tensor(tensor_name,
69+
/*dim=*/dim,
70+
/*rank=*/parallel_args_.rank(),
71+
/*world_size=*/parallel_args_.world_size())
72+
.to(device_);
73+
correct_tensor_dtype(at_weight_tensors_[weight_position], tensor_name);
74+
}
75+
}
76+
}
77+
}
78+
79+
void BaseLoader::set_weight(const StateDict& state_dict,
80+
const std::string& tensor_name,
81+
int weight_position,
82+
int dim,
83+
int rank,
84+
int world_size) {
85+
for (const auto& [name, tensor] : state_dict) {
86+
if (absl::EndsWith(name, tensor_name)) {
87+
if (world_size <= 1) {
88+
at::Tensor mutable_tensor = tensor;
89+
correct_tensor_dtype(mutable_tensor, tensor_name);
90+
at_weight_tensors_[weight_position] = mutable_tensor.to(device_);
91+
} else {
92+
at_weight_tensors_[weight_position] =
93+
state_dict
94+
.get_sharded_tensor(tensor_name,
95+
/*dim=*/dim,
96+
/*rank=*/rank,
97+
/*world_size=*/world_size)
98+
.to(device_);
99+
correct_tensor_dtype(at_weight_tensors_[weight_position], tensor_name);
100+
}
101+
}
102+
}
103+
}
104+
105+
void BaseLoader::correct_tensor_dtype(torch::Tensor& tensor,
106+
const std::string& tensorName) {
107+
if (absl::EndsWith(tensorName, "deq_scale") &&
108+
(torch_dtype_.compare("bfloat16") == 0)) {
109+
return;
110+
}
111+
112+
if (tensor.dtype() != torch::kInt8 && tensor.dtype() != torch::kInt32 &&
113+
tensor.dtype() != torch::kInt64) {
114+
torch::Dtype dtype = string2dtype(torch_dtype_);
115+
tensor = tensor.to(dtype);
116+
}
117+
}
118+
119+
torch::Dtype BaseLoader::string2dtype(const std::string& dtype_str) {
120+
if (dtype_str.compare("float16") == 0) {
121+
return torch::kFloat16;
122+
} else if (dtype_str.compare("bfloat16") == 0) {
123+
return torch::kBFloat16;
124+
} else if (dtype_str.compare("float32") == 0) {
125+
return torch::kFloat32;
126+
} else if (dtype_str.compare("float64") == 0) {
127+
return torch::kFloat64;
128+
} else if (dtype_str.compare("int8") == 0) {
129+
return torch::kInt8;
130+
} else if (dtype_str.compare("int16") == 0) {
131+
return torch::kInt16;
132+
} else if (dtype_str.compare("int32") == 0) {
133+
return torch::kInt32;
134+
} else if (dtype_str.compare("int64") == 0) {
135+
return torch::kInt64;
136+
} else if (dtype_str.compare("uint8") == 0) {
137+
return torch::kUInt8;
138+
} else if (dtype_str.compare("bool") == 0) {
139+
return torch::kBool;
140+
}
141+
142+
LOG(FATAL) << "Unsupported dtype string: " << dtype_str;
143+
}
144+
145+
} // namespace layer
146+
} // namespace xllm
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#pragma once
17+
18+
#include <absl/strings/match.h>
19+
#include <torch/torch.h>
20+
21+
#include "framework/eplb/expert_buffer_manager.h"
22+
#include "framework/kv_cache/kv_cache.h"
23+
#include "framework/model/model_input_params.h"
24+
#include "framework/model_context.h"
25+
#include "framework/state_dict/state_dict.h"
26+
#include "xllm_kernels/pytorch/atb_torch/core/include/base_operation.h"
27+
#include "xllm_kernels/pytorch/atb_torch/core/include/graph_operation.h"
28+
29+
namespace xllm {
30+
namespace layer {
31+
32+
class BaseLoader {
33+
public:
34+
explicit BaseLoader(uint64_t weight_count, const ModelContext& context);
35+
virtual ~BaseLoader() = default;
36+
37+
virtual void load_state_dict(const StateDict& state_dict) {};
38+
virtual void verify_loaded_weights() const {};
39+
virtual void verify_loaded_weights(const std::string& prefix) const {};
40+
virtual void merge_loaded_weights() {};
41+
virtual void resize_experts_weights(int num_of_device_experts) {};
42+
torch::Dtype string2dtype(const std::string& dtype_str);
43+
44+
void correct_tensor_dtype(torch::Tensor& tensor,
45+
const std::string& tensorName);
46+
47+
std::vector<at::Tensor>& get_at_weight_tensors() {
48+
return at_weight_tensors_;
49+
}
50+
51+
std::unordered_map<std::string, std::vector<torch::Tensor>>&
52+
get_experts_weight_tensors() {
53+
return experts_weights_;
54+
}
55+
56+
std::unique_ptr<ExpertBufferManager>& get_expert_shared_buffer() {
57+
return shared_buffer_;
58+
}
59+
60+
std::vector<int32_t>& get_device_expert_list() { return device_expert_list_; }
61+
62+
atb_torch::TorchTensorMap& get_weights_map() { return weights_map_; }
63+
64+
protected:
65+
uint64_t weight_count_;
66+
xllm::ParallelArgs parallel_args_;
67+
std::string quantize_type_;
68+
std::string torch_dtype_;
69+
torch::ScalarType dtype_;
70+
torch::TensorOptions options_;
71+
std::vector<at::Tensor> at_weight_tensors_;
72+
std::unique_ptr<ExpertBufferManager> shared_buffer_ = nullptr;
73+
std::unordered_map<std::string, torch::Tensor> shared_experts_weights_;
74+
std::unordered_map<std::string, std::vector<torch::Tensor>> experts_weights_;
75+
std::vector<int32_t> device_expert_list_;
76+
atb_torch::TorchTensorMap weights_map_;
77+
78+
at::Device device_;
79+
int32_t dp_size_;
80+
int32_t dp_local_tp_size_;
81+
int32_t dp_rank_;
82+
int32_t dp_local_tp_rank_;
83+
84+
void set_weight(const StateDict& state_dict,
85+
const std::string& tensor_name,
86+
int weight_position);
87+
88+
void set_weight(const StateDict& state_dict,
89+
const std::string& tensor_name,
90+
int weight_position,
91+
int dim);
92+
93+
void set_weight(const StateDict& state_dict,
94+
const std::string& tensor_name,
95+
int weight_position,
96+
int dim,
97+
int rank,
98+
int world_size);
99+
};
100+
101+
} // namespace layer
102+
} // namespace xllm

0 commit comments

Comments
 (0)