Metal validation (#432)

* tests clear metal validation

* add cpp test with metal validation to circleci

* nit
This commit is contained in:
Awni Hannun
2024-01-11 11:57:24 -08:00
committed by GitHub
parent 975e265f74
commit c9934fe8a4
10 changed files with 142 additions and 35 deletions

View File

@@ -438,3 +438,36 @@ TEST_CASE("test metal matmul") {
CHECK(array_equal(out, full({3, 3, 2, 2}, 2.0f), Device::cpu).item<bool>());
}
}
TEST_CASE("test metal validation") {
// Run this test with Metal validation enabled
// METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./tests/tests \
// -tc="test metal validation" \
auto x = array({});
eval(exp(x));
auto y = array({});
eval(add(x, y));
eval(sum(x));
x = array({1, 2, 3});
y = array(0);
eval(gather(x, y, 0, {0}));
eval(gather(x, y, 0, {2}));
eval(gather(x, y, 0, {0}));
eval(gather(x, y, 0, {2}));
eval(scatter(x, y, array({2}), 0));
x = arange(0, -3, 1);
eval(x);
array_equal(x, array({})).item<bool>();
x = array({1.0, 0.0});
eval(argmax(x));
eval(scatter_max(array(1), {}, array(2), std::vector<int>{}));
}