mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 08:20:15 +08:00

* export with callback * export with callback * Add types, fix kwarg ordering bug + test * cleanup, test, fix * typos
99 lines
2.8 KiB
C++
99 lines
2.8 KiB
C++
// Copyright © 2024 Apple Inc.
|
|
|
|
#include "mlx/io/load.h"
|
|
|
|
#pragma once
|
|
|
|
namespace mlx::core {
|
|
|
|
struct FunctionTable;
|
|
|
|
struct FunctionExporter {
|
|
void operator()(const std::initializer_list<array>& args) {
|
|
this->operator()(Args(args));
|
|
}
|
|
void operator()(const Args& args);
|
|
void operator()(const Kwargs& kwargs);
|
|
void operator()(const Args& args, const Kwargs& kwargs);
|
|
|
|
void close();
|
|
|
|
FunctionExporter(const FunctionExporter&) = delete;
|
|
FunctionExporter& operator=(const FunctionExporter&) = delete;
|
|
FunctionExporter(FunctionExporter&& other) = default;
|
|
|
|
private:
|
|
friend FunctionExporter exporter(
|
|
const std::string&,
|
|
const std::function<std::vector<array>(const Args&)>&,
|
|
bool shapeless);
|
|
|
|
friend FunctionExporter exporter(
|
|
const std::string&,
|
|
const std::function<std::vector<array>(const Kwargs&)>&,
|
|
bool shapeless);
|
|
|
|
friend FunctionExporter exporter(
|
|
const std::string&,
|
|
const std::function<std::vector<array>(const Args&, const Kwargs&)>&,
|
|
bool shapeless);
|
|
|
|
friend FunctionExporter exporter(
|
|
const ExportCallback&,
|
|
const std::function<std::vector<array>(const Args&)>&,
|
|
bool shapeless);
|
|
|
|
friend FunctionExporter exporter(
|
|
const ExportCallback&,
|
|
const std::function<std::vector<array>(const Kwargs&)>&,
|
|
bool shapeless);
|
|
|
|
friend FunctionExporter exporter(
|
|
const ExportCallback&,
|
|
const std::function<std::vector<array>(const Args&, const Kwargs&)>&,
|
|
bool shapeless);
|
|
|
|
FunctionExporter(
|
|
const std::string& file,
|
|
std::function<std::vector<array>(const Args&, const Kwargs&)> fun,
|
|
bool shapeless);
|
|
|
|
FunctionExporter(
|
|
const ExportCallback& callback,
|
|
std::function<std::vector<array>(const Args&, const Kwargs&)> fun,
|
|
bool shapeless);
|
|
|
|
io::FileWriter os;
|
|
ExportCallback callback;
|
|
std::function<std::vector<array>(const Args&, const Kwargs& kwargs)> fun;
|
|
void export_function(const Args& args, const Kwargs& kwargs);
|
|
void export_with_callback(
|
|
const std::vector<array>& inputs,
|
|
const std::vector<array>& outputs,
|
|
const std::vector<array>& tape,
|
|
const std::vector<std::string>& kwarg_keys);
|
|
std::set<std::uintptr_t> constants;
|
|
int count{0};
|
|
bool closed{false};
|
|
std::shared_ptr<FunctionTable> ftable;
|
|
};
|
|
|
|
struct ImportedFunction {
|
|
std::vector<array> operator()(
|
|
const std::initializer_list<array>& args) const {
|
|
return this->operator()(Args(args));
|
|
}
|
|
std::vector<array> operator()(const Args& args) const;
|
|
std::vector<array> operator()(const Kwargs& kwargs) const;
|
|
std::vector<array> operator()(const Args& args, const Kwargs& kwargs) const;
|
|
|
|
private:
|
|
ImportedFunction(const std::string& file);
|
|
friend ImportedFunction import_function(const std::string&);
|
|
ImportedFunction();
|
|
|
|
std::shared_ptr<FunctionTable> ftable;
|
|
};
|
|
|
|
} // namespace mlx::core
|