forked from nntrainer/Quick.AI
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathllm_util.cpp
More file actions
90 lines (77 loc) · 3.04 KB
/
Copy pathllm_util.cpp
File metadata and controls
90 lines (77 loc) · 3.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
// SPDX-License-Identifier: Apache-2.0
/**
*
* @file llm_util.cpp
* @brief util functions for llm (refactored from main.cpp)
* @date 21 August 2024
* @see https://github.com/nntrainer/nntrainer
* @author Seungbaek Hong <sb92.hong@samsung.com>
* @author Hyeonseok Lee <hs89.lee@samsung.com>
* @author Eunju Yang <ej.yang@samsung.com>
* @bug No known bugs except for NYI items
*/
#include <llm_util.hpp>
std::vector<unsigned int> generate_multi_tokens(
float *logits, unsigned int NUM_VOCAB, unsigned int NUM_TARGET_TOKENS,
float repetition_penalty, unsigned int *input_ids, unsigned int NUM_INPUT_IDS,
unsigned int *bad_words_ids, unsigned int NUM_BAD_WORDS_IDS) {
std::vector<unsigned int> outputs;
// apply repetition penalty
if (repetition_penalty != 1 && input_ids != nullptr && NUM_INPUT_IDS != 0) {
applyRepetitionPenalty(logits, input_ids, NUM_INPUT_IDS,
repetition_penalty);
}
// apply bad words penalty
if (bad_words_ids != nullptr && NUM_BAD_WORDS_IDS != 0)
applyBadWordsPenalty(logits, bad_words_ids, NUM_BAD_WORDS_IDS);
// Sort and generate multiple tokens
std::vector<std::pair<unsigned int, float>> top_indices_and_logits;
for (unsigned int i = 0; i < NUM_VOCAB; ++i) {
top_indices_and_logits.push_back({i, logits[i]});
}
std::partial_sort(top_indices_and_logits.begin(),
top_indices_and_logits.begin() + NUM_TARGET_TOKENS,
top_indices_and_logits.end(),
[](auto &a, auto &b) { return a.second > b.second; });
// add sampled words
for (unsigned int i = 0; i < NUM_TARGET_TOKENS; ++i) {
outputs.push_back(top_indices_and_logits[i].first);
}
return outputs;
}
void applyRepetitionPenalty(float *logits, unsigned int *input_ids,
unsigned int NUM_INPUT_IDS,
float repetition_penalty) {
for (unsigned int i = 0; i < NUM_INPUT_IDS; ++i) {
if (logits[input_ids[i]] < 0) {
logits[input_ids[i]] *= repetition_penalty;
} else {
logits[input_ids[i]] /= repetition_penalty;
}
}
}
void applyBadWordsPenalty(float *logits, unsigned int *bad_words_ids,
unsigned int NUM_BAD_WORDS_IDS) {
for (unsigned int i = 0; i < NUM_BAD_WORDS_IDS; ++i) {
logits[bad_words_ids[i]] = -INFINITY;
}
}
/**
* @brief Apply temperature & top-k & top-p to logits
* @return Max logit for softmax
*/
float applyTKP(float *logits, int len, float temperature, unsigned int top_k,
float top_p) {
// Apply temperature & Sort logits
std::vector<std::pair<int, float>> top_indices_and_logits;
for (int i = 0; i < len; ++i) {
if (temperature > 1e-5)
logits[i] = logits[i] / temperature;
top_indices_and_logits.push_back({i, logits[i]});
}
std::partial_sort(top_indices_and_logits.begin(),
top_indices_and_logits.begin() + 1,
top_indices_and_logits.end(),
[](auto &a, auto &b) { return a.second > b.second; });
return top_indices_and_logits[0].second;
}