diff --git a/mlx/backend/cuda/device/utils.cuh b/mlx/backend/cuda/device/utils.cuh index 7ebc5d654..2e017ba21 100644 --- a/mlx/backend/cuda/device/utils.cuh +++ b/mlx/backend/cuda/device/utils.cuh @@ -1,6 +1,6 @@ // Copyright © 2025 Apple Inc. -// This file must not include any host-only code, utilies that work under both +// This file must not include any host-only code, utilities that work under both // host and device can be put here. // // See more about the requirements at: @@ -202,7 +202,7 @@ struct Limits< } }; -// CUDA 11 does not have host side arithmatic operators for half types. +// CUDA 11 does not have host side arithmetic operators for half types. template struct Limits< T, diff --git a/mlx/backend/cuda/kernel_utils.cuh b/mlx/backend/cuda/kernel_utils.cuh index 9fca29116..d889cd590 100644 --- a/mlx/backend/cuda/kernel_utils.cuh +++ b/mlx/backend/cuda/kernel_utils.cuh @@ -1,8 +1,8 @@ // Copyright © 2025 Apple Inc. -// This file includes host-only utilies for writing CUDA kernels, the difference -// from backend/cuda/device/utils.cuh is that the latter file only include -// device-only code. +// This file includes host-only utilities for writing CUDA kernels, the +// difference from backend/cuda/device/utils.cuh is that the latter file only +// include device-only code. #pragma once diff --git a/mlx/backend/cuda/utils.h b/mlx/backend/cuda/utils.h index 81b19e346..9e95a84ef 100644 --- a/mlx/backend/cuda/utils.h +++ b/mlx/backend/cuda/utils.h @@ -1,6 +1,6 @@ // Copyright © 2025 Apple Inc. -// This file include utilies that are used by C++ code (i.e. .cpp files). +// This file include utilities that are used by C++ code (i.e. .cpp files). #pragma once diff --git a/mlx/export.cpp b/mlx/export.cpp index ca7e316b3..3448178e2 100644 --- a/mlx/export.cpp +++ b/mlx/export.cpp @@ -3,6 +3,7 @@ #include #include "mlx/compile_impl.h" #include "mlx/fast_primitives.h" +#include "mlx/graph_utils.h" #include "mlx/primitives.h" #include "mlx/utils.h" #include "mlx/version.h" @@ -13,6 +14,7 @@ #primitive, { \ serialize_primitive, \ deserialize_primitive, \ + primitive_state, \ {__VA_ARGS__} \ } \ } @@ -34,15 +36,20 @@ struct PrimitiveSerializer { using Serializer = std::function; using Deserializer = std::function(Reader&, Stream s)>; + using StateExtractor = std::function(const Primitive&)>; + PrimitiveSerializer( Serializer serialize, Deserializer deserialize, + StateExtractor extract_state, std::vector keys = {}) : serialize(std::move(serialize)), deserialize(std::move(deserialize)), + extract_state(std::move(extract_state)), keys(std::move(keys)) {}; Serializer serialize; Deserializer deserialize; + StateExtractor extract_state; std::vector keys; }; @@ -198,6 +205,32 @@ void serialize_primitive(Writer& os, const Primitive& p) { } } +template +void extract_state(const T state, std::vector& unpacked_state) { + if constexpr (std::is_arithmetic_v) { + unpacked_state.push_back(state); + } else if constexpr (std::is_enum_v) { + unpacked_state.push_back(static_cast(state)); + } else if constexpr (is_iterable) { + unpacked_state.push_back(state); + } else if constexpr (is_pair || is_tuple) { + std::apply( + [&unpacked_state](auto&... x) { + (..., extract_state(x, unpacked_state)); + }, + state); + } +} + +template +std::vector primitive_state(const Primitive& p) { + std::vector state; + if constexpr (has_state) { + extract_state(static_cast(p).state(), state); + } + return state; +} + template std::shared_ptr deserialize_primitive(Reader& is, Stream s) { if constexpr (has_state) { @@ -383,6 +416,23 @@ struct PrimitiveFactory { "[import_function] Unable to deserialize primitive " + name); } }; + + std::pair> extract_state( + const std::shared_ptr& p) { + std::string name = p->name(); + name = name.substr(0, name.find(' ')); + if (auto it = name_remap.find(name); it != name_remap.end()) { + name = it->second; + } + + if (auto it = factory.find(name); it != factory.end()) { + auto state = it->second.extract_state(*p); + return {name, state}; + } else { + throw std::invalid_argument( + "[export_function] Unable to get state for primitive " + name); + } + }; }; void write_header(Writer& os, int count, bool shapeless) { @@ -416,8 +466,10 @@ struct FunctionTable { }; bool shapeless; std::unordered_map> table; - Function* find(const Args& args, const Kwargs& kwargs); - std::pair emplace(const Args& args, const Kwargs& kwargs); + Function* find(const Args& args, const std::map& kwargs); + std::pair emplace( + const Args& args, + const std::map& kwargs); void insert( std::vector kwarg_keys, std::vector inputs, @@ -453,12 +505,15 @@ struct FunctionTable { } private: - bool match(const Args& args, const Kwargs& kwargs, const Function& fun); + bool match( + const Args& args, + const std::map& kwargs, + const Function& fun); }; bool FunctionTable::match( const Args& args, - const Kwargs& kwargs, + const std::map& kwargs, const Function& fun) { for (auto& k : fun.kwarg_keys) { if (kwargs.find(k) == kwargs.end()) { @@ -486,9 +541,7 @@ bool FunctionTable::match( return false; } } - auto sorted_kwargs = - std::map(kwargs.begin(), kwargs.end()); - for (auto& [_, in] : sorted_kwargs) { + for (auto& [_, in] : kwargs) { if (!match_inputs(in, fun.inputs[i++])) { return false; } @@ -499,7 +552,7 @@ bool FunctionTable::match( std::pair FunctionTable::emplace( const Args& args, - const Kwargs& kwargs) { + const std::map& kwargs) { auto n_inputs = args.size() + kwargs.size(); auto [it, _] = table.emplace(n_inputs, std::vector{}); auto& funs_vec = it->second; @@ -516,7 +569,7 @@ std::pair FunctionTable::emplace( FunctionTable::Function* FunctionTable::find( const Args& args, - const Kwargs& kwargs) { + const std::map& kwargs) { auto n_inputs = args.size() + kwargs.size(); auto it = table.find(n_inputs); if (it == table.end()) { @@ -545,16 +598,86 @@ FunctionExporter::FunctionExporter( write_header(os, count, shapeless); } +FunctionExporter::FunctionExporter( + const ExportCallback& callback, + std::function(const Args&, const Kwargs&)> fun, + bool shapeless) + : callback(callback), + fun(std::move(fun)), + ftable(std::make_shared(shapeless)) {} + void FunctionExporter::close() { closed = true; }; +void FunctionExporter::export_with_callback( + const std::vector& inputs, + const std::vector& outputs, + const std::vector& tape, + const std::vector& kwarg_keys) { + NodeNamer namer{}; + auto to_vector_data = [&namer](const auto& arrays) { + std::vector> data; + for (auto& a : arrays) { + data.emplace_back(namer.get_name(a), a.shape(), a.dtype()); + } + return data; + }; + + // Callback on the inputs + callback({{"type", "inputs"}, {"inputs", to_vector_data(inputs)}}); + std::vector> keyword_inputs; + for (int i = inputs.size() - kwarg_keys.size(), j = 0; i < inputs.size(); + ++i, ++j) { + keyword_inputs.emplace_back(kwarg_keys[j], namer.get_name(inputs[i])); + } + callback({{"type", "keyword_inputs"}, {"keywords", keyword_inputs}}); + + // Callback on the outputs + callback({{"type", "outputs"}, {"outputs", to_vector_data(outputs)}}); + + // Callback on the constants + { + std::unordered_set input_set; + for (auto& in : inputs) { + input_set.insert(in.id()); + } + std::vector> new_constants; + for (auto& arr : tape) { + if (arr.has_primitive() || input_set.find(arr.id()) != input_set.end()) { + continue; + } + if (constants.insert(arr.id()).second) { + new_constants.emplace_back(namer.get_name(arr), arr); + } + } + callback({{"type", "constants"}, {"constants", new_constants}}); + } + auto factory = PrimitiveFactory(); + + // Callback for each primitive in the tape + for (auto& arr : tape) { + if (!arr.has_primitive()) { + continue; + } + auto [name, state] = factory.extract_state(arr.primitive_ptr()); + callback( + {{"type", "primitive"}, + {"inputs", to_vector_data(arr.inputs())}, + {"outputs", to_vector_data(arr.outputs())}, + {"name", name}, + {"arguments", state}}); + } +} + void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) { if (closed) { throw std::runtime_error( "[export_function] Attempting to write after exporting is closed."); } - auto [fentry, inserted] = ftable->emplace(args, kwargs); + auto sorted_kwargs = + std::map(kwargs.begin(), kwargs.end()); + auto [fentry, inserted] = ftable->emplace(args, sorted_kwargs); if (!inserted) { throw std::runtime_error( "[export_function] Attempting to export a function twice with " @@ -564,8 +687,6 @@ void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) { // Flatten the inputs to the function for tracing std::vector kwarg_keys; auto inputs = args; - auto sorted_kwargs = - std::map(kwargs.begin(), kwargs.end()); for (auto& [k, v] : sorted_kwargs) { kwarg_keys.push_back(k); inputs.push_back(v); @@ -592,10 +713,18 @@ void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) { detail::compile_simplify(tape, parents_map, trace_outputs, /* passes */ 3); - // Update header + // Update the table entry + fentry.kwarg_keys = kwarg_keys; + fentry.inputs = trace_inputs; + count++; - // Overwrite the header + if (callback) { + export_with_callback(trace_inputs, trace_outputs, tape, kwarg_keys); + return; + } + + // Update the header auto pos = os.tell(); os.seek(0); write_header(os, count, ftable->shapeless); @@ -616,10 +745,6 @@ void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) { serialize(os, trace_inputs); serialize(os, arrays_to_ids(trace_outputs)); - // Update the table entry - fentry.kwarg_keys = std::move(kwarg_keys); - fentry.inputs = std::move(trace_inputs); - std::unordered_set input_set( trace_input_ids.begin(), trace_input_ids.end()); @@ -730,6 +855,58 @@ void export_function( exporter(file, fun, shapeless)(args, kwargs); } +FunctionExporter exporter( + const ExportCallback& callback, + const std::function(const Args&)>& fun, + bool shapeless /* = false */) { + return FunctionExporter{ + callback, + [fun](const Args& args, const Kwargs&) { return fun(args); }, + shapeless}; +} + +FunctionExporter exporter( + const ExportCallback& callback, + const std::function(const Kwargs&)>& fun, + bool shapeless /* = false */) { + return exporter( + callback, + [fun](const Args&, const Kwargs kwargs) { return fun(kwargs); }, + shapeless); +} + +FunctionExporter exporter( + const ExportCallback& callback, + const std::function(const Args&, const Kwargs&)>& fun, + bool shapeless /* = false */) { + return FunctionExporter{callback, fun, shapeless}; +} + +void export_function( + const ExportCallback& callback, + const std::function(const Args&)>& fun, + const Args& args, + bool shapeless /* = false */) { + exporter(callback, fun, shapeless)(args); +} + +void export_function( + const ExportCallback& callback, + const std::function(const Kwargs&)>& fun, + const Kwargs& kwargs, + bool shapeless /* = false */) { + exporter(callback, fun, shapeless)(kwargs); +} + +void export_function( + const ExportCallback& callback, + const std::function(const Args&, const Kwargs&)>& fun, + const Args& args, + const Kwargs& kwargs, + bool shapeless /* = false */) { + exporter(callback, fun, shapeless)(args, kwargs); +} + std::vector ImportedFunction::operator()(const Kwargs& kwargs) const { return this->operator()({}, kwargs); } @@ -741,7 +918,9 @@ std::vector ImportedFunction::operator()(const Args& args) const { std::vector ImportedFunction::operator()( const Args& args, const Kwargs& kwargs) const { - auto* fun = ftable->find(args, kwargs); + auto sorted_kwargs = + std::map(kwargs.begin(), kwargs.end()); + auto* fun = ftable->find(args, sorted_kwargs); if (fun == nullptr) { std::ostringstream msg; msg << "[import_function::call] No imported function found which matches " @@ -760,7 +939,7 @@ std::vector ImportedFunction::operator()( } auto inputs = args; - for (auto& [_, v] : kwargs) { + for (auto& [_, v] : sorted_kwargs) { inputs.push_back(v); } return detail::compile_replace( diff --git a/mlx/export.h b/mlx/export.h index c6859c6d8..715dac2c8 100644 --- a/mlx/export.h +++ b/mlx/export.h @@ -4,6 +4,7 @@ #include #include +#include #include "mlx/array.h" namespace mlx::core { @@ -11,6 +12,30 @@ namespace mlx::core { using Args = std::vector; using Kwargs = std::unordered_map; +// Possible types for a Primitive's state +using StateT = std::variant< + bool, + int, + size_t, + float, + double, + Dtype, + Shape, + Strides, + std::vector, + std::vector, + std::string>; + +using ExportCallbackInput = std::unordered_map< + std::string, + std::variant< + std::vector>, + std::vector>, + std::vector>, + std::vector, + std::string>>; +using ExportCallback = std::function; + struct FunctionExporter; /** @@ -61,6 +86,47 @@ struct ImportedFunction; */ ImportedFunction import_function(const std::string& file); +/** + * Make an exporter to export multiple traces of a given function with the same + * callback. + */ +FunctionExporter exporter( + const ExportCallback& callback, + const std::function(const Args&)>& fun, + bool shapeless = false); + +FunctionExporter exporter( + const ExportCallback& callback, + const std::function(const Kwargs&)>& fun, + bool shapeless = false); + +FunctionExporter exporter( + const ExportCallback& callback, + const std::function(const Args&, const Kwargs&)>& fun, + bool shapeless = false); + +/** + * Export a function with a callback. + */ +void export_function( + const ExportCallback& callback, + const std::function(const Args&)>& fun, + const Args& args, + bool shapeless = false); + +void export_function( + const ExportCallback& callback, + const std::function(const Kwargs&)>& fun, + const Kwargs& kwargs, + bool shapeless = false); + +void export_function( + const ExportCallback& callback, + const std::function(const Args&, const Kwargs&)>& fun, + const Args& args, + const Kwargs& kwargs, + bool shapeless = false); + } // namespace mlx::core #include "mlx/export_impl.h" diff --git a/mlx/export_impl.h b/mlx/export_impl.h index 74a1c35c1..0e7818981 100644 --- a/mlx/export_impl.h +++ b/mlx/export_impl.h @@ -38,13 +38,40 @@ struct FunctionExporter { const std::function(const Args&, const Kwargs&)>&, bool shapeless); + friend FunctionExporter exporter( + const ExportCallback&, + const std::function(const Args&)>&, + bool shapeless); + + friend FunctionExporter exporter( + const ExportCallback&, + const std::function(const Kwargs&)>&, + bool shapeless); + + friend FunctionExporter exporter( + const ExportCallback&, + const std::function(const Args&, const Kwargs&)>&, + bool shapeless); + FunctionExporter( const std::string& file, std::function(const Args&, const Kwargs&)> fun, bool shapeless); + + FunctionExporter( + const ExportCallback& callback, + std::function(const Args&, const Kwargs&)> fun, + bool shapeless); + io::FileWriter os; + ExportCallback callback; std::function(const Args&, const Kwargs& kwargs)> fun; void export_function(const Args& args, const Kwargs& kwargs); + void export_with_callback( + const std::vector& inputs, + const std::vector& outputs, + const std::vector& tape, + const std::vector& kwarg_keys); std::set constants; int count{0}; bool closed{false}; diff --git a/mlx/io/load.h b/mlx/io/load.h index 8b5dd95b6..0efcb367b 100644 --- a/mlx/io/load.h +++ b/mlx/io/load.h @@ -108,6 +108,7 @@ class ParallelFileReader : public Reader { class FileWriter : public Writer { public: + explicit FileWriter() {} explicit FileWriter(std::string file_path) : fd_(open( file_path.c_str(), diff --git a/python/src/export.cpp b/python/src/export.cpp index 4428e7cc8..3d3ffdf6b 100644 --- a/python/src/export.cpp +++ b/python/src/export.cpp @@ -1,8 +1,11 @@ // Copyright © 2024 Apple Inc. #include #include +#include #include +#include #include +#include #include #include @@ -131,24 +134,38 @@ auto wrap_export_function(nb::callable fun) { void init_export(nb::module_& m) { m.def( "export_function", - [](const std::string& file, + [](nb::object& file_or_callback, const nb::callable& fun, const nb::args& args, bool shapeless, const nb::kwargs& kwargs) { auto [args_, kwargs_] = validate_and_extract_inputs(args, kwargs, "[export_function]"); - mx::export_function( - file, wrap_export_function(fun), args_, kwargs_, shapeless); + if (nb::isinstance(file_or_callback)) { + mx::export_function( + nb::cast(file_or_callback), + wrap_export_function(fun), + args_, + kwargs_, + shapeless); + } else { + auto callback = nb::cast(file_or_callback); + auto wrapped_callback = + [callback](const mx::ExportCallbackInput& input) { + return callback(input); + }; + mx::export_function( + callback, wrap_export_function(fun), args_, kwargs_, shapeless); + } }, - "file"_a, + nb::arg(), "fun"_a, "args"_a, nb::kw_only(), "shapeless"_a = false, "kwargs"_a, R"pbdoc( - Export a function to a file. + Export an MLX function. Example input arrays must be provided to export a function. The example inputs can be variable ``*args`` and ``**kwargs`` or a tuple of arrays @@ -161,7 +178,8 @@ void init_export(nb::module_& m) { versions of MLX may not be compatible with future versions. Args: - file (str): File path to export the function to. + file (str or Callable): Either a file path to export the function + to or a callback. fun (Callable): A function which takes as input zero or more :class:`array` and returns one or more :class:`array`. *args (array): Example array inputs to the function. diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index b78673027..d8900d16d 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -319,7 +319,7 @@ class TestCompile(mlx_tests.MLXTestCase): # Check the state is unchanged self.assertEqual(state["y"], 2) - # Check the udpated state is used + # Check the updated state is used state["y"] = mx.array(3) out = test_state(mx.array(1)) self.assertEqual(out.item(), 4) diff --git a/python/tests/test_export_import.py b/python/tests/test_export_import.py index 1d8af8509..4a4ca82a5 100644 --- a/python/tests/test_export_import.py +++ b/python/tests/test_export_import.py @@ -485,6 +485,52 @@ class TestExportImport(mlx_tests.MLXTestCase): mx.array_equal(imported_fn(input_data)[0], model(input_data)) ) + def test_export_kwarg_ordering(self): + path = os.path.join(self.test_dir, "fun.mlxfn") + + def fn(x, y): + return x - y + + mx.export_function(path, fn, x=mx.array(1.0), y=mx.array(1.0)) + imported = mx.import_function(path) + out = imported(x=mx.array(2.0), y=mx.array(3.0))[0] + self.assertEqual(out.item(), -1.0) + out = imported(y=mx.array(2.0), x=mx.array(3.0))[0] + self.assertEqual(out.item(), 1.0) + + def test_export_with_callback(self): + + def fn(x, y): + return mx.log(mx.abs(x - y)) + + n_in = None + n_out = None + n_const = None + keywords = None + primitives = [] + + def callback(args): + nonlocal n_in, n_out, n_const, keywords, primitives + t = args["type"] + if t == "inputs": + n_in = len(args["inputs"]) + elif args["type"] == "outputs": + n_out = len(args["outputs"]) + elif args["type"] == "keyword_inputs": + keywords = args["keywords"] + elif t == "constants": + n_const = len(args["constants"]) + elif t == "primitive": + primitives.append(args["name"]) + + mx.export_function(callback, fn, mx.array(1.0), y=mx.array(1.0)) + self.assertEqual(n_in, 2) + self.assertEqual(n_out, 1) + self.assertEqual(n_const, 0) + self.assertEqual(len(keywords), 1) + self.assertEqual(keywords[0][0], "y") + self.assertEqual(primitives, ["Subtract", "Abs", "Log"]) + if __name__ == "__main__": mlx_tests.MLXTestRunner()