mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-16 22:11:15 +08:00
basic python tests
This commit is contained in:
parent
9739c72781
commit
0005cfe053
@ -1,5 +1,4 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
#include <pybind11/functional.h>
|
#include <pybind11/functional.h>
|
||||||
#include <pybind11/pybind11.h>
|
#include <pybind11/pybind11.h>
|
||||||
#include <pybind11/stl.h>
|
#include <pybind11/stl.h>
|
||||||
@ -163,6 +162,19 @@ py::object tree_unflatten(
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
py::object tree_unflatten_none(
|
||||||
|
py::object tree,
|
||||||
|
const std::vector<array>& values,
|
||||||
|
int index = 0) {
|
||||||
|
return tree_map(tree, [&](py::handle obj) {
|
||||||
|
if (py::isinstance<py::none>(obj)) {
|
||||||
|
return py::cast(values[index++]);
|
||||||
|
} else {
|
||||||
|
return py::cast<py::object>(obj);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
auto validate_argnums_argnames(
|
auto validate_argnums_argnames(
|
||||||
const std::optional<IntOrVec>& argnums,
|
const std::optional<IntOrVec>& argnums,
|
||||||
const StrOrVec& argnames) {
|
const StrOrVec& argnames) {
|
||||||
@ -438,30 +450,36 @@ auto py_vmap(
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto py_compile(const py::function& fun) {
|
auto py_compile(const py::function& fun) {
|
||||||
|
// This map is used to Cache the tree structure of the outputs
|
||||||
|
static std::unordered_map<size_t, py::object> tree_cache;
|
||||||
|
|
||||||
return [fun](const py::args& args) {
|
return [fun](const py::args& args) {
|
||||||
// Inputs must be array or tree of arrays
|
// Inputs must be array or tree of arrays
|
||||||
auto inputs = tree_flatten(args, true);
|
auto inputs = tree_flatten(args, true);
|
||||||
|
|
||||||
// py_value_out will hold the output of the python function in order to be
|
|
||||||
// able to reconstruct the python tree of extra return values
|
|
||||||
py::object py_outputs;
|
|
||||||
|
|
||||||
auto compile_fun =
|
|
||||||
[&fun, &args, &inputs, &py_outputs](const std::vector<array>& a) {
|
|
||||||
// Call the python function
|
|
||||||
py_outputs = fun(*tree_unflatten(args, a));
|
|
||||||
|
|
||||||
// Flatten the outputs
|
|
||||||
return tree_flatten(py_outputs, true);
|
|
||||||
};
|
|
||||||
|
|
||||||
// Compile and call
|
|
||||||
// TODO, awni, I think this cast is ok??
|
// TODO, awni, I think this cast is ok??
|
||||||
size_t fun_id = reinterpret_cast<size_t>(fun.ptr());
|
size_t fun_id = reinterpret_cast<size_t>(fun.ptr());
|
||||||
|
|
||||||
|
auto compile_fun = [fun_id, &fun, &args, &inputs](
|
||||||
|
const std::vector<array>& a) {
|
||||||
|
// Call the python function
|
||||||
|
py::object py_outputs = fun(*tree_unflatten(args, a));
|
||||||
|
|
||||||
|
// Flatten the outputs
|
||||||
|
auto outputs = tree_flatten(py_outputs, true);
|
||||||
|
|
||||||
|
py_outputs =
|
||||||
|
tree_map(py_outputs, [](const py::handle& x) { return py::none(); });
|
||||||
|
tree_cache.insert({fun_id, py_outputs});
|
||||||
|
return outputs;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Compile and call
|
||||||
auto outputs = detail::compile(compile_fun, fun_id)(inputs);
|
auto outputs = detail::compile(compile_fun, fun_id)(inputs);
|
||||||
|
|
||||||
// Put the outputs back in the container
|
// Put the outputs back in the container
|
||||||
return tree_unflatten(py_outputs, outputs);
|
py::object py_outputs = tree_cache.at(fun_id);
|
||||||
|
return tree_unflatten_none(py_outputs, outputs);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -15,7 +15,12 @@ class TestCompile(mlx_tests.MLXTestCase):
|
|||||||
compiled_fn = mx.compile(fun)
|
compiled_fn = mx.compile(fun)
|
||||||
x = mx.array(1.0)
|
x = mx.array(1.0)
|
||||||
y = mx.array(1.0)
|
y = mx.array(1.0)
|
||||||
# out = compiled_fn(x, y)
|
out = compiled_fn(x, y)
|
||||||
|
self.assertEqual(out.item(), 2.0)
|
||||||
|
|
||||||
|
# Try again
|
||||||
|
out = compiled_fn(x, y)
|
||||||
|
self.assertEqual(out.item(), 2.0)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
#include <iostream> // TODO
|
||||||
#include "doctest/doctest.h"
|
#include "doctest/doctest.h"
|
||||||
|
#include "mlx/utils.h" // TODO
|
||||||
|
|
||||||
#include "mlx/mlx.h"
|
#include "mlx/mlx.h"
|
||||||
|
|
||||||
@ -33,17 +35,50 @@ TEST_CASE("test simple compile") {
|
|||||||
CHECK(array_equal(out, array({3.0f, 4.0f})).item<bool>());
|
CHECK(array_equal(out, array({3.0f, 4.0f})).item<bool>());
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<array> fun1(const std::vector<array>& inputs) {
|
std::vector<array> grad_fun(const std::vector<array>& inputs) {
|
||||||
auto loss = [](std::vector<array> ins) { return exp(ins[0] + ins[1]); };
|
auto loss = [](std::vector<array> ins) { return exp(ins[0] + ins[1]); };
|
||||||
return grad(loss)(inputs);
|
return grad(loss, {0, 1})(inputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_CASE("test compile with grad") {
|
TEST_CASE("test compile with grad") {
|
||||||
auto x = array(1.0f);
|
auto x = array(1.0f);
|
||||||
auto y = array(1.0f);
|
auto y = array(1.0f);
|
||||||
auto grads_expected = fun1({x, y});
|
auto grads_expected = grad_fun({x, y});
|
||||||
auto grads_compile = compile(fun1)({x, y});
|
auto grads_compile = compile(grad_fun)({x, y});
|
||||||
CHECK_EQ(grads_compile[0].item<float>(), grads_expected[0].item<float>());
|
CHECK_EQ(grads_compile[0].item<float>(), grads_expected[0].item<float>());
|
||||||
|
CHECK_EQ(grads_compile[1].item<float>(), grads_expected[1].item<float>());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test compile inputs with primitive") {
|
||||||
|
auto [k1, k2] = random::split(random::key(0));
|
||||||
|
auto x = random::uniform({5, 5}, k1);
|
||||||
|
auto y = random::uniform({5, 5}, k2);
|
||||||
|
auto expected = simple_fun({x, y})[0];
|
||||||
|
|
||||||
|
x = random::uniform({5, 5}, k1);
|
||||||
|
y = random::uniform({5, 5}, k2);
|
||||||
|
auto out = compile(simple_fun)({x, y})[0];
|
||||||
|
CHECK(array_equal(expected, out).item<bool>());
|
||||||
|
|
||||||
|
// Same thing twice
|
||||||
|
out = compile(simple_fun)({x, y})[0];
|
||||||
|
CHECK(array_equal(expected, out).item<bool>());
|
||||||
|
}
|
||||||
|
|
||||||
|
/*std::vector<array> bigger_fun(const std::vector<array>& inputs) {
|
||||||
|
auto x = inputs[1];
|
||||||
|
for (int i = 1; i < inputs.size(); ++i) {
|
||||||
|
w = inputs[i]
|
||||||
|
x = maximum(matmul(x, w), 0);
|
||||||
|
}
|
||||||
|
return take(x, array(3)) - logsumexp(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test bigger graph") {
|
||||||
|
std::vector<array> inputs;
|
||||||
|
inputs.push_back(
|
||||||
|
for (int
|
||||||
|
for
|
||||||
|
}*/
|
||||||
|
|
||||||
TEST_CASE("test nested compile") {}
|
TEST_CASE("test nested compile") {}
|
||||||
|
Loading…
Reference in New Issue
Block a user