diff --git a/mlx/io/gguf.cpp b/mlx/io/gguf.cpp index 7de0ad611..231572145 100644 --- a/mlx/io/gguf.cpp +++ b/mlx/io/gguf.cpp @@ -1,5 +1,7 @@ // Copyright © 2023 Apple Inc. +#include + #include "mlx/ops.h" #include "mlx/primitives.h" #include "mlx/utils.h" diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index a46171443..bbd9dc5e1 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -396,7 +396,10 @@ std::pair, std::vector> vjp( // products for each primitive std::unordered_map cotan_map; for (auto [out_idx, cotan_idx] : output_cotan_pairs) { - cotan_map.insert({outputs[out_idx].id(), cotans[cotan_idx]}); + auto& o = outputs[out_idx]; + auto s = o.has_primitive() ? o.primitive().stream() + : default_stream(default_device()); + cotan_map.insert({o.id(), astype(cotans[cotan_idx], o.dtype(), s)}); } for (auto it = tape.rbegin(); it != tape.rend(); ++it) { auto& a = *it; @@ -636,9 +639,8 @@ ValueAndGradFn value_and_grad( for (auto arg : args) { ginputs.push_back(inputs[arg]); } - // Set the incoming gradient as int32 so that it will be promoted to the - // appropriate floating point type op(int, floatXX) -> floatXX for most ops - auto [outputs, grads] = vjp(gfun, ginputs, {array(1)}); + // Set the incoming gradient to int32, vjp will cast it to the output type + auto [outputs, grads] = vjp(gfun, ginputs, {array(1.0f)}); return std::make_pair(outputs, grads); }; } diff --git a/python/tests/test_autograd.py b/python/tests/test_autograd.py index 69d8baeea..946210e6e 100644 --- a/python/tests/test_autograd.py +++ b/python/tests/test_autograd.py @@ -274,6 +274,47 @@ class TestAutograd(mlx_tests.MLXTestCase): mx.eval(state) self.assertTrue(mx.allclose(state, mx.ones((2,)))) + def test_scatter_vjp(self): + def fun(x, idx): + x[idx] = 2.0 + return x.sum() + + dfdx = mx.grad(fun)(mx.array([1.0, 2.0, 3.0]), mx.array([1])) + self.assertTrue(mx.array_equal(dfdx, mx.array([1.0, 0.0, 1.0]))) + self.assertEqual(dfdx.dtype, mx.float32) + + y = mx.array([0.0, 1.0, 2.0]) + + def fun(x, idx): + y[idx] = x + return y.sum() + + dfdx = mx.grad(fun)(mx.array([2.0]), mx.array([1])) + self.assertTrue(mx.array_equal(dfdx, mx.array([1.0]))) + self.assertEqual(dfdx.dtype, mx.float32) + + def test_vjp_types(self): + def fun(x): + return x + + for t in [mx.float16, mx.bfloat16, mx.float32]: + out = mx.grad(fun)(mx.array(1.0, t)) + self.assertEqual(out.dtype, t) + + def fun(x): + return x.sum() + + for t in [mx.float16, mx.bfloat16, mx.float32]: + out = mx.grad(fun)(mx.array(1.0, t)) + self.assertEqual(out.dtype, t) + + def fun(x, y): + return (x + y).sum() + + for t in [mx.float16, mx.bfloat16, mx.float32]: + out = mx.grad(fun)(mx.array(1.0, t), mx.array(1.0, t)) + self.assertEqual(out.dtype, t) + if __name__ == "__main__": unittest.main() diff --git a/tests/autograd_tests.cpp b/tests/autograd_tests.cpp index 5e4100c75..6c31ee118 100644 --- a/tests/autograd_tests.cpp +++ b/tests/autograd_tests.cpp @@ -1202,3 +1202,29 @@ TEST_CASE("test update state") { CHECK(state.is_evaled()); CHECK(array_equal(state, array({1.0, 1.0})).item()); } + +TEST_CASE("test grad types") { + { + auto fn = [](array x) { return sum(x); }; + + for (auto t : {float16, bfloat16, float32}) { + auto x = array(1.0, t); + auto dfdx = grad(fn)(x); + CHECK_EQ(dfdx.dtype(), t); + } + } + + { + // Check for multi-input grad + auto fn = [](std::vector inputs) { + return sum(inputs[0] + inputs[1]); + }; + + for (auto t : {float16, bfloat16, float32}) { + auto x = array(1.0, t); + auto y = array(1.0, t); + auto out = grad(fn)({x, y}); + CHECK_EQ(out[0].dtype(), t); + } + } +}