clang format

This commit is contained in:
Awni Hannun 2025-01-09 15:21:17 -08:00
parent 2797c438bb
commit 5e8f88d079
4 changed files with 53 additions and 47 deletions

View File

@ -22,10 +22,8 @@ BPETokenizer load_tokenizer(const std::string& path) {
return BPETokenizer(path); return BPETokenizer(path);
} }
void generate( void generate(const std::function<mx::Args(mx::Args)> &model,
const std::function<mx::Args(mx::Args)>& model, const BPETokenizer &tokenizer, const std::string &prompt,
const BPETokenizer& tokenizer,
const std::string& prompt,
int max_tokens /* = 256 */) { int max_tokens /* = 256 */) {
auto prompt_tokens = tokenizer.encode(prompt); auto prompt_tokens = tokenizer.encode(prompt);
@ -41,16 +39,19 @@ void generate(
auto expand = [](auto &args, auto &mask) { auto expand = [](auto &args, auto &mask) {
constexpr int cache_step_size = 256; constexpr int cache_step_size = 256;
int cache_size = args[1].shape(-2); int cache_size = args[1].shape(-2);
int new_size = cache_step_size * ((cache_size + cache_step_size) / cache_step_size); int new_size =
cache_step_size * ((cache_size + cache_step_size) / cache_step_size);
for (auto it = args.begin() + 1; it != args.end(); ++it) { for (auto it = args.begin() + 1; it != args.end(); ++it) {
auto &x = *it; auto &x = *it;
auto shape = x.shape(); auto shape = x.shape();
shape[2] = new_size; shape[2] = new_size;
auto new_x = mx::zeros(shape, x.dtype()); auto new_x = mx::zeros(shape, x.dtype());
shape[2] = cache_size; shape[2] = cache_size;
*it = mx::slice_update(new_x, x, mx::Shape(x.ndim(), 0), std::move(shape)); *it =
mx::slice_update(new_x, x, mx::Shape(x.ndim(), 0), std::move(shape));
} }
mask = mx::slice_update(mx::full({new_size}, false), mask, {0}, {cache_size}); mask =
mx::slice_update(mx::full({new_size}, false), mask, {0}, {cache_size});
}; };
auto tic = time_now(); auto tic = time_now();

View File

@ -10,11 +10,8 @@ std::function<mx::Args(mx::Args)> load_model(const std::string& path);
BPETokenizer load_tokenizer(const std::string &path); BPETokenizer load_tokenizer(const std::string &path);
struct GenerationResponse { struct GenerationResponse {};
};
void generate( void generate(const std::function<mx::Args(mx::Args)> &model,
const std::function<mx::Args(mx::Args)>& model, const BPETokenizer &tokenizer, const std::string &prompt,
const BPETokenizer& tokenizer,
const std::string& prompt,
int max_tokens = 256); int max_tokens = 256);

View File

@ -1,12 +1,12 @@
#include <fstream>
#include <filesystem>
#include <locale>
#include <codecvt> #include <codecvt>
#include <filesystem>
#include <fstream>
#include <json.hpp> #include <json.hpp>
#include <locale>
#include "tokenizer.h"
#include "third_party/unicode.h" #include "third_party/unicode.h"
#include "tokenizer.h"
using json = nlohmann::json; using json = nlohmann::json;
@ -15,7 +15,8 @@ using json = nlohmann::json;
std::pair<std::wstring, int> utf8_to_utf16(const std::string &s) { std::pair<std::wstring, int> utf8_to_utf16(const std::string &s) {
static std::string replace_str = std::string(1, 0xFF); static std::string replace_str = std::string(1, 0xFF);
static std::wstring replace_wstr = std::wstring(1, 0xFFFD); static std::wstring replace_wstr = std::wstring(1, 0xFFFD);
std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>> cvt(replace_str, replace_wstr); std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>> cvt(replace_str,
replace_wstr);
auto out = cvt.from_bytes(s); auto out = cvt.from_bytes(s);
return {out, cvt.converted()}; return {out, cvt.converted()};
} }
@ -23,7 +24,8 @@ std::pair<std::wstring, int> utf8_to_utf16(const std::string& s) {
auto make_byte_decoder() { auto make_byte_decoder() {
std::unordered_map<uint16_t, char> byte_decoder; std::unordered_map<uint16_t, char> byte_decoder;
std::vector<uint16_t> limits = {0, '!', '~' + 1, L'¡', L'¬' + 1, L'®', L'ÿ' + 1}; std::vector<uint16_t> limits = {0, '!', '~' + 1, L'¡',
L'¬' + 1, L'®', L'ÿ' + 1};
char n = 0; char n = 0;
for (int i = 0; i < limits.size() - 1; ++i) { for (int i = 0; i < limits.size() - 1; ++i) {
auto start = limits[i]; auto start = limits[i];
@ -83,7 +85,10 @@ BPETokenizer::BPETokenizer(const std::string& path_) {
} }
// Currently hardcoded to Llama3 BPE regex // Currently hardcoded to Llama3 BPE regex
pre_tokenizer_regex_ = {"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"}; pre_tokenizer_regex_ = {
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}"
"\\p{N}]?\\p{L}+|\\p{N}{1,3}| "
"?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"};
} }
std::vector<int> BPETokenizer::encode(std::string text) const { std::vector<int> BPETokenizer::encode(std::string text) const {
@ -119,7 +124,8 @@ std::vector<int> BPETokenizer::encode(std::string text) const {
auto start = splits[i]; auto start = splits[i];
auto mid = splits[i + 1]; auto mid = splits[i + 1];
auto end = splits[i + 2]; auto end = splits[i + 2];
if (segment.substr(start, mid - start) == merge_l && segment.substr(mid, end - mid) == merge_r) { if (segment.substr(start, mid - start) == merge_l &&
segment.substr(mid, end - mid) == merge_r) {
splits.erase(splits.begin() + i + 1); splits.erase(splits.begin() + i + 1);
i -= 1; i -= 1;
} }
@ -131,18 +137,19 @@ std::vector<int> BPETokenizer::encode(std::string text) const {
ids.push_back(bos_id_); ids.push_back(bos_id_);
// Initialize merges to integer list // Initialize merges to integer list
auto merge_segment = [&ids, &one_step_merge, this](const std::string& segment) { auto merge_segment = [&ids, &one_step_merge,
this](const std::string &segment) {
std::vector<int> splits; std::vector<int> splits;
for (int i = 0; i < segment.size(); ++i) { for (int i = 0; i < segment.size(); ++i) {
splits.push_back(i); splits.push_back(i);
if (static_cast<unsigned char>(segment[i]) > 128) { if (static_cast<unsigned char>(segment[i]) >= 128) {
i++; i++;
} }
} }
splits.push_back(segment.size()); splits.push_back(segment.size());
while (one_step_merge(segment, splits)) { }; while (one_step_merge(segment, splits)) {
};
for (int i = 0; i < splits.size() - 1; ++i) { for (int i = 0; i < splits.size() - 1; ++i) {
auto start = splits[i]; auto start = splits[i];
auto end = splits[i + 1]; auto end = splits[i + 1];
@ -171,7 +178,8 @@ std::string BPETokenizer::id_to_bytes(int id) const {
return token; return token;
} }
std::pair<std::string, bool> BPETokenizer::try_decode(const std::vector<int>& ids) const { std::pair<std::string, bool>
BPETokenizer::try_decode(const std::vector<int> &ids) const {
std::string text; std::string text;
for (auto id : ids) { for (auto id : ids) {
text += id_to_bytes(id); text += id_to_bytes(id);