diff --git a/xllm/core/framework/model/CMakeLists.txt b/xllm/core/framework/model/CMakeLists.txt index dd19b5baa..947844371 100644 --- a/xllm/core/framework/model/CMakeLists.txt +++ b/xllm/core/framework/model/CMakeLists.txt @@ -29,6 +29,7 @@ cc_library( dit_model.h embedding_lm.h embedding_vlm.h + mm_embedding_vlm.h model_args.h npu_dp_ep_padding.h model_input_params.h diff --git a/xllm/core/framework/model/mm_embedding_vlm.h b/xllm/core/framework/model/mm_embedding_vlm.h new file mode 100644 index 000000000..135cfd5fb --- /dev/null +++ b/xllm/core/framework/model/mm_embedding_vlm.h @@ -0,0 +1,86 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include + +#include + +#include "causal_vlm.h" +#include "core/framework/kv_cache/kv_cache.h" +#include "core/framework/quant_args.h" +#include "core/framework/state_dict/state_dict.h" +#include "model_args.h" +#include "model_input_params.h" + +namespace xllm { + +class MMEmbeddingVLM : public CausalVLM { + public: + ~MMEmbeddingVLM() override = default; + + virtual std::vector encode( + const ModelInputParams& input_params) = 0; +}; + +template +class MMEmbeddingVLMImpl : public MMEmbeddingVLM { + public: + MMEmbeddingVLMImpl(Model model, const torch::TensorOptions& options) + : model_(std::move(model)), options_(options) {} + + virtual std::vector encode( + const ModelInputParams& input_params) override { + return model_->encode(input_params); + }; + + virtual torch::Tensor logits(const torch::Tensor& hidden_states, + const torch::Tensor& selected_idxes) { + return torch::Tensor(); + } + + virtual torch::Tensor forward(const torch::Tensor& tokens, + const torch::Tensor& positions, + std::vector& kv_caches, + const ModelInputParams& input_params) { + return torch::Tensor{}; + } + virtual void prepare_expert_weight(int32_t layer_id, + const std::vector& expert_ids) { + return; + } + virtual void update_expert_weight(int32_t layer_id) { return; } + virtual void set_lm_head(layer::LmHead& head) { return; } + virtual layer::LmHead get_lm_head() { return nullptr; } + virtual layer::WordEmbedding get_word_embedding() { return nullptr; } + virtual void set_word_embedding(layer::WordEmbedding& embedding) { return; } + + void load_model(std::unique_ptr loader) override { + model_->load_model(std::move(loader)); + } + + torch::Device device() const override { return options_.device(); } + + const torch::TensorOptions& options() const override { return options_; } + + private: + Model model_; + + torch::TensorOptions options_; +}; + +} // namespace xllm diff --git a/xllm/models/model_registry.cpp b/xllm/models/model_registry.cpp index 65c3353b1..38354fa07 100644 --- a/xllm/models/model_registry.cpp +++ b/xllm/models/model_registry.cpp @@ -115,6 +115,20 @@ void ModelRegistry::register_vlm_embedding_factory( } } +void ModelRegistry::register_mm_embedding_vlm_factory( + const std::string& name, + MMEmbeddingVLMFactory factory) { + ModelRegistry* instance = get_instance(); + + if (instance->model_registry_[name].mm_embedding_vlm_factory != nullptr) { + SAFE_LOG_WARNING("mm embedding vlm factory for " << name + << " already registered."); + } else { + instance->model_registry_[name].mm_embedding_vlm_factory = factory; + instance->model_backend_[name] = "vlm"; + } +} + void ModelRegistry::register_dit_model_factory(const std::string& name, DiTModelFactory factory) { ModelRegistry* instance = get_instance(); @@ -216,6 +230,13 @@ EmbeddingVLMFactory ModelRegistry::get_embeddingvlm_factory( return instance->model_registry_[name].embedding_vlm_factory; } +MMEmbeddingVLMFactory ModelRegistry::get_mm_embedding_vlm_factory( + const std::string& name) { + ModelRegistry* instance = get_instance(); + + return instance->model_registry_[name].mm_embedding_vlm_factory; +} + DiTModelFactory ModelRegistry::get_dit_model_factory(const std::string& name) { ModelRegistry* instance = get_instance(); return instance->model_registry_[name].dit_model_factory; @@ -317,6 +338,21 @@ std::unique_ptr create_vlm_embedding_model( return nullptr; } +std::unique_ptr create_vlm_mm_embedding_model( + const ModelContext& context) { + // get the factory function for the model type from model registry + auto factory = ModelRegistry::get_mm_embedding_vlm_factory( + context.get_model_args().model_type()); + if (factory) { + return factory(context); + } + + LOG(ERROR) << "Unsupported model type: " + << context.get_model_args().model_type(); + + return nullptr; +} + std::unique_ptr create_dit_model(const DiTModelContext& context) { // get the factory function for the model type from model registry auto factory = ModelRegistry::get_dit_model_factory(context.model_type()); diff --git a/xllm/models/model_registry.h b/xllm/models/model_registry.h index 25c34ad0a..9ec52afd6 100644 --- a/xllm/models/model_registry.h +++ b/xllm/models/model_registry.h @@ -26,6 +26,7 @@ limitations under the License. #include "core/framework/model/dit_model.h" #include "core/framework/model/embedding_lm.h" #include "core/framework/model/embedding_vlm.h" +#include "core/framework/model/mm_embedding_vlm.h" #include "core/framework/model_context.h" #include "core/framework/tokenizer/tokenizer_args.h" #include "core/util/json_reader.h" @@ -47,6 +48,9 @@ using EmbeddingLMFactory = using EmbeddingVLMFactory = std::function(const ModelContext& context)>; +using MMEmbeddingVLMFactory = + std::function(const ModelContext& context)>; + using DiTModelFactory = std::function(const DiTModelContext& context)>; @@ -71,6 +75,7 @@ struct ModelMeta { CausalVLMFactory causal_vlm_factory; EmbeddingLMFactory embedding_lm_factory; EmbeddingVLMFactory embedding_vlm_factory; + MMEmbeddingVLMFactory mm_embedding_vlm_factory; DiTModelFactory dit_model_factory; InputProcessorFactory input_processor_factory; ImageProcessorFactory image_processor_factory; @@ -97,6 +102,9 @@ class ModelRegistry { static void register_vlm_embedding_factory(const std::string& name, EmbeddingVLMFactory factory); + static void register_mm_embedding_vlm_factory(const std::string& name, + MMEmbeddingVLMFactory factory); + static void register_dit_model_factory(const std::string& name, DiTModelFactory factory); @@ -122,6 +130,9 @@ class ModelRegistry { static EmbeddingVLMFactory get_embeddingvlm_factory(const std::string& name); + static MMEmbeddingVLMFactory get_mm_embedding_vlm_factory( + const std::string& name); + static DiTModelFactory get_dit_model_factory(const std::string& name); static ModelArgsLoader get_model_args_loader(const std::string& name); @@ -153,6 +164,9 @@ std::unique_ptr create_lm_embedding_model( std::unique_ptr create_vlm_embedding_model( const ModelContext& context); +std::unique_ptr create_vlm_mm_embedding_model( + const ModelContext& context); + std::unique_ptr create_dit_model(const DiTModelContext& context); // Macro to register a model with the ModelRegistry @@ -218,6 +232,22 @@ std::unique_ptr create_dit_model(const DiTModelContext& context); #define REGISTER_EMBEDDING_VLM_MODEL(ModelType, ModelClass) \ REGISTER_EMBEDDING_VLM_MODEL_WITH_VARNAME(ModelType, ModelType, ModelClass) +#define REGISTER_MM_EMBEDDING_VLM_MODEL_WITH_VARNAME( \ + VarName, ModelType, ModelClass) \ + const bool VarName##_registered = []() { \ + ModelRegistry::register_mm_embedding_vlm_factory( \ + #ModelType, [](const ModelContext& context) { \ + ModelClass model(context); \ + model->eval(); \ + return std::make_unique>( \ + std::move(model), context.get_tensor_options()); \ + }); \ + return true; \ + }() + +#define REGISTER_MM_EMBEDDING_VLM_MODEL(ModelType, ModelClass) \ + REGISTER_MM_EMBEDDING_VLM_MODEL_WITH_VARNAME(ModelType, ModelType, ModelClass) + #define REGISTER_DIT_MODEL_WITH_VARNAME(VarName, ModelType, ModelClass) \ const bool VarName##_registered = []() { \ ModelRegistry::register_dit_model_factory( \