mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-03 18:18:15 +08:00
basic python tests
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <iostream> // TODO
|
||||
#include "doctest/doctest.h"
|
||||
#include "mlx/utils.h" // TODO
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
@@ -33,17 +35,50 @@ TEST_CASE("test simple compile") {
|
||||
CHECK(array_equal(out, array({3.0f, 4.0f})).item<bool>());
|
||||
}
|
||||
|
||||
std::vector<array> fun1(const std::vector<array>& inputs) {
|
||||
std::vector<array> grad_fun(const std::vector<array>& inputs) {
|
||||
auto loss = [](std::vector<array> ins) { return exp(ins[0] + ins[1]); };
|
||||
return grad(loss)(inputs);
|
||||
return grad(loss, {0, 1})(inputs);
|
||||
}
|
||||
|
||||
TEST_CASE("test compile with grad") {
|
||||
auto x = array(1.0f);
|
||||
auto y = array(1.0f);
|
||||
auto grads_expected = fun1({x, y});
|
||||
auto grads_compile = compile(fun1)({x, y});
|
||||
auto grads_expected = grad_fun({x, y});
|
||||
auto grads_compile = compile(grad_fun)({x, y});
|
||||
CHECK_EQ(grads_compile[0].item<float>(), grads_expected[0].item<float>());
|
||||
CHECK_EQ(grads_compile[1].item<float>(), grads_expected[1].item<float>());
|
||||
}
|
||||
|
||||
TEST_CASE("test compile inputs with primitive") {
|
||||
auto [k1, k2] = random::split(random::key(0));
|
||||
auto x = random::uniform({5, 5}, k1);
|
||||
auto y = random::uniform({5, 5}, k2);
|
||||
auto expected = simple_fun({x, y})[0];
|
||||
|
||||
x = random::uniform({5, 5}, k1);
|
||||
y = random::uniform({5, 5}, k2);
|
||||
auto out = compile(simple_fun)({x, y})[0];
|
||||
CHECK(array_equal(expected, out).item<bool>());
|
||||
|
||||
// Same thing twice
|
||||
out = compile(simple_fun)({x, y})[0];
|
||||
CHECK(array_equal(expected, out).item<bool>());
|
||||
}
|
||||
|
||||
/*std::vector<array> bigger_fun(const std::vector<array>& inputs) {
|
||||
auto x = inputs[1];
|
||||
for (int i = 1; i < inputs.size(); ++i) {
|
||||
w = inputs[i]
|
||||
x = maximum(matmul(x, w), 0);
|
||||
}
|
||||
return take(x, array(3)) - logsumexp(x);
|
||||
}
|
||||
|
||||
TEST_CASE("test bigger graph") {
|
||||
std::vector<array> inputs;
|
||||
inputs.push_back(
|
||||
for (int
|
||||
for
|
||||
}*/
|
||||
|
||||
TEST_CASE("test nested compile") {}
|
||||
|
||||
Reference in New Issue
Block a user