Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
235 commits
Select commit Hold shift + click to select a range
3e9b9f1
Implementation of linear_ layer for neural networks. This layer provi…
Cydral Apr 28, 2025
93ead3d
Minor change
Cydral May 2, 2025
bf1b805
Update dlib/dnn/layers.h
davisking May 3, 2025
49bfbc6
Merge branch 'davisking:master' into master
Cydral May 6, 2025
f234faa
Add reshape_to and flatten layers to Dlib's DNN module
Cydral May 6, 2025
26a2960
Missing update to "visitors.h"
Cydral May 22, 2025
c9a1ee4
format fixing for reshape_to
Cydral May 22, 2025
02e62d8
Update dlib/test/dnn.cpp
davisking May 23, 2025
394dee8
Merge branch 'davisking:master' into master
Cydral May 29, 2025
778bfc1
Vocabulary size fixed for learning, and function added for transforma…
Cydral May 29, 2025
03aafc2
Added a new example for learning a “complex” Transformer model.
Cydral May 29, 2025
22c2561
Added a new example for learning a “complex” Transformer model.
Cydral May 29, 2025
01cd0b2
Updated example for training a Transformer model.
Cydral May 29, 2025
6b63e55
fix for gcc/ffmpeg compilation
Cydral May 30, 2025
ad1f757
Fix a warning message for Ubuntu compilation.
Cydral May 30, 2025
c91c45a
Update for Linux environment.
Cydral May 30, 2025
6fcc0aa
Fix batch building
Cydral May 31, 2025
5a1773e
Slight improvement in model definition.
Cydral Jun 3, 2025
10d7c59
linear_ layer implementation improvement
Cydral Jun 7, 2025
d4bf94b
finalizing the example
Cydral Jun 7, 2025
a4dac0b
Fixing break condition in training method.
Cydral Jun 8, 2025
63454e3
Fixing declaration order of variables.
Cydral Jun 8, 2025
87ed70a
bpe_tokenizer improvements.
Cydral Jun 8, 2025
061c673
Example updated.
Cydral Jun 16, 2025
f6c8526
bpe_tokenizer class refactoring.
Cydral Jun 16, 2025
2db56f5
Example updated.
Cydral Jun 16, 2025
d4eeb2d
bpe_tokenizer class updated.
Cydral Jun 16, 2025
dcb5963
Decoding part of the bpe_tokenizer updated.
Cydral Jun 17, 2025
b81b502
Network definition update
Cydral Jun 27, 2025
80a6e0e
Merge branch 'davisking:master' into master
Cydral Aug 25, 2025
d520c2a
Add Adaptive Computation Time (ACT) layer with CPU/CUDA support
Cydral Aug 27, 2025
b089d58
Fixes
Cydral Aug 29, 2025
1a904f2
Update comments for params
Sep 8, 2025
ab29fc4
Fixes and improvements
Cydral Sep 13, 2025
f16f743
Disabling enable_depth_scaling, which obviously affects the result of…
Cydral Sep 13, 2025
4d95752
Merge branch 'davisking:master' into master
Cydral Sep 29, 2025
9b3e7dc
Implementation of ARC-AGI dataset loader
Oct 6, 2025
f2ac8b9
Merge branch 'davisking:master' into master
Cydral Oct 15, 2025
97f9368
Transformer structures and models integration
Cydral Oct 24, 2025
18b7dfc
Transformer integration (update)
Oct 27, 2025
0368056
Transformer integration (update)
Oct 27, 2025
0c0fcfa
Transformer integration (update)
Oct 27, 2025
7311279
Transformer integration (update)
Oct 29, 2025
cd1d054
Transformer integration (update)
Oct 29, 2025
ba7a9f6
Transformer integration (update)
Oct 29, 2025
4c67a0b
Transformer integration (update)
Oct 29, 2025
a9fbbb5
Transformer integration (update)
Oct 29, 2025
9c9df71
Update
Oct 29, 2025
892f0d9
Update
Oct 30, 2025
a7e8a6e
Update
Oct 30, 2025
56d1e45
Update
Oct 30, 2025
eb21183
Update
Oct 30, 2025
5a80efc
Update
Oct 31, 2025
c1cad3c
Update
Oct 31, 2025
8a5aabf
Update
Oct 31, 2025
3369358
Update
Oct 31, 2025
2e05a7b
Update
Cydral Nov 1, 2025
c3b7b8d
Update
Cydral Nov 2, 2025
3722ff3
Update
Cydral Nov 2, 2025
3fe82c1
Update
Cydral Nov 2, 2025
5256d70
Update
Cydral Nov 2, 2025
7ab4224
Update
Cydral Nov 2, 2025
e41b061
Update
Cydral Nov 2, 2025
a80d267
Update
Cydral Nov 2, 2025
14a113b
Update
Cydral Nov 2, 2025
c0c23e8
Update
Cydral Nov 2, 2025
8b4db5f
Update
Cydral Nov 2, 2025
08323b5
Update
Cydral Nov 2, 2025
3b88e35
Update
Cydral Nov 2, 2025
2de5697
Update
Cydral Nov 2, 2025
a7a30ce
Update
Cydral Nov 2, 2025
66a0581
Update
Nov 3, 2025
c96097a
Update
Nov 3, 2025
9b2375f
Update
Nov 3, 2025
13c5313
Update
Cydral Nov 3, 2025
d724242
Update
Nov 4, 2025
e8f1cec
Update
Nov 4, 2025
9c09117
Update
Nov 4, 2025
f745721
Update
Nov 4, 2025
2901dde
Update
Nov 4, 2025
13e9ca8
Update
Nov 4, 2025
6b11c3b
Update
Nov 4, 2025
20b21f8
Update
Cydral Nov 3, 2025
eccfcd7
Update
Cydral Nov 5, 2025
dfaf59f
Update
Cydral Nov 5, 2025
40da84e
Update
Nov 6, 2025
90a3d69
Update
Nov 6, 2025
10a57cf
Update
Nov 6, 2025
cb97eec
Update
Cydral Nov 6, 2025
303d501
Update
Cydral Nov 6, 2025
e9fe817
Update
Cydral Nov 6, 2025
18be2cb
Update
Cydral Nov 6, 2025
e4e283d
Update
Cydral Nov 7, 2025
c7e5e1c
Update
Cydral Nov 7, 2025
33bf45c
Update
Cydral Nov 8, 2025
0fdcbdf
Update
Cydral Nov 8, 2025
4062cd6
Update
Cydral Nov 8, 2025
99b558e
Update
Cydral Nov 8, 2025
7ad9c2f
Update
Cydral Nov 8, 2025
3c1fb72
Update
Cydral Nov 8, 2025
f38d9b4
Update
Cydral Nov 8, 2025
9ce3e1e
Update
Cydral Nov 9, 2025
4d29b02
Update
Cydral Nov 9, 2025
0c8a411
Update
Nov 10, 2025
57cc916
Update
Nov 10, 2025
4c20ab9
Update
Cydral Nov 9, 2025
806c204
Update
Cydral Nov 10, 2025
9adea62
Update
Cydral Nov 10, 2025
b5f430c
Update
Cydral Nov 10, 2025
a1e0696
Update
Cydral Nov 10, 2025
f3732f7
Update
Cydral Nov 11, 2025
48edb1b
Update
Cydral Nov 11, 2025
a2bd297
Update
Cydral Nov 11, 2025
b617d15
Update
Cydral Nov 11, 2025
ad46857
Update
Cydral Nov 11, 2025
91897e8
Update
Cydral Nov 11, 2025
28547fb
Update
Cydral Nov 11, 2025
6c73ffc
Update
Cydral Nov 11, 2025
6f8167b
Update
Nov 12, 2025
22ea770
Update
Nov 12, 2025
5140fd1
Update
Nov 12, 2025
491f380
Update
Nov 12, 2025
1132007
Update
Nov 12, 2025
a66ec7b
Update
Nov 12, 2025
2ad0dc2
Update
Cydral Nov 12, 2025
6ea54e6
Update
Cydral Nov 12, 2025
1cdba1a
Update
Nov 13, 2025
d858103
Update
Nov 13, 2025
4735cb7
Update
Nov 13, 2025
22f28d1
Update
Nov 13, 2025
966c8dc
Update
Nov 13, 2025
1232fd7
Update
Nov 13, 2025
871a7fd
Update
Nov 13, 2025
33f5162
Update
Nov 14, 2025
4961282
Update
Cydral Nov 14, 2025
a6178f0
Update
Cydral Nov 15, 2025
e885ab8
Update
Cydral Nov 15, 2025
7dc8ce0
Update
Cydral Nov 15, 2025
65e24b2
Update
Cydral Nov 15, 2025
58436fc
Update
Cydral Nov 15, 2025
7e095db
Update
Cydral Nov 15, 2025
27fa8a2
Update
Cydral Nov 15, 2025
d6f46f0
Update
Cydral Nov 15, 2025
d7ae7de
Update
Cydral Nov 15, 2025
d13d8d8
Update
Cydral Nov 15, 2025
e29b2f1
Update
Cydral Nov 15, 2025
39c5f9b
Update
Cydral Nov 15, 2025
5db8473
Update
Cydral Nov 15, 2025
76370e7
Update
Cydral Nov 16, 2025
c837549
Update
Cydral Nov 17, 2025
38d803e
Update
Cydral Nov 17, 2025
1479e74
Update
Cydral Nov 17, 2025
b1a4112
Update
Cydral Nov 20, 2025
179db86
Update
Cydral Nov 20, 2025
9b416ae
Update
Cydral Nov 20, 2025
13c2b84
Update
Cydral Nov 20, 2025
b2e64cf
Update
Cydral Nov 20, 2025
26d0648
Update
Cydral Nov 20, 2025
01dcece
Update
Cydral Nov 20, 2025
f1a7085
Update
Cydral Nov 20, 2025
eae5a42
Update
Cydral Nov 20, 2025
40eab2b
Update
Cydral Nov 20, 2025
e81b369
Update
Cydral Nov 20, 2025
f08d678
Update
Cydral Nov 20, 2025
cc88afc
Update
Cydral Nov 20, 2025
2fe42df
Update
Cydral Nov 20, 2025
0c2d183
Update
Cydral Nov 21, 2025
272f41c
Update
Cydral Nov 21, 2025
b3dd372
Update
Cydral Nov 23, 2025
3267a01
Update
Cydral Nov 23, 2025
5bc8428
Update
Cydral Nov 23, 2025
b0f5aa3
Update
Cydral Nov 23, 2025
76a043e
Update
Cydral Nov 24, 2025
3ea781f
Update
Cydral Nov 25, 2025
6d90742
Update
Cydral Nov 25, 2025
d3e143e
Update
Cydral Nov 25, 2025
726af51
Update
Cydral Nov 25, 2025
8b3219c
Update
Cydral Nov 25, 2025
c105ccf
Update
Cydral Nov 25, 2025
caac93b
Update
Cydral Nov 25, 2025
504d551
Update
Cydral Nov 26, 2025
ee33e9f
Update
Cydral Nov 26, 2025
3e83339
Update
Cydral Nov 26, 2025
62cf36b
Update
Cydral Nov 27, 2025
e36645a
Update
Cydral Nov 27, 2025
04ab704
Update
Cydral Nov 27, 2025
3f19899
Update
Cydral Nov 28, 2025
70ec30a
Update
Cydral Nov 28, 2025
d48164a
Update
Cydral Nov 28, 2025
40dd868
Update
Cydral Dec 1, 2025
44830f7
Update
Cydral Dec 2, 2025
85eb2c9
Update
Cydral Dec 5, 2025
5a83f2c
Update
Cydral Dec 5, 2025
78aae5b
Update
Cydral Dec 5, 2025
7cce339
Update
Cydral Dec 6, 2025
e8c6950
Update
Cydral Dec 7, 2025
e496c7d
Update
Cydral Dec 7, 2025
6460e81
Update
Cydral Dec 7, 2025
c6f6979
Update
Cydral Dec 7, 2025
952513c
Update
Cydral Dec 7, 2025
5a00bd8
Update
Cydral Dec 7, 2025
b302fb8
Update
Cydral Dec 8, 2025
b9be752
Update
Cydral Dec 9, 2025
8dfbdef
Update
Cydral Dec 13, 2025
17f859b
Update
Cydral Dec 15, 2025
f66369b
Update
Cydral Dec 16, 2025
84bd433
Update
Cydral Dec 16, 2025
c82213d
Update
Cydral Dec 17, 2025
8d1f4ea
Update
Cydral Dec 17, 2025
def6359
Update
Cydral Dec 18, 2025
9e76ed5
Update
Cydral Dec 18, 2025
426857c
Update
Cydral Dec 18, 2025
3fe4ee1
Update
Cydral Dec 19, 2025
c4086bc
New example
Cydral Dec 19, 2025
1fc065d
New example added
Cydral Dec 19, 2025
7b2c4ef
Update
Cydral Dec 19, 2025
9c86229
Update
Cydral Dec 20, 2025
0c730f9
Update
Cydral Dec 20, 2025
f028608
Fix bug in cuda code for act layer
Cydral Dec 23, 2025
e2c229d
Embeddings class improvement
Cydral Dec 24, 2025
d74b2f9
Update
Cydral Dec 26, 2025
9396527
Update
Cydral Dec 26, 2025
114dab9
Update
Cydral Dec 28, 2025
306b1d4
Update
Cydral Dec 28, 2025
da591d3
Fix tril_padding_context multiple definition linker errors with Meyer…
Cydral Dec 29, 2025
b69d284
Add lr_mult_visitor for visit_layers_range
Cydral Dec 29, 2025
1a6494f
Removed used var in patch_embeddings_/backward
Cydral Dec 29, 2025
1ca58b1
Updated slm_mixture_of_experts_ex.cpp example
Cydral Dec 29, 2025
d7a4ebe
Updated slm_chatbot_ex.cpp example
Cydral Dec 29, 2025
e61623f
Update
Cydral Dec 29, 2025
e6bb9ed
Remove (old) french comment
Cydral Jan 2, 2026
b40ed81
Fix typo
Cydral Jan 2, 2026
0d15d7b
New static signal handler using
Cydral Jan 2, 2026
86dcfcf
Add "atomic" header
Cydral Jan 2, 2026
07a847b
Remove HRM example to stabilize first the new version of the adaptive…
Cydral Jan 3, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 77 additions & 7 deletions dlib/cuda/cpu_dlib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1494,7 +1494,6 @@ namespace dlib
}
p_scale[n] = 1.0f / std::sqrt(p_scale[n] / (ks * num) + static_cast<float>(eps));
}
scale.host();

// Apply RMS normalization
p_src = src.host();
Expand Down Expand Up @@ -1648,14 +1647,22 @@ namespace dlib
for (long k = 0; k < num_channels; ++k)
max_val = std::max(max_val, ss[k * num_locations]);

float sum = 0.0f;
for (long k = 0; k < num_channels; ++k)
if (max_val == -std::numeric_limits<float>::infinity())
{
dd[k * num_locations] = std::exp(ss[k * num_locations] - max_val);
sum += dd[k * num_locations];
for (long k = 0; k < num_channels; ++k)
dd[k * num_locations] = 0.0f;
}
else
{
float sum = 0.0f;
for (long k = 0; k < num_channels; ++k)
{
dd[k * num_locations] = std::exp(ss[k * num_locations] - max_val);
sum += dd[k * num_locations];
}
for (long k = 0; k < num_channels; ++k)
dd[k * num_locations] /= sum;
}
for (long k = 0; k < num_channels; ++k)
dd[k * num_locations] /= sum;

++ss;
++dd;
Expand Down Expand Up @@ -3366,6 +3373,69 @@ namespace dlib
}
}

// ------------------------------------------------------------------------------------

void apply_rotary_positional_embedding(
bool is_backward,
resizable_tensor& data,
const resizable_tensor& cos_cache,
const resizable_tensor& sin_cache)
{
const long batch_size = data.num_samples();
const long num_heads = data.k();
const long seq_len = data.nr();
const long d_head = data.nc();
const long half_d = d_head / 2;

DLIB_CASSERT(cos_cache.nr() == seq_len, "cos_cache rows must match seq_len");
DLIB_CASSERT(cos_cache.nc() == half_d, "cos_cache cols must be d_head/2");
DLIB_CASSERT(sin_cache.nr() == seq_len, "sin_cache rows must match seq_len");
DLIB_CASSERT(sin_cache.nc() == half_d, "sin_cache cols must be d_head/2");

const bool is_odd = (d_head % 2 != 0);
const long rot_dim = is_odd ? d_head - 1 : d_head;

float* data_ptr = data.host();
const float* cos_ptr = cos_cache.host();
const float* sin_ptr = sin_cache.host();

const size_t total_elements = batch_size * num_heads * seq_len * half_d;

parallel_for(0, total_elements, [&](long idx)
{
const long pair_idx = idx % half_d;
const long pos = (idx / half_d) % seq_len;
const long head = (idx / (half_d * seq_len)) % num_heads;
const long batch = idx / (half_d * seq_len * num_heads);

const long dim_i = pair_idx * 2;
if (dim_i >= rot_dim) return;

const long data_offset = ((batch * num_heads + head) * seq_len + pos) * d_head + dim_i;
const long trig_offset = pos * half_d + pair_idx;

const float c = cos_ptr[trig_offset];
const float s = sin_ptr[trig_offset];
const float x0 = data_ptr[data_offset];
const float x1 = data_ptr[data_offset + 1];

if (!is_backward)
{
// Forward: [cos -sin] [x0]
// [sin cos] [x1]
data_ptr[data_offset] = x0 * c - x1 * s;
data_ptr[data_offset + 1] = x0 * s + x1 * c;
}
else
{
// Backward (inverse rotation): [cos sin] [x0]
// [-sin cos] [x1]
data_ptr[data_offset] = x0 * c + x1 * s;
data_ptr[data_offset + 1] = -x0 * s + x1 * c;
}
});
}

// ------------------------------------------------------------------------------------

}
Expand Down
141 changes: 141 additions & 0 deletions dlib/cuda/cpu_dlib.h
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,15 @@ namespace dlib
float scale_factor
);

// -----------------------------------------------------------------------------------

void apply_rotary_positional_embedding(
bool is_backward,
resizable_tensor& data,
const resizable_tensor& cos_cache,
const resizable_tensor& sin_cache
);

// -----------------------------------------------------------------------------------

class pooling
Expand Down Expand Up @@ -761,6 +770,138 @@ namespace dlib

// -----------------------------------------------------------------------------------

class compute_loss_cross_entropy_per_logit
{
/*!
Computes cross-entropy loss for causal language modeling
Uses all sequence positions (except last) for training
Each position t predicts the token at position t+1
!*/
public:
compute_loss_cross_entropy_per_logit() {}

template <typename const_label_iterator>
void operator()(
const_label_iterator truth,
const tensor& input_tensor,
const tensor& output_tensor,
tensor& grad,
double& loss,
long ignore_index
) const
{
DLIB_CASSERT(output_tensor.k() == 1);
DLIB_CASSERT(input_tensor.k() == 1);
DLIB_CASSERT(input_tensor.nc() == 1);

const long batch_size = output_tensor.num_samples();
const long seq_len = output_tensor.nr();
const long vocab_size = output_tensor.nc();

const float* out_data = output_tensor.host();
const float* in_data = input_tensor.host();
float* g = grad.host();

std::fill(g, g + grad.size(), 0.0f);

long valid_tokens = 0;

if (ignore_index < 0)
{
valid_tokens = batch_size * seq_len;
}
else {
for (long i = 0; i < batch_size; ++i)
{
for (long t = 0; t < seq_len; ++t)
{
unsigned long target_class;
if (t < seq_len - 1) {
target_class = static_cast<unsigned long>(
in_data[tensor_index(input_tensor, i, 0, t + 1, 0)]
);
}
else
target_class = *(truth + i);

if (static_cast<long>(target_class) != ignore_index)
valid_tokens++;
}
}
}
if (valid_tokens == 0)
{
loss = 0.0;
return;
}

const double scale = 1.0 / valid_tokens;
loss = 0.0;

for (long i = 0; i < batch_size; ++i)
{
// Loop over all positions (0 to seq_len-1)
for (long t = 0; t < seq_len; ++t)
{
unsigned long target_class;

// Extract target token
if (t < seq_len - 1) {
// For positions 0 to seq_len-2: target from input_tensor[t+1]
target_class = static_cast<unsigned long>(
in_data[tensor_index(input_tensor, i, 0, t + 1, 0)]
);
} else {
// For last position (seq_len-1): target from truth
target_class = *(truth + i);
}

if (ignore_index >= 0 && static_cast<long>(target_class) == ignore_index)
continue;

DLIB_CASSERT(target_class < static_cast<unsigned long>(vocab_size));

// Find max logit for numerical stability
float max_val = out_data[tensor_index(output_tensor, i, 0, t, 0)];
for (long c = 1; c < vocab_size; ++c)
{
const float val = out_data[tensor_index(output_tensor, i, 0, t, c)];
max_val = std::max(max_val, val);
}

// Compute softmax denominator
float sum_exp = 0.0f;
for (long c = 0; c < vocab_size; ++c)
{
const unsigned long idx = tensor_index(output_tensor, i, 0, t, c);
const float exp_val = std::exp(out_data[idx] - max_val);
g[idx] = exp_val;
sum_exp += exp_val;
}

// Compute loss and gradients
for (long c = 0; c < vocab_size; ++c)
{
const unsigned long idx = tensor_index(output_tensor, i, 0, t, c);
const float softmax_val = g[idx] / sum_exp;

if (static_cast<unsigned long>(c) == target_class)
{
loss += scale * (-std::log(std::max(softmax_val, 1e-10f)));
g[idx] = scale * (softmax_val - 1.0f);
}
else
{
g[idx] = scale * softmax_val;
}
}
}
}
}
};

// -----------------------------------------------------------------------------------

class compute_loss_binary_log_per_pixel
{

Expand Down
43 changes: 26 additions & 17 deletions dlib/cuda/cublas_dlibapi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,40 +159,43 @@ namespace dlib
const auto transa = trans_lhs ? CUBLAS_OP_T : CUBLAS_OP_N;
const auto transb = trans_rhs ? CUBLAS_OP_T : CUBLAS_OP_N;

long num_samples = std::min({ lhs.num_samples(), rhs.num_samples(), dest.num_samples() });
long num_channels = std::min({ lhs.k(), rhs.k(), dest.k() });

auto is_matrix = [](const auto& tensor) {
return ((tensor.num_samples() * tensor.k() == 1 && tensor.nr() * tensor.nc() > 1) ||
(tensor.num_samples() * tensor.k() > 1 && tensor.nr() * tensor.nc() == 1));
};
const bool lhs_is_matrix = is_matrix(lhs), rhs_is_matrix = is_matrix(rhs), dest_is_matrix = is_matrix(dest);

if (lhs_is_matrix && rhs_is_matrix && dest_is_matrix) num_samples = num_channels = 1;
const bool lhs_is_matrix = is_2d_matrix(lhs);
const bool rhs_is_matrix = is_2d_matrix(rhs);
const bool dest_is_matrix = is_2d_matrix(dest);

const size_t lhs_plane_size = lhs.nr() * lhs.nc();
const size_t rhs_plane_size = rhs.nr() * rhs.nc();
const size_t dest_plane_size = dest.nr() * dest.nc();

long num_samples, num_channels = std::min({ lhs.k(), rhs.k(), dest.k() });
if (lhs_is_matrix && rhs_is_matrix && dest_is_matrix)
num_samples = 1;
else if (!lhs_is_matrix && rhs_is_matrix)
num_samples = lhs.num_samples();
else
num_samples = std::min({ lhs.num_samples(), rhs.num_samples(), dest.num_samples() });

size_t lhs_rows = lhs.nr();
size_t lhs_cols = lhs.nc();
if (lhs_is_matrix && (lhs.num_samples() > 1 || lhs.k() > 1)) {
lhs_rows = lhs.num_samples();
lhs_cols = lhs.k();
}

size_t rhs_rows = rhs.nr();
size_t rhs_cols = rhs.nc();
if (rhs_is_matrix && (rhs.num_samples() > 1 || rhs.k() > 1)) {
rhs_rows = rhs.num_samples();
rhs_cols = rhs.k();
}

size_t dest_rows = dest.nr();
size_t dest_cols = dest.nc();
if (dest_is_matrix && (dest.num_samples() > 1 || dest.k() > 1)) {
dest_rows = dest.num_samples();
dest_cols = dest.k();
}

const size_t lhs_plane_size = lhs_rows * lhs_cols;
const size_t rhs_plane_size = rhs_rows * rhs_cols;
const size_t dest_plane_size = dest_rows * dest_cols;

for (long b = 0; b < num_samples; ++b)
{
for (long c = 0; c < num_channels; ++c)
Expand All @@ -203,12 +206,18 @@ namespace dlib
rhs.device() + (b * num_channels + c) * rhs_plane_size;
auto dest_slice = dest_is_matrix ? dest.device() :
dest.device() + (b * num_channels + c) * dest_plane_size;

const int k = trans_rhs ? rhs_cols : rhs_rows;

CHECK_CUBLAS(cublasSgemm(
context(), transb, transa, dest_cols, dest_rows, k,
&alpha, rhs_slice, rhs_cols, lhs_slice, lhs_cols,
&beta, dest_slice, dest_cols
context(),
transb, transa,
dest_cols, dest_rows, k,
&alpha,
rhs_slice, rhs_cols,
lhs_slice, lhs_cols,
&beta,
dest_slice, dest_cols
));
}
}
Expand Down
Loading
Loading