mlx-examples/llms/export/mlxlm.cpp

121 lines
3.4 KiB
C++
Raw Normal View History

2025-01-09 08:17:45 +08:00
// Copyright © 2024 Apple Inc.
#include <chrono>
#include <iomanip>
#include <iostream>
#include "mlxlm.h"
namespace mx = mlx::core;
#define seconds(x) \
(std::chrono::duration_cast<std::chrono::nanoseconds>(x).count() / 1e9)
#define time_now() std::chrono::high_resolution_clock::now()
// Maybe compile
2025-01-10 07:21:17 +08:00
std::function<mx::Args(mx::Args)> load_model(const std::string &path) {
2025-01-09 08:17:45 +08:00
return mx::compile(mx::import_function(path), /* shapeless = */ true);
}
// Maybe make tokenizer virtual
2025-01-10 07:21:17 +08:00
BPETokenizer load_tokenizer(const std::string &path) {
2025-01-09 08:17:45 +08:00
return BPETokenizer(path);
}
2025-01-10 07:21:17 +08:00
void generate(const std::function<mx::Args(mx::Args)> &model,
const BPETokenizer &tokenizer, const std::string &prompt,
int max_tokens /* = 256 */) {
2025-01-09 08:17:45 +08:00
auto prompt_tokens = tokenizer.encode(prompt);
int prompt_size = prompt_tokens.size();
auto y = mx::array(prompt_tokens.data(), {1, prompt_size}, mx::uint32);
auto create_causal_mask = [](int N) {
auto indices = mx::arange(N);
return mx::expand_dims(indices, 1) >= indices;
};
// Helper to expand the cache and mask
2025-01-10 07:21:17 +08:00
auto expand = [](auto &args, auto &mask) {
2025-01-09 08:17:45 +08:00
constexpr int cache_step_size = 256;
int cache_size = args[1].shape(-2);
2025-01-10 07:21:17 +08:00
int new_size =
cache_step_size * ((cache_size + cache_step_size) / cache_step_size);
2025-01-09 08:17:45 +08:00
for (auto it = args.begin() + 1; it != args.end(); ++it) {
2025-01-10 07:21:17 +08:00
auto &x = *it;
2025-01-09 08:17:45 +08:00
auto shape = x.shape();
shape[2] = new_size;
auto new_x = mx::zeros(shape, x.dtype());
shape[2] = cache_size;
2025-01-10 07:21:17 +08:00
*it =
mx::slice_update(new_x, x, mx::Shape(x.ndim(), 0), std::move(shape));
2025-01-09 08:17:45 +08:00
}
2025-01-10 07:21:17 +08:00
mask =
mx::slice_update(mx::full({new_size}, false), mask, {0}, {cache_size});
2025-01-09 08:17:45 +08:00
};
auto tic = time_now();
float prompt_time;
int n = 0;
mx::Args args;
{
args = model({y, create_causal_mask(y.size())});
auto logits = args[0];
logits = slice(logits, {0, -1, 0}, logits.shape());
y = argmax(logits, -1);
async_eval(y);
}
auto offset = mx::array(prompt_size, mx::uint32);
std::vector<int> tokens;
auto mask = mx::full({prompt_size}, true);
expand(args, mask);
for (; n < max_tokens; ++n) {
// Start next token decoding if needed
if (n < max_tokens - 1) {
args[0] = y;
auto m = prompt_size + n;
if (mask.size() <= m) {
expand(args, mask);
}
mask = mx::slice_update(mask, mx::array(true), {m}, {m + 1});
args.push_back(offset);
args.push_back(mask);
args = model(args);
args[0] = argmax(args[0], -1);
offset = offset + 1u;
async_eval(args[0]);
}
auto token = y.item<int>();
if (token == tokenizer.eos_token_id()) {
break;
}
tokens.push_back(token);
auto [result, complete] = tokenizer.try_decode(tokens);
if (complete) {
std::cout << result << std::flush;
tokens.clear();
}
if (n == 0) {
prompt_time = seconds(time_now() - tic);
tic = time_now();
}
if (n < max_tokens - 1) {
y = args[0];
}
}
auto result = tokenizer.decode(tokens);
std::cout << result << std::flush;
auto gen_time = seconds(time_now() - tic);
std::cout << std::endl;
std::cout << std::setprecision(5) << "Prompt toks/sec "
<< prompt_size / prompt_time << "\nGeneration toks/sec "
<< (n + 1) / gen_time << std::endl;
}