diff --git a/mlx/array.cpp b/mlx/array.cpp index a70cb43a0..2ec9214ab 100644 --- a/mlx/array.cpp +++ b/mlx/array.cpp @@ -6,6 +6,7 @@ #include "mlx/ops.h" #include "mlx/primitives.h" #include "mlx/transforms.h" +#include "mlx/transforms_impl.h" namespace mlx::core { @@ -21,6 +22,12 @@ std::pair> cum_prod(const std::vector& shape) { return {cum_prod, strides}; } +/** Return true if we are currently performing a function transformation in + * order to keep the graph when evaluating tracer arrays. */ +bool in_tracing() { + return detail::InTracing::in_tracing(); +} + } // namespace array::array(const std::complex& val, Dtype dtype /* = complex64 */) @@ -62,8 +69,12 @@ void array::detach() { array_desc_->primitive = nullptr; } -void array::eval(bool retain_graph /* = false */) { - mlx::core::eval({*this}, retain_graph); +void array::eval() { + mlx::core::eval({*this}); +} + +bool array::is_tracer() const { + return array_desc_->is_tracer && in_tracing(); } void array::set_data(allocator::Buffer buffer, deleter_t d) { diff --git a/mlx/array.h b/mlx/array.h index 3801fa1d0..8ea971347 100644 --- a/mlx/array.h +++ b/mlx/array.h @@ -116,11 +116,11 @@ class array { }; /** Evaluate the array. */ - void eval(bool retain_graph = false); + void eval(); /** Get the value from a scalar array. */ template - T item(bool retain_graph = false); + T item(); struct ArrayIterator { using iterator_category = std::random_access_iterator_tag; @@ -265,9 +265,7 @@ class array { array_desc_->is_tracer = is_tracer; } // Check if the array is a tracer array - bool is_tracer() const { - return array_desc_->is_tracer; - } + bool is_tracer() const; void set_data(allocator::Buffer buffer, deleter_t d = allocator::free); @@ -381,11 +379,11 @@ array::array( } template -T array::item(bool retain_graph /* = false */) { +T array::item() { if (size() != 1) { throw std::invalid_argument("item can only be called on arrays of size 1."); } - eval(retain_graph); + eval(); return *data(); } diff --git a/mlx/backend/metal/metal.cpp b/mlx/backend/metal/metal.cpp index 478e57c73..5c0f2d90e 100644 --- a/mlx/backend/metal/metal.cpp +++ b/mlx/backend/metal/metal.cpp @@ -46,39 +46,36 @@ MTL::CommandBuffer* increment_command_buffer(Stream s) { std::function make_task( array& arr, std::vector> deps, - std::shared_ptr> p, - bool retain_graph) { - auto task = - [retain_graph, arr, deps = std::move(deps), p = std::move(p)]() mutable { - auto pool = new_scoped_memory_pool(); - for (auto& d : deps) { - d.wait(); - } - auto s = arr.primitive().stream(); - auto command_buffer = increment_command_buffer(s); - arr.primitive().eval_gpu(arr.inputs(), arr); - if (p) { - metal::device(s.device).end_encoding(s.index); - scheduler::notify_new_task(s); - command_buffer->addCompletedHandler( - [retain_graph, s, arr, p = std::move(p)]( - MTL::CommandBuffer*) mutable { - if (!retain_graph) { - arr.detach(); - } - p->set_value(); - scheduler::notify_task_completion(s); - }); - metal::device(s.device).commit_command_buffer(s.index); - } else { - command_buffer->addCompletedHandler( - [retain_graph, s, arr](MTL::CommandBuffer*) mutable { - if (!retain_graph) { - arr.detach(); - } - }); - } - }; + std::shared_ptr> p) { + auto task = [arr, deps = std::move(deps), p = std::move(p)]() mutable { + auto pool = new_scoped_memory_pool(); + for (auto& d : deps) { + d.wait(); + } + auto s = arr.primitive().stream(); + auto command_buffer = increment_command_buffer(s); + arr.primitive().eval_gpu(arr.inputs(), arr); + if (p) { + metal::device(s.device).end_encoding(s.index); + scheduler::notify_new_task(s); + command_buffer->addCompletedHandler( + [s, arr, p = std::move(p)](MTL::CommandBuffer*) mutable { + if (!arr.is_tracer()) { + arr.detach(); + } + p->set_value(); + scheduler::notify_task_completion(s); + }); + metal::device(s.device).commit_command_buffer(s.index); + } else { + command_buffer->addCompletedHandler( + [s, arr](MTL::CommandBuffer*) mutable { + if (!arr.is_tracer()) { + arr.detach(); + } + }); + } + }; return task; } diff --git a/mlx/backend/metal/metal.h b/mlx/backend/metal/metal.h index 99f400956..ad12ef467 100644 --- a/mlx/backend/metal/metal.h +++ b/mlx/backend/metal/metal.h @@ -25,7 +25,6 @@ std::shared_ptr new_scoped_memory_pool(); std::function make_task( array& arr, std::vector> deps, - std::shared_ptr> p, - bool retain_graph); + std::shared_ptr> p); } // namespace mlx::core::metal diff --git a/mlx/backend/no_metal/metal.cpp b/mlx/backend/no_metal/metal.cpp index 212ca2839..0005a65a9 100644 --- a/mlx/backend/no_metal/metal.cpp +++ b/mlx/backend/no_metal/metal.cpp @@ -14,8 +14,7 @@ std::shared_ptr new_scoped_memory_pool() { std::function make_task( array& arr, std::vector> deps, - std::shared_ptr> p, - bool retain_graph) { + std::shared_ptr> p) { throw std::runtime_error( "[metal::make_task] Cannot make GPU task without metal backend"); } diff --git a/mlx/io/load.cpp b/mlx/io/load.cpp index 27c425455..fa3d13377 100644 --- a/mlx/io/load.cpp +++ b/mlx/io/load.cpp @@ -40,11 +40,11 @@ inline bool is_big_endian_() { } // namespace /** Save array to out stream in .npy format */ -void save(std::shared_ptr out_stream, array a, bool retain_graph) { +void save(std::shared_ptr out_stream, array a) { //////////////////////////////////////////////////////// // Check array - a.eval(retain_graph); + a.eval(); if (a.nbytes() == 0) { throw std::invalid_argument("[save] cannot serialize an empty array"); @@ -52,7 +52,7 @@ void save(std::shared_ptr out_stream, array a, bool retain_graph) { if (!(a.flags().row_contiguous || a.flags().col_contiguous)) { a = reshape(flatten(a), a.shape()); - a.eval(retain_graph); + a.eval(); } // Check once more in-case the above ops change if (!(a.flags().row_contiguous || a.flags().col_contiguous)) { @@ -127,7 +127,7 @@ void save(std::shared_ptr out_stream, array a, bool retain_graph) { } /** Save array to file in .npy format */ -void save(const std::string& file_, array a, bool retain_graph) { +void save(const std::string& file_, array a) { // Open and check file std::string file = file_; @@ -136,7 +136,7 @@ void save(const std::string& file_, array a, bool retain_graph) { file += ".npy"; // Serialize array - save(std::make_shared(file), a, retain_graph); + save(std::make_shared(file), a); } /** Load array from reader in .npy format */ diff --git a/mlx/io/load.h b/mlx/io/load.h index 1d193392a..8aa80bbb7 100644 --- a/mlx/io/load.h +++ b/mlx/io/load.h @@ -111,4 +111,4 @@ class FileWriter : public Writer { }; } // namespace io -} // namespace mlx::core \ No newline at end of file +} // namespace mlx::core diff --git a/mlx/io/safetensor.cpp b/mlx/io/safetensor.cpp index bb78be797..1ca79441d 100644 --- a/mlx/io/safetensor.cpp +++ b/mlx/io/safetensor.cpp @@ -125,8 +125,7 @@ std::unordered_map load_safetensors( /** Save array to out stream in .npy format */ void save_safetensors( std::shared_ptr out_stream, - std::unordered_map a, - std::optional retain_graph_) { + std::unordered_map a) { //////////////////////////////////////////////////////// // Check file if (!out_stream->good() || !out_stream->is_open()) { @@ -142,8 +141,7 @@ void save_safetensors( }); size_t offset = 0; for (auto& [key, arr] : a) { - auto retain = retain_graph_.value_or(arr.is_tracer()); - arr.eval(retain); + arr.eval(); if (arr.nbytes() == 0) { throw std::invalid_argument( "[save_safetensors] cannot serialize an empty array key: " + key); @@ -152,7 +150,7 @@ void save_safetensors( // Try to make it row contiguous if (!arr.flags().row_contiguous) { arr = reshape(flatten(arr), arr.shape()); - arr.eval(retain); + arr.eval(); } // Has to be row-major now but, check one more time in case @@ -181,8 +179,7 @@ void save_safetensors( void save_safetensors( const std::string& file_, - std::unordered_map a, - std::optional retain_graph) { + std::unordered_map a) { // Open and check file std::string file = file_; @@ -192,7 +189,7 @@ void save_safetensors( file += ".safetensors"; // Serialize array - save_safetensors(std::make_shared(file), a, retain_graph); + save_safetensors(std::make_shared(file), a); } } // namespace mlx::core diff --git a/mlx/ops.h b/mlx/ops.h index 31c2eb905..ae7776223 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1021,13 +1021,10 @@ array conv2d( /** Serialization operations */ /** Save array to out stream in .npy format */ -void save( - std::shared_ptr out_stream, - array a, - bool retain_graph = true); +void save(std::shared_ptr out_stream, array a); /** Save array to file in .npy format */ -void save(const std::string& file, array a, bool retain_graph = true); +void save(const std::string& file, array a); /** Load array from reader in .npy format */ array load(std::shared_ptr in_stream, StreamOrDevice s = {}); @@ -1091,10 +1088,8 @@ std::unordered_map load_safetensors( void save_safetensors( std::shared_ptr in_stream, - std::unordered_map, - std::optional retain_graph = std::nullopt); + std::unordered_map); void save_safetensors( const std::string& file, - std::unordered_map, - std::optional retain_graph = std::nullopt); + std::unordered_map); } // namespace mlx::core diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index 174d5e374..67e4731ad 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -19,6 +19,12 @@ namespace mlx::core { +// Initialize the static tracing counter from transforms_impl.h . +// +// This is used to implement the in_tracing() function the returns true if we +// are currently under a function transformation. +int detail::InTracing::tracing_counter{0}; + void simplify(const std::vector& outputs) { std::function recurse; std::queue tape; @@ -154,16 +160,7 @@ void simplify(const std::vector& outputs) { } } -void eval(const std::vector& outputs, bool retain_graph /* = false */) { - if (!retain_graph) { - for (auto& out : outputs) { - if (out.has_primitive() && out.is_tracer()) { - throw std::invalid_argument( - "[eval] Illegal to eval an array during " - "function transform without graph retention."); - } - } - } +void eval(const std::vector& outputs) { std::function recurse; std::queue tape; std::unordered_set cache; @@ -185,7 +182,7 @@ void eval(const std::vector& outputs, bool retain_graph /* = false */) { } } cache.insert(id); - if (!a.is_evaled() || (!retain_graph && a.has_primitive())) { + if (!a.is_evaled() || (!a.is_tracer() && a.has_primitive())) { if (!a.has_primitive()) { throw std::invalid_argument( "[eval] Attempting to eval an array without a primitive."); @@ -195,7 +192,7 @@ void eval(const std::vector& outputs, bool retain_graph /* = false */) { }; for (auto& arr : outputs) { - if (!arr.is_evaled() || (!retain_graph && arr.has_primitive())) { + if (!arr.is_evaled() || (!arr.is_tracer() && arr.has_primitive())) { recurse(arr); // Insert a dependency for every output to synchronize // with at the end. @@ -209,7 +206,7 @@ void eval(const std::vector& outputs, bool retain_graph /* = false */) { auto arr = std::move(tape.front()); tape.pop(); if (arr.is_evaled()) { - if (!retain_graph && arr.has_primitive()) { + if (!arr.is_tracer() && arr.has_primitive()) { arr.detach(); } continue; @@ -233,12 +230,9 @@ void eval(const std::vector& outputs, bool retain_graph /* = false */) { throw std::runtime_error("Metal GPU is not available."); } scheduler::enqueue( - stream, - metal::make_task( - arr, std::move(arr_deps), std::move(p), retain_graph)); + stream, metal::make_task(arr, std::move(arr_deps), std::move(p))); } else { - auto task = [retain_graph, - arr, + auto task = [arr, stream, arr_deps = std::move(arr_deps), p = std::move(p)]() mutable { @@ -247,7 +241,7 @@ void eval(const std::vector& outputs, bool retain_graph /* = false */) { } scheduler::notify_new_task(stream); arr.primitive().eval_cpu(arr.inputs(), arr); - if (!retain_graph) { + if (!arr.is_tracer()) { arr.detach(); } if (p) { @@ -269,6 +263,9 @@ std::pair, std::vector> vjp( const std::function(const std::vector&)>& fun, const std::vector& primals, const std::vector& cotans) { + // Set the global tracing flag. + detail::InTracing in_tracing; + // Make tracers from given primals std::vector primals_; for (auto& p : primals) { @@ -425,6 +422,9 @@ std::pair, std::vector> jvp( } } + // Set the global tracing flag. + detail::InTracing in_tracing; + std::vector primals_; for (auto& p : primals) { auto s = p.has_primitive() ? p.primitive().stream() @@ -578,6 +578,9 @@ std::pair, std::vector> vmap_trace( const std::function(const std::vector&)>& fun, const std::vector& inputs, const std::vector& in_axes) { + // Set the global tracing flag + InTracing in_tracing; + if (in_axes.size() != inputs.size()) { throw std::invalid_argument( "[vmap] The number of in axes must match the number of inputs."); diff --git a/mlx/transforms.h b/mlx/transforms.h index caf648163..813c5f7fd 100644 --- a/mlx/transforms.h +++ b/mlx/transforms.h @@ -14,11 +14,11 @@ void simplify(Arrays... outputs) { simplify(std::vector{std::forward(outputs)...}); } -void eval(const std::vector& outputs, bool retain_graph = false); +void eval(const std::vector& outputs); template void eval(Arrays... outputs) { - eval(std::vector{std::forward(outputs)...}, false); + eval(std::vector{std::forward(outputs)...}); } /** diff --git a/mlx/transforms_impl.h b/mlx/transforms_impl.h index 201e8009b..4b464bafd 100644 --- a/mlx/transforms_impl.h +++ b/mlx/transforms_impl.h @@ -14,4 +14,23 @@ std::vector vmap_replace( const std::vector& in_axes, const std::vector& out_axes); +// Create an InTracing object during tracing operations to signify to the rest +// of the codebase that we are during tracing so evals should not throw away +// the graph. +struct InTracing { + InTracing() { + tracing_counter++; + } + ~InTracing() { + tracing_counter--; + } + + static bool in_tracing() { + return tracing_counter > 0; + } + + private: + static int tracing_counter; +}; + } // namespace mlx::core::detail diff --git a/python/src/array.cpp b/python/src/array.cpp index 6ca0ab0ca..dddbcce29 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -39,34 +39,33 @@ py::list to_list(array& a, size_t index, int dim) { } auto to_scalar(array& a) { - bool retain_graph = a.is_tracer(); switch (a.dtype()) { case bool_: - return py::cast(a.item(retain_graph)); + return py::cast(a.item()); case uint8: - return py::cast(a.item(retain_graph)); + return py::cast(a.item()); case uint16: - return py::cast(a.item(retain_graph)); + return py::cast(a.item()); case uint32: - return py::cast(a.item(retain_graph)); + return py::cast(a.item()); case uint64: - return py::cast(a.item(retain_graph)); + return py::cast(a.item()); case int8: - return py::cast(a.item(retain_graph)); + return py::cast(a.item()); case int16: - return py::cast(a.item(retain_graph)); + return py::cast(a.item()); case int32: - return py::cast(a.item(retain_graph)); + return py::cast(a.item()); case int64: - return py::cast(a.item(retain_graph)); + return py::cast(a.item()); case float16: - return py::cast(static_cast(a.item(retain_graph))); + return py::cast(static_cast(a.item())); case float32: - return py::cast(a.item(retain_graph)); + return py::cast(a.item()); case bfloat16: - return py::cast(static_cast(a.item(retain_graph))); + return py::cast(static_cast(a.item())); case complex64: - return py::cast(a.item>(retain_graph)); + return py::cast(a.item>()); } } @@ -74,7 +73,7 @@ py::object tolist(array& a) { if (a.ndim() == 0) { return to_scalar(a); } - a.eval(a.is_tracer()); + a.eval(); py::object pl; switch (a.dtype()) { case bool_: @@ -527,7 +526,7 @@ void init_array(py::module_& m) { .def_buffer([](array& a) { // Eval if not already evaled if (!a.is_evaled()) { - eval({a}, a.is_tracer()); + a.eval(); } return pybind11::buffer_info( a.data(), @@ -751,7 +750,7 @@ void init_array(py::module_& m) { "__repr__", [](array& a) { if (!a.is_evaled()) { - a.eval(a.is_tracer()); + a.eval(); } std::ostringstream os; os << a; diff --git a/python/src/load.cpp b/python/src/load.cpp index ccd5b221b..fcc1cc722 100644 --- a/python/src/load.cpp +++ b/python/src/load.cpp @@ -345,19 +345,15 @@ class PyFileWriter : public io::Writer { py::object tell_func_; }; -void mlx_save_helper( - py::object file, - array a, - std::optional retain_graph_) { - bool retain_graph = retain_graph_.value_or(a.is_tracer()); +void mlx_save_helper(py::object file, array a) { if (py::isinstance(file)) { - save(py::cast(file), a, retain_graph); + save(py::cast(file), a); return; } else if (is_ostream_object(file)) { auto writer = std::make_shared(file); { py::gil_scoped_release gil; - save(writer, a, retain_graph); + save(writer, a); } return; @@ -414,26 +410,23 @@ void mlx_savez_helper( auto writer = std::make_shared(py_ostream); { py::gil_scoped_release gil; - save(writer, a, /*retain_graph=*/a.is_tracer()); + save(writer, a); } } return; } -void mlx_save_safetensor_helper( - py::object file, - py::dict d, - std::optional retain_graph) { +void mlx_save_safetensor_helper(py::object file, py::dict d) { auto arrays_map = d.cast>(); if (py::isinstance(file)) { - save_safetensors(py::cast(file), arrays_map, retain_graph); + save_safetensors(py::cast(file), arrays_map); return; } else if (is_ostream_object(file)) { auto writer = std::make_shared(file); { py::gil_scoped_release gil; - save_safetensors(writer, arrays_map, retain_graph); + save_safetensors(writer, arrays_map); } return; diff --git a/python/src/load.h b/python/src/load.h index 4dc6fcda7..d1d8bd59c 100644 --- a/python/src/load.h +++ b/python/src/load.h @@ -17,19 +17,13 @@ using DictOrArray = std::variant>; std::unordered_map mlx_load_safetensor_helper( py::object file, StreamOrDevice s); -void mlx_save_safetensor_helper( - py::object file, - py::dict d, - std::optional retain_graph = std::nullopt); +void mlx_save_safetensor_helper(py::object file, py::dict d); DictOrArray mlx_load_helper( py::object file, std::optional format, StreamOrDevice s); -void mlx_save_helper( - py::object file, - array a, - std::optional retain_graph = std::nullopt); +void mlx_save_helper(py::object file, array a); void mlx_savez_helper( py::object file, py::args args, diff --git a/python/src/ops.cpp b/python/src/ops.cpp index f41049b82..50bee45a6 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -2902,20 +2902,14 @@ void init_ops(py::module_& m) { &mlx_save_helper, "file"_a, "arr"_a, - py::pos_only(), - "retain_graph"_a = std::nullopt, - py::kw_only(), R"pbdoc( - save(file: str, arr: array, / , retain_graph: Optional[bool] = None) + save(file: str, arr: array) Save the array to a binary file in ``.npy`` format. Args: file (str): File to which the array is saved arr (array): Array to be saved. - retain_graph (bool, optional): Whether or not to retain the graph - during array evaluation. If left unspecified the graph is retained - only if saving is done in a function transformation. Default: ``None`` )pbdoc"); m.def( "savez", @@ -2999,11 +2993,8 @@ void init_ops(py::module_& m) { &mlx_save_safetensor_helper, "file"_a, "arrays"_a, - py::pos_only(), - "retain_graph"_a = std::nullopt, - py::kw_only(), R"pbdoc( - save_safetensors(file: str, arrays: Dict[str, array], /, retain_graph: Optional[bool] = None) + save_safetensors(file: str, arrays: Dict[str, array]) Save array(s) to a binary file in ``.safetensors`` format. @@ -3012,9 +3003,6 @@ void init_ops(py::module_& m) { Args: file (file, str): File in which the array is saved> arrays (dict(str, array)): The dictionary of names to arrays to be saved. - retain_graph (bool, optional): Whether or not to retain the graph - during array evaluation. If left unspecified the graph is retained - only if saving is done in a function transformation. Default: ``None``. )pbdoc"); m.def( "where", diff --git a/python/src/transforms.cpp b/python/src/transforms.cpp index 096d5a486..a1afebef9 100644 --- a/python/src/transforms.cpp +++ b/python/src/transforms.cpp @@ -440,11 +440,10 @@ auto py_vmap( void init_transforms(py::module_& m) { m.def( "eval", - [](const py::args& args, bool retain_graph) { + [](const py::args& args) { std::vector arrays = tree_flatten(args); - eval(arrays, retain_graph); + eval(arrays); }, - "retain_graph"_a = false, R"pbdoc( Evaluate an :class:`array` or tree of :class:`array`. @@ -453,9 +452,6 @@ void init_transforms(py::module_& m) { or a tree of arrays. If a tree is given the nodes can be a Python :class:`list`, :class:`tuple` or :class:`dict` but the leafs must all be an :class:`array`. - retain_graph (bool): Indicate that the graph structure should be - preserved. This option is intended to enable function transforms - which contain control flow based on the value of an array. )pbdoc"); m.def( "jvp", diff --git a/python/tests/test_autograd.py b/python/tests/test_autograd.py index 72cc0dc55..69d8baeea 100644 --- a/python/tests/test_autograd.py +++ b/python/tests/test_autograd.py @@ -259,6 +259,21 @@ class TestAutograd(mlx_tests.MLXTestCase): self.assertTrue(mx.allclose(vjps[0], mx.zeros(shape_in))) + def test_update_state(self): + y = mx.array([1.0]) + state = mx.zeros((2,)) + + def fn(y, x): + nonlocal state + x = y * x + state = state + x + return x.sum() + + x = mx.ones((2,)) + mx.grad(fn)(y, x) + mx.eval(state) + self.assertTrue(mx.allclose(state, mx.ones((2,)))) + if __name__ == "__main__": unittest.main() diff --git a/python/tests/test_eval.py b/python/tests/test_eval.py index b5145efd0..6619afa67 100644 --- a/python/tests/test_eval.py +++ b/python/tests/test_eval.py @@ -15,18 +15,13 @@ class TestEval(mlx_tests.MLXTestCase): self.assertEqual(x.tolist(), [[1, 1], [1, 1]]) def test_retain_graph(self): - def fun(x, retain_graph): + def fun(x): y = 3 * x - mx.eval(y, retain_graph=retain_graph) + mx.eval(y) return 2 * y - dfun_dx_1 = mx.grad(partial(fun, retain_graph=False)) - dfun_dx_2 = mx.grad(partial(fun, retain_graph=True)) - - with self.assertRaises(ValueError): - dfun_dx_1(mx.array(1.0)) - - y = dfun_dx_2(mx.array(1.0)) + dfun_dx = mx.grad(fun) + y = dfun_dx(mx.array(1.0)) self.assertEqual(y.item(), 6.0) diff --git a/tests/autograd_tests.cpp b/tests/autograd_tests.cpp index a7b7e7fca..554726363 100644 --- a/tests/autograd_tests.cpp +++ b/tests/autograd_tests.cpp @@ -95,19 +95,14 @@ TEST_CASE("test jvp") { CHECK_EQ(dout[0].item(), 4.0f); } - // Evaling in function without graph retention throws + // Evaling in function while tracing performs graph retention { - auto fun = [](const array& x) { - auto y = 3 * x; - eval(y); - return 2 * y; - }; - CHECK_THROWS(jvp(fun, array(1.0f), array(1.0f))); - - // Ok with graph retention auto fun1 = [](const array& x) { auto y = 3 * x; - eval({y}, true); + eval(y); + CHECK(y.is_evaled()); + CHECK(y.has_primitive()); + CHECK(y.is_tracer()); return 2 * y; }; CHECK_EQ(jvp(fun1, array(1.0f), array(1.0f)).second.item(), 6.0f); @@ -251,29 +246,27 @@ TEST_CASE("test grad") { } { - // Evaluating in the middle of the grad function throws - // as it breaks the graph - auto fn = [](array x) { - x = x + 2.0f; - eval(x); - return square(x); - }; - CHECK_THROWS(grad(fn)(array(1.0f))); - - // Ok since the output is independent of y + // No graph retention since the output is independent of y auto y = ones({3, 3}); auto fn1 = [y](array x) { x = x + 2.0f; eval(y); + CHECK(x.is_tracer()); + CHECK(!y.is_tracer()); + CHECK(y.is_evaled()); + CHECK(!y.has_primitive()); return square(x); }; auto dfdx = grad(fn1)(array(1.0f)); CHECK_EQ(dfdx.item(), 6.0f); - // Retain the graph to avoid breaking it + // Graph automatically retained to compute the grad auto fn2 = [](array x) { x = x + 2.0f; - eval({x}, true); + eval(x); + CHECK(x.is_tracer()); + CHECK(x.is_evaled()); + CHECK(x.has_primitive()); return square(x); }; dfdx = grad(fn2)(array(1.0f)); @@ -283,7 +276,8 @@ TEST_CASE("test grad") { // Control flow in grad computation { auto fn = [](array x) { - if (x.item(true) > 1) { + x = x + array(2.0f); + if (x.item() > 3) { return square(x); } else { return 4 * x; @@ -294,7 +288,7 @@ TEST_CASE("test grad") { CHECK_EQ(dfdx.item(), 4.0f); dfdx = grad(fn)(array(1.5f)); - CHECK_EQ(dfdx.item(), 3.0f); + CHECK_EQ(dfdx.item(), 7.0f); } // Grad with multiple inputs @@ -1192,3 +1186,19 @@ TEST_CASE("test scan grads") { CHECK(array_equal(out, expected).item()); } } + +TEST_CASE("test update state") { + auto y = array({1.0}); + auto x = array({1.0, 1.0}); + auto state = array({0.0, 0.0}); + auto fn = [&state, &x](array y) { + x = y * x; + state = state + x; + return sum(x); + }; + grad(fn)(y); + eval(state); + CHECK(!state.has_primitive()); + CHECK(state.is_evaled()); + CHECK(array_equal(state, array({1.0, 1.0})).item()); +} diff --git a/tests/eval_tests.cpp b/tests/eval_tests.cpp index dc2e96bbb..1c0ba857f 100644 --- a/tests/eval_tests.cpp +++ b/tests/eval_tests.cpp @@ -48,36 +48,36 @@ TEST_CASE("test eval multiple") { CHECK(array_equal(b, full({10, 10}, 0.0f)).item()); } -TEST_CASE("test eval with tracer") { +TEST_CASE("test eval with tracer when not tracing") { + // Since we are not tracing it doesn't matter that the array flags are + // tracers they will always be detached. auto x = array(1); x.set_tracer(true); - - // Ok, x is not a node + CHECK(!x.is_tracer()); eval(x); + CHECK(!x.has_primitive()); + CHECK(x.is_evaled()); x = ones({2, 3}); x.set_tracer(true); - CHECK_THROWS(eval(x)); - - // Ok retain_graph=true - eval({x}, true); - - // Make sure all arguments are checked - auto y = ones({2, 3}); - CHECK_THROWS(eval(x, y)); + eval(x); + CHECK(!x.has_primitive()); + CHECK(x.is_evaled()); } -TEST_CASE("test eval graph retention") { +TEST_CASE("test eval graph retention when not tracing") { + // Since we are not tracing it doesn't matter that the array flags are + // tracers they will always be detached. auto x = array(1); + x.set_tracer(true); auto y = array(2); auto z = x + y; - eval({z}, true); - CHECK(z.has_primitive()); - CHECK(z.is_evaled()); - CHECK_EQ(z.item(true), 3); - CHECK(z.has_primitive()); + eval(z); + CHECK(!z.has_primitive()); CHECK(z.is_evaled()); + CHECK_EQ(z.item(), 3); + z.set_tracer(false); CHECK_EQ(z.item(), 3); CHECK(!z.has_primitive()); CHECK(z.is_evaled()); @@ -85,13 +85,7 @@ TEST_CASE("test eval graph retention") { z = x + y; auto a = z + x; auto b = a + y; - eval({b}, true); - CHECK(z.has_primitive()); - CHECK(z.is_evaled()); - CHECK(a.has_primitive()); - CHECK(a.is_evaled()); - - eval({b}, false); + eval(b); CHECK(!z.has_primitive()); CHECK(z.is_evaled()); CHECK(!a.has_primitive()); diff --git a/tests/vmap_tests.cpp b/tests/vmap_tests.cpp index b30fd0a21..6eba9ef03 100644 --- a/tests/vmap_tests.cpp +++ b/tests/vmap_tests.cpp @@ -183,7 +183,7 @@ TEST_CASE("test vmap with eval") { auto fun2 = [](std::vector inputs) { auto x = inputs[0] + 1; auto y = inputs[1] + 2; - eval({x}, true); + eval(x); auto out = add(x, y); return std::vector{out}; };