mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Export / import functions to / from a file (#1642)
* export and import functions * refactor + works for few primitives * nit * allow primitives with state * nit * nit * simplify serialize / deserialize * fix for constants * python bindings * maybe fix serialize failure case * add example * more primitives, training kind of works * same result for python and c++ * some fixes * fix export * template it up * some simplificatoin * rebase * allow kwargs and multiple functions * exporter * more primitives for exporting * deal with endianness * handle invalid stream * add docstring
This commit is contained in:
parent
935c8c4bb1
commit
4ba0c24a8f
@ -27,6 +27,7 @@ option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
||||
if(NOT MLX_VERSION)
|
||||
set(MLX_VERSION 0.21.1)
|
||||
endif()
|
||||
add_compile_definitions("MLX_VERSION=${MLX_VERSION}")
|
||||
|
||||
# --------------------- Processor tests -------------------------
|
||||
|
||||
|
@ -61,6 +61,7 @@ are the CPU and GPU.
|
||||
python/array
|
||||
python/data_types
|
||||
python/devices_and_streams
|
||||
python/export
|
||||
python/ops
|
||||
python/random
|
||||
python/transforms
|
||||
|
14
docs/src/python/export.rst
Normal file
14
docs/src/python/export.rst
Normal file
@ -0,0 +1,14 @@
|
||||
.. _export:
|
||||
|
||||
Export Functions
|
||||
================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
export_function
|
||||
import_function
|
||||
exporter
|
||||
export_to_dot
|
27
examples/export/CMakeLists.txt
Normal file
27
examples/export/CMakeLists.txt
Normal file
@ -0,0 +1,27 @@
|
||||
cmake_minimum_required(VERSION 3.27)
|
||||
|
||||
project(import_mlx LANGUAGES CXX)
|
||||
|
||||
# ----------------------------- Setup -----------------------------
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||
|
||||
# ----------------------------- Dependencies -----------------------------
|
||||
find_package(
|
||||
Python 3.9
|
||||
COMPONENTS Interpreter Development.Module
|
||||
REQUIRED)
|
||||
execute_process(
|
||||
COMMAND "${Python_EXECUTABLE}" -m pip show mlx
|
||||
COMMAND grep location
|
||||
COMMAND awk "{print $4 \"/mlx\"}"
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
OUTPUT_VARIABLE MLX_ROOT)
|
||||
find_package(MLX CONFIG REQUIRED)
|
||||
|
||||
add_executable(eval_mlp eval_mlp.cpp)
|
||||
target_link_libraries(eval_mlp PRIVATE mlx)
|
||||
|
||||
add_executable(train_mlp train_mlp.cpp)
|
||||
target_link_libraries(train_mlp PRIVATE mlx)
|
49
examples/export/README.md
Normal file
49
examples/export/README.md
Normal file
@ -0,0 +1,49 @@
|
||||
## Setup
|
||||
|
||||
Install MLX:
|
||||
|
||||
```bash
|
||||
pip install mlx>=0.22
|
||||
```
|
||||
|
||||
Build the C++ examples:
|
||||
|
||||
```bash
|
||||
cmake -B build -DCMAKE_BUILD_TYPE=Release
|
||||
cmake --build build
|
||||
```
|
||||
|
||||
## Run
|
||||
|
||||
### Eval MLP
|
||||
|
||||
Run the Python script to export the eval function:
|
||||
|
||||
```bash
|
||||
python eval_mlp.py
|
||||
```
|
||||
|
||||
Then run the C++ program to import and run the function:
|
||||
|
||||
```
|
||||
./build/eval_mlp
|
||||
```
|
||||
|
||||
The Python and C++ programs should output the same result.
|
||||
|
||||
### Train MLP
|
||||
|
||||
Run the Python script to export the model initialization and training
|
||||
functions:
|
||||
|
||||
```bash
|
||||
python train_mlp.py
|
||||
```
|
||||
|
||||
Then run the C++ program to import and run the functions:
|
||||
|
||||
```
|
||||
./build/train_mlp
|
||||
```
|
||||
|
||||
The Python and C++ programs should output the same results.
|
25
examples/export/eval_mlp.cpp
Normal file
25
examples/export/eval_mlp.cpp
Normal file
@ -0,0 +1,25 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <mlx/mlx.h>
|
||||
#include <iostream>
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
int main() {
|
||||
int batch_size = 8;
|
||||
int input_dim = 32;
|
||||
|
||||
// Make the input
|
||||
random::seed(42);
|
||||
auto example_x = random::uniform({batch_size, input_dim});
|
||||
|
||||
// Import the function
|
||||
auto forward = import_function("eval_mlp.mlxfn");
|
||||
|
||||
// Call the imported function
|
||||
auto out = forward({example_x})[0];
|
||||
|
||||
std::cout << out << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
52
examples/export/eval_mlp.py
Normal file
52
examples/export/eval_mlp.py
Normal file
@ -0,0 +1,52 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.utils
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
"""A simple MLP."""
|
||||
|
||||
def __init__(
|
||||
self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int
|
||||
):
|
||||
super().__init__()
|
||||
layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim]
|
||||
self.layers = [
|
||||
nn.Linear(idim, odim)
|
||||
for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])
|
||||
]
|
||||
|
||||
def __call__(self, x):
|
||||
for l in self.layers[:-1]:
|
||||
x = nn.relu(l(x))
|
||||
return self.layers[-1](x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
batch_size = 8
|
||||
input_dim = 32
|
||||
output_dim = 10
|
||||
|
||||
# Load the model
|
||||
mx.random.seed(0) # Seed for params
|
||||
model = MLP(num_layers=5, input_dim=input_dim, hidden_dim=64, output_dim=output_dim)
|
||||
mx.eval(model)
|
||||
|
||||
# Note, the model parameters are saved in the export function
|
||||
def forward(x):
|
||||
return model(x)
|
||||
|
||||
mx.random.seed(42) # Seed for input
|
||||
example_x = mx.random.uniform(shape=(batch_size, input_dim))
|
||||
|
||||
mx.export_function("eval_mlp.mlxfn", forward, example_x)
|
||||
|
||||
# Import in Python
|
||||
imported_forward = mx.import_function("eval_mlp.mlxfn")
|
||||
expected = forward(example_x)
|
||||
(out,) = imported_forward(example_x)
|
||||
assert mx.allclose(expected, out)
|
||||
print(out)
|
35
examples/export/train_mlp.cpp
Normal file
35
examples/export/train_mlp.cpp
Normal file
@ -0,0 +1,35 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <mlx/mlx.h>
|
||||
#include <iostream>
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
int main() {
|
||||
int batch_size = 8;
|
||||
int input_dim = 32;
|
||||
int output_dim = 10;
|
||||
|
||||
auto state = import_function("init_mlp.mlxfn")({});
|
||||
|
||||
// Make the input
|
||||
random::seed(42);
|
||||
auto example_X = random::normal({batch_size, input_dim});
|
||||
auto example_y = random::randint(0, output_dim, {batch_size});
|
||||
|
||||
// Import the function
|
||||
auto step = import_function("train_mlp.mlxfn");
|
||||
|
||||
// Call the imported function
|
||||
for (int it = 0; it < 100; ++it) {
|
||||
state.insert(state.end(), {example_X, example_y});
|
||||
state = step(state);
|
||||
eval(state);
|
||||
auto loss = state.back();
|
||||
state.pop_back();
|
||||
if (it % 10 == 0) {
|
||||
std::cout << "Loss " << loss.item<float>() << std::endl;
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
76
examples/export/train_mlp.py
Normal file
76
examples/export/train_mlp.py
Normal file
@ -0,0 +1,76 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.optimizers as optim
|
||||
import mlx.utils
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
"""A simple MLP."""
|
||||
|
||||
def __init__(
|
||||
self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int
|
||||
):
|
||||
super().__init__()
|
||||
layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim]
|
||||
self.layers = [
|
||||
nn.Linear(idim, odim)
|
||||
for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])
|
||||
]
|
||||
|
||||
def __call__(self, x):
|
||||
for l in self.layers[:-1]:
|
||||
x = nn.relu(l(x))
|
||||
return self.layers[-1](x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
batch_size = 8
|
||||
input_dim = 32
|
||||
output_dim = 10
|
||||
|
||||
def init():
|
||||
# Seed for the parameter initialization
|
||||
mx.random.seed(0)
|
||||
model = MLP(
|
||||
num_layers=3, input_dim=input_dim, hidden_dim=64, output_dim=output_dim
|
||||
)
|
||||
optimizer = optim.SGD(learning_rate=1e-1)
|
||||
optimizer.init(model.parameters())
|
||||
state = [model.parameters(), optimizer.state]
|
||||
tree_structure, state = zip(*mlx.utils.tree_flatten(state))
|
||||
return model, optimizer, tree_structure, state
|
||||
|
||||
# Export the model parameter initialization
|
||||
model, optimizer, tree_structure, state = init()
|
||||
mx.eval(state)
|
||||
mx.export_function("init_mlp.mlxfn", lambda: init()[-1])
|
||||
|
||||
def loss_fn(params, X, y):
|
||||
model.update(params)
|
||||
return nn.losses.cross_entropy(model(X), y, reduction="mean")
|
||||
|
||||
def step(*inputs):
|
||||
*state, X, y = inputs
|
||||
params, opt_state = mlx.utils.tree_unflatten(list(zip(tree_structure, state)))
|
||||
optimizer.state = opt_state
|
||||
loss, grads = mx.value_and_grad(loss_fn)(params, X, y)
|
||||
params = optimizer.apply_gradients(grads, params)
|
||||
_, state = zip(*mlx.utils.tree_flatten([params, optimizer.state]))
|
||||
return *state, loss
|
||||
|
||||
# Make some random data
|
||||
mx.random.seed(42)
|
||||
example_X = mx.random.normal(shape=(batch_size, input_dim))
|
||||
example_y = mx.random.randint(low=0, high=output_dim, shape=(batch_size,))
|
||||
mx.export_function("train_mlp.mlxfn", step, *state, example_X, example_y)
|
||||
|
||||
# Export one step of SGD
|
||||
imported_step = mx.import_function("train_mlp.mlxfn")
|
||||
|
||||
for it in range(100):
|
||||
*state, loss = imported_step(*state, example_X, example_y)
|
||||
if it % 10 == 0:
|
||||
print(f"Loss {loss.item():.6}")
|
@ -5,6 +5,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/export.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
|
@ -156,9 +156,6 @@ CompileMode& compile_mode() {
|
||||
return compile_mode_;
|
||||
}
|
||||
|
||||
using ParentsMap =
|
||||
std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>;
|
||||
|
||||
// Helper like below but only merges the two provided arrays. If the src has
|
||||
// siblings then these won't be merged to the dst.
|
||||
void merge_one(array& dst, array& src, ParentsMap& parents_map) {
|
||||
@ -732,10 +729,15 @@ std::vector<array> compile_replace(
|
||||
trace_to_real.insert({trace_inputs[i].id(), inputs[i]});
|
||||
}
|
||||
|
||||
auto is_load = [](const Primitive& p) { return typeid(p) == typeid(Load); };
|
||||
|
||||
for (auto& a : tape) {
|
||||
// Arrays in the tape without primitives are constants
|
||||
// and can be used directly
|
||||
if (!a.has_primitive()) {
|
||||
// Arrays in the tape without primitives are either:
|
||||
// - inputs, which are already in the map
|
||||
// - constants, which can be used directly
|
||||
// - a load primitive which has no inputs and will become a constant
|
||||
// after the first eval
|
||||
if (!a.has_primitive() || is_load(a.primitive())) {
|
||||
trace_to_real.insert({a.id(), a});
|
||||
} else {
|
||||
// Find real inputs
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/device.h"
|
||||
#include "mlx/array.h"
|
||||
|
||||
namespace mlx::core::detail {
|
||||
|
||||
@ -22,4 +22,35 @@ void compile_erase(std::uintptr_t fun_id);
|
||||
void compile_clear_cache();
|
||||
|
||||
bool compile_available_for_device(const Device& device);
|
||||
|
||||
std::pair<std::vector<array>, std::vector<array>> compile_trace(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||
const std::vector<array>& inputs);
|
||||
|
||||
using ParentsMap =
|
||||
std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>;
|
||||
|
||||
// Traverses the graph to build a tape and a map of array ids to their parents
|
||||
std::pair<std::vector<array>, ParentsMap> compile_dfs(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs,
|
||||
const std::vector<array>& original_inputs);
|
||||
|
||||
// Simplify the tape. Note, this function modifies in-place both the tape and
|
||||
// the parents map to remove orphaned arrays
|
||||
void compile_simplify(
|
||||
std::vector<array>& tape,
|
||||
ParentsMap& parents_map,
|
||||
const std::vector<array>& outputs,
|
||||
int passes);
|
||||
|
||||
std::vector<array> compile_replace(
|
||||
const std::vector<array>& tape,
|
||||
const std::vector<array>& trace_inputs,
|
||||
const std::vector<array>& trace_outputs,
|
||||
const std::vector<array>& inputs,
|
||||
bool shapeless);
|
||||
|
||||
void compile_validate_shapeless(const std::vector<array>& tape);
|
||||
|
||||
} // namespace mlx::core::detail
|
||||
|
867
mlx/export.cpp
Normal file
867
mlx/export.cpp
Normal file
@ -0,0 +1,867 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
#include "mlx/export.h"
|
||||
#include "mlx/compile_impl.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
#define STRINGIFY(x) #x
|
||||
#define TOSTRING(x) STRINGIFY(x)
|
||||
|
||||
// clang-format off
|
||||
#define SERIALIZE_PRIMITIVE(primitive, keys...) \
|
||||
{ \
|
||||
#primitive, { \
|
||||
[](Writer& os, const Primitive& p) { \
|
||||
serialize_primitive<primitive>(os, p); \
|
||||
}, \
|
||||
[](Reader& is, Stream s) { \
|
||||
return deserialize_primitive<primitive>(is, s); \
|
||||
}, \
|
||||
{keys} \
|
||||
} \
|
||||
}
|
||||
// clang-format on
|
||||
|
||||
bool is_big_endian() {
|
||||
int num = 1;
|
||||
return *reinterpret_cast<char*>(&num) != 1;
|
||||
}
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
using namespace mlx::core::fast;
|
||||
|
||||
using Reader = io::ParallelFileReader;
|
||||
using Writer = io::FileWriter;
|
||||
|
||||
struct PrimitiveSerializer {
|
||||
using Serializer = std::function<void(Writer&, const Primitive&)>;
|
||||
using Deserializer =
|
||||
std::function<std::shared_ptr<Primitive>(Reader&, Stream s)>;
|
||||
PrimitiveSerializer(
|
||||
Serializer serialize,
|
||||
Deserializer deserialize,
|
||||
std::vector<std::string> keys = {})
|
||||
: serialize(std::move(serialize)),
|
||||
deserialize(std::move(deserialize)),
|
||||
keys(std::move(keys)) {};
|
||||
Serializer serialize;
|
||||
Deserializer deserialize;
|
||||
std::vector<std::string> keys;
|
||||
};
|
||||
|
||||
template <typename, typename = void>
|
||||
constexpr bool is_iterable = false;
|
||||
|
||||
template <typename T>
|
||||
constexpr bool is_iterable<
|
||||
T,
|
||||
std::void_t<
|
||||
decltype(std::declval<T>().begin()),
|
||||
decltype(std::declval<T>().end())>> = true;
|
||||
|
||||
template <template <typename...> class T, typename U>
|
||||
constexpr bool is_specialization_of = false;
|
||||
|
||||
template <template <typename...> class T, typename... Us>
|
||||
constexpr bool is_specialization_of<T, T<Us...>> = true;
|
||||
|
||||
template <typename T>
|
||||
constexpr bool is_pair = is_specialization_of<std::pair, std::decay_t<T>>;
|
||||
|
||||
template <typename T>
|
||||
constexpr bool is_tuple = is_specialization_of<std::tuple, std::decay_t<T>>;
|
||||
|
||||
template <typename>
|
||||
constexpr bool dependent_false = false;
|
||||
|
||||
template <typename T>
|
||||
struct NotSerializable {
|
||||
static_assert(dependent_false<T>, "Type is not serializable.");
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct NotDeserializable {
|
||||
static_assert(dependent_false<T>, "Type is not deserializable.");
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void reverse_bytes(T& data) {
|
||||
auto* bytes = reinterpret_cast<uint8_t*>(&data);
|
||||
for (size_t j = 0; j < (sizeof(T) / 2); j++) {
|
||||
std::swap(bytes[j], bytes[sizeof(T) - j - 1]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void serialize(Writer& os, T v) {
|
||||
if constexpr (std::is_arithmetic_v<T>) {
|
||||
if (is_big_endian()) {
|
||||
reverse_bytes(v);
|
||||
}
|
||||
os.write(reinterpret_cast<const char*>(&v), sizeof(T));
|
||||
} else if constexpr (std::is_enum_v<T>) {
|
||||
serialize(os, static_cast<int>(v));
|
||||
} else if constexpr (std::is_same_v<T, std::nullptr_t>) {
|
||||
} else if constexpr (is_iterable<T>) {
|
||||
serialize(os, static_cast<uint64_t>(v.size()));
|
||||
for (const auto& t : v) {
|
||||
serialize(os, t);
|
||||
}
|
||||
} else if constexpr (is_pair<T> || is_tuple<T>) {
|
||||
std::apply([&os](auto&... x) { (..., serialize(os, x)); }, v);
|
||||
} else {
|
||||
NotSerializable<T>();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, std::size_t... I>
|
||||
decltype(auto) deserialize_tuple(Reader& is, std::index_sequence<I...>);
|
||||
|
||||
template <typename T>
|
||||
T deserialize(Reader& is) {
|
||||
if constexpr (std::is_arithmetic_v<T>) {
|
||||
T v;
|
||||
is.read(reinterpret_cast<char*>(&v), sizeof(T));
|
||||
if (is_big_endian()) {
|
||||
reverse_bytes(v);
|
||||
}
|
||||
return v;
|
||||
} else if constexpr (std::is_enum_v<T>) {
|
||||
return static_cast<T>(deserialize<int>(is));
|
||||
} else if constexpr (std::is_same_v<T, std::nullptr_t>) {
|
||||
return nullptr;
|
||||
} else if constexpr (is_iterable<T>) {
|
||||
T v;
|
||||
auto size = deserialize<uint64_t>(is);
|
||||
v.reserve(size);
|
||||
for (int i = 0; i < size; ++i) {
|
||||
v.push_back(deserialize<typename T::value_type>(is));
|
||||
}
|
||||
return v;
|
||||
} else if constexpr (is_pair<T> || is_tuple<T>) {
|
||||
return deserialize_tuple<T>(
|
||||
is, std::make_index_sequence<std::tuple_size_v<std::decay_t<T>>>{});
|
||||
} else {
|
||||
NotDeserializable<T>();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, std::size_t... I>
|
||||
decltype(auto) deserialize_tuple(Reader& is, std::index_sequence<I...>) {
|
||||
return T{deserialize<std::tuple_element_t<I, T>>(is)...};
|
||||
};
|
||||
|
||||
void serialize(Writer& os, const Stream& s) {
|
||||
serialize(os, s.index);
|
||||
serialize(os, s.device.type);
|
||||
serialize(os, s.device.index);
|
||||
}
|
||||
template <>
|
||||
Stream deserialize(Reader& is) {
|
||||
auto stream_index = deserialize<int>(is);
|
||||
auto device_type = deserialize<Device::DeviceType>(is);
|
||||
auto device_index = deserialize<int>(is);
|
||||
return Stream(stream_index, Device(device_type, device_index));
|
||||
}
|
||||
|
||||
void serialize(Writer& os, const Dtype& t) {
|
||||
serialize(os, t.val());
|
||||
serialize(os, t.size());
|
||||
}
|
||||
|
||||
template <>
|
||||
Dtype deserialize(Reader& is) {
|
||||
auto val = deserialize<Dtype::Val>(is);
|
||||
auto size = deserialize<uint8_t>(is);
|
||||
return Dtype(val, size);
|
||||
};
|
||||
|
||||
void serialize(Writer& os, const array& arr) {
|
||||
serialize(os, arr.shape());
|
||||
serialize(os, arr.dtype());
|
||||
}
|
||||
template <>
|
||||
array deserialize(Reader& is) {
|
||||
auto shape = deserialize<std::vector<int>>(is);
|
||||
auto type = deserialize<Dtype>(is);
|
||||
return array(std::move(shape), type, nullptr, std::vector<array>{});
|
||||
}
|
||||
|
||||
template <typename, typename = void>
|
||||
constexpr bool has_state = false;
|
||||
|
||||
template <typename T>
|
||||
constexpr bool has_state<T, std::void_t<decltype(std::declval<T>().state())>> =
|
||||
true;
|
||||
|
||||
template <typename T>
|
||||
void serialize_primitive(Writer& os, const Primitive& p) {
|
||||
if constexpr (has_state<T>) {
|
||||
serialize(os, static_cast<const T&>(p).state());
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::shared_ptr<T> deserialize_primitive(Reader& is, Stream s) {
|
||||
if constexpr (has_state<T>) {
|
||||
auto args = deserialize<decltype(std::declval<T>().state())>(is);
|
||||
if constexpr (is_pair<decltype(args)> || is_tuple<decltype(args)>) {
|
||||
auto fn = [s](auto&&... args) {
|
||||
return std::make_shared<T>(s, std::move(args)...);
|
||||
};
|
||||
return std::apply(fn, std::move(args));
|
||||
} else {
|
||||
return std::make_shared<T>(s, std::move(args));
|
||||
}
|
||||
} else {
|
||||
return std::make_shared<T>(s);
|
||||
}
|
||||
}
|
||||
|
||||
struct PrimitiveFactory {
|
||||
std::unordered_map<std::string, PrimitiveSerializer> factory = {
|
||||
SERIALIZE_PRIMITIVE(Abs),
|
||||
SERIALIZE_PRIMITIVE(Add),
|
||||
SERIALIZE_PRIMITIVE(AddMM),
|
||||
SERIALIZE_PRIMITIVE(Arange),
|
||||
SERIALIZE_PRIMITIVE(ArcCos),
|
||||
SERIALIZE_PRIMITIVE(ArcCosh),
|
||||
SERIALIZE_PRIMITIVE(ArcSin),
|
||||
SERIALIZE_PRIMITIVE(ArcSinh),
|
||||
SERIALIZE_PRIMITIVE(ArcTan),
|
||||
SERIALIZE_PRIMITIVE(ArcTan2),
|
||||
SERIALIZE_PRIMITIVE(ArcTanh),
|
||||
SERIALIZE_PRIMITIVE(ArgPartition),
|
||||
SERIALIZE_PRIMITIVE(ArgReduce),
|
||||
SERIALIZE_PRIMITIVE(ArgSort),
|
||||
SERIALIZE_PRIMITIVE(AsType),
|
||||
SERIALIZE_PRIMITIVE(AsStrided),
|
||||
SERIALIZE_PRIMITIVE(
|
||||
BitwiseBinary,
|
||||
"BitwiseAnd",
|
||||
"BitwiseOr",
|
||||
"BitwiseXor",
|
||||
"LeftShift",
|
||||
"RightShift"),
|
||||
SERIALIZE_PRIMITIVE(BlockMaskedMM),
|
||||
SERIALIZE_PRIMITIVE(Broadcast),
|
||||
SERIALIZE_PRIMITIVE(Ceil),
|
||||
SERIALIZE_PRIMITIVE(Concatenate),
|
||||
SERIALIZE_PRIMITIVE(Conjugate),
|
||||
SERIALIZE_PRIMITIVE(Convolution),
|
||||
SERIALIZE_PRIMITIVE(Copy),
|
||||
SERIALIZE_PRIMITIVE(Cos),
|
||||
SERIALIZE_PRIMITIVE(Cosh),
|
||||
SERIALIZE_PRIMITIVE(Depends),
|
||||
SERIALIZE_PRIMITIVE(Divide),
|
||||
SERIALIZE_PRIMITIVE(DivMod),
|
||||
SERIALIZE_PRIMITIVE(Equal, "NaNEqual"),
|
||||
SERIALIZE_PRIMITIVE(Erf),
|
||||
SERIALIZE_PRIMITIVE(ErfInv),
|
||||
SERIALIZE_PRIMITIVE(Exp),
|
||||
SERIALIZE_PRIMITIVE(Expm1),
|
||||
SERIALIZE_PRIMITIVE(ExpandDims),
|
||||
SERIALIZE_PRIMITIVE(FFT),
|
||||
SERIALIZE_PRIMITIVE(Flatten),
|
||||
SERIALIZE_PRIMITIVE(Floor),
|
||||
SERIALIZE_PRIMITIVE(Full),
|
||||
SERIALIZE_PRIMITIVE(Gather),
|
||||
SERIALIZE_PRIMITIVE(GatherMM),
|
||||
SERIALIZE_PRIMITIVE(Greater),
|
||||
SERIALIZE_PRIMITIVE(GreaterEqual),
|
||||
SERIALIZE_PRIMITIVE(Hadamard),
|
||||
SERIALIZE_PRIMITIVE(Imag),
|
||||
SERIALIZE_PRIMITIVE(Less),
|
||||
SERIALIZE_PRIMITIVE(LessEqual),
|
||||
SERIALIZE_PRIMITIVE(Log, "Log2", "Log10"),
|
||||
SERIALIZE_PRIMITIVE(Log1p),
|
||||
SERIALIZE_PRIMITIVE(LogicalNot),
|
||||
SERIALIZE_PRIMITIVE(LogicalAnd),
|
||||
SERIALIZE_PRIMITIVE(LogicalOr),
|
||||
SERIALIZE_PRIMITIVE(LogAddExp),
|
||||
SERIALIZE_PRIMITIVE(Matmul),
|
||||
SERIALIZE_PRIMITIVE(Maximum),
|
||||
SERIALIZE_PRIMITIVE(Minimum),
|
||||
SERIALIZE_PRIMITIVE(Multiply),
|
||||
SERIALIZE_PRIMITIVE(Negative),
|
||||
SERIALIZE_PRIMITIVE(NotEqual),
|
||||
SERIALIZE_PRIMITIVE(Reshape),
|
||||
SERIALIZE_PRIMITIVE(NumberOfElements),
|
||||
SERIALIZE_PRIMITIVE(Pad),
|
||||
SERIALIZE_PRIMITIVE(Partition),
|
||||
SERIALIZE_PRIMITIVE(Power),
|
||||
SERIALIZE_PRIMITIVE(QuantizedMatmul),
|
||||
SERIALIZE_PRIMITIVE(GatherQMM),
|
||||
SERIALIZE_PRIMITIVE(RandomBits),
|
||||
SERIALIZE_PRIMITIVE(Real),
|
||||
SERIALIZE_PRIMITIVE(Remainder),
|
||||
SERIALIZE_PRIMITIVE(Reshape),
|
||||
SERIALIZE_PRIMITIVE(Reduce, "And", "Or", "Sum", "Prod", "Min", "Max"),
|
||||
SERIALIZE_PRIMITIVE(Round),
|
||||
SERIALIZE_PRIMITIVE(Scan, "CumSum", "CumProd", "CumMin", "CumMax"),
|
||||
SERIALIZE_PRIMITIVE(Scatter),
|
||||
SERIALIZE_PRIMITIVE(Select),
|
||||
SERIALIZE_PRIMITIVE(Sigmoid),
|
||||
SERIALIZE_PRIMITIVE(Sign),
|
||||
SERIALIZE_PRIMITIVE(Sin),
|
||||
SERIALIZE_PRIMITIVE(Sinh),
|
||||
SERIALIZE_PRIMITIVE(Slice),
|
||||
SERIALIZE_PRIMITIVE(SliceUpdate),
|
||||
SERIALIZE_PRIMITIVE(Softmax),
|
||||
SERIALIZE_PRIMITIVE(Sort),
|
||||
SERIALIZE_PRIMITIVE(Split),
|
||||
SERIALIZE_PRIMITIVE(Square),
|
||||
SERIALIZE_PRIMITIVE(Squeeze),
|
||||
SERIALIZE_PRIMITIVE(Sqrt, "Rsqrt", "Sqrt"),
|
||||
SERIALIZE_PRIMITIVE(StopGradient),
|
||||
SERIALIZE_PRIMITIVE(Subtract),
|
||||
SERIALIZE_PRIMITIVE(Tan),
|
||||
SERIALIZE_PRIMITIVE(Tanh),
|
||||
SERIALIZE_PRIMITIVE(View),
|
||||
SERIALIZE_PRIMITIVE(Transpose),
|
||||
SERIALIZE_PRIMITIVE(Unflatten),
|
||||
SERIALIZE_PRIMITIVE(QRF),
|
||||
SERIALIZE_PRIMITIVE(SVD),
|
||||
SERIALIZE_PRIMITIVE(Inverse),
|
||||
SERIALIZE_PRIMITIVE(Cholesky),
|
||||
SERIALIZE_PRIMITIVE(Eigh),
|
||||
SERIALIZE_PRIMITIVE(AffineQuantize),
|
||||
SERIALIZE_PRIMITIVE(RMSNorm),
|
||||
SERIALIZE_PRIMITIVE(RMSNormVJP),
|
||||
SERIALIZE_PRIMITIVE(LayerNorm),
|
||||
SERIALIZE_PRIMITIVE(LayerNormVJP),
|
||||
SERIALIZE_PRIMITIVE(RoPE),
|
||||
SERIALIZE_PRIMITIVE(ScaledDotProductAttention)};
|
||||
std::unordered_map<std::string, std::string> name_remap;
|
||||
|
||||
PrimitiveFactory() {
|
||||
for (auto& [n, f] : factory) {
|
||||
for (auto& k : f.keys) {
|
||||
name_remap[k] = n;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void save(Writer& os, const std::shared_ptr<Primitive>& p) {
|
||||
serialize(os, p->stream());
|
||||
std::ostringstream pout;
|
||||
p->print(pout);
|
||||
auto name = pout.str();
|
||||
name = name.substr(0, name.find(' '));
|
||||
if (auto it = name_remap.find(name); it != name_remap.end()) {
|
||||
name = it->second;
|
||||
}
|
||||
serialize(os, name);
|
||||
if (auto it = factory.find(name); it != factory.end()) {
|
||||
it->second.serialize(os, *p);
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[export_function] Unable to serialize primitive " + name);
|
||||
}
|
||||
};
|
||||
|
||||
std::shared_ptr<Primitive> load(Reader& is) {
|
||||
auto stream = deserialize<Stream>(is);
|
||||
if (get_stream(stream.index) != stream) {
|
||||
std::ostringstream msg;
|
||||
msg << "[import_function] Invalid stream encountered " << stream << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
auto name = deserialize<std::string>(is);
|
||||
if (auto it = factory.find(name); it != factory.end()) {
|
||||
return it->second.deserialize(is, stream);
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[import_function] Unable to deserialize primitive " + name);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
void write_header(Writer& os, int count, bool shapeless) {
|
||||
serialize(os, std::string(TOSTRING(MLX_VERSION)));
|
||||
serialize(os, count);
|
||||
serialize(os, shapeless);
|
||||
}
|
||||
|
||||
// A struct to hold and retrieve the graphs that are exported / imported
|
||||
struct FunctionTable {
|
||||
FunctionTable(bool shapeless = false) : shapeless(shapeless) {};
|
||||
struct Function {
|
||||
Function(
|
||||
std::vector<std::string> kwarg_keys,
|
||||
std::vector<array> inputs,
|
||||
std::vector<array> outputs,
|
||||
std::vector<array> tape)
|
||||
: kwarg_keys(std::move(kwarg_keys)),
|
||||
inputs(std::move(inputs)),
|
||||
outputs(std::move(outputs)),
|
||||
tape(std::move(tape)) {}
|
||||
|
||||
std::vector<std::string> kwarg_keys;
|
||||
std::vector<array> inputs;
|
||||
std::vector<array> outputs;
|
||||
std::vector<array> tape;
|
||||
Function(const Function&) = delete;
|
||||
Function& operator=(const Function&) = delete;
|
||||
Function(Function&&) = default;
|
||||
Function() = default;
|
||||
};
|
||||
bool shapeless;
|
||||
std::unordered_map<int, std::vector<Function>> table;
|
||||
Function* find(const Args& args, const Kwargs& kwargs);
|
||||
std::pair<Function&, bool> emplace(const Args& args, const Kwargs& kwargs);
|
||||
void insert(
|
||||
std::vector<std::string> kwarg_keys,
|
||||
std::vector<array> inputs,
|
||||
std::vector<array> outputs,
|
||||
std::vector<array> tape) {
|
||||
auto [it, _] = table.emplace(inputs.size(), std::vector<Function>{});
|
||||
it->second.emplace_back(
|
||||
std::move(kwarg_keys),
|
||||
std::move(inputs),
|
||||
std::move(outputs),
|
||||
std::move(tape));
|
||||
}
|
||||
|
||||
void print_functions(std::ostream& os) {
|
||||
int n = 1;
|
||||
for (auto& [_, vec] : table) {
|
||||
for (auto& fun : vec) {
|
||||
auto npos = fun.inputs.size() - fun.kwarg_keys.size();
|
||||
os << " " << n++ << ". Function with " << npos
|
||||
<< " positional inputs and " << fun.kwarg_keys.size()
|
||||
<< " keyword inputs:\n";
|
||||
for (int j = 0; j < fun.inputs.size(); ++j) {
|
||||
auto& in = fun.inputs[j];
|
||||
if (j < npos) {
|
||||
os << " " << j + 1 << ": ";
|
||||
} else {
|
||||
os << " \"" << fun.kwarg_keys[j - npos] << "\": ";
|
||||
}
|
||||
os << in.shape() << " " << in.dtype() << "\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
bool match(const Args& args, const Kwargs& kwargs, const Function& fun);
|
||||
};
|
||||
|
||||
bool FunctionTable::match(
|
||||
const Args& args,
|
||||
const Kwargs& kwargs,
|
||||
const Function& fun) {
|
||||
for (auto& k : fun.kwarg_keys) {
|
||||
if (kwargs.find(k) == kwargs.end()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
auto match_inputs = [shapeless = this->shapeless](
|
||||
const array& x, const array& y) {
|
||||
if (x.dtype() != y.dtype()) {
|
||||
return false;
|
||||
}
|
||||
if (!shapeless && x.shape() != y.shape()) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
int i = 0;
|
||||
for (; i < args.size(); ++i) {
|
||||
if (!match_inputs(args[i], fun.inputs[i])) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
for (auto& [_, in] : kwargs) {
|
||||
if (!match_inputs(in, fun.inputs[i++])) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::pair<FunctionTable::Function&, bool> FunctionTable::emplace(
|
||||
const Args& args,
|
||||
const Kwargs& kwargs) {
|
||||
auto n_inputs = args.size() + kwargs.size();
|
||||
auto [it, _] = table.emplace(n_inputs, std::vector<Function>{});
|
||||
auto& funs_vec = it->second;
|
||||
|
||||
for (auto& fun : funs_vec) {
|
||||
if (match(args, kwargs, fun)) {
|
||||
return {fun, false};
|
||||
}
|
||||
}
|
||||
|
||||
funs_vec.emplace_back();
|
||||
return {funs_vec.back(), true};
|
||||
}
|
||||
|
||||
FunctionTable::Function* FunctionTable::find(
|
||||
const Args& args,
|
||||
const Kwargs& kwargs) {
|
||||
auto n_inputs = args.size() + kwargs.size();
|
||||
auto it = table.find(n_inputs);
|
||||
if (it == table.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
for (auto& fun : it->second) {
|
||||
if (match(args, kwargs, fun)) {
|
||||
return &fun;
|
||||
}
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
FunctionExporter::FunctionExporter(
|
||||
const std::string& file,
|
||||
std::function<std::vector<array>(const Args&, const Kwargs&)> fun,
|
||||
bool shapeless)
|
||||
: os(file),
|
||||
fun(std::move(fun)),
|
||||
ftable(std::make_shared<FunctionTable>(shapeless)) {
|
||||
if (!os.is_open()) {
|
||||
throw std::runtime_error("[export_function] Failed to open " + file);
|
||||
}
|
||||
write_header(os, count, shapeless);
|
||||
}
|
||||
|
||||
void FunctionExporter::close() {
|
||||
closed = true;
|
||||
};
|
||||
|
||||
void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) {
|
||||
if (closed) {
|
||||
throw std::runtime_error(
|
||||
"[export_function] Attempting to write after exporting is closed.");
|
||||
}
|
||||
auto [fentry, inserted] = ftable->emplace(args, kwargs);
|
||||
if (!inserted) {
|
||||
throw std::runtime_error(
|
||||
"[export_function] Attempting to export a function twice with "
|
||||
"the same signature is not allowed.");
|
||||
}
|
||||
|
||||
// Flatten the inputs to the function for tracing
|
||||
std::vector<std::string> kwarg_keys;
|
||||
auto inputs = args;
|
||||
for (auto& [k, v] : kwargs) {
|
||||
kwarg_keys.push_back(k);
|
||||
inputs.push_back(v);
|
||||
}
|
||||
|
||||
auto flat_fun = [this, &kwarg_keys](const Args& flat_args) {
|
||||
auto args = Args(flat_args.begin(), flat_args.end() - kwarg_keys.size());
|
||||
Kwargs kwargs;
|
||||
auto it = flat_args.end() - kwarg_keys.size();
|
||||
;
|
||||
for (auto& k : kwarg_keys) {
|
||||
kwargs.insert({k, *it++});
|
||||
}
|
||||
return fun(args, kwargs);
|
||||
};
|
||||
|
||||
// Trace to build the graph
|
||||
auto [trace_inputs, trace_outputs] = detail::compile_trace(flat_fun, inputs);
|
||||
|
||||
// DFS the graph and get the tape
|
||||
auto [tape, parents_map] =
|
||||
detail::compile_dfs(trace_inputs, trace_outputs, inputs);
|
||||
|
||||
detail::compile_simplify(tape, parents_map, trace_outputs, /* passes */ 3);
|
||||
|
||||
// Update header
|
||||
count++;
|
||||
|
||||
// Overwrite the header
|
||||
auto pos = os.tell();
|
||||
os.seek(0);
|
||||
write_header(os, count, ftable->shapeless);
|
||||
os.seek(pos);
|
||||
serialize(os, kwarg_keys);
|
||||
|
||||
auto arrays_to_ids = [](const std::vector<array>& arrs) {
|
||||
std::vector<uint64_t> ids;
|
||||
for (auto& arr : arrs) {
|
||||
ids.push_back(arr.id());
|
||||
}
|
||||
return ids;
|
||||
};
|
||||
|
||||
// Inputs and outputs
|
||||
auto trace_input_ids = arrays_to_ids(trace_inputs);
|
||||
serialize(os, trace_input_ids);
|
||||
serialize(os, trace_inputs);
|
||||
serialize(os, arrays_to_ids(trace_outputs));
|
||||
|
||||
// Update the table entry
|
||||
fentry.kwarg_keys = std::move(kwarg_keys);
|
||||
fentry.inputs = std::move(trace_inputs);
|
||||
|
||||
std::unordered_set<std::uintptr_t> input_set(
|
||||
trace_input_ids.begin(), trace_input_ids.end());
|
||||
|
||||
// Tape
|
||||
auto factory = PrimitiveFactory();
|
||||
serialize(os, static_cast<uint64_t>(tape.size()));
|
||||
for (auto& arr : tape) {
|
||||
serialize(os, static_cast<uint64_t>(arr.id()));
|
||||
if (arr.has_primitive()) {
|
||||
serialize(os, true);
|
||||
serialize(os, arrays_to_ids(arr.inputs()));
|
||||
factory.save(os, arr.primitive_ptr());
|
||||
serialize(os, static_cast<uint64_t>(arr.siblings().size()));
|
||||
if (arr.siblings().empty()) {
|
||||
serialize(os, arr.shape());
|
||||
serialize(os, arr.dtype());
|
||||
} else {
|
||||
auto outputs = arr.outputs();
|
||||
serialize(os, arrays_to_ids(outputs));
|
||||
|
||||
std::vector<std::vector<int>> shapes;
|
||||
std::vector<Dtype> dtypes;
|
||||
for (auto& o : outputs) {
|
||||
shapes.push_back(o.shape());
|
||||
dtypes.push_back(o.dtype());
|
||||
}
|
||||
serialize(os, shapes);
|
||||
serialize(os, dtypes);
|
||||
}
|
||||
} else {
|
||||
serialize(os, false);
|
||||
if (input_set.find(arr.id()) == input_set.end()) {
|
||||
serialize(os, true);
|
||||
// Save constant data if not already saved
|
||||
if (constants.insert(arr.id()).second) {
|
||||
serialize(os, arr.shape());
|
||||
serialize(os, arr.dtype());
|
||||
os.write(arr.data<char>(), arr.nbytes());
|
||||
}
|
||||
} else {
|
||||
serialize(os, false);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void FunctionExporter::operator()(const Args& args) {
|
||||
export_function(args, {});
|
||||
}
|
||||
|
||||
void FunctionExporter::operator()(const Kwargs& kwargs) {
|
||||
export_function({}, kwargs);
|
||||
}
|
||||
|
||||
void FunctionExporter::operator()(const Args& args, const Kwargs& kwargs) {
|
||||
export_function(args, kwargs);
|
||||
}
|
||||
|
||||
FunctionExporter exporter(
|
||||
const std::string& file,
|
||||
const std::function<std::vector<array>(const Args&)>& fun,
|
||||
bool shapeless /* = false */) {
|
||||
return FunctionExporter{
|
||||
file,
|
||||
[fun](const Args& args, const Kwargs&) { return fun(args); },
|
||||
shapeless};
|
||||
}
|
||||
|
||||
FunctionExporter exporter(
|
||||
const std::string& file,
|
||||
const std::function<std::vector<array>(const Kwargs&)>& fun,
|
||||
bool shapeless /* = false */) {
|
||||
return exporter(
|
||||
file,
|
||||
[fun](const Args&, const Kwargs kwargs) { return fun(kwargs); },
|
||||
shapeless);
|
||||
}
|
||||
|
||||
FunctionExporter exporter(
|
||||
const std::string& file,
|
||||
const std::function<std::vector<array>(const Args&, const Kwargs&)>& fun,
|
||||
bool shapeless /* = false */) {
|
||||
return FunctionExporter{file, fun, shapeless};
|
||||
}
|
||||
|
||||
void export_function(
|
||||
const std::string& file,
|
||||
const std::function<std::vector<array>(const Args&)>& fun,
|
||||
const Args& args,
|
||||
bool shapeless /* = false */) {
|
||||
exporter(file, fun, shapeless)(args);
|
||||
}
|
||||
|
||||
void export_function(
|
||||
const std::string& file,
|
||||
const std::function<std::vector<array>(const Kwargs&)>& fun,
|
||||
const Kwargs& kwargs,
|
||||
bool shapeless /* = false */) {
|
||||
exporter(file, fun, shapeless)(kwargs);
|
||||
}
|
||||
|
||||
void export_function(
|
||||
const std::string& file,
|
||||
const std::function<std::vector<array>(const Args&, const Kwargs&)>& fun,
|
||||
const Args& args,
|
||||
const Kwargs& kwargs,
|
||||
bool shapeless /* = false */) {
|
||||
exporter(file, fun, shapeless)(args, kwargs);
|
||||
}
|
||||
|
||||
std::vector<array> ImportedFunction::operator()(const Kwargs& kwargs) const {
|
||||
return this->operator()({}, kwargs);
|
||||
}
|
||||
|
||||
std::vector<array> ImportedFunction::operator()(const Args& args) const {
|
||||
return this->operator()(args, {});
|
||||
}
|
||||
|
||||
std::vector<array> ImportedFunction::operator()(
|
||||
const Args& args,
|
||||
const Kwargs& kwargs) const {
|
||||
auto* fun = ftable->find(args, kwargs);
|
||||
if (fun == nullptr) {
|
||||
std::ostringstream msg;
|
||||
msg << "[import_function::call] No imported function found which matches "
|
||||
<< "the given positional and keyword arguments. Possible functions include:\n";
|
||||
ftable->print_functions(msg);
|
||||
msg << "\nReceived function with " << args.size()
|
||||
<< " positional inputs and " << kwargs.size() << " keyword inputs:\n";
|
||||
for (int i = 0; i < args.size(); ++i) {
|
||||
auto& in = args[i];
|
||||
msg << " " << i + 1 << ": " << in.shape() << " " << in.dtype() << "\n";
|
||||
}
|
||||
for (auto& [k, in] : kwargs) {
|
||||
msg << " \"" << k << "\": " << in.shape() << " " << in.dtype() << "\n";
|
||||
}
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto inputs = args;
|
||||
for (auto& [_, v] : kwargs) {
|
||||
inputs.push_back(v);
|
||||
}
|
||||
return detail::compile_replace(
|
||||
fun->tape, fun->inputs, fun->outputs, inputs, ftable->shapeless);
|
||||
}
|
||||
|
||||
ImportedFunction import_function(const std::string& file) {
|
||||
return ImportedFunction{file};
|
||||
}
|
||||
|
||||
ImportedFunction::ImportedFunction(const std::string& file)
|
||||
: ftable(std::make_shared<FunctionTable>()) {
|
||||
auto is_ptr = std::make_shared<Reader>(file);
|
||||
auto& is = *is_ptr;
|
||||
if (!is.is_open()) {
|
||||
throw std::runtime_error("[import_function] Failed to open " + file);
|
||||
}
|
||||
|
||||
// Parse header
|
||||
auto mlx_version = deserialize<std::string>(is);
|
||||
auto function_count = deserialize<int>(is);
|
||||
ftable->shapeless = deserialize<bool>(is);
|
||||
std::unordered_map<std::uintptr_t, array> constants;
|
||||
|
||||
auto import_one = [&]() {
|
||||
auto kwarg_keys = deserialize<std::vector<std::string>>(is);
|
||||
|
||||
std::unordered_map<uint64_t, array> array_map;
|
||||
auto trace_input_ids = deserialize<std::vector<uint64_t>>(is);
|
||||
auto trace_inputs = deserialize<std::vector<array>>(is);
|
||||
for (int i = 0; i < trace_inputs.size(); ++i) {
|
||||
array_map.emplace(trace_input_ids[i], trace_inputs[i]);
|
||||
}
|
||||
auto trace_output_ids = deserialize<std::vector<uint64_t>>(is);
|
||||
|
||||
std::vector<array> tape;
|
||||
auto tape_size = deserialize<uint64_t>(is);
|
||||
tape.reserve(tape_size);
|
||||
|
||||
auto factory = PrimitiveFactory();
|
||||
for (size_t i = 0; i < tape_size; ++i) {
|
||||
auto id = deserialize<uint64_t>(is);
|
||||
if (deserialize<bool>(is)) {
|
||||
auto input_ids = deserialize<std::vector<uint64_t>>(is);
|
||||
std::vector<array> inputs;
|
||||
inputs.reserve(input_ids.size());
|
||||
for (auto id : input_ids) {
|
||||
inputs.push_back(array_map.at(id));
|
||||
}
|
||||
std::shared_ptr<Primitive> prim = factory.load(is);
|
||||
auto num_siblings = deserialize<uint64_t>(is);
|
||||
if (num_siblings == 0) {
|
||||
auto shape = deserialize<std::vector<int>>(is);
|
||||
auto type = deserialize<Dtype>(is);
|
||||
tape.emplace_back(
|
||||
std::move(shape), type, std::move(prim), std::move(inputs));
|
||||
array_map.emplace(id, tape.back());
|
||||
} else {
|
||||
auto ids = deserialize<std::vector<uint64_t>>(is);
|
||||
auto shapes = deserialize<std::vector<std::vector<int>>>(is);
|
||||
auto types = deserialize<std::vector<Dtype>>(is);
|
||||
auto arrays = array::make_arrays(
|
||||
std::move(shapes),
|
||||
std::move(types),
|
||||
std::move(prim),
|
||||
std::move(inputs));
|
||||
for (int i = 0; i < arrays.size(); ++i) {
|
||||
auto sid = ids[i];
|
||||
if (sid == id) {
|
||||
tape.push_back(arrays[i]);
|
||||
}
|
||||
array_map.emplace(sid, arrays[i]);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (deserialize<bool>(is)) {
|
||||
// Load constant
|
||||
if (auto it = constants.find(id); it != constants.end()) {
|
||||
tape.push_back(it->second);
|
||||
} else {
|
||||
auto shape = deserialize<std::vector<int>>(is);
|
||||
auto type = deserialize<Dtype>(is);
|
||||
size_t offset = is.tell();
|
||||
tape.push_back(array(
|
||||
std::move(shape),
|
||||
type,
|
||||
std::make_shared<Load>(
|
||||
default_stream(default_device()), is_ptr, offset),
|
||||
{}));
|
||||
is.seek(offset + tape.back().nbytes());
|
||||
constants.insert({id, tape.back()});
|
||||
}
|
||||
array_map.emplace(id, tape.back());
|
||||
} else {
|
||||
// Function inputs are in the map
|
||||
tape.push_back(array_map.at(id));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<array> trace_outputs;
|
||||
trace_outputs.reserve(trace_output_ids.size());
|
||||
for (auto id : trace_output_ids) {
|
||||
trace_outputs.push_back(array_map.at(id));
|
||||
}
|
||||
ftable->insert(
|
||||
std::move(kwarg_keys),
|
||||
std::move(trace_inputs),
|
||||
std::move(trace_outputs),
|
||||
std::move(tape));
|
||||
};
|
||||
|
||||
for (int i = 0; i < function_count; ++i) {
|
||||
import_one();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
66
mlx/export.h
Normal file
66
mlx/export.h
Normal file
@ -0,0 +1,66 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include "mlx/array.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
using Args = std::vector<array>;
|
||||
using Kwargs = std::map<std::string, array>;
|
||||
|
||||
struct FunctionExporter;
|
||||
|
||||
/**
|
||||
* Make an exporter to save multiple traces of a given function to
|
||||
* the same file.
|
||||
*/
|
||||
FunctionExporter exporter(
|
||||
const std::string& file,
|
||||
const std::function<std::vector<array>(const Args&)>& fun,
|
||||
bool shapeless = false);
|
||||
|
||||
FunctionExporter exporter(
|
||||
const std::string& file,
|
||||
const std::function<std::vector<array>(const Kwargs&)>& fun,
|
||||
bool shapeless = false);
|
||||
|
||||
FunctionExporter exporter(
|
||||
const std::string& path,
|
||||
const std::function<std::vector<array>(const Args&, const Kwargs&)>& fun,
|
||||
bool shapeless = false);
|
||||
|
||||
/**
|
||||
* Export a function to a file.
|
||||
*/
|
||||
void export_function(
|
||||
const std::string& file,
|
||||
const std::function<std::vector<array>(const Args&)>& fun,
|
||||
const Args& args,
|
||||
bool shapeless = false);
|
||||
|
||||
void export_function(
|
||||
const std::string& file,
|
||||
const std::function<std::vector<array>(const Kwargs&)>& fun,
|
||||
const Kwargs& kwargs,
|
||||
bool shapeless = false);
|
||||
|
||||
void export_function(
|
||||
const std::string& file,
|
||||
const std::function<std::vector<array>(const Args&, const Kwargs&)>& fun,
|
||||
const Args& args,
|
||||
const Kwargs& kwargs,
|
||||
bool shapeless = false);
|
||||
|
||||
struct ImportedFunction;
|
||||
|
||||
/**
|
||||
* Import a function from a file.
|
||||
*/
|
||||
ImportedFunction import_function(const std::string& file);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
#include "mlx/export_impl.h"
|
71
mlx/export_impl.h
Normal file
71
mlx/export_impl.h
Normal file
@ -0,0 +1,71 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/io/load.h"
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
struct FunctionTable;
|
||||
|
||||
struct FunctionExporter {
|
||||
void operator()(const std::initializer_list<array>& args) {
|
||||
this->operator()(Args(args));
|
||||
}
|
||||
void operator()(const Args& args);
|
||||
void operator()(const Kwargs& kwargs);
|
||||
void operator()(const Args& args, const Kwargs& kwargs);
|
||||
|
||||
void close();
|
||||
|
||||
FunctionExporter(const FunctionExporter&) = delete;
|
||||
FunctionExporter& operator=(const FunctionExporter&) = delete;
|
||||
FunctionExporter(FunctionExporter&& other) = default;
|
||||
|
||||
private:
|
||||
friend FunctionExporter exporter(
|
||||
const std::string&,
|
||||
const std::function<std::vector<array>(const Args&)>&,
|
||||
bool shapeless);
|
||||
|
||||
friend FunctionExporter exporter(
|
||||
const std::string&,
|
||||
const std::function<std::vector<array>(const Kwargs&)>&,
|
||||
bool shapeless);
|
||||
|
||||
friend FunctionExporter exporter(
|
||||
const std::string&,
|
||||
const std::function<std::vector<array>(const Args&, const Kwargs&)>&,
|
||||
bool shapeless);
|
||||
|
||||
FunctionExporter(
|
||||
const std::string& file,
|
||||
std::function<std::vector<array>(const Args&, const Kwargs&)> fun,
|
||||
bool shapeless);
|
||||
io::FileWriter os;
|
||||
std::function<std::vector<array>(const Args&, const Kwargs& kwargs)> fun;
|
||||
void export_function(const Args& args, const Kwargs& kwargs);
|
||||
std::set<std::uintptr_t> constants;
|
||||
int count{0};
|
||||
bool closed{false};
|
||||
std::shared_ptr<FunctionTable> ftable;
|
||||
};
|
||||
|
||||
struct ImportedFunction {
|
||||
std::vector<array> operator()(
|
||||
const std::initializer_list<array>& args) const {
|
||||
return this->operator()(Args(args));
|
||||
}
|
||||
std::vector<array> operator()(const Args& args) const;
|
||||
std::vector<array> operator()(const Kwargs& kwargs) const;
|
||||
std::vector<array> operator()(const Args& args, const Kwargs& kwargs) const;
|
||||
|
||||
private:
|
||||
ImportedFunction(const std::string& file);
|
||||
friend ImportedFunction import_function(const std::string&);
|
||||
ImportedFunction();
|
||||
|
||||
std::shared_ptr<FunctionTable> ftable;
|
||||
};
|
||||
|
||||
} // namespace mlx::core
|
@ -697,8 +697,7 @@ array scaled_dot_product_attention(
|
||||
return array(
|
||||
std::move(out_shape),
|
||||
final_type,
|
||||
std::make_shared<ScaledDotProductAttention>(
|
||||
stream, fallback, scale, false),
|
||||
std::make_shared<ScaledDotProductAttention>(stream, fallback, scale),
|
||||
{q, k, v});
|
||||
}
|
||||
|
||||
@ -712,7 +711,7 @@ array scaled_dot_product_attention(
|
||||
bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const {
|
||||
const ScaledDotProductAttention& a_other =
|
||||
static_cast<const ScaledDotProductAttention&>(other);
|
||||
return needs_mask_ == a_other.needs_mask_ && scale_ == a_other.scale_;
|
||||
return scale_ == a_other.scale_;
|
||||
}
|
||||
|
||||
array pack_and_quantize(
|
||||
|
@ -60,6 +60,10 @@ class RMSNorm : public Custom {
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
auto state() const {
|
||||
return std::make_pair(nullptr, eps_);
|
||||
}
|
||||
|
||||
private:
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
||||
float eps_;
|
||||
@ -82,6 +86,9 @@ class RMSNormVJP : public Custom {
|
||||
|
||||
DEFINE_PRINT(RMSNormVJP)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
auto state() const {
|
||||
return std::make_pair(nullptr, eps_);
|
||||
}
|
||||
|
||||
private:
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
||||
@ -112,6 +119,9 @@ class LayerNorm : public Custom {
|
||||
DEFINE_PRINT(LayerNorm)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
auto state() const {
|
||||
return std::make_pair(nullptr, eps_);
|
||||
}
|
||||
|
||||
private:
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
||||
@ -135,6 +145,9 @@ class LayerNormVJP : public Custom {
|
||||
|
||||
DEFINE_PRINT(LayerNormVJP)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
auto state() const {
|
||||
return std::make_pair(nullptr, eps_);
|
||||
}
|
||||
|
||||
private:
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
||||
@ -174,6 +187,10 @@ class RoPE : public Custom {
|
||||
DEFINE_PRINT(RoPE)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
auto state() const {
|
||||
return std::make_tuple(
|
||||
nullptr, dims_, traditional_, base_, scale_, forward_);
|
||||
}
|
||||
|
||||
private:
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
||||
@ -189,9 +206,8 @@ class ScaledDotProductAttention : public Custom {
|
||||
explicit ScaledDotProductAttention(
|
||||
Stream stream,
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback,
|
||||
const float scale,
|
||||
const bool needs_mask)
|
||||
: Custom(stream, fallback), scale_(scale), needs_mask_(needs_mask) {}
|
||||
const float scale)
|
||||
: Custom(stream, fallback), scale_(scale) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override {
|
||||
@ -208,11 +224,13 @@ class ScaledDotProductAttention : public Custom {
|
||||
|
||||
DEFINE_PRINT(ScaledDotProductAttention);
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
auto state() const {
|
||||
return std::make_pair(nullptr, scale_);
|
||||
}
|
||||
|
||||
private:
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
||||
float scale_;
|
||||
bool needs_mask_;
|
||||
};
|
||||
|
||||
class AffineQuantize : public Custom {
|
||||
@ -238,6 +256,9 @@ class AffineQuantize : public Custom {
|
||||
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||
auto state() const {
|
||||
return std::make_tuple(nullptr, group_size_, bits_, dequantize_);
|
||||
}
|
||||
|
||||
private:
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
||||
|
@ -78,8 +78,15 @@ class ParallelFileReader : public Reader {
|
||||
return lseek(fd_, 0, SEEK_CUR);
|
||||
}
|
||||
|
||||
void seek(int64_t, std::ios_base::seekdir = std::ios_base::beg) override {
|
||||
throw std::runtime_error("[ParallelFileReader::seek] Not allowed");
|
||||
// Warning: do not use this function from multiple threads as
|
||||
// it advances the file descriptor
|
||||
void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
|
||||
override {
|
||||
if (way == std::ios_base::beg) {
|
||||
lseek(fd_, off, 0);
|
||||
} else {
|
||||
lseek(fd_, off, SEEK_CUR);
|
||||
}
|
||||
}
|
||||
|
||||
// Warning: do not use this function from multiple threads as
|
||||
@ -108,8 +115,16 @@ class FileWriter : public Writer {
|
||||
0644)),
|
||||
label_(std::move(file_path)) {}
|
||||
|
||||
FileWriter(const FileWriter&) = delete;
|
||||
FileWriter& operator=(const FileWriter&) = delete;
|
||||
FileWriter(FileWriter&& other) {
|
||||
std::swap(fd_, other.fd_);
|
||||
}
|
||||
|
||||
~FileWriter() override {
|
||||
close(fd_);
|
||||
if (fd_ != 0) {
|
||||
close(fd_);
|
||||
}
|
||||
}
|
||||
|
||||
bool is_open() const override {
|
||||
@ -151,7 +166,7 @@ class FileWriter : public Writer {
|
||||
}
|
||||
|
||||
private:
|
||||
int fd_;
|
||||
int fd_{0};
|
||||
std::string label_;
|
||||
};
|
||||
|
||||
|
@ -9,6 +9,7 @@
|
||||
#include "mlx/distributed/distributed.h"
|
||||
#include "mlx/distributed/ops.h"
|
||||
#include "mlx/einsum.h"
|
||||
#include "mlx/export.h"
|
||||
#include "mlx/fast.h"
|
||||
#include "mlx/fft.h"
|
||||
#include "mlx/io.h"
|
||||
|
@ -4486,7 +4486,7 @@ std::pair<std::vector<array>, std::vector<int>> View::vmap(
|
||||
}
|
||||
|
||||
void View::print(std::ostream& os) {
|
||||
os << "View" << dtype_;
|
||||
os << "View " << dtype_;
|
||||
}
|
||||
|
||||
bool View::is_equivalent(const Primitive& other) const {
|
||||
|
148
mlx/primitives.h
148
mlx/primitives.h
@ -203,6 +203,9 @@ class AddMM : public UnaryPrimitive {
|
||||
DEFINE_PRINT(AddMM)
|
||||
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
std::pair<float, float> state() const {
|
||||
return {alpha_, beta_};
|
||||
};
|
||||
|
||||
private:
|
||||
const float alpha_;
|
||||
@ -220,6 +223,9 @@ class Arange : public UnaryPrimitive {
|
||||
DEFINE_PRINT(Arange)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||
std::tuple<double, double, double> state() const {
|
||||
return {start_, stop_, step_};
|
||||
};
|
||||
|
||||
private:
|
||||
double start_;
|
||||
@ -361,6 +367,9 @@ class ArgPartition : public UnaryPrimitive {
|
||||
DEFINE_PRINT(ArgPartition)
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
std::pair<int, int> state() const {
|
||||
return {kth_, axis_};
|
||||
};
|
||||
|
||||
private:
|
||||
int kth_;
|
||||
@ -387,6 +396,9 @@ class ArgReduce : public UnaryPrimitive {
|
||||
DEFINE_PRINT(ArgReduce)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||
std::pair<ReduceType, int> state() const {
|
||||
return {reduce_type_, axis_};
|
||||
};
|
||||
|
||||
private:
|
||||
ReduceType reduce_type_;
|
||||
@ -407,6 +419,9 @@ class ArgSort : public UnaryPrimitive {
|
||||
DEFINE_PRINT(ArgSort)
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
int state() const {
|
||||
return axis_;
|
||||
};
|
||||
|
||||
private:
|
||||
int axis_;
|
||||
@ -427,6 +442,9 @@ class AsType : public UnaryPrimitive {
|
||||
DEFINE_PRINT(AsType)
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
Dtype state() const {
|
||||
return dtype_;
|
||||
};
|
||||
|
||||
private:
|
||||
Dtype dtype_;
|
||||
@ -448,6 +466,9 @@ class AsStrided : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(AsStrided)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
auto state() const {
|
||||
return std::make_tuple(shape_, strides_, offset_);
|
||||
}
|
||||
|
||||
private:
|
||||
Shape shape_;
|
||||
@ -472,6 +493,9 @@ class BitwiseBinary : public UnaryPrimitive {
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
void print(std::ostream& os) override;
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
auto state() const {
|
||||
return op_;
|
||||
}
|
||||
|
||||
private:
|
||||
Op op_;
|
||||
@ -493,6 +517,9 @@ class BlockMaskedMM : public UnaryPrimitive {
|
||||
|
||||
DEFINE_PRINT(BlockMaskedMM)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
auto state() const {
|
||||
return block_size_;
|
||||
}
|
||||
|
||||
private:
|
||||
int block_size_;
|
||||
@ -532,6 +559,9 @@ class Broadcast : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Broadcast)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
std::vector<int> state() const {
|
||||
return shape_;
|
||||
};
|
||||
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||
|
||||
@ -613,6 +643,9 @@ class Concatenate : public UnaryPrimitive {
|
||||
DEFINE_PRINT(Concatenate)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||
auto state() const {
|
||||
return axis_;
|
||||
}
|
||||
|
||||
private:
|
||||
int axis_;
|
||||
@ -684,6 +717,15 @@ class Convolution : public UnaryPrimitive {
|
||||
|
||||
DEFINE_PRINT(Convolution)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
auto state() const {
|
||||
return std::make_tuple(
|
||||
padding_,
|
||||
kernel_strides_,
|
||||
kernel_dilation_,
|
||||
input_dilation_,
|
||||
groups_,
|
||||
flip_);
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<int> padding_;
|
||||
@ -912,6 +954,9 @@ class Equal : public UnaryPrimitive {
|
||||
os << "Equal";
|
||||
}
|
||||
}
|
||||
auto state() const {
|
||||
return equal_nan_;
|
||||
};
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
@ -1001,6 +1046,9 @@ class ExpandDims : public UnaryPrimitive {
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
static Shape output_shape(const array& input, const std::vector<int>& axes);
|
||||
auto state() const {
|
||||
return axes_;
|
||||
}
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
@ -1024,6 +1072,9 @@ class FFT : public UnaryPrimitive {
|
||||
DEFINE_PRINT(FFT)
|
||||
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
auto state() const {
|
||||
return std::make_tuple(axes_, inverse_, real_);
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<size_t> axes_;
|
||||
@ -1048,6 +1099,9 @@ class Flatten : public UnaryPrimitive {
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
static Shape output_shape(const array& input, int start_axis, int end_axis);
|
||||
auto state() const {
|
||||
return std::make_pair(start_axis_, end_axis_);
|
||||
}
|
||||
|
||||
private:
|
||||
int start_axis_;
|
||||
@ -1103,6 +1157,9 @@ class Gather : public UnaryPrimitive {
|
||||
DEFINE_PRINT(Gather)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||
std::pair<std::vector<int>, std::vector<int>> state() const {
|
||||
return {axes_, slice_sizes_};
|
||||
}
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
@ -1158,6 +1215,9 @@ class Hadamard : public UnaryPrimitive {
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
auto state() const {
|
||||
return scale_;
|
||||
}
|
||||
|
||||
private:
|
||||
float scale_;
|
||||
@ -1260,6 +1320,10 @@ class Log : public UnaryPrimitive {
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
Base state() const {
|
||||
return base_;
|
||||
};
|
||||
|
||||
void print(std::ostream& os) override {
|
||||
switch (base_) {
|
||||
case e:
|
||||
@ -1488,6 +1552,9 @@ class NumberOfElements : public UnaryPrimitive {
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override {
|
||||
return {{}};
|
||||
}
|
||||
std::tuple<std::vector<int>, bool, Dtype> state() const {
|
||||
return {axes_, inverted_, dtype_};
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<int> axes_;
|
||||
@ -1516,6 +1583,9 @@ class Pad : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Pad)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
auto state() const {
|
||||
return std::make_tuple(axes_, low_pad_size_, high_pad_size_);
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<int> axes_;
|
||||
@ -1538,6 +1608,9 @@ class Partition : public UnaryPrimitive {
|
||||
DEFINE_PRINT(Partition)
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
auto state() const {
|
||||
return std::make_pair(kth_, axis_);
|
||||
};
|
||||
|
||||
private:
|
||||
int kth_;
|
||||
@ -1583,6 +1656,9 @@ class QuantizedMatmul : public UnaryPrimitive {
|
||||
DEFINE_PRINT(QuantizedMatmul)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||
auto state() const {
|
||||
return std::make_tuple(group_size_, bits_, transpose_);
|
||||
}
|
||||
|
||||
private:
|
||||
int group_size_;
|
||||
@ -1607,6 +1683,9 @@ class GatherQMM : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(GatherQMM)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
auto state() const {
|
||||
return std::make_tuple(group_size_, bits_, transpose_);
|
||||
}
|
||||
|
||||
private:
|
||||
int group_size_;
|
||||
@ -1627,6 +1706,9 @@ class RandomBits : public UnaryPrimitive {
|
||||
DEFINE_VMAP()
|
||||
DEFINE_PRINT(RandomBits)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
std::pair<std::vector<int>, int> state() const {
|
||||
return {shape_, width_};
|
||||
};
|
||||
|
||||
private:
|
||||
Shape shape_;
|
||||
@ -1661,6 +1743,9 @@ class Reshape : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Reshape)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
std::vector<int> state() const {
|
||||
return shape_;
|
||||
};
|
||||
|
||||
private:
|
||||
Shape shape_;
|
||||
@ -1712,6 +1797,9 @@ class Reduce : public UnaryPrimitive {
|
||||
}
|
||||
}
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
std::pair<ReduceType, std::vector<int>> state() const {
|
||||
return {reduce_type_, axes_};
|
||||
};
|
||||
|
||||
private:
|
||||
ReduceType reduce_type_;
|
||||
@ -1777,6 +1865,9 @@ class Scan : public UnaryPrimitive {
|
||||
}
|
||||
}
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
auto state() const {
|
||||
return std::make_tuple(reduce_type_, axis_, reverse_, inclusive_);
|
||||
}
|
||||
|
||||
private:
|
||||
ReduceType reduce_type_;
|
||||
@ -1823,6 +1914,9 @@ class Scatter : public UnaryPrimitive {
|
||||
}
|
||||
}
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
std::pair<ReduceType, std::vector<int>> state() const {
|
||||
return {reduce_type_, axes_};
|
||||
};
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
@ -1917,6 +2011,9 @@ class Slice : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Slice)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
auto state() const {
|
||||
return std::make_tuple(start_indices_, end_indices_, strides_);
|
||||
}
|
||||
|
||||
private:
|
||||
Shape start_indices_;
|
||||
@ -1946,6 +2043,9 @@ class SliceUpdate : public UnaryPrimitive {
|
||||
DEFINE_PRINT(SliceUpdate)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
auto state() const {
|
||||
return std::make_tuple(start_indices_, end_indices_, strides_);
|
||||
}
|
||||
|
||||
private:
|
||||
Shape start_indices_;
|
||||
@ -1969,6 +2069,9 @@ class Softmax : public UnaryPrimitive {
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
auto state() const {
|
||||
return precise_;
|
||||
};
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
@ -1988,6 +2091,9 @@ class Sort : public UnaryPrimitive {
|
||||
DEFINE_PRINT(Sort)
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
auto state() const {
|
||||
return axis_;
|
||||
}
|
||||
|
||||
private:
|
||||
int axis_;
|
||||
@ -2009,6 +2115,9 @@ class Split : public Primitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Split)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
std::pair<std::vector<int>, int> state() const {
|
||||
return {indices_, axis_};
|
||||
};
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
|
||||
@ -2046,6 +2155,9 @@ class Sqrt : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
auto state() const {
|
||||
return recip_;
|
||||
}
|
||||
|
||||
void print(std::ostream& os) override {
|
||||
if (recip_) {
|
||||
@ -2109,6 +2221,9 @@ class Squeeze : public UnaryPrimitive {
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
static Shape output_shape(const array& input, const std::vector<int>& axes);
|
||||
auto state() const {
|
||||
return axes_;
|
||||
};
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
@ -2164,6 +2279,9 @@ class Unflatten : public UnaryPrimitive {
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
static Shape output_shape(const array& input, int axis, const Shape& shape);
|
||||
auto state() const {
|
||||
return std::make_pair(axis_, shape_);
|
||||
}
|
||||
|
||||
private:
|
||||
int axis_;
|
||||
@ -2171,21 +2289,6 @@ class Unflatten : public UnaryPrimitive {
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Uniform : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Uniform(Stream stream) : UnaryPrimitive(stream) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
||||
DEFINE_VMAP()
|
||||
DEFINE_PRINT(Uniform)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class View : public UnaryPrimitive {
|
||||
public:
|
||||
explicit View(Stream stream, Dtype dtype)
|
||||
@ -2197,6 +2300,9 @@ class View : public UnaryPrimitive {
|
||||
DEFINE_VMAP()
|
||||
void print(std::ostream& os) override;
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
auto state() const {
|
||||
return dtype_;
|
||||
}
|
||||
|
||||
private:
|
||||
Dtype dtype_;
|
||||
@ -2215,6 +2321,9 @@ class Transpose : public UnaryPrimitive {
|
||||
DEFINE_PRINT(Transpose)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||
std::vector<int> state() const {
|
||||
return axes_;
|
||||
};
|
||||
|
||||
private:
|
||||
std::vector<int> axes_;
|
||||
@ -2266,6 +2375,9 @@ class Inverse : public UnaryPrimitive {
|
||||
|
||||
DEFINE_VMAP()
|
||||
DEFINE_PRINT(Inverse)
|
||||
auto state() const {
|
||||
return std::make_pair(tri_, upper_);
|
||||
}
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& output);
|
||||
@ -2280,6 +2392,9 @@ class Cholesky : public UnaryPrimitive {
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
auto state() const {
|
||||
return upper_;
|
||||
}
|
||||
|
||||
DEFINE_VMAP()
|
||||
DEFINE_PRINT(Cholesky)
|
||||
@ -2307,6 +2422,9 @@ class Eigh : public Primitive {
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
auto state() const {
|
||||
return std::make_pair(uplo_, compute_eigenvectors_);
|
||||
}
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
|
||||
|
@ -21,6 +21,10 @@ void set_default_stream(Stream s) {
|
||||
return scheduler::scheduler().set_default_stream(s);
|
||||
}
|
||||
|
||||
Stream get_stream(int index) {
|
||||
return scheduler::scheduler().get_stream(index);
|
||||
}
|
||||
|
||||
Stream new_stream(Device d) {
|
||||
if (!metal::is_available() && d == Device::gpu) {
|
||||
throw std::invalid_argument(
|
||||
|
@ -96,6 +96,9 @@ class Scheduler {
|
||||
Stream get_default_stream(const Device& d) const {
|
||||
return default_streams_.at(d.type);
|
||||
}
|
||||
Stream get_stream(int index) const {
|
||||
return streams_.at(index)->stream;
|
||||
}
|
||||
|
||||
void set_default_stream(const Stream& s) {
|
||||
default_streams_.at(s.device.type) = s;
|
||||
|
@ -21,6 +21,9 @@ void set_default_stream(Stream s);
|
||||
/** Make a new stream on the given device. */
|
||||
Stream new_stream(Device d);
|
||||
|
||||
/** Get the stream with the given index. */
|
||||
Stream get_stream(int index);
|
||||
|
||||
inline bool operator==(const Stream& lhs, const Stream& rhs) {
|
||||
return lhs.index == rhs.index;
|
||||
}
|
||||
|
@ -10,6 +10,11 @@ namespace mlx::core {
|
||||
|
||||
void async_eval(std::vector<array> outputs);
|
||||
|
||||
template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
|
||||
void async_eval(Arrays&&... outputs) {
|
||||
async_eval(std::vector<array>{std::forward<Arrays>(outputs)...});
|
||||
}
|
||||
|
||||
void eval(std::vector<array> outputs);
|
||||
|
||||
template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
|
||||
|
@ -11,6 +11,7 @@ nanobind_add_module(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/convert.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/export.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||
|
@ -331,7 +331,7 @@ PyScalarT validate_shape(
|
||||
t = pycomplex;
|
||||
} else {
|
||||
std::ostringstream msg;
|
||||
msg << "Invalid type " << nb::type_name(l.type()).c_str()
|
||||
msg << "Invalid type " << nb::type_name(l.type()).c_str()
|
||||
<< " received in array initialization.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
277
python/src/export.cpp
Normal file
277
python/src/export.cpp
Normal file
@ -0,0 +1,277 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/map.h>
|
||||
#include <nanobind/stl/optional.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
#include <nanobind/stl/vector.h>
|
||||
|
||||
#include <fstream>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/export.h"
|
||||
#include "mlx/graph_utils.h"
|
||||
#include "python/src/trees.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
namespace nb = nanobind;
|
||||
using namespace nb::literals;
|
||||
|
||||
std::pair<std::vector<mx::array>, std::map<std::string, mx::array>>
|
||||
validate_and_extract_inputs(
|
||||
const nb::args& args,
|
||||
const nb::kwargs& kwargs,
|
||||
const std::string& prefix) {
|
||||
auto maybe_throw = [&prefix](bool valid) {
|
||||
if (!valid) {
|
||||
throw std::invalid_argument(
|
||||
prefix +
|
||||
" Inputs can either be a variable "
|
||||
"number of positional and keyword arrays or a single tuple "
|
||||
"and/or dictionary of arrays.");
|
||||
}
|
||||
};
|
||||
std::vector<mx::array> args_;
|
||||
std::map<std::string, mx::array> kwargs_;
|
||||
if (args.size() == 0) {
|
||||
// No args so kwargs must be keyword arrays
|
||||
maybe_throw(nb::try_cast(kwargs, kwargs_));
|
||||
} else if (args.size() > 0 && nb::isinstance<mx::array>(args[0])) {
|
||||
// Args are positional arrays and kwargs are keyword arrays
|
||||
maybe_throw(nb::try_cast(args, args_));
|
||||
maybe_throw(nb::try_cast(kwargs, kwargs_));
|
||||
} else if (args.size() == 1) {
|
||||
// - args[0] can be a tuple or list or arrays or a dict
|
||||
// with string keys and array values
|
||||
// - kwargs should be empty
|
||||
maybe_throw(kwargs.size() == 0);
|
||||
if (!nb::try_cast(args[0], args_)) {
|
||||
maybe_throw(nb::try_cast(args[0], kwargs_));
|
||||
}
|
||||
} else if (args.size() == 2) {
|
||||
// - args[0] can be a tuple or list of arrays
|
||||
// - args[1] can be a dict of string keys with array values.
|
||||
// - kwargs should be empty
|
||||
maybe_throw(kwargs.size() == 0);
|
||||
maybe_throw(nb::try_cast(args[0], args_));
|
||||
maybe_throw(nb::try_cast(args[1], kwargs_));
|
||||
} else {
|
||||
maybe_throw(false);
|
||||
}
|
||||
return {args_, kwargs_};
|
||||
}
|
||||
|
||||
auto wrap_export_function(const nb::callable& fun) {
|
||||
return [fun](
|
||||
const std::vector<mx::array>& args_,
|
||||
const std::map<std::string, mx::array>& kwargs_) {
|
||||
auto kwargs = nb::dict();
|
||||
kwargs.update(nb::cast(kwargs_));
|
||||
auto args = nb::tuple(nb::cast(args_));
|
||||
auto outputs = fun(*args, **kwargs);
|
||||
std::vector<mx::array> outputs_;
|
||||
if (nb::isinstance<mx::array>(outputs)) {
|
||||
outputs_.push_back(nb::cast<mx::array>(outputs));
|
||||
} else if (!nb::try_cast(outputs, outputs_)) {
|
||||
throw std::invalid_argument(
|
||||
"[export_function] Outputs can be either a single array "
|
||||
"a tuple or list of arrays.");
|
||||
}
|
||||
return outputs_;
|
||||
};
|
||||
}
|
||||
|
||||
void init_export(nb::module_& m) {
|
||||
m.def(
|
||||
"export_function",
|
||||
[](const std::string& file,
|
||||
const nb::callable& fun,
|
||||
const nb::args& args,
|
||||
bool shapeless,
|
||||
const nb::kwargs& kwargs) {
|
||||
auto [args_, kwargs_] =
|
||||
validate_and_extract_inputs(args, kwargs, "[export_function]");
|
||||
mx::export_function(
|
||||
file, wrap_export_function(fun), args_, kwargs_, shapeless);
|
||||
},
|
||||
"file"_a,
|
||||
"fun"_a,
|
||||
"args"_a,
|
||||
nb::kw_only(),
|
||||
"shapeless"_a = false,
|
||||
"kwargs"_a,
|
||||
R"pbdoc(
|
||||
Export a function to a file.
|
||||
|
||||
Example input arrays must be provided to export a function. The example
|
||||
inputs can be variable ``*args`` and ``**kwargs`` or a tuple of arrays
|
||||
and/or dictionary of string keys with array values.
|
||||
|
||||
.. warning::
|
||||
|
||||
This is part of an experimental API which is likely to
|
||||
change in future versions of MLX. Functions exported with older
|
||||
versions of MLX may not be compatible with future versions.
|
||||
|
||||
Args:
|
||||
file (str): File path to export the function to.
|
||||
fun (Callable): A function which takes as input zero or more
|
||||
:class:`array` and returns one or more :class:`array`.
|
||||
*args (array): Example array inputs to the function.
|
||||
shapeless (bool, optional): Whether or not the function allows
|
||||
inputs with variable shapes. Default: ``False``.
|
||||
**kwargs (array): Additional example keyword array inputs to the
|
||||
function.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def fun(x, y):
|
||||
return x + y
|
||||
|
||||
x = mx.array(1)
|
||||
y = mx.array([1, 2, 3])
|
||||
mx.export_function("fun.mlxfn", fun, x, y=y)
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"import_function",
|
||||
[](const std::string& file) {
|
||||
return nb::cpp_function(
|
||||
[fn = mx::import_function(file)](
|
||||
const nb::args& args, const nb::kwargs& kwargs) {
|
||||
auto [args_, kwargs_] = validate_and_extract_inputs(
|
||||
args, kwargs, "[import_function::call]");
|
||||
return nb::tuple(nb::cast(fn(args_, kwargs_)));
|
||||
});
|
||||
},
|
||||
"file"_a,
|
||||
nb::sig("def import_function(file: str) -> Callable"),
|
||||
R"pbdoc(
|
||||
Import a function from a file.
|
||||
|
||||
The imported function can be called either with ``*args`` and
|
||||
``**kwargs`` or with a tuple of arrays and/or dictionary of string
|
||||
keys with array values. Imported functions always return a tuple of
|
||||
arrays.
|
||||
|
||||
.. warning::
|
||||
|
||||
This is part of an experimental API which is likely to
|
||||
change in future versions of MLX. Functions exported with older
|
||||
versions of MLX may not be compatible with future versions.
|
||||
|
||||
Args:
|
||||
file (str): The file path to import the function from.
|
||||
|
||||
Returns:
|
||||
Callable: The imported function.
|
||||
|
||||
Example:
|
||||
>>> fn = mx.import_function("function.mlxfn")
|
||||
>>> out = fn(a, b, x=x, y=y)[0]
|
||||
>>>
|
||||
>>> out = fn((a, b), {"x": x, "y": y}[0]
|
||||
)pbdoc");
|
||||
|
||||
nb::class_<mx::FunctionExporter>(
|
||||
m,
|
||||
"FunctionExporter",
|
||||
R"pbdoc(
|
||||
A context managing class for exporting multiple traces of the same
|
||||
function to a file.
|
||||
|
||||
Make an instance of this class by calling fun:`mx.exporter`.
|
||||
)pbdoc")
|
||||
.def("close", &mx::FunctionExporter::close)
|
||||
.def(
|
||||
"__enter__", [](mx::FunctionExporter& exporter) { return &exporter; })
|
||||
.def(
|
||||
"__exit__",
|
||||
[](mx::FunctionExporter& exporter,
|
||||
const std::optional<nb::object>&,
|
||||
const std::optional<nb::object>&,
|
||||
const std::optional<nb::object>&) { exporter.close(); },
|
||||
"exc_type"_a = nb::none(),
|
||||
"exc_value"_a = nb::none(),
|
||||
"traceback"_a = nb::none())
|
||||
.def(
|
||||
"__call__",
|
||||
[](mx::FunctionExporter& exporter,
|
||||
const nb::args& args,
|
||||
const nb::kwargs& kwargs) {
|
||||
auto [args_, kwargs_] =
|
||||
validate_and_extract_inputs(args, kwargs, "[export_function]");
|
||||
exporter(args_, kwargs_);
|
||||
});
|
||||
|
||||
m.def(
|
||||
"exporter",
|
||||
[](const std::string& file, const nb::callable& fun, bool shapeless) {
|
||||
return mx::exporter(file, wrap_export_function(fun), shapeless);
|
||||
},
|
||||
"file"_a,
|
||||
"fun"_a,
|
||||
nb::kw_only(),
|
||||
"shapeless"_a = false,
|
||||
R"pbdoc(
|
||||
Make a callable object to export multiple traces of a function to a file.
|
||||
|
||||
.. warning::
|
||||
|
||||
This is part of an experimental API which is likely to
|
||||
change in future versions of MLX. Functions exported with older
|
||||
versions of MLX may not be compatible with future versions.
|
||||
|
||||
Args:
|
||||
file (str): File path to export the function to.
|
||||
shapeless (bool, optional): Whether or not the function allows
|
||||
inputs with variable shapes. Default: ``False``.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def fun(*args):
|
||||
return sum(args)
|
||||
|
||||
with mx.exporter("fun.mlxfn", fun) as exporter:
|
||||
exporter(mx.array(1))
|
||||
exporter(mx.array(1), mx.array(2))
|
||||
exporter(mx.array(1), mx.array(2), mx.array(3))
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"export_to_dot",
|
||||
[](nb::object file, const nb::args& args) {
|
||||
std::vector<mx::array> arrays = tree_flatten(args);
|
||||
if (nb::isinstance<nb::str>(file)) {
|
||||
std::ofstream out(nb::cast<std::string>(file));
|
||||
mx::export_to_dot(out, arrays);
|
||||
} else if (nb::hasattr(file, "write")) {
|
||||
std::ostringstream out;
|
||||
mx::export_to_dot(out, arrays);
|
||||
auto write = file.attr("write");
|
||||
write(out.str());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[export_to_dot] Accepts file-like objects or strings "
|
||||
"to be used as filenames.");
|
||||
}
|
||||
},
|
||||
"file"_a,
|
||||
"args"_a,
|
||||
R"pbdoc(
|
||||
Export a graph to DOT format for visualization.
|
||||
|
||||
A variable number of output arrays can be provided for exporting
|
||||
The graph exported will recursively include all enevaluated inputs of
|
||||
the provided outputs.
|
||||
|
||||
Args:
|
||||
file (str): The file path to export to.
|
||||
*args (array): The output arrays.
|
||||
|
||||
Example:
|
||||
>>> a = mx.array(1) + mx.array(2)
|
||||
>>> mx.export_to_dot("graph.dot", a)
|
||||
)pbdoc");
|
||||
}
|
@ -19,6 +19,7 @@ void init_linalg(nb::module_&);
|
||||
void init_constants(nb::module_&);
|
||||
void init_fast(nb::module_&);
|
||||
void init_distributed(nb::module_&);
|
||||
void init_export(nb::module_&);
|
||||
|
||||
NB_MODULE(core, m) {
|
||||
m.doc() = "mlx: A framework for machine learning on Apple silicon.";
|
||||
@ -39,6 +40,7 @@ NB_MODULE(core, m) {
|
||||
init_constants(m);
|
||||
init_fast(m);
|
||||
init_distributed(m);
|
||||
init_export(m);
|
||||
|
||||
m.attr("__version__") = TOSTRING(_VERSION_);
|
||||
}
|
||||
|
@ -2898,7 +2898,7 @@ void init_ops(nb::module_& m) {
|
||||
Generate multidimensional coordinate grids from 1-D coordinate arrays
|
||||
|
||||
Args:
|
||||
arrays (array): Input arrays.
|
||||
*arrays (array): Input arrays.
|
||||
sparse (bool, optional): If ``True``, a sparse grid is returned in which each output
|
||||
array has a single non-zero element. If ``False``, a dense grid is returned.
|
||||
Defaults to ``False``.
|
||||
@ -3840,8 +3840,8 @@ void init_ops(nb::module_& m) {
|
||||
|
||||
Args:
|
||||
file (file, str): Path to file to which the arrays are saved.
|
||||
args (arrays): Arrays to be saved.
|
||||
kwargs (arrays): Arrays to be saved. Each array will be saved
|
||||
*args (arrays): Arrays to be saved.
|
||||
**kwargs (arrays): Arrays to be saved. Each array will be saved
|
||||
with the associated keyword as the output file name.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
|
@ -8,14 +8,12 @@
|
||||
#include <nanobind/stl/vector.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <fstream>
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/compile.h"
|
||||
#include "mlx/compile_impl.h"
|
||||
#include "mlx/graph_utils.h"
|
||||
#include "mlx/transforms.h"
|
||||
#include "mlx/transforms_impl.h"
|
||||
#include "mlx/utils.h"
|
||||
@ -945,34 +943,34 @@ void init_transforms(nb::module_& m) {
|
||||
Note, all custom transformations are optional. Undefined transformations
|
||||
fall back to the default behaviour.
|
||||
|
||||
Example usage:
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.core as mx
|
||||
|
||||
@mx.custom_function
|
||||
def f(x, y):
|
||||
return mx.sin(x) * y
|
||||
@mx.custom_function
|
||||
def f(x, y):
|
||||
return mx.sin(x) * y
|
||||
|
||||
@f.vjp
|
||||
def f_vjp(primals, cotangent, output):
|
||||
@f.vjp
|
||||
def f_vjp(primals, cotangent, output):
|
||||
x, y = primals
|
||||
return cotan * mx.cos(x) * y, cotan * mx.sin(x)
|
||||
|
||||
@f.jvp
|
||||
def f_jvp(primals, tangents):
|
||||
x, y = primals
|
||||
return cotan * mx.cos(x) * y, cotan * mx.sin(x)
|
||||
dx, dy = tangents
|
||||
return dx * mx.cos(x) * y + dy * mx.sin(x)
|
||||
|
||||
@f.jvp
|
||||
def f_jvp(primals, tangents):
|
||||
x, y = primals
|
||||
dx, dy = tangents
|
||||
return dx * mx.cos(x) * y + dy * mx.sin(x)
|
||||
|
||||
@f.vmap
|
||||
def f_vmap(inputs, axes):
|
||||
x, y = inputs
|
||||
ax, ay = axes
|
||||
if ay != ax and ax is not None:
|
||||
y = y.swapaxes(ay, ax)
|
||||
return mx.sin(x) * y, (ax or ay)
|
||||
@f.vmap
|
||||
def f_vmap(inputs, axes):
|
||||
x, y = inputs
|
||||
ax, ay = axes
|
||||
if ay != ax and ax is not None:
|
||||
y = y.swapaxes(ay, ax)
|
||||
return mx.sin(x) * y, (ax or ay)
|
||||
)pbdoc")
|
||||
.def(
|
||||
nb::init<nb::callable>(),
|
||||
@ -1313,25 +1311,6 @@ void init_transforms(nb::module_& m) {
|
||||
Returns:
|
||||
Callable: The vectorized function.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"export_to_dot",
|
||||
[](nb::object file, const nb::args& args) {
|
||||
std::vector<mx::array> arrays = tree_flatten(args);
|
||||
if (nb::isinstance<nb::str>(file)) {
|
||||
std::ofstream out(nb::cast<std::string>(file));
|
||||
export_to_dot(out, arrays);
|
||||
} else if (nb::hasattr(file, "write")) {
|
||||
std::ostringstream out;
|
||||
export_to_dot(out, arrays);
|
||||
auto write = file.attr("write");
|
||||
write(out.str());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"export_to_dot accepts file-like objects or strings to be used as filenames");
|
||||
}
|
||||
},
|
||||
"file"_a,
|
||||
"args"_a);
|
||||
m.def(
|
||||
"compile",
|
||||
[](const nb::callable& fun,
|
||||
|
244
python/tests/test_export_import.py
Normal file
244
python/tests/test_export_import.py
Normal file
@ -0,0 +1,244 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx_tests
|
||||
|
||||
|
||||
class TestExportImport(mlx_tests.MLXTestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.test_dir_fid = tempfile.TemporaryDirectory()
|
||||
cls.test_dir = cls.test_dir_fid.name
|
||||
if not os.path.isdir(cls.test_dir):
|
||||
os.mkdir(cls.test_dir)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls.test_dir_fid.cleanup()
|
||||
|
||||
def test_basic_export_import(self):
|
||||
path = os.path.join(self.test_dir, "fn.mlxfn")
|
||||
|
||||
# Function with no inputs
|
||||
def fun():
|
||||
return mx.zeros((3, 3))
|
||||
|
||||
mx.export_function(path, fun)
|
||||
imported = mx.import_function(path)
|
||||
|
||||
expected = fun()
|
||||
(out,) = imported()
|
||||
self.assertTrue(mx.array_equal(out, expected))
|
||||
|
||||
# Simple function with inputs
|
||||
def fun(x):
|
||||
return mx.abs(mx.sin(x))
|
||||
|
||||
inputs = mx.array([1.0, 2.0, 3.0, 4.0, 5.0])
|
||||
|
||||
mx.export_function(path, fun, inputs)
|
||||
imported = mx.import_function(path)
|
||||
|
||||
expected = fun(inputs)
|
||||
(out,) = imported(inputs)
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
# Inputs in a list or tuple
|
||||
def fun(x):
|
||||
x = mx.abs(mx.sin(x))
|
||||
return x
|
||||
|
||||
mx.export_function(path, fun, [inputs])
|
||||
imported = mx.import_function(path)
|
||||
|
||||
expected = fun(inputs)
|
||||
(out,) = imported([inputs])
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
(out,) = imported(inputs)
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
mx.export_function(path, fun, (inputs,))
|
||||
imported = mx.import_function(path)
|
||||
(out,) = imported((inputs,))
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
# Outputs in a list
|
||||
def fun(x):
|
||||
return [mx.abs(mx.sin(x))]
|
||||
|
||||
mx.export_function(path, fun, inputs)
|
||||
imported = mx.import_function(path)
|
||||
(out,) = imported(inputs)
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
# Outputs in a tuple
|
||||
def fun(x):
|
||||
return (mx.abs(mx.sin(x)),)
|
||||
|
||||
mx.export_function(path, fun, inputs)
|
||||
imported = mx.import_function(path)
|
||||
(out,) = imported(inputs)
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
# Check throws on invalid inputs / outputs
|
||||
def fun(x):
|
||||
return mx.abs(x)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.export_function(path, fun, "hi")
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.export_function(path, fun, mx.array(1.0), "hi")
|
||||
|
||||
def fun(x):
|
||||
return mx.abs(x[0][0])
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.export_function(path, fun, [[mx.array(1.0)]])
|
||||
|
||||
def fun():
|
||||
return (mx.zeros((3, 3)), 1)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.export_function(path, fun)
|
||||
|
||||
def fun():
|
||||
return (mx.zeros((3, 3)), [mx.zeros((3, 3))])
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.export_function(path, fun)
|
||||
|
||||
def fun(x, y):
|
||||
return x + y
|
||||
|
||||
mx.export_function(path, fun, mx.array(1.0), mx.array(1.0))
|
||||
imported = mx.import_function(path)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
imported(mx.array(1.0), 1.0)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
imported(mx.array(1.0), mx.array(1.0), mx.array(1.0))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
imported(mx.array(1.0), [mx.array(1.0)])
|
||||
|
||||
def test_export_random_sample(self):
|
||||
path = os.path.join(self.test_dir, "fn.mlxfn")
|
||||
|
||||
mx.random.seed(5)
|
||||
|
||||
def fun():
|
||||
return mx.random.uniform(shape=(3,))
|
||||
|
||||
mx.export_function(path, fun)
|
||||
imported = mx.import_function(path)
|
||||
|
||||
(out,) = imported()
|
||||
|
||||
mx.random.seed(5)
|
||||
expected = fun()
|
||||
|
||||
self.assertTrue(mx.array_equal(out, expected))
|
||||
|
||||
def test_export_with_kwargs(self):
|
||||
path = os.path.join(self.test_dir, "fn.mlxfn")
|
||||
|
||||
def fun(x, z=None):
|
||||
out = x
|
||||
if z is not None:
|
||||
out += z
|
||||
return out
|
||||
|
||||
x = mx.array([1, 2, 3])
|
||||
y = mx.array([1, 1, 0])
|
||||
z = mx.array([2, 2, 2])
|
||||
|
||||
mx.export_function(path, fun, (x,), {"z": z})
|
||||
imported_fun = mx.import_function(path)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
imported_fun(x, z)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
imported_fun(x, y=z)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
imported_fun((x,), {"y": z})
|
||||
|
||||
out = imported_fun(x, z=z)[0]
|
||||
self.assertTrue(mx.array_equal(out, mx.array([3, 4, 5])))
|
||||
|
||||
out = imported_fun((x,), {"z": z})[0]
|
||||
self.assertTrue(mx.array_equal(out, mx.array([3, 4, 5])))
|
||||
|
||||
mx.export_function(path, fun, x, z=z)
|
||||
imported_fun = mx.import_function(path)
|
||||
out = imported_fun(x, z=z)[0]
|
||||
self.assertTrue(mx.array_equal(out, mx.array([3, 4, 5])))
|
||||
|
||||
out = imported_fun((x,), {"z": z})[0]
|
||||
self.assertTrue(mx.array_equal(out, mx.array([3, 4, 5])))
|
||||
|
||||
# Only specify kwargs
|
||||
mx.export_function(path, fun, x=x, z=z)
|
||||
imported_fun = mx.import_function(path)
|
||||
with self.assertRaises(ValueError):
|
||||
out = imported_fun(x, z=z)[0]
|
||||
|
||||
out = imported_fun(x=x, z=z)[0]
|
||||
self.assertTrue(mx.array_equal(out, mx.array([3, 4, 5])))
|
||||
|
||||
out = imported_fun({"x": x, "z": z})[0]
|
||||
self.assertTrue(mx.array_equal(out, mx.array([3, 4, 5])))
|
||||
|
||||
def test_export_variable_inputs(self):
|
||||
path = os.path.join(self.test_dir, "fn.mlxfn")
|
||||
|
||||
def fun(x, y, z=None):
|
||||
out = x + y
|
||||
if z is not None:
|
||||
out += z
|
||||
return out
|
||||
|
||||
with mx.exporter(path, fun) as exporter:
|
||||
exporter(mx.array([1, 2, 3]), mx.array([1, 1, 1]))
|
||||
exporter(mx.array([1, 2, 3]), mx.array([1, 1, 1]), z=mx.array([2]))
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
exporter(mx.array([1, 2, 3, 4]), mx.array([1, 1, 1, 1]))
|
||||
|
||||
imported_fun = mx.import_function(path)
|
||||
out = imported_fun(mx.array([1, 2, 3]), mx.array([1, 1, 1]))[0]
|
||||
self.assertTrue(mx.array_equal(out, mx.array([2, 3, 4])))
|
||||
|
||||
out = imported_fun(mx.array([1, 2, 3]), mx.array([1, 1, 1]), z=mx.array([2]))[0]
|
||||
self.assertTrue(mx.array_equal(out, mx.array([4, 5, 6])))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
imported_fun(mx.array([1, 2, 3, 4]), mx.array([1, 1, 1, 1]))
|
||||
|
||||
# A function with a large constant
|
||||
constant = mx.zeros((16, 2048))
|
||||
mx.eval(constant)
|
||||
|
||||
def fun(*args):
|
||||
return constant + sum(args)
|
||||
|
||||
with mx.exporter(path, fun) as exporter:
|
||||
for i in range(5):
|
||||
exporter(*[mx.array(1)] * i)
|
||||
|
||||
# Check the exported file size < constant size + small amount
|
||||
constants_size = constant.nbytes + 8192
|
||||
self.assertTrue(os.path.getsize(path) < constants_size)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
@ -28,15 +28,14 @@ class TestLoad(mlx_tests.MLXTestCase):
|
||||
def setUpClass(cls):
|
||||
cls.test_dir_fid = tempfile.TemporaryDirectory()
|
||||
cls.test_dir = cls.test_dir_fid.name
|
||||
if not os.path.isdir(cls.test_dir):
|
||||
os.mkdir(cls.test_dir)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls.test_dir_fid.cleanup()
|
||||
|
||||
def test_save_and_load(self):
|
||||
if not os.path.isdir(self.test_dir):
|
||||
os.mkdir(self.test_dir)
|
||||
|
||||
for dt in self.dtypes:
|
||||
with self.subTest(dtype=dt):
|
||||
for i, shape in enumerate([(1,), (23,), (1024, 1024), (4, 6, 3, 1, 2)]):
|
||||
@ -64,9 +63,6 @@ class TestLoad(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(np.array_equal(load_arr_mlx_npy, save_arr_npy))
|
||||
|
||||
def test_save_and_load_safetensors(self):
|
||||
if not os.path.isdir(self.test_dir):
|
||||
os.mkdir(self.test_dir)
|
||||
|
||||
test_file = os.path.join(self.test_dir, "test.safetensors")
|
||||
with self.assertRaises(Exception):
|
||||
mx.save_safetensors(test_file, {"a": mx.ones((4, 4))}, {"testing": 0})
|
||||
@ -330,9 +326,6 @@ class TestLoad(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(np.array_equal(save_arrs_npy[k], v))
|
||||
|
||||
def test_non_contiguous(self):
|
||||
if not os.path.isdir(self.test_dir):
|
||||
os.mkdir(self.test_dir)
|
||||
|
||||
a = mx.broadcast_to(mx.array([1, 2]), [4, 2])
|
||||
|
||||
save_file = os.path.join(self.test_dir, "a.npy")
|
||||
|
@ -24,6 +24,7 @@ target_sources(
|
||||
creations_tests.cpp
|
||||
device_tests.cpp
|
||||
einsum_tests.cpp
|
||||
export_import_tests.cpp
|
||||
eval_tests.cpp
|
||||
fft_tests.cpp
|
||||
load_tests.cpp
|
||||
|
165
tests/export_import_tests.cpp
Normal file
165
tests/export_import_tests.cpp
Normal file
@ -0,0 +1,165 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <filesystem>
|
||||
#include <stdexcept>
|
||||
#include <vector>
|
||||
|
||||
#include "doctest/doctest.h"
|
||||
|
||||
#include "mlx/export.h"
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
namespace {
|
||||
std::string get_temp_file(const std::string& name) {
|
||||
return std::filesystem::temp_directory_path().append(name);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TEST_CASE("test export basic functions") {
|
||||
std::string file_path = get_temp_file("model.mlxfn");
|
||||
|
||||
auto fun = [](std::vector<array> x) -> std::vector<array> {
|
||||
return {negative(exp(x[0]))};
|
||||
};
|
||||
|
||||
export_function(file_path, fun, {array({1.0, 2.0})});
|
||||
|
||||
auto imported_fun = import_function(file_path);
|
||||
|
||||
// Check num inputs mismatch throws
|
||||
CHECK_THROWS_AS(
|
||||
imported_fun({array({1.0}), array({2.0})}), std::invalid_argument);
|
||||
|
||||
// Check shape mismatch throws
|
||||
CHECK_THROWS_AS(imported_fun({array({1.0})}), std::invalid_argument);
|
||||
|
||||
// Check type mismatch throws
|
||||
CHECK_THROWS_AS(imported_fun({array({1.0}, float16)}), std::invalid_argument);
|
||||
|
||||
auto expected = fun({array({1.0, -1.0})});
|
||||
auto out = imported_fun({array({1.0, -1.0})});
|
||||
CHECK(allclose(expected[0], out[0]).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test export function with no inputs") {
|
||||
auto fun = [](std::vector<array> x) -> std::vector<array> {
|
||||
return {zeros({2, 2})};
|
||||
};
|
||||
|
||||
std::string file_path = get_temp_file("model.mlxfn");
|
||||
|
||||
export_function(file_path, fun, {});
|
||||
|
||||
auto imported_fun = import_function(file_path);
|
||||
|
||||
auto expected = fun({});
|
||||
auto out = imported_fun({});
|
||||
CHECK(allclose(expected[0], out[0]).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test export multi output primitives") {
|
||||
std::string file_path = get_temp_file("model.mlxfn");
|
||||
|
||||
auto fun = [](std::vector<array> x) -> std::vector<array> {
|
||||
return {divmod(x[0], x[1])};
|
||||
};
|
||||
|
||||
auto inputs = std::vector<array>{array({5.0, -10.0}), array({3.0, -2.0})};
|
||||
export_function(file_path, fun, inputs);
|
||||
|
||||
auto imported_fun = import_function(file_path);
|
||||
|
||||
auto expected = fun(inputs);
|
||||
auto out = imported_fun(inputs);
|
||||
CHECK(allclose(expected[0], out[0]).item<bool>());
|
||||
CHECK(allclose(expected[1], out[1]).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test export primitives with state") {
|
||||
std::string file_path = get_temp_file("model.mlxfn");
|
||||
|
||||
auto fun = [](std::vector<array> x) -> std::vector<array> {
|
||||
return {argpartition(x[0], 2, 0)};
|
||||
};
|
||||
|
||||
auto x = array({1, 3, 2, 4, 5, 7, 6, 8}, {4, 2});
|
||||
export_function(file_path, fun, {x});
|
||||
|
||||
auto imported_fun = import_function(file_path);
|
||||
|
||||
auto expected = fun({x});
|
||||
auto out = imported_fun({x});
|
||||
CHECK(allclose(expected[0], out[0]).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test export functions with kwargs") {
|
||||
std::string file_path = get_temp_file("model.mlxfn");
|
||||
|
||||
auto fun =
|
||||
[](const std::map<std::string, array>& kwargs) -> std::vector<array> {
|
||||
return {kwargs.at("x") + kwargs.at("y")};
|
||||
};
|
||||
|
||||
export_function(file_path, fun, {{"x", array(1)}, {"y", array(2)}});
|
||||
auto fn = import_function(file_path);
|
||||
|
||||
// Must use kwargs
|
||||
CHECK_THROWS(fn({array(1), array(2)}));
|
||||
|
||||
// Wrong number of keys
|
||||
CHECK_THROWS(fn({{"x", array(1)}, {"y", array(2)}, {"z", array(3)}}));
|
||||
|
||||
// Wrong keys
|
||||
CHECK_THROWS(fn({{"a", array(1)}, {"b", array(2)}}));
|
||||
|
||||
// Works
|
||||
auto out = fn({{"x", array(1)}, {"y", array(2)}})[0];
|
||||
CHECK_EQ(out.item<int>(), 3);
|
||||
out = fn({}, {{"x", array(1)}, {"y", array(2)}})[0];
|
||||
CHECK_EQ(out.item<int>(), 3);
|
||||
}
|
||||
|
||||
TEST_CASE("test export function with variable inputs") {
|
||||
std::string file_path = get_temp_file("model.mlxfn");
|
||||
|
||||
auto fun = [](const std::vector<array>& args) -> std::vector<array> {
|
||||
auto out = array({1, 1, 1, 1});
|
||||
for (auto x : args) {
|
||||
out = out + x;
|
||||
}
|
||||
return {out};
|
||||
};
|
||||
|
||||
{
|
||||
auto fn_exporter = exporter(file_path, fun);
|
||||
fn_exporter({array(0), array(0)});
|
||||
fn_exporter({array(0), array(0), array(0)});
|
||||
}
|
||||
|
||||
auto imported_fun = import_function(file_path);
|
||||
|
||||
// Call with two inputs
|
||||
auto out = imported_fun({array(1), array(2)})[0];
|
||||
|
||||
CHECK(array_equal(out, array({4, 4, 4, 4})).item<bool>());
|
||||
|
||||
// Call with three inputs
|
||||
out = imported_fun({array(1), array(2), array(3)})[0];
|
||||
CHECK(array_equal(out, array({7, 7, 7, 7})).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test export function on different stream") {
|
||||
std::string file_path = get_temp_file("model.mlxfn");
|
||||
|
||||
// Caller is responsible for setting up streams before
|
||||
// importing functoins
|
||||
auto fun = [](const std::vector<array>& args) -> std::vector<array> {
|
||||
return {abs(args[0], Stream(1000, Device::cpu))};
|
||||
};
|
||||
|
||||
export_function(file_path, fun, {array({0, 1, 2})});
|
||||
|
||||
CHECK_THROWS(import_function(file_path));
|
||||
}
|
Loading…
Reference in New Issue
Block a user