From 08ef9408b52f6aba8bba2e5a94a587dad62676dc Mon Sep 17 00:00:00 2001 From: Cheng Date: Wed, 30 Apr 2025 15:06:00 +0900 Subject: [PATCH] Remove metal-only tests --- tests/CMakeLists.txt | 2 +- tests/{metal_tests.cpp => gpu_tests.cpp} | 31 +++++++++--------------- 2 files changed, 12 insertions(+), 21 deletions(-) rename tests/{metal_tests.cpp => gpu_tests.cpp} (95%) diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index be4479e70..cf0ba3d5d 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -10,7 +10,7 @@ FetchContent_MakeAvailable(doctest) add_executable(tests ${PROJECT_SOURCE_DIR}/tests/tests.cpp) if(MLX_BUILD_METAL) - set(METAL_TEST_SOURCES metal_tests.cpp) + set(METAL_TEST_SOURCES gpu_tests.cpp) endif() include(${doctest_SOURCE_DIR}/scripts/cmake/doctest.cmake) diff --git a/tests/metal_tests.cpp b/tests/gpu_tests.cpp similarity index 95% rename from tests/metal_tests.cpp rename to tests/gpu_tests.cpp index 7aabdf36d..f0ef969cf 100644 --- a/tests/metal_tests.cpp +++ b/tests/gpu_tests.cpp @@ -1,11 +1,8 @@ // Copyright © 2023-2024 Apple Inc. #include -#include "doctest/doctest.h" -#include "mlx/backend/metal/allocator.h" -#include "mlx/backend/metal/device.h" -#include "mlx/backend/metal/metal.h" +#include "doctest/doctest.h" #include "mlx/mlx.h" using namespace mlx::core; @@ -13,13 +10,7 @@ using namespace mlx::core; static const std::array types = {bool_, uint32, int32, int64, float32}; -TEST_CASE("test metal device") { - // Make sure the device and library can load - CHECK(metal::is_available()); - auto& device = metal::device(Device::gpu); -} - -TEST_CASE("test metal arange") { +TEST_CASE("test gpu arange") { for (auto t : types) { if (t == bool_) { continue; @@ -34,7 +25,7 @@ TEST_CASE("test metal arange") { } } -TEST_CASE("test metal full") { +TEST_CASE("test gpu full") { for (auto t : types) { auto out_cpu = full({4, 4}, 2, t, Device::cpu); auto out_gpu = full({4, 4}, 2, t, Device::gpu); @@ -63,7 +54,7 @@ TEST_CASE("test metal full") { } } -TEST_CASE("test metal astype") { +TEST_CASE("test gpu astype") { array x = array({-4, -3, -2, -1, 0, 1, 2, 3}); // Check all types work for (auto t : types) { @@ -80,7 +71,7 @@ TEST_CASE("test metal astype") { } } -TEST_CASE("test metal reshape") { +TEST_CASE("test gpu reshape") { array x = array({0, 1, 2, 3, 4, 5, 6, 7}); auto out_cpu = reshape(x, {2, 2, 2}); auto out_gpu = reshape(x, {2, 2, 2}, Device::gpu); @@ -96,7 +87,7 @@ TEST_CASE("test metal reshape") { CHECK(array_equal(out_gpu, out_cpu, Device::cpu).item()); } -TEST_CASE("test metal reduce") { +TEST_CASE("test gpu reduce") { { array a(true); CHECK_EQ(all(a, Device::gpu).item(), true); @@ -190,7 +181,7 @@ TEST_CASE("test metal reduce") { } } -TEST_CASE("test metal binary ops") { +TEST_CASE("test gpu binary ops") { // scalar-scalar { array a(2.0f); @@ -338,7 +329,7 @@ TEST_CASE("test metal binary ops") { } } -TEST_CASE("test metal unary ops") { +TEST_CASE("test gpu unary ops") { // contiguous { array x({-1.0f, 0.0f, 1.0f}); @@ -392,7 +383,7 @@ TEST_CASE("test metal unary ops") { } } -TEST_CASE("test metal random") { +TEST_CASE("test gpu random") { { auto key = random::key(0); auto x = random::bits({}, 4, key, Device::gpu); @@ -415,7 +406,7 @@ TEST_CASE("test metal random") { } } -TEST_CASE("test metal matmul") { +TEST_CASE("test gpu matmul") { { auto a = ones({2, 2}); auto b = ones({2, 2}); @@ -440,7 +431,7 @@ TEST_CASE("test metal matmul") { } } -TEST_CASE("test metal validation") { +TEST_CASE("test gpu validation") { // Run this test with Metal validation enabled // METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./tests/tests \ // -tc="test metal validation" \