mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-19 15:41:13 +08:00
more cpp tests
This commit is contained in:
parent
d34d10ebac
commit
ecfb72157e
@ -1,8 +1,6 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <iostream> // TODO
|
||||
#include "doctest/doctest.h"
|
||||
#include "mlx/utils.h" // TODO
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
@ -65,20 +63,36 @@ TEST_CASE("test compile inputs with primitive") {
|
||||
CHECK(array_equal(expected, out).item<bool>());
|
||||
}
|
||||
|
||||
/*std::vector<array> bigger_fun(const std::vector<array>& inputs) {
|
||||
auto x = inputs[1];
|
||||
for (int i = 1; i < inputs.size(); ++i) {
|
||||
w = inputs[i]
|
||||
x = maximum(matmul(x, w), 0);
|
||||
}
|
||||
return take(x, array(3)) - logsumexp(x);
|
||||
std::vector<array> fun_creats_array(const std::vector<array>& inputs) {
|
||||
return {inputs[0] + array(1.0f)};
|
||||
}
|
||||
|
||||
TEST_CASE("test bigger graph") {
|
||||
std::vector<array> inputs;
|
||||
inputs.push_back(
|
||||
for (int
|
||||
for
|
||||
}*/
|
||||
TEST_CASE("test compile with created array") {
|
||||
auto cfun = compile(fun_creats_array);
|
||||
auto out = cfun({array(2.0f)});
|
||||
CHECK_EQ(out[0].item<float>(), 3.0f);
|
||||
|
||||
TEST_CASE("test nested compile") {}
|
||||
// Try again
|
||||
out = cfun({array(2.0f)});
|
||||
CHECK_EQ(out[0].item<float>(), 3.0f);
|
||||
}
|
||||
|
||||
std::vector<array> inner_fun(const std::vector<array>& inputs) {
|
||||
return {array(2) * inputs[0]};
|
||||
}
|
||||
|
||||
std::vector<array> outer_fun(const std::vector<array>& inputs) {
|
||||
auto x = inputs[0] + inputs[1];
|
||||
auto y = compile(inner_fun)({x})[0];
|
||||
return {x + y};
|
||||
}
|
||||
|
||||
TEST_CASE("test nested compile") {
|
||||
auto cfun = compile(outer_fun);
|
||||
auto out = cfun({array(1), array(2)})[0];
|
||||
CHECK_EQ(out.item<int>(), 9);
|
||||
|
||||
// Try again
|
||||
out = cfun({array(1), array(2)})[0];
|
||||
CHECK_EQ(out.item<int>(), 9);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user