mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-22 05:08:08 +08:00
Correct types for vjp + tests (#418)
* correct types for vjp + tests * fix build + comment
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <cstring>
|
||||
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
@@ -396,7 +396,10 @@ std::pair<std::vector<array>, std::vector<array>> vjp(
|
||||
// products for each primitive
|
||||
std::unordered_map<std::uintptr_t, array> cotan_map;
|
||||
for (auto [out_idx, cotan_idx] : output_cotan_pairs) {
|
||||
cotan_map.insert({outputs[out_idx].id(), cotans[cotan_idx]});
|
||||
auto& o = outputs[out_idx];
|
||||
auto s = o.has_primitive() ? o.primitive().stream()
|
||||
: default_stream(default_device());
|
||||
cotan_map.insert({o.id(), astype(cotans[cotan_idx], o.dtype(), s)});
|
||||
}
|
||||
for (auto it = tape.rbegin(); it != tape.rend(); ++it) {
|
||||
auto& a = *it;
|
||||
@@ -636,9 +639,8 @@ ValueAndGradFn value_and_grad(
|
||||
for (auto arg : args) {
|
||||
ginputs.push_back(inputs[arg]);
|
||||
}
|
||||
// Set the incoming gradient as int32 so that it will be promoted to the
|
||||
// appropriate floating point type op(int, floatXX) -> floatXX for most ops
|
||||
auto [outputs, grads] = vjp(gfun, ginputs, {array(1)});
|
||||
// Set the incoming gradient to int32, vjp will cast it to the output type
|
||||
auto [outputs, grads] = vjp(gfun, ginputs, {array(1.0f)});
|
||||
return std::make_pair(outputs, grads);
|
||||
};
|
||||
}
|
||||
|
Reference in New Issue
Block a user