mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 06:54:39 +08:00
190 lines
5.5 KiB
C++
190 lines
5.5 KiB
C++
|
|
#include <fstream>
|
|
#include <filesystem>
|
|
#include <locale>
|
|
#include <codecvt>
|
|
#include <json.hpp>
|
|
|
|
#include "tokenizer.h"
|
|
#include "third_party/unicode.h"
|
|
|
|
using json = nlohmann::json;
|
|
|
|
#pragma GCC diagnostic push
|
|
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
|
|
std::pair<std::wstring, int> utf8_to_utf16(const std::string& s) {
|
|
static std::string replace_str = std::string(1, 0xFF);
|
|
static std::wstring replace_wstr = std::wstring(1, 0xFFFD);
|
|
std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>> cvt(replace_str, replace_wstr);
|
|
auto out = cvt.from_bytes(s);
|
|
return {out, cvt.converted()};
|
|
}
|
|
#pragma GCC diagnostic pop
|
|
|
|
auto make_byte_decoder() {
|
|
std::unordered_map<uint16_t, char> byte_decoder;
|
|
std::vector<uint16_t> limits = {0, '!', '~' + 1, L'¡', L'¬' + 1, L'®', L'ÿ' + 1};
|
|
char n = 0;
|
|
for (int i = 0; i < limits.size() - 1; ++i) {
|
|
auto start = limits[i];
|
|
auto stop = limits[i + 1];
|
|
if (i % 2 == 0) {
|
|
for (int b = start; b < stop; ++b) {
|
|
byte_decoder[256 + n++] = b;
|
|
}
|
|
} else {
|
|
for (int b = start; b < stop; ++b) {
|
|
byte_decoder[b] = b;
|
|
}
|
|
}
|
|
}
|
|
return byte_decoder;
|
|
}
|
|
|
|
auto BPETokenizer::byte_decoder_ = make_byte_decoder();
|
|
|
|
BPETokenizer::BPETokenizer(const std::string& path_) {
|
|
auto path = std::filesystem::path(path_);
|
|
std::ifstream ifs(path / "tokenizer.json");
|
|
auto tokenizer = json::parse(ifs);
|
|
auto model = tokenizer["model"];
|
|
token_to_id_ = model["vocab"];
|
|
id_to_token_.resize(token_to_id_.size());
|
|
for (auto& [s, id] : token_to_id_) {
|
|
if (id >= id_to_token_.size()) {
|
|
id_to_token_.resize(id + 1);
|
|
}
|
|
id_to_token_[id] = s;
|
|
}
|
|
std::string type = model["type"];
|
|
auto merges = model["merges"];
|
|
for (auto& s : merges) {
|
|
if (s.is_string()) {
|
|
merges_.emplace(s, merges_.size());
|
|
} else {
|
|
std::string s1 = s[0];
|
|
std::string s2 = s[1];
|
|
merges_.emplace(s1 + " " + s2, merges_.size());
|
|
}
|
|
}
|
|
|
|
auto added_tokens = tokenizer["added_tokens"];
|
|
for (auto& added_token : added_tokens) {
|
|
int id = added_token["id"];
|
|
if (id >= id_to_token_.size()) {
|
|
id_to_token_.resize(id + 1);
|
|
}
|
|
id_to_token_[id] = added_token["content"];
|
|
if (id_to_token_[id] == "<|begin_of_text|>") {
|
|
bos_id_ = id;
|
|
} else if (id_to_token_[id] == "<|eot_id|>") {
|
|
eos_id_ = id;
|
|
}
|
|
}
|
|
|
|
// Currently hardcoded to Llama3 BPE regex
|
|
pre_tokenizer_regex_ = {"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"};
|
|
}
|
|
|
|
std::vector<int> BPETokenizer::encode(std::string text) const {
|
|
|
|
auto segments = unicode_regex_split(text, pre_tokenizer_regex_);
|
|
|
|
auto one_step_merge = [this](std::string segment, std::vector<int>& splits) {
|
|
int merge_idx;
|
|
int rank = INT32_MAX;
|
|
for (int i = 0; i < splits.size() - 2; ++i) {
|
|
auto start = splits[i];
|
|
auto mid = splits[i + 1];
|
|
auto end = splits[i + 2];
|
|
std::string candidate = segment.substr(start, mid - start);
|
|
candidate += " ";
|
|
candidate += segment.substr(mid, end - mid);
|
|
if (auto it = merges_.find(candidate); it != merges_.end()) {
|
|
if (it->second < rank) {
|
|
merge_idx = i;
|
|
rank = it->second;
|
|
}
|
|
}
|
|
}
|
|
if (rank == INT32_MAX) {
|
|
return false;
|
|
}
|
|
auto start = splits[merge_idx];
|
|
auto mid = splits[merge_idx + 1];
|
|
auto end = splits[merge_idx + 2];
|
|
std::string merge_l = segment.substr(start, mid - start);
|
|
std::string merge_r = segment.substr(mid, end - mid);
|
|
for (int i = splits.size() - 2; i >= 0; --i) {
|
|
auto start = splits[i];
|
|
auto mid = splits[i + 1];
|
|
auto end = splits[i + 2];
|
|
if (segment.substr(start, mid - start) == merge_l && segment.substr(mid, end - mid) == merge_r) {
|
|
splits.erase(splits.begin() + i + 1);
|
|
i -= 1;
|
|
}
|
|
}
|
|
return true;
|
|
};
|
|
|
|
std::vector<int> ids;
|
|
ids.push_back(bos_id_);
|
|
|
|
// Initialize merges to integer list
|
|
auto merge_segment = [&ids, &one_step_merge, this](const std::string& segment) {
|
|
|
|
std::vector<int> splits;
|
|
for (int i = 0; i < segment.size(); ++i) {
|
|
splits.push_back(i);
|
|
if (static_cast<unsigned char>(segment[i]) > 128) {
|
|
i++;
|
|
}
|
|
}
|
|
splits.push_back(segment.size());
|
|
|
|
while (one_step_merge(segment, splits)) { };
|
|
for (int i = 0; i < splits.size() - 1; ++i) {
|
|
auto start = splits[i];
|
|
auto end = splits[i + 1];
|
|
std::string s = segment.substr(start, end - start);
|
|
if (auto it = token_to_id_.find(s); it != token_to_id_.end()) {
|
|
ids.push_back(it->second);
|
|
} else {
|
|
throw std::runtime_error("UNK ENCOUNTERED");
|
|
}
|
|
}
|
|
};
|
|
|
|
for (auto& segment : segments) {
|
|
merge_segment(segment);
|
|
}
|
|
return ids;
|
|
}
|
|
|
|
std::string BPETokenizer::id_to_bytes(int id) const {
|
|
std::string token;
|
|
auto [wide_token, _] = utf8_to_utf16(id_to_token_[id]);
|
|
token.resize(wide_token.size());
|
|
for (int i = 0; i < wide_token.size(); ++i) {
|
|
token[i] = byte_decoder_[wide_token[i]];
|
|
}
|
|
return token;
|
|
}
|
|
|
|
std::pair<std::string, bool> BPETokenizer::try_decode(const std::vector<int>& ids) const {
|
|
std::string text;
|
|
for (auto id : ids) {
|
|
text += id_to_bytes(id);
|
|
}
|
|
auto [_, converted] = utf8_to_utf16(text);
|
|
bool complete = converted == text.size();
|
|
text.resize(converted);
|
|
return {text, complete};
|
|
}
|
|
|
|
std::string BPETokenizer::decode(const std::vector<int>& ids) const {
|
|
return try_decode(ids).first;
|
|
}
|
|
|
|
int BPETokenizer::eos_token_id() const { return eos_id_; }
|