diff --git a/CMakeLists.txt b/CMakeLists.txt index f18bca2ff..1bdcb25e5 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -32,20 +32,20 @@ if(USE_NPU) if(DEVICE_TYPE STREQUAL "USE_A3") message("downloading a3 arm xllm kernels") file(DOWNLOAD - "https://9n-das-tools.s3.cn-north-1.jdcloud-oss.com/xllm-ai/xllm_kernels/0.7.0/xllm_kernels-1.3.3-Linux.a3.arm.rpm" + "https://9n-das-tools.s3.cn-north-1.jdcloud-oss.com/xllm-ai/xllm_kernels/0.7.0/xllm_kernels-1.3.4-Linux.a3.arm.rpm" "${CMAKE_BINARY_DIR}/xllm_kernels.rpm" ) else() if(DEVICE_ARCH STREQUAL "ARM") message("downloading a2 arm xllm_kernels") file(DOWNLOAD - "https://9n-das-tools.s3.cn-north-1.jdcloud-oss.com/xllm-ai/xllm_kernels/0.7.0/xllm_kernels-1.3.3-Linux.a2.arm.rpm" + "https://9n-das-tools.s3.cn-north-1.jdcloud-oss.com/xllm-ai/xllm_kernels/0.7.0/xllm_kernels-1.3.4-Linux.a2.arm.rpm" "${CMAKE_BINARY_DIR}/xllm_kernels.rpm" ) else() message("downloading a2 x86 xllm_kernels") file(DOWNLOAD - "https://9n-das-tools.s3.cn-north-1.jdcloud-oss.com/xllm-ai/xllm_kernels/0.7.0/xllm_kernels-1.3.3-Linux.a2.x86.rpm" + "https://9n-das-tools.s3.cn-north-1.jdcloud-oss.com/xllm-ai/xllm_kernels/0.7.0/xllm_kernels-1.3.4-Linux.a2.x86.rpm" "${CMAKE_BINARY_DIR}/xllm_kernels.rpm" ) endif() diff --git a/vcpkg.json b/vcpkg.json old mode 100644 new mode 100755 index 3600c2dab..bc61fca13 --- a/vcpkg.json +++ b/vcpkg.json @@ -98,9 +98,10 @@ "version>=": "1.3.1" }, { - "name": "opencv", + "name": "opencv4", "version>=": "4.7.0", - "default-features": false + "default-features": false, + "features": ["ffmpeg", "jpeg", "png","tiff","webp","openexr","quirc"] }, { "name": "yaml-cpp", diff --git a/xllm/core/framework/batch/mposition.cpp b/xllm/core/framework/batch/mposition.cpp index d4421aad5..ef4f0ee64 100755 --- a/xllm/core/framework/batch/mposition.cpp +++ b/xllm/core/framework/batch/mposition.cpp @@ -15,10 +15,34 @@ limitations under the License. #include "mposition.h" +#include + #include "framework/model/model_args.h" #include "framework/request/sequence.h" + namespace xllm { +namespace { +std::vector> groupByTokenType( + const std::vector& token_types) { + std::vector> groups; + if (token_types.empty()) return groups; + + std::string current_key = token_types[0]; + int start = 0; + + for (int i = 1; i < token_types.size(); ++i) { + if (token_types[i] != current_key) { + groups.emplace_back(current_key, start, i); + current_key = token_types[i]; + start = i; + } + } + groups.emplace_back(current_key, start, static_cast(token_types.size())); + return groups; +} +} // namespace + torch::Tensor MPositionHelper::get_positions() { // if (seq_.is_chunked_prefill_stage()) { if (seq_.kv_state().kv_cache_tokens_num() < seq_.num_prompt_tokens()) { @@ -35,16 +59,128 @@ torch::Tensor MPositionHelper::get_positions() { torch::Tensor second_per_grid_ts; if (auto res = mm_data.get("second_per_grid_ts")) second_per_grid_ts = res.value(); - auto res = - get_positions_p(image_grid_thw, video_grid_thw, second_per_grid_ts); + std::tuple res; + if (!absl::StartsWith(args_.model_type(), "glm4v")) { + res = get_positions_p(image_grid_thw, video_grid_thw, second_per_grid_ts); + } else { + res = get_positions_glm(image_grid_thw, video_grid_thw); + } seq_.set_mrope_position_delta(std::get<1>(res)); - return std::get<0>(res); } else { return get_positions_d(); } } +std::tuple MPositionHelper::get_positions_glm( + torch::Tensor image_grid_thw, + torch::Tensor video_grid_thw) { + auto input_tokens = seq_.tokens(); + auto spatial_merge_size = args_.mm_spatial_merge_size(); + auto image_token_id = args_.image_token_id(); + auto video_token_id = args_.video_token_id(); + auto video_start_token_id = args_.video_start_token_id(); + auto video_end_token_id = args_.video_end_token_id(); + + auto dtype = torch::kInt32; + + std::vector input_token_type; + bool in_video = false; + int num_tokens = input_tokens.size(); + + for (int index = 0; index < num_tokens; ++index) { + auto token = input_tokens[index]; + if (token == video_start_token_id) { + in_video = true; + } else if (token == video_end_token_id) { + in_video = false; + } + + if (token == image_token_id && !in_video) { + input_token_type.push_back("image"); + } else if (token == image_token_id && in_video) { + input_token_type.push_back("video"); + } else { + input_token_type.push_back("text"); + } + } + auto input_type_group = groupByTokenType(input_token_type); + int image_index = 0; + int video_index = 0; + int video_group_index = 0; + + std::vector llm_pos_ids_list; + int video_frame_num = 1; + for (const auto& group : input_type_group) { + const auto& modality_type = std::get<0>(group); + int start_idx = std::get<1>(group); + int end_idx = std::get<2>(group); + int st_idx = 0; + if (!llm_pos_ids_list.empty()) { + st_idx = llm_pos_ids_list.back().max().item() + 1; + } + + if (modality_type == "image") { + auto grid = image_grid_thw[image_index]; + int t = grid[0].item(); + int h = grid[1].item() / spatial_merge_size; + int w = grid[2].item() / spatial_merge_size; + + auto t_arange = + torch::arange(t, dtype).view({-1, 1}).expand({-1, h * w}).flatten(); + auto h_arange = + torch::arange(h, dtype).view({1, -1, 1}).expand({t, -1, w}).flatten(); + auto w_arange = + torch::arange(w, dtype).view({1, 1, -1}).expand({t, h, -1}).flatten(); + + auto pos = torch::stack({t_arange, h_arange, w_arange}) + st_idx; + llm_pos_ids_list.push_back(pos); + video_frame_num = 1; + image_index++; + } else if (modality_type == "video") { + int t = video_frame_num; + int h = video_grid_thw[video_index][1].item() / spatial_merge_size; + int w = video_grid_thw[video_index][2].item() / spatial_merge_size; + + for (int t_idx = 0; t_idx < t; ++t_idx) { + auto t_tensor = torch::full({1, h * w}, t_idx, dtype).flatten(); + auto h_tensor = torch::arange(h, dtype) + .view({1, -1, 1}) + .expand({1, -1, w}) + .flatten(); + auto w_tensor = torch::arange(w, dtype) + .view({1, 1, -1}) + .expand({1, h, -1}) + .flatten(); + + auto pos = torch::stack({t_tensor, h_tensor, w_tensor}) + st_idx; + llm_pos_ids_list.push_back(pos); + } + + video_group_index++; + if (video_group_index >= video_grid_thw[video_index][0].item()) { + video_index++; + video_group_index = 0; + } + video_frame_num++; + } else { // text + int text_len = end_idx - start_idx; + auto arange = + torch::arange(text_len, dtype).view({1, -1}).expand({3, -1}) + st_idx; + llm_pos_ids_list.push_back(arange); + video_frame_num = 1; + } + } + + torch::Tensor llm_positions = + torch::cat(llm_pos_ids_list, /*dim=*/1).reshape({3, -1}); + llm_positions = llm_positions; + int mrope_position_delta = + (llm_positions.max().item() + 1 - input_tokens.size()); + + return std::make_pair(llm_positions, mrope_position_delta); +} + std::tuple MPositionHelper::get_positions_p( torch::Tensor image_grid_thw, torch::Tensor video_grid_thw, diff --git a/xllm/core/framework/batch/mposition.h b/xllm/core/framework/batch/mposition.h index c4575526c..466660baa 100644 --- a/xllm/core/framework/batch/mposition.h +++ b/xllm/core/framework/batch/mposition.h @@ -37,6 +37,10 @@ class MPositionHelper { torch::Tensor image_grid_thw, torch::Tensor video_grid_thw, torch::Tensor second_per_grid_ts); + std::tuple get_positions_glm( + torch::Tensor image_grid_thw, + torch::Tensor video_grid_thw); + torch::Tensor get_positions_d(); private: diff --git a/xllm/core/framework/chat_template/jinja_chat_template.cpp b/xllm/core/framework/chat_template/jinja_chat_template.cpp index f206cc0f0..fcd1f2166 100644 --- a/xllm/core/framework/chat_template/jinja_chat_template.cpp +++ b/xllm/core/framework/chat_template/jinja_chat_template.cpp @@ -141,6 +141,15 @@ nlohmann::ordered_json JinjaChatTemplate::get_mm_content( if (item.type == "text") { item_json["text"] = item.text; + } else if (item.type == "video_url") { + item_json["video"] = "mm place holder"; + item_json["video_url"] = "mm place holder"; + } else if (item.type == "image_url") { + item_json["image"] = "mm place holder"; + item_json["image_url"] = "mm place holder"; + } else if (item.type == "audio_url") { + item_json["audio"] = "mm place holder"; + item_json["audio_url"] = "mm place holder"; } else { item_json[item.type] = "mm place holder"; } diff --git a/xllm/core/framework/hf_model_loader.cpp b/xllm/core/framework/hf_model_loader.cpp index e5fd7c348..23738289b 100644 --- a/xllm/core/framework/hf_model_loader.cpp +++ b/xllm/core/framework/hf_model_loader.cpp @@ -102,6 +102,11 @@ bool HFModelLoader::load_args(const std::string& model_weights_path) { return false; } + if (!load_video_preprocessor_args(model_weights_path)) { + LOG(ERROR) << "Failed to load video preprocess args from " + << model_weights_path; + return false; + } // Some hacky logics to support loading of old models // always use float16 for quantization // TODO: support quantization for other data types @@ -416,4 +421,24 @@ bool HFModelLoader::load_image_preprocessor_args( return true; } +bool HFModelLoader::load_video_preprocessor_args( + const std::string& model_weights_path) { + // image preprocessor args + JsonReader video_preprocess_reader; + const std::string video_preprocess_file_path = + model_weights_path + "/video_preprocessor_config.json"; + if (video_preprocess_reader.parse(video_preprocess_file_path)) { + LOG(INFO) << "Success to parse video preprocess args file: " + << video_preprocess_file_path; + + args_.mm_video_shortest_edge() = + video_preprocess_reader.value_or("size.shortest_edge", 0); + + args_.mm_video_longest_edge() = + video_preprocess_reader.value_or("size.longest_edge", 0); + } + + return true; +} + } // namespace xllm diff --git a/xllm/core/framework/hf_model_loader.h b/xllm/core/framework/hf_model_loader.h index bb2401b0f..f98939adb 100644 --- a/xllm/core/framework/hf_model_loader.h +++ b/xllm/core/framework/hf_model_loader.h @@ -40,6 +40,7 @@ class HFModelLoader : public ModelLoader { bool load_quant_args(const std::string& model_weights_path); bool load_tokenizer_args(const std::string& model_weights_path); bool load_image_preprocessor_args(const std::string& model_weights_path); + bool load_video_preprocessor_args(const std::string& model_weights_path); std::string model_weights_path() const override { return model_weights_path_; } diff --git a/xllm/core/framework/model/model_args.h b/xllm/core/framework/model/model_args.h index 168565e89..2f3ce9d96 100644 --- a/xllm/core/framework/model/model_args.h +++ b/xllm/core/framework/model/model_args.h @@ -136,6 +136,12 @@ struct ModelArgs { PROPERTY(int32_t, image_token_id) = 0; PROPERTY(int32_t, video_token_id) = 0; + // glm4v moe + PROPERTY(int32_t, image_start_token_id) = 0; + PROPERTY(int32_t, image_end_token_id) = 0; + PROPERTY(int32_t, video_start_token_id) = 0; + PROPERTY(int32_t, video_end_token_id) = 0; + PROPERTY(std::string, vision_custom_adapter); PROPERTY(int32_t, vision_max_slice_nums) = 0; @@ -297,6 +303,10 @@ struct ModelArgs { PROPERTY(int64_t, mm_image_shortest_edge) = 0; PROPERTY(int64_t, mm_image_longest_edge) = 0; + // GLM + PROPERTY(int64_t, mm_video_shortest_edge) = 0; + PROPERTY(int64_t, mm_video_longest_edge) = 0; + PROPERTY(int, mm_image_patch_size) = 0; PROPERTY(int, mm_image_temporal_patch_size) = 0; PROPERTY(int, mm_image_merge_size) = 0; diff --git a/xllm/core/framework/request/mm_codec.cpp b/xllm/core/framework/request/mm_codec.cpp index cdb1abc1c..e862b76d4 100644 --- a/xllm/core/framework/request/mm_codec.cpp +++ b/xllm/core/framework/request/mm_codec.cpp @@ -13,7 +13,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - +extern "C" { +#include +#include +#include +} #include "mm_codec.h" namespace xllm { @@ -73,4 +77,213 @@ bool OpenCVImageEncoder::valid(const torch::Tensor& t) { return true; } +bool OpenCVVideoDecoder::decode(const std::string& raw_data, + torch::Tensor& t, + VideoMetadata& metadata) { + struct MemCtx { + const uint8_t* p; + size_t sz; + size_t off; + }; + + struct Reader { + static int read(void* opaque, uint8_t* buf, int buf_size) { + auto* mc = static_cast(opaque); + size_t remain = mc->sz - mc->off; + int n = (int)std::min(remain, (size_t)buf_size); + if (n <= 0) return AVERROR_EOF; + memcpy(buf, mc->p + mc->off, n); + mc->off += (size_t)n; + return n; + } + + static int64_t seek(void* opaque, int64_t offset, int whence) { + auto* mc = static_cast(opaque); + + if (whence == AVSEEK_SIZE) { + return (int64_t)mc->sz; + } + + int64_t pos = 0; + switch (whence) { + case SEEK_SET: + pos = offset; + break; + case SEEK_CUR: + pos = (int64_t)mc->off + offset; + break; + case SEEK_END: + pos = (int64_t)mc->sz + offset; + break; + default: + return AVERROR(EINVAL); + } + + if (pos < 0 || pos > (int64_t)mc->sz) return AVERROR_EOF; + + mc->off = (size_t)pos; + return pos; + } + }; + + AVFormatContext* fmt = avformat_alloc_context(); + const int avio_buf_sz = 1 << 16; + uint8_t* avio_buf = (uint8_t*)av_malloc(avio_buf_sz); + if (!fmt || !avio_buf) { + if (fmt) avformat_free_context(fmt); + if (avio_buf) av_free(avio_buf); + return false; + } + + MemCtx mc{(const uint8_t*)raw_data.data(), raw_data.size(), 0}; + + AVIOContext* avio = avio_alloc_context( + avio_buf, avio_buf_sz, 0, &mc, &Reader::read, nullptr, &Reader::seek); + if (!avio) { + av_free(avio_buf); + avformat_free_context(fmt); + return false; + } + + avio->seekable = AVIO_SEEKABLE_NORMAL; + + fmt->pb = avio; + fmt->flags |= AVFMT_FLAG_CUSTOM_IO; + + fmt->probesize = std::min(raw_data.size(), 20 * 1024 * 1024); + fmt->max_analyze_duration = 5LL * AV_TIME_BASE; + + bool ok = false; + + AVDictionary* opts = nullptr; + av_dict_set(&opts, "probesize", "20000000", 0); + av_dict_set(&opts, "analyzeduration", "5000000", 0); + + int ret = avformat_open_input(&fmt, nullptr, nullptr, &opts); + av_dict_free(&opts); + + if (ret < 0) { + av_freep(&avio->buffer); + avio_context_free(&avio); + avformat_free_context(fmt); + return false; + } + + ret = avformat_find_stream_info(fmt, nullptr); + if (ret < 0) { + av_freep(&avio->buffer); + avio_context_free(&avio); + avformat_close_input(&fmt); + return false; + } + + int vs = av_find_best_stream(fmt, AVMEDIA_TYPE_VIDEO, -1, -1, nullptr, 0); + if (vs < 0) { + av_freep(&avio->buffer); + avio_context_free(&avio); + avformat_close_input(&fmt); + return false; + } + + AVStream* st = fmt->streams[vs]; + AVCodecParameters* par = st->codecpar; + const AVCodec* dec = avcodec_find_decoder(par->codec_id); + if (!dec) { + av_freep(&avio->buffer); + avio_context_free(&avio); + avformat_close_input(&fmt); + return false; + } + + AVCodecContext* cc = avcodec_alloc_context3(dec); + if (!cc) { + av_freep(&avio->buffer); + avio_context_free(&avio); + avformat_close_input(&fmt); + return false; + } + + if (avcodec_parameters_to_context(cc, par) < 0 || + avcodec_open2(cc, dec, nullptr) < 0) { + avcodec_free_context(&cc); + av_freep(&avio->buffer); + avio_context_free(&avio); + avformat_close_input(&fmt); + return false; + } + + AVRational r = st->avg_frame_rate.num ? st->avg_frame_rate : st->r_frame_rate; + double fps = (r.num && r.den) ? av_q2d(r) : 0.0; + metadata.fps = fps; + + SwsContext* sws = nullptr; + AVPacket* pkt = av_packet_alloc(); + AVFrame* frm = av_frame_alloc(); + std::vector frames; + + auto push_frame = [&](AVFrame* f) -> bool { + if (!sws) { + sws = sws_getContext(f->width, + f->height, + (AVPixelFormat)f->format, + f->width, + f->height, + AV_PIX_FMT_RGB24, + SWS_BILINEAR, + nullptr, + nullptr, + nullptr); + if (!sws) return false; + } + + torch::Tensor rgb = torch::empty({f->height, f->width, 3}, torch::kUInt8); + uint8_t* dst_data[4] = {rgb.data_ptr(), nullptr, nullptr, nullptr}; + int dst_linesize[4] = {(int)rgb.stride(0), 0, 0, 0}; + + sws_scale(sws, f->data, f->linesize, 0, f->height, dst_data, dst_linesize); + + frames.emplace_back(rgb.permute({2, 0, 1}).clone()); // [C,H,W] + return true; + }; + + while (av_read_frame(fmt, pkt) >= 0) { + if (pkt->stream_index == vs) { + if (avcodec_send_packet(cc, pkt) == 0) { + while (avcodec_receive_frame(cc, frm) == 0) { + if (!push_frame(frm)) break; + } + } + } + av_packet_unref(pkt); + } + + // flush + avcodec_send_packet(cc, nullptr); + while (avcodec_receive_frame(cc, frm) == 0) { + if (!push_frame(frm)) break; + } + + if (!frames.empty()) { + t = torch::stack(frames); // [T,C,H,W] + metadata.total_num_frames = static_cast(frames.size()); + if (metadata.fps > 0.0) { + metadata.duration = metadata.total_num_frames / metadata.fps; + } else { + metadata.duration = 0.0; + } + ok = true; + } + + if (sws) sws_freeContext(sws); + av_frame_free(&frm); + av_packet_free(&pkt); + avcodec_free_context(&cc); + + av_freep(&avio->buffer); + avio_context_free(&avio); + avformat_close_input(&fmt); + + return ok; +} + } // namespace xllm diff --git a/xllm/core/framework/request/mm_codec.h b/xllm/core/framework/request/mm_codec.h index eea7d9d32..92dd0f001 100644 --- a/xllm/core/framework/request/mm_codec.h +++ b/xllm/core/framework/request/mm_codec.h @@ -20,6 +20,8 @@ limitations under the License. #include #include +#include "mm_data.h" + namespace xllm { class OpenCVImageDecoder { @@ -41,4 +43,13 @@ class OpenCVImageEncoder { bool valid(const torch::Tensor& t); }; +class OpenCVVideoDecoder { + public: + OpenCVVideoDecoder() = default; + ~OpenCVVideoDecoder() = default; + + bool decode(const std::string& raw_data, + torch::Tensor& t, + VideoMetadata& meta); +}; } // namespace xllm diff --git a/xllm/core/framework/request/mm_data.h b/xllm/core/framework/request/mm_data.h index a6bca46c3..12d4dff61 100644 --- a/xllm/core/framework/request/mm_data.h +++ b/xllm/core/framework/request/mm_data.h @@ -52,6 +52,15 @@ class MMType { Value value = Value::NONE; }; +struct VideoMetadata { + double fps = 0.0; // original fps + int64_t total_num_frames = 0; // original frames + double duration = 0.0; + double sampled_fps = 0.0; + torch::Tensor frame_indices; + std::vector timestamps; +}; + using MMKey = std::string; using MMValue = std::variant>; using MMDict = std::unordered_map; @@ -133,11 +142,22 @@ struct MMData { void debug_print() const; + const std::vector& get_video_metadata() const { + return video_metadata_; + } + + void set_video_metadata(const std::vector& meta) { + video_metadata_ = meta; + } + static MMData to(const MMData& mm_data, const torch::Device& device); static MMData batch(const std::vector& mm_datas); uint32_t ty_ = MMType::NONE; MMDict data_; + + private: + std::vector video_metadata_; }; } // namespace xllm diff --git a/xllm/core/framework/request/mm_handler.cpp b/xllm/core/framework/request/mm_handler.cpp index 7a93dfd5e..bb2397b69 100644 --- a/xllm/core/framework/request/mm_handler.cpp +++ b/xllm/core/framework/request/mm_handler.cpp @@ -83,9 +83,36 @@ bool ImageHandler::decode(MMInputItem& input) { return decoder.decode(input.raw_data_, input.decode_data_); } +bool VideoHandler::load(const MMContent& content, MMInputItem& input) { + input.clear(); + + const auto& video_url = content.video_url; + const auto& url = video_url.url; + + if (url.compare(0, dataurl_prefix_.size(), dataurl_prefix_) == + 0) { // data url + + input.type_ = MMType::VIDEO; + return this->load_from_dataurl(url, input.raw_data_); + } else if (url.compare(0, httpurl_prefix_.size(), httpurl_prefix_) == + 0) { // http url + + input.type_ = MMType::VIDEO; + return this->load_from_http(url, input.raw_data_); + } else { + LOG(ERROR) << " video url is invalid, url is " << url; + return false; + } +} + +bool VideoHandler::decode(MMInputItem& input) { + OpenCVVideoDecoder decoder; + return decoder.decode(input.raw_data_, input.decode_data_, input.video_meta_); +} + MMHandlerSet::MMHandlerSet() { handlers_["image_url"] = std::make_unique(); - // handlers_["video_url"] = std::make_unique(); + handlers_["video_url"] = std::make_unique(); // handlers_["audio_url"] = std::make_unique(); } diff --git a/xllm/core/framework/request/mm_handler.h b/xllm/core/framework/request/mm_handler.h index db6d8ac1d..ff8d55c9c 100644 --- a/xllm/core/framework/request/mm_handler.h +++ b/xllm/core/framework/request/mm_handler.h @@ -59,6 +59,18 @@ class ImageHandler : public MMHandlerBase { std::string dataurl_prefix_{"data:image"}; }; +class VideoHandler : public MMHandlerBase { + public: + VideoHandler() = default; + ~VideoHandler() = default; + + virtual bool load(const MMContent& content, MMInputItem& input) override; + virtual bool decode(MMInputItem& input) override; + + private: + std::string dataurl_prefix_{"data:video"}; +}; + class MMHandlerSet { public: MMHandlerSet(); diff --git a/xllm/core/framework/request/mm_input.h b/xllm/core/framework/request/mm_input.h index 32deea294..9f2d3237c 100644 --- a/xllm/core/framework/request/mm_input.h +++ b/xllm/core/framework/request/mm_input.h @@ -35,6 +35,8 @@ struct MMInputItem { std::string raw_data_; // binary torch::Tensor decode_data_; // image: rgb, [c,h,w], uint8 + + VideoMetadata video_meta_; }; struct MMInput { @@ -56,6 +58,15 @@ struct MMInput { return std::move(vec); } + std::vector get_video_metadata() const { + std::vector metas; + metas.reserve(items_.size()); + for (auto& item : items_) { + metas.push_back(item.video_meta_); + } + return metas; + } + std::vector items_; }; diff --git a/xllm/core/layers/CMakeLists.txt b/xllm/core/layers/CMakeLists.txt index 6d56faccd..bb7d67ab9 100644 --- a/xllm/core/layers/CMakeLists.txt +++ b/xllm/core/layers/CMakeLists.txt @@ -61,8 +61,10 @@ cc_library( qwen2_decoder_layer.h qwen2dot5_vision_decode_layer.h qwen3_vision_encode_layer.h + glm4_vision_encode_layer.h qwen3_decoder_layer.h qwen3_moe_decoder_layer.h + glm4_decoder_layer.h rms_norm.h siglip_encoder_layer.h pos_embedding.h diff --git a/xllm/core/layers/glm4_decoder_layer.h b/xllm/core/layers/glm4_decoder_layer.h new file mode 100644 index 000000000..8fd399450 --- /dev/null +++ b/xllm/core/layers/glm4_decoder_layer.h @@ -0,0 +1,45 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include "npu/npu_glm4_decoder_layer_impl.h" + +namespace xllm { +namespace layer { + +#if defined(USE_NPU) +class Glm4DecoderLayer + : public torch::nn::ModuleHolder { + public: + using torch::nn::ModuleHolder::ModuleHolder; + using Impl __attribute__((__unused__)) = NpuGlm4DecoderLayerImpl; + + Glm4DecoderLayer(const ModelContext& context) + : ModuleHolder(std::make_shared(context)) {} +}; +#else +class Glm4DecoderLayer : public torch::nn::ModuleHolder { + public: + using torch::nn::ModuleHolder::ModuleHolder; + using Impl __attribute__((__unused__)) = Qwen2DecoderImpl; + + Glm4DecoderLayer(const ModelContext& context) + : ModuleHolder(std::make_shared(context)) {} +}; +#endif + +} // namespace layer +} // namespace xllm diff --git a/xllm/core/layers/glm4_vision_encode_layer.h b/xllm/core/layers/glm4_vision_encode_layer.h new file mode 100644 index 000000000..792700dfb --- /dev/null +++ b/xllm/core/layers/glm4_vision_encode_layer.h @@ -0,0 +1,39 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#if defined(USE_NPU) +#include "npu/npu_glm4_vision_encoder_layer_impl.h" +#endif + +namespace xllm { +namespace layer { + +#if defined(USE_NPU) +class Glm4VisionEncoderLayer + : public torch::nn::ModuleHolder { + public: + using torch::nn::ModuleHolder::ModuleHolder; + using Impl __attribute__((__unused__)) = NpuGlm4VisionEncoderLayerImpl; + + Glm4VisionEncoderLayer(const ModelContext& context) + : ModuleHolder( + std::make_shared(context)) {} +}; +#endif + +} // namespace layer +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/layers/npu/CMakeLists.txt b/xllm/core/layers/npu/CMakeLists.txt index 61f7759d9..8bae3ec2f 100644 --- a/xllm/core/layers/npu/CMakeLists.txt +++ b/xllm/core/layers/npu/CMakeLists.txt @@ -18,11 +18,13 @@ cc_library( buffer/atb_workspace.h npu_base_layer.h npu_column_parallel_linear_impl.h + npu_glm4_vision_encoder_layer_impl.h npu_glm4_moe_decoder_layer.h npu_deepseek_v2_decoder_layer_impl.h npu_llama_decoder_layer_impl.h npu_qwen2_decoder_layer_impl.h npu_qwen3_decoder_layer_impl.h + npu_glm4_decoder_layer_impl.h npu_rms_norm_impl.h npu_siglip_encoder_layer_impl.h SRCS @@ -38,11 +40,13 @@ cc_library( buffer/atb_workspace.cpp npu_base_layer.cpp npu_column_parallel_linear_impl.cpp + npu_glm4_vision_encoder_layer_impl.cpp npu_glm4_moe_decoder_layer.cpp npu_deepseek_v2_decoder_layer_impl.cpp npu_llama_decoder_layer_impl.cpp npu_qwen2_decoder_layer_impl.cpp npu_qwen3_decoder_layer_impl.cpp + npu_glm4_decoder_layer_impl.cpp npu_rms_norm_impl.cpp npu_siglip_encoder_layer_impl.cpp DEPS diff --git a/xllm/core/layers/npu/npu_glm4_decoder_layer_impl.cpp b/xllm/core/layers/npu/npu_glm4_decoder_layer_impl.cpp new file mode 100644 index 000000000..26a87bae6 --- /dev/null +++ b/xllm/core/layers/npu/npu_glm4_decoder_layer_impl.cpp @@ -0,0 +1,395 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "npu_glm4_decoder_layer_impl.h" + +#include +#include + +#include + +#include "common/global_flags.h" +#include "torch_npu/csrc/core/npu/NPUCachingAllocator.h" +#include "torch_npu/csrc/core/npu/NPUException.h" + +namespace xllm { +namespace layer { + +enum DecoderLayerTensorId : int { + IN_NORM_WEIGHT = 0, // weight + IN_NORM_BIAS = 1, // bias + IN_NORM_NEW_WEIGHT = 2, // new weight + IN_NORM_NEW_BIAS = 3, // new bias + + IN_Q_WEIGHT = 4, // weight + IN_Q_BIAS = 5, // bias + IN_Q_DEQSCALE = 6, // deq_scale + IN_Q_OFFSET = 7, // offset + IN_Q_SCALE = 8, // scale + IN_Q_COMPRESS_IDX = 9, + + IN_K_WEIGHT = 10, // weight + IN_K_BIAS = 11, // bias + IN_K_DEQSCALE = 12, // deq_scale + IN_K_OFFSET = 13, // offset + IN_K_SCALE = 14, // scale + IN_K_COMPRESS_IDX = 15, + + IN_V_WEIGHT = 16, // weight + IN_V_BIAS = 17, // bias + IN_V_DEQSCALE = 18, // deq_scale + IN_V_OFFSET = 19, // offset + IN_V_SCALE = 20, // scale + IN_V_COMPRESS_IDX = 21, + + IN_ATTENTION_OUT_WEIGHT = 22, // weight + IN_ATTENTION_OUT_BIAS = 23, // bias + IN_ATTENTION_OUT_DEQSCALE = 24, // deq_scale + IN_ATTENTION_OUT_OFFSET = 25, // offset + IN_ATTENTION_OUT_SCALE = 26, // scale + IN_ATTENTION_OUT_COMPRESS_IDX = 27, + + IN_SELFOUT_NORM_WEIGHT = 28, // weight + IN_SELFOUT_NORM_BIAS = 29, // bias + IN_SELFOUT_NORM_NEW_WEIGHT = 30, // new weight + IN_SELFOUT_NORM_NEW_BIAS = 31, // new bias + + IN_MLP_GATEUP_WEIGHT = 32, // weight + IN_MLP_GATEUP_BIAS = 33, // bias + IN_MLP_GATEUP_DEQSCALE = 34, // deq_scale + IN_MLP_GATEUP_OFFSET = 35, // offset + IN_MLP_GATEUP_SCALE = 36, // scale + IN_MLP_GATEUP_COMPRESS_IDX = 37, + + IN_MLP_W1_WEIGHT = 38, // weight + IN_MLP_W1_BIAS = 39, // bias + IN_MLP_W1_DEQSCALE = 40, // deq_scale + IN_MLP_W1_OFFSET = 41, // offset + IN_MLP_W1_SCALE = 42, // scale + IN_MLP_W1_COMPRESS_IDX = 43, + + IN_MLP_CPROJ_WEIGHT = 44, // weight + IN_MLP_CPROJ_BIAS = 45, // bias + IN_MLP_CPROJ_DEQSCALE = 46, // deq_scale + IN_MLP_CPROJ_OFFSET = 47, // offset + IN_MLP_CPROJ_SCALE = 48, // scale + IN_MLP_CPROJ_COMPRESS_IDX = 49, + + IN_SELFIN_NORM_WEIGHT = 50, + IN_MLPOUT_NORM_WEIGHT = 51 +}; + +const uint64_t WEIGHT_COUNT_PER_LAYER = 52; + +static std::unordered_map WEIGHT_MAPPING = { + {"input_layernorm.weight", IN_NORM_WEIGHT}, + + {"self_attn.q_proj.weight", IN_Q_WEIGHT}, + {"self_attn.q_proj.bias", IN_Q_BIAS}, + + {"self_attn.k_proj.weight", IN_K_WEIGHT}, + {"self_attn.k_proj.bias", IN_K_BIAS}, + + {"self_attn.v_proj.weight", IN_V_WEIGHT}, + {"self_attn.v_proj.bias", IN_V_BIAS}, + + {"self_attn.o_proj.weight", IN_ATTENTION_OUT_WEIGHT}, + + {"post_attention_layernorm.weight", IN_SELFOUT_NORM_WEIGHT}, + + // mlp + {"mlp.gate_up_proj.weight", IN_MLP_GATEUP_WEIGHT}, + + {"mlp.down_proj.weight", IN_MLP_CPROJ_WEIGHT}, + + {"post_self_attn_layernorm.weight", IN_SELFIN_NORM_WEIGHT}, + {"post_mlp_layernorm.weight", IN_MLPOUT_NORM_WEIGHT} + +}; + +static std::map WEIGHT_SHARD = {{IN_Q_WEIGHT, 0}, + {IN_Q_BIAS, 0}, + {IN_K_WEIGHT, 0}, + {IN_K_BIAS, 0}, + {IN_V_WEIGHT, 0}, + {IN_V_BIAS, 0}, + {IN_ATTENTION_OUT_WEIGHT, 1}, + {IN_MLP_GATEUP_WEIGHT, 0}, + {IN_MLP_CPROJ_WEIGHT, 1}}; + +void NpuGlm4DecoderLayerImpl::param_from_args( + atb_speed::chatglm::ChatglmLayerParam& param, + const ModelArgs& args, + const ParallelArgs& parallel_args, + bool isPrefill) { + param.isFA = false; + param.enableSwiGLU = true; + + param.enableLcoc = false; + param.rmsnormQKNorm = false; + param.isPrefill = isPrefill; + param.isBF16 = args.dtype() == "bfloat16"; + param.enableSplitFuse = FLAGS_enable_chunked_prefill && isPrefill; + param.loraEnableGMM = false; + + param.linearTransposeType = {1, -1, -1, 1, 1, -1, 1}; // TODO + param.quantGroupSize = 0; + param.normEps = args.rms_norm_eps(); + param.numAttentionHeadsPerRank = args.n_heads() / parallel_args.world_size(); + param.hiddenSizePerAttentionHead = args.head_dim(); + std::optional optionalValue = args.n_kv_heads(); + param.numKeyValueHeadsPerRank = + static_cast(optionalValue.value()) / parallel_args.world_size(); + param.backend = FLAGS_communication_backend; + param.tensorParallelInfo = {parallel_args.rank(), + parallel_args.world_size(), + FLAGS_communication_backend}; + param.linearHasBias = {true, false, false, false}; + param.useQKNorm = false; + + param.numHiddenLayers = args.n_layers(); + param.usePostSelfAttnLayerNorm = true; + param.usePostMlpLayerNorm = true; + initialize_quantization_parameters(param); +} +void NpuGlm4DecoderLayerImpl::initialize_quantization_parameters( + atb_speed::chatglm::ChatglmLayerParam& param) { + param.linearDescs = {static_cast(LinearTypeV2::INVALID), + static_cast(LinearTypeV2::INVALID), + static_cast(LinearTypeV2::INVALID), + static_cast(LinearTypeV2::INVALID), + static_cast(LinearTypeV2::INVALID), + static_cast(LinearTypeV2::INVALID), + static_cast(LinearTypeV2::INVALID)}; + param.packQuantType = {static_cast(PackType::ALL_FP), + static_cast(PackType::ALL_FP)}; + param.linearQuantType = {static_cast(LinearType::FP), + static_cast(LinearType::INVALID), + static_cast(LinearType::INVALID), + static_cast(LinearType::FP), + static_cast(LinearType::FP), + static_cast(LinearType::INVALID), + static_cast(LinearType::FP)}; +} + +NpuGlm4DecoderLayerImpl::NpuGlm4DecoderLayerImpl(const ModelContext& context) + : NpuBaseLayer(context) { + auto model_args = context.get_model_args(); + auto parallel_args = context.get_parallel_args(); + auto options = context.get_tensor_options(); + + param_from_args(prefill_param_, model_args, parallel_args, true); + param_from_args(decode_param_, model_args, parallel_args, false); + at_weight_tensors_.resize(WEIGHT_COUNT_PER_LAYER); + atb_weight_tensors_.resize(WEIGHT_COUNT_PER_LAYER); + placeholder_vec_ = {1}; + dtype_ = c10::typeMetaToScalarType(options.dtype()); + rank_id_ = parallel_args.rank(); + placeholder_ = atb_speed::Utils::AtTensor2Tensor( + torch::zeros({1}).to(device_).to(dtype_)); + at_placeholder_ = torch::zeros({1}).to(device_).to(dtype_); + for (int i = 0; i < WEIGHT_COUNT_PER_LAYER; ++i) { + at_weight_tensors_[i] = torch::zeros({1}).to(options); + } +} +void NpuGlm4DecoderLayerImpl::verify_loaded_weights() const { + for (const auto& [name, index] : WEIGHT_MAPPING) { + CHECK(at_weight_tensors_[index].sizes() != std::vector({1})) + << "weight is not loaded for " << name; + } +} + +void NpuGlm4DecoderLayerImpl::merge_loaded_weights() { + at_weight_tensors_[IN_Q_WEIGHT] = + torch::cat({at_weight_tensors_[IN_Q_WEIGHT], + at_weight_tensors_[IN_K_WEIGHT], + at_weight_tensors_[IN_V_WEIGHT]}, + 0) + .contiguous(); + at_weight_tensors_[IN_Q_BIAS] = torch::cat({at_weight_tensors_[IN_Q_BIAS], + at_weight_tensors_[IN_K_BIAS], + at_weight_tensors_[IN_V_BIAS]}, + 0) + .contiguous(); + + for (auto idx : + {IN_MLP_W1_WEIGHT, IN_K_WEIGHT, IN_V_WEIGHT, IN_K_BIAS, IN_V_BIAS}) { + at_weight_tensors_[idx] = at_placeholder_; + } + + c10_npu::NPUCachingAllocator::emptyCache(); + for (int i = 0; i < WEIGHT_COUNT_PER_LAYER; ++i) { + atb_weight_tensors_[i] = + atb_speed::Utils::AtTensor2Tensor(at_weight_tensors_[i]); + } + + init_layer(); +} + +void NpuGlm4DecoderLayerImpl::load_state_dict(const StateDict& state_dict) { + for (const auto& [name, index] : WEIGHT_MAPPING) { + if (WEIGHT_SHARD.find(index) != WEIGHT_SHARD.end()) { + set_weight(state_dict, name, index, WEIGHT_SHARD[index]); + } else { + set_weight(state_dict, name, index); + } + } +} + +int64_t NpuGlm4DecoderLayerImpl::init_layer() { + init_attn_mask(); + name_ = "glm4_decoder_layer"; + model_name_ = "glm4"; + CHECK_OPERATION_STATUS_RETURN(init_node(prefill_node_, prefill_param_)); + CHECK_OPERATION_STATUS_RETURN(init_node(decode_node_, decode_param_)); + + return atb::NO_ERROR; +} + +int64_t NpuGlm4DecoderLayerImpl::init_attn_mask() { + torch::Dtype dtype = + prefill_param_.isBF16 ? torch::kBFloat16 : torch::kFloat16; + decode_attn_mask_ = torch::zeros({1}).to(device_).to(dtype); + + return atb::NO_ERROR; +} + +int64_t NpuGlm4DecoderLayerImpl::init_node( + atb_speed::Model::Node& node, + atb_speed::chatglm::ChatglmLayerParam& param) { + atb::Operation* operation = nullptr; + atb_speed::chatglm::ChatglmDecoderLayer decoder_layer(param); + decoder_layer.BuildGraph(&operation); + node.operation.reset(operation); + if (node.operation == nullptr) { + LOG(ERROR) << "node.operation is null"; + return -1; + } + if (node.operation->GetInputNum() < 1) { + LOG(ERROR) << "Can not resize number which is smaller than 1"; + return -1; + } + node.inTensors.resize(node.operation->GetInputNum()); + node.outTensors.resize(1); + size_t inTensorId = 1; + + for (size_t weightTensorId = 0; weightTensorId < WEIGHT_COUNT_PER_LAYER; + ++weightTensorId) { + node.inTensors.at(weightTensorId) = &atb_weight_tensors_[weightTensorId]; + } + node.variantPack.inTensors.reserve(node.inTensors.size()); + node.variantPack.inTensors.resize(node.inTensors.size()); + node.variantPack.outTensors.reserve(1); + node.variantPack.outTensors.resize(1); + + return atb::NO_ERROR; +} + +torch::Tensor NpuGlm4DecoderLayerImpl::forward(torch::Tensor& x, + torch::Tensor& cos_pos, + torch::Tensor& sin_pos, + torch::Tensor& attn_mask, + KVCache& kv_cache, + ModelInputParams& input_params, + aclrtEvent* event, + std::atomic* event_flag, + int node_id) { + atb::Status st; + if (input_params.decode_seq_range.second != + input_params.q_seq_lens.size(0) - 1) { + // if (input_params.empty_kv_cache) { + // mstxRangeId id = mstxRangeStartA("prefill build variant", nullptr); + build_node_variant_pack(prefill_node_, + x, + cos_pos, + sin_pos, + attn_mask, + kv_cache, + input_params, + true); + // mstxRangeEnd(id); + st = execute_node(prefill_node_, node_id, event, event_flag); + LOG_IF(FATAL, st != 0) << model_name_ + << "excute prefill layer fail, error code: " << st; + } else { + build_node_variant_pack(decode_node_, + x, + cos_pos, + sin_pos, + decode_attn_mask_, + kv_cache, + input_params, + false); + st = execute_node(decode_node_, node_id + 1000, event, event_flag); + LOG_IF(FATAL, st != 0) << model_name_ + << "excute decode layer fail, error code: " << st; + } + + return at_placeholder_; +} + +void NpuGlm4DecoderLayerImpl::build_node_variant_pack( + atb_speed::Model::Node& node, + torch::Tensor& x, + torch::Tensor& cos_pos, + torch::Tensor& sin_pos, + at::Tensor& attn_mask, + KVCache& kv_cache, + ModelInputParams& input_params, + bool is_prefill) { + internal_tensors_ = atb_speed::Utils::AtTensor2Tensor(x); + // std::cout<<"node.variantPack.inTensors.size:"< +#include +#else +#include +#include +#endif + +#include + +#include + +#include "atb/atb_infer.h" +#include "framework/kv_cache/kv_cache.h" +#include "framework/model/model_input_params.h" +#include "framework/model_context.h" +#include "framework/state_dict/state_dict.h" +#include "nlohmann/json.hpp" +#include "npu_base_layer.h" +#include "pytorch/adapter/utils/utils.h" +#include "xllm_kernels/core/include/atb_speed/base/hosttensor_binder.h" +#include "xllm_kernels/core/include/atb_speed/base/model.h" +#include "xllm_kernels/core/include/atb_speed/log.h" +#include "xllm_kernels/core/include/atb_speed/utils/model_factory.h" +#include "xllm_kernels/models/glm/layer/decoder_layer.h" + +namespace xllm { +namespace layer { + +class NpuGlm4DecoderLayerImpl : public NpuBaseLayer { + public: + explicit NpuGlm4DecoderLayerImpl(const ModelContext& context); + + ~NpuGlm4DecoderLayerImpl() {}; + + virtual void load_state_dict(const StateDict& state_dict) override; + + virtual void verify_loaded_weights() const override; + + virtual void merge_loaded_weights() override; + + virtual int64_t init_layer() override; + + torch::Tensor forward(torch::Tensor& x, + torch::Tensor& cos_pos, + torch::Tensor& sin_pos, + torch::Tensor& attn_mask, + KVCache& kv_cache, + ModelInputParams& input_params, + aclrtEvent* event = nullptr, + std::atomic* event_flag = nullptr, + int node_id = 0); + + private: + void param_from_args(atb_speed::chatglm::ChatglmLayerParam& param, + const ModelArgs& args, + const ParallelArgs& parallel_args, + bool isPrefill); + + void build_node_variant_pack(atb_speed::Model::Node& node, + torch::Tensor& x, + torch::Tensor& cos_pos, + torch::Tensor& sin_pos, + torch::Tensor& attn_mask, + KVCache& kv_cache, + ModelInputParams& input_params, + bool is_prefill); + + void initialize_quantization_parameters( + atb_speed::chatglm::ChatglmLayerParam& param); + + int64_t init_node(atb_speed::Model::Node& node, + atb_speed::chatglm::ChatglmLayerParam& param); + + int64_t init_attn_mask(); + + atb_speed::Model::Node prefill_node_; + atb_speed::Model::Node decode_node_; + std::string model_name_; + atb_speed::chatglm::ChatglmLayerParam prefill_param_; + atb_speed::chatglm::ChatglmLayerParam decode_param_; + atb::Tensor internal_tensors_; + atb::Tensor placeholder_; + + at::Tensor decode_attn_mask_; + + at::Tensor at_placeholder_; + + int device_id_; + int32_t layer_id_; + int rank_id_; +}; + +} // namespace layer +} // namespace xllm diff --git a/xllm/core/layers/npu/npu_glm4_vision_encoder_layer_impl.cpp b/xllm/core/layers/npu/npu_glm4_vision_encoder_layer_impl.cpp new file mode 100644 index 000000000..866ec9c77 --- /dev/null +++ b/xllm/core/layers/npu/npu_glm4_vision_encoder_layer_impl.cpp @@ -0,0 +1,263 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// copy from qwen3 vl, please follow its modifications +#include "npu_glm4_vision_encoder_layer_impl.h" + +#include +#include + +#include +#include + +#include "torch_npu/csrc/core/npu/NPUCachingAllocator.h" +#include "torch_npu/csrc/core/npu/NPUException.h" +#include "xllm_kernels/models/glm4v/glm4v_encoder.h" + +namespace xllm { +namespace layer { + +enum Glm4VisionEncoderLayerTensorId : int { + IN_INPUT_NORM_WEIGHT = 0, + IN_POST_NORM_WEIGHT, + IN_QKV_WEIGHT, + IN_ATTN_PROJ_WEIGHT, + IN_LINEAR_GATE_UP_WEIGHT, + IN_LINEAR_DOWN_WEIGHT, + IN_LINEAR_UP_WEIGHT, + IN_LINEAR_GATE_WEIGHT +}; + +const uint64_t WEIGHT_COUNT_PER_LAYER = 8; + +static std::vector> WEIGHT_MAPPING = { + {IN_INPUT_NORM_WEIGHT, "norm1.weight"}, + {IN_POST_NORM_WEIGHT, "norm2.weight"}, + {IN_QKV_WEIGHT, "attn.qkv.weight"}, + {IN_ATTN_PROJ_WEIGHT, "attn.proj.weight"}, + {IN_LINEAR_GATE_WEIGHT, "mlp.gate_proj.weight"}, + {IN_LINEAR_UP_WEIGHT, "mlp.up_proj.weight"}, + {IN_LINEAR_DOWN_WEIGHT, "mlp.down_proj.weight"}}; + +// {weight,dim} +// IN_QKV_WEIGHT SHARD artificially in merge_loaded_weights +static std::map WEIGHT_SHARD = { + {IN_ATTN_PROJ_WEIGHT, 1}, + {IN_LINEAR_UP_WEIGHT, 0}, + {IN_LINEAR_GATE_WEIGHT, 0}, + {IN_LINEAR_DOWN_WEIGHT, 1}}; +// TODO: check shape with atb log -- HW pxy + +void NpuGlm4VisionEncoderLayerImpl::param_from_args( + atb_speed::glm::VisionEncoderLayerParam& param, + const ModelArgs& args, + const ParallelArgs& parallel_args) { + param.isBF16 = args.dtype() == "bfloat16"; + param.supportLcoc = false; + param.supportLora = false; + param.loraEnableGMM = false; + param.enableLogN = false; + param.backend = "hccl"; + param.rank = parallel_args.rank(); + param.worldSize = parallel_args.world_size(); + + param.quantType = 0; + param.quantGroupSize = 64; + + param.numAttentionHeadsPerRank = + args.mm_num_attention_heads() / param.worldSize; + param.hiddenSizePerAttentionHead = + args.mm_hidden_size() / args.mm_num_attention_heads(); + std::optional optionalValue = args.mm_num_attention_heads(); + param.numKeyValueHeadsPerRank = + static_cast(optionalValue.value()) / param.worldSize; + + param.rmsNormEps = args.rms_norm_eps(); +} + +NpuGlm4VisionEncoderLayerImpl::NpuGlm4VisionEncoderLayerImpl( + const ModelContext& context) + : NpuBaseLayer(context) { + auto model_args = context.get_model_args(); + auto parallel_args = context.get_parallel_args(); + auto options = context.get_tensor_options(); + param_from_args(encode_param_, model_args, parallel_args); + at_weight_tensors_.resize(WEIGHT_COUNT_PER_LAYER); + atb_weight_tensors_.resize(WEIGHT_COUNT_PER_LAYER); + dtype_ = c10::typeMetaToScalarType(options.dtype()); + device_id_ = options.device().index(); + placeholder_ = + atb_speed::Utils::AtTensor2Tensor(torch::zeros({1}).to(device_).to( + dtype_)); // seems not to be used -- HW pxy + at_placeholder_ = torch::zeros({1}).to(device_).to(dtype_); + for (int i = 0; i < WEIGHT_COUNT_PER_LAYER; ++i) { + at_weight_tensors_[i] = torch::zeros({1}).to(options); + } +} + +void NpuGlm4VisionEncoderLayerImpl::verify_loaded_weights() const { + for (const auto& [index, name] : WEIGHT_MAPPING) { + CHECK(at_weight_tensors_[index].sizes() != std::vector({1})) + << "weight is not loaded for " << name; + } +} + +void NpuGlm4VisionEncoderLayerImpl::merge_loaded_weights() { + if (encode_param_.worldSize > 1) { + // spilt pack qkv weight when enable tp + get_weights_col_packed_qkv(); + } + + at_weight_tensors_[IN_LINEAR_GATE_UP_WEIGHT] = torch::cat({ + at_weight_tensors_[IN_LINEAR_GATE_WEIGHT], + at_weight_tensors_[IN_LINEAR_UP_WEIGHT]}, + 0); + at_weight_tensors_[IN_LINEAR_GATE_WEIGHT] = torch::empty({}, device_); + at_weight_tensors_[IN_LINEAR_UP_WEIGHT] = torch::empty({}, device_); + + c10_npu::NPUCachingAllocator::emptyCache(); + for (int i = 0; i < WEIGHT_COUNT_PER_LAYER; ++i) { + atb_weight_tensors_[i] = + atb_speed::Utils::AtTensor2Tensor(at_weight_tensors_[i]); + } + + init_layer(); +} + +// tp spilt weight +void NpuGlm4VisionEncoderLayerImpl::get_weights_col_packed_qkv() { + int rank = encode_param_.rank; + int worldSize = encode_param_.worldSize; + // split qkv weight + auto qkv_weight = torch::chunk(at_weight_tensors_[IN_QKV_WEIGHT], 3, 0); + // get local weight and merge + auto new_qkv_weight = torch::cat({(qkv_weight[0].chunk(worldSize, 0))[rank], + (qkv_weight[1].chunk(worldSize, 0))[rank], + (qkv_weight[2].chunk(worldSize, 0))[rank]}, + 0); + at_weight_tensors_[IN_QKV_WEIGHT] = new_qkv_weight; +} + +void NpuGlm4VisionEncoderLayerImpl::load_state_dict( + const StateDict& state_dict) { + for (const auto& [index, name] : WEIGHT_MAPPING) { + if (WEIGHT_SHARD.find(index) != WEIGHT_SHARD.end()) { + set_weight(state_dict, name, index, WEIGHT_SHARD[index]); + } else { + set_weight(state_dict, name, index); + } + } +} + +int64_t NpuGlm4VisionEncoderLayerImpl::init_layer() { + name_ = "glm4_vision_encoder_layer"; + model_name_ = "glm4v"; + CHECK_OPERATION_STATUS_RETURN(init_node(encode_node_, encode_param_)); + return atb::NO_ERROR; +} + +int64_t NpuGlm4VisionEncoderLayerImpl::init_node( + atb_speed::Model::Node& node, + atb_speed::glm::VisionEncoderLayerParam& param) { + atb::Operation* operation = nullptr; + atb_speed::glm::Glm4v_EncoderLayer(param, &operation); + node.operation.reset(operation); + if (node.operation == nullptr) { + LOG(ERROR) << "node.operation is null"; + return -1; + } + if (node.operation->GetInputNum() < 1) { + LOG(ERROR) << "Can not resize number which is smaller than 1"; + return -1; + } + node.inTensors.resize(node.operation->GetInputNum()); + node.outTensors.resize(1); + size_t inTensorId = 1; + + for (size_t weightTensorId = 0; weightTensorId < WEIGHT_COUNT_PER_LAYER; + ++weightTensorId) { + node.inTensors.at(weightTensorId) = &atb_weight_tensors_[weightTensorId]; + } + + node.variantPack.inTensors.reserve(node.inTensors.size()); + node.variantPack.inTensors.resize(node.inTensors.size()); + node.variantPack.outTensors.reserve(1); + node.variantPack.outTensors.resize(1); + return atb::NO_ERROR; +} + +torch::Tensor NpuGlm4VisionEncoderLayerImpl::forward( + torch::Tensor& x, + torch::Tensor& cos_pos, + torch::Tensor& sin_pos, + torch::Tensor& cu_seqlen, + std::vector& cu_seqlen_vec, + ModelInputParams& input_params, + int node_id, + aclrtEvent* event, + std::atomic* event_flag) { + atb::Status st; + + build_node_variant_pack(encode_node_, + x, + cos_pos, + sin_pos, + cu_seqlen, + cu_seqlen_vec, + input_params, + true); + // mstxRangeEnd(id); + st = execute_node(encode_node_, node_id); + LOG_IF(FATAL, st != 0) << model_name_ + << "excute encode layer fail, error code: " << st; + return x; +} + +void NpuGlm4VisionEncoderLayerImpl::build_node_variant_pack( + atb_speed::Model::Node& node, + torch::Tensor& x, + torch::Tensor& cos_pos, + torch::Tensor& sin_pos, + torch::Tensor& cu_seqlen, + std::vector& cu_seqlen_vec, + ModelInputParams& input_params, + bool is_prefill) { + internal_tensors_ = atb_speed::Utils::AtTensor2Tensor(x); + + auto actual_weight_num = WEIGHT_COUNT_PER_LAYER - 2; + for (size_t i = 0; i < actual_weight_num; ++i) { + CHECK_THROW(node.inTensors.at(i) == nullptr, + model_name_ << "inTensor " << i << "is NULL"); + node.variantPack.inTensors.at(i) = *node.inTensors.at(i); + // LOG(INFO) << model_name_ << "inTensors[" << i << "]:" + // << atb_speed::TensorUtil::TensorToString( + // node.variantPack.inTensors.at(i)); + } + node.variantPack.inTensors.at(actual_weight_num) = internal_tensors_; + node.variantPack.inTensors.at(actual_weight_num + 1) = + atb_speed::Utils::AtTensor2Tensor(cos_pos); + node.variantPack.inTensors.at(actual_weight_num + 2) = + atb_speed::Utils::AtTensor2Tensor(sin_pos); + node.variantPack.inTensors.at(actual_weight_num + 3) = + atb_speed::Utils::AtTensor2Tensor(cu_seqlen); + node.variantPack.inTensors.at(actual_weight_num + 3).hostData = + cu_seqlen_vec.data(); + + + node.variantPack.outTensors.at(0) = internal_tensors_; +} + +} // namespace layer +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/layers/npu/npu_glm4_vision_encoder_layer_impl.h b/xllm/core/layers/npu/npu_glm4_vision_encoder_layer_impl.h new file mode 100644 index 000000000..75f72aadf --- /dev/null +++ b/xllm/core/layers/npu/npu_glm4_vision_encoder_layer_impl.h @@ -0,0 +1,121 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once +#ifdef TORCH_HIGHER_THAN_PTA6 +#include +#include +#else +#include +#include +#endif + +#include + +#include + +#include "atb/atb_infer.h" +#include "atb_speed/base/hosttensor_binder.h" +#include "atb_speed/base/model.h" +#include "atb_speed/log.h" +#include "atb_speed/utils/model_factory.h" +#include "core/framework/model/model_args.h" +#include "core/framework/model/model_input_params.h" +#include "core/framework/state_dict/state_dict.h" +#include "nlohmann/json.hpp" +#include "npu_base_layer.h" +#include "pytorch/adapter/utils/utils.h" +#include "xllm_kernels/models/glm4v/glm4v_encoder.h" + +namespace xllm { +namespace layer { + +// copy from qwen3 vl, please follow its modifications +class NpuGlm4VisionEncoderLayerImpl : public NpuBaseLayer { + public: + explicit NpuGlm4VisionEncoderLayerImpl(const ModelContext& context); + + ~NpuGlm4VisionEncoderLayerImpl() {}; + + void load_state_dict(const StateDict& state_dict) override; + + void verify_loaded_weights() const override; + + void merge_loaded_weights() override; + + int64_t init_layer() override; + + torch::Tensor forward(torch::Tensor& x, + torch::Tensor& cos_pos, + torch::Tensor& sin_pos, + torch::Tensor& cu_seqlen, + std::vector& cu_seqlen_vec, + ModelInputParams& input_params, + int node_id = 0, + aclrtEvent* event = nullptr, + std::atomic* event_flag = nullptr); + void build_node_variant_pack(atb_speed::Model::Node& node, + torch::Tensor& x, + torch::Tensor& cos_pos, + torch::Tensor& sin_pos, + torch::Tensor& cu_seqlen, + std::vector& cu_seqlen_vec, + ModelInputParams& input_params, + bool is_prefill); + + void get_weights_col_packed_qkv(); + + void param_from_args(atb_speed::glm::VisionEncoderLayerParam& param, + const ModelArgs& args, + const ParallelArgs& parallel_args); + + int64_t init_node(atb_speed::Model::Node& node, + atb_speed::glm::VisionEncoderLayerParam& param); + + void pad_qkv_weights(); + + void pad_mlp_weights(); + + torch::Tensor pad_tensor(const torch::Tensor& tensor, + int64_t target_shape, + int64_t dim = 0) { + int64_t pad_size = target_shape - tensor.size(dim); + if (tensor.dim() == 1) { + return torch::nn::functional::pad( + tensor, torch::nn::functional::PadFuncOptions({0, pad_size})); + } else if (tensor.dim() == 2) { + if (1 == dim) + return torch::nn::functional::pad( + tensor, torch::nn::functional::PadFuncOptions({0, pad_size, 0, 0})); + else + return torch::nn::functional::pad( + tensor, torch::nn::functional::PadFuncOptions({0, 0, 0, pad_size})); + } + return tensor; + } + + atb_speed::Model::Node encode_node_; + std::string model_name_; + + atb_speed::glm::VisionEncoderLayerParam encode_param_; + atb::Tensor internal_tensors_; + atb::Tensor placeholder_; + at::Tensor cu_seqlen_; + at::Tensor at_placeholder_; + int device_id_; +}; + +} // namespace layer +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/runtime/vlm_master.cpp b/xllm/core/runtime/vlm_master.cpp index a031f2aa7..6eb12f1eb 100755 --- a/xllm/core/runtime/vlm_master.cpp +++ b/xllm/core/runtime/vlm_master.cpp @@ -220,7 +220,11 @@ void VLMMaster::handle_request(const std::vector& messages, "Image processor process failed."); return; } - + if (const auto& res = mm_data.get("image_grid_thw")) + { + auto image_grid_thw = res.value(); + LOG(INFO)<<"image_grid_thw:"<handle_request(messages, mm_data, sp, callback); } @@ -307,7 +311,6 @@ std::shared_ptr VLMMaster::generate_request(std::string prompt, } Timer timer; input_processor_->process(prompt, mm_data); - std::vector prompt_tokens; if (!tokenizer_->encode(prompt, &prompt_tokens)) { LOG(ERROR) << "Failed to encode prompt: " << prompt; diff --git a/xllm/models/llm/glm4.h b/xllm/models/llm/glm4.h new file mode 100644 index 000000000..3fae6a20d --- /dev/null +++ b/xllm/models/llm/glm4.h @@ -0,0 +1,219 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include "core/layers/glm4_decoder_layer.h" +#include "llm_model_base.h" + +namespace xllm { + +class Glm4DecoderLayerImpl + : public LlmDecoderLayerImplBase { + public: + Glm4DecoderLayerImpl(const ModelContext& context) + : LlmDecoderLayerImplBase(context) {} +}; +TORCH_MODULE(Glm4DecoderLayer); + +class Glm4ModelImpl : public LlmModelImplBase { + public: + Glm4ModelImpl(const ModelContext& context) + : LlmModelImplBase("glm4", context.get_model_args()) { + // register submodules + auto model_args = context.get_model_args(); + auto options = context.get_tensor_options(); + auto parallel_args = context.get_parallel_args(); + auto dp_local_tp_size = + parallel_args.world_size() / parallel_args.dp_size(); + dp_rank_ = parallel_args.rank() / dp_local_tp_size; + + blocks_ = register_module("layers", torch::nn::ModuleList()); + layers_.reserve(model_args.n_layers()); + norm_ = register_module("norm", layer::RmsNorm(context)); + embed_tokens_ = + register_module("embed_tokens", layer::WordEmbedding(context)); +#if defined(USE_NPU) + atb_pos_emb_ = layer::PosEmbedding(context); +#endif + cos_sin_ = + get_chatglm_rotary_embedding(64, + model_args.max_position_embeddings(), + model_args.rope_theta(), + options); +#if defined(USE_NPU) + int32_t mask_value = FLAGS_enable_chunked_prefill ? -9984 : 1; + attn_mask_ = layer::AttentionMask(options.device(), + options.dtype().toScalarType(), + /*mask_value=*/mask_value); +#endif + + for (int32_t i = 0; i < model_args.n_layers(); i++) { + auto block = Glm4DecoderLayer(context); + layers_.push_back(block); + blocks_->push_back(block); + } + } + + virtual torch::Tensor forward(torch::Tensor tokens, + torch::Tensor positions, + std::vector& kv_caches, + const ModelInputParams& input_params) { + ModelInputParams& input_params_new = + const_cast(input_params); + + if (tokens.numel() == 0) { + tokens = torch::tensor({1}).to(torch::kInt32).to(tokens.device()); + positions = torch::tensor({0}).to(torch::kInt32).to(tokens.device()); + } + auto inputs_embeds = input_params.input_embedding; + torch::Tensor h; + if (inputs_embeds.defined()) { + h = inputs_embeds; + } else { + h = embed_tokens_(tokens, 0); + } + auto target_cos_sin = atb_pos_emb_(cos_sin_, positions, 0); + auto target_cos_sin_chunks = target_cos_sin.chunk(/*chunks=*/2, /*dim=*/-1); + auto cos_pos = target_cos_sin_chunks[0].contiguous(); + + auto sin_pos = target_cos_sin_chunks[1].contiguous(); + + if (positions.dim() == 2) { // mrope + auto apply = [this](torch::Tensor x) { + auto sections = mrope_section_; + sections.insert(sections.end(), sections.begin(), sections.end()); + + auto vec = x.split(sections, -1); + std::vector selects; + selects.reserve(vec.size()); + + for (int64_t i = 0; i < vec.size(); ++i) { + auto m = vec[i]; + selects.push_back(m[i % mrope_section_.size()]); + } + return torch::cat(selects, -1); + }; + cos_pos = apply(cos_pos.reshape( + {positions.sizes().front(), -1, cos_pos.sizes().back()})); + sin_pos = apply(sin_pos.reshape( + {positions.sizes().front(), -1, sin_pos.sizes().back()})); + } + cos_pos = cos_pos.reshape({-1, cos_pos.sizes().back() /2, 2}); + sin_pos = sin_pos.reshape({-1, sin_pos.sizes().back() /2, 2}); + torch::Tensor attn_mask; + if (FLAGS_enable_chunked_prefill) { + int max_kv_seq = input_params.kv_max_seq_len; + int num_sequences = input_params.num_sequences; + if (num_sequences > 0) { + std::vector req_mask_vec; + req_mask_vec.reserve(num_sequences); + + for (int j = 0; j < num_sequences; j++) { + auto mask = + attn_mask_.gen_append_mask(input_params.q_seq_lens_vec[j], + input_params.kv_seq_lens_vec[j], + max_kv_seq, + cos_pos.dtype().toScalarType(), + cos_pos.device()); + req_mask_vec.emplace_back(mask); + } + attn_mask = torch::cat(req_mask_vec, 0); + } + } else { + if (FLAGS_num_speculative_tokens == 0 || + input_params.global_empty_kv_cache) { + attn_mask = attn_mask_.get_attn_mask( + 128, cos_pos.dtype().toScalarType(), cos_pos.device()); + } else { + attn_mask = attn_mask_.gen_free_mask(FLAGS_num_speculative_tokens + 1, + cos_pos.dtype().toScalarType(), + cos_pos.device()); + } + } + + for (size_t i = 0; i < layers_.size(); i++) { + aclrtEvent* event{nullptr}; + std::atomic* event_flag{nullptr}; + + if (input_params.layer_synchronizer != nullptr) { + event = input_params.layer_synchronizer->get_event(i); + event_flag = input_params.layer_synchronizer->get_event_flag(i); + } + if (input_params.layer_wise_load_synchronizer != nullptr) { + if (!input_params.layer_wise_load_synchronizer->synchronize_layer(i)) { + return torch::Tensor(); + } + } + + auto& layer = layers_[i]; + + layer(h, + cos_pos, + sin_pos, + attn_mask, + kv_caches[i], + input_params_new, + i, + event, + event_flag); + } + return norm_(h, 0); + } + + private: + torch::Tensor viusal_pos_mask_; +}; +TORCH_MODULE(Glm4Model); + +class Glm4ForCausalLMImpl : public LlmForCausalLMImplBase { + public: + Glm4ForCausalLMImpl(const ModelContext& context) + : LlmForCausalLMImplBase(context) {} +}; +TORCH_MODULE(Glm4ForCausalLM); + +// register the causal model +REGISTER_CAUSAL_MODEL(glm4, Glm4ForCausalLM); + +// register the model args +REGISTER_MODEL_ARGS(glm4, [&] { + LOAD_ARG_OR(model_type, "model_type", "glm4"); + + LOAD_ARG_OR(dtype, "torch_dtype", ""); + LOAD_ARG_OR(attention_bias, "attention_bias", true); + LOAD_ARG_OR(attention_dropout, "attention_dropout", 0.0f); + LOAD_ARG_OR(eos_token_id_vec, "eos_token_id", std::vector{151329}); + LOAD_ARG_OR(head_dim, "head_dim", 128); + LOAD_ARG_OR(hidden_act, "hidden_act", "silu"); + LOAD_ARG_OR(hidden_size, "hidden_size", 4096); + LOAD_ARG_OR(initializer_range, "initializer_range", 0.02f); + LOAD_ARG_OR(intermediate_size, "intermediate_size", 13696); + LOAD_ARG_OR(max_position_embeddings, "max_position_embeddings", 32768); + LOAD_ARG_OR(n_heads, "num_attention_heads", 32); + LOAD_ARG_OR(n_layers, "num_hidden_layers", 40); + LOAD_ARG_OR(n_kv_heads, "num_key_value_heads", 2); + LOAD_ARG_OR(pad_token_id, "pad_token_id", 151329); + LOAD_ARG_OR(rms_norm_eps, "rms_norm_eps", 1e-5); + LOAD_ARG_OR(rope_theta, "rope_theta", 10000.0f); + LOAD_ARG_OR(tie_word_embeddings, "tie_word_embeddings", false); + LOAD_ARG_OR(vocab_size, "vocab_size", 151552); + + SET_ARG(stop_token_ids, + std::unordered_set(args->eos_token_id_vec().begin(), + args->eos_token_id_vec().end())); +}); + +} // namespace xllm diff --git a/xllm/models/llm/glm4_moe.h b/xllm/models/llm/glm4_moe.h index f9e582838..401931ceb 100644 --- a/xllm/models/llm/glm4_moe.h +++ b/xllm/models/llm/glm4_moe.h @@ -117,6 +117,14 @@ class Glm4MoeModelImpl : public torch::nn::Module { } } + torch::Tensor get_input_embeddings(torch::Tensor input_ids) { +#if defined(USE_NPU) + return embed_tokens_(input_ids, 0); +#else + return embed_tokens_(input_ids); +#endif + } + // tokens: [num_tokens] // positions: [num_tokens] token pos in the sequence torch::Tensor forward(torch::Tensor tokens, @@ -260,6 +268,10 @@ class Glm4MoeForCausalLMImpl : public torch::nn::Module { lm_head_ = register_module("lm_head", layer::LmHead(context)); } + torch::Tensor get_input_embeddings(torch::Tensor input_ids) { + return model_->get_input_embeddings(input_ids); + } + // tokens: [num_tokens] // positions: [num_tokens] token pos in the sequence // returns: [num_tokens, hidden_size] @@ -280,14 +292,15 @@ class Glm4MoeForCausalLMImpl : public torch::nn::Module { return lm_head_(hidden_states, seleted_idxes, 0); } - void load_model(std::unique_ptr loader) { + void load_model(std::unique_ptr loader, + std::string prefix = "model." /*llm model weight prefix*/) { for (const auto& state_dict : loader->get_state_dicts()) { - model_->load_state_dict(state_dict->get_dict_with_prefix("model.")); + model_->load_state_dict(state_dict->get_dict_with_prefix(prefix)); lm_head_->load_state_dict(state_dict->get_dict_with_prefix("lm_head.")); } // verify - model_->verify_loaded_weights("model."); + model_->verify_loaded_weights(prefix); lm_head_->verify_loaded_weights("lm_head."); model_->merge_loaded_weights(); diff --git a/xllm/models/llm/llm_model_base.h b/xllm/models/llm/llm_model_base.h index 441b52004..6c380cde0 100644 --- a/xllm/models/llm/llm_model_base.h +++ b/xllm/models/llm/llm_model_base.h @@ -50,10 +50,11 @@ limitations under the License. namespace xllm { -torch::Tensor get_concat_rotary_embedding(int64_t dim, - int64_t seq_len, - double rope_theta, - const torch::TensorOptions& options) { +torch::Tensor compute_rotary_embedding(int64_t dim, + int64_t seq_len, + double rope_theta, + const torch::TensorOptions& options, + bool use_cat) { auto options_new = torch::device(options.device()).dtype(at::ScalarType::Double); auto inv_freq = @@ -62,7 +63,13 @@ torch::Tensor get_concat_rotary_embedding(int64_t dim, auto seq_idx = torch::arange(seq_len, options_new); auto freqs = torch::ger(seq_idx, inv_freq).to(torch::kFloat32); - auto emb = torch::cat({freqs, freqs}, -1); + torch::Tensor emb; + if (use_cat) { + emb = torch::cat({freqs, freqs}, -1); + } else { + emb = torch::stack({freqs, freqs}, -1); + emb = emb.reshape({seq_len, dim}); + } auto rope_cos = torch::cos(emb); auto rope_sin = torch::sin(emb); @@ -81,6 +88,21 @@ torch::Tensor get_concat_rotary_embedding(int64_t dim, return torch::cat(cos_sin, -1); } +torch::Tensor get_concat_rotary_embedding(int64_t dim, + int64_t seq_len, + double rope_theta, + const torch::TensorOptions& options) { + return compute_rotary_embedding(dim, seq_len, rope_theta, options, true); +} + +torch::Tensor get_chatglm_rotary_embedding( + int64_t dim, + int64_t seq_len, + double rope_theta, + const torch::TensorOptions& options) { + return compute_rotary_embedding(dim, seq_len, rope_theta, options, false); +} + template class LlmDecoderLayerImplBase : public torch::nn::Module { public: diff --git a/xllm/models/models.h b/xllm/models/models.h index 0460d6ff5..9fbc3ea45 100644 --- a/xllm/models/models.h +++ b/xllm/models/models.h @@ -27,12 +27,16 @@ limitations under the License. #include "llm/deepseek_v2.h" // IWYU pragma: keep #include "llm/deepseek_v2_mtp.h" // IWYU pragma: keep #include "llm/deepseek_v3.h" // IWYU pragma: keep +#include "llm/glm4.h" // IWYU pragma: keep #include "llm/glm4_moe.h" // IWYU pragma: keep #include "llm/glm4_moe_mtp.h" // IWYU pragma: keep #include "llm/kimi_k2.h" // IWYU pragma: keep #include "llm/llama.h" // IWYU pragma: keep #include "llm/llama3.h" // IWYU pragma: keep #include "llm/qwen3_embedding.h" // IWYU pragma: keep +#include "vlm/glm4v.h" // IWYU pragma: keep +#include "vlm/glm4v_moe.h" // IWYU pragma: keep +#include "vlm/glm4v.h" // IWYU pragma: keep #include "vlm/minicpmv.h" // IWYU pragma: keep #include "vlm/qwen2_5_vl.h" // IWYU pragma: keep #include "vlm/qwen3_vl.h" // IWYU pragma: keep diff --git a/xllm/models/vlm/glm4v.h b/xllm/models/vlm/glm4v.h new file mode 100644 index 000000000..8ef661a3c --- /dev/null +++ b/xllm/models/vlm/glm4v.h @@ -0,0 +1,958 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include "core/framework/kv_cache/kv_cache.h" +#include "core/framework/model/model_input_params.h" +#include "core/layers/lm_head.h" +#include "models/model_registry.h" +#include "processors/input_processor.h" +#include "xllm_kernels/core/include/atb_speed/log.h" +#include "models/llm/glm4.h" +#include "xllm/core/layers/glm4_vision_encode_layer.h" +#include "torch_npu/csrc/aten/CustomFunctions.h" + + +namespace xllm { + +class GLM4_6_VLInputProcessor : public InputProcessor { + enum class TokenType { + INVALID, + IMAGE, + VIDEO, + }; + + public: + GLM4_6_VLInputProcessor(const ModelArgs& args) { + merge_size_ = args.mm_image_merge_size(); + } + + void process(std::string& prompt, const MMData& mm_data) override { + torch::Tensor image_grid_thw; + if (auto res = mm_data.get("image_grid_thw")) + image_grid_thw = res.value(); + + torch::Tensor video_grid_thw; + if (auto res = mm_data.get("video_grid_thw")) + video_grid_thw = res.value(); + + if (!image_grid_thw.defined() && !video_grid_thw.defined()) return; + + const auto& video_metadata = mm_data.get_video_metadata(); + if (video_metadata.size() > 0) { + CHECK(video_metadata.size() == + static_cast(video_grid_thw.sizes()[0])); + } + + auto merge_length = merge_size_ * merge_size_; + int total_image_token = 0; + + if (image_grid_thw.defined()) { + auto count = image_grid_thw.sizes()[0]; + for (int idx = 0; idx < count; ++idx) + total_image_token += + image_grid_thw[idx].prod().item() / merge_length; + } + + int total_video_token = 0; + if (video_grid_thw.defined()) { + auto count = video_grid_thw.sizes()[0]; + for (int idx = 0; idx < count; ++idx) + total_video_token += video_grid_thw[idx].prod().item() / + merge_length / video_grid_thw[idx][0].item(); + } + + size_t total_token_len = total_image_token * image_token_.size() + + total_video_token * image_token_.size(); + std::string data; + data.reserve(prompt.size() + total_token_len); + + int image_index = 0; + int video_index = 0; + + size_t begin = 0; + auto pair = find_vision_token(prompt, begin); + + while (pair.second != std::string::npos) { + data.append(prompt, begin, pair.second - begin); + + if (pair.first == TokenType::IMAGE) { + auto token_num = + image_grid_thw[image_index].prod().item() / merge_length; + while (token_num--) data.append(image_token_); + + image_index++; + begin = pair.second + image_token_.size(); + } else if (pair.first == TokenType::VIDEO) { + auto num_frames = video_grid_thw[video_index][0].item(); + auto timestamps = video_metadata[video_index].timestamps; + CHECK(!timestamps.empty()); + + auto selected = build_timestamps(timestamps, num_frames); + auto token_num = video_grid_thw[video_index].prod().item() / + merge_length / num_frames; + + for (size_t idx = 0; idx < num_frames; ++idx) { + data.append(begin_of_image_token_); + + auto num = token_num; + while (num--) data.append(image_token_); + + data.append(end_of_image_token_); + data.append(format_timestamp_str(selected[idx])); + } + + video_index++; + begin = pair.second + video_token_.size(); + } else { + assert(false); + } + + pair = find_vision_token(prompt, begin); + } + + if (begin < prompt.size()) data.append(prompt, begin, std::string::npos); + + prompt = std::move(data); + } + + private: + std::pair find_vision_token(const std::string& prompt, + size_t begin) { + auto img_pos = prompt.find(image_token_, begin); + auto vid_pos = prompt.find(video_token_, begin); + + if (img_pos == std::string::npos && vid_pos == std::string::npos) + return {TokenType::INVALID, std::string::npos}; + else if (vid_pos == std::string::npos) + return {TokenType::IMAGE, img_pos}; + else if (img_pos == std::string::npos) + return {TokenType::VIDEO, vid_pos}; + else + return img_pos < vid_pos ? std::make_pair(TokenType::IMAGE, img_pos) + : std::make_pair(TokenType::VIDEO, vid_pos); + } + + std::vector build_timestamps(const std::vector& timestamps, + size_t num_frames) { + std::vector vec; + vec.reserve(num_frames); + + for (size_t i = 0; i < timestamps.size(); i += 2) { + vec.push_back(timestamps[i]); + if (vec.size() == num_frames) break; + } + + while (vec.size() < num_frames) { + vec.push_back(vec.back()); + } + + return vec; + } + + std::string format_timestamp_str(double timestamp) { + char buffer[32]; + sprintf(buffer, "%.1f seconds", timestamp); + return buffer; + } + + private: + const std::string image_token_ = "<|image|>"; + const std::string video_token_ = "<|video|>"; + + const std::string begin_of_image_token_ = "<|begin_of_image|>"; + const std::string end_of_image_token_ = "<|end_of_image|>"; + + int merge_size_ = 0; +}; + +class Glm4VisionRmsNormImpl : public torch::nn::Module { + public: + torch::Tensor weight; + Glm4VisionRmsNormImpl(const ModelContext& context){ + auto model_args = context.get_model_args(); + auto options = context.get_tensor_options(); + weight = torch::empty({model_args.mm_hidden_size()}, options); + epsilon_ = 1e-5; + } + + torch::Tensor forward(torch::Tensor& x){ + auto results = at_npu::native::custom_ops::npu_rms_norm(x, weight, epsilon_); + return std::get<0>(results); + } + private: + double epsilon_; +}; +TORCH_MODULE(Glm4VisionRmsNorm); + +class Glm4VisionPatchEmbedImpl : public torch::nn::Module { + public: + Glm4VisionPatchEmbedImpl(const ModelContext& context) { + auto model_args = context.get_model_args(); + auto options = context.get_tensor_options(); + + auto in_features = model_args.mm_num_channels() * + model_args.mm_temporal_patch_size() * + model_args.mm_patch_size() * model_args.mm_patch_size(); + + auto out_features = model_args.mm_hidden_size(); + + proj_ = register_module( + "proj", + torch::nn::Linear( + torch::nn::LinearOptions(in_features, out_features).bias(true))); + + proj_->weight.set_data(proj_->weight.to(options)); + proj_->bias.set_data(proj_->bias.to(options)); + } + + torch::Tensor forward(torch::Tensor x) { return proj_(x); } + + void load_state_dict(const StateDict& state_dict) { + auto weight = state_dict.get_tensor("proj.weight"); + if (weight.defined()) { + weight = weight.reshape({weight.size(0), -1}); + DCHECK_EQ(proj_->weight.sizes(), weight.sizes()) + << "proj weight size mismatch for " << name(); + proj_->weight.data().copy_(weight); + proj_weight_loaded_ = true; + } + auto bias = state_dict.get_tensor("proj.bias"); + if (bias.defined()) { + bias = bias.reshape({bias.size(0)}); + DCHECK_EQ(proj_->bias.sizes(), bias.sizes()) + << "proj bias size mismatch for " << name(); + proj_->bias.data().copy_(bias); + proj_bias_loaded_ = true; + } + } + + void verify_loaded_weights(const std::string& prefix) const { + CHECK(proj_weight_loaded_) + << "weight is not loaded for " << prefix + "proj.weight"; + CHECK(proj_bias_loaded_) + << "bias is not loaded for " << prefix + "proj.bias"; + } + + private: + bool proj_weight_loaded_ = false; + bool proj_bias_loaded_ = false; + torch::nn::Linear proj_{nullptr}; +}; +TORCH_MODULE(Glm4VisionPatchEmbed); + +class Glm4_VisionBlockImpl : public torch::nn::Module { + public: + Glm4_VisionBlockImpl(const ModelContext& context) { + // register submodules + encoder_layer_ = register_module("encoder_layer", + layer::Glm4VisionEncoderLayer(context)); + } + //TO DO + torch::Tensor forward(torch::Tensor& x, + torch::Tensor& m_cos_pos, + torch::Tensor& m_sin_pos, + torch::Tensor& cu_seq_len, + std::vector& cu_seq_len_vec, + ModelInputParams& input_params, + int node_id) { + return encoder_layer_(x, + m_cos_pos, + m_sin_pos, + cu_seq_len, + cu_seq_len_vec, + input_params, + node_id); + } + + // load the weight from the checkpoint + void load_state_dict(const StateDict& state_dict) { + // call each submodule's load_state_dict function + encoder_layer_->load_state_dict(state_dict); + } + + void verify_loaded_weights(const std::string& prefix) const { + encoder_layer_->verify_loaded_weights(); + } + void merge_loaded_weights() { encoder_layer_->merge_loaded_weights(); } + + private: + layer::Glm4VisionEncoderLayer encoder_layer_{nullptr}; +}; +TORCH_MODULE(Glm4_VisionBlock); + +class Glm4VisionRotaryEmbeddingImpl : public torch::nn::Module { + public: + Glm4VisionRotaryEmbeddingImpl(const ModelContext& context) { + auto model_args = context.get_model_args(); + auto options = context.get_tensor_options(); + + dim_ = model_args.mm_head_dim() / 2; + theta_ = 10000.0; + + auto opts = options.dtype(torch::kFloat32); + auto inv_freq = + 1.0 / torch::pow(theta_, torch::arange(0, dim_, 2, opts) / dim_); + inv_freq_ = register_buffer("inv_freq", inv_freq); + } + + void update_freqs_cache(int64_t seqlen) { + if (seqlen <= seq_len_cached_) return; + + seqlen *= 2; + seq_len_cached_ = seqlen; + + auto options = torch::TensorOptions() + .dtype(torch::kFloat32) + .device(inv_freq_.device()); + inv_freq_ = + 1.0 / torch::pow(theta_, torch::arange(0, dim_, 2, options) / dim_); + auto seq = torch::arange(seqlen, options); + freqs_cached_ = torch::outer(seq, inv_freq_); + } + + torch::Tensor forward(int seqlen) { + update_freqs_cache(seqlen); + return freqs_cached_.slice(0, 0, seqlen); + } + + private: + int dim_ = 0; + double theta_ = 0.0; + + int64_t seq_len_cached_ = 0; + torch::Tensor inv_freq_; + torch::Tensor freqs_cached_; +}; +TORCH_MODULE(Glm4VisionRotaryEmbedding); + +class Glm4vVisionEmbeddingsImpl : public torch::nn::Module { + public: + Glm4vVisionEmbeddingsImpl(const ModelContext& context) { + auto model_args = context.get_model_args(); + auto options = context.get_tensor_options(); + embed_dim_ = model_args.mm_hidden_size(); + image_size_ = model_args.mm_image_size(); + patch_size_ = model_args.mm_patch_size(); + num_positions_ = image_size_ / patch_size_; + num_positions_ = num_positions_ * num_positions_; + position_embedding_ = register_module( + "position_embedding", + torch::nn::Embedding(num_positions_, embed_dim_) + ); + position_embedding_->weight.set_data(position_embedding_->weight.to(options)); + } + torch::Tensor forward( + torch::Tensor x, + std::vector lengths, + torch::Tensor image_shapes, + torch::Tensor h_coords, + torch::Tensor w_coords + ) { + const auto& pos_embed_weight = position_embedding_->weight; + const int64_t hidden_size = pos_embed_weight.size(1); + const int64_t total_seq = x.size(0); + const auto device = pos_embed_weight.device(); + const auto dtype = pos_embed_weight.dtype(); + + image_shapes = image_shapes.to(device); + h_coords = h_coords.to(device); + w_coords = w_coords.to(device); + x = x.to(device, dtype); + + torch::Tensor adapted_pos_embed; + if (total_seq == 0) { + adapted_pos_embed = torch::empty( + {0, hidden_size}, + torch::TensorOptions().device(device).dtype(dtype) + ); + } else { + const int64_t batch_size = static_cast(lengths.size()); + const int64_t orig_size_sq = pos_embed_weight.size(0); + const int64_t orig_size = static_cast(std::sqrt(orig_size_sq)); + auto pos_embed_2d = pos_embed_weight + .view({orig_size, orig_size, hidden_size}) + .permute({2, 0, 1}) + .unsqueeze(0) + .to(torch::kFloat32); + + std::vector target_h_list; + std::vector target_w_list; + target_h_list.reserve(batch_size); + target_w_list.reserve(batch_size); + LOG(INFO) << " Glm4vVisionEmbeddingsImpl forward batch_size: " << batch_size << "image_shapes " << image_shapes; + for (int64_t i = 0; i < batch_size; ++i) { + const int64_t seq_len = lengths[i]; + const auto img_h = image_shapes.index({i, 1}).to(torch::kFloat32); + const auto img_w = image_shapes.index({i, 2}).to(torch::kFloat32); + LOG(INFO) << " Glm4vVisionEmbeddingsImpl forward batch_size idx " << i; + target_h_list.push_back(img_h.repeat({seq_len})); + target_w_list.push_back(img_w.repeat({seq_len})); + } + + auto target_h = torch::cat(target_h_list, 0); + auto target_w = torch::cat(target_w_list, 0); + + auto h_coords_fp32 = h_coords.to(torch::kFloat32); + auto w_coords_fp32 = w_coords.to(torch::kFloat32); + + const auto norm_w = ((w_coords_fp32 + 0.5f) / target_w) * 2.0f - 1.0f; + const auto norm_h = ((h_coords_fp32 + 0.5f) / target_h) * 2.0f - 1.0f; + LOG(INFO) << " Glm4vVisionEmbeddingsImpl stack"; + auto grid = torch::stack({norm_w, norm_h}, -1) + .unsqueeze(0) + .unsqueeze(2); + LOG(INFO) << " Glm4vVisionEmbeddingsImpl stack after"; + namespace F = torch::nn::functional; + auto interpolated_embed = F::grid_sample( + pos_embed_2d, + grid, + F::GridSampleFuncOptions().mode(torch::kBicubic).padding_mode(torch::kBorder).align_corners(false)); + LOG(INFO) << " Glm4vVisionEmbeddingsImpl interpolated_embed"; + adapted_pos_embed = interpolated_embed + .squeeze(0) + .squeeze(-1) + .permute({1, 0}) + .to(dtype); + } + + return x + adapted_pos_embed; + } + + void load_state_dict(const StateDict& state_dict) { + auto weight = state_dict.get_tensor("position_embedding.weight"); + if (weight.defined()) { + position_embedding_->weight.data().copy_(weight); + position_embedding_weight_loaded_ = true; + } + } + + void verify_loaded_weights(const std::string& prefix) const { + CHECK(position_embedding_weight_loaded_) + << "weight is not loaded for " << prefix + "position_embedding.weight"; + } + private: + int64_t embed_dim_ = 0; + int64_t image_size_ = 0; + int64_t patch_size_ = 0 ; + int64_t num_positions_ = 0; + bool position_embedding_weight_loaded_ = false; + torch::nn::Embedding position_embedding_{nullptr}; +}; +TORCH_MODULE(Glm4vVisionEmbeddings); + +class Glm4_VisionPatchMergerImpl : public torch::nn::Module { + public: + Glm4_VisionPatchMergerImpl(const ModelContext& context) { + auto model_args = context.get_model_args(); + options_ = context.get_tensor_options(); + auto parallel_args = context.get_parallel_args(); + int64_t dim = model_args.mm_projection_dim(); + int64_t context_dim = model_args.mm_intermediate_size(); + norm_ = register_module("norm", torch::nn::LayerNorm(torch::nn::LayerNormOptions({dim}))); + norm_->weight.set_data(norm_->weight.to(options_)); + norm_->bias.set_data(norm_->bias.to(options_)); + proj_ = register_module( + "proj", + torch::nn::Linear(torch::nn::LinearOptions(dim, dim).bias(false))); + proj_->weight.set_data(proj_->weight.to(options_)); + act_ = register_module("act", torch::nn::GELU()); + silu_ = register_module("silu", torch::nn::SiLU()); + + gate_ = register_module( + "gate", + torch::nn::Linear(torch::nn::LinearOptions(dim, context_dim).bias(false))); + gate_->weight.set_data(gate_->weight.to(options_)); + up_ = register_module( + "up", + torch::nn::Linear(torch::nn::LinearOptions(dim, context_dim).bias(false))); + up_->weight.set_data(up_->weight.to(options_)); + down_ = register_module( + "down", + torch::nn::Linear(torch::nn::LinearOptions(context_dim, dim).bias(false))); + down_->weight.set_data(down_->weight.to(options_)); + } + + torch::Tensor forward(torch::Tensor x) { + LOG(INFO) << " Glm4_VisionPatchMergerImpl forward beging " << x.device() << "options_.device() : " << options_.device(); + x = proj_(x); + x = act_(norm_(x)); + x = down_(torch::mul(silu_((gate_(x))), up_(x))); + return x; + } + + void load_state_dict(const StateDict& state_dict) { + // norm + const auto& norm_dict = state_dict.get_dict_with_prefix("post_projection_norm."); + const auto& norm_weight = norm_dict.get_tensor("weight"); + if (norm_weight.defined()) { + CHECK_EQ(norm_->weight.sizes(), norm_weight.sizes()) + << "weight size mismatch for " << name(); + norm_->weight.data().copy_(norm_weight); + is_norm_weight_loaded = true; + } + const auto norm_bias = norm_dict.get_tensor("bias"); + if (norm_bias.defined()) { + CHECK_EQ(norm_->bias.sizes(), norm_bias.sizes()) + << "bias size mismatch for " << name(); + norm_->bias.data().copy_(norm_bias); + is_norm_bias_loaded = true; + } + + const auto& proj_dict = state_dict.get_dict_with_prefix("proj."); + const auto& proj_weight = proj_dict.get_tensor("weight"); + if (proj_weight.defined()) { + proj_->weight.data().copy_(proj_weight); + is_proj_weight_loaded = true; + } + + const auto& up_dict = state_dict.get_dict_with_prefix("up_proj."); + const auto& up_weight = up_dict.get_tensor("weight"); + if (up_weight.defined()) { + up_->weight.data().copy_(up_weight); + is_up_weight_loaded = true; + } + + const auto& down_dict = state_dict.get_dict_with_prefix("down_proj."); + const auto& down_weight = down_dict.get_tensor("weight"); + if (down_weight.defined()) { + down_->weight.data().copy_(down_weight); + is_down_weight_loaded = true; + } + + const auto& gate_dict = state_dict.get_dict_with_prefix("gate_proj."); + const auto& gate_weight = gate_dict.get_tensor("weight"); + if (gate_weight.defined()) { + gate_->weight.data().copy_(gate_weight); + is_gate_weight_loaded = true; + } + } + + void verify_loaded_weights(const std::string& prefix) const { + CHECK(is_proj_weight_loaded) + << "weight is not loaded for " << prefix + "proj_weight" + ".weight"; + CHECK(is_up_weight_loaded) + << "weight is not loaded for " << prefix + "up_weight" + ".weight"; + CHECK(is_down_weight_loaded) + << "weight is not loaded for " << prefix + "down_weight" + ".weight"; + CHECK(is_gate_weight_loaded) + << "weight is not loaded for " << prefix + "gate_weight" + ".weight"; + CHECK(is_norm_weight_loaded) + << "weight is not loaded for " << prefix + "norm" + ".weight"; + CHECK(is_norm_bias_loaded) + << "bias is not loaded for " << prefix + "norm" + ".bias"; + } + + private: + torch::nn::LayerNorm norm_{nullptr}; + torch::nn::Linear proj_{nullptr}; + torch::nn::Linear up_{nullptr}; + torch::nn::Linear gate_{nullptr}; + torch::nn::Linear down_{nullptr}; + torch::nn::GELU act_{nullptr}; + torch::nn::SiLU silu_{nullptr}; + torch::TensorOptions options_; + + + bool is_proj_weight_loaded = false; + bool is_up_weight_loaded = false; + bool is_down_weight_loaded = false; + bool is_gate_weight_loaded = false; + bool is_norm_weight_loaded = false; + bool is_norm_bias_loaded = false; +}; +TORCH_MODULE(Glm4_VisionPatchMerger); + +class Glm4VisionTransformerImpl : public torch::nn::Module { + public: + Glm4VisionTransformerImpl(const ModelContext& context): options_(context.get_tensor_options()) { + auto model_args = context.get_model_args(); + spatial_merge_size_ = model_args.mm_spatial_merge_size(); + hidden_size_ = model_args.mm_hidden_size(); + out_hidden_size_ = model_args.mm_projection_dim(); + + patch_embed_ = + register_module("patch_embed", Glm4VisionPatchEmbed(context)); + rotary_pos_emb_ = + register_module("rotary_pos_emb", Glm4VisionRotaryEmbedding(context)); + post_conv_layernorm_ = register_module("post_conv_layernorm", Glm4VisionRmsNorm(context)); + + embeddings_ = register_module("embeddings", Glm4vVisionEmbeddings(context)); + + blocks_ = register_module("blocks", torch::nn::ModuleList()); + + for (int32_t idx = 0; idx < model_args.mm_num_hidden_layers(); idx++) { + auto block = Glm4_VisionBlock(context); + blocks_->push_back(block); + layers_.push_back(block); + } + post_layernorm_ = register_module("post_layernorm", Glm4VisionRmsNorm(context)); + + downsample_ = register_module("downsample", torch::nn::Conv2d(torch::nn::Conv2dOptions(hidden_size_, out_hidden_size_, spatial_merge_size_) + .stride(spatial_merge_size_).bias(true).padding(0))); + downsample_->weight.set_data(downsample_->weight.to(options_)); + downsample_->bias.set_data(downsample_->bias.to(options_)); + merger_ = register_module("merger", Glm4_VisionPatchMerger(context)); + + } + std::tuple rot_pos_emb(torch::Tensor grid_thw) { + std::vector pos_ids_vec; + auto count = grid_thw.sizes()[0]; + pos_ids_vec.reserve(count); + auto options = + torch::TensorOptions().dtype(torch::kLong).device(grid_thw.device()); + + auto grid_thw_cpu = grid_thw.cpu(); + for (int idx = 0; idx < count; ++idx) { + auto t = grid_thw_cpu[idx][0].item(); + auto h = grid_thw_cpu[idx][1].item(); + auto w = grid_thw_cpu[idx][2].item(); + auto hpos_ids = torch::arange(h, options).unsqueeze(1).expand({-1, w}); + hpos_ids = hpos_ids.reshape({h / spatial_merge_size_, + spatial_merge_size_, + w / spatial_merge_size_, + spatial_merge_size_}).permute({0, 2, 1, 3}).flatten(); + auto wpos_ids = torch::arange(w, options).unsqueeze(0).expand({h, -1}); + wpos_ids = wpos_ids.reshape({h / spatial_merge_size_, + spatial_merge_size_, + w / spatial_merge_size_, + spatial_merge_size_}).permute({0, 2, 1, 3}).flatten(); + pos_ids_vec.push_back(torch::stack({hpos_ids, wpos_ids}, -1).repeat({t, 1})); + } + auto pos_ids = torch::cat(pos_ids_vec, 0); + auto max_grid_size = grid_thw.index({torch::indexing::Slice(), + torch::indexing::Slice(1, torch::indexing::None)}).max(); + auto rotary_pos_emb_full = rotary_pos_emb_(max_grid_size.item()); + auto rotary_pos_emb = rotary_pos_emb_full.index({pos_ids}).flatten(1); + + return std::make_tuple(rotary_pos_emb, pos_ids); + } + + torch::Tensor forward( + torch::Tensor hidden_states, + torch::Tensor grid_thw, + const ModelInputParams& input_params) { + LOG(INFO) << " Glm4VisionTransformerImpl forward beging "; + hidden_states = patch_embed_(hidden_states); + LOG(INFO) << " Glm4VisionTransformerImpl patch_embed_ beging "; + hidden_states = post_conv_layernorm_(hidden_states); + LOG(INFO) << " Glm4VisionTransformerImpl post_conv_layernorm_ beging "; + + auto [rotary_pos_emb, image_type_ids] = rot_pos_emb(grid_thw); + auto emb = torch::cat({rotary_pos_emb, rotary_pos_emb}, -1); + auto m_cos = emb.cos().type_as(hidden_states); + auto m_sin = emb.sin().type_as(hidden_states); + + auto device = grid_thw.device(); + auto grid_t = grid_thw.index_select(1, torch::tensor({0}, torch::TensorOptions().dtype(torch::kInt).device(device))); + auto grid_h = grid_thw.index_select(1, torch::tensor({1}, torch::TensorOptions().dtype(torch::kInt).device(device))); + auto grid_w = grid_thw.index_select(1, torch::tensor({2}, torch::TensorOptions().dtype(torch::kInt).device(device))); + auto h_times_w = (grid_h * grid_w).squeeze(1); + auto repeats = grid_t.squeeze(1); + auto repeated = torch::repeat_interleave(h_times_w, repeats, 0); + c10::optional cumsum_dtype; + + cumsum_dtype = torch::kInt32; + auto cu_seqlens = torch::cumsum(repeated, 0, cumsum_dtype); + namespace F = torch::nn::functional; + cu_seqlens = F::pad(cu_seqlens, F::PadFuncOptions({1, 0}).mode(torch::kConstant).value(0)); + cu_seqlens = torch::diff(cu_seqlens).cpu().to(torch::kInt); + std::vector seqlens; + seqlens.assign(cu_seqlens.data_ptr(),cu_seqlens.data_ptr() + cu_seqlens.numel()); + + hidden_states = embeddings_(hidden_states, seqlens, grid_thw, image_type_ids.select(1, 0), image_type_ids.select(1, 1)); + ModelInputParams& input_params_new = const_cast(input_params); + torch::Tensor cu_seqlens_cpu = cu_seqlens.cpu(); + std::vector cu_seqlens_vec( + cu_seqlens_cpu.data_ptr(), + cu_seqlens_cpu.data_ptr() + cu_seqlens_cpu.numel()); + cu_seqlens = cu_seqlens.to(hidden_states.device()); + for (int idx = 0; idx < blocks_->size(); ++idx) { + hidden_states = layers_[idx](hidden_states, m_cos, m_sin, cu_seqlens, cu_seqlens_vec, input_params_new, idx); + } + hidden_states = post_layernorm_(hidden_states); + hidden_states = hidden_states.view({-1, spatial_merge_size_, spatial_merge_size_, hidden_states.size(-1)}); + hidden_states = hidden_states.permute({0, 3, 1, 2}); + hidden_states = downsample_(hidden_states).view({-1, out_hidden_size_}); + hidden_states = merger_(hidden_states); + return hidden_states; + }; + + void load_state_dict(const StateDict& state_dict) { + patch_embed_->load_state_dict( + state_dict.get_dict_with_prefix("patch_embed.")); + embeddings_->load_state_dict(state_dict.get_dict_with_prefix("embeddings.")); + const auto& norm_weight = state_dict.get_dict_with_prefix("post_conv_layernorm.").get_tensor("weight"); + if (norm_weight.defined()) { + CHECK_EQ(post_conv_layernorm_->weight.sizes(), norm_weight.sizes()) + << "weight size mismatch for " << name(); + post_conv_layernorm_->weight.data().copy_(norm_weight); + is_post_conv_layernorm_weight_loaded = true; + } + for (int idx = 0; idx < layers_.size(); ++idx) { + layers_[idx]->load_state_dict(state_dict.get_dict_with_prefix( + "blocks." + std::to_string(idx) + ".")); + } + + const auto& post_norm_weight = state_dict.get_dict_with_prefix("post_layernorm.").get_tensor("weight"); + if (post_norm_weight.defined()) { + CHECK_EQ(post_layernorm_->weight.sizes(), post_norm_weight.sizes()) + << "weight size mismatch for " << name(); + post_layernorm_->weight.data().copy_(post_norm_weight); + is_post_layernorm_weight_loaded = true; + } + const auto& downsample_dict = state_dict.get_dict_with_prefix("downsample."); + const auto& downsample_weight = downsample_dict.get_tensor("weight"); + const auto& downsample_bias = downsample_dict.get_tensor("bias"); + if (downsample_weight.defined()) { + downsample_->weight.data().copy_(downsample_weight); + is_downsample_weight_loaded_ = true; + } + if (downsample_bias.defined()) { + downsample_->bias.data().copy_(downsample_bias); + is_downsample_bias_loaded_ = true; + } + merger_->load_state_dict(state_dict.get_dict_with_prefix("merger.")); + } + + void verify_loaded_weights(const std::string& prefix) const { + patch_embed_->verify_loaded_weights(prefix + "patch_embed."); + embeddings_->verify_loaded_weights(prefix + "embeddings."); + CHECK(is_post_conv_layernorm_weight_loaded) + << "weight is not loaded for " << prefix + "post_conv_layernorm.weight"; + for (int idx = 0; idx < blocks_->size(); ++idx) { + layers_[idx]->verify_loaded_weights(prefix + "blocks." + + std::to_string(idx) + "."); + } + CHECK(is_post_layernorm_weight_loaded) + << "weight is not loaded for " << prefix + "post_layernorm.weight"; + merger_->verify_loaded_weights(prefix + "merger."); + + CHECK(is_downsample_weight_loaded_) + << "weight is not loaded for " << prefix + "downsample.weight"; + CHECK(is_downsample_bias_loaded_) + << "bias is not loaded for " << prefix + "downsample.bias"; + } + + void merge_loaded_weights() { + for (int idx = 0; idx < layers_.size(); ++idx) { + layers_[idx]->merge_loaded_weights(); + } + } + private: + int hidden_size_ = 0; + int out_hidden_size_ = 0; + int spatial_merge_size_ = 0; + + Glm4VisionPatchEmbed patch_embed_{nullptr}; + Glm4VisionRotaryEmbedding rotary_pos_emb_{nullptr}; + torch::nn::ModuleList blocks_{nullptr}; + Glm4vVisionEmbeddings embeddings_{nullptr}; + Glm4VisionRmsNorm post_conv_layernorm_{nullptr}; + Glm4VisionRmsNorm post_layernorm_{nullptr}; + torch::nn::Conv2d downsample_{nullptr}; + std::vector layers_; + Glm4_VisionPatchMerger merger_{nullptr}; + torch::TensorOptions options_; + bool is_post_conv_layernorm_weight_loaded = false; + bool is_post_layernorm_weight_loaded = false; + bool is_downsample_weight_loaded_ = false; + bool is_downsample_bias_loaded_ = false; + torch::Tensor m_cos; + torch::Tensor m_sin; +}; +TORCH_MODULE(Glm4VisionTransformer); + +struct Glm4VImageInputs { + torch::Tensor pixel_values; + torch::Tensor image_grid_thw; +}; + +struct Glm4VVideoInputs { + torch::Tensor pixel_values_videos; + torch::Tensor video_grid_thw; + torch::Tensor second_per_grid_ts; +}; + +class Glm4vForConditionalGenerationImpl : public torch::nn::Module { + public: + Glm4vForConditionalGenerationImpl(const ModelContext& context) + : model_args_(context.get_model_args()), + options_(context.get_tensor_options()) { + visual_ = register_module("visual", Glm4VisionTransformer(context)); + + language_model_ = + register_module("language_model", Glm4ForCausalLM(context)); + } + + torch::Tensor get_input_embeddings( + torch::Tensor input_ids, + const std::optional& image_input, + const std::optional& video_input, + const ModelInputParams& input_params) { + auto inputs_embeds = language_model_->get_input_embeddings(input_ids); + if (image_input) { + auto image_embeds = + visual_(image_input->pixel_values.to(options_), + image_input->image_grid_thw, + input_params); + auto is_multimodal = torch::isin(input_ids, + model_args_.image_token_id()); input_params.visual_pos_masks = + is_multimodal; inputs_embeds.index_put_({is_multimodal}, image_embeds); + } + return inputs_embeds; + } + + torch::Tensor forward(const torch::Tensor& tokens, + const torch::Tensor& positions, + std::vector& kv_caches, + const ModelInputParams& input_params) { + torch::NoGradGuard no_grad; + const auto& mm_data = input_params.mm_data; + torch::Tensor pixel_values; + if (const auto& res = mm_data.get("pixel_values")) + pixel_values = res.value(); + + torch::Tensor image_grid_thw; + if (const auto& res = mm_data.get("image_grid_thw")) + image_grid_thw = res.value(); + std::optional image_inputs; + std::optional video_inputs; + + if (pixel_values.defined() && image_grid_thw.defined()) + image_inputs = Glm4VImageInputs{pixel_values, image_grid_thw}; + auto inputs_embeds = get_input_embeddings(tokens, image_inputs, video_inputs, input_params); + input_params.input_embedding = inputs_embeds; + auto emb = language_model_(tokens, positions, kv_caches, input_params); + + return emb; + } + + torch::Tensor logits(const torch::Tensor& hidden_states, + const torch::Tensor& seleted_idxes) { + return language_model_->logits(hidden_states, seleted_idxes); + } + + void load_model(std::unique_ptr loader) { + for (const auto& state_dict : loader->get_state_dicts()) { + visual_->load_state_dict( + state_dict->get_dict_with_prefix("model.visual.")); + } + visual_->verify_loaded_weights("model.visual."); + visual_->merge_loaded_weights(); + if (!model_args_.image_embedding_mode()) { + language_model_->load_model(std::move(loader), "model.language_model."); + } + } + + layer::LmHead get_lm_head() { return language_model_->get_lm_head(); } + void set_lm_head(layer::LmHead& head) { language_model_->set_lm_head(head); } + + layer::WordEmbedding get_word_embedding() { + return language_model_->get_word_embedding(); + } + + void set_word_embedding(layer::WordEmbedding& word_embedding) { + language_model_->set_word_embedding(word_embedding); + } + + private: + ModelArgs model_args_; + torch::TensorOptions options_; + Glm4VisionTransformer visual_{nullptr}; + Glm4ForCausalLM language_model_{nullptr}; +}; +TORCH_MODULE(Glm4vForConditionalGeneration); + +REGISTER_INPUT_PROCESSOR(glm4v, GLM4_6_VLInputProcessor); +REGISTER_CAUSAL_VLM_MODEL(glm4v, Glm4vForConditionalGeneration); +REGISTER_IMAGE_PROCESSOR(glm4v, Glm4VImageProcessor); +// register the model args +REGISTER_MODEL_ARGS(glm4v, [&] { + LOAD_ARG_OR(model_type, "model_type", "glm4v"); + LOAD_ARG_OR(image_start_token_id, "image_start_token_id", 151339); + LOAD_ARG_OR(image_end_token_id, "image_end_token_id", 151340); + LOAD_ARG_OR(video_start_token_id, "video_start_token_id", 151341); + LOAD_ARG_OR(video_end_token_id, "video_end_token_id", 151342); + LOAD_ARG_OR(image_token_id, "image_token_id", 151363); + LOAD_ARG_OR(video_token_id, "video_token_id", 151364); + LOAD_ARG_OR(tie_word_embeddings, "tie_word_embeddings", false); + + // text config + LOAD_ARG_OR(vocab_size, "text_config.vocab_size", 151552); + // LOAD_ARG_OR(pad_token_id, "text_config.pad_token_id", 151329); + LOAD_ARG_OR( + eos_token_id_vec, "text_config.eos_token_id", std::vector{151329}); + LOAD_ARG_OR(attention_bias, "text_config.attention_bias", true); + LOAD_ARG_OR(attention_dropout, "text_config.attention_dropout", 0.0f); + LOAD_ARG_OR(first_k_dense_replace, "text_config.first_k_dense_replace", 1); + LOAD_ARG_OR(hidden_act, "text_config.hidden_act", "silu"); + LOAD_ARG_OR(hidden_size, "text_config.hidden_size", 4096); + LOAD_ARG_OR(initializer_range, "text_config.initializer_range", 0.02); + LOAD_ARG_OR(intermediate_size, "text_config.intermediate_size", 10944); + LOAD_ARG_OR( + max_position_embeddings, "text_config.max_position_embeddings", 131072); + LOAD_ARG_OR(n_heads, "text_config.num_attention_heads", 96); + LOAD_ARG_OR_FUNC(head_dim, "text_config.head_dim", [&] { + return args->hidden_size() / args->n_heads(); + }); + LOAD_ARG_OR(num_experts_per_tok, "text_config.num_experts_per_tok", 8); + LOAD_ARG_OR(n_layers, "text_config.num_hidden_layers", 46); + LOAD_ARG_OR(n_kv_heads, "text_config.num_key_value_heads", 8); + // LOAD_ARG_OR(partial_rotary_factor, "text_config.partial_rotary_factor", + // 0.5); + LOAD_ARG_OR(rms_norm_eps, "text_config.rms_norm_eps", 1e-05); + LOAD_ARG_OR(dtype, "text_config.dtype", "bfloat16"); + LOAD_ARG_OR(rope_scaling_rope_type, "text_config.rope_scaling.type", "mrope"); + LOAD_ARG(rope_scaling_mrope_section, + "text_config.rope_scaling.mrope_section"); + LOAD_ARG_OR(rope_theta, "text_config.rope_theta", 500000.0f); + LOAD_ARG_OR(routed_scaling_factor, "text_config.routed_scaling_factor", 1.0); + LOAD_ARG_OR(topk_group, "text_config.topk_group", 1); + // LOAD_ARG_OR(use_cache, "text_config.use_cache", true); + LOAD_ARG_OR(use_qk_norm, "text_config.use_qk_norm", false); + + // vision config + // LOAD_ARG_OR(mm_attention_bias, "vision_config.attention_bias", false); + // LOAD_ARG_OR(mm_attention_dropout, "vision_config.attention_dropout", 0.0f); + LOAD_ARG_OR(mm_num_hidden_layers, "vision_config.depth", 24); + LOAD_ARG_OR(mm_hidden_act, "vision_config.hidden_act", "silu"); + LOAD_ARG_OR(mm_hidden_size, "vision_config.hidden_size", 1536); + LOAD_ARG_OR(mm_image_size, "vision_config.image_size", 336); + LOAD_ARG_OR(mm_num_channels, "vision_config.in_channels", 3); + LOAD_ARG_OR(mm_initializer_range, "vision_config.initializer_range", 0.02); + LOAD_ARG_OR(mm_intermediate_size, "vision_config.intermediate_size", 10944); + LOAD_ARG_OR(mm_num_attention_heads, "vision_config.num_heads", 12); + LOAD_ARG_OR(mm_projection_dim, "vision_config.out_hidden_size", 4096); + LOAD_ARG_OR(mm_patch_size, "vision_config.patch_size", 14); + // LOAD_ARG_OR(mm_rms_norm_eps, "vision_config.rms_norm_eps", 1e-05); + LOAD_ARG_OR(mm_spatial_merge_size, "vision_config.spatial_merge_size", 2); + LOAD_ARG_OR(mm_temporal_patch_size, "vision_config.temporal_patch_size", 2); + LOAD_ARG_OR_FUNC(mm_head_dim, "head_dim", [&] { + return args->mm_hidden_size() / args->mm_num_attention_heads(); + }); + + SET_ARG(stop_token_ids, + std::unordered_set(args->eos_token_id_vec().begin(), + args->eos_token_id_vec().end())); +}); +} // namespace xllm \ No newline at end of file diff --git a/xllm/models/vlm/glm4v_moe.h b/xllm/models/vlm/glm4v_moe.h new file mode 100644 index 000000000..e772a04c9 --- /dev/null +++ b/xllm/models/vlm/glm4v_moe.h @@ -0,0 +1,220 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include +#include +#include + +#include +#include + +#include "core/framework/kv_cache/kv_cache.h" +#include "core/framework/model/model_input_params.h" +#include "core/framework/model_context.h" +#include "core/layers/lm_head.h" +#include "core/layers/rms_norm.h" +#include "models/llm/glm4_moe.h" +#include "models/model_registry.h" +#include "processors/input_processor.h" +#include "processors/glm4v_image_processor.h" +#include "xllm_kernels/core/include/atb_speed/log.h" +#include "models/vlm/glm4v.h" + +namespace xllm { + +class Glm4vMoeForConditionalGenerationImpl : public torch::nn::Module { + public: + Glm4vMoeForConditionalGenerationImpl(const ModelContext& context) + : model_args_(context.get_model_args()), + options_(context.get_tensor_options()) { + std::cout << "----------------Glm4vMoeForConditionalGenerationImpl init begin ----------------- " << std::endl; + visual_ = register_module("visual", Glm4VisionTransformer(context)); + + language_model_ = + register_module("language_model", Glm4MoeForCausalLM(context)); + } + + torch::Tensor get_input_embeddings( + torch::Tensor input_ids, + const std::optional& image_input, + const std::optional& video_input, + const ModelInputParams& input_params) { + // visual + LOG(INFO) << " Glm4vMoeForConditionalGenerationImpl forward get_input_embeddings visual_ begin "; + torch::Tensor pixel = image_input->pixel_values.to(options_); + LOG(INFO) << " Glm4vMoeForConditionalGenerationImpl forward get_input_embeddings pixel aft "; + auto image_embeds = + visual_(pixel, + image_input->image_grid_thw.to(pixel.device()), + input_params); + LOG(INFO) << " Glm4vMoeForConditionalGenerationImpl forward get_input_embeddings visual_ end "; + auto inputs_embeds = language_model_->get_input_embeddings(input_ids); + // merge + auto is_multimodal = torch::isin(input_ids, + model_args_.image_token_id()); input_params.visual_pos_masks = + is_multimodal; inputs_embeds.index_put_({is_multimodal}, image_embeds); + return inputs_embeds; + } + + torch::Tensor forward(const torch::Tensor& tokens, + const torch::Tensor& positions, + std::vector& kv_caches, + const ModelInputParams& input_params) { + std::cout << "----------------Glm4vMoeForConditionalGenerationImpl beging ----------------- " << std::endl; + LOG(INFO) << " Glm4vMoeForConditionalGenerationImpl beging "; + torch::NoGradGuard no_grad; + const auto& mm_data = input_params.mm_data; + torch::Tensor pixel_values; + if (const auto& res = mm_data.get("pixel_values")) + pixel_values = res.value(); + + torch::Tensor image_grid_thw; + if (const auto& res = mm_data.get("image_grid_thw")) + image_grid_thw = res.value(); + std::optional image_inputs; + std::optional video_inputs; + LOG(INFO) << " Glm4vMoeForConditionalGenerationImpl forward Glm4VImageInputs beging "; + if (pixel_values.defined() && image_grid_thw.defined()) + image_inputs = Glm4VImageInputs{pixel_values, image_grid_thw}; + else + LOG(FATAL) << "Pixel value or image grid thw is null"; + + LOG(INFO) << " Glm4vMoeForConditionalGenerationImpl forward get_input_embeddings beging "; + + auto inputs_embeds = get_input_embeddings(tokens, image_inputs, video_inputs, input_params); + input_params.input_embedding = inputs_embeds; + auto emb = language_model_(tokens, positions, kv_caches, input_params); + + return emb; + } + + torch::Tensor logits(const torch::Tensor& hidden_states, + const torch::Tensor& seleted_idxes) { + return language_model_->logits(hidden_states, seleted_idxes); + } + + void load_model(std::unique_ptr loader) { + for (const auto& state_dict : loader->get_state_dicts()) { + visual_->load_state_dict( + state_dict->get_dict_with_prefix("model.visual.")); + } + // verify + visual_->verify_loaded_weights("model.visual."); + visual_->merge_loaded_weights(); + if (!model_args_.image_embedding_mode()) { + language_model_->load_model(std::move(loader), "model.language_model."); + } + } + + layer::LmHead get_lm_head() { return language_model_->get_lm_head(); } + void set_lm_head(layer::LmHead& head) { language_model_->set_lm_head(head); } + + layer::WordEmbedding get_word_embedding() { + return language_model_->get_word_embedding(); + } + + void set_word_embedding(layer::WordEmbedding& word_embedding) { + language_model_->set_word_embedding(word_embedding); + } + + private: + ModelArgs model_args_; + torch::TensorOptions options_; + Glm4VisionTransformer visual_{nullptr}; + Glm4MoeForCausalLM language_model_{nullptr}; +}; +TORCH_MODULE(Glm4vMoeForConditionalGeneration); + +REGISTER_INPUT_PROCESSOR(glm4v_moe, GLM4_6_VLInputProcessor); +REGISTER_CAUSAL_VLM_MODEL(glm4v_moe, Glm4vMoeForConditionalGeneration); +REGISTER_IMAGE_PROCESSOR(glm4v_moe, Glm4VImageProcessor); +// register the model args +REGISTER_MODEL_ARGS(glm4v_moe, [&] { + LOAD_ARG_OR(model_type, "model_type", "glm4v_moe"); + LOAD_ARG_OR(image_start_token_id, "image_start_token_id", 151339); + LOAD_ARG_OR(image_end_token_id, "image_end_token_id", 151340); + LOAD_ARG_OR(video_start_token_id, "video_start_token_id", 151341); + LOAD_ARG_OR(video_end_token_id, "video_end_token_id", 151342); + LOAD_ARG_OR(image_token_id, "image_token_id", 151363); + LOAD_ARG_OR(video_token_id, "video_token_id", 151364); + LOAD_ARG_OR(tie_word_embeddings, "tie_word_embeddings", false); + + // text config + LOAD_ARG_OR(vocab_size, "text_config.vocab_size", 151552); + // LOAD_ARG_OR(pad_token_id, "text_config.pad_token_id", 151329); + LOAD_ARG_OR( + eos_token_id_vec, "text_config.eos_token_id", std::vector{151329}); + LOAD_ARG_OR_FUNC(head_dim, "text_config.head_dim", [&] { + return args->hidden_size() / args->n_heads(); + }); + LOAD_ARG_OR(attention_bias, "text_config.attention_bias", true); + LOAD_ARG_OR(attention_dropout, "text_config.attention_dropout", 0.0f); + LOAD_ARG_OR(first_k_dense_replace, "text_config.first_k_dense_replace", 1); + LOAD_ARG_OR(hidden_act, "text_config.hidden_act", "silu"); + LOAD_ARG_OR(hidden_size, "text_config.hidden_size", 4096); + LOAD_ARG_OR(initializer_range, "text_config.initializer_range", 0.02); + LOAD_ARG_OR(intermediate_size, "text_config.intermediate_size", 10944); + LOAD_ARG_OR( + max_position_embeddings, "text_config.max_position_embeddings", 131072); + LOAD_ARG_OR(moe_intermediate_size, "text_config.moe_intermediate_size", 1408); + LOAD_ARG_OR(n_group, "text_config.n_group", 1); + LOAD_ARG_OR(num_experts, "text_config.n_routed_experts", 128); + LOAD_ARG_OR(n_shared_experts, "text_config.n_shared_experts", 1); + LOAD_ARG_OR(norm_topk_prob, "text_config.norm_topk_prob", true); + LOAD_ARG_OR(n_heads, "text_config.num_attention_heads", 96); + LOAD_ARG_OR(num_experts_per_tok, "text_config.num_experts_per_tok", 8); + LOAD_ARG_OR(n_layers, "text_config.num_hidden_layers", 46); + LOAD_ARG_OR(n_kv_heads, "text_config.num_key_value_heads", 8); + // LOAD_ARG_OR(partial_rotary_factor, "text_config.partial_rotary_factor", + // 0.5); + LOAD_ARG_OR(rms_norm_eps, "text_config.rms_norm_eps", 1e-05); + LOAD_ARG_OR(dtype, "text_config.dtype", "bfloat16"); + LOAD_ARG_OR(rope_scaling_rope_type, "text_config.rope_scaling.type", "mrope"); + LOAD_ARG(rope_scaling_mrope_section, + "text_config.rope_scaling.mrope_section"); + LOAD_ARG_OR(rope_theta, "text_config.rope_theta", 500000.0f); + LOAD_ARG_OR(routed_scaling_factor, "text_config.routed_scaling_factor", 1.0); + LOAD_ARG_OR(topk_group, "text_config.topk_group", 1); + // LOAD_ARG_OR(use_cache, "text_config.use_cache", true); + LOAD_ARG_OR(use_qk_norm, "text_config.use_qk_norm", false); + + // vision config + // LOAD_ARG_OR(mm_attention_bias, "vision_config.attention_bias", false); + // LOAD_ARG_OR(mm_attention_dropout, "vision_config.attention_dropout", 0.0f); + LOAD_ARG_OR(mm_num_hidden_layers, "vision_config.depth", 24); + LOAD_ARG_OR(mm_hidden_act, "vision_config.hidden_act", "silu"); + LOAD_ARG_OR(mm_hidden_size, "vision_config.hidden_size", 1536); + LOAD_ARG_OR(mm_image_size, "vision_config.image_size", 336); + LOAD_ARG_OR(mm_num_channels, "vision_config.in_channels", 3); + LOAD_ARG_OR(mm_initializer_range, "vision_config.initializer_range", 0.02); + LOAD_ARG_OR(mm_intermediate_size, "vision_config.intermediate_size", 10944); + LOAD_ARG_OR(mm_num_attention_heads, "vision_config.num_heads", 12); + LOAD_ARG_OR(mm_projection_dim, "vision_config.out_hidden_size", 4096); + LOAD_ARG_OR(mm_patch_size, "vision_config.patch_size", 14); + // LOAD_ARG_OR(mm_rms_norm_eps, "vision_config.rms_norm_eps", 1e-05); + LOAD_ARG_OR(mm_spatial_merge_size, "vision_config.spatial_merge_size", 2); + LOAD_ARG_OR(mm_temporal_patch_size, "vision_config.temporal_patch_size", 2); + LOAD_ARG_OR_FUNC(mm_head_dim, "head_dim", [&] { + return args->mm_hidden_size() / args->mm_num_attention_heads(); + }); + + SET_ARG(stop_token_ids, + std::unordered_set(args->eos_token_id_vec().begin(), + args->eos_token_id_vec().end())); +}); +} // namespace xllm \ No newline at end of file diff --git a/xllm/models/vlm/qwen2_5_vl.h b/xllm/models/vlm/qwen2_5_vl.h index ec6e6aa4a..48683a5ad 100644 --- a/xllm/models/vlm/qwen2_5_vl.h +++ b/xllm/models/vlm/qwen2_5_vl.h @@ -697,6 +697,15 @@ class Qwen2_5_VLForConditionalGenerationImpl : public torch::nn::Module { auto is_multimodal = torch::isin(input_ids, model_args_.image_token_id()); inputs_embeds.index_put_({is_multimodal}, image_embeds); } + if (video_input) { + // visual + auto video_embeds = visual_(video_input->pixel_values_videos.to(options_), + video_input->video_grid_thw, + input_params); + // merge + auto is_multimodal = torch::isin(input_ids, model_args_.video_token_id()); + inputs_embeds.index_put_({is_multimodal}, video_embeds); + } return inputs_embeds; } @@ -715,11 +724,29 @@ class Qwen2_5_VLForConditionalGenerationImpl : public torch::nn::Module { if (const auto& res = mm_data.get("image_grid_thw")) image_grid_thw = res.value(); + torch::Tensor pixel_values_videos; + if (const auto& res = mm_data.get("pixel_values_videos")) + pixel_values_videos = res.value(); + + torch::Tensor video_grid_thw; + if (const auto& res = mm_data.get("video_grid_thw")) + video_grid_thw = res.value(); + + torch::Tensor second_per_grid_ts; + if (const auto& res = mm_data.get("second_per_grid_ts")) + second_per_grid_ts = res.value(); + std::optional image_inputs; std::optional video_inputs; if (pixel_values.defined() && image_grid_thw.defined()) image_inputs = Qwen2_5_VLImageInputs{pixel_values, image_grid_thw}; + + if (pixel_values_videos.defined() && video_grid_thw.defined() && + second_per_grid_ts.defined()) + video_inputs = Qwen2_5_VLVideoInputs{ + pixel_values_videos, video_grid_thw, second_per_grid_ts}; + auto inputs_embeds = get_input_embeddings(tokens, image_inputs, video_inputs, input_params); input_params.input_embedding = inputs_embeds; diff --git a/xllm/processors/CMakeLists.txt b/xllm/processors/CMakeLists.txt index 27365efe8..fe24f2f4d 100755 --- a/xllm/processors/CMakeLists.txt +++ b/xllm/processors/CMakeLists.txt @@ -21,6 +21,7 @@ cc_library( processors HDRS image_processor.h + glm4v_image_processor.h clip_image_processor.h minicpmv_image_processor.h qwen2_vl_image_processor.h @@ -28,6 +29,7 @@ cc_library( input_processor.h SRCS image_processor.cpp + glm4v_image_processor.cpp clip_image_processor.cpp minicpmv_image_processor.cpp qwen2_vl_image_processor.cpp diff --git a/xllm/processors/glm4v_image_processor.cpp b/xllm/processors/glm4v_image_processor.cpp new file mode 100644 index 000000000..6e00b7e2b --- /dev/null +++ b/xllm/processors/glm4v_image_processor.cpp @@ -0,0 +1,478 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "glm4v_image_processor.h" + +namespace xllm { + +namespace { + +using Size = std::pair; + +std::optional smart_resize(int num_frames, + int height, + int width, + int temporal_factor, + int factor = 28, + int min_pixels = 56 * 56, + int max_pixels = 14 * 14 * 4 * 1280) { + if (height < factor || width < factor) { + LOG(ERROR) << "Height or width must be larger than factor"; + return std::nullopt; + } + if (num_frames < temporal_factor) { + LOG(ERROR) << "t:{num_frames} must be larger than " + "temporal_factor:{temporal_factor}"; + return std::nullopt; + } + + if (static_cast(std::max(height, width)) / std::min(height, width) > + 200) { + LOG(ERROR) << "Absolute aspect ratio must be smaller than 200"; + return std::nullopt; + } + int t_bar = static_cast(std::round( + num_frames / static_cast(temporal_factor))) * + temporal_factor; + int h_bar = + static_cast(std::round(height / static_cast(factor))) * + factor; + int w_bar = + static_cast(std::round(width / static_cast(factor))) * + factor; + + if (t_bar * h_bar * w_bar > max_pixels) { + double beta = std::sqrt((num_frames * height * width) / + static_cast(max_pixels)); + h_bar = static_cast( + std::floor(height / beta / static_cast(factor))) * + factor; + w_bar = static_cast( + std::floor(width / beta / static_cast(factor))) * + factor; + } else if (t_bar * h_bar * w_bar < min_pixels) { + double beta = std::sqrt(min_pixels / + static_cast(height * width * num_frames)); + h_bar = static_cast( + std::ceil(height * beta / static_cast(factor))) * + factor; + w_bar = static_cast( + std::ceil(width * beta / static_cast(factor))) * + factor; + } + + return std::make_pair(h_bar, w_bar); +} +} // namespace + +torch::Tensor Glm4VImageProcessor::sample_frames(const VideoMetadata& metadata, + int temporal_patch_size) { + // video: [T, C, H, W] + const int total_frames = metadata.total_num_frames; + if (total_frames <= 0) { + return torch::empty({0}, torch::dtype(torch::kLong)); + } + + if (metadata.fps <= 0.0) { + LOG(FATAL) << "invalid metadata.fps <= 0"; + } + + const int max_frame_idx = total_frames - 1; + + // duration = metadata.duration or round(max_idx / fps) + 1 + double duration = metadata.duration; + if (duration <= 0.0) { + duration = + std::round(static_cast(max_frame_idx) / metadata.fps) + 1.0; + } + + constexpr double DYN_FPS_30 = 3.0; + constexpr double DYN_FPS_300 = 1.0; + constexpr double DYN_FPS_2400 = 0.5; + constexpr int MAX_FRAME_COUNT_DYNAMIC = 640; + constexpr double MAX_DURATION = 2400.0; + + const double effective_duration = std::min(duration, MAX_DURATION); + + double target_fps = 0.0; + if (effective_duration <= 30.0) { + target_fps = DYN_FPS_30; + } else if (effective_duration <= 300.0) { + target_fps = DYN_FPS_300; + } else { + target_fps = DYN_FPS_2400; + } + + // extract_t = int(effective_duration * target_fps * temporal_patch_size) + int extract_t = static_cast(effective_duration * target_fps * + static_cast(temporal_patch_size)); + extract_t = std::min(extract_t, MAX_FRAME_COUNT_DYNAMIC); + + const double duration_per_frame = 1.0 / metadata.fps; + std::vector timestamps(total_frames); + for (int i = 0; i < total_frames; ++i) { + timestamps[i] = static_cast(i) * duration_per_frame; + } + const int max_second = static_cast(duration); + + torch::Tensor frame_indices; + + if (total_frames < extract_t) { + frame_indices = torch::linspace( + 0, total_frames - 1, extract_t, torch::dtype(torch::kLong)); + } else { + std::vector tmp; + tmp.reserve(static_cast(total_frames)); + double current_second = 0.0; + const double inv_fps = + 1.0 / (static_cast(temporal_patch_size) * target_fps); + + for (int frame_index = 0; frame_index < total_frames; frame_index++) { + if (timestamps[frame_index] >= current_second) { + current_second += inv_fps; + tmp.push_back(frame_index); + if (current_second >= static_cast(max_second)) { + break; + } + } + } + frame_indices = + torch::tensor(tmp, torch::TensorOptions().dtype(torch::kLong)); + } + int64_t len = frame_indices.size(0); + if (len < extract_t) { + int64_t start, end; + if (len == 0) { + start = 0; + end = std::max(total_frames - 1, 0); + } else { + start = frame_indices[0].item(); + end = frame_indices[len - 1].item(); + } + frame_indices = + torch::linspace(start, end, extract_t, torch::dtype(torch::kLong)); + } else if (len > extract_t) { + frame_indices = torch::linspace( + 0, total_frames - 1, extract_t, torch::dtype(torch::kLong)); + } + + len = frame_indices.size(0); + std::unordered_set seen; + seen.reserve(static_cast(len) * 2); + std::vector uniq; + uniq.reserve(static_cast(len)); + + for (int64_t i = 0; i < len; ++i) { + auto idx = frame_indices[i].item(); + if (seen.insert(idx).second) { + uniq.push_back(idx); + } + } + + if (!uniq.empty() && (uniq.size() & 1)) { + uniq.push_back(uniq.back()); + } + + return torch::tensor(uniq, torch::TensorOptions().dtype(torch::kLong)); +} + +Glm4VImageProcessor::Glm4VImageProcessor(const ModelArgs& args) { + image_mean_ = args.mm_image_normalize_mean(); + image_std_ = args.mm_image_normalize_std(); + + if (args.mm_image_max_pixels() && args.mm_image_min_pixels()) { + min_pixels_ = args.mm_image_min_pixels(); + max_pixels_ = args.mm_image_max_pixels(); + } else if (args.mm_image_shortest_edge() && args.mm_image_longest_edge()) { + min_pixels_ = args.mm_image_shortest_edge(); + max_pixels_ = args.mm_image_longest_edge(); + } + + patch_size_ = args.mm_image_patch_size(); + temporal_patch_size_ = args.mm_image_temporal_patch_size(); + + merge_size_ = args.mm_image_merge_size(); + size_ = {{"longest_edge", 12845056}, {"shortest_edge", 3136}}; + + // fuse image mean/std and rescale_factor + if (do_rescale_ && do_normalize_) { + for (auto& item : image_mean_) { + item = item * (1.0 / rescale_factor_); + } + + for (auto& item : image_std_) { + item = item * (1.0 / rescale_factor_); + } + + do_rescale_ = false; + } +} + +bool Glm4VImageProcessor::process(const MMInput& inputs, MMData& datas) { + std::vector images = inputs.get_decode_data(MMType::IMAGE); + std::vector videos = inputs.get_decode_data(MMType::VIDEO); + std::vector video_meta_list = inputs.get_video_metadata(); + + if (images.empty() && (videos.empty() || video_meta_list.empty())) { + LOG(ERROR) << "no image/video tensor found."; + return false; + } + + if (!images.empty()) { + if (!this->process_images(images, datas)) { + LOG(ERROR) << " process image failed."; + return false; + } + } + + if (!videos.empty()) { + if (!this->process_videos(videos, video_meta_list, datas)) { + LOG(ERROR) << " process video failed."; + return false; + } + } + + return true; +} + +bool Glm4VImageProcessor::process_images(std::vector images, + MMData& mm_datas) { + std::vector pixel_values; + std::vector grids; + + for (const auto& img : images) { + if (!this->process_image(img, pixel_values, grids)) { + return false; + } + } + + auto values = torch::cat(pixel_values); + auto thw = torch::tensor(grids); + + thw = thw.clone().reshape({-1, 3}); + mm_datas.add(MMType::IMAGE, "image_grid_thw", thw); + mm_datas.add(MMType::IMAGE, "pixel_values", values); + + return true; +} + +bool Glm4VImageProcessor::process_image( + torch::Tensor image, + std::vector& pixel_values, + std::vector& grids) { + auto shape = image.sizes(); + + auto resized_height = shape[1]; + auto resized_width = shape[2]; + + // do_convert_rgb + + // resize + if (do_resize_) { + auto size = smart_resize(temporal_patch_size_, + resized_height, + resized_width, + temporal_patch_size_, + patch_size_ * merge_size_, + min_pixels_, + max_pixels_); + if (!size) { + return false; + } + + std::tie(resized_height, resized_width) = *size; + image = + this->resize(image, {resized_height, resized_width}, resample_, true); + } + + // normalize + if (do_normalize_) { + image = this->normalize(image, image_mean_, image_std_); + } + + // rescale + if (do_rescale_) { + image = this->rescale(image, rescale_factor_); + } + + auto patches = torch::stack({image}, 0); + + auto repeats = patches[-1].unsqueeze(0).repeat( + /*{temporal_patch_size_ - (shape[0] % temporal_patch_size_)*/ { + temporal_patch_size_ - 1, 1, 1, 1}); + patches = torch::cat({patches, repeats}, 0); + shape = patches.sizes(); + auto channel = shape[1]; + auto grid_t = shape[0] / temporal_patch_size_; + + auto grid_h = resized_height / patch_size_; + auto grid_w = resized_width / patch_size_; + + patches = patches.view({grid_t, + temporal_patch_size_, + channel, + grid_h / merge_size_, + merge_size_, + patch_size_, + grid_w / merge_size_, + merge_size_, + patch_size_}); + patches = patches.permute({0, 3, 6, 4, 7, 2, 1, 5, 8}); + patches = patches.reshape( + {grid_t * grid_h * grid_w, + channel * temporal_patch_size_ * patch_size_ * patch_size_}); + + pixel_values.emplace_back(patches); + grids.insert(grids.end(), {grid_t, grid_h, grid_w}); + + return true; +} + +bool Glm4VImageProcessor::process_videos( + std::vector videos, + std::vector video_meta_list, + MMData& mm_datas) { + std::vector pixel_values; + std::vector grids; + + const size_t video_size = videos.size(); + for (size_t i = 0; i < video_size; ++i) { + auto& vid = videos[i]; + auto& metadata = video_meta_list[i]; + if (!this->process_video(vid, metadata, pixel_values, grids)) { + return false; + } + } + mm_datas.set_video_metadata(video_meta_list); + + auto values = torch::cat(pixel_values); + auto thw = torch::tensor(grids).clone().reshape({-1, 3}); + mm_datas.add(MMType::VIDEO, "video_grid_thw", thw); + mm_datas.add(MMType::VIDEO, "pixel_values_videos", values); + + return true; +} + +bool Glm4VImageProcessor::process_video( + torch::Tensor origin_video, + VideoMetadata& metadata, + std::vector& pixel_values, + std::vector& grids) { + if (origin_video.dim() != 4) { + LOG(FATAL) << "video must be TCHW"; + } + + torch::Tensor indices; + if (do_sample_frame_) { + indices = this->sample_frames(metadata, temporal_patch_size_); + } else { + indices = torch::arange(0, + static_cast(origin_video.size(0)), + torch::TensorOptions().dtype(torch::kLong)); + } + auto video = origin_video.index_select(/*dim=*/0, indices); + int64_t sampled_total_frames = video.size(0); + + metadata.frame_indices = indices; + metadata.timestamps.clear(); + metadata.timestamps.reserve(static_cast(sampled_total_frames)); + double fps_for_ts = (metadata.fps > 0.0) ? metadata.fps : 24.0; + for (int64_t i = 0; i < sampled_total_frames; ++i) { + int64_t frame_idx = metadata.frame_indices[i].item(); + metadata.timestamps.push_back(static_cast(frame_idx) / fps_for_ts); + } + + if (metadata.total_num_frames > 0 && metadata.fps > 0.0) { + metadata.sampled_fps = double(sampled_total_frames) / + double(metadata.total_num_frames) * metadata.fps; + } else { + metadata.sampled_fps = fps_for_ts; + } + + auto shape = video.sizes(); + auto time_len = shape[0]; + auto channel = shape[1]; + auto resized_height = shape[2]; + auto resized_width = shape[3]; + + if (do_resize_) { + auto size = smart_resize(temporal_patch_size_, + resized_height, + resized_width, + temporal_patch_size_, + patch_size_ * merge_size_, + min_pixels_, + max_pixels_); + if (!size) { + return false; + } + std::tie(resized_height, resized_width) = *size; + } + + std::vector out_frames; + out_frames.reserve(time_len); + // for each frame + auto frames = video.unbind(0); + + for (auto& frame : frames) { + // resize + if (do_resize_) + frame = + this->resize(frame, {resized_height, resized_width}, resample_, true); + // normalize + if (do_normalize_) frame = this->normalize(frame, image_mean_, image_std_); + // rescale + if (do_rescale_) frame = this->rescale(frame, rescale_factor_); + out_frames.push_back(frame); + } + + auto out_video = torch::stack(out_frames); // [T,C,H,W] + + if (out_video.size(0) % temporal_patch_size_) { + auto last = out_video.index({time_len - 1}) + .unsqueeze(0) + .repeat({temporal_patch_size_ - 1, 1, 1, 1}); + out_video = torch::cat({out_video, last}, 0); + } + + shape = out_video.sizes(); + auto grid_h = resized_height / patch_size_; + auto grid_w = resized_width / patch_size_; + auto grid_t = shape[0] / temporal_patch_size_; + + out_video = out_video.contiguous(); + + auto patches = out_video.view({grid_t, + temporal_patch_size_, + channel, + grid_h / merge_size_, + merge_size_, + patch_size_, + grid_w / merge_size_, + merge_size_, + patch_size_}); + + patches = patches.permute({0, 3, 6, 4, 7, 2, 1, 5, 8}); + patches = patches.reshape( + {grid_t * grid_h * grid_w, + channel * temporal_patch_size_ * patch_size_ * patch_size_}); + + pixel_values.emplace_back(patches); + + grids.insert(grids.end(), {grid_t, grid_h, grid_w}); + return true; +} + +} // namespace xllm \ No newline at end of file diff --git a/xllm/processors/glm4v_image_processor.h b/xllm/processors/glm4v_image_processor.h new file mode 100644 index 000000000..2313fb9bf --- /dev/null +++ b/xllm/processors/glm4v_image_processor.h @@ -0,0 +1,76 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include +#include + +#include "image_processor.h" + +namespace xllm { + +class Glm4VImageProcessor : public ImageProcessor { + public: + Glm4VImageProcessor(const ModelArgs&); + ~Glm4VImageProcessor() override = default; + + bool process(const MMInput& mm_inputs, MMData& mm_datas) override; + + private: + bool process_images(std::vector images, MMData& mm_datas); + bool process_image(torch::Tensor image, + std::vector& pixel_values, + std::vector& grids); + bool process_videos(std::vector videos, + std::vector video_meta_list, + MMData& mm_datas); + bool process_video(torch::Tensor video, + VideoMetadata& metadata, + std::vector& pixel_values, + std::vector& grids); + torch::Tensor sample_frames(const VideoMetadata& metadata, + int temporal_patch_size); + + private: + bool do_convert_rgb_ = true; + bool do_normalize_ = true; + + bool do_rescale_ = true; + bool do_resize_ = true; + + std::vector image_mean_; + std::vector image_std_; + + int max_pixels_ = 12845056; + int min_pixels_ = 3136; + + int merge_size_ = 2; + int patch_size_ = 14; + + int resample_ = 3; + double rescale_factor_ = 0.00392156862745098; + + std::unordered_map size_; + int temporal_patch_size_ = 2; + + bool do_sample_frame_ = true; + + int min_frames_ = 4; + int max_frames_ = 768; +}; + +} // namespace xllm diff --git a/xllm/processors/image_processor.cpp b/xllm/processors/image_processor.cpp index f77f82b24..82195f36a 100644 --- a/xllm/processors/image_processor.cpp +++ b/xllm/processors/image_processor.cpp @@ -118,9 +118,8 @@ torch::Tensor ImageProcessor::normalize(const torch::Tensor& image, result = image.to(torch::kFloat32); } - auto dtype = image.dtype(); auto device = image.device(); - auto options = torch::dtype(dtype).device(device); + auto options = torch::dtype(torch::kFloat32).device(device); auto m_tensor = torch::tensor(mean, options).reshape({-1, 1, 1}); auto s_tensor = torch::tensor(std, options).reshape({-1, 1, 1}); diff --git a/xllm/processors/qwen2_vl_image_processor.cpp b/xllm/processors/qwen2_vl_image_processor.cpp index 16adc17d4..cd30d8146 100644 --- a/xllm/processors/qwen2_vl_image_processor.cpp +++ b/xllm/processors/qwen2_vl_image_processor.cpp @@ -60,6 +60,72 @@ std::optional smart_resize(int height, } } // namespace +torch::Tensor Qwen2VLImageProcessor::sample_frames( + const VideoMetadata& metadata, + int temporal_patch_size, + int min_frames, + int max_frames, + int num_frames, + double set_fps) { + if (set_fps > 0.0 && num_frames > 0) { + LOG(FATAL) << "num_frames and fps are mutually exclusive arguments, please " + "use only one!"; + } + + double fps = set_fps; + + int total_num_frames = metadata.total_num_frames; + + if (num_frames > 0) { + double double_num_frames = + std::round(static_cast(num_frames) / temporal_patch_size) * + temporal_patch_size; + num_frames = static_cast(double_num_frames); + } else if (fps > 0.0) { + if (metadata.fps <= 0.0) { + LOG(FATAL) + << "Asked to sample `fps` frames per second but no video metadata " + "was provided which is required when sampling with `fps`. "; + } + + max_frames = + (std::min(max_frames, total_num_frames) / temporal_patch_size) * + temporal_patch_size; + double double_num_frames = + static_cast(total_num_frames) / metadata.fps * fps; + double_num_frames = std::min( + std::min(std::max(double_num_frames, static_cast(min_frames)), + static_cast(max_frames)), + static_cast(total_num_frames)); + double_num_frames = std::floor(double_num_frames / temporal_patch_size) * + temporal_patch_size; + + num_frames = static_cast(double_num_frames); + } + + if (num_frames > total_num_frames) { + LOG(FATAL) << "Video can't be sampled. The inferred num_frames=" + << num_frames << " exceeds total_num_frames=" << total_num_frames + << "."; + } + + if (num_frames > 0) { + std::vector indices; + indices.reserve(num_frames); + for (int i = 0; i < num_frames; ++i) { + int64_t k = static_cast( + (static_cast(i) * total_num_frames) / num_frames); + if (k >= total_num_frames) k = total_num_frames - 1; + indices.push_back(k); + } + return torch::tensor(indices, torch::TensorOptions().dtype(torch::kLong)); + } else { + return torch::arange(0, + static_cast(total_num_frames), + torch::TensorOptions().dtype(torch::kLong)); + } +} + Qwen2VLImageProcessor::Qwen2VLImageProcessor(const ModelArgs& args) { image_mean_ = args.mm_image_normalize_mean(); image_std_ = args.mm_image_normalize_std(); @@ -92,14 +158,26 @@ Qwen2VLImageProcessor::Qwen2VLImageProcessor(const ModelArgs& args) { bool Qwen2VLImageProcessor::process(const MMInput& inputs, MMData& datas) { std::vector images = inputs.get_decode_data(MMType::IMAGE); - if (images.empty()) { - LOG(ERROR) << " image tensor not found."; + std::vector videos = inputs.get_decode_data(MMType::VIDEO); + std::vector video_meta_list = inputs.get_video_metadata(); + + if (images.empty() && (videos.empty() || video_meta_list.empty())) { + LOG(ERROR) << "no image/video tensor found."; return false; } - if (!this->process_images(images, datas)) { - LOG(ERROR) << " process image failed."; - return false; + if (!images.empty()) { + if (!this->process_images(images, datas)) { + LOG(ERROR) << " process image failed."; + return false; + } + } + + if (!videos.empty()) { + if (!this->process_videos(videos, video_meta_list, datas)) { + LOG(ERROR) << " process video failed."; + return false; + } } return true; @@ -120,9 +198,9 @@ bool Qwen2VLImageProcessor::process_images(std::vector images, auto thw = torch::tensor(grids); thw = thw.clone().reshape({-1, 3}); - mm_datas = std::move(MMData( - MMType::IMAGE, {{"image_grid_thw", thw}, {"pixel_values", values}})); + mm_datas.add(MMType::IMAGE, "image_grid_thw", thw); + mm_datas.add(MMType::IMAGE, "pixel_values", values); return true; } @@ -198,4 +276,157 @@ bool Qwen2VLImageProcessor::process_image( return true; } +bool Qwen2VLImageProcessor::process_videos( + std::vector videos, + std::vector video_meta_list, + MMData& mm_datas) { + std::vector pixel_values; + std::vector grids; + + const size_t video_size = videos.size(); + for (size_t i = 0; i < video_size; ++i) { + auto& vid = videos[i]; + auto& metadata = video_meta_list[i]; + if (!this->process_video(vid, metadata, pixel_values, grids)) { + return false; + } + } + + auto values = torch::cat(pixel_values); + auto thw = torch::tensor(grids).clone().reshape({-1, 3}); + + const size_t num_videos = videos.size(); + std::vector second_per_grid; + second_per_grid.reserve(num_videos); + for (size_t i = 0; i < num_videos; ++i) { + const auto& metadata = video_meta_list[i]; + double fps = + metadata.sampled_fps > 0.0 ? metadata.sampled_fps : metadata.fps; + double seconds_per_grid = static_cast(temporal_patch_size_) / fps; + second_per_grid.push_back(seconds_per_grid); + } + mm_datas.set_video_metadata(video_meta_list); + + auto opts = torch::TensorOptions().dtype(torch::kFloat32); + auto second_per_grid_ts = torch::tensor(second_per_grid, opts); + + mm_datas.add(MMType::VIDEO, "video_grid_thw", thw); + mm_datas.add(MMType::VIDEO, "pixel_values_videos", values); + mm_datas.add(MMType::VIDEO, "second_per_grid_ts", second_per_grid_ts); + + return true; +} + +bool Qwen2VLImageProcessor::process_video( + torch::Tensor origin_video, + VideoMetadata& metadata, + std::vector& pixel_values, + std::vector& grids) { + if (origin_video.dim() != 4) { + LOG(FATAL) << "video must be TCHW"; + } + + torch::Tensor indices; + if (do_sample_frame_) { + indices = this->sample_frames(metadata, + temporal_patch_size_, + min_frames_, + max_frames_, + /*num_frames=*/-1, + /*set_fps=*/2.0); + } else { + indices = torch::arange(0, + static_cast(origin_video.size(0)), + torch::TensorOptions().dtype(torch::kLong)); + } + auto video = origin_video.index_select(/*dim=*/0, indices); + int64_t sampled_total_frames = video.size(0); + + metadata.frame_indices = indices; + metadata.timestamps.clear(); + metadata.timestamps.reserve(static_cast(sampled_total_frames)); + double fps_for_ts = (metadata.fps > 0.0) ? metadata.fps : 24.0; + for (int64_t i = 0; i < sampled_total_frames; ++i) { + int64_t frame_idx = metadata.frame_indices[i].item(); + metadata.timestamps.push_back(static_cast(frame_idx) / fps_for_ts); + } + + if (metadata.total_num_frames > 0 && metadata.fps > 0.0) { + metadata.sampled_fps = double(sampled_total_frames) / + double(metadata.total_num_frames) * metadata.fps; + } else { + metadata.sampled_fps = fps_for_ts; + } + + auto shape = video.sizes(); + auto time_len = shape[0]; + auto channel = shape[1]; + auto resized_height = shape[2]; + auto resized_width = shape[3]; + + if (do_resize_) { + auto size = smart_resize(resized_height, + resized_width, + patch_size_ * merge_size_, + size_["shortest_edge"], + size_["longest_edge"]); + if (!size) { + return false; + } + std::tie(resized_height, resized_width) = *size; + } + + std::vector out_frames; + out_frames.reserve(time_len); + // for each frame + auto frames = video.unbind(0); + for (auto& frame : frames) { + // resize + if (do_resize_) + frame = + this->resize(frame, {resized_height, resized_width}, resample_, true); + // normalize + if (do_normalize_) frame = this->normalize(frame, image_mean_, image_std_); + // rescale + if (do_rescale_) frame = this->rescale(frame, rescale_factor_); + out_frames.push_back(frame); + } + + auto out_video = torch::stack(out_frames); // [T,C,H,W] + + auto pad_t = (temporal_patch_size_ - (time_len % temporal_patch_size_)) % + temporal_patch_size_; + if (pad_t > 0) { + auto last = + out_video.index({time_len - 1}).unsqueeze(0).repeat({pad_t, 1, 1, 1}); + out_video = torch::cat({out_video, last}, 0); + } + + shape = out_video.sizes(); + auto grid_h = resized_height / patch_size_; + auto grid_w = resized_width / patch_size_; + auto grid_t = shape[0] / temporal_patch_size_; + + out_video = out_video.contiguous(); + + auto patches = out_video.view({grid_t, + temporal_patch_size_, + channel, + grid_h / merge_size_, + merge_size_, + patch_size_, + grid_w / merge_size_, + merge_size_, + patch_size_}); + + patches = patches.permute({0, 3, 6, 4, 7, 2, 1, 5, 8}); + patches = patches.reshape( + {grid_t * grid_h * grid_w, + channel * temporal_patch_size_ * patch_size_ * patch_size_}); + + pixel_values.emplace_back(patches); + grids.insert(grids.end(), {grid_t, grid_h, grid_w}); + return true; +} + } // namespace xllm diff --git a/xllm/processors/qwen2_vl_image_processor.h b/xllm/processors/qwen2_vl_image_processor.h index 0cdd4c0d9..3e35ac501 100644 --- a/xllm/processors/qwen2_vl_image_processor.h +++ b/xllm/processors/qwen2_vl_image_processor.h @@ -36,6 +36,20 @@ class Qwen2VLImageProcessor : public ImageProcessor { std::vector& pixel_values, std::vector& grids); + bool process_videos(std::vector videos, + std::vector video_meta_list, + MMData& mm_datas); + bool process_video(torch::Tensor video, + VideoMetadata& metadata, + std::vector& pixel_values, + std::vector& grids); + torch::Tensor sample_frames(const VideoMetadata& metadata, + int temporal_patch_size, + int min_frames, + int max_frames, + int num_frames = -1, + double set_fps = -1.0); + private: bool do_convert_rgb_ = true; bool do_normalize_ = true; @@ -57,6 +71,11 @@ class Qwen2VLImageProcessor : public ImageProcessor { std::unordered_map size_; int temporal_patch_size_ = 2; + + bool do_sample_frame_ = true; + + int min_frames_ = 4; + int max_frames_ = 768; }; } // namespace xllm diff --git a/xllm/pybind/CMakeLists.txt b/xllm/pybind/CMakeLists.txt index 964678802..6eccea0c9 100644 --- a/xllm/pybind/CMakeLists.txt +++ b/xllm/pybind/CMakeLists.txt @@ -24,6 +24,7 @@ pybind_extension( torch c10 ) +target_link_options(xllm_export PRIVATE -Wl,-Bsymbolic) target_link_libraries(common PRIVATE leveldb::leveldb ZLIB::ZLIB OpenSSL::SSL OpenSSL::Crypto protobuf::libprotobuf) add_dependencies(common brpc-static)