mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-20 09:33:16 +08:00
Correct types for vjp + tests (#418)
* correct types for vjp + tests * fix build + comment
This commit is contained in:
@@ -1202,3 +1202,29 @@ TEST_CASE("test update state") {
|
||||
CHECK(state.is_evaled());
|
||||
CHECK(array_equal(state, array({1.0, 1.0})).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test grad types") {
|
||||
{
|
||||
auto fn = [](array x) { return sum(x); };
|
||||
|
||||
for (auto t : {float16, bfloat16, float32}) {
|
||||
auto x = array(1.0, t);
|
||||
auto dfdx = grad(fn)(x);
|
||||
CHECK_EQ(dfdx.dtype(), t);
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
// Check for multi-input grad
|
||||
auto fn = [](std::vector<array> inputs) {
|
||||
return sum(inputs[0] + inputs[1]);
|
||||
};
|
||||
|
||||
for (auto t : {float16, bfloat16, float32}) {
|
||||
auto x = array(1.0, t);
|
||||
auto y = array(1.0, t);
|
||||
auto out = grad(fn)({x, y});
|
||||
CHECK_EQ(out[0].dtype(), t);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user