mlx/mlx/export.h
Awni Hannun dc4eada7f0
Use unordered map for kwargs in export/import (#2087)
* use unordered map for kwargs in export/import

* comment
2025-04-21 07:17:22 -07:00

67 lines
1.5 KiB
C++

// Copyright © 2024 Apple Inc.
#pragma once
#include <set>
#include <unordered_map>
#include "mlx/array.h"
namespace mlx::core {
using Args = std::vector<array>;
using Kwargs = std::unordered_map<std::string, array>;
struct FunctionExporter;
/**
* Make an exporter to save multiple traces of a given function to
* the same file.
*/
FunctionExporter exporter(
const std::string& file,
const std::function<std::vector<array>(const Args&)>& fun,
bool shapeless = false);
FunctionExporter exporter(
const std::string& file,
const std::function<std::vector<array>(const Kwargs&)>& fun,
bool shapeless = false);
FunctionExporter exporter(
const std::string& path,
const std::function<std::vector<array>(const Args&, const Kwargs&)>& fun,
bool shapeless = false);
/**
* Export a function to a file.
*/
void export_function(
const std::string& file,
const std::function<std::vector<array>(const Args&)>& fun,
const Args& args,
bool shapeless = false);
void export_function(
const std::string& file,
const std::function<std::vector<array>(const Kwargs&)>& fun,
const Kwargs& kwargs,
bool shapeless = false);
void export_function(
const std::string& file,
const std::function<std::vector<array>(const Args&, const Kwargs&)>& fun,
const Args& args,
const Kwargs& kwargs,
bool shapeless = false);
struct ImportedFunction;
/**
* Import a function from a file.
*/
ImportedFunction import_function(const std::string& file);
} // namespace mlx::core
#include "mlx/export_impl.h"