Correct types for vjp + tests (#418)

* correct types for vjp + tests

* fix build + comment
This commit is contained in:
Awni Hannun
2024-01-10 13:32:37 -08:00
committed by GitHub
parent b7f905787e
commit 3b4f066dac
4 changed files with 75 additions and 4 deletions

View File

@@ -1202,3 +1202,29 @@ TEST_CASE("test update state") {
CHECK(state.is_evaled());
CHECK(array_equal(state, array({1.0, 1.0})).item<bool>());
}
TEST_CASE("test grad types") {
{
auto fn = [](array x) { return sum(x); };
for (auto t : {float16, bfloat16, float32}) {
auto x = array(1.0, t);
auto dfdx = grad(fn)(x);
CHECK_EQ(dfdx.dtype(), t);
}
}
{
// Check for multi-input grad
auto fn = [](std::vector<array> inputs) {
return sum(inputs[0] + inputs[1]);
};
for (auto t : {float16, bfloat16, float32}) {
auto x = array(1.0, t);
auto y = array(1.0, t);
auto out = grad(fn)({x, y});
CHECK_EQ(out[0].dtype(), t);
}
}
}