Correct types for vjp + tests (#418)

* correct types for vjp + tests

* fix build + comment
This commit is contained in:
Awni Hannun
2024-01-10 13:32:37 -08:00
committed by GitHub
parent b7f905787e
commit 3b4f066dac
4 changed files with 75 additions and 4 deletions

View File

@@ -1,5 +1,7 @@
// Copyright © 2023 Apple Inc.
#include <cstring>
#include "mlx/ops.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"

View File

@@ -396,7 +396,10 @@ std::pair<std::vector<array>, std::vector<array>> vjp(
// products for each primitive
std::unordered_map<std::uintptr_t, array> cotan_map;
for (auto [out_idx, cotan_idx] : output_cotan_pairs) {
cotan_map.insert({outputs[out_idx].id(), cotans[cotan_idx]});
auto& o = outputs[out_idx];
auto s = o.has_primitive() ? o.primitive().stream()
: default_stream(default_device());
cotan_map.insert({o.id(), astype(cotans[cotan_idx], o.dtype(), s)});
}
for (auto it = tape.rbegin(); it != tape.rend(); ++it) {
auto& a = *it;
@@ -636,9 +639,8 @@ ValueAndGradFn value_and_grad(
for (auto arg : args) {
ginputs.push_back(inputs[arg]);
}
// Set the incoming gradient as int32 so that it will be promoted to the
// appropriate floating point type op(int, floatXX) -> floatXX for most ops
auto [outputs, grads] = vjp(gfun, ginputs, {array(1)});
// Set the incoming gradient to int32, vjp will cast it to the output type
auto [outputs, grads] = vjp(gfun, ginputs, {array(1.0f)});
return std::make_pair(outputs, grads);
};
}