mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	basic python tests
This commit is contained in:
		@@ -1,5 +1,4 @@
 | 
				
			|||||||
// Copyright © 2023 Apple Inc.
 | 
					// Copyright © 2023 Apple Inc.
 | 
				
			||||||
 | 
					 | 
				
			||||||
#include <pybind11/functional.h>
 | 
					#include <pybind11/functional.h>
 | 
				
			||||||
#include <pybind11/pybind11.h>
 | 
					#include <pybind11/pybind11.h>
 | 
				
			||||||
#include <pybind11/stl.h>
 | 
					#include <pybind11/stl.h>
 | 
				
			||||||
@@ -163,6 +162,19 @@ py::object tree_unflatten(
 | 
				
			|||||||
  });
 | 
					  });
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					py::object tree_unflatten_none(
 | 
				
			||||||
 | 
					    py::object tree,
 | 
				
			||||||
 | 
					    const std::vector<array>& values,
 | 
				
			||||||
 | 
					    int index = 0) {
 | 
				
			||||||
 | 
					  return tree_map(tree, [&](py::handle obj) {
 | 
				
			||||||
 | 
					    if (py::isinstance<py::none>(obj)) {
 | 
				
			||||||
 | 
					      return py::cast(values[index++]);
 | 
				
			||||||
 | 
					    } else {
 | 
				
			||||||
 | 
					      return py::cast<py::object>(obj);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  });
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
auto validate_argnums_argnames(
 | 
					auto validate_argnums_argnames(
 | 
				
			||||||
    const std::optional<IntOrVec>& argnums,
 | 
					    const std::optional<IntOrVec>& argnums,
 | 
				
			||||||
    const StrOrVec& argnames) {
 | 
					    const StrOrVec& argnames) {
 | 
				
			||||||
@@ -438,30 +450,36 @@ auto py_vmap(
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
auto py_compile(const py::function& fun) {
 | 
					auto py_compile(const py::function& fun) {
 | 
				
			||||||
 | 
					  // This map is used to Cache the tree structure of the outputs
 | 
				
			||||||
 | 
					  static std::unordered_map<size_t, py::object> tree_cache;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  return [fun](const py::args& args) {
 | 
					  return [fun](const py::args& args) {
 | 
				
			||||||
    // Inputs must be array or tree of arrays
 | 
					    // Inputs must be array or tree of arrays
 | 
				
			||||||
    auto inputs = tree_flatten(args, true);
 | 
					    auto inputs = tree_flatten(args, true);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // py_value_out will hold the output of the python function in order to be
 | 
					    // TODO, awni, I think this cast is ok??
 | 
				
			||||||
    // able to reconstruct the python tree of extra return values
 | 
					    size_t fun_id = reinterpret_cast<size_t>(fun.ptr());
 | 
				
			||||||
    py::object py_outputs;
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    auto compile_fun =
 | 
					    auto compile_fun = [fun_id, &fun, &args, &inputs](
 | 
				
			||||||
        [&fun, &args, &inputs, &py_outputs](const std::vector<array>& a) {
 | 
					                           const std::vector<array>& a) {
 | 
				
			||||||
      // Call the python function
 | 
					      // Call the python function
 | 
				
			||||||
          py_outputs = fun(*tree_unflatten(args, a));
 | 
					      py::object py_outputs = fun(*tree_unflatten(args, a));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      // Flatten the outputs
 | 
					      // Flatten the outputs
 | 
				
			||||||
          return tree_flatten(py_outputs, true);
 | 
					      auto outputs = tree_flatten(py_outputs, true);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      py_outputs =
 | 
				
			||||||
 | 
					          tree_map(py_outputs, [](const py::handle& x) { return py::none(); });
 | 
				
			||||||
 | 
					      tree_cache.insert({fun_id, py_outputs});
 | 
				
			||||||
 | 
					      return outputs;
 | 
				
			||||||
    };
 | 
					    };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // Compile and call
 | 
					    // Compile and call
 | 
				
			||||||
    // TODO, awni, I think this cast is ok??
 | 
					 | 
				
			||||||
    size_t fun_id = reinterpret_cast<size_t>(fun.ptr());
 | 
					 | 
				
			||||||
    auto outputs = detail::compile(compile_fun, fun_id)(inputs);
 | 
					    auto outputs = detail::compile(compile_fun, fun_id)(inputs);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // Put the outputs back in the container
 | 
					    // Put the outputs back in the container
 | 
				
			||||||
    return tree_unflatten(py_outputs, outputs);
 | 
					    py::object py_outputs = tree_cache.at(fun_id);
 | 
				
			||||||
 | 
					    return tree_unflatten_none(py_outputs, outputs);
 | 
				
			||||||
  };
 | 
					  };
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -15,7 +15,12 @@ class TestCompile(mlx_tests.MLXTestCase):
 | 
				
			|||||||
        compiled_fn = mx.compile(fun)
 | 
					        compiled_fn = mx.compile(fun)
 | 
				
			||||||
        x = mx.array(1.0)
 | 
					        x = mx.array(1.0)
 | 
				
			||||||
        y = mx.array(1.0)
 | 
					        y = mx.array(1.0)
 | 
				
			||||||
        # out = compiled_fn(x, y)
 | 
					        out = compiled_fn(x, y)
 | 
				
			||||||
 | 
					        self.assertEqual(out.item(), 2.0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Try again
 | 
				
			||||||
 | 
					        out = compiled_fn(x, y)
 | 
				
			||||||
 | 
					        self.assertEqual(out.item(), 2.0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,6 +1,8 @@
 | 
				
			|||||||
// Copyright © 2023 Apple Inc.
 | 
					// Copyright © 2023 Apple Inc.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include <iostream> // TODO
 | 
				
			||||||
#include "doctest/doctest.h"
 | 
					#include "doctest/doctest.h"
 | 
				
			||||||
 | 
					#include "mlx/utils.h" // TODO
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#include "mlx/mlx.h"
 | 
					#include "mlx/mlx.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -33,17 +35,50 @@ TEST_CASE("test simple compile") {
 | 
				
			|||||||
  CHECK(array_equal(out, array({3.0f, 4.0f})).item<bool>());
 | 
					  CHECK(array_equal(out, array({3.0f, 4.0f})).item<bool>());
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
std::vector<array> fun1(const std::vector<array>& inputs) {
 | 
					std::vector<array> grad_fun(const std::vector<array>& inputs) {
 | 
				
			||||||
  auto loss = [](std::vector<array> ins) { return exp(ins[0] + ins[1]); };
 | 
					  auto loss = [](std::vector<array> ins) { return exp(ins[0] + ins[1]); };
 | 
				
			||||||
  return grad(loss)(inputs);
 | 
					  return grad(loss, {0, 1})(inputs);
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
TEST_CASE("test compile with grad") {
 | 
					TEST_CASE("test compile with grad") {
 | 
				
			||||||
  auto x = array(1.0f);
 | 
					  auto x = array(1.0f);
 | 
				
			||||||
  auto y = array(1.0f);
 | 
					  auto y = array(1.0f);
 | 
				
			||||||
  auto grads_expected = fun1({x, y});
 | 
					  auto grads_expected = grad_fun({x, y});
 | 
				
			||||||
  auto grads_compile = compile(fun1)({x, y});
 | 
					  auto grads_compile = compile(grad_fun)({x, y});
 | 
				
			||||||
  CHECK_EQ(grads_compile[0].item<float>(), grads_expected[0].item<float>());
 | 
					  CHECK_EQ(grads_compile[0].item<float>(), grads_expected[0].item<float>());
 | 
				
			||||||
 | 
					  CHECK_EQ(grads_compile[1].item<float>(), grads_expected[1].item<float>());
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST_CASE("test compile inputs with primitive") {
 | 
				
			||||||
 | 
					  auto [k1, k2] = random::split(random::key(0));
 | 
				
			||||||
 | 
					  auto x = random::uniform({5, 5}, k1);
 | 
				
			||||||
 | 
					  auto y = random::uniform({5, 5}, k2);
 | 
				
			||||||
 | 
					  auto expected = simple_fun({x, y})[0];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  x = random::uniform({5, 5}, k1);
 | 
				
			||||||
 | 
					  y = random::uniform({5, 5}, k2);
 | 
				
			||||||
 | 
					  auto out = compile(simple_fun)({x, y})[0];
 | 
				
			||||||
 | 
					  CHECK(array_equal(expected, out).item<bool>());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Same thing twice
 | 
				
			||||||
 | 
					  out = compile(simple_fun)({x, y})[0];
 | 
				
			||||||
 | 
					  CHECK(array_equal(expected, out).item<bool>());
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/*std::vector<array> bigger_fun(const std::vector<array>& inputs) {
 | 
				
			||||||
 | 
					  auto x = inputs[1];
 | 
				
			||||||
 | 
					  for (int i = 1; i < inputs.size(); ++i) {
 | 
				
			||||||
 | 
					    w = inputs[i]
 | 
				
			||||||
 | 
					    x = maximum(matmul(x, w), 0);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  return take(x, array(3)) - logsumexp(x);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST_CASE("test bigger graph") {
 | 
				
			||||||
 | 
					  std::vector<array> inputs;
 | 
				
			||||||
 | 
					  inputs.push_back(
 | 
				
			||||||
 | 
					  for (int
 | 
				
			||||||
 | 
					  for
 | 
				
			||||||
 | 
					}*/
 | 
				
			||||||
 | 
					
 | 
				
			||||||
TEST_CASE("test nested compile") {}
 | 
					TEST_CASE("test nested compile") {}
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user