Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions api/causal_lm_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@

using json = nlohmann::json;

static std::unique_ptr<quick_dot_ai::Transformer> g_model;
static std::unique_ptr<quick_dot_ai::TransformerBase> g_model;
static std::mutex g_mutex;
static bool g_initialized = false;
static std::string g_architecture = "";
Expand Down Expand Up @@ -565,10 +565,10 @@ ErrorCode runModel(const char *inputTextPrompt, const char **outputText) {

// We assume single batch request for this API
#if defined(_WIN32)
g_model->run(std::wstring(input.begin(), input.end()), false, L"", L"",
g_model->run(std::wstring(input.begin(), input.end()), L"", L"", nullptr,
g_verbose);
#else
g_model->run(input, false, "", "", g_verbose);
g_model->run(input, "", "", nullptr, g_verbose);
#endif

auto causal_lm_model = dynamic_cast<quick_dot_ai::CausalLM *>(g_model.get());
Expand Down
10 changes: 5 additions & 5 deletions factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#define __CAUSALLM_FACTORY_H__

#include <ostream>
#include <transformer.h>
#include <transformer_base.h>
#include <unordered_map>

namespace quick_dot_ai {
Expand All @@ -26,7 +26,7 @@ namespace quick_dot_ai {
class Factory {
public:
using Creator =
std::function<std::unique_ptr<Transformer>(json &, json &, json &)>;
std::function<std::unique_ptr<TransformerBase>(json &, json &, json &)>;

static Factory &Instance() {
static Factory factory;
Expand All @@ -37,9 +37,9 @@ class Factory {
creators[key] = creator;
}

std::unique_ptr<Transformer> create(const std::string &key, json &cfg,
json &generation_cfg,
json &nntr_cfg) const {
std::unique_ptr<TransformerBase> create(const std::string &key, json &cfg,
json &generation_cfg,
json &nntr_cfg) const {
auto it = creators.find(key);
if (it != creators.end()) {
return (it->second)(cfg, generation_cfg, nntr_cfg);
Expand Down
6 changes: 2 additions & 4 deletions main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -303,16 +303,14 @@ int main(int argc, char *argv[]) {
model->initialize();
model->load_weight(weight_file);

bool do_sample = generation_cfg.value("do_sample", false);

#ifdef PROFILE
start_peak_tracker();
#endif
#if defined(_WIN32)
model->run(input_text.c_str(), do_sample, system_head_prompt.c_str(),
model->run(input_text.c_str(), system_head_prompt.c_str(),
system_tail_prompt.c_str());
#else
model->run(input_text, do_sample, system_head_prompt, system_tail_prompt);
model->run(input_text, system_head_prompt, system_tail_prompt);
#endif
#ifdef PROFILE
stop_and_print_peak();
Expand Down
17 changes: 14 additions & 3 deletions models/causal_lm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ void CausalLM::setupParameters(json &cfg, json &generation_cfg,
TEMPERATURE = generation_cfg.contains("temperature")
? generation_cfg["temperature"].get<float>()
: 0.7;
DO_SAMPLE = generation_cfg.contains("do_sample")
? generation_cfg["do_sample"].get<bool>()
: false;
global_token_len = 0;
}

Expand Down Expand Up @@ -285,8 +288,12 @@ void CausalLM::registerCustomLayers() {
}
}

void CausalLM::run(const WSTR prompt, bool do_sample, const WSTR system_prompt,
const WSTR tail_prompt, bool log_output) {
void CausalLM::run(const WSTR prompt, void *output_buf, bool log_output) {
run(prompt, "", "", output_buf, log_output);
}

void CausalLM::run(const WSTR prompt, const WSTR system_prompt,
const WSTR tail_prompt, void *output_buf, bool log_output) {

auto start_total = std::chrono::high_resolution_clock::now();
if (!is_initialized) {
Expand Down Expand Up @@ -489,7 +496,7 @@ void CausalLM::run(const WSTR prompt, bool do_sample, const WSTR system_prompt,
model->incremental_inference(BATCH_SIZE, input, label, input_len,
token_generation_idx - 1 + global_token_len,
token_generation_idx + global_token_len);
std::vector<unsigned int> ids_list(generate(output_interval[0], do_sample));
std::vector<unsigned int> ids_list(generate(output_interval[0], DO_SAMPLE));
if (token_generation_idx < input_len) {
for (unsigned int b = 0; b < BATCH_SIZE; ++b) {
input_sample[static_cast<size_t>(b) * MAX_SEQ_LEN] =
Expand Down Expand Up @@ -536,6 +543,10 @@ void CausalLM::run(const WSTR prompt, bool do_sample, const WSTR system_prompt,

global_token_len += (generation_cnt + init_len);

if (output_buf != nullptr) {
*static_cast<std::vector<std::string> *>(output_buf) = output_list;
}

auto finish_generation = std::chrono::high_resolution_clock::now();
auto generation_duration =
std::chrono::duration_cast<std::chrono::milliseconds>(finish_generation -
Expand Down
14 changes: 11 additions & 3 deletions models/causal_lm.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,16 @@ WIN_EXPORT class CausalLM : virtual public Transformer {
}

/**
* @brief run the CausalLM model
* @brief run the CausalLM model (simple)
*/
void run(const WSTR prompt, bool do_sample = false,
const WSTR system_prompt = "", const WSTR tail_prompt = "",
void run(const WSTR prompt, void *output_buf = nullptr,
bool log_output = true) override;

/**
* @brief run the CausalLM model (full)
*/
void run(const WSTR prompt, const WSTR system_prompt = "",
const WSTR tail_prompt = "", void *output_buf = nullptr,
bool log_output = true) override;

/**
Expand Down Expand Up @@ -142,6 +148,8 @@ WIN_EXPORT class CausalLM : virtual public Transformer {
unsigned int TOP_K;
float TOP_P;

bool DO_SAMPLE = false; /**< Whther to use sampling for generation */

std::vector<unsigned int> BAD_WORD_IDS; /**< List of bad word IDs */
unsigned int NUM_BADWORDS; /**< Number of bad words */

Expand Down
20 changes: 15 additions & 5 deletions models/sentence_transformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,13 @@ void SentenceTransformer::addModule(const std::string &type, int idx) {
model->addLayer(layer);
}

void SentenceTransformer::run(const WSTR prompt, bool do_sample,
const WSTR system_prompt, const WSTR tail_prompt,
void SentenceTransformer::run(const WSTR prompt, void *output_buf,
bool log_output) {
run(prompt, "", "", output_buf, log_output);
}

void SentenceTransformer::run(const WSTR prompt, const WSTR system_prompt,
const WSTR tail_prompt, void *output_buf,
bool log_output) {

try {
Expand All @@ -203,9 +208,14 @@ void SentenceTransformer::run(const WSTR prompt, bool do_sample,
}
}

// output should be deallocated after use.
for (auto out : results) {
delete[] out;
if (output_buf != nullptr) {
// Caller is responsible for dellocation
*static_cast<std::vector<float *> *>(output_buf) = results;
} else {
// output should be deallocated after use.
for (auto out : results) {
delete[] out;
}
}
} catch (const std::exception &e) {
std::cerr << "Error during embedding run: " << e.what() << std::endl;
Expand Down
12 changes: 9 additions & 3 deletions models/sentence_transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,16 @@ WIN_EXPORT class SentenceTransformer : virtual public Transformer {
virtual ~SentenceTransformer() {}

/**
* @brief run the SentenceTransformer model
* @brief run the SentenceTransformer model (simple)
*/
void run(const WSTR prompt, bool do_sample = false,
const WSTR system_prompt = "", const WSTR tail_prmopt = "",
void run(const WSTR prompt, void *output_buf = nullptr,
bool log_output = true) override;

/**
* @brief run the SentenceTransformer model (full)
*/
void run(const WSTR prompt, const WSTR system_prompt = "",
const WSTR tail_prmopt = "", void *output_buf = nullptr,
bool log_output = true) override;

/**
Expand Down
8 changes: 6 additions & 2 deletions models/transformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,12 @@ void Transformer::save_weight(
}
};

void Transformer::run(const WSTR prompt, bool do_sample,
const WSTR system_prompt, const WSTR tail_prompt,
void Transformer::run(const WSTR prompt, void *output_buf, bool log_output) {
run(prompt, "", "", output_buf, log_output);
}

void Transformer::run(const WSTR prompt, const WSTR system_prompt,
const WSTR tail_prompt, void *output_buf,
bool log_output) {
if (!is_initialized) {
throw std::runtime_error(
Expand Down
100 changes: 21 additions & 79 deletions models/transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,46 +24,19 @@
#define __TRANSFORMER_H__

#pragma once
#ifdef _WIN32
#define WIN_EXPORT __declspec(dllexport)
#define WSTR std::wstring
#define WCHAR_P wchar_t *
#else
#define WIN_EXPORT
#define WSTR std::string
#define WCHAR_P std::string &
#endif

#include <layer.h>
#include <limits.h>
#include <map>
#include <model.h>
#include <random>

#include <limits.h>

#include "json.hpp"
#include "performance_metrics.h"
#include <fstream>
#include <tokenizers_c.h>
#include <tokenizers_cpp.h>
#include <transformer_base.h>

namespace quick_dot_ai {

/*** ALIAS ****/
using LayerHandle = std::shared_ptr<ml::train::Layer>;
using ModelHandle = std::unique_ptr<ml::train::Model>;

using json = nlohmann::json;

/**
* @brief Model Type Enum
*/
enum class ModelType { MODEL, CAUSALLM, EMBEDDING, UNKNOWN };

/**
* @brief Transformer Class
*/
WIN_EXPORT class Transformer {
WIN_EXPORT class Transformer : virtual public TransformerBase {

public:
/**
Expand All @@ -84,36 +57,42 @@ WIN_EXPORT class Transformer {
/**
* @brief Initialize and Construct the Transformer model
*/
virtual void initialize();
void initialize() override;

/**
* @brief Load the model weights from a file
*/
virtual void load_weight(const std::string &weight_path);
void load_weight(const std::string &weight_path) override;

/**
* @brief Save the weight to a file
*/
virtual void save_weight(const std::string &weight_path);
void save_weight(const std::string &weight_path) override;

/**
* @brief Save the weight to a file with type conversion
* @param weight_path Path to save the weight file
* @param dtype Global target data type for all layers (NONE = keep original)
* @param layer_dtype_map Per-layer data type overrides (layer_name -> dtype)
*/
virtual void
save_weight(const std::string &weight_path,
ml::train::TensorDim::DataType dtype,
const std::map<std::string, ml::train::TensorDim::DataType>
&layer_dtype_map = {});
void save_weight(const std::string &weight_path,
ml::train::TensorDim::DataType dtype,
const std::map<std::string, ml::train::TensorDim::DataType>
&layer_dtype_map = {}) override;

/**
* @copydoc TransformerBase::run(const WSTR, void *, bool)
*/
void run(const WSTR prompt, void *output_buf = nullptr,
bool log_output = true) override;

/**
* @brief run the Transformer model
* @brief TransformerBase::run(const WSTR, const WSTR, const WSTR, void *,
* bool)
*/
virtual void run(const WSTR prompt, bool do_sample = false,
const WSTR system_prompt = "", const WSTR tail_prompt = "",
bool log_output = true);
void run(const WSTR prompt, const WSTR system_prompt = "",
const WSTR tail_prompt = "", void *output_buf = nullptr,
bool log_output = true) override;

/**
* @brief Get PerformanceMetrics
Expand Down Expand Up @@ -159,26 +138,12 @@ WIN_EXPORT class Transformer {
*/
virtual void registerCustomLayers();

/**
* @brief register Outputs
*/
bool is_initialized = false; /**< Flag to check if the model is initialized */
ModelHandle model;

/** tokenizer */
std::unique_ptr<tokenizers::Tokenizer> tokenizer;

unsigned int NUM_VOCAB;
int DIM;
int HEAD_DIM;
int INTERMEDIATE_SIZE;
int NUM_LAYERS;
bool USE_VOCAB_SELECTION;
bool TIE_WORD_EMBEDDINGS;
unsigned int MAX_SEQ_LEN;
int NUM_HEADS;
int NUM_KEY_VALUE_HEADS;
int NUM_TO_GENERATE;
std::string MODEL_TENSOR_TYPE;
std::string EMBEDDING_DTYPE; /** embedding dtype */
std::string FC_LAYER_DTYPE; /** custom_fc_lora */
Expand All @@ -190,8 +155,6 @@ WIN_EXPORT class Transformer {
float EMBEDDING_SCALE = 1.0f;
int GQA_SIZE;

unsigned int BATCH_SIZE; /**< Batch size for the model */
unsigned int INIT_SEQ_LEN; /**< Initial sequence length */
unsigned int MAX_POSITION_EMBEDDINGS; /**< max_position embeddings */
bool MEMORY_SWAP; /**< memory swap option */
unsigned int FSU_LOOKAHEAD;
Expand All @@ -201,28 +164,7 @@ WIN_EXPORT class Transformer {
// Performance metrics
PerformanceMetrics performance_metrics;
};
/**
* Loads JSON data from a file with detailed error handling
* @param file_path Path to JSON file
* @return JSON object
* @throws std::runtime_error on file open or parse failure
*/
inline json LoadJsonFile(const std::string &file_path) {
std::ifstream file(file_path);
if (!file.is_open()) {
throw std::runtime_error("Failed to open file: " + file_path +
" | Reason: " + std::strerror(errno));
}

try {
json data;
file >> data;
return data;
} catch (const json::parse_error &e) {
throw std::runtime_error("JSON parse error in " + file_path +
" | Details: " + e.what());
}
}
} // namespace quick_dot_ai

#endif
Loading
Loading