mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 06:54:39 +08:00
clang format
This commit is contained in:
parent
2797c438bb
commit
5e8f88d079
@ -13,20 +13,18 @@ namespace mx = mlx::core;
|
|||||||
#define time_now() std::chrono::high_resolution_clock::now()
|
#define time_now() std::chrono::high_resolution_clock::now()
|
||||||
|
|
||||||
// Maybe compile
|
// Maybe compile
|
||||||
std::function<mx::Args(mx::Args)> load_model(const std::string& path) {
|
std::function<mx::Args(mx::Args)> load_model(const std::string &path) {
|
||||||
return mx::compile(mx::import_function(path), /* shapeless = */ true);
|
return mx::compile(mx::import_function(path), /* shapeless = */ true);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Maybe make tokenizer virtual
|
// Maybe make tokenizer virtual
|
||||||
BPETokenizer load_tokenizer(const std::string& path) {
|
BPETokenizer load_tokenizer(const std::string &path) {
|
||||||
return BPETokenizer(path);
|
return BPETokenizer(path);
|
||||||
}
|
}
|
||||||
|
|
||||||
void generate(
|
void generate(const std::function<mx::Args(mx::Args)> &model,
|
||||||
const std::function<mx::Args(mx::Args)>& model,
|
const BPETokenizer &tokenizer, const std::string &prompt,
|
||||||
const BPETokenizer& tokenizer,
|
int max_tokens /* = 256 */) {
|
||||||
const std::string& prompt,
|
|
||||||
int max_tokens /* = 256 */) {
|
|
||||||
|
|
||||||
auto prompt_tokens = tokenizer.encode(prompt);
|
auto prompt_tokens = tokenizer.encode(prompt);
|
||||||
int prompt_size = prompt_tokens.size();
|
int prompt_size = prompt_tokens.size();
|
||||||
@ -38,19 +36,22 @@ void generate(
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Helper to expand the cache and mask
|
// Helper to expand the cache and mask
|
||||||
auto expand = [](auto& args, auto& mask) {
|
auto expand = [](auto &args, auto &mask) {
|
||||||
constexpr int cache_step_size = 256;
|
constexpr int cache_step_size = 256;
|
||||||
int cache_size = args[1].shape(-2);
|
int cache_size = args[1].shape(-2);
|
||||||
int new_size = cache_step_size * ((cache_size + cache_step_size) / cache_step_size);
|
int new_size =
|
||||||
|
cache_step_size * ((cache_size + cache_step_size) / cache_step_size);
|
||||||
for (auto it = args.begin() + 1; it != args.end(); ++it) {
|
for (auto it = args.begin() + 1; it != args.end(); ++it) {
|
||||||
auto& x = *it;
|
auto &x = *it;
|
||||||
auto shape = x.shape();
|
auto shape = x.shape();
|
||||||
shape[2] = new_size;
|
shape[2] = new_size;
|
||||||
auto new_x = mx::zeros(shape, x.dtype());
|
auto new_x = mx::zeros(shape, x.dtype());
|
||||||
shape[2] = cache_size;
|
shape[2] = cache_size;
|
||||||
*it = mx::slice_update(new_x, x, mx::Shape(x.ndim(), 0), std::move(shape));
|
*it =
|
||||||
|
mx::slice_update(new_x, x, mx::Shape(x.ndim(), 0), std::move(shape));
|
||||||
}
|
}
|
||||||
mask = mx::slice_update(mx::full({new_size}, false), mask, {0}, {cache_size});
|
mask =
|
||||||
|
mx::slice_update(mx::full({new_size}, false), mask, {0}, {cache_size});
|
||||||
};
|
};
|
||||||
|
|
||||||
auto tic = time_now();
|
auto tic = time_now();
|
||||||
|
@ -6,15 +6,12 @@
|
|||||||
|
|
||||||
namespace mx = mlx::core;
|
namespace mx = mlx::core;
|
||||||
|
|
||||||
std::function<mx::Args(mx::Args)> load_model(const std::string& path);
|
std::function<mx::Args(mx::Args)> load_model(const std::string &path);
|
||||||
|
|
||||||
BPETokenizer load_tokenizer(const std::string& path);
|
BPETokenizer load_tokenizer(const std::string &path);
|
||||||
|
|
||||||
struct GenerationResponse {
|
struct GenerationResponse {};
|
||||||
};
|
|
||||||
|
|
||||||
void generate(
|
void generate(const std::function<mx::Args(mx::Args)> &model,
|
||||||
const std::function<mx::Args(mx::Args)>& model,
|
const BPETokenizer &tokenizer, const std::string &prompt,
|
||||||
const BPETokenizer& tokenizer,
|
int max_tokens = 256);
|
||||||
const std::string& prompt,
|
|
||||||
int max_tokens = 256);
|
|
||||||
|
@ -1,21 +1,22 @@
|
|||||||
|
|
||||||
#include <fstream>
|
|
||||||
#include <filesystem>
|
|
||||||
#include <locale>
|
|
||||||
#include <codecvt>
|
#include <codecvt>
|
||||||
|
#include <filesystem>
|
||||||
|
#include <fstream>
|
||||||
#include <json.hpp>
|
#include <json.hpp>
|
||||||
|
#include <locale>
|
||||||
|
|
||||||
#include "tokenizer.h"
|
|
||||||
#include "third_party/unicode.h"
|
#include "third_party/unicode.h"
|
||||||
|
#include "tokenizer.h"
|
||||||
|
|
||||||
using json = nlohmann::json;
|
using json = nlohmann::json;
|
||||||
|
|
||||||
#pragma GCC diagnostic push
|
#pragma GCC diagnostic push
|
||||||
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
|
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
|
||||||
std::pair<std::wstring, int> utf8_to_utf16(const std::string& s) {
|
std::pair<std::wstring, int> utf8_to_utf16(const std::string &s) {
|
||||||
static std::string replace_str = std::string(1, 0xFF);
|
static std::string replace_str = std::string(1, 0xFF);
|
||||||
static std::wstring replace_wstr = std::wstring(1, 0xFFFD);
|
static std::wstring replace_wstr = std::wstring(1, 0xFFFD);
|
||||||
std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>> cvt(replace_str, replace_wstr);
|
std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>> cvt(replace_str,
|
||||||
|
replace_wstr);
|
||||||
auto out = cvt.from_bytes(s);
|
auto out = cvt.from_bytes(s);
|
||||||
return {out, cvt.converted()};
|
return {out, cvt.converted()};
|
||||||
}
|
}
|
||||||
@ -23,7 +24,8 @@ std::pair<std::wstring, int> utf8_to_utf16(const std::string& s) {
|
|||||||
|
|
||||||
auto make_byte_decoder() {
|
auto make_byte_decoder() {
|
||||||
std::unordered_map<uint16_t, char> byte_decoder;
|
std::unordered_map<uint16_t, char> byte_decoder;
|
||||||
std::vector<uint16_t> limits = {0, '!', '~' + 1, L'¡', L'¬' + 1, L'®', L'ÿ' + 1};
|
std::vector<uint16_t> limits = {0, '!', '~' + 1, L'¡',
|
||||||
|
L'¬' + 1, L'®', L'ÿ' + 1};
|
||||||
char n = 0;
|
char n = 0;
|
||||||
for (int i = 0; i < limits.size() - 1; ++i) {
|
for (int i = 0; i < limits.size() - 1; ++i) {
|
||||||
auto start = limits[i];
|
auto start = limits[i];
|
||||||
@ -43,14 +45,14 @@ auto make_byte_decoder() {
|
|||||||
|
|
||||||
auto BPETokenizer::byte_decoder_ = make_byte_decoder();
|
auto BPETokenizer::byte_decoder_ = make_byte_decoder();
|
||||||
|
|
||||||
BPETokenizer::BPETokenizer(const std::string& path_) {
|
BPETokenizer::BPETokenizer(const std::string &path_) {
|
||||||
auto path = std::filesystem::path(path_);
|
auto path = std::filesystem::path(path_);
|
||||||
std::ifstream ifs(path / "tokenizer.json");
|
std::ifstream ifs(path / "tokenizer.json");
|
||||||
auto tokenizer = json::parse(ifs);
|
auto tokenizer = json::parse(ifs);
|
||||||
auto model = tokenizer["model"];
|
auto model = tokenizer["model"];
|
||||||
token_to_id_ = model["vocab"];
|
token_to_id_ = model["vocab"];
|
||||||
id_to_token_.resize(token_to_id_.size());
|
id_to_token_.resize(token_to_id_.size());
|
||||||
for (auto& [s, id] : token_to_id_) {
|
for (auto &[s, id] : token_to_id_) {
|
||||||
if (id >= id_to_token_.size()) {
|
if (id >= id_to_token_.size()) {
|
||||||
id_to_token_.resize(id + 1);
|
id_to_token_.resize(id + 1);
|
||||||
}
|
}
|
||||||
@ -58,7 +60,7 @@ BPETokenizer::BPETokenizer(const std::string& path_) {
|
|||||||
}
|
}
|
||||||
std::string type = model["type"];
|
std::string type = model["type"];
|
||||||
auto merges = model["merges"];
|
auto merges = model["merges"];
|
||||||
for (auto& s : merges) {
|
for (auto &s : merges) {
|
||||||
if (s.is_string()) {
|
if (s.is_string()) {
|
||||||
merges_.emplace(s, merges_.size());
|
merges_.emplace(s, merges_.size());
|
||||||
} else {
|
} else {
|
||||||
@ -69,7 +71,7 @@ BPETokenizer::BPETokenizer(const std::string& path_) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto added_tokens = tokenizer["added_tokens"];
|
auto added_tokens = tokenizer["added_tokens"];
|
||||||
for (auto& added_token : added_tokens) {
|
for (auto &added_token : added_tokens) {
|
||||||
int id = added_token["id"];
|
int id = added_token["id"];
|
||||||
if (id >= id_to_token_.size()) {
|
if (id >= id_to_token_.size()) {
|
||||||
id_to_token_.resize(id + 1);
|
id_to_token_.resize(id + 1);
|
||||||
@ -83,14 +85,17 @@ BPETokenizer::BPETokenizer(const std::string& path_) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Currently hardcoded to Llama3 BPE regex
|
// Currently hardcoded to Llama3 BPE regex
|
||||||
pre_tokenizer_regex_ = {"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"};
|
pre_tokenizer_regex_ = {
|
||||||
|
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}"
|
||||||
|
"\\p{N}]?\\p{L}+|\\p{N}{1,3}| "
|
||||||
|
"?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int> BPETokenizer::encode(std::string text) const {
|
std::vector<int> BPETokenizer::encode(std::string text) const {
|
||||||
|
|
||||||
auto segments = unicode_regex_split(text, pre_tokenizer_regex_);
|
auto segments = unicode_regex_split(text, pre_tokenizer_regex_);
|
||||||
|
|
||||||
auto one_step_merge = [this](std::string segment, std::vector<int>& splits) {
|
auto one_step_merge = [this](std::string segment, std::vector<int> &splits) {
|
||||||
int merge_idx;
|
int merge_idx;
|
||||||
int rank = INT32_MAX;
|
int rank = INT32_MAX;
|
||||||
for (int i = 0; i < splits.size() - 2; ++i) {
|
for (int i = 0; i < splits.size() - 2; ++i) {
|
||||||
@ -119,7 +124,8 @@ std::vector<int> BPETokenizer::encode(std::string text) const {
|
|||||||
auto start = splits[i];
|
auto start = splits[i];
|
||||||
auto mid = splits[i + 1];
|
auto mid = splits[i + 1];
|
||||||
auto end = splits[i + 2];
|
auto end = splits[i + 2];
|
||||||
if (segment.substr(start, mid - start) == merge_l && segment.substr(mid, end - mid) == merge_r) {
|
if (segment.substr(start, mid - start) == merge_l &&
|
||||||
|
segment.substr(mid, end - mid) == merge_r) {
|
||||||
splits.erase(splits.begin() + i + 1);
|
splits.erase(splits.begin() + i + 1);
|
||||||
i -= 1;
|
i -= 1;
|
||||||
}
|
}
|
||||||
@ -131,18 +137,19 @@ std::vector<int> BPETokenizer::encode(std::string text) const {
|
|||||||
ids.push_back(bos_id_);
|
ids.push_back(bos_id_);
|
||||||
|
|
||||||
// Initialize merges to integer list
|
// Initialize merges to integer list
|
||||||
auto merge_segment = [&ids, &one_step_merge, this](const std::string& segment) {
|
auto merge_segment = [&ids, &one_step_merge,
|
||||||
|
this](const std::string &segment) {
|
||||||
std::vector<int> splits;
|
std::vector<int> splits;
|
||||||
for (int i = 0; i < segment.size(); ++i) {
|
for (int i = 0; i < segment.size(); ++i) {
|
||||||
splits.push_back(i);
|
splits.push_back(i);
|
||||||
if (static_cast<unsigned char>(segment[i]) > 128) {
|
if (static_cast<unsigned char>(segment[i]) >= 128) {
|
||||||
i++;
|
i++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
splits.push_back(segment.size());
|
splits.push_back(segment.size());
|
||||||
|
|
||||||
while (one_step_merge(segment, splits)) { };
|
while (one_step_merge(segment, splits)) {
|
||||||
|
};
|
||||||
for (int i = 0; i < splits.size() - 1; ++i) {
|
for (int i = 0; i < splits.size() - 1; ++i) {
|
||||||
auto start = splits[i];
|
auto start = splits[i];
|
||||||
auto end = splits[i + 1];
|
auto end = splits[i + 1];
|
||||||
@ -155,7 +162,7 @@ std::vector<int> BPETokenizer::encode(std::string text) const {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
for (auto& segment : segments) {
|
for (auto &segment : segments) {
|
||||||
merge_segment(segment);
|
merge_segment(segment);
|
||||||
}
|
}
|
||||||
return ids;
|
return ids;
|
||||||
@ -171,7 +178,8 @@ std::string BPETokenizer::id_to_bytes(int id) const {
|
|||||||
return token;
|
return token;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<std::string, bool> BPETokenizer::try_decode(const std::vector<int>& ids) const {
|
std::pair<std::string, bool>
|
||||||
|
BPETokenizer::try_decode(const std::vector<int> &ids) const {
|
||||||
std::string text;
|
std::string text;
|
||||||
for (auto id : ids) {
|
for (auto id : ids) {
|
||||||
text += id_to_bytes(id);
|
text += id_to_bytes(id);
|
||||||
@ -182,7 +190,7 @@ std::pair<std::string, bool> BPETokenizer::try_decode(const std::vector<int>& id
|
|||||||
return {text, complete};
|
return {text, complete};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string BPETokenizer::decode(const std::vector<int>& ids) const {
|
std::string BPETokenizer::decode(const std::vector<int> &ids) const {
|
||||||
return try_decode(ids).first;
|
return try_decode(ids).first;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -8,24 +8,24 @@
|
|||||||
|
|
||||||
/** BPE Tokenizer API */
|
/** BPE Tokenizer API */
|
||||||
class BPETokenizer {
|
class BPETokenizer {
|
||||||
public:
|
public:
|
||||||
BPETokenizer(const std::string& path);
|
BPETokenizer(const std::string &path);
|
||||||
|
|
||||||
/** Encode a string of text to token integer ids. */
|
/** Encode a string of text to token integer ids. */
|
||||||
std::vector<int> encode(std::string text) const;
|
std::vector<int> encode(std::string text) const;
|
||||||
|
|
||||||
/** Try to decode the vector of ids to text. The text is truncated to
|
/** Try to decode the vector of ids to text. The text is truncated to
|
||||||
* include only the fully decodable tokens. */
|
* include only the fully decodable tokens. */
|
||||||
std::string decode(const std::vector<int>& ids) const;
|
std::string decode(const std::vector<int> &ids) const;
|
||||||
|
|
||||||
/** Try to decode the vector of ids to text. The second return value
|
/** Try to decode the vector of ids to text. The second return value
|
||||||
* indicates if the decoding completed. The text is truncated to include
|
* indicates if the decoding completed. The text is truncated to include
|
||||||
* only the fully decodable tokens. */
|
* only the fully decodable tokens. */
|
||||||
std::pair<std::string, bool> try_decode(const std::vector<int>& ids) const;
|
std::pair<std::string, bool> try_decode(const std::vector<int> &ids) const;
|
||||||
|
|
||||||
int eos_token_id() const;
|
int eos_token_id() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::unordered_map<std::string, int> token_to_id_;
|
std::unordered_map<std::string, int> token_to_id_;
|
||||||
std::vector<std::string> id_to_token_;
|
std::vector<std::string> id_to_token_;
|
||||||
std::unordered_map<std::string, int> merges_;
|
std::unordered_map<std::string, int> merges_;
|
||||||
|
Loading…
Reference in New Issue
Block a user