export and run llama in C++

This commit is contained in:
Awni Hannun 2025-01-08 16:17:45 -08:00
parent b8f0cacfa8
commit 761b2c9886
14 changed files with 8628 additions and 4 deletions

1
llms/export/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
build/

View File

@ -0,0 +1,42 @@
cmake_minimum_required(VERSION 3.27)
project(mlxlm LANGUAGES CXX)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
find_package(
Python 3.9
COMPONENTS Interpreter Development.Module
REQUIRED)
execute_process(
COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE MLX_ROOT)
find_package(MLX CONFIG REQUIRED)
add_library(mlxlm)
target_link_libraries(mlxlm PUBLIC mlx)
include(FetchContent)
FetchContent_Declare(
json
URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz)
FetchContent_MakeAvailable(json)
target_include_directories(
mlxlm PRIVATE $<BUILD_INTERFACE:${json_SOURCE_DIR}/single_include/nlohmann>)
target_sources(mlxlm
PRIVATE
mlxlm.cpp
tokenizer.cpp
unicode.cpp
unicode_data.cpp)
add_executable(main main.cpp)
target_link_libraries(main PRIVATE mlxlm)
add_executable(test test.cpp)
target_link_libraries(test PRIVATE mlxlm)

34
llms/export/README.md Normal file
View File

@ -0,0 +1,34 @@
# Export LLMs to C++
Export language model inference from Python to run directly in C++.
To run, first install the requirements:
```bash
pip install -U mlx-lm
```
Then generate text from Python with:
```bash
python export.py generate "How tall is K2?"
```
To export the generation function run:
```bash
python export.py export
```
Then build the C++ code (requires CMake):
```bash
cmake -B build -DCMAKE_BUILD_TYPE=Release
cmake --build build
```
And run the generation from C++ with:
```bash
./build/main lama3.1-instruct-4bit "How tall is K2?"
```

171
llms/export/export.py Normal file
View File

@ -0,0 +1,171 @@
import time
from pathlib import Path
import fire
import mlx.core as mx
from mlx_lm import load
class ExportableCache:
def __init__(self, keys=None, values=None, offset=0):
self.offset = offset
self.keys = keys
self.values = values
def update_and_fetch(self, keys, values):
if self.keys is not None:
self.keys = mx.slice_update(self.keys, keys, self.offset, axes=(2,))
self.values = mx.slice_update(self.values, values, self.offset, axes=(2,))
else:
self.keys = keys
self.values = values
return self.keys, self.values
@property
def state(self):
return self.keys, self.values
def expand(cache, mask=None, cache_step_size=256):
cache_size = cache[0].shape[-2]
new_size = cache_step_size * ((cache_size + cache_step_size) // cache_step_size)
def expand_kv(x):
B, n_heads, _, head_dim = x.shape
new_x = mx.zeros((B, n_heads, new_size, head_dim), x.dtype)
new_x[..., : x.shape[2], :] = x
return new_x
cache = [expand_kv(c) for c in cache]
if mask is None:
mask = mx.full(new_size, False)
mask[:cache_size] = True
else:
mask = mx.concatenate([mask, mx.full(cache_step_size, False)])
return cache, mask
def causal_mask(N):
idx = mx.arange(N)
return idx[:, None] >= idx
def step(model, y, *state):
mask = state[-1]
if len(state) > 1:
cache, offset = state[:-2], state[-2]
cache = [
ExportableCache(keys, values, offset)
for keys, values in zip(cache[::2], cache[1::2])
]
else:
cache = [ExportableCache() for i in range(len(model.model.layers))]
logits = model(y, cache=cache, mask=mask)
cache = [y for x in cache for y in x.state]
return logits, *cache
def generate_step(prompt, model, max_tokens):
mx.eval(model)
compiled_step = mx.compile(lambda *args: step(model, *args), shapeless=True)
def _step(*args):
logits, *cache = compiled_step(*args)
return mx.argmax(logits[:, -1], axis=-1), *cache
y, *cache = _step(prompt, causal_mask(prompt.size))
mx.async_eval(y)
offset = mx.array(prompt.size, mx.uint32)
cache, mask = expand(cache)
n = 0
while True:
if n < max_tokens - 1:
if mask.size <= (prompt.size + n):
cache, mask = expand(cache, mask)
mask[prompt.size + n] = True
next_y, *cache = _step(y[None], *cache, offset, mask)
mx.async_eval(next_y)
offset += 1
n += 1
yield y.item()
if n == max_tokens:
break
y = next_y
def export(
model="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit",
path="llama3.1-instruct-4bit",
):
model, tokenizer = load(model)
mx.eval(model)
tokenizer.save_pretrained(path)
_step = lambda *args: step(model, *args)
# Make example inputs
y_prompt = mx.array([[0, 0]], mx.uint32)
y_gen = mx.array([[0]], mx.uint32)
offset = mx.array([0], mx.uint32)
mask = causal_mask(y_prompt.size)
_, *cache = _step(y_prompt, mask)
model_path = str(Path(path) / "model.mlxfn")
with mx.exporter(model_path, _step, shapeless=True) as exporter:
exporter(y_prompt, mask)
cache, mask = expand(cache)
exporter(y_gen, *cache, offset, mask)
def generate(
prompt,
model="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit",
max_tokens=128,
):
print("[INFO] Loading model from disk.")
model, tokenizer = load(model)
prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
add_generation_prompt=True,
return_tensors="mlx",
)
print("[INFO] Starting generation...")
tic = time.time()
tokens = []
detokenizer = tokenizer.detokenizer
detokenizer.reset()
for n, token in enumerate(generate_step(prompt, model, max_tokens)):
if n == 0:
prompt_tps = prompt.size / (time.time() - tic)
tic = time.time()
if token in tokenizer.eos_token_ids:
break
detokenizer.add_token(token)
print(detokenizer.last_segment, end="", flush=True)
detokenizer.finalize()
print(detokenizer.last_segment, flush=True)
gen_tps = (n + 1) / (time.time() - tic)
peak_memory = mx.metal.get_peak_memory() / 1e9
print("=" * 10)
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
print(f"Peak RAM: {peak_memory:.3f} GB")
if __name__ == "__main__":
fire.Fire(
{
"generate": generate,
"export": export,
}
)

18
llms/export/main.cpp Normal file
View File

@ -0,0 +1,18 @@
// Copyright © 2024 Apple Inc.
#include <iostream>
#include "mlxlm.h"
int main(int argc, char *argv[]) {
if (argc < 3) {
std::cerr << "Must provide the model path and prompt." << std::endl;
return 1;
}
auto path = std::string(argv[1]);
auto prompt = std::string(argv[2]);
auto model = load_model(path + "/model.mlxfn");
auto tokenizer = load_tokenizer(path);
generate(model, tokenizer, prompt);
}

119
llms/export/mlxlm.cpp Normal file
View File

@ -0,0 +1,119 @@
// Copyright © 2024 Apple Inc.
#include <chrono>
#include <iomanip>
#include <iostream>
#include "mlxlm.h"
namespace mx = mlx::core;
#define seconds(x) \
(std::chrono::duration_cast<std::chrono::nanoseconds>(x).count() / 1e9)
#define time_now() std::chrono::high_resolution_clock::now()
// Maybe compile
std::function<mx::Args(mx::Args)> load_model(const std::string& path) {
return mx::compile(mx::import_function(path), /* shapeless = */ true);
}
// Maybe make tokenizer virtual
BPETokenizer load_tokenizer(const std::string& path) {
return BPETokenizer(path);
}
void generate(
const std::function<mx::Args(mx::Args)>& model,
const BPETokenizer& tokenizer,
const std::string& prompt,
int max_tokens /* = 256 */) {
auto prompt_tokens = tokenizer.encode(prompt);
int prompt_size = prompt_tokens.size();
auto y = mx::array(prompt_tokens.data(), {1, prompt_size}, mx::uint32);
auto create_causal_mask = [](int N) {
auto indices = mx::arange(N);
return mx::expand_dims(indices, 1) >= indices;
};
// Helper to expand the cache and mask
auto expand = [](auto& args, auto& mask) {
constexpr int cache_step_size = 256;
int cache_size = args[1].shape(-2);
int new_size = cache_step_size * ((cache_size + cache_step_size) / cache_step_size);
for (auto it = args.begin() + 1; it != args.end(); ++it) {
auto& x = *it;
auto shape = x.shape();
shape[2] = new_size;
auto new_x = mx::zeros(shape, x.dtype());
shape[2] = cache_size;
*it = mx::slice_update(new_x, x, mx::Shape(x.ndim(), 0), std::move(shape));
}
mask = mx::slice_update(mx::full({new_size}, false), mask, {0}, {cache_size});
};
auto tic = time_now();
float prompt_time;
int n = 0;
mx::Args args;
{
args = model({y, create_causal_mask(y.size())});
auto logits = args[0];
logits = slice(logits, {0, -1, 0}, logits.shape());
y = argmax(logits, -1);
async_eval(y);
}
auto offset = mx::array(prompt_size, mx::uint32);
std::vector<int> tokens;
auto mask = mx::full({prompt_size}, true);
expand(args, mask);
for (; n < max_tokens; ++n) {
// Start next token decoding if needed
if (n < max_tokens - 1) {
args[0] = y;
auto m = prompt_size + n;
if (mask.size() <= m) {
expand(args, mask);
}
mask = mx::slice_update(mask, mx::array(true), {m}, {m + 1});
args.push_back(offset);
args.push_back(mask);
args = model(args);
args[0] = argmax(args[0], -1);
offset = offset + 1u;
async_eval(args[0]);
}
auto token = y.item<int>();
if (token == tokenizer.eos_token_id()) {
break;
}
tokens.push_back(token);
auto [result, complete] = tokenizer.try_decode(tokens);
if (complete) {
std::cout << result << std::flush;
tokens.clear();
}
if (n == 0) {
prompt_time = seconds(time_now() - tic);
tic = time_now();
}
if (n < max_tokens - 1) {
y = args[0];
}
}
auto result = tokenizer.decode(tokens);
std::cout << result << std::flush;
auto gen_time = seconds(time_now() - tic);
std::cout << std::endl;
std::cout << std::setprecision(5) << "Prompt toks/sec "
<< prompt_size / prompt_time << "\nGeneration toks/sec "
<< (n + 1) / gen_time << std::endl;
}

20
llms/export/mlxlm.h Normal file
View File

@ -0,0 +1,20 @@
// Copyright © 2024 Apple Inc.
#include <mlx/mlx.h>
#include "tokenizer.h"
namespace mx = mlx::core;
std::function<mx::Args(mx::Args)> load_model(const std::string& path);
BPETokenizer load_tokenizer(const std::string& path);
struct GenerationResponse {
};
void generate(
const std::function<mx::Args(mx::Args)>& model,
const BPETokenizer& tokenizer,
const std::string& prompt,
int max_tokens = 256);

23
llms/export/test.cpp Normal file
View File

@ -0,0 +1,23 @@
// Copyright © 2024 Apple Inc.
#include "tokenizer.h"
#include <iostream>
template <typename T, typename U = T> void check(const T &x, const U &y) {
if (x != y) {
std::cerr << "Mismatch" << std::endl;
}
}
void test_tokenizer(const std::string &path) {
BPETokenizer tokenizer(path);
check(tokenizer.encode("hello world!"), {128000, 15339, 1917, 0});
check(tokenizer.decode({15339}), "hello");
check(tokenizer.decode({0}), "!");
check(tokenizer.decode({1917}), " world");
check(tokenizer.encode("we'd see you say 世界你好真实好的很啊"),
{128000, 906, 4265, 220, 1518, 256, 499, 2019, 127365, 57668, 53901,
89151, 41073, 110085, 101600, 102856});
}
int main(int argc, char *argv[]) { test_tokenizer("."); }

189
llms/export/tokenizer.cpp Normal file
View File

@ -0,0 +1,189 @@
#include <fstream>
#include <filesystem>
#include <locale>
#include <codecvt>
#include <json.hpp>
#include "tokenizer.h"
#include "unicode.h"
using json = nlohmann::json;
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
std::pair<std::wstring, int> utf8_to_utf16(const std::string& s) {
static std::string replace_str = std::string(1, 0xFF);
static std::wstring replace_wstr = std::wstring(1, 0xFFFD);
std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>> cvt(replace_str, replace_wstr);
auto out = cvt.from_bytes(s);
return {out, cvt.converted()};
}
#pragma GCC diagnostic pop
auto make_byte_decoder() {
std::unordered_map<uint16_t, char> byte_decoder;
std::vector<uint16_t> limits = {0, '!', '~' + 1, L'¡', L'¬' + 1, L'®', L'ÿ' + 1};
char n = 0;
for (int i = 0; i < limits.size() - 1; ++i) {
auto start = limits[i];
auto stop = limits[i + 1];
if (i % 2 == 0) {
for (int b = start; b < stop; ++b) {
byte_decoder[256 + n++] = b;
}
} else {
for (int b = start; b < stop; ++b) {
byte_decoder[b] = b;
}
}
}
return byte_decoder;
}
auto BPETokenizer::byte_decoder_ = make_byte_decoder();
BPETokenizer::BPETokenizer(const std::string& path_) {
auto path = std::filesystem::path(path_);
std::ifstream ifs(path / "tokenizer.json");
auto tokenizer = json::parse(ifs);
auto model = tokenizer["model"];
token_to_id_ = model["vocab"];
id_to_token_.resize(token_to_id_.size());
for (auto& [s, id] : token_to_id_) {
if (id >= id_to_token_.size()) {
id_to_token_.resize(id + 1);
}
id_to_token_[id] = s;
}
std::string type = model["type"];
auto merges = model["merges"];
for (auto& s : merges) {
if (s.is_string()) {
merges_.emplace(s, merges_.size());
} else {
std::string s1 = s[0];
std::string s2 = s[1];
merges_.emplace(s1 + " " + s2, merges_.size());
}
}
auto added_tokens = tokenizer["added_tokens"];
for (auto& added_token : added_tokens) {
int id = added_token["id"];
if (id >= id_to_token_.size()) {
id_to_token_.resize(id + 1);
}
id_to_token_[id] = added_token["content"];
if (id_to_token_[id] == "<|begin_of_text|>") {
bos_id_ = id;
} else if (id_to_token_[id] == "<|eot_id|>") {
eos_id_ = id;
}
}
// Currently hardcoded to Llama3 BPE regex
pre_tokenizer_regex_ = {"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"};
}
std::vector<int> BPETokenizer::encode(std::string text) const {
auto segments = unicode_regex_split(text, pre_tokenizer_regex_);
auto one_step_merge = [this](std::string segment, std::vector<int>& splits) {
int merge_idx;
int rank = INT32_MAX;
for (int i = 0; i < splits.size() - 2; ++i) {
auto start = splits[i];
auto mid = splits[i + 1];
auto end = splits[i + 2];
std::string candidate = segment.substr(start, mid - start);
candidate += " ";
candidate += segment.substr(mid, end - mid);
if (auto it = merges_.find(candidate); it != merges_.end()) {
if (it->second < rank) {
merge_idx = i;
rank = it->second;
}
}
}
if (rank == INT32_MAX) {
return false;
}
auto start = splits[merge_idx];
auto mid = splits[merge_idx + 1];
auto end = splits[merge_idx + 2];
std::string merge_l = segment.substr(start, mid - start);
std::string merge_r = segment.substr(mid, end - mid);
for (int i = splits.size() - 2; i >= 0; --i) {
auto start = splits[i];
auto mid = splits[i + 1];
auto end = splits[i + 2];
if (segment.substr(start, mid - start) == merge_l && segment.substr(mid, end - mid) == merge_r) {
splits.erase(splits.begin() + i + 1);
i -= 1;
}
}
return true;
};
std::vector<int> ids;
ids.push_back(bos_id_);
// Initialize merges to integer list
auto merge_segment = [&ids, &one_step_merge, this](const std::string& segment) {
std::vector<int> splits;
for (int i = 0; i < segment.size(); ++i) {
splits.push_back(i);
if (static_cast<unsigned char>(segment[i]) > 128) {
i++;
}
}
splits.push_back(segment.size());
while (one_step_merge(segment, splits)) { };
for (int i = 0; i < splits.size() - 1; ++i) {
auto start = splits[i];
auto end = splits[i + 1];
std::string s = segment.substr(start, end - start);
if (auto it = token_to_id_.find(s); it != token_to_id_.end()) {
ids.push_back(it->second);
} else {
throw std::runtime_error("UNK ENCOUNTERED");
}
}
};
for (auto& segment : segments) {
merge_segment(segment);
}
return ids;
}
std::string BPETokenizer::id_to_bytes(int id) const {
std::string token;
auto [wide_token, _] = utf8_to_utf16(id_to_token_[id]);
token.resize(wide_token.size());
for (int i = 0; i < wide_token.size(); ++i) {
token[i] = byte_decoder_[wide_token[i]];
}
return token;
}
std::pair<std::string, bool> BPETokenizer::try_decode(const std::vector<int>& ids) const {
std::string text;
for (auto id : ids) {
text += id_to_bytes(id);
}
auto [_, converted] = utf8_to_utf16(text);
bool complete = converted == text.size();
text.resize(converted);
return {text, complete};
}
std::string BPETokenizer::decode(const std::vector<int>& ids) const {
return try_decode(ids).first;
}
int BPETokenizer::eos_token_id() const { return eos_id_; }

37
llms/export/tokenizer.h Normal file
View File

@ -0,0 +1,37 @@
// Copyright © 2024 Apple Inc.
#include <string>
#include <unordered_map>
#include <vector>
#pragma once
/** BPE Tokenizer API */
class BPETokenizer {
public:
BPETokenizer(const std::string& path);
/** Encode a string of text to token integer ids. */
std::vector<int> encode(std::string text) const;
/** Try to decode the vector of ids to text. The text is truncated to
* include only the fully decodable tokens. */
std::string decode(const std::vector<int>& ids) const;
/** Try to decode the vector of ids to text. The second return value
* indicates if the decoding completed. The text is truncated to include
* only the fully decodable tokens. */
std::pair<std::string, bool> try_decode(const std::vector<int>& ids) const;
int eos_token_id() const;
private:
std::unordered_map<std::string, int> token_to_id_;
std::vector<std::string> id_to_token_;
std::unordered_map<std::string, int> merges_;
int bos_id_;
int eos_id_;
static std::unordered_map<uint16_t, char> byte_decoder_;
std::string id_to_bytes(int id) const;
std::vector<std::string> pre_tokenizer_regex_;
};

842
llms/export/unicode.cpp Normal file
View File

@ -0,0 +1,842 @@
#if defined(_MSC_VER)
#define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING
#endif
#include "unicode.h"
#include <algorithm>
#include <cassert>
#include <cstddef>
#include <cstdint>
#include <map>
#include <regex>
#include <stdexcept>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include <locale>
#include <codecvt>
size_t unicode_len_utf8(char src) {
const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
uint8_t highbits = static_cast<uint8_t>(src) >> 4;
return lookup[highbits];
}
static std::string unicode_cpts_to_utf8(const std::vector<uint32_t> & cps) {
std::string result;
for (size_t i = 0; i < cps.size(); ++i) {
result.append(unicode_cpt_to_utf8(cps[i]));
}
return result;
}
uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset) {
assert(offset < utf8.size());
if (!(utf8[offset + 0] & 0x80)) {
auto result = utf8[offset + 0];
offset += 1;
return result;
}
if (!(utf8[offset + 0] & 0x40)) {
throw std::invalid_argument("invalid character");
}
if (!(utf8[offset + 0] & 0x20)) {
if (offset + 1 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80)) {
throw std::invalid_argument("invalid character");
}
auto result = ((utf8[offset + 0] & 0x1f) << 6) | (utf8[offset + 1] & 0x3f);
offset += 2;
return result;
}
if (!(utf8[offset + 0] & 0x10)) {
if (offset + 2 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80) || ! ((utf8[offset + 2] & 0xc0) == 0x80)) {
throw std::invalid_argument("invalid character");
}
auto result = ((utf8[offset + 0] & 0x0f) << 12) | ((utf8[offset + 1] & 0x3f) << 6) | (utf8[offset + 2] & 0x3f);
offset += 3;
return result;
}
if (!(utf8[offset + 0] & 0x08)) {
if (offset + 3 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80) || ! ((utf8[offset + 2] & 0xc0) == 0x80) || !((utf8[offset + 3] & 0xc0) == 0x80)) {
throw std::invalid_argument("invalid character");
}
auto result = ((utf8[offset + 0] & 0x07) << 18) | ((utf8[offset + 1] & 0x3f) << 12) | ((utf8[offset + 2] & 0x3f) << 6) | (utf8[offset + 3] & 0x3f);
offset += 4;
return result;
}
throw std::invalid_argument("failed to convert utf8 to codepoint");
}
//static std::vector<uint16_t> unicode_cpt_to_utf16(uint32_t cpt) {
// std::vector<uint16_t> result;
// if (/* 0x0000 <= cpt && */ cpt <= 0xffff) {
// result.emplace_back(cpt);
// return result;
// }
// if (0x10000 <= cpt && cpt <= 0x10ffff) {
// result.emplace_back(0xd800 | ((cpt - 0x10000) >> 10));
// result.emplace_back(0xdc00 | ((cpt - 0x10000) & 0x03ff));
// return result;
// }
// throw std::invalid_argument("failed to convert codepoint to utf16");
//}
//static std::vector<uint16_t> unicode_cpts_to_utf16(const std::vector<uint32_t> & cps) {
// std::vector<uint16_t> result;
// for (size_t i = 0; i < cps.size(); ++i) {
// auto temp = unicode_cpt_to_utf16(cps[i]);
// result.insert(result.end(), temp.begin(), temp.end());
// }
// return result;
//}
//static uint32_t unicode_cpt_from_utf16(const std::vector<uint16_t> & utf16, size_t & offset) {
// assert(offset < utf16.size());
// if (((utf16[0] >> 10) << 10) != 0xd800) {
// auto result = utf16[offset + 0];
// offset += 1;
// return result;
// }
//
// if (offset + 1 >= utf16.size() || !((utf16[1] & 0xdc00) == 0xdc00)) {
// throw std::invalid_argument("invalid character");
// }
//
// auto result = 0x10000 + (((utf16[0] & 0x03ff) << 10) | (utf16[1] & 0x03ff));
// offset += 2;
// return result;
//}
//static std::vector<uint32_t> unicode_cpts_from_utf16(const std::vector<uint16_t> & utf16) {
// std::vector<uint32_t> result;
// size_t offset = 0;
// while (offset < utf16.size()) {
// result.push_back(unicode_cpt_from_utf16(utf16, offset));
// }
// return result;
//}
static std::vector<unicode_cpt_flags> unicode_cpt_flags_array() {
std::vector<unicode_cpt_flags> cpt_flags(MAX_CODEPOINTS, unicode_cpt_flags::UNDEFINED);
assert (unicode_ranges_flags.begin()[0].first == 0);
assert (unicode_ranges_flags.begin()[unicode_ranges_flags.size()-1].first == MAX_CODEPOINTS);
for (size_t i = 1; i < unicode_ranges_flags.size(); ++i) {
const auto range_ini = unicode_ranges_flags.begin()[i-1]; // codepoint_ini, flags
const auto range_end = unicode_ranges_flags.begin()[i]; // codepoint_end, flags
for (uint32_t cpt = range_ini.first; cpt < range_end.first; ++cpt) {
cpt_flags[cpt] = range_ini.second;
}
}
for (auto cpt : unicode_set_whitespace) {
cpt_flags[cpt].is_whitespace = true;
}
for (auto p : unicode_map_lowercase) {
cpt_flags[p.second].is_lowercase = true;
}
for (auto p : unicode_map_uppercase) {
cpt_flags[p.second].is_uppercase = true;
}
for (auto &range : unicode_ranges_nfd) { // start, last, nfd
cpt_flags[range.nfd].is_nfd = true;
}
return cpt_flags;
}
static std::unordered_map<uint8_t, std::string> unicode_byte_to_utf8_map() {
std::unordered_map<uint8_t, std::string> map;
for (int ch = 0x21; ch <= 0x7E; ++ch) { // u'!' to u'~'
assert(0 <= ch && ch < 256);
map[ch] = unicode_cpt_to_utf8(ch);
}
for (int ch = 0xA1; ch <= 0xAC; ++ch) { // u'¡' to u'¬'
assert(0 <= ch && ch < 256);
map[ch] = unicode_cpt_to_utf8(ch);
}
for (int ch = 0xAE; ch <= 0xFF; ++ch) { // u'®' to u'ÿ'
assert(0 <= ch && ch < 256);
map[ch] = unicode_cpt_to_utf8(ch);
}
auto n = 0;
for (int ch = 0; ch < 256; ++ch) {
if (map.find(ch) == map.end()) {
map[ch] = unicode_cpt_to_utf8(256 + n);
++n;
}
}
return map;
}
static std::unordered_map<std::string, uint8_t> unicode_utf8_to_byte_map() {
std::unordered_map<std::string, uint8_t> map;
for (int ch = 0x21; ch <= 0x7E; ++ch) { // u'!' to u'~'
assert(0 <= ch && ch < 256);
map[unicode_cpt_to_utf8(ch)] = ch;
}
for (int ch = 0xA1; ch <= 0xAC; ++ch) { // u'¡' to u'¬'
assert(0 <= ch && ch < 256);
map[unicode_cpt_to_utf8(ch)] = ch;
}
for (int ch = 0xAE; ch <= 0xFF; ++ch) { // u'®' to u'ÿ'
assert(0 <= ch && ch < 256);
map[unicode_cpt_to_utf8(ch)] = ch;
}
auto n = 0;
for (int ch = 0; ch < 256; ++ch) {
if (map.find(unicode_cpt_to_utf8(ch)) == map.end()) {
map[unicode_cpt_to_utf8(256 + n)] = ch;
++n;
}
}
return map;
}
static inline std::wstring unicode_wstring_from_utf8(const std::string & s) {
#if defined(__clang__)
// disable C++17 deprecation warning for std::codecvt_utf8
# pragma clang diagnostic push
# pragma clang diagnostic ignored "-Wdeprecated-declarations"
#endif
std::wstring_convert<std::codecvt_utf8<wchar_t>> conv;
#if defined(__clang__)
# pragma clang diagnostic pop
#endif
return conv.from_bytes(s);
}
static std::vector<std::string> unicode_byte_encoding_process(const std::vector<std::string> & bpe_words) {
std::vector<std::string> bpe_encoded_words;
for (const auto & word : bpe_words) {
std::string text_utf;
auto utf_word = unicode_cpts_from_utf8(word);
for (size_t i = 0; i < utf_word.size(); ++i) {
text_utf += unicode_cpt_to_utf8(utf_word[i]);
}
std::string encoded_token;
for (char & c : text_utf) {
encoded_token += unicode_byte_to_utf8(c);
}
bpe_encoded_words.emplace_back(encoded_token);
}
return bpe_encoded_words;
}
// GPT2 system regex: 's|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+
static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & text, const std::vector<size_t> & offsets) {
std::vector<size_t> bpe_offsets; // store the offset of each word
bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
const auto cpts = unicode_cpts_from_utf8(text);
size_t start = 0;
for (auto offset : offsets) {
const size_t offset_ini = start;
const size_t offset_end = start + offset;
assert(offset_end <= cpts.size());
start = offset_end;
static const uint32_t OUT_OF_RANGE = 0xFFFFFFFF;
auto _get_cpt = [&] (const size_t pos) -> uint32_t {
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
};
auto _get_flags = [&] (const size_t pos) -> unicode_cpt_flags {
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags_from_cpt(cpts[pos]) : unicode_cpt_flags{};
};
size_t _prev_end = offset_ini;
auto _add_token = [&] (const size_t end) -> size_t {
assert(_prev_end <= end && end <= offset_end);
size_t len = end - _prev_end;
if (len > 0) {
bpe_offsets.push_back(len);
}
_prev_end = end;
//if (len > 0) {
// std::string s = "";
// for(size_t p = end-len; p < end; p++)
// s += unicode_cpt_to_utf8(cpts[p]);
// printf(">>> '%s'\n", s.c_str());
//}
return len;
};
for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
const uint32_t cpt = _get_cpt(pos);
const auto flags = _get_flags(pos);
// regex: 's|'t|'re|'ve|'m|'ll|'d
if (cpt == '\'' && pos+1 < offset_end) {
uint32_t cpt_next = _get_cpt(pos+1);
if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') {
pos += _add_token(pos+2);
continue;
}
if (pos+2 < offset_end) {
uint32_t cpt_next_next = _get_cpt(pos+2);
if ((cpt_next == 'r' && cpt_next_next == 'e') ||
(cpt_next == 'v' && cpt_next_next == 'e') ||
(cpt_next == 'l' && cpt_next_next == 'l')) {
pos += _add_token(pos+3);
continue;
}
}
}
auto flags2 = (cpt == ' ' ? _get_flags(pos+1) : flags);
// regex: <space>?\p{L}+
if (flags2.is_letter) {
pos += (cpt == ' ');
while (flags2.is_letter) {
flags2 = _get_flags(++pos);
}
_add_token(pos);
continue;
}
// regex: <space>?\p{N}+
if (flags2.is_number) {
pos += (cpt == ' ');
while (flags2.is_number) {
flags2 = _get_flags(++pos);
}
_add_token(pos);
continue;
}
// regex: <space>?[^\s\p{L}\p{N}]+
if (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags2.as_uint()) {
pos += (cpt == ' ');
while (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags2.as_uint()) {
flags2 = _get_flags(++pos);
}
_add_token(pos);
continue;
}
size_t num_whitespaces = 0;
while (_get_flags(pos+num_whitespaces).is_whitespace) {
num_whitespaces++;
}
// regex: \s+(?!\S)
if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != OUT_OF_RANGE) {
pos += num_whitespaces - 1;
_add_token(pos);
continue;
}
// regex: \s+
if (num_whitespaces > 0) {
pos += num_whitespaces;
_add_token(pos);
continue;
}
// no matches
_add_token(++pos);
}
}
return bpe_offsets;
}
// LLAMA3 system regex: "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"
static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string & text, const std::vector<size_t> & offsets) {
std::vector<size_t> bpe_offsets; // store the offset of each word
bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
const auto cpts = unicode_cpts_from_utf8(text);
size_t start = 0;
for (auto offset : offsets) {
const size_t offset_ini = start;
const size_t offset_end = start + offset;
assert(offset_end <= cpts.size());
start = offset_end;
static const uint32_t OUT_OF_RANGE = 0xFFFFFFFF;
auto _get_cpt = [&] (const size_t pos) -> uint32_t {
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
};
auto _get_flags = [&] (const size_t pos) -> unicode_cpt_flags {
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags_from_cpt(cpts[pos]) : unicode_cpt_flags{};
};
size_t _prev_end = offset_ini;
auto _add_token = [&] (const size_t end) -> size_t {
assert(_prev_end <= end && end <= offset_end);
size_t len = end - _prev_end;
if (len > 0) {
bpe_offsets.push_back(len);
}
_prev_end = end;
//if (len > 0) {
// std::string s = "";
// for(size_t p = end-len; p < end; p++)
// s += unicode_cpt_to_utf8(cpts[p]);
// printf(">>> '%s'\n", s.c_str());
//}
return len;
};
for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
const uint32_t cpt = _get_cpt(pos);
const auto flags = _get_flags(pos);
// regex: (?i:'s|'t|'re|'ve|'m|'ll|'d) // case insensitive
if (cpt == '\'' && pos+1 < offset_end) {
uint32_t cpt_next = unicode_tolower(_get_cpt(pos+1));
if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') {
pos += _add_token(pos+2);
continue;
}
if (pos+2 < offset_end) {
uint32_t cpt_next_next = unicode_tolower(_get_cpt(pos+2));
if ((cpt_next == 'r' && cpt_next_next == 'e') ||
(cpt_next == 'v' && cpt_next_next == 'e') ||
(cpt_next == 'l' && cpt_next_next == 'l')) {
pos += _add_token(pos+3);
continue;
}
}
}
// regex: [^\r\n\p{L}\p{N}]?\p{L}+
if (!(cpt == '\r' || cpt == '\n' || flags.is_number)) {
if (flags.is_letter || _get_flags(pos+1).is_letter) { // one or more letters
pos++;
while (_get_flags(pos).is_letter) {
pos++;
}
_add_token(pos);
continue;
}
}
// regex: \p{N}{1,3}
if (flags.is_number) {
size_t ini = pos;
while (_get_flags(pos).is_number) {
if (++pos - ini >= 3 ) {
_add_token(pos);
ini = pos;
}
}
_add_token(pos);
continue;
}
// regex: <space>?[^\s\p{L}\p{N}]+[\r\n]*
auto flags2 = (cpt == ' ' ? _get_flags(pos+1) : flags);
if (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags.as_uint()) {
pos += (cpt == ' ');
while (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags2.as_uint()) {
flags2 = _get_flags(++pos);
}
uint32_t cpt2 = _get_cpt(pos);
while (cpt2 == '\r' || cpt2 == '\n') {
cpt2 = _get_cpt(++pos);
}
_add_token(pos);
continue;
}
size_t num_whitespaces = 0;
size_t last_end_r_or_n = 0;
while (_get_flags(pos+num_whitespaces).is_whitespace) {
uint32_t cpt2 = _get_cpt(pos+num_whitespaces);
if (cpt2 == '\r' || cpt2 == '\n') {
last_end_r_or_n = pos + num_whitespaces + 1;
}
num_whitespaces++;
}
// regex: \s*[\r\n]+
if (last_end_r_or_n > 0) {
pos = last_end_r_or_n;
_add_token(pos);
continue;
}
// regex: \s+(?!\S)
if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != OUT_OF_RANGE) {
pos += num_whitespaces - 1;
_add_token(pos);
continue;
}
// regex: \s+
if (num_whitespaces > 0) {
pos += num_whitespaces;
_add_token(pos);
continue;
}
// no matches
_add_token(++pos);
}
}
return bpe_offsets;
}
// use std::wregex to split the text
static std::vector<size_t> unicode_regex_split_stl(const std::wstring & wtext, const std::wstring & regex_expr, const std::vector<size_t> & offsets) {
std::wregex expr(regex_expr);
std::vector<size_t> bpe_offsets; // store the offset of each word
bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
size_t start = 0;
for (auto offset : offsets) {
std::wcregex_iterator it(wtext.data() + start, wtext.data() + start + offset, expr);
std::wcregex_iterator end;
int64_t start_idx = 0;
while (it != end) {
std::wcmatch match = *it;
if (match.position() > start_idx) {
bpe_offsets.emplace_back(match.position() - start_idx);
}
bpe_offsets.emplace_back(match.length());
start_idx = match.position() + match.length();
++it;
}
if (start_idx < (int64_t) offset) {
bpe_offsets.emplace_back(offset - start_idx);
}
start += offset;
}
return bpe_offsets;
}
// use std::regex to split the text
static std::vector<size_t> unicode_regex_split_stl(const std::string & text, const std::string & regex_expr, const std::vector<size_t> & offsets) {
std::regex expr(regex_expr);
std::vector<size_t> bpe_offsets; // store the offset of each word
bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
size_t start = 0;
for (auto offset : offsets) {
std::cregex_iterator it(text.data() + start, text.data() + start + offset, expr);
std::cregex_iterator end;
int64_t start_idx = 0;
while (it != end) {
std::cmatch match = *it;
if (match.position() > start_idx) {
bpe_offsets.emplace_back(match.position() - start_idx);
}
bpe_offsets.emplace_back(match.length());
start_idx = match.position() + match.length();
++it;
}
if (start_idx < (int64_t) offset) {
bpe_offsets.emplace_back(offset - start_idx);
}
start += offset;
}
return bpe_offsets;
}
static std::vector<size_t> unicode_regex_split_custom(const std::string & text, const std::string & regex_expr, const std::vector<size_t> & offsets) {
std::vector<size_t> bpe_offsets;
if (regex_expr == "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)") {
bpe_offsets = unicode_regex_split_custom_gpt2(text, offsets);
} else if (
regex_expr == "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" ||
regex_expr == "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+") {
bpe_offsets = unicode_regex_split_custom_llama3(text, offsets);
}
return bpe_offsets;
}
//
// interface
//
std::string unicode_cpt_to_utf8(uint32_t cpt) {
std::string result;
if (/* 0x00 <= cpt && */ cpt <= 0x7f) {
result.push_back(cpt);
return result;
}
if (0x80 <= cpt && cpt <= 0x7ff) {
result.push_back(0xc0 | ((cpt >> 6) & 0x1f));
result.push_back(0x80 | (cpt & 0x3f));
return result;
}
if (0x800 <= cpt && cpt <= 0xffff) {
result.push_back(0xe0 | ((cpt >> 12) & 0x0f));
result.push_back(0x80 | ((cpt >> 6) & 0x3f));
result.push_back(0x80 | (cpt & 0x3f));
return result;
}
if (0x10000 <= cpt && cpt <= 0x10ffff) {
result.push_back(0xf0 | ((cpt >> 18) & 0x07));
result.push_back(0x80 | ((cpt >> 12) & 0x3f));
result.push_back(0x80 | ((cpt >> 6) & 0x3f));
result.push_back(0x80 | (cpt & 0x3f));
return result;
}
throw std::invalid_argument("invalid codepoint");
}
std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts) {
auto comp = [] (const uint32_t cpt, const range_nfd & range) {
return cpt < range.first;
};
std::vector<uint32_t> result(cpts.size());
for (size_t i = 0; i < cpts.size(); ++i) {
const uint32_t cpt = cpts[i];
auto it = std::upper_bound(unicode_ranges_nfd.begin(), unicode_ranges_nfd.end(), cpt, comp) - 1;
result[i] = (it->first <= cpt && cpt <= it->last) ? it->nfd : cpt;
}
return result;
}
std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8) {
std::vector<uint32_t> result;
result.reserve(utf8.size());
size_t offset = 0;
while (offset < utf8.size()) {
result.push_back(unicode_cpt_from_utf8(utf8, offset));
}
return result;
}
unicode_cpt_flags unicode_cpt_flags_from_cpt(const uint32_t cpt) {
static const unicode_cpt_flags undef(unicode_cpt_flags::UNDEFINED);
static const auto cpt_flags = unicode_cpt_flags_array();
return cpt < cpt_flags.size() ? cpt_flags[cpt] : undef;
}
unicode_cpt_flags unicode_cpt_flags_from_utf8(const std::string & utf8) {
static const unicode_cpt_flags undef(unicode_cpt_flags::UNDEFINED);
if (utf8.empty()) {
return undef; // undefined
}
size_t offset = 0;
return unicode_cpt_flags_from_cpt(unicode_cpt_from_utf8(utf8, offset));
}
std::string unicode_byte_to_utf8(uint8_t byte) {
static std::unordered_map<uint8_t, std::string> map = unicode_byte_to_utf8_map();
return map.at(byte);
}
uint8_t unicode_utf8_to_byte(const std::string & utf8) {
static std::unordered_map<std::string, uint8_t> map = unicode_utf8_to_byte_map();
return map.at(utf8);
}
uint32_t unicode_tolower(uint32_t cpt) {
// binary search
auto it = std::lower_bound(unicode_map_lowercase.begin(), unicode_map_lowercase.end(), cpt,
[](const std::pair<uint32_t, uint32_t> & pair, uint32_t value) {
return pair.first < value;
});
if (it != unicode_map_lowercase.end() && it->first == cpt) {
return it->second;
}
return cpt; // Return the original code point if no lowercase mapping is found
}
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs) {
// unicode categories
static const std::map<std::string, int> k_ucat_enum = {
{ "\\p{N}", unicode_cpt_flags::NUMBER },
{ "\\p{L}", unicode_cpt_flags::LETTER },
{ "\\p{P}", unicode_cpt_flags::PUNCTUATION },
{ "\\p{M}", unicode_cpt_flags::ACCENT_MARK },
{ "\\p{S}", unicode_cpt_flags::SYMBOL },
};
static const std::map<int, int> k_ucat_cpt = {
{ unicode_cpt_flags::NUMBER, 0xD1 },
{ unicode_cpt_flags::LETTER, 0xD2 },
{ unicode_cpt_flags::PUNCTUATION, 0xD3 },
{ unicode_cpt_flags::ACCENT_MARK, 0xD4 },
{ unicode_cpt_flags::SYMBOL, 0xD5 },
};
static const std::map<int, std::string> k_ucat_map = {
{ unicode_cpt_flags::NUMBER, "\x30-\x39" }, // 0-9
{ unicode_cpt_flags::LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z
{ unicode_cpt_flags::PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}
{ unicode_cpt_flags::ACCENT_MARK, "" }, // no sub-128 codepoints
{ unicode_cpt_flags::SYMBOL, "\\\x24\\\x2B\x3C-\x3E\x5E\x60\\\x7C" }, // $+<=>^`|
};
// compute collapsed codepoints only if needed by at least one regex
bool need_collapse = false;
for (const auto & regex_expr : regex_exprs) {
// search for unicode categories
for (const auto & ucat : k_ucat_enum) {
if (std::string::npos != regex_expr.find(ucat.first)) {
need_collapse = true;
break;
}
}
}
const auto cpts = unicode_cpts_from_utf8(text);
// generate a "collapsed" representation of the text, where all codepoints are replaced by a single byte
// ref: https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2081479935
std::string text_collapsed;
if (need_collapse) {
// collapse all unicode categories
text_collapsed.resize(cpts.size());
for (size_t i = 0; i < cpts.size(); ++i) {
// keep single-byte codepoints as is
if (cpts[i] < 128) {
text_collapsed[i] = cpts[i];
continue;
}
const auto flags = unicode_cpt_flags_from_cpt(cpts[i]);
if (flags.is_whitespace) {
//NOTE: C++ std::regex \s does not mach 0x85, Rust and Python regex does.
//text_collapsed[i] = (char) 0x85; // <Next Line> as whitespace fallback
text_collapsed[i] = (char) 0x0B; // <vertical tab> as whitespace fallback
} else if (k_ucat_cpt.find(flags.category_flag()) != k_ucat_cpt.end()) {
text_collapsed[i] = k_ucat_cpt.at(flags.category_flag());
} else {
text_collapsed[i] = (char) 0xD0; // fallback
}
}
}
std::vector<size_t> bpe_offsets = { cpts.size() };
for (const auto & regex_expr : regex_exprs) {
// first, see if we have an efficient custom regex implementation
auto tmp = unicode_regex_split_custom(text, regex_expr, bpe_offsets);
if (!tmp.empty()) {
bpe_offsets = std::move(tmp);
continue;
}
// fallback to general-purpose std::regex / std::wregex
try {
// if a unicode category is used in the regex, we use the collapsed text and replace the unicode category
// with the corresponding collapsed representation
bool use_collapsed = false;
for (const auto & ucat : k_ucat_enum) {
if (std::string::npos != regex_expr.find(ucat.first)) {
use_collapsed = true;
break;
}
}
if (use_collapsed) {
// sanity-check that the original regex does not contain any non-ASCII characters
const auto cpts_regex = unicode_cpts_from_utf8(regex_expr);
for (size_t i = 0; i < cpts_regex.size(); ++i) {
if (cpts_regex[i] >= 128) {
throw std::runtime_error("Regex includes both unicode categories and non-ASCII characters - not supported");
}
}
// generate a collapsed representation of the regex
std::string regex_expr_collapsed;
// track if we are inside [], because nested [] are not allowed
bool inside = false;
for (size_t i = 0; i < regex_expr.size(); ++i) {
if (regex_expr[i] == '[' && (i == 0 || regex_expr[i - 1] != '\\')) {
regex_expr_collapsed += '[';
inside = true;
continue;
}
if (inside && regex_expr[i] == ']' && regex_expr[i - 1] != '\\') {
regex_expr_collapsed += ']';
inside = false;
continue;
}
if (regex_expr[i + 0] == '\\' && i + 4 < regex_expr.size() &&
regex_expr[i + 1] == 'p' &&
regex_expr[i + 2] == '{' &&
regex_expr[i + 4] == '}') {
const std::string pat = regex_expr.substr(i, 5);
if (k_ucat_enum.find(pat) != k_ucat_enum.end()) {
if (!inside) {
regex_expr_collapsed += '[';
}
regex_expr_collapsed += k_ucat_cpt.at(k_ucat_enum.at(pat));
regex_expr_collapsed += k_ucat_map.at(k_ucat_enum.at(pat));
if (!inside) {
regex_expr_collapsed += ']';
}
i += 4;
continue;
}
}
regex_expr_collapsed += regex_expr[i];
}
//printf("text_collapsed: %s\n", text_collapsed.c_str());
//printf("regex_expr_collapsed: %s\n", regex_expr_collapsed.c_str());
bpe_offsets = unicode_regex_split_stl(text_collapsed, regex_expr_collapsed, bpe_offsets);
} else {
// no unicode category used, we can use std::wregex directly
const std::wstring wregex_expr = unicode_wstring_from_utf8(regex_expr);
// std::wregex \s does not mach non-ASCII whitespaces, using 0x0B as fallback
std::wstring wtext(cpts.begin(), cpts.end());
for (size_t i = 0; i < wtext.size(); ++i) {
if (wtext[i] > 0x7F && unicode_cpt_flags_from_cpt(wtext[i]).is_whitespace) {
wtext[i] = 0x0B;
}
}
//printf("text: %s\n", text.c_str());
//printf("regex_expr: %s\n", regex_expr.c_str());
bpe_offsets = unicode_regex_split_stl(wtext, wregex_expr, bpe_offsets);
}
} catch (std::regex_error & e) {
fprintf(stderr, "Failed to process regex: '%s'\n", regex_expr.c_str());
fprintf(stderr, "Regex error: %s\n", e.what());
throw std::runtime_error("Failed to process regex");
}
}
std::vector<std::string> bpe_words;
bpe_words.reserve(bpe_offsets.size()); // reserve memory for the approximate size
size_t start = 0;
for (size_t & offset : bpe_offsets) {
bpe_words.emplace_back();
for (size_t i = start; i < start + offset; ++i) {
bpe_words.back() += unicode_cpt_to_utf8(cpts[i]);
}
start += offset;
}
return unicode_byte_encoding_process(bpe_words);
}

94
llms/export/unicode.h Normal file
View File

@ -0,0 +1,94 @@
/**
* The following unicode files:
* - unicode.h
* - unicode_data.cpp
* - unicode.cpp
* are copied from llama.cpp with minor modifications:
* https://github.com/ggerganov/llama.cpp/
* Commit hash: 8d59d911711b8f1ba9ec57c4b192ccd2628af033
*/
#pragma once
#include <cstdint>
#include <string>
#include <vector>
#include <unordered_map>
#include <unordered_set>
/* unicode-data.h */
struct range_nfd {
uint32_t first;
uint32_t last;
uint32_t nfd;
};
static const uint32_t MAX_CODEPOINTS = 0x110000;
extern const std::initializer_list<std::pair<uint32_t, uint16_t>> unicode_ranges_flags;
extern const std::unordered_set<uint32_t> unicode_set_whitespace;
extern const std::initializer_list<std::pair<uint32_t, uint32_t>> unicode_map_lowercase;
extern const std::initializer_list<std::pair<uint32_t, uint32_t>> unicode_map_uppercase;
extern const std::initializer_list<range_nfd> unicode_ranges_nfd;
/* original unicode.h */
struct unicode_cpt_flags {
enum {
UNDEFINED = 0x0001,
NUMBER = 0x0002, // regex: \p{N}
LETTER = 0x0004, // regex: \p{L}
SEPARATOR = 0x0008, // regex: \p{Z}
ACCENT_MARK = 0x0010, // regex: \p{M}
PUNCTUATION = 0x0020, // regex: \p{P}
SYMBOL = 0x0040, // regex: \p{S}
CONTROL = 0x0080, // regex: \p{C}
MASK_CATEGORIES = 0x00FF,
};
// codepoint type
uint16_t is_undefined : 1;
uint16_t is_number : 1; // regex: \p{N}
uint16_t is_letter : 1; // regex: \p{L}
uint16_t is_separator : 1; // regex: \p{Z}
uint16_t is_accent_mark : 1; // regex: \p{M}
uint16_t is_punctuation : 1; // regex: \p{P}
uint16_t is_symbol : 1; // regex: \p{S}
uint16_t is_control : 1; // regex: \p{C}
// helper flags
uint16_t is_whitespace : 1; // regex: \s
uint16_t is_lowercase : 1;
uint16_t is_uppercase : 1;
uint16_t is_nfd : 1;
// decode from uint16
inline unicode_cpt_flags(const uint16_t flags = 0) {
*reinterpret_cast<uint16_t*>(this) = flags;
}
inline uint16_t as_uint() const {
return *reinterpret_cast<const uint16_t*>(this);
}
inline uint16_t category_flag() const {
return this->as_uint() & MASK_CATEGORIES;
}
};
size_t unicode_len_utf8(char src);
std::string unicode_cpt_to_utf8 (uint32_t cpt);
uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset);
std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8);
std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts);
unicode_cpt_flags unicode_cpt_flags_from_cpt (uint32_t cpt);
unicode_cpt_flags unicode_cpt_flags_from_utf8(const std::string & utf8);
std::string unicode_byte_to_utf8(uint8_t byte);
uint8_t unicode_utf8_to_byte(const std::string & utf8);
uint32_t unicode_tolower(uint32_t cpt);
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs);

7034
llms/export/unicode_data.cpp Normal file

File diff suppressed because it is too large Load Diff

View File

@ -74,9 +74,9 @@ class Attention(nn.Module):
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# Prepare the queries, keys and values for the attention computation # Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) queries = mx.unflatten(queries, -1, (self.n_heads, -1)).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) keys = mx.unflatten(keys, -1, (self.n_kv_heads, -1)).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) values = mx.unflatten(values, -1, (self.n_kv_heads, -1)).transpose(0, 2, 1, 3)
if cache is not None: if cache is not None:
queries = self.rope(queries, offset=cache.offset) queries = self.rope(queries, offset=cache.offset)
@ -90,7 +90,7 @@ class Attention(nn.Module):
queries, keys, values, cache=cache, scale=self.scale, mask=mask queries, keys, values, cache=cache, scale=self.scale, mask=mask
) )
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).flatten(-2, -1)
return self.o_proj(output) return self.o_proj(output)