mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 10:26:56 +08:00
Metal validation (#432)
* tests clear metal validation * add cpp test with metal validation to circleci * nit
This commit is contained in:
@@ -438,3 +438,36 @@ TEST_CASE("test metal matmul") {
|
||||
CHECK(array_equal(out, full({3, 3, 2, 2}, 2.0f), Device::cpu).item<bool>());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test metal validation") {
|
||||
// Run this test with Metal validation enabled
|
||||
// METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./tests/tests \
|
||||
// -tc="test metal validation" \
|
||||
|
||||
auto x = array({});
|
||||
eval(exp(x));
|
||||
|
||||
auto y = array({});
|
||||
eval(add(x, y));
|
||||
|
||||
eval(sum(x));
|
||||
|
||||
x = array({1, 2, 3});
|
||||
y = array(0);
|
||||
eval(gather(x, y, 0, {0}));
|
||||
eval(gather(x, y, 0, {2}));
|
||||
|
||||
eval(gather(x, y, 0, {0}));
|
||||
eval(gather(x, y, 0, {2}));
|
||||
|
||||
eval(scatter(x, y, array({2}), 0));
|
||||
|
||||
x = arange(0, -3, 1);
|
||||
eval(x);
|
||||
array_equal(x, array({})).item<bool>();
|
||||
|
||||
x = array({1.0, 0.0});
|
||||
eval(argmax(x));
|
||||
|
||||
eval(scatter_max(array(1), {}, array(2), std::vector<int>{}));
|
||||
}
|
||||
|
Reference in New Issue
Block a user