Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
218 changes: 218 additions & 0 deletions catgrad-llm/src/helpers/conv.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
use crate::helpers::tensors::*;
use catgrad::prelude::ops::*;
use catgrad::prelude::*;

// Depthwise 1D convolution without bias
pub fn depthwise_conv1d_no_bias(
builder: &Builder,
weight_path: Path,
kernel_size: usize,
x: Var,
padding_size: usize,
) -> Var {
let conv_weight = param(builder, &weight_path.extend(["weight"]).unwrap());
depthwise_conv1d_no_bias_param(builder, conv_weight, kernel_size, padding_size, x)
}

// Depthwise 1D convolution without bias for already padded inputs.
// `x_padded` is expected to be padded with `K-1` leading zeros; `output_len` is the desired
// unpadded sequence length. This variant is required because we cannot subtract Nats to get the
// unpadded sequence length directly.
pub fn padded_depthwise_conv1d_no_bias(
builder: &Builder,
weight_path: Path,
kernel_size: usize,
x_padded: Var,
output_len: Var,
) -> Var {
let conv_weight = param(builder, &weight_path.extend(["weight"]).unwrap());
padded_depthwise_conv1d_no_bias_param(builder, conv_weight, kernel_size, x_padded, output_len)
}

// Parameterized depthwise 1D convolution with optional bias.
// `conv_weight` is expected to have shape `H x 1 x K`.
// `conv_bias` (optional) is expected to have shape `H`.
pub fn depthwise_conv1d_param(
builder: &Builder,
conv_weight: Var,
conv_bias: Option<Var>,
kernel_size: usize,
x: Var,
padding_size: usize,
) -> Var {
let [b, h, s] = unpack::<3>(builder, shape(builder, x.clone()));

let x_padded = if padding_size > 0 {
let pad_shape = shape!(builder, b, h, padding_size);
let pad = constant(builder, 0.0, &pad_shape);
concat(builder, 2, pad, x)
} else {
x
};

depthwise_conv1d_param_padded(builder, conv_weight, conv_bias, kernel_size, x_padded, s)
}

// Parameterized depthwise 1D convolution without bias.
// `conv_weight` is expected to have shape `H x 1 x K`.
pub fn depthwise_conv1d_no_bias_param(
builder: &Builder,
conv_weight: Var,
kernel_size: usize,
padding_size: usize,
x: Var,
) -> Var {
depthwise_conv1d_param(builder, conv_weight, None, kernel_size, x, padding_size)
}

// Parameterized depthwise 1D convolution without bias for already padded inputs.
pub fn padded_depthwise_conv1d_no_bias_param(
builder: &Builder,
conv_weight: Var,
kernel_size: usize,
x_padded: Var,
output_len: Var,
) -> Var {
depthwise_conv1d_param_padded(
builder,
conv_weight,
None,
kernel_size,
x_padded,
output_len,
)
}

// Helper function for depthwise 1D convolution with bias for already padded inputs.
fn depthwise_conv1d_param_padded(
builder: &Builder,
conv_weight: Var,
conv_bias: Option<Var>,
kernel_size: usize,
x_padded: Var,
output_len: Var,
) -> Var {
let conv_weight = squeeze::<3, 2>(builder, 1, conv_weight);

let mut conv_out: Option<Var> = None;
for offset in 0..kernel_size {
let x_slice = slice(builder, 2, offset, output_len.clone(), x_padded.clone());
let w_slice = slice(builder, 1, offset, 1, conv_weight.clone());
let w_slice = unsqueeze::<2, 3>(builder, 0, w_slice);
let w_slice = broadcast(builder, shape(builder, x_slice.clone()), w_slice);
let term = x_slice * w_slice;
conv_out = Some(match conv_out {
Some(acc) => acc + term,
None => term,
});
}
let mut conv_out = conv_out.expect("kernel_size must be positive");

if let Some(bias) = conv_bias {
let bias = unsqueeze::<1, 2>(builder, 0, bias);
let bias = unsqueeze::<2, 3>(builder, 2, bias);
let bias = broadcast(builder, shape(builder, conv_out.clone()), bias);
conv_out = conv_out + bias;
}

conv_out
}

#[cfg(test)]
mod tests {
use super::*;
use catgrad::abstract_interpreter::Value as TypeValue;
use catgrad::category::core::Shape;
use catgrad::interpreter::backend::Backend;
use catgrad::interpreter::backend::ndarray::NdArrayBackend;
use catgrad::interpreter::{
Interpreter, Parameters, TaggedTensor, TaggedTensorTuple, Value, tensor,
};
use catgrad::stdlib::{Module, stdlib};
use catgrad::typecheck::value_types::*;

struct DepthwiseConv1dTest;

impl Module<2, 1> for DepthwiseConv1dTest {
fn ty(&self) -> ([Type; 2], [Type; 1]) {
let t_x = TypeValue::Tensor(TypeExpr::NdArrayType(NdArrayType {
dtype: DtypeExpr::Constant(Dtype::F32),
shape: ShapeExpr::Shape(vec![
NatExpr::Constant(1),
NatExpr::Constant(2),
NatExpr::Constant(4),
]),
}));
let t_w = TypeValue::Tensor(TypeExpr::NdArrayType(NdArrayType {
dtype: DtypeExpr::Constant(Dtype::F32),
shape: ShapeExpr::Shape(vec![
NatExpr::Constant(2),
NatExpr::Constant(1),
NatExpr::Constant(3),
]),
}));
let t_y = TypeValue::Tensor(TypeExpr::NdArrayType(NdArrayType {
dtype: DtypeExpr::Constant(Dtype::F32),
shape: ShapeExpr::Shape(vec![
NatExpr::Constant(1),
NatExpr::Constant(2),
NatExpr::Constant(4),
]),
}));
([t_x, t_w], [t_y])
}

fn path(&self) -> Path {
path(vec!["test", "depthwise_conv1d"]).unwrap()
}

fn def(&self, builder: &Builder, [x, w]: [Var; 2]) -> [Var; 1] {
[depthwise_conv1d_no_bias_param(builder, w, 3, 2, x)]
}
}

#[test]
fn test_depthwise_conv1d_no_bias_param_matches_reference_values() {
let typed_term = DepthwiseConv1dTest.term().unwrap();
let backend = NdArrayBackend;
let interpreter = Interpreter::new(backend, stdlib(), Parameters::default());

let x = tensor(
&interpreter.backend,
Shape(vec![1, 2, 4]),
vec![1.0f32, 2.0, 3.0, 4.0, 10.0, 20.0, 30.0, 40.0],
)
.unwrap();
let w = tensor(
&interpreter.backend,
Shape(vec![2, 1, 3]),
vec![1.0f32, 2.0, 3.0, 0.5, -1.0, 2.0],
)
.unwrap();

let mut outputs = interpreter.run(typed_term.term, vec![x, w]).unwrap();
let y = outputs.pop().expect("missing output");

let expected = tensor(
&interpreter.backend,
Shape(vec![1, 2, 4]),
vec![3.0f32, 8.0, 14.0, 20.0, 20.0, 30.0, 45.0, 60.0],
)
.unwrap();

match (y, expected) {
(
Value::Tensor(TaggedTensor::F32([actual])),
Value::Tensor(TaggedTensor::F32([exp])),
) => {
assert!(
interpreter
.backend
.compare(TaggedTensorTuple::F32([actual, exp])),
"depthwise conv output should match expected reference values"
);
}
_ => panic!("expected f32 tensor outputs"),
}
}
}
5 changes: 5 additions & 0 deletions catgrad-llm/src/helpers/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
mod conv;
mod tensors;

pub use conv::*;
pub use tensors::*;

mod rope;
Expand All @@ -17,6 +20,7 @@ pub struct Cache {
pub in_kv_cache: Vec<(Var, Var)>,
pub out_kv_cache: Vec<(Var, Var)>,
pub linear_state: Option<Vec<Var>>,
pub recurrent_state: Option<Vec<Var>>,
}

impl Cache {
Expand Down Expand Up @@ -69,6 +73,7 @@ impl Cache {
in_kv_cache,
out_kv_cache,
linear_state: None,
recurrent_state: None,
}
}

Expand Down
4 changes: 2 additions & 2 deletions catgrad-llm/src/helpers/tensors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,6 @@ pub fn cumsum<const N: usize>(builder: &Builder, x: Var) -> Var {
matmul(builder, x, lower)
}

pub fn zeros(builder: &Builder, shape: Var) -> Var {
constant(builder, 0.0, &shape)
pub fn zeros(builder: &Builder, shape: &Var) -> Var {
constant(builder, 0.0, shape)
}
61 changes: 18 additions & 43 deletions catgrad-llm/src/models/lfm2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,37 +244,6 @@ impl Lfm2Model {
linear_no_bias(builder, dim, dim, p.extend(["out_proj"]).unwrap(), attn)
}

fn depthwise_conv1d(&self, builder: &Builder, p: &Path, x_padded: Var, s: Var) -> Var {
let k = self.config.conv_l_cache;

let conv_weight = param(builder, &p.extend(["conv", "weight"]).unwrap());
let conv_weight = squeeze::<3, 2>(builder, 1, conv_weight);

let mut conv_out: Option<Var> = None;
for offset in 0..k {
let x_slice = slice(builder, 2, offset, s.clone(), x_padded.clone());
let w_slice = slice(builder, 1, offset, 1, conv_weight.clone());
let w_slice = unsqueeze::<2, 3>(builder, 0, w_slice);
let w_slice = broadcast(builder, shape(builder, x_slice.clone()), w_slice);
let term = x_slice * w_slice;
conv_out = Some(match conv_out {
Some(acc) => acc + term,
None => term,
});
}
let mut conv_out = conv_out.expect("conv_l_cache must be positive");

if self.config.conv_bias {
let conv_bias = param(builder, &p.extend(["conv", "bias"]).unwrap());
let conv_bias = unsqueeze::<1, 2>(builder, 0, conv_bias);
let conv_bias = unsqueeze::<2, 3>(builder, 2, conv_bias);
let conv_bias = broadcast(builder, shape(builder, conv_out.clone()), conv_bias);
conv_out = conv_out + conv_bias;
}

conv_out
}

fn short_conv(
&self,
builder: &Builder,
Expand Down Expand Up @@ -326,7 +295,7 @@ impl Lfm2Model {
builder,
is_decode,
|b, args: Vec<Var>| {
let [bx, s, batch_size, hidden_dim, conv_state, pos_clamped_u32] =
let [bx, s, _batch_size, _hidden_dim, conv_state, pos_clamped_u32] =
args.try_into().unwrap();

// `conv_state = conv_state.roll(shifts=-1, dims=-1)`
Expand All @@ -349,26 +318,32 @@ impl Lfm2Model {
let bx_decode = broadcast(b, sh_state, bx_decode);
let out_linear_state_decode = where_broadcast(b, one_hot, bx_decode, rolled_state);

// `conv_out = torch.sum(conv_state * self.conv.weight[:, 0, :], dim=-1).unsqueeze(-1)`
let zeros_decode_tail = zeros(b, shape!(b, batch_size, hidden_dim, s));
let x_padded_decode =
concat(b, 2, out_linear_state_decode.clone(), zeros_decode_tail);
let conv_out_decode = self.depthwise_conv1d(b, &p, x_padded_decode, s);
// Use the helper for decoding: pass out_linear_state_decode with 0 padding.
let conv_out_decode = padded_depthwise_conv1d_no_bias(
b,
p.extend(["conv"]).unwrap(),
cache_len,
out_linear_state_decode.clone(),
s,
);

vec![conv_out_decode, out_linear_state_decode]
},
|b, args: Vec<Var>| {
let [bx, s, batch_size, hidden_dim, _conv_state, _pos_clamped_u32] =
args.try_into().unwrap();

// `conv_out = self.conv(Bx)[..., :seqlen]`
let zeros_prefill_conv = zeros(b, shape!(b, batch_size, hidden_dim, cache_len - 1));
let x_padded_prefill_conv = concat(b, 2, zeros_prefill_conv, bx.clone());
let conv_out_prefill =
self.depthwise_conv1d(b, &p, x_padded_prefill_conv, s.clone());
// Use the helper for prefill: pass bx with causal padding.
let conv_out_prefill = depthwise_conv1d_no_bias(
b,
p.extend(["conv"]).unwrap(),
cache_len,
bx.clone(),
cache_len - 1,
);

// `conv_state = nn.functional.pad(Bx, (self.L_cache - Bx.shape[-1], 0))`
let zeros_prefill_state = zeros(b, shape!(b, batch_size, hidden_dim, cache_len));
let zeros_prefill_state = zeros(b, &shape!(b, batch_size, hidden_dim, cache_len));
let x_padded_prefill_state = concat(b, 2, zeros_prefill_state, bx);
let out_linear_state_prefill = slice(b, 2, s, cache_len, x_padded_prefill_state);

Expand Down
Loading