mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Export / import functions to / from a file (#1642)
* export and import functions * refactor + works for few primitives * nit * allow primitives with state * nit * nit * simplify serialize / deserialize * fix for constants * python bindings * maybe fix serialize failure case * add example * more primitives, training kind of works * same result for python and c++ * some fixes * fix export * template it up * some simplificatoin * rebase * allow kwargs and multiple functions * exporter * more primitives for exporting * deal with endianness * handle invalid stream * add docstring
This commit is contained in:
parent
935c8c4bb1
commit
4ba0c24a8f
@ -27,6 +27,7 @@ option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
|||||||
if(NOT MLX_VERSION)
|
if(NOT MLX_VERSION)
|
||||||
set(MLX_VERSION 0.21.1)
|
set(MLX_VERSION 0.21.1)
|
||||||
endif()
|
endif()
|
||||||
|
add_compile_definitions("MLX_VERSION=${MLX_VERSION}")
|
||||||
|
|
||||||
# --------------------- Processor tests -------------------------
|
# --------------------- Processor tests -------------------------
|
||||||
|
|
||||||
|
@ -61,6 +61,7 @@ are the CPU and GPU.
|
|||||||
python/array
|
python/array
|
||||||
python/data_types
|
python/data_types
|
||||||
python/devices_and_streams
|
python/devices_and_streams
|
||||||
|
python/export
|
||||||
python/ops
|
python/ops
|
||||||
python/random
|
python/random
|
||||||
python/transforms
|
python/transforms
|
||||||
|
14
docs/src/python/export.rst
Normal file
14
docs/src/python/export.rst
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
.. _export:
|
||||||
|
|
||||||
|
Export Functions
|
||||||
|
================
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.core
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
export_function
|
||||||
|
import_function
|
||||||
|
exporter
|
||||||
|
export_to_dot
|
27
examples/export/CMakeLists.txt
Normal file
27
examples/export/CMakeLists.txt
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
cmake_minimum_required(VERSION 3.27)
|
||||||
|
|
||||||
|
project(import_mlx LANGUAGES CXX)
|
||||||
|
|
||||||
|
# ----------------------------- Setup -----------------------------
|
||||||
|
set(CMAKE_CXX_STANDARD 17)
|
||||||
|
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||||
|
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||||
|
|
||||||
|
# ----------------------------- Dependencies -----------------------------
|
||||||
|
find_package(
|
||||||
|
Python 3.9
|
||||||
|
COMPONENTS Interpreter Development.Module
|
||||||
|
REQUIRED)
|
||||||
|
execute_process(
|
||||||
|
COMMAND "${Python_EXECUTABLE}" -m pip show mlx
|
||||||
|
COMMAND grep location
|
||||||
|
COMMAND awk "{print $4 \"/mlx\"}"
|
||||||
|
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||||
|
OUTPUT_VARIABLE MLX_ROOT)
|
||||||
|
find_package(MLX CONFIG REQUIRED)
|
||||||
|
|
||||||
|
add_executable(eval_mlp eval_mlp.cpp)
|
||||||
|
target_link_libraries(eval_mlp PRIVATE mlx)
|
||||||
|
|
||||||
|
add_executable(train_mlp train_mlp.cpp)
|
||||||
|
target_link_libraries(train_mlp PRIVATE mlx)
|
49
examples/export/README.md
Normal file
49
examples/export/README.md
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
## Setup
|
||||||
|
|
||||||
|
Install MLX:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install mlx>=0.22
|
||||||
|
```
|
||||||
|
|
||||||
|
Build the C++ examples:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cmake -B build -DCMAKE_BUILD_TYPE=Release
|
||||||
|
cmake --build build
|
||||||
|
```
|
||||||
|
|
||||||
|
## Run
|
||||||
|
|
||||||
|
### Eval MLP
|
||||||
|
|
||||||
|
Run the Python script to export the eval function:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python eval_mlp.py
|
||||||
|
```
|
||||||
|
|
||||||
|
Then run the C++ program to import and run the function:
|
||||||
|
|
||||||
|
```
|
||||||
|
./build/eval_mlp
|
||||||
|
```
|
||||||
|
|
||||||
|
The Python and C++ programs should output the same result.
|
||||||
|
|
||||||
|
### Train MLP
|
||||||
|
|
||||||
|
Run the Python script to export the model initialization and training
|
||||||
|
functions:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python train_mlp.py
|
||||||
|
```
|
||||||
|
|
||||||
|
Then run the C++ program to import and run the functions:
|
||||||
|
|
||||||
|
```
|
||||||
|
./build/train_mlp
|
||||||
|
```
|
||||||
|
|
||||||
|
The Python and C++ programs should output the same results.
|
25
examples/export/eval_mlp.cpp
Normal file
25
examples/export/eval_mlp.cpp
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include <mlx/mlx.h>
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
using namespace mlx::core;
|
||||||
|
|
||||||
|
int main() {
|
||||||
|
int batch_size = 8;
|
||||||
|
int input_dim = 32;
|
||||||
|
|
||||||
|
// Make the input
|
||||||
|
random::seed(42);
|
||||||
|
auto example_x = random::uniform({batch_size, input_dim});
|
||||||
|
|
||||||
|
// Import the function
|
||||||
|
auto forward = import_function("eval_mlp.mlxfn");
|
||||||
|
|
||||||
|
// Call the imported function
|
||||||
|
auto out = forward({example_x})[0];
|
||||||
|
|
||||||
|
std::cout << out << std::endl;
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
52
examples/export/eval_mlp.py
Normal file
52
examples/export/eval_mlp.py
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
import mlx.utils
|
||||||
|
|
||||||
|
|
||||||
|
class MLP(nn.Module):
|
||||||
|
"""A simple MLP."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim]
|
||||||
|
self.layers = [
|
||||||
|
nn.Linear(idim, odim)
|
||||||
|
for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])
|
||||||
|
]
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
for l in self.layers[:-1]:
|
||||||
|
x = nn.relu(l(x))
|
||||||
|
return self.layers[-1](x)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
batch_size = 8
|
||||||
|
input_dim = 32
|
||||||
|
output_dim = 10
|
||||||
|
|
||||||
|
# Load the model
|
||||||
|
mx.random.seed(0) # Seed for params
|
||||||
|
model = MLP(num_layers=5, input_dim=input_dim, hidden_dim=64, output_dim=output_dim)
|
||||||
|
mx.eval(model)
|
||||||
|
|
||||||
|
# Note, the model parameters are saved in the export function
|
||||||
|
def forward(x):
|
||||||
|
return model(x)
|
||||||
|
|
||||||
|
mx.random.seed(42) # Seed for input
|
||||||
|
example_x = mx.random.uniform(shape=(batch_size, input_dim))
|
||||||
|
|
||||||
|
mx.export_function("eval_mlp.mlxfn", forward, example_x)
|
||||||
|
|
||||||
|
# Import in Python
|
||||||
|
imported_forward = mx.import_function("eval_mlp.mlxfn")
|
||||||
|
expected = forward(example_x)
|
||||||
|
(out,) = imported_forward(example_x)
|
||||||
|
assert mx.allclose(expected, out)
|
||||||
|
print(out)
|
35
examples/export/train_mlp.cpp
Normal file
35
examples/export/train_mlp.cpp
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include <mlx/mlx.h>
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
using namespace mlx::core;
|
||||||
|
|
||||||
|
int main() {
|
||||||
|
int batch_size = 8;
|
||||||
|
int input_dim = 32;
|
||||||
|
int output_dim = 10;
|
||||||
|
|
||||||
|
auto state = import_function("init_mlp.mlxfn")({});
|
||||||
|
|
||||||
|
// Make the input
|
||||||
|
random::seed(42);
|
||||||
|
auto example_X = random::normal({batch_size, input_dim});
|
||||||
|
auto example_y = random::randint(0, output_dim, {batch_size});
|
||||||
|
|
||||||
|
// Import the function
|
||||||
|
auto step = import_function("train_mlp.mlxfn");
|
||||||
|
|
||||||
|
// Call the imported function
|
||||||
|
for (int it = 0; it < 100; ++it) {
|
||||||
|
state.insert(state.end(), {example_X, example_y});
|
||||||
|
state = step(state);
|
||||||
|
eval(state);
|
||||||
|
auto loss = state.back();
|
||||||
|
state.pop_back();
|
||||||
|
if (it % 10 == 0) {
|
||||||
|
std::cout << "Loss " << loss.item<float>() << std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
76
examples/export/train_mlp.py
Normal file
76
examples/export/train_mlp.py
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
import mlx.optimizers as optim
|
||||||
|
import mlx.utils
|
||||||
|
|
||||||
|
|
||||||
|
class MLP(nn.Module):
|
||||||
|
"""A simple MLP."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim]
|
||||||
|
self.layers = [
|
||||||
|
nn.Linear(idim, odim)
|
||||||
|
for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])
|
||||||
|
]
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
for l in self.layers[:-1]:
|
||||||
|
x = nn.relu(l(x))
|
||||||
|
return self.layers[-1](x)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
batch_size = 8
|
||||||
|
input_dim = 32
|
||||||
|
output_dim = 10
|
||||||
|
|
||||||
|
def init():
|
||||||
|
# Seed for the parameter initialization
|
||||||
|
mx.random.seed(0)
|
||||||
|
model = MLP(
|
||||||
|
num_layers=3, input_dim=input_dim, hidden_dim=64, output_dim=output_dim
|
||||||
|
)
|
||||||
|
optimizer = optim.SGD(learning_rate=1e-1)
|
||||||
|
optimizer.init(model.parameters())
|
||||||
|
state = [model.parameters(), optimizer.state]
|
||||||
|
tree_structure, state = zip(*mlx.utils.tree_flatten(state))
|
||||||
|
return model, optimizer, tree_structure, state
|
||||||
|
|
||||||
|
# Export the model parameter initialization
|
||||||
|
model, optimizer, tree_structure, state = init()
|
||||||
|
mx.eval(state)
|
||||||
|
mx.export_function("init_mlp.mlxfn", lambda: init()[-1])
|
||||||
|
|
||||||
|
def loss_fn(params, X, y):
|
||||||
|
model.update(params)
|
||||||
|
return nn.losses.cross_entropy(model(X), y, reduction="mean")
|
||||||
|
|
||||||
|
def step(*inputs):
|
||||||
|
*state, X, y = inputs
|
||||||
|
params, opt_state = mlx.utils.tree_unflatten(list(zip(tree_structure, state)))
|
||||||
|
optimizer.state = opt_state
|
||||||
|
loss, grads = mx.value_and_grad(loss_fn)(params, X, y)
|
||||||
|
params = optimizer.apply_gradients(grads, params)
|
||||||
|
_, state = zip(*mlx.utils.tree_flatten([params, optimizer.state]))
|
||||||
|
return *state, loss
|
||||||
|
|
||||||
|
# Make some random data
|
||||||
|
mx.random.seed(42)
|
||||||
|
example_X = mx.random.normal(shape=(batch_size, input_dim))
|
||||||
|
example_y = mx.random.randint(low=0, high=output_dim, shape=(batch_size,))
|
||||||
|
mx.export_function("train_mlp.mlxfn", step, *state, example_X, example_y)
|
||||||
|
|
||||||
|
# Export one step of SGD
|
||||||
|
imported_step = mx.import_function("train_mlp.mlxfn")
|
||||||
|
|
||||||
|
for it in range(100):
|
||||||
|
*state, loss = imported_step(*state, example_X, example_y)
|
||||||
|
if it % 10 == 0:
|
||||||
|
print(f"Loss {loss.item():.6}")
|
@ -5,6 +5,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/export.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||||
|
@ -156,9 +156,6 @@ CompileMode& compile_mode() {
|
|||||||
return compile_mode_;
|
return compile_mode_;
|
||||||
}
|
}
|
||||||
|
|
||||||
using ParentsMap =
|
|
||||||
std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>;
|
|
||||||
|
|
||||||
// Helper like below but only merges the two provided arrays. If the src has
|
// Helper like below but only merges the two provided arrays. If the src has
|
||||||
// siblings then these won't be merged to the dst.
|
// siblings then these won't be merged to the dst.
|
||||||
void merge_one(array& dst, array& src, ParentsMap& parents_map) {
|
void merge_one(array& dst, array& src, ParentsMap& parents_map) {
|
||||||
@ -732,10 +729,15 @@ std::vector<array> compile_replace(
|
|||||||
trace_to_real.insert({trace_inputs[i].id(), inputs[i]});
|
trace_to_real.insert({trace_inputs[i].id(), inputs[i]});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto is_load = [](const Primitive& p) { return typeid(p) == typeid(Load); };
|
||||||
|
|
||||||
for (auto& a : tape) {
|
for (auto& a : tape) {
|
||||||
// Arrays in the tape without primitives are constants
|
// Arrays in the tape without primitives are either:
|
||||||
// and can be used directly
|
// - inputs, which are already in the map
|
||||||
if (!a.has_primitive()) {
|
// - constants, which can be used directly
|
||||||
|
// - a load primitive which has no inputs and will become a constant
|
||||||
|
// after the first eval
|
||||||
|
if (!a.has_primitive() || is_load(a.primitive())) {
|
||||||
trace_to_real.insert({a.id(), a});
|
trace_to_real.insert({a.id(), a});
|
||||||
} else {
|
} else {
|
||||||
// Find real inputs
|
// Find real inputs
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "mlx/device.h"
|
#include "mlx/array.h"
|
||||||
|
|
||||||
namespace mlx::core::detail {
|
namespace mlx::core::detail {
|
||||||
|
|
||||||
@ -22,4 +22,35 @@ void compile_erase(std::uintptr_t fun_id);
|
|||||||
void compile_clear_cache();
|
void compile_clear_cache();
|
||||||
|
|
||||||
bool compile_available_for_device(const Device& device);
|
bool compile_available_for_device(const Device& device);
|
||||||
|
|
||||||
|
std::pair<std::vector<array>, std::vector<array>> compile_trace(
|
||||||
|
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||||
|
const std::vector<array>& inputs);
|
||||||
|
|
||||||
|
using ParentsMap =
|
||||||
|
std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>;
|
||||||
|
|
||||||
|
// Traverses the graph to build a tape and a map of array ids to their parents
|
||||||
|
std::pair<std::vector<array>, ParentsMap> compile_dfs(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<array>& outputs,
|
||||||
|
const std::vector<array>& original_inputs);
|
||||||
|
|
||||||
|
// Simplify the tape. Note, this function modifies in-place both the tape and
|
||||||
|
// the parents map to remove orphaned arrays
|
||||||
|
void compile_simplify(
|
||||||
|
std::vector<array>& tape,
|
||||||
|
ParentsMap& parents_map,
|
||||||
|
const std::vector<array>& outputs,
|
||||||
|
int passes);
|
||||||
|
|
||||||
|
std::vector<array> compile_replace(
|
||||||
|
const std::vector<array>& tape,
|
||||||
|
const std::vector<array>& trace_inputs,
|
||||||
|
const std::vector<array>& trace_outputs,
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
bool shapeless);
|
||||||
|
|
||||||
|
void compile_validate_shapeless(const std::vector<array>& tape);
|
||||||
|
|
||||||
} // namespace mlx::core::detail
|
} // namespace mlx::core::detail
|
||||||
|
867
mlx/export.cpp
Normal file
867
mlx/export.cpp
Normal file
@ -0,0 +1,867 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
#include "mlx/export.h"
|
||||||
|
#include "mlx/compile_impl.h"
|
||||||
|
#include "mlx/fast_primitives.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
|
#define STRINGIFY(x) #x
|
||||||
|
#define TOSTRING(x) STRINGIFY(x)
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
#define SERIALIZE_PRIMITIVE(primitive, keys...) \
|
||||||
|
{ \
|
||||||
|
#primitive, { \
|
||||||
|
[](Writer& os, const Primitive& p) { \
|
||||||
|
serialize_primitive<primitive>(os, p); \
|
||||||
|
}, \
|
||||||
|
[](Reader& is, Stream s) { \
|
||||||
|
return deserialize_primitive<primitive>(is, s); \
|
||||||
|
}, \
|
||||||
|
{keys} \
|
||||||
|
} \
|
||||||
|
}
|
||||||
|
// clang-format on
|
||||||
|
|
||||||
|
bool is_big_endian() {
|
||||||
|
int num = 1;
|
||||||
|
return *reinterpret_cast<char*>(&num) != 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
using namespace mlx::core::fast;
|
||||||
|
|
||||||
|
using Reader = io::ParallelFileReader;
|
||||||
|
using Writer = io::FileWriter;
|
||||||
|
|
||||||
|
struct PrimitiveSerializer {
|
||||||
|
using Serializer = std::function<void(Writer&, const Primitive&)>;
|
||||||
|
using Deserializer =
|
||||||
|
std::function<std::shared_ptr<Primitive>(Reader&, Stream s)>;
|
||||||
|
PrimitiveSerializer(
|
||||||
|
Serializer serialize,
|
||||||
|
Deserializer deserialize,
|
||||||
|
std::vector<std::string> keys = {})
|
||||||
|
: serialize(std::move(serialize)),
|
||||||
|
deserialize(std::move(deserialize)),
|
||||||
|
keys(std::move(keys)) {};
|
||||||
|
Serializer serialize;
|
||||||
|
Deserializer deserialize;
|
||||||
|
std::vector<std::string> keys;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename, typename = void>
|
||||||
|
constexpr bool is_iterable = false;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
constexpr bool is_iterable<
|
||||||
|
T,
|
||||||
|
std::void_t<
|
||||||
|
decltype(std::declval<T>().begin()),
|
||||||
|
decltype(std::declval<T>().end())>> = true;
|
||||||
|
|
||||||
|
template <template <typename...> class T, typename U>
|
||||||
|
constexpr bool is_specialization_of = false;
|
||||||
|
|
||||||
|
template <template <typename...> class T, typename... Us>
|
||||||
|
constexpr bool is_specialization_of<T, T<Us...>> = true;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
constexpr bool is_pair = is_specialization_of<std::pair, std::decay_t<T>>;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
constexpr bool is_tuple = is_specialization_of<std::tuple, std::decay_t<T>>;
|
||||||
|
|
||||||
|
template <typename>
|
||||||
|
constexpr bool dependent_false = false;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct NotSerializable {
|
||||||
|
static_assert(dependent_false<T>, "Type is not serializable.");
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct NotDeserializable {
|
||||||
|
static_assert(dependent_false<T>, "Type is not deserializable.");
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void reverse_bytes(T& data) {
|
||||||
|
auto* bytes = reinterpret_cast<uint8_t*>(&data);
|
||||||
|
for (size_t j = 0; j < (sizeof(T) / 2); j++) {
|
||||||
|
std::swap(bytes[j], bytes[sizeof(T) - j - 1]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void serialize(Writer& os, T v) {
|
||||||
|
if constexpr (std::is_arithmetic_v<T>) {
|
||||||
|
if (is_big_endian()) {
|
||||||
|
reverse_bytes(v);
|
||||||
|
}
|
||||||
|
os.write(reinterpret_cast<const char*>(&v), sizeof(T));
|
||||||
|
} else if constexpr (std::is_enum_v<T>) {
|
||||||
|
serialize(os, static_cast<int>(v));
|
||||||
|
} else if constexpr (std::is_same_v<T, std::nullptr_t>) {
|
||||||
|
} else if constexpr (is_iterable<T>) {
|
||||||
|
serialize(os, static_cast<uint64_t>(v.size()));
|
||||||
|
for (const auto& t : v) {
|
||||||
|
serialize(os, t);
|
||||||
|
}
|
||||||
|
} else if constexpr (is_pair<T> || is_tuple<T>) {
|
||||||
|
std::apply([&os](auto&... x) { (..., serialize(os, x)); }, v);
|
||||||
|
} else {
|
||||||
|
NotSerializable<T>();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, std::size_t... I>
|
||||||
|
decltype(auto) deserialize_tuple(Reader& is, std::index_sequence<I...>);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
T deserialize(Reader& is) {
|
||||||
|
if constexpr (std::is_arithmetic_v<T>) {
|
||||||
|
T v;
|
||||||
|
is.read(reinterpret_cast<char*>(&v), sizeof(T));
|
||||||
|
if (is_big_endian()) {
|
||||||
|
reverse_bytes(v);
|
||||||
|
}
|
||||||
|
return v;
|
||||||
|
} else if constexpr (std::is_enum_v<T>) {
|
||||||
|
return static_cast<T>(deserialize<int>(is));
|
||||||
|
} else if constexpr (std::is_same_v<T, std::nullptr_t>) {
|
||||||
|
return nullptr;
|
||||||
|
} else if constexpr (is_iterable<T>) {
|
||||||
|
T v;
|
||||||
|
auto size = deserialize<uint64_t>(is);
|
||||||
|
v.reserve(size);
|
||||||
|
for (int i = 0; i < size; ++i) {
|
||||||
|
v.push_back(deserialize<typename T::value_type>(is));
|
||||||
|
}
|
||||||
|
return v;
|
||||||
|
} else if constexpr (is_pair<T> || is_tuple<T>) {
|
||||||
|
return deserialize_tuple<T>(
|
||||||
|
is, std::make_index_sequence<std::tuple_size_v<std::decay_t<T>>>{});
|
||||||
|
} else {
|
||||||
|
NotDeserializable<T>();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, std::size_t... I>
|
||||||
|
decltype(auto) deserialize_tuple(Reader& is, std::index_sequence<I...>) {
|
||||||
|
return T{deserialize<std::tuple_element_t<I, T>>(is)...};
|
||||||
|
};
|
||||||
|
|
||||||
|
void serialize(Writer& os, const Stream& s) {
|
||||||
|
serialize(os, s.index);
|
||||||
|
serialize(os, s.device.type);
|
||||||
|
serialize(os, s.device.index);
|
||||||
|
}
|
||||||
|
template <>
|
||||||
|
Stream deserialize(Reader& is) {
|
||||||
|
auto stream_index = deserialize<int>(is);
|
||||||
|
auto device_type = deserialize<Device::DeviceType>(is);
|
||||||
|
auto device_index = deserialize<int>(is);
|
||||||
|
return Stream(stream_index, Device(device_type, device_index));
|
||||||
|
}
|
||||||
|
|
||||||
|
void serialize(Writer& os, const Dtype& t) {
|
||||||
|
serialize(os, t.val());
|
||||||
|
serialize(os, t.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
Dtype deserialize(Reader& is) {
|
||||||
|
auto val = deserialize<Dtype::Val>(is);
|
||||||
|
auto size = deserialize<uint8_t>(is);
|
||||||
|
return Dtype(val, size);
|
||||||
|
};
|
||||||
|
|
||||||
|
void serialize(Writer& os, const array& arr) {
|
||||||
|
serialize(os, arr.shape());
|
||||||
|
serialize(os, arr.dtype());
|
||||||
|
}
|
||||||
|
template <>
|
||||||
|
array deserialize(Reader& is) {
|
||||||
|
auto shape = deserialize<std::vector<int>>(is);
|
||||||
|
auto type = deserialize<Dtype>(is);
|
||||||
|
return array(std::move(shape), type, nullptr, std::vector<array>{});
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename, typename = void>
|
||||||
|
constexpr bool has_state = false;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
constexpr bool has_state<T, std::void_t<decltype(std::declval<T>().state())>> =
|
||||||
|
true;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void serialize_primitive(Writer& os, const Primitive& p) {
|
||||||
|
if constexpr (has_state<T>) {
|
||||||
|
serialize(os, static_cast<const T&>(p).state());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
std::shared_ptr<T> deserialize_primitive(Reader& is, Stream s) {
|
||||||
|
if constexpr (has_state<T>) {
|
||||||
|
auto args = deserialize<decltype(std::declval<T>().state())>(is);
|
||||||
|
if constexpr (is_pair<decltype(args)> || is_tuple<decltype(args)>) {
|
||||||
|
auto fn = [s](auto&&... args) {
|
||||||
|
return std::make_shared<T>(s, std::move(args)...);
|
||||||
|
};
|
||||||
|
return std::apply(fn, std::move(args));
|
||||||
|
} else {
|
||||||
|
return std::make_shared<T>(s, std::move(args));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return std::make_shared<T>(s);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct PrimitiveFactory {
|
||||||
|
std::unordered_map<std::string, PrimitiveSerializer> factory = {
|
||||||
|
SERIALIZE_PRIMITIVE(Abs),
|
||||||
|
SERIALIZE_PRIMITIVE(Add),
|
||||||
|
SERIALIZE_PRIMITIVE(AddMM),
|
||||||
|
SERIALIZE_PRIMITIVE(Arange),
|
||||||
|
SERIALIZE_PRIMITIVE(ArcCos),
|
||||||
|
SERIALIZE_PRIMITIVE(ArcCosh),
|
||||||
|
SERIALIZE_PRIMITIVE(ArcSin),
|
||||||
|
SERIALIZE_PRIMITIVE(ArcSinh),
|
||||||
|
SERIALIZE_PRIMITIVE(ArcTan),
|
||||||
|
SERIALIZE_PRIMITIVE(ArcTan2),
|
||||||
|
SERIALIZE_PRIMITIVE(ArcTanh),
|
||||||
|
SERIALIZE_PRIMITIVE(ArgPartition),
|
||||||
|
SERIALIZE_PRIMITIVE(ArgReduce),
|
||||||
|
SERIALIZE_PRIMITIVE(ArgSort),
|
||||||
|
SERIALIZE_PRIMITIVE(AsType),
|
||||||
|
SERIALIZE_PRIMITIVE(AsStrided),
|
||||||
|
SERIALIZE_PRIMITIVE(
|
||||||
|
BitwiseBinary,
|
||||||
|
"BitwiseAnd",
|
||||||
|
"BitwiseOr",
|
||||||
|
"BitwiseXor",
|
||||||
|
"LeftShift",
|
||||||
|
"RightShift"),
|
||||||
|
SERIALIZE_PRIMITIVE(BlockMaskedMM),
|
||||||
|
SERIALIZE_PRIMITIVE(Broadcast),
|
||||||
|
SERIALIZE_PRIMITIVE(Ceil),
|
||||||
|
SERIALIZE_PRIMITIVE(Concatenate),
|
||||||
|
SERIALIZE_PRIMITIVE(Conjugate),
|
||||||
|
SERIALIZE_PRIMITIVE(Convolution),
|
||||||
|
SERIALIZE_PRIMITIVE(Copy),
|
||||||
|
SERIALIZE_PRIMITIVE(Cos),
|
||||||
|
SERIALIZE_PRIMITIVE(Cosh),
|
||||||
|
SERIALIZE_PRIMITIVE(Depends),
|
||||||
|
SERIALIZE_PRIMITIVE(Divide),
|
||||||
|
SERIALIZE_PRIMITIVE(DivMod),
|
||||||
|
SERIALIZE_PRIMITIVE(Equal, "NaNEqual"),
|
||||||
|
SERIALIZE_PRIMITIVE(Erf),
|
||||||
|
SERIALIZE_PRIMITIVE(ErfInv),
|
||||||
|
SERIALIZE_PRIMITIVE(Exp),
|
||||||
|
SERIALIZE_PRIMITIVE(Expm1),
|
||||||
|
SERIALIZE_PRIMITIVE(ExpandDims),
|
||||||
|
SERIALIZE_PRIMITIVE(FFT),
|
||||||
|
SERIALIZE_PRIMITIVE(Flatten),
|
||||||
|
SERIALIZE_PRIMITIVE(Floor),
|
||||||
|
SERIALIZE_PRIMITIVE(Full),
|
||||||
|
SERIALIZE_PRIMITIVE(Gather),
|
||||||
|
SERIALIZE_PRIMITIVE(GatherMM),
|
||||||
|
SERIALIZE_PRIMITIVE(Greater),
|
||||||
|
SERIALIZE_PRIMITIVE(GreaterEqual),
|
||||||
|
SERIALIZE_PRIMITIVE(Hadamard),
|
||||||
|
SERIALIZE_PRIMITIVE(Imag),
|
||||||
|
SERIALIZE_PRIMITIVE(Less),
|
||||||
|
SERIALIZE_PRIMITIVE(LessEqual),
|
||||||
|
SERIALIZE_PRIMITIVE(Log, "Log2", "Log10"),
|
||||||
|
SERIALIZE_PRIMITIVE(Log1p),
|
||||||
|
SERIALIZE_PRIMITIVE(LogicalNot),
|
||||||
|
SERIALIZE_PRIMITIVE(LogicalAnd),
|
||||||
|
SERIALIZE_PRIMITIVE(LogicalOr),
|
||||||
|
SERIALIZE_PRIMITIVE(LogAddExp),
|
||||||
|
SERIALIZE_PRIMITIVE(Matmul),
|
||||||
|
SERIALIZE_PRIMITIVE(Maximum),
|
||||||
|
SERIALIZE_PRIMITIVE(Minimum),
|
||||||
|
SERIALIZE_PRIMITIVE(Multiply),
|
||||||
|
SERIALIZE_PRIMITIVE(Negative),
|
||||||
|
SERIALIZE_PRIMITIVE(NotEqual),
|
||||||
|
SERIALIZE_PRIMITIVE(Reshape),
|
||||||
|
SERIALIZE_PRIMITIVE(NumberOfElements),
|
||||||
|
SERIALIZE_PRIMITIVE(Pad),
|
||||||
|
SERIALIZE_PRIMITIVE(Partition),
|
||||||
|
SERIALIZE_PRIMITIVE(Power),
|
||||||
|
SERIALIZE_PRIMITIVE(QuantizedMatmul),
|
||||||
|
SERIALIZE_PRIMITIVE(GatherQMM),
|
||||||
|
SERIALIZE_PRIMITIVE(RandomBits),
|
||||||
|
SERIALIZE_PRIMITIVE(Real),
|
||||||
|
SERIALIZE_PRIMITIVE(Remainder),
|
||||||
|
SERIALIZE_PRIMITIVE(Reshape),
|
||||||
|
SERIALIZE_PRIMITIVE(Reduce, "And", "Or", "Sum", "Prod", "Min", "Max"),
|
||||||
|
SERIALIZE_PRIMITIVE(Round),
|
||||||
|
SERIALIZE_PRIMITIVE(Scan, "CumSum", "CumProd", "CumMin", "CumMax"),
|
||||||
|
SERIALIZE_PRIMITIVE(Scatter),
|
||||||
|
SERIALIZE_PRIMITIVE(Select),
|
||||||
|
SERIALIZE_PRIMITIVE(Sigmoid),
|
||||||
|
SERIALIZE_PRIMITIVE(Sign),
|
||||||
|
SERIALIZE_PRIMITIVE(Sin),
|
||||||
|
SERIALIZE_PRIMITIVE(Sinh),
|
||||||
|
SERIALIZE_PRIMITIVE(Slice),
|
||||||
|
SERIALIZE_PRIMITIVE(SliceUpdate),
|
||||||
|
SERIALIZE_PRIMITIVE(Softmax),
|
||||||
|
SERIALIZE_PRIMITIVE(Sort),
|
||||||
|
SERIALIZE_PRIMITIVE(Split),
|
||||||
|
SERIALIZE_PRIMITIVE(Square),
|
||||||
|
SERIALIZE_PRIMITIVE(Squeeze),
|
||||||
|
SERIALIZE_PRIMITIVE(Sqrt, "Rsqrt", "Sqrt"),
|
||||||
|
SERIALIZE_PRIMITIVE(StopGradient),
|
||||||
|
SERIALIZE_PRIMITIVE(Subtract),
|
||||||
|
SERIALIZE_PRIMITIVE(Tan),
|
||||||
|
SERIALIZE_PRIMITIVE(Tanh),
|
||||||
|
SERIALIZE_PRIMITIVE(View),
|
||||||
|
SERIALIZE_PRIMITIVE(Transpose),
|
||||||
|
SERIALIZE_PRIMITIVE(Unflatten),
|
||||||
|
SERIALIZE_PRIMITIVE(QRF),
|
||||||
|
SERIALIZE_PRIMITIVE(SVD),
|
||||||
|
SERIALIZE_PRIMITIVE(Inverse),
|
||||||
|
SERIALIZE_PRIMITIVE(Cholesky),
|
||||||
|
SERIALIZE_PRIMITIVE(Eigh),
|
||||||
|
SERIALIZE_PRIMITIVE(AffineQuantize),
|
||||||
|
SERIALIZE_PRIMITIVE(RMSNorm),
|
||||||
|
SERIALIZE_PRIMITIVE(RMSNormVJP),
|
||||||
|
SERIALIZE_PRIMITIVE(LayerNorm),
|
||||||
|
SERIALIZE_PRIMITIVE(LayerNormVJP),
|
||||||
|
SERIALIZE_PRIMITIVE(RoPE),
|
||||||
|
SERIALIZE_PRIMITIVE(ScaledDotProductAttention)};
|
||||||
|
std::unordered_map<std::string, std::string> name_remap;
|
||||||
|
|
||||||
|
PrimitiveFactory() {
|
||||||
|
for (auto& [n, f] : factory) {
|
||||||
|
for (auto& k : f.keys) {
|
||||||
|
name_remap[k] = n;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void save(Writer& os, const std::shared_ptr<Primitive>& p) {
|
||||||
|
serialize(os, p->stream());
|
||||||
|
std::ostringstream pout;
|
||||||
|
p->print(pout);
|
||||||
|
auto name = pout.str();
|
||||||
|
name = name.substr(0, name.find(' '));
|
||||||
|
if (auto it = name_remap.find(name); it != name_remap.end()) {
|
||||||
|
name = it->second;
|
||||||
|
}
|
||||||
|
serialize(os, name);
|
||||||
|
if (auto it = factory.find(name); it != factory.end()) {
|
||||||
|
it->second.serialize(os, *p);
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[export_function] Unable to serialize primitive " + name);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
std::shared_ptr<Primitive> load(Reader& is) {
|
||||||
|
auto stream = deserialize<Stream>(is);
|
||||||
|
if (get_stream(stream.index) != stream) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[import_function] Invalid stream encountered " << stream << ".";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
auto name = deserialize<std::string>(is);
|
||||||
|
if (auto it = factory.find(name); it != factory.end()) {
|
||||||
|
return it->second.deserialize(is, stream);
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[import_function] Unable to deserialize primitive " + name);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
void write_header(Writer& os, int count, bool shapeless) {
|
||||||
|
serialize(os, std::string(TOSTRING(MLX_VERSION)));
|
||||||
|
serialize(os, count);
|
||||||
|
serialize(os, shapeless);
|
||||||
|
}
|
||||||
|
|
||||||
|
// A struct to hold and retrieve the graphs that are exported / imported
|
||||||
|
struct FunctionTable {
|
||||||
|
FunctionTable(bool shapeless = false) : shapeless(shapeless) {};
|
||||||
|
struct Function {
|
||||||
|
Function(
|
||||||
|
std::vector<std::string> kwarg_keys,
|
||||||
|
std::vector<array> inputs,
|
||||||
|
std::vector<array> outputs,
|
||||||
|
std::vector<array> tape)
|
||||||
|
: kwarg_keys(std::move(kwarg_keys)),
|
||||||
|
inputs(std::move(inputs)),
|
||||||
|
outputs(std::move(outputs)),
|
||||||
|
tape(std::move(tape)) {}
|
||||||
|
|
||||||
|
std::vector<std::string> kwarg_keys;
|
||||||
|
std::vector<array> inputs;
|
||||||
|
std::vector<array> outputs;
|
||||||
|
std::vector<array> tape;
|
||||||
|
Function(const Function&) = delete;
|
||||||
|
Function& operator=(const Function&) = delete;
|
||||||
|
Function(Function&&) = default;
|
||||||
|
Function() = default;
|
||||||
|
};
|
||||||
|
bool shapeless;
|
||||||
|
std::unordered_map<int, std::vector<Function>> table;
|
||||||
|
Function* find(const Args& args, const Kwargs& kwargs);
|
||||||
|
std::pair<Function&, bool> emplace(const Args& args, const Kwargs& kwargs);
|
||||||
|
void insert(
|
||||||
|
std::vector<std::string> kwarg_keys,
|
||||||
|
std::vector<array> inputs,
|
||||||
|
std::vector<array> outputs,
|
||||||
|
std::vector<array> tape) {
|
||||||
|
auto [it, _] = table.emplace(inputs.size(), std::vector<Function>{});
|
||||||
|
it->second.emplace_back(
|
||||||
|
std::move(kwarg_keys),
|
||||||
|
std::move(inputs),
|
||||||
|
std::move(outputs),
|
||||||
|
std::move(tape));
|
||||||
|
}
|
||||||
|
|
||||||
|
void print_functions(std::ostream& os) {
|
||||||
|
int n = 1;
|
||||||
|
for (auto& [_, vec] : table) {
|
||||||
|
for (auto& fun : vec) {
|
||||||
|
auto npos = fun.inputs.size() - fun.kwarg_keys.size();
|
||||||
|
os << " " << n++ << ". Function with " << npos
|
||||||
|
<< " positional inputs and " << fun.kwarg_keys.size()
|
||||||
|
<< " keyword inputs:\n";
|
||||||
|
for (int j = 0; j < fun.inputs.size(); ++j) {
|
||||||
|
auto& in = fun.inputs[j];
|
||||||
|
if (j < npos) {
|
||||||
|
os << " " << j + 1 << ": ";
|
||||||
|
} else {
|
||||||
|
os << " \"" << fun.kwarg_keys[j - npos] << "\": ";
|
||||||
|
}
|
||||||
|
os << in.shape() << " " << in.dtype() << "\n";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool match(const Args& args, const Kwargs& kwargs, const Function& fun);
|
||||||
|
};
|
||||||
|
|
||||||
|
bool FunctionTable::match(
|
||||||
|
const Args& args,
|
||||||
|
const Kwargs& kwargs,
|
||||||
|
const Function& fun) {
|
||||||
|
for (auto& k : fun.kwarg_keys) {
|
||||||
|
if (kwargs.find(k) == kwargs.end()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto match_inputs = [shapeless = this->shapeless](
|
||||||
|
const array& x, const array& y) {
|
||||||
|
if (x.dtype() != y.dtype()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (!shapeless && x.shape() != y.shape()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
|
||||||
|
int i = 0;
|
||||||
|
for (; i < args.size(); ++i) {
|
||||||
|
if (!match_inputs(args[i], fun.inputs[i])) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (auto& [_, in] : kwargs) {
|
||||||
|
if (!match_inputs(in, fun.inputs[i++])) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<FunctionTable::Function&, bool> FunctionTable::emplace(
|
||||||
|
const Args& args,
|
||||||
|
const Kwargs& kwargs) {
|
||||||
|
auto n_inputs = args.size() + kwargs.size();
|
||||||
|
auto [it, _] = table.emplace(n_inputs, std::vector<Function>{});
|
||||||
|
auto& funs_vec = it->second;
|
||||||
|
|
||||||
|
for (auto& fun : funs_vec) {
|
||||||
|
if (match(args, kwargs, fun)) {
|
||||||
|
return {fun, false};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
funs_vec.emplace_back();
|
||||||
|
return {funs_vec.back(), true};
|
||||||
|
}
|
||||||
|
|
||||||
|
FunctionTable::Function* FunctionTable::find(
|
||||||
|
const Args& args,
|
||||||
|
const Kwargs& kwargs) {
|
||||||
|
auto n_inputs = args.size() + kwargs.size();
|
||||||
|
auto it = table.find(n_inputs);
|
||||||
|
if (it == table.end()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto& fun : it->second) {
|
||||||
|
if (match(args, kwargs, fun)) {
|
||||||
|
return &fun;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
FunctionExporter::FunctionExporter(
|
||||||
|
const std::string& file,
|
||||||
|
std::function<std::vector<array>(const Args&, const Kwargs&)> fun,
|
||||||
|
bool shapeless)
|
||||||
|
: os(file),
|
||||||
|
fun(std::move(fun)),
|
||||||
|
ftable(std::make_shared<FunctionTable>(shapeless)) {
|
||||||
|
if (!os.is_open()) {
|
||||||
|
throw std::runtime_error("[export_function] Failed to open " + file);
|
||||||
|
}
|
||||||
|
write_header(os, count, shapeless);
|
||||||
|
}
|
||||||
|
|
||||||
|
void FunctionExporter::close() {
|
||||||
|
closed = true;
|
||||||
|
};
|
||||||
|
|
||||||
|
void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) {
|
||||||
|
if (closed) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[export_function] Attempting to write after exporting is closed.");
|
||||||
|
}
|
||||||
|
auto [fentry, inserted] = ftable->emplace(args, kwargs);
|
||||||
|
if (!inserted) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[export_function] Attempting to export a function twice with "
|
||||||
|
"the same signature is not allowed.");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flatten the inputs to the function for tracing
|
||||||
|
std::vector<std::string> kwarg_keys;
|
||||||
|
auto inputs = args;
|
||||||
|
for (auto& [k, v] : kwargs) {
|
||||||
|
kwarg_keys.push_back(k);
|
||||||
|
inputs.push_back(v);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto flat_fun = [this, &kwarg_keys](const Args& flat_args) {
|
||||||
|
auto args = Args(flat_args.begin(), flat_args.end() - kwarg_keys.size());
|
||||||
|
Kwargs kwargs;
|
||||||
|
auto it = flat_args.end() - kwarg_keys.size();
|
||||||
|
;
|
||||||
|
for (auto& k : kwarg_keys) {
|
||||||
|
kwargs.insert({k, *it++});
|
||||||
|
}
|
||||||
|
return fun(args, kwargs);
|
||||||
|
};
|
||||||
|
|
||||||
|
// Trace to build the graph
|
||||||
|
auto [trace_inputs, trace_outputs] = detail::compile_trace(flat_fun, inputs);
|
||||||
|
|
||||||
|
// DFS the graph and get the tape
|
||||||
|
auto [tape, parents_map] =
|
||||||
|
detail::compile_dfs(trace_inputs, trace_outputs, inputs);
|
||||||
|
|
||||||
|
detail::compile_simplify(tape, parents_map, trace_outputs, /* passes */ 3);
|
||||||
|
|
||||||
|
// Update header
|
||||||
|
count++;
|
||||||
|
|
||||||
|
// Overwrite the header
|
||||||
|
auto pos = os.tell();
|
||||||
|
os.seek(0);
|
||||||
|
write_header(os, count, ftable->shapeless);
|
||||||
|
os.seek(pos);
|
||||||
|
serialize(os, kwarg_keys);
|
||||||
|
|
||||||
|
auto arrays_to_ids = [](const std::vector<array>& arrs) {
|
||||||
|
std::vector<uint64_t> ids;
|
||||||
|
for (auto& arr : arrs) {
|
||||||
|
ids.push_back(arr.id());
|
||||||
|
}
|
||||||
|
return ids;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Inputs and outputs
|
||||||
|
auto trace_input_ids = arrays_to_ids(trace_inputs);
|
||||||
|
serialize(os, trace_input_ids);
|
||||||
|
serialize(os, trace_inputs);
|
||||||
|
serialize(os, arrays_to_ids(trace_outputs));
|
||||||
|
|
||||||
|
// Update the table entry
|
||||||
|
fentry.kwarg_keys = std::move(kwarg_keys);
|
||||||
|
fentry.inputs = std::move(trace_inputs);
|
||||||
|
|
||||||
|
std::unordered_set<std::uintptr_t> input_set(
|
||||||
|
trace_input_ids.begin(), trace_input_ids.end());
|
||||||
|
|
||||||
|
// Tape
|
||||||
|
auto factory = PrimitiveFactory();
|
||||||
|
serialize(os, static_cast<uint64_t>(tape.size()));
|
||||||
|
for (auto& arr : tape) {
|
||||||
|
serialize(os, static_cast<uint64_t>(arr.id()));
|
||||||
|
if (arr.has_primitive()) {
|
||||||
|
serialize(os, true);
|
||||||
|
serialize(os, arrays_to_ids(arr.inputs()));
|
||||||
|
factory.save(os, arr.primitive_ptr());
|
||||||
|
serialize(os, static_cast<uint64_t>(arr.siblings().size()));
|
||||||
|
if (arr.siblings().empty()) {
|
||||||
|
serialize(os, arr.shape());
|
||||||
|
serialize(os, arr.dtype());
|
||||||
|
} else {
|
||||||
|
auto outputs = arr.outputs();
|
||||||
|
serialize(os, arrays_to_ids(outputs));
|
||||||
|
|
||||||
|
std::vector<std::vector<int>> shapes;
|
||||||
|
std::vector<Dtype> dtypes;
|
||||||
|
for (auto& o : outputs) {
|
||||||
|
shapes.push_back(o.shape());
|
||||||
|
dtypes.push_back(o.dtype());
|
||||||
|
}
|
||||||
|
serialize(os, shapes);
|
||||||
|
serialize(os, dtypes);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
serialize(os, false);
|
||||||
|
if (input_set.find(arr.id()) == input_set.end()) {
|
||||||
|
serialize(os, true);
|
||||||
|
// Save constant data if not already saved
|
||||||
|
if (constants.insert(arr.id()).second) {
|
||||||
|
serialize(os, arr.shape());
|
||||||
|
serialize(os, arr.dtype());
|
||||||
|
os.write(arr.data<char>(), arr.nbytes());
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
serialize(os, false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void FunctionExporter::operator()(const Args& args) {
|
||||||
|
export_function(args, {});
|
||||||
|
}
|
||||||
|
|
||||||
|
void FunctionExporter::operator()(const Kwargs& kwargs) {
|
||||||
|
export_function({}, kwargs);
|
||||||
|
}
|
||||||
|
|
||||||
|
void FunctionExporter::operator()(const Args& args, const Kwargs& kwargs) {
|
||||||
|
export_function(args, kwargs);
|
||||||
|
}
|
||||||
|
|
||||||
|
FunctionExporter exporter(
|
||||||
|
const std::string& file,
|
||||||
|
const std::function<std::vector<array>(const Args&)>& fun,
|
||||||
|
bool shapeless /* = false */) {
|
||||||
|
return FunctionExporter{
|
||||||
|
file,
|
||||||
|
[fun](const Args& args, const Kwargs&) { return fun(args); },
|
||||||
|
shapeless};
|
||||||
|
}
|
||||||
|
|
||||||
|
FunctionExporter exporter(
|
||||||
|
const std::string& file,
|
||||||
|
const std::function<std::vector<array>(const Kwargs&)>& fun,
|
||||||
|
bool shapeless /* = false */) {
|
||||||
|
return exporter(
|
||||||
|
file,
|
||||||
|
[fun](const Args&, const Kwargs kwargs) { return fun(kwargs); },
|
||||||
|
shapeless);
|
||||||
|
}
|
||||||
|
|
||||||
|
FunctionExporter exporter(
|
||||||
|
const std::string& file,
|
||||||
|
const std::function<std::vector<array>(const Args&, const Kwargs&)>& fun,
|
||||||
|
bool shapeless /* = false */) {
|
||||||
|
return FunctionExporter{file, fun, shapeless};
|
||||||
|
}
|
||||||
|
|
||||||
|
void export_function(
|
||||||
|
const std::string& file,
|
||||||
|
const std::function<std::vector<array>(const Args&)>& fun,
|
||||||
|
const Args& args,
|
||||||
|
bool shapeless /* = false */) {
|
||||||
|
exporter(file, fun, shapeless)(args);
|
||||||
|
}
|
||||||
|
|
||||||
|
void export_function(
|
||||||
|
const std::string& file,
|
||||||
|
const std::function<std::vector<array>(const Kwargs&)>& fun,
|
||||||
|
const Kwargs& kwargs,
|
||||||
|
bool shapeless /* = false */) {
|
||||||
|
exporter(file, fun, shapeless)(kwargs);
|
||||||
|
}
|
||||||
|
|
||||||
|
void export_function(
|
||||||
|
const std::string& file,
|
||||||
|
const std::function<std::vector<array>(const Args&, const Kwargs&)>& fun,
|
||||||
|
const Args& args,
|
||||||
|
const Kwargs& kwargs,
|
||||||
|
bool shapeless /* = false */) {
|
||||||
|
exporter(file, fun, shapeless)(args, kwargs);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<array> ImportedFunction::operator()(const Kwargs& kwargs) const {
|
||||||
|
return this->operator()({}, kwargs);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<array> ImportedFunction::operator()(const Args& args) const {
|
||||||
|
return this->operator()(args, {});
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<array> ImportedFunction::operator()(
|
||||||
|
const Args& args,
|
||||||
|
const Kwargs& kwargs) const {
|
||||||
|
auto* fun = ftable->find(args, kwargs);
|
||||||
|
if (fun == nullptr) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[import_function::call] No imported function found which matches "
|
||||||
|
<< "the given positional and keyword arguments. Possible functions include:\n";
|
||||||
|
ftable->print_functions(msg);
|
||||||
|
msg << "\nReceived function with " << args.size()
|
||||||
|
<< " positional inputs and " << kwargs.size() << " keyword inputs:\n";
|
||||||
|
for (int i = 0; i < args.size(); ++i) {
|
||||||
|
auto& in = args[i];
|
||||||
|
msg << " " << i + 1 << ": " << in.shape() << " " << in.dtype() << "\n";
|
||||||
|
}
|
||||||
|
for (auto& [k, in] : kwargs) {
|
||||||
|
msg << " \"" << k << "\": " << in.shape() << " " << in.dtype() << "\n";
|
||||||
|
}
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
auto inputs = args;
|
||||||
|
for (auto& [_, v] : kwargs) {
|
||||||
|
inputs.push_back(v);
|
||||||
|
}
|
||||||
|
return detail::compile_replace(
|
||||||
|
fun->tape, fun->inputs, fun->outputs, inputs, ftable->shapeless);
|
||||||
|
}
|
||||||
|
|
||||||
|
ImportedFunction import_function(const std::string& file) {
|
||||||
|
return ImportedFunction{file};
|
||||||
|
}
|
||||||
|
|
||||||
|
ImportedFunction::ImportedFunction(const std::string& file)
|
||||||
|
: ftable(std::make_shared<FunctionTable>()) {
|
||||||
|
auto is_ptr = std::make_shared<Reader>(file);
|
||||||
|
auto& is = *is_ptr;
|
||||||
|
if (!is.is_open()) {
|
||||||
|
throw std::runtime_error("[import_function] Failed to open " + file);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse header
|
||||||
|
auto mlx_version = deserialize<std::string>(is);
|
||||||
|
auto function_count = deserialize<int>(is);
|
||||||
|
ftable->shapeless = deserialize<bool>(is);
|
||||||
|
std::unordered_map<std::uintptr_t, array> constants;
|
||||||
|
|
||||||
|
auto import_one = [&]() {
|
||||||
|
auto kwarg_keys = deserialize<std::vector<std::string>>(is);
|
||||||
|
|
||||||
|
std::unordered_map<uint64_t, array> array_map;
|
||||||
|
auto trace_input_ids = deserialize<std::vector<uint64_t>>(is);
|
||||||
|
auto trace_inputs = deserialize<std::vector<array>>(is);
|
||||||
|
for (int i = 0; i < trace_inputs.size(); ++i) {
|
||||||
|
array_map.emplace(trace_input_ids[i], trace_inputs[i]);
|
||||||
|
}
|
||||||
|
auto trace_output_ids = deserialize<std::vector<uint64_t>>(is);
|
||||||
|
|
||||||
|
std::vector<array> tape;
|
||||||
|
auto tape_size = deserialize<uint64_t>(is);
|
||||||
|
tape.reserve(tape_size);
|
||||||
|
|
||||||
|
auto factory = PrimitiveFactory();
|
||||||
|
for (size_t i = 0; i < tape_size; ++i) {
|
||||||
|
auto id = deserialize<uint64_t>(is);
|
||||||
|
if (deserialize<bool>(is)) {
|
||||||
|
auto input_ids = deserialize<std::vector<uint64_t>>(is);
|
||||||
|
std::vector<array> inputs;
|
||||||
|
inputs.reserve(input_ids.size());
|
||||||
|
for (auto id : input_ids) {
|
||||||
|
inputs.push_back(array_map.at(id));
|
||||||
|
}
|
||||||
|
std::shared_ptr<Primitive> prim = factory.load(is);
|
||||||
|
auto num_siblings = deserialize<uint64_t>(is);
|
||||||
|
if (num_siblings == 0) {
|
||||||
|
auto shape = deserialize<std::vector<int>>(is);
|
||||||
|
auto type = deserialize<Dtype>(is);
|
||||||
|
tape.emplace_back(
|
||||||
|
std::move(shape), type, std::move(prim), std::move(inputs));
|
||||||
|
array_map.emplace(id, tape.back());
|
||||||
|
} else {
|
||||||
|
auto ids = deserialize<std::vector<uint64_t>>(is);
|
||||||
|
auto shapes = deserialize<std::vector<std::vector<int>>>(is);
|
||||||
|
auto types = deserialize<std::vector<Dtype>>(is);
|
||||||
|
auto arrays = array::make_arrays(
|
||||||
|
std::move(shapes),
|
||||||
|
std::move(types),
|
||||||
|
std::move(prim),
|
||||||
|
std::move(inputs));
|
||||||
|
for (int i = 0; i < arrays.size(); ++i) {
|
||||||
|
auto sid = ids[i];
|
||||||
|
if (sid == id) {
|
||||||
|
tape.push_back(arrays[i]);
|
||||||
|
}
|
||||||
|
array_map.emplace(sid, arrays[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (deserialize<bool>(is)) {
|
||||||
|
// Load constant
|
||||||
|
if (auto it = constants.find(id); it != constants.end()) {
|
||||||
|
tape.push_back(it->second);
|
||||||
|
} else {
|
||||||
|
auto shape = deserialize<std::vector<int>>(is);
|
||||||
|
auto type = deserialize<Dtype>(is);
|
||||||
|
size_t offset = is.tell();
|
||||||
|
tape.push_back(array(
|
||||||
|
std::move(shape),
|
||||||
|
type,
|
||||||
|
std::make_shared<Load>(
|
||||||
|
default_stream(default_device()), is_ptr, offset),
|
||||||
|
{}));
|
||||||
|
is.seek(offset + tape.back().nbytes());
|
||||||
|
constants.insert({id, tape.back()});
|
||||||
|
}
|
||||||
|
array_map.emplace(id, tape.back());
|
||||||
|
} else {
|
||||||
|
// Function inputs are in the map
|
||||||
|
tape.push_back(array_map.at(id));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<array> trace_outputs;
|
||||||
|
trace_outputs.reserve(trace_output_ids.size());
|
||||||
|
for (auto id : trace_output_ids) {
|
||||||
|
trace_outputs.push_back(array_map.at(id));
|
||||||
|
}
|
||||||
|
ftable->insert(
|
||||||
|
std::move(kwarg_keys),
|
||||||
|
std::move(trace_inputs),
|
||||||
|
std::move(trace_outputs),
|
||||||
|
std::move(tape));
|
||||||
|
};
|
||||||
|
|
||||||
|
for (int i = 0; i < function_count; ++i) {
|
||||||
|
import_one();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
66
mlx/export.h
Normal file
66
mlx/export.h
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <map>
|
||||||
|
#include <set>
|
||||||
|
#include "mlx/array.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
using Args = std::vector<array>;
|
||||||
|
using Kwargs = std::map<std::string, array>;
|
||||||
|
|
||||||
|
struct FunctionExporter;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Make an exporter to save multiple traces of a given function to
|
||||||
|
* the same file.
|
||||||
|
*/
|
||||||
|
FunctionExporter exporter(
|
||||||
|
const std::string& file,
|
||||||
|
const std::function<std::vector<array>(const Args&)>& fun,
|
||||||
|
bool shapeless = false);
|
||||||
|
|
||||||
|
FunctionExporter exporter(
|
||||||
|
const std::string& file,
|
||||||
|
const std::function<std::vector<array>(const Kwargs&)>& fun,
|
||||||
|
bool shapeless = false);
|
||||||
|
|
||||||
|
FunctionExporter exporter(
|
||||||
|
const std::string& path,
|
||||||
|
const std::function<std::vector<array>(const Args&, const Kwargs&)>& fun,
|
||||||
|
bool shapeless = false);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Export a function to a file.
|
||||||
|
*/
|
||||||
|
void export_function(
|
||||||
|
const std::string& file,
|
||||||
|
const std::function<std::vector<array>(const Args&)>& fun,
|
||||||
|
const Args& args,
|
||||||
|
bool shapeless = false);
|
||||||
|
|
||||||
|
void export_function(
|
||||||
|
const std::string& file,
|
||||||
|
const std::function<std::vector<array>(const Kwargs&)>& fun,
|
||||||
|
const Kwargs& kwargs,
|
||||||
|
bool shapeless = false);
|
||||||
|
|
||||||
|
void export_function(
|
||||||
|
const std::string& file,
|
||||||
|
const std::function<std::vector<array>(const Args&, const Kwargs&)>& fun,
|
||||||
|
const Args& args,
|
||||||
|
const Kwargs& kwargs,
|
||||||
|
bool shapeless = false);
|
||||||
|
|
||||||
|
struct ImportedFunction;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Import a function from a file.
|
||||||
|
*/
|
||||||
|
ImportedFunction import_function(const std::string& file);
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
|
|
||||||
|
#include "mlx/export_impl.h"
|
71
mlx/export_impl.h
Normal file
71
mlx/export_impl.h
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/io/load.h"
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
struct FunctionTable;
|
||||||
|
|
||||||
|
struct FunctionExporter {
|
||||||
|
void operator()(const std::initializer_list<array>& args) {
|
||||||
|
this->operator()(Args(args));
|
||||||
|
}
|
||||||
|
void operator()(const Args& args);
|
||||||
|
void operator()(const Kwargs& kwargs);
|
||||||
|
void operator()(const Args& args, const Kwargs& kwargs);
|
||||||
|
|
||||||
|
void close();
|
||||||
|
|
||||||
|
FunctionExporter(const FunctionExporter&) = delete;
|
||||||
|
FunctionExporter& operator=(const FunctionExporter&) = delete;
|
||||||
|
FunctionExporter(FunctionExporter&& other) = default;
|
||||||
|
|
||||||
|
private:
|
||||||
|
friend FunctionExporter exporter(
|
||||||
|
const std::string&,
|
||||||
|
const std::function<std::vector<array>(const Args&)>&,
|
||||||
|
bool shapeless);
|
||||||
|
|
||||||
|
friend FunctionExporter exporter(
|
||||||
|
const std::string&,
|
||||||
|
const std::function<std::vector<array>(const Kwargs&)>&,
|
||||||
|
bool shapeless);
|
||||||
|
|
||||||
|
friend FunctionExporter exporter(
|
||||||
|
const std::string&,
|
||||||
|
const std::function<std::vector<array>(const Args&, const Kwargs&)>&,
|
||||||
|
bool shapeless);
|
||||||
|
|
||||||
|
FunctionExporter(
|
||||||
|
const std::string& file,
|
||||||
|
std::function<std::vector<array>(const Args&, const Kwargs&)> fun,
|
||||||
|
bool shapeless);
|
||||||
|
io::FileWriter os;
|
||||||
|
std::function<std::vector<array>(const Args&, const Kwargs& kwargs)> fun;
|
||||||
|
void export_function(const Args& args, const Kwargs& kwargs);
|
||||||
|
std::set<std::uintptr_t> constants;
|
||||||
|
int count{0};
|
||||||
|
bool closed{false};
|
||||||
|
std::shared_ptr<FunctionTable> ftable;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ImportedFunction {
|
||||||
|
std::vector<array> operator()(
|
||||||
|
const std::initializer_list<array>& args) const {
|
||||||
|
return this->operator()(Args(args));
|
||||||
|
}
|
||||||
|
std::vector<array> operator()(const Args& args) const;
|
||||||
|
std::vector<array> operator()(const Kwargs& kwargs) const;
|
||||||
|
std::vector<array> operator()(const Args& args, const Kwargs& kwargs) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
ImportedFunction(const std::string& file);
|
||||||
|
friend ImportedFunction import_function(const std::string&);
|
||||||
|
ImportedFunction();
|
||||||
|
|
||||||
|
std::shared_ptr<FunctionTable> ftable;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
@ -697,8 +697,7 @@ array scaled_dot_product_attention(
|
|||||||
return array(
|
return array(
|
||||||
std::move(out_shape),
|
std::move(out_shape),
|
||||||
final_type,
|
final_type,
|
||||||
std::make_shared<ScaledDotProductAttention>(
|
std::make_shared<ScaledDotProductAttention>(stream, fallback, scale),
|
||||||
stream, fallback, scale, false),
|
|
||||||
{q, k, v});
|
{q, k, v});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -712,7 +711,7 @@ array scaled_dot_product_attention(
|
|||||||
bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const {
|
bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const {
|
||||||
const ScaledDotProductAttention& a_other =
|
const ScaledDotProductAttention& a_other =
|
||||||
static_cast<const ScaledDotProductAttention&>(other);
|
static_cast<const ScaledDotProductAttention&>(other);
|
||||||
return needs_mask_ == a_other.needs_mask_ && scale_ == a_other.scale_;
|
return scale_ == a_other.scale_;
|
||||||
}
|
}
|
||||||
|
|
||||||
array pack_and_quantize(
|
array pack_and_quantize(
|
||||||
|
@ -60,6 +60,10 @@ class RMSNorm : public Custom {
|
|||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
|
|
||||||
|
auto state() const {
|
||||||
|
return std::make_pair(nullptr, eps_);
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
||||||
float eps_;
|
float eps_;
|
||||||
@ -82,6 +86,9 @@ class RMSNormVJP : public Custom {
|
|||||||
|
|
||||||
DEFINE_PRINT(RMSNormVJP)
|
DEFINE_PRINT(RMSNormVJP)
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
auto state() const {
|
||||||
|
return std::make_pair(nullptr, eps_);
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
||||||
@ -112,6 +119,9 @@ class LayerNorm : public Custom {
|
|||||||
DEFINE_PRINT(LayerNorm)
|
DEFINE_PRINT(LayerNorm)
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
|
auto state() const {
|
||||||
|
return std::make_pair(nullptr, eps_);
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
||||||
@ -135,6 +145,9 @@ class LayerNormVJP : public Custom {
|
|||||||
|
|
||||||
DEFINE_PRINT(LayerNormVJP)
|
DEFINE_PRINT(LayerNormVJP)
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
auto state() const {
|
||||||
|
return std::make_pair(nullptr, eps_);
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
||||||
@ -174,6 +187,10 @@ class RoPE : public Custom {
|
|||||||
DEFINE_PRINT(RoPE)
|
DEFINE_PRINT(RoPE)
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
|
auto state() const {
|
||||||
|
return std::make_tuple(
|
||||||
|
nullptr, dims_, traditional_, base_, scale_, forward_);
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
||||||
@ -189,9 +206,8 @@ class ScaledDotProductAttention : public Custom {
|
|||||||
explicit ScaledDotProductAttention(
|
explicit ScaledDotProductAttention(
|
||||||
Stream stream,
|
Stream stream,
|
||||||
std::function<std::vector<array>(std::vector<array>)> fallback,
|
std::function<std::vector<array>(std::vector<array>)> fallback,
|
||||||
const float scale,
|
const float scale)
|
||||||
const bool needs_mask)
|
: Custom(stream, fallback), scale_(scale) {}
|
||||||
: Custom(stream, fallback), scale_(scale), needs_mask_(needs_mask) {}
|
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||||
override {
|
override {
|
||||||
@ -208,11 +224,13 @@ class ScaledDotProductAttention : public Custom {
|
|||||||
|
|
||||||
DEFINE_PRINT(ScaledDotProductAttention);
|
DEFINE_PRINT(ScaledDotProductAttention);
|
||||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
|
auto state() const {
|
||||||
|
return std::make_pair(nullptr, scale_);
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
||||||
float scale_;
|
float scale_;
|
||||||
bool needs_mask_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
class AffineQuantize : public Custom {
|
class AffineQuantize : public Custom {
|
||||||
@ -238,6 +256,9 @@ class AffineQuantize : public Custom {
|
|||||||
|
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||||
|
auto state() const {
|
||||||
|
return std::make_tuple(nullptr, group_size_, bits_, dequantize_);
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
||||||
|
@ -78,8 +78,15 @@ class ParallelFileReader : public Reader {
|
|||||||
return lseek(fd_, 0, SEEK_CUR);
|
return lseek(fd_, 0, SEEK_CUR);
|
||||||
}
|
}
|
||||||
|
|
||||||
void seek(int64_t, std::ios_base::seekdir = std::ios_base::beg) override {
|
// Warning: do not use this function from multiple threads as
|
||||||
throw std::runtime_error("[ParallelFileReader::seek] Not allowed");
|
// it advances the file descriptor
|
||||||
|
void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
|
||||||
|
override {
|
||||||
|
if (way == std::ios_base::beg) {
|
||||||
|
lseek(fd_, off, 0);
|
||||||
|
} else {
|
||||||
|
lseek(fd_, off, SEEK_CUR);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Warning: do not use this function from multiple threads as
|
// Warning: do not use this function from multiple threads as
|
||||||
@ -108,9 +115,17 @@ class FileWriter : public Writer {
|
|||||||
0644)),
|
0644)),
|
||||||
label_(std::move(file_path)) {}
|
label_(std::move(file_path)) {}
|
||||||
|
|
||||||
|
FileWriter(const FileWriter&) = delete;
|
||||||
|
FileWriter& operator=(const FileWriter&) = delete;
|
||||||
|
FileWriter(FileWriter&& other) {
|
||||||
|
std::swap(fd_, other.fd_);
|
||||||
|
}
|
||||||
|
|
||||||
~FileWriter() override {
|
~FileWriter() override {
|
||||||
|
if (fd_ != 0) {
|
||||||
close(fd_);
|
close(fd_);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
bool is_open() const override {
|
bool is_open() const override {
|
||||||
return fd_ >= 0;
|
return fd_ >= 0;
|
||||||
@ -151,7 +166,7 @@ class FileWriter : public Writer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int fd_;
|
int fd_{0};
|
||||||
std::string label_;
|
std::string label_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -9,6 +9,7 @@
|
|||||||
#include "mlx/distributed/distributed.h"
|
#include "mlx/distributed/distributed.h"
|
||||||
#include "mlx/distributed/ops.h"
|
#include "mlx/distributed/ops.h"
|
||||||
#include "mlx/einsum.h"
|
#include "mlx/einsum.h"
|
||||||
|
#include "mlx/export.h"
|
||||||
#include "mlx/fast.h"
|
#include "mlx/fast.h"
|
||||||
#include "mlx/fft.h"
|
#include "mlx/fft.h"
|
||||||
#include "mlx/io.h"
|
#include "mlx/io.h"
|
||||||
|
148
mlx/primitives.h
148
mlx/primitives.h
@ -203,6 +203,9 @@ class AddMM : public UnaryPrimitive {
|
|||||||
DEFINE_PRINT(AddMM)
|
DEFINE_PRINT(AddMM)
|
||||||
|
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
std::pair<float, float> state() const {
|
||||||
|
return {alpha_, beta_};
|
||||||
|
};
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const float alpha_;
|
const float alpha_;
|
||||||
@ -220,6 +223,9 @@ class Arange : public UnaryPrimitive {
|
|||||||
DEFINE_PRINT(Arange)
|
DEFINE_PRINT(Arange)
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||||
|
std::tuple<double, double, double> state() const {
|
||||||
|
return {start_, stop_, step_};
|
||||||
|
};
|
||||||
|
|
||||||
private:
|
private:
|
||||||
double start_;
|
double start_;
|
||||||
@ -361,6 +367,9 @@ class ArgPartition : public UnaryPrimitive {
|
|||||||
DEFINE_PRINT(ArgPartition)
|
DEFINE_PRINT(ArgPartition)
|
||||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
std::pair<int, int> state() const {
|
||||||
|
return {kth_, axis_};
|
||||||
|
};
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int kth_;
|
int kth_;
|
||||||
@ -387,6 +396,9 @@ class ArgReduce : public UnaryPrimitive {
|
|||||||
DEFINE_PRINT(ArgReduce)
|
DEFINE_PRINT(ArgReduce)
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||||
|
std::pair<ReduceType, int> state() const {
|
||||||
|
return {reduce_type_, axis_};
|
||||||
|
};
|
||||||
|
|
||||||
private:
|
private:
|
||||||
ReduceType reduce_type_;
|
ReduceType reduce_type_;
|
||||||
@ -407,6 +419,9 @@ class ArgSort : public UnaryPrimitive {
|
|||||||
DEFINE_PRINT(ArgSort)
|
DEFINE_PRINT(ArgSort)
|
||||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
int state() const {
|
||||||
|
return axis_;
|
||||||
|
};
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int axis_;
|
int axis_;
|
||||||
@ -427,6 +442,9 @@ class AsType : public UnaryPrimitive {
|
|||||||
DEFINE_PRINT(AsType)
|
DEFINE_PRINT(AsType)
|
||||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
Dtype state() const {
|
||||||
|
return dtype_;
|
||||||
|
};
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Dtype dtype_;
|
Dtype dtype_;
|
||||||
@ -448,6 +466,9 @@ class AsStrided : public UnaryPrimitive {
|
|||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_PRINT(AsStrided)
|
DEFINE_PRINT(AsStrided)
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
auto state() const {
|
||||||
|
return std::make_tuple(shape_, strides_, offset_);
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Shape shape_;
|
Shape shape_;
|
||||||
@ -472,6 +493,9 @@ class BitwiseBinary : public UnaryPrimitive {
|
|||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
void print(std::ostream& os) override;
|
void print(std::ostream& os) override;
|
||||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
|
auto state() const {
|
||||||
|
return op_;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Op op_;
|
Op op_;
|
||||||
@ -493,6 +517,9 @@ class BlockMaskedMM : public UnaryPrimitive {
|
|||||||
|
|
||||||
DEFINE_PRINT(BlockMaskedMM)
|
DEFINE_PRINT(BlockMaskedMM)
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
auto state() const {
|
||||||
|
return block_size_;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int block_size_;
|
int block_size_;
|
||||||
@ -532,6 +559,9 @@ class Broadcast : public UnaryPrimitive {
|
|||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_PRINT(Broadcast)
|
DEFINE_PRINT(Broadcast)
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
std::vector<int> state() const {
|
||||||
|
return shape_;
|
||||||
|
};
|
||||||
|
|
||||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||||
|
|
||||||
@ -613,6 +643,9 @@ class Concatenate : public UnaryPrimitive {
|
|||||||
DEFINE_PRINT(Concatenate)
|
DEFINE_PRINT(Concatenate)
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||||
|
auto state() const {
|
||||||
|
return axis_;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int axis_;
|
int axis_;
|
||||||
@ -684,6 +717,15 @@ class Convolution : public UnaryPrimitive {
|
|||||||
|
|
||||||
DEFINE_PRINT(Convolution)
|
DEFINE_PRINT(Convolution)
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
auto state() const {
|
||||||
|
return std::make_tuple(
|
||||||
|
padding_,
|
||||||
|
kernel_strides_,
|
||||||
|
kernel_dilation_,
|
||||||
|
input_dilation_,
|
||||||
|
groups_,
|
||||||
|
flip_);
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<int> padding_;
|
std::vector<int> padding_;
|
||||||
@ -912,6 +954,9 @@ class Equal : public UnaryPrimitive {
|
|||||||
os << "Equal";
|
os << "Equal";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
auto state() const {
|
||||||
|
return equal_nan_;
|
||||||
|
};
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
@ -1001,6 +1046,9 @@ class ExpandDims : public UnaryPrimitive {
|
|||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
|
||||||
static Shape output_shape(const array& input, const std::vector<int>& axes);
|
static Shape output_shape(const array& input, const std::vector<int>& axes);
|
||||||
|
auto state() const {
|
||||||
|
return axes_;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
@ -1024,6 +1072,9 @@ class FFT : public UnaryPrimitive {
|
|||||||
DEFINE_PRINT(FFT)
|
DEFINE_PRINT(FFT)
|
||||||
|
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
auto state() const {
|
||||||
|
return std::make_tuple(axes_, inverse_, real_);
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<size_t> axes_;
|
std::vector<size_t> axes_;
|
||||||
@ -1048,6 +1099,9 @@ class Flatten : public UnaryPrimitive {
|
|||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
|
||||||
static Shape output_shape(const array& input, int start_axis, int end_axis);
|
static Shape output_shape(const array& input, int start_axis, int end_axis);
|
||||||
|
auto state() const {
|
||||||
|
return std::make_pair(start_axis_, end_axis_);
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int start_axis_;
|
int start_axis_;
|
||||||
@ -1103,6 +1157,9 @@ class Gather : public UnaryPrimitive {
|
|||||||
DEFINE_PRINT(Gather)
|
DEFINE_PRINT(Gather)
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||||
|
std::pair<std::vector<int>, std::vector<int>> state() const {
|
||||||
|
return {axes_, slice_sizes_};
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
@ -1158,6 +1215,9 @@ class Hadamard : public UnaryPrimitive {
|
|||||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
|
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
auto state() const {
|
||||||
|
return scale_;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
float scale_;
|
float scale_;
|
||||||
@ -1260,6 +1320,10 @@ class Log : public UnaryPrimitive {
|
|||||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
|
|
||||||
|
Base state() const {
|
||||||
|
return base_;
|
||||||
|
};
|
||||||
|
|
||||||
void print(std::ostream& os) override {
|
void print(std::ostream& os) override {
|
||||||
switch (base_) {
|
switch (base_) {
|
||||||
case e:
|
case e:
|
||||||
@ -1488,6 +1552,9 @@ class NumberOfElements : public UnaryPrimitive {
|
|||||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override {
|
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override {
|
||||||
return {{}};
|
return {{}};
|
||||||
}
|
}
|
||||||
|
std::tuple<std::vector<int>, bool, Dtype> state() const {
|
||||||
|
return {axes_, inverted_, dtype_};
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<int> axes_;
|
std::vector<int> axes_;
|
||||||
@ -1516,6 +1583,9 @@ class Pad : public UnaryPrimitive {
|
|||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_PRINT(Pad)
|
DEFINE_PRINT(Pad)
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
auto state() const {
|
||||||
|
return std::make_tuple(axes_, low_pad_size_, high_pad_size_);
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<int> axes_;
|
std::vector<int> axes_;
|
||||||
@ -1538,6 +1608,9 @@ class Partition : public UnaryPrimitive {
|
|||||||
DEFINE_PRINT(Partition)
|
DEFINE_PRINT(Partition)
|
||||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
auto state() const {
|
||||||
|
return std::make_pair(kth_, axis_);
|
||||||
|
};
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int kth_;
|
int kth_;
|
||||||
@ -1583,6 +1656,9 @@ class QuantizedMatmul : public UnaryPrimitive {
|
|||||||
DEFINE_PRINT(QuantizedMatmul)
|
DEFINE_PRINT(QuantizedMatmul)
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||||
|
auto state() const {
|
||||||
|
return std::make_tuple(group_size_, bits_, transpose_);
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int group_size_;
|
int group_size_;
|
||||||
@ -1607,6 +1683,9 @@ class GatherQMM : public UnaryPrimitive {
|
|||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_PRINT(GatherQMM)
|
DEFINE_PRINT(GatherQMM)
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
auto state() const {
|
||||||
|
return std::make_tuple(group_size_, bits_, transpose_);
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int group_size_;
|
int group_size_;
|
||||||
@ -1627,6 +1706,9 @@ class RandomBits : public UnaryPrimitive {
|
|||||||
DEFINE_VMAP()
|
DEFINE_VMAP()
|
||||||
DEFINE_PRINT(RandomBits)
|
DEFINE_PRINT(RandomBits)
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
std::pair<std::vector<int>, int> state() const {
|
||||||
|
return {shape_, width_};
|
||||||
|
};
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Shape shape_;
|
Shape shape_;
|
||||||
@ -1661,6 +1743,9 @@ class Reshape : public UnaryPrimitive {
|
|||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_PRINT(Reshape)
|
DEFINE_PRINT(Reshape)
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
std::vector<int> state() const {
|
||||||
|
return shape_;
|
||||||
|
};
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Shape shape_;
|
Shape shape_;
|
||||||
@ -1712,6 +1797,9 @@ class Reduce : public UnaryPrimitive {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
std::pair<ReduceType, std::vector<int>> state() const {
|
||||||
|
return {reduce_type_, axes_};
|
||||||
|
};
|
||||||
|
|
||||||
private:
|
private:
|
||||||
ReduceType reduce_type_;
|
ReduceType reduce_type_;
|
||||||
@ -1777,6 +1865,9 @@ class Scan : public UnaryPrimitive {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
auto state() const {
|
||||||
|
return std::make_tuple(reduce_type_, axis_, reverse_, inclusive_);
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
ReduceType reduce_type_;
|
ReduceType reduce_type_;
|
||||||
@ -1823,6 +1914,9 @@ class Scatter : public UnaryPrimitive {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
std::pair<ReduceType, std::vector<int>> state() const {
|
||||||
|
return {reduce_type_, axes_};
|
||||||
|
};
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
@ -1917,6 +2011,9 @@ class Slice : public UnaryPrimitive {
|
|||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_PRINT(Slice)
|
DEFINE_PRINT(Slice)
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
auto state() const {
|
||||||
|
return std::make_tuple(start_indices_, end_indices_, strides_);
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Shape start_indices_;
|
Shape start_indices_;
|
||||||
@ -1946,6 +2043,9 @@ class SliceUpdate : public UnaryPrimitive {
|
|||||||
DEFINE_PRINT(SliceUpdate)
|
DEFINE_PRINT(SliceUpdate)
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
|
auto state() const {
|
||||||
|
return std::make_tuple(start_indices_, end_indices_, strides_);
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Shape start_indices_;
|
Shape start_indices_;
|
||||||
@ -1969,6 +2069,9 @@ class Softmax : public UnaryPrimitive {
|
|||||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
|
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
auto state() const {
|
||||||
|
return precise_;
|
||||||
|
};
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
@ -1988,6 +2091,9 @@ class Sort : public UnaryPrimitive {
|
|||||||
DEFINE_PRINT(Sort)
|
DEFINE_PRINT(Sort)
|
||||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
auto state() const {
|
||||||
|
return axis_;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int axis_;
|
int axis_;
|
||||||
@ -2009,6 +2115,9 @@ class Split : public Primitive {
|
|||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_PRINT(Split)
|
DEFINE_PRINT(Split)
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
std::pair<std::vector<int>, int> state() const {
|
||||||
|
return {indices_, axis_};
|
||||||
|
};
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
|
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
|
||||||
@ -2046,6 +2155,9 @@ class Sqrt : public UnaryPrimitive {
|
|||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
auto state() const {
|
||||||
|
return recip_;
|
||||||
|
}
|
||||||
|
|
||||||
void print(std::ostream& os) override {
|
void print(std::ostream& os) override {
|
||||||
if (recip_) {
|
if (recip_) {
|
||||||
@ -2109,6 +2221,9 @@ class Squeeze : public UnaryPrimitive {
|
|||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
|
||||||
static Shape output_shape(const array& input, const std::vector<int>& axes);
|
static Shape output_shape(const array& input, const std::vector<int>& axes);
|
||||||
|
auto state() const {
|
||||||
|
return axes_;
|
||||||
|
};
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
@ -2164,6 +2279,9 @@ class Unflatten : public UnaryPrimitive {
|
|||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
|
||||||
static Shape output_shape(const array& input, int axis, const Shape& shape);
|
static Shape output_shape(const array& input, int axis, const Shape& shape);
|
||||||
|
auto state() const {
|
||||||
|
return std::make_pair(axis_, shape_);
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int axis_;
|
int axis_;
|
||||||
@ -2171,21 +2289,6 @@ class Unflatten : public UnaryPrimitive {
|
|||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
};
|
};
|
||||||
|
|
||||||
class Uniform : public UnaryPrimitive {
|
|
||||||
public:
|
|
||||||
explicit Uniform(Stream stream) : UnaryPrimitive(stream) {}
|
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
|
||||||
|
|
||||||
DEFINE_VMAP()
|
|
||||||
DEFINE_PRINT(Uniform)
|
|
||||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
|
||||||
|
|
||||||
private:
|
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
|
||||||
};
|
|
||||||
|
|
||||||
class View : public UnaryPrimitive {
|
class View : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit View(Stream stream, Dtype dtype)
|
explicit View(Stream stream, Dtype dtype)
|
||||||
@ -2197,6 +2300,9 @@ class View : public UnaryPrimitive {
|
|||||||
DEFINE_VMAP()
|
DEFINE_VMAP()
|
||||||
void print(std::ostream& os) override;
|
void print(std::ostream& os) override;
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
auto state() const {
|
||||||
|
return dtype_;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Dtype dtype_;
|
Dtype dtype_;
|
||||||
@ -2215,6 +2321,9 @@ class Transpose : public UnaryPrimitive {
|
|||||||
DEFINE_PRINT(Transpose)
|
DEFINE_PRINT(Transpose)
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||||
|
std::vector<int> state() const {
|
||||||
|
return axes_;
|
||||||
|
};
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<int> axes_;
|
std::vector<int> axes_;
|
||||||
@ -2266,6 +2375,9 @@ class Inverse : public UnaryPrimitive {
|
|||||||
|
|
||||||
DEFINE_VMAP()
|
DEFINE_VMAP()
|
||||||
DEFINE_PRINT(Inverse)
|
DEFINE_PRINT(Inverse)
|
||||||
|
auto state() const {
|
||||||
|
return std::make_pair(tri_, upper_);
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void eval(const std::vector<array>& inputs, array& output);
|
void eval(const std::vector<array>& inputs, array& output);
|
||||||
@ -2280,6 +2392,9 @@ class Cholesky : public UnaryPrimitive {
|
|||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
|
auto state() const {
|
||||||
|
return upper_;
|
||||||
|
}
|
||||||
|
|
||||||
DEFINE_VMAP()
|
DEFINE_VMAP()
|
||||||
DEFINE_PRINT(Cholesky)
|
DEFINE_PRINT(Cholesky)
|
||||||
@ -2307,6 +2422,9 @@ class Eigh : public Primitive {
|
|||||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||||
|
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
auto state() const {
|
||||||
|
return std::make_pair(uplo_, compute_eigenvectors_);
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
|
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
|
||||||
|
@ -21,6 +21,10 @@ void set_default_stream(Stream s) {
|
|||||||
return scheduler::scheduler().set_default_stream(s);
|
return scheduler::scheduler().set_default_stream(s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Stream get_stream(int index) {
|
||||||
|
return scheduler::scheduler().get_stream(index);
|
||||||
|
}
|
||||||
|
|
||||||
Stream new_stream(Device d) {
|
Stream new_stream(Device d) {
|
||||||
if (!metal::is_available() && d == Device::gpu) {
|
if (!metal::is_available() && d == Device::gpu) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
|
@ -96,6 +96,9 @@ class Scheduler {
|
|||||||
Stream get_default_stream(const Device& d) const {
|
Stream get_default_stream(const Device& d) const {
|
||||||
return default_streams_.at(d.type);
|
return default_streams_.at(d.type);
|
||||||
}
|
}
|
||||||
|
Stream get_stream(int index) const {
|
||||||
|
return streams_.at(index)->stream;
|
||||||
|
}
|
||||||
|
|
||||||
void set_default_stream(const Stream& s) {
|
void set_default_stream(const Stream& s) {
|
||||||
default_streams_.at(s.device.type) = s;
|
default_streams_.at(s.device.type) = s;
|
||||||
|
@ -21,6 +21,9 @@ void set_default_stream(Stream s);
|
|||||||
/** Make a new stream on the given device. */
|
/** Make a new stream on the given device. */
|
||||||
Stream new_stream(Device d);
|
Stream new_stream(Device d);
|
||||||
|
|
||||||
|
/** Get the stream with the given index. */
|
||||||
|
Stream get_stream(int index);
|
||||||
|
|
||||||
inline bool operator==(const Stream& lhs, const Stream& rhs) {
|
inline bool operator==(const Stream& lhs, const Stream& rhs) {
|
||||||
return lhs.index == rhs.index;
|
return lhs.index == rhs.index;
|
||||||
}
|
}
|
||||||
|
@ -10,6 +10,11 @@ namespace mlx::core {
|
|||||||
|
|
||||||
void async_eval(std::vector<array> outputs);
|
void async_eval(std::vector<array> outputs);
|
||||||
|
|
||||||
|
template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
|
||||||
|
void async_eval(Arrays&&... outputs) {
|
||||||
|
async_eval(std::vector<array>{std::forward<Arrays>(outputs)...});
|
||||||
|
}
|
||||||
|
|
||||||
void eval(std::vector<array> outputs);
|
void eval(std::vector<array> outputs);
|
||||||
|
|
||||||
template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
|
template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
|
||||||
|
@ -11,6 +11,7 @@ nanobind_add_module(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/convert.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/convert.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/export.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||||
|
277
python/src/export.cpp
Normal file
277
python/src/export.cpp
Normal file
@ -0,0 +1,277 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
#include <nanobind/nanobind.h>
|
||||||
|
#include <nanobind/stl/map.h>
|
||||||
|
#include <nanobind/stl/optional.h>
|
||||||
|
#include <nanobind/stl/string.h>
|
||||||
|
#include <nanobind/stl/vector.h>
|
||||||
|
|
||||||
|
#include <fstream>
|
||||||
|
|
||||||
|
#include "mlx/array.h"
|
||||||
|
#include "mlx/export.h"
|
||||||
|
#include "mlx/graph_utils.h"
|
||||||
|
#include "python/src/trees.h"
|
||||||
|
|
||||||
|
namespace mx = mlx::core;
|
||||||
|
namespace nb = nanobind;
|
||||||
|
using namespace nb::literals;
|
||||||
|
|
||||||
|
std::pair<std::vector<mx::array>, std::map<std::string, mx::array>>
|
||||||
|
validate_and_extract_inputs(
|
||||||
|
const nb::args& args,
|
||||||
|
const nb::kwargs& kwargs,
|
||||||
|
const std::string& prefix) {
|
||||||
|
auto maybe_throw = [&prefix](bool valid) {
|
||||||
|
if (!valid) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
prefix +
|
||||||
|
" Inputs can either be a variable "
|
||||||
|
"number of positional and keyword arrays or a single tuple "
|
||||||
|
"and/or dictionary of arrays.");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
std::vector<mx::array> args_;
|
||||||
|
std::map<std::string, mx::array> kwargs_;
|
||||||
|
if (args.size() == 0) {
|
||||||
|
// No args so kwargs must be keyword arrays
|
||||||
|
maybe_throw(nb::try_cast(kwargs, kwargs_));
|
||||||
|
} else if (args.size() > 0 && nb::isinstance<mx::array>(args[0])) {
|
||||||
|
// Args are positional arrays and kwargs are keyword arrays
|
||||||
|
maybe_throw(nb::try_cast(args, args_));
|
||||||
|
maybe_throw(nb::try_cast(kwargs, kwargs_));
|
||||||
|
} else if (args.size() == 1) {
|
||||||
|
// - args[0] can be a tuple or list or arrays or a dict
|
||||||
|
// with string keys and array values
|
||||||
|
// - kwargs should be empty
|
||||||
|
maybe_throw(kwargs.size() == 0);
|
||||||
|
if (!nb::try_cast(args[0], args_)) {
|
||||||
|
maybe_throw(nb::try_cast(args[0], kwargs_));
|
||||||
|
}
|
||||||
|
} else if (args.size() == 2) {
|
||||||
|
// - args[0] can be a tuple or list of arrays
|
||||||
|
// - args[1] can be a dict of string keys with array values.
|
||||||
|
// - kwargs should be empty
|
||||||
|
maybe_throw(kwargs.size() == 0);
|
||||||
|
maybe_throw(nb::try_cast(args[0], args_));
|
||||||
|
maybe_throw(nb::try_cast(args[1], kwargs_));
|
||||||
|
} else {
|
||||||
|
maybe_throw(false);
|
||||||
|
}
|
||||||
|
return {args_, kwargs_};
|
||||||
|
}
|
||||||
|
|
||||||
|
auto wrap_export_function(const nb::callable& fun) {
|
||||||
|
return [fun](
|
||||||
|
const std::vector<mx::array>& args_,
|
||||||
|
const std::map<std::string, mx::array>& kwargs_) {
|
||||||
|
auto kwargs = nb::dict();
|
||||||
|
kwargs.update(nb::cast(kwargs_));
|
||||||
|
auto args = nb::tuple(nb::cast(args_));
|
||||||
|
auto outputs = fun(*args, **kwargs);
|
||||||
|
std::vector<mx::array> outputs_;
|
||||||
|
if (nb::isinstance<mx::array>(outputs)) {
|
||||||
|
outputs_.push_back(nb::cast<mx::array>(outputs));
|
||||||
|
} else if (!nb::try_cast(outputs, outputs_)) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[export_function] Outputs can be either a single array "
|
||||||
|
"a tuple or list of arrays.");
|
||||||
|
}
|
||||||
|
return outputs_;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
void init_export(nb::module_& m) {
|
||||||
|
m.def(
|
||||||
|
"export_function",
|
||||||
|
[](const std::string& file,
|
||||||
|
const nb::callable& fun,
|
||||||
|
const nb::args& args,
|
||||||
|
bool shapeless,
|
||||||
|
const nb::kwargs& kwargs) {
|
||||||
|
auto [args_, kwargs_] =
|
||||||
|
validate_and_extract_inputs(args, kwargs, "[export_function]");
|
||||||
|
mx::export_function(
|
||||||
|
file, wrap_export_function(fun), args_, kwargs_, shapeless);
|
||||||
|
},
|
||||||
|
"file"_a,
|
||||||
|
"fun"_a,
|
||||||
|
"args"_a,
|
||||||
|
nb::kw_only(),
|
||||||
|
"shapeless"_a = false,
|
||||||
|
"kwargs"_a,
|
||||||
|
R"pbdoc(
|
||||||
|
Export a function to a file.
|
||||||
|
|
||||||
|
Example input arrays must be provided to export a function. The example
|
||||||
|
inputs can be variable ``*args`` and ``**kwargs`` or a tuple of arrays
|
||||||
|
and/or dictionary of string keys with array values.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
|
||||||
|
This is part of an experimental API which is likely to
|
||||||
|
change in future versions of MLX. Functions exported with older
|
||||||
|
versions of MLX may not be compatible with future versions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file (str): File path to export the function to.
|
||||||
|
fun (Callable): A function which takes as input zero or more
|
||||||
|
:class:`array` and returns one or more :class:`array`.
|
||||||
|
*args (array): Example array inputs to the function.
|
||||||
|
shapeless (bool, optional): Whether or not the function allows
|
||||||
|
inputs with variable shapes. Default: ``False``.
|
||||||
|
**kwargs (array): Additional example keyword array inputs to the
|
||||||
|
function.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
def fun(x, y):
|
||||||
|
return x + y
|
||||||
|
|
||||||
|
x = mx.array(1)
|
||||||
|
y = mx.array([1, 2, 3])
|
||||||
|
mx.export_function("fun.mlxfn", fun, x, y=y)
|
||||||
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"import_function",
|
||||||
|
[](const std::string& file) {
|
||||||
|
return nb::cpp_function(
|
||||||
|
[fn = mx::import_function(file)](
|
||||||
|
const nb::args& args, const nb::kwargs& kwargs) {
|
||||||
|
auto [args_, kwargs_] = validate_and_extract_inputs(
|
||||||
|
args, kwargs, "[import_function::call]");
|
||||||
|
return nb::tuple(nb::cast(fn(args_, kwargs_)));
|
||||||
|
});
|
||||||
|
},
|
||||||
|
"file"_a,
|
||||||
|
nb::sig("def import_function(file: str) -> Callable"),
|
||||||
|
R"pbdoc(
|
||||||
|
Import a function from a file.
|
||||||
|
|
||||||
|
The imported function can be called either with ``*args`` and
|
||||||
|
``**kwargs`` or with a tuple of arrays and/or dictionary of string
|
||||||
|
keys with array values. Imported functions always return a tuple of
|
||||||
|
arrays.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
|
||||||
|
This is part of an experimental API which is likely to
|
||||||
|
change in future versions of MLX. Functions exported with older
|
||||||
|
versions of MLX may not be compatible with future versions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file (str): The file path to import the function from.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Callable: The imported function.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> fn = mx.import_function("function.mlxfn")
|
||||||
|
>>> out = fn(a, b, x=x, y=y)[0]
|
||||||
|
>>>
|
||||||
|
>>> out = fn((a, b), {"x": x, "y": y}[0]
|
||||||
|
)pbdoc");
|
||||||
|
|
||||||
|
nb::class_<mx::FunctionExporter>(
|
||||||
|
m,
|
||||||
|
"FunctionExporter",
|
||||||
|
R"pbdoc(
|
||||||
|
A context managing class for exporting multiple traces of the same
|
||||||
|
function to a file.
|
||||||
|
|
||||||
|
Make an instance of this class by calling fun:`mx.exporter`.
|
||||||
|
)pbdoc")
|
||||||
|
.def("close", &mx::FunctionExporter::close)
|
||||||
|
.def(
|
||||||
|
"__enter__", [](mx::FunctionExporter& exporter) { return &exporter; })
|
||||||
|
.def(
|
||||||
|
"__exit__",
|
||||||
|
[](mx::FunctionExporter& exporter,
|
||||||
|
const std::optional<nb::object>&,
|
||||||
|
const std::optional<nb::object>&,
|
||||||
|
const std::optional<nb::object>&) { exporter.close(); },
|
||||||
|
"exc_type"_a = nb::none(),
|
||||||
|
"exc_value"_a = nb::none(),
|
||||||
|
"traceback"_a = nb::none())
|
||||||
|
.def(
|
||||||
|
"__call__",
|
||||||
|
[](mx::FunctionExporter& exporter,
|
||||||
|
const nb::args& args,
|
||||||
|
const nb::kwargs& kwargs) {
|
||||||
|
auto [args_, kwargs_] =
|
||||||
|
validate_and_extract_inputs(args, kwargs, "[export_function]");
|
||||||
|
exporter(args_, kwargs_);
|
||||||
|
});
|
||||||
|
|
||||||
|
m.def(
|
||||||
|
"exporter",
|
||||||
|
[](const std::string& file, const nb::callable& fun, bool shapeless) {
|
||||||
|
return mx::exporter(file, wrap_export_function(fun), shapeless);
|
||||||
|
},
|
||||||
|
"file"_a,
|
||||||
|
"fun"_a,
|
||||||
|
nb::kw_only(),
|
||||||
|
"shapeless"_a = false,
|
||||||
|
R"pbdoc(
|
||||||
|
Make a callable object to export multiple traces of a function to a file.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
|
||||||
|
This is part of an experimental API which is likely to
|
||||||
|
change in future versions of MLX. Functions exported with older
|
||||||
|
versions of MLX may not be compatible with future versions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file (str): File path to export the function to.
|
||||||
|
shapeless (bool, optional): Whether or not the function allows
|
||||||
|
inputs with variable shapes. Default: ``False``.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
def fun(*args):
|
||||||
|
return sum(args)
|
||||||
|
|
||||||
|
with mx.exporter("fun.mlxfn", fun) as exporter:
|
||||||
|
exporter(mx.array(1))
|
||||||
|
exporter(mx.array(1), mx.array(2))
|
||||||
|
exporter(mx.array(1), mx.array(2), mx.array(3))
|
||||||
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"export_to_dot",
|
||||||
|
[](nb::object file, const nb::args& args) {
|
||||||
|
std::vector<mx::array> arrays = tree_flatten(args);
|
||||||
|
if (nb::isinstance<nb::str>(file)) {
|
||||||
|
std::ofstream out(nb::cast<std::string>(file));
|
||||||
|
mx::export_to_dot(out, arrays);
|
||||||
|
} else if (nb::hasattr(file, "write")) {
|
||||||
|
std::ostringstream out;
|
||||||
|
mx::export_to_dot(out, arrays);
|
||||||
|
auto write = file.attr("write");
|
||||||
|
write(out.str());
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[export_to_dot] Accepts file-like objects or strings "
|
||||||
|
"to be used as filenames.");
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"file"_a,
|
||||||
|
"args"_a,
|
||||||
|
R"pbdoc(
|
||||||
|
Export a graph to DOT format for visualization.
|
||||||
|
|
||||||
|
A variable number of output arrays can be provided for exporting
|
||||||
|
The graph exported will recursively include all enevaluated inputs of
|
||||||
|
the provided outputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file (str): The file path to export to.
|
||||||
|
*args (array): The output arrays.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> a = mx.array(1) + mx.array(2)
|
||||||
|
>>> mx.export_to_dot("graph.dot", a)
|
||||||
|
)pbdoc");
|
||||||
|
}
|
@ -19,6 +19,7 @@ void init_linalg(nb::module_&);
|
|||||||
void init_constants(nb::module_&);
|
void init_constants(nb::module_&);
|
||||||
void init_fast(nb::module_&);
|
void init_fast(nb::module_&);
|
||||||
void init_distributed(nb::module_&);
|
void init_distributed(nb::module_&);
|
||||||
|
void init_export(nb::module_&);
|
||||||
|
|
||||||
NB_MODULE(core, m) {
|
NB_MODULE(core, m) {
|
||||||
m.doc() = "mlx: A framework for machine learning on Apple silicon.";
|
m.doc() = "mlx: A framework for machine learning on Apple silicon.";
|
||||||
@ -39,6 +40,7 @@ NB_MODULE(core, m) {
|
|||||||
init_constants(m);
|
init_constants(m);
|
||||||
init_fast(m);
|
init_fast(m);
|
||||||
init_distributed(m);
|
init_distributed(m);
|
||||||
|
init_export(m);
|
||||||
|
|
||||||
m.attr("__version__") = TOSTRING(_VERSION_);
|
m.attr("__version__") = TOSTRING(_VERSION_);
|
||||||
}
|
}
|
||||||
|
@ -2898,7 +2898,7 @@ void init_ops(nb::module_& m) {
|
|||||||
Generate multidimensional coordinate grids from 1-D coordinate arrays
|
Generate multidimensional coordinate grids from 1-D coordinate arrays
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
arrays (array): Input arrays.
|
*arrays (array): Input arrays.
|
||||||
sparse (bool, optional): If ``True``, a sparse grid is returned in which each output
|
sparse (bool, optional): If ``True``, a sparse grid is returned in which each output
|
||||||
array has a single non-zero element. If ``False``, a dense grid is returned.
|
array has a single non-zero element. If ``False``, a dense grid is returned.
|
||||||
Defaults to ``False``.
|
Defaults to ``False``.
|
||||||
@ -3840,8 +3840,8 @@ void init_ops(nb::module_& m) {
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
file (file, str): Path to file to which the arrays are saved.
|
file (file, str): Path to file to which the arrays are saved.
|
||||||
args (arrays): Arrays to be saved.
|
*args (arrays): Arrays to be saved.
|
||||||
kwargs (arrays): Arrays to be saved. Each array will be saved
|
**kwargs (arrays): Arrays to be saved. Each array will be saved
|
||||||
with the associated keyword as the output file name.
|
with the associated keyword as the output file name.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
|
@ -8,14 +8,12 @@
|
|||||||
#include <nanobind/stl/vector.h>
|
#include <nanobind/stl/vector.h>
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <fstream>
|
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
#include "mlx/compile.h"
|
#include "mlx/compile.h"
|
||||||
#include "mlx/compile_impl.h"
|
#include "mlx/compile_impl.h"
|
||||||
#include "mlx/graph_utils.h"
|
|
||||||
#include "mlx/transforms.h"
|
#include "mlx/transforms.h"
|
||||||
#include "mlx/transforms_impl.h"
|
#include "mlx/transforms_impl.h"
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
@ -945,7 +943,7 @@ void init_transforms(nb::module_& m) {
|
|||||||
Note, all custom transformations are optional. Undefined transformations
|
Note, all custom transformations are optional. Undefined transformations
|
||||||
fall back to the default behaviour.
|
fall back to the default behaviour.
|
||||||
|
|
||||||
Example usage:
|
Example:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
@ -1313,25 +1311,6 @@ void init_transforms(nb::module_& m) {
|
|||||||
Returns:
|
Returns:
|
||||||
Callable: The vectorized function.
|
Callable: The vectorized function.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
|
||||||
"export_to_dot",
|
|
||||||
[](nb::object file, const nb::args& args) {
|
|
||||||
std::vector<mx::array> arrays = tree_flatten(args);
|
|
||||||
if (nb::isinstance<nb::str>(file)) {
|
|
||||||
std::ofstream out(nb::cast<std::string>(file));
|
|
||||||
export_to_dot(out, arrays);
|
|
||||||
} else if (nb::hasattr(file, "write")) {
|
|
||||||
std::ostringstream out;
|
|
||||||
export_to_dot(out, arrays);
|
|
||||||
auto write = file.attr("write");
|
|
||||||
write(out.str());
|
|
||||||
} else {
|
|
||||||
throw std::invalid_argument(
|
|
||||||
"export_to_dot accepts file-like objects or strings to be used as filenames");
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"file"_a,
|
|
||||||
"args"_a);
|
|
||||||
m.def(
|
m.def(
|
||||||
"compile",
|
"compile",
|
||||||
[](const nb::callable& fun,
|
[](const nb::callable& fun,
|
||||||
|
244
python/tests/test_export_import.py
Normal file
244
python/tests/test_export_import.py
Normal file
@ -0,0 +1,244 @@
|
|||||||
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx_tests
|
||||||
|
|
||||||
|
|
||||||
|
class TestExportImport(mlx_tests.MLXTestCase):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.test_dir_fid = tempfile.TemporaryDirectory()
|
||||||
|
cls.test_dir = cls.test_dir_fid.name
|
||||||
|
if not os.path.isdir(cls.test_dir):
|
||||||
|
os.mkdir(cls.test_dir)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
cls.test_dir_fid.cleanup()
|
||||||
|
|
||||||
|
def test_basic_export_import(self):
|
||||||
|
path = os.path.join(self.test_dir, "fn.mlxfn")
|
||||||
|
|
||||||
|
# Function with no inputs
|
||||||
|
def fun():
|
||||||
|
return mx.zeros((3, 3))
|
||||||
|
|
||||||
|
mx.export_function(path, fun)
|
||||||
|
imported = mx.import_function(path)
|
||||||
|
|
||||||
|
expected = fun()
|
||||||
|
(out,) = imported()
|
||||||
|
self.assertTrue(mx.array_equal(out, expected))
|
||||||
|
|
||||||
|
# Simple function with inputs
|
||||||
|
def fun(x):
|
||||||
|
return mx.abs(mx.sin(x))
|
||||||
|
|
||||||
|
inputs = mx.array([1.0, 2.0, 3.0, 4.0, 5.0])
|
||||||
|
|
||||||
|
mx.export_function(path, fun, inputs)
|
||||||
|
imported = mx.import_function(path)
|
||||||
|
|
||||||
|
expected = fun(inputs)
|
||||||
|
(out,) = imported(inputs)
|
||||||
|
self.assertTrue(mx.allclose(out, expected))
|
||||||
|
|
||||||
|
# Inputs in a list or tuple
|
||||||
|
def fun(x):
|
||||||
|
x = mx.abs(mx.sin(x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
mx.export_function(path, fun, [inputs])
|
||||||
|
imported = mx.import_function(path)
|
||||||
|
|
||||||
|
expected = fun(inputs)
|
||||||
|
(out,) = imported([inputs])
|
||||||
|
self.assertTrue(mx.allclose(out, expected))
|
||||||
|
|
||||||
|
(out,) = imported(inputs)
|
||||||
|
self.assertTrue(mx.allclose(out, expected))
|
||||||
|
|
||||||
|
mx.export_function(path, fun, (inputs,))
|
||||||
|
imported = mx.import_function(path)
|
||||||
|
(out,) = imported((inputs,))
|
||||||
|
self.assertTrue(mx.allclose(out, expected))
|
||||||
|
|
||||||
|
# Outputs in a list
|
||||||
|
def fun(x):
|
||||||
|
return [mx.abs(mx.sin(x))]
|
||||||
|
|
||||||
|
mx.export_function(path, fun, inputs)
|
||||||
|
imported = mx.import_function(path)
|
||||||
|
(out,) = imported(inputs)
|
||||||
|
self.assertTrue(mx.allclose(out, expected))
|
||||||
|
|
||||||
|
# Outputs in a tuple
|
||||||
|
def fun(x):
|
||||||
|
return (mx.abs(mx.sin(x)),)
|
||||||
|
|
||||||
|
mx.export_function(path, fun, inputs)
|
||||||
|
imported = mx.import_function(path)
|
||||||
|
(out,) = imported(inputs)
|
||||||
|
self.assertTrue(mx.allclose(out, expected))
|
||||||
|
|
||||||
|
# Check throws on invalid inputs / outputs
|
||||||
|
def fun(x):
|
||||||
|
return mx.abs(x)
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
mx.export_function(path, fun, "hi")
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
mx.export_function(path, fun, mx.array(1.0), "hi")
|
||||||
|
|
||||||
|
def fun(x):
|
||||||
|
return mx.abs(x[0][0])
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
mx.export_function(path, fun, [[mx.array(1.0)]])
|
||||||
|
|
||||||
|
def fun():
|
||||||
|
return (mx.zeros((3, 3)), 1)
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
mx.export_function(path, fun)
|
||||||
|
|
||||||
|
def fun():
|
||||||
|
return (mx.zeros((3, 3)), [mx.zeros((3, 3))])
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
mx.export_function(path, fun)
|
||||||
|
|
||||||
|
def fun(x, y):
|
||||||
|
return x + y
|
||||||
|
|
||||||
|
mx.export_function(path, fun, mx.array(1.0), mx.array(1.0))
|
||||||
|
imported = mx.import_function(path)
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
imported(mx.array(1.0), 1.0)
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
imported(mx.array(1.0), mx.array(1.0), mx.array(1.0))
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
imported(mx.array(1.0), [mx.array(1.0)])
|
||||||
|
|
||||||
|
def test_export_random_sample(self):
|
||||||
|
path = os.path.join(self.test_dir, "fn.mlxfn")
|
||||||
|
|
||||||
|
mx.random.seed(5)
|
||||||
|
|
||||||
|
def fun():
|
||||||
|
return mx.random.uniform(shape=(3,))
|
||||||
|
|
||||||
|
mx.export_function(path, fun)
|
||||||
|
imported = mx.import_function(path)
|
||||||
|
|
||||||
|
(out,) = imported()
|
||||||
|
|
||||||
|
mx.random.seed(5)
|
||||||
|
expected = fun()
|
||||||
|
|
||||||
|
self.assertTrue(mx.array_equal(out, expected))
|
||||||
|
|
||||||
|
def test_export_with_kwargs(self):
|
||||||
|
path = os.path.join(self.test_dir, "fn.mlxfn")
|
||||||
|
|
||||||
|
def fun(x, z=None):
|
||||||
|
out = x
|
||||||
|
if z is not None:
|
||||||
|
out += z
|
||||||
|
return out
|
||||||
|
|
||||||
|
x = mx.array([1, 2, 3])
|
||||||
|
y = mx.array([1, 1, 0])
|
||||||
|
z = mx.array([2, 2, 2])
|
||||||
|
|
||||||
|
mx.export_function(path, fun, (x,), {"z": z})
|
||||||
|
imported_fun = mx.import_function(path)
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
imported_fun(x, z)
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
imported_fun(x, y=z)
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
imported_fun((x,), {"y": z})
|
||||||
|
|
||||||
|
out = imported_fun(x, z=z)[0]
|
||||||
|
self.assertTrue(mx.array_equal(out, mx.array([3, 4, 5])))
|
||||||
|
|
||||||
|
out = imported_fun((x,), {"z": z})[0]
|
||||||
|
self.assertTrue(mx.array_equal(out, mx.array([3, 4, 5])))
|
||||||
|
|
||||||
|
mx.export_function(path, fun, x, z=z)
|
||||||
|
imported_fun = mx.import_function(path)
|
||||||
|
out = imported_fun(x, z=z)[0]
|
||||||
|
self.assertTrue(mx.array_equal(out, mx.array([3, 4, 5])))
|
||||||
|
|
||||||
|
out = imported_fun((x,), {"z": z})[0]
|
||||||
|
self.assertTrue(mx.array_equal(out, mx.array([3, 4, 5])))
|
||||||
|
|
||||||
|
# Only specify kwargs
|
||||||
|
mx.export_function(path, fun, x=x, z=z)
|
||||||
|
imported_fun = mx.import_function(path)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
out = imported_fun(x, z=z)[0]
|
||||||
|
|
||||||
|
out = imported_fun(x=x, z=z)[0]
|
||||||
|
self.assertTrue(mx.array_equal(out, mx.array([3, 4, 5])))
|
||||||
|
|
||||||
|
out = imported_fun({"x": x, "z": z})[0]
|
||||||
|
self.assertTrue(mx.array_equal(out, mx.array([3, 4, 5])))
|
||||||
|
|
||||||
|
def test_export_variable_inputs(self):
|
||||||
|
path = os.path.join(self.test_dir, "fn.mlxfn")
|
||||||
|
|
||||||
|
def fun(x, y, z=None):
|
||||||
|
out = x + y
|
||||||
|
if z is not None:
|
||||||
|
out += z
|
||||||
|
return out
|
||||||
|
|
||||||
|
with mx.exporter(path, fun) as exporter:
|
||||||
|
exporter(mx.array([1, 2, 3]), mx.array([1, 1, 1]))
|
||||||
|
exporter(mx.array([1, 2, 3]), mx.array([1, 1, 1]), z=mx.array([2]))
|
||||||
|
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
exporter(mx.array([1, 2, 3, 4]), mx.array([1, 1, 1, 1]))
|
||||||
|
|
||||||
|
imported_fun = mx.import_function(path)
|
||||||
|
out = imported_fun(mx.array([1, 2, 3]), mx.array([1, 1, 1]))[0]
|
||||||
|
self.assertTrue(mx.array_equal(out, mx.array([2, 3, 4])))
|
||||||
|
|
||||||
|
out = imported_fun(mx.array([1, 2, 3]), mx.array([1, 1, 1]), z=mx.array([2]))[0]
|
||||||
|
self.assertTrue(mx.array_equal(out, mx.array([4, 5, 6])))
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
imported_fun(mx.array([1, 2, 3, 4]), mx.array([1, 1, 1, 1]))
|
||||||
|
|
||||||
|
# A function with a large constant
|
||||||
|
constant = mx.zeros((16, 2048))
|
||||||
|
mx.eval(constant)
|
||||||
|
|
||||||
|
def fun(*args):
|
||||||
|
return constant + sum(args)
|
||||||
|
|
||||||
|
with mx.exporter(path, fun) as exporter:
|
||||||
|
for i in range(5):
|
||||||
|
exporter(*[mx.array(1)] * i)
|
||||||
|
|
||||||
|
# Check the exported file size < constant size + small amount
|
||||||
|
constants_size = constant.nbytes + 8192
|
||||||
|
self.assertTrue(os.path.getsize(path) < constants_size)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
@ -28,15 +28,14 @@ class TestLoad(mlx_tests.MLXTestCase):
|
|||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
cls.test_dir_fid = tempfile.TemporaryDirectory()
|
cls.test_dir_fid = tempfile.TemporaryDirectory()
|
||||||
cls.test_dir = cls.test_dir_fid.name
|
cls.test_dir = cls.test_dir_fid.name
|
||||||
|
if not os.path.isdir(cls.test_dir):
|
||||||
|
os.mkdir(cls.test_dir)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def tearDownClass(cls):
|
def tearDownClass(cls):
|
||||||
cls.test_dir_fid.cleanup()
|
cls.test_dir_fid.cleanup()
|
||||||
|
|
||||||
def test_save_and_load(self):
|
def test_save_and_load(self):
|
||||||
if not os.path.isdir(self.test_dir):
|
|
||||||
os.mkdir(self.test_dir)
|
|
||||||
|
|
||||||
for dt in self.dtypes:
|
for dt in self.dtypes:
|
||||||
with self.subTest(dtype=dt):
|
with self.subTest(dtype=dt):
|
||||||
for i, shape in enumerate([(1,), (23,), (1024, 1024), (4, 6, 3, 1, 2)]):
|
for i, shape in enumerate([(1,), (23,), (1024, 1024), (4, 6, 3, 1, 2)]):
|
||||||
@ -64,9 +63,6 @@ class TestLoad(mlx_tests.MLXTestCase):
|
|||||||
self.assertTrue(np.array_equal(load_arr_mlx_npy, save_arr_npy))
|
self.assertTrue(np.array_equal(load_arr_mlx_npy, save_arr_npy))
|
||||||
|
|
||||||
def test_save_and_load_safetensors(self):
|
def test_save_and_load_safetensors(self):
|
||||||
if not os.path.isdir(self.test_dir):
|
|
||||||
os.mkdir(self.test_dir)
|
|
||||||
|
|
||||||
test_file = os.path.join(self.test_dir, "test.safetensors")
|
test_file = os.path.join(self.test_dir, "test.safetensors")
|
||||||
with self.assertRaises(Exception):
|
with self.assertRaises(Exception):
|
||||||
mx.save_safetensors(test_file, {"a": mx.ones((4, 4))}, {"testing": 0})
|
mx.save_safetensors(test_file, {"a": mx.ones((4, 4))}, {"testing": 0})
|
||||||
@ -330,9 +326,6 @@ class TestLoad(mlx_tests.MLXTestCase):
|
|||||||
self.assertTrue(np.array_equal(save_arrs_npy[k], v))
|
self.assertTrue(np.array_equal(save_arrs_npy[k], v))
|
||||||
|
|
||||||
def test_non_contiguous(self):
|
def test_non_contiguous(self):
|
||||||
if not os.path.isdir(self.test_dir):
|
|
||||||
os.mkdir(self.test_dir)
|
|
||||||
|
|
||||||
a = mx.broadcast_to(mx.array([1, 2]), [4, 2])
|
a = mx.broadcast_to(mx.array([1, 2]), [4, 2])
|
||||||
|
|
||||||
save_file = os.path.join(self.test_dir, "a.npy")
|
save_file = os.path.join(self.test_dir, "a.npy")
|
||||||
|
@ -24,6 +24,7 @@ target_sources(
|
|||||||
creations_tests.cpp
|
creations_tests.cpp
|
||||||
device_tests.cpp
|
device_tests.cpp
|
||||||
einsum_tests.cpp
|
einsum_tests.cpp
|
||||||
|
export_import_tests.cpp
|
||||||
eval_tests.cpp
|
eval_tests.cpp
|
||||||
fft_tests.cpp
|
fft_tests.cpp
|
||||||
load_tests.cpp
|
load_tests.cpp
|
||||||
|
165
tests/export_import_tests.cpp
Normal file
165
tests/export_import_tests.cpp
Normal file
@ -0,0 +1,165 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include <filesystem>
|
||||||
|
#include <stdexcept>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "doctest/doctest.h"
|
||||||
|
|
||||||
|
#include "mlx/export.h"
|
||||||
|
#include "mlx/mlx.h"
|
||||||
|
|
||||||
|
using namespace mlx::core;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
std::string get_temp_file(const std::string& name) {
|
||||||
|
return std::filesystem::temp_directory_path().append(name);
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
TEST_CASE("test export basic functions") {
|
||||||
|
std::string file_path = get_temp_file("model.mlxfn");
|
||||||
|
|
||||||
|
auto fun = [](std::vector<array> x) -> std::vector<array> {
|
||||||
|
return {negative(exp(x[0]))};
|
||||||
|
};
|
||||||
|
|
||||||
|
export_function(file_path, fun, {array({1.0, 2.0})});
|
||||||
|
|
||||||
|
auto imported_fun = import_function(file_path);
|
||||||
|
|
||||||
|
// Check num inputs mismatch throws
|
||||||
|
CHECK_THROWS_AS(
|
||||||
|
imported_fun({array({1.0}), array({2.0})}), std::invalid_argument);
|
||||||
|
|
||||||
|
// Check shape mismatch throws
|
||||||
|
CHECK_THROWS_AS(imported_fun({array({1.0})}), std::invalid_argument);
|
||||||
|
|
||||||
|
// Check type mismatch throws
|
||||||
|
CHECK_THROWS_AS(imported_fun({array({1.0}, float16)}), std::invalid_argument);
|
||||||
|
|
||||||
|
auto expected = fun({array({1.0, -1.0})});
|
||||||
|
auto out = imported_fun({array({1.0, -1.0})});
|
||||||
|
CHECK(allclose(expected[0], out[0]).item<bool>());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test export function with no inputs") {
|
||||||
|
auto fun = [](std::vector<array> x) -> std::vector<array> {
|
||||||
|
return {zeros({2, 2})};
|
||||||
|
};
|
||||||
|
|
||||||
|
std::string file_path = get_temp_file("model.mlxfn");
|
||||||
|
|
||||||
|
export_function(file_path, fun, {});
|
||||||
|
|
||||||
|
auto imported_fun = import_function(file_path);
|
||||||
|
|
||||||
|
auto expected = fun({});
|
||||||
|
auto out = imported_fun({});
|
||||||
|
CHECK(allclose(expected[0], out[0]).item<bool>());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test export multi output primitives") {
|
||||||
|
std::string file_path = get_temp_file("model.mlxfn");
|
||||||
|
|
||||||
|
auto fun = [](std::vector<array> x) -> std::vector<array> {
|
||||||
|
return {divmod(x[0], x[1])};
|
||||||
|
};
|
||||||
|
|
||||||
|
auto inputs = std::vector<array>{array({5.0, -10.0}), array({3.0, -2.0})};
|
||||||
|
export_function(file_path, fun, inputs);
|
||||||
|
|
||||||
|
auto imported_fun = import_function(file_path);
|
||||||
|
|
||||||
|
auto expected = fun(inputs);
|
||||||
|
auto out = imported_fun(inputs);
|
||||||
|
CHECK(allclose(expected[0], out[0]).item<bool>());
|
||||||
|
CHECK(allclose(expected[1], out[1]).item<bool>());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test export primitives with state") {
|
||||||
|
std::string file_path = get_temp_file("model.mlxfn");
|
||||||
|
|
||||||
|
auto fun = [](std::vector<array> x) -> std::vector<array> {
|
||||||
|
return {argpartition(x[0], 2, 0)};
|
||||||
|
};
|
||||||
|
|
||||||
|
auto x = array({1, 3, 2, 4, 5, 7, 6, 8}, {4, 2});
|
||||||
|
export_function(file_path, fun, {x});
|
||||||
|
|
||||||
|
auto imported_fun = import_function(file_path);
|
||||||
|
|
||||||
|
auto expected = fun({x});
|
||||||
|
auto out = imported_fun({x});
|
||||||
|
CHECK(allclose(expected[0], out[0]).item<bool>());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test export functions with kwargs") {
|
||||||
|
std::string file_path = get_temp_file("model.mlxfn");
|
||||||
|
|
||||||
|
auto fun =
|
||||||
|
[](const std::map<std::string, array>& kwargs) -> std::vector<array> {
|
||||||
|
return {kwargs.at("x") + kwargs.at("y")};
|
||||||
|
};
|
||||||
|
|
||||||
|
export_function(file_path, fun, {{"x", array(1)}, {"y", array(2)}});
|
||||||
|
auto fn = import_function(file_path);
|
||||||
|
|
||||||
|
// Must use kwargs
|
||||||
|
CHECK_THROWS(fn({array(1), array(2)}));
|
||||||
|
|
||||||
|
// Wrong number of keys
|
||||||
|
CHECK_THROWS(fn({{"x", array(1)}, {"y", array(2)}, {"z", array(3)}}));
|
||||||
|
|
||||||
|
// Wrong keys
|
||||||
|
CHECK_THROWS(fn({{"a", array(1)}, {"b", array(2)}}));
|
||||||
|
|
||||||
|
// Works
|
||||||
|
auto out = fn({{"x", array(1)}, {"y", array(2)}})[0];
|
||||||
|
CHECK_EQ(out.item<int>(), 3);
|
||||||
|
out = fn({}, {{"x", array(1)}, {"y", array(2)}})[0];
|
||||||
|
CHECK_EQ(out.item<int>(), 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test export function with variable inputs") {
|
||||||
|
std::string file_path = get_temp_file("model.mlxfn");
|
||||||
|
|
||||||
|
auto fun = [](const std::vector<array>& args) -> std::vector<array> {
|
||||||
|
auto out = array({1, 1, 1, 1});
|
||||||
|
for (auto x : args) {
|
||||||
|
out = out + x;
|
||||||
|
}
|
||||||
|
return {out};
|
||||||
|
};
|
||||||
|
|
||||||
|
{
|
||||||
|
auto fn_exporter = exporter(file_path, fun);
|
||||||
|
fn_exporter({array(0), array(0)});
|
||||||
|
fn_exporter({array(0), array(0), array(0)});
|
||||||
|
}
|
||||||
|
|
||||||
|
auto imported_fun = import_function(file_path);
|
||||||
|
|
||||||
|
// Call with two inputs
|
||||||
|
auto out = imported_fun({array(1), array(2)})[0];
|
||||||
|
|
||||||
|
CHECK(array_equal(out, array({4, 4, 4, 4})).item<bool>());
|
||||||
|
|
||||||
|
// Call with three inputs
|
||||||
|
out = imported_fun({array(1), array(2), array(3)})[0];
|
||||||
|
CHECK(array_equal(out, array({7, 7, 7, 7})).item<bool>());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test export function on different stream") {
|
||||||
|
std::string file_path = get_temp_file("model.mlxfn");
|
||||||
|
|
||||||
|
// Caller is responsible for setting up streams before
|
||||||
|
// importing functoins
|
||||||
|
auto fun = [](const std::vector<array>& args) -> std::vector<array> {
|
||||||
|
return {abs(args[0], Stream(1000, Device::cpu))};
|
||||||
|
};
|
||||||
|
|
||||||
|
export_function(file_path, fun, {array({0, 1, 2})});
|
||||||
|
|
||||||
|
CHECK_THROWS(import_function(file_path));
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user