mlx/mlx/export.h
Awni Hannun 4ba0c24a8f
Export / import functions to / from a file (#1642)
* export and import functions

* refactor + works for few primitives

* nit

* allow primitives with state

* nit

* nit

* simplify serialize / deserialize

* fix for constants

* python bindings

* maybe fix serialize failure case

* add example

* more primitives, training kind of works

* same result for python and c++

* some fixes

* fix export

* template it up

* some simplificatoin

* rebase

* allow kwargs and multiple functions

* exporter

* more primitives for exporting

* deal with endianness

* handle invalid stream

* add docstring
2024-12-24 11:19:13 -08:00

67 lines
1.5 KiB
C++

// Copyright © 2024 Apple Inc.
#pragma once
#include <map>
#include <set>
#include "mlx/array.h"
namespace mlx::core {
using Args = std::vector<array>;
using Kwargs = std::map<std::string, array>;
struct FunctionExporter;
/**
* Make an exporter to save multiple traces of a given function to
* the same file.
*/
FunctionExporter exporter(
const std::string& file,
const std::function<std::vector<array>(const Args&)>& fun,
bool shapeless = false);
FunctionExporter exporter(
const std::string& file,
const std::function<std::vector<array>(const Kwargs&)>& fun,
bool shapeless = false);
FunctionExporter exporter(
const std::string& path,
const std::function<std::vector<array>(const Args&, const Kwargs&)>& fun,
bool shapeless = false);
/**
* Export a function to a file.
*/
void export_function(
const std::string& file,
const std::function<std::vector<array>(const Args&)>& fun,
const Args& args,
bool shapeless = false);
void export_function(
const std::string& file,
const std::function<std::vector<array>(const Kwargs&)>& fun,
const Kwargs& kwargs,
bool shapeless = false);
void export_function(
const std::string& file,
const std::function<std::vector<array>(const Args&, const Kwargs&)>& fun,
const Args& args,
const Kwargs& kwargs,
bool shapeless = false);
struct ImportedFunction;
/**
* Import a function from a file.
*/
ImportedFunction import_function(const std::string& file);
} // namespace mlx::core
#include "mlx/export_impl.h"