mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
33 lines
685 B
C++
33 lines
685 B
C++
// Copyright © 2023 Apple Inc.
|
|
|
|
#include "doctest/doctest.h"
|
|
|
|
#include "mlx/mlx.h"
|
|
|
|
using namespace mlx::core;
|
|
|
|
TEST_CASE("test simplify scalars") {
|
|
auto a = array({-1.0f, 2.0f});
|
|
auto b = maximum(a, array(0.0f));
|
|
auto c = maximum(-a, array(0.0f));
|
|
auto d = b + c;
|
|
simplify({d});
|
|
CHECK(b.inputs()[1].id() == c.inputs()[1].id());
|
|
}
|
|
|
|
TEST_CASE("test simplify") {
|
|
auto a = array({1.0f, 2.0f});
|
|
auto b = exp(a) + exp(a);
|
|
simplify(b);
|
|
eval(b);
|
|
CHECK(b.inputs()[0].id() == b.inputs()[1].id());
|
|
}
|
|
|
|
TEST_CASE("test no simplify") {
|
|
auto a = array({1.0f, 2.0f});
|
|
auto b = cos(a) + sin(a);
|
|
simplify(b);
|
|
eval(b);
|
|
CHECK(b.inputs()[0].id() != b.inputs()[1].id());
|
|
}
|