Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ set_ifndef(CUDA_DIR /usr/local/cuda-${CUDA_CTK_VERSION})
if(NOT DEFINED AARCH64_BUILD)
set(CMAKE_CUDA_ARCHITECTURES 80;86;89)
if(CUDA_CTK_VERSION VERSION_GREATER_EQUAL 12.8)
list(APPEND CMAKE_CUDA_ARCHITECTURES 100 120)
list(APPEND CMAKE_CUDA_ARCHITECTURES 100 110 120)
endif()
endif()

Expand Down
2 changes: 1 addition & 1 deletion cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ add_library(
${RUNTIME_CPP_SRCS}
${PROFILER_CPP_SRCS})
target_include_directories(edgellmCore PRIVATE ${COMMON_INCLUDE_DIRS})
target_link_libraries(edgellmCore PRIVATE dl)
target_link_libraries(edgellmCore PRIVATE dl curand)
# Apply FMHA SM exclusion definitions
target_compile_definitions(edgellmCore PRIVATE ${FMHA_EXCLUDE_DEFINITIONS})
# Enable separable compilation for MoE Marlin kernel templates
Expand Down
2 changes: 1 addition & 1 deletion cpp/common/inputLimits.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ constexpr int kReasonableMaxBatchSize = 16;
// Validation limits for message parsing.
constexpr size_t kMaxMessageContentSizeBytes = 128 * 1024; // 128KB per content item
constexpr size_t kMaxMessagesPerRequest = 64;
constexpr size_t kMaxContentItemsPerMessage = 16;
constexpr size_t kMaxContentItemsPerMessage = 64; // Alpamayo uses 4 cameras x 4 frames = 16 images + text items

} // namespace security

Expand Down
118 changes: 118 additions & 0 deletions cpp/kernels/alpamayoExpertKernels/alpamayoExpertKernels.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*
* CUDA kernels for AlpamayoExpertRunner.
*/

#include "alpamayoExpertKernels.h"

namespace trt_edgellm
{
namespace kernel
{

namespace
{

__global__ void kvCacheReshapeRepeatKernel(float* __restrict__ dst, half const* __restrict__ src, int32_t numLayers,
int32_t numKVHeads, int32_t seqLen, int32_t headDim, int32_t numCandidates, int64_t totalElements)
{
int64_t const idx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
if (idx >= totalElements)
return;

// Decompose: dst[layer2, cand, head, s, d]
int64_t rem = idx;
int32_t const d = rem % headDim;
rem /= headDim;
int32_t const s = rem % seqLen;
rem /= seqLen;
int32_t const head = rem % numKVHeads;
rem /= numKVHeads;
rem /= numCandidates; // skip cand dimension (broadcast)
int32_t const layer2 = rem;

int32_t const layer = layer2 / 2;
int32_t const kv = layer2 % 2;

// src layout: [layer, kv, head, s, d] contiguous
int64_t const hsd = static_cast<int64_t>(numKVHeads) * seqLen * headDim;
int64_t const srcIdx = (static_cast<int64_t>(layer) * 2 + kv) * hsd + static_cast<int64_t>(head) * seqLen * headDim
+ static_cast<int64_t>(s) * headDim + d;

dst[idx] = __half2float(src[srcIdx]);
}

__global__ void buildPositionIdsKernel(int64_t* __restrict__ posIds, int32_t total, int32_t numTokens, int64_t basePos)
{
int32_t const idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= total)
return;
posIds[idx] = basePos + (idx % numTokens);
}

__global__ void fillTimestepKernel(float* __restrict__ dst, int32_t n, float val)
{
int32_t const idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n)
dst[idx] = val;
}

__global__ void eulerUpdateKernel(float* __restrict__ x, float const* __restrict__ v, float dt, int32_t n)
{
int32_t const idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n)
x[idx] += v[idx] * dt;
}

} // anonymous namespace

void kvCacheReshapeRepeat(float* dst, half const* src, int32_t numLayers, int32_t numKVHeads, int32_t seqLen,
int32_t headDim, int32_t numCandidates, cudaStream_t stream)
{
int64_t const totalElements = static_cast<int64_t>(numLayers) * 2 * numCandidates * numKVHeads * seqLen * headDim;
int32_t const blockSize = 256;
int32_t const numBlocks = static_cast<int32_t>((totalElements + blockSize - 1) / blockSize);
kvCacheReshapeRepeatKernel<<<numBlocks, blockSize, 0, stream>>>(
dst, src, numLayers, numKVHeads, seqLen, headDim, numCandidates, totalElements);
}

void buildPositionIds(int64_t* posIds, int32_t numCandidates, int32_t numTokens, int64_t basePos, cudaStream_t stream)
{
int32_t const total = 3 * numCandidates * numTokens;
int32_t const blockSize = 256;
int32_t const numBlocks = (total + blockSize - 1) / blockSize;
buildPositionIdsKernel<<<numBlocks, blockSize, 0, stream>>>(posIds, total, numTokens, basePos);
}

void fillTimestep(float* dst, int32_t numCandidates, float tVal, cudaStream_t stream)
{
int32_t const blockSize = 256;
int32_t const numBlocks = (numCandidates + blockSize - 1) / blockSize;
fillTimestepKernel<<<numBlocks, blockSize, 0, stream>>>(dst, numCandidates, tVal);
}

void eulerUpdate(float* x, float const* v, float dt, int32_t n, cudaStream_t stream)
{
int32_t const blockSize = 256;
int32_t const numBlocks = (n + blockSize - 1) / blockSize;
eulerUpdateKernel<<<numBlocks, blockSize, 0, stream>>>(x, v, dt, n);
}

} // namespace kernel
} // namespace trt_edgellm
43 changes: 43 additions & 0 deletions cpp/kernels/alpamayoExpertKernels/alpamayoExpertKernels.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*
* CUDA kernels for AlpamayoExpertRunner.
*/

#pragma once

#include <cstdint>
#include <cuda_fp16.h>
#include <cuda_runtime.h>

namespace trt_edgellm
{
namespace kernel
{

void kvCacheReshapeRepeat(float* dst, half const* src, int32_t numLayers, int32_t numKVHeads, int32_t seqLen,
int32_t headDim, int32_t numCandidates, cudaStream_t stream);

void buildPositionIds(int64_t* posIds, int32_t numCandidates, int32_t numTokens, int64_t basePos, cudaStream_t stream);

void fillTimestep(float* dst, int32_t numCandidates, float tVal, cudaStream_t stream);

void eulerUpdate(float* x, float const* v, float dt, int32_t n, cudaStream_t stream);

} // namespace kernel
} // namespace trt_edgellm
10 changes: 10 additions & 0 deletions cpp/multimodal/multimodalRunner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,5 +133,15 @@ bool MultimodalRunner::preprocessSystemPrompt([[maybe_unused]] std::string const
return true;
}

bool MultimodalRunner::preprocessPreparedVisual([[maybe_unused]] rt::LLMGenerationRequest const& request,
[[maybe_unused]] std::vector<std::vector<int32_t>>& batchedInputIds,
[[maybe_unused]] tokenizer::Tokenizer const* tokenizer, [[maybe_unused]] rt::Tensor& ropeRotaryCosSinDevice,
[[maybe_unused]] rt::Tensor const& pixelValues, [[maybe_unused]] rt::Tensor const& imageGridTHW,
[[maybe_unused]] cudaStream_t stream)
{
LOG_ERROR("preprocessPreparedVisual is not implemented for this multimodal runner");
return false;
}

} // namespace rt
} // namespace trt_edgellm
20 changes: 20 additions & 0 deletions cpp/multimodal/multimodalRunner.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,26 @@ class MultimodalRunner
tokenizer::Tokenizer const* tokenizer, [[maybe_unused]] rt::Tensor& ropeRotaryCosSinDevice, cudaStream_t stream)
= 0;

/*!
* @brief Preprocess using already prepared visual inputs from an external pipeline.
*
* This bypasses image decoding / resize / patchification in TRT runtime and instead
* consumes tensors equivalent to processor outputs such as `pixel_values` and `image_grid_thw`.
*
* @param request Generation request with text/messages
* @param batchedInputIds Output batched input token IDs
* @param tokenizer Tokenizer instance
* @param ropeRotaryCosSinDevice RoPE cache tensor
* @param pixelValues Preprocessed visual input tensor [num_patches, input_dim]
* @param imageGridTHW Image grid tensor [num_images, 3]
* @param stream CUDA stream
* @return True on success, false on failure
*/
virtual bool preprocessPreparedVisual(rt::LLMGenerationRequest const& request,
std::vector<std::vector<int32_t>>& batchedInputIds, tokenizer::Tokenizer const* tokenizer,
[[maybe_unused]] rt::Tensor& ropeRotaryCosSinDevice, [[maybe_unused]] rt::Tensor const& pixelValues,
[[maybe_unused]] rt::Tensor const& imageGridTHW, cudaStream_t stream);

/*!
* @brief Used for KVCache saving where we need to conduct the tokenization of the system prompt and generate
* ND-Rope parameters for the system prompt.
Expand Down
130 changes: 130 additions & 0 deletions cpp/multimodal/qwenViTRunner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -768,6 +768,136 @@ bool QwenViTRunner::preprocess(rt::LLMGenerationRequest const& request,
return true;
}

bool QwenViTRunner::preprocessPreparedVisual(rt::LLMGenerationRequest const& request,
std::vector<std::vector<int32_t>>& batchedInputIds, tokenizer::Tokenizer const* tokenizer,
rt::Tensor& ropeRotaryCosSinDevice, rt::Tensor const& pixelValues, rt::Tensor const& imageGridTHW,
cudaStream_t stream)
{
try
{
check::check(pixelValues.getShape().getNumDims() == 2, "pixelValues must be 2D [num_patches, input_dim]");
check::check(pixelValues.getDataType() == nvinfer1::DataType::kHALF, "pixelValues must be FP16");
check::check(
imageGridTHW.getShape().getNumDims() == 2 && imageGridTHW.getShape()[1] == 3, "imageGridTHW must be [N, 3]");

int64_t const numImages = imageGridTHW.getShape()[0];
std::vector<int64_t> imageGridTHWHostVec(numImages * 3);
CUDA_CHECK(cudaMemcpyAsync(imageGridTHWHostVec.data(), imageGridTHW.rawPointer(),
sizeof(int64_t) * numImages * 3, cudaMemcpyDeviceToHost, stream));
CUDA_CHECK(cudaStreamSynchronize(stream));

std::vector<std::vector<int64_t>> imageGridTHWs;
imageGridTHWs.reserve(numImages);
std::vector<int64_t> imageTokenLengths;
imageTokenLengths.reserve(numImages);
std::vector<int64_t> numImagesPerRequest;
numImagesPerRequest.reserve(request.requests.size());

for (int64_t i = 0; i < numImages; ++i)
{
int64_t t = imageGridTHWHostVec[i * 3 + 0];
int64_t h = imageGridTHWHostVec[i * 3 + 1];
int64_t w = imageGridTHWHostVec[i * 3 + 2];
imageGridTHWs.push_back({t, h, w});
imageTokenLengths.push_back(t * h * w / (mConfig.mergeSize * mConfig.mergeSize));
}

// Count image items per request from message contents.
// For prepared_visual_input, the request may intentionally omit raw image content and
// provide a text-only prompt plus externally prepared pixel_values/image_grid_thw.
// In that case, infer the per-request image count from imageGridTHW for the single
// request we currently support.
int64_t totalImageCount = 0;
for (auto const& req : request.requests)
{
int64_t requestImageCount = 0;
for (auto const& msg : req.messages)
{
for (auto const& content : msg.contents)
{
if (content.type == "image")
{
requestImageCount++;
}
}
}
numImagesPerRequest.push_back(requestImageCount);
totalImageCount += requestImageCount;
}
if (totalImageCount == 0)
{
check::check(request.requests.size() == 1,
"prepared_visual_input without raw image messages currently supports a single request only");
numImagesPerRequest[0] = numImages;
totalImageCount = numImages;
}
check::check(totalImageCount == numImages, "imageGridTHW image count does not match request image count");

// Populate mVitInput directly from the prepared processor output.
int64_t const totalSeqLength = pixelValues.getShape()[0];
check::check(pixelValues.getShape()[1] == mConfig.inputDim, "pixelValues input dim mismatch");
check::check(mVitInput.reshape({totalSeqLength, mConfig.inputDim}), "Tensor reshape failed");
CUDA_CHECK(cudaMemcpyAsync(mVitInput.rawPointer(), pixelValues.rawPointer(),
totalSeqLength * mConfig.inputDim * sizeof(half), cudaMemcpyDeviceToDevice, stream));

// Build cu_seqlens and auxiliary tensors exactly like imagePreprocess tail.
int32_t* cuSeqlensData = mCuSeqlensHost.dataPointer<int32_t>();
cuSeqlensData[0] = 0;
int64_t cuSeqlensSize = 1;
int64_t maxSeqLen = 0;
for (auto const& grid : imageGridTHWs)
{
int64_t curSeqLength = grid[0] * grid[1] * grid[2];
int32_t prevCuSeqlen = cuSeqlensData[cuSeqlensSize - 1];
cuSeqlensData[cuSeqlensSize++] = static_cast<int32_t>(prevCuSeqlen + curSeqLength);
maxSeqLen = std::max(maxSeqLen, curSeqLength);
}

int64_t const totalImageTokens = totalSeqLength / (mConfig.mergeSize * mConfig.mergeSize);
check::check(mOutputEmbedding.reshape({totalImageTokens, mConfig.outHiddenSize}), "Tensor reshape failed");
check::check(mMaxSeqLenCarrier.reshape({maxSeqLen}), "Tensor reshape failed");

check::check(mCuSeqlens.reshape({cuSeqlensSize}), "Tensor reshape failed");
CUDA_CHECK(cudaMemcpyAsync(mCuSeqlens.rawPointer(), mCuSeqlensHost.rawPointer(),
cuSeqlensSize * sizeof(int32_t), cudaMemcpyHostToDevice, stream));

check::check(mRotaryPosEmb.reshape({totalSeqLength, mConfig.vitPosEmbDim}), "Tensor reshape failed");
for (size_t i = 0; i < imageGridTHWs.size(); ++i)
{
kernel::initRotaryPosEmbQwenViT(
mRotaryPosEmb, imageGridTHWs[i], mConfig.mergeSize, cuSeqlensData[i], 10000.0f, 1.0f, stream);
}

if (mModelType == multimodal::ModelType::QWEN3_VL
|| mModelType == multimodal::ModelType::QWEN3_OMNI_VISION_ENCODER)
{
check::check(mFastPosEmbIdx.reshape({4, totalSeqLength}), "Tensor reshape failed");
check::check(mFastPosEmbWeight.reshape({4, totalSeqLength}), "Tensor reshape failed");
for (size_t i = 0; i < imageGridTHWs.size(); ++i)
{
kernel::initFastPosEmbedQwenViT(mFastPosEmbIdx, mFastPosEmbWeight, imageGridTHWs[i], mConfig.mergeSize,
mConfig.numGridPerSide, cuSeqlensData[i], stream);
}
for (int64_t i = 0; i < mConfig.numDeepstackFeatures; ++i)
{
check::check(
mDeepstackFeatures[i].reshape({totalImageTokens, mConfig.outHiddenSize}), "Tensor reshape failed");
}
}
mLastImageGridTHWs = imageGridTHWs;

textPreprocess(request, batchedInputIds, numImagesPerRequest, imageTokenLengths, tokenizer);
generateMropeParams(batchedInputIds, imageGridTHWs, ropeRotaryCosSinDevice, stream);
}
catch (std::exception const& e)
{
LOG_ERROR("QwenViTRunner::preprocessPreparedVisual() failed: %s", e.what());
return false;
}

return true;
}

bool QwenViTRunner::preprocessSystemPrompt(std::string const& systemPrompt, tokenizer::Tokenizer const* tokenizer,
rt::Tensor& ropeRotaryCosSinDevice, cudaStream_t stream)
{
Expand Down
5 changes: 5 additions & 0 deletions cpp/multimodal/qwenViTRunner.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@ class QwenViTRunner : public MultimodalRunner
bool preprocess(rt::LLMGenerationRequest const& request, std::vector<std::vector<int32_t>>& batchedInputIds,
tokenizer::Tokenizer const* tokenizer, rt::Tensor& ropeRotaryCosSinDevice, cudaStream_t stream) override;

bool preprocessPreparedVisual(rt::LLMGenerationRequest const& request,
std::vector<std::vector<int32_t>>& batchedInputIds, tokenizer::Tokenizer const* tokenizer,
rt::Tensor& ropeRotaryCosSinDevice, rt::Tensor const& pixelValues, rt::Tensor const& imageGridTHW,
cudaStream_t stream) override;

//! \brief Encode the system prompt and generate ND-RoPE parameters for the system prompt for KVCache saving.
//! \param[in] systemPrompt System prompt string
//! \param[in] tokenizer Tokenizer for text processing
Expand Down
Loading