mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Correct types for vjp + tests (#418)
* correct types for vjp + tests * fix build + comment
This commit is contained in:
parent
b7f905787e
commit
3b4f066dac
@ -1,5 +1,7 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
#include <cstring>
|
||||||
|
|
||||||
#include "mlx/ops.h"
|
#include "mlx/ops.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
|
@ -396,7 +396,10 @@ std::pair<std::vector<array>, std::vector<array>> vjp(
|
|||||||
// products for each primitive
|
// products for each primitive
|
||||||
std::unordered_map<std::uintptr_t, array> cotan_map;
|
std::unordered_map<std::uintptr_t, array> cotan_map;
|
||||||
for (auto [out_idx, cotan_idx] : output_cotan_pairs) {
|
for (auto [out_idx, cotan_idx] : output_cotan_pairs) {
|
||||||
cotan_map.insert({outputs[out_idx].id(), cotans[cotan_idx]});
|
auto& o = outputs[out_idx];
|
||||||
|
auto s = o.has_primitive() ? o.primitive().stream()
|
||||||
|
: default_stream(default_device());
|
||||||
|
cotan_map.insert({o.id(), astype(cotans[cotan_idx], o.dtype(), s)});
|
||||||
}
|
}
|
||||||
for (auto it = tape.rbegin(); it != tape.rend(); ++it) {
|
for (auto it = tape.rbegin(); it != tape.rend(); ++it) {
|
||||||
auto& a = *it;
|
auto& a = *it;
|
||||||
@ -636,9 +639,8 @@ ValueAndGradFn value_and_grad(
|
|||||||
for (auto arg : args) {
|
for (auto arg : args) {
|
||||||
ginputs.push_back(inputs[arg]);
|
ginputs.push_back(inputs[arg]);
|
||||||
}
|
}
|
||||||
// Set the incoming gradient as int32 so that it will be promoted to the
|
// Set the incoming gradient to int32, vjp will cast it to the output type
|
||||||
// appropriate floating point type op(int, floatXX) -> floatXX for most ops
|
auto [outputs, grads] = vjp(gfun, ginputs, {array(1.0f)});
|
||||||
auto [outputs, grads] = vjp(gfun, ginputs, {array(1)});
|
|
||||||
return std::make_pair(outputs, grads);
|
return std::make_pair(outputs, grads);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
@ -274,6 +274,47 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
|||||||
mx.eval(state)
|
mx.eval(state)
|
||||||
self.assertTrue(mx.allclose(state, mx.ones((2,))))
|
self.assertTrue(mx.allclose(state, mx.ones((2,))))
|
||||||
|
|
||||||
|
def test_scatter_vjp(self):
|
||||||
|
def fun(x, idx):
|
||||||
|
x[idx] = 2.0
|
||||||
|
return x.sum()
|
||||||
|
|
||||||
|
dfdx = mx.grad(fun)(mx.array([1.0, 2.0, 3.0]), mx.array([1]))
|
||||||
|
self.assertTrue(mx.array_equal(dfdx, mx.array([1.0, 0.0, 1.0])))
|
||||||
|
self.assertEqual(dfdx.dtype, mx.float32)
|
||||||
|
|
||||||
|
y = mx.array([0.0, 1.0, 2.0])
|
||||||
|
|
||||||
|
def fun(x, idx):
|
||||||
|
y[idx] = x
|
||||||
|
return y.sum()
|
||||||
|
|
||||||
|
dfdx = mx.grad(fun)(mx.array([2.0]), mx.array([1]))
|
||||||
|
self.assertTrue(mx.array_equal(dfdx, mx.array([1.0])))
|
||||||
|
self.assertEqual(dfdx.dtype, mx.float32)
|
||||||
|
|
||||||
|
def test_vjp_types(self):
|
||||||
|
def fun(x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
for t in [mx.float16, mx.bfloat16, mx.float32]:
|
||||||
|
out = mx.grad(fun)(mx.array(1.0, t))
|
||||||
|
self.assertEqual(out.dtype, t)
|
||||||
|
|
||||||
|
def fun(x):
|
||||||
|
return x.sum()
|
||||||
|
|
||||||
|
for t in [mx.float16, mx.bfloat16, mx.float32]:
|
||||||
|
out = mx.grad(fun)(mx.array(1.0, t))
|
||||||
|
self.assertEqual(out.dtype, t)
|
||||||
|
|
||||||
|
def fun(x, y):
|
||||||
|
return (x + y).sum()
|
||||||
|
|
||||||
|
for t in [mx.float16, mx.bfloat16, mx.float32]:
|
||||||
|
out = mx.grad(fun)(mx.array(1.0, t), mx.array(1.0, t))
|
||||||
|
self.assertEqual(out.dtype, t)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -1202,3 +1202,29 @@ TEST_CASE("test update state") {
|
|||||||
CHECK(state.is_evaled());
|
CHECK(state.is_evaled());
|
||||||
CHECK(array_equal(state, array({1.0, 1.0})).item<bool>());
|
CHECK(array_equal(state, array({1.0, 1.0})).item<bool>());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test grad types") {
|
||||||
|
{
|
||||||
|
auto fn = [](array x) { return sum(x); };
|
||||||
|
|
||||||
|
for (auto t : {float16, bfloat16, float32}) {
|
||||||
|
auto x = array(1.0, t);
|
||||||
|
auto dfdx = grad(fn)(x);
|
||||||
|
CHECK_EQ(dfdx.dtype(), t);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// Check for multi-input grad
|
||||||
|
auto fn = [](std::vector<array> inputs) {
|
||||||
|
return sum(inputs[0] + inputs[1]);
|
||||||
|
};
|
||||||
|
|
||||||
|
for (auto t : {float16, bfloat16, float32}) {
|
||||||
|
auto x = array(1.0, t);
|
||||||
|
auto y = array(1.0, t);
|
||||||
|
auto out = grad(fn)({x, y});
|
||||||
|
CHECK_EQ(out[0].dtype(), t);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user