mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-29 14:58:11 +08:00
angelos's commit files
This commit is contained in:
224
tests/creations_tests.cpp
Normal file
224
tests/creations_tests.cpp
Normal file
@@ -0,0 +1,224 @@
|
||||
#include "doctest/doctest.h"
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
TEST_CASE("test arange") {
|
||||
// Check type is inferred correclty
|
||||
{
|
||||
auto x = arange(10);
|
||||
CHECK_EQ(x.dtype(), int32);
|
||||
|
||||
x = arange(10.0);
|
||||
CHECK_EQ(x.dtype(), float32);
|
||||
|
||||
x = arange(10, float32);
|
||||
CHECK_EQ(x.dtype(), float32);
|
||||
|
||||
x = arange(10.0, int32);
|
||||
CHECK_EQ(x.dtype(), int32);
|
||||
|
||||
x = arange(0, 10);
|
||||
CHECK_EQ(x.dtype(), int32);
|
||||
|
||||
x = arange(0.0, 10.0, int32);
|
||||
CHECK_EQ(x.dtype(), int32);
|
||||
|
||||
x = arange(0.0, 10.0);
|
||||
CHECK_EQ(x.dtype(), float32);
|
||||
|
||||
x = arange(0, 10, float32);
|
||||
CHECK_EQ(x.dtype(), float32);
|
||||
|
||||
x = arange(0, 10, 0.1, float32);
|
||||
CHECK_EQ(x.dtype(), float32);
|
||||
|
||||
x = arange(0.0, 10.0, 0.5, int32);
|
||||
CHECK_EQ(x.dtype(), int32);
|
||||
|
||||
x = arange(10.0, uint32);
|
||||
CHECK_EQ(x.dtype(), uint32);
|
||||
x = arange(0.0, 10.0, uint32);
|
||||
CHECK_EQ(x.dtype(), uint32);
|
||||
x = arange(0.0, 10.0, 0.5, uint32);
|
||||
CHECK_EQ(x.dtype(), uint32);
|
||||
|
||||
// arange unsupported for bool_
|
||||
CHECK_THROWS_AS(arange(10, bool_), std::invalid_argument);
|
||||
}
|
||||
|
||||
// Check correct sizes
|
||||
{
|
||||
auto x = arange(10);
|
||||
CHECK_EQ(x.size(), 10);
|
||||
|
||||
x = arange(0.0, 10.0, 0.5);
|
||||
CHECK_EQ(x.size(), 20);
|
||||
|
||||
x = arange(0.0, 10.0, 0.45);
|
||||
CHECK_EQ(x.size(), 23);
|
||||
|
||||
x = arange(0, 10, 10);
|
||||
CHECK_EQ(x.size(), 1);
|
||||
|
||||
x = arange(0, 10, 9);
|
||||
CHECK_EQ(x.size(), 2);
|
||||
|
||||
x = arange(0, 10, 100);
|
||||
CHECK_EQ(x.size(), 1);
|
||||
|
||||
x = arange(0, -10, 1);
|
||||
CHECK_EQ(x.size(), 0);
|
||||
|
||||
x = arange(0, -10, -1);
|
||||
CHECK_EQ(x.size(), 10);
|
||||
|
||||
x = arange(0, -10, -10);
|
||||
CHECK_EQ(x.size(), 1);
|
||||
}
|
||||
|
||||
// Check values
|
||||
{
|
||||
auto x = arange(0, 3);
|
||||
CHECK(array_equal(x, array({0, 1, 2})).item<bool>());
|
||||
|
||||
x = arange(0, 3, 2);
|
||||
CHECK(array_equal(x, array({0, 2})).item<bool>());
|
||||
|
||||
x = arange(0, 3, 3);
|
||||
CHECK(array_equal(x, array({0})).item<bool>());
|
||||
|
||||
x = arange(0, -3, 1);
|
||||
CHECK(array_equal(x, array({})).item<bool>());
|
||||
|
||||
x = arange(0, 3, -1);
|
||||
CHECK(array_equal(x, array({})).item<bool>());
|
||||
|
||||
x = arange(0, -3, -1);
|
||||
CHECK(array_equal(x, array({0, -1, -2})).item<bool>());
|
||||
|
||||
x = arange(0.0, 5.0, 0.5, int32);
|
||||
CHECK(array_equal(x, zeros({10})).item<bool>());
|
||||
|
||||
x = arange(0.0, 5.0, 1.5, int32);
|
||||
CHECK(array_equal(x, array({0, 1, 2, 3})).item<bool>());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test astype") {
|
||||
// Check type conversions
|
||||
{
|
||||
auto x = array(1);
|
||||
auto y = astype(x, float32);
|
||||
CHECK_EQ(y.dtype(), float32);
|
||||
CHECK_EQ(y.item<float>(), 1.0f);
|
||||
|
||||
y = astype(x, int32);
|
||||
CHECK_EQ(y.dtype(), int32);
|
||||
CHECK_EQ(y.item<int>(), 1);
|
||||
|
||||
x = array(-3.0f);
|
||||
y = astype(x, int32);
|
||||
CHECK_EQ(y.dtype(), int32);
|
||||
CHECK_EQ(y.item<int>(), -3);
|
||||
|
||||
y = astype(x, uint32);
|
||||
CHECK_EQ(y.dtype(), uint32);
|
||||
|
||||
// Use std::copy since the result is platform dependent
|
||||
uint32_t v;
|
||||
std::copy(x.data<float>(), x.data<float>() + 1, &v);
|
||||
CHECK_EQ(y.item<uint32_t>(), v);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test full") {
|
||||
// Check full works for different types
|
||||
{
|
||||
auto x = full({}, 0);
|
||||
CHECK_EQ(x.dtype(), int32);
|
||||
CHECK_EQ(x.item<int>(), 0);
|
||||
|
||||
x = full({}, 0.0);
|
||||
CHECK_EQ(x.dtype(), float32);
|
||||
CHECK_EQ(x.item<float>(), 0);
|
||||
|
||||
x = full({}, false);
|
||||
CHECK_EQ(x.item<bool>(), false);
|
||||
|
||||
x = full({}, 0, int32);
|
||||
CHECK_EQ(x.item<int>(), 0);
|
||||
|
||||
x = full({}, 0, float32);
|
||||
CHECK_EQ(x.item<float>(), 0);
|
||||
|
||||
x = full({1, 2}, 2, float32);
|
||||
CHECK(array_equal(x, array({2.0, 2.0}, {1, 2})).item<bool>());
|
||||
|
||||
x = full({2, 1}, 2, float32);
|
||||
CHECK(array_equal(x, array({2.0, 2.0}, {2, 1})).item<bool>());
|
||||
|
||||
x = full({2}, false);
|
||||
CHECK_EQ(x.dtype(), bool_);
|
||||
CHECK(array_equal(x, array({false, false})).item<bool>());
|
||||
|
||||
x = full({2}, 1.0, bool_);
|
||||
CHECK_EQ(x.dtype(), bool_);
|
||||
CHECK(array_equal(x, array({true, true})).item<bool>());
|
||||
|
||||
x = full({2}, 1.0, uint32);
|
||||
CHECK_EQ(x.dtype(), uint32);
|
||||
CHECK(array_equal(x, array({1, 1})).item<bool>());
|
||||
|
||||
CHECK_THROWS_AS(full({2}, array({})), std::invalid_argument);
|
||||
}
|
||||
|
||||
// Check broadcasting works
|
||||
{
|
||||
auto x = full({2, 2}, array({3, 4}, {2, 1}));
|
||||
CHECK(array_equal(x, array({3, 3, 4, 4}, {2, 2})).item<bool>());
|
||||
x = full({2, 2}, array({3, 4}, {1, 2}));
|
||||
CHECK(array_equal(x, array({3, 4, 3, 4}, {2, 2})).item<bool>());
|
||||
}
|
||||
|
||||
// Check zeros and ones
|
||||
{
|
||||
auto x = zeros({2, 2}, float32);
|
||||
CHECK_EQ(x.shape(), std::vector<int>{2, 2});
|
||||
CHECK_EQ(x.ndim(), 2);
|
||||
CHECK_EQ(x.dtype(), float32);
|
||||
auto y = array({0.0, 0.0, 0.0, 0.0}, {2, 2});
|
||||
CHECK(array_equal(x, y).item<bool>());
|
||||
|
||||
x = ones({2, 2}, float32);
|
||||
CHECK_EQ(x.shape(), std::vector<int>{2, 2});
|
||||
CHECK_EQ(x.ndim(), 2);
|
||||
CHECK_EQ(x.dtype(), float32);
|
||||
y = array({1.0, 1.0, 1.0, 1.0}, {2, 2});
|
||||
CHECK(array_equal(x, y).item<bool>());
|
||||
|
||||
x = zeros({2, 2}, int32);
|
||||
y = zeros_like(x);
|
||||
CHECK_EQ(y.dtype(), int32);
|
||||
CHECK(array_equal(x, y).item<bool>());
|
||||
|
||||
x = ones({2, 2}, int32);
|
||||
y = ones_like(x);
|
||||
CHECK_EQ(y.dtype(), int32);
|
||||
CHECK(array_equal(x, y).item<bool>());
|
||||
}
|
||||
|
||||
// Works for empty shape and empty array
|
||||
{
|
||||
array x = ones({}, int32);
|
||||
CHECK_EQ(x.shape(), std::vector<int>{});
|
||||
CHECK_EQ(x.item<int>(), 1);
|
||||
|
||||
x = full({0}, array({}));
|
||||
CHECK_EQ(x.shape(), std::vector<int>{0});
|
||||
CHECK_EQ(x.size(), 0);
|
||||
|
||||
CHECK_THROWS_AS(full({}, array({})), std::invalid_argument);
|
||||
}
|
||||
}
|
||||
331
tests/fft_tests.cpp
Normal file
331
tests/fft_tests.cpp
Normal file
@@ -0,0 +1,331 @@
|
||||
#include "doctest/doctest.h"
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
TEST_CASE("test fft basics") {
|
||||
auto device = default_device();
|
||||
set_default_device(Device::cpu);
|
||||
array x(1.0);
|
||||
CHECK_THROWS(fft::fft(x));
|
||||
CHECK_THROWS(fft::ifft(x));
|
||||
|
||||
x = array({1.0});
|
||||
auto y = fft::fft(x);
|
||||
CHECK_EQ(y.dtype(), complex64);
|
||||
CHECK_EQ(y.size(), x.size());
|
||||
CHECK_EQ(y.item<complex64_t>(), complex64_t{1.0f, 0.0f});
|
||||
|
||||
y = fft::ifft(x);
|
||||
CHECK_EQ(y.dtype(), complex64);
|
||||
CHECK_EQ(y.size(), x.size());
|
||||
CHECK_EQ(y.item<complex64_t>(), complex64_t{1.0f, 0.0f});
|
||||
|
||||
x = array({complex64_t{1.0f, 1.0f}}, complex64);
|
||||
y = fft::fft(x);
|
||||
CHECK_EQ(y.size(), x.size());
|
||||
CHECK_EQ(y.item<complex64_t>(), complex64_t{1.0f, 1.0f});
|
||||
|
||||
y = fft::ifft(x);
|
||||
CHECK_EQ(y.dtype(), complex64);
|
||||
CHECK_EQ(y.size(), x.size());
|
||||
CHECK_EQ(y.item<complex64_t>(), complex64_t{1.0f, 1.0f});
|
||||
|
||||
{
|
||||
x = array({0.0f, 1.0f, 2.0f, 3.0f});
|
||||
y = fft::fft(x);
|
||||
std::initializer_list<complex64_t> expected = {
|
||||
{6.0, 0.0},
|
||||
{-2.0, 2.0},
|
||||
{-2.0, 0.0},
|
||||
{-2.0, -2.0},
|
||||
};
|
||||
CHECK_EQ(y.size(), x.size());
|
||||
CHECK(array_equal(y, array(expected)).item<bool>());
|
||||
|
||||
y = fft::ifft(x);
|
||||
std::initializer_list<complex64_t> expected_inv = {
|
||||
{1.5, 0.0},
|
||||
{-0.5, -0.5},
|
||||
{-0.5, 0.0},
|
||||
{-0.5, 0.5},
|
||||
};
|
||||
CHECK(array_equal(y, array(expected_inv)).item<bool>());
|
||||
}
|
||||
|
||||
{
|
||||
std::initializer_list<complex64_t> vals = {
|
||||
{1.0f, 1.0f}, {2.0f, 1.0f}, {1.0f, 2.0f}, {2.0f, 2.0f}};
|
||||
x = array(vals);
|
||||
y = fft::fft(x);
|
||||
std::initializer_list<complex64_t> expected = {
|
||||
{6.0, 6.0},
|
||||
{-1.0, -1.0},
|
||||
{-2.0, 0.0},
|
||||
{1.0, -1.0},
|
||||
};
|
||||
CHECK_EQ(y.size(), x.size());
|
||||
CHECK(array_equal(y, array(expected)).item<bool>());
|
||||
CHECK(array_equal(fft::ifft(y), x).item<bool>());
|
||||
}
|
||||
|
||||
// Specify axes
|
||||
{
|
||||
x = array({0.0f, 1.0f, 2.0f, 3.0f}, {2, 2});
|
||||
std::initializer_list<complex64_t> expected_0 = {
|
||||
{2.0, 0.0},
|
||||
{4.0, 0.0},
|
||||
{-2.0, 0.0},
|
||||
{-2.0, 0.0},
|
||||
};
|
||||
y = fft::fft(x, 0);
|
||||
CHECK(array_equal(y, array(expected_0, {2, 2})).item<bool>());
|
||||
CHECK(array_equal(fft::ifft(y, 0), x).item<bool>());
|
||||
std::initializer_list<complex64_t> expected_1 = {
|
||||
{1.0, 0.0},
|
||||
{-1.0, 0.0},
|
||||
{5.0, 0.0},
|
||||
{-1.0, 0.0},
|
||||
};
|
||||
y = fft::fft(x, 1);
|
||||
CHECK(array_equal(y, array(expected_1, {2, 2})).item<bool>());
|
||||
CHECK(array_equal(fft::ifft(y, 1), x).item<bool>());
|
||||
}
|
||||
set_default_device(device);
|
||||
}
|
||||
|
||||
TEST_CASE("test real ffts") {
|
||||
auto device = default_device();
|
||||
set_default_device(Device::cpu);
|
||||
|
||||
auto x = array({1.0});
|
||||
auto y = fft::rfft(x);
|
||||
CHECK_EQ(y.dtype(), complex64);
|
||||
CHECK_EQ(y.size(), x.size());
|
||||
CHECK_EQ(y.item<complex64_t>(), complex64_t{1.0f, 0.0f});
|
||||
|
||||
{
|
||||
x = array({0.0f, 1.0f, 2.0f, 3.0f});
|
||||
y = fft::rfft(x);
|
||||
std::initializer_list<complex64_t> expected = {
|
||||
{6.0, 0.0}, {-2.0, 2.0}, {-2.0, -0.0}};
|
||||
CHECK_EQ(y.size(), x.size() / 2 + 1);
|
||||
CHECK(array_equal(y, array(expected)).item<bool>());
|
||||
}
|
||||
|
||||
x = array(complex64_t{1, 1});
|
||||
CHECK_THROWS(fft::irfft(x));
|
||||
|
||||
x = array({complex64_t{0, 1}, complex64_t{1, 0}});
|
||||
y = fft::irfft(x);
|
||||
CHECK_EQ(y.size(), 2);
|
||||
CHECK_EQ(y.dtype(), float32);
|
||||
CHECK(array_equal(y, array({0.5f, -0.5f})).item<bool>());
|
||||
|
||||
set_default_device(device);
|
||||
}
|
||||
|
||||
TEST_CASE("test fftn") {
|
||||
auto device = default_device();
|
||||
set_default_device(Device::cpu);
|
||||
|
||||
auto x = zeros({5, 5, 5});
|
||||
CHECK_THROWS_AS(fft::fftn(x, {}, {0, 3}), std::invalid_argument);
|
||||
CHECK_THROWS_AS(fft::fftn(x, {}, {0, -4}), std::invalid_argument);
|
||||
CHECK_THROWS_AS(fft::fftn(x, {}, {0, 0}), std::invalid_argument);
|
||||
CHECK_THROWS_AS(fft::fftn(x, {5, 5, 5}, {0}), std::invalid_argument);
|
||||
CHECK_THROWS_AS(fft::fftn(x, {0}, {}, {}), std::invalid_argument);
|
||||
CHECK_THROWS_AS(fft::fftn(x, {1, -1}, {}, {}), std::invalid_argument);
|
||||
|
||||
// Test 2D FFT
|
||||
{
|
||||
x = array({0.0f, 1.0f, 2.0f, 3.0f}, {2, 2});
|
||||
std::initializer_list<complex64_t> expected = {
|
||||
{6.0, 0.0},
|
||||
{-2.0, 0.0},
|
||||
{-4.0, 0.0},
|
||||
{0.0, 0.0},
|
||||
};
|
||||
auto y = fft::fft2(x);
|
||||
CHECK(array_equal(y, array(expected, {2, 2})).item<bool>());
|
||||
CHECK(array_equal(fft::ifft2(y), x).item<bool>());
|
||||
}
|
||||
|
||||
// Test 3D FFT
|
||||
{
|
||||
x = reshape(arange(8, float32), {2, 2, 2});
|
||||
std::initializer_list<complex64_t> expected = {
|
||||
{28.0, 0.0},
|
||||
{-4.0, 0.0},
|
||||
{-8.0, 0.0},
|
||||
{0.0, 0.0},
|
||||
{-16.0, 0.0},
|
||||
{0.0, 0.0},
|
||||
{0.0, 0.0},
|
||||
{0.0, 0.0},
|
||||
};
|
||||
auto y = fft::fftn(x);
|
||||
CHECK(array_equal(y, array(expected, {2, 2, 2})).item<bool>());
|
||||
CHECK(array_equal(fft::ifftn(y), x).item<bool>());
|
||||
|
||||
x = reshape(arange(20, float32), {5, 4});
|
||||
y = fft::rfftn(x);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{5, 3});
|
||||
y = fft::rfftn(x, {1, 0});
|
||||
CHECK_EQ(y.shape(), std::vector<int>{3, 4});
|
||||
|
||||
x = reshape(arange(20, float32), {5, 4});
|
||||
y = fft::irfftn(x);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{5, 6});
|
||||
y = fft::irfftn(x, {1, 0});
|
||||
CHECK_EQ(y.shape(), std::vector<int>{8, 4});
|
||||
}
|
||||
|
||||
// Check the types of real ffts
|
||||
{
|
||||
x = zeros({5, 5}, float32);
|
||||
auto y = fft::rfft2(x);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{5, 3});
|
||||
CHECK_EQ(y.dtype(), complex64);
|
||||
|
||||
y = fft::rfftn(x);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{5, 3});
|
||||
CHECK_EQ(y.dtype(), complex64);
|
||||
|
||||
x = zeros({5, 5}, complex64);
|
||||
y = fft::irfft2(x);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{5, 8});
|
||||
CHECK_EQ(y.dtype(), float32);
|
||||
|
||||
y = fft::irfftn(x);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{5, 8});
|
||||
CHECK_EQ(y.dtype(), float32);
|
||||
}
|
||||
|
||||
set_default_device(device);
|
||||
}
|
||||
|
||||
TEST_CASE("test fft with provided shape") {
|
||||
auto x = ones({5, 5});
|
||||
|
||||
auto y = fft::fft(x, 7, 0);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{7, 5});
|
||||
|
||||
y = fft::fft(x, 3, 0);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{3, 5});
|
||||
|
||||
y = fft::fft(x, 7, 1);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{5, 7});
|
||||
|
||||
y = fft::fft(x, 3, 1);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{5, 3});
|
||||
|
||||
y = fft::rfft(x, 7, 0);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{4, 5});
|
||||
|
||||
y = fft::rfft(x, 3, 0);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{2, 5});
|
||||
|
||||
y = fft::rfft(x, 3, 1);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{5, 2});
|
||||
}
|
||||
|
||||
TEST_CASE("test fft vmap") {
|
||||
auto device = default_device();
|
||||
set_default_device(Device::cpu);
|
||||
|
||||
auto fft_fn = [](array x) { return fft::fft(x); };
|
||||
auto x = reshape(arange(8), {2, 4});
|
||||
auto y = vmap(fft_fn)(x);
|
||||
CHECK(array_equal(y, fft::fft(x)).item<bool>());
|
||||
|
||||
y = vmap(fft_fn, 1, 1)(x);
|
||||
CHECK(array_equal(y, fft::fft(x, 0)).item<bool>());
|
||||
|
||||
auto rfft_fn = [](array x) { return fft::rfft(x); };
|
||||
|
||||
y = vmap(rfft_fn)(x);
|
||||
CHECK(array_equal(y, fft::rfft(x)).item<bool>());
|
||||
|
||||
y = vmap(rfft_fn, 1, 1)(x);
|
||||
CHECK(array_equal(y, fft::rfft(x, 0)).item<bool>());
|
||||
|
||||
set_default_device(device);
|
||||
}
|
||||
|
||||
TEST_CASE("test fft grads") {
|
||||
auto device = default_device();
|
||||
set_default_device(Device::cpu);
|
||||
|
||||
// Regular
|
||||
auto fft_fn = [](array x) { return fft::fft(x); };
|
||||
auto cotangent = astype(arange(10), complex64);
|
||||
auto vjp_out = vjp(fft_fn, zeros_like(cotangent), cotangent).second;
|
||||
CHECK(array_equal(fft::fft(cotangent), vjp_out).item<bool>());
|
||||
|
||||
auto tangent = astype(arange(10), complex64);
|
||||
auto jvp_out = jvp(fft_fn, zeros_like(tangent), tangent).second;
|
||||
CHECK(array_equal(fft::fft(tangent), jvp_out).item<bool>());
|
||||
|
||||
// Inverse
|
||||
auto ifft_fn = [](array x) { return fft::ifft(x); };
|
||||
vjp_out = vjp(ifft_fn, zeros_like(cotangent), cotangent).second;
|
||||
CHECK(array_equal(fft::ifft(cotangent), vjp_out).item<bool>());
|
||||
|
||||
jvp_out = jvp(ifft_fn, zeros_like(tangent), tangent).second;
|
||||
CHECK(array_equal(fft::ifft(tangent), jvp_out).item<bool>());
|
||||
|
||||
// Real
|
||||
auto rfft_fn = [](array x) { return fft::rfft(x); };
|
||||
cotangent = astype(arange(6), complex64);
|
||||
vjp_out = vjp(rfft_fn, zeros({10}), cotangent).second;
|
||||
auto expected = astype(fft::fft(cotangent, 10, 0), float32);
|
||||
CHECK(array_equal(expected, vjp_out).item<bool>());
|
||||
|
||||
tangent = astype(arange(10), float32);
|
||||
jvp_out = jvp(rfft_fn, zeros_like(tangent), tangent).second;
|
||||
CHECK(array_equal(fft::rfft(tangent), jvp_out).item<bool>());
|
||||
|
||||
// Inverse real
|
||||
auto irfft_fn = [](array x) { return fft::irfft(x); };
|
||||
cotangent = astype(arange(10), float32);
|
||||
vjp_out = vjp(irfft_fn, astype(zeros({6}), complex64), cotangent).second;
|
||||
expected = fft::fft(cotangent, 10, 0);
|
||||
auto o_splits = split(vjp_out, {1, 5});
|
||||
auto e_splits = split(expected, {1, 5, 6});
|
||||
CHECK_EQ(e_splits[0].item<complex64_t>(), o_splits[0].item<complex64_t>());
|
||||
CHECK(array_equal(2 * e_splits[1], o_splits[1]).item<bool>());
|
||||
CHECK_EQ(e_splits[2].item<complex64_t>(), o_splits[2].item<complex64_t>());
|
||||
|
||||
tangent = astype(arange(10), complex64);
|
||||
jvp_out = jvp(irfft_fn, zeros_like(tangent), tangent).second;
|
||||
CHECK(array_equal(fft::irfft(tangent), jvp_out).item<bool>());
|
||||
|
||||
// Check ND vjps run properly
|
||||
vjp_out = vjp([](array x) { return fft::fftn(x); },
|
||||
astype(zeros({5, 5}), complex64),
|
||||
astype(zeros({5, 5}), complex64))
|
||||
.second;
|
||||
CHECK_EQ(vjp_out.shape(), std::vector<int>{5, 5});
|
||||
|
||||
vjp_out = vjp([](array x) { return fft::ifftn(x); },
|
||||
astype(zeros({5, 5}), complex64),
|
||||
astype(zeros({5, 5}), complex64))
|
||||
.second;
|
||||
CHECK_EQ(vjp_out.shape(), std::vector<int>{5, 5});
|
||||
|
||||
vjp_out = vjp([](array x) { return fft::rfftn(x); },
|
||||
zeros({5, 9}),
|
||||
astype(zeros({5, 5}), complex64))
|
||||
.second;
|
||||
CHECK_EQ(vjp_out.shape(), std::vector<int>{5, 9});
|
||||
|
||||
vjp_out = vjp([](array x) { return fft::irfftn(x); },
|
||||
astype(zeros({5, 5}), complex64),
|
||||
zeros({5, 8}))
|
||||
.second;
|
||||
CHECK_EQ(vjp_out.shape(), std::vector<int>{5, 5});
|
||||
|
||||
set_default_device(device);
|
||||
}
|
||||
30
tests/graph_optimize_tests.cpp
Normal file
30
tests/graph_optimize_tests.cpp
Normal file
@@ -0,0 +1,30 @@
|
||||
#include "doctest/doctest.h"
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
TEST_CASE("test simplify scalars") {
|
||||
auto a = array({-1.0f, 2.0f});
|
||||
auto b = maximum(a, array(0.0f));
|
||||
auto c = maximum(-a, array(0.0f));
|
||||
auto d = b + c;
|
||||
simplify({d});
|
||||
CHECK(b.inputs()[1].id() == c.inputs()[1].id());
|
||||
}
|
||||
|
||||
TEST_CASE("test simplify") {
|
||||
auto a = array({1.0f, 2.0f});
|
||||
auto b = exp(a) + exp(a);
|
||||
simplify(b);
|
||||
eval(b);
|
||||
CHECK(b.inputs()[0].id() == b.inputs()[1].id());
|
||||
}
|
||||
|
||||
TEST_CASE("test no simplify") {
|
||||
auto a = array({1.0f, 2.0f});
|
||||
auto b = cos(a) + sin(a);
|
||||
simplify(b);
|
||||
eval(b);
|
||||
CHECK(b.inputs()[0].id() != b.inputs()[1].id());
|
||||
}
|
||||
Reference in New Issue
Block a user