diff --git a/api/causal_lm_api.cpp b/api/causal_lm_api.cpp index 1e16a447..0124bb69 100644 --- a/api/causal_lm_api.cpp +++ b/api/causal_lm_api.cpp @@ -39,7 +39,7 @@ using json = nlohmann::json; -static std::unique_ptr g_model; +static std::unique_ptr g_model; static std::mutex g_mutex; static bool g_initialized = false; static std::string g_architecture = ""; @@ -565,10 +565,10 @@ ErrorCode runModel(const char *inputTextPrompt, const char **outputText) { // We assume single batch request for this API #if defined(_WIN32) - g_model->run(std::wstring(input.begin(), input.end()), false, L"", L"", + g_model->run(std::wstring(input.begin(), input.end()), L"", L"", nullptr, g_verbose); #else - g_model->run(input, false, "", "", g_verbose); + g_model->run(input, "", "", nullptr, g_verbose); #endif auto causal_lm_model = dynamic_cast(g_model.get()); diff --git a/factory.h b/factory.h index e6f3a59b..1c8499b5 100644 --- a/factory.h +++ b/factory.h @@ -15,7 +15,7 @@ #define __CAUSALLM_FACTORY_H__ #include -#include +#include #include namespace quick_dot_ai { @@ -26,7 +26,7 @@ namespace quick_dot_ai { class Factory { public: using Creator = - std::function(json &, json &, json &)>; + std::function(json &, json &, json &)>; static Factory &Instance() { static Factory factory; @@ -37,9 +37,9 @@ class Factory { creators[key] = creator; } - std::unique_ptr create(const std::string &key, json &cfg, - json &generation_cfg, - json &nntr_cfg) const { + std::unique_ptr create(const std::string &key, json &cfg, + json &generation_cfg, + json &nntr_cfg) const { auto it = creators.find(key); if (it != creators.end()) { return (it->second)(cfg, generation_cfg, nntr_cfg); diff --git a/main.cpp b/main.cpp index 0e2fa09c..19646732 100644 --- a/main.cpp +++ b/main.cpp @@ -303,16 +303,14 @@ int main(int argc, char *argv[]) { model->initialize(); model->load_weight(weight_file); - bool do_sample = generation_cfg.value("do_sample", false); - #ifdef PROFILE start_peak_tracker(); #endif #if defined(_WIN32) - model->run(input_text.c_str(), do_sample, system_head_prompt.c_str(), + model->run(input_text.c_str(), system_head_prompt.c_str(), system_tail_prompt.c_str()); #else - model->run(input_text, do_sample, system_head_prompt, system_tail_prompt); + model->run(input_text, system_head_prompt, system_tail_prompt); #endif #ifdef PROFILE stop_and_print_peak(); diff --git a/models/causal_lm.cpp b/models/causal_lm.cpp index 63e0ad99..d09d1b8f 100644 --- a/models/causal_lm.cpp +++ b/models/causal_lm.cpp @@ -98,6 +98,9 @@ void CausalLM::setupParameters(json &cfg, json &generation_cfg, TEMPERATURE = generation_cfg.contains("temperature") ? generation_cfg["temperature"].get() : 0.7; + DO_SAMPLE = generation_cfg.contains("do_sample") + ? generation_cfg["do_sample"].get() + : false; global_token_len = 0; } @@ -285,8 +288,12 @@ void CausalLM::registerCustomLayers() { } } -void CausalLM::run(const WSTR prompt, bool do_sample, const WSTR system_prompt, - const WSTR tail_prompt, bool log_output) { +void CausalLM::run(const WSTR prompt, void *output_buf, bool log_output) { + run(prompt, "", "", output_buf, log_output); +} + +void CausalLM::run(const WSTR prompt, const WSTR system_prompt, + const WSTR tail_prompt, void *output_buf, bool log_output) { auto start_total = std::chrono::high_resolution_clock::now(); if (!is_initialized) { @@ -489,7 +496,7 @@ void CausalLM::run(const WSTR prompt, bool do_sample, const WSTR system_prompt, model->incremental_inference(BATCH_SIZE, input, label, input_len, token_generation_idx - 1 + global_token_len, token_generation_idx + global_token_len); - std::vector ids_list(generate(output_interval[0], do_sample)); + std::vector ids_list(generate(output_interval[0], DO_SAMPLE)); if (token_generation_idx < input_len) { for (unsigned int b = 0; b < BATCH_SIZE; ++b) { input_sample[static_cast(b) * MAX_SEQ_LEN] = @@ -536,6 +543,10 @@ void CausalLM::run(const WSTR prompt, bool do_sample, const WSTR system_prompt, global_token_len += (generation_cnt + init_len); + if (output_buf != nullptr) { + *static_cast *>(output_buf) = output_list; + } + auto finish_generation = std::chrono::high_resolution_clock::now(); auto generation_duration = std::chrono::duration_cast(finish_generation - diff --git a/models/causal_lm.h b/models/causal_lm.h index 9de7b103..cf4da746 100644 --- a/models/causal_lm.h +++ b/models/causal_lm.h @@ -67,10 +67,16 @@ WIN_EXPORT class CausalLM : virtual public Transformer { } /** - * @brief run the CausalLM model + * @brief run the CausalLM model (simple) */ - void run(const WSTR prompt, bool do_sample = false, - const WSTR system_prompt = "", const WSTR tail_prompt = "", + void run(const WSTR prompt, void *output_buf = nullptr, + bool log_output = true) override; + + /** + * @brief run the CausalLM model (full) + */ + void run(const WSTR prompt, const WSTR system_prompt = "", + const WSTR tail_prompt = "", void *output_buf = nullptr, bool log_output = true) override; /** @@ -142,6 +148,8 @@ WIN_EXPORT class CausalLM : virtual public Transformer { unsigned int TOP_K; float TOP_P; + bool DO_SAMPLE = false; /**< Whther to use sampling for generation */ + std::vector BAD_WORD_IDS; /**< List of bad word IDs */ unsigned int NUM_BADWORDS; /**< Number of bad words */ diff --git a/models/sentence_transformer.cpp b/models/sentence_transformer.cpp index a9b7bcd2..270692ec 100644 --- a/models/sentence_transformer.cpp +++ b/models/sentence_transformer.cpp @@ -178,8 +178,13 @@ void SentenceTransformer::addModule(const std::string &type, int idx) { model->addLayer(layer); } -void SentenceTransformer::run(const WSTR prompt, bool do_sample, - const WSTR system_prompt, const WSTR tail_prompt, +void SentenceTransformer::run(const WSTR prompt, void *output_buf, + bool log_output) { + run(prompt, "", "", output_buf, log_output); +} + +void SentenceTransformer::run(const WSTR prompt, const WSTR system_prompt, + const WSTR tail_prompt, void *output_buf, bool log_output) { try { @@ -203,9 +208,14 @@ void SentenceTransformer::run(const WSTR prompt, bool do_sample, } } - // output should be deallocated after use. - for (auto out : results) { - delete[] out; + if (output_buf != nullptr) { + // Caller is responsible for dellocation + *static_cast *>(output_buf) = results; + } else { + // output should be deallocated after use. + for (auto out : results) { + delete[] out; + } } } catch (const std::exception &e) { std::cerr << "Error during embedding run: " << e.what() << std::endl; diff --git a/models/sentence_transformer.h b/models/sentence_transformer.h index 9511d7f4..194ca5ff 100644 --- a/models/sentence_transformer.h +++ b/models/sentence_transformer.h @@ -41,10 +41,16 @@ WIN_EXPORT class SentenceTransformer : virtual public Transformer { virtual ~SentenceTransformer() {} /** - * @brief run the SentenceTransformer model + * @brief run the SentenceTransformer model (simple) */ - void run(const WSTR prompt, bool do_sample = false, - const WSTR system_prompt = "", const WSTR tail_prmopt = "", + void run(const WSTR prompt, void *output_buf = nullptr, + bool log_output = true) override; + + /** + * @brief run the SentenceTransformer model (full) + */ + void run(const WSTR prompt, const WSTR system_prompt = "", + const WSTR tail_prmopt = "", void *output_buf = nullptr, bool log_output = true) override; /** diff --git a/models/transformer.cpp b/models/transformer.cpp index c8324758..8984d402 100644 --- a/models/transformer.cpp +++ b/models/transformer.cpp @@ -272,8 +272,12 @@ void Transformer::save_weight( } }; -void Transformer::run(const WSTR prompt, bool do_sample, - const WSTR system_prompt, const WSTR tail_prompt, +void Transformer::run(const WSTR prompt, void *output_buf, bool log_output) { + run(prompt, "", "", output_buf, log_output); +} + +void Transformer::run(const WSTR prompt, const WSTR system_prompt, + const WSTR tail_prompt, void *output_buf, bool log_output) { if (!is_initialized) { throw std::runtime_error( diff --git a/models/transformer.h b/models/transformer.h index 333a33a2..f0b35372 100644 --- a/models/transformer.h +++ b/models/transformer.h @@ -24,46 +24,19 @@ #define __TRANSFORMER_H__ #pragma once -#ifdef _WIN32 -#define WIN_EXPORT __declspec(dllexport) -#define WSTR std::wstring -#define WCHAR_P wchar_t * -#else -#define WIN_EXPORT -#define WSTR std::string -#define WCHAR_P std::string & -#endif -#include +#include #include -#include #include -#include - -#include "json.hpp" -#include "performance_metrics.h" -#include -#include -#include +#include namespace quick_dot_ai { -/*** ALIAS ****/ -using LayerHandle = std::shared_ptr; -using ModelHandle = std::unique_ptr; - -using json = nlohmann::json; - -/** - * @brief Model Type Enum - */ -enum class ModelType { MODEL, CAUSALLM, EMBEDDING, UNKNOWN }; - /** * @brief Transformer Class */ -WIN_EXPORT class Transformer { +WIN_EXPORT class Transformer : virtual public TransformerBase { public: /** @@ -84,17 +57,17 @@ WIN_EXPORT class Transformer { /** * @brief Initialize and Construct the Transformer model */ - virtual void initialize(); + void initialize() override; /** * @brief Load the model weights from a file */ - virtual void load_weight(const std::string &weight_path); + void load_weight(const std::string &weight_path) override; /** * @brief Save the weight to a file */ - virtual void save_weight(const std::string &weight_path); + void save_weight(const std::string &weight_path) override; /** * @brief Save the weight to a file with type conversion @@ -102,18 +75,24 @@ WIN_EXPORT class Transformer { * @param dtype Global target data type for all layers (NONE = keep original) * @param layer_dtype_map Per-layer data type overrides (layer_name -> dtype) */ - virtual void - save_weight(const std::string &weight_path, - ml::train::TensorDim::DataType dtype, - const std::map - &layer_dtype_map = {}); + void save_weight(const std::string &weight_path, + ml::train::TensorDim::DataType dtype, + const std::map + &layer_dtype_map = {}) override; + + /** + * @copydoc TransformerBase::run(const WSTR, void *, bool) + */ + void run(const WSTR prompt, void *output_buf = nullptr, + bool log_output = true) override; /** - * @brief run the Transformer model + * @brief TransformerBase::run(const WSTR, const WSTR, const WSTR, void *, + * bool) */ - virtual void run(const WSTR prompt, bool do_sample = false, - const WSTR system_prompt = "", const WSTR tail_prompt = "", - bool log_output = true); + void run(const WSTR prompt, const WSTR system_prompt = "", + const WSTR tail_prompt = "", void *output_buf = nullptr, + bool log_output = true) override; /** * @brief Get PerformanceMetrics @@ -159,26 +138,12 @@ WIN_EXPORT class Transformer { */ virtual void registerCustomLayers(); - /** - * @brief register Outputs - */ - bool is_initialized = false; /**< Flag to check if the model is initialized */ - ModelHandle model; - - /** tokenizer */ - std::unique_ptr tokenizer; - - unsigned int NUM_VOCAB; - int DIM; int HEAD_DIM; int INTERMEDIATE_SIZE; - int NUM_LAYERS; bool USE_VOCAB_SELECTION; bool TIE_WORD_EMBEDDINGS; - unsigned int MAX_SEQ_LEN; int NUM_HEADS; int NUM_KEY_VALUE_HEADS; - int NUM_TO_GENERATE; std::string MODEL_TENSOR_TYPE; std::string EMBEDDING_DTYPE; /** embedding dtype */ std::string FC_LAYER_DTYPE; /** custom_fc_lora */ @@ -190,8 +155,6 @@ WIN_EXPORT class Transformer { float EMBEDDING_SCALE = 1.0f; int GQA_SIZE; - unsigned int BATCH_SIZE; /**< Batch size for the model */ - unsigned int INIT_SEQ_LEN; /**< Initial sequence length */ unsigned int MAX_POSITION_EMBEDDINGS; /**< max_position embeddings */ bool MEMORY_SWAP; /**< memory swap option */ unsigned int FSU_LOOKAHEAD; @@ -201,28 +164,7 @@ WIN_EXPORT class Transformer { // Performance metrics PerformanceMetrics performance_metrics; }; -/** - * Loads JSON data from a file with detailed error handling - * @param file_path Path to JSON file - * @return JSON object - * @throws std::runtime_error on file open or parse failure - */ -inline json LoadJsonFile(const std::string &file_path) { - std::ifstream file(file_path); - if (!file.is_open()) { - throw std::runtime_error("Failed to open file: " + file_path + - " | Reason: " + std::strerror(errno)); - } - try { - json data; - file >> data; - return data; - } catch (const json::parse_error &e) { - throw std::runtime_error("JSON parse error in " + file_path + - " | Details: " + e.what()); - } -} } // namespace quick_dot_ai #endif diff --git a/models/transformer_base.h b/models/transformer_base.h new file mode 100644 index 00000000..eba224dc --- /dev/null +++ b/models/transformer_base.h @@ -0,0 +1,170 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2025 Eunju Yang + * + * @file transformer_base.h + * @date 31 Mar 2026 + * @see https://github.com/nntrainer/nntrainer + * @author Eunju Yang + * @bug No known bugs except for NYI items + * @note This transformer_base.h defines an abstract base class for + * Transformer-based models. It provides the common interface and shared state + * that both NNTrainer-based Transformer and QNN-based Transformer can inherit. + */ + +#ifndef __TRANSFORMER_BASE_H__ +#define __TRANSFORMER_BASE_H__ + +#pragma once +#ifdef _WIN32 +#define WIN_EXPORT __declspec(dllexport) +#define WSTR std::wstring +#define WCHAR_P wchar_t * +#else +#define WIN_EXPORT +#define WSTR std::string +#define WCHAR_P std::string & +#endif + +#include +#include + +#include +#include +#include +#include +#include + +#include "json.hpp" +#include "performance_metrics.h" + +namespace quick_dot_ai { + +/*** ALIAS ****/ +using LayerHandle = std::shared_ptr; +using ModelHandle = std::unique_ptr; + +using json = nlohmann::json; + +/** + * @brief Model Type Enum + */ +enum class ModelType { MODEL, CAUSALLM, EMBEDDING, UNKNOWN }; + +/** + * @brief TransformerBase Abstract Class + * @note This is the common interface for all Transformer-based models. + * Both NNTrainer Transformer and QNN Transformer inherit from this + */ +WIN_EXPORT class TransformerBase { + +public: + /** + * @brief Default constructor + */ + TransformerBase() = default; + + /** + * @brief Destroy the TransformerBase object + */ + virtual ~TransformerBase() = default; + + /** + * @brief Initialize and Construct the Transformer model + */ + virtual void initialize() = 0; + + /** + * @brief Load the model weights from a file + */ + virtual void load_weight(const std::string &weight_path) = 0; + + /** + * @brief Save the weight to a file + */ + virtual void save_weight(const std::string &weight_path) = 0; + + /** + * @brief Save the weight to a file with type conversion + * @param weight_path Path to save the weight file + * @param dtype Global target data type for all layers (NONE = keep original) + * @param layer_dtype_map Per-layer data type overrides (layer_name -> dtype) + * @note Default implementation throws; concrete subclasses that support + * type-converted save (e.g. NNTrainer-based Transformer) override it. + */ + virtual void + save_weight(const std::string &weight_path, + ml::train::TensorDim::DataType dtype, + const std::map + &layer_dtype_map = {}) { + throw std::runtime_error( + "save_weight with type conversion is not implemented for this " + "TransformerBase subclass"); + } + + /** + * @brief run the Transformer model (simple) + * @param prompt User prompt + * @param output_buf Optional output pointer. For CausalLM, pass + * std::vector*. For Sentence Transformer, pass + * std:vector *. nullptr to skip output collection. + * @param log_output Whether to log output to stdout + */ + virtual void run(const WSTR prompt, void *output_buf = nullptr, + bool log_output = true) = 0; + + /** + * @brief run the Transformer model (full) + * @param prompt User prompt + * @param system_prompt System prompt prepended to user prompt + * @param tail_prompt Tail prompt appended to user prompt + * @param output_buf Optional output pointer (see simple overload for types) + * @param log_output Whether to log output to stdout + */ + virtual void run(const WSTR prompt, const WSTR system_prompt = "", + const WSTR tail_prompt = "", void *output_buf = nullptr, + bool log_output = true) = 0; + +protected: + bool is_initialized = false; /**< Flag to check if the model is initialized */ + ModelHandle model; + + /** tokenizer */ + std::unique_ptr tokenizer; + + unsigned int NUM_VOCAB; + int DIM; + int NUM_LAYERS; + + unsigned int MAX_SEQ_LEN; + unsigned int BATCH_SIZE; + unsigned int INIT_SEQ_LEN; + unsigned int NUM_TO_GENERATE; +}; + +/** + * Loads JSON data from a file with detailed error handling + * @param file_path Path to JSON file + * @return JSON object + * @throws std::runtime_error on file open or parse failure + */ +inline json LoadJsonFile(const std::string &file_path) { + std::ifstream file(file_path); + if (!file.is_open()) { + throw std::runtime_error("Failed to open file: " + file_path + + " | Reason: " + std::strerror(errno)); + } + + try { + json data; + file >> data; + return data; + } catch (const json::parse_error &e) { + throw std::runtime_error("JSON parse error in " + file_path + + " | Details: " + e.what()); + } +} + +} // namespace quick_dot_ai + +#endif