This commit is contained in:
paramthakkar123
2025-05-06 09:53:10 +05:30
128 changed files with 2291 additions and 895 deletions

View File

@@ -10,7 +10,7 @@ FetchContent_MakeAvailable(doctest)
add_executable(tests ${PROJECT_SOURCE_DIR}/tests/tests.cpp)
if(MLX_BUILD_METAL)
set(METAL_TEST_SOURCES metal_tests.cpp)
set(METAL_TEST_SOURCES gpu_tests.cpp)
endif()
include(${doctest_SOURCE_DIR}/scripts/cmake/doctest.cmake)

View File

@@ -795,3 +795,12 @@ TEST_CASE("test compile lambda") {
out = cfun2({array(0)});
CHECK_EQ(out[0].item<int>(), 3);
}
TEST_CASE("test compile with no-ops") {
auto fun = [](const std::vector<array>& inputs) {
return std::vector<array>{abs(stop_gradient(abs(inputs[0])))};
};
auto in = array(1.0);
auto out = compile(fun)({in})[0];
CHECK_EQ(out.inputs()[0].id(), in.id());
}

View File

@@ -309,6 +309,7 @@ TEST_CASE("test fft grads") {
CHECK_EQ(vjp_out.shape(), Shape{5, 5});
}
<<<<<<< HEAD
TEST_CASE("test stft and istft") {
int n_fft = 4;
int hop_length = 2;
@@ -381,4 +382,62 @@ TEST_CASE("test stft and istft") {
CHECK_EQ(stft_result.shape(1), n_fft);
}
}
}
== == == = TEST_CASE("test fftshift and ifftshift") {
// Test 1D array with even length
auto x = arange(8);
auto y = fft::fftshift(x);
CHECK_EQ(y.shape(), x.shape());
// print y
CHECK(array_equal(y, array({4, 5, 6, 7, 0, 1, 2, 3})).item<bool>());
// Test 1D array with odd length
x = arange(7);
y = fft::fftshift(x);
CHECK_EQ(y.shape(), x.shape());
CHECK(array_equal(y, array({4, 5, 6, 0, 1, 2, 3})).item<bool>());
// Test 2D array
x = reshape(arange(16), {4, 4});
y = fft::fftshift(x);
auto expected =
array({10, 11, 8, 9, 14, 15, 12, 13, 2, 3, 0, 1, 6, 7, 4, 5}, {4, 4});
CHECK(array_equal(y, expected).item<bool>());
// Test with specific axes
y = fft::fftshift(x, {0});
expected =
array({8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7}, {4, 4});
CHECK(array_equal(y, expected).item<bool>());
y = fft::fftshift(x, {1});
expected =
array({2, 3, 0, 1, 6, 7, 4, 5, 10, 11, 8, 9, 14, 15, 12, 13}, {4, 4});
CHECK(array_equal(y, expected).item<bool>());
// Test ifftshift (inverse operation)
x = arange(8);
y = fft::ifftshift(x);
CHECK_EQ(y.shape(), x.shape());
CHECK(array_equal(y, array({4, 5, 6, 7, 0, 1, 2, 3})).item<bool>());
// Test ifftshift with odd length (different from fftshift)
x = arange(7);
y = fft::ifftshift(x);
CHECK_EQ(y.shape(), x.shape());
CHECK(array_equal(y, array({3, 4, 5, 6, 0, 1, 2})).item<bool>());
// Test 2D ifftshift
x = reshape(arange(16), {4, 4});
y = fft::ifftshift(x);
expected =
array({10, 11, 8, 9, 14, 15, 12, 13, 2, 3, 0, 1, 6, 7, 4, 5}, {4, 4});
CHECK(array_equal(y, expected).item<bool>());
// Test error cases
CHECK_THROWS_AS(fft::fftshift(x, {3}), std::invalid_argument);
CHECK_THROWS_AS(fft::fftshift(x, {-5}), std::invalid_argument);
CHECK_THROWS_AS(fft::ifftshift(x, {3}), std::invalid_argument);
CHECK_THROWS_AS(fft::ifftshift(x, {-5}), std::invalid_argument);
}
>>>>>>> 5a1a5d5ed16f69af7c3ce56dd94e4502661e1565

View File

@@ -1,11 +1,8 @@
// Copyright © 2023-2024 Apple Inc.
#include <array>
#include "doctest/doctest.h"
#include "mlx/backend/metal/allocator.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/metal.h"
#include "doctest/doctest.h"
#include "mlx/mlx.h"
using namespace mlx::core;
@@ -13,13 +10,7 @@ using namespace mlx::core;
static const std::array<Dtype, 5> types =
{bool_, uint32, int32, int64, float32};
TEST_CASE("test metal device") {
// Make sure the device and library can load
CHECK(metal::is_available());
auto& device = metal::device(Device::gpu);
}
TEST_CASE("test metal arange") {
TEST_CASE("test gpu arange") {
for (auto t : types) {
if (t == bool_) {
continue;
@@ -34,7 +25,7 @@ TEST_CASE("test metal arange") {
}
}
TEST_CASE("test metal full") {
TEST_CASE("test gpu full") {
for (auto t : types) {
auto out_cpu = full({4, 4}, 2, t, Device::cpu);
auto out_gpu = full({4, 4}, 2, t, Device::gpu);
@@ -63,7 +54,7 @@ TEST_CASE("test metal full") {
}
}
TEST_CASE("test metal astype") {
TEST_CASE("test gpu astype") {
array x = array({-4, -3, -2, -1, 0, 1, 2, 3});
// Check all types work
for (auto t : types) {
@@ -80,7 +71,7 @@ TEST_CASE("test metal astype") {
}
}
TEST_CASE("test metal reshape") {
TEST_CASE("test gpu reshape") {
array x = array({0, 1, 2, 3, 4, 5, 6, 7});
auto out_cpu = reshape(x, {2, 2, 2});
auto out_gpu = reshape(x, {2, 2, 2}, Device::gpu);
@@ -96,7 +87,7 @@ TEST_CASE("test metal reshape") {
CHECK(array_equal(out_gpu, out_cpu, Device::cpu).item<bool>());
}
TEST_CASE("test metal reduce") {
TEST_CASE("test gpu reduce") {
{
array a(true);
CHECK_EQ(all(a, Device::gpu).item<bool>(), true);
@@ -190,7 +181,7 @@ TEST_CASE("test metal reduce") {
}
}
TEST_CASE("test metal binary ops") {
TEST_CASE("test gpu binary ops") {
// scalar-scalar
{
array a(2.0f);
@@ -338,7 +329,7 @@ TEST_CASE("test metal binary ops") {
}
}
TEST_CASE("test metal unary ops") {
TEST_CASE("test gpu unary ops") {
// contiguous
{
array x({-1.0f, 0.0f, 1.0f});
@@ -392,7 +383,7 @@ TEST_CASE("test metal unary ops") {
}
}
TEST_CASE("test metal random") {
TEST_CASE("test gpu random") {
{
auto key = random::key(0);
auto x = random::bits({}, 4, key, Device::gpu);
@@ -415,7 +406,7 @@ TEST_CASE("test metal random") {
}
}
TEST_CASE("test metal matmul") {
TEST_CASE("test gpu matmul") {
{
auto a = ones({2, 2});
auto b = ones({2, 2});
@@ -440,7 +431,7 @@ TEST_CASE("test metal matmul") {
}
}
TEST_CASE("test metal validation") {
TEST_CASE("test gpu validation") {
// Run this test with Metal validation enabled
// METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./tests/tests \
// -tc="test metal validation" \

View File

@@ -3859,6 +3859,9 @@ TEST_CASE("test roll") {
y = roll(x, {1, 2}, {0, 1});
CHECK(array_equal(y, array({8, 9, 5, 6, 7, 3, 4, 0, 1, 2}, {2, 5}))
.item<bool>());
y = roll(array({}), 0, 0);
CHECK(array_equal(y, array({})).item<bool>());
}
TEST_CASE("test contiguous") {
@@ -3911,4 +3914,70 @@ TEST_CASE("test bitwise shift operations") {
CHECK_EQ(right_shift_bool_result.dtype(), uint8);
CHECK(array_equal(right_shift_bool_result, full({4}, 0, uint8)).item<bool>());
}
}
TEST_CASE("test conv_transpose1d with output_padding") {
auto in = array({1.0, 2.0, 3.0}, {1, 1, 3});
auto wt = array({1.0, 1.0, 1.0}, {1, 1, 3});
int stride = 2;
int padding = 0;
int dilation = 1;
int output_padding = 1;
int groups = 1;
auto out = conv_transpose1d(
in, wt, stride, padding, dilation, output_padding, groups);
auto expected = array({6.0, 0.0}, {1, 2, 1});
CHECK(array_equal(out, expected).item<bool>());
}
TEST_CASE("test conv_transpose2d with output_padding") {
auto in = array({1.0, 2.0, 3.0, 4.0}, {1, 1, 2, 2});
auto wt = array({1.0, 1.0, 1.0, 1.0}, {2, 1, 1, 2});
std::pair<int, int> stride{2, 2};
std::pair<int, int> padding{0, 0};
std::pair<int, int> output_padding{1, 1};
std::pair<int, int> dilation{1, 1};
int groups = 1;
auto out = conv_transpose2d(
in, wt, stride, padding, dilation, output_padding, groups);
auto expected = array(
{3.0,
3.0,
0.0,
0.0,
7.0,
7.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0},
{1, 2, 4, 2});
CHECK(array_equal(out, expected).item<bool>());
}
TEST_CASE("test conv_transpose3d with output_padding") {
auto in = array({1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}, {1, 1, 2, 2, 2});
auto wt = array({1.0, 1.0}, {1, 1, 1, 1, 2});
std::tuple<int, int, int> stride{2, 2, 2};
std::tuple<int, int, int> padding{0, 0, 0};
std::tuple<int, int, int> output_padding{1, 1, 1};
std::tuple<int, int, int> dilation{1, 1, 1};
int groups = 1;
auto out = conv_transpose3d(
in, wt, stride, padding, dilation, output_padding, groups);
auto expected = array(
{3.0, 0.0, 7.0, 0.0, 0.0, 0.0, 0.0, 0.0, 11.0, 0.0, 15.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0},
{1, 2, 4, 4, 1});
CHECK(array_equal(out, expected).item<bool>());
}