more cpp tests

This commit is contained in:
Awni Hannun 2024-01-15 22:03:41 -08:00
parent d34d10ebac
commit ecfb72157e

View File

@ -1,8 +1,6 @@
// Copyright © 2023 Apple Inc.
#include <iostream> // TODO
#include "doctest/doctest.h"
#include "mlx/utils.h" // TODO
#include "mlx/mlx.h"
@ -65,20 +63,36 @@ TEST_CASE("test compile inputs with primitive") {
CHECK(array_equal(expected, out).item<bool>());
}
/*std::vector<array> bigger_fun(const std::vector<array>& inputs) {
auto x = inputs[1];
for (int i = 1; i < inputs.size(); ++i) {
w = inputs[i]
x = maximum(matmul(x, w), 0);
}
return take(x, array(3)) - logsumexp(x);
std::vector<array> fun_creats_array(const std::vector<array>& inputs) {
return {inputs[0] + array(1.0f)};
}
TEST_CASE("test bigger graph") {
std::vector<array> inputs;
inputs.push_back(
for (int
for
}*/
TEST_CASE("test compile with created array") {
auto cfun = compile(fun_creats_array);
auto out = cfun({array(2.0f)});
CHECK_EQ(out[0].item<float>(), 3.0f);
TEST_CASE("test nested compile") {}
// Try again
out = cfun({array(2.0f)});
CHECK_EQ(out[0].item<float>(), 3.0f);
}
std::vector<array> inner_fun(const std::vector<array>& inputs) {
return {array(2) * inputs[0]};
}
std::vector<array> outer_fun(const std::vector<array>& inputs) {
auto x = inputs[0] + inputs[1];
auto y = compile(inner_fun)({x})[0];
return {x + y};
}
TEST_CASE("test nested compile") {
auto cfun = compile(outer_fun);
auto out = cfun({array(1), array(2)})[0];
CHECK_EQ(out.item<int>(), 9);
// Try again
out = cfun({array(1), array(2)})[0];
CHECK_EQ(out.item<int>(), 9);
}