mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-26 02:33:21 +08:00
Remove metal-only tests
This commit is contained in:
parent
bb6565ef14
commit
08ef9408b5
@ -10,7 +10,7 @@ FetchContent_MakeAvailable(doctest)
|
|||||||
add_executable(tests ${PROJECT_SOURCE_DIR}/tests/tests.cpp)
|
add_executable(tests ${PROJECT_SOURCE_DIR}/tests/tests.cpp)
|
||||||
|
|
||||||
if(MLX_BUILD_METAL)
|
if(MLX_BUILD_METAL)
|
||||||
set(METAL_TEST_SOURCES metal_tests.cpp)
|
set(METAL_TEST_SOURCES gpu_tests.cpp)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
include(${doctest_SOURCE_DIR}/scripts/cmake/doctest.cmake)
|
include(${doctest_SOURCE_DIR}/scripts/cmake/doctest.cmake)
|
||||||
|
@ -1,11 +1,8 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include <array>
|
#include <array>
|
||||||
#include "doctest/doctest.h"
|
|
||||||
|
|
||||||
#include "mlx/backend/metal/allocator.h"
|
#include "doctest/doctest.h"
|
||||||
#include "mlx/backend/metal/device.h"
|
|
||||||
#include "mlx/backend/metal/metal.h"
|
|
||||||
#include "mlx/mlx.h"
|
#include "mlx/mlx.h"
|
||||||
|
|
||||||
using namespace mlx::core;
|
using namespace mlx::core;
|
||||||
@ -13,13 +10,7 @@ using namespace mlx::core;
|
|||||||
static const std::array<Dtype, 5> types =
|
static const std::array<Dtype, 5> types =
|
||||||
{bool_, uint32, int32, int64, float32};
|
{bool_, uint32, int32, int64, float32};
|
||||||
|
|
||||||
TEST_CASE("test metal device") {
|
TEST_CASE("test gpu arange") {
|
||||||
// Make sure the device and library can load
|
|
||||||
CHECK(metal::is_available());
|
|
||||||
auto& device = metal::device(Device::gpu);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_CASE("test metal arange") {
|
|
||||||
for (auto t : types) {
|
for (auto t : types) {
|
||||||
if (t == bool_) {
|
if (t == bool_) {
|
||||||
continue;
|
continue;
|
||||||
@ -34,7 +25,7 @@ TEST_CASE("test metal arange") {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_CASE("test metal full") {
|
TEST_CASE("test gpu full") {
|
||||||
for (auto t : types) {
|
for (auto t : types) {
|
||||||
auto out_cpu = full({4, 4}, 2, t, Device::cpu);
|
auto out_cpu = full({4, 4}, 2, t, Device::cpu);
|
||||||
auto out_gpu = full({4, 4}, 2, t, Device::gpu);
|
auto out_gpu = full({4, 4}, 2, t, Device::gpu);
|
||||||
@ -63,7 +54,7 @@ TEST_CASE("test metal full") {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_CASE("test metal astype") {
|
TEST_CASE("test gpu astype") {
|
||||||
array x = array({-4, -3, -2, -1, 0, 1, 2, 3});
|
array x = array({-4, -3, -2, -1, 0, 1, 2, 3});
|
||||||
// Check all types work
|
// Check all types work
|
||||||
for (auto t : types) {
|
for (auto t : types) {
|
||||||
@ -80,7 +71,7 @@ TEST_CASE("test metal astype") {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_CASE("test metal reshape") {
|
TEST_CASE("test gpu reshape") {
|
||||||
array x = array({0, 1, 2, 3, 4, 5, 6, 7});
|
array x = array({0, 1, 2, 3, 4, 5, 6, 7});
|
||||||
auto out_cpu = reshape(x, {2, 2, 2});
|
auto out_cpu = reshape(x, {2, 2, 2});
|
||||||
auto out_gpu = reshape(x, {2, 2, 2}, Device::gpu);
|
auto out_gpu = reshape(x, {2, 2, 2}, Device::gpu);
|
||||||
@ -96,7 +87,7 @@ TEST_CASE("test metal reshape") {
|
|||||||
CHECK(array_equal(out_gpu, out_cpu, Device::cpu).item<bool>());
|
CHECK(array_equal(out_gpu, out_cpu, Device::cpu).item<bool>());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_CASE("test metal reduce") {
|
TEST_CASE("test gpu reduce") {
|
||||||
{
|
{
|
||||||
array a(true);
|
array a(true);
|
||||||
CHECK_EQ(all(a, Device::gpu).item<bool>(), true);
|
CHECK_EQ(all(a, Device::gpu).item<bool>(), true);
|
||||||
@ -190,7 +181,7 @@ TEST_CASE("test metal reduce") {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_CASE("test metal binary ops") {
|
TEST_CASE("test gpu binary ops") {
|
||||||
// scalar-scalar
|
// scalar-scalar
|
||||||
{
|
{
|
||||||
array a(2.0f);
|
array a(2.0f);
|
||||||
@ -338,7 +329,7 @@ TEST_CASE("test metal binary ops") {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_CASE("test metal unary ops") {
|
TEST_CASE("test gpu unary ops") {
|
||||||
// contiguous
|
// contiguous
|
||||||
{
|
{
|
||||||
array x({-1.0f, 0.0f, 1.0f});
|
array x({-1.0f, 0.0f, 1.0f});
|
||||||
@ -392,7 +383,7 @@ TEST_CASE("test metal unary ops") {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_CASE("test metal random") {
|
TEST_CASE("test gpu random") {
|
||||||
{
|
{
|
||||||
auto key = random::key(0);
|
auto key = random::key(0);
|
||||||
auto x = random::bits({}, 4, key, Device::gpu);
|
auto x = random::bits({}, 4, key, Device::gpu);
|
||||||
@ -415,7 +406,7 @@ TEST_CASE("test metal random") {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_CASE("test metal matmul") {
|
TEST_CASE("test gpu matmul") {
|
||||||
{
|
{
|
||||||
auto a = ones({2, 2});
|
auto a = ones({2, 2});
|
||||||
auto b = ones({2, 2});
|
auto b = ones({2, 2});
|
||||||
@ -440,7 +431,7 @@ TEST_CASE("test metal matmul") {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_CASE("test metal validation") {
|
TEST_CASE("test gpu validation") {
|
||||||
// Run this test with Metal validation enabled
|
// Run this test with Metal validation enabled
|
||||||
// METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./tests/tests \
|
// METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./tests/tests \
|
||||||
// -tc="test metal validation" \
|
// -tc="test metal validation" \
|
Loading…
Reference in New Issue
Block a user