mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-21 16:51:15 +08:00
more cpp tests
This commit is contained in:
parent
d34d10ebac
commit
ecfb72157e
@ -1,8 +1,6 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
#include <iostream> // TODO
|
|
||||||
#include "doctest/doctest.h"
|
#include "doctest/doctest.h"
|
||||||
#include "mlx/utils.h" // TODO
|
|
||||||
|
|
||||||
#include "mlx/mlx.h"
|
#include "mlx/mlx.h"
|
||||||
|
|
||||||
@ -65,20 +63,36 @@ TEST_CASE("test compile inputs with primitive") {
|
|||||||
CHECK(array_equal(expected, out).item<bool>());
|
CHECK(array_equal(expected, out).item<bool>());
|
||||||
}
|
}
|
||||||
|
|
||||||
/*std::vector<array> bigger_fun(const std::vector<array>& inputs) {
|
std::vector<array> fun_creats_array(const std::vector<array>& inputs) {
|
||||||
auto x = inputs[1];
|
return {inputs[0] + array(1.0f)};
|
||||||
for (int i = 1; i < inputs.size(); ++i) {
|
|
||||||
w = inputs[i]
|
|
||||||
x = maximum(matmul(x, w), 0);
|
|
||||||
}
|
|
||||||
return take(x, array(3)) - logsumexp(x);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_CASE("test bigger graph") {
|
TEST_CASE("test compile with created array") {
|
||||||
std::vector<array> inputs;
|
auto cfun = compile(fun_creats_array);
|
||||||
inputs.push_back(
|
auto out = cfun({array(2.0f)});
|
||||||
for (int
|
CHECK_EQ(out[0].item<float>(), 3.0f);
|
||||||
for
|
|
||||||
}*/
|
|
||||||
|
|
||||||
TEST_CASE("test nested compile") {}
|
// Try again
|
||||||
|
out = cfun({array(2.0f)});
|
||||||
|
CHECK_EQ(out[0].item<float>(), 3.0f);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<array> inner_fun(const std::vector<array>& inputs) {
|
||||||
|
return {array(2) * inputs[0]};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<array> outer_fun(const std::vector<array>& inputs) {
|
||||||
|
auto x = inputs[0] + inputs[1];
|
||||||
|
auto y = compile(inner_fun)({x})[0];
|
||||||
|
return {x + y};
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test nested compile") {
|
||||||
|
auto cfun = compile(outer_fun);
|
||||||
|
auto out = cfun({array(1), array(2)})[0];
|
||||||
|
CHECK_EQ(out.item<int>(), 9);
|
||||||
|
|
||||||
|
// Try again
|
||||||
|
out = cfun({array(1), array(2)})[0];
|
||||||
|
CHECK_EQ(out.item<int>(), 9);
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user