diff --git a/.github/actions/override-submodules-from-pr/action.yml b/.github/actions/override-submodules-from-pr/action.yml deleted file mode 100644 index fa874672..00000000 --- a/.github/actions/override-submodules-from-pr/action.yml +++ /dev/null @@ -1,58 +0,0 @@ -name: "Override submodules from PR title" -description: >- - Parse the triggering pull request title for `[Depend on @]` - markers and replace matching submodules under `subprojects/` with - the tip of the referenced pull request head. Useful for running CI - against an in-flight dependency PR without committing a submodule bump. - -inputs: - pr-title: - description: "The pull request title to parse." - required: true - -runs: - using: composite - steps: - - name: Override submodules from PR dependency markers - shell: bash - env: - PR_TITLE: ${{ inputs.pr-title }} - run: | - set -euo pipefail - - if [[ -z "${PR_TITLE:-}" ]]; then - echo "No PR title provided; skipping submodule overrides." - exit 0 - fi - - # Extract every `[Depend on @]` marker from the title. - mapfile -t pairs < <( - grep -oE '\[Depend on [A-Za-z0-9._-]+ @[0-9]+\]' <<<"$PR_TITLE" \ - | sed -E 's/^\[Depend on ([A-Za-z0-9._-]+) @([0-9]+)\]$/\1 \2/' - ) - - if [[ ${#pairs[@]} -eq 0 ]]; then - echo "No [Depend on @] markers found in PR title." - exit 0 - fi - - for pair in "${pairs[@]}"; do - name="${pair% *}" - pr_num="${pair##* }" - sub="subprojects/${name}" - - if [[ ! -e "${sub}/.git" ]]; then - echo "::warning::Override requested for '${name}' but '${sub}' is not an initialized submodule; skipping." - continue - fi - - echo "::group::Override ${sub} with PR #${pr_num}" - ( - cd "${sub}" - git fetch --no-tags origin "pull/${pr_num}/head:pr-${pr_num}" - git checkout "pr-${pr_num}" - git log -1 --oneline - git submodule update --init --recursive - ) - echo "::endgroup::" - done diff --git a/.github/codeql/codeql-config.yml b/.github/codeql/codeql-config.yml deleted file mode 100644 index 6b3422e8..00000000 --- a/.github/codeql/codeql-config.yml +++ /dev/null @@ -1,11 +0,0 @@ -name: "Quick.AI CodeQL config" - -# Limit scanning to Quick.AI's own code. The `subprojects/` tree is a -# vendored copy of nntrainer (a separate project with its own CI) and -# contains thousands of C++/Python files that are not part of Quick.AI's -# attack surface. -paths-ignore: - - subprojects/** - - tokenizers-cpp-build/** - - build/** - - builddir/** diff --git a/.gitignore b/.gitignore index a97051ad..deea8ba5 100644 --- a/.gitignore +++ b/.gitignore @@ -1,43 +1,44 @@ -# Build outputs -build/ -builddir/ -*.o -*.a -*.so -*.dylib -*.dll -*.exe +# AI Agent directories +.opencode/plans/ -# Meson subproject cache / auto-generated wrap redirects -subprojects/packagecache/ -subprojects/.wraplock -subprojects/*.wrap -!subprojects/gtest.wrap -# Wrap-extracted source trees (fetched by `meson setup`, never checked in) -subprojects/googletest-*/ +# Proprietary model sources β€” internal only; must never be committed to the +# public remote. The public tree ships a fixed allow-list of model directories; +# every OTHER model subdirectory is ignored so a new proprietary model dropped +# into the tree cannot be `git add`ed by accident. (gitignore never hides files +# that are already tracked, so the allow-listed dirs below stay tracked.) +/src/models/*/ +!/src/models/qnn/ +/src/models/qnn/*/ +!/src/models/qnn/gemma4-e2b-qnn/ -# Test staging (download_qwen3_0.6b.sh) -.test_cache/ -models/qwen3-0.6b-w16a16/ +# Build directories (meson) +builddir/ +builddir_x86/ +builddir_android/ -# NNTrainer runtime log directory, created at test time -logs/ +# NDK build outputs +src/jni/libs/ +src/jni/obj/ +api/jni/libs/ +api/jni/obj/ +api-app/jni/libs/ +api-app/jni/obj/ -# Encoder staging -encoder/ -encoder-*.tar.gz -json.hpp +# Copied headers +include/ +!qnn/jni/qnn/PAL/include/ -# Android build outputs -jni/libs/ -jni/obj/ +# res (model/QNN resource dirs) +res/ +!Android/SampleTestAPP/src/main/res/ -# Editor junk -.vscode/ -.idea/ -.DS_Store -*.swp +cross/android-aarch64.cross -# Python -__pycache__/ -*.pyc +*.dex +*.bin +*.json +*.so +*.log +# Internal-only docs β€” must NOT ship in public releases (kept locally only) +docs/tasks/ +docs/superpowers/ diff --git a/.gitmodules b/.gitmodules index 9a1a18e8..d612595f 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,6 @@ -[submodule "subprojects/nntrainer"] - path = subprojects/nntrainer - url = https://github.com/eunjuyang/nntrainer.git +[submodule "nntrainer"] + path = nntrainer + url = https://github.sec.samsung.net/j2z0-lee/nntrainer +[submodule "xgrammar"] + path = xgrammar + url = https://github.com/mlc-ai/xgrammar.git diff --git a/Android/.gitignore b/Android/.gitignore new file mode 100644 index 00000000..aa724b77 --- /dev/null +++ b/Android/.gitignore @@ -0,0 +1,15 @@ +*.iml +.gradle +/local.properties +/.idea/caches +/.idea/libraries +/.idea/modules.xml +/.idea/workspace.xml +/.idea/navEditor.xml +/.idea/assetWizardSettings.xml +.DS_Store +/build +/captures +.externalNativeBuild +.cxx +local.properties diff --git a/Android/Architecture.md b/Android/Architecture.md new file mode 100644 index 00000000..703e3ecf --- /dev/null +++ b/Android/Architecture.md @@ -0,0 +1,168 @@ +# Android Architecture πŸ“± + +This document describes the current Android state of Quick.AI and separates it +from the planned REST/foreground-service layer that older documents described +as if it already existed. + +## βœ… Current Gradle Modules + +The Android build currently includes: + +```text +Android/ +β”œβ”€β”€ QuickDotAI/ # AAR module +└── SampleTestAPP/ # Direct sample app using the AAR +``` + +`Android/settings.gradle.kts` includes only `:QuickDotAI` and +`:SampleTestAPP`. + +## 🧱 QuickDotAI AAR + +`QuickDotAI` exposes the public Kotlin API in +`com.example.quickdotai`. + +Key files: + +| File | Role | +|---|---| +| `QuickDotAI.kt` | Public interface and `BackendResult` / `StreamSink` contracts | +| `Types.kt` | Serializable request/response DTOs, model enums, errors, metrics | +| `NativeQuickDotAI.kt` | Kotlin wrapper around one native `CausalLmHandle` | +| `NativeCausalLm.kt` | Low-level JNI declarations | +| `LiteRTLm.kt` | LiteRT-LM engine wrapper for the `gemma4` (`ModelIds.GEMMA4`) model | +| `NativeChatSession.kt` | Native chat-session helper | +| `LiteRTLmChatSession.kt` | LiteRT-LM chat-session helper | +| `ImageStore.kt` | Per-session image cache | +| `LlavaNextImageProcessor.kt` | Native multimodal preprocessing helper | +| `src/main/cpp/quickai_jni.cpp` | JNI bridge to `quick_dot_ai_api.h` | +| `src/main/cpp/CMakeLists.txt` | Builds `libquickai_jni.so` and links `libquick_dot_ai_api.so` | + +## πŸ”Œ Native Path + +`NativeQuickDotAI` owns one native handle: + +```text +NativeQuickDotAI + └── NativeCausalLm.ensureLoaded() + β”œβ”€β”€ System.loadLibrary("qnn_context") + └── System.loadLibrary("quickai_jni") + └── links/calls libquick_dot_ai_api.so +``` + +The native API surface is declared in `api/quick_dot_ai_api.h`. +The preferred calls are handle-based: + +- `loadModelHandle` +- `runModelHandleWithMessagesStreaming` +- `runModelHandleWithJsonStreaming` +- `runMultimodalHandleStreaming` +- `cancelModelHandle` +- `destroyModelHandle` + +## ModelCatalog + +Model selection in the AAR is driven by the `ModelCatalog` singleton. Models +are identified by string ids rather than an enum. + +### Seeding + +`ModelCatalog` is seeded on first access by calling `nativeQueryCatalog()` +through JNI, which delegates to `getModelCatalogJson()` in +`libquick_dot_ai_api.so`. Hardcoded LiteRT descriptors (e.g., `gemma4`) are +merged in at the Kotlin layer. + +### Key types + +| Type | Role | +|---|---| +| `enum class RuntimeKind { NATIVE, LITERT }` | Selects the engine path | +| `enum class Capability { STREAMING, MESSAGES_API, MULTIMODAL, TOOL_USE, EMBEDDING, MULTI_IMAGE }` | Per-model feature flags | +| `data class ModelDescriptor(id, family, displayName, runtime, backends, capabilities)` | Descriptor from the catalog | +| `object ModelIds` | String constants for well-known model ids | +| `object ModelCatalog` | Singleton: `all()`, `families()`, `selectable()`, `selectableFamilies()`, `runtimesFor(family)`, `backendsFor(family, rt)`, `resolve(family, rt, backend)`, `byId(id)` | + +### 3-axis cascading UI + +`SampleTestAPP` presents a 3-axis cascading UI: + +1. **Family** β€” populated from `ModelCatalog.selectableFamilies()` +2. **Runtime chip row** β€” populated from `ModelCatalog.runtimesFor(selectedFamily)` +3. **Backend chip row** β€” populated from `ModelCatalog.backendsFor(selectedFamily, selectedRuntime)` + +The app lists only **selectable** (generative) models. Embedding-only models +such as `tiny-bert` β€” which expose only the `EMBEDDING` capability and have no +public output path β€” are filtered out by `selectableFamilies()`. They remain in +the AAR catalog and are still reachable through `ModelCatalog.all()` / +`ModelCatalog.byId(...)`. + +The resolved descriptor is obtained via `ModelCatalog.resolve(family, runtime, backend)` +and passed directly to `createEngine()`. + +### Engine factory + +```kotlin +QuickDotAI.createEngine(context, descriptor: ModelDescriptor): QuickDotAI +``` + +`createEngine` dispatches to `NativeQuickDotAI` (for `RuntimeKind.NATIVE`) or +`LiteRTLm` (for `RuntimeKind.LITERT`) based on `descriptor.runtime`. + +### LoadModelRequest + +`LoadModelRequest.modelId` is a `String` catalog id. The cache key is +`"$modelId:${quantization.name}"`. The JNI call dispatched on load is +`loadModelHandleByNameNative`. + +## πŸŒ— LiteRT Runtime Path + +`LiteRTLm` is selected for the `gemma4` (`ModelIds.GEMMA4`) model and takes a `.litertlm` file path +through `LoadModelRequest.modelPath`. `visionBackend != null` enables +multimodal calls for engines/models that support image input. + +## 🧡 Threading Model + +A `QuickDotAI` instance is not internally thread-safe. Host apps should drive a +loaded engine from one worker thread. `SampleTestAPP` follows this pattern with +a background dispatcher. + +Streaming callbacks are delivered to the caller-provided `StreamSink`. +Apps that update UI must marshal callbacks to the main thread. + +## πŸ§ͺ SampleTestAPP + +`SampleTestAPP` is the current runnable Android sample. It links the +`:QuickDotAI` module directly; it does not start a REST service and does not +communicate over sockets. + +## πŸ—ΊοΈ Planned Service Layer + +The following pieces are design targets, not current Gradle modules: + +| Planned component | Status | +|---|---| +| `LauncherApp` foreground-service bootstrap UI | Planned | +| `QuickAIService` remote foreground service | Planned | +| NanoHTTPD loopback REST server | Planned | +| `RequestDispatcher`, `ModelRegistry`, `ModelWorker` | Planned | +| Standalone REST client app | Planned | + +When implemented, the service layer should wrap the same `QuickDotAI` AAR +contract rather than inventing a separate model API. + +## πŸ“¦ Packaging + +`apk-build-install.sh` performs the current full Android workflow: + +1. Build native libraries with `./build.sh --platform=android --enable-qnn --clean`. +2. Install/copy native shared libraries through `apk_install_android.sh`. +3. Copy `.so` files into `Android/QuickDotAI/prebuilt_libs/`. +4. Run Gradle install for `:SampleTestAPP`. + +Set `NDK_ROOT` inside `apk-build-install.sh` before using it on a new machine. + +## πŸ“Ž Related Docs + +- [QuickDotAI AAR API](QuickDotAI/README.md) +- [Android Native Async & Streaming](AsyncAndStreaming.md) +- [Main README](../README.md) diff --git a/Android/AsyncAndStreaming.md b/Android/AsyncAndStreaming.md new file mode 100644 index 00000000..bfd8b43c --- /dev/null +++ b/Android/AsyncAndStreaming.md @@ -0,0 +1,98 @@ +# Android Native Async & Streaming πŸ”„ + +Quick.AI streaming is synchronous at the native C boundary and asynchronous at +the host-app boundary. The native call blocks the worker thread while invoking a +callback for each token delta; the app decides how to dispatch those deltas to +UI or transport layers. + +## 🧭 Scope + +This document covers the current `QuickDotAI` AAR path: + +```text +QuickDotAI.kt + └── NativeQuickDotAI.kt + └── NativeCausalLm.kt + └── quickai_jni.cpp + └── quick_dot_ai_api.h / libquick_dot_ai_api.so +``` + +It does not describe the planned REST/foreground-service layer. + +## 🧡 Streaming Contract + +Native streaming functions are synchronous: + +```c +ErrorCode runModelHandleStreaming( + CausalLmHandle handle, + const char *inputTextPrompt, + CausalLmTokenCallback callback, + void *user_data); +``` + +While the function runs, it calls: + +```c +typedef int (*CausalLmTokenCallback)(const char *delta, void *user_data); +``` + +Returning `0` continues generation. Returning non-zero requests cooperative +cancellation at the next token boundary. + +## πŸ”Œ JNI Bridge + +`quickai_jni.cpp` converts native token callbacks into Kotlin listener calls. +Because callbacks run on the same thread that entered JNI, the bridge can use +the current `JNIEnv *` without attaching a new thread. + +Kotlin then forwards deltas to `StreamSink`: + +```kotlin +interface StreamSink { + fun onDelta(text: String) + fun onReasoningDelta(text: String) {} + fun onDone() + fun onError(error: QuickAiError, message: String?) +} +``` + +## 🧱 QuickDotAI Methods + +Current streaming methods include: + +| Method | Input shape | +|---|---| +| `runModelHandleWithMessagesStreaming()` | `List` | +| `runModelHandleWithJsonStreaming()` | OpenAI-style JSON string | +| `runMultimodalHandleWithMessagesStreaming()` | OpenAI-style messages with image parts | +| `runMultimodalHandleStreaming()` | `List` | +| `runChatModelHandleStreaming()` | Active chat session + text | +| `runChatMultimodalHandleStreaming()` | Active chat session + image/text parts | + +The removed flat methods (`runStreaming`, `runWithMessagesStreaming`, and +`chatRunStreaming`) should not be used in new docs or app code. + +## 🚦 Cancellation + +- `QuickDotAI.cancel()` forwards to the native handle cancel path when the + engine supports it. +- `QuickDotAI.chatCancel()` cancels the active chat session. +- Native cancellation is cooperative and may stop at the next generated token. +- `LiteRTLm` uses its Kotlin-side cancellation flag/session logic. + +## βœ… Failure Semantics + +Each streaming call should emit exactly one terminal sink event: + +- `onDone()` on success +- `onError(error, message)` on failure + +Native non-zero `ErrorCode` values are mapped through `QuickAiError.fromNativeCode`. + +## πŸ“Ž Related Docs + +- [QuickDotAI AAR API](QuickDotAI/README.md) +- [Android Architecture](Architecture.md) +- [Chat and OpenAI Usage Examples](../docs/ChatAndOpenAIUsage.md) +- [C API Reference](../api/README.md) diff --git a/Android/QuickDotAI/.gitignore b/Android/QuickDotAI/.gitignore new file mode 100644 index 00000000..aa724b77 --- /dev/null +++ b/Android/QuickDotAI/.gitignore @@ -0,0 +1,15 @@ +*.iml +.gradle +/local.properties +/.idea/caches +/.idea/libraries +/.idea/modules.xml +/.idea/workspace.xml +/.idea/navEditor.xml +/.idea/assetWizardSettings.xml +.DS_Store +/build +/captures +.externalNativeBuild +.cxx +local.properties diff --git a/Android/QuickDotAI/README.md b/Android/QuickDotAI/README.md new file mode 100644 index 00000000..f0eb8013 --- /dev/null +++ b/Android/QuickDotAI/README.md @@ -0,0 +1,234 @@ +# QuickDotAI AAR API πŸ“± + +`QuickDotAI` is the Android-facing API for Quick.AI. It provides one Kotlin +interface over two engine implementations: + +- `NativeQuickDotAI`: JNI path for nntrainer / QNN models, backed by + `libquickai_jni.so` and the native `quick_dot_ai_api.h` entry points. +- `LiteRTLm`: LiteRT-LM path for Gemma-family `.litertlm` models. + +The current Gradle build includes `:QuickDotAI` and `:SampleTestAPP`. + +## πŸ“¦ Dependency + +```kotlin +dependencies { + implementation(project(":QuickDotAI")) +} +``` + +Only `arm64-v8a` prebuilt native libraries are currently supported. + +## 🧭 API Surface + +Package: `com.example.quickdotai` + +```kotlin +interface QuickDotAI { + val kind: String + val architecture: String? + val chatSessionId: String? + + fun load(req: LoadModelRequest): BackendResult + fun unload(): BackendResult + fun metrics(): BackendResult + fun cancel() + fun close() + + fun runModelHandleWithMessagesStreaming( + messages: List, + sink: StreamSink + ): BackendResult + + fun runMultimodalHandleWithMessagesStreaming( + messages: List, + sink: StreamSink + ): BackendResult + + fun runModelHandleWithJsonStreaming( + jsonRequest: String, + sink: StreamSink + ): BackendResult + + fun runMultimodalHandle(parts: List): BackendResult + + fun runMultimodalHandleStreaming( + parts: List, + sink: StreamSink + ): BackendResult + + fun openChatSession( + config: QuickAiChatSessionConfig? = null + ): BackendResult + + fun closeChatSession(): BackendResult + + fun runChatModelHandleStreaming( + text: String, + sink: StreamSink + ): BackendResult + + fun runChatMultimodalHandleStreaming( + parts: List, + sink: StreamSink + ): BackendResult + + fun chatRebuild(messages: List): BackendResult + fun chatCancel() +} +``` + +Removed APIs: `run()`, `runStreaming()`, `runWithMessages()`, +`runWithMessagesStreaming()`, `chatRun()`, and `chatRunStreaming()`. + +## πŸ€– Engine Selection + +Use the `createEngine` factory with a `ModelDescriptor` from `ModelCatalog`. It +picks the engine from the descriptor's `runtime`: + +```kotlin +val descriptor = ModelCatalog.byId(ModelIds.GEMMA4) ?: return +val engine: QuickDotAI = createEngine(applicationContext, descriptor) +// RuntimeKind.LITERT -> LiteRTLm, RuntimeKind.NATIVE -> NativeQuickDotAI +``` + +`gemma4` (LiteRT) is Kotlin-only and never crosses the JNI boundary. Native +model ids are loaded through `loadModelHandleByName()` in `quick_dot_ai_api.h`. + +## πŸ’¬ OpenAI Message Streaming + +Use `runModelHandleWithMessagesStreaming()` for OpenAI-style message lists and +`runModelHandleWithJsonStreaming()` for full OpenAI JSON requests containing +`tools` or legacy `functions`. + +End-to-end Chat tab and OpenAI tab examples live in +[`../../docs/ChatAndOpenAIUsage.md`](../../docs/ChatAndOpenAIUsage.md). + +## πŸ–ΌοΈ Multimodal Usage + +LiteRT-LM multimodal usage requires `LoadModelRequest.visionBackend`. +Native multimodal usage requires a native model handle whose config loads the +expected vision encoder + LLM sub-models. + +```kotlin +engine.load( + LoadModelRequest( + modelId = ModelIds.GEMMA4, + backend = BackendType.GPU, + visionBackend = BackendType.GPU, + modelPath = "/sdcard/Download/aistudio-mobile/models/gemma-4-E2B-it/gemma-4-E2B-it.litertlm" + ) +) + +engine.runMultimodalHandleWithMessagesStreaming( + listOf( + QuickAiChatMessage( + role = QuickAiChatRole.USER, + parts = listOf( + PromptPart.ImageFile("/sdcard/photo.jpg"), + PromptPart.Text("Describe this picture.") + ) + ) + ), + sink +) +``` + +## 🧡 Chat Sessions + +Chat sessions keep backend-managed conversation state. Use +`openChatSession()` before `runChatModelHandleStreaming()` or +`runChatMultimodalHandleStreaming()`, then call `chatRebuild()` or +`closeChatSession()` when the conversation state changes or ends. Only one chat +session may be active per engine instance. + +See [`../../docs/ChatAndOpenAIUsage.md`](../../docs/ChatAndOpenAIUsage.md) for +complete session examples. + +## 🧱 Core Types + +```kotlin +data class LoadModelRequest( + val backend: BackendType = BackendType.GPU, + val modelId: String, + val quantization: QuantizationType = QuantizationType.W4A32, + val modelPath: String? = null, + val visionBackend: BackendType? = null, + val cacheDir: String? = null, + val maxNumTokens: Int? = null, + val nativeLibDir: String? = null, + val modelBasePath: String? = null, + val htpBackendConfigPath: String? = null, +) + +enum class BackendType { CPU, GPU, NPU } + +// Model ids are plain Strings. Well-known ids are exposed as constants +// (see ModelCatalog.kt); the live list comes from ModelCatalog / the native +// catalog, so the AAR is not recompiled when the model list changes. +object ModelIds { + const val QWEN3_0_6B = "qwen3-0.6b" + const val QWEN3_1_7B_Q40 = "qwen3-1.7b-q40" + const val TINY_BERT = "tiny-bert" + const val FUNCTION_GEMMA = "function-gemma" + const val GEMMA4 = "gemma4" // LiteRT only + const val GEMMA4_CPU = "gemma4-cpu" + const val GEMMA4_E2B_QNN = "gemma4-e2b-qnn" + const val VJEPA_QNN = "vjepa-qnn" // V-JEPA multi-image (QNN) +} + +enum class QuantizationType { UNKNOWN, W4A32, W16A16, W8A16, W32A32 } + +sealed class PromptPart { + data class Text(val text: String) : PromptPart() + data class ImageFile(val absolutePath: String) : PromptPart() + data class ImageBytes(val bytes: ByteArray) : PromptPart() + // Pre-processed CHW pixel values, used by models like V-JEPA that take + // externally preprocessed multi-image input. + data class PreprocessedPixels( + val pixelValues: FloatArray, + val numPatches: Int, + val numImages: Int, + val patchesPerImage: IntArray, + val imageHeights: IntArray, + val imageWidths: IntArray + ) : PromptPart() +} + +data class QuickAiChatMessage( + val role: QuickAiChatRole, + val parts: List +) + +enum class QuickAiChatRole { SYSTEM, USER, ASSISTANT } + +interface StreamSink { + fun onDelta(text: String) + fun onReasoningDelta(text: String) {} + fun onDone() + fun onError(error: QuickAiError, message: String?) +} +``` + +See `Types.kt` for the full DTO set, including `QuickAiChatSessionConfig`, +sampling options, error codes, and metrics. + +For native QNN models, `htpBackendConfigPath` points to +`htp_backend_ext_config.json`. Absolute paths are used as-is. Relative paths are +resolved from the app external files directory, so +`"configs/htp_backend_ext_config.json"` resolves to +`/configs/htp_backend_ext_config.json`. When omitted, +`NativeQuickDotAI` uses `/htp_backend_ext_config.json`. + +## βœ… Rules + +- Call `load()` before any inference call. +- Drive each `QuickDotAI` instance from one worker thread. +- Call `close()` when finished; it closes any active chat session. +- Pass `nativeLibDir` for native/QNN models when the host app can provide + `applicationInfo.nativeLibraryDir`. +- Pass `modelBasePath` for native models when model files live outside the + native default path. +- Pass `htpBackendConfigPath` for QNN models when + `htp_backend_ext_config.json` lives outside the app external files root. +- Pass `modelPath` for `LiteRTLm` / `GEMMA4` models. diff --git a/Android/QuickDotAI/build.gradle.kts b/Android/QuickDotAI/build.gradle.kts new file mode 100644 index 00000000..2fe903c3 --- /dev/null +++ b/Android/QuickDotAI/build.gradle.kts @@ -0,0 +1,139 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright (C) 2026 Samsung Electronics Co., Ltd. All Rights Reserved. +// +// QuickDotAI β€” reusable AAR that bundles the QuickDotAI interface and +// both concrete implementations (LiteRTLm + NativeQuickDotAI) plus the +// JNI shim (libquickai_jni.so) and the CausalLM prebuilt shared +// libraries. Third-party apps can depend on this AAR to run on-device +// LLMs without linking QuickAIService or any of LauncherApp's REST +// plumbing. + +plugins { + alias(libs.plugins.android.library) + alias(libs.plugins.kotlin.android) + alias(libs.plugins.kotlin.serialization) +} + +// Copies the flat prebuilt .so files from QuickDotAI/prebuilt_libs/ into an +// ABI-nested directory (build/generated/jniLibs/arm64-v8a/) so that Android +// Gradle's standard jniLibs machinery can bundle them into the AAR. +val prebuiltNativeLibsDir = + layout.buildDirectory.dir("generated/jniLibs/arm64-v8a") + +val copyPrebuiltNativeLibs = tasks.register("copyPrebuiltNativeLibs") { + from(project.file("prebuilt_libs")) + include("*.so") + include("htp_backend_ext_config.json") + into(prebuiltNativeLibsDir) +} + +android { + namespace = "com.example.quickdotai" + compileSdk = 36 + + packaging { + jniLibs.useLegacyPackaging = true + } + + + defaultConfig { + minSdk = 33 + + ndk { + // Only arm64-v8a is supported by the prebuilt libcausallm_api.so. + abiFilters += listOf("arm64-v8a") + } + + externalNativeBuild { + cmake { + cppFlags += "-std=c++17 -frtti -fexceptions" + } + } + + consumerProguardFiles("consumer-rules.pro") + } + + externalNativeBuild { + cmake { + path = file("src/main/cpp/CMakeLists.txt") + version = "3.22.1" + } + } + + sourceSets { + getByName("main") { + // Pick up the generated/jniLibs//*.so tree produced by + // copyPrebuiltNativeLibs above, alongside any hand-placed files + // in src/main/jniLibs/. + // + // `buildDir` getter was deprecated in Gradle 8 and removed in + // Gradle 9; use the Provider-based layout.buildDirectory API + // so this file is forward-compatible with Gradle 9 if we + // ever roll the wrapper back up. + jniLibs.srcDirs( + "src/main/jniLibs", + layout.buildDirectory.dir("generated/jniLibs").get().asFile + ) + } + } + + buildTypes { + release { + isMinifyEnabled = false + proguardFiles( + getDefaultProguardFile("proguard-android-optimize.txt"), + "consumer-rules.pro" + ) + } + } + compileOptions { + sourceCompatibility = JavaVersion.VERSION_17 + targetCompatibility = JavaVersion.VERSION_17 + } + kotlinOptions { + jvmTarget = "17" + // LiteRT-LM 0.10.0 (the version our mirror serves) was compiled + // with Kotlin 2.3.0 but our compiler is 2.2.21; the flag tells + // kotlinc to accept the newer metadata stamp on the LiteRT-LM + // AAR this module directly links against. See + // libs.versions.toml for the full story. + freeCompilerArgs += "-Xskip-metadata-version-check" + } +} + +// The merge*JniLibFolders task reads android.sourceSets.main.jniLibs and +// stages the native libraries for packaging into the AAR, so make it +// depend on the copy task. ExternalNativeBuild also benefits because the +// CMake link step reads libcausallm_api.so directly from prebuilt_libs. +tasks.matching { + it.name.startsWith("merge") && it.name.endsWith("JniLibFolders") +}.configureEach { + dependsOn(copyPrebuiltNativeLibs) +} +tasks.matching { it.name.startsWith("externalNativeBuild") }.configureEach { + dependsOn(copyPrebuiltNativeLibs) +} + +dependencies { + // kotlinx.serialization is exposed as an `api` dependency because the + // public types (ModelId, BackendType, LoadModelRequest, …) carry + // @Serializable annotations so consumers that want to JSON-ify them + // can do so without pulling the runtime in themselves. + api(libs.kotlinx.serialization.json) + + // LiteRT-LM is the engine used by LiteRTLm.kt for Gemma-family models. + // Exposed as `api` so consumers don't have to redeclare it. + // + // Pinned to an explicit version via the version catalog instead of + // `latest.release`: dynamic versions are non-deterministic (they + // resolve differently depending on what each environment's Maven + // mirror happens to cache) and they caused a hard failure earlier + // when one mirror served 0.10.0 as "latest" while our Kotlin + // compiler was pinned to a version that could not read 0.10.0's + // metadata stamp. See gradle/libs.versions.toml for the rationale + // behind the exact pin. + api(libs.litertlm.android) + + // AndroidX Core for createBitmap and other utility functions + implementation("androidx.core:core-ktx:1.12.0") +} diff --git a/Android/QuickDotAI/consumer-rules.pro b/Android/QuickDotAI/consumer-rules.pro new file mode 100644 index 00000000..ce1ff8de --- /dev/null +++ b/Android/QuickDotAI/consumer-rules.pro @@ -0,0 +1,31 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (C) 2026 Samsung Electronics Co., Ltd. All Rights Reserved. +# +# QuickDotAI consumer ProGuard rules. These rules are automatically +# applied to any app that depends on the QuickDotAI AAR. + +# Keep all JNI entry points β€” they are called from native code and +# renaming them would break System.loadLibrary + external symbols. +-keepclasseswithmembernames class com.example.quickdotai.NativeCausalLm { + native ; +} +-keep class com.example.quickdotai.NativeCausalLm$* { *; } +-keep class com.example.quickdotai.NativeCausalLm { *; } + +# Keep the public QuickDotAI surface so consumers can reference it by +# name after R8 shrinks their app. +-keep class com.example.quickdotai.QuickDotAI { *; } +-keep interface com.example.quickdotai.QuickDotAI { *; } +-keep class com.example.quickdotai.LiteRTLm { *; } +-keep class com.example.quickdotai.NativeQuickDotAI { *; } +-keep class com.example.quickdotai.StreamSink { *; } +-keep interface com.example.quickdotai.StreamSink { *; } +-keep class com.example.quickdotai.BackendResult** { *; } +-keep class com.example.quickdotai.LoadModelRequest { *; } +-keep class com.example.quickdotai.PerformanceMetrics { *; } +-keep class com.example.quickdotai.PromptPart { *; } +-keep class com.example.quickdotai.PromptPart$* { *; } +-keepclassmembers class com.example.quickdotai.ModelId { *; } +-keepclassmembers class com.example.quickdotai.BackendType { *; } +-keepclassmembers class com.example.quickdotai.QuantizationType { *; } +-keepclassmembers class com.example.quickdotai.QuickAiError { *; } diff --git a/Android/QuickDotAI/prebuilt_libs/.gitkeep b/Android/QuickDotAI/prebuilt_libs/.gitkeep new file mode 100644 index 00000000..e69de29b diff --git a/Android/QuickDotAI/src/main/AndroidManifest.xml b/Android/QuickDotAI/src/main/AndroidManifest.xml new file mode 100644 index 00000000..d4788a90 --- /dev/null +++ b/Android/QuickDotAI/src/main/AndroidManifest.xml @@ -0,0 +1,31 @@ + + + + + + + + + + + + + + + + diff --git a/Android/QuickDotAI/src/main/cpp/CMakeLists.txt b/Android/QuickDotAI/src/main/cpp/CMakeLists.txt new file mode 100644 index 00000000..524b86fb --- /dev/null +++ b/Android/QuickDotAI/src/main/cpp/CMakeLists.txt @@ -0,0 +1,52 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (C) 2026 Samsung Electronics Co., Ltd. All Rights Reserved. +# +# QuickDotAI JNI shim build +# +# Produces libquickai_jni.so, which forwards JNI calls from Kotlin +# (com.example.quickdotai.NativeCausalLm) to the C entry points in +# quick_dot_ai_api.h. The causal_lm_api.so itself (plus its transitive +# dependencies libnntrainer.so, libccapi-nntrainer.so, libcausallm_core.so, +# libc++_shared.so) is shipped as a set of prebuilt shared libraries in +# QuickDotAI/prebuilt_libs/ β€” see build_api_lib.sh in Applications/CausalLM +# to regenerate them. API headers are bundled in src/main/cpp/include/ so +# the AAR module is fully self-contained. A Gradle copy task in +# QuickDotAI/build.gradle.kts copies the prebuilt .so files into the AAR's +# jniLibs/arm64-v8a at build time, so they are discoverable at runtime via +# System.loadLibrary from whichever app hosts the AAR. +cmake_minimum_required(VERSION 3.22.1) +project(quickai_jni LANGUAGES CXX) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +# Path to the QuickDotAI API headers so we can #include "quick_dot_ai_api.h". +# Bundled locally in src/main/cpp/include/ so the AAR is self-contained. +set(CAUSALLM_API_DIR + "${CMAKE_CURRENT_SOURCE_DIR}/include") + +# Path to the prebuilt native libraries. Layout is flat: +# QuickDotAI/prebuilt_libs/lib*.so +# (only arm64-v8a is supported at the moment β€” see build_api_lib.sh). +set(PREBUILT_LIBS_DIR + "${CMAKE_CURRENT_SOURCE_DIR}/../../../prebuilt_libs") + +# Register the prebuilt causal_lm_api shared library so the linker can +# resolve its symbols when linking libquickai_jni.so. +add_library(quick_dot_ai_api SHARED IMPORTED) +set_target_properties(quick_dot_ai_api PROPERTIES + IMPORTED_LOCATION + "${PREBUILT_LIBS_DIR}/libquick_dot_ai_api.so") + +add_library(quickai_jni SHARED + quickai_jni.cpp) + + +target_include_directories(quickai_jni PRIVATE + ${CAUSALLM_API_DIR}) + +find_library(log-lib log) + +target_link_libraries(quickai_jni + quick_dot_ai_api + ${log-lib}) diff --git a/Android/QuickDotAI/src/main/cpp/include/callback_streamer.h b/Android/QuickDotAI/src/main/cpp/include/callback_streamer.h new file mode 100644 index 00000000..fa7dd77b --- /dev/null +++ b/Android/QuickDotAI/src/main/cpp/include/callback_streamer.h @@ -0,0 +1,75 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2026 Samsung Electronics Co., Ltd. All Rights Reserved. + * + * @file callback_streamer.h + * @brief BaseStreamer implementation that forwards every delta to a + * user-supplied C function pointer. + * + * This is the streamer used by the JNI bridge in QuickAI: the Kotlin + * side hands the JNI entry point a listener object, and the JNI entry + * point wraps the listener in a CausalLmTokenCallback + user_data pair + * and pushes a CallbackStreamer onto its own stack frame for the + * duration of the call. + * + * See AsyncAndStreaming.md Β§3.2 at the repo root. + */ +#ifndef __QUICK_DOT_AI_CALLBACK_STREAMER_H__ +#define __QUICK_DOT_AI_CALLBACK_STREAMER_H__ + +#ifndef WIN_EXPORT +#ifdef _WIN32 +#define WIN_EXPORT __declspec(dllexport) +#else +#define WIN_EXPORT +#endif +#endif + +#include "streamer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * @brief Token callback signature. + * + * @param delta UTF-8 text produced for this token boundary. Valid + * only for the duration of the call β€” copy before + * retaining. + * @param user_data Opaque pointer passed through from the + * runModelHandleStreaming() caller. + * @return 0 to continue generation, non-zero to request cancellation. + */ +typedef int (*CausalLmTokenCallback)(const char *delta, void *user_data); + +/** + * @brief A BaseStreamer that forwards every put() to a + * CausalLmTokenCallback. + * + * Layout note: @c base MUST be the first member so that a + * `CallbackStreamer*` can be safely cast to `BaseStreamer*`. + */ +typedef struct { + BaseStreamer base; + CausalLmTokenCallback callback; + void *user_data; + int cancelled; /**< sticky: once set to non-zero, put() becomes a no-op. */ +} CallbackStreamer; + +/** + * @brief Initialize a CallbackStreamer in-place. Does not allocate. + * + * @param self Storage owned by the caller (typically stack). + * @param cb Callback to invoke for every delta. Must be non-NULL. + * @param user_data Opaque pointer forwarded to @c cb. + */ +WIN_EXPORT void callback_streamer_init(CallbackStreamer *self, + CausalLmTokenCallback cb, + void *user_data); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // __QUICK_DOT_AI_CALLBACK_STREAMER_H__ \ No newline at end of file diff --git a/Android/QuickDotAI/src/main/cpp/include/quick_dot_ai_api.h b/Android/QuickDotAI/src/main/cpp/include/quick_dot_ai_api.h new file mode 100644 index 00000000..6b6e7514 --- /dev/null +++ b/Android/QuickDotAI/src/main/cpp/include/quick_dot_ai_api.h @@ -0,0 +1,608 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2026 Samsung Electronics Co., Ltd. All Rights Reserved. + * + * @file quick_dot_ai_api.h + * @date 20 Mar 2026 + * @brief C API for src (extension of CausalLM) + * + * This header is self-contained: if causal_lm_api.h has already + * been included its types are reused; otherwise fallback + * definitions are provided so that this single header is + * sufficient for application code. + * + * @see https://github.com/nntrainer/nntrainer + * @author Eunju Yang + * @bug No known bugs except for NYI items + */ +#ifndef __QUICK_DOT_AI_API_H__ +#define __QUICK_DOT_AI_API_H__ + +/* ── Extended model types (src additions) ────────────────────── */ +#ifdef __CAUSAL_LM_API_H__ +/* Model types already defined from causal_lm_api.h */ +#else /* causal_lm_api.h not included β€” provide full definitions */ + +#define __CAUSAL_LM_API_H__ + +#ifdef _WIN32 +#define WIN_EXPORT __declspec(dllexport) +#else +#define WIN_EXPORT +#endif + +#include "callback_streamer.h" +#include "streamer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +#include + +typedef enum { + CAUSAL_LM_ERROR_NONE = 0, + CAUSAL_LM_ERROR_INVALID_PARAMETER = 1, + CAUSAL_LM_ERROR_MODEL_LOAD_FAILED = 2, + CAUSAL_LM_ERROR_INFERENCE_FAILED = 3, + CAUSAL_LM_ERROR_NOT_INITIALIZED = 4, + CAUSAL_LM_ERROR_INFERENCE_NOT_RUN = 5, + CAUSAL_LM_ERROR_UNSUPPORTED = 6, + CAUSAL_LM_ERROR_UNKNOWN = 99 +} ErrorCode; + +typedef enum { + CAUSAL_LM_BACKEND_CPU = 0, + CAUSAL_LM_BACKEND_GPU = 1, + CAUSAL_LM_BACKEND_NPU = 2, +} BackendType; + +/** + * @deprecated T4: λͺ¨λΈ μ‹λ³„μ˜ 정본은 λ¬Έμžμ—΄ id (loadModelHandleByName). + * 이 enum은 κΈ°μ‘΄ 호좜자 ν˜Έν™˜μš© public-only compat shim. + * λͺ¨λΈμ€ μΉ΄νƒˆλ‘œκ·Έλ‘œ μžλ™ 등둝. + */ +typedef enum { + CAUSAL_LM_MODEL_QWEN3_0_6B = 0, + CAUSAL_LM_MODEL_QWEN3_1_7B_Q40 = 4, /* original ordinal preserved */ + CAUSAL_LM_MODEL_TINY_BERT = 8, /* original */ + CAUSAL_LM_MODEL_FUNCTION_GEMMA = 9, /* original */ + CAUSAL_LM_MODEL_GEMMA4_CPU = 11, /* original */ + CAUSAL_LM_MODEL_GEMMA4_E2B_QNN = 12, /* original */ + CAUSAL_LM_MODEL_VJEPA_QNN = 13 +} ModelType; + +typedef struct { + // Add configuration options here as needed + bool use_chat_template; /// < @brief Whether to apply chat template to input + bool debug_mode; /// < @brief Check model file validity during initialization + bool verbose; /// < @brief Whether to print output during generation + const char + *chat_template_name; /// < @brief Template name to select from array + /// (e.g., "default", "tool_use"). NULL for + /// "default". +} Config; + +WIN_EXPORT ErrorCode setOptions(Config config); + +typedef enum { + CAUSAL_LM_QUANTIZATION_UNKNOWN = 0, + CAUSAL_LM_QUANTIZATION_W4A32 = 1, + CAUSAL_LM_QUANTIZATION_W16A16 = 2, + CAUSAL_LM_QUANTIZATION_W8A16 = 3, + CAUSAL_LM_QUANTIZATION_W32A32 = 4, +} ModelQuantizationType; + +/** + * @brief Chat message structure for chat template formatting + * @note Compatible with HuggingFace apply_chat_template() format + */ +typedef struct { + const char *role; /**< Message role: "system", "user", or "assistant" */ + const char *content; /**< Message content text */ +} CausalLMChatMessage; + +/** + * @brief Load a model + * @param compute Backend compute type + * @param modeltype Model type + * @param quant_type Model quantization type + * @return ErrorCode + */ +WIN_EXPORT ErrorCode loadModel(BackendType compute, ModelType modeltype, + ModelQuantizationType quant_type, + const char *model_base_path); + +typedef struct { + unsigned int prefill_tokens; + double prefill_duration_ms; + unsigned int generation_tokens; + double generation_duration_ms; + double total_duration_ms; + double initialization_duration_ms; + size_t peak_memory_kb; +} PerformanceMetrics; + +WIN_EXPORT ErrorCode getPerformanceMetrics(PerformanceMetrics *metrics); + +WIN_EXPORT ErrorCode saveQnnKvCache(const char *cache_path); +WIN_EXPORT ErrorCode loadQnnKvCache(const char *cache_path); +WIN_EXPORT ErrorCode resetQnnKvCache(void); + +/** + * @brief Apply chat template to messages without running inference + * @param messages Array of chat messages with role and content + * @param num_messages Number of messages in the array + * @param add_generation_prompt Whether to append generation prompt at end + * @param formattedText Buffer to store formatted text (owned by the library) + * @return ErrorCode + */ +WIN_EXPORT ErrorCode applyChatTemplate(const CausalLMChatMessage *messages, + size_t num_messages, + bool add_generation_prompt, + const char **formattedText); +/*============================================================================ + * Handle-based API (for parallel multi-model execution) + * + * The non-handle API above operates on a single process-wide model instance + * protected by one global mutex, which serializes every call and prevents + * loading more than one model at a time. The handle-based API below lets a + * caller load several models simultaneously and run them in parallel from + * different threads, with per-handle state so that different handles never + * block each other. Each handle owns its own model, its own last-output + * buffer, and its own mutex. + * + * A single handle may internally carry multiple sub-models (e.g. vision + * encoder + LLM) when loaded from a top-level nntr_config.json that + * specifies "architectures" and "model_dirs" arrays. The single-model + * run API (runModelHandleWithMessages / runModelHandleStreaming) drives + *models[0] only; the multimodal API (runMultimodalHandle*) drives the full set. + * + * Typical usage: + * CausalLmHandle h = NULL; + * loadModelHandle(CAUSAL_LM_BACKEND_CPU, CAUSAL_LM_MODEL_QWEN3_0_6B, + * CAUSAL_LM_QUANTIZATION_W4A32, NULL, &h); + * const char *out = NULL; + * CausalLMChatMessage msg; + * msg.role = "user"; + * msg.content = "Hello"; + * runModelHandleWithMessages(h, &msg, 1, true, &out); + * // ... use out (owned by h, valid until the next run or destroy) ... + * destroyModelHandle(h); + *============================================================================*/ + +/** + * @brief Opaque handle to a loaded CausalLM model instance. + */ +typedef struct CausalLmModel *CausalLmHandle; + +/** + * @brief Load a model and return a newly-allocated handle. + * + * Calling this multiple times with different parameters returns independent + * handles, each with its own model state. The caller must eventually call + * destroyModelHandle on the returned handle to release resources. + * + * @param compute Backend compute type + * @param modeltype Model type enum + * @param quant_type Quantization type + * @param native_lib_dir Native library directory path (from Android + * ApplicationInfo.nativeLibraryDir). May be NULL. + * @param out_handle Out-parameter that receives the new handle on success + * @return ErrorCode + */ +WIN_EXPORT ErrorCode loadModelHandle(BackendType compute, ModelType modeltype, + ModelQuantizationType quant_type, + const char *native_lib_dir, + const char *model_base_path, + CausalLmHandle *out_handle); + +/** + * @brief Load model by string id (T4 catalog path). + * + * Looks up the descriptor from the registry by @p model_id, validates the + * backend, then loads via the same internal path as loadModelHandle. + * Returns CAUSAL_LM_ERROR_INVALID_PARAMETER if the id is unknown, the + * descriptor has no config_name, or the backend is not in backend_mask. + * + * @param compute Backend compute type + * @param model_id Catalog string id e.g. "Qwen3-0.6B" + * @param quant_type Quantization type + * @param native_lib_dir Native library directory path. May be NULL. + * @param model_base_path Base path for model files. May be NULL. + * @param out_handle Out-parameter receiving the new handle on success + * @return ErrorCode + */ +WIN_EXPORT ErrorCode loadModelHandleByName(BackendType compute, + const char *model_id, + ModelQuantizationType quant_type, + const char *native_lib_dir, + const char *model_base_path, + CausalLmHandle *out_handle); + +/** + * @brief Run inference on a specific handle. + * + * The returned outputText pointer is owned by the handle and remains valid + * until the next runModelHandleWithMessages call on the same handle or until + * the handle is destroyed. Different handles are safe to call concurrently from + * different threads; the same handle is serialized by its own internal + * mutex. + * + * Single-model API: drives models[0] only even when the handle was + * populated with multiple sub-models. Use runMultimodalHandleWithMessages for + * compositions such as vision-encoder + LLM. + * + * @param handle Handle returned by loadModelHandle + * @param messages Array of chat messages with role and content + * @param num_messages Number of messages in the array + * @param add_generation_prompt Whether to append generation prompt at end + * @param outputText Out-parameter that receives a pointer to the output + * @return ErrorCode + */ +WIN_EXPORT ErrorCode runModelHandleWithMessages( + CausalLmHandle handle, const CausalLMChatMessage *messages, + size_t num_messages, bool add_generation_prompt, const char **outputText); + +/** + * @brief Streaming inference with OpenAI message format on a specific handle. + * + * Format the messages array through the chat template, then drive + * generation token-by-token, invoking @p callback for each delta. + * Blocks on the invoking thread until generation finishes or an error + * occurs. Semantics are otherwise identical to runModelHandleStreaming. + * + * @param handle Handle returned by loadModelHandle + * @param messages Array of chat messages with role and content + * @param num_messages Number of messages in the array + * @param add_generation_prompt Whether to append generation prompt at end + * @param callback Token delta callback. Must be non-NULL. + * @param user_data Opaque pointer forwarded to callback + * @return ErrorCode + */ +WIN_EXPORT ErrorCode runModelHandleWithMessagesStreaming( + CausalLmHandle handle, const CausalLMChatMessage *messages, + size_t num_messages, bool add_generation_prompt, + CausalLmTokenCallback callback, void *user_data); + +WIN_EXPORT ErrorCode saveQnnKvCacheHandle(CausalLmHandle handle, + const char *cache_path); +WIN_EXPORT ErrorCode loadQnnKvCacheHandle(CausalLmHandle handle, + const char *cache_path); +WIN_EXPORT ErrorCode resetQnnKvCacheHandle(CausalLmHandle handle); + +/** + * @brief Retrieve performance metrics of the last run for a given handle. + * @param handle Handle returned by loadModelHandle + * @param metrics Pointer to a PerformanceMetrics struct to be filled + * @return ErrorCode + */ +WIN_EXPORT ErrorCode getPerformanceMetricsHandle(CausalLmHandle handle, + PerformanceMetrics *metrics); + +/** + * @brief Release all resources owned by a handle. + * + * Passing a NULL handle is a no-op and returns CAUSAL_LM_ERROR_NONE. + * + * @param handle Handle returned by loadModelHandle + * @return ErrorCode + */ +WIN_EXPORT ErrorCode destroyModelHandle(CausalLmHandle handle); + +/** + * @brief Request cancellation of an in-progress run on a handle. + * + * Sets the stop flag on the model, causing the token generation loop + * to exit at the next token boundary. Thread-safe: can be called from + * any thread (e.g., from a UI cancel button handler). + * + * If no run is in progress, this function is a no-op. + * + * @param handle Handle returned by loadModelHandle + * @return ErrorCode + */ +WIN_EXPORT ErrorCode cancelModelHandle(CausalLmHandle handle); + +/** + * @brief Unload the model from a handle without destroying the handle. + * + * Releases the model weights and internal state but keeps the handle + * struct alive. After a successful unload, the handle's initialized flag + * is cleared and subsequent run / metrics calls will return + * CAUSAL_LM_ERROR_NOT_INITIALIZED. The handle can be destroyed later + * with destroyModelHandle, or (in future) re-loaded. + * + * Passing a NULL handle is a no-op and returns CAUSAL_LM_ERROR_NONE. + * + * @param handle Handle returned by loadModelHandle + * @return ErrorCode + */ +WIN_EXPORT ErrorCode unloadModelHandle(CausalLmHandle handle); + +/** + * @brief Streaming counterpart of runModelHandle. + * + * Synchronously drives inference on @p handle and invokes @p callback + * once per decoded-token boundary with a UTF-8 delta string. The call + * blocks on the invoking thread until generation finishes, hits an EOS + * token, reaches NUM_TO_GENERATE, the callback returns non-zero (which + * requests cancellation at the next token boundary), or an error + * occurs. + * + * The @p delta pointer passed into the callback is owned by the + * streaming runtime and is only valid for the duration of the callback + * invocation. Callers that need to retain the text must copy it. + * + * After a successful return the handle's "last output" buffer holds + * the full concatenated generation (or the partial output on a + * cancelled run), so a subsequent getPerformanceMetricsHandle() call + * returns valid metrics and the same handle can be reused for another + * run β€” identical semantics to runModelHandleWithMessages. + * + * Streaming is currently only supported on models whose underlying + * C++ implementation derives from causallm::CausalLM (all the Qwen + * variants and Llama do; non-CausalLM models return + * CAUSAL_LM_ERROR_UNKNOWN). See AsyncAndStreaming.md Β§3.4 at the repo + * root for the full design. + * + * @param handle Handle returned by loadModelHandle. + * @param inputTextPrompt Input prompt (UTF-8, NUL-terminated). + * @param callback Token delta callback. Must be non-NULL. + * @param user_data Opaque pointer forwarded verbatim to the + * callback on every invocation. May be NULL. + * @return ErrorCode + */ +WIN_EXPORT ErrorCode runModelHandleStreaming(CausalLmHandle handle, + const char *inputTextPrompt, + CausalLmTokenCallback callback, + void *user_data); + +/** + * @brief Encode a single text prompt into a sentence-embedding vector using a + * handle whose models[0] is an embedding model (e.g. Ouro / "ouro"). + * + * On success, *out_embedding points to a freshly allocated array of *out_dim + * floats (the batch-0 embedding). The caller OWNS this buffer and MUST release + * it with freeEmbedding(). On any error, *out_embedding is set to NULL and + * *out_dim to 0. + * + * @param handle Handle from loadModelHandle / loadModelHandleByName + * @param text UTF-8 input text (NUL-terminated) + * @param out_embedding [out] receives a newly allocated float[*out_dim] + * @param out_dim [out] receives the embedding dimension + * @return ErrorCode. CAUSAL_LM_ERROR_UNSUPPORTED if models[0] is not an + * embedding (SentenceTransformer) model. + */ +WIN_EXPORT ErrorCode encodeModelHandle(CausalLmHandle handle, const char *text, + float **out_embedding, int *out_dim); + +/** + * @brief Release a buffer returned by encodeModelHandle(). + * @param embedding Pointer previously returned via out_embedding (may be NULL) + */ +WIN_EXPORT void freeEmbedding(float *embedding); + +/** + * @brief Run inference on a handle with a tool schema for constrained + * generation. + * + * @param handle Handle returned by loadModelHandle + * @param inputTextPrompt Input prompt text + * @param outputText Buffer to store output text (owned by the handle) + * @param tool_name Name of the tool (used to cache the grammar) + * @param tool_schema JSON schema for the tool output format + * @return ErrorCode + */ +WIN_EXPORT ErrorCode runModelHandleWithTool(CausalLmHandle handle, + const char *inputTextPrompt, + const char **outputText, + const char *tool_name, + const char *tool_schema); + +/*============================================================================ + * Multimodal API + * + * These functions extend the handle-based API to support image+text inputs. + * The pixel values are passed as preprocessed FloatArray (CHW format) from + * the Kotlin image processor (LlavaNextImageProcessor). + * + * The handle must have been loaded from a multi-model nntr_config.json + * (architectures[] + model_dirs[]) with at least [vision_encoder, llm]; + * a single-model handle returns CAUSAL_LM_ERROR_UNSUPPORTED. + * + * Vision Encoder integration is planned for future implementation. + * Currently these functions return CAUSAL_LM_ERROR_UNSUPPORTED as stubs + * once the multi-model precondition is satisfied. + *============================================================================*/ + +/** + * @brief Streaming multimodal inference on a specific handle. + * + * @param handle Handle returned by loadModelHandle + * @param prompt Text prompt (UTF-8, NUL-terminated) + * @param pixelValues Preprocessed image patches in CHW format + * @param numPatches Number of image patches + * @param originalHeight Original image height before preprocessing + * @param originalWidth Original image width before preprocessing + * @param callback Token delta callback. Must be non-NULL. + * @param user_data Opaque pointer forwarded to callback + * @return ErrorCode (CAUSAL_LM_ERROR_UNSUPPORTED until Vision Encoder + * implemented) + */ +WIN_EXPORT ErrorCode runMultimodalHandleStreaming( + CausalLmHandle handle, const char *prompt, const float *pixelValues, + int numPatches, int originalHeight, int originalWidth, + CausalLmTokenCallback callback, void *user_data); + +/** + * @brief Blocking multimodal inference with OpenAI message format on a specific + * handle. + * + * @param handle Handle returned by loadModelHandle + * @param messages Array of chat messages with role and content + * (text-only, image via pixelValues) + * @param num_messages Number of messages in the array + * @param add_generation_prompt Whether to append generation prompt at end + * @param pixelValues Preprocessed image patches in CHW format + * @param numPatches Number of image patches + * @param originalHeight Original image height before preprocessing + * @param originalWidth Original image width before preprocessing + * @param outputText Out-parameter that receives a pointer to the output + * @return ErrorCode (CAUSAL_LM_ERROR_UNSUPPORTED until Vision Encoder + * implemented) + */ +WIN_EXPORT ErrorCode runMultimodalHandleWithMessages( + CausalLmHandle handle, const CausalLMChatMessage *messages, + size_t num_messages, bool add_generation_prompt, const float *pixelValues, + int numPatches, int originalHeight, int originalWidth, + const char **outputText); + +/** + * @brief Streaming multimodal inference with OpenAI message format on a + * specific handle. + * + * Format the messages array through the chat template, run the vision + * encoder if needed, then drive LLM generation token-by-token invoking + * @p callback for each delta. Blocks on the invoking thread until + * generation finishes or an error occurs. + * + * @param handle Handle returned by loadModelHandle + * @param messages Array of chat messages with role and content + * (text-only, image via pixelValues) + * @param num_messages Number of messages in the array + * @param add_generation_prompt Whether to append generation prompt at end + * @param pixelValues Preprocessed image patches in CHW format + * @param numPatches Number of image patches + * @param originalHeight Original image height before preprocessing + * @param originalWidth Original image width before preprocessing + * @param callback Token delta callback. Must be non-NULL. + * @param user_data Opaque pointer forwarded to callback + * @return ErrorCode + */ +WIN_EXPORT ErrorCode runMultimodalHandleWithMessagesStreaming( + CausalLmHandle handle, const CausalLMChatMessage *messages, + size_t num_messages, bool add_generation_prompt, const float *pixelValues, + int numPatches, int originalHeight, int originalWidth, + CausalLmTokenCallback callback, void *user_data); + +/*============================================================================ + * Multi-image Multimodal API (V-JEPA) + * + * These functions extend the multimodal API to support multiple images + * (e.g. video frames for V-JEPA). The pixel values for all images are + * concatenated into a single flat array, with per-image metadata + * (patches per image, heights, widths) passed as separate arrays. + * + * The handle must have been loaded with CAUSAL_LM_MODEL_VJEPA_QNN or + * another multi-image-capable model type. + *============================================================================*/ + +/** + * @brief Streaming multi-image multimodal inference on a specific handle. + * + * Designed for models like V-JEPA that accept multiple preprocessed + * image frames (e.g. 16 video frames) as input. + * + * @param handle Handle returned by loadModelHandle + * @param prompt Text prompt (UTF-8, NUL-terminated) + * @param pixelValues Preprocessed image patches in CHW format + * (all images concatenated) + * @param numPatches Total number of image patches across all images + * @param numImages Number of images (e.g. 16 for V-JEPA) + * @param patchesPerImage Array of numImages ints: patches per image + * @param originalHeights Array of numImages ints: original height per image + * @param originalWidths Array of numImages ints: original width per image + * @param callback Token delta callback. Must be non-NULL. + * @param user_data Opaque pointer forwarded to callback + * @return ErrorCode + */ +WIN_EXPORT ErrorCode runMultimodalMultiImageHandleStreaming( + CausalLmHandle handle, const char *prompt, const float *pixelValues, + int numPatches, int numImages, const int *patchesPerImage, + const int *originalHeights, const int *originalWidths, + CausalLmTokenCallback callback, void *user_data); + +/** + * @brief Streaming multi-image multimodal inference with OpenAI message + * format on a specific handle. + * + * @param handle Handle returned by loadModelHandle + * @param messages Array of chat messages with role and content + * @param num_messages Number of messages in the array + * @param add_generation_prompt Whether to append generation prompt at end + * @param pixelValues Preprocessed image patches in CHW format + * (all images concatenated) + * @param numPatches Total number of image patches across all images + * @param numImages Number of images (e.g. 16 for V-JEPA) + * @param patchesPerImage Array of numImages ints: patches per image + * @param originalHeights Array of numImages ints: original height per image + * @param originalWidths Array of numImages ints: original width per image + * @param callback Token delta callback. Must be non-NULL. + * @param user_data Opaque pointer forwarded to callback + * @return ErrorCode + */ +WIN_EXPORT ErrorCode runMultimodalMultiImageHandleWithMessagesStreaming( + CausalLmHandle handle, const CausalLMChatMessage *messages, + size_t num_messages, bool add_generation_prompt, const float *pixelValues, + int numPatches, int numImages, const int *patchesPerImage, + const int *originalHeights, const int *originalWidths, + CausalLmTokenCallback callback, void *user_data); + +/*============================================================================ + * OpenAI JSON streaming API + * + * Accepts a JSON string in OpenAI format and processes it through the + * chat template. Supports messages, tools, functions, and all other + * fields recognized by minja chat template renderer. + * + * Example JSON input: + * { + * "messages": [ + * {"role": "developer", "content": "..."}, + * {"role": "user", "content": "..."} + * ], + * "tools": [ + * {"type": "function", "function": {"name": "call", "description": "..."}} + * ] + * } + *============================================================================*/ + +/** + * @brief Streaming inference with OpenAI JSON format. + * + * Parses the JSON request and applies the chat template, then drives + * generation token-by-token invoking @p callback for each delta. + * + * @param handle Handle returned by loadModelHandle + * @param jsonRequest OpenAI format JSON string (UTF-8, NUL-terminated) + * @param callback Token delta callback. Must be non-NULL. + * @param user_data Opaque pointer forwarded to callback + * @return ErrorCode + */ +WIN_EXPORT ErrorCode runModelHandleWithJsonStreaming( + CausalLmHandle handle, const char *jsonRequest, + CausalLmTokenCallback callback, void *user_data); + +/** + * @brief Return a JSON array of all registered ModelDescriptors. + * + * Returns a NUL-terminated UTF-8 string like: + * [{"id":"...","family":"...","display_name":"...","runtime":0, + * "backend_mask":0,"capabilities":0}, ...] + * + * The registry is empty until tasks that call + * quick_dot_ai::register_model_descriptor() are linked in. + * The returned pointer is valid until the next call to getModelCatalogJson(). + * + * @return const char* JSON array string (never NULL) + */ +WIN_EXPORT const char *getModelCatalogJson(void); + +#ifdef __cplusplus +} +#endif + +#endif /* __CAUSAL_LM_API_H__ */ + +#endif /* __QUICK_DOT_AI_API_H__ */ diff --git a/Android/QuickDotAI/src/main/cpp/include/streamer.h b/Android/QuickDotAI/src/main/cpp/include/streamer.h new file mode 100644 index 00000000..d661716e --- /dev/null +++ b/Android/QuickDotAI/src/main/cpp/include/streamer.h @@ -0,0 +1,94 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2026 Samsung Electronics Co., Ltd. All Rights Reserved. + * + * @file streamer.h + * @brief Minimal C-callable base streamer used by the handle-based + * `runModelHandleStreaming` entry point in quick_dot_ai_api.h. + * + * This is intentionally a very thin vtable-based polymorphism layer so + * that: + * - the CausalLM inference loop can push decoded tokens through a + * single pointer, + * - concrete streamers (currently only CallbackStreamer) can be + * implemented in plain C without dragging C++ headers into the + * CausalLM internals, + * - the same mechanism is reusable from JNI callers (the JNI bridge + * instantiates a CallbackStreamer on the stack and lets the C API + * drive it). + * + * See AsyncAndStreaming.md Β§3.1 at the repo root for the full design + * rationale. + */ +#ifndef __QUICK_DOT_AI_STREAMER_H__ +#define __QUICK_DOT_AI_STREAMER_H__ + +#ifndef WIN_EXPORT +#ifdef _WIN32 +#define WIN_EXPORT __declspec(dllexport) +#else +#define WIN_EXPORT +#endif +#endif + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct BaseStreamer BaseStreamer; + +/** + * @brief Vtable for a BaseStreamer. + * + * Both function pointers may be NULL β€” the streamer_put / streamer_end + * helpers below are null-safe so the caller never has to check. + */ +typedef struct { + /** + * @brief Forward one UTF-8 delta string to the streamer. + * + * The pointer is only valid for the duration of the call; the + * streamer implementation must copy if it needs to retain the data. + * + * @return 0 to continue generation, non-zero to request cancellation + * at the next token boundary. + */ + int (*put)(BaseStreamer *self, const char *decoded_utf8); + + /** + * @brief Called exactly once after the last put, regardless of whether + * generation finished normally, was cancelled via the callback + * return value, or ended because an exception propagated out of + * the run loop. + */ + void (*end)(BaseStreamer *self); +} BaseStreamerVTable; + +/** + * @brief Base streamer. Concrete streamers embed this as their first + * field and set @c vtable to a static const instance of + * BaseStreamerVTable. + */ +struct BaseStreamer { + const BaseStreamerVTable *vtable; +}; + +/** + * @brief NULL-safe wrapper around the vtable's put() hook. Returns + * non-zero if the streamer requested cancellation. + */ +WIN_EXPORT int streamer_put(BaseStreamer *self, const char *decoded_utf8); + +/** + * @brief NULL-safe wrapper around the vtable's end() hook. Idempotent + * from the caller's perspective β€” concrete implementations + * should tolerate being called multiple times, but the CausalLM + * inference path calls this at most once. + */ +WIN_EXPORT void streamer_end(BaseStreamer *self); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // __QUICK_DOT_AI_STREAMER_H__ \ No newline at end of file diff --git a/Android/QuickDotAI/src/main/cpp/quickai_jni.cpp b/Android/QuickDotAI/src/main/cpp/quickai_jni.cpp new file mode 100644 index 00000000..bef07c45 --- /dev/null +++ b/Android/QuickDotAI/src/main/cpp/quickai_jni.cpp @@ -0,0 +1,1046 @@ +// SPDX-License-Identifier: Apache-2.0 +/* + * Copyright (C) 2026 Samsung Electronics Co., Ltd. All Rights Reserved. + * + * @file quickai_jni.cpp + * @brief JNI shim that forwards calls from Kotlin's + * com.example.quickdotai.NativeCausalLm object to the + * handle-based C entry points declared in + * Applications/CausalLM/api/quick_dot_ai_api.h. + * + * This file contains no business logic β€” only JNI marshalling: + * jstring <-> const char* + * jlong <-> CausalLmHandle + * ErrorCode + struct PerformanceMetrics -> Kotlin data classes. + * + * Higher-level concerns (per-model threading, FIFO queue, Gemma4 + * routing) live one level up in NativeQuickDotAI.kt and in the host + * app's worker (for QuickAIService that is ModelWorker / ModelRegistry; + * SampleTestAPP drives NativeQuickDotAI directly from its own + * background dispatcher). + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "quick_dot_ai_api.h" + +#define LOG_TAG "quickai_jni" +#define LOGE(...) __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, __VA_ARGS__) + +namespace { + +// --------------------------------------------------------------------------- +// Cached global references to the Kotlin result data-classes so we can +// construct them back in JNI. Looked up once in JNI_OnLoad. +// --------------------------------------------------------------------------- + +struct JniCache { + jclass loadResultCls = nullptr; // NativeCausalLm$LoadResult + jmethodID loadResultCtor = nullptr; + + jclass runResultCls = nullptr; // NativeCausalLm$RunResult + jmethodID runResultCtor = nullptr; + + jclass metricsResultCls = nullptr; // NativeCausalLm$MetricsResult + jmethodID metricsResultCtor = nullptr; +}; + +JniCache g_cache; + +jclass find_global(JNIEnv *env, const char *name) { + jclass local = env->FindClass(name); + if (local == nullptr) { + LOGE("FindClass failed: %s", name); + return nullptr; + } + auto *global = reinterpret_cast(env->NewGlobalRef(local)); + env->DeleteLocalRef(local); + return global; +} + +} // namespace + +extern "C" JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *vm, void * /*reserved*/) { + JNIEnv *env = nullptr; + if (vm->GetEnv(reinterpret_cast(&env), JNI_VERSION_1_6) != JNI_OK || + env == nullptr) { + return JNI_ERR; + } + + g_cache.loadResultCls = + find_global(env, "com/example/quickdotai/NativeCausalLm$LoadResult"); + if (g_cache.loadResultCls != nullptr) { + g_cache.loadResultCtor = + env->GetMethodID(g_cache.loadResultCls, "", "(IJ)V"); + } + + g_cache.runResultCls = + find_global(env, "com/example/quickdotai/NativeCausalLm$RunResult"); + if (g_cache.runResultCls != nullptr) { + g_cache.runResultCtor = env->GetMethodID(g_cache.runResultCls, "", + "(ILjava/lang/String;)V"); + } + + g_cache.metricsResultCls = + find_global(env, "com/example/quickdotai/NativeCausalLm$MetricsResult"); + if (g_cache.metricsResultCls != nullptr) { + g_cache.metricsResultCtor = + env->GetMethodID(g_cache.metricsResultCls, "", "(IIDIDDDJ)V"); + } + + return JNI_VERSION_1_6; +} + +// --------------------------------------------------------------------------- +// setOptions +// --------------------------------------------------------------------------- +extern "C" JNIEXPORT jint JNICALL +Java_com_example_quickdotai_NativeCausalLm_setOptionsNative( + JNIEnv * /*env*/, jobject /*thiz*/, jboolean useChatTemplate, + jboolean debugMode, jboolean verbose) { + Config cfg; + cfg.use_chat_template = (useChatTemplate == JNI_TRUE); + cfg.debug_mode = (debugMode == JNI_TRUE); + cfg.verbose = (verbose == JNI_TRUE); + return static_cast(setOptions(cfg)); +} + +// --------------------------------------------------------------------------- +// chdir +// +// The C API in quick_dot_ai_api.cpp hardcodes the model discovery prefix to +// the relative path "./models/-" (see resolve_model_path()), +// which ties model lookup to the process's current working directory. +// Android apps start with cwd="/" so the only way to point the loader at +// an app-owned directory is to chdir(2) the process before calling +// loadModelHandle. This helper exposes that chdir to Kotlin as a thin +// wrapper β€” returning 0 on success or the POSIX errno on failure, which +// NativeQuickDotAI surfaces as CAUSAL_LM_ERROR_MODEL_LOAD_FAILED. +// --------------------------------------------------------------------------- +extern "C" JNIEXPORT jint JNICALL +Java_com_example_quickdotai_NativeCausalLm_chdirNative(JNIEnv *env, + jobject /*thiz*/, + jstring pathJ) { + if (pathJ == nullptr) { + return EINVAL; + } + const char *path = env->GetStringUTFChars(pathJ, nullptr); + if (path == nullptr) { + return ENOMEM; + } + int rc = chdir(path); + int err = (rc == 0) ? 0 : errno; + env->ReleaseStringUTFChars(pathJ, path); + return static_cast(err); +} + +// --------------------------------------------------------------------------- +// loadModelHandle +// --------------------------------------------------------------------------- +extern "C" JNIEXPORT jobject JNICALL +Java_com_example_quickdotai_NativeCausalLm_loadModelHandleNative( + JNIEnv *env, jobject /*thiz*/, jint backendOrdinal, jint modelOrdinal, + jint quantOrdinal, jstring nativeLibDirJ, jstring modelBasePathJ, + jstring htpBackendConfigPathJ) { + const char *native_lib_dir = nullptr; + if (nativeLibDirJ != nullptr) { + native_lib_dir = env->GetStringUTFChars(nativeLibDirJ, nullptr); + } + + const char *model_base_path = nullptr; + if (modelBasePathJ != nullptr) { + model_base_path = env->GetStringUTFChars(modelBasePathJ, nullptr); + } + + const char *htp_backend_config_path = nullptr; + if (htpBackendConfigPathJ != nullptr) { + htp_backend_config_path = + env->GetStringUTFChars(htpBackendConfigPathJ, nullptr); + } + + const char *previous_htp_backend_config_path = + getenv("QUICK_DOT_AI_QNN_BACKEND_EXT_CONFIG_PATH"); + const bool had_previous_htp_backend_config_path = + previous_htp_backend_config_path != nullptr; + std::string previous_htp_backend_config_path_value; + if (had_previous_htp_backend_config_path) { + previous_htp_backend_config_path_value = previous_htp_backend_config_path; + } + const bool has_htp_backend_config_path = + htp_backend_config_path != nullptr && htp_backend_config_path[0] != '\0'; + if (has_htp_backend_config_path) { + setenv("QUICK_DOT_AI_QNN_BACKEND_EXT_CONFIG_PATH", htp_backend_config_path, + 1); + } + + CausalLmHandle handle = nullptr; + ErrorCode ec = + loadModelHandle(static_cast(backendOrdinal), + static_cast(modelOrdinal), + static_cast(quantOrdinal), + native_lib_dir, model_base_path, &handle); + + if (has_htp_backend_config_path) { + if (had_previous_htp_backend_config_path) { + setenv("QUICK_DOT_AI_QNN_BACKEND_EXT_CONFIG_PATH", + previous_htp_backend_config_path_value.c_str(), 1); + } else { + unsetenv("QUICK_DOT_AI_QNN_BACKEND_EXT_CONFIG_PATH"); + } + } + + if (native_lib_dir != nullptr && nativeLibDirJ != nullptr) { + env->ReleaseStringUTFChars(nativeLibDirJ, native_lib_dir); + } + if (model_base_path != nullptr && modelBasePathJ != nullptr) { + env->ReleaseStringUTFChars(modelBasePathJ, model_base_path); + } + if (htp_backend_config_path != nullptr && htpBackendConfigPathJ != nullptr) { + env->ReleaseStringUTFChars(htpBackendConfigPathJ, htp_backend_config_path); + } + + if (g_cache.loadResultCls == nullptr || g_cache.loadResultCtor == nullptr) { + return nullptr; + } + return env->NewObject(g_cache.loadResultCls, g_cache.loadResultCtor, + static_cast(ec), reinterpret_cast(handle)); +} + +// ---- loadModelHandleByName (T4 string-id path) ---------------------------- +extern "C" JNIEXPORT jlong JNICALL +Java_com_example_quickdotai_NativeCausalLm_loadModelHandleByNameNative( + JNIEnv *env, jobject /*thiz*/, jint backend, jstring modelIdJ, jint quant, + jstring nativeLibDirJ, jstring modelBasePathJ) { + const char *id = env->GetStringUTFChars(modelIdJ, nullptr); + const char *nld = + nativeLibDirJ ? env->GetStringUTFChars(nativeLibDirJ, nullptr) : nullptr; + const char *mbp = + modelBasePathJ ? env->GetStringUTFChars(modelBasePathJ, nullptr) : nullptr; + + CausalLmHandle h = nullptr; + ErrorCode ec = loadModelHandleByName( + static_cast(backend), id, + static_cast(quant), nld, mbp, &h); + + env->ReleaseStringUTFChars(modelIdJ, id); + if (nld) + env->ReleaseStringUTFChars(nativeLibDirJ, nld); + if (mbp) + env->ReleaseStringUTFChars(modelBasePathJ, mbp); + + return (ec == CAUSAL_LM_ERROR_NONE) + ? static_cast(reinterpret_cast(h)) + : 0L; +} + +// ---- nativeQueryCatalog --------------------------------------------------- +extern "C" JNIEXPORT jstring JNICALL +Java_com_example_quickdotai_NativeCausalLm_nativeQueryCatalog( + JNIEnv *env, jobject /*thiz*/) { + return env->NewStringUTF(getModelCatalogJson()); +} + +// ---- encodeModelHandle (embedding vector) --------------------------------- +extern "C" JNIEXPORT jfloatArray JNICALL +Java_com_example_quickdotai_NativeCausalLm_encodeModelHandleNative( + JNIEnv *env, jobject /*thiz*/, jlong handleJlong, jstring textJ) { + auto handle = reinterpret_cast(handleJlong); + if (handle == nullptr || textJ == nullptr) { + return nullptr; + } + const char *text = env->GetStringUTFChars(textJ, nullptr); + + float *vec = nullptr; + int dim = 0; + ErrorCode ec = encodeModelHandle(handle, text, &vec, &dim); + + env->ReleaseStringUTFChars(textJ, text); + + if (ec != CAUSAL_LM_ERROR_NONE || vec == nullptr || dim <= 0) { + if (vec != nullptr) { + freeEmbedding(vec); + } + return nullptr; // null signals failure to the Kotlin layer + } + + jfloatArray arr = env->NewFloatArray(dim); + if (arr != nullptr) { + env->SetFloatArrayRegion(arr, 0, dim, vec); + } + freeEmbedding(vec); + return arr; +} + +// --------------------------------------------------------------------------- +// runModelHandleStreaming +// --------------------------------------------------------------------------- +extern "C" JNIEXPORT jobject JNICALL +Java_com_example_quickdotai_NativeCausalLm_getPerformanceMetricsHandleNative( + JNIEnv *env, jobject /*thiz*/, jlong handleJlong) { + auto handle = reinterpret_cast(handleJlong); + + PerformanceMetrics m{}; + ErrorCode ec = getPerformanceMetricsHandle(handle, &m); + + if (g_cache.metricsResultCls == nullptr || + g_cache.metricsResultCtor == nullptr) { + return nullptr; + } + return env->NewObject(g_cache.metricsResultCls, g_cache.metricsResultCtor, + static_cast(ec), + static_cast(m.prefill_tokens), + static_cast(m.prefill_duration_ms), + static_cast(m.generation_tokens), + static_cast(m.generation_duration_ms), + static_cast(m.total_duration_ms), + static_cast(m.initialization_duration_ms), + static_cast(m.peak_memory_kb)); +} + +// --------------------------------------------------------------------------- +// unloadModelHandle +// --------------------------------------------------------------------------- +extern "C" JNIEXPORT jint JNICALL +Java_com_example_quickdotai_NativeCausalLm_unloadModelHandleNative( + JNIEnv * /*env*/, jobject /*thiz*/, jlong handleJlong) { + auto handle = reinterpret_cast(handleJlong); + return static_cast(unloadModelHandle(handle)); +} + +// --------------------------------------------------------------------------- +// destroyModelHandle +// --------------------------------------------------------------------------- +extern "C" JNIEXPORT jint JNICALL +Java_com_example_quickdotai_NativeCausalLm_destroyModelHandleNative( + JNIEnv * /*env*/, jobject /*thiz*/, jlong handleJlong) { + auto handle = reinterpret_cast(handleJlong); + return static_cast(destroyModelHandle(handle)); +} + +// --------------------------------------------------------------------------- +// cancelModelHandle +// +// Requests cancellation of an in-progress streaming run. Thread-safe: +// can be called from any thread (e.g., UI cancel button handler). +// --------------------------------------------------------------------------- +extern "C" JNIEXPORT jint JNICALL +Java_com_example_quickdotai_NativeCausalLm_cancelModelHandleNative( + JNIEnv * /*env*/, jobject /*thiz*/, jlong handleJlong) { + auto handle = reinterpret_cast(handleJlong); + __android_log_print(ANDROID_LOG_DEBUG, LOG_TAG, + "cancelModelHandleNative: handle=%p", (void *)handle); + auto result = cancelModelHandle(handle); + __android_log_print(ANDROID_LOG_DEBUG, LOG_TAG, + "cancelModelHandleNative: returned %d", result); + return static_cast(result); +} + +// --------------------------------------------------------------------------- +// runModelHandleStreaming +// +// Forwards deltas from the native quick_dot_ai_api streaming callback to a +// Kotlin NativeStreamListener.onDelta(String). See AsyncAndStreaming.md Β§4 +// at the repo root for the design rationale β€” in particular, the +// callback fires on the SAME thread that invoked this JNI entry point +// (the ModelWorker thread), which means we do NOT need AttachCurrentThread: +// the JNIEnv* captured here is still valid throughout every callback. +// --------------------------------------------------------------------------- +namespace { +struct StreamCtx { + JNIEnv *env; + jobject listener; // local ref owned by the JNI entry frame + jmethodID onDelta; // Ljava/lang/String;)V +}; + +inline bool is_valid_utf8(const char *s) { + if (s == nullptr) + return true; + while (*s) { + unsigned char c = static_cast(*s); + int follow = 0; + if (c < 0x80) { + ++s; + continue; + } else if ((c & 0xE0) == 0xC0) + follow = 1; + else if ((c & 0xF0) == 0xE0) + follow = 2; + else if ((c & 0xF8) == 0xF0) + follow = 3; + else + return false; // invalid lead byte + + for (int i = 0; i < follow; ++i) { + ++s; + if (*s == '\0' || (static_cast(*s) & 0xC0) != 0x80) + return false; // missing or invalid continuation byte + } + ++s; + } + return true; +} + +int stream_trampoline(const char *delta, void *user_data) { + auto *ctx = static_cast(user_data); + if (ctx == nullptr || ctx->env == nullptr || ctx->listener == nullptr || + ctx->onDelta == nullptr) { + return 1; // cancel + } + const char *deltaStr = delta != nullptr ? delta : ""; + if (!is_valid_utf8(deltaStr)) { + LOGE("Invalid UTF-8 delta skipped (lead=0x%02X)", + static_cast(deltaStr[0])); + return 0; // skip this delta, keep streaming + } + jstring js = ctx->env->NewStringUTF(deltaStr); + if (js == nullptr) { + // OOM or pending exception; clear and ask the native runner to stop. + if (ctx->env->ExceptionCheck()) { + ctx->env->ExceptionClear(); + } + return 1; + } + ctx->env->CallVoidMethod(ctx->listener, ctx->onDelta, js); + ctx->env->DeleteLocalRef(js); + if (ctx->env->ExceptionCheck()) { + // Surface Kotlin-side errors as cancellation; the Kotlin override + // in NativeQuickDotAI.runStreaming will catch the exception on the + // JNI call's return and report it through StreamSink.onError. + ctx->env->ExceptionDescribe(); + ctx->env->ExceptionClear(); + return 1; + } + return 0; +} +} // namespace + +extern "C" JNIEXPORT jint JNICALL +Java_com_example_quickdotai_NativeCausalLm_runModelHandleStreamingNative( + JNIEnv *env, jobject /*thiz*/, jlong handleJlong, jstring promptJ, + jobject listenerObj) { + if (promptJ == nullptr || listenerObj == nullptr) { + return static_cast(CAUSAL_LM_ERROR_INVALID_PARAMETER); + } + + // Resolve the onDelta(String)V method id per-call. We can't cache + // this globally because NativeStreamListener is a `fun interface` + // and the concrete class of `listenerObj` varies call-to-call. + jclass listenerCls = env->GetObjectClass(listenerObj); + if (listenerCls == nullptr) { + return static_cast(CAUSAL_LM_ERROR_INVALID_PARAMETER); + } + jmethodID onDelta = + env->GetMethodID(listenerCls, "onDelta", "(Ljava/lang/String;)V"); + env->DeleteLocalRef(listenerCls); + if (onDelta == nullptr) { + if (env->ExceptionCheck()) { + env->ExceptionClear(); + } + return static_cast(CAUSAL_LM_ERROR_INVALID_PARAMETER); + } + + const char *prompt = env->GetStringUTFChars(promptJ, nullptr); + if (prompt == nullptr) { + return static_cast(CAUSAL_LM_ERROR_INVALID_PARAMETER); + } + + auto handle = reinterpret_cast(handleJlong); + StreamCtx ctx{env, listenerObj, onDelta}; + ErrorCode ec = + runModelHandleStreaming(handle, prompt, &stream_trampoline, &ctx); + + env->ReleaseStringUTFChars(promptJ, prompt); + return static_cast(ec); +} + +// --------------------------------------------------------------------------- +// runMultimodalHandleStreaming +// +// Multimodal streaming inference that accepts preprocessed image patches +// (as FloatArray) and a text prompt. The pixel values are converted from +// jfloatArray to native float* and passed to the C API. +// --------------------------------------------------------------------------- +extern "C" JNIEXPORT jint JNICALL +Java_com_example_quickdotai_NativeCausalLm_runMultimodalHandleStreamingNative( + JNIEnv *env, jobject /*thiz*/, jlong handleJlong, jstring promptJ, + jfloatArray pixelValuesJ, jint numPatches, jint originalHeight, + jint originalWidth, jobject listenerObj) { + if (promptJ == nullptr || pixelValuesJ == nullptr || listenerObj == nullptr) { + return static_cast(CAUSAL_LM_ERROR_INVALID_PARAMETER); + } + + // Resolve the onDelta(String)V method id + jclass listenerCls = env->GetObjectClass(listenerObj); + if (listenerCls == nullptr) { + return static_cast(CAUSAL_LM_ERROR_INVALID_PARAMETER); + } + jmethodID onDelta = + env->GetMethodID(listenerCls, "onDelta", "(Ljava/lang/String;)V"); + env->DeleteLocalRef(listenerCls); + if (onDelta == nullptr) { + if (env->ExceptionCheck()) { + env->ExceptionClear(); + } + return static_cast(CAUSAL_LM_ERROR_INVALID_PARAMETER); + } + + const char *prompt = env->GetStringUTFChars(promptJ, nullptr); + if (prompt == nullptr) { + return static_cast(CAUSAL_LM_ERROR_INVALID_PARAMETER); + } + + // Get float* from FloatArray + float *pixels = env->GetFloatArrayElements(pixelValuesJ, nullptr); + if (pixels == nullptr) { + env->ReleaseStringUTFChars(promptJ, prompt); + return static_cast(CAUSAL_LM_ERROR_INVALID_PARAMETER); + } + + auto handle = reinterpret_cast(handleJlong); + StreamCtx ctx{env, listenerObj, onDelta}; + + ErrorCode ec = runMultimodalHandleStreaming( + handle, prompt, pixels, numPatches, originalHeight, originalWidth, + &stream_trampoline, &ctx); + + // Release resources + env->ReleaseFloatArrayElements(pixelValuesJ, pixels, JNI_ABORT); + env->ReleaseStringUTFChars(promptJ, prompt); + + return static_cast(ec); +} + +namespace { + +/** + * @brief Convert a Kotlin QuickAiChatMessage to C CausalLMChatMessage. + * + * QuickAiChatMessage structure: + * role: QuickAiChatRole (enum) β€” call name() to get String + * parts: List β€” extract Text parts and concatenate + * + * Returns false if conversion fails (sets JNI exception). + */ +bool convertQuickAiChatMessage(JNIEnv *env, jobject msgObj, + std::string &outRole, std::string &outContent) { + if (msgObj == nullptr) + return false; + + jclass msgCls = env->GetObjectClass(msgObj); + if (msgCls == nullptr) + return false; + + // --- role: QuickAiChatRole enum --- + jfieldID roleFid = + env->GetFieldID(msgCls, "role", "Lcom/example/quickdotai/QuickAiChatRole;"); + if (roleFid == nullptr) { + env->DeleteLocalRef(msgCls); + return false; + } + jobject roleEnum = env->GetObjectField(msgObj, roleFid); + if (roleEnum == nullptr) { + env->DeleteLocalRef(msgCls); + return false; + } + + // Call enum.name() and convert to lowercase for OpenAI API compatibility + jclass enumCls = env->GetObjectClass(roleEnum); + jmethodID nameMid = env->GetMethodID(enumCls, "name", "()Ljava/lang/String;"); + jstring roleNameJ = (jstring)env->CallObjectMethod(roleEnum, nameMid); + if (roleNameJ != nullptr) { + const char *roleName = env->GetStringUTFChars(roleNameJ, nullptr); + outRole = roleName ? roleName : ""; + // Convert to lowercase: "SYSTEM" -> "system", "USER" -> "user", "ASSISTANT" + // -> "assistant" + std::transform(outRole.begin(), outRole.end(), outRole.begin(), ::tolower); + env->ReleaseStringUTFChars(roleNameJ, roleName); + env->DeleteLocalRef(roleNameJ); + } + env->DeleteLocalRef(enumCls); + env->DeleteLocalRef(roleEnum); + + // --- parts: List --- + jfieldID partsFid = env->GetFieldID(msgCls, "parts", "Ljava/util/List;"); + if (partsFid == nullptr) { + env->DeleteLocalRef(msgCls); + return false; + } + jobject partsList = env->GetObjectField(msgObj, partsFid); + if (partsList == nullptr) { + env->DeleteLocalRef(msgCls); + return false; + } + + // Get List.size() and List.get() + jclass listCls = env->GetObjectClass(partsList); + jmethodID sizeMid = env->GetMethodID(listCls, "size", "()I"); + jmethodID getMid = env->GetMethodID(listCls, "get", "(I)Ljava/lang/Object;"); + + jint partsSize = env->CallIntMethod(partsList, sizeMid); + std::string content; + + // PromptPart.Text class + jclass textPartCls = env->FindClass("com/example/quickdotai/PromptPart$Text"); + if (textPartCls == nullptr) { + if (env->ExceptionCheck()) + env->ExceptionClear(); + } else { + jfieldID textFid = + env->GetFieldID(textPartCls, "text", "Ljava/lang/String;"); + if (textFid != nullptr) { + for (jint p = 0; p < partsSize; ++p) { + jobject partObj = env->CallObjectMethod(partsList, getMid, p); + if (partObj == nullptr) + continue; + + if (env->IsInstanceOf(partObj, textPartCls)) { + jstring textJ = (jstring)env->GetObjectField(partObj, textFid); + if (textJ != nullptr) { + const char *text = env->GetStringUTFChars(textJ, nullptr); + if (text != nullptr) { + if (!content.empty()) + content += " "; + content += text; + env->ReleaseStringUTFChars(textJ, text); + } + env->DeleteLocalRef(textJ); + } + } + env->DeleteLocalRef(partObj); + } + } + env->DeleteLocalRef(textPartCls); + } + + outContent = std::move(content); + + env->DeleteLocalRef(listCls); + env->DeleteLocalRef(partsList); + env->DeleteLocalRef(msgCls); + return true; +} + +} // namespace + +// --------------------------------------------------------------------------- +// runModelHandleWithMessagesStreaming +// +// Streaming inference with OpenAI message format on a specific handle. +// Converts jobjectArray of QuickAiChatMessage to C array, applies chat +// template, then drives streaming generation token-by-token. +// --------------------------------------------------------------------------- +extern "C" JNIEXPORT jint JNICALL +Java_com_example_quickdotai_NativeCausalLm_runModelHandleWithMessagesStreamingNative( + JNIEnv *env, jobject /*thiz*/, jlong handleJlong, jobjectArray messagesJ, + jboolean addGenerationPrompt, jobject listenerObj) { + + if (messagesJ == nullptr || listenerObj == nullptr) { + return static_cast(CAUSAL_LM_ERROR_INVALID_PARAMETER); + } + + // Resolve listener method + jclass listenerCls = env->GetObjectClass(listenerObj); + if (listenerCls == nullptr) { + return static_cast(CAUSAL_LM_ERROR_INVALID_PARAMETER); + } + jmethodID onDelta = + env->GetMethodID(listenerCls, "onDelta", "(Ljava/lang/String;)V"); + env->DeleteLocalRef(listenerCls); + if (onDelta == nullptr) { + if (env->ExceptionCheck()) + env->ExceptionClear(); + return static_cast(CAUSAL_LM_ERROR_INVALID_PARAMETER); + } + + // Convert messages + jsize len = env->GetArrayLength(messagesJ); + std::vector msgs; + msgs.reserve(len); + std::vector roleStorage; + std::vector contentStorage; + roleStorage.reserve(len); + contentStorage.reserve(len); + + for (jsize i = 0; i < len; ++i) { + jobject msgObj = env->GetObjectArrayElement(messagesJ, i); + std::string roleStr, contentStr; + if (msgObj != nullptr && + convertQuickAiChatMessage(env, msgObj, roleStr, contentStr)) { + roleStorage.push_back(std::move(roleStr)); + contentStorage.push_back(std::move(contentStr)); + msgs.push_back( + {roleStorage.back().c_str(), contentStorage.back().c_str()}); + } + if (msgObj != nullptr) + env->DeleteLocalRef(msgObj); + } + + auto handle = reinterpret_cast(handleJlong); + StreamCtx ctx{env, listenerObj, onDelta}; + + ErrorCode ec = runModelHandleWithMessagesStreaming( + handle, msgs.data(), msgs.size(), addGenerationPrompt == JNI_TRUE, + &stream_trampoline, &ctx); + + return static_cast(ec); +} + +// --------------------------------------------------------------------------- +// runMultimodalHandleWithMessagesStreaming +// +// Streaming multimodal inference with OpenAI message format on a specific +// handle. +// --------------------------------------------------------------------------- +extern "C" JNIEXPORT jint JNICALL +Java_com_example_quickdotai_NativeCausalLm_runMultimodalHandleWithMessagesStreamingNative( + JNIEnv *env, jobject /*thiz*/, jlong handleJlong, jobjectArray messagesJ, + jboolean addGenerationPrompt, jfloatArray pixelValuesJ, jint numPatches, + jint originalHeight, jint originalWidth, jobject listenerObj) { + + if (messagesJ == nullptr || pixelValuesJ == nullptr || + listenerObj == nullptr) { + return static_cast(CAUSAL_LM_ERROR_INVALID_PARAMETER); + } + + // Resolve listener method + jclass listenerCls = env->GetObjectClass(listenerObj); + if (listenerCls == nullptr) { + return static_cast(CAUSAL_LM_ERROR_INVALID_PARAMETER); + } + jmethodID onDelta = + env->GetMethodID(listenerCls, "onDelta", "(Ljava/lang/String;)V"); + env->DeleteLocalRef(listenerCls); + if (onDelta == nullptr) { + if (env->ExceptionCheck()) + env->ExceptionClear(); + return static_cast(CAUSAL_LM_ERROR_INVALID_PARAMETER); + } + + // Convert messages + jsize len = env->GetArrayLength(messagesJ); + std::vector msgs; + msgs.reserve(len); + std::vector roleStorage; + std::vector contentStorage; + roleStorage.reserve(len); + contentStorage.reserve(len); + + for (jsize i = 0; i < len; ++i) { + jobject msgObj = env->GetObjectArrayElement(messagesJ, i); + std::string roleStr, contentStr; + if (msgObj != nullptr && + convertQuickAiChatMessage(env, msgObj, roleStr, contentStr)) { + roleStorage.push_back(std::move(roleStr)); + contentStorage.push_back(std::move(contentStr)); + msgs.push_back( + {roleStorage.back().c_str(), contentStorage.back().c_str()}); + } + if (msgObj != nullptr) + env->DeleteLocalRef(msgObj); + } + + float *pixels = env->GetFloatArrayElements(pixelValuesJ, nullptr); + if (pixels == nullptr) { + return static_cast(CAUSAL_LM_ERROR_INVALID_PARAMETER); + } + + auto handle = reinterpret_cast(handleJlong); + StreamCtx ctx{env, listenerObj, onDelta}; + + ErrorCode ec = runMultimodalHandleWithMessagesStreaming( + handle, msgs.data(), msgs.size(), addGenerationPrompt == JNI_TRUE, pixels, + numPatches, originalHeight, originalWidth, &stream_trampoline, &ctx); + + env->ReleaseFloatArrayElements(pixelValuesJ, pixels, JNI_ABORT); + + return static_cast(ec); +} + +// --------------------------------------------------------------------------- +// runModelHandleWithJsonStreaming +// +// Streaming inference with OpenAI JSON format on a specific handle. +// Accepts a JSON string containing messages, tools, functions, etc. +// --------------------------------------------------------------------------- +extern "C" JNIEXPORT jint JNICALL +Java_com_example_quickdotai_NativeCausalLm_runModelHandleWithJsonStreamingNative( + JNIEnv *env, jobject /*thiz*/, jlong handleJlong, jstring jsonRequestJ, + jobject listenerObj) { + + if (jsonRequestJ == nullptr || listenerObj == nullptr) { + return static_cast(CAUSAL_LM_ERROR_INVALID_PARAMETER); + } + + // Resolve listener method + jclass listenerCls = env->GetObjectClass(listenerObj); + if (listenerCls == nullptr) { + return static_cast(CAUSAL_LM_ERROR_INVALID_PARAMETER); + } + jmethodID onDelta = + env->GetMethodID(listenerCls, "onDelta", "(Ljava/lang/String;)V"); + env->DeleteLocalRef(listenerCls); + if (onDelta == nullptr) { + if (env->ExceptionCheck()) + env->ExceptionClear(); + return static_cast(CAUSAL_LM_ERROR_INVALID_PARAMETER); + } + + const char *jsonRequest = env->GetStringUTFChars(jsonRequestJ, nullptr); + if (jsonRequest == nullptr) { + return static_cast(CAUSAL_LM_ERROR_INVALID_PARAMETER); + } + + auto handle = reinterpret_cast(handleJlong); + StreamCtx ctx{env, listenerObj, onDelta}; + + ErrorCode ec = runModelHandleWithJsonStreaming(handle, jsonRequest, + &stream_trampoline, &ctx); + + env->ReleaseStringUTFChars(jsonRequestJ, jsonRequest); + + return static_cast(ec); +} + +// --------------------------------------------------------------------------- +// runMultimodalMultiImageStreamingNative +// +// Multimodal streaming inference with multi-image support (V-JEPA). +// Accepts preprocessed pixel values for multiple images along with +// per-image metadata (patches per image, heights, widths). +// --------------------------------------------------------------------------- +extern "C" JNIEXPORT jint JNICALL +Java_com_example_quickdotai_NativeCausalLm_runMultimodalMultiImageStreamingNative( + JNIEnv *env, jobject /*thiz*/, jlong handleJlong, jstring promptJ, + jfloatArray pixelValuesJ, jint numPatches, jint numImages, + jintArray patchesPerImageJ, jintArray originalHeightsJ, + jintArray originalWidthsJ, jobject listenerObj) { + + if (promptJ == nullptr || pixelValuesJ == nullptr || + patchesPerImageJ == nullptr || originalHeightsJ == nullptr || + originalWidthsJ == nullptr || listenerObj == nullptr) { + return static_cast(CAUSAL_LM_ERROR_INVALID_PARAMETER); + } + + // Resolve listener method + jclass listenerCls = env->GetObjectClass(listenerObj); + if (listenerCls == nullptr) { + return static_cast(CAUSAL_LM_ERROR_INVALID_PARAMETER); + } + jmethodID onDelta = + env->GetMethodID(listenerCls, "onDelta", "(Ljava/lang/String;)V"); + env->DeleteLocalRef(listenerCls); + if (onDelta == nullptr) { + if (env->ExceptionCheck()) + env->ExceptionClear(); + return static_cast(CAUSAL_LM_ERROR_INVALID_PARAMETER); + } + + const char *prompt = env->GetStringUTFChars(promptJ, nullptr); + if (prompt == nullptr) { + return static_cast(CAUSAL_LM_ERROR_INVALID_PARAMETER); + } + + // Get float* from FloatArray + float *pixels = env->GetFloatArrayElements(pixelValuesJ, nullptr); + if (pixels == nullptr) { + env->ReleaseStringUTFChars(promptJ, prompt); + return static_cast(CAUSAL_LM_ERROR_INVALID_PARAMETER); + } + + // Get int* from IntArrays + jint *patchesPerImage = env->GetIntArrayElements(patchesPerImageJ, nullptr); + if (patchesPerImage == nullptr) { + env->ReleaseFloatArrayElements(pixelValuesJ, pixels, JNI_ABORT); + env->ReleaseStringUTFChars(promptJ, prompt); + return static_cast(CAUSAL_LM_ERROR_INVALID_PARAMETER); + } + + jint *originalHeights = env->GetIntArrayElements(originalHeightsJ, nullptr); + if (originalHeights == nullptr) { + env->ReleaseIntArrayElements(patchesPerImageJ, patchesPerImage, JNI_ABORT); + env->ReleaseFloatArrayElements(pixelValuesJ, pixels, JNI_ABORT); + env->ReleaseStringUTFChars(promptJ, prompt); + return static_cast(CAUSAL_LM_ERROR_INVALID_PARAMETER); + } + + jint *originalWidths = env->GetIntArrayElements(originalWidthsJ, nullptr); + if (originalWidths == nullptr) { + env->ReleaseIntArrayElements(originalHeightsJ, originalHeights, JNI_ABORT); + env->ReleaseIntArrayElements(patchesPerImageJ, patchesPerImage, JNI_ABORT); + env->ReleaseFloatArrayElements(pixelValuesJ, pixels, JNI_ABORT); + env->ReleaseStringUTFChars(promptJ, prompt); + return static_cast(CAUSAL_LM_ERROR_INVALID_PARAMETER); + } + + auto handle = reinterpret_cast(handleJlong); + StreamCtx ctx{env, listenerObj, onDelta}; + + // Debug: log multi-image metadata + __android_log_print(ANDROID_LOG_DEBUG, LOG_TAG, + "runMultimodalMultiImageStreamingNative: handle=%p, numPatches=%d, " + "numImages=%d, pixelValues[0..4]=%f,%f,%f,%f,%f", + (void *)handle, numPatches, numImages, + pixels[0], (numPatches > 1 ? pixels[1] : 0.0f), + (numPatches > 2 ? pixels[2] : 0.0f), + (numPatches > 3 ? pixels[3] : 0.0f), + (numPatches > 4 ? pixels[4] : 0.0f)); + for (int i = 0; i < numImages; ++i) { + __android_log_print(ANDROID_LOG_DEBUG, LOG_TAG, + " image[%d]: patches=%d, height=%d, width=%d", + i, patchesPerImage[i], originalHeights[i], originalWidths[i]); + } + + ErrorCode ec = runMultimodalMultiImageHandleStreaming( + handle, prompt, pixels, numPatches, numImages, + reinterpret_cast(patchesPerImage), + reinterpret_cast(originalHeights), + reinterpret_cast(originalWidths), + &stream_trampoline, &ctx); + + __android_log_print(ANDROID_LOG_DEBUG, LOG_TAG, + "runMultimodalMultiImageStreamingNative: returned ec=%d", (int)ec); + + // Release resources + env->ReleaseIntArrayElements(originalWidthsJ, originalWidths, JNI_ABORT); + env->ReleaseIntArrayElements(originalHeightsJ, originalHeights, JNI_ABORT); + env->ReleaseIntArrayElements(patchesPerImageJ, patchesPerImage, JNI_ABORT); + env->ReleaseFloatArrayElements(pixelValuesJ, pixels, JNI_ABORT); + env->ReleaseStringUTFChars(promptJ, prompt); + + return static_cast(ec); +} + +// --------------------------------------------------------------------------- +// runMultimodalMultiImageWithMessagesStreamingNative +// +// Streaming multimodal inference with multi-image + messages (V-JEPA). +// --------------------------------------------------------------------------- +extern "C" JNIEXPORT jint JNICALL +Java_com_example_quickdotai_NativeCausalLm_runMultimodalMultiImageWithMessagesStreamingNative( + JNIEnv *env, jobject /*thiz*/, jlong handleJlong, jobjectArray messagesJ, + jboolean addGenerationPrompt, jfloatArray pixelValuesJ, jint numPatches, + jint numImages, jintArray patchesPerImageJ, jintArray originalHeightsJ, + jintArray originalWidthsJ, jobject listenerObj) { + + if (messagesJ == nullptr || pixelValuesJ == nullptr || + patchesPerImageJ == nullptr || originalHeightsJ == nullptr || + originalWidthsJ == nullptr || listenerObj == nullptr) { + return static_cast(CAUSAL_LM_ERROR_INVALID_PARAMETER); + } + + // Resolve listener method + jclass listenerCls = env->GetObjectClass(listenerObj); + if (listenerCls == nullptr) { + return static_cast(CAUSAL_LM_ERROR_INVALID_PARAMETER); + } + jmethodID onDelta = + env->GetMethodID(listenerCls, "onDelta", "(Ljava/lang/String;)V"); + env->DeleteLocalRef(listenerCls); + if (onDelta == nullptr) { + if (env->ExceptionCheck()) + env->ExceptionClear(); + return static_cast(CAUSAL_LM_ERROR_INVALID_PARAMETER); + } + + // Convert messages + jsize len = env->GetArrayLength(messagesJ); + std::vector msgs; + msgs.reserve(len); + std::vector roleStorage; + std::vector contentStorage; + roleStorage.reserve(len); + contentStorage.reserve(len); + + for (jsize i = 0; i < len; ++i) { + jobject msgObj = env->GetObjectArrayElement(messagesJ, i); + std::string roleStr, contentStr; + if (msgObj != nullptr && + convertQuickAiChatMessage(env, msgObj, roleStr, contentStr)) { + roleStorage.push_back(std::move(roleStr)); + contentStorage.push_back(std::move(contentStr)); + msgs.push_back( + {roleStorage.back().c_str(), contentStorage.back().c_str()}); + } + if (msgObj != nullptr) + env->DeleteLocalRef(msgObj); + } + + // Get float* from FloatArray + float *pixels = env->GetFloatArrayElements(pixelValuesJ, nullptr); + if (pixels == nullptr) { + return static_cast(CAUSAL_LM_ERROR_INVALID_PARAMETER); + } + + // Get int* from IntArrays + jint *patchesPerImage = env->GetIntArrayElements(patchesPerImageJ, nullptr); + if (patchesPerImage == nullptr) { + env->ReleaseFloatArrayElements(pixelValuesJ, pixels, JNI_ABORT); + return static_cast(CAUSAL_LM_ERROR_INVALID_PARAMETER); + } + + jint *originalHeights = env->GetIntArrayElements(originalHeightsJ, nullptr); + if (originalHeights == nullptr) { + env->ReleaseIntArrayElements(patchesPerImageJ, patchesPerImage, JNI_ABORT); + env->ReleaseFloatArrayElements(pixelValuesJ, pixels, JNI_ABORT); + return static_cast(CAUSAL_LM_ERROR_INVALID_PARAMETER); + } + + jint *originalWidths = env->GetIntArrayElements(originalWidthsJ, nullptr); + if (originalWidths == nullptr) { + env->ReleaseIntArrayElements(originalHeightsJ, originalHeights, JNI_ABORT); + env->ReleaseIntArrayElements(patchesPerImageJ, patchesPerImage, JNI_ABORT); + env->ReleaseFloatArrayElements(pixelValuesJ, pixels, JNI_ABORT); + return static_cast(CAUSAL_LM_ERROR_INVALID_PARAMETER); + } + + auto handle = reinterpret_cast(handleJlong); + StreamCtx ctx{env, listenerObj, onDelta}; + + // Debug: log multi-image + messages metadata + __android_log_print(ANDROID_LOG_DEBUG, LOG_TAG, + "runMultimodalMultiImageWithMessagesStreamingNative: handle=%p, " + "numMessages=%zu, numPatches=%d, numImages=%d, " + "pixelValues[0..4]=%f,%f,%f,%f,%f", + (void *)handle, msgs.size(), numPatches, numImages, + pixels[0], (numPatches > 1 ? pixels[1] : 0.0f), + (numPatches > 2 ? pixels[2] : 0.0f), + (numPatches > 3 ? pixels[3] : 0.0f), + (numPatches > 4 ? pixels[4] : 0.0f)); + for (int i = 0; i < numImages; ++i) { + __android_log_print(ANDROID_LOG_DEBUG, LOG_TAG, + " image[%d]: patches=%d, height=%d, width=%d", + i, patchesPerImage[i], originalHeights[i], originalWidths[i]); + } + for (size_t i = 0; i < msgs.size(); ++i) { + __android_log_print(ANDROID_LOG_DEBUG, LOG_TAG, + " msg[%zu]: role='%s' contentLen=%zu", + i, msgs[i].role, strlen(msgs[i].content)); + } + + ErrorCode ec = runMultimodalMultiImageHandleWithMessagesStreaming( + handle, msgs.data(), msgs.size(), addGenerationPrompt == JNI_TRUE, pixels, + numPatches, numImages, + reinterpret_cast(patchesPerImage), + reinterpret_cast(originalHeights), + reinterpret_cast(originalWidths), + &stream_trampoline, &ctx); + + __android_log_print(ANDROID_LOG_DEBUG, LOG_TAG, + "runMultimodalMultiImageWithMessagesStreamingNative: returned ec=%d", (int)ec); + + // Release resources + env->ReleaseIntArrayElements(originalWidthsJ, originalWidths, JNI_ABORT); + env->ReleaseIntArrayElements(originalHeightsJ, originalHeights, JNI_ABORT); + env->ReleaseIntArrayElements(patchesPerImageJ, patchesPerImage, JNI_ABORT); + env->ReleaseFloatArrayElements(pixelValuesJ, pixels, JNI_ABORT); + + return static_cast(ec); +} diff --git a/Android/QuickDotAI/src/main/java/com/example/quickdotai/ImageStore.kt b/Android/QuickDotAI/src/main/java/com/example/quickdotai/ImageStore.kt new file mode 100644 index 00000000..84e692a4 --- /dev/null +++ b/Android/QuickDotAI/src/main/java/com/example/quickdotai/ImageStore.kt @@ -0,0 +1,107 @@ +// SPDX-License-Identifier: Apache-2.0 +/* + * Copyright (C) 2026 Samsung Electronics Co., Ltd. All Rights Reserved. + * + * @file ImageStore.kt + * @brief SHA-256-based in-memory image cache for structured chat sessions. + * + * Identical images that arrive via different temporary file paths are + * identified by content hash, so the conversation history can reference + * them stably across turns without depending on filesystem paths. + * + * Each chat session owns its own [ImageStore]. The store is + * cleared when the session is closed or the owning engine is unloaded. + */ +package com.example.quickdotai + +import android.util.Log +import java.io.File +import java.security.MessageDigest +import java.util.concurrent.ConcurrentHashMap + +/** + * @brief In-memory, hash-addressed image cache. + * + * Thread safety: all public methods are safe to call concurrently + * (backed by [ConcurrentHashMap]). In practice the session drives + * the store from a single worker thread, but defensive safety costs + * almost nothing here. + */ +internal class ImageStore { + + private val cache = ConcurrentHashMap() + + /** + * Store raw image [bytes] and return the SHA-256 hex digest. + * If the same content was already stored, no duplicate is created. + */ + fun store(bytes: ByteArray): String { + val hash = sha256Hex(bytes) + cache.putIfAbsent(hash, bytes) + return hash + } + + /** + * Read the image at [absolutePath], store its bytes, and return + * the SHA-256 hex digest. + * + * @throws IllegalArgumentException if the file does not exist or + * is not readable. + */ + fun store(absolutePath: String): String { + val f = File(absolutePath) + require(f.exists() && f.canRead()) { + "ImageStore: file not readable: $absolutePath" + } + return store(f.readBytes()) + } + + /** Retrieve previously-stored bytes by their hash, or null. */ + fun get(hash: String): ByteArray? = cache[hash] + + /** Check whether a hash is present in the store. */ + fun contains(hash: String): Boolean = cache.containsKey(hash) + + /** Number of images currently cached. */ + val size: Int get() = cache.size + + /** + * Remove all cached images and free memory. Called when the owning + * session is closed or the engine is unloaded. + */ + fun clear() { + val n = cache.size + cache.clear() + if (n > 0) { + Log.i(TAG, "clear(): removed $n cached image(s)") + } + } + + /** + * Remove images whose hashes are NOT in [retainHashes]. Used by + * [LiteRTLmChatSession.rebuild] to prune images that are no longer + * referenced by the new history. + */ + fun retainOnly(retainHashes: Set) { + val iter = cache.keys.iterator() + var removed = 0 + while (iter.hasNext()) { + if (iter.next() !in retainHashes) { + iter.remove() + removed++ + } + } + if (removed > 0) { + Log.i(TAG, "retainOnly(): pruned $removed unreferenced image(s)") + } + } + + companion object { + private const val TAG = "ImageStore" + + fun sha256Hex(data: ByteArray): String { + val digest = MessageDigest.getInstance("SHA-256").digest(data) + return digest.joinToString("") { "%02x".format(it) } + } + } +} diff --git a/Android/QuickDotAI/src/main/java/com/example/quickdotai/LiteRTLm.kt b/Android/QuickDotAI/src/main/java/com/example/quickdotai/LiteRTLm.kt new file mode 100644 index 00000000..78802b58 --- /dev/null +++ b/Android/QuickDotAI/src/main/java/com/example/quickdotai/LiteRTLm.kt @@ -0,0 +1,864 @@ +// SPDX-License-Identifier: Apache-2.0 +/* + * Copyright (C) 2026 Samsung Electronics Co., Ltd. All Rights Reserved. + * + * @file LiteRTLm.kt + * @brief QuickDotAI implementation backed by the LiteRT-LM Kotlin API + * (https://github.com/google-ai-edge/LiteRT-LM). + * + * LiteRTLm is the QuickDotAI-level routing target for ModelId.GEMMA4 + * and is typically selected inside the host app's registry. It + * implements the same [QuickDotAI] contract as [NativeQuickDotAI], so + * consumers never need to branch on the concrete implementation. + * + * See how-to-use-litert-lm-guide.md at the repo root for the canonical + * LiteRT-LM Kotlin API surface this code is written against. + */ +package com.example.quickdotai + +import android.content.Context +import android.util.Log +import com.google.ai.edge.litertlm.Backend as LlmBackend +import com.google.ai.edge.litertlm.Content +import com.google.ai.edge.litertlm.Contents +import com.google.ai.edge.litertlm.Conversation +import com.google.ai.edge.litertlm.Engine +import com.google.ai.edge.litertlm.EngineConfig +import com.google.ai.edge.litertlm.Message +import com.google.ai.edge.litertlm.MessageCallback +import java.io.File +import java.util.concurrent.CountDownLatch +import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicBoolean + +/** + * @brief LiteRT-LM-backed QuickDotAI implementation for Gemma-family + * models. + * + * Non-thread-safe β€” the host app must drive a single instance from a + * single worker thread. + * + * @param appContext application context, used only to resolve the + * test-mode fallback model path via [Context.getExternalFilesDir] + * when [LoadModelRequest.modelPath] is null. Third-party apps + * are encouraged to always pass an explicit [LoadModelRequest.modelPath] + * and therefore never hit the fallback. + */ +class LiteRTLm( + private val appContext: Context, + /** Base directory for model files, used by [testModelFile] fallback. + * Defaults to `/sdcard/Download/aistudio-mobile/models/`. */ + private val defaultModelBasePath: String = "/sdcard/Download/aistudio-mobile/models/" +) : QuickDotAI { + + override val kind: String = "litert-lm" + + override var architecture: String? = "Gemma4ForCausalLM" + private set + + // LiteRT-LM's Engine is AutoCloseable β€” we hold on to it (and a + // single reusable Conversation) for the entire lifetime of the + // LiteRTLm instance. Closed in [close]. + private var engine: Engine? = null + private var conversation: Conversation? = null + + // True if the engine was loaded with a non-null visionBackend. We + // gate runMultimodal / runMultimodalStreaming on this so callers + // that loaded in text-only mode get a clear UNSUPPORTED error + // instead of a cryptic native failure deep inside LiteRT-LM. + private var visionEnabled: Boolean = false + + /** Signals an in-flight cancel request for one-shot run(). */ + private val cancelRequested = AtomicBoolean(false) + + // Simple wall-clock metrics. LiteRT-LM's Kotlin API does not expose + // token-level prefill/generation timings in the release we target, + // so we record initialization and last-run durations ourselves and + // leave the token counts at 0 for now. + private var initializationDurationMs: Double = 0.0 + private var lastRunDurationMs: Double = 0.0 + + // LiteRT-LM allows only one Conversation per Engine at a time. + // A new session cannot be opened until the active one is closed. + private var activeSession: LiteRTLmChatSession? = null + + override val chatSessionId: String? + get() = activeSession?.sessionId + + override fun load(req: LoadModelRequest): BackendResult { + Log.i( + TAG, + "load() entered: modelId=${req.modelId} backend=${req.backend} " + + "quant=${req.quantization} modelPath=${req.modelPath}" + ) + + // TEST-MODE fallback: during bring-up we want `load(GEMMA4)` to + // Just Work even if the caller forgets to pass model_path. Fall + // back to a known-good on-device path inside the host app's + // external files dir so the pipeline can be end-to-end verified. + // + // /data/local/tmp is NOT app-readable on user builds: it carries + // the shell_data_file SELinux context and the untrusted_app + // domain is denied read access. getExternalFilesDir() (a) + // requires no runtime permissions, (b) is always writable by + // adb, and (c) is already created by the framework for us. + val modelPath = req.modelPath?.takeIf { it.isNotBlank() } + ?: run { + val fallback = testModelFile().absolutePath + Log.w( + TAG, + "load(): model_path not provided, falling back to " + + "test path: $fallback" + ) + fallback + } + + Log.i(TAG, "load(): resolved modelPath=$modelPath") + + val modelFile = File(modelPath) + if (!modelFile.exists()) { + val parentDir = testModelFile().parentFile?.absolutePath ?: "" + val hint = "push it with: adb push $TEST_GEMMA4_FILE_NAME $parentDir/" + Log.e(TAG, "load(): model file not found at $modelPath β€” $hint") + return BackendResult.Err( + QuickAiError.MODEL_LOAD_FAILED, + "model file not found at $modelPath. $hint" + ) + } + Log.i( + TAG, + "load(): model file exists, size=${modelFile.length()} bytes, " + + "canRead=${modelFile.canRead()}" + ) + + val llmBackend: LlmBackend = mapBackend(req.backend) + // Null visionBackend leaves the engine in text-only mode. A + // non-null value enables the multimodal code path and unblocks + // [runMultimodal] / [runMultimodalStreaming] at call time. + val visionLlmBackend: LlmBackend? = req.visionBackend?.let(::mapBackend) + + Log.i( + TAG, + "load(): mapped compute backend ${req.backend} -> " + + "${llmBackend::class.java.simpleName}, vision=${req.visionBackend} -> " + + (visionLlmBackend?.let { it::class.java.simpleName } ?: "") + ) + + val engineConfig = EngineConfig( + modelPath = modelPath, + backend = llmBackend, + visionBackend = visionLlmBackend, + cacheDir = req.cacheDir, + maxNumTokens = req.maxNumTokens, + ) + Log.i( + TAG, + "load(): EngineConfig built (cacheDir=${req.cacheDir}, " + + "maxNumTokens=${req.maxNumTokens}), constructing Engine…" + ) + + return try { + val startNs = System.nanoTime() + val e = Engine(engineConfig) + Log.i(TAG, "load(): Engine() constructed, calling initialize()…") + e.initialize() + Log.i( + TAG, + "load(): Engine.initialize() returned after " + + "${(System.nanoTime() - startNs) / 1_000_000} ms" + ) + + val c = e.createConversation() + Log.i(TAG, "load(): Engine.createConversation() returned") + + initializationDurationMs = (System.nanoTime() - startNs) / 1_000_000.0 + engine = e + conversation = c + visionEnabled = (visionLlmBackend != null) + Log.i( + TAG, + "load(): SUCCESS, total init duration=${initializationDurationMs} ms, " + + "visionEnabled=$visionEnabled" + ) + BackendResult.Ok(Unit) + } catch (t: Throwable) { + Log.e(TAG, "load(): LiteRT-LM engine load failed", t) + // On partial success, make sure we don't leak a half-initialised + // engine into the caller's registry. + closeQuietly() + BackendResult.Err( + QuickAiError.MODEL_LOAD_FAILED, + t.message ?: "LiteRT-LM engine initialization failed" + ) + } + } + + /** + * @brief Multimodal inference β€” blocking. + * + * Builds a LiteRT-LM [Contents] from [parts] and hands it to + * `conversation.sendMessage(contents)`. Returns the decoded text + * of the model's reply on success, or a [BackendResult.Err] on + * failure. Gated on [visionEnabled] so callers that forgot to set + * [LoadModelRequest.visionBackend] get a clear UNSUPPORTED error + * rather than a cryptic native crash. + */ + override fun runMultimodalHandle(parts: List): BackendResult { + val c = conversation + ?: run { + Log.e(TAG, "runMultimodal(): called before load() β€” conversation is null") + return BackendResult.Err( + QuickAiError.NOT_INITIALIZED, + "LiteRTLm has not been loaded yet" + ) + } + if (!visionEnabled) { + Log.e( + TAG, + "runMultimodal(): engine loaded in text-only mode β€” " + + "reload with LoadModelRequest.visionBackend set" + ) + return BackendResult.Err( + QuickAiError.UNSUPPORTED, + "LiteRTLm was loaded without a visionBackend β€” reload with " + + "LoadModelRequest.visionBackend set to a non-null value." + ) + } + if (parts.isEmpty()) { + return BackendResult.Err( + QuickAiError.INVALID_PARAMETER, + "runMultimodal(): parts list is empty" + ) + } + + val contents = try { + toLiteRtContents(parts) + } catch (t: Throwable) { + Log.e(TAG, "runMultimodal(): failed to build Contents", t) + return BackendResult.Err( + QuickAiError.INVALID_PARAMETER, + t.message ?: "failed to build LiteRT-LM Contents from parts" + ) + } + + Log.i(TAG, "runMultimodal(): sending ${parts.size} parts") + return try { + val startNs = System.nanoTime() + val message = c.sendMessage(contents) + lastRunDurationMs = (System.nanoTime() - startNs) / 1_000_000.0 + val output = message.toString() + Log.i( + TAG, + "runMultimodal(): sendMessage returned in ${lastRunDurationMs.toLong()} ms, " + + "output length=${output.length}" + ) + BackendResult.Ok(output) + } catch (t: Throwable) { + Log.e(TAG, "runMultimodal(): LiteRT-LM sendMessage failed", t) + BackendResult.Err( + QuickAiError.INFERENCE_FAILED, + t.message ?: "LiteRT-LM multimodal inference failed" + ) + } + } + + /** + * @brief Multimodal inference β€” streaming. + * + * Same shape as [runStreaming]: drives LiteRT-LM's + * `sendMessageAsync(contents, callback)` and forwards incremental + * deltas to [sink], blocking the caller thread on a + * [CountDownLatch] until the callback signals `onDone` or + * `onError`. The delta-extraction logic is shared with the + * text-only path (see [runStreaming] for the rationale behind + * the `accumulated` StringBuilder defensive handling). + */ + override fun runMultimodalHandleStreaming( + parts: List, + sink: StreamSink + ): BackendResult { + val c = conversation + ?: run { + Log.e(TAG, "runMultimodalStreaming(): called before load()") + val err = BackendResult.Err( + QuickAiError.NOT_INITIALIZED, + "LiteRTLm has not been loaded yet" + ) + sink.onError(err.error, err.message) + return err + } + if (!visionEnabled) { + Log.e(TAG, "runMultimodalStreaming(): engine loaded in text-only mode") + val err = BackendResult.Err( + QuickAiError.UNSUPPORTED, + "LiteRTLm was loaded without a visionBackend β€” reload with " + + "LoadModelRequest.visionBackend set to a non-null value." + ) + sink.onError(err.error, err.message) + return err + } + if (parts.isEmpty()) { + val err = BackendResult.Err( + QuickAiError.INVALID_PARAMETER, + "runMultimodalStreaming(): parts list is empty" + ) + sink.onError(err.error, err.message) + return err + } + + val contents = try { + toLiteRtContents(parts) + } catch (t: Throwable) { + Log.e(TAG, "runMultimodalStreaming(): failed to build Contents", t) + val err = BackendResult.Err( + QuickAiError.INVALID_PARAMETER, + t.message ?: "failed to build LiteRT-LM Contents from parts" + ) + sink.onError(err.error, err.message) + return err + } + + Log.i(TAG, "runMultimodalStreaming(): ${parts.size} parts") + + val latch = CountDownLatch(1) + val accumulated = StringBuilder() + var terminalError: BackendResult.Err? = null + val startNs = System.nanoTime() + + val callback = object : MessageCallback { + override fun onMessage(message: Message) { + try { + val full = message.toString() + val delta = if (full.startsWith(accumulated.toString())) { + full.substring(accumulated.length) + } else { + full + } + if (delta.isNotEmpty()) { + accumulated.append(delta) + sink.onDelta(delta) + } + } catch (t: Throwable) { + Log.w(TAG, "runMultimodalStreaming(): onMessage threw", t) + } + } + + override fun onDone() { + lastRunDurationMs = (System.nanoTime() - startNs) / 1_000_000.0 + Log.i( + TAG, + "runMultimodalStreaming(): onDone after ${lastRunDurationMs.toLong()} ms, " + + "total chars=${accumulated.length}" + ) + try { + sink.onDone() + } finally { + latch.countDown() + } + } + + override fun onError(throwable: Throwable) { + Log.e(TAG, "runMultimodalStreaming(): onError from LiteRT-LM", throwable) + val err = BackendResult.Err( + QuickAiError.INFERENCE_FAILED, + throwable.message ?: "LiteRT-LM multimodal streaming inference failed" + ) + terminalError = err + try { + sink.onError(err.error, err.message) + } finally { + latch.countDown() + } + } + } + + return try { + c.sendMessageAsync(contents, callback) + val finished = latch.await(5, TimeUnit.MINUTES) + if (!finished) { + Log.e(TAG, "runMultimodalStreaming(): timed out waiting for onDone/onError") + val err = BackendResult.Err( + QuickAiError.INFERENCE_FAILED, + "LiteRT-LM multimodal streaming timeout" + ) + sink.onError(err.error, err.message) + return err + } + terminalError ?: BackendResult.Ok(Unit) + } catch (t: Throwable) { + Log.e(TAG, "runMultimodalStreaming(): sendMessageAsync threw", t) + val err = BackendResult.Err( + QuickAiError.INFERENCE_FAILED, + t.message ?: "LiteRT-LM multimodal streaming inference failed" + ) + sink.onError(err.error, err.message) + err + } + } + + // --- chat session management ----------------------------------------- + + override fun openChatSession( + config: QuickAiChatSessionConfig? + ): BackendResult { + val e = engine + ?: return BackendResult.Err( + QuickAiError.NOT_INITIALIZED, + "LiteRTLm has not been loaded yet" + ) + + // LiteRT-LM supports only one Conversation per Engine. Reject + // if a session is already active. + if (activeSession != null) { + Log.w( + TAG, + "openChatSession(): rejected β€” session ${activeSession!!.sessionId} " + + "is still active. Close it first." + ) + return BackendResult.Err( + QuickAiError.BAD_REQUEST, + "A chat session is already active (${activeSession!!.sessionId}). " + + "Close it before opening a new one." + ) + } + + // LiteRT-LM allows only one Conversation per Engine. The flat + // run()/runStreaming() API keeps its own Conversation in + // `this.conversation` β€” close it first so the chat session can + // create a fresh one. It will be recreated when the session is + // closed (see closeActiveSession / closeChatSession). + try { + conversation?.close() + } catch (t: Throwable) { + Log.w(TAG, "openChatSession(): conversation.close() threw", t) + } + conversation = null + + return try { + val session = LiteRTLmChatSession( + engine = e, + config = config, + visionEnabled = visionEnabled, + onSessionClosed = { + // Called when session.close() fires (from any caller). + // Only act if we still own this session β€” closeActiveSession() + // nulls activeSession first so teardown skips restoration. + if (activeSession != null) { + activeSession = null + restoreConversation() + } + } + ) + activeSession = session + Log.i(TAG, "openChatSession(): created session ${session.sessionId}") + BackendResult.Ok(session.sessionId) + } catch (t: Throwable) { + Log.e(TAG, "openChatSession(): failed", t) + // Session creation failed β€” restore the flat-API Conversation + // so run()/runStreaming() remain usable. + restoreConversation() + BackendResult.Err( + QuickAiError.INFERENCE_FAILED, + t.message ?: "failed to create chat session" + ) + } + } + + override fun closeChatSession(): BackendResult { + val session = activeSession + if (session == null) { + Log.w(TAG, "closeChatSession(): no active session") + return BackendResult.Err( + QuickAiError.BAD_REQUEST, + "No active chat session to close" + ) + } + // session.close() fires the onSessionClosed callback which + // nulls activeSession and calls restoreConversation(). + session.close() + Log.i(TAG, "closeChatSession(${session.sessionId}): closed") + return BackendResult.Ok(Unit) + } + + override fun runChatModelHandleStreaming( + text: String, + sink: StreamSink + ): BackendResult { + val session = activeSession + if (session == null) { + val err = BackendResult.Err( + QuickAiError.BAD_REQUEST, + "No active chat session β€” call openChatSession() first" + ) + sink.onError(err.error, err.message) + return err + } + return try { + val messages = listOf( + QuickAiChatMessage(role = QuickAiChatRole.USER, parts = listOf(PromptPart.Text(text))) + ) + session.runStreaming(messages, sink) + } catch (t: Throwable) { + Log.e(TAG, "runChatModelHandleStreaming(): threw", t) + val err = BackendResult.Err( + QuickAiError.INFERENCE_FAILED, + t.message ?: "chat streaming failed" + ) + sink.onError(err.error, err.message) + err + } + } + + override fun runChatMultimodalHandleStreaming( + parts: List, + sink: StreamSink + ): BackendResult { + val session = activeSession + if (session == null) { + val err = BackendResult.Err( + QuickAiError.BAD_REQUEST, + "No active chat session β€” call openChatSession() first" + ) + sink.onError(err.error, err.message) + return err + } + if (!visionEnabled) { + val err = BackendResult.Err( + QuickAiError.UNSUPPORTED, + "LiteRTLm was loaded without a visionBackend" + ) + sink.onError(err.error, err.message) + return err + } + return try { + val messages = listOf( + QuickAiChatMessage(role = QuickAiChatRole.USER, parts = parts) + ) + session.runStreaming(messages, sink) + } catch (t: Throwable) { + Log.e(TAG, "runChatMultimodalHandleStreaming(): threw", t) + val err = BackendResult.Err( + QuickAiError.INFERENCE_FAILED, + t.message ?: "chat multimodal streaming failed" + ) + sink.onError(err.error, err.message) + err + } + } + + // ----- OpenAI messages API (handle-based) -------------------------------- + + /** + * @brief Streaming inference with OpenAI message format. + * + * Accumulates deltas into a single response, then emits it through [sink]. + * LiteRT-LM does not currently support true token-by-token streaming for + * handle-based messages, so this is implemented as blocking + chunk. + */ + override fun runModelHandleWithMessagesStreaming( + messages: List, + sink: StreamSink + ): BackendResult { + val c = conversation + ?: run { + val err = BackendResult.Err(QuickAiError.NOT_INITIALIZED) + sink.onError(err.error, err.message) + return err + } + + val prompt = messages.joinToString("\n") { msg -> + "${msg.role}: ${msg.parts.filterIsInstance().joinToString("") { it.text }}" + } + + return try { + val message = c.sendMessage(prompt) + val text = message.toString() + if (text.isNotEmpty()) sink.onDelta(text) + sink.onDone() + BackendResult.Ok(Unit) + } catch (t: Throwable) { + val err = BackendResult.Err(QuickAiError.INFERENCE_FAILED, t.message) + sink.onError(err.error, err.message) + err + } + } + + /** + * @brief Streaming multimodal inference with OpenAI message format. + * + * Accumulates deltas into a single response, then emits it through [sink]. + */ + override fun runMultimodalHandleWithMessagesStreaming( + messages: List, + sink: StreamSink + ): BackendResult { + val c = conversation + ?: run { + val err = BackendResult.Err(QuickAiError.NOT_INITIALIZED) + sink.onError(err.error, err.message) + return err + } + if (!visionEnabled) { + val err = BackendResult.Err(QuickAiError.UNSUPPORTED, "Vision not enabled") + sink.onError(err.error, err.message) + return err + } + + // Validate image count (1 only) + val imageCount = messages.sumOf { msg -> + msg.parts.count { it is PromptPart.ImageBytes || it is PromptPart.ImageFile } + } + if (imageCount == 0) { + val err = BackendResult.Err(QuickAiError.INVALID_PARAMETER, "No image found") + sink.onError(err.error, err.message) + return err + } + if (imageCount > 1) { + val err = BackendResult.Err(QuickAiError.INVALID_PARAMETER, "Only 1 image is allowed") + sink.onError(err.error, err.message) + return err + } + + val contents = try { + toLiteRtContentsFromMessages(messages) + } catch (t: Throwable) { + val err = BackendResult.Err(QuickAiError.INVALID_PARAMETER, t.message) + sink.onError(err.error, err.message) + return err + } + + return try { + val message = c.sendMessage(contents) + val text = message.toString() + if (text.isNotEmpty()) sink.onDelta(text) + sink.onDone() + BackendResult.Ok(Unit) + } catch (t: Throwable) { + val err = BackendResult.Err(QuickAiError.INFERENCE_FAILED, t.message) + sink.onError(err.error, err.message) + err + } + } + + private fun toLiteRtContentsFromMessages(messages: List): Contents { + val mapped: List = messages.flatMap { msg -> + msg.parts.map { part -> + when (part) { + is PromptPart.Text -> Content.Text(part.text) + is PromptPart.ImageFile -> { + val f = File(part.absolutePath) + require(f.exists() && f.canRead()) { + "PromptPart.ImageFile not readable: ${part.absolutePath}" + } + Content.ImageFile(part.absolutePath) + } + is PromptPart.ImageBytes -> { + require(part.bytes.isNotEmpty()) { + "PromptPart.ImageBytes has empty byte array" + } + Content.ImageBytes(part.bytes) + } + is PromptPart.PreprocessedPixels -> { + throw UnsupportedOperationException( + "PromptPart.PreprocessedPixels is not supported by LiteRTLm. " + + "Use NativeQuickDotAI for V-JEPA multi-image inference." + ) + } + } + } + } + return Contents.of(mapped) + } + + override fun cancel() { + cancelRequested.set(true) + Log.i(TAG, "cancel(): one-shot run cancel requested") + } + + override fun chatCancel() { + activeSession?.cancel() + ?: Log.w(TAG, "chatCancel(): no active session") + } + + override fun chatRebuild( + messages: List + ): BackendResult { + val session = activeSession + ?: return BackendResult.Err( + QuickAiError.BAD_REQUEST, + "No active chat session β€” call openChatSession() first" + ) + return try { + session.rebuild(messages) + } catch (t: Throwable) { + Log.e(TAG, "chatRebuild(): threw", t) + BackendResult.Err(QuickAiError.UNKNOWN, t.message) + } + } + + /** + * Close the active chat session if any. Nulls [activeSession] first + * so the session's onSessionClosed callback sees nothing to restore + * β€” [unload] / [close] call [closeQuietly] right after, which tears + * down the entire Engine. + */ + private fun closeActiveSession() { + val session = activeSession ?: return + Log.i(TAG, "closeActiveSession(): closing ${session.sessionId}") + activeSession = null // detach first β†’ callback skips restore + session.close() + } + + /** + * Recreate `this.conversation` from the engine so the flat run()/ + * runStreaming() API is usable again after a chat session closes. + */ + private fun restoreConversation() { + if (conversation != null || engine == null) return + try { + conversation = engine!!.createConversation() + Log.i(TAG, "restoreConversation(): flat-API Conversation recreated") + } catch (t: Throwable) { + Log.e(TAG, "restoreConversation(): failed to recreate Conversation", t) + } + } + + override fun unload(): BackendResult { + Log.i(TAG, "unload() invoked") + // Cancel any in-flight inference before unloading + cancel() + + closeActiveSession() + closeQuietly() + return BackendResult.Ok(Unit) + } + + override fun metrics(): BackendResult { + if (engine == null) { + return BackendResult.Err( + QuickAiError.NOT_INITIALIZED, + "LiteRTLm has not been loaded yet" + ) + } + // LiteRT-LM does not currently expose token-level counters + // through its Kotlin API, so most fields stay at 0. We still + // publish the wall-clock timings we measured ourselves so + // callers can at least see the load + last-run durations. + return BackendResult.Ok( + PerformanceMetrics( + initializationDurationMs = initializationDurationMs, + totalDurationMs = lastRunDurationMs + ) + ) + } + + override fun close() { + Log.i(TAG, "close() invoked") + closeActiveSession() + closeQuietly() + } + + private fun closeQuietly() { + try { + conversation?.close() + } catch (t: Throwable) { + Log.w(TAG, "conversation.close() threw", t) + } + conversation = null + try { + engine?.close() + } catch (t: Throwable) { + Log.w(TAG, "engine.close() threw", t) + } + engine = null + visionEnabled = false + } + + /** + * @brief Map a QuickDotAI [BackendType] to a LiteRT-LM [LlmBackend]. + * + * Extracted as a helper so both the compute and vision backends + * use exactly the same mapping (including the NPU β†’ CPU fallback) + * and we never drift between the two. + */ + private fun mapBackend(b: BackendType): LlmBackend = when (b) { + BackendType.CPU -> LlmBackend.CPU() + BackendType.GPU -> LlmBackend.GPU() + // LiteRT-LM's NPU backend wants the dir holding the vendor + // native .so files. For an app-bundled setup that is simply + // the APK's nativeLibraryDir, but we don't have a Context + // here β€” fall back to CPU until the caller wires one in. + BackendType.NPU -> LlmBackend.CPU() + } + + /** + * @brief Convert a list of [PromptPart]s into a LiteRT-LM [Contents] + * object ready to hand to `sendMessage` / `sendMessageAsync`. + * + * This is where the AAR-level public types cross the boundary into + * the LiteRT-LM package. We also validate each part eagerly so the + * caller gets a crisp error BEFORE we reach the native layer: + * - ImageFile: file must exist and be readable + * - ImageBytes: bytes must be non-empty + * + * Throws [IllegalArgumentException] on validation failure; the + * calling runMultimodal* wrappers translate that into a + * [QuickAiError.INVALID_PARAMETER] BackendResult. + */ + private fun toLiteRtContents(parts: List): Contents { + val mapped: List = parts.map { p -> + when (p) { + is PromptPart.Text -> Content.Text(p.text) + is PromptPart.ImageFile -> { + val f = File(p.absolutePath) + require(f.exists() && f.canRead()) { + "PromptPart.ImageFile not readable: ${p.absolutePath}" + } + Content.ImageFile(p.absolutePath) + } + is PromptPart.ImageBytes -> { + require(p.bytes.isNotEmpty()) { + "PromptPart.ImageBytes has empty byte array" + } + Content.ImageBytes(p.bytes) + } + is PromptPart.PreprocessedPixels -> { + throw UnsupportedOperationException( + "PromptPart.PreprocessedPixels is not supported by LiteRTLm. " + + "Use NativeQuickDotAI for V-JEPA multi-image inference." + ) + } + } + } + return Contents.of(mapped) + } + + /** + * @brief Build the test-mode fallback model file handle from the shared + * model directory at `/sdcard/Download/aistudio-mobile/models/`. + * This ensures all team apps (AI Studio Mobile, SampleTestAPP, etc.) + * can share the same model file without duplication. + */ + private fun testModelFile(): File { + val baseDir = File(defaultModelBasePath.trimEnd('/')) + val dir = File(baseDir, TEST_GEMMA4_REL_DIR) + if (!dir.exists()) dir.mkdirs() + return File(dir, TEST_GEMMA4_FILE_NAME) + } + + companion object { + private const val TAG = "LiteRTLm" + + /** + * @brief TEST ONLY β€” path components of the Gemma-4 E2B-IT + * `.litertlm` model, relative to the shared model directory. + * The absolute path is resolved at runtime via [testModelFile]. + * + * Push the file with adb before running: + * adb push gemma-4-E2B-it.litertlm \ + * /sdcard/Download/aistudio-mobile/models/gemma-4-E2B-it/ + */ + const val TEST_GEMMA4_REL_DIR: String = "gemma-4-E2B-it" + const val TEST_GEMMA4_FILE_NAME: String = "gemma-4-E2B-it.litertlm" + } +} diff --git a/Android/QuickDotAI/src/main/java/com/example/quickdotai/LiteRTLmChatSession.kt b/Android/QuickDotAI/src/main/java/com/example/quickdotai/LiteRTLmChatSession.kt new file mode 100644 index 00000000..9b731978 --- /dev/null +++ b/Android/QuickDotAI/src/main/java/com/example/quickdotai/LiteRTLmChatSession.kt @@ -0,0 +1,796 @@ +// SPDX-License-Identifier: Apache-2.0 +/* + * Copyright (C) 2026 Samsung Electronics Co., Ltd. All Rights Reserved. + * + * @file LiteRTLmChatSession.kt + * @brief Chat session helper backed by a LiteRT-LM Conversation. + * + * Each session owns a LiteRT-LM [Conversation] and an [ImageStore] that + * caches image bytes keyed by SHA-256 hash. When the same image arrives + * via a different temporary file path, the hash matches and the + * conversation history stays consistent. + * + * History ownership: + * This wrapper does NOT track conversation history itself. The + * underlying LiteRT-LM [Conversation] owns the full KV cache, so prior + * turns are implicitly retained across calls as long as the same + * [Conversation] is reused. Callers are expected to pass only *new* + * messages on each turn β€” not the whole transcript. + * + * Role handling: + * OpenAI-style role-interleaved inputs (including multiple SYSTEM turns, + * e.g. `[SYSTEM, USER, ASSISTANT, SYSTEM, USER]`) are forwarded to + * LiteRT-LM with roles preserved. When such input arrives, the session + * rebuilds the underlying [Conversation] with prior turns passed + * through [ConversationConfig.initialMessages] as a mix of + * [Message.system], [Message.user], and [Message.model] β€” the model's + * embedded chat template then renders the full role-annotated array + * natively. + * + * Fast path: + * When the caller passes a single trailing USER turn (the common + * "continue the dialogue" case), the session reuses the existing + * [Conversation] and simply calls `sendMessage(user)` β€” no close / no + * re-prefill. LiteRT-LM keeps the prior history in its KV cache, so + * this stays O(new tokens) instead of O(all tokens) per turn. + */ +package com.example.quickdotai + +import android.util.Log +import com.google.ai.edge.litertlm.Channel +import com.google.ai.edge.litertlm.Content +import com.google.ai.edge.litertlm.Contents +import com.google.ai.edge.litertlm.Conversation +import com.google.ai.edge.litertlm.ConversationConfig +import com.google.ai.edge.litertlm.Engine +import com.google.ai.edge.litertlm.Message +import com.google.ai.edge.litertlm.MessageCallback +import com.google.ai.edge.litertlm.SamplerConfig +import java.io.File +import java.util.UUID +import java.util.concurrent.CountDownLatch +import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicBoolean + +/** + * @brief Internal helper that holds the state of a single LiteRT-LM + * chat session. Not an interface implementation β€” all public + * chat API goes through [QuickDotAI] / [LiteRTLm]. + * + * @param engine the parent LiteRT-LM engine (kept for rebuild) + * @param config session-level sampling + template config + * @param visionEnabled whether the engine was loaded with a vision backend + * @param onSessionClosed callback fired once when [close] is invoked. + * Used by [LiteRTLm] to clear its `activeSession` and restore + * the flat-API Conversation. Skipped when the parent engine is + * tearing down (it nulls out `activeSession` before calling + * close so the callback sees nothing to do). + */ +internal class LiteRTLmChatSession( + private val engine: Engine, + private val config: QuickAiChatSessionConfig?, + private val visionEnabled: Boolean, + private val onSessionClosed: (() -> Unit)? = null, + val sessionId: String = UUID.randomUUID().toString() +) { + + private var conversation: Conversation? = null + internal val imageStore = ImageStore() + + /** Signals an in-flight cancel request. */ + private val cancelRequested = AtomicBoolean(false) + + /** + * Session-level `extraContext` map forwarded to every + * [Conversation.sendMessage] / [Conversation.sendMessageAsync] call. + * + * LiteRT-LM's [ConversationConfig] does not expose template kwargs, + * so we materialize them per-call instead. Built once from + * [QuickAiChatSessionConfig.chatTemplateKwargs]: + * - `enableThinking` β†’ `"enable_thinking" β†’ Boolean` + * + * Empty when no template kwargs are configured. + */ + private val extraContext: Map = + buildExtraContext(config?.chatTemplateKwargs).also { + if (it.isNotEmpty()) { + Log.i(TAG, "LiteRTLmChatSession($sessionId): extraContext=$it") + } + } + + @Volatile + private var closed = false + + private var lastRunDurationMs: Double = 0.0 + + // ----- run (blocking) ------------------------------------------------ + + fun run( + messages: List + ): BackendResult { + if (closed) return errClosed() + cancelRequested.set(false) + + val prep = prepareTurn(messages) ?: return lastPrepError + ?: BackendResult.Err(QuickAiError.INVALID_PARAMETER, "invalid chat input") + + return try { + val c = acquireConversationForTurn(prep, messages) + ?: return BackendResult.Err( + QuickAiError.INFERENCE_FAILED, + "Failed to build Conversation for this turn" + ) + + val startNs = System.nanoTime() + val response = if (hasImages(prep.lastUser)) { + val contents = toChatContents(prep.lastUser) + c.sendMessage(contents, extraContext) + } else { + val text = extractText(prep.lastUser) + c.sendMessage(text, extraContext) + } + lastRunDurationMs = (System.nanoTime() - startNs) / 1_000_000.0 + + val output = response.toString() + val reasoning = response.channels[THOUGHT_CHANNEL_NAME]?.takeIf { it.isNotBlank() } + Log.i(TAG, "run($sessionId): completed in ${lastRunDurationMs.toLong()} ms") + + BackendResult.Ok( + QuickAiChatResult( + content = output, + reasoning = reasoning, + metrics = PerformanceMetrics(totalDurationMs = lastRunDurationMs) + ) + ) + } catch (t: Throwable) { + Log.e(TAG, "run($sessionId): inference failed", t) + BackendResult.Err( + QuickAiError.INFERENCE_FAILED, + t.message ?: "chat inference failed" + ) + } + } + + // ----- runStreaming --------------------------------------------------- + + fun runStreaming( + messages: List, + sink: StreamSink + ): BackendResult { + if (closed) { + val err = errClosed() + sink.onError(err.error, err.message) + return err + } + cancelRequested.set(false) + + val prep = prepareTurn(messages) + if (prep == null) { + val err = lastPrepError ?: BackendResult.Err( + QuickAiError.INVALID_PARAMETER, + "invalid chat input" + ) + sink.onError(err.error, err.message) + return err + } + + val c = acquireConversationForTurn(prep, messages) + if (c == null) { + val err = BackendResult.Err( + QuickAiError.INFERENCE_FAILED, + "Failed to build Conversation for this turn" + ) + sink.onError(err.error, err.message) + return err + } + + val latch = CountDownLatch(1) + val accumulated = StringBuilder() + val reasoningAccumulated = StringBuilder() + var terminalError: BackendResult.Err? = null + val startNs = System.nanoTime() + + val callback = object : MessageCallback { + override fun onMessage(message: Message) { + if (cancelRequested.get()) return + try { + val full = message.toString() + val delta = if (full.startsWith(accumulated.toString())) { + full.substring(accumulated.length) + } else { + full + } + if (delta.isNotEmpty()) { + accumulated.append(delta) + sink.onDelta(delta) + } + val reasoning = message.channels[THOUGHT_CHANNEL_NAME].orEmpty() + val reasoningDelta = if (reasoning.startsWith(reasoningAccumulated.toString())) { + reasoning.substring(reasoningAccumulated.length) + } else { + reasoning + } + if (reasoningDelta.isNotEmpty()) { + reasoningAccumulated.append(reasoningDelta) + sink.onReasoningDelta(reasoningDelta) + } + } catch (t: Throwable) { + Log.w(TAG, "runStreaming($sessionId): onMessage threw", t) + } + } + + override fun onDone() { + lastRunDurationMs = (System.nanoTime() - startNs) / 1_000_000.0 + Log.i( + TAG, + "runStreaming($sessionId): onDone after " + + "${lastRunDurationMs.toLong()} ms" + ) + try { sink.onDone() } finally { latch.countDown() } + } + + override fun onError(throwable: Throwable) { + Log.e(TAG, "runStreaming($sessionId): onError", throwable) + val err = BackendResult.Err( + QuickAiError.INFERENCE_FAILED, + throwable.message ?: "chat streaming failed" + ) + terminalError = err + try { sink.onError(err.error, err.message) } finally { latch.countDown() } + } + } + + return try { + if (hasImages(prep.lastUser)) { + val contents = toChatContents(prep.lastUser) + c.sendMessageAsync(contents, callback, extraContext) + } else { + val text = extractText(prep.lastUser) + c.sendMessageAsync(text, callback, extraContext) + } + + val finished = latch.await(5, TimeUnit.MINUTES) + if (!finished) { + Log.e(TAG, "runStreaming($sessionId): timed out") + val err = BackendResult.Err( + QuickAiError.INFERENCE_FAILED, + "chat streaming timeout" + ) + sink.onError(err.error, err.message) + return err + } + + if (terminalError != null) { + return terminalError!! + } + + val output = accumulated.toString() + BackendResult.Ok( + QuickAiChatResult( + content = output, + reasoning = reasoningAccumulated.toString().takeIf { it.isNotEmpty() }, + metrics = PerformanceMetrics(totalDurationMs = lastRunDurationMs) + ) + ) + } catch (t: Throwable) { + Log.e(TAG, "runStreaming($sessionId): threw", t) + val err = BackendResult.Err( + QuickAiError.INFERENCE_FAILED, + t.message ?: "chat streaming failed" + ) + sink.onError(err.error, err.message) + err + } + } + + // ----- cancel -------------------------------------------------------- + + fun cancel() { + cancelRequested.set(true) + Log.i(TAG, "cancel($sessionId): cancel requested") + } + + // ----- rebuild ------------------------------------------------------- + + fun rebuild( + messages: List + ): BackendResult { + if (closed) return errClosed() + + Log.i( + TAG, + "rebuild($sessionId): reset Conversation and pre-seed with " + + "${messages.size} message(s)" + ) + + // Drop the KV cache carried by the current Conversation so the + // rebuild is a true reset. + try { + conversation?.close() + } catch (t: Throwable) { + Log.w(TAG, "rebuild($sessionId): conversation.close() threw", t) + } + conversation = null + + // Prune images not referenced by the new seed. + val referencedHashes = collectImageHashes(messages) + imageStore.retainOnly(referencedHashes) + + // When the caller provides seed messages, eagerly create a new + // Conversation with them as initialMessages so LiteRT-LM can + // pre-fill its KV cache. If no seed is given, we leave the + // Conversation null and the next run() will lazily create it. + if (messages.isNotEmpty()) { + val initial = try { + messages.map { toLiteRtMessage(it) } + } catch (t: Throwable) { + Log.e(TAG, "rebuild($sessionId): mapping seed failed", t) + return BackendResult.Err( + QuickAiError.INFERENCE_FAILED, + t.message ?: "failed to map seed messages" + ) + } + try { + conversation = createConversationFromConfig(engine, config, initial) + } catch (t: Throwable) { + Log.e(TAG, "rebuild($sessionId): createConversation threw", t) + return BackendResult.Err( + QuickAiError.INFERENCE_FAILED, + t.message ?: "failed to create seeded Conversation" + ) + } + } + + return BackendResult.Ok(Unit) + } + + // ----- close --------------------------------------------------------- + + fun close() { + if (closed) return + closed = true + Log.i(TAG, "close($sessionId)") + try { + conversation?.close() + } catch (t: Throwable) { + Log.w(TAG, "close($sessionId): conversation.close() threw", t) + } + conversation = null + imageStore.clear() + try { + onSessionClosed?.invoke() + } catch (t: Throwable) { + Log.w(TAG, "close($sessionId): onSessionClosed threw", t) + } + } + + // ----- helpers ------------------------------------------------------- + + /** Per-turn data: trailing USER and any prior turns in this call. */ + private data class TurnPrep( + val priorTurns: List, + val lastUser: QuickAiChatMessage, + ) + + /** + * Surface for [run] / [runStreaming] to learn WHY [prepareTurn] + * returned null without resorting to exceptions. Mutated only from + * the caller thread right before prepareTurn is invoked. + */ + @Volatile + private var lastPrepError: BackendResult.Err? = null + + /** + * Validate [messages] (non-empty, trailing USER, vision gating) and + * split off the trailing USER turn that will drive inference from + * any leading turns that need to feed the chat template. + * + * Returns null on validation failure; the concrete error is stashed + * in [lastPrepError] for the caller to surface (and optionally + * forward to a StreamSink). + */ + private fun prepareTurn(messages: List): TurnPrep? { + lastPrepError = null + if (messages.isEmpty()) { + lastPrepError = BackendResult.Err( + QuickAiError.INVALID_PARAMETER, + "messages list is empty" + ) + return null + } + if (messages.last().role != QuickAiChatRole.USER) { + lastPrepError = BackendResult.Err( + QuickAiError.INVALID_PARAMETER, + "last message must have role USER to trigger inference " + + "(got ${messages.last().role})" + ) + return null + } + + val lastUser = messages.last() + val priorTurns = messages.subList(0, messages.size - 1) + + if (!visionEnabled && messages.any { hasImages(it) }) { + lastPrepError = BackendResult.Err( + QuickAiError.UNSUPPORTED, + "Engine loaded in text-only mode β€” cannot process images" + ) + return null + } + + return TurnPrep(priorTurns = priorTurns, lastUser = lastUser) + } + + /** + * Close any previously-held Conversation and build a fresh one for + * this turn, passing [priorTurns] through + * [ConversationConfig.initialMessages] so LiteRT-LM's native chat + * template renders the full role-annotated array. + * + * Returns null on construction failure. + */ + private fun rebuildConversationForTurn( + priorTurns: List + ): Conversation? { + try { + conversation?.close() + } catch (t: Throwable) { + Log.w(TAG, "rebuildConversationForTurn($sessionId): close() threw", t) + } + conversation = null + + val initial = try { + priorTurns.map { toLiteRtMessage(it) } + } catch (t: Throwable) { + Log.e(TAG, "rebuildConversationForTurn($sessionId): mapping failed", t) + return null + } + + return try { + createConversationFromConfig(engine, config, initial).also { + conversation = it + if (initial.isNotEmpty()) { + Log.i( + TAG, + "rebuildConversationForTurn($sessionId): seeded with " + + "${initial.size} initialMessage(s) " + + priorTurns.joinToString(prefix = "[", postfix = "]") { + it.role.name + } + ) + } + } + } catch (t: Throwable) { + Log.e(TAG, "rebuildConversationForTurn($sessionId): createConversation threw", t) + null + } + } + + /** + * Fast path wrapper around [rebuildConversationForTurn]. + * + * Source of truth: LiteRT-LM's [Conversation] owns the KV cache and + * therefore the full conversation history. As long as we keep + * handing the same [Conversation] out, `sendMessage(user)` extends + * the dialogue correctly β€” no replay needed on our side. + * + * When the caller passes only a single trailing USER turn we just + * reuse the existing [Conversation]. This keeps each follow-up turn + * O(new user tokens) instead of O(all history tokens). + * + * Falls back to [rebuildConversationForTurn] whenever: + * - no existing Conversation is held yet (first turn, or post-rebuild), or + * - the caller injects anything other than exactly one USER turn + * (e.g. SYSTEM/ASSISTANT turns, role-interleaved bundles, or + * multi-USER batches) β€” those require the full role-annotated + * initialMessages replay, which drops the KV cache. + */ + private fun acquireConversationForTurn( + prep: TurnPrep, + newMessages: List + ): Conversation? { + val existing = conversation + if (existing != null && + newMessages.size == 1 && + newMessages[0].role == QuickAiChatRole.USER + ) { + Log.i( + TAG, + "acquireConversationForTurn($sessionId): fast path β€” " + + "reusing existing Conversation (KV cache retained)" + ) + return existing + } + return rebuildConversationForTurn(prep.priorTurns) + } + + /** + * Map a [QuickAiChatMessage] to a LiteRT-LM [Message], preserving + * the original role: + * - [QuickAiChatRole.SYSTEM] β†’ [Message.system] + * - [QuickAiChatRole.USER] β†’ [Message.user] + * - [QuickAiChatRole.ASSISTANT] β†’ [Message.model] + * + * Image parts are also stored in the session's [imageStore] so the + * hash-based identity stays stable across rebuilds. + */ + private fun toLiteRtMessage(msg: QuickAiChatMessage): Message { + val contents = toChatContents(msg) + return when (msg.role) { + QuickAiChatRole.SYSTEM -> Message.system(contents) + QuickAiChatRole.USER -> Message.user(contents) + QuickAiChatRole.ASSISTANT -> Message.model(contents = contents) + } + } + + private fun hasImages(msg: QuickAiChatMessage): Boolean = + msg.parts.any { it is PromptPart.ImageFile || it is PromptPart.ImageBytes } + + private fun extractText(msg: QuickAiChatMessage): String = + msg.parts.filterIsInstance().joinToString("") { it.text } + + /** + * Convert a user message's parts into a LiteRT-LM [Contents]. + * Images are stored in [imageStore] along the way. + */ + private fun toChatContents(msg: QuickAiChatMessage): Contents { + val mapped: List = msg.parts.map { p -> + when (p) { + is PromptPart.Text -> Content.Text(p.text) + is PromptPart.ImageFile -> { + val f = File(p.absolutePath) + require(f.exists() && f.canRead()) { + "Image file not readable: ${p.absolutePath}" + } + // Store in ImageStore for stable hash-based identity + imageStore.store(p.absolutePath) + Content.ImageFile(p.absolutePath) + } + is PromptPart.ImageBytes -> { + require(p.bytes.isNotEmpty()) { + "Image bytes are empty" + } + imageStore.store(p.bytes) + Content.ImageBytes(p.bytes) + } + is PromptPart.PreprocessedPixels -> { + throw UnsupportedOperationException( + "PromptPart.PreprocessedPixels is not supported by LiteRTLm. " + + "Use NativeQuickDotAI for V-JEPA multi-image inference." + ) + } + } + } + return Contents.of(mapped) + } + + /** + * Collect all image hashes referenced in a list of messages. + * Used by [rebuild] to determine which cached images to keep. + */ + private fun collectImageHashes(messages: List): Set { + val hashes = mutableSetOf() + for (msg in messages) { + for (part in msg.parts) { + when (part) { + is PromptPart.ImageFile -> { + val f = File(part.absolutePath) + if (f.exists() && f.canRead()) { + hashes.add(ImageStore.sha256Hex(f.readBytes())) + } + } + is PromptPart.ImageBytes -> { + if (part.bytes.isNotEmpty()) { + hashes.add(ImageStore.sha256Hex(part.bytes)) + } + } + is PromptPart.PreprocessedPixels -> { /* not supported by LiteRTLm */ } + is PromptPart.Text -> { /* no image */ } + } + } + } + return hashes + } + + private fun errClosed(): BackendResult.Err = BackendResult.Err( + QuickAiError.NOT_INITIALIZED, + "Chat session $sessionId is closed" + ) + + companion object { + private const val TAG = "LiteRTLmChatSession" + + /** + * Translate a [QuickAiChatTemplateKwargs] into the + * `extraContext: Map` that LiteRT-LM's + * [Conversation.sendMessage] / `sendMessageAsync` accepts. + * + * The native chat template reads these keys directly + * (e.g. Jinja `{% if enable_thinking %}`), so the JSON-style + * snake_case names must be preserved exactly: + * - [QuickAiChatTemplateKwargs.enableThinking] β†’ + * `"enable_thinking" β†’ Boolean` + * + * Null fields are omitted from the map so the chat template's + * own default kicks in. Returns an empty map when [kwargs] is + * null. + */ + private fun buildExtraContext( + kwargs: QuickAiChatTemplateKwargs? + ): Map { + if (kwargs == null) return emptyMap() + val out = mutableMapOf() + kwargs.enableThinking?.let { out["enable_thinking"] = it } + return out + } + + // Neutral fallback values used only when the caller specifies + // *some* sampling fields but not all three required ones. + // LiteRT-LM's SamplerConfig(topK, topP, temperature) are all + // non-nullable with no defaults, so we must supply something + // whenever we construct a SamplerConfig at all. + private const val FALLBACK_TEMPERATURE = 1.0 + private const val FALLBACK_TOP_K = 40 + private const val FALLBACK_TOP_P = 0.95 + private const val THOUGHT_CHANNEL_NAME = "thought" + private const val THOUGHT_CHANNEL_START = "<|channel>thought" + private const val THOUGHT_CHANNEL_END = "" + + /** + * Build a LiteRT-LM [Conversation] from a [QuickAiChatSessionConfig] + * and a list of prior turns that will be forwarded to the native + * chat template via [ConversationConfig.initialMessages]. + * + * Maps: + * - [QuickAiChatSessionConfig.systemInstruction] β†’ + * [ConversationConfig.systemInstruction] + * - [QuickAiChatSamplingConfig] β†’ [SamplerConfig] + * (see [buildSamplerConfig] for per-field behavior) + * - [initialMessages] β†’ [ConversationConfig.initialMessages] + * (role-preserving: SYSTEM/USER/ASSISTANT β†’ system/user/model) + * + * Falls back to the bare `engine.createConversation()` overload + * only when nothing is configured AND there are no prior turns, + * so LiteRT-LM uses its own engine-level defaults. + */ + private fun createConversationFromConfig( + engine: Engine, + config: QuickAiChatSessionConfig?, + initialMessages: List = emptyList(), + ): Conversation { + val sysInstruction = config?.systemInstruction?.takeIf { it.isNotBlank() } + val samplerConfig = buildSamplerConfig(config?.sampling) + val channels = buildConversationChannels(config?.chatTemplateKwargs) + + // Skip ConversationConfig entirely when nothing is configured + // so LiteRT-LM uses its own engine/model defaults. + if (sysInstruction == null && + samplerConfig == null && + channels == null && + initialMessages.isEmpty() + ) { + return engine.createConversation() + } + + val convConfig = ConversationConfig( + systemInstruction = + sysInstruction?.let { Contents.of(it) }, + initialMessages = initialMessages, + samplerConfig = samplerConfig, + channels = channels, + ) + Log.i( + TAG, + "createConversationFromConfig: " + + "sysInstruction=${sysInstruction?.take(60)}, " + + "samplerConfig=$samplerConfig, " + + "initialMessages=${initialMessages.size}, " + + "channels=${channels?.joinToString { it.channelName } ?: "none"}" + ) + return engine.createConversation(convConfig) + } + + private fun buildConversationChannels( + kwargs: QuickAiChatTemplateKwargs? + ): List? { + if (kwargs?.enableThinking != true) { + return null + } + return listOf( + Channel( + channelName = THOUGHT_CHANNEL_NAME, + start = THOUGHT_CHANNEL_START, + end = THOUGHT_CHANNEL_END, + ) + ) + } + + /** + * Map [QuickAiChatSamplingConfig] to LiteRT-LM [SamplerConfig]. + * + * LiteRT-LM's `SamplerConfig(topK: Int, topP: Double, + * temperature: Double, seed: Int = 0)` has three non-nullable + * core fields. That means we cannot express "set only + * temperature, leave topK/topP to the engine default" through + * a partially-populated SamplerConfig β€” any SamplerConfig we + * construct MUST carry all three. + * + * Behavior: + * - Returns `null` when [sampling] is null or all relevant + * fields are null. The caller then passes no samplerConfig + * to ConversationConfig, and LiteRT-LM uses its own + * engine-level defaults (preferred path for best quality). + * - When any of temperature/topK/topP/seed is specified, + * constructs a full SamplerConfig, filling the remaining + * core fields from [FALLBACK_TEMPERATURE]/[FALLBACK_TOP_K]/ + * [FALLBACK_TOP_P]. A warning is logged so partial + * specification is visible in logcat. + * - [QuickAiChatSamplingConfig.minP] and + * [QuickAiChatSamplingConfig.maxTokens] are not supported by + * LiteRT-LM's SamplerConfig; values are ignored and a + * warning is logged. + * + * LiteRT-LM validates ranges in `SamplerConfig.init` + * (topK > 0, topP in [0,1], temperature >= 0) and throws + * [IllegalArgumentException] on violation. That throw + * propagates up through [createConversationFromConfig] and is + * caught in [LiteRTLm.openChatSession], where it is converted + * to a BackendResult.Err. + */ + private fun buildSamplerConfig( + sampling: QuickAiChatSamplingConfig? + ): SamplerConfig? { + if (sampling == null) return null + + val anyCoreSet = sampling.temperature != null || + sampling.topK != null || + sampling.topP != null || + sampling.seed != null + + // Warn about QuickAi fields that LiteRT-LM's SamplerConfig + // does not expose. Doing it up front means the warning + // fires even if no core field is set (i.e. even if we end + // up returning null below). + if (sampling.minP != null) { + Log.w( + TAG, + "buildSamplerConfig: minP=${sampling.minP} is not " + + "supported by LiteRT-LM SamplerConfig β€” ignored" + ) + } + if (sampling.maxTokens != null) { + Log.w( + TAG, + "buildSamplerConfig: maxTokens=${sampling.maxTokens} " + + "is not supported by LiteRT-LM SamplerConfig β€” ignored" + ) + } + + if (!anyCoreSet) return null + + val missing = buildList { + if (sampling.temperature == null) add("temperature") + if (sampling.topK == null) add("topK") + if (sampling.topP == null) add("topP") + } + if (missing.isNotEmpty()) { + Log.w( + TAG, + "buildSamplerConfig: partial sampling config β€” " + + "LiteRT-LM SamplerConfig requires all core fields; " + + "filling ${missing.joinToString()} with fallback " + + "defaults (temperature=$FALLBACK_TEMPERATURE, " + + "topK=$FALLBACK_TOP_K, topP=$FALLBACK_TOP_P). " + + "Specify all three together to avoid this." + ) + } + + return SamplerConfig( + topK = sampling.topK ?: FALLBACK_TOP_K, + topP = sampling.topP ?: FALLBACK_TOP_P, + temperature = sampling.temperature ?: FALLBACK_TEMPERATURE, + seed = sampling.seed ?: 0, + ) + } + } +} diff --git a/Android/QuickDotAI/src/main/java/com/example/quickdotai/LlavaNextImageProcessor.kt b/Android/QuickDotAI/src/main/java/com/example/quickdotai/LlavaNextImageProcessor.kt new file mode 100644 index 00000000..6e7b4cbf --- /dev/null +++ b/Android/QuickDotAI/src/main/java/com/example/quickdotai/LlavaNextImageProcessor.kt @@ -0,0 +1,288 @@ +package com.example.quickdotai + +import android.content.Context +import android.graphics.Bitmap +import android.graphics.Canvas +import android.graphics.ImageDecoder +import android.net.Uri +import android.graphics.Color +import java.io.IOException +import kotlin.math.ceil +import kotlin.math.min +import androidx.core.graphics.createBitmap +import com.example.quickdotai.PillowBilinearResizer + + +/** + * A Kotlin implementation of the Llava-NeXT image processor for Android. + * + * This class transforms a Bitmap into a list of normalized FloatArrays, + * representing image patches, ready for input into a vision model. + * + * @param cropSize The size (height and width) of each square patch. Corresponds to the model's input size. + * @param imageGridPinpoints A list of possible high-resolution grids to select from. + * @param imageMean The mean values for normalization (R, G, B). + * @param imageStd The standard deviation values for normalization (R, G, B). + */ +class LlavaNextImageProcessor( + private val context: Context, + private val cropSize: Int = 512, + private val imageGridPinpoints: List> = listOf( + Pair(512,1024),Pair(512,1536),Pair(512,2048),Pair(512,2560),Pair(512,3072),Pair(512,3584),Pair(512,4096),Pair(512,4608),Pair(512,5120),Pair(512,5632),Pair(512,6144),Pair(1024,512),Pair(1024,1024),Pair(1024,1536),Pair(1024,2048),Pair(1024,2560),Pair(1024,3072),Pair(1536,512),Pair(1536,1024),Pair(1536,1536),Pair(1536,2048),Pair(2048,512),Pair(2048,1024),Pair(2048,1536),Pair(2560,512),Pair(2560,1024),Pair(3072,512),Pair(3072,1024),Pair(3584,512),Pair(4096,512),Pair(4608,512),Pair(5120,512),Pair(5632,512),Pair(6144,512) + ), + private val imageMean: FloatArray = floatArrayOf(0.5f, 0.5f, 0.5f), + private val imageStd: FloatArray = floatArrayOf(0.5f, 0.5f, 0.5f), + private val rescaleFactor: Double = 1.0 / 255.0, + private val patchMergeType: String = "nopad" +) { + + fun resizeImage(inputBitmap: Bitmap, targetWidth: Int, targetHeight: Int): Bitmap { + val width = inputBitmap.width + val height = inputBitmap.height + val pixels = IntArray(width * height) + inputBitmap.getPixels(pixels, 0, width, 0, 0, width, height) +// val resizedPixels = PillowBicubicResizer.resize(pixels, width, height, targetWidth, targetHeight) + val resizedPixels = PillowBilinearResizer.resize(pixels, width, height, targetWidth, targetHeight) + return Bitmap.createBitmap(resizedPixels, targetWidth, targetHeight, Bitmap.Config.ARGB_8888) + +// return inputBitmap.scale(targetWidth, targetHeight, filter = true) + } + + /** + * Loads a Bitmap from a given content URI. + * + * @param imageUri The URI of the image to load. + * @return A Bitmap object, or null if loading fails. + */ + fun loadBitmapFromUri(imageUri: Uri, resize: Boolean = false): Bitmap? { + val bitmap = try { + val source = ImageDecoder.createSource(context.contentResolver, imageUri) + + // Decode the bitmap, ensuring it's mutable and in ARGB_8888 config + ImageDecoder.decodeBitmap(source) { decoder, _, _ -> + decoder.isMutableRequired = true + decoder.allocator = ImageDecoder.ALLOCATOR_SOFTWARE + }.copy(Bitmap.Config.ARGB_8888, true) + } catch (e: IOException) { + e.printStackTrace() + null + } + if (resize) { + return resizeBitmapIfTooLarge(bitmap) + } + return bitmap + } + + fun resizeBitmapIfTooLarge( + originalBitmap: Bitmap?, + maxPixels: Int = 300000 + ): Bitmap? { + if (originalBitmap == null) return null + val currentPixels = originalBitmap.width * originalBitmap.height + + if (currentPixels <= maxPixels) { + // Bitmap is already smaller or equal to the max pixels, no resize needed + return originalBitmap + } + + val aspectRatio = originalBitmap.width.toFloat() / originalBitmap.height.toFloat() + + // Calculate new dimensions + // newHeight = sqrt(maxPixels / aspectRatio) + val newHeight = Math.sqrt(maxPixels / aspectRatio.toDouble()).toInt() + // newWidth = newHeight * aspectRatio + val newWidth = (newHeight * aspectRatio).toInt() + + // Create a new scaled bitmap + return resizeImage(originalBitmap, newWidth, newHeight) + } + + // Represents the final model input for a single image + data class ModelInput(val pixelValues: FloatArray, val originalSize: Pair) + + /** + * @brief Returns the crop size (patch size) used for image preprocessing. + */ + fun getCropSize(): Int = cropSize + + /** + * Preprocesses a single Bitmap image. + * + * @param image The input Bitmap. + * @return A ModelInput object containing a list of float arrays (patches) and original image size. + */ + fun preprocess(image: Bitmap): ModelInput { + val originalSize = Pair(image.height, image.width) + val imagePatches = getImagePatches(image) + val perImagePatchSize = cropSize * cropSize * 3 + val floatValues = FloatArray(imagePatches.size * perImagePatchSize) + + imagePatches.mapIndexed { index, patch -> + // All patches, including the base one, are already at the target cropSize. + // We just need to normalize them. + normalize(patch, floatValues, index * perImagePatchSize) + } + + // TODO: Compare pixelValues to PyTorch's pixelValues for various images + return ModelInput(pixelValues = floatValues, originalSize = originalSize) + } + + /** + * Creates image patches based on the LLaVa-NeXT "any-resolution" strategy. + */ + private fun getImagePatches(image: Bitmap): List { + // 1. Create the base, low-res image (resized to cropSize x cropSize) + val baseImage = resizeImage(image, cropSize, cropSize) + + // 2. Handle high-resolution patching + val bestResolution = selectBestResolution(Pair(image.height, image.width))!! + val resizedForPatching = resizeForPatching(image, bestResolution) + val paddedImage = padToResolution(resizedForPatching, bestResolution) + val highResPatches = divideToPatches(paddedImage, cropSize) + + return listOf(baseImage) + highResPatches + } + + /** + * Selects the best grid resolution from `imageGridPinpoints` that fits the image. + */ + private fun selectBestResolution(originalSize: Pair): Pair? { + if (imageGridPinpoints.isEmpty()) { + return null + } + + val (originalHeight, originalWidth) = originalSize + var bestFit: Pair? = null + var maxEffectiveResolution = -1 + var minWastedResolution = Int.MAX_VALUE + + for (resolution in imageGridPinpoints) { + val (height, width) = resolution + + // Use Double for division to maintain precision + val scale = min( + width.toDouble() / originalWidth, + height.toDouble() / originalHeight + ) + val downscaledWidth = (originalWidth * scale).toInt() + val downscaledHeight = (originalHeight * scale).toInt() + + val effectiveResolution = min( + downscaledWidth * downscaledHeight, + originalWidth * originalHeight + ) + val wastedResolution = (width * height) - effectiveResolution + + if (effectiveResolution > maxEffectiveResolution || + (effectiveResolution == maxEffectiveResolution && wastedResolution < minWastedResolution) + ) { + maxEffectiveResolution = effectiveResolution + minWastedResolution = wastedResolution + bestFit = resolution + } + } + + return bestFit + } + + private fun getPatchOutputSize(image: Bitmap, targetResolution: Pair): Pair{ + val (targetHeight, targetWidth) = targetResolution + val (originalHeight, originalWidth) = image.height to image.width + + val scaleW = targetWidth.toFloat() / originalWidth + val scaleH = targetHeight.toFloat() / originalHeight + + val newWidth: Int + val newHeight: Int + + if (scaleW < scaleH) { + newWidth = targetWidth + newHeight = min(ceil(originalHeight * scaleW).toInt(), targetHeight) + } else { + newHeight = targetHeight + newWidth = min(ceil(originalWidth * scaleH).toInt(), targetWidth) + } + return Pair(newHeight, newWidth); + } + + /** + * Resizes an image to fit within a target resolution while maintaining aspect ratio. + */ + private fun resizeForPatching(image: Bitmap, targetResolution: Pair): Bitmap { + val (targetHeight, targetWidth) = targetResolution + val newWidth: Int + val newHeight: Int + + if (patchMergeType == "nopad") { + newHeight = targetHeight + newWidth = targetWidth + } + else { // spatial_unpad + val (newH, newW) = getPatchOutputSize(image, targetResolution) + newHeight = newH + newWidth = newW + } + + return resizeImage(image, newWidth, newHeight) + } + + /** + * Pads a resized image to the exact target resolution by adding black bars. + */ + private fun padToResolution(image: Bitmap, targetResolution: Pair): Bitmap { + val (targetHeight, targetWidth) = targetResolution + val (imageHeight, imageWidth) = getPatchOutputSize(image, targetResolution) + + if (imageHeight == targetHeight && imageWidth == targetWidth) { + return image + } + + val paddedBitmap = createBitmap(targetWidth, targetHeight, image.config ?: Bitmap.Config.ARGB_8888) + val canvas = Canvas(paddedBitmap) + canvas.drawColor(Color.BLACK) // Pad with black + + val left = (targetWidth - imageWidth) / 2f + val top = (targetHeight - imageHeight) / 2f + + canvas.drawBitmap(image, left, top, null) + return paddedBitmap + } + + /** + * Divides an image into a grid of square patches. + */ + private fun divideToPatches(image: Bitmap, patchSize: Int): List { + val patches = mutableListOf() + val (height, width) = image.height to image.width + + for (i in 0 until height step patchSize) { + for (j in 0 until width step patchSize) { + val patch = Bitmap.createBitmap(image, j, i, patchSize, patchSize) + patches.add(patch) + } + } + return patches + } + + /** + * Rescales pixel values from [0, 255] to [0, 1] and then normalizes them. + * The output is a flattened FloatArray in CHW (Channels, Height, Width) format. + */ + private fun normalize(image: Bitmap, floatValues: FloatArray, offset: Int) { + val width = image.width + val height = image.height + val pixels = IntArray(width * height) + image.getPixels(pixels, 0, width, 0, 0, width, height) + + // HWC format: RGB...RGB...RGB... + + for (i in 0 until (width * height)) { + val pixel = pixels[i] + // Rescale to [0, 1] then normalize + floatValues[offset + i * 3] = ((Color.red(pixel) / 255.0f) - imageMean[0]) / imageStd[0] + floatValues[offset + i * 3 + 1] = ((Color.green(pixel) / 255.0f) - imageMean[1]) / imageStd[1] + floatValues[offset + i * 3 + 2] = ((Color.blue(pixel) / 255.0f) - imageMean[2]) / imageStd[2] + } + } + +} \ No newline at end of file diff --git a/Android/QuickDotAI/src/main/java/com/example/quickdotai/ModelCatalog.kt b/Android/QuickDotAI/src/main/java/com/example/quickdotai/ModelCatalog.kt new file mode 100644 index 00000000..f012859a --- /dev/null +++ b/Android/QuickDotAI/src/main/java/com/example/quickdotai/ModelCatalog.kt @@ -0,0 +1,138 @@ +// SPDX-License-Identifier: Apache-2.0 +/* + * Copyright (C) 2026 Samsung Electronics Co., Ltd. All Rights Reserved. + * + * @file ModelCatalog.kt + * @brief Kotlin model catalog types and the ModelCatalog singleton. + * + * The catalog is seeded from the native C registry via + * [NativeCausalLm.nativeQueryCatalog] (returns a JSON array). LiteRT-only + * descriptors are appended locally because they are never registered in the + * C layer. A static fallback is used when the native call fails (e.g. + * emulator / missing .so). + */ +package com.example.quickdotai + +import android.util.Log +import org.json.JSONArray + +enum class RuntimeKind { NATIVE, LITERT } +enum class Capability { STREAMING, MESSAGES_API, MULTIMODAL, TOOL_USE, EMBEDDING, MULTI_IMAGE } + +data class ModelDescriptor( + val id: String, + val family: String, + val displayName: String, + val runtime: RuntimeKind, + val backends: Set, + val capabilities: Set, +) + +/** String constants for public model ids (migration convenience). */ +object ModelIds { + const val QWEN3_0_6B = "qwen3-0.6b" + const val QWEN3_1_7B_Q40 = "qwen3-1.7b-q40" + const val TINY_BERT = "tiny-bert" + const val FUNCTION_GEMMA = "function-gemma" + const val GEMMA4 = "gemma4" // LiteRT only + const val GEMMA4_CPU = "gemma4-cpu" + const val GEMMA4_E2B_QNN = "gemma4-e2b-qnn" + const val VJEPA_QNN = "vjepa-qnn" // V-JEPA multi-image (QNN) +} + +object ModelCatalog { + private const val TAG = "ModelCatalog" + + // LiteRT-only descriptor (not registered in C). + private val liteRtDescriptors = listOf( + ModelDescriptor( + id = ModelIds.GEMMA4, + family = "gemma4", + displayName = "Gemma4 (LiteRT)", + runtime = RuntimeKind.LITERT, + backends = setOf(BackendType.GPU), + capabilities = setOf(Capability.MULTIMODAL, Capability.MESSAGES_API, Capability.STREAMING), + ) + ) + + // Fallback when native query fails (public NATIVE subset). + private val nativeFallback = listOf( + ModelDescriptor(ModelIds.QWEN3_0_6B, "qwen3-0.6b", "Qwen3 0.6B", + RuntimeKind.NATIVE, setOf(BackendType.CPU, BackendType.GPU), + setOf(Capability.STREAMING, Capability.TOOL_USE)), + ModelDescriptor(ModelIds.GEMMA4_CPU, "gemma4", "Gemma4 (CPU)", + RuntimeKind.NATIVE, setOf(BackendType.CPU), setOf(Capability.STREAMING)), + ) + + private val catalog: List by lazy(LazyThreadSafetyMode.SYNCHRONIZED) { build() } + + fun all(): List = catalog + + private fun build(): List { + val native = if (NativeCausalLm.ensureLoaded()) { + try { + parse(NativeCausalLm.nativeQueryCatalog()) + } catch (t: Throwable) { + Log.e(TAG, "nativeQueryCatalog failed; using fallback", t) + nativeFallback + } + } else { + Log.w(TAG, "native library not loaded; using fallback catalog") + nativeFallback + } + return native + liteRtDescriptors + } + + private fun parse(json: String): List { + val arr = JSONArray(json) + return (0 until arr.length()).map { i -> + val o = arr.getJSONObject(i) + ModelDescriptor( + id = o.getString("id"), + family = o.getString("family"), + displayName = o.optString("display_name", o.getString("id")), + runtime = if (o.getInt("runtime") == 1) RuntimeKind.LITERT else RuntimeKind.NATIVE, + backends = decodeBackends(o.getInt("backend_mask")), + capabilities = decodeCaps(o.getInt("capabilities")), + ) + } + } + + private fun decodeBackends(mask: Int): Set = + BackendType.values().filter { (mask shr it.ordinal) and 1 == 1 }.toSet() + + private fun decodeCaps(bits: Int): Set = buildSet { + if (bits and 0b000001 != 0) add(Capability.STREAMING) + if (bits and 0b000010 != 0) add(Capability.MESSAGES_API) + if (bits and 0b000100 != 0) add(Capability.MULTIMODAL) + if (bits and 0b001000 != 0) add(Capability.TOOL_USE) + if (bits and 0b010000 != 0) add(Capability.EMBEDDING) + if (bits and 0b100000 != 0) add(Capability.MULTI_IMAGE) + } + + fun byId(id: String): ModelDescriptor? = all().firstOrNull { it.id == id } + fun families(): List = all().map { it.family }.distinct() + fun runtimesFor(family: String): Set = + all().filter { it.family == family }.map { it.runtime }.toSet() + fun backendsFor(family: String, rt: RuntimeKind): Set = + all().filter { it.family == family && it.runtime == rt } + .flatMap { it.backends }.toSet() + fun resolve(family: String, rt: RuntimeKind, backend: BackendType): ModelDescriptor? = + all().firstOrNull { it.family == family && it.runtime == rt && backend in it.backends } + + /** μ‚¬μš©μž 선택(생성/μ‹€ν–‰) κ°€λŠ₯ capability. ν•˜λ‚˜λΌλ„ 있으면 피컀에 λ…ΈμΆœ. */ + private val SELECTABLE_CAPS = setOf( + Capability.STREAMING, Capability.MESSAGES_API, + Capability.MULTIMODAL, Capability.TOOL_USE + ) + + /** 생성/μ‹€ν–‰ κ°€λŠ₯ν•œ λͺ¨λΈμΈμ§€(EMBEDDING μ „μš© λͺ¨λΈμ€ false). */ + fun isSelectable(d: ModelDescriptor): Boolean = + d.capabilities.any { it in SELECTABLE_CAPS } + + /** 피컀에 λ…ΈμΆœν•  λͺ¨λΈλ§Œ. all()은 전체λ₯Ό κ·ΈλŒ€λ‘œ μœ μ§€. */ + fun selectable(): List = all().filter { isSelectable(it) } + + /** selectable()μ—μ„œ νŒŒμƒν•œ family λͺ©λ‘(families()와 동일 distinct κ·œμΉ™). */ + fun selectableFamilies(): List = selectable().map { it.family }.distinct() +} diff --git a/Android/QuickDotAI/src/main/java/com/example/quickdotai/NativeCausalLm.kt b/Android/QuickDotAI/src/main/java/com/example/quickdotai/NativeCausalLm.kt new file mode 100644 index 00000000..606e69cd --- /dev/null +++ b/Android/QuickDotAI/src/main/java/com/example/quickdotai/NativeCausalLm.kt @@ -0,0 +1,402 @@ +// SPDX-License-Identifier: Apache-2.0 +/* + * Copyright (C) 2026 Samsung Electronics Co., Ltd. All Rights Reserved. + * + * @file NativeCausalLm.kt + * @brief JNI bindings for libcausallm_api.so (handle-based API only). + * + * All methods here are 1:1 with the handle-based entry points added to + * quick_dot_ai_api.h. Higher-level lifecycle (serialization, registry, + * threading) lives in [NativeQuickDotAI] and in the host app β€” this + * file is only the JNI glue. + */ +package com.example.quickdotai + +/** + * @brief Low-level JNI bridge to libcausallm_api.so. + * + * Loaded libraries (all bundled into the QuickDotAI AAR under + * jniLibs/arm64-v8a/): + * - libquickai_jni.so (JNI shim produced by src/main/cpp) + * - libcausallm_api.so (the C API lib built from Applications/CausalLM) + * - libcausallm_core.so (transitive) + * - libnntrainer.so (transitive) + * - libccapi-nntrainer.so (transitive) + * + * Any non-zero `errorCode` value corresponds to `ErrorCode` in + * quick_dot_ai_api.h β€” see [QuickAiError.fromNativeCode] for the Kotlin + * mapping. + * + * @hide + * + * Implementation detail: this object is `public` rather than `internal` + * because Kotlin's `internal`-visibility name mangling (`$modulename` + * suffix) would interfere with JNI symbol resolution β€” the JNI entry + * points in quickai_jni.cpp use the unmangled `Java_com_example_quickdotai_ + * NativeCausalLm_` names. Treat it as implementation detail and + * always go through [NativeQuickDotAI]. + */ +object NativeCausalLm { + + @Volatile + private var loaded: Boolean = false + + /** + * @brief Must be called once before any other method. Swallows + * UnsatisfiedLinkError so callers can still return a clean + * MODEL_LOAD_FAILED error to their own clients when the native lib + * is missing (e.g. during emulator development without the + * prebuilt .so files). + */ + @Synchronized + fun ensureLoaded(): Boolean { + if (loaded) return true + return try { + // qnn_context must be loaded before quickai_jni + System.loadLibrary("qnn_context") + // quickai_jni dlopens libcausallm_api.so as part of its JNI_OnLoad. + System.loadLibrary("quickai_jni") + loaded = true + true + } catch (t: UnsatisfiedLinkError) { + android.util.Log.e(TAG, "Failed to load libquickai_jni.so: ${t.message}") + false + } + } + + /** + * @brief Result of a loadModel call. [handle] is an opaque pointer + * (packed in a long) that must be passed back to [runModelHandleNative], + * [getPerformanceMetricsHandleNative] and [destroyModelHandleNative]. + */ + data class LoadResult(val errorCode: Int, val handle: Long) + + /** + * @brief Result of a runModel call. + */ + data class RunResult(val errorCode: Int, val output: String?) + + /** + * @brief Result of a metrics call. + */ + data class MetricsResult( + val errorCode: Int, + val prefillTokens: Int, + val prefillDurationMs: Double, + val generationTokens: Int, + val generationDurationMs: Double, + val totalDurationMs: Double, + val initializationDurationMs: Double, + val peakMemoryKb: Long + ) + + + /** + * @brief Result of a multimodal run call. + */ + data class MultimodalRunResult(val errorCode: Int, val output: String?) + + /** + * @brief Multimodal input data for vision encoder. + * + * Supports both single-image (legacy) and multi-image (e.g. V-JEPA + * video frames) scenarios. + * + * @param pixelValues Preprocessed image patches in CHW format. + * Shape: [numPatches * 3 * 512 * 512] (patch size is fixed at 512) + * For multi-image, all images' patches are concatenated. + * @param numPatches Total number of image patches across all images + * @param originalHeight Original image height before preprocessing (first image) + * @param originalWidth Original image width before preprocessing (first image) + * @param numImages Number of images (e.g. 16 for V-JEPA video frames). + * Defaults to 1 for backward compatibility. + * @param patchesPerImage Number of patches per image. Null for single-image. + * @param originalHeights Original height of each image. Null for single-image. + * @param originalWidths Original width of each image. Null for single-image. + */ + data class MultimodalInput( + val pixelValues: FloatArray, + val numPatches: Int, + val originalHeight: Int, + val originalWidth: Int, + val numImages: Int = 1, + val patchesPerImage: IntArray? = null, + val originalHeights: IntArray? = null, + val originalWidths: IntArray? = null + ) + + + /** Forwards to `setOptions` in quick_dot_ai_api.h. */ + external fun setOptionsNative( + useChatTemplate: Boolean, + debugMode: Boolean, + verbose: Boolean + ): Int + + /** + * @brief Thin wrapper around POSIX `chdir(2)`. + * + * The native C API in quick_dot_ai_api.cpp builds its model paths as + * `./models/-` (see `resolve_model_path`), so the + * loader's behaviour depends on the process's current working + * directory. Android apps launch with cwd="/" which is not writable, + * so the host code must chdir the process to an app-owned directory + * (typically `Context.getExternalFilesDir(null)`) before calling + * [loadModelHandleNative]. [NativeQuickDotAI] does this + * automatically when the caller supplies a [LoadModelRequest.modelPath]. + * + * @return 0 on success, or the POSIX errno value on failure. + */ + external fun chdirNative(path: String): Int + + /** + * @brief Forwards to `loadModelHandle` in quick_dot_ai_api.h. + * + * @param nativeLibDir Native library directory path from + * ApplicationInfo.nativeLibraryDir. May be null. + * @param modelBasePath Base directory for model files + * (e.g. "/sdcard/Download/aistudio-mobile/models/"). + * May be null (uses C API default). + */ + external fun loadModelHandleNative( + backendOrdinal: Int, + modelOrdinal: Int, + quantOrdinal: Int, + nativeLibDir: String?, + modelBasePath: String?, + htpBackendConfigPath: String? + ): LoadResult + + /** + * @brief Loads model by string catalog id (T4 path). + * @return Handle as Long, or 0 on failure. + */ + external fun loadModelHandleByNameNative( + backend: Int, + modelId: String, + quant: Int, + nativeLibDir: String?, + modelBasePath: String?, + ): Long + + /** @brief Returns the registered model catalog as a JSON array string. */ + external fun nativeQueryCatalog(): String + + /** + * Encode [text] into a sentence-embedding vector using an embedding handle + * (models[0] must be a SentenceTransformer, e.g. "ouro"). + * + * @return the embedding FloatArray on success, or null on any native error + * (unsupported model, not initialized, inference failure). + */ + external fun encodeModelHandleNative(handle: Long, text: String): FloatArray? + + /** + * @brief Listener invoked by the JNI trampoline once per decoded + * delta during [runModelHandleStreamingNative]. + * + * The method is called **on the same thread that invoked + * runModelHandleStreamingNative** β€” the JNI bridge does NOT attach + * any new thread to the JVM β€” so implementations must be + * non-blocking (deltas arrive back-to-back at decode speed). + */ + fun interface NativeStreamListener { + fun onDelta(text: String) + } + + /** + * @brief Forwards to `runModelHandleStreaming` in quick_dot_ai_api.h. + * + * Blocking: returns only when generation finishes, EOS is emitted, + * NUM_TO_GENERATE is reached, the listener throws, or an error + * occurs. [listener] is invoked synchronously from the same thread + * for every decoded delta; if it throws, the JNI bridge catches + * the exception, asks the native runner to cancel at the next + * token boundary, and propagates a non-zero ErrorCode back here. + * Terminal events (onDone / onError) are synthesized on the Kotlin + * side from the return value β€” see [NativeQuickDotAI.runStreaming]. + * + * @return An `ErrorCode` int; 0 on clean completion. + */ + external fun runModelHandleStreamingNative( + handle: Long, + prompt: String, + listener: NativeStreamListener + ): Int + + /** + * @brief Forwards to `runModelHandleWithMessagesStreaming` in quick_dot_ai_api.h. + * + * Streaming inference with OpenAI message format on a specific handle. + * + * @param handle Handle returned by loadModelHandleNative + * @param messages Array of chat messages + * @param addGenerationPrompt Whether to append generation prompt at end + * @param listener Callback for streaming output + * @return An `ErrorCode` int; 0 on clean completion. + */ + external fun runModelHandleWithMessagesStreamingNative( + handle: Long, + messages: Array< QuickAiChatMessage>, + addGenerationPrompt: Boolean, + listener: NativeStreamListener + ): Int + + /** Forwards to `getPerformanceMetricsHandle` in quick_dot_ai_api.h. */ + external fun getPerformanceMetricsHandleNative(handle: Long): MetricsResult + + /** Forwards to `unloadModelHandle` in quick_dot_ai_api.h. */ + external fun unloadModelHandleNative(handle: Long): Int + + /** Forwards to `destroyModelHandle` in quick_dot_ai_api.h. */ + external fun destroyModelHandleNative(handle: Long): Int + + /** + * @brief Forwards to `cancelModelHandle` in quick_dot_ai_api.h. + * + * Requests cancellation of an in-progress streaming run. Thread-safe: + * can be called from any thread (e.g., UI cancel button handler). + * + * @param handle Handle returned by loadModelHandleNative + * @return An `ErrorCode` int; 0 on success. + */ + external fun cancelModelHandleNative(handle: Long): Int + + /** + * @brief Forwards to `runMultimodalHandleStreaming` in quick_dot_ai_api.h. + * + * Multimodal streaming inference that accepts preprocessed image patches + * and a text prompt. The pixel values are passed as a FloatArray and + * converted to native float* in JNI layer. + * + * @param handle Handle returned by loadModelHandleNative + * @param prompt Text prompt + * @param pixelValues Preprocessed image patches (CHW format) + * @param numPatches Number of image patches + * @param originalHeight Original image height before preprocessing + * @param originalWidth Original image width before preprocessing + * @param listener Callback for streaming output + * @return An `ErrorCode` int; 0 on clean completion. + */ + external fun runMultimodalHandleStreamingNative( + handle: Long, + prompt: String, + pixelValues: FloatArray, + numPatches: Int, + originalHeight: Int, + originalWidth: Int, + listener: NativeStreamListener + ): Int + + /** + * @brief Forwards to `runMultimodalHandleWithMessagesStreaming` in quick_dot_ai_api.h. + * + * Streaming multimodal inference with OpenAI message format on a specific handle. + * + * @param handle Handle returned by loadModelHandleNative + * @param messages Array of chat messages (text-only, image via pixelValues) + * @param addGenerationPrompt Whether to append generation prompt at end + * @param pixelValues Preprocessed image patches (CHW format) + * @param numPatches Number of image patches + * @param originalHeight Original image height before preprocessing + * @param originalWidth Original image width before preprocessing + * @param listener Callback for streaming output + * @return An `ErrorCode` int; 0 on clean completion. + */ + external fun runMultimodalHandleWithMessagesStreamingNative( + handle: Long, + messages: Array, + addGenerationPrompt: Boolean, + pixelValues: FloatArray, + numPatches: Int, + originalHeight: Int, + originalWidth: Int, + listener: NativeStreamListener + ): Int + + /** + * @brief Forwards to `runModelHandleWithJsonStreaming` in quick_dot_ai_api.h. + * + * Streaming inference with OpenAI JSON format on a specific handle. + * Accepts a JSON string containing messages, tools, functions, etc. + * + * Example JSON input: + * ``` + * { + * "messages": [ + * {"role": "developer", "content": "..."}, + * {"role": "user", "content": "..."} + * ], + * "tools": [ + * {"type": "function", "function": {"name": "call", "description": "..."}} + * ] + * } + * ``` + * + * @param handle Handle returned by loadModelHandleNative + * @param jsonRequest OpenAI format JSON string + * @param listener Callback for streaming output + * @return An `ErrorCode` int; 0 on clean completion. + */ + external fun runModelHandleWithJsonStreamingNative( + handle: Long, + jsonRequest: String, + listener: NativeStreamListener + ): Int + + /** + * @brief Multimodal streaming inference with multi-image support (V-JEPA). + * + * @param handle Handle returned by loadModelHandleNative + * @param prompt Text prompt + * @param pixelValues Preprocessed image patches (CHW format, all images concatenated) + * @param numPatches Total number of image patches + * @param numImages Number of images (e.g. 16 for V-JEPA) + * @param patchesPerImage Number of patches per image + * @param originalHeights Original height of each image + * @param originalWidths Original width of each image + * @param listener Callback for streaming output + * @return An `ErrorCode` int; 0 on clean completion. + */ + external fun runMultimodalMultiImageStreamingNative( + handle: Long, + prompt: String, + pixelValues: FloatArray, + numPatches: Int, + numImages: Int, + patchesPerImage: IntArray, + originalHeights: IntArray, + originalWidths: IntArray, + listener: NativeStreamListener + ): Int + + /** + * @brief Multimodal streaming inference with multi-image + messages (V-JEPA). + * + * @param handle Handle returned by loadModelHandleNative + * @param messages Array of chat messages + * @param addGenerationPrompt Whether to append generation prompt at end + * @param pixelValues Preprocessed image patches (CHW format, all images concatenated) + * @param numPatches Total number of image patches + * @param numImages Number of images (e.g. 16 for V-JEPA) + * @param patchesPerImage Number of patches per image + * @param originalHeights Original height of each image + * @param originalWidths Original width of each image + * @param listener Callback for streaming output + * @return An `ErrorCode` int; 0 on clean completion. + */ + external fun runMultimodalMultiImageWithMessagesStreamingNative( + handle: Long, + messages: Array, + addGenerationPrompt: Boolean, + pixelValues: FloatArray, + numPatches: Int, + numImages: Int, + patchesPerImage: IntArray, + originalHeights: IntArray, + originalWidths: IntArray, + listener: NativeStreamListener + ): Int + + private const val TAG = "NativeCausalLm" +} diff --git a/Android/QuickDotAI/src/main/java/com/example/quickdotai/NativeChatSession.kt b/Android/QuickDotAI/src/main/java/com/example/quickdotai/NativeChatSession.kt new file mode 100644 index 00000000..c19a9e7a --- /dev/null +++ b/Android/QuickDotAI/src/main/java/com/example/quickdotai/NativeChatSession.kt @@ -0,0 +1,254 @@ +// SPDX-License-Identifier: Apache-2.0 +/* + * Copyright (C) 2026 Samsung Electronics Co., Ltd. All Rights Reserved. + * + * @file NativeChatSession.kt + * @brief Chat session helper for the native causal_lm backend. + * + * The native engine now manages its own KV cache and chat template, so + * this wrapper no longer tracks conversation history in Kotlin. + * It simply validates input, extracts the trailing USER turn, and + * forwards it to the native handle. The native engine is responsible + * for prompt formatting and state retention across turns. + * + * Lifecycle: + * openChatSession() -> run()/runStreaming() -> closeChatSession() + * + * Rebuild semantics (option D): + * rebuild() is a no-op at the Kotlin layer. The native engine handles + * KV cache resets internally. The system instruction supplied at + * session creation remains in effect. + */ +package com.example.quickdotai + +import android.util.Log +import java.util.UUID +import java.util.concurrent.atomic.AtomicBoolean + +internal class NativeChatSession( + private val handleProvider: () -> Long, + private val config: QuickAiChatSessionConfig? = null, + val sessionId: String = UUID.randomUUID().toString() +) { + + private val cancelRequested = AtomicBoolean(false) + + @Volatile + private var closed = false + + private var lastRunDurationMs: Double = 0.0 + + init { + config?.systemInstruction?.takeIf { it.isNotBlank() }?.let { sys -> + Log.i(TAG, "NativeChatSession($sessionId): system instruction configured (${sys.length} chars)") + } + } + + fun runStreaming( + text: String, + sink: StreamSink + ): BackendResult { + if (closed) { + val err = errClosed() + sink.onError(err.error, err.message) + return err + } + + // Convert raw text to message format for C++ chat template + val messages = listOf( + QuickAiChatMessage( + role = QuickAiChatRole.USER, + parts = listOf(PromptPart.Text(text)) + ) + ) + + cancelRequested.set(false) + + val handle = handleProvider() + if (handle == 0L) { + val err = BackendResult.Err(QuickAiError.NOT_INITIALIZED, "Native handle is not available") + sink.onError(err.error, err.message) + return err + } + + val accumulated = StringBuilder() + val startNs = System.nanoTime() + + return try { + val errorCode = NativeCausalLm.runModelHandleWithMessagesStreamingNative( + handle, + messages.toTypedArray(), + true, + object : NativeCausalLm.NativeStreamListener { + override fun onDelta(text: String) { + if (cancelRequested.get()) return + accumulated.append(text) + sink.onDelta(text) + } + } + ) + + lastRunDurationMs = (System.nanoTime() - startNs) / 1_000_000.0 + + if (errorCode != 0) { + val err = QuickAiError.fromNativeCode(errorCode) + Log.e(TAG, "runStreaming($sessionId): failed with errorCode=$errorCode") + sink.onError(err, "Inference failed (errorCode=$errorCode)") + BackendResult.Err(err, "Inference failed (errorCode=$errorCode)") + } else { + val output = accumulated.toString() + Log.i(TAG, "runStreaming($sessionId): completed in ${lastRunDurationMs.toLong()} ms") + sink.onDone() + BackendResult.Ok( + QuickAiChatResult( + content = output, + metrics = PerformanceMetrics(totalDurationMs = lastRunDurationMs) + ) + ) + } + } catch (t: Throwable) { + Log.e(TAG, "runStreaming($sessionId): threw exception", t) + sink.onError(QuickAiError.INFERENCE_FAILED, t.message) + BackendResult.Err(QuickAiError.INFERENCE_FAILED, t.message) + } + } + + fun runMultimodalStreaming( + parts: List, + sink: StreamSink + ): BackendResult { + if (closed) { + val err = errClosed() + sink.onError(err.error, err.message) + return err + } + + cancelRequested.set(false) + + val handle = handleProvider() + if (handle == 0L) { + val err = BackendResult.Err(QuickAiError.NOT_INITIALIZED, "Native handle is not available") + sink.onError(err.error, err.message) + return err + } + + // Extract text and image from parts + val text = parts.filterIsInstance().joinToString(" ") { it.text } + val imageBytes = parts.filterIsInstance().firstOrNull()?.bytes + + if (imageBytes == null) { + val err = BackendResult.Err(QuickAiError.INVALID_PARAMETER, "No image found in parts") + sink.onError(err.error, err.message) + return err + } + + val accumulated = StringBuilder() + val startNs = System.nanoTime() + + return try { + // For now, use simple text streaming - image processing will be added + val errorCode = NativeCausalLm.runModelHandleStreamingNative( + handle, + text + ) { delta -> + if (cancelRequested.get()) return@runModelHandleStreamingNative + accumulated.append(delta) + sink.onDelta(delta) + } + + lastRunDurationMs = (System.nanoTime() - startNs) / 1_000_000.0 + + if (errorCode != 0) { + val err = QuickAiError.fromNativeCode(errorCode) + Log.e(TAG, "runMultimodalStreaming($sessionId): failed with errorCode=$errorCode") + sink.onError(err, "Inference failed (errorCode=$errorCode)") + BackendResult.Err(err, "Inference failed (errorCode=$errorCode)") + } else { + val output = accumulated.toString() + Log.i(TAG, "runMultimodalStreaming($sessionId): completed in ${lastRunDurationMs.toLong()} ms") + sink.onDone() + BackendResult.Ok( + QuickAiChatResult( + content = output, + metrics = PerformanceMetrics(totalDurationMs = lastRunDurationMs) + ) + ) + } + } catch (t: Throwable) { + Log.e(TAG, "runMultimodalStreaming($sessionId): threw exception", t) + sink.onError(QuickAiError.INFERENCE_FAILED, t.message) + BackendResult.Err(QuickAiError.INFERENCE_FAILED, t.message) + } + } + + fun cancel() { + if (closed) return + cancelRequested.set(true) + val handle = handleProvider() + if (handle != 0L) { + Log.i(TAG, "cancel($sessionId): requesting stop for handle=0x${handle.toString(16)}") + NativeCausalLm.cancelModelHandleNative(handle) + } else { + Log.w(TAG, "cancel($sessionId): no valid handle to cancel") + } + } + + fun rebuild( + messages: List + ): BackendResult { + if (closed) return errClosed() + + Log.i(TAG, "rebuild($sessionId): no-op at Kotlin layer β€” native engine manages KV cache") + + // The native engine owns the KV cache. There is no local history to clear. + // System instruction remains in [config] and is still in effect. + // Callers who need a hard reset can close/open a new session instead. + + return BackendResult.Ok(Unit) + } + + fun close() { + if (closed) return + closed = true + Log.i(TAG, "close($sessionId): session closed") + } + + private data class TurnPrep(val lastUser: QuickAiChatMessage) + + @Volatile + private var lastPrepError: BackendResult.Err? = null + + private fun prepareTurn(messages: List): TurnPrep? { + lastPrepError = null + + if (messages.isEmpty()) { + lastPrepError = BackendResult.Err( + QuickAiError.INVALID_PARAMETER, + "messages list is empty" + ) + return null + } + + if (messages.last().role != QuickAiChatRole.USER) { + lastPrepError = BackendResult.Err( + QuickAiError.INVALID_PARAMETER, + "last message must have role USER to trigger inference (got ${messages.last().role})" + ) + return null + } + + return TurnPrep(lastUser = messages.last()) + } + + private fun extractText(msg: QuickAiChatMessage): String = + msg.parts.filterIsInstance().joinToString("") { it.text } + + private fun errClosed(): BackendResult.Err = BackendResult.Err( + QuickAiError.NOT_INITIALIZED, + "Chat session $sessionId is closed" + ) + + companion object { + private const val TAG = "NativeChatSession" + } +} diff --git a/Android/QuickDotAI/src/main/java/com/example/quickdotai/NativeQuickDotAI.kt b/Android/QuickDotAI/src/main/java/com/example/quickdotai/NativeQuickDotAI.kt new file mode 100644 index 00000000..ec5ff4f1 --- /dev/null +++ b/Android/QuickDotAI/src/main/java/com/example/quickdotai/NativeQuickDotAI.kt @@ -0,0 +1,906 @@ +// SPDX-License-Identifier: Apache-2.0 +/* + * Copyright (C) 2026 Samsung Electronics Co., Ltd. All Rights Reserved. + * + * @file NativeQuickDotAI.kt + * @brief QuickDotAI implementation backed by the handle-based + * quick_dot_ai_api.h (routed through libquickai_jni.so β†’ JNI β†’ + * libcausallm_api.so). + */ +package com.example.quickdotai + +import android.content.Context +import android.graphics.BitmapFactory +import android.util.Log +import java.io.File + +/** + * @brief Kotlin wrapper around a single `CausalLmHandle` in native code. + * + * Non-thread-safe by design β€” the host app must drive a single instance + * from a single worker thread. + * + * @param appContext Application context required for multimodal image processing. + * Must be non-null to enable runMultimodal/runMultimodalStreaming. + */ +class NativeQuickDotAI( + private val appContext: Context +) : QuickDotAI { + + override val kind: String = "native" + + override var architecture: String? = null + private set + + private var handle: Long = 0L + private var loaded: Boolean = false + + // Image processor for multimodal inference + private var imageProcessor: LlavaNextImageProcessor? = null + + // Vision backend type (null = text-only mode) + private var visionBackend: BackendType? = null + + // Currently loaded model ID β€” used to route multi-image vs single-image paths + private var currentModelId: String? = null + + override fun load(req: LoadModelRequest): BackendResult { + Log.i( + TAG, + "load() entered: modelId=${req.modelId} backend=${req.backend} " + + "quant=${req.quantization}" + ) + if (loaded) { + Log.i(TAG, "load(): already loaded, returning Ok") + return BackendResult.Ok(Unit) + } + + if (!req.htpBackendConfigPath.isNullOrBlank()) { + Log.w(TAG, "load(): htpBackendConfigPath='${req.htpBackendConfigPath}' " + + "is not forwarded by the byName load path; " + + "C layer will derive HTP config from modelBasePath.") + } + + if (!NativeCausalLm.ensureLoaded()) { + Log.e(TAG, "load(): native libs unavailable on this device") + return BackendResult.Err( + QuickAiError.MODEL_LOAD_FAILED, + "libquickai_jni.so / libcausallm_api.so not available on this device" + ) + } + + // modelBasePath is passed directly from the caller. The C API uses + // this as the base directory for resolving model directories + // (e.g. "/qwen3-0.6b"). + val modelBasePath = req.modelBasePath + if (modelBasePath == null || modelBasePath.isBlank()) { + Log.w( + TAG, + "load(): modelBasePath is null/blank β€” C API will use its default " + + "fallback path. Specify modelBasePath for shared model access." + ) + } else { + Log.i(TAG, "load(): modelBasePath=$modelBasePath") + } + + return try { + Log.i(TAG, "load(): calling loadModelHandleByNameNative(backend=${req.backend.ordinal}, " + + "modelId=${req.modelId}, quant=${req.quantization.ordinal}, " + + "nativeLibDir=${req.nativeLibDir}, modelBasePath=$modelBasePath)") + val h = NativeCausalLm.loadModelHandleByNameNative( + backend = mapBackend(req.backend), + modelId = req.modelId, + quant = mapQuant(req.quantization), + nativeLibDir = req.nativeLibDir, + modelBasePath = modelBasePath, + ) + if (h == 0L) { + Log.e(TAG, "load(): loadModelHandleByNameNative returned 0 for '${req.modelId}'") + BackendResult.Err( + QuickAiError.MODEL_LOAD_FAILED, + "loadModelHandleByName failed for '${req.modelId}'" + ) + } else { + handle = h + loaded = true + architecture = req.modelId + currentModelId = req.modelId + visionBackend = req.visionBackend + if (req.visionBackend != null) { + imageProcessor = LlavaNextImageProcessor(appContext) + Log.i(TAG, "load(): visionBackend=${req.visionBackend}, image processor initialized") + } + Log.i(TAG, "load(): SUCCESS, handle=0x${h.toString(16)}") + BackendResult.Ok(Unit) + } + } catch (t: Throwable) { + Log.e(TAG, "load(): loadModelHandleByNameNative threw", t) + BackendResult.Err(QuickAiError.MODEL_LOAD_FAILED, t.message) + } + } + + override fun metrics(): BackendResult { + if (!loaded || handle == 0L) { + return BackendResult.Err(QuickAiError.NOT_INITIALIZED) + } + return try { + val m = NativeCausalLm.getPerformanceMetricsHandleNative(handle) + if (m.errorCode != 0) { + BackendResult.Err(QuickAiError.fromNativeCode(m.errorCode)) + } else { + BackendResult.Ok( + PerformanceMetrics( + prefillTokens = m.prefillTokens, + prefillDurationMs = m.prefillDurationMs, + generationTokens = m.generationTokens, + generationDurationMs = m.generationDurationMs, + totalDurationMs = m.totalDurationMs, + initializationDurationMs = m.initializationDurationMs, + peakMemoryKb = m.peakMemoryKb + ) + ) + } + } catch (t: Throwable) { + Log.e(TAG, "getPerformanceMetricsHandleNative threw", t) + BackendResult.Err(QuickAiError.UNKNOWN, t.message) + } + } + + override fun encode(text: String): BackendResult { + if (handle == 0L) { + return BackendResult.Err( + QuickAiError.NOT_INITIALIZED, + "encode(): no model loaded" + ) + } + return try { + val vec = NativeCausalLm.encodeModelHandleNative(handle, text) + if (vec == null || vec.isEmpty()) { + BackendResult.Err( + QuickAiError.INFERENCE_FAILED, + "encode() failed for current model '$currentModelId'" + ) + } else { + BackendResult.Ok(vec) + } + } catch (t: Throwable) { + Log.e(TAG, "encode() threw", t) + BackendResult.Err(QuickAiError.INFERENCE_FAILED, t.message) + } + } + + override fun unload(): BackendResult { + // Cancel any in-flight inference before unloading + cancel() + activeSession?.close() + activeSession = null + + if (!loaded || handle == 0L) { + return BackendResult.Ok(Unit) + } + return try { + val ec = NativeCausalLm.unloadModelHandleNative(handle) + loaded = false + if (ec != 0) { + BackendResult.Err(QuickAiError.fromNativeCode(ec)) + } else { + BackendResult.Ok(Unit) + } + } catch (t: Throwable) { + Log.w(TAG, "unloadModelHandleNative threw", t) + BackendResult.Err(QuickAiError.UNKNOWN, t.message) + } + } + + // --- chat session (dummy) -------------------------------------------- + + private var activeSession: NativeChatSession? = null + + override val chatSessionId: String? + get() = activeSession?.sessionId + + override fun openChatSession( + config: QuickAiChatSessionConfig? + ): BackendResult { + if (!loaded || handle == 0L) { + return BackendResult.Err( + QuickAiError.NOT_INITIALIZED, + "NativeQuickDotAI has not been loaded yet" + ) + } + if (activeSession != null) { + return BackendResult.Err( + QuickAiError.BAD_REQUEST, + "A chat session is already active (${activeSession!!.sessionId}). " + + "Close it before opening a new one." + ) + } + val session = NativeChatSession( + handleProvider = { handle }, + config = config + ) + activeSession = session + Log.i(TAG, "openChatSession(): created session ${session.sessionId} with handle=0x${handle.toString(16)}") + return BackendResult.Ok(session.sessionId) + } + + override fun closeChatSession(): BackendResult { + val session = activeSession + if (session == null) { + return BackendResult.Err( + QuickAiError.BAD_REQUEST, + "No active chat session to close" + ) + } + session.close() + activeSession = null + Log.i(TAG, "closeChatSession(${session.sessionId}): closed") + return BackendResult.Ok(Unit) + } + + override fun runChatModelHandleStreaming( + text: String, + sink: StreamSink + ): BackendResult { + val session = activeSession + if (session == null) { + val err = BackendResult.Err( + QuickAiError.BAD_REQUEST, + "No active chat session β€” call openChatSession() first" + ) + sink.onError(err.error, err.message) + return err + } + return session.runStreaming(text, sink) + } + + override fun runChatMultimodalHandleStreaming( + parts: List, + sink: StreamSink + ): BackendResult { + if (activeSession == null) { + val err = BackendResult.Err( + QuickAiError.BAD_REQUEST, + "No active chat session β€” call openChatSession() first" + ) + sink.onError(err.error, err.message) + return err + } + val accumulated = StringBuilder() + val forwardingSink = object : StreamSink { + override fun onDelta(text: String) { + accumulated.append(text) + sink.onDelta(text) + } + + override fun onReasoningDelta(text: String) { + sink.onReasoningDelta(text) + } + + override fun onDone() { + sink.onDone() + } + + override fun onError(error: QuickAiError, message: String?) { + sink.onError(error, message) + } + } + val messages = listOf( + QuickAiChatMessage(role = QuickAiChatRole.USER, parts = parts) + ) + return when (val r = runMultimodalHandleWithMessagesStreaming(messages, forwardingSink)) { + is BackendResult.Ok -> { + val metrics = when (val m = metrics()) { + is BackendResult.Ok -> m.value + is BackendResult.Err -> null + } + BackendResult.Ok( + QuickAiChatResult( + content = accumulated.toString(), + metrics = metrics + ) + ) + } + is BackendResult.Err -> BackendResult.Err(r.error, r.message) + } + } + + override fun cancel() { + Log.d(TAG, "cancel(): START, handle=0x${handle.toString(16)}") + if (handle != 0L) { + Log.d(TAG, "cancel(): calling NativeCausalLm.cancelModelHandleNative(handle=0x${handle.toString(16)})") + val result = NativeCausalLm.cancelModelHandleNative(handle) + Log.d(TAG, "cancel(): cancelModelHandleNative returned $result") + } else { + Log.w(TAG, "cancel(): no valid handle to cancel") + } + } + + override fun chatCancel() { + activeSession?.cancel() + } + + override fun chatRebuild( + messages: List + ): BackendResult { + val session = activeSession + ?: return BackendResult.Err( + QuickAiError.BAD_REQUEST, + "No active chat session β€” call openChatSession() first" + ) + return session.rebuild(messages) + } + + // --- OpenAI messages API (handle-based) -------------------------------- + + override fun runModelHandleWithMessagesStreaming( + messages: List, + sink: StreamSink + ): BackendResult { + if (!loaded || handle == 0L) { + val err = BackendResult.Err( + QuickAiError.NOT_INITIALIZED, + "NativeQuickDotAI has not been loaded yet" + ) + sink.onError(err.error, err.message) + return err + } + + return try { + val errorCode = NativeCausalLm.runModelHandleWithMessagesStreamingNative( + handle = handle, + messages = messages.toTypedArray(), + addGenerationPrompt = true, + listener = object : NativeCausalLm.NativeStreamListener { + override fun onDelta(text: String) { + sink.onDelta(text) + } + } + ) + if (errorCode != 0) { + val err = QuickAiError.fromNativeCode(errorCode) + sink.onError(err, "runModelHandleWithMessagesStreaming failed (errorCode=$errorCode)") + BackendResult.Err(err, "runModelHandleWithMessagesStreaming failed (errorCode=$errorCode)") + } else { + sink.onDone() + BackendResult.Ok(Unit) + } + } catch (t: Throwable) { + sink.onError(QuickAiError.INFERENCE_FAILED, t.message) + BackendResult.Err(QuickAiError.INFERENCE_FAILED, t.message) + } + } + + /** + * @brief Streaming inference with OpenAI JSON format. + * + * Accepts a JSON string in OpenAI format and processes it through the + * chat template. Supports messages, tools, functions, and all other + * fields recognized by minja chat template renderer. + * + * Example JSON input: + * ``` + * { + * "messages": [ + * {"role": "developer", "content": "..."}, + * {"role": "user", "content": "..."} + * ], + * "tools": [ + * {"type": "function", "function": {"name": "call", "description": "..."}} + * ] + * } + * ``` + * + * @param jsonRequest OpenAI format JSON string + * @param sink StreamSink for receiving streaming output + * @return BackendResult + */ + override fun runModelHandleWithJsonStreaming( + jsonRequest: String, + sink: StreamSink + ): BackendResult { + if (!loaded || handle == 0L) { + val err = BackendResult.Err( + QuickAiError.NOT_INITIALIZED, + "NativeQuickDotAI has not been loaded yet" + ) + sink.onError(err.error, err.message) + return err + } + + return try { + val errorCode = NativeCausalLm.runModelHandleWithJsonStreamingNative( + handle = handle, + jsonRequest = jsonRequest, + listener = object : NativeCausalLm.NativeStreamListener { + override fun onDelta(text: String) { + sink.onDelta(text) + } + } + ) + if (errorCode != 0) { + val err = QuickAiError.fromNativeCode(errorCode) + sink.onError(err, "runModelHandleWithJsonStreaming failed (errorCode=$errorCode)") + BackendResult.Err(err, "runModelHandleWithJsonStreaming failed (errorCode=$errorCode)") + } else { + sink.onDone() + BackendResult.Ok(Unit) + } + } catch (t: Throwable) { + sink.onError(QuickAiError.INFERENCE_FAILED, t.message) + BackendResult.Err(QuickAiError.INFERENCE_FAILED, t.message) + } + } + + /** + * @brief Streaming multimodal inference with OpenAI message format on a specific handle. + */ + override fun runMultimodalHandleWithMessagesStreaming( + messages: List, + sink: StreamSink + ): BackendResult { + if (!loaded || handle == 0L) { + val err = BackendResult.Err( + QuickAiError.NOT_INITIALIZED, + "NativeQuickDotAI has not been loaded yet" + ) + sink.onError(err.error, err.message) + return err + } + + val processor = imageProcessor + if (processor == null) { + val err = BackendResult.Err( + QuickAiError.UNSUPPORTED, + "Vision model not loaded" + ) + sink.onError(err.error, err.message) + return err + } + + // Extract image from messages + val allParts = messages.flatMap { it.parts } + val imageParts = allParts.filter { it is PromptPart.ImageBytes || it is PromptPart.ImageFile || it is PromptPart.PreprocessedPixels } + + if (imageParts.isEmpty()) { + val err = BackendResult.Err( + QuickAiError.INVALID_PARAMETER, + "No image found. Expected parts: [Text, ImageBytes]" + ) + sink.onError(err.error, err.message) + return err + } + + val multimodalInput = prepareMultimodalInput(allParts, processor) + if (multimodalInput == null) { + val err = BackendResult.Err( + QuickAiError.INVALID_PARAMETER, + "Image preprocessing failed" + ) + sink.onError(err.error, err.message) + return err + } + + return try { + val errorCode = if (multimodalInput.numImages > 1 && multimodalInput.patchesPerImage != null) { + // Multi-image path (V-JEPA) + Log.i(TAG, "runMultimodalHandleWithMessagesStreaming(): using multi-image path, numImages=${multimodalInput.numImages}") + NativeCausalLm.runMultimodalMultiImageWithMessagesStreamingNative( + handle = handle, + messages = messages.toTypedArray(), + addGenerationPrompt = true, + pixelValues = multimodalInput.pixelValues, + numPatches = multimodalInput.numPatches, + numImages = multimodalInput.numImages, + patchesPerImage = multimodalInput.patchesPerImage, + originalHeights = multimodalInput.originalHeights ?: IntArray(multimodalInput.numImages) { multimodalInput.originalHeight }, + originalWidths = multimodalInput.originalWidths ?: IntArray(multimodalInput.numImages) { multimodalInput.originalWidth }, + listener = object : NativeCausalLm.NativeStreamListener { + override fun onDelta(text: String) { + sink.onDelta(text) + } + } + ) + } else { + // Single-image path (legacy) + NativeCausalLm.runMultimodalHandleWithMessagesStreamingNative( + handle = handle, + messages = messages.toTypedArray(), + addGenerationPrompt = true, + pixelValues = multimodalInput.pixelValues, + numPatches = multimodalInput.numPatches, + originalHeight = multimodalInput.originalHeight, + originalWidth = multimodalInput.originalWidth, + listener = object : NativeCausalLm.NativeStreamListener { + override fun onDelta(text: String) { + sink.onDelta(text) + } + } + ) + } + if (errorCode != 0) { + val err = QuickAiError.fromNativeCode(errorCode) + sink.onError(err, "runMultimodalHandleWithMessagesStreaming failed (errorCode=$errorCode)") + BackendResult.Err(err, "runMultimodalHandleWithMessagesStreaming failed (errorCode=$errorCode)") + } else { + sink.onDone() + BackendResult.Ok(Unit) + } + } catch (t: Throwable) { + sink.onError(QuickAiError.INFERENCE_FAILED, t.message) + BackendResult.Err(QuickAiError.INFERENCE_FAILED, t.message) + } + } + + override fun close() { + activeSession?.close() + activeSession = null + if (handle != 0L) { + try { + NativeCausalLm.destroyModelHandleNative(handle) + } catch (t: Throwable) { + Log.w(TAG, "destroyModelHandleNative threw", t) + } + handle = 0L + } + loaded = false + } + + // --- multimodal ------------------------------------------------------- + + /** + * @brief Blocking multimodal inference. + * + * Preprocesses images from [parts], combines with text prompt, and + * runs inference through the native engine. + * + * @param parts List of PromptPart containing text and/or images + * @return BackendResult with generated text on success + */ + override fun runMultimodalHandle(parts: List): BackendResult { + if (!loaded || handle == 0L) { + return BackendResult.Err( + QuickAiError.NOT_INITIALIZED, + "NativeQuickDotAI has not been loaded yet" + ) + } + + val processor = imageProcessor + if (processor == null) { + return BackendResult.Err( + QuickAiError.UNSUPPORTED, + "Multimodal not enabled β€” reload with LoadModelRequest.visionBackend set" + ) + } + + // Extract image and text from parts + val multimodalInput = prepareMultimodalInput(parts, processor) + ?: return BackendResult.Err( + QuickAiError.INVALID_PARAMETER, + "No valid image found in parts" + ) + + val textPrompt = extractTextPrompt(parts) + + Log.i( + TAG, + "runMultimodal(): numPatches=${multimodalInput.numPatches}, " + + "originalSize=${multimodalInput.originalHeight}x${multimodalInput.originalWidth}, " + + "prompt length=${textPrompt.length}" + ) + + return try { + val accumulated = StringBuilder() + val errorCode = NativeCausalLm.runMultimodalHandleStreamingNative( + handle, + textPrompt, + multimodalInput.pixelValues, + multimodalInput.numPatches, + multimodalInput.originalHeight, + multimodalInput.originalWidth, + object : NativeCausalLm.NativeStreamListener { + override fun onDelta(text: String) { + accumulated.append(text) + } + } + ) + if (errorCode != 0) { + val err = QuickAiError.fromNativeCode(errorCode) + Log.e(TAG, "runMultimodal(): failed with errorCode=$errorCode") + BackendResult.Err(err, "runMultimodalHandle failed (errorCode=$errorCode)") + } else { + val output = accumulated.toString() + Log.i(TAG, "runMultimodal(): success, output length=${output.length}") + BackendResult.Ok(output) + } + } catch (t: Throwable) { + Log.e(TAG, "runMultimodal(): threw exception", t) + BackendResult.Err(QuickAiError.INFERENCE_FAILED, t.message) + } + } + + /** + * @brief Streaming multimodal inference. + * + * Preprocesses images from [parts], combines with text prompt, and + * runs streaming inference through the native engine. Deltas are + * forwarded to [sink] as they are generated. + * + * @param parts List of PromptPart containing text and/or images + * @param sink StreamSink to receive streaming output + * @return BackendResult on completion + */ + override fun runMultimodalHandleStreaming( + parts: List, + sink: StreamSink + ): BackendResult { + if (!loaded || handle == 0L) { + val err = BackendResult.Err( + QuickAiError.NOT_INITIALIZED, + "NativeQuickDotAI has not been loaded yet" + ) + sink.onError(err.error, err.message) + return err + } + + val processor = imageProcessor + if (processor == null) { + val err = BackendResult.Err( + QuickAiError.UNSUPPORTED, + "MultimodalStreaming not enabled β€” reload with LoadModelRequest.visionBackend set" + ) + sink.onError(err.error, err.message) + return err + } + + // Extract image and text from parts + val multimodalInput = prepareMultimodalInput(parts, processor) + if (multimodalInput == null) { + val err = BackendResult.Err( + QuickAiError.INVALID_PARAMETER, + "No valid image found in parts" + ) + sink.onError(err.error, err.message) + return err + } + + val textPrompt = extractTextPrompt(parts) + + Log.i( + TAG, + "runMultimodalStreaming(): numPatches=${multimodalInput.numPatches}, " + + "originalSize=${multimodalInput.originalHeight}x${multimodalInput.originalWidth}, " + + "prompt length=${textPrompt.length}" + ) + + return try { + val errorCode = if (multimodalInput.numImages > 1 && multimodalInput.patchesPerImage != null) { + // Multi-image path (V-JEPA) + Log.i(TAG, "runMultimodalStreaming(): using multi-image path, numImages=${multimodalInput.numImages}") + NativeCausalLm.runMultimodalMultiImageStreamingNative( + handle, + textPrompt, + multimodalInput.pixelValues, + multimodalInput.numPatches, + multimodalInput.numImages, + multimodalInput.patchesPerImage, + multimodalInput.originalHeights ?: IntArray(multimodalInput.numImages) { multimodalInput.originalHeight }, + multimodalInput.originalWidths ?: IntArray(multimodalInput.numImages) { multimodalInput.originalWidth }, + ) { delta -> + sink.onDelta(delta) + } + } else { + // Single-image path (legacy) + NativeCausalLm.runMultimodalHandleStreamingNative( + handle, + textPrompt, + multimodalInput.pixelValues, + multimodalInput.numPatches, + multimodalInput.originalHeight, + multimodalInput.originalWidth + ) { delta -> + sink.onDelta(delta) + } + } + + if (errorCode != 0) { + val err = QuickAiError.fromNativeCode(errorCode) + Log.e(TAG, "runMultimodalStreaming(): failed with errorCode=$errorCode") + sink.onError(err, "runMultimodalHandleStreaming failed (errorCode=$errorCode)") + BackendResult.Err(err, "runMultimodalHandleStreaming failed (errorCode=$errorCode)") + } else { + Log.i(TAG, "runMultimodalStreaming(): success") + sink.onDone() + BackendResult.Ok(Unit) + } + } catch (t: Throwable) { + Log.e(TAG, "runMultimodalStreaming(): threw exception", t) + sink.onError(QuickAiError.INFERENCE_FAILED, t.message) + BackendResult.Err(QuickAiError.INFERENCE_FAILED, t.message) + } + } + + /** + * @brief Prepare multimodal input from PromptPart list. + * + * Extracts images from parts and preprocesses them using + * LlavaNextImageProcessor. Supports both single-image and + * multi-image (V-JEPA) scenarios: + * - Single image: returns a MultimodalInput with numImages=1 (default) + * - Multiple ImageBytes: preprocesses each image, concatenates pixel + * values, and returns a multi-image MultimodalInput + * - PreprocessedPixels: passed through directly + * + * @return MultimodalInput with preprocessed pixel values, or null if no image found + */ + private fun prepareMultimodalInput( + parts: List, + processor: LlavaNextImageProcessor + ): NativeCausalLm.MultimodalInput? { + // Collect all image parts first + val imageParts = mutableListOf() + for (part in parts) { + when (part) { + is PromptPart.ImageFile -> imageParts.add(part) + is PromptPart.ImageBytes -> imageParts.add(part) + is PromptPart.PreprocessedPixels -> { + // PreprocessedPixels bypass the image processor entirely + return NativeCausalLm.MultimodalInput( + pixelValues = part.pixelValues, + numPatches = part.numPatches, + originalHeight = part.imageHeights.firstOrNull() ?: 0, + originalWidth = part.imageWidths.firstOrNull() ?: 0, + numImages = part.numImages, + patchesPerImage = part.patchesPerImage, + originalHeights = part.imageHeights, + originalWidths = part.imageWidths + ) + } + is PromptPart.Text -> { /* skip text parts */ } + } + } + + if (imageParts.isEmpty()) return null + + // Single image: use the original single-image path + if (imageParts.size == 1) { + return preprocessSingleImage(imageParts[0], processor) + } + + // Multiple images: preprocess each and concatenate + val allPixelValues = mutableListOf() + val patchesPerImageList = mutableListOf() + val heightsList = mutableListOf() + val widthsList = mutableListOf() + var totalPatches = 0 + val cropSize = processor.getCropSize() + val patchSize = cropSize * cropSize * 3 + + for (imgPart in imageParts) { + val bitmap = when (imgPart) { + is PromptPart.ImageFile -> { + val file = File(imgPart.absolutePath) + if (!file.exists() || !file.canRead()) { + Log.w(TAG, "Image file not readable: ${imgPart.absolutePath}") + continue + } + BitmapFactory.decodeFile(imgPart.absolutePath) + } + is PromptPart.ImageBytes -> { + if (imgPart.bytes.isEmpty()) { + Log.w(TAG, "Image bytes are empty") + continue + } + BitmapFactory.decodeByteArray(imgPart.bytes, 0, imgPart.bytes.size) + } + else -> null + } + if (bitmap == null) { + Log.w(TAG, "Failed to decode image in multi-image batch") + continue + } + val modelInput = processor.preprocess(bitmap) + val numPatches = modelInput.pixelValues.size / patchSize + allPixelValues.addAll(modelInput.pixelValues.toList()) + patchesPerImageList.add(numPatches) + heightsList.add(modelInput.originalSize.first) + widthsList.add(modelInput.originalSize.second) + totalPatches += numPatches + } + + if (allPixelValues.isEmpty()) return null + + val numImages = patchesPerImageList.size + Log.i(TAG, "prepareMultimodalInput(): multi-image mode, numImages=$numImages, " + + "totalPatches=$totalPatches, patchesPerImage=$patchesPerImageList") + + return NativeCausalLm.MultimodalInput( + pixelValues = allPixelValues.toFloatArray(), + numPatches = totalPatches, + originalHeight = heightsList.firstOrNull() ?: 0, + originalWidth = widthsList.firstOrNull() ?: 0, + numImages = numImages, + patchesPerImage = patchesPerImageList.toIntArray(), + originalHeights = heightsList.toIntArray(), + originalWidths = widthsList.toIntArray() + ) + } + + /** + * @brief Preprocess a single image part into a MultimodalInput. + */ + private fun preprocessSingleImage( + part: PromptPart, + processor: LlavaNextImageProcessor + ): NativeCausalLm.MultimodalInput? { + when (part) { + is PromptPart.ImageFile -> { + val file = File(part.absolutePath) + if (!file.exists() || !file.canRead()) { + Log.w(TAG, "Image file not readable: ${part.absolutePath}") + return null + } + val bitmap = BitmapFactory.decodeFile(part.absolutePath) + if (bitmap == null) { + Log.w(TAG, "Failed to decode image: ${part.absolutePath}") + return null + } + val modelInput = processor.preprocess(bitmap) + return NativeCausalLm.MultimodalInput( + pixelValues = modelInput.pixelValues, + numPatches = modelInput.pixelValues.size / (processor.getCropSize() * processor.getCropSize() * 3), + originalHeight = modelInput.originalSize.first, + originalWidth = modelInput.originalSize.second + ) + } + is PromptPart.ImageBytes -> { + if (part.bytes.isEmpty()) { + Log.w(TAG, "Image bytes are empty") + return null + } + val bitmap = BitmapFactory.decodeByteArray(part.bytes, 0, part.bytes.size) + if (bitmap == null) { + Log.w(TAG, "Failed to decode image from bytes") + return null + } + val modelInput = processor.preprocess(bitmap) + return NativeCausalLm.MultimodalInput( + pixelValues = modelInput.pixelValues, + numPatches = modelInput.pixelValues.size / (processor.getCropSize() * processor.getCropSize() * 3), + originalHeight = modelInput.originalSize.first, + originalWidth = modelInput.originalSize.second + ) + } + else -> return null + } + } + + /** + * @brief Extract text prompt from PromptPart list. + * + * Concatenates all Text parts into a single prompt string. + */ + private fun extractTextPrompt(parts: List): String { + return parts.filterIsInstance() + .joinToString(" ") { it.text } + .ifEmpty { "Describe this image." } + } + + private fun mapBackend(b: BackendType): Int = when (b) { + BackendType.CPU -> 0 + BackendType.GPU -> 1 + BackendType.NPU -> 2 + } + + private fun mapQuant(q: QuantizationType): Int = when (q) { + QuantizationType.UNKNOWN -> 0 + QuantizationType.W4A32 -> 1 + QuantizationType.W16A16 -> 2 + QuantizationType.W8A16 -> 3 + QuantizationType.W32A32 -> 4 + } + + companion object { + private const val TAG = "NativeQuickDotAI" + } +} diff --git a/Android/QuickDotAI/src/main/java/com/example/quickdotai/PilloBilinearResizer.kt b/Android/QuickDotAI/src/main/java/com/example/quickdotai/PilloBilinearResizer.kt new file mode 100644 index 00000000..faf18812 --- /dev/null +++ b/Android/QuickDotAI/src/main/java/com/example/quickdotai/PilloBilinearResizer.kt @@ -0,0 +1,185 @@ +package com.example.quickdotai + +import kotlin.math.abs +import kotlin.math.ceil + +class PillowBilinearResizer { + companion object { + + // Bilinear kernel + private fun bilinearKernel(x: Double): Double { + val absX = abs(x) + if (absX < 1.0) { + return 1.0 - absX + } + return 0.0 + } + + fun resize( + pixels: IntArray, + width: Int, + height: Int, + newWidth: Int, + newHeight: Int + ): IntArray { + // Pillow uses fixed point arithmetic with 22 bits of precision for weights + val precisionBits = 22 + val halfOne = 1 shl (precisionBits - 1) + + // 1. Horizontal Pass: (width, height) -> (newWidth, height) + val tempPixels = IntArray(newWidth * height) + val xScale = width.toDouble() / newWidth + val filterScaleX = if (xScale < 1.0) 1.0 else xScale + val supportX = 1.0 * filterScaleX // Support is 1.0 for Bilinear + val scaleFactorX = 1.0 / filterScaleX + + // Precompute weights for horizontal pass + val kSizeX = (ceil(supportX).toInt() * 2 + 1) + val boundsX = IntArray(newWidth * 2) + val kkX = IntArray(newWidth * kSizeX) + + for (x in 0 until newWidth) { + val center = (x + 0.5) * xScale + var xMin = (center - supportX + 0.5).toInt() + var xMax = (center + supportX + 0.5).toInt() + + if (xMin < 0) xMin = 0 + if (xMax > width) xMax = width + + val count = xMax - xMin + + boundsX[x * 2] = xMin + boundsX[x * 2 + 1] = count + + var ww = 0.0 + val weights = DoubleArray(count) + for (i in 0 until count) { + val srcX = xMin + i + val w = bilinearKernel((srcX + 0.5 - center) * scaleFactorX) + weights[i] = w + ww += w + } + + // Normalize and convert to fixed point + for (i in 0 until count) { + if (ww != 0.0) weights[i] /= ww + val fw = if (weights[i] < 0) { + (weights[i] * (1 shl precisionBits) - 0.5).toInt() + } else { + (weights[i] * (1 shl precisionBits) + 0.5).toInt() + } + kkX[x * kSizeX + i] = fw + } + } + + for (y in 0 until height) { + for (x in 0 until newWidth) { + val xMin = boundsX[x * 2] + val count = boundsX[x * 2 + 1] + + var r = halfOne + var g = halfOne + var b = halfOne + var a = halfOne + + for (i in 0 until count) { + val weight = kkX[x * kSizeX + i] + val srcX = xMin + i + val pixel = pixels[y * width + srcX] + + a += ((pixel shr 24) and 0xFF) * weight + r += ((pixel shr 16) and 0xFF) * weight + g += ((pixel shr 8) and 0xFF) * weight + b += (pixel and 0xFF) * weight + } + + val rInt = (r shr precisionBits).coerceIn(0, 255) + val gInt = (g shr precisionBits).coerceIn(0, 255) + val bInt = (b shr precisionBits).coerceIn(0, 255) + val aInt = (a shr precisionBits).coerceIn(0, 255) + + tempPixels[y * newWidth + x] = + (aInt shl 24) or (rInt shl 16) or (gInt shl 8) or bInt + } + } + + // 2. Vertical Pass: (newWidth, height) -> (newWidth, newHeight) + val finalPixels = IntArray(newWidth * newHeight) + val yScale = height.toDouble() / newHeight + val filterScaleY = if (yScale < 1.0) 1.0 else yScale + val supportY = 1.0 * filterScaleY // Support is 1.0 for Bilinear + val scaleFactorY = 1.0 / filterScaleY + + val kSizeY = (ceil(supportY).toInt() * 2 + 1) + val boundsY = IntArray(newHeight * 2) + val kkY = IntArray(newHeight * kSizeY) + + for (y in 0 until newHeight) { + val center = (y + 0.5) * yScale + var yMin = (center - supportY + 0.5).toInt() + var yMax = (center + supportY + 0.5).toInt() + + if (yMin < 0) yMin = 0 + if (yMax > height) yMax = height + + val count = yMax - yMin + + boundsY[y * 2] = yMin + boundsY[y * 2 + 1] = count + + var ww = 0.0 + val weights = DoubleArray(count) + for (i in 0 until count) { + val srcY = yMin + i + val w = bilinearKernel((srcY + 0.5 - center) * scaleFactorY) + weights[i] = w + ww += w + } + + for (i in 0 until count) { + if (ww != 0.0) weights[i] /= ww + val fw = if (weights[i] < 0) { + (weights[i] * (1 shl precisionBits) - 0.5).toInt() + } else { + (weights[i] * (1 shl precisionBits) + 0.5).toInt() + } + kkY[y * kSizeY + i] = fw + } + } + + for (x in 0 until newWidth) { + for (y in 0 until newHeight) { + val yMin = boundsY[y * 2] + val count = boundsY[y * 2 + 1] + + var r = halfOne + var g = halfOne + var b = halfOne + var a = halfOne + + for (i in 0 until count) { + val weight = kkY[y * kSizeY + i] + val srcY = yMin + i + val pixel = tempPixels[srcY * newWidth + x] + + a += ((pixel shr 24) and 0xFF) * weight + r += ((pixel shr 16) and 0xFF) * weight + g += ((pixel shr 8) and 0xFF) * weight + b += (pixel and 0xFF) * weight + } + + val rInt = (r shr precisionBits).coerceIn(0, 255) + val gInt = (g shr precisionBits).coerceIn(0, 255) + val bInt = (b shr precisionBits).coerceIn(0, 255) + val aInt = (a shr precisionBits).coerceIn(0, 255) + + finalPixels[y * newWidth + x] = + (aInt shl 24) or (rInt shl 16) or (gInt shl 8) or bInt + } + } + + return finalPixels + } + + } +} diff --git a/Android/QuickDotAI/src/main/java/com/example/quickdotai/QuickDotAI.kt b/Android/QuickDotAI/src/main/java/com/example/quickdotai/QuickDotAI.kt new file mode 100644 index 00000000..09e623ee --- /dev/null +++ b/Android/QuickDotAI/src/main/java/com/example/quickdotai/QuickDotAI.kt @@ -0,0 +1,395 @@ +// SPDX-License-Identifier: Apache-2.0 +/* + * Copyright (C) 2026 Samsung Electronics Co., Ltd. All Rights Reserved. + * + * @file QuickDotAI.kt + * @brief Public surface of the QuickDotAI AAR. + * + * QuickDotAI is a thin abstraction over a single loaded on-device + * language model. Two concrete implementations are shipped in this AAR: + * + * - [NativeQuickDotAI] β€” routes non-Gemma models through JNI to + * libcausallm_api.so, the handle-based C API built from + * Applications/CausalLM. + * - [LiteRTLm] β€” routes Gemma-family models through the + * LiteRT-LM Kotlin API. + * + * Both implementations satisfy the same [QuickDotAI] contract so a host + * app can pick an engine once at load time and then drive it through a + * interface for handle-based inference (OpenAI tab), session-based + * chat (Chat tab), and lifecycle management (load / unload / close). + * + * Threading: a [QuickDotAI] instance is NOT internally thread-safe. The + * expectation is that the host app owns exactly one instance per loaded + * model and drives it from a single worker thread β€” the same contract + * that QuickAIService's ModelWorker implements, and the one the sample + * app (SampleTestAPP) follows from its background dispatcher. + */ +package com.example.quickdotai + +import android.content.Context + +/** + * @brief Outcome of a QuickDotAI call. + * + * Every public method returns a [BackendResult] so errors never + * propagate out as exceptions across the AAR boundary. [Ok] carries the + * successful value; [Err] carries a [QuickAiError] code and an optional + * human-readable message. + */ +sealed class BackendResult { + data class Ok(val value: T) : BackendResult() + data class Err( + val error: QuickAiError, + val message: String? = null + ) : BackendResult() +} + +/** + * @brief Where a [QuickDotAI] implementation pushes streamed output + * during [QuickDotAI.runStreaming]. + * + * The contract is: + * - zero or more [onDelta] calls carrying newly-generated text, + * followed by + * - exactly one terminal call β€” either [onDone] on success or + * [onError] on failure. + * + * Implementations may be invoked from an implementation-internal + * thread (LiteRT-LM for example dispatches MessageCallback on its own + * worker thread). Host code that wants to marshal events back to the UI + * thread must do that bridging itself β€” the AAR does not assume any + * particular threading model on the consumer side. + */ +interface StreamSink { + fun onDelta(text: String) + fun onReasoningDelta(text: String) { + } + fun onDone() + fun onError(error: QuickAiError, message: String?) +} + +/** + * @brief Common interface implemented by every QuickDotAI engine. + * + * Lifecycle: [load] exactly once, then inference calls, then [close] + * exactly once. Calling any inference method before [load] returns a + * [BackendResult.Err] with [QuickAiError.NOT_INITIALIZED]. + * + * **Chat session lifecycle:** [openChatSession] β†’ [runChatModelHandleStreaming] / + * [runChatMultimodalHandleStreaming] / [chatCancel] / [chatRebuild] β†’ [closeChatSession]. + * Only one session may be active at a time. + */ +interface QuickDotAI { + /** @return a short identifier like "native" or "litert-lm". */ + val kind: String + + /** @return the architecture string reported by the engine, if any. */ + val architecture: String? + + /** + * @return the sessionId of the currently active chat session, or + * null if no session is open. + */ + val chatSessionId: String? + get() = null + + /** + * @brief Load the model described by [req]. Must be called exactly + * once before any inference call. + */ + fun load(req: LoadModelRequest): BackendResult + + /** + * @brief Blocking multimodal inference β€” accepts a sequence of + * [PromptPart]s that may interleave text and image inputs. + * + * The default implementation returns [QuickAiError.UNSUPPORTED] + * because not every engine can handle non-text inputs. Concrete + * implementations backed by multimodal-capable models (currently + * [LiteRTLm] with a multimodal Gemma loaded through a non-null + * [LoadModelRequest.visionBackend]) override this to do the real + * work. [NativeQuickDotAI] inherits the UNSUPPORTED default, so + * consumers get a clear error message instead of a silent failure + * when they aim an image prompt at the text-only native engine. + * + * Contract: + * - [parts] must be non-empty; an empty list returns + * [QuickAiError.INVALID_PARAMETER]. + * - Parts may appear in any order. The canonical Gemma-4 / + * Gemma3n convention is one or more image parts followed by a + * single trailing text instruction. + * - Must be called only after a successful [load]; calling it + * before [load] returns [QuickAiError.NOT_INITIALIZED]. + * + * Example: + * ``` + * val reply = engine.runMultimodalHandle(listOf( + * PromptPart.ImageFile("/sdcard/photo.jpg"), + * PromptPart.Text("What is happening in this picture?"), + * )) + * ``` + */ + fun runMultimodalHandle(parts: List): BackendResult = + BackendResult.Err( + QuickAiError.UNSUPPORTED, + "runMultimodalHandle is not supported by engine '$kind'. " + + "Load a multimodal-capable model (e.g. GEMMA4) with " + + "LoadModelRequest.visionBackend set to a non-null value." + ) + + /** + * @brief Streaming variant of [runMultimodalHandle]. + * + * The default implementation returns [QuickAiError.UNSUPPORTED] + * and delivers a single terminal [StreamSink.onError] before + * returning, so callers can rely on the same StreamSink contract + * as text-only streaming regardless of which engine they targeted. + */ + fun runMultimodalHandleStreaming( + parts: List, + sink: StreamSink + ): BackendResult { + val err = BackendResult.Err( + QuickAiError.UNSUPPORTED, + "runMultimodalHandleStreaming is not supported by engine '$kind'. " + + "Load a multimodal-capable model (e.g. GEMMA4) with " + + "LoadModelRequest.visionBackend set to a non-null value." + ) + sink.onError(err.error, err.message) + return err + } + + /** + * @brief Unload the model weights without destroying the engine. + * + * After a successful unload the engine is in a "not initialized" state + * β€” subsequent [run] / [runStreaming] / [metrics] calls will return + * [QuickAiError.NOT_INITIALIZED]. The instance can still be [close]d + * normally (and must be, to release any remaining resources). + * + * Implementations that do not support partial unload may treat this as + * a full [close] or return [BackendResult.Ok] as a no-op. + */ + fun unload(): BackendResult + + /** + * @brief Fetch performance metrics for the most recent run. + */ + fun metrics(): BackendResult + + /** + * Encode [text] into a sentence-embedding vector. Only embedding models + * (e.g. the Ouro family) support this; other engines return + * [QuickAiError.INFERENCE_FAILED] by default. + * + * @return [BackendResult.Ok] with the embedding FloatArray, or + * [BackendResult.Err] on failure. + */ + fun encode(text: String): BackendResult = + BackendResult.Err( + QuickAiError.INFERENCE_FAILED, + "encode() is not supported by this engine" + ) + + // ----- Chat session API ------------------------------------------------ + // All chat operations go through this interface so the app never needs + // to interact with chat session classes directly. + + /** + * @brief Open a new structured chat session on this engine. + * + * Only **one** session may be active at a time (LiteRT-LM allows a + * single Conversation per Engine). If a session is already open, + * this method returns [QuickAiError.BAD_REQUEST]. Returns the + * session ID on success. + */ + fun openChatSession( + config: QuickAiChatSessionConfig? = null + ): BackendResult = + BackendResult.Err( + QuickAiError.UNSUPPORTED, + "openChatSession is not supported by engine '$kind'." + ) + + /** + * @brief Close the active chat session, releasing its resources + * (conversation handle, cached images, etc.). After closing, the + * flat [run] / [runStreaming] APIs become usable again. + */ + fun closeChatSession(): BackendResult = + BackendResult.Err( + QuickAiError.UNSUPPORTED, + "closeChatSession is not supported by engine '$kind'." + ) + + /** + * @brief Send a chat message in a session with streaming response. + * + * Requires an active session opened via [openChatSession]. + * The text message is converted internally to a structured format + * and sent to the native engine. + * + * @param text Raw text input from the user + * @param sink StreamSink to receive streaming output + * @return BackendResult containing the chat result or an error + */ + fun runChatModelHandleStreaming( + text: String, + sink: StreamSink + ): BackendResult { + val err = BackendResult.Err( + QuickAiError.UNSUPPORTED, + "runChatModelHandleStreaming is not supported by engine '$kind'." + ) + sink.onError(err.error, err.message) + return err + } + + /** + * @brief Send a multimodal chat message (with image) in a session + * with streaming response. + * + * Requires an active session opened via [openChatSession]. + * + * @param parts List of PromptPart containing text and/or images + * @param sink StreamSink to receive streaming output + * @return BackendResult containing the chat result or an error + */ + fun runChatMultimodalHandleStreaming( + parts: List, + sink: StreamSink + ): BackendResult { + val err = BackendResult.Err( + QuickAiError.UNSUPPORTED, + "runChatMultimodalHandleStreaming is not supported by engine '$kind'." + ) + sink.onError(err.error, err.message) + return err + } + + /** + * @brief Cancel an in-flight [runStreaming] or [runMultimodalStreaming]. + * Safe to call from any thread. No-op if no generation is running. + */ + fun cancel() { /* no-op by default */ } + + /** + * @brief Cancel an in-flight [chatRun] or [chatRunStreaming]. + * Safe to call from any thread. No-op if no generation is running. + */ + fun chatCancel() { /* no-op by default */ } + + /** + * @brief Reset the active session: drop the backend's KV cache and + * optionally pre-seed a fresh conversation with [messages] as + * initial turns. Pass `emptyList()` to simply clear the session. + * Use this after history edits, sampling changes, or to recover + * from a failed/cancelled turn. + */ + fun chatRebuild( + messages: List + ): BackendResult = + BackendResult.Err( + QuickAiError.UNSUPPORTED, + "chatRebuild is not supported by engine '$kind'." + ) + + // ----- Handle-based OpenAI messages API (streaming only) ---------- + + /** + * @brief Streaming inference with OpenAI message format on a specific handle. + * + * @param messages List of chat messages with role (system/user/assistant) and content + * @param sink StreamSink to receive streaming output + * @return BackendResult on completion + */ + fun runModelHandleWithMessagesStreaming( + messages: List, + sink: StreamSink + ): BackendResult { + val err = BackendResult.Err( + QuickAiError.UNSUPPORTED, + "runModelHandleWithMessagesStreaming is not supported by engine '$kind'." + ) + sink.onError(err.error, err.message) + return err + } + + /** + * @brief Streaming multimodal inference with OpenAI message format on a specific handle. + * + * @param messages List of chat messages. Image should be included as ImageBytes part. + * @param sink StreamSink to receive streaming output + * @return BackendResult on completion + */ + fun runMultimodalHandleWithMessagesStreaming( + messages: List, + sink: StreamSink + ): BackendResult { + val err = BackendResult.Err( + QuickAiError.UNSUPPORTED, + "runMultimodalHandleWithMessagesStreaming is not supported by engine '$kind'." + ) + sink.onError(err.error, err.message) + return err + } + + /** + * @brief Streaming inference with OpenAI JSON format. + * + * Accepts a JSON string in OpenAI format and processes it through the + * chat template. Supports messages, tools, functions, and all other + * fields recognized by minja chat template renderer. + * + * Example JSON input: + * ``` + * { + * "messages": [ + * {"role": "developer", "content": "..."}, + * {"role": "user", "content": "..."} + * ], + * "tools": [ + * {"type": "function", "function": {"name": "call", "description": "..."}} + * ] + * } + * ``` + * + * @param jsonRequest OpenAI format JSON string + * @param sink StreamSink to receive streaming output + * @return BackendResult on completion + */ + fun runModelHandleWithJsonStreaming( + jsonRequest: String, + sink: StreamSink + ): BackendResult { + val err = BackendResult.Err( + QuickAiError.UNSUPPORTED, + "runModelHandleWithJsonStreaming is not supported by engine '$kind'." + ) + sink.onError(err.error, err.message) + return err + } + + /** + * @brief Release all resources. Idempotent β€” safe to call more + * than once. + */ + fun close() +} + +/** + * Factory: create the right engine for a [ModelDescriptor]. + */ +fun createEngine( + context: Context, + descriptor: ModelDescriptor, + modelBasePath: String? = null +): QuickDotAI = + when (descriptor.runtime) { + RuntimeKind.LITERT -> LiteRTLm( + context, + defaultModelBasePath = modelBasePath ?: "/sdcard/Download/aistudio-mobile/models/" + ) + RuntimeKind.NATIVE -> NativeQuickDotAI(context) + } diff --git a/Android/QuickDotAI/src/main/java/com/example/quickdotai/Types.kt b/Android/QuickDotAI/src/main/java/com/example/quickdotai/Types.kt new file mode 100644 index 00000000..de255294 --- /dev/null +++ b/Android/QuickDotAI/src/main/java/com/example/quickdotai/Types.kt @@ -0,0 +1,372 @@ +// SPDX-License-Identifier: Apache-2.0 +/* + * Copyright (C) 2026 Samsung Electronics Co., Ltd. All Rights Reserved. + * + * @file Types.kt + * @brief Value types shared by the QuickDotAI interface and its + * implementations. + * + * The enums mirror the C enums in Applications/CausalLM/api/quick_dot_ai_api.h. + * Model identifiers are plain Strings (see [ModelIds] in ModelCatalog.kt) + * so the AAR is not re-compiled whenever the model list changes. + * + * Every public class in this file carries `@Serializable` so host apps + * that want to JSON-ify requests/responses (for example QuickAIService's + * REST layer) can do so without redefining the types β€” the AAR exposes + * kotlinx-serialization-json as an `api` dependency for that purpose. + */ +package com.example.quickdotai + +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable + +/** + * @brief Compute backend. Mirrors BackendType in quick_dot_ai_api.h. + */ +@Serializable +enum class BackendType { + CPU, + GPU, + NPU +} + + +/** + * @brief Quantization type. Mirrors ModelQuantizationType in + * quick_dot_ai_api.h. + */ +@Serializable +enum class QuantizationType { + UNKNOWN, + W4A32, + W16A16, + W8A16, + W32A32 +} + +/** + * @brief Error code. Mirrors ErrorCode in quick_dot_ai_api.h plus a few + * Kotlin-level additions for out-of-band conditions. + */ +@Serializable +enum class QuickAiError(val code: Int) { + NONE(0), + INVALID_PARAMETER(1), + MODEL_LOAD_FAILED(2), + INFERENCE_FAILED(3), + NOT_INITIALIZED(4), + INFERENCE_NOT_RUN(5), + UNKNOWN(99), + + // Kotlin-only conditions returned by higher layers (QuickAIService + // worker / dispatcher). The AAR itself only surfaces the native + // codes and NOT_INITIALIZED, but it is convenient for the host app + // to have the full enum in one place. + QUEUE_FULL(100), + MODEL_NOT_FOUND(101), + UNSUPPORTED(102), + BAD_REQUEST(103); + + companion object { + fun fromNativeCode(code: Int): QuickAiError = + entries.firstOrNull { it.code == code } ?: UNKNOWN + } +} + +/** + * @brief Descriptor passed to [QuickDotAI.load]. + * + * [modelPath] is required by [LiteRTLm] (which takes an explicit path + * to a `.litertlm` file) and ignored by [NativeQuickDotAI] (which + * discovers its model assets through the native C API's internal + * model-directory resolution). + * + * [visionBackend] and [cacheDir] are optional knobs used only by + * multimodal-capable engines ([LiteRTLm] today). They are ignored + * by [NativeQuickDotAI]. + */ +@Serializable +data class LoadModelRequest( + val backend: BackendType = BackendType.GPU, + @SerialName("model_id") val modelId: String, + val quantization: QuantizationType = QuantizationType.W4A32, + @SerialName("model_path") val modelPath: String? = null, + + /** + * Compute backend for the model's vision encoder when loading a + * multimodal-capable model (e.g. Gemma-4 / Gemma3n). Null means + * the engine is loaded in text-only mode β€” in that case + * [QuickDotAI.runMultimodal] returns [QuickAiError.UNSUPPORTED] + * even on backends that would otherwise support images. + * + * Only honored by [LiteRTLm]; [NativeQuickDotAI] ignores it. + */ + @SerialName("vision_backend") val visionBackend: BackendType? = null, + + /** + * Writable directory for engine on-disk caches. Populating this + * field materially speeds up the second and subsequent loads of + * the same model. Maps to LiteRT-LM's EngineConfig.cacheDir. + * Null = engine default. + * + * Only honored by [LiteRTLm]; [NativeQuickDotAI] ignores it. + * + * (Note: the LiteRT-LM 0.10.x EngineConfig surface we compile + * against does not yet expose a per-prompt `maxNumImages` cap β€” + * that is a 1.0+ feature. Once we roll forward past 1.0 we can + * add the corresponding field back to this request.) + */ + @SerialName("cache_dir") val cacheDir: String? = null, + + /** + * Maximum number of tokens the engine should allocate for the KV + * cache / context window. This is a load-sensitive setting β€” it must + * be known at engine-construction time and cannot be changed per + * request. Null = engine default. + * + * Honored by [LiteRTLm] (maps to EngineConfig.maxNumTokens) and + * ignored by [NativeQuickDotAI]. + */ + @SerialName("max_num_tokens") val maxNumTokens: Int? = null, + + /** + * Native library directory path from ApplicationInfo.nativeLibraryDir. + * Used by the native engine to locate shared libraries for loading. + * + * Only honored by [NativeQuickDotAI]; [LiteRTLm] ignores it. + */ + @SerialName("native_lib_dir") val nativeLibDir: String? = null, + + /** + * Base directory path for model files. The native C API uses this + * as the prefix when resolving model directories (e.g. + * "$modelBasePath/qwen3-0.6b"). When null, the C API falls back + * to its built-in default path. + * + * Only honored by [NativeQuickDotAI]; [LiteRTLm] ignores it. + */ + @SerialName("model_base_path") val modelBasePath: String? = null, + + /** + * Path to the HTP backend extension config JSON file used by QNN + * models. Absolute paths are used as-is. Relative paths are resolved + * from ``. When null, the native engine falls back to + * `/htp_backend_ext_config.json`. + * + * Only honored by [NativeQuickDotAI]; [LiteRTLm] ignores it. + */ + @SerialName("htp_backend_config_path") val htpBackendConfigPath: String? = null, +) { + /** + * Canonical key shared across the stack: one worker/handle per + * (model, quantization) pair. + */ + val modelKey: String get() = "$modelId:${quantization.name}" +} + +/** + * @brief One part of a multimodal prompt passed to + * [QuickDotAI.runMultimodal] / [QuickDotAI.runMultimodalStreaming]. + * + * The concrete backend (currently [LiteRTLm] for Gemma-family models + * loaded with a non-null [LoadModelRequest.visionBackend]) translates + * each part to its native content representation. The ordering in the + * list is preserved β€” the canonical Gemma-4 / Gemma3n convention is + * one or more image parts followed by a single trailing text + * instruction, e.g. + * ``` + * runMultimodal(listOf( + * PromptPart.ImageFile("/sdcard/.../photo.jpg"), + * PromptPart.Text("Describe this picture in one sentence."), + * )) + * ``` + * + * PromptPart is intentionally NOT @Serializable: `ImageBytes.bytes` + * would serialize as a JSON array of ints, which is the wrong wire + * format for a REST layer. Consumers that need to carry multimodal + * prompts over the wire (e.g. LauncherApp's HTTP server) should + * define their own Base64-flavored DTO and convert at the boundary. + */ +sealed class PromptPart { + /** A chunk of text β€” typically the user's question or instruction. */ + data class Text(val text: String) : PromptPart() + + /** + * A local image file. [absolutePath] must point to a readable file + * on the device β€” the engine opens it directly from the native + * layer, so relative paths are NOT supported. Mirrors LiteRT-LM's + * parameter naming for clarity. + * + * Supported formats depend on the underlying engine but generally + * include JPEG and PNG. + */ + data class ImageFile(val absolutePath: String) : PromptPart() + + /** + * Image bytes already held in memory. Useful when the image comes + * from an in-process source (camera buffer, bundled asset, HTTP + * download) and the caller does not want to materialize it to a + * temporary file first. + * + * The byte layout must be the raw file contents of an encoded + * image (JPEG / PNG / …), NOT a decoded pixel array. + */ + data class ImageBytes(val bytes: ByteArray) : PromptPart() { + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (other !is ImageBytes) return false + return bytes.contentEquals(other.bytes) + } + override fun hashCode(): Int = bytes.contentHashCode() + } + + /** + * Pre-processed pixel values already in CHW float format. + * Used by models like V-JEPA where the caller has already performed + * image preprocessing externally and wants to pass the result + * directly to the native vision encoder without any further + * transformation on the Kotlin side. + * + * @param pixelValues Flattened pixel values in CHW format + * (all images concatenated) + * @param numPatches Total number of patches across all images + * @param numImages Number of images (e.g. video frames) + * @param patchesPerImage Number of patches per image + * @param imageHeights Original height of each image before preprocessing + * @param imageWidths Original width of each image before preprocessing + */ + data class PreprocessedPixels( + val pixelValues: FloatArray, + val numPatches: Int, + val numImages: Int, + val patchesPerImage: IntArray, + val imageHeights: IntArray, + val imageWidths: IntArray + ) : PromptPart() { + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (other !is PreprocessedPixels) return false + return pixelValues.contentEquals(other.pixelValues) + } + override fun hashCode(): Int = pixelValues.contentHashCode() + } +} + +/** + * @brief Performance metrics for the most recent run. + * + * Not every engine fills every field: + * - [NativeQuickDotAI] fills prefill_* / generation_* / peak_memory_kb + * from the C API's PerformanceMetrics struct. + * - [LiteRTLm] currently only fills [initializationDurationMs] and + * [totalDurationMs] because the LiteRT-LM Kotlin API does not expose + * token-level counters in the release we target. + */ +/* -------------------------------------------------------------------- */ +/* Structured Chat / Session types (request-mail1 Β§1–§5, mail2 Β§6) */ +/* -------------------------------------------------------------------- */ + +/** + * @brief Role within a structured chat conversation. + */ +enum class QuickAiChatRole { + SYSTEM, + USER, + ASSISTANT +} + +/** + * @brief Sampling configuration applied to a chat session. + * + * Field support by backend: + * - [LiteRTLm] maps [temperature], [topK], [topP], and [seed] to + * LiteRT-LM's `SamplerConfig`. [minP] and [maxTokens] are NOT + * supported by LiteRT-LM's SamplerConfig and are silently ignored + * (a warning is logged). + * + * Partial specification: + * - Leaving every field null is equivalent to passing no sampling + * config at all β€” LiteRT-LM uses its engine/model default. + * - Specifying any of [temperature] / [topK] / [topP] requires the + * wrapper to build a full `SamplerConfig`, which in LiteRT-LM has + * non-nullable core fields. Unspecified core fields fall back to + * temperature=1.0, topK=40, topP=0.95, and a warning is logged. + * To avoid surprises, specify all three together. + * + * Validation: + * - LiteRT-LM requires topK > 0, topP in [0, 1], temperature >= 0. + * Violations throw from the underlying engine and surface as a + * [BackendResult.Err]. + */ +@Serializable +data class QuickAiChatSamplingConfig( + val temperature: Double? = null, + @SerialName("top_k") val topK: Int? = null, + @SerialName("top_p") val topP: Double? = null, + @SerialName("min_p") val minP: Double? = null, + @SerialName("max_tokens") val maxTokens: Int? = null, + val seed: Int? = null +) + +/** + * @brief Template keyword arguments forwarded to the chat template + * renderer. [enableThinking] controls whether the model's "thinking" + * prompt preamble is activated. + */ +@Serializable +data class QuickAiChatTemplateKwargs( + @SerialName("enable_thinking") val enableThinking: Boolean? = null +) + +/** + * @brief Configuration for a new chat session, passed to + * [QuickDotAI.openChatSession]. + * + * [systemInstruction] maps to LiteRT-LM's + * `ConversationConfig.systemInstruction` and is applied once when the + * conversation is created β€” equivalent to the `"system"` role in + * OpenAI-style message lists. + */ +@Serializable +data class QuickAiChatSessionConfig( + @SerialName("system_instruction") val systemInstruction: String? = null, + val sampling: QuickAiChatSamplingConfig? = null, + @SerialName("chat_template_kwargs") val chatTemplateKwargs: QuickAiChatTemplateKwargs? = null +) + +/** + * @brief One message in a structured chat conversation. + * + * The [parts] list may contain text, image files, or raw image bytes + * in any order β€” the backend preserves insertion order. For text-only + * turns, a single [PromptPart.Text] suffices. + */ +data class QuickAiChatMessage( + val role: QuickAiChatRole, + val parts: List +) + +/** + * @brief Result returned by [QuickDotAI.chatRun] / + * [QuickDotAI.chatRunStreaming]. + */ +data class QuickAiChatResult( + val content: String, + val reasoning: String? = null, + val metrics: PerformanceMetrics? = null +) + +/* -------------------------------------------------------------------- */ +/* Performance metrics */ +/* -------------------------------------------------------------------- */ + +@Serializable +data class PerformanceMetrics( + @SerialName("prefill_tokens") val prefillTokens: Int = 0, + @SerialName("prefill_duration_ms") val prefillDurationMs: Double = 0.0, + @SerialName("generation_tokens") val generationTokens: Int = 0, + @SerialName("generation_duration_ms") val generationDurationMs: Double = 0.0, + @SerialName("total_duration_ms") val totalDurationMs: Double = 0.0, + @SerialName("initialization_duration_ms") val initializationDurationMs: Double = 0.0, + @SerialName("peak_memory_kb") val peakMemoryKb: Long = 0 +) diff --git a/Android/SampleTestAPP/.gitignore b/Android/SampleTestAPP/.gitignore new file mode 100644 index 00000000..aa724b77 --- /dev/null +++ b/Android/SampleTestAPP/.gitignore @@ -0,0 +1,15 @@ +*.iml +.gradle +/local.properties +/.idea/caches +/.idea/libraries +/.idea/modules.xml +/.idea/workspace.xml +/.idea/navEditor.xml +/.idea/assetWizardSettings.xml +.DS_Store +/build +/captures +.externalNativeBuild +.cxx +local.properties diff --git a/Android/SampleTestAPP/build.gradle.kts b/Android/SampleTestAPP/build.gradle.kts new file mode 100644 index 00000000..cbc30450 --- /dev/null +++ b/Android/SampleTestAPP/build.gradle.kts @@ -0,0 +1,73 @@ +plugins { + alias(libs.plugins.android.application) + alias(libs.plugins.kotlin.android) + alias(libs.plugins.kotlin.serialization) +} + +android { + namespace = "com.example.sampletestapp" + compileSdk = 36 + + packaging { + jniLibs.useLegacyPackaging = true + } + defaultConfig { + applicationId = "com.example.sampletestapp" + minSdk = 33 + targetSdk = 36 + versionCode = 1 + versionName = "1.0" + + testInstrumentationRunner = "androidx.test.runner.AndroidJUnitRunner" + + ndk { + // SampleTestAPP hosts the QuickDotAI AAR directly (no remote + // :remote process) so it packages the AAR's arm64-v8a + // jniLibs. Restrict to the matching ABI to avoid empty + // armv7/x86_64 slices. + abiFilters += listOf("arm64-v8a") + } + } + + buildTypes { + release { + isMinifyEnabled = false + proguardFiles( + getDefaultProguardFile("proguard-android-optimize.txt"), + "proguard-rules.pro" + ) + } + } + compileOptions { + sourceCompatibility = JavaVersion.VERSION_17 + targetCompatibility = JavaVersion.VERSION_17 + } + kotlinOptions { + jvmTarget = "17" + // LiteRT-LM 0.10.0 (the version our mirror serves) was compiled + // with Kotlin 2.3.0 but our compiler is 2.2.21; the flag tells + // kotlinc to accept the newer metadata stamp on transitive + // artifacts. See libs.versions.toml for the full story. + freeCompilerArgs += "-Xskip-metadata-version-check" + } +} + +dependencies { + // The whole point of SampleTestAPP: depend on the :QuickDotAI AAR + // directly and drive LiteRTLm / NativeQuickDotAI in-process, without + // QuickAIService. The AAR re-exports kotlinx-serialization-json and + // the LiteRT-LM Kotlin runtime as `api` dependencies so we get them + // transitively. + implementation(project(":QuickDotAI")) + + implementation(libs.androidx.core.ktx) + implementation(libs.androidx.appcompat) + implementation(libs.androidx.activity) + implementation(libs.material) + implementation(libs.androidx.lifecycle.runtime.ktx) + implementation(libs.kotlinx.coroutines.android) + + testImplementation(libs.junit) + androidTestImplementation(libs.androidx.junit) + androidTestImplementation(libs.androidx.espresso.core) +} diff --git a/Android/SampleTestAPP/proguard-rules.pro b/Android/SampleTestAPP/proguard-rules.pro new file mode 100644 index 00000000..161a6dc7 --- /dev/null +++ b/Android/SampleTestAPP/proguard-rules.pro @@ -0,0 +1,9 @@ +# Add project specific ProGuard rules here. +# You can control the set of applied configuration files using the +# proguardFiles setting in build.gradle. +# +# For more details, see +# http://developer.android.com/guide/developing/tools/proguard.html + +# R8/ProGuard rules for the QuickDotAI AAR itself are contributed via +# consumer-rules.pro in that module and applied automatically here. diff --git a/Android/SampleTestAPP/src/main/AndroidManifest.xml b/Android/SampleTestAPP/src/main/AndroidManifest.xml new file mode 100644 index 00000000..c7c6f0f7 --- /dev/null +++ b/Android/SampleTestAPP/src/main/AndroidManifest.xml @@ -0,0 +1,37 @@ + + + + + + + + + + + + + + + + + + diff --git a/Android/SampleTestAPP/src/main/java/com/example/sampletestapp/MainActivity.kt b/Android/SampleTestAPP/src/main/java/com/example/sampletestapp/MainActivity.kt new file mode 100644 index 00000000..4734c717 --- /dev/null +++ b/Android/SampleTestAPP/src/main/java/com/example/sampletestapp/MainActivity.kt @@ -0,0 +1,2843 @@ +// SPDX-License-Identifier: Apache-2.0 +/* + * Copyright (C) 2026 Samsung Electronics Co., Ltd. All Rights Reserved. + * + * @file MainActivity.kt + * @brief Standalone sample that drives the :QuickDotAI AAR directly β€” + * no QuickAIService, no REST, no remote process. + * + * Engine wiring (unchanged from the original sample): + * + * 1. Instantiates [LiteRTLm] for GEMMA4 and [NativeQuickDotAI] for every + * other model, both against a single-thread Executor so all calls + * touching a given engine are serialised on the same worker thread + * (the [QuickDotAI] interface is not internally thread-safe). + * 2. Calls [QuickDotAI.load] once per chosen (model, quant) pair. For + * GEMMA4 it auto-populates [LoadModelRequest.visionBackend] so the + * multimodal path is armed from load time. + * 3. Drives [QuickDotAI.runStreaming] (text-only) or + * [QuickDotAI.runMultimodalStreaming] (when an image is selected) + * with an in-memory StreamSink that appends each delta to the output + * view on the main thread. + * + * UI (M3 Expressive redesign β€” see Applications/QuickAI/QuickDotAI/QuickDotAI.html + * design bundle): the screen is rebuilt as a tabbed Material 3 interface + * with a custom top bar, hero status pill, collapsible Model section + * with chip-group backend / quantization pickers, and a terminal-styled + * output panel shared across the Run / Chat / OpenAI tabs. A light/dark + * toggle in the top bar swaps the full M3 token set at runtime. + */ +package com.example.sampletestapp + +import android.content.res.ColorStateList +import android.graphics.Color +import android.graphics.PorterDuff +import android.graphics.Typeface +import android.graphics.drawable.GradientDrawable +import android.graphics.drawable.RippleDrawable +import android.content.Intent +import android.net.Uri +import android.os.Bundle +import android.os.Environment +import android.provider.Settings +import android.text.Editable +import android.text.InputType +import android.text.TextWatcher +import android.util.TypedValue +import android.view.Gravity +import android.view.View +import android.view.ViewGroup +import android.view.ViewGroup.LayoutParams.MATCH_PARENT +import android.view.ViewGroup.LayoutParams.WRAP_CONTENT +import android.widget.Button +import android.widget.EditText +import android.widget.FrameLayout +import android.widget.HorizontalScrollView +import android.widget.LinearLayout +import android.widget.PopupMenu +import android.widget.ScrollView +import android.widget.TextView +import androidx.activity.result.PickVisualMediaRequest +import androidx.activity.result.contract.ActivityResultContracts +import androidx.appcompat.app.AppCompatActivity +import androidx.core.widget.NestedScrollView +import android.content.ClipData +import android.content.ClipboardManager +import android.content.Context +import com.example.quickdotai.BackendResult +import com.example.quickdotai.BackendType +import com.example.quickdotai.Capability +import com.example.quickdotai.LoadModelRequest +import com.example.quickdotai.ModelCatalog +import com.example.quickdotai.ModelDescriptor +import com.example.quickdotai.ModelIds +import com.example.quickdotai.NativeQuickDotAI +import com.example.quickdotai.RuntimeKind +import com.example.quickdotai.createEngine +import com.example.quickdotai.PerformanceMetrics +import com.example.quickdotai.PromptPart +import com.example.quickdotai.QuickAiChatMessage +import com.example.quickdotai.QuickAiChatRole +import com.example.quickdotai.QuickAiChatSamplingConfig +import com.example.quickdotai.QuickAiChatSessionConfig +import com.example.quickdotai.QuickAiChatTemplateKwargs +import com.example.quickdotai.QuantizationType +import com.example.quickdotai.QuickAiError +import com.example.quickdotai.QuickDotAI +import com.example.quickdotai.StreamSink +import java.io.File +import java.util.concurrent.Executor +import java.util.concurrent.Executors +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonArray +import kotlinx.serialization.json.JsonElement +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.jsonArray +import kotlinx.serialization.json.jsonObject +import kotlinx.serialization.json.jsonPrimitive + +/* ──────────────────────────────────────────────────────────────────────── + * M3 Expressive token set β€” mirrors the LIGHT / DARK objects defined in + * the QuickDotAI.html design bundle. Holding both palettes in a single + * data class lets us swap them atomically when the user taps the dark- + * mode toggle in the top bar. + * ──────────────────────────────────────────────────────────────────────── */ +private data class M3Tokens( + val bg: Int, val surface: Int, val surfaceDim: Int, + val surfaceContainer: Int, val surfaceContainerHigh: Int, + val outline: Int, val outlineVariant: Int, + val onSurface: Int, val onSurfaceVar: Int, + val primary: Int, val onPrimary: Int, + val primaryContainer: Int, val onPrimaryContainer: Int, + val secondary: Int, val secondaryContainer: Int, + val tertiary: Int, val tertiaryContainer: Int, + val error: Int, val errorContainer: Int, + val success: Int, val successContainer: Int, + val codeBg: Int, val codeFg: Int, +) + +private val LIGHT = M3Tokens( + bg = 0xFFFBF8FF.toInt(), + surface = 0xFFFFFFFF.toInt(), + surfaceDim = 0xFFF2EEF8.toInt(), + surfaceContainer = 0xFFF4EFFA.toInt(), + surfaceContainerHigh = 0xFFEDE7F6.toInt(), + outline = 0xFFCAC4D0.toInt(), + outlineVariant = 0xFFE7E0EC.toInt(), + onSurface = 0xFF1C1B1F.toInt(), + onSurfaceVar = 0xFF49454F.toInt(), + primary = 0xFF5B3EBE.toInt(), + onPrimary = 0xFFFFFFFF.toInt(), + primaryContainer = 0xFFE9DDFF.toInt(), + onPrimaryContainer = 0xFF21005D.toInt(), + secondary = 0xFF625B71.toInt(), + secondaryContainer = 0xFFE8DEF8.toInt(), + tertiary = 0xFF7D5260.toInt(), + tertiaryContainer = 0xFFFFD8E4.toInt(), + error = 0xFFB3261E.toInt(), + errorContainer = 0xFFF9DEDC.toInt(), + success = 0xFF146C2E.toInt(), + successContainer = 0xFFD5F5DF.toInt(), + codeBg = 0xFF0F0B1E.toInt(), + codeFg = 0xFFEDE7F6.toInt(), +) + +private val DARK = M3Tokens( + bg = 0xFF121019.toInt(), + surface = 0xFF1B1823.toInt(), + surfaceDim = 0xFF100E17.toInt(), + surfaceContainer = 0xFF211E2B.toInt(), + surfaceContainerHigh = 0xFF2B2834.toInt(), + outline = 0xFF4A4458.toInt(), + outlineVariant = 0xFF2D2A37.toInt(), + onSurface = 0xFFE6E0E9.toInt(), + onSurfaceVar = 0xFFCAC4D0.toInt(), + primary = 0xFFCFBCFF.toInt(), + onPrimary = 0xFF371E73.toInt(), + primaryContainer = 0xFF4A3A8C.toInt(), + onPrimaryContainer = 0xFFE9DDFF.toInt(), + secondary = 0xFFCCC2DC.toInt(), + secondaryContainer = 0xFF4A4458.toInt(), + tertiary = 0xFFEFB8C8.toInt(), + tertiaryContainer = 0xFF633B48.toInt(), + error = 0xFFF2B8B5.toInt(), + errorContainer = 0xFF601410.toInt(), + success = 0xFF6EDB88.toInt(), + successContainer = 0xFF124F24.toInt(), + codeBg = 0xFF06040F.toInt(), + codeFg = 0xFFCFBCFF.toInt(), +) + +class MainActivity : AppCompatActivity() { + + /* ───── Engine plumbing (unchanged from the original sample) ───── */ + + private val engineExecutor: Executor = Executors.newSingleThreadExecutor { r -> + Thread(r, "SampleTestAPP-Engine").apply { isDaemon = true } + } + + private val mainHandler by lazy { android.os.Handler(mainLooper) } + + @Volatile private var engine: QuickDotAI? = null + @Volatile private var loadedKey: String? = null + @Volatile private var selectedImageBytesList: MutableList = mutableListOf() + + /** Convenience: first selected image (or null) β€” for single-image code paths. */ + private val selectedImageBytes: ByteArray? + get() = selectedImageBytesList.firstOrNull() + + /* ───── UI state (preserved across light/dark theme rebuilds) ───── */ + + private var darkMode = false + private var selectedTab: String = "openai" // chat | openai | metrics + private var modelExpanded = true + private var samplingExpanded = false + + private var selFamily: String = ModelIds.GEMMA4 + private var selRuntime: RuntimeKind = RuntimeKind.NATIVE + private var selBackend: BackendType = BackendType.NPU + private val selDescriptor: ModelDescriptor? + get() = ModelCatalog.resolve(selFamily, selRuntime, selBackend) + private var selectedQuant: QuantizationType = QuantizationType.W4A32 + + private var chatSelFamily: String = ModelIds.GEMMA4 + private var chatSelRuntime: RuntimeKind = RuntimeKind.NATIVE + private var chatSelBackend: BackendType = BackendType.NPU + private val chatSelDescriptor: ModelDescriptor? + get() = ModelCatalog.resolve(chatSelFamily, chatSelRuntime, chatSelBackend) + private var chatSelectedQuant: QuantizationType = QuantizationType.W4A32 + + private var modelBasePathText: String = "/sdcard/Download/aistudio-mobile/models/" + private var modelPathText: String = "" + private var promptText: String = "What is rainbow?" + + private var systemPromptText: String = "" + private var temperatureText: String = "" + private var topKText: String = "" + private var topPText: String = "" + private var seedText: String = "" + private var thinkingChoice: String = "default" // default | true | false + private var chatPromptText: String = "I bought a red car" + private var openAiJsonText: String = """[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello!"}, + {"role": "assistant", "content": "Hi! How can I help?"}, + {"role": "system", "content": "Answer in one short sentence."}, + {"role": "user", "content": "Write a short joke about saving RAM."} +]""" + + private var statusText: String = "Idle." + private var outputText: String = "" + private var streaming: Boolean = false + private var loadStatus: String = "idle" // idle | loading | loaded + private var loadedLabel: String = "" + private var sessionIdText: String? = null + private var activeSessionKey: String? = null + private var lastMetrics: PerformanceMetrics? = null + + private var mainScrollY = 0 + private var outputScrollY = 0 + + /* ───── UI refs (re-wired on every rebuildUi() call) ───── */ + + private lateinit var rootHost: FrameLayout + private lateinit var mainScrollView: NestedScrollView + private lateinit var outputScrollView: NestedScrollView + private lateinit var statusView: TextView + private lateinit var outputView: TextView + private lateinit var modelBasePathField: EditText + private lateinit var modelPathField: EditText + private lateinit var promptField: EditText + private lateinit var imageStatusView: TextView + private lateinit var chatSystemPromptField: EditText + private lateinit var chatTemperatureField: EditText + private lateinit var chatTopKField: EditText + private lateinit var chatTopPField: EditText + private lateinit var chatSeedField: EditText + private lateinit var chatPromptField: EditText + private lateinit var openAIMessagesField: EditText + private lateinit var chatModelBasePathField: EditText + private lateinit var chatSessionStatusView: TextView + + /** + * @brief ActivityResult launcher for the Android system photo picker + * (single image). Uses [ActivityResultContracts.PickVisualMedia] which + * does NOT require any runtime permissions. + */ + private val imagePickerLauncher = registerForActivityResult( + ActivityResultContracts.PickVisualMedia() + ) { uri: Uri? -> + if (uri == null) { + setStatus("Image pick cancelled.") + return@registerForActivityResult + } + selectedImageBytesList.clear() + readImageBytesAsync(listOf(uri)) + } + + /** + * @brief ActivityResult launcher for multi-image photo picker. + * Uses [ActivityResultContracts.PickMultipleVisualMedia] which allows + * selecting multiple images for V-JEPA multi-image inference. + */ + private val multiImagePickerLauncher = registerForActivityResult( + ActivityResultContracts.PickMultipleVisualMedia() + ) { uris: List -> + if (uris.isEmpty()) { + setStatus("Image pick cancelled.") + return@registerForActivityResult + } + selectedImageBytesList.clear() + readImageBytesAsync(uris) + } + + override fun onCreate(savedInstanceState: Bundle?) { + super.onCreate(savedInstanceState) + checkAllFilesAccess() + // Seed the path field so a single Load tap works without typing. + modelPathText = defaultModelPathFor(selDescriptor, selectedQuant) ?: "" + rebuildUi() + } + + /** + * @brief Tear down and re-inflate the entire view tree using the + * current [darkMode] palette. EditText contents are preserved by + * round-tripping through the `*Text` state vars, which are kept in + * sync via [TextWatcher]s installed in the field builders below. + */ + private fun rebuildUi(resetModelPath: Boolean = false) { + val tokens = if (darkMode) DARK else LIGHT + // Snapshot any in-flight EditText contents into state vars so the + // theme rebuild does not lose typed input. + if (::promptField.isInitialized) promptText = promptField.text.toString() + if (::modelBasePathField.isInitialized) modelBasePathText = modelBasePathField.text.toString() + if (::chatModelBasePathField.isInitialized) modelBasePathText = chatModelBasePathField.text.toString() + if (::chatSystemPromptField.isInitialized) systemPromptText = chatSystemPromptField.text.toString() + if (::chatTemperatureField.isInitialized) temperatureText = chatTemperatureField.text.toString() + if (::chatTopKField.isInitialized) topKText = chatTopKField.text.toString() + if (::chatTopPField.isInitialized) topPText = chatTopPField.text.toString() + if (::chatSeedField.isInitialized) seedText = chatSeedField.text.toString() + if (::chatPromptField.isInitialized) chatPromptText = chatPromptField.text.toString() + if (::openAIMessagesField.isInitialized) openAiJsonText = openAIMessagesField.text.toString() + + // Save scroll positions before rebuilding + if (::mainScrollView.isInitialized) mainScrollY = mainScrollView.scrollY + if (::outputScrollView.isInitialized) outputScrollY = outputScrollView.scrollY + + rootHost = FrameLayout(this).apply { + layoutParams = ViewGroup.LayoutParams(MATCH_PARENT, MATCH_PARENT) + setBackgroundColor(tokens.bg) + fitsSystemWindows = true + } + val column = LinearLayout(this).apply { + orientation = LinearLayout.VERTICAL + layoutParams = FrameLayout.LayoutParams(MATCH_PARENT, MATCH_PARENT) + } + rootHost.addView(column) + + column.addView(buildTopBar(tokens)) + column.addView(buildHeroCard(tokens)) + column.addView(buildTabBar(tokens)) + + // Main scrolling content area β€” wraps the model section, the + // active tab body, and (for non-metrics tabs) the shared output + // panel. + mainScrollView = NestedScrollView(this).apply { + isFillViewport = false + layoutParams = LinearLayout.LayoutParams(MATCH_PARENT, 0, 1f) + overScrollMode = View.OVER_SCROLL_NEVER + } + val scrollColumn = LinearLayout(this).apply { + orientation = LinearLayout.VERTICAL + setPadding(dp(12), 0, dp(12), dp(120)) + layoutParams = LinearLayout.LayoutParams(MATCH_PARENT, WRAP_CONTENT) + } + mainScrollView.addView(scrollColumn) + column.addView(mainScrollView) + + if (selectedTab != "chat") { + scrollColumn.addView(buildModelSection(tokens)) + spacer(scrollColumn, 10) + } + + scrollColumn.addView(when (selectedTab) { + "chat" -> buildChatTab(tokens) + "openai" -> buildOpenAiTab(tokens) + "metrics" -> buildMetricsTab(tokens) + else -> buildOpenAiTab(tokens) + }) + + if (selectedTab != "metrics") { + spacer(scrollColumn, 10) + scrollColumn.addView(buildOutputPanel(tokens)) + } + + setContentView(rootHost) + + // Restore scroll positions after UI is built + mainScrollView.post { mainScrollView.scrollTo(0, mainScrollY) } + if (::outputScrollView.isInitialized) { + outputScrollView.post { outputScrollView.scrollTo(0, outputScrollY) } + } + } + + /* ════════════════════════════════════════════════════════════════ + * Component builders + * ════════════════════════════════════════════════════════════════ */ + + private fun buildTopBar(t: M3Tokens): View { + val row = LinearLayout(this).apply { + orientation = LinearLayout.HORIZONTAL + gravity = Gravity.CENTER_VERTICAL + setPadding(dp(16), dp(14), dp(16), dp(10)) + } + // Brand mark β€” gradient-filled rounded square with a "✦" glyph, + // approximating the hero icon in the design. + val brand = TextView(this).apply { + text = "✦" + setTextColor(Color.WHITE) + textSize = 18f + typeface = Typeface.DEFAULT_BOLD + gravity = Gravity.CENTER + background = gradient(t.primary, t.tertiary, 12) + layoutParams = LinearLayout.LayoutParams(dp(40), dp(40)) + } + row.addView(brand) + + spacerH(row, 12) + + val titleColumn = LinearLayout(this).apply { + orientation = LinearLayout.VERTICAL + layoutParams = LinearLayout.LayoutParams(0, WRAP_CONTENT, 1f) + } + val title = TextView(this).apply { + text = "QuickDotAI" + setTextColor(t.onSurface) + textSize = 18f + typeface = Typeface.DEFAULT_BOLD + } + val sub = TextView(this).apply { + val tail = if (loadStatus == "loaded") loadedLabel else "no model" + val tailColor = if (loadStatus == "loaded") t.success else t.onSurfaceVar + val full = "In-process AAR Β· $tail" + text = full + textSize = 11f + // Approximate the React design's two-tone subtitle by using + // the success color when a model is loaded, otherwise neutral. + setTextColor(tailColor) + } + titleColumn.addView(title) + titleColumn.addView(sub) + row.addView(titleColumn) + + // Light/dark toggle button. + val toggle = TextView(this).apply { + text = if (darkMode) "β˜€" else "☾" + textSize = 16f + gravity = Gravity.CENTER + setTextColor(t.onSurfaceVar) + background = solid(t.surfaceContainer, 20) + layoutParams = LinearLayout.LayoutParams(dp(40), dp(40)) + setOnClickListener { + darkMode = !darkMode + rebuildUi() + } + } + row.addView(toggle) + return row + } + + private fun buildHeroCard(t: M3Tokens): View { + val tone = statusTone() + val (bgColor, dotColor, fgColor) = when (tone) { + "error" -> Triple(t.errorContainer, t.error, t.error) + "success" -> Triple(t.successContainer, t.success, t.success) + "progress" -> Triple(t.primaryContainer, t.primary, t.onPrimaryContainer) + else -> Triple(t.surfaceContainer, t.outline, t.onSurfaceVar) + } + val outer = FrameLayout(this).apply { + setPadding(dp(12), 0, dp(12), dp(12)) + } + val card = LinearLayout(this).apply { + orientation = LinearLayout.HORIZONTAL + gravity = Gravity.CENTER_VERTICAL + setPadding(dp(14), dp(12), dp(14), dp(12)) + background = solid(bgColor, 20) + layoutParams = FrameLayout.LayoutParams(MATCH_PARENT, WRAP_CONTENT) + } + // Status indicator dot. + val dot = View(this).apply { + background = circle(dotColor) + layoutParams = LinearLayout.LayoutParams(dp(10), dp(10)) + } + card.addView(dot) + spacerH(card, 10) + + statusView = TextView(this).apply { + text = statusText + setTextColor(fgColor) + textSize = 13f + typeface = Typeface.MONOSPACE + ellipsize = android.text.TextUtils.TruncateAt.END + maxLines = 1 + layoutParams = LinearLayout.LayoutParams(0, WRAP_CONTENT, 1f) + } + card.addView(statusView) + + if (streaming) { + val stop = TextView(this).apply { + text = "β–  Stop" + setTextColor(Color.WHITE) + textSize = 12f + typeface = Typeface.DEFAULT_BOLD + background = solid(t.error, 100) + setPadding(dp(12), dp(6), dp(12), dp(6)) + setOnClickListener { onCancelClicked() } + } + card.addView(stop) + } + outer.addView(card) + return outer + } + + private fun buildTabBar(t: M3Tokens): View { + val outer = FrameLayout(this).apply { + setPadding(dp(12), 0, dp(12), dp(10)) + } + val pill = LinearLayout(this).apply { + orientation = LinearLayout.HORIZONTAL + background = solid(t.surfaceContainer, 100) + setPadding(dp(4), dp(4), dp(4), dp(4)) + layoutParams = FrameLayout.LayoutParams(MATCH_PARENT, WRAP_CONTENT) + } + val tabs = listOf("chat" to "⌬ Chat", + "openai" to "{ } OpenAI", "metrics" to "β–€ Metrics") + for ((idx, entry) in tabs.withIndex()) { + val (key, label) = entry + val active = key == selectedTab + val tab = TextView(this).apply { + text = label + gravity = Gravity.CENTER + textSize = 13f + typeface = if (active) Typeface.DEFAULT_BOLD else Typeface.DEFAULT + setTextColor(if (active) t.onPrimary else t.onSurfaceVar) + background = if (active) solid(t.primary, 100) else null + setPadding(dp(6), dp(10), dp(6), dp(10)) + layoutParams = LinearLayout.LayoutParams(0, WRAP_CONTENT, 1f).also { + if (idx > 0) it.leftMargin = dp(4) + } + setOnClickListener { + selectedTab = key + rebuildUi() + } + } + pill.addView(tab) + } + outer.addView(pill) + return outer + } + + private fun buildModelSection(t: M3Tokens): View { + val subtitle = if (loadStatus == "loaded") + "$loadedLabel Β· ${selBackend.name}" + else + "${selDescriptor?.displayName ?: selFamily} Β· ${selRuntime.name} Β· ${selBackend.name}" + val statusDotColor = when (loadStatus) { + "loaded" -> t.success + "loading" -> t.primary + else -> t.outline + } + val card = collapsibleCard( + t = t, + iconGlyph = "β–¦", + iconBg = t.primaryContainer, + iconFg = t.onPrimaryContainer, + title = "Model", + subtitle = subtitle, + rightAdornment = View(this).apply { + background = circle(statusDotColor) + layoutParams = LinearLayout.LayoutParams(dp(10), dp(10)).also { + it.rightMargin = dp(4) + } + }, + expanded = modelExpanded, + onToggle = { + modelExpanded = !modelExpanded + rebuildUi() + } + ) { body -> + // MODEL select β€” 3-axis cascading + body.addView(labelView(t, "FAMILY")) + body.addView(dropdownField(t, ModelCatalog.selectableFamilies(), selFamily) { picked -> + selFamily = picked + selRuntime = ModelCatalog.runtimesFor(selFamily).firstOrNull() ?: selRuntime + selBackend = ModelCatalog.backendsFor(selFamily, selRuntime).firstOrNull() ?: selBackend + modelPathText = defaultModelPathFor(selDescriptor, selectedQuant) ?: "" + rebuildUi(resetModelPath = true) + }) + spacer(body, 12) + + body.addView(labelView(t, "RUNTIME")) + body.addView(chipRow(t, ModelCatalog.runtimesFor(selFamily).map { it.name }, selRuntime.name) { picked -> + selRuntime = RuntimeKind.valueOf(picked) + selBackend = ModelCatalog.backendsFor(selFamily, selRuntime).firstOrNull() ?: selBackend + modelPathText = defaultModelPathFor(selDescriptor, selectedQuant) ?: "" + rebuildUi(resetModelPath = true) + }) + spacer(body, 12) + + body.addView(labelView(t, "BACKEND")) + body.addView(chipRow(t, ModelCatalog.backendsFor(selFamily, selRuntime).map { it.name }, selBackend.name) { picked -> + selBackend = BackendType.valueOf(picked) + modelPathText = defaultModelPathFor(selDescriptor, selectedQuant) ?: "" + rebuildUi(resetModelPath = true) + }) + spacer(body, 12) + + // MODEL BASE PATH β€” editable root directory for model files. + body.addView(labelView(t, "MODEL BASE PATH")) + modelBasePathField = roundedEditText(t, modelBasePathText, mono = true, + onTextChange = { modelBasePathText = it }) + body.addView(modelBasePathField) + spacer(body, 12) + + // MODEL NAME β€” read-only display of default folder name + error if missing. + body.addView(labelView(t, "MODEL NAME")) + body.addView(modelNameView(t, selDescriptor)) + spacer(body, 12) + + // Load / Unload action row. + val actions = LinearLayout(this).apply { orientation = LinearLayout.HORIZONTAL } + val loadLabel = if (loadStatus == "loaded") "↻ Reload model" else "↓ Load model" + actions.addView(filledButton(t, loadLabel, fill = "horizontal") { onLoadClicked() }) + if (loadStatus == "loaded") { + spacerH(actions, 8) + actions.addView(tonalButton(t, "βœ• Unload", danger = true) { onUnloadClicked() }) + } + body.addView(actions) + } + return card + } + + private fun buildRunTab(t: M3Tokens): View { + val card = roundedCard(t, t.surfaceContainer) + val header = sectionHeader(t, "⚑", t.secondaryContainer, t.onSurface, + "One-shot run", "Raw prompt Β· streaming output") + card.addView(header) + spacer(card, 14) + + // PROMPT. + card.addView(labelView(t, "PROMPT")) + promptField = roundedEditText(t, promptText, multiline = true, mono = true, rows = 5, + onTextChange = { promptText = it }) + card.addView(promptField) + spacer(card, 14) + + // IMAGE INPUT. + val imgLabelRow = LinearLayout(this).apply { + orientation = LinearLayout.HORIZONTAL + gravity = Gravity.CENTER_VERTICAL + } + imgLabelRow.addView(labelView(t, "IMAGE INPUT").also { + it.layoutParams = LinearLayout.LayoutParams(WRAP_CONTENT, WRAP_CONTENT) + }) + if (selDescriptor?.let { isMultimodal(it) } != true) { + spacerH(imgLabelRow, 6) + val badge = TextView(this).apply { + text = "GEMMA4 only" + setTextColor(t.tertiary) + textSize = 10f + typeface = Typeface.DEFAULT_BOLD + background = solid(t.tertiaryContainer, 4) + setPadding(dp(6), dp(1), dp(6), dp(1)) + } + imgLabelRow.addView(badge) + } + card.addView(imgLabelRow) + spacer(card, 6) + + if (selectedImageBytes != null) { + val attached = LinearLayout(this).apply { + orientation = LinearLayout.HORIZONTAL + gravity = Gravity.CENTER_VERTICAL + background = solid(t.surfaceContainerHigh, 12) + setPadding(dp(12), dp(12), dp(12), dp(12)) + } + val thumb = TextView(this).apply { + text = "πŸ–Ό" + setTextColor(t.onSurface) + textSize = 22f + gravity = Gravity.CENTER + background = gradient( + blendAlpha(t.primary, 0x55), + blendAlpha(t.tertiary, 0x55), 8 + ) + layoutParams = LinearLayout.LayoutParams(dp(48), dp(48)) + } + attached.addView(thumb) + spacerH(attached, 10) + val info = LinearLayout(this).apply { + orientation = LinearLayout.VERTICAL + layoutParams = LinearLayout.LayoutParams(0, WRAP_CONTENT, 1f) + } + info.addView(TextView(this).apply { + text = "Selected image" + setTextColor(t.onSurface) + textSize = 13f + typeface = Typeface.DEFAULT_BOLD + }) + val totalBytes = selectedImageBytesList.sumOf { it.size } + val modeLabel = if (selectedImageBytesList.size > 1) " Β· multi-image mode" else "" + info.addView(TextView(this).apply { + text = "${totalBytes} bytes Β· raw bytes ready$modeLabel" + setTextColor(t.onSurfaceVar) + textSize = 11f + typeface = Typeface.MONOSPACE + }) + attached.addView(info) + val close = TextView(this).apply { + text = "βœ•" + setTextColor(t.onSurfaceVar) + gravity = Gravity.CENTER + textSize = 14f + layoutParams = LinearLayout.LayoutParams(dp(32), dp(32)) + setOnClickListener { onClearImageClicked(); rebuildUi() } + } + attached.addView(close) + card.addView(attached) + // Keep the legacy imageStatusView reference happy β€” it is + // touched by onClearImageClicked / readImageBytesAsync. + imageStatusView = TextView(this).apply { visibility = View.GONE } + } else { + val pickLabel = if (isMultiImageModel(selDescriptor)) + "+ Pick images for V-JEPA" else "+ Pick image for multimodal run" + val dropzone = TextView(this).apply { + text = pickLabel + setTextColor(t.onSurfaceVar) + textSize = 13f + typeface = Typeface.DEFAULT_BOLD + gravity = Gravity.CENTER + background = dashedBg(t) + setPadding(dp(14), dp(14), dp(14), dp(14)) + setOnClickListener { onPickImageClicked() } + } + card.addView(dropzone) + imageStatusView = TextView(this).apply { visibility = View.GONE } + } + spacer(card, 14) + + // RUN button. + val runLabel = when { + streaming -> "β–  Stop streaming" + selectedImageBytes != null -> "β–Ά Run multimodal (streaming)" + else -> "β–Ά Run (streaming)" + } + val runBtn = filledButton(t, runLabel, fill = "vertical", + danger = streaming) { + if (streaming) onCancelClicked() else onRunClicked() + } + card.addView(runBtn) + return card + } + + private fun buildChatTab(t: M3Tokens): View { + val container = LinearLayout(this).apply { orientation = LinearLayout.VERTICAL } + + // ── Model Selection Card ── + val modelCard = LinearLayout(this).apply { + orientation = LinearLayout.VERTICAL + background = strokedSolid(t.surface, 24, t.outlineVariant, 1) + setPadding(dp(14), dp(14), dp(14), dp(14)) + } + val active = sessionIdText != null && activeSessionKey == loadedKey + + // Top row: status icon + session controls. + val modelHeaderRow = LinearLayout(this).apply { + orientation = LinearLayout.HORIZONTAL + gravity = Gravity.CENTER_VERTICAL + } + val modelIcon = TextView(this).apply { + text = "β–¦" + gravity = Gravity.CENTER + textSize = 18f + setTextColor(if (active) t.success else t.onSurfaceVar) + background = solid(if (active) t.successContainer else t.surfaceContainerHigh, 10) + layoutParams = LinearLayout.LayoutParams(dp(38), dp(38)) + } + modelHeaderRow.addView(modelIcon) + spacerH(modelHeaderRow, 12) + modelHeaderRow.addView(TextView(this).apply { + text = chatSelDescriptor?.let { badgeLabel(it) } ?: chatSelFamily + setTextColor(t.onSurface) + textSize = 13f + typeface = Typeface.MONOSPACE + layoutParams = LinearLayout.LayoutParams(0, WRAP_CONTENT, 1f) + }) + + if (active) { + chatSessionStatusView = TextView(this).apply { + text = "${sessionIdText!!.take(8)}…" + setTextColor(t.success) + textSize = 12f + typeface = Typeface.MONOSPACE + } + spacerH(modelHeaderRow, 8) + modelHeaderRow.addView(chatSessionStatusView) + spacerH(modelHeaderRow, 6) + modelHeaderRow.addView(tonalButton(t, "βœ• Close", danger = true) { onChatCloseClicked() }) + } else { + spacerH(modelHeaderRow, 8) + modelHeaderRow.addView(filledButton(t, "+ Open") { onChatOpenClicked() }) + } + modelCard.addView(modelHeaderRow) + spacer(modelCard, 12) + + // 3-axis cascading selection. + modelCard.addView(labelView(t, "FAMILY")) + modelCard.addView(dropdownField(t, ModelCatalog.selectableFamilies(), chatSelFamily) { picked -> + chatSelFamily = picked + chatSelRuntime = ModelCatalog.runtimesFor(chatSelFamily).firstOrNull() ?: chatSelRuntime + chatSelBackend = ModelCatalog.backendsFor(chatSelFamily, chatSelRuntime).firstOrNull() ?: chatSelBackend + clearChatSessionState() + rebuildUi() + }) + spacer(modelCard, 12) + + modelCard.addView(labelView(t, "RUNTIME")) + modelCard.addView(chipRow(t, ModelCatalog.runtimesFor(chatSelFamily).map { it.name }, chatSelRuntime.name) { picked -> + chatSelRuntime = RuntimeKind.valueOf(picked) + chatSelBackend = ModelCatalog.backendsFor(chatSelFamily, chatSelRuntime).firstOrNull() ?: chatSelBackend + clearChatSessionState() + rebuildUi() + }) + spacer(modelCard, 12) + + modelCard.addView(labelView(t, "BACKEND")) + modelCard.addView(chipRow(t, ModelCatalog.backendsFor(chatSelFamily, chatSelRuntime).map { it.name }, chatSelBackend.name) { picked -> + chatSelBackend = BackendType.valueOf(picked) + clearChatSessionState() + rebuildUi() + }) + spacer(modelCard, 12) + + // MODEL BASE PATH β€” editable root directory for model files. + modelCard.addView(labelView(t, "MODEL BASE PATH")) + chatModelBasePathField = roundedEditText(t, modelBasePathText, mono = true, + onTextChange = { modelBasePathText = it }) + modelCard.addView(chatModelBasePathField) + spacer(modelCard, 12) + + // MODEL NAME β€” read-only display of default folder name + error if missing. + modelCard.addView(labelView(t, "MODEL NAME")) + modelCard.addView(modelNameView(t, chatSelDescriptor)) + container.addView(modelCard) + spacer(container, 10) + + val chatImageCard = roundedCard(t, t.surfaceContainer) + val imgSubtitle = if (isMultiImageModel(chatSelDescriptor)) + "Select multiple images for V-JEPA" else "Attach one image to the next chat message" + chatImageCard.addView(sectionHeader(t, "[ ]", t.secondaryContainer, t.onSurface, + "Image input", imgSubtitle)) + spacer(chatImageCard, 12) + + val chatImgLabelRow = LinearLayout(this).apply { + orientation = LinearLayout.HORIZONTAL + gravity = Gravity.CENTER_VERTICAL + } + chatImgLabelRow.addView(labelView(t, "IMAGE INPUT").also { + it.layoutParams = LinearLayout.LayoutParams(WRAP_CONTENT, WRAP_CONTENT) + }) + if (chatSelDescriptor?.let { isMultimodal(it) } != true) { + spacerH(chatImgLabelRow, 6) + val badge = TextView(this).apply { + text = "Vision model only" + setTextColor(t.tertiary) + textSize = 10f + typeface = Typeface.DEFAULT_BOLD + background = solid(t.tertiaryContainer, 4) + setPadding(dp(6), dp(1), dp(6), dp(1)) + } + chatImgLabelRow.addView(badge) + } + chatImageCard.addView(chatImgLabelRow) + spacer(chatImageCard, 6) + + if (selectedImageBytesList.isNotEmpty()) { + val attached = LinearLayout(this).apply { + orientation = LinearLayout.HORIZONTAL + gravity = Gravity.CENTER_VERTICAL + background = solid(t.surfaceContainerHigh, 12) + setPadding(dp(12), dp(12), dp(12), dp(12)) + } + val thumb = TextView(this).apply { + text = if (selectedImageBytesList.size > 1) "[${selectedImageBytesList.size}]" else "[ ]" + setTextColor(t.onSurface) + textSize = if (selectedImageBytesList.size > 1) 16f else 22f + gravity = Gravity.CENTER + background = gradient( + blendAlpha(t.primary, 0x55), + blendAlpha(t.tertiary, 0x55), 8 + ) + layoutParams = LinearLayout.LayoutParams(dp(48), dp(48)) + } + attached.addView(thumb) + spacerH(attached, 10) + val info = LinearLayout(this).apply { + orientation = LinearLayout.VERTICAL + layoutParams = LinearLayout.LayoutParams(0, WRAP_CONTENT, 1f) + } + val totalBytes = selectedImageBytesList.sumOf { it.size } + info.addView(TextView(this).apply { + text = if (selectedImageBytesList.size == 1) "Selected image" else "${selectedImageBytesList.size} images selected" + setTextColor(t.onSurface) + textSize = 13f + typeface = Typeface.DEFAULT_BOLD + }) + info.addView(TextView(this).apply { + val modeLabel = if (selectedImageBytesList.size > 1) " Β· multi-image mode" else "" + text = "${totalBytes} bytes raw bytes ready$modeLabel" + setTextColor(t.onSurfaceVar) + textSize = 11f + typeface = Typeface.MONOSPACE + }) + attached.addView(info) + val close = TextView(this).apply { + text = "x" + setTextColor(t.onSurfaceVar) + gravity = Gravity.CENTER + textSize = 14f + layoutParams = LinearLayout.LayoutParams(dp(32), dp(32)) + setOnClickListener { onClearImageClicked(); rebuildUi() } + } + attached.addView(close) + chatImageCard.addView(attached) + imageStatusView = TextView(this).apply { visibility = View.GONE } + } else { + val pickLabel = if (isMultiImageModel(chatSelDescriptor)) + "+ Pick images for V-JEPA" else "+ Pick image for multimodal chat" + val dropzone = TextView(this).apply { + text = pickLabel + setTextColor(t.onSurfaceVar) + textSize = 13f + typeface = Typeface.DEFAULT_BOLD + gravity = Gravity.CENTER + background = dashedBg(t) + setPadding(dp(14), dp(14), dp(14), dp(14)) + setOnClickListener { onPickImageClicked() } + } + chatImageCard.addView(dropzone) + imageStatusView = TextView(this).apply { visibility = View.GONE } + } + container.addView(chatImageCard) + spacer(container, 10) + + // ── Collapsible session config ── + val configCard = collapsibleCard( + t = t, + iconGlyph = "βš™", + iconBg = t.tertiaryContainer, + iconFg = t.tertiary, + title = "Session config", + subtitle = "System prompt Β· sampling Β· thinking mode", + rightAdornment = null, + expanded = samplingExpanded, + onToggle = { + samplingExpanded = !samplingExpanded + rebuildUi() + } + ) { body -> + body.addView(labelView(t, "SYSTEM PROMPT")) + chatSystemPromptField = roundedEditText(t, systemPromptText, + multiline = true, rows = 2, + placeholder = "You are a helpful assistant.", + onTextChange = { systemPromptText = it }) + body.addView(chatSystemPromptField) + spacer(body, 10) + + // 2x2 grid of numeric fields. + val grid1 = LinearLayout(this).apply { orientation = LinearLayout.HORIZONTAL } + chatTemperatureField = roundedEditText(t, temperatureText, mono = true, numeric = true, + placeholder = "0.7", onTextChange = { temperatureText = it }) + chatTopKField = roundedEditText(t, topKText, mono = true, numeric = true, + placeholder = "40", onTextChange = { topKText = it }) + grid1.addView(labeledColumn(t, "TEMPERATURE", chatTemperatureField, weight = 1f)) + spacerH(grid1, 10) + grid1.addView(labeledColumn(t, "TOP_K", chatTopKField, weight = 1f)) + body.addView(grid1) + spacer(body, 10) + + val grid2 = LinearLayout(this).apply { orientation = LinearLayout.HORIZONTAL } + chatTopPField = roundedEditText(t, topPText, mono = true, numeric = true, + placeholder = "0.95", onTextChange = { topPText = it }) + chatSeedField = roundedEditText(t, seedText, mono = true, numeric = true, + placeholder = "random", onTextChange = { seedText = it }) + grid2.addView(labeledColumn(t, "TOP_P", chatTopPField, weight = 1f)) + spacerH(grid2, 10) + grid2.addView(labeledColumn(t, "SEED", chatSeedField, weight = 1f)) + body.addView(grid2) + spacer(body, 12) + + body.addView(labelView(t, "ENABLE_THINKING")) + body.addView(chipRow(t, listOf("default", "true", "false"), + thinkingChoice) { picked -> + thinkingChoice = picked + rebuildUi() + }) + } + container.addView(configCard) + spacer(container, 10) + + // ── Composer ── + val composer = LinearLayout(this).apply { + orientation = LinearLayout.VERTICAL + background = solid(t.surfaceContainer, 20) + setPadding(dp(14), dp(14), dp(14), dp(14)) + } + composer.addView(labelView(t, "CHAT MESSAGE")) + spacer(composer, 6) + val composerRow = LinearLayout(this).apply { + orientation = LinearLayout.HORIZONTAL + gravity = Gravity.CENTER_VERTICAL + background = solid(t.surfaceContainerHigh, 24) + setPadding(dp(6), dp(6), dp(6), dp(6)) + } + chatPromptField = EditText(this).apply { + setText(chatPromptText) + hint = "Type a chat message…" + setHintTextColor(t.onSurfaceVar) + setTextColor(t.onSurface) + textSize = 14f + background = null + minLines = 2 + maxLines = 6 + inputType = InputType.TYPE_CLASS_TEXT or InputType.TYPE_TEXT_FLAG_MULTI_LINE + setPadding(dp(14), dp(10), dp(14), dp(10)) + layoutParams = LinearLayout.LayoutParams(0, WRAP_CONTENT, 1f) + addTextChangedListener(simpleWatcher { chatPromptText = it }) + } + composerRow.addView(chatPromptField) + + val canSend = sessionIdText != null && activeSessionKey == loadedKey && !streaming + val sendFab = TextView(this).apply { + text = "β–Ά" + gravity = Gravity.CENTER + textSize = 18f + setTextColor(if (canSend) t.onPrimary else t.onSurfaceVar) + background = solid(if (canSend) t.primary else t.outlineVariant, 22) + layoutParams = LinearLayout.LayoutParams(dp(44), dp(44)).also { + it.bottomMargin = dp(2); it.rightMargin = dp(2) + } + isEnabled = canSend + setOnClickListener { onChatRunStreamingClicked() } + } + composerRow.addView(sendFab) + composer.addView(composerRow) + container.addView(composer) + return container + } + + private fun buildOpenAiTab(t: M3Tokens): View { + val card = roundedCard(t, t.surfaceContainer) + card.addView(sectionHeader(t, "{ }", t.secondaryContainer, t.onSurface, + "OpenAI messages", + "Role-interleaved array Β· forwarded to chat template")) + spacer(card, 12) + + // Parsed preview. + var parseErr: String? = null + val parsed: List? = try { + parseOpenAIMessages(openAiJsonText) + } catch (e: Throwable) { parseErr = e.message; null } + + if (parsed != null) { + val previewWrap = LinearLayout(this).apply { + orientation = LinearLayout.VERTICAL + background = solid(t.surfaceContainerHigh, 16) + setPadding(dp(10), dp(10), dp(10), dp(10)) + } + for (msg in parsed) { + val row = LinearLayout(this).apply { + orientation = LinearLayout.HORIZONTAL + gravity = Gravity.TOP + setPadding(0, dp(3), 0, dp(3)) + } + val (badgeBg, badgeFg, badgeLabel) = when (msg.role) { + QuickAiChatRole.SYSTEM -> Triple(t.tertiaryContainer, t.tertiary, "SYSTEM") + QuickAiChatRole.USER -> Triple(t.primaryContainer, t.onPrimaryContainer, "USER") + QuickAiChatRole.ASSISTANT -> Triple(t.secondaryContainer, t.onSurface, "ASSISTANT") + } + val badge = TextView(this).apply { + text = badgeLabel + setTextColor(badgeFg) + textSize = 10f + typeface = Typeface.MONOSPACE + gravity = Gravity.CENTER + background = solid(badgeBg, 100) + setPadding(dp(8), dp(2), dp(8), dp(2)) + layoutParams = LinearLayout.LayoutParams(dp(80), WRAP_CONTENT) + } + row.addView(badge) + spacerH(row, 8) + val content = TextView(this).apply { + val txtPart = msg.parts.firstOrNull() as? PromptPart.Text + text = txtPart?.text ?: "" + setTextColor(t.onSurface) + textSize = 13f + layoutParams = LinearLayout.LayoutParams(0, WRAP_CONTENT, 1f) + } + row.addView(content) + previewWrap.addView(row) + } + card.addView(previewWrap) + spacer(card, 12) + } + if (parseErr != null) { + val errBox = TextView(this).apply { + text = "β“˜ $parseErr" + setTextColor(t.error) + textSize = 12f + typeface = Typeface.MONOSPACE + background = solid(t.errorContainer, 10) + setPadding(dp(12), dp(8), dp(12), dp(8)) + } + card.addView(errBox) + spacer(card, 12) + } + + val openAiImageCard = roundedCard(t, t.surfaceContainerHigh) + openAiImageCard.addView(sectionHeader(t, "[ ]", t.secondaryContainer, t.onSurface, + "Image input", "Attach the selected image to the last user message")) + spacer(openAiImageCard, 12) + + val imageLabelRow = LinearLayout(this).apply { + orientation = LinearLayout.HORIZONTAL + gravity = Gravity.CENTER_VERTICAL + } + imageLabelRow.addView(labelView(t, "IMAGE INPUT").also { + it.layoutParams = LinearLayout.LayoutParams(WRAP_CONTENT, WRAP_CONTENT) + }) + if (selDescriptor?.let { isMultimodal(it) } != true) { + spacerH(imageLabelRow, 6) + imageLabelRow.addView(TextView(this).apply { + text = "Vision model only" + setTextColor(t.tertiary) + textSize = 10f + typeface = Typeface.DEFAULT_BOLD + background = solid(t.tertiaryContainer, 4) + setPadding(dp(6), dp(1), dp(6), dp(1)) + }) + } + openAiImageCard.addView(imageLabelRow) + spacer(openAiImageCard, 6) + + if (selectedImageBytesList.isNotEmpty()) { + val attached = LinearLayout(this).apply { + orientation = LinearLayout.HORIZONTAL + gravity = Gravity.CENTER_VERTICAL + background = solid(t.surfaceContainer, 12) + setPadding(dp(12), dp(12), dp(12), dp(12)) + } + attached.addView(TextView(this).apply { + text = if (selectedImageBytesList.size > 1) "[${selectedImageBytesList.size}]" else "[ ]" + setTextColor(t.onSurface) + textSize = if (selectedImageBytesList.size > 1) 16f else 22f + gravity = Gravity.CENTER + background = gradient( + blendAlpha(t.primary, 0x55), + blendAlpha(t.tertiary, 0x55), 8 + ) + layoutParams = LinearLayout.LayoutParams(dp(48), dp(48)) + }) + spacerH(attached, 10) + val info = LinearLayout(this).apply { + orientation = LinearLayout.VERTICAL + layoutParams = LinearLayout.LayoutParams(0, WRAP_CONTENT, 1f) + } + val totalBytes = selectedImageBytesList.sumOf { it.size } + info.addView(TextView(this).apply { + text = if (selectedImageBytesList.size == 1) "Selected image" else "${selectedImageBytesList.size} images selected" + setTextColor(t.onSurface) + textSize = 13f + typeface = Typeface.DEFAULT_BOLD + }) + info.addView(TextView(this).apply { + val modeLabel = if (selectedImageBytesList.size > 1) " Β· multi-image mode" else "" + text = "${totalBytes} bytes raw bytes ready$modeLabel" + setTextColor(t.onSurfaceVar) + textSize = 11f + typeface = Typeface.MONOSPACE + }) + attached.addView(info) + attached.addView(TextView(this).apply { + text = "x" + setTextColor(t.onSurfaceVar) + gravity = Gravity.CENTER + textSize = 14f + layoutParams = LinearLayout.LayoutParams(dp(32), dp(32)) + setOnClickListener { onClearImageClicked(); rebuildUi() } + }) + openAiImageCard.addView(attached) + imageStatusView = TextView(this).apply { visibility = View.GONE } + } else { + val pickLabel = if (isMultiImageModel(selDescriptor)) + "+ Pick images for V-JEPA" else "+ Pick image for OpenAI multimodal" + openAiImageCard.addView(TextView(this).apply { + text = pickLabel + setTextColor(t.onSurfaceVar) + textSize = 13f + typeface = Typeface.DEFAULT_BOLD + gravity = Gravity.CENTER + background = dashedBg(t) + setPadding(dp(14), dp(14), dp(14), dp(14)) + setOnClickListener { onPickImageClicked() } + }) + imageStatusView = TextView(this).apply { visibility = View.GONE } + } + card.addView(openAiImageCard) + spacer(card, 12) + + card.addView(labelView(t, "MESSAGES JSON")) + openAIMessagesField = roundedEditText(t, openAiJsonText, multiline = true, mono = true, rows = 8, + onTextChange = { + val same = it == openAiJsonText + openAiJsonText = it + // Re-render the preview so role pills track edits live. + if (!same) rebuildUi() + }) + card.addView(openAIMessagesField) + spacer(card, 12) + + val actions = LinearLayout(this).apply { orientation = LinearLayout.HORIZONTAL } + val runBtn = filledButton(t, "β–Ά Run (streaming)", fill = "horizontal") { + onOpenAIMessagesRunClicked() + }.apply { isEnabled = !streaming && parseErr == null } + actions.addView(runBtn) + spacerH(actions, 8) + val blockingBtn = tonalButton(t, "Blocking") { + onOpenAIMessagesRunBlockingClicked() + }.apply { isEnabled = !streaming && parseErr == null } + actions.addView(blockingBtn) + card.addView(actions) + return card + } + + private fun buildMetricsTab(t: M3Tokens): View { + val column = LinearLayout(this).apply { orientation = LinearLayout.VERTICAL } + val m = lastMetrics + if (m == null) { + val empty = roundedCard(t, t.surfaceContainer).apply { + gravity = Gravity.CENTER + setPadding(dp(20), dp(32), dp(20), dp(32)) + } + val icon = TextView(this).apply { + text = "β–€" + gravity = Gravity.CENTER + textSize = 28f + setTextColor(t.onSurfaceVar) + background = solid(t.surfaceContainerHigh, 28) + layoutParams = LinearLayout.LayoutParams(dp(56), dp(56)).also { + it.bottomMargin = dp(12); it.gravity = Gravity.CENTER_HORIZONTAL + } + } + empty.addView(icon) + empty.addView(TextView(this).apply { + text = "No metrics yet" + setTextColor(t.onSurface) + textSize = 15f + typeface = Typeface.DEFAULT_BOLD + gravity = Gravity.CENTER + }) + empty.addView(TextView(this).apply { + text = "Run a prompt and tap Fetch metrics in the Run tab to populate counters." + setTextColor(t.onSurfaceVar) + textSize = 13f + gravity = Gravity.CENTER + setPadding(0, dp(4), 0, 0) + }) + spacer(empty, 12) + empty.addView(filledButton(t, "Fetch metrics", fill = "vertical") { onMetricsClicked() }) + column.addView(empty) + return column + } + + val tps = if (m.generationDurationMs > 0) + String.format("%.1f", m.generationTokens / (m.generationDurationMs / 1000.0)) + else "β€”" + val ttft = String.format("%.0f", m.prefillDurationMs) + + // Big stat tile: tokens/sec. + column.addView(metricTile(t, "TOKENS PER SECOND", tps, "tok/s", big = true)) + spacer(column, 10) + + val grid = LinearLayout(this).apply { orientation = LinearLayout.HORIZONTAL } + grid.addView(metricTile(t, "TTFT", ttft, "ms", weight = 1f)) + spacerH(grid, 10) + grid.addView(metricTile(t, "TOTAL", String.format("%.2f", m.totalDurationMs / 1000.0), "s", weight = 1f)) + column.addView(grid) + spacer(column, 10) + + val grid2 = LinearLayout(this).apply { orientation = LinearLayout.HORIZONTAL } + grid2.addView(metricTile(t, "PREFILL TOKENS", m.prefillTokens.toString(), "tok", weight = 1f)) + spacerH(grid2, 10) + grid2.addView(metricTile(t, "GEN TOKENS", m.generationTokens.toString(), "tok", weight = 1f)) + column.addView(grid2) + spacer(column, 10) + + val grid3 = LinearLayout(this).apply { orientation = LinearLayout.HORIZONTAL } + grid3.addView(metricTile(t, "INIT", String.format("%.0f", m.initializationDurationMs), "ms", weight = 1f)) + spacerH(grid3, 10) + grid3.addView(metricTile(t, "PEAK MEMORY", + String.format("%.1f", m.peakMemoryKb / 1024.0), "MB", weight = 1f)) + column.addView(grid3) + spacer(column, 10) + + // Prefill β–Έ Gen bar. + val barCard = roundedCard(t, t.surfaceContainer) + barCard.addView(TextView(this).apply { + text = "PREFILL β–Έ GENERATION" + setTextColor(t.onSurfaceVar) + textSize = 11f + typeface = Typeface.DEFAULT_BOLD + setPadding(0, 0, 0, dp(10)) + }) + val total = (m.totalDurationMs).coerceAtLeast(1.0) + val prefillFrac = (m.prefillDurationMs / total).coerceIn(0.0, 1.0).toFloat() + val bar = LinearLayout(this).apply { + orientation = LinearLayout.HORIZONTAL + background = solid(t.surfaceContainerHigh, 4) + layoutParams = LinearLayout.LayoutParams(MATCH_PARENT, dp(8)) + } + bar.addView(View(this).apply { + background = solid(t.tertiary, 0) + layoutParams = LinearLayout.LayoutParams(0, MATCH_PARENT, prefillFrac) + }) + bar.addView(View(this).apply { + background = solid(t.primary, 0) + layoutParams = LinearLayout.LayoutParams(0, MATCH_PARENT, 1f - prefillFrac) + }) + barCard.addView(bar) + spacer(barCard, 8) + val barLegend = LinearLayout(this).apply { orientation = LinearLayout.HORIZONTAL } + barLegend.addView(TextView(this).apply { + text = "● prefill ${ttft}ms" + setTextColor(t.tertiary) + textSize = 11f + typeface = Typeface.MONOSPACE + layoutParams = LinearLayout.LayoutParams(0, WRAP_CONTENT, 1f) + }) + barLegend.addView(TextView(this).apply { + text = "● gen ${String.format("%.0f", m.generationDurationMs)}ms" + setTextColor(t.primary) + textSize = 11f + typeface = Typeface.MONOSPACE + gravity = Gravity.RIGHT + layoutParams = LinearLayout.LayoutParams(0, WRAP_CONTENT, 1f) + }) + barCard.addView(barLegend) + column.addView(barCard) + spacer(column, 10) + column.addView(tonalButton(t, "↻ Refresh metrics") { onMetricsClicked() }) + return column + } + + private fun buildOutputPanel(t: M3Tokens): View { + val wrap = LinearLayout(this).apply { + orientation = LinearLayout.VERTICAL + background = solid(t.codeBg, 20) + } + // Title bar with macOS-style traffic-light dots. + val titleBar = LinearLayout(this).apply { + orientation = LinearLayout.HORIZONTAL + gravity = Gravity.CENTER_VERTICAL + setPadding(dp(14), dp(10), dp(14), dp(10)) + } + val dotRow = LinearLayout(this).apply { orientation = LinearLayout.HORIZONTAL } + for (color in intArrayOf(0xFFFF5F57.toInt(), 0xFFFEBC2E.toInt(), 0xFF28C840.toInt())) { + val d = View(this).apply { + background = circle(color) + layoutParams = LinearLayout.LayoutParams(dp(10), dp(10)).also { + it.rightMargin = dp(4) + } + } + dotRow.addView(d) + } + titleBar.addView(dotRow) + titleBar.addView(TextView(this).apply { + text = "output Β· stream.kt" + setTextColor(0x80FFFFFF.toInt()) + textSize = 11f + typeface = Typeface.MONOSPACE + gravity = Gravity.CENTER + layoutParams = LinearLayout.LayoutParams(0, WRAP_CONTENT, 1f) + }) + val copy = TextView(this).apply { + text = "πŸ“‹" + setTextColor(0x80FFFFFF.toInt()) + gravity = Gravity.CENTER + textSize = 12f + layoutParams = LinearLayout.LayoutParams(dp(24), dp(24)) + setOnClickListener { + if (outputText.isNotEmpty()) { + val clipboard = getSystemService(Context.CLIPBOARD_SERVICE) as ClipboardManager + val clip = ClipData.newPlainText("output", outputText) + clipboard.setPrimaryClip(clip) + } + } + } + titleBar.addView(copy) + spacerH(titleBar, 8) + val clear = TextView(this).apply { + text = "πŸ—‘" + setTextColor(0x80FFFFFF.toInt()) + gravity = Gravity.CENTER + textSize = 12f + layoutParams = LinearLayout.LayoutParams(dp(24), dp(24)) + setOnClickListener { + outputText = "" + outputView.text = "" + } + } + titleBar.addView(clear) + wrap.addView(titleBar) + wrap.addView(View(this).apply { + setBackgroundColor(0x10FFFFFF) + layoutParams = LinearLayout.LayoutParams(MATCH_PARENT, 1) + }) + + outputScrollView = NestedScrollView(this).apply { + layoutParams = LinearLayout.LayoutParams(MATCH_PARENT, dp(220)) + } + outputView = TextView(this).apply { + text = if (outputText.isEmpty()) + "// streaming output appears here…" else outputText + setTextColor(if (outputText.isEmpty()) 0x4DFFFFFF else 0xFFEDE7F6.toInt()) + textSize = 13f + typeface = Typeface.MONOSPACE + setPadding(dp(16), dp(12), dp(16), dp(12)) + layoutParams = FrameLayout.LayoutParams(MATCH_PARENT, WRAP_CONTENT) + } + outputScrollView.addView(outputView) + wrap.addView(outputScrollView) + return wrap + } + + /* ════════════════════════════════════════════════════════════════ + * Reusable UI primitives (drawables, buttons, fields, chips, …) + * ════════════════════════════════════════════════════════════════ */ + + private fun roundedCard(t: M3Tokens, color: Int): LinearLayout { + return LinearLayout(this).apply { + orientation = LinearLayout.VERTICAL + background = solid(color, 20) + setPadding(dp(16), dp(16), dp(16), dp(16)) + } + } + + private fun collapsibleCard( + t: M3Tokens, + iconGlyph: String, iconBg: Int, iconFg: Int, + title: String, subtitle: String, + rightAdornment: View?, + expanded: Boolean, + onToggle: () -> Unit, + body: (LinearLayout) -> Unit, + ): View { + val card = LinearLayout(this).apply { + orientation = LinearLayout.VERTICAL + background = strokedSolid(t.surface, 24, t.outlineVariant, 1) + setPadding(dp(12), dp(12), dp(12), dp(12)) + } + // Header row. + val header = LinearLayout(this).apply { + orientation = LinearLayout.HORIZONTAL + gravity = Gravity.CENTER_VERTICAL + setPadding(dp(4), dp(6), dp(4), dp(6)) + isClickable = true + setOnClickListener { onToggle() } + } + val iconBox = TextView(this).apply { + text = iconGlyph + gravity = Gravity.CENTER + textSize = 16f + typeface = Typeface.DEFAULT_BOLD + setTextColor(iconFg) + background = solid(iconBg, 10) + layoutParams = LinearLayout.LayoutParams(dp(36), dp(36)) + } + header.addView(iconBox) + spacerH(header, 12) + val titleCol = LinearLayout(this).apply { + orientation = LinearLayout.VERTICAL + layoutParams = LinearLayout.LayoutParams(0, WRAP_CONTENT, 1f) + } + titleCol.addView(TextView(this).apply { + text = title + setTextColor(t.onSurface) + textSize = 15f + typeface = Typeface.DEFAULT_BOLD + }) + titleCol.addView(TextView(this).apply { + text = subtitle + setTextColor(t.onSurfaceVar) + textSize = 12f + }) + header.addView(titleCol) + if (rightAdornment != null) header.addView(rightAdornment) + header.addView(TextView(this).apply { + text = if (expanded) "β–²" else "β–Ό" + setTextColor(t.onSurfaceVar) + textSize = 12f + setPadding(dp(6), 0, 0, 0) + }) + card.addView(header) + if (expanded) { + val bodyContainer = LinearLayout(this).apply { + orientation = LinearLayout.VERTICAL + setPadding(dp(4), dp(12), dp(4), dp(4)) + } + body(bodyContainer) + card.addView(bodyContainer) + } + return card + } + + private fun sectionHeader( + t: M3Tokens, glyph: String, iconBg: Int, iconFg: Int, + title: String, subtitle: String + ): View { + val row = LinearLayout(this).apply { + orientation = LinearLayout.HORIZONTAL + gravity = Gravity.CENTER_VERTICAL + } + val icon = TextView(this).apply { + text = glyph + gravity = Gravity.CENTER + textSize = 16f + typeface = Typeface.DEFAULT_BOLD + setTextColor(iconFg) + background = solid(iconBg, 10) + layoutParams = LinearLayout.LayoutParams(dp(32), dp(32)) + } + row.addView(icon) + spacerH(row, 8) + val col = LinearLayout(this).apply { orientation = LinearLayout.VERTICAL } + col.addView(TextView(this).apply { + text = title + setTextColor(t.onSurface) + textSize = 15f + typeface = Typeface.DEFAULT_BOLD + }) + col.addView(TextView(this).apply { + text = subtitle + setTextColor(t.onSurfaceVar) + textSize = 12f + }) + row.addView(col) + return row + } + + private fun labelView(t: M3Tokens, text: String): TextView = TextView(this).apply { + this.text = text + setTextColor(t.onSurfaceVar) + textSize = 12f + typeface = Typeface.DEFAULT_BOLD + setPadding(dp(4), 0, 0, dp(6)) + } + + private fun roundedEditText( + t: M3Tokens, value: String, + multiline: Boolean = false, + mono: Boolean = false, + numeric: Boolean = false, + rows: Int = 1, + placeholder: String? = null, + onTextChange: (String) -> Unit, + ): EditText { + val baseBg = strokedSolid(t.surfaceContainer, 12, Color.TRANSPARENT, 0) + val focusBg = strokedSolid(t.surfaceContainer, 12, t.primary, 2) + val field = EditText(this).apply { + setText(value) + if (placeholder != null) { + hint = placeholder + setHintTextColor(t.onSurfaceVar) + } + setTextColor(t.onSurface) + textSize = 14f + background = baseBg + setPadding(dp(14), dp(12), dp(14), dp(12)) + if (mono) typeface = Typeface.MONOSPACE + if (multiline) { + inputType = InputType.TYPE_CLASS_TEXT or InputType.TYPE_TEXT_FLAG_MULTI_LINE + minLines = rows + maxLines = rows + 4 + setHorizontallyScrolling(false) + gravity = Gravity.TOP + } + if (numeric) { + inputType = InputType.TYPE_CLASS_NUMBER or + InputType.TYPE_NUMBER_FLAG_DECIMAL or + InputType.TYPE_NUMBER_FLAG_SIGNED + } + layoutParams = LinearLayout.LayoutParams(MATCH_PARENT, WRAP_CONTENT) + setOnFocusChangeListener { _, hasFocus -> + background = if (hasFocus) focusBg else baseBg + setPadding(dp(14), dp(12), dp(14), dp(12)) + } + addTextChangedListener(simpleWatcher(onTextChange)) + } + return field + } + + private fun labeledColumn(t: M3Tokens, label: String, field: View, weight: Float): View { + val col = LinearLayout(this).apply { + orientation = LinearLayout.VERTICAL + layoutParams = LinearLayout.LayoutParams(0, WRAP_CONTENT, weight) + } + col.addView(labelView(t, label)) + col.addView(field) + return col + } + + private fun chipRow(t: M3Tokens, options: List, selected: String, + onPick: (String) -> Unit): View { + val scroll = HorizontalScrollView(this).apply { + isHorizontalScrollBarEnabled = false + overScrollMode = View.OVER_SCROLL_NEVER + } + val row = LinearLayout(this).apply { orientation = LinearLayout.HORIZONTAL } + for ((i, opt) in options.withIndex()) { + val active = opt == selected + val chip = TextView(this).apply { + text = if (active) "βœ“ $opt" else opt + setTextColor(if (active) t.onSurface else t.onSurfaceVar) + textSize = 13f + typeface = if (active) Typeface.DEFAULT_BOLD else Typeface.DEFAULT + background = if (active) + solid(t.secondaryContainer, 8) + else + strokedSolid(Color.TRANSPARENT, 8, t.outline, 1) + setPadding(dp(12), dp(6), dp(12), dp(6)) + gravity = Gravity.CENTER + layoutParams = LinearLayout.LayoutParams(WRAP_CONTENT, WRAP_CONTENT).also { + if (i > 0) it.leftMargin = dp(6) + } + setOnClickListener { onPick(opt) } + } + row.addView(chip) + } + scroll.addView(row) + return scroll + } + + /** + * @brief Read-only view showing the default model name (folder name) + * derived from the model descriptor. Displays an error message if the + * expected folder does not exist under MODEL BASE PATH. + */ + private fun modelNameView(t: M3Tokens, descriptor: ModelDescriptor?): View { + val wrapper = LinearLayout(this).apply { + orientation = LinearLayout.VERTICAL + layoutParams = LinearLayout.LayoutParams(MATCH_PARENT, WRAP_CONTENT) + } + val relPath = descriptor?.let { modelPathById[it.id] } ?: descriptor?.id ?: "β€”" + val folderName = relPath.substringBefore('/') + val fullPath = "${modelBasePathText.trimEnd('/')}/$folderName" + val folderExists = File(fullPath).exists() + + val nameView = TextView(this).apply { + text = folderName + setTextColor(t.onSurface) + textSize = 14f + typeface = Typeface.DEFAULT_BOLD + background = strokedSolid(t.surfaceContainer, 8, t.outline, 1) + setPadding(dp(14), dp(12), dp(14), dp(12)) + layoutParams = LinearLayout.LayoutParams(MATCH_PARENT, WRAP_CONTENT) + } + wrapper.addView(nameView) + + if (!folderExists) { + wrapper.addView(TextView(this).apply { + text = "⚠ Folder not found: $fullPath" + setTextColor(t.error) + textSize = 11f + typeface = Typeface.MONOSPACE + setPadding(dp(4), dp(4), 0, 0) + layoutParams = LinearLayout.LayoutParams(MATCH_PARENT, WRAP_CONTENT) + }) + } + return wrapper + } + + /** + * @brief Dropdown field that lists subdirectories inside [basePath] + * and updates [modelPathText] with the full path when a model is picked. + * Falls back to the current [currentPath] as the displayed value if + * the directory is not readable or has no subdirectories. + */ + private fun modelNameDropdown(t: M3Tokens, currentPath: String, basePath: String): View { + // Scan the base path for subdirectories + val baseDir = File(basePath.trimEnd('/')) + val subDirs: List = if (baseDir.exists() && baseDir.isDirectory) { + baseDir.listFiles() + ?.filter { it.isDirectory } + ?.map { it.name } + ?.sorted() + ?: emptyList() + } else { + emptyList() + } + + // Determine display name: if currentPath starts with basePath, show the relative name + val displayPath = currentPath.trimEnd('/') + val displayValue = if (displayPath.startsWith(basePath.trimEnd('/')) && displayPath.length > basePath.trimEnd('/').length) { + displayPath.substring(basePath.trimEnd('/').length + 1) + } else { + displayPath.substringAfterLast('/') + } + + val field = LinearLayout(this).apply { + orientation = LinearLayout.HORIZONTAL + gravity = Gravity.CENTER_VERTICAL + background = strokedSolid(t.surfaceContainer, 8, t.outline, 1) + setPadding(dp(12), dp(10), dp(12), dp(10)) + layoutParams = LinearLayout.LayoutParams(MATCH_PARENT, WRAP_CONTENT) + } + val valueView = TextView(this).apply { + text = displayValue.ifEmpty { "β€” tap to select β€”" } + setTextColor(if (displayValue.isNotEmpty()) t.onSurface else t.onSurfaceVar) + textSize = 14f + typeface = Typeface.DEFAULT_BOLD + layoutParams = LinearLayout.LayoutParams(0, WRAP_CONTENT, 1f) + } + val arrow = TextView(this).apply { + text = "β–Ύ" + setTextColor(t.onSurfaceVar) + textSize = 14f + } + field.addView(valueView) + field.addView(arrow) + + field.setOnClickListener { anchor -> + // Build options: subdirectories + a "Custom…" option + val options = if (subDirs.isNotEmpty()) subDirs + "Custom…" else listOf("Custom…") + val menu = PopupMenu(this, anchor) + options.forEachIndexed { i, opt -> + menu.menu.add(0, i, i, opt).apply { + isCheckable = true + isChecked = opt == displayValue + } + } + menu.setOnMenuItemClickListener { item -> + val picked = options[item.itemId] + if (picked == "Custom…") { + // Switch to a free-form EditText for custom model name entry + modelPathField = roundedEditText(t, modelPathText, mono = true, + onTextChange = { modelPathText = it }) + // Replace the dropdown with the edit text in the parent + val parent = field.parent as? ViewGroup + val index = parent?.indexOfChild(field) ?: -1 + if (parent != null && index >= 0) { + parent.removeView(field) + parent.addView(modelPathField, index) + } + } else { + // Build the full path: basePath/picked + val newPath = "${basePath.trimEnd('/')}/$picked" + modelPathText = newPath + valueView.text = picked + setStatus("Model name: $picked") + } + true + } + menu.show() + } + return field + } + + private fun dropdownField(t: M3Tokens, options: List, selected: String, + onPick: (String) -> Unit): View { + val enabled = options.isNotEmpty() + val field = LinearLayout(this).apply { + orientation = LinearLayout.HORIZONTAL + gravity = Gravity.CENTER_VERTICAL + background = strokedSolid( + if (enabled) t.surfaceContainer else Color.TRANSPARENT, 8, t.outline, 1) + setPadding(dp(12), dp(10), dp(12), dp(10)) + layoutParams = LinearLayout.LayoutParams(MATCH_PARENT, WRAP_CONTENT) + } + val valueView = TextView(this).apply { + text = if (enabled) selected else "β€”" + setTextColor(if (enabled) t.onSurface else t.onSurfaceVar) + textSize = 14f + typeface = Typeface.DEFAULT_BOLD + layoutParams = LinearLayout.LayoutParams(0, WRAP_CONTENT, 1f) + } + val arrow = TextView(this).apply { + text = "β–Ύ" + setTextColor(t.onSurfaceVar) + textSize = 14f + } + field.addView(valueView) + field.addView(arrow) + if (enabled) { + field.setOnClickListener { anchor -> + val menu = PopupMenu(this, anchor) + options.forEachIndexed { i, opt -> + menu.menu.add(0, i, i, opt).apply { + isCheckable = true + isChecked = opt == selected + } + } + menu.setOnMenuItemClickListener { item -> + onPick(options[item.itemId]) + true + } + menu.show() + } + } + return field + } + + private fun filledButton(t: M3Tokens, label: String, fill: String? = null, + danger: Boolean = false, onClick: () -> Unit): Button { + return Button(this).apply { + text = label + isAllCaps = false + setTextColor(if (danger) Color.WHITE else t.onPrimary) + textSize = 14f + typeface = Typeface.DEFAULT_BOLD + stateListAnimator = null + background = solid(if (danger) t.error else t.primary, 100) + setPadding(dp(20), dp(12), dp(20), dp(12)) + layoutParams = when (fill) { + "vertical" -> LinearLayout.LayoutParams(MATCH_PARENT, WRAP_CONTENT) + "horizontal" -> LinearLayout.LayoutParams(0, WRAP_CONTENT, 1f) + else -> LinearLayout.LayoutParams(WRAP_CONTENT, WRAP_CONTENT) + } + setOnClickListener { onClick() } + } + } + + private fun tonalButton(t: M3Tokens, label: String, danger: Boolean = false, + onClick: () -> Unit): Button { + return Button(this).apply { + text = label + isAllCaps = false + setTextColor(if (danger) t.error else t.onSurface) + textSize = 13f + typeface = Typeface.DEFAULT_BOLD + stateListAnimator = null + background = solid(if (danger) t.errorContainer else t.secondaryContainer, 100) + setPadding(dp(14), dp(8), dp(14), dp(8)) + layoutParams = LinearLayout.LayoutParams(WRAP_CONTENT, WRAP_CONTENT) + setOnClickListener { onClick() } + } + } + + private fun outlinedButton(t: M3Tokens, label: String, + onClick: () -> Unit): Button { + return Button(this).apply { + text = label + isAllCaps = false + setTextColor(t.primary) + textSize = 13f + typeface = Typeface.DEFAULT_BOLD + stateListAnimator = null + background = strokedSolid(Color.TRANSPARENT, 100, t.outline, 1) + setPadding(dp(14), dp(8), dp(14), dp(8)) + layoutParams = LinearLayout.LayoutParams(WRAP_CONTENT, WRAP_CONTENT) + setOnClickListener { onClick() } + } + } + + private fun metricTile(t: M3Tokens, label: String, value: String, unit: String, + big: Boolean = false, weight: Float = 0f): View { + val tile = LinearLayout(this).apply { + orientation = LinearLayout.VERTICAL + background = solid(if (big) t.primaryContainer else t.surfaceContainer, 20) + setPadding(dp(if (big) 20 else 14), + dp(if (big) 20 else 14), + dp(if (big) 20 else 14), + dp(if (big) 20 else 14)) + layoutParams = if (weight > 0) + LinearLayout.LayoutParams(0, WRAP_CONTENT, weight) + else + LinearLayout.LayoutParams(MATCH_PARENT, WRAP_CONTENT) + } + val fg = if (big) t.onPrimaryContainer else t.onSurface + tile.addView(TextView(this).apply { + text = label + setTextColor(fg and 0x00FFFFFF or (0xB3 shl 24)) + textSize = 11f + typeface = Typeface.DEFAULT_BOLD + }) + val valueRow = LinearLayout(this).apply { + orientation = LinearLayout.HORIZONTAL + gravity = Gravity.BOTTOM + setPadding(0, dp(4), 0, 0) + } + valueRow.addView(TextView(this).apply { + text = value + setTextColor(fg) + textSize = if (big) 38f else 24f + typeface = Typeface.MONOSPACE + }) + valueRow.addView(TextView(this).apply { + text = " $unit" + setTextColor(fg and 0x00FFFFFF or (0x99 shl 24)) + textSize = if (big) 14f else 11f + typeface = Typeface.DEFAULT_BOLD + setPadding(dp(4), 0, 0, dp(if (big) 6 else 3)) + }) + tile.addView(valueRow) + return tile + } + + /* ───── Drawable / dimension helpers ───── */ + + private fun solid(color: Int, radiusDp: Int): GradientDrawable = GradientDrawable().apply { + setColor(color) + cornerRadius = dpf(radiusDp) + } + + private fun strokedSolid(fill: Int, radiusDp: Int, strokeColor: Int, strokeDp: Int): GradientDrawable = + GradientDrawable().apply { + setColor(fill) + cornerRadius = dpf(radiusDp) + if (strokeDp > 0) setStroke(dp(strokeDp), strokeColor) + } + + private fun circle(color: Int): GradientDrawable = GradientDrawable().apply { + shape = GradientDrawable.OVAL + setColor(color) + } + + private fun gradient(c1: Int, c2: Int, radiusDp: Int): GradientDrawable = + GradientDrawable(GradientDrawable.Orientation.TL_BR, intArrayOf(c1, c2)).apply { + cornerRadius = dpf(radiusDp) + } + + private fun dashedBg(t: M3Tokens): GradientDrawable = GradientDrawable().apply { + setColor(Color.TRANSPARENT) + cornerRadius = dpf(12) + setStroke(dp(1), t.outline, dpf(6), dpf(4)) + } + + private fun blendAlpha(color: Int, alpha: Int): Int = + (color and 0x00FFFFFF) or (alpha shl 24) + + private fun dp(v: Int): Int = TypedValue.applyDimension( + TypedValue.COMPLEX_UNIT_DIP, v.toFloat(), resources.displayMetrics + ).toInt() + + private fun dpf(v: Int): Float = TypedValue.applyDimension( + TypedValue.COMPLEX_UNIT_DIP, v.toFloat(), resources.displayMetrics + ) + + private fun spacer(parent: LinearLayout, h: Int) { + parent.addView(View(this).apply { + layoutParams = LinearLayout.LayoutParams(MATCH_PARENT, dp(h)) + }) + } + + private fun spacerH(parent: LinearLayout, w: Int) { + parent.addView(View(this).apply { + layoutParams = LinearLayout.LayoutParams(dp(w), MATCH_PARENT) + }) + } + + private fun simpleWatcher(onChange: (String) -> Unit): TextWatcher = object : TextWatcher { + override fun beforeTextChanged(s: CharSequence?, start: Int, count: Int, after: Int) {} + override fun onTextChanged(s: CharSequence?, start: Int, before: Int, count: Int) {} + override fun afterTextChanged(s: Editable?) { onChange(s?.toString() ?: "") } + } + + private fun statusTone(): String { + val s = statusText.lowercase() + return when { + "fail" in s || "error" in s || "empty" in s || "cancel" in s -> "error" + "done" in s || "loaded" in s || "opened" in s || "ready" in s -> "success" + "load" in s || "running" in s || "open" in s || "stream" in s -> "progress" + else -> "neutral" + } + } + + /** Pretty label including the design's "MULTIMODAL/TEXT/QNN" tag. */ + private fun badgeLabel(d: ModelDescriptor): String = when { + Capability.MULTIMODAL in d.capabilities -> "[MULTIMODAL] ${d.displayName}" + Capability.EMBEDDING in d.capabilities -> "[EMBEDDING] ${d.displayName}" + d.backends == setOf(BackendType.NPU) -> "[QNN] ${d.displayName}" + else -> "[TEXT] ${d.displayName}" + } + + private fun isMultimodal(d: ModelDescriptor) = Capability.MULTIMODAL in d.capabilities + private fun usesMessagesApi(d: ModelDescriptor) = Capability.MESSAGES_API in d.capabilities + + private fun visionBackendFor(d: ModelDescriptor, backend: BackendType): BackendType? = + if (isMultimodal(d)) backend else null + + /* ════════════════════════════════════════════════════════════════ + * Engine handlers (logic preserved from the original sample) + * ════════════════════════════════════════════════════════════════ */ + + override fun onDestroy() { + // Fire-and-forget close on the engine thread so we don't leak the + // native model handle / LiteRT-LM Engine on config changes. + // engine.close() internally closes any active chat session. + val e = engine + engine = null + loadedKey = null + if (e != null) { + engineExecutor.execute { + try { e.close() } catch (_: Throwable) { /* best effort */ } + } + } + super.onDestroy() + } + + private fun onLoadClicked() { + val req = buildLoadRequest() + loadStatus = "loading" + setStatus("Loading ${req.modelKey}… (vision=${req.visionBackend?.name ?: "off"})") + outputText = "" + mainHandler.post { rebuildUi() } + engineExecutor.execute { loadModelInternal(req) } + } + + private fun onCancelClicked() { + val e = engine + if (e != null && e.chatSessionId != null) { + e.chatCancel() + setStatus("Cancel requested.") + } else { + e?.cancel() + streaming = false + setStatus("Cancelled.") + mainHandler.post { rebuildUi() } + } + } + + /** + * @brief Build a [LoadModelRequest] from the current UI state. Must + * be called on the main thread. + */ + private fun buildLoadRequest(): LoadModelRequest { + val d = selDescriptor + val backend = selBackend + val quant = selectedQuant + val modelPath = modelPathText.trim().ifEmpty { null } + val nativeLibDir = applicationContext.applicationInfo.nativeLibraryDir + val basePath = (if (::modelBasePathField.isInitialized) modelBasePathField.text.toString() + else modelBasePathText).trim().ifEmpty { modelBasePathText } + return LoadModelRequest( + backend = backend, + modelId = d?.id ?: selFamily, + quantization = quant, + modelPath = modelPath, + visionBackend = d?.let { visionBackendFor(it, backend) }, + nativeLibDir = nativeLibDir, + modelBasePath = basePath, + ) + } + + private fun buildChatLoadRequest(): LoadModelRequest { + val d = chatSelDescriptor + val backend = chatSelBackend + val quant = chatSelectedQuant + val modelPath = defaultModelPathFor(d, quant) + val nativeLibDir = applicationContext.applicationInfo.nativeLibraryDir + val basePath = (if (::chatModelBasePathField.isInitialized) chatModelBasePathField.text.toString() + else modelBasePathText).trim().ifEmpty { modelBasePathText } + return LoadModelRequest( + backend = backend, + modelId = d?.id ?: chatSelFamily, + quantization = quant, + modelPath = modelPath, + visionBackend = d?.let { visionBackendFor(it, backend) }, + nativeLibDir = nativeLibDir, + modelBasePath = basePath, + ) + } + + /** + * @brief Core model loading logic. Must be called from [engineExecutor]. + * Returns the loaded [QuickDotAI] engine, or null on failure. + */ + private fun loadModelInternal(req: LoadModelRequest): QuickDotAI? { + if (loadedKey != null && loadedKey != req.modelKey) { + try { engine?.close() } catch (_: Throwable) { /* best effort */ } + engine = null + loadedKey = null + clearChatSessionState() + } + if (engine != null && loadedKey == req.modelKey) { + if (activeSessionKey != null && activeSessionKey != req.modelKey) { + clearChatSessionState() + } + loadDefaultOpenAIExampleFor(req.modelId) + loadStatus = "loaded" + loadedLabel = req.modelKey + setStatus("Already loaded: ${req.modelKey}") + mainHandler.post { rebuildUi() } + return engine + } + + val descriptor = ModelCatalog.byId(req.modelId) + val basePath = req.modelBasePath ?: modelBasePathText + val newEngine: QuickDotAI = if (descriptor != null) { + createEngine(applicationContext, descriptor, modelBasePath = basePath) + } else { + NativeQuickDotAI(applicationContext) // fallback + } + return when (val r = newEngine.load(req)) { + is BackendResult.Ok -> { + engine = newEngine + loadedKey = req.modelKey + clearChatSessionState() + loadDefaultOpenAIExampleFor(req.modelId) + loadStatus = "loaded" + loadedLabel = req.modelKey + setStatus("Loaded ${req.modelKey} (${newEngine.kind}, arch=${newEngine.architecture ?: "?"})") + mainHandler.post { rebuildUi() } + newEngine + } + is BackendResult.Err -> { + try { newEngine.close() } catch (_: Throwable) { /* best effort */ } + loadStatus = "idle" + clearChatSessionState() + setStatus("Load failed: [${r.error.name}] ${r.message ?: ""}") + mainHandler.post { rebuildUi() } + null + } + } + } + + private fun onRunClicked() { + val prompt = promptField.text.toString() + if (prompt.isBlank()) { + setStatus("Prompt is empty.") + return + } + val imgBytesList = selectedImageBytesList.toList() + outputText = "" + mainHandler.post { outputView.text = "" } + streaming = true + val imgStatus = if (imgBytesList.isEmpty()) "" + else if (imgBytesList.size == 1) "Running multimodal (${imgBytesList[0].size}B image)…" + else "Running multimodal (${imgBytesList.size} images)…" + setStatus(if (imgStatus.isNotEmpty()) imgStatus else "Running…") + mainHandler.post { rebuildUi() } + + engineExecutor.execute { + val e = engine + if (e == null) { + streaming = false + setStatus("No model loaded β€” tap Load first.") + mainHandler.post { rebuildUi() } + return@execute + } + val sink = object : StreamSink { + override fun onDelta(text: String) { + outputText += text + mainHandler.post { outputView.append(text) } + } + override fun onDone() { + streaming = false + setStatus("Done.") + mainHandler.post { rebuildUi() } + } + override fun onError(error: QuickAiError, message: String?) { + streaming = false + setStatus("Run failed: [${error.name}] ${message ?: ""}") + mainHandler.post { rebuildUi() } + } + } + try { + if (imgBytesList.isNotEmpty()) { + val parts = imgBytesList.map { PromptPart.ImageBytes(it) } + + listOf(PromptPart.Text(prompt)) + e.runMultimodalHandleStreaming(parts, sink) + } else { + // Run tab removed - use Chat or OpenAI tab instead + streaming = false + setStatus("Run tab removed. Use Chat or OpenAI tab.") + mainHandler.post { rebuildUi() } + } + } catch (t: Throwable) { + streaming = false + setStatus("Run threw: ${t.message}") + mainHandler.post { rebuildUi() } + } + } + } + + private fun onMetricsClicked() { + engineExecutor.execute { + val e = engine + if (e == null) { + setStatus("No model loaded.") + return@execute + } + when (val r = e.metrics()) { + is BackendResult.Ok -> { + lastMetrics = r.value + setStatus("Metrics fetched.") + mainHandler.post { rebuildUi() } + } + is BackendResult.Err -> + setStatus("Metrics failed: [${r.error.name}] ${r.message ?: ""}") + } + } + } + + private fun onUnloadClicked() { + val e = engine + e?.cancel() // Immediately cancel any in-flight inference (thread-safe) + + engineExecutor.execute { + val e = engine + if (e == null) { + clearChatSessionState() + setStatus("Nothing to unload.") + return@execute + } + if (e.chatSessionId != null) { + try { e.closeChatSession() } catch (_: Throwable) {} + clearChatSessionState() + } + when (val r = e.unload()) { + is BackendResult.Ok -> { + loadedKey = null + loadStatus = "idle" + loadedLabel = "" + clearChatSessionState() + setStatus("Unloaded.") + mainHandler.post { rebuildUi() } + } + is BackendResult.Err -> + setStatus("Unload failed: [${r.error.name}] ${r.message ?: ""}") + } + } + } + + /* ───── Chat session handlers ───── */ + + private fun onChatOpenClicked() { + val req = buildChatLoadRequest() + val systemPrompt = systemPromptText.trim().ifEmpty { null } + val temperature = temperatureText.trim().ifEmpty { null }?.toDoubleOrNull() + val topK = topKText.trim().ifEmpty { null }?.toIntOrNull() + val topP = topPText.trim().ifEmpty { null }?.toDoubleOrNull() + val seed = seedText.trim().ifEmpty { null }?.toIntOrNull() + + setStatus("Opening chat session…") + engineExecutor.execute { + val e = loadModelInternal(req) + if (e == null) { + setStatus("Cannot open chat session β€” model load failed.") + return@execute + } + if (e.chatSessionId != null) { + try { e.closeChatSession() } catch (_: Throwable) {} + clearChatSessionState() + } + + val sampling = if (temperature != null || topK != null || topP != null || seed != null) { + QuickAiChatSamplingConfig( + temperature = temperature, topK = topK, topP = topP, seed = seed + ) + } else null + + val templateKwargs = when (thinkingChoice) { + "true" -> QuickAiChatTemplateKwargs(enableThinking = true) + "false" -> QuickAiChatTemplateKwargs(enableThinking = false) + else -> null + } + + val config = if (systemPrompt != null || sampling != null || templateKwargs != null) { + QuickAiChatSessionConfig( + systemInstruction = systemPrompt, + sampling = sampling, + chatTemplateKwargs = templateKwargs + ) + } else null + + when (val r = e.openChatSession(config)) { + is BackendResult.Ok -> { + markChatSessionOpened(r.value) + setStatus("Chat session opened: ${r.value.take(8)}…") + mainHandler.post { rebuildUi() } + } + is BackendResult.Err -> { + clearChatSessionState() + setStatus("Chat open failed: [${r.error.name}] ${r.message ?: ""}") + } + } + } + } + + private fun onChatRunStreamingClicked() { + val prompt = normalizeVisionPromptText(chatPromptField.text.toString()) + if (prompt.isBlank()) { setStatus("Chat message is empty."); return } + val imgBytesList = selectedImageBytesList.toList() + if (imgBytesList.isNotEmpty() && chatSelDescriptor?.let { isMultimodal(it) } != true) { + setStatus("Selected chat model does not support image input.") + return + } + outputText = "" + outputView.text = "" + streaming = true + val imgStatus = if (imgBytesList.isEmpty()) "" + else if (imgBytesList.size == 1) "Chat multimodal streaming..." + else "Chat multimodal streaming (${imgBytesList.size} images)..." + setStatus(if (imgStatus.isNotEmpty()) imgStatus else "Chat streaming...") + mainHandler.post { rebuildUi() } + + engineExecutor.execute { + val e = engine + if (e == null) { + streaming = false + setStatus("No model loaded - tap Open first.") + mainHandler.post { rebuildUi() } + return@execute + } + if (imgBytesList.isEmpty() && (e.chatSessionId == null || activeSessionKey != loadedKey)) { + clearChatSessionState() + streaming = false + setStatus("No chat session - tap Open first.") + mainHandler.post { rebuildUi() } + return@execute + } + val sink = object : StreamSink { + override fun onDelta(text: String) { + outputText += text + mainHandler.post { outputView.append(text) } + } + override fun onDone() { + streaming = false + setStatus("Chat done.") + mainHandler.post { rebuildUi() } + } + override fun onError(error: QuickAiError, message: String?) { + streaming = false + setStatus("Chat error: [${error.name}] ${message ?: ""}") + mainHandler.post { rebuildUi() } + } + } + val parts = buildChatParts(prompt, imgBytesList) + if (imgBytesList.isNotEmpty()) { + try { + when (val r = e.runMultimodalHandleStreaming(parts, sink)) { + is BackendResult.Ok -> { + streaming = false + when (val metrics = e.metrics()) { + is BackendResult.Ok -> { + lastMetrics = metrics.value + setStatus("Chat multimodal done. (${metrics.value.totalDurationMs.toLong()} ms)") + } + is BackendResult.Err -> + setStatus("Chat multimodal done.") + } + mainHandler.post { rebuildUi() } + } + is BackendResult.Err -> { + streaming = false + mainHandler.post { rebuildUi() } + } + } + } catch (t: Throwable) { + streaming = false + setStatus("Chat threw: ${t.message}") + mainHandler.post { rebuildUi() } + } + return@execute + } + try { + when (val r = e.runChatModelHandleStreaming(prompt, sink)) { + is BackendResult.Ok -> { + streaming = false + lastMetrics = r.value.metrics ?: lastMetrics + setStatus("Chat done. (${r.value.metrics?.totalDurationMs?.toLong() ?: "?"} ms)") + mainHandler.post { rebuildUi() } + } + is BackendResult.Err -> { + streaming = false + mainHandler.post { rebuildUi() } + } + } + } catch (t: Throwable) { + streaming = false + setStatus("Chat threw: ${t.message}") + mainHandler.post { rebuildUi() } + } + } + } + + private fun onChatRunBlockingClicked() { + val prompt = normalizeVisionPromptText(chatPromptField.text.toString()) + if (prompt.isBlank()) { setStatus("Chat message is empty."); return } + val imgBytes = selectedImageBytes + if (imgBytes != null && chatSelDescriptor?.let { isMultimodal(it) } != true) { + setStatus("Selected chat model does not support image input.") + return + } + outputText = "" + outputView.text = "" + setStatus("Blocking chat API removed. Use streaming API.") + + engineExecutor.execute { + val e = engine + if (e == null || e.chatSessionId == null || activeSessionKey != loadedKey) { + clearChatSessionState() + setStatus("No chat session - tap Open first.") + return@execute + } + setStatus("Blocking chat API removed. Use streaming API.") + } + } + + private fun onChatRebuildClicked() { + setStatus("Rebuilding chat (clear history)…") + engineExecutor.execute { + val e = engine + if (e == null || e.chatSessionId == null || activeSessionKey != loadedKey) { + clearChatSessionState() + setStatus("No active chat session.") + return@execute + } + when (val r = e.chatRebuild(emptyList())) { + is BackendResult.Ok -> + setStatus("Chat history cleared. Session still active.") + is BackendResult.Err -> + setStatus("Chat rebuild failed: [${r.error.name}] ${r.message ?: ""}") + } + } + } + + private fun onChatCloseClicked() { + setStatus("Closing chat session…") + engineExecutor.execute { + val e = engine + if (e == null || e.chatSessionId == null || activeSessionKey != loadedKey) { + clearChatSessionState() + setStatus("No active chat session.") + return@execute + } + when (val r = e.closeChatSession()) { + is BackendResult.Ok -> { + clearChatSessionState() + setStatus("Chat session closed.") + mainHandler.post { rebuildUi() } + } + is BackendResult.Err -> + setStatus("Chat close failed: [${r.error.name}] ${r.message ?: ""}") + } + } + } + + private fun buildChatParts(prompt: String, imgBytesList: List): List { + return if (imgBytesList.isNotEmpty()) { + imgBytesList.map { PromptPart.ImageBytes(it) } + listOf(PromptPart.Text(prompt)) + } else { + listOf(PromptPart.Text(prompt)) + } + } + + /* ───── OpenAI-style messages handlers ───── */ + + private fun parseOpenAIContentParts(contentElement: JsonElement?): List { + if (contentElement == null) return listOf(PromptPart.Text("")) + if (contentElement is JsonArray) { + val parts = mutableListOf() + for (partElement in contentElement) { + val partObj = partElement as? JsonObject ?: continue + val type = partObj["type"]?.jsonPrimitive?.content?.lowercase() ?: continue + when (type) { + "text", "input_text" -> { + val text = normalizeVisionPromptText( + partObj["text"]?.jsonPrimitive?.content.orEmpty() + ) + if (text.isNotEmpty()) parts.add(PromptPart.Text(text)) + } + "image_url", "input_image" -> { + // SampleTestAPP keeps picked images in selectedImageBytes; + // the actual bytes are attached after parsing. + } + } + } + return parts.ifEmpty { listOf(PromptPart.Text("")) } + } + return listOf(PromptPart.Text( + normalizeVisionPromptText(contentElement.jsonPrimitive.content) + )) + } + + private fun parseOpenAIMessages( + jsonString: String, + attachedImageBytesList: List = emptyList() + ): List? { + return try { + val json = Json { ignoreUnknownKeys = true; isLenient = true } + val element = json.parseToJsonElement(jsonString) + + // Support both top-level array and {"messages": [...]} object + val jsonArray = when { + element is JsonArray -> element + element is JsonObject && element.containsKey("messages") -> + element["messages"]!!.jsonArray + else -> return null + } + + val messages = mutableListOf() + for (element in jsonArray) { + val obj = element.jsonObject + val role = obj["role"]?.jsonPrimitive?.content?.lowercase() ?: continue + val quickRole = when (role) { + "system" -> QuickAiChatRole.SYSTEM + "user" -> QuickAiChatRole.USER + "assistant" -> QuickAiChatRole.ASSISTANT + else -> continue + } + messages.add(QuickAiChatMessage( + role = quickRole, + parts = parseOpenAIContentParts(obj["content"]) + )) + } + if (attachedImageBytesList.isNotEmpty()) { + val lastUserIndex = messages.indexOfLast { it.role == QuickAiChatRole.USER } + if (lastUserIndex < 0) return null + val lastUser = messages[lastUserIndex] + val hasImage = lastUser.parts.any { + it is PromptPart.ImageBytes || it is PromptPart.ImageFile + } + if (!hasImage) { + // Attach all images as PromptPart.ImageBytes before existing parts + val imageParts = attachedImageBytesList.map { PromptPart.ImageBytes(it) } + messages[lastUserIndex] = lastUser.copy( + parts = imageParts + lastUser.parts + ) + } + } + messages + } catch (t: Throwable) { null } + } + + private fun onOpenAIMessagesRunClicked() { + val jsonText = openAIMessagesField.text.toString().trim() + if (jsonText.isBlank()) { setStatus("Messages JSON is empty."); return } + val imgBytesList = selectedImageBytesList.toList() + if (imgBytesList.isNotEmpty() && selDescriptor?.let { isMultimodal(it) } != true) { + setStatus("Selected model does not support OpenAI image input.") + return + } + outputText = "" + outputView.text = "" + streaming = true + val imgStatus = if (imgBytesList.isEmpty()) "" + else if (imgBytesList.size == 1) "Running OpenAI multimodal (streaming)..." + else "Running OpenAI multimodal (${imgBytesList.size} images, streaming)..." + setStatus(if (imgStatus.isNotEmpty()) imgStatus else "Running OpenAI JSON (streaming)...") + mainHandler.post { rebuildUi() } + + val req = buildLoadRequest() + engineExecutor.execute { + val e = loadModelInternal(req) + if (e == null) { + streaming = false; setStatus("Model load failed.") + mainHandler.post { rebuildUi() }; return@execute + } + if (e.chatSessionId != null) { + when (val closeResult = e.closeChatSession()) { + is BackendResult.Ok -> clearChatSessionState() + is BackendResult.Err -> { + streaming = false + setStatus("Failed to close chat session: ${closeResult.message ?: closeResult.error.name}") + mainHandler.post { rebuildUi() } + return@execute + } + } + } else if (sessionIdText != null) { + clearChatSessionState() + } + val sink = object : StreamSink { + override fun onDelta(text: String) { + outputText += text + mainHandler.post { outputView.append(text) } + } + override fun onDone() { + streaming = false; setStatus("OpenAI JSON done.") + mainHandler.post { rebuildUi() } + } + override fun onError(error: QuickAiError, message: String?) { + streaming = false; setStatus("OpenAI error: [${error.name}] ${message ?: ""}") + mainHandler.post { rebuildUi() } + } + } + try { + // Route based on model type: + // - LiteRT-LM (GEMMA4) only supports messages-based API. + // - All others use JSON streaming for full OpenAI format support. + if (imgBytesList.isNotEmpty()) { + // Attach all selected images to the last user message + val messages = parseOpenAIMessages(jsonText, attachedImageBytesList = imgBytesList) + if (messages == null) { + streaming = false + setStatus("Failed to parse messages JSON for multimodal API.") + mainHandler.post { rebuildUi() } + return@execute + } + when (val r = e.runMultimodalHandleWithMessagesStreaming(messages, sink)) { + is BackendResult.Ok -> { + streaming = false + setStatus("Done.") + mainHandler.post { rebuildUi() } + } + is BackendResult.Err -> { + streaming = false + setStatus("Failed: [${r.error.name}] ${r.message ?: ""}") + mainHandler.post { rebuildUi() } + } + } + } else if (selDescriptor?.let { usesMessagesApi(it) } == true) { + val messages = parseOpenAIMessages(jsonText) + if (messages == null) { + streaming = false + setStatus("Failed to parse messages JSON for Messages API.") + mainHandler.post { rebuildUi() } + return@execute + } + when (val r = e.runModelHandleWithMessagesStreaming(messages, sink)) { + is BackendResult.Ok -> { + streaming = false + setStatus("Done.") + mainHandler.post { rebuildUi() } + } + is BackendResult.Err -> { + streaming = false + setStatus("Failed: [${r.error.name}] ${r.message ?: ""}") + mainHandler.post { rebuildUi() } + } + } + } else { + // Use runWithJsonStreaming for full OpenAI format support + when (val r = e.runModelHandleWithJsonStreaming(jsonText, sink)) { + is BackendResult.Ok -> { + streaming = false + setStatus("Done.") + mainHandler.post { rebuildUi() } + } + is BackendResult.Err -> { + streaming = false + setStatus("Failed: [${r.error.name}] ${r.message ?: ""}") + mainHandler.post { rebuildUi() } + } + } + } + } catch (t: Throwable) { + streaming = false; setStatus("Threw: ${t.message}") + mainHandler.post { rebuildUi() } + } + } + } + + private fun onOpenAIMessagesRunBlockingClicked() { + val jsonText = openAIMessagesField.text.toString().trim() + if (jsonText.isBlank()) { setStatus("Messages JSON is empty."); return } + val messages = parseOpenAIMessages(jsonText) + if (messages == null) { setStatus("Failed to parse messages JSON. Check format."); return } + if (messages.isEmpty()) { setStatus("No messages found in JSON."); return } + if (messages.last().role != QuickAiChatRole.USER) { + setStatus("Last message must be role=\"user\" to trigger inference.") + return + } + outputText = ""; outputView.text = "" + setStatus("Running OpenAI messages (blocking)…") + + // Blocking API removed - use streaming API instead + setStatus("Blocking API removed. Use streaming API.") + // val req = buildLoadRequest() + // engineExecutor.execute { + // val e = loadModelInternal(req) + // if (e == null) { setStatus("Model load failed."); return@execute } + // try { + // when (val r = e.runModelHandleWithMessages(messages)) { + // is BackendResult.Ok -> { + // outputText = r.value + // mainHandler.post { outputView.text = r.value } + // setStatus("Done.") + // } + // is BackendResult.Err -> + // setStatus("Failed: [${r.error.name}] ${r.message ?: ""}") + // } + // } catch (t: Throwable) { setStatus("Threw: ${t.message}") } + // } + } + + /** Whether the currently selected model supports multi-image (V-JEPA). */ + private fun isMultiImageModel(d: ModelDescriptor?): Boolean = + d != null && Capability.MULTI_IMAGE in d.capabilities + + /* ───── Image picker handlers ───── */ + + private fun onPickImageClicked() { + // Use multi-image picker for V-JEPA models, single for others + if (isMultiImageModel(selDescriptor) || isMultiImageModel(chatSelDescriptor)) { + multiImagePickerLauncher.launch( + PickVisualMediaRequest(ActivityResultContracts.PickVisualMedia.ImageOnly) + ) + } else { + imagePickerLauncher.launch( + PickVisualMediaRequest(ActivityResultContracts.PickVisualMedia.ImageOnly) + ) + } + setStatus("Opening photo picker…") + } + + private fun onClearImageClicked() { + selectedImageBytesList.clear() + if (::imageStatusView.isInitialized) { + mainHandler.post { imageStatusView.text = "Image: none" } + } + setStatus("Image cleared.") + } + + private fun readImageBytesAsync(uris: List) { + val total = uris.size + setStatus(if (total == 1) "Reading image…" else "Reading $total images…") + Thread({ + try { + val bytesRead = mutableListOf() + for ((i, uri) in uris.withIndex()) { + val bytes = contentResolver.openInputStream(uri)?.use { it.readBytes() } + if (bytes == null || bytes.isEmpty()) { + setStatus("Image ${i + 1} read failed or empty."); return@Thread + } + bytesRead.add(bytes) + } + selectedImageBytesList.clear() + selectedImageBytesList.addAll(bytesRead) + val totalBytes = bytesRead.sumOf { it.size } + val statusMsg = if (bytesRead.size == 1) { + "Image loaded ($totalBytes bytes). Ready for Run or Send." + } else { + "${bytesRead.size} images loaded ($totalBytes bytes total). Multi-image mode." + } + setStatus(statusMsg) + mainHandler.post { rebuildUi() } + } catch (t: Throwable) { + setStatus("Failed to read image: ${t.message}") + } + }, "SampleTestAPP-ImageRead").apply { isDaemon = true }.start() + } + + /* ───── Misc helpers ───── */ + + private fun defaultVisionPrompt(): String = + "이미지 μ„€λͺ…ν•΄μ€˜<|image_start|><|image|><|image_end|>" + + private fun normalizeVisionPromptText(text: String): String = + text.replace("<|image_strart|>", "<|image_start|>") + + private val exampleByModelId: Map = mapOf( + ModelIds.GEMMA4 to """[ + {"role": "system", "content": "You are a concise vision assistant."}, + {"role": "user", "content": [{"type": "text", "text": "Describe this image."}, {"type": "image_url", "image_url": {"url": "sampletestapp://selected-image"}}]} +]""", + ModelIds.FUNCTION_GEMMA to """{ + "messages": [ + {"role": "system", "content": "You can call tools when they are useful."}, + {"role": "user", "content": "01012345678 번호둜 상담 μ˜ˆμ•½ 확인 문자λ₯Ό λ³΄λ‚΄μ€˜."} + ], + "tools": [{"type": "function", "function": {"name": "send_sms", "description": "Send a text message to a phone number.", "parameters": {"type": "object", "properties": {"phone_number": {"type": "string"}, "message": {"type": "string"}}, "required": ["phone_number", "message"]}}}] +}""", + ModelIds.TINY_BERT to """{ + "messages": [{"role": "user", "content": "Explain what text embeddings are in one sentence."}] +}""", + ) + // Shared template for MESSAGES_API models + private val messagesApiExample = """[ + {"role": "system", "content": "You are a helpful assistant. Answer briefly."}, + {"role": "user", "content": "Summarize why on-device language models are useful."} +]""" + private val defaultToolExample = """{ + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Write a short checklist for testing an Android API wrapper."} + ] +}""" + + private fun defaultOpenAIExampleForId(modelId: String): String { + exampleByModelId[modelId]?.let { return it } + val d = ModelCatalog.byId(modelId) + return when { + d != null && Capability.MESSAGES_API in d.capabilities -> messagesApiExample + else -> defaultToolExample + } + } + + private fun loadDefaultOpenAIExampleFor(modelId: String) { + openAiJsonText = defaultOpenAIExampleForId(modelId) + if (::openAIMessagesField.isInitialized) { + mainHandler.post { openAIMessagesField.setText(openAiJsonText) } + } + } + + private fun clearChatSessionState() { + sessionIdText = null + activeSessionKey = null + } + + private fun markChatSessionOpened(sessionId: String) { + sessionIdText = sessionId + activeSessionKey = loadedKey + } + + private fun setStatus(text: String) { + statusText = text + mainHandler.post { + if (::statusView.isInitialized) statusView.text = text + } + } + + /** + * @brief Builds the default on-device model path for the given + * (model, quantization) pair rooted in this app's external files + * dir, so the path lines up with the native C API's hardcoded + * `./models/-` prefix (resolve_model_path() in + * quick_dot_ai_api.cpp). + */ + private fun defaultModelPathFor(d: ModelDescriptor?, quant: QuantizationType): String? { + if (d == null) return null + val base = modelBasePathText.trimEnd('/') + return modelPathById[d.id]?.let { "$base/$it" } + ?: "$base/${d.id}" + } + + private val modelPathById: Map = mapOf( + ModelIds.GEMMA4 to "gemma-4-E2B-it/gemma-4-E2B-it.litertlm", + ModelIds.QWEN3_0_6B to "qwen3-0.6b", + ModelIds.QWEN3_1_7B_Q40 to "qwen3-1.7b-q40-arm", + ModelIds.TINY_BERT to "tiny-bert", + ModelIds.FUNCTION_GEMMA to "function_gemma", + ModelIds.GEMMA4_CPU to "gemma4_cpu", + ModelIds.GEMMA4_E2B_QNN to "gemma-4-e2b-qnn", + ModelIds.VJEPA_QNN to "vjepa-qnn", + ) + + private fun checkAllFilesAccess() { + if (Environment.isExternalStorageManager()) return + val intent = Intent( + Settings.ACTION_MANAGE_APP_ALL_FILES_ACCESS_PERMISSION, + Uri.parse("package:${packageName}") + ) + startActivity(intent) + } + + private fun quantizationSuffix(quant: QuantizationType): String = when (quant) { + QuantizationType.W4A32 -> "-w4a32" + QuantizationType.W16A16 -> "-w16a16" + QuantizationType.W8A16 -> "-w8a16" + QuantizationType.W32A32 -> "-w32a32" + QuantizationType.UNKNOWN -> "-w4a32" + } +} diff --git a/Android/SampleTestAPP/src/main/res/mipmap-hdpi/ic_launcher.png b/Android/SampleTestAPP/src/main/res/mipmap-hdpi/ic_launcher.png new file mode 100755 index 00000000..7d9fd438 Binary files /dev/null and b/Android/SampleTestAPP/src/main/res/mipmap-hdpi/ic_launcher.png differ diff --git a/Android/SampleTestAPP/src/main/res/mipmap-mdpi/ic_launcher.png b/Android/SampleTestAPP/src/main/res/mipmap-mdpi/ic_launcher.png new file mode 100755 index 00000000..dc1525a4 Binary files /dev/null and b/Android/SampleTestAPP/src/main/res/mipmap-mdpi/ic_launcher.png differ diff --git a/Android/SampleTestAPP/src/main/res/mipmap-xhdpi/ic_launcher.png b/Android/SampleTestAPP/src/main/res/mipmap-xhdpi/ic_launcher.png new file mode 100755 index 00000000..8cdc57aa Binary files /dev/null and b/Android/SampleTestAPP/src/main/res/mipmap-xhdpi/ic_launcher.png differ diff --git a/Android/SampleTestAPP/src/main/res/mipmap-xxhdpi/ic_launcher.png b/Android/SampleTestAPP/src/main/res/mipmap-xxhdpi/ic_launcher.png new file mode 100755 index 00000000..cd651b02 Binary files /dev/null and b/Android/SampleTestAPP/src/main/res/mipmap-xxhdpi/ic_launcher.png differ diff --git a/Android/SampleTestAPP/src/main/res/mipmap-xxxhdpi/ic_launcher.png b/Android/SampleTestAPP/src/main/res/mipmap-xxxhdpi/ic_launcher.png new file mode 100755 index 00000000..5a2aa5dc Binary files /dev/null and b/Android/SampleTestAPP/src/main/res/mipmap-xxxhdpi/ic_launcher.png differ diff --git a/Android/SampleTestAPP/src/main/res/values/strings.xml b/Android/SampleTestAPP/src/main/res/values/strings.xml new file mode 100644 index 00000000..5da7c27a --- /dev/null +++ b/Android/SampleTestAPP/src/main/res/values/strings.xml @@ -0,0 +1,4 @@ + + + Quick.AI + diff --git a/Android/SampleTestAPP/src/main/res/values/themes.xml b/Android/SampleTestAPP/src/main/res/values/themes.xml new file mode 100644 index 00000000..f355f587 --- /dev/null +++ b/Android/SampleTestAPP/src/main/res/values/themes.xml @@ -0,0 +1,31 @@ + + + + + + diff --git a/Android/build.gradle.kts b/Android/build.gradle.kts new file mode 100644 index 00000000..55f06703 --- /dev/null +++ b/Android/build.gradle.kts @@ -0,0 +1,7 @@ +// Top-level build file where you can add configuration options common to all sub-projects/modules. +plugins { + alias(libs.plugins.android.application) apply false + alias(libs.plugins.android.library) apply false + alias(libs.plugins.kotlin.android) apply false + alias(libs.plugins.kotlin.serialization) apply false +} diff --git a/Android/gradle.properties b/Android/gradle.properties new file mode 100644 index 00000000..42acdd7f --- /dev/null +++ b/Android/gradle.properties @@ -0,0 +1,34 @@ +# Project-wide Gradle settings. +# IDE (e.g. Android Studio) users: +# Gradle settings configured through the IDE *will override* +# any settings specified in this file. +# For more details on how to configure your build environment visit +# http://www.gradle.org/docs/current/userguide/build_environment.html +# Specifies the JVM arguments used for the daemon process. +# The setting is particularly useful for tweaking memory settings. +org.gradle.jvmargs=-Xmx2048m -Dfile.encoding=UTF-8 +# When configured, Gradle will run in incubating parallel mode. +# This option should only be used with decoupled projects. For more details, visit +# https://developer.android.com/r/tools/gradle-multi-project-decoupled-projects +# org.gradle.parallel=true +# Kotlin code style for this project: "official" or "obsolete": +kotlin.code.style=official + +# AndroidX is the namespace every modern AndroidX / material / activity +# dependency in this build lives under. AGP 9.x had it on by default; +# AGP 8.x still requires this flag to be set explicitly or the +# :checkDebugAarMetadata task fails with a "Configuration contains +# AndroidX dependencies, but the `android.useAndroidX` property is +# not enabled" error. +android.useAndroidX=true +# Non-transitive R classes are the modern best practice β€” each module +# gets its own R class containing only its own resources, which +# speeds up incremental builds and makes resource ownership explicit. +# Off by default on AGP 8.x, on by default on AGP 9.x. +android.nonTransitiveRClass=true +# AGP 8.9.0's internal "tested compileSdk" table caps out at 35, and +# emits a loud WARNING for compileSdk = 36 even though API 36 actually +# works fine on 8.9.x. Suppress the warning so CI logs stay readable. +# This line can be removed once we roll forward to AGP 8.10+ which +# was tested against API 36 directly. +android.suppressUnsupportedCompileSdk=36 \ No newline at end of file diff --git a/Android/gradle/libs.versions.toml b/Android/gradle/libs.versions.toml new file mode 100644 index 00000000..6dd83080 --- /dev/null +++ b/Android/gradle/libs.versions.toml @@ -0,0 +1,72 @@ +[versions] +# Toolchain pinned to versions proven to be resolvable from the +# target environment's Maven mirror, determined empirically from +# build failure logs: +# +# - kotlin-stdlib/kotlin-reflect 2.2.21 appeared in ~/.gradle/caches +# as transitive deps of litertlm-android, proving 2.2.21 is in the +# mirror. Kotlin 2.2.21 also matches what most Kotlin 2.2-compiled +# LiteRT-LM artifacts expect, so we pin the Kotlin compiler here. +# +# - litertlm-android 0.10.0 is the concrete version the mirror serves +# when we ask for `latest.release`. LiteRT-LM 1.0.1 (which we +# originally targeted) is NOT in the mirror. 0.10.0's public API +# matches the one we code against (Content / Contents / Backend / +# EngineConfig with visionBackend / sendMessage(Contents) / etc), +# so pinning down a minor version does not require any code +# changes. +# +# - AGP 8.9.1 is the last AGP that ships the old Kotlin serialization +# plugin default wiring and is compatible with Kotlin 2.2.21 (the +# combination emits a "Kotlin version tested up to 2.1" warning +# but the build succeeds β€” suppressed via +# kotlin.compiler.suppressExperimentalICOptimizationsWarning). +# +# - Gradle 8.11.1 is the paired Gradle for AGP 8.9.x. +# +# LiteRT-LM 0.10.0 itself was compiled with Kotlin 2.3.0; Kotlin +# 2.2.21's compiler cannot natively read 2.3.0-stamped metadata, so +# every module adds `-Xskip-metadata-version-check` to +# kotlinOptions.freeCompilerArgs. The flag is safe here because 2.3 +# only added new metadata fields on top of 2.2, and every LiteRT-LM +# symbol we actually call predates 2.3. +agp = "8.9.1" +kotlin = "2.2.21" +kotlinxSerialization = "1.7.3" +kotlinxCoroutines = "1.9.0" +okhttp = "4.12.0" +nanohttpd = "2.3.1" +litertlm = "0.10.0" +coreKtx = "1.18.0" +junit = "4.13.2" +junitVersion = "1.3.0" +espressoCore = "3.7.0" +appcompat = "1.7.1" +material = "1.13.0" +constraintlayout = "2.2.1" +activity = "1.13.0" +lifecycleRuntimeKtx = "2.9.0" + +[libraries] +androidx-core-ktx = { group = "androidx.core", name = "core-ktx", version.ref = "coreKtx" } +junit = { group = "junit", name = "junit", version.ref = "junit" } +androidx-junit = { group = "androidx.test.ext", name = "junit", version.ref = "junitVersion" } +androidx-espresso-core = { group = "androidx.test.espresso", name = "espresso-core", version.ref = "espressoCore" } +androidx-appcompat = { group = "androidx.appcompat", name = "appcompat", version.ref = "appcompat" } +material = { group = "com.google.android.material", name = "material", version.ref = "material" } +androidx-constraintlayout = { group = "androidx.constraintlayout", name = "constraintlayout", version.ref = "constraintlayout" } +androidx-activity = { group = "androidx.activity", name = "activity", version.ref = "activity" } +androidx-lifecycle-runtime-ktx = { group = "androidx.lifecycle", name = "lifecycle-runtime-ktx", version.ref = "lifecycleRuntimeKtx" } + +# QuickAI additions +nanohttpd = { group = "org.nanohttpd", name = "nanohttpd", version.ref = "nanohttpd" } +kotlinx-serialization-json = { group = "org.jetbrains.kotlinx", name = "kotlinx-serialization-json", version.ref = "kotlinxSerialization" } +kotlinx-coroutines-android = { group = "org.jetbrains.kotlinx", name = "kotlinx-coroutines-android", version.ref = "kotlinxCoroutines" } +okhttp = { group = "com.squareup.okhttp3", name = "okhttp", version.ref = "okhttp" } +litertlm-android = { group = "com.google.ai.edge.litertlm", name = "litertlm-android", version.ref = "litertlm" } + +[plugins] +android-application = { id = "com.android.application", version.ref = "agp" } +android-library = { id = "com.android.library", version.ref = "agp" } +kotlin-android = { id = "org.jetbrains.kotlin.android", version.ref = "kotlin" } +kotlin-serialization = { id = "org.jetbrains.kotlin.plugin.serialization", version.ref = "kotlin" } diff --git a/Android/gradle/wrapper/gradle-wrapper.jar b/Android/gradle/wrapper/gradle-wrapper.jar new file mode 100644 index 00000000..8bdaf60c Binary files /dev/null and b/Android/gradle/wrapper/gradle-wrapper.jar differ diff --git a/Android/gradle/wrapper/gradle-wrapper.properties b/Android/gradle/wrapper/gradle-wrapper.properties new file mode 100644 index 00000000..e2847c82 --- /dev/null +++ b/Android/gradle/wrapper/gradle-wrapper.properties @@ -0,0 +1,7 @@ +distributionBase=GRADLE_USER_HOME +distributionPath=wrapper/dists +distributionUrl=https\://services.gradle.org/distributions/gradle-8.11.1-bin.zip +networkTimeout=10000 +validateDistributionUrl=true +zipStoreBase=GRADLE_USER_HOME +zipStorePath=wrapper/dists diff --git a/Android/gradlew b/Android/gradlew new file mode 100755 index 00000000..ef07e016 --- /dev/null +++ b/Android/gradlew @@ -0,0 +1,251 @@ +#!/bin/sh + +# +# Copyright Β© 2015 the original authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 +# + +############################################################################## +# +# Gradle start up script for POSIX generated by Gradle. +# +# Important for running: +# +# (1) You need a POSIX-compliant shell to run this script. If your /bin/sh is +# noncompliant, but you have some other compliant shell such as ksh or +# bash, then to run this script, type that shell name before the whole +# command line, like: +# +# ksh Gradle +# +# Busybox and similar reduced shells will NOT work, because this script +# requires all of these POSIX shell features: +# * functions; +# * expansions Β«$varΒ», Β«${var}Β», Β«${var:-default}Β», Β«${var+SET}Β», +# Β«${var#prefix}Β», Β«${var%suffix}Β», and Β«$( cmd )Β»; +# * compound commands having a testable exit status, especially Β«caseΒ»; +# * various built-in commands including Β«commandΒ», Β«setΒ», and Β«ulimitΒ». +# +# Important for patching: +# +# (2) This script targets any POSIX shell, so it avoids extensions provided +# by Bash, Ksh, etc; in particular arrays are avoided. +# +# The "traditional" practice of packing multiple parameters into a +# space-separated string is a well documented source of bugs and security +# problems, so this is (mostly) avoided, by progressively accumulating +# options in "$@", and eventually passing that to Java. +# +# Where the inherited environment variables (DEFAULT_JVM_OPTS, JAVA_OPTS, +# and GRADLE_OPTS) rely on word-splitting, this is performed explicitly; +# see the in-line comments for details. +# +# There are tweaks for specific operating systems such as AIX, CygWin, +# Darwin, MinGW, and NonStop. +# +# (3) This script is generated from the Groovy template +# https://github.com/gradle/gradle/blob/HEAD/platforms/jvm/plugins-application/src/main/resources/org/gradle/api/internal/plugins/unixStartScript.txt +# within the Gradle project. +# +# You can find Gradle at https://github.com/gradle/gradle/. +# +############################################################################## + +# Attempt to set APP_HOME + +# Resolve links: $0 may be a link +app_path=$0 + +# Need this for daisy-chained symlinks. +while + APP_HOME=${app_path%"${app_path##*/}"} # leaves a trailing /; empty if no leading path + [ -h "$app_path" ] +do + ls=$( ls -ld "$app_path" ) + link=${ls#*' -> '} + case $link in #( + /*) app_path=$link ;; #( + *) app_path=$APP_HOME$link ;; + esac +done + +# This is normally unused +# shellcheck disable=SC2034 +APP_BASE_NAME=${0##*/} +# Discard cd standard output in case $CDPATH is set (https://github.com/gradle/gradle/issues/25036) +APP_HOME=$( cd -P "${APP_HOME:-./}" > /dev/null && printf '%s\n' "$PWD" ) || exit + +# Use the maximum available, or set MAX_FD != -1 to use that value. +MAX_FD=maximum + +warn () { + echo "$*" +} >&2 + +die () { + echo + echo "$*" + echo + exit 1 +} >&2 + +# OS specific support (must be 'true' or 'false'). +cygwin=false +msys=false +darwin=false +nonstop=false +case "$( uname )" in #( + CYGWIN* ) cygwin=true ;; #( + Darwin* ) darwin=true ;; #( + MSYS* | MINGW* ) msys=true ;; #( + NONSTOP* ) nonstop=true ;; +esac + +CLASSPATH="\\\"\\\"" + + +# Determine the Java command to use to start the JVM. +if [ -n "$JAVA_HOME" ] ; then + if [ -x "$JAVA_HOME/jre/sh/java" ] ; then + # IBM's JDK on AIX uses strange locations for the executables + JAVACMD=$JAVA_HOME/jre/sh/java + else + JAVACMD=$JAVA_HOME/bin/java + fi + if [ ! -x "$JAVACMD" ] ; then + die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." + fi +else + JAVACMD=java + if ! command -v java >/dev/null 2>&1 + then + die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." + fi +fi + +# Increase the maximum file descriptors if we can. +if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then + case $MAX_FD in #( + max*) + # In POSIX sh, ulimit -H is undefined. That's why the result is checked to see if it worked. + # shellcheck disable=SC2039,SC3045 + MAX_FD=$( ulimit -H -n ) || + warn "Could not query maximum file descriptor limit" + esac + case $MAX_FD in #( + '' | soft) :;; #( + *) + # In POSIX sh, ulimit -n is undefined. That's why the result is checked to see if it worked. + # shellcheck disable=SC2039,SC3045 + ulimit -n "$MAX_FD" || + warn "Could not set maximum file descriptor limit to $MAX_FD" + esac +fi + +# Collect all arguments for the java command, stacking in reverse order: +# * args from the command line +# * the main class name +# * -classpath +# * -D...appname settings +# * --module-path (only if needed) +# * DEFAULT_JVM_OPTS, JAVA_OPTS, and GRADLE_OPTS environment variables. + +# For Cygwin or MSYS, switch paths to Windows format before running java +if "$cygwin" || "$msys" ; then + APP_HOME=$( cygpath --path --mixed "$APP_HOME" ) + CLASSPATH=$( cygpath --path --mixed "$CLASSPATH" ) + + JAVACMD=$( cygpath --unix "$JAVACMD" ) + + # Now convert the arguments - kludge to limit ourselves to /bin/sh + for arg do + if + case $arg in #( + -*) false ;; # don't mess with options #( + /?*) t=${arg#/} t=/${t%%/*} # looks like a POSIX filepath + [ -e "$t" ] ;; #( + *) false ;; + esac + then + arg=$( cygpath --path --ignore --mixed "$arg" ) + fi + # Roll the args list around exactly as many times as the number of + # args, so each arg winds up back in the position where it started, but + # possibly modified. + # + # NB: a `for` loop captures its iteration list before it begins, so + # changing the positional parameters here affects neither the number of + # iterations, nor the values presented in `arg`. + shift # remove old arg + set -- "$@" "$arg" # push replacement arg + done +fi + + +# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' + +# Collect all arguments for the java command: +# * DEFAULT_JVM_OPTS, JAVA_OPTS, and optsEnvironmentVar are not allowed to contain shell fragments, +# and any embedded shellness will be escaped. +# * For example: A user cannot expect ${Hostname} to be expanded, as it is an environment variable and will be +# treated as '${Hostname}' itself on the command line. + +set -- \ + "-Dorg.gradle.appname=$APP_BASE_NAME" \ + -classpath "$CLASSPATH" \ + -jar "$APP_HOME/gradle/wrapper/gradle-wrapper.jar" \ + "$@" + +# Stop when "xargs" is not available. +if ! command -v xargs >/dev/null 2>&1 +then + die "xargs is not available" +fi + +# Use "xargs" to parse quoted args. +# +# With -n1 it outputs one arg per line, with the quotes and backslashes removed. +# +# In Bash we could simply go: +# +# readarray ARGS < <( xargs -n1 <<<"$var" ) && +# set -- "${ARGS[@]}" "$@" +# +# but POSIX shell has neither arrays nor command substitution, so instead we +# post-process each arg (as a line of input to sed) to backslash-escape any +# character that might be a shell metacharacter, then use eval to reverse +# that process (while maintaining the separation between arguments), and wrap +# the whole thing up as a single "set" statement. +# +# This will of course break if any of these variables contains a newline or +# an unmatched quote. +# + +eval "set -- $( + printf '%s\n' "$DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS" | + xargs -n1 | + sed ' s~[^-[:alnum:]+,./:=@_]~\\&~g; ' | + tr '\n' ' ' + )" '"$@"' + +exec "$JAVACMD" "$@" diff --git a/Android/gradlew.bat b/Android/gradlew.bat new file mode 100644 index 00000000..5eed7ee8 --- /dev/null +++ b/Android/gradlew.bat @@ -0,0 +1,94 @@ +@rem +@rem Copyright 2015 the original author or authors. +@rem +@rem Licensed under the Apache License, Version 2.0 (the "License"); +@rem you may not use this file except in compliance with the License. +@rem You may obtain a copy of the License at +@rem +@rem https://www.apache.org/licenses/LICENSE-2.0 +@rem +@rem Unless required by applicable law or agreed to in writing, software +@rem distributed under the License is distributed on an "AS IS" BASIS, +@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +@rem See the License for the specific language governing permissions and +@rem limitations under the License. +@rem +@rem SPDX-License-Identifier: Apache-2.0 +@rem + +@if "%DEBUG%"=="" @echo off +@rem ########################################################################## +@rem +@rem Gradle startup script for Windows +@rem +@rem ########################################################################## + +@rem Set local scope for the variables with windows NT shell +if "%OS%"=="Windows_NT" setlocal + +set DIRNAME=%~dp0 +if "%DIRNAME%"=="" set DIRNAME=. +@rem This is normally unused +set APP_BASE_NAME=%~n0 +set APP_HOME=%DIRNAME% + +@rem Resolve any "." and ".." in APP_HOME to make it shorter. +for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi + +@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m" + +@rem Find java.exe +if defined JAVA_HOME goto findJavaFromJavaHome + +set JAVA_EXE=java.exe +%JAVA_EXE% -version >NUL 2>&1 +if %ERRORLEVEL% equ 0 goto execute + +echo. 1>&2 +echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. 1>&2 +echo. 1>&2 +echo Please set the JAVA_HOME variable in your environment to match the 1>&2 +echo location of your Java installation. 1>&2 + +goto fail + +:findJavaFromJavaHome +set JAVA_HOME=%JAVA_HOME:"=% +set JAVA_EXE=%JAVA_HOME%/bin/java.exe + +if exist "%JAVA_EXE%" goto execute + +echo. 1>&2 +echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% 1>&2 +echo. 1>&2 +echo Please set the JAVA_HOME variable in your environment to match the 1>&2 +echo location of your Java installation. 1>&2 + +goto fail + +:execute +@rem Setup the command line + +set CLASSPATH= + + +@rem Execute Gradle +"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" -jar "%APP_HOME%\gradle\wrapper\gradle-wrapper.jar" %* + +:end +@rem End local scope for the variables with windows NT shell +if %ERRORLEVEL% equ 0 goto mainEnd + +:fail +rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of +rem the _cmd.exe /c_ return code! +set EXIT_CODE=%ERRORLEVEL% +if %EXIT_CODE% equ 0 set EXIT_CODE=1 +if not ""=="%GRADLE_EXIT_CONSOLE%" exit %EXIT_CODE% +exit /b %EXIT_CODE% + +:mainEnd +if "%OS%"=="Windows_NT" endlocal + +:omega diff --git a/Android/local.properties b/Android/local.properties new file mode 100644 index 00000000..0a737e31 --- /dev/null +++ b/Android/local.properties @@ -0,0 +1,8 @@ +## This file must *NOT* be checked into Version Control Systems, +# as it contains information specific to your local configuration. +# +# Location of the SDK. This is only used by Gradle. +# For customization when using a Version Control System, please read the +# header note. +#Fri Apr 10 10:52:48 KST 2026 +sdk.dir=/home/junbong/progra/Android/Sdk diff --git a/Android/settings.gradle.kts b/Android/settings.gradle.kts new file mode 100644 index 00000000..87d321f1 --- /dev/null +++ b/Android/settings.gradle.kts @@ -0,0 +1,59 @@ +pluginManagement { + repositories { + // Google Maven hosts the real AGP classpath artifact + // (`com.android.tools.build:gradle`) as well as androidx and + // com.google.* tooling. No includeGroupByRegex filter here: + // the filter was excluding the transitive toolchain resolution + // for some plugin markers in certain environments, and the + // performance cost of a broader repo is negligible compared to + // the maintenance cost of a fragile allow-list. + google() + mavenCentral() + gradlePluginPortal() + } + // Map plugin IDs directly to their canonical artifact coordinates + // so Gradle can resolve them from Google Maven / Maven Central + // without going through the Gradle Plugin Portal's marker + // redirection. This matters in environments where the Plugin + // Portal (plugins.gradle.org) is unreachable or where its marker + // artifacts for a given version have not propagated to the local + // mirror yet β€” the symptom is a "could not resolve plugin artifact + // ...gradle.plugin:" failure even though the underlying + // tools artifact is sitting right there on Google Maven. The + // mapping is purely additive: if the marker is reachable, Gradle + // will use it; if not, eachPlugin kicks in and rewrites the + // request to the full Maven coordinate. + resolutionStrategy { + eachPlugin { + val id = requested.id.id + val version = requested.version + when { + id == "com.android.application" || id == "com.android.library" -> + useModule("com.android.tools.build:gradle:$version") + id == "org.jetbrains.kotlin.android" -> + useModule("org.jetbrains.kotlin:kotlin-gradle-plugin:$version") + id == "org.jetbrains.kotlin.plugin.serialization" -> + useModule("org.jetbrains.kotlin:kotlin-serialization:$version") + } + } + } +} +// No java { toolchain } / kotlin { jvmToolchain } DSL is used anywhere +// in this build β€” every module pins its Java version through +// `android { compileOptions { sourceCompatibility / targetCompatibility } }` +// instead, and the Gradle daemon JVM is provisioned via +// gradle/gradle-daemon-jvm.properties (foojay URLs are hard-coded +// there, no plugin needed). The foojay-resolver-convention plugin +// was scaffolded originally but is dead code here; re-add it if a +// module later adopts the toolchain DSL. +dependencyResolutionManagement { + repositoriesMode.set(RepositoriesMode.FAIL_ON_PROJECT_REPOS) + repositories { + google() + mavenCentral() + } +} + +rootProject.name = "QuickAI" +include(":QuickDotAI") +include(":SampleTestAPP") diff --git a/CLAUDE.md b/CLAUDE.md deleted file mode 100644 index ed73d712..00000000 --- a/CLAUDE.md +++ /dev/null @@ -1,324 +0,0 @@ -# CLAUDE.md - -This file briefs Claude Code (and any new contributor) on Quick.AI's -conventions before making changes. Read it top-to-bottom β€” it is -deliberately short. - -The bulk of this document is contributor-facing guidance that applies -to humans and AI coding agents alike. Agent-specific rules (Claude -Code, other coding agents) live in [a single dedicated section at the -bottom](#for-ai-coding-agents). - ---- - -## Project at a glance - -Quick.AI is a production-grade **on-device causal-LM inference engine** -built on top of [NNTrainer](https://github.com/nntrainer/nntrainer). It -targets Linux and Android (arm64-v8a) with hand-tuned ARMv8.2-a (FP16, -dotprod, i8mm) and AVX2 kernels, and runs MoE models (Qwen3-MoE 30B, -GPT-OSS 20B/120B) on a phone via Flash Storage Utilization (FSU) β€” -experts stream from disk only when their tokens fire. - -| Item | Value | -|---|---| -| Language | C++17 (and a little C for the public API) | -| Build system | Meson + Ninja (`meson_version >= 0.55.0`) | -| Submodule | `subprojects/nntrainer/` (pinned commit, meson subproject) | -| C++ namespace | `quick_dot_ai` (do **not** reintroduce the old `causallm`) | -| Brand spelling | "Quick.AI" in human copy, `quick_dot_ai` in identifiers | -| License | Apache-2.0 | - -Sources of truth that go deeper than this file: - -- [`README.md`](README.md) β€” user-facing intro, demos, quick start. -- [`docs/architecture.md`](docs/architecture.md) β€” Mermaid diagram, per-binary / per-plugin breakdown, design rationale. -- [`models/README.md`](models/README.md) β€” model author guide. -- [`api/README.md`](api/README.md) β€” C API reference. -- [`benchmarks/README.md`](benchmarks/README.md) β€” Android benchmark tooling. - ---- - -## Repository map - -``` -api/ Stable C API surface (libquick_dot_ai_api.so). - ABI-stable β€” do NOT rename symbols or change enum values. -layers/ Per-layer plugin .so's (rms_norm, swiglu, qkv, - mha_core, lm_head, tie_word_embedding, embedding_*, - reshaped_rms_norm). Each builds as its own - libquick_dot_ai__layer.so. -models/ causal_lm + transformer base classes; per-family - causal LMs (qwen2, qwen3, qwen3_moe, gpt_oss, - gemma3). The *_cached_slim and *_slim_moe variants - enable FSU. -factory.h Model registry. Every new family must be wired here - so loadModel can dispatch by ModelType. -main.cpp quick_dot_ai_run executable. -quantize.cpp quick_dot_ai_quantize executable - (Q4_0 / Q4_K / Q6_K / FP16). -huggingface_tokenizer.{cpp,h} Tokenizer adapter over tokenizers-cpp. -llm_util.{cpp,hpp} Generation-loop helpers. -jni/ Android.mk + prepare_encoder.{sh,ps1}. -build_android.sh Core Android build (NDK + Rust target). -build_api_lib.sh libquick_dot_ai_api.so for Android. -build_test_app.sh quick_dot_ai_test_api for Android. -install_android.sh adb push to /data/local/tmp/quick_dot_ai/. -benchmarks/ Android perf tooling (benchmark_android.py et al.). -docs/ architecture.md + demo GIFs. -res/ Drop model directories here (config.json, - tokenizer.json, nntr_config.json, weight .bin, …). -.clang-format clang-format 14 style β€” CI fails on diffs. -.github/workflows/ ci-linux, ci-android, cpp-linter, codeql, - check_count, labeler. -subprojects/nntrainer/ Vendored NNTrainer, built lean - (enable-app=false, enable-test=false, - enable-tflite-{backbone,interpreter}=false). -meson.build / meson_options.txt Top-level build wiring. -``` - ---- - -## Build & verify - -### Linux (mirrors `ci-linux.yml`) - -```bash -sudo apt-get install -y libopenblas-dev libflatbuffers-dev \ - flatbuffers-compiler libiniparser-dev libomp-dev cmake \ - build-essential pkg-config -pip install meson ninja - -meson setup build -Denable-fp16=true -Dthread-backend=omp \ - -Domp-num-threads=4 -ninja -C build -``` - -Expected artifacts under `build/`: - -- `libquick_dot_ai.so` -- `quick_dot_ai_run`, `quick_dot_ai_quantize`, `quick_dot_ai_test_api` -- `layers/libquick_dot_ai_*_layer.so` (one per layer plugin) - -Smoke test the runner with a model directory under `res//`: - -```bash -export OMP_NUM_THREADS=4 OMP_WAIT_POLICY=active \ - OMP_PROC_BIND=true OMP_PLACES=cores -./build/quick_dot_ai_run ./res/qwen3/qwen3-4b/ -``` - -### Android (mirrors `ci-android.yml`) - -Prereqs: NDK r26d (export `ANDROID_NDK`), CMake, Rust with the -`aarch64-linux-android` target, `adb`. - -```bash -export ANDROID_NDK=/path/to/android-ndk -./build_android.sh # core: libquick_dot_ai_core.so + binaries -./build_api_lib.sh # libquick_dot_ai_api.so -./build_test_app.sh # quick_dot_ai_test_api -./install_android.sh # adb push to /data/local/tmp/quick_dot_ai/ -``` - -### Format every changed C/C++ file before committing - -```bash -clang-format-14 -i path/to/changed.cpp path/to/changed.h -``` - -The `cpp-linter.yml` workflow runs `clang-format 14` against -`.clang-format` and gates the PR; `subprojects/` is excluded. - ---- - -## Commit rules (HARD β€” CI / DCO will block violations) - -Every commit MUST end with a `Signed-off-by:` trailer. Use a real name -and an email reachable by the author. The trailer is the project's -DCO sign-off and is non-negotiable. - -### Format - -``` -[] - - - -Signed-off-by: Your Name -``` - -Always sign off with `git commit -s` (or include the trailer manually -when using the GitHub API). Co-authored work gets one additional -`Signed-off-by:` line per author. - -### Subject conventions - -- **Imperative mood**: "add", "fix", "rename" β€” not "added" / "adds". -- **Component prefix** when the change is local to a subsystem. - Prefixes already used in the history (re-use, don't invent): - `[CausalLM]`, `[api]`, `[Android.mk]`, `[neuralnet]`, `[script]`, - `[Docs]`, `ci`, `ci(android)`, `ci(codeql)`, `ci(linux)`. A bare - imperative subject (no bracket) is also accepted for repo-wide - changes β€” see commit `526f361` ("Rename project to Quick.AI; …"). -- **<= 72 chars**, no trailing period. -- Keep brand spelling consistent: human copy β†’ "Quick.AI", code - identifiers β†’ `quick_dot_ai`. - -### Body conventions - -- Explain *why* (motivating bug, requirement, constraint), not - *what the diff does*. -- Wrap at ~72 chars. One blank line between subject and body, and - between body and the trailers. -- Bullets start with `- ` and stay short. -- Call out explicitly when the change touches CI, the build system, - the public API/ABI, or the on-device install layout. - -### Examples drawn from the history - -``` -[CausalLM] fix mmap read in tie-word-embedding -ci(android): pass user-writable prefix to nntrainer package build -ci(codeql): pin source-root and exclude vendored trees from scan -Rename project to Quick.AI; unify identifiers to quick_dot_ai -``` - -### Don'ts - -- Do **not** use `--no-verify` to skip hooks. -- Do **not** `--amend` an already-pushed commit; create a new one. -- Do **not** force-push to `main` or any shared branch. -- Do **not** add `Co-Authored-By:` trailers that imply authorship a - tool does not have. - ---- - -## Branching & PR workflow - -- Work on a topic branch; never commit directly to `main`. Branch - naming patterns observed in the repo: - `feat/...`, `bugfix/...`, `ci/...`, `unittest/...`, `claude/...`. -- A PR is gated by: - | Workflow | What it does | - |---|---| - | `ci-linux.yml` | Meson + Ninja on Ubuntu 22.04 & 24.04. | - | `ci-android.yml` | NDK r26d, arm64-v8a, Rust `aarch64-linux-android`. | - | `cpp-linter.yml` | clang-format 14 against `.clang-format` (subprojects/ ignored). | - | `codeql.yml` | CodeQL c-cpp + python; vendored trees excluded by `.github/codeql/codeql-config.yml`. | -- `check_count.yml` + `labeler.yml` toggle `Need Review` and - `PR/READY2MERGE` based on **2 approving reviewers** (Quick.AI's - threshold; nntrainer uses 3). - ---- - -## Code style essentials - -Read `.clang-format` for the canonical rules. The non-obvious bits: - -- 2-space indent, **80-column** hard limit, tabs forbidden. -- `BreakBeforeBraces: Attach`, `PointerAlignment: Right` - (`int *p`, not `int* p`). -- `SortIncludes: CaseSensitive` β€” keep includes sorted; let - clang-format do it. -- C++17 (`-std=c++17`), C `gnu89` for the public C API. -- Keep new C++ symbols inside `namespace quick_dot_ai`. - -### ABI stability - -`api/causal_lm_api.h` is the integration seam used by Android JNI, -iOS, and server embedders. The following are **frozen**: - -- Symbols: `loadModel`, `runModel`, `getPerformanceMetrics`. -- Enums: `BackendType`, `ModelType`, `ModelQuantizationType` - (and the `CAUSAL_LM_*` enumerator names / values). - -Don't rename them, don't reorder enumerators, don't change numeric -values. If you need a new entry, append at the end. - -### Layers as plugins - -Each transformer building block under `layers/` builds as its own -`shared_library` (named `libquick_dot_ai__layer.so`). New -layers must: - -1. Add `.{cpp,h}` under `layers/`. -2. Declare the build target in `layers/meson.build`. -3. Append the resulting `quick_dot_ai__dep` to - `quick_dot_ai_layer_dependencies` in the top-level `meson.build`. - ---- - -## Adding a new model family - -1. `models//_causallm.{h,cpp}` deriving from the - appropriate base in `models/causal_lm.{h,cpp}` / - `models/transformer.{h,cpp}`. -2. `models//meson.build`; append the family's sources to - `models/meson.build` (and to `quick_dot_ai_src` / - `quick_dot_ai_inc` at the top level if needed). -3. Register the family + a new `ModelType` enumerator in - [`factory.h`](factory.h) so `loadModel` can dispatch to it. -4. Optional: implement custom layers under `layers/` and wire their - deps into the top-level `meson.build`. -5. Drop a runnable model directory under `res///` - containing `config.json`, `generation_config.json`, - `tokenizer.json`, `tokenizer_config.json`, `vocab.json`, - `nntr_config.json`, and the NNTrainer `.bin` weight file - referenced from `nntr_config.json`. - -Verify: `meson setup build && ninja -C build && \ -./build/quick_dot_ai_run ./res///`. - ---- - -## Anti-patterns / things to NOT do - -- Do not reintroduce the old `causallm` C++ namespace, the - `nntr_causallm` / `nntr_quantize` / `test_api` binary names, the - `libcausallm.so` library name, or the - `/data/local/tmp/nntrainer/causallm` install path. They were all - renamed in commit `526f361` and the rename is enforced by CI. -- Do not add system deps Quick.AI doesn't actually use (e.g. - `tensorflow2-lite-dev`, `nnstreamer-dev`). NNTrainer is pulled in - with `enable-tflite-backbone=false`, `enable-tflite-interpreter=false`, - `enable-app=false`, `enable-test=false` β€” keep that surface lean. -- Do not commit meson auto-generated wrap redirects under - `subprojects/*.wrap` β€” they are regeneration artifacts and are - already gitignored. -- Do not relax `BreakBeforeBraces`, `IndentWidth`, or `ColumnLimit` - in `.clang-format` to make a diff fit; reformat the diff instead. -- Do not assume Q4_0 quantized files are portable across - architectures. Quantize on the same ISA (ARM vs x86_64) you serve - from. -- Do not push to `main` directly, and do not force-push to any - shared branch. - ---- - -## For AI coding agents - -This section collects rules that apply specifically to AI coding -agents (Claude Code, and any other tool that drives commits/PRs on -behalf of a human). It supplements β€” does not replace β€” every other -rule above. If you are a human reading this file, you can skip it. - -- **Stay on the topic branch the user named for the session.** Do - not open new branches, rebase shared branches, or rewrite history - on branches the user did not authorize. -- **Do not open a pull request unless the user explicitly asks for - one.** Pushing the branch is fine; opening / merging the PR is the - user's call. -- **Never force-push to `main` or any shared branch**, and avoid - force-push on user-visible feature branches unless the user asks. -- **Keep the DCO trailer honest.** Sign off as the human running the - session (or as instructed). Do not invent `Co-Authored-By:` - trailers. -- **Treat `api/causal_lm_api.h` and `factory.h` as load-bearing.** - Surface any change to those files explicitly in the PR / commit - message so a human reviewer can sanity-check the ABI impact. -- **Run the local format / build smoke tests** described above - before pushing, even when CI will rerun them β€” failing CI on - trivial whitespace wastes a review round trip. diff --git a/README.md b/README.md index dc673ea1..26116e59 100644 --- a/README.md +++ b/README.md @@ -1,365 +1,234 @@ -
- -

β˜„οΈ Quick.AI

- -

The fastest way to run an LLM on the device in your hand.

- -

-Production-grade causal-LM inference on top of NNTrainer β€”
-Qwen 3, GPT-OSS, Gemma 3, Llama and more, with MoE on phones via on-the-fly expert streaming. -

- -

- Linux - Android - Format - CodeQL -
- License - C++17 - Android - Platform - Offline -

- -

- Quick start Β· - Demos Β· - Models Β· - Android Β· - Quantization Β· - Chat Template Β· - Architecture -

- -
- ---- - -
- -### Quick.AI in three numbers - -| Peak RAM | Library size | Network use | -|:---:|:---:|:---:| -| **16.5 GB β†’ 1.3 GB** | **~13 MB** | **0 bytes** | -| Peak RAM for Qwen3-MoE 30B with FSU | Single core inference library | Sent over the network at runtime | - -
- ---- - -## Why Quick.AI? - - - - - - -
- -### MoE that fits in your pocket -Run **30B-parameter Mixture-of-Experts** models in **~1.3 GB of RAM** with Flash Storage Utilization (FSU) β€” experts stream in from disk only when their tokens fire. - -### Tuned for the metal -Hand-written kernels for **ARMv8.2-a** (FP16, dotprod, i8mm) and **AVX2** on x86_64. Multi-threaded with OpenMP, NEON-vectorized hot paths. - -### Offline by design -Weights, prompts, and activations stay on the device. No telemetry, no Python runtime at inference time. - - - -### Pluggable layers -Each transformer building block (RMSNorm, SwiGLU, QKV, MHA core, tied embeddings…) ships as an **independently loadable `.so`** β€” drop in your own without recompiling the world. - -### Embed anywhere -Native **C and C++ APIs** plus a clean Android JNI build. Same source tree builds for desktop, server, and mobile. - -### Zero‑install quantizer -`quick_dot_ai_quantize` shrinks an FP32 checkpoint to **Q4_0 / Q4_K / Q6_K / FP16** in one command. - -
- ---- - -## See it in action - -
- -#### MoE inference on a phone - - - - - - - - - - -
GPT-OSS 20BQwen3-MoE 30B-A3B
- -#### FSU: the same model, the same machine, a 12Γ— memory cut - - - - - - - - - - - - - - -
Load whole model
Qwen3-30B-A3B
Load experts on the fly
Quick.AI / FSU
Memory: 16.5 GBMemory: 1.3 GB
- -
- ---- +# Quick.AI ⚑ + +Quick.AI is an on-device LLM stack built around nntrainer CausalLM extensions. +It provides self-registering C++ model plugins, a handle-based C API, Qualcomm +QNN integration, and an Android AAR (`QuickDotAI`) with native and LiteRT-LM +backends. + +## πŸ“š Table of Contents + +- [Features](#-features) +- [Supported Models](#-supported-models) +- [Quick Start](#-quick-start) +- [Prerequisites](#-prerequisites) +- [Building](#-building) +- [How to Create a Custom Model](#-how-to-create-a-custom-model) +- [Architecture](#-architecture) +- [Directory Structure](#-directory-structure) +- [Documentation](#-documentation) + +## ✨ Features + +- **Self-registering model plugins**: add custom `CausalLM` models without + changing nntrainer source files. +- **Handle-based C API**: load independent model handles, stream tokens, cancel + in-flight runs, collect metrics, and use OpenAI-style messages. +- **Android AAR**: `QuickDotAI` exposes `NativeQuickDotAI` for nntrainer/QNN + models and `LiteRTLm` for Gemma-family `.litertlm` models. +- **Structured generation**: XGrammar-backed tool/schema constrained output via + `runModelHandleWithTool()`. +- **Chat templates**: OpenAI-compatible `messages`, `tools`, and `functions` + formatting through model-local `chat_template.jinja` or + `tokenizer_config.json`. +- **Multimodal paths**: LiteRT-LM image input for Gemma-family models and native + QNN vision paths where the loaded model supplies vision + LLM sub-models. + +## πŸ€– Supported Models + +Models are identified by a **string model id**. The public model catalog is +provided by `getModelCatalogJson()` (C API) and `ModelCatalog` (Android AAR). +Each model self-registers its descriptor at load time β€” see +[`docs/Architecture.md`](docs/Architecture.md) for how the registry works. + +| Model id | Runtime | Backends | Capabilities | +|------------------|---------|----------|------------------------------------| +| `qwen3-0.6b` | NATIVE | CPU, GPU | Streaming, Tool use | +| `qwen3-1.7b-q40` | NATIVE | CPU, GPU | Streaming, Tool use | +| `tiny-bert` | NATIVE | CPU | Embedding | +| `function-gemma` | NATIVE | CPU, GPU | Tool use | +| `gemma4-cpu` | NATIVE | CPU | Streaming | +| `gemma4-e2b-qnn` | NATIVE | QNN | Messages API | +| `vjepa-qnn` | NATIVE | QNN | Multimodal, Multi-image, Messages API | +| `gemma4` | LiteRT | GPU | Streaming, Multimodal, Messages API | + +> QNN model ids (`gemma4-e2b-qnn`, `vjepa-qnn`) only appear in the catalog on +> Android builds compiled with `--enable-qnn`. + +The `ModelType` C enum (`CAUSAL_LM_MODEL_*` constants) is a **deprecated +compatibility shim**. Prefer string model ids and `loadModelHandleByName()` for +new code. + +Model configuration files are placed under `src/res/` and model +implementations live under `src/models/`. + +## πŸš€ Quick Start + +### Android AAR + +Quick.AI currently ships the `QuickDotAI` AAR module and the direct +`SampleTestAPP` sample. The REST/foreground-service layer described in +older plans is not part of the current Gradle build. + +```kotlin +dependencies { + implementation(project(":QuickDotAI")) +} +``` -## Supported models +See [`docs/ChatAndOpenAIUsage.md`](docs/ChatAndOpenAIUsage.md) for Chat tab, +OpenAI tab, JSON streaming, and XGrammar examples. See +[`Android/QuickDotAI/README.md`](Android/QuickDotAI/README.md) for the full AAR +API. -| Family | Variants | Notes | -|---|---|---| -| **Llama** | 1B / 3B / 7B-class | reference architecture | -| **Qwen 2** | 0.5B – 7B | causal LM | -| **Qwen 3** | 0.6B Β· 1.7B Β· 4B Β· 8B Β· 14B Β· 32B | [HF: Qwen3-4B](https://huggingface.co/Qwen/Qwen3-4B) | -| **Qwen 3-MoE** | 30B-A3B | [HF: Qwen3-30B-A3B](https://huggingface.co/Qwen/Qwen3-30B-A3B-Instruct-2507) Β· **FSU** | -| **GPT-OSS** | MoE 20B Β· 120B | [HF: gpt-oss-20b](https://huggingface.co/openai/gpt-oss-20b) Β· **FSU** | -| **Gemma 3** | all causal variants | + sentence-embedding head | +### C API -> **Bring your own**: subclass the causal-LM template under `models//` and the [factory](factory.h) wires it in. See the [model author guide](models/README.md). +```cpp +#include "quick_dot_ai_api.h" ---- +// Preferred: load by string model id +CausalLmHandle handle = nullptr; +loadModelHandleByName(CAUSAL_LM_BACKEND_CPU, "qwen3-0.6b", + CAUSAL_LM_QUANTIZATION_W4A32, + nullptr, "/models", &handle); -## Quick start +runModelHandleStreaming(handle, "Hello!", [](const char *delta, void *) { + std::cout << delta << std::flush; + return 0; +}, nullptr); -```bash -# 1 Β· Clone (with submodules β€” NNTrainer rides along) -git clone --recursive https://github.com/nntrainer/Quick.AI.git -cd Quick.AI - -# 2 Β· System deps (Ubuntu 22.04 / 24.04) -sudo apt-get install -y libopenblas-dev libflatbuffers-dev flatbuffers-compiler \ - build-essential pkg-config -pip install meson ninja - -# 3 Β· Build (~1 min on a modern laptop) -meson setup build -Dnntrainer:enable-fp16=true -Dnntrainer:thread-backend=omp -Dnntrainer:omp-num-threads=4 -ninja -C build - -# 4 Β· Generate -export OMP_NUM_THREADS=4 OMP_WAIT_POLICY=active OMP_PROC_BIND=true OMP_PLACES=cores -./build/quick_dot_ai_run ./res/qwen3/qwen3-4b/ +destroyModelHandle(handle); ``` -> **Model layout** β€” drop a model into `res//` containing -> `config.json`, `generation_config.json`, `tokenizer.json`, `tokenizer_config.json`, -> `vocab.json`, `nntr_config.json`, and the NNTrainer `.bin` weight file referenced from `nntr_config.json`. +Use `getModelCatalogJson()` to enumerate all available model descriptors at +runtime. See [`api/README.md`](api/README.md) for the complete C API reference. ---- +## 🧰 Prerequisites -## Android build +- C++17 compiler +- [Meson](https://mesonbuild.com/) >= 0.55.0 +- [Ninja](https://ninja-build.org/) +- Android NDK for Android builds +- OpenBLAS for x86 builds (`apt install libopenblas-dev`) +- nntrainer submodule dependencies +- Qualcomm QNN and Hexagon SDK for `--enable-qnn` Android builds -
-Click to expand the modular Android pipeline +## πŸ—οΈ Building -
+All native builds go through the root `build.sh`. -**Prerequisites:** Android NDK r21d+, CMake, [Rust](https://rustup.rs) (for `tokenizers-cpp`), `adb`. +### x86 / Linux ```bash -export ANDROID_NDK=/path/to/android-ndk -./build_android.sh # libquick_dot_ai_core.so Β· quick_dot_ai Β· quick_dot_ai_quantize -./build_api_lib.sh # (optional) libquick_dot_ai_api.so -./build_test_app.sh # (optional) quick_dot_ai_test_api -./install_android.sh # adb push to /data/local/tmp/quick_dot_ai/ +./build.sh +./build.sh --target=src +./build.sh --clean ``` -| Script | Output(s) | Depends on | -|---|---|---| -| `build_android.sh` | `libquick_dot_ai_core.so`, `quick_dot_ai`, `quick_dot_ai_quantize` | NDK + Rust | -| `build_api_lib.sh` | `libquick_dot_ai_api.so` | core lib | -| `build_test_app.sh` | `quick_dot_ai_test_api` | core + api lib | -| `install_android.sh` | `/data/local/tmp/quick_dot_ai/*` | adb device | - -Run on the phone: +Run the standalone executable: ```bash -adb shell /data/local/tmp/quick_dot_ai/run_causallm.sh -adb shell /data/local/tmp/quick_dot_ai/run_quantize.sh -adb shell /data/local/tmp/quick_dot_ai/run_test_api.sh "" +LD_LIBRARY_PATH=nntrainer/builddir_x86/nntrainer:nntrainer/builddir_x86/api/ccapi:builddir_x86/src:builddir_x86/api \ + builddir_x86/src/quick_dot_ai /path/to/model "Your prompt" ``` -All artifacts land under `jni/libs/arm64-v8a/`. - -
- ---- - -## Quantization +Plugin mode is still available for the original `nntr_causallm` executable: ```bash -# Default: FC β†’ Q4_0, embedding β†’ FP32 -./build/quick_dot_ai_quantize /path/to/qwen3-4b - -# Mix dtypes per layer family -./build/quick_dot_ai_quantize /path/to/qwen3-4b \ - --fc_dtype Q4_0 --embd_dtype Q6_K --lmhead_dtype FP16 - -# Write into a separate output directory -./build/quick_dot_ai_quantize /path/to/qwen3-4b -o /out/qwen3-4b-q40 +LD_PRELOAD=$(pwd)/builddir_x86/src/libquick_dot_ai.so nntr_causallm /path/to/model ``` -| dtype | bits | typical use | -|---|---|---| -| `FP32` | 32 | embedding, LM head (default) | -| `FP16` | 16 | LM head when memory matters | -| `Q4_0` | 4 | FC layers (default), fastest path | -| `Q4_K` | 4 | FC layers, K-quant accuracy | -| `Q6_K` | 6 | embedding when 4-bit hurts quality | - -> **Q4_0 is ISA-specific** β€” an x86-quantized Q4_0 binary is not byte-compatible with ARM. Quantize on the same architecture you serve from. - -After quantization, point `quick_dot_ai_run` at the quantized directory (or `mv nntr_config_quantized.json nntr_config.json` in place and rerun). - ---- - -## Chat Template - -Quick.AI supports automatic chat template formatting by reading the `chat_template` field from HuggingFace's `tokenizer_config.json`. This eliminates the need for hardcoded per-model chat formatting. - -### How it works - -Most HuggingFace models include a `tokenizer_config.json` with a `chat_template` field (Jinja2 format) that defines how to format conversations. Quick.AI includes a built-in mini Jinja2 renderer that processes these templates at runtime. - -When a `tokenizer_config.json` is present in the model directory: -- **CLI (`quick_dot_ai_run`)**: Raw user input provided as a command-line argument is automatically wrapped with the chat template. -- **C API**: The `apply_chat_template()` function uses the dynamic template instead of hardcoded formats. - -If `tokenizer_config.json` is absent or does not contain a `chat_template` field, a warning is printed and the system falls back to hardcoded per-architecture templates (Llama, Qwen, Gemma3). - -### Supported template features - -The built-in Jinja2 renderer supports the following constructs commonly used in HuggingFace chat templates: - -| Feature | Example | -|---------|---------| -| For loops | `{% for message in messages %}...{% endfor %}` | -| Conditionals | `{% if %}...{% elif %}...{% else %}...{% endif %}` | -| Output expressions | `{{ bos_token }}` | -| Variable assignment | `{% set offset = 1 %}` | -| Dict/array access | `message['role']`, `messages[0]` | -| String concatenation | `'<\|im_start\|>' + message['role']` | -| Comparison operators | `==`, `!=`, `>`, `<`, `>=`, `<=` | -| Boolean operators | `and`, `or`, `not` | -| Loop variables | `loop.first`, `loop.last`, `loop.index`, `loop.index0` | -| Filters | `\| trim`, `\| length`, `\| tojson` | -| String methods | `.strip()`, `.startswith()`, `.upper()`, `.split()` | -| Containment test | `'keyword' in message['content']` | -| Namespace | `namespace()` for cross-scope variable mutation | -| Whitespace control | `{%- -%}`, `{{- -}}` | - -### Required files - -To use chat templates, ensure `tokenizer_config.json` is in your model directory alongside the other config files. This file is included by default when downloading models from HuggingFace. - -### Example +### Android Arm64 Build ```bash -# With tokenizer_config.json present, raw input is auto-formatted: -./build/quick_dot_ai_run /path/to/model "What is machine learning?" +export ANDROID_NDK=/path/to/android-ndk -# The input will be automatically wrapped, e.g. for Qwen3: -# <|im_start|>user -# What is machine learning?<|im_end|> -# <|im_start|>assistant +./build.sh --platform=android +./build.sh --platform=android --enable-qnn +./install_android.sh ``` -### Multi-turn conversations (API) +To build native libraries, copy them into `Android/QuickDotAI/prebuilt_libs/`, +and install `SampleTestAPP`: -The C API supports multi-turn conversations through `ChatMessage`: - -```cpp -#include "chat_template.h" - -quick_dot_ai::ChatTemplate tmpl = quick_dot_ai::ChatTemplate::fromFile("tokenizer_config.json"); - -std::vector messages = { - {"system", "You are a helpful assistant."}, - {"user", "Hello!"}, - {"assistant", "Hi there!"}, - {"user", "How are you?"} -}; - -std::string formatted = tmpl.apply(messages); +```bash +export ANDROID_NDK=/path/to/android-ndk +./apk-build-install.sh ``` ---- - -## Continuous integration - -Every PR is gated by: - -| Check | What it does | -|---|---| -| **Linux build** | Meson + Ninja on Ubuntu 22.04 & 24.04 | -| **Android build** | `arm64-v8a`, NDK r26d, Rust `aarch64-linux-android` | -| **C++ format** | clang-format 14 against `.clang-format` | -| **CodeQL** | security & quality static analysis | - -Workflows live under [`.github/workflows/`](.github/workflows/). - ---- +Before running `apk-build-install.sh`, set `NDK_ROOT` inside the script to your +local Android NDK path. -## Further reading +### Build Options -- [Architecture deep-dive](docs/architecture.md) β€” layered diagram, module-by-module breakdown, design choices -- [Model implementation guide](models/README.md) -- [C API reference](api/README.md) -- [Benchmark tooling](benchmarks/README.md) -- Talks & papers: - - [Memory-Efficient LLM Inference on Edge Devices with NNTrainer](https://youtu.be/J2tUmi4bwMY?si=rJyiXkwr5iFrMhIK) β€” Open Source Summit 2025 Seoul - - [A New Frontier of AI: On-Device AI Training and Personalization](https://dl.acm.org/doi/abs/10.1145/3639477.3639716) β€” ICSE-SEIP 2024 - - [NNTrainer: Light-Weight On-Device Training Framework](https://arxiv.org/pdf/2206.04688.pdf) β€” arXiv 2022 - ---- - -## Contributing - -We love PRs. Before opening one: - -1. `meson setup build && ninja -C build` β€” the same command CI runs. -2. `clang-format -i` on any changed C/C++ files (config in `.clang-format`). -3. Adding a new model family? Drop it under `models//`, wire it into `models/meson.build`, and register it in [`factory.h`](factory.h). - -## License +| Option | Default | Description | +|---|---|---| +| `--platform=x86|android` | `x86` | Target platform | +| `--target=src,api,api-test,qnn` | `all` | Comma-separated target set | +| `--enable-qnn` | off | Enable QNN integration (Android only) | +| `--clean` | off | Clean rebuild | -Quick.AI is released under the [Apache License 2.0](LICENSE). NNTrainer, bundled as a submodule, is also Apache-2.0. +Meson options are declared in [`meson_options.txt`](meson_options.txt). -## Citation +## 🧩 How to Create a Custom Model -If Quick.AI is useful for your research, please cite the NNTrainer paper it builds on: +1. Add your implementation under `src/models//`. +2. Inherit from `causallm::CausalLM` or the appropriate Quick.AI model base. +3. Register the architecture in your `.cpp` file: -```bibtex -@inproceedings{10.1145/3639477.3639716, - author = {Moon, Jijoong and Lee, Hyeonseok and Chu, Jiho and Park, Donghak and Hong, Seungbaek and Seo, Hyungjun and Jeong, Donghyeon and Kong, Sungsik and Ham, Myungjoo}, - title = {A New Frontier of AI: On-Device AI Training and Personalization}, - booktitle = {Proceedings of the 46th International Conference on Software Engineering: Software Engineering in Practice}, - series = {ICSE-SEIP '24}, - year = {2024}, - pages = {323--333}, - doi = {10.1145/3639477.3639716} +```cpp +__attribute__((constructor)) static void register_my_models() { + causallm::Factory::Instance().registerModel( + "MyModelForCausalLM", + [](causallm::json cfg, causallm::json generation_cfg, + causallm::json nntr_cfg) { + return std::make_unique( + cfg, generation_cfg, nntr_cfg); + }); } ``` -
- ---- +4. Add model config files under `src/res//`. +5. Add `src/models//meson.build` and include it from + `src/models/meson.build`. + +## πŸ›οΈ Architecture + +Quick.AI uses nntrainer's CausalLM application and `Factory` registration, while +Quick.AI-specific models are compiled into `quick_dot_ai` and +`libquick_dot_ai.so`. The deployable C API is `libquick_dot_ai_api.so`. + +See [`docs/Architecture.md`](docs/Architecture.md) for native architecture and +[`Android/Architecture.md`](Android/Architecture.md) for Android module status. + +## πŸ—‚οΈ Directory Structure + +```text +project-root/ +β”œβ”€β”€ nntrainer/ # nntrainer submodule +β”œβ”€β”€ xgrammar/ # XGrammar submodule +β”œβ”€β”€ src/ # Native CausalLM extensions and model configs +β”œβ”€β”€ api/ # libquick_dot_ai_api.so public C API +β”œβ”€β”€ qnn/ # Android QNN context library +β”œβ”€β”€ Android/ +β”‚ β”œβ”€β”€ QuickDotAI/ # Android AAR +β”‚ └── SampleTestAPP/ # Direct sample app +β”œβ”€β”€ docs/ # Canonical project documentation +β”œβ”€β”€ gemma_python/ # Gemma4 quantization-oriented Python package +β”œβ”€β”€ build.sh +β”œβ”€β”€ install_android.sh +└── apk-build-install.sh +``` -Built on top of NNTrainer. +## πŸ“– Documentation -
+| Document | Audience | Content | +|---|---|---| +| [`docs/Guides.md`](docs/Guides.md) | All users | Entry points by platform and goal | +| [`docs/ChatAndOpenAIUsage.md`](docs/ChatAndOpenAIUsage.md) | App/API users | Chat tab, OpenAI tab, JSON streaming, and XGrammar examples | +| [`docs/Architecture.md`](docs/Architecture.md) | Native contributors | Plugin, build, and C API architecture | +| [`api/README.md`](api/README.md) | C/C++ users | C API reference | +| [`Android/QuickDotAI/README.md`](Android/QuickDotAI/README.md) | Android users | AAR API surface and types | +| [`Android/Architecture.md`](Android/Architecture.md) | Android contributors | Current modules and planned service layer | +| [`docs/ChatTemplate.md`](docs/ChatTemplate.md) | Model/API users | Chat template discovery and JSON request handling | +| [`docs/XGrammarReference.md`](docs/XGrammarReference.md) | Tool-calling users | XGrammar internals, toolsets, cache behavior, and native API notes | +| [`qnn/README.md`](qnn/README.md) | QNN developers | QNN context development guide | diff --git a/api-app/jni/Android.mk b/api-app/jni/Android.mk new file mode 100644 index 00000000..11c625bf --- /dev/null +++ b/api-app/jni/Android.mk @@ -0,0 +1,56 @@ +LOCAL_PATH := $(call my-dir) + +# ── Path configuration ────────────────────────────────────────────────── +ifndef NNTRAINER_ROOT +NNTRAINER_ROOT := $(LOCAL_PATH)/../../nntrainer +endif + +API_ROOT := $(LOCAL_PATH)/../../api +INCLUDE_ROOT := $(LOCAL_PATH)/../include + +# ── Prebuilt nntrainer libraries (needed at runtime) ───────────────────── +include $(CLEAR_VARS) +LOCAL_MODULE := nntrainer +LOCAL_SRC_FILES := $(NNTRAINER_ROOT)/builddir/android_build_result/lib/$(TARGET_ARCH_ABI)/libnntrainer.so +include $(PREBUILT_SHARED_LIBRARY) + +include $(CLEAR_VARS) +LOCAL_MODULE := ccapi-nntrainer +LOCAL_SRC_FILES := $(NNTRAINER_ROOT)/builddir/android_build_result/lib/$(TARGET_ARCH_ABI)/libccapi-nntrainer.so +include $(PREBUILT_SHARED_LIBRARY) + +# ── Prebuilt causallm library (needed at runtime) ──────────────────────── +include $(CLEAR_VARS) +LOCAL_MODULE := causallm +LOCAL_SRC_FILES := $(LOCAL_PATH)/../../src/jni/libs/$(TARGET_ARCH_ABI)/libcausallm.so +include $(PREBUILT_SHARED_LIBRARY) + +# ── Prebuilt API library ──────────────────────────────────────────────── +include $(CLEAR_VARS) +LOCAL_MODULE := quick_dot_ai_api +LOCAL_SRC_FILES := $(API_ROOT)/jni/libs/$(TARGET_ARCH_ABI)/libquick_dot_ai_api.so +include $(PREBUILT_SHARED_LIBRARY) + +# ══════════════════════════════════════════════════════════════════════════ +# Module: quick_dot_ai_test (test executable) +# +# Only includes quick_dot_ai_api.h, links only libquick_dot_ai_api.so. +# Runtime dependencies (nntrainer, causallm) are resolved transitively. +# ══════════════════════════════════════════════════════════════════════════ +include $(CLEAR_VARS) + +LOCAL_ARM_NEON := true +LOCAL_CFLAGS += -std=c++17 +LOCAL_CXXFLAGS += -std=c++17 -frtti +LOCAL_MODULE := quick_dot_ai_test +LOCAL_LDLIBS := -llog -landroid + +LOCAL_SRC_FILES := ../test_api.cpp + +# API headers (copied to include/ by build script) +LOCAL_C_INCLUDES += $(INCLUDE_ROOT) + +# Link against the API library; runtime deps are pulled transitively +LOCAL_SHARED_LIBRARIES := quick_dot_ai_api nntrainer ccapi-nntrainer causallm + +include $(BUILD_EXECUTABLE) diff --git a/api-app/jni/Application.mk b/api-app/jni/Application.mk new file mode 100644 index 00000000..b00bb435 --- /dev/null +++ b/api-app/jni/Application.mk @@ -0,0 +1,4 @@ +APP_ABI := arm64-v8a +APP_PLATFORM := android-29 +APP_STL := c++_shared +NDK_TOOLCHAIN_VERSION := clang \ No newline at end of file diff --git a/api-app/meson.build b/api-app/meson.build new file mode 100644 index 00000000..dda52e8a --- /dev/null +++ b/api-app/meson.build @@ -0,0 +1,11 @@ +# api-app/meson.build β€” builds quick_dot_ai_test executable + +if not get_option('enable-api-test') + subdir_done() +endif + +executable('quick_dot_ai_test', + 'test_api.cpp', + dependencies: [quick_dot_ai_api_dep], + install: true, +) diff --git a/api-app/test_api.cpp b/api-app/test_api.cpp new file mode 100644 index 00000000..f49c0b70 --- /dev/null +++ b/api-app/test_api.cpp @@ -0,0 +1,411 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * @file test_api.cpp + * @brief Test application for src C API + * @note This file only includes quick_dot_ai_api.h and standard headers. + * It links against libquick_dot_ai_api.so. + */ + +#include "quick_dot_ai_api.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +// ── ANSI color codes ───────────────────────────────────────────────────────── +namespace clr { +const char *reset = "\033[0m"; +const char *bold = "\033[1m"; +const char *dim = "\033[2m"; +const char *cyan = "\033[36m"; +const char *green = "\033[32m"; +const char *yellow = "\033[33m"; +const char *red = "\033[31m"; +const char *magenta = "\033[35m"; +const char *blue = "\033[34m"; +const char *bold_cyan = "\033[1;36m"; +const char *bold_green = "\033[1;32m"; +const char *bold_yellow = "\033[1;33m"; +const char *bold_red = "\033[1;31m"; +const char *bold_magenta = "\033[1;35m"; +const char *bold_blue = "\033[1;34m"; +const char *bold_white = "\033[1;37m"; +} // namespace clr + +// ── ASCII art banner ───────────────────────────────────────────────────────── +static void print_banner() { + std::cout << clr::bold_cyan; + std::cout << R"( + ____ _ _ _ ___ + / __ \ (_) | | / \ |_ _| + | | | |_ _ _ ___| | __ / _ \ | | + | | | | | | | |/ __| |/ / / ___ \ | | + | |__| | |_| | | (__| < _ / / \ \| |_ + \___\_\\__,_|_|\___|_|\_(_)_/ \_\___| )" + << "\n"; + std::cout << clr::reset; + std::cout << clr::dim << " ─────────────────────────────────────────────\n" + << " On-Device LLM Inference Engine | API Test\n" + << " ─────────────────────────────────────────────" << clr::reset + << "\n\n"; +} + +// ── Box-drawing helpers ────────────────────────────────────────────────────── +static void print_section(const char *title, const char *color) { + std::cout << color << clr::bold << "β”Œβ”€ " << title << " "; + int pad = 50 - static_cast(strlen(title)); + for (int i = 0; i < pad; i++) + std::cout << "─"; + std::cout << "┐" << clr::reset << "\n"; +} + +static void print_section_end(const char *color) { + std::cout << color << "β””"; + for (int i = 0; i < 53; i++) + std::cout << "─"; + std::cout << "β”˜" << clr::reset << "\n\n"; +} + +static void print_kv(const char *key, const std::string &value, + const char *color) { + std::cout << color << "β”‚" << clr::reset << " " << clr::dim << std::left + << std::setw(18) << key << clr::reset << clr::bold_white << value + << clr::reset << "\n"; +} + +static void print_status(const char *msg, const char *icon, const char *color) { + std::cout << color << icon << " " << msg << clr::reset << "\n"; +} + +static void print_error(const std::string &msg) { + std::cout << clr::bold_red << " ERROR " << clr::reset << clr::red << msg + << clr::reset << "\n"; +} + +// ── Usage ──────────────────────────────────────────────────────────────────── +static void print_usage(const char *prog) { + print_banner(); + print_section("Usage", clr::yellow); + std::cout << clr::yellow << "β”‚" << clr::reset << " " << clr::bold_white + << prog << clr::reset << " [prompt] [chat_tpl] [quant] " + << "[verbose]\n"; + std::cout << clr::yellow << "β”‚" << clr::reset << "\n"; + print_kv("model", "qwen3-0.6b | gemma4-cpu | gemma4-e2b-qnn | function_gemma", + clr::yellow); + print_kv("prompt", "\"Hello, how are you?\"", clr::yellow); + print_kv("chat_tpl", "true | false (default: true)", clr::yellow); + print_kv("quant", "W4A32 | W16A16 | W8A16 | W32A32", clr::yellow); + print_kv("verbose", "true | false (default: true)", clr::yellow); + print_kv("model_base_path", + "Base directory for models (or set QUICKAI_MODEL_BASE_PATH)", + clr::yellow); + print_section_end(clr::yellow); +} + +int main(int argc, char *argv[]) { + if (argc < 2) { + print_error("Missing required argument: "); + print_usage(argv[0]); + return 1; + } + + // ── Parse arguments ────────────────────────────────────────────────────── + const char *model_name = argv[1]; + const char *prompt = (argc >= 3) ? argv[2] : "Hello, how are you?"; + + bool use_chat_template = true; + if (argc >= 4) { + std::string arg(argv[3]); + std::transform(arg.begin(), arg.end(), arg.begin(), ::tolower); + use_chat_template = (arg == "true" || arg == "1"); + } + + ModelQuantizationType quant_type = CAUSAL_LM_QUANTIZATION_W4A32; + std::string quant_str = "W4A32"; + if (argc >= 5) { + quant_str = std::string(argv[4]); + if (quant_str == "W4A32") { + quant_type = CAUSAL_LM_QUANTIZATION_W4A32; + } else if (quant_str == "W16A16") { + quant_type = CAUSAL_LM_QUANTIZATION_W16A16; + } else if (quant_str == "W8A16") { + quant_type = CAUSAL_LM_QUANTIZATION_W8A16; + } else if (quant_str == "W32A32") { + quant_type = CAUSAL_LM_QUANTIZATION_W32A32; + } + } + + bool verbose = true; + if (argc >= 6) { + std::string arg(argv[5]); + verbose = (arg == "1" || arg == "true"); + } + + // Model base path: CLI arg > env var > nullptr (uses C API default) + const char *model_base_path = nullptr; + std::string model_base_path_storage; + if (argc >= 7) { + model_base_path_storage = argv[6]; + model_base_path = model_base_path_storage.c_str(); + } else { + const char *env_path = std::getenv("QUICKAI_MODEL_BASE_PATH"); + if (env_path != nullptr && strlen(env_path) > 0) { + model_base_path_storage = env_path; + model_base_path = model_base_path_storage.c_str(); + } + } + + // ── Banner ───────────────────────────────────────────────────────────── + print_banner(); + + // ── Configuration ────────────────────────────────────────────────────── + print_section("Configuration", clr::cyan); + print_kv("Model", model_name, clr::cyan); + print_kv("Prompt", std::string("\"") + prompt + "\"", clr::cyan); + print_kv("Chat Template", use_chat_template ? "Yes" : "No", clr::cyan); + print_kv("Quantization", quant_str, clr::cyan); + print_kv("Verbose", verbose ? "Yes" : "No", clr::cyan); + print_kv("Model Base Path", + model_base_path ? model_base_path : "(C API default)", clr::cyan); + print_section_end(clr::cyan); + + // ── Set options ──────────────────────────────────────────────────────── + Config config; + config.use_chat_template = use_chat_template; + config.debug_mode = verbose; + config.verbose = verbose; + + ErrorCode err = setOptions(config); + if (err != CAUSAL_LM_ERROR_NONE) { + print_error("Failed to set options (code " + std::to_string(err) + ")"); + return 1; + } + + // ── Resolve model type ───────────────────────────────────────────────── + // Public models use the compat enum path (loadModelHandle). Any other name + // is treated as a catalog id and loaded via loadModelHandleByName β€” this is + // how proprietary model plugins (registered in the catalog, not the enum) + // are selected without per-model edits here. + std::string model_name_str(model_name); + std::transform(model_name_str.begin(), model_name_str.end(), + model_name_str.begin(), ::tolower); + + ModelType model_type = + CAUSAL_LM_MODEL_QWEN3_0_6B; // default, overridden below + std::string catalog_id; // non-empty β†’ use loadModelHandleByName + bool use_by_name = false; + + if (model_name_str == "qwen3-0.6b") { + model_type = CAUSAL_LM_MODEL_QWEN3_0_6B; + } else if (model_name_str == "qwen3-1.7b-q40" || + model_name_str == "qwen3_1.7b_q40") { + model_type = CAUSAL_LM_MODEL_QWEN3_1_7B_Q40; + } else if (model_name_str == "tiny_bert" || model_name_str == "tiny-bert") { + model_type = CAUSAL_LM_MODEL_TINY_BERT; + } else if (model_name_str == "function_gemma" || + model_name_str == "function-gemma") { + model_type = CAUSAL_LM_MODEL_FUNCTION_GEMMA; + } else if (model_name_str == "gemma4_cpu" || model_name_str == "gemma4-cpu") { + model_type = CAUSAL_LM_MODEL_GEMMA4_CPU; + } else if (model_name_str == "gemma4_e2b_qnn" || + model_name_str == "gemma4-e2b-qnn") { + model_type = CAUSAL_LM_MODEL_GEMMA4_E2B_QNN; + } else { + // Not a built-in enum model: treat the given name as a catalog id and + // load it by name. Proprietary model plugins register themselves in the + // catalog, so they are reachable here without being named explicitly. + catalog_id = model_name_str; + use_by_name = true; + } + + // ── Load/Unload Stress Test ──────────────────────────────────────────── + const int STRESS_CYCLES = 1; + print_section("Load/Unload Stress Test", clr::blue); + + for (int i = 0; i < STRESS_CYCLES; ++i) { + std::cout << clr::blue << "β”‚" << clr::reset << " " << clr::bold_white + << "Cycle " << (i + 1) << "/" << STRESS_CYCLES << clr::reset + << ": "; + + // Load + CausalLmHandle cycle_handle = nullptr; + if (use_by_name) { + err = loadModelHandleByName(CAUSAL_LM_BACKEND_CPU, catalog_id.c_str(), + quant_type, nullptr, model_base_path, + &cycle_handle); + } else { + err = loadModelHandle(CAUSAL_LM_BACKEND_CPU, model_type, quant_type, + nullptr, model_base_path, &cycle_handle); + } + if (err != CAUSAL_LM_ERROR_NONE) { + print_error("loadModel failed at cycle " + std::to_string(i + 1) + + " (code " + std::to_string(err) + ")"); + return 1; + } + std::cout << clr::bold_green << "LOAD OK" << clr::reset; + + // Unload + err = unloadModelHandle(cycle_handle); + if (err != CAUSAL_LM_ERROR_NONE) { + print_error("unloadModelHandle failed at cycle " + std::to_string(i + 1) + + " (code " + std::to_string(err) + ")"); + destroyModelHandle(cycle_handle); + return 1; + } + std::cout << clr::dim << " β†’ " << clr::reset; + std::cout << clr::bold_yellow << "UNLOAD OK" << clr::reset; + + // Destroy handle (unload keeps struct alive, must destroy to avoid leak) + destroyModelHandle(cycle_handle); + std::cout << clr::dim << " β†’ " << clr::reset; + std::cout << clr::dim << "DESTROY OK" << clr::reset << "\n"; + } + + std::cout << clr::blue << "β”‚" << clr::reset << "\n"; + std::cout << clr::blue << "β”‚" << clr::reset << " " << clr::bold_white + << "Final load (#" << (STRESS_CYCLES + 1) << "):" << clr::reset + << " Loading " << clr::bold_white << model_name << clr::reset + << " (" << quant_str << ") ...\n"; + + CausalLmHandle handle = nullptr; + if (use_by_name) { + err = loadModelHandleByName(CAUSAL_LM_BACKEND_CPU, catalog_id.c_str(), + quant_type, nullptr, model_base_path, &handle); + } else { + err = loadModelHandle(CAUSAL_LM_BACKEND_CPU, model_type, quant_type, + nullptr, model_base_path, &handle); + } + if (err != CAUSAL_LM_ERROR_NONE) { + print_error("Final loadModel failed (code " + std::to_string(err) + ")"); + return 1; + } + + std::cout << clr::blue << "β”‚" << clr::reset << " "; + print_status("Model loaded successfully", ">>", clr::bold_green); + print_section_end(clr::blue); + + // ── Inference ────────────────────────────────────────────────────────── + print_section("Inference", clr::green); + std::cout << clr::green << "β”‚" << clr::reset << " " << clr::dim + << "Input: " << clr::reset << clr::bold_white << prompt + << clr::reset << "\n"; + std::cout << clr::green << "β”‚" << clr::reset << "\n"; + + const char *outputText = nullptr; + // CausalLMChatMessage msg; + // msg.role = "user"; + // msg.content = prompt; + // err = runModelHandleWithMessages(handle, &msg, 1, true, &outputText); + + // XGrammar Test + auto tool_name = "web_search"; + auto schema = + "{\"type\": \"object\",\"properties\": {\"query\": {\"type\": \"string\", " + "\"description\": \"Search query in the most effective language for " + "results (use Korean for Korean local info, English for global " + "topics)\"},\"count\": {\"type\": \"integer\", \"description\": \"Number " + "of results to return (default 5, max 10)\"}},\"required\": [\"query\"]}"; + err = runModelHandleWithTool(handle, prompt, &outputText, tool_name, schema); + + if (err != CAUSAL_LM_ERROR_NONE) { + print_error("Inference failed (code " + std::to_string(err) + ")"); + return 1; + } + + if (outputText) { + std::cout << clr::green << "β”‚" << clr::reset << " " << clr::dim + << "Output:" << clr::reset << "\n"; + std::cout << clr::green << "β”‚" << clr::reset << " " << clr::bold_white + << outputText << clr::reset << "\n"; + } + print_section_end(clr::green); + + // ── Performance Metrics ──────────────────────────────────────────────── + print_section("Performance Metrics", clr::magenta); + + PerformanceMetrics metrics; + memset(&metrics, 0, sizeof(metrics)); + err = getPerformanceMetricsHandle(handle, &metrics); + if (err == CAUSAL_LM_ERROR_NONE) { + double prefill_tps = + metrics.prefill_duration_ms > 0 + ? metrics.prefill_tokens / metrics.prefill_duration_ms * 1000.0 + : 0.0; + double gen_tps = + metrics.generation_duration_ms > 0 + ? metrics.generation_tokens / metrics.generation_duration_ms * 1000.0 + : 0.0; + + std::ostringstream oss; + + oss << metrics.initialization_duration_ms << " ms"; + print_kv("Init", oss.str(), clr::magenta); + + oss.str(""); + oss << metrics.prefill_tokens << " tokens / " << metrics.prefill_duration_ms + << " ms (" << std::fixed << std::setprecision(1) << prefill_tps + << " tok/s)"; + print_kv("Prefill", oss.str(), clr::magenta); + + oss.str(""); + oss << metrics.generation_tokens << " tokens / " + << metrics.generation_duration_ms << " ms (" << std::fixed + << std::setprecision(1) << gen_tps << " tok/s)"; + print_kv("Generation", oss.str(), clr::magenta); + + oss.str(""); + oss << metrics.total_duration_ms << " ms"; + print_kv("Total", oss.str(), clr::magenta); + + oss.str(""); + oss << metrics.peak_memory_kb << " KB"; + print_kv("Peak Memory", oss.str(), clr::magenta); + + // ── Metric validation ──────────────────────────────────────────────── + bool metrics_ok = true; + if (metrics.prefill_tokens == 0) { + std::cout << clr::magenta << "β”‚" << clr::reset << " " << clr::yellow + << "⚠ Warning: prefill_tokens is zero" << clr::reset << "\n"; + metrics_ok = false; + } + if (metrics.generation_tokens == 0) { + std::cout << clr::magenta << "β”‚" << clr::reset << " " << clr::yellow + << "⚠ Warning: generation_tokens is zero" << clr::reset << "\n"; + metrics_ok = false; + } + if (metrics.generation_duration_ms <= 0) { + std::cout << clr::magenta << "β”‚" << clr::reset << " " << clr::yellow + << "⚠ Warning: generation_duration_ms is zero/negative" + << clr::reset << "\n"; + metrics_ok = false; + } + if (metrics.total_duration_ms < + metrics.prefill_duration_ms + metrics.generation_duration_ms - 0.001) { + std::cout << clr::magenta << "β”‚" << clr::reset << " " << clr::yellow + << "⚠ Warning: total_duration_ms < prefill + generation" + << clr::reset << "\n"; + metrics_ok = false; + } + if (metrics_ok) { + std::cout << clr::magenta << "β”‚" << clr::reset << " " << clr::green + << "βœ“ All metric sanity checks passed" << clr::reset << "\n"; + } + } else { + std::cout << clr::magenta << "β”‚" << clr::reset << " " << clr::dim + << "(metrics not available)" << clr::reset << "\n"; + } + print_section_end(clr::magenta); + + // ── Cleanup ──────────────────────────────────────────────────────────── + destroyModelHandle(handle); + + // ── Done ─────────────────────────────────────────────────────────────── + std::cout << clr::bold_green << " Done." << clr::reset << "\n\n"; + + return 0; +} diff --git a/api/README.md b/api/README.md index 7bb344e3..75b31a2e 100644 --- a/api/README.md +++ b/api/README.md @@ -1,161 +1,417 @@ -# CausalLM API +# Quick.AI C API Reference 🧠 -This directory contains the C API for CausalLM application, designed to provide a simple interface for loading and running Large Language Models (LLMs) on various backends, including Android. +The Quick.AI C API is declared in `quick_dot_ai_api.h` and implemented by +`libquick_dot_ai_api.so`. It exposes nntrainer-backed model loading, +handle-based inference, streaming callbacks, multimodal paths, QNN KV-cache +helpers, chat templates, and XGrammar structured generation. -## Overview +## πŸ“š Contents -The API provides functionality to: -- Initialize and configure the CausalLM environment. -- Load pre-trained models with specific quantization settings. -- Run inference (text generation) given a prompt. -- Retrieve performance metrics (token counts, duration). +- [Model Catalog](#-model-catalog) +- [Model Enums](#-model-enums) +- [Core Types](#-core-types) +- [Global Options](#-global-options) +- [Legacy Single Model API](#-legacy-single-model-api) +- [Handle Based API](#-handle-based-api) +- [Streaming](#-streaming) +- [Multimodal](#-multimodal) +- [XGrammar](#-xgrammar) +- [OpenAI JSON Streaming](#-openai-json-streaming) +- [Error Codes](#-error-codes) -## Build & Integration +## Model Catalog -The CausalLM API is built as a separate shared library `libcausallm_api.so`, which depends on the core logic in `libcausallm_core.so`. +The preferred model identification mechanism is a **string model id** routed +through a self-registering descriptor catalog. -### Build Artifacts (Android) +### loadModelHandleByName -- **`libcausallm_core.so`**: Contains the core LLM implementation (Model, Layers, etc.). -- **`libcausallm_api.so`**: Contains the C API implementation (`causal_lm_api.cpp`) and configuration helpers (`model_config.cpp`). +```c +ErrorCode loadModelHandleByName(BackendType compute, + const char *model_id, + ModelQuantizationType quant_type, + const char *native_lib_dir, + const char *model_base_path, + CausalLmHandle *out_handle); +``` -### Linking +`model_id` is a string such as `"qwen3-0.6b"` or `"gemma4-cpu"`. The function +looks up the descriptor in the process-global registry, resolves the config, +and loads the model. This is the **preferred load path** for all new code. -When integrating this API into your application (e.g., via JNI), you must link against both libraries: -1. `libcausallm_api.so` -2. `libcausallm_core.so` -3. `libnntrainer.so` (Dependency) +### getModelCatalogJson -## Directory Structure & Model Loading +```c +const char *getModelCatalogJson(void); +``` -The API strictly relies on registered model types and quantization settings to locate model files. There are two modes of loading, depending on how the model is registered within the library. +Returns a JSON array of all model descriptors registered in the current +process. The returned pointer is valid for the lifetime of the process. Example +output: + +```json +[ + { + "id": "qwen3-0.6b", + "family": "qwen3-0.6b", + "display_name": "Qwen3 0.6B", + "runtime": 0, + "backend_mask": 3, + "capabilities": 9, + "config_name": "qwen3_0_6b", + "arch_string": "Qwen3ForCausalLM" + }, + { + "id": "gemma4-cpu", + "family": "gemma4-cpu", + "display_name": "Gemma4 CPU", + "runtime": 0, + "backend_mask": 1, + "capabilities": 1, + "config_name": "gemma4_cpu", + "arch_string": "Gemma3ForCausalLM" + } +] +``` -**Path Convention:** `./models/{ModelKey}{QuantizationSuffix}/` +`runtime` is `0` for `NATIVE` and `1` for `LITERT`. `backend_mask` is a +bitmask where bit 0 = CPU, bit 1 = GPU, bit 2 = NPU/QNN. `capabilities` is a +bitmask using the `CapabilityFlag` enum flags (bit 0 = STREAMING, bit 1 = +MESSAGES_API, bit 2 = MULTIMODAL, bit 3 = TOOL_USE, bit 4 = EMBEDDING, bit 5 = +MULTI_IMAGE, bit 6 = VISION_ENCODER). -- **ModelKey**: Derived from `ModelType` (e.g., `qwen3-0.6b`). -- **QuantizationSuffix**: Derived from `ModelQuantizationType` (e.g., `-w16a16`). +### ModelDescriptor struct -### 1. Internal/Embedded Configuration (Pre-configured) +```c +typedef struct { + const char *id; + const char *family; + const char *display_name; + int runtime; // 0 = NATIVE, 1 = LITERT + uint32_t backend_mask; + uint32_t capabilities; + const char *config_name; + const char *arch_string; +} ModelDescriptor; +``` -Some model configurations (including architecture, tokenizer settings, and generation parameters) are embedded directly into the CausalLM library code (via `model_config.cpp`). This protects the model specifications and simplifies deployment. +## πŸ€– Model Enums -- **Required Files:** - - **Weight Binary File**: The actual model weights (e.g., `qwen3-0.6b-fp32.bin`). The filename is hardcoded in the internal configuration. - - **Tokenizer Files**: `tokenizer.json` / `vocab.json` (if required by the specific tokenizer implementation). +> **Deprecated.** The `ModelType` enum is a compatibility shim maintained for +> ABI stability. All new code should use string model ids with +> `loadModelHandleByName()` instead. Ordinals are preserved for ABI stability, +> so the values are not contiguous. -- **Ignored Files:** - - `config.json`, `nntr_config.json`, `generation_config.json` are **NOT** loaded from the disk even if they exist. +| Enum | Value | Notes | +|---|---:|---| +| `CAUSAL_LM_MODEL_QWEN3_0_6B` | 0 | Qwen3 0.6B | +| `CAUSAL_LM_MODEL_QWEN3_1_7B_Q40` | 4 | Qwen3 1.7B Q40 | +| `CAUSAL_LM_MODEL_TINY_BERT` | 8 | TinyBERT | +| `CAUSAL_LM_MODEL_FUNCTION_GEMMA` | 9 | Function-calling Gemma | +| `CAUSAL_LM_MODEL_GEMMA4_CPU` | 11 | Gemma4 CPU | +| `CAUSAL_LM_MODEL_GEMMA4_E2B_QNN` | 12 | Gemma4 E2B QNN | +| `CAUSAL_LM_MODEL_VJEPA_QNN` | 13 | V-JEPA QNN (multi-image) | -### 2. External/File-based Configuration +QNN models require Android builds with `--enable-qnn`. -For registered model types that do not have a specific hardcoded configuration for the requested quantization, the API falls back to loading configuration files from the directory. +## 🧱 Core Types -- **Required Files:** - - **`config.json`**: Model architecture configuration (HuggingFace format). - - **`nntr_config.json`**: NNTrainer specific configuration. - - Must contain `"model_file_name"` field pointing to the binary weight file. - - **Weight Binary File**: The file specified in `nntr_config.json`. - - **`generation_config.json`**: (Optional) Generation parameters. - - **Tokenizer Files**: `tokenizer.json` / `vocab.json`. +```c +typedef struct CausalLmModel *CausalLmHandle; + +typedef enum { + CAUSAL_LM_BACKEND_CPU = 0, + CAUSAL_LM_BACKEND_GPU = 1, + CAUSAL_LM_BACKEND_NPU = 2, +} BackendType; + +typedef enum { + CAUSAL_LM_QUANTIZATION_UNKNOWN = 0, + CAUSAL_LM_QUANTIZATION_W4A32 = 1, + CAUSAL_LM_QUANTIZATION_W16A16 = 2, + CAUSAL_LM_QUANTIZATION_W8A16 = 3, + CAUSAL_LM_QUANTIZATION_W32A32 = 4, +} ModelQuantizationType; + +typedef struct { + const char *role; + const char *content; +} CausalLMChatMessage; + +typedef int (*CausalLmTokenCallback)(const char *delta, void *user_data); +``` -**Note:** When `debug_mode` is enabled in `setOptions`, the API will attempt to validate the existence of the required files for the resolved mode during initialization. +`delta` passed to `CausalLmTokenCallback` is valid only during the callback. +Copy it if you need to keep it. -## API Reference +## βš™οΈ Global Options -The main header file is `causal_lm_api.h`. +```c +typedef struct { + bool use_chat_template; + bool debug_mode; + bool verbose; + const char *chat_template_name; +} Config; + +ErrorCode setOptions(Config config); +``` -### Enums +`setOptions()` affects global chat-template/debug behavior for subsequent API +calls in the current process. -#### `ErrorCode` -Return codes for API functions. -- `CAUSAL_LM_ERROR_NONE`: Success. -- `CAUSAL_LM_ERROR_INVALID_PARAMETER`: Invalid argument provided. -- `CAUSAL_LM_ERROR_MODEL_LOAD_FAILED`: Failed to load the model. -- `CAUSAL_LM_ERROR_INFERENCE_FAILED`: Inference execution failed. -- `CAUSAL_LM_ERROR_NOT_INITIALIZED`: API not initialized or model not loaded. -- `CAUSAL_LM_ERROR_INFERENCE_NOT_RUN`: Metrics requested before inference run. +## πŸ•°οΈ Legacy Single Model API -#### `BackendType` -Target backend for computation. -- `CAUSAL_LM_BACKEND_CPU`: CPU execution (default). -- `CAUSAL_LM_BACKEND_GPU`: GPU execution (Planned). -- `CAUSAL_LM_BACKEND_NPU`: NPU execution (Planned). +These functions operate on one process-wide default handle. Prefer the +handle-based API for new code. -#### `ModelType` -Supported pre-defined model types. -- `CAUSAL_LM_MODEL_QWEN3_0_6B`: Qwen3 0.6B model. +```c +ErrorCode loadModel(BackendType compute, ModelType modeltype, + ModelQuantizationType quant_type, + const char *model_base_path); -#### `ModelQuantizationType` -Supported quantization formats. -- `CAUSAL_LM_QUANTIZATION_W4A32`: 4-bit weights, 32-bit activations. -- `CAUSAL_LM_QUANTIZATION_W16A16`: 16-bit weights, 16-bit activations (FP16). -- `CAUSAL_LM_QUANTIZATION_W8A16`: 8-bit weights, 16-bit activations. -- `CAUSAL_LM_QUANTIZATION_W32A32`: 32-bit weights, 32-bit activations (FP32). +ErrorCode getPerformanceMetrics(PerformanceMetrics *metrics); -### Functions +ErrorCode applyChatTemplate(const CausalLMChatMessage *messages, + size_t num_messages, + bool add_generation_prompt, + const char **formattedText); -#### `ErrorCode setOptions(Config config)` -Sets global configuration options. -- **config**: Structure containing options like `use_chat_template` and `debug_mode`. +ErrorCode saveQnnKvCache(const char *cache_path); +ErrorCode loadQnnKvCache(const char *cache_path); +ErrorCode resetQnnKvCache(void); +``` -#### `ErrorCode loadModel(BackendType compute, ModelType modeltype, ModelQuantizationType quant_type)` -Loads a registered model. -- **compute**: Backend to use. -- **modeltype**: Specific model enum. -- **quant_type**: Quantization type. +## 🧩 Handle Based API -#### `ErrorCode runModel(const char *inputTextPrompt, const char **outputText)` -Runs inference on the loaded model. -- **inputTextPrompt**: The input text/prompt. -- **outputText**: Pointer to store the result string. +```c +ErrorCode loadModelHandle(BackendType compute, ModelType modeltype, + ModelQuantizationType quant_type, + const char *native_lib_dir, + const char *model_base_path, + CausalLmHandle *out_handle); + +ErrorCode destroyModelHandle(CausalLmHandle handle); +ErrorCode unloadModelHandle(CausalLmHandle handle); +ErrorCode cancelModelHandle(CausalLmHandle handle); + +ErrorCode getPerformanceMetricsHandle(CausalLmHandle handle, + PerformanceMetrics *metrics); + +ErrorCode saveQnnKvCacheHandle(CausalLmHandle handle, const char *cache_path); +ErrorCode loadQnnKvCacheHandle(CausalLmHandle handle, const char *cache_path); +ErrorCode resetQnnKvCacheHandle(CausalLmHandle handle); +``` + +`native_lib_dir` is mainly used by Android/QNN flows to locate shared +libraries. `model_base_path` is the base directory for model files. -#### `ErrorCode getPerformanceMetrics(PerformanceMetrics *metrics)` -Retrieves performance metrics of the last run. -- **metrics**: Pointer to `PerformanceMetrics` struct to be filled. -- `prefill_tokens`, `prefill_duration_ms`: Stats for prompt processing. -- `generation_tokens`, `generation_duration_ms`: Stats for token generation. -- `total_duration_ms`: Total execution time from start to finish. -- `peak_memory_kb`: Peak resident set size (memory usage) in KB. +For native QNN backend extensions, set +`QUICK_DOT_AI_QNN_BACKEND_EXT_CONFIG_PATH` before loading the model to override +the default `htp_backend_ext_config.json` location. Absolute values are used +as-is. Relative values are resolved by `QNNContext` from +`QUICK_DOT_AI_BASE_DIR` when set, otherwise from the process current working +directory. If no override is set, the C API uses +`/htp_backend_ext_config.json`. -## Usage Example +## πŸ”„ Streaming ```c -#include "causal_lm_api.h" -#include +ErrorCode runModelHandleStreaming(CausalLmHandle handle, + const char *inputTextPrompt, + CausalLmTokenCallback callback, + void *user_data); + +ErrorCode runModelHandleWithMessages( + CausalLmHandle handle, + const CausalLMChatMessage *messages, + size_t num_messages, + bool add_generation_prompt, + const char **outputText); + +ErrorCode runModelHandleWithMessagesStreaming( + CausalLmHandle handle, + const CausalLMChatMessage *messages, + size_t num_messages, + bool add_generation_prompt, + CausalLmTokenCallback callback, + void *user_data); +``` + +Streaming calls are synchronous. They block the calling thread until generation +finishes, fails, or is cancelled, while progressively invoking `callback`. + +### Minimal Streaming Example + +```cpp +#include "quick_dot_ai_api.h" +#include + +int on_token(const char *delta, void *) { + std::cout << delta << std::flush; + return 0; +} int main() { - // 1. Set Options - Config config; - config.use_chat_template = true; - config.debug_mode = false; - setOptions(config); - - // 2. Load Model - // Automatically looks for files in "./models/qwen3-0.6b-w16a16/" - ErrorCode err = loadModel(CAUSAL_LM_BACKEND_CPU, - CAUSAL_LM_MODEL_QWEN3_0_6B, - CAUSAL_LM_QUANTIZATION_W16A16); - - if (err != CAUSAL_LM_ERROR_NONE) { - printf("Failed to load model\n"); - return -1; - } - - // 3. Run Inference - const char* output = NULL; - err = runModel("Hello, how are you?", &output); - - if (err == CAUSAL_LM_ERROR_NONE) { - printf("Response: %s\n", output); - } - - // 4. Check Metrics - PerformanceMetrics metrics; - if (getPerformanceMetrics(&metrics) == CAUSAL_LM_ERROR_NONE) { - printf("Generated %d tokens in %.2f ms\n", - metrics.generation_tokens, metrics.generation_duration_ms); - } - - return 0; + CausalLmHandle handle = nullptr; + ErrorCode err = loadModelHandle(CAUSAL_LM_BACKEND_NPU, + CAUSAL_LM_MODEL_GEMMA4_E2B_QNN, + CAUSAL_LM_QUANTIZATION_W4A32, + nullptr, + "/models", + &handle); + if (err != CAUSAL_LM_ERROR_NONE) return err; + + err = runModelHandleStreaming(handle, "Hello!", on_token, nullptr); + destroyModelHandle(handle); + return err; } ``` + +## πŸ–ΌοΈ Multimodal + +```c +ErrorCode runMultimodalHandleStreaming( + CausalLmHandle handle, + const char *prompt, + const float *pixelValues, + int numPatches, + int originalHeight, + int originalWidth, + CausalLmTokenCallback callback, + void *user_data); + +ErrorCode runMultimodalHandleWithMessages( + CausalLmHandle handle, + const CausalLMChatMessage *messages, + size_t num_messages, + bool add_generation_prompt, + const float *pixelValues, + int numPatches, + int originalHeight, + int originalWidth, + const char **outputText); + +ErrorCode runMultimodalHandleWithMessagesStreaming( + CausalLmHandle handle, + const CausalLMChatMessage *messages, + size_t num_messages, + bool add_generation_prompt, + const float *pixelValues, + int numPatches, + int originalHeight, + int originalWidth, + CausalLmTokenCallback callback, + void *user_data); +``` + +The handle must be loaded from a model configuration that supplies the expected +vision encoder + LLM sub-models. Unsupported handles return +`CAUSAL_LM_ERROR_UNSUPPORTED`. + +### Pairing a vision encoder with an LLM + +`loadMultimodalHandleByName()` builds one multimodal handle from two catalog +ids: a vision/embedding model and an LLM. The resulting handle has +`models[0]` = embedding producer and `models[1]` = LLM, driven together by the +multimodal run path. + +```c +ErrorCode loadMultimodalHandleByName(BackendType compute, + const char *embedding_model_id, + const char *llm_model_id, + ModelQuantizationType quant_type, + const char *native_lib_dir, + const char *model_base_path, + CausalLmHandle *out_handle); +``` + +Returns `CAUSAL_LM_ERROR_UNSUPPORTED` if the pair is incompatible (for example, +the chosen LLM exposes no embedding table). + +### Multi-image input + +For models that accept several images in one prompt (e.g. `vjepa-qnn`), use the +multi-image streaming functions: + +```c +ErrorCode runMultimodalMultiImageHandleStreaming( + CausalLmHandle handle, + const char *prompt, + const float *pixelValues, + int numPatches, + int numImages, + const int *patchesPerImage, + const int *originalHeights, + const int *originalWidths, + CausalLmTokenCallback callback, + void *user_data); + +ErrorCode runMultimodalMultiImageHandleWithMessagesStreaming( + CausalLmHandle handle, + const CausalLMChatMessage *messages, + size_t num_messages, + bool add_generation_prompt, + const float *pixelValues, + int numPatches, + int numImages, + const int *patchesPerImage, + const int *originalHeights, + const int *originalWidths, + CausalLmTokenCallback callback, + void *user_data); +``` + +## 🧰 XGrammar + +```c +ErrorCode runModelHandleWithTool(CausalLmHandle handle, + const char *inputTextPrompt, + const char **outputText, + const char *tool_name, + const char *tool_schema); +``` + +If a model directory contains `Toolset.json`, tools are precompiled at model +load. For dynamic schemas, pass `tool_schema` on first use. + +See [`../docs/ChatAndOpenAIUsage.md`](../docs/ChatAndOpenAIUsage.md) for usage +examples and [`../docs/XGrammarReference.md`](../docs/XGrammarReference.md) +for XGrammar internals. + +## πŸ“‘ OpenAI JSON Streaming + +```c +ErrorCode runModelHandleWithJsonStreaming(CausalLmHandle handle, + const char *jsonRequest, + CausalLmTokenCallback callback, + void *user_data); +``` + +`jsonRequest` accepts OpenAI-style request JSON, including `messages`, `tools`, +and legacy `functions`. A chat template must be available from the loaded model +directory or the call returns `CAUSAL_LM_ERROR_UNSUPPORTED`. + +See [`../docs/ChatAndOpenAIUsage.md`](../docs/ChatAndOpenAIUsage.md) for usage +examples and request routing details. + +## ❌ Error Codes + +| Code | Constant | Meaning | +|---:|---|---| +| 0 | `CAUSAL_LM_ERROR_NONE` | Success | +| 1 | `CAUSAL_LM_ERROR_INVALID_PARAMETER` | Null pointer, invalid request, bad JSON, or invalid argument | +| 2 | `CAUSAL_LM_ERROR_MODEL_LOAD_FAILED` | Model/config/weight load failed | +| 3 | `CAUSAL_LM_ERROR_INFERENCE_FAILED` | Runtime inference failure | +| 4 | `CAUSAL_LM_ERROR_NOT_INITIALIZED` | Handle/model is not loaded | +| 5 | `CAUSAL_LM_ERROR_INFERENCE_NOT_RUN` | Metrics requested before inference | +| 6 | `CAUSAL_LM_ERROR_UNSUPPORTED` | Feature unsupported by this build/model/handle | +| 99 | `CAUSAL_LM_ERROR_UNKNOWN` | Unknown internal failure | + +## πŸ“Ž Related Docs + +- [Main README](../README.md) +- [Android AAR API](../Android/QuickDotAI/README.md) +- [Chat and OpenAI Usage Examples](../docs/ChatAndOpenAIUsage.md) +- [Chat Templates](../docs/ChatTemplate.md) +- [XGrammar Reference](../docs/XGrammarReference.md) diff --git a/api/callback_streamer.cpp b/api/callback_streamer.cpp new file mode 100644 index 00000000..e315a63c --- /dev/null +++ b/api/callback_streamer.cpp @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2026 Samsung Electronics Co., Ltd. All Rights Reserved. + * + * @file callback_streamer.cpp + * @brief Implementation of the CallbackStreamer vtable, which routes + * every decoded-token delta to a user-supplied callback. + * See AsyncAndStreaming.md Β§3.2 at the repo root. + */ + +#include "callback_streamer.h" + +extern "C" { + +static int callback_streamer_put(BaseStreamer *self, const char *decoded_utf8) { + CallbackStreamer *cs = reinterpret_cast(self); + if (cs == nullptr || cs->callback == nullptr) { + return 0; + } + // Once the user has asked us to cancel, keep returning the sticky + // cancellation flag β€” this protects against the (cheap but real) + // race where the CausalLM generation loop emits one extra token + // between setting stop_requested_ and actually breaking out. + if (cs->cancelled != 0) { + return cs->cancelled; + } + int rc = cs->callback(decoded_utf8, cs->user_data); + if (rc != 0) { + cs->cancelled = rc; + } + return rc; +} + +static void callback_streamer_end(BaseStreamer * /*self*/) { + // Intentionally empty. Stream termination is reported to the caller + // through the return value of runModelHandleStreaming(); there is no + // "done" payload to forward here. +} + +static const BaseStreamerVTable kCallbackStreamerVTable = { + /*.put =*/&callback_streamer_put, + /*.end =*/&callback_streamer_end, +}; + +void callback_streamer_init(CallbackStreamer *self, CausalLmTokenCallback cb, + void *user_data) { + if (self == nullptr) { + return; + } + self->base.vtable = &kCallbackStreamerVTable; + self->callback = cb; + self->user_data = user_data; + self->cancelled = 0; +} + +} // extern "C" \ No newline at end of file diff --git a/api/callback_streamer.h b/api/callback_streamer.h new file mode 100644 index 00000000..fa7dd77b --- /dev/null +++ b/api/callback_streamer.h @@ -0,0 +1,75 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2026 Samsung Electronics Co., Ltd. All Rights Reserved. + * + * @file callback_streamer.h + * @brief BaseStreamer implementation that forwards every delta to a + * user-supplied C function pointer. + * + * This is the streamer used by the JNI bridge in QuickAI: the Kotlin + * side hands the JNI entry point a listener object, and the JNI entry + * point wraps the listener in a CausalLmTokenCallback + user_data pair + * and pushes a CallbackStreamer onto its own stack frame for the + * duration of the call. + * + * See AsyncAndStreaming.md Β§3.2 at the repo root. + */ +#ifndef __QUICK_DOT_AI_CALLBACK_STREAMER_H__ +#define __QUICK_DOT_AI_CALLBACK_STREAMER_H__ + +#ifndef WIN_EXPORT +#ifdef _WIN32 +#define WIN_EXPORT __declspec(dllexport) +#else +#define WIN_EXPORT +#endif +#endif + +#include "streamer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * @brief Token callback signature. + * + * @param delta UTF-8 text produced for this token boundary. Valid + * only for the duration of the call β€” copy before + * retaining. + * @param user_data Opaque pointer passed through from the + * runModelHandleStreaming() caller. + * @return 0 to continue generation, non-zero to request cancellation. + */ +typedef int (*CausalLmTokenCallback)(const char *delta, void *user_data); + +/** + * @brief A BaseStreamer that forwards every put() to a + * CausalLmTokenCallback. + * + * Layout note: @c base MUST be the first member so that a + * `CallbackStreamer*` can be safely cast to `BaseStreamer*`. + */ +typedef struct { + BaseStreamer base; + CausalLmTokenCallback callback; + void *user_data; + int cancelled; /**< sticky: once set to non-zero, put() becomes a no-op. */ +} CallbackStreamer; + +/** + * @brief Initialize a CallbackStreamer in-place. Does not allocate. + * + * @param self Storage owned by the caller (typically stack). + * @param cb Callback to invoke for every delta. Must be non-NULL. + * @param user_data Opaque pointer forwarded to @c cb. + */ +WIN_EXPORT void callback_streamer_init(CallbackStreamer *self, + CausalLmTokenCallback cb, + void *user_data); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // __QUICK_DOT_AI_CALLBACK_STREAMER_H__ \ No newline at end of file diff --git a/api/causal_lm_api.cpp b/api/causal_lm_api.cpp deleted file mode 100644 index 45166934..00000000 --- a/api/causal_lm_api.cpp +++ /dev/null @@ -1,627 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * Copyright (C) 2026 Samsung Electronics Co., Ltd. All Rights Reserved. - * - * @file causal_lm_api.cpp - * @date 21 Jan 2026 - * @brief This is a C API for CausalLM application - * @see https://github.com/nntrainer/nntrainer - * @author Eunju Yang - * @bug No known bugs except for NYI items - */ - -#include "causal_lm_api.h" -#include -#include -#include -#include -#include -#include -#include -#include - -#include "causal_lm.h" -#include "chat_template.h" -#include "gemma3_causallm.h" -#include "gptoss_cached_slim_causallm.h" -#include "gptoss_causallm.h" -#include "json.hpp" -#include "model_config_internal.h" -#include "qwen2_causallm.h" -#include "qwen3_cached_slim_moe_causallm.h" -#include "qwen3_causallm.h" -#include "qwen3_moe_causallm.h" -#include "qwen3_slim_moe_causallm.h" -#include -#include -#include -#include - -using json = nlohmann::json; - -static std::unique_ptr g_model; -static std::mutex g_mutex; -static bool g_initialized = false; -static std::string g_architecture = ""; -static bool g_use_chat_template = false; -static bool g_verbose = false; -static std::string g_last_output = ""; -static double g_initialization_duration_ms = 0.0; -static quick_dot_ai::ChatTemplate g_chat_template; - -static std::map g_model_path_map = { - {"QWEN3-0.6B", "qwen3-0.6b"}, -}; - -/** - * @brief RegisteredModel - */ -struct RegisteredModel { - std::string arch_name; - ModelRuntimeConfig config; -}; -static std::map g_model_registry; -static std::map g_arch_config_map; - -// Helper to register models (similar to main.cpp) -// ensuring factory is populated. -// @note: Factory registration is singleton and persistent, but we do it once -// here to be sure. Since main.cpp is not linked, we must duplicate registration -// or share it. Assuming this lib is used independently of main.cpp. -static void register_models() { - static std::once_flag flag; - std::call_once(flag, []() { - quick_dot_ai::Factory::Instance().registerModel( - "LlamaForCausalLM", [](json cfg, json generation_cfg, json nntr_cfg) { - return std::make_unique(cfg, generation_cfg, - nntr_cfg); - }); - quick_dot_ai::Factory::Instance().registerModel( - "Qwen2ForCausalLM", [](json cfg, json generation_cfg, json nntr_cfg) { - return std::make_unique(cfg, generation_cfg, - nntr_cfg); - }); - quick_dot_ai::Factory::Instance().registerModel( - "Qwen3ForCausalLM", [](json cfg, json generation_cfg, json nntr_cfg) { - return std::make_unique(cfg, generation_cfg, - nntr_cfg); - }); - quick_dot_ai::Factory::Instance().registerModel( - "Qwen3MoeForCausalLM", [](json cfg, json generation_cfg, json nntr_cfg) { - return std::make_unique(cfg, generation_cfg, - nntr_cfg); - }); - quick_dot_ai::Factory::Instance().registerModel( - "Qwen3SlimMoeForCausalLM", - [](json cfg, json generation_cfg, json nntr_cfg) { - return std::make_unique( - cfg, generation_cfg, nntr_cfg); - }); - quick_dot_ai::Factory::Instance().registerModel( - "Qwen3CachedSlimMoeForCausalLM", - [](json cfg, json generation_cfg, json nntr_cfg) { - return std::make_unique( - cfg, generation_cfg, nntr_cfg); - }); - quick_dot_ai::Factory::Instance().registerModel( - "GptOssForCausalLM", [](json cfg, json generation_cfg, json nntr_cfg) { - return std::make_unique( - cfg, generation_cfg, nntr_cfg); - }); - quick_dot_ai::Factory::Instance().registerModel( - "GptOssCachedSlimCausalLM", - [](json cfg, json generation_cfg, json nntr_cfg) { - return std::make_unique( - cfg, generation_cfg, nntr_cfg); - }); - quick_dot_ai::Factory::Instance().registerModel( - "Gemma3ForCausalLM", [](json cfg, json generation_cfg, json nntr_cfg) { - return std::make_unique(cfg, generation_cfg, - nntr_cfg); - }); - - // Register built-in configurations - register_builtin_model_configs(); - }); -} - -static const char *get_model_name_from_type(ModelType type) { - switch (type) { - case CAUSAL_LM_MODEL_QWEN3_0_6B: - return "QWEN3-0.6B"; - default: - return nullptr; - } -} - -static std::string apply_chat_template(const std::string &architecture, - const std::string &input) { - // Use dynamic chat template from tokenizer_config.json if available - if (g_chat_template.isAvailable()) { - return g_chat_template.apply(input); - } - - // Fallback: hardcoded per-architecture templates - if (architecture == "LlamaForCausalLM") { - // Llama 2/3 chat format: [INST] {prompt} [/INST] - return "[INST] " + input + " [/INST]"; - } else if (architecture == "Qwen2ForCausalLM" || - architecture == "Qwen3ForCausalLM" || - architecture == "Qwen3MoeForCausalLM" || - architecture == "Qwen3SlimMoeForCausalLM" || - architecture == "Qwen3CachedSlimMoeForCausalLM") { - // Qwen chat format - // <|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n - return "<|im_start|>user\n" + input + "<|im_end|>\n<|im_start|>assistant\n"; - } else if (architecture == "Gemma3ForCausalLM") { - // Gemma chat format: - // user\n{prompt}\nmodel\n - return "user\n" + input + - "\nmodel\n"; - } - return input; -} - -static std::string get_quantization_suffix(ModelQuantizationType type) { - switch (type) { - case CAUSAL_LM_QUANTIZATION_W4A32: - return "-w4a32"; - case CAUSAL_LM_QUANTIZATION_W16A16: - return "-w16a16"; - case CAUSAL_LM_QUANTIZATION_W8A16: - return "-w8a16"; - case CAUSAL_LM_QUANTIZATION_W32A32: - return "-w32a32"; - default: // W4A32 by default - return "-w4a32"; - } -} - -static std::string resolve_model_path(const std::string &model_key, - ModelQuantizationType quant_type) { - std::string path_upper = model_key; - std::transform(path_upper.begin(), path_upper.end(), path_upper.begin(), - ::toupper); - - std::string base_dir_name = ""; - - // 1. Try to find base directory name from map - if (g_model_path_map.find(path_upper) != g_model_path_map.end()) { - base_dir_name = g_model_path_map[path_upper]; - } else { - // Fallback: use lowercased key as base dir name if not found in map - // or just return empty? For restricted API, we should probably fail - // earlier, but here we can return constructed path. - base_dir_name = path_upper; - std::transform(base_dir_name.begin(), base_dir_name.end(), - base_dir_name.begin(), ::tolower); - } - - std::string model_path = - "./models/" + base_dir_name + get_quantization_suffix(quant_type); - - return model_path; -} - -static bool check_file_exists(const std::string &path) { - struct stat buffer; - return (stat(path.c_str(), &buffer) == 0); -} - -static void validate_models() { - std::cout << "[DEBUG] Validating model files..." << std::endl; - // Iterate over all known model names in map - for (auto const &[key, val] : g_model_path_map) { - // We want to check for each Quantization Type if it exists - // List of quant types to check: UNKNOWN (default), W4A32, W16A16, W32A32 - std::vector quant_types = { - CAUSAL_LM_QUANTIZATION_UNKNOWN, CAUSAL_LM_QUANTIZATION_W4A32, - CAUSAL_LM_QUANTIZATION_W16A16, CAUSAL_LM_QUANTIZATION_W32A32}; - - for (auto qt : quant_types) { - std::string quant_suffix = get_quantization_suffix(qt); - - std::string lookup_key = key; - if (qt != CAUSAL_LM_QUANTIZATION_UNKNOWN) { - std::transform(quant_suffix.begin(), quant_suffix.end(), - quant_suffix.begin(), ::toupper); // "-W4A32" - lookup_key += quant_suffix; - } - - // Resolve path for this combination - std::string resolved_path = resolve_model_path(key, qt); - - if (g_model_registry.find(lookup_key) != g_model_registry.end()) { - // CASE 1: Configuration is registered in model_config.cpp - // For these models, we only check if the binary weight file exists. - // The configurations (config.json, etc.) are embedded in the library. - RegisteredModel &rm = g_model_registry[lookup_key]; - std::string bin_file_name = rm.config.model_file_name; - std::string full_path = resolved_path + "/" + bin_file_name; - - if (check_file_exists(full_path)) { - std::cout << " [OK] Reg Config: " << lookup_key << " -> " - << full_path << std::endl; - } else { - std::cout << " [FAIL] Reg Config: " << lookup_key - << " -> Missing binary: " << full_path << std::endl; - } - - } else { - // CASE 2: No internal config, but model type exists (via map - // iteration). For these models, we require external configuration files - // (config.json, nntr_config.json) to be present in the directory. - if (check_file_exists(resolved_path)) { - bool has_config = check_file_exists(resolved_path + "/config.json"); - bool has_nntr = - check_file_exists(resolved_path + "/nntr_config.json"); - - if (has_config && has_nntr) { - std::cout << " [OK] External Config: " << lookup_key << " -> " - << resolved_path << std::endl; - // Optional: Parse nntr_config to check bin - try { - json nntr = - quick_dot_ai::LoadJsonFile(resolved_path + "/nntr_config.json"); - if (nntr.contains("model_file_name")) { - std::string bin = nntr["model_file_name"]; - if (check_file_exists(resolved_path + "/" + bin)) { - std::cout << " (Binary confirmed: " << bin << ")" - << std::endl; - } else { - std::cout << " (MISSING BINARY: " << bin << ")" - << std::endl; - } - } - } catch (...) { - } - } else { - std::cout << " [FAIL] External Config: " << lookup_key - << " -> Missing configs in " << resolved_path - << std::endl; - } - } - } - } - } -} - -ErrorCode setOptions(Config config) { - // Currently no options are being handled - g_use_chat_template = config.use_chat_template; - g_verbose = config.verbose; - if (config.debug_mode) { - // Ensure models are registered so we can validate them - register_models(); - validate_models(); - } - return CAUSAL_LM_ERROR_NONE; -} - -ErrorCode registerModelArchitecture(const char *arch_name, - ModelArchConfig config) { - if (arch_name == nullptr) - return CAUSAL_LM_ERROR_INVALID_PARAMETER; - std::lock_guard lock(g_mutex); - std::string name(arch_name); - std::transform(name.begin(), name.end(), name.begin(), ::toupper); - g_arch_config_map[name] = config; - return CAUSAL_LM_ERROR_NONE; -} - -ErrorCode registerModel(const char *model_name, const char *arch_name, - ModelRuntimeConfig config) { - if (model_name == nullptr || arch_name == nullptr) - return CAUSAL_LM_ERROR_INVALID_PARAMETER; - std::lock_guard lock(g_mutex); - std::string name(model_name); - std::transform(name.begin(), name.end(), name.begin(), ::toupper); - - std::string aname(arch_name); - std::transform(aname.begin(), aname.end(), aname.begin(), ::toupper); - - g_model_registry[name] = {aname, config}; - return CAUSAL_LM_ERROR_NONE; -} - -ErrorCode loadModel(BackendType compute, ModelType modeltype, - ModelQuantizationType quant_type) { - - auto start_init = std::chrono::high_resolution_clock::now(); - - const char *target_model_name = get_model_name_from_type(modeltype); - if (target_model_name == nullptr) { - return CAUSAL_LM_ERROR_INVALID_PARAMETER; - } - - // Ensure models/configs are registered (thread-safe via call_once) - register_models(); - - std::lock_guard lock(g_mutex); - try { - - // Check if it's a registered in-memory config - std::string input_name = std::string(target_model_name); - std::string input_name_upper = input_name; - std::transform(input_name_upper.begin(), input_name_upper.end(), - input_name_upper.begin(), ::toupper); - - std::string quant_suffix = ""; - switch (quant_type) { - case CAUSAL_LM_QUANTIZATION_W4A32: - quant_suffix = "-W4A32"; - break; - case CAUSAL_LM_QUANTIZATION_W16A16: - quant_suffix = "-W16A16"; - break; - case CAUSAL_LM_QUANTIZATION_W8A16: - quant_suffix = "-W8A16"; - break; - case CAUSAL_LM_QUANTIZATION_W32A32: - quant_suffix = "-W32A32"; - break; - default: - break; - } - std::string lookup_name = input_name_upper + quant_suffix; - - json cfg; - json generation_cfg; - json nntr_cfg; - std::string model_dir_path; - - // Check in-memory map first - if (g_model_registry.find(lookup_name) != g_model_registry.end()) { - // ------------------------------------------------------------------------ - // CASE 1: Model Configuration is Internal (Registered in - // model_config.cpp) - // ------------------------------------------------------------------------ - // In this case, we do NOT load config.json or nntr_config.json from disk. - // We only locate the binary weight file. - RegisteredModel &rm = g_model_registry[lookup_name]; - - // Find architecture config - if (g_arch_config_map.find(rm.arch_name) == g_arch_config_map.end()) { - std::cerr << "Architecture '" << rm.arch_name - << "' not found for model '" << lookup_name << "'" - << std::endl; - return CAUSAL_LM_ERROR_MODEL_LOAD_FAILED; - } - ModelArchConfig &ac = g_arch_config_map[rm.arch_name]; - ModelRuntimeConfig &rc = rm.config; - - // Strategy: Resolve path to find the weight file - model_dir_path = resolve_model_path(target_model_name, quant_type); - - // Populate JSONs from Arch Struct - cfg["vocab_size"] = ac.vocab_size; - cfg["hidden_size"] = ac.hidden_size; - cfg["intermediate_size"] = ac.intermediate_size; - cfg["num_hidden_layers"] = ac.num_hidden_layers; - cfg["num_attention_heads"] = ac.num_attention_heads; - cfg["head_dim"] = ac.head_dim; - cfg["num_key_value_heads"] = ac.num_key_value_heads > 0 - ? ac.num_key_value_heads - : ac.num_attention_heads; - cfg["max_position_embeddings"] = ac.max_position_embeddings; - cfg["rope_theta"] = ac.rope_theta; - cfg["rms_norm_eps"] = ac.rms_norm_eps; - cfg["tie_word_embeddings"] = ac.tie_word_embeddings; - if (ac.sliding_window != UINT_MAX) { - cfg["sliding_window"] = ac.sliding_window; - } else { - cfg["sliding_window"] = nullptr; - } - cfg["sliding_window_pattern"] = ac.sliding_window_pattern; - cfg["architectures"] = {std::string(ac.architecture)}; - - if (ac.num_eos_token_ids > 0) { - std::vector eos_ids; - for (unsigned int i = 0; i < ac.num_eos_token_ids; ++i) - eos_ids.push_back(ac.eos_token_ids[i]); - generation_cfg["eos_token_id"] = eos_ids; - } - generation_cfg["bos_token_id"] = ac.bos_token_id; - - // Populate JSONs from Runtime Struct - generation_cfg["top_k"] = rc.top_k; - generation_cfg["top_p"] = rc.top_p; - generation_cfg["temperature"] = rc.temperature; - generation_cfg["do_sample"] = false; - - nntr_cfg["batch_size"] = rc.batch_size; - nntr_cfg["model_type"] = std::string(rc.model_type); - nntr_cfg["model_tensor_type"] = std::string(rc.model_tensor_type); - nntr_cfg["init_seq_len"] = rc.init_seq_len; - nntr_cfg["max_seq_len"] = rc.max_seq_len; - nntr_cfg["num_to_generate"] = rc.num_to_generate; - nntr_cfg["fsu"] = rc.fsu; - nntr_cfg["fsu_lookahead"] = rc.fsu_lookahead; - nntr_cfg["embedding_dtype"] = std::string(rc.embedding_dtype); - nntr_cfg["fc_layer_dtype"] = std::string(rc.fc_layer_dtype); - nntr_cfg["model_file_name"] = std::string(rc.model_file_name); - - std::string t_file = rc.tokenizer_file; - nntr_cfg["tokenizer_file"] = model_dir_path + "/" + t_file; - - if (strlen(rc.lmhead_dtype) > 0) { - nntr_cfg["lmhead_dtype"] = std::string(rc.lmhead_dtype); - } - - std::vector bad_ids; - for (unsigned int i = 0; i < rc.num_bad_word_ids; ++i) - bad_ids.push_back(rc.bad_word_ids[i]); - nntr_cfg["bad_word_ids"] = bad_ids; - - } else { - // -------------------------------------------------- - // CASE 2: External Model Configuration (File-based) - // -------------------------------------------------- - // The model type is registered (enum), but specific configuration for - // this quantization is not in memory. We must load config.json and - // nntr_config.json from the model directory - model_dir_path = resolve_model_path(target_model_name, quant_type); - - // Load configuration files - cfg = quick_dot_ai::LoadJsonFile(model_dir_path + "/config.json"); - generation_cfg = - quick_dot_ai::LoadJsonFile(model_dir_path + "/generation_config.json"); - nntr_cfg = quick_dot_ai::LoadJsonFile(model_dir_path + "/nntr_config.json"); - - if (nntr_cfg.contains("tokenizer_file")) { - std::string t_file = nntr_cfg["tokenizer_file"]; - nntr_cfg["tokenizer_file"] = model_dir_path + "/" + t_file; - } - } - - // Load chat template from tokenizer_config.json if available - std::string tc_path = model_dir_path + "/tokenizer_config.json"; - if (check_file_exists(tc_path)) { - g_chat_template = quick_dot_ai::ChatTemplate::fromFile(tc_path); - if (g_chat_template.isAvailable()) { - std::cout << "[Info] Chat template loaded from tokenizer_config.json" - << std::endl; - } else { - std::cerr - << "[Warning] tokenizer_config.json found but chat template could " - "not be loaded. Falling back to hardcoded templates." - << std::endl; - } - } else { - g_chat_template = quick_dot_ai::ChatTemplate(); - std::cerr << "[Warning] tokenizer_config.json not found in " - << model_dir_path << ". Using hardcoded chat templates." - << std::endl; - } - - // Construct weight file path - std::string weight_file_name; - if (nntr_cfg.contains("model_file_name")) { - weight_file_name = nntr_cfg["model_file_name"].get(); - } else { - weight_file_name = - "pytorch_model.bin"; // Default fallback if not specified - } - - const std::string weight_file = model_dir_path + "/" + weight_file_name; - - // Determine architecture from config or ModelType - // Priority: Config file architecture > ModelType mapping (fallback) - std::string architecture; - if (cfg.contains("architectures") && cfg["architectures"].is_array() && - !cfg["architectures"].empty()) { - architecture = cfg["architectures"].get>()[0]; - } else { - // No fallback mapping from specific ModelType instances to generic - // architecture strings for now, as specific types should have config or - // be loaded from valid file with config.json - return CAUSAL_LM_ERROR_INVALID_PARAMETER; - } - - g_model = quick_dot_ai::Factory::Instance().create(architecture, cfg, - generation_cfg, nntr_cfg); - if (!g_model) { - return CAUSAL_LM_ERROR_MODEL_LOAD_FAILED; - } - - g_model->initialize(); - g_model->load_weight(weight_file); - - g_initialized = true; - g_architecture = architecture; - - auto finish_init = std::chrono::high_resolution_clock::now(); - auto init_duration = std::chrono::duration_cast( - finish_init - start_init); - g_initialization_duration_ms = init_duration.count(); - - } catch (const std::exception &e) { - std::cerr << "Exception in loadModel: " << e.what() << std::endl; - return CAUSAL_LM_ERROR_MODEL_LOAD_FAILED; - } catch (...) { - std::cerr << "Unknown exception in loadModel" << std::endl; - return CAUSAL_LM_ERROR_MODEL_LOAD_FAILED; - } - - return CAUSAL_LM_ERROR_NONE; -} - -ErrorCode runModel(const char *inputTextPrompt, const char **outputText) { - if (!g_initialized || !g_model) { - return CAUSAL_LM_ERROR_NOT_INITIALIZED; - } - if (inputTextPrompt == nullptr || outputText == nullptr) { - return CAUSAL_LM_ERROR_INVALID_PARAMETER; - } - - try { - std::lock_guard lock(g_mutex); - - std::string input(inputTextPrompt); - - if (g_use_chat_template) { - input = apply_chat_template(g_architecture, input); - } - -// We assume single batch request for this API -#if defined(_WIN32) - g_model->run(std::wstring(input.begin(), input.end()), L"", L"", nullptr, - g_verbose); -#else - g_model->run(input, "", "", nullptr, g_verbose); -#endif - - auto causal_lm_model = dynamic_cast(g_model.get()); - g_last_output = ""; // Reset last output - if (causal_lm_model) { - g_last_output = causal_lm_model->getOutput(0); - } - - *outputText = g_last_output.c_str(); - - } catch (const std::exception &e) { - std::cerr << "Exception in runModel: " << e.what() << std::endl; - return CAUSAL_LM_ERROR_INFERENCE_FAILED; - } - - return CAUSAL_LM_ERROR_NONE; -} - -ErrorCode getPerformanceMetrics(PerformanceMetrics *metrics) { - if (!g_initialized || !g_model) { - return CAUSAL_LM_ERROR_NOT_INITIALIZED; - } - if (metrics == nullptr) { - return CAUSAL_LM_ERROR_INVALID_PARAMETER; - } - - try { - std::lock_guard lock(g_mutex); - auto causal_lm_model = dynamic_cast(g_model.get()); - - if (causal_lm_model) { - if (!causal_lm_model->hasRun()) { - return CAUSAL_LM_ERROR_INFERENCE_NOT_RUN; - } - auto internal_metrics = causal_lm_model->getPerformanceMetrics(); - metrics->prefill_tokens = internal_metrics.prefill_tokens; - metrics->prefill_duration_ms = internal_metrics.prefill_duration_ms; - metrics->generation_tokens = internal_metrics.generation_tokens; - metrics->generation_duration_ms = internal_metrics.generation_duration_ms; - metrics->total_duration_ms = internal_metrics.total_duration_ms; - metrics->peak_memory_kb = internal_metrics.peak_memory_kb; - - // Overwrite init duration with the one measured in loadModel API - metrics->initialization_duration_ms = g_initialization_duration_ms; - } else { - return CAUSAL_LM_ERROR_UNKNOWN; - } - - } catch (const std::exception &e) { - std::cerr << "Exception in getPerformanceMetrics: " << e.what() - << std::endl; - return CAUSAL_LM_ERROR_UNKNOWN; - } - - return CAUSAL_LM_ERROR_NONE; -} diff --git a/api/causal_lm_api.h b/api/causal_lm_api.h deleted file mode 100644 index 6eb0eafe..00000000 --- a/api/causal_lm_api.h +++ /dev/null @@ -1,128 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * Copyright (C) 2026 Samsung Electronics Co., Ltd. All Rights Reserved. - * - * @file causal_lm_api.h - * @date 21 Jan 2026 - * @brief This is a C API for CausalLM application - * @see https://github.com/nntrainer/nntrainer - * @author Eunju Yang - * @bug No known bugs except for NYI items - */ -#ifndef __CAUSAL_LM_API_H__ -#define __CAUSAL_LM_API_H__ - -#ifdef _WIN32 -#define WIN_EXPORT __declspec(dllexport) -#else -#define WIN_EXPORT -#endif - -#ifdef __cplusplus -extern "C" { -#endif - -#include - -/** - * @brief Performance Metrics - */ -typedef struct { - unsigned int prefill_tokens; - double prefill_duration_ms; - unsigned int generation_tokens; - double generation_duration_ms; - double total_duration_ms; - double initialization_duration_ms; - size_t peak_memory_kb; -} PerformanceMetrics; - -/** - * @brief Error codes - */ -typedef enum { - CAUSAL_LM_ERROR_NONE = 0, - CAUSAL_LM_ERROR_INVALID_PARAMETER = 1, - CAUSAL_LM_ERROR_MODEL_LOAD_FAILED = 2, - CAUSAL_LM_ERROR_INFERENCE_FAILED = 3, - CAUSAL_LM_ERROR_NOT_INITIALIZED = 4, - CAUSAL_LM_ERROR_INFERENCE_NOT_RUN = 5, - CAUSAL_LM_ERROR_UNKNOWN = 99 -} ErrorCode; - -/** - * @brief Backend compute type - */ -typedef enum { - CAUSAL_LM_BACKEND_CPU = 0, - CAUSAL_LM_BACKEND_GPU = 1, /// < @todo: support gpu - CAUSAL_LM_BACKEND_NPU = 2, /// < @todo: support npu -} BackendType; - -/** - * @brief Model type - * @note Enable only when your library supports the model - */ -typedef enum { - CAUSAL_LM_MODEL_QWEN3_0_6B = 0, -} ModelType; - -/** - * @brief Configuration structure - */ -typedef struct { - // Add configuration options here as needed - bool use_chat_template; /// < @brief Whether to apply chat template to input - bool debug_mode; /// < @brief Check model file validity during initialization - bool verbose; /// < @brief Whether to print output during generation -} Config; - -/** - * @brief Set global options - * @param config Configuration object - * @return ErrorCode - */ -WIN_EXPORT ErrorCode setOptions(Config config); - -/** - * @brief Model Quantization type - */ -typedef enum { - CAUSAL_LM_QUANTIZATION_UNKNOWN = 0, - CAUSAL_LM_QUANTIZATION_W4A32 = 1, ///< 4-bit weights, 32-bit activations - CAUSAL_LM_QUANTIZATION_W16A16 = 2, ///< 16-bit weights, 16-bit activations - CAUSAL_LM_QUANTIZATION_W8A16 = 3, ///< 8-bit weights, 16-bit activations - CAUSAL_LM_QUANTIZATION_W32A32 = 4, ///< 32-bit weights, 32-bit activations -} ModelQuantizationType; - -/** - * @brief Load a model - * @param compute Backend compute type - * @param modeltype Model type - * @param quant_type Model quantization type - * @return ErrorCode - */ -WIN_EXPORT ErrorCode loadModel(BackendType compute, ModelType modeltype, - ModelQuantizationType quant_type); - -/** - * @brief Get performance metrics of the last run - * @param metrics Pointer to PerformanceMetrics struct to be filled - * @return ErrorCode - */ -WIN_EXPORT ErrorCode getPerformanceMetrics(PerformanceMetrics *metrics); - -/** - * @brief Run inference - * @param inputTextPrompt Input prompt - * @param outputText Buffer to store output text - * @return ErrorCode - */ -WIN_EXPORT ErrorCode runModel(const char *inputTextPrompt, - const char **outputText); - -#ifdef __cplusplus -} -#endif - -#endif // __CAUSAL_LM_API_H__ diff --git a/api/jni/Android.mk b/api/jni/Android.mk new file mode 100644 index 00000000..4a549939 --- /dev/null +++ b/api/jni/Android.mk @@ -0,0 +1,86 @@ +LOCAL_PATH := $(call my-dir) + +# ── Path configuration ────────────────────────────────────────────────── +ifndef NNTRAINER_ROOT +NNTRAINER_ROOT := $(LOCAL_PATH)/../../nntrainer +endif + +CAUSALLM_ROOT := $(NNTRAINER_ROOT)/Applications/CausalLM +QUICK_DOT_AI_ROOT := $(LOCAL_PATH)/../../src + +NNTRAINER_INCLUDES := $(NNTRAINER_ROOT)/builddir/android_build_result/include/nntrainer + +# ── Prebuilt nntrainer libraries ───────────────────────────────────────── +include $(CLEAR_VARS) +LOCAL_MODULE := nntrainer +LOCAL_SRC_FILES := $(NNTRAINER_ROOT)/builddir/android_build_result/lib/$(TARGET_ARCH_ABI)/libnntrainer.so +include $(PREBUILT_SHARED_LIBRARY) + +include $(CLEAR_VARS) +LOCAL_MODULE := ccapi-nntrainer +LOCAL_SRC_FILES := $(NNTRAINER_ROOT)/builddir/android_build_result/lib/$(TARGET_ARCH_ABI)/libccapi-nntrainer.so +include $(PREBUILT_SHARED_LIBRARY) + +# ── Prebuilt causallm library (from src Android build) ────────── +include $(CLEAR_VARS) +LOCAL_MODULE := causallm +LOCAL_SRC_FILES := $(QUICK_DOT_AI_ROOT)/jni/libs/$(TARGET_ARCH_ABI)/libcausallm.so +include $(PREBUILT_SHARED_LIBRARY) + +# ── Prebuilt quick_dot_ai library (from src Android build) ────── +include $(CLEAR_VARS) +LOCAL_MODULE := quick_dot_ai +LOCAL_SRC_FILES := $(QUICK_DOT_AI_ROOT)/jni/libs/$(TARGET_ARCH_ABI)/libquick_dot_ai.so +include $(PREBUILT_SHARED_LIBRARY) + +# ── Common flags ───────────────────────────────────────────────────────── +COMMON_CFLAGS := -std=c++17 -Ofast -mcpu=cortex-a53 \ + -DENABLE_FP16=1 -DUSE__FP16=1 -D__ARM_NEON__=1 \ + -march=armv8.2-a+fp16+dotprod+i8mm -DUSE_NEON=1 \ + -mtune=cortex-a76 -O3 -ffast-math \ + -pthread -fexceptions -fopenmp -static-openmp + +COMMON_LDFLAGS := -fexceptions -fopenmp -static-openmp \ + -DENABLE_FP16=1 -DUSE__FP16=1 -D__ARM_NEON__=1 \ + -march=armv8.2-a+fp16+dotprod+i8mm -DUSE_NEON=1 \ + -mtune=cortex-a76 -O3 -ffast-math + +# ── CausalLM includes ─────────────────────────────────────────────────── +CAUSALLM_INCLUDES := $(NNTRAINER_INCLUDES) \ + $(CAUSALLM_ROOT) \ + $(CAUSALLM_ROOT)/layers \ + $(CAUSALLM_ROOT)/models \ + $(sort $(dir $(wildcard $(CAUSALLM_ROOT)/models/*/))) + +# ══════════════════════════════════════════════════════════════════════════ +# Module: libquick_dot_ai_api.so (C API library) +# ══════════════════════════════════════════════════════════════════════════ +include $(CLEAR_VARS) + +LOCAL_ARM_NEON := true +LOCAL_CFLAGS += $(COMMON_CFLAGS) +LOCAL_CXXFLAGS += -std=c++17 -frtti +LOCAL_LDFLAGS += $(COMMON_LDFLAGS) +LOCAL_MODULE := quick_dot_ai_api + +LOCAL_SRC_FILES := \ + ../quick_dot_ai_api.cpp \ + ../model_callbacks.cpp \ + ../model_config.cpp \ + ../streamer.cpp \ + ../callback_streamer.cpp \ + ../model_descriptors_public.cpp + +LOCAL_SHARED_LIBRARIES := nntrainer ccapi-nntrainer causallm quick_dot_ai +LOCAL_LDLIBS := -llog -landroid -ldl + +# Include every model subdirectory present so the API TUs can resolve plugin +# headers. Globbed generically (public tree ships only gemma4-e2b-qnn; any +# proprietary model directory dropped in is picked up automatically). +LOCAL_C_INCLUDES += $(CAUSALLM_INCLUDES) \ + $(LOCAL_PATH)/.. \ + $(QUICK_DOT_AI_ROOT)/models/qnn \ + $(wildcard $(QUICK_DOT_AI_ROOT)/models/*/) \ + $(wildcard $(QUICK_DOT_AI_ROOT)/models/qnn/*/) + +include $(BUILD_SHARED_LIBRARY) diff --git a/api/jni/Application.mk b/api/jni/Application.mk new file mode 100644 index 00000000..b00bb435 --- /dev/null +++ b/api/jni/Application.mk @@ -0,0 +1,4 @@ +APP_ABI := arm64-v8a +APP_PLATFORM := android-29 +APP_STL := c++_shared +NDK_TOOLCHAIN_VERSION := clang \ No newline at end of file diff --git a/api/meson.build b/api/meson.build new file mode 100644 index 00000000..40a1e18d --- /dev/null +++ b/api/meson.build @@ -0,0 +1,42 @@ +# api/meson.build β€” builds libquick_dot_ai_api.so + +if not get_option('enable-api') + subdir_done() +endif + +api_inc = include_directories('.') + +# Extension model headers +quick_dot_ai_inc_args = [ + '-I' + meson.current_source_dir() / '..' / 'xgrammar' / 'include', + '-I' + meson.current_source_dir() / '..' / 'xgrammar' / '3rdparty' / 'picojson', + '-I' + meson.current_source_dir() / '..' / 'xgrammar' / '3rdparty' / 'dlpack' / 'include', + '-I' + meson.current_source_dir() / '..' / 'src' / 'xgrammar', + '-I' + meson.current_source_dir() / '..' / 'src' / 'models', +] + +if enable_qnn +quick_dot_ai_inc_args += [ + '-I' + meson.current_source_dir() / '..' / 'src' / 'models' / 'qnn', + '-I' + meson.current_source_dir() / '..' / 'src' / 'models' / 'qnn' / 'gemma4-e2b-qnn', +] +endif + +quick_dot_ai_api_lib = shared_library('quick_dot_ai_api', + 'quick_dot_ai_api.cpp', + 'model_callbacks.cpp', + 'model_config.cpp', + 'model_descriptors_public.cpp', + 'streamer.cpp', + 'callback_streamer.cpp', + include_directories: [api_inc], + dependencies: [nntrainer_dep, causallm_dep, quick_dot_ai_dep, openmp_dep, thread_dep, log_dep], + cpp_args: quick_dot_ai_inc_args, + install: true, +) + +quick_dot_ai_api_dep = declare_dependency( + link_with: quick_dot_ai_api_lib, + include_directories: api_inc, + dependencies: [causallm_dep], +) diff --git a/api/model_callbacks.cpp b/api/model_callbacks.cpp new file mode 100644 index 00000000..1d8557e5 --- /dev/null +++ b/api/model_callbacks.cpp @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: Apache-2.0 +#include "model_callbacks.h" + +ModelCallbackRegistry &ModelCallbackRegistry::instance() { + static ModelCallbackRegistry reg; + return reg; +} + +void ModelCallbackRegistry::register_for(const std::string &architecture, + ModelCallbacks cb) { + by_arch_[architecture] = std::move(cb); +} + +const ModelCallbacks * +ModelCallbackRegistry::lookup(const std::string &architecture) const { + auto it = by_arch_.find(architecture); + return it != by_arch_.end() ? &it->second : nullptr; +} + +bool ModelCallbackRegistry::any_requires_htp() const { + for (const auto &[arch, cb] : by_arch_) { + if (cb.requires_htp) + return true; + } + return false; +} diff --git a/api/model_callbacks.h b/api/model_callbacks.h new file mode 100644 index 00000000..ae6db521 --- /dev/null +++ b/api/model_callbacks.h @@ -0,0 +1,93 @@ +// SPDX-License-Identifier: Apache-2.0 +#pragma once +#include +#include +#include + +#include "quick_dot_ai_api.h" // ErrorCode, CausalLmTokenCallback, CausalLmHandle + +namespace causallm { +class Transformer; +} + +/** + * Per-architecture callbacks registered by proprietary model plugin TU files. + * When a plugin TU is absent (public build), no callbacks are registered for + * that architecture; callers should fall back to CAUSAL_LM_ERROR_UNSUPPORTED. + */ +struct ModelCallbacks { + /** + * Apply architecture-specific chat template to a raw single-turn input. + * Returns empty string if not registered (caller uses raw input). + */ + std::function format_prompt; + + /** True when this architecture requires an HTP/QNN backend. */ + bool requires_htp = false; + + /** + * Read the current KV-cache length from a loaded transformer. + * Used for incremental-session tracking. + * Returns 0 if not registered. + */ + std::function read_kv_len; + + /** + * Given the full prompt history (already-formatted), extract the latest user + * content and rebuild it as the minimal incremental prompt for next turn. + * Returns empty string if not registered. + */ + std::function incremental_prompt; + + /** + * Streaming multimodal execution. + * `handle` is CausalLmHandle (= CausalLmModel*). + * The plugin TU casts it to CausalLmModel* and accesses h.models[0]/[1]. + */ + std::function + multimodal_streaming; + + /** + * Blocking multimodal execution; appends generated text to *output. + * `handle` is CausalLmHandle (= CausalLmModel*). + */ + std::function + multimodal_blocking; +}; + +/** + * Registry keyed by architecture name string (the model's arch_string). + * Proprietary model plugin TUs call register_for() at static-init time. + * quick_dot_ai_api.cpp calls lookup() at runtime. + */ +class ModelCallbackRegistry { +public: + static ModelCallbackRegistry &instance(); + + /** Register callbacks for one architecture name. */ + void register_for(const std::string &architecture, ModelCallbacks cb); + + /** + * Look up callbacks for the given architecture. + * Returns nullptr if not registered (no plugin TU for this architecture). + */ + const ModelCallbacks *lookup(const std::string &architecture) const; + + /** True if ANY registered architecture has requires_htp = true. */ + bool any_requires_htp() const; + +private: + ModelCallbackRegistry() = default; + ModelCallbackRegistry(const ModelCallbackRegistry &) = delete; + ModelCallbackRegistry &operator=(const ModelCallbackRegistry &) = delete; + + std::unordered_map by_arch_; +}; diff --git a/api/model_config.cpp b/api/model_config.cpp index a84adbcb..3ebb0e6e 100644 --- a/api/model_config.cpp +++ b/api/model_config.cpp @@ -4,15 +4,19 @@ * * @file model_config.cpp * @date 22 Jan 2026 - * @brief This is a sample code for internal regitration of model_config. + * @brief Built-in model configuration registration for api. + * All calls use C++ namespaced functions β€” no extern "C" PLT calls. * @see https://github.com/nntrainer/nntrainer * @author Eunju Yang * @bug No known bugs except for NYI items */ -#include "causal_lm_api.h" #include "model_config_internal.h" +#include "quick_dot_ai_api.h" #include #include +#include + +namespace quick_dot_ai { static void register_qwen3_0_6b() { // 1. Architecture Config @@ -38,7 +42,7 @@ static void register_qwen3_0_6b() { ac.eos_token_ids[0] = 151645; ac.num_eos_token_ids = 1; - registerModelArchitecture("Qwen3-0.6B-Arch", ac); + register_arch("Qwen3-0.6B-Arch", ac); // 2. Runtime Config ModelRuntimeConfig rc; @@ -52,9 +56,9 @@ static void register_qwen3_0_6b() { rc.num_to_generate = 512; rc.fsu = false; rc.fsu_lookahead = 2; - strncpy(rc.embedding_dtype, "Q4_0", sizeof(rc.embedding_dtype) - 1); + strncpy(rc.embedding_dtype, "Q6_K", sizeof(rc.embedding_dtype) - 1); strncpy(rc.fc_layer_dtype, "Q4_0", sizeof(rc.fc_layer_dtype) - 1); - strncpy(rc.model_file_name, "qwen3-0.6b-q40-fp32-arm.bin", + strncpy(rc.model_file_name, "qwen3-0.6b-q6k-q40-q40-fp32-arm.bin", sizeof(rc.model_file_name) - 1); strncpy(rc.tokenizer_file, "tokenizer.json", sizeof(rc.tokenizer_file) - 1); strncpy(rc.lmhead_dtype, "Q4_0", sizeof(rc.lmhead_dtype) - 1); @@ -64,7 +68,7 @@ static void register_qwen3_0_6b() { rc.top_p = 0.95f; rc.temperature = 0.7f; - registerModel("Qwen3-0.6B-W4A32", "Qwen3-0.6B-Arch", rc); + register_model("Qwen3-0.6B-W4A32", "Qwen3-0.6B-Arch", rc); // Example for W32A32 (FP32) ModelRuntimeConfig rc_fp32 = rc; @@ -75,14 +79,15 @@ static void register_qwen3_0_6b() { strncpy(rc_fp32.lmhead_dtype, "FP32", sizeof(rc_fp32.lmhead_dtype) - 1); strncpy(rc_fp32.model_file_name, "qwen3-0.6b-fp32.bin", sizeof(rc_fp32.model_file_name) - 1); - registerModel("Qwen3-0.6B-W32A32", "Qwen3-0.6B-Arch", rc_fp32); + register_model("Qwen3-0.6B-W32A32", "Qwen3-0.6B-Arch", rc_fp32); // Register default alias - registerModel("Qwen3-0.6B", "Qwen3-0.6B-Arch", rc); + register_model("Qwen3-0.6B", "Qwen3-0.6B-Arch", rc); } -int register_builtin_model_configs() { +void register_builtin_configs() { register_qwen3_0_6b(); // Add more models here... - return 0; } + +} // namespace quick_dot_ai diff --git a/api/model_config_internal.h b/api/model_config_internal.h index 9830825e..bedab50e 100644 --- a/api/model_config_internal.h +++ b/api/model_config_internal.h @@ -1,59 +1,47 @@ // SPDX-License-Identifier: Apache-2.0 /** * @file model_config_internal.h - * @brief Internal Structures and Functions for Model Configuration - * This file should NOT be exposed to the public API users. - * @see https://github.com/nntrainer/nntrainer - * @author Eunju Yang - * @bug No known bugs except for NYI items + * @brief Internal structures and registration for api. + * Self-contained β€” does NOT depend on upstream causal_lm_api headers. */ -#ifndef __MODEL_CONFIG_INTERNAL_H__ -#define __MODEL_CONFIG_INTERNAL_H__ +#ifndef __QUICK_DOT_AI_MODEL_CONFIG_INTERNAL_H__ +#define __QUICK_DOT_AI_MODEL_CONFIG_INTERNAL_H__ -#include "causal_lm_api.h" #include #include -#ifdef __cplusplus -extern "C" { -#endif - /** - * @brief Model Architecture Configuration (config.json) + * @brief Model Architecture Configuration (replaces config.json) */ typedef struct { - // config.json parameters unsigned int vocab_size; unsigned int hidden_size; unsigned int intermediate_size; unsigned int num_hidden_layers; unsigned int num_attention_heads; unsigned int head_dim; - unsigned int num_key_value_heads; // if 0, defaults to num_attention_heads + unsigned int num_key_value_heads; unsigned int max_position_embeddings; float rope_theta; float rms_norm_eps; bool tie_word_embeddings; - unsigned int sliding_window; // Use UINT_MAX for null + unsigned int sliding_window; unsigned int sliding_window_pattern; - // generation_config.json (static model properties) unsigned int eos_token_ids[4]; unsigned int num_eos_token_ids; unsigned int bos_token_id; - // architecture identification - char architecture[64]; // e.g., "Qwen3ForCausalLM" + char architecture[64]; } ModelArchConfig; /** - * @brief Model Runtime/Execution Configuration (nntr_config.json) + * @brief Model Runtime Configuration (replaces nntr_config.json) */ typedef struct { - // nntr_config.json parameters unsigned int batch_size; - char model_type[32]; // e.g. "CausalLM" + char model_type[32]; char model_tensor_type[32]; unsigned int init_seq_len; unsigned int max_seq_len; @@ -68,40 +56,29 @@ typedef struct { unsigned int num_bad_word_ids; char lmhead_dtype[32]; - // generation_config.json (runtime parameters) unsigned int top_k; float top_p; float temperature; } ModelRuntimeConfig; +namespace quick_dot_ai { + /** - * @brief Register a model architecture configuration - * @param arch_name Name of the architecture (e.g., "Qwen3-0.6B-Arch") - * @param config Architecture configuration - * @return ErrorCode + * @brief Register a model architecture config (writes to g_arch_config_map) */ -ErrorCode registerModelArchitecture(const char *arch_name, - ModelArchConfig config); +void register_arch(const char *arch_name, ModelArchConfig config); /** - * @brief Register a full model configuration linking runtime config to an - * architecture - * @param model_name Name of the model to register (e.g., "Qwen3-0.6B") - * @param arch_name Name of the registered architecture to use - * @param config Runtime configuration - * @return ErrorCode + * @brief Register a model runtime config (writes to g_model_registry) */ -ErrorCode registerModel(const char *model_name, const char *arch_name, - ModelRuntimeConfig config); +void register_model(const char *model_name, const char *arch_name, + ModelRuntimeConfig config); /** - * @brief Register built-in model configurations (e.g., Qwen3-0.6B) - * @return 0 on success + * @brief Called from register_models() to register all built-in configs */ -int register_builtin_model_configs(); +void register_builtin_configs(); -#ifdef __cplusplus -} -#endif +} // namespace quick_dot_ai -#endif // __MODEL_CONFIG_INTERNAL_H__ +#endif // __QUICK_DOT_AI_MODEL_CONFIG_INTERNAL_H__ diff --git a/api/model_descriptor.h b/api/model_descriptor.h new file mode 100644 index 00000000..72d0f605 --- /dev/null +++ b/api/model_descriptor.h @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * @file model_descriptor.h + * @brief T4 string-id model catalog schema (pluggable, self-registering). + */ +#ifndef __QUICK_DOT_AI_MODEL_DESCRIPTOR_H__ +#define __QUICK_DOT_AI_MODEL_DESCRIPTOR_H__ + +#ifdef __cplusplus +extern "C" { +#endif + +typedef enum { + QDA_RUNTIME_NATIVE = 0, /**< nntrainer (NativeQuickDotAI) */ + QDA_RUNTIME_LITERT = 1, /**< LiteRT-LM - Kotlin only, not registered in C */ +} RuntimeKind; + +typedef enum { + QDA_CAP_STREAMING = 1u << 0, + QDA_CAP_MESSAGES_API = 1u << 1, /**< requires messages-based API */ + QDA_CAP_MULTIMODAL = 1u << 2, + QDA_CAP_TOOL_USE = 1u << 3, + QDA_CAP_EMBEDDING = 1u << 4, + QDA_CAP_MULTI_IMAGE = 1u << 5, /**< supports multiple images (e.g. V-JEPA) */ + QDA_CAP_VISION_ENCODER = + 1u << 6, /**< standalone vision embedding producer; pair with an LLM */ +} CapabilityFlag; + +/** + * All `const char*` fields must point to storage with lifetime at least as + * long as the process (e.g. string literals or static storage). The registry + * stores pointers, not copies. + */ +typedef struct { + const char *id; /**< "Qwen3-0.6B" (catalog key) */ + const char *family; /**< "qwen3-0.6b" */ + const char *display_name; /**< "Qwen3 0.6B" */ + RuntimeKind runtime; /**< QDA_RUNTIME_NATIVE */ + unsigned int backend_mask; /**< bit i set => BackendType i supported */ + unsigned int capabilities; /**< CapabilityFlag OR */ + const char + *config_name; /**< g_model_registry lookup key e.g. "Qwen3-0.6B-W4A32" */ + const char + *arch_string; /**< causallm::Factory key e.g. "Qwen3ForCausalLM" */ +} ModelDescriptor; + +#ifdef __cplusplus +} // extern "C" +#endif + +#ifdef __cplusplus +namespace quick_dot_ai { +/** + * @brief Register a descriptor into g_descriptor_registry. + * Duplicate id overwrites the existing entry (logs a warning). + */ +void register_model_descriptor(const ModelDescriptor *desc); +} // namespace quick_dot_ai +#endif + +#endif /* __QUICK_DOT_AI_MODEL_DESCRIPTOR_H__ */ diff --git a/api/model_descriptors_public.cpp b/api/model_descriptors_public.cpp new file mode 100644 index 00000000..1f00939a --- /dev/null +++ b/api/model_descriptors_public.cpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * @file model_descriptors_public.cpp + * @brief Public model descriptor self-registration. Proprietary model + * plugins register themselves in their own TUs, not here. + * + * config_name values verified against get_model_name_from_type() in + * quick_dot_ai_api.cpp. arch_string values verified against + * register_models() Factory registrations in the same file. + */ +#include "model_descriptor.h" + +using namespace quick_dot_ai; + +#define B(x) (1u << (unsigned)(x)) /* BackendType: CPU=0, GPU=1, NPU=2 */ + +__attribute__((constructor)) static void register_public_descriptors() { + static const ModelDescriptor kPublic[] = { + {"qwen3-0.6b", "qwen3-0.6b", "Qwen3 0.6B", QDA_RUNTIME_NATIVE, B(0) | B(1), + QDA_CAP_STREAMING | QDA_CAP_TOOL_USE, + "QWEN3-0.6B", /* get_model_name_from_type(CAUSAL_LM_MODEL_QWEN3_0_6B) */ + "Qwen3ForCausalLM"}, + {"qwen3-1.7b-q40", "qwen3-1.7b", "Qwen3 1.7B (Q40)", QDA_RUNTIME_NATIVE, + B(0) | B(1), QDA_CAP_STREAMING | QDA_CAP_TOOL_USE, + "QWEN3-1.7B-Q40", /* get_model_name_from_type(CAUSAL_LM_MODEL_QWEN3_1_7B_Q40) + */ + "Qwen3ForCausalLM"}, + {"tiny-bert", "tiny-bert", "Tiny BERT", QDA_RUNTIME_NATIVE, B(0), + QDA_CAP_EMBEDDING, + "TINY_BERT", /* get_model_name_from_type(CAUSAL_LM_MODEL_TINY_BERT) */ + "MultilingualTinyBert"}, + {"function-gemma", "function-gemma", "Function Gemma", QDA_RUNTIME_NATIVE, + B(0) | B(1), QDA_CAP_TOOL_USE, + "FUNCTION_GEMMA", /* get_model_name_from_type(CAUSAL_LM_MODEL_FUNCTION_GEMMA) + */ + "Gemma3ForCausalLM"}, + {"gemma4-cpu", "gemma4", "Gemma4 (CPU)", QDA_RUNTIME_NATIVE, B(0), + QDA_CAP_STREAMING, + "GEMMA4_CPU", /* get_model_name_from_type(CAUSAL_LM_MODEL_GEMMA4_CPU) */ + "Gemma4ForCausalLM" /* Factory registration pending */}, +#ifdef ENABLE_QNN + {"gemma4-e2b-qnn", "gemma4", "Gemma4 E2B (QNN)", QDA_RUNTIME_NATIVE, B(2), + QDA_CAP_MESSAGES_API, + "GEMMA4-E2B-QNN", /* get_model_name_from_type(CAUSAL_LM_MODEL_GEMMA4_E2B_QNN) + */ + "Gemma4_E2B_QNN"}, + {"vjepa-qnn", "vjepa", "V-JEPA (QNN)", QDA_RUNTIME_NATIVE, B(2), + QDA_CAP_MULTIMODAL | QDA_CAP_MESSAGES_API | QDA_CAP_MULTI_IMAGE, + "VJEPA-QNN", /* get_model_name_from_type(CAUSAL_LM_MODEL_VJEPA_QNN) */ + "VJEPA_QNN"}, +#endif + }; + for (const auto &d : kPublic) + register_model_descriptor(&d); +} diff --git a/api/quick_dot_ai_api.cpp b/api/quick_dot_ai_api.cpp new file mode 100644 index 00000000..19523d73 --- /dev/null +++ b/api/quick_dot_ai_api.cpp @@ -0,0 +1,2844 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2026 Samsung Electronics Co., Ltd. All Rights Reserved. + * + * @file quick_dot_ai_api.cpp + * @date 21 Jan 2026 + * @brief This is a C API for CausalLM application + * @see https://github.com/nntrainer/nntrainer + * @author Eunju Yang + * @bug No known bugs except for NYI items + */ + +#include "quick_dot_ai_api.h" +#ifdef ENABLE_QNN +#include "quick_dot_ai_qnn.h" +#endif +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "causal_lm.h" +#include "chat_template.h" +#include "sentence_transformer.h" +#include "gemma3_causallm.h" +#include "gemma4_causallm.h" +#include "gptoss_cached_slim_causallm.h" +#include "gptoss_causallm.h" +#include "json.hpp" +#include "model_config_internal.h" +#include "model_descriptor.h" +#include "multilingual_tinybert_16mb.h" +#include "qwen2_causallm.h" +#include "qwen3_cached_slim_moe_causallm.h" +#include "qwen3_causallm.h" +#include "qwen3_moe_causallm.h" +#include "qwen3_slim_moe_causallm.h" +#include "xgrammar_manager.h" +#include "xgrammar_wrapper.h" +#include +#include "model_callbacks.h" +#ifdef ENABLE_QNN +#include "gemma4_e2b_qnn.h" +#include "quick_dot_ai_qnn.h" + +#endif +#include +#include +#include + +#ifdef __ANDROID__ +#include +#define LOG_TAG "QuickAI" +#define LOGD(...) __android_log_print(ANDROID_LOG_DEBUG, LOG_TAG, __VA_ARGS__) +#define LOGE(...) __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, __VA_ARGS__) +#else +#define LOGD(fmt, ...) fprintf(stdout, fmt "\n", ##__VA_ARGS__) +#define LOGE(fmt, ...) fprintf(stderr, fmt "\n", ##__VA_ARGS__) +#endif + +using json = nlohmann::json; +using causallm::multimodal_pointer; + +/** + * @brief Per-handle state for a loaded CausalLM model instance. + * + * Each handle may carry one or more sub-models so that compositions like + * vision-encoder + LLM can live behind a single handle. The vectors are + * kept parallel: models[i] ↔ architectures[i] ↔ model_dirs[i] ↔ + * initialization_duration_ms[i]. The single-model API paths + * (runModelHandleWithMessages / runModelHandleStreaming) operate on models[0] + * and ignore the rest; the multimodal API drives the full set. + * + * Note: the legacy non-handle API (loadModel / ...) is + * implemented on top of a single static "default" instance of this struct + * so that existing callers (e.g. test_api) keep working unchanged. + */ +struct CausalLmModel { + std::mutex mtx; + std::vector> models; + std::vector architectures; + std::vector model_dirs; + std::string last_output; + std::string native_lib_dir; + std::vector initialization_duration_ms; + bool initialized = false; + int kv_len = 0; +}; + +// Globals shared across all handles β€” options set via setOptions() apply +// process-wide regardless of which handle is active. +static std::mutex g_registry_mutex; +static bool g_use_chat_template = true; +static bool g_verbose = false; +static std::string g_last_output = ""; +static std::optional g_chat_template; +static std::string g_formatted_template; +static std::string g_chat_template_name = "default"; + +// Default handle backing the legacy non-handle API. +static CausalLmModel &get_default_handle() { + static CausalLmModel instance; + return instance; +} + +static std::map g_model_path_map = { + {"QWEN3-0.6B", "qwen3-0.6b"}, + {"QWEN3-1.7B-Q40", "qwen3-1.7b-q40-arm"}, + {"TINY_BERT", "tiny_bert"}, + {"FUNCTION_GEMMA", "function_gemma"}, + {"GEMMA4_CPU", "gemma4_cpu"}, +#ifdef ENABLE_QNN + {"GEMMA4-E2B-QNN", "gemma-4-e2b-qnn"}, + {"VJEPA-QNN", "vjepa-qnn"}, +#endif +}; + +/** + * @brief RegisteredModel + */ +struct RegisteredModel { + std::string arch_name; + ModelRuntimeConfig config; +}; +static std::map g_model_registry; +static std::map g_arch_config_map; + +// Internal C++ registration functions β€” called from model_config.cpp +// These bypass extern "C" PLT and write directly to our static maps. +namespace quick_dot_ai { + +void register_arch(const char *arch_name, ModelArchConfig config) { + std::string name(arch_name); + std::transform(name.begin(), name.end(), name.begin(), ::toupper); + g_arch_config_map[name] = config; +} + +void register_model(const char *model_name, const char *arch_name, + ModelRuntimeConfig config) { + std::string name(model_name); + std::transform(name.begin(), name.end(), name.begin(), ::toupper); + std::string aname(arch_name); + std::transform(aname.begin(), aname.end(), aname.begin(), ::toupper); + g_model_registry[name] = {aname, config}; +} + +} // namespace quick_dot_ai + +// --------------------------------------------------------------------------- +// T4: string-id descriptor registry + catalog JSON +// --------------------------------------------------------------------------- + +// Lazily-constructed (Meyers singleton) so cross-library self-registration +// from the model plugin (libquick_dot_ai.so) β€” whose static constructors run +// BEFORE this lib's globals would be constructed β€” lands in a live registry +// instead of an uninitialized one (static-init-order fiasco). +static std::mutex &descriptor_mutex() { + static std::mutex m; + return m; +} +static std::vector &descriptor_registry() { + static std::vector v; + return v; +} + +namespace quick_dot_ai { +void register_model_descriptor(const ModelDescriptor *desc) { + if (!desc || !desc->id) + return; + std::lock_guard lock(descriptor_mutex()); + for (auto &d : descriptor_registry()) { + if (std::strcmp(d.id, desc->id) == 0) { + LOGE("register_model_descriptor: duplicate id '%s', overwriting", + desc->id); + d = *desc; + return; + } + } + descriptor_registry().push_back(*desc); +} +} // namespace quick_dot_ai + +/** Find a descriptor by string id. Returns a copy while locked, or nullopt. */ +static std::optional find_descriptor_by_id(const char *id) { + if (!id) + return std::nullopt; + std::lock_guard lk(descriptor_mutex()); + for (const auto &d : descriptor_registry()) + if (std::strcmp(d.id, id) == 0) + return d; // copy while locked + return std::nullopt; +} + +// Library-owned buffer: rebuilt on every call and returned via c_str(). +// The pointer is valid only until the next call to getModelCatalogJson(). +static std::string g_catalog_json_cache; + +/** + * Returns a pointer to a library-owned buffer containing a JSON array of + * registered model descriptors. The buffer is valid only until the next + * call to getModelCatalogJson(). Callers must copy the contents immediately + * (e.g. via JNI NewStringUTF) and must not hold the pointer across calls. + * Not safe for concurrent access to the returned pointer. + */ +extern "C" const char *getModelCatalogJson(void) { + auto json_escape = [](const char *s) -> std::string { + if (!s) + return ""; + std::string out; + for (; *s; ++s) { + if (*s == '"') + out += "\\\""; + else if (*s == '\\') + out += "\\\\"; + else + out += *s; + } + return out; + }; + + std::lock_guard lock(descriptor_mutex()); + std::ostringstream os; + os << "["; + for (size_t i = 0; i < descriptor_registry().size(); ++i) { + const auto &d = descriptor_registry()[i]; + if (i) + os << ","; + os << "{\"id\":\"" << json_escape(d.id) << "\",\"family\":\"" + << json_escape(d.family) << "\",\"display_name\":\"" + << json_escape(d.display_name ? d.display_name : d.id) + << "\",\"runtime\":" << static_cast(d.runtime) + << ",\"backend_mask\":" << d.backend_mask + << ",\"capabilities\":" << d.capabilities << "}"; + } + os << "]"; + g_catalog_json_cache = os.str(); + return g_catalog_json_cache.c_str(); +} + +// Helper to register models (similar to main.cpp) ensuring factory is +// populated. Factory registration is singleton and persistent, but we do it +// once here to be sure. Since mquiain.cpp is not linked, we must duplicate +// registration or share it. Assuming this lib is used independently of +// main.cpp. +static void register_models() { + static std::once_flag flag; + std::call_once(flag, []() { + causallm::Factory::Instance().registerModel( + "LlamaForCausalLM", [](json cfg, json generation_cfg, json nntr_cfg) { + return std::make_unique(cfg, generation_cfg, + nntr_cfg); + }); + causallm::Factory::Instance().registerModel( + "Qwen2ForCausalLM", [](json cfg, json generation_cfg, json nntr_cfg) { + return std::make_unique(cfg, generation_cfg, + nntr_cfg); + }); + causallm::Factory::Instance().registerModel( + "Qwen3ForCausalLM", [](json cfg, json generation_cfg, json nntr_cfg) { + return std::make_unique(cfg, generation_cfg, + nntr_cfg); + }); + causallm::Factory::Instance().registerModel( + "Qwen3MoeForCausalLM", [](json cfg, json generation_cfg, json nntr_cfg) { + return std::make_unique(cfg, generation_cfg, + nntr_cfg); + }); + causallm::Factory::Instance().registerModel( + "Qwen3SlimMoeForCausalLM", + [](json cfg, json generation_cfg, json nntr_cfg) { + return std::make_unique( + cfg, generation_cfg, nntr_cfg); + }); + causallm::Factory::Instance().registerModel( + "Qwen3CachedSlimMoeForCausalLM", + [](json cfg, json generation_cfg, json nntr_cfg) { + return std::make_unique( + cfg, generation_cfg, nntr_cfg); + }); + causallm::Factory::Instance().registerModel( + "GptOssForCausalLM", [](json cfg, json generation_cfg, json nntr_cfg) { + return std::make_unique( + cfg, generation_cfg, nntr_cfg); + }); + causallm::Factory::Instance().registerModel( + "GptOssCachedSlimCausalLM", + [](json cfg, json generation_cfg, json nntr_cfg) { + return std::make_unique( + cfg, generation_cfg, nntr_cfg); + }); + causallm::Factory::Instance().registerModel( + "Gemma3ForCausalLM", [](json cfg, json generation_cfg, json nntr_cfg) { + return std::make_unique(cfg, generation_cfg, + nntr_cfg); + }); + causallm::Factory::Instance().registerModel( + "Gemma4ForCausalLM", [](json cfg, json generation_cfg, json nntr_cfg) { + return std::make_unique(cfg, generation_cfg, + nntr_cfg); + }); + causallm::Factory::Instance().registerModel( + "MultilingualTinyBert", [](json cfg, json generation_cfg, json nntr_cfg) { + return std::make_unique( + cfg, generation_cfg, nntr_cfg); + }); + +#ifdef ENABLE_QNN + causallm::Factory::Instance().registerModel( + "Gemma4_E2B_QNN", [](json cfg, json generation_cfg, json nntr_cfg) { + return std::make_unique(cfg, generation_cfg, + nntr_cfg); + }); +#endif + // Register built-in configurations + quick_dot_ai::register_builtin_configs(); + }); +} + +static const char *get_model_name_from_type(ModelType type) { + switch (type) { + case CAUSAL_LM_MODEL_QWEN3_0_6B: + return "QWEN3-0.6B"; + case CAUSAL_LM_MODEL_QWEN3_1_7B_Q40: + return "QWEN3-1.7B-Q40"; + case CAUSAL_LM_MODEL_TINY_BERT: + return "TINY_BERT"; + case CAUSAL_LM_MODEL_FUNCTION_GEMMA: + return "FUNCTION_GEMMA"; + case CAUSAL_LM_MODEL_GEMMA4_CPU: + return "GEMMA4_CPU"; +#ifdef ENABLE_QNN + case CAUSAL_LM_MODEL_GEMMA4_E2B_QNN: + return "GEMMA4-E2B-QNN"; + case CAUSAL_LM_MODEL_VJEPA_QNN: + return "VJEPA-QNN"; +#endif + default: + return nullptr; + } +} + +static std::string apply_chat_template(const std::string &architecture, + const std::string &input) { + // Use dynamic chat template from tokenizer_config.json if available + if (g_chat_template) { + nlohmann::json request; + request["messages"] = nlohmann::json::array(); + request["messages"].push_back({{"role", "user"}, {"content", input}}); + request["add_generation_prompt"] = true; + try { + return g_chat_template->apply(request); + } catch (const std::exception &e) { + LOGE("Chat template apply failed: %s", e.what()); + // fallback to hardcoded + } + } + + LOGE("----------------APPLY CHAT FALLBACKS!!!!!!-------------"); + + // Fallback: hardcoded per-architecture templates + if (architecture == "LlamaForCausalLM") { + // Llama 2/3 chat format: [INST] {prompt} [/INST] + return "[INST] " + input + " [/INST]"; + } else if (architecture == "Qwen2ForCausalLM" || + architecture == "Qwen3ForCausalLM" || + architecture == "Qwen3MoeForCausalLM" || + architecture == "Qwen3SlimMoeForCausalLM" || + architecture == "Qwen3CachedSlimMoeForCausalLM") { + // Qwen chat format + // <|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n + return "<|im_start|>user\n" + input + "<|im_end|>\n<|im_start|>assistant\n"; + } else if (architecture == "Gemma3ForCausalLM") { + // Gemma chat format: + // user\n{prompt}\nmodel\n + return "user\n" + input + + "\nmodel\n"; + } else if (architecture == "Gemma4ForCausalLM" || + architecture == "Gemma4_E2B_QNN") { + // Gemma 4 requires the prompt to begin with the token. The model's + // own nntr_config.json sample_input documents the canonical format as + // "<|turn>user\n...\n<|turn>model\n". This model's + // tokenizer.json post_processor does NOT add , and tokenizer_config + // .json carries no chat_template / add_bos_token, so it must be added + // here or the model receives a BOS-less prompt and emits garbage. + // "" is a special added token (id 2) and encodes to that single id. + return "<|turn>user\n" + input + "\n<|turn>model\n"; + } else { + if (const auto *cb = ModelCallbackRegistry::instance().lookup(architecture)) { + if (cb->format_prompt) { + return cb->format_prompt(input); + } + } + } + return input; +} + +static size_t text_generation_model_index(const CausalLmModel &h) { + // Convention: a multi-model handle is [vision producer, text LLM, ...]; + // text generation runs on the LLM at index 1. + return (h.models.size() > 1) ? 1 : 0; +} + +static std::string trim_wrapping_newlines(std::string value) { + while (!value.empty() && (value.front() == '\n' || value.front() == '\r')) { + value.erase(value.begin()); + } + while (!value.empty() && (value.back() == '\n' || value.back() == '\r')) { + value.pop_back(); + } + return value; +} + +static void reset_handle_session_state(CausalLmModel &h) { h.kv_len = 0; } + +static void update_handle_session_after_run(CausalLmModel &h, + size_t model_index) { + if (model_index >= h.models.size() || model_index >= h.architectures.size()) + return; + const auto *cb = ModelCallbackRegistry::instance().lookup(h.architectures[model_index]); + if (!cb || !cb->read_kv_len) + return; + h.kv_len = cb->read_kv_len(h.models[model_index].get()); +} + +#ifdef ENABLE_QNN +static causallm::Quick_Dot_AI_QNN *find_qnn_kv_cache_model(CausalLmModel &h) { + for (auto &m : h.models) { + auto *q = dynamic_cast(m.get()); + if (q && q->supportsKvCachePersistence()) return q; + } + return nullptr; +} +#endif + +static ErrorCode save_qnn_kv_cache_on_handle(CausalLmModel &h, + const char *cache_path) { +#ifndef ENABLE_QNN + (void)h; + (void)cache_path; + return CAUSAL_LM_ERROR_UNSUPPORTED; +#else + if (cache_path == nullptr || cache_path[0] == '\0') { + return CAUSAL_LM_ERROR_INVALID_PARAMETER; + } + + std::lock_guard lock(h.mtx); + if (!h.initialized || h.models.empty()) { + return CAUSAL_LM_ERROR_NOT_INITIALIZED; + } + + auto *model = find_qnn_kv_cache_model(h); + if (model == nullptr) { + return CAUSAL_LM_ERROR_UNSUPPORTED; + } + + try { + model->saveKvCache(cache_path); + h.kv_len = model->getKvLen(); + } catch (const std::exception &e) { + LOGE("saveQnnKvCacheHandle failed: %s", e.what()); + return CAUSAL_LM_ERROR_UNKNOWN; + } + + return CAUSAL_LM_ERROR_NONE; +#endif +} + +static ErrorCode load_qnn_kv_cache_on_handle(CausalLmModel &h, + const char *cache_path) { +#ifndef ENABLE_QNN + (void)h; + (void)cache_path; + return CAUSAL_LM_ERROR_UNSUPPORTED; +#else + if (cache_path == nullptr || cache_path[0] == '\0') { + return CAUSAL_LM_ERROR_INVALID_PARAMETER; + } + + std::lock_guard lock(h.mtx); + if (!h.initialized || h.models.empty()) { + return CAUSAL_LM_ERROR_NOT_INITIALIZED; + } + + auto *model = find_qnn_kv_cache_model(h); + if (model == nullptr) { + return CAUSAL_LM_ERROR_UNSUPPORTED; + } + + try { + model->loadKvCache(cache_path); + h.kv_len = model->getKvLen(); + } catch (const std::exception &e) { + LOGE("loadQnnKvCacheHandle failed: %s", e.what()); + return CAUSAL_LM_ERROR_UNKNOWN; + } + + return CAUSAL_LM_ERROR_NONE; +#endif +} + +static ErrorCode reset_qnn_kv_cache_on_handle(CausalLmModel &h) { +#ifndef ENABLE_QNN + (void)h; + return CAUSAL_LM_ERROR_UNSUPPORTED; +#else + std::lock_guard lock(h.mtx); + if (!h.initialized || h.models.empty()) { + return CAUSAL_LM_ERROR_NOT_INITIALIZED; + } + + auto *model = find_qnn_kv_cache_model(h); + if (model == nullptr) { + return CAUSAL_LM_ERROR_UNSUPPORTED; + } + + try { + model->resetKvCache(); + reset_handle_session_state(h); + } catch (const std::exception &e) { + LOGE("resetQnnKvCacheHandle failed: %s", e.what()); + return CAUSAL_LM_ERROR_UNKNOWN; + } + + return CAUSAL_LM_ERROR_NONE; +#endif +} + +static std::string prepare_input_for_model(CausalLmModel &h, size_t model_index, + const std::string &input, + bool input_already_formatted) { + if (model_index >= h.architectures.size() || !g_use_chat_template) { + return input; + } + + const std::string &architecture = h.architectures[model_index]; + if (h.kv_len > 0) { + const auto *cb = ModelCallbackRegistry::instance().lookup(architecture); + if (cb && cb->incremental_prompt) { + return cb->incremental_prompt(input); + } + } + + if (input_already_formatted) { + return input; + } + + return apply_chat_template(architecture, input); +} + +static std::string get_quantization_suffix(ModelQuantizationType type) { + return ""; + switch (type) { + case CAUSAL_LM_QUANTIZATION_W4A32: + return "-w4a32"; + case CAUSAL_LM_QUANTIZATION_W16A16: + return "-w16a16"; + case CAUSAL_LM_QUANTIZATION_W8A16: + return "-w8a16"; + case CAUSAL_LM_QUANTIZATION_W32A32: + return "-w32a32"; + default: // W4A32 by default + return "-w4a32"; + } +} + +static std::string resolve_model_path(const std::string &model_key, + ModelQuantizationType quant_type) { + std::string path_upper = model_key; + std::transform(path_upper.begin(), path_upper.end(), path_upper.begin(), + ::toupper); + + std::string base_dir_name = ""; + + // 1. Try to find base directory name from map + if (g_model_path_map.find(path_upper) != g_model_path_map.end()) { + base_dir_name = g_model_path_map[path_upper]; + } else { + // Fallback: use lowercased key as base dir name if not found in map + // or just return empty? For restricted API, we should probably fail + // earlier, but here we can return constructed path. + base_dir_name = path_upper; + std::transform(base_dir_name.begin(), base_dir_name.end(), + base_dir_name.begin(), ::tolower); + } + + std::string model_path = + "/" + base_dir_name + get_quantization_suffix(quant_type); + + return model_path; +} + +/** + * @brief Rebase path-like keys of a sub-model nntr_config.json onto @p sub_dir. + * + * Called once per sub-model inside the multi-model branch of + * load_into_handle so that downstream code (Factory::create, load_weight) + * sees absolute paths β€” mirrors the inline fixups the single-model path + * already performs for model_file_name / binary_config_path / ... + * + * Absolute values (leading '/') are left untouched so the caller can + * override a specific file with a system-wide path if they want. + */ +static bool is_absolute_path(const std::string &path) { + return !path.empty() && path[0] == '/'; +} + +static std::string rebase_path(const std::string &path, + const std::string &base_dir) { + if (path.empty() || is_absolute_path(path)) + return path; + return base_dir + "/" + path; +} + +static void fix_paths(json &nntr_cfg, const std::string &sub_dir) { + static const char *kKeys[] = { + "tokenizer_file", "model_file_name", "binary_config_path", + "image_newline_path", "embedding_file_name", "ple_file_name", + }; + for (const char *k : kKeys) { + if (!nntr_cfg.contains(k) || !nntr_cfg[k].is_string()) + continue; + std::string v = nntr_cfg[k].get(); + nntr_cfg[k] = rebase_path(v, sub_dir); + } +} + +static bool check_file_exists(const std::string &path) { + struct stat buffer; + return (stat(path.c_str(), &buffer) == 0); +} + +static void validate_models() { + LOGD("[DEBUG] Validating model files..."); + // Iterate over all known model names in map + for (auto const &[key, val] : g_model_path_map) { + // We want to check for each Quantization Type if it exists + // List of quant types to check: UNKNOWN (default), W4A32, W16A16, W32A32 + std::vector quant_types = { + CAUSAL_LM_QUANTIZATION_UNKNOWN, CAUSAL_LM_QUANTIZATION_W4A32, + CAUSAL_LM_QUANTIZATION_W16A16, CAUSAL_LM_QUANTIZATION_W32A32}; + + for (auto qt : quant_types) { + std::string quant_suffix = get_quantization_suffix(qt); + + std::string lookup_key = key; + if (qt != CAUSAL_LM_QUANTIZATION_UNKNOWN) { + std::transform(quant_suffix.begin(), quant_suffix.end(), + quant_suffix.begin(), ::toupper); // "-W4A32" + lookup_key += quant_suffix; + } + + // Resolve path for this combination + std::string resolved_path = "./models" + resolve_model_path(key, qt); + + if (g_model_registry.find(lookup_key) != g_model_registry.end()) { + // CASE 1: Configuration is registered in model_config.cpp + // For these models, we only check if the binary weight file exists. + // The configurations (config.json, etc.) are embedded in the library. + RegisteredModel &rm = g_model_registry[lookup_key]; + std::string bin_file_name = rm.config.model_file_name; + std::string full_path = resolved_path + "/" + bin_file_name; + + if (check_file_exists(full_path)) { + LOGD(" [OK] Reg Config: %s -> %s", lookup_key.c_str(), + full_path.c_str()); + } else { + LOGD(" [FAIL] Reg Config: %s -> Missing binary: %s", + lookup_key.c_str(), full_path.c_str()); + } + } else { + // CASE 2: No internal config, but model type exists (via map + // iteration). For these models, we require external configuration files + // (config.json, nntr_config.json) to be present in the directory. + if (check_file_exists(resolved_path)) { + bool has_config = check_file_exists(resolved_path + "/config.json"); + bool has_nntr = + check_file_exists(resolved_path + "/nntr_config.json"); + + if (has_config && has_nntr) { + LOGD(" [OK] External Config: %s -> %s", lookup_key.c_str(), + resolved_path.c_str()); + // Optional: Parse nntr_config to check bin + try { + json nntr = + causallm::LoadJsonFile(resolved_path + "/nntr_config.json"); + if (nntr.contains("model_file_name")) { + std::string bin = nntr["model_file_name"]; + if (check_file_exists(resolved_path + "/" + bin)) { + LOGD(" (Binary confirmed: %s)", bin.c_str()); + } else { + LOGD(" (MISSING BINARY: %s)", bin.c_str()); + } + } + } catch (...) { + } + } else { + LOGD(" [FAIL] External Config: %s -> Missing configs in %s", + lookup_key.c_str(), resolved_path.c_str()); + } + } + } + } + } +} + +ErrorCode setOptions(Config config) { + // Currently no options are being handled + g_use_chat_template = config.use_chat_template; + g_verbose = config.verbose; + g_chat_template_name = (config.chat_template_name != nullptr) + ? config.chat_template_name + : "default"; + if (config.debug_mode) { + // Ensure models are registered so we can validate them + register_models(); + validate_models(); + } + return CAUSAL_LM_ERROR_NONE; +} + +ErrorCode loadToolset(const char *toolset_path, + tokenizers::Tokenizer *tokenizer, + unsigned int vocab_size) { + if (toolset_path == nullptr) { + return CAUSAL_LM_ERROR_INVALID_PARAMETER; + } + if (tokenizer == nullptr) { + std::cerr << "Error: Tokenizer is null" << std::endl; + return CAUSAL_LM_ERROR_UNKNOWN; + } + + LOGD("[LoadToolset] load toolset path: %s", toolset_path); + + try { + // Load and pre-compile all tool grammars + bool success = causallm::XGrammarManager::Instance().loadToolset( + std::string(toolset_path), tokenizer, vocab_size); + LOGD("causallm::XGrammarManager::loadToolset() done"); + if (!success) { + return CAUSAL_LM_ERROR_UNKNOWN; + } + } catch (const std::exception &e) { + std::cerr << "Exception in loadToolset: " << e.what() << std::endl; + return CAUSAL_LM_ERROR_UNKNOWN; + } + + return CAUSAL_LM_ERROR_NONE; +} + +ErrorCode registerModelArchitecture(const char *arch_name, + ModelArchConfig config) { + if (arch_name == nullptr) + return CAUSAL_LM_ERROR_INVALID_PARAMETER; + std::string name(arch_name); + std::transform(name.begin(), name.end(), name.begin(), ::toupper); + g_arch_config_map[name] = config; + return CAUSAL_LM_ERROR_NONE; +} + +ErrorCode registerModel(const char *model_name, const char *arch_name, + ModelRuntimeConfig config) { + if (model_name == nullptr || arch_name == nullptr) + return CAUSAL_LM_ERROR_INVALID_PARAMETER; + std::string name(model_name); + std::transform(name.begin(), name.end(), name.begin(), ::toupper); + + std::string aname(arch_name); + std::transform(aname.begin(), aname.end(), aname.begin(), ::toupper); + + g_model_registry[name] = {aname, config}; + return CAUSAL_LM_ERROR_NONE; +} + +/** + * @brief Core loader shared by loadModel and loadModelHandle. + * + * Populates the given handle's model / architecture / init-duration + * vectors on success. Takes the handle's own mutex so two concurrent + * loads on the same handle are serialized, while loads on different + * handles run in parallel. A separate registry mutex protects + * g_model_registry / g_arch_config_map during lookup. + * + * Dispatch in CASE 2 (file-based): + * - If the top-level nntr_config.json has both "architectures" (string + * array) and "model_dirs" (string array) of equal non-zero length, + * loads one sub-model per entry (e.g. vision encoder + LLM). + * - Otherwise loads a single model from the resolved directory using + * the pre-existing flow. + */ +#ifdef ENABLE_QNN +// Point the QNN HTP backend at htp_backend_ext_config.json for ANY QNN arch. +// Must run BEFORE constructing/initializing a QNN model β€” single OR multi-model +// sub-models (e.g. the vision encoder of a multimodal pair) β€” otherwise +// QNNContext falls back to the process cwd ("/" for an installed APK) and QNN +// layer registration throws not_supported ("Unable to load backend extensions +// config"). Honors an externally configured path if already set. +static void ensure_qnn_backend_ext_config(const std::string &base_dir) { + const char *configured = + getenv("QUICK_DOT_AI_QNN_BACKEND_EXT_CONFIG_PATH"); + if (configured != nullptr && configured[0] != '\0') + return; + std::string config_path = base_dir; + // Strip a trailing "/models" so the config resolves next to the + // model-collection root (mirrors the single-model resolution). + while (!config_path.empty() && config_path.back() == '/') + config_path.pop_back(); + if (config_path.length() >= 7 && + config_path.substr(config_path.length() - 7) == "/models") { + config_path = config_path.substr(0, config_path.length() - 7); + } + config_path += "/htp_backend_ext_config.json"; + setenv("QUICK_DOT_AI_QNN_BACKEND_EXT_CONFIG_PATH", config_path.c_str(), 1); +} +#endif + +/** Internal overload: config_name already resolved (T4 byName path). */ +static ErrorCode load_into_handle(CausalLmModel &h, BackendType compute, + const char *target_model_name, + ModelQuantizationType quant_type, + const char *native_lib_dir, + const char *model_base_path) { + LOGD("[DEBUG] load_into_handle: START"); + LOGD("[DEBUG] compute: %d", compute); + LOGD("[DEBUG] target_model_name: %s", + target_model_name ? target_model_name : "(null)"); + LOGD("[DEBUG] quant_type: %d", quant_type); + + auto start_init = std::chrono::high_resolution_clock::now(); + + // Ensure models/configs are registered (thread-safe via call_once) + LOGD("[DEBUG] load_into_handle: Calling register_models..."); + register_models(); + LOGD("[DEBUG] load_into_handle: register_models done"); + + std::lock_guard lock(h.mtx); + try { + h.models.clear(); + h.architectures.clear(); + h.model_dirs.clear(); + h.initialization_duration_ms.clear(); + h.initialized = false; + reset_handle_session_state(h); + + // Check if it's a registered in-memory config + std::string input_name = std::string(target_model_name); + std::string input_name_upper = input_name; + std::transform(input_name_upper.begin(), input_name_upper.end(), + input_name_upper.begin(), ::toupper); + LOGD("[DEBUG] load_into_handle: input_name = %s", input_name.c_str()); + + std::string quant_suffix = ""; + switch (quant_type) { + case CAUSAL_LM_QUANTIZATION_W4A32: + quant_suffix = "-W4A32"; + break; + case CAUSAL_LM_QUANTIZATION_W16A16: + quant_suffix = "-W16A16"; + break; + case CAUSAL_LM_QUANTIZATION_W8A16: + quant_suffix = "-W8A16"; + break; + case CAUSAL_LM_QUANTIZATION_W32A32: + quant_suffix = "-W32A32"; + break; + default: + break; + } + std::string lookup_name = input_name_upper + quant_suffix; + LOGD("[DEBUG] load_into_handle: lookup_name = %s", lookup_name.c_str()); + + json cfg; + json generation_cfg; + json nntr_cfg; + std::string model_dir_path; + std::string abs_model_dir; + std::string base_dir = + (model_base_path != nullptr && strlen(model_base_path) > 0) + ? model_base_path + : "/sdcard/Download/aistudio-mobile/models/"; + +#ifdef ENABLE_QNN + // Set the QNN backend-extensions config path up front so it is in effect + // for BOTH the multi-model sub-model loop and the single-model path. + ensure_qnn_backend_ext_config(base_dir); +#endif + + // Snapshot registry entries under the registry mutex so concurrent + // loads on different handles don't race with each other (or with + // registerModel / registerModelArchitecture). + std::lock_guard reg_lock(g_registry_mutex); + + // Check in-memory map first + // if (g_model_registry.find(lookup_name) != g_model_registry.end()) { + + // always goto case2 + if (0) { + LOGD("[DEBUG] load_into_handle: CASE 1 - Internal config found for %s", + lookup_name.c_str()); + // ------------------------------------------------------------------------ + // CASE 1: Model Configuration is Internal (Registered in + // model_config.cpp) + // ------------------------------------------------------------------------ + // In this case, we do NOT load config.json or nntr_config.json from disk. + // We only locate the binary weight file. + RegisteredModel &rm = g_model_registry[lookup_name]; + + // Find architecture config + if (g_arch_config_map.find(rm.arch_name) == g_arch_config_map.end()) { + LOGE("[DEBUG] load_into_handle: Architecture '%s' not found for model " + "'%s'", + rm.arch_name.c_str(), lookup_name.c_str()); + return CAUSAL_LM_ERROR_MODEL_LOAD_FAILED; + } + LOGD("[DEBUG] load_into_handle: arch_name = %s", rm.arch_name.c_str()); + ModelArchConfig &ac = g_arch_config_map[rm.arch_name]; + ModelRuntimeConfig &rc = rm.config; + + // Strategy: Resolve path to find the weight file + model_dir_path = + "./models" + resolve_model_path(target_model_name, quant_type); + LOGD("[DEBUG] load_into_handle: model_dir_path = %s", + model_dir_path.c_str()); + + // Populate JSONs from Arch Struct + cfg["vocab_size"] = ac.vocab_size; + cfg["hidden_size"] = ac.hidden_size; + cfg["intermediate_size"] = ac.intermediate_size; + cfg["num_hidden_layers"] = ac.num_hidden_layers; + cfg["num_attention_heads"] = ac.num_attention_heads; + cfg["head_dim"] = ac.head_dim; + cfg["num_key_value_heads"] = ac.num_key_value_heads > 0 + ? ac.num_key_value_heads + : ac.num_attention_heads; + cfg["max_position_embeddings"] = ac.max_position_embeddings; + cfg["rope_theta"] = ac.rope_theta; + cfg["rms_norm_eps"] = ac.rms_norm_eps; + cfg["tie_word_embeddings"] = ac.tie_word_embeddings; + if (ac.sliding_window != UINT_MAX) { + cfg["sliding_window"] = ac.sliding_window; + } else { + cfg["sliding_window"] = nullptr; + } + cfg["sliding_window_pattern"] = ac.sliding_window_pattern; + cfg["architectures"] = {std::string(ac.architecture)}; + + if (ac.num_eos_token_ids > 0) { + std::vector eos_ids; + for (unsigned int i = 0; i < ac.num_eos_token_ids; ++i) + eos_ids.push_back(ac.eos_token_ids[i]); + generation_cfg["eos_token_id"] = eos_ids; + } + generation_cfg["bos_token_id"] = ac.bos_token_id; + + // Populate JSONs from Runtime Struct + generation_cfg["top_k"] = rc.top_k; + generation_cfg["top_p"] = rc.top_p; + generation_cfg["temperature"] = rc.temperature; + generation_cfg["do_sample"] = false; + + nntr_cfg["batch_size"] = rc.batch_size; + nntr_cfg["model_type"] = std::string(rc.model_type); + nntr_cfg["model_tensor_type"] = std::string(rc.model_tensor_type); + nntr_cfg["init_seq_len"] = rc.init_seq_len; + nntr_cfg["max_seq_len"] = rc.max_seq_len; + nntr_cfg["num_to_generate"] = rc.num_to_generate; + nntr_cfg["fsu"] = rc.fsu; + nntr_cfg["fsu_lookahead"] = rc.fsu_lookahead; + nntr_cfg["embedding_dtype"] = std::string(rc.embedding_dtype); + nntr_cfg["fc_layer_dtype"] = std::string(rc.fc_layer_dtype); + nntr_cfg["model_file_name"] = std::string(rc.model_file_name); + + // tokenizer_file path is set later from abs_model_dir in the shared + // post-processing block below. + (void)rc.tokenizer_file; + + if (strlen(rc.lmhead_dtype) > 0) { + nntr_cfg["lmhead_dtype"] = std::string(rc.lmhead_dtype); + } + + std::vector bad_ids; + for (unsigned int i = 0; i < rc.num_bad_word_ids; ++i) + bad_ids.push_back(rc.bad_word_ids[i]); + nntr_cfg["bad_word_ids"] = bad_ids; + } else { + LOGD("[DEBUG] load_into_handle: CASE 2 - External config (file-based)"); + // -------------------------------------------------- + // CASE 2: External Model Configuration (File-based) + // -------------------------------------------------- + // The model type is registered (enum), but specific configuration for + // this quantization is not in memory. We must load config.json and + // nntr_config.json from the model directory + model_dir_path = resolve_model_path(target_model_name, quant_type); + LOGD("[DEBUG] load_into_handle: model_dir_path = %s", + model_dir_path.c_str()); + + abs_model_dir = base_dir + model_dir_path; + LOGD("[DEBUG] load_into_handle: abs_model_dir = %s", + abs_model_dir.c_str()); + + // Top-level nntr_config.json is read once and used for both + // (a) multi-model dispatch (architectures[] + model_dirs[]), and + // (b) the single-model fallback below. + json top_nntr = + causallm::LoadJsonFile(abs_model_dir + "/nntr_config.json"); + + LOGD("[DEBUG] load_into_handle: abs_model_dir = %s", + abs_model_dir.c_str()); + + LOGD("[DEBUG] load_into_handle: top_nntr = %s", + (abs_model_dir + "/nntr_config.json").c_str()); + + const bool is_multi = + top_nntr.contains("architectures") && + top_nntr["architectures"].is_array() && + top_nntr.contains("model_dirs") && top_nntr["model_dirs"].is_array() && + !top_nntr["architectures"].empty() && + top_nntr["architectures"].size() == top_nntr["model_dirs"].size(); + + if (top_nntr.contains("use_chat_template")) { + g_use_chat_template = top_nntr["use_chat_template"].get(); + } + + LOGD("[DEBUG] load_into_handle: abs_model_dir = %d %d %d %d %zu %zu", + top_nntr.contains("architectures"), + top_nntr["architectures"].is_array(), + top_nntr.contains("model_dirs"), top_nntr["model_dirs"].is_array(), + top_nntr["architectures"].size(), top_nntr["model_dirs"].size()); + + if (is_multi) { + // ---------------------------------------------------------------- + // Multi-model branch. + // + // top_nntr_config.json: + // { "architectures": ["ArchA", "ArchB"], + // "model_dirs": ["sub_a", "sub_b"] } + // + // Each sub_dir = abs_model_dir + "/" + model_dirs[i] owns its own + // config.json / generation_config.json / nntr_config.json + + // weights. The top-level architectures[i] wins over any + // "architectures" entry inside sub-config β€” one source of truth. + // ---------------------------------------------------------------- + auto archs = top_nntr["architectures"].get>(); + auto dirs = top_nntr["model_dirs"].get>(); + LOGD("[DEBUG] load_into_handle: MULTI-MODEL spec (N=%zu)", + archs.size()); + + for (size_t i = 0; i < archs.size(); ++i) { + const std::string &arch_i = archs[i]; + const std::string sub_dir = abs_model_dir + "/" + dirs[i]; + LOGD("[DEBUG] [%zu] arch=%s dir=%s", i, arch_i.c_str(), + sub_dir.c_str()); + + json sub_cfg = causallm::LoadJsonFile(sub_dir + "/config.json"); + + json sub_gen; + if (check_file_exists(sub_dir + "/generation_config.json")) { + sub_gen = + causallm::LoadJsonFile(sub_dir + "/generation_config.json"); + } else { + sub_gen = json::object(); + } + + json sub_nntr = causallm::LoadJsonFile(sub_dir + "/nntr_config.json"); + + fix_paths(sub_nntr, sub_dir); + + // Optional per-sub-model overrides from the top-level config. + // Lets callers flip flags like uses_embedding / add keys like + // embedding_file_name without duplicating the sub-model's own + // nntr_config.json. fix_paths is run again so any newly + // introduced path-like key (e.g. embedding_file_name) is + // resolved relative to sub_dir just like the native keys. + if (top_nntr.contains("model_options") && + top_nntr["model_options"].is_array() && + i < top_nntr["model_options"].size() && + top_nntr["model_options"][i].is_object()) { + for (auto it = top_nntr["model_options"][i].begin(); + it != top_nntr["model_options"][i].end(); ++it) { + sub_nntr[it.key()] = it.value(); + LOGD("[DEBUG] override sub[%zu] %s", i, it.key().c_str()); + } + fix_paths(sub_nntr, sub_dir); + } + if (sub_nntr.contains("lora_path")) { + LOGD("lora_path : %s", + sub_nntr["lora_path"].get().c_str()); + std::string lora_path = + sub_dir + "/" + sub_nntr["lora_path"].get(); + sub_nntr["lora_path"] = lora_path; + LOGD("lora_path is now %s", lora_path.c_str()); + } + + auto m = causallm::Factory::Instance().create(arch_i, sub_cfg, + sub_gen, sub_nntr); + if (!m) { + LOGE("[DEBUG] load_into_handle: Factory::create returned nullptr " + "for sub-model %zu (arch=%s)", + i, arch_i.c_str()); + return CAUSAL_LM_ERROR_MODEL_LOAD_FAILED; + } + + auto sub_t0 = std::chrono::high_resolution_clock::now(); + if (native_lib_dir != nullptr && strlen(native_lib_dir) > 0) { + setenv("ADSP_LIBRARY_PATH", native_lib_dir, 1); + m->initialize(std::string(native_lib_dir)); + } else { + m->initialize(); + } + + std::string weight_file = + sub_nntr.contains("model_file_name") + ? sub_nntr["model_file_name"].get() + : (sub_dir + "/pytorch_model.bin"); + m->load_weight(weight_file); + auto sub_t1 = std::chrono::high_resolution_clock::now(); + double sub_ms = std::chrono::duration_cast( + sub_t1 - sub_t0) + .count(); + + h.models.push_back(std::move(m)); + h.architectures.push_back(arch_i); + h.model_dirs.push_back(sub_dir); + h.initialization_duration_ms.push_back(sub_ms); + LOGD("[DEBUG] [%zu] loaded (%.1f ms)", i, sub_ms); + + // Load chat template from model directory if available. + if (causallm::ChatTemplate::Exists(sub_dir)) { + try { + g_chat_template = causallm::ChatTemplate::Load(sub_dir); + std::cout << "[Info] Chat template loaded from " << sub_dir + << std::endl; + } catch (const std::exception &e) { + std::cerr << "[Warning] Chat template load failed: " << e.what() + << ". Falling back to hardcoded templates." + << std::endl; + g_chat_template.reset(); + } + } + } + + if (native_lib_dir != nullptr) + h.native_lib_dir = native_lib_dir; + h.initialized = true; + + auto finish_init = std::chrono::high_resolution_clock::now(); + auto e2e = std::chrono::duration_cast( + finish_init - start_init) + .count(); + LOGD("[DEBUG] load_into_handle: MULTI-MODEL SUCCESS " + "(%zu models, %lld ms e2e)", + h.models.size(), e2e); + return CAUSAL_LM_ERROR_NONE; + } + + // -------------------- single-model fallback -------------------- + LOGD("single cfg : %s", (abs_model_dir + "/config.json").c_str()); + cfg = causallm::LoadJsonFile(abs_model_dir + "/config.json"); + + if (check_file_exists(abs_model_dir + "/generation_config.json")) { + generation_cfg = + causallm::LoadJsonFile(abs_model_dir + "/generation_config.json"); + } + + nntr_cfg = std::move(top_nntr); + + if (nntr_cfg.contains("lora_path")) { + nntr_cfg["lora_path"] = ""; + } + + LOGD("single tokenizer : %s", + (abs_model_dir + "/tokenizer.json").c_str()); + + if (nntr_cfg.contains("tokenizer_file")) { + nntr_cfg["tokenizer_file"] = abs_model_dir + "/tokenizer.json"; + } + } + + // Load chat template from model directory if available. + if (causallm::ChatTemplate::Exists(abs_model_dir)) { + try { + g_chat_template = causallm::ChatTemplate::Load(abs_model_dir); + LOGD("[Info] Chat template loaded from %s", abs_model_dir.c_str()); + } catch (const std::exception &e) { + LOGE("[Warning] Chat template load failed: %s. Falling back to " + "hardcoded templates.", + e.what()); + g_chat_template.reset(); + } + } else { + g_chat_template.reset(); + LOGE("[Warning] No chat template found in %s. Using hardcoded chat " + "templates.", + abs_model_dir.c_str()); + } + + // Construct weight file path + std::string weight_file_name; + if (nntr_cfg.contains("model_file_name")) { + weight_file_name = nntr_cfg["model_file_name"].get(); + } else { + weight_file_name = "pytorch_model.bin"; + } + + const std::string weight_file = + rebase_path(weight_file_name, abs_model_dir); + LOGD("[DEBUG] load_into_handle: weight_file = %s", weight_file.c_str()); + std::cout << "-------------------" << abs_model_dir << "/" << std::endl; + + nntr_cfg["model_file_name"] = weight_file; + if (nntr_cfg.contains("binary_config_path")) { + std::string str = nntr_cfg["binary_config_path"].get(); + nntr_cfg["binary_config_path"] = rebase_path(str, abs_model_dir); + LOGD("[DEBUG] bianry config data: file = %s", + nntr_cfg["binary_config_path"].get().c_str()); + } + if (nntr_cfg.contains("image_newline_path")) { + std::string str = nntr_cfg["image_newline_path"].get(); + nntr_cfg["image_newline_path"] = rebase_path(str, abs_model_dir); + LOGD("[DEBUG] new line config data: file = %s", + nntr_cfg["image_newline_path"].get().c_str()); + } + if (nntr_cfg.contains("embedding_file_name")) { + std::string str = nntr_cfg["embedding_file_name"].get(); + nntr_cfg["embedding_file_name"] = rebase_path(str, abs_model_dir); + } + if (nntr_cfg.contains("ple_file_name")) { + std::string str = nntr_cfg["ple_file_name"].get(); + nntr_cfg["ple_file_name"] = rebase_path(str, abs_model_dir); + } + + LOGD("[DEBUG] -------------------------- asdfasdfasdfasdfasdfasdf "); + + // Determine architecture from config or ModelType + // Priority: Config file architecture > ModelType mapping (fallback) + std::string architecture; + if (cfg.contains("architectures") && cfg["architectures"].is_array() && + !cfg["architectures"].empty()) { + architecture = cfg["architectures"].get>()[0]; + } else { + // No fallback mapping from specific ModelType instances to generic + // architecture strings for now, as specific types should have config or + // be loaded from valid file with config.json + LOGE("[DEBUG] load_into_handle: No architecture found in config"); + return CAUSAL_LM_ERROR_INVALID_PARAMETER; + } + LOGD("[DEBUG] load_into_handle: architecture = %s", architecture.c_str()); + + LOGD("[DEBUG] load_into_handle: Creating model via Factory...%s ", + architecture.c_str()); + + auto m = causallm::Factory::Instance().create(architecture, cfg, + generation_cfg, nntr_cfg); + if (!m) { + LOGE("[DEBUG] load_into_handle: Factory::create returned nullptr"); + return CAUSAL_LM_ERROR_MODEL_LOAD_FAILED; + } + LOGD("[DEBUG] load_into_handle: Model created successfully"); + + if (native_lib_dir != nullptr) + h.native_lib_dir = native_lib_dir; + +#ifdef ENABLE_QNN + // Point the QNN HTP backend at htp_backend_ext_config.json. This must run + // for ANY QNN-backed architecture (e.g. Gemma4_E2B_QNN), not only + // callback-registered ones: otherwise QNNContext falls back to the process + // cwd β€” which is "/" for an installed APK β€” and QNN layer registration + // throws not_supported ("Unable to load backend extensions config"). + // Honor an externally configured path if one is already set. + ensure_qnn_backend_ext_config(base_dir); +#endif + + LOGD("[DEBUG] load_into_handle: Calling model->initialize()..."); + if (native_lib_dir != nullptr && strlen(native_lib_dir) > 0) { + setenv("ADSP_LIBRARY_PATH", native_lib_dir, 1); + m->initialize(std::string(native_lib_dir)); + } else { + m->initialize(); + } + LOGD("[DEBUG] load_into_handle: model->initialize() done"); + + LOGD("[DEBUG] load_into_handle: Calling model->load_weight()..."); + m->load_weight(weight_file); + LOGD("[DEBUG] load_into_handle: model->load_weight() done"); + + auto finish_init = std::chrono::high_resolution_clock::now(); + auto init_duration = std::chrono::duration_cast( + finish_init - start_init); + + h.models.push_back(std::move(m)); + h.architectures.push_back(architecture); + h.model_dirs.push_back(abs_model_dir); + h.initialization_duration_ms.push_back( + static_cast(init_duration.count())); + h.initialized = true; + + // XGrammarManager Initalize + auto *tokenizer = h.models[0]->getTokenizer(); + unsigned int vocab_size = h.models[0]->getVocabSize(); + causallm::XGrammarManager::Instance().initialize(tokenizer, vocab_size); + + // XGrammarManager Toolset Load + std::string default_toolset_path = abs_model_dir + "/Toolset.json"; + bool toolset_file_exists = check_file_exists(default_toolset_path); + if (toolset_file_exists) { + loadToolset(default_toolset_path.c_str(), tokenizer, vocab_size); + } + + LOGD("[DEBUG] load_into_handle: SINGLE SUCCESS (init took %lld ms)", + init_duration.count()); + } catch (...) { + // RTTI may not match across shared libraries β€” query the current + // exception's typeinfo directly via the Itanium ABI hook. This + // works even when catching by concrete types fails due to typeinfo + // duplication between libnntrainer.so and libquick_dot_ai_api.so. + const std::type_info *ti = abi::__cxa_current_exception_type(); + const char *raw = ti ? ti->name() : "(null)"; + int status = 0; + char *demangled = (ti != nullptr) + ? abi::__cxa_demangle(raw, nullptr, nullptr, &status) + : nullptr; + LOGE("[DEBUG] load_into_handle: unknown exception, type=%s", + demangled ? demangled : raw); + std::free(demangled); + + // Also try once more via rethrow β€” in case std::exception RTTI does + // match from this catch-site (we already tried above but leaving + // this as a second chance is cheap). + try { + throw; + } catch (const std::exception &e) { + LOGE("[DEBUG] load_into_handle: rethrown std::exception what()=%s", + e.what()); + } catch (const std::invalid_argument &e) { + LOGE("[DEBUG] load_into_handle: rethrown std::exception what()=%s", + e.what()); + } catch (...) { + LOGE("[DEBUG] load_into_handle: rethrown still non-std"); + } + return CAUSAL_LM_ERROR_MODEL_LOAD_FAILED; + } + LOGD("[DEBUG] load_into_handle: END (returning CAUSAL_LM_ERROR_NONE)"); + return CAUSAL_LM_ERROR_NONE; +} + +/** ModelType overload: translates enum β†’ config_name then delegates. */ +static ErrorCode load_into_handle(CausalLmModel &h, BackendType compute, + ModelType modeltype, + ModelQuantizationType quant_type, + const char *native_lib_dir, + const char *model_base_path) { + const char *target_model_name = get_model_name_from_type(modeltype); + if (!target_model_name) { + LOGE("[DEBUG] load_into_handle: Invalid modeltype %d", modeltype); + return CAUSAL_LM_ERROR_INVALID_PARAMETER; + } + LOGD("[DEBUG] load_into_handle: target_model_name = %s %d", target_model_name, + modeltype); + return load_into_handle(h, compute, target_model_name, quant_type, + native_lib_dir, model_base_path); +} + +/** + * @brief Core runner shared by runModelHandleWithMessages. + */ +static ErrorCode run_on_handle(CausalLmModel &h, const char *inputTextPrompt, + const char **outputText, + bool input_already_formatted = false, + size_t model_index = 0) { + if (inputTextPrompt == nullptr || outputText == nullptr) { + return CAUSAL_LM_ERROR_INVALID_PARAMETER; + } + + std::lock_guard lock(h.mtx); + if (!h.initialized || model_index >= h.models.size() || + !h.models[model_index]) { + return CAUSAL_LM_ERROR_NOT_INITIALIZED; + } + + try { + auto &model = *h.models[model_index]; + std::string input = prepare_input_for_model( + h, model_index, std::string(inputTextPrompt), input_already_formatted); + +// We assume single batch request for this API +#if defined(_WIN32) + model.run(std::wstring(input.begin(), input.end()), false, L"", L"", + g_verbose); +#else + model.run(input, false, "", "", g_verbose); +#endif + + h.last_output = model.getOutput(0); + *outputText = h.last_output.c_str(); + update_handle_session_after_run(h, model_index); + } catch (const std::exception &e) { + LOGE("Exception in run_on_handle: %s", e.what()); + return CAUSAL_LM_ERROR_INFERENCE_FAILED; + } + + return CAUSAL_LM_ERROR_NONE; +} + +/** + * @brief Core metrics fetcher shared by getPerformanceMetrics and its + * handle-based counterpart. + * + * Reports models[0] runtime metrics. initialization_duration_ms is the + * sum over all sub-models this handle owns. + */ +static ErrorCode metrics_on_handle(CausalLmModel &h, + PerformanceMetrics *metrics) { + if (metrics == nullptr) { + return CAUSAL_LM_ERROR_INVALID_PARAMETER; + } + + std::lock_guard lock(h.mtx); + size_t metrics_model_idx = text_generation_model_index(h); + if (!h.initialized || h.models.size() <= metrics_model_idx || + !h.models[metrics_model_idx]) { + return CAUSAL_LM_ERROR_NOT_INITIALIZED; + } + + try { + auto *model = h.models[metrics_model_idx].get(); + if (!model->hasRun()) { + return CAUSAL_LM_ERROR_INFERENCE_NOT_RUN; + } + auto im = model->getPerformanceMetrics(); + metrics->prefill_tokens = im.prefill_tokens; + metrics->prefill_duration_ms = im.prefill_duration_ms; + metrics->generation_tokens = im.generation_tokens; + metrics->generation_duration_ms = im.generation_duration_ms; + metrics->total_duration_ms = im.total_duration_ms; + metrics->peak_memory_kb = im.peak_memory_kb; + + double total_init = 0.0; + for (double d : h.initialization_duration_ms) + total_init += d; + metrics->initialization_duration_ms = total_init; + } catch (const std::exception &e) { + LOGE("Exception in getPerformanceMetrics: %s", e.what()); + return CAUSAL_LM_ERROR_UNKNOWN; + } + + return CAUSAL_LM_ERROR_NONE; +} + +/***************************************************************************** + * Chat Template API - role + content message support + *****************************************************************************/ + +// Internal ChatMessage struct for API use +struct ChatMessage { + std::string role; + std::string content; +}; + +static std::vector +convertMessages(const CausalLMChatMessage *messages, size_t num_messages) { + std::vector result; + result.reserve(num_messages); + for (size_t i = 0; i < num_messages; ++i) { + ChatMessage msg; + msg.role = messages[i].role ? messages[i].role : ""; + msg.content = messages[i].content ? messages[i].content : ""; + result.push_back(std::move(msg)); + } + return result; +} + +/** + * @brief Apply chat template to messages with hardcoded fallback + * + * @param model_dir Optional model directory to load tokenizer_config.json + * from if g_chat_template is not already loaded. This ensures + * registered models (and any other model) use their tokenizer's + * chat template when available. + */ +static std::string apply_chat_template_messages( + const std::string &architecture, const std::vector &messages, + bool add_generation_prompt, const std::string &model_dir = "") { + // If g_chat_template is not loaded but a model_dir is provided, + // try loading tokenizer_config.json from that directory at run time. + if (!g_chat_template && !model_dir.empty()) { + std::string tc_path = model_dir + "/tokenizer_config.json"; + if (check_file_exists(tc_path)) { + try { + g_chat_template = causallm::ChatTemplate::Load(model_dir); + if (g_chat_template) { + LOGD("[Info] Chat template loaded on-demand from %s", + model_dir.c_str()); + } else { + LOGE("[Warning] tokenizer_config.json found in %s but could not be " + "loaded.", + model_dir.c_str()); + } + } catch (const std::exception &e) { + LOGE("[Warning] Failed to load chat template from %s: %s", + model_dir.c_str(), e.what()); + } + } else { + LOGE("[Warning] tokenizer_config.json not found in %s", + model_dir.c_str()); + } + } + + // Use Enhanced Chat Template if available + if (g_chat_template) { + nlohmann::json request; + request["messages"] = nlohmann::json::array(); + for (const auto &msg : messages) { + request["messages"].push_back( + {{"role", msg.role}, {"content", msg.content}}); + } + request["add_generation_prompt"] = add_generation_prompt; + + try { + return g_chat_template->apply(request); + } catch (const std::exception &e) { + LOGE("Chat template apply failed: %s", e.what()); + // fallback to hardcoded + } + } + + LOGD("APPLYING HARD CODED FALLBACK"); + std::string result; + + if (architecture == "LlamaForCausalLM") { + for (const auto &msg : messages) { + if (msg.role == "system") { + result += "<>\n" + msg.content + "\n<>\n\n"; + } else if (msg.role == "user") { + result += "[INST] " + msg.content + " [/INST]"; + } else if (msg.role == "assistant") { + result += msg.content + "\n"; + } + } + } else if (architecture == "Qwen2ForCausalLM" || + architecture == "Qwen3ForCausalLM" || + architecture == "Qwen3MoeForCausalLM" || + architecture == "Qwen3SlimMoeForCausalLM" || + architecture == "Qwen3CachedSlimMoeForCausalLM") { + for (const auto &msg : messages) { + result += "<|im_start|>" + msg.role + "\n" + msg.content + "<|im_end|>\n"; + } + if (add_generation_prompt) { + result += "<|im_start|>assistant\n"; + } + } else if (architecture == "Gemma3ForCausalLM") { + for (const auto &msg : messages) { + if (msg.role == "user") { + result += "user\n" + msg.content + "\n"; + } else if (msg.role == "assistant") { + result += "model\n" + msg.content + "\n"; + } + } + if (add_generation_prompt) { + result += "model\n"; + } + } else if (architecture == "Gemma4ForCausalLM" || + architecture == "Gemma4_E2B_QNN") { + // Gemma 4 requires a single leading token (id 2) at the very start + // of the prompt; see the canonical sample_input in the model's + // nntr_config.json. Nothing downstream adds it (the tokenizer.json + // post_processor is empty and tokenizer_config.json has no + // add_bos_token), so prepend it exactly once here. + result += ""; + for (const auto &msg : messages) { + std::string role = msg.role; + if (role == "assistant") { + role = "model"; + } + result += "<|turn>" + role + "\n" + msg.content + "\n"; + } + if (add_generation_prompt) { + result += "<|turn>model\n"; + } + } else { + const auto *reg_cb = ModelCallbackRegistry::instance().lookup(architecture); + if (reg_cb && reg_cb->format_prompt) { + // Build a full multi-message prompt via the registered format_prompt. + // Concatenate all message contents and delegate to the callback. + std::string combined; + for (const auto &msg : messages) { + combined += msg.content + "\n"; + } + result = reg_cb->format_prompt(combined); + return result; + } + // Unknown architecture fallback + for (const auto &msg : messages) { + result += msg.content + "\n"; + } + } + + return result; +} + +ErrorCode applyChatTemplate(const CausalLMChatMessage *messages, + size_t num_messages, bool add_generation_prompt, + const char **formattedText) { + if (messages == nullptr || num_messages == 0 || formattedText == nullptr) { + return CAUSAL_LM_ERROR_INVALID_PARAMETER; + } + + try { + auto &h = get_default_handle(); + std::lock_guard lock(h.mtx); + + // Debug: print messages before convertMessages + LOGD("[DEBUG] applyChatTemplate: num_messages=%zu", num_messages); + for (size_t i = 0; i < num_messages; ++i) { + LOGD("[DEBUG] applyChatTemplate: messages[%zu] role='%s' content='%s'", i, + messages[i].role ? messages[i].role : "(null)", + messages[i].content ? messages[i].content : "(null)"); + } + + auto chat_messages = convertMessages(messages, num_messages); + std::string arch = + h.architectures.empty() ? std::string() : h.architectures[0]; + std::string model_dir = + h.model_dirs.empty() ? std::string() : h.model_dirs[0]; + std::string formattedInput = apply_chat_template_messages( + arch, chat_messages, add_generation_prompt, model_dir); + + g_formatted_template = std::move(formattedInput); + *formattedText = g_formatted_template.c_str(); + } catch (const std::exception &e) { + LOGE("Exception in applyChatTemplate: %s", e.what()); + return CAUSAL_LM_ERROR_UNKNOWN; + } + + return CAUSAL_LM_ERROR_NONE; +} + +ErrorCode runModelHandleWithMessages(CausalLmHandle handle, + const CausalLMChatMessage *messages, + size_t num_messages, + bool add_generation_prompt, + const char **outputText) { + LOGD("[DEBUG] runModelHandleWithMessages: handle=%p", (void *)handle); + + if (handle == nullptr || outputText == nullptr) { + return CAUSAL_LM_ERROR_INVALID_PARAMETER; + } + + auto &h = *handle; + size_t model_index = 0; + std::string formattedInput; + + try { + { + std::lock_guard lock(h.mtx); + if (!h.initialized || h.models.empty()) { + return CAUSAL_LM_ERROR_NOT_INITIALIZED; + } + model_index = text_generation_model_index(h); + if (model_index >= h.models.size() || !h.models[model_index]) { + return CAUSAL_LM_ERROR_NOT_INITIALIZED; + } + + std::string model_dir = h.model_dirs.size() > model_index + ? h.model_dirs[model_index] + : std::string(); + + auto chat_messages = convertMessages(messages, num_messages); + std::string arch = h.architectures.size() > model_index + ? h.architectures[model_index] + : std::string(); + formattedInput = apply_chat_template_messages( + arch, chat_messages, add_generation_prompt, model_dir); + } + + return run_on_handle(h, formattedInput.c_str(), outputText, + /*input_already_formatted=*/true, model_index); + } catch (const std::exception &e) { + LOGE("Exception in runModelHandleWithMessages: %s", e.what()); + return CAUSAL_LM_ERROR_UNKNOWN; + } +} + +ErrorCode runModelHandleWithTool(CausalLmHandle handle, + const char *inputTextPrompt, + const char **outputText, const char *tool_name, + const char *tool_schema) { + if (handle == nullptr) { + return CAUSAL_LM_ERROR_INVALID_PARAMETER; + } + + auto &h = *handle; + + causallm::XGrammar *grammar = nullptr; + // Step 1: Check if tool exists in XGrammarManager + if (causallm::XGrammarManager::Instance().hasTool(tool_name)) { + LOGD("[runModelWithToolHandle] Tool '%s' found in XGrammarManager, using " + "existing grammar", + tool_name); + grammar = causallm::XGrammarManager::Instance().getGrammar(tool_name); + } else { + // Step 2: Tool doesn't exist, create and register it + if (tool_schema == nullptr) { + LOGE("Error: Tool '%s' not found and no schema provided", tool_name); + return CAUSAL_LM_ERROR_INVALID_PARAMETER; + } + + LOGD("[runModelWithToolHandle] Tool '%s' not found, creating new grammar", + tool_name); + bool registered = causallm::XGrammarManager::Instance().registerTool( + tool_name, tool_schema); + + if (!registered) { + LOGE("Error: Failed to register tool '%s'", tool_name); + return CAUSAL_LM_ERROR_UNKNOWN; + } + + grammar = causallm::XGrammarManager::Instance().getGrammar(tool_name); + } + + if (grammar == nullptr) { + LOGE("Error: Failed to get grammar for tool '%s'", tool_name); + return CAUSAL_LM_ERROR_UNKNOWN; + } + + // Run inference using the handle + h.models[0]->setXGrammar(grammar); + ErrorCode err = run_on_handle(*handle, inputTextPrompt, outputText); + h.models[0]->resetXGrammar(); + return err; +} + +/*============================================================================ + * Legacy non-handle API implementation + *============================================================================*/ + +ErrorCode loadModel(BackendType compute, ModelType modeltype, + ModelQuantizationType quant_type, + const char *model_base_path) { + return load_into_handle(get_default_handle(), compute, modeltype, quant_type, + nullptr, model_base_path); +} + +ErrorCode saveQnnKvCache(const char *cache_path) { + return save_qnn_kv_cache_on_handle(get_default_handle(), cache_path); +} + +ErrorCode loadQnnKvCache(const char *cache_path) { + return load_qnn_kv_cache_on_handle(get_default_handle(), cache_path); +} + +ErrorCode resetQnnKvCache(void) { + return reset_qnn_kv_cache_on_handle(get_default_handle()); +} + +ErrorCode getPerformanceMetrics(PerformanceMetrics *metrics) { + return metrics_on_handle(get_default_handle(), metrics); +} + +/*============================================================================ + * Handle-based API implementation + *============================================================================*/ + +ErrorCode loadModelHandle(BackendType compute, ModelType modeltype, + ModelQuantizationType quant_type, + const char *native_lib_dir, + const char *model_base_path, + CausalLmHandle *out_handle) { + LOGD("[DEBUG] loadModelHandle:%d START", __LINE__); + LOGD("[DEBUG] loadModelHandle:%d compute: %d", __LINE__, compute); + LOGD("[DEBUG] loadModelHandle:%d modeltype: %d", __LINE__, modeltype); + LOGD("[DEBUG] loadModelHandle:%d quant_type: %d", __LINE__, quant_type); + LOGD("[DEBUG] loadModelHandle:%d native_lib_dir: %s", __LINE__, + native_lib_dir ? native_lib_dir : "(null)"); + LOGD("[DEBUG] loadModelHandle:%d out_handle ptr: %p", __LINE__, + (void *)out_handle); + + if (out_handle == nullptr) { + LOGE("[DEBUG] loadModelHandle:%d out_handle is nullptr", __LINE__); + return CAUSAL_LM_ERROR_INVALID_PARAMETER; + } + auto *h = new (std::nothrow) CausalLmModel(); + if (h == nullptr) { + LOGE("[DEBUG] loadModelHandle:%d Failed to allocate CausalLmModel", + __LINE__); + return CAUSAL_LM_ERROR_UNKNOWN; + } + LOGD("[DEBUG] loadModelHandle:%d CausalLmModel allocated at %p", __LINE__, + (void *)h); + + LOGD("[DEBUG] loadModelHandle:%d Calling load_into_handle...", __LINE__); + ErrorCode ec = load_into_handle(*h, compute, modeltype, quant_type, + native_lib_dir, model_base_path); + LOGD("[DEBUG] loadModelHandle:%d load_into_handle returned: %d", __LINE__, + ec); + + if (ec != CAUSAL_LM_ERROR_NONE) { + LOGE("[DEBUG] loadModelHandle:%d load_into_handle failed, deleting handle", + __LINE__); + delete h; + *out_handle = nullptr; + return ec; + } + *out_handle = h; + LOGD("[DEBUG] loadModelHandle:%d SUCCESS, handle set to %p", __LINE__, + (void *)h); + return CAUSAL_LM_ERROR_NONE; +} + +ErrorCode loadModelHandleByName(BackendType compute, const char *model_id, + ModelQuantizationType quant_type, + const char *native_lib_dir, + const char *model_base_path, + CausalLmHandle *out_handle) { + if (out_handle == nullptr || model_id == nullptr) + return CAUSAL_LM_ERROR_INVALID_PARAMETER; + + register_models(); // ensure factory + public descriptors registered + + auto d_opt = find_descriptor_by_id(model_id); + if (!d_opt) { + LOGE("loadModelHandleByName: unknown id '%s'", model_id); + return CAUSAL_LM_ERROR_INVALID_PARAMETER; + } + const ModelDescriptor &d = *d_opt; + if (!d.config_name) { + LOGE("loadModelHandleByName: descriptor '%s' has null config_name", + model_id); + return CAUSAL_LM_ERROR_INVALID_PARAMETER; + } + if (((d.backend_mask >> (unsigned)compute) & 1u) == 0u) { + LOGE("loadModelHandleByName: backend %d not in mask 0x%x for '%s'", compute, + d.backend_mask, model_id); + return CAUSAL_LM_ERROR_INVALID_PARAMETER; + } + + auto *h = new (std::nothrow) CausalLmModel(); + if (!h) + return CAUSAL_LM_ERROR_UNKNOWN; + + ErrorCode ec = load_into_handle(*h, compute, d.config_name, quant_type, + native_lib_dir, model_base_path); + if (ec != CAUSAL_LM_ERROR_NONE) { + delete h; + *out_handle = nullptr; + return ec; + } + *out_handle = h; + return CAUSAL_LM_ERROR_NONE; +} + +ErrorCode loadMultimodalHandleByName(BackendType compute, + const char *embedding_model_id, + const char *llm_model_id, + ModelQuantizationType quant_type, + const char *native_lib_dir, + const char *model_base_path, + CausalLmHandle *out_handle) { + if (out_handle == nullptr || embedding_model_id == nullptr || + llm_model_id == nullptr) + return CAUSAL_LM_ERROR_INVALID_PARAMETER; + + register_models(); + + auto ev = find_descriptor_by_id(embedding_model_id); + auto lv = find_descriptor_by_id(llm_model_id); + if (!ev || !lv || !ev->config_name || !lv->config_name) { + LOGE("loadMultimodalHandleByName: unknown id(s) emb='%s' llm='%s'", + embedding_model_id, llm_model_id); + return CAUSAL_LM_ERROR_INVALID_PARAMETER; + } + + // Load each model into its own temporary single-model handle, then move the + // sub-models into the combined handle in order [vision producer, LLM]. + // This reuses the proven single-model load path without modifying it. + CausalLmModel tmp_vision; + CausalLmModel tmp_llm; + ErrorCode ec = load_into_handle(tmp_vision, compute, ev->config_name, + quant_type, native_lib_dir, model_base_path); + if (ec != CAUSAL_LM_ERROR_NONE) { + LOGE("loadMultimodalHandleByName: vision '%s' load failed (%d)", + embedding_model_id, ec); + return ec; + } + ec = load_into_handle(tmp_llm, compute, lv->config_name, quant_type, + native_lib_dir, model_base_path); + if (ec != CAUSAL_LM_ERROR_NONE) { + LOGE("loadMultimodalHandleByName: llm '%s' load failed (%d)", llm_model_id, + ec); + return ec; + } + if (tmp_vision.models.empty() || tmp_llm.models.empty()) + return CAUSAL_LM_ERROR_MODEL_LOAD_FAILED; + + // Compatibility check (R5): the LLM must expose an embedding table so the + // composer can interleave text + image embeddings. + if (tmp_llm.models[0]->embeddingBytesPerToken() == 0) { + LOGE("loadMultimodalHandleByName: LLM '%s' has no embedding table β€” " + "incompatible pair", + llm_model_id); + return CAUSAL_LM_ERROR_UNSUPPORTED; + } + + auto *h = new (std::nothrow) CausalLmModel(); + if (!h) + return CAUSAL_LM_ERROR_UNKNOWN; + + auto move_one = [](CausalLmModel &src, CausalLmModel &dst) { + dst.models.push_back(std::move(src.models[0])); + dst.architectures.push_back(src.architectures.empty() ? std::string() + : src.architectures[0]); + dst.model_dirs.push_back(src.model_dirs.empty() ? std::string() + : src.model_dirs[0]); + if (!src.initialization_duration_ms.empty()) + dst.initialization_duration_ms.push_back( + src.initialization_duration_ms[0]); + }; + move_one(tmp_vision, *h); // index 0 = vision producer + move_one(tmp_llm, *h); // index 1 = LLM consumer + if (native_lib_dir != nullptr) + h->native_lib_dir = native_lib_dir; + h->initialized = true; + + *out_handle = h; + return CAUSAL_LM_ERROR_NONE; +} + +ErrorCode saveQnnKvCacheHandle(CausalLmHandle handle, const char *cache_path) { + if (handle == nullptr) { + return CAUSAL_LM_ERROR_INVALID_PARAMETER; + } + return save_qnn_kv_cache_on_handle(*handle, cache_path); +} + +ErrorCode loadQnnKvCacheHandle(CausalLmHandle handle, const char *cache_path) { + if (handle == nullptr) { + return CAUSAL_LM_ERROR_INVALID_PARAMETER; + } + return load_qnn_kv_cache_on_handle(*handle, cache_path); +} + +ErrorCode resetQnnKvCacheHandle(CausalLmHandle handle) { + if (handle == nullptr) { + return CAUSAL_LM_ERROR_INVALID_PARAMETER; + } + return reset_qnn_kv_cache_on_handle(*handle); +} + +ErrorCode getPerformanceMetricsHandle(CausalLmHandle handle, + PerformanceMetrics *metrics) { + if (handle == nullptr) { + return CAUSAL_LM_ERROR_INVALID_PARAMETER; + } + return metrics_on_handle(*handle, metrics); +} + +/*============================================================================ + * Internal streaming helper + *============================================================================*/ + +static ErrorCode run_model_streaming_on_handle(CausalLmModel &h, + const std::string &raw_input, + CausalLmTokenCallback callback, + void *user_data, + bool input_already_formatted, + size_t model_index) { + if (model_index >= h.models.size() || !h.models[model_index]) { + return CAUSAL_LM_ERROR_NOT_INITIALIZED; + } + auto *m = h.models[model_index].get(); + + CallbackStreamer streamer; + callback_streamer_init(&streamer, callback, user_data); + m->setStreamer(&streamer.base); + + struct Detach { + causallm::Transformer *t; + ~Detach() { t->setStreamer(nullptr); } + } detach_guard{m}; + + try { + std::string input = prepare_input_for_model(h, model_index, raw_input, + input_already_formatted); + + LOGD("[DEBUG] raw input length: %zu", raw_input.length()); + LOGD("[DEBUG] g_use_chat_template: %d", g_use_chat_template); + if (input_already_formatted) { + LOGD("[DEBUG] input_already_formatted=1, using pre-formatted input " + "(length: %zu)", + input.length()); + } else { + LOGD("[DEBUG] input_already_formatted=0, applying chat template"); + LOGD("[DEBUG] model input length: %zu", input.length()); + LOGD("[DEBUG] model input: %s", input.c_str()); + } + +#if defined(_WIN32) + m->run(std::wstring(input.begin(), input.end()), false, L"", L"", + g_verbose); +#else + m->run(input, false, "", "", true); +#endif + + h.last_output = m->getOutput(0); + update_handle_session_after_run(h, model_index); + + if (m->hasRun()) { + auto im = m->getPerformanceMetrics(); + double total_init = 0.0; + for (double d : h.initialization_duration_ms) + total_init += d; + + LOGD("[PERF] Performance Metrics:"); + LOGD("[PERF] prefill_tokens: %u", im.prefill_tokens); + LOGD("[PERF] prefill_duration_ms: %.2f", im.prefill_duration_ms); + LOGD("[PERF] generation_tokens: %u", im.generation_tokens); + LOGD("[PERF] generation_duration_ms: %.2f", im.generation_duration_ms); + LOGD("[PERF] total_duration_ms: %.2f", im.total_duration_ms); + LOGD("[PERF] peak_memory_kb: %zu", im.peak_memory_kb); + LOGD("[PERF] initialization_duration_ms: %.2f", total_init); + + if (im.prefill_duration_ms > 0) { + double tokens_per_sec = + (im.prefill_tokens * 1000.0) / im.prefill_duration_ms; + LOGD("[PERF] prefill_tokens_per_sec: %.2f", tokens_per_sec); + } + if (im.generation_duration_ms > 0) { + double tokens_per_sec = + (im.generation_tokens * 1000.0) / im.generation_duration_ms; + LOGD("[PERF] generation_tokens_per_sec: %.2f", tokens_per_sec); + } + } + } catch (const std::exception &e) { + LOGE("[DEBUG] run_model_streaming_on_handle: Exception caught: %s", + e.what()); + return CAUSAL_LM_ERROR_INFERENCE_FAILED; + } catch (...) { + LOGE("[DEBUG] run_model_streaming_on_handle: Unknown exception caught"); + return CAUSAL_LM_ERROR_INFERENCE_FAILED; + } + + return CAUSAL_LM_ERROR_NONE; +} + +ErrorCode runModelHandleStreaming(CausalLmHandle handle, + const char *inputTextPrompt, + CausalLmTokenCallback callback, + void *user_data) { + LOGD("[DEBUG] runModelHandleStreaming: START"); + LOGD("[DEBUG] handle: %p", (void *)handle); + LOGD("[DEBUG] inputTextPrompt: %.50s%s", + inputTextPrompt ? inputTextPrompt : "(null)", + inputTextPrompt && strlen(inputTextPrompt) > 50 ? "..." : ""); + + if (handle == nullptr || inputTextPrompt == nullptr || callback == nullptr) { + LOGE("[DEBUG] runModelHandleStreaming: INVALID_PARAMETER"); + return CAUSAL_LM_ERROR_INVALID_PARAMETER; + } + + auto &h = *handle; + std::lock_guard lock(h.mtx); + + if (!h.initialized || h.models.empty()) { + LOGE("[DEBUG] runModelHandleStreaming: NOT_INITIALIZED"); + return CAUSAL_LM_ERROR_NOT_INITIALIZED; + } + const size_t model_index = text_generation_model_index(h); + + ErrorCode ec = run_model_streaming_on_handle( + h, std::string(inputTextPrompt), callback, user_data, + /*input_already_formatted=*/false, model_index); + + LOGD("[DEBUG] runModelHandleStreaming: END (errorCode=%d)", ec); + return ec; +} + +ErrorCode encodeModelHandle(CausalLmHandle handle, const char *text, + float **out_embedding, int *out_dim) { + if (handle == nullptr || text == nullptr || out_embedding == nullptr || + out_dim == nullptr) { + return CAUSAL_LM_ERROR_INVALID_PARAMETER; + } + *out_embedding = nullptr; + *out_dim = 0; + + auto &h = *handle; + std::lock_guard lock(h.mtx); + + if (!h.initialized || h.models.empty()) { + return CAUSAL_LM_ERROR_NOT_INITIALIZED; + } + + // Embedding models occupy models[0] (single-model embedding handle). + auto *st = + dynamic_cast(h.models[0].get()); + if (st == nullptr) { + LOGE("encodeModelHandle: models[0] is not a SentenceTransformer"); + return CAUSAL_LM_ERROR_UNSUPPORTED; + } + + try { + const int dim = st->getEmbeddingDim(); + if (dim <= 0) { + return CAUSAL_LM_ERROR_INFERENCE_FAILED; + } + + // WSTR is std::string in this codebase; pass the text directly, + // consistent with runModelHandleStreaming. + std::string s(text); + + std::vector results = st->encode(s); + if (results.empty() || results[0] == nullptr) { + for (auto *p : results) + delete[] p; + return CAUSAL_LM_ERROR_INFERENCE_FAILED; + } + + // Copy the batch-0 embedding (first DIM floats) into a caller-owned buffer. + float *buf = new float[dim]; + std::memcpy(buf, results[0], sizeof(float) * static_cast(dim)); + + // encode() allocates each pointer with new[]; release them all. + for (auto *p : results) + delete[] p; + + *out_embedding = buf; + *out_dim = dim; + return CAUSAL_LM_ERROR_NONE; + } catch (const std::exception &e) { + LOGE("encodeModelHandle: exception: %s", e.what()); + return CAUSAL_LM_ERROR_INFERENCE_FAILED; + } catch (...) { + LOGE("encodeModelHandle: unknown exception"); + return CAUSAL_LM_ERROR_INFERENCE_FAILED; + } +} + +void freeEmbedding(float *embedding) { delete[] embedding; } + +ErrorCode unloadModelHandle(CausalLmHandle handle) { + if (handle == nullptr) { + return CAUSAL_LM_ERROR_NONE; + } + std::lock_guard lock(handle->mtx); + handle->models.clear(); + handle->architectures.clear(); + handle->model_dirs.clear(); + handle->initialization_duration_ms.clear(); + handle->initialized = false; + reset_handle_session_state(*handle); + return CAUSAL_LM_ERROR_NONE; +} + +ErrorCode destroyModelHandle(CausalLmHandle handle) { + if (handle == nullptr) { + return CAUSAL_LM_ERROR_NONE; + } + // Take the mutex to make sure no in-flight call on this handle is still + // running, then release and delete. Any caller that still holds a pointer + // to the output buffer returned by runModelHandleWithMessages is reading + // freed memory after this point β€” documented as "valid until destroy". + { + std::lock_guard lock(handle->mtx); + handle->models.clear(); + handle->architectures.clear(); + handle->model_dirs.clear(); + handle->initialization_duration_ms.clear(); + handle->initialized = false; + reset_handle_session_state(*handle); + } + delete handle; + return CAUSAL_LM_ERROR_NONE; +} + +ErrorCode cancelModelHandle(CausalLmHandle handle) { + LOGD("[DEBUG] cancelModelHandle: handle=%p", (void *)handle); + + if (handle == nullptr) { + LOGE("[DEBUG] cancelModelHandle: handle is nullptr, returning " + "INVALID_PARAMETER"); + return CAUSAL_LM_ERROR_INVALID_PARAMETER; + } + + // NOTE: We intentionally do NOT take the mutex here to avoid blocking + // when run() is holding the lock. The requestStop() method is thread-safe + // (uses atomic), and the models vector is not modified during run() + // (only during load/unload which do take the mutex). This allows immediate + // cancellation from any thread (e.g., UI cancel button handler). + LOGD("[DEBUG] cancelModelHandle: checking state without mutex, " + "initialized=%d, models.size=%zu", + handle->initialized, handle->models.size()); + + if (!handle->initialized || handle->models.empty()) { + LOGE( + "[DEBUG] cancelModelHandle: not initialized, returning NOT_INITIALIZED"); + return CAUSAL_LM_ERROR_NOT_INITIALIZED; + } + + // Set stop flag on all models (primarily affects models[0] for LLM) + for (size_t i = 0; i < handle->models.size(); ++i) { + if (handle->models[i]) { + LOGD("[DEBUG] cancelModelHandle: calling requestStop() on model[%zu]", i); + handle->models[i]->requestStop(); + } + } + + LOGD("[DEBUG] cancelModelHandle: returning NONE (success)"); + return CAUSAL_LM_ERROR_NONE; +} + +/*============================================================================ + * Multimodal API Implementation + * + * Preconditions: the handle must have been loaded from a multi-model + * nntr_config.json carrying at least two sub-models. The first sub-model + * is expected to be the vision encoder and the second the LLM, though + * the concrete integration (vision encoding + embedding fusion + LLM + * generation) is still TODO. Single-model handles return + * CAUSAL_LM_ERROR_UNSUPPORTED. + *============================================================================*/ + +#ifdef ENABLE_QNN +/** + * Model-agnostic multimodal composer. Works through base Transformer virtuals + * only (no concrete-model casts), so any [vision producer, LLM consumer] pair + * (e.g. a vision encoder + an embedding-consuming LLM) is driven identically. + * + * llm: embedding CONSUMER (lookupEmbedding / run_with_embeddings) + * image_embeds: producer output; ownership taken here (freed before return) + */ +static ErrorCode execute_multimodal(CausalLmModel &h, + causallm::Transformer *llm, + causallm::multimodal_pointer image_embeds, + const std::string &prompt, + CausalLmTokenCallback callback, + void *user_data) { + auto *tok = llm->getTokenizer(); + if (tok == nullptr) { + LOGE("[MM] llm has no tokenizer"); + std::free(image_embeds.first); + return CAUSAL_LM_ERROR_UNSUPPORTED; + } + std::vector text_ids = tok->Encode(prompt); + int32_t image_token_id = tok->TokenToId("<|image|>"); + + const size_t bpt = llm->embeddingBytesPerToken(); + if (bpt == 0) { + LOGE("[MM] llm embedding table not loaded (needs uses_embedding=false + " + "embedding_file_name)"); + std::free(image_embeds.first); + return CAUSAL_LM_ERROR_UNSUPPORTED; + } + if (image_embeds.second % bpt != 0) { + LOGE("[MM] image_embeds.size=%zu not a multiple of bpt=%zu", + image_embeds.second, bpt); + std::free(image_embeds.first); + return CAUSAL_LM_ERROR_INFERENCE_FAILED; + } + const size_t n_image = image_embeds.second / bpt; + + auto it_img = (image_token_id >= 0) + ? std::find(text_ids.begin(), text_ids.end(), image_token_id) + : text_ids.end(); + const bool has_placeholder = (it_img != text_ids.end()); + const size_t img_pos = + has_placeholder + ? static_cast(std::distance(text_ids.begin(), it_img)) + : 0; + const size_t n_text_kept = text_ids.size() - (has_placeholder ? 1 : 0); + const size_t n_total = n_text_kept + n_image; + LOGD("[MM] text=%zu image=%zu total=%zu placeholder=%d pos=%zu", + text_ids.size(), n_image, n_total, has_placeholder, img_pos); + + std::vector combined(n_total * bpt); + uint8_t *dst = combined.data(); + auto copy_text_range = [&](size_t start, size_t end) -> bool { + for (size_t i = start; i < end; ++i) { + const void *e = llm->lookupEmbedding(text_ids[i]); + if (e == nullptr) { + LOGE("[MM] lookupEmbedding(%d) null", text_ids[i]); + return false; + } + std::memcpy(dst, e, bpt); + dst += bpt; + } + return true; + }; + if (!copy_text_range(0, img_pos)) { + std::free(image_embeds.first); + return CAUSAL_LM_ERROR_INFERENCE_FAILED; + } + std::memcpy(dst, image_embeds.first, n_image * bpt); + dst += n_image * bpt; + const size_t after_start = has_placeholder ? img_pos + 1 : img_pos; + if (!copy_text_range(after_start, text_ids.size())) { + std::free(image_embeds.first); + return CAUSAL_LM_ERROR_INFERENCE_FAILED; + } + std::free(image_embeds.first); + image_embeds.first = nullptr; + + CallbackStreamer streamer; + callback_streamer_init(&streamer, callback, user_data); + llm->setStreamer(&streamer.base); + struct Detach { + causallm::Transformer *t; + ~Detach() { t->setStreamer(nullptr); } + } detach_guard{llm}; + + try { + llm->run_with_embeddings(combined.data(), n_total, text_ids, + /*do_sample=*/false, /*log_output=*/g_verbose); + h.kv_len = llm->getKvLen(); + } catch (const std::exception &e) { + LOGE("[MM] llm threw: %s", e.what()); + return CAUSAL_LM_ERROR_INFERENCE_FAILED; + } + return CAUSAL_LM_ERROR_NONE; +} + +/** + * Run the vision encoder (producer) on raw pixels and return its image + * embeddings, after matching its quant space to the LLM consumer. + * @return image_embeds (ownership transferred to caller) or {nullptr,0}. + */ +static causallm::multimodal_pointer +run_vision_encoder(CausalLmModel &h, const char *prompt, + const float *pixelValues, int numPatches, + int originalHeight, int originalWidth) { + const int PATCH_SIZE = 512; // pixel layout: numPatches*3*512*512 floats + causallm::Transformer *vision = h.models[0].get(); + causallm::Transformer *llm = h.models[1].get(); + + auto info = llm->get_embedding_info(); + vision->set_quant_param(info.first, info.second); + + const size_t pixel_bytes = static_cast(numPatches) * 3 * PATCH_SIZE * + PATCH_SIZE * sizeof(float); + causallm::multimodal_pointer image_in{const_cast(pixelValues), + pixel_bytes}; + return vision->run_image(std::string(prompt ? prompt : ""), image_in, + originalHeight, originalWidth, /*do_sample=*/false, + "", "", g_verbose); +} +#endif // ENABLE_QNN + +ErrorCode runMultimodalHandleStreaming(CausalLmHandle handle, + const char *prompt, + const float *pixelValues, int numPatches, + int originalHeight, int originalWidth, + CausalLmTokenCallback callback, + void *user_data) { + LOGD("[DEBUG] runMultimodalHandleStreaming: START"); + LOGD("[DEBUG] handle=%p", handle); + LOGD("[DEBUG] prompt=%s", prompt ? prompt : "(null)"); + LOGD("[DEBUG] pixelValues=%p", pixelValues); + LOGD("[DEBUG] numPatches=%d", numPatches); + LOGD("[DEBUG] originalHeight=%d", originalHeight); + LOGD("[DEBUG] originalWidth=%d", originalWidth); + LOGD("[DEBUG] callback=%p", (void *)callback); + LOGD("[DEBUG] user_data=%p", user_data); + + if (handle == nullptr || prompt == nullptr || pixelValues == nullptr || + callback == nullptr) { + LOGE("[DEBUG] runMultimodalHandleStreaming: INVALID_PARAMETER" + " handle=%p prompt=%s pixelValues=%p callback=%p", + handle, prompt, pixelValues, (void *)callback); + return CAUSAL_LM_ERROR_INVALID_PARAMETER; + } + + auto &h = *handle; + std::lock_guard lock(h.mtx); + if (!h.initialized || h.models.empty()) { + LOGE("[DEBUG] runMultimodalHandleStreaming: NOT_INITIALIZED"); + return CAUSAL_LM_ERROR_NOT_INITIALIZED; + } + + // Multimodal expects the handle to be loaded from a multi-model + // nntr_config.json (architectures[] + model_dirs[]) with at least + // [vision_encoder, llm]. A single-model handle cannot drive this path. + if (h.models.size() < 2) { + LOGE("[DEBUG] runMultimodalHandleStreaming: need >=2 sub-models " + "(got %zu). Load with multi-model nntr_config.json.", + h.models.size()); + return CAUSAL_LM_ERROR_UNSUPPORTED; + } + + LOGD("[DEBUG] runMultimodalHandleStreaming: %zu sub-models loaded", + h.models.size()); + for (size_t i = 0; i < h.architectures.size(); ++i) { + LOGD("[DEBUG] models[%zu]: arch=%s dir=%s", i, h.architectures[i].c_str(), + h.model_dirs[i].c_str()); + } + + // Log pixel values summary (first few values) + // Note: patch size is fixed at 512x512 + const int PATCH_SIZE = 512; + long long totalValues = 1LL * numPatches * 3 * PATCH_SIZE * PATCH_SIZE; + LOGD("[DEBUG] totalPixelValues=%lld", totalValues); + if (totalValues > 0 && pixelValues != nullptr) { + LOGD("[DEBUG] pixelValues[0..4]=%f, %f, %f, %f, %f", pixelValues[0], + pixelValues[1], pixelValues[2], + (totalValues > 3 ? pixelValues[3] : 0.0f), + (totalValues > 4 ? pixelValues[4] : 0.0f)); + } + +#ifdef ENABLE_QNN + // Generic path: models[0]=vision producer, models[1]=LLM consumer. + causallm::multimodal_pointer image_embeds{nullptr, 0}; + try { + image_embeds = run_vision_encoder(h, prompt, pixelValues, numPatches, + originalHeight, originalWidth); + } catch (const std::exception &e) { + LOGE("[DEBUG] runMultimodalHandleStreaming: vision threw: %s", e.what()); + return CAUSAL_LM_ERROR_INFERENCE_FAILED; + } + + const std::string raw_input(prompt); + const bool input_already_formatted = + raw_input.find("<|turn_start|>") != std::string::npos || + raw_input.find("<|im_start|>") != std::string::npos || + raw_input.find("") != std::string::npos; + std::string input = + prepare_input_for_model(h, 1, raw_input, input_already_formatted); + + return execute_multimodal(h, h.models[1].get(), image_embeds, input, + callback, user_data); +#else + LOGE("[DEBUG] runMultimodalHandleStreaming: built without ENABLE_QNN"); + return CAUSAL_LM_ERROR_UNSUPPORTED; +#endif +} + +ErrorCode runMultimodalHandleWithMessages( + CausalLmHandle handle, const CausalLMChatMessage *messages, + size_t num_messages, bool add_generation_prompt, const float *pixelValues, + int numPatches, int originalHeight, int originalWidth, + const char **outputText) { + LOGD("[DEBUG] runMultimodalHandleWithMessages: START"); + LOGD("[DEBUG] handle=%p", handle); + LOGD("[DEBUG] messages=%p", messages); + LOGD("[DEBUG] num_messages=%zu", num_messages); + LOGD("[DEBUG] add_generation_prompt=%d", add_generation_prompt); + LOGD("[DEBUG] pixelValues=%p", pixelValues); + LOGD("[DEBUG] numPatches=%d", numPatches); + LOGD("[DEBUG] originalHeight=%d", originalHeight); + LOGD("[DEBUG] originalWidth=%d", originalWidth); + LOGD("[DEBUG] outputText=%p", outputText); + + if (handle == nullptr || messages == nullptr || num_messages == 0 || + pixelValues == nullptr || outputText == nullptr) { + LOGE("[DEBUG] runMultimodalHandleWithMessages: INVALID_PARAMETER" + " handle=%p messages=%p num_messages=%zu pixelValues=%p outputText=%p", + handle, messages, num_messages, pixelValues, outputText); + return CAUSAL_LM_ERROR_INVALID_PARAMETER; + } + + auto &h = *handle; + std::lock_guard lock(h.mtx); + if (!h.initialized || h.models.empty()) { + LOGE("[DEBUG] runMultimodalHandleWithMessages: NOT_INITIALIZED"); + *outputText = nullptr; + return CAUSAL_LM_ERROR_NOT_INITIALIZED; + } + + if (h.models.size() < 2) { + LOGE( + "[DEBUG] runMultimodalHandleWithMessages: need >=2 sub-models (got %zu)", + h.models.size()); + *outputText = nullptr; + return CAUSAL_LM_ERROR_UNSUPPORTED; + } + + LOGD("[DEBUG] runMultimodalHandleWithMessages: %zu sub-models loaded", + h.models.size()); + for (size_t i = 0; i < h.architectures.size(); ++i) { + LOGD("[DEBUG] models[%zu]: arch=%s dir=%s", i, h.architectures[i].c_str(), + h.model_dirs[i].c_str()); + } + + // Apply chat template + auto chat_messages = convertMessages(messages, num_messages); + const size_t llm_index = h.architectures.size() > 1 ? 1 : 0; + std::string arch = h.architectures.size() > llm_index + ? h.architectures[llm_index] + : std::string(); + std::string model_dir = + h.model_dirs.size() > llm_index ? h.model_dirs[llm_index] : std::string(); + std::string prompt = apply_chat_template_messages( + arch, chat_messages, add_generation_prompt, model_dir); + LOGD("[DEBUG] formatted prompt length: %zu", prompt.length()); + LOGD("[DEBUG] formatted prompt preview: %.100s%s", prompt.c_str(), + prompt.length() > 100 ? "..." : ""); + + // Log pixel values summary (first few values) + // Note: patch size is fixed at 512x512 + const int PATCH_SIZE = 512; + long long totalValues = 1LL * numPatches * 3 * PATCH_SIZE * PATCH_SIZE; + LOGD("[DEBUG] totalPixelValues=%lld", totalValues); + if (totalValues > 0 && pixelValues != nullptr) { + LOGD("[DEBUG] pixelValues[0..4]=%f, %f, %f, %f, %f", pixelValues[0], + pixelValues[1], pixelValues[2], + (totalValues > 3 ? pixelValues[3] : 0.0f), + (totalValues > 4 ? pixelValues[4] : 0.0f)); + } + +#ifdef ENABLE_QNN + causallm::multimodal_pointer image_embeds{nullptr, 0}; + try { + image_embeds = run_vision_encoder(h, prompt.c_str(), pixelValues, + numPatches, originalHeight, originalWidth); + } catch (const std::exception &e) { + LOGE("[DEBUG] runMultimodalHandleWithMessages: vision threw: %s", e.what()); + *outputText = nullptr; + return CAUSAL_LM_ERROR_INFERENCE_FAILED; + } + + h.last_output.clear(); + auto accumulate_cb = [](const char *delta, void *ud) -> int { + if (delta) + static_cast(ud)->append(delta); + return 0; + }; + ErrorCode ec = execute_multimodal(h, h.models[1].get(), image_embeds, prompt, + accumulate_cb, &h.last_output); + if (ec != CAUSAL_LM_ERROR_NONE) { + *outputText = nullptr; + return ec; + } +#else + LOGE("[DEBUG] runMultimodalHandleWithMessages: built without ENABLE_QNN"); + *outputText = nullptr; + return CAUSAL_LM_ERROR_UNSUPPORTED; +#endif + *outputText = h.last_output.c_str(); + return CAUSAL_LM_ERROR_NONE; +} + +/*============================================================================ + * OpenAI messages streaming variants + *============================================================================*/ + +extern "C" { + +ErrorCode runModelHandleWithMessagesStreaming( + CausalLmHandle handle, const CausalLMChatMessage *messages, + size_t num_messages, bool add_generation_prompt, + CausalLmTokenCallback callback, void *user_data) { + LOGD("[DEBUG] runModelHandleWithMessagesStreaming: START"); + LOGD("[DEBUG] handle: %p", (void *)handle); + LOGD("[DEBUG] num_messages: %zu", num_messages); + + if (handle == nullptr || messages == nullptr || num_messages == 0 || + callback == nullptr) { + LOGE("[DEBUG] runModelHandleWithMessagesStreaming: INVALID_PARAMETER"); + return CAUSAL_LM_ERROR_INVALID_PARAMETER; + } + + auto &h = *handle; + std::lock_guard lock(h.mtx); + + if (!h.initialized || h.models.empty()) { + LOGE("[DEBUG] runModelHandleWithMessagesStreaming: NOT_INITIALIZED"); + return CAUSAL_LM_ERROR_NOT_INITIALIZED; + } + const size_t model_index = text_generation_model_index(h); + if (model_index >= h.models.size() || !h.models[model_index]) { + LOGE("[DEBUG] runModelHandleWithMessagesStreaming: text model is missing"); + return CAUSAL_LM_ERROR_NOT_INITIALIZED; + } + + try { + LOGD("[DEBUG] runModelHandleWithMessagesStreaming: Formatting messages..."); + + std::string model_dir = h.model_dirs.size() > model_index + ? h.model_dirs[model_index] + : std::string(); + + // Use the *actual* handle's architecture so architecture-specific + // chat template markers are generated. + auto chat_messages = convertMessages(messages, num_messages); + std::string arch = h.architectures.size() > model_index + ? h.architectures[model_index] + : std::string(); + std::string formattedInput = apply_chat_template_messages( + arch, chat_messages, add_generation_prompt, model_dir); + + LOGD("[DEBUG] raw messages count: %zu", num_messages); + LOGD("[DEBUG] formatted input length: %zu", formattedInput.length()); + LOGD("[DEBUG] formatted input: %s", formattedInput.c_str()); + + LOGD("[DEBUG] runModelHandleWithMessagesStreaming: Calling internal helper " + "directly..."); + return run_model_streaming_on_handle(h, formattedInput, callback, user_data, + /*input_already_formatted=*/true, + model_index); + } catch (const std::exception &e) { + LOGE("[DEBUG] runModelHandleWithMessagesStreaming: Exception caught: %s", + e.what()); + return CAUSAL_LM_ERROR_INFERENCE_FAILED; + } catch (...) { + LOGE( + "[DEBUG] runModelHandleWithMessagesStreaming: Unknown exception caught"); + return CAUSAL_LM_ERROR_INFERENCE_FAILED; + } +} + +ErrorCode runMultimodalHandleWithMessagesStreaming( + CausalLmHandle handle, const CausalLMChatMessage *messages, + size_t num_messages, bool add_generation_prompt, const float *pixelValues, + int numPatches, int originalHeight, int originalWidth, + CausalLmTokenCallback callback, void *user_data) { + LOGD("[DEBUG] runMultimodalHandleWithMessagesStreaming: START"); + LOGD("[DEBUG] handle: %p", (void *)handle); + LOGD("[DEBUG] num_messages: %zu", num_messages); + + if (handle == nullptr || messages == nullptr || num_messages == 0 || + pixelValues == nullptr || callback == nullptr) { + LOGE("[DEBUG] runMultimodalHandleWithMessagesStreaming: INVALID_PARAMETER"); + return CAUSAL_LM_ERROR_INVALID_PARAMETER; + } + + try { + LOGD("[DEBUG] runMultimodalHandleWithMessagesStreaming: Formatting " + "messages..."); + + std::string formattedInput; + { + auto &h = *handle; + std::lock_guard lock(h.mtx); + if (!h.initialized) { + LOGE("[DEBUG] runMultimodalHandleWithMessagesStreaming: handle is not " + "initialized for multimodal"); + return CAUSAL_LM_ERROR_NOT_INITIALIZED; + } + if (h.models.size() < 2) { + LOGE("[DEBUG] runMultimodalHandleWithMessagesStreaming: need >=2 " + "sub-models (got %zu)", + h.models.size()); + return CAUSAL_LM_ERROR_UNSUPPORTED; + } + + auto chat_messages = convertMessages(messages, num_messages); + const size_t llm_index = h.architectures.size() > 1 ? 1 : 0; + std::string arch = h.architectures.size() > llm_index + ? h.architectures[llm_index] + : std::string(); + std::string model_dir = h.model_dirs.size() > llm_index + ? h.model_dirs[llm_index] + : std::string(); + formattedInput = apply_chat_template_messages( + arch, chat_messages, add_generation_prompt, model_dir); + } + + LOGD("[DEBUG] raw messages count: %zu", num_messages); + LOGD("[DEBUG] formatted input length: %zu", formattedInput.length()); + LOGD("[DEBUG] formatted input preview: %.100s%s", formattedInput.c_str(), + formattedInput.length() > 100 ? "..." : ""); + + LOGD("[DEBUG] runMultimodalHandleWithMessagesStreaming: Delegating to " + "runMultimodalHandleStreaming..."); + return runMultimodalHandleStreaming(handle, formattedInput.c_str(), + pixelValues, numPatches, originalHeight, + originalWidth, callback, user_data); + } catch (const std::exception &e) { + LOGE( + "[DEBUG] runMultimodalHandleWithMessagesStreaming: Exception caught: %s", + e.what()); + return CAUSAL_LM_ERROR_INFERENCE_FAILED; + } catch (...) { + LOGE("[DEBUG] runMultimodalHandleWithMessagesStreaming: Unknown exception " + "caught"); + return CAUSAL_LM_ERROR_INFERENCE_FAILED; + } +} + +/*============================================================================ + * OpenAI JSON streaming API implementation + *============================================================================*/ + +ErrorCode runModelHandleWithJsonStreaming(CausalLmHandle handle, + const char *jsonRequest, + CausalLmTokenCallback callback, + void *user_data) { + LOGD("[DEBUG] runModelHandleWithJsonStreaming: START"); + LOGD("[DEBUG] handle: %p", (void *)handle); + LOGD("[DEBUG] jsonRequest length: %zu", + jsonRequest ? strlen(jsonRequest) : 0); + + if (handle == nullptr || jsonRequest == nullptr || callback == nullptr) { + LOGE("[DEBUG] runModelHandleWithJsonStreaming: INVALID_PARAMETER"); + return CAUSAL_LM_ERROR_INVALID_PARAMETER; + } + + auto &h = *handle; + std::lock_guard lock(h.mtx); + + if (!h.initialized || h.models.empty()) { + LOGE("[DEBUG] runModelHandleWithJsonStreaming: NOT_INITIALIZED"); + return CAUSAL_LM_ERROR_NOT_INITIALIZED; + } + const size_t model_index = text_generation_model_index(h); + if (model_index >= h.models.size() || !h.models[model_index]) { + LOGE("[DEBUG] runModelHandleWithJsonStreaming: text model is missing"); + return CAUSAL_LM_ERROR_NOT_INITIALIZED; + } + + try { + LOGD("[DEBUG] runModelHandleWithJsonStreaming: Parsing JSON request..."); + + // Parse JSON request + json request = json::parse(jsonRequest); + LOGD("[DEBUG] JSON parsed successfully"); + + // Apply chat template using the existing g_chat_template + // The chat_template.apply() method handles messages, tools, functions, etc. + std::string formattedInput; + if (g_chat_template.has_value()) { + LOGD( + "[DEBUG] runModelHandleWithJsonStreaming: Applying chat template..."); + formattedInput = g_chat_template->apply(request); + LOGD("[DEBUG] Formatted input length: %zu", formattedInput.length()); + LOGD("[DEBUG] Formatted input preview: %.100s%s", + formattedInput.c_str(), formattedInput.length() > 100 ? "..." : ""); + } else { + LOGE( + "[DEBUG] runModelHandleWithJsonStreaming: Chat template not available"); + return CAUSAL_LM_ERROR_UNSUPPORTED; + } + + LOGD("[DEBUG] runModelHandleWithJsonStreaming: Running inference..."); + return run_model_streaming_on_handle(h, formattedInput, callback, user_data, + /*input_already_formatted=*/true, + model_index); + } catch (const json::exception &e) { + LOGE("[DEBUG] runModelHandleWithJsonStreaming: JSON parse error: %s", + e.what()); + return CAUSAL_LM_ERROR_INVALID_PARAMETER; + } catch (const std::exception &e) { + LOGE("[DEBUG] runModelHandleWithJsonStreaming: Exception caught: %s", + e.what()); + return CAUSAL_LM_ERROR_INFERENCE_FAILED; + } catch (...) { + LOGE("[DEBUG] runModelHandleWithJsonStreaming: Unknown exception caught"); + return CAUSAL_LM_ERROR_INFERENCE_FAILED; + } +} + +// --------------------------------------------------------------------------- +// Multi-image Multimodal API (V-JEPA) +// --------------------------------------------------------------------------- + +ErrorCode runMultimodalMultiImageHandleStreaming( + CausalLmHandle handle, const char *prompt, const float *pixelValues, + int numPatches, int numImages, const int *patchesPerImage, + const int *originalHeights, const int *originalWidths, + CausalLmTokenCallback callback, void *user_data) { + + LOGD("[DEBUG] runMultimodalMultiImageHandleStreaming: START"); + LOGD("[DEBUG] handle: %p", (void *)handle); + LOGD("[DEBUG] numPatches: %d, numImages: %d", numPatches, numImages); + + if (handle == nullptr || prompt == nullptr || pixelValues == nullptr || + callback == nullptr || patchesPerImage == nullptr || + originalHeights == nullptr || originalWidths == nullptr) { + LOGE("[DEBUG] runMultimodalMultiImageHandleStreaming: INVALID_PARAMETER"); + return CAUSAL_LM_ERROR_INVALID_PARAMETER; + } + + if (numImages < 1 || numPatches < 1) { + LOGE("[DEBUG] runMultimodalMultiImageHandleStreaming: " + "numImages=%d, numPatches=%d β€” must be >= 1", + numImages, numPatches); + return CAUSAL_LM_ERROR_INVALID_PARAMETER; + } + + // Validate handle and initialization + { + auto &h = *reinterpret_cast(handle); + std::lock_guard lock(h.mtx); + if (!h.initialized || h.models.empty()) { + LOGE("[DEBUG] runMultimodalMultiImageHandleStreaming: NOT_INITIALIZED"); + return CAUSAL_LM_ERROR_NOT_INITIALIZED; + } + } + + // TODO: Implement multi-image (V-JEPA) inference. For now, delegate + // to the single-image path using the first image's metadata, as a + // temporary bridge until the V-JEPA vision encoder is integrated. + LOGD("[DEBUG] runMultimodalMultiImageHandleStreaming: STUB β€” delegating to " + "single-image runMultimodalHandleStreaming"); + + return runMultimodalHandleStreaming(handle, prompt, pixelValues, numPatches, + originalHeights[0], originalWidths[0], + callback, user_data); +} + +ErrorCode runMultimodalMultiImageHandleWithMessagesStreaming( + CausalLmHandle handle, const CausalLMChatMessage *messages, + size_t num_messages, bool add_generation_prompt, const float *pixelValues, + int numPatches, int numImages, const int *patchesPerImage, + const int *originalHeights, const int *originalWidths, + CausalLmTokenCallback callback, void *user_data) { + + LOGD("[DEBUG] runMultimodalMultiImageHandleWithMessagesStreaming: START"); + LOGD("[DEBUG] handle: %p", (void *)handle); + LOGD("[DEBUG] num_messages: %zu, numPatches: %d, numImages: %d", + num_messages, numPatches, numImages); + + if (handle == nullptr || messages == nullptr || pixelValues == nullptr || + callback == nullptr || patchesPerImage == nullptr || + originalHeights == nullptr || originalWidths == nullptr) { + LOGE("[DEBUG] runMultimodalMultiImageHandleWithMessagesStreaming: " + "INVALID_PARAMETER"); + return CAUSAL_LM_ERROR_INVALID_PARAMETER; + } + + if (numImages < 1 || numPatches < 1) { + LOGE("[DEBUG] runMultimodalMultiImageHandleWithMessagesStreaming: " + "numImages=%d, numPatches=%d β€” must be >= 1", + numImages, numPatches); + return CAUSAL_LM_ERROR_INVALID_PARAMETER; + } + + // TODO: Implement multi-image (V-JEPA) inference. For now, delegate + // to the single-image path using the first image's metadata, as a + // temporary bridge until the V-JEPA vision encoder is integrated. + LOGD("[DEBUG] runMultimodalMultiImageHandleWithMessagesStreaming: STUB β€” " + "delegating to single-image runMultimodalHandleWithMessagesStreaming"); + + return runMultimodalHandleWithMessagesStreaming( + handle, messages, num_messages, add_generation_prompt, pixelValues, + numPatches, originalHeights[0], originalWidths[0], callback, user_data); +} + +} // extern "C" diff --git a/api/quick_dot_ai_api.h b/api/quick_dot_ai_api.h new file mode 100644 index 00000000..0c8a9352 --- /dev/null +++ b/api/quick_dot_ai_api.h @@ -0,0 +1,641 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2026 Samsung Electronics Co., Ltd. All Rights Reserved. + * + * @file quick_dot_ai_api.h + * @date 20 Mar 2026 + * @brief C API for src (extension of CausalLM) + * + * This header is self-contained: if causal_lm_api.h has already + * been included its types are reused; otherwise fallback + * definitions are provided so that this single header is + * sufficient for application code. + * + * @see https://github.com/nntrainer/nntrainer + * @author Eunju Yang + * @bug No known bugs except for NYI items + */ +#ifndef __QUICK_DOT_AI_API_H__ +#define __QUICK_DOT_AI_API_H__ + +/* ── Extended model types (src additions) ────────────────────── */ +#ifdef __CAUSAL_LM_API_H__ +/* Model types already defined from causal_lm_api.h */ +#else /* causal_lm_api.h not included β€” provide full definitions */ + +#define __CAUSAL_LM_API_H__ + +#ifdef _WIN32 +#define WIN_EXPORT __declspec(dllexport) +#else +#define WIN_EXPORT +#endif + +#include "callback_streamer.h" +#include "streamer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +#include + +typedef enum { + CAUSAL_LM_ERROR_NONE = 0, + CAUSAL_LM_ERROR_INVALID_PARAMETER = 1, + CAUSAL_LM_ERROR_MODEL_LOAD_FAILED = 2, + CAUSAL_LM_ERROR_INFERENCE_FAILED = 3, + CAUSAL_LM_ERROR_NOT_INITIALIZED = 4, + CAUSAL_LM_ERROR_INFERENCE_NOT_RUN = 5, + CAUSAL_LM_ERROR_UNSUPPORTED = 6, + CAUSAL_LM_ERROR_UNKNOWN = 99 +} ErrorCode; + +typedef enum { + CAUSAL_LM_BACKEND_CPU = 0, + CAUSAL_LM_BACKEND_GPU = 1, + CAUSAL_LM_BACKEND_NPU = 2, +} BackendType; + +/* causallm::transformer.h defines enum class ModelType at global scope. + * Suppress our deprecated compat shim when that header is already included + * to prevent an ambiguous-name error in translation units that include both. */ +#ifndef __TRANSFORMER_H__ +/** + * @deprecated T4: λͺ¨λΈ μ‹λ³„μ˜ 정본은 λ¬Έμžμ—΄ id (loadModelHandleByName). + * 이 enum은 κΈ°μ‘΄ 호좜자 ν˜Έν™˜μš© public-only compat shim. + * λͺ¨λΈμ€ μΉ΄νƒˆλ‘œκ·Έλ‘œ μžλ™ 등둝. + */ +typedef enum { + CAUSAL_LM_MODEL_QWEN3_0_6B = 0, + CAUSAL_LM_MODEL_QWEN3_1_7B_Q40 = 4, /* original ordinal preserved */ + CAUSAL_LM_MODEL_TINY_BERT = 8, /* original */ + CAUSAL_LM_MODEL_FUNCTION_GEMMA = 9, /* original */ + CAUSAL_LM_MODEL_GEMMA4_CPU = 11, /* original */ + CAUSAL_LM_MODEL_GEMMA4_E2B_QNN = 12, /* original */ + CAUSAL_LM_MODEL_VJEPA_QNN = 13, +} ModelType; +#endif /* __TRANSFORMER_H__ */ + +typedef struct { + // Add configuration options here as needed + bool use_chat_template; /// < @brief Whether to apply chat template to input + bool debug_mode; /// < @brief Check model file validity during initialization + bool verbose; /// < @brief Whether to print output during generation + const char + *chat_template_name; /// < @brief Template name to select from array + /// (e.g., "default", "tool_use"). NULL for + /// "default". +} Config; + +WIN_EXPORT ErrorCode setOptions(Config config); + +typedef enum { + CAUSAL_LM_QUANTIZATION_UNKNOWN = 0, + CAUSAL_LM_QUANTIZATION_W4A32 = 1, + CAUSAL_LM_QUANTIZATION_W16A16 = 2, + CAUSAL_LM_QUANTIZATION_W8A16 = 3, + CAUSAL_LM_QUANTIZATION_W32A32 = 4, +} ModelQuantizationType; + +/** + * @brief Chat message structure for chat template formatting + * @note Compatible with HuggingFace apply_chat_template() format + */ +typedef struct { + const char *role; /**< Message role: "system", "user", or "assistant" */ + const char *content; /**< Message content text */ +} CausalLMChatMessage; + +/** + * @brief Load a model + * @param compute Backend compute type + * @param modeltype Model type + * @param quant_type Model quantization type + * @return ErrorCode + */ +#ifndef __TRANSFORMER_H__ +WIN_EXPORT ErrorCode loadModel(BackendType compute, ModelType modeltype, + ModelQuantizationType quant_type, + const char *model_base_path); +#endif /* __TRANSFORMER_H__ */ + +typedef struct { + unsigned int prefill_tokens; + double prefill_duration_ms; + unsigned int generation_tokens; + double generation_duration_ms; + double total_duration_ms; + double initialization_duration_ms; + size_t peak_memory_kb; +} PerformanceMetrics; + +WIN_EXPORT ErrorCode getPerformanceMetrics(PerformanceMetrics *metrics); + +WIN_EXPORT ErrorCode saveQnnKvCache(const char *cache_path); +WIN_EXPORT ErrorCode loadQnnKvCache(const char *cache_path); +WIN_EXPORT ErrorCode resetQnnKvCache(void); + +/** + * @brief Apply chat template to messages without running inference + * @param messages Array of chat messages with role and content + * @param num_messages Number of messages in the array + * @param add_generation_prompt Whether to append generation prompt at end + * @param formattedText Buffer to store formatted text (owned by the library) + * @return ErrorCode + */ +WIN_EXPORT ErrorCode applyChatTemplate(const CausalLMChatMessage *messages, + size_t num_messages, + bool add_generation_prompt, + const char **formattedText); +/*============================================================================ + * Handle-based API (for parallel multi-model execution) + * + * The non-handle API above operates on a single process-wide model instance + * protected by one global mutex, which serializes every call and prevents + * loading more than one model at a time. The handle-based API below lets a + * caller load several models simultaneously and run them in parallel from + * different threads, with per-handle state so that different handles never + * block each other. Each handle owns its own model, its own last-output + * buffer, and its own mutex. + * + * A single handle may internally carry multiple sub-models (e.g. vision + * encoder + LLM) when loaded from a top-level nntr_config.json that + * specifies "architectures" and "model_dirs" arrays. The single-model + * run API (runModelHandleWithMessages / runModelHandleStreaming) drives + *models[0] only; the multimodal API (runMultimodalHandle*) drives the full set. + * + * Typical usage: + * CausalLmHandle h = NULL; + * loadModelHandle(CAUSAL_LM_BACKEND_CPU, CAUSAL_LM_MODEL_QWEN3_0_6B, + * CAUSAL_LM_QUANTIZATION_W4A32, NULL, &h); + * const char *out = NULL; + * CausalLMChatMessage msg; + * msg.role = "user"; + * msg.content = "Hello"; + * runModelHandleWithMessages(h, &msg, 1, true, &out); + * // ... use out (owned by h, valid until the next run or destroy) ... + * destroyModelHandle(h); + *============================================================================*/ + +/** + * @brief Opaque handle to a loaded CausalLM model instance. + */ +typedef struct CausalLmModel *CausalLmHandle; + +/** + * @brief Load a model and return a newly-allocated handle. + * + * Calling this multiple times with different parameters returns independent + * handles, each with its own model state. The caller must eventually call + * destroyModelHandle on the returned handle to release resources. + * + * @param compute Backend compute type + * @param modeltype Model type enum + * @param quant_type Quantization type + * @param native_lib_dir Native library directory path (from Android + * ApplicationInfo.nativeLibraryDir). May be NULL. + * @param out_handle Out-parameter that receives the new handle on success + * @return ErrorCode + */ +#ifndef __TRANSFORMER_H__ +WIN_EXPORT ErrorCode loadModelHandle(BackendType compute, ModelType modeltype, + ModelQuantizationType quant_type, + const char *native_lib_dir, + const char *model_base_path, + CausalLmHandle *out_handle); +#endif /* __TRANSFORMER_H__ */ + +/** + * @brief Load model by string id (T4 catalog path). + * + * Looks up the descriptor from the registry by @p model_id, validates the + * backend, then loads via the same internal path as loadModelHandle. + * Returns CAUSAL_LM_ERROR_INVALID_PARAMETER if the id is unknown, the + * descriptor has no config_name, or the backend is not in backend_mask. + * + * @param compute Backend compute type + * @param model_id Catalog string id e.g. "Qwen3-0.6B" + * @param quant_type Quantization type + * @param native_lib_dir Native library directory path. May be NULL. + * @param model_base_path Base path for model files. May be NULL. + * @param out_handle Out-parameter receiving the new handle on success + * @return ErrorCode + */ +WIN_EXPORT ErrorCode loadModelHandleByName(BackendType compute, + const char *model_id, + ModelQuantizationType quant_type, + const char *native_lib_dir, + const char *model_base_path, + CausalLmHandle *out_handle); + +/** + * @brief Load a vision-encoder model and an LLM model as one multimodal handle. + * + * Lets the user freely pair an embedding (vision) model with an LLM by catalog + * id. The resulting handle has models[0] = embedding producer (vision) and + * models[1] = consumer (LLM); the multimodal run path drives the pair through + * the generic composer. + * + * @param compute Backend compute type + * @param embedding_model_id Catalog id of the vision encoder + * @param llm_model_id Catalog id of the LLM + * @param quant_type Quantization type + * @param native_lib_dir Native library directory path. May be NULL. + * @param model_base_path Base path for model files. May be NULL. + * @param out_handle Out-parameter receiving the new handle on success + * @return ErrorCode. CAUSAL_LM_ERROR_UNSUPPORTED if the pair is incompatible + * (e.g. the chosen LLM exposes no embedding table). + */ +WIN_EXPORT ErrorCode loadMultimodalHandleByName( + BackendType compute, const char *embedding_model_id, + const char *llm_model_id, ModelQuantizationType quant_type, + const char *native_lib_dir, const char *model_base_path, + CausalLmHandle *out_handle); + +/** + * @brief Run inference on a specific handle. + * + * The returned outputText pointer is owned by the handle and remains valid + * until the next runModelHandleWithMessages call on the same handle or until + * the handle is destroyed. Different handles are safe to call concurrently from + * different threads; the same handle is serialized by its own internal + * mutex. + * + * Single-model API: drives models[0] only even when the handle was + * populated with multiple sub-models. Use runMultimodalHandleWithMessages for + * compositions such as vision-encoder + LLM. + * + * @param handle Handle returned by loadModelHandle + * @param messages Array of chat messages with role and content + * @param num_messages Number of messages in the array + * @param add_generation_prompt Whether to append generation prompt at end + * @param outputText Out-parameter that receives a pointer to the output + * @return ErrorCode + */ +WIN_EXPORT ErrorCode runModelHandleWithMessages( + CausalLmHandle handle, const CausalLMChatMessage *messages, + size_t num_messages, bool add_generation_prompt, const char **outputText); + +/** + * @brief Streaming inference with OpenAI message format on a specific handle. + * + * Format the messages array through the chat template, then drive + * generation token-by-token, invoking @p callback for each delta. + * Blocks on the invoking thread until generation finishes or an error + * occurs. Semantics are otherwise identical to runModelHandleStreaming. + * + * @param handle Handle returned by loadModelHandle + * @param messages Array of chat messages with role and content + * @param num_messages Number of messages in the array + * @param add_generation_prompt Whether to append generation prompt at end + * @param callback Token delta callback. Must be non-NULL. + * @param user_data Opaque pointer forwarded to callback + * @return ErrorCode + */ +WIN_EXPORT ErrorCode runModelHandleWithMessagesStreaming( + CausalLmHandle handle, const CausalLMChatMessage *messages, + size_t num_messages, bool add_generation_prompt, + CausalLmTokenCallback callback, void *user_data); + +WIN_EXPORT ErrorCode saveQnnKvCacheHandle(CausalLmHandle handle, + const char *cache_path); +WIN_EXPORT ErrorCode loadQnnKvCacheHandle(CausalLmHandle handle, + const char *cache_path); +WIN_EXPORT ErrorCode resetQnnKvCacheHandle(CausalLmHandle handle); + +/** + * @brief Retrieve performance metrics of the last run for a given handle. + * @param handle Handle returned by loadModelHandle + * @param metrics Pointer to a PerformanceMetrics struct to be filled + * @return ErrorCode + */ +WIN_EXPORT ErrorCode getPerformanceMetricsHandle(CausalLmHandle handle, + PerformanceMetrics *metrics); + +/** + * @brief Release all resources owned by a handle. + * + * Passing a NULL handle is a no-op and returns CAUSAL_LM_ERROR_NONE. + * + * @param handle Handle returned by loadModelHandle + * @return ErrorCode + */ +WIN_EXPORT ErrorCode destroyModelHandle(CausalLmHandle handle); + +/** + * @brief Request cancellation of an in-progress run on a handle. + * + * Sets the stop flag on the model, causing the token generation loop + * to exit at the next token boundary. Thread-safe: can be called from + * any thread (e.g., from a UI cancel button handler). + * + * If no run is in progress, this function is a no-op. + * + * @param handle Handle returned by loadModelHandle + * @return ErrorCode + */ +WIN_EXPORT ErrorCode cancelModelHandle(CausalLmHandle handle); + +/** + * @brief Unload the model from a handle without destroying the handle. + * + * Releases the model weights and internal state but keeps the handle + * struct alive. After a successful unload, the handle's initialized flag + * is cleared and subsequent run / metrics calls will return + * CAUSAL_LM_ERROR_NOT_INITIALIZED. The handle can be destroyed later + * with destroyModelHandle, or (in future) re-loaded. + * + * Passing a NULL handle is a no-op and returns CAUSAL_LM_ERROR_NONE. + * + * @param handle Handle returned by loadModelHandle + * @return ErrorCode + */ +WIN_EXPORT ErrorCode unloadModelHandle(CausalLmHandle handle); + +/** + * @brief Streaming counterpart of runModelHandle. + * + * Synchronously drives inference on @p handle and invokes @p callback + * once per decoded-token boundary with a UTF-8 delta string. The call + * blocks on the invoking thread until generation finishes, hits an EOS + * token, reaches NUM_TO_GENERATE, the callback returns non-zero (which + * requests cancellation at the next token boundary), or an error + * occurs. + * + * The @p delta pointer passed into the callback is owned by the + * streaming runtime and is only valid for the duration of the callback + * invocation. Callers that need to retain the text must copy it. + * + * After a successful return the handle's "last output" buffer holds + * the full concatenated generation (or the partial output on a + * cancelled run), so a subsequent getPerformanceMetricsHandle() call + * returns valid metrics and the same handle can be reused for another + * run β€” identical semantics to runModelHandleWithMessages. + * + * Streaming is currently only supported on models whose underlying + * C++ implementation derives from causallm::CausalLM (all the Qwen + * variants and Llama do; non-CausalLM models return + * CAUSAL_LM_ERROR_UNKNOWN). See AsyncAndStreaming.md Β§3.4 at the repo + * root for the full design. + * + * @param handle Handle returned by loadModelHandle. + * @param inputTextPrompt Input prompt (UTF-8, NUL-terminated). + * @param callback Token delta callback. Must be non-NULL. + * @param user_data Opaque pointer forwarded verbatim to the + * callback on every invocation. May be NULL. + * @return ErrorCode + */ +WIN_EXPORT ErrorCode runModelHandleStreaming(CausalLmHandle handle, + const char *inputTextPrompt, + CausalLmTokenCallback callback, + void *user_data); + +/** + * @brief Encode a single text prompt into a sentence-embedding vector using a + * handle whose models[0] is an embedding model (e.g. Ouro / "ouro"). + * + * On success, *out_embedding points to a freshly allocated array of *out_dim + * floats (the batch-0 embedding). The caller OWNS this buffer and MUST release + * it with freeEmbedding(). On any error, *out_embedding is set to NULL and + * *out_dim to 0. + * + * @param handle Handle from loadModelHandle / loadModelHandleByName + * @param text UTF-8 input text (NUL-terminated) + * @param out_embedding [out] receives a newly allocated float[*out_dim] + * @param out_dim [out] receives the embedding dimension + * @return ErrorCode. CAUSAL_LM_ERROR_UNSUPPORTED if models[0] is not an + * embedding (SentenceTransformer) model. + */ +WIN_EXPORT ErrorCode encodeModelHandle(CausalLmHandle handle, const char *text, + float **out_embedding, int *out_dim); + +/** + * @brief Release a buffer returned by encodeModelHandle(). + * @param embedding Pointer previously returned via out_embedding (may be NULL) + */ +WIN_EXPORT void freeEmbedding(float *embedding); + +/** + * @brief Run inference on a handle with a tool schema for constrained + * generation. + * + * @param handle Handle returned by loadModelHandle + * @param inputTextPrompt Input prompt text + * @param outputText Buffer to store output text (owned by the handle) + * @param tool_name Name of the tool (used to cache the grammar) + * @param tool_schema JSON schema for the tool output format + * @return ErrorCode + */ +WIN_EXPORT ErrorCode runModelHandleWithTool(CausalLmHandle handle, + const char *inputTextPrompt, + const char **outputText, + const char *tool_name, + const char *tool_schema); + +/*============================================================================ + * Multimodal API + * + * These functions extend the handle-based API to support image+text inputs. + * The pixel values are passed as preprocessed FloatArray (CHW format) from + * the Kotlin image processor (LlavaNextImageProcessor). + * + * The handle must have been loaded from a multi-model nntr_config.json + * (architectures[] + model_dirs[]) with at least [vision_encoder, llm]; + * a single-model handle returns CAUSAL_LM_ERROR_UNSUPPORTED. + * + * Vision Encoder integration is planned for future implementation. + * Currently these functions return CAUSAL_LM_ERROR_UNSUPPORTED as stubs + * once the multi-model precondition is satisfied. + *============================================================================*/ + +/** + * @brief Streaming multimodal inference on a specific handle. + * + * @param handle Handle returned by loadModelHandle + * @param prompt Text prompt (UTF-8, NUL-terminated) + * @param pixelValues Preprocessed image patches in CHW format + * @param numPatches Number of image patches + * @param originalHeight Original image height before preprocessing + * @param originalWidth Original image width before preprocessing + * @param callback Token delta callback. Must be non-NULL. + * @param user_data Opaque pointer forwarded to callback + * @return ErrorCode (CAUSAL_LM_ERROR_UNSUPPORTED until Vision Encoder + * implemented) + */ +WIN_EXPORT ErrorCode runMultimodalHandleStreaming( + CausalLmHandle handle, const char *prompt, const float *pixelValues, + int numPatches, int originalHeight, int originalWidth, + CausalLmTokenCallback callback, void *user_data); + +/** + * @brief Blocking multimodal inference with OpenAI message format on a specific + * handle. + * + * @param handle Handle returned by loadModelHandle + * @param messages Array of chat messages with role and content + * (text-only, image via pixelValues) + * @param num_messages Number of messages in the array + * @param add_generation_prompt Whether to append generation prompt at end + * @param pixelValues Preprocessed image patches in CHW format + * @param numPatches Number of image patches + * @param originalHeight Original image height before preprocessing + * @param originalWidth Original image width before preprocessing + * @param outputText Out-parameter that receives a pointer to the output + * @return ErrorCode (CAUSAL_LM_ERROR_UNSUPPORTED until Vision Encoder + * implemented) + */ +WIN_EXPORT ErrorCode runMultimodalHandleWithMessages( + CausalLmHandle handle, const CausalLMChatMessage *messages, + size_t num_messages, bool add_generation_prompt, const float *pixelValues, + int numPatches, int originalHeight, int originalWidth, + const char **outputText); + +/** + * @brief Streaming multimodal inference with OpenAI message format on a + * specific handle. + * + * Format the messages array through the chat template, run the vision + * encoder if needed, then drive LLM generation token-by-token invoking + * @p callback for each delta. Blocks on the invoking thread until + * generation finishes or an error occurs. + * + * @param handle Handle returned by loadModelHandle + * @param messages Array of chat messages with role and content + * (text-only, image via pixelValues) + * @param num_messages Number of messages in the array + * @param add_generation_prompt Whether to append generation prompt at end + * @param pixelValues Preprocessed image patches in CHW format + * @param numPatches Number of image patches + * @param originalHeight Original image height before preprocessing + * @param originalWidth Original image width before preprocessing + * @param callback Token delta callback. Must be non-NULL. + * @param user_data Opaque pointer forwarded to callback + * @return ErrorCode + */ +WIN_EXPORT ErrorCode runMultimodalHandleWithMessagesStreaming( + CausalLmHandle handle, const CausalLMChatMessage *messages, + size_t num_messages, bool add_generation_prompt, const float *pixelValues, + int numPatches, int originalHeight, int originalWidth, + CausalLmTokenCallback callback, void *user_data); + +/*============================================================================ + * Multi-image Multimodal API (V-JEPA) + * + * These functions extend the multimodal API to support multiple images + * (e.g. video frames for V-JEPA). The pixel values for all images are + * concatenated into a single flat array, with per-image metadata + * (patches per image, heights, widths) passed as separate arrays. + * + * The handle must have been loaded with CAUSAL_LM_MODEL_VJEPA_QNN or + * another multi-image-capable model type. + *============================================================================*/ + +/** + * @brief Streaming multi-image multimodal inference on a specific handle. + * + * Designed for models like V-JEPA that accept multiple preprocessed + * image frames (e.g. 16 video frames) as input. + * + * @param handle Handle returned by loadModelHandle + * @param prompt Text prompt (UTF-8, NUL-terminated) + * @param pixelValues Preprocessed image patches in CHW format + * (all images concatenated) + * @param numPatches Total number of image patches across all images + * @param numImages Number of images (e.g. 16 for V-JEPA) + * @param patchesPerImage Array of numImages ints: patches per image + * @param originalHeights Array of numImages ints: original height per image + * @param originalWidths Array of numImages ints: original width per image + * @param callback Token delta callback. Must be non-NULL. + * @param user_data Opaque pointer forwarded to callback + * @return ErrorCode + */ +WIN_EXPORT ErrorCode runMultimodalMultiImageHandleStreaming( + CausalLmHandle handle, const char *prompt, const float *pixelValues, + int numPatches, int numImages, const int *patchesPerImage, + const int *originalHeights, const int *originalWidths, + CausalLmTokenCallback callback, void *user_data); + +/** + * @brief Streaming multi-image multimodal inference with OpenAI message + * format on a specific handle. + * + * @param handle Handle returned by loadModelHandle + * @param messages Array of chat messages with role and content + * @param num_messages Number of messages in the array + * @param add_generation_prompt Whether to append generation prompt at end + * @param pixelValues Preprocessed image patches in CHW format + * (all images concatenated) + * @param numPatches Total number of image patches across all images + * @param numImages Number of images (e.g. 16 for V-JEPA) + * @param patchesPerImage Array of numImages ints: patches per image + * @param originalHeights Array of numImages ints: original height per image + * @param originalWidths Array of numImages ints: original width per image + * @param callback Token delta callback. Must be non-NULL. + * @param user_data Opaque pointer forwarded to callback + * @return ErrorCode + */ +WIN_EXPORT ErrorCode runMultimodalMultiImageHandleWithMessagesStreaming( + CausalLmHandle handle, const CausalLMChatMessage *messages, + size_t num_messages, bool add_generation_prompt, const float *pixelValues, + int numPatches, int numImages, const int *patchesPerImage, + const int *originalHeights, const int *originalWidths, + CausalLmTokenCallback callback, void *user_data); + +/*============================================================================ + * OpenAI JSON streaming API + * + * Accepts a JSON string in OpenAI format and processes it through the + * chat template. Supports messages, tools, functions, and all other + * fields recognized by minja chat template renderer. + * + * Example JSON input: + * { + * "messages": [ + * {"role": "developer", "content": "..."}, + * {"role": "user", "content": "..."} + * ], + * "tools": [ + * {"type": "function", "function": {"name": "call", "description": "..."}} + * ] + * } + *============================================================================*/ + +/** + * @brief Streaming inference with OpenAI JSON format. + * + * Parses the JSON request and applies the chat template, then drives + * generation token-by-token invoking @p callback for each delta. + * + * @param handle Handle returned by loadModelHandle + * @param jsonRequest OpenAI format JSON string (UTF-8, NUL-terminated) + * @param callback Token delta callback. Must be non-NULL. + * @param user_data Opaque pointer forwarded to callback + * @return ErrorCode + */ +WIN_EXPORT ErrorCode runModelHandleWithJsonStreaming( + CausalLmHandle handle, const char *jsonRequest, + CausalLmTokenCallback callback, void *user_data); + +/** + * @brief Return a JSON array of all registered ModelDescriptors. + * + * Returns a NUL-terminated UTF-8 string like: + * [{"id":"...","family":"...","display_name":"...","runtime":0, + * "backend_mask":0,"capabilities":0}, ...] + * + * The registry is empty until tasks that call + * quick_dot_ai::register_model_descriptor() are linked in. + * The returned pointer is valid until the next call to getModelCatalogJson(). + * + * @return const char* JSON array string (never NULL) + */ +WIN_EXPORT const char *getModelCatalogJson(void); + +#ifdef __cplusplus +} +#endif + +#endif /* __CAUSAL_LM_API_H__ */ + +#endif /* __QUICK_DOT_AI_API_H__ */ diff --git a/api/streamer.cpp b/api/streamer.cpp new file mode 100644 index 00000000..c578e3e9 --- /dev/null +++ b/api/streamer.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2026 Samsung Electronics Co., Ltd. All Rights Reserved. + * + * @file streamer.cpp + * @brief Null-safe wrappers around the BaseStreamer vtable declared in + * streamer.h. See AsyncAndStreaming.md Β§3.1. + */ + +#include "streamer.h" + +extern "C" { + +int streamer_put(BaseStreamer *self, const char *decoded_utf8) { + if (self == nullptr || self->vtable == nullptr || + self->vtable->put == nullptr || decoded_utf8 == nullptr) { + return 0; + } + return self->vtable->put(self, decoded_utf8); +} + +void streamer_end(BaseStreamer *self) { + if (self == nullptr || self->vtable == nullptr || + self->vtable->end == nullptr) { + return; + } + self->vtable->end(self); +} + +} // extern "C" \ No newline at end of file diff --git a/api/streamer.h b/api/streamer.h new file mode 100644 index 00000000..d661716e --- /dev/null +++ b/api/streamer.h @@ -0,0 +1,94 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2026 Samsung Electronics Co., Ltd. All Rights Reserved. + * + * @file streamer.h + * @brief Minimal C-callable base streamer used by the handle-based + * `runModelHandleStreaming` entry point in quick_dot_ai_api.h. + * + * This is intentionally a very thin vtable-based polymorphism layer so + * that: + * - the CausalLM inference loop can push decoded tokens through a + * single pointer, + * - concrete streamers (currently only CallbackStreamer) can be + * implemented in plain C without dragging C++ headers into the + * CausalLM internals, + * - the same mechanism is reusable from JNI callers (the JNI bridge + * instantiates a CallbackStreamer on the stack and lets the C API + * drive it). + * + * See AsyncAndStreaming.md Β§3.1 at the repo root for the full design + * rationale. + */ +#ifndef __QUICK_DOT_AI_STREAMER_H__ +#define __QUICK_DOT_AI_STREAMER_H__ + +#ifndef WIN_EXPORT +#ifdef _WIN32 +#define WIN_EXPORT __declspec(dllexport) +#else +#define WIN_EXPORT +#endif +#endif + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct BaseStreamer BaseStreamer; + +/** + * @brief Vtable for a BaseStreamer. + * + * Both function pointers may be NULL β€” the streamer_put / streamer_end + * helpers below are null-safe so the caller never has to check. + */ +typedef struct { + /** + * @brief Forward one UTF-8 delta string to the streamer. + * + * The pointer is only valid for the duration of the call; the + * streamer implementation must copy if it needs to retain the data. + * + * @return 0 to continue generation, non-zero to request cancellation + * at the next token boundary. + */ + int (*put)(BaseStreamer *self, const char *decoded_utf8); + + /** + * @brief Called exactly once after the last put, regardless of whether + * generation finished normally, was cancelled via the callback + * return value, or ended because an exception propagated out of + * the run loop. + */ + void (*end)(BaseStreamer *self); +} BaseStreamerVTable; + +/** + * @brief Base streamer. Concrete streamers embed this as their first + * field and set @c vtable to a static const instance of + * BaseStreamerVTable. + */ +struct BaseStreamer { + const BaseStreamerVTable *vtable; +}; + +/** + * @brief NULL-safe wrapper around the vtable's put() hook. Returns + * non-zero if the streamer requested cancellation. + */ +WIN_EXPORT int streamer_put(BaseStreamer *self, const char *decoded_utf8); + +/** + * @brief NULL-safe wrapper around the vtable's end() hook. Idempotent + * from the caller's perspective β€” concrete implementations + * should tolerate being called multiple times, but the CausalLM + * inference path calls this at most once. + */ +WIN_EXPORT void streamer_end(BaseStreamer *self); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // __QUICK_DOT_AI_STREAMER_H__ \ No newline at end of file diff --git a/api/test_api.cpp b/api/test_api.cpp deleted file mode 100644 index 8477cfad..00000000 --- a/api/test_api.cpp +++ /dev/null @@ -1,306 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * Copyright (C) 2025 Eunju Yang - * - * @file test_api.cpp - * @date 21 Jan 2026 - * @see https://github.com/nntrainer/nntrainer - * @author Eunju Yang - * @brief Simple application to test CausalLM API - * @bug No known bugs except for NYI items - * - */ - -#include "causal_lm_api.h" -#include -#include -#include -#include -#include -#include -#include - -namespace { -constexpr const char *COLOR_RESET = "\033[0m"; -constexpr const char *COLOR_BOLD = "\033[1m"; -constexpr const char *COLOR_CYAN = "\033[36m"; -constexpr const char *COLOR_GREEN = "\033[32m"; -constexpr const char *COLOR_YELLOW = "\033[33m"; -constexpr const char *COLOR_BLUE = "\033[34m"; -constexpr const char *COLOR_RED = "\033[31m"; -constexpr const char *COLOR_MAGENTA = "\033[35m"; -constexpr const char *COLOR_GRAY = "\033[90m"; - -void printLine(const std::string &s, int length = 80) { - for (int i = 0; i < length; ++i) - std::cout << s; - std::cout << std::endl; -} - -void printSection(const std::string §ion) { - std::cout << "\n" - << COLOR_BOLD << COLOR_BLUE - << "+-------------------------------------------------------------+" - << COLOR_RESET << "\n"; - std::cout << COLOR_BOLD << COLOR_BLUE << "| " << section - << std::string(58 - section.length(), ' ') << "|" << COLOR_RESET - << "\n"; - std::cout << COLOR_BOLD << COLOR_BLUE - << "+-------------------------------------------------------------+" - << COLOR_RESET << "\n\n"; -} - -void printSuccess(const std::string &msg) { - std::cout << COLOR_GREEN << "βœ“ " << COLOR_BOLD << msg << COLOR_RESET - << "\n\n"; -} - -void printError(const std::string &msg) { - std::cerr << COLOR_RED << "βœ— " << COLOR_BOLD << "Error: " << COLOR_RESET - << msg << "\n"; -} - -void printWarning(const std::string &msg) { - std::cout << COLOR_YELLOW << "⚠ " << msg << COLOR_RESET << "\n"; -} - -void printInfo(const std::string &label, const std::string &value) { - std::cout << COLOR_CYAN << " " << label << ":" << COLOR_RESET << " " << value - << "\n"; -} - -void printLogo() { - std::cout << "\n"; - std::cout << COLOR_BOLD << COLOR_MAGENTA; - std::cout << " β–ˆβ–ˆβ–ˆβ•— β–ˆβ–ˆβ•—β–ˆβ–ˆβ–ˆβ•— β–ˆβ–ˆβ•—β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•—β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•— \n"; - std::cout << " β–ˆβ–ˆβ–ˆβ–ˆβ•— β–ˆβ–ˆβ•‘β–ˆβ–ˆβ–ˆβ–ˆβ•— β–ˆβ–ˆβ•‘β•šβ•β•β–ˆβ–ˆβ•”β•β•β•β–ˆβ–ˆβ•”β•β•β–ˆβ–ˆβ•—\n"; - std::cout << " β–ˆβ–ˆβ•”β–ˆβ–ˆβ•— β–ˆβ–ˆβ•‘β–ˆβ–ˆβ•”β–ˆβ–ˆβ•— β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•”β•\n"; - std::cout << " β–ˆβ–ˆβ•‘β•šβ–ˆβ–ˆβ•—β–ˆβ–ˆβ•‘β–ˆβ–ˆβ•‘β•šβ–ˆβ–ˆβ•—β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•”β•β•β–ˆβ–ˆβ•—\n"; - std::cout << " β–ˆβ–ˆβ•‘ β•šβ–ˆβ–ˆβ–ˆβ–ˆβ•‘β–ˆβ–ˆβ•‘ β•šβ–ˆβ–ˆβ–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•‘\n"; - std::cout << " β•šβ•β• β•šβ•β•β•β•β•šβ•β• β•šβ•β•β•β• β•šβ•β• β•šβ•β• β•šβ•β•\n"; - std::cout << COLOR_RESET; - std::cout << COLOR_BOLD << COLOR_CYAN - << " ────────────────────────────────\n"; - std::cout << " Causal Language Model API\n" - << " ────────────────────────────────\n"; - std::cout << COLOR_RESET << "\n"; -} - -void printUsage(const char *program_name) { - std::cout << COLOR_YELLOW << "Usage:" << COLOR_RESET << "\n"; - std::cout << " " << COLOR_BOLD << program_name << COLOR_RESET - << " [prompt] [use_chat_template] [quantization] " - "[verbose] \n\n"; - - std::cout << COLOR_CYAN << "Arguments:" << COLOR_RESET << "\n"; - std::cout << " model_name " << COLOR_BOLD << "REQUIRED" << COLOR_RESET - << " - Model name (e.g., QWEN3-0.6B)\n"; - std::cout << " prompt " << COLOR_GREEN << "OPTIONAL" - << COLOR_RESET - << " - Input prompt (default: 'Hello, how are you?')\n"; - std::cout << " use_chat_template " << COLOR_GREEN << "OPTIONAL" - << COLOR_RESET << " - 0/1 or true/false (default: 1)\n"; - std::cout << " quantization " << COLOR_GREEN << "OPTIONAL" - << COLOR_RESET - << " - W4A32/W16A16/W8A16/W32A32/UNKNOWN (default: UNKNOWN)\n"; - std::cout << " verbose " << COLOR_GREEN << "OPTIONAL" - << COLOR_RESET << " - 0/1 or true/false (default: 0)\n\n"; - - std::cout << COLOR_YELLOW << "Examples:" << COLOR_RESET << "\n"; - std::cout << " " << COLOR_BOLD << program_name << COLOR_RESET - << " QWEN3-0.6B \"Tell me a joke\" 1 W4A32\n"; - std::cout << " " << COLOR_BOLD << program_name << COLOR_RESET - << " QWEN3-0.6B \"Write a poem\" 1 W32A32 1\n\n"; -} -} // namespace - -int main(int argc, char *argv[]) { - printLogo(); - - if (argc < 2) { - printSection("ERROR: Missing Required Arguments"); - printUsage(argv[0]); - return 1; - } - - const char *model_name = argv[1]; - const char *prompt = (argc >= 3) ? argv[2] : "Hello, how are you?"; - bool use_chat_template = true; - if (argc >= 4) { - use_chat_template = - (std::string(argv[3]) == "1" || std::string(argv[3]) == "true"); - } - - std::string quant_str = "UNKNOWN"; - ModelQuantizationType quant_type = CAUSAL_LM_QUANTIZATION_UNKNOWN; - if (argc >= 5) { - quant_str = std::string(argv[4]); - if (quant_str == "W4A32") - quant_type = CAUSAL_LM_QUANTIZATION_W4A32; - else if (quant_str == "W16A16") - quant_type = CAUSAL_LM_QUANTIZATION_W16A16; - else if (quant_str == "W8A16") - quant_type = CAUSAL_LM_QUANTIZATION_W8A16; - else if (quant_str == "W32A32") - quant_type = CAUSAL_LM_QUANTIZATION_W32A32; - } - - bool verbose = true; - if (argc >= 6) { - verbose = (std::string(argv[5]) == "1" || std::string(argv[5]) == "true"); - } - - printSection("Configuration"); - printInfo("Model Name", model_name); - printInfo("Use Chat Template", use_chat_template ? "true" : "false"); - printInfo("Quantization", quant_str); - printInfo("Verbose", verbose ? "true" : "false"); - std::cout << "\n"; - - printSection("Initialization"); - std::cout << COLOR_CYAN << "⏳ " << COLOR_RESET << "Configuring options...\n"; - Config config; - config.use_chat_template = use_chat_template; - config.debug_mode = true; - config.verbose = verbose; - ErrorCode err = setOptions(config); - if (err != CAUSAL_LM_ERROR_NONE) { - printError("Failed to set options"); - std::cerr << " Error code: " << static_cast(err) << "\n"; - return 1; - } - printSuccess("Options configured successfully"); - - printSection("Model Loading"); - std::cout << COLOR_CYAN << "⏳ " << COLOR_RESET - << "Loading model: " << COLOR_BOLD << model_name << COLOR_RESET - << "\n"; - - // Map string to ModelType - ModelType model_type = CAUSAL_LM_MODEL_QWEN3_0_6B; - std::string model_name_str(model_name); - if (model_name_str == "QWEN3-0.6B") { - model_type = CAUSAL_LM_MODEL_QWEN3_0_6B; - } else { - std::cout << COLOR_YELLOW << "⚠ Warning: Unknown model name '" - << model_name_str << "'. Defaulting to QWEN3-0.6B." << COLOR_RESET - << "\n"; - } - - err = loadModel(CAUSAL_LM_BACKEND_CPU, model_type, quant_type); - - if (err != CAUSAL_LM_ERROR_NONE) { - printError("Failed to load model"); - std::cerr << " Error code: " << static_cast(err) << "\n"; - return 1; - } - printSuccess("Model loaded successfully"); - - printSection("Inference"); - std::cout << COLOR_CYAN << "πŸ“ " << COLOR_RESET << "Input Prompt:\n"; - std::cout << COLOR_BOLD << COLOR_YELLOW << " " << prompt << COLOR_RESET - << "\n\n"; - - std::cout << COLOR_CYAN << "⚑ " << COLOR_RESET << "Running inference...\n\n"; - - const char *outputText = nullptr; - - if (verbose) { - std::cout << COLOR_CYAN << "πŸ’¬ " << COLOR_RESET << "Streaming Output:\n"; - std::cout << COLOR_BOLD << COLOR_GRAY; - } - - err = runModel(prompt, &outputText); - - if (verbose) { - std::cout << COLOR_RESET << "\n\n"; - } - - if (err != CAUSAL_LM_ERROR_NONE) { - printError("Failed to run model"); - std::cerr << " Error code: " << static_cast(err) << "\n"; - return 1; - } - - if (outputText) { - std::cout << COLOR_CYAN << "πŸ’¬ " << COLOR_RESET << "Output:\n"; - std::cout << COLOR_BOLD << COLOR_GREEN << " "; - std::string out(outputText); - size_t pos = 0; - while (pos < out.length()) { - size_t newlinePos = out.find('\n', pos); - if (newlinePos == std::string::npos) { - newlinePos = out.length(); - } - std::string line = out.substr(pos, newlinePos - pos); - std::cout << line; - if (newlinePos < out.length()) { - std::cout << "\n "; - pos = newlinePos + 1; - } else { - pos = out.length(); - } - } - std::cout << COLOR_RESET << "\n\n"; - } else { - printWarning("No output generated"); - } - - printSection("Performance Metrics"); - PerformanceMetrics metrics; - err = getPerformanceMetrics(&metrics); - if (err != CAUSAL_LM_ERROR_NONE) { - printWarning("Failed to get metrics"); - std::cout << " Error code: " << static_cast(err) << "\n"; - } else { - double prefill_tps = - metrics.prefill_duration_ms > 0 - ? (metrics.prefill_tokens / metrics.prefill_duration_ms * 1000.0) - : 0.0; - double gen_tps = - metrics.generation_duration_ms > 0 - ? (metrics.generation_tokens / metrics.generation_duration_ms * 1000.0) - : 0.0; - - std::cout << COLOR_CYAN << " πŸ“Š " << COLOR_RESET << COLOR_BOLD - << "Prefill Stage" << COLOR_RESET << "\n"; - std::cout << COLOR_CYAN << " Tokens:" << COLOR_RESET << " " - << metrics.prefill_tokens << "\n"; - std::cout << COLOR_CYAN << " Duration:" << COLOR_RESET << " " - << std::fixed << std::setprecision(2) - << metrics.prefill_duration_ms << " ms\n"; - std::cout << COLOR_CYAN << " Throughput:" << COLOR_RESET << " " - << COLOR_BOLD << COLOR_GREEN << std::fixed << std::setprecision(1) - << prefill_tps << COLOR_RESET << " tokens/sec\n\n"; - - std::cout << COLOR_CYAN << " πŸ“Š " << COLOR_RESET << COLOR_BOLD - << "Generation Stage" << COLOR_RESET << "\n"; - std::cout << COLOR_CYAN << " Tokens:" << COLOR_RESET << " " - << metrics.generation_tokens << "\n"; - std::cout << COLOR_CYAN << " Duration:" << COLOR_RESET << " " - << std::fixed << std::setprecision(2) - << metrics.generation_duration_ms << " ms\n"; - std::cout << COLOR_CYAN << " Throughput:" << COLOR_RESET << " " - << COLOR_BOLD << COLOR_GREEN << std::fixed << std::setprecision(1) - << gen_tps << COLOR_RESET << " tokens/sec\n\n"; - - std::cout << COLOR_CYAN << " πŸ“Š " << COLOR_RESET << COLOR_BOLD - << "Total Stats" << COLOR_RESET << "\n"; - std::cout << COLOR_CYAN << " Init time:" << COLOR_RESET << " " - << std::fixed << std::setprecision(2) - << metrics.initialization_duration_ms << " ms\n"; - std::cout << COLOR_CYAN << " Duration :" << COLOR_RESET << " " - << std::fixed << std::setprecision(2) << metrics.total_duration_ms - << " ms\n"; - std::cout << COLOR_CYAN << " Peak Mem:" << COLOR_RESET << " " - << metrics.peak_memory_kb / 1024 << " MB\n\n"; - } - - printLine("═", 63); - std::cout << COLOR_BOLD << COLOR_GREEN << " βœ“ Test completed successfully!" - << COLOR_RESET << "\n"; - printLine("═", 63); - std::cout << "\n"; - - return 0; -} diff --git a/apk-build-install.sh b/apk-build-install.sh new file mode 100755 index 00000000..57ecefae --- /dev/null +++ b/apk-build-install.sh @@ -0,0 +1,68 @@ +#!/bin/bash +echo "==========================================" +echo " Android Build & Install Script" +echo "==========================================" + +# Exit immediately if any command fails +set -e + +# ========================================================== +# Configuration +# ========================================================== +APK_APPLICATION="SampleTestAPP" + +# ========================================================== +# 1. Configure Environment +# ========================================================== +echo "[1/6] Configuring environment variables..." +export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:${NDK_ROOT}" +export PATH="${PATH}:${NDK_ROOT}" + +if [ -z "$NDK_ROOT" ]; then + echo "Error: NDK_ROOT environment variable is not set" + echo "Please set NDK_ROOT to your Android NDK installation path" + echo "Example: export NDK_ROOT=/path/to/android-ndk" + exit 1 +fi +export ANDROID_NDK="${NDK_ROOT}" +echo " ANDROID_NDK set to: ${ANDROID_NDK}" + +# ========================================================== +# 2. Build NNTrainer for Android with QNN support +# ========================================================== +echo "[2/6] Building project for Android (with QNN, clean build)..." +./build.sh --platform=android --enable-qnn --clean + +# ========================================================== +# 3. Install Android Libraries for APK +# ========================================================== +echo "[3/6] Installing Android libraries for APK..." +./apk_install_android.sh + +# ========================================================== +# 4. Deploy Prebuilt Libraries +# ========================================================== +echo "[4/6] Copying prebuilt libraries to QuickDotAI project..." +PREBUILT_DIR="./Android/QuickDotAI/prebuilt_libs" + +# Ensure destination directory exists +mkdir -p "${PREBUILT_DIR}" + +# Copy all shared libraries to the project's prebuilt directory +cp ./install_libs/*.so "${PREBUILT_DIR}/" +echo " Libraries copied to: ${PREBUILT_DIR}" + +# ========================================================== +# 5. Build and Install APK +# ========================================================== +echo "[5/6] Building and installing APK..." +cd ./Android/ +./gradlew ":${APK_APPLICATION}:installDebug" + +# ========================================================== +# 6. Completion +# ========================================================== +echo "[6/6] Build and installation complete!" +echo "==========================================" +echo " Success!" +echo "==========================================" diff --git a/apk_install_android.sh b/apk_install_android.sh new file mode 100755 index 00000000..44535253 --- /dev/null +++ b/apk_install_android.sh @@ -0,0 +1,63 @@ +#!/bin/bash +# APK λΉŒλ“œμš© 라이브러리 μ„€μΉ˜ 슀크립트 +# install_libs/ 디렉토리에 라이브러리만 볡사 (adb push μ—†μŒ) +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +BUILD_DIR="$SCRIPT_DIR/builddir_android" +NNTRAINER_ROOT="$SCRIPT_DIR/nntrainer" +NNTRAINER_ANDROID="$NNTRAINER_ROOT/builddir/android_build_result/lib/arm64-v8a" +INSTALL_LIBS_DIR="$SCRIPT_DIR/install_libs" + +# ── Validate ──────────────────────────────────────────────────────────── +if [ ! -d "$BUILD_DIR" ]; then + echo "Error: Build directory not found: $BUILD_DIR" + echo "Run './build.sh --platform=android' first." + exit 1 +fi + +echo "=== Installing libraries for APK build ===" +echo "Install dir: $INSTALL_LIBS_DIR" +echo "" + +mkdir -p "$INSTALL_LIBS_DIR" + +# ── Copy nntrainer runtime libraries ──────────────────────────────────── +echo "Copying nntrainer libraries..." +[ -f "$NNTRAINER_ANDROID/libnntrainer.so" ] && cp "$NNTRAINER_ANDROID/libnntrainer.so" "$INSTALL_LIBS_DIR/" +[ -f "$NNTRAINER_ANDROID/libccapi-nntrainer.so" ] && cp "$NNTRAINER_ANDROID/libccapi-nntrainer.so" "$INSTALL_LIBS_DIR/" + +# ── Copy built artifacts ──────────────────────────────────────────────── +echo "Copying built artifacts..." + +# src targets +for f in libcausallm.so libquick_dot_ai.so; do + [ -f "$BUILD_DIR/src/$f" ] && cp "$BUILD_DIR/src/$f" "$INSTALL_LIBS_DIR/" +done + +[ -f "$BUILD_DIR/src/quick_dot_ai" ] && cp "$BUILD_DIR/src/quick_dot_ai" "$INSTALL_LIBS_DIR/" + +# api target +[ -f "$BUILD_DIR/api/libquick_dot_ai_api.so" ] && cp "$BUILD_DIR/api/libquick_dot_ai_api.so" "$INSTALL_LIBS_DIR/" + +# api-test target +[ -f "$BUILD_DIR/api-app/quick_dot_ai_test" ] && cp "$BUILD_DIR/api-app/quick_dot_ai_test" "$INSTALL_LIBS_DIR/" + +# qnn target +[ -f "$BUILD_DIR/qnn/libqnn_context.so" ] && cp "$BUILD_DIR/qnn/libqnn_context.so" "$INSTALL_LIBS_DIR/" + +# ── Copy libc++_shared.so from NDK ────────────────────────────────────── +if [ -n "$ANDROID_NDK" ]; then + LIBCXX=$(find "$ANDROID_NDK" -name "libc++_shared.so" -path "*/aarch64*" 2>/dev/null | head -1) + if [ -n "$LIBCXX" ]; then + echo "Copying libc++_shared.so..." + cp "$LIBCXX" "$INSTALL_LIBS_DIR/" + fi +fi + +echo "" +echo "=== Installation completed ===" +echo "Libraries copied to: $INSTALL_LIBS_DIR" +echo "" +echo "Copied files:" +ls -la "$INSTALL_LIBS_DIR/" \ No newline at end of file diff --git a/benchmarks/README.md b/benchmarks/README.md deleted file mode 100644 index 67da3a32..00000000 --- a/benchmarks/README.md +++ /dev/null @@ -1,168 +0,0 @@ -# nntrainer Benchmark Tool - -A benchmark tool for nntrainer CausalLM models on Android devices. - -## Installation - -### Requirements - -- Python 3.6+ -- ADB (Android Debug Bridge) installed and device connected -- nntrainer C++ binary built and deployed to device at `/data/local/tmp/quick_dot_ai/` -- Model files and `nntr_config.json` available on device - -### Python Dependencies - -```bash -pip install tabulate -``` - -## Usage - -### Basic Benchmark - -Run a single trial benchmark: - -```bash -python3 benchmark_android.py \ - -m /data/local/tmp/quick_dot_ai/qwen3-0.6b \ - -p 512 \ - -n 128 \ - -t 4 -``` - -### Options - -- `-m, --model`: Model directory path on device (required) -- `-p, --n-prompt`: Number of prompt tokens (default: 512) -- `-n, --n-gen`: Number of generation tokens (default: 0) -- `-t, --n-threads`: Number of OMP threads (default: 4) -- `-b, --batch-size`: Batch size (default: 1) - -## Output - -The tool outputs TPS (tokens per second) metrics for both prefill and generation phases. - -```bash -python3 benchmark_android.py \ - -m \ # Model directory path on device (required) - -p \ # Number of prompt tokens (default: 512) - -n \ # Number of generation tokens (default: 0) - -r \ # Number of trials (default: 5) - -t \ # Number of OMP threads (default: 4) - -b \ # Batch size (default: 1) - --device-info \ # Device info (auto-detect if not specified) -``` - -Can list arguments accept comma-separated values: -- `-p, --n-prompt`: Comma-separated prompt token counts -- `-n, --n-gen`: Comma-separated generation token counts -- `-t, --n-threads`: Comma-separated thread counts - -The script runs benchmarks for all combinations of the specified parameters (Cartesian product). - - -### Examples - -#### Prefill benchmark -```bash -# Test 512 tokens, no generation -python3 benchmark_android.py \ - -m /data/local/tmp/quick_dot_ai/qwen3-0.6b \ - -p 512 -n 0 -r 5 -t 4 - -# Test 128,256,512,1024 tokens, no generation -python3 benchmark_android.py \ - -m /data/local/tmp/quick_dot_ai/qwen3-0.6b \ - -p 128,256,512,1024 -n 0 -r 5 -t 4 -``` - -#### Generation benchmark -```bash -# Test 128 token generation -python3 benchmark_android.py \ - -m /data/local/tmp/quick_dot_ai/qwen3-0.6b \ - -p 512 -n 128 -r 5 -t 4 - -# Test 128,256,512,1024 token generation -python3 benchmark_android.py \ - -m /data/local/tmp/quick_dot_ai/qwen3-0.6b \ - -p 512 -n 128,256,512,1024 -r 5 -t 4 -``` - -#### Test different thread counts -```bash -# Test with 4 threads -python3 benchmark_android.py \ - -m /data/local/tmp/quick_dot_ai/qwen3-0.6b \ - -t 4 - -# Test with 2,4,8,16 threads -python3 benchmark_android.py \ - -m /data/local/tmp/quick_dot_ai/qwen3-0.6b \ - -t 2,4,8,16 -``` - -### Output Format - -The sweep script outputs a pretty table: - -**Pretty Table:** -``` -BENCHMARK SWEEP RESULTS (Not Real Result) -Model: qwen3-0.6b | Size: 2.30 GiB | Type: CausalLM | Dtype: FP32-FP32 | Device: S25U -+-----------+---------+------+----------------+-----------------+ -| Threads | Prompt | Gen | Prefill TPS | Gen TPS | -+===========+=========+======+================+=================+ -| 1 | 512 | 128 | 200.50 Β± 5.25 | 30.10 Β± 2.10 | -| 2 | 512 | 128 | 350.25 Β± 8.40 | 55.30 Β± 3.20 | -| 4 | 512 | 128 | 620.80 Β± 12.50 | 95.40 Β± 4.80 | -| 8 | 512 | 128 | 750.30 Β± 15.20 | 120.50 Β± 5.80 | -+-----------+---------+------+----------------+-----------------+ -``` - -## How It Works - -1. **Load Configuration**: Pulls `nntr_config.json` from the device via ADB -2. **Backup & Modify**: Creates a backup of the original config on device, modifies it with test parameters, and pushes back to device -3. **Run Trials**: Executes the C++ benchmark binary on the device multiple times via ADB -4. **Collect Metrics**: Parses output to extract TPS values and temperatures -5. **Calculate Statistics**: Computes mean and standard deviation across trials -6. **Restore Configuration**: Restores the original `nntr_config.json` on the device -7. **Output Results**: Prints results in specified format - - -## Troubleshooting - -### ADB device not found -```bash -adb devices -``` -Make sure your device is connected and ADB debugging is enabled. - -### Model file not found on device -The script requires all model files to be on the device. Ensure your model is deployed to `/data/local/tmp/quick_dot_ai/`. - -### nntr_config.json not found on device -The script reads `nntr_config.json` directly from the device. Make sure it exists at the specified model path. - -Example device structure: -``` -/data/local/tmp/quick_dot_ai/ -β”œβ”€β”€ models/qwen3-0.6b/ -β”‚ β”œβ”€β”€ nntr_config.json -β”‚ β”œβ”€β”€ config.json -β”‚ β”œβ”€β”€ generation_config.json -β”‚ └── nntr_qwen3_0.6b_fp32.bin -β”œβ”€β”€ libc++_shared.so -β”œβ”€β”€ libccapi-nntrainer.so -β”œβ”€β”€ libnntrainer.so -β”œβ”€β”€ nntrainer_causallm -└── run_causallm.sh -``` - - - -## License - -Same as nntrainer project license. diff --git a/benchmarks/benchmark_android.py b/benchmarks/benchmark_android.py deleted file mode 100644 index de4216d8..00000000 --- a/benchmarks/benchmark_android.py +++ /dev/null @@ -1,479 +0,0 @@ -#!/usr/bin/env python3 -""" -nntrainer benchmark for CausalLM models with configuration sweeping. - -Usage: - python3 benchmark_android.py -m [options] - -This script can sweep through multiple configurations: - - Different thread counts: -t 1,2,4,8 - - Different generation lengths: -n 128,512,1024 - - Different prompt lengths: -p 256,512,1024 - -Example: - python3 benchmark_android.py -m /data/local/tmp/quick_dot_ai/models/qwen3-0.6b -t 1,2,4,8 -n 128,256 -""" - -import subprocess -import re -import time -import statistics -import sys -import argparse -import json -import tempfile -import os -import shutil -from itertools import product -from tabulate import tabulate -from transformers import AutoTokenizer - -from device_utils import ( - get_thermal_temp, - wait_for_cooling, - get_device_model, - get_model_size, -) - -def generate_sample_input(target_tokens, local_tokenizer_path=None): - """ - Generate sample input that matches target token count. - If transformers is available, use exact tokenizer. Otherwise, use heuristic. - """ - if local_tokenizer_path: - # Load tokenizer from local path - tokenizer = AutoTokenizer.from_pretrained(os.path.dirname(local_tokenizer_path)) - - # Generic base text (repeating pattern) - base_token = 5555 - base_text = tokenizer.decode([base_token]) - - generated_text = base_text * target_tokens - - return generated_text - else: - # Heuristic fallback: assume ~4 chars per token on average - chars_per_token = 4 - target_chars = target_tokens * chars_per_token - - # Use a repeating pattern - base_text = "The quick brown fox jumps over the lazy dog. " - repeats = max(1, target_chars // len(base_text) + 1) - generated_text = base_text * repeats - - # Trim to approximate length - return generated_text[:target_chars] - - -def backup_and_modify_config(model_path, n_prompt, n_gen, batch_size=1): - """ - Backup original nntr_config.json from device and create modified version. - Returns context manager that restores original config on exit. - """ - class ConfigModifier: - def __init__(self, model_path, n_prompt, n_gen, batch_size): - self.n_prompt = n_prompt - self.n_gen = n_gen - self.batch_size = batch_size - self.device_backup = None - self.temp_config_path = None - self.device_config_path = f"{model_path}/nntr_config.json" - - def __enter__(self): - # Backup device config - result = subprocess.run( - ["adb", "shell", "cat", self.device_config_path], - capture_output=True, text=True - ) - - if result.returncode != 0: - raise RuntimeError(f"Could not read config from device: {result.stderr}") - - self.device_backup = result.stdout - - # Create backup on device - subprocess.run( - ["adb", "shell", "cp", self.device_config_path, self.device_config_path + ".benchmark_backup"], - capture_output=True - ) - - # Load and modify config - config = json.loads(self.device_backup) - config["init_seq_len"] = self.n_prompt - config["num_to_generate"] = self.n_gen - config["batch_size"] = self.batch_size - - # Generate sample_input matching target token count - local_tokenizer_path = None - - if "tokenizer_file" in config: - device_tokenizer_path = config["tokenizer_file"] - - # Create local temp directory for tokenizer - temp_dir = tempfile.mkdtemp(prefix="tokenizer_") - - try: - # Extract tokenizer directory name from device path - tokenizer_dir = os.path.dirname(device_tokenizer_path) - tokenizer_filename = os.path.basename(device_tokenizer_path) - - # Pull tokenizer directory from device - print(f" Pulling tokenizer from device...") - result = subprocess.run( - ["adb", "pull", tokenizer_dir + '/' + tokenizer_filename, temp_dir], - capture_output=True, text=True - ) - result = subprocess.run( - ["adb", "pull", tokenizer_dir + '/' + 'config.json', temp_dir], - capture_output=True, text=True - ) - - if result.returncode == 0: - local_tokenizer_path = os.path.join(temp_dir, tokenizer_filename) - else: - print(f" Warning: Could not pull tokenizer, using heuristic") - shutil.rmtree(temp_dir) - temp_dir = None - except Exception as e: - print(f" Warning: Could not pull tokenizer: {e}") - if temp_dir: - shutil.rmtree(temp_dir) - temp_dir = None - - generated_input = generate_sample_input(self.n_prompt, local_tokenizer_path) - config["sample_input"] = generated_input - - if local_tokenizer_path: - print(f"Generated sample_input ({self.n_prompt} token length, using tokenizer)") - else: - print(f"Generated sample_input ({self.n_prompt} token length, heuristic)") - - # Clean up temporary tokenizer directory - if local_tokenizer_path and os.path.exists(os.path.dirname(local_tokenizer_path)): - shutil.rmtree(os.path.dirname(local_tokenizer_path)) - - # Create temporary file with modified config - with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: - json.dump(config, f, indent=2) - self.temp_config_path = f.name - - # Push modified config to device - result = subprocess.run( - ["adb", "push", self.temp_config_path, self.device_config_path], - capture_output=True, text=True - ) - - if result.returncode != 0: - raise RuntimeError(f"Could not push config to device: {result.stderr}") - - return config - - def __exit__(self, exc_type, exc_val, exc_tb): - # Restore device config from backup - if self.device_backup: - with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: - f.write(self.device_backup) - temp_backup_path = f.name - - try: - subprocess.run( - ["adb", "push", temp_backup_path, self.device_config_path], - capture_output=True - ) - os.remove(temp_backup_path) - except Exception as e: - print(f"Warning: Could not restore config: {e}") - - # Clean up temporary files - if self.temp_config_path and os.path.exists(self.temp_config_path): - os.remove(self.temp_config_path) - - # Remove backup from device - subprocess.run( - ["adb", "shell", "rm", "-f", self.device_config_path + ".benchmark_backup"], - capture_output=True - ) - - return False - - return ConfigModifier(model_path, n_prompt, n_gen, batch_size) - - -def run_single_trial(model_path, omp_threads=None): - """Run a single benchmark trial and collect metrics.""" - # Build command to run nntrainer C++ binary - # Set OMP_NUM_THREADS as environment variable for the shell command - if omp_threads: - cmd = [ - "adb", "shell", - f"cd /data/local/tmp/quick_dot_ai && OMP_NUM_THREADS={omp_threads} ./run_causallm.sh '{model_path}'" - ] - else: - cmd = [ - "adb", "shell", - f"cd /data/local/tmp/quick_dot_ai && ./run_causallm.sh '{model_path}'" - ] - - # Capture output - result = subprocess.run(cmd, capture_output=True, text=True) - - output = result.stdout + result.stderr - print(output) - - # Parse TPS from output - prefill_match = re.search(r"prefill:.*,\s+([\d\.]+)\s+TPS", output) - - # Parse generation TPS if available - gen_match = re.search(r"generation:.*,\s+([\d\.]+)\s+TPS", output) - - prefill_tps = float(prefill_match.group(2) if prefill_match and len(prefill_match.groups()) > 1 else prefill_match.group(1)) if prefill_match else 0.0 - gen_tps = float(gen_match.group(1)) if gen_match else 0.0 - - return { - "prefill_tps": prefill_tps, - "gen_tps": gen_tps, - "error": result.stderr if result.returncode != 0 else "" - } - - -def calculate_statistics(values): - """Calculate mean and standard deviation.""" - if not values: - return 0.0, 0.0 - - mean = statistics.mean(values) - std = statistics.stdev(values) if len(values) > 1 else 0.0 - - return mean, std - - -def validate_model_path(model_path): - """ - Validate model path to prevent command injection and path traversal. - """ - # Normalize path to remove any '..' sequences - try: - normalized = os.path.normpath(model_path) - except Exception as e: - raise ValueError(f"Invalid path format: {e}") - - # Define allowed prefix (must be within nntrainer causallm directory) - allowed_prefix = "/data/local/tmp/nntrainer/" - - # Ensure path starts with allowed prefix - if not normalized.startswith(allowed_prefix): - raise ValueError( - f"Model path must start with '{allowed_prefix}'. " - f"Got: {model_path}" - ) - - # Validate characters: allow only safe filesystem characters - # Allow: alphanumeric, hyphen, underscore, dot, forward slash, plus sign - safe_chars_pattern = r'^[a-zA-Z0-9_\-./+]+$' - if not re.match(safe_chars_pattern, normalized): - raise ValueError( - f"Model path contains invalid characters. " - f"Only alphanumeric, '-', '_', '.', '/', and '+' are allowed. " - f"Got: {model_path}" - ) - - # Prevent empty path segments (like double slashes) - if '//' in normalized: - raise ValueError( - f"Model path contains empty segments (double slashes). " - f"Got: {model_path}" - ) - - return normalized - - -def output_results_table(all_results, model_name, model_size, model_type, dtype, device): - """Output all benchmark results in a pretty table format.""" - # Prepare table data - headers = ["Threads", "Prompt", "Gen", "Prefill TPS", "Gen TPS"] - table_data = [] - - for result in all_results: - prefill_str = f"{result['prefill_mean']:.2f} Β± {result['prefill_std']:.2f}" if result['prefill_mean'] > 0 else "N/A" - gen_str = f"{result['gen_mean']:.2f} Β± {result['gen_std']:.2f}" if result['gen_mean'] > 0 else "N/A" - - table_data.append([ - result['n_threads'], - result['n_prompt'], - result['n_gen'], - prefill_str, - gen_str - ]) - - print("\n" + "=" * 90) - print("BENCHMARK SWEEP RESULTS") - print("=" * 90) - print(f"Model: {model_name} | Size: {model_size} | Type: {model_type} | Dtype: {dtype} | Device: {device}") - print("=" * 90) - print(tabulate(table_data, headers=headers, tablefmt="grid")) - print("=" * 90) - - -def parse_list_arg(arg_string): - """Parse comma-separated list argument.""" - if not arg_string: - return [] - return [int(x.strip()) for x in arg_string.split(',')] - - -def main(): - parser = argparse.ArgumentParser( - description="nntrainer benchmark with configuration sweeping for nntrainer CausalLM models" - ) - parser.add_argument("-m", "--model", type=str, required=True, - help="Model directory path (on device, e.g., /data/local/tmp/quick_dot_ai/models/qwen3-0.6b-q40)") - parser.add_argument("-p", "--n-prompt", type=str, default="512", - help="Number of prompt tokens, comma-separated (default: 512)") - parser.add_argument("-n", "--n-gen", type=str, default="0", - help="Number of generation tokens, comma-separated (default: 0)") - parser.add_argument("-r", "--n-trials", type=int, default=5, - help="Number of trials per configuration (default: 5)") - parser.add_argument("-t", "--n-threads", type=str, default="4", - help="Number of OMP threads, comma-separated (default: 4)") - parser.add_argument("-b", "--batch-size", type=int, default=1, - help="Batch size (default: 1)") - parser.add_argument("--device-info", type=str, default=None, - help="Device info (auto-detect if not specified)") - parser.add_argument("--cool-to", type=float, default=35.0, - help="Cool device to this temperature before each config (default: 35.0)") - parser.add_argument("--max-cool-wait", type=int, default=300, - help="Maximum wait time for cooling in seconds (default: 300)") - parser.add_argument("--skip-cooling", action="store_true", - help="Skip cooling between configurations") - - args = parser.parse_args() - - # Parse list arguments - n_prompts = parse_list_arg(args.n_prompt) - n_gens = parse_list_arg(args.n_gen) - n_threads_list = parse_list_arg(args.n_threads) - - for n_threads in n_threads_list: - assert n_threads > 0, "Error: Thread counts must be positive integers" - - # Generate all configurations - configs = list(product(n_prompts, n_gens, n_threads_list)) - - # Extract model name from path - model_path = validate_model_path(args.model) - model_name = os.path.basename(model_path) - - print(f"=== nntrainer benchmark sweep ===") - print(f"Model: {model_name}") - print(f"Device path: {model_path}") - print(f"n_prompt values: {n_prompts}") - print(f"n_gen values: {n_gens}") - print(f"n_threads values: {n_threads_list}") - print(f"n_trials per config: {args.n_trials}") - print(f"batch_size: {args.batch_size}") - print(f"Total configurations: {len(configs)}") - print("-" * 50) - - # Load nntr_config.json from device - try: - device_config_path = f"{model_path}/nntr_config.json" - result = subprocess.run( - ["adb", "shell", "cat", device_config_path], - capture_output=True, text=True - ) - - if result.returncode != 0: - raise RuntimeError(f"Could not read nntr_config.json from device: {result.stderr}") - - nntr_cfg = json.loads(result.stdout) - print("Successfully loaded nntr_config.json from device") - except Exception as e: - print(f"Error loading nntr_config.json: {e}") - return - - # Extract model metadata - model_type = nntr_cfg.get("model_type", "Unknown") - dtype = nntr_cfg.get("model_tensor_type", "Unknown") - - # Get model size - model_size = get_model_size(model_path, nntr_cfg) - print(f"Model size: {model_size}") - print(f"Model type: {model_type}") - print(f"Dtype: {dtype}") - - # Get device info - device = args.device_info if args.device_info else get_device_model() - print(f"Device: {device}") - print("-" * 50) - - # Run benchmarks for all configurations - all_results = [] - - for idx, (n_prompt, n_gen, n_threads) in enumerate(configs, 1): - print(f"\n[{idx}/{len(configs)}] Config: n_prompt={n_prompt}, n_gen={n_gen}, n_threads={n_threads}") - print("-" * 50) - - # Wait for cooling before starting next configuration (for fair comparison) - if idx > 1 and not args.skip_cooling: - print("\nWaiting for device cooling...") - wait_for_cooling(args.cool_to, args.max_cool_wait) - time.sleep(2) # Brief pause after cooling - - # Create config modifier for this specific configuration - config_modifier = backup_and_modify_config(model_path, n_prompt, n_gen, args.batch_size) - - try: - # Manually enter context to ensure proper cleanup on interrupt - config_modifier.__enter__() - - results = [] - for i in range(args.n_trials): - res = run_single_trial(model_path, n_threads) - results.append(res) - time.sleep(1) # Brief pause between trials - - # Calculate statistics - prefills = [r["prefill_tps"] for r in results if r["prefill_tps"] > 0] - gens = [r["gen_tps"] for r in results if r["gen_tps"] > 0] - - prefill_mean, prefill_std = calculate_statistics(prefills) - gen_mean, gen_std = calculate_statistics(gens) - - all_results.append({ - 'n_prompt': n_prompt, - 'n_gen': n_gen, - 'n_threads': n_threads, - 'prefill_mean': prefill_mean, - 'prefill_std': prefill_std, - 'gen_mean': gen_mean, - 'gen_std': gen_std - }) - - print(f" Prefill: {prefill_mean:.2f} Β± {prefill_std:.2f} TPS") - print(f" Generation: {gen_mean:.2f} Β± {gen_std:.2f} TPS") - - # Normal completion - restore config - config_modifier.__exit__(None, None, None) - - except KeyboardInterrupt: - print("\nInterrupted by user. Restoring config...") - # Config will be restored in __exit__ even on interrupt - config_modifier.__exit__(KeyboardInterrupt, KeyboardInterrupt(), None) - break - except Exception as e: - print(f"Error in configuration: {e}") - print("Restoring config...") - # Config will be restored in __exit__ even on error - config_modifier.__exit__(type(e), e, e.__traceback__) - continue - - print("\n" + "=" * 50) - print("ALL RESULTS") - print("=" * 50) - - # Output all results in table format - output_results_table(all_results, model_name, model_size, model_type, dtype, device) - - -if __name__ == "__main__": - main() diff --git a/benchmarks/device_utils.py b/benchmarks/device_utils.py deleted file mode 100644 index 75284f1e..00000000 --- a/benchmarks/device_utils.py +++ /dev/null @@ -1,100 +0,0 @@ -""" -Device utilities for nntrainer benchmark. -Device interaction including temperature monitoring, cooling, and device info. -""" - -import subprocess -import time - - -def get_thermal_temp(): - """Get thermal zone temperature in Celsius.""" - try: - cmd = ["adb", "shell", "cat", "/sys/class/thermal/thermal_zone0/temp"] - result = subprocess.run(cmd, capture_output=True, text=True) - if result.returncode == 0: - return float(result.stdout.strip()) / 1000.0 - except Exception as e: - print(f"Error reading temp: {e}") - return 0.0 - - -def wait_for_cooling(target_temp=40.0, max_wait=300, poll_interval=10): - """ - Wait for device to cool down to target temperature. - - Args: - target_temp: Target temperature in Celsius (default: 40.0) - max_wait: Maximum wait time in seconds (default: 300 = 5 min) - poll_interval: Time between temperature checks (default: 10 seconds) - - Returns: - True if target temperature reached, False if timeout - """ - current_temp = get_thermal_temp() - print(f"Current temp: {current_temp:.1f}Β°C, Target: {target_temp:.1f}Β°C") - - if current_temp <= target_temp: - print(f"Temperature already at target ({current_temp:.1f}Β°C ≀ {target_temp:.1f}Β°C)") - return True - - print(f"Cooling down device... (Max wait: {max_wait}s)") - start_time = time.time() - - while (time.time() - start_time) < max_wait: - time.sleep(poll_interval) - current_temp = get_thermal_temp() - print(f" Current temp: {current_temp:.1f}Β°C") - - if current_temp <= target_temp: - print(f"Reached target temperature ({current_temp:.1f}Β°C)") - return True - - print(f"Warning: Timeout waiting for cooling. Current temp: {current_temp:.1f}Β°C") - return False - - -def get_device_model(): - """Get device model name from system properties.""" - try: - cmd = ["adb", "shell", "getprop", "ro.product.model"] - result = subprocess.run(cmd, capture_output=True, text=True) - if result.returncode == 0: - device = result.stdout.strip() - return device - except Exception as e: - print(f"Error reading device model: {e}") - return "Unknown" - - -def get_model_size(model_path, nntr_cfg): - """Get model file size in human-readable format from device.""" - try: - # Get the binary file name from config - model_file = nntr_cfg.get("model_file_name", "model.bin") - - # Always get from device (that's where inference runs) - device_path = f"{model_path}/{model_file}" - - # Get file size via adb - cmd = ["adb", "shell", "wc", "-c", device_path] - result = subprocess.run(cmd, capture_output=True, text=True) - - if result.returncode == 0: - size_bytes = int(result.stdout.strip().split()[0]) - return format_size(size_bytes) - else: - print(f"Warning: Could not get model size from device: {result.stderr}") - except Exception as e: - print(f"Error getting model size: {e}") - return "Unknown" - - -def format_size(size_bytes): - """Convert bytes to human-readable format.""" - for unit in ['B', 'KiB', 'MiB', 'GiB']: - if size_bytes >= 1024.0: - size_bytes /= 1024.0 - else: - break - return f"{size_bytes:.2f} {unit}" diff --git a/build.sh b/build.sh new file mode 100755 index 00000000..d42a060d --- /dev/null +++ b/build.sh @@ -0,0 +1,277 @@ +#!/bin/bash +# Unified build script for causallm-extension +# +# Usage: +# ./build.sh # x86, all targets +# ./build.sh --platform=android # android, all targets +# ./build.sh --platform=android --target=src # android, src only +# ./build.sh --platform=x86 --target=src,api # x86, src + api +# ./build.sh --platform=android --enable-qnn # android with QNN +# ./build.sh --clean # clean rebuild +# +# Environment: +# ANDROID_NDK - required for --platform=android builds +set -e + +# ── Parse arguments ───────────────────────────────────────────────────── +PLATFORM="x86" +CLEAN=false +TARGETS="all" +ENABLE_QNN=false + +for arg in "$@"; do + case "$arg" in + --platform=*) PLATFORM="${arg#*=}" ;; + --target=*) TARGETS="${arg#*=}" ;; + --clean) CLEAN=true ;; + --enable-qnn) ENABLE_QNN=true ;; + --help|-h) + sed -n '2,/^set -e$/p' "$0" | grep '^#' | sed 's/^# \?//' + exit 0 ;; + esac +done + +# Validate platform +if [[ "$PLATFORM" != "x86" && "$PLATFORM" != "android" ]]; then + echo "Error: --platform must be x86 or android (got: $PLATFORM)" + exit 1 +fi + +# QNN is android-only +if [ "$ENABLE_QNN" = true ] && [ "$PLATFORM" != "android" ]; then + echo "Error: --enable-qnn is only supported with --platform=android" + exit 1 +fi + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +NNTRAINER_ROOT="$SCRIPT_DIR/nntrainer" +XGRAMMAR_ROOT="$SCRIPT_DIR/xgrammar" +CAUSALLM_ROOT="$NNTRAINER_ROOT/Applications/CausalLM" + +# Build directory +if [ "$PLATFORM" = "android" ]; then + BUILD_DIR="$SCRIPT_DIR/builddir_android" +else + BUILD_DIR="$SCRIPT_DIR/builddir_x86" +fi + +echo "=== Quick-Dot-AI Unified Build ===" +echo "PLATFORM: $PLATFORM" +echo "TARGETS: $TARGETS" +echo "ENABLE_QNN: $ENABLE_QNN" +echo "BUILD_DIR: $BUILD_DIR" +echo "" + +# ── Step 0: Check submodules ──────────────────────────────────────────── +if [ ! -f "$NNTRAINER_ROOT/meson.build" ]; then + echo "[0] Initializing nntrainer submodule..." + git -C "$SCRIPT_DIR" submodule update --init --recursive --depth 1 +fi + +# xgrammar submodule: src/meson.build compiles xgrammar/cpp/*.cc and the root +# meson.build adds xgrammar/include + 3rdparty/dlpack/include to the include +# path. A missing xgrammar checkout makes meson configuration fail with: +# "ERROR: File .../xgrammar/cpp/compiled_grammar.cc does not exist." +# NOTE: init from the superproject ($SCRIPT_DIR) targeting the xgrammar path, +# not `git -C "$XGRAMMAR_ROOT"` (that dir is empty until checkout, so it is +# not a git repo). Also do NOT check $NNTRAINER_ROOT/meson.build here β€” that +# is the nntrainer guard and is unrelated to xgrammar. +if [ ! -f "$XGRAMMAR_ROOT/cpp/compiled_grammar.cc" ]; then + echo "[0] Initializing xgrammar submodule..." + git -C "$SCRIPT_DIR" submodule update --init xgrammar +fi + +# xgrammar nested submodule: only dlpack is required by the build +# (xgrammar/3rdparty/dlpack/include, used e.g. by grammar_matcher.cc). +# cpptrace is compiled out (guarded by XGRAMMAR_ENABLE_CPPTRACE != 1) and +# googletest is test-only, so we deliberately avoid --recursive to skip those +# large, unnecessary clones. +if [ ! -d "$XGRAMMAR_ROOT/3rdparty/dlpack/include" ]; then + echo "[0] Initializing xgrammar nested submodule (dlpack)..." + git -C "$XGRAMMAR_ROOT" submodule update --init 3rdparty/dlpack +fi + +# Check iniparser submodule +if [ ! -f "$NNTRAINER_ROOT/subprojects/iniparser/src/iniparser.h" ]; then + echo "[0] Initializing nntrainer nested submodules..." + cd "$NNTRAINER_ROOT" + git submodule update --init --recursive --depth 1 + cd "$SCRIPT_DIR" +fi + +# ── Step 1: Pre-build nntrainer ───────────────────────────────────────── +if [ "$PLATFORM" = "android" ]; then + if [ -z "$ANDROID_NDK" ]; then + echo "Error: ANDROID_NDK is not set." + echo "Example: export ANDROID_NDK=/path/to/android-ndk-r21d" + exit 1 + fi + + NNTRAINER_ANDROID_LIB="$NNTRAINER_ROOT/builddir/android_build_result/lib/arm64-v8a/libnntrainer.so" + if [ "$CLEAN" = true ] || [ ! -f "$NNTRAINER_ANDROID_LIB" ]; then + echo "[1] Building nntrainer for Android..." + cd "$NNTRAINER_ROOT" + [ "$CLEAN" = true ] && rm -rf builddir + + # Build nntrainer with QNN support if requested. + # + # -Dprefix: package_android.sh runs `ninja install`. nntrainer's own + # Android artifacts install to an absolute path (builddir/android_build_result, + # see nntrainer/meson.build), independent of the meson prefix. However the + # googletest subproject (pulled in unconditionally when no system gtest/gmock + # is found) still honors the meson prefix and tries to install libgmock.a / + # libgtest.a to the default /usr/local/lib, which fails for a non-root user + # with "PermissionError: [Errno 13] Permission denied". Point the prefix at a + # writable, in-tree staging dir so the install succeeds without root. + NNTRAINER_STAGE_PREFIX="$SCRIPT_DIR/.nntrainer_android_stage" + NNTRAINER_EXTRA_OPTS="-Dmmap-read=false -Dprefix=$NNTRAINER_STAGE_PREFIX" + if [ "$ENABLE_QNN" = true ]; then + NNTRAINER_EXTRA_OPTS="$NNTRAINER_EXTRA_OPTS -Denable-npu=true" + fi + ./tools/package_android.sh $NNTRAINER_EXTRA_OPTS + cd "$SCRIPT_DIR" + else + echo "[1] nntrainer (android) already built. (use --clean to rebuild)" + fi +else + NNTRAINER_BUILD="$NNTRAINER_ROOT/builddir_x86" + NNTRAINER_X86_LIB="$NNTRAINER_BUILD/nntrainer/libnntrainer.so" + if [ "$CLEAN" = true ] || [ ! -f "$NNTRAINER_X86_LIB" ]; then + echo "[1] Building nntrainer for x86..." + cd "$NNTRAINER_ROOT" + if [ "$CLEAN" = true ]; then + rm -rf "$NNTRAINER_BUILD" + fi + if [ ! -f "$NNTRAINER_BUILD/build.ninja" ]; then + meson setup "$NNTRAINER_BUILD" . \ + --buildtype=release \ + -Denable-app=false \ + -Denable-test=false \ + -Denable-transformer=false \ + -Denable-tflite-backbone=false \ + -Denable-tflite-interpreter=false + fi + ninja -C "$NNTRAINER_BUILD" -j $(nproc) + cd "$SCRIPT_DIR" + else + echo "[1] nntrainer (x86) already built. (use --clean to rebuild)" + fi +fi + +# ── Step 2: Prepare json.hpp ──────────────────────────────────────────── +if [ ! -f "$CAUSALLM_ROOT/json.hpp" ]; then + echo "[2] Preparing json.hpp..." + pushd "$NNTRAINER_ROOT" > /dev/null + + if [ "$PLATFORM" = "android" ]; then + "$NNTRAINER_ROOT/jni/prepare_encoder.sh" "$NNTRAINER_ROOT/builddir" "0.2" || true + else + "$NNTRAINER_ROOT/jni/prepare_encoder.sh" "$NNTRAINER_ROOT/builddir_x86" "0.2" || true + fi + + # Fallback: manual copy + if [ ! -f "$CAUSALLM_ROOT/json.hpp" ]; then + for candidate in "$NNTRAINER_ROOT/builddir_x86/json.hpp" "$NNTRAINER_ROOT/builddir/json.hpp"; do + if [ -f "$candidate" ]; then + cp "$candidate" "$CAUSALLM_ROOT/" + break + fi + done + fi + popd > /dev/null + + if [ ! -f "$CAUSALLM_ROOT/json.hpp" ]; then + echo "Error: Failed to prepare json.hpp" + exit 1 + fi +fi + +# ── Step 3: Tokenizer check ──────────────────────────────────────────── +if [ "$PLATFORM" = "android" ]; then + TOKENIZER="$CAUSALLM_ROOT/lib/libtokenizers_android_c.a" +else + TOKENIZER="$CAUSALLM_ROOT/lib/libtokenizers_c.a" +fi + +if [ ! -f "$TOKENIZER" ]; then + echo "Warning: Tokenizer library not found: $TOKENIZER" + if [ -f "$CAUSALLM_ROOT/build_tokenizer_android.sh" ] && [ "$PLATFORM" = "android" ]; then + echo "Building tokenizer..." + cd "$CAUSALLM_ROOT" && ./build_tokenizer_android.sh && cd "$SCRIPT_DIR" + else + echo "Error: Tokenizer library missing. Place it at: $TOKENIZER" + exit 1 + fi +fi + +# ── Step 4: Generate cross file (android) ─────────────────────────────── +CROSS_ARGS="" +if [ "$PLATFORM" = "android" ]; then + CROSS_FILE="$SCRIPT_DIR/cross/android-aarch64.cross" + sed "s|@ANDROID_NDK@|$ANDROID_NDK|g" \ + "$SCRIPT_DIR/cross/android-aarch64.cross.in" > "$CROSS_FILE" + CROSS_ARGS="--cross-file $CROSS_FILE" +fi + +# ── Step 5: Parse targets into meson options ──────────────────────────── +ENABLE_API=false +ENABLE_API_TEST=false + +if [ "$TARGETS" = "all" ]; then + ENABLE_API=true + ENABLE_API_TEST=true +else + IFS=',' read -ra T <<< "$TARGETS" + for t in "${T[@]}"; do + case "$(echo "$t" | tr -d ' ')" in + api) ENABLE_API=true ;; + api-test) ENABLE_API=true; ENABLE_API_TEST=true ;; + src) ;; # src is always built + qnn) ENABLE_QNN=true ;; + esac + done +fi + +# ── Step 6: Meson setup ──────────────────────────────────────────────── +MESON_OPTS=( + --buildtype=release + -Denable-qnn=$ENABLE_QNN + -Denable-api=$ENABLE_API + -Denable-api-test=$ENABLE_API_TEST +) + +if [ "$PLATFORM" = "android" ]; then + MESON_OPTS+=(-Dplatform=android) +else + MESON_OPTS+=(-Dnntrainer_builddir=builddir_x86) +fi + +echo "[3] Configuring meson..." +if [ "$CLEAN" = true ] || [ ! -f "$BUILD_DIR/build.ninja" ]; then + rm -rf "$BUILD_DIR" + meson setup "$BUILD_DIR" "$SCRIPT_DIR" $CROSS_ARGS "${MESON_OPTS[@]}" +else + meson setup "$BUILD_DIR" "$SCRIPT_DIR" --reconfigure $CROSS_ARGS "${MESON_OPTS[@]}" || true +fi + +# ── Step 7: Build ─────────────────────────────────────────────────────── +echo "[4] Building..." +ninja -C "$BUILD_DIR" -j $(nproc) + +echo "" +echo "=== Build completed ===" +echo "Artifacts in: $BUILD_DIR" + +if [ "$PLATFORM" = "x86" ]; then + NNTRAINER_BUILD="${NNTRAINER_BUILD:-$NNTRAINER_ROOT/builddir_x86}" + echo "" + echo "Run executable:" + echo " LD_LIBRARY_PATH=$NNTRAINER_BUILD/nntrainer:$NNTRAINER_BUILD/api/ccapi:$BUILD_DIR/src:$BUILD_DIR/api \\" + echo " $BUILD_DIR/src/quick_dot_ai [input_prompt]" +fi + +if [ "$PLATFORM" = "android" ]; then + echo "" + echo "Install to device:" + echo " ./install_android.sh" +fi diff --git a/build_android.sh b/build_android.sh deleted file mode 100755 index e592547a..00000000 --- a/build_android.sh +++ /dev/null @@ -1,190 +0,0 @@ -#!/bin/bash - -# Build script for CausalLM Android application -# This script builds libquick_dot_ai_core.so and quick_dot_ai executable -set -e - -# Color codes -RED='\033[0;31m' -GREEN='\033[0;32m' -YELLOW='\033[1;33m' -CYAN='\033[0;36m' -BLUE='\033[0;34m' -NC='\033[0m' # No Color - -log_info() { - echo -e "${BLUE}[INFO]${NC} $1" -} - -log_success() { - echo -e "${GREEN}[SUCCESS]${NC} $1" -} - -log_warning() { - echo -e "${YELLOW}[WARNING]${NC} $1" -} - -log_error() { - echo -e "${RED}[ERROR]${NC} $1" -} - -log_header() { - echo -e "\n${CYAN}========================================${NC}" - echo -e "${CYAN} $1 ${NC}" - echo -e "${CYAN}========================================${NC}" -} - -log_step() { - echo -e "\n${YELLOW}[Step $1]${NC} $2" - echo -e "${YELLOW}----------------------------------------${NC}" -} - -# Function to check and fix artifact location -check_artifact() { - local filename=$1 - local libs_path="libs/arm64-v8a/$filename" - local obj_path="obj/local/arm64-v8a/$filename" - - if [ -f "$libs_path" ]; then - size=$(ls -lh "$libs_path" | awk '{print $5}') - echo -e " ${GREEN}[OK]${NC} $filename ($size)" - return 0 - elif [ -f "$obj_path" ]; then - echo -e " ${YELLOW}[WARN]${NC} $filename found in obj but not in libs. Copying..." - mkdir -p "libs/arm64-v8a" - cp "$obj_path" "$libs_path" - if [ -x "$obj_path" ]; then - chmod +x "$libs_path" - fi - size=$(ls -lh "$libs_path" | awk '{print $5}') - echo -e " ${GREEN}[OK]${NC} $filename ($size) (Copied from obj)" - return 0 - else - echo -e " ${RED}[ERROR]${NC} $filename not found!" - log_info " Checked paths:" - log_info " - $libs_path" - log_info " - $obj_path" - return 1 - fi -} - -# Check if NDK path is set -if [ -z "$ANDROID_NDK" ]; then - log_error "ANDROID_NDK is not set. Please set it to your Android NDK path." - log_info "Example: export ANDROID_NDK=/path/to/android-ndk-r21d" - exit 1 -fi - -# Set NNTRAINER_ROOT -SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -NNTRAINER_ROOT="${NNTRAINER_ROOT:-$(cd "$SCRIPT_DIR/subprojects/nntrainer" && pwd)}" -export NNTRAINER_ROOT - -log_header "Build CausalLM Android Application" -log_info "NNTRAINER_ROOT: $NNTRAINER_ROOT" -log_info "ANDROID_NDK: $ANDROID_NDK" -log_info "Working directory: $(pwd)" - -# Step 1: Build nntrainer for Android if not already built -log_step "1/4" "Build nntrainer for Android" - -if [ ! -f "$NNTRAINER_ROOT/builddir/android_build_result/lib/arm64-v8a/libnntrainer.so" ]; then - log_info "Building nntrainer for Android..." - cd "$NNTRAINER_ROOT" - if [ -d "$NNTRAINER_ROOT/builddir" ]; then - rm -rf builddir - fi - # Use a user-writable prefix so `meson install` does not try to - # escalate privileges to write into /usr/local (fails in CI / non- - # interactive environments). - ./tools/package_android.sh "-Dprefix=$NNTRAINER_ROOT/builddir/install" -else - log_info "nntrainer for Android already built (skipping)" -fi - -# Check if build was successful -if [ ! -f "$NNTRAINER_ROOT/builddir/android_build_result/lib/arm64-v8a/libnntrainer.so" ]; then - log_error "nntrainer build failed. Please check the build logs." - exit 1 -fi -log_success "nntrainer ready" - -# Step 2: Build tokenizer library if not present -log_step "2/4" "Build Tokenizer Library" - -cd "$SCRIPT_DIR" -if [ ! -f "lib/libtokenizers_android_c.a" ]; then - log_warning "libtokenizers_android_c.a not found in lib directory." - log_info "Attempting to build tokenizer library..." - if [ -f "build_tokenizer_android.sh" ]; then - ./build_tokenizer_android.sh - else - log_error "tokenizer library not found and build script is missing." - log_info "Please build or download the tokenizer library for Android arm64-v8a" - log_info "and place it in: $SCRIPT_DIR/lib/libtokenizers_android_c.a" - exit 1 - fi -else - log_info "Tokenizer library already built (skipping)" -fi -log_success "Tokenizer library ready" - -# Step 3: Prepare json.hpp if not present -log_step "3/4" "Prepare json.hpp" - -if [ ! -f "$SCRIPT_DIR/json.hpp" ]; then - log_info "json.hpp not found. Downloading..." - # Use Quick.AI's own prepare_encoder.sh so the downloaded json.hpp - # is copied to the Quick.AI project root (the nntrainer submodule - # ships a legacy variant that drops it into Applications/CausalLM/ - # instead and leaves our expected location empty). - "$SCRIPT_DIR/jni/prepare_encoder.sh" "$NNTRAINER_ROOT/builddir" "0.2" - - if [ ! -f "$SCRIPT_DIR/json.hpp" ]; then - log_error "Failed to download json.hpp" - exit 1 - fi -else - log_info "json.hpp already exists (skipping)" -fi -log_success "json.hpp ready" - -# Step 4: Build CausalLM (libquick_dot_ai_core.so and quick_dot_ai) -log_step "4/4" "Build CausalLM Core (library + executable)" - -cd "$SCRIPT_DIR/jni" - -# Clean previous builds -rm -rf libs obj - -log_info "Building with ndk-build (builds quick_dot_ai_core, quick_dot_ai, quick_dot_ai_quantize)..." -# We explicitly set paths to ensure outputs are predictable -if ndk-build NDK_PROJECT_PATH=. NDK_LIBS_OUT=./libs NDK_OUT=./obj APP_BUILD_SCRIPT=./Android.mk NDK_APPLICATION_MK=./Application.mk quick_dot_ai_core quick_dot_ai quick_dot_ai_quantize -j $(nproc); then - log_success "Build completed successfully" -else - log_error "Build failed" - exit 1 -fi - -# Verify outputs -log_info "Build artifacts:" - -check_artifact "libquick_dot_ai_core.so" || exit 1 -check_artifact "quick_dot_ai" || exit 1 -check_artifact "quick_dot_ai_quantize" || exit 1 - -# Summary -log_header "Build Summary" -log_success "Build completed successfully!" -log_info "Output files are in: $SCRIPT_DIR/jni/libs/arm64-v8a/" -log_info "Executables:" -log_info " - quick_dot_ai (main application), quick_dot_ai_quantize" -log_info "Libraries:" -log_info " - libquick_dot_ai_core.so (CausalLM Core library)" -log_info " - libnntrainer.so (nntrainer library)" -log_info " - libccapi-nntrainer.so (nntrainer C/C API)" -log_info " - libc++_shared.so (C++ runtime)" -log_info "To build API library, run:" -log_info " ./build_api_lib.sh" -log_info "To install and run:" -log_info " ./install_android.sh" diff --git a/build_api_lib.sh b/build_api_lib.sh deleted file mode 100755 index b757fd1b..00000000 --- a/build_api_lib.sh +++ /dev/null @@ -1,128 +0,0 @@ -#!/bin/bash - -# Build script for CausalLM API Library -# This script builds libquick_dot_ai_api.so only -set -e - -# Color codes -RED='\033[0;31m' -GREEN='\033[0;32m' -YELLOW='\033[1;33m' -CYAN='\033[0;36m' -BLUE='\033[0;34m' -NC='\033[0m' # No Color - -log_info() { - echo -e "${BLUE}[INFO]${NC} $1" -} - -log_success() { - echo -e "${GREEN}[SUCCESS]${NC} $1" -} - -log_warning() { - echo -e "${YELLOW}[WARNING]${NC} $1" -} - -log_error() { - echo -e "${RED}[ERROR]${NC} $1" -} - -log_header() { - echo -e "\n${CYAN}========================================${NC}" - echo -e "${CYAN} $1 ${NC}" - echo -e "${CYAN}========================================${NC}" -} - -log_step() { - echo -e "\n${YELLOW}[Step $1]${NC} $2" - echo -e "${YELLOW}----------------------------------------${NC}" -} - -# Function to check and fix artifact location -check_artifact() { - local filename=$1 - local libs_path="libs/arm64-v8a/$filename" - local obj_path="obj/local/arm64-v8a/$filename" - - if [ -f "$libs_path" ]; then - size=$(ls -lh "$libs_path" | awk '{print $5}') - echo -e " ${GREEN}[OK]${NC} $filename ($size)" - return 0 - elif [ -f "$obj_path" ]; then - echo -e " ${YELLOW}[WARN]${NC} $filename found in obj but not in libs. Copying..." - mkdir -p "libs/arm64-v8a" - cp "$obj_path" "$libs_path" - if [ -x "$obj_path" ]; then - chmod +x "$libs_path" - fi - size=$(ls -lh "$libs_path" | awk '{print $5}') - echo -e " ${GREEN}[OK]${NC} $filename ($size) (Copied from obj)" - return 0 - else - echo -e " ${RED}[ERROR]${NC} $filename not found!" - echo " Checked paths:" - echo " - $libs_path" - echo " - $obj_path" - return 1 - fi -} - -# Check if NDK path is set -if [ -z "$ANDROID_NDK" ]; then - log_error "ANDROID_NDK is not set. Please set it to your Android NDK path." - echo "Example: export ANDROID_NDK=/path/to/android-ndk-r21d" - exit 1 -fi - -# Set NNTRAINER_ROOT -SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -NNTRAINER_ROOT="${NNTRAINER_ROOT:-$(cd "$SCRIPT_DIR/subprojects/nntrainer" && pwd)}" -export NNTRAINER_ROOT - -log_header "Build CausalLM API Library" -echo "NNTRAINER_ROOT: $NNTRAINER_ROOT" -echo "ANDROID_NDK: $ANDROID_NDK" -echo "Working directory: $(pwd)" - -# Check if CausalLM Core is built -log_step "1/2" "Check CausalLM Core" - -if [ ! -f "$SCRIPT_DIR/jni/libs/arm64-v8a/libquick_dot_ai_core.so" ]; then - log_error "libquick_dot_ai_core.so not found." - echo "Please run build_android.sh first to build the core library." - exit 1 -fi -log_success "CausalLM Core found" - -# Step 2: Build CausalLM API -log_step "2/2" "Build CausalLM API Library" - -cd "$SCRIPT_DIR/jni" - -log_info "Building with ndk-build (builds libquick_dot_ai_api.so)..." -if ndk-build NDK_PROJECT_PATH=. NDK_LIBS_OUT=./libs NDK_OUT=./obj APP_BUILD_SCRIPT=./Android.mk NDK_APPLICATION_MK=./Application.mk quick_dot_ai_api -j $(nproc); then - log_success "Build completed successfully" -else - log_error "Build failed" - exit 1 -fi - -# Verify output -echo "" -echo "Build artifacts:" - -check_artifact "libquick_dot_ai_api.so" || exit 1 - -# Summary -log_header "Build Summary" -log_success "Build completed successfully!" -echo "" -echo "Output files are in: $SCRIPT_DIR/jni/libs/arm64-v8a/" -echo "" -echo "Libraries:" -echo " - libquick_dot_ai_api.so (CausalLM API library)" -echo "" -echo "To build test app, run:" -echo " ./build_test_app.sh" -echo "" diff --git a/build_test_app.sh b/build_test_app.sh deleted file mode 100755 index 956e61bb..00000000 --- a/build_test_app.sh +++ /dev/null @@ -1,149 +0,0 @@ -#!/bin/bash - -# Build script for CausalLM Test Application -# This script builds test_api executable only -set -e - -# Color codes -RED='\033[0;31m' -GREEN='\033[0;32m' -YELLOW='\033[1;33m' -CYAN='\033[0;36m' -BLUE='\033[0;34m' -NC='\033[0m' # No Color - -log_info() { - echo -e "${BLUE}[INFO]${NC} $1" -} - -log_success() { - echo -e "${GREEN}[SUCCESS]${NC} $1" -} - -log_warning() { - echo -e "${YELLOW}[WARNING]${NC} $1" -} - -log_error() { - echo -e "${RED}[ERROR]${NC} $1" -} - -log_header() { - echo -e "\n${CYAN}========================================${NC}" - echo -e "${CYAN} $1 ${NC}" - echo -e "${CYAN}========================================${NC}" -} - -log_step() { - echo -e "\n${YELLOW}[Step $1]${NC} $2" - echo -e "${YELLOW}----------------------------------------${NC}" -} - -# Function to check and fix artifact location -check_artifact() { - local filename=$1 - local libs_path="libs/arm64-v8a/$filename" - local obj_path="obj/local/arm64-v8a/$filename" - - if [ -f "$libs_path" ]; then - size=$(ls -lh "$libs_path" | awk '{print $5}') - echo -e " ${GREEN}[OK]${NC} $filename ($size)" - return 0 - elif [ -f "$obj_path" ]; then - echo -e " ${YELLOW}[WARN]${NC} $filename found in obj but not in libs. Copying..." - mkdir -p "libs/arm64-v8a" - cp "$obj_path" "$libs_path" - if [ -x "$obj_path" ]; then - chmod +x "$libs_path" - fi - size=$(ls -lh "$libs_path" | awk '{print $5}') - echo -e " ${GREEN}[OK]${NC} $filename ($size) (Copied from obj)" - return 0 - else - echo -e " ${RED}[ERROR]${NC} $filename not found!" - echo " Checked paths:" - echo " - $libs_path" - echo " - $obj_path" - return 1 - fi -} - -# Check if NDK path is set -if [ -z "$ANDROID_NDK" ]; then - log_error "ANDROID_NDK is not set. Please set it to your Android NDK path." - echo "Example: export ANDROID_NDK=/path/to/android-ndk-r21d" - exit 1 -fi - -# Set NNTRAINER_ROOT -SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -NNTRAINER_ROOT="${NNTRAINER_ROOT:-$(cd "$SCRIPT_DIR/subprojects/nntrainer" && pwd)}" -export NNTRAINER_ROOT - -log_header "Build CausalLM Test Application" -echo "NNTRAINER_ROOT: $NNTRAINER_ROOT" -echo "ANDROID_NDK: $ANDROID_NDK" -echo "Working directory: $(pwd)" - -# Check required libraries -log_step "1/2" "Check Dependencies" - -MISSING_DEPS=false - -# Check Core Lib -if [ -f "$SCRIPT_DIR/jni/libs/arm64-v8a/libquick_dot_ai_core.so" ]; then - echo -e " ${GREEN}[OK]${NC} libquick_dot_ai_core.so found" -else - echo -e " ${RED}[MISSING]${NC} libquick_dot_ai_core.so" - echo " -> Run ./build_android.sh first" - MISSING_DEPS=true -fi - -# Check API Lib -if [ -f "$SCRIPT_DIR/jni/libs/arm64-v8a/libquick_dot_ai_api.so" ]; then - echo -e " ${GREEN}[OK]${NC} libquick_dot_ai_api.so found" -else - echo -e " ${RED}[MISSING]${NC} libquick_dot_ai_api.so" - echo " -> Run ./build_api_lib.sh first" - MISSING_DEPS=true -fi - -if [ "$MISSING_DEPS" = true ]; then - echo "" - log_error "Missing dependencies. Please build required libraries first." - exit 1 -fi - -log_success "All dependencies found" - -# Step 2: Build Test App -log_step "2/2" "Build Test App" - -cd "$SCRIPT_DIR/jni" - -log_info "Building with ndk-build (builds test_api)..." -if ndk-build NDK_PROJECT_PATH=. NDK_LIBS_OUT=./libs NDK_OUT=./obj APP_BUILD_SCRIPT=./Android.mk NDK_APPLICATION_MK=./Application.mk test_api -j $(nproc); then - log_success "Build completed successfully" -else - log_error "Build failed" - exit 1 -fi - -# Verify output -echo "" -echo "Build artifacts:" - -check_artifact "test_api" || exit 1 - -# Summary -log_header "Build Summary" -log_success "Build completed successfully!" -echo "" -echo "Output files are in: $SCRIPT_DIR/jni/libs/arm64-v8a/" -echo "" -echo "Executables:" -echo " - test_api (API test application)" -echo "" -echo "To install and run:" -echo " ./install_android.sh" -echo "" diff --git a/build_tokenizer_android.sh b/build_tokenizer_android.sh deleted file mode 100755 index 2cca45e9..00000000 --- a/build_tokenizer_android.sh +++ /dev/null @@ -1,248 +0,0 @@ -#!/bin/bash - -# Script to build tokenizers-cpp library for Android -set -e - -# Default target ABI -TARGET_ABI="${1:-arm64-v8a}" - -echo "Building tokenizers-cpp library for Android $TARGET_ABI..." - -# Check prerequisites -if [ -z "$ANDROID_NDK" ]; then - echo "Error: ANDROID_NDK is not set. Please set it to your Android NDK path." - exit 1 -fi - -# Check if cmake is installed -if ! command -v cmake &> /dev/null; then - echo "Error: cmake is not installed. Please install cmake." - exit 1 -fi - -# Check if rust is installed -if ! command -v rustc &> /dev/null || ! command -v cargo &> /dev/null; then - echo "Error: Rust is not installed. Please install Rust from https://rustup.rs/" - exit 1 -fi - -# Map Android ABI to Rust target -case "$TARGET_ABI" in - "arm64-v8a") - RUST_TARGET="aarch64-linux-android" - ;; - "armeabi-v7a") - RUST_TARGET="armv7-linux-androideabi" - ;; - "x86") - RUST_TARGET="i686-linux-android" - ;; - "x86_64") - RUST_TARGET="x86_64-linux-android" - ;; -esac - -# Install Rust target if not already installed -echo "Checking Rust target: $RUST_TARGET" -if ! rustup target list --installed | grep -q "$RUST_TARGET"; then - echo "Installing Rust target: $RUST_TARGET" - rustup target add "$RUST_TARGET" -fi - -# Validate target ABI -case "$TARGET_ABI" in - "arm64-v8a"|"armeabi-v7a"|"x86"|"x86_64") - echo "Target ABI: $TARGET_ABI" - ;; - *) - echo "Error: Invalid target ABI: $TARGET_ABI" - echo "Supported ABIs: arm64-v8a, armeabi-v7a, x86, x86_64" - exit 1 - ;; -esac - -# Set build directory -SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -BUILD_DIR="$SCRIPT_DIR/tokenizers-cpp-build" - -# Clone tokenizers-cpp repository if not exists -if [ ! -d "$BUILD_DIR/tokenizers-cpp" ]; then - echo "Cloning tokenizers-cpp repository..." - mkdir -p "$BUILD_DIR" - cd "$BUILD_DIR" - git clone https://github.com/mlc-ai/tokenizers-cpp.git -fi - -cd "$BUILD_DIR/tokenizers-cpp" - -# Update submodules -echo "Updating submodules..." -git submodule update --init --recursive - -# Create build directory for specific ABI -mkdir -p "build-android-$TARGET_ABI" -cd "build-android-$TARGET_ABI" - -# Set up Android toolchain variables -ANDROID_PLATFORM="android-29" -ANDROID_STL="c++_static" - -# Detect platform for NDK paths -if [[ "$OSTYPE" == "darwin"* ]]; then - NDK_HOST="darwin-x86_64" -elif [[ "$OSTYPE" == "linux-gnu"* ]]; then - NDK_HOST="linux-x86_64" -elif [[ "$OSTYPE" == "msys" ]] || [[ "$OSTYPE" == "cygwin" ]] || [[ "$OSTYPE" == "win32" ]]; then - NDK_HOST="windows-x86_64" -else - echo "Warning: Unknown platform $OSTYPE, assuming linux-x86_64" - NDK_HOST="linux-x86_64" -fi - -# Set Rust environment variables for cross-compilation -export CARGO_TARGET_DIR="$BUILD_DIR/tokenizers-cpp/build-android-$TARGET_ABI/rust" - -# Additional Rust configuration for Android -export CARGO_BUILD_TARGET="$RUST_TARGET" -export CARGO_TARGET_AARCH64_LINUX_ANDROID_LINKER="$ANDROID_NDK/toolchains/llvm/prebuilt/$NDK_HOST/bin/aarch64-linux-android29-clang" -export CARGO_TARGET_ARMV7_LINUX_ANDROIDEABI_LINKER="$ANDROID_NDK/toolchains/llvm/prebuilt/$NDK_HOST/bin/armv7a-linux-androideabi29-clang" -export CARGO_TARGET_I686_LINUX_ANDROID_LINKER="$ANDROID_NDK/toolchains/llvm/prebuilt/$NDK_HOST/bin/i686-linux-android29-clang" -export CARGO_TARGET_X86_64_LINUX_ANDROID_LINKER="$ANDROID_NDK/toolchains/llvm/prebuilt/$NDK_HOST/bin/x86_64-linux-android29-clang" - -export CC_aarch64_linux_android="$ANDROID_NDK/toolchains/llvm/prebuilt/$NDK_HOST/bin/aarch64-linux-android29-clang" -export CXX_aarch64_linux_android="$ANDROID_NDK/toolchains/llvm/prebuilt/$NDK_HOST/bin/aarch64-linux-android29-clang++" -export AR_aarch64_linux_android="$ANDROID_NDK/toolchains/llvm/prebuilt/$NDK_HOST/bin/llvm-ar" -export CC_armv7_linux_androideabi="$ANDROID_NDK/toolchains/llvm/prebuilt/$NDK_HOST/bin/armv7a-linux-androideabi29-clang" -export CXX_armv7_linux_androideabi="$ANDROID_NDK/toolchains/llvm/prebuilt/$NDK_HOST/bin/armv7a-linux-androideabi29-clang++" -export AR_armv7_linux_androideabi="$ANDROID_NDK/toolchains/llvm/prebuilt/$NDK_HOST/bin/llvm-ar" -export CC_i686_linux_android="$ANDROID_NDK/toolchains/llvm/prebuilt/$NDK_HOST/bin/i686-linux-android29-clang" -export CXX_i686_linux_android="$ANDROID_NDK/toolchains/llvm/prebuilt/$NDK_HOST/bin/i686-linux-android29-clang++" -export AR_i686_linux_android="$ANDROID_NDK/toolchains/llvm/prebuilt/$NDK_HOST/bin/llvm-ar" -export CC_x86_64_linux_android="$ANDROID_NDK/toolchains/llvm/prebuilt/$NDK_HOST/bin/x86_64-linux-android29-clang" -export CXX_x86_64_linux_android="$ANDROID_NDK/toolchains/llvm/prebuilt/$NDK_HOST/bin/x86_64-linux-android29-clang++" -export AR_x86_64_linux_android="$ANDROID_NDK/toolchains/llvm/prebuilt/$NDK_HOST/bin/llvm-ar" - -# Configure with CMake for Android -echo "Configuring CMake for Android $TARGET_ABI..." -cmake .. \ - -DCMAKE_TOOLCHAIN_FILE="$ANDROID_NDK/build/cmake/android.toolchain.cmake" \ - -DANDROID_ABI="$TARGET_ABI" \ - -DANDROID_PLATFORM="$ANDROID_PLATFORM" \ - -DANDROID_STL="$ANDROID_STL" \ - -DCMAKE_BUILD_TYPE=Release \ - -DBUILD_SHARED_LIBS=OFF \ - -DTOKENIZERS_CPP_BUILD_TESTS=OFF \ - -DTOKENIZERS_CPP_BUILD_EXAMPLES=OFF \ - -DCMAKE_VERBOSE_MAKEFILE=ON - -# Build the library -echo "Building tokenizers-cpp..." -cmake --build . -j$(nproc) --verbose - -# Show what was actually built -echo "Build complete. Checking build outputs..." -echo "Contents of build directory:" -ls -la -echo "" -echo "Looking for static libraries (.a files):" -find . -name "*.a" -type f -ls -echo "" - -# Find and copy the built library -echo "Searching for built libraries..." -mkdir -p "$SCRIPT_DIR/lib/$TARGET_ABI" - -# Current directory is build-android-$TARGET_ABI -CURRENT_BUILD_DIR="$BUILD_DIR/tokenizers-cpp/build-android-$TARGET_ABI" - -# Find all the generated libraries -echo "Looking for .a files in build directory..." -find "$CURRENT_BUILD_DIR" -name "*.a" -type f | while read -r lib; do - echo "Found library: $lib" -done - -# Collect all libraries to combine -LIBS_TO_COMBINE="" - -# Search for specific libraries with more flexible paths -for lib_name in "libtokenizers_cpp.a" "libtokenizers_c.a" "libsentencepiece.a"; do - echo "Searching for $lib_name..." - lib_path=$(find "$CURRENT_BUILD_DIR" -name "$lib_name" -type f | head -n 1) - if [ -n "$lib_path" ]; then - echo "Found $lib_name at: $lib_path" - LIBS_TO_COMBINE="$LIBS_TO_COMBINE $lib_path" - fi -done - -# If specific libraries not found, collect all .a files -if [ -z "$LIBS_TO_COMBINE" ]; then - echo "Specific libraries not found. Collecting all .a files..." - LIBS_TO_COMBINE=$(find "$CURRENT_BUILD_DIR" -name "*.a" -type f | grep -v "CMakeFiles" | tr '\n' ' ') -fi - -# Combine all libraries into one -if [ -n "$LIBS_TO_COMBINE" ]; then - echo "Libraries to combine: $LIBS_TO_COMBINE" - - # Create a temporary directory for extracting object files - TEMP_DIR="$BUILD_DIR/temp_objs" - rm -rf "$TEMP_DIR" - mkdir -p "$TEMP_DIR" - cd "$TEMP_DIR" - - # Extract all object files from each library - for lib in $LIBS_TO_COMBINE; do - if [ -f "$lib" ]; then - echo "Extracting from $lib..." - "$ANDROID_NDK/toolchains/llvm/prebuilt/$NDK_HOST/bin/llvm-ar" x "$lib" - else - echo "Warning: Could not find $lib" - fi - done - - # Create the combined library - echo "Creating combined library..." - if ls *.o 1> /dev/null 2>&1; then - "$ANDROID_NDK/toolchains/llvm/prebuilt/$NDK_HOST/bin/llvm-ar" rcs "$SCRIPT_DIR/lib/$TARGET_ABI/libtokenizers_android_c.a" *.o - echo "Combined library created successfully" - else - echo "Error: No object files found to combine" - echo "Checking if any libraries were built..." - - # If no object files, maybe the libraries are header-only or built differently - # Try to copy the first found library as-is - first_lib=$(echo $LIBS_TO_COMBINE | awk '{print $1}') - if [ -f "$first_lib" ]; then - echo "Copying $first_lib as libtokenizers_android_c.a" - cp "$first_lib" "$SCRIPT_DIR/lib/$TARGET_ABI/libtokenizers_android_c.a" - else - cd .. - rm -rf "$TEMP_DIR" - exit 1 - fi - fi - - # Clean up - cd .. - rm -rf "$TEMP_DIR" -else - echo "Error: No libraries found to combine" - echo "Build may have failed. Check the build output above." - exit 1 -fi - -# For backward compatibility, also copy to lib directory for default ABI -if [ "$TARGET_ABI" = "arm64-v8a" ] && [ -f "$SCRIPT_DIR/lib/$TARGET_ABI/libtokenizers_android_c.a" ]; then - cp "$SCRIPT_DIR/lib/$TARGET_ABI/libtokenizers_android_c.a" "$SCRIPT_DIR/lib/libtokenizers_android_c.a" -fi - -if [ -f "$SCRIPT_DIR/lib/$TARGET_ABI/libtokenizers_android_c.a" ]; then - echo "Build completed successfully!" - echo "Library copied to: $SCRIPT_DIR/lib/$TARGET_ABI/libtokenizers_android_c.a" - if [ "$TARGET_ABI" = "arm64-v8a" ]; then - echo "Also copied to: $SCRIPT_DIR/lib/libtokenizers_android_c.a (for backward compatibility)" - fi -else - echo "Error: Failed to build or find the tokenizers library" - exit 1 -fi diff --git a/chat_template.cpp b/chat_template.cpp deleted file mode 100644 index dc91cb66..00000000 --- a/chat_template.cpp +++ /dev/null @@ -1,1941 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * Copyright (C) 2026 Samsung Electronics Co., Ltd. All Rights Reserved. - * - * @file chat_template.cpp - * @date 10 Apr 2026 - * @brief Chat template implementation with mini Jinja2 renderer - * @see https://github.com/nntrainer/Quick.AI - * @author Eunju Yang - * @bug No known bugs except for NYI items - */ - -#include "chat_template.h" -#include -#include -#include -#include -#include -#include -#include -#include - -namespace quick_dot_ai { - -// ============================================================================ -// Token types for the Jinja2 lexer -// ============================================================================ -/** @brief TokenType - enum class for Jinja2 template processing */ -enum class TokenType { - TEXT, - EXPRESSION_START, // {{ - EXPRESSION_END, // }} - STATEMENT_START, // {% - STATEMENT_END, // %} - STRING, - INTEGER, - FLOAT, - IDENTIFIER, - DOT, - LBRACKET, - RBRACKET, - LPAREN, - RPAREN, - PLUS, - MINUS, - PERCENT, - PIPE, - COMMA, - EQ, // == - NEQ, // != - ASSIGN, // = - NOT, // not - AND, // and - OR, // or - TRUE_LIT, - FALSE_LIT, - NONE_LIT, - IF, - ELIF, - ELSE, - ENDIF, - FOR, - IN, - ENDFOR, - SET, - IS, - TILDE, // ~ - COLON, // : - GT, // > - LT, // < - GTE, // >= - LTE, // <= - END_OF_INPUT, -}; - -/** @brief Lexer token with type, value, and whitespace control flags */ -struct Token { - TokenType type; - std::string value; - bool strip_before = false; // {%- or {{- - bool strip_after = false; // -%} or -}} -}; - -// ============================================================================ -// Lexer -// ============================================================================ -/** @brief Tokenizes Jinja2 template strings into token sequences */ -class Lexer { -public: - /** @brief Construct lexer with input template string */ - explicit Lexer(const std::string &input) : input_(input), pos_(0) {} - - /** @brief Tokenize the input template into a sequence of tokens */ - std::vector tokenize() { - std::vector tokens; - - while (pos_ < input_.size()) { - if (match("{{")) { - bool strip = false; - if (pos_ < input_.size() && input_[pos_] == '-') { - strip = true; - pos_++; - } - Token start; - start.type = TokenType::EXPRESSION_START; - start.strip_before = strip; - tokens.push_back(start); - skipWhitespace(); - tokenizeInside(tokens, TokenType::EXPRESSION_END); - } else if (match("{%")) { - bool strip = false; - if (pos_ < input_.size() && input_[pos_] == '-') { - strip = true; - pos_++; - } - Token start; - start.type = TokenType::STATEMENT_START; - start.strip_before = strip; - tokens.push_back(start); - skipWhitespace(); - tokenizeInside(tokens, TokenType::STATEMENT_END); - } else { - std::string text; - while (pos_ < input_.size()) { - if ((pos_ + 1 < input_.size()) && - ((input_[pos_] == '{' && - (input_[pos_ + 1] == '{' || input_[pos_ + 1] == '%')))) { - break; - } - text += input_[pos_++]; - } - if (!text.empty()) { - Token t; - t.type = TokenType::TEXT; - t.value = text; - tokens.push_back(t); - } - } - } - - Token eof; - eof.type = TokenType::END_OF_INPUT; - tokens.push_back(eof); - return tokens; - } - -private: - /** @brief Try to match a string at current position and advance */ - bool match(const std::string &s) { - if (pos_ + s.size() <= input_.size() && - input_.substr(pos_, s.size()) == s) { - pos_ += s.size(); - return true; - } - return false; - } - - /** @brief Skip whitespace characters at current position */ - void skipWhitespace() { - while (pos_ < input_.size() && std::isspace(input_[pos_])) - pos_++; - } - - /** @brief Tokenize content inside expression or statement tags */ - void tokenizeInside(std::vector &tokens, TokenType end_type) { - while (pos_ < input_.size()) { - skipWhitespace(); - if (pos_ >= input_.size()) - break; - - // Check for closing tag - if (end_type == TokenType::EXPRESSION_END) { - if (pos_ + 1 < input_.size() && input_[pos_] == '-' && - input_[pos_ + 1] == '}' && pos_ + 2 < input_.size() && - input_[pos_ + 2] == '}') { - pos_ += 3; - Token end; - end.type = end_type; - end.strip_after = true; - tokens.push_back(end); - return; - } - if (match("}}")) { - Token end; - end.type = end_type; - tokens.push_back(end); - return; - } - } else if (end_type == TokenType::STATEMENT_END) { - if (pos_ + 1 < input_.size() && input_[pos_] == '-' && - input_[pos_ + 1] == '%' && pos_ + 2 < input_.size() && - input_[pos_ + 2] == '}') { - pos_ += 3; - Token end; - end.type = end_type; - end.strip_after = true; - tokens.push_back(end); - return; - } - if (match("%}")) { - Token end; - end.type = end_type; - tokens.push_back(end); - return; - } - } - - // String literal - if (input_[pos_] == '\'' || input_[pos_] == '"') { - tokens.push_back(readString()); - continue; - } - - // Number - if (std::isdigit(input_[pos_])) { - tokens.push_back(readNumber()); - continue; - } - - // Identifier or keyword - if (std::isalpha(input_[pos_]) || input_[pos_] == '_') { - tokens.push_back(readIdentifier()); - continue; - } - - // Operators and punctuation - Token t; - switch (input_[pos_]) { - case '.': - t.type = TokenType::DOT; - t.value = "."; - pos_++; - break; - case '[': - t.type = TokenType::LBRACKET; - t.value = "["; - pos_++; - break; - case ']': - t.type = TokenType::RBRACKET; - t.value = "]"; - pos_++; - break; - case '(': - t.type = TokenType::LPAREN; - t.value = "("; - pos_++; - break; - case ')': - t.type = TokenType::RPAREN; - t.value = ")"; - pos_++; - break; - case '+': - t.type = TokenType::PLUS; - t.value = "+"; - pos_++; - break; - case '-': - t.type = TokenType::MINUS; - t.value = "-"; - pos_++; - break; - case '%': - t.type = TokenType::PERCENT; - t.value = "%"; - pos_++; - break; - case '|': - t.type = TokenType::PIPE; - t.value = "|"; - pos_++; - break; - case ',': - t.type = TokenType::COMMA; - t.value = ","; - pos_++; - break; - case '=': - if (pos_ + 1 < input_.size() && input_[pos_ + 1] == '=') { - t.type = TokenType::EQ; - t.value = "=="; - pos_ += 2; - } else { - t.type = TokenType::ASSIGN; - t.value = "="; - pos_++; - } - break; - case '!': - if (pos_ + 1 < input_.size() && input_[pos_ + 1] == '=') { - t.type = TokenType::NEQ; - t.value = "!="; - pos_ += 2; - } else { - pos_++; - continue; - } - break; - case '~': - t.type = TokenType::TILDE; - t.value = "~"; - pos_++; - break; - case ':': - t.type = TokenType::COLON; - t.value = ":"; - pos_++; - break; - case '>': - if (pos_ + 1 < input_.size() && input_[pos_ + 1] == '=') { - t.type = TokenType::GTE; - t.value = ">="; - pos_ += 2; - } else { - t.type = TokenType::GT; - t.value = ">"; - pos_++; - } - break; - case '<': - if (pos_ + 1 < input_.size() && input_[pos_ + 1] == '=') { - t.type = TokenType::LTE; - t.value = "<="; - pos_ += 2; - } else { - t.type = TokenType::LT; - t.value = "<"; - pos_++; - } - break; - default: - pos_++; - continue; - } - tokens.push_back(t); - } - } - - /** @brief Read a string literal token */ - Token readString() { - char quote = input_[pos_++]; - std::string value; - while (pos_ < input_.size() && input_[pos_] != quote) { - if (input_[pos_] == '\\' && pos_ + 1 < input_.size()) { - pos_++; - switch (input_[pos_]) { - case 'n': - value += '\n'; - break; - case 't': - value += '\t'; - break; - case '\\': - value += '\\'; - break; - case '\'': - value += '\''; - break; - case '"': - value += '"'; - break; - default: - value += '\\'; - value += input_[pos_]; - break; - } - } else { - value += input_[pos_]; - } - pos_++; - } - if (pos_ < input_.size()) - pos_++; // skip closing quote - - Token t; - t.type = TokenType::STRING; - t.value = value; - return t; - } - - /** @brief Read a numeric literal token */ - Token readNumber() { - std::string value; - bool has_dot = false; - while (pos_ < input_.size() && - (std::isdigit(input_[pos_]) || input_[pos_] == '.')) { - if (input_[pos_] == '.') { - if (has_dot) - break; - has_dot = true; - } - value += input_[pos_++]; - } - Token t; - t.type = has_dot ? TokenType::FLOAT : TokenType::INTEGER; - t.value = value; - return t; - } - - /** @brief Read an identifier or keyword token */ - Token readIdentifier() { - std::string value; - while (pos_ < input_.size() && - (std::isalnum(input_[pos_]) || input_[pos_] == '_')) { - value += input_[pos_++]; - } - - Token t; - t.value = value; - - // Check for keywords - if (value == "if") - t.type = TokenType::IF; - else if (value == "elif") - t.type = TokenType::ELIF; - else if (value == "else") - t.type = TokenType::ELSE; - else if (value == "endif") - t.type = TokenType::ENDIF; - else if (value == "for") - t.type = TokenType::FOR; - else if (value == "in") - t.type = TokenType::IN; - else if (value == "endfor") - t.type = TokenType::ENDFOR; - else if (value == "set") - t.type = TokenType::SET; - else if (value == "not") - t.type = TokenType::NOT; - else if (value == "and") - t.type = TokenType::AND; - else if (value == "or") - t.type = TokenType::OR; - else if (value == "is") - t.type = TokenType::IS; - else if (value == "true" || value == "True") - t.type = TokenType::TRUE_LIT; - else if (value == "false" || value == "False") - t.type = TokenType::FALSE_LIT; - else if (value == "none" || value == "None") - t.type = TokenType::NONE_LIT; - else - t.type = TokenType::IDENTIFIER; - - return t; - } - - std::string input_; - size_t pos_; -}; - -// ============================================================================ -// AST Nodes -// ============================================================================ -/** @brief Base AST node for template parsing */ -struct ASTNode { - virtual ~ASTNode() = default; -}; - -using ASTNodePtr = std::shared_ptr; - -/** @brief Base expression AST node */ -struct ExprNode : ASTNode {}; -using ExprNodePtr = std::shared_ptr; - -/** @brief AST node for literal text output */ -struct TextNode : ASTNode { - std::string text; -}; - -/** @brief AST node for expression output ({{ expr }}) */ -struct OutputNode : ASTNode { - ExprNodePtr expr; - bool strip_before = false; - bool strip_after = false; -}; - -/** @brief Single branch of an if/elif/else block */ -struct IfBranch { - ExprNodePtr condition; // nullptr for else branch - std::vector body; -}; - -/** @brief AST node for if/elif/else conditional blocks */ -struct IfNode : ASTNode { - std::vector branches; - bool strip_before = false; - bool strip_after = false; -}; - -/** @brief AST node for for-loop iteration blocks */ -struct ForNode : ASTNode { - std::string var_name; - ExprNodePtr iterable; - std::vector body; - bool strip_before = false; - bool strip_after = false; -}; - -/** @brief AST node for variable assignment (set statement) */ -struct SetNode : ASTNode { - std::string var_name; - std::string attr_name; // for "set ns.attr = val" (empty if simple set) - ExprNodePtr value; - bool strip_before = false; - bool strip_after = false; -}; - -// Expression nodes -/** @brief Expression node for string literal values */ -struct StringLiteral : ExprNode { - std::string value; -}; - -/** @brief Expression node for integer literal values */ -struct IntegerLiteral : ExprNode { - int value; -}; - -/** @brief Expression node for boolean literal values */ -struct BoolLiteral : ExprNode { - bool value; -}; - -/** @brief Expression node for None/null literal values */ -struct NoneLiteral : ExprNode {}; - -/** @brief Expression node for variable references */ -struct VariableExpr : ExprNode { - std::string name; -}; - -/** @brief Expression node for attribute access (obj.attr) */ -struct AttributeExpr : ExprNode { - ExprNodePtr object; - std::string attribute; -}; - -/** @brief Expression node for index access (obj[key]) */ -struct IndexExpr : ExprNode { - ExprNodePtr object; - ExprNodePtr index; -}; - -/** @brief Expression node for binary operations (+, ==, and, etc.) */ -struct BinaryExpr : ExprNode { - std::string op; // "+", "==", "!=", "and", "or", "%" - ExprNodePtr left; - ExprNodePtr right; -}; - -/** @brief Expression node for unary operations (not) */ -struct UnaryExpr : ExprNode { - std::string op; // "not" - ExprNodePtr operand; -}; - -/** @brief Expression node for filter application (val | filter) */ -struct FilterExpr : ExprNode { - ExprNodePtr value; - std::string filter_name; -}; - -/** @brief Expression node for "is defined" test */ -struct IsDefinedExpr : ExprNode { - ExprNodePtr value; -}; - -/** @brief Expression node for function calls */ -struct FunctionCallExpr : ExprNode { - std::string name; - std::vector args; -}; - -/** @brief Expression node for method calls (obj.method()) */ -struct MethodCallExpr : ExprNode { - ExprNodePtr object; - std::string method; - std::vector args; -}; - -/** @brief Expression node for slice operations (obj[start:stop:step]) */ -struct SliceExpr : ExprNode { - ExprNodePtr object; - ExprNodePtr start; // nullable - ExprNodePtr stop; // nullable - ExprNodePtr step; // nullable -}; - -/** @brief Expression node for "in" containment test */ -struct ContainsExpr : ExprNode { - ExprNodePtr item; - ExprNodePtr container; -}; - -// ============================================================================ -// Parser -// ============================================================================ -/** @brief Parses token sequences into an AST for template rendering */ -class Parser { -public: - explicit Parser(const std::vector &tokens) : - tokens_(tokens), pos_(0) {} - - /** @brief Parse tokens into AST node list */ - std::vector parse() { - std::vector nodes; - parseBody(nodes, {}); - return nodes; - } - -private: - /** @brief Get the current token without advancing */ - const Token ¤t() const { return tokens_[pos_]; } - - /** @brief Advance to next token and return the current one */ - const Token &advance() { return tokens_[pos_++]; } - - /** @brief Peek at the next token without advancing */ - const Token &peek() const { return tokens_[pos_ + 1]; } - - /** @brief Check if current token matches the given type */ - bool check(TokenType type) const { return current().type == type; } - - /** @brief Match current token type and advance if matched */ - bool matchToken(TokenType type) { - if (check(type)) { - advance(); - return true; - } - return false; - } - - /** @brief Expect current token to match type, throw on mismatch */ - Token expect(TokenType type) { - if (!check(type)) { - throw std::runtime_error("ChatTemplate parser: unexpected token '" + - current().value + "', expected type " + - std::to_string(static_cast(type))); - } - return advance(); - } - - /** @brief Parse template body until a stop keyword is found */ - void parseBody(std::vector &nodes, - const std::vector &stop_keywords) { - while (pos_ < tokens_.size() && current().type != TokenType::END_OF_INPUT) { - if (current().type == TokenType::TEXT) { - auto node = std::make_shared(); - node->text = current().value; - nodes.push_back(node); - advance(); - } else if (current().type == TokenType::EXPRESSION_START) { - nodes.push_back(parseOutput()); - } else if (current().type == TokenType::STATEMENT_START) { - // Peek at the keyword after {% - size_t save = pos_; - advance(); // skip STATEMENT_START - - // Check if this is a stop keyword - bool is_stop = false; - for (auto sk : stop_keywords) { - if (check(sk)) { - is_stop = true; - break; - } - } - - if (is_stop) { - pos_ = save; // rewind - return; - } - - pos_ = save; // rewind - parseStatement(nodes); - } else { - advance(); // skip unexpected - } - } - } - - /** @brief Parse an output expression block ({{ expr }}) */ - ASTNodePtr parseOutput() { - auto node = std::make_shared(); - Token start = expect(TokenType::EXPRESSION_START); - node->strip_before = start.strip_before; - node->expr = parseExpression(); - Token end = expect(TokenType::EXPRESSION_END); - node->strip_after = end.strip_after; - return node; - } - - /** @brief Parse a statement block ({% ... %}) */ - void parseStatement(std::vector &nodes) { - Token start = expect(TokenType::STATEMENT_START); - bool strip_before = start.strip_before; - - if (check(TokenType::IF)) { - nodes.push_back(parseIf(strip_before)); - } else if (check(TokenType::FOR)) { - nodes.push_back(parseFor(strip_before)); - } else if (check(TokenType::SET)) { - nodes.push_back(parseSet(strip_before)); - } else { - // Unknown statement - skip to end - while (pos_ < tokens_.size() && - current().type != TokenType::STATEMENT_END) { - advance(); - } - if (check(TokenType::STATEMENT_END)) - advance(); - } - } - - /** @brief Parse an if/elif/else/endif block */ - ASTNodePtr parseIf(bool strip_before) { - auto node = std::make_shared(); - node->strip_before = strip_before; - - // Parse: if %} - expect(TokenType::IF); - IfBranch branch; - branch.condition = parseExpression(); - Token end = expect(TokenType::STATEMENT_END); - node->strip_after = end.strip_after; - - // Parse body until elif/else/endif - parseBody(branch.body, - {TokenType::ELIF, TokenType::ELSE, TokenType::ENDIF}); - node->branches.push_back(branch); - - // Parse elif/else branches - while (pos_ < tokens_.size() && - current().type == TokenType::STATEMENT_START) { - advance(); // skip {% - - if (check(TokenType::ELIF)) { - advance(); // skip elif - IfBranch elif_branch; - elif_branch.condition = parseExpression(); - expect(TokenType::STATEMENT_END); - parseBody(elif_branch.body, - {TokenType::ELIF, TokenType::ELSE, TokenType::ENDIF}); - node->branches.push_back(elif_branch); - } else if (check(TokenType::ELSE)) { - advance(); // skip else - expect(TokenType::STATEMENT_END); - IfBranch else_branch; - else_branch.condition = nullptr; // no condition = else - parseBody(else_branch.body, {TokenType::ENDIF}); - node->branches.push_back(else_branch); - } else if (check(TokenType::ENDIF)) { - advance(); // skip endif - expect(TokenType::STATEMENT_END); - break; - } else { - break; - } - } - - return node; - } - - /** @brief Parse a for/endfor loop block */ - ASTNodePtr parseFor(bool strip_before) { - auto node = std::make_shared(); - node->strip_before = strip_before; - - expect(TokenType::FOR); - node->var_name = expect(TokenType::IDENTIFIER).value; - expect(TokenType::IN); - node->iterable = parseExpression(); - Token end = expect(TokenType::STATEMENT_END); - node->strip_after = end.strip_after; - - parseBody(node->body, {TokenType::ENDFOR}); - - // Consume endfor - expect(TokenType::STATEMENT_START); - expect(TokenType::ENDFOR); - expect(TokenType::STATEMENT_END); - - return node; - } - - /** @brief Parse a set variable assignment statement */ - ASTNodePtr parseSet(bool strip_before) { - auto node = std::make_shared(); - node->strip_before = strip_before; - - expect(TokenType::SET); - node->var_name = expect(TokenType::IDENTIFIER).value; - - // Handle dotted assignment: "set ns.attr = val" - if (check(TokenType::DOT)) { - advance(); - node->attr_name = expect(TokenType::IDENTIFIER).value; - } - - expect(TokenType::ASSIGN); - node->value = parseExpression(); - Token end = expect(TokenType::STATEMENT_END); - node->strip_after = end.strip_after; - - return node; - } - - // Expression parsing with precedence - /** @brief Parse a complete expression with precedence */ - ExprNodePtr parseExpression() { return parseOr(); } - - /** @brief Parse OR boolean expression */ - ExprNodePtr parseOr() { - auto left = parseAnd(); - while (check(TokenType::OR)) { - advance(); - auto right = parseAnd(); - auto node = std::make_shared(); - node->op = "or"; - node->left = left; - node->right = right; - left = node; - } - return left; - } - - /** @brief Parse AND boolean expression */ - ExprNodePtr parseAnd() { - auto left = parseNot(); - while (check(TokenType::AND)) { - advance(); - auto right = parseNot(); - auto node = std::make_shared(); - node->op = "and"; - node->left = left; - node->right = right; - left = node; - } - return left; - } - - /** @brief Parse NOT unary boolean expression */ - ExprNodePtr parseNot() { - if (check(TokenType::NOT)) { - advance(); - auto node = std::make_shared(); - node->op = "not"; - node->operand = parseNot(); - return node; - } - return parseComparison(); - } - - /** @brief Parse comparison and "is" test expressions */ - ExprNodePtr parseComparison() { - auto left = parseContains(); - - if (check(TokenType::EQ) || check(TokenType::NEQ) || check(TokenType::GT) || - check(TokenType::LT) || check(TokenType::GTE) || - check(TokenType::LTE)) { - std::string op = advance().value; - auto right = parseContains(); - auto node = std::make_shared(); - node->op = op; - node->left = left; - node->right = right; - return node; - } - - // "is" tests: "is defined", "is not defined", "is string", "is false", etc. - if (check(TokenType::IS)) { - advance(); - bool negated = false; - if (check(TokenType::NOT)) { - negated = true; - advance(); - } - if (check(TokenType::IDENTIFIER)) { - std::string test_name = current().value; - advance(); - ExprNodePtr result; - if (test_name == "defined") { - auto node = std::make_shared(); - node->value = left; - result = node; - } else if (test_name == "string") { - // "is string" -> check if value is a string type - auto call = std::make_shared(); - call->name = "__is_string"; - call->args.push_back(left); - result = call; - } else if (test_name == "none") { - auto call = std::make_shared(); - call->name = "__is_none"; - call->args.push_back(left); - result = call; - } else if (test_name == "number") { - auto call = std::make_shared(); - call->name = "__is_number"; - call->args.push_back(left); - result = call; - } else { - // Fallback: treat unknown test as comparison - result = left; - } - if (negated) { - auto not_node = std::make_shared(); - not_node->op = "not"; - not_node->operand = result; - return not_node; - } - return result; - } - // "is true" / "is false" / "is none" (keyword forms) - if (check(TokenType::TRUE_LIT)) { - advance(); - auto node = std::make_shared(); - node->op = "=="; - node->left = left; - auto lit = std::make_shared(); - lit->value = true; - node->right = lit; - ExprNodePtr result = node; - if (negated) { - auto not_node = std::make_shared(); - not_node->op = "not"; - not_node->operand = result; - return not_node; - } - return result; - } - if (check(TokenType::FALSE_LIT)) { - advance(); - auto node = std::make_shared(); - node->op = "=="; - node->left = left; - auto lit = std::make_shared(); - lit->value = false; - node->right = lit; - ExprNodePtr result = node; - if (negated) { - auto not_node = std::make_shared(); - not_node->op = "not"; - not_node->operand = result; - return not_node; - } - return result; - } - if (check(TokenType::NONE_LIT)) { - advance(); - auto node = std::make_shared(); - node->op = "=="; - node->left = left; - node->right = std::make_shared(); - ExprNodePtr result = node; - if (negated) { - auto not_node = std::make_shared(); - not_node->op = "not"; - not_node->operand = result; - return not_node; - } - return result; - } - } - - return left; - } - - // "in" / "not in" containment - /** @brief Parse "in" and "not in" containment expressions */ - ExprNodePtr parseContains() { - auto left = parseAddition(); - - bool negated = false; - if (check(TokenType::NOT) && pos_ + 1 < tokens_.size() && - peek().type == TokenType::IN) { - advance(); // skip 'not' - negated = true; - } - - if (check(TokenType::IN)) { - advance(); - auto right = parseAddition(); - auto node = std::make_shared(); - node->item = left; - node->container = right; - if (negated) { - auto not_node = std::make_shared(); - not_node->op = "not"; - not_node->operand = node; - return not_node; - } - return node; - } - - return left; - } - - /** @brief Parse addition, subtraction, and tilde concat */ - ExprNodePtr parseAddition() { - auto left = parseModulo(); - while (check(TokenType::PLUS) || check(TokenType::MINUS) || - check(TokenType::TILDE)) { - std::string op = advance().value; - auto right = parseModulo(); - auto node = std::make_shared(); - node->op = op; - node->left = left; - node->right = right; - left = node; - } - return left; - } - - /** @brief Parse modulo arithmetic expression */ - ExprNodePtr parseModulo() { - auto left = parseFilter(); - while (check(TokenType::PERCENT)) { - advance(); - auto right = parseFilter(); - auto node = std::make_shared(); - node->op = "%"; - node->left = left; - node->right = right; - left = node; - } - return left; - } - - /** @brief Parse pipe filter expression (val | filter) */ - ExprNodePtr parseFilter() { - auto left = parsePostfix(); - while (check(TokenType::PIPE)) { - advance(); - std::string filter_name = expect(TokenType::IDENTIFIER).value; - auto node = std::make_shared(); - node->value = left; - node->filter_name = filter_name; - left = node; - } - return left; - } - - /** @brief Parse postfix operations (dot, index, method, slice) */ - ExprNodePtr parsePostfix() { - auto node = parsePrimary(); - while (true) { - if (check(TokenType::DOT)) { - advance(); - std::string attr = expect(TokenType::IDENTIFIER).value; - // Check for method call: obj.method(args) - if (check(TokenType::LPAREN)) { - advance(); - auto call = std::make_shared(); - call->object = node; - call->method = attr; - if (!check(TokenType::RPAREN)) { - call->args.push_back(parseExpression()); - while (check(TokenType::COMMA)) { - advance(); - call->args.push_back(parseExpression()); - } - } - expect(TokenType::RPAREN); - node = call; - } else { - auto access = std::make_shared(); - access->object = node; - access->attribute = attr; - node = access; - } - } else if (check(TokenType::LBRACKET)) { - advance(); - // Check for slice: obj[start:stop:step] or obj[::step] - if (check(TokenType::COLON)) { - // [:stop] or [::step] - auto slice = std::make_shared(); - slice->object = node; - slice->start = nullptr; - advance(); // skip first ':' - if (check(TokenType::COLON)) { - // [::step] - advance(); - slice->stop = nullptr; - if (!check(TokenType::RBRACKET)) { - slice->step = parseExpression(); - } - } else if (!check(TokenType::RBRACKET)) { - slice->stop = parseExpression(); - if (check(TokenType::COLON)) { - advance(); - if (!check(TokenType::RBRACKET)) { - slice->step = parseExpression(); - } - } - } - expect(TokenType::RBRACKET); - node = slice; - } else { - auto index = parseExpression(); - if (check(TokenType::COLON)) { - // [start:stop] or [start:stop:step] - advance(); - auto slice = std::make_shared(); - slice->object = node; - slice->start = index; - if (check(TokenType::COLON)) { - // [start::step] - advance(); - slice->stop = nullptr; - if (!check(TokenType::RBRACKET)) { - slice->step = parseExpression(); - } - } else if (!check(TokenType::RBRACKET)) { - slice->stop = parseExpression(); - if (check(TokenType::COLON)) { - advance(); - if (!check(TokenType::RBRACKET)) { - slice->step = parseExpression(); - } - } - } - expect(TokenType::RBRACKET); - node = slice; - } else { - expect(TokenType::RBRACKET); - auto access = std::make_shared(); - access->object = node; - access->index = index; - node = access; - } - } - } else { - break; - } - } - return node; - } - - /** @brief Parse primary expression (literals, variables, parens) */ - ExprNodePtr parsePrimary() { - // Unary minus - if (check(TokenType::MINUS)) { - advance(); - auto operand = parsePrimary(); - if (auto *intLit = dynamic_cast(operand.get())) { - intLit->value = -intLit->value; - return operand; - } - auto node = std::make_shared(); - node->op = "-"; - auto zero = std::make_shared(); - zero->value = 0; - node->left = zero; - node->right = operand; - return node; - } - if (check(TokenType::STRING)) { - auto node = std::make_shared(); - node->value = advance().value; - return node; - } - if (check(TokenType::INTEGER)) { - auto node = std::make_shared(); - node->value = std::stoi(advance().value); - return node; - } - if (check(TokenType::TRUE_LIT)) { - advance(); - auto node = std::make_shared(); - node->value = true; - return node; - } - if (check(TokenType::FALSE_LIT)) { - advance(); - auto node = std::make_shared(); - node->value = false; - return node; - } - if (check(TokenType::NONE_LIT)) { - advance(); - return std::make_shared(); - } - if (check(TokenType::IDENTIFIER)) { - std::string name = advance().value; - - // Check for function call - if (check(TokenType::LPAREN)) { - advance(); - // Special case: namespace(key=val, ...) with keyword arguments - if (name == "namespace") { - // Parse keyword arguments and build a JSON object initializer - // We create a FunctionCallExpr where args[0] is a JSON-like init - auto call = std::make_shared(); - call->name = name; - // Parse key=value pairs and store as string literal pairs - json init_pairs = json::object(); - while (!check(TokenType::RPAREN) && pos_ < tokens_.size()) { - if (check(TokenType::IDENTIFIER)) { - std::string key = advance().value; - if (check(TokenType::ASSIGN)) { - advance(); - // We can't fully evaluate here, so skip for now - // Just consume until comma or rparen - auto val_expr = parseExpression(); - // Store key for later reference - } - } - if (check(TokenType::COMMA)) - advance(); - else - break; - } - expect(TokenType::RPAREN); - return call; - } - auto call = std::make_shared(); - call->name = name; - if (!check(TokenType::RPAREN)) { - call->args.push_back(parseExpression()); - while (check(TokenType::COMMA)) { - advance(); - call->args.push_back(parseExpression()); - } - } - expect(TokenType::RPAREN); - return call; - } - - auto node = std::make_shared(); - node->name = name; - return node; - } - if (check(TokenType::LPAREN)) { - advance(); - auto expr = parseExpression(); - expect(TokenType::RPAREN); - return expr; - } - - // Fallback: return an empty string literal - return std::make_shared(); - } - - const std::vector &tokens_; - size_t pos_; -}; - -// ============================================================================ -// Evaluator -// ============================================================================ -/** @brief Evaluates AST nodes to produce rendered template output */ -class Evaluator { -public: - /** @brief Construct evaluator with template context variables */ - explicit Evaluator(const json &context) { scopes_.push_back(context); } - - /** @brief Evaluate AST node list and return rendered string */ - std::string evaluate(const std::vector &nodes) { - std::string result; - for (size_t i = 0; i < nodes.size(); ++i) { - std::string chunk = evalNode(nodes[i].get()); - - // Handle whitespace stripping - if (shouldStripBefore(nodes[i].get())) { - // Strip trailing whitespace from result - while (!result.empty() && - (result.back() == ' ' || result.back() == '\t' || - result.back() == '\n' || result.back() == '\r')) { - result.pop_back(); - } - } - - result += chunk; - - // Handle strip_after: strip leading whitespace of next text - if (shouldStripAfter(nodes[i].get()) && i + 1 < nodes.size()) { - auto *text = dynamic_cast(nodes[i + 1].get()); - if (text) { - size_t start = 0; - while (start < text->text.size() && - (text->text[start] == ' ' || text->text[start] == '\t' || - text->text[start] == '\n' || text->text[start] == '\r')) { - start++; - } - text->text = text->text.substr(start); - } - } - } - return result; - } - -private: - /** @brief Check if node has strip-before whitespace control */ - bool shouldStripBefore(ASTNode *node) { - if (auto *o = dynamic_cast(node)) - return o->strip_before; - if (auto *i = dynamic_cast(node)) - return i->strip_before; - if (auto *f = dynamic_cast(node)) - return f->strip_before; - if (auto *s = dynamic_cast(node)) - return s->strip_before; - return false; - } - - /** @brief Check if node has strip-after whitespace control */ - bool shouldStripAfter(ASTNode *node) { - if (auto *o = dynamic_cast(node)) - return o->strip_after; - if (auto *i = dynamic_cast(node)) - return i->strip_after; - if (auto *f = dynamic_cast(node)) - return f->strip_after; - if (auto *s = dynamic_cast(node)) - return s->strip_after; - return false; - } - - /** @brief Evaluate a single AST node and return its string result */ - std::string evalNode(ASTNode *node) { - if (auto *text = dynamic_cast(node)) { - return text->text; - } - if (auto *output = dynamic_cast(node)) { - json val = evalExpr(output->expr.get()); - return jsonToString(val); - } - if (auto *if_node = dynamic_cast(node)) { - return evalIf(if_node); - } - if (auto *for_node = dynamic_cast(node)) { - return evalFor(for_node); - } - if (auto *set_node = dynamic_cast(node)) { - json val = evalExpr(set_node->value.get()); - if (!set_node->attr_name.empty()) { - // Namespace attribute mutation: "set ns.attr = val" - // Find the namespace object and mutate it in place - for (auto it = scopes_.rbegin(); it != scopes_.rend(); ++it) { - if (it->contains(set_node->var_name)) { - (*it)[set_node->var_name][set_node->attr_name] = val; - break; - } - } - } else { - setVariable(set_node->var_name, val); - } - return ""; - } - return ""; - } - - /** @brief Evaluate an if/elif/else conditional node */ - std::string evalIf(IfNode *node) { - for (auto &branch : node->branches) { - if (!branch.condition || isTruthy(evalExpr(branch.condition.get()))) { - return evaluate(branch.body); - } - } - return ""; - } - - /** @brief Evaluate a for-loop node over an iterable */ - std::string evalFor(ForNode *node) { - json iterable = evalExpr(node->iterable.get()); - if (!iterable.is_array()) - return ""; - - std::string result; - size_t size = iterable.size(); - - for (size_t i = 0; i < size; ++i) { - // Push new scope with loop variable - json scope; - scope[node->var_name] = iterable[i]; - - // Loop context - json loop; - loop["index"] = static_cast(i + 1); - loop["index0"] = static_cast(i); - loop["first"] = (i == 0); - loop["last"] = (i == size - 1); - loop["length"] = static_cast(size); - scope["loop"] = loop; - - scopes_.push_back(scope); - result += evaluate(node->body); - scopes_.pop_back(); - } - - return result; - } - - /** @brief Evaluate an expression node and return JSON value */ - json evalExpr(ExprNode *node) { - if (auto *str = dynamic_cast(node)) { - return str->value; - } - if (auto *num = dynamic_cast(node)) { - return num->value; - } - if (auto *b = dynamic_cast(node)) { - return b->value; - } - if (dynamic_cast(node)) { - return nullptr; - } - if (auto *var = dynamic_cast(node)) { - return lookupVariable(var->name); - } - if (auto *attr = dynamic_cast(node)) { - json obj = evalExpr(attr->object.get()); - if (obj.is_object() && obj.contains(attr->attribute)) { - return obj[attr->attribute]; - } - return nullptr; - } - if (auto *idx = dynamic_cast(node)) { - json obj = evalExpr(idx->object.get()); - json index = evalExpr(idx->index.get()); - if (obj.is_array() && index.is_number_integer()) { - int i = index.get(); - int sz = static_cast(obj.size()); - if (i < 0) - i += sz; - if (i >= 0 && i < sz) - return obj[i]; - } else if (obj.is_object() && index.is_string()) { - std::string key = index.get(); - if (obj.contains(key)) - return obj[key]; - } - return nullptr; - } - if (auto *bin = dynamic_cast(node)) { - return evalBinary(bin); - } - if (auto *unary = dynamic_cast(node)) { - if (unary->op == "not") { - return !isTruthy(evalExpr(unary->operand.get())); - } - } - if (auto *filter = dynamic_cast(node)) { - json val = evalExpr(filter->value.get()); - if (filter->filter_name == "trim" && val.is_string()) { - std::string s = val.get(); - // Trim whitespace - size_t start = s.find_first_not_of(" \t\n\r"); - size_t end = s.find_last_not_of(" \t\n\r"); - if (start == std::string::npos) - return ""; - return s.substr(start, end - start + 1); - } - if (filter->filter_name == "length") { - if (val.is_array()) - return static_cast(val.size()); - if (val.is_string()) - return static_cast(val.get().size()); - return 0; - } - if (filter->filter_name == "tojson") { - return val.dump(); - } - return val; // unknown filter, passthrough - } - if (auto *def = dynamic_cast(node)) { - // Check if the variable exists in any scope - if (auto *var = dynamic_cast(def->value.get())) { - for (auto it = scopes_.rbegin(); it != scopes_.rend(); ++it) { - if (it->contains(var->name)) - return true; - } - return false; - } - // For attribute access, check if the parent exists and has the attr - if (auto *attr = dynamic_cast(def->value.get())) { - json obj = evalExpr(attr->object.get()); - return obj.is_object() && obj.contains(attr->attribute); - } - return false; - } - if (auto *call = dynamic_cast(node)) { - if (call->name == "raise_exception") { - std::string msg = "Template error"; - if (!call->args.empty()) { - json arg = evalExpr(call->args[0].get()); - if (arg.is_string()) - msg = arg.get(); - } - throw std::runtime_error("ChatTemplate: " + msg); - } - if (call->name == "namespace") { - // namespace() creates a mutable object that persists across scopes - // keyword args are not parsed at AST level, so we return empty object - // The actual initialization happens via {% set ns.attr = val %} - return json::object(); - } - // Type-checking built-in tests - if (call->name == "__is_string") { - if (!call->args.empty()) { - json val = evalExpr(call->args[0].get()); - return val.is_string(); - } - return false; - } - if (call->name == "__is_none") { - if (!call->args.empty()) { - json val = evalExpr(call->args[0].get()); - return val.is_null(); - } - return false; - } - if (call->name == "__is_number") { - if (!call->args.empty()) { - json val = evalExpr(call->args[0].get()); - return val.is_number(); - } - return false; - } - return nullptr; - } - if (auto *method = dynamic_cast(node)) { - return evalMethodCall(method); - } - if (auto *slice = dynamic_cast(node)) { - return evalSlice(slice); - } - if (auto *contains = dynamic_cast(node)) { - json item = evalExpr(contains->item.get()); - json container = evalExpr(contains->container.get()); - // String containment: 'x' in 'xyz' - if (item.is_string() && container.is_string()) { - return container.get().find(item.get()) != - std::string::npos; - } - // Array containment - if (container.is_array()) { - for (const auto &elem : container) { - if (elem == item) - return true; - } - return false; - } - // Object key containment - if (container.is_object() && item.is_string()) { - return container.contains(item.get()); - } - return false; - } - return nullptr; - } - - /** @brief Evaluate a method call on a string object */ - json evalMethodCall(MethodCallExpr *node) { - json obj = evalExpr(node->object.get()); - const std::string &method = node->method; - - if (obj.is_string()) { - std::string s = obj.get(); - - if (method == "startswith" && !node->args.empty()) { - json arg = evalExpr(node->args[0].get()); - if (arg.is_string()) { - std::string prefix = arg.get(); - return s.size() >= prefix.size() && - s.compare(0, prefix.size(), prefix) == 0; - } - return false; - } - if (method == "endswith" && !node->args.empty()) { - json arg = evalExpr(node->args[0].get()); - if (arg.is_string()) { - std::string suffix = arg.get(); - return s.size() >= suffix.size() && - s.compare(s.size() - suffix.size(), suffix.size(), suffix) == - 0; - } - return false; - } - if (method == "strip") { - std::string chars = " \t\n\r"; - if (!node->args.empty()) { - json arg = evalExpr(node->args[0].get()); - if (arg.is_string()) - chars = arg.get(); - } - size_t start = s.find_first_not_of(chars); - if (start == std::string::npos) - return std::string(""); - size_t end = s.find_last_not_of(chars); - return s.substr(start, end - start + 1); - } - if (method == "lstrip") { - std::string chars = " \t\n\r"; - if (!node->args.empty()) { - json arg = evalExpr(node->args[0].get()); - if (arg.is_string()) - chars = arg.get(); - } - size_t start = s.find_first_not_of(chars); - if (start == std::string::npos) - return std::string(""); - return s.substr(start); - } - if (method == "rstrip") { - std::string chars = " \t\n\r"; - if (!node->args.empty()) { - json arg = evalExpr(node->args[0].get()); - if (arg.is_string()) - chars = arg.get(); - } - size_t end = s.find_last_not_of(chars); - if (end == std::string::npos) - return std::string(""); - return s.substr(0, end + 1); - } - if (method == "split" && !node->args.empty()) { - json arg = evalExpr(node->args[0].get()); - if (arg.is_string()) { - std::string delimiter = arg.get(); - json result = json::array(); - size_t pos = 0; - size_t found; - while ((found = s.find(delimiter, pos)) != std::string::npos) { - result.push_back(s.substr(pos, found - pos)); - pos = found + delimiter.size(); - } - result.push_back(s.substr(pos)); - return result; - } - } - if (method == "upper") { - std::string result = s; - std::transform(result.begin(), result.end(), result.begin(), ::toupper); - return result; - } - if (method == "lower") { - std::string result = s; - std::transform(result.begin(), result.end(), result.begin(), ::tolower); - return result; - } - } - - return nullptr; - } - - /** @brief Evaluate an array slice expression */ - json evalSlice(SliceExpr *node) { - json obj = evalExpr(node->object.get()); - if (!obj.is_array()) - return json::array(); - - int size = static_cast(obj.size()); - int start = 0, stop = size, step = 1; - - if (node->start) - start = evalExpr(node->start.get()).get(); - if (node->stop) - stop = evalExpr(node->stop.get()).get(); - if (node->step) - step = evalExpr(node->step.get()).get(); - - // Handle negative indices - if (start < 0) - start = std::max(0, size + start); - if (stop < 0) - stop = std::max(0, size + stop); - - // Clamp - start = std::max(0, std::min(start, size)); - stop = std::max(0, std::min(stop, size)); - - json result = json::array(); - if (step > 0) { - for (int i = start; i < stop; i += step) - result.push_back(obj[i]); - } else if (step < 0) { - // Reverse iteration: e.g., [::-1] - if (!node->start) - start = size - 1; - if (!node->stop) - stop = -1; - else if (stop < 0) - stop = std::max(-1, size + stop); - // Re-clamp for reverse - if (start >= size) - start = size - 1; - for (int i = start; i > stop; i += step) { - if (i >= 0 && i < size) - result.push_back(obj[i]); - } - } - - return result; - } - - /** @brief Evaluate a binary operation expression */ - json evalBinary(BinaryExpr *node) { - json left = evalExpr(node->left.get()); - json right = evalExpr(node->right.get()); - - if (node->op == "+" || node->op == "~") { - if (node->op == "~") { - // Tilde always does string concat - return jsonToString(left) + jsonToString(right); - } - if (left.is_string() && right.is_string()) { - return left.get() + right.get(); - } - if (left.is_number() && right.is_number()) { - return left.get() + right.get(); - } - // String + non-string: convert to string - return jsonToString(left) + jsonToString(right); - } - if (node->op == "-") { - if (left.is_number() && right.is_number()) { - return left.get() - right.get(); - } - return 0; - } - if (node->op == "==") { - return left == right; - } - if (node->op == "!=") { - return left != right; - } - if (node->op == ">") { - if (left.is_number() && right.is_number()) - return left.get() > right.get(); - return false; - } - if (node->op == "<") { - if (left.is_number() && right.is_number()) - return left.get() < right.get(); - return false; - } - if (node->op == ">=") { - if (left.is_number() && right.is_number()) - return left.get() >= right.get(); - return false; - } - if (node->op == "<=") { - if (left.is_number() && right.is_number()) - return left.get() <= right.get(); - return false; - } - if (node->op == "%") { - if (left.is_number_integer() && right.is_number_integer()) { - int r = right.get(); - if (r != 0) - return left.get() % r; - } - return 0; - } - if (node->op == "and") { - return isTruthy(left) && isTruthy(right); - } - if (node->op == "or") { - return isTruthy(left) || isTruthy(right); - } - return nullptr; - } - - /** @brief Check if a JSON value is truthy */ - bool isTruthy(const json &val) { - if (val.is_null()) - return false; - if (val.is_boolean()) - return val.get(); - if (val.is_number_integer()) - return val.get() != 0; - if (val.is_string()) - return !val.get().empty(); - if (val.is_array()) - return !val.empty(); - if (val.is_object()) - return !val.empty(); - return false; - } - - /** @brief Convert a JSON value to its string representation */ - std::string jsonToString(const json &val) { - if (val.is_string()) - return val.get(); - if (val.is_null()) - return ""; - if (val.is_boolean()) - return val.get() ? "True" : "False"; - if (val.is_number_integer()) - return std::to_string(val.get()); - if (val.is_number_float()) - return std::to_string(val.get()); - return val.dump(); - } - - /** @brief Look up a variable name in scope chain */ - json lookupVariable(const std::string &name) { - // Search from innermost scope to outermost - for (auto it = scopes_.rbegin(); it != scopes_.rend(); ++it) { - if (it->contains(name)) - return (*it)[name]; - } - return nullptr; - } - - /** @brief Set a variable in the current innermost scope */ - void setVariable(const std::string &name, const json &value) { - // Set in the current (innermost) scope - if (!scopes_.empty()) { - scopes_.back()[name] = value; - } - } - - std::vector scopes_; -}; - -// ============================================================================ -// ChatTemplate Implementation -// ============================================================================ - -ChatTemplate::ChatTemplate() : available_(false) {} - -ChatTemplate ChatTemplate::fromFile(const std::string &tokenizer_config_path) { - ChatTemplate tmpl; - - std::ifstream file(tokenizer_config_path); - if (!file.is_open()) { - std::cerr << "[ChatTemplate] Warning: cannot open " << tokenizer_config_path - << std::endl; - return tmpl; - } - - json config; - try { - file >> config; - } catch (const json::parse_error &e) { - std::cerr << "[ChatTemplate] Warning: JSON parse error in " - << tokenizer_config_path << ": " << e.what() << std::endl; - return tmpl; - } - - // Extract chat_template - if (config.contains("chat_template")) { - if (config["chat_template"].is_string()) { - tmpl.template_str_ = config["chat_template"].get(); - } else if (config["chat_template"].is_array()) { - // Some models have an array of templates; use the first one - for (const auto &entry : config["chat_template"]) { - if (entry.is_object() && entry.contains("template")) { - tmpl.template_str_ = entry["template"].get(); - break; - } - } - } - } - - if (tmpl.template_str_.empty()) { - std::cerr << "[ChatTemplate] Warning: no 'chat_template' field found in " - << tokenizer_config_path << std::endl; - return tmpl; - } - - // Extract bos_token (can be string or object with "content" field) - if (config.contains("bos_token")) { - if (config["bos_token"].is_string()) { - tmpl.bos_token_ = config["bos_token"].get(); - } else if (config["bos_token"].is_object() && - config["bos_token"].contains("content")) { - tmpl.bos_token_ = config["bos_token"]["content"].get(); - } - } - - // Extract eos_token - if (config.contains("eos_token")) { - if (config["eos_token"].is_string()) { - tmpl.eos_token_ = config["eos_token"].get(); - } else if (config["eos_token"].is_object() && - config["eos_token"].contains("content")) { - tmpl.eos_token_ = config["eos_token"]["content"].get(); - } - } - - tmpl.available_ = true; - return tmpl; -} - -std::string ChatTemplate::apply(const std::vector &messages, - bool add_generation_prompt) const { - if (!available_) - return ""; - - // Build context - json context; - json msgs = json::array(); - for (const auto &msg : messages) { - json m; - m["role"] = msg.role; - m["content"] = msg.content; - msgs.push_back(m); - } - context["messages"] = msgs; - context["bos_token"] = bos_token_; - context["eos_token"] = eos_token_; - context["add_generation_prompt"] = add_generation_prompt; - - return render(template_str_, context); -} - -std::string ChatTemplate::apply(const std::string &user_input, - bool add_generation_prompt) const { - std::vector messages = {{"user", user_input}}; - return apply(messages, add_generation_prompt); -} - -bool ChatTemplate::isAvailable() const { return available_; } - -std::string ChatTemplate::getBosToken() const { return bos_token_; } - -std::string ChatTemplate::getEosToken() const { return eos_token_; } - -std::string ChatTemplate::render(const std::string &tmpl, - const json &context) const { - try { - Lexer lexer(tmpl); - auto tokens = lexer.tokenize(); - - Parser parser(tokens); - auto ast = parser.parse(); - - Evaluator evaluator(context); - return evaluator.evaluate(ast); - } catch (const std::exception &e) { - std::cerr << "[ChatTemplate] Render error: " << e.what() << std::endl; - return ""; - } -} - -} // namespace quick_dot_ai diff --git a/chat_template.h b/chat_template.h deleted file mode 100644 index 40b20603..00000000 --- a/chat_template.h +++ /dev/null @@ -1,106 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * Copyright (C) 2026 Samsung Electronics Co., Ltd. All Rights Reserved. - * - * @file chat_template.h - * @date 10 Apr 2026 - * @brief Chat template support using tokenizer_config.json - * @see https://github.com/nntrainer/Quick.AI - * @author Eunju Yang - * @bug No known bugs except for NYI items - */ - -#ifndef __CHAT_TEMPLATE_H__ -#define __CHAT_TEMPLATE_H__ - -#include -#include - -#include "json.hpp" - -namespace quick_dot_ai { - -using json = nlohmann::json; - -/** - * @brief Chat message structure for multi-turn conversations - */ -struct ChatMessage { - std::string role; // "system", "user", "assistant" - std::string content; // message content -}; - -/** - * @brief Chat template class that reads and applies HuggingFace chat templates - * - * Loads chat_template from tokenizer_config.json and renders it using a - * minimal Jinja2 subset renderer. Supports common constructs used in - * HuggingFace chat templates: for loops, if/elif/else, variable access, - * string operations, loop variables, and filters. - */ -class ChatTemplate { -public: - /** - * @brief Default constructor (no template loaded) - */ - ChatTemplate(); - - /** - * @brief Load chat template from tokenizer_config.json - * @param tokenizer_config_path Path to tokenizer_config.json - * @return ChatTemplate instance - */ - static ChatTemplate fromFile(const std::string &tokenizer_config_path); - - /** - * @brief Apply template to multi-turn messages - * @param messages Vector of ChatMessage (role + content) - * @param add_generation_prompt Whether to add generation prompt at end - * @return Formatted prompt string - */ - std::string apply(const std::vector &messages, - bool add_generation_prompt = true) const; - - /** - * @brief Apply template to a single user input (convenience) - * @param user_input Raw user input string - * @param add_generation_prompt Whether to add generation prompt at end - * @return Formatted prompt string - */ - std::string apply(const std::string &user_input, - bool add_generation_prompt = true) const; - - /** - * @brief Check if a chat template is loaded and available - * @return true if template is available - */ - bool isAvailable() const; - - /** - * @brief Get BOS token - */ - std::string getBosToken() const; - - /** - * @brief Get EOS token - */ - std::string getEosToken() const; - -private: - std::string template_str_; - std::string bos_token_; - std::string eos_token_; - bool available_ = false; - - /** - * @brief Render a Jinja2 template with the given context - * @param tmpl Jinja2 template string - * @param context JSON object with template variables - * @return Rendered string - */ - std::string render(const std::string &tmpl, const json &context) const; -}; - -} // namespace quick_dot_ai - -#endif // __CHAT_TEMPLATE_H__ diff --git a/cross/android-aarch64.cross.in b/cross/android-aarch64.cross.in new file mode 100644 index 00000000..c8ef5865 --- /dev/null +++ b/cross/android-aarch64.cross.in @@ -0,0 +1,21 @@ +# Meson cross file for Android aarch64 (arm64-v8a) +# Generated from template by build.sh β€” do not edit the generated copy. +# @ANDROID_NDK@ is substituted with the ANDROID_NDK environment variable. + +[binaries] +c = '@ANDROID_NDK@/toolchains/llvm/prebuilt/linux-x86_64/bin/aarch64-linux-android29-clang' +cpp = '@ANDROID_NDK@/toolchains/llvm/prebuilt/linux-x86_64/bin/aarch64-linux-android29-clang++' +ar = '@ANDROID_NDK@/toolchains/llvm/prebuilt/linux-x86_64/bin/llvm-ar' +strip = '@ANDROID_NDK@/toolchains/llvm/prebuilt/linux-x86_64/bin/llvm-strip' + +[built-in options] +c_args = ['-march=armv8.2-a+fp16+dotprod+i8mm', '-mcpu=cortex-a53', '-mtune=cortex-a76', '-O3', '-ffast-math'] +cpp_args = ['-march=armv8.2-a+fp16+dotprod+i8mm', '-mcpu=cortex-a53', '-mtune=cortex-a76', '-O3', '-ffast-math', '-frtti', '-fexceptions'] +c_link_args = ['-fexceptions', '-fopenmp', '-static-openmp'] +cpp_link_args = ['-fexceptions', '-fopenmp', '-static-openmp'] + +[host_machine] +system = 'android' +cpu_family = 'aarch64' +cpu = 'armv8' +endian = 'little' diff --git a/docs/Architecture.md b/docs/Architecture.md new file mode 100644 index 00000000..e094ea45 --- /dev/null +++ b/docs/Architecture.md @@ -0,0 +1,162 @@ +# Quick.AI Native Architecture πŸ›οΈ + +Quick.AI extends nntrainer's CausalLM application with custom model +implementations, QNN support, XGrammar structured generation, and a deployable +C API. + +## 🧱 Native Layers + +```text +nntrainer/Applications/CausalLM + β”œβ”€β”€ main.cpp, Factory, tokenizer, ChatTemplate + └── base CausalLM/Transformer implementations + +Quick.AI + β”œβ”€β”€ src/models/ # self-registering model implementations + β”œβ”€β”€ qnn/ # Android QNN context and SDK wrappers + β”œβ”€β”€ src/xgrammar/ # XGrammar manager/wrapper + └── api/quick_dot_ai_api.* # handle-based deployment API +``` + +The native build produces these main artifacts: + +| Artifact | Built from | Purpose | +|---|---|---| +| `builddir_*/src/quick_dot_ai` | nntrainer `main.cpp` + Quick.AI static model archive | Standalone runner | +| `builddir_*/src/libquick_dot_ai.so` | Quick.AI model extension objects | `LD_PRELOAD` plugin mode | +| `builddir_*/api/libquick_dot_ai_api.so` | `api/quick_dot_ai_api.cpp` + model deps | Public C API for apps/JNI | +| `builddir_android/qnn/libqnn_context.so` | `qnn/` | QNN context plugin, Android QNN builds | + +## 🧩 Self Registration + +Model implementations register themselves with nntrainer's CausalLM factory +before `main()` or API load-time execution: + +```cpp +__attribute__((constructor)) static void register_my_models() { + causallm::Factory::Instance().registerModel( + "MyModelForCausalLM", + [](causallm::json cfg, causallm::json generation_cfg, + causallm::json nntr_cfg) { + return std::make_unique( + cfg, generation_cfg, nntr_cfg); + }); +} +``` + +This keeps Quick.AI model additions outside the nntrainer submodule and makes +model implementations independently addable under `src/models/`. + +## 🧠 C API Runtime + +The public API is `api/quick_dot_ai_api.h`. + +- Legacy single-model functions still exist for compatibility. +- New work should use `CausalLmHandle`. +- Each handle owns its model instances, output buffer, metrics, and mutex. +- Different handles can be loaded independently. +- Streaming APIs are synchronous calls that emit token deltas through a callback. + +Important entry points: + +| API | Purpose | +|---|---| +| `loadModelHandle()` | Load one model handle | +| `runModelHandleStreaming()` | Stream a raw prompt | +| `runModelHandleWithMessagesStreaming()` | Stream OpenAI-style messages | +| `runModelHandleWithJsonStreaming()` | Stream full OpenAI JSON requests | +| `runModelHandleWithTool()` | Run XGrammar-constrained structured generation | +| `runMultimodalHandle*()` | Run image + text paths when supported by the handle | +| `cancelModelHandle()` | Request cooperative cancellation | +| `destroyModelHandle()` | Release handle resources | +| `unloadModelHandle()` | Unload model (handle remains valid) | +| `getPerformanceMetricsHandle()` | Get per-handle performance metrics | +| `saveQnnKvCacheHandle()` | Save QNN KV cache | +| `loadQnnKvCacheHandle()` | Load QNN KV cache | +| `resetQnnKvCacheHandle()` | Reset QNN KV cache | + +## Model Registry + +The API layer maintains a **string-keyed self-registering model descriptor +catalog** separate from the nntrainer CausalLM factory. + +### Self-registration + +Each model descriptor translation unit (`src/model_descriptors_.cpp`) +declares a `quick_dot_ai::ModelDescriptor` struct and registers it at load +time: + +```cpp +static quick_dot_ai::ModelDescriptor desc = { + .id = "qwen3-0.6b", + .family = "qwen3-0.6b", + .display_name = "Qwen3 0.6B", + .runtime = 0, // 0 = NATIVE, 1 = LITERT + .backend_mask = /* CPU|GPU bitmask */, + .capabilities = /* STREAMING|TOOL_USE bitmask */, + .config_name = "qwen3_0_6b", + .arch_string = "Qwen3ForCausalLM", +}; + +__attribute__((constructor)) static void register_descriptors() { + quick_dot_ai::register_model_descriptor(&desc); +} +``` + +This runs before `main()` / API first-call, adding the descriptor to a +process-global registry. No central switch statement or header change is +needed β€” just link in the TU. + +### Catalog API + +| Function | Purpose | +|---|---| +| `loadModelHandleByName(backend, model_id, quant, lib_dir, base_path, out)` | Preferred load path β€” routes through the descriptor registry | +| `getModelCatalogJson()` | Returns a JSON array of all registered descriptors | + +`getModelCatalogJson()` returns a JSON array in this shape: + +```json +[ + { + "id": "qwen3-0.6b", + "family": "qwen3-0.6b", + "display_name": "Qwen3 0.6B", + "runtime": 0, + "backend_mask": 3, + "capabilities": 9, + "config_name": "qwen3_0_6b", + "arch_string": "Qwen3ForCausalLM" + } +] +``` + +### ModelType enum status + +The `CAUSAL_LM_MODEL_*` C enum is a **deprecated compatibility shim**. +Ordinals are preserved for ABI compatibility, so the values are not +contiguous. All new code should use string model ids and +`loadModelHandleByName()`. + +## 🧰 Build System + +The root `build.sh` prepares nntrainer, tokenizer assets, Android cross files, +and Meson options. `src/` is always built. `api/`, `api-app/`, and `qnn/` are +enabled by build flags: + +```bash +./build.sh +./build.sh --platform=android --enable-qnn +./build.sh --target=src,api +``` + +Meson options live in `meson_options.txt`. + +## πŸ“Ž Related Docs + +- [Main README](../README.md) +- [C API Reference](../api/README.md) +- [Chat and OpenAI Usage Examples](ChatAndOpenAIUsage.md) +- [Chat Templates](ChatTemplate.md) +- [XGrammar Reference](XGrammarReference.md) +- [QNN Context Guide](../qnn/README.md) diff --git a/docs/ChatAndOpenAIUsage.md b/docs/ChatAndOpenAIUsage.md new file mode 100644 index 00000000..9af6add0 --- /dev/null +++ b/docs/ChatAndOpenAIUsage.md @@ -0,0 +1,338 @@ +# Chat and OpenAI Usage Examples πŸ’¬ + +This is the canonical place for Quick.AI chat, OpenAI-style request, and +structured-output examples. Keep API signatures in the API reference docs; keep +end-to-end usage examples here. + +## 🧭 Choose the Right API + +| Goal | Android API | Native C API | Notes | +|---|---|---|---| +| Chat tab style conversation | `openChatSession()` β†’ `runChatModelHandleStreaming()` β†’ `closeChatSession()` | session is Android-facing | Keeps backend-managed conversation state. | +| OpenAI-style messages | `runModelHandleWithMessagesStreaming()` | `runModelHandleWithMessagesStreaming()` | Best for models that need message-template formatting but not full JSON fields. | +| Full OpenAI JSON request | `runModelHandleWithJsonStreaming()` | `runModelHandleWithJsonStreaming()` | Preserves `messages`, `tools`, legacy `functions`, and template kwargs understood by the chat template. | +| Hard schema-constrained output | not exposed by the AAR today | `runModelHandleWithTool()` | Uses XGrammar token masking. This is different from OpenAI JSON `tools`. | + +`tools` in an OpenAI JSON request are passed into the model's chat template. +They help the model see tool metadata and produce a tool-call shaped answer. +XGrammar is stricter: it masks invalid tokens during decoding so the output +matches a JSON schema. + +## βš™οΈ Common Setup + +Every path needs a loaded engine or handle before inference. + +```kotlin +// Resolve a descriptor from the catalog and let the factory pick the engine. +val descriptor = ModelCatalog.byId(ModelIds.GEMMA4_E2B_QNN) ?: return +val engine: QuickDotAI = createEngine(applicationContext, descriptor) + +val loaded = engine.load( + LoadModelRequest( + modelId = descriptor.id, + backend = BackendType.NPU, + quantization = QuantizationType.W4A32, + nativeLibDir = applicationInfo.nativeLibraryDir, + modelBasePath = "/sdcard/Download/aistudio-mobile/models" + ) +) +``` + +Streaming methods report deltas through `StreamSink`. + +```kotlin +val sink = object : StreamSink { + override fun onDelta(text: String) { + outputView.append(text) + } + + override fun onReasoningDelta(text: String) { + // Called for thinking/reasoning model tokens + reasoningView.append(text) + } + + override fun onDone() { + setStatus("Done.") + } + + override fun onError(error: QuickAiError, message: String?) { + setStatus("Failed: [${error.name}] ${message.orEmpty()}") + } +} +``` + +Message and JSON APIs require a usable chat template. The native loader checks +the model directory for `chat_template.jinja` or +`tokenizer_config.json.chat_template`; see [Chat Templates](ChatTemplate.md). + +## πŸ’¬ Chat Tab Pattern + +The Chat tab is session based. It opens one active session per engine, sends +turns into that session, and closes it when the user leaves or reloads. + +```kotlin +val config = QuickAiChatSessionConfig( + systemInstruction = "You are concise.", + sampling = QuickAiChatSamplingConfig( + temperature = 0.7, + topK = 40, + topP = 0.9, + seed = 42, + minP = null, // Min-P sampling + maxTokens = null // Maximum generation tokens + ), + chatTemplateKwargs = QuickAiChatTemplateKwargs(enableThinking = false) +) + +when (val opened = engine.openChatSession(config)) { + is BackendResult.Ok -> { + val sessionId = opened.value + setStatus("Chat session opened: ${sessionId.take(8)}") + } + is BackendResult.Err -> { + setStatus("Chat open failed: [${opened.error.name}] ${opened.message.orEmpty()}") + } +} +``` + +Send text through the active session: + +```kotlin +when (val result = engine.runChatModelHandleStreaming("Explain KV cache.", sink)) { + is BackendResult.Ok -> { + val metrics = result.value.metrics + setStatus("Chat done. ${metrics?.totalDurationMs?.toLong() ?: "?"} ms") + } + is BackendResult.Err -> { + setStatus("Chat failed: [${result.error.name}] ${result.message.orEmpty()}") + } +} +``` + +Reset or close the session when needed: + +```kotlin +engine.chatRebuild(emptyList()) +engine.closeChatSession() +``` + +For multimodal chat input, build `PromptPart` values. The sample app currently +routes image-attached chat turns through the direct multimodal handle path. + +```kotlin +val parts = listOf( + PromptPart.ImageBytes(imageBytes), + PromptPart.Text("Describe the image.") +) + +engine.runMultimodalHandleStreaming(parts, sink) +``` + +## πŸ“‘ OpenAI Tab Pattern + +The OpenAI tab accepts OpenAI-style JSON from the UI and then routes by the +selected model's **capabilities**. Models that advertise the `MESSAGES_API` +capability need model-specific message formatting, so they use the messages +API; other native models use full JSON streaming. + +```kotlin +private fun usesMessagesApi(d: ModelDescriptor) = + Capability.MESSAGES_API in d.capabilities + +if (usesMessagesApi(descriptor)) { + val messages = parseOpenAIMessages(jsonText) ?: return + engine.runModelHandleWithMessagesStreaming(messages, sink) +} else { + engine.runModelHandleWithJsonStreaming(jsonText, sink) +} +``` + +Use messages streaming when the request is just an ordered chat history: + +```kotlin +val messages = listOf( + QuickAiChatMessage( + role = QuickAiChatRole.SYSTEM, + parts = listOf(PromptPart.Text("You are concise.")) + ), + QuickAiChatMessage( + role = QuickAiChatRole.USER, + parts = listOf(PromptPart.Text("Write a one-line haiku.")) + ) +) + +engine.runModelHandleWithMessagesStreaming(messages, sink) +``` + +Use JSON streaming when you need the full OpenAI request shape: + +```kotlin +val jsonRequest = """ +{ + "messages": [ + {"role": "developer", "content": "You can call tools."}, + {"role": "user", "content": "Set an alarm for 7 AM."} + ], + "tools": [ + { + "type": "function", + "function": { + "name": "set_alarm", + "description": "Create an alarm.", + "parameters": { + "type": "object", + "properties": { + "time": {"type": "string"}, + "label": {"type": "string"} + }, + "required": ["time"] + } + } + } + ] +} +""".trimIndent() + +engine.runModelHandleWithJsonStreaming(jsonRequest, sink) +``` + +Legacy OpenAI `functions` are accepted by the chat-template renderer as raw +function schemas: + +```json +{ + "messages": [ + {"role": "user", "content": "Send a short status email."} + ], + "functions": [ + { + "name": "send_email", + "description": "Send email.", + "parameters": { + "type": "object", + "properties": { + "to": {"type": "string"}, + "body": {"type": "string"} + }, + "required": ["to", "body"] + } + } + ] +} +``` + +## 🧱 Native Call Flow + +The Android wrapper is thin. It converts Kotlin DTOs to native inputs and then +uses the C API. + +| User path | JNI/native path | Formatting step | Inference step | +|---|---|---|---| +| `runModelHandleWithMessagesStreaming()` | `QuickAiChatMessage[]` β†’ `CausalLMChatMessage[]` | `apply_chat_template_messages()` | `run_model_streaming_on_handle(..., input_already_formatted=true)` | +| `runModelHandleWithJsonStreaming()` | JSON string | `g_chat_template->apply(request)` | `run_model_streaming_on_handle(..., input_already_formatted=true)` | +| `runModelHandleWithTool()` | prompt + tool name/schema | `XGrammarManager` attaches grammar | `run_on_handle()` with grammar mask active | + +Native messages example: + +```c +CausalLMChatMessage messages[] = { + {.role = "system", .content = "You are concise."}, + {.role = "user", .content = "Hello!"} +}; + +ErrorCode err = runModelHandleWithMessagesStreaming( + handle, + messages, + 2, + true, + callback, + user_data); +``` + +Native JSON streaming example: + +```c +const char *request = + "{" + "\"messages\":[" + "{\"role\":\"user\",\"content\":\"Summarize Quick.AI.\"}" + "]" + "}"; + +ErrorCode err = runModelHandleWithJsonStreaming( + handle, + request, + callback, + user_data); +``` + +## 🧩 XGrammar Examples + +Use XGrammar when you need the generated text itself to obey a JSON schema. A +model directory can preload grammars from `Toolset.json`. + +```json +{ + "set_alarm": { + "type": "object", + "properties": { + "time": {"type": "string", "pattern": "^\\d{2}:\\d{2}$"}, + "label": {"type": "string"} + }, + "required": ["time"] + } +} +``` + +After `loadModelHandle()`, call the tool path by name: + +```c +const char *output = NULL; + +ErrorCode err = runModelHandleWithTool( + handle, + "Create an alarm for 07:00.", + &output, + "set_alarm", + NULL); +``` + +For an ad-hoc schema, pass the schema on first use: + +```c +const char *schema = + "{" + "\"type\":\"object\"," + "\"properties\":{\"answer\":{\"type\":\"string\"}}," + "\"required\":[\"answer\"]" + "}"; + +ErrorCode err = runModelHandleWithTool( + handle, + "Return the answer as JSON.", + &output, + "answer_schema", + schema); +``` + +The XGrammar manager compiles schemas once and can load/save +`Toolset.json.cache` for faster subsequent loads. See +[XGrammar Reference](XGrammarReference.md) for cache behavior and C++ manager +details. + +## βœ… Troubleshooting + +| Symptom | Likely cause | Fix | +|---|---|---| +| `runModelHandleWithJsonStreaming()` returns `CAUSAL_LM_ERROR_UNSUPPORTED` | The loaded model has no chat template cached. | Add `chat_template.jinja` or `tokenizer_config.json.chat_template` next to the model config. | +| JSON streaming returns `CAUSAL_LM_ERROR_INVALID_PARAMETER` | The request is not valid JSON or a required pointer is null. | Validate the JSON and ensure `messages` is non-empty for normal chat use. | +| OpenAI tab loses `tools` on `MESSAGES_API` models (e.g. `gemma4-e2b-qnn`, `gemma4`) | The sample routes models with the `MESSAGES_API` capability through messages streaming. | Use a model that supports full JSON streaming, or add a dedicated model-specific full JSON path. | +| `tools` are visible to the model but output is not schema-valid | OpenAI JSON `tools` only guide the chat template. | Use `runModelHandleWithTool()` and XGrammar for hard constraints. | +| Chat tab says no active session | `openChatSession()` has not succeeded or the session was closed. | Open a session first, then call `runChatModelHandleStreaming()`. | + +## πŸ“Ž Related Docs + +- [QuickDotAI AAR API](../Android/QuickDotAI/README.md) +- [C API Reference](../api/README.md) +- [Chat Templates](ChatTemplate.md) +- [XGrammar Reference](XGrammarReference.md) diff --git a/docs/ChatTemplate.md b/docs/ChatTemplate.md new file mode 100644 index 00000000..4c772772 --- /dev/null +++ b/docs/ChatTemplate.md @@ -0,0 +1,87 @@ +# Chat Templates πŸ’¬ + +Quick.AI uses chat templates to convert OpenAI-style chat requests into the +model-specific prompt strings expected by nntrainer/LiteRT-LM style models. +The behavior is intentionally close to Hugging Face +`tokenizer.apply_chat_template()`. + +## πŸ”Ž Template Discovery + +For a model directory, Quick.AI looks for a template in this order: + +1. `/chat_template.jinja` +2. `/tokenizer_config.json`, field `chat_template` +3. Built-in fallback formatting for selected architectures + +`tokenizer_config.json.chat_template` may be a string, an object of named +templates, or an array converted into named templates. When named templates are +available, `tool_use` is selected for requests containing tools; otherwise +`default` is preferred. + +Special tokens are loaded from `tokenizer_config.json` and +`special_tokens_map.json` when present. + +## 🧱 Native API Integration + +Chat templates are used by these C API paths: + +| API | Input | +|---|---| +| `applyChatTemplate()` | `CausalLMChatMessage[]` | +| `runModelHandleStreaming()` | Raw prompt string (uses chat template when `g_use_chat_template` is true and input is not already formatted) | +| `runModelHandleWithMessages()` | `CausalLMChatMessage[]` | +| `runModelHandleWithMessagesStreaming()` | `CausalLMChatMessage[]` | +| `runModelHandleWithJsonStreaming()` | OpenAI-style JSON string | + +For JSON streaming, the loaded model must provide a usable chat template. If no +template is available, `runModelHandleWithJsonStreaming()` returns +`CAUSAL_LM_ERROR_UNSUPPORTED`. + +## πŸ“¦ OpenAI JSON Shape + +```json +{ + "messages": [ + { "role": "developer", "content": "You can call tools." }, + { "role": "user", "content": "Call mom." } + ], + "tools": [ + { + "type": "function", + "function": { + "name": "call", + "description": "Make a phone call.", + "parameters": { + "type": "object", + "properties": { + "name": { "type": "string" } + }, + "required": ["name"] + } + } + } + ] +} +``` + +Legacy OpenAI `functions` is accepted as an alias for raw function schemas. + +## πŸ§‘β€πŸ’» Usage Examples + +End-to-end Chat tab, OpenAI tab, native messages, and JSON streaming examples +live in [Chat and OpenAI Usage Examples](ChatAndOpenAIUsage.md). + +## ⚠️ Notes + +- `developer` role handling depends on the template and renderer options. +- `tool` and function-call message formatting depends on the model template. +- Native message APIs use text-only `role/content` pairs; Android + `QuickAiChatMessage` can also carry image parts for multimodal methods. +- Template files should live next to the model config files used by + `loadModelHandle()`. + +## πŸ“Ž Related Docs + +- [Chat and OpenAI Usage Examples](ChatAndOpenAIUsage.md) +- [C API Reference](../api/README.md) +- [QuickDotAI AAR API](../Android/QuickDotAI/README.md) diff --git a/docs/Guides.md b/docs/Guides.md new file mode 100644 index 00000000..cfd54139 --- /dev/null +++ b/docs/Guides.md @@ -0,0 +1,98 @@ +# Quick.AI Guides & Examples 🧭 + +This is the documentation hub for the current Quick.AI repository. Start with +the path that matches what you are building. + +## πŸš€ Quick Start by User Type + +### Android App Developer + +Use the `QuickDotAI` AAR directly from an Android app. The current Gradle build +contains `:QuickDotAI` and `:SampleTestAPP`. + +| Guide | What you get | +|---|---| +| [Chat and OpenAI Usage Examples](ChatAndOpenAIUsage.md) | Chat tab, OpenAI tab, JSON streaming, and XGrammar examples | +| [QuickDotAI AAR API](../Android/QuickDotAI/README.md) | Kotlin API, model loading, streaming, chat sessions | +| [Android Architecture](../Android/Architecture.md) | Current module layout and planned REST/service layer | +| [Android Native Async & Streaming](../Android/AsyncAndStreaming.md) | How JNI and the C streaming callback connect | + +For app-level examples, start with the usage guide and use the AAR API +reference for exact type definitions. + +### C/C++ Developer + +Use the handle-based C API directly from native applications. + +| Guide | What you get | +|---|---| +| [Chat and OpenAI Usage Examples](ChatAndOpenAIUsage.md) | Native messages, JSON streaming, and XGrammar examples | +| [C API Reference](../api/README.md) | Function signatures, enums, and error codes | +| [Build Options](../README.md#-building) | Meson flags and Android/x86 build commands | +| [Chat Templates](ChatTemplate.md) | `messages`, `tools`, and `functions` formatting | + +For native examples, start with the usage guide and use the C API reference for +full signatures and error codes. + +### Model Developer + +Extend Quick.AI with a new CausalLM architecture or QNN model. + +| Guide | What you get | +|---|---| +| [Custom Model Guide](../README.md#-how-to-create-a-custom-model) | Model registration and Meson wiring | +| [Native Architecture](Architecture.md) | Plugin system and build artifacts | +| [QNN Context Guide](../qnn/README.md) | QNN backend/context extension details | + +## Adding a New Model + +Adding a model requires only a new translation unit β€” no changes to the +`ModelType` enum, `loadModelHandle`, or UI code are needed. + +1. Create `src/model_descriptors_.cpp` with a `ModelDescriptor` struct + and an `__attribute__((constructor))` that calls + `quick_dot_ai::register_model_descriptor(&desc)`. The descriptor fields + include `id` (string), `family`, `display_name`, `runtime` (0=NATIVE or + 1=LITERT), `backend_mask`, `capabilities`, `config_name`, and + `arch_string`. +2. Add the new TU to `src/meson.build` so it is linked into + `libquick_dot_ai_api.so`. +3. The C API catalog (`getModelCatalogJson()`) and the Android `ModelCatalog` + singleton will automatically reflect the new model after the library is + rebuilt β€” no additional registration steps are required. +4. If the model needs a new nntrainer architecture, register it with the + CausalLM factory as described in the [Custom Model Guide](../README.md#-how-to-create-a-custom-model). + +See [Native Architecture](Architecture.md) for the full descriptor struct +layout and the `register_model_descriptor` call convention. + +## 🧩 Feature Guides + +| Feature | Guide | +|---|---| +| Chat/OpenAI usage examples | [Chat and OpenAI Usage Examples](ChatAndOpenAIUsage.md) | +| Structured output and tool calling | [XGrammar Reference](XGrammarReference.md) | +| OpenAI JSON request streaming | [Chat and OpenAI Usage Examples](ChatAndOpenAIUsage.md) | +| Chat templates | [Chat Templates](ChatTemplate.md) | +| QNN SDK setup | [How to Install QNN](HowToInstallQNN.md) | + +## πŸ“‹ API References + +- [C API](../api/README.md) +- [Android AAR API](../Android/QuickDotAI/README.md) +- [Kotlin DTO source](../Android/QuickDotAI/src/main/java/com/example/quickdotai/Types.kt) + +## πŸ—οΈ Architecture & Design + +| Document | Topic | +|---|---| +| [Native Architecture](Architecture.md) | Self-registration, model factory, native build outputs | +| [Android Architecture](../Android/Architecture.md) | Current AAR/sample modules and planned REST service | +| [Android Native Async & Streaming](../Android/AsyncAndStreaming.md) | C callback streaming through JNI | + +## πŸ”— Quick Links + +- [Main README](../README.md) +- [Chat and OpenAI Usage Examples](ChatAndOpenAIUsage.md) +- [QuickDotAI AAR](../Android/QuickDotAI/README.md) +- [C API Reference](../api/README.md) diff --git a/docs/HowToInstallQNN.md b/docs/HowToInstallQNN.md new file mode 100644 index 00000000..5829c035 --- /dev/null +++ b/docs/HowToInstallQNN.md @@ -0,0 +1,70 @@ +# How to Install QNN and Hexagon SDK βš™οΈ +:last update 2025-03-25: + +There are various ways to install qnn. +In this doc, we recommend you to use qpm-cli. +If you install the QNN and Hexagon SDK for NNtrainer usage, please follow the versions of: + +- QNN (a.k.a., Qualcomm Neural Processing SDK) version 2.31.0.250130 +- HexagonSDK version 5.5.2.0 + +## Prepare QPM CLI + + +1. Download qpm-cli + +> https://qpm.qualcomm.com/#/main/tools/details/QPM3 +2. Install qpm-cli + +> sudo apt-get install ./QualcommPackageManager3.3.0.117.0.Linux-x86.deb +- In order to use qpm-cli, you may need to login +``` +$ qpm-cli login +``` + +## Install QNN v 2.31 + +Install qnn version 2.31 + +``` +$ qpm-cli --product-list | grep neural + qualcomm_neural_processing_sdk + qualcomm_neural_processing_sdk_public +$ qpm-cli --install qualcomm_neural_processing_sdk --version 2.31.0.250130 +``` + +- Please be sure where the qnn is installed. +- It would be `/opt/qcom/aistack/qairt/2.31.0.250130` by default. + +## Install Hexagon SDK 5.5.2.0 + +``` +$ qpm-cli --license-activate hexagonsdk5.x +$ qpm-cli --install hexagonsdk5.x --version 5.5.2.0 +[Error] : Required dependency criteria not met. HexagonSDK6.x should be installed before installing Compute1.x +[Error] : Required version of component HexagonSDK6.x.Core is not installed on the machine +[Warning] : Compute1.x.1.12.0.Linux-x64.qik was not installed. Reason: ErrorProcessingComponents +[Info] : SUCCESS: Installed HexagonSDK5.x.Core at /local/mnt/workspace/Qualcomm/Hexagon_SDK/5.5.2.0 +``` + +### Trouble shooting + +1. Access to the path is denied + +Phenomenon: +``` +[Info] : Extracting files +[Fatal] : Access to the path '/local/mnt/workspace/Qualcomm/Hexagon_SDK/5.5.2.0' is denied. +[Error] : Installation failed with Exception +``` + +Trouble-shoot: +``` +$ mkdir -p /local/mnt/workspace/Qualcomm/ +$ sudo chmod 777 /local/mnt/workspace/Qualcomm/ +``` + +2. error while loading shared libraries: libtinfo.so.5: cannot open shared object file: No such file or directory +``` +sudo apt install libncurses5 +``` diff --git a/docs/XGrammarReference.md b/docs/XGrammarReference.md new file mode 100644 index 00000000..aff3196b --- /dev/null +++ b/docs/XGrammarReference.md @@ -0,0 +1,181 @@ +# XGrammar Reference 🧩 + +This document is the XGrammar reference for Quick.AI. It explains the runtime +components, toolset files, cache behavior, and native API contract. For +end-to-end Chat tab, OpenAI JSON, and XGrammar usage examples, start with +[Chat and OpenAI Usage Examples](ChatAndOpenAIUsage.md). + +## πŸ”Ž Overview + +[XGrammar](https://github.com/mlc-ai/xgrammar) provides constrained decoding: +Quick.AI masks invalid next tokens during sampling so generated text follows a +target grammar or JSON schema. + +Quick.AI currently uses XGrammar for: + +- JSON-schema-constrained output +- tool/function-call shaped output +- reusable grammars loaded from model-local toolset files + +This is separate from OpenAI JSON `tools`. OpenAI JSON `tools` are rendered +through the model's chat template; XGrammar enforces the output shape during +token selection. + +## 🧱 Architecture + +```text +XGrammarManager + - TokenizerInfo, shared per loaded model + - GrammarCompiler, shared per loaded model + - compiled_grammars_: tool_name -> XGrammar + +XGrammar + - CompiledGrammar + - GrammarMatcher + - bitmask of allowed next tokens + - (optional) own TokenizerInfo and GrammarCompiler + +Model sampler + - asks XGrammar for the next-token mask + - samples from the masked logits + - accepts the sampled token into the GrammarMatcher +``` + +| Component | Responsibility | +|---|---| +| `XGrammarManager` | Owns shared tokenizer/compiler state and cached tool grammars. | +| `TokenizerInfo` | Stores the model vocabulary in the format XGrammar needs. | +| `GrammarCompiler` | Compiles JSON schemas, EBNF grammars, or regex patterns. | +| `XGrammar` | Holds one compiled grammar and its matcher state. May also own `TokenizerInfo` and `GrammarCompiler` for independent (non-shared) usage. | +| `GrammarMatcher` | Tracks grammar progress and produces allowed-token masks. | + +## πŸ”„ Runtime Flow + +At model load time, Quick.AI initializes the XGrammar manager from the loaded +tokenizer. If the model directory contains `Toolset.json`, Quick.AI precompiles +the listed schemas and stores them by tool name. + +At inference time, `runModelHandleWithTool()`: + +1. Looks up the requested tool name in `XGrammarManager`. +2. Registers `tool_schema` dynamically if the tool is missing and a schema was + provided. +3. Attaches the resulting `XGrammar` instance to the model. +4. Runs inference while the sampler applies grammar masks. +5. Resets the model grammar state after generation. + +## πŸ“¦ Toolset Files + +Place `Toolset.json` next to the model files used by `loadModelHandle()`. + +```json +{ + "tool_name": { + "type": "object", + "properties": { + "field": { "type": "string" } + }, + "required": ["field"] + } +} +``` + +The top-level object maps a tool name to a JSON Schema. Tool names are later +passed to `runModelHandleWithTool()`. + +## 🧠 Native API + +```c +ErrorCode runModelHandleWithTool(CausalLmHandle handle, + const char *inputTextPrompt, + const char **outputText, + const char *tool_name, + const char *tool_schema); +``` + +| Parameter | Meaning | +|---|---| +| `handle` | Loaded model handle. | +| `inputTextPrompt` | Prompt to run with grammar constraints. | +| `outputText` | Receives the generated output pointer. | +| `tool_name` | Name of a precompiled or dynamically registered schema. | +| `tool_schema` | JSON Schema string used when `tool_name` is not already registered. | + +Return values follow the common `ErrorCode` contract in the +[C API Reference](../api/README.md). + +## 🧰 C++ Manager API + +Direct C++ integrations can use `causallm::XGrammarManager::Instance()` from +`src/xgrammar/xgrammar_manager.h`. + +| Method | Purpose | +|---|---| +| `initialize(tokenizer, vocab_size)` | Build shared tokenizer/compiler state. | +| `loadToolset(path, tokenizer, vocab_size)` | Load and compile `Toolset.json`. | +| `hasTool(name)` | Check whether a grammar is registered. | +| `getGrammar(name)` | Fetch a compiled `XGrammar`. | +| `registerTool(name, schema)` | Compile and register a schema at runtime. | +| `resetGrammar(name)` | Reset matcher state after a run. | +| `getToolNames()` | List registered tool names. | +| `isInitialized()` | Check whether the manager has been initialized. | +| `clear()` | Drop compiled grammars and shared state. | + +### XGrammar Public Methods + +| Method | Purpose | +|---|---| +| `loadFromCache(serialized, tokenizer_info, vocab_size)` | Load a grammar from a serialized cache. | +| `serialize()` | Serialize the compiled grammar to JSON. | +| `applyGrammarMask(float*, int)` | Apply grammar mask to FP32 logits. | +| `applyGrammarMask(uint16_t*, int, float, int)` | Apply grammar mask to quantized (FP16) logits. | +| `initializeGrammar(type, schema, grammar_compiler, vocab_size)` | Initialize grammar using a shared `GrammarCompiler` (overload). | + +## πŸ—‚οΈ Cache Behavior + +XGrammar compilation can be expensive. Quick.AI uses a sidecar cache: + +```text +model/ + Toolset.json + Toolset.json.cache +``` + +Load behavior: + +1. If `Toolset.json.cache` exists and is valid, grammars are loaded from cache. +2. If the cache is missing or incomplete, Quick.AI compiles from `Toolset.json`. +3. After successful compilation, Quick.AI saves a fresh cache file. + +Delete `Toolset.json.cache` when changing schemas during development. + +## βœ… JSON Schema Support + +XGrammar supports the schema features used by Quick.AI tool definitions: + +- object, string, number, integer, boolean, array +- `properties` +- `required` +- `enum` +- `pattern` +- nested objects and arrays + +Keep schemas narrow and explicit for best constrained-decoding behavior. + +## πŸ› οΈ Troubleshooting + +| Symptom | Cause | Fix | +|---|---|---| +| Tool not found | `tool_name` is absent and `tool_schema` is null. | Add the tool to `Toolset.json` or pass a schema on first use. | +| Compilation failed | Unsupported or malformed JSON Schema. | Validate the schema and start with a smaller object shape. | +| Output is valid JSON but semantically wrong | Grammar only constrains structure. | Add stronger prompt instructions or narrower schema fields. | +| First load is slow | Toolset schemas are being compiled. | Keep the generated `Toolset.json.cache` for later loads. | +| Cache seems stale | The schema changed but cache was reused. | Delete `Toolset.json.cache` and reload the model. | + +## πŸ“Ž Related Docs + +- [Chat and OpenAI Usage Examples](ChatAndOpenAIUsage.md) +- [Chat Templates](ChatTemplate.md) +- [C API Reference](../api/README.md) +- [XGrammar Official Documentation](https://xgrammar.mlc.ai/docs/) +- [XGrammar GitHub Repository](https://github.com/mlc-ai/xgrammar) diff --git a/docs/architecture.md b/docs/architecture.md deleted file mode 100644 index 3dfa182c..00000000 --- a/docs/architecture.md +++ /dev/null @@ -1,158 +0,0 @@ -# Quick.AI Architecture - -This document describes how Quick.AI is layered, what each binary and `.so` is for, and how the build wires everything together. For a top-level intro, see the [project README](../README.md). - ---- - -## Bird's-eye view - -```mermaid -flowchart TB - subgraph Apps["Your application / CLI"] - BIN["quick_dot_ai_run
quick_dot_ai_quantize
quick_dot_ai_test_api"] - end - - subgraph QAI["β˜„οΈ Quick.AI"] - direction TB - API["libquick_dot_ai_api.so
stable C API"] - CORE["libquick_dot_ai.so
causal-LM engine Β· namespace quick_dot_ai"] - subgraph LAYERS["Per-layer plugins (.so)"] - direction LR - L1["rms_norm"] - L2["swiglu"] - L3["mha_core"] - L4["qkv"] - L5["lm_head"] - L6["tied_embed"] - L7["embed_pool"] - L8["…"] - end - end - - subgraph DEPS["Foundations"] - NNT["NNTrainer
subprojects/nntrainer Β· meson subproject"] - SYS["OpenBLAS Β· OpenMP Β· Flatbuffers"] - end - - BIN --> API - BIN --> CORE - API --> CORE - CORE --> LAYERS - CORE --> NNT - LAYERS --> NNT - NNT --> SYS - - classDef app fill:#fff7e6,stroke:#fa8c16,color:#874d00 - classDef qai fill:#e6f4ff,stroke:#1677ff,color:#003a8c - classDef plugin fill:#f6ffed,stroke:#52c41a,color:#135200 - classDef dep fill:#f5f5f5,stroke:#8c8c8c,color:#262626 - class BIN app - class API,CORE qai - class L1,L2,L3,L4,L5,L6,L7,L8 plugin - class NNT,SYS dep -``` - ---- - -## Layers, top to bottom - -### 1. Binaries (`quick_dot_ai_*`) - -| Binary | Source | Purpose | -|---|---|---| -| `quick_dot_ai_run` | `main.cpp` | Interactive / one-shot text generation against a prepared model directory. | -| `quick_dot_ai_quantize` | `quantize.cpp` | Convert an FP32 checkpoint into Q4_0 / Q4_K / Q6_K / FP16 in place or to a new directory. | -| `quick_dot_ai_test_api` | `api/test_api.cpp` | Smoke test exercising the C API end-to-end. | - -All three link against `libquick_dot_ai.so` and (where relevant) `libquick_dot_ai_api.so`. - -### 2. C API (`libquick_dot_ai_api.so`) - -The integration surface for host applications β€” Android JNI, iOS, server processes, anything that wants to call into Quick.AI without taking a C++ dependency. - -- Header: [`api/causal_lm_api.h`](../api/causal_lm_api.h) -- Symbols: `loadModel`, `runModel`, `getPerformanceMetrics`, plus the `BackendType` / `ModelType` / `ModelQuantizationType` enums. -- **ABI policy.** These symbols and enums are **deliberately not renamed** as part of the Quick.AI rebrand β€” existing embedders keep building unmodified. - -### 3. Core engine (`libquick_dot_ai.so`) - -The actual causal-LM runtime. Lives in `namespace quick_dot_ai`. Key entry points: - -- `models/causal_lm.{h,cpp}` β€” base class that drives token generation, KV-cache management, and the inference loop. -- `models/transformer.{h,cpp}` β€” generic transformer assembly used by the per-family models. -- `models//_causallm.{h,cpp}` β€” concrete model classes (Qwen 2/3, GPT-OSS, Gemma 3, …). -- `factory.h` β€” registers every model class so `loadModel` can dispatch by `ModelType`. - -### 4. Per-layer plugins (`build/layers/libquick_dot_ai_*_layer.so`) - -Each transformer building block is its own `shared_library` β€” declared in [`layers/meson.build`](../layers/meson.build) and registered through `causallm_common_properties.h`. - -| Plugin | Role | -|---|---| -| `rms_norm` / `reshaped_rms_norm` | RMSNorm, with an optional reshape friendly to MoE routing. | -| `swiglu` | SwiGLU MLP activation. | -| `qkv` | Fused Q/K/V projection. | -| `mha_core` | Multi-head attention kernel (NEON / AVX2 hot paths). | -| `lm_head` | Final projection to vocabulary logits. | -| `tied_embed` | Tied input/output embedding (mmap-friendly). | -| `embed_layer`, `embed_pool`, `embed_normalize` | Token embedding + sentence-embedding heads. | - -Why plugins? Two reasons: (a) you can drop a new attention kernel into `layers/` and rebuild only that one `.so`; (b) downstream products can ship a slimmer subset by only linking the layers their model needs. - -### 5. NNTrainer - -Pulled in via `subprojects/nntrainer/` as a **Meson subproject**, pinned to a specific commit by the parent submodule. - -The top-level `meson.build` declares it as: - -```meson -nntrainer_proj = subproject('nntrainer', - default_options: [ - 'enable-app=false', # don't build NNTrainer's own Applications - 'enable-test=false', # skip its unit-test suite - 'enable-tflite-backbone=false', - 'enable-tflite-interpreter=false', - 'werror=false', - ], -) -``` - -so only the core engine and the C++ API ride along β€” not the rest of NNTrainer's PR-gated peripherals. - -### 6. System foundations - -| Dependency | Used for | -|---|---| -| OpenBLAS | dense GEMM in NNTrainer's tensor backend | -| OpenMP | thread-pool parallelism across the inference graph | -| Flatbuffers | NNTrainer's serialized model format | - -On Android the same role is filled by NDK-bundled libomp + an in-tree BLAS path; see [`jni/Android.mk`](../jni/Android.mk). - ---- - -## Design choices worth knowing - -### Stable C API as the integration seam -Renaming during a brand change is tempting; we resisted on `api/causal_lm_api.h` because it would silently break every embedder downstream. Quick.AI's C symbols and enum prefixes (`CAUSAL_LM_QUANTIZATION_*`) stay as-is. - -### Lean subproject build of NNTrainer -NNTrainer's own CI exercises a much wider matrix (Tizen / Yocto / Windows / NNStreamer / TFLite). For Quick.AI's purposes we only want `nntrainer_dep` and `nntrainer_ccapi_dep`, so the subproject `default_options` strip almost everything else. This is also why the Quick.AI Linux build only needs `libopenblas-dev`, `libflatbuffers-dev`, and `flatbuffers-compiler` from apt β€” no `tensorflow2-lite-dev`, no `nnstreamer-dev`. - -### Per-layer `.so`s instead of one fat library -Custom transformer pieces are loaded as plugins, so swapping in a new attention kernel doesn't require relinking the whole library. It also keeps debug builds fast and lets distributions ship per-feature subsets. - -### Flash Storage Utilization (FSU) is opt-in per model -The `*-slim` model variants under `models/` use FSU to stream MoE experts from disk. This is what powers the 16.5 GB β†’ 1.3 GB demo on the README β€” it isn't a global runtime mode, but a property of the model graph definition. - ---- - -## Adding a new model family - -1. Create `models//`. -2. Implement `_causallm.{h,cpp}` deriving from the appropriate causal-LM template. -3. Register the new file list in `models//meson.build` and append it to `quick_dot_ai_src` / `quick_dot_ai_inc`. -4. Add the family enum + factory entry in [`factory.h`](../factory.h) so `loadModel` can dispatch to it. -5. (Optional) Implement custom layers under `layers/` and append the resulting deps to `quick_dot_ai_layer_dependencies` in the top-level `meson.build`. - -A model author guide with concrete examples lives in [`models/README.md`](../models/README.md). diff --git a/docs/superpowers/plans/2026-06-02-family-dropdown.md b/docs/superpowers/plans/2026-06-02-family-dropdown.md new file mode 100644 index 00000000..088b7105 --- /dev/null +++ b/docs/superpowers/plans/2026-06-02-family-dropdown.md @@ -0,0 +1,244 @@ +# Model FAMILY Dropdown Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** λͺ¨λΈ FAMILY 선택을 κ°€λ‘œ 슀크둀 μΉ©(`chipRow`)μ—μ„œ Material μŠ€νƒ€μΌ λ“œλ‘­λ‹€μš΄ ν•„λ“œ(`dropdownField`)둜 κ΅μ²΄ν•œλ‹€. Run/OpenAI νƒ­κ³Ό Chat νƒ­ 두 곳에 μ μš©ν•œλ‹€. + +**Architecture:** `chipRow()`와 λ™μΌν•œ μ‹œκ·Έλ‹ˆμ²˜λ₯Ό κ°€μ§„ μ‹ κ·œ 헬퍼 `dropdownField()`λ₯Ό μΆ”κ°€ν•œλ‹€. FAMILY ν˜ΈμΆœλΆ€ 2κ³³μ—μ„œ ν•¨μˆ˜λͺ…λ§Œ `chipRow` β†’ `dropdownField`둜 κ΅μ²΄ν•˜λ©°, FAMILY λ³€κ²½ μ‹œ RUNTIME/BACKENDλ₯Ό μž¬κ³„μ‚°ν•˜λŠ” cascading λžŒλ‹€λŠ” κ·ΈλŒ€λ‘œ μž¬μ‚¬μš©ν•œλ‹€. RUNTIME/BACKEND/QUANTIZATION 및 `ModelCatalog.kt`λŠ” λ³€κ²½ν•˜μ§€ μ•ŠλŠ”λ‹€. + +**Tech Stack:** Kotlin, ν΄λž˜μ‹ Android View (μ½”λ“œλ‘œ 직접 UI 생성), Material 3 토큰 μ‹œμŠ€ν…œ, `android.widget.PopupMenu`. λΉŒλ“œ: Gradle (AGP 8.9.1, JDK 17). + +**μ°Έκ³  μŠ€νŽ™:** `docs/superpowers/specs/2026-06-02-family-dropdown-design.md` + +--- + +## λΉŒλ“œ/검증 ν™˜κ²½ (λ©”λͺ¨λ¦¬ `quickai-android-build-env` μš”μ•½) + +μžλ™ν™” UI ν…ŒμŠ€νŠΈκ°€ μ—†μœΌλ―€λ‘œ 각 μ½”λ“œ νƒœμŠ€ν¬λŠ” **Kotlin 컴파일 톡과**둜 κ²€μ¦ν•˜κ³ , λ§ˆμ§€λ§‰μ— **λ””λ°”μ΄μŠ€ μ„€μΉ˜ ν›„ μˆ˜λ™ 확인**ν•œλ‹€. + +νƒœμŠ€ν¬ μ‹€ν–‰ μ „ 셸에 μ•„λž˜ ν™˜κ²½ λ³€μˆ˜λ₯Ό export ν•΄μ•Ό ν•œλ‹€ (μ—†μœΌλ©΄ λΉŒλ“œ μ‹€νŒ¨): + +```bash +export NDK_ROOT=/home/jiyoung/Android/Sdk/ndk/android-ndk-r26b +export JAVA_HOME=/home/jiyoung/jdks/jdk-17.0.19+10 +export ANDROID_HOME=/home/jiyoung/Android/Sdk +export ANDROID_SERIAL=R3CX80H8Y0F # SM-S936U (S25+). minSdk=33 μΆ©μ‘± 기기만. +``` + +λΉ λ₯Έ 컴파일 검증 λͺ…λ Ή (λ„€μ΄ν‹°λΈŒ λΉŒλ“œ 없이 Kotlin만): + +```bash +./Android/gradlew :SampleTestAPP:compileDebugKotlin +``` + +--- + +## File Structure + +- **Modify only:** `Android/SampleTestAPP/src/main/java/com/example/sampletestapp/MainActivity.kt` + - import 블둝: `android.widget.PopupMenu` μΆ”κ°€ + - `chipRow()` μ •μ˜λΆ€(~1588) 근처: `dropdownField()` μ‹ κ·œ μΆ”κ°€ + - `:565` Run/OpenAI νƒ­ FAMILY 호좜: `chipRow` β†’ `dropdownField` + - `:796` Chat νƒ­ FAMILY 호좜: `chipRow` β†’ `dropdownField` +- **λ³€κ²½ μ—†μŒ:** `Android/QuickDotAI/.../ModelCatalog.kt`, 기타 λͺ¨λ“  파일 + +참고둜 이 νŒŒμΌμ—λŠ” 이미 λ‹€μŒ 헬퍼/μž„ν¬νŠΈκ°€ μ‘΄μž¬ν•˜μ—¬ κ·ΈλŒ€λ‘œ μž¬μ‚¬μš©ν•œλ‹€: +`M3Tokens`(ν•„λ“œ: `surfaceContainer`, `outline`, `onSurface`, `onSurfaceVar`), `solid()`, `strokedSolid()`, `dp()`, μž„ν¬νŠΈλœ `MATCH_PARENT`/`WRAP_CONTENT`/`Gravity`/`Color`. + +--- + +## Task 1: `dropdownField()` 헬퍼 μΆ”κ°€ + +**Files:** +- Modify: `Android/SampleTestAPP/src/main/java/com/example/sampletestapp/MainActivity.kt` (import 블둝 + `chipRow()` μ •μ˜ 직후) + +- [ ] **Step 1: `PopupMenu` import μΆ”κ°€** + +`MainActivity.kt`의 import λΈ”λ‘μ—μ„œ `import android.widget.LinearLayout`(57ν–‰ λΆ€κ·Ό) λ°”λ‘œ λ‹€μŒ 쀄에 μΆ”κ°€ν•œλ‹€ (μ•ŒνŒŒλ²³ μˆœμ„œμƒ LinearLayoutκ³Ό ScrollView 사이): + +```kotlin +import android.widget.PopupMenu +``` + +- [ ] **Step 2: `dropdownField()` ν•¨μˆ˜ μΆ”κ°€** + +`chipRow()` ν•¨μˆ˜ μ •μ˜κ°€ λλ‚˜λŠ” 지점(λ‹«λŠ” `}` λ‹€μŒ, `filledButton()` μ •μ˜ μ•ž, μ•½ 1617~1618ν–‰)에 μ•„λž˜ ν•¨μˆ˜λ₯Ό μΆ”κ°€ν•œλ‹€. μ‹œκ·Έλ‹ˆμ²˜λŠ” `chipRow()`와 λ™μΌν•˜λ‹€. + +```kotlin + private fun dropdownField(t: M3Tokens, options: List, selected: String, + onPick: (String) -> Unit): View { + val enabled = options.isNotEmpty() + val field = LinearLayout(this).apply { + orientation = LinearLayout.HORIZONTAL + gravity = Gravity.CENTER_VERTICAL + background = strokedSolid( + if (enabled) t.surfaceContainer else Color.TRANSPARENT, 8, t.outline, 1) + setPadding(dp(12), dp(10), dp(12), dp(10)) + layoutParams = LinearLayout.LayoutParams(MATCH_PARENT, WRAP_CONTENT) + } + val valueView = TextView(this).apply { + text = if (enabled) selected else "β€”" + setTextColor(if (enabled) t.onSurface else t.onSurfaceVar) + textSize = 14f + typeface = Typeface.DEFAULT_BOLD + layoutParams = LinearLayout.LayoutParams(0, WRAP_CONTENT, 1f) + } + val arrow = TextView(this).apply { + text = "β–Ύ" // β–Ύ + setTextColor(t.onSurfaceVar) + textSize = 14f + } + field.addView(valueView) + field.addView(arrow) + if (enabled) { + field.setOnClickListener { anchor -> + val menu = PopupMenu(this, anchor) + options.forEachIndexed { i, opt -> + menu.menu.add(0, i, i, opt).apply { + isCheckable = true + isChecked = opt == selected + } + } + menu.setOnMenuItemClickListener { item -> + onPick(options[item.itemId]) + true + } + menu.show() + } + } + return field + } +``` + +- [ ] **Step 3: 컴파일 검증** + +Run: `./Android/gradlew :SampleTestAPP:compileDebugKotlin` +Expected: `BUILD SUCCESSFUL`. (`dropdownField`λŠ” 아직 ν˜ΈμΆœλ˜μ§€ μ•Šμ•„ "never used" κ²½κ³ κ°€ λ‚  수 μžˆμœΌλ‚˜ μ—λŸ¬λŠ” μ•„λ‹ˆλ‹€.) + +- [ ] **Step 4: 컀밋** + +```bash +git add Android/SampleTestAPP/src/main/java/com/example/sampletestapp/MainActivity.kt +git commit -m "feat: add dropdownField() Material dropdown helper" +``` + +--- + +## Task 2: Run/OpenAI νƒ­ FAMILYλ₯Ό λ“œλ‘­λ‹€μš΄μœΌλ‘œ ꡐ체 + +**Files:** +- Modify: `Android/SampleTestAPP/src/main/java/com/example/sampletestapp/MainActivity.kt:565` + +- [ ] **Step 1: FAMILY ν˜ΈμΆœλΆ€ ν•¨μˆ˜λͺ… ꡐ체** + +565ν–‰μ˜ `chipRow(` λ₯Ό `dropdownField(` 둜 λ°”κΎΌλ‹€. μΈμžμ™€ λžŒλ‹€λŠ” λ³€κ²½ν•˜μ§€ μ•ŠλŠ”λ‹€. λ³€κ²½ ν›„ 블둝은 λ‹€μŒκ³Ό κ°™μ•„μ•Ό ν•œλ‹€: + +```kotlin + body.addView(labelView(t, "FAMILY")) + body.addView(dropdownField(t, ModelCatalog.families(), selFamily) { picked -> + selFamily = picked + selRuntime = ModelCatalog.runtimesFor(selFamily).firstOrNull() ?: selRuntime + selBackend = ModelCatalog.backendsFor(selFamily, selRuntime).firstOrNull() ?: selBackend + modelPathText = defaultModelPathFor(selDescriptor, selectedQuant) ?: "" + rebuildUi(resetModelPath = true) + }) +``` + +- [ ] **Step 2: 컴파일 검증** + +Run: `./Android/gradlew :SampleTestAPP:compileDebugKotlin` +Expected: `BUILD SUCCESSFUL`. + +- [ ] **Step 3: 컀밋** + +```bash +git add Android/SampleTestAPP/src/main/java/com/example/sampletestapp/MainActivity.kt +git commit -m "feat: use dropdownField for FAMILY in Run/OpenAI tab" +``` + +--- + +## Task 3: Chat νƒ­ FAMILYλ₯Ό λ“œλ‘­λ‹€μš΄μœΌλ‘œ ꡐ체 + +**Files:** +- Modify: `Android/SampleTestAPP/src/main/java/com/example/sampletestapp/MainActivity.kt:796` + +- [ ] **Step 1: FAMILY ν˜ΈμΆœλΆ€ ν•¨μˆ˜λͺ… ꡐ체** + +796ν–‰μ˜ `chipRow(` λ₯Ό `dropdownField(` 둜 λ°”κΎΌλ‹€. μΈμžμ™€ λžŒλ‹€λŠ” λ³€κ²½ν•˜μ§€ μ•ŠλŠ”λ‹€. λ³€κ²½ ν›„ 블둝은 λ‹€μŒκ³Ό κ°™μ•„μ•Ό ν•œλ‹€: + +```kotlin + modelCard.addView(labelView(t, "FAMILY")) + modelCard.addView(dropdownField(t, ModelCatalog.families(), chatSelFamily) { picked -> + chatSelFamily = picked + chatSelRuntime = ModelCatalog.runtimesFor(chatSelFamily).firstOrNull() ?: chatSelRuntime + chatSelBackend = ModelCatalog.backendsFor(chatSelFamily, chatSelRuntime).firstOrNull() ?: chatSelBackend + clearChatSessionState() + rebuildUi() + }) +``` + +- [ ] **Step 2: 컴파일 검증** + +Run: `./Android/gradlew :SampleTestAPP:compileDebugKotlin` +Expected: `BUILD SUCCESSFUL`. `dropdownField`κ°€ 이제 μ‚¬μš©λ˜λ―€λ‘œ "never used" 경고도 사라진닀. + +- [ ] **Step 3: 컀밋** + +```bash +git add Android/SampleTestAPP/src/main/java/com/example/sampletestapp/MainActivity.kt +git commit -m "feat: use dropdownField for FAMILY in Chat tab" +``` + +--- + +## Task 4: λΉŒλ“œΒ·μ„€μΉ˜ ν›„ μˆ˜λ™ 검증 + +**Files:** (μ½”λ“œ λ³€κ²½ μ—†μŒ β€” 검증 μ „μš©) + +- [ ] **Step 1: λ””λ°”μ΄μŠ€ μ—°κ²° 및 ν™˜κ²½ λ³€μˆ˜ 확인** + +```bash +adb devices # R3CX80H8Y0F (SM-S936U)κ°€ 'device' μƒνƒœμΈμ§€ 확인 +echo "$JAVA_HOME $ANDROID_HOME $NDK_ROOT $ANDROID_SERIAL" # 4개 λͺ¨λ‘ 좜λ ₯λ˜λŠ”μ§€ +``` +Expected: λŒ€μƒ κΈ°κΈ°κ°€ `device` μƒνƒœ, 4개 λ³€μˆ˜ λͺ¨λ‘ λΉ„μ–΄μžˆμ§€ μ•ŠμŒ. + +- [ ] **Step 2: APK μ„€μΉ˜ (μ„œλͺ… 좩돌 μ‹œ μž¬μ„€μΉ˜)** + +λ¨Όμ € λΉ λ₯Έ μ„€μΉ˜λ₯Ό μ‹œλ„ν•œλ‹€ (λ„€μ΄ν‹°λΈŒ λΌμ΄λΈŒλŸ¬λ¦¬λŠ” μœ μ§€λ¨): + +```bash +./Android/gradlew ":SampleTestAPP:installDebug" +``` + +`INSTALL_FAILED_UPDATE_INCOMPATIBLE ... signatures do not match` κ°€ λ‚˜μ˜€λ©΄: + +```bash +adb -s "$ANDROID_SERIAL" uninstall com.example.sampletestapp +./Android/gradlew ":SampleTestAPP:installDebug" +``` + +Expected: `BUILD SUCCESSFUL`, 기기에 μ•± μ„€μΉ˜λ¨. + +- [ ] **Step 3: μˆ˜λ™ 확인 (κΈ°κΈ°μ—μ„œ 직접)** + +λ‹€μŒμ„ λͺ¨λ‘ ν™•μΈν•œλ‹€: +1. **Run/OpenAI νƒ­**: FAMILY μ˜μ—­μ΄ μΉ© 행이 μ•„λ‹ˆλΌ λ“œλ‘­λ‹€μš΄ ν•„λ“œ(`ν˜„μž¬κ°’ ... β–Ύ`)둜 ν‘œμ‹œλœλ‹€. +2. **Chat νƒ­**: FAMILY μ˜μ—­μ΄ λ™μΌν•˜κ²Œ λ“œλ‘­λ‹€μš΄ ν•„λ“œλ‘œ ν‘œμ‹œλœλ‹€. +3. **λ“œλ‘­λ‹€μš΄ λ™μž‘**: FAMILY ν•„λ“œλ₯Ό νƒ­ν•˜λ©΄ family λͺ©λ‘ νŒμ—…μ΄ 뜨고, ν˜„μž¬ 선택값에 체크 ν‘œμ‹œκ°€ μžˆλ‹€. +4. **Cascading νšŒκ·€ μ—†μŒ**: λ‹€λ₯Έ familyλ₯Ό μ„ νƒν•˜λ©΄ RUNTIME/BACKEND 칩이 μžλ™μœΌλ‘œ μž¬κ³„μ‚°λ˜μ–΄ 바뀐닀 (κΈ°μ‘΄ μΉ© μ‹œμ ˆκ³Ό 동일 λ™μž‘). MODEL PATH도 κ°±μ‹ λœλ‹€. +5. **λ‹€λ₯Έ μΆ• μœ μ§€**: RUNTIME / BACKEND / QUANTIZATION은 μ—¬μ „νžˆ μΉ© ν–‰μœΌλ‘œ ν‘œμ‹œλœλ‹€. +6. **ν…Œλ§ˆ**: 닀크/라이트 λͺ¨λ“œμ—μ„œ λ“œλ‘­λ‹€μš΄ ν•„λ“œ ν…Œλ‘λ¦¬Β·κΈ€μžμƒ‰μ΄ κΉ¨μ§€μ§€ μ•ŠλŠ”λ‹€. + +- [ ] **Step 4: 검증 κ²°κ³Ό 기둝** + +μœ„ 6개 ν•­λͺ© κ²°κ³Όλ₯Ό μ‚¬μš©μžμ—κ²Œ λ³΄κ³ ν•œλ‹€. λ¬Έμ œκ°€ 있으면 ν•΄λ‹Ή νƒœμŠ€ν¬λ‘œ λŒμ•„κ°€ μˆ˜μ •ν•œλ‹€. λͺ¨λ‘ ν†΅κ³Όν•˜λ©΄ μ™„λ£Œ. + +--- + +## Self-Review λ©”λͺ¨ + +- **μŠ€νŽ™ 컀버리지**: FAMILY만(βœ“ Task 1~3μ—μ„œ λ‹€λ₯Έ μΆ• λ―Έλ³€κ²½), 두 νƒ­(βœ“ Task 2 Run/OpenAI, Task 3 Chat), Material λ“œλ‘­λ‹€μš΄+PopupMenu(βœ“ Task 1), 빈 λͺ©λ‘ λ°©μ–΄(βœ“ Task 1 `enabled` λΆ„κΈ°), 라이트/닀크(βœ“ 토큰 μ‚¬μš© + Task 4 Step 3-6), cascading 보쑴(βœ“ λžŒλ‹€ 무변경) β€” μŠ€νŽ™ λͺ¨λ“  μš”κ΅¬μ‚¬ν•­μ΄ νƒœμŠ€ν¬λ‘œ 맀핑됨. +- **νƒ€μž…/μ‹œκ·Έλ‹ˆμ²˜ 일관성**: `dropdownField(t: M3Tokens, options: List, selected: String, onPick: (String) -> Unit): View` β€” Task 1 μ •μ˜μ™€ Task 2/3 호좜 인자 μˆœμ„œΒ·νƒ€μž… 일치. `chipRow`와 동일 μ‹œκ·Έλ‹ˆμ²˜λΌ ν˜ΈμΆœλΆ€λŠ” μ΄λ¦„λ§Œ ꡐ체. +- **ν”Œλ ˆμ΄μŠ€ν™€λ”**: μ—†μŒ (λͺ¨λ“  μ½”λ“œ/λͺ…λ Ή 전체 기재). diff --git a/docs/superpowers/plans/2026-06-05-lfm2-vl-siglip-pluggable.md b/docs/superpowers/plans/2026-06-05-lfm2-vl-siglip-pluggable.md new file mode 100644 index 00000000..dd372ab2 --- /dev/null +++ b/docs/superpowers/plans/2026-06-05-lfm2-vl-siglip-pluggable.md @@ -0,0 +1,808 @@ +# LFM2-VL(SigLIP + LFM2) pluggable composer 톡합 κ΅¬ν˜„ κ³„νš + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** nntrainer의 λͺ¨λ†€λ¦¬μ‹ LFM2-VL-450M을 vision 인코더 + LFM2 LM 두 개의 독립 λ‘œλ“œ κ°€λŠ₯ν•œ λͺ¨λΈλ‘œ λΆ„ν•΄ν•˜μ—¬, Quick.AI의 κΈ°μ‘΄ generic `execute_multimodal` composer(CPU 개방)둜 νŽ˜μ–΄λ§ν•˜κ³ , μ•±μ—μ„œ vision+LLM λ―ΉμŠ€μ•€λ§€μΉ˜λ‘œ 골라 이미지+ν…μŠ€νŠΈ 좔둠을 ν•œλ‹€. + +**Architecture:** nntrainer에 `Lfm2VlVisionEncoder`(ViT+connectorλ₯Ό λ¬Άμ–΄ `run_image`κ°€ LM μž„λ² λ”© 곡간 FP32 1024-dim 좜λ ₯) 래퍼λ₯Ό μ‹ κ·œ μΆ”κ°€ν•˜κ³ , `Lfm2CausalLM`에 composerκ°€ μš”κ΅¬ν•˜λŠ” base 가상(`embeddingBytesPerToken`/`lookupEmbedding(int)const`/`get_embedding_info`)을 μ˜€λ²„λΌμ΄λ“œν•œλ‹€. Quick.AI apiλŠ” 두 λͺ¨λΈμ„ Factory 등둝 + 곡개 descriptor μΆ”κ°€ν•˜κ³ , `#ifdef ENABLE_QNN` λ©€ν‹°λͺ¨λ‹¬ 경둜λ₯Ό CPU둜 κ°œλ°©ν•˜λ©° 이미지 마컀λ₯Ό LFM2(``=396)에 λ§žμΆ˜λ‹€. 앱은 `loadMultimodalHandleByName` JNI λ…ΈμΆœ + λ―ΉμŠ€μ•€λ§€μΉ˜ picker + SigLIP-NaFlex μ „μ²˜λ¦¬(MVP κ³ μ • 256Β²)λ₯Ό μΆ”κ°€ν•œλ‹€. + +**Tech Stack:** C++17 (nntrainer CausalLM, `causallm::Factory`/`Transformer`), Quick.AI C API(`libquick_dot_ai_api.so`), Android NDK JNI(`quickai_jni.cpp`), Kotlin(Jetpack Compose, `QuickDotAI` AAR + `SampleTestAPP`). 검증: ν—€λ“œλ¦¬μŠ€ `quick_dot_ai_test` β†’ APK on-device. + +**μ •λ‹΅ 였라클(oracle):** κΈ°μ‘΄ λͺ¨λ†€λ¦¬μ‹ `Lfm2VlForConditionalGeneration`(`nntrainer/.../lfm2/lfm2-vl/lfm2_vl_model.cpp`)의 `run()` 좜λ ₯. λͺ¨λ“  ν—€λ“œλ¦¬μŠ€ 검증은 동일 이미지/ν”„λ‘¬ν”„νŠΈμ—μ„œ 이 경둜의 생성 토큰열과 λΉ„κ΅ν•œλ‹€. + +**λΉŒλ“œ ν™˜κ²½:** λ©”λͺ¨λ¦¬ `quickai-android-build-env`(JDK/SDK 경둜, `./build.sh --platform=android`, `apk_install_android.sh`, `ANDROID_SERIAL`) 및 `gauss-pluggable-bringup`(APK verify 절차) μ°Έκ³ . + +--- + +## File Structure (λ³€κ²½/생성 파일 λ§΅) + +**nntrainer μ„œλΈŒλͺ¨λ“ˆ (곡개 μ½”λ“œ):** +- Create `nntrainer/Applications/CausalLM/models/lfm2/lfm2-vl/lfm2_vl_vision_encoder.h` β€” `Lfm2VlVisionEncoder` μ„ μ–Έ(ViT+connector μ†Œμœ , `run_image` μ˜€λ²„λΌμ΄λ“œ). +- Create `nntrainer/Applications/CausalLM/models/lfm2/lfm2-vl/lfm2_vl_vision_encoder.cpp` β€” κ΅¬ν˜„. +- Modify `nntrainer/Applications/CausalLM/models/lfm2/lfm2_causallm.h` β€” base 가상 3μ’… μ˜€λ²„λΌμ΄λ“œ μ„ μ–Έ + scratch 멀버. +- Modify `nntrainer/Applications/CausalLM/models/lfm2/lfm2_causallm.cpp` β€” κ·Έ κ΅¬ν˜„. +- Modify `nntrainer/Applications/CausalLM/models/lfm2/lfm2-vl/meson.build` (λ˜λŠ” μƒμœ„ meson) β€” μƒˆ .cpp λΉŒλ“œ 포함. +- Modify `nntrainer/Applications/CausalLM/main.cpp` β€” `Lfm2VlVisionEncoder` Factory 등둝(ν—€λ“œλ¦¬μŠ€ 단독 ν…ŒμŠ€νŠΈμš©). + +**Quick.AI api:** +- Modify `api/quick_dot_ai_api.cpp` β€” (a) `register_models()`에 `Lfm2ForCausalLM`+`Lfm2VlVisionEncoder` 등둝, (b) λ©€ν‹°λͺ¨λ‹¬ 경둜 CPU 개방, (c) 이미지 마컀 LFM2 μ •ν•©, (d) `run_vision_encoder` ν”½μ…€ λ ˆμ΄μ•„μ›ƒ λͺ¨λΈ 주도화. +- Modify `api/model_descriptors_public.cpp` β€” `lfm2-450m`(LM) + `siglip2-vl-encoder`(VISION_ENCODER) descriptor μΆ”κ°€. + +**Android:** +- Modify `Android/QuickDotAI/src/main/cpp/quickai_jni.cpp` β€” `loadMultimodalHandleByNameNative` μΆ”κ°€. +- Modify `Android/QuickDotAI/src/main/java/com/example/quickdotai/NativeCausalLm.kt` β€” external fun + 래퍼. +- Modify `Android/QuickDotAI/src/main/java/com/example/quickdotai/ModelCatalog.kt` β€” `visionEncoders()`/`pairableLlms()`. +- Create `Android/QuickDotAI/src/main/java/com/example/quickdotai/SigLipNaFlexImageProcessor.kt` β€” 256Β² μ „μ²˜λ¦¬. +- Modify `Android/SampleTestAPP/.../MainActivity.kt` β€” λ―ΉμŠ€μ•€λ§€μΉ˜ picker UI + νŽ˜μ–΄ λ‘œλ“œ + μ „μ²˜λ¦¬ λΆ„κΈ°. + +**λ””λ°”μ΄μŠ€ 에셋(λΉŒλ“œ μ‚°μΆœλ¬Ό μ•„λ‹˜, μˆ˜λ™ 배치):** +- `/sdcard/Download/aistudio-mobile/models/siglip2-vl-encoder/` (ViT+connector κ°€μ€‘μΉ˜ + nntr_config.json) +- `/sdcard/Download/aistudio-mobile/models/lfm2-450m/` (LFM2 LM κ°€μ€‘μΉ˜ + tokenizer + μž„λ² λ”© bin + nntr_config.json) + +--- + +## Milestone 1 β€” nntrainer: λͺ¨λΈ λΆ„ν•΄ + +### Task 1.1: `Lfm2CausalLM`에 composer용 base 가상 μ˜€λ²„λΌμ΄λ“œ μΆ”κ°€ + +**λ°°κ²½:** composer(`execute_multimodal`)λŠ” base 가상 `embeddingBytesPerToken()`/`const void* lookupEmbedding(int)const`/`get_embedding_info()`λ₯Ό ν˜ΈμΆœν•œλ‹€. ν˜„μž¬ `Lfm2CausalLM`은 concrete `std::vector lookupEmbedding(unsigned)`만 μžˆμ–΄ base 가상은 κΈ°λ³Έκ°’(0/nullptr)을 λ°˜ν™˜ β†’ composerκ°€ λ™μž‘ λΆˆκ°€. + +**Files:** +- Modify: `nntrainer/Applications/CausalLM/models/lfm2/lfm2_causallm.h` +- Modify: `nntrainer/Applications/CausalLM/models/lfm2/lfm2_causallm.cpp` + +- [ ] **Step 1: 헀더에 base 가상 μ˜€λ²„λΌμ΄λ“œ + scratch 멀버 μ„ μ–Έ μΆ”κ°€** + +`lfm2_causallm.h`의 `class Lfm2CausalLM` public μ˜μ—­(κΈ°μ‘΄ `std::vector lookupEmbedding(unsigned int token_id);` μ„ μ–Έ μ•„λž˜)에 μΆ”κ°€: + +```cpp + // ── Multimodal composer (base Transformer) interface ────────────────── + // The generic api composer drives this LM through base-class virtuals. + // These adapt the concrete FP32 embedding path to the model-agnostic API. + + /** Bytes of one token embedding (FP32 DIM). 0 until weights loaded. */ + size_t embeddingBytesPerToken() const override; + + /** Embedding of @p token_id as a raw FP32 row, or nullptr. Pointer is + * valid until the next call (per-call scratch buffer). */ + const void *lookupEmbedding(int token_id) const override; + + /** FP32 identity quant space: connector emits FP32 1024-dim directly. */ + std::pair get_embedding_info() override { return {1.0f, 0}; } +``` + +그리고 private μ˜μ—­(κΈ°μ‘΄ `embedding_*` μΊμ‹œ 멀버듀 근처)에 μΆ”κ°€: + +```cpp + /** Per-call scratch for the base-virtual lookupEmbedding(int) const. */ + mutable std::vector emb_scratch_; +``` + +- [ ] **Step 2: cpp에 κ΅¬ν˜„ μΆ”κ°€** + +`lfm2_causallm.cpp` 파일 끝(λ˜λŠ” κΈ°μ‘΄ `lookupEmbedding` μ •μ˜ μ•„λž˜)에 μΆ”κ°€. `DIM`은 base `Transformer`의 hidden-size 멀버(κΈ°μ‘΄ concrete `lookupEmbedding`이 μ‚¬μš©ν•˜λŠ” 것과 λ™μΌν•œ 멀버λͺ…을 κ·Έ ν•¨μˆ˜ λ³Έλ¬Έμ—μ„œ 확인해 λ™μΌν•˜κ²Œ μ‚¬μš©ν•œλ‹€ β€” λŒ€κ°œ `DIM`): + +```cpp +size_t Lfm2CausalLM::embeddingBytesPerToken() const { + // FP32 row of width DIM. Returns 0 if embedding weights not yet cached. + if (!embedding_weight_cached_) + return 0; + return static_cast(DIM) * sizeof(float); +} + +const void *Lfm2CausalLM::lookupEmbedding(int token_id) const { + if (!embedding_weight_cached_ || token_id < 0) + return nullptr; + // Reuse the concrete FP32 lookup, but it is non-const; cast away const for + // the cache-only access (no logical state change). The result is copied + // into the per-call scratch so the returned pointer stays valid until the + // next call (the composer memcpy's it immediately). + auto *self = const_cast(this); + emb_scratch_ = self->lookupEmbedding(static_cast(token_id)); + if (emb_scratch_.empty()) + return nullptr; + return emb_scratch_.data(); +} +``` + +> 주의: concrete `lookupEmbedding(unsigned)`와 base `lookupEmbedding(int)const`λŠ” μ‹œκ·Έλ‹ˆμ²˜κ°€ 달라 **μ˜€λ²„λ‘œλ“œ**둜 κ³΅μ‘΄ν•œλ‹€. cpp Step 2의 `self->lookupEmbedding(...)`λŠ” `unsigned int` 인자라 concreteκ°€ μ„ νƒλœλ‹€. 컴파일 μ‹œ λͺ¨ν˜Έμ„± κ²½κ³ κ°€ λ‚˜μ˜€λ©΄ concreteλ₯Ό `lookupEmbeddingVec`둜 개λͺ…ν•˜κ³  μ–‘μͺ½ ν˜ΈμΆœλΆ€(이 ν•¨μˆ˜ + `lfm2_vl_model.cpp`)λ₯Ό λ§žμΆ˜λ‹€. + +- [ ] **Step 3: λΉŒλ“œ 확인 (ν—€λ“œλ¦¬μŠ€ 호슀트 λΉŒλ“œ)** + +Run: `cd nntrainer && ./tools/package_android.sh 2>/dev/null; meson compile -C build 2>&1 | tail -20` *(μ‹€μ œ λΉŒλ“œ λͺ…령은 λ©”λͺ¨λ¦¬ `quickai-android-build-env` κΈ°μ€€μœΌλ‘œ λŒ€μ²΄)* +Expected: `Lfm2CausalLM` 컴파일 성곡, 링크 μ—λŸ¬ μ—†μŒ. + +- [ ] **Step 4: Commit** + +```bash +cd nntrainer && git add Applications/CausalLM/models/lfm2/lfm2_causallm.h Applications/CausalLM/models/lfm2/lfm2_causallm.cpp +git commit -m "feat(lfm2): expose base Transformer embedding virtuals for composer" +``` + +### Task 1.2: `Lfm2VlVisionEncoder` 래퍼 μ‹ κ·œ μž‘μ„± (ViT+connector β†’ run_image) + +**λ°°κ²½:** composerλŠ” `models[0]`(vision)의 base 가상 `run_image()`λ₯Ό ν˜ΈμΆœν•΄ LM μž„λ² λ”© κ³΅κ°„μ˜ μž„λ² λ”©μ„ μ–»λŠ”λ‹€. nntrainerμ—” ViT(`Lfm2VlVisionTransformer`)와 connector(`Lfm2VlConnector`)κ°€ 뢄리돼 있고 μ˜€μΌ€μŠ€νŠΈλ ˆμ΄ν„°λ§Œ λ‘˜μ„ μž‡λŠ”λ‹€. 이λ₯Ό ν•˜λ‚˜μ˜ Factory λͺ¨λΈλ‘œ 감싼닀. ViT의 `run()`은 파일 κΈ°λ°˜μ΄λ―€λ‘œ in-memory 픽셀을 μž„μ‹œ 파일둜 우회(μ˜€μΌ€μŠ€νŠΈλ ˆμ΄ν„°μ™€ λ™μΌν•˜κ²Œ κ²€μ¦λœ 경둜). + +**Files:** +- Create: `nntrainer/Applications/CausalLM/models/lfm2/lfm2-vl/lfm2_vl_vision_encoder.h` +- Create: `nntrainer/Applications/CausalLM/models/lfm2/lfm2-vl/lfm2_vl_vision_encoder.cpp` + +- [ ] **Step 1: 헀더 μž‘μ„±** + +`lfm2_vl_vision_encoder.h`: + +```cpp +// SPDX-License-Identifier: Apache-2.0 +/** + * @file lfm2_vl_vision_encoder.h + * @brief Loadable vision-encoder model for LFM2-VL: wraps the SigLIP2 ViT + * (Lfm2VlVisionTransformer) + pixel-unshuffle + Lfm2VlConnector so a + * single run_image() returns image embeddings already projected into + * the LFM2 LM embedding space (FP32, out_features wide). + * Registered with Factory under "Lfm2VlVisionEncoder". + */ +#ifndef __LFM2_VL_VISION_ENCODER_H__ +#define __LFM2_VL_VISION_ENCODER_H__ + +#include +#include + +#include "lfm2_vl_connector.h" +#include "vision/lfm2_vl_vision_transformer.h" + +namespace causallm { + +class Lfm2VlVisionEncoder : public Transformer { +public: + static constexpr const char *architectures = "Lfm2VlVisionEncoder"; + + Lfm2VlVisionEncoder(json &cfg, json &generation_cfg, json &nntr_cfg); + ~Lfm2VlVisionEncoder() override = default; + + void initialize() override; + void load_weight(const std::string &base_path) override; + + /** + * @brief Encode an in-memory FP32 image into LM-space embeddings. + * @param image multimodal_pointer{ float* pixels, byte_count }. Pixels are + * [3 * IMAGE_SIZE * IMAGE_SIZE] FP32 (CHW), preprocessed. + * @return multimodal_pointer{ malloc'd float* embeds, n_img_tokens * + * out_features * sizeof(float) }. Caller takes ownership (free()). + */ + multimodal_pointer run_image(const WSTR prompt, multimodal_pointer image, + int image_height, int image_width, + bool do_sample, const WSTR system_prompt, + const WSTR tail_prompt, + bool log_output) override; + +private: + json cfg_, generation_cfg_, nntr_cfg_; + unsigned int downsample_factor_{2}; + std::unique_ptr vit_; + std::unique_ptr connector_; + std::string cache_dir_{"/data/local/tmp"}; /**< temp file dir for ViT input */ +}; + +} // namespace causallm + +#endif // __LFM2_VL_VISION_ENCODER_H__ +``` + +- [ ] **Step 2: cpp μž‘μ„± β€” 생성/μ΄ˆκΈ°ν™”/κ°€μ€‘μΉ˜ λ‘œλ“œλŠ” μ˜€μΌ€μŠ€νŠΈλ ˆμ΄ν„° νŒ¨ν„΄ 차용** + +`lfm2_vl_vision_encoder.cpp`. `splitConfig`/ν‚€ 이름은 `lfm2_vl_model.cpp`의 μƒμ„±μžΒ·`load_weight`Β·`run`μ—μ„œ μ‚¬μš©ν•˜λŠ” 것과 **λ™μΌν•˜κ²Œ** λ§žμΆ˜λ‹€(`vision_config`, `image_size`, `patch_size`, `hidden_size`, `vision_model_file`, `connector_model_file`): + +```cpp +// SPDX-License-Identifier: Apache-2.0 +#include "lfm2_vl_vision_encoder.h" + +#include +#include +#include +#include + +namespace causallm { + +static json pick(const json &cfg, const char *key) { + return cfg.contains(key) ? cfg.at(key) : json::object(); +} + +Lfm2VlVisionEncoder::Lfm2VlVisionEncoder(json &cfg, json &generation_cfg, + json &nntr_cfg) + : Transformer(cfg, generation_cfg, nntr_cfg, ModelType::EMBEDDING), + cfg_(cfg), generation_cfg_(generation_cfg), nntr_cfg_(nntr_cfg) { + downsample_factor_ = cfg.value("downsample_factor", 2u); + json vision_cfg = pick(cfg, "vision_config"); + if (vision_cfg.empty()) + vision_cfg = cfg; // flat vision-only config + vit_ = std::make_unique(vision_cfg, generation_cfg, + nntr_cfg); + unsigned int vit_embed = vision_cfg.value("hidden_size", 768u); + unsigned int in_features = + vit_embed * downsample_factor_ * downsample_factor_; + unsigned int hidden = cfg.value("projector_hidden_size", 2560u); + unsigned int out_features = cfg.value("text_hidden_size", 1024u); + connector_ = std::make_unique(in_features, hidden, + out_features); + if (nntr_cfg.contains("cache_dir")) + cache_dir_ = nntr_cfg["cache_dir"].get(); +} + +void Lfm2VlVisionEncoder::initialize() { + vit_->initialize(); + vit_->allocateAndBindVitKVCache(); +} + +void Lfm2VlVisionEncoder::load_weight(const std::string &base_path) { + // ViT weights + if (nntr_cfg_.contains("vision_model_file")) + vit_->load_weight(base_path + "/" + + nntr_cfg_["vision_model_file"].get()); + else + vit_->load_weight(base_path); + // Connector weights + if (nntr_cfg_.contains("connector_model_file")) + connector_->loadWeights( + base_path + "/" + nntr_cfg_["connector_model_file"].get()); + else + throw std::runtime_error( + "Lfm2VlVisionEncoder: connector_model_file missing in nntr_config"); +} + +multimodal_pointer Lfm2VlVisionEncoder::run_image( + const WSTR /*prompt*/, multimodal_pointer image, int /*image_height*/, + int /*image_width*/, bool /*do_sample*/, const WSTR /*system_prompt*/, + const WSTR /*tail_prompt*/, bool log_output) { + + json vision_cfg = pick(cfg_, "vision_config"); + if (vision_cfg.empty()) + vision_cfg = cfg_; + unsigned int img_size = vision_cfg.value("image_size", 256u); + unsigned int patch_size = vision_cfg.value("patch_size", 16u); + unsigned int vit_embed = vision_cfg.value("hidden_size", 768u); + unsigned int ph = img_size / patch_size; + unsigned int pw = img_size / patch_size; + unsigned int n_patches = ph * pw; + + // 1) Write incoming FP32 pixels [3*img*img] to a temp file the ViT reads. + const size_t n_pixels = static_cast(3) * img_size * img_size; + std::string tmp_path = cache_dir_ + "/lfm2vl_vit_input.bin"; + { + std::ofstream ofs(tmp_path, std::ios::binary); + if (!ofs) + throw std::runtime_error("Lfm2VlVisionEncoder: cannot open temp " + + tmp_path); + ofs.write(reinterpret_cast(image.first), + static_cast(n_pixels * sizeof(float))); + } + + // 2) Run ViT; features land in getLastFeatures(). + vit_->run(tmp_path, false, "", "", log_output); + const std::vector &feats = vit_->getLastFeatures(); + if (feats.empty()) + throw std::runtime_error("Lfm2VlVisionEncoder: ViT produced no features"); + + // 3) pixel-unshuffle + connector MLP -> [n_img_tokens * out_features] FP32. + auto unshuffled = + pixelUnshuffle(feats, n_patches, vit_embed, ph, pw, downsample_factor_); + unsigned int n_img_tokens = connector_->outTokens(n_patches); + std::vector embeds = connector_->forward(unshuffled, n_img_tokens); + + // 4) Hand back a malloc'd buffer (composer/api owns + frees it). + const size_t out_bytes = embeds.size() * sizeof(float); + void *out = std::malloc(out_bytes); + if (!out) + throw std::runtime_error("Lfm2VlVisionEncoder: OOM for embeds"); + std::memcpy(out, embeds.data(), out_bytes); + if (log_output) + std::cout << "[Lfm2VlVisionEncoder] img_tokens=" << n_img_tokens + << " out_features=" << connector_->outFeatures() << "\n"; + return multimodal_pointer{out, out_bytes}; +} + +} // namespace causallm +``` + +- [ ] **Step 3: λΉŒλ“œ μ‹œμŠ€ν…œμ— μƒˆ .cpp 등둝** + +`nntrainer/Applications/CausalLM/models/lfm2/lfm2-vl/meson.build`(μ—†μœΌλ©΄ μƒμœ„ `models/meson.build`μ—μ„œ lfm2-vl μ†ŒμŠ€ λͺ©λ‘)을 μ—΄μ–΄ κΈ°μ‘΄ `lfm2_vl_connector.cpp`/`lfm2_vl_model.cpp` ν•­λͺ© μ˜†μ— `lfm2_vl_vision_encoder.cpp`λ₯Ό μΆ”κ°€. Android.mk λΉŒλ“œλ„ μ“°λ©΄ 동일 글둭에 ν¬ν•¨λ˜λŠ”μ§€ 확인(λ©”λͺ¨λ¦¬ `gauss-pluggable-bringup`의 generic `*/*.cpp` 글둭이면 μžλ™ 포함). + +- [ ] **Step 4: λΉŒλ“œ 확인** + +Run: 호슀트 λΉŒλ“œ(λ©”λͺ¨λ¦¬ ν™˜κ²½) β†’ Expected: `lfm2_vl_vision_encoder.cpp` 컴파일·링크 성곡. + +- [ ] **Step 5: Commit** + +```bash +cd nntrainer && git add Applications/CausalLM/models/lfm2/lfm2-vl/lfm2_vl_vision_encoder.{h,cpp} Applications/CausalLM/models/lfm2/lfm2-vl/meson.build +git commit -m "feat(lfm2-vl): add Lfm2VlVisionEncoder (ViT+connector -> run_image)" +``` + +### Task 1.3: ν—€λ“œλ¦¬μŠ€ 단독 ν…ŒμŠ€νŠΈμš© Factory 등둝 (main.cpp) + +**Files:** +- Modify: `nntrainer/Applications/CausalLM/main.cpp` + +- [ ] **Step 1: include + 등둝 μΆ”κ°€** + +`main.cpp` 상단 include에 `#include "models/lfm2/lfm2-vl/lfm2_vl_vision_encoder.h"` μΆ”κ°€. κΈ°μ‘΄ `Lfm2ForCausalLM` registerModel(295ν–‰ λΆ€κ·Ό) λ°”λ‘œ μ•„λž˜μ— μΆ”κ°€: + +```cpp + causallm::Factory::Instance().registerModel( + "Lfm2VlVisionEncoder", [](json cfg, json generation_cfg, json nntr_cfg) { + return std::make_unique( + cfg, generation_cfg, nntr_cfg); + }); +``` + +- [ ] **Step 2: λΉŒλ“œ + 등둝 확인** + +Run: 호슀트 λΉŒλ“œ ν›„ `./quick_dot_ai_test` 인자 없이 μ‹€ν–‰ β†’ Expected: `printRegistered`에 `Lfm2VlVisionEncoder` 포함. + +- [ ] **Step 3: Commit** + +```bash +cd nntrainer && git add Applications/CausalLM/main.cpp +git commit -m "feat(lfm2-vl): register Lfm2VlVisionEncoder in standalone factory" +``` + +--- + +## Milestone 2 β€” Quick.AI api: composer 개방 + 등둝 + 마컀 μ •ν•© + +### Task 2.1: api Factory에 LFM2 LM + vision encoder 등둝 + +**Files:** +- Modify: `api/quick_dot_ai_api.cpp` (`register_models()`, ~257–317) + +- [ ] **Step 1: include μΆ”κ°€** + +파일 상단 λͺ¨λΈ include μ˜μ—­μ—: + +```cpp +#include "lfm2_causallm.h" +#include "lfm2-vl/lfm2_vl_vision_encoder.h" +``` + +(κ²½λ‘œλŠ” api meson의 include_directoriesκ°€ `models/lfm2`λ₯Ό κ°€λ¦¬ν‚€λŠ”μ§€ 확인; μ•„λ‹ˆλ©΄ μƒλŒ€κ²½λ‘œ `../nntrainer/Applications/CausalLM/models/lfm2/...`둜 λ§žμΆ˜λ‹€.) + +- [ ] **Step 2: `register_models()`의 `MultilingualTinyBert` 등둝 λ’€, `#ifdef ENABLE_QNN` μ•žμ— μΆ”κ°€** + +```cpp + causallm::Factory::Instance().registerModel( + "Lfm2ForCausalLM", [](json cfg, json generation_cfg, json nntr_cfg) { + return std::make_unique(cfg, generation_cfg, + nntr_cfg); + }); + causallm::Factory::Instance().registerModel( + "Lfm2VlVisionEncoder", + [](json cfg, json generation_cfg, json nntr_cfg) { + return std::make_unique( + cfg, generation_cfg, nntr_cfg); + }); +``` + +- [ ] **Step 3: api λΉŒλ“œ 확인** + +Run: `./build.sh --platform=android` (λ©”λͺ¨λ¦¬ ν™˜κ²½) β†’ Expected: `libquick_dot_ai_api.so` 링크 성곡(λ―Έμ •μ˜ 심볼 μ—†μŒ). μ‹€νŒ¨ μ‹œ api meson.build의 μ†ŒμŠ€/링크에 lfm2 TUκ°€ libcausallm둜 λ“€μ–΄μ˜€λŠ”μ§€ 확인. + +- [ ] **Step 4: Commit** + +```bash +git add api/quick_dot_ai_api.cpp +git commit -m "feat(api): register Lfm2ForCausalLM + Lfm2VlVisionEncoder in factory" +``` + +### Task 2.2: 곡개 descriptor 2개 μΆ”κ°€ + +**Files:** +- Modify: `api/model_descriptors_public.cpp` + +- [ ] **Step 1: `kPublic[]` λ°°μ—΄μ˜ `gemma4-cpu` ν•­λͺ© λ’€(QNN `#ifdef` μ•ž)에 μΆ”κ°€** + +```cpp + {"lfm2-450m", "lfm2-vl", "LFM2-VL 450M (LM)", QDA_RUNTIME_NATIVE, B(0), + QDA_CAP_STREAMING | QDA_CAP_MULTIMODAL, + "LFM2-450M", /* device dir: /models/lfm2-450m */ + "Lfm2ForCausalLM"}, + {"siglip2-vl-encoder", "lfm2-vl", "SigLIP2 Vision Encoder", + QDA_RUNTIME_NATIVE, B(0), + QDA_CAP_VISION_ENCODER, + "SIGLIP2-VL-ENCODER", /* device dir: /models/siglip2-vl-encoder */ + "Lfm2VlVisionEncoder"}, +``` + +> `config_name`은 μ†Œλ¬Έμžν™” μ‹œ λ””λ°”μ΄μŠ€ 디렉터리λͺ…κ³Ό μΌμΉ˜ν•΄μ•Ό ν•œλ‹€(λ©”λͺ¨λ¦¬: resolve_model_pathκ°€ lowercased config_name을 dir둜 μ‚¬μš©, quant suffixλŠ” dead code). λ”°λΌμ„œ λ””λ°”μ΄μŠ€ dir은 `lfm2-450m`, `siglip2-vl-encoder`. + +- [ ] **Step 2: λΉŒλ“œ + μΉ΄νƒˆλ‘œκ·Έ 확인** + +Run: api λΉŒλ“œ ν›„ 호슀트/λ””λ°”μ΄μŠ€μ—μ„œ `getModelCatalogJson()` 좜λ ₯(λ˜λŠ” `api-app/test_api.cpp`둜 μΉ΄νƒˆλ‘œκ·Έ 덀프) β†’ Expected: `lfm2-450m`(cap 0b0101=STREAMING|MULTIMODAL)κ³Ό `siglip2-vl-encoder`(cap 0b1000000=VISION_ENCODER) λ“±μž₯. + +- [ ] **Step 3: Commit** + +```bash +git add api/model_descriptors_public.cpp +git commit -m "feat(api): add lfm2-450m + siglip2-vl-encoder public descriptors" +``` + +### Task 2.3: λ©€ν‹°λͺ¨λ‹¬ 경둜 CPU 개방 + 이미지 마컀 LFM2 μ •ν•© + ν”½μ…€ λ ˆμ΄μ•„μ›ƒ μΌλ°˜ν™” + +**λ°°κ²½:** `execute_multimodal`/`run_vision_encoder`와 `runMultimodal*` μ§„μž…μ μ΄ `#ifdef ENABLE_QNN`둜 κ°€λ“œλΌ CPU λΉŒλ“œμ—μ„œ λΉ„ν™œμ„±. 또 λ§ˆμ»€κ°€ `<|image|>` ν•˜λ“œμ½”λ”©(LFM2λŠ” ``=396), `run_vision_encoder`κ°€ `PATCH_SIZE=512`(512Β²) ν•˜λ“œμ½”λ”©(LFM2λŠ” 256Β²). + +**Files:** +- Modify: `api/quick_dot_ai_api.cpp` + +- [ ] **Step 1: composer/헬퍼λ₯Ό QNN κ°€λ“œ λ°–μœΌλ‘œ 이동** + +`#ifdef ENABLE_QNN`(2166) ... `#endif`(2289)둜 감싼 `execute_multimodal`와 `run_vision_encoder`λ₯Ό κ°€λ“œ **λ°–**으둜 κΊΌλ‚Έλ‹€(κ°€λ“œ 제거). 단 `run_vision_encoder` 내뢀에 QNN μ „μš© μ½”λ“œκ°€ 있으면 κ·Έ 라인만 `#ifdef ENABLE_QNN`둜 κ΅­μ†Œν™”. 또 `runMultimodalHandleStreaming`(2351), `runMultimodalHandleWithMessages`(2451) λ“±μ—μ„œ `#ifdef ENABLE_QNN ... #else LOGE("built without ENABLE_QNN") #endif` ꡬ쑰의 `#ifdef`/`#else`λ₯Ό μ œκ±°ν•΄ 본문이 항상 컴파일되게 ν•œλ‹€. + +- [ ] **Step 2: 이미지 토큰 id 해석을 LFM2 ν˜Έν™˜μœΌλ‘œ** + +`execute_multimodal`의 2188행을 ꡐ체: + +```cpp + // LFM2 uses "" (id 396); gauss/vjepa use "<|image|>". Try both. + int32_t image_token_id = tok->TokenToId("<|image|>"); + if (image_token_id < 0) + image_token_id = tok->TokenToId(""); +``` + +- [ ] **Step 3: `run_vision_encoder` ν”½μ…€ λ°”μ΄νŠΈ 계산을 λͺ¨λΈ μ£Όλ„λ‘œ** + +2280–2284ν–‰μ˜ `PATCH_SIZE=512` 가정을 μ œκ±°ν•˜κ³ , μ‹€μ œ 이미지 해상도 기반으둜: + +```cpp + // Pixel buffer is [3 * H * W] FP32 (CHW). The vision model interprets its + // own expected resolution; we size the pointer by the given dimensions. + const size_t pixel_bytes = static_cast(3) * + static_cast(originalHeight) * + static_cast(originalWidth) * sizeof(float); +``` + +(κΈ°μ‘΄ numPatches*3*512*512 라인 μ‚­μ œ. `numPatches` μΈμžλŠ” LFM2 κ²½λ‘œμ—μ„œ λ―Έμ‚¬μš©μ΄λ‚˜ μ‹œκ·Έλ‹ˆμ²˜λŠ” μœ μ§€.) + +- [ ] **Step 4: CPU λΉŒλ“œ 확인** + +Run: `./build.sh --platform=android` (ENABLE_QNN 없이도 λΉŒλ“œλ˜λŠ” native λ³€ν˜• λ˜λŠ” QNN λΉŒλ“œ) β†’ Expected: `runMultimodalHandleWithMessagesStreaming` 등이 LOGE μŠ€ν…μ΄ μ•„λ‹Œ μ‹€μ œ 본문으둜 컴파일. + +- [ ] **Step 5: QNN 경둜 λ¬΄νšŒκ·€ 점검(슀λͺ¨ν¬)** + +κΈ°μ‘΄ `vjepa-qnn`/gauss-vision μΉ΄νƒˆλ‘œκ·ΈΒ·λ‘œλ“œ κ²½λ‘œκ°€ μ»΄νŒŒμΌΒ·λ“±λ‘ κ·ΈλŒ€λ‘œμΈμ§€ 확인(κ°€λ“œ μ œκ±°κ°€ QNN μ „μš© μ½”λ“œλ₯Ό κΉ¨μ§€ μ•Šμ•˜λŠ”μ§€). λ””λ°”μ΄μŠ€ QNN λΉŒλ“œμ—μ„œ V-JEPA λ©€ν‹°λͺ¨λ‹¬ 1회 λ‘œλ“œκΉŒμ§€. + +- [ ] **Step 6: Commit** + +```bash +git add api/quick_dot_ai_api.cpp +git commit -m "feat(api): enable multimodal composer on CPU + LFM2 image marker + model-driven pixel layout" +``` + +--- + +## Milestone 3 β€” ν—€λ“œλ¦¬μŠ€ 검증 (였라클 λŒ€μ‘°) + +### Task 3.1: λ””λ°”μ΄μŠ€ 에셋 + nntr_config μ€€λΉ„ + +**Files:** (λ””λ°”μ΄μŠ€, λΉŒλ“œ μ‚°μΆœλ¬Ό μ•„λ‹˜) + +- [ ] **Step 1: LFM2-VL κ°€μ€‘μΉ˜ λ³€ν™˜ + 배치** + +nntrainer `Applications/CausalLM/res/lfm2-vl/`의 λ³€ν™˜ 슀크립트둜 HF LFM2-VL-450M β†’ nntrainer λ°”μ΄λ„ˆλ¦¬ 생성: +- `convert_vision_hf.py` β†’ ViT κ°€μ€‘μΉ˜, `convert_connector.py` β†’ connector κ°€μ€‘μΉ˜, `convert_lm.py` β†’ LM κ°€μ€‘μΉ˜, `convert_embedding.py` β†’ standalone μž„λ² λ”© bin. + +λ””λ°”μ΄μŠ€ 배치: +``` +/sdcard/Download/aistudio-mobile/models/siglip2-vl-encoder/ (ViT.bin, connector.bin, nntr_config.json) +/sdcard/Download/aistudio-mobile/models/lfm2-450m/ (lm.bin, tokenizer, embedding.bin, nntr_config.json) +``` + +- [ ] **Step 2: `nntr_config.json` μž‘μ„±** + +vision: `{"vision_model_file":"vit.bin","connector_model_file":"connector.bin","vision_config":{"image_size":256,"patch_size":16,"hidden_size":768},"projector_hidden_size":2560,"text_hidden_size":1024,"downsample_factor":2}` +LM: κΈ°μ‘΄ LFM2 LM 단독 config에 `embedding_bin_path`/`tokenizer_file`/`model_file_name` 포함(λͺ¨λ†€λ¦¬μ‹ `res/lfm2-vl/README.md` 토큰 id·치수 μ°Έκ³ : image=396,start=498,end=499, LM hidden=1024). + +### Task 3.2: λͺ¨λ†€λ¦¬μ‹ 였라클 좜λ ₯ 캑처 + +- [ ] **Step 1: λͺ¨λ†€λ¦¬μ‹ 경둜둜 μ •λ‹΅ 생성** + +μ „μ²˜λ¦¬λœ 이미지 ν…μ„œ(`naflex_preprocess.py`둜 256Β² FP32 [3,256,256] 생성)와 ν”„λ‘¬ν”„νŠΈλ‘œ standalone μ‹€ν–‰: + +Run (λ””λ°”μ΄μŠ€ `/data/local/tmp/Quick.AI`): `./quick_dot_ai_test`에 `architecture=Lfm2VlForConditionalGeneration`, `image_tensor_path=<...>.bin`, `sample_input="What is in this image?"` μ„€μ •ν•œ nntr_config둜 μ‹€ν–‰. +Expected: 의미 μžˆλŠ” μΊ‘μ…˜/응닡 + 생성 토큰열을 둜그둜 μ €μž₯(였라클). + +### Task 3.3: composer νŽ˜μ–΄ 경둜둜 동일 μž…λ ₯ μž¬ν˜„ + +- [ ] **Step 1: api ν—€λ“œλ¦¬μŠ€λ‘œ vision+LLM νŽ˜μ–΄ λ‘œλ“œ ν›„ λ©€ν‹°λͺ¨λ‹¬ μ‹€ν–‰** + +`api-app/test_api.cpp`(λ˜λŠ” 동등 ν—€λ“œλ¦¬μŠ€ λ“œλΌμ΄λ²„)에 μΌ€μ΄μŠ€ μΆ”κ°€: `loadMultimodalHandleByName("siglip2-vl-encoder", "lfm2-450m", ...)` β†’ μ „μ²˜λ¦¬λœ 동일 이미지(256Β² FP32) + ν”„λ‘¬ν”„νŠΈ(`` 포함)둜 `runMultimodalHandleStreaming` 호좜. + +```cpp +// pseudo-driver (test_api.cpp): exact API per quick_dot_ai_api.h +CausalLmHandle h = nullptr; +loadMultimodalHandleByName(/*compute*/CPU, "siglip2-vl-encoder", "lfm2-450m", + model_base_path, native_lib_dir, &h); +runMultimodalHandleStreaming(h, "What is in this image?", + pixels /*[3*256*256] FP32*/, /*numPatches*/256, + /*H*/256, /*W*/256, on_token, nullptr); +``` + +- [ ] **Step 2: 였라클 λŒ€μ‘° (1μ°¨ 게이트)** + +Run: μœ„ λ“œλΌμ΄λ²„ μ‹€ν–‰. +Expected: composer 경둜 생성 토큰열이 Task 3.2 였라클과 **일치(λ˜λŠ” μ˜λ―Έμƒ 동일)**. 뢈일치 μ‹œ [[gemma4-qnn-garbage-debug]] λ₯˜ ν—€λ“œλ¦¬μŠ€ 마컀둜 디버그: +- (a) `embeddingBytesPerToken`==4096, `lookupEmbedding(396)`이 nullptr μ•„λ‹˜ 확인. +- (b) `[MM] text=.. image=.. total=..` λ‘œκ·Έμ—μ„œ image 토큰 수 == connector `outTokens`(=256/4=64) 확인. +- (c) vision run_image 좜λ ₯ μž„λ² λ”©μ˜ μŠ€μΌ€μΌμ΄ LM μž„λ² λ”©κ³Ό 동일 λ²”μœ„(FP32 identity)인지 β€” connector 좜λ ₯ vs `lookupEmbedding`된 ν…μŠ€νŠΈ μž„λ² λ”©μ˜ norm 비ꡐ. +- (d) 마컀 μœ„μΉ˜: ``κ°€ 396으둜 ν† ν°ν™”λ˜μ–΄ splice μœ„μΉ˜κ°€ μ˜€μΌ€μŠ€νŠΈλ ˆμ΄ν„°(`<|image_start|><|image_end|>`)와 μ •ν•©ν•˜λŠ”μ§€ β€” ν•„μš” μ‹œ ν”„λ‘¬ν”„νŠΈ ν…œν”Œλ¦Ώ/마컀λ₯Ό μ˜€μΌ€μŠ€νŠΈλ ˆμ΄ν„°μ™€ λ™μΌν•˜κ²Œ. + +- [ ] **Step 3: 검증 λ©”λͺ¨ 컀밋(있으면)** + +```bash +git add api-app/test_api.cpp +git commit -m "test(api): headless LFM2-VL composer pair vs monolithic oracle" +``` + +--- + +## Milestone 4 β€” Android μ•±: λ―ΉμŠ€μ•€λ§€μΉ˜ picker + NaFlex μ „μ²˜λ¦¬ + +### Task 4.1: `loadMultimodalHandleByName` JNI λ…ΈμΆœ + +**Files:** +- Modify: `Android/QuickDotAI/src/main/cpp/quickai_jni.cpp` +- Modify: `Android/QuickDotAI/src/main/java/com/example/quickdotai/NativeCausalLm.kt` + +- [ ] **Step 1: JNI ν•¨μˆ˜ μΆ”κ°€ (quickai_jni.cpp)** + +κΈ°μ‘΄ `loadModelHandleByName` JNI κ΅¬ν˜„ νŒ¨ν„΄μ„ 따라 μΆ”κ°€(ν•Έλ“€ long λ°˜ν™˜): + +```cpp +extern "C" JNIEXPORT jlong JNICALL +Java_com_example_quickdotai_NativeCausalLm_loadMultimodalHandleByNameNative( + JNIEnv *env, jobject /*thiz*/, jint compute, jstring emb_id, jstring llm_id, + jstring base_path, jstring native_lib_dir) { + const char *c_emb = env->GetStringUTFChars(emb_id, nullptr); + const char *c_llm = env->GetStringUTFChars(llm_id, nullptr); + const char *c_base = env->GetStringUTFChars(base_path, nullptr); + const char *c_nld = + native_lib_dir ? env->GetStringUTFChars(native_lib_dir, nullptr) : nullptr; + CausalLmHandle handle = nullptr; + ErrorCode ec = loadMultimodalHandleByName( + static_cast(compute), c_emb, c_llm, c_base, c_nld, &handle); + env->ReleaseStringUTFChars(emb_id, c_emb); + env->ReleaseStringUTFChars(llm_id, c_llm); + env->ReleaseStringUTFChars(base_path, c_base); + if (c_nld) env->ReleaseStringUTFChars(native_lib_dir, c_nld); + if (ec != CAUSAL_LM_ERROR_NONE) { + LOGE("loadMultimodalHandleByNameNative failed: %d", ec); + return 0; + } + return reinterpret_cast(handle); +} +``` + +> `loadMultimodalHandleByName`의 μ •ν™•ν•œ C μ‹œκ·Έλ‹ˆμ²˜λŠ” `api/quick_dot_ai_api.h:250` 확인 ν›„ 인자 μˆœμ„œλ₯Ό λ§žμΆ˜λ‹€(compute/emb_id/llm_id/base/native_lib_dir/out_handle). + +- [ ] **Step 2: Kotlin external fun + 래퍼 (NativeCausalLm.kt)** + +κΈ°μ‘΄ λ©€ν‹°λͺ¨λ‹¬ external fun μ˜μ—­(272ν–‰ λΆ€κ·Ό)에: + +```kotlin +external fun loadMultimodalHandleByNameNative( + compute: Int, embId: String, llmId: String, + basePath: String, nativeLibDir: String? +): Long +``` + +그리고 곡개 래퍼(κΈ°μ‘΄ loadModelHandleByName 래퍼 μ˜†): + +```kotlin +fun loadMultimodalHandleByName( + compute: BackendType, embId: String, llmId: String, + basePath: String, nativeLibDir: String? = null +): Long = loadMultimodalHandleByNameNative( + compute.ordinal, embId, llmId, basePath, nativeLibDir) +``` + +- [ ] **Step 3: λΉŒλ“œ 확인** + +Run: `./Android/gradlew :QuickDotAI:assembleDebug` β†’ Expected: BUILD SUCCESSFUL, JNI 심볼 링크. + +- [ ] **Step 4: Commit** + +```bash +git add Android/QuickDotAI/src/main/cpp/quickai_jni.cpp Android/QuickDotAI/src/main/java/com/example/quickdotai/NativeCausalLm.kt +git commit -m "feat(android): expose loadMultimodalHandleByName via JNI" +``` + +### Task 4.2: ModelCatalog 헬퍼 (vision/LLM λΆ„λ₯˜) + +**Files:** +- Modify: `Android/QuickDotAI/src/main/java/com/example/quickdotai/ModelCatalog.kt` + +- [ ] **Step 1: 헬퍼 μΆ”κ°€** + +`ModelCatalog` object에: + +```kotlin +/** Vision-encoder producers selectable as the multimodal "eye". */ +fun visionEncoders(): List = + all().filter { Capability.VISION_ENCODER in it.capabilities } + +/** LLMs that can consume image embeddings (MULTIMODAL-capable). */ +fun pairableLlms(): List = + all().filter { Capability.MULTIMODAL in it.capabilities } +``` + +> `Capability` enum에 `VISION_ENCODER`κ°€ μžˆλŠ”μ§€ 확인. μ—†μœΌλ©΄ `ModelCatalog.kt`의 capability λΉ„νŠΈ νŒŒμ„œ(104–111ν–‰)에 `0b1000000 -> VISION_ENCODER` μΆ”κ°€ + enum 보강(api `QDA_CAP_VISION_ENCODER = 1u<<6`와 μ •ν•©). + +- [ ] **Step 2: λΉŒλ“œ + λ‹¨μœ„ 확인** + +Run: `./Android/gradlew :QuickDotAI:compileDebugKotlin` β†’ Expected: 성곡. (κ°€λŠ₯ν•˜λ©΄ `visionEncoders()`κ°€ `siglip2-vl-encoder`, `pairableLlms()`κ°€ `lfm2-450m` ν¬ν•¨ν•˜λŠ”μ§€ μΉ΄νƒˆλ‘œκ·Έ 덀프 둜그둜 확인.) + +- [ ] **Step 3: Commit** + +```bash +git add Android/QuickDotAI/src/main/java/com/example/quickdotai/ModelCatalog.kt +git commit -m "feat(android): catalog helpers for vision encoders + pairable LLMs" +``` + +### Task 4.3: SigLIP-NaFlex μ „μ²˜λ¦¬(MVP κ³ μ • 256Β²) + +**Files:** +- Create: `Android/QuickDotAI/src/main/java/com/example/quickdotai/SigLipNaFlexImageProcessor.kt` + +- [ ] **Step 1: ν”„λ‘œμ„Έμ„œ μž‘μ„± (256Β² square, mean/std 0.5, CHW)** + +`LlavaNextImageProcessor`의 `ModelInput`/μ •κ·œν™” νŒ¨ν„΄μ„ λ”°λ₯΄λ˜ 256Β²Β·patch16: + +```kotlin +package com.example.quickdotai + +import android.graphics.Bitmap +import android.graphics.Color + +/** + * SigLIP2 (LFM2-VL) preprocessing β€” MVP: fixed 256x256 square resize, no + * NaFlex dynamic resolution. Output: FP32 CHW [3*256*256], normalized to + * (x/255 - 0.5)/0.5. + */ +class SigLipNaFlexImageProcessor { + companion object { + const val IMAGE_SIZE = 256 + const val PATCH_SIZE = 16 + private const val MEAN = 0.5f + private const val STD = 0.5f + } + + /** Returns (pixelValues CHW FP32, numPatches=(IMAGE_SIZE/PATCH_SIZE)^2). */ + fun preprocess(src: Bitmap): NativeCausalLm.MultimodalInput { + val resized = Bitmap.createScaledBitmap(src, IMAGE_SIZE, IMAGE_SIZE, true) + val n = IMAGE_SIZE * IMAGE_SIZE + val out = FloatArray(3 * n) + val px = IntArray(n) + resized.getPixels(px, 0, IMAGE_SIZE, 0, 0, IMAGE_SIZE, IMAGE_SIZE) + for (i in 0 until n) { + val p = px[i] + out[i] = ((Color.red(p) / 255f) - MEAN) / STD // R plane + out[n + i] = ((Color.green(p) / 255f) - MEAN) / STD // G plane + out[2 * n + i] = ((Color.blue(p) / 255f) - MEAN) / STD // B plane + } + val patches = (IMAGE_SIZE / PATCH_SIZE) * (IMAGE_SIZE / PATCH_SIZE) + return NativeCausalLm.MultimodalInput( + pixelValues = out, numPatches = patches, + originalHeight = IMAGE_SIZE, originalWidth = IMAGE_SIZE) + } +} +``` + +> `MultimodalInput` ν•„λ“œλͺ…/νŒ¨ν‚€μ§€λŠ” `NativeCausalLm.kt:117` 확인 ν›„ μ •ν•©. `numPatches`λŠ” ν—€λ“œλ¦¬μŠ€μ—μ„œ κ²€μ¦ν•œ vision encoder κΈ°λŒ€κ°’(=256, ViT μž…λ ₯ 패치 수)κ³Ό 동일 μ˜λ―ΈμΈμ§€ 확인 β€” composerλŠ” numPatchesλ₯Ό LFM2 κ²½λ‘œμ—μ„œ λ―Έμ‚¬μš©ν•˜λ―€λ‘œ 값은 μ§„λ‹¨μš©. + +- [ ] **Step 2: λΉŒλ“œ 확인** + +Run: `./Android/gradlew :QuickDotAI:compileDebugKotlin` β†’ Expected: 성곡. + +- [ ] **Step 3: Commit** + +```bash +git add Android/QuickDotAI/src/main/java/com/example/quickdotai/SigLipNaFlexImageProcessor.kt +git commit -m "feat(android): SigLIP NaFlex image processor (MVP fixed 256)" +``` + +### Task 4.4: λ―ΉμŠ€μ•€λ§€μΉ˜ picker UI + νŽ˜μ–΄ λ‘œλ“œ + μ „μ²˜λ¦¬ λΆ„κΈ° + +**Files:** +- Modify: `Android/SampleTestAPP/.../MainActivity.kt` + +- [ ] **Step 1: vision/LLM λ“œλ‘­λ‹€μš΄ 2개 μΆ”κ°€** + +OpenAI νƒ­(λ©€ν‹°λͺ¨λ‹¬ 경둜)에 κΈ°μ‘΄ FAMILY λ“œλ‘­λ‹€μš΄ νŒ¨ν„΄(`dropdownField`)을 μž¬μ‚¬μš©ν•΄ 두 개 μΆ”κ°€: +- "VISION ENCODER" = `ModelCatalog.visionEncoders().map { it.id }` +- "LLM" = `ModelCatalog.pairableLlms().map { it.id }` +선택값을 μƒνƒœλ‘œ 보관(`var selectedVisionId by remember`, `var selectedLlmId by remember`). + +- [ ] **Step 2: νŽ˜μ–΄ λ‘œλ“œ λ²„νŠΌ λ™μž‘** + +"Load multimodal pair" μ•‘μ…˜μ—μ„œ: + +```kotlin +val handle = NativeCausalLm.loadMultimodalHandleByName( + compute = selectedBackend, embId = selectedVisionId, + llmId = selectedLlmId, basePath = modelBasePath, + nativeLibDir = applicationInfo.nativeLibraryDir) +if (handle == 0L) { /* show error banner */ } else { currentHandle = handle } +``` + +- [ ] **Step 3: μ „μ²˜λ¦¬ λΆ„κΈ° (λͺ¨λΈλ³„)** + +이미지 ν”½ ν›„ ν”„λ‘œμ„Έμ„œ 선택: + +```kotlin +val mmInput = if (selectedVisionId == "siglip2-vl-encoder") + SigLipNaFlexImageProcessor().preprocess(bitmap) +else + LlavaNextImageProcessor(...).process(bitmap) // κΈ°μ‘΄ 경둜 μœ μ§€ +``` + +그리고 `runMultimodalHandleWithMessagesStreaming(handle, messages, ..., mmInput.pixelValues, mmInput.numPatches, mmInput.originalHeight, mmInput.originalWidth, listener)` 호좜(ν”„λ‘¬ν”„νŠΈ/λ©”μ‹œμ§€μ— `` 포함). + +- [ ] **Step 4: λΉŒλ“œ 확인** + +Run: `./Android/gradlew :SampleTestAPP:compileDebugKotlin` β†’ Expected: BUILD SUCCESSFUL. + +- [ ] **Step 5: Commit** + +```bash +git add Android/SampleTestAPP +git commit -m "feat(app): vision+LLM mix-and-match picker + LFM2-VL preprocessing path" +``` + +--- + +## Milestone 5 β€” APK on-device 검증 + +### Task 5.1: APK λΉŒλ“œΒ·μ„€μΉ˜ + νŽ˜μ–΄ λ‘œλ“œΒ·μΆ”λ‘  + +**절차:** λ©”λͺ¨λ¦¬ `quickai-android-build-env`(λΉŒλ“œ/μ„€μΉ˜) + `gauss-pluggable-bringup`(APK verify: 탭별 λ‘œλ”, μΌ€μ΄μŠ€λ§ˆλ‹€ force-stop+relaunch). + +- [ ] **Step 1: λΉŒλ“œΒ·μ„€μΉ˜** + +Run: `./build.sh --platform=android && ./apk_install_android.sh` (λ˜λŠ” λ©”λͺ¨λ¦¬ κΈ°μ€€ λͺ…λ Ή). +Expected: μ„€μΉ˜ 성곡, 16KB μ •λ ¬ κ²½κ³ λŠ” 무해(OK). + +- [ ] **Step 2: λ””λ°”μ΄μŠ€ 에셋 확인** + +`/sdcard/Download/aistudio-mobile/models/siglip2-vl-encoder/`, `/lfm2-450m/`에 κ°€μ€‘μΉ˜+config 쑴재(Task 3.1κ³Ό 동일). + +- [ ] **Step 3: λ―ΉμŠ€μ•€λ§€μΉ˜ λ‘œλ“œ + 이미지 μΆ”λ‘ ** + +OpenAI νƒ­ β†’ VISION="siglip2-vl-encoder", LLM="lfm2-450m" 선택 β†’ "Load multimodal pair" β†’ 이미지 ν”½ β†’ λ©”μ‹œμ§€(`` 포함) β†’ Run(streaming). +Expected: logcat `QuickAI`에 `SINGLE/MULTI-MODEL SUCCESS` + `[MM] text=.. image=64 total=..`, μ½˜μ†”μ— 이미지λ₯Ό μ‹€μ œ μΈμ§€ν•œ 일관 응닡. gauss-vision "이미지 μ•ˆ λ³΄μž„" νšŒκ·€ μ—†λŠ”μ§€ 확인. + +- [ ] **Step 4: ν…μŠ€νŠΈ-only LFM2 LM 단독 점검(νšŒκ·€)** + +LLM만 `lfm2-450m`둜 일반 ν…μŠ€νŠΈ λ‘œλ“œ/생성도 정상인지(μž„λ² λ”© 가상 μΆ”κ°€κ°€ ν…μŠ€νŠΈ 경둜 무영ν–₯) 확인. + +- [ ] **Step 5: κ²°κ³Όλ₯Ό λ©”λͺ¨λ¦¬μ— 기둝** + +검증 맀트릭슀(λ””λ°”μ΄μŠ€/μΌ€μ΄μŠ€/κ²°κ³Ό)λ₯Ό λ©”λͺ¨λ¦¬ `gauss-pluggable-bringup` λ˜λŠ” μ‹ κ·œ `lfm2-vl-bringup` λ©”λͺ¨λ¦¬μ— μΆ”κ°€. + +--- + +## Self-Review (spec λŒ€μ‘°) + +- **λΆ„ν•΄(spec 4.1)** β†’ Task 1.1/1.2/1.3 βœ… (vision encoder 래퍼 + LM base 가상 + Factory). +- **composer CPU 개방(spec 4.2)** β†’ Task 2.3 Step 1/4 βœ…. +- **이미지 마컀 μ •ν•© 396(spec 4.2)** β†’ Task 2.3 Step 2 βœ…. +- **splice 토큰 수 = outTokens(spec 4.2)** β†’ κΈ°μ‘΄ composerκ°€ `image_embeds.size()/bpt`둜 μžλ™ 계산(Task 3.3 Step 2(b)μ—μ„œ 64 검증) βœ…. +- **곡개 descriptor 2개(spec 4.2)** β†’ Task 2.2 βœ…. +- **JNI loadMultimodalHandleByName(spec 4.3)** β†’ Task 4.1 βœ…. +- **λ―ΉμŠ€μ•€λ§€μΉ˜ UI(spec 4.3)** β†’ Task 4.2 + 4.4 βœ…. +- **NaFlex μ „μ²˜λ¦¬ MVP 256Β²(spec 4.3)** β†’ Task 4.3 βœ…. +- **ν—€λ“œλ¦¬μŠ€β†’APK 단계 검증(spec 6)** β†’ Milestone 3 + 5 βœ…. +- **리슀크: μž„λ² λ”© 곡간 μ •ν•©(spec 7)** β†’ Task 3.3 Step 2(a)(c) 게이트 βœ…. +- **리슀크: 마컀 의미둠(spec 7)** β†’ Task 3.3 Step 2(d) 였라클 λŒ€μ‘° βœ…. +- **리슀크: QNN λ¬΄νšŒκ·€(spec 7)** β†’ Task 2.3 Step 5 βœ…. + +**미반영(μ˜λ„μ , spec 8 후속):** full NaFlex 동적 해상도, 닀쀑 이미지, NPU μ–‘μžν™”, ViT in-memory μž…λ ₯ 경둜(MVPλŠ” temp 파일), `ModelDescriptor` image_token μŠ€ν‚€λ§ˆ 정식화. + +**μ•Œλ €μ§„ λΆˆν™•μ •(κ΅¬ν˜„ 쀑 첫 λΉŒλ“œμ—μ„œ ν™•μ •):** (1) `DIM` λ“± base hidden-size 멀버 μ •ν™•λͺ… β€” Task 1.1μ—μ„œ concrete `lookupEmbedding` 본문으둜 확인. (2) lfm2 헀더 include 경둜(api meson include_directories) β€” Task 2.1 Step 1. (3) `MultimodalInput`/`Capability.VISION_ENCODER` 쑴재 β€” Task 4.2/4.3. (4) `loadMultimodalHandleByName` C μ‹œκ·Έλ‹ˆμ²˜ 인자 μˆœμ„œ β€” Task 4.1. diff --git a/docs/superpowers/specs/2026-06-02-family-dropdown-design.md b/docs/superpowers/specs/2026-06-02-family-dropdown-design.md new file mode 100644 index 00000000..be403259 --- /dev/null +++ b/docs/superpowers/specs/2026-06-02-family-dropdown-design.md @@ -0,0 +1,131 @@ +# Model Family 선택을 λ“œλ‘­λ‹€μš΄μœΌλ‘œ λ³€κ²½ β€” 섀계 λ¬Έμ„œ + +- **λ‚ μ§œ**: 2026-06-02 +- **λŒ€μƒ μ•±**: Quick.AI Android (`SampleTestAPP`) +- **μƒνƒœ**: 승인됨, κ΅¬ν˜„ κ³„νš λŒ€κΈ° + +## 1. λͺ©μ  + +ν˜„μž¬ λͺ¨λΈ **FAMILY** 선택은 κ°€λ‘œ 슀크둀 μΉ© ν–‰(`chipRow()`)으둜 κ΅¬ν˜„λ˜μ–΄ μžˆλ‹€. FAMILY ν•­λͺ© μˆ˜κ°€ λŠ˜μ–΄λ‚˜λ©΄μ„œ μΉ© 행이 κΈΈμ–΄μ§€κ³  μ–΄λ–€ 값이 μ„ νƒλλŠ”μ§€ ν•œλˆˆμ— νŒŒμ•…ν•˜κΈ° μ–΄λ ΅λ‹€. FAMILY 선택을 Material μŠ€νƒ€μΌ λ“œλ‘­λ‹€μš΄ ν•„λ“œλ‘œ λ°”κΏ” 선택값 가독성과 ν™”λ©΄ 곡간 νš¨μœ¨μ„ κ°œμ„ ν•œλ‹€. + +## 2. λ²”μœ„ + +### 포함 +- **FAMILY μ„ νƒλ§Œ** λ“œλ‘­λ‹€μš΄μœΌλ‘œ λ³€κ²½. +- Run/OpenAI νƒ­κ³Ό Chat νƒ­ **두 κ³³ λͺ¨λ‘** 적용. + +### μ œμ™Έ (λ³€κ²½ν•˜μ§€ μ•ŠμŒ) +- RUNTIME / BACKEND / QUANTIZATION 선택 β€” κΈ°μ‘΄ `chipRow()` μΉ© κ·ΈλŒ€λ‘œ μœ μ§€. +- FAMILY λ³€κ²½ μ‹œ RUNTIME/BACKENDλ₯Ό μž¬κ³„μ‚°ν•˜λŠ” cascading 둜직 β€” κΈ°μ‘΄ λžŒλ‹€ κ·ΈλŒ€λ‘œ μž¬μ‚¬μš©. +- `ModelCatalog.kt` β€” μ˜΅μ…˜ μ†ŒμŠ€(`ModelCatalog.families()`)λŠ” λ³€κ²½ 없이 κ·ΈλŒ€λ‘œ μ‚¬μš©. +- μžλ™ν™” UI ν…ŒμŠ€νŠΈ λ„μž… β€” ν”„λ‘œμ νŠΈμ— UI ν…ŒμŠ€νŠΈ 인프라가 μ—†μœΌλ―€λ‘œ λ²”μœ„ λ°–. + +## 3. μ•„ν‚€ν…μ²˜ / μ ‘κ·Ό + +`MainActivity.kt`에 μ‹ κ·œ 헬퍼 ν•¨μˆ˜ `dropdownField()`λ₯Ό μΆ”κ°€ν•œλ‹€. μ‹œκ·Έλ‹ˆμ²˜λŠ” κΈ°μ‘΄ `chipRow()`와 **동일**ν•˜κ²Œ λ§žμΆ˜λ‹€: + +```kotlin +private fun dropdownField( + t: M3Tokens, + options: List, + selected: String, + onPick: (String) -> Unit +): View +``` + +μ‹œκ·Έλ‹ˆμ²˜λ₯Ό λ™μΌν•˜κ²Œ λ§žμΆ”λ©΄ FAMILY ν˜ΈμΆœλΆ€ 2κ³³μ—μ„œ ν•¨μˆ˜ μ΄λ¦„λ§Œ `chipRow` β†’ `dropdownField`둜 κ΅μ²΄ν•˜λ©΄ 되고, cascading λ™μž‘μ„ 담은 κΈ°μ‘΄ λžŒλ‹€λŠ” κ·ΈλŒ€λ‘œ μ „λ‹¬λœλ‹€. + +``` +chipRow() ← RUNTIME / BACKEND / QUANTIZATION 계속 μ‚¬μš© (λ³€κ²½ μ—†μŒ) +dropdownField() ← FAMILY μ „μš© (μ‹ κ·œ) +``` + +이 λ°©μ‹μœΌλ‘œ λ³€κ²½ λ²”μœ„λ₯Ό μ΅œμ†Œν™”ν•˜κ³  κΈ°μ‘΄ μΉ© 둜직과의 결합을 λŠλŠ”λ‹€. + +## 4. `dropdownField()` μ»΄ν¬λ„ŒνŠΈ 섀계 + +### μ™Έν˜• +- M3 토큰을 μ‚¬μš©ν•˜λŠ” outlined ν•„λ“œ. +- μ’ŒμΈ‘μ— ν˜„μž¬ 선택값 ν…μŠ€νŠΈ(`selected`), μš°μΈ‘μ— `β–Ύ` μ•„μ΄μ½˜. +- ν…Œλ‘λ¦¬/배경은 κΈ°μ‘΄ 헬퍼(`strokedSolid`, `solid`, `dp`)와 색상 토큰(`onSurface`, `onSurfaceVar`, `outline` λ“±)을 μž¬μ‚¬μš©ν•΄ μΉ©κ³Ό μ‹œκ°μ  일관성 μœ μ§€. + +### μƒν˜Έμž‘μš© +- ν•„λ“œλ₯Ό νƒ­ν•˜λ©΄ `android.widget.PopupMenu`λ₯Ό ν•„λ“œ View에 μ•΅μ»€μ‹œμΌœ `options` λͺ©λ‘μ„ ν‘œμ‹œ. +- 메뉴 ν•­λͺ© 선택 μ‹œ `onPick(opt)` 호좜 β†’ κΈ°μ‘΄ λžŒλ‹€κ°€ μƒνƒœλ₯Ό κ°±μ‹ ν•˜κ³  `rebuildUi()`λ₯Ό 호좜 β†’ ν•„λ“œ 라벨이 μƒˆ μ„ νƒκ°’μœΌλ‘œ λ‹€μ‹œ κ·Έλ €μ§„λ‹€. + +### 선택 ν‘œμ‹œ +- νŒμ—… λ©”λ‰΄μ—μ„œ ν˜„μž¬ μ„ νƒλœ ν•­λͺ©μ— 체크(βœ“) ν‘œμ‹œ(`MenuItem.setChecked`). + +### λ°©μ–΄ 처리 +- `options`κ°€ λΉ„μ–΄ 있으면 ν•„λ“œλ₯Ό λΉ„ν™œμ„±(νšŒμƒ‰) μƒνƒœλ‘œ ν‘œμ‹œν•˜κ³  νƒ­ λ™μž‘μ„ λ§‰λŠ”λ‹€. + +### PopupMenu 선택 이유 +별도 λ ˆμ΄μ•„μ›ƒ/μ–΄λŒ‘ν„° 없이 View에 μ•΅μ»€λ˜λŠ” λ„€μ΄ν‹°λΈŒ νŒμ—…μ΄λΌ μ½”λ“œλŸ‰μ΄ 적닀. Spinner와 달리 M3 색상 토큰과 μΆ©λŒν•˜λŠ” OS κΈ°λ³Έ μŠ€νƒ€μΌ λ°•μŠ€κ°€ μ—†μ–΄ ν…Œλ§ˆ 일관성을 μœ μ§€ν•˜κΈ° 쉽닀. + +## 5. λ³€κ²½ 지점 + +| 파일 / μœ„μΉ˜ | λ³€κ²½ λ‚΄μš© | +|-------------|-----------| +| `Android/SampleTestAPP/src/main/java/com/example/sampletestapp/MainActivity.kt` (`chipRow()` μ •μ˜λΆ€ ~1588 근처) | `dropdownField()` 헬퍼 μ‹ κ·œ μΆ”κ°€ | +| `MainActivity.kt:565` (Run/OpenAI νƒ­ FAMILY) | `chipRow(...)` β†’ `dropdownField(...)` β€” 전달 λžŒλ‹€ 동일 | +| `MainActivity.kt:796` (Chat νƒ­ FAMILY) | `chipRow(...)` β†’ `dropdownField(...)` β€” 전달 λžŒλ‹€ 동일 | + +`ModelCatalog.kt`λŠ” λ³€κ²½ μ—†μŒ. + +### μ°Έκ³ : ν˜„μž¬ ν˜ΈμΆœλΆ€ (λ³€κ²½ ν›„ ν•¨μˆ˜λͺ…λ§Œ ꡐ체) + +Run/OpenAI νƒ­ (565–571): +```kotlin +body.addView(chipRow(t, ModelCatalog.families(), selFamily) { picked -> + selFamily = picked + selRuntime = ModelCatalog.runtimesFor(selFamily).firstOrNull() ?: selRuntime + selBackend = ModelCatalog.backendsFor(selFamily, selRuntime).firstOrNull() ?: selBackend + modelPathText = defaultModelPathFor(selDescriptor, selectedQuant) ?: "" + rebuildUi(resetModelPath = true) +}) +``` + +Chat νƒ­ (796–802): +```kotlin +modelCard.addView(chipRow(t, ModelCatalog.families(), chatSelFamily) { picked -> + chatSelFamily = picked + chatSelRuntime = ModelCatalog.runtimesFor(chatSelFamily).firstOrNull() ?: chatSelRuntime + chatSelBackend = ModelCatalog.backendsFor(chatSelFamily, chatSelRuntime).firstOrNull() ?: chatSelBackend + clearChatSessionState() + rebuildUi() +}) +``` + +각 호좜의 `chipRow` β†’ `dropdownField` ν•œ λ‹¨μ–΄λ§Œ 바뀐닀. + +## 6. 데이터 흐름 + +``` +μ‚¬μš©μž νƒ­ β†’ PopupMenu ν‘œμ‹œ(options = ModelCatalog.families()) + β†’ ν•­λͺ© 선택 β†’ onPick(picked) + β†’ (κΈ°μ‘΄ λžŒλ‹€) selFamily/chatSelFamily κ°±μ‹  + runtime/backend μž¬κ³„μ‚° + β†’ rebuildUi() β†’ dropdownFieldκ°€ μƒˆ selected κ°’μœΌλ‘œ λ‹€μ‹œ λ Œλ” +``` + +μƒνƒœ λ³€μˆ˜(`selFamily`, `chatSelFamily`)와 cascading μž¬κ³„μ‚°μ€ μ „ν˜€ λ°”λ€Œμ§€ μ•ŠλŠ”λ‹€. λ“œλ‘­λ‹€μš΄μ€ μΉ©κ³Ό λ™μΌν•œ μž…λ ₯/좜λ ₯ 계약을 λ”°λ₯΄λŠ” ν‘œν˜„(presentation) λ ˆμ΄μ–΄ ꡐ체일 뿐이닀. + +## 7. 였λ₯˜ 처리 / μ—£μ§€ μΌ€μ΄μŠ€ + +- **빈 μ˜΅μ…˜ λͺ©λ‘**: λΉ„ν™œμ„± ν•„λ“œλ‘œ ν‘œμ‹œ, νƒ­ λ¬΄λ™μž‘. +- **선택값이 μ˜΅μ…˜μ— μ—†μŒ**: `selected` ν…μŠ€νŠΈλ₯Ό κ·ΈλŒ€λ‘œ ν‘œμ‹œ(κΈ°μ‘΄ μΉ©κ³Ό λ™μΌν•˜κ²Œ κ°•μ œ λ³€κ²½ν•˜μ§€ μ•ŠμŒ). cascading κΈ°λ³Έκ°’ 둜직이 이미 유효 값을 보μž₯. +- **ν…Œλ§ˆ**: 라이트/닀크 μ–‘μͺ½μ—μ„œ 토큰 기반 색상 μ‚¬μš©μœΌλ‘œ μžλ™ λŒ€μ‘. + +## 8. ν…ŒμŠ€νŠΈ / 검증 + +ν”„λ‘œμ νŠΈλŠ” UIλ₯Ό μ½”λ“œλ‘œ 직접 μƒμ„±ν•˜λ©° μžλ™ν™” UI ν…ŒμŠ€νŠΈκ°€ μ—†λ‹€. 검증은 **λΉŒλ“œ + λ””λ°”μ΄μŠ€ μ„€μΉ˜ ν›„ μˆ˜λ™ 확인**으둜 μ§„ν–‰ν•œλ‹€ (λΉŒλ“œ ν™˜κ²½μ€ λ©”λͺ¨λ¦¬ `quickai-android-build-env` μ°Έκ³ ). + +검증 ν•­λͺ©: +1. Run/OpenAI 탭에 FAMILY λ“œλ‘­λ‹€μš΄μ΄ ν‘œμ‹œλ˜κ³  ν˜„μž¬ 선택값이 보인닀. +2. Chat 탭에 FAMILY λ“œλ‘­λ‹€μš΄μ΄ ν‘œμ‹œλ˜κ³  ν˜„μž¬ 선택값이 보인닀. +3. λ“œλ‘­λ‹€μš΄μ—μ„œ λ‹€λ₯Έ family 선택 μ‹œ RUNTIME/BACKEND 칩이 cascading으둜 μž¬κ³„μ‚°λœλ‹€ (νšŒκ·€ μ—†μŒ). +4. RUNTIME/BACKEND/QUANTIZATION은 μ—¬μ „νžˆ 칩으둜 ν‘œμ‹œλœλ‹€. +5. 라이트/닀크 ν…Œλ§ˆμ—μ„œ μ™Έν˜•μ΄ κΉ¨μ§€μ§€ μ•ŠλŠ”λ‹€. + +## 9. λ―Έν•΄κ²° / ν–₯ν›„ 과제 + +- ν–₯ν›„ RUNTIME/BACKEND/QUANTIZATION도 λ“œλ‘­λ‹€μš΄ 톡일을 μ›ν•˜λ©΄ 동일 `dropdownField()` μž¬μ‚¬μš©μœΌλ‘œ ν™•μž₯ κ°€λŠ₯ (이번 λ²”μœ„ λ°–). diff --git a/docs/superpowers/specs/2026-06-05-lfm2-vl-siglip-pluggable-design.md b/docs/superpowers/specs/2026-06-05-lfm2-vl-siglip-pluggable-design.md new file mode 100644 index 00000000..acc3c9f1 --- /dev/null +++ b/docs/superpowers/specs/2026-06-05-lfm2-vl-siglip-pluggable-design.md @@ -0,0 +1,170 @@ +# LFM2-VL(SigLIP + LFM2) β€” Quick.AI pluggable composer 톡합 섀계 λ¬Έμ„œ + +- **λ‚ μ§œ**: 2026-06-05 +- **브랜치**: `v0.4.0` (nntrainer μ„œλΈŒλͺ¨λ“ˆμ€ LFM2-VL 컀밋 `0b52d15` 체크아웃 μƒνƒœ) +- **λŒ€μƒ**: nntrainer μ„œλΈŒλͺ¨λ“ˆ Β· Quick.AI public API(`libquick_dot_ai_api.so`) Β· Android(`QuickDotAI` AAR + `SampleTestAPP`) +- **μƒνƒœ**: 승인됨, κ΅¬ν˜„ κ³„νš λŒ€κΈ° +- **λΆ„λ₯˜**: λ‚΄λΆ€ λ¬Έμ„œ (`docs/superpowers/`λŠ” git-ignore) + +## 1. λ°°κ²½ / 문제 + +nntrainer μ„œλΈŒλͺ¨λ“ˆμ— **LFM2-VL-450M**(이미지+ν…μŠ€νŠΈ β†’ ν…μŠ€νŠΈ) λ©€ν‹°λͺ¨λ‹¬ λͺ¨λΈμ΄ +μΆ”κ°€λ˜μ—ˆλ‹€. κ΅¬μ‘°λŠ” λ‹€μŒ 3개 μ»΄ν¬λ„ŒνŠΈλ‘œ 이루어진 **λͺ¨λ†€λ¦¬μ‹ μ˜€μΌ€μŠ€νŠΈλ ˆμ΄ν„°**λ‹€: + +- `Lfm2VlVisionTransformer` (SigLIP2 ViT) β€” `models/lfm2/lfm2-vl/vision/lfm2_vl_vision_transformer.{h,cpp}`, arch `"Lfm2VlVisionTransformer"`. 비인과/RoPE μ—†μŒ, patch16, NaFlex 지원. +- `Lfm2VlConnector` (pixel-unshuffle Γ—2 + LayerNorm + FCβ†’GELUβ†’FC, 3072β†’2560β†’1024) β€” `models/lfm2/lfm2-vl/lfm2_vl_connector.{h,cpp}`. Transformer μ„œλΈŒν΄λž˜μŠ€κ°€ **μ•„λ‹˜**(독립 MLP 클래슀). +- `Lfm2CausalLM` (hybrid conv/attn LM, hidden 1024) β€” `models/lfm2/lfm2_causallm.{h,cpp}`. `lookupEmbedding()`(FP32/Q4_0/Q6_K), `run_with_embeddings()` 이미 μ˜€λ²„λΌμ΄λ“œ. + +이 셋을 λ¬ΆλŠ” `Lfm2VlForConditionalGeneration`(`models/lfm2/lfm2-vl/lfm2_vl_model.{h,cpp}`, +arch `"Lfm2VlForConditionalGeneration"`)이 `run()` μ•ˆμ—μ„œ ViT 인코딩 β†’ +pixelUnshuffle β†’ connector β†’ `image_token_id=396` μœ„μΉ˜μ— μž„λ² λ”© splice β†’ +`run_with_embeddings()`κΉŒμ§€ λͺ¨λ‘ μˆ˜ν–‰ν•œλ‹€. **λ¬Έμ œλŠ” 이 κ²½λ‘œκ°€ nntrainer +standalone `main.cpp`(391–409ν–‰)의 직접 λΆ„κΈ°μ—μ„œλ§Œ λ™μž‘**ν•˜λ©°, Factory 등둝도 +μ—†κ³  Quick.AI API/μ•±μ—μ„œλŠ” **μ „ν˜€ μ‚¬μš©ν•  수 μ—†λ‹€**λŠ” 점이닀. CPU/FP32 λͺ¨λΈμ΄λ‹€. + +ν•œνŽΈ Quick.AI public APIμ—λŠ” 이미 **generic λ©€ν‹°λͺ¨λ‹¬ composer**κ°€ μžˆλ‹€: +`execute_multimodal()`(`api/quick_dot_ai_api.cpp:2175`)κ°€ +`[vision producer = models[0], LLM consumer = models[1]]` νŽ˜μ–΄λ₯Ό λ°›μ•„, ν…μŠ€νŠΈ +토큰 μŠ€νŠΈλ¦Όμ—μ„œ 이미지 마컀λ₯Ό μ°Ύμ•„ vision μž„λ² λ”©μ„ spliceν•˜κ³  +`llm->run_with_embeddings()`둜 μƒμ„±ν•œλ‹€. κ·ΈλŸ¬λ‚˜ 이 composer와 λͺ¨λ“  +`runMultimodal*` μ§„μž…μ μ΄ **`#ifdef ENABLE_QNN`둜 κ°€λ“œ**λ˜μ–΄ μžˆμ–΄(V-JEPAΒ·gauss-vision은 +NPU) CPU μ „μš© LFM2-VLμ—λŠ” κ²½λ‘œκ°€ 열리지 μ•ŠλŠ”λ‹€. λ˜ν•œ composerκ°€ μ°ΎλŠ” λ§ˆμ»€λŠ” +ν•˜λ“œμ½”λ”©λœ `<|image|>`(`:2188`)인데 LFM2λŠ” `image_token_id=396`(+start 498, +end 499)을 μ“΄λ‹€. + +μ•±(`SampleTestAPP`)은 μΉ΄νƒˆλ‘œκ·Έ 기반으둜 λ™μž‘ν•œλ‹€: `nativeQueryCatalog()` JSON β†’ +`ModelCatalog.kt` β†’ `QDA_CAP_MULTIMODAL` λΉ„νŠΈκ°€ 있으면 이미지 피컀 + +`runMultimodal*Streaming` JNI 경둜 μžλ™ ν™œμ„±. 단 μ „μ²˜λ¦¬λŠ” LLaVA-Next식 +(`LlavaNextImageProcessor.kt`, 512Β² 크둭, mean/std 0.5)이라 SigLIP2-NaFlex +(256Β², patch16) 규격과 λ§žμ§€ μ•ŠλŠ”λ‹€. 또 두 λͺ¨λΈμ„ νŽ˜μ–΄λ‘œ λ‘œλ“œν•˜λŠ” +`loadMultimodalHandleByName(emb_id, llm_id)`λŠ” **JNI에 λ…ΈμΆœλ˜μ–΄ μžˆμ§€ μ•Šλ‹€**. + +## 2. λͺ©ν‘œ / λ²”μœ„ + +### 포함 +- λͺ¨λ†€λ¦¬μ‹ LFM2-VL을 **두 개의 독립 Factory λͺ¨λΈ**둜 λΆ„ν•΄: + (A) `Lfm2VlVisionEncoder` = ViT + connectorλ₯Ό λ¬Άμ–΄ `run_image()` 좜λ ₯이 **LM + μž„λ² λ”© 곡간(1024-dim)** 으둜 λ‚˜μ˜€λŠ” vision 인코더, (B) `Lfm2CausalLM` 단독 LM. +- κΈ°μ‘΄ generic `execute_multimodal` composerκ°€ 이 νŽ˜μ–΄λ₯Ό κ΅¬λ™ν•˜λ„λ‘ + **CPU(native) λΉŒλ“œμ—μ„œ 컴파일/λ™μž‘**ν•˜κ²Œ κ°€λ“œ 개방. +- 이미지 마컀 μ •ν•©: composerκ°€ ν•˜λ“œμ½”λ”© `<|image|>` λŒ€μ‹  **LLM ν† ν¬λ‚˜μ΄μ €/λ””μŠ€ν¬λ¦½ν„°μ—μ„œ + image token idλ₯Ό 쑰회**(LFM2=396). +- 앱에 **vision encoder + LLM λ―ΉμŠ€μ•€λ§€μΉ˜ picker** UI μΆ”κ°€, `loadMultimodalHandleByName` + JNI λ…ΈμΆœ. +- **SigLIP2-NaFlex μ „μ²˜λ¦¬** Kotlin ν”„λ‘œμ„Έμ„œ μΆ”κ°€(MVP: κ³ μ • 256Β² square). +- 검증: **ν—€λ“œλ¦¬μŠ€ api ν…ŒμŠ€νŠΈ β†’ APK on-device** 단계적. + +### μ œμ™Έ (이번 λ²”μœ„ λ°–) +- λͺ¨λ†€λ¦¬μ‹ `Lfm2VlForConditionalGeneration` 자체 μ‚­μ œ/λ¦¬νŒ©ν„° β€” **레퍼런슀둜 + 보쑴**(ν—€λ“œλ¦¬μŠ€ μ •λ‹΅ λŒ€μ‘°μš©). +- full NaFlex 동적 해상도 + μœ„μΉ˜ μž„λ² λ”© 보간 β€” **후속 단계**(MVPλŠ” κ³ μ • 해상도). +- QNN/NPU λ²„μ „μ˜ LFM2-VL β€” ν˜„μž¬ λͺ¨λΈμ€ CPU/FP32. NPU μ–‘μžν™”λŠ” 별도 과제. +- 닀쀑 이미지(multi-image) LFM2-VL β€” 단일 이미지 μš°μ„ . + +## 3. μ•„ν‚€ν…μ²˜ / μ ‘κ·Ό + +핡심: **"ν•œ 덩어리 LFM2-VL을 vision(눈+λ³€ν™˜κΈ°) Β· LLM(λ‘λ‡Œ) 두 쑰각으둜 λΆ„ν•΄ β†’ +κΈ°μ‘΄ pluggable composer둜 νŽ˜μ–΄λ§ β†’ μ•±μ—μ„œ 골라 끼움"**. APIλŠ” gauss-agnostic을 +μœ μ§€ν•˜λ©°, LFM2-VL은 **곡개 nntrainer μ½”λ“œ**μ΄λ―€λ‘œ λž˜νΌλŠ” nntrainer μ„œλΈŒλͺ¨λ“ˆμ— +두고 Quick.AIλŠ” **곡개 descriptor만** μΆ”κ°€ν•œλ‹€. + +``` +[App] vision picker (SigLIP) ─┐ + LLM picker (LFM2) ─┴─► loadMultimodalHandleByName(emb_id, llm_id) + β”‚ +[API] composer (CPU 개방) ─────────── models[0]=Lfm2VlVisionEncoder + execute_multimodal β”‚ models[1]=Lfm2CausalLM + β”œβ”€ run_vision_encoder ─────── ↳ run_image β†’ ViTβ†’unshuffleβ†’connector β†’ 1024-dim emb + β”œβ”€ 이미지 토큰 id 쑰회(396) ── ↳ image marker μœ„μΉ˜ splice + └─ llm->run_with_embeddings β”˜ ↳ LFM2 λ””μ½”λ”© +``` + +## 4. λ³€κ²½ 지점 + +### 4.1 nntrainer β€” vision 인코더 래퍼 + LM Factory 등둝 + +- **μ‹ κ·œ `Lfm2VlVisionEncoder`** (`models/lfm2/lfm2-vl/`): 내뢀에 + `Lfm2VlVisionTransformer` + `Lfm2VlConnector`λ₯Ό μ†Œμœ . `run_image(pixels, …)` + μ˜€λ²„λΌμ΄λ“œ β†’ ViT.run() β†’ getLastFeatures() β†’ pixelUnshuffle() β†’ + connector.forward() β†’ **1024-dim `multimodal_pointer` λ°˜ν™˜**. `get_embedding_info()`둜 + LFM2 LM의 μž„λ² λ”© μ–‘μžν™” μŠ€μΌ€μΌ/μ˜€ν”„μ…‹κ³Ό μ •ν•©. κ²°μ • 사항: **κΈ°μ‘΄ + `Lfm2VlVisionTransformer`/`Lfm2VlConnector`λ₯Ό 직접 μˆ˜μ •ν•˜μ§€ μ•Šκ³  μƒˆ λž˜νΌμ—μ„œ + μ‘°ν•©**(κΈ°μ‘΄ λͺ¨λ†€λ¦¬μ‹ 경둜 무손상). Factory 등둝 arch `"Lfm2VlVisionEncoder"`, + capability `QDA_CAP_VISION_ENCODER`. +- **`Lfm2CausalLM` Factory 등둝 μΆ”κ°€**: arch `"Lfm2ForCausalLM"`(λ˜λŠ” config의 + μ‹€μ œ architectures λ¬Έμžμ—΄). 단독 ν…μŠ€νŠΈ LLMμœΌλ‘œλ„ λ‘œλ“œ κ°€λŠ₯. `lookupEmbedding`/ + `run_with_embeddings`λŠ” κ΅¬ν˜„ μ™„λ£Œ β†’ μΆ”κ°€ μž‘μ—… μ—†μŒ. +- μ˜€μΌ€μŠ€νŠΈλ ˆμ΄ν„°(`Lfm2VlForConditionalGeneration`) + `main.cpp` λΆ„κΈ°λŠ” 무변경. + +### 4.2 Quick.AI API β€” composer CPU 개방 + 마컀 μ •ν•© + descriptor + +| μœ„μΉ˜ | λ³€κ²½ | +|------|------| +| `api/quick_dot_ai_api.cpp` `execute_multimodal`/`run_vision_encoder`/`runMultimodal*` | `#ifdef ENABLE_QNN` κ°€λ“œμ—μ„œ 뢄리 β†’ native(CPU) λΉŒλ“œμ—μ„œλ„ ν™œμ„± | +| `api/quick_dot_ai_api.cpp:2188` 이미지 마컀 | ν•˜λ“œμ½”λ”© `<|image|>` β†’ **ν† ν¬λ‚˜μ΄μ €/λ””μŠ€ν¬λ¦½ν„°μ—μ„œ image token id 쑰회**(LFM2=396); start/end(498/499) λ§ˆμ»€λŠ” ν”„λ‘¬ν”„νŠΈ ν…œν”Œλ¦Ώμ—μ„œ 처리 | +| splice 토큰 개수 | `connector.outTokens()`(= n_patches / rΒ²)만큼 슬둯, vision 좜λ ₯ 길이 κ·ΈλŒ€λ‘œ μ‚¬μš© | +| `api/model_descriptors_public.cpp` | 곡개 descriptor 2개 μΆ”κ°€: `siglip2-vl-encoder`(VISION_ENCODER), `lfm2-450m`(LM, MULTIMODAL 짝). config_name = λ””λ°”μ΄μŠ€ dirκ³Ό 일치(μ†Œλ¬Έμž κ·œμΉ™). | + +`ModelDescriptor`에 `image_token`(λ¬Έμžμ—΄ λ˜λŠ” id) ν•„λ“œλ₯Ό μΆ”κ°€ν• μ§€, ν† ν¬λ‚˜μ΄μ € +special tokenμ—μ„œ μœ λ„ν• μ§€λŠ” κ΅¬ν˜„ κ³„νšμ—μ„œ ν™•μ •(ν† ν¬λ‚˜μ΄μ € μœ λ„κ°€ descriptor +μŠ€ν‚€λ§ˆ 변경을 ν”Όν•΄ μš°μ„ ). + +### 4.3 Android AAR / μ•± β€” λ―ΉμŠ€μ•€λ§€μΉ˜ picker + NaFlex μ „μ²˜λ¦¬ + +- **JNI λ…ΈμΆœ** (`Android/QuickDotAI/.../quickai_jni.cpp` + `NativeCausalLm.kt`): + `loadMultimodalHandleByName(emb_id, llm_id)` external fun μΆ”κ°€(ν˜„μž¬ λ―Έλ…ΈμΆœ). +- **λ―ΉμŠ€μ•€λ§€μΉ˜ UI** (`SampleTestAPP MainActivity.kt`): μΉ΄νƒˆλ‘œκ·Έμ—μ„œ capability둜 + ν•„ν„° β€” VISION_ENCODER λͺ©λ‘μ„ vision λ“œλ‘­λ‹€μš΄, 생성 κ°€λŠ₯ LLM λͺ©λ‘μ„ LLM + λ“œλ‘­λ‹€μš΄μœΌλ‘œ λ…ΈμΆœ. 두 κ°’ 선택 μ‹œ `loadMultimodalHandleByName`둜 νŽ˜μ–΄ λ‘œλ“œ. + `ModelCatalog.kt`에 `visionEncoders()` / `pairableLlms()` 헬퍼 μΆ”κ°€. +- **μ‹ κ·œ `SigLipNaFlexImageProcessor.kt`**: SigLIP2 규격(256Γ—256, patch16, + mean/std 0.5, CHW). MVPλŠ” **κ³ μ • 256Β² square λ¦¬μ‚¬μ΄μ¦ˆ**. λͺ¨λΈ id둜 + `LlavaNextImageProcessor`와 λΆ„κΈ°. `MultimodalInput`(pixelValues/numPatches/ + dims)에 맞좰 좜λ ₯. + +## 5. 데이터 흐름 + +``` +[λ””λ°”μ΄μŠ€] /sdcard/Download/aistudio-mobile/models// (ViT+connector κ°€μ€‘μΉ˜) + // (LFM2 LM κ°€μ€‘μΉ˜) +[μ•±] 이미지 ν”½ β†’ SigLipNaFlexImageProcessor β†’ pixelValues(256Β²,CHW) + numPatches + vision=siglip2-vl-encoder, llm=lfm2-450m 선택 + └─ loadMultimodalHandleByName(emb,llm) β†’ ν•œ handle (models[0]=ViT래퍼, models[1]=LM) + └─ runMultimodalHandleWithMessagesStreaming(...) + └─ [API] run_vision_encoder β†’ 1024-dim emb + β†’ 토큰화 ν›„ image token(396) μœ„μΉ˜μ— splice + β†’ llm->run_with_embeddings β†’ 슀트리밍 토큰 콜백 +``` + +## 6. 단계적 검증 + +1. **ν—€λ“œλ¦¬μŠ€ api ν…ŒμŠ€νŠΈ** (`quick_dot_ai_test`, λΉŒλ“œ ν™˜κ²½ λ©”λͺ¨λ¦¬ + `quickai-android-build-env` μ°Έκ³ ): vision 인코더 + LFM2 LM νŽ˜μ–΄ λ‘œλ“œ, + **μ „μ²˜λ¦¬λœ 이미지 ν…μ„œ + ν”„λ‘¬ν”„νŠΈ** μž…λ ₯ β†’ 좜λ ₯이 nntrainer λͺ¨λ†€λ¦¬μ‹ + `Lfm2VlForConditionalGeneration` 레퍼런슀(동일 이미지/ν”„λ‘¬ν”„νŠΈ)와 **토큰 λ‹¨μœ„ + 일치** 확인. connector 좜λ ₯의 μž„λ² λ”© μ–‘μžν™” 정합이 1μ°¨ 게이트. +2. **APK on-device** (S26 Ultra λ“±, 절차 λ©”λͺ¨λ¦¬ `gauss-pluggable-bringup`의 APK + verify μ°Έκ³ ): λ―ΉμŠ€μ•€λ§€μΉ˜λ‘œ νŽ˜μ–΄ λ‘œλ“œ β†’ Chat/OpenAI νƒ­μ—μ„œ 이미지+질문 β†’ + 이미지λ₯Ό μ‹€μ œλ‘œ μΈμ§€ν•œ **μΌκ΄€λœ 응닡** 확인(gauss-vision의 "이미지 μ•ˆ λ³΄μž„" + νšŒκ·€ μ—¬λΆ€ 점검). + +## 7. μ£Όμš” 리슀크 + +- **μž„λ² λ”© 곡간 μ •ν•©**: connector 좜λ ₯이 LFM2 LM μž„λ² λ”©μ˜ μ–‘μžν™” μŠ€μΌ€μΌ/μ˜€ν”„μ…‹κ³Ό + μ •ν™•νžˆ λ§žμ§€ μ•ŠμœΌλ©΄ 의미 손싀(gauss-vision "이미지 μ•ˆ λ³΄μž„"λ₯˜). ν—€λ“œλ¦¬μŠ€ 1μ°¨ + 게이트둜 μ‘°κΈ° 차단. +- **이미지 마컀 의미둠**: 396 μœ„μΉ˜ splice만으둜 λΆ€μ‘±ν•˜κ³  start/end(498/499) λ˜λŠ” + νŠΉμ • ν”„λ‘¬ν”„νŠΈ ν…œν”Œλ¦Ώμ΄ ν•„μš”ν•  수 있음 β€” λͺ¨λ†€λ¦¬μ‹ `run()`의 splice 둜직 + (lfm2_vl_model.cpp:238–254)을 μ •λ‹΅μœΌλ‘œ λŒ€μ‘°. +- **NaFlex μ „μ²˜λ¦¬ 정밀도**: κ³ μ • 256Β² MVPλŠ” 비정사각 μ΄λ―Έμ§€μ—μ„œ 정확도 ν•˜λ½ κ°€λŠ₯ β€” + 후속 동적 해상도 λ‹¨κ³„λ‘œ 뢄리. +- **CPU λ©€ν‹°λͺ¨λ‹¬ κ°€λ“œ 개방의 λΆ€μž‘μš©**: κΈ°μ‘΄ QNN 경둜(V-JEPA/gauss-vision) + λ¬΄νšŒκ·€ 확인 ν•„μš”. + +## 8. λ―Έν•΄κ²° / ν–₯ν›„ 과제 + +- full NaFlex 동적 해상도 + μœ„μΉ˜ μž„λ² λ”© 보간(nntrainer `naflex_preprocess.py`/ + `naflex_interp` μœ λ‹›ν…ŒμŠ€νŠΈ μ°Έκ³ ). +- LFM2-VL 닀쀑 이미지 지원. +- NPU/QNN μ–‘μžν™” 버전. +- `ModelDescriptor`에 image_token/νŽ˜μ–΄ 힌트 μŠ€ν‚€λ§ˆλ₯Ό 정식 μΆ”κ°€ν• μ§€ μ—¬λΆ€. diff --git a/docs/videos/GPT_OSS_20B_Demo.gif b/docs/videos/GPT_OSS_20B_Demo.gif deleted file mode 100644 index 88e68a38..00000000 Binary files a/docs/videos/GPT_OSS_20B_Demo.gif and /dev/null differ diff --git a/docs/videos/Qwen_30B_Demo.gif b/docs/videos/Qwen_30B_Demo.gif deleted file mode 100644 index bdd133af..00000000 Binary files a/docs/videos/Qwen_30B_Demo.gif and /dev/null differ diff --git a/docs/videos/moe-full.gif b/docs/videos/moe-full.gif deleted file mode 100644 index 6c5ed2db..00000000 Binary files a/docs/videos/moe-full.gif and /dev/null differ diff --git a/docs/videos/moe-on-the-fly.gif b/docs/videos/moe-on-the-fly.gif deleted file mode 100644 index 0b85d531..00000000 Binary files a/docs/videos/moe-on-the-fly.gif and /dev/null differ diff --git a/factory.h b/factory.h deleted file mode 100644 index 1c8499b5..00000000 --- a/factory.h +++ /dev/null @@ -1,62 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * Copyright (C) 2025 Eunju Yang - * - * @file causal_lm.h - * @date 22 July 2025 - * @see https://github.com/nntrainer/nntrainer - * @author Eunju Yang - * @bug No known bugs except for NYI items - * @brief CausalLM Factory to support registration and creation of various - * CausalLM models - */ - -#ifndef __CAUSALLM_FACTORY_H__ -#define __CAUSALLM_FACTORY_H__ - -#include -#include -#include - -namespace quick_dot_ai { - -/** - * @brief Factory class - */ -class Factory { -public: - using Creator = - std::function(json &, json &, json &)>; - - static Factory &Instance() { - static Factory factory; - return factory; - } - - void registerModel(const std::string &key, Creator creator) { - creators[key] = creator; - } - - std::unique_ptr create(const std::string &key, json &cfg, - json &generation_cfg, - json &nntr_cfg) const { - auto it = creators.find(key); - if (it != creators.end()) { - return (it->second)(cfg, generation_cfg, nntr_cfg); - } - return nullptr; - } - - void printRegistered(std::ostream &os) const { - for (const auto &pair : creators) { - os << "\n\t" << pair.first; - } - } - -private: - std::unordered_map creators; -}; - -} // namespace quick_dot_ai - -#endif diff --git a/huggingface_tokenizer.cpp b/huggingface_tokenizer.cpp deleted file mode 100644 index 07173684..00000000 --- a/huggingface_tokenizer.cpp +++ /dev/null @@ -1,132 +0,0 @@ - -/*! - * Copyright (c) 2023 by Contributors - * \file huggingface_tokenizer.cc - * \brief Huggingface tokenizer - */ -#include -#include - -#include - -namespace tokenizers { -/*! - * \brief A simple c++ header of tokenizer via C API. - */ -class HFTokenizer : public Tokenizer { -public: - explicit HFTokenizer(TokenizerHandle handle) : handle_(handle) { -#ifdef COMPILE_WASM_RUNTIME - setenv("TOKENIZERS_PARALLELISM", "false", true); -#endif - } - - HFTokenizer(const HFTokenizer &) = delete; - HFTokenizer(HFTokenizer &&other) { std::swap(other.handle_, handle_); } - - ~HFTokenizer() { - if (handle_ != nullptr) { - tokenizers_free(handle_); - } - } - - // use i32 to be consistent with sentencepiece - std::vector Encode(const std::string &text, - bool add_special_tokens) { - TokenizerEncodeResult result; - tokenizers_encode(handle_, text.data(), text.length(), - static_cast(add_special_tokens), &result); - std::vector ret(result.token_ids, result.token_ids + result.len); - tokenizers_free_encode_results(&result, 1); - return ret; - } - - // use i32 to be consistent with sentencepiece - std::vector Encode(const std::string &text) final { - return Encode(text, false); - } - - std::vector> - EncodeBatch(const std::vector &texts, bool add_special_tokens) { - std::vector texts_raw; - std::vector seq_lens; - size_t num_seqs = texts.size(); - texts_raw.reserve(num_seqs); - seq_lens.reserve(num_seqs); - for (const auto &text : texts) { - texts_raw.push_back(text.data()); - seq_lens.push_back(text.length()); - } - std::vector results(num_seqs); - tokenizers_encode_batch(handle_, texts_raw.data(), seq_lens.data(), - texts.size(), static_cast(add_special_tokens), - results.data()); - std::vector> ret; - ret.reserve(texts.size()); - for (size_t i = 0; i < texts.size(); ++i) { - ret.push_back(std::vector( - results[i].token_ids, results[i].token_ids + results[i].len)); - } - tokenizers_free_encode_results(results.data(), texts.size()); - return ret; - } - - std::vector> - EncodeBatch(const std::vector &texts) final { - return EncodeBatch(texts, false); - } - - // use i32 to be consistent with sentencepiece - std::string Decode(const std::vector &ids, - bool skip_special_tokens) { - tokenizers_decode(handle_, reinterpret_cast(ids.data()), - ids.size(), static_cast(skip_special_tokens)); - const char *data; - size_t len; - tokenizers_get_decode_str(handle_, &data, &len); - return std::string(data, len); - } - - std::string Decode(const std::vector &ids) final { - return Decode(ids, false); - } - - size_t GetVocabSize() final { - size_t size; - tokenizers_get_vocab_size(handle_, &size); - assert(size > 0); - return size; - } - - std::string IdToToken(int32_t id) final { - const char *data; - size_t len; - tokenizers_id_to_token(handle_, static_cast(id), &data, &len); - return std::string(data, len); - } - - int32_t TokenToId(const std::string &token) final { - int32_t id; - tokenizers_token_to_id(handle_, token.data(), token.length(), &id); - return id; - } - -private: - // internal handle - TokenizerHandle handle_{nullptr}; -}; - -std::unique_ptr Tokenizer::FromBlobJSON(const std::string &json) { - return std::make_unique( - tokenizers_new_from_str(json.data(), json.length())); -} - -std::unique_ptr -Tokenizer::FromBlobByteLevelBPE(const std::string &vocab, - const std::string &merges, - const std::string &added_tokens) { - return std::make_unique(byte_level_bpe_tokenizers_new_from_str( - vocab.data(), vocab.length(), merges.data(), merges.length(), - added_tokens.data(), added_tokens.length())); -} -} // namespace tokenizers diff --git a/install_android.sh b/install_android.sh index defc2f10..446e4f24 100755 --- a/install_android.sh +++ b/install_android.sh @@ -1,274 +1,115 @@ #!/bin/bash - -# Installation script for CausalLM Android application +# Unified installation script for Quick-Dot-AI on Android device +# Installs all built artifacts from builddir_android/ to /data/local/tmp/Quick.AI/ set -e -# Configuration -INSTALL_DIR="/data/local/tmp/quick_dot_ai" +INSTALL_DIR="/data/local/tmp/Quick.AI" MODEL_DIR="$INSTALL_DIR/models" -# Set SCRIPT_DIR SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" - -# Color codes -RED='\033[0;31m' -GREEN='\033[0;32m' -YELLOW='\033[1;33m' -CYAN='\033[0;36m' -BLUE='\033[0;34m' -NC='\033[0m' # No Color - -log_info() { - echo -e "${BLUE}[INFO]${NC} $1" -} - -log_success() { - echo -e "${GREEN}[SUCCESS]${NC} $1" -} - -log_warning() { - echo -e "${YELLOW}[WARNING]${NC} $1" -} - -log_error() { - echo -e "${RED}[ERROR]${NC} $1" -} - -log_header() { - echo -e "\n${CYAN}========================================${NC}" - echo -e "${CYAN} $1 ${NC}" - echo -e "${CYAN}========================================${NC}" -} - -log_step() { - echo -e "\n${YELLOW}[Step $1]${NC} $2" - echo -e "${YELLOW}----------------------------------------${NC}" -} - -log_header "Install CausalLM to Android Device" -log_info "INSTALL_DIR: $INSTALL_DIR" -log_info "SCRIPT_DIR: $SCRIPT_DIR" - -# Check if device is connected -log_step "1/3" "Check device connection" -if ! adb devices | grep -q "device$"; then - log_error "No Android device connected. Please connect a device and try again." +BUILD_DIR="$SCRIPT_DIR/builddir_android" +NNTRAINER_ROOT="$SCRIPT_DIR/nntrainer" +NNTRAINER_ANDROID="$NNTRAINER_ROOT/builddir/android_build_result/lib/arm64-v8a" + +# ── Validate ──────────────────────────────────────────────────────────── +if [ ! -d "$BUILD_DIR" ]; then + echo "Error: Build directory not found: $BUILD_DIR" + echo "Run './build.sh --platform=android' first." exit 1 fi -DEVICE_ID=$(adb devices | grep "device$" | head -1 | cut -f1) -log_success "Device connected: $DEVICE_ID" - -# Check if all required files exist -log_step "2/3" "Check build artifacts" -REQUIRED_FILES=( - "$SCRIPT_DIR/jni/libs/arm64-v8a/quick_dot_ai" - "$SCRIPT_DIR/jni/libs/arm64-v8a/libquick_dot_ai_core.so" - "$SCRIPT_DIR/jni/libs/arm64-v8a/quick_dot_ai_quantize" -) - -# Optional dependency files (might not be in libs/arm64-v8a depending on build) -DEP_FILES=( - "$SCRIPT_DIR/jni/libs/arm64-v8a/libnntrainer.so" - "$SCRIPT_DIR/jni/libs/arm64-v8a/libccapi-nntrainer.so" - "$SCRIPT_DIR/jni/libs/arm64-v8a/libc++_shared.so" -) - -# Check main build artifacts -ALL_FOUND=true -for file in "${REQUIRED_FILES[@]}"; do - if [ -f "$file" ]; then - size=$(du -h "$file" | cut -f1) - echo -e " ${GREEN}[OK]${NC} $(basename $file) ($size)" - else - echo -e " ${RED}[MISSING]${NC} $file" - ALL_FOUND=false - fi -done - -if [ "$ALL_FOUND" = false ]; then - log_error "Some required files are missing" - log_info "Please run: ./build_android.sh" +if ! adb devices | grep -q "device$"; then + echo "Error: No Android device connected." exit 1 fi -# Check dependencies with fallback to obj/local/arm64-v8a -for file in "${DEP_FILES[@]}"; do - filename=$(basename "$file") +INSTALL_LIBS_DIR="$SCRIPT_DIR/install_libs" - # Special handling for libc++_shared.so (Try copy from NDK) - if [[ "$filename" == "libc++_shared.so" ]] && [ ! -f "$file" ]; then - if [ -n "$ANDROID_NDK" ]; then - # Attempt to find it in typical NDK locations for aarch64 - NDK_LIBCXX=$(find "$ANDROID_NDK" -name "libc++_shared.so" 2>/dev/null | grep "aarch64" | head -n 1) +echo "=== Installing Quick-Dot-AI to Android device ===" +echo "Install dir: $INSTALL_DIR" +echo "" - if [ -n "$NDK_LIBCXX" ] && [ -f "$NDK_LIBCXX" ]; then - log_warning "libc++_shared.so not found in build dir, copying from NDK..." - cp "$NDK_LIBCXX" "$file" - # Fall through to standard check to confirm copy success - fi - fi - fi +mkdir -p "$INSTALL_LIBS_DIR" - if [ -f "$file" ]; then - size=$(du -h "$file" | cut -f1) - echo -e " ${GREEN}[OK]${NC} $filename ($size)" - else - # Try to find in obj directory - obj_path="$SCRIPT_DIR/jni/obj/local/arm64-v8a/$filename" - if [ -f "$obj_path" ]; then - log_warning "$filename found in obj, copying to libs..." - cp "$obj_path" "$file" - size=$(du -h "$file" | cut -f1) - echo -e " ${GREEN}[OK]${NC} $filename ($size) (Copied)" - else - echo -e " ${RED}[MISSING]${NC} $filename" - log_error "Required dependency not found" - exit 1 - fi - fi -done +adb shell "mkdir -p $INSTALL_DIR $MODEL_DIR" -log_success "All required build artifacts found" +# ── Push nntrainer runtime libraries ──────────────────────────────────── +echo "Pushing nntrainer libraries..." +adb push "$NNTRAINER_ANDROID/libnntrainer.so" $INSTALL_DIR/ && cp "$NNTRAINER_ANDROID/libnntrainer.so" "$INSTALL_LIBS_DIR/" +adb push "$NNTRAINER_ANDROID/libccapi-nntrainer.so" $INSTALL_DIR/ && cp "$NNTRAINER_ANDROID/libccapi-nntrainer.so" "$INSTALL_LIBS_DIR/" -# Check optional files (API and test app) -OPTIONAL_FILES=( - "$SCRIPT_DIR/jni/libs/arm64-v8a/libquick_dot_ai_api.so" - "$SCRIPT_DIR/jni/libs/arm64-v8a/test_api" -) +# ── Push built artifacts ──────────────────────────────────────────────── +echo "Pushing built artifacts..." -for file in "${OPTIONAL_FILES[@]}"; do - if [ -f "$file" ]; then - size=$(du -h "$file" | cut -f1) - echo -e " ${GREEN}[OK]${NC} $(basename $file) ($size) (Optional)" - fi +# src targets +for f in libcausallm.so libquick_dot_ai.so; do + [ -f "$BUILD_DIR/src/$f" ] && adb push "$BUILD_DIR/src/$f" $INSTALL_DIR/ && cp "$BUILD_DIR/src/$f" "$INSTALL_LIBS_DIR/" done -# Create directories on device -log_step "3/3" "Push files to device" -log_info "Creating directories on device..." -adb shell "mkdir -p $INSTALL_DIR" -adb shell "mkdir -p $MODEL_DIR" -log_success "Directories created" - -# Push executables -log_info "Pushing executables..." -adb push "$SCRIPT_DIR/jni/libs/arm64-v8a/quick_dot_ai" "$INSTALL_DIR/" 2>&1 | tail -1 -adb shell "chmod 755 $INSTALL_DIR/quick_dot_ai" -log_success "quick_dot_ai pushed" - -# Push optional test_api if exists -if [ -f "$SCRIPT_DIR/jni/libs/arm64-v8a/test_api" ]; then - log_info "Pushing test_api..." - adb push "$SCRIPT_DIR/jni/libs/arm64-v8a/test_api" "$INSTALL_DIR/" 2>&1 | tail -1 - adb shell "chmod 755 $INSTALL_DIR/test_api" - log_success "test_api pushed" +if [ -f "$BUILD_DIR/src/quick_dot_ai" ]; then + adb push "$BUILD_DIR/src/quick_dot_ai" $INSTALL_DIR/ && cp "$BUILD_DIR/src/quick_dot_ai" "$INSTALL_LIBS_DIR/" + adb shell "chmod 755 $INSTALL_DIR/quick_dot_ai" fi +# api target +[ -f "$BUILD_DIR/api/libquick_dot_ai_api.so" ] && \ + adb push "$BUILD_DIR/api/libquick_dot_ai_api.so" $INSTALL_DIR/ && cp "$BUILD_DIR/api/libquick_dot_ai_api.so" "$INSTALL_LIBS_DIR/" -log_info "Pushing quick_dot_ai_quantize..." -adb push "$SCRIPT_DIR/jni/libs/arm64-v8a/quick_dot_ai_quantize" "$INSTALL_DIR/" 2>&1 | tail -1 -adb shell "chmod 755 $INSTALL_DIR/quick_dot_ai_quantize" -log_success "quick_dot_ai_quantize pushed" - -# Push shared libraries -log_info "Pushing shared libraries..." -log_info " [1/6] libquick_dot_ai_core.so (CausalLM Core library)..." -adb push "$SCRIPT_DIR/jni/libs/arm64-v8a/libquick_dot_ai_core.so" "$INSTALL_DIR/" 2>&1 | tail -1 - -log_info " [2/6] libnntrainer.so (nntrainer library)..." -adb push "$SCRIPT_DIR/jni/libs/arm64-v8a/libnntrainer.so" "$INSTALL_DIR/" 2>&1 | tail -1 - -log_info " [3/6] libccapi-nntrainer.so (nntrainer C/C API)..." -adb push "$SCRIPT_DIR/jni/libs/arm64-v8a/libccapi-nntrainer.so" "$INSTALL_DIR/" 2>&1 | tail -1 +# api-test target +if [ -f "$BUILD_DIR/api-app/quick_dot_ai_test" ]; then + adb push "$BUILD_DIR/api-app/quick_dot_ai_test" $INSTALL_DIR/ && cp "$BUILD_DIR/api-app/quick_dot_ai_test" "$INSTALL_LIBS_DIR/" + adb shell "chmod 755 $INSTALL_DIR/quick_dot_ai_test" +fi -log_info " [4/6] libc++_shared.so (C++ runtime)..." -adb push "$SCRIPT_DIR/jni/libs/arm64-v8a/libc++_shared.so" "$INSTALL_DIR/" 2>&1 | tail -1 +[ -f "$BUILD_DIR/qnn/libqnn_context.so" ] && \ + adb push "$BUILD_DIR/qnn/libqnn_context.so" $INSTALL_DIR/ && cp "$BUILD_DIR/qnn/libqnn_context.so" "$INSTALL_LIBS_DIR/" -log_info " [5/6] libomp.so (OpenMP runtime)..." -if [ -f "$SCRIPT_DIR/jni/libs/arm64-v8a/libomp.so" ]; then - adb push "$SCRIPT_DIR/jni/libs/arm64-v8a/libomp.so" "$INSTALL_DIR/" 2>&1 | tail -1 -else - log_warning "libomp.so not found (skipping)" +# ── Push libc++_shared.so from NDK ────────────────────────────────────── +if [ -n "$ANDROID_NDK" ]; then + LIBCXX=$(find "$ANDROID_NDK" -name "libc++_shared.so" -path "*/aarch64*" 2>/dev/null | head -1) + if [ -n "$LIBCXX" ]; then + echo "Pushing libc++_shared.so..." + adb push "$LIBCXX" $INSTALL_DIR/ && cp "$LIBCXX" "$INSTALL_LIBS_DIR/" + fi fi -log_info " [6/6] libquick_dot_ai_api.so (CausalLM API library)..." -if [ -f "$SCRIPT_DIR/jni/libs/arm64-v8a/libquick_dot_ai_api.so" ]; then - adb push "$SCRIPT_DIR/jni/libs/arm64-v8a/libquick_dot_ai_api.so" "$INSTALL_DIR/" 2>&1 | tail -1 -else - log_warning "libquick_dot_ai_api.so not found (Optional, skipping)" +# ── Push model config files (res/) ────────────────────────────────────── +RES_DIR="$SCRIPT_DIR/src/res" +if [ -d "$RES_DIR" ]; then + echo "Pushing model configs..." + for model_dir in "$RES_DIR"/*/; do + model_name=$(basename "$model_dir") + adb shell "mkdir -p $MODEL_DIR/$model_name" + adb push "$model_dir." "$MODEL_DIR/$model_name/" + mkdir -p "$INSTALL_LIBS_DIR/models/$model_name" + cp -r "$model_dir." "$INSTALL_LIBS_DIR/models/$model_name/" + done fi -log_success "All libraries pushed" - -# Create run script on device -log_info "Creating run script on device..." -adb shell "cat > $INSTALL_DIR/run_causallm.sh << 'EOF' +# ── Create run scripts on device ──────────────────────────────────────── +adb shell "cat > $INSTALL_DIR/run.sh << 'EOF' #!/system/bin/sh -export LD_LIBRARY_PATH=$INSTALL_DIR:\$LD_LIBRARY_PATH -cd $INSTALL_DIR +export LD_LIBRARY_PATH=/data/local/tmp/Quick.AI:\$LD_LIBRARY_PATH +cd /data/local/tmp/Quick.AI +export NNTR_NUM_THREADS=7 ./quick_dot_ai \$@ -EOF -" -adb shell "chmod 755 $INSTALL_DIR/run_causallm.sh" - -# Create quantize run script on device -adb shell "cat > $INSTALL_DIR/run_quantize.sh << 'EOF' -#!/system/bin/sh -export LD_LIBRARY_PATH=$INSTALL_DIR:\$LD_LIBRARY_PATH -cd $INSTALL_DIR -./quick_dot_ai_quantize \$@ EOF" +adb shell "chmod 755 $INSTALL_DIR/run.sh" -adb shell "chmod 755 $INSTALL_DIR/run_quantize.sh" - -# Create test script on device if API lib exists -if [ -f "$SCRIPT_DIR/jni/libs/arm64-v8a/test_api" ]; then - adb shell "cat > $INSTALL_DIR/run_test_api.sh << 'EOF' +adb shell "cat > $INSTALL_DIR/run_test.sh << 'EOF' #!/system/bin/sh -export LD_LIBRARY_PATH=$INSTALL_DIR:\$LD_LIBRARY_PATH -cd $INSTALL_DIR -./test_api \$@ -EOF -" - adb shell "chmod 755 $INSTALL_DIR/run_test_api.sh" - log_info "Run script for test_api created" -fi - -log_success "Run scripts created" - -# Summary -log_header "Installation Complete!" -log_info "Device: $DEVICE_ID" -log_info "Install directory: $INSTALL_DIR" -log_info "Installed files:" -log_info " - quick_dot_ai (executable)" -if [ -f "$SCRIPT_DIR/jni/libs/arm64-v8a/test_api" ]; then - log_info " - test_api (executable)" -fi -log_info " - libquick_dot_ai_core.so (CausalLM Core library)" -if [ -f "$SCRIPT_DIR/jni/libs/arm64-v8a/libquick_dot_ai_api.so" ]; then - log_info " - libquick_dot_ai_api.so (CausalLM API library)" -fi -log_info " - libnntrainer.so" -log_info " - libccapi-nntrainer.so" -log_info " - libc++_shared.so" -log_info " - libomp.so (if available)" -log_header "How to run" -log_info "To run CausalLM on the device:" -log_info " 1. Push your model files to: $MODEL_DIR/" -log_info " Example: adb push res/qwen3/qwen3-4b $MODEL_DIR/qwen3-4b/" -log_info "2. Run the application:" -log_info " adb shell $INSTALL_DIR/run_causallm.sh $MODEL_DIR/qwen3-4b" -log_info "" -log_info "(optional) Run quantization:" -log_info " adb shell $INSTALL_DIR/run_quantize.sh $MODEL_DIR/qwen3-4b --fc_dtype Q4_0" -log_info "" -log_info "For interactive shell:" -log_info " adb shell" -log_info " cd $INSTALL_DIR" -log_info " ./run_causallm.sh $MODEL_DIR/qwen3-4b" -if [ -f "$SCRIPT_DIR/jni/libs/arm64-v8a/test_api" ]; then - log_info "To run API Test on the device:" - log_info " adb shell $INSTALL_DIR/run_test_api.sh [ARGS]" -fi +export LD_LIBRARY_PATH=/data/local/tmp/Quick.AI:\$LD_LIBRARY_PATH +cd /data/local/tmp/Quick.AI +export NNTR_NUM_THREADS=7 +./quick_dot_ai_test \$@ +EOF" +adb shell "chmod 755 $INSTALL_DIR/run_test.sh" + +echo "" +echo "=== Installation completed ===" +echo "" +echo "To run on device:" +echo " adb shell $INSTALL_DIR/run.sh $MODEL_DIR/" +echo "" +echo "To run API test:" +echo " adb shell $INSTALL_DIR/run_test.sh [prompt] [chat_template] [quant] [verbose]" diff --git a/install_libs/quick_dot_ai b/install_libs/quick_dot_ai new file mode 100755 index 00000000..e187ed5d Binary files /dev/null and b/install_libs/quick_dot_ai differ diff --git a/install_libs/quick_dot_ai_test b/install_libs/quick_dot_ai_test new file mode 100755 index 00000000..45237fde Binary files /dev/null and b/install_libs/quick_dot_ai_test differ diff --git a/jni/Android.mk b/jni/Android.mk deleted file mode 100644 index fc94e880..00000000 --- a/jni/Android.mk +++ /dev/null @@ -1,235 +0,0 @@ -LOCAL_PATH := $(call my-dir) - -include $(CLEAR_VARS) - -# ndk path -ifndef ANDROID_NDK -$(error ANDROID_NDK is not defined!) -endif - -ifndef NNTRAINER_ROOT -NNTRAINER_ROOT := $(LOCAL_PATH)/../subprojects/nntrainer -endif - -NNTRAINER_INCLUDES := $(NNTRAINER_ROOT)/builddir/android_build_result/include/nntrainer - -# Common Includes Definition -CAUSALLM_COMMON_INCLUDES := \ - $(LOCAL_PATH)/.. \ - $(LOCAL_PATH)/../layers \ - $(LOCAL_PATH)/../models \ - $(LOCAL_PATH)/../models/gpt_oss \ - $(LOCAL_PATH)/../models/gpt_oss_cached_slim \ - $(LOCAL_PATH)/../models/qwen2 \ - $(LOCAL_PATH)/../models/qwen3 \ - $(LOCAL_PATH)/../models/qwen3_moe \ - $(LOCAL_PATH)/../models/qwen3_slim_moe \ - $(LOCAL_PATH)/../models/qwen3_cached_slim_moe \ - $(LOCAL_PATH)/../models/gemma3 - -# Prebuilt nntrainer libraries -include $(CLEAR_VARS) -LOCAL_MODULE := nntrainer -LOCAL_SRC_FILES := $(NNTRAINER_ROOT)/builddir/android_build_result/lib/$(TARGET_ARCH_ABI)/libnntrainer.so -include $(PREBUILT_SHARED_LIBRARY) - -include $(CLEAR_VARS) -LOCAL_MODULE := ccapi-nntrainer -LOCAL_SRC_FILES := $(NNTRAINER_ROOT)/builddir/android_build_result/lib/$(TARGET_ARCH_ABI)/libccapi-nntrainer.so -include $(PREBUILT_SHARED_LIBRARY) - -# Tokenizer library -include $(CLEAR_VARS) -LOCAL_MODULE := tokenizers_c -LOCAL_SRC_FILES := ../lib/libtokenizers_android_c.a -include $(PREBUILT_STATIC_LIBRARY) - -# Build libquick_dot_ai_core.so (shared library - without api) -include $(CLEAR_VARS) - -LOCAL_ARM_NEON := true -LOCAL_CFLAGS += -std=c++17 -Ofast -mcpu=cortex-a53 -Ilz4-nougat/lib -DENABLE_FP16=1 -DUSE__FP16=1 -D__ARM_NEON__=1 -march=armv8.2-a+fp16+dotprod+i8mm -DUSE_NEON=1 -mtune=cortex-a76 -O3 -ffast-math -LOCAL_LDFLAGS += -Llz4-nougat/lib/obj/local/$(TARGET_ARCH_ABI)/ -LOCAL_CXXFLAGS += -std=c++17 -frtti -LOCAL_CFLAGS += -pthread -fexceptions -fopenmp -static-openmp -DENABLE_FP16=1 -DUSE__FP16=1 -D__ARM_NEON__=1 -march=armv8.2-a+fp16+dotprod+i8mm -DUSE_NEON=1 -mtune=cortex-a76 -O3 -ffast-math -LOCAL_LDFLAGS += -fexceptions -fopenmp -static-openmp -DENABLE_FP16=1 -DUSE__FP16=1 -D__ARM_NEON__=1 -march=armv8.2-a+fp16+dotprod+i8mm -DUSE_NEON=1 -mtune=cortex-a76 -O3 -ffast-math -LOCAL_ARM_MODE := arm -LOCAL_MODULE := quick_dot_ai_core -LOCAL_LDLIBS := -llog -landroid -fopenmp -static-openmp -DENABLE_FP16=1 -DUSE__FP16=1 -D__ARM_NEON__=1 -march=armv8.2-a+fp16+dotprod+i8mm -DUSE_NEON=1 - -LOCAL_SRC_FILES := \ - ../chat_template.cpp \ - ../models/causal_lm.cpp \ - ../models/transformer.cpp \ - ../models/sentence_transformer.cpp \ - ../models/qwen2/qwen2_causallm.cpp \ - ../models/qwen2/qwen2_embedding.cpp \ - ../models/qwen3/qwen3_causallm.cpp \ - ../models/qwen3/qwen3_embedding.cpp \ - ../models/qwen3_moe/qwen3_moe_causallm.cpp \ - ../models/qwen3_slim_moe/qwen3_slim_moe_causallm.cpp \ - ../models/qwen3_cached_slim_moe/qwen3_cached_slim_moe_causallm.cpp \ - ../models/gpt_oss/gptoss_causallm.cpp \ - ../models/gpt_oss_cached_slim/gptoss_cached_slim_causallm.cpp \ - ../huggingface_tokenizer.cpp \ - ../llm_util.cpp \ - ../layers/embedding_layer.cpp \ - ../layers/embedding_pooling_layer.cpp \ - ../layers/embedding_normalize_layer.cpp \ - ../layers/mha_core.cpp \ - ../layers/lm_head.cpp \ - ../models/qwen3_moe/qwen_moe_layer.cpp \ - ../layers/reshaped_rms_norm.cpp \ - ../layers/rms_norm.cpp \ - ../layers/swiglu.cpp \ - ../layers/tie_word_embedding.cpp \ - ../models/qwen3_cached_slim_moe/qwen_moe_layer_cached.cpp \ - ../layers/qkv_layer.cpp \ - ../models/qwen3_slim_moe/qwen_moe_layer_fsu.cpp \ - ../models/gpt_oss/gpt_oss_moe_layer.cpp \ - ../models/gpt_oss_cached_slim/gpt_oss_moe_layer_cached.cpp \ - ../models/gemma3/gemma3_causallm.cpp \ - ../models/gemma3/embedding_gemma.cpp \ - ../models/gemma3/function.cpp \ - -LOCAL_SHARED_LIBRARIES := nntrainer ccapi-nntrainer -LOCAL_STATIC_LIBRARIES := tokenizers_c - -LOCAL_C_INCLUDES += $(NNTRAINER_INCLUDES) $(CAUSALLM_COMMON_INCLUDES) - -include $(BUILD_SHARED_LIBRARY) - -# Build libquick_dot_ai_api.so (shared library - api only) -include $(CLEAR_VARS) - -LOCAL_ARM_NEON := true -LOCAL_CFLAGS += -std=c++17 -Ofast -mcpu=cortex-a53 -DENABLE_FP16=1 -DUSE__FP16=1 -D__ARM_NEON__=1 -march=armv8.2-a+fp16+dotprod+i8mm -DUSE_NEON=1 -mtune=cortex-a76 -O3 -ffast-math -LOCAL_CXXFLAGS += -std=c++17 -frtti -LOCAL_CFLAGS += -pthread -fexceptions -fopenmp -static-openmp -DENABLE_FP16=1 -DUSE__FP16=1 -D__ARM_NEON__=1 -march=armv8.2-a+fp16+dotprod+i8mm -DUSE_NEON=1 -mtune=cortex-a76 -O3 -ffast-math -LOCAL_LDFLAGS += -fexceptions -fopenmp -static-openmp -DENABLE_FP16=1 -DUSE__FP16=1 -D__ARM_NEON__=1 -march=armv8.2-a+fp16+dotprod+i8mm -DUSE_NEON=1 -mtune=cortex-a76 -O3 -ffast-math -LOCAL_ARM_MODE := arm -LOCAL_MODULE := quick_dot_ai_api -LOCAL_LDLIBS := -llog -landroid -fopenmp -static-openmp -DENABLE_FP16=1 -DUSE__FP16=1 -D__ARM_NEON__=1 -march=armv8.2-a+fp16+dotprod+i8mm -DUSE_NEON=1 - -LOCAL_SRC_FILES := \ - ../api/causal_lm_api.cpp \ - ../api/model_config.cpp - -LOCAL_SHARED_LIBRARIES := quick_dot_ai_core nntrainer ccapi-nntrainer -LOCAL_STATIC_LIBRARIES := tokenizers_c - -LOCAL_C_INCLUDES += $(NNTRAINER_INCLUDES) $(CAUSALLM_COMMON_INCLUDES) \ - $(LOCAL_PATH)/../api - -include $(BUILD_SHARED_LIBRARY) - -# Build quick_dot_ai executable -include $(CLEAR_VARS) - -LOCAL_ARM_NEON := true -LOCAL_CFLAGS += -std=c++17 -Ofast -mcpu=cortex-a53 -DENABLE_FP16=1 -DUSE__FP16=1 -D__ARM_NEON__=1 -march=armv8.2-a+fp16+dotprod+i8mm -DUSE_NEON=1 -mtune=cortex-a76 -O3 -ffast-math -LOCAL_CXXFLAGS += -std=c++17 -frtti -LOCAL_CFLAGS += -pthread -fexceptions -fopenmp -static-openmp -DENABLE_FP16=1 -DUSE__FP16=1 -D__ARM_NEON__=1 -march=armv8.2-a+fp16+dotprod+i8mm -DUSE_NEON=1 -mtune=cortex-a76 -O3 -ffast-math -LOCAL_LDFLAGS += -fexceptions -fopenmp -static-openmp -DENABLE_FP16=1 -DUSE__FP16=1 -D__ARM_NEON__=1 -march=armv8.2-a+fp16+dotprod+i8mm -DUSE_NEON=1 -mtune=cortex-a76 -O3 -ffast-math -LOCAL_MODULE_TAGS := optional -LOCAL_ARM_MODE := arm -LOCAL_MODULE := quick_dot_ai -LOCAL_LDLIBS := -llog -landroid -fopenmp -static-openmp -DENABLE_FP16=1 -DUSE__FP16=1 -D__ARM_NEON__=1 -march=armv8.2-a+fp16+dotprod+i8mm -DUSE_NEON=1 - -LOCAL_SRC_FILES := ../main.cpp - -LOCAL_SHARED_LIBRARIES := quick_dot_ai_core nntrainer ccapi-nntrainer -LOCAL_STATIC_LIBRARIES := tokenizers_c - -LOCAL_C_INCLUDES += $(NNTRAINER_INCLUDES) $(CAUSALLM_COMMON_INCLUDES) - -include $(BUILD_EXECUTABLE) - -# Build test_api executable -include $(CLEAR_VARS) - -LOCAL_ARM_NEON := true -LOCAL_CFLAGS += -std=c++17 -Ofast -mcpu=cortex-a53 -DENABLE_FP16=1 -DUSE__FP16=1 -D__ARM_NEON__=1 -march=armv8.2-a+fp16+dotprod+i8mm -DUSE_NEON=1 -mtune=cortex-a76 -O3 -ffast-math -LOCAL_CXXFLAGS += -std=c++17 -frtti -LOCAL_CFLAGS += -pthread -fexceptions -fopenmp -static-openmp -DENABLE_FP16=1 -DUSE__FP16=1 -D__ARM_NEON__=1 -march=armv8.2-a+fp16+dotprod+i8mm -DUSE_NEON=1 -mtune=cortex-a76 -O3 -ffast-math -LOCAL_LDFLAGS += -fexceptions -fopenmp -static-openmp -DENABLE_FP16=1 -DUSE__FP16=1 -D__ARM_NEON__=1 -march=armv8.2-a+fp16+dotprod+i8mm -DUSE_NEON=1 -mtune=cortex-a76 -O3 -ffast-math -LOCAL_MODULE_TAGS := optional -LOCAL_ARM_MODE := arm -LOCAL_MODULE := quick_dot_ai_test_api -LOCAL_LDLIBS := -llog -landroid -fopenmp -static-openmp -DENABLE_FP16=1 -DUSE__FP16=1 -D__ARM_NEON__=1 -march=armv8.2-a+fp16+dotprod+i8mm -DUSE_NEON=1 - -LOCAL_SRC_FILES := ../api/test_api.cpp - -LOCAL_SHARED_LIBRARIES := quick_dot_ai_api quick_dot_ai_core nntrainer ccapi-nntrainer -LOCAL_STATIC_LIBRARIES := tokenizers_c - -LOCAL_C_INCLUDES += $(NNTRAINER_INCLUDES) $(CAUSALLM_COMMON_INCLUDES) \ - $(LOCAL_PATH)/../api - -include $(BUILD_EXECUTABLE) - - -# Build nntr_quantize executable -include $(CLEAR_VARS) - -LOCAL_ARM_NEON := true -LOCAL_CFLAGS += -std=c++17 -Ofast -mcpu=cortex-a53 -Ilz4-nougat/lib -DENABLE_FP16=1 -DUSE__FP16=1 -D__ARM_NEON__=1 -march=armv8.2-a+fp16+dotprod+i8mm -DUSE_NEON=1 -mtune=cortex-a76 -O3 -ffast-math -LOCAL_LDFLAGS += -Llz4-nougat/lib/obj/local/$(TARGET_ARCH_ABI)/ -LOCAL_CXXFLAGS += -std=c++17 -frtti -LOCAL_CFLAGS += -pthread -fexceptions -fopenmp -static-openmp -DENABLE_FP16=1 -DUSE__FP16=1 -D__ARM_NEON__=1 -march=armv8.2-a+fp16+dotprod+i8mm -DUSE_NEON=1 -mtune=cortex-a76 -O3 -ffast-math -LOCAL_LDFLAGS += -fexceptions -fopenmp -static-openmp -DENABLE_FP16=1 -DUSE__FP16=1 -D__ARM_NEON__=1 -march=armv8.2-a+fp16+dotprod+i8mm -DUSE_NEON=1 -mtune=cortex-a76 -O3 -ffast-math -LOCAL_MODULE_TAGS := optional -LOCAL_ARM_MODE := arm -LOCAL_MODULE := quick_dot_ai_quantize -LOCAL_LDLIBS := -llog -landroid -fopenmp -static-openmp -DENABLE_FP16=1 -DUSE__FP16=1 -D__ARM_NEON__=1 -march=armv8.2-a+fp16+dotprod+i8mm -DUSE_NEON=1 - -# Source files -LOCAL_SRC_FILES := ../quantize.cpp \ - ../models/causal_lm.cpp \ - ../models/transformer.cpp \ - ../models/sentence_transformer.cpp \ - ../models/qwen2/qwen2_causallm.cpp \ - ../models/qwen2/qwen2_embedding.cpp \ - ../models/qwen3/qwen3_causallm.cpp \ - ../models/qwen3/qwen3_embedding.cpp \ - ../models/qwen3_moe/qwen3_moe_causallm.cpp \ - ../models/qwen3_slim_moe/qwen3_slim_moe_causallm.cpp \ - ../models/qwen3_cached_slim_moe/qwen3_cached_slim_moe_causallm.cpp \ - ../models/gpt_oss/gptoss_causallm.cpp \ - ../models/gpt_oss_cached_slim/gptoss_cached_slim_causallm.cpp \ - ../llm_util.cpp \ - ../layers/embedding_layer.cpp \ - ../layers/embedding_pooling_layer.cpp \ - ../layers/embedding_normalize_layer.cpp \ - ../layers/mha_core.cpp \ - ../models/qwen3_moe/qwen_moe_layer.cpp \ - ../layers/reshaped_rms_norm.cpp \ - ../layers/rms_norm.cpp \ - ../layers/swiglu.cpp \ - ../layers/tie_word_embedding.cpp\ - ../layers/lm_head.cpp\ - ../models/qwen3_cached_slim_moe/qwen_moe_layer_cached.cpp \ - ../layers/qkv_layer.cpp \ - ../models/qwen3_slim_moe/qwen_moe_layer_fsu.cpp \ - ../models/gpt_oss/gpt_oss_moe_layer.cpp \ - ../models/gpt_oss_cached_slim/gpt_oss_moe_layer_cached.cpp \ - ../models/gemma3/gemma3_causallm.cpp \ - ../models/gemma3/embedding_gemma.cpp \ - -LOCAL_SHARED_LIBRARIES := nntrainer ccapi-nntrainer -LOCAL_STATIC_LIBRARIES := tokenizers_c - -LOCAL_C_INCLUDES += $(NNTRAINER_INCLUDES) \ - $(LOCAL_PATH)/.. \ - $(LOCAL_PATH)/../layers \ - $(LOCAL_PATH)/../models \ - $(LOCAL_PATH)/../models/gpt_oss \ - $(LOCAL_PATH)/../models/gpt_oss_cached_slim \ - $(LOCAL_PATH)/../models/qwen2 \ - $(LOCAL_PATH)/../models/qwen3 \ - $(LOCAL_PATH)/../models/qwen3_moe \ - $(LOCAL_PATH)/../models/qwen3_slim_moe \ - $(LOCAL_PATH)/../models/qwen3_cached_slim_moe \ - $(LOCAL_PATH)/../models/gemma3 \ - -include $(BUILD_EXECUTABLE) diff --git a/jni/prepare_encoder.ps1 b/jni/prepare_encoder.ps1 deleted file mode 100644 index ae0cd45f..00000000 --- a/jni/prepare_encoder.ps1 +++ /dev/null @@ -1,59 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -## -# @file prepare_encoder.ps1 -# @brief Download the encoder archive and place json.hpp at the project root. -# @usage ./prepare_encoder.ps1 - -param ( - [string]$Target, - [string]$TargetVersion -) - -$TarPrefix = "encoder" -$TarName = "$TarPrefix-$TargetVersion.tar.gz" -$Url = "https://github.com/nnstreamer/nnstreamer-android-resource/raw/main/external/$TarName" - -$ScriptDir = Split-Path -Parent $MyInvocation.MyCommand.Definition -$ProjectRoot = Resolve-Path (Join-Path $ScriptDir "..") - -Write-Output "PREPARING Encoder at $Target" - -if (-Not (Test-Path $Target)) { - New-Item -ItemType Directory -Path $Target | Out-Null -} - -Push-Location $Target - -function Download-Encoder { - if (Test-Path $TarName) { - Write-Output "$TarName exists, skip downloading" - return - } - - Write-Output "[Encoder] downloading $TarName" - try { - Invoke-WebRequest -Uri $Url -OutFile $TarName - Write-Output "[Encoder] Finish downloading encoder" - } catch { - Write-Output "[Encoder] Download failed, please check url" - exit 1 - } -} - -function Untar-Encoder { - Write-Output "[Encoder] untar encoder" - tar -zxvf $TarName - Remove-Item $TarName - - if ($TargetVersion -eq "0.2") { - Copy-Item -Path "json.hpp" -Destination (Join-Path $ProjectRoot "json.hpp") -Force - Write-Output "[Encoder] Copied json.hpp to $ProjectRoot" - } -} - -if (-Not (Test-Path "$TarPrefix")) { - Download-Encoder - Untar-Encoder -} - -Pop-Location diff --git a/jni/prepare_encoder.sh b/jni/prepare_encoder.sh deleted file mode 100755 index a9bdfbb4..00000000 --- a/jni/prepare_encoder.sh +++ /dev/null @@ -1,47 +0,0 @@ -#! /bin/bash -# SPDX-License-Identifier: Apache-2.0 -## -# @file prepare_encoder.sh -# @brief Download the encoder archive and place json.hpp at the project root. -# @usage ./prepare_encoder.sh - -set -e -TARGET=$1 -TARGET_VERSION=$2 -TAR_PREFIX=encoder -TAR_NAME=${TAR_PREFIX}-${TARGET_VERSION}.tar.gz -URL="https://github.com/nnstreamer/nnstreamer-android-resource/raw/main/external/${TAR_NAME}" - -SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -PROJECT_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)" - -echo "PREPARING Encoder at ${TARGET}" - -[ ! -d "${TARGET}" ] && mkdir -p "${TARGET}" - -pushd "${TARGET}" > /dev/null - -_download_encoder() { - [ -f "$TAR_NAME" ] && echo "${TAR_NAME} exists, skip downloading" && return 0 - echo "[Encoder] downloading ${TAR_NAME}" - if ! wget -q "${URL}"; then - echo "[Encoder] Download failed, please check url" - exit 1 - fi - echo "[Encoder] Finish downloading encoder" -} - -_untar_encoder() { - echo "[Encoder] untar encoder" - tar -zxvf "${TAR_NAME}" -C "${TARGET}" - rm -f "${TAR_NAME}" - - if [ "${TARGET_VERSION}" = "0.2" ]; then - cp -f json.hpp "${PROJECT_ROOT}/" - echo "[Encoder] Copied json.hpp to ${PROJECT_ROOT}/" - fi -} - -[ ! -d "${TAR_PREFIX}" ] && _download_encoder && _untar_encoder - -popd > /dev/null diff --git a/layers/causallm_common_properties.h b/layers/causallm_common_properties.h deleted file mode 100644 index 0c41fead..00000000 --- a/layers/causallm_common_properties.h +++ /dev/null @@ -1,104 +0,0 @@ -/** - * Copyright (C) 2025 Samsung Electronics Co., Ltd. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * http://www.apache.org/licenses/LICENSE-2.0 - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - * - * @file causallm_common_properties.h - * @date 23 July 2025 - * @brief This defines a qwen3 causal language model. - * @see https://github.com/nnstreamer/ - * @author Eunju Yang - * @bug No known bugs except for NYI items - * - */ -#ifndef __CAUSALLM_COMMON_PROPERTIES_H__ -#define __CAUSALLM_COMMON_PROPERTIES_H__ - -#pragma once -#ifdef _WIN32 -#define WIN_EXPORT __declspec(dllexport) -#else -#define WIN_EXPORT -#endif - -#include -#include -#include - -namespace quick_dot_ai { - -namespace props { - -/** - * @brief MoE activation type - */ -class MoEActivation final - : public nntrainer::EnumProperty { -public: - using prop_tag = nntrainer::enum_class_prop_tag; - static constexpr const char *key = "moe_activation"; -}; -/** - * @brief NumExperts, Number of experts property - */ -class NumExperts : public nntrainer::PositiveIntegerProperty { -public: - static constexpr const char *key = "num_experts"; /**< unique key to access */ - using prop_tag = nntrainer::uint_prop_tag; /**< property type */ -}; - -/** - * @brief NumExpertsPerToken, Number of experts per token property - */ -class NumExpertsPerToken : public nntrainer::PositiveIntegerProperty { -public: - static constexpr const char *key = - "num_experts_per_token"; /**< unique key to access */ - using prop_tag = nntrainer::uint_prop_tag; /**< property type */ -}; - -/** - * @brief unit property, unit is used to measure how many weights are there - * - */ -class FeatureSize : public nntrainer::PositiveIntegerProperty { -public: - static constexpr const char *key = - "feature_size"; /**< unique key to access */ - using prop_tag = nntrainer::uint_prop_tag; /**< property type */ -}; - -/** - * @brief RMS_NORM_GAMMA_INIT Initialization Enumeration Information - * - */ -WIN_EXPORT class RMS_NORM_GAMMA_INIT final - : public nntrainer::EnumProperty { -public: - /** - * @brief Construct a CUSTOM_RMS_NORM_GAMMA_INIT object - */ - WIN_EXPORT RMS_NORM_GAMMA_INIT( - nntrainer::Initializer value = nntrainer::Initializer::ONES) { - set(value); - }; - - using prop_tag = nntrainer::enum_class_prop_tag; - static constexpr const char *key = "gamma_initializer"; -}; -}; // namespace props - -WIN_EXPORT enum RMSParams { gamma }; - -} // namespace quick_dot_ai - -#endif diff --git a/layers/embedding_layer.cpp b/layers/embedding_layer.cpp deleted file mode 100644 index cf6427cc..00000000 --- a/layers/embedding_layer.cpp +++ /dev/null @@ -1,257 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * Copyright (C) 2020 Jijoong Moon - * - * @file embedding.cpp - * @date 04 March 2021 - * @brief This is Embedding Layer Class of Neural Network - * @see https://github.com/nntrainer/nntrainer - * @author Jijoong Moon - * @bug No known bugs except for NYI items - * @note This embedding layer supports FP32/FP16/Q6_K data type only. - */ - -#include -#include -#include -#include -#include -#include - -namespace quick_dot_ai { - -static constexpr size_t SINGLE_INOUT_IDX = 0; - -enum EmbeddingParams { weight }; - -EmbeddingLayer::EmbeddingLayer() : - LayerImpl(), - embedding_props(nntrainer::props::InDim(), nntrainer::props::OutDim(), - nntrainer::props::Scale()), - weight_idx(std::numeric_limits::max()) {} - -void EmbeddingLayer::finalize(nntrainer::InitLayerContext &context) { - NNTR_THROW_IF(context.getNumInputs() != 1, std::invalid_argument) - << "Embedding layer takes only one input"; - - const nntrainer::TensorDim &input_dim = - context.getInputDimensions()[SINGLE_INOUT_IDX]; - NNTR_THROW_IF(input_dim.channel() != 1, std::invalid_argument) - << "Embedding layer takes only one for channel size"; - - NNTR_THROW_IF(input_dim.getDataType() != nntrainer::TensorDim::DataType::FP32, - std::invalid_argument) - << "Embedding layer takes only FP32 input data"; - - auto &weight_regularizer = - std::get(*layer_impl_props); - auto &weight_regularizer_constant = - std::get(*layer_impl_props); - auto weight_initializer = nntrainer::props::InitializerInfo::Enum::NONE; - auto &weight_decay = - std::get(*layer_impl_props); - - size_t in_dim = - static_cast(std::get(embedding_props)); - size_t out_dim = - static_cast(std::get(embedding_props)); - - nntrainer::TensorDim output_dim = input_dim; - - // output_dim expected as hidden x num input (batch size) - output_dim.height(input_dim.width()); - output_dim.width(out_dim); - output_dim.setTensorType( - {context.getFormat(), context.getActivationDataType()}); - context.setOutputDimensions({output_dim}); - - nntrainer::TensorDim dim = output_dim; - - dim.setTensorType({context.getFormat(), context.getWeightDataType()}); - - dim.height(in_dim); - dim.width(out_dim); - dim.batch(1); - - weight_idx = context.requestWeight( - dim, weight_initializer, weight_regularizer, weight_regularizer_constant, - weight_decay, "Embedding", true); -} - -void EmbeddingLayer::setProperty(const std::vector &values) { - auto remain_props = loadProperties(values, embedding_props); - LayerImpl::setProperty(remain_props); -} - -void EmbeddingLayer::forwarding(nntrainer::RunLayerContext &context, - bool training) {} - -void EmbeddingLayer::incremental_forwarding(nntrainer::RunLayerContext &context, - unsigned int from, unsigned int to, - bool training) { - - /// @todo get input and output dimension from input_ and hidden itself - unsigned int in_dim = std::get(embedding_props); - unsigned int out_dim = std::get(embedding_props); - float scale = std::get(embedding_props).empty() - ? 1.0f - : std::get(embedding_props).get(); - unsigned int _from = from; - - nntrainer::Tensor &weight = context.getWeight(weight_idx); - nntrainer::Tensor &hidden_ = context.getOutput(SINGLE_INOUT_IDX); - nntrainer::Tensor &input_ = context.getInput(SINGLE_INOUT_IDX); - - nntrainer::TensorDim out_tensor_dim = - nntrainer::TensorDim({1, 1, 1, out_dim}, hidden_.getTensorType()); - - unsigned int b_size = input_.batch(); - - for (unsigned int b = 0; b < b_size; ++b) { - float *in_data = - input_.getAddress(b * input_.getDim().getFeatureLen()); - nntrainer::Tensor batchsliced_hidden = hidden_.getBatchSlice(b, 1); - - int iter = to - from; - -#pragma omp parallel for - for (int i = 0; i < iter; ++i) { - size_t embed_idx = static_cast(in_data[i]); - if (embed_idx >= in_dim) { - throw std::invalid_argument("input word index is greater than in_dim"); - } - - nntrainer::Tensor cur_weight = - weight.getSharedDataTensor(out_tensor_dim, out_dim * embed_idx); - nntrainer::Tensor out_tensor = - batchsliced_hidden.getSharedDataTensor(out_tensor_dim, out_dim * (i)); - - if (weight.getDataType() == nntrainer::TensorDim::DataType::Q6_K) { - ///@note this should be replaced with quantizer operation - int num_blocks_per_row = (weight.width() + 256 - 1) / 256; - nntrainer::dequantize_row_q6_K( - (void *)((char *)weight.getData() + - (210 * num_blocks_per_row) * embed_idx), - out_tensor.getData(), out_dim); - } else if (weight.getDataType() == nntrainer::TensorDim::DataType::Q4_0) { - ///@note this should be replaced with quantizer operation - int num_blocks_per_row = (weight.width() + 32 - 1) / 32; - nntrainer::dequantize_row_q4_0( - (void *)((char *)weight.getData() + - (18 * num_blocks_per_row) * embed_idx), - out_tensor.getData(), out_dim); - } else { - out_tensor.copyData(cur_weight); - } - - if (scale != 1.0f) { - out_tensor.multiply_i(scale); - } - } - -#ifdef DEBUG - std::cout << context.getName() << " : " - << "\n input:" << input_ << "\n weight: " << weight - << "\n hidden: " << hidden_ << std::endl; -#endif - } -} - -void EmbeddingLayer::calcDerivative(nntrainer::RunLayerContext &context) { - throw nntrainer::exception::not_supported( - "calcDerivative for Embedding layer is not supported"); -} - -void EmbeddingLayer::calcGradient(nntrainer::RunLayerContext &context) {} - -void EmbeddingLayer::exportTo(nntrainer::Exporter &exporter, - const ml::train::ExportMethods &method) const { - LayerImpl::exportTo(exporter, method); - exporter.saveResult(embedding_props, method, this); -} - -void EmbeddingLayer::save(std::ofstream &file, - nntrainer::RunLayerContext &run_context, bool opt_var, - ml::train::ExecutionMode mode, bool trainable, - nntrainer::TensorDim::DataType dtype) const { - // @note shared weights are only be saved at the first access - for (unsigned int i = 0; i < run_context.getNumWeights(); ++i) { - if (run_context.isGradientFirstAccess(i)) { - auto &weight = run_context.getWeight(i); - if (dtype == nntrainer::TensorDim::DataType::NONE || - weight.getDataType() == dtype) - weight.save(file); - else { - NNTR_THROW_IF(weight.getDataType() != - nntrainer::TensorDim::DataType::FP32, - std::runtime_error) - << "Save with quantization only supports for FP32 weight."; - ///@note The codelines below can be replaced with quantizer's - /// quantize() - nntrainer::TensorDim dim = weight.getDim(); - unsigned int K = dim.height(); - unsigned int N = dim.width(); - - if (dtype == nntrainer::TensorDim::DataType::Q4_0) { - - // Skip quantization for bias-like tensors (1D with height == 1) - // as they are not suitable for Q4_0 block quantization - if (K == 1) { - weight.save(file); - } else { - NNTR_THROW_IF(N % 32 != 0 || K % 32 != 0, std::invalid_argument) - << "Q4_0 quantization requires both width and height to be " - "divisible by 32, but got height=" - << K << ", width=" << N; - ////////////////////////////////////////////////////////////////// - ///@note Please note that Embedding layer doesn't need to be - /// transposed! - ////////////////////////////////////////////////////////////////// - nntrainer::Tensor quant_weight(dim.batch(), dim.channel(), K, N, - {nntrainer::Tformat::NCHW, dtype}); - nntrainer::quantize_q4_0(weight.getData(), - quant_weight.getData(), K, N, - nullptr); - quant_weight.save(file); - } - } else if (dtype == nntrainer::TensorDim::DataType::Q6_K) { - ////////////////////////////////////////////////////////////////// - ///@note Please note that Embedding layer doesn't need to be - /// transposed! - ////////////////////////////////////////////////////////////////// - nntrainer::Tensor quant_weight(dim.batch(), dim.channel(), K, N, - {nntrainer::Tformat::NCHW, dtype}); - nntrainer::quantize_q6_K(weight.getData(), - quant_weight.getData(), K, N, - nullptr); - quant_weight.save(file); - } else { - NNTR_THROW_IF(true, std::runtime_error) - << "This dtype is not supported in save with quantization"; - } - } - } - } -} - -#ifdef PLUGGABLE - -nntrainer::Layer *create_embedding_layer() { - auto layer = new EmbeddingLayer(); - std::cout << "embedding layer created\n"; - return layer; -} - -void destroy_embedding_layer(nntrainer::Layer *layer) { - std::cout << "embeddinglayer is deleted\n"; - delete layer; -} - -extern "C" { -nntrainer::LayerPluggable ml_train_layer_pluggable{create_embedding_layer, - destroy_embedding_layer}; -} - -#endif - -} // namespace quick_dot_ai diff --git a/layers/embedding_layer.h b/layers/embedding_layer.h deleted file mode 100644 index ffe5542a..00000000 --- a/layers/embedding_layer.h +++ /dev/null @@ -1,137 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * Copyright (C) 2021 Jijoong Moon - * - * @file embedding.h - * @date 04 March 2021 - * @brief This is Embedding Layer Class of Neural Network - * @see https://github.com/nntrainer/nntrainer - * @author Jijoong Moon - * @author Eunju Yang - * @bug No known bugs except for NYI items - * - */ - -#ifndef __EMBEDDING_LAYER_H__ -#define __EMBEDDING_LAYER_H__ -#ifdef __cplusplus - -#pragma once -#ifdef _WIN32 -#define WIN_EXPORT __declspec(dllexport) -#else -#define WIN_EXPORT -#endif - -#include -#include - -namespace quick_dot_ai { - -/** - * @class EmbeddingLayer - * @brief EmbeddingLayer - * @todo Support setBatch for EmbeddingLayer - */ -WIN_EXPORT class EmbeddingLayer : public nntrainer::LayerImpl { -public: - /** - * @brief Constructor of Embedding Layer - */ - WIN_EXPORT EmbeddingLayer(); - - /** - * @brief Destructor of Embedding Layer - */ - WIN_EXPORT ~EmbeddingLayer() = default; - - /** - * @brief Move constructor. - * @param[in] EmbeddingLayer && - */ - WIN_EXPORT EmbeddingLayer(EmbeddingLayer &&rhs) noexcept = default; - - /** - * @brief Move assignment operator. - * @parma[in] rhs EmbeddingLayer to be moved. - */ - WIN_EXPORT EmbeddingLayer &operator=(EmbeddingLayer &&rhs) = default; - - /** - * @copydoc Layer::finalize(InitLayerContext &context) - */ - WIN_EXPORT void finalize(nntrainer::InitLayerContext &context) override; - - /** - * @copydoc Layer::forwarding(RunLayerContext &context, bool training) - */ - WIN_EXPORT void forwarding(nntrainer::RunLayerContext &context, - bool training) override; - - /** -οΏΌ * @copydoc Layer::incremental_forwarding(RunLayerContext &context, unsigned -οΏΌ * int from, unsigned int to, bool training) -οΏΌ */ - WIN_EXPORT void incremental_forwarding(nntrainer::RunLayerContext &context, - unsigned int from, unsigned int to, - bool training) override; - - /** - * @copydoc Layer::calcDerivative(RunLayerContext &context) - */ - WIN_EXPORT void calcDerivative(nntrainer::RunLayerContext &context) override; - - /** - * @copydoc Layer::calcGradient(RunLayerContext &context) - */ - WIN_EXPORT void calcGradient(nntrainer::RunLayerContext &context) override; - - /** - * @copydoc Layer::exportTo(Exporter &exporter, ml::train::ExportMethods - * method) - */ - WIN_EXPORT void - exportTo(nntrainer::Exporter &exporter, - const ml::train::ExportMethods &method) const override; - - /** - * @copydoc Layer::getType() - */ - WIN_EXPORT const std::string getType() const override { - return EmbeddingLayer::type; - }; - - /** - * @copydoc Layer::supportBackwarding() - */ - WIN_EXPORT bool supportBackwarding() const override { return false; } - - using Layer::setProperty; - - /** - * @copydoc Layer::setProperty(const PropertyType type, const std::string - * &value) - */ - WIN_EXPORT void setProperty(const std::vector &values) override; - - /** - * @copydic Layer::save() - */ - WIN_EXPORT void save(std::ofstream &file, - nntrainer::RunLayerContext &run_context, bool opt_var, - ml::train::ExecutionMode mode, bool trainable, - nntrainer::TensorDim::DataType dtype = - nntrainer::TensorDim::DataType::NONE) const override; - - inline static const std::string type = "embedding_layer"; - -private: - std::tuple - embedding_props; - unsigned int weight_idx; -}; -} // namespace quick_dot_ai - -#endif /* __cplusplus */ -#endif /* __EMBEDDING_H__ */ diff --git a/layers/embedding_normalize_layer.cpp b/layers/embedding_normalize_layer.cpp deleted file mode 100644 index 0bbf17ff..00000000 --- a/layers/embedding_normalize_layer.cpp +++ /dev/null @@ -1,82 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * Copyright (C) 2026 Eunju Yang - * - * @file embedding_normalize_layer.cpp - * @date 06 Jan 2026 - * @brief This is Embedding Normalize Layer Class - * @see https://github.com/nnstreamer/nntrainer - * @author Eunju Yang - * @bug No known bugs except for NYI items - * - */ - -#include -#include -#include -#include -#include -#include -#include -#include - -namespace quick_dot_ai { - -static constexpr size_t SINGLE_INOUT_IDX = 0; - -EmbeddingNormalizeLayer::EmbeddingNormalizeLayer() : LayerImpl() {} - -void EmbeddingNormalizeLayer::finalize(nntrainer::InitLayerContext &context) { - NNTR_THROW_IF(context.getNumInputs() != 1, std::invalid_argument) - << "EmbeddingNormalize layer takes only one input"; - - const nntrainer::TensorDim &input_dim = - context.getInputDimensions()[SINGLE_INOUT_IDX]; - - context.setOutputDimensions({input_dim}); -} - -void EmbeddingNormalizeLayer::forwarding(nntrainer::RunLayerContext &context, - bool training) { - nntrainer::Tensor &input = context.getInput(SINGLE_INOUT_IDX); - nntrainer::Tensor &output = context.getOutput(SINGLE_INOUT_IDX); - - // Copy input to output as we will modify output in-place - output.copyData(input); - // Normalize along the last dimension (dim=3) - output.normalization_i(3); -} - -void EmbeddingNormalizeLayer::incremental_forwarding( - nntrainer::RunLayerContext &context, unsigned int from, unsigned int to, - bool training) { - // Incremental forwarding for element-wise/row-wise normalization is typically - // identical to forwarding if the input shape matches the processing chunk. - // However, often incremental_forwarding is used when we process a chunk of - // seq_len. BUT, EmbeddingNormalizeLayer usually comes AFTER Pooling, so - // seq_len is likely 1. In that case, incremental_forwarding might not even be - // called or acts same as forwarding. If we assume this layer is generic, we - // should process 'from' to 'to'. But strictly, this layer is designed for - // pooled output [batch, 1, 1, dim]. So 'from' and 'to' are likely 0 and 1. - - forwarding(context, training); -} - -void EmbeddingNormalizeLayer::calcDerivative( - nntrainer::RunLayerContext &context) { - throw nntrainer::exception::not_supported( - "calcDerivative for EmbeddingNormalize layer is not supported"); -} - -void EmbeddingNormalizeLayer::calcGradient( - nntrainer::RunLayerContext &context) { - throw nntrainer::exception::not_supported( - "calcGradient for EmbeddingNormalize layer is not supported"); -} - -void EmbeddingNormalizeLayer::exportTo( - nntrainer::Exporter &exporter, const ml::train::ExportMethods &method) const { - LayerImpl::exportTo(exporter, method); -} - -} // namespace quick_dot_ai diff --git a/layers/embedding_normalize_layer.h b/layers/embedding_normalize_layer.h deleted file mode 100644 index 6999dedb..00000000 --- a/layers/embedding_normalize_layer.h +++ /dev/null @@ -1,88 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * Copyright (C) 2026 Eunju Yang - * - * @file embedding_normalize_layer.h - * @date 06 Jan 2026 - * @brief This is Embedding Normalize Layer Class - * @see https://github.com/nntrainer/nntrainer - * @author Eunju Yang - * @bug No known bugs except for NYI items - * - */ - -#ifndef __EMBEDDING_NORMALIZE_LAYER_H__ -#define __EMBEDDING_NORMALIZE_LAYER_H__ - -#include - -namespace quick_dot_ai { - -/** - * @class EmbeddingNormalizeLayer - * @brief Embedding Normalize Layer - */ -class EmbeddingNormalizeLayer : public nntrainer::LayerImpl { -public: - /** - * @brief Constructor of EmbeddingNormalizeLayer - */ - EmbeddingNormalizeLayer(); - - /** - * @brief Destructor of EmbeddingNormalizeLayer - */ - ~EmbeddingNormalizeLayer() = default; - - /** - * @copydoc Layer::finalize(InitLayerContext &context) - */ - void finalize(nntrainer::InitLayerContext &context) override; - - /** - * @copydoc Layer::forwarding(RunLayerContext &context, bool training) - */ - void forwarding(nntrainer::RunLayerContext &context, bool training) override; - - /** - * @copydoc Layer::incremental_forwarding(RunLayerContext &context, unsigned - * int from, unsigned int to, bool training) - */ - void incremental_forwarding(nntrainer::RunLayerContext &context, - unsigned int from, unsigned int to, - bool training) override; - - /** - * @copydoc Layer::calcDerivative(RunLayerContext &context) - */ - void calcDerivative(nntrainer::RunLayerContext &context) override; - - /** - * @copydoc Layer::calcGradient(RunLayerContext &context) - */ - void calcGradient(nntrainer::RunLayerContext &context) override; - - /** - * @copydoc Layer::exportTo(Exporter &exporter, const ExportMethods &method) - */ - void exportTo(nntrainer::Exporter &exporter, - const ml::train::ExportMethods &method) const override; - - /** - * @copydoc Layer::getType() - */ - const std::string getType() const override { - return EmbeddingNormalizeLayer::type; - } - - /** - * @copydoc Layer::supportBackwarding() - */ - bool supportBackwarding() const override { return false; } - - static constexpr const char *type = "embedding_normalize"; -}; - -} // namespace quick_dot_ai - -#endif /* __EMBEDDING_NORMALIZE_LAYER_H__ */ diff --git a/layers/embedding_pooling_layer.cpp b/layers/embedding_pooling_layer.cpp deleted file mode 100644 index 551841cb..00000000 --- a/layers/embedding_pooling_layer.cpp +++ /dev/null @@ -1,171 +0,0 @@ -// SPDX-License-Identifier: Apatche-2.0 -/** - * Copyright (C) 2025 Eunju Yang - * - * @file embedding_pooling_layer.cpp - * @date 02 Jan 2026 - * @brief This is Embedding Pooling Layer Class - * @see https://github.com/nntrainer/nntrainer - * @author Eunju Yang - * @bug No known bugs except for NYI items - */ - -#include -#include -#include -#include -#include -#include - -namespace quick_dot_ai { - -static constexpr size_t SINGLE_INOUT_IDX = 0; - -EmbeddingPoolingLayer::EmbeddingPoolingLayer() : - LayerImpl(), - pooling_props( - props::WordEmbeddingDimension(), props::PoolingModeClsToken(false), - props::PoolingModeMeanTokens(false), props::PoolingModeMaxTokens(false), - props::PoolingModeMeanSqrtLenTokens(false), - props::PoolingModeWeightedMeanTokens(false), - props::PoolingModeLastToken(false), props::IncludePrompt(true)) {} - -void EmbeddingPoolingLayer::finalize(nntrainer::InitLayerContext &context) { - NNTR_THROW_IF(context.getNumInputs() != 1, std::invalid_argument) - << "EmbeddingPooling layer takes only one input"; - - const nntrainer::TensorDim &input_dim = - context.getInputDimensions()[SINGLE_INOUT_IDX]; - - unsigned int word_embed_dim = - std::get(pooling_props); - - if (input_dim.width() != word_embed_dim) { - ml_logw( - "Input dimension width (%d) does not match word_embedding_dimension (%d)", - input_dim.width(), word_embed_dim); - } - - // Output dimension for Pooling is [batch, 1, 1, word_embed_dim] - nntrainer::TensorDim output_dim = input_dim; - output_dim.height(1); - - context.setOutputDimensions({output_dim}); - - bool mode_cls = std::get(pooling_props); - bool mode_mean = std::get(pooling_props); - bool mode_max = std::get(pooling_props); - bool mode_mean_sqrt = - std::get(pooling_props); - bool mode_weighted_mean = - std::get(pooling_props); - - if (mode_cls || mode_max || mode_mean_sqrt || mode_weighted_mean) { - throw nntrainer::exception::not_supported( - "Only pooling_mode_lasttoken and pooling_mode_mean_tokens are currently " - "supported in EmbeddingPoolingLayer"); - } -} - -void EmbeddingPoolingLayer::setProperty( - const std::vector &values) { - auto remain_props = loadProperties(values, pooling_props); - LayerImpl::setProperty(remain_props); -} - -void EmbeddingPoolingLayer::forwarding(nntrainer::RunLayerContext &context, - bool training) { - nntrainer::Tensor &input = context.getInput(SINGLE_INOUT_IDX); - nntrainer::Tensor &output = context.getOutput(SINGLE_INOUT_IDX); - - unsigned int batch = input.batch(); - unsigned int seq_len = input.height(); - unsigned int dim = input.width(); - - bool mode_lasttoken = std::get(pooling_props); - bool mode_mean = std::get(pooling_props); - - if (mode_lasttoken) { - for (unsigned int b = 0; b < batch; ++b) { - // Last token index = seq_len - 1 - nntrainer::Tensor source = input.getSharedDataTensor( - {1, 1, 1, dim}, b * seq_len * dim + (seq_len - 1) * dim); - - nntrainer::Tensor dest = - output.getSharedDataTensor({1, 1, 1, dim}, b * dim); - dest.copyData(source); - } - } else if (mode_mean) { - for (unsigned int b = 0; b < batch; ++b) { - nntrainer::Tensor source = - input.getSharedDataTensor({1, 1, seq_len, dim}, b * seq_len * dim); - nntrainer::Tensor dest = - output.getSharedDataTensor({1, 1, 1, dim}, b * dim); - - // Calculate mean along average dim (height/seq_len) - dest.copyData(source.average(2)); - } - } else { - output.setZero(); - } -} - -void EmbeddingPoolingLayer::incremental_forwarding( - nntrainer::RunLayerContext &context, unsigned int from, unsigned int to, - bool training) { - nntrainer::Tensor &input = context.getInput(SINGLE_INOUT_IDX); - nntrainer::Tensor &output = context.getOutput(SINGLE_INOUT_IDX); - - unsigned int batch = input.batch(); - unsigned int dim = input.width(); - size_t feature_len = input.getDim().getFeatureLen(); // height * width - - bool mode_lasttoken = std::get(pooling_props); - - if (mode_lasttoken) { - for (unsigned int b = 0; b < batch; ++b) { - // Use feature_len for batch stride - // The last token processed is at index `to-1` in the absolute sequence. - size_t offset = static_cast(b) * feature_len + (to - 1) * dim; - - nntrainer::Tensor source = - input.getSharedDataTensor({1, 1, 1, dim}, offset); - nntrainer::Tensor dest = - output.getSharedDataTensor({1, 1, 1, dim}, b * dim); - - dest.copyData(source); - } - } else if (std::get(pooling_props)) { - for (unsigned int b = 0; b < batch; ++b) { - unsigned int len = to - from; - size_t offset = static_cast(b) * feature_len + from * dim; - - nntrainer::Tensor source = - input.getSharedDataTensor({1, 1, len, dim}, offset); - nntrainer::Tensor dest = - output.getSharedDataTensor({1, 1, 1, dim}, b * dim); - - dest.copyData(source.average(2)); - } - } else { - output.setZero(); - } -} - -void EmbeddingPoolingLayer::calcDerivative( - nntrainer::RunLayerContext &context) { - throw nntrainer::exception::not_supported( - "calcDerivative for EmbeddingPooling layer is not supported"); -} - -void EmbeddingPoolingLayer::calcGradient(nntrainer::RunLayerContext &context) { - throw nntrainer::exception::not_supported( - "calcGradient for EmbeddingPooling layer is not supported"); -} - -void EmbeddingPoolingLayer::exportTo( - nntrainer::Exporter &exporter, const ml::train::ExportMethods &method) const { - LayerImpl::exportTo(exporter, method); - exporter.saveResult(pooling_props, method, this); -} -} // namespace quick_dot_ai diff --git a/layers/embedding_pooling_layer.h b/layers/embedding_pooling_layer.h deleted file mode 100644 index 28aaa8d0..00000000 --- a/layers/embedding_pooling_layer.h +++ /dev/null @@ -1,193 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * Copyright (C) 2025 Eunju Yang - * - * @file embedding_pooling_layer.h - * @date 02 Jan 2026 - * @brief This is Embedding Pooling Layer Class (for sentence-transformer) - * @see https://github.com/nntrainer/nntrainer - * @author Eunju Yang - * @bug No known bugs except for NYI items - */ - -#ifndef __EMBEDDING_POOLING_LAYER_H__ -#define __EMBEDDING_POOLING_LAYER_H__ - -#include -#include -#include - -namespace quick_dot_ai { - -namespace props { - -/** - * @brief WordEmbeddingDimension property class to hold word embedding dimension - */ -class WordEmbeddingDimension : public nntrainer::Property { -public: - static constexpr const char *key = "word_embedding_dimension"; - using prop_tag = nntrainer::uint_prop_tag; - WordEmbeddingDimension(unsigned int value = 0) { set(value); } -}; - -/** - * @brief PoolingModeClsToken property class to hold pooling mode cls token flag - */ -class PoolingModeClsToken : public nntrainer::Property { -public: - static constexpr const char *key = "pooling_mode_cls_token"; - using prop_tag = nntrainer::bool_prop_tag; - PoolingModeClsToken(bool value = false) { set(value); } -}; - -/** - * @brief PoolingModeMeanTokens property class to hold pooling mode mean tokens - * flag - */ -class PoolingModeMeanTokens : public nntrainer::Property { -public: - static constexpr const char *key = "pooling_mode_mean_tokens"; - using prop_tag = nntrainer::bool_prop_tag; - PoolingModeMeanTokens(bool value = false) { set(value); } -}; - -/** - * @brief PoolingModeMaxTokens property class to hold pooling mode max tokens - * flag - */ -class PoolingModeMaxTokens : public nntrainer::Property { -public: - static constexpr const char *key = "pooling_mode_max_tokens"; - using prop_tag = nntrainer::bool_prop_tag; - PoolingModeMaxTokens(bool value = false) { set(value); } -}; - -/** - * @brief PoolingModeMeanSqrtLenTokens property class to hold pooling mode mean - */ -class PoolingModeMeanSqrtLenTokens : public nntrainer::Property { -public: - static constexpr const char *key = "pooling_mode_mean_sqrt_len_tokens"; - using prop_tag = nntrainer::bool_prop_tag; - PoolingModeMeanSqrtLenTokens(bool value = false) { set(value); } -}; - -/** - * @brief PoolingModeWeightedMeanTokens property class to hold pooling mode - * weighted mean tokens flag - */ -class PoolingModeWeightedMeanTokens : public nntrainer::Property { -public: - static constexpr const char *key = "pooling_mode_weightedmean_tokens"; - using prop_tag = nntrainer::bool_prop_tag; - PoolingModeWeightedMeanTokens(bool value = false) { set(value); } -}; - -/** - * @brief PoolingModeLastToken property class to hold pooling mode last token - * flag - */ -class PoolingModeLastToken : public nntrainer::Property { -public: - static constexpr const char *key = "pooling_mode_lasttoken"; - using prop_tag = nntrainer::bool_prop_tag; - PoolingModeLastToken(bool value = false) { set(value); } -}; - -/** - * @brief IncludePrompt property class to hold include prompt flag (default - * true) - */ -class IncludePrompt : public nntrainer::Property { -public: - static constexpr const char *key = "include_prompt"; - using prop_tag = nntrainer::bool_prop_tag; - IncludePrompt(bool value = true) { set(value); } -}; -} // namespace props - -/** - * @brief Embedding Pooling Layer - * @note This layer corresponds to sentence_transformers.models.Pooling. - * Currently, only pooling_mode_lasttoken with include_prompt is fully - * implemented. Other pooling modes are defined as properties but their logic is - * not yet implemented. - */ -class EmbeddingPoolingLayer : public nntrainer::LayerImpl { -public: - /** - * @brief Construct a new Embedding Pooling Layer object - */ - EmbeddingPoolingLayer(); - - /** - * @brief Destroy the Embedding Pooling Layer object - */ - ~EmbeddingPoolingLayer() {} - - /** - * @copydoc Layer::finalize(InitLayerContext &context) - */ - void finalize(nntrainer::InitLayerContext &context) override; - - /** - * @copydoc Layer::forwarding(RunLayerContext &context, bool training) - */ - void forwarding(nntrainer::RunLayerContext &context, bool training) override; - - /** - * @copydoc Layer::incremental_forwarding(RunLayerContext &context, unsigned - * int from, unsigned int to, bool training) - */ - void incremental_forwarding(nntrainer::RunLayerContext &context, - unsigned int from, unsigned int to, - bool training) override; - - /** - * @copydoc Layer::calcDerivative(RunLayerContext &context) - */ - void calcDerivative(nntrainer::RunLayerContext &context) override; - - /** - * @copydoc Layer::calcGradient(RunLayerContext &context) - */ - void calcGradient(nntrainer::RunLayerContext &context) override; - - /** - * @copydoc Layer::exportTo(Exporter &exporter, const ExportMethods &method) - */ - void exportTo(nntrainer::Exporter &exporter, - const ml::train::ExportMethods &method) const override; - - /** - * @copydoc Layer::setProperty(const std::vector &values) - */ - void setProperty(const std::vector &values) override; - - /** - * @copydoc Layer::getType() - */ - const std::string getType() const override { - return EmbeddingPoolingLayer::type; - } - - /** - * @copydoc Layer::supportBackwarding() - */ - bool supportBackwarding() const override { return false; } - - static constexpr const char *type = "embedding_pooling"; - -private: - std::tuple - pooling_props; -}; - -} // namespace quick_dot_ai - -#endif diff --git a/layers/lm_head.cpp b/layers/lm_head.cpp deleted file mode 100644 index 6883b3f9..00000000 --- a/layers/lm_head.cpp +++ /dev/null @@ -1,208 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * Copyright (C) 2026 Eunju Yang - * - * @file lm_head.cpp - * @date 16 Jan 2026 - * @brief This is lmhead layer - * @see https://github.com/nntrainer/nntrainer - * @author Eunju Yang - * @bug No known bugs except for NYI items - * - */ - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace quick_dot_ai { - -static constexpr size_t SINGLE_INOUT_IDX = 0; - -enum LmHeadParams { - weight, - bias, -}; - -LmHeadLayer::LmHeadLayer() : - LayerImpl(), lmhead_props(nntrainer::props::Unit()) { - weight_idx.fill(std::numeric_limits::max()); -} - -void LmHeadLayer::finalize(nntrainer::InitLayerContext &context) { - auto &weight_regularizer = - std::get(*layer_impl_props); - auto &weight_regularizer_constant = - std::get(*layer_impl_props); - auto weight_initializer = nntrainer::props::InitializerInfo::Enum::NONE; - auto &weight_decay = - std::get(*layer_impl_props); - auto &bias_decay = std::get(*layer_impl_props); - auto &bias_initializer = - std::get(*layer_impl_props); - auto &disable_bias = - std::get(*layer_impl_props); - - auto unit = std::get(lmhead_props).get(); - - NNTR_THROW_IF(context.getNumInputs() != 1, std::invalid_argument) - << "lm head layer takes only one input"; - - std::vector output_dims(1); - - /// @todo fc actaully supports multidimensions. - /// EffDimFlag shouldn't be fixed like this. - context.setEffDimFlagInputDimension(0, 0b1001); - context.setDynDimFlagInputDimension(0, 0b1000); - bool is_nchw = (context.getFormat() == nntrainer::Tformat::NCHW); - - /** set output dimensions */ - ///@note lm_head's output dimension (height is always 1 !) - auto const &in_dim = context.getInputDimensions()[0]; - output_dims[0] = in_dim; - if (is_nchw) - output_dims[0].width(unit); - else - output_dims[0].channel(unit); - output_dims[0].height(1); - - output_dims[0].setTensorType( - {context.getFormat(), context.getActivationDataType()}); - - context.setOutputDimensions(output_dims); - - /** set weight specifications */ - ml::train::TensorDim bias_dim( - 1, is_nchw ? 1 : unit, 1, is_nchw ? unit : 1, - ml::train::TensorDim::TensorType(context.getFormat(), - context.getWeightDataType()), - is_nchw ? 0b0001 : 0b0100); - - ///@note LMHead layer's tensor dim is transposed dim of user-defined - /// dim - /// so it can reuse embedding layer. - ml::train::TensorDim weight_dim( - 1, is_nchw ? 1 : unit, is_nchw ? in_dim.width() : 1, - is_nchw ? unit : in_dim.channel(), - ml::train::TensorDim::TensorType(context.getFormat(), - context.getWeightDataType()), - is_nchw ? 0b0011 : 0b0101); - - weight_idx[LmHeadParams::weight] = context.requestWeight( - weight_dim, weight_initializer, weight_regularizer, - weight_regularizer_constant, weight_decay, "weight", true); - - if (disable_bias.empty() || disable_bias.get() == false) { - weight_idx[LmHeadParams::bias] = context.requestWeight( - bias_dim, bias_initializer, nntrainer::WeightRegularizer::NONE, 1.0f, - bias_decay, "bias", true); - } -} - -void LmHeadLayer::setProperty(const std::vector &values) { - auto remain_props = loadProperties(values, lmhead_props); - LayerImpl::setProperty(remain_props); -} - -void LmHeadLayer::forwarding(nntrainer::RunLayerContext &context, - bool training) { - throw nntrainer::exception::not_supported( - "Forwarding for LMHead layer is not supported"); -} - -void LmHeadLayer::incremental_forwarding(nntrainer::RunLayerContext &context, - unsigned int from, unsigned int to, - bool training) { - - nntrainer::Tensor weight = - context.getWeight(weight_idx[LmHeadParams::weight]); - - nntrainer::Tensor &input_ = context.getInput(SINGLE_INOUT_IDX); - nntrainer::Tensor &hidden_ = context.getOutput(SINGLE_INOUT_IDX); - - ml::train::TensorDim input_dim = input_.getDim(); - ml::train::TensorDim hidden_dim = hidden_.getDim(); - - ml::train::TensorDim input_step_dim = input_dim; - ml::train::TensorDim hidden_step_dim = hidden_dim; - - input_step_dim.batch(1); - input_step_dim.height(1); - hidden_step_dim.batch(1); - - unsigned int b_size = input_dim.batch(); - - for (unsigned int b = 0; b < b_size; ++b) { - nntrainer::Tensor input_step = input_.getSharedDataTensor( - input_step_dim, - b * input_dim.getFeatureLen() + (to - from - 1) * input_.width(), true); - nntrainer::Tensor hidden_step = hidden_.getSharedDataTensor( - hidden_step_dim, b * hidden_dim.getFeatureLen(), true); - - input_step.dot(weight, hidden_step, false, false); - - if (auto &disable_bias = - std::get(*layer_impl_props); - disable_bias.empty() || disable_bias.get() == false) { - nntrainer::Tensor &bias = - context.getWeight(weight_idx[LmHeadParams::bias]); - hidden_step.add_i(bias); - } - } -} - -void LmHeadLayer::calcDerivative(nntrainer::RunLayerContext &context) { - throw nntrainer::exception::not_supported( - "calcDerivative for LMHead layer is not supported"); -} - -void LmHeadLayer::calcGradient(nntrainer::RunLayerContext &context) { - throw nntrainer::exception::not_supported( - "calcGradient for LMHead layer is not supported"); -} - -void LmHeadLayer::exportTo(nntrainer::Exporter &exporter, - const ml::train::ExportMethods &method) const { - LayerImpl::exportTo(exporter, method); - exporter.saveResult(lmhead_props, method, this); -} - -void LmHeadLayer::updateTensorsByInputDimensions( - nntrainer::RunLayerContext &context, - std::vector input_dimensions) { - nntrainer::TensorDim in_dim = context.getInput(SINGLE_INOUT_IDX).getDim(); - - unsigned int height = input_dimensions[0].height(); - - // output dim's height is always 1 ! - in_dim.height(height); - context.updateInput(SINGLE_INOUT_IDX, in_dim); -} - -#ifdef PLUGGABLE - -nntrainer::Layer *create_tie_word_embedding() { - auto layer = new LmHeadLayer(); - std::cout << "embedding layer created\n"; - return layer; -} - -void destroy_tie_word_embedding(nntrainer::Layer *layer) { - std::cout << "embeddinglayer is deleted\n"; - delete layer; -} - -extern "C" { -nntrainer::LayerPluggable ml_train_layer_pluggable{create_tie_word_embedding, - destroy_tie_word_embedding}; -} - -#endif - -} // namespace quick_dot_ai diff --git a/layers/lm_head.h b/layers/lm_head.h deleted file mode 100644 index 9eb836b9..00000000 --- a/layers/lm_head.h +++ /dev/null @@ -1,129 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * Copyright (C) 2026 Eunju Yang - * - * @file lm_head.h - * @date 16 Jan 2026 - * @brief This is LM_Head Layer Class of Neural Network - * @see https://github.com/nntrainer/nntrainer - * @author Eunju Yang - * @bug No known bugs except for NYI items - * - */ - -#ifndef __LM_HEAD_H__ -#define __LM_HEAD_H__ -#ifdef __cplusplus - -#pragma once -#ifdef _WIN32 -#define WIN_EXPORT __declspec(dllexport) -#else -#define WIN_EXPORT -#endif - -#include -#include -#include - -namespace quick_dot_ai { - -/** - * @class LMHead layer - * @brief LMHead layer - */ -WIN_EXPORT class LmHeadLayer : public nntrainer::LayerImpl { -public: - /** - * @brief Constructor of Embedding Layer - */ - WIN_EXPORT LmHeadLayer(); - - /** - * @brief Destructor of Embedding Layer - */ - WIN_EXPORT ~LmHeadLayer() = default; - - /** - * @brief Move constructor. - * @param[in] LmHeadLayer && - */ - WIN_EXPORT LmHeadLayer(LmHeadLayer &&rhs) noexcept = default; - - /** - * @brief Move assignment operator. - * @parma[in] rhs LmHeadLayer to be moved. - */ - WIN_EXPORT LmHeadLayer &operator=(LmHeadLayer &&rhs) = default; - - /** - * @copydoc Layer::finalize(InitLayerContext &context) - */ - WIN_EXPORT void finalize(nntrainer::InitLayerContext &context) override; - - /** - * @copydoc Layer::forwarding(RunLayerContext &context, bool training) - */ - WIN_EXPORT void forwarding(nntrainer::RunLayerContext &context, - bool training) override; - - /** -οΏΌ * @copydoc Layer::incremental_forwarding(RunLayerContext &context, unsigned -οΏΌ * int from, unsigned int to, bool training) -οΏΌ */ - WIN_EXPORT void incremental_forwarding(nntrainer::RunLayerContext &context, - unsigned int from, unsigned int to, - bool training) override; - - /** - * @copydoc Layer::calcDerivative(RunLayerContext &context) - */ - WIN_EXPORT void calcDerivative(nntrainer::RunLayerContext &context) override; - - /** - * @copydoc Layer::calcGradient(RunLayerContext &context) - */ - WIN_EXPORT void calcGradient(nntrainer::RunLayerContext &context) override; - - /** - * @copydoc Layer::exportTo(Exporter &exporter, ml::train::ExportMethods - * method) - */ - WIN_EXPORT void - exportTo(nntrainer::Exporter &exporter, - const ml::train::ExportMethods &method) const override; - - /** - * @copydoc Layer::getType() - */ - WIN_EXPORT const std::string getType() const override { - return LmHeadLayer::type; - }; - - /** - * @copydoc Layer::supportBackwarding() - */ - WIN_EXPORT bool supportBackwarding() const override { return false; } - - WIN_EXPORT void updateTensorsByInputDimensions( - nntrainer::RunLayerContext &context, - std::vector input_dimensions) override; - - using Layer::setProperty; - - /** - * @copydoc Layer::setProperty(const PropertyType type, const std::string - * &value) - */ - WIN_EXPORT void setProperty(const std::vector &values) override; - - inline static const std::string type = "lm_head"; - -private: - std::tuple lmhead_props; - std::array weight_idx; /**< indices of the weights */ -}; -} // namespace quick_dot_ai - -#endif -#endif diff --git a/layers/meson.build b/layers/meson.build deleted file mode 100644 index 09ca577e..00000000 --- a/layers/meson.build +++ /dev/null @@ -1,146 +0,0 @@ -quick_dot_ai_layer_inc_abs = [meson.current_source_dir()] -quick_dot_ai_layer_inc = [include_directories('.')] - -quick_dot_ai_rms_norm_src_abs = [meson.current_source_dir() / 'rms_norm.cpp'] -quick_dot_ai_swiglu_src_abs = [meson.current_source_dir() / 'swiglu.cpp'] -quick_dot_ai_tie_word_embedding_src_abs = [meson.current_source_dir() / 'tie_word_embedding.cpp'] -quick_dot_ai_lmhead_src_abs = [meson.current_source_dir() / 'lm_head.cpp'] -quick_dot_ai_mha_core_abs = [meson.current_source_dir()/ 'mha_core.cpp'] -quick_dot_ai_embedding_src_abs = [meson.current_source_dir() / 'embedding_layer.cpp'] -quick_dot_ai_reshaped_rms_norm_src_abs = [meson.current_source_dir() / 'reshaped_rms_norm.cpp'] -quick_dot_ai_qkv_layer_src_abs = [meson.current_source_dir() / 'qkv_layer.cpp'] -quick_dot_ai_embedding_pooling_src_abs = [meson.current_source_dir() / 'embedding_pooling_layer.cpp'] -quick_dot_ai_embedding_normalize_src_abs = [meson.current_source_dir() / 'embedding_normalize_layer.cpp'] - -openmp_dep = dependency('openmp') -quick_dot_ai_rms_norm = shared_library( - 'quick_dot_ai_rms_norm_layer', - quick_dot_ai_rms_norm_src_abs, - include_directories: quick_dot_ai_layer_inc, - dependencies: [nntrainer_dep, nntrainer_ccapi_dep], - install: true, - install_dir: application_install_dir -) -quick_dot_ai_rms_norm_dep = declare_dependency( - link_with: quick_dot_ai_rms_norm, - include_directories: quick_dot_ai_layer_inc -) - -quick_dot_ai_swiglu = shared_library( - 'quick_dot_ai_swiglu_layer', - quick_dot_ai_swiglu_src_abs, - include_directories: quick_dot_ai_layer_inc, - dependencies: [nntrainer_dep, nntrainer_ccapi_dep], - install: true, - install_dir: application_install_dir -) -quick_dot_ai_swiglu_dep = declare_dependency( - link_with: quick_dot_ai_swiglu, - include_directories: quick_dot_ai_layer_inc -) - -quick_dot_ai_tie_word_embedding = shared_library( - 'quick_dot_ai_tie_word_embedding_layer', - quick_dot_ai_tie_word_embedding_src_abs, - include_directories: quick_dot_ai_layer_inc, - dependencies: [nntrainer_dep, nntrainer_ccapi_dep, openmp_dep], - install: true, - install_dir: application_install_dir -) -quick_dot_ai_tie_word_embedding_dep = declare_dependency( - link_with: quick_dot_ai_tie_word_embedding, - include_directories: quick_dot_ai_layer_inc -) - -quick_dot_ai_lm_head = shared_library( - 'quick_dot_ai_lm_head', - quick_dot_ai_lmhead_src_abs, - include_directories: quick_dot_ai_layer_inc, - dependencies: [nntrainer_dep, nntrainer_ccapi_dep], - install: true, - install_dir: application_install_dir -) -quick_dot_ai_lm_head_dep = declare_dependency( - link_with: quick_dot_ai_lm_head, - include_directories: quick_dot_ai_layer_inc -) - -quick_dot_ai_mha_core = shared_library( - 'quick_dot_ai_mha_core_layer', - quick_dot_ai_mha_core_abs, - include_directories: quick_dot_ai_layer_inc, - dependencies: [nntrainer_dep, nntrainer_ccapi_dep, openmp_dep], - install: true, - install_dir: application_install_dir -) -quick_dot_ai_mha_core_dep = declare_dependency( - link_with: quick_dot_ai_mha_core, - include_directories: quick_dot_ai_layer_inc -) - -quick_dot_ai_embedding_layer = shared_library( - 'quick_dot_ai_embedding_layer', - quick_dot_ai_embedding_src_abs, - include_directories: quick_dot_ai_layer_inc, - dependencies: [nntrainer_dep, nntrainer_ccapi_dep, openmp_dep], - install: true, - install_dir: application_install_dir -) -quick_dot_ai_embedding_layer_dep = declare_dependency( - link_with: quick_dot_ai_embedding_layer, - include_directories: quick_dot_ai_layer_inc -) - -quick_dot_ai_embedding_pooling_layer = shared_library( - 'quick_dot_ai_embedding_pooling_layer', - quick_dot_ai_embedding_pooling_src_abs, - include_directories: quick_dot_ai_layer_inc, - dependencies: [nntrainer_dep, nntrainer_ccapi_dep, openmp_dep], - install: true, - install_dir: application_install_dir -) -quick_dot_ai_embedding_pooling_layer_dep = declare_dependency( - link_with: quick_dot_ai_embedding_pooling_layer, - include_directories: quick_dot_ai_layer_inc -) - -quick_dot_ai_embedding_normalize_layer = shared_library( - 'quick_dot_ai_embedding_normalize_layer', - quick_dot_ai_embedding_normalize_src_abs, - include_directories: quick_dot_ai_layer_inc, - dependencies: [nntrainer_dep, nntrainer_ccapi_dep], - install: true, - install_dir: application_install_dir -) - -quick_dot_ai_embedding_normalize_layer_dep = declare_dependency( - link_with: quick_dot_ai_embedding_normalize_layer, - include_directories: quick_dot_ai_layer_inc -) - -quick_dot_ai_reshaped_rms_norm = shared_library( - 'quick_dot_ai_reshaped_rms_norm_layer', - quick_dot_ai_reshaped_rms_norm_src_abs, - include_directories: quick_dot_ai_layer_inc, - dependencies: [nntrainer_dep, nntrainer_ccapi_dep], - install: true, - install_dir: application_install_dir -) -quick_dot_ai_reshaped_rms_norm_dep = declare_dependency( - link_with: quick_dot_ai_reshaped_rms_norm, - include_directories: quick_dot_ai_layer_inc -) - -quick_dot_ai_qkv_layer = shared_library( - 'quick_dot_ai_qkv_layer', - quick_dot_ai_qkv_layer_src_abs, - include_directories: quick_dot_ai_layer_inc, - dependencies: [nntrainer_dep, nntrainer_ccapi_dep], - install: true, - install_dir: application_install_dir -) - -quick_dot_ai_qkv_layer_dep = declare_dependency( - link_with: quick_dot_ai_qkv_layer, - include_directories: quick_dot_ai_layer_inc -) diff --git a/layers/mha_core.cpp b/layers/mha_core.cpp deleted file mode 100644 index 23230091..00000000 --- a/layers/mha_core.cpp +++ /dev/null @@ -1,1318 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * Copyright (C) 2025 Jijoong Moon - * - * @file mha_core.cpp - * @date 11 July 2025 - * @see https://github.com/nntrainer/nntrainer - * https://arxiv.org/abs/1706.03762 - * @author Jijoong Moon - * @bug No known bugs except for NYI items - * @brief This code is based on custom_multi_head_attention_layer.cpp. - * This code is a part of the break down version of the mha layer. - */ -#include -#include -#include -#include -#include -#include - -static std::mutex rope_init_mtx; - -#include -#include -#include -#include -#include -#include - -#include - -inline float convert_scalar(uint16_t h) { - return nntrainer::compute_fp16_to_fp32(h); -} - -namespace quick_dot_ai { - -#define tile_size 4 - -/************************************************************** */ - -/** - * @brief constructor of MHACoreLayer - */ -MHACoreLayer::MHACoreLayer() : - mha_core_props( - nntrainer::props::NumHeads(), props::NumHeads_KV(), - nntrainer::props::ProjectedKeyDim(), nntrainer::props::ProjectedValueDim(), - nntrainer::props::OutputShape(), nntrainer::props::DropOutRate(), - nntrainer::props::ReturnAttentionWeight(), - nntrainer::props::AverageAttentionWeight(), nntrainer::props::MaxTimestep(), - props::SlidingWindow(), props::MaxNewTokens(), props::RopeTheta(), - props::MaxPositionEmbeddings(), props::UseSink(), props::RopeScalingType(), - props::RopeScalingFactor(), props::RopeScalingMaxPositionEmbeddings(), - props::AttnLogitSoftcapping(), props::IsCausal()), - sm(nntrainer::ActivationType::ACT_SOFTMAX), - epsilon(1e-3), - cache_index(0), - num_heads_Q(0), - num_heads_KV(0), - head_dim(0), - cache_shift(false) { - tensor_idx.fill(std::numeric_limits::max()); -} - -MHACoreLayer::~MHACoreLayer() {} - -/************************************************************** */ - -void MHACoreLayer::finalize(nntrainer::InitLayerContext &context) { - - NNTR_THROW_IF(context.getNumInputs() < 3 || context.getNumInputs() > 4, - std::invalid_argument) - << "Multi head Attention layer needs 3 or 4 inputs. (query, key, value and " - "mask is optional)"; - ml::train::TensorDim::TensorType activation_type = { - context.getFormat(), context.getActivationDataType()}; - ml::train::TensorDim empty_dim(activation_type); - - const std::vector &input_dims = - context.getInputDimensions(); - const ml::train::TensorDim &query_dim = input_dims[INOUT_INDEX::QUERY]; - const ml::train::TensorDim &key_dim = input_dims[INOUT_INDEX::KEY]; - - /** max time step of this model */ - const unsigned int max_timestep = - std::get(mha_core_props).get(); - - /** max position embeddings */ - max_position_embeddings = - std::get(mha_core_props).get(); - - /** local window size */ - local_window_size = std::get(mha_core_props).get(); - - /** attention scaling computation */ - rope_scaling_type = std::get(mha_core_props).get(); - scale = std::get(mha_core_props).get(); - if (rope_scaling_type == "yarn") - original_max_position_embeddings = - std::get(mha_core_props).get(); - - /** query_dim = (B, 1, seq_len, H_Q * Head_Dim ) */ - const unsigned int batch_size = query_dim.batch(); - const unsigned int query_width = query_dim.width(); - /** key_dim = (B, 1, max_seq_len, H_KV * Head_Dim ) */ - const unsigned int key_width = key_dim.width(); - - /** - * @note If NumHeads_KV is set, then use the value. Otherwise, - * we initialize num_heads_KV with num_heads_Q. - */ - num_heads_Q = static_cast( - std::get(mha_core_props).get()); - num_heads_KV = - std::get(mha_core_props).empty() - ? num_heads_Q - : static_cast(std::get(mha_core_props).get()); - - // head_dim - head_dim = static_cast(query_width) / num_heads_Q; - NNTR_THROW_IF(head_dim != key_width / num_heads_KV, std::invalid_argument) - << "num_heads_Q and num_heads_KV are not properly given. Please check the " - "num_heads_* are set correctly so that the `head_dim`s are all same for " - "query / key / value"; - - /** Weight for Sink */ - use_sink = std::get(mha_core_props).get(); - if (use_sink) { -#if ENABLE_FP16 && defined(__ANDROID__) - nntrainer::TensorDim sink_dim( - 1, 1, 1, num_heads_Q, - nntrainer::TensorDim::TensorType(context.getFormat(), - ml::train::TensorDim::DataType::FP16)); -#else - nntrainer::TensorDim sink_dim( - 1, 1, 1, num_heads_Q, - nntrainer::TensorDim::TensorType(context.getFormat(), - context.getActivationDataType())); -#endif - sink_idx = context.requestWeight(sink_dim, nntrainer::Initializer::ZEROS, - nntrainer::WeightRegularizer::NONE, 0.0f, - 0.0f, "sink"); - } - - attn_logit_softcapping = - std::get(mha_core_props).get(); - - /** Is Causal */ - is_causal = std::get(mha_core_props).get(); - - /** Tensor for KV-Cache */ -#ifdef ENABLE_FP16 - ml::train::TensorDim cache_key_dim( - {batch_size, 1, max_timestep, num_heads_KV * head_dim}, - {context.getFormat(), ml::train::TensorDim::DataType::FP16}); - ml::train::TensorDim cache_value_dim( - {batch_size, 1, max_timestep, num_heads_KV * head_dim}, - {context.getFormat(), ml::train::TensorDim::DataType::FP16}); -#else - ml::train::TensorDim cache_key_dim( - {batch_size, 1, max_timestep, num_heads_KV * head_dim}, - {context.getFormat(), ml::train::TensorDim::DataType::UINT16}); - ml::train::TensorDim cache_value_dim( - {batch_size, 1, max_timestep, num_heads_KV * head_dim}, - {context.getFormat(), ml::train::TensorDim::DataType::UINT16}); -#endif - - tensor_idx[AttentionParams::cache_key] = context.requestTensor( - cache_key_dim, "cache_key", nntrainer::Initializer::NONE, false, - nntrainer::TensorLifespan::MAX_LIFESPAN); - tensor_idx[AttentionParams::cache_value] = context.requestTensor( - cache_value_dim, "cache_value", nntrainer::Initializer::NONE, false, - nntrainer::TensorLifespan::MAX_LIFESPAN); - - theta = (float)std::get(mha_core_props).get(); - - /** set Output dimension! - one output */ - std::vector output_dims(1); - output_dims[0] = input_dims[0]; - output_dims[0].width(head_dim * num_heads_Q); - output_dims[0].setTensorType( - {context.getFormat(), context.getActivationDataType()}); - context.setOutputDimensions(output_dims); -} - -/************************************************************** */ - -/** - * @note This forwarding function is used for training mode. - * This will be implemented ASAP. - * @date 2024-09-02 - */ -void MHACoreLayer::forwarding(nntrainer::RunLayerContext &context, - bool training) {} - -/** - * @note This incremental_forwarding method is invoked for inference mode. - * Please note that Transformer Decoder's MHA takes only one sequence at a - * step. Incremental forwarding function is used for this. - */ -void MHACoreLayer::incremental_forwarding(nntrainer::RunLayerContext &context, - unsigned int _from, unsigned int _to, - bool training) { - /// @todo replace step_size into input height - unsigned int step_size = _to - _from; - - unsigned int max_timestep = - std::get(mha_core_props).get(); - - unsigned int from = _from; - unsigned int to = _to; - - if (to >= max_timestep) { - // initial forwarding - if (!_from) { - throw std::invalid_argument( - "to shouldn't greater than max_timestep for initial forwarding"); - } else { - throw std::runtime_error("NYI: cache shift is not available"); - // exceeds the kv_cache size - // KV_cache is shifted! - cache_shift = true; - from = max_timestep - 1; - to = max_timestep; - } - } - - // util fn to compute tensor dimension for one step. - auto get_step_dim = [step_size](const ml::train::TensorDim &dim) { - auto step_dim = dim; - step_dim.batch(1); - step_dim.height(step_size); - return step_dim; - }; - - /** incremental forwarding for each batch */ - nntrainer::Tensor &query = - context.getInput(INOUT_INDEX::QUERY); // projected query - nntrainer::Tensor &key = context.getInput(INOUT_INDEX::KEY); // projected key - nntrainer::Tensor &value = - context.getInput(INOUT_INDEX::VALUE); // projected value - nntrainer::Tensor &output = - context.getOutput(INOUT_INDEX::OUTPUT); // output to be projected - - nntrainer::Tensor &cache_key = - context.getTensor(tensor_idx[AttentionParams::cache_key]); - nntrainer::Tensor &cache_value = - context.getTensor(tensor_idx[AttentionParams::cache_value]); - - nntrainer::Tensor sink; - if (use_sink) { - sink = context.getWeight(sink_idx); - } - - ml::train::TensorDim query_dim = - query.getDim(); // (B, 1, seq_len, n_heads_Q * head_dim) - ml::train::TensorDim key_dim = - key.getDim(); // (B, 1, seq_len, n_heads_KV * head_dim) - ml::train::TensorDim value_dim = - value.getDim(); // (B, 1, seq_len, n_heads_KV * head_dim) - ml::train::TensorDim output_dim = - output.getDim(); // (B, 1, seq_len, n_heads_Q * head_dim) - ml::train::TensorDim cache_key_dim = - cache_key.getDim(); // (B, 1, max_timestep, n_heads_KV * head_dim) - ml::train::TensorDim cache_value_dim = - cache_value.getDim(); // (B, 1, max_timestep, n_heads_KV * head_dim) - - ml::train::TensorDim query_step_dim = - get_step_dim(query_dim); // (1, 1, step_size, n_heads_Q * head_dim) - ml::train::TensorDim key_step_dim = get_step_dim(key_dim); - ml::train::TensorDim value_step_dim = get_step_dim(value_dim); - ml::train::TensorDim output_step_dim = - get_step_dim(output_dim); // (1, 1, step_size, n_heads_Q * head_dim) - ml::train::TensorDim cache_key_step_dim = - get_step_dim(cache_key_dim); // (1, 1, step_size, n_heads_KV * head_dim) - - ml::train::TensorDim cache_value_step_dim = - get_step_dim(cache_value_dim); // (1, 1, step_size, n_heads_KV * head_dim) - - unsigned int batch_size = query_dim.batch(); - // do the incremental forwarding - for (unsigned int batch = 0; batch < batch_size; ++batch) { - - // preparing step tensors - nntrainer::Tensor query_step = query.getSharedDataTensor( - query_step_dim, batch * query_dim.getFeatureLen(), true); - nntrainer::Tensor key_step = key.getSharedDataTensor( - key_step_dim, batch * key_dim.getFeatureLen(), true); - nntrainer::Tensor value_step = value.getSharedDataTensor( - value_step_dim, batch * value_dim.getFeatureLen(), true); - nntrainer::Tensor output_step = output.getSharedDataTensor( - output_step_dim, batch * output_dim.getFeatureLen(), true); - - if (query_step.getDataType() == ml::train::TensorDim::DataType::FP32) { -#if ENABLE_FP16 && defined(__ANDROID__) - nntrainer::TensorDim Q_step_dim = query_step_dim; - nntrainer::TensorDim K_step_dim = key_step_dim; - nntrainer::TensorDim V_step_dim = value_step_dim; - nntrainer::TensorDim O_step_dim = output_step_dim; - Q_step_dim.setDataType(ml::train::TensorDim::DataType::FP16); - K_step_dim.setDataType(ml::train::TensorDim::DataType::FP16); - V_step_dim.setDataType(ml::train::TensorDim::DataType::FP16); - O_step_dim.setDataType(ml::train::TensorDim::DataType::FP16); - - nntrainer::Tensor Q_step = nntrainer::Tensor(Q_step_dim, true); - nntrainer::Tensor K_step = nntrainer::Tensor(K_step_dim, true); - nntrainer::Tensor V_step = nntrainer::Tensor(V_step_dim, true); - nntrainer::Tensor O_step = nntrainer::Tensor(O_step_dim, true); - - Q_step.copyData(query_step); - K_step.copyData(key_step); - V_step.copyData(value_step); - if (use_sink) { - one_batch_incremental_forwarding( - batch, _from, from, to, Q_step, K_step, V_step, O_step, cache_key, - cache_value, cache_key_dim, cache_key_step_dim, cache_value_dim, - cache_value_step_dim, sink); - } else { - one_batch_incremental_forwarding(batch, _from, from, to, Q_step, K_step, - V_step, O_step, cache_key, cache_value, - cache_key_dim, cache_key_step_dim, - cache_value_dim, cache_value_step_dim); - } - output_step.copyData(O_step); -#else - if (use_sink) { - one_batch_incremental_forwarding( - batch, _from, from, to, query_step, key_step, value_step, output_step, - cache_key, cache_value, cache_key_dim, cache_key_step_dim, - cache_value_dim, cache_value_step_dim, sink); - } else { - one_batch_incremental_forwarding( - batch, _from, from, to, query_step, key_step, value_step, output_step, - cache_key, cache_value, cache_key_dim, cache_key_step_dim, - cache_value_dim, cache_value_step_dim); - } -#endif - } else { - one_batch_incremental_forwarding( - batch, _from, from, to, query_step, key_step, value_step, output_step, - cache_key, cache_value, cache_key_dim, cache_key_step_dim, - cache_value_dim, cache_value_step_dim); - } - } - - // increase cache size - cache_index += step_size; -} - -/** - * @brief Function to compute Attention Scores using Tensor inputs. Wrapper - * around nntrainer::compute_kcaches with multi-threading support - * - * Expected Input Shapes: - * @param in (Query): [Batch, 1, sequence_len, Num_Heads_Q * Head_Dim] - * @param cache (Key Cache): [Batch, 1, Max_Timestep, Num_Heads_KV * Head_Dim] - * @param out (Attention Score): [Batch, 1, 1, Num_Heads_Q * Context_Len] - * where Context_Len is usually the current timestep 'to'. - * - */ -void MHACoreLayer::compute_kcaches( - nntrainer::Tensor &in, nntrainer::Tensor &cache, nntrainer::Tensor &out, - unsigned int from, size_t sequence_len, unsigned int num_head, - unsigned int group_size, unsigned int head_dim, BS::thread_pool<> &pool) { - - // Dispatch based on data type (FP32 or FP16) - if (in.getDataType() == ml::train::TensorDim::DataType::FP32) { - if (sequence_len == 1) { - // Single token processing (common during generation) - // Parallelize over KV heads for decoding since Q direction is always 1 - int row_to_compute = is_causal ? from + 1 : from + sequence_len; - unsigned int num_cache_head = num_head / group_size; - - // Use OpenMP for lower overhead parallelization during decoding - const float *in_data = in.getData(); - const uint16_t *cache_data = cache.getData(); - float *out_data = out.getData(); - -#pragma omp parallel for schedule(static) - for (unsigned int head_kv = 0; head_kv < num_cache_head; ++head_kv) { - nntrainer::compute_kcaches( - in_data, cache_data, out_data, row_to_compute, num_cache_head, - head_dim, group_size, tile_size, local_window_size, head_kv, - head_kv + 1); - } - - } else { - // Sequence processing (prefill or chunked) - // Parallelize over the sequence length - std::vector> futures; - int seq = - sequence_len < local_window_size ? sequence_len : local_window_size; - - for (int i = 0; i < seq; ++i) { - float *input_addr = in.getData() + num_head * head_dim * i; - uint16_t *cache_addr = cache.getData(); - int row_to_compute = is_causal ? from + i + 1 : from + sequence_len; - // Calculate dynamic offset for the output (triangle optimization) - size_t out_start_row = - is_causal ? calc_attn_index(from + i) - calc_attn_index(from) - : i * (from + sequence_len); - float *output_addr = out.getData() + out_start_row * num_head; - - futures.emplace_back(pool.submit_task([=]() { - nntrainer::compute_kcaches( - input_addr, cache_addr, output_addr, row_to_compute, - num_head / group_size, head_dim, group_size, tile_size, - local_window_size); - })); - } - for (auto &fut : futures) - fut.get(); - } - } else if (in.getDataType() == ml::train::TensorDim::DataType::FP16) { -#ifdef ENABLE_FP16 - if (sequence_len == 1) { - // Single token processing (common during generation) - // Parallelize over KV heads for decoding since Q direction is always 1 - int num_rows = is_causal ? from + 1 : from + sequence_len; - unsigned int num_cache_head = num_head / group_size; - - // Use OpenMP for lower overhead parallelization during decoding - const _FP16 *in_data = in.getData<_FP16>(); - const _FP16 *cache_data = cache.getData<_FP16>(); - _FP16 *out_data = out.getData<_FP16>(); - -#pragma omp parallel for schedule(static) - for (unsigned int head_kv = 0; head_kv < num_cache_head; ++head_kv) { - nntrainer::compute_kcaches( - in_data, cache_data, out_data, num_rows, num_cache_head, head_dim, - group_size, tile_size, local_window_size, head_kv, head_kv + 1); - } - } else { - std::vector> futures; - unsigned int seq_start = - sequence_len < local_window_size ? 0 : sequence_len - local_window_size; - for (unsigned int i = seq_start; i < sequence_len; ++i) { - _FP16 *input_addr = in.getData<_FP16>() + num_head * head_dim * i; - _FP16 *cache_addr = cache.getData<_FP16>(); - int row_to_compute = is_causal ? from + i + 1 : from + sequence_len; - size_t out_start_row = - is_causal ? calc_attn_index(from + i) - calc_attn_index(from) - : i * (from + sequence_len); - - _FP16 *output_addr = out.getData<_FP16>() + out_start_row * num_head; - - futures.emplace_back(pool.submit_task([=]() { - nntrainer::compute_kcaches(input_addr, cache_addr, output_addr, - row_to_compute, num_head / group_size, - head_dim, group_size, tile_size, - local_window_size); - })); - } - for (auto &fut : futures) - fut.get(); - } -#else - NNTR_THROW_IF(true, std::invalid_argument) << "enable-fp16 is not set!"; -#endif - } -} - -void MHACoreLayer::one_batch_incremental_forwarding( - const unsigned int batch, const unsigned int _from, const unsigned int from, - const unsigned int to, nntrainer::Tensor &query_step, - nntrainer::Tensor &key_step, nntrainer::Tensor &value_step, - nntrainer::Tensor &attention_output_step, nntrainer::Tensor &cache_key, - nntrainer::Tensor &cache_value, ml::train::TensorDim &cache_key_dim, - ml::train::TensorDim &cache_key_step_dim, - ml::train::TensorDim &cache_value_dim, - ml::train::TensorDim &cache_value_step_dim) { - - /** - * - * cache_key - * +------------------------------------------+ - * |<--cache_index-->|<--b_cache_value_step-->| - * +------------------------------------------+ - * |<-------key_step------->| - * |<-------------b_cached_key--------------->| - */ - - // Load Input Tensors of this batch : b_ denotes a Tensor for this batch - auto &pool = - nntrainer::Engine::Global().getThreadPoolManager()->getThreadPool(); - - nntrainer::Tensor b_cache_key_step = cache_key.getSharedDataTensor( - cache_key_step_dim, - batch * cache_key_dim.getFeatureLen() + cache_index * cache_key_dim.width(), - true); - nntrainer::Tensor b_cache_value_step = - cache_value.getSharedDataTensor(cache_value_step_dim, - batch * cache_value_dim.getFeatureLen() + - cache_index * cache_value_dim.width(), - true); - - // apply rotary embedding for query - apply_rotary_emb_tensor_v2(query_step, query_step, head_dim, cache_index, - false); - - // append kcache with rotary embedding - apply_rotary_emb_tensor_v2(key_step, b_cache_key_step, head_dim, cache_index, - false); - - // append vcache without rotary embedding - if (query_step.getDataType() == ml::train::TensorDim::DataType::FP32) { - apply_rotary_emb_tensor_v2(value_step, b_cache_value_step, head_dim, - cache_index, true); - } else if (query_step.getDataType() == ml::train::TensorDim::DataType::FP16) { -#ifdef ENABLE_FP16 - b_cache_value_step.copyData(value_step); -#else - NNTR_THROW_IF(true, std::invalid_argument) << "enable-fp16 is not set!"; -#endif - } - - /// @todo replace step_size into input height - unsigned int step_size = to - from; - unsigned int cache_from = cache_index; - unsigned int cache_to = cache_from + step_size; - - ml::train::TensorDim cached_key_dim = cache_key_dim; - ml::train::TensorDim cached_value_dim = cache_value_dim; - cached_key_dim.height(cache_to); - cached_value_dim.height(cache_to); - - nntrainer::Tensor b_cached_key = cache_key.getSharedDataTensor( - cached_key_dim, batch * cache_key_dim.getFeatureLen(), true); - nntrainer::Tensor b_cached_value = cache_value.getSharedDataTensor( - cached_value_dim, batch * cache_value_dim.getFeatureLen(), true); - - // out_ stores the output of Q * K - nntrainer::Tensor out_( - 1, 1, - is_causal ? (calc_attn_index(cache_to) - calc_attn_index(cache_from)) - : (step_size * cache_to), - num_heads_Q, query_step.getTensorType()); - - unsigned int gqa_size = num_heads_Q / num_heads_KV; - - compute_kcaches(query_step, b_cached_key, out_, cache_from, - cache_to - cache_from, num_heads_Q, gqa_size, head_dim, pool); - - softmax_triangle(out_, step_size, num_heads_Q, cache_from, pool); - - compute_fp16vcache_transposed(out_, b_cached_value, attention_output_step, - cache_from, num_heads_KV, gqa_size, head_dim, - cache_to, pool); -} - -void MHACoreLayer::one_batch_incremental_forwarding( - const unsigned int batch, const unsigned int _from, const unsigned int from, - const unsigned int to, nntrainer::Tensor &query_step, - nntrainer::Tensor &key_step, nntrainer::Tensor &value_step, - nntrainer::Tensor &attention_output_step, nntrainer::Tensor &cache_key, - nntrainer::Tensor &cache_value, ml::train::TensorDim &cache_key_dim, - ml::train::TensorDim &cache_key_step_dim, - ml::train::TensorDim &cache_value_dim, - ml::train::TensorDim &cache_value_step_dim, nntrainer::Tensor &sink_step) { - /// @todo replace from, to into cache_index, input height - /// @note currently, only gpt-oss uses this method - - /** - * cache_key - * +--------+ -> - * | | -> - * | | -> - * |........| from -> - * |........| to -> b_cache_key_step -> b_cached_key - * | | - * +--------+ - * - */ - - /** 1. Load Input Tensors of this batch : b_ denotes a Tensor for this batch - * **/ - auto &pool = - nntrainer::Engine::Global().getThreadPoolManager()->getThreadPool(); - - nntrainer::Tensor b_cache_key_step = cache_key.getSharedDataTensor( - cache_key_step_dim, - batch * cache_key_dim.getFeatureLen() + from * cache_key_dim.width(), true); - nntrainer::Tensor b_cache_value_step = cache_value.getSharedDataTensor( - cache_value_step_dim, - batch * cache_value_dim.getFeatureLen() + from * cache_value_dim.width(), - true); - - apply_rotary_emb_tensor_v2(query_step, query_step, head_dim, _from, false); - - apply_rotary_emb_tensor_v2(key_step, b_cache_key_step, head_dim, _from, - false); - - if (query_step.getDataType() == ml::train::TensorDim::DataType::FP32) { - apply_rotary_emb_tensor_v2(value_step, b_cache_value_step, head_dim, _from, - true); - } else if (query_step.getDataType() == ml::train::TensorDim::DataType::FP16) { -#ifdef ENABLE_FP16 - b_cache_value_step.copyData(value_step); -#else - NNTR_THROW_IF(true, std::invalid_argument) << "enable-fp16 is not set!"; -#endif - } - - ml::train::TensorDim cached_key_dim = cache_key_dim; - ml::train::TensorDim cached_value_dim = cache_value_dim; - cached_key_dim.height(to); - cached_value_dim.height(to); - - nntrainer::Tensor b_cached_key = cache_key.getSharedDataTensor( - cached_key_dim, batch * cache_key_dim.getFeatureLen(), true); - nntrainer::Tensor b_cached_value = cache_value.getSharedDataTensor( - cached_value_dim, batch * cache_value_dim.getFeatureLen(), true); - - nntrainer::Tensor out_( - 1, 1, - is_causal - ? (((to - from) == 1) ? to : calc_attn_index(to) - calc_attn_index(from)) - : ((to - from) * to), - num_heads_Q, query_step.getTensorType()); - - unsigned int gqa_size = num_heads_Q / num_heads_KV; - - compute_kcaches(query_step, b_cached_key, out_, _from, to - from, num_heads_Q, - gqa_size, head_dim, pool); - - softmax_triangle(out_, to - from, num_heads_Q, from, pool, sink_step); - - compute_fp16vcache_transposed(out_, b_cached_value, attention_output_step, - from, num_heads_KV, gqa_size, head_dim, to, - pool); -} - -/************************************************************** */ - -/** - * @brief rotary embedding-related member function - * @note seq_len -> max_position_embeddings - */ -void MHACoreLayer::precompute_freqs(int head_dim, unsigned int seq_len, - float theta, bool is_fp16) { - // compute the freqs only when it is the first time to call this function -#ifdef ENABLE_FP16 - if (freqs_cos_fp16 != nullptr && freqs_cos_fp16->size() == seq_len) - return; -#else - if (freqs_cos != nullptr && freqs_cos->size() == seq_len) - return; -#endif - - if (thetas.empty()) { - if (rope_scaling_type == "default") - _compute_default_parameters(head_dim, theta); - else if (rope_scaling_type == "yarn") - _compute_yarn_parameters(head_dim, theta); - else - NNTR_THROW_IF(true, std::invalid_argument) << "Unsupported rope type!"; - } - - unsigned int half_ = head_dim / 2; - - if (!is_fp16) { - // cos / sin - auto cos = new std::vector>(); - cos->assign(seq_len, std::vector(head_dim, 0)); - auto sin = new std::vector>(); - sin->assign(seq_len, std::vector(head_dim, 0)); - - // update cos / sin frequency - for (unsigned int i = 0; i < seq_len; ++i) { - -#ifdef USE_NEON - nntrainer::calc_trigonometric_vals_dup(half_, thetas.data(), - (*cos)[i].data(), (*sin)[i].data(), - i, attention_scaling); -#else - for (unsigned int j = 0; j < half_; ++j) { - float angle = i * thetas[j]; - (*cos)[i][j] = std::cos(angle) * attention_scaling; - (*cos)[i][j + half_] = - std::cos(angle) * attention_scaling; // repeated 2 times - - (*sin)[i][j] = std::sin(angle) * attention_scaling; - (*sin)[i][j + half_] = - std::sin(angle) * attention_scaling; // repeated 2 times - } -#endif - } - freqs_cos = cos; - freqs_sin = sin; - } - -#ifdef ENABLE_FP16 - if (is_fp16) { - // cos / sin for FP16 - auto cos_fp16 = new std::vector>(); - cos_fp16->assign(seq_len, std::vector<_FP16>(head_dim, 0)); - auto sin_fp16 = new std::vector>(); - sin_fp16->assign(seq_len, std::vector<_FP16>(head_dim, 0)); - - std::vector cos_tmp(head_dim); - std::vector sin_tmp(head_dim); - - for (unsigned int i = 0; i < seq_len; ++i) { -#ifdef USE_NEON - nntrainer::calc_trigonometric_vals_dup(half_, thetas.data(), - cos_tmp.data(), sin_tmp.data(), i, - attention_scaling); -#else - for (unsigned int j = 0; j < half_; ++j) { - float angle = i * thetas[j]; - cos_tmp[j] = std::cos(angle) * attention_scaling; - cos_tmp[j + half_] = - std::cos(angle) * attention_scaling; // repeated 2 times - - sin_tmp[j] = std::sin(angle) * attention_scaling; - sin_tmp[j + half_] = - std::sin(angle) * attention_scaling; // repeated 2 times - } -#endif - for (unsigned int j = 0; j < head_dim; ++j) { - (*cos_fp16)[i][j] = (_FP16)cos_tmp[j]; - (*sin_fp16)[i][j] = (_FP16)sin_tmp[j]; - } - } - freqs_cos_fp16 = cos_fp16; - freqs_sin_fp16 = sin_fp16; - } -#endif -}; - -void MHACoreLayer::_compute_default_parameters(int head_dim, float theta) { - - // no attention scaling - attention_scaling = 1.0f; - - // theta_i = 10000^(-2(i-1)/dim) for i = [1, 2, ... , dim/2] - // head_dim should be divisible by 2 - unsigned int half_ = head_dim / 2; - for (unsigned int i = 0; i < half_; ++i) { - thetas.push_back(1.0 / - (std::pow(theta, (2 * i) / static_cast(head_dim)))); - } -} - -void MHACoreLayer::_compute_yarn_parameters(int head_dim, float theta) { - - // Config parameters - ///@todo partial_rotary_factor should be generalized to fully support - /// transformers's implementation - // const float partial_rotary_factor = has_partial_rotary_factor ? - // config_partial_rotary_factor : 1.0f; - const float partial_rotary_factor = 1.0f; - const int dim = static_cast(head_dim * partial_rotary_factor); - const float base = theta; - - // Handle max position embeddings - - // Attention scaling calculation (simplified from Python version) - auto get_mscale = [](float scale, float mscale = 1.0f) { - return (scale <= 1.0f) ? 1.0f : (0.1f * mscale * std::log(scale) + 1.0f); - }; - - ///@todo attention_scaling should be generalized to fully support - /// transformers's implementation - // if (has_mscale && has_mscale_all_dim) { - // attention_scaling = get_mscale(factor, mscale) / get_mscale(factor, - // mscale_all_dim); - // } else { - // attention_scaling = get_mscale(factor); - // } - attention_scaling = get_mscale(scale); - - ///@todo attention_scaling should be generalized to fully support - /// transformers's implementation - // const float beta_fast = has_beta_fast ? config_beta_fast : 32.0f; - // const float beta_slow = has_beta_slow ? config_beta_slow : 1.0f; - // const bool truncate = has_truncate ? config_truncate : true; - // Beta parameters - const float beta_fast = 32.0f; - const float beta_slow = 1.0f; - const bool truncate = false; - - // Helper functions - auto find_correction_dim = [&](float num_rotations) { - return (dim * std::log(original_max_position_embeddings / - (num_rotations * 2 * M_PI))) / - (2 * std::log(base)); - }; - - auto [low, high] = [&]() { - float low_val = find_correction_dim(beta_fast); - float high_val = find_correction_dim(beta_slow); - if (truncate) { - low_val = std::floor(low_val); - high_val = std::ceil(high_val); - } - return std::make_pair(low_val, high_val); - }(); - - // Compute position frequencies - thetas.resize(dim / 2); - - // Compute interpolation and extrapolation frequencies - std::vector inv_freq_interpolation; - std::vector inv_freq_extrapolation; - for (size_t i = 0; i < dim / 2; ++i) { - inv_freq_extrapolation.push_back( - 1.0 / (std::pow(theta, (2 * i) / static_cast(head_dim)))); - inv_freq_interpolation.push_back( - 1.0 / (scale * std::pow(theta, (2 * i) / static_cast(head_dim)))); - } - - auto linear_ramp_factor = [](float min, float max, int size) { - if (min == max) { - max += 0.001f; // Prevent singularity - } - std::vector ramp(size); - for (int i = 0; i < size; ++i) { - float val = (i - min) / (max - min); - ramp[i] = std::clamp(val, 0.0f, 1.0f); - } - return ramp; - }; - - std::vector inv_freq_extrapolation_factor = - linear_ramp_factor(low, high, dim / 2); - for (auto &val : inv_freq_extrapolation_factor) { - val = 1.0f - val; - } - - // Combine frequencies - for (size_t i = 0; i < thetas.size(); ++i) { - thetas[i] = - inv_freq_extrapolation[i] * inv_freq_extrapolation_factor[i] + - inv_freq_interpolation[i] * (1.0f - inv_freq_extrapolation_factor[i]); - } -} - -void MHACoreLayer::apply_rotary_emb_tensor_v2(nntrainer::Tensor &in, - nntrainer::Tensor &out, - unsigned int dim, - unsigned int from, - bool convert_only) { - unsigned int half_ = dim / 2; - unsigned int max_timestep = - std::get(mha_core_props).get(); - - if (in.getDataType() == ml::train::TensorDim::DataType::FP32) { - if (freqs_cos == nullptr) { - const std::lock_guard lock(rope_init_mtx); - if (freqs_cos == nullptr) { - precompute_freqs(head_dim, max_position_embeddings, theta, false); - } - } - std::vector *cos_ = nullptr; - std::vector *sin_ = nullptr; - - for (unsigned int b = 0; b < in.batch(); b++) { - for (unsigned int c = 0; c < in.channel(); c++) { - for (unsigned int h = 0; h < in.height(); h++) { - if (from < max_timestep) { - cos_ = &(*freqs_cos)[from + h]; - sin_ = &(*freqs_sin)[from + h]; - } - float *in_ptr = in.getData() + - b * in.channel() * in.height() * in.width() + - c * in.height() * in.width() + h * in.width(); - - if (out.getDataType() == ml::train::TensorDim::DataType::FP32) { - - nntrainer::compute_rotary_emb_value(in.width(), dim, half_, in_ptr, - nullptr, cos_->data(), - sin_->data(), convert_only); - } else if (out.getDataType() == - ml::train::TensorDim::DataType::UINT16 || - out.getDataType() == - ml::train::TensorDim::DataType::FP16) { - uint16_t *out_ptr = out.getData() + - b * out.channel() * out.height() * out.width() + - c * out.height() * out.width() + - h * out.width(); - - nntrainer::compute_rotary_emb_value(in.width(), dim, half_, in_ptr, - out_ptr, cos_->data(), - sin_->data(), convert_only); - } - } - } - } - } else if (in.getDataType() == ml::train::TensorDim::DataType::FP16) { -#ifdef ENABLE_FP16 - if (freqs_cos_fp16 == nullptr) { - const std::lock_guard lock(rope_init_mtx); - if (freqs_cos_fp16 == nullptr) { - precompute_freqs(head_dim, max_position_embeddings, theta, true); - } - } - std::vector<_FP16> *cos_ = nullptr; - std::vector<_FP16> *sin_ = nullptr; - - for (unsigned int b = 0; b < in.batch(); b++) { - for (unsigned int c = 0; c < in.channel(); c++) { - for (unsigned int h = 0; h < in.height(); h++) { - if (from < max_timestep) { - cos_ = &(*freqs_cos_fp16)[from + h]; - sin_ = &(*freqs_sin_fp16)[from + h]; - } - _FP16 *in_ptr = in.getData<_FP16>() + - b * in.channel() * in.height() * in.width() + - c * in.height() * in.width() + h * in.width(); - _FP16 *out_ptr = out.getData<_FP16>() + - b * out.channel() * out.height() * out.width() + - c * out.height() * out.width() + h * out.width(); - - nntrainer::compute_rotary_emb_value(in.width(), dim, half_, in_ptr, - out_ptr, cos_->data(), - sin_->data()); - } - } - } -#else - NNTR_THROW_IF(true, std::invalid_argument) << "enable-fp16 is not set!"; -#endif - } -} - -void MHACoreLayer::softmax_triangle(nntrainer::Tensor &qk_out, size_t row, - size_t num_head, unsigned int from, - BS::thread_pool<> &pool) { - if (qk_out.getDataType() == ml::train::TensorDim::DataType::FP32) { - float *qk_out_ = qk_out.getData(); - - if (attn_logit_softcapping > 0.0f) { - size_t len = - qk_out.batch() * qk_out.height() * qk_out.width() * qk_out.channel(); - float inv_softcapping = 1.0f / attn_logit_softcapping; - for (size_t i = 0; i < len; ++i) { - qk_out_[i] = - std::tanh(qk_out_[i] * inv_softcapping) * attn_logit_softcapping; - } - } - - if (row == 1) { - size_t start_row = 0; - size_t end_row = 0; - if (is_causal) { - end_row = from < local_window_size ? from + 1 : local_window_size; - } else { - end_row = from + row; // end_row = to - } - nntrainer::softmax_row_inplace(qk_out_, start_row, end_row, num_head); - } else { - std::vector> futures; - int seq = row < local_window_size ? row : local_window_size; - if (!is_causal) - seq = row; - - for (int i = 0; i < seq; ++i) { - size_t start_row, end_row; - if (is_causal) { - start_row = calc_attn_index(from + i) - calc_attn_index(from); - end_row = calc_attn_index(from + i + 1) - calc_attn_index(from); - } else { - unsigned int to = from + row; - start_row = i * to; - end_row = (i + 1) * to; - } - futures.push_back(pool.submit_task([=]() { - nntrainer::softmax_row(qk_out_, start_row, end_row, num_head); - })); - } - for (auto &fut : futures) { - fut.get(); - } - } - } else if (qk_out.getDataType() == ml::train::TensorDim::DataType::FP16) { -#ifdef ENABLE_FP16 - _FP16 *qk_out_ = qk_out.getData<_FP16>(); - - if (attn_logit_softcapping > 0.0f) { - size_t len = - qk_out.batch() * qk_out.height() * qk_out.width() * qk_out.channel(); - float inv_softcapping = 1.0f / attn_logit_softcapping; - for (size_t i = 0; i < len; ++i) { - qk_out_[i] = (_FP16)(std::tanh((float)qk_out_[i] * inv_softcapping) * - attn_logit_softcapping); - } - } - - if (row == 1) { - size_t start_row = 0; - size_t end_row = 0; - if (is_causal) { - end_row = from < local_window_size ? from + 1 : local_window_size; - } else { - end_row = from + row; // end_row = to - } - nntrainer::softmax_row_inplace(qk_out_, start_row, end_row, num_head); - } else { - std::vector> futures; - int seq = row < local_window_size ? row : local_window_size; - if (!is_causal) - seq = row; - - for (int i = 0; i < seq; ++i) { - size_t start_row, end_row; - if (is_causal) { - start_row = calc_attn_index(from + i) - calc_attn_index(from); - end_row = calc_attn_index(from + i + 1) - calc_attn_index(from); - } else { - unsigned int to = from + row; - start_row = i * to; - end_row = (i + 1) * to; - } - futures.push_back(pool.submit_task([=]() { - nntrainer::softmax_row_inplace(qk_out_, start_row, end_row, num_head); - })); - } - for (auto &fut : futures) { - fut.get(); - } - } -#else - NNTR_THROW_IF(true, std::invalid_argument) << "enable-fp16 is not set!"; -#endif - } -} - -void MHACoreLayer::softmax_triangle(nntrainer::Tensor &qk_out, size_t row, - size_t num_head, unsigned int from, - BS::thread_pool<> &pool, - nntrainer::Tensor &sink_step) { - if (qk_out.getDataType() == ml::train::TensorDim::DataType::FP32) { - float *qk_out_ = qk_out.getData(); - - if (attn_logit_softcapping > 0.0f) { - size_t len = - qk_out.batch() * qk_out.height() * qk_out.width() * qk_out.channel(); - float inv_softcapping = 1.0f / attn_logit_softcapping; - for (size_t i = 0; i < len; ++i) { - qk_out_[i] = - std::tanh(qk_out_[i] * inv_softcapping) * attn_logit_softcapping; - } - } - - if (row == 1) { - size_t start_row = 0; - size_t end_row = 0; - if (is_causal) { - end_row = from < local_window_size ? from + 1 : local_window_size; - } else { - unsigned int to = from + row; - end_row = to; - } - nntrainer::softmax_row_inplace(qk_out_, start_row, end_row, num_head, - sink_step.getData()); - } else { - std::vector> futures; - - int seq = row < local_window_size ? row : local_window_size; - if (!is_causal) - seq = row; - - for (int i = 0; i < seq; ++i) { - size_t start_row, end_row; - if (is_causal) { - start_row = calc_attn_index(i + from) - calc_attn_index(from); - end_row = calc_attn_index(from + i + 1) - calc_attn_index(from); - } else { - unsigned int to = from + row; - start_row = i * to; - end_row = (i + 1) * to; - } - futures.push_back(pool.submit_task([=]() { - nntrainer::softmax_row(qk_out_, start_row, end_row, num_head, - sink_step.getData()); - })); - } - for (auto &fut : futures) { - fut.get(); - } - } - } else if (qk_out.getDataType() == ml::train::TensorDim::DataType::FP16) { -#ifdef ENABLE_FP16 - _FP16 *qk_out_ = qk_out.getData<_FP16>(); - _FP16 *sink_step_ = sink_step.getData<_FP16>(); - - if (attn_logit_softcapping > 0.0f) { - size_t len = - qk_out.batch() * qk_out.height() * qk_out.width() * qk_out.channel(); - float inv_softcapping = 1.0f / attn_logit_softcapping; - for (size_t i = 0; i < len; ++i) { - qk_out_[i] = (_FP16)(std::tanh((float)qk_out_[i] * inv_softcapping) * - attn_logit_softcapping); - } - } - - if (row == 1) { - size_t start_row = 0; - size_t end_row = 0; - if (is_causal) { - end_row = from < local_window_size ? from + 1 : local_window_size; - } else { - end_row = from + row; // end_row = to - } - nntrainer::softmax_row_inplace(qk_out_, start_row, end_row, num_head, - sink_step_); - } else { - std::vector> futures; - int seq = row < local_window_size ? row : local_window_size; - if (!is_causal) - seq = row; - - for (int i = 0; i < seq; ++i) { - size_t start_row = calc_attn_index(i + from) - calc_attn_index(from); - size_t end_row = calc_attn_index(from + i + 1) - calc_attn_index(from); - futures.push_back(pool.submit_task([=]() { - nntrainer::softmax_row(qk_out_, start_row, end_row, num_head, - sink_step_); - })); - } - for (auto &fut : futures) { - fut.get(); - } - } -#else - NNTR_THROW_IF(true, std::invalid_argument) << "enable-fp16 is not set!"; -#endif - } -} - -void MHACoreLayer::compute_fp16vcache_transposed( - nntrainer::Tensor &in, nntrainer::Tensor &vcache, nntrainer::Tensor &output, - int from, int num_cache_head, int gqa_size, int head_dim, int to, - BS::thread_pool<> &pool) { - - if (in.getDataType() == ml::train::TensorDim::DataType::FP32) { - if ((to - from) != 1) { - std::vector> futures; - - int seq = (to - from) < local_window_size ? to - from : local_window_size; - // if non-causal, seq is practically to - from. - if (!is_causal) - seq = to - from; - futures.reserve(seq); - - for (int i = 0; i < seq; ++i) { - futures.push_back(pool.submit_task([=]() { - size_t start_idx; - if (is_causal) { - start_idx = - calc_attn_index(to - seq + i) - calc_attn_index(to - seq); - } else { - start_idx = i * to; // linear index - } - const float *input = - in.getData() + start_idx * num_cache_head * gqa_size; - float *out = output.getData() + - i * (num_cache_head * gqa_size * head_dim); - - int row_num = is_causal ? (to - seq + i) : to - 1; - nntrainer::compute_fp16vcache_fp32_transposed( - row_num, input, vcache.getData(), out, num_cache_head, - gqa_size, head_dim, local_window_size); - })); - } - for (auto &fut : futures) - fut.get(); - } else { - // Single token processing (common during generation) - // Parallelize over KV heads for decoding since Q direction is always 1 - int row_num = to - 1; - - // Use OpenMP for lower overhead parallelization during decoding - const float *in_data = in.getData(); - const uint16_t *vcache_data = vcache.getData(); - float *output_data = output.getData(); - -#pragma omp parallel for schedule(static) - for (int head_kv = 0; head_kv < num_cache_head; ++head_kv) { - nntrainer::compute_fp16vcache_fp32_transposed( - row_num, in_data, vcache_data, output_data, num_cache_head, gqa_size, - head_dim, local_window_size, head_kv, head_kv + 1); - } - } - } else if (in.getDataType() == ml::train::TensorDim::DataType::FP16) { -#ifdef ENABLE_FP16 - if ((to - from) != 1) { - std::vector> futures; - int seq = (to - from) < local_window_size ? to - from : local_window_size; - if (!is_causal) - seq = to - from; - futures.reserve(seq); - - for (int i = 0; i < seq; ++i) { - futures.push_back(pool.submit_task([=]() { - size_t start_idx; - if (is_causal) { - start_idx = - calc_attn_index(to - seq + i) - calc_attn_index(to - seq); - } else { - start_idx = i * to; - } - const _FP16 *input = - in.getData<_FP16>() + start_idx * num_cache_head * gqa_size; - _FP16 *out = output.getData<_FP16>() + - i * (num_cache_head * gqa_size * head_dim); - int row_num = is_causal ? (to - seq + i) : to - 1; - nntrainer::compute_fp16vcache_transposed( - row_num, input, vcache.getData<_FP16>(), out, num_cache_head, - gqa_size, head_dim, local_window_size); - })); - } - for (auto &fut : futures) - fut.get(); - } else { - // Single token processing (common during generation) - // Parallelize over KV heads for decoding since Q direction is always 1 - int row_num = to - 1; - - // Use OpenMP for lower overhead parallelization during decoding - const _FP16 *in_data = in.getData<_FP16>(); - const _FP16 *vcache_data = vcache.getData<_FP16>(); - _FP16 *output_data = output.getData<_FP16>(); - -#pragma omp parallel for schedule(static) - for (int head_kv = 0; head_kv < num_cache_head; ++head_kv) { - nntrainer::compute_fp16vcache_transposed( - row_num, in_data, vcache_data, output_data, num_cache_head, gqa_size, - head_dim, local_window_size, head_kv, head_kv + 1); - } - } -#else - NNTR_THROW_IF(true, std::invalid_argument) << "enable-fp16 is not set!"; -#endif - } -} - -void MHACoreLayer::setBatch(nntrainer::RunLayerContext &context, - unsigned int batch) { - - const float dropout_rate = - std::get(mha_core_props).get(); - context.updateTensor(tensor_idx[AttentionParams::cache_key], batch); - context.updateTensor(tensor_idx[AttentionParams::cache_value], batch); - // context.updateTensor(tensor_idx[AttentionParams::attention_weight], batch); - if (dropout_rate > epsilon) { - context.updateTensor(tensor_idx[AttentionParams::dropout_mask], batch); - } -} - -void MHACoreLayer::updateTensorsByInputDimensions( - nntrainer::RunLayerContext &context, - std::vector input_dimensions) { - unsigned int height = input_dimensions[0].height(); - unsigned int &max_timestep = - std::get(mha_core_props).get(); - unsigned int &max_new_tokens = - std::get(mha_core_props).get(); - max_position_embeddings = - std::get(mha_core_props).get(); - max_timestep = height + max_new_tokens; - - ml::train::TensorDim kv_dim = input_dimensions[0]; - kv_dim.width(kv_dim.width() / (num_heads_Q / num_heads_KV)); - - ml::train::TensorDim kv_cache_dim = kv_dim; -#ifdef ENABLE_FP16 - kv_cache_dim.setDataType(ml::train::TensorDim::DataType::FP16); -#else - kv_cache_dim.setDataType(ml::train::TensorDim::DataType::UINT16); -#endif - kv_cache_dim.height(max_timestep); - - context.updateInput(INOUT_INDEX::QUERY, input_dimensions[0]); - context.updateInput(INOUT_INDEX::KEY, kv_dim); - context.updateInput(INOUT_INDEX::VALUE, kv_dim); - context.updateOutput(0, input_dimensions[0]); - - context.updateTensor(tensor_idx[AttentionParams::cache_key], kv_cache_dim); - context.updateTensor(tensor_idx[AttentionParams::cache_value], kv_cache_dim); -} - -void MHACoreLayer::calcDerivative(nntrainer::RunLayerContext &context) {} - -void MHACoreLayer::calcGradient(nntrainer::RunLayerContext &context) {} - -void MHACoreLayer::exportTo(nntrainer::Exporter &exporter, - const ml::train::ExportMethods &method) const { - LayerImpl::exportTo(exporter, method); - exporter.saveResult(mha_core_props, method, this); -} - -void MHACoreLayer::setProperty(const std::vector &values) { - auto remain_props = loadProperties(values, mha_core_props); - LayerImpl::setProperty(remain_props); -} - -size_t MHACoreLayer::calc_attn_index(size_t i) { return (i * (i + 1)) / 2; }; - -#ifdef PLUGGABLE - -nntrainer::Layer *create_mha_core_layer() { - auto layer = new MHACoreLayer(); - return layer; -} - -void destroy_mha_core_layer(nntrainer::Layer *layer) { delete layer; } - -extern "C" { -nntrainer::LayerPluggable ml_train_layer_pluggable{create_mha_core_layer, - destroy_mha_core_layer}; -} - -#endif - -} // namespace quick_dot_ai diff --git a/layers/mha_core.h b/layers/mha_core.h deleted file mode 100644 index 7c3271e0..00000000 --- a/layers/mha_core.h +++ /dev/null @@ -1,453 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * Copyright (C) 2025 Jijoong Moon - * - * @file mha_core.h - * @date 02 September 2024 - * @see https://github.com/nntrainer/nntrainer - * https://arxiv.org/abs/1706.03762 - * @author Jijoong Moon - * @bug No known bugs except for NYI items - * @brief This is custom_mha_core layer supports - * the work of multi_head_attention. - * @note Unlike custom_multi_head_attention_layer, - * which works all of the attention operations - * in a layer, this layer is attached after Q / K / V - * fully connected layer to post-process them - * including KV-Cache. - * For inference, incremental_forwarding is called, - * which takes inputs of seq_len = 1 via `from` / `to` param. - * For training, forwarding is called, - * which takes all input seqences at once. - */ - -#ifndef __MHA_CORE_H__ -#define __MHA_CORE_H__ - -#pragma once -#ifdef _WIN32 -#define WIN_EXPORT __declspec(dllexport) -#else -#define WIN_EXPORT -#endif - -#include - -#include -#include -#include -#include -#include -#include -#include - -#include - -namespace quick_dot_ai { - -namespace props { - -/** - * @brief NumHeads property, NumHeads is number of head in multi head attention - * of Q - */ -class NumHeads_KV : public nntrainer::PositiveIntegerProperty { -public: - /** - * @brief Construct a new NumHeads object with default value 1 - */ - NumHeads_KV(unsigned int value = 1) { set(value); }; - static constexpr const char *key = - "num_heads_KV"; /**< unique key to access */ - using prop_tag = nntrainer::uint_prop_tag; /**< property type */ -}; - -/** - * @brief SlidingWindow - */ -class SlidingWindow : public nntrainer::Property { -public: - SlidingWindow(unsigned int value = UINT_MAX) { set(value); }; - static constexpr const char *key = - "sliding_window"; /**< unique key to access */ - using prop_tag = nntrainer::uint_prop_tag; /**< property type */ -}; - -/** - * @brief MaxNewTokens - */ -class MaxNewTokens : public nntrainer::Property { -public: - MaxNewTokens(unsigned int value = 1) { set(value); }; - static constexpr const char *key = - "max_new_tokens"; /**< unique key to access */ - using prop_tag = nntrainer::uint_prop_tag; /**< property type */ -}; - -/** - * @brief MaxNewTokens - */ -class MaxPositionEmbeddings : public nntrainer::Property { -public: - MaxPositionEmbeddings(unsigned int value = 40960) { set(value); }; - static constexpr const char *key = - "max_position_embeddings"; /**< unique key to access */ - using prop_tag = nntrainer::uint_prop_tag; /**< property type */ -}; - -/** - * @brief RopeTheta - */ -class RopeTheta : public nntrainer::Property { -public: - RopeTheta(unsigned int value = 500000) { set(value); }; - static constexpr const char *key = "rope_theta"; /**< unique key to access */ - using prop_tag = nntrainer::uint_prop_tag; /**< property type */ -}; - -/** - * @brief UseSink property - */ -class UseSink : public nntrainer::Property { -public: - UseSink(bool value = false) { set(value); }; - static constexpr const char *key = "use_sink"; /**< unique key to access */ - using prop_tag = nntrainer::bool_prop_tag; /**< property type */ -}; - -/** - * @brief AttnLogitSoftcapping - */ -class AttnLogitSoftcapping : public nntrainer::Property { -public: - AttnLogitSoftcapping(float value = 0.0f) { set(value); }; - static constexpr const char *key = - "attn_logit_softcapping"; /**< unique key to access */ - using prop_tag = nntrainer::float_prop_tag; /**< property type */ -}; - -/** - * @brief IsCausal property - */ -class IsCausal : public nntrainer::Property { -public: - IsCausal(bool value = true) { set(value); }; - static constexpr const char *key = "is_causal"; /**< unique key to access */ - using prop_tag = nntrainer::bool_prop_tag; /**< property type */ -}; - -/** - * @brief RopeScalingType - * - default - * - yarn - */ -class RopeScalingType : public nntrainer::Property { -public: - RopeScalingType(std::string value = "default") { set(value); }; - static constexpr const char *key = - "rope_scaling_type"; /**< unique key to access */ - using prop_tag = nntrainer::str_prop_tag; /**< property type */ -}; -/** - * @brief RopeScalingFactor - */ -class RopeScalingFactor : public nntrainer::Property { -public: - RopeScalingFactor(float value = 1.0) { set(value); }; - static constexpr const char *key = - "rope_scaling_factor"; /**< unique key to access */ - using prop_tag = nntrainer::float_prop_tag; /**< property type */ -}; - -/** - * @brief RopeScalingMaxPositionEmbeddings - */ -class RopeScalingMaxPositionEmbeddings - : public nntrainer::Property { -public: - RopeScalingMaxPositionEmbeddings(unsigned int value = 4096) { set(value); }; - static constexpr const char *key = - "rope_scaling_max_position_embeddings"; /**< unique key to access */ - using prop_tag = nntrainer::uint_prop_tag; /**< property type */ -}; - -}; // namespace props - -/** - * @class MHA Core Layer - * @brief Part of Multi-Head-Attention Layer. - * It should be attached after Q / K / V fc layers and before O fc layer. - * custom_mha_core_layer computes attention, while updating KV-cache for - * inference mode. - * - * [ Q ] [ K ] [ V ] - * | | | - * [ mha_core ] - * | - * [ O ] - * - */ -WIN_EXPORT class MHACoreLayer : public nntrainer::LayerImpl { -public: - /** - * @brief Constructor of MhaCore Layer - */ - WIN_EXPORT MHACoreLayer(); - - /** - * @brief Destructor of MhaPost Layer - */ - WIN_EXPORT ~MHACoreLayer(); - - /** - * @brief Move constructor of CustomMultiHeadAttentionLayer. - * @param[in] CustomMultiHeadAttentionLayer && - */ - WIN_EXPORT - MHACoreLayer(MHACoreLayer &&rhs) noexcept = default; - - /** - * @brief Move assignment operator. - * @parma[in] rhs CustomMultiHeadAttentionLayer to be moved. - */ - WIN_EXPORT MHACoreLayer &operator=(MHACoreLayer &&rhs) = default; - - /** - * @brief Finalize funciton of MhaCore Layer - */ - WIN_EXPORT void finalize(nntrainer::InitLayerContext &context) override; - - /** - * @brief forwarding function of MhaCore Layer - * Please note that forwarding function is used only for training. - */ - WIN_EXPORT void forwarding(nntrainer::RunLayerContext &context, - bool training) override; - - void one_batch_incremental_forwarding( - const unsigned int batch, const unsigned int _from, const unsigned int from, - const unsigned int to, nntrainer::Tensor &query_step, - nntrainer::Tensor &key_step, nntrainer::Tensor &value_step, - nntrainer::Tensor &attention_output_step, nntrainer::Tensor &cache_key, - nntrainer::Tensor &cache_value, ml::train::TensorDim &cache_key_dim, - ml::train::TensorDim &cache_key_step_dim, - ml::train::TensorDim &cache_value_dim, - ml::train::TensorDim &cache_value_step_dim); - - void one_batch_incremental_forwarding( - const unsigned int batch, const unsigned int _from, const unsigned int from, - const unsigned int to, nntrainer::Tensor &query_step, - nntrainer::Tensor &key_step, nntrainer::Tensor &value_step, - nntrainer::Tensor &attention_output_step, nntrainer::Tensor &cache_key, - nntrainer::Tensor &cache_value, ml::train::TensorDim &cache_key_dim, - ml::train::TensorDim &cache_key_step_dim, - ml::train::TensorDim &cache_value_dim, - ml::train::TensorDim &cache_value_step_dim, nntrainer::Tensor &sink_step); - /** - * @copydoc Layer::calcDerivative(RunLayerContext &context) - */ - WIN_EXPORT void incremental_forwarding(nntrainer::RunLayerContext &context, - unsigned int from, unsigned int to, - bool training) override; - - /** - * @copydoc Layer::calcDerivative(RunLayerContext &context) - */ - WIN_EXPORT void calcDerivative(nntrainer::RunLayerContext &context) override; - - /** - * @copydoc Layer::calcGradient(RunLayerContext &context) - */ - WIN_EXPORT void calcGradient(nntrainer::RunLayerContext &context) override; - - /** - * @copydoc bool supportBackwarding() const - * @note In current version, we do not support backwarding yet. - * It will be updated ASAP. - */ - WIN_EXPORT bool supportBackwarding() const override { return true; }; - - /** - * @copydoc Layer::setBatch(RunLayerContext &context, unsigned int batch) - */ - WIN_EXPORT void - exportTo(nntrainer::Exporter &exporter, - const ml::train::ExportMethods &method) const override; - - /** - * @copydoc Layer::setProperty(const std::vector &values) - */ - WIN_EXPORT void setProperty(const std::vector &values) override; - - /** - * @copydoc Layer::getType() - */ - WIN_EXPORT const std::string getType() const override { - return MHACoreLayer::type; - }; - - /** - * @copydoc Layer::setBatch(RunLayerContext &context, unsigned int batch) - */ - WIN_EXPORT void setBatch(nntrainer::RunLayerContext &context, - unsigned int batch) override; - - WIN_EXPORT void updateTensorsByInputDimensions( - nntrainer::RunLayerContext &context, - std::vector input_dimensions) override; - - inline static const std::string type = "mha_core"; - -private: - std::tuple< - nntrainer::props::NumHeads, props::NumHeads_KV, - nntrainer::props::ProjectedKeyDim, nntrainer::props::ProjectedValueDim, - nntrainer::props::OutputShape, nntrainer::props::DropOutRate, - nntrainer::props::ReturnAttentionWeight, - nntrainer::props::AverageAttentionWeight, nntrainer::props::MaxTimestep, - props::SlidingWindow, props::MaxNewTokens, props::RopeTheta, - props::MaxPositionEmbeddings, props::UseSink, props::RopeScalingType, - props::RopeScalingFactor, props::RopeScalingMaxPositionEmbeddings, - props::AttnLogitSoftcapping, props::IsCausal> - mha_core_props; /**< mha_core layer properties */ - - /** softmax activation operation */ - nntrainer::ActiFunc sm; - - float epsilon; /** to avoid overflow */ - unsigned int cache_index; /** idx of kv cache */ - - /** intermal info */ - size_t num_heads_Q; - size_t num_heads_KV; - size_t head_dim; - bool cache_shift; - float theta; - size_t local_window_size; - bool use_sink = false; - float attn_logit_softcapping = 0.0f; - bool is_causal; - - enum INOUT_INDEX { - /** input index */ - QUERY = 0, - KEY = 1, - VALUE = 2, - MASK = 3, - - /** output index */ - OUTPUT = 0, - RETURN_ATTENTION_WEIGHT = 1, - }; - - /**< indices of the weights and tensors */ - enum AttentionParams { - cache_key, - cache_value, - projected_key, - projected_value, - /** intended comment for later use of attention_mask */ - // attention_mask, - attention_weight, - dropout_mask, - attention_output, - }; - std::array tensor_idx; - unsigned int sink_idx; - - /** attention parameters */ - unsigned int max_position_embeddings; - - /** rope_scaling parameters */ - std::string rope_scaling_type; - float attention_scaling = 1.0f; - float mscale = 1.0f; - float scale = 1.0f; - unsigned int original_max_position_embeddings = 4096; - - /****************** ROTARY EMBEDDING *****************/ - /** static variable - they are all expected to be initialized once */ - inline static std::vector> *freqs_cos = {}; - inline static std::vector> *freqs_sin = {}; - inline static std::vector thetas; -#ifdef ENABLE_FP16 - inline static std::vector> *freqs_cos_fp16 = {}; - inline static std::vector> *freqs_sin_fp16 = {}; -#endif - - /** - * @brief pre_compute frequencies for Rotary Embedding. - * @note it is expected to be called only once at the finalize. - * @param[in] head_dim dimension of head - * @param[in] seq_len sequence length - * @param[in] theta base of theta (default = 10000) - */ - void precompute_freqs(int head_dim, unsigned int seq_len, - float theta = 10000.0, bool is_fp16 = false); - - /** - * @brief _compute frequency parameters for default ROPE - */ - void _compute_default_parameters(int head_dim, float theta); - - /** - * @brief _compute frequency parameters for default ROPE - */ - void _compute_yarn_parameters(int head_dim, float theta); - - /** - * @brief apply rotary embedding - * @param[in] in input tensor - * @param[out] out output tensor - * @param[in] dim hidden dim size - * @param[in] from sequence order - * @param[in] convert_only - conversion only - */ - void apply_rotary_emb_tensor_v2(nntrainer::Tensor &in, nntrainer::Tensor &out, - unsigned int dim, unsigned int from, - bool convert_only = false); - - template - void compute(const float *A, const BType *B, float *output, int num_rows, - int N, int chunk_size, int group_size, int tile_size, - bool process_all); - - void compute_kcaches(nntrainer::Tensor &in, nntrainer::Tensor &cache, - nntrainer::Tensor &out, unsigned int from, - size_t sequence_len, unsigned int num_heads, - unsigned int group_size, unsigned int head_dim, - BS::thread_pool<> &pool); - - void softmax_triangle(nntrainer::Tensor &qk_out, size_t row, size_t num_heads, - unsigned int from, BS::thread_pool<> &pool); - - void softmax_triangle(nntrainer::Tensor &qk_out, size_t row, size_t num_heads, - unsigned int from, BS::thread_pool<> &pool, - nntrainer::Tensor &sink_step); - - void compute_vcaches(nntrainer::Tensor &in, nntrainer::Tensor &vcache, - nntrainer::Tensor &out, unsigned int from, - size_t sequence_len, unsigned int num_heads, - unsigned int group_size, unsigned int head_dim); - - void compute_fp16vcache_transposed(nntrainer::Tensor &in, - nntrainer::Tensor &vcache, - nntrainer::Tensor &output, int from, - int num_cache_head, int gqa_size, - int head_dim, int to, - BS::thread_pool<> &pool); - - /************** END OF ROTARY EMBEDDING *************/ - - /** - * @brief calculate common derivative - * @param context Context of the layer - */ - void calcCommonDerivative(nntrainer::RunLayerContext &context); - - size_t calc_attn_index(size_t i); - -}; // end of class MHACoreLayer -} // namespace quick_dot_ai - -#endif diff --git a/layers/qkv_layer.cpp b/layers/qkv_layer.cpp deleted file mode 100644 index 067f7ce8..00000000 --- a/layers/qkv_layer.cpp +++ /dev/null @@ -1,206 +0,0 @@ -/** - * Copyright (C) 2020 Samsung Electronics Co., Ltd. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * http://www.apache.org/licenses/LICENSE-2.0 - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - * - * @file qkv_layer.cpp - * @date 14 May 2020 - * @brief This is Fully Connected Layer Class for Neural Network - * @see https://github.com/nntrainer/nntrainer - * @author Eunju Yang - * @bug No known bugs except for NYI items - * - */ - -#include - -#include -#include -#include -#include -#include -#include -#include - -namespace quick_dot_ai { - -static constexpr size_t SINGLE_INOUT_IDX = 0; - -enum QKVParams { Q, K, V }; - -QKVLayer::QKVLayer() : - LayerImpl(), qkv_props(props::QUnit(), props::KUnit(), props::VUnit()) { - weight_idx.fill(std::numeric_limits::max()); -} - -void QKVLayer::finalize(nntrainer::InitLayerContext &context) { - NNTR_THROW_IF(context.getNumInputs() != 1, std::invalid_argument) - << "Fully connected layer takes only one input"; - - auto &weight_regularizer = - std::get(*layer_impl_props); - auto &weight_regularizer_constant = - std::get(*layer_impl_props); - auto weight_initializer = nntrainer::props::InitializerInfo::Enum::NONE; - auto &weight_decay = - std::get(*layer_impl_props); - - const auto &q_unit = std::get(qkv_props).get(); - const auto &k_unit = std::get(qkv_props).get(); - const auto &v_unit = std::get(qkv_props).get(); - - std::vector output_dims(3); - - /// @todo fc actaully supports multidimensions. EffDimFlag shouldn't be fixed - /// like this. - context.setEffDimFlagInputDimension(0, 0b1001); - context.setDynDimFlagInputDimension(0, 0b1000); - - bool is_nchw = (context.getFormat() == nntrainer::Tformat::NCHW); - /** set output dimensions */ - auto const &in_dim = context.getInputDimensions()[0]; - - /** Q out */ - output_dims[QKVParams::Q] = in_dim; - is_nchw ? output_dims[QKVParams::Q].width(q_unit) - : output_dims[QKVParams::Q].channel(q_unit); - output_dims[QKVParams::Q].setTensorType( - {context.getFormat(), context.getActivationDataType()}); - - /** K out */ - output_dims[QKVParams::K] = in_dim; - is_nchw ? output_dims[QKVParams::K].width(k_unit) - : output_dims[QKVParams::K].channel(k_unit); - output_dims[QKVParams::K].setTensorType( - {context.getFormat(), context.getActivationDataType()}); - - /** V out */ - output_dims[QKVParams::V] = in_dim; - is_nchw ? output_dims[QKVParams::V].width(v_unit) - : output_dims[QKVParams::V].channel(v_unit); - output_dims[QKVParams::V].setTensorType( - {context.getFormat(), context.getActivationDataType()}); - - context.setOutputDimensions(output_dims); - - /** Q */ - nntrainer::TensorDim weight_dim( - 1, is_nchw ? 1 : q_unit, is_nchw ? in_dim.width() : 1, - is_nchw ? q_unit : in_dim.channel(), - nntrainer::TensorDim::TensorType(context.getFormat(), - context.getWeightDataType()), - is_nchw ? 0b0011 : 0b0101); - weight_idx[QKVParams::Q] = context.requestWeight( - weight_dim, weight_initializer, weight_regularizer, - weight_regularizer_constant, weight_decay, "qweight", true); - - /** K */ - weight_dim.width(k_unit); - weight_idx[QKVParams::K] = context.requestWeight( - weight_dim, weight_initializer, weight_regularizer, - weight_regularizer_constant, weight_decay, "kweight", true); - - /** V */ - weight_dim.width(v_unit); - weight_idx[QKVParams::V] = context.requestWeight( - weight_dim, weight_initializer, weight_regularizer, - weight_regularizer_constant, weight_decay, "vweight", true); -} - -void QKVLayer::exportTo(nntrainer::Exporter &exporter, - const ml::train::ExportMethods &method) const { - LayerImpl::exportTo(exporter, method); - exporter.saveResult(qkv_props, method, this); -} - -void QKVLayer::setProperty(const std::vector &values) { - auto remain_props = loadProperties(values, qkv_props); - LayerImpl::setProperty(remain_props); -} - -void QKVLayer::forwarding(nntrainer::RunLayerContext &context, bool training) { - return; -} - -void QKVLayer::incremental_forwarding(nntrainer::RunLayerContext &context, - unsigned int from, unsigned int to, - bool training) { - nntrainer::Tensor &Qweight = context.getWeight(weight_idx[QKVParams::Q]); - nntrainer::Tensor &Kweight = context.getWeight(weight_idx[QKVParams::K]); - nntrainer::Tensor &Vweight = context.getWeight(weight_idx[QKVParams::V]); - nntrainer::Tensor &input_ = context.getInput(SINGLE_INOUT_IDX); - nntrainer::Tensor &Qhidden_ = context.getOutput(QKVParams::Q); - nntrainer::Tensor &Khidden_ = context.getOutput(QKVParams::K); - nntrainer::Tensor &Vhidden_ = context.getOutput(QKVParams::V); - - nntrainer::TensorDim input_dim = input_.getDim(); - nntrainer::TensorDim input_step_dim = input_dim; - input_step_dim.batch(1); - input_step_dim.height(to - from); - - auto &pool = - nntrainer::Engine::Global().getThreadPoolManager()->getThreadPool(); - - nntrainer::Tensor input_step = - input_.getSharedDataTensor(input_step_dim, 0, true); - - nntrainer::TensorDim Qhidden_dim = Qhidden_.getDim(); - nntrainer::TensorDim Qhidden_step_dim = Qhidden_.getDim(); - Qhidden_step_dim.batch(1); - Qhidden_step_dim.height(to - from); - nntrainer::Tensor Qhidden_step = - Qhidden_.getSharedDataTensor(Qhidden_step_dim, 0, true); - - nntrainer::TensorDim Khidden_dim = Khidden_.getDim(); - nntrainer::TensorDim Khidden_step_dim = Khidden_.getDim(); - Khidden_step_dim.batch(1); - Khidden_step_dim.height(to - from); - nntrainer::Tensor Khidden_step = - Khidden_.getSharedDataTensor(Khidden_step_dim, 0, true); - - nntrainer::TensorDim Vhidden_dim = Vhidden_.getDim(); - nntrainer::TensorDim Vhidden_step_dim = Vhidden_.getDim(); - Vhidden_step_dim.batch(1); - Vhidden_step_dim.height(to - from); - nntrainer::Tensor Vhidden_step = - Vhidden_.getSharedDataTensor(Vhidden_step_dim, 0, true); - - std::vector Weights({&Qweight, &Kweight, &Vweight}); - std::vector Outputs( - {&Qhidden_step, &Khidden_step, &Vhidden_step}); - - input_step.dot(Weights, Outputs); -} - -void QKVLayer::calcDerivative(nntrainer::RunLayerContext &context) { return; } - -void QKVLayer::calcGradient(nntrainer::RunLayerContext &context) { return; } - -void QKVLayer::updateTensorsByInputDimensions( - nntrainer::RunLayerContext &context, - std::vector input_dimensions) { - ml::train::TensorDim input_dim = context.getInput(SINGLE_INOUT_IDX).getDim(); - ml::train::TensorDim Qoutput_dim = context.getOutput(QKVParams::Q).getDim(); - ml::train::TensorDim Koutput_dim = context.getOutput(QKVParams::K).getDim(); - ml::train::TensorDim Voutput_dim = context.getOutput(QKVParams::V).getDim(); - - input_dim.height(input_dimensions[0].height()); - Qoutput_dim.height(input_dimensions[0].height()); - Koutput_dim.height(input_dimensions[0].height()); - Voutput_dim.height(input_dimensions[0].height()); - - context.updateInput(SINGLE_INOUT_IDX, input_dim); - context.updateOutput(QKVParams::Q, Qoutput_dim); - context.updateOutput(QKVParams::K, Koutput_dim); - context.updateOutput(QKVParams::V, Voutput_dim); -} -} // namespace quick_dot_ai diff --git a/layers/qkv_layer.h b/layers/qkv_layer.h deleted file mode 100644 index c3cdb077..00000000 --- a/layers/qkv_layer.h +++ /dev/null @@ -1,153 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * Copyright (C) 2020 Jijoong Moone - * - * @file qkv_layer.h - * @date 14 May 2020 - * @brief This is Fully Connected Layer Class of Neural Network - * @see https://github.com/nntrainer/nntrainer - * @author Jijoong Moon - * @author Eunju Yang - * @bug No known bugs except for NYI items - * - */ - -#ifndef __QKV_LAYER_H__ -#define __QKV_LAYER_H__ -#ifdef __cplusplus - -#pragma once -#ifdef _WIN32 -#define WIN_EXPORT __declspec(dllexport) -#else -#define WIN_EXPORT -#endif - -#include -#include - -namespace quick_dot_ai { - -namespace props { - -class QUnit : public nntrainer::PositiveIntegerProperty { -public: - static constexpr const char *key = "q_unit"; - using prop_tag = nntrainer::uint_prop_tag; -}; - -class KUnit : public nntrainer::PositiveIntegerProperty { -public: - static constexpr const char *key = "k_unit"; - using prop_tag = nntrainer::uint_prop_tag; -}; - -class VUnit : public nntrainer::PositiveIntegerProperty { -public: - static constexpr const char *key = "v_unit"; - using prop_tag = nntrainer::uint_prop_tag; -}; - -} // namespace props - -/** - * @class FullyConnecedLayer - * @brief fully connected layer - */ -WIN_EXPORT class QKVLayer : public nntrainer::LayerImpl { -public: - /** - * @brief Constructor of Fully Connected Layer - */ - WIN_EXPORT QKVLayer(); - - /** - * @brief Destructor of Fully Connected Layer - */ - WIN_EXPORT ~QKVLayer() = default; - - /** - * @brief Move constructor. - * @param[in] FullyConnected && - */ - WIN_EXPORT QKVLayer(QKVLayer &&rhs) noexcept = default; - - /** - * @brief Move assignment operator. - * @parma[in] rhs QKVLayer to be moved. - */ - WIN_EXPORT QKVLayer &operator=(QKVLayer &&rhs) = default; - - /** - * @copydoc Layer::finalize(InitLayerContext &context) - */ - WIN_EXPORT void finalize(nntrainer::InitLayerContext &context) override; - - /** - * @copydoc Layer::forwarding(RunLayerContext &context, bool training) - */ - WIN_EXPORT void forwarding(nntrainer::RunLayerContext &context, - bool training) override; - - /** -οΏΌ * @copydoc Layer::incremental_forwarding(RunLayerContext &context, unsigned -οΏΌ * int from, unsigned int to, bool training) -οΏΌ */ - WIN_EXPORT void incremental_forwarding(nntrainer::RunLayerContext &context, - unsigned int from, unsigned int to, - bool training) override; - - /** - * @copydoc Layer::calcDerivative(RunLayerContext &context) - */ - WIN_EXPORT void calcDerivative(nntrainer::RunLayerContext &context) override; - - /** - * @copydoc Layer::calcGradient(RunLayerContext &context) - * @note - * [note for LoRA] implicit calcDerivative is implicitly applied. - * The weight is already updated with the LoRA's (W = W + W_lora) - */ - WIN_EXPORT void calcGradient(nntrainer::RunLayerContext &context) override; - - /** - * @copydoc Layer::exportTo(Exporter &exporter, ml::train::ExportMethods - * method) - */ - WIN_EXPORT void - exportTo(nntrainer::Exporter &exporter, - const ml::train::ExportMethods &method) const override; - - /** - * @copydoc Layer::getType() - */ - WIN_EXPORT const std::string getType() const override { - return QKVLayer::type; - }; - - /** - * @copydoc Layer::supportBackwarding() - */ - WIN_EXPORT bool supportBackwarding() const override { return true; } - - /** - * @copydoc Layer::setProperty(const PropertyType type, const std::string - * &value) - */ - WIN_EXPORT void setProperty(const std::vector &values) override; - - WIN_EXPORT void updateTensorsByInputDimensions( - nntrainer::RunLayerContext &context, - std::vector input_dimensions) override; - - inline static const std::string type = "qkv_layer"; - -private: - std::tuple qkv_props; - std::array weight_idx; /**< indices of the weights */ -}; - -} // namespace quick_dot_ai - -#endif /* __cplusplus */ -#endif /* __QKV_LAYER_H__ */ diff --git a/layers/reshaped_rms_norm.cpp b/layers/reshaped_rms_norm.cpp deleted file mode 100644 index 64413796..00000000 --- a/layers/reshaped_rms_norm.cpp +++ /dev/null @@ -1,139 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * Copyright (C) 2023 Seungbaek Hong - * - * @file custom_rms_norm.cpp - * @date 19 July 2023 - * @brief Implementation of custom RMS normalization function - * @see https://github.com/nntrainer/nntrainer - * @author Seungbaek Hong - * @bug No known bugs except for NYI items - * - */ - -#include -#include -#include - -namespace quick_dot_ai { - -static constexpr size_t SINGLE_INOUT_IDX = 0; - -void ReshapedRMSNormLayer::finalize(nntrainer::InitLayerContext &context) { - std::vector dim = context.getInputDimensions(); - context.setOutputDimensions(dim); - feature_size = std::get(rms_props); - - NNTR_THROW_IF(dim[0].width() % feature_size != 0, std::invalid_argument) - << "feature size must be a divisor of width"; - - nntrainer::TensorDim gamma_dim( - 1, 1, 1, feature_size, - nntrainer::TensorDim::TensorType(context.getFormat(), - context.getWeightDataType())); - wt_idx[RMSParams::gamma] = context.requestWeight( - gamma_dim, nntrainer::props::InitializerInfo::Enum::NONE, - nntrainer::WeightRegularizer::NONE, 1.0f, 0.0f, "gamma", false); -} - -void ReshapedRMSNormLayer::forwarding(nntrainer::RunLayerContext &context, - bool training) {} - -void ReshapedRMSNormLayer::incremental_forwarding( - nntrainer::RunLayerContext &context, unsigned int from, unsigned int to, - bool training) { - auto &epsilon = std::get(rms_props).get(); - - nntrainer::Tensor &in = context.getInput(SINGLE_INOUT_IDX); - nntrainer::Tensor &out = context.getOutput(SINGLE_INOUT_IDX); - nntrainer::Tensor &gamma = context.getWeight(wt_idx[RMSParams::gamma]); - - ml::train::TensorDim in_dim = in.getDim(); - ml::train::TensorDim out_dim = out.getDim(); - - ml::train::TensorDim in_step_dim = in_dim; - ml::train::TensorDim out_step_dim = out_dim; - - unsigned int _from = from; - - in_step_dim.batch(1); - in_step_dim.height(to - from); - out_step_dim.batch(1); - out_step_dim.height(to - from); - - // set reshaped dim to (1, 1, -1, feature_size) - ml::train::TensorDim step_reshaped_dim = in_step_dim; - - step_reshaped_dim.width(feature_size); - step_reshaped_dim.height(in_step_dim.height() * - (in_dim.width() / feature_size)); - - unsigned int b_size = in_dim.batch(); - - for (unsigned int b = 0; b < b_size; ++b) { - nntrainer::Tensor in_step = - in.getSharedDataTensor(in_step_dim, b * in_dim.getFeatureLen(), true); - nntrainer::Tensor out_step = - out.getSharedDataTensor(out_step_dim, b * out_dim.getFeatureLen(), true); - - // reshape in_step - // reshape out_step - in_step.reshape(step_reshaped_dim); - out_step.reshape(step_reshaped_dim); - - if (in_step.getDataType() == ml::train::TensorDim::DataType::FP32) { - ///@todo rms_norm_wrt_width_something() should be refactored to - /// nntrainer::Tensor operation. -#ifdef ENABLE_FP16 - nntrainer::rms_norm_wrt_width_fp16_intrinsic( - in_step.getData(), out_step.getData(), - in_step.getDim().height(), in_step.getDim().width(), epsilon); -#else - nntrainer::rms_norm_wrt_width_fp32_intrinsic( - in_step.getData(), out_step.getData(), - in_step.getDim().height(), in_step.getDim().width(), epsilon); -#endif - } else { - throw std::invalid_argument( - "Error: not yet implemented for this data type"); - } - out_step.multiply_i(gamma); - - // reshape again out_step - out_step.reshape(out_step_dim); - -#ifdef DEBUG - std::cout << context.getName() << " \n input:" << in_step - << "output:" << out_step << "gamma:" << gamma << std::endl; -#endif - } -} - -void ReshapedRMSNormLayer::updateTensorsByInputDimensions( - nntrainer::RunLayerContext &context, - std::vector input_dimensions) { - context.updateInput(SINGLE_INOUT_IDX, input_dimensions[0]); - context.updateOutput(SINGLE_INOUT_IDX, input_dimensions[0]); -} - -void ReshapedRMSNormLayer::calcDerivative(nntrainer::RunLayerContext &context) { - std::throw_with_nested(std::runtime_error("Training is not supported yet.")); -} - -#ifdef PLUGGABLE - -nntrainer::Layer *create_rms_norm_layer() { - auto layer = new ReshapedRMSNormLayer(); - return layer; -} - -void destroy_rms_norm_layer(nntrainer::Layer *layer) { delete layer; } - -extern "C" { -nntrainer::LayerPluggable ml_train_layer_pluggable{create_rms_norm_layer, - destroy_rms_norm_layer}; -} - -#endif - -} // namespace quick_dot_ai diff --git a/layers/reshaped_rms_norm.h b/layers/reshaped_rms_norm.h deleted file mode 100644 index 3d7470f4..00000000 --- a/layers/reshaped_rms_norm.h +++ /dev/null @@ -1,131 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * Copyright (C) 2025 Eunju Yang - * - * @file reshaped_rms_norm.h - * @date 15 July 2025 - * @brief Implementation of RMS normalization function with reshaping. - * @see https://github.com/nntrainer/nntrainer - * @author Eunju Yang - * @bug No known bugs except for NYI items - * @note This layer only supports inference mode. - */ - -#ifndef __RMS_NORM_LAYER_H__ -#define __RMS_NORM_LAYER_H__ - -#pragma once -#ifdef _WIN32 -#define WIN_EXPORT __declspec(dllexport) -#else -#define WIN_EXPORT -#endif - -#include -#include -#include -#include - -#include -#include -#include -#include -#include - -namespace quick_dot_ai { - -/** - * @brief A custom Reshaped RMS normalization layer for llama. - * - */ -WIN_EXPORT class ReshapedRMSNormLayer final : public nntrainer::Layer { -public: - /** - * @brief Construct a new custom RMS normalization layer object - * - */ - WIN_EXPORT ReshapedRMSNormLayer() : - Layer(), - rms_props(props::RMS_NORM_GAMMA_INIT(), nntrainer::props::Epsilon(), - props::FeatureSize()), - feature_size(0) { - wt_idx.fill(std::numeric_limits::max()); - } - - /** - * @brief Destroy the custom RMS normalization layer object - * - */ - WIN_EXPORT ~ReshapedRMSNormLayer() {} - - /** - * @copydoc Layer::finalize(InitLayerContext &context) - */ - WIN_EXPORT void finalize(nntrainer::InitLayerContext &context) override; - - /** - * @copydoc Layer::forwarding(RunLayerContext &context, bool training) - */ - WIN_EXPORT void forwarding(nntrainer::RunLayerContext &context, - bool training) override; - - /** - * @copydoc Layer::incremental_forwarding(RunLayerContext &context, unsigned - * int from, unsigned int to, bool training) - */ - WIN_EXPORT void incremental_forwarding(nntrainer::RunLayerContext &context, - unsigned int from, unsigned int to, - bool training) override; - - /** - * @copydoc Layer::calcDerivative(RunLayerContext &context) - */ - WIN_EXPORT void calcDerivative(nntrainer::RunLayerContext &context) override; - - /** - * @copydoc bool supportBackwarding() const - */ - WIN_EXPORT bool supportBackwarding() const override { return false; }; - - /** - * @copydoc Layer::exportTo(Exporter &exporter, ExportMethods method) - */ - WIN_EXPORT void - exportTo(nntrainer::Exporter &exporter, - const ml::train::ExportMethods &method) const override{}; - - /** - * @copydoc Layer::getType() - */ - WIN_EXPORT const std::string getType() const override { - return ReshapedRMSNormLayer::type; - }; - - /** - * @copydoc Layer::setProperty(const std::vector &values) - */ - WIN_EXPORT void setProperty(const std::vector &values) override { - auto remain_props = loadProperties(values, rms_props); - NNTR_THROW_IF(!remain_props.empty(), std::invalid_argument) - << "[rms_norm] Unknown Layer Properties count " + - std::to_string(values.size()); - }; - - WIN_EXPORT void updateTensorsByInputDimensions( - nntrainer::RunLayerContext &context, - std::vector input_dimensions) override; - - inline static const std::string type = "reshaped_rms_norm"; - -private: - std::array wt_idx; - std::tuple - rms_props; - - unsigned int feature_size; -}; - -} // namespace quick_dot_ai - -#endif /* __CAUSALLM_RMS_NORM_LAYER_H__ */ diff --git a/layers/rms_norm.cpp b/layers/rms_norm.cpp deleted file mode 100644 index a28b3b4d..00000000 --- a/layers/rms_norm.cpp +++ /dev/null @@ -1,112 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * Copyright (C) 2023 Seungbaek Hong - * - * @file custom_rms_norm.cpp - * @date 19 July 2023 - * @brief Implementation of custom RMS normalization function - * @see https://github.com/nntrainer/nntrainer - * @author Seungbaek Hong - * @bug No known bugs except for NYI items - * - */ - -#include -#include - -#include "rms_norm.h" - -namespace quick_dot_ai { - -static constexpr size_t SINGLE_INOUT_IDX = 0; - -void RMSNormLayer::finalize(nntrainer::InitLayerContext &context) { - std::vector dim = context.getInputDimensions(); - context.setOutputDimensions(dim); - nntrainer::TensorDim gamma_dim( - 1, 1, 1, dim[0].width(), - nntrainer::TensorDim::TensorType(context.getFormat(), - context.getWeightDataType())); - wt_idx[RMSParams::gamma] = context.requestWeight( - gamma_dim, nntrainer::props::InitializerInfo::Enum::NONE, - nntrainer::WeightRegularizer::NONE, 1.0f, 0.0f, "gamma", false); -} - -void RMSNormLayer::forwarding(nntrainer::RunLayerContext &context, - bool training) {} - -void RMSNormLayer::incremental_forwarding(nntrainer::RunLayerContext &context, - unsigned int from, unsigned int to, - bool training) { - auto &epsilon = std::get(rms_props).get(); - - nntrainer::Tensor &in = context.getInput(SINGLE_INOUT_IDX); - nntrainer::Tensor &out = context.getOutput(SINGLE_INOUT_IDX); - nntrainer::Tensor &gamma = context.getWeight(wt_idx[RMSParams::gamma]); - - ml::train::TensorDim in_dim = in.getDim(); - ml::train::TensorDim out_dim = out.getDim(); - - ml::train::TensorDim in_step_dim = in_dim; - ml::train::TensorDim out_step_dim = out_dim; - - unsigned int _from = from; - - in_step_dim.batch(1); - in_step_dim.height(to - from); - out_step_dim.batch(1); - out_step_dim.height(to - from); - - unsigned int b_size = in_dim.batch(); - - for (unsigned int b = 0; b < b_size; ++b) { - nntrainer::Tensor in_step = - in.getSharedDataTensor(in_step_dim, b * in_dim.getFeatureLen(), true); - nntrainer::Tensor out_step = - out.getSharedDataTensor(out_step_dim, b * out_dim.getFeatureLen(), true); - - if (in_step.getDataType() == ml::train::TensorDim::DataType::FP32) { - auto t = in_step.multiply(in_step).average(3).add(epsilon); - t.inv_sqrt_i(); - in_step.multiply(t, out_step); - } else { - throw std::invalid_argument( - "Error: not yet implemented for this data type"); - } - out_step.multiply_i(gamma); - -#ifdef DEBUG - std::cout << context.getName() << " \n input:" << in_step - << "output:" << out_step << "gamma:" << gamma << std::endl; -#endif - } -} - -void RMSNormLayer::updateTensorsByInputDimensions( - nntrainer::RunLayerContext &context, - std::vector input_dimensions) { - context.updateInput(SINGLE_INOUT_IDX, input_dimensions[0]); - context.updateOutput(SINGLE_INOUT_IDX, input_dimensions[0]); -} - -void RMSNormLayer::calcDerivative(nntrainer::RunLayerContext &context) { - std::throw_with_nested(std::runtime_error("Training is not supported yet.")); -} - -#ifdef PLUGGABLE - -nntrainer::Layer *create_rms_norm_layer() { - auto layer = new RMSNormLayer(); - return layer; -} - -void destroy_rms_norm_layer(nntrainer::Layer *layer) { delete layer; } - -extern "C" { -nntrainer::LayerPluggable ml_train_layer_pluggable{create_rms_norm_layer, - destroy_rms_norm_layer}; -} - -#endif - -} // namespace quick_dot_ai diff --git a/layers/rms_norm.h b/layers/rms_norm.h deleted file mode 100644 index e49a940d..00000000 --- a/layers/rms_norm.h +++ /dev/null @@ -1,120 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * Copyright (C) 2023 Seungbaek Hong - * - * @file rms_norm.h - * @date 11 July 2025 - * @brief Implementation of RMS normalization function - * @see https://github.com/nntrainer/nntrainer - * @author Seungbaek Hong - * @bug No known bugs except for NYI items - * @note This layer only supports inference mode. - */ - -#ifndef __RMS_NORM_LAYER_H__ -#define __RMS_NORM_LAYER_H__ - -#pragma once -#ifdef _WIN32 -#define WIN_EXPORT __declspec(dllexport) -#else -#define WIN_EXPORT -#endif - -#include -#include -#include -#include - -#include -#include -#include -#include - -namespace quick_dot_ai { - -/** - * @brief A custom RMS normalization layer for llama. - * - */ -WIN_EXPORT class RMSNormLayer final : public nntrainer::Layer { -public: - /** - * @brief Construct a new custom RMS normalization layer object - * - */ - WIN_EXPORT RMSNormLayer() : Layer(), wt_idx({0}) {} - - /** - * @brief Destroy the custom RMS normalization layer object - * - */ - WIN_EXPORT ~RMSNormLayer() {} - - /** - * @copydoc Layer::finalize(InitLayerContext &context) - */ - WIN_EXPORT void finalize(nntrainer::InitLayerContext &context) override; - - /** - * @copydoc Layer::forwarding(RunLayerContext &context, bool training) - */ - WIN_EXPORT void forwarding(nntrainer::RunLayerContext &context, - bool training) override; - - /** - * @copydoc Layer::incremental_forwarding(RunLayerContext &context, unsigned - * int from, unsigned int to, bool training) - */ - WIN_EXPORT void incremental_forwarding(nntrainer::RunLayerContext &context, - unsigned int from, unsigned int to, - bool training) override; - - /** - * @copydoc Layer::calcDerivative(RunLayerContext &context) - */ - WIN_EXPORT void calcDerivative(nntrainer::RunLayerContext &context) override; - - /** - * @copydoc bool supportBackwarding() const - */ - WIN_EXPORT bool supportBackwarding() const override { return false; }; - - /** - * @copydoc Layer::exportTo(Exporter &exporter, ExportMethods method) - */ - WIN_EXPORT void - exportTo(nntrainer::Exporter &exporter, - const ml::train::ExportMethods &method) const override{}; - - /** - * @copydoc Layer::getType() - */ - WIN_EXPORT const std::string getType() const override { - return RMSNormLayer::type; - }; - - /** - * @copydoc Layer::setProperty(const std::vector &values) - */ - WIN_EXPORT void setProperty(const std::vector &values) override { - auto remain_props = loadProperties(values, rms_props); - NNTR_THROW_IF(!remain_props.empty(), std::invalid_argument) - << "[rms_norm] Unknown Layer Properties count " + - std::to_string(values.size()); - }; - - WIN_EXPORT void updateTensorsByInputDimensions( - nntrainer::RunLayerContext &context, - std::vector input_dimensions) override; - - inline static const std::string type = "rms_norm"; - -private: - std::array wt_idx; - std::tuple rms_props; -}; - -} // namespace quick_dot_ai - -#endif /* __CAUSALLM_RMS_NORM_LAYER_H__ */ diff --git a/layers/swiglu.cpp b/layers/swiglu.cpp deleted file mode 100644 index a672b9aa..00000000 --- a/layers/swiglu.cpp +++ /dev/null @@ -1,117 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * Copyright (C) 2023 Seungbaek Hong - * - * @file swiglu.cpp - * @date 14 July 2023 - * @brief Implementation of SwiGLU activation function - * @see https://github.com/nntrainer/nntrainer - * @author Seungbaek Hong - * @bug No known bugs except for NYI items - * - */ - -#include - -#include "swiglu.h" - -namespace quick_dot_ai { - -static constexpr size_t OUT_IDX = 0; -static constexpr size_t INPUT_IDX_1 = 0; -static constexpr size_t INPUT_IDX_2 = 1; - -namespace ActivationOp { -/** - * @brief activation function swiglu - * @param x input - * @retval swiglu(x) - */ -float swiglu(float x) { return x / (1 + nntrainer::exp_util(-x)); } -} // namespace ActivationOp - -void SwiGLULayer::finalize(nntrainer::InitLayerContext &context) { - context.setOutputDimensions({context.getInputDimensions()[0]}); -} - -void SwiGLULayer::forwarding(nntrainer::RunLayerContext &context, - bool training) {} - -void SwiGLULayer::incremental_forwarding(nntrainer::RunLayerContext &context, - unsigned int from, unsigned int to, - bool training) { - nntrainer::Tensor &in1 = context.getInput(INPUT_IDX_1); - nntrainer::Tensor &in2 = context.getInput(INPUT_IDX_2); - nntrainer::Tensor &out = context.getOutput(OUT_IDX); - - unsigned int _from = from; - - int iter = to - from; - - if (in1.getDataType() == ml::train::TensorDim::DataType::FP32) { - for (unsigned int b = 0; b < in1.batch(); b++) { - for (unsigned int c = 0; c < in1.channel(); c++) { - for (unsigned int h = 0; h < iter; h++) { - nntrainer::swiglu(in1.width(), - out.getData() + out.getIndex(b, c, h, 0), - in1.getData() + in1.getIndex(b, c, h, 0), - in2.getData() + in2.getIndex(b, c, h, 0)); - } - } - } - } else if (in1.getDataType() == ml::train::TensorDim::DataType::FP16) { -#ifdef ENABLE_FP16 - for (unsigned int b = 0; b < in1.batch(); b++) { - for (unsigned int c = 0; c < in1.channel(); c++) { - for (unsigned int h = 0; h < iter; h++) { - nntrainer::swiglu(in1.width(), - out.getData<_FP16>() + out.getIndex(b, c, h, 0), - in1.getData<_FP16>() + in1.getIndex(b, c, h, 0), - in2.getData<_FP16>() + in2.getIndex(b, c, h, 0)); - } - } - } -#else - NNTR_THROW_IF(true, std::invalid_argument) << "enable-fp16 is not set!"; -#endif - } -} - -void SwiGLULayer::updateTensorsByInputDimensions( - nntrainer::RunLayerContext &context, - std::vector input_dimensions) { - ml::train::TensorDim input_dim1 = context.getInput(INPUT_IDX_1).getDim(); - ml::train::TensorDim input_dim2 = context.getInput(INPUT_IDX_2).getDim(); - ml::train::TensorDim output_dim = context.getOutput(OUT_IDX).getDim(); - - input_dim1.height(input_dimensions[0].height()); - input_dim2.height(input_dimensions[0].height()); - output_dim.height(input_dimensions[0].height()); - - context.updateInput(INPUT_IDX_1, input_dim1); - context.updateInput(INPUT_IDX_2, input_dim2); - context.updateOutput(OUT_IDX, output_dim); -} - -void SwiGLULayer::calcDerivative(nntrainer::RunLayerContext &context) { - // std::throw_with_nested(std::runtime_error("Training is not supported - // yet.")); -} - -#ifdef PLUGGABLE - -nntrainer::Layer *create_swiglu_layer() { - auto layer = new SwiGLULayer(); - return layer; -} - -void destroy_swiglu_layer(nntrainer::Layer *layer) { delete layer; } - -extern "C" { -nntrainer::LayerPluggable ml_train_layer_pluggable{create_swiglu_layer, - destroy_swiglu_layer}; -} - -#endif - -} // namespace quick_dot_ai diff --git a/layers/swiglu.h b/layers/swiglu.h deleted file mode 100644 index 7b5d873a..00000000 --- a/layers/swiglu.h +++ /dev/null @@ -1,107 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * Copyright (C) 2023 Seungbaek Hong - * - * @file swiglu.h - * @date 14 July 2023 - * @brief Implementation of custom SwiGLU activation function - * @see https://github.com/nntrainer/nntrainer - * @author Seungbaek Hong - * @bug No known bugs except for NYI items - * - */ - -#ifndef __SWIGLU_LAYER_H__ -#define __SWIGLU_LAYER_H__ - -#include -#include -#include -#include - -#pragma once -#ifdef _WIN32 -#define WIN_EXPORT __declspec(dllexport) -#else -#define WIN_EXPORT -#endif - -namespace quick_dot_ai { - -/** - * @brief A SwiGLU layer for llama. - * - */ -WIN_EXPORT class SwiGLULayer final : public nntrainer::Layer { -public: - /** - * @brief Construct a new custom SwiGLU layer object - * - */ - WIN_EXPORT SwiGLULayer() : Layer() {} - - /** - * @brief Destroy the custom SwiGLU layer object - * - */ - WIN_EXPORT ~SwiGLULayer() {} - - /** - * @copydoc Layer::finalize(InitLayerContext &context) - */ - WIN_EXPORT void finalize(nntrainer::InitLayerContext &context) override; - - /** - * @copydoc Layer::forwarding(RunLayerContext &context, bool training) - */ - WIN_EXPORT void forwarding(nntrainer::RunLayerContext &context, - bool training) override; - - /** - * @copydoc Layer::incremental_forwarding(RunLayerContext &context, unsigned - * int from, unsigned int to, bool training) - */ - WIN_EXPORT void incremental_forwarding(nntrainer::RunLayerContext &context, - unsigned int from, unsigned int to, - bool training) override; - - /** - * @copydoc Layer::calcDerivative(RunLayerContext &context) - */ - WIN_EXPORT void calcDerivative(nntrainer::RunLayerContext &context) override; - - /** - * @copydoc bool supportBackwarding() const - */ - WIN_EXPORT bool supportBackwarding() const override { return true; }; - - /** - * @copydoc Layer::exportTo(Exporter &exporter, ExportMethods method) - */ - WIN_EXPORT void - exportTo(nntrainer::Exporter &exporter, - const ml::train::ExportMethods &method) const override{}; - - /** - * @copydoc Layer::getType() - */ - WIN_EXPORT const std::string getType() const override { - return SwiGLULayer::type; - }; - - /** - * @copydoc Layer::setProperty(const std::vector &values) - */ - WIN_EXPORT void - setProperty(const std::vector &values) override{}; - - WIN_EXPORT void updateTensorsByInputDimensions( - nntrainer::RunLayerContext &context, - std::vector input_dimensions) override; - - inline static const std::string type = "swiglu"; -}; - -} // namespace quick_dot_ai - -#endif /* __SWIGLU_LAYER_H__ */ diff --git a/layers/tie_word_embedding.cpp b/layers/tie_word_embedding.cpp deleted file mode 100644 index 6dc55dcb..00000000 --- a/layers/tie_word_embedding.cpp +++ /dev/null @@ -1,448 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * Copyright (C) 2020 Jijoong Moon - * - * @file tie_word_embedding.cpp - * @date 21 May 2025 - * @brief This is Embedding Layer Class of Neural Network - * @see https://github.com/nntrainer/nntrainer - * @author Eunju Yang - * @bug No known bugs except for NYI items - * - */ - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace quick_dot_ai { - -static constexpr size_t SINGLE_INOUT_IDX = 0; - -enum TieWordEmbeddingParams { - weight, - bias, - candidate_weight, - candidate_hidden_step -}; - -TieWordEmbedding::TieWordEmbedding() : - LayerImpl(), - tieword_embedding_props(nntrainer::props::InDim(), nntrainer::props::OutDim(), - nntrainer::props::Unit(), nntrainer::props::Scale()) { - weight_idx.fill(std::numeric_limits::max()); -} - -void TieWordEmbedding::finalize(nntrainer::InitLayerContext &context) { - mode_ = std::get(tieword_embedding_props).empty() - ? mode::embedding - : mode::lm_head; - if (mode_ == mode::embedding) - finalize_embedding(context); - else if (mode_ == mode::lm_head) - finalize_lmhead(context); -} - -void TieWordEmbedding::finalize_embedding( - nntrainer::InitLayerContext &context) { - - NNTR_THROW_IF(context.getNumInputs() != 1, std::invalid_argument) - << "Embedding layer takes only one input"; - - const nntrainer::TensorDim &input_dim = - context.getInputDimensions()[SINGLE_INOUT_IDX]; - NNTR_THROW_IF(input_dim.channel() != 1, std::invalid_argument) - << "Embedding layer takes only one for channel size"; - - NNTR_THROW_IF(input_dim.getDataType() != nntrainer::TensorDim::DataType::FP32, - std::invalid_argument) - << "Embedding layer takes only FP32 input data"; - - auto &weight_regularizer = - std::get(*layer_impl_props); - auto &weight_regularizer_constant = - std::get(*layer_impl_props); - auto weight_initializer = nntrainer::props::InitializerInfo::Enum::NONE; - auto &weight_decay = - std::get(*layer_impl_props); - - unsigned int in_dim = - std::get(tieword_embedding_props); - unsigned int out_dim = - std::get(tieword_embedding_props); - - nntrainer::TensorDim output_dim = input_dim; - - // output_dim expected as hidden x num input (batch size) - output_dim.height(input_dim.width()); - output_dim.width(out_dim); - output_dim.setTensorType( - {context.getFormat(), context.getActivationDataType()}); - context.setOutputDimensions({output_dim}); - - nntrainer::TensorDim dim = output_dim; - - dim.setTensorType({context.getFormat(), context.getWeightDataType()}); - - dim.height(in_dim); - dim.width(out_dim); - dim.batch(1); - - weight_idx[TieWordEmbeddingParams::weight] = context.requestWeight( - dim, weight_initializer, weight_regularizer, weight_regularizer_constant, - weight_decay, "Embedding", true); -} - -void TieWordEmbedding::finalize_lmhead(nntrainer::InitLayerContext &context) { - auto &weight_regularizer = - std::get(*layer_impl_props); - auto &weight_regularizer_constant = - std::get(*layer_impl_props); - auto weight_initializer = nntrainer::props::InitializerInfo::Enum::NONE; - auto &weight_decay = - std::get(*layer_impl_props); - auto &bias_decay = std::get(*layer_impl_props); - auto &bias_initializer = - std::get(*layer_impl_props); - auto &disable_bias = - std::get(*layer_impl_props); - - auto unit = std::get(tieword_embedding_props).get(); - - NNTR_THROW_IF(context.getNumInputs() != 1, std::invalid_argument) - << "lm head layer takes only one input"; - - std::vector output_dims(1); - - /// @todo fc actaully supports multidimensions. - /// EffDimFlag shouldn't be fixed like this. - context.setEffDimFlagInputDimension(0, 0b1001); - context.setDynDimFlagInputDimension(0, 0b1000); - bool is_nchw = (context.getFormat() == nntrainer::Tformat::NCHW); - - /** set output dimensions */ - auto const &in_dim = context.getInputDimensions()[0]; - output_dims[0] = in_dim; - is_nchw ? output_dims[0].width(unit) : output_dims[0].channel(unit); - output_dims[0].height(1); - - output_dims[0].setTensorType( - {context.getFormat(), context.getActivationDataType()}); - - context.setOutputDimensions(output_dims); - - /** set weight specifications */ - ml::train::TensorDim bias_dim( - 1, is_nchw ? 1 : unit, 1, is_nchw ? unit : 1, - ml::train::TensorDim::TensorType(context.getFormat(), - context.getWeightDataType()), - is_nchw ? 0b0001 : 0b0100); - - ///@note TieWordEmbedding layer's tensor dim is transposed dim of user-defined - /// dim - /// so it can reuse embedding layer. - ml::train::TensorDim weight_dim( - 1, is_nchw ? 1 : in_dim.channel(), is_nchw ? unit : 1, - is_nchw ? in_dim.width() : unit, - ml::train::TensorDim::TensorType(context.getFormat(), - context.getWeightDataType()), - is_nchw ? 0b0011 : 0b0101); - - weight_idx[TieWordEmbeddingParams::weight] = context.requestWeight( - weight_dim, weight_initializer, weight_regularizer, - weight_regularizer_constant, weight_decay, "Embedding", true); - - if (disable_bias.empty() || disable_bias.get() == false) { - weight_idx[TieWordEmbeddingParams::bias] = context.requestWeight( - bias_dim, bias_initializer, nntrainer::WeightRegularizer::NONE, 1.0f, - bias_decay, "bias", true); - } -} - -void TieWordEmbedding::setProperty(const std::vector &values) { - auto remain_props = loadProperties(values, tieword_embedding_props); - LayerImpl::setProperty(remain_props); -} - -void TieWordEmbedding::forwarding(nntrainer::RunLayerContext &context, - bool training) {} - -void TieWordEmbedding::incremental_forwarding( - nntrainer::RunLayerContext &context, unsigned int from, unsigned int to, - bool training) { - - if (mode_ == mode::embedding) - incremental_forwarding_embedding(context, from, to, training); - else if (mode_ == mode::lm_head) - incremental_forwarding_lmhead(context, from, to, training); - else - throw std::invalid_argument("lm_head is not supported yet"); -} - -void TieWordEmbedding::incremental_forwarding_embedding( - nntrainer::RunLayerContext &context, unsigned int from, unsigned int to, - bool training) { - /// @todo get input and output dimension from input_ and hidden itself - unsigned int in_dim = - std::get(tieword_embedding_props); - unsigned int out_dim = - std::get(tieword_embedding_props); - float scale = - std::get(tieword_embedding_props).empty() - ? 1.0f - : std::get(tieword_embedding_props).get(); - unsigned int _from = from; - - nntrainer::Tensor &weight = - context.getWeight(weight_idx[TieWordEmbeddingParams::weight]); - nntrainer::Tensor &hidden_ = context.getOutput(SINGLE_INOUT_IDX); - nntrainer::Tensor &input_ = context.getInput(SINGLE_INOUT_IDX); - - nntrainer::TensorDim out_tensor_dim = - nntrainer::TensorDim({1, 1, 1, out_dim}, hidden_.getTensorType()); - - if (!(weight.getDataType() == nntrainer::TensorDim::DataType::Q6_K || - weight.getDataType() == nntrainer::TensorDim::DataType::FP32)) - throw std::invalid_argument( - "Tieword embedding is not supported yet for the data type"); - - size_t b_size = input_.batch(); - - for (size_t b = 0; b < b_size; ++b) { - float *in_data = - input_.getAddress(b * input_.getDim().getFeatureLen()); - - nntrainer::Tensor batchsliced_hidden = hidden_.getBatchSlice(b, 1); - int iter = to - from; - -#pragma omp parallel for - for (int i = 0; i < iter; ++i) { - unsigned int embed_idx = static_cast(in_data[i]); - if (embed_idx >= in_dim) { - throw std::invalid_argument("input word index is greater than in_dim"); - } - - nntrainer::Tensor cur_weight = - weight.getSharedDataTensor(out_tensor_dim, out_dim * embed_idx); - nntrainer::Tensor out_tensor = - batchsliced_hidden.getSharedDataTensor(out_tensor_dim, out_dim * (i)); - - if (weight.getDataType() == nntrainer::TensorDim::DataType::Q6_K) { - ///@note this should be replaced with quantizer operation - int num_blocks_per_row = (weight.width() + 256 - 1) / 256; - nntrainer::dequantize_row_q6_K( - (void *)((char *)weight.getData() + - (210 * num_blocks_per_row) * embed_idx), - out_tensor.getData(), out_dim); - } else { - out_tensor.copyData(cur_weight); - } - - if (scale != 1.0f) { - out_tensor.multiply_i(scale); - } - } - -#ifdef DEBUG - std::cout << context.getName() << " : " - << "\n input:" << input_ << "\n weight: " << weight - << "\n hidden: " << hidden_ << std::endl; -#endif - } -} - -void TieWordEmbedding::incremental_forwarding_lmhead( - nntrainer::RunLayerContext &context, unsigned int from, unsigned int to, - bool training) { - nntrainer::Tensor weight = - context.getWeight(weight_idx[TieWordEmbeddingParams::weight]); - - nntrainer::Tensor &input_ = context.getInput(SINGLE_INOUT_IDX); - nntrainer::Tensor &hidden_ = context.getOutput(SINGLE_INOUT_IDX); - - ml::train::TensorDim input_dim = input_.getDim(); - ml::train::TensorDim hidden_dim = hidden_.getDim(); - - ml::train::TensorDim input_step_dim = input_dim; - ml::train::TensorDim hidden_step_dim = hidden_dim; - - input_step_dim.batch(1); - input_step_dim.height(1); - hidden_step_dim.batch(1); - - unsigned int b_size = input_dim.batch(); - - for (unsigned int b = 0; b < b_size; ++b) { - nntrainer::Tensor input_step = input_.getSharedDataTensor( - input_step_dim, - b * input_dim.getFeatureLen() + (to - from - 1) * input_.width(), true); - nntrainer::Tensor hidden_step = hidden_.getSharedDataTensor( - hidden_step_dim, b * hidden_dim.getFeatureLen(), true); - - ///@note Since tieword embedding shares the weight with embedding, - /// the weight is transposed. Thus, the dot product should be consider - /// this. - NNTR_THROW_IF(weight.getDataType() == nntrainer::TensorDim::DataType::BCQ, - std::invalid_argument) - << "weight type is not supported for custom tie word embedding layer"; - - input_step.dot(weight, hidden_step, false, true); - - if (auto &disable_bias = - std::get(*layer_impl_props); - disable_bias.empty() || disable_bias.get() == false) { - nntrainer::Tensor &bias = - context.getWeight(weight_idx[TieWordEmbeddingParams::bias]); - hidden_step.add_i(bias); - } - } -} - -void TieWordEmbedding::calcDerivative(nntrainer::RunLayerContext &context) { - throw nntrainer::exception::not_supported( - "calcDerivative for Embedding layer is not supported"); -} - -void TieWordEmbedding::calcGradient(nntrainer::RunLayerContext &context) {} - -void TieWordEmbedding::exportTo(nntrainer::Exporter &exporter, - const ml::train::ExportMethods &method) const { - LayerImpl::exportTo(exporter, method); - exporter.saveResult(tieword_embedding_props, method, this); -} - -void TieWordEmbedding::updateTensorsByInputDimensions( - nntrainer::RunLayerContext &context, - std::vector input_dimensions) { - nntrainer::TensorDim in_dim = context.getInput(SINGLE_INOUT_IDX).getDim(); - nntrainer::TensorDim out_dim = context.getOutput(SINGLE_INOUT_IDX).getDim(); - - unsigned int height = input_dimensions[0].height(); - - if (mode_ == mode::embedding) { - in_dim.width(height); - } else { - in_dim.height(height); - } - out_dim.height(height); - - context.updateInput(SINGLE_INOUT_IDX, in_dim); - context.updateOutput(SINGLE_INOUT_IDX, out_dim); -} - -void TieWordEmbedding::read( - std::ifstream &file, nntrainer::RunLayerContext &context, bool opt_var, - ml::train::ExecutionMode mode, bool trainable, - nntrainer::TensorDim::DataType definedWeightDataType, bool fsu, - size_t start_offset, bool read_from_offset, int file_fd) { - - // Only read when mode is embedding - if (mode_ == mode::embedding) { - for (unsigned int i = 0; i < context.getNumWeights(); ++i) { - /// @note shared weights are only be read at the first acecss - if (context.isGradientFirstAccess(i)) { - context.getWeight(i).read(file, start_offset, read_from_offset); - if (context.isMixedPrecision(i) && trainable && - !context.getWeightFP32(i).empty()) { - context.getWeightFP32(i).copyData(context.getWeight(i)); - } - } - } - } -} - -void TieWordEmbedding::read( - nntrainer::ReadSource src, nntrainer::RunLayerContext &context, bool opt_var, - ml::train::ExecutionMode mode, bool trainable, - nntrainer::TensorDim::DataType definedWeightDataType, bool fsu, - size_t start_offset, bool read_from_offset) { - - // Only read when mode is embedding - if (mode_ == mode::embedding) { - for (unsigned int i = 0; i < context.getNumWeights(); ++i) { - /// @note shared weights are only be read at the first acecss - if (context.isGradientFirstAccess(i)) { - context.getWeight(i).read(src, start_offset, read_from_offset); - if (context.isMixedPrecision(i) && trainable && - !context.getWeightFP32(i).empty()) { - context.getWeightFP32(i).copyData(context.getWeight(i)); - } - } - } - } -} - -void TieWordEmbedding::save(std::ofstream &file, - nntrainer::RunLayerContext &run_context, - bool opt_var, ml::train::ExecutionMode mode, - bool trainable, - nntrainer::TensorDim::DataType dtype) const { - // Only read when mode is embedding - if (mode_ == mode::embedding) { - // @note shared weights are only be saved at the first access - for (unsigned int i = 0; i < run_context.getNumWeights(); ++i) { - if (run_context.isGradientFirstAccess(i)) { - auto &weight = run_context.getWeight(i); - if (dtype == nntrainer::TensorDim::DataType::NONE || - weight.getDataType() == dtype) - weight.save(file); - else { - NNTR_THROW_IF(weight.getDataType() != - nntrainer::TensorDim::DataType::FP32, - std::runtime_error) - << "Save with quantization only supports for FP32 weight."; - ///@note The codelines below can be replaced with quantizer's - /// quantize() - nntrainer::TensorDim dim = weight.getDim(); - unsigned int K = dim.height(); - unsigned int N = dim.width(); - - if (dtype == nntrainer::TensorDim::DataType::Q6_K) { - ////////////////////////////////////////////////////////////////// - ///@note Please note that Embedding layer doesn't need to be - /// transposed! - ////////////////////////////////////////////////////////////////// - nntrainer::Tensor quant_weight(dim.batch(), dim.channel(), K, N, - {nntrainer::Tformat::NCHW, dtype}); - - nntrainer::quantize_q6_K(weight.getData(), - quant_weight.getData(), K, N, - nullptr); - quant_weight.save(file); - } else { - NNTR_THROW_IF(true, std::runtime_error) - << "This dtype is not supported in save with quantization"; - } - } - } - } - } -} - -#ifdef PLUGGABLE - -nntrainer::Layer *create_tie_word_embedding() { - auto layer = new TieWordEmbedding(); - std::cout << "embedding layer created\n"; - return layer; -} - -void destroy_tie_word_embedding(nntrainer::Layer *layer) { - std::cout << "embeddinglayer is deleted\n"; - delete layer; -} - -extern "C" { -nntrainer::LayerPluggable ml_train_layer_pluggable{create_tie_word_embedding, - destroy_tie_word_embedding}; -} - -#endif - -} // namespace quick_dot_ai diff --git a/layers/tie_word_embedding.h b/layers/tie_word_embedding.h deleted file mode 100644 index 65356b0f..00000000 --- a/layers/tie_word_embedding.h +++ /dev/null @@ -1,175 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * Copyright (C) 2021 Jijoong Moon - * - * @file custom_tie_word_embedding_layer.h - * @date 21 May 2025 - * @brief This is Tie_Word_Embedding Layer Class of Neural Network - * @see https://github.com/nntrainer/nntrainer - * @author Eunju Yang - * @bug No known bugs except for NYI items - * - */ - -#ifndef __CUSTOM_TIE_WORD_EMBEDDING_H__ -#define __CUSTOM_TIE_WORD_EMBEDDING_H__ -#ifdef __cplusplus - -#pragma once -#ifdef _WIN32 -#define WIN_EXPORT __declspec(dllexport) -#else -#define WIN_EXPORT -#endif - -#include -#include -#include - -namespace quick_dot_ai { - -/** - * @class TieWordEmbedding - * @brief TieWordEmbedding - * @todo Support setBatch for TieWordEmbedding - */ -WIN_EXPORT class TieWordEmbedding : public nntrainer::LayerImpl { -public: - /** - * @brief Constructor of Embedding Layer - */ - WIN_EXPORT TieWordEmbedding(); - - /** - * @brief Destructor of Embedding Layer - */ - WIN_EXPORT ~TieWordEmbedding() = default; - - /** - * @brief Move constructor. - * @param[in] TieWordEmbedding && - */ - WIN_EXPORT TieWordEmbedding(TieWordEmbedding &&rhs) noexcept = default; - - /** - * @brief Move assignment operator. - * @parma[in] rhs TieWordEmbedding to be moved. - */ - WIN_EXPORT TieWordEmbedding &operator=(TieWordEmbedding &&rhs) = default; - - /** - * @copydoc Layer::finalize(InitLayerContext &context) - */ - WIN_EXPORT void finalize(nntrainer::InitLayerContext &context) override; - - /** - * @copydoc Layer::forwarding(RunLayerContext &context, bool training) - */ - WIN_EXPORT void forwarding(nntrainer::RunLayerContext &context, - bool training) override; - - /** -οΏΌ * @copydoc Layer::incremental_forwarding(RunLayerContext &context, unsigned -οΏΌ * int from, unsigned int to, bool training) -οΏΌ */ - WIN_EXPORT void incremental_forwarding(nntrainer::RunLayerContext &context, - unsigned int from, unsigned int to, - bool training) override; - - /** - * @copydoc Layer::calcDerivative(RunLayerContext &context) - */ - WIN_EXPORT void calcDerivative(nntrainer::RunLayerContext &context) override; - - /** - * @copydoc Layer::calcGradient(RunLayerContext &context) - */ - WIN_EXPORT void calcGradient(nntrainer::RunLayerContext &context) override; - - /** - * @copydoc Layer::exportTo(Exporter &exporter, ml::train::ExportMethods - * method) - */ - WIN_EXPORT void - exportTo(nntrainer::Exporter &exporter, - const ml::train::ExportMethods &method) const override; - - /** - * @copydoc Layer::getType() - */ - WIN_EXPORT const std::string getType() const override { - return TieWordEmbedding::type; - }; - - /** - * @copydoc Layer::supportBackwarding() - */ - WIN_EXPORT bool supportBackwarding() const override { return false; } - - WIN_EXPORT void updateTensorsByInputDimensions( - nntrainer::RunLayerContext &context, - std::vector input_dimensions) override; - - /** - * @copydoc Layer::read() - */ - WIN_EXPORT void read(std::ifstream &file, nntrainer::RunLayerContext &context, - bool opt_var, ml::train::ExecutionMode mode, - bool trainable, - nntrainer::TensorDim::DataType definedWeightDataType, - bool fsu = false, size_t start_offset = 0, - bool read_from_offset = false, - int file_fd = -1) override; - - /** - * @copydoc Layer::read() (ReadSource/mmap variant) - */ - WIN_EXPORT void read(nntrainer::ReadSource src, - nntrainer::RunLayerContext &context, bool opt_var, - ml::train::ExecutionMode mode, bool trainable, - nntrainer::TensorDim::DataType definedWeightDataType, - bool fsu, size_t start_offset = 0, - bool read_from_offset = false) override; - - /** - * @copydic Layer::save() - */ - WIN_EXPORT void save(std::ofstream &file, - nntrainer::RunLayerContext &run_context, bool opt_var, - ml::train::ExecutionMode mode, bool trainable, - nntrainer::TensorDim::DataType dtype = - nntrainer::TensorDim::DataType::NONE) const override; - - using Layer::setProperty; - - /** - * @copydoc Layer::setProperty(const PropertyType type, const std::string - * &value) - */ - WIN_EXPORT void setProperty(const std::vector &values) override; - - inline static const std::string type = "tie_word_embeddings"; - -private: - std::tuple - tieword_embedding_props; - enum mode { embedding, lm_head }; - enum mode mode_; - std::array weight_idx; /**< indices of the weights */ - - WIN_EXPORT void finalize_embedding(nntrainer::InitLayerContext &context); - WIN_EXPORT void finalize_lmhead(nntrainer::InitLayerContext &context); - WIN_EXPORT void - incremental_forwarding_embedding(nntrainer::RunLayerContext &context, - unsigned int from, unsigned int to, - bool training); - WIN_EXPORT void - incremental_forwarding_lmhead(nntrainer::RunLayerContext &context, - unsigned int from, unsigned int to, - bool training); -}; -} // namespace quick_dot_ai - -#endif /* __cplusplus */ -#endif /* __CUSTOM_TIE_WORD_EMBEDDING_H__ */ diff --git a/lib/libtokenizers_c.a b/lib/libtokenizers_c.a deleted file mode 100644 index 2f38213c..00000000 Binary files a/lib/libtokenizers_c.a and /dev/null differ diff --git a/llm_util.cpp b/llm_util.cpp deleted file mode 100644 index 69c636c1..00000000 --- a/llm_util.cpp +++ /dev/null @@ -1,90 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * - * @file llm_util.cpp - * @brief util functions for llm (refactored from main.cpp) - * @date 21 August 2024 - * @see https://github.com/nntrainer/nntrainer - * @author Seungbaek Hong - * @author Hyeonseok Lee - * @author Eunju Yang - * @bug No known bugs except for NYI items - */ - -#include - -std::vector generate_multi_tokens( - float *logits, unsigned int NUM_VOCAB, unsigned int NUM_TARGET_TOKENS, - float repetition_penalty, unsigned int *input_ids, unsigned int NUM_INPUT_IDS, - unsigned int *bad_words_ids, unsigned int NUM_BAD_WORDS_IDS) { - - std::vector outputs; - - // apply repetition penalty - if (repetition_penalty != 1 && input_ids != nullptr && NUM_INPUT_IDS != 0) { - applyRepetitionPenalty(logits, input_ids, NUM_INPUT_IDS, - repetition_penalty); - } - - // apply bad words penalty - if (bad_words_ids != nullptr && NUM_BAD_WORDS_IDS != 0) - applyBadWordsPenalty(logits, bad_words_ids, NUM_BAD_WORDS_IDS); - - // Sort and generate multiple tokens - std::vector> top_indices_and_logits; - for (unsigned int i = 0; i < NUM_VOCAB; ++i) { - top_indices_and_logits.push_back({i, logits[i]}); - } - std::partial_sort(top_indices_and_logits.begin(), - top_indices_and_logits.begin() + NUM_TARGET_TOKENS, - top_indices_and_logits.end(), - [](auto &a, auto &b) { return a.second > b.second; }); - - // add sampled words - for (unsigned int i = 0; i < NUM_TARGET_TOKENS; ++i) { - outputs.push_back(top_indices_and_logits[i].first); - } - - return outputs; -} - -void applyRepetitionPenalty(float *logits, unsigned int *input_ids, - unsigned int NUM_INPUT_IDS, - float repetition_penalty) { - for (unsigned int i = 0; i < NUM_INPUT_IDS; ++i) { - if (logits[input_ids[i]] < 0) { - logits[input_ids[i]] *= repetition_penalty; - } else { - logits[input_ids[i]] /= repetition_penalty; - } - } -} - -void applyBadWordsPenalty(float *logits, unsigned int *bad_words_ids, - unsigned int NUM_BAD_WORDS_IDS) { - for (unsigned int i = 0; i < NUM_BAD_WORDS_IDS; ++i) { - logits[bad_words_ids[i]] = -INFINITY; - } -} - -/** - * @brief Apply temperature & top-k & top-p to logits - * @return Max logit for softmax - */ -float applyTKP(float *logits, int len, float temperature, unsigned int top_k, - float top_p) { - - // Apply temperature & Sort logits - std::vector> top_indices_and_logits; - for (int i = 0; i < len; ++i) { - if (temperature > 1e-5) - logits[i] = logits[i] / temperature; - top_indices_and_logits.push_back({i, logits[i]}); - } - std::partial_sort(top_indices_and_logits.begin(), - top_indices_and_logits.begin() + 1, - top_indices_and_logits.end(), - [](auto &a, auto &b) { return a.second > b.second; }); - - return top_indices_and_logits[0].second; -} diff --git a/llm_util.hpp b/llm_util.hpp deleted file mode 100644 index 45fa419e..00000000 --- a/llm_util.hpp +++ /dev/null @@ -1,116 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * - * @file llm_util.hpp - * @brief util functions for llm (refactored from main.cpp) - * @date 21 August 2024 - * @see https://github.com/nntrainer/nntrainer - * @author Seungbaek Hong - * @author Hyeonseok Lee - * @author Eunju Yang - * @bug No known bugs except for NYI items - */ - -#ifndef __LLM_UTIL_HPP__ -#define __LLM_UTIL_HPP__ __LLM_UTIL_HPP__ - -#include // sort -#include // INFINITY -#include - -#include -#include -#include -#include -/***************** ALAIS *******************/ -using LayerHandle = std::shared_ptr; -using ModelHandle = std::unique_ptr; -using ml::train::createLayer; - -/****************** UTIL *******************/ -/** - * @brief util functio to make "key=value" from key and value - * - * @tparam T type of a value - * @param key key - * @param value value - * @return std::string with "key=value" - */ -template -static std::string withKey(const std::string &key, const T &value) { - std::stringstream ss; - ss << key << "=" << value; - return ss.str(); -} - -/** - * @brief util function to make "key=value1,value2, ..." from key and value - - * @tparam T type of a value - * @param key key - * @param value list of value - * @return std::string with "key=value1, value, ...." - */ -template -static std::string withKey(const std::string &key, - std::initializer_list value) { - if (std::empty(value)) { - throw std::invalid_argument("empty data cannot be converted"); - } - - std::stringstream ss; - ss << key << "="; - - auto iter = value.begin(); - for (; iter != value.end() - 1; ++iter) { - ss << *iter << ','; - } - ss << *iter; - - return ss.str(); -} - -/** - * @brief - */ -template -T unwrap(std::optional &&value, const std::string &error_msg) { - if (value.has_value()) { - return value.value(); - } else { - throw std::runtime_error(error_msg); - } -} - -/** - * @brief generate multi tokens from logits - * @note This function apply repetition penalty, bad words penalty, and sort to - * generate multiple tokens - */ -std::vector generate_multi_tokens( - float *logits, unsigned int NUM_VOCAB = 0, unsigned int NUM_TARGET_TOKENS = 1, - float repetition_penalty = 1, unsigned int *input_ids = nullptr, - unsigned int NUM_INPUT_IDS = 0, unsigned int *bad_words_ids = nullptr, - unsigned int NUM_BAD_WORDS_IDS = 0); - -/** - * @brief Apply repetition penalty to logits - */ -void applyRepetitionPenalty(float *logits, unsigned int *input_ids, - unsigned int NUM_INPUT_IDS, - float repetition_penalty = 1); - -/** - * @brief Apply bad words penalty - */ -void applyBadWordsPenalty(float *logits, unsigned int *bad_words_ids, - unsigned int NUM_BAD_WORDS_IDS); - -/** - * @brief Apply temperature & top-k & top-p to logits - * @return Max logit for softmax - */ -float applyTKP(float *logits, int len, float temperature, unsigned int top_k, - float top_p); - -#endif // __LLM_UTIL_HPP__ diff --git a/main.cpp b/main.cpp deleted file mode 100644 index 19646732..00000000 --- a/main.cpp +++ /dev/null @@ -1,330 +0,0 @@ -/** - * Copyright (C) 2025 Samsung Electronics Co., Ltd. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * http://www.apache.org/licenses/LICENSE-2.0 - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - * - * @file main.cpp - * @date 23 July 2025 - * @brief This is a main file for CausalLM application - * @see https://github.com/nnstreamer/ - * @author Eunju Yang - * @bug No known bugs except for NYI items - * - */ -#include -#include -#include -#include - -#include "json.hpp" -#include -#include - -#include "causal_lm.h" -#include "chat_template.h" -#include "embedding_gemma.h" -#include "gemma3_causallm.h" -#include "gptoss_cached_slim_causallm.h" -#include "gptoss_causallm.h" -#include "qwen2_causallm.h" -#include "qwen2_embedding.h" -#include "qwen3_cached_slim_moe_causallm.h" -#include "qwen3_causallm.h" -#include "qwen3_embedding.h" -#include "qwen3_moe_causallm.h" -#include "qwen3_slim_moe_causallm.h" -#include -#include - -#include -#include -#include -#include - -using json = nlohmann::json; - -std::atomic peak_rss_kb{0}; -std::atomic tracking_enabled{true}; - -void printMemoryUsage() { - struct rusage usage; - getrusage(RUSAGE_SELF, &usage); - std::cout << "Max Resident Set Size: " << usage.ru_maxrss << " KB" - << std::endl; -} - -size_t read_vm_rss_kb() { - std::ifstream status("/proc/self/status"); - std::string line; - while (std::getline(status, line)) { - if (line.find("VmRSS:") == 0) { - size_t kb = 0; - sscanf(line.c_str(), "VmRSS: %zu kB", &kb); - return kb; - } - } - return 0; -} - -size_t read_private_rss_kb() { - std::ifstream smaps("/proc/self/smaps_rollup"); - std::string line; - size_t total = 0; - while (std::getline(smaps, line)) { - if (line.find("Private_Clean:") == 0 || line.find("Private_Dirty:") == 0) { - size_t kb; - sscanf(line.c_str(), "%*s %zu", &kb); - total += kb; - } - } - return total; -} - -void start_peak_tracker() { - std::thread([] { - while (tracking_enabled.load()) { - size_t current = read_private_rss_kb(); - size_t prev = peak_rss_kb.load(); - if (current > prev) { - peak_rss_kb.store(current); - } - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - } - }).detach(); -} - -void stop_and_print_peak() { - tracking_enabled.store(false); - std::this_thread::sleep_for(std::chrono::milliseconds(20)); - std::cout << "Peak memory usage (VmRSS): " << peak_rss_kb.load() << " KB" - << std::endl; -} - -std::string resolve_architecture(std::string model_type, - const std::string &architecture) { - std::transform(model_type.begin(), model_type.end(), model_type.begin(), - [](unsigned char c) { return std::tolower(c); }); - - if (model_type == "embedding") { - if (architecture == "Qwen3ForCausalLM") { - return "Qwen3Embedding"; - } else if (architecture == "Gemma3ForCausalLM" || - architecture == "Gemma3TextModel") { - return "EmbeddingGemma"; - } else if (architecture == "Qwen2Model") { - return "Qwen2Embedding"; - } else { - throw std::invalid_argument( - "Unsupported architecture for embedding model: " + architecture); - } - } - - return architecture; -} - -int main(int argc, char *argv[]) { - - auto start_time = std::chrono::high_resolution_clock::now(); - - /** Register all runnable causallm models to factory */ - quick_dot_ai::Factory::Instance().registerModel( - "LlamaForCausalLM", [](json cfg, json generation_cfg, json nntr_cfg) { - return std::make_unique(cfg, generation_cfg, - nntr_cfg); - }); - quick_dot_ai::Factory::Instance().registerModel( - "Qwen2ForCausalLM", [](json cfg, json generation_cfg, json nntr_cfg) { - return std::make_unique(cfg, generation_cfg, - nntr_cfg); - }); - quick_dot_ai::Factory::Instance().registerModel( - "Qwen2Embedding", [](json cfg, json generation_cfg, json nntr_cfg) { - return std::make_unique(cfg, generation_cfg, - nntr_cfg); - }); - quick_dot_ai::Factory::Instance().registerModel( - "Qwen3ForCausalLM", [](json cfg, json generation_cfg, json nntr_cfg) { - return std::make_unique(cfg, generation_cfg, - nntr_cfg); - }); - quick_dot_ai::Factory::Instance().registerModel( - "Qwen3MoeForCausalLM", [](json cfg, json generation_cfg, json nntr_cfg) { - return std::make_unique(cfg, generation_cfg, - nntr_cfg); - }); - quick_dot_ai::Factory::Instance().registerModel( - "Qwen3SlimMoeForCausalLM", - [](json cfg, json generation_cfg, json nntr_cfg) { - return std::make_unique( - cfg, generation_cfg, nntr_cfg); - }); - quick_dot_ai::Factory::Instance().registerModel( - "Qwen3CachedSlimMoeForCausalLM", - [](json cfg, json generation_cfg, json nntr_cfg) { - return std::make_unique( - cfg, generation_cfg, nntr_cfg); - }); - quick_dot_ai::Factory::Instance().registerModel( - "Qwen3Embedding", [](json cfg, json generation_cfg, json nntr_cfg) { - return std::make_unique(cfg, generation_cfg, - nntr_cfg); - }); - quick_dot_ai::Factory::Instance().registerModel( - "GptOssForCausalLM", [](json cfg, json generation_cfg, json nntr_cfg) { - return std::make_unique(cfg, generation_cfg, - nntr_cfg); - }); - quick_dot_ai::Factory::Instance().registerModel( - "GptOssCachedSlimCausalLM", - [](json cfg, json generation_cfg, json nntr_cfg) { - return std::make_unique( - cfg, generation_cfg, nntr_cfg); - }); - quick_dot_ai::Factory::Instance().registerModel( - "Gemma3ForCausalLM", [](json cfg, json generation_cfg, json nntr_cfg) { - return std::make_unique(cfg, generation_cfg, - nntr_cfg); - }); - quick_dot_ai::Factory::Instance().registerModel( - "EmbeddingGemma", [](json cfg, json generation_cfg, json nntr_cfg) { - return std::make_unique(cfg, generation_cfg, - nntr_cfg); - }); - - // Validate arguments - if (argc < 2) { - std::cerr << "Usage: " << argv[0] << " [input_prompt]\n" - << " : Path to model directory\n" - << " [input_prompt] : Optional input text (uses sample_input or " - "chat_input if omitted)\n"; - return EXIT_FAILURE; - } - - const std::string model_path = argv[1]; - std::string input_text; - std::string system_head_prompt = ""; - std::string system_tail_prompt = ""; - - std::cout << model_path << std::endl; - - try { - // Load configuration files - json cfg = quick_dot_ai::LoadJsonFile(model_path + "/config.json"); - json generation_cfg = - quick_dot_ai::LoadJsonFile(model_path + "/generation_config.json"); - json nntr_cfg = quick_dot_ai::LoadJsonFile(model_path + "/nntr_config.json"); - - if (nntr_cfg.contains("system_prompt")) { - system_head_prompt = - nntr_cfg["system_prompt"]["head_prompt"].get(); - system_tail_prompt = - nntr_cfg["system_prompt"]["tail_prompt"].get(); - } - - // Construct weight file path - const std::string weight_file = - model_path + "/" + nntr_cfg["model_file_name"].get(); - - std::cout << weight_file << std::endl; - - // Initialize and run model - std::string architecture = - cfg["architectures"].get>()[0]; - - if (nntr_cfg.contains("model_type")) { - std::string model_type = nntr_cfg["model_type"].get(); - architecture = resolve_architecture(model_type, architecture); - } - - // Load chat template from tokenizer_config.json (if available) - quick_dot_ai::ChatTemplate chat_tmpl; - std::string tokenizer_config_path = model_path + "/tokenizer_config.json"; - if (std::filesystem::exists(tokenizer_config_path)) { - chat_tmpl = quick_dot_ai::ChatTemplate::fromFile(tokenizer_config_path); - if (chat_tmpl.isAvailable()) { - std::cout << "[Info] Chat template loaded from tokenizer_config.json" - << std::endl; - } else { - std::cerr - << "[Warning] tokenizer_config.json found but chat template could " - "not be loaded. Chat formatting will not be applied to raw input." - << std::endl; - } - } else { - std::cerr - << "[Warning] tokenizer_config.json not found in " << model_path - << ". Chat template will not be available for raw input formatting." - << std::endl; - } - - // Determine input text - if (argc >= 3) { - input_text = argv[2]; - // Apply chat template to raw user input if available - if (chat_tmpl.isAvailable()) { - input_text = chat_tmpl.apply(input_text); - } - } else { - if (nntr_cfg.contains("chat_input")) { - if (architecture == "Gemma3ForCausalLM") { - input_text = quick_dot_ai::gemma3::apply_function_gemma_template( - nntr_cfg["chat_input"]); - } else { - std::cerr << "[Warning] 'chat_input' is set but support for model " - "architecture '" - << architecture - << "' is not implemented. Falling back to 'sample_input'." - << std::endl; - input_text = nntr_cfg["sample_input"].get(); - } - } else { - input_text = nntr_cfg["sample_input"].get(); - } - } - - auto model = quick_dot_ai::Factory::Instance().create(architecture, cfg, - generation_cfg, nntr_cfg); - if (!model) { - std::cerr << "Unknown architecture: " << architecture << std::endl; - std::cerr << "Registered architectures:"; - quick_dot_ai::Factory::Instance().printRegistered(std::cerr); - std::cerr << std::endl; - return EXIT_FAILURE; - } - model->initialize(); - model->load_weight(weight_file); - -#ifdef PROFILE - start_peak_tracker(); -#endif -#if defined(_WIN32) - model->run(input_text.c_str(), system_head_prompt.c_str(), - system_tail_prompt.c_str()); -#else - model->run(input_text, system_head_prompt, system_tail_prompt); -#endif -#ifdef PROFILE - stop_and_print_peak(); -#endif - auto finish_time = std::chrono::high_resolution_clock::now(); - auto e2e_duration = std::chrono::duration_cast( - finish_time - start_time); - std::cout << "[e2e time]: " << e2e_duration.count() << " ms \n"; - printMemoryUsage(); - - } catch (const std::exception &e) { - std::cerr << "\n[!] FATAL ERROR: " << e.what() << "\n"; - return EXIT_FAILURE; - } - - return EXIT_SUCCESS; -} diff --git a/meson.build b/meson.build index 1ff38b5d..ebf25adf 100644 --- a/meson.build +++ b/meson.build @@ -1,138 +1,170 @@ -project('quick_dot_ai', 'c', 'cpp', - version: '0.1.0', +project('quick-dot-ai', 'c', 'cpp', + version: '0.2.0', license: ['apache-2.0'], meson_version: '>=0.55.0', default_options: [ + 'werror=false', 'warning_level=1', 'c_std=gnu89', 'cpp_std=c++17', - 'buildtype=release', - ], -) - -# Pull in nntrainer via the bundled submodule as a meson subproject. -# We disable Applications so that nntrainer's own (now legacy) in-tree -# Applications/CausalLM is not built alongside this standalone one, and -# we skip tests to keep the dependency build lean. -nntrainer_proj = subproject('nntrainer', - default_options: [ - 'enable-app=false', - 'enable-test=false', - 'enable-tflite-backbone=false', - 'enable-tflite-interpreter=false', - 'werror=false', - ], + 'buildtype=release' + ] ) -nntrainer_dep = nntrainer_proj.get_variable('nntrainer_dep') -nntrainer_ccapi_dep = nntrainer_proj.get_variable('nntrainer_ccapi_dep') -application_install_dir = get_option('prefix') / get_option('libdir') / 'quick_dot_ai' - -openmp_dep = dependency('openmp') -powershell_prog = find_program('powershell', required: (build_machine.system() == 'windows')) - -quick_dot_ai_src = [ - meson.current_source_dir() / 'chat_template.cpp', - meson.current_source_dir() / 'huggingface_tokenizer.cpp', - meson.current_source_dir() / 'llm_util.cpp', - meson.current_source_dir() / 'api' / 'causal_lm_api.cpp', - meson.current_source_dir() / 'api' / 'model_config.cpp', -] - -executable_src = [ - meson.current_source_dir() / 'main.cpp', -] - -quick_dot_ai_inc_abs = [meson.current_source_dir()] -quick_dot_ai_inc = [include_directories('.')] - -# Build layer dependency -quick_dot_ai_layer_dependencies = [] -subdir('layers') - -# Add common layers to dependencies -quick_dot_ai_layer_dependencies += [ - quick_dot_ai_rms_norm_dep, - quick_dot_ai_tie_word_embedding_dep, - quick_dot_ai_lm_head_dep, - quick_dot_ai_swiglu_dep, - quick_dot_ai_mha_core_dep, - quick_dot_ai_embedding_layer_dep, - quick_dot_ai_lm_head_dep, - quick_dot_ai_reshaped_rms_norm_dep, - quick_dot_ai_qkv_layer_dep, - quick_dot_ai_embedding_pooling_layer_dep, - quick_dot_ai_embedding_normalize_layer_dep, -] - -subdir('models') - -if (get_option('platform') == 'windows') and (build_machine.system() == 'windows') - run_command(powershell_prog, '-ExecutionPolicy', 'Bypass', '-File', join_paths(meson.current_source_dir(), 'jni', 'prepare_encoder.ps1'), meson.build_root(), '0.2', check: true) -elif get_option('platform') != 'tizen' - run_command([meson.current_source_dir() / 'jni' / 'prepare_encoder.sh', meson.build_root(), '0.2'], check: true) +# ── Compiler setup ────────────────────────────────────────────────────── +cc = meson.get_compiler('c') +cxx = meson.get_compiler('cpp') + +# ── Platform detection ────────────────────────────────────────────────── +_platform_opt = get_option('platform') +if _platform_opt == 'auto' + is_android = host_machine.system() == 'android' +elif _platform_opt == 'android' + is_android = true +else + is_android = false endif -tokenizer_dependencies = [] -tokenizer_path = meson.current_source_dir() / 'lib' / 'libtokenizers_c.a' - -cpp_args = [] - -if get_option('default_library') == 'shared' - cpp_args += [ '-DPLUGGABLE' ] +# ── QNN guard: android only ───────────────────────────────────────────── +enable_qnn = get_option('enable-qnn') +if enable_qnn and not is_android + warning('enable-qnn is only supported on android. Disabling QNN.') + enable_qnn = false endif -quick_dot_ai = shared_library('quick_dot_ai', - quick_dot_ai_src, - dependencies: [ - nntrainer_dep, - nntrainer_ccapi_dep, - quick_dot_ai_layer_dependencies, - tokenizer_dependencies, - openmp_dep - ], - include_directories: quick_dot_ai_inc, - install: true, - install_dir: application_install_dir, - link_args: [tokenizer_path], - cpp_args: cpp_args -) - -quick_dot_ai_dep = declare_dependency( - link_with: [quick_dot_ai], - include_directories: quick_dot_ai_inc -) +# ── Path configuration ───────────────────────────────────────────────── +project_root = meson.current_source_dir() +nntrainer_root = project_root / 'nntrainer' +causallm_root = nntrainer_root / 'Applications' / 'CausalLM' + +# ── Find pre-built nntrainer libraries ────────────────────────────────── +if is_android + _nntr_result = nntrainer_root / 'builddir' / 'android_build_result' + _nntr_libdir = _nntr_result / 'lib' / 'arm64-v8a' + _nntr_incdir = _nntr_result / 'include' / 'nntrainer' + + libnntrainer = cxx.find_library('nntrainer', dirs: [_nntr_libdir]) + libccapi = cxx.find_library('ccapi-nntrainer', dirs: [_nntr_libdir]) + + nntrainer_inc_args = ['-I' + _nntr_incdir] +else + _nntr_builddir = nntrainer_root / get_option('nntrainer_builddir') + + libnntrainer = cxx.find_library('nntrainer', + dirs: [_nntr_builddir / 'nntrainer']) + libccapi = cxx.find_library('ccapi-nntrainer', + dirs: [_nntr_builddir / 'api' / 'ccapi']) + + # nntrainer source-tree includes (x86 build) + # NOTE: Do NOT include the builddir itself β€” ruy/time.h shadows system + nntrainer_inc_args = [] + foreach p : [ + nntrainer_root / 'api' / 'ccapi' / 'include', + nntrainer_root / 'api', + nntrainer_root / 'nntrainer', + nntrainer_root / 'nntrainer' / 'layers', + nntrainer_root / 'nntrainer' / 'layers' / 'loss', + nntrainer_root / 'nntrainer' / 'models', + nntrainer_root / 'nntrainer' / 'graph', + nntrainer_root / 'nntrainer' / 'compiler', + nntrainer_root / 'nntrainer' / 'optimizers', + nntrainer_root / 'nntrainer' / 'tensor', + nntrainer_root / 'nntrainer' / 'tensor' / 'cpu_backend', + nntrainer_root / 'nntrainer' / 'tensor' / 'cpu_backend' / 'fallback', + nntrainer_root / 'nntrainer' / 'tensor' / 'cpu_backend' / 'cblas_interface', + nntrainer_root / 'nntrainer' / 'tensor' / 'cpu_backend' / 'x86', + nntrainer_root / 'nntrainer' / 'tensor' / 'cpu_backend' / 'ggml_interface', + nntrainer_root / 'nntrainer' / 'tensor' / 'cpu_backend' / 'ggml_interface' / 'nntr_ggml_impl', + nntrainer_root / 'nntrainer' / 'utils', + nntrainer_root / 'nntrainer' / 'dataset', + nntrainer_root / 'nntrainer' / 'schema', + ] + nntrainer_inc_args += ['-I' + p] + endforeach +endif -e = executable('quick_dot_ai_run', - executable_src, - include_directories: quick_dot_ai_inc, - dependencies: [nntrainer_dep, nntrainer_ccapi_dep, quick_dot_ai_layer_dependencies, quick_dot_ai_dep], +nntrainer_dep = declare_dependency( + dependencies: [libnntrainer, libccapi], + compile_args: nntrainer_inc_args, ) +# ── OpenMP / threads ──────────────────────────────────────────────────── +openmp_dep = dependency('openmp', required: false) +thread_dep = dependency('threads') -quantize_src = [ - meson.current_source_dir() / 'quantize.cpp', -] - -e_quantize = executable('quick_dot_ai_quantize', - quantize_src, - include_directories: quick_dot_ai_inc, - dependencies: [nntrainer_dep, nntrainer_ccapi_dep, quick_dot_ai_layer_dependencies, quick_dot_ai_dep], -) +# ── Android log library ───────────────────────────────────────────────── +if is_android + log_dep = cxx.find_library('log', required: true) +else + log_dep = dependency('', required: false) +endif +# ── CausalLM includes ────────────────────────────────────────────────── +causallm_auto_inc_dirs = run_command('find', + causallm_root / 'layers', + causallm_root / 'models', + causallm_root / 'res', + '-type', 'd', + check: true).stdout().strip().split('\n') + +causallm_inc_args = ['-I' + causallm_root] +foreach d : causallm_auto_inc_dirs + causallm_inc_args += ['-I' + d] +endforeach + +# Add minja include path for chat template support +causallm_inc_args += ['-I' + causallm_root / 'third_party' / 'minja' / 'include'] + +# ── XGrammar ────────────────────────────────────────────────── +causallm_inc_args += ['-I' + project_root / 'src' / 'xgrammar'] +# Add xgrammar library include paths +causallm_inc_args += ['-I' + project_root / 'xgrammar' / 'include'] +causallm_inc_args += ['-I' + project_root / 'xgrammar' / '3rdparty' / 'picojson'] +causallm_inc_args += ['-I' + project_root / 'xgrammar' / '3rdparty' / 'dlpack' / 'include'] + +# ── XGrammar ────────────────────────────────────────────────── +causallm_inc_args += ['-I' + project_root / 'src' / 'xgrammar'] +# Add xgrammar library include paths +causallm_inc_args += ['-I' + project_root / 'xgrammar' / 'include'] +causallm_inc_args += ['-I' + project_root / 'xgrammar' / '3rdparty' / 'picojson'] +causallm_inc_args += ['-I' + project_root / 'xgrammar' / '3rdparty' / 'dlpack' / 'include'] + +# ── Tokenizer: platform-specific ──────────────────────────────────────── +if is_android + tokenizer_path = causallm_root / 'lib' / 'libtokenizers_android_c.a' +else + tokenizer_path = causallm_root / 'lib' / 'libtokenizers_c.a' +endif -test_api_src = [ - meson.current_source_dir() / 'api' / 'test_api.cpp', -] +# ── Extra defines ─────────────────────────────────────────────────────── +extra_defines = [] +if enable_qnn + extra_defines += '-DENABLE_QNN=1' +endif -test_api = executable('quick_dot_ai_test_api', - test_api_src, - include_directories: quick_dot_ai_inc, - dependencies: [nntrainer_dep, nntrainer_ccapi_dep, quick_dot_ai_layer_dependencies, quick_dot_ai_dep], - build_by_default: true, - install: false, -) +# FP16 is ARM-only (__fp16 type not available on x86) +if get_option('enable-fp16') and is_android + extra_defines += ['-DENABLE_FP16=1', '-DUSE__FP16=1'] +endif -if get_option('enable-test') - subdir('test') +foreach d : extra_defines + add_project_arguments(d, language: ['c', 'cpp']) +endforeach + +# Android-specific ARM flags +if is_android + _arm_args = [ + '-D__ARM_NEON__=1', + '-DUSE_NEON=1', + ] + foreach a : _arm_args + add_project_arguments(a, language: ['c', 'cpp']) + endforeach endif + +# ── Build components ──────────────────────────────────────────────────── +# Each subdir guards itself with its own enable option. +subdir('src') +subdir('qnn') +subdir('api') +subdir('api-app') diff --git a/meson_options.txt b/meson_options.txt index 4d42f1d8..3e0778de 100644 --- a/meson_options.txt +++ b/meson_options.txt @@ -1,3 +1,12 @@ -option('platform', type: 'combo', choices: ['none', 'tizen', 'yocto', 'android', 'windows'], value: 'none') -option('enable-test', type: 'boolean', value: false, - description: 'Build Quick.AI GoogleTest unit tests under test/.') +option('platform', type: 'combo', choices: ['auto', 'x86', 'android'], value: 'auto', + description: 'Target platform (auto detects from cross file)') +option('nntrainer_builddir', type: 'string', value: 'builddir_x86', + description: 'nntrainer build directory name (x86 only)') +option('enable-qnn', type: 'boolean', value: false, + description: 'Build QNN integration - android only (qnn_context lib + qnn-transformer model)') +option('enable-fp16', type: 'boolean', value: true, + description: 'Enable FP16 support') +option('enable-api', type: 'boolean', value: false, + description: 'Build libquick_dot_ai_api.so') +option('enable-api-test', type: 'boolean', value: false, + description: 'Build quick_dot_ai_test executable') \ No newline at end of file diff --git a/models/README.md b/models/README.md deleted file mode 100644 index df30321d..00000000 --- a/models/README.md +++ /dev/null @@ -1,43 +0,0 @@ -# CausalLM Models - -This directory contains the implementations of the Causal Language Models structured into subdirectories. - -## Base Model -- **`causal_lm.cpp/h`**: The base class for all CausalLM implementations, defining the core architecture (Embedding, Decoder Blocks, RMSNorm, LMHead). - -## Available Models - - -Here is the list of supported models. We provide **Standard** implementations and **NNTrainer Variants** optimized for on-device environments. - -| Model Name | Size | Type | Special Features | Description | -| :--- | :---: | :---: | :--- | :--- | -| `causal_lm` | - | Standard | - | Basic implementation of the llama model. | -| `qwen3_causallm` | **0.6B, 1.7B, 4B, 8B, 14B, 32B** | Standard | - | Basic implementation of the Qwen3 model. | -| `qwen3_moe_causallm` | **30B-A3B** | Standard | - | Basic implementation of the Qwen3 MoE model. | -| `qwen3_slim_moe_causallm` | **30B-A3B** | **Variant** | **Slim** | Activated by FSU scheme (On-the-fly expert loading). | -| `qwen3_cached_slim_moe_causallm` | **30B-A3B** | **Variant** | **Cached Slim** | MoE-specific FSU implementation with **expert caching**. | -| `gptoss_causallm` | **20B-A3.6B, 120B-5.1B** | Standard | - | Basic implementation of the GPT-OSS model. | -| `gptoss_cached_slim_causallm` | **20B-A3.6B, 120B-5.1B** | **Variant** | **Cached Slim** | GPT-OSS MoE implementation with **expert caching**. | - -> *Note: **Standard** refers to the basic implementation, while **Variant** refers to models optimized for your device using FSU schemes.* - -### MoE inference support - -#### What is a `slim` model? -The *_slim_* model reduces peak memory usage by loading experts in an on-the-fly manner. - -- **Efficient Initialization**: Instead of loading all model weights at once, the slim model initializes without the heavy expert layers. -- **Dynamic Loading**: Only the activated experts are loaded into memory during runtime, keeping memory usage significantly lower than the original model. -- **Performance Note**: Since the model dynamically maps memory to experts on storage, inference speed relies heavily on the storage read I/O speed. - -#### What is a `cached` model? - -The cached model is a variant of the slim model that caches activated experts. Instead of immediately deactivating experts after use, it delays memory unmapping. This approach reduces repetitive loading overhead, thereby increasing inference speed. - -## Directory Structure -Each model directory typically contains: -- `*_causallm.cpp`: The model implementation class. -- `*_layer.cpp`: (Optional) Model-specific custom layer implementations. -- `meson.build`: Build configuration for the model. -- `README.md`: Specific details about the model. diff --git a/models/causal_lm.cpp b/models/causal_lm.cpp deleted file mode 100644 index d09d1b8f..00000000 --- a/models/causal_lm.cpp +++ /dev/null @@ -1,594 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * Copyright (C) 2025 Jijoong Moon - * Copyright (C) 2025 Seungback Hong - * Copyright (C) 2025 Hyeonseok Lee - * Copyright (C) 2025 Eunju Yang - * - * @file causal_lm.cpp - * @date 10 July 2025 - * @see https://github.com/nntrainer/nntrainer - * @author Jijoong Moon - * @author Seungbaek Hong - * @author Hyeonseok Lee - * @author Eunju Yang - * @bug No known bugs except for NYI items - * @brief This file defines CausalLM's basic actions - * @note This causal_lm.h constructs a class for Transformer-based Causal - * Language Model (CausalLM). It aims to support AutoModelForCausalLM with - * nntrainer. It supports the following models: - * - Llama - */ - -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include - -#include -#include - -namespace quick_dot_ai { - -CausalLM::CausalLM(json &cfg, json &generation_cfg, json &nntr_cfg) : - Transformer(cfg, generation_cfg, nntr_cfg, ModelType::CAUSALLM) { - setupParameters(cfg, generation_cfg, nntr_cfg); -} - -void CausalLM::setupParameters(json &cfg, json &generation_cfg, - json &nntr_cfg) { - // Initialize output list - for (unsigned int i = 0; i < BATCH_SIZE; ++i) - output_list.push_back(""); - - // allocate memory for the internal buffer - ids_history = (unsigned int *)malloc(static_cast(BATCH_SIZE) * - MAX_SEQ_LEN * sizeof(unsigned int)); - - BAD_WORD_IDS = nntr_cfg["bad_word_ids"].get>(); - NUM_BADWORDS = BAD_WORD_IDS.size(); - - LMHEAD_DTYPE = nntr_cfg.contains("lmhead_dtype") - ? nntr_cfg["lmhead_dtype"] - : nntr_cfg["embedding_dtype"]; - - USE_KVCACHE = false; - PRE_COMPUTED_CACHE_PATH = ""; - SYS_PROMP_LEN = 0; - - if (nntr_cfg.contains("system_prompt") && - nntr_cfg["system_prompt"].contains("kvcache")) { - USE_KVCACHE = true; - PRE_COMPUTED_CACHE_PATH = - nntr_cfg["system_prompt"]["kvcache"]["pre_computed_cache_path"]; - if (nntr_cfg["system_prompt"]["kvcache"].contains("sys_prompt_token_size")) - SYS_PROMP_LEN = - nntr_cfg["system_prompt"]["kvcache"]["sys_prompt_token_size"] - .get(); - } - - if (generation_cfg["eos_token_id"].is_array()) { - EOS_TOKEN_ID = - generation_cfg["eos_token_id"].empty() - ? cfg["eos_token_id"].get>() - : generation_cfg["eos_token_id"].get>(); - } else { - EOS_TOKEN_ID.clear(); - EOS_TOKEN_ID.push_back(generation_cfg["eos_token_id"].get()); - } - BOS_TOKEN_ID = generation_cfg["bos_token_id"].empty() - ? cfg["bos_token_id"].get() - : generation_cfg["bos_token_id"].get(); - TOP_K = generation_cfg.contains("top_k") - ? generation_cfg["top_k"].get() - : 20; - TOP_P = generation_cfg.contains("top_p") - ? generation_cfg["top_p"].get() - : 0.95; - TEMPERATURE = generation_cfg.contains("temperature") - ? generation_cfg["temperature"].get() - : 0.7; - DO_SAMPLE = generation_cfg.contains("do_sample") - ? generation_cfg["do_sample"].get() - : false; - global_token_len = 0; -} - -void CausalLM::constructModel() { - - // It adds all transformer model's block to model - Transformer::constructModel(); - - const std::string lmhead_type = - TIE_WORD_EMBEDDINGS ? "tie_word_embeddings" : "lm_head"; - - // add lmhead - std::vector lmhead_prop = { - withKey("name", "output_of_causallm"), - withKey("unit", NUM_VOCAB), - withKey("disable_bias", "true"), - withKey("input_layers", "output_norm"), - withKey("weight_dtype", LMHEAD_DTYPE), - }; - - if (TIE_WORD_EMBEDDINGS) - lmhead_prop.emplace_back(withKey("shared_from", "embedding0")); - - model->addLayer(createLayer(lmhead_type, lmhead_prop)); -} - -void CausalLM::registerOutputs( - std::unique_ptr &tokenizer, - std::vector ids, unsigned int pos, - const std::vector &eos_list, bool log_output) { - - static const std::vector puncts{',', '!', ':', ';', '?'}; - for (size_t b = 0; b < ids.size(); ++b) { - if (!eos_list[b]) { - pending_ids_.push_back(static_cast(ids[b])); - ids_history[b * MAX_SEQ_LEN + pos] = ids[b]; - std::string decoded_str = tokenizer->Decode(pending_ids_); - - if (std::find(puncts.begin(), puncts.end(), decoded_str.back()) != - puncts.end()) { - // last symbol is a punctuation, hold on - } else if (decoded_str.size() >= 3 && - decoded_str.compare(decoded_str.size() - 3, 3, "") == 0) { - // ends with an incomplete token, hold on - } else { - if (log_output) { -#if defined(_WIN32) - std::wcout << L"" << utf8_to_wstring(decoded_str); - std::wcout.flush(); -#else - std::cout << decoded_str; - std::cout.flush(); -#endif - } - output_list[b].append(decoded_str); - pending_ids_.clear(); - } - } - } -} - -void CausalLM::save_kvcache(std::string path, int to_) { - auto f = nntrainer::checkedOpenStream( - path, std::ios::out | std::ios::binary | std::ios::trunc); - - std::function - fn = [&f](ml::train::Layer &l, nntrainer::RunLayerContext &context, - void *idx) { - if (l.getType() == quick_dot_ai::MHACoreLayer::type) { - int to = static_cast(reinterpret_cast(idx)); - auto k_cache = context.getTensor(0); - auto v_cache = context.getTensor(1); - ml::train::TensorDim k_dim = k_cache.getDim(); - ml::train::TensorDim v_dim = v_cache.getDim(); - k_dim.height(to); - v_dim.height(to); - nntrainer::Tensor k_cache_prompt = - k_cache.getSharedDataTensor(k_dim, 0, true); - nntrainer::Tensor v_cache_prompt = - v_cache.getSharedDataTensor(v_dim, 0, true); - k_cache_prompt.save(f); - v_cache_prompt.save(f); - } - }; - void *arg = reinterpret_cast(static_cast(to_)); - model->forEachLayer(fn, arg); - f.close(); -} - -void CausalLM::load_kvcache(std::string path, int to_) { - auto f = nntrainer::checkedOpenStream( - path, std::ios::in | std::ios::binary); - - model->allocate(ml::train::ExecutionMode::INFERENCE); - - std::function - fn = [&f](ml::train::Layer &l, nntrainer::RunLayerContext &context, - void *idx) { - if (l.getType() == quick_dot_ai::MHACoreLayer::type) { - auto k_cache = context.getTensor(0); - auto v_cache = context.getTensor(1); - int to = static_cast(reinterpret_cast(idx)); - ml::train::TensorDim k_dim = k_cache.getDim(); - ml::train::TensorDim v_dim = v_cache.getDim(); - k_dim.height(to); - v_dim.height(to); - nntrainer::Tensor k_cache_prompt = - k_cache.getSharedDataTensor(k_dim, 0, true); - nntrainer::Tensor v_cache_prompt = - v_cache.getSharedDataTensor(v_dim, 0, true); - k_cache_prompt.read(f); - v_cache_prompt.read(f); - } - }; - void *arg = reinterpret_cast(static_cast(to_)); - model->forEachLayer(fn, arg); - f.close(); -} - -std::vector CausalLM::generate(float *logits, bool do_sample, - float repetition_penalty, - unsigned int *input_ids, - unsigned int NUM_INPUT_IDS) { - - std::vector outputs; - for (unsigned int iteration = 0; iteration < BATCH_SIZE; ++iteration) { - - // apply repetition penalty - if (repetition_penalty != 1 && input_ids != nullptr && NUM_INPUT_IDS != 0) { - applyRepetitionPenalty(logits, input_ids, NUM_INPUT_IDS, - repetition_penalty); - } - - // apply bad words penalty - if (BAD_WORD_IDS.size() != 0 && NUM_BADWORDS != 0) { - applyBadWordsPenalty(logits, BAD_WORD_IDS.data(), NUM_BADWORDS); - } - - // return argmax if do_sample is false - if (do_sample == false) { - unsigned int argmax_idx = - std::distance(logits, std::max_element(logits, logits + NUM_VOCAB)); - outputs.push_back(argmax_idx); - } else { - // apply temperature & top-k & top-p to logits - float max_logits = applyTKP(logits, NUM_VOCAB, TEMPERATURE, TOP_K, TOP_P); - // transform logits to softmax - float sum_exp_logits = 0; - for (unsigned int i = 0; i < NUM_VOCAB; i++) { - float exp_x = exp(logits[i] - max_logits); - sum_exp_logits += exp_x; - logits[i] = exp_x; - } - - for (unsigned int i = 0; i < NUM_VOCAB; ++i) { - logits[i] /= sum_exp_logits; - } - - // sample from final logits - std::discrete_distribution dist(logits, logits + NUM_VOCAB); - unsigned int sampled_idx = dist(rng); - - // add sampled word - outputs.push_back(sampled_idx); - } - - // set batch offset - logits = logits + NUM_VOCAB; - input_ids = input_ids + MAX_SEQ_LEN; - } - - return outputs; -}; - -void CausalLM::registerCustomLayers() { - Transformer::registerCustomLayers(); - const auto &ct_engine = nntrainer::Engine::Global(); - const auto app_context = - static_cast(ct_engine.getRegisteredContext("cpu")); - try { - app_context->registerFactory(nntrainer::createLayer); - } catch (std::invalid_argument &e) { - std::cerr << "failed to register factory, reason: " << e.what() - << std::endl; - } -} - -void CausalLM::run(const WSTR prompt, void *output_buf, bool log_output) { - run(prompt, "", "", output_buf, log_output); -} - -void CausalLM::run(const WSTR prompt, const WSTR system_prompt, - const WSTR tail_prompt, void *output_buf, bool log_output) { - - auto start_total = std::chrono::high_resolution_clock::now(); - if (!is_initialized) { - throw std::runtime_error("CausalLM model is not initialized. Please call " - "initialize() before run()."); - } - - has_run_ = false; - - output_list.clear(); - for (unsigned int b = 0; b < BATCH_SIZE; ++b) { - output_list.push_back(""); - } - - if (MAX_SEQ_LEN < INIT_SEQ_LEN) { - throw std::invalid_argument( - "MAX_SEQ_LEN must be greater than or equal to INIT_SEQ_LEN"); - } - - /** - * Variables for Log - */ - unsigned int generation_cnt = 0; - int64_t total_generation_duration = 0; - - /** - * INPUT PREPARATION - */ - std::vector input; - std::vector label; - - /** - * SAVE_KVCACHE ? - * if USE_KVCACHE && system_prompt is given && but the - * PRE_COMPUTED_CACHE_PATH does not exist - */ - SAVE_KVCACHE = (USE_KVCACHE && system_prompt != "" && - !std::filesystem::exists(PRE_COMPUTED_CACHE_PATH)); - -#if defined(_WIN32) - if (log_output) - std::wcout << L"" << system_prompt << L"" << text_ << std::endl; - std::wstring prompt_ = prompt; - if (!SAVE_KVCACHE) - prompt_ += TAIL_PROMPT; - std::wstring_convert> converter; - auto _input = tokenizer->Encode(converter.to_bytes(prompt_)); -#else - // print input text - if (log_output) - std::cout << system_prompt << prompt << tail_prompt << std::endl; - - // actual prompt to be used in computation - std::string prompt_; - - if (USE_KVCACHE) { - prompt_ = SAVE_KVCACHE ? system_prompt : (prompt + tail_prompt); - } else { - prompt_ = system_prompt + prompt + tail_prompt; - } - - if (USE_KVCACHE && !SAVE_KVCACHE && SYS_PROMP_LEN == 0) - SYS_PROMP_LEN = tokenizer->Encode(system_prompt).size(); - - auto _input = tokenizer->Encode(prompt_); - ///@note insert bos token at the beginning of the input - // _input.insert(_input.begin(), BOS_TOKEN_ID); -#endif - - // | <------------------- MAX_SEQ_LEN -------------------> | - // || || - // |<-- System prompt -->||<-- input -->||<-- generate -->| - - std::vector init_input; - unsigned int _len = _input.size(); - unsigned int num_allow_str = MAX_SEQ_LEN - NUM_TO_GENERATE; - unsigned text_len = _len; - - if (_len > num_allow_str) - text_len = num_allow_str; - - // feed only available length - // if _input is allowed, it feeds all of the _input - // otherwise, feeds only a part of _input - for (unsigned int i = 0; i < text_len; ++i) - init_input.push_back(_input[i]); - - ///@todo currently, the whole sequence may not be fed into the model - /// This should be handled later. - _input.clear(); - - unsigned int init_len = init_input.size(); - float *input_sample = - (float *)malloc(sizeof(float) * BATCH_SIZE * MAX_SEQ_LEN); - std::vector eos_list(BATCH_SIZE, false); - - unsigned int input_len = init_len; - unsigned int token_generation_idx = input_len + 1; - - for (unsigned int b = 0; b < BATCH_SIZE; ++b) { - for (unsigned int i = 0; i < input_len; ++i) { - input_sample[static_cast(b) * MAX_SEQ_LEN + i] = - static_cast(init_input[i]); - ids_history[static_cast(b) * MAX_SEQ_LEN + i] = init_input[i]; - } - } - - /** - * PREFILL - */ - std::vector token_ids; - input.push_back(input_sample); - - ///@note contains possible bug - // std::vector input_dims; - // ml::train::TensorDim input_dim(1, 1, input_len, DIM); - // input_dims.push_back(input_dim); - // model->resetInputDimension(input_dims); - - auto start_prefill = std::chrono::high_resolution_clock::now(); - - std::vector output; - - if (SAVE_KVCACHE) { - //@note This is for the save the kv cache. precomputed kv cache should be - // always located at the begining of the prompt. - // Therefore, it start from 0. and system prompt should be saved in the - // init_input, so that we can compute system prompt size properly - // - // The structure of this precomputed K,V Cache is : - // - // //<-- System Prompt -->/<-- Input Tokens -->/<-- Tail prompt --> // - // //< Precomputed cache >/<--given as input-->/<--- from json ---->// - // - - if (log_output) - std::cout << "\n==============[KV CACHE SAVE MODE]================\n"; - output = model->incremental_inference(BATCH_SIZE, input, label, input_len, - 0 + global_token_len, - input_len + global_token_len, false); - - SYS_PROMP_LEN = input_len; - save_kvcache(PRE_COMPUTED_CACHE_PATH, SYS_PROMP_LEN); - - if (log_output) { - - std::cout - << "kv caches are saved in " << PRE_COMPUTED_CACHE_PATH << std::endl - << "and the size of prompt is " << SYS_PROMP_LEN << ".\n" - << "You may need this prompt lenth to set the \"sys_prompt_token_size\"" - << "\n==================================================\n" - << std::endl; - } - return; - } - - if (USE_KVCACHE) { - load_kvcache(PRE_COMPUTED_CACHE_PATH, SYS_PROMP_LEN); - } else { - SYS_PROMP_LEN = 0; - } - output = model->incremental_inference(BATCH_SIZE, input, label, init_len, - SYS_PROMP_LEN, - SYS_PROMP_LEN + input_len, false); - - // post process of model output - std::vector id_list(generate_multi_tokens( - output[0], NUM_VOCAB, BATCH_SIZE, 1, ids_history, _len)); - - if (init_len < INIT_SEQ_LEN) - registerOutputs(tokenizer, id_list, init_len, eos_list, log_output); - - // output should be deallocated after use - for (auto &out : output) { - delete[] out; - } - - auto finish_prefill = std::chrono::high_resolution_clock::now(); - auto prefill_duration = std::chrono::duration_cast( - finish_prefill - start_prefill); - - /** - * TOKEN GENERATION - */ - - input_len += SYS_PROMP_LEN; - - // Update generated token by prefill as an input - for (unsigned int b = 0; b < BATCH_SIZE; ++b) - input_sample[static_cast(b) * MAX_SEQ_LEN] = - static_cast(id_list[b]); - - auto start_generation = std::chrono::high_resolution_clock::now(); - - for (token_generation_idx = input_len + 1; - token_generation_idx < input_len + 1 + NUM_TO_GENERATE; - ++token_generation_idx) { - - auto output_interval = - model->incremental_inference(BATCH_SIZE, input, label, input_len, - token_generation_idx - 1 + global_token_len, - token_generation_idx + global_token_len); - std::vector ids_list(generate(output_interval[0], DO_SAMPLE)); - if (token_generation_idx < input_len) { - for (unsigned int b = 0; b < BATCH_SIZE; ++b) { - input_sample[static_cast(b) * MAX_SEQ_LEN] = - static_cast(init_input[token_generation_idx - SYS_PROMP_LEN]); - } - registerOutputs(tokenizer, ids_list, token_generation_idx, eos_list, - log_output); - } else { - for (unsigned int b = 0; b < BATCH_SIZE; ++b) { - input_sample[static_cast(b) * MAX_SEQ_LEN] = - static_cast(ids_list[b]); - } - registerOutputs(tokenizer, ids_list, token_generation_idx, eos_list, - log_output); - } - ++generation_cnt; - - // output should be deallocated after use - for (auto out : output_interval) { - delete[] out; - } - - // check FINISH - for (unsigned int j = 0; j < BATCH_SIZE; ++j) { - if (!eos_list[j] && (std::find(EOS_TOKEN_ID.begin(), EOS_TOKEN_ID.end(), - ids_list[j]) != EOS_TOKEN_ID.end())) { - eos_list[j] = true; - } - } - - bool is_finish = true; - for (unsigned int j = 0; j < BATCH_SIZE; ++j) { - if (!eos_list[j]) { - is_finish = false; - break; - } - } - - if (is_finish) { - free(input_sample); - break; - } - } - - global_token_len += (generation_cnt + init_len); - - if (output_buf != nullptr) { - *static_cast *>(output_buf) = output_list; - } - - auto finish_generation = std::chrono::high_resolution_clock::now(); - auto generation_duration = - std::chrono::duration_cast(finish_generation - - start_generation); - - auto finish_total = std::chrono::high_resolution_clock::now(); - auto total_duration = std::chrono::duration_cast( - finish_total - start_total); - size_t peak_memory = getPeakMemoryKb(); - - if (log_output) { - - std::cout << "\n\n"; - std::cout << "=================[ LLM with NNTrainer ]===================\n"; - std::cout << "prefill: " << init_len << " tokens, " - << prefill_duration.count() << " ms, " - << ((double)init_len / prefill_duration.count() * 1000) - << " TPS\n"; - std::cout << "generation: " << generation_cnt << " tokens, " - << generation_duration.count() << " ms, " - << ((double)generation_cnt / generation_duration.count() * 1000) - << " TPS\n"; - std::cout << "total: " << total_duration.count() << " ms\n"; - std::cout << "peak memory: " << peak_memory << " KB\n"; - std::cout << "==========================================================\n"; - } - - performance_metrics.prefill_tokens = init_len; - performance_metrics.prefill_duration_ms = prefill_duration.count(); - performance_metrics.generation_tokens = generation_cnt; - performance_metrics.generation_duration_ms = generation_duration.count(); - performance_metrics.total_duration_ms = total_duration.count(); - performance_metrics.peak_memory_kb = peak_memory; - - has_run_ = true; -} - -std::string CausalLM::getOutput(int batch_idx) const { - if (batch_idx < 0 || batch_idx >= static_cast(output_list.size())) { - return ""; - } - return output_list[batch_idx]; -} - -} // namespace quick_dot_ai diff --git a/models/causal_lm.h b/models/causal_lm.h deleted file mode 100644 index cf4da746..00000000 --- a/models/causal_lm.h +++ /dev/null @@ -1,170 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * Copyright (C) 2025 Jijoong Moon - * Copyright (C) 2025 Seungback Hong - * Copyright (C) 2025 Hyeonseok Lee - * Copyright (C) 2025 Eunju Yang - * - * @file causal_lm.h - * @date 10 July 2025 - * @see https://github.com/nntrainer/nntrainer - * @author Jijoong Moon - * @author Seungbaek Hong - * @author Hyeonseok Lee - * @author Eunju Yang - * @bug No known bugs except for NYI items - * @note This causal_lm.h constructs a class for Transformer-based Causal - * Language Model (CausalLM). It aims to support AutoModelForCausalLM with - * nntrainer. It supports the following models: - * - Qwen3 - * - Qwen3-MoE - * @note This CausalLM assumes the Decoder-based model, which structure is - * - * [Transformer] - * | - * [LMHead] - */ - -#ifndef __CAUSAL_LM_H__ -#define __CAUSAL_LM_H__ - -#pragma once -#ifdef _WIN32 -#define WIN_EXPORT __declspec(dllexport) -#define WSTR std::wstring -#define WCHAR_P wchar_t * -#else -#define WIN_EXPORT -#define WSTR std::string -#define WCHAR_P std::string & -#endif - -#include - -namespace quick_dot_ai { - -/** - * @brief CausalLM Class - */ -WIN_EXPORT class CausalLM : virtual public Transformer { - -public: - /** - * @brief Construct a new CausalLM object - * @param cfg Configuration for the model (config.json) - * @param generation_cfg Configuration for the generation - * (generation_config.json) - * @param nntr_cfg Configuration for nntrainer (nntrainer_config.json) - */ - CausalLM(json &cfg, json &generation_cfg, json &nntr_cfg); - - /** - * @brief Destroy the CausalLM object - */ - virtual ~CausalLM() { - if (ids_history) - free(ids_history); - } - - /** - * @brief run the CausalLM model (simple) - */ - void run(const WSTR prompt, void *output_buf = nullptr, - bool log_output = true) override; - - /** - * @brief run the CausalLM model (full) - */ - void run(const WSTR prompt, const WSTR system_prompt = "", - const WSTR tail_prompt = "", void *output_buf = nullptr, - bool log_output = true) override; - - /** - * @brief Get the generated output text - * @param batch_idx Index of the batch item - * @return Generated text string - */ - std::string getOutput(int batch_idx = 0) const; - - /** - * @brief get the status of run - */ - bool hasRun() const { return has_run_; } - -protected: - /** - * @brief Setup the parameters for the CausalLM model - */ - virtual void setupParameters(json &cfg, json &generation_cfg, - json &nntr_cfg) override; - - /** - * @brief Construct Model - */ - virtual void constructModel() override; - - /** - * @brief register Outputs - */ - virtual void - registerOutputs(std::unique_ptr &tokenizer, - std::vector ids, unsigned int pos, - const std::vector &eos_list, bool log_output = true); - - /** - * @brief save kv cache - */ - WIN_EXPORT virtual void save_kvcache(std::string path, int to); - - /** - * @brief load kv cache - */ - WIN_EXPORT virtual void load_kvcache(std::string path, int to); - - /** - * @brief generate - */ - std::vector generate(float *logits, bool do_sample, - float repetition_penalty = 1, - unsigned int *input_ids = nullptr, - unsigned int NUM_INPUT_IDS = 0); - - /** - * @brief registerCutomLayers - */ - void registerCustomLayers() override; - - /** internal buffer */ - std::vector - output_list; /**< List of output names for the model */ - unsigned int *ids_history; /**< History of input IDs for the model */ - - std::vector pending_ids_; - - std::string LMHEAD_DTYPE; /** embedding dtype */ - std::vector EOS_TOKEN_ID; - unsigned int BOS_TOKEN_ID; - float TEMPERATURE; - unsigned int TOP_K; - float TOP_P; - - bool DO_SAMPLE = false; /**< Whther to use sampling for generation */ - - std::vector BAD_WORD_IDS; /**< List of bad word IDs */ - unsigned int NUM_BADWORDS; /**< Number of bad words */ - - unsigned int SYS_PROMP_LEN; - std::string PRE_COMPUTED_CACHE_PATH; - std::string TAIL_PROMPT; - bool SAVE_KVCACHE; - bool USE_KVCACHE; - unsigned int global_token_len; - - bool has_run_ = false; - - std::mt19937 rng; /**< Random Number Gen */ -}; - -} // namespace quick_dot_ai - -#endif diff --git a/models/gemma3/embedding_gemma.cpp b/models/gemma3/embedding_gemma.cpp deleted file mode 100644 index ad5a0f25..00000000 --- a/models/gemma3/embedding_gemma.cpp +++ /dev/null @@ -1,28 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * Copyright (C) 2026 Samsung Electronics Co., Ltd. All Rights Reserved. - * - * @file embedding_gemma.cpp - * @date 11 Jan 2026 - * @see https://github.com/nntrainer/nntrainer - * @author Seungbaek Hong - * @bug No known bugs except for NYI items - * @brief This file defines Gemma3 Embedding model - */ - -#include - -namespace quick_dot_ai { - -void EmbeddingGemma::setupParameters(json &cfg, json &generation_cfg, - json &nntr_cfg) { - Gemma3Transformer::setupParameters(cfg, generation_cfg, nntr_cfg); - SentenceTransformer::setupParameters(cfg, generation_cfg, nntr_cfg); -} - -void EmbeddingGemma::registerCustomLayers() { - SentenceTransformer::registerCustomLayers(); - Gemma3Transformer::registerCustomLayers(); -} - -} // namespace quick_dot_ai diff --git a/models/gemma3/embedding_gemma.h b/models/gemma3/embedding_gemma.h deleted file mode 100644 index d7a606c1..00000000 --- a/models/gemma3/embedding_gemma.h +++ /dev/null @@ -1,63 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * Copyright (C) 2026 SeungBaek Hong - * - * @file embedding_gemma.h - * @date 11 Jan 2026 - * @see https://github.com/nntrainer/nntrainer - * @author Seungbaek Hong - * @bug No known bugs except for NYI items - * @note This embedding_gemma.h constructs a class for Gemma3-based Embedding - * model. - */ - -#ifndef __EMBEDDING_GEMMA_H__ -#define __EMBEDDING_GEMMA_H__ - -#include -#include - -namespace quick_dot_ai { - -/** - * @brief EmbeddingGemma Class - */ -class EmbeddingGemma : public SentenceTransformer, public Gemma3Transformer { - -public: - static constexpr const char *architectures = "EmbeddingGemma"; - - /** - * @brief Construct a new EmbeddingGemma object - * @param cfg Configuration for the model - * @param generation_cfg Configuration for generation - * @param nntr_cfg Configuration for nntrainer - */ - EmbeddingGemma(json &cfg, json &generation_cfg, json &nntr_cfg) : - Transformer( - Gemma3Transformer::sanitizeConfig(cfg), - Gemma3Transformer::sanitizeGenerationConfig(generation_cfg, cfg), - nntr_cfg, ModelType::EMBEDDING), - SentenceTransformer(cfg, generation_cfg, nntr_cfg), - Gemma3Transformer(cfg, generation_cfg, nntr_cfg) {} - - /** - * @brief Destroy the EmbeddingGemma object - */ - virtual ~EmbeddingGemma() = default; - - /** - * @brief Setup parameters - */ - void setupParameters(json &cfg, json &generation_cfg, - json &nntr_cfg) override; - - /** - * @brief register CustomLayers - */ - void registerCustomLayers() override; -}; - -} // namespace quick_dot_ai - -#endif // __EMBEDDING_GEMMA_H__ diff --git a/models/gemma3/function.cpp b/models/gemma3/function.cpp deleted file mode 100644 index c3366569..00000000 --- a/models/gemma3/function.cpp +++ /dev/null @@ -1,218 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * Copyright (C) 2025 SeungBaek Hong - * - * @file function.cpp - * @date 19 January 2026 - * @brief This defines a chat format for FunctionGemma - * @see https://github.com/nntrainer/nntrainer - * @author Seungbaek Hong - * @bug No known bugs except for NYI items - */ - -#include "function.h" -#include -#include -#include -#include - -namespace quick_dot_ai { -namespace gemma3 { - -// Helper to escape string values -std::string escape_value(const std::string &value) { - return "" + value + ""; -} - -// Helper to uppercase string -std::string to_upper(std::string s) { - std::transform(s.begin(), s.end(), s.begin(), ::toupper); - return s; -} - -// Recursively format parameters -std::string format_parameters(const json &properties) { - std::stringstream ss; - bool first = true; - for (auto &[key, val] : properties.items()) { - if (!first) - ss << ","; - ss << key << ":{"; - - // inner properties - bool inner_first = true; - if (val.contains("description")) { - if (!inner_first) - ss << ","; - ss << "description:" - << escape_value(val["description"].get()); - inner_first = false; - } - - if (val.contains("type")) { - if (!inner_first) - ss << ","; - ss << "type:" << escape_value(to_upper(val["type"].get())); - inner_first = false; - } - - // Recursion for nested objects - if (val.contains("properties")) { - if (!inner_first) - ss << ","; - ss << "properties:{" << format_parameters(val["properties"]) << "}"; - inner_first = false; - } - - if (val.contains("required")) { - if (!inner_first) - ss << ","; - ss << "required:["; - bool req_first = true; - for (const auto &item : val["required"]) { - if (!req_first) - ss << ","; - ss << escape_value(item.get()); - req_first = false; - } - ss << "]"; - inner_first = false; - } - - ss << "}"; - first = false; - } - return ss.str(); -} - -std::string format_function_declaration(const json &tool) { - std::stringstream ss; - if (tool.contains("function")) { - const auto &func = tool["function"]; - ss << "declaration:" << func.value("name", "") << ","; - ss << "description:" << escape_value(func.value("description", "")) << ","; - - ss << "parameters:{"; - if (func.contains("parameters")) { - const auto ¶ms = func["parameters"]; - - if (params.contains("properties")) { - ss << "properties:{"; - ss << format_parameters(params["properties"]); - ss << "},"; - } - if (params.contains("required")) { - ss << "required:["; - bool first_req = true; - for (const auto &req : params["required"]) { - if (!first_req) - ss << ","; - ss << escape_value(req.get()); - first_req = false; - } - ss << "],"; - } - if (params.contains("type")) { - ss << "type:" << escape_value(to_upper(params.value("type", "object"))); - } - } - ss << "}"; - } - return ss.str(); -} - -// Helper to format a single argument value (for tool calls/responses) -std::string format_argument_value(const json &value) { - if (value.is_string()) { - return value.get(); - } else { - return value.dump(); - } -} - -std::string apply_function_gemma_template(const json &chat_input) { - std::stringstream prompt; - - prompt << ""; - const auto &messages = chat_input["messages"]; - bool tools_inserted = false; - - for (size_t i = 0; i < messages.size(); ++i) { - const auto &message = messages[i]; - std::string role = message.value("role", ""); - if (role == "assistant") - role = "model"; - - // Open turn - if (role != "tool") { - prompt << "" << role << "\n"; - } - - // Content - if (message.contains("content")) { - if (message["content"].is_string()) { - prompt << message["content"].get(); - } - } - - // Insert tools if this is the first message and it is developer/system - if (!tools_inserted && chat_input.contains("tools") && - (role == "developer" || role == "system")) { - for (const auto &tool : chat_input["tools"]) { - prompt << ""; - prompt << format_function_declaration(tool); - prompt << ""; - } - tools_inserted = true; - } - - // Tool calls - if (message.contains("tool_calls")) { - for (const auto &tool_call : message["tool_calls"]) { - const auto &func = tool_call["function"]; - prompt << "call:" - << func["name"].get() << "{"; - // Simplistic argument formatting - if (func.contains("arguments")) { - if (func["arguments"].is_object()) { - bool first = true; - for (auto &[key, val] : func["arguments"].items()) { - if (!first) - prompt << ","; - prompt << key << ":" << format_argument_value(val); - first = false; - } - } else if (func["arguments"].is_string()) { - prompt << func["arguments"].get(); - } - } - prompt << "}"; - } - } - - // End turn - if (role != "tool") { - prompt << "\n"; - } else { - if (message.contains("content")) { - std::string name = message.value("name", ""); - std::string content_str; - if (message["content"].is_string()) - content_str = message["content"].get(); - else - content_str = message["content"].dump(); - - prompt << "response:" << name << "{" - << "value:" << content_str << "}"; - } - } - } - - // Add generation prompt - prompt << "model\n"; - - return prompt.str(); -} - -} // namespace gemma3 -} // namespace quick_dot_ai diff --git a/models/gemma3/function.h b/models/gemma3/function.h deleted file mode 100644 index 9aa75318..00000000 --- a/models/gemma3/function.h +++ /dev/null @@ -1,34 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * Copyright (C) 2025 SeungBaek Hong - * - * @file function.h - * @date 19 January 2026 - * @brief This defines a chat format for FunctionGemma - * @see https://github.com/nntrainer/nntrainer - * @author Seungbaek Hong - * @bug No known bugs except for NYI items - */ - -#ifndef __GEMMA3_FUNCTION_H__ -#define __GEMMA3_FUNCTION_H__ - -#include -#include - -using json = nlohmann::json; - -namespace quick_dot_ai { -namespace gemma3 { - -/** - * @brief Apply the chat template for FunctionGemma - * @param chat_input The input JSON containing "messages" and optionally "tools" - * @return The formatted prompt string - */ -std::string apply_function_gemma_template(const json &chat_input); - -} // namespace gemma3 -} // namespace quick_dot_ai - -#endif // __GEMMA3_FUNCTION_H__ diff --git a/models/gemma3/gemma3_causallm.cpp b/models/gemma3/gemma3_causallm.cpp deleted file mode 100644 index 58f72e0f..00000000 --- a/models/gemma3/gemma3_causallm.cpp +++ /dev/null @@ -1,295 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * Copyright (C) 2025 Seungbaek Hong - * - * @file gemma3_causallm.cpp - * @date 24 Dec 2025 - * @brief This defines a gemma3 causal language model. - * @see https://github.com/nnstreamer/ - * @author Seungbaek Hong - * @bug No known bugs except for NYI items - * - */ -#include - -#include -#include -#include -#include - -namespace quick_dot_ai { - -json &Gemma3Transformer::sanitizeConfig(json &cfg) { - if (!cfg.contains("tie_word_embeddings")) { - cfg["tie_word_embeddings"] = true; - } - return cfg; -} - -json &Gemma3Transformer::sanitizeGenerationConfig(json &gen_cfg, - const json &cfg) { - if (!gen_cfg.contains("eos_token_id")) { - if (cfg.contains("eos_token_id")) { - auto eos = cfg["eos_token_id"]; - if (eos.is_number()) { - gen_cfg["eos_token_id"] = - std::vector{eos.get()}; - } else { - gen_cfg["eos_token_id"] = eos; - } - } - } else { - auto eos = gen_cfg["eos_token_id"]; - if (eos.is_number()) { - gen_cfg["eos_token_id"] = - std::vector{eos.get()}; - } - } - - return gen_cfg; -} - -void Gemma3Transformer::setupParameters(json &cfg, json &generation_cfg, - json &nntr_cfg) { - Transformer::setupParameters(cfg, generation_cfg, nntr_cfg); - if (cfg.contains("layer_types")) { - layer_types = cfg["layer_types"].get>(); - } - if (cfg.contains("attn_logit_softcapping") && - !cfg["attn_logit_softcapping"].is_null()) { - ATTN_LOGIT_SOFTCAPPING = cfg["attn_logit_softcapping"].get(); - } -} - -std::vector -Gemma3Transformer::createTransformerDecoderBlock(const int layer_id, - std::string input_name) { - - std::vector layers; - - layers.push_back(createLayer( - "rms_norm", - {withKey("name", "layer" + std::to_string(layer_id) + "_attention_norm"), - withKey("input_layers", input_name), - withKey("epsilon", std::to_string(NORM_EPS)), - withKey("packed", "false")})); - - auto att_layer = - createAttention(layer_id, INIT_SEQ_LEN, NUM_HEADS, HEAD_DIM, - "layer" + std::to_string(layer_id) + "_attention_norm", - "layer" + std::to_string(layer_id) + "_attention_norm", - "layer" + std::to_string(layer_id) + "_attention_norm"); - layers.insert(layers.end(), att_layer.begin(), att_layer.end()); - - layers.push_back(createLayer( - "rms_norm", {withKey("name", "layer" + std::to_string(layer_id) + - "_post_attention_norm"), - withKey("input_layers", - "layer" + std::to_string(layer_id) + "_attention_out"), - withKey("epsilon", std::to_string(NORM_EPS)), - withKey("packed", "false")})); - - layers.push_back(createLayer( - "addition", - {withKey("name", "layer" + std::to_string(layer_id) + "_post_attention"), - withKey("input_layers", input_name + ",layer" + std::to_string(layer_id) + - "_post_attention_norm")})); - - layers.push_back(createLayer( - "rms_norm", - {withKey("name", "layer" + std::to_string(layer_id) + "pre_ffn_norm"), - withKey("input_layers", - "layer" + std::to_string(layer_id) + "_post_attention"), - withKey("epsilon", std::to_string(NORM_EPS)), - withKey("packed", "false")})); - - auto ffn_layer = - createMlp(layer_id, DIM, INTERMEDIATE_SIZE, - "layer" + std::to_string(layer_id) + "pre_ffn_norm"); - layers.insert(layers.end(), ffn_layer.begin(), ffn_layer.end()); - - layers.push_back(createLayer( - "rms_norm", - {withKey("name", "layer" + std::to_string(layer_id) + "post_ffn_norm"), - withKey("epsilon", std::to_string(NORM_EPS)), - withKey("packed", "false")})); - - layers.push_back(createLayer( - "addition", - {withKey("name", "layer" + std::to_string(layer_id) + "_decoder_output"), - withKey("input_layers", "layer" + std::to_string(layer_id) + - "_post_attention,layer" + - std::to_string(layer_id) + "post_ffn_norm")})); - - return layers; -} - -std::vector Gemma3Transformer::createAttention( - const int layer_id, int seq_len, int n_heads, int head_dim, - std::string query_name, std::string key_name, std::string value_name) { - std::vector layers; - - auto Q = "layer" + std::to_string(layer_id) + "_wq"; - auto Q_norm = "layer" + std::to_string(layer_id) + "_q_norm"; - auto K = "layer" + std::to_string(layer_id) + "_wk"; - auto K_norm = "layer" + std::to_string(layer_id) + "_k_norm"; - auto V = "layer" + std::to_string(layer_id) + "_wv"; - auto A = "layer" + std::to_string(layer_id) + "_attention"; - auto O = "layer" + std::to_string(layer_id) + "_attention_out"; - - // Q layer - std::vector q_params = {withKey("name", Q), - withKey("unit", head_dim * n_heads), - withKey("disable_bias", "true"), - withKey("input_layers", query_name), - withKey("weight_initializer", "ones"), - withKey("weight_dtype", FC_LAYER_DTYPE)}; - layers.push_back(createLayer("fully_connected", q_params)); - - // K layer - std::vector k_params = { - withKey("name", K), - withKey("unit", head_dim * n_heads / GQA_SIZE), - withKey("disable_bias", "true"), - withKey("input_layers", key_name), - withKey("weight_initializer", "ones"), - withKey("weight_dtype", FC_LAYER_DTYPE)}; - layers.push_back(createLayer("fully_connected", k_params)); - - // V layer - std::vector v_params = { - withKey("name", V), - withKey("unit", head_dim * n_heads / GQA_SIZE), - withKey("disable_bias", "true"), - withKey("input_layers", value_name), - withKey("weight_initializer", "ones"), - withKey("weight_dtype", FC_LAYER_DTYPE)}; - layers.push_back(createLayer("fully_connected", v_params)); - - // q_norm - std::vector q_norm_params = { - withKey("name", Q_norm), withKey("input_layers", Q), - withKey("packed", "false"), withKey("epsilon", std::to_string(NORM_EPS)), - withKey("feature_size", std::to_string(head_dim))}; - layers.push_back(createLayer("reshaped_rms_norm", q_norm_params)); - - // k_norm - std::vector k_norm_params = { - withKey("name", K_norm), withKey("input_layers", K), - withKey("packed", "false"), withKey("epsilon", std::to_string(NORM_EPS)), - withKey("feature_size", std::to_string(head_dim))}; - layers.push_back(createLayer("reshaped_rms_norm", k_norm_params)); - - // Attention core layer - unsigned int window_size = UINT_MAX; - if (!layer_types.empty()) { - if (layer_id < layer_types.size()) { - if (layer_types[layer_id] == "sliding_attention") { - window_size = SLIDING_WINDOW; - } - } - } else { - window_size = SLIDING_WINDOW; - } - - float rope_theta = ROPE_THETA; // Default global - if (!layer_types.empty() && layer_id < layer_types.size()) { - if (layer_types[layer_id] == "sliding_attention") { - rope_theta = 10000.0f; - } - } - - std::vector a_params = { - withKey("name", A), - withKey("num_heads", n_heads), - withKey("num_heads_kv", n_heads / GQA_SIZE), - withKey("max_timestep", std::to_string(INIT_SEQ_LEN + NUM_TO_GENERATE)), - withKey("sliding_window", window_size), - withKey("rope_theta", std::to_string(rope_theta)), - withKey("max_new_tokens", std::to_string(NUM_TO_GENERATE)), - withKey("attn_logit_softcapping", std::to_string(ATTN_LOGIT_SOFTCAPPING)), - withKey("is_causal", IS_CAUSAL ? "true" : "false"), - withKey("input_layers", {Q_norm, K_norm, V})}; - layers.push_back(createLayer("mha_core", a_params)); - - // O layer - std::vector o_params = {withKey("name", O), - withKey("unit", DIM), - withKey("disable_bias", "true"), - withKey("input_layers", A), - withKey("weight_initializer", "ones"), - withKey("weight_dtype", FC_LAYER_DTYPE)}; - layers.push_back(createLayer("fully_connected", o_params)); - - return layers; -} - -std::vector Gemma3Transformer::createMlp(const int layer_id, - int dim, int hidden_dim, - std::string input_name) { - std::vector layers; - - // Gate projection - layers.push_back(createLayer( - "fully_connected", - {withKey("name", "layer" + std::to_string(layer_id) + "_ffn_gate"), - withKey("unit", hidden_dim), withKey("disable_bias", "true"), - withKey("input_layers", input_name), withKey("weight_initializer", "ones"), - withKey("weight_dtype", FC_LAYER_DTYPE)})); - - // GeLU - layers.push_back(createLayer( - "activation", - {withKey("name", "layer" + std::to_string(layer_id) + "_ffn_gate_gelu"), - withKey("activation", "tanh_gelu"), - withKey("input_layers", - "layer" + std::to_string(layer_id) + "_ffn_gate")})); - - // Up projection - layers.push_back(createLayer( - "fully_connected", - {withKey("name", "layer" + std::to_string(layer_id) + "_ffn_up"), - withKey("unit", hidden_dim), withKey("disable_bias", "true"), - withKey("input_layers", input_name), withKey("weight_initializer", "ones"), - withKey("weight_dtype", FC_LAYER_DTYPE)})); - - // Multiply - layers.push_back(createLayer( - "multiply", - {withKey("name", "layer" + std::to_string(layer_id) + "_ffn_geglu"), - withKey("input_layers", "layer" + std::to_string(layer_id) + - "_ffn_gate_gelu,layer" + - std::to_string(layer_id) + "_ffn_up")})); - - // Down projection - layers.push_back(createLayer( - "fully_connected", - {withKey("name", "layer" + std::to_string(layer_id) + "_ffn_down"), - withKey("unit", dim), withKey("disable_bias", "true"), - withKey("input_layers", "layer" + std::to_string(layer_id) + "_ffn_geglu"), - withKey("weight_initializer", "ones"), - withKey("weight_dtype", FC_LAYER_DTYPE)})); - - return layers; -} - -void Gemma3Transformer::registerCustomLayers() { - auto &ct_engine = nntrainer::Engine::Global(); - auto app_context = - static_cast(ct_engine.getRegisteredContext("cpu")); - - try { - app_context->registerFactory( - nntrainer::createLayer); - } catch (std::invalid_argument &e) { - std::cerr << "failed to register factory, reason: " << e.what() - << std::endl; - } -} - -void Gemma3CausalLM::registerCustomLayers() { - CausalLM::registerCustomLayers(); - Gemma3Transformer::registerCustomLayers(); -} - -} // namespace quick_dot_ai diff --git a/models/gemma3/gemma3_causallm.h b/models/gemma3/gemma3_causallm.h deleted file mode 100644 index 6fab61e8..00000000 --- a/models/gemma3/gemma3_causallm.h +++ /dev/null @@ -1,95 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * Copyright (C) 2025 SeungBaek Hong - * - * @file gemma3_causallm.h - * @date 24 Dec 2025 - * @see https://github.com/nnstreamer/nntrainer - * @author Seungbaek Hong - * @bug No known bugs except for NYI items - */ - -#ifndef __GEMMA3_CAUSAL_LM_H__ -#define __GEMMA3_CAUSAL_LM_H__ - -#include - -namespace quick_dot_ai { - -/** - * @brief Gemma3Transformer class - */ -class Gemma3Transformer : virtual public Transformer { - -public: - static constexpr const char *architectures = "Gemma3Transformer"; - - Gemma3Transformer(json &cfg, json &generation_cfg, json &nntr_cfg) : - Transformer(sanitizeConfig(cfg), - sanitizeGenerationConfig(generation_cfg, cfg), nntr_cfg) { - if (cfg.contains("layer_types")) { - layer_types = cfg["layer_types"].get>(); - } - EMBEDDING_SCALE = std::sqrt(static_cast(cfg["hidden_size"])); - } - - virtual ~Gemma3Transformer() = default; - -protected: - static json &sanitizeConfig(json &cfg); - static json &sanitizeGenerationConfig(json &gen_cfg, const json &cfg); - - std::vector layer_types; - -public: - std::vector createAttention(const int layer_id, int seq_len, - int n_heads, int head_dim, - std::string query_name, - std::string key_name, - std::string value_name) override; - - std::vector - createTransformerDecoderBlock(const int layer_id, std::string input_name); - - void setupParameters(json &cfg, json &generation_cfg, - json &nntr_cfg) override; - - std::vector createMlp(const int layer_id, int dim, - int hidden_dim, - std::string input_name) override; - - void registerCustomLayers() override; -}; - -/** - * @brief Gemma3CausalLM class - */ -class Gemma3CausalLM : public CausalLM, public Gemma3Transformer { - -public: - static constexpr const char *architectures = "Gemma3ForCausalLM"; - - Gemma3CausalLM(json &cfg, json &generation_cfg, json &nntr_cfg) : - Transformer(sanitizeConfig(cfg), - sanitizeGenerationConfig(generation_cfg, cfg), nntr_cfg), - CausalLM(sanitizeConfig(cfg), sanitizeGenerationConfig(generation_cfg, cfg), - nntr_cfg), - Gemma3Transformer(sanitizeConfig(cfg), - sanitizeGenerationConfig(generation_cfg, cfg), nntr_cfg) { - } - - virtual ~Gemma3CausalLM() = default; - - void setupParameters(json &cfg, json &generation_cfg, - json &nntr_cfg) override { - CausalLM::setupParameters(cfg, generation_cfg, nntr_cfg); - Gemma3Transformer::setupParameters(cfg, generation_cfg, nntr_cfg); - } - - void registerCustomLayers() override; - -private: -}; -} // namespace quick_dot_ai - -#endif /* __GEMMA3_CAUSAL_LM_H__ */ diff --git a/models/gemma3/meson.build b/models/gemma3/meson.build deleted file mode 100644 index 4bb1dac4..00000000 --- a/models/gemma3/meson.build +++ /dev/null @@ -1,10 +0,0 @@ -gemma3_src = [ - meson.current_source_dir() / 'gemma3_causallm.cpp', - meson.current_source_dir() / 'embedding_gemma.cpp', - meson.current_source_dir() / 'function.cpp', -] - -gemma3_inc = include_directories('.') - -quick_dot_ai_src += gemma3_src -quick_dot_ai_inc += gemma3_inc diff --git a/models/gpt_oss/README.md b/models/gpt_oss/README.md deleted file mode 100644 index ac02f341..00000000 --- a/models/gpt_oss/README.md +++ /dev/null @@ -1,7 +0,0 @@ -# GPT-OSS Model - -This directory contains the implementation for GPT-OSS model. - -## Files -- `gptoss_causallm.cpp`: GPT-OSS implementation. -- `gpt_oss_moe_layer.cpp`: GPT-OSS MoE layer implementation. diff --git a/models/gpt_oss/gpt_oss_moe_layer.cpp b/models/gpt_oss/gpt_oss_moe_layer.cpp deleted file mode 100644 index 5f56d295..00000000 --- a/models/gpt_oss/gpt_oss_moe_layer.cpp +++ /dev/null @@ -1,595 +0,0 @@ -/** - * Copyright (C) 2020 Samsung Electronics Co., Ltd. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * http://www.apache.org/licenses/LICENSE-2.0 - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - * - * @file gpt_oss_moe_layer.cpp - * @date 02 Sep 2025 - * @brief This is a Mixture of Expert Layer Class for Gpt-Oss model - * @see https://github.com/nnstreamer/ - * @author Eunju Yang - * @bug No known bugs except for NYI items - * - */ - -#include -#include -#include -#include -#include -#include -#include - -namespace quick_dot_ai { - -static constexpr size_t SINGLE_INOUT_IDX = 0; - -GptOssMoELayer::GptOssMoELayer() : - LayerImpl(), - num_experts(0), - topk(0), - moe_props(props::NumExperts(), props::NumExpertsPerToken(), - nntrainer::props::Unit()), - expert_gate_proj_indices({}), - expert_gate_bias_indices({}), - expert_up_proj_indices({}), - expert_up_bias_indices({}), - expert_down_proj_indices({}), - expert_down_bias_indices({}), - gate_idx(std::numeric_limits::max()), - gate_bias_idx(std::numeric_limits::max()), - router_logits_idx(std::numeric_limits::max()), - expert_mask_idx(std::numeric_limits::max()) {} - -void GptOssMoELayer::finalize(nntrainer::InitLayerContext &context) { - - // 1. Validate input/output dimensions - NNTR_THROW_IF(context.getNumInputs() != 1, std::invalid_argument) - << "MoE layer only supports single input"; - - auto &weight_regularizer = - std::get(*layer_impl_props); - auto &weight_regularizer_constant = - std::get(*layer_impl_props); - auto &weight_initializer = - std::get(*layer_impl_props); - auto &weight_decay = - std::get(*layer_impl_props); - - // 2. Set output dimensions (same as input) - const auto &in_dim = context.getInputDimensions()[SINGLE_INOUT_IDX]; - const bool is_nchw = context.getFormat() == nntrainer::Tformat::NCHW; - std::vector output_dims(1); - output_dims[SINGLE_INOUT_IDX] = in_dim; - context.setOutputDimensions(output_dims); - - // 3. Get MoE properties - num_experts = std::get(moe_props).get(); - topk = std::get(moe_props).get(); - const unsigned int intermediate_size = - std::get(moe_props).get(); - const unsigned int hidden_size = in_dim.width(); // Feature dimension - - // 4. Initialie gate layer (router) - nntrainer::TensorDim gate_dim( - 1, is_nchw ? 1 : num_experts, is_nchw ? hidden_size : 1, - is_nchw ? num_experts : hidden_size, - nntrainer::TensorDim::TensorType(context.getFormat(), - nntrainer::TensorDim::DataType::FP32), - is_nchw ? 0b0011 : 0b0101); - - gate_idx = context.requestWeight( - gate_dim, weight_initializer, weight_regularizer, - weight_regularizer_constant, weight_decay, "gate", true); - - nntrainer::TensorDim gate_bias_dim( - 1, 1, 1, num_experts, - nntrainer::TensorDim::TensorType(context.getFormat(), - context.getActivationDataType())); - gate_bias_idx = - context.requestWeight(gate_bias_dim, weight_initializer, weight_regularizer, - 1.0f, weight_decay, "gate_bias", false); - - // 5. Initializer expert weights - expert_gate_proj_indices.reserve(num_experts); - expert_up_proj_indices.reserve(num_experts); - expert_down_proj_indices.reserve(num_experts); - expert_gate_bias_indices.reserve(num_experts); - expert_up_bias_indices.reserve(num_experts); - expert_down_bias_indices.reserve(num_experts); - - nntrainer::TensorDim expert_gate_dim( - 1, is_nchw ? 1 : intermediate_size, is_nchw ? hidden_size : 1, - is_nchw ? intermediate_size : hidden_size, - nntrainer::TensorDim::TensorType(context.getFormat(), - context.getWeightDataType()), - is_nchw ? 0b0011 : 0b0101); - - nntrainer::TensorDim expert_down_dim( - 1, is_nchw ? 1 : hidden_size, is_nchw ? intermediate_size : 1, - is_nchw ? hidden_size : intermediate_size, - nntrainer::TensorDim::TensorType(context.getFormat(), - context.getWeightDataType()), - is_nchw ? 0b0011 : 0b0101); - - nntrainer::TensorDim expert_gate_bias_dim( - 1, 1, 1, intermediate_size, - nntrainer::TensorDim::TensorType(context.getFormat(), - context.getActivationDataType()), - is_nchw ? 0b0011 : 0b0101); - - nntrainer::TensorDim expert_down_bias_dim( - 1, 1, 1, hidden_size, - nntrainer::TensorDim::TensorType(context.getFormat(), - context.getActivationDataType()), - is_nchw ? 0b0011 : 0b0101); - - for (unsigned int i = 0; i < num_experts; ++i) { - // Up projection - expert_up_proj_indices.push_back(context.requestWeight( - expert_gate_dim, // Same dimensions as gate projection - weight_initializer, weight_regularizer, weight_regularizer_constant, - weight_decay, "expert_up_" + std::to_string(i), false)); - - expert_up_bias_indices.push_back(context.requestWeight( - expert_gate_bias_dim, // Same dimensions as gate projection - weight_initializer, weight_regularizer, weight_regularizer_constant, - weight_decay, "expert_up_bias_" + std::to_string(i), false)); - - // Gate projection - expert_gate_proj_indices.push_back(context.requestWeight( - expert_gate_dim, weight_initializer, weight_regularizer, - weight_regularizer_constant, weight_decay, - "expert_gate_" + std::to_string(i), false)); - - expert_gate_bias_indices.push_back(context.requestWeight( - expert_gate_bias_dim, // Same dimensions as gate projection - weight_initializer, weight_regularizer, weight_regularizer_constant, - weight_decay, "expert_gate_bias_" + std::to_string(i), false)); - - // Down projection - expert_down_proj_indices.push_back(context.requestWeight( - expert_down_dim, weight_initializer, weight_regularizer, - weight_regularizer_constant, weight_decay, - "expert_down_" + std::to_string(i), false)); - - expert_down_bias_indices.push_back(context.requestWeight( - expert_down_bias_dim, // Same dimensions as gate projection - weight_initializer, weight_regularizer, weight_regularizer_constant, - weight_decay, "expert_down_bias_" + std::to_string(i), false)); - } - - // 6. Request intermediate tensors - const unsigned batch_size = in_dim.batch(); - const unsigned seq_len = in_dim.height(); - const unsigned total_tokens = batch_size * seq_len; - - // Router logits : [batch * seq, num_experts] - router_logits_idx = - context.requestTensor({total_tokens, 1, 1, num_experts}, "router_logits", - nntrainer::Initializer::NONE, false, - nntrainer::TensorLifespan::FORWARD_FUNC_LIFESPAN); - - // Expert mask: [num_experts, batch*seq] - expert_mask_idx = - context.requestTensor({num_experts, 1, topk, total_tokens}, "expert_mask", - nntrainer::Initializer::ZEROS, false, - nntrainer::TensorLifespan::FORWARD_FUNC_LIFESPAN); -} - -void GptOssMoELayer::forwarding(nntrainer::RunLayerContext &context, - bool training) { - nntrainer::Tensor &input = context.getInput(SINGLE_INOUT_IDX); - nntrainer::Tensor &output = context.getOutput(SINGLE_INOUT_IDX); - - nntrainer::Tensor &router_logits = context.getTensor(router_logits_idx); - nntrainer::Tensor &expert_mask = context.getTensor(expert_mask_idx); - - const unsigned batch_size = input.batch(); - const unsigned seq_len = input.height(); - const unsigned hidden_size = input.width(); - const unsigned total_tokens = batch_size * seq_len; - - // reshape input: [B,1,S,H] -> [B*S,1,1,H] - input.reshape({total_tokens, 1, 1, hidden_size}); - - // reshape output: [B,1,S,H] -> [B*S,1,1,H] - output.reshape({total_tokens, 1, 1, hidden_size}); - output.setZero(); - - // routing - nntrainer::Tensor &gate_weights = context.getWeight(gate_idx); - input.dot(gate_weights, router_logits); - router_logits.apply(nntrainer::ActiFunc::softmax, router_logits); - auto topk_result = router_logits.topK(topk); - auto topk_values = std::get<0>(topk_result); - auto topk_indices = std::get<1>(topk_result); - - const uint32_t *indices_data = topk_indices.getData(); -#pragma omp parallel for collapse(2) - for (int i = 0; i < static_cast(total_tokens); ++i) { - for (int k = 0; k < static_cast(topk); ++k) { - expert_mask.setValue(indices_data[i * topk + k], 0, k, i, 1.0f); - } - } - - // Pre-compute expert token assignments for better cache locality - std::vector>> expert_assignments( - num_experts); - for (int i = 0; i < static_cast(total_tokens); ++i) { - for (int k = 0; k < static_cast(topk); ++k) { - unsigned expert_idx = indices_data[i * topk + k]; - float weight = topk_values.getValue(i, 0, 0, k); - expert_assignments[expert_idx].emplace_back(i, weight); - } - } - - // Adaptive optimization based on workload - const int active_experts = - std::count_if(expert_assignments.begin(), expert_assignments.end(), - [](const auto &assignments) { return !assignments.empty(); }); - - // Calculate total work (sum of token assignments across all experts) - int total_work = 0; - for (const auto &assignments : expert_assignments) { - total_work += assignments.size(); - } - - // Use parallel processing only when it's beneficial - const bool use_parallel = (total_work > 4) && (active_experts > 1); - - if (use_parallel) { - // Parallel processing for larger workloads -#pragma omp parallel - { -#pragma omp for schedule(dynamic) - for (int expert_idx = 0; expert_idx < static_cast(num_experts); - ++expert_idx) { - const auto &assignments = expert_assignments[expert_idx]; - if (assignments.empty()) - continue; - - // Use optimized expert forward computation without memory copies - compute_expert_forward( - input, output, assignments, - context.getWeight(expert_gate_proj_indices[expert_idx]), - context.getWeight(expert_up_proj_indices[expert_idx]), - context.getWeight(expert_down_proj_indices[expert_idx]), hidden_size); - } - } - } else { - // Sequential processing for smaller workloads - for (int expert_idx = 0; expert_idx < static_cast(num_experts); - ++expert_idx) { - const auto &assignments = expert_assignments[expert_idx]; - if (assignments.empty()) - continue; - - // Use optimized expert forward computation without memory copies - compute_expert_forward( - input, output, assignments, - context.getWeight(expert_gate_proj_indices[expert_idx]), - context.getWeight(expert_up_proj_indices[expert_idx]), - context.getWeight(expert_down_proj_indices[expert_idx]), hidden_size); - } - } - - // reshape output: [B*S,1,1,H] -> [B,1,S,H] - output.reshape({batch_size, 1, seq_len, hidden_size}); -} - -inline void GptOssMoELayer::compute_expert_forward( - const nntrainer::Tensor &input, nntrainer::Tensor &output, - const std::vector> &token_assignments, - const nntrainer::Tensor &gate_proj, const nntrainer::Tensor &up_proj, - const nntrainer::Tensor &down_proj, unsigned int hidden_size) { - - const unsigned intermediate_size = gate_proj.width(); - const unsigned num_tokens = token_assignments.size(); - - if (num_tokens == 0) - return; - - // Create tensor dimensions for single token processing - nntrainer::TensorDim token_input_dim({1, 1, 1, hidden_size}, - input.getTensorType()); - nntrainer::TensorDim intermediate_dim({1, 1, 1, intermediate_size}, - input.getTensorType()); - nntrainer::TensorDim token_output_dim({1, 1, 1, hidden_size}, - input.getTensorType()); - - // Create a temporary output tensor for this expert to avoid critical section - nntrainer::Tensor expert_output(output.batch(), output.channel(), - output.height(), output.width(), - output.getTensorType()); - expert_output.setZero(); - - // Process each token individually to avoid memory copies - for (size_t i = 0; i < num_tokens; ++i) { - const unsigned token_idx = token_assignments[i].first; - const float weight = token_assignments[i].second; - - // Create shared tensor for input token (no memory copy) - size_t token_offset = token_idx * hidden_size; - nntrainer::Tensor token_input = - input.getSharedDataTensor(token_input_dim, token_offset, true); - - // Create intermediate tensors for this token - nntrainer::Tensor gate_out(intermediate_dim); - nntrainer::Tensor acti_out(intermediate_dim); - nntrainer::Tensor up_out(intermediate_dim); - - // Gate projection using optimized dot operation - token_input.dot(gate_proj, gate_out); - - // Up projection using optimized dot operation - token_input.dot(up_proj, up_out); - up_out.add_i(1); - - // Apply activation (silu) - // (up + 1) * (gate * torch.sigmoid(gate * alpha)) - // swiglu : X = Z * (Y / 1 + exp(-alpha * Y)) - // X := acti_out - // Y := gate_out - // Z := up_out + 1 - for (unsigned int b = 0; b < acti_out.batch(); ++b) { - for (unsigned int c = 0; c < acti_out.channel(); ++c) { - for (unsigned int h = 0; h < acti_out.height(); ++h) { - nntrainer::swiglu( - acti_out.width(), - acti_out.getData() + acti_out.getIndex(b, c, h, 0), - gate_out.getData() + gate_out.getIndex(b, c, h, 0), - up_out.getData() + up_out.getIndex(b, c, h, 0), alpha); - } - } - } - - // Down projection using optimized dot operation - nntrainer::Tensor token_expert_output(token_output_dim); - acti_out.dot(down_proj, token_expert_output); - - // Apply weight and accumulate to expert's temporary output - token_expert_output.multiply_i(weight); - size_t output_offset = token_idx * hidden_size; - nntrainer::Tensor token_output = - expert_output.getSharedDataTensor(token_output_dim, output_offset, true); - - token_output.add_i(token_expert_output); - } - - // Add expert's result to final output (no critical section in sequential - // mode) - output.add_i(expert_output); -} - -inline void GptOssMoELayer::compute_expert_forward_no_critical( - const nntrainer::Tensor &input, nntrainer::Tensor &expert_output, - const std::vector> &token_assignments, - const nntrainer::Tensor &gate_proj, const nntrainer::Tensor &up_proj, - const nntrainer::Tensor &down_proj, unsigned int hidden_size, - const nntrainer::Tensor &gate_bias, const nntrainer::Tensor &up_bias, - const nntrainer::Tensor &down_bias) { - - const unsigned intermediate_size = gate_proj.width(); - const unsigned num_tokens = token_assignments.size(); - - if (num_tokens == 0) - return; - - // Create tensor dimensions for single token processing - nntrainer::TensorDim token_input_dim({1, 1, 1, hidden_size}, - input.getTensorType()); - nntrainer::TensorDim intermediate_dim({1, 1, 1, intermediate_size}, - input.getTensorType()); - nntrainer::TensorDim token_output_dim({1, 1, 1, hidden_size}, - input.getTensorType()); - - // Process each token individually to avoid memory copies - for (size_t i = 0; i < num_tokens; ++i) { - const unsigned token_idx = token_assignments[i].first; - const float weight = token_assignments[i].second; - - // Create shared tensor for input token (no memory copy) - size_t token_offset = token_idx * hidden_size; - nntrainer::Tensor token_input = - input.getSharedDataTensor(token_input_dim, token_offset, true); - - // Create intermediate tensors for this token - nntrainer::Tensor gate_out(intermediate_dim); - nntrainer::Tensor acti_out(intermediate_dim); - nntrainer::Tensor up_out(intermediate_dim); - - // Gate projection using optimized dot operation - token_input.dot(gate_proj, gate_out); - gate_out.add(gate_bias, gate_out); - // gate_out.clamp(min=None, max=limit) - - // Up projection using optimized dot operation - token_input.dot(up_proj, up_out); - up_out.add_i(up_bias); - // up_out.clamp(min=-limit, max=limit) - - // Apply activation (silu) - // (up + 1) * (gate * torch.sigmoid(gate * alpha)) - // swiglu : X = Z * (Y / 1 + exp(-alpha * Y)) - // X := acti_out - // Y := gate_out - // Z := up_out + 1 - up_out.add_i(1); -#pragma omp parallel for collapse(3) - for (unsigned int b = 0; b < acti_out.batch(); ++b) { - for (unsigned int c = 0; c < acti_out.channel(); ++c) { - for (unsigned int h = 0; h < acti_out.height(); ++h) { - nntrainer::swiglu( - acti_out.width(), - acti_out.getData() + acti_out.getIndex(b, c, h, 0), - gate_out.getData() + gate_out.getIndex(b, c, h, 0), - up_out.getData() + up_out.getIndex(b, c, h, 0), alpha); - } - } - } - - // Down projection using optimized dot operation - nntrainer::Tensor token_expert_output(token_output_dim); - acti_out.dot(down_proj, token_expert_output); - token_expert_output.add_i(down_bias); - - // Apply weight and accumulate to expert's output (no critical section - // needed) - token_expert_output.multiply_i(weight); - size_t output_offset = token_idx * hidden_size; - nntrainer::Tensor token_output = - expert_output.getSharedDataTensor(token_output_dim, output_offset, true); - - token_output.add_i(token_expert_output); - } -} - -void GptOssMoELayer::incremental_forwarding(nntrainer::RunLayerContext &context, - unsigned int from, unsigned int to, - bool training) { - - nntrainer::Tensor &input_ = context.getInput(SINGLE_INOUT_IDX); - nntrainer::Tensor &output_ = context.getOutput(SINGLE_INOUT_IDX); - - nntrainer::Tensor &router_logits_ = context.getTensor(router_logits_idx); - nntrainer::Tensor &expert_mask = context.getTensor(expert_mask_idx); - - nntrainer::TensorDim input_step_dim = input_.getDim(); - nntrainer::TensorDim output_step_dim = output_.getDim(); - nntrainer::TensorDim router_logits_step_dim = router_logits_.getDim(); - - input_step_dim.batch(1); - output_step_dim.batch(1); - router_logits_step_dim.batch(to - from); - - input_step_dim.height(to - from); - output_step_dim.height(to - from); - - for (unsigned int b = 0; b < input_.batch(); ++b) { - - auto input = input_.getSharedDataTensor( - input_step_dim, b * input_step_dim.getFeatureLen(), true); - auto output = output_.getSharedDataTensor( - output_step_dim, b * output_step_dim.getFeatureLen(), true); - auto router_logits = - router_logits_.getSharedDataTensor(router_logits_step_dim, 0, true); - - const unsigned batch_size = input.batch(); - const unsigned seq_len = input.height(); - const unsigned hidden_size = input.width(); - const unsigned total_tokens = batch_size * seq_len; - - // reshape input: [B,1,S,H] -> [B*S,1,1,H] - input.reshape({total_tokens, 1, 1, hidden_size}); - - // reshape output: [B,1,S,H] -> [B*S,1,1,H] - output.reshape({total_tokens, 1, 1, hidden_size}); - output.setZero(); - expert_mask.setZero(); - - // routing - nntrainer::Tensor &gate_weights = context.getWeight(gate_idx); - input.dot(gate_weights, router_logits); - nntrainer::Tensor gate_bias = context.getWeight(gate_bias_idx); - router_logits.add_i(gate_bias); - - router_logits.apply(nntrainer::ActiFunc::softmax, router_logits); - auto topk_result = router_logits.topK(topk); - auto topk_values = std::get<0>(topk_result); - auto topk_indices = std::get<1>(topk_result); - - // norm_topk_prob - topk_values.divide_i(topk_values.sum(3)); - - const uint32_t *indices_data = topk_indices.getData(); - // Set expert mask - for (int i = 0; i < static_cast(total_tokens); ++i) { - for (int k = 0; k < static_cast(topk); ++k) { - expert_mask.setValue(indices_data[i * topk + k], 0, k, i, 1.0f); - } - } - - // Pre-compute expert token assignments for better performance - std::vector>> expert_assignments( - num_experts); - for (int i = 0; i < static_cast(total_tokens); ++i) { - for (int k = 0; k < static_cast(topk); ++k) { - unsigned expert_idx = indices_data[i * topk + k]; - float weight = topk_values.getValue(i, 0, 0, k); - expert_assignments[expert_idx].emplace_back(i, weight); - } - } - - // Parallel processing for multiple tokens with many active experts - std::vector expert_outputs(num_experts); - for (int expert_idx = 0; expert_idx < static_cast(num_experts); - ++expert_idx) { - if (!expert_assignments[expert_idx].empty()) { - expert_outputs[expert_idx] = nntrainer::Tensor( - total_tokens, 1, 1, hidden_size, output.getTensorType()); - expert_outputs[expert_idx].setZero(); - } - } - - // #pragma omp parallel for schedule(dynamic) - for (int expert_idx = 0; expert_idx < static_cast(num_experts); - ++expert_idx) { - const auto &assignments = expert_assignments[expert_idx]; - if (assignments.empty()) - continue; - - compute_expert_forward_no_critical( - input, expert_outputs[expert_idx], assignments, - context.getWeight(expert_gate_proj_indices[expert_idx]), - context.getWeight(expert_up_proj_indices[expert_idx]), - context.getWeight(expert_down_proj_indices[expert_idx]), hidden_size, - context.getWeight(expert_gate_bias_indices[expert_idx]), - context.getWeight(expert_up_bias_indices[expert_idx]), - context.getWeight(expert_down_bias_indices[expert_idx])); - } - - // Combine expert outputs - for (int expert_idx = 0; expert_idx < static_cast(num_experts); - ++expert_idx) { - if (!expert_assignments[expert_idx].empty()) { - output.add_i(expert_outputs[expert_idx]); - } - } - - // reshape output: [B*S,1,1,H] -> [B,1,S,H] - output.reshape({batch_size, 1, seq_len, hidden_size}); - } -} - -void GptOssMoELayer::setProperty(const std::vector &values) { - auto remain_props = loadProperties(values, moe_props); - nntrainer::LayerImpl::setProperty(remain_props); -} - -void GptOssMoELayer::calcDerivative(nntrainer::RunLayerContext &context) { - // MoE layer does not support derivative calculation - throw std::runtime_error("MoE layer does not support derivative calculation"); -} - -void GptOssMoELayer::calcGradient(nntrainer::RunLayerContext &context) { - // MoE layer does not support gradient calculation - throw std::runtime_error("MoE layer does not support gradient calculation"); -} - -void GptOssMoELayer::exportTo(nntrainer::Exporter &exporter, - const ml::train::ExportMethods &method) const { - nntrainer::LayerImpl::exportTo(exporter, method); - exporter.saveResult(moe_props, method, this); // Save MoE specific properties -} - -} // namespace quick_dot_ai diff --git a/models/gpt_oss/gpt_oss_moe_layer.h b/models/gpt_oss/gpt_oss_moe_layer.h deleted file mode 100644 index f9213f3e..00000000 --- a/models/gpt_oss/gpt_oss_moe_layer.h +++ /dev/null @@ -1,169 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * Copyright (C) 2025 Eunju Yang - * - * @file gpt_oss_moe_layer.h - * @date 09 June 2025 - * @brief This is Mixture of Expert Layer Class of Neural Network - * @see https://github.com/nntrainer/nntrainer - * @author Eunju Yang - * @bug No known bugs except for NYI items - * @todo This layer does not support backwarding yet. - */ - -#ifndef __GPT_OSS_MOE_LAYER_H__ -#define __GPT_OSS_MOE_LAYER_H__ -#ifdef __cplusplus - -#include -#include -#include -#include - -namespace quick_dot_ai { - -/** - * @class GptOssMoELayer - * @brief Mixture of Expert Layer - */ -class GptOssMoELayer : public nntrainer::LayerImpl { -public: - /** - * @brief Constructor of Mixture of Expert Layer - */ - GptOssMoELayer(); - - /** - * @brief Destructor of Mixture of Expert Layer - */ - ~GptOssMoELayer() = default; - - /** - * @brief Move constructor. - * @param[in] GptOssMoELayer && - */ - GptOssMoELayer(GptOssMoELayer &&rhs) noexcept = default; - - /** - * @brief Move assignment operator. - * @param[in] rhs GptOssMoELayer to be moved. - */ - GptOssMoELayer &operator=(GptOssMoELayer &&rhs) = default; - - /** - * @copydoc Layer::finalize(InitLayerContext &context) - */ - void finalize(nntrainer::InitLayerContext &context) override; - - /** - * @copydoc Layer::forwarding(RunLayerContext &context, bool training) - */ - void forwarding(nntrainer::RunLayerContext &context, bool training) override; - - /** - * @copydoc Layer::incremental_forwarding(RunLayerContext &context, unsigned) - */ - void incremental_forwarding(nntrainer::RunLayerContext &context, - unsigned int from, unsigned int to, - bool training) override; - - /** - * @copydoc Layer::calcDerivative(RunLayerContext &context) - */ - void calcDerivative(nntrainer::RunLayerContext &context) override; - - /** - * @copydoc Layer::calcGradient(RunLayerContext &context) - */ - void calcGradient(nntrainer::RunLayerContext &context) override; - - /** - * @copydoc Layer::setProperty(const std::vector &values) - */ - void setProperty(const std::vector &values) override; - - /** - * @copydoc Layer::exportTo(Exporter &exporter, const ml::train::ExportMethods - * &methods) - */ - void exportTo(nntrainer::Exporter &exporter, - const ml::train::ExportMethods &method) const override; - - /** - * @copydoc Layer::getType() - */ - const std::string getType() const override { return GptOssMoELayer::type; }; - - /** - * @brief Layer::supportBackwarding() - */ - bool supportBackwarding() const override { return false; } - - static constexpr const char *type = "gpt_oss_moe"; /**< type of the layer */ - -private: - unsigned int num_experts; /**< number of experts */ - unsigned int topk; /**< number of experts per token, i.e., topk */ - std::tuple - moe_props; - - // weight indeices - std::vector expert_gate_proj_indices; - std::vector expert_gate_bias_indices; - std::vector expert_up_proj_indices; - std::vector expert_up_bias_indices; - std::vector expert_down_proj_indices; - std::vector expert_down_bias_indices; - unsigned int gate_idx; - unsigned int gate_bias_idx; - - // Intermediate tensor indices - unsigned int router_logits_idx; - unsigned int expert_mask_idx; - bool enable_bias = false; - - float alpha = 1.702; - float limit = 7.0; - - /** - * @brief expert forward computation without memory copies - * @param input Input tensor (reshaped to [total_tokens, 1, 1, hidden_size]) - * @param output Output tensor to accumulate results - * @param token_assignments Vector of (token_index, weight) pairs for this - * expert - * @param gate_proj Gate projection weight tensor - * @param up_proj Up projection weight tensor - * @param down_proj Down projection weight tensor - * @param hidden_size Hidden dimension size - */ - inline void compute_expert_forward( - const nntrainer::Tensor &input, nntrainer::Tensor &output, - const std::vector> &token_assignments, - const nntrainer::Tensor &gate_proj, const nntrainer::Tensor &up_proj, - const nntrainer::Tensor &down_proj, unsigned int hidden_size); - - /** - * @brief expert forward computation without critical section - * @param input Input tensor (reshaped to [total_tokens, 1, 1, hidden_size]) - * @param expert_output Expert-specific output tensor - * @param token_assignments Vector of (token_index, weight) pairs for this - * expert - * @param gate_proj Gate projection weight tensor - * @param up_proj Up projection weight tensor - * @param down_proj Down projection weight tensor - * @param hidden_size Hidden dimension size - */ - inline void compute_expert_forward_no_critical( - const nntrainer::Tensor &input, nntrainer::Tensor &expert_output, - const std::vector> &token_assignments, - const nntrainer::Tensor &gate_proj, const nntrainer::Tensor &up_proj, - const nntrainer::Tensor &down_proj, unsigned int hidden_size, - const nntrainer::Tensor &gate_bias = {}, - const nntrainer::Tensor &up_bias = {}, - const nntrainer::Tensor &down_bias = {}); -}; -} // namespace quick_dot_ai - -#endif /* __cplusplus */ -#endif /* __GPT_OSS_MOE_LAYER_H__ */ diff --git a/models/gpt_oss/gptoss_causallm.cpp b/models/gpt_oss/gptoss_causallm.cpp deleted file mode 100644 index 350a4ca7..00000000 --- a/models/gpt_oss/gptoss_causallm.cpp +++ /dev/null @@ -1,146 +0,0 @@ -/** - * Copyright (C) 2025 Samsung Electronics Co., Ltd. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * http://www.apache.org/licenses/LICENSE-2.0 - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - * - * @file gptoss_causallm.cpp - * @brief This defines a gpt_oss causal language model. - * @date 26 Aug 2025 - * @see https://github.com/nnstreamer/ - * @author Eunju Yang - * @bug No known bugs except for NYI items - * - */ - -#include -#include -#include - -#include -#include -#include - -namespace quick_dot_ai { - -std::vector GptOssForCausalLM::createAttention( - const int layer_id, int seq_len, int n_heads, int head_dim, - std::string query_name, std::string key_name, std::string value_name) { - - std::vector layers; - - ///@note Q/K/V/O has bias! - auto Q = "layer" + std::to_string(layer_id) + "_wq"; - auto K = "layer" + std::to_string(layer_id) + "_wk"; - auto V = "layer" + std::to_string(layer_id) + "_wv"; - auto A = "layer" + std::to_string(layer_id) + "_attention"; - auto O = "layer" + std::to_string(layer_id) + "_attention_out"; - - // Q layer - std::vector q_params = { - withKey("name", Q), withKey("unit", head_dim * n_heads), - withKey("disable_bias", "false"), withKey("input_layers", query_name), - withKey("weight_initializer", "ones")}; - layers.push_back(createLayer("fully_connected", q_params)); - - // K layer - std::vector k_params = { - withKey("name", K), withKey("unit", head_dim * n_heads / GQA_SIZE), - withKey("disable_bias", "false"), withKey("input_layers", key_name), - withKey("weight_initializer", "ones")}; - layers.push_back(createLayer("fully_connected", k_params)); - - // V layer - std::vector v_params = { - withKey("name", V), withKey("unit", head_dim * n_heads / GQA_SIZE), - withKey("disable_bias", "false"), withKey("input_layers", value_name), - withKey("weight_initializer", "ones")}; - layers.push_back(createLayer("fully_connected", v_params)); - - // Attention core layer - // layer_types[layer_id] == "sliding_attention" - // layer_types[layer_id] == "full_attention" - unsigned sliding_window = - (LAYER_TYPES[layer_id] == "sliding_attention") ? SLIDING_WINDOW : UINT_MAX; - // this attention use sink! - std::vector a_params = { - withKey("name", A), - withKey("num_heads", n_heads), - withKey("num_heads_kv", n_heads / GQA_SIZE), - withKey("max_timestep", std::to_string(INIT_SEQ_LEN + NUM_TO_GENERATE)), - withKey("sliding_window", sliding_window), - withKey("rope_theta", ROPE_THETA), - withKey("max_position_embeddings", MAX_POSITION_EMBEDDINGS), - withKey("max_new_tokens", std::to_string(NUM_TO_GENERATE)), - withKey("use_sink", "true"), - withKey("rope_scaling_factor", ATTENTION_ROPE_SCALING_FACTOR), - withKey("rope_scaling_type", "yarn"), - withKey("rope_scaling_max_position_embeddings", 4096), - withKey("input_layers", {Q, K, V})}; - layers.push_back(createLayer("mha_core", a_params)); - - // O layer - std::vector o_params = { - withKey("name", O), withKey("unit", DIM), withKey("disable_bias", "false"), - withKey("input_layers", A), withKey("weight_initializer", "ones")}; - layers.push_back(createLayer("fully_connected", o_params)); - - return layers; -} - -std::vector GptOssForCausalLM::createMlp(const int layer_id, - int dim, int hidden_dim, - std::string input_name) { - - std::vector layers; - layers.push_back(createLayer( - "gpt_oss_moe", - { - withKey("name", "layer" + std::to_string(layer_id) + "_ffn_down"), - withKey("input_layers", input_name), - withKey("unit", hidden_dim), - withKey("num_experts", NUM_EXPERTS), - withKey("num_experts_per_token", NUM_EXPERTS_PER_TOK), - })); - - return layers; -} - -void GptOssForCausalLM::setupParameters(json &cfg, json &generation_cfg, - json &nntr_cfg) { - CausalLM(cfg, generation_cfg, nntr_cfg); - - try { - NUM_EXPERTS = cfg["num_local_experts"].get(); - NUM_EXPERTS_PER_TOK = cfg["num_experts_per_tok"].get(); - LAYER_TYPES = cfg["layer_types"].get>(); - ATTENTION_ROPE_SCALING_FACTOR = cfg["rope_scaling"]["factor"]; - } catch (const std::exception &e) { - throw std::runtime_error("GptOssForCausalLM: config parsing error"); - } -} - -void GptOssForCausalLM::registerCustomLayers() { - CausalLM::registerCustomLayers(); - auto &ct_engine = nntrainer::Engine::Global(); - auto app_context = - static_cast(ct_engine.getRegisteredContext("cpu")); - - try { - app_context->registerFactory( - nntrainer::createLayer); - } catch (std::invalid_argument &e) { - std::cerr << "failed to register factory, reason: " << e.what() - << std::endl; - } -} - -} // namespace quick_dot_ai diff --git a/models/gpt_oss/gptoss_causallm.h b/models/gpt_oss/gptoss_causallm.h deleted file mode 100644 index d55d1614..00000000 --- a/models/gpt_oss/gptoss_causallm.h +++ /dev/null @@ -1,73 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * Copyright (C) 2025 Eunju Yang - * - * @file gptoss_causallm.h - * @date 26 Aug 2025 - * @see https://github.com/nntrainer/nntrainer - * @author Eunju Yang - * @bug No known bugs except for NYI items - * @note Please refer to the following code : - * https://github.com/huggingface/transformers/blob/e68146fbe7052a6dc8456f48edabe705dc1f7381/src/transformers/models/gpt_oss/modeling_gpt_oss.py - */ - -#ifndef __GPTOSS_CAUSALLM_H__ -#define __GPTOSS_CAUSALLM_H__ __GPTOSS_CAUSALLM_H__ - -#include - -namespace quick_dot_ai { - -/** - * @brief GptOssForCausalLM - */ -class GptOssForCausalLM : public CausalLM { -public: - static constexpr const char *architectures = "GptOssForCausalLM"; - - GptOssForCausalLM(json &cfg, json &generation_cfg, json &nntr_cfg) : - Transformer(cfg, generation_cfg, nntr_cfg, ModelType::CAUSALLM), - CausalLM(cfg, generation_cfg, nntr_cfg) { - setupParameters(cfg, generation_cfg, nntr_cfg); - } - - virtual ~GptOssForCausalLM() = default; - - /** - * @brief createAttention - * @note sink attention with sliding window - */ - std::vector createAttention(const int layer_id, int seq_len, - int n_heads, int head_dim, - std::string query_name, - std::string key_name, - std::string value_name) override; - - /** - * @brief MoE layer - */ - std::vector createMlp(const int layer_id, int dim, - int hidden_dim, - std::string input_name) override; - - /** - * @brief setupParameters - */ - void setupParameters(json &cfg, json &generation_cfg, - json &nntr_cfg) override; - - /** - * @brief registerCutomLayers - */ - void registerCustomLayers() override; - -private: - unsigned int NUM_EXPERTS; - unsigned int NUM_EXPERTS_PER_TOK; - std::vector LAYER_TYPES; - float ATTENTION_ROPE_SCALING_FACTOR; -}; - -} // namespace quick_dot_ai - -#endif /** __GPTOSS_CAUSALLM_H__ */ diff --git a/models/gpt_oss/meson.build b/models/gpt_oss/meson.build deleted file mode 100644 index a2c38e01..00000000 --- a/models/gpt_oss/meson.build +++ /dev/null @@ -1,26 +0,0 @@ -gpt_oss_src = [ - meson.current_source_dir() / 'gptoss_causallm.cpp', -] - -gpt_oss_inc = include_directories('.') - -# Define Layers -causallm_gptoss_moe_layer_src_abs = [meson.current_source_dir() / 'gpt_oss_moe_layer.cpp'] - -causallm_gptoss_moe_layer = shared_library( - 'gptoss_moe_layer', - causallm_gptoss_moe_layer_src_abs, - include_directories: [quick_dot_ai_layer_inc, gpt_oss_inc], - dependencies: [nntrainer_dep, nntrainer_ccapi_dep], - install: true, - install_dir: application_install_dir -) - -casuallm_gptoss_moe_layer_dep = declare_dependency( - link_with: causallm_gptoss_moe_layer, - include_directories: gpt_oss_inc -) - -quick_dot_ai_src += gpt_oss_src -quick_dot_ai_inc += gpt_oss_inc -quick_dot_ai_layer_dependencies += [casuallm_gptoss_moe_layer_dep] diff --git a/models/gpt_oss_cached_slim/README.md b/models/gpt_oss_cached_slim/README.md deleted file mode 100644 index 6787c6b7..00000000 --- a/models/gpt_oss_cached_slim/README.md +++ /dev/null @@ -1,9 +0,0 @@ -# GPT-OSS Cached Slim Model - -This directory contains the implementation for GPT-OSS Cached Slim model. - -> πŸ“Œ **Note** on `Cached-Slim`: This model extends the Slim approach (dynamic loading) by caching active experts. This strategy minimizes storage I/O bottlenecks, offering a sweet spot between low memory footprint and high inference speed. - -## Files -- `gptoss_cached_slim_causallm.cpp`: Cached Slim GPT-OSS implementation. -- `gpt_oss_moe_layer_cached.cpp`: Cached GPT-OSS MoE layer implementation. diff --git a/models/gpt_oss_cached_slim/gpt_oss_moe_layer_cached.cpp b/models/gpt_oss_cached_slim/gpt_oss_moe_layer_cached.cpp deleted file mode 100644 index 46c15f18..00000000 --- a/models/gpt_oss_cached_slim/gpt_oss_moe_layer_cached.cpp +++ /dev/null @@ -1,579 +0,0 @@ -/** - * Copyright (C) 2020 Samsung Electronics Co., Ltd. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * http://www.apache.org/licenses/LICENSE-2.0 - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - * - * @file gpt_oss_moe_layer_cached.cpp - * @date 05 Sep 2025 - * @brief This is a Mixture of Expert Layer Class for Gpt-Oss model - * @see https://github.com/nnstreamer/ - * @author Eunju Yang - * @bug No known bugs except for NYI items - * - */ - -#include -#include -#include -#include -#include -#include -#include - -#include -using std::chrono::duration_cast; -using std::chrono::high_resolution_clock; -using std::chrono::nanoseconds; - -namespace quick_dot_ai { - -static constexpr size_t SINGLE_INOUT_IDX = 0; - -CachedSlimGptOssMoELayer::CachedSlimGptOssMoELayer() : - LayerImpl(), - num_experts(0), - topk(0), - moe_props(props::NumExperts(), props::NumExpertsPerToken(), - nntrainer::props::Unit()), - expert_gate_proj_indices({}), - expert_gate_bias_indices({}), - expert_up_proj_indices({}), - expert_up_bias_indices({}), - expert_down_proj_indices({}), - expert_down_bias_indices({}), - gate_idx(std::numeric_limits::max()), - gate_bias_idx(std::numeric_limits::max()), - loaded_expert_deque({}), - need_load({}), - router_logits_idx(std::numeric_limits::max()), - expert_mask_idx(std::numeric_limits::max()) {} - -void CachedSlimGptOssMoELayer::finalize(nntrainer::InitLayerContext &context) { - - // 1. Validate input/output dimensions - NNTR_THROW_IF(context.getNumInputs() != 1, std::invalid_argument) - << "MoE layer only supports single input"; - - auto &weight_regularizer = - std::get(*layer_impl_props); - auto &weight_regularizer_constant = - std::get(*layer_impl_props); - auto &weight_initializer = - std::get(*layer_impl_props); - auto &weight_decay = - std::get(*layer_impl_props); - - // 2. Set output dimensions (same as input) - const auto &in_dim = context.getInputDimensions()[SINGLE_INOUT_IDX]; - const bool is_nchw = context.getFormat() == nntrainer::Tformat::NCHW; - std::vector output_dims(1); - output_dims[SINGLE_INOUT_IDX] = in_dim; - context.setOutputDimensions(output_dims); - - // 3. Get MoE properties - num_experts = std::get(moe_props).get(); - topk = std::get(moe_props).get(); - const unsigned int intermediate_size = - std::get(moe_props).get(); - const unsigned int hidden_size = in_dim.width(); // Feature dimension - - // 4. Initialie gate layer (router) - nntrainer::TensorDim gate_dim( - 1, is_nchw ? 1 : num_experts, is_nchw ? hidden_size : 1, - is_nchw ? num_experts : hidden_size, - nntrainer::TensorDim::TensorType(context.getFormat(), - nntrainer::TensorDim::DataType::FP32), - is_nchw ? 0b0011 : 0b0101); - - gate_idx = context.requestWeight( - gate_dim, weight_initializer, weight_regularizer, - weight_regularizer_constant, weight_decay, "gate", true); - - // pure tensor - nntrainer::TensorDim gate_bias_dim( - 1, 1, 1, num_experts, - nntrainer::TensorDim::TensorType(context.getFormat(), - context.getActivationDataType())); - // pure tensor - gate_bias_idx = - context.requestWeight(gate_bias_dim, weight_initializer, weight_regularizer, - 1.0f, weight_decay, "gate_bias", false); - - // 5. Initializer expert weights (virtual tensor) - expert_gate_proj_indices.reserve(num_experts); - expert_up_proj_indices.reserve(num_experts); - expert_down_proj_indices.reserve(num_experts); - expert_gate_bias_indices.reserve(num_experts); - expert_up_bias_indices.reserve(num_experts); - expert_down_bias_indices.reserve(num_experts); - - nntrainer::TensorDim expert_gate_dim( - 1, is_nchw ? 1 : intermediate_size, is_nchw ? hidden_size : 1, - is_nchw ? intermediate_size : hidden_size, - nntrainer::TensorDim::TensorType(context.getFormat(), - context.getWeightDataType()), - is_nchw ? 0b0011 : 0b0101); - - nntrainer::TensorDim expert_down_dim( - 1, is_nchw ? 1 : hidden_size, is_nchw ? intermediate_size : 1, - is_nchw ? hidden_size : intermediate_size, - nntrainer::TensorDim::TensorType(context.getFormat(), - context.getWeightDataType()), - is_nchw ? 0b0011 : 0b0101); - - nntrainer::TensorDim expert_gate_bias_dim( - 1, 1, 1, intermediate_size, - nntrainer::TensorDim::TensorType(context.getFormat(), - context.getActivationDataType()), - is_nchw ? 0b0011 : 0b0101); - - nntrainer::TensorDim expert_down_bias_dim( - 1, 1, 1, hidden_size, - nntrainer::TensorDim::TensorType(context.getFormat(), - context.getActivationDataType()), - is_nchw ? 0b0011 : 0b0101); - - for (unsigned int i = 0; i < num_experts; ++i) { - // Up projection - expert_up_proj_indices.push_back(context.requestWeight( - expert_gate_dim, // Same dimensions as gate projection - weight_initializer, weight_regularizer, weight_regularizer_constant, - weight_decay, "expert_up_" + std::to_string(i), false, true)); - - expert_up_bias_indices.push_back(context.requestWeight( - expert_gate_bias_dim, // Same dimensions as gate projection - weight_initializer, weight_regularizer, weight_regularizer_constant, - weight_decay, "expert_up_bias_" + std::to_string(i), false, true)); - - // Gate projection - expert_gate_proj_indices.push_back(context.requestWeight( - expert_gate_dim, weight_initializer, weight_regularizer, - weight_regularizer_constant, weight_decay, - "expert_gate_" + std::to_string(i), false, true)); - - expert_gate_bias_indices.push_back(context.requestWeight( - expert_gate_bias_dim, // Same dimensions as gate projection - weight_initializer, weight_regularizer, weight_regularizer_constant, - weight_decay, "expert_gate_bias_" + std::to_string(i), false, true)); - - // Down projection - expert_down_proj_indices.push_back(context.requestWeight( - expert_down_dim, weight_initializer, weight_regularizer, - weight_regularizer_constant, weight_decay, - "expert_down_" + std::to_string(i), false, true)); - - expert_down_bias_indices.push_back(context.requestWeight( - expert_down_bias_dim, // Same dimensions as gate projection - weight_initializer, weight_regularizer, weight_regularizer_constant, - weight_decay, "expert_down_bias_" + std::to_string(i), false, true)); - - // need_load this expert - need_load.push_back(true); - } - - // 6. Request intermediate tensors - const unsigned batch_size = in_dim.batch(); - const unsigned seq_len = in_dim.height(); - const unsigned total_tokens = batch_size * seq_len; - - // Router logits : [batch * seq, num_experts] - router_logits_idx = - context.requestTensor({total_tokens, 1, 1, num_experts}, "router_logits", - nntrainer::Initializer::NONE, false, - nntrainer::TensorLifespan::FORWARD_FUNC_LIFESPAN); - - // Expert mask: [num_experts, batch*seq] - expert_mask_idx = - context.requestTensor({num_experts, 1, topk, total_tokens}, "expert_mask", - nntrainer::Initializer::ZEROS, false, - nntrainer::TensorLifespan::FORWARD_FUNC_LIFESPAN); -} - -void CachedSlimGptOssMoELayer::forwarding(nntrainer::RunLayerContext &context, - bool training) {} - -void CachedSlimGptOssMoELayer::incremental_forwarding( - nntrainer::RunLayerContext &context, unsigned int from, unsigned int to, - bool training) { -#ifdef DEBUG - auto t1 = high_resolution_clock::now(); -#endif - - nntrainer::Tensor &input_ = context.getInput(SINGLE_INOUT_IDX); - nntrainer::Tensor &output_ = context.getOutput(SINGLE_INOUT_IDX); - - nntrainer::Tensor &router_logits_ = context.getTensor(router_logits_idx); - - nntrainer::TensorDim input_step_dim = input_.getDim(); - nntrainer::TensorDim output_step_dim = output_.getDim(); - nntrainer::TensorDim router_logits_step_dim = router_logits_.getDim(); - - input_step_dim.batch(1); - output_step_dim.batch(1); - router_logits_step_dim.batch(to - from); - - input_step_dim.height(to - from); - output_step_dim.height(to - from); - - for (unsigned int b = 0; b < input_.batch(); ++b) { - - auto input = input_.getSharedDataTensor( - input_step_dim, b * input_step_dim.getFeatureLen(), true); - auto output = output_.getSharedDataTensor( - output_step_dim, b * output_step_dim.getFeatureLen(), true); - auto router_logits = - router_logits_.getSharedDataTensor(router_logits_step_dim, 0, true); - - const unsigned batch_size = input.batch(); - const unsigned seq_len = input.height(); - const unsigned hidden_size = input.width(); - const unsigned total_tokens = batch_size * seq_len; - - // reshape input: [B,1,S,H] -> [B*S,1,1,H] - input.reshape({total_tokens, 1, 1, hidden_size}); - - // reshape output: [B,1,S,H] -> [B*S,1,1,H] - output.reshape({total_tokens, 1, 1, hidden_size}); - - // routing - nntrainer::Tensor &gate_weights = context.getWeight(gate_idx); - input.dot(gate_weights, router_logits); - router_logits.apply(nntrainer::ActiFunc::softmax, router_logits); - - // get extra topK - auto extra_topk_result = router_logits.topK(topk + 3); - auto extra_topk_values = std::get<0>(extra_topk_result); - auto extra_topk_indices = std::get<1>(extra_topk_result); - std::deque extra_top_k = {}; - extra_topk_values.divide_i(extra_topk_values.sum(3)); - const uint32_t *extra_indices_data = extra_topk_indices.getData(); - - // get extra topk - for (int i = static_cast(total_tokens) - 1; i >= 0; --i) { - for (int k = 0; k < static_cast(topk + 3); ++k) { - unsigned expert_idx = extra_indices_data[i * topk + k]; - extra_top_k.push_back(expert_idx); - } - } - - auto topk_result = router_logits.topK(topk); - auto topk_values = std::get<0>(topk_result); - auto topk_indices = std::get<1>(topk_result); - - // norm_topk_prob - topk_values.divide_i(topk_values.sum(3)); - - const uint32_t *indices_data = topk_indices.getData(); - std::vector>> expert_assignments( - num_experts); - // Set expert mask - for (int i = 0; i < static_cast(total_tokens); ++i) { - for (int k = 0; k < static_cast(topk); ++k) { - unsigned expert_idx = indices_data[i * topk + k]; - float weight = topk_values.getValue(i, 0, 0, k); - expert_assignments[expert_idx].emplace_back(i, weight); - } - } - - // Parallel processing for multiple tokens with many active experts - std::vector expert_outputs(num_experts); -#pragma omp parallel for schedule(static) - for (int expert_idx = 0; expert_idx < static_cast(num_experts); - ++expert_idx) { - if (!expert_assignments[expert_idx].empty()) { - expert_outputs[expert_idx] = nntrainer::Tensor( - total_tokens, 1, 1, hidden_size, output.getTensorType()); - } - } - std::vector target_idx_vector; - - for (int expert_idx = 0; expert_idx < static_cast(num_experts); - ++expert_idx) { - const auto &assignments = expert_assignments[expert_idx]; - if (assignments.empty()) - continue; - - target_idx_vector.push_back(expert_idx); - } - - int hit_count = 0; - int miss_count = 0; - -#ifdef DEBUG - auto t1_miss = high_resolution_clock::now(); - auto t2_miss = high_resolution_clock::now(); - auto t1_hit = high_resolution_clock::now(); - auto t2_hit = high_resolution_clock::now(); -#endif - -#pragma omp parallel for schedule(dynamic) - for (int expert_idx : target_idx_vector) { - const auto &assignments = expert_assignments[expert_idx]; - if (need_load[expert_idx]) { - -#ifdef DEBUG - t1_miss = high_resolution_clock::now(); -#endif - - context.getWeight(expert_gate_proj_indices[expert_idx]).activate(); - context.getWeight(expert_up_proj_indices[expert_idx]).activate(); - context.getWeight(expert_down_proj_indices[expert_idx]).activate(); - - context.getWeight(expert_gate_bias_indices[expert_idx]).activate(); - context.getWeight(expert_up_bias_indices[expert_idx]).activate(); - context.getWeight(expert_down_bias_indices[expert_idx]).activate(); - - { - std::lock_guard lock(cache_mutex); - loaded_expert_deque.push_back(expert_idx); - iteration_map[expert_idx] = --loaded_expert_deque.end(); - need_load[expert_idx] = false; - miss_count += 1; - } - - compute_expert_forward( - input, expert_outputs[expert_idx], assignments, - context.getWeight(expert_gate_proj_indices[expert_idx]), - context.getWeight(expert_up_proj_indices[expert_idx]), - context.getWeight(expert_down_proj_indices[expert_idx]), - context.getWeight(expert_gate_bias_indices[expert_idx]), - context.getWeight(expert_up_bias_indices[expert_idx]), - context.getWeight(expert_down_bias_indices[expert_idx]), hidden_size); -#ifdef DEBUG - t2_miss = high_resolution_clock::now(); -#endif - } else { - -#ifdef DEBUG - t1_hit = high_resolution_clock::now(); -#endif - { - std::lock_guard lock(cache_mutex); - hit_count += 1; - } - - compute_expert_forward( - input, expert_outputs[expert_idx], assignments, - context.getWeight(expert_gate_proj_indices[expert_idx]), - context.getWeight(expert_up_proj_indices[expert_idx]), - context.getWeight(expert_down_proj_indices[expert_idx]), - context.getWeight(expert_gate_bias_indices[expert_idx]), - context.getWeight(expert_up_bias_indices[expert_idx]), - context.getWeight(expert_down_bias_indices[expert_idx]), hidden_size); - -#ifdef DEBUG - t2_hit = high_resolution_clock::now(); -#endif - } - } - - for (int i = extra_top_k.size() - 1; i >= 0; i--) { - if (iteration_map.find(extra_top_k[i]) != iteration_map.end()) { - loaded_expert_deque.erase(iteration_map[extra_top_k[i]]); - loaded_expert_deque.push_back(extra_top_k[i]); - iteration_map[extra_top_k[i]] = --loaded_expert_deque.end(); - } - } - -#ifdef DEBUG - auto t1_evict = high_resolution_clock::now(); -#endif - -// Evict experts -#pragma omp parallel - while (loaded_expert_deque.size() > 16) { - int target_idx; - { - std::lock_guard lock(cache_mutex); - target_idx = loaded_expert_deque.front(); - loaded_expert_deque.pop_front(); - iteration_map.erase(target_idx); - need_load[target_idx] = true; - } - context.getWeight(expert_gate_proj_indices[target_idx]).deactivate(); - context.getWeight(expert_up_proj_indices[target_idx]).deactivate(); - context.getWeight(expert_down_proj_indices[target_idx]).deactivate(); - context.getWeight(expert_gate_bias_indices[target_idx]).deactivate(); - context.getWeight(expert_up_bias_indices[target_idx]).deactivate(); - context.getWeight(expert_down_bias_indices[target_idx]).deactivate(); - } - -#ifdef DEBUG - auto t2_evict = high_resolution_clock::now(); -#endif - - // Combine expert outputs - int init = 0; - for (int expert_idx : target_idx_vector) { - if (!init) { - output.copyData(expert_outputs[expert_idx]); - ++init; - } else { - output.add_i(expert_outputs[expert_idx]); - } - } - - // reshape output: [B*S,1,1,H] -> [B,1,S,H] - output.reshape({batch_size, 1, seq_len, hidden_size}); - -#ifdef DEBUG - auto t2 = high_resolution_clock::now(); - auto dt = duration_cast(t2 - t1); - auto dt_miss = duration_cast(t2_miss - t1_miss); - auto dt_hit = duration_cast(t2_hit - t1_hit); - auto dt_evict = duration_cast(t2_evict - t1_evict); - std::cout << context.getName() << " \t| " << dt.count() << " ns " - << "\t| " << dt.count() / 1'000 << " us " - << "\t| " << dt.count() / 1'000'000 << " ms " - << "\t| " - << "hit ratio: " << hit_count / 8.0 << "\t | " - << " miss ratio: " << miss_count / 8.0 << "\t | " - << "hit_compute: " << dt_hit.count() / 1'000'000 << " ms " - << "\t| " - << "miss_compute: " << dt_miss.count() / 1'000'000 << " ms " - << "\t| " - << "evict_time: " << dt_evict.count() / 1'000'000 << " ms " - << "\t| " << std::endl; -#endif - } -} - -inline void CachedSlimGptOssMoELayer::compute_expert_forward( - const nntrainer::Tensor &input, nntrainer::Tensor &expert_output, - const std::vector> &token_assignments, - const nntrainer::Tensor &gate_proj, const nntrainer::Tensor &up_proj, - const nntrainer::Tensor &down_proj, const nntrainer::Tensor &gate_bias, - const nntrainer::Tensor &up_bias, const nntrainer::Tensor &down_bias, - unsigned int hidden_size) { - - const unsigned intermediate_size = gate_proj.width(); - const unsigned num_tokens = token_assignments.size(); - - if (num_tokens == 0) - return; - - // Create tensor dimensions for single token processing - nntrainer::TensorDim token_input_dim({1, 1, num_tokens, hidden_size}, - input.getTensorType()); - nntrainer::TensorDim intermediate_dim({1, 1, num_tokens, intermediate_size}, - input.getTensorType()); - nntrainer::TensorDim token_output_dim({1, 1, num_tokens, hidden_size}, - input.getTensorType()); - nntrainer::TensorDim out_step_dim({1, 1, 1, hidden_size}, - input.getTensorType()); - nntrainer::TensorDim step_dim({1, 1, 1, intermediate_size}, - input.getTensorType()); - // Create intermediate tensors for this token - nntrainer::Tensor gate_out(intermediate_dim); - nntrainer::Tensor acti_out(intermediate_dim); - nntrainer::Tensor up_out(intermediate_dim); - nntrainer::Tensor token_input(token_input_dim); - // Down projection using optimized dot operation - nntrainer::Tensor token_expert_output(token_output_dim); - - unsigned token_idx = token_assignments[0].first; - float weight = token_assignments[0].second; - - if (num_tokens > 1) { - /** if prefill, copy data to make a batch */ -#pragma omp parallel for schedule(static) if (num_tokens > 4) - for (size_t i = 0; i < num_tokens; ++i) { - const unsigned token_idx = token_assignments[i].first; - // Use tensor's optimized copy operation - nntrainer::Tensor src_view = input.getSharedDataTensor( - {1, 1, 1, hidden_size}, token_idx * hidden_size, true); - nntrainer::Tensor dst_view = token_input.getSharedDataTensor( - {1, 1, 1, hidden_size}, i * hidden_size, true); - dst_view.copyData(src_view); - } - } else { - /** if token generation, do not copy but get the shared tensor */ - // Create shared tensor for input token (no memory copy) - size_t token_offset = token_idx * hidden_size; - token_input = - input.getSharedDataTensor(token_input_dim, token_offset, true); - } - - // Gate projection using optimized dot operation - token_input.dot(gate_proj, gate_out); - gate_out.add(gate_bias, gate_out); - // gate_out.clamp(min=None, max=limit) - nntrainer::clamp(gate_out.getData(), gate_out.getData(), - num_tokens * intermediate_size, - std::numeric_limits::lowest(), limit); - - // Up projection using optimized dot operation - token_input.dot(up_proj, up_out); - up_out.add_i(up_bias); - // up_out.clamp(min=-limit, max=limit) - nntrainer::clamp(up_out.getData(), up_out.getData(), - num_tokens * intermediate_size, -limit, limit); - - // Apply activation (silu) - // (up + 1) * (gate * torch.sigmoid(gate * alpha)) - // swiglu : X = Z * (Y / 1 + exp(-alpha * Y)) - // X := acti_out - // Y := gate_out - // Z := up_out + 1 - up_out.add_i(1); -#pragma omp parallel for schedule(static) if (num_tokens > 2) - for (size_t i = 0; i < num_tokens; ++i) { - const unsigned offset = acti_out.getIndex(0, 0, i, 0); - nntrainer::swiglu(acti_out.width(), acti_out.getData() + offset, - gate_out.getData() + offset, - up_out.getData() + offset, alpha); - } - - // Down projection using optimized dot operation - acti_out.dot(down_proj, token_expert_output); - token_expert_output.add_i(down_bias); - - // accumulate to output -#pragma omp parallel for schedule(static) if (num_tokens > 2) - for (size_t i = 0; i < num_tokens; ++i) { - token_idx = token_assignments[i].first; - weight = token_assignments[i].second; - size_t output_offset = token_idx * hidden_size; - nntrainer::Tensor token_output = - expert_output.getSharedDataTensor(out_step_dim, output_offset, true); - nntrainer::Tensor target = token_expert_output.getSharedDataTensor( - out_step_dim, i * hidden_size, true); - target.multiply_i(weight); - token_output.add(target, token_output); - } -} - -void CachedSlimGptOssMoELayer::setProperty( - const std::vector &values) { - auto remain_props = loadProperties(values, moe_props); - nntrainer::LayerImpl::setProperty(remain_props); -} - -void CachedSlimGptOssMoELayer::calcDerivative( - nntrainer::RunLayerContext &context) { - // MoE layer does not support derivative calculation - throw std::runtime_error("MoE layer does not support derivative calculation"); -} - -void CachedSlimGptOssMoELayer::calcGradient( - nntrainer::RunLayerContext &context) { - // MoE layer does not support gradient calculation - throw std::runtime_error("MoE layer does not support gradient calculation"); -} - -void CachedSlimGptOssMoELayer::exportTo( - nntrainer::Exporter &exporter, const ml::train::ExportMethods &method) const { - nntrainer::LayerImpl::exportTo(exporter, method); - exporter.saveResult(moe_props, method, this); // Save MoE specific properties -} - -} // namespace quick_dot_ai \ No newline at end of file diff --git a/models/gpt_oss_cached_slim/gpt_oss_moe_layer_cached.h b/models/gpt_oss_cached_slim/gpt_oss_moe_layer_cached.h deleted file mode 100644 index e417ed16..00000000 --- a/models/gpt_oss_cached_slim/gpt_oss_moe_layer_cached.h +++ /dev/null @@ -1,165 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * Copyright (C) 2025 Eunju Yang - * - * @file gpt_oss_moe_layer_cached.h - * @date 05 Sep 2025 - * @brief Gpt Oss MoE layer with cached fsu - * @see https://github.com/nntrainer/nntrainer - * @author Eunju Yang - * @bug No known bugs except for NYI items - * @todo This layer does not support backwarding yet. - */ - -#ifndef __GPT_OSS_MOE_LAYER_CACHED_H__ -#define __GPT_OSS_MOE_LAYER_CACHED_H__ -#ifdef __cplusplus - -#include -#include -#include -#include -#include - -namespace quick_dot_ai { - -/** - * @class GptOssMoELayer - * @brief Mixture of Expert Layer - */ -class CachedSlimGptOssMoELayer : public nntrainer::LayerImpl { -public: - /** - * @brief Constructor of Mixture of Expert Layer - */ - CachedSlimGptOssMoELayer(); - - /** - * @brief Destructor of Mixture of Expert Layer - */ - ~CachedSlimGptOssMoELayer() = default; - - /** - * @brief Move constructor. - * @param[in] CachedSlimGptOssMoELayer && - */ - CachedSlimGptOssMoELayer(CachedSlimGptOssMoELayer &&rhs) = delete; - - /** - * @brief Move assignment operator. - * @param[in] rhs CachedSlimGptOssMoELayer to be moved. - */ - CachedSlimGptOssMoELayer &operator=(CachedSlimGptOssMoELayer &&rhs) = delete; - - /** - * @copydoc Layer::finalize(InitLayerContext &context) - */ - void finalize(nntrainer::InitLayerContext &context) override; - - /** - * @copydoc Layer::forwarding(RunLayerContext &context, bool training) - */ - void forwarding(nntrainer::RunLayerContext &context, bool training) override; - - /** - * @copydoc Layer::incremental_forwarding(RunLayerContext &context, unsigned) - */ - void incremental_forwarding(nntrainer::RunLayerContext &context, - unsigned int from, unsigned int to, - bool training) override; - - /** - * @copydoc Layer::calcDerivative(RunLayerContext &context) - */ - void calcDerivative(nntrainer::RunLayerContext &context) override; - - /** - * @copydoc Layer::calcGradient(RunLayerContext &context) - */ - void calcGradient(nntrainer::RunLayerContext &context) override; - - /** - * @copydoc Layer::setProperty(const std::vector &values) - */ - void setProperty(const std::vector &values) override; - - /** - * @copydoc Layer::exportTo(Exporter &exporter, const ml::train::ExportMethods - * &methods) - */ - void exportTo(nntrainer::Exporter &exporter, - const ml::train::ExportMethods &method) const override; - - /** - * @copydoc Layer::getType() - */ - const std::string getType() const override { - return CachedSlimGptOssMoELayer::type; - }; - - /** - * @brief Layer::supportBackwarding() - */ - bool supportBackwarding() const override { return false; } - - static constexpr const char *type = - "gpt_oss_moe_slim_cached"; /**< type of the layer */ - -private: - unsigned int num_experts; /**< number of experts */ - unsigned int topk; /**< number of experts per token, i.e., topk */ - std::tuple - moe_props; - - // weight indeices - std::vector expert_gate_proj_indices; - std::vector expert_gate_bias_indices; - std::vector expert_up_proj_indices; - std::vector expert_up_bias_indices; - std::vector expert_down_proj_indices; - std::vector expert_down_bias_indices; - unsigned int gate_idx; - unsigned int gate_bias_idx; - - std::list loaded_expert_deque; - std::unordered_map::iterator> iteration_map; - std::unordered_map expert_predict_scores; - std::vector need_load; - - // Intermediate tensor indices - unsigned int router_logits_idx; - unsigned int expert_mask_idx; - bool enable_bias = false; - std::mutex cache_mutex; - - float alpha = 1.702; - float limit = 7.0; - - /** - * @brief expert forward computation without critical section - * @param input Input tensor (reshaped to [total_tokens, 1, 1, hidden_size]) - * @param expert_output Expert-specific output tensor - * @param token_assignments Vector of (token_index, weight) pairs for this - * expert - * @param gate_proj Gate projection weight tensor - * @param up_proj Up projection weight tensor - * @param down_proj Down projection weight tensor - * @param gate_bias Gate projection weight tensor - * @param up_bias Up projection weight tensor - * @param down_bias Down projection weight tensor - * @param hidden_size Hidden dimension size - */ - inline void compute_expert_forward( - const nntrainer::Tensor &input, nntrainer::Tensor &expert_output, - const std::vector> &token_assignments, - const nntrainer::Tensor &gate_proj, const nntrainer::Tensor &up_proj, - const nntrainer::Tensor &down_proj, const nntrainer::Tensor &gate_bias, - const nntrainer::Tensor &up_bias, const nntrainer::Tensor &down_bias, - unsigned int hidden_size); -}; - -} // namespace quick_dot_ai - -#endif /** __cplusplus */ -#endif /** __GPT_OSS_MOE_LAYER_CACHED_H__ */ diff --git a/models/gpt_oss_cached_slim/gptoss_cached_slim_causallm.cpp b/models/gpt_oss_cached_slim/gptoss_cached_slim_causallm.cpp deleted file mode 100644 index 05b2dbda..00000000 --- a/models/gpt_oss_cached_slim/gptoss_cached_slim_causallm.cpp +++ /dev/null @@ -1,146 +0,0 @@ -/** - * Copyright (C) 2025 Samsung Electronics Co., Ltd. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * http://www.apache.org/licenses/LICENSE-2.0 - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - * - * @file gptoss_causallm.cpp - * @brief This defines a gpt_oss causal language model. - * @date 26 Aug 2025 - * @see https://github.com/nnstreamer/ - * @author Eunju Yang - * @bug No known bugs except for NYI items - * - */ - -#include -#include -#include - -#include -#include -#include - -namespace quick_dot_ai { - -std::vector GptOssCachedSlimCausalLM::createAttention( - const int layer_id, int seq_len, int n_heads, int head_dim, - std::string query_name, std::string key_name, std::string value_name) { - - std::vector layers; - - ///@note Q/K/V/O has bias! - auto Q = "layer" + std::to_string(layer_id) + "_wq"; - auto K = "layer" + std::to_string(layer_id) + "_wk"; - auto V = "layer" + std::to_string(layer_id) + "_wv"; - auto A = "layer" + std::to_string(layer_id) + "_attention"; - auto O = "layer" + std::to_string(layer_id) + "_attention_out"; - - // V layer - std::vector v_params = { - withKey("name", V), withKey("unit", head_dim * n_heads / GQA_SIZE), - withKey("disable_bias", "false"), withKey("input_layers", value_name), - withKey("weight_initializer", "ones")}; - layers.push_back(createLayer("fully_connected", v_params)); - - // K layer - std::vector k_params = { - withKey("name", K), withKey("unit", head_dim * n_heads / GQA_SIZE), - withKey("disable_bias", "false"), withKey("input_layers", key_name), - withKey("weight_initializer", "ones")}; - layers.push_back(createLayer("fully_connected", k_params)); - - // Q layer - std::vector q_params = { - withKey("name", Q), withKey("unit", head_dim * n_heads), - withKey("disable_bias", "false"), withKey("input_layers", query_name), - withKey("weight_initializer", "ones")}; - layers.push_back(createLayer("fully_connected", q_params)); - - // Attention core layer - // layer_types[layer_id] == "sliding_attention" - // layer_types[layer_id] == "full_attention" - unsigned sliding_window = - (LAYER_TYPES[layer_id] == "sliding_attention") ? SLIDING_WINDOW : UINT_MAX; - // this attention use sink! - std::vector a_params = { - withKey("name", A), - withKey("num_heads", n_heads), - withKey("num_heads_kv", n_heads / GQA_SIZE), - withKey("max_timestep", std::to_string(INIT_SEQ_LEN + NUM_TO_GENERATE)), - withKey("sliding_window", sliding_window), - withKey("rope_theta", ROPE_THETA), - withKey("max_position_embeddings", MAX_POSITION_EMBEDDINGS), - withKey("max_new_tokens", std::to_string(NUM_TO_GENERATE)), - withKey("use_sink", "true"), - withKey("rope_scaling_factor", ATTENTION_ROPE_SCALING_FACTOR), - withKey("rope_scaling_type", "yarn"), - withKey("rope_scaling_max_position_embeddings", 4096), - withKey("input_layers", {Q, K, V})}; - layers.push_back(createLayer("mha_core", a_params)); - - // O layer - std::vector o_params = { - withKey("name", O), withKey("unit", DIM), withKey("disable_bias", "false"), - withKey("input_layers", A), withKey("weight_initializer", "ones")}; - layers.push_back(createLayer("fully_connected", o_params)); - - return layers; -} - -std::vector -GptOssCachedSlimCausalLM::createMlp(const int layer_id, int dim, int hidden_dim, - std::string input_name) { - - std::vector layers; - layers.push_back(createLayer( - "gpt_oss_moe_slim_cached", - { - withKey("name", "layer" + std::to_string(layer_id) + "_ffn_down"), - withKey("input_layers", input_name), - withKey("unit", hidden_dim), - withKey("num_experts", NUM_EXPERTS), - withKey("num_experts_per_token", NUM_EXPERTS_PER_TOK), - })); - - return layers; -} - -void GptOssCachedSlimCausalLM::setupParameters(json &cfg, json &generation_cfg, - json &nntr_cfg) { - CausalLM(cfg, generation_cfg, nntr_cfg); - - try { - NUM_EXPERTS = cfg["num_local_experts"].get(); - NUM_EXPERTS_PER_TOK = cfg["num_experts_per_tok"].get(); - LAYER_TYPES = cfg["layer_types"].get>(); - ATTENTION_ROPE_SCALING_FACTOR = cfg["rope_scaling"]["factor"]; - } catch (const std::exception &e) { - throw std::runtime_error("GptOssCachedSlimCausalLM: config parsing error"); - } -} - -void GptOssCachedSlimCausalLM::registerCustomLayers() { - CausalLM::registerCustomLayers(); - auto &ct_engine = nntrainer::Engine::Global(); - auto app_context = - static_cast(ct_engine.getRegisteredContext("cpu")); - - try { - app_context->registerFactory( - nntrainer::createLayer); - } catch (std::invalid_argument &e) { - std::cerr << "failed to register factory, reason: " << e.what() - << std::endl; - } -} - -} // namespace quick_dot_ai diff --git a/models/gpt_oss_cached_slim/gptoss_cached_slim_causallm.h b/models/gpt_oss_cached_slim/gptoss_cached_slim_causallm.h deleted file mode 100644 index a57ea9ce..00000000 --- a/models/gpt_oss_cached_slim/gptoss_cached_slim_causallm.h +++ /dev/null @@ -1,73 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * Copyright (C) 2025 Eunju Yang - * - * @file gptoss_causallm.h - * @date 26 Aug 2025 - * @see https://github.com/nntrainer/nntrainer - * @author Eunju Yang - * @bug No known bugs except for NYI items - * @note Please refer to the following code : - * https://github.com/huggingface/transformers/blob/e68146fbe7052a6dc8456f48edabe705dc1f7381/src/transformers/models/gpt_oss/modeling_gpt_oss.py - */ - -#ifndef __GPTOSS_CACHED_SLIM_CAUSALLM_H__ -#define __GPTOSS_CACHED_SLIM_CAUSALLM_H__ - -#include - -namespace quick_dot_ai { - -/** - * @brief GptOssCachedSlimCausalLM - */ -class GptOssCachedSlimCausalLM : public CausalLM { -public: - static constexpr const char *architectures = "GptOssCachedSlimCausalLM"; - - GptOssCachedSlimCausalLM(json &cfg, json &generation_cfg, json &nntr_cfg) : - Transformer(cfg, generation_cfg, nntr_cfg, ModelType::CAUSALLM), - CausalLM(cfg, generation_cfg, nntr_cfg) { - setupParameters(cfg, generation_cfg, nntr_cfg); - } - - virtual ~GptOssCachedSlimCausalLM() = default; - - /** - * @brief createAttention - * @note sink attention with sliding window - */ - std::vector createAttention(const int layer_id, int seq_len, - int n_heads, int head_dim, - std::string query_name, - std::string key_name, - std::string value_name) override; - - /** - * @brief MoE layer - */ - std::vector createMlp(const int layer_id, int dim, - int hidden_dim, - std::string input_name) override; - - /** - * @brief setupParameters - */ - void setupParameters(json &cfg, json &generation_cfg, - json &nntr_cfg) override; - - /** - * @brief registerCutomLayers - */ - void registerCustomLayers() override; - -private: - unsigned int NUM_EXPERTS; - unsigned int NUM_EXPERTS_PER_TOK; - std::vector LAYER_TYPES; - float ATTENTION_ROPE_SCALING_FACTOR; -}; - -} // namespace quick_dot_ai - -#endif /** __GPTOSS_CAUSALLM_H__ */ diff --git a/models/gpt_oss_cached_slim/meson.build b/models/gpt_oss_cached_slim/meson.build deleted file mode 100644 index a94e33d1..00000000 --- a/models/gpt_oss_cached_slim/meson.build +++ /dev/null @@ -1,25 +0,0 @@ -gpt_oss_cached_slim_src = [ - meson.current_source_dir() / 'gptoss_cached_slim_causallm.cpp', -] - -gpt_oss_cached_slim_inc = include_directories('.') - -causallm_gptoss_moe_layer_cached_src_abs = [meson.current_source_dir() / 'gpt_oss_moe_layer_cached.cpp'] - -causallm_cached_slim_gptoss_moe_layer = shared_library( - 'cached_slim_gptoss_moe_layer', - causallm_gptoss_moe_layer_cached_src_abs, - include_directories: [quick_dot_ai_layer_inc, gpt_oss_cached_slim_inc], - dependencies: [nntrainer_dep, nntrainer_ccapi_dep], - install: true, - install_dir: application_install_dir -) - -causallm_cached_slim_gpt_oss_moe_layer_dep = declare_dependency( - link_with: causallm_cached_slim_gptoss_moe_layer, - include_directories: gpt_oss_cached_slim_inc -) - -quick_dot_ai_src += gpt_oss_cached_slim_src -quick_dot_ai_inc += gpt_oss_cached_slim_inc -quick_dot_ai_layer_dependencies += [causallm_cached_slim_gpt_oss_moe_layer_dep] diff --git a/models/meson.build b/models/meson.build deleted file mode 100644 index 9e342b65..00000000 --- a/models/meson.build +++ /dev/null @@ -1,16 +0,0 @@ -quick_dot_ai_src += [ - meson.current_source_dir() / 'causal_lm.cpp', - meson.current_source_dir() / 'transformer.cpp', - meson.current_source_dir() / 'sentence_transformer.cpp', -] - -quick_dot_ai_inc += include_directories('.') - -subdir('gpt_oss') -subdir('gpt_oss_cached_slim') -subdir('qwen2') -subdir('qwen3') -subdir('qwen3_moe') -subdir('qwen3_slim_moe') -subdir('qwen3_cached_slim_moe') -subdir('gemma3') diff --git a/models/performance_metrics.h b/models/performance_metrics.h deleted file mode 100644 index 60e6892d..00000000 --- a/models/performance_metrics.h +++ /dev/null @@ -1,69 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * Copyright (C) 2026 Samsung Electronics Co., Ltd. All Rights Reserved. - * - * @file performance_metrics.h - * @date 24 Mar 2026 - * @brief Performance metrics definitions shared between models and API layers - * @see https://github.com/nntrainer/nntrainer - * @author Eunju Yang - * @bug No known bugs except for NYI items - */ - -#ifndef __CAUSAL_LM_PERFORMANCE_METRICS_H__ -#define __CAUSAL_LM_PERFORMANCE_METRICS_H__ - -#include - -#ifdef __cplusplus -extern "C" { -#endif - -/** - * @brief Performance Metrics - */ -typedef struct { - unsigned int prefill_tokens; - double prefill_duration_ms; - unsigned int generation_tokens; - double generation_duration_ms; - double total_duration_ms; - double initialization_duration_ms; - size_t peak_memory_kb; -} TransformerPerformanceMetrics; - -#ifdef __cplusplus -} -#endif - -#ifdef __cplusplus - -#ifdef _WIN32 -#include -#include -#else -#include -#endif - -/** - * @brief Get peak memory usage in KB - */ -inline size_t getPeakMemoryKb() { -#if defined(_WIN32) - PROCESS_MEMORY_COUNTERS pmc; - if (GetProcessMemoryInfo(GetCurrentProcess(), &pmc, sizeof(pmc))) { - return (size_t)(pmc.PeakWorkingSetSize / 1024); - } - return 0; -#else - struct rusage rusage; - if (getrusage(RUSAGE_SELF, &rusage) == 0) { - return (size_t)(rusage.ru_maxrss); - } - return 0; -#endif -} - -#endif // __cplusplus - -#endif // __CAUSAL_LM_PERFORMANCE_METRICS_H__ diff --git a/models/qwen2/meson.build b/models/qwen2/meson.build deleted file mode 100644 index 78f15b27..00000000 --- a/models/qwen2/meson.build +++ /dev/null @@ -1,9 +0,0 @@ -qwen2_src = [ - meson.current_source_dir() / 'qwen2_causallm.cpp', - meson.current_source_dir() / 'qwen2_embedding.cpp', -] - -qwen2_inc = include_directories('.') - -quick_dot_ai_inc += qwen2_inc -quick_dot_ai_src += qwen2_src diff --git a/models/qwen2/qwen2_causallm.cpp b/models/qwen2/qwen2_causallm.cpp deleted file mode 100644 index 78c60966..00000000 --- a/models/qwen2/qwen2_causallm.cpp +++ /dev/null @@ -1,76 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * Copyright (C) 2026 Seunghui Lee - * - * @file qwen2_causallm.h - * @date 6 January 2026 - * @brief This defines a qwen2 causal language model. - * @see https://github.com/nntrainer/nntrainer - * @author Seunghui Lee - * @bug No known bugs except for NYI items - */ -#include -#include -#include - -#include -#include -#include - -namespace quick_dot_ai { - -std::vector Qwen2Transformer::createAttention( - const int layer_id, int seq_len, int n_heads, int head_dim, - std::string query_name, std::string key_name, std::string value_name) { - std::vector layers; - auto Q = "layer" + std::to_string(layer_id) + "_wq"; - auto K = "layer" + std::to_string(layer_id) + "_wk"; - auto V = "layer" + std::to_string(layer_id) + "_wv"; - auto A = "layer" + std::to_string(layer_id) + "_attention"; - auto O = "layer" + std::to_string(layer_id) + "_attention_out"; - - // Q layer - std::vector q_params = { - withKey("name", Q), withKey("unit", head_dim * n_heads), - withKey("disable_bias", "false"), withKey("input_layers", query_name), - withKey("weight_initializer", "ones")}; - layers.push_back(createLayer("fully_connected", q_params)); - - // K layer - std::vector k_params = { - withKey("name", K), withKey("unit", head_dim * n_heads / GQA_SIZE), - withKey("disable_bias", "false"), withKey("input_layers", key_name), - withKey("weight_initializer", "ones")}; - layers.push_back(createLayer("fully_connected", k_params)); - - // V layer - std::vector v_params = { - withKey("name", V), withKey("unit", head_dim * n_heads / GQA_SIZE), - withKey("disable_bias", "false"), withKey("input_layers", value_name), - withKey("weight_initializer", "ones")}; - layers.push_back(createLayer("fully_connected", v_params)); - - // Attention core layer - std::vector a_params = { - withKey("name", A), - withKey("num_heads", n_heads), - withKey("num_heads_kv", n_heads / GQA_SIZE), - withKey("max_timestep", std::to_string(INIT_SEQ_LEN + NUM_TO_GENERATE)), - withKey("sliding_window", SLIDING_WINDOW), - withKey("rope_theta", ROPE_THETA), - withKey("max_position_embeddings", MAX_POSITION_EMBEDDINGS), - withKey("max_new_tokens", std::to_string(NUM_TO_GENERATE)), - withKey("is_causal", IS_CAUSAL ? "true" : "false"), - withKey("input_layers", {Q, K, V})}; - layers.push_back(createLayer("mha_core", a_params)); - - // O layer - std::vector o_params = { - withKey("name", O), withKey("unit", DIM), withKey("disable_bias", "true"), - withKey("input_layers", A), withKey("weight_initializer", "ones")}; - layers.push_back(createLayer("fully_connected", o_params)); - - return layers; -} - -} // namespace quick_dot_ai diff --git a/models/qwen2/qwen2_causallm.h b/models/qwen2/qwen2_causallm.h deleted file mode 100644 index 73ff0ff0..00000000 --- a/models/qwen2/qwen2_causallm.h +++ /dev/null @@ -1,58 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * Copyright (C) 2026 Seunghui Lee - * - * @file qwen2_causallm.h - * @date 6 January 2026 - * @see https://github.com/nntrainer/nntrainer - * @author Seunghui Lee - * @bug No known bugs except for NYI items - * @note Please refer to the following code : - * https://github.com/huggingface/transformers/blob/v4.37.0/src/transformers/models/qwen2/modeling_qwen2.py - */ - -#ifndef __QWEN2_CAUSAL_LM_H__ -#define __QWEN2_CAUSAL_LM_H__ - -#include - -namespace quick_dot_ai { - -/** - * @brief Qwen2Transformer class - */ -class Qwen2Transformer : virtual public Transformer { - -public: - static constexpr const char *architectures = "Qwen2Transformer"; - - Qwen2Transformer(json &cfg, json &generation_cfg, json &nntr_cfg) : - Transformer(cfg, generation_cfg, nntr_cfg) {} - - virtual ~Qwen2Transformer() = default; - - std::vector createAttention(const int layer_id, int seq_len, - int n_heads, int head_dim, - std::string query_name, - std::string key_name, - std::string value_name) override; -}; - -/** - * @brief Qwen2CausalLM class - */ -class Qwen2CausalLM : public CausalLM, public Qwen2Transformer { - -public: - static constexpr const char *architectures = "Qwen2CausalLM"; - - Qwen2CausalLM(json &cfg, json &generation_cfg, json &nntr_cfg) : - Transformer(cfg, generation_cfg, nntr_cfg, ModelType::CAUSALLM), - CausalLM(cfg, generation_cfg, nntr_cfg), - Qwen2Transformer(cfg, generation_cfg, nntr_cfg) {} - - virtual ~Qwen2CausalLM() = default; -}; -} // namespace quick_dot_ai - -#endif /* __QWEN2_CAUSAL_LM_H__*/ diff --git a/models/qwen2/qwen2_embedding.cpp b/models/qwen2/qwen2_embedding.cpp deleted file mode 100644 index 1ba5e47a..00000000 --- a/models/qwen2/qwen2_embedding.cpp +++ /dev/null @@ -1,14 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * Copyright (C) 2026 Seunghui Lee - * - * @file qwen2_embedding.cpp - * @date 14 January 2026 - * @see https://github.com/nntrainer/nntrainer - * @author Seunghui Lee - * @bug No known bugs except for NYI items - * @brief This file defines Qwen2 Embedding model - */ -#include "qwen2_embedding.h" - -namespace quick_dot_ai {} // namespace quick_dot_ai diff --git a/models/qwen2/qwen2_embedding.h b/models/qwen2/qwen2_embedding.h deleted file mode 100644 index 071cbb8f..00000000 --- a/models/qwen2/qwen2_embedding.h +++ /dev/null @@ -1,37 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * Copyright (C) 2026 Seunghui Lee - * - * @file qwen2_embedding.h - * @date 14 January 2026 - * @see https://github.com/nntrainer/nntrainer - * @author Seunghui Lee - * @bug No known bugs except for NYI items - * @note This qwen2_embedding.h constructs a class for Qwen2-based Embedding - * model. - */ -#ifndef __QWEN2_EMBEDDING_H__ -#define __QWEN2_EMBEDDING_H__ - -#include -#include - -namespace quick_dot_ai { - -/** - * @brief Qwen2Embedding class - */ -class Qwen2Embedding : public SentenceTransformer, public Qwen2Transformer { - -public: - Qwen2Embedding(json &cfg, json &generation_cfg, json &nntr_cfg) : - Transformer(cfg, generation_cfg, nntr_cfg, ModelType::EMBEDDING), - SentenceTransformer(cfg, generation_cfg, nntr_cfg), - Qwen2Transformer(cfg, generation_cfg, nntr_cfg) {} - - virtual ~Qwen2Embedding() {} -}; - -} // namespace quick_dot_ai - -#endif /* __QWEN2_EMBEDDING_H__ */ diff --git a/models/qwen3/README.md b/models/qwen3/README.md deleted file mode 100644 index bbc33dd5..00000000 --- a/models/qwen3/README.md +++ /dev/null @@ -1,6 +0,0 @@ -# Qwen3 Model - -This directory contains the implementation for Qwen3 Causal LM. - -## Files -- `qwen3_causallm.cpp`: Qwen3 specific implementation. diff --git a/models/qwen3/meson.build b/models/qwen3/meson.build deleted file mode 100644 index 449c039a..00000000 --- a/models/qwen3/meson.build +++ /dev/null @@ -1,9 +0,0 @@ -qwen3_src = [ - meson.current_source_dir() / 'qwen3_causallm.cpp', - meson.current_source_dir() / 'qwen3_embedding.cpp', -] - -qwen3_inc = include_directories('.') - -quick_dot_ai_src += qwen3_src -quick_dot_ai_inc += qwen3_inc diff --git a/models/qwen3/qwen3_causallm.cpp b/models/qwen3/qwen3_causallm.cpp deleted file mode 100644 index 966d5127..00000000 --- a/models/qwen3/qwen3_causallm.cpp +++ /dev/null @@ -1,125 +0,0 @@ -/** - * Copyright (C) 2025 Samsung Electronics Co., Ltd. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * http://www.apache.org/licenses/LICENSE-2.0 - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - * - * @file qwen3_causallm.cpp - * @date 23 July 2025 - * @brief This defines a qwen3 causal language model. - * @see https://github.com/nnstreamer/ - * @author Eunju Yang - * @bug No known bugs except for NYI items - * - */ -#include -#include -#include - -#include -#include -#include - -namespace quick_dot_ai { - -std::vector Qwen3Transformer::createAttention( - const int layer_id, int seq_len, int n_heads, int head_dim, - std::string query_name, std::string key_name, std::string value_name) { - - std::vector layers; - auto Q = "layer" + std::to_string(layer_id) + "_wq"; - auto Q_norm = "layer" + std::to_string(layer_id) + "_q_norm"; - auto K = "layer" + std::to_string(layer_id) + "_wk"; - auto K_norm = "layer" + std::to_string(layer_id) + "_k_norm"; - auto V = "layer" + std::to_string(layer_id) + "_wv"; - auto A = "layer" + std::to_string(layer_id) + "_attention"; - auto O = "layer" + std::to_string(layer_id) + "_attention_out"; - - // Q layer - std::vector q_params = { - withKey("name", Q), withKey("unit", head_dim * n_heads), - withKey("disable_bias", "true"), withKey("input_layers", query_name), - withKey("weight_initializer", "ones")}; - layers.push_back(createLayer("fully_connected", q_params)); - - // Q-reshaped-norm layer - // q_norm(q_proj.view(hidden_shape)) - std::vector q_norm_params = { - withKey("name", Q_norm), withKey("input_layers", Q), - withKey("packed", "false"), withKey("epsilon", std::to_string(NORM_EPS)), - withKey("feature_size", std::to_string(head_dim))}; - layers.push_back(createLayer("reshaped_rms_norm", q_norm_params)); - - // K layer - std::vector k_params = { - withKey("name", K), withKey("unit", head_dim * n_heads / GQA_SIZE), - withKey("disable_bias", "true"), withKey("input_layers", key_name), - withKey("weight_initializer", "ones")}; - layers.push_back(createLayer("fully_connected", k_params)); - - // K-reshaped-norm layer - // k_norm(k_proj.view(hidden_shape)) - std::vector k_norm_params = { - withKey("name", K_norm), withKey("input_layers", K), - withKey("packed", "false"), withKey("epsilon", std::to_string(NORM_EPS)), - withKey("feature_size", std::to_string(head_dim))}; - layers.push_back(createLayer("reshaped_rms_norm", k_norm_params)); - - // V layer - std::vector v_params = { - withKey("name", V), withKey("unit", head_dim * n_heads / GQA_SIZE), - withKey("disable_bias", "true"), withKey("input_layers", value_name), - withKey("weight_initializer", "ones")}; - layers.push_back(createLayer("fully_connected", v_params)); - - // Attention core layer - std::vector a_params = { - withKey("name", A), - withKey("num_heads", n_heads), - withKey("num_heads_kv", n_heads / GQA_SIZE), - withKey("max_timestep", std::to_string(INIT_SEQ_LEN + NUM_TO_GENERATE)), - withKey("sliding_window", SLIDING_WINDOW), - withKey("rope_theta", ROPE_THETA), - withKey("max_position_embeddings", MAX_POSITION_EMBEDDINGS), - withKey("max_new_tokens", std::to_string(NUM_TO_GENERATE)), - withKey("input_layers", {Q_norm, K_norm, V})}; - layers.push_back(createLayer("mha_core", a_params)); - - // O layer - std::vector o_params = { - withKey("name", O), withKey("unit", DIM), withKey("disable_bias", "true"), - withKey("input_layers", A), withKey("weight_initializer", "ones")}; - layers.push_back(createLayer("fully_connected", o_params)); - - return layers; -} - -void Qwen3Transformer::registerCustomLayers() { - /// - auto &ct_engine = nntrainer::Engine::Global(); - auto app_context = - static_cast(ct_engine.getRegisteredContext("cpu")); - - try { - app_context->registerFactory( - nntrainer::createLayer); - } catch (std::invalid_argument &e) { - std::cerr << "failed to register factory, reason: " << e.what() - << std::endl; - } -} - -void Qwen3CausalLM::registerCustomLayers() { - CausalLM::registerCustomLayers(); - Qwen3Transformer::registerCustomLayers(); -} - -} // namespace quick_dot_ai diff --git a/models/qwen3/qwen3_causallm.h b/models/qwen3/qwen3_causallm.h deleted file mode 100644 index fcf16ef1..00000000 --- a/models/qwen3/qwen3_causallm.h +++ /dev/null @@ -1,63 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * Copyright (C) 2025 Eunju Yang - * - * @file qwen3_causallm.h - * @date 10 July 2025 - * @see https://github.com/nntrainer/nntrainer - * @author Eunju Yang - * @bug No known bugs except for NYI items - * @note Please refer to the following code : - * https://github.com/huggingface/transformers/blob/v4.52.3/src/transformers/models/qwen3/modeling_qwen3.py - */ - -#ifndef __QWEN_CAUSAL_LM_H__ -#define __QWEN_CAUSAL_LM_H__ __QWEN_CAUSAL_LM_H__ - -#include - -namespace quick_dot_ai { - -/** - * @brief Qwen3Transformer class - */ -class Qwen3Transformer : virtual public Transformer { -public: - static constexpr const char *architectures = "Qwen3Transformer"; - - Qwen3Transformer(json &cfg, json &generation_cfg, json &nntr_cfg) : - Transformer(cfg, generation_cfg, nntr_cfg) {} - - virtual ~Qwen3Transformer() = default; - - std::vector createAttention(const int layer_id, int seq_len, - int n_heads, int head_dim, - std::string query_name, - std::string key_name, - std::string value_name) override; - - void registerCustomLayers() override; -}; - -/** - * @brief Qwen3CausalLM class - */ -class Qwen3CausalLM : public CausalLM, public Qwen3Transformer { - -public: - static constexpr const char *architectures = "Qwen3ForCausalLM"; - - Qwen3CausalLM(json &cfg, json &generation_cfg, json &nntr_cfg) : - Transformer(cfg, generation_cfg, nntr_cfg, ModelType::CAUSALLM), - CausalLM(cfg, generation_cfg, nntr_cfg), - Qwen3Transformer(cfg, generation_cfg, nntr_cfg) {} - - virtual ~Qwen3CausalLM() = default; - - void registerCustomLayers() override; - -private: -}; -} // namespace quick_dot_ai - -#endif /* __QWEN3_CAUSAL_LM_H__ */ diff --git a/models/qwen3/qwen3_embedding.cpp b/models/qwen3/qwen3_embedding.cpp deleted file mode 100644 index 24e21e3d..00000000 --- a/models/qwen3/qwen3_embedding.cpp +++ /dev/null @@ -1,22 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * Copyright (C) 2025 Samsung Electronics Co., Ltd. All Rights Reserved. - * - * @file qwen3_embedding.cpp - * @date 07 Jan 2026 - * @see https://github.com/nntrainer/nntrainer - * @author Seungbaek Hong - * @bug No known bugs except for NYI items - * @brief This file defines Qwen3 Embedding model - */ - -#include - -namespace quick_dot_ai { - -void Qwen3Embedding::registerCustomLayers() { - SentenceTransformer::registerCustomLayers(); - Qwen3Transformer::registerCustomLayers(); -} - -} // namespace quick_dot_ai diff --git a/models/qwen3/qwen3_embedding.h b/models/qwen3/qwen3_embedding.h deleted file mode 100644 index 9afa462e..00000000 --- a/models/qwen3/qwen3_embedding.h +++ /dev/null @@ -1,54 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * Copyright (C) 2025 Samsung Electronics Co., Ltd. All Rights Reserved. - * - * @file qwen3_embedding.h - * @date 07 Jan 2026 - * @see https://github.com/nntrainer/nntrainer - * @author Seungbaek Hong - * @bug No known bugs except for NYI items - * @note This qwen3_embedding.h constructs a class for Qwen3-based Embedding - * model. - */ - -#ifndef __QWEN3_EMBEDDING_H__ -#define __QWEN3_EMBEDDING_H__ - -#include -#include - -namespace quick_dot_ai { - -/** - * @brief Qwen3Embedding Class - */ -class Qwen3Embedding : public SentenceTransformer, public Qwen3Transformer { - -public: - static constexpr const char *architectures = "Qwen3Embedding"; - - /** - * @brief Construct a new Qwen3Embedding object - * @param cfg Configuration for the model - * @param generation_cfg Configuration for generation - * @param nntr_cfg Configuration for nntrainer - */ - Qwen3Embedding(json &cfg, json &generation_cfg, json &nntr_cfg) : - Transformer(cfg, generation_cfg, nntr_cfg, ModelType::EMBEDDING), - SentenceTransformer(cfg, generation_cfg, nntr_cfg), - Qwen3Transformer(cfg, generation_cfg, nntr_cfg) {} - - /** - * @brief Destroy the Qwen3Embedding object - */ - virtual ~Qwen3Embedding() = default; - - /** - * @brief register CustomLayers - */ - void registerCustomLayers() override; -}; - -} // namespace quick_dot_ai - -#endif // __QWEN3_EMBEDDING_H__ diff --git a/models/qwen3_cached_slim_moe/README.md b/models/qwen3_cached_slim_moe/README.md deleted file mode 100644 index 40001053..00000000 --- a/models/qwen3_cached_slim_moe/README.md +++ /dev/null @@ -1,9 +0,0 @@ -# Qwen3 Cached Slim MoE Model - -This directory contains the implementation for Qwen3 Slim MoE model with caching support. - -> πŸ“Œ **Note** on `Cached-Slim`: This model extends the Slim approach (dynamic loading) by caching active experts. This strategy minimizes storage I/O bottlenecks, offering a sweet spot between low memory footprint and high inference speed. - -## Files -- `qwen3_cached_slim_moe_causallm.cpp`: Cached Slim MoE implementation. -- `qwen_moe_layer_cached.cpp`: Cached MoE layer implementation. diff --git a/models/qwen3_cached_slim_moe/meson.build b/models/qwen3_cached_slim_moe/meson.build deleted file mode 100644 index 099242f1..00000000 --- a/models/qwen3_cached_slim_moe/meson.build +++ /dev/null @@ -1,25 +0,0 @@ -qwen3_cached_slim_moe_src = [ - meson.current_source_dir() / 'qwen3_cached_slim_moe_causallm.cpp', -] - -qwen3_cached_slim_moe_inc = include_directories('.') - -quick_dot_ai_cached_slim_moe_layer_src_abs = [meson.current_source_dir() / 'qwen_moe_layer_cached.cpp'] - -quick_dot_ai_cached_slim_moe_layer = shared_library( - 'qwen_moe_layer_cached', - quick_dot_ai_cached_slim_moe_layer_src_abs, - include_directories: [quick_dot_ai_layer_inc, qwen3_cached_slim_moe_inc], - dependencies: [nntrainer_dep, nntrainer_ccapi_dep, openmp_dep], - install: true, - install_dir: application_install_dir -) - -quick_dot_ai_cached_slim_moe_layer_dep = declare_dependency( - link_with: quick_dot_ai_cached_slim_moe_layer, - include_directories: qwen3_cached_slim_moe_inc -) - -quick_dot_ai_src += qwen3_cached_slim_moe_src -quick_dot_ai_inc += qwen3_cached_slim_moe_inc -quick_dot_ai_layer_dependencies += [quick_dot_ai_cached_slim_moe_layer_dep] diff --git a/models/qwen3_cached_slim_moe/qwen3_cached_slim_moe_causallm.cpp b/models/qwen3_cached_slim_moe/qwen3_cached_slim_moe_causallm.cpp deleted file mode 100644 index 18914ead..00000000 --- a/models/qwen3_cached_slim_moe/qwen3_cached_slim_moe_causallm.cpp +++ /dev/null @@ -1,81 +0,0 @@ -/** - * Copyright (C) 2025 Samsung Electronics Co., Ltd. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * http://www.apache.org/licenses/LICENSE-2.0 - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - * - * @file qwen3_moe_causallm.cpp - * @date 23 July 2025 - * @brief This defines a qwen3_moe causal language model. - * @see https://github.com/nnstreamer/ - * @author Eunju Yang - * @bug No known bugs except for NYI items - * - */ -#include -#include -#include - -#include -#include -#include - -namespace quick_dot_ai { - -void Qwen3CachedSlimMoECausalLM::setupParameters(json &cfg, - json &generation_cfg, - json &nntr_cfg) { - Qwen3CausalLM(cfg, generation_cfg, nntr_cfg); - - // parameters for Qwen3MoE model - try { - NUM_EXPERTS = cfg["num_experts"]; - NUM_EXPERTS_PER_TOK = cfg["num_experts_per_tok"]; - INTERMEDIATE_SIZE = cfg["moe_intermediate_size"]; - } catch (const std::exception &e) { - throw std::runtime_error("Qwen3MoE: num_experts and num_experts_per_tok " - "are not specified in the config file"); - } -} - -std::vector -Qwen3CachedSlimMoECausalLM::createMlp(const int layer_id, int dim, - int hidden_dim, std::string input_name) { - - std::vector layers; - layers.push_back(createLayer( - "moe_cached_slim", - {withKey("name", "layer" + std::to_string(layer_id) + "_ffn_down"), - withKey("input_layers", input_name), withKey("unit", hidden_dim), - withKey("num_experts", NUM_EXPERTS), - withKey("num_experts_per_token", NUM_EXPERTS_PER_TOK), - withKey("moe_activation", "swish")})); - - return layers; -} - -void Qwen3CachedSlimMoECausalLM::registerCustomLayers() { - - Qwen3CausalLM::registerCustomLayers(); - auto &ct_engine = nntrainer::Engine::Global(); - auto app_context = - static_cast(ct_engine.getRegisteredContext("cpu")); - - try { - app_context->registerFactory( - nntrainer::createLayer); - } catch (std::invalid_argument &e) { - std::cerr << "failed to register factory, reason: " << e.what() - << std::endl; - } -} - -} // namespace quick_dot_ai diff --git a/models/qwen3_cached_slim_moe/qwen3_cached_slim_moe_causallm.h b/models/qwen3_cached_slim_moe/qwen3_cached_slim_moe_causallm.h deleted file mode 100644 index 79d56336..00000000 --- a/models/qwen3_cached_slim_moe/qwen3_cached_slim_moe_causallm.h +++ /dev/null @@ -1,53 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * Copyright (C) 2025 Eunju Yang - * - * @file qwen3_moe_causallm.h - * @date 15 July 2025 - * @see https://github.com/nntrainer/nntrainer - * @author Eunju Yang - * @bug No known bugs except for NYI items - * - */ - -#ifndef __QWEN_CACHED_SLIM_MOE_CAUSAL_LM_H__ -#define __QWEN_CACHED_SLIM_MOE_CAUSAL_LM_H__ - -#include -#include - -namespace quick_dot_ai { - -/** - * @brief Qwen3CachedSlimMoECausalLM class - * @note This class inherits Qwewn3CaUSALlm - */ -class Qwen3CachedSlimMoECausalLM : public Qwen3CausalLM { - -public: - static constexpr const char *architectures = "Qwen3CachedSlimMoeForCausalLM"; - - Qwen3CachedSlimMoECausalLM(json &cfg, json &generation_cfg, json &nntr_cfg) : - Transformer(cfg, generation_cfg, nntr_cfg, ModelType::CAUSALLM), - Qwen3CausalLM(cfg, generation_cfg, nntr_cfg) { - setupParameters(cfg, generation_cfg, nntr_cfg); - } - - virtual ~Qwen3CachedSlimMoECausalLM() = default; - - std::vector createMlp(const int layer_id, int dim, - int hidden_dim, - std::string input_name) override; - - void setupParameters(json &cfg, json &generation_cfg, - json &nntr_cfg) override; - - void registerCustomLayers() override; - -private: - unsigned int NUM_EXPERTS; - unsigned int NUM_EXPERTS_PER_TOK; -}; -}; // namespace quick_dot_ai - -#endif /* __QWEN_MOE_CAUSAL_LM_H__ */ diff --git a/models/qwen3_cached_slim_moe/qwen_moe_layer_cached.cpp b/models/qwen3_cached_slim_moe/qwen_moe_layer_cached.cpp deleted file mode 100644 index 37db86e6..00000000 --- a/models/qwen3_cached_slim_moe/qwen_moe_layer_cached.cpp +++ /dev/null @@ -1,535 +0,0 @@ -/** - * Copyright (C) 2020 Samsung Electronics Co., Ltd. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * http://www.apache.org/licenses/LICENSE-2.0 - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - * - * @file qwen_moe_layer_fsu.cpp - * @date 09 June 2025 - * @brief This is a Mixture of Expert Layer Class for Neural Network - * @see https://github.com/nnstreamer/ - * @author Eunju Yang - * @bug No known bugs except for NYI items - * @note MoE layer with on-the-fly expert FSU - * - */ - -#include -#include -#include -#include -#include -#include -#include -#include - -#include -using std::chrono::duration_cast; -using std::chrono::high_resolution_clock; -using std::chrono::nanoseconds; - -namespace quick_dot_ai { - -static constexpr size_t SINGLE_INOUT_IDX = 0; - -CachedSlimMoELayer::CachedSlimMoELayer() : - LayerImpl(), - num_experts(0), - topk(0), - moe_props(props::NumExperts(), props::NumExpertsPerToken(), - nntrainer::props::Unit(), props::MoEActivation()), - expert_gate_proj_indices({}), - expert_up_proj_indices({}), - expert_down_proj_indices({}), - loaded_expert_deque({}), - need_load({}), - gate_idx(std::numeric_limits::max()), - router_logits_idx(std::numeric_limits::max()), - expert_mask_idx(std::numeric_limits::max()) {} - -void CachedSlimMoELayer::finalize(nntrainer::InitLayerContext &context) { - - // 1. Validate input/output dimensions - NNTR_THROW_IF(context.getNumInputs() != 1, std::invalid_argument) - << "MoE layer only supports single input"; - - auto &weight_regularizer = - std::get(*layer_impl_props); - auto &weight_regularizer_constant = - std::get(*layer_impl_props); - auto &weight_initializer = - std::get(*layer_impl_props); - auto &weight_decay = - std::get(*layer_impl_props); - - // 2. Set output dimensions (same as input) - const auto &in_dim = context.getInputDimensions()[SINGLE_INOUT_IDX]; - const bool is_nchw = context.getFormat() == nntrainer::Tformat::NCHW; - std::vector output_dims(1); - output_dims[SINGLE_INOUT_IDX] = in_dim; - context.setOutputDimensions(output_dims); - - // 3. Get MoE properties - num_experts = std::get(moe_props).get(); - topk = std::get(moe_props).get(); - const unsigned int intermediate_size = - std::get(moe_props).get(); - const unsigned int hidden_size = in_dim.width(); // Feature dimension - - // activation function - if (std::get(moe_props).empty()) { - throw std::runtime_error("Activation type is not set for MoE layer"); - } - switch (context.getActivationDataType()) { - case ml::train::TensorDim::DataType::FP32: - acti_func.setActiFunc( - std::get(moe_props).get()); - break; - default: - throw std::runtime_error("Unsupported activation data type for MoE layer"); - } - - // 4. Initialie gate layer (router) - nntrainer::TensorDim gate_dim( - 1, is_nchw ? 1 : num_experts, is_nchw ? hidden_size : 1, - is_nchw ? num_experts : hidden_size, - nntrainer::TensorDim::TensorType(context.getFormat(), - nntrainer::TensorDim::DataType::FP32), - is_nchw ? 0b0011 : 0b0101); - - gate_idx = context.requestWeight( - gate_dim, weight_initializer, weight_regularizer, - weight_regularizer_constant, weight_decay, "gate", true); - - // 5. Initializer expert weights - expert_gate_proj_indices.reserve(num_experts); - expert_up_proj_indices.reserve(num_experts); - expert_down_proj_indices.reserve(num_experts); - - nntrainer::TensorDim expert_gate_dim( - 1, is_nchw ? 1 : intermediate_size, is_nchw ? hidden_size : 1, - is_nchw ? intermediate_size : hidden_size, - nntrainer::TensorDim::TensorType(context.getFormat(), - context.getWeightDataType()), - is_nchw ? 0b0011 : 0b0101); - - nntrainer::TensorDim expert_down_dim( - 1, is_nchw ? 1 : hidden_size, is_nchw ? intermediate_size : 1, - is_nchw ? hidden_size : intermediate_size, - nntrainer::TensorDim::TensorType(context.getFormat(), - context.getWeightDataType()), - is_nchw ? 0b0011 : 0b0101); - - for (unsigned int i = 0; i < num_experts; ++i) { - // Up projection - expert_up_proj_indices.push_back(context.requestWeight( - expert_gate_dim, // Same dimensions as gate projection - weight_initializer, weight_regularizer, weight_regularizer_constant, - weight_decay, "expert_up_" + std::to_string(i), false, true)); - - // Gate projection - expert_gate_proj_indices.push_back(context.requestWeight( - expert_gate_dim, weight_initializer, weight_regularizer, - weight_regularizer_constant, weight_decay, - "expert_gate_" + std::to_string(i), false, true)); - - // Down projection - expert_down_proj_indices.push_back(context.requestWeight( - expert_down_dim, weight_initializer, weight_regularizer, - weight_regularizer_constant, weight_decay, - "expert_down_" + std::to_string(i), false, true)); - need_load.push_back(true); - } - - // 6. Request intermediate tensors - const unsigned batch_size = in_dim.batch(); - const unsigned seq_len = in_dim.height(); - const unsigned total_tokens = batch_size * seq_len; - - // Router logits : [batch * seq, num_experts] - router_logits_idx = - context.requestTensor({total_tokens, 1, 1, num_experts}, "router_logits", - nntrainer::Initializer::NONE, false, - nntrainer::TensorLifespan::FORWARD_FUNC_LIFESPAN); - - // Expert mask: [num_experts, batch*seq] - expert_mask_idx = - context.requestTensor({num_experts, 1, topk, total_tokens}, "expert_mask", - nntrainer::Initializer::ZEROS, false, - nntrainer::TensorLifespan::FORWARD_FUNC_LIFESPAN); -} - -void CachedSlimMoELayer::forwarding(nntrainer::RunLayerContext &context, - bool training) {} - -inline void CachedSlimMoELayer::compute_expert_forward( - const nntrainer::Tensor &input, nntrainer::Tensor &output, - const std::vector> &token_assignments, - const nntrainer::Tensor &gate_proj, const nntrainer::Tensor &up_proj, - const nntrainer::Tensor &down_proj, unsigned int hidden_size) { - - const unsigned intermediate_size = gate_proj.width(); - const unsigned num_tokens = token_assignments.size(); - - if (num_tokens == 0) - return; - - // Create tensor dimensions for single token processing - nntrainer::TensorDim token_input_dim({1, 1, num_tokens, hidden_size}, - input.getTensorType()); - nntrainer::TensorDim intermediate_dim({1, 1, num_tokens, intermediate_size}, - input.getTensorType()); - nntrainer::TensorDim token_output_dim({1, 1, num_tokens, hidden_size}, - input.getTensorType()); - nntrainer::TensorDim out_step_dim({1, 1, 1, hidden_size}, - input.getTensorType()); - nntrainer::TensorDim step_dim({1, 1, 1, intermediate_size}, - input.getTensorType()); - // Create intermediate tensors for this token - nntrainer::Tensor gate_out(intermediate_dim); - nntrainer::Tensor acti_out(intermediate_dim); - nntrainer::Tensor up_out(intermediate_dim); - nntrainer::Tensor token_input(token_input_dim); - // Down projection using optimized dot operation - nntrainer::Tensor token_expert_output(token_output_dim); - - unsigned token_idx = token_assignments[0].first; - float weight = token_assignments[0].second; - - if (num_tokens > 1) { - /** if prefill, copy data to make a batch */ -#pragma omp parallel for schedule(static) if (num_tokens > 4) - for (size_t i = 0; i < num_tokens; ++i) { - const unsigned token_idx = token_assignments[i].first; - // Use tensor's optimized copy operation - nntrainer::Tensor src_view = input.getSharedDataTensor( - {1, 1, 1, hidden_size}, token_idx * hidden_size, true); - nntrainer::Tensor dst_view = token_input.getSharedDataTensor( - {1, 1, 1, hidden_size}, i * hidden_size, true); - dst_view.copyData(src_view); - } - } else { - /** if token generation, do not copy but get the shared tensor */ - // Create shared tensor for input token (no memory copy) - size_t token_offset = token_idx * hidden_size; - token_input = - input.getSharedDataTensor(token_input_dim, token_offset, true); - } - - // Gate projection using optimized dot operation - token_input.dot(gate_proj, gate_out); - - // Up projection using optimized dot operation - token_input.dot(up_proj, up_out); - - if (num_tokens == 1) { - // Apply activation (silu) - acti_func.run_fn(gate_out, acti_out); - // Element-wise multiply: silu(gate_out) * up_out - acti_out.multiply_i(up_out); - } else { -#pragma omp parallel for schedule(static) if (num_tokens > 4) - for (size_t i = 0; i < num_tokens; ++i) { - const unsigned offset = acti_out.getIndex(0, 0, i, 0); - nntrainer::swiglu(acti_out.width(), acti_out.getData() + offset, - gate_out.getData() + offset, - up_out.getData() + offset); - } - } - - acti_out.dot(down_proj, token_expert_output); - - // accumulate to output - for (size_t i = 0; i < num_tokens; ++i) { - token_idx = token_assignments[i].first; - weight = token_assignments[i].second; - size_t output_offset = token_idx * hidden_size; - nntrainer::Tensor token_output = - output.getSharedDataTensor(out_step_dim, output_offset, true); - nntrainer::Tensor target = token_expert_output.getSharedDataTensor( - out_step_dim, i * hidden_size, true); - target.multiply_i(weight); - token_output.add(target, token_output); - } -} - -void CachedSlimMoELayer::incremental_forwarding( - nntrainer::RunLayerContext &context, unsigned int from, unsigned int to, - bool training) { - -#ifdef DEBUG - auto t1 = high_resolution_clock::now(); -#endif - - nntrainer::Tensor &input_ = context.getInput(SINGLE_INOUT_IDX); - nntrainer::Tensor &output_ = context.getOutput(SINGLE_INOUT_IDX); - - nntrainer::Tensor &router_logits_ = context.getTensor(router_logits_idx); - - nntrainer::TensorDim input_step_dim = input_.getDim(); - nntrainer::TensorDim output_step_dim = output_.getDim(); - nntrainer::TensorDim router_logits_step_dim = router_logits_.getDim(); - - input_step_dim.batch(1); - output_step_dim.batch(1); - router_logits_step_dim.batch(to - from); - - input_step_dim.height(to - from); - output_step_dim.height(to - from); - - for (unsigned int b = 0; b < input_.batch(); ++b) { - - auto input = input_.getSharedDataTensor( - input_step_dim, b * input_step_dim.getFeatureLen(), true); - auto output = output_.getSharedDataTensor( - output_step_dim, b * output_step_dim.getFeatureLen(), true); - auto router_logits = - router_logits_.getSharedDataTensor(router_logits_step_dim, 0, true); - - const unsigned batch_size = input.batch(); - const unsigned seq_len = input.height(); - const unsigned hidden_size = input.width(); - const unsigned total_tokens = batch_size * seq_len; - - // reshape input: [B,1,S,H] -> [B*S,1,1,H] - input.reshape({total_tokens, 1, 1, hidden_size}); - - // reshape output: [B,1,S,H] -> [B*S,1,1,H] - output.reshape({total_tokens, 1, 1, hidden_size}); - output.setZero(); - - // routing - nntrainer::Tensor &gate_weights = context.getWeight(gate_idx); - input.dot(gate_weights, router_logits); - router_logits.apply(nntrainer::ActiFunc::softmax, router_logits); - - // get extra topK - auto extra_topk_result = router_logits.topK(topk + 5); - auto extra_topk_values = std::get<0>(extra_topk_result); - auto extra_topk_indices = std::get<1>(extra_topk_result); - std::deque extra_top_k = {}; - extra_topk_values.divide_i(extra_topk_values.sum(3)); - const uint32_t *extra_indices_data = extra_topk_indices.getData(); - - // get extra topk - for (int i = static_cast(total_tokens) - 1; i >= 0; --i) { - for (int k = 0; k < static_cast(topk + 5); ++k) { - unsigned expert_idx = extra_indices_data[i * topk + k]; - extra_top_k.push_back(expert_idx); - } - } - - auto topk_result = router_logits.topK(topk); - auto topk_values = std::get<0>(topk_result); - auto topk_indices = std::get<1>(topk_result); - - // norm_topk_prob - topk_values.divide_i(topk_values.sum(3)); - - const uint32_t *indices_data = topk_indices.getData(); - std::vector>> expert_assignments( - num_experts); - // Set expert mask - for (int i = 0; i < static_cast(total_tokens); ++i) { - for (int k = 0; k < static_cast(topk); ++k) { - unsigned expert_idx = indices_data[i * topk + k]; - float weight = topk_values.getValue(i, 0, 0, k); - expert_assignments[expert_idx].emplace_back(i, weight); - } - } - - // Parallel processing for multiple tokens with many active experts - std::vector expert_outputs(num_experts); -#pragma omp parallel for schedule(static) - for (int expert_idx = 0; expert_idx < static_cast(num_experts); - ++expert_idx) { - if (!expert_assignments[expert_idx].empty()) { - expert_outputs[expert_idx] = nntrainer::Tensor( - total_tokens, 1, 1, hidden_size, output.getTensorType()); - } - } - std::vector target_idx_vector; - - for (int expert_idx = 0; expert_idx < static_cast(num_experts); - ++expert_idx) { - const auto &assignments = expert_assignments[expert_idx]; - if (assignments.empty()) - continue; - - target_idx_vector.push_back(expert_idx); - } - - int hit_count = 0; - int miss_count = 0; - -#ifdef DEBUG - auto t1_miss = high_resolution_clock::now(); - auto t2_miss = t1_miss; - auto t1_hit = high_resolution_clock::now(); - auto t2_hit = t1_hit; -#endif - -#pragma omp parallel for schedule(dynamic) - for (int expert_idx : target_idx_vector) { - const auto &assignments = expert_assignments[expert_idx]; - if (need_load[expert_idx]) { - -#ifdef DEBUG - t1_miss = high_resolution_clock::now(); -#endif - - context.getWeight(expert_gate_proj_indices[expert_idx]).activate(); - context.getWeight(expert_up_proj_indices[expert_idx]).activate(); - context.getWeight(expert_down_proj_indices[expert_idx]).activate(); - - { - std::lock_guard lock(cache_mutex); - loaded_expert_deque.push_back(expert_idx); - iteration_map[expert_idx] = --loaded_expert_deque.end(); - need_load[expert_idx] = false; - miss_count += 1; - } - - compute_expert_forward( - input, expert_outputs[expert_idx], assignments, - context.getWeight(expert_gate_proj_indices[expert_idx]), - context.getWeight(expert_up_proj_indices[expert_idx]), - context.getWeight(expert_down_proj_indices[expert_idx]), hidden_size); -#ifdef DEBUG - t2_miss = high_resolution_clock::now(); -#endif - } else { - -#ifdef DEBUG - t1_hit = high_resolution_clock::now(); -#endif - { - std::lock_guard lock(cache_mutex); - hit_count += 1; - } - - compute_expert_forward( - input, expert_outputs[expert_idx], assignments, - context.getWeight(expert_gate_proj_indices[expert_idx]), - context.getWeight(expert_up_proj_indices[expert_idx]), - context.getWeight(expert_down_proj_indices[expert_idx]), hidden_size); - -#ifdef DEBUG - t2_hit = high_resolution_clock::now(); -#endif - } - } - - for (int i = extra_top_k.size() - 1; i >= 0; i--) { - if (iteration_map.find(extra_top_k[i]) != iteration_map.end()) { - loaded_expert_deque.erase(iteration_map[extra_top_k[i]]); - loaded_expert_deque.push_back(extra_top_k[i]); - iteration_map[extra_top_k[i]] = --loaded_expert_deque.end(); - } - } - -#ifdef DEBUG - auto t1_evict = high_resolution_clock::now(); -#endif - -// Evict experts -#pragma omp parallel - while (loaded_expert_deque.size() > 32) { - int target_idx; - { - std::lock_guard lock(cache_mutex); - target_idx = loaded_expert_deque.front(); - loaded_expert_deque.pop_front(); - iteration_map.erase(target_idx); - need_load[target_idx] = true; - } - - context.getWeight(expert_gate_proj_indices[target_idx]).deactivate(); - context.getWeight(expert_up_proj_indices[target_idx]).deactivate(); - context.getWeight(expert_down_proj_indices[target_idx]).deactivate(); - } - -#ifdef DEBUG - auto t2_evict = high_resolution_clock::now(); -#endif - - // Combine expert outputs - int init = 0; - for (int expert_idx : target_idx_vector) { - if (!init) { - output.copyData(expert_outputs[expert_idx]); - ++init; - } else { - output.add_i(expert_outputs[expert_idx]); - } - } - - // reshape output: [B*S,1,1,H] -> [B,1,S,H] - output.reshape({batch_size, 1, seq_len, hidden_size}); - -#ifdef DEBUG - auto t2 = high_resolution_clock::now(); - auto dt = duration_cast(t2 - t1); - auto dt_miss = duration_cast(t2_miss - t1_miss); - auto dt_hit = duration_cast(t2_hit - t1_hit); - auto dt_evict = duration_cast(t2_evict - t1_evict); - std::cout << context.getName() << " \t| " << dt.count() << " ns " - << "\t| " << dt.count() / 1'000 << " us " - << "\t| " << dt.count() / 1'000'000 << " ms " - << "\t| " - << "hit ratio: " << hit_count / 8.0 << "\t | " - << " miss ratio: " << miss_count / 8.0 << "\t | " - << "hit_compute: " << dt_hit.count() / 1'000'000 << " ms " - << "\t| " - << "miss_compute: " << dt_miss.count() / 1'000'000 << " ms " - << "\t| " - << "evict_time: " << dt_evict.count() / 1'000'000 << " ms " - << "\t| " << std::endl; -#endif - } -} - -void CachedSlimMoELayer::setProperty(const std::vector &values) { - auto remain_props = loadProperties(values, moe_props); - nntrainer::LayerImpl::setProperty(remain_props); -} - -void CachedSlimMoELayer::calcDerivative(nntrainer::RunLayerContext &context) { - // MoE layer does not support derivative calculation - throw std::runtime_error("MoE layer does not support derivative calculation"); -} - -void CachedSlimMoELayer::calcGradient(nntrainer::RunLayerContext &context) { - // MoE layer does not support gradient calculation - throw std::runtime_error("MoE layer does not support gradient calculation"); -} - -void CachedSlimMoELayer::exportTo( - nntrainer::Exporter &exporter, const ml::train::ExportMethods &method) const { - nntrainer::LayerImpl::exportTo(exporter, method); - exporter.saveResult(moe_props, method, this); // Save MoE specific properties -} - -void CachedSlimMoELayer::updateTensorsByInputDimensions( - nntrainer::RunLayerContext &context, - std::vector input_dimensions) { - ml::train::TensorDim input_dim = context.getInput(SINGLE_INOUT_IDX).getDim(); - ml::train::TensorDim output_dim = - context.getOutput(SINGLE_INOUT_IDX).getDim(); - - input_dim.height(input_dimensions[0].height()); - output_dim.height(input_dimensions[0].height()); - - context.updateInput(SINGLE_INOUT_IDX, input_dim); - context.updateOutput(SINGLE_INOUT_IDX, output_dim); -} - -} // namespace quick_dot_ai diff --git a/models/qwen3_cached_slim_moe/qwen_moe_layer_cached.h b/models/qwen3_cached_slim_moe/qwen_moe_layer_cached.h deleted file mode 100644 index a679c820..00000000 --- a/models/qwen3_cached_slim_moe/qwen_moe_layer_cached.h +++ /dev/null @@ -1,161 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * Copyright (C) 2025 Eunju Yang - * - * @file qwen_moe_layer_fsu.h - * @date 09 June 2025 - * @brief This is Mixture of Expert Layer Class of Neural Network - * @see https://github.com/nntrainer/nntrainer - * @author Eunju Yang - * @bug No known bugs except for NYI items - * @note This file is part of the Mixture of Expert Layer implementation. - * It does not support shared experts. - * This layer is implemented based on the LLama-MoE. - * For more information, please refer to the following link: - * https://arxiv.org/pdf/2406.16554 - * @todo This layer does not support backwarding yet. - */ - -#ifndef __MOE_LAYER_H__ -#define __MOE_LAYER_H__ -#ifdef __cplusplus - -#include -#include -#include -#include -#include - -namespace quick_dot_ai { - -/** - * @class SlimMoELayer - * @brief Mixture of Expert Layer - */ -class CachedSlimMoELayer : public nntrainer::LayerImpl { -public: - /** - * @brief Constructor of Mixture of Expert Layer - */ - CachedSlimMoELayer(); - - /** - * @brief Destructor of Mixture of Expert Layer - */ - ~CachedSlimMoELayer() = default; - - /** - * @brief Move constructor. - * @param[in] CachedSlimMoELayer && - */ - CachedSlimMoELayer(CachedSlimMoELayer &&rhs) = delete; - - /** - * @brief Move assignment operator. - * @param[in] rhs CachedSlimMoELayer to be moved. - */ - CachedSlimMoELayer &operator=(CachedSlimMoELayer &&rhs) = delete; - - /** - * @copydoc Layer::finalize(InitLayerContext &context) - */ - void finalize(nntrainer::InitLayerContext &context) override; - - /** - * @copydoc Layer::forwarding(RunLayerContext &context, bool training) - */ - void forwarding(nntrainer::RunLayerContext &context, bool training) override; - - /** - * @copydoc Layer::incremental_forwarding(RunLayerContext &context, unsigned) - */ - void incremental_forwarding(nntrainer::RunLayerContext &context, - unsigned int from, unsigned int to, - bool training) override; - - /** - * @copydoc Layer::calcDerivative(RunLayerContext &context) - */ - void calcDerivative(nntrainer::RunLayerContext &context) override; - - /** - * @copydoc Layer::calcGradient(RunLayerContext &context) - */ - void calcGradient(nntrainer::RunLayerContext &context) override; - - /** - * @copydoc Layer::setProperty(const std::vector &values) - */ - void setProperty(const std::vector &values) override; - - /** - * @copydoc Layer::exportTo(Exporter &exporter, const ml::train::ExportMethods - * &methods) - */ - void exportTo(nntrainer::Exporter &exporter, - const ml::train::ExportMethods &method) const override; - - /** - * @copydoc Layer::getType() - */ - const std::string getType() const override { - return CachedSlimMoELayer::type; - }; - - /** - * @brief Layer::supportBackwarding() - */ - bool supportBackwarding() const override { return false; } - - WIN_EXPORT void updateTensorsByInputDimensions( - nntrainer::RunLayerContext &context, - std::vector input_dimensions) override; - - static constexpr const char *type = - "moe_cached_slim"; /**< type of the layer */ - -private: - unsigned int num_experts; /**< number of experts */ - unsigned int topk; /**< number of experts per token, i.e., topk */ - nntrainer::ActiFunc acti_func; /**< activation function for the expert */ - std::tuple - moe_props; - - // weight indeices - std::vector expert_gate_proj_indices; - std::vector expert_up_proj_indices; - std::vector expert_down_proj_indices; - - std::list loaded_expert_deque; - std::unordered_map::iterator> iteration_map; - std::unordered_map expert_predict_scores; - std::vector need_load; - std::mutex cache_mutex; - - unsigned int gate_idx; - - // Intermediate tensor indices - unsigned int router_logits_idx; - unsigned int expert_mask_idx; - /** - * @brief expert forward computation without memory copies - * @param input Input tensor (reshaped to [total_tokens, 1, 1, hidden_size]) - * @param output Output tensor to accumulate results - * @param token_assignments Vector of (token_index, weight) pairs for this - * expert - * @param gate_proj Gate projection weight tensor - * @param up_proj Up projection weight tensor - * @param down_proj Down projection weight tensor - * @param hidden_size Hidden dimension size - */ - inline void compute_expert_forward( - const nntrainer::Tensor &input, nntrainer::Tensor &output, - const std::vector> &token_assignments, - const nntrainer::Tensor &gate_proj, const nntrainer::Tensor &up_proj, - const nntrainer::Tensor &down_proj, unsigned int hidden_size); -}; -} // namespace quick_dot_ai - -#endif /* __cplusplus */ -#endif /* __MOE_LAYER_H__ */ diff --git a/models/qwen3_moe/README.md b/models/qwen3_moe/README.md deleted file mode 100644 index c901df58..00000000 --- a/models/qwen3_moe/README.md +++ /dev/null @@ -1,7 +0,0 @@ -# Qwen3 MoE Model - -This directory contains the implementation for Qwen3 Mixture of Experts (MoE) model. - -## Files -- `qwen3_moe_causallm.cpp`: Qwen3 MoE implementation. -- `qwen_moe_layer.cpp`: MoE layer implementation used by this model. diff --git a/models/qwen3_moe/meson.build b/models/qwen3_moe/meson.build deleted file mode 100644 index 5d472328..00000000 --- a/models/qwen3_moe/meson.build +++ /dev/null @@ -1,26 +0,0 @@ -qwen3_moe_src = [ - meson.current_source_dir() / 'qwen3_moe_causallm.cpp', -] - -qwen3_moe_inc = include_directories('.') - -# Define Layers -causallm_moe_layer_src_abs = [meson.current_source_dir() / 'qwen_moe_layer.cpp'] - -causallm_moe_layer = shared_library( - 'qwen_moe_layer', - causallm_moe_layer_src_abs, - include_directories: [quick_dot_ai_layer_inc, qwen3_moe_inc], - dependencies: [nntrainer_dep, nntrainer_ccapi_dep, openmp_dep], - install: true, - install_dir: application_install_dir -) - -causallm_moe_layer_dep = declare_dependency( - link_with: causallm_moe_layer, - include_directories: qwen3_moe_inc -) - -quick_dot_ai_src += qwen3_moe_src -quick_dot_ai_inc += qwen3_moe_inc -quick_dot_ai_layer_dependencies += [causallm_moe_layer_dep] diff --git a/models/qwen3_moe/qwen3_moe_causallm.cpp b/models/qwen3_moe/qwen3_moe_causallm.cpp deleted file mode 100644 index 83944dd3..00000000 --- a/models/qwen3_moe/qwen3_moe_causallm.cpp +++ /dev/null @@ -1,79 +0,0 @@ -/** - * Copyright (C) 2025 Samsung Electronics Co., Ltd. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * http://www.apache.org/licenses/LICENSE-2.0 - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - * - * @file qwen3_moe_causallm.cpp - * @date 23 July 2025 - * @brief This defines a qwen3_moe causal language model. - * @see https://github.com/nnstreamer/ - * @author Eunju Yang - * @bug No known bugs except for NYI items - * - */ -#include -#include -#include - -#include -#include -#include - -namespace quick_dot_ai { - -void Qwen3MoECausalLM::setupParameters(json &cfg, json &generation_cfg, - json &nntr_cfg) { - Qwen3CausalLM(cfg, generation_cfg, nntr_cfg); - - // parameters for Qwen3MoE model - try { - NUM_EXPERTS = cfg["num_experts"]; - NUM_EXPERTS_PER_TOK = cfg["num_experts_per_tok"]; - INTERMEDIATE_SIZE = cfg["moe_intermediate_size"]; - } catch (const std::exception &e) { - throw std::runtime_error("Qwen3MoE: num_experts and num_experts_per_tok " - "are not specified in the config file"); - } -} - -std::vector Qwen3MoECausalLM::createMlp(const int layer_id, - int dim, int hidden_dim, - std::string input_name) { - - std::vector layers; - layers.push_back(createLayer( - "qwen_moe", - {withKey("name", "layer" + std::to_string(layer_id) + "_ffn_down"), - withKey("input_layers", input_name), withKey("unit", hidden_dim), - withKey("num_experts", NUM_EXPERTS), - withKey("num_experts_per_token", NUM_EXPERTS_PER_TOK), - withKey("moe_activation", "swish")})); - - return layers; -} - -void Qwen3MoECausalLM::registerCustomLayers() { - - Qwen3CausalLM::registerCustomLayers(); - auto &ct_engine = nntrainer::Engine::Global(); - auto app_context = - static_cast(ct_engine.getRegisteredContext("cpu")); - - try { - app_context->registerFactory(nntrainer::createLayer); - } catch (std::invalid_argument &e) { - std::cerr << "failed to register factory, reason: " << e.what() - << std::endl; - } -} - -} // namespace quick_dot_ai diff --git a/models/qwen3_moe/qwen3_moe_causallm.h b/models/qwen3_moe/qwen3_moe_causallm.h deleted file mode 100644 index 25efbcd1..00000000 --- a/models/qwen3_moe/qwen3_moe_causallm.h +++ /dev/null @@ -1,53 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * Copyright (C) 2025 Eunju Yang - * - * @file qwen3_moe_causallm.h - * @date 15 July 2025 - * @see https://github.com/nntrainer/nntrainer - * @author Eunju Yang - * @bug No known bugs except for NYI items - * - */ - -#ifndef __QWEN_MOE_CAUSAL_LM_H__ -#define __QWEN_MOE_CAUSAL_LM_H__ - -#include -#include - -namespace quick_dot_ai { - -/** - * @brief Qwen3MoECausalLM class - * @note This class inherits Qwewn3CaUSALlm - */ -class Qwen3MoECausalLM : public Qwen3CausalLM { - -public: - static constexpr const char *architectures = "Qwen3MoeForCausalLM"; - - Qwen3MoECausalLM(json &cfg, json &generation_cfg, json &nntr_cfg) : - Transformer(cfg, generation_cfg, nntr_cfg, ModelType::CAUSALLM), - Qwen3CausalLM(cfg, generation_cfg, nntr_cfg) { - setupParameters(cfg, generation_cfg, nntr_cfg); - } - - virtual ~Qwen3MoECausalLM() = default; - - std::vector createMlp(const int layer_id, int dim, - int hidden_dim, - std::string input_name) override; - - void setupParameters(json &cfg, json &generation_cfg, - json &nntr_cfg) override; - - void registerCustomLayers() override; - -private: - unsigned int NUM_EXPERTS; - unsigned int NUM_EXPERTS_PER_TOK; -}; -}; // namespace quick_dot_ai - -#endif /* __QWEN_MOE_CAUSAL_LM_H__ */ diff --git a/models/qwen3_moe/qwen_moe_layer.cpp b/models/qwen3_moe/qwen_moe_layer.cpp deleted file mode 100644 index 81e807cc..00000000 --- a/models/qwen3_moe/qwen_moe_layer.cpp +++ /dev/null @@ -1,526 +0,0 @@ - -/** - * Copyright (C) 2020 Samsung Electronics Co., Ltd. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * http://www.apache.org/licenses/LICENSE-2.0 - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - * - * @file moe_layer.cpp - * @date 09 June 2025 - * @brief This is a Mixture of Expert Layer Class for Neural Network - * @see https://github.com/nnstreamer/ - * @author Eunju Yang - * @bug No known bugs except for NYI items - * - */ - -#include -#include -#include -#include -#include -#include -#include - -namespace quick_dot_ai { - -static constexpr size_t SINGLE_INOUT_IDX = 0; - -MoELayer::MoELayer() : - LayerImpl(), - num_experts(0), - topk(0), - moe_props(props::NumExperts(), props::NumExpertsPerToken(), - nntrainer::props::Unit(), props::MoEActivation()), - expert_gate_proj_indices({}), - expert_up_proj_indices({}), - expert_down_proj_indices({}), - gate_idx(std::numeric_limits::max()), - router_logits_idx(std::numeric_limits::max()), - expert_mask_idx(std::numeric_limits::max()) {} - -void MoELayer::finalize(nntrainer::InitLayerContext &context) { - - // 1. Validate input/output dimensions - NNTR_THROW_IF(context.getNumInputs() != 1, std::invalid_argument) - << "MoE layer only supports single input"; - - auto &weight_regularizer = - std::get(*layer_impl_props); - auto &weight_regularizer_constant = - std::get(*layer_impl_props); - auto &weight_initializer = - std::get(*layer_impl_props); - auto &weight_decay = - std::get(*layer_impl_props); - - // 2. Set output dimensions (same as input) - const auto &in_dim = context.getInputDimensions()[SINGLE_INOUT_IDX]; - const bool is_nchw = context.getFormat() == nntrainer::Tformat::NCHW; - std::vector output_dims(1); - output_dims[SINGLE_INOUT_IDX] = in_dim; - context.setOutputDimensions(output_dims); - - // 3. Get MoE properties - num_experts = std::get(moe_props).get(); - topk = std::get(moe_props).get(); - const unsigned int intermediate_size = - std::get(moe_props).get(); - const unsigned int hidden_size = in_dim.width(); // Feature dimension - - // activation function - if (std::get(moe_props).empty()) { - throw std::runtime_error("Activation type is not set for MoE layer"); - } - switch (context.getActivationDataType()) { - case ml::train::TensorDim::DataType::FP32: - acti_func.setActiFunc( - std::get(moe_props).get()); - break; - default: - throw std::runtime_error("Unsupported activation data type for MoE layer"); - } - - // 4. Initialie gate layer (router) - nntrainer::TensorDim gate_dim( - 1, is_nchw ? 1 : num_experts, is_nchw ? hidden_size : 1, - is_nchw ? num_experts : hidden_size, - nntrainer::TensorDim::TensorType(context.getFormat(), - nntrainer::TensorDim::DataType::FP32), - is_nchw ? 0b0011 : 0b0101); - - gate_idx = context.requestWeight( - gate_dim, weight_initializer, weight_regularizer, - weight_regularizer_constant, weight_decay, "gate", true); - - // 5. Initializer expert weights - expert_gate_proj_indices.reserve(num_experts); - expert_up_proj_indices.reserve(num_experts); - expert_down_proj_indices.reserve(num_experts); - - nntrainer::TensorDim expert_gate_dim( - 1, is_nchw ? 1 : intermediate_size, is_nchw ? hidden_size : 1, - is_nchw ? intermediate_size : hidden_size, - nntrainer::TensorDim::TensorType(context.getFormat(), - context.getWeightDataType()), - is_nchw ? 0b0011 : 0b0101); - - nntrainer::TensorDim expert_down_dim( - 1, is_nchw ? 1 : hidden_size, is_nchw ? intermediate_size : 1, - is_nchw ? hidden_size : intermediate_size, - nntrainer::TensorDim::TensorType(context.getFormat(), - context.getWeightDataType()), - is_nchw ? 0b0011 : 0b0101); - - for (unsigned int i = 0; i < num_experts; ++i) { - // Up projection - expert_up_proj_indices.push_back(context.requestWeight( - expert_gate_dim, // Same dimensions as gate projection - weight_initializer, weight_regularizer, weight_regularizer_constant, - weight_decay, "expert_up_" + std::to_string(i), false)); - - // Gate projection - expert_gate_proj_indices.push_back(context.requestWeight( - expert_gate_dim, weight_initializer, weight_regularizer, - weight_regularizer_constant, weight_decay, - "expert_gate_" + std::to_string(i), false)); - - // Down projection - expert_down_proj_indices.push_back(context.requestWeight( - expert_down_dim, weight_initializer, weight_regularizer, - weight_regularizer_constant, weight_decay, - "expert_down_" + std::to_string(i), false)); - } - - // 6. Request intermediate tensors - const unsigned batch_size = in_dim.batch(); - const unsigned seq_len = in_dim.height(); - const unsigned total_tokens = batch_size * seq_len; - - // Router logits : [batch * seq, num_experts] - router_logits_idx = - context.requestTensor({total_tokens, 1, 1, num_experts}, "router_logits", - nntrainer::Initializer::NONE, false, - nntrainer::TensorLifespan::FORWARD_FUNC_LIFESPAN); - - // Expert mask: [num_experts, batch*seq] - expert_mask_idx = - context.requestTensor({num_experts, 1, topk, total_tokens}, "expert_mask", - nntrainer::Initializer::ZEROS, false, - nntrainer::TensorLifespan::FORWARD_FUNC_LIFESPAN); -} - -void MoELayer::forwarding(nntrainer::RunLayerContext &context, bool training) { - nntrainer::Tensor &input = context.getInput(SINGLE_INOUT_IDX); - nntrainer::Tensor &output = context.getOutput(SINGLE_INOUT_IDX); - - nntrainer::Tensor &router_logits = context.getTensor(router_logits_idx); - nntrainer::Tensor &expert_mask = context.getTensor(expert_mask_idx); - - const unsigned batch_size = input.batch(); - const unsigned seq_len = input.height(); - const unsigned hidden_size = input.width(); - const unsigned total_tokens = batch_size * seq_len; - - // reshape input: [B,1,S,H] -> [B*S,1,1,H] - input.reshape({total_tokens, 1, 1, hidden_size}); - - // reshape output: [B,1,S,H] -> [B*S,1,1,H] - output.reshape({total_tokens, 1, 1, hidden_size}); - output.setZero(); - - // routing - nntrainer::Tensor &gate_weights = context.getWeight(gate_idx); - input.dot(gate_weights, router_logits); - router_logits.apply(nntrainer::ActiFunc::softmax, router_logits); - auto topk_result = router_logits.topK(topk); - auto topk_values = std::get<0>(topk_result); - auto topk_indices = std::get<1>(topk_result); - - const uint32_t *indices_data = topk_indices.getData(); -#pragma omp parallel for collapse(2) - for (int i = 0; i < static_cast(total_tokens); ++i) { - for (int k = 0; k < static_cast(topk); ++k) { - expert_mask.setValue(indices_data[i * topk + k], 0, k, i, 1.0f); - } - } - - // Pre-compute expert token assignments for better cache locality - std::vector>> expert_assignments( - num_experts); - for (int i = 0; i < static_cast(total_tokens); ++i) { - for (int k = 0; k < static_cast(topk); ++k) { - unsigned expert_idx = indices_data[i * topk + k]; - float weight = topk_values.getValue(i, 0, 0, k); - expert_assignments[expert_idx].emplace_back(i, weight); - } - } - - // Adaptive optimization based on workload - const int active_experts = - std::count_if(expert_assignments.begin(), expert_assignments.end(), - [](const auto &assignments) { return !assignments.empty(); }); - - // Calculate total work (sum of token assignments across all experts) - int total_work = 0; - for (const auto &assignments : expert_assignments) { - total_work += assignments.size(); - } - - // Use parallel processing only when it's beneficial - const bool use_parallel = (total_work > 4) && (active_experts > 1); - - if (use_parallel) { - // Parallel processing for larger workloads -#pragma omp parallel - { -#pragma omp for schedule(dynamic) - for (int expert_idx = 0; expert_idx < static_cast(num_experts); - ++expert_idx) { - const auto &assignments = expert_assignments[expert_idx]; - if (assignments.empty()) - continue; - - // Use optimized expert forward computation without memory copies - compute_expert_forward( - input, output, assignments, - context.getWeight(expert_gate_proj_indices[expert_idx]), - context.getWeight(expert_up_proj_indices[expert_idx]), - context.getWeight(expert_down_proj_indices[expert_idx]), hidden_size); - } - } - } else { - // Sequential processing for smaller workloads - for (int expert_idx = 0; expert_idx < static_cast(num_experts); - ++expert_idx) { - const auto &assignments = expert_assignments[expert_idx]; - if (assignments.empty()) - continue; - - // Use optimized expert forward computation without memory copies - compute_expert_forward( - input, output, assignments, - context.getWeight(expert_gate_proj_indices[expert_idx]), - context.getWeight(expert_up_proj_indices[expert_idx]), - context.getWeight(expert_down_proj_indices[expert_idx]), hidden_size); - } - } - - // reshape output: [B*S,1,1,H] -> [B,1,S,H] - output.reshape({batch_size, 1, seq_len, hidden_size}); -} - -inline void MoELayer::compute_expert_forward( - const nntrainer::Tensor &input, nntrainer::Tensor &output, - const std::vector> &token_assignments, - const nntrainer::Tensor &gate_proj, const nntrainer::Tensor &up_proj, - const nntrainer::Tensor &down_proj, unsigned int hidden_size) { - - const unsigned intermediate_size = gate_proj.width(); - const unsigned num_tokens = token_assignments.size(); - - if (num_tokens == 0) - return; - - // Create tensor dimensions for single token processing - nntrainer::TensorDim token_input_dim({1, 1, 1, hidden_size}, - input.getTensorType()); - nntrainer::TensorDim intermediate_dim({1, 1, 1, intermediate_size}, - input.getTensorType()); - nntrainer::TensorDim token_output_dim({1, 1, 1, hidden_size}, - input.getTensorType()); - - // Create a temporary output tensor for this expert to avoid critical section - nntrainer::Tensor expert_output(output.batch(), output.channel(), - output.height(), output.width(), - output.getTensorType()); - expert_output.setZero(); - - // Process each token individually to avoid memory copies - for (size_t i = 0; i < num_tokens; ++i) { - const unsigned token_idx = token_assignments[i].first; - const float weight = token_assignments[i].second; - - // Create shared tensor for input token (no memory copy) - size_t token_offset = token_idx * hidden_size; - nntrainer::Tensor token_input = - input.getSharedDataTensor(token_input_dim, token_offset, true); - - // Create intermediate tensors for this token - nntrainer::Tensor gate_out(intermediate_dim); - nntrainer::Tensor acti_out(intermediate_dim); - nntrainer::Tensor up_out(intermediate_dim); - - // Gate projection using optimized dot operation - token_input.dot(gate_proj, gate_out); - - // Apply activation (silu) - acti_func.run_fn(gate_out, acti_out); - - // Up projection using optimized dot operation - token_input.dot(up_proj, up_out); - - // Element-wise multiply: silu(gate_out) * up_out - acti_out.multiply_i(up_out); - - // Down projection using optimized dot operation - nntrainer::Tensor token_expert_output(token_output_dim); - acti_out.dot(down_proj, token_expert_output); - - // Apply weight and accumulate to expert's temporary output - token_expert_output.multiply_i(weight); - size_t output_offset = token_idx * hidden_size; - nntrainer::Tensor token_output = - expert_output.getSharedDataTensor(token_output_dim, output_offset, true); - - token_output.add_i(token_expert_output); - } - - // Add expert's result to final output (no critical section in sequential - // mode) - output.add_i(expert_output); -} - -inline void MoELayer::compute_expert_forward_no_critical( - const nntrainer::Tensor &input, nntrainer::Tensor &expert_output, - const std::vector> &token_assignments, - const nntrainer::Tensor &gate_proj, const nntrainer::Tensor &up_proj, - const nntrainer::Tensor &down_proj, unsigned int hidden_size) { - - const unsigned intermediate_size = gate_proj.width(); - const unsigned num_tokens = token_assignments.size(); - - if (num_tokens == 0) - return; - - // Create tensor dimensions for single token processing - nntrainer::TensorDim token_input_dim({1, 1, 1, hidden_size}, - input.getTensorType()); - nntrainer::TensorDim intermediate_dim({1, 1, 1, intermediate_size}, - input.getTensorType()); - nntrainer::TensorDim token_output_dim({1, 1, 1, hidden_size}, - input.getTensorType()); - - // Process each token individually to avoid memory copies - for (size_t i = 0; i < num_tokens; ++i) { - const unsigned token_idx = token_assignments[i].first; - const float weight = token_assignments[i].second; - - // Create shared tensor for input token (no memory copy) - size_t token_offset = token_idx * hidden_size; - nntrainer::Tensor token_input = - input.getSharedDataTensor(token_input_dim, token_offset, true); - - // Create intermediate tensors for this token - nntrainer::Tensor gate_out(intermediate_dim); - nntrainer::Tensor acti_out(intermediate_dim); - nntrainer::Tensor up_out(intermediate_dim); - - // Gate projection using optimized dot operation - token_input.dot(gate_proj, gate_out); - - // Apply activation (silu) - acti_func.run_fn(gate_out, acti_out); - - // Up projection using optimized dot operation - token_input.dot(up_proj, up_out); - - // Element-wise multiply: silu(gate_out) * up_out - acti_out.multiply_i(up_out); - - // Down projection using optimized dot operation - nntrainer::Tensor token_expert_output(token_output_dim); - acti_out.dot(down_proj, token_expert_output); - - // Apply weight and accumulate to expert's output (no critical section - // needed) - token_expert_output.multiply_i(weight); - size_t output_offset = token_idx * hidden_size; - nntrainer::Tensor token_output = - expert_output.getSharedDataTensor(token_output_dim, output_offset, true); - - token_output.add_i(token_expert_output); - } -} - -void MoELayer::incremental_forwarding(nntrainer::RunLayerContext &context, - unsigned int from, unsigned int to, - bool training) { - - nntrainer::Tensor &input_ = context.getInput(SINGLE_INOUT_IDX); - nntrainer::Tensor &output_ = context.getOutput(SINGLE_INOUT_IDX); - - nntrainer::Tensor &router_logits_ = context.getTensor(router_logits_idx); - nntrainer::Tensor &expert_mask = context.getTensor(expert_mask_idx); - - nntrainer::TensorDim input_step_dim = input_.getDim(); - nntrainer::TensorDim output_step_dim = output_.getDim(); - nntrainer::TensorDim router_logits_step_dim = router_logits_.getDim(); - - input_step_dim.batch(1); - output_step_dim.batch(1); - router_logits_step_dim.batch(to - from); - - input_step_dim.height(to - from); - output_step_dim.height(to - from); - - for (unsigned int b = 0; b < input_.batch(); ++b) { - - auto input = input_.getSharedDataTensor( - input_step_dim, b * input_step_dim.getFeatureLen(), true); - auto output = output_.getSharedDataTensor( - output_step_dim, b * output_step_dim.getFeatureLen(), true); - auto router_logits = - router_logits_.getSharedDataTensor(router_logits_step_dim, 0, true); - - const unsigned batch_size = input.batch(); - const unsigned seq_len = input.height(); - const unsigned hidden_size = input.width(); - const unsigned total_tokens = batch_size * seq_len; - - // reshape input: [B,1,S,H] -> [B*S,1,1,H] - input.reshape({total_tokens, 1, 1, hidden_size}); - - // reshape output: [B,1,S,H] -> [B*S,1,1,H] - output.reshape({total_tokens, 1, 1, hidden_size}); - output.setZero(); - expert_mask.setZero(); - - // routing - nntrainer::Tensor &gate_weights = context.getWeight(gate_idx); - input.dot(gate_weights, router_logits); - router_logits.apply(nntrainer::ActiFunc::softmax, router_logits); - auto topk_result = router_logits.topK(topk); - auto topk_values = std::get<0>(topk_result); - auto topk_indices = std::get<1>(topk_result); - - // norm_topk_prob - topk_values.divide_i(topk_values.sum(3)); - - const uint32_t *indices_data = topk_indices.getData(); - // Set expert mask - for (int i = 0; i < static_cast(total_tokens); ++i) { - for (int k = 0; k < static_cast(topk); ++k) { - expert_mask.setValue(indices_data[i * topk + k], 0, k, i, 1.0f); - } - } - - // Pre-compute expert token assignments for better performance - std::vector>> expert_assignments( - num_experts); - for (int i = 0; i < static_cast(total_tokens); ++i) { - for (int k = 0; k < static_cast(topk); ++k) { - unsigned expert_idx = indices_data[i * topk + k]; - float weight = topk_values.getValue(i, 0, 0, k); - expert_assignments[expert_idx].emplace_back(i, weight); - } - } - - // Parallel processing for multiple tokens with many active experts - std::vector expert_outputs(num_experts); - for (int expert_idx = 0; expert_idx < static_cast(num_experts); - ++expert_idx) { - if (!expert_assignments[expert_idx].empty()) { - expert_outputs[expert_idx] = nntrainer::Tensor( - total_tokens, 1, 1, hidden_size, output.getTensorType()); - expert_outputs[expert_idx].setZero(); - } - } - -#pragma omp parallel for schedule(dynamic) - for (int expert_idx = 0; expert_idx < static_cast(num_experts); - ++expert_idx) { - const auto &assignments = expert_assignments[expert_idx]; - if (assignments.empty()) - continue; - - compute_expert_forward_no_critical( - input, expert_outputs[expert_idx], assignments, - context.getWeight(expert_gate_proj_indices[expert_idx]), - context.getWeight(expert_up_proj_indices[expert_idx]), - context.getWeight(expert_down_proj_indices[expert_idx]), hidden_size); - } - - // Combine expert outputs - for (int expert_idx = 0; expert_idx < static_cast(num_experts); - ++expert_idx) { - if (!expert_assignments[expert_idx].empty()) { - output.add_i(expert_outputs[expert_idx]); - } - } - - // reshape output: [B*S,1,1,H] -> [B,1,S,H] - output.reshape({batch_size, 1, seq_len, hidden_size}); - } -} - -void MoELayer::setProperty(const std::vector &values) { - auto remain_props = loadProperties(values, moe_props); - nntrainer::LayerImpl::setProperty(remain_props); -} - -void MoELayer::calcDerivative(nntrainer::RunLayerContext &context) { - // MoE layer does not support derivative calculation - throw std::runtime_error("MoE layer does not support derivative calculation"); -} - -void MoELayer::calcGradient(nntrainer::RunLayerContext &context) { - // MoE layer does not support gradient calculation - throw std::runtime_error("MoE layer does not support gradient calculation"); -} - -void MoELayer::exportTo(nntrainer::Exporter &exporter, - const ml::train::ExportMethods &method) const { - nntrainer::LayerImpl::exportTo(exporter, method); - exporter.saveResult(moe_props, method, this); // Save MoE specific properties -} - -} // namespace quick_dot_ai diff --git a/models/qwen3_moe/qwen_moe_layer.h b/models/qwen3_moe/qwen_moe_layer.h deleted file mode 100644 index 0d5ee2bd..00000000 --- a/models/qwen3_moe/qwen_moe_layer.h +++ /dev/null @@ -1,164 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * Copyright (C) 2025 Eunju Yang - * - * @file moe_layer.h - * @date 09 June 2025 - * @brief This is Mixture of Expert Layer Class of Neural Network - * @see https://github.com/nntrainer/nntrainer - * @author Eunju Yang - * @bug No known bugs except for NYI items - * @note This file is part of the Mixture of Expert Layer implementation. - * It does not support shared experts. - * This layer is implemented based on the LLama-MoE. - * For more information, please refer to the following link: - * https://arxiv.org/pdf/2406.16554 - * @todo This layer does not support backwarding yet. - */ - -#ifndef __MOE_LAYER_H__ -#define __MOE_LAYER_H__ -#ifdef __cplusplus - -#include -#include -#include -#include - -namespace quick_dot_ai { - -/** - * @class MoELayer - * @brief Mixture of Expert Layer - */ -class MoELayer : public nntrainer::LayerImpl { -public: - /** - * @brief Constructor of Mixture of Expert Layer - */ - MoELayer(); - - /** - * @brief Destructor of Mixture of Expert Layer - */ - ~MoELayer() = default; - - /** - * @brief Move constructor. - * @param[in] MoELayer && - */ - MoELayer(MoELayer &&rhs) noexcept = default; - - /** - * @brief Move assignment operator. - * @param[in] rhs MoELayer to be moved. - */ - MoELayer &operator=(MoELayer &&rhs) = default; - - /** - * @copydoc Layer::finalize(InitLayerContext &context) - */ - void finalize(nntrainer::InitLayerContext &context) override; - - /** - * @copydoc Layer::forwarding(RunLayerContext &context, bool training) - */ - void forwarding(nntrainer::RunLayerContext &context, bool training) override; - - /** - * @copydoc Layer::incremental_forwarding(RunLayerContext &context, unsigned) - */ - void incremental_forwarding(nntrainer::RunLayerContext &context, - unsigned int from, unsigned int to, - bool training) override; - - /** - * @copydoc Layer::calcDerivative(RunLayerContext &context) - */ - void calcDerivative(nntrainer::RunLayerContext &context) override; - - /** - * @copydoc Layer::calcGradient(RunLayerContext &context) - */ - void calcGradient(nntrainer::RunLayerContext &context) override; - - /** - * @copydoc Layer::setProperty(const std::vector &values) - */ - void setProperty(const std::vector &values) override; - - /** - * @copydoc Layer::exportTo(Exporter &exporter, const ml::train::ExportMethods - * &methods) - */ - void exportTo(nntrainer::Exporter &exporter, - const ml::train::ExportMethods &method) const override; - - /** - * @copydoc Layer::getType() - */ - const std::string getType() const override { return MoELayer::type; }; - - /** - * @brief Layer::supportBackwarding() - */ - bool supportBackwarding() const override { return false; } - - static constexpr const char *type = "qwen_moe"; /**< type of the layer */ - -private: - unsigned int num_experts; /**< number of experts */ - unsigned int topk; /**< number of experts per token, i.e., topk */ - nntrainer::ActiFunc acti_func; /**< activation function for the expert */ - std::tuple - moe_props; - - // weight indeices - std::vector expert_gate_proj_indices; - std::vector expert_up_proj_indices; - std::vector expert_down_proj_indices; - unsigned int gate_idx; - - // Intermediate tensor indices - unsigned int router_logits_idx; - unsigned int expert_mask_idx; - - /** - * @brief expert forward computation without memory copies - * @param input Input tensor (reshaped to [total_tokens, 1, 1, hidden_size]) - * @param output Output tensor to accumulate results - * @param token_assignments Vector of (token_index, weight) pairs for this - * expert - * @param gate_proj Gate projection weight tensor - * @param up_proj Up projection weight tensor - * @param down_proj Down projection weight tensor - * @param hidden_size Hidden dimension size - */ - inline void compute_expert_forward( - const nntrainer::Tensor &input, nntrainer::Tensor &output, - const std::vector> &token_assignments, - const nntrainer::Tensor &gate_proj, const nntrainer::Tensor &up_proj, - const nntrainer::Tensor &down_proj, unsigned int hidden_size); - - /** - * @brief expert forward computation without critical section - * @param input Input tensor (reshaped to [total_tokens, 1, 1, hidden_size]) - * @param expert_output Expert-specific output tensor - * @param token_assignments Vector of (token_index, weight) pairs for this - * expert - * @param gate_proj Gate projection weight tensor - * @param up_proj Up projection weight tensor - * @param down_proj Down projection weight tensor - * @param hidden_size Hidden dimension size - */ - inline void compute_expert_forward_no_critical( - const nntrainer::Tensor &input, nntrainer::Tensor &expert_output, - const std::vector> &token_assignments, - const nntrainer::Tensor &gate_proj, const nntrainer::Tensor &up_proj, - const nntrainer::Tensor &down_proj, unsigned int hidden_size); -}; -} // namespace quick_dot_ai - -#endif /* __cplusplus */ -#endif /* __MOE_LAYER_H__ */ diff --git a/models/qwen3_slim_moe/README.md b/models/qwen3_slim_moe/README.md deleted file mode 100644 index 686cbaed..00000000 --- a/models/qwen3_slim_moe/README.md +++ /dev/null @@ -1,13 +0,0 @@ -# Qwen3 Slim MoE Model - -This directory contains the implementation for Qwen3 Slim MoE model. -## About the Slim Model -The **Slim** model is designed to minimize peak memory usage by loading experts in an on-the-fly manner. - -- **Efficient Initialization**: Instead of loading all model weights at once, the model initializes without heavy expert layers. -- **Dynamic Loading**: Only the activated experts are loaded into memory during runtime. -- **Performance Note**: Since the model dynamically maps memory to experts on storage, inference speed relies heavily on storage read I/O performance. - -## Files -- `qwen3_slim_moe_causallm.cpp`: Slim MoE implementation. -- `qwen_moe_layer_fsu.cpp`: FSU optimized layer for Slim MoE. diff --git a/models/qwen3_slim_moe/meson.build b/models/qwen3_slim_moe/meson.build deleted file mode 100644 index a0d1b91a..00000000 --- a/models/qwen3_slim_moe/meson.build +++ /dev/null @@ -1,25 +0,0 @@ -qwen3_slim_moe_src = [ - meson.current_source_dir() / 'qwen3_slim_moe_causallm.cpp', -] - -qwen3_slim_moe_inc = include_directories('.') - -causallm_slim_moe_layer_src_abs = [meson.current_source_dir() / 'qwen_moe_layer_fsu.cpp'] - -causallm_slim_moe_layer = shared_library( - 'qwen_moe_layer_fsu', - causallm_slim_moe_layer_src_abs, - include_directories: [quick_dot_ai_layer_inc, qwen3_slim_moe_inc], - dependencies: [nntrainer_dep, nntrainer_ccapi_dep, openmp_dep], - install: true, - install_dir: application_install_dir -) - -causallm_slim_moe_layer_dep = declare_dependency( - link_with: causallm_slim_moe_layer, - include_directories: qwen3_slim_moe_inc -) - -quick_dot_ai_src += qwen3_slim_moe_src -quick_dot_ai_inc += qwen3_slim_moe_inc -quick_dot_ai_layer_dependencies += [causallm_slim_moe_layer_dep] diff --git a/models/qwen3_slim_moe/qwen3_slim_moe_causallm.cpp b/models/qwen3_slim_moe/qwen3_slim_moe_causallm.cpp deleted file mode 100644 index 5d0b45c7..00000000 --- a/models/qwen3_slim_moe/qwen3_slim_moe_causallm.cpp +++ /dev/null @@ -1,80 +0,0 @@ -/** - * Copyright (C) 2025 Samsung Electronics Co., Ltd. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * http://www.apache.org/licenses/LICENSE-2.0 - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - * - * @file qwen3_moe_causallm.cpp - * @date 23 July 2025 - * @brief This defines a qwen3_moe causal language model. - * @see https://github.com/nnstreamer/ - * @author Eunju Yang - * @bug No known bugs except for NYI items - * - */ -#include -#include -#include - -#include -#include -#include - -namespace quick_dot_ai { - -void Qwen3SlimMoECausalLM::setupParameters(json &cfg, json &generation_cfg, - json &nntr_cfg) { - Qwen3CausalLM(cfg, generation_cfg, nntr_cfg); - - // parameters for Qwen3MoE model - try { - NUM_EXPERTS = cfg["num_experts"]; - NUM_EXPERTS_PER_TOK = cfg["num_experts_per_tok"]; - INTERMEDIATE_SIZE = cfg["moe_intermediate_size"]; - } catch (const std::exception &e) { - throw std::runtime_error("Qwen3MoE: num_experts and num_experts_per_tok " - "are not specified in the config file"); - } -} - -std::vector -Qwen3SlimMoECausalLM::createMlp(const int layer_id, int dim, int hidden_dim, - std::string input_name) { - - std::vector layers; - layers.push_back(createLayer( - "moe_slim", - {withKey("name", "layer" + std::to_string(layer_id) + "_ffn_down"), - withKey("input_layers", input_name), withKey("unit", hidden_dim), - withKey("num_experts", NUM_EXPERTS), - withKey("num_experts_per_token", NUM_EXPERTS_PER_TOK), - withKey("moe_activation", "swish")})); - - return layers; -} - -void Qwen3SlimMoECausalLM::registerCustomLayers() { - - Qwen3CausalLM::registerCustomLayers(); - auto &ct_engine = nntrainer::Engine::Global(); - auto app_context = - static_cast(ct_engine.getRegisteredContext("cpu")); - - try { - app_context->registerFactory( - nntrainer::createLayer); - } catch (std::invalid_argument &e) { - std::cerr << "failed to register factory, reason: " << e.what() - << std::endl; - } -} - -} // namespace quick_dot_ai diff --git a/models/qwen3_slim_moe/qwen3_slim_moe_causallm.h b/models/qwen3_slim_moe/qwen3_slim_moe_causallm.h deleted file mode 100644 index e7d4d23f..00000000 --- a/models/qwen3_slim_moe/qwen3_slim_moe_causallm.h +++ /dev/null @@ -1,53 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * Copyright (C) 2025 Eunju Yang - * - * @file qwen3_moe_causallm.h - * @date 15 July 2025 - * @see https://github.com/nntrainer/nntrainer - * @author Eunju Yang - * @bug No known bugs except for NYI items - * - */ - -#ifndef __QWEN_SLIM_MOE_CAUSAL_LM_H__ -#define __QWEN_SLIM_MOE_CAUSAL_LM_H__ - -#include -#include - -namespace quick_dot_ai { - -/** - * @brief Qwen3SlimMoECausalLM class - * @note This class inherits Qwewn3CaUSALlm - */ -class Qwen3SlimMoECausalLM : public Qwen3CausalLM { - -public: - static constexpr const char *architectures = "Qwen3SlimMoeForCausalLM"; - - Qwen3SlimMoECausalLM(json &cfg, json &generation_cfg, json &nntr_cfg) : - Transformer(cfg, generation_cfg, nntr_cfg, ModelType::CAUSALLM), - Qwen3CausalLM(cfg, generation_cfg, nntr_cfg) { - setupParameters(cfg, generation_cfg, nntr_cfg); - } - - virtual ~Qwen3SlimMoECausalLM() = default; - - std::vector createMlp(const int layer_id, int dim, - int hidden_dim, - std::string input_name) override; - - void setupParameters(json &cfg, json &generation_cfg, - json &nntr_cfg) override; - - void registerCustomLayers() override; - -private: - unsigned int NUM_EXPERTS; - unsigned int NUM_EXPERTS_PER_TOK; -}; -}; // namespace quick_dot_ai - -#endif /* __QWEN_MOE_CAUSAL_LM_H__ */ diff --git a/models/qwen3_slim_moe/qwen_moe_layer_fsu.cpp b/models/qwen3_slim_moe/qwen_moe_layer_fsu.cpp deleted file mode 100644 index 6c062122..00000000 --- a/models/qwen3_slim_moe/qwen_moe_layer_fsu.cpp +++ /dev/null @@ -1,517 +0,0 @@ -/** - * Copyright (C) 2020 Samsung Electronics Co., Ltd. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * http://www.apache.org/licenses/LICENSE-2.0 - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - * - * @file qwen_moe_layer_fsu.cpp - * @date 09 June 2025 - * @brief This is a Mixture of Expert Layer Class for Neural Network - * @see https://github.com/nnstreamer/ - * @author Eunju Yang - * @bug No known bugs except for NYI items - * @note MoE layer with on-the-fly expert FSU - * - */ - -#include -#include -#include -#include -#include -#include -#include - -namespace quick_dot_ai { - -static constexpr size_t SINGLE_INOUT_IDX = 0; - -SlimMoELayer::SlimMoELayer() : - LayerImpl(), - num_experts(0), - topk(0), - moe_props(props::NumExperts(), props::NumExpertsPerToken(), - nntrainer::props::Unit(), props::MoEActivation()), - expert_gate_proj_indices({}), - expert_up_proj_indices({}), - expert_down_proj_indices({}), - gate_idx(std::numeric_limits::max()), - router_logits_idx(std::numeric_limits::max()), - expert_mask_idx(std::numeric_limits::max()) {} - -void SlimMoELayer::finalize(nntrainer::InitLayerContext &context) { - - // 1. Validate input/output dimensions - NNTR_THROW_IF(context.getNumInputs() != 1, std::invalid_argument) - << "MoE layer only supports single input"; - - auto &weight_regularizer = - std::get(*layer_impl_props); - auto &weight_regularizer_constant = - std::get(*layer_impl_props); - auto &weight_initializer = - std::get(*layer_impl_props); - auto &weight_decay = - std::get(*layer_impl_props); - - // 2. Set output dimensions (same as input) - const auto &in_dim = context.getInputDimensions()[SINGLE_INOUT_IDX]; - const bool is_nchw = context.getFormat() == nntrainer::Tformat::NCHW; - std::vector output_dims(1); - output_dims[SINGLE_INOUT_IDX] = in_dim; - context.setOutputDimensions(output_dims); - - // 3. Get MoE properties - num_experts = std::get(moe_props).get(); - topk = std::get(moe_props).get(); - const unsigned int intermediate_size = - std::get(moe_props).get(); - const unsigned int hidden_size = in_dim.width(); // Feature dimension - - // activation function - if (std::get(moe_props).empty()) { - throw std::runtime_error("Activation type is not set for MoE layer"); - } - switch (context.getActivationDataType()) { - case ml::train::TensorDim::DataType::FP32: - acti_func.setActiFunc( - std::get(moe_props).get()); - break; - default: - throw std::runtime_error("Unsupported activation data type for MoE layer"); - } - - // 4. Initialie gate layer (router) - nntrainer::TensorDim gate_dim( - 1, is_nchw ? 1 : num_experts, is_nchw ? hidden_size : 1, - is_nchw ? num_experts : hidden_size, - nntrainer::TensorDim::TensorType(context.getFormat(), - nntrainer::TensorDim::DataType::FP32), - is_nchw ? 0b0011 : 0b0101); - - gate_idx = context.requestWeight( - gate_dim, weight_initializer, weight_regularizer, - weight_regularizer_constant, weight_decay, "gate", true); - - // 5. Initializer expert weights - expert_gate_proj_indices.reserve(num_experts); - expert_up_proj_indices.reserve(num_experts); - expert_down_proj_indices.reserve(num_experts); - - nntrainer::TensorDim expert_gate_dim( - 1, is_nchw ? 1 : intermediate_size, is_nchw ? hidden_size : 1, - is_nchw ? intermediate_size : hidden_size, - nntrainer::TensorDim::TensorType(context.getFormat(), - context.getWeightDataType()), - is_nchw ? 0b0011 : 0b0101); - - nntrainer::TensorDim expert_down_dim( - 1, is_nchw ? 1 : hidden_size, is_nchw ? intermediate_size : 1, - is_nchw ? hidden_size : intermediate_size, - nntrainer::TensorDim::TensorType(context.getFormat(), - context.getWeightDataType()), - is_nchw ? 0b0011 : 0b0101); - - for (unsigned int i = 0; i < num_experts; ++i) { - // Up projection - expert_up_proj_indices.push_back(context.requestWeight( - expert_gate_dim, // Same dimensions as gate projection - weight_initializer, weight_regularizer, weight_regularizer_constant, - weight_decay, "expert_up_" + std::to_string(i), false, true)); - - // Gate projection - expert_gate_proj_indices.push_back(context.requestWeight( - expert_gate_dim, weight_initializer, weight_regularizer, - weight_regularizer_constant, weight_decay, - "expert_gate_" + std::to_string(i), false, true)); - - // Down projection - expert_down_proj_indices.push_back(context.requestWeight( - expert_down_dim, weight_initializer, weight_regularizer, - weight_regularizer_constant, weight_decay, - "expert_down_" + std::to_string(i), false, true)); - } - - // 6. Request intermediate tensors - const unsigned batch_size = in_dim.batch(); - const unsigned seq_len = in_dim.height(); - const unsigned total_tokens = batch_size * seq_len; - - // Router logits : [batch * seq, num_experts] - router_logits_idx = - context.requestTensor({total_tokens, 1, 1, num_experts}, "router_logits", - nntrainer::Initializer::NONE, false, - nntrainer::TensorLifespan::FORWARD_FUNC_LIFESPAN); - - // Expert mask: [num_experts, batch*seq] - expert_mask_idx = - context.requestTensor({num_experts, 1, topk, total_tokens}, "expert_mask", - nntrainer::Initializer::ZEROS, false, - nntrainer::TensorLifespan::FORWARD_FUNC_LIFESPAN); -} - -void SlimMoELayer::forwarding(nntrainer::RunLayerContext &context, - bool training) { - nntrainer::Tensor &input = context.getInput(SINGLE_INOUT_IDX); - nntrainer::Tensor &output = context.getOutput(SINGLE_INOUT_IDX); - - nntrainer::Tensor &router_logits = context.getTensor(router_logits_idx); - nntrainer::Tensor &expert_mask = context.getTensor(expert_mask_idx); - - const unsigned batch_size = input.batch(); - const unsigned seq_len = input.height(); - const unsigned hidden_size = input.width(); - const unsigned total_tokens = batch_size * seq_len; - - // reshape input: [B,1,S,H] -> [B*S,1,1,H] - input.reshape({total_tokens, 1, 1, hidden_size}); - - // reshape output: [B,1,S,H] -> [B*S,1,1,H] - output.reshape({total_tokens, 1, 1, hidden_size}); - output.setZero(); - - // routing - nntrainer::Tensor &gate_weights = context.getWeight(gate_idx); - input.dot(gate_weights, router_logits); - router_logits.apply(nntrainer::ActiFunc::softmax, router_logits); - auto topk_result = router_logits.topK(topk); - auto topk_values = std::get<0>(topk_result); - auto topk_indices = std::get<1>(topk_result); - - const uint32_t *indices_data = topk_indices.getData(); -#pragma omp parallel for collapse(2) - for (int i = 0; i < static_cast(total_tokens); ++i) { - for (int k = 0; k < static_cast(topk); ++k) { - expert_mask.setValue(indices_data[i * topk + k], 0, k, i, 1.0f); - } - } - - // Pre-compute expert token assignments for better cache locality - std::vector>> expert_assignments( - num_experts); - for (int i = 0; i < static_cast(total_tokens); ++i) { - for (int k = 0; k < static_cast(topk); ++k) { - unsigned expert_idx = indices_data[i * topk + k]; - float weight = topk_values.getValue(i, 0, 0, k); - expert_assignments[expert_idx].emplace_back(i, weight); - } - } - - for (int expert_idx = 0; expert_idx < static_cast(num_experts); - ++expert_idx) { - const auto &assignments = expert_assignments[expert_idx]; - if (assignments.empty()) - continue; - - ///@note load expert layer for the expert_idx - nntrainer::Tensor expert_gate_proj = - context.getWeight(expert_gate_proj_indices[expert_idx]); - nntrainer::Tensor expert_up_proj = - context.getWeight(expert_up_proj_indices[expert_idx]); - nntrainer::Tensor expert_down_proj = - context.getWeight(expert_down_proj_indices[expert_idx]); - - ///@note Please note that expert_gate_proj is virtual tensor, - /// which is not allocated so far. It will be allocated when it is - /// used. `activate(read=true)` will allocate its memory and will read - /// from the original weight. activate is true by default. i.e., mmap - expert_gate_proj.activate(); - expert_up_proj.activate(); - expert_down_proj.activate(); - - // Use optimized expert forward computation without memory copies - compute_expert_forward(input, output, assignments, expert_gate_proj, - expert_up_proj, expert_down_proj, hidden_size); - - ////@note Please note that the virtual tensor is deactivated after usage - //// This will allocate and load data from the storage on-the-fly - //// i.e., unmap - expert_gate_proj.deactivate(); - expert_up_proj.deactivate(); - expert_down_proj.deactivate(); - } - - // reshape output: [B*S,1,1,H] -> [B,1,S,H] - output.reshape({batch_size, 1, seq_len, hidden_size}); -} - -inline void SlimMoELayer::compute_expert_forward( - const nntrainer::Tensor &input, nntrainer::Tensor &output, - const std::vector> &token_assignments, - const nntrainer::Tensor &gate_proj, const nntrainer::Tensor &up_proj, - const nntrainer::Tensor &down_proj, unsigned int hidden_size) { - - const unsigned intermediate_size = gate_proj.width(); - const unsigned num_tokens = token_assignments.size(); - - if (num_tokens == 0) - return; - - // Create tensor dimensions for single token processing - nntrainer::TensorDim token_input_dim({1, 1, 1, hidden_size}, - input.getTensorType()); - nntrainer::TensorDim intermediate_dim({1, 1, 1, intermediate_size}, - input.getTensorType()); - nntrainer::TensorDim token_output_dim({1, 1, 1, hidden_size}, - input.getTensorType()); - - // Create a temporary output tensor for this expert to avoid critical section - nntrainer::Tensor expert_output(output.batch(), output.channel(), - output.height(), output.width(), - output.getTensorType()); - expert_output.setZero(); - - // Process each token individually to avoid memory copies - for (size_t i = 0; i < num_tokens; ++i) { - const unsigned token_idx = token_assignments[i].first; - const float weight = token_assignments[i].second; - - // Create shared tensor for input token (no memory copy) - size_t token_offset = token_idx * hidden_size; - nntrainer::Tensor token_input = - input.getSharedDataTensor(token_input_dim, token_offset, true); - - // Create intermediate tensors for this token - nntrainer::Tensor gate_out(intermediate_dim); - nntrainer::Tensor acti_out(intermediate_dim); - nntrainer::Tensor up_out(intermediate_dim); - - // Gate projection using optimized dot operation - token_input.dot(gate_proj, gate_out); - - // Apply activation (silu) - acti_func.run_fn(gate_out, acti_out); - - // Up projection using optimized dot operation - token_input.dot(up_proj, up_out); - - // Element-wise multiply: silu(gate_out) * up_out - acti_out.multiply_i(up_out); - - // Down projection using optimized dot operation - nntrainer::Tensor token_expert_output(token_output_dim); - acti_out.dot(down_proj, token_expert_output); - - // Apply weight and accumulate to expert's temporary output - token_expert_output.multiply_i(weight); - size_t output_offset = token_idx * hidden_size; - nntrainer::Tensor token_output = - expert_output.getSharedDataTensor(token_output_dim, output_offset, true); - - token_output.add_i(token_expert_output); - } - - // Add expert's result to final output (no critical section in sequential - // mode) - output.add_i(expert_output); -} - -inline void SlimMoELayer::compute_expert_forward_no_critical( - const nntrainer::Tensor &input, nntrainer::Tensor &expert_output, - const std::vector> &token_assignments, - const nntrainer::Tensor &gate_proj, const nntrainer::Tensor &up_proj, - const nntrainer::Tensor &down_proj, unsigned int hidden_size) { - - const unsigned intermediate_size = gate_proj.width(); - const unsigned num_tokens = token_assignments.size(); - - if (num_tokens == 0) - return; - - // Create tensor dimensions for single token processing - nntrainer::TensorDim token_input_dim({1, 1, 1, hidden_size}, - input.getTensorType()); - nntrainer::TensorDim intermediate_dim({1, 1, 1, intermediate_size}, - input.getTensorType()); - nntrainer::TensorDim token_output_dim({1, 1, 1, hidden_size}, - input.getTensorType()); - - // Process each token individually to avoid memory copies - for (size_t i = 0; i < num_tokens; ++i) { - const unsigned token_idx = token_assignments[i].first; - const float weight = token_assignments[i].second; - - // Create shared tensor for input token (no memory copy) - size_t token_offset = token_idx * hidden_size; - nntrainer::Tensor token_input = - input.getSharedDataTensor(token_input_dim, token_offset, true); - - // Create intermediate tensors for this token - nntrainer::Tensor gate_out(intermediate_dim); - nntrainer::Tensor acti_out(intermediate_dim); - nntrainer::Tensor up_out(intermediate_dim); - - // Gate projection using optimized dot operation - token_input.dot(gate_proj, gate_out); - - // Apply activation (silu) - acti_func.run_fn(gate_out, acti_out); - - // Up projection using optimized dot operation - token_input.dot(up_proj, up_out); - - // Element-wise multiply: silu(gate_out) * up_out - acti_out.multiply_i(up_out); - - // Down projection using optimized dot operation - nntrainer::Tensor token_expert_output(token_output_dim); - acti_out.dot(down_proj, token_expert_output); - - // Apply weight and accumulate to expert's output (no critical section - // needed) - token_expert_output.multiply_i(weight); - size_t output_offset = token_idx * hidden_size; - nntrainer::Tensor token_output = - expert_output.getSharedDataTensor(token_output_dim, output_offset, true); - - token_output.add_i(token_expert_output); - } -} - -void SlimMoELayer::incremental_forwarding(nntrainer::RunLayerContext &context, - unsigned int from, unsigned int to, - bool training) { - - nntrainer::Tensor &input_ = context.getInput(SINGLE_INOUT_IDX); - nntrainer::Tensor &output_ = context.getOutput(SINGLE_INOUT_IDX); - - nntrainer::Tensor &router_logits_ = context.getTensor(router_logits_idx); - - nntrainer::TensorDim input_step_dim = input_.getDim(); - nntrainer::TensorDim output_step_dim = output_.getDim(); - nntrainer::TensorDim router_logits_step_dim = router_logits_.getDim(); - - input_step_dim.batch(1); - output_step_dim.batch(1); - router_logits_step_dim.batch(to - from); - - input_step_dim.height(to - from); - output_step_dim.height(to - from); - - for (unsigned int b = 0; b < input_.batch(); ++b) { - - auto input = input_.getSharedDataTensor( - input_step_dim, b * input_step_dim.getFeatureLen(), true); - auto output = output_.getSharedDataTensor( - output_step_dim, b * output_step_dim.getFeatureLen(), true); - auto router_logits = - router_logits_.getSharedDataTensor(router_logits_step_dim, 0, true); - - const unsigned batch_size = input.batch(); - const unsigned seq_len = input.height(); - const unsigned hidden_size = input.width(); - const unsigned total_tokens = batch_size * seq_len; - - // reshape input: [B,1,S,H] -> [B*S,1,1,H] - input.reshape({total_tokens, 1, 1, hidden_size}); - - // reshape output: [B,1,S,H] -> [B*S,1,1,H] - output.reshape({total_tokens, 1, 1, hidden_size}); - output.setZero(); - - // routing - nntrainer::Tensor &gate_weights = context.getWeight(gate_idx); - input.dot(gate_weights, router_logits); - router_logits.apply(nntrainer::ActiFunc::softmax, router_logits); - auto topk_result = router_logits.topK(topk); - auto topk_values = std::get<0>(topk_result); - auto topk_indices = std::get<1>(topk_result); - - // norm_topk_prob - topk_values.divide_i(topk_values.sum(3)); - - const uint32_t *indices_data = topk_indices.getData(); - std::vector>> expert_assignments( - num_experts); - // Set expert mask - for (int i = 0; i < static_cast(total_tokens); ++i) { - for (int k = 0; k < static_cast(topk); ++k) { - unsigned expert_idx = indices_data[i * topk + k]; - float weight = topk_values.getValue(i, 0, 0, k); - expert_assignments[expert_idx].emplace_back(i, weight); - } - } - - // Parallel processing for multiple tokens with many active experts - std::vector expert_outputs(num_experts); - for (int expert_idx = 0; expert_idx < static_cast(num_experts); - ++expert_idx) { - if (!expert_assignments[expert_idx].empty()) { - expert_outputs[expert_idx] = nntrainer::Tensor( - total_tokens, 1, 1, hidden_size, output.getTensorType()); - } - } - -#pragma omp parallel for schedule(dynamic) - for (int expert_idx = 0; expert_idx < static_cast(num_experts); - ++expert_idx) { - const auto &assignments = expert_assignments[expert_idx]; - if (assignments.empty()) - continue; - - ///@note Please note that expert_gate_proj is virtual tensor, - /// which is not allocated so far. It will be allocated when it is - /// used. `activate(read=true)` will allocate its memory and will - /// read from the original weight. activate is true by default. i.e., - /// mmap - context.getWeight(expert_gate_proj_indices[expert_idx]).activate(); - context.getWeight(expert_up_proj_indices[expert_idx]).activate(); - context.getWeight(expert_down_proj_indices[expert_idx]).activate(); - - compute_expert_forward_no_critical( - input, expert_outputs[expert_idx], assignments, - context.getWeight(expert_gate_proj_indices[expert_idx]), - context.getWeight(expert_up_proj_indices[expert_idx]), - context.getWeight(expert_down_proj_indices[expert_idx]), hidden_size); - - ////@note Please note that the virtual tensor is deactivated after usage - //// This will allocate and load data from the storage on-the-fly - //// i.e., unmap - context.getWeight(expert_gate_proj_indices[expert_idx]).deactivate(); - context.getWeight(expert_up_proj_indices[expert_idx]).deactivate(); - context.getWeight(expert_down_proj_indices[expert_idx]).deactivate(); - } - - // Combine expert outputs - for (int expert_idx = 0; expert_idx < static_cast(num_experts); - ++expert_idx) { - if (!expert_assignments[expert_idx].empty()) { - output.add_i(expert_outputs[expert_idx]); - } - } - - // reshape output: [B*S,1,1,H] -> [B,1,S,H] - output.reshape({batch_size, 1, seq_len, hidden_size}); - } -} - -void SlimMoELayer::setProperty(const std::vector &values) { - auto remain_props = loadProperties(values, moe_props); - nntrainer::LayerImpl::setProperty(remain_props); -} - -void SlimMoELayer::calcDerivative(nntrainer::RunLayerContext &context) { - // MoE layer does not support derivative calculation - throw std::runtime_error("MoE layer does not support derivative calculation"); -} - -void SlimMoELayer::calcGradient(nntrainer::RunLayerContext &context) { - // MoE layer does not support gradient calculation - throw std::runtime_error("MoE layer does not support gradient calculation"); -} - -void SlimMoELayer::exportTo(nntrainer::Exporter &exporter, - const ml::train::ExportMethods &method) const { - nntrainer::LayerImpl::exportTo(exporter, method); - exporter.saveResult(moe_props, method, this); // Save MoE specific properties -} - -} // namespace quick_dot_ai diff --git a/models/qwen3_slim_moe/qwen_moe_layer_fsu.h b/models/qwen3_slim_moe/qwen_moe_layer_fsu.h deleted file mode 100644 index 718c6b6d..00000000 --- a/models/qwen3_slim_moe/qwen_moe_layer_fsu.h +++ /dev/null @@ -1,163 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * Copyright (C) 2025 Eunju Yang - * - * @file qwen_moe_layer_fsu.h - * @date 09 June 2025 - * @brief This is Mixture of Expert Layer Class of Neural Network - * @see https://github.com/nntrainer/nntrainer - * @author Eunju Yang - * @bug No known bugs except for NYI items - * @note This file is part of the Mixture of Expert Layer implementation. - * It does not support shared experts. - * This layer is implemented based on the LLama-MoE. - * For more information, please refer to the following link: - * https://arxiv.org/pdf/2406.16554 - * @todo This layer does not support backwarding yet. - */ - -#ifndef __MOE_LAYER_H__ -#define __MOE_LAYER_H__ -#ifdef __cplusplus - -#include -#include -#include -#include - -namespace quick_dot_ai { - -/** - * @class SlimMoELayer - * @brief Mixture of Expert Layer - */ -class SlimMoELayer : public nntrainer::LayerImpl { -public: - /** - * @brief Constructor of Mixture of Expert Layer - */ - SlimMoELayer(); - - /** - * @brief Destructor of Mixture of Expert Layer - */ - ~SlimMoELayer() = default; - - /** - * @brief Move constructor. - * @param[in] SlimMoELayer && - */ - SlimMoELayer(SlimMoELayer &&rhs) noexcept = default; - - /** - * @brief Move assignment operator. - * @param[in] rhs SlimMoELayer to be moved. - */ - SlimMoELayer &operator=(SlimMoELayer &&rhs) = default; - - /** - * @copydoc Layer::finalize(InitLayerContext &context) - */ - void finalize(nntrainer::InitLayerContext &context) override; - - /** - * @copydoc Layer::forwarding(RunLayerContext &context, bool training) - */ - void forwarding(nntrainer::RunLayerContext &context, bool training) override; - - /** - * @copydoc Layer::incremental_forwarding(RunLayerContext &context, unsigned) - */ - void incremental_forwarding(nntrainer::RunLayerContext &context, - unsigned int from, unsigned int to, - bool training) override; - - /** - * @copydoc Layer::calcDerivative(RunLayerContext &context) - */ - void calcDerivative(nntrainer::RunLayerContext &context) override; - - /** - * @copydoc Layer::calcGradient(RunLayerContext &context) - */ - void calcGradient(nntrainer::RunLayerContext &context) override; - - /** - * @copydoc Layer::setProperty(const std::vector &values) - */ - void setProperty(const std::vector &values) override; - - /** - * @copydoc Layer::exportTo(Exporter &exporter, const ml::train::ExportMethods - * &methods) - */ - void exportTo(nntrainer::Exporter &exporter, - const ml::train::ExportMethods &method) const override; - - /** - * @copydoc Layer::getType() - */ - const std::string getType() const override { return SlimMoELayer::type; }; - - /** - * @brief Layer::supportBackwarding() - */ - bool supportBackwarding() const override { return false; } - - static constexpr const char *type = "moe_slim"; /**< type of the layer */ - -private: - unsigned int num_experts; /**< number of experts */ - unsigned int topk; /**< number of experts per token, i.e., topk */ - nntrainer::ActiFunc acti_func; /**< activation function for the expert */ - std::tuple - moe_props; - - // weight indeices - std::vector expert_gate_proj_indices; - std::vector expert_up_proj_indices; - std::vector expert_down_proj_indices; - unsigned int gate_idx; - - // Intermediate tensor indices - unsigned int router_logits_idx; - unsigned int expert_mask_idx; - /** - * @brief expert forward computation without memory copies - * @param input Input tensor (reshaped to [total_tokens, 1, 1, hidden_size]) - * @param output Output tensor to accumulate results - * @param token_assignments Vector of (token_index, weight) pairs for this - * expert - * @param gate_proj Gate projection weight tensor - * @param up_proj Up projection weight tensor - * @param down_proj Down projection weight tensor - * @param hidden_size Hidden dimension size - */ - inline void compute_expert_forward( - const nntrainer::Tensor &input, nntrainer::Tensor &output, - const std::vector> &token_assignments, - const nntrainer::Tensor &gate_proj, const nntrainer::Tensor &up_proj, - const nntrainer::Tensor &down_proj, unsigned int hidden_size); - - /** - * @brief expert forward computation without critical section - * @param input Input tensor (reshaped to [total_tokens, 1, 1, hidden_size]) - * @param expert_output Expert-specific output tensor - * @param token_assignments Vector of (token_index, weight) pairs for this - * expert - * @param gate_proj Gate projection weight tensor - * @param up_proj Up projection weight tensor - * @param down_proj Down projection weight tensor - * @param hidden_size Hidden dimension size - */ - inline void compute_expert_forward_no_critical( - const nntrainer::Tensor &input, nntrainer::Tensor &expert_output, - const std::vector> &token_assignments, - const nntrainer::Tensor &gate_proj, const nntrainer::Tensor &up_proj, - const nntrainer::Tensor &down_proj, unsigned int hidden_size); -}; -} // namespace quick_dot_ai - -#endif /* __cplusplus */ -#endif /* __MOE_LAYER_H__ */ diff --git a/models/sentence_transformer.cpp b/models/sentence_transformer.cpp deleted file mode 100644 index 270692ec..00000000 --- a/models/sentence_transformer.cpp +++ /dev/null @@ -1,305 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * Copyright (C) 2025 Eunju Yang - * - * @file sentence_transformer.cpp - * @date 02 Jan 2026 - * @see https://github.com/nntrainer/nntrainer - * @author Eunju Yang - * @bug No known bugs except for NYI items - * @brief This file defines SentenceTransformer's basic actions - */ - -#include -#include -#include -#include -#include - -#include -#include - -namespace quick_dot_ai { - -SentenceTransformer::SentenceTransformer(json &cfg, json &generation_cfg, - json &nntr_cfg) : - Transformer(cfg, generation_cfg, nntr_cfg, ModelType::EMBEDDING) { - setupParameters(cfg, generation_cfg, nntr_cfg); -} - -std::map SentenceTransformer::layer_map = { - {"Pooling", "embedding_pooling"}, - {"Normalize", "embedding_normalize"}, - {"Dense", "fully_connected"}}; - -void SentenceTransformer::setupParameters(json &cfg, json &generation_cfg, - json &nntr_cfg) { - Transformer::setupParameters(cfg, generation_cfg, nntr_cfg); - - std::string modules_config_path = "modules.json"; - if (nntr_cfg.contains("module_config_path")) { - modules_config_path = nntr_cfg["module_config_path"].get(); - } else { - std::cout << "module_config_path is not set. Using default: " - << modules_config_path << std::endl; - } - - // Get the directory containing modules.json to resolve relative paths - std::filesystem::path modules_json_path(modules_config_path); - std::filesystem::path base_dir = modules_json_path.parent_path(); - - try { - // 1. Load modules.json to get the structure and order of layers - json modules_json = LoadJsonFile(modules_config_path); - modules = modules_json.get>(); - - for (auto &module : modules) { - if (module.contains("path")) { - std::string module_path_str = module["path"].get(); - if (module_path_str.empty()) { - // For the first module (Transformer), the path might be empty or "."" - // We generally skip it or handle it if it points to a separate - // config. - continue; - } - - // 2. Resolve config.json path for each module - std::filesystem::path module_dir = base_dir / module_path_str; - - if (std::filesystem::exists(module_dir) && - std::filesystem::is_directory(module_dir)) { - std::filesystem::path config_path = module_dir / "config.json"; - if (std::filesystem::exists(config_path)) { - try { - // 3. Load config.json and store it in module_configs map using - // idx as key - json module_config = LoadJsonFile(config_path.string()); - if (module.contains("idx")) { - int idx = module["idx"].get(); - module_configs[idx] = module_config; - } else { - std::cerr << "Warning: Module does not have idx field" - << std::endl; - } - } catch (const std::exception &e) { - std::cerr << "Failed to load config for module: " - << module_path_str << " Reason: " << e.what() - << std::endl; - } - } else { - // It's possible some modules don't have a config.json - } - } - } - } - } catch (const std::exception &e) { - std::cerr << "Failed to load modules config from: " << modules_config_path - << " Reason: " << e.what() << std::endl; - } -} - -void SentenceTransformer::constructModel() { - for (auto &module : modules) { - if (!module.contains("type")) { - continue; - } - std::string type = module["type"].get(); - std::string component = getLastComponent(type); - - if (component == "Transformer") { - Transformer::constructModel(); - } else { - if (module.contains("idx")) { - int idx = module["idx"].get(); - // Add module layer using properties from loaded config - addModule(type, idx); - } else { - std::cerr << "Warning: Module does not have idx field, skipping: " - << type << std::endl; - } - } - } -} - -void SentenceTransformer::addModule(const std::string &type, int idx) { - json config; - if (module_configs.find(idx) != module_configs.end()) { - config = module_configs[idx]; - } else { - // Config might be empty if no config.json was found. - // This is valid for layers that don't satisfy specific configurations - // (e.g., default behavior) - } - - // Determine the layer type component (e.g., "Pooling" from - // "sentence_transformers.models.Pooling") - std::string component = getLastComponent(type); - std::string layer_name; - auto it = layer_map.find(component); - if (it != layer_map.end()) { - layer_name = it->second; - } - - if (layer_name.empty()) { - std::cerr << "Warning: No layer mapping found for module type: " << type - << " (component: " << component << "). Skipping." << std::endl; - return; - } - - // Convert JSON config to nntrainer property format (key=value strings) - std::vector props; - for (auto &el : config.items()) { - std::string val_str; - if (el.value().is_string()) - val_str = el.value().get(); - else - val_str = el.value().dump(); // convert to string - - if (el.key() == "out_features") { - props.push_back("unit=" + val_str); - } else if (el.key() == "bias") { - if (val_str == "false") { - props.push_back("disable_bias=true"); - } - } else if (el.key() == "activation_function") { - if (val_str.find("Identity") == std::string::npos) { - props.push_back("activation=" + val_str); - } else { - // need to support other activations later on - } - } else if (el.key() == "in_features") { - // Ignore in_features as nntrainer infers it - } else { - props.push_back(el.key() + "=" + val_str); - } - } - - LayerHandle layer = ml::train::createLayer(layer_name, props); - model->addLayer(layer); -} - -void SentenceTransformer::run(const WSTR prompt, void *output_buf, - bool log_output) { - run(prompt, "", "", output_buf, log_output); -} - -void SentenceTransformer::run(const WSTR prompt, const WSTR system_prompt, - const WSTR tail_prompt, void *output_buf, - bool log_output) { - - try { - std::vector results = encode(prompt, system_prompt, tail_prompt); - - if (log_output) { - - std::cout << "Embedding Result (" << BATCH_SIZE - << " batch(es)):" << std::endl; - for (unsigned int b = 0; b < BATCH_SIZE; ++b) { - std::cout << "Batch " << b << ": ["; - // Print first few elements as sample - int print_dim = (DIM > 10) ? 10 : DIM; - for (int i = 0; i < print_dim; ++i) { - std::cout << results[0][b * DIM + i] - << (i == print_dim - 1 ? "" : ", "); - } - if (DIM > 10) - std::cout << ", ..."; - std::cout << "] (Total DIM: " << DIM << ")" << std::endl; - } - } - - if (output_buf != nullptr) { - // Caller is responsible for dellocation - *static_cast *>(output_buf) = results; - } else { - // output should be deallocated after use. - for (auto out : results) { - delete[] out; - } - } - } catch (const std::exception &e) { - std::cerr << "Error during embedding run: " << e.what() << std::endl; - } -} - -std::vector SentenceTransformer::encode(const WSTR prompt, - const WSTR system_prompt, - const WSTR tail_prompt) { - if (!is_initialized) { - throw std::runtime_error( - "SentenceTransformer model is not initialized. Please call " - "initialize() before encode()."); - } - -#if defined(_WIN32) - std::wstring prompt_ = system_prompt + prompt + tail_prompt; - std::wstring_convert> converter; - auto _input = tokenizer->Encode(converter.to_bytes(prompt_), true); -#else - std::string prompt_ = system_prompt + prompt + tail_prompt; - auto _input = tokenizer->Encode(prompt_, true); -#endif - - std::vector init_input; - unsigned int input_len = - std::min((unsigned int)_input.size(), (unsigned int)MAX_SEQ_LEN); - - // feed only available length - for (unsigned int i = 0; i < input_len; ++i) - init_input.push_back(_input[i]); - - float *input_sample = - (float *)malloc(sizeof(float) * BATCH_SIZE * MAX_SEQ_LEN); - - for (unsigned int b = 0; b < BATCH_SIZE; ++b) { - for (unsigned int i = 0; i < input_len; ++i) { - input_sample[static_cast(b) * MAX_SEQ_LEN + i] = - static_cast(init_input[i]); - } - } - - std::vector input; - input.push_back(input_sample); - - std::vector label; // Empty label for inference - - // Run incremental inference for the prefill stage - // start: 0, end: input_len (process all tokens at once) - // This performs a single forward pass for the entire prompt sequence to get - // embeddings. - std::vector output = model->incremental_inference( - BATCH_SIZE, input, label, input_len, 0, input_len, false); - - free(input_sample); - - return output; -} - -std::string SentenceTransformer::getLastComponent(const std::string &type) { - std::string last_component = type; - size_t last_dot_pos = type.find_last_of('.'); - if (last_dot_pos != std::string::npos) { - last_component = type.substr(last_dot_pos + 1); - } - return last_component; -} - -void SentenceTransformer::registerCustomLayers() { - Transformer::registerCustomLayers(); - - const auto &ct_engine = nntrainer::Engine::Global(); - const auto app_context = - static_cast(ct_engine.getRegisteredContext("cpu")); - - try { - app_context->registerFactory( - nntrainer::createLayer); - app_context->registerFactory( - nntrainer::createLayer); - } catch (std::invalid_argument &e) { - std::cerr << "failed to register factory, reason: " << e.what() - << std::endl; - } -} - -} // namespace quick_dot_ai diff --git a/models/sentence_transformer.h b/models/sentence_transformer.h deleted file mode 100644 index 194ca5ff..00000000 --- a/models/sentence_transformer.h +++ /dev/null @@ -1,122 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * Copyright (C) 2025 Eunju Yang - * - * @file sentence_transformer.h - * @date 02 Jan 2026 - * @see https://github.com/nntrainer/nntrainer - * @author Eunju Yang - * @bug No known bugs except for NYI items - * @note This embedding.h constructs a class for SentenceTransformer model - * which can be a parent of models with embedding (encoder) structure. - */ - -#ifndef __SENTENCE_TRANSFORMER_H__ -#define __SENTENCE_TRANSFORMER_H__ - -#pragma once - -#include -#include - -namespace quick_dot_ai { - -/** - * @brief SentenceTransformer Class - */ -WIN_EXPORT class SentenceTransformer : virtual public Transformer { - -public: - /** - * @brief Construct a new SentenceTransformer object - * @param cfg Configuration for the model (config.json) - * @param generation_cfg Configuration for the generation (generation.json) - * @param nntr_cfg Configuration for nntrainer (nntr_config.json) - */ - SentenceTransformer(json &cfg, json &generation_cfg, json &nntr_cfg); - - /** - * @brief Destroy the SentenceTransformer object - */ - virtual ~SentenceTransformer() {} - - /** - * @brief run the SentenceTransformer model (simple) - */ - void run(const WSTR prompt, void *output_buf = nullptr, - bool log_output = true) override; - - /** - * @brief run the SentenceTransformer model (full) - */ - void run(const WSTR prompt, const WSTR system_prompt = "", - const WSTR tail_prmopt = "", void *output_buf = nullptr, - bool log_output = true) override; - - /** - * @brief Encode the prompt and return the embedding - * @param prompt User prompt - * @param system_prompt System prompt - * @param tail_prompt Tail prompt - * @return SentenceTransformer output from the model - */ - std::vector encode(const WSTR prompt, const WSTR system_prompt = "", - const WSTR tail_prompt = ""); - -protected: - /** - * @brief Setup the parameters for the SentenceTransformer model - */ - void setupParameters(json &cfg, json &generation_dfg, - json &nntr_cfg) override; - - /** - * @brief Construct Model - */ - void constructModel() override; - - /** - * @brief Map of module type suffix to layer type name - * @note This map is used to dynamically resolve the nntrainer layer type from - * the module configuration type suffix. - * Key: Suffix of the module type (e.g., "Pooling") - * Value: Registered layer name in nntrainer (e.g., "embedding_pooling") - * @note All layers in this map correspond to operations defined in - * sentence_transformers/models/ and are prefixed with "embedding_" in - * nntrainer to distinguish with the general layers. - */ - static std::map layer_map; - - /** - * @brief Add Module Layer - * @param config Configuration for the layer - */ - void addModule(const std::string &type, int idx); - - /** - * @brief register CustomLayers - */ - void registerCustomLayers() override; - -private: - /** - * @brief Module metadata list (from modules.json) - */ - std::vector modules; - - /** - * @brief Module property configurations (from Module_name/config.json) - */ - std::map module_configs; - - /** - * @brief Get the last component of the module type string - * @param type Full type string (e.g., "sentence_transformers.models.Pooling") - * @return Last component (e.g., "Pooling") - */ - std::string getLastComponent(const std::string &type); -}; - -} // namespace quick_dot_ai - -#endif diff --git a/models/transformer.cpp b/models/transformer.cpp deleted file mode 100644 index 8984d402..00000000 --- a/models/transformer.cpp +++ /dev/null @@ -1,458 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * Copyright (C) 2025 Eunju Yang - * - * @file transformer.cpp - * @date 10 July 2025 - * @see https://github.com/nntrainer/nntrainer - * @author Eunju Yang - * @bug No known bugs except for NYI items - * @brief This file defines Transformer's basic actions - */ - -#include - -#include -#include -#include - -#include -#include -#include - -#include -#include -#include -#include -#include - -namespace quick_dot_ai { - -std::string LoadBytesFromFile(const std::string &path) { - std::ifstream file(path, std::ios::binary | std::ios::ate); - if (!file.is_open()) { - throw std::runtime_error("Failed to open file: " + path); - } - std::streamsize size = file.tellg(); - file.seekg(0, std::ios::beg); - - std::string buffer(size, ' '); - if (!file.read(&buffer[0], size)) { - throw std::runtime_error("Failed to read file: " + path); - } - return buffer; -} - -ModelType strToModelType(std::string model_type) { - - std::string model_type_lower = model_type; - std::transform(model_type_lower.begin(), model_type_lower.end(), - model_type_lower.begin(), - [](unsigned char c) { return std::tolower(c); }); - - static const std::unordered_map model_type_map = { - {"model", ModelType::MODEL}, - {"causallm", ModelType::CAUSALLM}, - {"embedding", ModelType::EMBEDDING}}; - - if (model_type_map.find(model_type_lower) == model_type_map.end()) { - return ModelType::UNKNOWN; - } - - return model_type_map.at(model_type_lower); -} - -Transformer::Transformer(json &cfg, json &generation_cfg, json &nntr_cfg, - ModelType model_type) { - - std::string config_model_type_str = "Model"; - if (nntr_cfg.contains("model_type")) { - config_model_type_str = nntr_cfg["model_type"].get(); - } - - ModelType config_model_type = strToModelType(config_model_type_str); - - if (model_type != config_model_type) { - throw std::runtime_error("model_type mismatch. Class Type: " + - std::to_string(static_cast(model_type)) + - ", Config Type: " + config_model_type_str); - } - - // Initialize the model with the provided configurations - // This is where you would set up the model layers, parameters, etc. - setupParameters(cfg, generation_cfg, nntr_cfg); - - // prep tokenizer - tokenizer = tokenizers::Tokenizer::FromBlobJSON( - LoadBytesFromFile(nntr_cfg["tokenizer_file"])); -}; - -void Transformer::setupParameters(json &cfg, json &generation_cfg, - json &nntr_cfg) { - - /** Initialize nntr prameters */ - BATCH_SIZE = nntr_cfg["batch_size"].get(); - MODEL_TENSOR_TYPE = nntr_cfg["model_tensor_type"].get(); - INIT_SEQ_LEN = nntr_cfg["init_seq_len"]; - MAX_SEQ_LEN = nntr_cfg["max_seq_len"]; - NUM_TO_GENERATE = nntr_cfg["num_to_generate"]; - MODEL_TENSOR_TYPE = nntr_cfg["model_tensor_type"]; - MEMORY_SWAP = nntr_cfg.contains("fsu") ? nntr_cfg["fsu"].get() : false; - FSU_LOOKAHEAD = nntr_cfg.contains("fsu_lookahead") - ? nntr_cfg["fsu_lookahead"].get() - : 1; - EMBEDDING_DTYPE = nntr_cfg["embedding_dtype"]; - FC_LAYER_DTYPE = nntr_cfg["fc_layer_dtype"]; - - if (cfg.contains("is_causal")) { - IS_CAUSAL = cfg["is_causal"].get(); - } else if (cfg.contains("use_bidirectional_attention")) { - IS_CAUSAL = !cfg["use_bidirectional_attention"].get(); - } - - NUM_VOCAB = cfg["vocab_size"]; - DIM = cfg["hidden_size"]; - INTERMEDIATE_SIZE = cfg["intermediate_size"]; - NUM_LAYERS = cfg["num_hidden_layers"]; - NUM_HEADS = cfg["num_attention_heads"]; - HEAD_DIM = cfg.contains("head_dim") - ? cfg["head_dim"].get() - : DIM / NUM_HEADS; // default value is hidden_size / num_heads - NUM_KEY_VALUE_HEADS = cfg.contains("num_key_value_heads") - ? cfg["num_key_value_heads"].get() - : NUM_HEADS; - SLIDING_WINDOW = - cfg.contains("sliding_window") && !cfg["sliding_window"].is_null() - ? cfg["sliding_window"].get() - : UINT_MAX; - SLIDING_WINDOW_PATTERN = cfg.contains("sliding_window_pattern") - ? cfg["sliding_window_pattern"].get() - : 1; - MAX_POSITION_EMBEDDINGS = cfg["max_position_embeddings"].get(); - ROPE_THETA = cfg["rope_theta"].get(); - TIE_WORD_EMBEDDINGS = cfg["tie_word_embeddings"].get(); - NORM_EPS = cfg["rms_norm_eps"]; - GQA_SIZE = NUM_HEADS / NUM_KEY_VALUE_HEADS; - - return; -}; - -void Transformer::initialize() { - - // RegisterCustomLayers - registerCustomLayers(); - - // construct causalLM model - constructModel(); - - // setup model property - std::vector model_props = { - withKey("batch_size", BATCH_SIZE), withKey("epochs", "1"), - withKey("model_tensor_type", MODEL_TENSOR_TYPE)}; - if (MEMORY_SWAP) { - model_props.emplace_back(withKey("fsu", "true")); - model_props.emplace_back(withKey("fsu_lookahead", FSU_LOOKAHEAD)); - } - - model->setProperty(model_props); - - if (model->compile(ml::train::ExecutionMode::INFERENCE)) { - throw std::invalid_argument("Model compilation failed."); - } - - if (model->initialize(ml::train::ExecutionMode::INFERENCE)) { - throw std::invalid_argument("Model initialization failed."); - } - - is_initialized = true; - -#ifdef DEBUG - model->summarize(std::cout, ML_TRAIN_SUMMARY_MODEL); -#endif -} - -void Transformer::constructModel() { - - // layers used in the model - std::vector layers; - - // create model - model = ml::train::createModel(ml::train::ModelType::NEURAL_NET); - - // create input layer - layers.push_back(createLayer( - "input", {withKey("name", "input0"), - withKey("input_shape", "1:1:" + std::to_string(INIT_SEQ_LEN))})); - - // create embedding layer - const std::string embedding_type = - TIE_WORD_EMBEDDINGS ? "tie_word_embeddings" : "embedding_layer"; - - layers.push_back(createLayer( - embedding_type, - {"name=embedding0", "in_dim=" + std::to_string(NUM_VOCAB), - "weight_dtype=" + EMBEDDING_DTYPE, "out_dim=" + std::to_string(DIM), - "scale=" + std::to_string(EMBEDDING_SCALE)})); - - // create transformer layers - for (int i = 0; i < NUM_LAYERS; ++i) { - std::vector transformer; - if (i == 0) - transformer = createTransformerDecoderBlock(0, "embedding0"); - else - transformer = createTransformerDecoderBlock( - i, "layer" + std::to_string(i - 1) + "_decoder_output"); - layers.insert(layers.end(), transformer.begin(), transformer.end()); - } - - // create rms_norm - layers.push_back(createLayer( - "rms_norm", - {withKey("name", "output_norm"), - withKey("epsilon", std::to_string(NORM_EPS)), - withKey("input_layers", - "layer" + std::to_string(NUM_LAYERS - 1) + "_decoder_output"), - withKey("packed", "false")})); - - // add created layers into the model - for (auto &layer : layers) { - model->addLayer(layer); - } -}; - -void Transformer::load_weight(const std::string &weight_path) { - - if (!is_initialized) { - throw std::runtime_error( - "Transformer model is not initialized. Please call " - "initialize() before load_weight()."); - } - - try { - model->load(weight_path, ml::train::ModelFormat::MODEL_FORMAT_BIN); - } catch (const std::exception &e) { - throw std::runtime_error("Failed to load model weights: " + - std::string(e.what())); - } -}; - -void Transformer::save_weight(const std::string &weight_path) { - - if (!is_initialized) { - throw std::runtime_error( - "Transformer model is not initialized. Please call " - "initialize() before save_weight()."); - } - - try { - model->save(weight_path, ml::train::ModelFormat::MODEL_FORMAT_BIN); - } catch (const std::exception &e) { - throw std::runtime_error("Failed to save model weights: " + - std::string(e.what())); - } -}; - -void Transformer::save_weight( - const std::string &weight_path, ml::train::TensorDim::DataType dtype, - const std::map - &layer_dtype_map) { - - if (!is_initialized) { - throw std::runtime_error( - "Transformer model is not initialized. Please call " - "initialize() before save_weight()."); - } - - try { - model->save(weight_path, ml::train::ModelFormat::MODEL_FORMAT_BIN, dtype, - layer_dtype_map); - } catch (const std::exception &e) { - throw std::runtime_error("Failed to save model weights with dtype: " + - std::string(e.what())); - } -}; - -void Transformer::run(const WSTR prompt, void *output_buf, bool log_output) { - run(prompt, "", "", output_buf, log_output); -} - -void Transformer::run(const WSTR prompt, const WSTR system_prompt, - const WSTR tail_prompt, void *output_buf, - bool log_output) { - if (!is_initialized) { - throw std::runtime_error( - "Transformer model is not initialized. Please call " - "initialize() before run()."); - } - ///@note This part should be filled in. - /// The run action can be defined by the precedent classes. -} - -std::vector -Transformer::createTransformerDecoderBlock(const int layer_id, - std::string input_name) { - - std::vector layers; - - layers.push_back(createLayer( - "rms_norm", - {withKey("name", "layer" + std::to_string(layer_id) + "_attention_norm"), - withKey("input_layers", input_name), - withKey("epsilon", std::to_string(NORM_EPS)), - withKey("packed", "false")})); - - auto att_layer = - createAttention(layer_id, INIT_SEQ_LEN, NUM_HEADS, HEAD_DIM, - "layer" + std::to_string(layer_id) + "_attention_norm", - "layer" + std::to_string(layer_id) + "_attention_norm", - "layer" + std::to_string(layer_id) + "_attention_norm"); - - layers.insert(layers.end(), att_layer.begin(), att_layer.end()); - - layers.push_back(createLayer( - "addition", - {withKey("name", "layer" + std::to_string(layer_id) + "_decoder_add"), - withKey("input_layers", input_name + ",layer" + std::to_string(layer_id) + - "_attention_out")})); - - layers.push_back(createLayer( - "rms_norm", - {withKey("name", "layer" + std::to_string(layer_id) + "_ffn_norm"), - withKey("input_layers", - "layer" + std::to_string(layer_id) + "_decoder_add"), - withKey("epsilon", std::to_string(NORM_EPS)), - withKey("packed", "false")})); - - auto ffn_layer = createMlp(layer_id, DIM, INTERMEDIATE_SIZE, - "layer" + std::to_string(layer_id) + "_ffn_norm"); - layers.insert(layers.end(), ffn_layer.begin(), ffn_layer.end()); - - layers.push_back(createLayer( - "addition", - {withKey("name", "layer" + std::to_string(layer_id) + "_decoder_output"), - withKey("input_layers", "layer" + std::to_string(layer_id) + - "_decoder_add,layer" + std::to_string(layer_id) + - "_ffn_down")})); - - return layers; -} - -std::vector -Transformer::createAttention(const int layer_id, int seq_len, int n_heads, - int head_dim, std::string query_name, - std::string key_name, std::string value_name) { - - std::vector layers; - - auto Q = "layer" + std::to_string(layer_id) + "_wq"; - auto K = "layer" + std::to_string(layer_id) + "_wk"; - auto V = "layer" + std::to_string(layer_id) + "_wv"; - auto A = "layer" + std::to_string(layer_id) + "_attention"; - auto O = "layer" + std::to_string(layer_id) + "_attention_out"; - - // Q layer - std::vector q_params = { - withKey("name", Q), withKey("unit", head_dim * n_heads), - withKey("disable_bias", "true"), withKey("input_layers", query_name), - withKey("weight_initializer", "ones")}; - layers.push_back(createLayer("fully_connected", q_params)); - - // K layer - std::vector k_params = { - withKey("name", K), withKey("unit", head_dim * n_heads / GQA_SIZE), - withKey("disable_bias", "true"), withKey("input_layers", key_name), - withKey("weight_initializer", "ones")}; - layers.push_back(createLayer("fully_connected", k_params)); - - // V layer - std::vector v_params = { - withKey("name", V), withKey("unit", head_dim * n_heads / GQA_SIZE), - withKey("disable_bias", "true"), withKey("input_layers", value_name), - withKey("weight_initializer", "ones")}; - layers.push_back(createLayer("fully_connected", v_params)); - - // Attention core layer - std::vector a_params = { - withKey("name", A), - withKey("num_heads", n_heads), - withKey("num_heads_kv", n_heads / GQA_SIZE), - withKey("max_timestep", std::to_string(INIT_SEQ_LEN + NUM_TO_GENERATE)), - withKey("sliding_window", (layer_id + 1) % SLIDING_WINDOW_PATTERN - ? SLIDING_WINDOW - : UINT_MAX), - withKey("rope_theta", ROPE_THETA), - withKey("max_new_tokens", std::to_string(NUM_TO_GENERATE)), - withKey("is_causal", IS_CAUSAL ? "true" : "false"), - withKey("input_layers", {Q, K, V})}; - layers.push_back(createLayer("mha_core", a_params)); - - // O layer - std::vector o_params = { - withKey("name", O), withKey("unit", DIM), withKey("disable_bias", "true"), - withKey("input_layers", A), withKey("weight_initializer", "ones")}; - layers.push_back(createLayer("fully_connected", o_params)); - - return layers; -} - -std::vector Transformer::createMlp(const int layer_id, int dim, - int hidden_dim, - std::string input_name) { - - std::vector layers; - - layers.push_back(createLayer( - "fully_connected", - {withKey("name", "layer" + std::to_string(layer_id) + "_ffn_up"), - withKey("unit", hidden_dim), withKey("disable_bias", "true"), - withKey("input_layers", input_name), - withKey("weight_initializer", "ones")})); - layers.push_back(createLayer( - "fully_connected", - {withKey("name", "layer" + std::to_string(layer_id) + "_ffn_gate"), - withKey("unit", hidden_dim), withKey("disable_bias", "true"), - withKey("input_layers", input_name), - withKey("weight_initializer", "ones")})); - - layers.push_back(createLayer( - "swiglu", - {withKey("name", "layer" + std::to_string(layer_id) + "_ffn_swiglu"), - withKey("input_layers", "layer" + std::to_string(layer_id) + "_ffn_gate," + - "layer" + std::to_string(layer_id) + - "_ffn_up")})); - - layers.push_back(createLayer( - "fully_connected", - {withKey("name", "layer" + std::to_string(layer_id) + "_ffn_down"), - withKey("unit", dim), withKey("disable_bias", "true"), - withKey("input_layers", - "layer" + std::to_string(layer_id) + "_ffn_swiglu"), - withKey("weight_initializer", "ones")})); - - return layers; -} - -void Transformer::registerCustomLayers() { - /// - const auto &ct_engine = nntrainer::Engine::Global(); - const auto app_context = - static_cast(ct_engine.getRegisteredContext("cpu")); - - try { - app_context->registerFactory(nntrainer::createLayer); - app_context->registerFactory( - nntrainer::createLayer); - app_context->registerFactory( - nntrainer::createLayer); - app_context->registerFactory( - nntrainer::createLayer); - app_context->registerFactory( - nntrainer::createLayer); - - } catch (std::invalid_argument &e) { - std::cerr << "failed to register factory, reason: " << e.what() - << std::endl; - } -} - -} // namespace quick_dot_ai diff --git a/models/transformer.h b/models/transformer.h deleted file mode 100644 index 0a937c3e..00000000 --- a/models/transformer.h +++ /dev/null @@ -1,170 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * Copyright (C) 2025 Eunju Yang - * - * @file transformer.h - * @date 31 Dec 2025 - * @see https://github.com/nntrainer/nntrainer - * @author Eunju Yang - * @bug No known bugs except for NYI items - * @note This transformer.h constructs a class for Transformer model which can - * be a parent of CausalLM and Encoder models with transformer structure. - * @note This transformer assumes the following structure : - * - * [Input] - * | - * [Embedding] - * | - * [Decoder Block] (repeated N times) - * | - * [RMSNorm] - * - */ -#ifndef __TRANSFORMER_H__ -#define __TRANSFORMER_H__ - -#pragma once - -#include -#include -#include - -#include - -namespace quick_dot_ai { - -/** - * @brief Transformer Class - */ -WIN_EXPORT class Transformer : virtual public TransformerBase { - -public: - /** - * @brief Construct a new Transformer object - * @param cfg Configuration for the model (config.json) - * @param generation_cfg Configuration for the generation (generation.json) - * @param nntr_cfg Configuration for nntrainer (nntrainer_config.json) - * @param model_type Type of the model (default: ModelType::MODEL) - */ - Transformer(json &cfg, json &generation_cfg, json &nntr_cfg, - ModelType model_type = ModelType::MODEL); - - /** - * @brief Destroy the Transformer object - */ - virtual ~Transformer() {} - - /** - * @brief Initialize and Construct the Transformer model - */ - void initialize() override; - - /** - * @brief Load the model weights from a file - */ - void load_weight(const std::string &weight_path) override; - - /** - * @brief Save the weight to a file - */ - void save_weight(const std::string &weight_path) override; - - /** - * @brief Save the weight to a file with type conversion - * @param weight_path Path to save the weight file - * @param dtype Global target data type for all layers (NONE = keep original) - * @param layer_dtype_map Per-layer data type overrides (layer_name -> dtype) - */ - void save_weight(const std::string &weight_path, - ml::train::TensorDim::DataType dtype, - const std::map - &layer_dtype_map = {}) override; - - /** - * @copydoc TransformerBase::run(const WSTR, void *, bool) - */ - void run(const WSTR prompt, void *output_buf = nullptr, - bool log_output = true) override; - - /** - * @brief TransformerBase::run(const WSTR, const WSTR, const WSTR, void *, - * bool) - */ - void run(const WSTR prompt, const WSTR system_prompt = "", - const WSTR tail_prompt = "", void *output_buf = nullptr, - bool log_output = true) override; - - /** - * @brief Get TransformerPerformanceMetrics - */ - TransformerPerformanceMetrics getPerformanceMetrics() const { - return performance_metrics; - } - -protected: - /** - * @brief Setup the parameters for the Transformer model - */ - virtual void setupParameters(json &cfg, json &generation_cfg, json &nntr_cfg); - - /** - * @brief Construct Model - */ - virtual void constructModel(); - - /** - * @brief create Attention Layer - */ - virtual std::vector - createTransformerDecoderBlock(const int layer_id, std::string input_name); - - /** - * @brief create Attention Layer - */ - virtual std::vector - createAttention(const int layer_id, int seq_len, int n_heads, int head_dim, - std::string query_name, std::string key_name, - std::string value_name); - - /** - * @brief create Feed Forward Layer - */ - virtual std::vector createMlp(const int layer_id, int dim, - int hidden_dim, - std::string input_name); - - /** - * @brief register CustomLayers - */ - virtual void registerCustomLayers(); - - int HEAD_DIM; - int INTERMEDIATE_SIZE; - bool USE_VOCAB_SELECTION; - bool TIE_WORD_EMBEDDINGS; - int NUM_HEADS; - int NUM_KEY_VALUE_HEADS; - std::string MODEL_TENSOR_TYPE; - std::string EMBEDDING_DTYPE; /** embedding dtype */ - std::string FC_LAYER_DTYPE; /** custom_fc_lora */ - - unsigned int SLIDING_WINDOW = UINT_MAX; - unsigned int SLIDING_WINDOW_PATTERN = 5; - unsigned int ROPE_THETA = 10000; /**< RoPE theta value */ - float NORM_EPS = 1e-5; /**< RMSNorm epsilon value */ - float EMBEDDING_SCALE = 1.0f; - int GQA_SIZE; - - unsigned int MAX_POSITION_EMBEDDINGS; /**< max_position embeddings */ - bool MEMORY_SWAP; /**< memory swap option */ - unsigned int FSU_LOOKAHEAD; - float ATTN_LOGIT_SOFTCAPPING = 0.0f; /**< attention logit softcapping */ - bool IS_CAUSAL = true; - - // Performance metrics - TransformerPerformanceMetrics performance_metrics; -}; - -} // namespace quick_dot_ai - -#endif diff --git a/models/transformer_base.h b/models/transformer_base.h deleted file mode 100644 index eba224dc..00000000 --- a/models/transformer_base.h +++ /dev/null @@ -1,170 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * Copyright (C) 2025 Eunju Yang - * - * @file transformer_base.h - * @date 31 Mar 2026 - * @see https://github.com/nntrainer/nntrainer - * @author Eunju Yang - * @bug No known bugs except for NYI items - * @note This transformer_base.h defines an abstract base class for - * Transformer-based models. It provides the common interface and shared state - * that both NNTrainer-based Transformer and QNN-based Transformer can inherit. - */ - -#ifndef __TRANSFORMER_BASE_H__ -#define __TRANSFORMER_BASE_H__ - -#pragma once -#ifdef _WIN32 -#define WIN_EXPORT __declspec(dllexport) -#define WSTR std::wstring -#define WCHAR_P wchar_t * -#else -#define WIN_EXPORT -#define WSTR std::string -#define WCHAR_P std::string & -#endif - -#include -#include - -#include -#include -#include -#include -#include - -#include "json.hpp" -#include "performance_metrics.h" - -namespace quick_dot_ai { - -/*** ALIAS ****/ -using LayerHandle = std::shared_ptr; -using ModelHandle = std::unique_ptr; - -using json = nlohmann::json; - -/** - * @brief Model Type Enum - */ -enum class ModelType { MODEL, CAUSALLM, EMBEDDING, UNKNOWN }; - -/** - * @brief TransformerBase Abstract Class - * @note This is the common interface for all Transformer-based models. - * Both NNTrainer Transformer and QNN Transformer inherit from this - */ -WIN_EXPORT class TransformerBase { - -public: - /** - * @brief Default constructor - */ - TransformerBase() = default; - - /** - * @brief Destroy the TransformerBase object - */ - virtual ~TransformerBase() = default; - - /** - * @brief Initialize and Construct the Transformer model - */ - virtual void initialize() = 0; - - /** - * @brief Load the model weights from a file - */ - virtual void load_weight(const std::string &weight_path) = 0; - - /** - * @brief Save the weight to a file - */ - virtual void save_weight(const std::string &weight_path) = 0; - - /** - * @brief Save the weight to a file with type conversion - * @param weight_path Path to save the weight file - * @param dtype Global target data type for all layers (NONE = keep original) - * @param layer_dtype_map Per-layer data type overrides (layer_name -> dtype) - * @note Default implementation throws; concrete subclasses that support - * type-converted save (e.g. NNTrainer-based Transformer) override it. - */ - virtual void - save_weight(const std::string &weight_path, - ml::train::TensorDim::DataType dtype, - const std::map - &layer_dtype_map = {}) { - throw std::runtime_error( - "save_weight with type conversion is not implemented for this " - "TransformerBase subclass"); - } - - /** - * @brief run the Transformer model (simple) - * @param prompt User prompt - * @param output_buf Optional output pointer. For CausalLM, pass - * std::vector*. For Sentence Transformer, pass - * std:vector *. nullptr to skip output collection. - * @param log_output Whether to log output to stdout - */ - virtual void run(const WSTR prompt, void *output_buf = nullptr, - bool log_output = true) = 0; - - /** - * @brief run the Transformer model (full) - * @param prompt User prompt - * @param system_prompt System prompt prepended to user prompt - * @param tail_prompt Tail prompt appended to user prompt - * @param output_buf Optional output pointer (see simple overload for types) - * @param log_output Whether to log output to stdout - */ - virtual void run(const WSTR prompt, const WSTR system_prompt = "", - const WSTR tail_prompt = "", void *output_buf = nullptr, - bool log_output = true) = 0; - -protected: - bool is_initialized = false; /**< Flag to check if the model is initialized */ - ModelHandle model; - - /** tokenizer */ - std::unique_ptr tokenizer; - - unsigned int NUM_VOCAB; - int DIM; - int NUM_LAYERS; - - unsigned int MAX_SEQ_LEN; - unsigned int BATCH_SIZE; - unsigned int INIT_SEQ_LEN; - unsigned int NUM_TO_GENERATE; -}; - -/** - * Loads JSON data from a file with detailed error handling - * @param file_path Path to JSON file - * @return JSON object - * @throws std::runtime_error on file open or parse failure - */ -inline json LoadJsonFile(const std::string &file_path) { - std::ifstream file(file_path); - if (!file.is_open()) { - throw std::runtime_error("Failed to open file: " + file_path + - " | Reason: " + std::strerror(errno)); - } - - try { - json data; - file >> data; - return data; - } catch (const json::parse_error &e) { - throw std::runtime_error("JSON parse error in " + file_path + - " | Details: " + e.what()); - } -} - -} // namespace quick_dot_ai - -#endif diff --git a/nntrainer b/nntrainer new file mode 160000 index 00000000..5c07ad06 --- /dev/null +++ b/nntrainer @@ -0,0 +1 @@ +Subproject commit 5c07ad068a65a5591a110e5b2cddbc2baa8e8cab diff --git a/qnn/README.md b/qnn/README.md new file mode 100644 index 00000000..17cfdbb7 --- /dev/null +++ b/qnn/README.md @@ -0,0 +1,416 @@ +# Quick.AI QNN Context Guide βš™οΈ + +> **Quick.AI ν”„λ‘œμ νŠΈ λ¬Έμ„œ** | nntrainer μ„œλΈŒλͺ¨λ“ˆ 기반 QNN λ°±μ—”λ“œ ν™•μž₯ κ°€μ΄λ“œ +> +> 이 λ¬Έμ„œλŠ” [nntrainer](https://github.com/nntrainer/nntrainer) ν”„λ ˆμž„μ›Œν¬μ—μ„œ Qualcomm Neural Network (QNN) λ°±μ—”λ“œλ₯Ό κ΄€λ¦¬ν•˜λŠ” `QNNContext` 클래슀λ₯Ό μ°Έκ³ ν•˜μ—¬, **Quick.AI ν”„λ‘œμ νŠΈ λ‚΄μ—μ„œ μ‚¬μš©μžκ°€ μžμ‹ λ§Œμ˜ μ»€μŠ€ν…€ QNN Contextλ₯Ό λ§Œλ“œλŠ” 방법**을 μ•ˆλ‚΄ν•©λ‹ˆλ‹€. + +## κ°œμš” + +Quick.AI의 QNN (`qnn/`) λ””λ ‰ν† λ¦¬λŠ” Android κΈ°κΈ°μ—μ„œ Qualcomm NPU (HTP)λ₯Ό 톡해 LLM 좔둠을 κ°€μ†ν™”ν•˜λŠ” QNN λ°±μ—”λ“œ μ»΄ν¬λ„ŒνŠΈλ₯Ό ν¬ν•¨ν•©λ‹ˆλ‹€. μ΄λŠ” nntrainer의 QNNContextλ₯Ό 기반으둜 ν•˜λ©°, Quick.AI λΉŒλ“œ μ‹œ `--enable-qnn` μ˜΅μ…˜μœΌλ‘œ ν™œμ„±ν™”λ©λ‹ˆλ‹€. + +`QNNContext`λŠ” λ‹€μŒκ³Ό 같은 역할을 μˆ˜ν–‰ν•©λ‹ˆλ‹€: + +- **QNN λ°±μ—”λ“œ μ΄ˆκΈ°ν™”**: HTP λ°±μ—”λ“œ 라이브러리(`libQnnHtp.so`) λ‘œλ“œ 및 μ„€μ • +- **λ ˆμ΄μ–΄ νŒ©ν† λ¦¬ 관리**: QNN μ „μš© λ ˆμ΄μ–΄(QNNLinear, QNNGraph λ“±)의 생성 νŒ©ν† λ¦¬ 등둝/쑰회 +- **λ©”λͺ¨λ¦¬ 관리**: RPC λ©”λͺ¨λ¦¬ ν• λ‹ΉκΈ°(`QNNRpcManager`)λ₯Ό ν†΅ν•œ QNN ν…μ„œ λ©”λͺ¨λ¦¬ 관리 +- **λ°”μ΄λ„ˆλ¦¬ μ»¨ν…μŠ€νŠΈ λ‘œλ“œ**: QNN λ°”μ΄λ„ˆλ¦¬ 파일(`.bin`)λ‘œλΆ€ν„° κ·Έλž˜ν”„λ₯Ό λ‘œλ“œν•˜κ³  μ‹€ν–‰ μ€€λΉ„ + +### Quick.AIμ—μ„œ QNN μ‚¬μš©ν•˜κΈ° + +```bash +# QNN μ§€μ›μœΌλ‘œ Android λΉŒλ“œ +./build.sh --platform=android --enable-qnn + +# λ˜λŠ” QNN νƒ€κ²Ÿλ§Œ λΉŒλ“œ +./build.sh --platform=android --target=qnn +``` + +QNN λͺ¨λΈμ€ Android (arm64-v8a)μ—μ„œλ§Œ μ§€μ›λ©λ‹ˆλ‹€. μžμ„Έν•œ QNN/Hexagon SDK μ„€μΉ˜ 방법은 [QNN μ„€μΉ˜ κ°€μ΄λ“œ](../docs/HowToInstallQNN.md)λ₯Ό μ°Έμ‘°ν•˜μ„Έμš”. + +--- + +## μ•„ν‚€ν…μ²˜ κ°œμš” + +### 클래슀 상속 ꡬ쑰 + +``` +Context (context.h) Singleton (singleton.h) + β”‚ β”‚ + β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ + β”‚ + QNNContext (qnn_context.h) +``` + +### 데이터 ꡬ쑰 + +``` +ContextData (context.h) + β”‚ +QNNBackendVar (qnn_context_var.h) + β”‚ + └── QNNVar (qnn_context_var.h) ← QNN λ°±μ—”λ“œ ν•Έλ“€, λ””λ°”μ΄μŠ€, μ»¨ν…μŠ€νŠΈ λ§΅ λ“± +``` + +- `Context`: λ ˆμ΄μ–΄/μ˜΅ν‹°λ§ˆμ΄μ € νŒ©ν† λ¦¬μ™€ `ContextData`λ₯Ό κ΄€λ¦¬ν•˜λŠ” 베이슀 클래슀 +- `Singleton`: 싱글톀 νŒ¨ν„΄μ„ μ œκ³΅ν•˜λŠ” ν…œν”Œλ¦Ώ 클래슀 (`Global()` λ©”μ„œλ“œλ‘œ μ ‘κ·Ό) +- `ContextData`: λ°±μ—”λ“œλ³„ 데이터λ₯Ό μ €μž₯ν•˜λŠ” μ»¨ν…Œμ΄λ„ˆ (λ©”λͺ¨λ¦¬ ν• λ‹ΉκΈ° 포함) +- `QNNBackendVar`: `ContextData`λ₯Ό μƒμ†ν•˜μ—¬ `QNNVar`λ₯Ό λž˜ν•‘ +- `QNNVar`: QNN λ°±μ—”λ“œ ν•Έλ“€, λ””λ°”μ΄μŠ€ ν•Έλ“€, ν”„λ‘œνŒŒμΌλ§, λ°”μ΄λ„ˆλ¦¬ μ»¨ν…μŠ€νŠΈ λ§΅ λ“± μ‹€μ œ QNN μƒνƒœλ₯Ό 보관 + +--- + +## μ»€μŠ€ν…€ Context λ§Œλ“€κΈ° + +`QNNContext`λ₯Ό μ°Έκ³ ν•˜μ—¬ μƒˆλ‘œμš΄ λ°±μ—”λ“œ Contextλ₯Ό λ§Œλ“œλŠ” 단계별 κ°€μ΄λ“œμž…λ‹ˆλ‹€. + +### Step 1: ContextData 상속과 λ°±μ—”λ“œ 데이터 클래슀 μž‘μ„± + +λ°±μ—”λ“œμ—μ„œ μ‚¬μš©ν•  μƒνƒœμ™€ 핸듀을 λ‹΄λŠ” 데이터 클래슀λ₯Ό λ§Œλ“­λ‹ˆλ‹€. + +```cpp +// my_backend_var.h +#include + +namespace nntrainer { + +// λ°±μ—”λ“œ 고유 μƒνƒœλ₯Ό λ³΄κ΄€ν•˜λŠ” ꡬ쑰체 +struct MyBackendVar { + void *backend_handle = nullptr; + bool is_initialized = false; + // ... λ°±μ—”λ“œλ³„ ν•Έλ“€, μ„€μ • λ“± +}; + +// ContextDataλ₯Ό μƒμ†ν•˜μ—¬ λž˜ν•‘ +class MyBackendData : public ContextData { +public: + MyBackendData() : data(std::make_shared()) {} + std::shared_ptr &getVar() { return data; } + +private: + std::shared_ptr data; +}; + +} // namespace nntrainer +``` + +> **μ°Έκ³ **: `QNNContext`μ—μ„œλŠ” `QNNBackendVar`κ°€ `ContextData`λ₯Ό μƒμ†ν•˜κ³ , 내뢀에 `QNNVar`λ₯Ό `shared_ptr`둜 λ³΄κ΄€ν•©λ‹ˆλ‹€. (`qnn_context_var.h` μ°Έμ‘°) + +### Step 2: Context 클래슀 상속 + +`Context`λ₯Ό 상속받아 ν•„μˆ˜ λ©”μ„œλ“œλ₯Ό κ΅¬ν˜„ν•©λ‹ˆλ‹€. 싱글톀이 ν•„μš”ν•˜λ©΄ `Singleton`도 ν•¨κ»˜ μƒμ†ν•©λ‹ˆλ‹€. + +```cpp +// my_context.h +#include +#include "singleton.h" +#include "my_backend_var.h" + +namespace nntrainer { + +class MyContext : public Context, public Singleton { +public: + MyContext() : Context(std::make_shared()) {} + ~MyContext(); + + // [ν•„μˆ˜] λ°±μ—”λ“œ μ΄ˆκΈ°ν™” + int init() override; + + // [ν•„μˆ˜] Context 이름 λ°˜ν™˜ (Engineμ—μ„œ μ‹λ³„μš©) + std::string getName() override { return "my_backend"; } + + // [ν•„μˆ˜] λ ˆμ΄μ–΄ 생성 - λ¬Έμžμ—΄ ν‚€ + std::unique_ptr + createLayerObject(const std::string &type, + const std::vector &properties) override; + + // [ν•„μˆ˜] λ ˆμ΄μ–΄ 생성 - μ •μˆ˜ ν‚€ + std::unique_ptr + createLayerObject(const int int_key, + const std::vector &properties = {}) override; + + // [선택] λͺ¨λΈ λ°”μ΄λ„ˆλ¦¬ λ‘œλ“œ + int load(const std::string &file_path) override; + +private: + // Singletonμ—μ„œ ν˜ΈμΆœν•˜λŠ” μ΄ˆκΈ°ν™” λ©”μ„œλ“œ + void initialize() noexcept override; + + // λ ˆμ΄μ–΄ νŒ©ν† λ¦¬ λ§΅ + FactoryMap factory_map; +}; + +} // namespace nntrainer +``` + +### Step 3: ν•„μˆ˜ λ©”μ„œλ“œ κ΅¬ν˜„ + +#### `initialize()`: 졜초 μ΄ˆκΈ°ν™” 및 λ ˆμ΄μ–΄ 등둝 + +`Singleton::Global()` 호좜 μ‹œ `initializeOnce()`λ₯Ό 톡해 ν•œ 번만 μ‹€ν–‰λ©λ‹ˆλ‹€. μ—¬κΈ°μ„œ λ°±μ—”λ“œλ₯Ό μ΄ˆκΈ°ν™”ν•˜κ³  λ ˆμ΄μ–΄ νŒ©ν† λ¦¬λ₯Ό λ“±λ‘ν•©λ‹ˆλ‹€. + +```cpp +// my_context.cpp +#include "my_context.h" +#include + +namespace nntrainer { + +void MyContext::initialize() noexcept { + try { + // 1. λ°±μ—”λ“œ μ΄ˆκΈ°ν™” + init(); + + // 2. λ©”λͺ¨λ¦¬ ν• λ‹ΉκΈ° μ„€μ • (ν•„μš” μ‹œ) + setMemAllocator(std::make_shared()); + + // 3. μ»€μŠ€ν…€ λ ˆμ΄μ–΄ νŒ©ν† λ¦¬ 등둝 + registerFactory(nntrainer::createLayer, + MyCustomLayer::type, + ml::train::LayerType::LAYER_FC); + + } catch (std::exception &e) { + ml_loge("MyContext initialization failed: %s", e.what()); + } +} +``` + +> **μ°Έκ³ **: `QNNContext::initialize()`μ—μ„œλŠ” `QNNLinear`, `WeightLayer`, `TensorLayer`, `QNNGraph` λ„€ κ°€μ§€ λ ˆμ΄μ–΄λ₯Ό λ“±λ‘ν•©λ‹ˆλ‹€. (`qnn_context.cpp`의 `registerFactory` ν˜ΈμΆœλΆ€ μ°Έμ‘°) + +#### `init()`: λ°±μ—”λ“œ μ„ΈλΆ€ μ΄ˆκΈ°ν™” + +λ°±μ—”λ“œ 라이브러리 λ‘œλ“œ, ν•Έλ“€ 생성 λ“± ꡬ체적인 μ΄ˆκΈ°ν™” λ‘œμ§μ„ κ΅¬ν˜„ν•©λ‹ˆλ‹€. + +```cpp +int MyContext::init() { + auto backend_data = std::static_pointer_cast(getContextData()); + auto var = backend_data->getVar(); + + // λ°±μ—”λ“œ 라이브러리 λ‘œλ“œ + // λ””λ°”μ΄μŠ€ ν•Έλ“€ 생성 + // ν”„λ‘œνŒŒμΌλ§ μ„€μ • + // ... + + var->is_initialized = true; + return 0; +} +``` + +#### `createLayerObject()`: λ ˆμ΄μ–΄ 객체 생성 + +νŒ©ν† λ¦¬ λ§΅μ—μ„œ ν‚€λ‘œ κ²€μƒ‰ν•˜μ—¬ λ ˆμ΄μ–΄ 객체λ₯Ό μƒμ„±ν•©λ‹ˆλ‹€. + +```cpp +std::unique_ptr +MyContext::createLayerObject(const std::string &type, + const std::vector &properties) { + return createObject(type, properties); +} + +std::unique_ptr +MyContext::createLayerObject(const int int_key, + const std::vector &properties) { + return createObject(int_key, properties); +} +``` + +#### `load()`: λͺ¨λΈ λ°”μ΄λ„ˆλ¦¬ λ‘œλ“œ (선택) + +QNN λ°”μ΄λ„ˆλ¦¬ λ“± λ°±μ—”λ“œ 고유 λͺ¨λΈ νŒŒμΌμ„ λ‘œλ“œν•©λ‹ˆλ‹€. + +```cpp +int MyContext::load(const std::string &file_path) { + // λ°±μ—”λ“œ 고유 λ°”μ΄λ„ˆλ¦¬ λ‘œλ“œ 둜직 + return 0; +} +``` + +### Step 4: λ ˆμ΄μ–΄ νŒ©ν† λ¦¬ 등둝 + +`registerFactory` λ©”μ„œλ“œλ₯Ό μ‚¬μš©ν•˜μ—¬ μ»€μŠ€ν…€ λ ˆμ΄μ–΄λ₯Ό λ“±λ‘ν•©λ‹ˆλ‹€. + +```cpp +// ν•¨μˆ˜ ν¬μΈν„°λ‘œ 등둝 +registerFactory(nntrainer::createLayer, "my_layer_type", -1); + +// μ •μˆ˜ ν‚€λ₯Ό ν•¨κ»˜ μ§€μ • +registerFactory(nntrainer::createLayer, + MyLayer::type, + ml::train::LayerType::LAYER_FC); +``` + +**νŒŒλΌλ―Έν„° μ„€λͺ…:** + +| νŒŒλΌλ―Έν„° | μ„€λͺ… | +|---------|------| +| `factory` | λ ˆμ΄μ–΄λ₯Ό μƒμ„±ν•˜λŠ” νŒ©ν† λ¦¬ ν•¨μˆ˜ (`std::unique_ptr(const PropsType&)`) | +| `key` | λ¬Έμžμ—΄ ν‚€ (빈 λ¬Έμžμ—΄μ΄λ©΄ `factory({})->getType()`으둜 μžλ™ κ²°μ •) | +| `int_key` | μ •μˆ˜ ν‚€ (`-1`이면 μžλ™ ν• λ‹Ή) | + +> **주의**: 이미 λ“±λ‘λœ ν‚€λ₯Ό λ‹€μ‹œ λ“±λ‘ν•˜λ©΄ `std::invalid_argument` μ˜ˆμ™Έκ°€ λ°œμƒν•©λ‹ˆλ‹€. + +### Step 5: Pluggable Context둜 λΉŒλ“œ + +`PLUGGABLE` 맀크둜λ₯Ό μ •μ˜ν•˜μ—¬ 곡유 라이브러리(`.so`)둜 λΉŒλ“œν•˜λ©΄, λŸ°νƒ€μž„μ— λ™μ μœΌλ‘œ λ‘œλ“œν•  수 μžˆμŠ΅λ‹ˆλ‹€. + +```cpp +// my_context.cpp ν•˜λ‹¨ +#ifdef PLUGGABLE +nntrainer::Context *create_my_context() { + nntrainer::MyContext *ctx = new nntrainer::MyContext(); + ctx->Global(); + return ctx; +} + +void destroy_my_context(nntrainer::Context *ct) { delete ct; } + +extern "C" { +nntrainer::ContextPluggable ml_train_context_pluggable{ + create_my_context, destroy_my_context}; +} +#endif +``` + +> **핡심**: `extern "C"` λΈ”λ‘μ—μ„œ `ml_train_context_pluggable` 심볼을 λ°˜λ“œμ‹œ μ •μ˜ν•΄μ•Ό ν•©λ‹ˆλ‹€. Engine이 이 심볼을 톡해 Contextλ₯Ό λ‘œλ“œν•©λ‹ˆλ‹€. + +--- + +## Engine 등둝 + +μƒμ„±ν•œ Contextλ₯Ό Engine에 λ“±λ‘ν•˜μ—¬ μ‚¬μš©ν•©λ‹ˆλ‹€. + +```cpp +#include "engine.h" + +auto &engine = nntrainer::Engine::Global(); + +// 곡유 라이브러리둜 등둝 (Pluggable 방식) +engine.registerContext("libmy_context.so", ""); + +// λ“±λ‘λœ Context μ‚¬μš© +auto *ctx = engine.getRegisteredContext("my_backend"); +``` + +Engine은 `registerContext()`λ₯Ό ν˜ΈμΆœν•  λ•Œ 곡유 라이브러리λ₯Ό `dlopen`ν•˜κ³ , λ‚΄λΆ€μ˜ `ml_train_context_pluggable` 심볼을 μ°Ύμ•„ Contextλ₯Ό μƒμ„±ν•©λ‹ˆλ‹€. + +λ ˆμ΄μ–΄ 생성 μ‹œ `engine` μ†μ„±μœΌλ‘œ Contextλ₯Ό μ§€μ •ν•  수 μžˆμŠ΅λ‹ˆλ‹€: + +```cpp +auto layer = createLayer("my_layer_type", {withKey("engine", "my_backend")}); +``` + +--- + +## QNNContext μ΄ˆκΈ°ν™” 흐름 + +`QNNContext`의 `init()` λ©”μ„œλ“œκ°€ μ‹€ν–‰λ˜λŠ” λ‚΄λΆ€ νλ¦„μž…λ‹ˆλ‹€. μ»€μŠ€ν…€ Context μž‘μ„± μ‹œ μ°Έκ³ ν•  수 μžˆμŠ΅λ‹ˆλ‹€. + +``` +init() + β”‚ + β”œβ”€β”€ 1. λ‘œκΉ… μ΄ˆκΈ°ν™” (QNN Logger) + β”‚ + β”œβ”€β”€ 2. λ°±μ—”λ“œ 라이브러리 λ‘œλ“œ + β”‚ └── dynamicloadutil::getQnnFunctionPointers("libQnnHtp.so") + β”‚ + β”œβ”€β”€ 3. μ‹œμŠ€ν…œ 라이브러리 λ‘œλ“œ + β”‚ └── dynamicloadutil::getQnnSystemFunctionPointers("libQnnSystem.so") + β”‚ + β”œβ”€β”€ 4. 둜그 ν•Έλ“€ 생성 + β”‚ └── qnnInterface.logCreate() + β”‚ + β”œβ”€β”€ 5. λ°±μ—”λ“œ ν™•μž₯(Backend Extensions) μ„€μ • + β”‚ β”œβ”€β”€ BackendExtensions 생성 (htp_backend_ext_config.json) + β”‚ └── beforeBackendInitialize() β†’ backendCreate() β†’ afterBackendInitialize() + β”‚ + β”œβ”€β”€ 6. λ””λ°”μ΄μŠ€ 생성 + β”‚ └── isDevicePropertySupported() β†’ createDevice() + β”‚ + β”œβ”€β”€ 7. ν”„λ‘œνŒŒμΌλ§ μ΄ˆκΈ°ν™” + β”‚ └── initializeProfiling() (OFF / BASIC / DETAILED) + β”‚ + └── 8. Op νŒ¨ν‚€μ§€ 등둝 + └── registerOpPackages() +``` + +--- + +## λ°”μ΄λ„ˆλ¦¬ μ»¨ν…μŠ€νŠΈ λ‘œλ“œ + +`QNNContext::load(file_path)` 호좜 μ‹œ `QNNVar::makeContext()`κ°€ μ‹€ν–‰λ˜μ–΄ QNN λ°”μ΄λ„ˆλ¦¬λ₯Ό λ©”λͺ¨λ¦¬μ— λ§€ν•‘ν•˜κ³  μ»¨ν…μŠ€νŠΈλ₯Ό μƒμ„±ν•©λ‹ˆλ‹€. + +``` +load(bin_path) + β”‚ + └── QNNVar::makeContext(bin_path) + β”‚ + β”œβ”€β”€ 1. κΈ°μ‘΄ μ»¨ν…μŠ€νŠΈ 확인 (이미 λ‘œλ“œλ˜μ—ˆμœΌλ©΄ μŠ€ν‚΅) + β”‚ + β”œβ”€β”€ 2. λ°±μ—”λ“œ ν™•μž₯: beforeCreateFromBinary() + β”‚ + β”œβ”€β”€ 3. λ°”μ΄λ„ˆλ¦¬ 파일 mmap + β”‚ └── mmapBinaryFile() β†’ mmap(PROT_READ, MAP_PRIVATE) + β”‚ + β”œβ”€β”€ 4. λ°”μ΄λ„ˆλ¦¬ 메타데이터 μΆ”μΆœ + β”‚ β”œβ”€β”€ systemContextCreate() + β”‚ β”œβ”€β”€ systemContextGetBinaryInfo() + β”‚ └── copyMetadataToGraphsInfo() + β”‚ + β”œβ”€β”€ 5. QNN μ»¨ν…μŠ€νŠΈ 생성 + β”‚ └── contextCreateFromBinary() + β”‚ + β”œβ”€β”€ 6. λ°±μ—”λ“œ ν™•μž₯: afterCreateFromBinary() + β”‚ + β”œβ”€β”€ 7. κ·Έλž˜ν”„ 정보 λ§΅ μ„€μ • + β”‚ └── setGraphInfoMap() β†’ graph_map에 κ·Έλž˜ν”„ 이름-정보 λ§€ν•‘ + β”‚ + └── 8. μ»¨ν…μŠ€νŠΈ 맡에 μ €μž₯ + └── ct_map[bin_path] = context_i +``` + +이후 `QNNVar::graphRetrieve(bin_path, graphName)`을 ν˜ΈμΆœν•˜μ—¬ νŠΉμ • κ·Έλž˜ν”„λ₯Ό κ°€μ Έμ˜¬ 수 μžˆμŠ΅λ‹ˆλ‹€. λ°”μ΄λ„ˆλ¦¬ 파일 ν•˜λ‚˜μ— μ—¬λŸ¬ κ·Έλž˜ν”„κ°€ 포함될 수 있으며, `ct_map`을 톡해 λ°”μ΄λ„ˆλ¦¬ νŒŒμΌλ³„λ‘œ μ»¨ν…μŠ€νŠΈλ₯Ό κ΄€λ¦¬ν•©λ‹ˆλ‹€. + +--- + +## HTP backend extension config 경둜 + +`QNNContext`λŠ” backend extension loaderλ₯Ό 톡해 +`htp_backend_ext_config.json`을 λ‘œλ“œν•©λ‹ˆλ‹€. 경둜 탐색 μš°μ„ μˆœμœ„λŠ” λ‹€μŒκ³Ό +κ°™μŠ΅λ‹ˆλ‹€. + +1. `QNNContext::setBackendExtConfigPath()` / + `setDefaultBackendExtConfigPath()`둜 μ§€μ •ν•œ κ°’ +2. `QUICK_DOT_AI_QNN_BACKEND_EXT_CONFIG_PATH` +3. `QUICK_DOT_AI_BASE_DIR/htp_backend_ext_config.json` +4. ν˜„μž¬ μž‘μ—… λ””λ ‰ν„°λ¦¬μ˜ `htp_backend_ext_config.json` + +μ ˆλŒ€ κ²½λ‘œλŠ” κ·ΈλŒ€λ‘œ μ‚¬μš©ν•˜κ³ , μƒλŒ€ κ²½λ‘œλŠ” `QUICK_DOT_AI_BASE_DIR`이 있으면 +κ·Έ 디렉터리 κΈ°μ€€μœΌλ‘œ, μ—†μœΌλ©΄ ν˜„μž¬ μž‘μ—… 디렉터리 κΈ°μ€€μœΌλ‘œ ν•΄μ„ν•©λ‹ˆλ‹€. +Android `QuickDotAI`의 `htpBackendConfigPath`λŠ” μƒλŒ€ 경둜λ₯Ό μ•± external +files 디렉터리 κΈ°μ€€μœΌλ‘œ λ¨Όμ € μ ˆλŒ€ κ²½λ‘œν™”ν•œ λ’€ native layer에 μ „λ‹¬ν•©λ‹ˆλ‹€. + +--- + +## κ΄€λ ¨ 파일 λͺ©λ‘ + +Quick.AI ν”„λ‘œμ νŠΈ 루트 κΈ°μ€€ 경둜: + +| 파일 | μ„€λͺ… | +|------|------| +| `qnn/qnn_context.h` | Quick.AI QNNContext 클래슀 μ„ μ–Έ | +| `qnn/qnn_context.cpp` | Quick.AI QNNContext κ΅¬ν˜„ | +| `qnn/meson.build` | QNN 라이브러리 λΉŒλ“œ μ„€μ • | +| `qnn/jni/qnn_context_var.h` | QNNVar, QNNBackendVar 데이터 ꡬ쑰체 | +| `qnn/jni/qnn_rpc_manager.h` | QNN RPC λ©”λͺ¨λ¦¬ κ΄€λ¦¬μž | +| `nntrainer/nntrainer/qnn_context.h` | nntrainer 베이슀 Context 클래슀 (μ„œλΈŒλͺ¨λ“ˆ) | +| `nntrainer/nntrainer/context.h` | nntrainer 베이슀 Context 클래슀 (μ„œλΈŒλͺ¨λ“ˆ) | +| `nntrainer/nntrainer/engine.h` | Engine 클래슀 (μ„œλΈŒλͺ¨λ“ˆ) | +| `nntrainer/nntrainer/utils/singleton.h` | Singleton ν…œν”Œλ¦Ώ (μ„œλΈŒλͺ¨λ“ˆ) | + +> **μ°Έκ³ **: `nntrainer/` λ””λ ‰ν† λ¦¬λŠ” Quick.AI의 Git μ„œλΈŒλͺ¨λ“ˆμž…λ‹ˆλ‹€. μœ„ κ²½λ‘œλŠ” Quick.AI 루트 κΈ°μ€€ μƒλŒ€ κ²½λ‘œμž…λ‹ˆλ‹€. + +## κ΄€λ ¨ λ¬Έμ„œ + +| λ¬Έμ„œ | μ„€λͺ… | +|------|------| +| [QNN μ„€μΉ˜ κ°€μ΄λ“œ](../docs/HowToInstallQNN.md) | QNN SDK 및 Hexagon SDK μ„€μΉ˜ 방법 | +| [API λ¬Έμ„œ](../api/README.md) | C APIμ—μ„œ QNN λͺ¨λΈ νƒ€μž… (`GEMMA4_E2B_QNN`, `VJEPA_QNN` λ“±) μ°Έμ‘° | diff --git a/qnn/jni/Android.mk.in b/qnn/jni/Android.mk.in new file mode 100644 index 00000000..ce7d1f11 --- /dev/null +++ b/qnn/jni/Android.mk.in @@ -0,0 +1,45 @@ +LOCAL_PATH := $(call my-dir) + +include $(CLEAR_VARS) + +# ndk path +ifndef ANDROID_NDK +$(error ANDROID_NDK is not defined!) +endif + +NNTRAINER_ROOT := $(LOCAL_PATH)/../../../nntrainer/builddir/android_build_result/ + +NNTRAINER_INCLUDES := $(NNTRAINER_ROOT)/include/nntrainer + +LOCAL_MODULE := nntrainer +LOCAL_SRC_FILES := $(NNTRAINER_ROOT)/lib/$(TARGET_ARCH_ABI)/libnntrainer.so +LOCAL_EXPORT_C_INCLUDES := $(NNTRAINER_INCLUDES) + +include $(PREBUILT_SHARED_LIBRARY) + +include $(CLEAR_VARS) + +LOCAL_MODULE := ccapi-nntrainer +LOCAL_SRC_FILES := $(NNTRAINER_ROOT)/lib/$(TARGET_ARCH_ABI)/libccapi-nntrainer.so +LOCAL_EXPORT_C_INCLUDES := $(NNTRAINER_INCLUDES) $(NNTRAINER_INCLUDES)/nntrainer + +include $(PREBUILT_SHARED_LIBRARY) + +include $(CLEAR_VARS) + +LOCAL_MODULE := qnn_context +LOCAL_SRC_FILES := @MESON_QNN_CONTEXT_SRCS@ +LOCAL_C_INCLUDES := $(NNTRAINER_INCLUDES) @MESON_QNN_CONTEXT_INC@ +LOCAL_SHARED_LIBRARIES := nntrainer ccapi-nntrainer + +LOCAL_ARM_NEON := true +LOCAL_CFLAGS += -pthread -fexceptions -fopenmp -static-openmp @MESON_CFLAGS@ +LOCAL_CXXFLAGS += -std=c++17 -frtti -fexceptions @MESON_CXXFLAGS@ -DENABLE_FP16=1 -DUSE__PF16=1 -DENABLE_QNN=1 -DPLUGGABLE=1 +LOCAL_MODULE_TAGS := optional + +LOCAL_LDLIBS := -llog -landroid -fopenmp -static-openmp +LOCAL_LDFLAGS += "-Wl,-z,max-page-size=16384" + +LOCAL_SHARED_LIBRARIES += nntrainer ccapi-nntrainer + +include $(BUILD_SHARED_LIBRARY) diff --git a/qnn/jni/Application.mk b/qnn/jni/Application.mk new file mode 100644 index 00000000..497a6d27 --- /dev/null +++ b/qnn/jni/Application.mk @@ -0,0 +1,5 @@ +APP_ABI := arm64-v8a +LIBCXX_USE_GABIXX := true +APP_STL := c++_shared +APP_PLATFORM := android-29 +APP_SUPPORT_FLEXIBLE_PAGE_SIZES := true diff --git a/qnn/jni/iotensor_wrapper.hpp b/qnn/jni/iotensor_wrapper.hpp new file mode 100644 index 00000000..3415407d --- /dev/null +++ b/qnn/jni/iotensor_wrapper.hpp @@ -0,0 +1,353 @@ +#ifndef __QNN_IOTENSOR_WRAPPER_H__ +#define __QNN_IOTENSOR_WRAPPER_H__ +#include "DataUtil.hpp" +#include "IOTensor.hpp" +#include "PAL/StringOp.hpp" + +#include +#include + +namespace nntrainer { + +using namespace qnn::tools; +using namespace qnn::tools::iotensor; + +class IOTensorWrapper { +public: + StatusCode + setupInputAndOutputTensors(Qnn_Tensor_t **inputs, Qnn_Tensor_t **outputs, + qnn_wrapper_api::GraphInfo_t graphInfo) { + auto returnStatus = StatusCode::SUCCESS; + if (StatusCode::SUCCESS != setupTensorsNoCopy(inputs, + graphInfo.numInputTensors, + (graphInfo.inputTensors))) { + ml_loge("Failure in setting up input tensors"); + returnStatus = StatusCode::FAILURE; + } + if (StatusCode::SUCCESS != setupTensorsNoCopy(outputs, + graphInfo.numOutputTensors, + (graphInfo.outputTensors))) { + ml_loge("Failure in setting up output tensors"); + returnStatus = StatusCode::FAILURE; + } + if (StatusCode::SUCCESS != returnStatus) { + ml_loge("Failure in setupInputAndOutputTensors, cleaning up resources"); + if (nullptr != *inputs) { + QNN_DEBUG("cleaning up input tensors"); + tearDownTensors(*inputs, graphInfo.numInputTensors); + *inputs = nullptr; + } + if (nullptr != *outputs) { + QNN_DEBUG("cleaning up output tensors"); + tearDownTensors(*outputs, graphInfo.numOutputTensors); + *outputs = nullptr; + } + ml_loge( + "Failure in setupInputAndOutputTensors, done cleaning up resources"); + } + return returnStatus; + } + + StatusCode populateInputTensor(uint8_t *buffer, Qnn_Tensor_t *input, + InputDataType inputDataType) { + if (nullptr == input) { + ml_loge("input is nullptr"); + return StatusCode::FAILURE; + } + std::vector dims; + fillDims(dims, QNN_TENSOR_GET_DIMENSIONS(input), + QNN_TENSOR_GET_RANK(input)); + if (inputDataType == InputDataType::FLOAT && + QNN_TENSOR_GET_DATA_TYPE(input) != QNN_DATATYPE_FLOAT_32) { + QNN_DEBUG("Received FLOAT input, but model needs non-float input"); + if (StatusCode::SUCCESS != + copyFromFloatToNative(reinterpret_cast(buffer), input)) { + QNN_DEBUG("copyFromFloatToNative failure"); + return StatusCode::FAILURE; + } + } else { + size_t length; + datautil::StatusCode returnStatus; + std::tie(returnStatus, length) = + datautil::calculateLength(dims, QNN_TENSOR_GET_DATA_TYPE(input)); + if (datautil::StatusCode::SUCCESS != returnStatus) { + return StatusCode::FAILURE; + } + pal::StringOp::memscpy( + reinterpret_cast(QNN_TENSOR_GET_CLIENT_BUF(input).data), + length, buffer, length); + } + return StatusCode::SUCCESS; + } + + StatusCode populateInputTensor(uint16_t *buffer, Qnn_Tensor_t *input, + InputDataType inputDataType) { + if (nullptr == input) { + ml_loge("input is nullptr"); + return StatusCode::FAILURE; + } + + std::vector dims; + fillDims(dims, QNN_TENSOR_GET_DIMENSIONS(input), + QNN_TENSOR_GET_RANK(input)); + + if (inputDataType == InputDataType::FLOAT && + QNN_TENSOR_GET_DATA_TYPE(input) != QNN_DATATYPE_FLOAT_32) { + QNN_DEBUG("Received FLOAT input, but model needs non-float input"); + if (StatusCode::SUCCESS != + copyFromFloatToNative(reinterpret_cast(buffer), input)) { + QNN_DEBUG("copyFromFloatToNative failure"); + return StatusCode::FAILURE; + } + } else { + size_t length; + datautil::StatusCode returnStatus; + std::tie(returnStatus, length) = + datautil::calculateLength(dims, QNN_TENSOR_GET_DATA_TYPE(input)); + + if (datautil::StatusCode::SUCCESS != returnStatus) { + return StatusCode::FAILURE; + } + pal::StringOp::memscpy( + reinterpret_cast(QNN_TENSOR_GET_CLIENT_BUF(input).data), + length, buffer, length); + } + return StatusCode::SUCCESS; + } + + StatusCode populateInputTensor(float *buffer, Qnn_Tensor_t *input, + InputDataType inputDataType) { + if (nullptr == input) { + ml_loge("input is nullptr"); + return StatusCode::FAILURE; + } + std::vector dims; + fillDims(dims, QNN_TENSOR_GET_DIMENSIONS(input), + QNN_TENSOR_GET_RANK(input)); + if (inputDataType == InputDataType::FLOAT && + QNN_TENSOR_GET_DATA_TYPE(input) != QNN_DATATYPE_FLOAT_32) { + QNN_DEBUG("Received FLOAT input, but model needs non-float input"); + if (StatusCode::SUCCESS != + copyFromFloatToNative(reinterpret_cast(buffer), input)) { + QNN_DEBUG("copyFromFloatToNative failure"); + return StatusCode::FAILURE; + } + } else { + size_t length; + datautil::StatusCode returnStatus; + std::tie(returnStatus, length) = + datautil::calculateLength(dims, QNN_TENSOR_GET_DATA_TYPE(input)); + if (datautil::StatusCode::SUCCESS != returnStatus) { + return StatusCode::FAILURE; + } + pal::StringOp::memscpy( + reinterpret_cast(QNN_TENSOR_GET_CLIENT_BUF(input).data), + length, buffer, length); + } + return StatusCode::SUCCESS; + } + +private: + StatusCode setupTensorsNoCopy(Qnn_Tensor_t **tensors, uint32_t tensorCount, + Qnn_Tensor_t *tensorWrappers) { + if (nullptr == tensorWrappers) { + ml_loge("tensorWrappers is nullptr"); + return StatusCode::FAILURE; + } + if (0 == tensorCount) { + QNN_INFO("tensor count is 0. Nothing to setup."); + return StatusCode::SUCCESS; + } + auto returnStatus = StatusCode::SUCCESS; + *tensors = (Qnn_Tensor_t *)calloc(1, tensorCount * sizeof(Qnn_Tensor_t)); + // std::cout << "tensorCount: "< dims; + fillDims(dims, QNN_TENSOR_GET_DIMENSIONS(wrapperTensor), + QNN_TENSOR_GET_RANK(wrapperTensor)); + if (StatusCode::SUCCESS == returnStatus) { + QNN_DEBUG("allocateBuffer successful"); + (*tensors)[tensorIdx] = QNN_TENSOR_INIT; + returnStatus = (qnn::tools::sample_app::deepCopyQnnTensorInfo( + ((*tensors) + tensorIdx), &wrapperTensor) == true + ? StatusCode::SUCCESS + : StatusCode::FAILURE); + } + if (StatusCode::SUCCESS == returnStatus) { + QNN_DEBUG("deepCopyQnnTensorInfo successful"); + QNN_TENSOR_SET_MEM_TYPE(((*tensors) + tensorIdx), + QNN_TENSORMEMTYPE_MEMHANDLE); + } + } + return returnStatus; + } + + // Clean up all tensors related data after execution. + StatusCode tearDownTensors(Qnn_Tensor_t *tensors, uint32_t tensorCount) { + for (size_t tensorIdx = 0; tensorIdx < tensorCount; tensorIdx++) { + QNN_DEBUG("freeing resources for tensor: %d", tensorIdx); + if (nullptr != QNN_TENSOR_GET_DIMENSIONS(tensors[tensorIdx])) { + QNN_DEBUG("freeing dimensions"); + free(QNN_TENSOR_GET_DIMENSIONS(tensors[tensorIdx])); + } + if (nullptr != QNN_TENSOR_GET_CLIENT_BUF(tensors[tensorIdx]).data) { + QNN_DEBUG("freeing clientBuf.data"); + free(QNN_TENSOR_GET_CLIENT_BUF(tensors[tensorIdx]).data); + } + } + free(tensors); + return StatusCode::SUCCESS; + } + + StatusCode fillDims(std::vector &dims, uint32_t *inDimensions, + uint32_t rank) { + if (nullptr == inDimensions) { + QNN_ERROR("input dimensions is nullptr"); + return StatusCode::FAILURE; + } + for (size_t r = 0; r < rank; r++) { + dims.push_back(inDimensions[r]); + } + return StatusCode::SUCCESS; + } + + // Helper method to copy a float buffer, quantize it, and copy + // it to a tensor (Qnn_Tensor_t) buffer. + StatusCode copyFromFloatToNative(float *floatBuffer, Qnn_Tensor_t *tensor) { + if (nullptr == floatBuffer || nullptr == tensor) { + QNN_ERROR("copyFromFloatToNative(): received a nullptr"); + return StatusCode::FAILURE; + } + + StatusCode returnStatus = StatusCode::SUCCESS; + std::vector dims; + fillDims(dims, QNN_TENSOR_GET_DIMENSIONS(tensor), + QNN_TENSOR_GET_RANK(tensor)); + + switch (QNN_TENSOR_GET_DATA_TYPE(tensor)) { + case QNN_DATATYPE_UFIXED_POINT_8: + datautil::floatToTfN( + static_cast(QNN_TENSOR_GET_CLIENT_BUF(tensor).data), + floatBuffer, + QNN_TENSOR_GET_QUANT_PARAMS(tensor).scaleOffsetEncoding.offset, + QNN_TENSOR_GET_QUANT_PARAMS(tensor).scaleOffsetEncoding.scale, + datautil::calculateElementCount(dims)); + break; + + case QNN_DATATYPE_UFIXED_POINT_16: + datautil::floatToTfN( + static_cast(QNN_TENSOR_GET_CLIENT_BUF(tensor).data), + floatBuffer, + QNN_TENSOR_GET_QUANT_PARAMS(tensor).scaleOffsetEncoding.offset, + QNN_TENSOR_GET_QUANT_PARAMS(tensor).scaleOffsetEncoding.scale, + datautil::calculateElementCount(dims)); + break; + + case QNN_DATATYPE_UINT_8: + if (datautil::StatusCode::SUCCESS != + datautil::castFromFloat( + static_cast(QNN_TENSOR_GET_CLIENT_BUF(tensor).data), + floatBuffer, datautil::calculateElementCount(dims))) { + QNN_ERROR("failure in castFromFloat"); + returnStatus = StatusCode::FAILURE; + } + break; + + case QNN_DATATYPE_UINT_16: + if (datautil::StatusCode::SUCCESS != + datautil::castFromFloat( + static_cast(QNN_TENSOR_GET_CLIENT_BUF(tensor).data), + floatBuffer, datautil::calculateElementCount(dims))) { + QNN_ERROR("failure in castFromFloat"); + returnStatus = StatusCode::FAILURE; + } + break; + + case QNN_DATATYPE_UINT_32: + if (datautil::StatusCode::SUCCESS != + datautil::castFromFloat( + static_cast(QNN_TENSOR_GET_CLIENT_BUF(tensor).data), + floatBuffer, datautil::calculateElementCount(dims))) { + QNN_ERROR("failure in castFromFloat"); + returnStatus = StatusCode::FAILURE; + } + break; + + case QNN_DATATYPE_UINT_64: + if (datautil::StatusCode::SUCCESS != + datautil::castFromFloat( + static_cast(QNN_TENSOR_GET_CLIENT_BUF(tensor).data), + floatBuffer, datautil::calculateElementCount(dims))) { + QNN_ERROR("failure in castFromFloat"); + returnStatus = StatusCode::FAILURE; + } + break; + + case QNN_DATATYPE_INT_8: + if (datautil::StatusCode::SUCCESS != + datautil::castFromFloat( + static_cast(QNN_TENSOR_GET_CLIENT_BUF(tensor).data), + floatBuffer, datautil::calculateElementCount(dims))) { + QNN_ERROR("failure in castFromFloat"); + returnStatus = StatusCode::FAILURE; + } + break; + + case QNN_DATATYPE_INT_16: + if (datautil::StatusCode::SUCCESS != + datautil::castFromFloat( + static_cast(QNN_TENSOR_GET_CLIENT_BUF(tensor).data), + floatBuffer, datautil::calculateElementCount(dims))) { + QNN_ERROR("failure in castFromFloat"); + returnStatus = StatusCode::FAILURE; + } + break; + + case QNN_DATATYPE_INT_32: + if (datautil::StatusCode::SUCCESS != + datautil::castFromFloat( + static_cast(QNN_TENSOR_GET_CLIENT_BUF(tensor).data), + floatBuffer, datautil::calculateElementCount(dims))) { + QNN_ERROR("failure in castFromFloat"); + returnStatus = StatusCode::FAILURE; + } + break; + + case QNN_DATATYPE_INT_64: + if (datautil::StatusCode::SUCCESS != + datautil::castFromFloat( + static_cast(QNN_TENSOR_GET_CLIENT_BUF(tensor).data), + floatBuffer, datautil::calculateElementCount(dims))) { + QNN_ERROR("failure in castFromFloat"); + returnStatus = StatusCode::FAILURE; + } + break; + + case QNN_DATATYPE_BOOL_8: + if (datautil::StatusCode::SUCCESS != + datautil::castFromFloat( + static_cast(QNN_TENSOR_GET_CLIENT_BUF(tensor).data), + floatBuffer, datautil::calculateElementCount(dims))) { + QNN_ERROR("failure in castFromFloat"); + returnStatus = StatusCode::FAILURE; + } + break; + + default: + QNN_ERROR("Datatype not supported yet!"); + returnStatus = StatusCode::FAILURE; + break; + } + return returnStatus; + } +}; +} // namespace nntrainer + +#endif diff --git a/qnn/jni/qnn-api/BackendExtensions.cpp b/qnn/jni/qnn-api/BackendExtensions.cpp new file mode 100644 index 00000000..e1dc445f --- /dev/null +++ b/qnn/jni/qnn-api/BackendExtensions.cpp @@ -0,0 +1,81 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All rights reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#include + +#include "BackendExtensions.hpp" +#include "dlwrap.hpp" +#include "PAL/DynamicLoading.hpp" +#include "qualla/detail/Log.hpp" + +BackendExtensions::BackendExtensions(BackendExtensionsConfigs backendExtensionsConfig, + void* backendLibHandle, + bool debug_qnn, + QnnLog_Callback_t registeredLogCallback, + QnnLog_Level_t qnnLogLevel) + : m_backendInterface(nullptr), m_destroyBackendInterfaceFn(nullptr) { + QNN_DEBUG("DEBUG: backendExtensionsConfig.sharedLibraryPath=%s\n", + backendExtensionsConfig.sharedLibraryPath.c_str()); + if (backendExtensionsConfig.sharedLibraryPath.empty()) { + throw std::runtime_error("Empty backend extensions library path."); + } + + QNN_DEBUG("DEBUG: backendExtensionsConfig.configFilePath=%s\n", + backendExtensionsConfig.configFilePath.c_str()); + if (backendExtensionsConfig.configFilePath.empty()) { + throw std::runtime_error("Empty backend extensions config path."); + } + + void* libHandle = + pal::dynamicloading::dlOpen(backendExtensionsConfig.sharedLibraryPath.c_str(), + pal::dynamicloading::DL_NOW | pal::dynamicloading::DL_LOCAL); + if (nullptr == libHandle) { + const char* msg = pal::dynamicloading::dlError(); + QNN_ERROR("Unable to load backend extensions lib: [%s]. dlerror(): [%s]", + backendExtensionsConfig.sharedLibraryPath.c_str(), + msg ? msg : "Unknown error"); + throw std::runtime_error("Unable to open backend extension library."); + } + + auto createBackendInterfaceFn = + reinterpret_cast( + pal::dynamicloading::dlSym(libHandle, "createBackendInterface")); + if (nullptr == createBackendInterfaceFn) { + throw std::runtime_error("Unable to resolve createBackendInterface."); + } + + m_destroyBackendInterfaceFn = + reinterpret_cast( + pal::dynamicloading::dlSym(libHandle, "destroyBackendInterface")); + if (nullptr == m_destroyBackendInterfaceFn) { + throw std::runtime_error("Unable to resolve destroyBackendInterface."); + } + + m_backendInterface = createBackendInterfaceFn(); + if (nullptr == m_backendInterface) { + throw std::runtime_error("Unable to load backend extensions interface."); + } + + if (debug_qnn) { + if (!(m_backendInterface->setupLogging(registeredLogCallback, qnnLogLevel))) { + throw std::runtime_error("Unable to initialize logging in backend extensions."); + } + } + + if (!m_backendInterface->initialize(backendLibHandle)) { + throw std::runtime_error("Unable to initialize backend extensions interface."); + } + + if (!m_backendInterface->loadConfig(backendExtensionsConfig.configFilePath)) { + throw std::runtime_error("Unable to load backend extensions config. " + backendExtensionsConfig.configFilePath); + } +} + +BackendExtensions::~BackendExtensions() { m_destroyBackendInterfaceFn(m_backendInterface); } + +qnn::tools::netrun::IBackend* BackendExtensions::interface() { return m_backendInterface; } diff --git a/qnn/jni/qnn-api/BackendExtensions.hpp b/qnn/jni/qnn-api/BackendExtensions.hpp new file mode 100644 index 00000000..3a6548d2 --- /dev/null +++ b/qnn/jni/qnn-api/BackendExtensions.hpp @@ -0,0 +1,27 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#pragma once + +#include "IBackend.hpp" +#include "QnnConfig.hpp" + +class BackendExtensions final { + public: + BackendExtensions(BackendExtensionsConfigs backendExtensionsConfig, + void* backendLibHandle, + bool debug_qnn = false, + QnnLog_Callback_t registeredLogCallback = nullptr, + QnnLog_Level_t qnnLogLevel = QNN_LOG_LEVEL_ERROR); + ~BackendExtensions(); + qnn::tools::netrun::IBackend* interface(); + + private: + qnn::tools::netrun::IBackend* m_backendInterface; + qnn::tools::netrun::DestroyBackendInterfaceFnType_t m_destroyBackendInterfaceFn; +}; diff --git a/qnn/jni/qnn-api/IBackend.hpp b/qnn/jni/qnn-api/IBackend.hpp new file mode 100644 index 00000000..90acb5ff --- /dev/null +++ b/qnn/jni/qnn-api/IBackend.hpp @@ -0,0 +1,202 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All rights reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#pragma once + +#include +#include +#include + +#include "QnnBackend.h" +#include "QnnContext.h" +#include "QnnDevice.h" +#include "QnnGraph.h" +#include "QnnLog.h" +#include "QnnProfile.h" +#include "QnnTypes.h" +#include "QnnWrapperUtils.hpp" + +namespace qnn { +namespace commandline2 { +class ICommandLineManager; +} +namespace tools { +namespace iotensor { +class IBufferAlloc; +} +namespace netrun { + +const uint32_t g_profilingLevelNotSet = 0; + +enum class PerfProfile { + LOW_BALANCED, + BALANCED, + DEFAULT, + HIGH_PERFORMANCE, + SUSTAINED_HIGH_PERFORMANCE, + BURST, + EXTREME_POWER_SAVER, + LOW_POWER_SAVER, + POWER_SAVER, + HIGH_POWER_SAVER, + SYSTEM_SETTINGS, + NO_USER_INPUT, + CUSTOM, + INVALID +}; + +enum class AppType { + QNN_APP_NETRUN = 0, + QNN_APP_CONTEXT_BINARY_GENERATOR = 1, + // Value selected to ensure 32 bits. + QNN_APP_UNKNOWN = 0x7FFFFFFF +}; + +// This is the interface that enables backend specific extensions in qnn-net-run. +// It is designed as hooks in the timeline of various events in NetRun. +// Backends that intend to implement custom features through qnn-net-run will have +// to implement this interface and add functionality in appropriate methods depending +// on where/when the custom functionality needs to be exercised. +// These functions/hooks will be called through the IBackend interface from within +// qnn-net-run wherever necessary. +class IBackend { + public: + virtual ~IBackend() {} + + virtual bool setupLogging(QnnLog_Callback_t callback, QnnLog_Level_t maxLogLevel) = 0; + + virtual bool initialize(void* backendLibHandle) = 0; + + virtual bool setPerfProfile(PerfProfile perfProfile) = 0; + + virtual QnnProfile_Level_t getProfilingLevel() = 0; + + virtual bool loadConfig(std::string configFile) = 0; + + virtual bool loadCommandLineArgs( + std::shared_ptr clManager) = 0; + + virtual bool beforeBackendInitialize(QnnBackend_Config_t*** customConfigs, + uint32_t* configCount) = 0; + + virtual bool afterBackendInitialize() = 0; + + virtual bool beforeContextCreate(QnnContext_Config_t*** customConfigs, uint32_t* configCount) = 0; + + virtual bool afterContextCreate() = 0; + + virtual bool beforeComposeGraphs(qnn_wrapper_api::GraphConfigInfo_t*** customGraphConfigs, + uint32_t* graphCount) = 0; + + virtual bool afterComposeGraphs() = 0; + + virtual bool beforeGraphFinalizeUpdateConfig(const char* graphName, + Qnn_GraphHandle_t graphHandle, + QnnGraph_Config_t*** customConfigs, + uint32_t* configCount) = 0; + + virtual bool beforeGraphFinalize() = 0; + + virtual bool afterGraphFinalize() = 0; + + virtual bool beforeRegisterOpPackages() = 0; + + virtual bool afterRegisterOpPackages() = 0; + + virtual bool beforeExecute(const char* graphName, + QnnGraph_Config_t*** customConfigs, + uint32_t* configCount) = 0; + + virtual bool afterExecute() = 0; + + virtual bool beforeContextFree(const std::vector& contextHandle) = 0; + + virtual bool afterContextFree() = 0; + + virtual bool beforeBackendTerminate() = 0; + + virtual bool afterBackendTerminate() = 0; + + virtual bool beforeCreateFromBinary(QnnContext_Config_t*** customConfigs, + uint32_t* configCount) = 0; + + virtual bool afterCreateFromBinary() = 0; + + virtual bool beforeCreateContextsFromBinaryList( + std::map>* + contextKeyToCustomConfigsMap, + QnnContext_Config_t*** commonCustomConfigs, + uint32_t* commonConfigCount) = 0; + + virtual bool afterCreateContextsFromBinaryList() = 0; + + virtual bool beforeCreateDevice(QnnDevice_Config_t*** deviceConfigs, + uint32_t* configCount, + uint32_t socModel) = 0; + + virtual bool afterCreateDevice() = 0; + + virtual bool beforeFreeDevice() = 0; + + virtual bool afterFreeDevice() = 0; + + virtual bool beforeActivateContext(QnnContext_Config_t*** customConfigs, + uint32_t* configCount) = 0; + + virtual bool afterActivateContext() = 0; + + virtual bool beforeDeactivateContext(QnnContext_Config_t*** customConfigs, + uint32_t* configCount) = 0; + + virtual bool afterDeactivateContext() = 0; + + virtual std::unique_ptr allocateBinaryBuffer(uint32_t bufferSize) = 0; + + virtual void releaseBinaryBuffer(std::unique_ptr buffer) = 0; + + virtual std::unique_ptr getBufferAllocator() = 0; + + virtual bool setParentAppType(AppType appType) = 0; + + virtual bool beforeContextApplyBinarySection() = 0; + + virtual bool afterContextApplyBinarySection() = 0; + + virtual bool isOpMappingsRequired() = 0; + + virtual bool prepareSoc(std::int32_t curDeviceId, + std::string dspArch, + int vtcmMem, + std::string name) = 0; + + virtual bool allocateExternalBuffers(void* contextHandle, + int64_t scratchBuffer, + int64_t weightsBuffer) = 0; + + virtual void provideOpMappings(Qnn_OpMapping_t* opMappings, uint32_t numOpMappings) = 0; + + virtual bool detachableBuffersEnabled() = 0; + + virtual bool detachBuffers(Qnn_ContextHandle_t contextHandle) = 0; + + virtual bool attachBuffers(Qnn_ContextHandle_t contextHandle) = 0; +}; + +// These are the function types that the backend extensions shared library is +// expected to expose. The first function helps NetRun obtain a valid implementation +// of IBackend interface and the second is used to destroy the same interface at the end. +// The function names themselves are expected to be these strings: +// 1. "createBackendInterface" +// 2. "destroyBackendInterface" +// These functions need to be tagged with extern "C" and their symbols need to be exposed. +typedef IBackend* (*CreateBackendInterfaceFnType_t)(); +typedef void (*DestroyBackendInterfaceFnType_t)(IBackend*); + +} // namespace netrun +} // namespace tools +} // namespace qnn \ No newline at end of file diff --git a/qnn/jni/qnn-api/QnnConfig.hpp b/qnn/jni/qnn-api/QnnConfig.hpp new file mode 100644 index 00000000..6c4c3e4d --- /dev/null +++ b/qnn/jni/qnn-api/QnnConfig.hpp @@ -0,0 +1,39 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#pragma once + +#include +#include + +#include "QnnGraph.h" +#include "QnnTypes.h" + +struct BackendExtensionsConfigs { + std::string sharedLibraryPath; + std::string configFilePath; + + BackendExtensionsConfigs() : sharedLibraryPath(""), configFilePath("") {} + + BackendExtensionsConfigs(const std::string& _sharedLibraryPath, + const std::string& _configFilePath) + : sharedLibraryPath(_sharedLibraryPath), configFilePath(_configFilePath) {} +}; + +struct GraphConfigs { + std::string graphName; + bool priorityPresent; + Qnn_Priority_t priority; + GraphConfigs() : graphName(), priorityPresent(false), priority(QNN_PRIORITY_UNDEFINED) {} +}; + +struct ConfigOptions { + BackendExtensionsConfigs backendExtensionsConfigs; + std::vector graphConfigs; + ConfigOptions() : backendExtensionsConfigs(), graphConfigs() {} +}; diff --git a/qnn/jni/qnn-api/dlwrap.cpp b/qnn/jni/qnn-api/dlwrap.cpp new file mode 100644 index 00000000..9d9fb101 --- /dev/null +++ b/qnn/jni/qnn-api/dlwrap.cpp @@ -0,0 +1,66 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#ifdef _WIN32 + +#pragma warning(disable : 4133 4996) + +#include +#include +#include +#include +#include +#include + +#include "dlwrap.hpp" + +static const char* last_func; +static long last_err; + +void* dlopen(const char* dll, int flags) { + HINSTANCE h = LoadLibraryA(dll); + if (h == NULL) { + last_err = GetLastError(); + last_func = "dlopen"; + } + + return h; +} + +int dlclose(void* h) { + if (!FreeLibrary((HINSTANCE)h)) { + last_err = GetLastError(); + last_func = "dlclose"; + return -1; + } + + return 0; +} + +void* dlsym(void* h, const char* name) { + FARPROC p = GetProcAddress((HINSTANCE)h, name); + if (!p) { + last_err = GetLastError(); + last_func = "dlsym"; + } + return (void*)(intptr_t)p; +} + +const char* dlerror(void) { + static char str[88]; + + if (!last_err) return NULL; + + sprintf(str, "%s error #%ld", last_func, last_err); + last_err = 0; + last_func = NULL; + + return str; +} + +#endif // _WIN32 diff --git a/qnn/jni/qnn-api/dlwrap.hpp b/qnn/jni/qnn-api/dlwrap.hpp new file mode 100644 index 00000000..49444ab0 --- /dev/null +++ b/qnn/jni/qnn-api/dlwrap.hpp @@ -0,0 +1,33 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#ifndef DLWRAP_HPP +#define DLWRAP_HPP + +#ifndef _WIN32 + +// Just include regular dlfcn +#include + +#else // _WIN32 + +// Define basic set dl functions and flags + +#define RTLD_GLOBAL 0x100 +#define RTLD_LOCAL 0x000 +#define RTLD_LAZY 0x000 +#define RTLD_NOW 0x001 + +void* dlopen(const char* filename, int flag); +int dlclose(void* handle); +void* dlsym(void* handle, const char* name); +const char* dlerror(void); + +#endif // _WIN32 + +#endif // DLWRAP_HPP diff --git a/qnn/jni/qnn/Log/LogUtils.cpp b/qnn/jni/qnn/Log/LogUtils.cpp new file mode 100644 index 00000000..f136eff7 --- /dev/null +++ b/qnn/jni/qnn/Log/LogUtils.cpp @@ -0,0 +1,45 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All rights reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#include "LogUtils.hpp" + +void qnn::log::utils::logDefaultCallback(const char* fmt, + QnnLog_Level_t level, + uint64_t timestamp, + va_list argp) { + const char* levelStr = ""; + switch (level) { + case QNN_LOG_LEVEL_ERROR: + levelStr = " ERROR "; + break; + case QNN_LOG_LEVEL_WARN: + levelStr = "WARNING"; + break; + case QNN_LOG_LEVEL_INFO: + levelStr = " INFO "; + break; + case QNN_LOG_LEVEL_DEBUG: + levelStr = " DEBUG "; + break; + case QNN_LOG_LEVEL_VERBOSE: + levelStr = "VERBOSE"; + break; + case QNN_LOG_LEVEL_MAX: + levelStr = "UNKNOWN"; + break; + } + + double ms = (double)timestamp / 1000000.0; + // To avoid interleaved messages + { + std::lock_guard lock(sg_logUtilMutex); + fprintf(stdout, "%8.1fms [%-7s] ", ms, levelStr); + vfprintf(stdout, fmt, argp); + fprintf(stdout, "\n"); + } +} \ No newline at end of file diff --git a/qnn/jni/qnn/Log/LogUtils.hpp b/qnn/jni/qnn/Log/LogUtils.hpp new file mode 100644 index 00000000..1c6ab28b --- /dev/null +++ b/qnn/jni/qnn/Log/LogUtils.hpp @@ -0,0 +1,29 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All rights reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#pragma once + +#include +#include +#include +#include + +#include "QnnLog.h" + +namespace qnn { +namespace log { +namespace utils { + +// In non-hexagon app stdout is used and for hexagon farf logging is used +void logDefaultCallback(const char* fmt, QnnLog_Level_t level, uint64_t timestamp, va_list argp); + +static std::mutex sg_logUtilMutex; + +} // namespace utils +} // namespace log +} // namespace qnn diff --git a/qnn/jni/qnn/Log/Logger.cpp b/qnn/jni/qnn/Log/Logger.cpp new file mode 100644 index 00000000..38f15fb2 --- /dev/null +++ b/qnn/jni/qnn/Log/Logger.cpp @@ -0,0 +1,105 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All rights reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#include +#include +#include +#include + +#include "LogUtils.hpp" +#include "Logger.hpp" + +using namespace qnn::log; + +std::shared_ptr Logger::s_logger = nullptr; + +std::mutex Logger::s_logMutex; + +std::shared_ptr Logger::createLogger(QnnLog_Callback_t callback, + QnnLog_Level_t maxLevel, + QnnLog_Error_t* status) { + std::lock_guard lock(s_logMutex); + if ((maxLevel > QNN_LOG_LEVEL_VERBOSE) || (maxLevel == 0)) { + if (status) { + *status = QNN_LOG_ERROR_INVALID_ARGUMENT; + } + return nullptr; + } + if (!s_logger) { + s_logger = std::shared_ptr(new (std::nothrow) Logger(callback, maxLevel, status)); + } + *status = QNN_LOG_NO_ERROR; + return s_logger; +} + +Logger::Logger(QnnLog_Callback_t callback, QnnLog_Level_t maxLevel, QnnLog_Error_t* status) + : m_callback(callback), m_maxLevel(maxLevel), m_epoch(getTimestamp()) { + if (!callback) { + m_callback = utils::logDefaultCallback; + } +} + +void Logger::log(QnnLog_Level_t level, const char* file, long line, const char* fmt, ...) { + if (m_callback) { + if (level > m_maxLevel.load(std::memory_order_seq_cst)) { + return; + } + va_list argp; + va_start(argp, fmt); + std::string logString(fmt); + std::ignore = file; + std::ignore = line; + (*m_callback)(logString.c_str(), level, getTimestamp() - m_epoch, argp); + va_end(argp); + } +} + +uint64_t Logger::getTimestamp() const { + return std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(); +} + +std::shared_ptr<::qnn::log::Logger> g_logger{nullptr}; + +bool qnn::log::initializeLogging() { + QnnLog_Level_t logLevel; + QnnLog_Error_t status; +#ifdef QNN_ENABLE_DEBUG + logLevel = QNN_LOG_LEVEL_DEBUG; +#else + logLevel = QNN_LOG_LEVEL_INFO; +#endif + // Default log stream is enabled in Core/Logger component + g_logger = ::qnn::log::Logger::createLogger(nullptr, logLevel, &status); + if (QNN_LOG_NO_ERROR != status || !g_logger) { + return false; + } + return true; +} + +QnnLog_Callback_t qnn::log::getLogCallback() { return g_logger->getLogCallback(); } + +QnnLog_Level_t qnn::log::getLogLevel() { return g_logger->getMaxLevel(); } + +bool qnn::log::isLogInitialized() { + if (g_logger == nullptr) { + return false; + } + return true; +} + +bool qnn::log::setLogLevel(QnnLog_Level_t maxLevel) { + if (!::qnn::log::Logger::isValid() || + !(maxLevel >= QNN_LOG_LEVEL_ERROR && maxLevel <= QNN_LOG_LEVEL_DEBUG)) { + return false; + } + + g_logger->setMaxLevel(maxLevel); + return true; +} diff --git a/qnn/jni/qnn/Log/Logger.hpp b/qnn/jni/qnn/Log/Logger.hpp new file mode 100644 index 00000000..69389232 --- /dev/null +++ b/qnn/jni/qnn/Log/Logger.hpp @@ -0,0 +1,106 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All rights reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#pragma once + +#include +#include +#include +#include +#include + +#include "QnnLog.h" + +#define __FILENAME__ (strrchr(__FILE__, '/') + 1) + +/** + * @brief Log something with the current logger. Always valid to call, though + * it won't do something if no logger has been set. + */ + +#define QNN_LOG_LEVEL(level, fmt, ...) \ + do { \ + auto logger = ::qnn::log::Logger::getLogger(); \ + if (logger) { \ + logger->log(level, __FILENAME__, __LINE__, fmt, ##__VA_ARGS__); \ + } \ + } while (0) + +#define QNN_ERROR(fmt, ...) QNN_LOG_LEVEL(QNN_LOG_LEVEL_ERROR, fmt, ##__VA_ARGS__) + +#define QNN_ERROR_EXIT(fmt, ...) \ + { \ + QNN_ERROR(fmt, ##__VA_ARGS__); \ + exit(EXIT_FAILURE); \ + } + +#define QNN_WARN(fmt, ...) QNN_LOG_LEVEL(QNN_LOG_LEVEL_WARN, fmt, ##__VA_ARGS__) + +#define QNN_INFO(fmt, ...) QNN_LOG_LEVEL(QNN_LOG_LEVEL_INFO, fmt, ##__VA_ARGS__) + +#define QNN_DEBUG(fmt, ...) QNN_LOG_LEVEL(QNN_LOG_LEVEL_DEBUG, fmt, ##__VA_ARGS__) + +#define QNN_VERBOSE(fmt, ...) QNN_LOG_LEVEL(QNN_LOG_LEVEL_VERBOSE, fmt, ##__VA_ARGS__) + +#define QNN_FUNCTION_ENTRY_LOG QNN_LOG_LEVEL(QNN_LOG_LEVEL_VERBOSE, "Entering %s", __func__) + +#define QNN_FUNCTION_EXIT_LOG QNN_LOG_LEVEL(QNN_LOG_LEVEL_VERBOSE, "Returning from %s", __func__) + +namespace qnn { +namespace log { + +bool initializeLogging(); + +QnnLog_Callback_t getLogCallback(); + +QnnLog_Level_t getLogLevel(); + +bool isLogInitialized(); + +bool setLogLevel(QnnLog_Level_t maxLevel); + +class Logger final { + public: + Logger(const Logger&) = delete; + Logger& operator=(const Logger&) = delete; + Logger(Logger&&) = delete; + Logger& operator=(Logger&&) = delete; + + void setMaxLevel(QnnLog_Level_t maxLevel) { + m_maxLevel.store(maxLevel, std::memory_order_seq_cst); + } + + QnnLog_Level_t getMaxLevel() { return m_maxLevel.load(std::memory_order_seq_cst); } + + QnnLog_Callback_t getLogCallback() { return m_callback; } + + void log(QnnLog_Level_t level, const char* file, long line, const char* fmt, ...); + + static std::shared_ptr createLogger(QnnLog_Callback_t callback, + QnnLog_Level_t maxLevel, + QnnLog_Error_t* status); + + static bool isValid() { return (s_logger != nullptr); } + + static std::shared_ptr getLogger() { return s_logger; } + + static void reset() { s_logger = nullptr; } + uint64_t getTimestamp() const; + + private: + Logger(QnnLog_Callback_t callback, QnnLog_Level_t maxLevel, QnnLog_Error_t* status); + + QnnLog_Callback_t m_callback; + std::atomic m_maxLevel; + uint64_t m_epoch; + static std::shared_ptr s_logger; + static std::mutex s_logMutex; +}; + +} // namespace log +} // namespace qnn diff --git a/qnn/jni/qnn/Log/meson.build b/qnn/jni/qnn/Log/meson.build new file mode 100644 index 00000000..60607c96 --- /dev/null +++ b/qnn/jni/qnn/Log/meson.build @@ -0,0 +1,19 @@ +log_sources = [ + 'Logger.cpp', + 'LogUtils.cpp', +] + +log_headers = [ + 'Logger.hpp', + 'LogUtils.hpp' +] + +foreach s : log_sources + qnn_sources += meson.current_source_dir() / s +endforeach + +qnn_include_dir = quick_dot_ai_prefix / 'include' / 'nntrainer' / 'npu' / 'qnn' + +install_subdir ('Log', install_dir : qnn_include_dir ) + +install_headers(log_headers, install_dir : qnn_include_dir / 'Log') diff --git a/qnn/jni/qnn/PAL/include/PAL/Debug.hpp b/qnn/jni/qnn/PAL/include/PAL/Debug.hpp new file mode 100644 index 00000000..cbdd0fe5 --- /dev/null +++ b/qnn/jni/qnn/PAL/include/PAL/Debug.hpp @@ -0,0 +1,21 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All rights reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#pragma once + +#define DEBUG_ON 0 + +#if DEBUG_ON +#define DEBUG_MSG(...) \ + { \ + fprintf(stderr, __VA_ARGS__); \ + fprintf(stderr, "\n"); \ + } +#else +#define DEBUG_MSG(...) +#endif diff --git a/qnn/jni/qnn/PAL/include/PAL/Directory.hpp b/qnn/jni/qnn/PAL/include/PAL/Directory.hpp new file mode 100644 index 00000000..233255f6 --- /dev/null +++ b/qnn/jni/qnn/PAL/include/PAL/Directory.hpp @@ -0,0 +1,80 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All rights reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +//--------------------------------------------------------------------------- +/// @file +/// This file includes APIs for directory operations on supported platforms +//--------------------------------------------------------------------------- + +#pragma once + +#include + +#include "PAL/FileOp.hpp" + +namespace pal { +class Directory; +} + +class pal::Directory { + public: + using DirMode = pal::FileOp::FileMode; + //--------------------------------------------------------------------------- + /// @brief + /// Creates a directory in the file system. + /// @param path + /// Name of directory to create. + /// @param dirmode + /// Directory mode + /// @return + /// True if + /// 1. create a directory successfully + /// 2. or directory exist already + /// False otherwise + /// + /// For example: + /// + /// - Create a directory in default. + /// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + /// pal::Directory::Create(path, pal::Directory::DirMode::S_DEFAULT_); + /// pal::Directory::Create(path); + /// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + /// + /// - Create a directory with specific permission. + /// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + /// pal::Directory::Create(path, pal::Directory::DirMode::S_IRWXU_| + /// pal::Directory::DirMode::S_IRWXG_| + /// pal::Directory::DirMode::S_IRWXO_); + /// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + /// + /// @note For windows, dirmode is not used. + /// @note For linux, dirmode is used to set the permission of the folder. + //--------------------------------------------------------------------------- + static bool create(const std::string &path, + pal::Directory::DirMode dirmode = pal::Directory::DirMode::S_DEFAULT_); + + //--------------------------------------------------------------------------- + /// @brief + /// Removes the entire directory whether it's empty or not. + /// @param path + /// Name of directory to delete. + /// @return + /// True if the directory was successfully deleted, false otherwise. + //--------------------------------------------------------------------------- + static bool remove(const std::string &path); + + //--------------------------------------------------------------------------- + /// @brief + /// Creates a directory and all parent directories required. + /// @param path + /// Path of directory to create. + /// @return + /// True if the directory was successfully created, false otherwise. + //--------------------------------------------------------------------------- + static bool makePath(const std::string &path); +}; diff --git a/qnn/jni/qnn/PAL/include/PAL/Dsp.hpp b/qnn/jni/qnn/PAL/include/PAL/Dsp.hpp new file mode 100644 index 00000000..ff93feef --- /dev/null +++ b/qnn/jni/qnn/PAL/include/PAL/Dsp.hpp @@ -0,0 +1,40 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All rights reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +//--------------------------------------------------------------------------- +/// @file +/// This file includes APIs related to DSP on supported platforms +//--------------------------------------------------------------------------- + +#ifndef DSP_HPP +#define DSP_HPP + +#include +#include + +namespace pal { +class Dsp; +} + +class pal::Dsp { + public: + //--------------------------------------------------------------------------- + /// @brief + /// This API is only for Windows platform. + /// Get the absolute location of DSP driver library (libcdsprpc.so/dll). + /// @return + /// On success, return location of DSP driver library. + /// On error, return an empty string. + //--------------------------------------------------------------------------- + static std::string getDspDriverPath(); + + private: + static std::mutex s_mutex; +}; + +#endif // DSP_HPP diff --git a/qnn/jni/qnn/PAL/include/PAL/DynamicLoading.hpp b/qnn/jni/qnn/PAL/include/PAL/DynamicLoading.hpp new file mode 100644 index 00000000..ed65ceb0 --- /dev/null +++ b/qnn/jni/qnn/PAL/include/PAL/DynamicLoading.hpp @@ -0,0 +1,100 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All rights reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +//--------------------------------------------------------------------------- +/// @file +/// This file includes APIs for dynamic loading on supported platforms +//--------------------------------------------------------------------------- + +#pragma once + +#include + +namespace pal { +namespace dynamicloading { +// we only support subset of POSIX of dlopen/dlsym/dladdr/dlerror/dlclose +// except the following flags for dlopen, others should be done only +// when we really need them +// DL_NOW is MUST +// DL_LOCAL is enabled if not specified +enum { + DL_NOW = 0x0001, + DL_LOCAL = 0x0002, + DL_GLOBAL = 0x0004, + DL_NOLOAD = 0x0008 +}; + +// specify this address to distingiush from NULL pointer +#define DL_DEFAULT (void *)(0x4) + +//--------------------------------------------------------------------------- +/// @brief +/// Loads the dynamic shared object +/// @param filename +/// If contains path separators, treat it as relative or absolute pathname +/// or search it for the rule of dynamic linker +/// @param flags +/// - DL_NOW: resolve undefined symbols before return. MUST be specified. +/// - DL_LOCAL: optional, but the default specified. Symbols defined in this +/// shared object are not made available to resolve references in subsequently +/// loaded shared objects +/// - DL_GLOBAL: optional, resolve symbol globally +/// @return +/// On success, a non-NULL handle for the loaded library. +/// On error, NULL +//--------------------------------------------------------------------------- +void *dlOpen(const char *filename, int flags); + +//--------------------------------------------------------------------------- +/// @brief +/// Obtain address of a symbol in a shared object or executable +/// @param handle +/// A handle of a dynamic loaded shared object returned by dlopen +/// @param symbol +/// A null-terminated symbol name +/// @return +/// On success, return the address associated with symbol +/// On error, NULL +//--------------------------------------------------------------------------- +void *dlSym(void *handle, const char *symbol); + +//--------------------------------------------------------------------------- +/// @brief +/// Translate the address of a symbol to the path of the belonging shared object +/// @param addr +/// Address of symbol in a shared object +/// @param path +/// Full name of shared object that contains address, usually it is an absolute path +/// @return +/// On success, return a non-zero value +/// On error, return 0 +//--------------------------------------------------------------------------- +int dlAddrToLibName(void *addr, std::string &name); + +//--------------------------------------------------------------------------- +/// @brief +/// Decrements the reference count on the dynamically loaded shared object +/// referred to by handle. If the reference count drops to 0, then the +/// object is unloaded. +/// @return +/// On success, 0; on error, a nonzero value +//--------------------------------------------------------------------------- +int dlClose(void *handle); + +//--------------------------------------------------------------------------- +/// @brief +/// Obtain error diagnostic for functions in the dl-family APIs. +/// @return +/// Returns a human-readable, null-terminated string describing the most +/// recent error that occurred from a call to one of the functions in the +/// dl-family APIs. +//--------------------------------------------------------------------------- +char *dlError(void); + +} // namespace dynamicloading +} // namespace pal diff --git a/qnn/jni/qnn/PAL/include/PAL/FileOp.hpp b/qnn/jni/qnn/PAL/include/PAL/FileOp.hpp new file mode 100644 index 00000000..07c1e802 --- /dev/null +++ b/qnn/jni/qnn/PAL/include/PAL/FileOp.hpp @@ -0,0 +1,310 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All rights reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +//------------------------------------------------------------------------------ +/// @file +/// This file includes APIs for file operations on the supported platforms +//------------------------------------------------------------------------------ + +#pragma once + +#include + +#include +#include +#include + +namespace pal { +class FileOp; +} + +//------------------------------------------------------------------------------ +/// @brief +/// FileOp contains OS Specific file system functionality. +//------------------------------------------------------------------------------ +class pal::FileOp { + public: + enum class AccessMode : int32_t { + O_RDONLY_ = O_RDONLY, // File access flag: Read only + O_WRONLY_ = O_WRONLY, // File access flag: Write only + O_RDWR_ = O_RDWR, // File access flag: Read and write + O_CREAT_ = O_CREAT, // File creation flag: Create file or open existing + O_EXCL_ = O_EXCL, // File creation flag: Opens file, creates if DNE (use with O_CREAT) + O_TRUNC_ = O_TRUNC, // File creation flag: Truncate file on open + O_APPEND_ = O_APPEND // File status flag: Open file and shift fp to end of file + }; + + friend AccessMode operator&(AccessMode lhs, AccessMode rhs) { + return static_cast(static_cast(lhs) & static_cast(rhs)); + } + friend AccessMode operator|(AccessMode lhs, AccessMode rhs) { + return static_cast(static_cast(lhs) | static_cast(rhs)); + } + static AccessMode getFileAccessMode(AccessMode mode) { + return mode & (AccessMode::O_RDONLY_ | AccessMode::O_WRONLY_ | AccessMode::O_RDWR_); + } + + // enum for symbolic constants mode, strictly follow linux usage + // windows or another OS user should transfer the usage + // ref : http://man7.org/linux/man-pages/man2/open.2.html + enum class FileMode : uint32_t { + S_DEFAULT_ = 0777, + S_IRWXU_ = 0700, + S_IRUSR_ = 0400, + S_IWUSR_ = 0200, + S_IXUSR_ = 0100, + S_IRWXG_ = 0070, + S_IRGRP_ = 0040, + S_IWGRP_ = 0020, + S_IXGRP_ = 0010, + S_IRWXO_ = 0007, + S_IROTH_ = 0004, + S_IWOTH_ = 0002, + S_IXOTH_ = 0001 + }; + + friend FileMode operator&(FileMode lhs, FileMode rhs) { + return static_cast(static_cast(lhs) & static_cast(rhs)); + } + friend FileMode operator|(FileMode lhs, FileMode rhs) { + return static_cast(static_cast(lhs) | static_cast(rhs)); + } + + //--------------------------------------------------------------------------- + /// @brief + /// Open a file + /// @param path, flags + /// Path to check, flags to set permissions + /// @return + /// Returns a file descriptor or -1 to indicate a failure + /// + /// For examples: + /// -# open a write/read file: + /// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + /// int32_t fd = pal::FileOp::open(path, pal::FileOp::AccessMode::O_RDWR_); + /// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + /// -# open a read-only file: + /// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + /// int32_t fd = pal::FileOp::open(path, pal::FileOp::AccessMode::O_RDONLY_); + /// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + /// -# open a write only file with append: + /// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + /// int32_t fd = pal::FileOp::open(path, pal::FileOp::AccessMode::O_WRONLY_ | + /// pal::FileOp::AccessMode::O_APPEND); + /// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + /// -# open/create a new file with user write/read/exec + other read only + /// + group read only + /// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + /// int32_t fd = pal::FileOp::open(path, pal::FileOp::AccessMode::O_CREAT_ | + /// pal::FileOp::AccessMode::O_RDWR_, pal::FileOp::FileMode::S_IRWXU_ | + /// pal::FileOp::FileMode::S_IRGRP_ | pal::FileOp::FileMode::S_IROTH_); + /// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + /// In this case, linux can work as expected. Windows will creat a file with + /// both read and write since at least one of three kinds can read. + //-------------------------------------------------------------------------- + static int32_t open(const std::string &path, AccessMode flags, FileMode mode = FileMode::S_DEFAULT_); + + //--------------------------------------------------------------------------- + /// @brief + /// Closes a file + /// @param fd + /// File descriptor + /// @return + /// Returns 0 if successful, -1 otherwise + //--------------------------------------------------------------------------- + static int32_t close(int32_t fd); + + //--------------------------------------------------------------------------- + /// @brief + /// Copies a file from one location to another, overwrites if the + /// destination already exists. + /// @param source + /// File name of the source file. + /// @param target + /// File name of the target file. + /// @return + /// True on success, otherwise false. + //--------------------------------------------------------------------------- + static bool copyOverFile(const std::string &source, const std::string &target); + + //--------------------------------------------------------------------------- + /// @brief + /// Checks whether the file exists or not. + /// @param fileName + /// File name of the source file, including its complete path. + /// @return + /// True on success, otherwise false. + //--------------------------------------------------------------------------- + static bool checkFileExists(const std::string &fileName); + + //--------------------------------------------------------------------------- + /// @brief + /// Renames an existing file. If the file with target name exists, this call + /// overwrites it with the file with source name. + /// @param source + /// Current File name. + /// @param target + /// New name of the file. + /// @param overwrite + /// Flag indicating to overwrite existing file with newName + /// @return + /// True if successful, otherwise false. + /// @warning + /// Does not work if source and target are on different filesystems. + //--------------------------------------------------------------------------- + static bool move(const std::string &source, const std::string &target, bool overwrite); + + //--------------------------------------------------------------------------- + /// @brief + /// Delete an existing file + /// @param fileName + /// File name of the file to be deleted. + /// @return + /// True if successful, otherwise false. + //--------------------------------------------------------------------------- + static bool deleteFile(const std::string &fileName); + + //--------------------------------------------------------------------------- + /// @brief + /// Check if path is a directory or not + /// @param path + /// Path to check + /// @return + /// True if successful, otherwise false. + //--------------------------------------------------------------------------- + static bool checkIsDir(const std::string &path); + + //--------------------------------------------------------------------------- + /// @brief Data type representing parts of a filename + //--------------------------------------------------------------------------- + typedef struct { + //--------------------------------------------------------------------------- + /// @brief Name of the file without the extension (i.e., basename) + //--------------------------------------------------------------------------- + std::string basename; + + //--------------------------------------------------------------------------- + /// @brief Name of the file extension (i.e., .txt or .hlnd, .html) + //--------------------------------------------------------------------------- + std::string extension; + + //--------------------------------------------------------------------------- + /// @brief + /// Location of the file (i.e., /abc/xyz/foo.bar <-- /abc/xyz/). + /// If the file name has no location then the Directory points to + /// empty string + //--------------------------------------------------------------------------- + std::string directory; + } FilenamePartsType_t; + + //--------------------------------------------------------------------------- + /// @brief + /// Determines the components of a given filename, being the directory, + /// basename and extension. If the file has no location or extension, these + /// components remain empty + /// @param filename + /// Path of the file for which the components are to be determined + /// @param filenameParts + /// Will contain the file name components when this function returns + /// @return + /// True if successful, false otherwise + //--------------------------------------------------------------------------- + static bool getFileInfo(const std::string &filename, FilenamePartsType_t &filenameParts); + + //--------------------------------------------------------------------------- + /// @brief + /// Typedef for a vector of FilenamePartsType_t + //--------------------------------------------------------------------------- + typedef std::vector FilenamePartsListType_t; + + //--------------------------------------------------------------------------- + /// @brief + /// Typedef for a vector of FilenamePartsType_t const iterator + //--------------------------------------------------------------------------- + typedef std::vector::const_iterator FilenamePartsListTypeIter_t; + + //--------------------------------------------------------------------------- + /// @brief + /// Returns a vector of FilenamePartsType_t objects for a given directory + /// @param path + /// Path to scan for files + /// @return + /// True if successful, false otherwise + //--------------------------------------------------------------------------- + static bool getFileInfoList(const std::string &path, FilenamePartsListType_t &filenamePartsList); + + //--------------------------------------------------------------------------- + /// @brief + /// Returns a vector of FilenamePartsType_t objects for a given directory + /// and the child directories inside. + /// @param path + /// Path to directory to scan for files for + /// @note if path is not a directory - the function will return false + /// @param filenamePartList + /// List to append to + /// @param ignoreDirs + /// If this flag is set to true, directories (and symbolic links to directories) + /// are not included in the list. Only actual files below the specified + /// directory path will be appended. + /// @return True if successful, false otherwise + /// @note Directories in list only populate Directory member variable of the struct. + /// That is Basename and Extension will be empty strings. + /// @note Symbolic links to directories are not followed. This is to avoid possible + /// infinite recursion. However the initial call to this method can have + /// path to be a symbolic link to a directory. If ignoreDirs is true, + /// symbolic links to directories are also ignored. + /// @note The order in which the files/directories are listed is platform + /// dependent. However files inside a directory always come before the + /// directory itself. + //--------------------------------------------------------------------------- + static bool getFileInfoListRecursive(const std::string &path, + FilenamePartsListType_t &filenamePartsList, + const bool ignoreDirs); + + //--------------------------------------------------------------------------- + /// @brief + /// Create an absolute path from the supplied path + /// @param path + /// Path should not contain trailing '/' or '\\' + /// @return + /// Return absolute path without trailing '/' or '\\' + //--------------------------------------------------------------------------- + static std::string getAbsolutePath(const std::string &path); + + //--------------------------------------------------------------------------- + /// @brief Get the file name from a path + //--------------------------------------------------------------------------- + static std::string getFileName(const std::string &file); + + //--------------------------------------------------------------------------- + /// @brief Get the directory path to a file + //--------------------------------------------------------------------------- + static std::string getDirectory(const std::string &file); + + //--------------------------------------------------------------------------- + /// @brief Get the current working directory. + /// @returns The absolute CWD or empty string if the path could not be + /// retrieved (because it was too long or deleted for example). + //--------------------------------------------------------------------------- + static std::string getCurrentWorkingDirectory(); + + //--------------------------------------------------------------------------- + /// @brief Set the current working directory + //--------------------------------------------------------------------------- + static bool setCurrentWorkingDirectory(const std::string &workingDir); + + //--------------------------------------------------------------------------- + /// @brief Returns true if the file contains any extension or false. + //--------------------------------------------------------------------------- + static bool hasFileExtension(const std::string &file); + + //--------------------------------------------------------------------------- + /// @brief Returns full path of file, Directory/Basename(.Extension, if any) + //--------------------------------------------------------------------------- + static std::string partsToString(const FilenamePartsType_t &filenameParts); +}; diff --git a/qnn/jni/qnn/PAL/include/PAL/GetOpt.hpp b/qnn/jni/qnn/PAL/include/PAL/GetOpt.hpp new file mode 100644 index 00000000..b1ca35f1 --- /dev/null +++ b/qnn/jni/qnn/PAL/include/PAL/GetOpt.hpp @@ -0,0 +1,93 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All rights reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +//-------------------------------------------------------------------------------- +/// @file +/// This file includes APIs for the command line parsing on supported platforms +//-------------------------------------------------------------------------------- + +#pragma once + +namespace pal { +// we implement a similar API for POSIX.2 +// so that some global var are necessary + +extern const char *g_optArg; +extern int g_optInd; + +enum { + no_argument = 0, + required_argument = 1, + optional_argument = 2, +}; + +//-------------------------------------------------------------------------------------------------- +/// @brief +/// This structure describes a single long option name for the sake of getopt_long. The argument +/// longopts must be an array of these structures, one for each long option. Terminate the array +/// with an element containing all zeros. +//-------------------------------------------------------------------------------------------------- +struct Option { + //-------------------------------------------------------------------------------------------------- + /// @brief The name of the long option. + //-------------------------------------------------------------------------------------------------- + const char *name; + + //-------------------------------------------------------------------------------------------------- + /// @brief + /// If the option does not take an argument, no_argument (or 0). + /// If the option requires an argument, required_argument (or 1). + //-------------------------------------------------------------------------------------------------- + int hasArg; + + //-------------------------------------------------------------------------------------------------- + /// @brief + /// Specifies how results are returned for a long option. + /// If flag is NULL, then GetOptLongOnly() returns val. Otherwise, it returns 0, and flag + /// points to a variable which is set to val if the option is found, but + /// left unchanged if the option is not found. + //-------------------------------------------------------------------------------------------------- + int *flag; + + //-------------------------------------------------------------------------------------------------- + /// @brief + /// The value to return, or to load into the variable pointed to by flag. + /// The last element of the array has to be filled with zeros. + //-------------------------------------------------------------------------------------------------- + int val; +}; + +//-------------------------------------------------------------------------------------------------- +/// @brief +/// This parses command-line options as POSIX getopt_long_only() +/// but we don't support optstring and optonal_argument now +/// @param argc +/// Argument count +/// @param argv +/// Argument array +/// @param optstring +/// Legitimate option characters, short options, don't support now +/// @param longopts +/// A pointer to the first element of an array of struct option, +/// has_arg field in the struct option indicates 3 possibilities, +/// no_argument, required_argument or optional_argument. we don't +/// support optional_argument now +/// @param longindex +/// If longindex is not NULL, it points to a variable which is set +/// to the index of the long option relative to longopts +/// @return +/// -1 for parsing done, '?' for non-recognized arguments, 0 for +/// flag in longopts is not NULL and saved the val to it +//-------------------------------------------------------------------------------------------------- +int getOptLongOnly(int argc, + const char *const argv[], + const char *optstring, + const struct Option *longopts, + int *longindex); + +} // namespace pal diff --git a/qnn/jni/qnn/PAL/include/PAL/Path.hpp b/qnn/jni/qnn/PAL/include/PAL/Path.hpp new file mode 100644 index 00000000..5f6e6749 --- /dev/null +++ b/qnn/jni/qnn/PAL/include/PAL/Path.hpp @@ -0,0 +1,51 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All rights reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +//------------------------------------------------------------------------------ +/// @file +/// The file includes APIs for path related operations on supported platforms +//------------------------------------------------------------------------------ + +#pragma once + +#include +#include + +namespace pal { +class Path; +} + +class pal::Path { + public: + //--------------------------------------------------------------------------- + /// @brief Returns path separator for the system + //--------------------------------------------------------------------------- + static char getSeparator(); + + //--------------------------------------------------------------------------- + /// @brief Concatenate s1 and s2 + //--------------------------------------------------------------------------- + static std::string combine(const std::string &s1, const std::string &s2); + + //--------------------------------------------------------------------------- + /// @brief Get the directory name + //--------------------------------------------------------------------------- + static std::string getDirectoryName(const std::string &path); + + //--------------------------------------------------------------------------- + /// @brief Get absolute path + //--------------------------------------------------------------------------- + static std::string getAbsolute(const std::string &path); + + //--------------------------------------------------------------------------- + /// @brief Check if the input path is absolute path + //--------------------------------------------------------------------------- + static bool isAbsolute(const std::string &path); + + private: +}; diff --git a/qnn/jni/qnn/PAL/include/PAL/StringOp.hpp b/qnn/jni/qnn/PAL/include/PAL/StringOp.hpp new file mode 100644 index 00000000..40699328 --- /dev/null +++ b/qnn/jni/qnn/PAL/include/PAL/StringOp.hpp @@ -0,0 +1,60 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All rights reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +//----------------------------------------------------------------------------- +/// @file +/// The file inludes APIs for string operations on supported platforms +//----------------------------------------------------------------------------- + +#pragma once + +#include + +namespace pal { +class StringOp; +} + +//------------------------------------------------------------------------------ +/// @brief +/// FileOp contains OS Specific file system functionality. +//------------------------------------------------------------------------------ +class pal::StringOp { + public: + //--------------------------------------------------------------------------- + /// @brief + /// Copy copy_size bytes from buffer src to buffer dst. Behaviour of the + /// function is undefined if src and dst overlap. + /// @param dst + /// Destination buffer + /// @param dst_size + /// Size of destination buffer + /// @param src + /// Source buffer + /// @param copy_size + /// Number of bytes to copy + /// @return + /// Number of bytes copied + //--------------------------------------------------------------------------- + static size_t memscpy(void *dst, size_t dstSize, const void *src, size_t copySize); + + //--------------------------------------------------------------------------- + /// @brief + /// Returns a pointer to a null-terminated byte string, which contains copies + /// of at most size bytes from the string pointed to by str. If the null + /// terminator is not encountered in the first size bytes, it is added to the + /// duplicated string. + /// @param source + /// Source string + /// @param maxlen + /// Max number of bytes to copy from str + /// @return + /// A pointer to the newly allocated string, or a null pointer if an error + /// occurred. + //--------------------------------------------------------------------------- + static char *strndup(const char *source, size_t maxlen); +}; diff --git a/qnn/jni/qnn/PAL/meson.build b/qnn/jni/qnn/PAL/meson.build new file mode 100644 index 00000000..2fd419fa --- /dev/null +++ b/qnn/jni/qnn/PAL/meson.build @@ -0,0 +1,30 @@ +pal_headers = [ + 'include/PAL/Debug.hpp', + 'include/PAL/Directory.hpp', + 'include/PAL/DynamicLoading.hpp', + 'include/PAL/FileOp.hpp', + 'include/PAL/GetOpt.hpp', + 'include/PAL/Path.hpp', + 'include/PAL/StringOp.hpp' +] + +pal_sources = [ + 'src/common/GetOpt.cpp', + 'src/common/StringOp.cpp', + 'src/linux/Directory.cpp', + 'src/linux/DynamicLoading.cpp', + 'src/linux/FileOp.cpp', + 'src/linux/Path.cpp' +] + +foreach s : pal_sources + qnn_sources += meson.current_source_dir() / s +endforeach + +qnn_inc_abs += meson.current_source_dir() / 'include' + +qnn_include_dir = quick_dot_ai_prefix / 'include' / 'nntrainer' / 'npu' / 'qnn' + +install_subdir ('PAL', install_dir : qnn_include_dir ) + +install_headers(pal_headers, install_dir : qnn_include_dir / 'PAL') \ No newline at end of file diff --git a/qnn/jni/qnn/PAL/src/common/GetOpt.cpp b/qnn/jni/qnn/PAL/src/common/GetOpt.cpp new file mode 100644 index 00000000..01e01316 --- /dev/null +++ b/qnn/jni/qnn/PAL/src/common/GetOpt.cpp @@ -0,0 +1,154 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All rights reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#include + +#include + +#include "PAL/GetOpt.hpp" + +using namespace std; + +namespace pal { + +const char *g_optArg = nullptr; +int g_optInd = 1; + +static const struct Option *findOpt(const string str, + const struct Option *longopts, + int *longindex) { + const struct Option *opt = nullptr; + int idx = 0; + size_t searchEnd = str.find_first_of("="); + + for (opt = longopts; opt->name && strlen(opt->name) > 0; opt++, idx++) { + if (str.substr(0, searchEnd) == opt->name) { + if (longindex) { + *longindex = idx; + } + break; + } + } + // if not found, opt would point to the last element of longopts + // whose name MUST be empty + return opt->name ? opt : nullptr; +} + +int getOptLongOnly(int argc, + const char *const argv[], + const char *, + const struct Option *longopts, + int *longindex) { + const struct Option *opt; + int argLen = 0; + bool isShort = false; + const char *arg = ""; + + g_optArg = nullptr; + // no arg, means the end of command + if (g_optInd >= argc) { + return -1; + } + + arg = argv[g_optInd]; + + if (arg[0] != '-') { + g_optInd += 1; + return '?'; + } + + argLen = strlen(arg); + + if (argLen < 2) { + g_optInd += 1; + return '?'; + } + + if (!longopts) { + g_optInd += 1; + return '?'; + } + + // check short options with this form, -a arg + if (argLen == 2) { + isShort = true; + // check short options with this form, -a=arg + } else if (argLen > 3 && arg[2] == '=') { + isShort = true; + // check for long options, can be used for both forms + } else if (argLen > 2 && arg[1] != '=') { + if (arg[1] != '-') { + g_optInd += 1; + return '?'; + } + isShort = false; + } + + // start after -- to find the option + const char *const optStr = isShort ? &arg[1] : &arg[2]; + opt = findOpt(optStr, longopts, longindex); + if (!opt) { + g_optInd += 1; + return '?'; + } + + if (opt->hasArg == no_argument) { + g_optInd += 1; + + if (!opt->flag) { + return opt->val; + } else { + *(opt->flag) = opt->val; + return 0; + } + } + + if (opt->hasArg == required_argument) { + string optStr = argv[g_optInd]; + size_t assignIdx = optStr.find_first_of("="); + bool advance = (assignIdx == string::npos); + + // if it is --opt arg form, this will be true, + // so we need to advance one step to get arg + // otherwise, need to stop advance step & extract arg from argv[g_optInd] + if (advance) { + g_optInd += 1; + } + + if (g_optInd >= argc) { + return '?'; + } else { + // if advance, means it is the form --opt arg + // otherwise, the form, --opt=arg + if (advance) { + // since g_optInd is advanced, g_optArg can be assigned directly + g_optArg = argv[g_optInd]; + } else { + if (assignIdx == optStr.size()) { + return '?'; + } + // for not advanced form, + // g_optArg should point to the address right after "=" + g_optArg = &argv[g_optInd][assignIdx + 1]; + } + // OK, now we are ready to handle the next pair + g_optInd += 1; + + if (!opt->flag) { + return opt->val; + } else { + *(opt->flag) = opt->val; + return 0; + } + } + } + + return '?'; +} // end of getOptLongOnly + +} // namespace pal diff --git a/qnn/jni/qnn/PAL/src/common/StringOp.cpp b/qnn/jni/qnn/PAL/src/common/StringOp.cpp new file mode 100644 index 00000000..333e4b66 --- /dev/null +++ b/qnn/jni/qnn/PAL/src/common/StringOp.cpp @@ -0,0 +1,63 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All rights reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#include +#include + +#include "PAL/StringOp.hpp" + +//--------------------------------------------------------------------------- +// pal::StringOp::memscpy +//--------------------------------------------------------------------------- +size_t pal::StringOp::memscpy(void *dst, size_t dstSize, const void *src, size_t copySize) { + if (!dst || !src || !dstSize || !copySize) return 0; + + size_t minSize = dstSize < copySize ? dstSize : copySize; + + memcpy(dst, src, minSize); + + return minSize; +} + +#ifdef __hexagon__ +size_t strnlen(const char *s, size_t n) { + size_t i; + for (i = 0; i < n && s[i] != '\0'; i++) continue; + return i; +} +#endif + +//--------------------------------------------------------------------------- +// pal::StringOp::strndup +//--------------------------------------------------------------------------- +char *pal::StringOp::strndup(const char *source, size_t maxlen) { +#ifdef _WIN32 + size_t length = ::strnlen(source, maxlen); + + char *destination = (char *)malloc((length + 1) * sizeof(char)); + if (destination == nullptr) return nullptr; + + // copy length bytes to destination and leave destination[length] to be + // null terminator + strncpy_s(destination, length + 1, source, length); + + return destination; +#elif __hexagon__ + size_t length = strnlen(source, maxlen); + + char *destination = (char *)malloc((length + 1) * sizeof(char)); + if (destination == nullptr) return nullptr; + // copy length bytes to destination and leave destination[length] to be + // null terminator + strncpy(destination, source, length); + destination[length] = '\0'; + return destination; +#else + return ::strndup(source, maxlen); +#endif +} diff --git a/qnn/jni/qnn/PAL/src/linux/Directory.cpp b/qnn/jni/qnn/PAL/src/linux/Directory.cpp new file mode 100644 index 00000000..8d7ec679 --- /dev/null +++ b/qnn/jni/qnn/PAL/src/linux/Directory.cpp @@ -0,0 +1,153 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All rights reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#include +#include +#include +#ifndef __QNXNTO__ +#include +#endif +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "PAL/Directory.hpp" +#include "PAL/FileOp.hpp" +#include "PAL/Path.hpp" + +//------------------------------------------------------------------------------ +//------------------------------------------------------------------------------ +#ifdef __QNXNTO__ +static bool is_qnx_dir(const struct dirent *ep) { + struct dirent_extra *exp; + bool is_dir = false; + + for (exp = _DEXTRA_FIRST(ep); _DEXTRA_VALID(exp, ep); exp = _DEXTRA_NEXT(exp)) { + if (exp->d_type == _DTYPE_STAT || exp->d_type == _DTYPE_LSTAT) { + struct stat *statbuff = &((dirent_extra_stat *)exp)->d_stat; + if (statbuff && S_ISDIR(statbuff->st_mode)) { + is_dir = true; + break; + } + } + } + return is_dir; +} +#endif + +// ------------------------------------------------------------------------------ +// pal::Directory::create +// ------------------------------------------------------------------------------ +bool pal::Directory::create(const std::string &path, pal::Directory::DirMode dirmode) { + struct stat st; + int status = 0; + if (stat(path.c_str(), &st) != 0) { + // Directory does not exist + status = mkdir(path.c_str(), static_cast(dirmode)); + } else if (!S_ISDIR(st.st_mode)) { + errno = ENOTDIR; + status = -1; + } + return (status == 0); +} + +//------------------------------------------------------------------------------ +//------------------------------------------------------------------------------ +bool pal::Directory::remove(const std::string &dirName) { + DIR *dir; + struct dirent *entry; + + dir = opendir(dirName.c_str()); + if (dir == nullptr) { + // If the directory doesn't exist then just return true. + if (errno == ENOENT) { + return true; + } + return false; + } + +#ifdef __QNXNTO__ + if (dircntl(dir, D_SETFLAG, D_FLAG_STAT) == -1) { + return false; + } +#endif + + // Recursively traverse the directory tree. + while ((entry = readdir(dir)) != nullptr) { + if (strcmp(entry->d_name, ".") && strcmp(entry->d_name, "..")) { + std::stringstream ss; + ss << dirName << Path::getSeparator() << entry->d_name; + std::string path = ss.str(); +#ifdef __QNXNTO__ + if (is_qnx_dir(entry)) +#else + if (entry->d_type == DT_DIR) +#endif + { + // It's a directory so we need to drill down into it and delete + // its contents. + if (!remove(path)) { + return false; + } + } else { + if (::remove(path.c_str())) { + return false; + } + } + } + } + + closedir(dir); + + if (::remove(dirName.c_str())) { + return false; + } + + return true; +} + +bool pal::Directory::makePath(const std::string &path) { + struct stat st; + bool rc = false; + + if (path == ".") { + rc = true; + } else if (stat(path.c_str(), &st) == 0) { + if (st.st_mode & S_IFDIR) { + rc = true; + } + } else { + size_t offset = path.find_last_of(Path::getSeparator()); + if (offset != std::string::npos) { + std::string newPath = path.substr(0, offset); + if (!makePath(newPath)) { + return false; + } + } + + // There is a possible race condition, where a file/directory can be + // created in between the stat() above, and the mkdir() call here. + // So, ignore the return code from the mkdir() call, and then re-check + // for existence of the directory after it. Ensure both that it exists + // and that it is a directory - just like above. + mkdir(path.c_str(), 0777); + + if ((stat(path.c_str(), &st) == 0) && (st.st_mode & S_IFDIR)) { + rc = true; + } + } + + return rc; +} diff --git a/qnn/jni/qnn/PAL/src/linux/DynamicLoading.cpp b/qnn/jni/qnn/PAL/src/linux/DynamicLoading.cpp new file mode 100644 index 00000000..5a563570 --- /dev/null +++ b/qnn/jni/qnn/PAL/src/linux/DynamicLoading.cpp @@ -0,0 +1,75 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All rights reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#include +#include + +#include "PAL/Debug.hpp" +#include "PAL/DynamicLoading.hpp" + +void *pal::dynamicloading::dlOpen(const char *filename, int flags) { + int realFlags = 0; + + if (flags & DL_NOW) { + realFlags |= RTLD_NOW; + } + + if (flags & DL_LOCAL) { + realFlags |= RTLD_LOCAL; + } + + if (flags & DL_GLOBAL) { + realFlags |= RTLD_GLOBAL; + } + + return ::dlopen(filename, realFlags); +} + +void *pal::dynamicloading::dlSym(void *handle, const char *symbol) { + if (handle == DL_DEFAULT) { + return ::dlsym(RTLD_DEFAULT, symbol); + } + + return ::dlsym(handle, symbol); +} + +int pal::dynamicloading::dlAddrToLibName(void *addr, std::string &name) { + // Clean the output buffer + name = std::string(); + + // If the address is empty, return zero as treating failure + if (!addr) { + DEBUG_MSG("Input address is nullptr."); + return 0; + } + + // Dl_info do not maintain the lifetime of its string members, + // it would be maintained by dlopen() and dlclose(), + // so we do not need to release it manually + Dl_info info; + int result = ::dladdr(addr, &info); + + // If dladdr() successes, set name to the library name + if (result) { + name = std::string(info.dli_fname); + } else { + DEBUG_MSG("Input address could not be matched to a shared object."); + } + + return result; +} + +int pal::dynamicloading::dlClose(void *handle) { + if (!handle) { + return 0; + } + + return ::dlclose(handle); +} + +char *pal::dynamicloading::dlError(void) { return ::dlerror(); } diff --git a/qnn/jni/qnn/PAL/src/linux/FileOp.cpp b/qnn/jni/qnn/PAL/src/linux/FileOp.cpp new file mode 100644 index 00000000..5ace0f49 --- /dev/null +++ b/qnn/jni/qnn/PAL/src/linux/FileOp.cpp @@ -0,0 +1,369 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All rights reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#include +#include +#include +#ifndef __QNXNTO__ +#include +#endif +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "PAL/Debug.hpp" +#include "PAL/FileOp.hpp" +#include "PAL/Path.hpp" + +typedef struct stat Stat_t; + +//--------------------------------------------------------------------------- +// pal::FileOp::HasFileExtension +//--------------------------------------------------------------------------- +bool pal::FileOp::checkFileExists(const std::string& fileName) { + Stat_t sb; + + if (stat(fileName.c_str(), &sb) == -1) { + return false; + } else { + return true; + } +} + +//--------------------------------------------------------------------------- +// pal::FileOp::move +//--------------------------------------------------------------------------- +bool pal::FileOp::move(const std::string& currentName, const std::string& newName, bool overwrite) { + if (overwrite) { + remove(newName.c_str()); + } + return (rename(currentName.c_str(), newName.c_str()) == 0); +} + +//--------------------------------------------------------------------------- +// pal::FileOp::deleteFile +//--------------------------------------------------------------------------- +bool pal::FileOp::deleteFile(const std::string& fileName) { + return (remove(fileName.c_str()) == 0); +} + +//------------------------------------------------------------------------------ +// pal::FileOp::checkIsDir +//------------------------------------------------------------------------------ +bool pal::FileOp::checkIsDir(const std::string& fileName) { + bool retVal = false; + Stat_t sb; + if (stat(fileName.c_str(), &sb) == 0) { + if (sb.st_mode & S_IFDIR) { + retVal = true; + } + } + return retVal; +} + +//------------------------------------------------------------------------------ +// pal::FileOp::getFileInfo +//------------------------------------------------------------------------------ +bool pal::FileOp::getFileInfo(const std::string& filename, + pal::FileOp::FilenamePartsType_t& filenameParts) { + std::string name; + + // Clear the result + filenameParts.basename.clear(); + filenameParts.extension.clear(); + filenameParts.directory.clear(); + + size_t lastPathSeparator = filename.find_last_of(Path::getSeparator()); + if (lastPathSeparator == std::string::npos) { + // No directory + name = filename; + } else { + // has a directory part + filenameParts.directory = filename.substr(0, lastPathSeparator); + name = filename.substr(lastPathSeparator + 1); + } + + size_t ext = name.find_last_of("."); + if (ext == std::string::npos) { + // no extension + filenameParts.basename = name; + } else { + // has extension + filenameParts.basename = name.substr(0, ext); + filenameParts.extension = name.substr(ext + 1); + } + + return true; +} + +//--------------------------------------------------------------------------- +// pal::FileOp::copyOverFile +//--------------------------------------------------------------------------- +bool pal::FileOp::copyOverFile(const std::string& fromFile, const std::string& toFile) { + bool rc = false; + int readFd; + int writeFd; + struct stat statBuf; + + // Open the input file. + readFd = ::open(fromFile.c_str(), O_RDONLY); + if (readFd == -1) { + close(readFd); + return false; + } + + // Stat the input file to obtain its size. */ + if (fstat(readFd, &statBuf) != 0) { + close(readFd); + return false; + } + + // Open the output file for writing, with the same permissions as the input + writeFd = ::open(toFile.c_str(), O_WRONLY | O_CREAT | O_TRUNC, statBuf.st_mode); + if (writeFd == -1) { + close(readFd); + return false; + } + + // Copy the file in a non-kernel specific way */ + char fileBuf[8192]; + ssize_t rBytes, wBytes; + while (true) { + rBytes = read(readFd, fileBuf, sizeof(fileBuf)); + + if (!rBytes) { + rc = true; + break; + } + + if (rBytes < 0) { + rc = false; + break; + } + + wBytes = write(writeFd, fileBuf, (size_t)rBytes); + + if (!wBytes) { + rc = true; + break; + } + + if (wBytes < 0) { + rc = false; + break; + } + } + + /* Close up. */ + close(readFd); + close(writeFd); + return rc; +} + +static bool getFileInfoListRecursiveImpl(const std::string& path, + pal::FileOp::FilenamePartsListType_t& filenamePartsList, + const bool ignoreDirs, + size_t maxDepth) { + struct dirent** namelist = nullptr; + int entryCount = 0; + + // Base case + if (maxDepth == 0) { + return true; + } + +#ifdef __ANDROID__ + // android dirent.h has the wrong signature for alphasort so it had to be disabled or fixed + entryCount = scandir(path.c_str(), &namelist, 0, 0); +#else + entryCount = scandir(path.c_str(), &namelist, 0, alphasort); +#endif + if (entryCount < 0) { + return false; + } else { + while (entryCount--) { + const std::string dName(namelist[entryCount]->d_name); + free(namelist[entryCount]); + + // skip current directory, prev directory and empty string + if (dName.empty() || dName == "." || dName == "..") { + continue; + } + + std::string curPath = path; + curPath += pal::Path::getSeparator(); + curPath += dName; + + // recurse if directory but avoid symbolic links to directories + if (pal::FileOp::checkIsDir(curPath)) { + Stat_t sb; + if (lstat(curPath.c_str(), &sb) == 0 && S_ISDIR(sb.st_mode)) { + if (!getFileInfoListRecursiveImpl(curPath, filenamePartsList, ignoreDirs, maxDepth - 1)) { + return false; + } + } + + if (ignoreDirs) { + continue; + } + + // Append training / to make this path look like a directory for + // getFileInfo() + if (curPath.back() != pal::Path::getSeparator()) { + curPath += pal::Path::getSeparator(); + } + } + + // add to vector + pal::FileOp::FilenamePartsType_t filenameParts; + if (pal::FileOp::getFileInfo(curPath, filenameParts)) { + filenamePartsList.push_back(filenameParts); + } + } + + free(namelist); + } + + return true; +} + +//--------------------------------------------------------------------------- +// pal::FileOp::getFileInfoList +//--------------------------------------------------------------------------- +bool pal::FileOp::getFileInfoList(const std::string& path, + FilenamePartsListType_t& filenamePartsList) { + return getFileInfoListRecursiveImpl(path, filenamePartsList, false, 1); +} + +//--------------------------------------------------------------------------- +// pal::FileOp::getFileInfoListRecursive +//--------------------------------------------------------------------------- +bool pal::FileOp::getFileInfoListRecursive(const std::string& path, + FilenamePartsListType_t& filenamePartsList, + const bool ignoreDirs) { + return getFileInfoListRecursiveImpl( + path, filenamePartsList, ignoreDirs, std::numeric_limits::max()); +} + +//--------------------------------------------------------------------------- +// pal::FileOp::getAbsolutePath +//--------------------------------------------------------------------------- +std::string pal::FileOp::getAbsolutePath(const std::string& path) { + // NOTE: This implementation is broken currently when a path with + // non-existant components is passed! NEO-19723 was created to address. + char absPath[PATH_MAX + 1] = {0}; + + if (realpath(path.c_str(), absPath) == NULL) { + DEBUG_MSG("GetAbsolute path fail! Error code : %d", errno); + return std::string(); + } + return std::string(absPath); +} + +//--------------------------------------------------------------------------- +// pal::FileOp::setCWD +//--------------------------------------------------------------------------- +bool pal::FileOp::setCurrentWorkingDirectory(const std::string& workingDir) { + return chdir(workingDir.c_str()) == 0; +} + +//--------------------------------------------------------------------------- +// pal::FileOp::getDirectory +//--------------------------------------------------------------------------- +std::string pal::FileOp::getDirectory(const std::string& file) { + std::string rc = file; + size_t offset = file.find_last_of(Path::getSeparator()); + if (offset != std::string::npos) { + rc = file.substr(0, offset); + } + return rc; +} + +//--------------------------------------------------------------------------- +// pal::FileOp::getFileName +//--------------------------------------------------------------------------- +std::string pal::FileOp::getFileName(const std::string& file) { + std::string rc = file; + size_t offset = file.find_last_of(Path::getSeparator()); + if (offset != std::string::npos) { + rc = file.substr(offset + 1); // +1 to skip path separator + } + return rc; +} + +//--------------------------------------------------------------------------- +// pal::FileOp::hasFileExtension +//--------------------------------------------------------------------------- +bool pal::FileOp::hasFileExtension(const std::string& file) { + FilenamePartsType_t parts; + getFileInfo(file, parts); + + return !parts.extension.empty(); +} + +//--------------------------------------------------------------------------- +// pal::FileOp::getCWD +//--------------------------------------------------------------------------- +std::string pal::FileOp::getCurrentWorkingDirectory() { + char buffer[PATH_MAX + 1]; + buffer[0] = '\0'; + + // If there is any failure return empty string. It is technically possible + // to handle paths exceeding PATH_MAX on some flavors of *nix but platforms + // like Android (Bionic) do no provide such capability. For consistency we + // will not handle extra long path names. + if (nullptr == getcwd(buffer, PATH_MAX)) { + return std::string(); + } else { + return std::string(buffer); + } +} + +//--------------------------------------------------------------------------- +// pal::FileOp::partsToString +//--------------------------------------------------------------------------- +std::string pal::FileOp::partsToString(const FilenamePartsType_t& filenameParts) { + std::string path; + + if (!filenameParts.directory.empty()) { + path += filenameParts.directory; + path += Path::getSeparator(); + } + if (!filenameParts.basename.empty()) { + path += filenameParts.basename; + } + if (!filenameParts.extension.empty()) { + path += "."; + path += filenameParts.extension; + } + return path; +} + + +//--------------------------------------------------------------------------- +// pal::FileOp::open +//--------------------------------------------------------------------------- +int32_t pal::FileOp::open(const std::string& path, const AccessMode flags, FileMode mode) { + return ::open(path.c_str(), static_cast(flags), static_cast(mode)); +} + +//--------------------------------------------------------------------------- +// pal::FileOp::close +//--------------------------------------------------------------------------- +int32_t pal::FileOp::close(const int32_t fd) { return ::close(fd); } diff --git a/qnn/jni/qnn/PAL/src/linux/Path.cpp b/qnn/jni/qnn/PAL/src/linux/Path.cpp new file mode 100644 index 00000000..ceda4e34 --- /dev/null +++ b/qnn/jni/qnn/PAL/src/linux/Path.cpp @@ -0,0 +1,48 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All rights reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#include + +#include +#ifndef PATH_MAX +#include +#endif + +#include "PAL/FileOp.hpp" +#include "PAL/Path.hpp" + +char pal::Path::getSeparator() { return '/'; } + +std::string pal::Path::combine(const std::string &s1, const std::string &s2) { + std::stringstream ss; + ss << s1; + if (s1.size() > 0 && s1[s1.size() - 1] != getSeparator()) { + ss << getSeparator(); + } + ss << s2; + return ss.str(); +} + +std::string pal::Path::getDirectoryName(const std::string &path) { + std::string rc = path; + size_t index = path.find_last_of(pal::Path::getSeparator()); + if (index != std::string::npos) { + rc = path.substr(0, index); + } + return rc; +} + +std::string pal::Path::getAbsolute(const std::string &path) { + // Functionality was duplicated of function in FileOp + // Just call that function directly instead + return pal::FileOp::getAbsolutePath(path); +} + +bool pal::Path::isAbsolute(const std::string &path) { + return path.size() > 0 && path[0] == getSeparator(); +} diff --git a/qnn/jni/qnn/QNN.hpp b/qnn/jni/qnn/QNN.hpp new file mode 100644 index 00000000..3f61030b --- /dev/null +++ b/qnn/jni/qnn/QNN.hpp @@ -0,0 +1,37 @@ +//============================================================================== +// +// Copyright (c) 2020-2023 Qualcomm Technologies, Inc. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#pragma once + +#include "QnnInterface.h" +#include "System/QnnSystemInterface.h" +#include "WrapperUtils/QnnWrapperUtils.hpp" + +namespace qnn { +namespace tools { +namespace sample_app { + +// Graph Related Function Handle Types +typedef qnn_wrapper_api::ModelError_t (*ComposeGraphsFnHandleType_t)( + Qnn_BackendHandle_t, QNN_INTERFACE_VER_TYPE, Qnn_ContextHandle_t, + const qnn_wrapper_api::GraphConfigInfo_t **, const uint32_t, + qnn_wrapper_api::GraphInfo_t ***, uint32_t *, bool, QnnLog_Callback_t, + QnnLog_Level_t); +typedef qnn_wrapper_api::ModelError_t (*FreeGraphInfoFnHandleType_t)( + qnn_wrapper_api::GraphInfo_t ***, uint32_t); + +typedef struct QnnFunctionPointers { + ComposeGraphsFnHandleType_t composeGraphsFnHandle; + FreeGraphInfoFnHandleType_t freeGraphInfoFnHandle; + QNN_INTERFACE_VER_TYPE qnnInterface; + QNN_SYSTEM_INTERFACE_VER_TYPE qnnSystemInterface; +} QnnFunctionPointers; + +} // namespace sample_app +} // namespace tools +} // namespace qnn diff --git a/qnn/jni/qnn/QNN/GPU.unused/QnnGpuBackend.h b/qnn/jni/qnn/QNN/GPU.unused/QnnGpuBackend.h new file mode 100644 index 00000000..cf6ad7d9 --- /dev/null +++ b/qnn/jni/qnn/QNN/GPU.unused/QnnGpuBackend.h @@ -0,0 +1,77 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +/** + * @file + * @brief A header which defines the QNN GPU specialization of the QnnBackend.h interface. + */ + +#ifndef QNN_GPU_BACKEND_H +#define QNN_GPU_BACKEND_H + +#ifdef __cplusplus +#include +#else +#include +#endif + +#include "QnnBackend.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * @brief This enum defines QNN GPU custom Backend config options. + */ +typedef enum { + /// If non-zero, tuning mode will be enabled + QNN_GPU_BACKEND_CONFIG_OPTION_ENABLE_TUNING_MODE = 0, + /// The Performance cache directory. Must be non-null + QNN_GPU_BACKEND_CONFIG_OPTION_PERFORMANCE_CACHE_DIR = 1, + /// If non-zero, the performance cache will be ignored when initializing + QNN_GPU_BACKEND_CONFIG_OPTION_INVALIDATE_PERFORMANCE_CACHE = 2, + /// If non-zero, weight sharing is disabled + QNN_GPU_BACKEND_CONFIG_OPTION_WEIGHT_SHARING_ENABLED = 3, + /// If non-zero, kernels will not be profiled in tuning mode + QNN_GPU_BACKEND_CONFIG_OPTION_DISABLE_KERNEL_PROFILING = 4, + /// Unused, present to ensure 32 bits. + QNN_GPU_BACKEND_CONFIG_OPTION_UNDEFINED = 0x7FFFFFFF, +} QnnGpuBackend_ConfigOption_t; + +/** + * @brief A struct which defines the QNN GPU Backend custom configuration options. + * Objects of this type are to be referenced through QnnBackend_CustomConfig_t. + */ +typedef struct { + QnnGpuBackend_ConfigOption_t option; + union UNNAMED { + uint8_t enableTuningMode; + const char* performanceCacheDir; + uint8_t invalidatePerformanceCache; + uint8_t weightSharingEnabled; + uint8_t disableKernelProfiling; + }; +} QnnGpuBackend_CustomConfig_t; + +// clang-format off +/// QnnGpuBackend_CustomConfig_t initializer macro +#define QNN_GPU_BACKEND_CUSTOM_CONFIG_INIT \ + { \ + QNN_GPU_BACKEND_CONFIG_OPTION_UNDEFINED, /*option*/ \ + { \ + false /*enableTuningMode*/ \ + } \ + } +// clang-format on + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif diff --git a/qnn/jni/qnn/QNN/GPU.unused/QnnGpuCommon.h b/qnn/jni/qnn/QNN/GPU.unused/QnnGpuCommon.h new file mode 100644 index 00000000..906e33e0 --- /dev/null +++ b/qnn/jni/qnn/QNN/GPU.unused/QnnGpuCommon.h @@ -0,0 +1,49 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +/** + * @file + * @brief A header which defines common QNN GPU macros. + */ + +#ifndef QNN_GPU_COMMON_H +#define QNN_GPU_COMMON_H + +#include "QnnCommon.h" + +/// GPU Backend identifier +#define QNN_BACKEND_ID_GPU 4 + +/// GPU interface provider +#define QNN_GPU_INTERFACE_PROVIDER_NAME "GPU_QTI_AISW" + +// GPU API Version values +#define QNN_GPU_API_VERSION_MAJOR 3 +#define QNN_GPU_API_VERSION_MINOR 12 +#define QNN_GPU_API_VERSION_PATCH 0 + +// clang-format off + +/// Macro to set Qnn_ApiVersion_t for GPU backend +#define QNN_GPU_API_VERSION_INIT \ + { \ + { \ + QNN_API_VERSION_MAJOR, /*coreApiVersion.major*/ \ + QNN_API_VERSION_MINOR, /*coreApiVersion.major*/ \ + QNN_API_VERSION_PATCH /*coreApiVersion.major*/ \ + }, \ + { \ + QNN_GPU_API_VERSION_MAJOR, /*backendApiVersion.major*/ \ + QNN_GPU_API_VERSION_MINOR, /*backendApiVersion.minor*/ \ + QNN_GPU_API_VERSION_PATCH /*backendApiVersion.patch*/ \ + } \ + } + +// clang-format on + +#endif // QNN_GPU_COMMON_H diff --git a/qnn/jni/qnn/QNN/GPU.unused/QnnGpuContext.h b/qnn/jni/qnn/QNN/GPU.unused/QnnGpuContext.h new file mode 100644 index 00000000..42599e42 --- /dev/null +++ b/qnn/jni/qnn/QNN/GPU.unused/QnnGpuContext.h @@ -0,0 +1,78 @@ +//============================================================================== +// +// Copyright (c) 2021-2023 Qualcomm Technologies, Inc. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +/** + * @file + * @brief A header which defines the QNN GPU specialization of the QnnContext.h interface. + */ + +#ifndef QNN_GPU_CONTEXT_H +#define QNN_GPU_CONTEXT_H + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * @brief This enum defines QNN GPU custom context config options. + */ +typedef enum { + /// Sets performance hint options via QnnGpuContext_PerfHint_t + QNN_GPU_CONTEXT_CONFIG_OPTION_PERF_HINT = 0, + /// If non-zero, OpenGL buffers will be used + QNN_GPU_CONTEXT_CONFIG_OPTION_USE_GL_BUFFERS = 1, + /// The kernel disk cache directory. Must be non-null + QNN_GPU_CONTEXT_CONFIG_OPTION_KERNEL_REPO_DIR = 2, + /// If non-zero, the kernel disk cache will be ignored when initializing + QNN_GPU_CONTEXT_CONFIG_OPTION_INVALIDATE_KERNEL_REPO = 3, + /// Unused, present to ensure 32 bits. + QNN_GPU_CONTEXT_CONFIG_OPTION_UNDEFINED = 0x7FFFFFFF +} QnnGpuContext_ConfigOption_t; + +/** + * @brief An enum which defines the different GPU performance hint options. + */ +typedef enum { + /// Sets the GPU performance hint to high performance, this is the default + QNN_GPU_CONTEXT_PERF_HINT_HIGH = 0, + /// Sets the GPU performance hint to normal performance + QNN_GPU_CONTEXT_PERF_HINT_NORMAL = 1, + /// Sets the GPU performance hint to low performance + QNN_GPU_CONTEXT_PERF_HINT_LOW = 2 +} QnnGpuContext_PerfHint_t; + +/** + * @brief A struct which defines the QNN GPU context custom configuration options. + * Objects of this type are to be referenced through QnnContext_CustomConfig_t. + */ +typedef struct { + QnnGpuContext_ConfigOption_t option; + union UNNAMED { + QnnGpuContext_PerfHint_t perfHint; + uint8_t useGLBuffers; + const char* kernelRepoDir; + uint8_t invalidateKernelRepo; + }; +} QnnGpuContext_CustomConfig_t; + +// clang-format off +/// QnnGpuContext_CustomConfig_t initializer macro +#define QNN_GPU_CONTEXT_CUSTOM_CONFIG_INIT \ + { \ + QNN_GPU_CONTEXT_CONFIG_OPTION_UNDEFINED, /*option*/ \ + { \ + QNN_GPU_CONTEXT_PERF_HINT_HIGH /*perfHint*/ \ + } \ + } +// clang-format on + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif diff --git a/qnn/jni/qnn/QNN/GPU.unused/QnnGpuGraph.h b/qnn/jni/qnn/QNN/GPU.unused/QnnGpuGraph.h new file mode 100644 index 00000000..e0652d44 --- /dev/null +++ b/qnn/jni/qnn/QNN/GPU.unused/QnnGpuGraph.h @@ -0,0 +1,72 @@ +//============================================================================== +// +// Copyright (c) 2020-2021 Qualcomm Technologies, Inc. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +/** + * @file + * @brief A header which defines the QNN GPU specialization of the QnnGraph.h interface. + */ + +#ifndef QNN_GPU_GRAPH_H +#define QNN_GPU_GRAPH_H + +#ifdef __cplusplus +#include +#else +#include +#endif + +#include "QnnGraph.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * @brief An enum which defines the different tensor optimization options. A + * tensor may be optimized to the specified QnnGpu_Precision_t when it + * is a graph tensor that is not a graph input or a graph output and + * does not connect two operations from different op packages. + */ +typedef enum { + /// Sets the precision mode to floating point 32-bit (FP32) + QNN_GPU_PRECISION_FP32 = 0, + /// Sets the precision mode to floating point 16-bit (FP16) + QNN_GPU_PRECISION_FP16 = 1, + /// Sets the precision mode to FP16 for storage and FP32 for calculations + QNN_GPU_PRECISION_HYBRID = 2, + /// Uses the tensor data type provided by the user (default) + QNN_GPU_PRECISION_USER_PROVIDED = 3, +} QnnGpu_Precision_t; + +/** + * @brief A struct which defines the QNN GPU graph custom configuration options. + * Objects of this type are to be referenced through QnnGraph_CustomConfig_t. + */ +typedef struct { + QnnGpu_Precision_t precision; + uint8_t disableMemoryOptimizations; + uint8_t disableNodeOptimizations; + uint8_t disableQueueRecording; +} QnnGpuGraph_CustomConfig_t; + +// clang-format off +/// QnnGpuGraph_CustomConfig_t initializer macro +#define QNN_GPU_GRAPH_CUSTOM_CONFIG_INIT \ + { \ + QNN_GPU_PRECISION_USER_PROVIDED, /*precision*/ \ + 0u, /*disableMemoryOptimizations*/ \ + 0u, /*disableNodeOptimizations*/ \ + 0u /*disableQueueRecording*/ \ + } +// clang-format on + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif diff --git a/qnn/jni/qnn/QNN/GPU.unused/QnnGpuMem.h b/qnn/jni/qnn/QNN/GPU.unused/QnnGpuMem.h new file mode 100644 index 00000000..1c6cd5c3 --- /dev/null +++ b/qnn/jni/qnn/QNN/GPU.unused/QnnGpuMem.h @@ -0,0 +1,52 @@ +//============================================================================== +// +// Copyright (c) 2024 Qualcomm Technologies, Inc. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +/** + * @file + * @brief A header which defines the QNN GPU specialization of the QnnMem.h interface. + */ + +#ifndef QNN_GPU_MEM_H +#define QNN_GPU_MEM_H + +#ifdef __cplusplus +extern "C" { +#endif + +typedef void* QnnGpuMem_Buffer_t; + +/** + * @brief This enum defines QNN GPU memory type + */ +typedef enum { QNN_GPU_MEM_OPENCL = 0, QNN_GPU_MEM_UNDEFINED = 0x7FFFFFF } QnnGpu_MemType_t; + +/** + * @brief A struct which defines the QNN GPU memory preallocated by the client. + * Objects of this type are to be referenced through Qnn_MemInfoCustom_t. + */ +typedef struct { + QnnGpu_MemType_t memType; + union { + QnnGpuMem_Buffer_t buffer; + }; +} QnnGpu_MemInfoCustom_t; + +// clang-format off +/// QnnGpu_MemInfoCustom_t initializer macro +#define QNN_GPU_MEMINFO_CUSTOM_INIT \ + { \ + QNN_GPU_MEM_UNDEFINED, /*memType*/ \ + NULL /* buffer*/ \ + } +// clang-format on + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif diff --git a/qnn/jni/qnn/QNN/GPU.unused/QnnGpuOpPackage.h b/qnn/jni/qnn/QNN/GPU.unused/QnnGpuOpPackage.h new file mode 100644 index 00000000..7a6a1e34 --- /dev/null +++ b/qnn/jni/qnn/QNN/GPU.unused/QnnGpuOpPackage.h @@ -0,0 +1,703 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +/** + * @file + * @brief A header which defines the QNN GPU specialization of the QnnOpPackage.h interface. + */ + +#ifndef QNN_GPU_OP_PACKAGE_H +#define QNN_GPU_OP_PACKAGE_H + +#ifdef __cplusplus +#include +#else +#include +#endif + +#include "GPU/QnnGpuCommon.h" +#include "GPU/QnnGpuGraph.h" +#include "QnnOpPackage.h" +#include "QnnTypes.h" + +#ifdef __cplusplus +extern "C" { +#endif + +//============================================================================= +// QnnOpPackage_GlobalInfrastructure_t specialization. +//============================================================================= + +/** + * @brief A struct which is used to communicate device constant properties + */ +typedef struct { + /// GPU device version string + char deviceVersion[128]; + /// GPU driver interface version {major, minor} + uint32_t interfaceVersion[2]; + /// GPU Adreno(TM) tier string + char tierName[8]; + /// GPU driver version {product, major, minor, patch} + uint32_t compilerVersion[4]; + /// GPU device max work group size + size_t maxWorkGroupSize; + /// GPU device image 1D max width + size_t image1dBufferMaxWidth; + /// GPU device image 2D max width + size_t image2dMaxWidth; + /// GPU device image 2D max height + size_t image2dMaxHeight; + /// GPU device max memory allocation size + size_t maxBufferAllocSize; + /// GPU device addr alignment in bits + uint32_t baseAddrAlignment; + /// GPU device image 2D Array max width + size_t image2dArrayMaxWidth; + /// GPU device image 2D Array max height + size_t image2dArrayMaxHeight; + /// GPU device image 2D Array max depth + size_t image2dArrayMaxDepth; + /// GPU compiler predicate clobber type + bool predicateClobberFullRegister; + /// GPU local memory support type + bool isLocalMemorySupported; + /// GPU compiler vector64 support + bool vector64Support; +} QnnGpu_DeviceProperties_t; + +/** + * @brief A QNN GPU struct specializing QnnOpPackage_GlobalInfrastructure_t + */ +typedef struct _QnnOpPackage_GlobalInfrastructure_t { + /// GPU backend version (as returned by QnnBackend_getApiVersion()) + const Qnn_ApiVersion_t* sdkApiVersion; + /// GPU device properties + const QnnGpu_DeviceProperties_t* deviceProperties; + /// Null terminated path to the OpenCL driver used by the backend + const char* driverPath; +} QnnGpuOpPackage_GlobalInfrastructure_t; + +//============================================================================= +// QnnOpPackage_PackageInfo_t specialization. +//============================================================================= + +/** + * @brief A struct having op package specific information + */ +typedef struct _QnnOpPackage_PackageInfo_t { + /// Null terminated hash key string of all kernel sources + const char* kernelRepoHash; +} QnnGpuOpPackage_PackageInfo_t; + +//============================================================================= +// QnnOpPackage_Optimization_t specialization. +//============================================================================= + +/** + * @brief An enum to specify the QNN GPU optimization type + * + */ +typedef enum { + /// Undefined option only used for QNN_GPU_OP_PACKAGE_OPTIMIZATION_INIT + QNN_GPU_OPTIMIZATION_TYPE_UNDEFINED = 0, + /// Super node optimization + QNN_GPU_OPTIMIZATION_TYPE_SUPER_NODE = 2, +} QnnGpuOpPackage_OptimizationType_t; + +/** + * @brief A struct representing a super node connection constraint. + */ +typedef struct { + /// Producer node corresponding to QnnGpuOpPackage_SuperNodeOptimization_t::operations + uint32_t producer; + /// Output tensor index corresponding to the producer node + uint32_t producerOutputIndex; + /// Consumer node corresponding to QnnGpuOpPackage_SuperNodeOptimization_t::operations + uint32_t consumer; + /// Output tensor index corresponding to the consumer node + uint32_t consumerInputIndex; +} QnnGpuOpPackage_SuperNodeConnectionConstraint_t; + +/** + * @brief An enum to specify the source of a tensor in an op def for a tensor constraint. + * + */ +typedef enum { + /// Tensor is an op def output + QNN_GPU_OPTIMIZATION_SUPER_NODE_TENSOR_SOURCE_OUTPUT = 1, + QNN_GPU_OPTIMIZATION_SUPER_NODE_TENSOR_SOURCE_INPUT = 2, +} QnnGpuOpPackage_TensorConstraintSource_t; + +/** + * @brief An enum to specify the tensor constraint type. + * + */ +typedef enum { + /// Add a Qnn_DataType_t to the whitelist of allowable types. + /// If no data type constraint is present for a tensor, all data types are allowed. + QNN_GPU_OPTIMIZATION_SUPER_NODE_TENSOR_CONSTRAINT_DATA_TYPE = 1, + /// Tensor must match it's rank + QNN_GPU_OPTIMIZATION_SUPER_NODE_TENSOR_CONSTRAINT_RANK = 2, + /// Tensor must match one of it's dimensions + QNN_GPU_OPTIMIZATION_SUPER_NODE_TENSOR_CONSTRAINT_DIMENSION = 3, + /// Add a Qnn_TensorType_t to the whitelist of allowable tensor types. + /// If no tensor type constraint is present for a tensor, all types are allowed. + QNN_GPU_OPTIMIZATION_SUPER_NODE_TENSOR_CONSTRAINT_TENSOR_TYPE = 4, +} QnnGpuOpPackage_TensorConstraintType_t; + +/** + * @brief A struct representing a tensor constraint. + */ +typedef struct { + /// Operation corresponding to QnnGpuOpPackage_SuperNodeOptimization_t::operations + uint32_t operationIndex; + /// Source of the tensor in the Qnn_OpConfig_t + QnnGpuOpPackage_TensorConstraintSource_t source; + union { + /// Tensor index in the Qnn_OpConfig_t, used only for inputs and outputs + uint32_t index; + /// Tensor parameter name in the Qnn_OpConfig_t, used only for parameters + const char* name; + }; + /// Type of tensor constraint + QnnGpuOpPackage_TensorConstraintType_t type; + union { + /// Tensor data type for Qnn_DataType_t constraints + Qnn_DataType_t dataType; + /// Tensor type for Qnn_TensorType_t constraints + Qnn_TensorType_t tensorType; + /// Tensor rank for rank constraints + uint32_t rank; + struct { + /// Tensor dimension index for dimension constraints + uint32_t index; + /// Tensor dimension size for dimension constraints + uint32_t size; + } dimension; + }; +} QnnGpuOpPackage_TensorConstraint_t; + +typedef struct { + /// Null-terminated array of comma separated lists of operations used for matching super node ops. + /// An asterisk (*) may be used to represent any operation type. + const char** operations; + /// Null-terminated array of pointers to super node connection constraints + QnnGpuOpPackage_SuperNodeConnectionConstraint_t** connectionConstraints; + /// Null-terminated array of pointers to super node tensor constraints + QnnGpuOpPackage_TensorConstraint_t** tensorConstraints; +} QnnGpuOpPackage_SuperNodeOptimization_t; + +// clang-format off +/// QnnGpuOpPackage_SuperNodeOptimization_t initializer macro +#define QNN_GPU_OP_PACKAGE_SUPER_NODE_OPTIMIZATION_INIT \ + { \ + NULL, /*operations*/ \ + NULL, /*connectionConstraints*/ \ + NULL, /*tensorConstraints*/ \ + } +// clang-format on + +/** + * @brief A struct representing a QNN GPU optimization. + */ +typedef struct _QnnOpPackage_Optimization_t { + /// Type of optimization + QnnGpuOpPackage_OptimizationType_t type; + /// Op package assigned name of the optimization + const char* name; + union { + /// Super node optimization, used when type is QNN_GPU_OPTIMIZATION_TYPE_SUPER_NODE + const QnnGpuOpPackage_SuperNodeOptimization_t* superNode; + }; +} QnnGpuOpPackage_Optimization_t; + +/// QnnGpuOpPackage_Optimization_t initializer macro +#define QNN_GPU_OP_PACKAGE_OPTIMIZATION_INIT \ + { \ + QNN_GPU_OPTIMIZATION_TYPE_UNDEFINED, NULL, { NULL } \ + } + +//============================================================================= +// QnnOpPackage_GraphInfrastructure_t specialization. +//============================================================================= + +/** + * @brief A QNN GPU struct specializing QnnOpPackage_GraphInfrastructure_t + */ +typedef struct _QnnOpPackage_GraphInfrastructure_t { + /// GPU precision mode, user-supplied hint used for optimal kernel selection + QnnGpu_Precision_t precisionMode; + /// GPU device properties + const QnnGpu_DeviceProperties_t* deviceProperties; +} QnnGpuOpPackage_GraphInfrastructure_t; + +//============================================================================= +// QNN GPU Memory Object +//============================================================================= + +/** + * @brief An enum to specify the QNN GPU memory object type + * + */ +typedef enum { + /// Host memory, only used for Qnn_Param_t tensors + QNN_GPU_MEM_OBJ_TYPE_HOST = 0, + /// GPU driver buffer memory object + QNN_GPU_MEM_OBJ_TYPE_BUFFER = 1, + /// GPU driver image 2D memory object + QNN_GPU_MEM_OBJ_TYPE_IMAGE2D = 2, + /// GPU driver image 2D array memory object + QNN_GPU_MEM_OBJ_TYPE_IMAGE2D_ARRAY = 3, + /// Aggregation of GPU driver image 2D memory objects + QNN_GPU_MEM_OBJ_TYPE_AGGREGATED_IMAGE2D = 4, + /// Aggregation of GPU driver image 2D array memory objects + QNN_GPU_MEM_OBJ_TYPE_AGGREGATED_IMAGE2D_ARRAY = 5, + /// Memory type is unclaimed and can be specified by the op package via the \n + /// QnnGpu_OutputClaim_t struct + QNN_GPU_MEM_OBJ_TYPE_UNCLAIMED = 6, + /// GPU driver image 1D memory object + QNN_GPU_MEM_OBJ_TYPE_IMAGE1D_BUFFER = 7, +} QnnGpu_MemoryObjectType_t; + +/** + * @brief An enum to specify the QNN GPU memory layout + * + */ +typedef enum { + /// HWC layout + QNN_GPU_MEM_LAYOUT_HWC = 0, + /// HCW layout + QNN_GPU_MEM_LAYOUT_HCW = 1, + /// CHW layout + QNN_GPU_MEM_LAYOUT_CHW = 2, + /// C_HWC4 layout + QNN_GPU_MEM_LAYOUT_C_HWC4 = 3, + /// DHWC layout + QNN_GPU_MEM_LAYOUT_DHWC = 4, + /// CDHW layout + QNN_GPU_MEM_LAYOUT_CDHW = 5, + /// Undefined + QNN_GPU_MEM_LAYOUT_UNDEFINED = 0x7FFFFFFF, +} QnnGpu_MemoryLayout_t; + +/** + * @brief A struct to specify blockSize for weight Tensor and tensorId for weight Param tensor + */ +typedef struct { + // Block Quantization, block Sizes + uint32_t* bqBlockSize; + /// Tensor Id for Quantization encodings + uint32_t bqEncodingTensorId; +} QnnGpu_BlockEncodingInfo_t; + +// clang-format off +/// QnnGpu_MemoryObject_t initializer macro +#define QNN_GPU_BLOCK_ENCODING_INFO_INIT \ + { \ + NULL, /*bqBlockSize*/ \ + 0u /*bqEncodingTensorId*/ \ + } +// clang-format on + +/** + * @brief A QNN GPU struct specifying a memory object + * This struct is used with the following kernel argument types: + * - QNN_GPU_KERNEL_ARG_TYPE_OP_INPUT_READ + * - QNN_GPU_KERNEL_ARG_TYPE_OP_INPUT_READWRITE + * - QNN_GPU_KERNEL_ARG_TYPE_OP_OUTPUT_WRITE + * - QNN_GPU_KERNEL_ARG_TYPE_INTERNAL_READ + * - QNN_GPU_KERNEL_ARG_TYPE_INTERNAL_READWRITE + * - QNN_GPU_KERNEL_ARG_TYPE_INTERNAL_WRITE + */ +typedef struct { + /// Type of memory object + QnnGpu_MemoryObjectType_t type; + /// Data type of the memory object + Qnn_DataType_t dataType; + /// Memory object dimensions \n + /// Size is numDimensions. Uses the following type dependent format: \n + /// QNN_GPU_MEM_OBJ_TYPE_BUFFER -> {numElements} \n + /// QNN_GPU_MEM_OBJ_TYPE_IMAGE2D -> {height,width} \n + /// QNN_GPU_MEM_OBJ_TYPE_IMAGE2D_ARRAY -> {height,width,array_size} \n + /// QNN_GPU_MEM_OBJ_TYPE_AGGREGATED_IMAGE2D -> {num_batches,height,width} \n + /// QNN_GPU_MEM_OBJ_TYPE_AGGREGATED_IMAGE2D_ARRAY -> {num_batches,height,width,array_size} + uint32_t* dimensions; + /// Memory object offsets \n + /// Size is numDimensions. \n + /// Indicates where the data store starts in the memory object. \n + uint32_t* offsets; + /// Number of dimensions in memory object \n + /// Size is numDimensions. Has the following type dependent size: \n + /// QNN_GPU_MEM_OBJ_TYPE_BUFFER -> 1 \n + /// QNN_GPU_MEM_OBJ_TYPE_IMAGE2D -> 2 \n + /// QNN_GPU_MEM_OBJ_TYPE_IMAGE2D_ARRAY -> 3 \n + /// QNN_GPU_MEM_OBJ_TYPE_AGGREGATED_IMAGE2D -> 3 \n + /// QNN_GPU_MEM_OBJ_TYPE_AGGREGATED_IMAGE2D_ARRAY -> 4 + uint32_t numDimensions; + /// Memory object layout \n + /// Op package specific layout identifier \n + /// Default is QNN_GPU_MEM_LAYOUT_UNDEFINED if not already specified by a prior operation + QnnGpu_MemoryLayout_t layout; + /// Block Quantization Tensor Information + QnnGpu_BlockEncodingInfo_t blockEncodingInfo; + /// Memory object name used to propagate the tensor name to Backend + const char* name; +} QnnGpu_MemoryObject_t; + +// clang-format off +/// QnnGpu_MemoryObject_t initializer macro +#define QNN_GPU_MEMORY_OBJECT_INIT \ + { \ + QNN_GPU_MEM_OBJ_TYPE_UNCLAIMED, /*type*/ \ + QNN_DATATYPE_UNDEFINED, /*dataType*/ \ + NULL, /*dimensions*/ \ + NULL, /*offsets*/ \ + 0u, /*numDimensions*/ \ + QNN_GPU_MEM_LAYOUT_UNDEFINED, /*layout*/ \ + QNN_GPU_BLOCK_ENCODING_INFO_INIT, /*blockEncodingInfo*/ \ + NULL /*name*/ \ + } +// clang-format on + +//============================================================================= +// QnnOpPackage_Node_t specialization. +//============================================================================= + +/** + * @brief A QNN GPU struct specifying a storage tensor + */ +typedef struct { + /// Tensor ID + uint32_t id; + /// Tensor's associated memory object + const QnnGpu_MemoryObject_t* memoryObject; +} QnnGpu_TensorStorageType_t; + +// clang-format off +/// QnnGpu_TensorStorageType_t initializer macro +#define QNN_GPU_TENSOR_STORAGE_TYPE_INIT \ + { \ + 0u, /*id*/ \ + NULL /*memoryObject*/ \ + } +// clang-format on + +/** + * @brief A QNN GPU struct specializing QnnOpPackage_Node_t + */ +typedef struct _QnnOpPackage_Node_t { + /// Optimization index, see QnnOpPackage_Info_t, ignore when only one op config provided + uint32_t optimization; + /// Null-terminated array of operation config pointers + /// Only one pointer provided when no optimizations performed + const Qnn_OpConfig_t** configs; + /// Null-terminated array of tensor storage type pointers called out in the config + const QnnGpu_TensorStorageType_t** storageTypes; + /// Kernel variant index, if set then used by OpPackage to determine kernel selection + int32_t kernelVariant; +} QnnGpuOpPackage_Node_t; + +//============================================================================= +// QnnOpPackage_OpImpl_t specialization. +//============================================================================= + +/** + * @brief A QNN GPU struct specifying an output tensor claim. Using the principle + * of least work, operations must output a memory object type that is most + * convenient for itself. Only QNN_TENSOR_TYPE_NATIVE tensor types may + * be claimed. + */ +typedef struct { + /// Index into the Qnn_OpConfig_t provided in QnnGpuOpPackage_Node_t + uint32_t opConfigIndex; + /// Index into the operation outputs to identify the tensor + uint32_t outputIndex; + /// Specification of the claimed memory object + const QnnGpu_MemoryObject_t* memoryObject; +} QnnGpu_OutputClaim_t; + +// clang-format off +/// QnnGpu_OutputClaim_t initializer macro +#define QNN_GPU_OUTPUT_CLAIM_INIT \ + { \ + 0u, /*opConfigIndex*/ \ + 0u, /*outputIndex*/ \ + NULL /*memoryObject*/ \ + } +// clang-format on + +/** + * @brief An enum to specify the kernel argument type. + * + */ +typedef enum { + /// Operation input tensor used as kernel input + QNN_GPU_KERNEL_ARG_TYPE_OP_INPUT_READ = 0, + /// Operation input tensor used as kernel output + QNN_GPU_KERNEL_ARG_TYPE_OP_INPUT_READWRITE = 1, + /// Operation output tensor used as kernel output + QNN_GPU_KERNEL_ARG_TYPE_OP_OUTPUT_WRITE = 2, + /// Operation internal tensor used as kernel input + QNN_GPU_KERNEL_ARG_TYPE_INTERNAL_READ = 3, + /// Operation internal tensor used as kernel input/output + QNN_GPU_KERNEL_ARG_TYPE_INTERNAL_READWRITE = 4, + /// Operation internal tensor used as kernel output + QNN_GPU_KERNEL_ARG_TYPE_INTERNAL_WRITE = 5, + /// Plain old data kernel argument + QNN_GPU_KERNEL_ARG_TYPE_DATA = 6, + /// Local memory kernel argument + QNN_GPU_KERNEL_ARG_TYPE_LOCAL = 7, + /// Null pointer kernel argument + QNN_GPU_KERNEL_ARG_TYPE_NULL_PTR = 8, + /// Operation tensor parameter used as kernel input + QNN_GPU_KERNEL_ARG_TYPE_OP_TENSOR_PARAM = 9, +} QnnGpu_KernelArgType_t; + +/** + * @brief A QNN GPU struct specifying a kernel argument corresponding to a tensor. + * This struct is used with the following kernel argument types: + * - QNN_GPU_KERNEL_ARG_TYPE_OP_INPUT_READ + * - QNN_GPU_KERNEL_ARG_TYPE_OP_INPUT_READWRITE + * - QNN_GPU_KERNEL_ARG_TYPE_OP_OUTPUT_WRITE + * - QNN_GPU_KERNEL_ARG_TYPE_INTERNAL_READ + * - QNN_GPU_KERNEL_ARG_TYPE_INTERNAL_READWRITE + * - QNN_GPU_KERNEL_ARG_TYPE_INTERNAL_WRITE + */ +typedef struct { + /// Index into the Qnn_OpConfig_t provided in QnnGpuOpPackage_Node_t, ignored for INTERNAL types + uint32_t opConfigIndex; + /// Index into the operation input ot output list or the internal tensor list + uint32_t tensorIndex; + /// Batch element index for aggregated tensor types + uint32_t element; +} QnnGpu_TensorKernelArg_t; + +// clang-format off +/// QnnGpu_TensorKernelArg_t initializer macro +#define QNN_GPU_TENSOR_KERNEL_ARG_INIT \ + { \ + 0u, /*opConfigIndex*/ \ + 0u, /*tensorIndex*/ \ + 0u /*element*/ \ + } +// clang-format on + +/** + * @brief An enum to specify the kernel data argument type. + * + */ +typedef enum { + QNN_GPU_KERNEL_ARG_CL_TYPE_CHAR = 0, + QNN_GPU_KERNEL_ARG_CL_TYPE_UCHAR = 1, + QNN_GPU_KERNEL_ARG_CL_TYPE_SHORT = 2, + QNN_GPU_KERNEL_ARG_CL_TYPE_USHORT = 3, + QNN_GPU_KERNEL_ARG_CL_TYPE_INT = 4, + QNN_GPU_KERNEL_ARG_CL_TYPE_UINT = 5, + QNN_GPU_KERNEL_ARG_CL_TYPE_LONG = 6, + QNN_GPU_KERNEL_ARG_CL_TYPE_ULONG = 7, + QNN_GPU_KERNEL_ARG_CL_TYPE_FLOAT = 8, + QNN_GPU_KERNEL_ARG_CL_TYPE_DOUBLE = 9, +} QnnGpu_DataKernelArgType_t; + +/** + * @brief A QNN GPU struct specifying a kernel argument corresponding to a plain old data. + * This struct is used only with the QNN_GPU_KERNEL_ARG_TYPE_DATA arg type. + */ +typedef struct { + /// Data type of the data + QnnGpu_DataKernelArgType_t type; + union { + /// Used with QNN_GPU_KERNEL_ARG_CL_TYPE_CHAR + int8_t qnnChar; + /// Used with QNN_GPU_KERNEL_ARG_CL_TYPE_UCHAR + uint8_t qnnUChar; + /// Used with QNN_GPU_KERNEL_ARG_CL_TYPE_SHORT + int16_t qnnShort; + /// Used with QNN_GPU_KERNEL_ARG_CL_TYPE_USHORT + uint16_t qnnUShort; + /// Used with QNN_GPU_KERNEL_ARG_CL_TYPE_INT + int32_t qnnInt; + /// Used with QNN_GPU_KERNEL_ARG_CL_TYPE_UINT + uint32_t qnnUInt; + /// Used with QNN_GPU_KERNEL_ARG_CL_TYPE_LONG + int64_t qnnLong; + /// Used with QNN_GPU_KERNEL_ARG_CL_TYPE_ULONG + uint64_t qnnULong; + /// Used with QNN_GPU_KERNEL_ARG_CL_TYPE_FLOAT + float qnnFloat; + /// Used with QNN_GPU_KERNEL_ARG_CL_TYPE_DOUBLE + double qnnDouble; + }; +} QnnGpu_DataKernelArg_t; + +/// QnnGpu_DataKernelArg_t initializer macro +#define QNN_GPU_DATA_KERNEL_ARG_INIT \ + { \ + QNN_GPU_KERNEL_ARG_CL_TYPE_CHAR, /*type*/ \ + { \ + 0 /*qnnChar*/ \ + } \ + } + +/** + * @brief A QNN GPU struct specifying a kernel argument corresponding to a local memory type. + * This struct is used only with the QNN_GPU_KERNEL_ARG_TYPE_LOCAL arg type. + */ +typedef struct { + /// Size of the memory requested in bytes + uint32_t size; +} QnnGpu_LocalKernelArg_t; + +/// QnnGpu_LocalKernelArg_t initializer macro +#define QNN_GPU_LOCAL_KERNEL_ARG_INIT \ + { 0u /*size*/ } + +/** + * @brief A QNN GPU struct specifying a kernel argument. + * Note that the QNN_GPU_KERNEL_ARG_TYPE_NULL_PTR type does not have an entry in + * the union. + */ +typedef struct { + /// Type of kernel argument + QnnGpu_KernelArgType_t type; + union { + /// Tensor type argument + QnnGpu_TensorKernelArg_t tensor; + /// Plain old data argument + QnnGpu_DataKernelArg_t data; + /// Local memory argument + QnnGpu_LocalKernelArg_t local; + }; +} QnnGpu_KernelArg_t; + +/// QnnGpu_KernelArg_t initializer macro +#define QNN_GPU_KERNEL_ARG_INIT \ + { \ + QNN_GPU_KERNEL_ARG_TYPE_NULL_PTR, /*type*/ \ + { \ + QNN_GPU_TENSOR_KERNEL_ARG_INIT /*tensor*/ \ + } \ + } + +/** + * @brief An enum to specify the kernel source type. + * + */ +typedef enum { + QNN_GPU_KERNEL_SOURCE_TYPE_TEXT = 0, + QNN_GPU_KERNEL_SOURCE_TYPE_BINARY = 1, +} QnnGpu_KernelSourceType_t; + +/** + * @brief This enum defines QNN GPU kernel tuning options. + */ +typedef enum { + /// local work size tuning + QNN_GPU_KERNEL_TUNING_LOCAL_WORK_SIZE = 0, + QNN_GPU_KERNEL_TUNING_UNDEFINED = 0x7FFFFFFF +} QnnGpu_KernelTuningOption_t; + +/** + * @brief This struct provides local-work-size tuning configuration. + */ +typedef struct { + uint32_t minValue[3]; + uint32_t maxValue[3]; + uint32_t stepSize[3]; +} QnnGpu_KernelLocalWorkSizeTuning_t; + +/** + * @brief This struct provides QNN GPU kernel tuning configuration. + */ +typedef struct { + QnnGpu_KernelTuningOption_t option; + union UNNAMED { + QnnGpu_KernelLocalWorkSizeTuning_t lws; + }; +} QnnGpu_KernelTuningConfig_t; + +/** + * @brief A QNN GPU struct specifying a kernel. + */ +typedef struct { + /// Kernel source code or binary + const void* kernelSource; + /// Length of kernel source/binary in bytes + size_t sourceLength; + /// Type of kernel source + QnnGpu_KernelSourceType_t sourceType; + /// Null terminated build options string used for kernel compilation + const char* buildOptions; + /// Rank of the globalWorkSizes + size_t globalWorkDim; + /// Global work sizes used by enqueuing the kernel + size_t globalWorkSizes[3]; + /// Rank of the localWorkSizes + size_t localWorkDim; + /// Local work sizes used by enqueuing the kernel + size_t localWorkSizes[3]; + /// Null-terminated array of kernel arguments in the order they appear in the kernel function + QnnGpu_KernelArg_t** args; + /// Null terminated name of the kernel + const char* name; + /// If non-zero, kernel will be enqueued during execute even if it is static + uint32_t isDynamic; + /// Null-terminated array to provide kernel tuning configurations. + QnnGpu_KernelTuningConfig_t** tuningConfigs; + /// Reserved field, must be null + void* reserved; +} QnnGpu_Kernel_t; + +// clang-format off +/// QnnGpu_Kernel_t initializer macro +#define QNN_GPU_KERNEL_INIT \ + { \ + NULL, /*kernelSource*/ \ + 0u, /*sourceLength*/ \ + QNN_GPU_KERNEL_SOURCE_TYPE_TEXT, /*sourceType*/ \ + NULL, /*buildOptions*/ \ + 0u, /*globalWorkDim*/ \ + {0u}, /*globalWorkSizes*/ \ + 0u, /*localWorkDim*/ \ + {0u}, /*localWorkSizes*/ \ + NULL, /*args*/ \ + NULL, /*name*/ \ + 0u, /*isDynamic*/ \ + NULL, /*tuningConfigs*/ \ + NULL /*reserved*/ \ + } +// clang-format on + +/** + * @brief A QNN GPU struct specifying an operation. + */ +typedef struct _QnnOpPackage_OpImpl_t { + /// Null-terminated array of output claims + QnnGpu_OutputClaim_t** outputClaims; + /// Null-terminated array of tensor requests + QnnGpu_MemoryObject_t** memoryObjects; + /// Null-terminated array of kernels + QnnGpu_Kernel_t** kernels; +} QnnGpu_Operation_t; + +// clang-format off +/// QnnGpu_Operation_t initializer macro +#define QNN_GPU_OPERATION_INIT \ + { \ + NULL, /*outputClaims*/ \ + NULL, /*memoryObjects*/ \ + NULL, /*kernels*/ \ + } +// clang-format on + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif diff --git a/qnn/jni/qnn/QNN/GPU.unused/meson.build b/qnn/jni/qnn/QNN/GPU.unused/meson.build new file mode 100644 index 00000000..8550c460 --- /dev/null +++ b/qnn/jni/qnn/QNN/GPU.unused/meson.build @@ -0,0 +1,13 @@ +gpu_headers=[ +'QnnGpuBackend.h', +'QnnGpuCommon.h', +'QnnGpuContext.h', +'QnnGpuGraph.h', +'QnnGpuOpPackage.h', +] + +qnn_dir_include_dir = quick_dot_ai_prefix / 'include' / 'nntrainer' / 'npu' / 'qnn' / 'QNN' + +install_subdir ('GPU', install_dir : qnn_dir_include_dir ) + +install_headers(gpu_headers, install_dir : qnn_dir_include_dir / 'GPU') \ No newline at end of file diff --git a/qnn/jni/qnn/QNN/HTP/QnnHtpCommon.h b/qnn/jni/qnn/QNN/HTP/QnnHtpCommon.h new file mode 100644 index 00000000..ac193199 --- /dev/null +++ b/qnn/jni/qnn/QNN/HTP/QnnHtpCommon.h @@ -0,0 +1,98 @@ +//============================================================================= +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================= + +/** @file + * @brief QNN HTP Common components + * + * This file defines versioning and other identification details + * and supplements QnnCommon.h for HTP backend + */ + +#ifndef QNN_HTP_COMMON_H +#define QNN_HTP_COMMON_H + +#include "QnnCommon.h" + +/// HTP Backend identifier +#define QNN_BACKEND_ID_HTP 6 + +/// HTP interface provider +#define QNN_HTP_INTERFACE_PROVIDER_NAME "HTP_QTI_AISW" + +// HTP API Version values +#define QNN_HTP_API_VERSION_MAJOR 5 +#define QNN_HTP_API_VERSION_MINOR 41 +#define QNN_HTP_API_VERSION_PATCH 0 + +// clang-format off + +/// Macro to set Qnn_ApiVersion_t for HTP backend +#define QNN_HTP_API_VERSION_INIT \ + { \ + { \ + QNN_API_VERSION_MAJOR, /*coreApiVersion.major*/ \ + QNN_API_VERSION_MINOR, /*coreApiVersion.major*/ \ + QNN_API_VERSION_PATCH /*coreApiVersion.major*/ \ + }, \ + { \ + QNN_HTP_API_VERSION_MAJOR, /*backendApiVersion.major*/ \ + QNN_HTP_API_VERSION_MINOR, /*backendApiVersion.minor*/ \ + QNN_HTP_API_VERSION_PATCH /*backendApiVersion.patch*/ \ + } \ + } + +// clang-format on + +// DSP Context blob Version values +#define QNN_HTP_CONTEXT_BLOB_VERSION_MAJOR 3 +#define QNN_HTP_CONTEXT_BLOB_VERSION_MINOR 3 +#define QNN_HTP_CONTEXT_BLOB_VERSION_PATCH 4 + +/* ==== CDSP Security Library Versioning ==== */ +/* ==== This information is only intended for OEMs ==== */ + +/* Security versioning for DSP libraries is supported V73 onwards */ +#define QNN_HTP_NATIVE_LIB_SECURITY_VERSIONING_MIN_ARCH 73 + +/* Here we will define CDSP library versions for different targets + * Version is increased whenever there is a security fix from CDSP + * The versioning will start from 1.0.0 for each new target + * */ + +/* V73 Security Issues: + * List of security issues fixed for V73 and the fixed version + * */ +#define QNN_HTP_V73_NATIVE_LIB_SECURITY_VERSION_MAJOR 1 +#define QNN_HTP_V73_NATIVE_LIB_SECURITY_VERSION_MINOR 0 +#define QNN_HTP_V73_NATIVE_LIB_SECURITY_VERSION_PATCH 0 + +/* V75 Security Issues: + * List of security issues fixed for V75 and the fixed version + * */ +// HTP Native library version values for V75 +#define QNN_HTP_V75_NATIVE_LIB_SECURITY_VERSION_MAJOR 1 +#define QNN_HTP_V75_NATIVE_LIB_SECURITY_VERSION_MINOR 0 +#define QNN_HTP_V75_NATIVE_LIB_SECURITY_VERSION_PATCH 0 + +/* V79 Security Issues: + * List of security issues fixed for V79 and the fixed version + * */ +// HTP Native library version values for V79 +#define QNN_HTP_V79_NATIVE_LIB_SECURITY_VERSION_MAJOR 1 +#define QNN_HTP_V79_NATIVE_LIB_SECURITY_VERSION_MINOR 0 +#define QNN_HTP_V79_NATIVE_LIB_SECURITY_VERSION_PATCH 0 + +/* V81 Security Issues: + * List of security issues fixed for V81 and the fixed version + * */ +// HTP Native library version values for V81 +#define QNN_HTP_V81_NATIVE_LIB_SECURITY_VERSION_MAJOR 1 +#define QNN_HTP_V81_NATIVE_LIB_SECURITY_VERSION_MINOR 0 +#define QNN_HTP_V81_NATIVE_LIB_SECURITY_VERSION_PATCH 0 + +#endif // QNN_HTP_COMMON_H diff --git a/qnn/jni/qnn/QNN/HTP/QnnHtpContext.h b/qnn/jni/qnn/QNN/HTP/QnnHtpContext.h new file mode 100644 index 00000000..d14cd563 --- /dev/null +++ b/qnn/jni/qnn/QNN/HTP/QnnHtpContext.h @@ -0,0 +1,246 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All rights reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc.s +// +//============================================================================== + +/** + * @file + * @brief QNN HTP component Context API. + * + * The interfaces in this file work with the top level QNN + * API and supplements QnnContext.h for HTP backend + */ + +#ifndef QNN_HTP_CONTEXT_H +#define QNN_HTP_CONTEXT_H + +#include "QnnContext.h" + +#ifdef __cplusplus +extern "C" { +#endif + +//============================================================================= +// Macros +//============================================================================= + +//============================================================================= +// Data Types +//============================================================================= + +/** + * @brief This enum provides different HTP context configuration + * options associated with QnnContext + */ +typedef enum { + QNN_HTP_CONTEXT_CONFIG_OPTION_WEIGHT_SHARING_ENABLED = 1, + QNN_HTP_CONTEXT_CONFIG_OPTION_REGISTER_MULTI_CONTEXTS = 2, + QNN_HTP_CONTEXT_CONFIG_OPTION_FILE_READ_MEMORY_BUDGET = 3, + QNN_HTP_CONTEXT_CONFIG_OPTION_DSP_MEMORY_PROFILING_ENABLED = 4, + QNN_HTP_CONTEXT_CONFIG_OPTION_SHARE_RESOURCES = 5, + QNN_HTP_CONTEXT_CONFIG_OPTION_IO_MEM_ESTIMATION = 6, + QNN_HTP_CONTEXT_CONFIG_OPTION_PREPARE_ONLY = 7, + QNN_HTP_CONTEXT_CONFIG_OPTION_INIT_ACCELERATION = 8, + QNN_HTP_CONTEXT_CONFIG_OPTION_SKIP_VALIDATION_ON_BINARY_SECTION = 9, + QNN_HTP_CONTEXT_CONFIG_OPTION_SHARE_RESOURCES_OPTIMIZATION_TYPE = 10, + QNN_HTP_CONTEXT_CONFIG_OPTION_USE_EXTENDED_UDMA = 11, + QNN_HTP_CONTEXT_CONFIG_OPTION_REGISTER_CONCURRENT_RESOURCE_SHARING = 12, + QNN_HTP_CONTEXT_CONFIG_OPTION_LORA_WEIGHT_SHARING_ENABLED = 13, + QNN_HTP_CONTEXT_CONFIG_OPTION_RESERVED_14 = 14, + QNN_HTP_CONTEXT_CONFIG_OPTION_RESERVED_15 = 15, + QNN_HTP_CONTEXT_CONFIG_OPTION_UNKNOWN = 0x7fffffff +} QnnHtpContext_ConfigOption_t; + +typedef struct { + // Handle referring to the first context associated to a group. When a new + // group is to be registered, the following value must be 0. + Qnn_ContextHandle_t firstGroupHandle; + // Max spill-fill buffer to be allocated for the group of context in bytes. + // The value that is passed during the registration of the first context to + // a group is taken. Subsequent configuration of this value is disregarded. + uint64_t maxSpillFillBuffer; +} QnnHtpContext_GroupRegistration_t; + +// This enum is supported only with the QnnContext_createFromBinaryListAsync API, when +// shareResources is true; otherwise, it is ignored. This enumeration allows users to specify how +// graphs are going to be executed, providing QNN with hints for optimizing memory. +typedef enum { + // Default value if no user input is provided. + // This type is used for sequential graph execution, optimizing both VA and memory. + SEQUENTIAL_WITH_VA_OPTIMIZATION, + // This type is used for sequential graph execution, optimizing memory. + SEQUENTIAL_WITHOUT_VA_OPTIMIZATION, + // This type is used for concurrent resource sharing, optimizing memory by sharing + // resources across contexts with the same priority level. + CONCURRENT_OPTIMIZATION, +} QnnHtpContext_ShareResourcesOptimizationType_t; + +//============================================================================= +// Public Functions +//============================================================================= + +//------------------------------------------------------------------------------ +// Implementation Definition +//------------------------------------------------------------------------------ + +// clang-format off + +/** + * @brief Structure describing the set of configurations supported by context. + * Objects of this type are to be referenced through QnnContext_CustomConfig_t. + * + * The struct has two fields - option and a union of config values + * Based on the option corresponding item in the union can be used to specify + * config. + * + * Below is the Map between QnnHtpContext_CustomConfig_t and config value + * + * \verbatim embed:rst:leading-asterisk + * +----+---------------------------------------------------------------------+--------------------------------------------------+ + * | # | Config Option | Configuration Struct/value | + * +====+=====================================================================+==================================================+ + * | 1 | QNN_HTP_CONTEXT_CONFIG_OPTION_WEIGHT_SHARING_ENABLED | bool | + * +====+=====================================================================+==================================================+ + * | 2 | QNN_HTP_CONTEXT_CONFIG_OPTION_REGISTER_MULTI_CONTEXTS | QnnHtpContext_GroupRegistration_t | + * +====+=====================================================================+==================================================+ + * | 3 | QNN_HTP_CONTEXT_CONFIG_OPTION_FILE_READ_MEMORY_BUDGET | uint64_t | + * +====+=====================================================================+==================================================+ + * | 4 | QNN_HTP_CONTEXT_CONFIG_OPTION_DSP_MEMORY_PROFILING_ENABLED | bool | + * +====+=====================================================================+==================================================+ + * | 5 | QNN_HTP_CONTEXT_CONFIG_OPTION_SHARE_RESOURCES | bool | + * +----+---------------------------------------------------------------------+--------------------------------------------------+ + * | 6 | QNN_HTP_CONTEXT_CONFIG_OPTION_IO_MEM_ESTIMATION | bool | + * +----+---------------------------------------------------------------------+--------------------------------------------------+ + * | 7 | QNN_HTP_CONTEXT_CONFIG_OPTION_PREPARE_ONLY | bool | + * +----+---------------------------------------------------------------------+--------------------------------------------------+ + * | 8 | QNN_HTP_CONTEXT_CONFIG_OPTION_INIT_ACCELERATION | bool | + * +----+---------------------------------------------------------------------+--------------------------------------------------+ + * | 9 | QNN_HTP_CONTEXT_CONFIG_OPTION_SKIP_VALIDATION_ON_BINARY_SECTION | bool | + * +----+---------------------------------------------------------------------+--------------------------------------------------+ + * | 10 | QNN_HTP_CONTEXT_CONFIG_OPTION_SHARE_RESOURCES_OPTIMIZATION_TYPE | QnnHtpContext_ShareResourcesOptimizationType_t | + * +----+---------------------------------------------------------------------+--------------------------------------------------+ + * | 11 | QNN_HTP_CONTEXT_CONFIG_OPTION_USE_EXTENDED_UDMA | bool | + * +----+---------------------------------------------------------------------+--------------------------------------------------+ + * | 12 | QNN_HTP_CONTEXT_CONFIG_OPTION_REGISTER_CONCURRENT_RESOURCE_SHARING | QnnHtpContext_GroupRegistration_t | + * +----+---------------------------------------------------------------------+--------------------------------------------------+ + * | 13 | QNN_HTP_CONTEXT_CONFIG_OPTION_LORA_WEIGHT_SHARING_ENABLED | bool | + * +----+---------------------------------------------------------------------+--------------------------------------------------+ + * \endverbatim + */ +typedef struct QnnHtpContext_CustomConfig { + QnnHtpContext_ConfigOption_t option; + union UNNAMED { + // This field sets the weight sharing which is by default false + bool weightSharingEnabled; + QnnHtpContext_GroupRegistration_t groupRegistration; + // - Init time may be impacted depending the value set below + // - Value should be grather than 0 and less than or equal to the file size + // - If set to 0, the feature is not utilized + // - If set to greater than file size, min(fileSize, fileReadMemoryBudgetInMb) is used + // - As an example, if value 2 is passed, it would translate to (2 * 1024 * 1024) bytes + uint64_t fileReadMemoryBudgetInMb; + bool dspMemoryProfilingEnabled; + // This field enables resource optimization. When it is set to true optimizations are + // done based on QnnHtpContext_ShareResourcesOptimizationType_t setting. + // Note This configuration option is only supported when using QnnContext_createFromBinaryListAsync API. + bool shareResources; + // This field enables I/O memory estimation during QnnContext_createFromBinary API when multiple + // PDs are available. When enabled, it estimates the total size of the I/O tensors required by + // the context to ensure sufficient space on the PD before deserialization. This feature helps + // with memory registration failures in large models. + // Note that enabling this feature increases peak RAM usage during context initialization phase + // in QnnContext_createFromBinary, but sustained RAM remains unaffected. + bool ioMemEstimation; + // This field enables model preparation without mapping its content on the DSP side. It is + // useful when a model needs to be prepared on the device but executed through a serialized + // binary method. This prevents extra mapping onto the DSP VA space. Set this flag only when + // creating the context. + bool isPrepareOnly; + // This field enables initialization acceleration, which is disabled by default. + // If set to true, the DSP will utilize all hardware threads to accelerate deserialization. + // It is not recommended to execute graphs simultaneously, as this will significantly degrade + // performance. + // Note that this feature may not be effective for small graphs with a few number of ops. + bool initAcceleration; + // This field enables crc32 check skip in Lora super adapter apply, which is disabled by default. + // If set to true, crc32 check for non-base adapter in super adapter apply use case will be + // skipped to improve time cost. + // Note that base adapter in super adaper never do crc32 check, therefore, their apply time cost + // won't improve by turning this config option on. + bool skipValidationOnBinarySection; + // If shareResources is true: + // shareResOptType is read. If no value is set by the user, + // the default value of QnnHtpContext_ShareResourcesOptimizationType_t is used. + // If shareResources is false: + // shareResOptType is ignored. + // Note: This configuration option is only supported when using the QnnContext_createFromBinaryListAsync API. + QnnHtpContext_ShareResourcesOptimizationType_t shareResOptType; + // This field enables preparing graphs, associated with this context, with far-mapping enabled so that weights + // and spill/fill buffer are mapped to the far region of the DSP which is helpful if PD's limited VA space is + // exhausted. Total RAM usage may increase if used together with shared weights. Only available for Hexagon + // arch v81 and above. + bool useExtendedUdma; + // This field enables concurrent resource sharing among graphs with the same priority level + // during the QnnContext_createFromBinary API on devices that support this capability. + QnnHtpContext_GroupRegistration_t concurrentGroupRegistration; + // This field sets the lora weight sharing. When it is set to true, one additional replaceable weight blob + // that contains the RP shared by all graphs will be generated and maintained. It is disabled by default. + bool loraWeightSharingEnabled; + }; +} QnnHtpContext_CustomConfig_t; + +/// QnnHtpContext_CustomConfig_t initializer macro +#define QNN_HTP_CONTEXT_CUSTOM_CONFIG_INIT \ + { \ + QNN_HTP_CONTEXT_CONFIG_OPTION_UNKNOWN, /*option*/ \ + { \ + false /*weightsharing*/\ + } \ + } + +/** + * @brief Structure describing the set of properties supported by context. + * Objects of this type are to be referenced through QnnContext_CustomProperty_t. + * Used by QnnContext_getProperty. + */ +typedef enum { + // get the alignment requirement of persistent buffers + QNN_HTP_CONTEXT_GET_PROP_BUFFER_START_ALIGNMENT = 1, + // get the size requirement of spill/fill buffer + QNN_HTP_CONTEXT_GET_PROP_MAX_SPILLFILL_BUFFER_SIZE = 2, + // get the size requirement of persistent weights buffer + QNN_HTP_CONTEXT_GET_PROP_WEIGHTS_BUFFER_SIZE = 3, + QNN_HTP_CONTEXT_GET_PROP_RESERVED_4 = 4, + QNN_HTP_CONTEXT_GET_PROP_RESERVED_5 = 5, + // Unused, present to ensure 32 bits. + QNN_HTP_CONTEXT_GET_PROP_UNDEFINED = 0x7fffffff +} QnnHtpContext_GetPropertyOption_t; + +// used by QnnContext_getProperty +typedef struct { + QnnHtpContext_GetPropertyOption_t option; + union UNNAMED { + uint64_t bufferStartAlignment; + uint64_t spillfillBufferSize; + uint64_t weightsBufferSize; + }; +} QnnHtpContext_CustomProperty_t; + +// clang-format off +/// QnnHtpContext_CustomProperty_t initializer macro +#define QNN_HTP_CONTEXT_CUSTOM_PROPERTY_INIT \ + { \ + QNN_HTP_CONTEXT_GET_PROP_UNDEFINED, /*option*/ \ + 0 /*scratchBufferSize*/ \ + } +// clang-format on + +// clang-format on +#ifdef __cplusplus +} // extern "C" +#endif + +#endif diff --git a/qnn/jni/qnn/QNN/HTP/QnnHtpDevice.h b/qnn/jni/qnn/QNN/HTP/QnnHtpDevice.h new file mode 100644 index 00000000..a76b24a3 --- /dev/null +++ b/qnn/jni/qnn/QNN/HTP/QnnHtpDevice.h @@ -0,0 +1,183 @@ +//============================================================================= +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================= + +/** @file + * @brief QNN HTP Device components + * + * This file defines structures and supplements QnnDevice.h for QNN HTP device + */ + +#pragma once + +#include "QnnCommon.h" +#include "QnnDevice.h" +#include "QnnHtpPerfInfrastructure.h" +#include "QnnTypes.h" +#ifdef __cplusplus +extern "C" { +#endif + +/** + * This is used to represent the HTP hardware architecture + * Since QnnDevice only supports V68 or newer, using legacy ARCH will result in error + */ +typedef enum { + QNN_HTP_DEVICE_ARCH_NONE = 0, + QNN_HTP_DEVICE_ARCH_V68 = 68, + QNN_HTP_DEVICE_ARCH_V69 = 69, + QNN_HTP_DEVICE_ARCH_V73 = 73, + QNN_HTP_DEVICE_ARCH_V75 = 75, + QNN_HTP_DEVICE_ARCH_V79 = 79, + QNN_HTP_DEVICE_ARCH_V81 = 81, + QNN_HTP_DEVICE_ARCH_V85 = 85, + QNN_HTP_DEVICE_ARCH_UNKNOWN = 0x7fffffff +} QnnHtpDevice_Arch_t; + +/** + * data struture to configure a device to set the minimum HTP Arch + * the driver will use ops that compatible to this HTP Arch + */ +typedef struct { + uint32_t deviceId; + QnnHtpDevice_Arch_t arch; +} QnnHtpDevice_Minimum_Arch_t; + +/** + * data struture to configure a device to running in Signed/unsigned Domain. + */ +typedef struct { + uint32_t deviceId; + bool useSignedProcessDomain; +} QnnHtpDevice_UseSignedProcessDomain_t; + +/** + * data struture to configure a device to running in Secure/normal Domain. + * running in secure process domain (SecurePD) is only supported in V81 and SecurePD is part of add-on SDK. + */ +typedef struct { + uint32_t deviceId; + bool useSecureProcessDomain; +} QnnHtpDevice_UseSecureProcessDomain_t; + +/** + * enum to list what custom configure is available. + */ +typedef enum { + QNN_HTP_DEVICE_CONFIG_OPTION_SOC = 0, + QNN_HTP_DEVICE_CONFIG_OPTION_ARCH = 1, + QNN_HTP_DEVICE_CONFIG_OPTION_SIGNEDPD = 2, + QNN_HTP_DEVICE_CONFIG_OPTION_SECUREPD = 3, + QNN_HTP_DEVICE_CONFIG_OPTION_UNKNOWN = 0x7fffffff +} QnnHtpDevice_ConfigOption_t; + +/** + * Data structure for custom configure. + */ +typedef struct { + QnnHtpDevice_ConfigOption_t option; + union UNNAMED { + // This field set the SoC Model + uint32_t socModel; + // This field update the minimum HTP arch + QnnHtpDevice_Minimum_Arch_t arch; + // This structure is used for enabling/disabling Signed/unsigned PD + QnnHtpDevice_UseSignedProcessDomain_t useSignedProcessDomain; + // This structure is used for enabling Secure PD + QnnHtpDevice_UseSecureProcessDomain_t useSecureProcessDomain; + }; +} QnnHtpDevice_CustomConfig_t; + +// For deviceType in QnnDevice_HardwareDeviceInfoV1_t +typedef enum { + QNN_HTP_DEVICE_TYPE_ON_CHIP = 0, // HTP cores are inside SoC + QNN_HTP_DEVICE_TYPE_UNKNOWN = 0x7fffffff +} QnnHtpDevice_DeviceType_t; + +/** + * @brief QNN HTP Device core type + * This enumeration provides information about the core type inside the SOC. + * + * For online operation, the caller should retrieve this information from + * `QnnDevice_getPlatformInfo`. For offline operation, the caller needs to create a + * `QnnDevice_CoreInfo_t` with the correct core type, and then use it to create the + * `QnnDevice_PlatformInfo_t`. + */ +typedef enum { + QNN_HTP_CORE_TYPE_NSP = 0, + QNN_HTP_CORE_TYPE_HPASS = 1, + + // supported coreType are < QNN_CORE_TYPE_MAX + QNN_HTP_CORE_TYPE_MAX, + QNN_HTP_CORE_TYPE_UNKNOWN = 0x7fffffff +} QnnHtpDevice_CoreType_t; + +/** + * This structure provides info about the NSP device inside SoC + * For online operation, caller should get these info from QnnDevice_getPlatformInfo + * For offline operation, caller need to create this structure and filling the correct information + * for QnnDevice_create + */ +typedef struct { + size_t vtcmSize; // The VTCM for this device in Mega Byte + // user could not request VTCM size exceed this value + uint32_t socModel; // An enum value defined in Qnn Header that represent SoC model + bool signedPdSupport; // This field is true if the device supports Signed PD + bool dlbcSupport; // This field is true if the device supports DLBC + QnnHtpDevice_Arch_t arch; // This field shows the Architecture of this device +} QnnHtpDevice_OnChipDeviceInfoExtension_t; + +/** + * This structure is being used in QnnDevice_HardwareDeviceInfoV1_t + * QnnDevice_getPlatformInfo use this structure to list the supported device features/info + */ +typedef struct _QnnDevice_DeviceInfoExtension_t { + QnnHtpDevice_DeviceType_t devType; + union UNNAMED { + QnnHtpDevice_OnChipDeviceInfoExtension_t onChipDevice; + }; +} QnnHtpDevice_DeviceInfoExtension_t; + +/** + * @brief QNN HTP Device PerfInfrastructure specialization structure. + * Objects of this type are to be referenced through QnnDevice_getInfrastructure. + * + * Contains function pointers for each interface method for + * Htp PerfInfrastructure. + */ +typedef struct { + QnnHtpPerfInfrastructure_CreatePowerConfigIdFn_t createPowerConfigId; + QnnHtpPerfInfrastructure_DestroyPowerConfigIdFn_t destroyPowerConfigId; + QnnHtpPerfInfrastructure_SetPowerConfigFn_t setPowerConfig; + QnnHtpPerfInfrastructure_SetMemoryConfigFn_t setMemoryConfig; +} QnnHtpDevice_PerfInfrastructure_t; + +/// QnnHtpDevice_PerfInfrastructure_t initializer macro +#define QNN_HTP_DEVICE_PERF_INFRASTRUCTURE_INIT \ + { \ + NULL, /*createPowerConfigId*/ \ + NULL, /*destroyPowerConfigId*/ \ + NULL, /*setPowerConfig*/ \ + NULL /*setMemoryConfig*/ \ + } + +typedef enum { + QNN_HTP_DEVICE_INFRASTRUCTURE_TYPE_PERF = 0, + QNN_HTP_DEVICE_INFRASTRUCTURE_TYPE_UNKNOWN = 0x7fffffff +} QnnHtpDevice_InfrastructureType_t; + +typedef struct _QnnDevice_Infrastructure_t { + QnnHtpDevice_InfrastructureType_t infraType; + union UNNAMED { + QnnHtpDevice_PerfInfrastructure_t perfInfra; + }; +} QnnHtpDevice_Infrastructure_t; + +// clang-format on +#ifdef __cplusplus +} // extern "C" +#endif diff --git a/qnn/jni/qnn/QNN/HTP/QnnHtpGraph.h b/qnn/jni/qnn/QNN/HTP/QnnHtpGraph.h new file mode 100644 index 00000000..fdeddcc7 --- /dev/null +++ b/qnn/jni/qnn/QNN/HTP/QnnHtpGraph.h @@ -0,0 +1,314 @@ +//============================================================================= +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================= + +/** + * @file + * @brief QNN HTP component Graph API. + * + * The interfaces in this file work with the top level QNN + * API and supplements QnnGraph.h for HTP backend + */ + +#ifndef QNN_HTP_GRAPH_H +#define QNN_HTP_GRAPH_H + +#include "QnnGraph.h" + +#ifdef __cplusplus +extern "C" { +#endif + +//============================================================================= +// Macros +//============================================================================= +/** + * @brief QnnHtpGraph config value macro. Represents to use the maximum + * available number of the resource. + * + * Currently only applicable for QNN_HTP_GRAPH_CONFIG_OPTION_VTCM_SIZE. + */ +#define QNN_HTP_GRAPH_CONFIG_OPTION_MAX 0 + +//============================================================================= +// Data Types +//============================================================================= + +/** + * @brief This enum provides different HTP graph optimization + * options that can be used to finalize the graph + * for optimum performance. + */ +typedef enum { + QNN_HTP_GRAPH_OPTIMIZATION_TYPE_SCHEDULE_THRESHOLD = 1, + QNN_HTP_GRAPH_OPTIMIZATION_TYPE_FINALIZE_RETRIES = 2, + QNN_HTP_GRAPH_OPTIMIZATION_TYPE_FINALIZE_OPTIMIZATION_FLAG = 3, + QNN_HTP_GRAPH_OPTIMIZATION_TYPE_ENABLE_DLBC = 4, + QNN_HTP_GRAPH_OPTIMIZATION_TYPE_ENABLE_DLBC_WEIGHTS = 5, + QNN_HTP_GRAPH_OPTIMIZATION_TYPE_ENABLE_SPARSE_WEIGHTS_COMPRESSION = 6, + QNN_HTP_GRAPH_OPTIMIZATION_TYPE_ENABLE_SLC_ALLOCATOR = 7, + QNN_HTP_GRAPH_OPTIMIZATION_TYPE_UNKNOWN = 0x7fffffff +} QnnHtpGraph_OptimizationType_t; + +// clang-format off + +/** + * @brief Struct describing the set of optimization types + * and the values associated with each optimization type. + * + * Below is the Map between QnnHtpGraph_OptimizationType_t and allowable values: + * + * \verbatim embed:rst:leading-asterisk + * +----+--------------------------------------------------------------------+---------------------------------------------------------------------+ + * | # | OptimizationType option | Allowable values | + * +====+====================================================================+=====================================================================+ + * | 1 | QNN_HTP_GRAPH_OPTIMIZATION_TYPE_SCHEDULE_THRESHOLD | Reserved | + * +----+--------------------------------------------------------------------+---------------------------------------------------------------------+ + * | 2 | QNN_HTP_GRAPH_OPTIMIZATION_TYPE_FINALIZE_RETRIES | Reserved | + * +----+--------------------------------------------------------------------+---------------------------------------------------------------------+ + * | 3 | QNN_HTP_GRAPH_OPTIMIZATION_TYPE_FINALIZE_OPTIMIZATION_FLAG | Defines the optimization strategy used by the HTP backend | + * +----+--------------------------------------------------------------------+---------------------------------------------------------------------+ + * | 4 | QNN_HTP_GRAPH_OPTIMIZATION_TYPE_ENABLE_DLBC | Reserved | + * +----+--------------------------------------------------------------------+---------------------------------------------------------------------+ + * | 5 | QNN_HTP_GRAPH_OPTIMIZATION_TYPE_ENABLE_DLBC_WEIGHTS | Enables DLBC weights compression | + * +----+--------------------------------------------------------------------+---------------------------------------------------------------------+ + * | 6 | QNN_HTP_GRAPH_OPTIMIZATION_TYPE_ENABLE_SPARSE_WEIGHTS_COMPRESSION | Enables Weight Sparsity Compression | + * +----+--------------------------------------------------------------------+---------------------------------------------------------------------+ + * | 7 | QNN_HTP_GRAPH_OPTIMIZATION_TYPE_ENABLE_SLC_ALLOCATOR | Enables System Level Cache Allocator usage | + * +----+--------------------------------------------------------------------+---------------------------------------------------------------------+ + * \endverbatim + */ +typedef struct { + QnnHtpGraph_OptimizationType_t type; + float floatValue; +} QnnHtpGraph_OptimizationOption_t; + +/** + * @brief This struct encapsulates all the VTCM configurations for parallel graph execution. + * + * @code + * |<-- (1) 8MB Total Hardware VTCM -->| + * |<-- (2) 7MB Addressable -->| + * +------+------+------+------+------+------+------+------+ + * | CV | | | | | | | | + * +------+------+------+------+------+------+------+------+ + * |<-- (4) Graph A -->|<-- (4) Graph B -->| + * + * A |> 0 MB (3) Graph Offset + * B |-------------------> 3 MB + * @endcode + */ +typedef struct { + /// (4) above, the amount of VTCM used by a graph + uint32_t sizeInBytes; + /// (3) above, where in the addressable region to start VTCM. + /// Note: (3) + (4) <= (2) + uint32_t offsetInBytes; + /// (2) Addressable portion of VTCM. + /// Set to less than hardware size so Graph(s) can coexist with other VTCM clients. + uint32_t sizeTotalInBytes; + + // For ABI compatibility in the future. + // Set to 0 for now. + uint32_t reserved[3]; +} QnnHtpGraph_VtcmConfig_t; + +/** + * @brief This enum defines whether graph concurrency (i.e. multiple graphs running concurrently) + * is possible, and how to behave when circumstances for concurrency aren't possible. + */ +typedef enum { + /// This graph will not be able to run concurrently with other graphs. + QNN_HTP_GRAPH_CONCURRENCY_OPTION_NONE = 0, + QNN_HTP_GRAPH_CONCURRENCY_OPTION_DEFAULT = QNN_HTP_GRAPH_CONCURRENCY_OPTION_NONE, + /// Graph will try to run concurrently, sharing all resources on the DSP (VTCM, HMX, HVX, etc). + QNN_HTP_GRAPH_CONCURRENCY_OPTION_ALL_SHARED = 1, + // Unused, present to ensure 32 bits. + QNN_HTP_GRAPH_CONCURRENCY_OPTION_UNKNOWN = 0x7fffffff +} QnnHtpGraph_ConcurrencyOption_t; + +/** + * @brief This struct encapsulates all the configurations for parallel graph execution. + */ +typedef struct { + QnnHtpGraph_ConcurrencyOption_t concurrency; + QnnHtpGraph_VtcmConfig_t vtcmConfig; + + // For ABI compatibility in the future. + // Set to 0 for now. + uint32_t reserved[4]; +} QnnHtpGraph_ParallelGraphExecutionConfig_t; +/// The settings in this struct is only applicable +/// for DSP architectures >= V81. +/// Use on other SOCs will return an error. +/// +/// Values will be defaulted to their SOC's TURBO frequency +/// (SOC as identified by Qnn_DeviceHandle_t). +/// +/// On automotive SDKs HMX OP Bounding will be enabled by default. +/// +/// On non-automotive SDKs using this setting will enable +/// HMX OP Bounding. It is off by default. +typedef struct QnnHtp_HmxBoundingInfo { + /// Target HMX freq in Hz. + /// Can be derived from sysMonApp (HexagonSDK) or QProfiler. + float targetHmxFreqHz; + /// Target DSP Core freq in Hz. + /// Can be derived from sysMonApp (HexagonSDK) or QProfiler. + float targetDspCoreFreq; +} QnnHtp_HmxBoundingInfo_t; + +/// QnnHtpGraph_OptimizationOption_t initializer macro +#define QNN_HTP_GRAPH_OPTIMIZATION_OPTION_INIT \ + { \ + QNN_HTP_GRAPH_OPTIMIZATION_TYPE_UNKNOWN, /*type*/ \ + 0.0f /*floatValue*/ \ + } +// clang-format on + +/** + * @brief This enum provides different HTP graph configuration + * options associated with QnnGraph + */ +typedef enum { + QNN_HTP_GRAPH_CONFIG_OPTION_OPTIMIZATION = 1, + QNN_HTP_GRAPH_CONFIG_OPTION_PRECISION = 2, + QNN_HTP_GRAPH_CONFIG_OPTION_VTCM_SIZE_IN_MB = 3, + QNN_HTP_GRAPH_CONFIG_OPTION_VTCM_SIZE = QNN_HTP_GRAPH_CONFIG_OPTION_VTCM_SIZE_IN_MB, + QNN_HTP_GRAPH_CONFIG_OPTION_FOLD_RELU_ACTIVATION_INTO_CONV_OFF = 4, + QNN_HTP_GRAPH_CONFIG_OPTION_SHORT_DEPTH_CONV_ON_HMX_OFF = 5, + QNN_HTP_GRAPH_CONFIG_OPTION_NUM_HVX_THREADS = 6, + QNN_HTP_GRAPH_CONFIG_OPTION_FINALIZE_CONFIG = 7, + QNN_HTP_GRAPH_CONFIG_OPTION_NUM_CORES = 8, + QNN_HTP_GRAPH_CONFIG_OPTION_PARALLEL_GRAPH_EXECUTION_CONFIG = 9, + QNN_HTP_GRAPH_CONFIG_OPTION_VTCM_SIZE_IN_BYTES = 10, + QNN_HTP_GRAPH_CONFIG_OPTION_HMX_BOUNDING = 11, + QNN_HTP_GRAPH_CONFIG_OPTION_WEIGHTS_PACKING = 12, + QNN_HTP_GRAPH_CONFIG_OPTION_ASSUME_SAME_QUANT = 13, + QNN_HTP_GRAPH_CONFIG_OPTION_SHARE_IO_BUFFER = 14, + QNN_HTP_GRAPH_CONFIG_OPTION_ADVANCED_ACTIVATION_FUSION = 15, + QNN_HTP_GRAPH_CONFIG_OPTION_HIGH_PRECISION_SIGMOID = 16, + QNN_HTP_GRAPH_CONFIG_OPTION_MONOLITHIC_LSTM = 17, + QNN_HTP_GRAPH_CONFIG_OPTION_RESERVED = 0x7fff0000, + QNN_HTP_GRAPH_CONFIG_OPTION_UNKNOWN = 0x7fffffff +} QnnHtpGraph_ConfigOption_t; + +//============================================================================= +// Public Functions +//============================================================================= + +//------------------------------------------------------------------------------ +// Implementation Definition +//------------------------------------------------------------------------------ + +/** + * @brief A struct for different config parameters in a key value format. + */ +typedef struct { + const char* key; + Qnn_Scalar_t value; +} QnnHtpGraph_FinalizeConfig_t; + +/** + * @brief Structure describing the set of configurations supported by graph. + * Objects of this type are to be referenced through QnnGraph_CustomConfig_t. + * + * The struct has two fields - option and a union of corresponding config values + * Based on the option corresponding item in the union can be used to specify + * config. + * + * Below is the Map between QnnHtpGraph_ConfigOption_t and config value + * + * \verbatim embed:rst:leading-asterisk + * +----+-------------------------------------------------------------------------------------+------------------------------------------------+ + * | # | Config Option | Configuration Struct/value | + * +====+=====================================================================================+================================================+ + * | 1 | QNN_HTP_GRAPH_CONFIG_OPTION_OPTIMIZATION | QnnHtpGraph_OptimizationOption_t + * | + * +----+-------------------------------------------------------------------------------------+------------------------------------------------+ + * | 2 | QNN_HTP_GRAPH_CONFIG_OPTION_PRECISION | Qnn_Precision_t | + * +----+-------------------------------------------------------------------------------------+------------------------------------------------+ + * | 3 | + * QNN_HTP_GRAPH_CONFIG_OPTION_VTCM_SIZE_IN_MB/QNN_HTP_GRAPH_CONFIG_OPTION_VTCM_SIZE | uint32_t | + * +----+-------------------------------------------------------------------------------------+------------------------------------------------+ + * | 4 | QNN_HTP_GRAPH_CONFIG_OPTION_FOLD_RELU_ACTIVATION_INTO_CONV_OFF | bool | + * +----+-------------------------------------------------------------------------------------+------------------------------------------------+ + * | 5 | QNN_HTP_GRAPH_CONFIG_OPTION_SHORT_DEPTH_CONV_ON_HMX_OFF | bool | + * +----+-------------------------------------------------------------------------------------+------------------------------------------------+ + * | 6 | QNN_HTP_GRAPH_CONFIG_OPTION_NUM_HVX_THREADS | uint32_t | + * +----+-------------------------------------------------------------------------------------+------------------------------------------------+ + * | 7 | QNN_HTP_GRAPH_CONFIG_OPTION_FINALIZE_CONFIG | QnnHtpGraph_FinalizeConfig_t | + * +----+-------------------------------------------------------------------------------------+------------------------------------------------+ + * | 8 | QNN_HTP_GRAPH_CONFIG_OPTION_NUM_CORES | uint32_t | + * +----+-------------------------------------------------------------------------------------+------------------------------------------------+ + * | 9 | QNN_HTP_GRAPH_CONFIG_OPTION_PARALLEL_GRAPH_EXECUTION_CONFIG | + * QnnHtpGraph_ParallelGraphExecutionConfig_t | + * +----+-------------------------------------------------------------------------------------+------------------------------------------------+ + * | 10 | QNN_HTP_GRAPH_CONFIG_OPTION_VTCM_SIZE_IN_BYTES | uint32_t | + * +----+-------------------------------------------------------------------------------------+------------------------------------------------+ + * | 11 | QNN_HTP_GRAPH_CONFIG_OPTION_HMX_BOUNDING | uint32_t | + * +----+-------------------------------------------------------------------------------------+------------------------------------------------+ + * | 12 | QNN_HTP_GRAPH_CONFIG_OPTION_WEIGHTS_PACKING | bool | + * +----+-------------------------------------------------------------------------------------+------------------------------------------------+ + * | 13 | QNN_HTP_GRAPH_CONFIG_OPTION_ASSUME_SAME_QUANT | bool | + * +----+-------------------------------------------------------------------------------------+------------------------------------------------+ + * | 14 | QNN_HTP_GRAPH_CONFIG_OPTION_SHARE_IO_BUFFER | bool | + * +----+-------------------------------------------------------------------------------------+------------------------------------------------+ + * | 15 | QNN_HTP_GRAPH_CONFIG_OPTION_ADVANCED_ACTIVATION_FUSION | bool | + * +----+-------------------------------------------------------------------------------------+------------------------------------------------+ + * | 16 | QNN_HTP_GRAPH_CONFIG_OPTION_HIGH_PRECISION_SIGMOID | bool | + * +----+-------------------------------------------------------------------------------------+------------------------------------------------+ + * | 17 | QNN_HTP_GRAPH_CONFIG_OPTION_MONOLITHIC_LSTM | bool | + * +-------------------------+----------------------------------------------------------------+------------------------------------------------+ + * | 0x7fff0000 - 0x7ffffffe | QNN_HTP_GRAPH_CONFIG_OPTION_RESERVED | These are + * reserved for internal purposes | + * +-------------------------+----------------------------------------------------------------+------------------------------------------------+ + * \endverbatim + * + * NOTE: Option #6 (i.e. QNN_HTP_GRAPH_CONFIG_OPTION_NUM_HVX_THREADS), can only be + * set prior to the first execution of the graph. Proceeding executions will not use + * the updated value if user does change it after the first execution. + */ +typedef struct { + QnnHtpGraph_ConfigOption_t option; + union { + QnnHtpGraph_OptimizationOption_t optimizationOption; + Qnn_Precision_t precision; + uint32_t vtcmSizeInMB; + bool foldReluActivationIntoConvOff; + bool shortDepthConvOnHmxOff; + uint64_t numHvxThreads; + void* reserved; + QnnHtpGraph_FinalizeConfig_t finalizeConfig; + uint32_t numCores; + QnnHtpGraph_ParallelGraphExecutionConfig_t parallelGraphExecutionConfig; + uint32_t vtcmSizeInBytes; + QnnHtp_HmxBoundingInfo_t hmxBoundingInfo; + bool weightsPacking; + bool assumeSameQuant; + bool shareIOBuffer; + bool advancedActivationFusion; + bool highPrecisionSigmoid; + bool monolithicLstm; + }; +} QnnHtpGraph_CustomConfig_t; + +// clang-format on +/// QnnHtpGraph_CustomConfig_t initializer macro +#define QNN_HTP_GRAPH_CUSTOM_CONFIG_INIT \ + { \ + QNN_HTP_GRAPH_CONFIG_OPTION_UNKNOWN, /*option*/ \ + { \ + QNN_HTP_GRAPH_OPTIMIZATION_OPTION_INIT /*optimizationOption*/ \ + } \ + } + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif diff --git a/qnn/jni/qnn/QNN/HTP/QnnHtpMem.h b/qnn/jni/qnn/QNN/HTP/QnnHtpMem.h new file mode 100644 index 00000000..d8442591 --- /dev/null +++ b/qnn/jni/qnn/QNN/HTP/QnnHtpMem.h @@ -0,0 +1,89 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#ifndef QNN_HTP_MEMORY_INFRASTRUCTURE_2_H +#define QNN_HTP_MEMORY_INFRASTRUCTURE_2_H + +#include "QnnCommon.h" + +/** + * @file + * @brief QNN HTP Memory Infrastructure component API. + */ + +#ifdef __cplusplus +extern "C" { +#endif + +//============================================================================= +// VTCM +//============================================================================= + +// clang-format off + +/** + * @brief Raw memory address that exists ONLY on the QURT + * side. + */ +typedef uint32_t QnnHtpMem_QurtAddress_t; + +/** + * @brief Configuration for custom shared buffer memory type + * This shared buffer is a contiguous chunk of memory identified + * by a single file descriptor which will be used by multiple tensors + * based on the offset provided + * Each QnnMem_register call with different offset will return a + * unique memory handle + */ +typedef struct { + // File descriptor for memory, must be set to QNN_MEM_INVALID_FD if not applicable + int32_t fd; + // Offset to be used in contiguous shared buffer + uint64_t offset; +} QnnHtpMem_SharedBufferConfig_t; + +// clang-format off + +/** + * @brief QNN Memory Type + */ +typedef enum { + QNN_HTP_MEM_QURT = 0, + QNN_HTP_MEM_SHARED_BUFFER = 1, + QNN_HTP_MEM_WEIGHTS_BUFFER = 2, + QNN_HTP_MEM_SHARED_SPILLFILL_BUFFER = 3, + QNN_HTP_MEM_UNDEFINED = 0x7FFFFFFF +} QnnHtpMem_Type_t; + +// clang-format off + +/** + * @brief descriptor used for the QNN API + */ +typedef struct { + // Memory type identified by QnnHtpMem_Type_t + QnnHtpMem_Type_t type; + // Total size of the buffer + // For memory type QURT, it would be size of a tensor + // For memory type SHARED BUFFER, it would be the total size of the buffer + uint64_t size; + + union { + QnnHtpMem_QurtAddress_t qurtAddress; + QnnHtpMem_SharedBufferConfig_t sharedBufferConfig; + QnnHtpMem_SharedBufferConfig_t weightsBufferConfig; + QnnHtpMem_SharedBufferConfig_t sharedSpillfillBufferConfig; + }; +} QnnMemHtp_Descriptor_t; + +// clang-format on +#ifdef __cplusplus +} // extern "C" +#endif + +#endif diff --git a/qnn/jni/qnn/QNN/HTP/QnnHtpPerfInfrastructure.h b/qnn/jni/qnn/QNN/HTP/QnnHtpPerfInfrastructure.h new file mode 100644 index 00000000..d6fbebc1 --- /dev/null +++ b/qnn/jni/qnn/QNN/HTP/QnnHtpPerfInfrastructure.h @@ -0,0 +1,543 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All rights reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +/** @file + * @brief QNN HTP component Performance Infrastructure API + * + * Provides interface to the client to control performance and system + * settings of the QNN HTP Accelerator + */ + +#ifndef QNN_HTP_PERF_INFRASTRUCTURE_H +#define QNN_HTP_PERF_INFRASTRUCTURE_H + +#include "QnnCommon.h" +#include "QnnTypes.h" + +#ifdef __cplusplus +extern "C" { +#endif + +// max rpc polling time allowed - 9999 us +#define QNN_HTP_PERF_INFRASTRUCTURE_POWER_CONFIG_MAX_RPC_POLLING_TIME 9999 + +//============================================================================= +// Data Types +//============================================================================= + +/** + * @brief QNN HTP PerfInfrastructure API result / error codes. + * + */ +typedef enum { + QNN_HTP_PERF_INFRASTRUCTURE_MIN_ERROR = QNN_MIN_ERROR_PERF_INFRASTRUCTURE, + //////////////////////////////////////////////////////////////////////// + + QNN_HTP_PERF_INFRASTRUCTURE_NO_ERROR = QNN_SUCCESS, + QNN_HTP_PERF_INFRASTRUCTURE_ERROR_INVALID_HANDLE_PTR = QNN_MIN_ERROR_PERF_INFRASTRUCTURE + 0, + QNN_HTP_PERF_INFRASTRUCTURE_ERROR_INVALID_INPUT = QNN_MIN_ERROR_PERF_INFRASTRUCTURE + 1, + QNN_HTP_PERF_INFRASTRUCTURE_ERROR_UNSUPPORTED_CONFIG = QNN_MIN_ERROR_PERF_INFRASTRUCTURE + 2, + QNN_HTP_PERF_INFRASTRUCTURE_ERROR_TRANSPORT = QNN_MIN_ERROR_PERF_INFRASTRUCTURE + 3, + QNN_HTP_PERF_INFRASTRUCTURE_ERROR_UNSUPPORTED = QNN_MIN_ERROR_PERF_INFRASTRUCTURE + 4, + QNN_HTP_PERF_INFRASTRUCTURE_ERROR_MEM_ALLOC = QNN_MIN_ERROR_PERF_INFRASTRUCTURE + 5, + QNN_HTP_PERF_INFRASTRUCTURE_ERROR_FAILED = QNN_MIN_ERROR_PERF_INFRASTRUCTURE + 6, + + //////////////////////////////////////////////////////////////////////// + QNN_HTP_PERF_INFRASTRUCTURE_MAX_ERROR = QNN_MAX_ERROR_PERF_INFRASTRUCTURE, + /// UNDEFINED value that must not be used by client + QNN_HTP_PERF_INFRASTRUCTURE_ERROR_UNDEFINED = 0x7fffffff +} QnnHtpPerfInfrastructure_Error_t; + +/** + * @brief Allows client to consider (non-zero value) DCVS enable/disable + * and option parameters, otherwise (zero value) + * + */ +typedef uint32_t QnnHtpPerfInfrastructure_SetDcvsEnable_t; + +/** + * @brief Allows client to start (non-zero value) or stop (zero value) + * participating in DCVS + * + */ +typedef uint32_t QnnHtpPerfInfrastructure_DcvsEnable_t; + +/** + * @brief Allows client to consider (non-zero value) latency parameter, + * otherwise (zero value) + * + */ +typedef uint32_t QnnHtpPerfInfrastructure_SetSleepLatency_t; + +/** + * @brief Allows client to set up the sleep latency in microseconds + * + */ +typedef uint32_t QnnHtpPerfInfrastructure_SleepLatency_t; + +/** + * @brief Allows client to consider (non-zero value) sleep disable + * parameter, otherwise (zero value) + * + */ +typedef uint32_t QnnHtpPerfInfrastructure_SetSleepDisable_t; + +/** + * @brief Allows client to disable sleep or low power modes. + * Pass a non-zero value to disable sleep in HTP + * + */ +typedef uint32_t QnnHtpPerfInfrastructure_SleepDisable_t; + +/** + * @brief Allows client to consider (non-zero value) bus clock + * params, otherwise (zero value) + * + */ +typedef uint32_t QnnHtpPerfInfrastructure_SetBusParams_t; + +/** + * @brief Allows client consider (non-zero value) core clock + * params, otherwise (zero value) + * + */ +typedef uint32_t QnnHtpPerfInfrastructure_SetCoreParams_t; + +/** + * @brief Allows client to set up the RPC control latency in microseconds + * + */ +typedef uint32_t QnnHtpPerfInfrastructure_RpcControlLatency_t; + +/** + * @brief Allows client to set up the RPC polling time in microseconds + */ +typedef uint32_t QnnHtpPerfInfrastructure_RpcPollingTime_t; + +/** + * @brief Allows client to set up the adaptive polling time in microseconds + */ +typedef uint32_t QnnHtpPerfInfrastructure_AdaptivePollingTime_t; + +/** + * @brief Allows client to enable (non-zero value) or disable (zero value) + * DDR performance mode + */ +typedef uint32_t QnnHtpPerfInfrastructure_DdrPerfMode_t; + +/** + * @brief Allows client to set up the HMX timeout interval in microseconds + */ +typedef uint32_t QnnHtpPerfInfrastructure_HmxTimeoutIntervalUs_t; + +/** + * @brief sets the minimum size by which user heap should grow + * when heap is exhausted. This API is expected to be + * called only once per backend and has a process wide impact + * + * Grow size provided in bytes and defaults to 16MB + */ +typedef uint32_t QnnHtpPerfInfrastructure_MemGrowSize_t; + +/** + * @brief Allows client to set default values for HMX frequency. + * If enabled 1 HMX vote will scale with DCVS Corner if 0 HMX vote + * needs to be specified manually. + * + */ +typedef uint32_t QnnHtpPerfInfrastructure_HmxDefault_Vote_t; + +/** + * @brief Perf modes to specify clock frequency level within + * target voltage corner currently applies only for HMX config. + */ +typedef enum { + // To select max frequency at target voltage corner. + QNN_HTP_PERF_INFRASTRUCTURE_CLK_PERF_HIGH = 0, + // To select min frequency at target voltage corner. + QNN_HTP_PERF_INFRASTRUCTURE_CLK_PERF_LOW, + /// UNKNOWN value that must not be used by client + QNN_HTP_PERF_INFRASTRUCTURE_CLK_PERF_UNKNOWN = 0x7fffffff +} QnnHtpPerfInfrastructure_ClkPerfMode_t; + +/** + * @brief These are the different voltage corners that can + * be requested by the client to influence the voting scheme + * for DCVS + * + */ +typedef enum { + /// Maps to HAP_DCVS_VCORNER_DISABLE. + /// Disable setting up voltage corner + DCVS_VOLTAGE_CORNER_DISABLE = 0x10, + /// Maps to HAP_DCVS_VCORNER_SVS2. + /// Set voltage corner to minimum value supported on platform + DCVS_VOLTAGE_VCORNER_MIN_VOLTAGE_CORNER = 0x20, + /// Maps to HAP_DCVS_VCORNER_SVS2. + /// Set voltage corner to SVS2 value for the platform + DCVS_VOLTAGE_VCORNER_SVS2 = 0x30, + /// Maps to HAP_DCVS_VCORNER_SVS. + /// Set voltage corner to SVS value for the platform + DCVS_VOLTAGE_VCORNER_SVS = 0x40, + /// Maps to HAP_DCVS_VCORNER_SVS_PLUS. + /// Set voltage corner to SVS_PLUS value for the platform + DCVS_VOLTAGE_VCORNER_SVS_PLUS = 0x50, + /// Maps to HAP_DCVS_VCORNER_NOM. + /// Set voltage corner to NOMINAL value for the platform + DCVS_VOLTAGE_VCORNER_NOM = 0x60, + /// Maps to HAP_DCVS_VCORNER_NOM_PLUS. + /// Set voltage corner to NOMINAL_PLUS value for the platform + DCVS_VOLTAGE_VCORNER_NOM_PLUS = 0x70, + /// Maps to HAP_DCVS_VCORNER_TURBO. + /// Set voltage corner to TURBO value for the platform + DCVS_VOLTAGE_VCORNER_TURBO = 0x80, + /// Maps to HAP_DCVS_VCORNER_TURBO_PLUS. + /// Set voltage corner to TURBO_PLUS value for the platform + DCVS_VOLTAGE_VCORNER_TURBO_PLUS = 0x90, + /// Maps to HAP_DCVS_VCORNER_TURBO_L2. + /// Set voltage corner to TURBO_L2 value for the platform + DCVS_VOLTAGE_VCORNER_TURBO_L2 = 0x92, + /// Maps to HAP_DCVS_VCORNER_TURBO_L3. + /// Set voltage corner to TURBO_L3 value for the platform + DCVS_VOLTAGE_VCORNER_TURBO_L3 = 0x93, + /// Maps to HAP_DCVS_VCORNER_TURBO_L4. + /// Set voltage corner to TURBO_L4 value for the platform + DCVS_VOLTAGE_VCORNER_TURBO_L4 = 0x94, + /// Maps to HAP_DCVS_VCORNER_TURBO_L5. + /// Set voltage corner to TURBO_L5 value for the platform + DCVS_VOLTAGE_VCORNER_TURBO_L5 = 0x95, + /// Maps to HAP_DCVS_VCORNER_MAX. + /// Set voltage corner to maximum value supported on the platform + DCVS_VOLTAGE_VCORNER_MAX_VOLTAGE_CORNER = 0xA0, + /// UNKNOWN value that must not be used by client + DCVS_VOLTAGE_VCORNER_UNKNOWN = 0x7fffffff +} QnnHtpPerfInfrastructure_VoltageCorner_t; + +/** + * @brief These are the expanded voltage corners that can + * be requested by the client to influence the voting scheme + * for DCVS + * + */ +typedef enum { + /// Maps to HAP_DCVS_EXP_VCORNER_DISABLE. + /// Disable setting up voltage corner + DCVS_EXP_VCORNER_DISABLE = 0, + /// Maps to HAP_DCVS_EXP_VCORNER_MIN. + /// Set voltage corner to minimum value supported on platform + DCVS_EXP_VCORNER_MIN = 0x100, + /// Maps to HAP_DCVS_EXP_VCORNER_LOW_SVS_D2. + /// Set voltage corner to LOWSVS_D2 value for the platform + DCVS_EXP_VCORNER_LOW_SVS_D2 = 0x134, + /// Maps to HAP_DCVS_EXP_VCORNER_LOW_SVS_D1. + /// Set voltage corner to LOWSVS_D1 value for the platform + DCVS_EXP_VCORNER_LOW_SVS_D1 = 0x138, + /// Maps to HAP_DCVS_EXP_VCORNER_LOW_SVS. + /// Set voltage corner to LOWSVS value for the platform + DCVS_EXP_VCORNER_LOW_SVS = 0x140, + /// Maps to HAP_DCVS_EXP_VCORNER_LOW_SVS_L0. + /// Set voltage corner to LOWSVS_L0 value for the platform + DCVS_EXP_VCORNER_LOW_SVS_L0 = 0x14C, + /// Maps to HAP_DCVS_EXP_VCORNER_SVS. + /// Set voltage corner to SVS value for the platform + DCVS_EXP_VCORNER_SVS = 0x180, + /// Maps to HAP_DCVS_EXP_VCORNER_SVS_L0. + /// Set voltage corner to SVS_L0 value for the platform + DCVS_EXP_VCORNER_SVS_L0 = 0x190, + /// Maps to HAP_DCVS_EXP_VCORNER_SVS_L1. + /// Set voltage corner to SVS_L1 value for the platform + DCVS_EXP_VCORNER_SVS_L1 = 0x1C0, + /// Maps to HAP_DCVS_EXP_VCORNER_SVS_L2. + /// Set voltage corner to SVS_L2 value for the platform + DCVS_EXP_VCORNER_SVS_L2 = 0x1E0, + /// Maps to HAP_DCVS_EXP_VCORNER_NOM. + /// Set voltage corner to NOM value for the platform + DCVS_EXP_VCORNER_NOM = 0x200, + /// Maps to HAP_DCVS_EXP_VCORNER_NOM_L1. + /// Set voltage corner to NOM_L1 value for the platform + DCVS_EXP_VCORNER_NOM_L1 = 0x240, + /// Maps to HAP_DCVS_EXP_VCORNER_NOM_L2. + /// Set voltage corner to NOM_L2 value for the platform + DCVS_EXP_VCORNER_NOM_L2 = 0x250, + /// Maps to HAP_DCVS_EXP_VCORNER_TUR. + /// Set voltage corner to TURBO value for the platform + DCVS_EXP_VCORNER_TUR = 0x280, + /// Maps to HAP_DCVS_EXP_VCORNER_TUR_L1. + /// Set voltage corner to TURBO_L1 value for the platform + DCVS_EXP_VCORNER_TUR_L1 = 0x2A0, + /// Maps to HAP_DCVS_EXP_VCORNER_TUR_L2. + /// Set voltage corner to TURBO_L2 value for the platform + DCVS_EXP_VCORNER_TUR_L2 = 0x2B0, + /// Maps to HAP_DCVS_EXP_VCORNER_TUR_L3. + /// Set voltage corner to TURBO_L3 value for the platform + DCVS_EXP_VCORNER_TUR_L3 = 0x2C0, + /// Maps to HAP_DCVS_EXP_VCORNER_MAX. + /// Selects the maximum voltage corner defined for the chipset + DCVS_EXP_VCORNER_MAX = 0xFFFF, + /// UNKNOWN value that must not be used by client + DCVS_EXP_VCORNER_UNKNOWN = 0x7fffffff +} QnnHtpPerfInfrastructure_ExpVoltageCorner_t; + +/** + * @brief This enum defines all the possible power mode + * that a client can set to influence DCVS mode + */ +typedef enum { + /// Maps to HAP_DCVS_V2_ADJUST_UP_DOWN. + /// Allows for DCVS to adjust up and down + QNN_HTP_PERF_INFRASTRUCTURE_POWERMODE_ADJUST_UP_DOWN = 0x1, + /// Maps to HAP_DCVS_V2_ADJUST_ONLY_UP. + /// Allows for DCVS to adjust up only + QNN_HTP_PERF_INFRASTRUCTURE_POWERMODE_ADJUST_ONLY_UP = 0x2, + /// Maps to HAP_DCVS_V2_POWER_SAVER_MODE. + /// Higher thresholds for power efficiency + QNN_HTP_PERF_INFRASTRUCTURE_POWERMODE_POWER_SAVER_MODE = 0x4, + /// Maps to HAP_DCVS_V2_POWER_SAVER_AGGRESSIVE_MODE. + /// Higher thresholds for power efficiency with faster ramp down + QNN_HTP_PERF_INFRASTRUCTURE_POWERMODE_POWER_SAVER_AGGRESSIVE_MODE = 0x8, + /// Maps to HAP_DCVS_V2_PERFORMANCE_MODE. + /// Lower thresholds for maximum performance + QNN_HTP_PERF_INFRASTRUCTURE_POWERMODE_PERFORMANCE_MODE = 0x10, + /// Maps to HAP_DCVS_V2_DUTY_CYCLE_MODE. + /// The below value applies only for HVX clients: + /// - For streaming class clients: + /// - detects periodicity based on HVX usage + /// - lowers clocks in the no HVX activity region of each period. + /// - For compute class clients: + /// - Lowers clocks on no HVX activity detects and brings clocks up on detecting HVX activity + /// again. + /// - Latency involved in bringing up the clock will be at max 1 to 2 ms. + QNN_HTP_PERF_INFRASTRUCTURE_POWERMODE_DUTY_CYCLE_MODE = 0x20, + /// UNKNOWN value that must not be used by client + QNN_HTP_PERF_INFRASTRUCTURE_POWERMODE_UNKNOWN = 0x7fffffff +} QnnHtpPerfInfrastructure_PowerMode_t; + +/** + * @brief This struct provides performance infrastructure configuration + * associated with setting up of DcvsV3 which allows to select + * bus and core operating corners separately + */ +typedef struct { + uint32_t contextId; + QnnHtpPerfInfrastructure_SetDcvsEnable_t setDcvsEnable; + QnnHtpPerfInfrastructure_DcvsEnable_t dcvsEnable; + QnnHtpPerfInfrastructure_PowerMode_t powerMode; + QnnHtpPerfInfrastructure_SetSleepLatency_t setSleepLatency; + QnnHtpPerfInfrastructure_SleepLatency_t sleepLatency; + QnnHtpPerfInfrastructure_SetSleepDisable_t setSleepDisable; + QnnHtpPerfInfrastructure_SleepDisable_t sleepDisable; + QnnHtpPerfInfrastructure_SetBusParams_t setBusParams; + QnnHtpPerfInfrastructure_VoltageCorner_t busVoltageCornerMin; + QnnHtpPerfInfrastructure_VoltageCorner_t busVoltageCornerTarget; + QnnHtpPerfInfrastructure_VoltageCorner_t busVoltageCornerMax; + QnnHtpPerfInfrastructure_SetCoreParams_t setCoreParams; + QnnHtpPerfInfrastructure_VoltageCorner_t coreVoltageCornerMin; + QnnHtpPerfInfrastructure_VoltageCorner_t coreVoltageCornerTarget; + QnnHtpPerfInfrastructure_VoltageCorner_t coreVoltageCornerMax; +} QnnHtpPerfInfrastructure_DcvsV3_t; + +/** + * @brief This struct provides performance infrastructure configuration + * associated with setting up of hmxv2 which allows to select + * hmx corner separately. If hmxPickDefault is 1 all voltage corner + * params will be ignored. Ensure to use same contextID as used for + * DCVS vote. + */ +typedef struct { + QnnHtpPerfInfrastructure_HmxDefault_Vote_t hmxPickDefault; + QnnHtpPerfInfrastructure_ExpVoltageCorner_t hmxVoltageCornerMin; + QnnHtpPerfInfrastructure_ExpVoltageCorner_t hmxVoltageCornerTarget; + QnnHtpPerfInfrastructure_ExpVoltageCorner_t hmxVoltageCornerMax; + QnnHtpPerfInfrastructure_ClkPerfMode_t hmxPerfMode; +} QnnHtpPerfInfrastructure_HmxV2_t; + +/** + * @brief This enum defines all the possible performance + * options in Htp Performance Infrastructure that + * relate to setting up of power levels + */ +typedef enum { + /// config enum implies the usage of Dcvs v3 + QNN_HTP_PERF_INFRASTRUCTURE_POWER_CONFIGOPTION_DCVS_V3 = 1, + /// config enum implies the usage of rpcControlLatencyConfig struct + QNN_HTP_PERF_INFRASTRUCTURE_POWER_CONFIGOPTION_RPC_CONTROL_LATENCY = 2, + /// config enum implies the usage of rpcPollingTimeConfig struct + /// this config is only supported on V69 and later + /// if enabled, this config is applied to entire process + /// max allowed is QNN_HTP_PERF_INFRASTRUCTURE_POWER_CONFIG_MAX_RPC_POLLING_TIME us + QNN_HTP_PERF_INFRASTRUCTURE_POWER_CONFIGOPTION_RPC_POLLING_TIME = 3, + /// config HMX timeout interval in us. The HMX is turned off after the set interval + /// time if no interaction with it after an inference is finished. + QNN_HTP_PERF_INFRASTRUCTURE_POWER_CONFIGOPTION_HMX_TIMEOUT_INTERVAL_US = 4, + /// config HMX V2 voting parameters only on supported chips + QNN_HTP_PERF_INFRASTRUCTURE_POWER_CONFIGOPTION_HMX_V2 = 5, + /// config enum implies the usage of adaptivePollingTime struct + /// this config can only be enabled in the RPC polling mode + /// if enabled, this config is applied to the entire process + QNN_HTP_PERF_INFRASTRUCTURE_POWER_CONFIGOPTION_ADAPTIVE_POLLING_TIME = 6, + /// config enum implies the usage of DDR performance mode + /// this config can only be enabled under the following conditions: + /// 1. The SoC must support DDR performance mode (e.g. V81) + /// 2. The RPC polling mode is turned on + /// 3. Currently, it can only be used by LLM on Android V81 + /// 4. Currently, it can only be used when bus voltage corner is voted to maximum level + QNN_HTP_PERF_INFRASTRUCTURE_POWER_CONFIGOPTION_DDR_PERF_MODE = 7, + /// UNKNOWN config option which must not be used + QNN_HTP_PERF_INFRASTRUCTURE_POWER_CONFIGOPTION_UNKNOWN = 0x7fffffff +} QnnHtpPerfInfrastructure_PowerConfigOption_t; + +/** + * @brief This struct provides performance infrastructure configuration + * associated with setting up of power levels + */ +typedef struct { + QnnHtpPerfInfrastructure_PowerConfigOption_t option; + union UNNAMED { + QnnHtpPerfInfrastructure_DcvsV3_t dcvsV3Config; + QnnHtpPerfInfrastructure_RpcControlLatency_t rpcControlLatencyConfig; + QnnHtpPerfInfrastructure_RpcPollingTime_t rpcPollingTimeConfig; + QnnHtpPerfInfrastructure_HmxTimeoutIntervalUs_t hmxTimeoutIntervalUsConfig; + QnnHtpPerfInfrastructure_HmxV2_t hmxV2Config; + QnnHtpPerfInfrastructure_AdaptivePollingTime_t adaptivePollingTimeConfig; + QnnHtpPerfInfrastructure_DdrPerfMode_t ddrPerfModeConfig; + }; +} QnnHtpPerfInfrastructure_PowerConfig_t; + +/// QnnHtpPerfInfrastructure_PowerConfig_t initializer macro +#define QNN_HTP_PERF_INFRASTRUCTURE_POWER_CONFIG_INIT \ + { \ + QNN_HTP_PERF_INFRASTRUCTURE_POWER_CONFIGOPTION_UNKNOWN, /*config*/ \ + { \ + 0 /*dcvsV3Config*/ \ + } \ + } + +/** + * @brief This enum defines all the possible performance + * options in Htp Performance Infrastructure that + * relate to system memory settings + */ +typedef enum { + /// sets memory grow size + QNN_HTP_PERF_INFRASTRUCTURE_MEMORY_CONFIGOPTION_GROW_SIZE = 1, + /// UNKNOWN config option that must not be used + QNN_HTP_PERF_INFRASTRUCTURE_MEMORY_CONFIGOPTION_UNKNOWN = 0x7fffffff +} QnnHtpPerfInfrastructure_MemoryConfigOption_t; + +/** + * @brief Provides performance infrastructure configuration + * options that are memory specific + */ +typedef struct { + QnnHtpPerfInfrastructure_MemoryConfigOption_t option; + union UNNAMED { + QnnHtpPerfInfrastructure_MemGrowSize_t memGrowSizeConfig; + }; +} QnnHtpPerfInfrastructure_MemoryConfig_t; + +/// QnnHtpPerfInfrastructure_MemoryConfig_t initializer macro +#define QNN_HTP_PERF_INFRASTRUCTURE_MEMORY_CONFIG_INIT \ + { \ + QNN_HTP_PERF_INFRASTRUCTURE_MEMORY_CONFIGOPTION_UNKNOWN, /*config*/ \ + { \ + 0 /*memGrowSizeConfig*/ \ + } \ + } + +//============================================================================= +// API Methods +//============================================================================= + +/** + * @brief This API allows client to create power configuration id that + * has to be used to set different performance modes. + * Power configuration id has to be destroyed by client when not needed. + * + * @param[in] deviceId Hardware Device on which this config id needs to be created. + * + * @param[in] coreId Core/NSP on which this config id needs to be created. + * + * @param[out] powerConfigId Pointer to power configuration id to be created. + * + * @return Error code + * \n QNN_SUCCESS: No error encountered + * \n QNN_HTP_PERF_INFRASTRUCTURE_ERROR_INVALID_INPUT if deviceId/coreId + * or power configuration id is NULL + */ +typedef Qnn_ErrorHandle_t (*QnnHtpPerfInfrastructure_CreatePowerConfigIdFn_t)( + uint32_t deviceId, uint32_t coreId, uint32_t* powerConfigId); + +/** + * @brief This API allows client to destroy power configuration id. + * + * @param[in] powerConfigId A power configuration id to be destroyed. + * + * @return Error code + * \n QNN_SUCCESS: No error encountered + * \n QNN_HTP_PERF_INFRASTRUCTURE_ERROR_INVALID_INPUT if power configuration + * id does not exist + * \n QNN_COMMON_ERROR_SYSTEM_COMMUNICATION: SSR occurence (successful recovery) + * \n QNN_COMMON_ERROR_SYSTEM_COMMUNICATION_FATAL: SSR occurence (unsuccessful recovery) + */ +typedef Qnn_ErrorHandle_t (*QnnHtpPerfInfrastructure_DestroyPowerConfigIdFn_t)( + uint32_t powerConfigId); + +/** + * @brief This API allows client to set up system power configuration that + * will enable different performance modes. This API uses + * HAP_power_dcvs_v3_payload struct to config HAP power parameters. + * Detailed HAP power parameters description please refer to Hexagon + * SDK HAP_power_dcvs_v3_payload documentation. + * + * @param[in] powerConfigId A power client id to associate calls to system + * power settings. A value of 0 implies NULL power client id + * and can override every other setting the user process. To + * enable power settings for multiple clients in the same + * process, use a non-zero power client id. + * + * @param[in] config Pointer to a NULL terminated array + * of config option for performance configuration. + * NULL is allowed and indicates no config options are provided. + * + * @return Error code + * \n QNN_SUCCESS: No error encountered + * \n QNN_HTP_PERF_INFRASTRUCTURE_ERROR_INVALID_INPUT if power configuration + * does not exist + * \n QNN_COMMON_ERROR_SYSTEM_COMMUNICATION: SSR occurence (successful recovery) + * \n QNN_COMMON_ERROR_SYSTEM_COMMUNICATION_FATAL: SSR occurence (unsuccessful recovery) + */ +typedef Qnn_ErrorHandle_t (*QnnHtpPerfInfrastructure_SetPowerConfigFn_t)( + uint32_t powerConfigId, const QnnHtpPerfInfrastructure_PowerConfig_t** config); + +/** + * @brief This API allows clients to set up configuration associated with + * system memory on a specific device + * + * @param[in] deviceId Hardware Device on which this config needs to be applied. + * + * @param[in] coreId Core/NSP on which this config needs to be applied. + * + * @param[in] config Pointer to a NULL terminated array + * of config option for system memory configuration. + * NULL is allowed and indicates no config options are provided. + * + * @return Error code + * \n QNN_SUCCESS: No error encountered + * \n QNN_HTP_PERF_INFRASTRUCTURE_ERROR_INVALID_INPUT if deviceId/coreId + * or memory configuration does not exist + * \n QNN_COMMON_ERROR_SYSTEM_COMMUNICATION: SSR occurence (successful recovery) + * \n QNN_COMMON_ERROR_SYSTEM_COMMUNICATION_FATAL: SSR occurence (unsuccessful recovery) + */ +typedef Qnn_ErrorHandle_t (*QnnHtpPerfInfrastructure_SetMemoryConfigFn_t)( + uint32_t deviceId, uint32_t coreId, const QnnHtpPerfInfrastructure_MemoryConfig_t** config); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // QNN_HTP_PERF_INFRASTRUCTURE_H diff --git a/qnn/jni/qnn/QNN/HTP/QnnHtpProfile.h b/qnn/jni/qnn/QNN/HTP/QnnHtpProfile.h new file mode 100644 index 00000000..92381d17 --- /dev/null +++ b/qnn/jni/qnn/QNN/HTP/QnnHtpProfile.h @@ -0,0 +1,567 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +/** + * @file + * @brief QNN HTP Profile component API. + * + * Requires HTP backend to be initialized. + * Should be used with the QnnProfile API but has HTP backend + * specific definition for different QnnProfile data structures + * + */ + +#ifndef QNN_HTP_PROFILE_H +#define QNN_HTP_PROFILE_H + +#include "QnnProfile.h" + +#ifdef __cplusplus +extern "C" { +#endif + +//============================================================================= +// Macros +//============================================================================= +/** + * @brief QnnProfile_EventType_t definition to get profile information + * that corresponds to the remote procedure call on the ARM processor + * when client invokes QnnContext_createFromBinary. The value + * returned is time in microseconds. + * + * @note context load binary host rpc time maybe available on both + * QNN_PROFILE_LEVEL_BASIC and QNN_PROFILE_LEVEL_DETAILED levels + */ +#define QNN_HTP_PROFILE_EVENTTYPE_CONTEXT_LOAD_BIN_HOST_RPC_TIME_MICROSEC 1002 + +/** + * @brief QnnProfile_EventType_t definition to get profile information + * that corresponds to the remote procedure call on the HTP processor + * when client invokes QnnContext_createFromBinary. The value + * returned is time in microseconds. + * + * @note context load binary htp rpc time maybe available on both + * QNN_PROFILE_LEVEL_BASIC and QNN_PROFILE_LEVEL_DETAILED levels + */ +#define QNN_HTP_PROFILE_EVENTTYPE_CONTEXT_LOAD_BIN_HTP_RPC_TIME_MICROSEC 1003 + +/** + * @brief QnnProfile_EventType_t definition to get profile information + * that corresponds to the time taken to create the context on the + * accelerator when client invokes QnnContext_createFromBinary. + * The value returned is time in microseconds. + * + * @note context load binary accelerator time maybe available on both + * QNN_PROFILE_LEVEL_BASIC and QNN_PROFILE_LEVEL_DETAILED levels + */ +#define QNN_HTP_PROFILE_EVENTTYPE_CONTEXT_LOAD_BIN_ACCEL_TIME_MICROSEC 1004 + +/** + * @brief QnnProfile_EventType_t definition to get profile information + * that corresponds to the remote procedure call on the ARM processor + * when client invokes QnnGraph_finalize. + * The value returned is time in microseconds. + * + * @note graph finalize host rpc time maybe available on both + * QNN_PROFILE_LEVEL_BASIC and QNN_PROFILE_LEVEL_DETAILED levels + */ +#define QNN_HTP_PROFILE_EVENTTYPE_GRAPH_FINALIZE_HOST_RPC_TIME_MICROSEC 2001 + +/** + * @brief QnnProfile_EventType_t definition to get profile information + * that corresponds to the remote procedure call on the HTP processor + * when client invokes QnnGraph_finalize. + * The value returned is time in microseconds. + * + * @note graph finalize htp rpc time maybe available on both + * QNN_PROFILE_LEVEL_BASIC and QNN_PROFILE_LEVEL_DETAILED levels + */ +#define QNN_HTP_PROFILE_EVENTTYPE_GRAPH_FINALIZE_HTP_RPC_TIME_MICROSEC 2002 + +/** + * @brief QnnProfile_EventType_t definition to get profile information + * that corresponds to finalize the graph on the accelerator + * when client invokes QnnGraph_finalize. + * The value returned is time in microseconds. + * + * @note graph finalize accelerator time maybe available on both + * QNN_PROFILE_LEVEL_BASIC and QNN_PROFILE_LEVEL_DETAILED levels + */ +#define QNN_HTP_PROFILE_EVENTTYPE_GRAPH_FINALIZE_ACCEL_TIME_MICROSEC 2003 + +/* Graph Performance Estimate Support + * + **/ +/** + * @brief QnnProfile_EventType_t definition to get profile information + * that corresponds to Performance Estimates for the graph + * when client invokes QnnGraph_finalize. + * This is just a dummy event which will print only the heading + * with no value or unit. + * @note HTP Performance Estimates maybe available on both + * QNN_PROFILE_LEVEL_BASIC and QNN_PROFILE_LEVEL_DETAILED levels + */ +#define QNN_HTP_PROFILE_EVENTTYPE_GRAPH_FINALIZE_PERF_ESTIMATE 2004 + +/** + * @brief QnnProfile_EventType_t definition to get perf mode at which + * the perf estimates are collected during QnnGraph_finalize. + * The value returned is the perf mode in string with no unit. + * + * @note Perf mode maybe available on both + * QNN_PROFILE_LEVEL_BASIC and QNN_PROFILE_LEVEL_DETAILED levels + */ +#define QNN_HTP_PROFILE_EVENTTYPE_GRAPH_FINALIZE_PERF_ESTIMATE_MODE 2005 + +/** + * @brief QnnProfile_EventType_t definition to get profile information + * that corresponds to simulated execution cycles during + * QnnGraph_finalize. + * The value returned is number of cycles. + * + * @note Simulated execution cycles maybe available on both + * QNN_PROFILE_LEVEL_BASIC and QNN_PROFILE_LEVEL_DETAILED levels + */ +#define QNN_HTP_PROFILE_EVENTTYPE_GRAPH_FINALIZE_PERF_ESTIMATE_SIM_EXEC_CYCLES 2006 + +/** + * @brief QnnProfile_EventType_t definition to get profile information + * that corresponds to a lower estimate of simulated execution + * cycles during QnnGraph_finalize. + * The value returned is number of cycles. + * + * @note Simulated execution cycles lower estimate maybe available on both + * QNN_PROFILE_LEVEL_BASIC and QNN_PROFILE_LEVEL_DETAILED levels + */ +#define QNN_HTP_PROFILE_EVENTTYPE_GRAPH_FINALIZE_PERF_ESTIMATE_SIM_EXEC_LOWER_CYCLES 2007 + +/** + * @brief QnnProfile_EventType_t definition to get profile information + * that corresponds to a upper estimate of simulated execution + * cycles during QnnGraph_finalize. + * The value returned is number of cycles. + * + * @note Simulated execution cycles upper estimate maybe available on both + * QNN_PROFILE_LEVEL_BASIC and QNN_PROFILE_LEVEL_DETAILED levels + */ +#define QNN_HTP_PROFILE_EVENTTYPE_GRAPH_FINALIZE_PERF_ESTIMATE_SIM_EXEC_UPPER_CYCLES 2008 + +/** + * @brief QnnProfile_EventType_t definition to get profile information + * that corresponds to DDR information for each HTP during + * QnnGraph_finalize. + * This is just a dummy event which will print only the heading + * with no value or unit. + * + * @note DDR Information for each HTP maybe available on both + * QNN_PROFILE_LEVEL_BASIC and QNN_PROFILE_LEVEL_DETAILED levels + */ +#define QNN_HTP_PROFILE_EVENTTYPE_GRAPH_FINALIZE_PERF_ESTIMATE_BANDWIDTH_STATS 2009 + +/** + * @brief QnnProfile_EventType_t definition to get profile information + * that corresponds to the HTP ID on chip during QnnGraph_finalize. + * The value returned is the HTP ID with no unit. + * + * @note HTP ID's maybe available on both + * QNN_PROFILE_LEVEL_BASIC and QNN_PROFILE_LEVEL_DETAILED levels + */ +#define QNN_HTP_PROFILE_EVENTTYPE_GRAPH_FINALIZE_PERF_ESTIMATE_BANDWIDTH_STATS_HTP_ID 2010 + +/** + * @brief QnnProfile_EventType_t definition to get profile information + * that corresponds to the Graph defined inputs or the total reads + * (in bytes) from DDR for graph input related tensors (weights, + * bias, activations) which do not have predecessors. + * The value returned is the num of blocks in bytes. + * + * @note Graph defined inputs for each HTP maybe available on both + * QNN_PROFILE_LEVEL_BASIC and QNN_PROFILE_LEVEL_DETAILED levels + */ +#define QNN_HTP_PROFILE_EVENTTYPE_GRAPH_FINALIZE_PERF_ESTIMATE_INPUT_FILL 2011 + +/** + * @brief QnnProfile_EventType_t definition to get profile information + * that corresponds to the total reads (in bytes) from DDR for + * compiler generated fill operators which have predecessors and + * successors and originate on the same HTP. + * The value returned is the num of blocks in bytes. + * + * @note Intermediate Fill Information for each HTP maybe available on both + * QNN_PROFILE_LEVEL_BASIC and QNN_PROFILE_LEVEL_DETAILED levels + */ +#define QNN_HTP_PROFILE_EVENTTYPE_GRAPH_FINALIZE_PERF_ESTIMATE_INTERMEDIATE_FILL 2012 + +/** + * @brief QnnProfile_EventType_t definition to get profile information + * that corresponds to the total writes (in bytes) from DDR for + * compiler generated fill operators which have predecessors and + * successors and originate on the same HTP. + * The value returned is the num of blocks in bytes. + * + * @note Intermediate Spill Information for each HTP maybe available on both + * QNN_PROFILE_LEVEL_BASIC and QNN_PROFILE_LEVEL_DETAILED levels + */ +#define QNN_HTP_PROFILE_EVENTTYPE_GRAPH_FINALIZE_PERF_ESTIMATE_INTERMEDIATE_SPILL 2013 + +/** + * @brief QnnProfile_EventType_t definition to get profile information + * that corresponds to the total reads (in bytes) from DDR for + * fills which were generated by a different HTP core and do not + * have a predecessor, but have a successor. + * The value returned is the num of blocks in bytes. + * + * @note Inter HTP Fill Information for each HTP maybe available on both + * QNN_PROFILE_LEVEL_BASIC and QNN_PROFILE_LEVEL_DETAILED levels + */ +#define QNN_HTP_PROFILE_EVENTTYPE_GRAPH_FINALIZE_PERF_ESTIMATE_INTER_HTP_FILL 2014 + +/** + * @brief QnnProfile_EventType_t definition to get profile information + * that corresponds to the total writes (in bytes) from DDR for + * fills which were generated by a different HTP core and do not + * have a successor, but have a predecessor. + * The value returned is the num of blocks in bytes. + * + * @note Inter HTP Spill Information for each HTP maybe available on both + * QNN_PROFILE_LEVEL_BASIC and QNN_PROFILE_LEVEL_DETAILED levels + */ +#define QNN_HTP_PROFILE_EVENTTYPE_GRAPH_FINALIZE_PERF_ESTIMATE_INTER_HTP_SPILL 2015 + +/** + * @brief QnnProfile_EventType_t definition to get profile information + * that corresponds to the total writes (in bytes) to DDR for + * graph output related tensors which do not have successors. + * The value returned is the num of blocks in bytes. + * + * @note Graph output related tensors for each HTP maybe available on both + * QNN_PROFILE_LEVEL_BASIC and QNN_PROFILE_LEVEL_DETAILED levels + */ +#define QNN_HTP_PROFILE_EVENTTYPE_GRAPH_FINALIZE_PERF_ESTIMATE_OUTPUT_SPILL 2016 + +/** + * @brief QnnProfile_EventType_t definition to get profile information + * that corresponds to the total number of missing ops which do + * not have any cost associated with them while getting the graph + * performance estimates. + * The value returned is the num of missing ops with no unit. + * + * @note Number of missing cost ops maybe available on both + * QNN_PROFILE_LEVEL_BASIC and QNN_PROFILE_LEVEL_DETAILED levels + */ +#define QNN_HTP_PROFILE_EVENTTYPE_GRAPH_FINALIZE_PERF_ESTIMATE_MISSING_COST_OPS 2017 + +/** + * @brief QnnProfile_EventType_t definition to get profile information + * that corresponds to the op ids of the missing ops which do + * not have any cost associated with them while getting the graph + * performance estimates. + * The value returned is the opname along with the op id (decimal + * format) of the ops which does not have any costs associated + * with them. + * + * @note Opname and Op ids of missing cost ops are available only with + * QNN_PROFILE_LEVEL_DETAILED levels + */ +#define QNN_HTP_PROFILE_EVENTTYPE_GRAPH_FINALIZE_PERF_ESTIMATE_MISSING_COST_OPID 2018 + +/** + * @brief QnnProfile_EventType_t definition to get profile information + * that corresponds to the remote procedure call on the ARM processor + * when client invokes QnnGraph_execute or QnnGraph_executeAsync. + * The value returned is time in microseconds. + * + * @note graph execute host rpc time maybe available on both + * QNN_PROFILE_LEVEL_BASIC and QNN_PROFILE_LEVEL_DETAILED levels + */ +#define QNN_HTP_PROFILE_EVENTTYPE_GRAPH_EXECUTE_HOST_RPC_TIME_MICROSEC 3001 + +/** + * @brief QnnProfile_EventType_t definition to get profile information + * that corresponds to the remote procedure call on the HTP processor + * when client invokes QnnGraph_execute or QnnGraph_executeAsync. + * The value returned is time in microseconds. + * + * @note graph execute htp rpc time maybe available on both + * QNN_PROFILE_LEVEL_BASIC and QNN_PROFILE_LEVEL_DETAILED levels + */ +#define QNN_HTP_PROFILE_EVENTTYPE_GRAPH_EXECUTE_HTP_RPC_TIME_MICROSEC 3002 + +/** + * @brief QnnProfile_EventType_t definition to get profile information + * that corresponds to execute the graph on the accelerator + * when client invokes QnnGraph_execute or QnnGraph_executeAsync. + * The value returned is number of processor cycles taken. + * + * @note graph execute accelerator time maybe available only on + * QNN_PROFILE_LEVEL_DETAILED levels + * + * @note When QNN_PROFILE_LEVEL_DETAILED is used, this event can have + * multiple sub-events of type QNN_PROFILE_EVENTTYPE_NODE. + * There will be a sub-event for each node that was added to the graph + */ +#define QNN_HTP_PROFILE_EVENTTYPE_GRAPH_EXECUTE_ACCEL_TIME_CYCLE 3003 + +/** + * @brief QnnProfile_EventType_t definition to get profile information + * that corresponds to execute the graph on the accelerator + * when client invokes QnnGraph_execute or QnnGraph_executeAsync. + * The value indicates execute including wait/resource acquisition + * time on the accelerator, if applicable in multi-threaded scenarios. + * The value returned is time taken in microseconds + * + * @note graph execute accelerator time maybe available on both + * QNN_PROFILE_LEVEL_BASIC and QNN_PROFILE_LEVEL_DETAILED levels + * + * @note When QNN_PROFILE_LEVEL_DETAILED is used, this event can have + * multiple sub-events of type QNN_PROFILE_EVENTTYPE_NODE / QNN_PROFILE_EVENTUNIT_MICROSEC + * There will be a sub-event for each node that was added to the graph + */ +#define QNN_HTP_PROFILE_EVENTTYPE_GRAPH_EXECUTE_ACCEL_TIME_MICROSEC 3004 + +/** + * @brief QnnProfile_EventType_t definition to get profile information + * that corresponds to time taken for miscellaneous work i.e. time + * that cannot be attributed to a node but are still needed to + * execute the graph on the accelerator. This occurs when client invokes + * QnnGraph_execute or QnnGraph_executeAsync. + * The value returned is time taken in microseconds + * + * @note graph execute misc accelerator time is available only on + * QNN_PROFILE_LEVEL_DETAILED levels + */ +#define QNN_HTP_PROFILE_EVENTTYPE_GRAPH_EXECUTE_MISC_ACCEL_TIME_MICROSEC 3005 + +/** + * @brief QnnProfile_EventType_t definition to get profile information + * that corresponds to time taken for a graph yield instance to + * release all its resources to the other graph. + * The value returned is time taken in microseconds. + */ +#define QNN_HTP_PROFILE_EVENTTYPE_GRAPH_EXECUTE_YIELD_INSTANCE_RELEASE_TIME 3006 + +/** + * @brief QnnProfile_EventType_t definition to get profile information + * that corresponds to time a graph spends waiting for a higher + * priority graph to finish execution. + * The value returned is time taken in microseconds + */ +#define QNN_HTP_PROFILE_EVENTTYPE_GRAPH_EXECUTE_YIELD_INSTANCE_WAIT_TIME 3007 + +/** + * @brief QnnProfile_EventType_t definition to get profile information + * that corresponds to time a graph spends re-acquiring resources + * and restoring vtcm. + * The value returned is time taken in microseconds + */ +#define QNN_HTP_PROFILE_EVENTTYPE_GRAPH_EXECUTE_YIELD_INSTANCE_RESTORE_TIME 3008 + +/** + * @brief QnnProfile_EventType_t definition to get profile information + * that corresponds to the number of times that a yield occured + * during execution + */ +#define QNN_HTP_PROFILE_EVENTTYPE_GRAPH_EXECUTE_YIELD_COUNT 3009 + +/** + * @brief QnnProfile_EventType_t definition for time a graph waits to get + * VTCM. This should be constant UNLESS we need another graph to yield. + * The value returned is time taken in microseconds. + */ +#define QNN_HTP_PROFILE_EVENTTYPE_GRAPH_EXECUTE_VTCM_ACQUIRE_TIME 3010 + +/** + * @brief QnnProfile_EventType_t definition for time a graph waits to get + * HMX + HVX, and turn them all on. + * The value returned is time taken in microseconds. + */ +#define QNN_HTP_PROFILE_EVENTTYPE_GRAPH_EXECUTE_RESOURCE_POWER_UP_TIME 3011 + +/** + * @brief QnnProfile_EventType_t definition to get profile information + * that corresponds to execute the graph on the accelerator + * when client invokes QnnGraph_execute or QnnGraph_executeAsync. + * The value indicates execute excluding wait/resource acquisition + * time on the accelerator, if applicable in multi-threaded scenarios. + * The value returned is time taken in microseconds + * + * @note graph execute accelerator time maybe available on both + * QNN_PROFILE_LEVEL_BASIC and QNN_PROFILE_LEVEL_DETAILED levels + * + * @note When QNN_PROFILE_LEVEL_DETAILED is used, this event can have + * multiple sub-events of type QNN_PROFILE_EVENTTYPE_NODE / QNN_PROFILE_EVENTUNIT_MICROSEC + * There will be a sub-event for each node that was added to the graph + */ +#define QNN_HTP_PROFILE_EVENTTYPE_GRAPH_EXECUTE_ACCEL_EXCL_WAIT_TIME_MICROSEC 3012 + +/** + * @brief QnnProfile_EventType_t definition to get profile information + * that corresponds to the remote procedure call on the ARM processor + * when client invokes QnnContext_free which in consequence deinit graph. + * The value returned is time in microseconds. + * + * @note graph deinit host rpc time maybe available on both + * QNN_PROFILE_LEVEL_BASIC and QNN_PROFILE_LEVEL_DETAILED levels + */ +#define QNN_HTP_PROFILE_EVENTTYPE_GRAPH_DEINIT_HOST_RPC_TIME_MICROSEC 4001 + +/** + * @brief QnnProfile_EventType_t definition to get profile information + * that corresponds to the remote procedure call on the HTP processor + * when client invokes QnnContext_free which in consequence deinit graph. + * The value returned is time in microseconds. + * + * @note graph deinit htp rpc time maybe available on both + * QNN_PROFILE_LEVEL_BASIC and QNN_PROFILE_LEVEL_DETAILED levels + */ +#define QNN_HTP_PROFILE_EVENTTYPE_GRAPH_DEINIT_HTP_RPC_TIME_MICROSEC 4002 + +/** + * @brief QnnProfile_EventType_t definition to get profile information + * that corresponds to the time taken to deinit graph on the + * accelerator when client invokes QnnContext_free which in consequence + * deinit graph. The value returned is time in microseconds. + * + * @note graph deinit accelerator time maybe available on both + * QNN_PROFILE_LEVEL_BASIC and QNN_PROFILE_LEVEL_DETAILED levels + */ +#define QNN_HTP_PROFILE_EVENTTYPE_GRAPH_DEINIT_ACCEL_TIME_MICROSEC 4003 + +/** + * @brief QnnProfile_EventType_t definition to get data related to execution of + * an operation. This value represents the amount of time an op spends + * waiting for execution on the main thread since the last op on the main + * thread due to scheduling and can be interpreted appropriately in + * conjunction with the unit. + * + * @note node wait information is available on QNN_HTP_PROFILE_LEVEL_LINTING level + */ +#define QNN_HTP_PROFILE_EVENTTYPE_NODE_WAIT 5001 + +/** + * @brief QnnProfile_EventType_t definition to get data related to execution of + * an operation. This value represents the amount of time at least one + * background op is running during the execution of an op on the main thread + * and can be interpreted appropriately in conjunction with the unit. + * + * @note node overlap information is available on QNN_HTP_PROFILE_LEVEL_LINTING level + */ +#define QNN_HTP_PROFILE_EVENTTYPE_NODE_OVERLAP 5002 + +/** + * @brief QnnProfile_EventType_t definition to get data related to execution of + * an operation. This value represents the amount of time at least one + * background op that is not being waited upon to finish is running during + * the wait period of an op on the main thread and can be interpreted + * appropriately in conjunction with the unit. + * + * @note node wait overlap information is available on QNN_HTP_PROFILE_LEVEL_LINTING + * level + */ +#define QNN_HTP_PROFILE_EVENTTYPE_NODE_WAIT_OVERLAP 5003 + +/** + * @brief QnnProfile_EventType_t definition to get data related to execution of + * an operation. This value represents a bitmask denoting the resources + * an op uses. + * + * @note node specific information is available on QNN_HTP_PROFILE_LEVEL_LINTING level + */ +#define QNN_HTP_PROFILE_EVENTTYPE_NODE_RESOURCEMASK 5004 + +/** + * @brief QnnProfile_EventType_t definition to get data related to execution of + * an operation. This value represents the ID of an op running in parallel to + * an op running on the main thread or on HMX. + * + * @note node specific information is available on QNN_HTP_PROFILE_LEVEL_LINTING level + */ +#define QNN_HTP_PROFILE_EVENTTYPE_NODE_CRITICAL_BG_OP_ID 5005 + +/** + * @brief QnnProfile_EventType_t definition to get data related to execution of + * an operation. This value represents the ID of an op running on threads other + * than the main or the HMX thread when the main and the HMX threads are not + * executing any op. + * + * @note node specific information is available on QNN_HTP_PROFILE_LEVEL_LINTING level + */ +#define QNN_HTP_PROFILE_EVENTTYPE_NODE_WAIT_BG_OP_ID 5006 + +/** + * @brief QnnProfile_EventType_t definition to get profile information + * that corresponds to execute the graph's critical path on the accelerator + * when client invokes QnnGraph_execute or QnnGraph_executeAsync. + * The value returned is number of processor cycles taken. + * + * @note graph execute accelerator time maybe available only on + * QNN_HTP_PROFILE_LEVEL_LINTING levels + * + * @note When QNN_HTP_PROFILE_LEVEL_LINTING is used, this event can have + * multiple sub-events of type QNN_PROFILE_EVENTTYPE_NODE. + * There will be a sub-event for each node that was added to the graph + */ +#define QNN_HTP_PROFILE_EVENTTYPE_GRAPH_EXECUTE_CRITICAL_ACCEL_TIME_CYCLE 6001 + +/** + * @brief Linting QnnProfile_Level_t definition that allows collecting in-depth + * performance metrics for each op in the graph including main thread + * execution time and time spent on parallel background ops. + */ +#define QNN_HTP_PROFILE_LEVEL_LINTING 7001 + +/** + * @brief QnnProfile_EventType_t definition to get number of HVX threads + * configured by a graph. Different graphs can have a different + * value. + */ +#define QNN_HTP_PROFILE_EVENTTYPE_GRAPH_NUMBER_OF_HVX_THREADS 8001 + +/** + * @brief QnnProfile_EventType_t definition to get profile information + * that corresponds to applying binary section for updatable tensors + * when client invokes QnnContext_ApplyBinarySection. + * It refers to the total time the entire API takes. + * The value returned is time taken in microseconds. + */ +#define QNN_HTP_PROFILE_EVENTTYPE_GRAPH_APPLY_BINARY_SECTION_QNN 9001 + +/** + * @brief QnnProfile_EventType_t definition to get profile information + * that corresponds to applying binary section for updatable tensors + * when client invokes QnnContext_ApplyBinarySection. + * It refers to the time of callTransport. + * The value returned is time taken in microseconds. + */ +#define QNN_HTP_PROFILE_EVENTTYPE_GRAPH_APPLY_BINARY_SECTION_RPC 9002 + +/** + * @brief QnnProfile_EventType_t definition to get profile information + * that corresponds to applying binary section for updatable tensors + * when client invokes QnnContext_ApplyBinarySection. + * It refers to the remote procedure call on the HTP processor. + * The value returned is time taken in microseconds. + */ +#define QNN_HTP_PROFILE_EVENTTYPE_GRAPH_APPLY_BINARY_SECTION_QNN_ACC 9003 + +/** + * @brief QnnProfile_EventType_t definition to get profile information + * that corresponds to applying binary section for updatable tensors + * when client invokes QnnContext_ApplyBinarySection. + * It refers to the Hexnn call + * The value returned is time taken in microseconds. + */ +#define QNN_HTP_PROFILE_EVENTTYPE_GRAPH_APPLY_BINARY_SECTION_ACC 9004 + + + +#ifdef __cplusplus +} +#endif + +#endif // QNN_HTP_PROFILE_H diff --git a/qnn/jni/qnn/QNN/HTP/QnnHtpProperty.h b/qnn/jni/qnn/QNN/HTP/QnnHtpProperty.h new file mode 100644 index 00000000..51440061 --- /dev/null +++ b/qnn/jni/qnn/QNN/HTP/QnnHtpProperty.h @@ -0,0 +1,30 @@ +//============================================================================== +// +// Copyright (c) 2022 Qualcomm Technologies, Inc. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#ifndef QNN_HTP_PROPERTY_H +#define QNN_HTP_PROPERTY_H + +#include "QnnProperty.h" + +#ifdef __cplusplus +extern "C" { +#endif + +//============================================================================= +// Macros +//============================================================================= +/** + * @brief Property key for determining whether a backend supports unsigned pd. + */ +#define QNN_PROPERTY_CUSTOM_HTP_UNSIGNED_PD_SUPPORT QNN_PROPERTY_GROUP_CUSTOM + 1 + +#ifdef __cplusplus +} +#endif + +#endif // QNN_HTP_PROPERTY_H diff --git a/qnn/jni/qnn/QNN/HTP/QnnHtpSystemContext.h b/qnn/jni/qnn/QNN/HTP/QnnHtpSystemContext.h new file mode 100644 index 00000000..76d4d182 --- /dev/null +++ b/qnn/jni/qnn/QNN/HTP/QnnHtpSystemContext.h @@ -0,0 +1,176 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All rights reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +/** + * @file + * @brief QNN HTP component System Context API. + * + * The interfaces in this file work with the top level QNN + * API and supplements QnnSystemContext.h for HTP backend + */ + +#ifndef QNN_HTP_SYSTEM_CONTEXT_H +#define QNN_HTP_SYSTEM_CONTEXT_H + +#ifdef __cplusplus +extern "C" { +#endif + +//============================================================================= +// Macros +//============================================================================= +typedef enum { + // Following version with hwInfoBlobVersion as: + // - Major 0, Minor: 0, Patch: 1 + QNN_SYSTEM_CONTEXT_HTP_HW_INFO_BLOB_VERSION_V1 = 0x01, + // Unused, present to ensure 32 bits. + QNN_SYSTEM_CONTEXT_HTP_HW_INFO_BLOB_UNDEFINED = 0x7FFFFFFF +} QnnHtpSystemContext_HwInfoBlobVersion_t; + +// This struct is gets populated within a binary blob as part of hwInfoBlob in +// QnnSystemContext_BinaryInfoV#_t struct in QnnSystemContext.h +typedef struct QnnHtpSystemContext_HwBlobInfoV1 { + // This value represents the index of the list of graphs registered + // to this context as specified in QnnSystemContext_GraphInfo_t* + uint32_t graphListIndex; + // Stores the spill-fill buffer size used by each of the graphs + uint64_t spillFillBufferSize; +} QnnHtpSystemContext_HwBlobInfoV1_t; + +typedef struct { + QnnHtpSystemContext_HwInfoBlobVersion_t version; + union UNNAMED { + QnnHtpSystemContext_HwBlobInfoV1_t contextBinaryHwInfoBlobV1_t; + }; +} QnnHtpSystemContext_HwBlobInfo_t; + +typedef enum { + // Following version with GraphInfoBlobVersion as: + // - Major 0, Minor: 0, Patch: 1 + QNN_SYSTEM_CONTEXT_HTP_GRAPH_INFO_BLOB_VERSION_V1 = 0x01, + QNN_SYSTEM_CONTEXT_HTP_GRAPH_INFO_BLOB_VERSION_V2 = 0x02, + // Unused, present to ensure 32 bits. + QNN_SYSTEM_CONTEXT_HTP_GRAPH_INFO_BLOB_UNDEFINED = 0x7FFFFFFF +} QnnHtpSystemContext_GraphInfoBlobVersion_t; + +// This struct is gets populated within a binary blob as part of GraphInfoBlob in +// QnnSystemContext_BinaryInfoV1_t struct in QnnSystemContext.h +typedef struct { + // Stores the spill-fill buffer size used by each of the graphs + uint64_t spillFillBufferSize; + // HTP vtcm size (MB) + uint32_t vtcmSize; + // Optimization level + uint32_t optimizationLevel; + // Htp Dlbc + uint8_t htpDlbc; + // Number of HVX Threads to reserve; + uint64_t numHvxThreads; +} QnnHtpSystemContext_GraphBlobInfoV1_t; + +// This struct gets populated within a binary blob as part of GraphInfoBlob in +// QnnSystemContext_BinaryInfoV2_t struct in QnnSystemContext.h +/* +Note: This chart is for illustrative purposes only. ++-----------------------------+--------------------------+ +| 256G (Far Mem) | | +| | Shared (name) : 6G | +| +--------------------------+ +| | Non-Shared (name) : 1M | +| +--------------------------+ +| | I/O | +| | K Cache (name) : 100M | +| | V Cache (name) : 100M | +| +--------------------------+ +| | Free Memory : 100M | +| +--------------------------+ +| | | +| | (Far Memory starts at 4G)| ++-----------------------------+--------------------------+ +| 3.5G (Near Mem) | | +| | Shared (name) : 1G | +| +--------------------------+ +| | Non-Shared (Const): 100K | +| +--------------------------+ +| | Op Data : 100K | +| | | ++-----------------------------+--------------------------+ + + Total Memory Used: 6.701G +*/ +typedef struct { + QnnHtpSystemContext_GraphInfoBlobVersion_t version; + // Stores the nativeK channel tile size used by each of the graphs (bytes) + uint64_t nativeKChannelSize; + // Stores the nativeV channel tile size used by each of the graphs (bytes) + uint64_t nativeVChannelSize; + // The field name IsSafeShareIO indicates if it is safe to share the + // buffer between inputs and outputs. 1: True, 0: False + // Client is responsible for ensuring no clash between input and output + // when flag is set. + uint32_t isSafeShareIO; + // Stores graph input/output tensors size (bytes) + uint64_t ioTensorSize; + // Stores opdata memory/meta data size associated with ops that will be executed, + // inlcuding op data like runlists (bytes) + uint64_t opDataSize; + // Stores size of const data in the graph (bytes) + uint64_t constSize; + // Stores size of DDR-tensor (bytes) + uint64_t ddrTensorSize; + // Stores shared weights size (bytes) + uint64_t sharedWeightsSize; + +} QnnHtpSystemContext_GraphBlobInfoV2_t; + +typedef struct { + QnnHtpSystemContext_GraphInfoBlobVersion_t version; + union UNNAMED { + QnnHtpSystemContext_GraphBlobInfoV1_t contextBinaryGraphBlobInfoV1; + }; +} QnnHtpSystemContext_GraphBlobInfo_t; + +typedef enum { + // Following version with ContextInfoBlobVersion as: + // - Major 0, Minor: 0, Patch: 1 + QNN_SYSTEM_CONTEXT_HTP_CONTEXT_INFO_BLOB_VERSION_V1 = 0x01, + // Unused, present to ensure 32 bits. + QNN_SYSTEM_CONTEXT_HTP_CONTEXT_INFO_BLOB_UNDEFINED = 0x7FFFFFFF +} QnnHtpSystemContext_ContextInfoBlobVersion_t; + +typedef struct { + /// An integer representation of SocUtility::DspArch + uint32_t dspArch; +} QnnHtpSystemContext_ContextBlobInfoV1_t; + +typedef struct { + QnnHtpSystemContext_ContextInfoBlobVersion_t version; + union UNNAMED { + QnnHtpSystemContext_ContextBlobInfoV1_t contextBinaryContextBlobInfoV1; + }; +} QnnHtpSystemContext_ContextBlobInfo_t; + + +//============================================================================= +// Data Types +//============================================================================= + +//============================================================================= +// Public Functions +//============================================================================= + +//============================================================================= +// Implementation Definition +//============================================================================= + +// clang-format on +#ifdef __cplusplus +} // extern "C" +#endif + +#endif \ No newline at end of file diff --git a/qnn/jni/qnn/QNN/HTP/core/afuncs.h b/qnn/jni/qnn/QNN/HTP/core/afuncs.h new file mode 100644 index 00000000..0f17913a --- /dev/null +++ b/qnn/jni/qnn/QNN/HTP/core/afuncs.h @@ -0,0 +1,430 @@ +//============================================================================== +// +// Copyright (c) 2018, 2023 Qualcomm Technologies, Inc. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#ifndef AFUNCS_H +#define AFUNCS_H 1 + +#include +#include +#include "dtype.h" +#ifndef __hexagon__ +#include // for memcpy etc +#endif +// #include "asm_define.h" +#include "builtin_intrinsics.h" +#include "macros_attribute.h" + +struct tile_data { + uint8_t **addr; + uint32_t offset_t_col; + uint32_t offset_t_row; + uint32_t width; + uint32_t height; + uint32_t depth; +}; + +// Define order: .addr, .offset_t_col, .offset_t_row, .width, .height, .depth +#define TILEDATA(adrtab, next_tab_col, next_tab_row, h, w, d) \ + { \ + (uint8_t **)(adrtab), static_cast(next_tab_col), static_cast(next_tab_row), \ + static_cast(w), static_cast(h), static_cast(d) \ + } + +/*=======================================*/ +/* Auxiliary functions */ +/*=======================================*/ +#if defined(__hexagon__) +inline int32_t max_i32(int32_t a, int32_t b) +{ + return Q6_R_max_RR(a, b); +} +inline int32_t min_i32(int32_t a, int32_t b) +{ + return Q6_R_min_RR(a, b); +} +inline uint32_t max_u32(uint32_t a, uint32_t b) +{ + return Q6_R_maxu_RR(a, b); +} +inline uint32_t min_u32(uint32_t a, uint32_t b) +{ + return Q6_R_minu_RR(a, b); +} +#else +inline int32_t max_i32(int32_t a, int32_t b) +{ + return (a < b) ? b : a; +} +inline int32_t min_i32(int32_t a, int32_t b) +{ + return (a < b) ? a : b; +} +inline uint32_t max_u32(uint32_t a, uint32_t b) +{ + return (a < b) ? b : a; +} +inline uint32_t min_u32(uint32_t a, uint32_t b) +{ + return (a < b) ? a : b; +} +#endif + +[[maybe_unused]] inline ALWAYSINLINE int64_t roundf_i64(float val) +{ + // add 0.5 (with same sign as val) and then conversion to int truncates toward 0. + // values exactly halfway will round away from 0 (like roundf). + + return (int64_t)(val + copysignf(0.5f, val)); +} + +[[maybe_unused]] inline ALWAYSINLINE NN_INT32_T roundf_i32(float val) +{ + // add 0.5 (with same sign as val) and then conversion to int truncates toward 0. + // values exactly halfway will round away from 0 (like roundf). + + return (int)(val + copysignf(0.5f, val)); +} +// same thing for rounding to unsigned range; -ve inputs will give 0. +// +[[maybe_unused]] inline ALWAYSINLINE uint32_t roundf_u32(float val) +{ + // add 0.5f and then convert to uint (trunc towards 0; -ve values are clipped to 0). +#ifdef __hexagon__ + // use intrinsic since conv of -ve float to unsigned is 'undefined behaviour' in C. + return Q6_R_convert_sf2uw_R_chop(val + 0.5f); +#else + return (val < 0.5f) ? 0 : (uint32_t)(val + 0.5f); +#endif +} + +[[maybe_unused]] inline ALWAYSINLINE NN_INT32_T roundd_i32(double val) +{ + // add 0.5 (with same sign as val) and then conversion to int truncates toward 0. + // values exactly halfway will round away from 0 (like round). + + return (int)(val + copysign(0.5, val)); +} + +[[maybe_unused]] inline ALWAYSINLINE NN_INT32_T saturate_u8(NN_INT32_T val) +{ +#ifdef __hexagon__ + return Q6_R_satub_R(val); +#else + return (val < 0) ? 0 : ((val > 255) ? 255 : val); +#endif +} + +[[maybe_unused]] inline ALWAYSINLINE NN_INT32_T saturate_u16(NN_INT32_T val) +{ +#ifdef __hexagon__ + return Q6_R_satuh_R(val); +#else + return (val < 0) ? 0 : ((val > 65535) ? 65535 : val); +#endif +} + +[[maybe_unused]] static inline ALWAYSINLINE NN_INT32_T saturate_i16(NN_INT32_T val) +{ +#ifdef __hexagon__ + return Q6_R_sath_R(val); +#else + return (val < -32768) ? -32768 : ((val > 32767) ? 32767 : val); +#endif +} + +/** + * @brief low-cost frexpf (but only the exponent result); + * Generates only a few instructions on hexagon. + * + * Input must not be inf,nan, zero, or denormal. + * + * returns: + * -1 if abs(x) is in range 0.25 ... 0.249999 + * 0 if abs(x) is in range 0.5 ... 0.99999 + * 1 if abs(x) is in range 1.0 .. 1.9999 + * etc + * + * If the value -126 is returned, x is a zero or denormal; + * 129 is returned for inf or NaN. for other cases the value is the same + * as what frexpf (in math.h) generates for the exponent. + */ +[[maybe_unused]] inline ALWAYSINLINE constexpr int flt_getexp(float x) +{ + union { + float f; + uint32_t u32; + } const uu = {x}; + return ((uu.u32 >> 23u) & 0xFFu) - 126; +} + +// Specialized flt_getexp for the case where you want the value +// converted to exponent, and mantissa, with the mantissa normalized +// to a specific number of bits. + +// e.g. for 15-bit case: +// int e = flt_getexp_for_frac<15>(scale); +// float m_f = flt_ldexp(scale, 15-e); +// int m_i = roundf_i32(m_f); +// +// If you use 'flt_getexp(scale)', the value of m_f will be >= 16384.0, < 32768.0 +// *but* it could fall in range >= 32767.5, < 32768.0; in these cases m_i will be 32768. +// By using flt_getexp_for_frac<16> instead, you will get an an exponent which is larger +// by 1 for those specific cases, and then m_f will be in range > 16383.75, < 16384, +// so for those cases m_i will be rounded up to 16384; this is a more accurate representation +// than you get by saturating the result to 32767. +// +// This can be used for x values that may be negative; but in that case +// W does not count the sign. for W=15, with the +// above code you should have a value 16384 <= abs(m_i) <= 32767 +// If you want to normalize over the full signed range (i.e. -32768 <= m_i <= -16385, +// when scale < 0), use flt_getexp_for_signed_frac<16>(scale). +// +// ** In general **: +// Given an normal float x, > 0, return an exponent e such that +// x * 2^(W-e) +// .. rounded to nearest, is 'normalized' in W unsigned bits: +// >= (1<= 24, result is always the same as flt_getexp(x) +// +template // +inline int flt_getexp_for_frac(float x) +{ + static_assert(W >= 3 && W < 32); + // We want to return exponent larger by 1 + // if, and only if, the upper W bits of the mantissa are all 1 (not including the hidden + // bit) - so, add a 1 to bit 23-W of the 'uint32' image of the value; it will carry + // into the exponent field if and only if all those bits are 1. + union { + float f; + uint32_t u32; + } uu = {x}; + uu.u32 += (1u << 23u) >> W; + return flt_getexp(uu.f); +} +// This is like flt_getexp_for_frac, but for cases where you want +// a fully normalized 'signed; mantissa; e.g. +// +// int e = flt_getexp_for_signed_frac<16>(scale); +// float m_f = flt_ldexp(scale, 15-e); +// int m_i = roundf_i32(m_f); +// m_i = saturate_i16(m_i); // see note below; could be m_i = std::max(m_i, -32768) +// +// m_i will always be -32768.. -16385 (for scale < 0) +// or 16384..32767 (for scale > 0) +// +// For x > 0, the result is always the same as flt_getexp_frac(x). +// for x < 0, it is usually the same as flt_getexp(x), but sometimes +// one smaller: this happens when -x is exactly a power of two, or marginally larger. +// Those are cases where want the 'most negative' signed value -32768. +// NOTE: the saturate_i16 is needed since the rounded m_i result could be -32769; +// in such cases -32768 is sill the best available representation (better than +// -16385 with a larger +1 exponent) +// +template // W includes sign bit +inline int flt_getexp_for_signed_frac(float x) +{ + static_assert(W >= 4 && W <= 25); + // for x > 0, same effect as flt_getexp_for_frac; add 1 in bit (24-(W-1)) + // for x < 0 we subtract (1<<(24-(W-1))) + 1, so it will carry to exponent + // field and reduce by 1 in applicable cases. + // Equivalent is to add ~(1<<(24-(W-1))) modularly. + // (this 'modular add' is defence against 'sanitize' detecting overflow) + auto modular_add_u32 = [](uint32_t a, uint32_t b) -> uint32_t { + uint64_t const sum = uint64_t(a) + uint64_t(b); + return uint32_t(sum); + }; + union { + float f; + uint32_t u32; + } uu = {x}; + uint32_t constexpr fudge_bit = (1u << 23u) >> (W - 1); + uu.u32 = modular_add_u32(uu.u32, (uu.u32 & (1u << 31u)) ? ~fudge_bit : fudge_bit); + return flt_getexp(uu.f); +} + +/** + * @brief low-cost frexpf (but only the 'fraction' result); + * Generates only a few instructions on hexagon. + * + * Input must not be inf,nan, zero, or denormal. + * + * returns a value in the range [0.5, 1.0) (or in (-1.0,-0.5] when x < 0) + * such that x = flt_getmant(x) * powf2(2.0, flt_getexp(x)) + * + */ +[[maybe_unused]] inline ALWAYSINLINE constexpr float flt_getmant(float x) +{ + union { + float f; + uint32_t u32; + } uu = {x}; + uu.u32 = (uu.u32 & 0x807fffffu) | (uint32_t(126) << 23u); // force exponent = 126 + return uu.f; +} + +/** + * @brief returns the mantissa of x, as a 24-bit number + * in the range 0x800000 .. 0xFFFFFF + * + * Input must not be inf,nan, zero, or denormal. + * + * Sign is discarded. same as powf(2,24) * flt_getmant(fabsf(x)). + */ +[[maybe_unused]] inline ALWAYSINLINE constexpr int32_t flt_getfrac(float x) +{ + union { + float f; + uint32_t u32; + } const uu = {x}; + int32_t const m = (uu.u32 & 0x007fffffu) | (uint32_t(1) << 23u); + return m; +} + +// +// This 'normalizes' a float to 0.5 .. 0.9999 (sign is retained) +// Same result as the return value from frexpf, without using a function call +// Results are not valid if x is 0, denormal, or inf/nan +// +[[maybe_unused]] inline ALWAYSINLINE float flt_getfrac_norm(float x) +{ + union { + float f; + uint32_t u32; + } uu = {x}; + uu.u32 = (uu.u32 & 0x807fffffu) | (uint32_t(126) << 23u); // force exponent = 126 + return uu.f; +} +/** + * @brief low-cost 2.0*n for integer n. + * Same as powf(2.0f, iexpo) without a function call; + * + * Constraint: iexpo must be in range -126..127 + */ +[[maybe_unused]] inline ALWAYSINLINE constexpr float flt_power2(uint32_t const iexpo) +{ + uint32_t const a = (iexpo + 127) & 0xFFu; + union { + uint32_t u32; + float f; + } const uu = {a << 23u}; + return uu.f; +} +/** + * @brief low-cost ldexpf + * Same as ldexpf(val, iexpo) without a function call; + * + * Constraint: iexpo must be in range -126..127 + */ +[[maybe_unused]] inline ALWAYSINLINE constexpr float flt_ldexp(float val, int iexpo) +{ + return val * flt_power2(iexpo); +} +/** + * @brief low-cost 2.0*n for integer n. + * Same as pow(2.0d, iexpo) without a function call; + * + * Constraint: iexpo must be in range -1022..1023 + */ +[[maybe_unused]] inline ALWAYSINLINE constexpr double double_power2(uint32_t const iexpo) +{ + uint64_t const a = (iexpo + 1023) & 0x7FFu; + union { + uint64_t u64; + double d; + } const uu = {a << 52u}; + return uu.d; +} +/** + * @brief low-cost ldexpf + * Same as ldexp(val, iexpo) without a function call; + * + * Constraint: iexpo must be in range -1022..1023 + */ +[[maybe_unused]] inline ALWAYSINLINE constexpr double double_ldexp(double val, int iexpo) +{ + return val * double_power2(iexpo); +} + +/** + * @brief returns the exponent and mantissa of x, as a n-bit number + * + * Constraint: iexpo must be in range -126..127 + * Input must not be negative, inf,nan, zero, or denormal. + */ +template inline constexpr std::pair get_scalefactor(float x) +{ + union { + float f; + uint32_t u32; + } const uu = {x}; + + uint32_t inval = uu.u32; + uint32_t const mask = hnnx::safe_lshift(1, MBITS) - 1; + inval = hnnx::safe_rshift(inval + hnnx::safe_lshift(1, (24 - MBITS - 1)), + (24 - MBITS)); // possibly overflows into exponent, but that's OK. + uint32_t const m = ((inval & mask) | hnnx::safe_lshift(1u, (MBITS - 1))); + int32_t const e = int32_t(hnnx::safe_rshift(inval, (MBITS - 1)) & 0xFFu) - 126; + return {e, m}; +} + +/** + * @brief returns the parameters for scaling. + * bit 31-24: left shift amount + * bit 23-16: right shift amout + * bit 15- 0: scale factor + * + * Input must not be inf,nan, zero, negative or denormal. + * + */ +[[maybe_unused]] inline ALWAYSINLINE constexpr uint32_t get_scaling_params(float x, int max_sl, int max_sr) +{ + auto [e, m] = get_scalefactor<15>(x); + // Set a sl or sr amount to perform a multiply of 2^exponent by mantissa. + int sl = (e > 0) ? e : 0; + int sr = (e > 0) ? 0 : -e; + // The max_sl allows the addition of extra left shifts when working with small numbers having negative exponents. + // For every extra left shift, there is an offsetting right shift added so that the net right shift amount + // required from the exponent stays the same. The max_sr parameter provides a ceiling to the required offsetting + // right shifts, preventing the total right shift requirement from being large enough to erase data through shifting. + if (sl == 0 && sr > 0) { + sl = min_i32(max_sl, max_i32(max_sr - sr, 0)); + sr = sr + sl; + } + return ((uint32_t(sl) & 0x0FFu) << 24u) | ((uint32_t(sr) & 0x0FFu) << 16u) | uint32_t(m); +} + +/** + * @brief given a scale in float and a recip shift amount + * return a quantized scale multiplier and change recip shamt inplace + * + */ +inline uint32_t get_quantized_multipiler(const float scale_f, int &recip_shamt) +{ + recip_shamt = (scale_f <= 1.0f) ? 0 : flt_getexp(scale_f); + uint32_t scale = static_cast(roundf(flt_ldexp(scale_f, (31 - recip_shamt)))); + scale = (scale < 0x7fffffffu) ? scale : 0x7FFFFFFFu; + return scale; +} + +/** + * @brief given a scale in float and a recip shift amount + * return a quantized scale multiplier and change recip shamt inplace + * + */ +//Now with corrected spelling +inline uint32_t get_quantized_multiplier(const float scale_f, int &recip_shamt) +{ + return get_quantized_multipiler(scale_f, recip_shamt); +} +#endif /*AFUNCS_H*/ diff --git a/qnn/jni/qnn/QNN/HTP/core/allocator.h b/qnn/jni/qnn/QNN/HTP/core/allocator.h new file mode 100644 index 00000000..93c4b195 --- /dev/null +++ b/qnn/jni/qnn/QNN/HTP/core/allocator.h @@ -0,0 +1,237 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#ifndef ALLOCATOR_H +#define ALLOCATOR_H 1 + +#include +#include +#include +#include "dtype_enum.h" +#include "weak_linkage.h" +#include "macros_attribute.h" +#include "forward_classes.h" +#include "hexagon_nn_types.h" + +enum class MemoryClass { + Plain, + TCM, + UnCached, // for spill/fill DDR + XXX_LAST_MEMORY_TYPE, + Default = Plain +}; + +PUSH_VISIBILITY(default) + +extern bool TrackedAllocError; + +class Graph; +class HexagonNNEnv; +namespace fa { +struct PoolDesc; +struct BigBuff; +struct RuntimeAllocator; +} // namespace fa +namespace hnnx { + +class Serializer; +class Deserializer; + +// some options flags (powers of 2) for calls to Tensor::allocate +enum AllocOptions { + AllocOpts_packed = 0x1 // allocation will be packed +}; + +/* + * Maybe FIXME: It seems like FancyAllocator has just about all the same interfaces as Allocator, + * is all this pimpl stuff needed, or could we just inherit Allocator and have a unique_ptr + * in our graph? + */ + +class Allocator { + public: + // MIN_ALIGN, MAX_ALIGN: + // - both must be powers of 2 + // - 8 <= MIN_ALIGN <= MAX_ALIGN + // All allocations will be aligned to at least MIN_ALIGN, both start and end of each region. + // This includes sub-allocations in memory pools. + // Alignment requests > MAX_ALIGN may be treated as MAX_ALIGN if allocated in DDR. + // + static constexpr unsigned MIN_ALIGN = 256; + static constexpr unsigned MAX_ALIGN = 256; + + // The alignment used by TCM allocation; >= MIN_ALIGN + static constexpr unsigned TCM_ALLOC_ALIGN = 2048; + + static void *vacant() { return (void *)2; } // special value for 'vacant' slot. + enum Mode { AllocVirtual, AllocPhysical, AllocTemp, AllocTempEnd, AllocComplete, LastMode = AllocComplete }; + + // AllocTemp/AllocTempEnd are used in Virtual mode, to set a 'Temp Physical' mode + // where allocation is done to physical memory, but into memory blocks which + // are discarded when we return via AllocTempEnd (So, AllocTempEnd is not possible as an actual + // current mode). + // This is intended to support nesting (multiple levels of AllocTemp; each + // AllocTempEnd discards all allocs since the matching AllocTemp; but + // currently nesting is not supported, so AllocTemp must be followed by AllocTempEnd, + // which actually takes you back to AllocVirtual + // AllocComplete allows no further allocations. A deserialized allocator + // is in this state. + + API_EXPORT Allocator(Mode mode_in, Graph &graph_in) : graph(graph_in), mode(mode_in){}; + API_EXPORT virtual ~Allocator() = 0; + + Graph &graph; + + // Either allocates enough, or dips into a buffer (and changes the buffer pointer and size parameter accordingly). + // al is an alignment parameter; it must be a power of 2 or the code below won't work. + API_EXPORT void *tracked_aligned_alloc(size_t al, size_t bytes, fa::BigBuff *const bb = nullptr); + API_EXPORT void tracked_free(void *aligned_ptr) noexcept; + + API_EXPORT virtual void allocate_n(void **arrp, size_t n, size_t block_size, size_t alignment, MemoryClass memclass, + unsigned options, DType dtype); + + // options for allocate_persistent_blocks. + // if 'allnew' is *not* present, it is assumed that all of the pointers + // are either null, or point to existing persistent blocks. The 'null' ones + // are replaced with new allocations, and the ref counts are increased in both cases. + // with 'allnew': pointers are assumed to contain garbage. Equivalent to zeroing the + // pointer table first. + // + // zoneB: with this, ref counts are update in 'B' zone instead of A + // + // incref: ovverides 'allnew'; all of the existing pointers are required to be valid persistent + // blocks; the ref counts are increased by 1 + // decref: overrides 'incref and allnew'; all of the pointers are required to be valid persistent + // blocks; the ref counts are reduced by 1. If total refs are zero, block is freed. + // the pointer table is not updated. + // + // infinite: newly alloc'd blocks get refcount set to a huge number, instead of 1. + // Currently this is used when deserializing, since we can't free things immediately when in Crate. + // + enum persistent_options { + allnew = 1u, // assume existing pointers are garbage, allocate them all. + zoneB = 2u, // reference count in zone B instead of A. + incref = 4u, // enforce that all existing are persistnent; incref them. + decref = 8u, + infinite = 16u, // refcounts on new blocks, set to a huge # instead of 1. + }; + + // allocate n 'persistent' blocks of the given size/alignment, and update the table. + API_EXPORT virtual void allocate_persistent_blocks(void **table, size_t nblocks, size_t block_size, + size_t alignment, unsigned options); + + API_EXPORT inline void *allocate(const void *oldval, size_t block_size, size_t alignment, MemoryClass memclass, + unsigned options, DType dtype) + { + PUSH_WARNING() + DISABLE_WARNING("-Wcast-qual", MSVC_NO_EQUIV) + void *tmp = const_cast(oldval); + POP_WARNING() + allocate_n(&tmp, 1, block_size, alignment, memclass, options, dtype); + return tmp; + } + + API_EXPORT Mode get_mode() const { return mode; } + API_EXPORT virtual void set_mode(Mode new_mode); + + API_EXPORT virtual void set_tcm_pool(void *base, size_t size); + + API_EXPORT virtual void set_largest_memory_alloc_size(size_t size); + + /* + * Serialize all the internal data for the allocator. + * Memory regions / pools, etc. + */ + API_EXPORT virtual void serialize(Serializer &) const; + /* + * Deserialize the allocator, restore internal data from buffer. + */ + API_EXPORT virtual void deserialize(HexagonNNEnv &env, Deserializer &dctx, + hexagon_nn_wide_address_const_t params_weights = 0U, + const size_t params_weights_length = 0, + hexagon_nn_wide_iovec_t const &weights = NULL_IOVEC); + + API_EXPORT virtual int find_replaceable_mempool(unsigned const replaceable_pool_seq, + fa::PoolDesc &found_pool) const; + + // LCOV_EXCL_START [SAFTYSWCCB-1542] + API_EXPORT static inline constexpr size_t fixup_alignment(size_t align) + { + static_assert(MIN_ALIGN >= 8 && (MIN_ALIGN & (MIN_ALIGN - 1)) == 0, "bad MIN_ALIGN"); + static_assert(MAX_ALIGN >= MIN_ALIGN && (MAX_ALIGN & (MAX_ALIGN - 1)) == 0, "bad MAX_ALIGN"); + if (MIN_ALIGN < MAX_ALIGN) { + return std::max(MIN_ALIGN, std::min(MAX_ALIGN, align)); + } else { + return MIN_ALIGN; + } + } + // LCOV_EXCL_STOP + + API_EXPORT static inline constexpr size_t round_up_align(size_t n, size_t align) + { + return (n + (align - 1)) & ~(align - 1); + } + template API_EXPORT static inline T *round_up_align(T *p, size_t align) + { + return (T *)round_up_align((size_t)p, align); + } + + protected: + Mode mode = AllocVirtual; +}; + +// +// this is s 'shim' class to help in making dummy allocators. It defines overrides +// for all of the pure-virtual methods, so you don't need to +// +class FakeAllocator : public Allocator { + public: + API_EXPORT FakeAllocator(Allocator::Mode mode_in, Graph &graph_in) : Allocator(mode_in, graph_in){}; + API_EXPORT virtual ~FakeAllocator(); +}; + +// this is an accessor which is used by the Dma 'Fill' operation +// to get a source pointer for reading const, based on (pool_id, offset). +// It also holds the base pointer for ddr spill area. +// Maybe other things could be added later. + +class MemPoolRunTimeAccessor { + hexagon_nn_wide_address_t spill_area; + fa::PoolDesc const *pool_table; // pool_table[0] is for poolid=1 + unsigned max_pool_id; + + public: + API_EXPORT MemPoolRunTimeAccessor(hexagon_nn_wide_address_const_t spill_area_in, fa::PoolDesc const *const pt, + unsigned const pt_size) + : spill_area(spill_area_in), pool_table(pt), max_pool_id(pt_size) + { + } + API_EXPORT MemPoolRunTimeAccessor() : spill_area(0), pool_table(nullptr), max_pool_id(0) {} + API_EXPORT MemPoolRunTimeAccessor(MemPoolRunTimeAccessor const &) = default; + API_EXPORT MemPoolRunTimeAccessor &operator=(MemPoolRunTimeAccessor const &) = default; + + // pool ids are >= 1, <= num_pools + API_EXPORT constexpr unsigned num_pools() const { return max_pool_id; } //LCOV_EXCL_LINE [SAFTYSWCCB-1542] + // map pool_id to base address of the data, for persistent pool; also get 'is_weights' flag. + // implementation in runtime_alloc.h + std::pair get_persistent_pool_base_iswts(unsigned pool_id) const; + API_EXPORT hexagon_nn_wide_address_t get_spill_area() const { return spill_area; } + + // used to construct the ConstExtentDescriptor during prep + // implementation in fa_alloc.h + API_EXPORT fa::PoolDesc const *get_descriptor(unsigned pool_id) const; + + // get the id of first DDR mempool + API_EXPORT unsigned get_first_ddr_pool_id() const; +}; + +} // namespace hnnx + +POP_VISIBILITY() + +#endif diff --git a/qnn/jni/qnn/QNN/HTP/core/bake_defs.h b/qnn/jni/qnn/QNN/HTP/core/bake_defs.h new file mode 100644 index 00000000..6666dcd1 --- /dev/null +++ b/qnn/jni/qnn/QNN/HTP/core/bake_defs.h @@ -0,0 +1,244 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#ifndef BAKE_DEFS +#define BAKE_DEFS 1 +#include +#include +#include +#include + +#include "executable.h" + +// Contains defs for host-side and target side, so try not +// to add too many 'host only' things. + +#ifdef __hexagon__ +#define HNNX_ARCH_CAN_RUN_BAKED 1 +#endif + +namespace hnnx { + +namespace bake { + +using tgt_ptr_word = unsigned; +using tgt_sizet_word = unsigned; +static constexpr unsigned tgt_ptr_bytes = sizeof(tgt_ptr_word); +static constexpr unsigned tgt_sizet_bytes = sizeof(tgt_sizet_word); +static constexpr bool op_has_graphp = false; +static constexpr unsigned tensor_uptr_ptrs = 2; +static constexpr unsigned max_opaquet_align = 1024; // must be power of 2 + +// This should be OK as a first approx: includes hexagon and x86-32 +static constexpr bool host_can_run_baked = sizeof(void *) == tgt_ptr_bytes; + +inline unsigned constexpr round_up(unsigned x, unsigned m) +{ + return ((x + (m - 1)) / m) * m; +} + +// functions to calculate size, align of various things. They +// are included in target build so we can static_assert that sizes are what we think they are. +// (all must be constexpr). + +// {size, alignment} of typical_op +inline constexpr std::pair typical_op_tgt_size_align(unsigned n_in, unsigned n_out) +{ + // 1 pointer per input, plus tensor_uptr_ptrs per output; but if n_in = n_out == 0, it's 1 pointer. + // (for a 'fill' byte). + unsigned num_io_ptrs = n_in + n_out * tensor_uptr_ptrs; + if (num_io_ptrs == 0) num_io_ptrs = 1; // n_in = n_out = 0 case + return {tgt_ptr_bytes * ((op_has_graphp ? 2 : 1) // vptr, and maybe Graph * + + num_io_ptrs), // inputs and outputs + tgt_ptr_bytes}; // align +} + +// 'tensor_op_tgt_size_align is used for crate accounting of ShapeWrapperOp, ConstWrapperOp, DummyOp +// In a proper 'baked graph' we don't need to insert these, just the tensors... + +inline constexpr std::pair tensor_op_tgt_size_align(unsigned n_out) +{ + // happens to be the same as TypicalOp with no inputs... + return typical_op_tgt_size_align(0, n_out); +} + +// {size, alignment, extra} of typical_op_with_compiler +// extra_len is the len of the extra data +// extra_align is its alignment. +// The 3rd return value is the offset of the 'extra' within the image. +// +inline constexpr std::tuple +typical_op_extra_tgt_size_align(unsigned n_in, unsigned n_out, unsigned extra_len, unsigned extra_align) +{ + std::pair base_size = typical_op_tgt_size_align(n_in, n_out); + unsigned extra_offs = base_size.first; + if (extra_len > 0) { + extra_align = std::max(extra_align, base_size.second); + extra_len = round_up(extra_len, extra_align); + extra_offs = round_up(extra_offs, extra_align); + base_size.first = extra_offs + extra_len; + base_size.second = extra_align; + } + return {base_size.first, base_size.second, extra_offs}; +} + +// {size, alignment} of variadic op (without the in, out array contents)! +constexpr std::pair variadic_op_tgt_size_align(unsigned n_in, unsigned n_out) +{ + const unsigned cratevec_words = 2; + return {tgt_ptr_bytes * (1 // vptr + + (op_has_graphp ? 1 : 0) // Graph * + + 2 * cratevec_words), // two cratevecs + tgt_ptr_bytes}; // align +} +// {size, alignment} of simple_op_wrapper (without the in, out array contents)! +constexpr std::pair simplewrap_op_tgt_size_align(unsigned n_in, unsigned n_out) +{ + // this is just one more pointer than a variadic op... + const auto var_result = variadic_op_tgt_size_align(n_in, n_out); + return {var_result.first + tgt_ptr_bytes, var_result.second}; +} + +// {size, alignment} of a ChunkPreloadOp +constexpr std::pair chunk_preload_op_tgt_size_align() +{ + return {tgt_ptr_bytes * (1 // vptr + + (op_has_graphp ? 1 : 0) // Graph * + + 2), // ptr, len; + tgt_ptr_bytes}; // align +} + +// +// {size_align} of Shape object +// +constexpr std::pair shape_tgt_size_align(unsigned rank) +{ + // tgt_sizet_bytes * (1 + 1 + 2 * rank) = + // vtable ptr + // shapeflag flags + padding[] + // std::array dims + // std::array max_dims + // + rank = std::array pad + return {round_up(tgt_sizet_bytes * (1 + 1 + 1 + 2 * rank) + rank, tgt_sizet_bytes), tgt_sizet_bytes}; +} + +// +// {size_align} of DynamicShape object +// +constexpr std::pair dynamic_shape_tgt_size_align(const unsigned rank) +{ + // std::array dims == tgt_sizet_bytes * rank + // (shapeflag flags + padding[]) + vtable ptr + dynamic_state = (3 * tgt_sizet_bytes) + return {round_up(tgt_sizet_bytes * rank + (4 * tgt_sizet_bytes), tgt_sizet_bytes), tgt_sizet_bytes}; +} + +// +// {size_align} of interface object (may or may not be quantized) +// +constexpr std::pair interface_tgt_size_align(bool is_quantized) +{ + return {tgt_sizet_bytes + (is_quantized ? round_up(3 * 4, tgt_sizet_bytes) : 0), tgt_sizet_bytes}; +} + +// {size_align} of Tensors, of three different forms: +// +// 'general' tensor +// +constexpr std::pair tensor_general_tgt_size_align() +{ + return {tgt_sizet_bytes * 4 + 2 * tgt_ptr_bytes, tgt_sizet_bytes}; +} + +// 'shape' tensor, of given rank. +// +constexpr std::pair tensor_shape_tgt_size_align(unsigned rank) +{ + return {tgt_sizet_bytes * ((rank == 0 ? 1 : rank) + 1), tgt_sizet_bytes}; +} + +// 'scalar' tensor, need to know if the interface is 'quantized' or not +// Note, this assumes all value are <= size_t bytes. +// +constexpr std::pair tensor_scalar_tgt_size_align(bool is_quantized) +{ + const unsigned ifc_size = interface_tgt_size_align(is_quantized).first; + return {tgt_sizet_bytes * 2 + ifc_size, tgt_sizet_bytes}; +} +// sizeof OpExtraInfo on target: {long long, 2 * unsigned, char *, 4 * padbyte} +constexpr std::pair OpExtraInfo_size_align = {24, 8}; + +// The size of a SliceDispatchOp for the given number of slices. +// Currently it's always the same regardless of 'nslices'; We may introduce 'right-sized' +// value, in which case 'exact=true' will get the 'real' size; but exact = false will always +// give the full size. +constexpr std::pair slice_dispatch_op_size_align(unsigned const nslices, bool const exact = false) +{ + return {tgt_sizet_bytes * ((op_has_graphp ? 7 : 6) + Executable::MAX_OP_SLICES), tgt_sizet_bytes}; +} + +// The size of a Predicated Op +constexpr std::pair pred_op_size_align() +{ + return {tgt_sizet_bytes * ((op_has_graphp ? 5 : 4) + 3), tgt_sizet_bytes}; +} + +// this is used in e.g. +// if constexpr(host_can_run_baked) static_assert(size_align_matches(N_IN, N_OUT)); + +template constexpr bool size_align_matches(SZAL sz) +{ + return sizeof(T) == std::get<0>(sz) && alignof(T) == std::get<1>(sz); +} + +// This is a utility to check that a type T has a given size and aligment, using static_assert; +// Just need to include a call to 'do-nothing' bake::check_size_align::template check(); +// The static assert is *disabled* unless compiling on hexagon (or compatible host). +// +// It's more complex than it needs to be, since it's designed to make sure the type and +// numbers wind up in the error message, e.g. you could end up with +// error: static_assert failed due to requirement 'claimed(40) == actual(48)' "size not as claimed" +// static_assert(claimed(CLAIMED_SIZE) == actual(ACTUAL_SIZE), "size not as claimed"); +// ... note: in instantiation of function template specialization 'check_szal::check_size_align<..., ...>' +// +template struct check_size_align { + static constexpr int claimed(int K) { return K; } + static constexpr int actual(int K) { return K; } + template static constexpr bool check_size() + { + static_assert(claimed(CLAIMED_SIZE) == actual(ACTUAL_SIZE), "size not as claimed"); + return CLAIMED_SIZE == ACTUAL_SIZE; + } + template static constexpr bool check_align() + { + static_assert(claimed(CLAIMED_ALIGN) == actual(ACTUAL_ALIGN), "align not as claimed"); + return CLAIMED_ALIGN == ACTUAL_ALIGN; + } + + template static constexpr bool check() + { + bool result = true; + if constexpr (host_can_run_baked) { + result = check_size() && check_align(); + } + return result; + } +}; + +} // namespace bake + +// +// op_opaque_tgt_info must be specialized for each OpaqueT used in TypicalOpWithCompiler +// +template struct op_opaque_tgt_info { + // static constexpr unsigned length = ..; // length of the struct on target CPU + // static constexpr unsigned alignment = ... // aligbment on target CPU +}; + +} // namespace hnnx + +#endif // BAKE_DEFS diff --git a/qnn/jni/qnn/QNN/HTP/core/bfloat16.h b/qnn/jni/qnn/QNN/HTP/core/bfloat16.h new file mode 100644 index 00000000..f76321be --- /dev/null +++ b/qnn/jni/qnn/QNN/HTP/core/bfloat16.h @@ -0,0 +1,377 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#ifndef BFLOAT16_H +#define BFLOAT16_H + +#include +#include +#include +#include +#include + +#include "builtin_intrinsics.h" + +#include "weak_linkage.h" +#include "macros_attribute.h" + +PUSH_VISIBILITY(default) + +struct API_EXPORT BFloat16 { + public: + constexpr BFloat16() : d(0) {} + constexpr BFloat16(float f); + constexpr BFloat16(const BFloat16 &f) : d(f.d) {} + constexpr BFloat16 &operator=(const BFloat16 &f); + constexpr BFloat16(BFloat16 &&f) = default; + constexpr BFloat16 &operator=(BFloat16 &&f) = default; + ~BFloat16() = default; + + constexpr bool is_zero() const; + constexpr bool is_neg() const; + constexpr bool is_inf() const; + constexpr bool is_nan() const; + constexpr bool is_subnorm() const; + constexpr bool is_norm() const; + constexpr bool is_finite() const; + + constexpr int16_t exp() const; + constexpr uint16_t frac() const; + constexpr uint16_t raw() const { return d; } + + static constexpr int exp_max() { return 127; } + static constexpr int exp_min() { return -126; } + static constexpr int16_t bias() { return 127; } + + static constexpr BFloat16 zero(bool neg = false); + static constexpr BFloat16 qnan(); + static constexpr BFloat16 snan(); + static constexpr BFloat16 inf(bool neg = false); + + static constexpr BFloat16 from_raw(uint16_t v); + + constexpr operator float() const; + + private: + union { + uint16_t d; + struct { + uint16_t mantissa : 7; + uint16_t exponent : 8; + uint16_t sign : 1; + }; + }; + + constexpr uint16_t sign_bit() const; + constexpr uint16_t exp_bits() const; + constexpr uint16_t frac_bits() const; + + static constexpr uint16_t make(int sign, int exp, uint32_t frac); + static constexpr uint32_t round(uint32_t v, unsigned s); + + friend API_FUNC_EXPORT BFloat16 operator-(BFloat16 a); + friend API_FUNC_EXPORT BFloat16 operator+(BFloat16 a, BFloat16 b); + friend API_FUNC_EXPORT BFloat16 operator-(BFloat16 a, BFloat16 b); + friend API_FUNC_EXPORT BFloat16 operator*(BFloat16 a, BFloat16 b); + friend API_FUNC_EXPORT BFloat16 operator/(BFloat16 a, BFloat16 b); +}; + +POP_VISIBILITY() + +inline constexpr BFloat16 BFloat16::from_raw(uint16_t v) +{ + BFloat16 f; + f.d = v; + return f; +} + +inline constexpr BFloat16::BFloat16(float f) : d(0) +{ + union U { + constexpr U(float f) : as_f32(f) {} + float as_f32; + uint32_t as_u32; + } const u(f); + + // Preserve NaN values + // The only potential NaN values that can be lost are the ones that have an exp=0xFF and a non 0 bit in the 16 lsb + if ((u.as_u32 & 0x7F80FFFF) > 0x7F800000) { + d = 0x7FA0u; // qnan + return; + } + + // BFloat uses round to nearest even + bool const neg = u.as_u32 & (uint32_t(1u) << 31u); + int const exp_extract = (u.as_u32 >> 23u) & 0xFFu; + uint32_t const frac_bits = u.as_u32 & 0x7FFFFFu; + + int const exp = exp_extract - 127; + uint32_t frac = round(frac_bits | (uint32_t(1) << 23u), 23 - 7); + d = make(neg, exp, frac); +} + +inline constexpr BFloat16 &BFloat16::operator=(const BFloat16 &f) +{ + d = f.d; + return *this; +} + +inline constexpr uint16_t BFloat16::sign_bit() const +{ + return d & 0x8000u; +} + +inline constexpr uint16_t BFloat16::exp_bits() const +{ + return d & 0x7F80u; +} + +inline constexpr uint16_t BFloat16::frac_bits() const +{ + return d & 0x7Fu; +} + +inline constexpr bool BFloat16::is_zero() const +{ + return (exp_bits() | frac_bits()) == 0x0000; +} + +inline constexpr bool BFloat16::is_neg() const +{ + return sign_bit(); +} + +inline constexpr BFloat16 BFloat16::zero(bool neg) +{ + return BFloat16::from_raw((neg) ? 0x8000u : 0x0); +} + +inline constexpr BFloat16 BFloat16::qnan() +{ + return BFloat16::from_raw(0x7FA0u); +} + +inline constexpr BFloat16 BFloat16::snan() +{ + return BFloat16::from_raw(0x7FC0u); // impl defined +} + +inline constexpr BFloat16 BFloat16::inf(bool neg) +{ + return BFloat16::from_raw((neg) ? 0xFF80u : 0x7F80u); +} + +inline constexpr BFloat16::operator float() const +{ + union U { + constexpr U(uint32_t u) : as_u32(u) {} + float as_f32; + uint32_t as_u32; + } u(static_cast(raw()) << 16); + return u.as_f32; +} + +inline constexpr bool BFloat16::is_norm() const +{ + return is_zero() || (!is_inf() && !is_nan() && !is_subnorm()); +} + +inline constexpr bool BFloat16::is_inf() const +{ + return exp_bits() == 0x7F80u && frac_bits() == 0x0u; +} + +inline constexpr bool BFloat16::is_nan() const +{ + return exp_bits() == 0x7F80u && frac_bits() != 0x0u; +} + +inline constexpr bool BFloat16::is_subnorm() const +{ + return exp_bits() == 0x0000 && frac_bits() != 0x0000; +} + +inline constexpr bool BFloat16::is_finite() const +{ + return is_norm() || is_subnorm(); +} + +inline constexpr uint16_t BFloat16::make(int sign, int exp, uint32_t frac) +{ + assert(frac > 0); +#if defined(_MSC_VER) + // HEX_COUNT_LEADING_ZERO as defined for MSVC is not a constexpr. + // This logic is testing in gtest test_bfloat16.cc if changing this code please update test. + unsigned clz = 32u; + for (unsigned i = 0; i < 32; i++) { + if (frac & (1u << (31u - i))) { + clz = i; + break; + } + } +#else + unsigned const clz = static_cast(HEX_COUNT_LEADING_ZERO(frac)); +#endif // _MSC_VER + // For a finite, normalized non-zero number, clz should be 16+(16-8) = 24. + int exp_inc = 24u - clz; + if (exp + exp_inc > exp_max()) { + // Number has a magnitude that is too large. + return BFloat16::inf(sign).raw(); + } + if (exp + exp_inc < exp_min()) { + // This number can become subnormal or zero. + // safe_rshift will hit an assert if the shift is out of range + // If we had an out of range shift, then we should just clip it to the range + // Which should cause the frac to become 0 in either case + int mask = static_cast(hnnx::get_safe_shift_mask()); + int shift_amount = exp_min() - exp - exp_inc - 1; + shift_amount = (shift_amount > mask) ? mask : shift_amount; + frac = hnnx::safe_rshift(static_cast(frac), shift_amount); + return (static_cast(sign) << 15u) | (static_cast(frac) & 0x007Fu); // 0 exp bits + } + + if (exp_inc > 0) { // exp_inc < 0 not expected for float32 to bfloat16 casting + frac = round(frac, static_cast(exp_inc)); + // Rounding can change the most significant bit, so check it again. + // unsigned const clzr = HEX_COUNT_LEADING_ZERO(frac); + // assert(clzr == 24); + // clzr can only be 24 here because this make function is only called in the instantiation of a BFloat16 from a float. + // In the current code path, there is a rounding of the fractional bits before it is passed into this make function. + // As a result, there can be at most 8 significant bits in frac variable passed to the make function, which means that clzr can only be 24 + // However, if this function is called in other places where there are different limits on the number of significant bits in the input frac + // then clzr may be 23 or 24 and that will need to be accounted for here. + } + exp += exp_inc; + exp += bias(); + return (static_cast(sign) << 15u) | (static_cast(exp) << 7u) | + (static_cast(frac) & 0x007Fu); +} + +inline constexpr uint32_t BFloat16::round(uint32_t v, unsigned s) +{ + assert(s > 0); + unsigned const out_msb = hnnx::safe_lshift(1u, (s - 1)); + if ((v & out_msb) == 0) { + // Round down. + return hnnx::safe_rshift(v, s); + } + if ((v & (out_msb - 1)) == 0) { + // It's a tie, round to even. + v = hnnx::safe_rshift(v, s); + return v & 1u ? v + 1 : v; + } + // Round up. + return hnnx::safe_rshift(v, s) + 1; +} + +inline constexpr uint16_t BFloat16::frac() const +{ + if (is_zero()) { + return 0x0u; + } + uint16_t f = frac_bits(); + if (is_norm()) f |= 1u << 7u; + return f; +} + +inline constexpr int16_t BFloat16::exp() const +{ + int16_t const e = static_cast(exp_bits() >> 7u); + return e != 0 ? e - bias() : exp_min(); +} + +PUSH_VISIBILITY(default) +template <> class API_EXPORT std::numeric_limits { + public: + static constexpr bool is_specialized = true; + static constexpr bool is_signed = true; + static constexpr bool is_integer = false; + static constexpr bool is_exact = false; + static constexpr bool has_infinity = true; + static constexpr bool has_quiet_NaN = true; + static constexpr bool has_signaling_NaN = true; + static constexpr auto has_denorm = std::denorm_present; + static constexpr bool has_denorm_loss = false; // libc++ + static constexpr auto round_style = std::round_to_nearest; + static constexpr bool is_iec559 = false; + static constexpr bool is_bounded = true; + static constexpr bool is_modulo = false; + static constexpr int digits = 8; + static constexpr int digits10 = 2; // floor((digits-1) * log10(2)) + static constexpr int max_digits10 = 4; // ceil(digits * log10(2) + 1) + static constexpr int radix = 2; + static constexpr int min_exponent = -126; + static constexpr int min_exponent10 = -37; // float32 value + static constexpr int max_exponent = 127; + static constexpr int max_exponent10 = 38; // largest finite val = 3.3895314E38 + static constexpr bool traps = false; + static constexpr bool tinyness_before = false; // libc++ + + static constexpr BFloat16 min() noexcept; // returns min positive normal + static constexpr BFloat16 lowest() noexcept; // returns true min + static constexpr BFloat16 max() noexcept; // max positive + static constexpr BFloat16 epsilon() noexcept; // step at 1.0 + static constexpr BFloat16 round_error() noexcept; // 0.5 + static constexpr BFloat16 infinity() noexcept; + static constexpr BFloat16 quiet_NaN() noexcept; + static constexpr BFloat16 signaling_NaN() noexcept; + static constexpr BFloat16 denorm_min() noexcept; // min positive denorm +}; + +POP_VISIBILITY() + +constexpr BFloat16 std::numeric_limits::min() noexcept +{ + // 0 0000 0001 0000000 + return BFloat16::from_raw(0x80u); +} + +constexpr BFloat16 std::numeric_limits::lowest() noexcept +{ + // -2^127 * (1.9921875) ; 1 1111 1110 1111 111 + return BFloat16::from_raw(0xFF7Fu); // -3.3895314E38 +} + +constexpr BFloat16 std::numeric_limits::max() noexcept +{ + return BFloat16::from_raw(0x7f7fu); +} + +constexpr BFloat16 std::numeric_limits::epsilon() noexcept +{ + // 2^-7 * (1) ; 0 01111000 0000000 + return BFloat16::from_raw(0x3C00u); // next_after_1.0 - 1.0 +} + +constexpr BFloat16 std::numeric_limits::round_error() noexcept +{ + // 2^-1 * (1) ; 0 01111110 0000000 + return BFloat16::from_raw(0x3F00u); // 0.5 +} + +constexpr BFloat16 std::numeric_limits::infinity() noexcept +{ + return BFloat16::inf(false); +} + +constexpr BFloat16 std::numeric_limits::quiet_NaN() noexcept +{ + return BFloat16::qnan(); +} + +constexpr BFloat16 std::numeric_limits::signaling_NaN() noexcept +{ + return BFloat16::snan(); +} + +constexpr BFloat16 std::numeric_limits::denorm_min() noexcept +{ + return BFloat16::from_raw(0x0001u); +} + +#endif // BFLOAT16_H diff --git a/qnn/jni/qnn/QNN/HTP/core/build_options_pub.h b/qnn/jni/qnn/QNN/HTP/core/build_options_pub.h new file mode 100644 index 00000000..2b71cc59 --- /dev/null +++ b/qnn/jni/qnn/QNN/HTP/core/build_options_pub.h @@ -0,0 +1,39 @@ +//============================================================================== +// +// Copyright (c) 2024 Qualcomm Technologies, Inc. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== +#ifndef BUILD_OPTIONS_PUB_H +#define BUILD_OPTIONS_PUB_H 1 + +namespace build_options_pub { + +#ifdef WITH_OPT_DEBUG +#ifndef DEFOPT_LOG +#define DEFOPT_LOG 1 +#endif +#endif + +#ifdef DEFOPT_LOG +constexpr bool DefOptLog = true; +#else +constexpr bool DefOptLog = false; +#endif + +#ifdef DEBUG_REGISTRY +constexpr bool DebugRegistry = true; +#else +constexpr bool DebugRegistry = false; +#endif + +#ifdef PREPARE_DISABLED +static constexpr bool WITH_PREPARE = false; +#else +static constexpr bool WITH_PREPARE = true; +#endif + +} // namespace build_options_pub + +#endif // BUILD_OPTIONS_PUB_H diff --git a/qnn/jni/qnn/QNN/HTP/core/builtin_intrinsics.h b/qnn/jni/qnn/QNN/HTP/core/builtin_intrinsics.h new file mode 100644 index 00000000..3496b792 --- /dev/null +++ b/qnn/jni/qnn/QNN/HTP/core/builtin_intrinsics.h @@ -0,0 +1,247 @@ +//============================================================================== +// +// Copyright (c) 2023 Qualcomm Technologies, Inc. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +// Compiler builtin intrinsic functions should be specified in this file + +#ifndef BUILTIN_INTRINSICS_H_ +#define BUILTIN_INTRINSICS_H_ + +#include +#include +#include +#include + +// Branch prediction +#if defined(__clang__) + +#define HEX_LIKELY(x) __builtin_expect(!!(x), 1) +#define HEX_UNLIKELY(x) __builtin_expect(!!(x), 0) + +#define HEX_ASSUME __builtin_assume +#define HEX_UNREACHABLE __builtin_unreachable + +#elif defined(_MSC_VER) + +#define HEX_LIKELY(x) (x) +#define HEX_UNLIKELY(x) (x) + +#define HEX_ASSUME __assume +#define HEX_UNREACHABLE() __assume(0) + +#elif defined(__GNUC__) +//No equivalent __builtin_assume in GNUC. Hence leaving empty. +#define HEX_ASSUME(cond) + +#define HEX_LIKELY(x) __builtin_expect(!!(x), 1) +#define HEX_UNLIKELY(x) __builtin_expect(!!(x), 0) +#define HEX_UNREACHABLE __builtin_unreachable + +#endif // defined(__clang__) + +// Overflow detection +#if defined(__clang__) || defined(__GNUC__) + +#define HEX_ADD_OVERFLOW __builtin_add_overflow +#define HEX_MUL_OVERFLOW __builtin_mul_overflow + +#elif defined(_MSC_VER) + +#include + +template static inline bool HEX_ADD_OVERFLOW(_T a, _T b, _T *out) +{ + *out = a + b; + return ((b > 0) && (a > std::numeric_limits<_T>::max() - b)) || + ((b < 0) && (a < std::numeric_limits<_T>::min() - b)); +} + +template static inline bool HEX_MUL_OVERFLOW(_T a, _T b, _T *out) +{ + *out = a * b; + return ((b > 0) && (a > std::numeric_limits<_T>::max() / b || a < std::numeric_limits<_T>::min() / b)) || + ((b < 0) && (a > std::numeric_limits<_T>::min() / b || a < std::numeric_limits<_T>::max() / b)); +} + +#endif // __clang__ + +// Count bits + +#include + +template static inline int HEX_COUNT_ONE_BIT(_T x) +{ + return std::bitset(x).count(); +} + +#define HEX_COUNT_ONE_BIT_ULL HEX_COUNT_ONE_BIT +#define HEX_COUNT_ONE_BIT_UL HEX_COUNT_ONE_BIT + +#if defined(__clang__) || defined(__GNUC__) + +#define HEX_COUNT_LEADING_ZERO __builtin_clz +#define HEX_COUNT_LEADING_ZERO_UL __builtin_clzl +#define HEX_COUNT_LEADING_ZERO_ULL __builtin_clzll + +#define HEX_COUNT_TRAILING_ZERO __builtin_ctz +#define HEX_COUNT_TRAILING_ZERO_UL __builtin_ctzl +#define HEX_COUNT_TRAILING_ZERO_ULL __builtin_ctzll + +#elif defined(_MSC_VER) + +#include + +// Returns the number of leading 0-bits in x, starting at the most significant +// bit position. If x is 0, the result is undefined. +static inline int HEX_COUNT_LEADING_ZERO_ULL(unsigned long long x) +{ + unsigned long where; + if (_BitScanReverse64(&where, x)) return static_cast(63 - where); + return 64; // Undefined behavior +} + +static inline int HEX_COUNT_LEADING_ZERO(unsigned int x) +{ + unsigned long where; + if (_BitScanReverse(&where, x)) return static_cast(31 - where); + return 32; // Undefined Behavior. +} + +static inline int HEX_COUNT_LEADING_ZERO_UL(unsigned long x) +{ + return sizeof(x) == 8 ? HEX_COUNT_LEADING_ZERO_ULL(x) : HEX_COUNT_LEADING_ZERO(static_cast(x)); +} + +// Returns the number of trailing 0-bits in x, starting at the least significant +// bit position. If x is 0, the result is undefined. +static inline int HEX_COUNT_TRAILING_ZERO_ULL(unsigned long long x) +{ + unsigned long where; + if (_BitScanForward64(&where, x)) return static_cast(where); + return 64; // Undefined Behavior. +} + +static inline int HEX_COUNT_TRAILING_ZERO(unsigned int x) +{ + unsigned long where; + if (_BitScanForward(&where, x)) return static_cast(where); + return 32; // Undefined Behavior. +} + +static inline int HEX_COUNT_TRAILING_ZERO_UL(unsigned long x) +{ + return sizeof(x) == 8 ? HEX_COUNT_TRAILING_ZERO_ULL(x) : HEX_COUNT_TRAILING_ZERO(static_cast(x)); +} + +#endif // defined(__clang__) + +// Atomic operation + +#if defined(__clang__) || defined(__GNUC__) + +#define HEX_ATOMIC_FETCH_AND_ADD __sync_fetch_and_add + +#define HEX_ATOMIC_FETCH_AND_AND __sync_fetch_and_and +#define HEX_ATOMIC_FETCH_AND_OR __sync_fetch_and_or + +#define HEX_ATOMIC_VAL_COMPARE_AND_SWAP __sync_val_compare_and_swap +#define HEX_ATOMIC_BOOL_COMPARE_AND_SWAP __sync_bool_compare_and_swap + +#elif defined(_MSC_VER) + +#include + +#define HEX_ATOMIC_FETCH_AND_ADD(_p, _v) \ + (sizeof *(_p) == sizeof(__int64) ? _InterlockedExchangeAdd64((__int64 *)(_p), (__int64)(_v)) \ + : _InterlockedExchangeAdd((long *)(_p), (long)(_v))) + +template static inline _T HEX_ATOMIC_FETCH_AND_AND(_T volatile *_p, _T _v) +{ + _InterlockedAnd((long *)_p, (long)_v); + return static_cast<_T>(*_p); +} + +template static inline _T HEX_ATOMIC_FETCH_AND_OR(_T volatile *_p, _T _v) +{ + _InterlockedOr((long *)_p, (long)_v); + return static_cast<_T>(*_p); +} + +#define HEX_ATOMIC_VAL_COMPARE_AND_SWAP(_p, _old, _new) \ + (sizeof *(_p) == sizeof(__int64) \ + ? _InterlockedCompareExchange64((__int64 *)(_p), (__int64)(_new), (__int64)(_old)) \ + : _InterlockedCompareExchange((long *)(_p), (long)(_new), (long)(_old))) + +#define HEX_ATOMIC_BOOL_COMPARE_AND_SWAP(_p, _old, _new) (HEX_ATOMIC_VAL_COMPARE_AND_SWAP(_p, _old, _new) == (_old)) + +#endif // defined(__clang__) + +namespace hnnx { + +/** + * @brief promote_shift_operand reflects the integral promotions for small integer types. + * safe_lshift/safe_rshift must be aware of these promotions, since the C++ standard only + * defines the behavior for shift operations where the RHS is between 0 and + * 1 less than the bit-width of the *promoted* type of the LHS. + */ +template struct promote_shift_operand { + typedef T type; +}; + +template <> struct promote_shift_operand { + using type = int; +}; +template <> struct promote_shift_operand { + using type = int; +}; +template <> struct promote_shift_operand { + using type = int; +}; +template <> struct promote_shift_operand { + using type = int; +}; +template <> struct promote_shift_operand { + using type = int; +}; + +template using promote_shift_operand_t = typename promote_shift_operand::type; + +// The following portable template functions are replacements for the +// built-in shift operations, << and >>, that provide the following guarantees: +// +// 1. Both the left and right operands of the shift will be treated as unsigned. +// This, by construction, prevents any undefined or implementation-defined +// behavior that may arise when shifting negative-valued expressions. +// 2. The right operand will be bit-masked in a way that guarantees +// that its value is in the range [0, bitwidth(promoted_left_operand) - 1] + +template constexpr unsigned get_safe_shift_mask() +{ + return unsigned(CHAR_BIT * sizeof(promote_shift_operand_t>>) - 1); +} + +template ()> +constexpr auto safe_lshift(T const value, S const shift_amount) +{ + static_assert(std::is_integral::value && std::is_integral::value, + "safe_lshift only makes sense for integral parameters"); + assert((static_cast(shift_amount) & ~mask) == 0 && "shift_amount is out of range"); + return value << shift_amount; +} + +template ()> +constexpr auto safe_rshift(T const value, S const shift_amount) +{ + static_assert(std::is_integral::value && std::is_integral::value, + "safe_rshift only makes sense for integral parameters"); + assert((static_cast(shift_amount) & ~mask) == 0 && "shift_amount is out of range"); + return value >> shift_amount; +} + +} // namespace hnnx + +#endif /* BUILTIN_INTRINSICS_H_ */ diff --git a/qnn/jni/qnn/QNN/HTP/core/c_tricks.h b/qnn/jni/qnn/QNN/HTP/core/c_tricks.h new file mode 100644 index 00000000..05316250 --- /dev/null +++ b/qnn/jni/qnn/QNN/HTP/core/c_tricks.h @@ -0,0 +1,21 @@ +//============================================================================== +// +// Copyright (c) 2020 Qualcomm Technologies, Inc. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#ifndef C_TRICKS_H +#define C_TRICKS_H 1 + +#define CTRICKS_PASTER2(A, B) A##B +#define CTRICKS_PASTER(A, B) CTRICKS_PASTER2(A, B) + +#define STRINGIFY(x) #x +#define TOSTRING(x) STRINGIFY(x) + +#define PROBABLY(x) __builtin_expect(!(!(x)), 1) +#define YEAHRIGHT(x) __builtin_expect(!(!(x)), 1) + +#endif diff --git a/qnn/jni/qnn/QNN/HTP/core/cc_pp.h b/qnn/jni/qnn/QNN/HTP/core/cc_pp.h new file mode 100644 index 00000000..c4363d8c --- /dev/null +++ b/qnn/jni/qnn/QNN/HTP/core/cc_pp.h @@ -0,0 +1,26 @@ +//============================================================================== +// +// Copyright (c) 2020 Qualcomm Technologies, Inc. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#ifndef CC_PP_H +#define CC_PP_H 1 + +/* + * C++ Preprocessor Definitions + */ + +#ifdef __cplusplus +#define EXTERN_C_BEGIN extern "C" { +#define EXTERN_C_END \ + } \ + ; +#else +#define EXTERN_C_BEGIN /* NOTHING */ +#define EXTERN_C_END /* NOTHING */ +#endif + +#endif diff --git a/qnn/jni/qnn/QNN/HTP/core/check_hvx.h b/qnn/jni/qnn/QNN/HTP/core/check_hvx.h new file mode 100644 index 00000000..bd12354b --- /dev/null +++ b/qnn/jni/qnn/QNN/HTP/core/check_hvx.h @@ -0,0 +1,35 @@ +//============================================================================== +// +// Copyright (c) 2022-2023 Qualcomm Technologies, Inc. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#include "cc_pp.h" +#include "macros_attribute.h" +#include "weak_linkage.h" + +#ifndef CHECK_HVX_H +#define CHECK_HVX_H 1 + +// +// This makes sure that we have an HVX context (or not). Does nothing on H2 or +// QuRT, but on x86, makes use of a TLS variable to do the check. +// + +#ifdef __hexagon__ + +static inline void check_hvx() {} +static inline void check_not_hvx() {} + +#else + +PUSH_VISIBILITY(default) +API_EXPORT void check_hvx(); +API_EXPORT void check_not_hvx(); +POP_VISIBILITY() + +#endif + +#endif // CHECK_HVX_H diff --git a/qnn/jni/qnn/QNN/HTP/core/conditional_default_deleter.h b/qnn/jni/qnn/QNN/HTP/core/conditional_default_deleter.h new file mode 100644 index 00000000..79ceb826 --- /dev/null +++ b/qnn/jni/qnn/QNN/HTP/core/conditional_default_deleter.h @@ -0,0 +1,85 @@ +#pragma once +/// +/// @file conditional_default_deleter.h +/// @brief Implementation of a conditional (i.e. to destroy or to not destroy +/// managed object) deleter for use with smart pointers +/// +/// Copyright (c) 2025 Qualcomm Technologies, Inc. All Rights Reserved. +/// Confidential and Proprietary - Qualcomm Technologies, Inc. +/// + +#include +namespace hnnx { +/// +/// @brief Conditional deleter for use with C++ smart pointers +/// @tparam T Type being managed by associated smart pointer +/// +template struct conditional_default_deleter { + /// + /// @brief Constructor + /// @param destroy Should instance of managed object be really destroyed via + /// standard deallocator - delete or delete[]? + /// + constexpr conditional_default_deleter(bool destroy) : _must_destroy(destroy) {} + + /// + /// @brief Copy constructor + /// @param rhs Conditional deleter instance to copy from + /// + conditional_default_deleter(conditional_default_deleter const &from) : _must_destroy(from._must_destroy) {} + + /// + /// @brief Move constructor + /// @param [in] from Instance to move from + /// @warning Required by static analyzer + /// + conditional_default_deleter(conditional_default_deleter &&from) = default; + + /// + /// @brief Copy assignment operator + /// @param [in] from Instance to copy from + /// @warning Required by static analyzer + /// + conditional_default_deleter &operator=(conditional_default_deleter const &from) = default; + + /// + /// @brief Move assignment operator + /// @details Not implemented! + /// @warning Required by static analyzer + /// + conditional_default_deleter &operator=(conditional_default_deleter &&from) = default; + + /// + /// @brief Destructor + /// + ~conditional_default_deleter() = default; + + /// + /// @brief Function operator + /// @param [in] ptr Pointer to be deleted + /// + void operator()(T *ptr) const + { + if (_must_destroy) { + delete ptr; + } + } + + /// + /// @brief Function operator + /// @param [in] ptr Array to be deleted + /// + void operator()(T ptr) const + { + if (_must_destroy) { + delete[] ptr; + } else { + } + } + + /// + /// @brief Should object managed by smart pointer be destroyed? + /// + bool const _must_destroy; +}; +}; // namespace hnnx diff --git a/qnn/jni/qnn/QNN/HTP/core/const_extent_descriptor.h b/qnn/jni/qnn/QNN/HTP/core/const_extent_descriptor.h new file mode 100644 index 00000000..0a9a8d0e --- /dev/null +++ b/qnn/jni/qnn/QNN/HTP/core/const_extent_descriptor.h @@ -0,0 +1,267 @@ +//============================================================================== +// +// Copyright (c) 2023 Qualcomm Technologies, Inc. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#ifndef CONST_EXTENT_DESCRIPTOR_H +#define CONST_EXTENT_DESCRIPTOR_H 1 + +#include +#include +#include +#include +#include "forward_classes.h" +#include "serialize_defs.h" +#include "pickle_header_tags.h" +#include "const_extent_shared.h" + +namespace hnnx { + +// This class is used, on both encoder and decoder, to contain a 'const extent descriptor' in its raw form, (just an array of uint32) +// and provide higher-level access to the contents. + +class ConstExtentDesc { + protected: + using table_t = std::vector; + // The 'table' may or may not contain the 'padding' section at the end; this is not accessed, + // and the serialize method will always generate the required padding. + table_t table; + // some values broken out from the header... + unsigned extab_n = 0, extab_idx = 0; // number of extents, and word index where they start + unsigned mptab_n = 0, mptab_idx = 0; // number of memory pools, and word index where they start. + unsigned desc_len = 0; // length of the entire descriptor in bytes (0 if invalid descriptor) + + bool scan_table(); // sanity check, and unpacks the above; returns true if OK. + + public: + /// + /// @brief Header + /// @details Composition of header of constant extent section ... + /// + /// 33222222 22221111 111111 + /// 10987654 32109876 54321098 76543210 + /// +--------+--------+--------+--------+ + /// | magic | 0 + /// +--------+--------+--------+--------+ + /// |hlen/4W | desc_len/64B | 1 + /// +--------+--------+--------+--------+ + /// |reserved| flags | num_extents | 2 + /// +--------+--------+--------+--------+ + /// |reserved| num_mempools | 3 + /// +--------+--------------------------+ + /// + + /// + /// @brief LSB and width of various bitfields in header + /// @warning It MUST MATCH the ASCII art of the header above! + /// + static size_t constexpr HEADER_DESC_LEN_BITFIELD_LSB = 0u; + static size_t constexpr HEADER_DESC_LEN_BITFIELD_WIDTH = 24u; + static size_t constexpr HEADER_LEN_BITFIELD_LSB = 24u; + static size_t constexpr HEADER_LEN_BITFIELD_WIDTH = 8u; + static size_t constexpr HEADER_NUM_EXTENTS_BITFIELD_LSB = 0u; + static size_t constexpr HEADER_NUM_EXTENTS_BITFIELD_WIDTH = 16u; + static size_t constexpr HEADER_FLAGS_BITFIELD_LSB = 16u; + static size_t constexpr HEADER_FLAGS_BITFIELD_WIDTH = 8u; + static size_t constexpr HEADER_NUM_MEMPOOLS_BITFIELD_LSB = 0u; + static size_t constexpr HEADER_NUM_MEMPOOLS_BITFIELD_WIDTH = 24u; + + /// + /// @brief Values for 8b flags in constant extent header + /// + static uint8_t constexpr HEADER_FLAG_RESERVED_0 = (1 << 0); + static uint8_t constexpr HEADER_FLAG_RESERVED_1 = (1 << 1); + static uint8_t constexpr HEADER_FLAG_RESERVED_2 = (1 << 2); + static uint8_t constexpr HEADER_FLAG_IS_REPLACEABLE = (1 << 3); ///< Contents are replaceable weights + static uint8_t constexpr HEADER_FLAG_IS_FAR_HINT = (1 << 4); ///< Contents maybe far + static uint8_t constexpr HEADER_FLAG_RESERVED_5 = (1 << 5); + static uint8_t constexpr HEADER_FLAG_RESERVED_6 = (1 << 6); + static uint8_t constexpr HEADER_FLAG_RESERVED_7 = (1 << 7); + + static uint8_t constexpr EXTENT_FLAGS_BITFIELD_LSB = 8u; + static uint8_t constexpr EXTENT_FLAGS_BITFIELD_WIDTH = 8u; + + /// + /// @brief Values for 8b flags in extent record + /// + static uint8_t constexpr EXTENT_FLAG_RESERVED_0 = (1 << 0); + static uint8_t constexpr EXTENT_FLAG_RESERVED_1 = (1 << 1); + static uint8_t constexpr EXTENT_FLAG_RESERVED_2 = (1 << 2); + static uint8_t constexpr EXTENT_FLAG_RESERVED_3 = (1 << 3); + static uint8_t constexpr EXTENT_FLAG_IS_FAR_HINT = (1 << 4); ///< Contents maybe far + static uint8_t constexpr EXTENT_FLAG_RESERVED_5 = (1 << 5); + static uint8_t constexpr EXTENT_FLAG_RESERVED_6 = (1 << 6); + static uint8_t constexpr EXTENT_FLAG_RESERVED_7 = (1 << 7); + + // Return from 'extent_info'. + struct extab_entry { + uint32_t extent_flags; + uint32_t align; // a power of 2, >= 64 + uint64_t offset; // offset, in bytes, from the start of the descriptor, to where the data is. + uint64_t length; // length of the data in bytes. + }; + // Return from 'mempool_info'. + // Note: if 'adjust_offset' is true, the 'offset' field from the containing extent will be added to offset, + // so that the offset is from the start of the descriptor, instead of the start of the containing extent. + struct mempool_entry { + uint32_t mempool_id; // a mempool id >=2 indicating a const mempool + uint32_t extent_id; // an extent_id, >=1 + uint64_t offset; // offset in bytes of the data from the start of the extent (see note above) + uint64_t length; // length in bytes of the data + }; + // optional name of the const_extent this descriptor corresponds to. Used for matching in weight_sharing. + std::string name = std::string{}; + + /// + /// @brief Various options for adjusting offset in mempool_info() + /// + enum offset_adjust_t { + OFFSET_ADJUST_DESC_REL = 0, ///< Adjust offset relative to descriptor (default) + OFFSET_ADJUST_FALSE = + OFFSET_ADJUST_DESC_REL, ///< Alias to descriptor-relative address (default) - i.e. dont adjust + OFFSET_ADJUST_EXTENT_REL = 1, ///< Adjust offset relative to extent + OFFSET_ADJUST_TRUE = OFFSET_ADJUST_EXTENT_REL, ///< Alias to extent-relative offset - i.e. adjust + OFFSET_ADJUST_IF_FAR, ///< Offset relative to extent if containing extent is far + }; + + ConstExtentDesc() {} + ConstExtentDesc(table_t &&table_in); + void serialize(Serializer &) const; + inline bool load_table(table_t &&table_in) + { + table = std::move(table_in); + return scan_table(); + } + + constexpr bool is_valid() const { return desc_len != 0; } + + constexpr unsigned descriptor_length() const { return desc_len; } + + constexpr unsigned num_extents() const { return extab_n; } + constexpr unsigned num_mempools() const { return mptab_n; } + + // unpack a row of the extent table + // NOTE: extent_id is 1-based, must be 1 .. num_extents() + extab_entry extent_info(unsigned extent_id) const; + + /// + /// @brief Get/unpack a mempool entry from mempool table in this constant extent + /// descriptor + /// @param [in] idx ID (1-based!) of the mempool entry to get. It is expected + /// to be in range [1...num_mempools()] + /// @param [in] adjust_offset Option to adjust offset + /// @return Valid mempool entry + /// + mempool_entry mempool_info(unsigned idx, offset_adjust_t adjust_offset = OFFSET_ADJUST_FALSE) const; + + // The ordering of the data and the descriptors is such that: + // + // (1) extent_info(1).offset >= descriptor_length() + // mempool_info(1,true).offset >= descriptor_length() + // (2) for i >=2, + // extent_info(i).offset >= extent_info(i+1).offset + extent_info(i+1).length + // mempool_info(i,true).offset >= mempool_info(1-1,true).offset + mempool_info(1-1).length + // + +#if !defined(PREPARE_DISABLED) + /// + /// @brief Memory pool record iterator + /// @details Use to iterator over records in memory pool table in constant + /// extent descriptor + /// + class mempool_iterator { + public: + using iterator_category = std::input_iterator_tag; + using value_type = ConstExtentDesc::mempool_entry; + using difference_type = std::ptrdiff_t; + using pointer = value_type *; + using reference = value_type &; + + /// + /// @brief Constructor + /// @param [in] cedesc A valid constant extent descriptor instance + /// @param [in] index Record index (zero-based!) + /// + explicit mempool_iterator(ConstExtentDesc const &cedesc, uint32_t index) : _cedesc(cedesc), _index(index) {} + + /// + /// @brief Increment record + /// @return Iterator + /// + mempool_iterator &operator++() + { + // Increment IFF valid constant extent descriptor and mempool record + // index within range + _index += (_cedesc.is_valid() && (_index < _cedesc.mptab_n)) ? 1 : 0; + return *this; + } + + /// + /// @brief Equality operator + /// @return true if iterators are equal + /// + bool operator==(mempool_iterator const &other) const { return _index == other._index; } + + /// + /// @brief Inequality operator + /// @return true if iterators are not equal + /// + bool operator!=(mempool_iterator const &other) const { return !(*this == other); } + + /// + /// @brief Dereference iterator + /// + reference operator*(); + + private: + /// + /// @brief Reference to a constant extent descriptor instance + /// @details It contains the blob representing constant extent segment + /// + ConstExtentDesc const &_cedesc; + + /// + /// @brief Current index + /// + uint32_t _index; + + /// + /// @brief Mempool record entry + /// @details It is assigned when on iterator dereference + /// + value_type _entry; + }; + + /// + /// @brief Return mempool iterator initialized to the first record + /// @return Mempool iterator + /// + mempool_iterator begin() { return mempool_iterator(*this, 0); } + + /// + /// @brief Return mempool iterator beyond the last record + /// @warning Intended to be used as a sentinel + /// @return Mempool iterator + /// + mempool_iterator end() { return mempool_iterator(*this, mptab_n); } +#endif +}; +#ifndef PREPARE_DISABLED +// Called at the end of serializing a graph, if 'const extent' mode is enabled. +// See comment in const_extent_descriptor.cc for full details. +// LCOV_EXCL_START [SAFTYSWCCB-1542] +size_t write_aligned_const_info(Graph const &gr, Serializer &sctx, unsigned buried_aux_n_words = 0); +#else +inline constexpr size_t write_aligned_const_info(Graph const &gr, Serializer const &sctx, unsigned = 0) +{ + return 0; +} +// LCOV_EXCL_STOP +#endif + +} // namespace hnnx + +#endif // CONST_EXTENT_DESCRIPTOR_H diff --git a/qnn/jni/qnn/QNN/HTP/core/const_extent_shared.h b/qnn/jni/qnn/QNN/HTP/core/const_extent_shared.h new file mode 100644 index 00000000..39c95e26 --- /dev/null +++ b/qnn/jni/qnn/QNN/HTP/core/const_extent_shared.h @@ -0,0 +1,81 @@ +//============================================================================== +// +// Copyright (c) 2024 Qualcomm Technologies, Inc. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#ifndef CONST_EXTENT_SHARED_H_ +#define CONST_EXTENT_SHARED_H_ + +namespace hnnx { +// definitions pertaining to the 'const extent descriptor'. + +constexpr unsigned CONST_EXTENT_DESC_MAGIC = 0x71c43c9b; +// if a const extent descriptor has a 'cbname' in it, the last 32-bit slot +// is this value. The 0x3e, 0x00 is the ">\0" at the end of the cbname +constexpr unsigned CONST_EXTENT_CBNAME_TAG = 0xebbe003e; + +// This must be a power of 2, and >= 64. +// This is effectively a 'quiet' minimum on options.serialize_const_alignment, which sets +// the actual alignment. +// It is not necessary for the decoder to know what value of alignment was used in the encoder. +constexpr unsigned CONST_EXTENT_MIN_ALIGN = 256; +// +// this is a (non-quiet) maximum on options.serialize_const_alignment +constexpr unsigned CONST_EXTENT_MAX_ALIGN = 1024 * 1024; + +/// +/// @brief Size of const extent descriptor header +/// +constexpr unsigned CONST_EXTENT_HEADER_SIZE_WORDS = 4u; +constexpr unsigned CONST_EXTENT_HEADER_SIZE_BYTES = CONST_EXTENT_HEADER_SIZE_WORDS * 4u; + +/// +/// @brief Size of an extent record +/// @details Const extent descriptor contains a table of such records +/// +constexpr unsigned CONST_EXTENT_RECORD_SIZE_WORDS = 4u; +constexpr unsigned CONST_EXTENT_RECORD_SIZE_BYTES = CONST_EXTENT_RECORD_SIZE_WORDS * 4u; + +/// +/// @brief Offset of extent record table relative to const extent descriptor +/// @details Both byte and words offsets are listed +/// +constexpr unsigned CONST_EXTENT_RECORD_TAB_OFFSET_WORDS = 4u; +constexpr unsigned CONST_EXTENT_RECORD_TAB_OFFSET_BYTES = CONST_EXTENT_RECORD_TAB_OFFSET_WORDS * 4u; + +/// +/// @brief Size of mempool record in a const extent descriptor +/// @details Both byte and word sizes are provided +/// +constexpr unsigned CONST_EXTENT_MEMPOOL_RECORD_SIZE_WORDS = 4u; +constexpr unsigned CONST_EXTENT_MEMPOOL_RECORD_SIZE_BYTES = CONST_EXTENT_MEMPOOL_RECORD_SIZE_WORDS * 4u; + +// This function is used by deserializer to help it extract the extent-desc table (as a vector) from some +// arbitrary point down the pickle. Parameter is a pointer to the first 4 words; the return value is +// 0 if the first two words do not look like CEDesc header; +// n otherwise (where 'n' is the number of 32-bit words to extract). +// +inline unsigned const_extent_hdr_check(uint32_t const *const hdrp) +{ + if (hdrp[0] != CONST_EXTENT_DESC_MAGIC) return 0; + const unsigned word0 = hdrp[1]; + const unsigned hdr_len16 = word0 >> 24u; // units of 16 bytes + const unsigned desc_len64 = word0 & 0xFFFFFFu; // units of 64 bytes + const unsigned n_extent = hdrp[2] & 0xFFFFFFu; + const unsigned n_mempool = hdrp[3] & 0xFFFFFFu; + // no. of words actually needed + const unsigned desc_words = 4 * (hdr_len16 + n_extent + n_mempool); + + // note, n_extent == n_mempool == 0 is allowed. + if (hdr_len16 == 0 || desc_len64 == 0 || n_extent > n_mempool || desc_words > desc_len64 * 16) { + return -1; + } + return desc_words; +} + +} // namespace hnnx + +#endif diff --git a/qnn/jni/qnn/QNN/HTP/core/constraints.h b/qnn/jni/qnn/QNN/HTP/core/constraints.h new file mode 100644 index 00000000..7ecdbade --- /dev/null +++ b/qnn/jni/qnn/QNN/HTP/core/constraints.h @@ -0,0 +1,146 @@ +//============================================================================== +// +// Copyright (c) 2020 Qualcomm Technologies, Inc. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#ifndef CONSTRAINTS_H +#define CONSTRAINTS_H + +#include "interface_defs.h" +#include "op_def.h" + +#include +#include + +namespace constraint_lib { + +/** \defgroup OptConstraint Constraint Expressions for Optimization Rules + * \ingroup OptimizationFuncs + * + * @{ + */ +//! Find the chunksize of a given tensor type in a given dimension (a constant). +/// For instance, LAYOUT_CHUNKSIZE(QUint8CroutonTensor,3) gives size_t(32) +/// +#define LAYOUT_CHUNKSIZE(TYPENAME, IDX) (TYPENAME::layout.ChunkSizes[(IDX)]) + +// some convenience wrappers... + +//! IS_FLOAT16("operand") -> bool (true if operand has Float16 output) +#define IS_FLOAT16(X) EQ(DTYPE_OF(X), DType::Float16) + +//! IS_FLOAT32("operand") -> bool (true if operand has float output) +#define IS_FLOAT32(X) EQ(DTYPE_OF(X), DType::Float32) + +//! IS_FLOAT("operand") -> bool (alias of IS_FLOAT32) +#define IS_FLOAT(X) IS_FLOAT32(X) + +//! IS_BFLOAT16("operand") -> bool (true if operand has BFloat16 output) +#define IS_BFLOAT16(X) EQ(DTYPE_OF(X), DType::BFloat16) + +//! IS_QUINT8("operand") -> bool (true if operand has 'QUInt8' output) +#define IS_QUINT8(X) EQ(DTYPE_OF(X), DType::QUInt8) + +//! IS_QINT8("operand") -> bool (true if operand has 'QInt8' output) +#define IS_QINT8(X) EQ(DTYPE_OF(X), DType::QInt8) + +//! IS_QINT16("operand") -> bool (true if operand has 'QInt16' output) +#define IS_QINT16(X) EQ(DTYPE_OF(X), DType::QInt16) + +//! IS_QUINT16("operand") -> bool (true if operand has 'QUInt16' output) +#define IS_QUINT16(X) EQ(DTYPE_OF(X), DType::QUInt16) + +//! IS_QINT32("operand") -> bool (true if operand has 'QInt32' output) +#define IS_QINT32(X) EQ(DTYPE_OF(X), DType::QInt32) +//! IS_INT32("operand") -> bool (true if operand has 'Int32' output) +#define IS_INT32(X) EQ(DTYPE_OF(X), DType::Int32) + +//! IS_INT64("operand") -> bool (true if operand has 'Int64' output) +#define IS_INT64(X) EQ(DTYPE_OF(X), DType::Int64) + +//! IS_QUANT_TYPE("operand") -> bool (true if operand has 'Quantized' output) +#define IS_QUANT_TYPE(X) OR(IS_QUINT8(X), IS_QINT8(X), IS_QINT16(X), IS_QUINT16(X), IS_QINT32(X)) +//! IS_QUANT_SIGNED("operand") -> bool (true if operand has 'Signed Quantized' output) +#define IS_QUANT_SIGNED(X) OR(IS_QINT32(X), IS_QINT16(X), IS_QINT8(X)) +//! IS_SIGNED_SYMM("operand") -> bool (true if operand has 'Signed Quantized' output with offset == 0) +#define IS_SIGNED_SYMM(X) AND(IS_QUANT_SIGNED(X), EQ(ZERO_OFFSET_OF(X), 0)) + +// The problem with IS_SIGNED_SYMM is that it tends to get used as +// AND( IS_QINT8(X), IS_SIGNED_SYMM(X)) +// which expands to X.dtype==qint8 && ( (X.dtype ==qint32 || X.dtype == .. ) && X.zero_offs == 0) +// So, use IS_QINT8_SYMM(X) etc instead. + +//! IS_QINT8_SYMM("operand") -> bool (true if operand has QINT8 output with offset == 0) +#define IS_QINT8_SYMM(X) AND(IS_QINT8(X), EQ(ZERO_OFFSET_OF(X), 0)) +//! IS_QINT16_SYMM("operand") -> bool (true if operand has QINT16 output with offset == 0) +#define IS_QINT16_SYMM(X) AND(IS_QINT16(X), EQ(ZERO_OFFSET_OF(X), 0)) +//! IS_QINT32_SYMM("operand") -> bool (true if operand has QINT32 output with offset == 0) +#define IS_QINT32_SYMM(X) AND(IS_QINT32(X), EQ(ZERO_OFFSET_OF(X), 0)) + +//! IS_FULLY_CONNECT_WEIGHT("operand") -> bool (true if operand is QUInt8 or (QInt8 and symmetrically quantized)) +#define IS_FULLY_CONNECT_WEIGHT(X) OR(IS_QUINT8(X), IS_QINT8_SYMM(X)) + +//! IS_FLOAT16_BOTH("operand", "operand") -> bool (true if both operands are FP16 type) +#define IS_FLOAT16_BOTH(X, Y) AND(IS_FLOAT16(X), IS_FLOAT16(Y)) +//! IS_FLOAT16_ALL("operand", ...) -> bool (true if all operands are FP16 type) +#define IS_FLOAT16_ALL(...) IS_DTYPE_ALL(DType::Float16, __VA_ARGS__) +//! IS_BFLOAT16_ALL("operand", ...) -> bool (true if all operands are BF16 type) +#define IS_BFLOAT16_ALL(...) IS_DTYPE_ALL(DType::BFloat16, __VA_ARGS__) +//! IS_FLOAT32_ALL("operand", ...) -> bool (true if all operands are FP32 type) +#define IS_FLOAT32_ALL(...) IS_DTYPE_ALL(DType::Float32, __VA_ARGS__) + +//! DIM_CHANNEL("operand") -> unsigned (extract depth dimension, #4) +#define DIM_CHANNEL(X) DIM_OF(X, 4) +//! DIM_DEPTH("operand") -> unsigned (extract depth dimension, #3) +#define DIM_DEPTH(X) DIM_OF(X, 3) +//! DIM_WIDTH("operand") -> unsigned (extract width dimension, #2) +#define DIM_WIDTH(X) DIM_OF(X, 2) +//! DIM_HEIGHT("operand") -> unsigned (extract height dimension, #1) +#define DIM_HEIGHT(X) DIM_OF(X, 1) +//! DIM_BATCHES("operand") -> unsigned (extract batches dimension, #0) +#define DIM_BATCHES(X) DIM_OF(X, 0) + +//! DIM_NFILTS("operand") -> unsigned (extract 'output depth' dimension from filter weights, #3) +#define DIM_NFILTS(X) DIM_OF(X, 3) +//! DIM_FILTDEPTH("operand") -> unsigned (extract 'input depth' dimension from filter weights, #2) +#define DIM_FILTDEPTH(X) DIM_OF(X, 2) +//! DIM_FILTWIDTH("operand") -> unsigned (extract 'filter width' dimension from filter weights, #1) +#define DIM_FILTWIDTH(X) DIM_OF(X, 1) +//! DIM_FILTHEIGHT("operand") -> unsigned (extract 'filter height' dimension from filter weights, #0) +#define DIM_FILTHEIGHT(X) DIM_OF(X, 0) + +#define MAX_SPARSE_ELEMENTS(X) DIM_OF(X, (MAX_DIMENSIONS - 1)) + +//! IS_EMPTY_DIM("operand", dim) -> bool (true if size of dim is 0) +#define IS_EMPTY_DIM(X, DIM) EQ(DIM_OF(X, DIM), 0) + +//! IS_EMPTY("operand") -> bool (true if size of all dims is 0) +#define IS_EMPTY(X) AND(IS_EMPTY_DIM(X, 0), IS_EMPTY_DIM(X, 1), IS_EMPTY_DIM(X, 2), IS_EMPTY_DIM(X, 3)) + +#define IS_FP16_BF16_1(X) OR(EQ(DTYPE_OF(X), DType::Float16), EQ(DTYPE_OF(X), DType::BFloat16)) + +#define IS_FP16_BF16_2(X, Y) \ + OR(AND(EQ(DTYPE_OF(X), DType::Float16), EQ(DTYPE_OF(Y), DType::Float16)), \ + AND(EQ(DTYPE_OF(X), DType::BFloat16), EQ(DTYPE_OF(Y), DType::BFloat16))) + +#define IS_FP16_BF16_3(X, Y, Z) \ + OR(AND(EQ(DTYPE_OF(X), DType::Float16), EQ(DTYPE_OF(Y), DType::Float16), EQ(DTYPE_OF(Z), DType::Float16)), \ + AND(EQ(DTYPE_OF(X), DType::BFloat16), EQ(DTYPE_OF(Y), DType::BFloat16), EQ(DTYPE_OF(Z), DType::BFloat16))) + +#define IS_FP16_BF16_4(W, X, Y, Z) \ + OR(AND(EQ(DTYPE_OF(W), DType::Float16), EQ(DTYPE_OF(X), DType::Float16), EQ(DTYPE_OF(Y), DType::Float16), \ + EQ(DTYPE_OF(Z), DType::Float16)), \ + AND(EQ(DTYPE_OF(W), DType::BFloat16), EQ(DTYPE_OF(X), DType::BFloat16), EQ(DTYPE_OF(Y), DType::BFloat16), \ + EQ(DTYPE_OF(Z), DType::BFloat16))) + +#define IS_FP16_BF16_N(_1, _2, _3, _4, NAME, ...) NAME +#define IS_FP16_BF16(...) \ + IS_FP16_BF16_N(__VA_ARGS__, IS_FP16_BF16_4, IS_FP16_BF16_3, IS_FP16_BF16_2, IS_FP16_BF16_1)(__VA_ARGS__) + +} // namespace constraint_lib +/** @} */ + +#endif diff --git a/qnn/jni/qnn/QNN/HTP/core/conversions.h b/qnn/jni/qnn/QNN/HTP/core/conversions.h new file mode 100644 index 00000000..4cb348c6 --- /dev/null +++ b/qnn/jni/qnn/QNN/HTP/core/conversions.h @@ -0,0 +1,609 @@ +//============================================================================== +// +// Copyright (c) 2018 Qualcomm Technologies, Inc. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#ifndef CONVERSIONS_H +#define CONVERSIONS_H + +#include +#include +#include +#include +#include + +#include "builtin_intrinsics.h" + +#ifdef __hexagon__ +#include "hexagon_protos.h" +#endif + +#include "float16.h" + +#if defined(__clang__) +#define ATTR_NO_SANITIZE(CATEGORY) __attribute__((no_sanitize(CATEGORY))) +#else +#define ATTR_NO_SANITIZE(CATEGORY) /*empty */ +#endif + +namespace hnnx { + +namespace scast { + +// for a given floating type F, and a integer type TI, +// intrange_within_float::max() +// generates the largest value representable in type F which will fit into TI without overflow. +// in many cases this is F(std::numeric_limits::max()), +// but there are exceptions when the mantissa of F is narrower than TI; in those cases we +// want the representable value which is smaller than the integer's max value, not the nearest: +// F TI +// Float16 int16 32752.0 (0x7ff0) +// Float15 uint16 65504.0 (0xffe0) +// float int32 2147483520.0 (0x7fffff80) +// float uint32 4294967040.0 (0xFFFFFF00) +// float int64 9.223371487e18 (0x7fff_ff80_0000_0000) +// float uint64 1.844674297e+19 (0xFFFF_FF00__0000_0000) +// double int64 9223372036854774784.0 (0x7FFF_FFFF_FFFF_FC00) +// double uint64 18446744073709549568.0 (0xFFFF_FFFF_FFFF_F800) +// +// All of the 'min' limits are zero or powers of 2, so those can be converted +// directly from std::numeric_limits::min() +// +// +template struct intrange_within_float { +}; + +// LCOV_EXCL_START [SAFTYSWCCB-1736] constexprs resolved during compile time +template struct intrange_within_float { + static_assert(std::numeric_limits::is_integer); + static inline constexpr Float16 max() + { + if constexpr (sizeof(TI) < 2) { + return Float16(std::numeric_limits::max()); + } else if constexpr (sizeof(TI) == 2) { + return std::numeric_limits::is_signed ? Float16(32752.0f) : Float16(65504.0f); + } else { + return std::numeric_limits::is_signed ? Float16(-65504.0f) : Float16(65504.0f); + } + } + // 'min' value of integer range is always exactly representable + static inline constexpr Float16 min() { return Float16(std::numeric_limits::min()); } +}; + +template struct intrange_within_float { + static_assert(std::numeric_limits::is_integer); + static inline constexpr float max() + { + if constexpr (sizeof(TI) < 4) { + return float(std::numeric_limits::max()); + } else if constexpr (sizeof(TI) == 4) { + return std::numeric_limits::is_signed ? 2147483520.0f : 4294967040.0f; + } else { + static_assert(sizeof(TI) == 8); + return std::numeric_limits::is_signed ? 9.223371487e18f : 1.844674297e+19f; + } + } + // 'min' value of integer range is always exactly representable + static inline constexpr float min() { return float(std::numeric_limits::min()); } +}; + +template struct intrange_within_float { + static_assert(std::numeric_limits::is_integer); + static inline constexpr double max() + { + if constexpr (sizeof(TI) < 8) { + return double(std::numeric_limits::max()); + } else { + static_assert(sizeof(TI) == 8); + return std::numeric_limits::is_signed ? 9223372036854774784.0 : 18446744073709549568.0; + } + } + // 'min' value of integer range is always exactly representable + static inline constexpr float min() { return double(std::numeric_limits::min()); } +}; +// LCOV_EXCL_STOP + +template struct satcast_helper { + static_assert(std::numeric_limits::is_specialized && std::numeric_limits::is_specialized); + static inline TOUT constexpr op(TIN val) + { + if constexpr (!std::numeric_limits::is_integer) { // convert to a float + return TOUT(val); + } else { + constexpr bool OUTS = std::numeric_limits::is_signed; + if constexpr (std::numeric_limits::is_integer) { + // integer to integer. + // widening? or same width, same signedness? + constexpr bool INS = std::numeric_limits::is_signed; + if (sizeof(TOUT) > sizeof(TIN) || (sizeof(TOUT) == sizeof(TIN) && OUTS == INS)) { + // if the output is unsigned and the input < 0, return 0 + // otherwise it's a normal cast. + return (!OUTS && INS && val < 0) ? TOUT(0) : TOUT(val); + } else if (sizeof(TOUT) == sizeof(TIN)) { + if (!OUTS) { // same size, different signs + return (val < 0) ? (TOUT)0 : (TOUT)val; // signed->unsigned + } else { + constexpr TIN lim = std::numeric_limits::max(); + return (val > lim) ? (TOUT)lim : (TOUT)val; + } + } else { + // narrowing conversion + if (!OUTS) { + constexpr TIN m = std::numeric_limits::max(); + return (val < 0) ? TOUT(0) : (val > m) ? TOUT(m) : TOUT(val); + } else { + constexpr TIN mn = INS ? std::numeric_limits::min() : 0; + constexpr TIN mx = std::numeric_limits::max(); + return (val < mn) ? TOUT(mn) : (val > mx) ? TOUT(mx) : TOUT(val); + } + } + } else { // float to integer + if constexpr (sizeof(TOUT) <= sizeof(int32_t)) { + if constexpr (OUTS) { + constexpr TIN loval = intrange_within_float::min(); + constexpr TIN hival = intrange_within_float::max(); + int32_t const tmp = (int32_t)std::max(loval, std::min(hival, val)); + return satcast_helper::op(tmp); + } else { + constexpr TIN loval = 0.0; + constexpr TIN hival = intrange_within_float::max(); + uint32_t const tmp = (uint32_t)std::max(loval, std::min(hival, val)); + return satcast_helper::op(tmp); + } + } else { // 64-bit output assumed + constexpr TIN loval = intrange_within_float::min(); + constexpr TIN hival = intrange_within_float::max(); + return (TOUT)std::max(loval, std::min(hival, val)); + } + } + } + } +}; +// specialize for conversion to same +template struct satcast_helper { + static_assert(std::numeric_limits::is_specialized); + static inline TT constexpr op(TT val) { return val; } +}; + +#ifdef __hexagon__ + +// saturate to types <= int. +template struct q6_sat_int { +}; +template <> struct q6_sat_int { + static inline int op(int x) { return Q6_R_satb_R(x); } +}; +template <> struct q6_sat_int { + static inline int op(int x) { return Q6_R_satub_R(x); } +}; +template <> struct q6_sat_int { + static inline int op(int x) { return Q6_R_sath_R(x); } +}; +template <> struct q6_sat_int { + static inline int op(int x) { return Q6_R_satuh_R(x); } +}; + +// TODO: these should be done again for 'long' if long is also 32 bits. +#if 0 // NOTE: we can't really do this unless intrinsics are constexpr +template <> struct satcast_helper { + static inline uint8_t /*constexpr*/ op(int val) + { + return Q6_R_satub_R(val); + } +}; +template <> struct satcast_helper { + static inline int8_t /*constexpr*/ op(int val) { return Q6_R_satb_R(val); } +}; +template <> struct satcast_helper { + static inline uint16_t /*constexpr*/ op(int val) + { + return Q6_R_satuh_R(val); + } +}; +template <> struct satcast_helper { + static inline int16_t /*constexpr*/ op(int val) { return Q6_R_sath_R(val); } +}; +#endif + +#endif +} // end namespace scast + +} // namespace hnnx + +/** + * @brief saturate_cast( TIN val ) will work on any two numeric types; + * if the input is outside the numeric range of the output type, it + * will be range-limited. + * + * it works as follows: + * * if TOUT is a floating type, the operation is the same as the C++ cast. + * * if TOUT is integer and TIN is float, the input is first converted + * to one of int32,uint32, int64, uint64 ensuring that out-of-range values + * are clipped; and then converted to the output type as below (if it is smaller + * than 32 bits) (The 2-step conversion is intended to work well when things + * are specialized to support native hexagon ops). + * * Otherwise they are both integers. + * - If the output width is larger than the input (or if they are the same size + * and of the same signedness): + * * if the output is unsigned, and the input is < 0, the result is zero + * * otherwise the result is the same as a C++ cast (all values representable) + * - Otherwise, it is a saturating cast; values are limited to the range of TOUT. + */ +template inline constexpr TOUT saturate_cast(TIN val) +{ + return hnnx::scast::satcast_helper::op(val); +} + +/** + * @brief T saturate_round( float val ) + * round val to nearest int, and saturate to range of T. + * + * T must be an integer type, at most 32 bits. + */ +// For general C platform, we need to clip the range before converting to int; +// for hexagon the conversions saturate. +// +#ifndef __hexagon__ +template inline TOUT saturate_round(float val) +{ + static_assert(sizeof(TOUT) <= 8 && std::numeric_limits::is_integer); + return saturate_cast(std::nearbyintf(val)); +} + +#else +template inline TOUT saturate_round(float val) +{ + static_assert(sizeof(TOUT) <= 8 && std::numeric_limits::is_integer); + if constexpr ((sizeof(TOUT) == 8) && !std::numeric_limits::is_signed) { + // convert to unsigned u64, rounding, saturating + return Q6_P_convert_sf2ud_R(val); + } else if constexpr ((sizeof(TOUT) == 8) && std::numeric_limits::is_signed) { + // convert to int64, rounding + return Q6_P_convert_sf2d_R(val); + } else if constexpr ((sizeof(TOUT) == 4) && !std::numeric_limits::is_signed) { + // convert to unsigned u32, rounding, saturating + return Q6_R_convert_sf2uw_R(val); + } else { + // convert to int32,rounding; + int const r = Q6_R_convert_sf2w_R(val); + if constexpr (sizeof(TOUT) < 4) return static_cast(hnnx::scast::q6_sat_int::op(r)); + return static_cast(r); // LCOV_EXCL_LINE [SAFTYSWCCB-1736] + } +} +#endif + +namespace hnnx { + +/** + * @brief 'proper' compare of any two integer types + * proper_gt( a, b) => a > b; + * E.g. if a is unsigned and b is signed, the operation checks to see if b is < 0; + * if so, the result is true; otherwise an unsigned compare is done: a > (unsigned)b + * + */ +namespace prpercmp { + +/** + * @brief if both A and B are either *int*, or smaller than int, + * then promote them both to int and compare them. + * + * otherwise, if TA is wider than TB, (or the same, with TA unsigned): + * promote b to TA, and then compare them. + * Exception, if TA is unsigned and TB is signed and b < 0; then a struct proper_cmp_helper { + static_assert(std::numeric_limits::is_integer && std::numeric_limits::is_integer); + static const bool ASIGNED = std::numeric_limits::is_signed; + static const bool BSIGNED = std::numeric_limits::is_signed; + + // compare by promoting both to int, when... + static const bool CMP_AS_INT = (sizeof(TA) < sizeof(int) || (sizeof(TA) == sizeof(int) && ASIGNED)) && + (sizeof(TB) < sizeof(int) || (sizeof(TB) == sizeof(int) && BSIGNED)); + // otherwise, compare by promoting B to A when ... + static const bool B_TO_A = sizeof(TA) > sizeof(TB) || (sizeof(TA) == sizeof(TB) && !ASIGNED); + // otherwise, compare by promoting A to B + + static inline bool constexpr eq(TA a, TB b) + { + if (CMP_AS_INT) { + return (int)a == (int)b; + } else if (B_TO_A) { + if (!ASIGNED && BSIGNED && b < 0) return false; + return a == (TA)b; + } else { + if (!BSIGNED && ASIGNED && a < 0) return false; + return (TB)a == b; + } + } + static inline bool constexpr lt(TA a, TB b) + { + if (CMP_AS_INT) { + return (int)a < (int)b; + } else if (B_TO_A) { + if (!ASIGNED && BSIGNED && b < 0) return false; // a < b always false if b<0 + return a < (TA)b; + } else { + if (!BSIGNED && ASIGNED && a < 0) return true; // a < b always true if a<0 + return (TB)a < b; + } + } +}; +/** + * @brief specialize for comparison to same type + */ +template struct proper_cmp_helper { + static_assert(std::numeric_limits::is_integer); + static inline bool constexpr eq(T a, T b) { return a == b; } + static inline bool constexpr lt(T a, T b) { return a < b; } +}; + +} // end namespace prpercmp + +} // namespace hnnx + +/** + * @brief 'proper' compare of any two integer types, respecting signedness and actual numeric value. + * proper_eq(a,b) => a == b; + * + * E.g. if a is signed and <0, and b is unsigned, result will always be false. + * + */ + +template inline bool constexpr proper_eq(TA a, TB b) +{ + return hnnx::prpercmp::proper_cmp_helper::eq(a, b); +} +/** + * @brief 'proper' compare of any two integer types, respecting signedness and actual numeric value + * proper_ne(a,b) => !proper_eq(a,b); + */ +template inline bool constexpr proper_ne(TA a, TB b) +{ + return !hnnx::prpercmp::proper_cmp_helper::eq(a, b); +} +/** + * @brief 'proper' compare of any two integer types, respecting signedness and actual numeric value + * proper_lt(a,b) => a inline bool constexpr proper_lt(TA a, TB b) +{ + return hnnx::prpercmp::proper_cmp_helper::lt(a, b); +} +/** + * @brief 'proper' compare of any two integer types, respecting signedness and actual numeric value + * proper_ge(a,b) => a>=b; + */ +template inline bool constexpr proper_ge(TA a, TB b) +{ + return !hnnx::prpercmp::proper_cmp_helper::lt(a, b); +} +/** + * @brief 'proper' compare of any two integer types, respecting signedness and actual numeric value + * proper_gt(a,b) => a>b; + */ +template inline bool constexpr proper_gt(TA a, TB b) +{ + return hnnx::prpercmp::proper_cmp_helper::lt(b, a); +} +/** + * @brief 'proper' compare of any two integer types, respecting signedness and actual numeric value + * proper_le(a,b) => a<=b; + */ +template inline bool constexpr proper_le(TA a, TB b) +{ + return !hnnx::prpercmp::proper_cmp_helper::lt(b, a); +} +/** + * @brief x >= lo && x < limit, using proper compares + */ +template inline bool constexpr proper_inrange(TA x, TB lo, TC limit) +{ + return proper_ge(x, lo) && proper_lt(x, limit); +} + +/** + * @brief x >= lo && x <= hi, using proper compares + */ +template inline bool constexpr proper_inrange_closed(TA x, TB lo, TC hi) +{ + return proper_ge(x, lo) && proper_le(x, hi); +} + +/** + * @brief find the 'width' of an unsigned value (# of bits needed to contain it) + * this is floor( log2(x))+1 + * (and 0 when x = 0) + * + */ +inline int constexpr binary_bitwidth(unsigned x) +{ + return (x == 0) ? 0 : (sizeof(unsigned) * 8 - HEX_COUNT_LEADING_ZERO(x)); +} +/** + * @brief find the 'width' of an unsigned long value (# of bits needed to contain it) + * this is floor( log2(x))+1 + * (and 0 when x = 0) + * + */ +inline int constexpr binary_bitwidth(unsigned long x) +{ + return (x == 0) ? 0 : (sizeof(unsigned long) * 8 - HEX_COUNT_LEADING_ZERO_UL(x)); +} +/** + * @brief find the 'width' of an unsigned long long value (# of bits needed to contain it) + * this is floor( log2(x))+1 + * (and 0 when x = 0) + * + */ +inline int constexpr binary_bitwidth(unsigned long long x) +{ + return (x == 0) ? 0 : (sizeof(unsigned long long) * 8 - HEX_COUNT_LEADING_ZERO_ULL(x)); +} +/** + * @brief saturating u32+u32 add + */ +inline uint32_t /*constexpr*/ addu32_sat(uint32_t a, uint32_t b) +{ + uint64_t const prod = (uint64_t)a + b; + return saturate_cast(prod); +} + +/** + * @brief saturating i32+i32 add + */ +inline int32_t /*constexpr*/ addi32_sat(int32_t a, int32_t b) +{ +#ifdef __hexagon__ + return Q6_R_add_RR_sat(a, b); +#else + int64_t prod = (int64_t)a + b; + return saturate_cast(prod); +#endif +} + +/** + * @brief saturating u32xu32 multiply + */ +inline uint32_t constexpr mulu32_sat(uint32_t a, uint32_t b) +{ + uint64_t const prod = (uint64_t)a * b; + return saturate_cast(prod); +} + +/** + * @brief saturating i32xi32 multiply + */ +inline int32_t constexpr muli32_sat(int32_t a, int32_t b) +{ + int64_t const prod = (int64_t)a * b; + return saturate_cast(prod); +} + +/** + * @brief saturating u64xu64 multiply + */ +inline uint64_t /*constexpr*/ mulu64_sat(uint64_t a, uint64_t b) +{ + uint64_t prod = 0; + if (HEX_MUL_OVERFLOW(a, b, &prod)) { + prod = std::numeric_limits::max(); + } + return prod; +} + +/** + * @brief saturating i64xi64 multiply + */ +inline int64_t /*constexpr*/ muli64_sat(int64_t a, int64_t b) +{ + int64_t prod = 0; + if (HEX_MUL_OVERFLOW(a, b, &prod)) { + prod = (int64_t(uint64_t(a) ^ uint64_t(b)) >= 0) ? std::numeric_limits::max() + : std::numeric_limits::min(); + } + return prod; +} +/** + * @brief add unsigned+unsigned->unsigned, escaping 'unsigned overflow' checks + */ +ATTR_NO_SANITIZE("unsigned-integer-overflow") +inline unsigned constexpr addu32_modular(unsigned a, unsigned b) +{ + return a + b; +} +/** + * @brief subtract unsigned-unsigned->unsigned, escaping 'unsigned overflow' checks + * For '-unsigned_var', use subu32_modular(0,unsigned_var) + */ +ATTR_NO_SANITIZE("unsigned-integer-overflow") +inline unsigned constexpr subu32_modular(unsigned a, unsigned b) +{ + return a - b; +} +/** + * @brief multiply unsigned*unsigned->unsigned, escaping 'unsigned overflow' checks + */ +ATTR_NO_SANITIZE("unsigned-integer-overflow") +inline unsigned constexpr mulu32_modular(unsigned a, unsigned b) +{ + return a * b; +} +/** + * @brief mul-add u32*u32+u32->u32, escaping 'unsigned overflow' checks + */ +ATTR_NO_SANITIZE("unsigned-integer-overflow") +inline unsigned constexpr muladdu32_modular(unsigned a, unsigned b, unsigned c) +{ + return a * b + c; +} + +/** + * @brief add u64+u64->u64, escaping 'unsigned overflow' checks + */ +ATTR_NO_SANITIZE("unsigned-integer-overflow") +inline uint64_t constexpr addu64_modular(uint64_t a, uint64_t b) +{ + return a + b; +} + +/** + * @brief subtract u64-u64->u64, escaping 'unsigned overflow' checks + */ +ATTR_NO_SANITIZE("unsigned-integer-overflow") +inline uint64_t constexpr subu64_modular(uint64_t a, uint64_t b) +{ + return a - b; +} +/** + * @brief mul u64*u64->u64, escaping 'unsigned overflow' checks + */ +ATTR_NO_SANITIZE("unsigned-integer-overflow") +inline uint64_t constexpr mulu64_modular(uint64_t a, uint64_t b) +{ + return a * b; +} + +/** + * @brief 'image' conversion from TIN to TOUT (which must be the same size) + * e.g. image_convert( 1.25f) -> 0x3fa00000 + */ + +template inline constexpr TOUT image_convert(TIN x) +{ + static_assert(sizeof(TOUT) == sizeof(TIN)); + static_assert(std::is_trivially_copyable_v); + static_assert(std::is_trivially_copyable_v); + static_assert(std::is_trivially_constructible_v); + TOUT out; + std::memcpy(&out, &x, sizeof(TOUT)); + return out; +} + +// round up A to a multiple of B. +// b is expected to be > 0 even if signed. + +template inline constexpr size_t round_up(size_t a, TD b) +{ + static_assert(std::is_integral_v, "round_up can only apply to integer types"); + // for b being a power of 2, this should compile as (a+(b-1)) &~(b-1) + return b * ((a + (b - 1)) / b); +} +// for int, b is expected to be > 0; +// this will work for negative a, e.g. round_up(-53,10) -> -50 +template inline constexpr size_t round_up(int a, TD b) +{ + static_assert(std::is_integral_v, "round_up can only apply to integer types"); + int const bi = b; + int const tmp = a + ((a > 0) ? (bi - 1) : 0); + return bi * (tmp / bi); +} + +#endif /*CONVERSIONS_H*/ diff --git a/qnn/jni/qnn/QNN/HTP/core/cost.h b/qnn/jni/qnn/QNN/HTP/core/cost.h new file mode 100644 index 00000000..8f0b21cc --- /dev/null +++ b/qnn/jni/qnn/QNN/HTP/core/cost.h @@ -0,0 +1,38 @@ +//============================================================================== +// +// Copyright (c) 2020 Qualcomm Technologies, Inc. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#ifndef COST_H +#define COST_H 1 + +// NOTE: WHATCOST may be something like SNAIL/128 +#define COST_OF(FUNC, WHATCOST) COST_OF_OP(typename DerivedType<(FUNC)>::type, WHATCOST) +#define COST_OF_F(FUNC, WHATCOSTFN) COST_OF_OP_F(typename DerivedType<(FUNC)>::type, WHATCOSTFN) + +#ifdef PREPARE_DISABLED +#define COST_OF_OP(OP, WHATCOST) +#define COST_OF_OP_F(OP, WHATCOSTFN) +#else +#define COST_OF_OP(OP, WHATCOST) \ + template <> [[maybe_unused]] constexpr hnnx::cost_function_t hnnx::get_costf() \ + { \ + return hnnx::cost_function_t(float(StandardCosts::WHATCOST)); \ + } + +#define COST_OF_OP_F(OP, WHATCOSTFN) \ + template <> \ + float hnnx::cost_function_t::cfunc(hnnx::cost_function_t const &, const Graph &graph_in, const Op *op) \ + { \ + return WHATCOSTFN(graph_in, op); \ + } \ + template <> [[maybe_unused]] constexpr hnnx::cost_function_t hnnx::get_costf() \ + { \ + return hnnx::cost_function_t(hnnx::cost_function_t::cfunc, 1.0); \ + } +#endif + +#endif diff --git a/qnn/jni/qnn/QNN/HTP/core/cost_funcs.h b/qnn/jni/qnn/QNN/HTP/core/cost_funcs.h new file mode 100644 index 00000000..286945b9 --- /dev/null +++ b/qnn/jni/qnn/QNN/HTP/core/cost_funcs.h @@ -0,0 +1,56 @@ +//============================================================================= +// +// Copyright (c) 2020 Qualcomm Technologies, Inc. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================ + +#ifndef COST_FUNCS_H +#define COST_FUNCS_H +#include +#include +#include "weak_linkage.h" +#include "macros_attribute.h" +PUSH_VISIBILITY(default) + +class Graph; +class Op; + +namespace hnnx { + +class API_EXPORT cost_function_t { + using inner_func_t = float (*)(cost_function_t const &, const Graph &, Op const *); + inner_func_t funcp; + float val; + + public: + cost_function_t(cost_function_t const &) = default; + cost_function_t &operator=(cost_function_t const &) = default; + constexpr explicit cost_function_t(float val_in) : funcp(simple_cost_function), val(val_in) {} + constexpr cost_function_t(inner_func_t f, float val_in) : funcp(f), val(val_in) {} + constexpr cost_function_t() noexcept : funcp(simple_cost_function), val(0.0f) {} + + inline float operator()(const Graph &graph_in, Op const *op) const { return (*funcp)(*this, graph_in, op); } + static float simple_cost_function(cost_function_t const &self, const Graph &, Op const *) + { + return self.val; + } // just returns val; + + float get_val() const { return val; } + + // unreliable compare for two cost func: returns -1,0,1 if this cost + // is <,=,> than rhs cost, with the second result being true; or <0,false> + // if it can't tell. + std::pair compare(cost_function_t const &rhs) const; + + template static float cfunc(cost_function_t const &, const Graph &, Op const *); +}; + +API_EXPORT cost_function_t cost_func_from_str(std::string_view); + +} // namespace hnnx + +POP_VISIBILITY() + +#endif diff --git a/qnn/jni/qnn/QNN/HTP/core/crate.h b/qnn/jni/qnn/QNN/HTP/core/crate.h new file mode 100644 index 00000000..83b3d301 --- /dev/null +++ b/qnn/jni/qnn/QNN/HTP/core/crate.h @@ -0,0 +1,480 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +/* + * crate.h + * + * Created on: Aug 1, 2019 + * Author: smithg + */ + +#ifndef CRATE_H_ +#define CRATE_H_ +#include +#include +#include +#include +#include +#include +#include +#include + +#include "is_detected.h" +#include "forward_classes.h" +#include "macros_attribute.h" +#include "weak_linkage.h" +#include "size_align_code.h" + +PUSH_VISIBILITY(default) + +class Graph; +class Tensor; + +/// @brief A 'Crate' allows construction of some number of different data types, +/// contiguously packed into a few large memory blocks. +/// +/// Example: +/// +/// Crate crt; +/// Thing tp* = crt.emplace( ... ctor parms for Thing ... ) +/// AnotherThing tp2* = crt.emplace( ... ctor parms for AnotherThing ... ) +/// +/// When the crate is destroyed, all of the contained objects are destroyed in the reverse +/// order. You cannot 'remove' a single entry using +/// +/// crt.erase has been deprecated +/// +/// However, this is likely not going to free any memory; it will just call the dtor of the +/// object (and make sure it doesn't get called later, when the Crate is cleared or destroyed). +/// +/// You can also emplace variable-sized arrays of trivially-destructable objects. +/// +/// alloc_array does not initialize: +/// +/// float * farr = crt.alloc_array(n); +/// +/// alloc_array_zero does zero-initializing: +/// +/// int * farr = crt.alloc_array_zero(n); +/// +/// If an allocation needs space larger than CHUNKBYTES, it will get its own chunk. +/// +// Each record containing an object has a non-null 'dtor' field; if the object is trivially destructible, +// this will be (dtor_funcp)1, and the object is not on the linked-list. +// +// note: +// A constructor may emplace additional records in the crate recursively. Likewise, +// it's OK if the dtors call erase() on other objects. If this happens during a 'clear', +// the erase calls are ignored since the other objects are going to get dtor'd anyhow (if they have not +// been already). +// Important: if object A's constructor places B into the crate, then B will very likely get destroyed +// first when the crate is cleared. Thus, A's destructor can't look at B (it can erase B, which is ignored +// as described above). + +// +// new 'raw' mode: +// - when the crate is in 'raw' mode, no destructors are registered. inserting an object +// increases 'alloc_count' in the chunk header, but does not increment 'nrec', nor any +// does it increase Crate::m_records. +// - raw mode is entered by enable_raw_mode(size_needed): +// which does this in addition to enabling raw mode: +// - if there is no current chunk, or if the current chunk doesn't have room for 'size_needed' bytes, +// a new chunk is added which does. +// - enable_raw_mode(size_needed) returns a chunk handle. +// +// Internally, raw_mode causes add_record_slot() to do the same thing, but it only moves alloc_count, it does +// not assign a slot index, and 'idx' is -1 in the returned struct. +// All callers of add_record_slot() *must* check for raw mode (can be done by checking idx < 0), and then avoid +// adding an dtor or doing '++m_records'. +// +// it's also possible to call .enable_raw_mode(), disable_raw_mode() +// but .enable_raw_mode() does nothing if there isn't at least one chunk allocated. +// + +namespace hnnx { + +// +// This is used to statically determine whether a type T has a clear(Graph&) +// method. This is used as an additional destructor which takes a Graph +// reference. +// + +template using clear_t = decltype(std::declval().clear(std::declval())); + +template constexpr bool has_clear = is_detected_v; + +class Deserz; +class DCrate; + +class Crate { + API_EXPORT static constexpr size_t CHUNKBYTES = (1 << 16); + static_assert(CHUNKBYTES % 8 == 0 && CHUNKBYTES >= 128); + typedef void (*dtor_funcp)(Graph *graph_in, void *); + API_EXPORT static dtor_funcp DTOR_TRIVIAL() { return (dtor_funcp)1; } + API_EXPORT static dtor_funcp DTOR_IN_PROCESS() { return (dtor_funcp)2; } + + //! A record in the index of a chunk + struct index_rec { + unsigned loc; ///< offset in bytes to the object + dtor_funcp + dtor; ///< pointer to dtor function (null if empty record; (DTOR_TRIVIAL if the object is trivial dtor) + }; + //! A chunk record in the crate. + /// + /// Each chunk is created as an array of uint64_t, via make_unique + /// The memory in a chunk has a chunkhdr, which is followed by: + /// + /// [Objects][Objects][Objects]--> free space <--[Index records] + /// + /// 'alloc_count' is the next offset available to be allocated. + /// index records are entered in reverse order from the end. So, the last nrec*sizeof(index_rec) + /// bytes of the area, are the index. + /// + typedef std::unique_ptr uptr_chunk_t; + struct chunkhdr; + API_EXPORT static chunkhdr *hdr_of(uptr_chunk_t &p) { return reinterpret_cast(p.get()); } + API_EXPORT static chunkhdr const *hdr_of(uptr_chunk_t const &p) + { + return reinterpret_cast(p.get()); + } + /// The chunkhdr is the first portion of the chunk, and is immediately followed + /// by data_len bytes, which is a multiple of 8. + struct API_EXPORT alignas(8) chunkhdr { + unsigned data_len; ///< length of the data area following header, bytes (>=CHUNKBYTES). + unsigned nrec; ///< records in use (including deleted ones) + unsigned alloc_count; ///< offset of first byte in 'free space' + // init to a given length (header not included) + void init(unsigned length) + { + data_len = length; + nrec = 0; + alloc_count = 0; + } + // reset (preserve data_len) + void init() + { + nrec = 0; + alloc_count = 0; + } + // pointer to 'offs ' within data area + inline uint8_t *get_ptr(unsigned offs) { return (uint8_t *)(this + 1) + offs; } + // pointer to end of the allocation + inline uint8_t *get_end_ptr() { return (uint8_t *)(this + 1) + data_len; } + // amount of space remaining + inline size_t space_avail() const { return data_len - alloc_count - nrec * sizeof(index_rec); } + // get pointer to an index record. + // record 0 is the last (oldest) one. + index_rec *index_p(int idx) { return (index_rec *)get_end_ptr() - (idx + 1); } + static uptr_chunk_t allocate(unsigned len); + }; + std::vector m_chunks; /// < chunks with data + std::vector m_free; /// < chunks without + typedef std::vector::iterator chunk_iter; + + bool m_rawmode = false; + bool m_clearing = false; ///< set while clearing. + size_t m_allrecords = 0; ///< includes removed and 'padding' records + size_t m_records = 0; ///< only actual, non-erased records. + + //! Returned from add_record_slot (which is used to create a new record) + struct recposn { + chunkhdr *chunkp; ///< the chunk in which it was found + void *objp; ///< pointer to the object + int idx; ///< index within the chunk (= -1 if insert was done in raw mode) + }; + API_EXPORT recposn add_record_slot(size_t bytes, size_t align); + API_EXPORT void recover_ctor_throw(recposn const &) noexcept; + API_EXPORT void install_dtor(recposn const &, dtor_funcp dtor_func); + API_EXPORT void move_to_free(chunk_iter chunk_to_free); + + public: + class ChunkHandle { + friend class Crate; + chunkhdr *chunkp; + + protected: + ChunkHandle(chunkhdr *cp) : chunkp(cp){}; + + public: + ChunkHandle() : chunkp(nullptr) {} // null handle may only be assigned-to + ChunkHandle(ChunkHandle const &) = default; + ChunkHandle &operator=(ChunkHandle const &) = default; + friend inline bool operator==(ChunkHandle const &a, ChunkHandle const &b) { return a.chunkp == b.chunkp; } + std::pair get_memory_extent() const + { + size_t const len = chunkp->get_ptr(chunkp->alloc_count) - (uint8_t *)chunkp; + return {chunkp, len}; + } + }; + + API_EXPORT Crate(); ///< Construct a new Crate + Crate(Crate const &) = delete; + Crate &operator=(Crate const &) = delete; + + // get the preload handle for the first chunk + ChunkHandle first_chunk_handle() const + { + return ChunkHandle(m_chunks.empty() ? nullptr : hdr_of(const_cast(*this).m_chunks.front())); + } + // get the preload handle for the most recent chunk + ChunkHandle last_chunk_handle() const + { + return ChunkHandle(m_chunks.empty() ? nullptr : hdr_of(const_cast(*this).m_chunks.back())); + } + // 'raw mode' + ChunkHandle enable_raw_mode(unsigned bytes_needed); + API_EXPORT void enable_raw_mode(); + void disable_raw_mode() { m_rawmode = false; } + bool raw_mode() const { return m_rawmode; } + + // Note that the destructor doesn't do anything. You have to call clear() manually. + API_EXPORT ~Crate(); + //! The number of objects in the crate. + size_t size() const { return m_records; } + //! The number of chunks in use + size_t chunk_count() const { return m_chunks.size(); } + //! The size of crate used cross all recorded chunks + unsigned get_crate_used() const + { + unsigned total_crate = 0; + for (auto const &chunk_ptr : m_chunks) { + total_crate += hdr_of(chunk_ptr)->data_len; + } + return total_crate; + } + //! The amount of space left in the current chunk, approximately. + /// DO NOT CALL unless chunk_count() > 0 + size_t current_chunk_space_remain() const { return hdr_of(this->m_chunks.back())->space_avail(); } + //! Delete all objects. Does not necessarily free all storage to the + /// system; but all retained storage is availabe for re-use in the crate. + /// Note that this is no longer called by the destructor- it must be called explicitly. + API_EXPORT void clear(Graph *graph_in); + // Special entry for deserialzing in segments. + // If it is possible to allocate, in current raw-mode chunk, everything from offset 'start' + // up to but not including 'limit', this is done, and the base address of that region is returned. + // otherwise does nothing and returns null. + API_EXPORT void *allocate_bulk(size_t start, size_t limit); + + //! Construct an object of type T into the crate, using the + /// parameters of any constructor of T. It is acceptable for the + /// constructor to call the emplace method to add other objects to + /// the crate. + template API_HIDDEN T *emplace(Args &&...args) + { + recposn const pos = add_record_slot(sizeof(T), alignof(T)); + // construct the object + if constexpr (std::is_nothrow_constructible::value) { + new (pos.objp) T(std::forward(args)...); + } else { + try { + new (pos.objp) T(std::forward(args)...); + } catch (const std::exception &e) { + recover_ctor_throw(pos); + throw; + } + } + if (pos.idx >= 0) { + // register destructor + if constexpr (!std::is_trivially_destructible::value) { + // Obtain a callable '~T()' function. + // this typically generates a jump, or a small inline; lambda can + // be implicitly converted to a function pointer since it has no state. + auto dtor_func = [](Graph *graph_in, void *obj) { + if constexpr (has_clear) { + static_cast(obj)->clear(graph_in); + } + static_cast(obj)->~T(); + }; + install_dtor(pos, dtor_func); + } else { + ++m_records; // note, install_dtor does this too. + } + } + return static_cast(pos.objp); + } + + using deserialize_op_func = void *(*)(void *, Deserz &); + using deserialize_dtor_func = void (*)(Graph *, void *); + + // Alternate interface to cut down on template instantations: + // init_func is used to initialize the memory, and dtor_func + // is is used to register the destructor. It's up to the user + // to provide the correct size and alignment. + + API_EXPORT void *emplace_explicit(Deserz &dctx, deserialize_op_func init_func, deserialize_dtor_func dtor_func, + size_align_code_t size_al); + + //! Allocate 'n' of type T in the crate. + /// Will initially be garbage; T must be trivially destructable (unless waived) + template T *alloc_array(size_t n) + { + static_assert(DTOR_OK || std::is_trivially_destructible::value); + if (n == 0) return nullptr; + recposn const pos = add_record_slot(sizeof(T) * n, alignof(T)); + if (pos.idx >= 0) m_records++; + return static_cast(pos.objp); + } + //! Allocate 'n' of type T in the crate. + /// Will be zero-filled; T must be trivially destructable. + template T *alloc_array_zero(size_t n) + { + T *const res = alloc_array(n); + if (n != 0) ::memset(res, 0, sizeof(T) * n); + return res; + } + //! Allocate 'n' of type T in the crate. + /// Will be "value constructed"; in case of things like int and pointer, + /// this means they will be zeroed. + /// + /// T must be trivially destructable. + template T *alloc_array_value(size_t n) + { + T *res = alloc_array(n); + if (n != 0) std::uninitialized_value_construct_n(res, n); + return res; + } +}; + +/* + * EJP: This seems silly, but I don't know how to get visibility into Graph into a templated Tensor because of include hell. + */ + +API_EXPORT Crate *graph_crate(Graph &graph_in); + +// +// replacement for vector, for use in ops; + +// +// limited options for constructor: +// (1) copy, or move, from vector - need Graph *; +// (2) create with a given size, null-initialized; - need Graph *; +// (3) create empty, and then fill in later +// using init( Graph* , std::vector const &) +// or init( Graph* , std::vector &&) +// or init( Graph *, size ) +// or init( Graph *, T const *ptr, size ); +// or init_move( Graph *, T *ptr, size ); + +// With option 3, it assumed that the 'init' is done during the constructor of +// a host object - this is needed during deserialize, for instance. +// the 'len' is 32 bits so this type occupies 2 pointers, vs. 3 for std::vector. +// +// If 'T' has a destructor, the cratevec's destructor will invoke that on +// each element of the vector, in reverse order. +// when the 'move-from' mechanisms to init from 'std::vector && are used, +// the supplied vector will not be cleared; but its elements will all be +// 'moved-from'. + +template class cratevec { + T *m_ptr; + unsigned m_len; + using vec_t = std::vector; + static constexpr bool need_dtor = !std::is_trivially_destructible::value; + + public: + using iterator = T *; + using const_iterator = T const *; + using value_type = T; + using size_type = size_t; + using difference_type = ptrdiff_t; + using reference = T &; + using const_reference = T const &; + + cratevec() : m_ptr(nullptr), m_len(0) {} + cratevec(Graph *g, vec_t const &v) : cratevec() + { + if (!v.empty()) init(g, v.data(), v.size()); + } + cratevec(Graph *g, vec_t &&v) : cratevec() + { + if (!v.empty()) init_move(g, v.data(), v.size()); + } + cratevec(Graph *g, size_t n) : cratevec() { init(g, n); } + cratevec(cratevec const &) = delete; + cratevec(cratevec &&) = delete; + ~cratevec() + { + if constexpr (need_dtor) { + if (m_len > 0) { + T *const ptr0 = m_ptr; + T *ptr = ptr0 + m_len; + do { + ptr--; + ptr->~T(); + } while (ptr > ptr0); + } + } + } + + cratevec &operator=(cratevec const &) = delete; + cratevec &operator=(cratevec &&) = delete; + + void init(Graph *g, T const *data, size_t n) + { + assert(m_len == 0); + if (n) { + m_ptr = graph_crate(*g)->alloc_array(n); + std::uninitialized_copy_n(data, n, m_ptr); + m_len = n; + } + } + void init_move(Graph *g, T *data, size_t n) + { + assert(m_len == 0); + if (n) { + m_ptr = graph_crate(*g)->alloc_array(n); + std::uninitialized_move_n(data, n, m_ptr); + m_len = n; + } + } + // these methods get used during deserialize, so allow it to pass crate in directly. + void init(hnnx::Crate *const crate_p, size_t const n) + { + assert(m_len == 0); + if (n) { + m_ptr = crate_p->alloc_array(n); + std::uninitialized_value_construct_n(m_ptr, n); + m_len = n; + } + } + // The DCrate version is defined in dcrate_inlines.h + void init(hnnx::DCrate *crate_p, size_t n); + + void init(Graph *const g, size_t const n) { init(graph_crate(*g), n); } + void init(Graph *const g, vec_t const &v) { init(g, v.data(), v.size()); } + void init(Graph *const g, vec_t &&v) { init_move(g, v.data(), v.size()); } + + iterator begin() noexcept { return m_ptr; } + iterator end() noexcept { return m_ptr + m_len; } + const_iterator begin() const noexcept { return m_ptr; } + const_iterator end() const noexcept { return m_ptr + m_len; } + const_iterator cbegin() const noexcept { return m_ptr; } + const_iterator cend() const noexcept { return m_ptr + m_len; } + size_type size() const noexcept { return m_len; } + T *data() noexcept { return m_ptr; } + T const *data() const noexcept { return m_ptr; } + bool empty() const noexcept { return m_len == 0; } + reference operator[](size_type idx) { return m_ptr[idx]; } + const_reference operator[](size_type idx) const { return m_ptr[idx]; } + reference at(size_type idx) + { + if (idx >= m_len) throw std::range_error("cratevec"); + return m_ptr[idx]; + } + const_reference at(size_type idx) const { return const_cast(*this).at(idx); } + reference front() { return m_ptr[0]; } + const_reference front() const { return m_ptr[0]; } + reference back() { return m_ptr[m_len - 1]; } + const_reference back() const { return m_ptr[m_len - 1]; } +}; + +} // namespace hnnx + +POP_VISIBILITY() + +#endif /* CRATE_H_ */ diff --git a/qnn/jni/qnn/QNN/HTP/core/ctor_hook.h b/qnn/jni/qnn/QNN/HTP/core/ctor_hook.h new file mode 100644 index 00000000..32723299 --- /dev/null +++ b/qnn/jni/qnn/QNN/HTP/core/ctor_hook.h @@ -0,0 +1,51 @@ +//============================================================================== +// +// Copyright (c) 2020 Qualcomm Technologies, Inc. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#ifndef CTOR_HOOK_H +#define CTOR_HOOK_H 1 + +class Graph; + +namespace hnnx { +class OpIoPtrs; + +template inline void ctor_hook(Graph &, T &ref) +{ + return; +} + +// the 'pre-hook' can install an OpHookBase pointer into the op_io_ptrs +// default is to do nothing. +template inline void ctor_ophook(OpIoPtrs const &op_io_ptrs) +{ + return; +} + +} // namespace hnnx + +#ifdef PREPARE_DISABLED +#define CTOR_HOOK(FUNC, VAR, CODE) +#else +#define CTOR_HOOK(FUNC, VAR, CODE) \ + template <> \ + [[maybe_unused]] inline void hnnx::ctor_hook(Graph &graph_in, typename DerivedType<(&FUNC)>::type &VAR) \ + { \ + CODE \ + } +#endif + +// maybe we could add more than one ophook... just define this with different #'s of parms. +// 'HOOKCLASS' must be a subclass of OpHookBase, which defines the hook. +#define CTOR_OPHOOK(FUNC, HOOKCLASS) \ + template <> inline void hnnx::ctor_ophook::type>(OpIoPtrs const &op_io_ptrs) \ + { \ + static constexpr HOOKCLASS hook; \ + const_cast(op_io_ptrs).add_ophook(&hook); \ + } + +#endif diff --git a/qnn/jni/qnn/QNN/HTP/core/dcrate_inlines.h b/qnn/jni/qnn/QNN/HTP/core/dcrate_inlines.h new file mode 100644 index 00000000..89d48a75 --- /dev/null +++ b/qnn/jni/qnn/QNN/HTP/core/dcrate_inlines.h @@ -0,0 +1,101 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#ifndef DCRATE_INLINES_H +#define DCRATE_INLINES_H 1 + +#include +#include +#include + +#include "macros_attribute.h" +#include "deser_concurrent.h" +#include "crate.h" + +namespace hnnx { + +// alloc 'amount' bytes with given alignment. +inline void *DCrate::do_alloc(const size_t align, const size_t amount) +{ + size_t basep = size_t(nextp); + if (align > 4) { + basep = (basep + (align - 1)) & ~(align - 1); + } + size_t const next_base = basep + amount; + if (next_base > (size_t)limitp) hnnx::throw_dcrate_seg_overflow(); + nextp = (void *)next_base; // update 'nextp' ... + return (void *)basep; +} + +template inline T *DCrate::alloc_array(const size_t n) +{ + if (nextp != nullptr) { + void *const allocp = do_alloc(alignof(T), sizeof(T) * n); + if (allocp) return (T *)allocp; + } + return cratep->alloc_array(n); +} + +template inline T *DCrate::emplace(Args &&...args) +{ + if (nextp != nullptr) { + void *const allocp = do_alloc(alignof(T), sizeof(T)); + if (allocp) { + new (allocp) T(std::forward(args)...); + return (T *)allocp; + } + } + return cratep->emplace(std::forward(args)...); +} + +template <> +inline void *DCrate::emplace_explicit(Deserz &dctx, deserialize_op_func const init_func, + deserialize_dtor_func const dtor_func, size_align_code_t const size_al) +{ + if (nextp != nullptr) { + void *const allocp = do_alloc(size_al.align(), size_al.size()); + if (allocp) { + init_func(allocp, dctx); + return allocp; + } + } + return cratep->emplace_explicit(dctx, init_func, dtor_func, size_al); +} + +// this will be used in place of 'emplace' when the constructor parms +// are just 'Deserz &' +template inline T *DCrate::emplace0(Deserz &dctx) +{ + deserialize_op_func const ctor = [](void *const ptr, Deserz &dctx) -> void * { + new (ptr) T(dctx); + return ptr; + }; + if (nextp != nullptr) { + void *const allocp = do_alloc(alignof(T), sizeof(T)); + if (allocp) { + (ctor)(allocp, dctx); + return (T *)allocp; + } + } + return (T *)cratep->emplace_explicit(dctx, ctor, nullptr, size_align_code_t::for_type()); +} +// init method of cratevec using 'Dcrate' is declared here to avoid header inclusion madness. +// +template inline void hnnx::cratevec::init(hnnx::DCrate *crate_p, size_t n) +{ + assert(m_len == 0); + if (n) { + m_ptr = crate_p->alloc_array(n); + std::uninitialized_value_construct_n(m_ptr, n); + m_len = n; + } +} + +} // namespace hnnx + +#endif // DCRATE_INLINES_H diff --git a/qnn/jni/qnn/QNN/HTP/core/deser_concurrent.h b/qnn/jni/qnn/QNN/HTP/core/deser_concurrent.h new file mode 100644 index 00000000..700b66c2 --- /dev/null +++ b/qnn/jni/qnn/QNN/HTP/core/deser_concurrent.h @@ -0,0 +1,302 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#ifndef DESER_CONCURRENT_H +#define DESER_CONCURRENT_H 1 + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "deser_concurrent_defs.h" + +// this is intended to be included only in "deserialize.h" + +struct PreloadInfo; + +namespace hnnx { +struct runlist_seg_descriptor; +class Crate; +class Deserz; +class fixup_supplemental_recs; +class InitTimeSchedule; + +// describes a 'span' of the deserialized data +struct deser_segment_span { + void *base; + void *limit; +}; + +// This describes a partially-decoded segment; includes fixups. +// This should stay small so we can place it inside Deserz, and std::move it +// out (to keep the fixup list) when done with the segment. +struct runlist_fixup_state { + unsigned segno = 0; + size_t *crate_begin = nullptr; // where the data starts in the crate + runlist_seg_descriptor *seg_desc = nullptr; // Corresponding 'runlist_seg_descriptor' for reference. + // The next three are copied from the runlist_auxdata_seg_desc + uint32_t base_tensor_index = 0; // first tensor index defined this segment + uint32_t base_blocktable_index = 0; // first blocktable index defined in this segment + uint32_t base_sharedobj_index = 0; // first 'shared_object' index defined in this segment + // fixup data + size_t *fixup_list_head = nullptr; // head of the 'fixup list', or null if none. + fixup_supplemental_recs *fixup_supplemental; // supplemental fixup list + + runlist_fixup_state() = default; + ~runlist_fixup_state() = default; + runlist_fixup_state(runlist_fixup_state const &) = default; + // *Some* implementations of c++lib require this to have operator= (non-move) + // in order for std::vector containing it to be constructed via resize. + runlist_fixup_state &operator=(runlist_fixup_state const &) = default; + // the move-ctor and move-assign must leave the source with no fixup list, + // and segno = 0. + runlist_fixup_state(runlist_fixup_state &&from) { do_move_from(std::move(from)); } + runlist_fixup_state &operator=(runlist_fixup_state &&from) + { + do_move_from(std::move(from)); + return *this; + } + + private: + // this is used in move-constructor and move-assign; it will always leave 'from' + // with certain 'null' values to trap cases where we're using the wrong instance. + void do_move_from(runlist_fixup_state &&from) + { + segno = from.segno; + crate_begin = from.crate_begin; + seg_desc = from.seg_desc; + base_tensor_index = from.base_tensor_index; + base_blocktable_index = from.base_blocktable_index; + base_sharedobj_index = from.base_sharedobj_index; + fixup_list_head = from.fixup_list_head; + fixup_supplemental = from.fixup_supplemental; + from.segno = 0; + from.seg_desc = nullptr; + from.fixup_list_head = nullptr; + } +}; +// +// This contains 'supplemental' fixup records for a segment; there is one instance in each runlist_seg_descriptor, +// and a pointer to in the runlist_fixup_state. When the 'runlist_fixup_state' is moved in or out of the Deserz, +// the pointer to this remains. +// To avoid the overhead of vec_push_back, this // has a static array into which values are recorded; +// when this is full (or near full), all the records within are appended to the vector in a single operation. +// At the end of the operation, any remaining records are appended to the vector, but only if the vector +// is not empty (we can read the records out of the fixed array, if they all fit). +// +// The append() is not safe unless 'ensure_room_for' is checked first; you can e.g. do ensure_room_for(3) +// ahead of doing up to 3 append +// It is best to use a constant as parameter to ensure_room_for, i.e. ahead of code which may append +// *up to* 4 values, use ensure_room_for(4); this simplifies the inline expansion of 'ensure_room_for', +// and makes very little difference to performance compared to using the exact value. +// +class fixup_supplemental_recs { + static constexpr unsigned ARR_SIZE = 64; + unsigned num_in_arr = 0; + uint32_t fixed_arr[ARR_SIZE]; + std::vector var_arr; + unsigned n_vec = 0; // = var_arr.size() + + public: + void clear(); + unsigned constexpr size() const { return num_in_arr + n_vec; } + void reserve(unsigned const n) { var_arr.reserve(n); } + inline void ensure_room_for(unsigned const n) + { + assert(n <= ARR_SIZE); + if (num_in_arr > ARR_SIZE - n) flush_to_vec(); + } + // append allowed only when preceded by 'ensure_room_for' + inline void append(uint32_t const val) + { + assert(num_in_arr < ARR_SIZE); + fixed_arr[num_in_arr++] = val; + } + // use instead of 'ensure_room_for(1); push_back(n)' + inline void push_back(uint32_t const val) + { + if (num_in_arr > ARR_SIZE - 1) flush_to_vec(); + fixed_arr[num_in_arr++] = val; + } + // After all push_back() done, do a 'finish' + // and then get_limits() can be used to traverse the data. + void finish(); // flushes, but only if the vec is not empty. + std::pair get_limits() const; + + protected: + void flush_to_vec(); +}; + +// An array of these (size N+1) is used to hold the +// information used in deserializing each each segment. +// The [N+1] is partially used; some operations may use +// e.g. arr[i+1].auxinfo.some_field to find out where something +// ends for the current segment, using the start of the next segment; +// so N-1 entry needs a next. + +struct runlist_seg_descriptor { + runlist_auxdata_seg_desc auxinfo; // the data from the 'aux_data' record for this segment + runlist_fixup_state segfixup; // the deserialization state (moved in and out of Deserz as needed) + fixup_supplemental_recs fixup_supp; // fixup supplemental recs. + deser_segment_span span_to_deser = {}; + // These are used to configure the last preload in each segment, which preloads a region + // which is either partially, or entirely, in the next segment. So, the first two entries + // below are actually set at the end of deserialization of the previous segment; the end_preload + // is set by the current segment deserialize. + // The information stored in [N] is for configuring + // the last preload in the last segment, with end_preload set to 'end of crate'; in this case + // start_preload could be <= the end of the crate, and then we don't configure it. + // likewise the information in [0] is only 'end_preload', which can be used to configure + // 'Graph::m_initial_preload' (it should go from start-of-crate to seg[0].end_preload). + // In some cases (hopefully, only in testing) we may have segments with no preloads in them, + // in which case null pointers will appear in some of these; the ChunkPreload ops need to + // configured by getting info from adjacent segments. + PreloadInfo *prev_seg_final_preload{}; // points to the prev segments' final PreloadInfo + char *start_preload{}; // the preload start address for prev seg's final preload + char *end_preload{}; // end address for prev seg's final preload +}; + +// One instance of this is in Deserializer, called segments. +// It is created 'empty', and populated when we encounter the valid +// Aux Data record. +// +class DeserSegDescs { + unsigned n_segs = 0; + // points to an array of n_seg + 1, if n_segs > 0 + std::unique_ptr seg_arr; + + public: + DeserSegDescs() = default; + ~DeserSegDescs() = default; + DeserSegDescs(DeserSegDescs const &) = delete; + DeserSegDescs(DeserSegDescs &&) = default; + DeserSegDescs &operator=(DeserSegDescs const &) = delete; + DeserSegDescs &operator=(DeserSegDescs &&) = default; + + // these two are used to create the array + void set_size(unsigned const n); // used to create sized, empty array + runlist_seg_descriptor *data() { return seg_arr.get(); } + + constexpr unsigned num_segs() const { return n_segs; } + constexpr bool is_active() const { return n_segs != 0; } + // note: 'i' may be 0 .. num_segs(); only can use when 'is_active'. + runlist_seg_descriptor &operator[](unsigned const i) { return seg_arr[i]; } + runlist_seg_descriptor const &operator[](unsigned const i) const { return seg_arr[i]; } + + // We can add other data in here, to manage the concurrent deserialization. + unsigned n_threads = 0; // set when allocating the 'Deserz' array + std::vector deserz_arr; // sized as 'n_threads'. + + // start-of-crate, rounded to a multiple of 32; Calculated before any multi-thread + // operations. Use to configure Graph::m_initial_preload. + void *crate_preload_start_boundary; + // end-of-crate, rounded up to multiple of 32. Calculated before any multi-thread + // operations. No 'ChunkPreloadOp' will exceed this. + void *crate_preload_final_boundary; + + InitTimeSchedule *initSchedule; +}; + +class dcrate_seg_overflow_error : public std::exception { + public: + dcrate_seg_overflow_error() noexcept {} //LCOV_EXCL_LINE [SAFTYSWCCB-1753] + ~dcrate_seg_overflow_error() {} //LCOV_EXCL_LINE [SAFTYSWCCB-1753] + dcrate_seg_overflow_error(dcrate_seg_overflow_error const &) = default; + dcrate_seg_overflow_error(dcrate_seg_overflow_error &&) = default; + dcrate_seg_overflow_error &operator=(dcrate_seg_overflow_error const &) = default; + dcrate_seg_overflow_error &operator=(dcrate_seg_overflow_error &&) = default; + + char const *what() const noexcept override; +}; + +// A 'DCrate' is a proxy object stored within Deserz. +// It has some of the same methods as Crate; but if nextp is not null, +// it will allocated into the space at 'nextp', limited by 'limitp' +// Otherwise it will use the Crate. +// Most methods are defined as inlines in dcrate_inlines,h +// +class DCrate { + // these are either both null, or both non-null and 4-aligned. + void *nextp = nullptr; + void *limitp = nullptr; + Crate *cratep = nullptr; + + public: + DCrate() {} + ~DCrate() {} + DCrate(DCrate const &) = default; + DCrate(DCrate &&) = default; + DCrate &operator=(DCrate const &) = default; + DCrate &operator=(DCrate &&) = default; + explicit DCrate(Crate &c) : cratep(&c) {} + void set_crate(Crate &c) { cratep = &c; } + Crate *crate() { return cratep; } + bool is_active() const { return nextp != nullptr; } + + constexpr size_t bytes_remaining() const { return (char *)limitp - (char *)nextp; } + char *next_loc() { return (char *)nextp; } + std::pair range_remain() { return {(char *)nextp, (char *)limitp}; } + + void set_memory_range(void *base, unsigned len) + { + nextp = base; + limitp = (void *)((char *)base + len); + } + void remove_memory_range() + { + nextp = nullptr; + limitp = nullptr; + } + + // Methods of Crate we want to support (See crate.h for more more detail). + // Note that the constructors invoked in 'emplace' and 'emplace_explicit' + // can and will recursively call 'emplace' to construct their sub-objects. + template T *emplace(Args &&...args); + // variant of 'emplace' which can use the 'emplace_explicit' call to avoid + // instantiating the constructor twice + template T *emplace0(Deserz &dctx); + // (this is defined with 'template' args, only so it can be declared here without + // forward refs. All are pass-by-value. Only one specialization will be defined). + template void *emplace_explicit(Deserz &dctx, FI, FD, SA); + // array allocation, used to make all arrays in crate during deserialize. + template T *alloc_array(size_t n); + + private: + // reserve the specified data in the range, and return pointer to start; or + // return null if not possible. + void *do_alloc(size_t align, size_t amount); +}; + +// defines the encoding in the upper 3 bits of the last word of a 'multi-word' supplemental record +// all must be 4..7, since a 0 in the msb indicates a 'short' record. + +constexpr unsigned SUPPFIXUP_CAT_tensor = 4; +constexpr unsigned SUPPFIXUP_CAT_sharedobj = 5; +constexpr unsigned SUPPFIXUP_CAT_blocktable = 6; // with indices packed in one word +constexpr unsigned SUPPFIXUP_CAT_blocktable_full = 7; // .. in two words +constexpr unsigned SUPPFIXUP_CAT_SHIFT = 29u; + +bool fixup_encode_for_blocktable(runlist_fixup_state &seginfo, uint32_t idx, uint32_t table_offs, void **ptrloc); + +// high-level operations in the 'deserialize by segments' code. + +GraphStatus do_multiseg_deser(Deserializer &dctx, size_t ref_deser_pos); +GraphStatus segmentjob_deserialize_ops(Deserializer &dctx, unsigned segno, unsigned threadno); +GraphStatus segmentjob_process_fixups(Deserializer &dctx, unsigned segno, unsigned threadno); +GraphStatus segmentjob_compile_ops(Deserializer &dctx, unsigned segno, unsigned threadno); +void resolve_chunk_preload_after_multiseg_deser(Deserializer &dctx); +[[noreturn]] NOINLINE void throw_dcrate_seg_overflow(); + +} // namespace hnnx + +#endif // DESER_CONCURRENT_H diff --git a/qnn/jni/qnn/QNN/HTP/core/deser_concurrent_defs.h b/qnn/jni/qnn/QNN/HTP/core/deser_concurrent_defs.h new file mode 100644 index 00000000..ddfa4296 --- /dev/null +++ b/qnn/jni/qnn/QNN/HTP/core/deser_concurrent_defs.h @@ -0,0 +1,102 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#ifndef DESER_CONCURRENT_DEFS_H +#define DESER_CONCURRENT_DEFS_H 1 + +#include +#include + +namespace hnnx { + +// NOTE: this file contains defs for concurrent deserialize which are needed on both decode and prepare +// side; mostly just the format of the Aux Data records. +// Defs needed only on decode side are in 'deser_concurrent.h', which #includes this file. + +constexpr unsigned DesConcur_MIN_SEGMENTS = 8; // can't have less than this number. + +// This is the number of runlist slots in the runlist_auxdata_seg_desc format. +// It must be >= the actual number. This number is coded into the start of the AuxData +// payload. If the number gets bigger, the reader of the aux-data +// record will need to be able to cope with the older, smaller value. + +constexpr unsigned DesConcur_MAX_RUNLISTS = 4; + +// This is a identifer string for Backward Compatible Concurrent deserialization's measure. +// This is string is used both by hexnn and QNN to understand that a particular graph has no segments +// And therefore it should not try to measure it. +constexpr const char *NoSegments_Identifier = "no_segments"; + +// The 'Aux Data' record describing the runlist partition has a payload formed of +// a runlist_auxdata_header, followed immediately by N+1 of runlist_auxdata_seg_desc. +// The number N is in the header; there may be additional words after, which can be +// ignored +// +// Aux Data header record. +// The 'record_version' is reserved to flag changes in the format, so that +// if it changes, new skel can understand old records. +// Currently, It has this format; most changes will expand one of the fields +// so following this may be adequate to capture version changes; if it is not, +// add flags in the upper bits. +// bits 31 ..13 : reserved, 0 +// bit 12: set of crate sizes are calculated based on 'dynamic tensor' sizes +// bits 11..8 length of the header in uint32's +// bits 7..3 length of 'segment' record, in uint32's +// bits 2..0 .. value of DesConcur_MAX_RUNLISTS +// +struct runlist_auxdata_header { + unsigned record_version; // see above + unsigned numsegs : 16; // number of segments; >= 8, likely <= 64 but who knows + unsigned hdrflags : 16; // reserved for flags + unsigned runlist_offset; // see below +}; + +// 'runlist_offset' is the offset, in u32's units, from the 'num_in_tensors' word +// to the 'n_ops_total' word. This is needed by 'weight share' processing in order to +// adjust the deser_offset values to accommodate changes in the encoding length of pointers. + +// The N segments are described by an array of N+1 of runlist_auxdata_seg_desc; +// segment i is defined by arr[i] (start) and arr[i+1] (end). +// An exception is 'crate_seg_len'- this may be less than arr[i+1].crate_offset - arr[i].crate_offset +// due to padding. +// In the final record arr[N]: +// - crate_seg_len is not used (0) +// - The *_list_posn records are the total length of the runlists +// - the four 'base_*_index' values are all 1 greater than any index used in the graph +// +struct runlist_auxdata_seg_desc { + uint32_t deser_offset; // where the input (pickle) data begins - reference point is the start of 'Runlist' as + // // defined in docs/pickle_format.md, i.e. the location of 'n_ops_total' word + uint32_t crate_offset; // offset in crate + uint32_t crate_seg_len; // crate length needed (not used in final entry) + uint32_t runlist_posn[DesConcur_MAX_RUNLISTS]; // where the segment starts in Op* runlist + uint32_t execlist_posn[DesConcur_MAX_RUNLISTS]; // where the segment starts in 'execlist' + uint32_t base_opseq_index; // first 'op_sequence_marker' index used in the segment. + uint32_t base_tensor_index; // first tensor index defined this segment + uint32_t base_blocktable_index; // first blocktable index defined in this segment + uint32_t base_sharedobj_index; // first 'shared_object' index defined in this segment +}; + +// Bit in the header version indicating crate sizes allow for 'dynamic shapes'. +// NOTE: if that gets backed out later, leave this here but remove it from DesConcur_AUXDATA_REC_VERSION +// +constexpr unsigned DesConcur_AUXDATA_REC_VERSION_DYNSHAPE_SIZES = 4096; + +constexpr unsigned DesConcur_AUXDATA_REC_VERSION = // composed of: + ((sizeof(runlist_auxdata_header) / sizeof(uint32_t)) * 256 // header size + + (sizeof(runlist_auxdata_seg_desc) / sizeof(uint32_t)) * 8 // seg desc len + + DesConcur_MAX_RUNLISTS) | + DesConcur_AUXDATA_REC_VERSION_DYNSHAPE_SIZES; + +// values to be used to 'grow' old crate estimate to compensate for 'dyn shape' mismatch +constexpr unsigned DesConcur_CrateGrowPerTensor = 2; // number of words per 'tensor' +constexpr unsigned DesConcur_CrateGrowPerShared = 2; // number of words per 'shared object' + +} // namespace hnnx + +#endif // DESER_CONCURRENT_DEFS_H diff --git a/qnn/jni/qnn/QNN/HTP/core/deserialize_tensors.h b/qnn/jni/qnn/QNN/HTP/core/deserialize_tensors.h new file mode 100644 index 00000000..43f14039 --- /dev/null +++ b/qnn/jni/qnn/QNN/HTP/core/deserialize_tensors.h @@ -0,0 +1,68 @@ +//============================================================================== +// +// Copyright (c) 2021-2023 Qualcomm Technologies, Inc. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#ifndef DESERIALIZE_TENSORS_H +#define DESERIALIZE_TENSORS_H 1 + +#include +#include +#include +#include +#include +#include +#include "limits.h" +#include "log.h" + +#include "forward_classes.h" +#include "serdes_tensors.h" + +namespace hnnx { + +// see comment in serdes_tensors.h for overview of how this works. + +class Deserializer; + +class DeserTensorConn : public SerTensorConnDefs { + typedef unsigned tensor_idx; + typedef Tensor const *ptr_type; + + // this collects all of the tensor_def we have seen. index is seq_index-1. + std::vector defined_tensors; + + public: + DeserTensorConn() {} + // process a tensor definition + void tensor_def(Deserz &, ptr_type); + // process n tensor refs. + void tensor_refs(Deserz &, ptr_type *ptrs, unsigned num); + // process a tensor ref + void tensor_ref(Deserz &dctx, ptr_type &ptr) { tensor_refs(dctx, &ptr, 1); } + + // TODO: remove these two, we don't use them, and should not. + // read an identity (for use in subsequent need_fixup) + tensor_idx read_identity(Deserz &); + // apply the identity to 'fix' a tensor pointer (usually now, sometimes later + void need_fixup(tensor_idx ident, ptr_type *dst); + + // 'reserve' the defined tensors to avoid allocation overhead... + inline void reserve_tensors(const size_t n) { defined_tensors.reserve(n); } + // resize the 'defined tensors' table to its full capacity (specified). + // Used only in multi-thread deserialize, prior to deserializing the runlist. + inline void resize_tensordef_table(const size_t n) { defined_tensors.resize(n); } + + // this is for use by 'reference fixup' code, in concurrent deserialize. + std::vector const &get_defined_tensors() const { return defined_tensors; } + + protected: + tensor_idx read_identity_inline(Deserz &); + void apply_fixup_inline(tensor_idx idx, ptr_type *dst); +}; + +} // namespace hnnx + +#endif // DESERIALIZE_TENSORS_H diff --git a/qnn/jni/qnn/QNN/HTP/core/deserializer.h b/qnn/jni/qnn/QNN/HTP/core/deserializer.h new file mode 100644 index 00000000..6d6af905 --- /dev/null +++ b/qnn/jni/qnn/QNN/HTP/core/deserializer.h @@ -0,0 +1,766 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#ifndef DESERIALIZER_H +#define DESERIALIZER_H 1 + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "limits.h" +#include "dtype.h" +#include "log.h" +#include "allocator.h" +#include "op_extra_info.h" + +#include "serialize_defs.h" +#include "forward_classes.h" +#include "deserialize_tensors.h" +#include "macros_attribute.h" +#include "const_extent_descriptor.h" +#include "weak_linkage.h" +#include "size_align_code.h" +#include "deser_concurrent.h" +#include "hexagon_nn_types.h" +#include "conditional_default_deleter.h" + +namespace hnnx { +class DMA_Manager; +class Crate; +/** + * @brief \ref Serializer and \ref Deserializer modules that provides + * a mechanism to flatten (serialize) and reconstruct (deserialize) + * primitive and user-defined data types. The initial objective + * was to create an in-memory representation of the optimized + * \ref Graph on x86 which can then be reconstructed and executed on + * a qdsp target, essentially, a means to Graph caching. + * + */ +using tensor_deserializer_fn = uptr_Tensor (*)(Deserz &); + +using deserialize_op_func = void *(*)(void *, Deserz &); // Allocation function +using deserialize_dtor_func = void (*)(Graph *, void *); // Deallocation function + +struct op_deserializer_fn { + op_deserializer_fn(deserialize_op_func init_func_in, const size_align_code_t sizeal_in) + : init_func(init_func_in), size_align_code(sizeal_in) + { + } + op_deserializer_fn(deserialize_op_func init_func_in, deserialize_dtor_func dtor_func_in, + const size_align_code_t sizeal_in) + : dtor_func(dtor_func_in), init_func(init_func_in), size_align_code(sizeal_in){}; + op_deserializer_fn(const op_deserializer_fn &) = default; + op_deserializer_fn(op_deserializer_fn &&) = default; + op_deserializer_fn &operator=(const op_deserializer_fn &) = delete; + deserialize_dtor_func dtor_func = nullptr; + deserialize_op_func init_func = nullptr; + const size_align_code_t size_align_code{}; + inline constexpr size_t get_size() const { return size_align_code.size(); } + inline constexpr size_t get_align() const { return size_align_code.align(); } +}; + +// here's a quick and dirty way to make these maps go faster: compare string_view starting with len; +// and if the len is the same, then compare the middle character, and if that's the same, +// use memcmp. This avoids getting slowed down by a lot of long common prefixes in the type names. +// and we don't care about the weird ordering it generates. +// +struct trick_stringview_lt { + bool operator()(std::string_view const &a, std::string_view const &b) const + { + unsigned const na = a.size(); + unsigned const nb = b.size(); + if (na != nb) return na < nb; + char const *const pa = a.data(); + char const *const pb = b.data(); + if (pa == pb || na == 0) return false; // pa==pb is a common case. + unsigned const char_a = pa[na >> 1]; + unsigned const char_b = pb[na >> 1]; + if (char_a != char_b) return char_a < char_b; + return ::memcmp(pa, pb, na) < 0; + } +}; + +using op_deserializer_map_t = std::map, trick_stringview_lt>; +using op_filename_map_t = std::map; +using tensor_deserializer_map_t = std::map; +using cexdesc_deserializer_map = std::map; + +using const_extent_t = std::pair; +using weight_buf_deserializer_map = std::map; + +/** + * @brief Deserializer class to reverse the serialization + * process and reconstruct the data for specific types + * + */ +class Deserz : public DeSerError { + friend class Deserializer; // weirdly, sometimes a derived class needs to be a friend. + friend class DeserTensorConn; + + protected: + Deserz(Deserializer *full_deser, char const *p, size_t n, Graph *g = nullptr); + + public: + // I want to make this protected, but can't. + // Even code which has access to a protected copy_ctor + // of foo can't invoke .resize(n, foo_inst) on a std::vector. This + // seems like a defect in C++. Applies to various 'emplace' methods too; + // the 'emplace' can only ever use public ctors. + Deserz(Deserz const &) = default; + + public: + virtual ~Deserz(); // please keep this as first virtual method declared. + + // These three ONLY TO BE USED when setting up a Deserz to start processing a segment. + void setup_source_span(deser_segment_span const &); + void setup_dcrate_out(void *base, size_t len); + void setup_next_tensor_index(unsigned const idx) { next_tensordef_index = idx; } + + typedef uint32_t object_identity_type; + + // Note, various accessor methods are defined as inlines below 'class Deserializer'. + // true if this Deserz is really an instance of Deserializer. + constexpr bool is_base_deser() const; + + using op_deserialize_fn_list_t = std::vector; + using tensor_deserialize_fn_list_t = std::vector; + + op_deserialize_fn_list_t &get_op_deserialize_fn_list(); + tensor_deserialize_fn_list_t &get_tensor_deserialize_fn_list(); + std::vector &get_blocktable_link_table(); + // when deserializing an op: + // - call deserialize_tensor_ref (or _refs) on all the input tensor pointers + // - pass all output tensor addresses to deserialize_tensor_def + // Sequence must match serialization; note that the deserialize-ctor of Tensor + // calls deserialize_tensor_def on itself; so there is no need to call it elsewhere, + // except for specialized types which are constructed otherwise during depickle (e.g., + // types embedded in the Op). + // + // Some ops have multiple copies of some input tensor pointers; for these, it's possible + // serialize just one reference, and the deserialize it using + // auto id = deserialize_object_identity() // <- corresponds to serialize_tensor_ref + // need_tensor_fixup( id, &first_tensor_pointer); + // (other deserialize activity can happen here) + // need_tensor_fixup( id, &second_tensor_pointer); + + void deserialize_tensor_def(Tensor const *tensor_ptr); + void deserialize_tensor_ref(Tensor const *&where); + void deserialize_tensor_refs(Tensor const **ptrs, unsigned n); + template void deserialize_tensor_ref(T const *&where); + template void deserialize_tensor_refs(T const **ptrs, unsigned n); + object_identity_type deserialize_object_identity(); + void need_tensor_fixup(object_identity_type oid, Tensor const **where); + + Graph &graph() const { return *graph_ptr; } + Crate *crate() { return d_crate.crate(); } + DCrate *dcrate() { return &d_crate; } + DeserSegDescs const &get_segments() const; // gets ref to associated 'segments' object + op_deserializer_map_t const &get_op_deser_map() const { return *op_deserializer_map; } + + bool is_aligned_const_format() const; + bool has_pending_tensor_updates(); + + bool is_shared_dynamic_tensor_shape_format() const; + + fa::RuntimeAllocator *allocator; + DCrate d_crate; // contains a crate pointer + + protected: + // hoist pointers to these maps into Deserializer to avoid static lock overhead + op_deserializer_map_t const *op_deserializer_map; + tensor_deserializer_map_t const *tensor_deserializer_map; + Graph *graph_ptr{}; + Deserializer *full_deser; + + char const *bufstart; // start of current buffer + char const *bufend; // first byte we can't read + char const *bufp; // next to read + char const *buf_limit; // <= bufend; where 'fill_buffer' needs to be called. + size_t bytes_filled; // bytes previously filled + + uint32_t op_flags; + OpExtraInfo op_extra_info; + + unsigned next_tensordef_index = 1; // belongs to 'tensorconn' but needs to be in Deserz. + // 'format version'. Currently only ones used are 0 = classic, 1 = July/2023 + // Only access through methods like .classic_format(); + // This is changed to non-zero value based on seeing certain Aux Data records + // (which must appear before the allocator). + int format_version = 0; + + // this is used in multi-thread decoding. It is important that + // it remains null-constructed if the object is really a base of Deserializer; + // it is only used in 'segment' Deserz instances. + runlist_fixup_state seg_fixup_state{}; + + /** + * @brief throws an error since deserializer detected + * deserialization on insufficient bytes i.e. an underflow + * + */ + API_EXPORT virtual char const *fill_buffer(); // called for underflow on short operation + + /** + * @brief Deserialize data of specified length and write into + * buffer provided by caller + * + * @param[out] p buffer to write to + * @param[in] len length of the \ref bufp to read from + * @param[in] align if true, skip input bytes to a boundary of 4 + */ + API_EXPORT virtual void deserialize_fread(void *p, size_t len, bool align); + + /** + * @brief Get current position of buffer from which next data will be read + * + * @return size_t offset from buffer start + */ + size_t buffer_offset() const { return bufp - bufstart; } + /** + * @brief Available buffer size remaining for deserialization + * + * @return size_t remaining bytes size + */ + size_t buffer_remain() const { return bufend - bufp; } + + /** + * @brief deserialize buffer for type T + * + * @retval T returs the deserialized value of type T + * + * Note: This is the templated API called by deserialize_T() functions + * + * Note: Cannot be used for more than 4 bytes, there is a specialized version to read u64. + */ + template T simple_deserialize() + { + static_assert(sizeof(T) <= 4, "can only read sizeof(T) <= 4"); + constexpr size_t W = 4; + char const *curr_p = bufp; + if (curr_p >= buf_limit) { + curr_p = fill_buffer(); + } + T const val = *(T const *)(curr_p); + bufp = curr_p + W; + return val; + } + // see comment above deserialize_shared_obj. + API_EXPORT std::pair deserialize_shared_obj_func(void const **ptrloc); + API_EXPORT uint64_t deser_u64_slowpath(); + void initial_l2fetch(); // called only from ctor + + public: + inline constexpr bool classic_format() const { return format_version == 0; } + /** + * @brief deserialize data of type which calls simple_deserialize + * + * @param val data to deserialize + * + * Note: the below are the only types supported for deserialize_type + */ + API_EXPORT uint64_t deserialize_uint64(); // inline later + inline float deserialize_float() { return simple_deserialize(); } + inline uint32_t deserialize_uint32() { return simple_deserialize(); } + inline NN_INT32_T deserialize_int32() { return simple_deserialize(); } + inline int16_t deserialize_int16() { return simple_deserialize(); } + inline uint16_t deserialize_uint16() { return simple_deserialize(); } + inline int8_t deserialize_int8() { return simple_deserialize(); } + inline uint8_t deserialize_uint8() { return simple_deserialize(); } + + inline uint64_t deserialize_namesig() { return deserialize_uint64(); } + + // note, this is defined as an inline in deserializer.cc and not available elsewhere + tensor_deserializer_fn deserialize_tensor_identification(unsigned tensor_class_index); + + // deserialize string + // **NOTE** will throe runtime error if called in a Deserz which is not really a Deserialize. + API_EXPORT std::string_view deserialize_str(); + + uint32_t get_op_flags() const { return op_flags; }; + void clear_op_flags() { op_flags = 0; }; + void set_op_flags(uint32_t f) { op_flags = f; }; + + const OpExtraInfo &get_op_extra_info() const { return op_extra_info; }; + void clear_extra_info() { op_extra_info.clear(); }; + void set_op_extra_info(OpExtraInfo in_op_extra_info) { op_extra_info = in_op_extra_info; }; + + /** + * @brief deserialize buffer for specified size + * + * @param[in] alloc_size number of bytes to read from \ref bufp + * @param[out] ptr destination buffer for the read bytes + * @return size_t number of bytes actually read + */ + API_EXPORT size_t deserialize_buf(size_t alloc_size, void *ptr); + /** + * @brief similar to deserialize_buf but first deserialize a + * uint32_t size of bytes that should match the alloc_size + * + * @param[in] alloc_size number of bytes to read from \ref bufp + * @param[out] ptr destination buffer for the read bytes + * @return size_t number of bytes actually read + */ + API_EXPORT size_t deserialize_buf_withlen(size_t alloc_size, void *ptr); + // deserialize a pointer as 64 bits + inline void *deserialize_ptr() { return (void *)size_t(deserialize_uint64()); } + + template T deserialize_type(); + + template std::array deserialize_array(); + + /** + * @brief convernience wrappers for deserialize fuctions that + * take in different number of arguments of uint32_t type + * + * @return std::tuple (first, second) uint32_t data deserialized + */ + // convenience wrappers (to reduce inlined code size w/o much loss of speed) + API_EXPORT std::tuple deserialize_uint32_x2(); + API_EXPORT std::tuple deserialize_uint32_x3(); + API_EXPORT std::tuple deserialize_uint32_x4(); + + API_EXPORT void deserialize_uint32_arr(uint32_t *p, size_t N); + + // to reduce code size in the templates, we can deserialize arrays of + // N uint32 to sizet + API_EXPORT void deserialize_uint32_arr_sizet(size_t *p, size_t N); + + /** + * @brief deserialize array containing uint32_t type date + * + * @tparam N size of the array + * @return std::array array containing the deserialized values + */ + template std::array deserialize_uint32_array_sizet() + { + std::array res; + deserialize_uint32_arr_sizet(&res[0], N); + return res; + } + + // + // This is used for shared objects like Shape and Interface. + // it deserializes the index, and decides if it's the first instance. + // - must always pass the address which needs to point to it; though it + // will be not be set by this function. + // - if retval.second is null, then the object was previously deserialized, + // and return.first is the pointer to it. + // - otherwise, caller must deserialize the instance, and store the pointer + // at *retval.second. retval.first will be null in this case. + // In scenarios where delayed resolution is used, the return may be {token,null} + // where 'token' is actually delayed resolution token. + // + template + std::pair // see above + deserialize_shared_obj(T const **const loc) + { + auto const res = deserialize_shared_obj_func((void const **)loc); + return {(T const *)res.first, (T const **)res.second}; + } + + // Increment tue current read position of internal buffer without reading anything + void deserialize_skip_words(size_t nwords); + + // Apply the 'pointer fixups' contained within seg_info. This can + // be called with 'this' being any Deserz or Deserializer associated + // with the operation (it is only used to access tables in Deserializer). + // This can only be done on a given segment when all previous have + // been deserialized; so if we have one Deserz per thread, we need + // to 'move' the seg_info object out of it after completing the segment, + // and use this later to do the fixups. + // Returns true if ok, false if failed. + // Will leave the fixup list empty on success. + bool apply_segment_fixups(runlist_fixup_state &seg_info) const; + + // Methods to move 'seg_fixup_state' object in or out. + void install_seg_fixup_state(runlist_fixup_state &&src) { seg_fixup_state = std::move(src); } + runlist_fixup_state extract_seg_fixup_state() { return std::move(seg_fixup_state); } + void extract_seg_fixup_state_to(runlist_fixup_state &dest) { dest = std::move(seg_fixup_state); } + + // and a read_only accessor + runlist_fixup_state const &fixup_state() const { return seg_fixup_state; } + + // for Tensor::deserialize_blocktable + inline bool fixup_encode_for_blocktable(uint32_t const idx, uint32_t const table_offs, void **const ptrloc) + { + return hnnx::fixup_encode_for_blocktable(seg_fixup_state, idx, table_offs, ptrloc); + } +}; + +///////////////// + +class Deserializer : public Deserz { + friend class Deserz; + + public: + /** + * @brief Construct a new Deserializer object + * + * @param[in] p buffer that needs to be deserialized + * @param[in] n length of the buffer + * @param[in] g pointer Graph object to deserialize (usually null, since object + * is being passed to the Graph::Graph ctor to deserialize; that ctor + * must immediately call dctx.set_graph(*this) ) + */ + API_EXPORT Deserializer(char const *p, size_t n, Graph *g = nullptr); + API_EXPORT virtual ~Deserializer(); // please keep this as first virtual method declared. + + void set_graph(Graph &g); + + inline void deserialize_tensor_def(Tensor const *tensor_ptr) { tensorconn.tensor_def(*this, tensor_ptr); } + inline void deserialize_tensor_ref(Tensor const *&where) { tensorconn.tensor_ref(*this, where); } + inline void deserialize_tensor_refs(Tensor const **ptrs, unsigned n) { tensorconn.tensor_refs(*this, ptrs, n); } + template inline void deserialize_tensor_ref(T const *&where) + { + static_assert(std::is_base_of::value); + tensorconn.tensor_ref(*this, *(Tensor const **)&where); + } + template void deserialize_tensor_refs(T const **ptrs, unsigned n) + { + static_assert(std::is_base_of::value); + tensorconn.tensor_refs(*this, (Tensor const **)ptrs, n); + } + inline object_identity_type deserialize_object_identity() { return tensorconn.read_identity(*this); } + + inline void need_tensor_fixup(object_identity_type oid, Tensor const **where) { tensorconn.need_fixup(oid, where); } + inline void resolve_fixups() + { + [[maybe_unused]] const object_identity_type newval = tensorconn.read_identity(*this); + assert(newval == 0); + } + + constexpr bool is_aligned_const_format() const { return aligned_const_format_flag; } + void set_aligned_const_format(const bool v = true) { aligned_const_format_flag = v; } + + constexpr bool is_shared_dynamic_tensor_shape_format() const { return shared_dynamic_tensor_shape; } + void set_shared_dynamic_tensor_shape_format(const bool v = true) { shared_dynamic_tensor_shape = v; } + + void set_shared_io_buffer(const bool v = true) { shared_io_buffer = v; } + + PUSH_WARNING() + DISABLE_WARNING("-Wcast-qual", MSVC_NO_EQUIV) + // valid when the entire pickle, in const_extent format, is loaded as a single, persistent dma buffer + inline unsigned char *get_weight_pointer() { return ((unsigned char *)bufstart) + (4 * pickle_len_words); }; + POP_WARNING() + inline size_t get_weight_size() { return (bufend - bufstart) - (4 * pickle_len_words); }; + + inline op_deserialize_fn_list_t &get_op_deserialize_fn_list() { return op_deserialize_fn_list; } + inline tensor_deserialize_fn_list_t &get_tensor_deserialize_fn_list() { return tensor_deserialize_fn_list; } + + // Next 4 methods are used to support 'deserialize_by_segments'. + // 'get_forward_span' returns a 'deser_segment_span' (pair of pointers) for a region of deserialized data + // from 'ref + start' up to 'ref + end', where start and end (0 <= start < end) are byte offsets + // relative to some position 'ref' in the deserialized data, and 'ref' is the value which bytes_consumed() + // returned at that reference point. All should be multiples of 4. + deser_segment_span get_forward_span(size_t ref, size_t start, size_t end); + // used to get a reference point for bytes_consumed + size_t bytes_consumed() const { return bufp - bufstart; } + // used to skip past the last 'get_forward_span' we did + void skip_to_after_span(deser_segment_span const &); + // resize tables: tensor, shared_obj, linktable, according to info in final_segdesc + void resize_object_tables(runlist_auxdata_seg_desc const &final_desc); + + uint32_t crate_size_according_to_segments() const; + + protected: + /// + /// @brief Type for a unique readonly block-of-bytes (32b array) + /// + typedef std::unique_ptr> unique_readonly_blob_t; + + std::vector objindex; // index of pointers to shape, etc. + // the state of the 'tensor connectivity' deserialize engine. + DeserTensorConn tensorconn; + bool aligned_const_format_flag = false; + bool shared_dynamic_tensor_shape = false; + bool shared_io_buffer = false; + + // this is used in 'deserialize_str', so it ideally should be in Deserz; but + // it's pretty large; so, put it here and forbid calling deserialize_str + // on a Derserz which not really a Deserialize. We only use it to decode + // 'classic' pickles, so this is ok. + char name_buf[4096]; // used for string view + + // do the reference fixups on a segment. Return true if OK. + // See Deserz::apply_segment_fixups for public API. + static bool do_segment_fixups(runlist_fixup_state &seginfo, Deserz const &dctx0); + + /// + /// @brief Function to load header part of constant extent section + /// @param [in] ptr Pointer to constant extent section + /// @return Unique readonly blob pointing to header + /// + unique_readonly_blob_t load_header(hexagon_nn_wide_address_const_t const addr); + + public: + inline constexpr bool classic_format() const { return format_version == 0; } + inline void set_format_2307() { format_version = 1; } + + // This is called when a 'class index' Aux Data is encountered. + // It must deserialize exactly the indicated number of payload words. + // is_tensor = false for "Co" (op class index), and true for "Ct" (tensor class index) + API_EXPORT void auxdata_class_index(unsigned payload_words, bool is_tensor); + // + // called when an 'Nt' Aux data is encountered, which provides some array sizes for the + // deserialization. + // It must deserialize exactly the indicated number of payload words. + API_EXPORT void auxdata_temparr_sizes(unsigned payload_words); + // Called when a 'AuxTag_deserializeSegments' is encountered. If it likes + // the record, it will set up the 'segments' object. + API_EXPORT void auxdata_deserialize_segments(unsigned payload_words); + + // called when a 'KS' Aux data is encountered, which provides a const_extent_descriptor + // It must deserialize exactly the indicated number of payload words. + API_EXPORT int auxdata_read_const_extent_descriptor(const unsigned payload_words); + // helper for above. payload_words is the length WITH PADDING + API_EXPORT int extract_const_extent_name(const unsigned payload_words, std::string &retVal); + + // Extract a std::vector containing the 'const extent descriptor table, + // from a given offset (in units of 32-bit words) relative to the start of the pickle. + // or separate pointer (if separate buffer for the weights was passed in). + // This does not affect the current position. + // If there is a problem, it returns an empty vector; caller *must* check and report. + // This uses hnnx::const_extent_hdr_check to understand how much it should read, + // and to do basic check. + API_EXPORT std::vector extract_const_extent_table(size_t posn_in_words); + std::vector extract_const_extent_table(hexagon_nn_wide_address_const_t weight_data, + const uint64_t weight_size); + // given a destination char pointer, prefilled with \null, fills it in with the name of the const_extent + // caller must provide destination of sufficient length + std::string name_from_weight_data(hexagon_nn_wide_address_const_t weight_data, const uint64_t weight_length); + + // helper func for above. return -1 if name not present. + std::string get_name(hexagon_nn_wide_address_const_t weight_data, const uint64_t weight_length); + // give a vector of weight_data buffers, stores them all in the appropriate map + void store_named_weight_bufs(const hexagon_nn_wide_address_const_t *const buffers, const uint64_t *const lengths, + const unsigned num_buffers); + void store_named_weight_bufs(std::vector const &named_weights); + // + // copy 'len' bytes of data at offset offs_bytes in the pickle into location dstp. + // returns true if it's possible. You can maybe pass a DMA_Manager to have it queued... + // offs_bytes defined as uint64_t to support possible 'far' data on hexagon. + API_EXPORT bool extract_const_extent_data(uint64_t offs_bytes, size_t len, void *dstp, DMA_Manager *dma = nullptr); + // same, using an external const_extent + bool extract_const_extent_data(uint64_t offs_bytes, size_t len, void *dstp, + hexagon_nn_wide_address_const_t weight_data, const uint64_t weight_length); + + // This extracts the 'objindex', when it is needed e.g. to 'patch' interfaces. + // Must be done only after deserializing, and can only be done once. + std::vector extract_objindex() { return std::move(objindex); } + + DeserSegDescs segments; // array of runlist_seg_descriptor, empty if not doing multiseg. + + // this is used to pass the offset of the const-extent-descriptor (recorded as pickle_len) + // to the alloc->deserialize. + size_t pickle_len_words; + + // OPTIONAL maps from weight buffer names to the descriptors and the buffers, respectively + cexdesc_deserializer_map named_cexdescs; + weight_buf_deserializer_map named_weight_bufs; + + void *uncached_ptr; + uint32_t uncached_len; + + std::vector op_deserialize_fn_list; + std::vector tensor_deserialize_fn_list; + + // used to 'link' shared blocktables during deser. + std::vector blocktable_link_table; +}; + +///////////////// + +// true if this Deserz is really an instance of Deserializer. +inline constexpr bool Deserz::is_base_deser() const +{ + return static_cast(full_deser) == this; +} + +inline bool Deserz::is_aligned_const_format() const +{ + return full_deser->aligned_const_format_flag; +} +inline bool Deserz::is_shared_dynamic_tensor_shape_format() const +{ + return full_deser->shared_dynamic_tensor_shape; +} +inline Deserz::op_deserialize_fn_list_t &Deserz::get_op_deserialize_fn_list() +{ + return full_deser->op_deserialize_fn_list; +} +inline Deserz::tensor_deserialize_fn_list_t &Deserz::get_tensor_deserialize_fn_list() +{ + return full_deser->tensor_deserialize_fn_list; +} +inline std::vector &Deserz::get_blocktable_link_table() +{ + return full_deser->blocktable_link_table; +} +// For these in Deserz, we must call the corresponding methods on the +// tensorconn in 'full_deser', but must pass 'this' as first parameter. +inline void Deserz::deserialize_tensor_def(Tensor const *const tensor_ptr) +{ + full_deser->tensorconn.tensor_def(*this, tensor_ptr); +} +inline void Deserz::deserialize_tensor_ref(Tensor const *&where) +{ + full_deser->tensorconn.tensor_ref(*this, where); +} +inline void Deserz::deserialize_tensor_refs(Tensor const **const ptrs, const unsigned n) +{ + full_deser->tensorconn.tensor_refs(*this, ptrs, n); +} +inline DeserSegDescs const &Deserz::get_segments() const +{ + return full_deser->segments; +} + +// unaligned read of 64-bits (two 32-bit aligned reads) +template <> inline uint64_t Deserz::simple_deserialize() +{ + char const *const curr_p = bufp; + if (curr_p + 8u > buf_limit) { + return deser_u64_slowpath(); + } + uint32_t const *const p = (uint32_t const *)(curr_p); + bufp = curr_p + 8u; + return p[0] + ((uint64_t)p[1] << 32); +} +inline uint64_t Deserz::deserialize_uint64() +{ + return simple_deserialize(); +} + +template <> inline uint64_t Deserz::deserialize_type() +{ + return deserialize_uint64(); +} +template <> inline float Deserz::deserialize_type() +{ + return deserialize_float(); +} +// sometimes uint32_t is unsigned long, sometimes it's unsigned +// sometimes unsigned long is uint64. Hopefully this should cover it all. +#if ULONG_MAX == UINT_MAX +template <> inline unsigned long Deserz::deserialize_type() +{ + return deserialize_uint32(); +} +template <> inline long Deserz::deserialize_type() +{ + return deserialize_int32(); +} +#endif +template <> inline unsigned Deserz::deserialize_type() +{ + return deserialize_uint32(); +} +template <> inline int Deserz::deserialize_type() +{ + return deserialize_int32(); +} +template <> inline int16_t Deserz::deserialize_type() +{ + return deserialize_int16(); +} +template <> inline uint16_t Deserz::deserialize_type() +{ + return deserialize_uint16(); +} +template <> inline int8_t Deserz::deserialize_type() +{ + return deserialize_int8(); +} +template <> inline uint8_t Deserz::deserialize_type() +{ + return deserialize_uint8(); +} + +// assert( dctx.deserialize_uint32() == SOME_CONST ); +// is not safe, since if you turn off asserts, it will no longer read the 4 bytes. This is to allow that to work +#define DESERIALIZE_ASSERT_UINT32(DCTX, VAL) \ + do { \ + uint32_t const tmp [[gnu::unused]] = (DCTX).deserialize_uint32(); \ + assert(tmp == (VAL)); \ + } while (0) + +#include "weak_linkage.h" +PUSH_VISIBILITY(default) + +/** + * @brief register the deserialization function for each \ref Op + * TypicalOp and VariadicOp derived classes are instantiated via + * template and hence the need to create a map of deserialize functions + * for each Op when they are generated at library initialization + * + * @param[in] tinf Op type_info that is used to key the map + * @param[in] fn Deserialize function + */ +API_EXPORT void deserialize_op_register(std::type_info const *tinf, const std::string_view type_tag, + const op_deserializer_fn &fn, bool is_external = false, + std::string_view filename = ""); +/** + * @brief register the deserialization function for each \ref Tensor + * Since \ref Tensor derived classes are instantiated via templates, there + * is a need to create a map of deserialize function for each Tensor at runtime + * + * @param[in] type_tag Tensor type tag that is used to key the map + * @param[in] fn Deserialize function + */ +API_FUNC_EXPORT void deserialize_tensor_register(std::type_info const &tinf, const char *type_tag, + tensor_deserializer_fn fn); + +POP_VISIBILITY() + +// this is fully defined in serialize_register.h +template struct deserialize_tensor_using_constructor; + +// this is fully defined in serialize_register.h +template struct alloc_func_for_op; +template struct dealloc_func_for_op; + +////////////////////// +// Forward decls of things defined in template_help.h +// +// contains_type< tuple, x >::value: true if x is in a,b,c ... +// no 'remove ref' etc is done. +template struct contains_type; +template struct not_contains_type; +template