Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions include/infinicore/nn/linear.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@

#include "../ops.hpp"
#include "module.hpp"
#include "quantization.hpp"
#include <infiniccl.h>
#include <optional>

namespace infinicore::nn {

class BaseLinear : public Module {
public:
BaseLinear(size_t in_features, size_t out_features, bool bias = true,
const DataType &dtype = DataType::F32, const Device &device = Device());
const DataType &dtype = DataType::F32, const Device &device = Device(),
const std::optional<QuantScheme> &quant_scheme = std::nullopt);

// Forward pass: output = input @ weight.T + bias
Tensor forward(Tensor &input) const;
Expand All @@ -27,12 +30,17 @@ class BaseLinear : public Module {
// Accessors for parameters
Tensor weight() const { return weight_; }
Tensor bias() const { return bias_; }
Tensor weight_scale() const { return weight_scale_; }
Tensor weight_zeros() const { return weight_zeros_; }

protected:
// Parameters
INFINICORE_NN_PARAMETER(weight);
INFINICORE_NN_PARAMETER(bias);

INFINICORE_NN_PARAMETER(weight_scale);
INFINICORE_NN_PARAMETER(weight_zeros);

protected:
// Helper method for common forward computation
Tensor compute_linear(Tensor &input) const;
Expand All @@ -41,6 +49,7 @@ class BaseLinear : public Module {
size_t out_features_;
bool has_bias_;
DataType dtype_;
const std::optional<QuantScheme> quant_scheme_;
};

} // namespace infinicore::nn
Expand All @@ -50,7 +59,8 @@ namespace infinicore::nn {
class Linear : public BaseLinear {
public:
Linear(size_t in_features, size_t out_features, bool bias = true,
const DataType &dtype = DataType::F32, const Device &device = Device());
const DataType &dtype = DataType::F32, const Device &device = Device(),
const std::optional<QuantScheme> &quant_scheme = std::nullopt);

// Forward pass: output = input @ weight.T + bias
Tensor forward(Tensor &input) const;
Expand All @@ -63,7 +73,8 @@ class ColumnParallelLinear : public BaseLinear {
public:
ColumnParallelLinear(size_t in_features, size_t out_features, bool bias = true,
const DataType &dtype = DataType::F32, const Device &device = Device(),
Size tp_rank = 0, Size tp_size = 1);
Size tp_rank = 0, Size tp_size = 1,
const std::optional<QuantScheme> &quant_scheme = std::nullopt);

// Forward pass: output = input @ weight.T + bias
Tensor forward(Tensor &input) const;
Expand All @@ -80,7 +91,8 @@ class RowParallelLinear : public BaseLinear {
public:
RowParallelLinear(size_t in_features, size_t out_features, bool bias = true,
const DataType &dtype = DataType::F32, const Device &device = Device(),
Size tp_rank = 0, Size tp_size = 1, infinicclComm_t communicator = nullptr);
Size tp_rank = 0, Size tp_size = 1, infinicclComm_t communicator = nullptr,
const std::optional<QuantScheme> &quant_scheme = std::nullopt);

// Forward pass: output = input @ weight.T + bias
Tensor forward(Tensor &input) const;
Expand Down
11 changes: 11 additions & 0 deletions include/infinicore/nn/quantization.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// quant.hpp
#pragma once

namespace infinicore::nn {

enum class QuantScheme {
NONE,
COMPRESSED_TENSOR_W8A8I8,
};

} // namespace infinicore::nn
12 changes: 12 additions & 0 deletions include/infinicore/ops/linear_w8a8i8.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#pragma once

#include "common/op.hpp"
#include <optional>

namespace infinicore::op {

Tensor linear_w8a8i8(Tensor input, Tensor weight_packed, Tensor weight_scale, std::optional<Tensor> bias);

void linear_w8a8i8_(Tensor out, Tensor input, Tensor weight_packed, Tensor weight_scale, std::optional<Tensor> bias);

} // namespace infinicore::op
15 changes: 15 additions & 0 deletions include/infinicore/ops/per_channel_quant_i8.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#pragma once
#include "../device.hpp"
#include "common/op.hpp"
#include <optional>

namespace infinicore::op {
class PerChannelQuantI8 {
public:
using schema = void (*)(Tensor, Tensor, Tensor);
static void execute(Tensor x, Tensor x_packed, Tensor x_scale);
static common::OpDispatcher<schema> &dispatcher();
};

void per_channel_quant_i8_(Tensor x, Tensor x_packed, Tensor x_scale);
} // namespace infinicore::op
16 changes: 16 additions & 0 deletions include/infinicore/ops/scaled_mm_i8.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#pragma once

#include "../device.hpp"
#include "common/op.hpp"
#include <optional>

namespace infinicore::op {
class ScaledMMI8 {
public:
using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor, std::optional<Tensor>);
static void execute(Tensor c, Tensor a_p, Tensor a_s, Tensor b_p, Tensor b_s, std::optional<Tensor> bias);
static common::OpDispatcher<schema> &dispatcher();
};

void scaled_mm_i8_(Tensor c, Tensor a_p, Tensor a_s, Tensor b_p, Tensor b_s, std::optional<Tensor> bias);
} // namespace infinicore::op
2 changes: 2 additions & 0 deletions include/infiniop.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "infiniop/ops/dequantize_awq.h"
#include "infiniop/ops/gelu.h"
#include "infiniop/ops/gemm.h"
#include "infiniop/ops/int8_gemm.h"
#include "infiniop/ops/layer_norm.h"
#include "infiniop/ops/logsoftmax.h"
#include "infiniop/ops/lp_norm.h"
Expand All @@ -19,6 +20,7 @@
#include "infiniop/ops/paged_attention.h"
#include "infiniop/ops/paged_attention_prefill.h"
#include "infiniop/ops/paged_caching.h"
#include "infiniop/ops/quant/per_channel_quant_int8.h"
#include "infiniop/ops/random_sample.h"
#include "infiniop/ops/rearrange.h"
#include "infiniop/ops/relu.h"
Expand Down
28 changes: 28 additions & 0 deletions include/infiniop/ops/quant/per_channel_quant_int8.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#ifndef __INFINIOP_PER_CHANNEL_QUANT_INT8_API_H__
#define __INFINIOP_PER_CHANNEL_QUANT_INT8_API_H__

#include "../../operator_descriptor.h"

typedef InfiniopDescriptor *infiniopPerChannelQuantI8Descriptor_t;

__C __export infiniStatus_t infiniopCreatePerChannelQuantI8Descriptor(infiniopHandle_t handle,
infiniopPerChannelQuantI8Descriptor_t *desc_ptr,
infiniopTensorDescriptor_t x_packed_desc,
infiniopTensorDescriptor_t x_scale_desc,
infiniopTensorDescriptor_t x_zero_desc,
infiniopTensorDescriptor_t x_desc);

__C __export infiniStatus_t infiniopGetPerChannelQuantI8WorkspaceSize(infiniopPerChannelQuantI8Descriptor_t desc, size_t *size);

__C __export infiniStatus_t infiniopPerChannelQuantI8(infiniopPerChannelQuantI8Descriptor_t desc,
void *workspace,
size_t workspace_size,
void *x_packed,
void *x_scale,
void *x_zero,
const void *x,
void *stream);

__C __export infiniStatus_t infiniopDestroyPerChannelQuantI8Descriptor(infiniopPerChannelQuantI8Descriptor_t desc);

#endif
2 changes: 2 additions & 0 deletions python/infinicore/nn/functional/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .causal_softmax import causal_softmax
from .embedding import embedding
from .linear import linear
from .linear_w8a8i8 import linear_w8a8i8
from .random_sample import random_sample
from .rms_norm import rms_norm
from .rope import RopeAlgo, rope
Expand All @@ -17,4 +18,5 @@
"embedding",
"rope",
"RopeAlgo",
"linear_w8a8i8",
]
31 changes: 31 additions & 0 deletions python/infinicore/nn/functional/linear_w8a8i8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor


def linear_w8a8i8(
input: Tensor,
weight_packed: Tensor,
weight_scale: Tensor,
bias=None,
out=None,
) -> Tensor:
r"""Linear layer with weight quantized to int8 and input quantized to int8 with per-tensor scale."""

if out is None:
return Tensor(
_infinicore.linear_w8a8i8(
input._underlying,
weight_packed._underlying,
weight_scale._underlying,
None if bias is None else bias._underlying,
)
)

_infinicore.linear_w8a8i8_(
out._underlying,
input._underlying,
weight_packed._underlying,
weight_scale._underlying,
None if bias is None else bias._underlying,
)
return out
Empty file.
Loading