Export / import functions to / from a file (#1642)

* export and import functions

* refactor + works for few primitives

* nit

* allow primitives with state

* nit

* nit

* simplify serialize / deserialize

* fix for constants

* python bindings

* maybe fix serialize failure case

* add example

* more primitives, training kind of works

* same result for python and c++

* some fixes

* fix export

* template it up

* some simplificatoin

* rebase

* allow kwargs and multiple functions

* exporter

* more primitives for exporting

* deal with endianness

* handle invalid stream

* add docstring
This commit is contained in:
Awni Hannun 2024-12-24 11:19:13 -08:00 committed by GitHub
parent 935c8c4bb1
commit 4ba0c24a8f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
35 changed files with 2239 additions and 90 deletions

View File

@ -27,6 +27,7 @@ option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
if(NOT MLX_VERSION)
set(MLX_VERSION 0.21.1)
endif()
add_compile_definitions("MLX_VERSION=${MLX_VERSION}")
# --------------------- Processor tests -------------------------

View File

@ -61,6 +61,7 @@ are the CPU and GPU.
python/array
python/data_types
python/devices_and_streams
python/export
python/ops
python/random
python/transforms

View File

@ -0,0 +1,14 @@
.. _export:
Export Functions
================
.. currentmodule:: mlx.core
.. autosummary::
:toctree: _autosummary
export_function
import_function
exporter
export_to_dot

View File

@ -0,0 +1,27 @@
cmake_minimum_required(VERSION 3.27)
project(import_mlx LANGUAGES CXX)
# ----------------------------- Setup -----------------------------
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
# ----------------------------- Dependencies -----------------------------
find_package(
Python 3.9
COMPONENTS Interpreter Development.Module
REQUIRED)
execute_process(
COMMAND "${Python_EXECUTABLE}" -m pip show mlx
COMMAND grep location
COMMAND awk "{print $4 \"/mlx\"}"
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE MLX_ROOT)
find_package(MLX CONFIG REQUIRED)
add_executable(eval_mlp eval_mlp.cpp)
target_link_libraries(eval_mlp PRIVATE mlx)
add_executable(train_mlp train_mlp.cpp)
target_link_libraries(train_mlp PRIVATE mlx)

49
examples/export/README.md Normal file
View File

@ -0,0 +1,49 @@
## Setup
Install MLX:
```bash
pip install mlx>=0.22
```
Build the C++ examples:
```bash
cmake -B build -DCMAKE_BUILD_TYPE=Release
cmake --build build
```
## Run
### Eval MLP
Run the Python script to export the eval function:
```bash
python eval_mlp.py
```
Then run the C++ program to import and run the function:
```
./build/eval_mlp
```
The Python and C++ programs should output the same result.
### Train MLP
Run the Python script to export the model initialization and training
functions:
```bash
python train_mlp.py
```
Then run the C++ program to import and run the functions:
```
./build/train_mlp
```
The Python and C++ programs should output the same results.

View File

@ -0,0 +1,25 @@
// Copyright © 2024 Apple Inc.
#include <mlx/mlx.h>
#include <iostream>
using namespace mlx::core;
int main() {
int batch_size = 8;
int input_dim = 32;
// Make the input
random::seed(42);
auto example_x = random::uniform({batch_size, input_dim});
// Import the function
auto forward = import_function("eval_mlp.mlxfn");
// Call the imported function
auto out = forward({example_x})[0];
std::cout << out << std::endl;
return 0;
}

View File

@ -0,0 +1,52 @@
# Copyright © 2024 Apple Inc.
import mlx.core as mx
import mlx.nn as nn
import mlx.utils
class MLP(nn.Module):
"""A simple MLP."""
def __init__(
self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int
):
super().__init__()
layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim]
self.layers = [
nn.Linear(idim, odim)
for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])
]
def __call__(self, x):
for l in self.layers[:-1]:
x = nn.relu(l(x))
return self.layers[-1](x)
if __name__ == "__main__":
batch_size = 8
input_dim = 32
output_dim = 10
# Load the model
mx.random.seed(0) # Seed for params
model = MLP(num_layers=5, input_dim=input_dim, hidden_dim=64, output_dim=output_dim)
mx.eval(model)
# Note, the model parameters are saved in the export function
def forward(x):
return model(x)
mx.random.seed(42) # Seed for input
example_x = mx.random.uniform(shape=(batch_size, input_dim))
mx.export_function("eval_mlp.mlxfn", forward, example_x)
# Import in Python
imported_forward = mx.import_function("eval_mlp.mlxfn")
expected = forward(example_x)
(out,) = imported_forward(example_x)
assert mx.allclose(expected, out)
print(out)

View File

@ -0,0 +1,35 @@
// Copyright © 2024 Apple Inc.
#include <mlx/mlx.h>
#include <iostream>
using namespace mlx::core;
int main() {
int batch_size = 8;
int input_dim = 32;
int output_dim = 10;
auto state = import_function("init_mlp.mlxfn")({});
// Make the input
random::seed(42);
auto example_X = random::normal({batch_size, input_dim});
auto example_y = random::randint(0, output_dim, {batch_size});
// Import the function
auto step = import_function("train_mlp.mlxfn");
// Call the imported function
for (int it = 0; it < 100; ++it) {
state.insert(state.end(), {example_X, example_y});
state = step(state);
eval(state);
auto loss = state.back();
state.pop_back();
if (it % 10 == 0) {
std::cout << "Loss " << loss.item<float>() << std::endl;
}
}
return 0;
}

View File

@ -0,0 +1,76 @@
# Copyright © 2024 Apple Inc.
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import mlx.utils
class MLP(nn.Module):
"""A simple MLP."""
def __init__(
self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int
):
super().__init__()
layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim]
self.layers = [
nn.Linear(idim, odim)
for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])
]
def __call__(self, x):
for l in self.layers[:-1]:
x = nn.relu(l(x))
return self.layers[-1](x)
if __name__ == "__main__":
batch_size = 8
input_dim = 32
output_dim = 10
def init():
# Seed for the parameter initialization
mx.random.seed(0)
model = MLP(
num_layers=3, input_dim=input_dim, hidden_dim=64, output_dim=output_dim
)
optimizer = optim.SGD(learning_rate=1e-1)
optimizer.init(model.parameters())
state = [model.parameters(), optimizer.state]
tree_structure, state = zip(*mlx.utils.tree_flatten(state))
return model, optimizer, tree_structure, state
# Export the model parameter initialization
model, optimizer, tree_structure, state = init()
mx.eval(state)
mx.export_function("init_mlp.mlxfn", lambda: init()[-1])
def loss_fn(params, X, y):
model.update(params)
return nn.losses.cross_entropy(model(X), y, reduction="mean")
def step(*inputs):
*state, X, y = inputs
params, opt_state = mlx.utils.tree_unflatten(list(zip(tree_structure, state)))
optimizer.state = opt_state
loss, grads = mx.value_and_grad(loss_fn)(params, X, y)
params = optimizer.apply_gradients(grads, params)
_, state = zip(*mlx.utils.tree_flatten([params, optimizer.state]))
return *state, loss
# Make some random data
mx.random.seed(42)
example_X = mx.random.normal(shape=(batch_size, input_dim))
example_y = mx.random.randint(low=0, high=output_dim, shape=(batch_size,))
mx.export_function("train_mlp.mlxfn", step, *state, example_X, example_y)
# Export one step of SGD
imported_step = mx.import_function("train_mlp.mlxfn")
for it in range(100):
*state, loss = imported_step(*state, example_X, example_y)
if it % 10 == 0:
print(f"Loss {loss.item():.6}")

View File

@ -5,6 +5,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
${CMAKE_CURRENT_SOURCE_DIR}/export.cpp
${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp

View File

@ -156,9 +156,6 @@ CompileMode& compile_mode() {
return compile_mode_;
}
using ParentsMap =
std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>;
// Helper like below but only merges the two provided arrays. If the src has
// siblings then these won't be merged to the dst.
void merge_one(array& dst, array& src, ParentsMap& parents_map) {
@ -732,10 +729,15 @@ std::vector<array> compile_replace(
trace_to_real.insert({trace_inputs[i].id(), inputs[i]});
}
auto is_load = [](const Primitive& p) { return typeid(p) == typeid(Load); };
for (auto& a : tape) {
// Arrays in the tape without primitives are constants
// and can be used directly
if (!a.has_primitive()) {
// Arrays in the tape without primitives are either:
// - inputs, which are already in the map
// - constants, which can be used directly
// - a load primitive which has no inputs and will become a constant
// after the first eval
if (!a.has_primitive() || is_load(a.primitive())) {
trace_to_real.insert({a.id(), a});
} else {
// Find real inputs

View File

@ -2,7 +2,7 @@
#pragma once
#include "mlx/device.h"
#include "mlx/array.h"
namespace mlx::core::detail {
@ -22,4 +22,35 @@ void compile_erase(std::uintptr_t fun_id);
void compile_clear_cache();
bool compile_available_for_device(const Device& device);
std::pair<std::vector<array>, std::vector<array>> compile_trace(
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
const std::vector<array>& inputs);
using ParentsMap =
std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>;
// Traverses the graph to build a tape and a map of array ids to their parents
std::pair<std::vector<array>, ParentsMap> compile_dfs(
const std::vector<array>& inputs,
const std::vector<array>& outputs,
const std::vector<array>& original_inputs);
// Simplify the tape. Note, this function modifies in-place both the tape and
// the parents map to remove orphaned arrays
void compile_simplify(
std::vector<array>& tape,
ParentsMap& parents_map,
const std::vector<array>& outputs,
int passes);
std::vector<array> compile_replace(
const std::vector<array>& tape,
const std::vector<array>& trace_inputs,
const std::vector<array>& trace_outputs,
const std::vector<array>& inputs,
bool shapeless);
void compile_validate_shapeless(const std::vector<array>& tape);
} // namespace mlx::core::detail

867
mlx/export.cpp Normal file
View File

@ -0,0 +1,867 @@
// Copyright © 2024 Apple Inc.
#include "mlx/export.h"
#include "mlx/compile_impl.h"
#include "mlx/fast_primitives.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
#define STRINGIFY(x) #x
#define TOSTRING(x) STRINGIFY(x)
// clang-format off
#define SERIALIZE_PRIMITIVE(primitive, keys...) \
{ \
#primitive, { \
[](Writer& os, const Primitive& p) { \
serialize_primitive<primitive>(os, p); \
}, \
[](Reader& is, Stream s) { \
return deserialize_primitive<primitive>(is, s); \
}, \
{keys} \
} \
}
// clang-format on
bool is_big_endian() {
int num = 1;
return *reinterpret_cast<char*>(&num) != 1;
}
namespace mlx::core {
using namespace mlx::core::fast;
using Reader = io::ParallelFileReader;
using Writer = io::FileWriter;
struct PrimitiveSerializer {
using Serializer = std::function<void(Writer&, const Primitive&)>;
using Deserializer =
std::function<std::shared_ptr<Primitive>(Reader&, Stream s)>;
PrimitiveSerializer(
Serializer serialize,
Deserializer deserialize,
std::vector<std::string> keys = {})
: serialize(std::move(serialize)),
deserialize(std::move(deserialize)),
keys(std::move(keys)) {};
Serializer serialize;
Deserializer deserialize;
std::vector<std::string> keys;
};
template <typename, typename = void>
constexpr bool is_iterable = false;
template <typename T>
constexpr bool is_iterable<
T,
std::void_t<
decltype(std::declval<T>().begin()),
decltype(std::declval<T>().end())>> = true;
template <template <typename...> class T, typename U>
constexpr bool is_specialization_of = false;
template <template <typename...> class T, typename... Us>
constexpr bool is_specialization_of<T, T<Us...>> = true;
template <typename T>
constexpr bool is_pair = is_specialization_of<std::pair, std::decay_t<T>>;
template <typename T>
constexpr bool is_tuple = is_specialization_of<std::tuple, std::decay_t<T>>;
template <typename>
constexpr bool dependent_false = false;
template <typename T>
struct NotSerializable {
static_assert(dependent_false<T>, "Type is not serializable.");
};
template <typename T>
struct NotDeserializable {
static_assert(dependent_false<T>, "Type is not deserializable.");
};
template <typename T>
void reverse_bytes(T& data) {
auto* bytes = reinterpret_cast<uint8_t*>(&data);
for (size_t j = 0; j < (sizeof(T) / 2); j++) {
std::swap(bytes[j], bytes[sizeof(T) - j - 1]);
}
}
template <typename T>
void serialize(Writer& os, T v) {
if constexpr (std::is_arithmetic_v<T>) {
if (is_big_endian()) {
reverse_bytes(v);
}
os.write(reinterpret_cast<const char*>(&v), sizeof(T));
} else if constexpr (std::is_enum_v<T>) {
serialize(os, static_cast<int>(v));
} else if constexpr (std::is_same_v<T, std::nullptr_t>) {
} else if constexpr (is_iterable<T>) {
serialize(os, static_cast<uint64_t>(v.size()));
for (const auto& t : v) {
serialize(os, t);
}
} else if constexpr (is_pair<T> || is_tuple<T>) {
std::apply([&os](auto&... x) { (..., serialize(os, x)); }, v);
} else {
NotSerializable<T>();
}
}
template <typename T, std::size_t... I>
decltype(auto) deserialize_tuple(Reader& is, std::index_sequence<I...>);
template <typename T>
T deserialize(Reader& is) {
if constexpr (std::is_arithmetic_v<T>) {
T v;
is.read(reinterpret_cast<char*>(&v), sizeof(T));
if (is_big_endian()) {
reverse_bytes(v);
}
return v;
} else if constexpr (std::is_enum_v<T>) {
return static_cast<T>(deserialize<int>(is));
} else if constexpr (std::is_same_v<T, std::nullptr_t>) {
return nullptr;
} else if constexpr (is_iterable<T>) {
T v;
auto size = deserialize<uint64_t>(is);
v.reserve(size);
for (int i = 0; i < size; ++i) {
v.push_back(deserialize<typename T::value_type>(is));
}
return v;
} else if constexpr (is_pair<T> || is_tuple<T>) {
return deserialize_tuple<T>(
is, std::make_index_sequence<std::tuple_size_v<std::decay_t<T>>>{});
} else {
NotDeserializable<T>();
}
}
template <typename T, std::size_t... I>
decltype(auto) deserialize_tuple(Reader& is, std::index_sequence<I...>) {
return T{deserialize<std::tuple_element_t<I, T>>(is)...};
};
void serialize(Writer& os, const Stream& s) {
serialize(os, s.index);
serialize(os, s.device.type);
serialize(os, s.device.index);
}
template <>
Stream deserialize(Reader& is) {
auto stream_index = deserialize<int>(is);
auto device_type = deserialize<Device::DeviceType>(is);
auto device_index = deserialize<int>(is);
return Stream(stream_index, Device(device_type, device_index));
}
void serialize(Writer& os, const Dtype& t) {
serialize(os, t.val());
serialize(os, t.size());
}
template <>
Dtype deserialize(Reader& is) {
auto val = deserialize<Dtype::Val>(is);
auto size = deserialize<uint8_t>(is);
return Dtype(val, size);
};
void serialize(Writer& os, const array& arr) {
serialize(os, arr.shape());
serialize(os, arr.dtype());
}
template <>
array deserialize(Reader& is) {
auto shape = deserialize<std::vector<int>>(is);
auto type = deserialize<Dtype>(is);
return array(std::move(shape), type, nullptr, std::vector<array>{});
}
template <typename, typename = void>
constexpr bool has_state = false;
template <typename T>
constexpr bool has_state<T, std::void_t<decltype(std::declval<T>().state())>> =
true;
template <typename T>
void serialize_primitive(Writer& os, const Primitive& p) {
if constexpr (has_state<T>) {
serialize(os, static_cast<const T&>(p).state());
}
}
template <typename T>
std::shared_ptr<T> deserialize_primitive(Reader& is, Stream s) {
if constexpr (has_state<T>) {
auto args = deserialize<decltype(std::declval<T>().state())>(is);
if constexpr (is_pair<decltype(args)> || is_tuple<decltype(args)>) {
auto fn = [s](auto&&... args) {
return std::make_shared<T>(s, std::move(args)...);
};
return std::apply(fn, std::move(args));
} else {
return std::make_shared<T>(s, std::move(args));
}
} else {
return std::make_shared<T>(s);
}
}
struct PrimitiveFactory {
std::unordered_map<std::string, PrimitiveSerializer> factory = {
SERIALIZE_PRIMITIVE(Abs),
SERIALIZE_PRIMITIVE(Add),
SERIALIZE_PRIMITIVE(AddMM),
SERIALIZE_PRIMITIVE(Arange),
SERIALIZE_PRIMITIVE(ArcCos),
SERIALIZE_PRIMITIVE(ArcCosh),
SERIALIZE_PRIMITIVE(ArcSin),
SERIALIZE_PRIMITIVE(ArcSinh),
SERIALIZE_PRIMITIVE(ArcTan),
SERIALIZE_PRIMITIVE(ArcTan2),
SERIALIZE_PRIMITIVE(ArcTanh),
SERIALIZE_PRIMITIVE(ArgPartition),
SERIALIZE_PRIMITIVE(ArgReduce),
SERIALIZE_PRIMITIVE(ArgSort),
SERIALIZE_PRIMITIVE(AsType),
SERIALIZE_PRIMITIVE(AsStrided),
SERIALIZE_PRIMITIVE(
BitwiseBinary,
"BitwiseAnd",
"BitwiseOr",
"BitwiseXor",
"LeftShift",
"RightShift"),
SERIALIZE_PRIMITIVE(BlockMaskedMM),
SERIALIZE_PRIMITIVE(Broadcast),
SERIALIZE_PRIMITIVE(Ceil),
SERIALIZE_PRIMITIVE(Concatenate),
SERIALIZE_PRIMITIVE(Conjugate),
SERIALIZE_PRIMITIVE(Convolution),
SERIALIZE_PRIMITIVE(Copy),
SERIALIZE_PRIMITIVE(Cos),
SERIALIZE_PRIMITIVE(Cosh),
SERIALIZE_PRIMITIVE(Depends),
SERIALIZE_PRIMITIVE(Divide),
SERIALIZE_PRIMITIVE(DivMod),
SERIALIZE_PRIMITIVE(Equal, "NaNEqual"),
SERIALIZE_PRIMITIVE(Erf),
SERIALIZE_PRIMITIVE(ErfInv),
SERIALIZE_PRIMITIVE(Exp),
SERIALIZE_PRIMITIVE(Expm1),
SERIALIZE_PRIMITIVE(ExpandDims),
SERIALIZE_PRIMITIVE(FFT),
SERIALIZE_PRIMITIVE(Flatten),
SERIALIZE_PRIMITIVE(Floor),
SERIALIZE_PRIMITIVE(Full),
SERIALIZE_PRIMITIVE(Gather),
SERIALIZE_PRIMITIVE(GatherMM),
SERIALIZE_PRIMITIVE(Greater),
SERIALIZE_PRIMITIVE(GreaterEqual),
SERIALIZE_PRIMITIVE(Hadamard),
SERIALIZE_PRIMITIVE(Imag),
SERIALIZE_PRIMITIVE(Less),
SERIALIZE_PRIMITIVE(LessEqual),
SERIALIZE_PRIMITIVE(Log, "Log2", "Log10"),
SERIALIZE_PRIMITIVE(Log1p),
SERIALIZE_PRIMITIVE(LogicalNot),
SERIALIZE_PRIMITIVE(LogicalAnd),
SERIALIZE_PRIMITIVE(LogicalOr),
SERIALIZE_PRIMITIVE(LogAddExp),
SERIALIZE_PRIMITIVE(Matmul),
SERIALIZE_PRIMITIVE(Maximum),
SERIALIZE_PRIMITIVE(Minimum),
SERIALIZE_PRIMITIVE(Multiply),
SERIALIZE_PRIMITIVE(Negative),
SERIALIZE_PRIMITIVE(NotEqual),
SERIALIZE_PRIMITIVE(Reshape),
SERIALIZE_PRIMITIVE(NumberOfElements),
SERIALIZE_PRIMITIVE(Pad),
SERIALIZE_PRIMITIVE(Partition),
SERIALIZE_PRIMITIVE(Power),
SERIALIZE_PRIMITIVE(QuantizedMatmul),
SERIALIZE_PRIMITIVE(GatherQMM),
SERIALIZE_PRIMITIVE(RandomBits),
SERIALIZE_PRIMITIVE(Real),
SERIALIZE_PRIMITIVE(Remainder),
SERIALIZE_PRIMITIVE(Reshape),
SERIALIZE_PRIMITIVE(Reduce, "And", "Or", "Sum", "Prod", "Min", "Max"),
SERIALIZE_PRIMITIVE(Round),
SERIALIZE_PRIMITIVE(Scan, "CumSum", "CumProd", "CumMin", "CumMax"),
SERIALIZE_PRIMITIVE(Scatter),
SERIALIZE_PRIMITIVE(Select),
SERIALIZE_PRIMITIVE(Sigmoid),
SERIALIZE_PRIMITIVE(Sign),
SERIALIZE_PRIMITIVE(Sin),
SERIALIZE_PRIMITIVE(Sinh),
SERIALIZE_PRIMITIVE(Slice),
SERIALIZE_PRIMITIVE(SliceUpdate),
SERIALIZE_PRIMITIVE(Softmax),
SERIALIZE_PRIMITIVE(Sort),
SERIALIZE_PRIMITIVE(Split),
SERIALIZE_PRIMITIVE(Square),
SERIALIZE_PRIMITIVE(Squeeze),
SERIALIZE_PRIMITIVE(Sqrt, "Rsqrt", "Sqrt"),
SERIALIZE_PRIMITIVE(StopGradient),
SERIALIZE_PRIMITIVE(Subtract),
SERIALIZE_PRIMITIVE(Tan),
SERIALIZE_PRIMITIVE(Tanh),
SERIALIZE_PRIMITIVE(View),
SERIALIZE_PRIMITIVE(Transpose),
SERIALIZE_PRIMITIVE(Unflatten),
SERIALIZE_PRIMITIVE(QRF),
SERIALIZE_PRIMITIVE(SVD),
SERIALIZE_PRIMITIVE(Inverse),
SERIALIZE_PRIMITIVE(Cholesky),
SERIALIZE_PRIMITIVE(Eigh),
SERIALIZE_PRIMITIVE(AffineQuantize),
SERIALIZE_PRIMITIVE(RMSNorm),
SERIALIZE_PRIMITIVE(RMSNormVJP),
SERIALIZE_PRIMITIVE(LayerNorm),
SERIALIZE_PRIMITIVE(LayerNormVJP),
SERIALIZE_PRIMITIVE(RoPE),
SERIALIZE_PRIMITIVE(ScaledDotProductAttention)};
std::unordered_map<std::string, std::string> name_remap;
PrimitiveFactory() {
for (auto& [n, f] : factory) {
for (auto& k : f.keys) {
name_remap[k] = n;
}
}
}
void save(Writer& os, const std::shared_ptr<Primitive>& p) {
serialize(os, p->stream());
std::ostringstream pout;
p->print(pout);
auto name = pout.str();
name = name.substr(0, name.find(' '));
if (auto it = name_remap.find(name); it != name_remap.end()) {
name = it->second;
}
serialize(os, name);
if (auto it = factory.find(name); it != factory.end()) {
it->second.serialize(os, *p);
} else {
throw std::invalid_argument(
"[export_function] Unable to serialize primitive " + name);
}
};
std::shared_ptr<Primitive> load(Reader& is) {
auto stream = deserialize<Stream>(is);
if (get_stream(stream.index) != stream) {
std::ostringstream msg;
msg << "[import_function] Invalid stream encountered " << stream << ".";
throw std::invalid_argument(msg.str());
}
auto name = deserialize<std::string>(is);
if (auto it = factory.find(name); it != factory.end()) {
return it->second.deserialize(is, stream);
} else {
throw std::invalid_argument(
"[import_function] Unable to deserialize primitive " + name);
}
};
};
void write_header(Writer& os, int count, bool shapeless) {
serialize(os, std::string(TOSTRING(MLX_VERSION)));
serialize(os, count);
serialize(os, shapeless);
}
// A struct to hold and retrieve the graphs that are exported / imported
struct FunctionTable {
FunctionTable(bool shapeless = false) : shapeless(shapeless) {};
struct Function {
Function(
std::vector<std::string> kwarg_keys,
std::vector<array> inputs,
std::vector<array> outputs,
std::vector<array> tape)
: kwarg_keys(std::move(kwarg_keys)),
inputs(std::move(inputs)),
outputs(std::move(outputs)),
tape(std::move(tape)) {}
std::vector<std::string> kwarg_keys;
std::vector<array> inputs;
std::vector<array> outputs;
std::vector<array> tape;
Function(const Function&) = delete;
Function& operator=(const Function&) = delete;
Function(Function&&) = default;
Function() = default;
};
bool shapeless;
std::unordered_map<int, std::vector<Function>> table;
Function* find(const Args& args, const Kwargs& kwargs);
std::pair<Function&, bool> emplace(const Args& args, const Kwargs& kwargs);
void insert(
std::vector<std::string> kwarg_keys,
std::vector<array> inputs,
std::vector<array> outputs,
std::vector<array> tape) {
auto [it, _] = table.emplace(inputs.size(), std::vector<Function>{});
it->second.emplace_back(
std::move(kwarg_keys),
std::move(inputs),
std::move(outputs),
std::move(tape));
}
void print_functions(std::ostream& os) {
int n = 1;
for (auto& [_, vec] : table) {
for (auto& fun : vec) {
auto npos = fun.inputs.size() - fun.kwarg_keys.size();
os << " " << n++ << ". Function with " << npos
<< " positional inputs and " << fun.kwarg_keys.size()
<< " keyword inputs:\n";
for (int j = 0; j < fun.inputs.size(); ++j) {
auto& in = fun.inputs[j];
if (j < npos) {
os << " " << j + 1 << ": ";
} else {
os << " \"" << fun.kwarg_keys[j - npos] << "\": ";
}
os << in.shape() << " " << in.dtype() << "\n";
}
}
}
}
private:
bool match(const Args& args, const Kwargs& kwargs, const Function& fun);
};
bool FunctionTable::match(
const Args& args,
const Kwargs& kwargs,
const Function& fun) {
for (auto& k : fun.kwarg_keys) {
if (kwargs.find(k) == kwargs.end()) {
return false;
}
}
auto match_inputs = [shapeless = this->shapeless](
const array& x, const array& y) {
if (x.dtype() != y.dtype()) {
return false;
}
if (!shapeless && x.shape() != y.shape()) {
return false;
}
return true;
};
int i = 0;
for (; i < args.size(); ++i) {
if (!match_inputs(args[i], fun.inputs[i])) {
return false;
}
}
for (auto& [_, in] : kwargs) {
if (!match_inputs(in, fun.inputs[i++])) {
return false;
}
}
return true;
}
std::pair<FunctionTable::Function&, bool> FunctionTable::emplace(
const Args& args,
const Kwargs& kwargs) {
auto n_inputs = args.size() + kwargs.size();
auto [it, _] = table.emplace(n_inputs, std::vector<Function>{});
auto& funs_vec = it->second;
for (auto& fun : funs_vec) {
if (match(args, kwargs, fun)) {
return {fun, false};
}
}
funs_vec.emplace_back();
return {funs_vec.back(), true};
}
FunctionTable::Function* FunctionTable::find(
const Args& args,
const Kwargs& kwargs) {
auto n_inputs = args.size() + kwargs.size();
auto it = table.find(n_inputs);
if (it == table.end()) {
return nullptr;
}
for (auto& fun : it->second) {
if (match(args, kwargs, fun)) {
return &fun;
}
}
return nullptr;
}
FunctionExporter::FunctionExporter(
const std::string& file,
std::function<std::vector<array>(const Args&, const Kwargs&)> fun,
bool shapeless)
: os(file),
fun(std::move(fun)),
ftable(std::make_shared<FunctionTable>(shapeless)) {
if (!os.is_open()) {
throw std::runtime_error("[export_function] Failed to open " + file);
}
write_header(os, count, shapeless);
}
void FunctionExporter::close() {
closed = true;
};
void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) {
if (closed) {
throw std::runtime_error(
"[export_function] Attempting to write after exporting is closed.");
}
auto [fentry, inserted] = ftable->emplace(args, kwargs);
if (!inserted) {
throw std::runtime_error(
"[export_function] Attempting to export a function twice with "
"the same signature is not allowed.");
}
// Flatten the inputs to the function for tracing
std::vector<std::string> kwarg_keys;
auto inputs = args;
for (auto& [k, v] : kwargs) {
kwarg_keys.push_back(k);
inputs.push_back(v);
}
auto flat_fun = [this, &kwarg_keys](const Args& flat_args) {
auto args = Args(flat_args.begin(), flat_args.end() - kwarg_keys.size());
Kwargs kwargs;
auto it = flat_args.end() - kwarg_keys.size();
;
for (auto& k : kwarg_keys) {
kwargs.insert({k, *it++});
}
return fun(args, kwargs);
};
// Trace to build the graph
auto [trace_inputs, trace_outputs] = detail::compile_trace(flat_fun, inputs);
// DFS the graph and get the tape
auto [tape, parents_map] =
detail::compile_dfs(trace_inputs, trace_outputs, inputs);
detail::compile_simplify(tape, parents_map, trace_outputs, /* passes */ 3);
// Update header
count++;
// Overwrite the header
auto pos = os.tell();
os.seek(0);
write_header(os, count, ftable->shapeless);
os.seek(pos);
serialize(os, kwarg_keys);
auto arrays_to_ids = [](const std::vector<array>& arrs) {
std::vector<uint64_t> ids;
for (auto& arr : arrs) {
ids.push_back(arr.id());
}
return ids;
};
// Inputs and outputs
auto trace_input_ids = arrays_to_ids(trace_inputs);
serialize(os, trace_input_ids);
serialize(os, trace_inputs);
serialize(os, arrays_to_ids(trace_outputs));
// Update the table entry
fentry.kwarg_keys = std::move(kwarg_keys);
fentry.inputs = std::move(trace_inputs);
std::unordered_set<std::uintptr_t> input_set(
trace_input_ids.begin(), trace_input_ids.end());
// Tape
auto factory = PrimitiveFactory();
serialize(os, static_cast<uint64_t>(tape.size()));
for (auto& arr : tape) {
serialize(os, static_cast<uint64_t>(arr.id()));
if (arr.has_primitive()) {
serialize(os, true);
serialize(os, arrays_to_ids(arr.inputs()));
factory.save(os, arr.primitive_ptr());
serialize(os, static_cast<uint64_t>(arr.siblings().size()));
if (arr.siblings().empty()) {
serialize(os, arr.shape());
serialize(os, arr.dtype());
} else {
auto outputs = arr.outputs();
serialize(os, arrays_to_ids(outputs));
std::vector<std::vector<int>> shapes;
std::vector<Dtype> dtypes;
for (auto& o : outputs) {
shapes.push_back(o.shape());
dtypes.push_back(o.dtype());
}
serialize(os, shapes);
serialize(os, dtypes);
}
} else {
serialize(os, false);
if (input_set.find(arr.id()) == input_set.end()) {
serialize(os, true);
// Save constant data if not already saved
if (constants.insert(arr.id()).second) {
serialize(os, arr.shape());
serialize(os, arr.dtype());
os.write(arr.data<char>(), arr.nbytes());
}
} else {
serialize(os, false);
}
}
}
}
void FunctionExporter::operator()(const Args& args) {
export_function(args, {});
}
void FunctionExporter::operator()(const Kwargs& kwargs) {
export_function({}, kwargs);
}
void FunctionExporter::operator()(const Args& args, const Kwargs& kwargs) {
export_function(args, kwargs);
}
FunctionExporter exporter(
const std::string& file,
const std::function<std::vector<array>(const Args&)>& fun,
bool shapeless /* = false */) {
return FunctionExporter{
file,
[fun](const Args& args, const Kwargs&) { return fun(args); },
shapeless};
}
FunctionExporter exporter(
const std::string& file,
const std::function<std::vector<array>(const Kwargs&)>& fun,
bool shapeless /* = false */) {
return exporter(
file,
[fun](const Args&, const Kwargs kwargs) { return fun(kwargs); },
shapeless);
}
FunctionExporter exporter(
const std::string& file,
const std::function<std::vector<array>(const Args&, const Kwargs&)>& fun,
bool shapeless /* = false */) {
return FunctionExporter{file, fun, shapeless};
}
void export_function(
const std::string& file,
const std::function<std::vector<array>(const Args&)>& fun,
const Args& args,
bool shapeless /* = false */) {
exporter(file, fun, shapeless)(args);
}
void export_function(
const std::string& file,
const std::function<std::vector<array>(const Kwargs&)>& fun,
const Kwargs& kwargs,
bool shapeless /* = false */) {
exporter(file, fun, shapeless)(kwargs);
}
void export_function(
const std::string& file,
const std::function<std::vector<array>(const Args&, const Kwargs&)>& fun,
const Args& args,
const Kwargs& kwargs,
bool shapeless /* = false */) {
exporter(file, fun, shapeless)(args, kwargs);
}
std::vector<array> ImportedFunction::operator()(const Kwargs& kwargs) const {
return this->operator()({}, kwargs);
}
std::vector<array> ImportedFunction::operator()(const Args& args) const {
return this->operator()(args, {});
}
std::vector<array> ImportedFunction::operator()(
const Args& args,
const Kwargs& kwargs) const {
auto* fun = ftable->find(args, kwargs);
if (fun == nullptr) {
std::ostringstream msg;
msg << "[import_function::call] No imported function found which matches "
<< "the given positional and keyword arguments. Possible functions include:\n";
ftable->print_functions(msg);
msg << "\nReceived function with " << args.size()
<< " positional inputs and " << kwargs.size() << " keyword inputs:\n";
for (int i = 0; i < args.size(); ++i) {
auto& in = args[i];
msg << " " << i + 1 << ": " << in.shape() << " " << in.dtype() << "\n";
}
for (auto& [k, in] : kwargs) {
msg << " \"" << k << "\": " << in.shape() << " " << in.dtype() << "\n";
}
throw std::invalid_argument(msg.str());
}
auto inputs = args;
for (auto& [_, v] : kwargs) {
inputs.push_back(v);
}
return detail::compile_replace(
fun->tape, fun->inputs, fun->outputs, inputs, ftable->shapeless);
}
ImportedFunction import_function(const std::string& file) {
return ImportedFunction{file};
}
ImportedFunction::ImportedFunction(const std::string& file)
: ftable(std::make_shared<FunctionTable>()) {
auto is_ptr = std::make_shared<Reader>(file);
auto& is = *is_ptr;
if (!is.is_open()) {
throw std::runtime_error("[import_function] Failed to open " + file);
}
// Parse header
auto mlx_version = deserialize<std::string>(is);
auto function_count = deserialize<int>(is);
ftable->shapeless = deserialize<bool>(is);
std::unordered_map<std::uintptr_t, array> constants;
auto import_one = [&]() {
auto kwarg_keys = deserialize<std::vector<std::string>>(is);
std::unordered_map<uint64_t, array> array_map;
auto trace_input_ids = deserialize<std::vector<uint64_t>>(is);
auto trace_inputs = deserialize<std::vector<array>>(is);
for (int i = 0; i < trace_inputs.size(); ++i) {
array_map.emplace(trace_input_ids[i], trace_inputs[i]);
}
auto trace_output_ids = deserialize<std::vector<uint64_t>>(is);
std::vector<array> tape;
auto tape_size = deserialize<uint64_t>(is);
tape.reserve(tape_size);
auto factory = PrimitiveFactory();
for (size_t i = 0; i < tape_size; ++i) {
auto id = deserialize<uint64_t>(is);
if (deserialize<bool>(is)) {
auto input_ids = deserialize<std::vector<uint64_t>>(is);
std::vector<array> inputs;
inputs.reserve(input_ids.size());
for (auto id : input_ids) {
inputs.push_back(array_map.at(id));
}
std::shared_ptr<Primitive> prim = factory.load(is);
auto num_siblings = deserialize<uint64_t>(is);
if (num_siblings == 0) {
auto shape = deserialize<std::vector<int>>(is);
auto type = deserialize<Dtype>(is);
tape.emplace_back(
std::move(shape), type, std::move(prim), std::move(inputs));
array_map.emplace(id, tape.back());
} else {
auto ids = deserialize<std::vector<uint64_t>>(is);
auto shapes = deserialize<std::vector<std::vector<int>>>(is);
auto types = deserialize<std::vector<Dtype>>(is);
auto arrays = array::make_arrays(
std::move(shapes),
std::move(types),
std::move(prim),
std::move(inputs));
for (int i = 0; i < arrays.size(); ++i) {
auto sid = ids[i];
if (sid == id) {
tape.push_back(arrays[i]);
}
array_map.emplace(sid, arrays[i]);
}
}
} else {
if (deserialize<bool>(is)) {
// Load constant
if (auto it = constants.find(id); it != constants.end()) {
tape.push_back(it->second);
} else {
auto shape = deserialize<std::vector<int>>(is);
auto type = deserialize<Dtype>(is);
size_t offset = is.tell();
tape.push_back(array(
std::move(shape),
type,
std::make_shared<Load>(
default_stream(default_device()), is_ptr, offset),
{}));
is.seek(offset + tape.back().nbytes());
constants.insert({id, tape.back()});
}
array_map.emplace(id, tape.back());
} else {
// Function inputs are in the map
tape.push_back(array_map.at(id));
}
}
}
std::vector<array> trace_outputs;
trace_outputs.reserve(trace_output_ids.size());
for (auto id : trace_output_ids) {
trace_outputs.push_back(array_map.at(id));
}
ftable->insert(
std::move(kwarg_keys),
std::move(trace_inputs),
std::move(trace_outputs),
std::move(tape));
};
for (int i = 0; i < function_count; ++i) {
import_one();
}
}
} // namespace mlx::core

66
mlx/export.h Normal file
View File

@ -0,0 +1,66 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include <map>
#include <set>
#include "mlx/array.h"
namespace mlx::core {
using Args = std::vector<array>;
using Kwargs = std::map<std::string, array>;
struct FunctionExporter;
/**
* Make an exporter to save multiple traces of a given function to
* the same file.
*/
FunctionExporter exporter(
const std::string& file,
const std::function<std::vector<array>(const Args&)>& fun,
bool shapeless = false);
FunctionExporter exporter(
const std::string& file,
const std::function<std::vector<array>(const Kwargs&)>& fun,
bool shapeless = false);
FunctionExporter exporter(
const std::string& path,
const std::function<std::vector<array>(const Args&, const Kwargs&)>& fun,
bool shapeless = false);
/**
* Export a function to a file.
*/
void export_function(
const std::string& file,
const std::function<std::vector<array>(const Args&)>& fun,
const Args& args,
bool shapeless = false);
void export_function(
const std::string& file,
const std::function<std::vector<array>(const Kwargs&)>& fun,
const Kwargs& kwargs,
bool shapeless = false);
void export_function(
const std::string& file,
const std::function<std::vector<array>(const Args&, const Kwargs&)>& fun,
const Args& args,
const Kwargs& kwargs,
bool shapeless = false);
struct ImportedFunction;
/**
* Import a function from a file.
*/
ImportedFunction import_function(const std::string& file);
} // namespace mlx::core
#include "mlx/export_impl.h"

71
mlx/export_impl.h Normal file
View File

@ -0,0 +1,71 @@
// Copyright © 2024 Apple Inc.
#include "mlx/io/load.h"
#pragma once
namespace mlx::core {
struct FunctionTable;
struct FunctionExporter {
void operator()(const std::initializer_list<array>& args) {
this->operator()(Args(args));
}
void operator()(const Args& args);
void operator()(const Kwargs& kwargs);
void operator()(const Args& args, const Kwargs& kwargs);
void close();
FunctionExporter(const FunctionExporter&) = delete;
FunctionExporter& operator=(const FunctionExporter&) = delete;
FunctionExporter(FunctionExporter&& other) = default;
private:
friend FunctionExporter exporter(
const std::string&,
const std::function<std::vector<array>(const Args&)>&,
bool shapeless);
friend FunctionExporter exporter(
const std::string&,
const std::function<std::vector<array>(const Kwargs&)>&,
bool shapeless);
friend FunctionExporter exporter(
const std::string&,
const std::function<std::vector<array>(const Args&, const Kwargs&)>&,
bool shapeless);
FunctionExporter(
const std::string& file,
std::function<std::vector<array>(const Args&, const Kwargs&)> fun,
bool shapeless);
io::FileWriter os;
std::function<std::vector<array>(const Args&, const Kwargs& kwargs)> fun;
void export_function(const Args& args, const Kwargs& kwargs);
std::set<std::uintptr_t> constants;
int count{0};
bool closed{false};
std::shared_ptr<FunctionTable> ftable;
};
struct ImportedFunction {
std::vector<array> operator()(
const std::initializer_list<array>& args) const {
return this->operator()(Args(args));
}
std::vector<array> operator()(const Args& args) const;
std::vector<array> operator()(const Kwargs& kwargs) const;
std::vector<array> operator()(const Args& args, const Kwargs& kwargs) const;
private:
ImportedFunction(const std::string& file);
friend ImportedFunction import_function(const std::string&);
ImportedFunction();
std::shared_ptr<FunctionTable> ftable;
};
} // namespace mlx::core

View File

@ -697,8 +697,7 @@ array scaled_dot_product_attention(
return array(
std::move(out_shape),
final_type,
std::make_shared<ScaledDotProductAttention>(
stream, fallback, scale, false),
std::make_shared<ScaledDotProductAttention>(stream, fallback, scale),
{q, k, v});
}
@ -712,7 +711,7 @@ array scaled_dot_product_attention(
bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const {
const ScaledDotProductAttention& a_other =
static_cast<const ScaledDotProductAttention&>(other);
return needs_mask_ == a_other.needs_mask_ && scale_ == a_other.scale_;
return scale_ == a_other.scale_;
}
array pack_and_quantize(

View File

@ -60,6 +60,10 @@ class RMSNorm : public Custom {
bool is_equivalent(const Primitive& other) const override;
DEFINE_INPUT_OUTPUT_SHAPE()
auto state() const {
return std::make_pair(nullptr, eps_);
}
private:
std::function<std::vector<array>(std::vector<array>)> fallback_;
float eps_;
@ -82,6 +86,9 @@ class RMSNormVJP : public Custom {
DEFINE_PRINT(RMSNormVJP)
bool is_equivalent(const Primitive& other) const override;
auto state() const {
return std::make_pair(nullptr, eps_);
}
private:
std::function<std::vector<array>(std::vector<array>)> fallback_;
@ -112,6 +119,9 @@ class LayerNorm : public Custom {
DEFINE_PRINT(LayerNorm)
bool is_equivalent(const Primitive& other) const override;
DEFINE_INPUT_OUTPUT_SHAPE()
auto state() const {
return std::make_pair(nullptr, eps_);
}
private:
std::function<std::vector<array>(std::vector<array>)> fallback_;
@ -135,6 +145,9 @@ class LayerNormVJP : public Custom {
DEFINE_PRINT(LayerNormVJP)
bool is_equivalent(const Primitive& other) const override;
auto state() const {
return std::make_pair(nullptr, eps_);
}
private:
std::function<std::vector<array>(std::vector<array>)> fallback_;
@ -174,6 +187,10 @@ class RoPE : public Custom {
DEFINE_PRINT(RoPE)
bool is_equivalent(const Primitive& other) const override;
DEFINE_INPUT_OUTPUT_SHAPE()
auto state() const {
return std::make_tuple(
nullptr, dims_, traditional_, base_, scale_, forward_);
}
private:
std::function<std::vector<array>(std::vector<array>)> fallback_;
@ -189,9 +206,8 @@ class ScaledDotProductAttention : public Custom {
explicit ScaledDotProductAttention(
Stream stream,
std::function<std::vector<array>(std::vector<array>)> fallback,
const float scale,
const bool needs_mask)
: Custom(stream, fallback), scale_(scale), needs_mask_(needs_mask) {}
const float scale)
: Custom(stream, fallback), scale_(scale) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override {
@ -208,11 +224,13 @@ class ScaledDotProductAttention : public Custom {
DEFINE_PRINT(ScaledDotProductAttention);
DEFINE_INPUT_OUTPUT_SHAPE()
auto state() const {
return std::make_pair(nullptr, scale_);
}
private:
std::function<std::vector<array>(std::vector<array>)> fallback_;
float scale_;
bool needs_mask_;
};
class AffineQuantize : public Custom {
@ -238,6 +256,9 @@ class AffineQuantize : public Custom {
bool is_equivalent(const Primitive& other) const override;
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
auto state() const {
return std::make_tuple(nullptr, group_size_, bits_, dequantize_);
}
private:
std::function<std::vector<array>(std::vector<array>)> fallback_;

View File

@ -78,8 +78,15 @@ class ParallelFileReader : public Reader {
return lseek(fd_, 0, SEEK_CUR);
}
void seek(int64_t, std::ios_base::seekdir = std::ios_base::beg) override {
throw std::runtime_error("[ParallelFileReader::seek] Not allowed");
// Warning: do not use this function from multiple threads as
// it advances the file descriptor
void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
override {
if (way == std::ios_base::beg) {
lseek(fd_, off, 0);
} else {
lseek(fd_, off, SEEK_CUR);
}
}
// Warning: do not use this function from multiple threads as
@ -108,8 +115,16 @@ class FileWriter : public Writer {
0644)),
label_(std::move(file_path)) {}
FileWriter(const FileWriter&) = delete;
FileWriter& operator=(const FileWriter&) = delete;
FileWriter(FileWriter&& other) {
std::swap(fd_, other.fd_);
}
~FileWriter() override {
close(fd_);
if (fd_ != 0) {
close(fd_);
}
}
bool is_open() const override {
@ -151,7 +166,7 @@ class FileWriter : public Writer {
}
private:
int fd_;
int fd_{0};
std::string label_;
};

View File

@ -9,6 +9,7 @@
#include "mlx/distributed/distributed.h"
#include "mlx/distributed/ops.h"
#include "mlx/einsum.h"
#include "mlx/export.h"
#include "mlx/fast.h"
#include "mlx/fft.h"
#include "mlx/io.h"

View File

@ -4486,7 +4486,7 @@ std::pair<std::vector<array>, std::vector<int>> View::vmap(
}
void View::print(std::ostream& os) {
os << "View" << dtype_;
os << "View " << dtype_;
}
bool View::is_equivalent(const Primitive& other) const {

View File

@ -203,6 +203,9 @@ class AddMM : public UnaryPrimitive {
DEFINE_PRINT(AddMM)
bool is_equivalent(const Primitive& other) const override;
std::pair<float, float> state() const {
return {alpha_, beta_};
};
private:
const float alpha_;
@ -220,6 +223,9 @@ class Arange : public UnaryPrimitive {
DEFINE_PRINT(Arange)
bool is_equivalent(const Primitive& other) const override;
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
std::tuple<double, double, double> state() const {
return {start_, stop_, step_};
};
private:
double start_;
@ -361,6 +367,9 @@ class ArgPartition : public UnaryPrimitive {
DEFINE_PRINT(ArgPartition)
DEFINE_INPUT_OUTPUT_SHAPE()
bool is_equivalent(const Primitive& other) const override;
std::pair<int, int> state() const {
return {kth_, axis_};
};
private:
int kth_;
@ -387,6 +396,9 @@ class ArgReduce : public UnaryPrimitive {
DEFINE_PRINT(ArgReduce)
bool is_equivalent(const Primitive& other) const override;
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
std::pair<ReduceType, int> state() const {
return {reduce_type_, axis_};
};
private:
ReduceType reduce_type_;
@ -407,6 +419,9 @@ class ArgSort : public UnaryPrimitive {
DEFINE_PRINT(ArgSort)
DEFINE_INPUT_OUTPUT_SHAPE()
bool is_equivalent(const Primitive& other) const override;
int state() const {
return axis_;
};
private:
int axis_;
@ -427,6 +442,9 @@ class AsType : public UnaryPrimitive {
DEFINE_PRINT(AsType)
DEFINE_INPUT_OUTPUT_SHAPE()
bool is_equivalent(const Primitive& other) const override;
Dtype state() const {
return dtype_;
};
private:
Dtype dtype_;
@ -448,6 +466,9 @@ class AsStrided : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_PRINT(AsStrided)
bool is_equivalent(const Primitive& other) const override;
auto state() const {
return std::make_tuple(shape_, strides_, offset_);
}
private:
Shape shape_;
@ -472,6 +493,9 @@ class BitwiseBinary : public UnaryPrimitive {
bool is_equivalent(const Primitive& other) const override;
void print(std::ostream& os) override;
DEFINE_INPUT_OUTPUT_SHAPE()
auto state() const {
return op_;
}
private:
Op op_;
@ -493,6 +517,9 @@ class BlockMaskedMM : public UnaryPrimitive {
DEFINE_PRINT(BlockMaskedMM)
bool is_equivalent(const Primitive& other) const override;
auto state() const {
return block_size_;
}
private:
int block_size_;
@ -532,6 +559,9 @@ class Broadcast : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_PRINT(Broadcast)
bool is_equivalent(const Primitive& other) const override;
std::vector<int> state() const {
return shape_;
};
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
@ -613,6 +643,9 @@ class Concatenate : public UnaryPrimitive {
DEFINE_PRINT(Concatenate)
bool is_equivalent(const Primitive& other) const override;
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
auto state() const {
return axis_;
}
private:
int axis_;
@ -684,6 +717,15 @@ class Convolution : public UnaryPrimitive {
DEFINE_PRINT(Convolution)
bool is_equivalent(const Primitive& other) const override;
auto state() const {
return std::make_tuple(
padding_,
kernel_strides_,
kernel_dilation_,
input_dilation_,
groups_,
flip_);
}
private:
std::vector<int> padding_;
@ -912,6 +954,9 @@ class Equal : public UnaryPrimitive {
os << "Equal";
}
}
auto state() const {
return equal_nan_;
};
private:
void eval(const std::vector<array>& inputs, array& out);
@ -1001,6 +1046,9 @@ class ExpandDims : public UnaryPrimitive {
bool is_equivalent(const Primitive& other) const override;
static Shape output_shape(const array& input, const std::vector<int>& axes);
auto state() const {
return axes_;
}
private:
void eval(const std::vector<array>& inputs, array& out);
@ -1024,6 +1072,9 @@ class FFT : public UnaryPrimitive {
DEFINE_PRINT(FFT)
bool is_equivalent(const Primitive& other) const override;
auto state() const {
return std::make_tuple(axes_, inverse_, real_);
}
private:
std::vector<size_t> axes_;
@ -1048,6 +1099,9 @@ class Flatten : public UnaryPrimitive {
bool is_equivalent(const Primitive& other) const override;
static Shape output_shape(const array& input, int start_axis, int end_axis);
auto state() const {
return std::make_pair(start_axis_, end_axis_);
}
private:
int start_axis_;
@ -1103,6 +1157,9 @@ class Gather : public UnaryPrimitive {
DEFINE_PRINT(Gather)
bool is_equivalent(const Primitive& other) const override;
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
std::pair<std::vector<int>, std::vector<int>> state() const {
return {axes_, slice_sizes_};
}
private:
void eval(const std::vector<array>& inputs, array& out);
@ -1158,6 +1215,9 @@ class Hadamard : public UnaryPrimitive {
DEFINE_INPUT_OUTPUT_SHAPE()
bool is_equivalent(const Primitive& other) const override;
auto state() const {
return scale_;
}
private:
float scale_;
@ -1260,6 +1320,10 @@ class Log : public UnaryPrimitive {
DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
Base state() const {
return base_;
};
void print(std::ostream& os) override {
switch (base_) {
case e:
@ -1488,6 +1552,9 @@ class NumberOfElements : public UnaryPrimitive {
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override {
return {{}};
}
std::tuple<std::vector<int>, bool, Dtype> state() const {
return {axes_, inverted_, dtype_};
}
private:
std::vector<int> axes_;
@ -1516,6 +1583,9 @@ class Pad : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_PRINT(Pad)
bool is_equivalent(const Primitive& other) const override;
auto state() const {
return std::make_tuple(axes_, low_pad_size_, high_pad_size_);
}
private:
std::vector<int> axes_;
@ -1538,6 +1608,9 @@ class Partition : public UnaryPrimitive {
DEFINE_PRINT(Partition)
DEFINE_INPUT_OUTPUT_SHAPE()
bool is_equivalent(const Primitive& other) const override;
auto state() const {
return std::make_pair(kth_, axis_);
};
private:
int kth_;
@ -1583,6 +1656,9 @@ class QuantizedMatmul : public UnaryPrimitive {
DEFINE_PRINT(QuantizedMatmul)
bool is_equivalent(const Primitive& other) const override;
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
auto state() const {
return std::make_tuple(group_size_, bits_, transpose_);
}
private:
int group_size_;
@ -1607,6 +1683,9 @@ class GatherQMM : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_PRINT(GatherQMM)
bool is_equivalent(const Primitive& other) const override;
auto state() const {
return std::make_tuple(group_size_, bits_, transpose_);
}
private:
int group_size_;
@ -1627,6 +1706,9 @@ class RandomBits : public UnaryPrimitive {
DEFINE_VMAP()
DEFINE_PRINT(RandomBits)
bool is_equivalent(const Primitive& other) const override;
std::pair<std::vector<int>, int> state() const {
return {shape_, width_};
};
private:
Shape shape_;
@ -1661,6 +1743,9 @@ class Reshape : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_PRINT(Reshape)
bool is_equivalent(const Primitive& other) const override;
std::vector<int> state() const {
return shape_;
};
private:
Shape shape_;
@ -1712,6 +1797,9 @@ class Reduce : public UnaryPrimitive {
}
}
bool is_equivalent(const Primitive& other) const override;
std::pair<ReduceType, std::vector<int>> state() const {
return {reduce_type_, axes_};
};
private:
ReduceType reduce_type_;
@ -1777,6 +1865,9 @@ class Scan : public UnaryPrimitive {
}
}
bool is_equivalent(const Primitive& other) const override;
auto state() const {
return std::make_tuple(reduce_type_, axis_, reverse_, inclusive_);
}
private:
ReduceType reduce_type_;
@ -1823,6 +1914,9 @@ class Scatter : public UnaryPrimitive {
}
}
bool is_equivalent(const Primitive& other) const override;
std::pair<ReduceType, std::vector<int>> state() const {
return {reduce_type_, axes_};
};
private:
void eval(const std::vector<array>& inputs, array& out);
@ -1917,6 +2011,9 @@ class Slice : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_PRINT(Slice)
bool is_equivalent(const Primitive& other) const override;
auto state() const {
return std::make_tuple(start_indices_, end_indices_, strides_);
}
private:
Shape start_indices_;
@ -1946,6 +2043,9 @@ class SliceUpdate : public UnaryPrimitive {
DEFINE_PRINT(SliceUpdate)
bool is_equivalent(const Primitive& other) const override;
DEFINE_INPUT_OUTPUT_SHAPE()
auto state() const {
return std::make_tuple(start_indices_, end_indices_, strides_);
}
private:
Shape start_indices_;
@ -1969,6 +2069,9 @@ class Softmax : public UnaryPrimitive {
DEFINE_INPUT_OUTPUT_SHAPE()
bool is_equivalent(const Primitive& other) const override;
auto state() const {
return precise_;
};
private:
void eval(const std::vector<array>& inputs, array& out);
@ -1988,6 +2091,9 @@ class Sort : public UnaryPrimitive {
DEFINE_PRINT(Sort)
DEFINE_INPUT_OUTPUT_SHAPE()
bool is_equivalent(const Primitive& other) const override;
auto state() const {
return axis_;
}
private:
int axis_;
@ -2009,6 +2115,9 @@ class Split : public Primitive {
DEFINE_GRADS()
DEFINE_PRINT(Split)
bool is_equivalent(const Primitive& other) const override;
std::pair<std::vector<int>, int> state() const {
return {indices_, axis_};
};
private:
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
@ -2046,6 +2155,9 @@ class Sqrt : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_INPUT_OUTPUT_SHAPE()
bool is_equivalent(const Primitive& other) const override;
auto state() const {
return recip_;
}
void print(std::ostream& os) override {
if (recip_) {
@ -2109,6 +2221,9 @@ class Squeeze : public UnaryPrimitive {
bool is_equivalent(const Primitive& other) const override;
static Shape output_shape(const array& input, const std::vector<int>& axes);
auto state() const {
return axes_;
};
private:
void eval(const std::vector<array>& inputs, array& out);
@ -2164,6 +2279,9 @@ class Unflatten : public UnaryPrimitive {
bool is_equivalent(const Primitive& other) const override;
static Shape output_shape(const array& input, int axis, const Shape& shape);
auto state() const {
return std::make_pair(axis_, shape_);
}
private:
int axis_;
@ -2171,21 +2289,6 @@ class Unflatten : public UnaryPrimitive {
void eval(const std::vector<array>& inputs, array& out);
};
class Uniform : public UnaryPrimitive {
public:
explicit Uniform(Stream stream) : UnaryPrimitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
DEFINE_PRINT(Uniform)
DEFINE_DEFAULT_IS_EQUIVALENT()
private:
void eval(const std::vector<array>& inputs, array& out);
};
class View : public UnaryPrimitive {
public:
explicit View(Stream stream, Dtype dtype)
@ -2197,6 +2300,9 @@ class View : public UnaryPrimitive {
DEFINE_VMAP()
void print(std::ostream& os) override;
bool is_equivalent(const Primitive& other) const override;
auto state() const {
return dtype_;
}
private:
Dtype dtype_;
@ -2215,6 +2321,9 @@ class Transpose : public UnaryPrimitive {
DEFINE_PRINT(Transpose)
bool is_equivalent(const Primitive& other) const override;
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
std::vector<int> state() const {
return axes_;
};
private:
std::vector<int> axes_;
@ -2266,6 +2375,9 @@ class Inverse : public UnaryPrimitive {
DEFINE_VMAP()
DEFINE_PRINT(Inverse)
auto state() const {
return std::make_pair(tri_, upper_);
}
private:
void eval(const std::vector<array>& inputs, array& output);
@ -2280,6 +2392,9 @@ class Cholesky : public UnaryPrimitive {
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
auto state() const {
return upper_;
}
DEFINE_VMAP()
DEFINE_PRINT(Cholesky)
@ -2307,6 +2422,9 @@ class Eigh : public Primitive {
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
bool is_equivalent(const Primitive& other) const override;
auto state() const {
return std::make_pair(uplo_, compute_eigenvectors_);
}
private:
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);

View File

@ -21,6 +21,10 @@ void set_default_stream(Stream s) {
return scheduler::scheduler().set_default_stream(s);
}
Stream get_stream(int index) {
return scheduler::scheduler().get_stream(index);
}
Stream new_stream(Device d) {
if (!metal::is_available() && d == Device::gpu) {
throw std::invalid_argument(

View File

@ -96,6 +96,9 @@ class Scheduler {
Stream get_default_stream(const Device& d) const {
return default_streams_.at(d.type);
}
Stream get_stream(int index) const {
return streams_.at(index)->stream;
}
void set_default_stream(const Stream& s) {
default_streams_.at(s.device.type) = s;

View File

@ -21,6 +21,9 @@ void set_default_stream(Stream s);
/** Make a new stream on the given device. */
Stream new_stream(Device d);
/** Get the stream with the given index. */
Stream get_stream(int index);
inline bool operator==(const Stream& lhs, const Stream& rhs) {
return lhs.index == rhs.index;
}

View File

@ -10,6 +10,11 @@ namespace mlx::core {
void async_eval(std::vector<array> outputs);
template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
void async_eval(Arrays&&... outputs) {
async_eval(std::vector<array>{std::forward<Arrays>(outputs)...});
}
void eval(std::vector<array> outputs);
template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>

View File

@ -11,6 +11,7 @@ nanobind_add_module(
${CMAKE_CURRENT_SOURCE_DIR}/convert.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
${CMAKE_CURRENT_SOURCE_DIR}/export.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp

View File

@ -331,7 +331,7 @@ PyScalarT validate_shape(
t = pycomplex;
} else {
std::ostringstream msg;
msg << "Invalid type " << nb::type_name(l.type()).c_str()
msg << "Invalid type " << nb::type_name(l.type()).c_str()
<< " received in array initialization.";
throw std::invalid_argument(msg.str());
}

277
python/src/export.cpp Normal file
View File

@ -0,0 +1,277 @@
// Copyright © 2024 Apple Inc.
#include <nanobind/nanobind.h>
#include <nanobind/stl/map.h>
#include <nanobind/stl/optional.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/vector.h>
#include <fstream>
#include "mlx/array.h"
#include "mlx/export.h"
#include "mlx/graph_utils.h"
#include "python/src/trees.h"
namespace mx = mlx::core;
namespace nb = nanobind;
using namespace nb::literals;
std::pair<std::vector<mx::array>, std::map<std::string, mx::array>>
validate_and_extract_inputs(
const nb::args& args,
const nb::kwargs& kwargs,
const std::string& prefix) {
auto maybe_throw = [&prefix](bool valid) {
if (!valid) {
throw std::invalid_argument(
prefix +
" Inputs can either be a variable "
"number of positional and keyword arrays or a single tuple "
"and/or dictionary of arrays.");
}
};
std::vector<mx::array> args_;
std::map<std::string, mx::array> kwargs_;
if (args.size() == 0) {
// No args so kwargs must be keyword arrays
maybe_throw(nb::try_cast(kwargs, kwargs_));
} else if (args.size() > 0 && nb::isinstance<mx::array>(args[0])) {
// Args are positional arrays and kwargs are keyword arrays
maybe_throw(nb::try_cast(args, args_));
maybe_throw(nb::try_cast(kwargs, kwargs_));
} else if (args.size() == 1) {
// - args[0] can be a tuple or list or arrays or a dict
// with string keys and array values
// - kwargs should be empty
maybe_throw(kwargs.size() == 0);
if (!nb::try_cast(args[0], args_)) {
maybe_throw(nb::try_cast(args[0], kwargs_));
}
} else if (args.size() == 2) {
// - args[0] can be a tuple or list of arrays
// - args[1] can be a dict of string keys with array values.
// - kwargs should be empty
maybe_throw(kwargs.size() == 0);
maybe_throw(nb::try_cast(args[0], args_));
maybe_throw(nb::try_cast(args[1], kwargs_));
} else {
maybe_throw(false);
}
return {args_, kwargs_};
}
auto wrap_export_function(const nb::callable& fun) {
return [fun](
const std::vector<mx::array>& args_,
const std::map<std::string, mx::array>& kwargs_) {
auto kwargs = nb::dict();
kwargs.update(nb::cast(kwargs_));
auto args = nb::tuple(nb::cast(args_));
auto outputs = fun(*args, **kwargs);
std::vector<mx::array> outputs_;
if (nb::isinstance<mx::array>(outputs)) {
outputs_.push_back(nb::cast<mx::array>(outputs));
} else if (!nb::try_cast(outputs, outputs_)) {
throw std::invalid_argument(
"[export_function] Outputs can be either a single array "
"a tuple or list of arrays.");
}
return outputs_;
};
}
void init_export(nb::module_& m) {
m.def(
"export_function",
[](const std::string& file,
const nb::callable& fun,
const nb::args& args,
bool shapeless,
const nb::kwargs& kwargs) {
auto [args_, kwargs_] =
validate_and_extract_inputs(args, kwargs, "[export_function]");
mx::export_function(
file, wrap_export_function(fun), args_, kwargs_, shapeless);
},
"file"_a,
"fun"_a,
"args"_a,
nb::kw_only(),
"shapeless"_a = false,
"kwargs"_a,
R"pbdoc(
Export a function to a file.
Example input arrays must be provided to export a function. The example
inputs can be variable ``*args`` and ``**kwargs`` or a tuple of arrays
and/or dictionary of string keys with array values.
.. warning::
This is part of an experimental API which is likely to
change in future versions of MLX. Functions exported with older
versions of MLX may not be compatible with future versions.
Args:
file (str): File path to export the function to.
fun (Callable): A function which takes as input zero or more
:class:`array` and returns one or more :class:`array`.
*args (array): Example array inputs to the function.
shapeless (bool, optional): Whether or not the function allows
inputs with variable shapes. Default: ``False``.
**kwargs (array): Additional example keyword array inputs to the
function.
Example:
.. code-block:: python
def fun(x, y):
return x + y
x = mx.array(1)
y = mx.array([1, 2, 3])
mx.export_function("fun.mlxfn", fun, x, y=y)
)pbdoc");
m.def(
"import_function",
[](const std::string& file) {
return nb::cpp_function(
[fn = mx::import_function(file)](
const nb::args& args, const nb::kwargs& kwargs) {
auto [args_, kwargs_] = validate_and_extract_inputs(
args, kwargs, "[import_function::call]");
return nb::tuple(nb::cast(fn(args_, kwargs_)));
});
},
"file"_a,
nb::sig("def import_function(file: str) -> Callable"),
R"pbdoc(
Import a function from a file.
The imported function can be called either with ``*args`` and
``**kwargs`` or with a tuple of arrays and/or dictionary of string
keys with array values. Imported functions always return a tuple of
arrays.
.. warning::
This is part of an experimental API which is likely to
change in future versions of MLX. Functions exported with older
versions of MLX may not be compatible with future versions.
Args:
file (str): The file path to import the function from.
Returns:
Callable: The imported function.
Example:
>>> fn = mx.import_function("function.mlxfn")
>>> out = fn(a, b, x=x, y=y)[0]
>>>
>>> out = fn((a, b), {"x": x, "y": y}[0]
)pbdoc");
nb::class_<mx::FunctionExporter>(
m,
"FunctionExporter",
R"pbdoc(
A context managing class for exporting multiple traces of the same
function to a file.
Make an instance of this class by calling fun:`mx.exporter`.
)pbdoc")
.def("close", &mx::FunctionExporter::close)
.def(
"__enter__", [](mx::FunctionExporter& exporter) { return &exporter; })
.def(
"__exit__",
[](mx::FunctionExporter& exporter,
const std::optional<nb::object>&,
const std::optional<nb::object>&,
const std::optional<nb::object>&) { exporter.close(); },
"exc_type"_a = nb::none(),
"exc_value"_a = nb::none(),
"traceback"_a = nb::none())
.def(
"__call__",
[](mx::FunctionExporter& exporter,
const nb::args& args,
const nb::kwargs& kwargs) {
auto [args_, kwargs_] =
validate_and_extract_inputs(args, kwargs, "[export_function]");
exporter(args_, kwargs_);
});
m.def(
"exporter",
[](const std::string& file, const nb::callable& fun, bool shapeless) {
return mx::exporter(file, wrap_export_function(fun), shapeless);
},
"file"_a,
"fun"_a,
nb::kw_only(),
"shapeless"_a = false,
R"pbdoc(
Make a callable object to export multiple traces of a function to a file.
.. warning::
This is part of an experimental API which is likely to
change in future versions of MLX. Functions exported with older
versions of MLX may not be compatible with future versions.
Args:
file (str): File path to export the function to.
shapeless (bool, optional): Whether or not the function allows
inputs with variable shapes. Default: ``False``.
Example:
.. code-block:: python
def fun(*args):
return sum(args)
with mx.exporter("fun.mlxfn", fun) as exporter:
exporter(mx.array(1))
exporter(mx.array(1), mx.array(2))
exporter(mx.array(1), mx.array(2), mx.array(3))
)pbdoc");
m.def(
"export_to_dot",
[](nb::object file, const nb::args& args) {
std::vector<mx::array> arrays = tree_flatten(args);
if (nb::isinstance<nb::str>(file)) {
std::ofstream out(nb::cast<std::string>(file));
mx::export_to_dot(out, arrays);
} else if (nb::hasattr(file, "write")) {
std::ostringstream out;
mx::export_to_dot(out, arrays);
auto write = file.attr("write");
write(out.str());
} else {
throw std::invalid_argument(
"[export_to_dot] Accepts file-like objects or strings "
"to be used as filenames.");
}
},
"file"_a,
"args"_a,
R"pbdoc(
Export a graph to DOT format for visualization.
A variable number of output arrays can be provided for exporting
The graph exported will recursively include all enevaluated inputs of
the provided outputs.
Args:
file (str): The file path to export to.
*args (array): The output arrays.
Example:
>>> a = mx.array(1) + mx.array(2)
>>> mx.export_to_dot("graph.dot", a)
)pbdoc");
}

View File

@ -19,6 +19,7 @@ void init_linalg(nb::module_&);
void init_constants(nb::module_&);
void init_fast(nb::module_&);
void init_distributed(nb::module_&);
void init_export(nb::module_&);
NB_MODULE(core, m) {
m.doc() = "mlx: A framework for machine learning on Apple silicon.";
@ -39,6 +40,7 @@ NB_MODULE(core, m) {
init_constants(m);
init_fast(m);
init_distributed(m);
init_export(m);
m.attr("__version__") = TOSTRING(_VERSION_);
}

View File

@ -2898,7 +2898,7 @@ void init_ops(nb::module_& m) {
Generate multidimensional coordinate grids from 1-D coordinate arrays
Args:
arrays (array): Input arrays.
*arrays (array): Input arrays.
sparse (bool, optional): If ``True``, a sparse grid is returned in which each output
array has a single non-zero element. If ``False``, a dense grid is returned.
Defaults to ``False``.
@ -3840,8 +3840,8 @@ void init_ops(nb::module_& m) {
Args:
file (file, str): Path to file to which the arrays are saved.
args (arrays): Arrays to be saved.
kwargs (arrays): Arrays to be saved. Each array will be saved
*args (arrays): Arrays to be saved.
**kwargs (arrays): Arrays to be saved. Each array will be saved
with the associated keyword as the output file name.
)pbdoc");
m.def(

View File

@ -8,14 +8,12 @@
#include <nanobind/stl/vector.h>
#include <algorithm>
#include <fstream>
#include <numeric>
#include <sstream>
#include "mlx/array.h"
#include "mlx/compile.h"
#include "mlx/compile_impl.h"
#include "mlx/graph_utils.h"
#include "mlx/transforms.h"
#include "mlx/transforms_impl.h"
#include "mlx/utils.h"
@ -945,34 +943,34 @@ void init_transforms(nb::module_& m) {
Note, all custom transformations are optional. Undefined transformations
fall back to the default behaviour.
Example usage:
Example:
.. code-block:: python
.. code-block:: python
import mlx.core as mx
import mlx.core as mx
@mx.custom_function
def f(x, y):
return mx.sin(x) * y
@mx.custom_function
def f(x, y):
return mx.sin(x) * y
@f.vjp
def f_vjp(primals, cotangent, output):
@f.vjp
def f_vjp(primals, cotangent, output):
x, y = primals
return cotan * mx.cos(x) * y, cotan * mx.sin(x)
@f.jvp
def f_jvp(primals, tangents):
x, y = primals
return cotan * mx.cos(x) * y, cotan * mx.sin(x)
dx, dy = tangents
return dx * mx.cos(x) * y + dy * mx.sin(x)
@f.jvp
def f_jvp(primals, tangents):
x, y = primals
dx, dy = tangents
return dx * mx.cos(x) * y + dy * mx.sin(x)
@f.vmap
def f_vmap(inputs, axes):
x, y = inputs
ax, ay = axes
if ay != ax and ax is not None:
y = y.swapaxes(ay, ax)
return mx.sin(x) * y, (ax or ay)
@f.vmap
def f_vmap(inputs, axes):
x, y = inputs
ax, ay = axes
if ay != ax and ax is not None:
y = y.swapaxes(ay, ax)
return mx.sin(x) * y, (ax or ay)
)pbdoc")
.def(
nb::init<nb::callable>(),
@ -1313,25 +1311,6 @@ void init_transforms(nb::module_& m) {
Returns:
Callable: The vectorized function.
)pbdoc");
m.def(
"export_to_dot",
[](nb::object file, const nb::args& args) {
std::vector<mx::array> arrays = tree_flatten(args);
if (nb::isinstance<nb::str>(file)) {
std::ofstream out(nb::cast<std::string>(file));
export_to_dot(out, arrays);
} else if (nb::hasattr(file, "write")) {
std::ostringstream out;
export_to_dot(out, arrays);
auto write = file.attr("write");
write(out.str());
} else {
throw std::invalid_argument(
"export_to_dot accepts file-like objects or strings to be used as filenames");
}
},
"file"_a,
"args"_a);
m.def(
"compile",
[](const nb::callable& fun,

View File

@ -0,0 +1,244 @@
# Copyright © 2024 Apple Inc.
import os
import tempfile
import unittest
import mlx.core as mx
import mlx_tests
class TestExportImport(mlx_tests.MLXTestCase):
@classmethod
def setUpClass(cls):
cls.test_dir_fid = tempfile.TemporaryDirectory()
cls.test_dir = cls.test_dir_fid.name
if not os.path.isdir(cls.test_dir):
os.mkdir(cls.test_dir)
@classmethod
def tearDownClass(cls):
cls.test_dir_fid.cleanup()
def test_basic_export_import(self):
path = os.path.join(self.test_dir, "fn.mlxfn")
# Function with no inputs
def fun():
return mx.zeros((3, 3))
mx.export_function(path, fun)
imported = mx.import_function(path)
expected = fun()
(out,) = imported()
self.assertTrue(mx.array_equal(out, expected))
# Simple function with inputs
def fun(x):
return mx.abs(mx.sin(x))
inputs = mx.array([1.0, 2.0, 3.0, 4.0, 5.0])
mx.export_function(path, fun, inputs)
imported = mx.import_function(path)
expected = fun(inputs)
(out,) = imported(inputs)
self.assertTrue(mx.allclose(out, expected))
# Inputs in a list or tuple
def fun(x):
x = mx.abs(mx.sin(x))
return x
mx.export_function(path, fun, [inputs])
imported = mx.import_function(path)
expected = fun(inputs)
(out,) = imported([inputs])
self.assertTrue(mx.allclose(out, expected))
(out,) = imported(inputs)
self.assertTrue(mx.allclose(out, expected))
mx.export_function(path, fun, (inputs,))
imported = mx.import_function(path)
(out,) = imported((inputs,))
self.assertTrue(mx.allclose(out, expected))
# Outputs in a list
def fun(x):
return [mx.abs(mx.sin(x))]
mx.export_function(path, fun, inputs)
imported = mx.import_function(path)
(out,) = imported(inputs)
self.assertTrue(mx.allclose(out, expected))
# Outputs in a tuple
def fun(x):
return (mx.abs(mx.sin(x)),)
mx.export_function(path, fun, inputs)
imported = mx.import_function(path)
(out,) = imported(inputs)
self.assertTrue(mx.allclose(out, expected))
# Check throws on invalid inputs / outputs
def fun(x):
return mx.abs(x)
with self.assertRaises(ValueError):
mx.export_function(path, fun, "hi")
with self.assertRaises(ValueError):
mx.export_function(path, fun, mx.array(1.0), "hi")
def fun(x):
return mx.abs(x[0][0])
with self.assertRaises(ValueError):
mx.export_function(path, fun, [[mx.array(1.0)]])
def fun():
return (mx.zeros((3, 3)), 1)
with self.assertRaises(ValueError):
mx.export_function(path, fun)
def fun():
return (mx.zeros((3, 3)), [mx.zeros((3, 3))])
with self.assertRaises(ValueError):
mx.export_function(path, fun)
def fun(x, y):
return x + y
mx.export_function(path, fun, mx.array(1.0), mx.array(1.0))
imported = mx.import_function(path)
with self.assertRaises(ValueError):
imported(mx.array(1.0), 1.0)
with self.assertRaises(ValueError):
imported(mx.array(1.0), mx.array(1.0), mx.array(1.0))
with self.assertRaises(ValueError):
imported(mx.array(1.0), [mx.array(1.0)])
def test_export_random_sample(self):
path = os.path.join(self.test_dir, "fn.mlxfn")
mx.random.seed(5)
def fun():
return mx.random.uniform(shape=(3,))
mx.export_function(path, fun)
imported = mx.import_function(path)
(out,) = imported()
mx.random.seed(5)
expected = fun()
self.assertTrue(mx.array_equal(out, expected))
def test_export_with_kwargs(self):
path = os.path.join(self.test_dir, "fn.mlxfn")
def fun(x, z=None):
out = x
if z is not None:
out += z
return out
x = mx.array([1, 2, 3])
y = mx.array([1, 1, 0])
z = mx.array([2, 2, 2])
mx.export_function(path, fun, (x,), {"z": z})
imported_fun = mx.import_function(path)
with self.assertRaises(ValueError):
imported_fun(x, z)
with self.assertRaises(ValueError):
imported_fun(x, y=z)
with self.assertRaises(ValueError):
imported_fun((x,), {"y": z})
out = imported_fun(x, z=z)[0]
self.assertTrue(mx.array_equal(out, mx.array([3, 4, 5])))
out = imported_fun((x,), {"z": z})[0]
self.assertTrue(mx.array_equal(out, mx.array([3, 4, 5])))
mx.export_function(path, fun, x, z=z)
imported_fun = mx.import_function(path)
out = imported_fun(x, z=z)[0]
self.assertTrue(mx.array_equal(out, mx.array([3, 4, 5])))
out = imported_fun((x,), {"z": z})[0]
self.assertTrue(mx.array_equal(out, mx.array([3, 4, 5])))
# Only specify kwargs
mx.export_function(path, fun, x=x, z=z)
imported_fun = mx.import_function(path)
with self.assertRaises(ValueError):
out = imported_fun(x, z=z)[0]
out = imported_fun(x=x, z=z)[0]
self.assertTrue(mx.array_equal(out, mx.array([3, 4, 5])))
out = imported_fun({"x": x, "z": z})[0]
self.assertTrue(mx.array_equal(out, mx.array([3, 4, 5])))
def test_export_variable_inputs(self):
path = os.path.join(self.test_dir, "fn.mlxfn")
def fun(x, y, z=None):
out = x + y
if z is not None:
out += z
return out
with mx.exporter(path, fun) as exporter:
exporter(mx.array([1, 2, 3]), mx.array([1, 1, 1]))
exporter(mx.array([1, 2, 3]), mx.array([1, 1, 1]), z=mx.array([2]))
with self.assertRaises(RuntimeError):
exporter(mx.array([1, 2, 3, 4]), mx.array([1, 1, 1, 1]))
imported_fun = mx.import_function(path)
out = imported_fun(mx.array([1, 2, 3]), mx.array([1, 1, 1]))[0]
self.assertTrue(mx.array_equal(out, mx.array([2, 3, 4])))
out = imported_fun(mx.array([1, 2, 3]), mx.array([1, 1, 1]), z=mx.array([2]))[0]
self.assertTrue(mx.array_equal(out, mx.array([4, 5, 6])))
with self.assertRaises(ValueError):
imported_fun(mx.array([1, 2, 3, 4]), mx.array([1, 1, 1, 1]))
# A function with a large constant
constant = mx.zeros((16, 2048))
mx.eval(constant)
def fun(*args):
return constant + sum(args)
with mx.exporter(path, fun) as exporter:
for i in range(5):
exporter(*[mx.array(1)] * i)
# Check the exported file size < constant size + small amount
constants_size = constant.nbytes + 8192
self.assertTrue(os.path.getsize(path) < constants_size)
if __name__ == "__main__":
unittest.main()

View File

@ -28,15 +28,14 @@ class TestLoad(mlx_tests.MLXTestCase):
def setUpClass(cls):
cls.test_dir_fid = tempfile.TemporaryDirectory()
cls.test_dir = cls.test_dir_fid.name
if not os.path.isdir(cls.test_dir):
os.mkdir(cls.test_dir)
@classmethod
def tearDownClass(cls):
cls.test_dir_fid.cleanup()
def test_save_and_load(self):
if not os.path.isdir(self.test_dir):
os.mkdir(self.test_dir)
for dt in self.dtypes:
with self.subTest(dtype=dt):
for i, shape in enumerate([(1,), (23,), (1024, 1024), (4, 6, 3, 1, 2)]):
@ -64,9 +63,6 @@ class TestLoad(mlx_tests.MLXTestCase):
self.assertTrue(np.array_equal(load_arr_mlx_npy, save_arr_npy))
def test_save_and_load_safetensors(self):
if not os.path.isdir(self.test_dir):
os.mkdir(self.test_dir)
test_file = os.path.join(self.test_dir, "test.safetensors")
with self.assertRaises(Exception):
mx.save_safetensors(test_file, {"a": mx.ones((4, 4))}, {"testing": 0})
@ -330,9 +326,6 @@ class TestLoad(mlx_tests.MLXTestCase):
self.assertTrue(np.array_equal(save_arrs_npy[k], v))
def test_non_contiguous(self):
if not os.path.isdir(self.test_dir):
os.mkdir(self.test_dir)
a = mx.broadcast_to(mx.array([1, 2]), [4, 2])
save_file = os.path.join(self.test_dir, "a.npy")

View File

@ -24,6 +24,7 @@ target_sources(
creations_tests.cpp
device_tests.cpp
einsum_tests.cpp
export_import_tests.cpp
eval_tests.cpp
fft_tests.cpp
load_tests.cpp

View File

@ -0,0 +1,165 @@
// Copyright © 2024 Apple Inc.
#include <filesystem>
#include <stdexcept>
#include <vector>
#include "doctest/doctest.h"
#include "mlx/export.h"
#include "mlx/mlx.h"
using namespace mlx::core;
namespace {
std::string get_temp_file(const std::string& name) {
return std::filesystem::temp_directory_path().append(name);
}
} // namespace
TEST_CASE("test export basic functions") {
std::string file_path = get_temp_file("model.mlxfn");
auto fun = [](std::vector<array> x) -> std::vector<array> {
return {negative(exp(x[0]))};
};
export_function(file_path, fun, {array({1.0, 2.0})});
auto imported_fun = import_function(file_path);
// Check num inputs mismatch throws
CHECK_THROWS_AS(
imported_fun({array({1.0}), array({2.0})}), std::invalid_argument);
// Check shape mismatch throws
CHECK_THROWS_AS(imported_fun({array({1.0})}), std::invalid_argument);
// Check type mismatch throws
CHECK_THROWS_AS(imported_fun({array({1.0}, float16)}), std::invalid_argument);
auto expected = fun({array({1.0, -1.0})});
auto out = imported_fun({array({1.0, -1.0})});
CHECK(allclose(expected[0], out[0]).item<bool>());
}
TEST_CASE("test export function with no inputs") {
auto fun = [](std::vector<array> x) -> std::vector<array> {
return {zeros({2, 2})};
};
std::string file_path = get_temp_file("model.mlxfn");
export_function(file_path, fun, {});
auto imported_fun = import_function(file_path);
auto expected = fun({});
auto out = imported_fun({});
CHECK(allclose(expected[0], out[0]).item<bool>());
}
TEST_CASE("test export multi output primitives") {
std::string file_path = get_temp_file("model.mlxfn");
auto fun = [](std::vector<array> x) -> std::vector<array> {
return {divmod(x[0], x[1])};
};
auto inputs = std::vector<array>{array({5.0, -10.0}), array({3.0, -2.0})};
export_function(file_path, fun, inputs);
auto imported_fun = import_function(file_path);
auto expected = fun(inputs);
auto out = imported_fun(inputs);
CHECK(allclose(expected[0], out[0]).item<bool>());
CHECK(allclose(expected[1], out[1]).item<bool>());
}
TEST_CASE("test export primitives with state") {
std::string file_path = get_temp_file("model.mlxfn");
auto fun = [](std::vector<array> x) -> std::vector<array> {
return {argpartition(x[0], 2, 0)};
};
auto x = array({1, 3, 2, 4, 5, 7, 6, 8}, {4, 2});
export_function(file_path, fun, {x});
auto imported_fun = import_function(file_path);
auto expected = fun({x});
auto out = imported_fun({x});
CHECK(allclose(expected[0], out[0]).item<bool>());
}
TEST_CASE("test export functions with kwargs") {
std::string file_path = get_temp_file("model.mlxfn");
auto fun =
[](const std::map<std::string, array>& kwargs) -> std::vector<array> {
return {kwargs.at("x") + kwargs.at("y")};
};
export_function(file_path, fun, {{"x", array(1)}, {"y", array(2)}});
auto fn = import_function(file_path);
// Must use kwargs
CHECK_THROWS(fn({array(1), array(2)}));
// Wrong number of keys
CHECK_THROWS(fn({{"x", array(1)}, {"y", array(2)}, {"z", array(3)}}));
// Wrong keys
CHECK_THROWS(fn({{"a", array(1)}, {"b", array(2)}}));
// Works
auto out = fn({{"x", array(1)}, {"y", array(2)}})[0];
CHECK_EQ(out.item<int>(), 3);
out = fn({}, {{"x", array(1)}, {"y", array(2)}})[0];
CHECK_EQ(out.item<int>(), 3);
}
TEST_CASE("test export function with variable inputs") {
std::string file_path = get_temp_file("model.mlxfn");
auto fun = [](const std::vector<array>& args) -> std::vector<array> {
auto out = array({1, 1, 1, 1});
for (auto x : args) {
out = out + x;
}
return {out};
};
{
auto fn_exporter = exporter(file_path, fun);
fn_exporter({array(0), array(0)});
fn_exporter({array(0), array(0), array(0)});
}
auto imported_fun = import_function(file_path);
// Call with two inputs
auto out = imported_fun({array(1), array(2)})[0];
CHECK(array_equal(out, array({4, 4, 4, 4})).item<bool>());
// Call with three inputs
out = imported_fun({array(1), array(2), array(3)})[0];
CHECK(array_equal(out, array({7, 7, 7, 7})).item<bool>());
}
TEST_CASE("test export function on different stream") {
std::string file_path = get_temp_file("model.mlxfn");
// Caller is responsible for setting up streams before
// importing functoins
auto fun = [](const std::vector<array>& args) -> std::vector<array> {
return {abs(args[0], Stream(1000, Device::cpu))};
};
export_function(file_path, fun, {array({0, 1, 2})});
CHECK_THROWS(import_function(file_path));
}