-
Notifications
You must be signed in to change notification settings - Fork 91
refactor: separate the weight loading in the npu layer class. #489
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
eef8bfc to
aaa9676
Compare
| int weight_position) { | ||
| for (const auto& [name, tensor] : state_dict) { | ||
| if (absl::EndsWith(name, tensor_name)) { | ||
| at::Tensor mutable_tensor = tensor; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
change all at:: to torch::.
|
|
||
| void BaseLoader::set_weight(const StateDict& state_dict, | ||
| const std::string& tensor_name, | ||
| int weight_position, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
clarify whether it is int32_t or int64_t.
| } | ||
|
|
||
| torch::Dtype BaseLoader::string2dtype(const std::string& dtype_str) { | ||
| if (dtype_str.compare("float16") == 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use switch
|
|
||
| if (tensor.dtype() != torch::kInt8 && tensor.dtype() != torch::kInt32 && | ||
| tensor.dtype() != torch::kInt64) { | ||
| torch::Dtype dtype = string2dtype(torch_dtype_); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
replace torch::Dtype with torch::ScalarType.
| const ModelContext& context) | ||
| : BaseLoader(weight_count, context) { | ||
| auto options = context.get_tensor_options(); | ||
| dtype_ = c10::typeMetaToScalarType(options.dtype()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
c10:: ==> torch::
| namespace layer { | ||
| class ColumParallelLinearLoader : public BaseLoader { | ||
| public: | ||
| explicit ColumParallelLinearLoader(uint64_t weight_count, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need to add explicit when construct func has two params.
|
|
||
| void ColumParallelLinearLoader::verify_loaded_weights( | ||
| const std::string& weight_str) const { | ||
| CHECK(at_weight_tensors_[0].sizes() != std::vector<int64_t>({1})) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: CHECK_EQ
| #include <torch_npu/csrc/core/npu/NPUCachingAllocator.h> | ||
| #include <torch_npu/csrc/core/npu/NPUException.h> | ||
|
|
||
| #include <map> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
try # include <unordered_map>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe we can move BaseLayer to npu dir or merge BaseLayer and NpuBaseLayer, because no other platform will use BaseLayer.
No description provided.