mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 18:26:37 +08:00
export and run llama in C++
This commit is contained in:
parent
b8f0cacfa8
commit
761b2c9886
1
llms/export/.gitignore
vendored
Normal file
1
llms/export/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
build/
|
42
llms/export/CMakeLists.txt
Normal file
42
llms/export/CMakeLists.txt
Normal file
@ -0,0 +1,42 @@
|
||||
cmake_minimum_required(VERSION 3.27)
|
||||
|
||||
project(mlxlm LANGUAGES CXX)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||
|
||||
find_package(
|
||||
Python 3.9
|
||||
COMPONENTS Interpreter Development.Module
|
||||
REQUIRED)
|
||||
execute_process(
|
||||
COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
OUTPUT_VARIABLE MLX_ROOT)
|
||||
find_package(MLX CONFIG REQUIRED)
|
||||
|
||||
add_library(mlxlm)
|
||||
target_link_libraries(mlxlm PUBLIC mlx)
|
||||
|
||||
include(FetchContent)
|
||||
|
||||
FetchContent_Declare(
|
||||
json
|
||||
URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz)
|
||||
FetchContent_MakeAvailable(json)
|
||||
target_include_directories(
|
||||
mlxlm PRIVATE $<BUILD_INTERFACE:${json_SOURCE_DIR}/single_include/nlohmann>)
|
||||
|
||||
target_sources(mlxlm
|
||||
PRIVATE
|
||||
mlxlm.cpp
|
||||
tokenizer.cpp
|
||||
unicode.cpp
|
||||
unicode_data.cpp)
|
||||
|
||||
add_executable(main main.cpp)
|
||||
target_link_libraries(main PRIVATE mlxlm)
|
||||
|
||||
add_executable(test test.cpp)
|
||||
target_link_libraries(test PRIVATE mlxlm)
|
34
llms/export/README.md
Normal file
34
llms/export/README.md
Normal file
@ -0,0 +1,34 @@
|
||||
# Export LLMs to C++
|
||||
|
||||
Export language model inference from Python to run directly in C++.
|
||||
|
||||
To run, first install the requirements:
|
||||
|
||||
```bash
|
||||
pip install -U mlx-lm
|
||||
```
|
||||
|
||||
Then generate text from Python with:
|
||||
|
||||
```bash
|
||||
python export.py generate "How tall is K2?"
|
||||
```
|
||||
|
||||
To export the generation function run:
|
||||
|
||||
```bash
|
||||
python export.py export
|
||||
```
|
||||
|
||||
Then build the C++ code (requires CMake):
|
||||
|
||||
```bash
|
||||
cmake -B build -DCMAKE_BUILD_TYPE=Release
|
||||
cmake --build build
|
||||
```
|
||||
|
||||
And run the generation from C++ with:
|
||||
|
||||
```bash
|
||||
./build/main lama3.1-instruct-4bit "How tall is K2?"
|
||||
```
|
171
llms/export/export.py
Normal file
171
llms/export/export.py
Normal file
@ -0,0 +1,171 @@
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import fire
|
||||
import mlx.core as mx
|
||||
from mlx_lm import load
|
||||
|
||||
|
||||
class ExportableCache:
|
||||
|
||||
def __init__(self, keys=None, values=None, offset=0):
|
||||
self.offset = offset
|
||||
self.keys = keys
|
||||
self.values = values
|
||||
|
||||
def update_and_fetch(self, keys, values):
|
||||
if self.keys is not None:
|
||||
self.keys = mx.slice_update(self.keys, keys, self.offset, axes=(2,))
|
||||
self.values = mx.slice_update(self.values, values, self.offset, axes=(2,))
|
||||
else:
|
||||
self.keys = keys
|
||||
self.values = values
|
||||
return self.keys, self.values
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
return self.keys, self.values
|
||||
|
||||
|
||||
def expand(cache, mask=None, cache_step_size=256):
|
||||
cache_size = cache[0].shape[-2]
|
||||
new_size = cache_step_size * ((cache_size + cache_step_size) // cache_step_size)
|
||||
|
||||
def expand_kv(x):
|
||||
B, n_heads, _, head_dim = x.shape
|
||||
new_x = mx.zeros((B, n_heads, new_size, head_dim), x.dtype)
|
||||
new_x[..., : x.shape[2], :] = x
|
||||
return new_x
|
||||
|
||||
cache = [expand_kv(c) for c in cache]
|
||||
if mask is None:
|
||||
mask = mx.full(new_size, False)
|
||||
mask[:cache_size] = True
|
||||
else:
|
||||
mask = mx.concatenate([mask, mx.full(cache_step_size, False)])
|
||||
return cache, mask
|
||||
|
||||
|
||||
def causal_mask(N):
|
||||
idx = mx.arange(N)
|
||||
return idx[:, None] >= idx
|
||||
|
||||
|
||||
def step(model, y, *state):
|
||||
mask = state[-1]
|
||||
if len(state) > 1:
|
||||
cache, offset = state[:-2], state[-2]
|
||||
cache = [
|
||||
ExportableCache(keys, values, offset)
|
||||
for keys, values in zip(cache[::2], cache[1::2])
|
||||
]
|
||||
else:
|
||||
cache = [ExportableCache() for i in range(len(model.model.layers))]
|
||||
logits = model(y, cache=cache, mask=mask)
|
||||
cache = [y for x in cache for y in x.state]
|
||||
return logits, *cache
|
||||
|
||||
|
||||
def generate_step(prompt, model, max_tokens):
|
||||
mx.eval(model)
|
||||
|
||||
compiled_step = mx.compile(lambda *args: step(model, *args), shapeless=True)
|
||||
|
||||
def _step(*args):
|
||||
logits, *cache = compiled_step(*args)
|
||||
return mx.argmax(logits[:, -1], axis=-1), *cache
|
||||
|
||||
y, *cache = _step(prompt, causal_mask(prompt.size))
|
||||
mx.async_eval(y)
|
||||
offset = mx.array(prompt.size, mx.uint32)
|
||||
cache, mask = expand(cache)
|
||||
n = 0
|
||||
while True:
|
||||
if n < max_tokens - 1:
|
||||
if mask.size <= (prompt.size + n):
|
||||
cache, mask = expand(cache, mask)
|
||||
mask[prompt.size + n] = True
|
||||
next_y, *cache = _step(y[None], *cache, offset, mask)
|
||||
mx.async_eval(next_y)
|
||||
offset += 1
|
||||
n += 1
|
||||
yield y.item()
|
||||
if n == max_tokens:
|
||||
break
|
||||
y = next_y
|
||||
|
||||
|
||||
def export(
|
||||
model="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit",
|
||||
path="llama3.1-instruct-4bit",
|
||||
):
|
||||
model, tokenizer = load(model)
|
||||
|
||||
mx.eval(model)
|
||||
|
||||
tokenizer.save_pretrained(path)
|
||||
|
||||
_step = lambda *args: step(model, *args)
|
||||
|
||||
# Make example inputs
|
||||
y_prompt = mx.array([[0, 0]], mx.uint32)
|
||||
y_gen = mx.array([[0]], mx.uint32)
|
||||
offset = mx.array([0], mx.uint32)
|
||||
|
||||
mask = causal_mask(y_prompt.size)
|
||||
_, *cache = _step(y_prompt, mask)
|
||||
|
||||
model_path = str(Path(path) / "model.mlxfn")
|
||||
with mx.exporter(model_path, _step, shapeless=True) as exporter:
|
||||
exporter(y_prompt, mask)
|
||||
cache, mask = expand(cache)
|
||||
exporter(y_gen, *cache, offset, mask)
|
||||
|
||||
|
||||
def generate(
|
||||
prompt,
|
||||
model="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit",
|
||||
max_tokens=128,
|
||||
):
|
||||
print("[INFO] Loading model from disk.")
|
||||
model, tokenizer = load(model)
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": prompt}],
|
||||
add_generation_prompt=True,
|
||||
return_tensors="mlx",
|
||||
)
|
||||
|
||||
print("[INFO] Starting generation...")
|
||||
tic = time.time()
|
||||
tokens = []
|
||||
|
||||
detokenizer = tokenizer.detokenizer
|
||||
detokenizer.reset()
|
||||
|
||||
for n, token in enumerate(generate_step(prompt, model, max_tokens)):
|
||||
if n == 0:
|
||||
prompt_tps = prompt.size / (time.time() - tic)
|
||||
tic = time.time()
|
||||
|
||||
if token in tokenizer.eos_token_ids:
|
||||
break
|
||||
detokenizer.add_token(token)
|
||||
print(detokenizer.last_segment, end="", flush=True)
|
||||
|
||||
detokenizer.finalize()
|
||||
print(detokenizer.last_segment, flush=True)
|
||||
gen_tps = (n + 1) / (time.time() - tic)
|
||||
peak_memory = mx.metal.get_peak_memory() / 1e9
|
||||
print("=" * 10)
|
||||
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
|
||||
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
|
||||
print(f"Peak RAM: {peak_memory:.3f} GB")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(
|
||||
{
|
||||
"generate": generate,
|
||||
"export": export,
|
||||
}
|
||||
)
|
18
llms/export/main.cpp
Normal file
18
llms/export/main.cpp
Normal file
@ -0,0 +1,18 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "mlxlm.h"
|
||||
|
||||
int main(int argc, char *argv[]) {
|
||||
if (argc < 3) {
|
||||
std::cerr << "Must provide the model path and prompt." << std::endl;
|
||||
return 1;
|
||||
}
|
||||
auto path = std::string(argv[1]);
|
||||
auto prompt = std::string(argv[2]);
|
||||
|
||||
auto model = load_model(path + "/model.mlxfn");
|
||||
auto tokenizer = load_tokenizer(path);
|
||||
generate(model, tokenizer, prompt);
|
||||
}
|
119
llms/export/mlxlm.cpp
Normal file
119
llms/export/mlxlm.cpp
Normal file
@ -0,0 +1,119 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <chrono>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
|
||||
#include "mlxlm.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
|
||||
#define seconds(x) \
|
||||
(std::chrono::duration_cast<std::chrono::nanoseconds>(x).count() / 1e9)
|
||||
#define time_now() std::chrono::high_resolution_clock::now()
|
||||
|
||||
// Maybe compile
|
||||
std::function<mx::Args(mx::Args)> load_model(const std::string& path) {
|
||||
return mx::compile(mx::import_function(path), /* shapeless = */ true);
|
||||
}
|
||||
|
||||
// Maybe make tokenizer virtual
|
||||
BPETokenizer load_tokenizer(const std::string& path) {
|
||||
return BPETokenizer(path);
|
||||
}
|
||||
|
||||
void generate(
|
||||
const std::function<mx::Args(mx::Args)>& model,
|
||||
const BPETokenizer& tokenizer,
|
||||
const std::string& prompt,
|
||||
int max_tokens /* = 256 */) {
|
||||
|
||||
auto prompt_tokens = tokenizer.encode(prompt);
|
||||
int prompt_size = prompt_tokens.size();
|
||||
auto y = mx::array(prompt_tokens.data(), {1, prompt_size}, mx::uint32);
|
||||
|
||||
auto create_causal_mask = [](int N) {
|
||||
auto indices = mx::arange(N);
|
||||
return mx::expand_dims(indices, 1) >= indices;
|
||||
};
|
||||
|
||||
// Helper to expand the cache and mask
|
||||
auto expand = [](auto& args, auto& mask) {
|
||||
constexpr int cache_step_size = 256;
|
||||
int cache_size = args[1].shape(-2);
|
||||
int new_size = cache_step_size * ((cache_size + cache_step_size) / cache_step_size);
|
||||
for (auto it = args.begin() + 1; it != args.end(); ++it) {
|
||||
auto& x = *it;
|
||||
auto shape = x.shape();
|
||||
shape[2] = new_size;
|
||||
auto new_x = mx::zeros(shape, x.dtype());
|
||||
shape[2] = cache_size;
|
||||
*it = mx::slice_update(new_x, x, mx::Shape(x.ndim(), 0), std::move(shape));
|
||||
}
|
||||
mask = mx::slice_update(mx::full({new_size}, false), mask, {0}, {cache_size});
|
||||
};
|
||||
|
||||
auto tic = time_now();
|
||||
float prompt_time;
|
||||
int n = 0;
|
||||
|
||||
mx::Args args;
|
||||
{
|
||||
args = model({y, create_causal_mask(y.size())});
|
||||
auto logits = args[0];
|
||||
logits = slice(logits, {0, -1, 0}, logits.shape());
|
||||
y = argmax(logits, -1);
|
||||
async_eval(y);
|
||||
}
|
||||
|
||||
auto offset = mx::array(prompt_size, mx::uint32);
|
||||
std::vector<int> tokens;
|
||||
|
||||
auto mask = mx::full({prompt_size}, true);
|
||||
expand(args, mask);
|
||||
|
||||
for (; n < max_tokens; ++n) {
|
||||
// Start next token decoding if needed
|
||||
if (n < max_tokens - 1) {
|
||||
args[0] = y;
|
||||
auto m = prompt_size + n;
|
||||
if (mask.size() <= m) {
|
||||
expand(args, mask);
|
||||
}
|
||||
mask = mx::slice_update(mask, mx::array(true), {m}, {m + 1});
|
||||
args.push_back(offset);
|
||||
args.push_back(mask);
|
||||
args = model(args);
|
||||
args[0] = argmax(args[0], -1);
|
||||
offset = offset + 1u;
|
||||
async_eval(args[0]);
|
||||
}
|
||||
|
||||
auto token = y.item<int>();
|
||||
if (token == tokenizer.eos_token_id()) {
|
||||
break;
|
||||
}
|
||||
tokens.push_back(token);
|
||||
auto [result, complete] = tokenizer.try_decode(tokens);
|
||||
if (complete) {
|
||||
std::cout << result << std::flush;
|
||||
tokens.clear();
|
||||
}
|
||||
if (n == 0) {
|
||||
prompt_time = seconds(time_now() - tic);
|
||||
tic = time_now();
|
||||
}
|
||||
|
||||
if (n < max_tokens - 1) {
|
||||
y = args[0];
|
||||
}
|
||||
}
|
||||
auto result = tokenizer.decode(tokens);
|
||||
std::cout << result << std::flush;
|
||||
|
||||
auto gen_time = seconds(time_now() - tic);
|
||||
std::cout << std::endl;
|
||||
std::cout << std::setprecision(5) << "Prompt toks/sec "
|
||||
<< prompt_size / prompt_time << "\nGeneration toks/sec "
|
||||
<< (n + 1) / gen_time << std::endl;
|
||||
}
|
20
llms/export/mlxlm.h
Normal file
20
llms/export/mlxlm.h
Normal file
@ -0,0 +1,20 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <mlx/mlx.h>
|
||||
|
||||
#include "tokenizer.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
|
||||
std::function<mx::Args(mx::Args)> load_model(const std::string& path);
|
||||
|
||||
BPETokenizer load_tokenizer(const std::string& path);
|
||||
|
||||
struct GenerationResponse {
|
||||
};
|
||||
|
||||
void generate(
|
||||
const std::function<mx::Args(mx::Args)>& model,
|
||||
const BPETokenizer& tokenizer,
|
||||
const std::string& prompt,
|
||||
int max_tokens = 256);
|
23
llms/export/test.cpp
Normal file
23
llms/export/test.cpp
Normal file
@ -0,0 +1,23 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "tokenizer.h"
|
||||
#include <iostream>
|
||||
|
||||
template <typename T, typename U = T> void check(const T &x, const U &y) {
|
||||
if (x != y) {
|
||||
std::cerr << "Mismatch" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
void test_tokenizer(const std::string &path) {
|
||||
BPETokenizer tokenizer(path);
|
||||
check(tokenizer.encode("hello world!"), {128000, 15339, 1917, 0});
|
||||
check(tokenizer.decode({15339}), "hello");
|
||||
check(tokenizer.decode({0}), "!");
|
||||
check(tokenizer.decode({1917}), " world");
|
||||
check(tokenizer.encode("we'd see you say 世界你好真实好的很啊"),
|
||||
{128000, 906, 4265, 220, 1518, 256, 499, 2019, 127365, 57668, 53901,
|
||||
89151, 41073, 110085, 101600, 102856});
|
||||
}
|
||||
|
||||
int main(int argc, char *argv[]) { test_tokenizer("."); }
|
189
llms/export/tokenizer.cpp
Normal file
189
llms/export/tokenizer.cpp
Normal file
@ -0,0 +1,189 @@
|
||||
|
||||
#include <fstream>
|
||||
#include <filesystem>
|
||||
#include <locale>
|
||||
#include <codecvt>
|
||||
#include <json.hpp>
|
||||
|
||||
#include "tokenizer.h"
|
||||
#include "unicode.h"
|
||||
|
||||
using json = nlohmann::json;
|
||||
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
|
||||
std::pair<std::wstring, int> utf8_to_utf16(const std::string& s) {
|
||||
static std::string replace_str = std::string(1, 0xFF);
|
||||
static std::wstring replace_wstr = std::wstring(1, 0xFFFD);
|
||||
std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>> cvt(replace_str, replace_wstr);
|
||||
auto out = cvt.from_bytes(s);
|
||||
return {out, cvt.converted()};
|
||||
}
|
||||
#pragma GCC diagnostic pop
|
||||
|
||||
auto make_byte_decoder() {
|
||||
std::unordered_map<uint16_t, char> byte_decoder;
|
||||
std::vector<uint16_t> limits = {0, '!', '~' + 1, L'¡', L'¬' + 1, L'®', L'ÿ' + 1};
|
||||
char n = 0;
|
||||
for (int i = 0; i < limits.size() - 1; ++i) {
|
||||
auto start = limits[i];
|
||||
auto stop = limits[i + 1];
|
||||
if (i % 2 == 0) {
|
||||
for (int b = start; b < stop; ++b) {
|
||||
byte_decoder[256 + n++] = b;
|
||||
}
|
||||
} else {
|
||||
for (int b = start; b < stop; ++b) {
|
||||
byte_decoder[b] = b;
|
||||
}
|
||||
}
|
||||
}
|
||||
return byte_decoder;
|
||||
}
|
||||
|
||||
auto BPETokenizer::byte_decoder_ = make_byte_decoder();
|
||||
|
||||
BPETokenizer::BPETokenizer(const std::string& path_) {
|
||||
auto path = std::filesystem::path(path_);
|
||||
std::ifstream ifs(path / "tokenizer.json");
|
||||
auto tokenizer = json::parse(ifs);
|
||||
auto model = tokenizer["model"];
|
||||
token_to_id_ = model["vocab"];
|
||||
id_to_token_.resize(token_to_id_.size());
|
||||
for (auto& [s, id] : token_to_id_) {
|
||||
if (id >= id_to_token_.size()) {
|
||||
id_to_token_.resize(id + 1);
|
||||
}
|
||||
id_to_token_[id] = s;
|
||||
}
|
||||
std::string type = model["type"];
|
||||
auto merges = model["merges"];
|
||||
for (auto& s : merges) {
|
||||
if (s.is_string()) {
|
||||
merges_.emplace(s, merges_.size());
|
||||
} else {
|
||||
std::string s1 = s[0];
|
||||
std::string s2 = s[1];
|
||||
merges_.emplace(s1 + " " + s2, merges_.size());
|
||||
}
|
||||
}
|
||||
|
||||
auto added_tokens = tokenizer["added_tokens"];
|
||||
for (auto& added_token : added_tokens) {
|
||||
int id = added_token["id"];
|
||||
if (id >= id_to_token_.size()) {
|
||||
id_to_token_.resize(id + 1);
|
||||
}
|
||||
id_to_token_[id] = added_token["content"];
|
||||
if (id_to_token_[id] == "<|begin_of_text|>") {
|
||||
bos_id_ = id;
|
||||
} else if (id_to_token_[id] == "<|eot_id|>") {
|
||||
eos_id_ = id;
|
||||
}
|
||||
}
|
||||
|
||||
// Currently hardcoded to Llama3 BPE regex
|
||||
pre_tokenizer_regex_ = {"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"};
|
||||
}
|
||||
|
||||
std::vector<int> BPETokenizer::encode(std::string text) const {
|
||||
|
||||
auto segments = unicode_regex_split(text, pre_tokenizer_regex_);
|
||||
|
||||
auto one_step_merge = [this](std::string segment, std::vector<int>& splits) {
|
||||
int merge_idx;
|
||||
int rank = INT32_MAX;
|
||||
for (int i = 0; i < splits.size() - 2; ++i) {
|
||||
auto start = splits[i];
|
||||
auto mid = splits[i + 1];
|
||||
auto end = splits[i + 2];
|
||||
std::string candidate = segment.substr(start, mid - start);
|
||||
candidate += " ";
|
||||
candidate += segment.substr(mid, end - mid);
|
||||
if (auto it = merges_.find(candidate); it != merges_.end()) {
|
||||
if (it->second < rank) {
|
||||
merge_idx = i;
|
||||
rank = it->second;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (rank == INT32_MAX) {
|
||||
return false;
|
||||
}
|
||||
auto start = splits[merge_idx];
|
||||
auto mid = splits[merge_idx + 1];
|
||||
auto end = splits[merge_idx + 2];
|
||||
std::string merge_l = segment.substr(start, mid - start);
|
||||
std::string merge_r = segment.substr(mid, end - mid);
|
||||
for (int i = splits.size() - 2; i >= 0; --i) {
|
||||
auto start = splits[i];
|
||||
auto mid = splits[i + 1];
|
||||
auto end = splits[i + 2];
|
||||
if (segment.substr(start, mid - start) == merge_l && segment.substr(mid, end - mid) == merge_r) {
|
||||
splits.erase(splits.begin() + i + 1);
|
||||
i -= 1;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
std::vector<int> ids;
|
||||
ids.push_back(bos_id_);
|
||||
|
||||
// Initialize merges to integer list
|
||||
auto merge_segment = [&ids, &one_step_merge, this](const std::string& segment) {
|
||||
|
||||
std::vector<int> splits;
|
||||
for (int i = 0; i < segment.size(); ++i) {
|
||||
splits.push_back(i);
|
||||
if (static_cast<unsigned char>(segment[i]) > 128) {
|
||||
i++;
|
||||
}
|
||||
}
|
||||
splits.push_back(segment.size());
|
||||
|
||||
while (one_step_merge(segment, splits)) { };
|
||||
for (int i = 0; i < splits.size() - 1; ++i) {
|
||||
auto start = splits[i];
|
||||
auto end = splits[i + 1];
|
||||
std::string s = segment.substr(start, end - start);
|
||||
if (auto it = token_to_id_.find(s); it != token_to_id_.end()) {
|
||||
ids.push_back(it->second);
|
||||
} else {
|
||||
throw std::runtime_error("UNK ENCOUNTERED");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
for (auto& segment : segments) {
|
||||
merge_segment(segment);
|
||||
}
|
||||
return ids;
|
||||
}
|
||||
|
||||
std::string BPETokenizer::id_to_bytes(int id) const {
|
||||
std::string token;
|
||||
auto [wide_token, _] = utf8_to_utf16(id_to_token_[id]);
|
||||
token.resize(wide_token.size());
|
||||
for (int i = 0; i < wide_token.size(); ++i) {
|
||||
token[i] = byte_decoder_[wide_token[i]];
|
||||
}
|
||||
return token;
|
||||
}
|
||||
|
||||
std::pair<std::string, bool> BPETokenizer::try_decode(const std::vector<int>& ids) const {
|
||||
std::string text;
|
||||
for (auto id : ids) {
|
||||
text += id_to_bytes(id);
|
||||
}
|
||||
auto [_, converted] = utf8_to_utf16(text);
|
||||
bool complete = converted == text.size();
|
||||
text.resize(converted);
|
||||
return {text, complete};
|
||||
}
|
||||
|
||||
std::string BPETokenizer::decode(const std::vector<int>& ids) const {
|
||||
return try_decode(ids).first;
|
||||
}
|
||||
|
||||
int BPETokenizer::eos_token_id() const { return eos_id_; }
|
37
llms/export/tokenizer.h
Normal file
37
llms/export/tokenizer.h
Normal file
@ -0,0 +1,37 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#pragma once
|
||||
|
||||
/** BPE Tokenizer API */
|
||||
class BPETokenizer {
|
||||
public:
|
||||
BPETokenizer(const std::string& path);
|
||||
|
||||
/** Encode a string of text to token integer ids. */
|
||||
std::vector<int> encode(std::string text) const;
|
||||
|
||||
/** Try to decode the vector of ids to text. The text is truncated to
|
||||
* include only the fully decodable tokens. */
|
||||
std::string decode(const std::vector<int>& ids) const;
|
||||
|
||||
/** Try to decode the vector of ids to text. The second return value
|
||||
* indicates if the decoding completed. The text is truncated to include
|
||||
* only the fully decodable tokens. */
|
||||
std::pair<std::string, bool> try_decode(const std::vector<int>& ids) const;
|
||||
|
||||
int eos_token_id() const;
|
||||
|
||||
private:
|
||||
std::unordered_map<std::string, int> token_to_id_;
|
||||
std::vector<std::string> id_to_token_;
|
||||
std::unordered_map<std::string, int> merges_;
|
||||
int bos_id_;
|
||||
int eos_id_;
|
||||
static std::unordered_map<uint16_t, char> byte_decoder_;
|
||||
std::string id_to_bytes(int id) const;
|
||||
std::vector<std::string> pre_tokenizer_regex_;
|
||||
};
|
842
llms/export/unicode.cpp
Normal file
842
llms/export/unicode.cpp
Normal file
@ -0,0 +1,842 @@
|
||||
#if defined(_MSC_VER)
|
||||
#define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING
|
||||
#endif
|
||||
|
||||
#include "unicode.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <map>
|
||||
#include <regex>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <locale>
|
||||
#include <codecvt>
|
||||
|
||||
size_t unicode_len_utf8(char src) {
|
||||
const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
|
||||
uint8_t highbits = static_cast<uint8_t>(src) >> 4;
|
||||
return lookup[highbits];
|
||||
}
|
||||
|
||||
static std::string unicode_cpts_to_utf8(const std::vector<uint32_t> & cps) {
|
||||
std::string result;
|
||||
for (size_t i = 0; i < cps.size(); ++i) {
|
||||
result.append(unicode_cpt_to_utf8(cps[i]));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset) {
|
||||
assert(offset < utf8.size());
|
||||
if (!(utf8[offset + 0] & 0x80)) {
|
||||
auto result = utf8[offset + 0];
|
||||
offset += 1;
|
||||
return result;
|
||||
}
|
||||
if (!(utf8[offset + 0] & 0x40)) {
|
||||
throw std::invalid_argument("invalid character");
|
||||
}
|
||||
if (!(utf8[offset + 0] & 0x20)) {
|
||||
if (offset + 1 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80)) {
|
||||
throw std::invalid_argument("invalid character");
|
||||
}
|
||||
auto result = ((utf8[offset + 0] & 0x1f) << 6) | (utf8[offset + 1] & 0x3f);
|
||||
offset += 2;
|
||||
return result;
|
||||
}
|
||||
if (!(utf8[offset + 0] & 0x10)) {
|
||||
if (offset + 2 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80) || ! ((utf8[offset + 2] & 0xc0) == 0x80)) {
|
||||
throw std::invalid_argument("invalid character");
|
||||
}
|
||||
auto result = ((utf8[offset + 0] & 0x0f) << 12) | ((utf8[offset + 1] & 0x3f) << 6) | (utf8[offset + 2] & 0x3f);
|
||||
offset += 3;
|
||||
return result;
|
||||
}
|
||||
if (!(utf8[offset + 0] & 0x08)) {
|
||||
if (offset + 3 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80) || ! ((utf8[offset + 2] & 0xc0) == 0x80) || !((utf8[offset + 3] & 0xc0) == 0x80)) {
|
||||
throw std::invalid_argument("invalid character");
|
||||
}
|
||||
auto result = ((utf8[offset + 0] & 0x07) << 18) | ((utf8[offset + 1] & 0x3f) << 12) | ((utf8[offset + 2] & 0x3f) << 6) | (utf8[offset + 3] & 0x3f);
|
||||
offset += 4;
|
||||
return result;
|
||||
}
|
||||
throw std::invalid_argument("failed to convert utf8 to codepoint");
|
||||
}
|
||||
|
||||
//static std::vector<uint16_t> unicode_cpt_to_utf16(uint32_t cpt) {
|
||||
// std::vector<uint16_t> result;
|
||||
// if (/* 0x0000 <= cpt && */ cpt <= 0xffff) {
|
||||
// result.emplace_back(cpt);
|
||||
// return result;
|
||||
// }
|
||||
// if (0x10000 <= cpt && cpt <= 0x10ffff) {
|
||||
// result.emplace_back(0xd800 | ((cpt - 0x10000) >> 10));
|
||||
// result.emplace_back(0xdc00 | ((cpt - 0x10000) & 0x03ff));
|
||||
// return result;
|
||||
// }
|
||||
// throw std::invalid_argument("failed to convert codepoint to utf16");
|
||||
//}
|
||||
|
||||
//static std::vector<uint16_t> unicode_cpts_to_utf16(const std::vector<uint32_t> & cps) {
|
||||
// std::vector<uint16_t> result;
|
||||
// for (size_t i = 0; i < cps.size(); ++i) {
|
||||
// auto temp = unicode_cpt_to_utf16(cps[i]);
|
||||
// result.insert(result.end(), temp.begin(), temp.end());
|
||||
// }
|
||||
// return result;
|
||||
//}
|
||||
|
||||
//static uint32_t unicode_cpt_from_utf16(const std::vector<uint16_t> & utf16, size_t & offset) {
|
||||
// assert(offset < utf16.size());
|
||||
// if (((utf16[0] >> 10) << 10) != 0xd800) {
|
||||
// auto result = utf16[offset + 0];
|
||||
// offset += 1;
|
||||
// return result;
|
||||
// }
|
||||
//
|
||||
// if (offset + 1 >= utf16.size() || !((utf16[1] & 0xdc00) == 0xdc00)) {
|
||||
// throw std::invalid_argument("invalid character");
|
||||
// }
|
||||
//
|
||||
// auto result = 0x10000 + (((utf16[0] & 0x03ff) << 10) | (utf16[1] & 0x03ff));
|
||||
// offset += 2;
|
||||
// return result;
|
||||
//}
|
||||
|
||||
//static std::vector<uint32_t> unicode_cpts_from_utf16(const std::vector<uint16_t> & utf16) {
|
||||
// std::vector<uint32_t> result;
|
||||
// size_t offset = 0;
|
||||
// while (offset < utf16.size()) {
|
||||
// result.push_back(unicode_cpt_from_utf16(utf16, offset));
|
||||
// }
|
||||
// return result;
|
||||
//}
|
||||
|
||||
static std::vector<unicode_cpt_flags> unicode_cpt_flags_array() {
|
||||
std::vector<unicode_cpt_flags> cpt_flags(MAX_CODEPOINTS, unicode_cpt_flags::UNDEFINED);
|
||||
|
||||
assert (unicode_ranges_flags.begin()[0].first == 0);
|
||||
assert (unicode_ranges_flags.begin()[unicode_ranges_flags.size()-1].first == MAX_CODEPOINTS);
|
||||
for (size_t i = 1; i < unicode_ranges_flags.size(); ++i) {
|
||||
const auto range_ini = unicode_ranges_flags.begin()[i-1]; // codepoint_ini, flags
|
||||
const auto range_end = unicode_ranges_flags.begin()[i]; // codepoint_end, flags
|
||||
for (uint32_t cpt = range_ini.first; cpt < range_end.first; ++cpt) {
|
||||
cpt_flags[cpt] = range_ini.second;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto cpt : unicode_set_whitespace) {
|
||||
cpt_flags[cpt].is_whitespace = true;
|
||||
}
|
||||
|
||||
for (auto p : unicode_map_lowercase) {
|
||||
cpt_flags[p.second].is_lowercase = true;
|
||||
}
|
||||
|
||||
for (auto p : unicode_map_uppercase) {
|
||||
cpt_flags[p.second].is_uppercase = true;
|
||||
}
|
||||
|
||||
for (auto &range : unicode_ranges_nfd) { // start, last, nfd
|
||||
cpt_flags[range.nfd].is_nfd = true;
|
||||
}
|
||||
|
||||
return cpt_flags;
|
||||
}
|
||||
|
||||
static std::unordered_map<uint8_t, std::string> unicode_byte_to_utf8_map() {
|
||||
std::unordered_map<uint8_t, std::string> map;
|
||||
for (int ch = 0x21; ch <= 0x7E; ++ch) { // u'!' to u'~'
|
||||
assert(0 <= ch && ch < 256);
|
||||
map[ch] = unicode_cpt_to_utf8(ch);
|
||||
}
|
||||
for (int ch = 0xA1; ch <= 0xAC; ++ch) { // u'¡' to u'¬'
|
||||
assert(0 <= ch && ch < 256);
|
||||
map[ch] = unicode_cpt_to_utf8(ch);
|
||||
}
|
||||
for (int ch = 0xAE; ch <= 0xFF; ++ch) { // u'®' to u'ÿ'
|
||||
assert(0 <= ch && ch < 256);
|
||||
map[ch] = unicode_cpt_to_utf8(ch);
|
||||
}
|
||||
auto n = 0;
|
||||
for (int ch = 0; ch < 256; ++ch) {
|
||||
if (map.find(ch) == map.end()) {
|
||||
map[ch] = unicode_cpt_to_utf8(256 + n);
|
||||
++n;
|
||||
}
|
||||
}
|
||||
return map;
|
||||
}
|
||||
|
||||
static std::unordered_map<std::string, uint8_t> unicode_utf8_to_byte_map() {
|
||||
std::unordered_map<std::string, uint8_t> map;
|
||||
for (int ch = 0x21; ch <= 0x7E; ++ch) { // u'!' to u'~'
|
||||
assert(0 <= ch && ch < 256);
|
||||
map[unicode_cpt_to_utf8(ch)] = ch;
|
||||
}
|
||||
for (int ch = 0xA1; ch <= 0xAC; ++ch) { // u'¡' to u'¬'
|
||||
assert(0 <= ch && ch < 256);
|
||||
map[unicode_cpt_to_utf8(ch)] = ch;
|
||||
}
|
||||
for (int ch = 0xAE; ch <= 0xFF; ++ch) { // u'®' to u'ÿ'
|
||||
assert(0 <= ch && ch < 256);
|
||||
map[unicode_cpt_to_utf8(ch)] = ch;
|
||||
}
|
||||
auto n = 0;
|
||||
for (int ch = 0; ch < 256; ++ch) {
|
||||
if (map.find(unicode_cpt_to_utf8(ch)) == map.end()) {
|
||||
map[unicode_cpt_to_utf8(256 + n)] = ch;
|
||||
++n;
|
||||
}
|
||||
}
|
||||
return map;
|
||||
}
|
||||
|
||||
static inline std::wstring unicode_wstring_from_utf8(const std::string & s) {
|
||||
#if defined(__clang__)
|
||||
// disable C++17 deprecation warning for std::codecvt_utf8
|
||||
# pragma clang diagnostic push
|
||||
# pragma clang diagnostic ignored "-Wdeprecated-declarations"
|
||||
#endif
|
||||
|
||||
std::wstring_convert<std::codecvt_utf8<wchar_t>> conv;
|
||||
|
||||
#if defined(__clang__)
|
||||
# pragma clang diagnostic pop
|
||||
#endif
|
||||
|
||||
return conv.from_bytes(s);
|
||||
}
|
||||
|
||||
static std::vector<std::string> unicode_byte_encoding_process(const std::vector<std::string> & bpe_words) {
|
||||
std::vector<std::string> bpe_encoded_words;
|
||||
for (const auto & word : bpe_words) {
|
||||
std::string text_utf;
|
||||
auto utf_word = unicode_cpts_from_utf8(word);
|
||||
for (size_t i = 0; i < utf_word.size(); ++i) {
|
||||
text_utf += unicode_cpt_to_utf8(utf_word[i]);
|
||||
}
|
||||
|
||||
std::string encoded_token;
|
||||
for (char & c : text_utf) {
|
||||
encoded_token += unicode_byte_to_utf8(c);
|
||||
}
|
||||
bpe_encoded_words.emplace_back(encoded_token);
|
||||
}
|
||||
return bpe_encoded_words;
|
||||
}
|
||||
|
||||
// GPT2 system regex: 's|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+
|
||||
static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & text, const std::vector<size_t> & offsets) {
|
||||
std::vector<size_t> bpe_offsets; // store the offset of each word
|
||||
bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
|
||||
|
||||
const auto cpts = unicode_cpts_from_utf8(text);
|
||||
|
||||
size_t start = 0;
|
||||
for (auto offset : offsets) {
|
||||
const size_t offset_ini = start;
|
||||
const size_t offset_end = start + offset;
|
||||
assert(offset_end <= cpts.size());
|
||||
start = offset_end;
|
||||
|
||||
static const uint32_t OUT_OF_RANGE = 0xFFFFFFFF;
|
||||
auto _get_cpt = [&] (const size_t pos) -> uint32_t {
|
||||
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
|
||||
};
|
||||
|
||||
auto _get_flags = [&] (const size_t pos) -> unicode_cpt_flags {
|
||||
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags_from_cpt(cpts[pos]) : unicode_cpt_flags{};
|
||||
};
|
||||
|
||||
size_t _prev_end = offset_ini;
|
||||
auto _add_token = [&] (const size_t end) -> size_t {
|
||||
assert(_prev_end <= end && end <= offset_end);
|
||||
size_t len = end - _prev_end;
|
||||
if (len > 0) {
|
||||
bpe_offsets.push_back(len);
|
||||
}
|
||||
_prev_end = end;
|
||||
//if (len > 0) {
|
||||
// std::string s = "";
|
||||
// for(size_t p = end-len; p < end; p++)
|
||||
// s += unicode_cpt_to_utf8(cpts[p]);
|
||||
// printf(">>> '%s'\n", s.c_str());
|
||||
//}
|
||||
return len;
|
||||
};
|
||||
|
||||
for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
|
||||
const uint32_t cpt = _get_cpt(pos);
|
||||
const auto flags = _get_flags(pos);
|
||||
|
||||
// regex: 's|'t|'re|'ve|'m|'ll|'d
|
||||
if (cpt == '\'' && pos+1 < offset_end) {
|
||||
uint32_t cpt_next = _get_cpt(pos+1);
|
||||
if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') {
|
||||
pos += _add_token(pos+2);
|
||||
continue;
|
||||
}
|
||||
if (pos+2 < offset_end) {
|
||||
uint32_t cpt_next_next = _get_cpt(pos+2);
|
||||
if ((cpt_next == 'r' && cpt_next_next == 'e') ||
|
||||
(cpt_next == 'v' && cpt_next_next == 'e') ||
|
||||
(cpt_next == 'l' && cpt_next_next == 'l')) {
|
||||
pos += _add_token(pos+3);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto flags2 = (cpt == ' ' ? _get_flags(pos+1) : flags);
|
||||
// regex: <space>?\p{L}+
|
||||
if (flags2.is_letter) {
|
||||
pos += (cpt == ' ');
|
||||
while (flags2.is_letter) {
|
||||
flags2 = _get_flags(++pos);
|
||||
}
|
||||
_add_token(pos);
|
||||
continue;
|
||||
}
|
||||
// regex: <space>?\p{N}+
|
||||
if (flags2.is_number) {
|
||||
pos += (cpt == ' ');
|
||||
while (flags2.is_number) {
|
||||
flags2 = _get_flags(++pos);
|
||||
}
|
||||
_add_token(pos);
|
||||
continue;
|
||||
}
|
||||
// regex: <space>?[^\s\p{L}\p{N}]+
|
||||
if (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags2.as_uint()) {
|
||||
pos += (cpt == ' ');
|
||||
while (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags2.as_uint()) {
|
||||
flags2 = _get_flags(++pos);
|
||||
}
|
||||
_add_token(pos);
|
||||
continue;
|
||||
}
|
||||
|
||||
size_t num_whitespaces = 0;
|
||||
while (_get_flags(pos+num_whitespaces).is_whitespace) {
|
||||
num_whitespaces++;
|
||||
}
|
||||
|
||||
// regex: \s+(?!\S)
|
||||
if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != OUT_OF_RANGE) {
|
||||
pos += num_whitespaces - 1;
|
||||
_add_token(pos);
|
||||
continue;
|
||||
}
|
||||
|
||||
// regex: \s+
|
||||
if (num_whitespaces > 0) {
|
||||
pos += num_whitespaces;
|
||||
_add_token(pos);
|
||||
continue;
|
||||
}
|
||||
|
||||
// no matches
|
||||
_add_token(++pos);
|
||||
}
|
||||
}
|
||||
|
||||
return bpe_offsets;
|
||||
}
|
||||
|
||||
// LLAMA3 system regex: "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"
|
||||
static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string & text, const std::vector<size_t> & offsets) {
|
||||
std::vector<size_t> bpe_offsets; // store the offset of each word
|
||||
bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
|
||||
|
||||
const auto cpts = unicode_cpts_from_utf8(text);
|
||||
|
||||
size_t start = 0;
|
||||
for (auto offset : offsets) {
|
||||
const size_t offset_ini = start;
|
||||
const size_t offset_end = start + offset;
|
||||
assert(offset_end <= cpts.size());
|
||||
start = offset_end;
|
||||
|
||||
static const uint32_t OUT_OF_RANGE = 0xFFFFFFFF;
|
||||
auto _get_cpt = [&] (const size_t pos) -> uint32_t {
|
||||
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
|
||||
};
|
||||
|
||||
auto _get_flags = [&] (const size_t pos) -> unicode_cpt_flags {
|
||||
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags_from_cpt(cpts[pos]) : unicode_cpt_flags{};
|
||||
};
|
||||
|
||||
size_t _prev_end = offset_ini;
|
||||
auto _add_token = [&] (const size_t end) -> size_t {
|
||||
assert(_prev_end <= end && end <= offset_end);
|
||||
size_t len = end - _prev_end;
|
||||
if (len > 0) {
|
||||
bpe_offsets.push_back(len);
|
||||
}
|
||||
_prev_end = end;
|
||||
//if (len > 0) {
|
||||
// std::string s = "";
|
||||
// for(size_t p = end-len; p < end; p++)
|
||||
// s += unicode_cpt_to_utf8(cpts[p]);
|
||||
// printf(">>> '%s'\n", s.c_str());
|
||||
//}
|
||||
return len;
|
||||
};
|
||||
|
||||
for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
|
||||
const uint32_t cpt = _get_cpt(pos);
|
||||
const auto flags = _get_flags(pos);
|
||||
|
||||
// regex: (?i:'s|'t|'re|'ve|'m|'ll|'d) // case insensitive
|
||||
if (cpt == '\'' && pos+1 < offset_end) {
|
||||
uint32_t cpt_next = unicode_tolower(_get_cpt(pos+1));
|
||||
if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') {
|
||||
pos += _add_token(pos+2);
|
||||
continue;
|
||||
}
|
||||
if (pos+2 < offset_end) {
|
||||
uint32_t cpt_next_next = unicode_tolower(_get_cpt(pos+2));
|
||||
if ((cpt_next == 'r' && cpt_next_next == 'e') ||
|
||||
(cpt_next == 'v' && cpt_next_next == 'e') ||
|
||||
(cpt_next == 'l' && cpt_next_next == 'l')) {
|
||||
pos += _add_token(pos+3);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// regex: [^\r\n\p{L}\p{N}]?\p{L}+
|
||||
if (!(cpt == '\r' || cpt == '\n' || flags.is_number)) {
|
||||
if (flags.is_letter || _get_flags(pos+1).is_letter) { // one or more letters
|
||||
pos++;
|
||||
while (_get_flags(pos).is_letter) {
|
||||
pos++;
|
||||
}
|
||||
_add_token(pos);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// regex: \p{N}{1,3}
|
||||
if (flags.is_number) {
|
||||
size_t ini = pos;
|
||||
while (_get_flags(pos).is_number) {
|
||||
if (++pos - ini >= 3 ) {
|
||||
_add_token(pos);
|
||||
ini = pos;
|
||||
}
|
||||
}
|
||||
_add_token(pos);
|
||||
continue;
|
||||
}
|
||||
|
||||
// regex: <space>?[^\s\p{L}\p{N}]+[\r\n]*
|
||||
auto flags2 = (cpt == ' ' ? _get_flags(pos+1) : flags);
|
||||
if (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags.as_uint()) {
|
||||
pos += (cpt == ' ');
|
||||
while (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags2.as_uint()) {
|
||||
flags2 = _get_flags(++pos);
|
||||
}
|
||||
uint32_t cpt2 = _get_cpt(pos);
|
||||
while (cpt2 == '\r' || cpt2 == '\n') {
|
||||
cpt2 = _get_cpt(++pos);
|
||||
}
|
||||
_add_token(pos);
|
||||
continue;
|
||||
}
|
||||
|
||||
size_t num_whitespaces = 0;
|
||||
size_t last_end_r_or_n = 0;
|
||||
while (_get_flags(pos+num_whitespaces).is_whitespace) {
|
||||
uint32_t cpt2 = _get_cpt(pos+num_whitespaces);
|
||||
if (cpt2 == '\r' || cpt2 == '\n') {
|
||||
last_end_r_or_n = pos + num_whitespaces + 1;
|
||||
}
|
||||
num_whitespaces++;
|
||||
}
|
||||
|
||||
// regex: \s*[\r\n]+
|
||||
if (last_end_r_or_n > 0) {
|
||||
pos = last_end_r_or_n;
|
||||
_add_token(pos);
|
||||
continue;
|
||||
}
|
||||
|
||||
// regex: \s+(?!\S)
|
||||
if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != OUT_OF_RANGE) {
|
||||
pos += num_whitespaces - 1;
|
||||
_add_token(pos);
|
||||
continue;
|
||||
}
|
||||
|
||||
// regex: \s+
|
||||
if (num_whitespaces > 0) {
|
||||
pos += num_whitespaces;
|
||||
_add_token(pos);
|
||||
continue;
|
||||
}
|
||||
|
||||
// no matches
|
||||
_add_token(++pos);
|
||||
}
|
||||
}
|
||||
|
||||
return bpe_offsets;
|
||||
}
|
||||
|
||||
// use std::wregex to split the text
|
||||
static std::vector<size_t> unicode_regex_split_stl(const std::wstring & wtext, const std::wstring & regex_expr, const std::vector<size_t> & offsets) {
|
||||
std::wregex expr(regex_expr);
|
||||
std::vector<size_t> bpe_offsets; // store the offset of each word
|
||||
bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
|
||||
size_t start = 0;
|
||||
for (auto offset : offsets) {
|
||||
std::wcregex_iterator it(wtext.data() + start, wtext.data() + start + offset, expr);
|
||||
std::wcregex_iterator end;
|
||||
|
||||
int64_t start_idx = 0;
|
||||
while (it != end) {
|
||||
std::wcmatch match = *it;
|
||||
if (match.position() > start_idx) {
|
||||
bpe_offsets.emplace_back(match.position() - start_idx);
|
||||
}
|
||||
bpe_offsets.emplace_back(match.length());
|
||||
start_idx = match.position() + match.length();
|
||||
++it;
|
||||
}
|
||||
|
||||
if (start_idx < (int64_t) offset) {
|
||||
bpe_offsets.emplace_back(offset - start_idx);
|
||||
}
|
||||
start += offset;
|
||||
}
|
||||
|
||||
return bpe_offsets;
|
||||
}
|
||||
|
||||
// use std::regex to split the text
|
||||
static std::vector<size_t> unicode_regex_split_stl(const std::string & text, const std::string & regex_expr, const std::vector<size_t> & offsets) {
|
||||
std::regex expr(regex_expr);
|
||||
std::vector<size_t> bpe_offsets; // store the offset of each word
|
||||
bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
|
||||
size_t start = 0;
|
||||
for (auto offset : offsets) {
|
||||
std::cregex_iterator it(text.data() + start, text.data() + start + offset, expr);
|
||||
std::cregex_iterator end;
|
||||
|
||||
int64_t start_idx = 0;
|
||||
while (it != end) {
|
||||
std::cmatch match = *it;
|
||||
if (match.position() > start_idx) {
|
||||
bpe_offsets.emplace_back(match.position() - start_idx);
|
||||
}
|
||||
bpe_offsets.emplace_back(match.length());
|
||||
start_idx = match.position() + match.length();
|
||||
++it;
|
||||
}
|
||||
|
||||
if (start_idx < (int64_t) offset) {
|
||||
bpe_offsets.emplace_back(offset - start_idx);
|
||||
}
|
||||
start += offset;
|
||||
}
|
||||
|
||||
return bpe_offsets;
|
||||
}
|
||||
|
||||
static std::vector<size_t> unicode_regex_split_custom(const std::string & text, const std::string & regex_expr, const std::vector<size_t> & offsets) {
|
||||
std::vector<size_t> bpe_offsets;
|
||||
|
||||
if (regex_expr == "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)") {
|
||||
bpe_offsets = unicode_regex_split_custom_gpt2(text, offsets);
|
||||
} else if (
|
||||
regex_expr == "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" ||
|
||||
regex_expr == "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+") {
|
||||
|
||||
bpe_offsets = unicode_regex_split_custom_llama3(text, offsets);
|
||||
}
|
||||
|
||||
return bpe_offsets;
|
||||
}
|
||||
|
||||
//
|
||||
// interface
|
||||
//
|
||||
|
||||
std::string unicode_cpt_to_utf8(uint32_t cpt) {
|
||||
std::string result;
|
||||
|
||||
if (/* 0x00 <= cpt && */ cpt <= 0x7f) {
|
||||
result.push_back(cpt);
|
||||
return result;
|
||||
}
|
||||
if (0x80 <= cpt && cpt <= 0x7ff) {
|
||||
result.push_back(0xc0 | ((cpt >> 6) & 0x1f));
|
||||
result.push_back(0x80 | (cpt & 0x3f));
|
||||
return result;
|
||||
}
|
||||
if (0x800 <= cpt && cpt <= 0xffff) {
|
||||
result.push_back(0xe0 | ((cpt >> 12) & 0x0f));
|
||||
result.push_back(0x80 | ((cpt >> 6) & 0x3f));
|
||||
result.push_back(0x80 | (cpt & 0x3f));
|
||||
return result;
|
||||
}
|
||||
if (0x10000 <= cpt && cpt <= 0x10ffff) {
|
||||
result.push_back(0xf0 | ((cpt >> 18) & 0x07));
|
||||
result.push_back(0x80 | ((cpt >> 12) & 0x3f));
|
||||
result.push_back(0x80 | ((cpt >> 6) & 0x3f));
|
||||
result.push_back(0x80 | (cpt & 0x3f));
|
||||
return result;
|
||||
}
|
||||
|
||||
throw std::invalid_argument("invalid codepoint");
|
||||
}
|
||||
|
||||
std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts) {
|
||||
auto comp = [] (const uint32_t cpt, const range_nfd & range) {
|
||||
return cpt < range.first;
|
||||
};
|
||||
std::vector<uint32_t> result(cpts.size());
|
||||
for (size_t i = 0; i < cpts.size(); ++i) {
|
||||
const uint32_t cpt = cpts[i];
|
||||
auto it = std::upper_bound(unicode_ranges_nfd.begin(), unicode_ranges_nfd.end(), cpt, comp) - 1;
|
||||
result[i] = (it->first <= cpt && cpt <= it->last) ? it->nfd : cpt;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8) {
|
||||
std::vector<uint32_t> result;
|
||||
result.reserve(utf8.size());
|
||||
size_t offset = 0;
|
||||
while (offset < utf8.size()) {
|
||||
result.push_back(unicode_cpt_from_utf8(utf8, offset));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
unicode_cpt_flags unicode_cpt_flags_from_cpt(const uint32_t cpt) {
|
||||
static const unicode_cpt_flags undef(unicode_cpt_flags::UNDEFINED);
|
||||
static const auto cpt_flags = unicode_cpt_flags_array();
|
||||
return cpt < cpt_flags.size() ? cpt_flags[cpt] : undef;
|
||||
}
|
||||
|
||||
unicode_cpt_flags unicode_cpt_flags_from_utf8(const std::string & utf8) {
|
||||
static const unicode_cpt_flags undef(unicode_cpt_flags::UNDEFINED);
|
||||
if (utf8.empty()) {
|
||||
return undef; // undefined
|
||||
}
|
||||
size_t offset = 0;
|
||||
return unicode_cpt_flags_from_cpt(unicode_cpt_from_utf8(utf8, offset));
|
||||
}
|
||||
|
||||
std::string unicode_byte_to_utf8(uint8_t byte) {
|
||||
static std::unordered_map<uint8_t, std::string> map = unicode_byte_to_utf8_map();
|
||||
return map.at(byte);
|
||||
}
|
||||
|
||||
uint8_t unicode_utf8_to_byte(const std::string & utf8) {
|
||||
static std::unordered_map<std::string, uint8_t> map = unicode_utf8_to_byte_map();
|
||||
return map.at(utf8);
|
||||
}
|
||||
|
||||
uint32_t unicode_tolower(uint32_t cpt) {
|
||||
// binary search
|
||||
auto it = std::lower_bound(unicode_map_lowercase.begin(), unicode_map_lowercase.end(), cpt,
|
||||
[](const std::pair<uint32_t, uint32_t> & pair, uint32_t value) {
|
||||
return pair.first < value;
|
||||
});
|
||||
if (it != unicode_map_lowercase.end() && it->first == cpt) {
|
||||
return it->second;
|
||||
}
|
||||
return cpt; // Return the original code point if no lowercase mapping is found
|
||||
}
|
||||
|
||||
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs) {
|
||||
// unicode categories
|
||||
static const std::map<std::string, int> k_ucat_enum = {
|
||||
{ "\\p{N}", unicode_cpt_flags::NUMBER },
|
||||
{ "\\p{L}", unicode_cpt_flags::LETTER },
|
||||
{ "\\p{P}", unicode_cpt_flags::PUNCTUATION },
|
||||
{ "\\p{M}", unicode_cpt_flags::ACCENT_MARK },
|
||||
{ "\\p{S}", unicode_cpt_flags::SYMBOL },
|
||||
};
|
||||
|
||||
static const std::map<int, int> k_ucat_cpt = {
|
||||
{ unicode_cpt_flags::NUMBER, 0xD1 },
|
||||
{ unicode_cpt_flags::LETTER, 0xD2 },
|
||||
{ unicode_cpt_flags::PUNCTUATION, 0xD3 },
|
||||
{ unicode_cpt_flags::ACCENT_MARK, 0xD4 },
|
||||
{ unicode_cpt_flags::SYMBOL, 0xD5 },
|
||||
};
|
||||
|
||||
static const std::map<int, std::string> k_ucat_map = {
|
||||
{ unicode_cpt_flags::NUMBER, "\x30-\x39" }, // 0-9
|
||||
{ unicode_cpt_flags::LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z
|
||||
{ unicode_cpt_flags::PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}
|
||||
{ unicode_cpt_flags::ACCENT_MARK, "" }, // no sub-128 codepoints
|
||||
{ unicode_cpt_flags::SYMBOL, "\\\x24\\\x2B\x3C-\x3E\x5E\x60\\\x7C" }, // $+<=>^`|
|
||||
};
|
||||
|
||||
// compute collapsed codepoints only if needed by at least one regex
|
||||
bool need_collapse = false;
|
||||
for (const auto & regex_expr : regex_exprs) {
|
||||
// search for unicode categories
|
||||
for (const auto & ucat : k_ucat_enum) {
|
||||
if (std::string::npos != regex_expr.find(ucat.first)) {
|
||||
need_collapse = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const auto cpts = unicode_cpts_from_utf8(text);
|
||||
|
||||
// generate a "collapsed" representation of the text, where all codepoints are replaced by a single byte
|
||||
// ref: https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2081479935
|
||||
std::string text_collapsed;
|
||||
if (need_collapse) {
|
||||
// collapse all unicode categories
|
||||
text_collapsed.resize(cpts.size());
|
||||
|
||||
for (size_t i = 0; i < cpts.size(); ++i) {
|
||||
// keep single-byte codepoints as is
|
||||
if (cpts[i] < 128) {
|
||||
text_collapsed[i] = cpts[i];
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto flags = unicode_cpt_flags_from_cpt(cpts[i]);
|
||||
|
||||
if (flags.is_whitespace) {
|
||||
//NOTE: C++ std::regex \s does not mach 0x85, Rust and Python regex does.
|
||||
//text_collapsed[i] = (char) 0x85; // <Next Line> as whitespace fallback
|
||||
text_collapsed[i] = (char) 0x0B; // <vertical tab> as whitespace fallback
|
||||
} else if (k_ucat_cpt.find(flags.category_flag()) != k_ucat_cpt.end()) {
|
||||
text_collapsed[i] = k_ucat_cpt.at(flags.category_flag());
|
||||
} else {
|
||||
text_collapsed[i] = (char) 0xD0; // fallback
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<size_t> bpe_offsets = { cpts.size() };
|
||||
|
||||
for (const auto & regex_expr : regex_exprs) {
|
||||
// first, see if we have an efficient custom regex implementation
|
||||
auto tmp = unicode_regex_split_custom(text, regex_expr, bpe_offsets);
|
||||
|
||||
if (!tmp.empty()) {
|
||||
bpe_offsets = std::move(tmp);
|
||||
continue;
|
||||
}
|
||||
|
||||
// fallback to general-purpose std::regex / std::wregex
|
||||
try {
|
||||
// if a unicode category is used in the regex, we use the collapsed text and replace the unicode category
|
||||
// with the corresponding collapsed representation
|
||||
bool use_collapsed = false;
|
||||
for (const auto & ucat : k_ucat_enum) {
|
||||
if (std::string::npos != regex_expr.find(ucat.first)) {
|
||||
use_collapsed = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (use_collapsed) {
|
||||
// sanity-check that the original regex does not contain any non-ASCII characters
|
||||
const auto cpts_regex = unicode_cpts_from_utf8(regex_expr);
|
||||
for (size_t i = 0; i < cpts_regex.size(); ++i) {
|
||||
if (cpts_regex[i] >= 128) {
|
||||
throw std::runtime_error("Regex includes both unicode categories and non-ASCII characters - not supported");
|
||||
}
|
||||
}
|
||||
|
||||
// generate a collapsed representation of the regex
|
||||
std::string regex_expr_collapsed;
|
||||
|
||||
// track if we are inside [], because nested [] are not allowed
|
||||
bool inside = false;
|
||||
for (size_t i = 0; i < regex_expr.size(); ++i) {
|
||||
if (regex_expr[i] == '[' && (i == 0 || regex_expr[i - 1] != '\\')) {
|
||||
regex_expr_collapsed += '[';
|
||||
inside = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (inside && regex_expr[i] == ']' && regex_expr[i - 1] != '\\') {
|
||||
regex_expr_collapsed += ']';
|
||||
inside = false;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (regex_expr[i + 0] == '\\' && i + 4 < regex_expr.size() &&
|
||||
regex_expr[i + 1] == 'p' &&
|
||||
regex_expr[i + 2] == '{' &&
|
||||
regex_expr[i + 4] == '}') {
|
||||
const std::string pat = regex_expr.substr(i, 5);
|
||||
if (k_ucat_enum.find(pat) != k_ucat_enum.end()) {
|
||||
if (!inside) {
|
||||
regex_expr_collapsed += '[';
|
||||
}
|
||||
regex_expr_collapsed += k_ucat_cpt.at(k_ucat_enum.at(pat));
|
||||
regex_expr_collapsed += k_ucat_map.at(k_ucat_enum.at(pat));
|
||||
if (!inside) {
|
||||
regex_expr_collapsed += ']';
|
||||
}
|
||||
i += 4;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
regex_expr_collapsed += regex_expr[i];
|
||||
}
|
||||
|
||||
//printf("text_collapsed: %s\n", text_collapsed.c_str());
|
||||
//printf("regex_expr_collapsed: %s\n", regex_expr_collapsed.c_str());
|
||||
bpe_offsets = unicode_regex_split_stl(text_collapsed, regex_expr_collapsed, bpe_offsets);
|
||||
} else {
|
||||
// no unicode category used, we can use std::wregex directly
|
||||
const std::wstring wregex_expr = unicode_wstring_from_utf8(regex_expr);
|
||||
|
||||
// std::wregex \s does not mach non-ASCII whitespaces, using 0x0B as fallback
|
||||
std::wstring wtext(cpts.begin(), cpts.end());
|
||||
for (size_t i = 0; i < wtext.size(); ++i) {
|
||||
if (wtext[i] > 0x7F && unicode_cpt_flags_from_cpt(wtext[i]).is_whitespace) {
|
||||
wtext[i] = 0x0B;
|
||||
}
|
||||
}
|
||||
|
||||
//printf("text: %s\n", text.c_str());
|
||||
//printf("regex_expr: %s\n", regex_expr.c_str());
|
||||
bpe_offsets = unicode_regex_split_stl(wtext, wregex_expr, bpe_offsets);
|
||||
}
|
||||
} catch (std::regex_error & e) {
|
||||
fprintf(stderr, "Failed to process regex: '%s'\n", regex_expr.c_str());
|
||||
fprintf(stderr, "Regex error: %s\n", e.what());
|
||||
throw std::runtime_error("Failed to process regex");
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::string> bpe_words;
|
||||
bpe_words.reserve(bpe_offsets.size()); // reserve memory for the approximate size
|
||||
|
||||
size_t start = 0;
|
||||
for (size_t & offset : bpe_offsets) {
|
||||
bpe_words.emplace_back();
|
||||
for (size_t i = start; i < start + offset; ++i) {
|
||||
bpe_words.back() += unicode_cpt_to_utf8(cpts[i]);
|
||||
}
|
||||
start += offset;
|
||||
}
|
||||
|
||||
return unicode_byte_encoding_process(bpe_words);
|
||||
}
|
94
llms/export/unicode.h
Normal file
94
llms/export/unicode.h
Normal file
@ -0,0 +1,94 @@
|
||||
/**
|
||||
* The following unicode files:
|
||||
* - unicode.h
|
||||
* - unicode_data.cpp
|
||||
* - unicode.cpp
|
||||
* are copied from llama.cpp with minor modifications:
|
||||
* https://github.com/ggerganov/llama.cpp/
|
||||
* Commit hash: 8d59d911711b8f1ba9ec57c4b192ccd2628af033
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
/* unicode-data.h */
|
||||
struct range_nfd {
|
||||
uint32_t first;
|
||||
uint32_t last;
|
||||
uint32_t nfd;
|
||||
};
|
||||
|
||||
static const uint32_t MAX_CODEPOINTS = 0x110000;
|
||||
|
||||
extern const std::initializer_list<std::pair<uint32_t, uint16_t>> unicode_ranges_flags;
|
||||
extern const std::unordered_set<uint32_t> unicode_set_whitespace;
|
||||
extern const std::initializer_list<std::pair<uint32_t, uint32_t>> unicode_map_lowercase;
|
||||
extern const std::initializer_list<std::pair<uint32_t, uint32_t>> unicode_map_uppercase;
|
||||
extern const std::initializer_list<range_nfd> unicode_ranges_nfd;
|
||||
|
||||
/* original unicode.h */
|
||||
|
||||
struct unicode_cpt_flags {
|
||||
enum {
|
||||
UNDEFINED = 0x0001,
|
||||
NUMBER = 0x0002, // regex: \p{N}
|
||||
LETTER = 0x0004, // regex: \p{L}
|
||||
SEPARATOR = 0x0008, // regex: \p{Z}
|
||||
ACCENT_MARK = 0x0010, // regex: \p{M}
|
||||
PUNCTUATION = 0x0020, // regex: \p{P}
|
||||
SYMBOL = 0x0040, // regex: \p{S}
|
||||
CONTROL = 0x0080, // regex: \p{C}
|
||||
MASK_CATEGORIES = 0x00FF,
|
||||
};
|
||||
|
||||
// codepoint type
|
||||
uint16_t is_undefined : 1;
|
||||
uint16_t is_number : 1; // regex: \p{N}
|
||||
uint16_t is_letter : 1; // regex: \p{L}
|
||||
uint16_t is_separator : 1; // regex: \p{Z}
|
||||
uint16_t is_accent_mark : 1; // regex: \p{M}
|
||||
uint16_t is_punctuation : 1; // regex: \p{P}
|
||||
uint16_t is_symbol : 1; // regex: \p{S}
|
||||
uint16_t is_control : 1; // regex: \p{C}
|
||||
// helper flags
|
||||
uint16_t is_whitespace : 1; // regex: \s
|
||||
uint16_t is_lowercase : 1;
|
||||
uint16_t is_uppercase : 1;
|
||||
uint16_t is_nfd : 1;
|
||||
|
||||
// decode from uint16
|
||||
inline unicode_cpt_flags(const uint16_t flags = 0) {
|
||||
*reinterpret_cast<uint16_t*>(this) = flags;
|
||||
}
|
||||
|
||||
inline uint16_t as_uint() const {
|
||||
return *reinterpret_cast<const uint16_t*>(this);
|
||||
}
|
||||
|
||||
inline uint16_t category_flag() const {
|
||||
return this->as_uint() & MASK_CATEGORIES;
|
||||
}
|
||||
};
|
||||
|
||||
size_t unicode_len_utf8(char src);
|
||||
|
||||
std::string unicode_cpt_to_utf8 (uint32_t cpt);
|
||||
uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset);
|
||||
|
||||
std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8);
|
||||
|
||||
std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts);
|
||||
|
||||
unicode_cpt_flags unicode_cpt_flags_from_cpt (uint32_t cpt);
|
||||
unicode_cpt_flags unicode_cpt_flags_from_utf8(const std::string & utf8);
|
||||
|
||||
std::string unicode_byte_to_utf8(uint8_t byte);
|
||||
uint8_t unicode_utf8_to_byte(const std::string & utf8);
|
||||
|
||||
uint32_t unicode_tolower(uint32_t cpt);
|
||||
|
||||
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs);
|
7034
llms/export/unicode_data.cpp
Normal file
7034
llms/export/unicode_data.cpp
Normal file
File diff suppressed because it is too large
Load Diff
@ -74,9 +74,9 @@ class Attention(nn.Module):
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
queries = mx.unflatten(queries, -1, (self.n_heads, -1)).transpose(0, 2, 1, 3)
|
||||
keys = mx.unflatten(keys, -1, (self.n_kv_heads, -1)).transpose(0, 2, 1, 3)
|
||||
values = mx.unflatten(values, -1, (self.n_kv_heads, -1)).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
@ -90,7 +90,7 @@ class Attention(nn.Module):
|
||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
output = output.transpose(0, 2, 1, 3).flatten(-2, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user