diff --git a/include/infinicore/nn/linear.hpp b/include/infinicore/nn/linear.hpp index e77a432c2..511b666dd 100644 --- a/include/infinicore/nn/linear.hpp +++ b/include/infinicore/nn/linear.hpp @@ -2,14 +2,17 @@ #include "../ops.hpp" #include "module.hpp" +#include "quantization.hpp" #include +#include namespace infinicore::nn { class BaseLinear : public Module { public: BaseLinear(size_t in_features, size_t out_features, bool bias = true, - const DataType &dtype = DataType::F32, const Device &device = Device()); + const DataType &dtype = DataType::F32, const Device &device = Device(), + const std::optional &quant_scheme = std::nullopt); // Forward pass: output = input @ weight.T + bias Tensor forward(Tensor &input) const; @@ -27,12 +30,17 @@ class BaseLinear : public Module { // Accessors for parameters Tensor weight() const { return weight_; } Tensor bias() const { return bias_; } + Tensor weight_scale() const { return weight_scale_; } + Tensor weight_zeros() const { return weight_zeros_; } protected: // Parameters INFINICORE_NN_PARAMETER(weight); INFINICORE_NN_PARAMETER(bias); + INFINICORE_NN_PARAMETER(weight_scale); + INFINICORE_NN_PARAMETER(weight_zeros); + protected: // Helper method for common forward computation Tensor compute_linear(Tensor &input) const; @@ -41,6 +49,7 @@ class BaseLinear : public Module { size_t out_features_; bool has_bias_; DataType dtype_; + const std::optional quant_scheme_; }; } // namespace infinicore::nn @@ -50,7 +59,8 @@ namespace infinicore::nn { class Linear : public BaseLinear { public: Linear(size_t in_features, size_t out_features, bool bias = true, - const DataType &dtype = DataType::F32, const Device &device = Device()); + const DataType &dtype = DataType::F32, const Device &device = Device(), + const std::optional &quant_scheme = std::nullopt); // Forward pass: output = input @ weight.T + bias Tensor forward(Tensor &input) const; @@ -63,7 +73,8 @@ class ColumnParallelLinear : public BaseLinear { public: ColumnParallelLinear(size_t in_features, size_t out_features, bool bias = true, const DataType &dtype = DataType::F32, const Device &device = Device(), - Size tp_rank = 0, Size tp_size = 1); + Size tp_rank = 0, Size tp_size = 1, + const std::optional &quant_scheme = std::nullopt); // Forward pass: output = input @ weight.T + bias Tensor forward(Tensor &input) const; @@ -80,7 +91,8 @@ class RowParallelLinear : public BaseLinear { public: RowParallelLinear(size_t in_features, size_t out_features, bool bias = true, const DataType &dtype = DataType::F32, const Device &device = Device(), - Size tp_rank = 0, Size tp_size = 1, infinicclComm_t communicator = nullptr); + Size tp_rank = 0, Size tp_size = 1, infinicclComm_t communicator = nullptr, + const std::optional &quant_scheme = std::nullopt); // Forward pass: output = input @ weight.T + bias Tensor forward(Tensor &input) const; diff --git a/include/infinicore/nn/quantization.hpp b/include/infinicore/nn/quantization.hpp new file mode 100644 index 000000000..e1c51f629 --- /dev/null +++ b/include/infinicore/nn/quantization.hpp @@ -0,0 +1,11 @@ +// quant.hpp +#pragma once + +namespace infinicore::nn { + +enum class QuantScheme { + NONE, + COMPRESSED_TENSOR_W8A8I8, +}; + +} // namespace infinicore::nn diff --git a/include/infinicore/ops/linear_w8a8i8.hpp b/include/infinicore/ops/linear_w8a8i8.hpp new file mode 100644 index 000000000..b29b7ba11 --- /dev/null +++ b/include/infinicore/ops/linear_w8a8i8.hpp @@ -0,0 +1,12 @@ +#pragma once + +#include "common/op.hpp" +#include + +namespace infinicore::op { + +Tensor linear_w8a8i8(Tensor input, Tensor weight_packed, Tensor weight_scale, std::optional bias); + +void linear_w8a8i8_(Tensor out, Tensor input, Tensor weight_packed, Tensor weight_scale, std::optional bias); + +} // namespace infinicore::op diff --git a/include/infinicore/ops/per_channel_quant_i8.hpp b/include/infinicore/ops/per_channel_quant_i8.hpp new file mode 100644 index 000000000..f5055e376 --- /dev/null +++ b/include/infinicore/ops/per_channel_quant_i8.hpp @@ -0,0 +1,15 @@ +#pragma once +#include "../device.hpp" +#include "common/op.hpp" +#include + +namespace infinicore::op { +class PerChannelQuantI8 { +public: + using schema = void (*)(Tensor, Tensor, Tensor); + static void execute(Tensor x, Tensor x_packed, Tensor x_scale); + static common::OpDispatcher &dispatcher(); +}; + +void per_channel_quant_i8_(Tensor x, Tensor x_packed, Tensor x_scale); +} // namespace infinicore::op diff --git a/include/infinicore/ops/scaled_mm_i8.hpp b/include/infinicore/ops/scaled_mm_i8.hpp new file mode 100644 index 000000000..747f84653 --- /dev/null +++ b/include/infinicore/ops/scaled_mm_i8.hpp @@ -0,0 +1,16 @@ +#pragma once + +#include "../device.hpp" +#include "common/op.hpp" +#include + +namespace infinicore::op { +class ScaledMMI8 { +public: + using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor, std::optional); + static void execute(Tensor c, Tensor a_p, Tensor a_s, Tensor b_p, Tensor b_s, std::optional bias); + static common::OpDispatcher &dispatcher(); +}; + +void scaled_mm_i8_(Tensor c, Tensor a_p, Tensor a_s, Tensor b_p, Tensor b_s, std::optional bias); +} // namespace infinicore::op diff --git a/include/infiniop.h b/include/infiniop.h index c0a09fcb4..61eb23cf9 100644 --- a/include/infiniop.h +++ b/include/infiniop.h @@ -11,6 +11,7 @@ #include "infiniop/ops/dequantize_awq.h" #include "infiniop/ops/gelu.h" #include "infiniop/ops/gemm.h" +#include "infiniop/ops/int8_gemm.h" #include "infiniop/ops/layer_norm.h" #include "infiniop/ops/logsoftmax.h" #include "infiniop/ops/lp_norm.h" @@ -19,6 +20,7 @@ #include "infiniop/ops/paged_attention.h" #include "infiniop/ops/paged_attention_prefill.h" #include "infiniop/ops/paged_caching.h" +#include "infiniop/ops/quant/per_channel_quant_int8.h" #include "infiniop/ops/random_sample.h" #include "infiniop/ops/rearrange.h" #include "infiniop/ops/relu.h" diff --git a/include/infiniop/ops/quant/per_channel_quant_int8.h b/include/infiniop/ops/quant/per_channel_quant_int8.h new file mode 100644 index 000000000..ce21f4556 --- /dev/null +++ b/include/infiniop/ops/quant/per_channel_quant_int8.h @@ -0,0 +1,28 @@ +#ifndef __INFINIOP_PER_CHANNEL_QUANT_INT8_API_H__ +#define __INFINIOP_PER_CHANNEL_QUANT_INT8_API_H__ + +#include "../../operator_descriptor.h" + +typedef InfiniopDescriptor *infiniopPerChannelQuantI8Descriptor_t; + +__C __export infiniStatus_t infiniopCreatePerChannelQuantI8Descriptor(infiniopHandle_t handle, + infiniopPerChannelQuantI8Descriptor_t *desc_ptr, + infiniopTensorDescriptor_t x_packed_desc, + infiniopTensorDescriptor_t x_scale_desc, + infiniopTensorDescriptor_t x_zero_desc, + infiniopTensorDescriptor_t x_desc); + +__C __export infiniStatus_t infiniopGetPerChannelQuantI8WorkspaceSize(infiniopPerChannelQuantI8Descriptor_t desc, size_t *size); + +__C __export infiniStatus_t infiniopPerChannelQuantI8(infiniopPerChannelQuantI8Descriptor_t desc, + void *workspace, + size_t workspace_size, + void *x_packed, + void *x_scale, + void *x_zero, + const void *x, + void *stream); + +__C __export infiniStatus_t infiniopDestroyPerChannelQuantI8Descriptor(infiniopPerChannelQuantI8Descriptor_t desc); + +#endif diff --git a/python/infinicore/nn/functional/__init__.py b/python/infinicore/nn/functional/__init__.py index 255079790..53a0c9d9f 100644 --- a/python/infinicore/nn/functional/__init__.py +++ b/python/infinicore/nn/functional/__init__.py @@ -1,6 +1,7 @@ from .causal_softmax import causal_softmax from .embedding import embedding from .linear import linear +from .linear_w8a8i8 import linear_w8a8i8 from .random_sample import random_sample from .rms_norm import rms_norm from .rope import RopeAlgo, rope @@ -17,4 +18,5 @@ "embedding", "rope", "RopeAlgo", + "linear_w8a8i8", ] diff --git a/python/infinicore/nn/functional/linear_w8a8i8.py b/python/infinicore/nn/functional/linear_w8a8i8.py new file mode 100644 index 000000000..33cb59b0e --- /dev/null +++ b/python/infinicore/nn/functional/linear_w8a8i8.py @@ -0,0 +1,31 @@ +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + + +def linear_w8a8i8( + input: Tensor, + weight_packed: Tensor, + weight_scale: Tensor, + bias=None, + out=None, +) -> Tensor: + r"""Linear layer with weight quantized to int8 and input quantized to int8 with per-tensor scale.""" + + if out is None: + return Tensor( + _infinicore.linear_w8a8i8( + input._underlying, + weight_packed._underlying, + weight_scale._underlying, + None if bias is None else bias._underlying, + ) + ) + + _infinicore.linear_w8a8i8_( + out._underlying, + input._underlying, + weight_packed._underlying, + weight_scale._underlying, + None if bias is None else bias._underlying, + ) + return out diff --git a/python/infinicore/nn/modules/linear_w8a8.py b/python/infinicore/nn/modules/linear_w8a8.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/infinicore/nn/linear.cc b/src/infinicore/nn/linear.cc index bb4fc29b1..024d18e92 100644 --- a/src/infinicore/nn/linear.cc +++ b/src/infinicore/nn/linear.cc @@ -2,36 +2,56 @@ #include "../utils.hpp" #include "infinicore/ops.hpp" #include "infinicore/ops/linear.hpp" +#include "infinicore/ops/linear_w8a8i8.hpp" #include #include +#include + namespace infinicore::nn { BaseLinear::BaseLinear(size_t in_features, size_t out_features, bool bias, - const DataType &dtype, const Device &device) + const DataType &dtype, const Device &device, + const std::optional &quant_scheme) : in_features_(in_features), out_features_(out_features), has_bias_(bias), - dtype_(dtype) { + dtype_(dtype), + quant_scheme_(quant_scheme) { device_ = device; } Tensor BaseLinear::compute_linear(Tensor &input) const { - // Ensure input is contiguous before creating views (required for matmul) - // This prevents hanging when input tensor has non-contiguous memory layout - Tensor input_contiguous = input->is_contiguous() ? input : input->contiguous(); + switch (this->quant_scheme_.value_or(QuantScheme::NONE)) { + case infinicore::nn::QuantScheme::COMPRESSED_TENSOR_W8A8I8: { + Tensor input_contiguous = input->is_contiguous() ? input : input->contiguous(); - // Use ops::linear_ directly to match Python backend's exact code path - // This ensures identical computation and numerical results - // Parameter inherits from Tensor, so we cast to Tensor explicitly - Tensor weight_tensor = static_cast(weight_); - std::optional bias_opt = has_bias_ ? std::make_optional(static_cast(bias_)) : std::nullopt; + Tensor weight_packed_tensor = static_cast(weight_); + Tensor weight_scale_tensor = static_cast(weight_scale_); + // weight_packed should be transposed and non-contiguous. + std::optional bias_opt = has_bias_ ? std::make_optional(static_cast(bias_)) : std::nullopt; - auto output = infinicore::op::linear(input_contiguous->contiguous(), weight_tensor->contiguous(), bias_opt); - return output; -} + auto output = infinicore::op::linear_w8a8i8(input_contiguous->contiguous(), weight_packed_tensor, weight_scale_tensor, bias_opt); + return output; + } + default: { + // Ensure input is contiguous before creating views (required for matmul) + // This prevents hanging when input tensor has non-contiguous memory layout + Tensor input_contiguous = input->is_contiguous() ? input : input->contiguous(); + + // Use ops::linear_ directly to match Python backend's exact code path + // This ensures identical computation and numerical results + // Parameter inherits from Tensor, so we cast to Tensor explicitly + Tensor weight_tensor = static_cast(weight_); + std::optional bias_opt = has_bias_ ? std::make_optional(static_cast(bias_)) : std::nullopt; + + auto output = infinicore::op::linear(input_contiguous->contiguous(), weight_tensor->contiguous(), bias_opt); + return output; + } + } +} // namespace infinicore::nn Tensor BaseLinear::forward(Tensor &input) const { return compute_linear(input); @@ -51,23 +71,40 @@ Tensor BaseLinear::forward(Tensor &input, Tensor &residual) const { namespace infinicore::nn { Linear::Linear(size_t in_features, size_t out_features, bool bias, - const DataType &dtype, const Device &device) - : BaseLinear(in_features, out_features, bias, dtype, device_) { + const DataType &dtype, const Device &device, + const std::optional &quant_scheme) + : BaseLinear(in_features, out_features, bias, dtype, device_, quant_scheme) { device_ = device; - // Initialize parameters using macro - INFINICORE_NN_PARAMETER_INIT(weight, ({out_features, in_features}, dtype_, device)); - - // Register bias parameter if requested - if (bias) { - INFINICORE_NN_PARAMETER_INIT(bias, ({out_features}, dtype_, device)); - } else { - bias_ = Parameter(); // Default constructed empty parameter + switch (this->quant_scheme_.value_or(QuantScheme::NONE)) { + case infinicore::nn::QuantScheme::COMPRESSED_TENSOR_W8A8I8: { + INFINICORE_NN_PARAMETER_INIT(weight, ({out_features, in_features}, infinicore::DataType::I8, device)); + INFINICORE_NN_PARAMETER_INIT(weight_scale, ({out_features, 1}, infinicore::DataType::F32, device)); + + if (bias) { + INFINICORE_NN_PARAMETER_INIT(bias, ({out_features}, dtype_, device)); + } else { + bias_ = Parameter(); + } + break; + } + default: { + // Initialize parameters using macro + INFINICORE_NN_PARAMETER_INIT(weight, ({out_features, in_features}, dtype_, device)); + + // Register bias parameter if requested + if (bias) { + INFINICORE_NN_PARAMETER_INIT(bias, ({out_features}, dtype_, device)); + } else { + bias_ = Parameter(); // Default constructed empty parameter + } + + // SPDLOG_DEBUG("Created Linear module: in_features={}, out_features={}, bias={}, dtype={}", + // in_features, out_features, bias, static_cast(dtype_)); + break; + } } - - // SPDLOG_DEBUG("Created Linear module: in_features={}, out_features={}, bias={}, dtype={}", - // in_features, out_features, bias, static_cast(dtype_)); } Tensor Linear::forward(Tensor &input) const { @@ -84,23 +121,41 @@ namespace infinicore::nn { ColumnParallelLinear::ColumnParallelLinear(size_t in_features, size_t out_features, bool bias, const DataType &dtype, const Device &device, - Size tp_rank, Size tp_size) - : BaseLinear(in_features, out_features, bias, dtype, device_), + Size tp_rank, Size tp_size, + const std::optional &quant_scheme) + : BaseLinear(in_features, out_features, bias, dtype, device_, quant_scheme), tp_rank_(tp_rank), tp_size_(tp_size) { device_ = device; - // Initialize parameters using macro - INFINICORE_NN_PARAMETER_INIT(weight, ({out_features, in_features}, dtype_, device, - 0, tp_rank_, tp_size_)); + switch (this->quant_scheme_.value_or(QuantScheme::NONE)) { + case infinicore::nn::QuantScheme::COMPRESSED_TENSOR_W8A8I8: { + + INFINICORE_NN_PARAMETER_INIT(weight, ({out_features, in_features}, infinicore::DataType::I8, device, 0, tp_rank_, tp_size_)); + INFINICORE_NN_PARAMETER_INIT(weight_scale, ({out_features, 1}, infinicore::DataType::F32, device, 0, tp_rank_, tp_size_)); - // Register bias parameter if requested - if (bias) { - INFINICORE_NN_PARAMETER_INIT(bias, ({out_features}, dtype_, device, - 0, tp_rank_, tp_size_)); - } else { - bias_ = Parameter(); // Default constructed empty parameter + if (bias) { + INFINICORE_NN_PARAMETER_INIT(bias, ({out_features}, dtype_, device, 0, 0, 1)); + } else { + bias_ = Parameter(); + } + break; + } + default: { + // Initialize parameters using macro + INFINICORE_NN_PARAMETER_INIT(weight, ({out_features, in_features}, dtype_, device, + 0, tp_rank_, tp_size_)); + + // Register bias parameter if requested + if (bias) { + INFINICORE_NN_PARAMETER_INIT(bias, ({out_features}, dtype_, device, + 0, tp_rank_, tp_size_)); + } else { + bias_ = Parameter(); // Default constructed empty parameter + } + break; + } } // SPDLOG_DEBUG("Created ColumnParallelLinear module: in_features={}, out_features={}, bias={}, dtype={}", @@ -121,26 +176,43 @@ namespace infinicore::nn { RowParallelLinear::RowParallelLinear(size_t in_features, size_t out_features, bool bias, const DataType &dtype, const Device &device, - Size tp_rank, Size tp_size, infinicclComm_t communicator) - : BaseLinear(in_features, out_features, bias, dtype, device_), + Size tp_rank, Size tp_size, infinicclComm_t communicator, + const std::optional &quant_scheme) + : BaseLinear(in_features, out_features, bias, dtype, device_, quant_scheme), tp_rank_(tp_rank), tp_size_(tp_size), communicator_(communicator) { device_ = device; - // Initialize parameters using macro - INFINICORE_NN_PARAMETER_INIT(weight, ({out_features, in_features}, dtype_, device, - 1, tp_rank_, tp_size_)); - - // Register bias parameter if requested - if (bias && (0 == tp_rank_)) { - INFINICORE_NN_PARAMETER_INIT(bias, ({out_features}, dtype_, device, 0, 0, 1)); - } else { - bias_ = Parameter(); // Default constructed empty parameter + switch (this->quant_scheme_.value_or(QuantScheme::NONE)) { + case infinicore::nn::QuantScheme::COMPRESSED_TENSOR_W8A8I8: { + INFINICORE_NN_PARAMETER_INIT(weight, ({out_features, in_features}, infinicore::DataType::I8, device, 1, tp_rank_, tp_size_)); + INFINICORE_NN_PARAMETER_INIT(weight_scale, ({out_features, 1}, infinicore::DataType::F32, device, 0, 0, 1)); + + if (bias) { + INFINICORE_NN_PARAMETER_INIT(bias, ({out_features}, dtype_, device, 0, tp_rank_, tp_size_)); + } else { + bias_ = Parameter(); + } + break; + } + default: { + // Initialize parameters using macro + INFINICORE_NN_PARAMETER_INIT(weight, ({out_features, in_features}, dtype_, device, + 1, tp_rank_, tp_size_)); + + // Register bias parameter if requested + if (bias && (0 == tp_rank_)) { + INFINICORE_NN_PARAMETER_INIT(bias, ({out_features}, dtype_, device, 0, 0, 1)); + } else { + bias_ = Parameter(); // Default constructed empty parameter + } + + // SPDLOG_DEBUG("Created RowParallelLinear module: in_features={}, out_features={}, bias={}, dtype={}", + // in_features, out_features, bias, static_cast(dtype_)); + break; + } } - - // SPDLOG_DEBUG("Created RowParallelLinear module: in_features={}, out_features={}, bias={}, dtype={}", - // in_features, out_features, bias, static_cast(dtype_)); } Tensor RowParallelLinear::forward(Tensor &input) const { diff --git a/src/infinicore/nn/module.cc b/src/infinicore/nn/module.cc index 89207e1e9..b4be82058 100644 --- a/src/infinicore/nn/module.cc +++ b/src/infinicore/nn/module.cc @@ -1,4 +1,5 @@ #include "infinicore/nn/module.hpp" +#include #include #include @@ -72,6 +73,7 @@ Tensor Module::register_buffer(const std::string &name, Parameter buffer) { void Module::load_state_dict_recursively(const std::unordered_map &_state_dict, const std::string &prefix) { // Load direct parameters with the given prefix + for (const auto &[param_name, param] : parameters_) { std::string full_name = prefix.empty() ? param_name : prefix + "." + param_name; auto it = _state_dict.find(full_name); diff --git a/src/infinicore/ops/linear_w8a8i8/linear_w8a8i8.cc b/src/infinicore/ops/linear_w8a8i8/linear_w8a8i8.cc new file mode 100644 index 000000000..c0c7ce28f --- /dev/null +++ b/src/infinicore/ops/linear_w8a8i8/linear_w8a8i8.cc @@ -0,0 +1,66 @@ +#include "infinicore/ops/linear_w8a8i8.hpp" +#include "infinicore/ops/per_channel_quant_i8.hpp" +#include "infinicore/ops/scaled_mm_i8.hpp" +#include +namespace infinicore::op { + +Tensor linear_w8a8i8(Tensor input, + Tensor weight_packed, + Tensor weight_scale, + std::optional bias) { + + // Input is of shape [M, K], Weight_packed is of shape [N, K],stirdes is [N, 1] + Size ndim = input->ndim(); + Size out_features = weight_packed->shape()[0]; + + // Assign memory to out variables + auto output_shape = input->shape(); + output_shape[ndim - 1] = out_features; + auto out = Tensor::empty(output_shape, input->dtype(), input->device()); + + // Inplace Calculate + linear_w8a8i8_(out, input, weight_packed, weight_scale, bias); + return out; +} + +void linear_w8a8i8_(Tensor out, + Tensor input, + Tensor weight_packed, + Tensor weight_scale, + std::optional bias) { + + auto weight_packed_shape = weight_packed->shape(); + Size out_features = weight_packed_shape[0]; + Size in_features = weight_packed_shape[1]; + + Size ndim = input->ndim(); + assert(out->ndim() == ndim); + + Size N = 1; + auto input_shape = input->shape(); + for (size_t i = 0; i < ndim - 1; ++i) { + N *= input_shape[i]; + } + + auto input_packed = Tensor::empty( + {N, input_shape[ndim - 1]}, + DataType::I8, + input->device()); + auto input_scale = Tensor::empty( + {N, 1}, + DataType::F32, + input->device()); + op::per_channel_quant_i8_(input->view({N, in_features}), input_packed, input_scale); + if (bias.has_value()) { + bias = std::make_optional(bias.value()->as_strided({N, out_features}, {0, 1})); + } + op::scaled_mm_i8_( + out->view({N, out_features}), + input_packed, + input_scale, + weight_packed->permute({1, 0}), + weight_scale, + bias); +} + +} // namespace infinicore::op diff --git a/src/infinicore/ops/per_channel_quant_i8/per_channel_quant_i8.cc b/src/infinicore/ops/per_channel_quant_i8/per_channel_quant_i8.cc new file mode 100644 index 000000000..4a0862eaf --- /dev/null +++ b/src/infinicore/ops/per_channel_quant_i8/per_channel_quant_i8.cc @@ -0,0 +1,49 @@ +#include "infinicore/ops/per_channel_quant_i8.hpp" +#include "../../utils.hpp" +#include "infinicore/common/hash.hpp" +#include "infinicore/ops/common/cache.hpp" +#include +#include + +namespace infinicore::op::per_channel_quant_i8_impl::infiniop { + +thread_local common::OpCache caches( + 100, // capacity + [](infiniopPerChannelQuantI8Descriptor_t &desc) { + if (desc != nullptr) { + INFINICORE_CHECK_ERROR(infiniopDestroyPerChannelQuantI8Descriptor(desc)); + desc = nullptr; + } + }); + +void calculate(Tensor x, Tensor x_packed, Tensor x_scale) { + size_t seed = hash_combine(x, x_packed, x_scale); + + auto device = context::getDevice(); + auto &cache = caches.getCache(device); + + auto desc_opt = cache.get(seed); + infiniopGemmDescriptor_t desc = nullptr; + if (!desc_opt) { + INFINICORE_CHECK_ERROR(infiniopCreatePerChannelQuantI8Descriptor( + context::getInfiniopHandle(device), &desc, + x_packed->desc(), x_scale->desc(), nullptr, x->desc())); + cache.put(seed, desc); + } else { + desc = *desc_opt; + } + + size_t workspace_size = 0; + INFINICORE_CHECK_ERROR(infiniopGetPerChannelQuantI8WorkspaceSize(desc, &workspace_size)); + std::shared_ptr workspace = context::allocateMemory(workspace_size); + INFINICORE_CHECK_ERROR(infiniopPerChannelQuantI8( + desc, workspace->data(), workspace_size, + x_packed->data(), x_scale->data(), nullptr, x->data(), context::getStream())); +} + +static bool registered = []() { + PerChannelQuantI8::dispatcher().registerAll(&calculate, false); + return true; +}(); + +} // namespace infinicore::op::per_channel_quant_i8_impl::infiniop diff --git a/src/infinicore/ops/per_channel_quant_i8/per_channel_quant_i8_infiniop.cc b/src/infinicore/ops/per_channel_quant_i8/per_channel_quant_i8_infiniop.cc new file mode 100644 index 000000000..fd282724f --- /dev/null +++ b/src/infinicore/ops/per_channel_quant_i8/per_channel_quant_i8_infiniop.cc @@ -0,0 +1,22 @@ +#include "infinicore/ops/per_channel_quant_i8.hpp" + +#include "../../utils.hpp" +#include + +namespace infinicore::op { + +common::OpDispatcher &PerChannelQuantI8::dispatcher() { + static common::OpDispatcher dispatcher_; + return dispatcher_; +}; + +void PerChannelQuantI8::execute(Tensor x, Tensor x_packed, Tensor x_scale) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(x, x_packed, x_scale); + infinicore::context::setDevice(x->device()); + dispatcher().lookup(x->device().getType())(x, x_packed, x_scale); +} + +void per_channel_quant_i8_(Tensor x, Tensor x_packed, Tensor x_scale) { + PerChannelQuantI8::execute(x, x_packed, x_scale); +} +} // namespace infinicore::op diff --git a/src/infinicore/ops/scaled_mm_i8/scaled_mm_i8.cc b/src/infinicore/ops/scaled_mm_i8/scaled_mm_i8.cc new file mode 100644 index 000000000..5b9485cb8 --- /dev/null +++ b/src/infinicore/ops/scaled_mm_i8/scaled_mm_i8.cc @@ -0,0 +1,50 @@ +#include "infinicore/ops/scaled_mm_i8.hpp" +#include "../../utils.hpp" +#include "infinicore/common/hash.hpp" +#include "infinicore/ops/common/cache.hpp" +#include + +namespace infinicore::op::scaled_mm_i8_impl::infiniop { + +thread_local common::OpCache caches( + 100, // capacity + [](infiniopI8GemmDescriptor_t &desc) { + if (desc != nullptr) { + INFINICORE_CHECK_ERROR(infiniopDestroyI8GemmDescriptor(desc)); + desc = nullptr; + } + }); + +void calculate(Tensor c, Tensor a_p, Tensor a_s, Tensor b_p, Tensor b_s, std::optional bias) { + size_t seed = hash_combine(c, a_p, a_s, b_p, b_s); + + auto device = context::getDevice(); + auto &cache = caches.getCache(device); + + auto desc_opt = cache.get(seed); + infiniopGemmDescriptor_t desc = nullptr; + + if (!desc_opt) { + INFINICORE_CHECK_ERROR(infiniopCreateI8GemmDescriptor( + context::getInfiniopHandle(device), &desc, + c->desc(), bias.has_value() ? bias.value()->desc() : nullptr, a_p->desc(), a_s->desc(), b_p->desc(), b_s->desc())); + cache.put(seed, desc); + } else { + desc = *desc_opt; + } + + size_t workspace_size = 0; + INFINICORE_CHECK_ERROR(infiniopGetI8GemmWorkspaceSize(desc, &workspace_size)); + std::shared_ptr workspace = context::allocateMemory(workspace_size); + + INFINICORE_CHECK_ERROR(infiniopI8Gemm( + desc, workspace->data(), workspace_size, + c->data(), bias.has_value() ? bias.value()->data() : nullptr, a_p->data(), a_s->data(), b_p->data(), b_s->data(), context::getStream())); +} + +static bool registered = []() { + ScaledMMI8::dispatcher().registerAll(&calculate, false); + return true; +}(); + +} // namespace infinicore::op::scaled_mm_i8_impl::infiniop diff --git a/src/infinicore/ops/scaled_mm_i8/scaled_mm_i8_infiniop.cc b/src/infinicore/ops/scaled_mm_i8/scaled_mm_i8_infiniop.cc new file mode 100644 index 000000000..47827213f --- /dev/null +++ b/src/infinicore/ops/scaled_mm_i8/scaled_mm_i8_infiniop.cc @@ -0,0 +1,22 @@ +#include "infinicore/ops/scaled_mm_i8.hpp" + +#include "../../utils.hpp" + +namespace infinicore::op { + +common::OpDispatcher &ScaledMMI8::dispatcher() { + static common::OpDispatcher dispatcher_; + return dispatcher_; +}; + +void ScaledMMI8::execute(Tensor c, Tensor a_p, Tensor a_s, Tensor b_p, Tensor b_s, std::optional bias) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(c, a_p, a_s, b_p, b_s); + infinicore::context::setDevice(c->device()); + dispatcher().lookup(c->device().getType())(c, a_p, a_s, b_p, b_s, bias); +} + +void scaled_mm_i8_(Tensor c, Tensor a_p, Tensor a_s, Tensor b_p, Tensor b_s, std::optional bias) { + ScaledMMI8::execute(c, a_p, a_s, b_p, b_s, bias); +} + +} // namespace infinicore::op diff --git a/src/infinicore/pybind11/ops.hpp b/src/infinicore/pybind11/ops.hpp index 3d6ebe79a..8b6e100cd 100644 --- a/src/infinicore/pybind11/ops.hpp +++ b/src/infinicore/pybind11/ops.hpp @@ -8,6 +8,7 @@ #include "ops/causal_softmax.hpp" #include "ops/embedding.hpp" #include "ops/linear.hpp" +#include "ops/linear_w8a8i8.hpp" #include "ops/matmul.hpp" #include "ops/mul.hpp" #include "ops/paged_attention.hpp" @@ -42,6 +43,7 @@ inline void bind(py::module &m) { bind_swiglu(m); bind_rope(m); bind_embedding(m); + bind_linear_w8a8i8(m); } } // namespace infinicore::ops diff --git a/src/infinicore/pybind11/ops/linear_w8a8i8.hpp b/src/infinicore/pybind11/ops/linear_w8a8i8.hpp new file mode 100644 index 000000000..926d554b1 --- /dev/null +++ b/src/infinicore/pybind11/ops/linear_w8a8i8.hpp @@ -0,0 +1,54 @@ +#pragma once + +#include + +#include "infinicore/ops/linear_w8a8i8.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +Tensor py_linear_w8a8i8(Tensor input, + Tensor weight_packed, + Tensor weight_scale, + pybind11::object bias) { + std::optional bias_tensor = std::nullopt; + if (!bias.is_none()) { + bias_tensor = bias.cast(); + } + return op::linear_w8a8i8(input, weight_packed, weight_scale, bias_tensor); +} + +void py_linear_w8a8i8_(Tensor out, + Tensor input, + Tensor weight_packed, + Tensor weight_scale, + pybind11::object bias) { + + std::optional bias_tensor = std::nullopt; + if (!bias.is_none()) { + bias_tensor = bias.cast(); + } + + op::linear_w8a8i8_(out, input, weight_packed, weight_scale, bias_tensor); +} + +inline void bind_linear_w8a8i8(py::module &m) { + m.def("linear_w8a8i8", + &ops::py_linear_w8a8i8, + py::arg("input"), + py::arg("weight_packed"), + py::arg("weight_scale"), + py::arg("bias") = py::none(), + R"doc(linear_w8a8i8.)doc"); + m.def("linear_w8a8i8_", + &ops::py_linear_w8a8i8_, + py::arg("out"), + py::arg("input"), + py::arg("weight_packed"), + py::arg("weight_scale"), + py::arg("bias") = py::none(), + R"doc(linear_w8a8i8_.)doc"); +} + +} // namespace infinicore::ops diff --git a/src/infinicore/pybind11/ops/per_channel_quant_i8.hpp b/src/infinicore/pybind11/ops/per_channel_quant_i8.hpp new file mode 100644 index 000000000..da6f9f592 --- /dev/null +++ b/src/infinicore/pybind11/ops/per_channel_quant_i8.hpp @@ -0,0 +1,20 @@ +#pragma once + +#include + +#include "infinicore/ops/per_channel_quant_i8.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +inline void bind_per_channel_quant_i8(py::module &m) { + m.def("per_channel_quant_i8_", + &op::per_channel_quant_i8_, + py::arg("x"), + py::arg("x_packed"), + py::arg("x_scale"), + R"doc(Per-channel quantization of a tensor.)doc"); +} + +} // namespace infinicore::ops diff --git a/src/infinicore/pybind11/ops/scaled_mm_i8.hpp b/src/infinicore/pybind11/ops/scaled_mm_i8.hpp new file mode 100644 index 000000000..c3d46d9df --- /dev/null +++ b/src/infinicore/pybind11/ops/scaled_mm_i8.hpp @@ -0,0 +1,30 @@ +#pragma once + +#include + +#include "infinicore/ops/scaled_mm_i8.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +inline void bind_scaled_mm_i8(py::module &m) { + m.def("scaled_mm_i8", + &op::scaled_mm_i8, + py::arg("a_p"), + py::arg("a_s"), + py::arg("b_p"), + py::arg("b_s"), + py::arg("bias"), + R"doc(Scaled matrix multiplication of two tensors.)doc"); + + m.def("scaled_mm_i8_", + &op::scaled_mm_i8_, + py::arg("a"), + py::arg("b"), + py::arg("a_scale"), + py::arg("b_scale"), + R"doc(In-place Scaled matrix multiplication of two tensors.)doc"); +} + +} // namespace infinicore::ops diff --git a/src/infinicore/tensor/debug.cc b/src/infinicore/tensor/debug.cc index 0ae1946e3..b57b00a52 100644 --- a/src/infinicore/tensor/debug.cc +++ b/src/infinicore/tensor/debug.cc @@ -95,6 +95,20 @@ void print_data_bf16(const uint16_t *data, const Shape &shape, const Strides &st } } +// Function for printing I8 data +void print_data_i8(const int8_t *data, const Shape &shape, const Strides &strides, size_t dim) { + if (dim == shape.size() - 1) { + for (size_t i = 0; i < shape[dim]; i++) { + std::cout << static_cast(data[i * strides[dim]]) << " "; + } + std::cout << std::endl; + } else if (dim < shape.size() - 1) { + for (size_t i = 0; i < shape[dim]; i++) { + print_data_i8(data + i * strides[dim], shape, strides, dim + 1); + } + } +} + // Template function for writing data recursively to binary file (handles non-contiguous tensors) template void write_binary_data(std::ofstream &out, const T *data, const Shape &shape, const Strides &strides, size_t dim) { @@ -181,8 +195,8 @@ void TensorImpl::debug(const std::string &filename) const { cpu_tensor->shape(), cpu_tensor->strides(), 0); break; case DataType::I8: - print_data(reinterpret_cast(cpu_data), - cpu_tensor->shape(), cpu_tensor->strides(), 0); + print_data_i8(reinterpret_cast(cpu_data), + cpu_tensor->shape(), cpu_tensor->strides(), 0); break; case DataType::BF16: print_data_bf16(reinterpret_cast(cpu_data), diff --git a/src/infiniop/ops/quant/per_channel_quant_int8/cuda/kernel.cuh b/src/infiniop/ops/quant/per_channel_quant_int8/cuda/kernel.cuh new file mode 100644 index 000000000..8cf5d6db9 --- /dev/null +++ b/src/infiniop/ops/quant/per_channel_quant_int8/cuda/kernel.cuh @@ -0,0 +1,267 @@ +#ifndef __PERCHANNEL_QUANTINT8_KERNEL_CUH__ +#define __PERCHANNEL_QUANTINT8_KERNEL_CUH__ + +#include +__device__ inline int round_half_away_from_zero(float x) { + float ax = fabsf(x); + float r = floorf(ax + 0.5f); + return (x >= 0.0f) ? (int)r : -(int)r; +} + +template +__device__ void blockPerChannelQuantI8Kernel( + int8_t *x_packed, float *x_scale, float *x_zero, const Tdata *x, + int M, int K) { + int row = blockIdx.x; + int tid = row * K; + + // ---- 1. reduce max ---- + float local_max = op::common_cuda::reduce_op::max( + x + tid, K); + + __shared__ float global_max_f; + if (threadIdx.x == 0) { + global_max_f = local_max; + } + __syncthreads(); + + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + // ---- 2. reduce min ---- + float thread_min = __FLT_MAX__; + for (int ind = threadIdx.x; ind < K; ind += BLOCK_SIZE) { + thread_min = fminf(thread_min, (float)x[tid + ind]); + } + float local_min = BlockReduce(temp_storage).Reduce(thread_min, cub::Min()); + + __shared__ float global_min_f; + if (threadIdx.x == 0) { + global_min_f = local_min; + } + __syncthreads(); + + float global_max = global_max_f; + float global_min = global_min_f; + + float scale = (global_max - global_min) / 255.0f; + if (scale < 1e-8f) { + scale = 1e-8f; + } + + float inv_scale = 1.0f / scale; + float zero = -global_min * inv_scale - 128.0f; + + x_scale[row] = (Tdata)scale; + x_zero[row] = (Tdata)zero; + + for (int ind = threadIdx.x; ind < K; ind += BLOCK_SIZE) { + + float v = (float)x[tid + ind]; + float qf = v * inv_scale + zero; + + int q = round_half_away_from_zero(qf); + + if (q > 127) { + q = 127; + } + if (q < -128) { + q = -128; + } + + x_packed[tid + ind] = (int8_t)q; + } +} + +template +__device__ void blockPerChannelQuantI8SymKernel( + int8_t *x_packed, float *x_scale, const Tdata *x, + int M, int K) { + int row = blockIdx.x; + int tid = row * K; + + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + // ---- 2. reduce min ---- + float thread_max = -__FLT_MAX__; + for (int ind = threadIdx.x; ind < K; ind += BLOCK_SIZE) { + thread_max = fmaxf(thread_max, fabs((float)x[tid + ind])); + } + float local_max = BlockReduce(temp_storage).Reduce(thread_max, cub::Max()); + + __shared__ float global_max_f; + if (threadIdx.x == 0) { + global_max_f = local_max; + } + __syncthreads(); + + float global_max = global_max_f; + + float scale = global_max / 127.0f; + if (scale < 1e-8f) { + scale = 1e-8f; + } + + float inv_scale = 1.0f / scale; + + x_scale[row] = (Tdata)scale; + + for (int ind = threadIdx.x; ind < K; ind += BLOCK_SIZE) { + + float v = (float)x[tid + ind]; + float qf = v * inv_scale; + + int q = round_half_away_from_zero(qf); + + if (q > 127) { + q = 127; + } + if (q < -127) { + q = -127; + } + + x_packed[tid + ind] = (int8_t)q; + } +} + +template +struct MaxOp { + __device__ __forceinline__ T operator()(const T &a, const T &b) const { + return max(a, b); + } +}; +template +struct MinOp { + __device__ __forceinline__ T operator()(const T &a, const T &b) const { + return min(a, b); + } +}; +template