Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/ntops/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@
softmax,
sub,
tanh,
acosh,
adaptive_avg_pool2d,
addmv,
argsort,
fmax,
)

__all__ = [
Expand Down Expand Up @@ -76,4 +81,9 @@
"softmax",
"sub",
"tanh",
"acosh",
"adaptive_avg_pool2d",
"addmv",
"argsort",
"fmax",
]
96 changes: 96 additions & 0 deletions src/ntops/kernels/acosh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import functools

import ninetoothed.language as ntl
from ninetoothed import Tensor

from ntops.kernels.reduction import arrangement


def _sqrt(x, dtype):
"""数值稳定的平方根计算,特别处理float16精度"""
sqrt_dtype = dtype if dtype != ntl.float16 else ntl.float32
return ntl.cast(ntl.sqrt(ntl.cast(x, sqrt_dtype)), dtype)


def _log(x, dtype):
"""数值稳定的对数计算,特别处理float16精度"""
log_dtype = dtype if dtype != ntl.float16 else ntl.float32
return ntl.cast(ntl.log(ntl.cast(x, log_dtype)), dtype)


def application(input, output):
"""
计算反双曲余弦函数 acosh(x) = ln(x + sqrt(x² - 1))

参数:
input: 输入张量,形状为 (C // block_size, block_size)
output: 输出张量,形状为 (C // block_size, block_size)

数值稳定性考虑:
1. 当x接近1时,x² - 1接近0,使用(x-1)(x+1)形式避免精度损失
2. 当x很大时,避免x²溢出,使用代数变换
3. float16特殊处理,提升到float32计算
"""
dtype = output.dtype.dtype

for i in range(input.shape[0]):
# 获取当前块的数据
input_block = ntl.cast(input[i], dtype)

# 数值稳定的acosh计算
# acosh(x) = ln(x + sqrt(x² - 1))

# 处理x接近1的情况:x² - 1 = (x-1)(x+1)
# 这样可以避免当x接近1时的精度损失
x_minus_one = input_block - ntl.cast(1.0, dtype)
x_plus_one = input_block + ntl.cast(1.0, dtype)

# 计算 sqrt(x² - 1) = sqrt((x-1)(x+1))
sqrt_term = _sqrt(x_minus_one * x_plus_one, dtype)

# 计算 x + sqrt(x² - 1)
# 当x很大时,这可能会导致数值问题,但acosh的定义域x≥1
sum_term = input_block + sqrt_term

# 最终计算 ln(x + sqrt(x² - 1))
result = _log(sum_term, dtype)

# 处理边界情况:当x < 1时,acosh未定义,返回NaN
# 当x == 1时,acosh(1) = 0
result = ntl.where(
input_block < ntl.cast(1.0, dtype),
ntl.cast(float("nan"), dtype),
ntl.where(
input_block == ntl.cast(1.0, dtype),
ntl.cast(0.0, dtype),
result
)
)

# 将结果存入输出
output[i] = result


def premake(ndim, dim, dtype=None, block_size=None):
"""
准备acosh内核

参数:
ndim: 输入张量的维度
dim: 要计算acosh的维度
dtype: 数据类型
block_size: 分块大小,用于优化内存访问

返回:
arrangement_: 张量排列函数
application: 计算函数
tensors: 输入输出张量描述
"""
arrangement_ = functools.partial(arrangement, dim=dim, block_size=block_size)

tensors = (
Tensor(ndim, dtype=dtype), # 输入张量
Tensor(ndim, dtype=dtype), # 输出张量
)

return arrangement_, application, tensors
91 changes: 91 additions & 0 deletions src/ntops/kernels/adaptive_avg_pool2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import functools
import ninetoothed
import ninetoothed.language as ntl
from ninetoothed import Tensor

def _cast_to_f32(x, dtype):
"""
为了保证累加精度,如果是 float16 则转为 float32 计算
"""
return ntl.cast(x, ntl.float32) if dtype == ntl.float16 else x

def arrangement(input, output, kernel_size_flatted, kernel_size_h, kernel_size_w, stride_h, stride_w, block_size):
if block_size is None:
block_size = ninetoothed.block_size()

# input: (N, C, H_in, W_in)
# output: (N, C, H_out, W_out)

# 1. 窗口切分
input_arranged = input.tile(
(1, 1, kernel_size_h, kernel_size_w),
(1, 1, stride_h, stride_w)
)
# => (N, C, H_out, W_out), dtype=(1, 1, k_h, k_w)

# 2. 展平与重排
input_arranged = input_arranged.ravel()
# => (N, C, H_out, W_out, 1, 1, k_h, k_w)

input_arranged = input_arranged.flatten(end_dim=4).flatten(start_dim=1)
# => (N*C*H_out*W_out, k_h*k_w)

# 3. Padding 到最近的 2 的幂次 (用于规约)
# 这里的 padding 值由 premake 中的 other=0 决定
nearest_pow2 = 1 << (kernel_size_h * kernel_size_w - 1).bit_length()
input_arranged = input_arranged.tile((1, nearest_pow2))
# => (..., 1), dtype=(1, nearest_pow2)

input_arranged.dtype = input_arranged.dtype.squeeze(0)
input_arranged = input_arranged.tile((block_size, -1))
input_arranged.dtype = input_arranged.dtype.ravel().squeeze(1)
# => (..., 1), dtype=(block_size, nearest_pow2)

# 4. Output 对齐
output_arranged = output.tile((1, 1, 1, 1))
output_arranged = output_arranged.ravel()
output_arranged = output_arranged.flatten(end_dim=4).flatten(start_dim=1)
output_arranged = output_arranged.tile((block_size, -1))
output_arranged.dtype = output_arranged.dtype.squeeze(1)

return input_arranged, output_arranged, kernel_size_flatted

def application(input, output, kernel_size_flatted):
# input: (block_size, nearest_pow2)
# output: (block_size, )
# kernel_size_flatted: scalar tensor (k_h * k_w)

dtype = input.dtype

# 转为高精度进行 Sum
val = _cast_to_f32(input, dtype)

# 求和 (Axis 1 对应 nearest_pow2 维度)
# 这里的 0 填充不会影响 Sum 结果
acc = ntl.sum(val, axis=1)

# 求平均: Sum / Area
# 注意:kernel_size_flatted 是实际的窗口大小,不是 nearest_pow2
res = acc / ntl.cast(kernel_size_flatted, acc.dtype)

# 转回原类型
output = ntl.cast(res, dtype)

def premake(ndim, kernel_size_h, kernel_size_w, stride_h, stride_w, block_size=None, dtype=None):
arrangement_ = functools.partial(
arrangement,
kernel_size_h=kernel_size_h,
kernel_size_w=kernel_size_w,
stride_h=stride_h,
stride_w=stride_w,
block_size=block_size,
)

tensors = (
# input: 设置 other=0,保证 tile 补齐的值不影响 sum
Tensor(ndim, dtype=dtype, other=0),
Tensor(ndim, dtype=dtype), # output
Tensor(0, dtype=dtype), # kernel_size_flatted (scalar)
)

return arrangement_, application, tensors
105 changes: 105 additions & 0 deletions src/ntops/kernels/addmv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import functools

import ninetoothed.language as ntl
from ninetoothed import Tensor

import ntops.kernels.mm as mm


def arrangement(
input,
mat,
vec,
beta,
alpha,
output,
input_precision,
block_size_m=None,
block_size_n=None,
block_size_k=None,
):
if block_size_m is None:
block_size_m = mm.BLOCK_SIZE_M

# 关键:强制 block_size_n 为 1,因为这是 MV 乘法
if block_size_n is None:
block_size_n = 1

if block_size_k is None:
block_size_k = mm.BLOCK_SIZE_K

# mm.arrangement 现在接收的都是 Rank=2 的 Tensor
# vec 作为 (K, 1) 矩阵
# input 作为 (M, 1) 矩阵
_, _, input_arranged, _ = mm.arrangement(
mat,
vec,
input,
input_precision,
block_size_m=block_size_m,
block_size_n=block_size_n,
block_size_k=block_size_k,
)

mat_arranged, vec_arranged, output_arranged, _ = mm.arrangement(
mat,
vec,
output,
input_precision,
block_size_m=block_size_m,
block_size_n=block_size_n,
block_size_k=block_size_k,
)

input_precision_arranged = input_precision

return (
input_arranged,
mat_arranged,
vec_arranged,
beta,
alpha,
output_arranged,
input_precision_arranged,
)


def application(input, mat, vec, beta, alpha, output, input_precision):
# 这里 output 是 (M, 1),zeros 也创建 (M, 1)
mm_output = ntl.zeros(output.shape, dtype=ntl.float32)
mm.application(mat, vec, mm_output, input_precision)

# 逐元素操作,形状兼容 (M, 1)
output = beta * input + alpha * mm_output


def premake(
input_precision=None,
dtype=None,
block_size_m=None,
block_size_n=None,
block_size_k=None,
):
arrangement_ = functools.partial(
arrangement,
block_size_m=block_size_m,
block_size_n=block_size_n,
block_size_k=block_size_k,
)

# 修正:将所有 Tensor 定义为 2 维
# input: (M,) -> (M, 1) Tensor(2)
# mat: (M, K) -> Tensor(2)
# vec: (K,) -> (K, 1) Tensor(2)
# output: (M,) -> (M, 1) Tensor(2)
tensors = (
Tensor(2, dtype=dtype), # input (bias) treated as column vector
Tensor(2, dtype=dtype), # mat
Tensor(2, dtype=dtype), # vec treated as column vector
Tensor(0, dtype=dtype), # beta
Tensor(0, dtype=dtype), # alpha
Tensor(2, dtype=dtype), # output treated as column vector
Tensor(0, dtype=dtype, constexpr=True, value=input_precision),
)

return arrangement_, application, tensors
75 changes: 75 additions & 0 deletions src/ntops/kernels/argsort.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import functools

import ninetoothed.language as ntl
from ninetoothed import Tensor

from ntops.kernels.reduction import arrangement

def application(input, values, indices, dim_size, descending):
val_block = input[0]

# 初始化结果 buffer
res_vals = ntl.zeros(val_block.shape, dtype=val_block.dtype)
res_idxs = ntl.zeros(val_block.shape, dtype=indices.dtype.dtype)

# 用于记录当前写入位置的索引
output_range = ntl.arange(0, val_block.shape[0])
# 原始数据的索引,用于比较
idx_block = ntl.arange(0, val_block.shape[0])

# 根据排序方向决定处理逻辑
# argsort 默认是 ascending (从小到大),即 descending=False
# 如果 descending=True (largest=True),我们找最大值
# 如果 descending=False (largest=False),我们将值取反后找最大值 (即找最小值)
if descending:
working_val = val_block
else:
working_val = -val_block

sentinel = float("-inf")

# 循环次数为维度的完整大小
for i in range(dim_size):
# 找到当前 working_val 中的最大值(及其索引)
current_max_val = ntl.max(working_val, axis=0)
current_max_idx = ntl.argmax(working_val, axis=0)

# 还原真实值(如果是为了找最小值取反过,现在要反回来)
real_val = -current_max_val if not descending else current_max_val
real_val = ntl.cast(real_val, res_vals.dtype)

# 确定当前写入的位置 (target_mask 只有一个位置是 True)
target_mask = output_range == i

# 将找到的最值和索引写入结果 Tensor
res_vals = ntl.where(target_mask, real_val, res_vals)
res_idxs = ntl.where(target_mask, current_max_idx, res_idxs)

# Mask 掉已经选中的元素,防止下次被选中
mask_selected = idx_block == current_max_idx
updated_working_val = ntl.where(mask_selected, sentinel, working_val)
working_val = ntl.cast(updated_working_val, working_val.dtype)

# 写回输出
values[0] = res_vals
indices[0] = res_idxs


def premake(
ndim, dim, dim_size, descending, dtype=None, indices_dtype=None, block_size=None
):
# 使用 reduction 的 arrangement,通常用于处理规约维度的 block 划分
arrangement_ = functools.partial(arrangement, dim=dim, block_size=block_size)

# 填充值用于处理 padding (虽然全量 argsort 通常不需要 padding,但为了鲁棒性)
pad_val = float("-inf") if descending else float("inf")

tensors = (
Tensor(ndim, dtype=dtype, other=pad_val), # Input
Tensor(ndim, dtype=dtype), # Output Values (辅助,虽然 argsort 主要要 indices)
Tensor(ndim, dtype=indices_dtype), # Output Indices
Tensor(0, constexpr=True, value=dim_size),# Loop bound (k=dim_size)
Tensor(0, constexpr=True, value=descending),
)

return arrangement_, application, tensors
Loading