mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
cleanup, test, fix
This commit is contained in:
@@ -222,7 +222,6 @@ void extract_state(const T state, std::vector<StateT>& unpacked_state) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// std::vector<StateT> extract_state(const Primitive& p) {
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
std::vector<StateT> primitive_state(const Primitive& p) {
|
std::vector<StateT> primitive_state(const Primitive& p) {
|
||||||
std::vector<StateT> state;
|
std::vector<StateT> state;
|
||||||
@@ -628,8 +627,9 @@ void FunctionExporter::export_with_callback(
|
|||||||
// Callback on the inputs
|
// Callback on the inputs
|
||||||
callback({{"type", "inputs"}, {"inputs", to_vector_data(inputs)}});
|
callback({{"type", "inputs"}, {"inputs", to_vector_data(inputs)}});
|
||||||
std::vector<std::pair<std::string, std::string>> keyword_inputs;
|
std::vector<std::pair<std::string, std::string>> keyword_inputs;
|
||||||
for (int i = inputs.size() - kwarg_keys.size(); i < inputs.size(); ++i) {
|
for (int i = inputs.size() - kwarg_keys.size(), j = 0; i < inputs.size();
|
||||||
keyword_inputs.emplace_back(kwarg_keys[i], namer.get_name(inputs[i]));
|
++i, ++j) {
|
||||||
|
keyword_inputs.emplace_back(kwarg_keys[j], namer.get_name(inputs[i]));
|
||||||
}
|
}
|
||||||
callback({{"type", "keyword_inputs"}, {"keywords", keyword_inputs}});
|
callback({{"type", "keyword_inputs"}, {"keywords", keyword_inputs}});
|
||||||
|
|
||||||
|
|||||||
@@ -498,6 +498,39 @@ class TestExportImport(mlx_tests.MLXTestCase):
|
|||||||
out = imported(y=mx.array(2.0), x=mx.array(3.0))[0]
|
out = imported(y=mx.array(2.0), x=mx.array(3.0))[0]
|
||||||
self.assertEqual(out.item(), 1.0)
|
self.assertEqual(out.item(), 1.0)
|
||||||
|
|
||||||
|
def test_export_with_callback(self):
|
||||||
|
|
||||||
|
def fn(x, y):
|
||||||
|
return mx.log(mx.abs(x - y))
|
||||||
|
|
||||||
|
n_in = None
|
||||||
|
n_out = None
|
||||||
|
n_const = None
|
||||||
|
keywords = None
|
||||||
|
primitives = []
|
||||||
|
|
||||||
|
def callback(args):
|
||||||
|
nonlocal n_in, n_out, n_const, keywords, primitives
|
||||||
|
t = args["type"]
|
||||||
|
if t == "inputs":
|
||||||
|
n_in = len(args["inputs"])
|
||||||
|
elif args["type"] == "outputs":
|
||||||
|
n_out = len(args["outputs"])
|
||||||
|
elif args["type"] == "keyword_inputs":
|
||||||
|
keywords = args["keywords"]
|
||||||
|
elif t == "constants":
|
||||||
|
n_const = len(args["constants"])
|
||||||
|
elif t == "primitive":
|
||||||
|
primitives.append(args["name"])
|
||||||
|
|
||||||
|
mx.export_function(callback, fn, mx.array(1.0), y=mx.array(1.0))
|
||||||
|
self.assertEqual(n_in, 2)
|
||||||
|
self.assertEqual(n_out, 1)
|
||||||
|
self.assertEqual(n_const, 0)
|
||||||
|
self.assertEqual(len(keywords), 1)
|
||||||
|
self.assertEqual(keywords[0][0], "y")
|
||||||
|
self.assertEqual(primitives, ["Subtract", "Abs", "Log"])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
mlx_tests.MLXTestRunner()
|
mlx_tests.MLXTestRunner()
|
||||||
|
|||||||
Reference in New Issue
Block a user