Removes the retain_graph flag (#385)

* Adds global tracing flag
* Removes retain_graph in favor of is_tracer
This commit is contained in:
Angelos Katharopoulos 2024-01-07 15:16:51 -08:00 committed by GitHub
parent 449b43762e
commit a611b0bc82
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 209 additions and 207 deletions

View File

@ -6,6 +6,7 @@
#include "mlx/ops.h" #include "mlx/ops.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
#include "mlx/transforms.h" #include "mlx/transforms.h"
#include "mlx/transforms_impl.h"
namespace mlx::core { namespace mlx::core {
@ -21,6 +22,12 @@ std::pair<size_t, std::vector<size_t>> cum_prod(const std::vector<int>& shape) {
return {cum_prod, strides}; return {cum_prod, strides};
} }
/** Return true if we are currently performing a function transformation in
* order to keep the graph when evaluating tracer arrays. */
bool in_tracing() {
return detail::InTracing::in_tracing();
}
} // namespace } // namespace
array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */) array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
@ -62,8 +69,12 @@ void array::detach() {
array_desc_->primitive = nullptr; array_desc_->primitive = nullptr;
} }
void array::eval(bool retain_graph /* = false */) { void array::eval() {
mlx::core::eval({*this}, retain_graph); mlx::core::eval({*this});
}
bool array::is_tracer() const {
return array_desc_->is_tracer && in_tracing();
} }
void array::set_data(allocator::Buffer buffer, deleter_t d) { void array::set_data(allocator::Buffer buffer, deleter_t d) {

View File

@ -116,11 +116,11 @@ class array {
}; };
/** Evaluate the array. */ /** Evaluate the array. */
void eval(bool retain_graph = false); void eval();
/** Get the value from a scalar array. */ /** Get the value from a scalar array. */
template <typename T> template <typename T>
T item(bool retain_graph = false); T item();
struct ArrayIterator { struct ArrayIterator {
using iterator_category = std::random_access_iterator_tag; using iterator_category = std::random_access_iterator_tag;
@ -265,9 +265,7 @@ class array {
array_desc_->is_tracer = is_tracer; array_desc_->is_tracer = is_tracer;
} }
// Check if the array is a tracer array // Check if the array is a tracer array
bool is_tracer() const { bool is_tracer() const;
return array_desc_->is_tracer;
}
void set_data(allocator::Buffer buffer, deleter_t d = allocator::free); void set_data(allocator::Buffer buffer, deleter_t d = allocator::free);
@ -381,11 +379,11 @@ array::array(
} }
template <typename T> template <typename T>
T array::item(bool retain_graph /* = false */) { T array::item() {
if (size() != 1) { if (size() != 1) {
throw std::invalid_argument("item can only be called on arrays of size 1."); throw std::invalid_argument("item can only be called on arrays of size 1.");
} }
eval(retain_graph); eval();
return *data<T>(); return *data<T>();
} }

View File

@ -46,10 +46,8 @@ MTL::CommandBuffer* increment_command_buffer(Stream s) {
std::function<void()> make_task( std::function<void()> make_task(
array& arr, array& arr,
std::vector<std::shared_future<void>> deps, std::vector<std::shared_future<void>> deps,
std::shared_ptr<std::promise<void>> p, std::shared_ptr<std::promise<void>> p) {
bool retain_graph) { auto task = [arr, deps = std::move(deps), p = std::move(p)]() mutable {
auto task =
[retain_graph, arr, deps = std::move(deps), p = std::move(p)]() mutable {
auto pool = new_scoped_memory_pool(); auto pool = new_scoped_memory_pool();
for (auto& d : deps) { for (auto& d : deps) {
d.wait(); d.wait();
@ -61,9 +59,8 @@ std::function<void()> make_task(
metal::device(s.device).end_encoding(s.index); metal::device(s.device).end_encoding(s.index);
scheduler::notify_new_task(s); scheduler::notify_new_task(s);
command_buffer->addCompletedHandler( command_buffer->addCompletedHandler(
[retain_graph, s, arr, p = std::move(p)]( [s, arr, p = std::move(p)](MTL::CommandBuffer*) mutable {
MTL::CommandBuffer*) mutable { if (!arr.is_tracer()) {
if (!retain_graph) {
arr.detach(); arr.detach();
} }
p->set_value(); p->set_value();
@ -72,8 +69,8 @@ std::function<void()> make_task(
metal::device(s.device).commit_command_buffer(s.index); metal::device(s.device).commit_command_buffer(s.index);
} else { } else {
command_buffer->addCompletedHandler( command_buffer->addCompletedHandler(
[retain_graph, s, arr](MTL::CommandBuffer*) mutable { [s, arr](MTL::CommandBuffer*) mutable {
if (!retain_graph) { if (!arr.is_tracer()) {
arr.detach(); arr.detach();
} }
}); });

View File

@ -25,7 +25,6 @@ std::shared_ptr<void> new_scoped_memory_pool();
std::function<void()> make_task( std::function<void()> make_task(
array& arr, array& arr,
std::vector<std::shared_future<void>> deps, std::vector<std::shared_future<void>> deps,
std::shared_ptr<std::promise<void>> p, std::shared_ptr<std::promise<void>> p);
bool retain_graph);
} // namespace mlx::core::metal } // namespace mlx::core::metal

View File

@ -14,8 +14,7 @@ std::shared_ptr<void> new_scoped_memory_pool() {
std::function<void()> make_task( std::function<void()> make_task(
array& arr, array& arr,
std::vector<std::shared_future<void>> deps, std::vector<std::shared_future<void>> deps,
std::shared_ptr<std::promise<void>> p, std::shared_ptr<std::promise<void>> p) {
bool retain_graph) {
throw std::runtime_error( throw std::runtime_error(
"[metal::make_task] Cannot make GPU task without metal backend"); "[metal::make_task] Cannot make GPU task without metal backend");
} }

View File

@ -40,11 +40,11 @@ inline bool is_big_endian_() {
} // namespace } // namespace
/** Save array to out stream in .npy format */ /** Save array to out stream in .npy format */
void save(std::shared_ptr<io::Writer> out_stream, array a, bool retain_graph) { void save(std::shared_ptr<io::Writer> out_stream, array a) {
//////////////////////////////////////////////////////// ////////////////////////////////////////////////////////
// Check array // Check array
a.eval(retain_graph); a.eval();
if (a.nbytes() == 0) { if (a.nbytes() == 0) {
throw std::invalid_argument("[save] cannot serialize an empty array"); throw std::invalid_argument("[save] cannot serialize an empty array");
@ -52,7 +52,7 @@ void save(std::shared_ptr<io::Writer> out_stream, array a, bool retain_graph) {
if (!(a.flags().row_contiguous || a.flags().col_contiguous)) { if (!(a.flags().row_contiguous || a.flags().col_contiguous)) {
a = reshape(flatten(a), a.shape()); a = reshape(flatten(a), a.shape());
a.eval(retain_graph); a.eval();
} }
// Check once more in-case the above ops change // Check once more in-case the above ops change
if (!(a.flags().row_contiguous || a.flags().col_contiguous)) { if (!(a.flags().row_contiguous || a.flags().col_contiguous)) {
@ -127,7 +127,7 @@ void save(std::shared_ptr<io::Writer> out_stream, array a, bool retain_graph) {
} }
/** Save array to file in .npy format */ /** Save array to file in .npy format */
void save(const std::string& file_, array a, bool retain_graph) { void save(const std::string& file_, array a) {
// Open and check file // Open and check file
std::string file = file_; std::string file = file_;
@ -136,7 +136,7 @@ void save(const std::string& file_, array a, bool retain_graph) {
file += ".npy"; file += ".npy";
// Serialize array // Serialize array
save(std::make_shared<io::FileWriter>(file), a, retain_graph); save(std::make_shared<io::FileWriter>(file), a);
} }
/** Load array from reader in .npy format */ /** Load array from reader in .npy format */

View File

@ -125,8 +125,7 @@ std::unordered_map<std::string, array> load_safetensors(
/** Save array to out stream in .npy format */ /** Save array to out stream in .npy format */
void save_safetensors( void save_safetensors(
std::shared_ptr<io::Writer> out_stream, std::shared_ptr<io::Writer> out_stream,
std::unordered_map<std::string, array> a, std::unordered_map<std::string, array> a) {
std::optional<bool> retain_graph_) {
//////////////////////////////////////////////////////// ////////////////////////////////////////////////////////
// Check file // Check file
if (!out_stream->good() || !out_stream->is_open()) { if (!out_stream->good() || !out_stream->is_open()) {
@ -142,8 +141,7 @@ void save_safetensors(
}); });
size_t offset = 0; size_t offset = 0;
for (auto& [key, arr] : a) { for (auto& [key, arr] : a) {
auto retain = retain_graph_.value_or(arr.is_tracer()); arr.eval();
arr.eval(retain);
if (arr.nbytes() == 0) { if (arr.nbytes() == 0) {
throw std::invalid_argument( throw std::invalid_argument(
"[save_safetensors] cannot serialize an empty array key: " + key); "[save_safetensors] cannot serialize an empty array key: " + key);
@ -152,7 +150,7 @@ void save_safetensors(
// Try to make it row contiguous // Try to make it row contiguous
if (!arr.flags().row_contiguous) { if (!arr.flags().row_contiguous) {
arr = reshape(flatten(arr), arr.shape()); arr = reshape(flatten(arr), arr.shape());
arr.eval(retain); arr.eval();
} }
// Has to be row-major now but, check one more time in case // Has to be row-major now but, check one more time in case
@ -181,8 +179,7 @@ void save_safetensors(
void save_safetensors( void save_safetensors(
const std::string& file_, const std::string& file_,
std::unordered_map<std::string, array> a, std::unordered_map<std::string, array> a) {
std::optional<bool> retain_graph) {
// Open and check file // Open and check file
std::string file = file_; std::string file = file_;
@ -192,7 +189,7 @@ void save_safetensors(
file += ".safetensors"; file += ".safetensors";
// Serialize array // Serialize array
save_safetensors(std::make_shared<io::FileWriter>(file), a, retain_graph); save_safetensors(std::make_shared<io::FileWriter>(file), a);
} }
} // namespace mlx::core } // namespace mlx::core

View File

@ -1021,13 +1021,10 @@ array conv2d(
/** Serialization operations */ /** Serialization operations */
/** Save array to out stream in .npy format */ /** Save array to out stream in .npy format */
void save( void save(std::shared_ptr<io::Writer> out_stream, array a);
std::shared_ptr<io::Writer> out_stream,
array a,
bool retain_graph = true);
/** Save array to file in .npy format */ /** Save array to file in .npy format */
void save(const std::string& file, array a, bool retain_graph = true); void save(const std::string& file, array a);
/** Load array from reader in .npy format */ /** Load array from reader in .npy format */
array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s = {}); array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s = {});
@ -1091,10 +1088,8 @@ std::unordered_map<std::string, array> load_safetensors(
void save_safetensors( void save_safetensors(
std::shared_ptr<io::Writer> in_stream, std::shared_ptr<io::Writer> in_stream,
std::unordered_map<std::string, array>, std::unordered_map<std::string, array>);
std::optional<bool> retain_graph = std::nullopt);
void save_safetensors( void save_safetensors(
const std::string& file, const std::string& file,
std::unordered_map<std::string, array>, std::unordered_map<std::string, array>);
std::optional<bool> retain_graph = std::nullopt);
} // namespace mlx::core } // namespace mlx::core

View File

@ -19,6 +19,12 @@
namespace mlx::core { namespace mlx::core {
// Initialize the static tracing counter from transforms_impl.h .
//
// This is used to implement the in_tracing() function the returns true if we
// are currently under a function transformation.
int detail::InTracing::tracing_counter{0};
void simplify(const std::vector<array>& outputs) { void simplify(const std::vector<array>& outputs) {
std::function<void(const array&)> recurse; std::function<void(const array&)> recurse;
std::queue<array> tape; std::queue<array> tape;
@ -154,16 +160,7 @@ void simplify(const std::vector<array>& outputs) {
} }
} }
void eval(const std::vector<array>& outputs, bool retain_graph /* = false */) { void eval(const std::vector<array>& outputs) {
if (!retain_graph) {
for (auto& out : outputs) {
if (out.has_primitive() && out.is_tracer()) {
throw std::invalid_argument(
"[eval] Illegal to eval an array during "
"function transform without graph retention.");
}
}
}
std::function<void(const array&)> recurse; std::function<void(const array&)> recurse;
std::queue<array> tape; std::queue<array> tape;
std::unordered_set<std::uintptr_t> cache; std::unordered_set<std::uintptr_t> cache;
@ -185,7 +182,7 @@ void eval(const std::vector<array>& outputs, bool retain_graph /* = false */) {
} }
} }
cache.insert(id); cache.insert(id);
if (!a.is_evaled() || (!retain_graph && a.has_primitive())) { if (!a.is_evaled() || (!a.is_tracer() && a.has_primitive())) {
if (!a.has_primitive()) { if (!a.has_primitive()) {
throw std::invalid_argument( throw std::invalid_argument(
"[eval] Attempting to eval an array without a primitive."); "[eval] Attempting to eval an array without a primitive.");
@ -195,7 +192,7 @@ void eval(const std::vector<array>& outputs, bool retain_graph /* = false */) {
}; };
for (auto& arr : outputs) { for (auto& arr : outputs) {
if (!arr.is_evaled() || (!retain_graph && arr.has_primitive())) { if (!arr.is_evaled() || (!arr.is_tracer() && arr.has_primitive())) {
recurse(arr); recurse(arr);
// Insert a dependency for every output to synchronize // Insert a dependency for every output to synchronize
// with at the end. // with at the end.
@ -209,7 +206,7 @@ void eval(const std::vector<array>& outputs, bool retain_graph /* = false */) {
auto arr = std::move(tape.front()); auto arr = std::move(tape.front());
tape.pop(); tape.pop();
if (arr.is_evaled()) { if (arr.is_evaled()) {
if (!retain_graph && arr.has_primitive()) { if (!arr.is_tracer() && arr.has_primitive()) {
arr.detach(); arr.detach();
} }
continue; continue;
@ -233,12 +230,9 @@ void eval(const std::vector<array>& outputs, bool retain_graph /* = false */) {
throw std::runtime_error("Metal GPU is not available."); throw std::runtime_error("Metal GPU is not available.");
} }
scheduler::enqueue( scheduler::enqueue(
stream, stream, metal::make_task(arr, std::move(arr_deps), std::move(p)));
metal::make_task(
arr, std::move(arr_deps), std::move(p), retain_graph));
} else { } else {
auto task = [retain_graph, auto task = [arr,
arr,
stream, stream,
arr_deps = std::move(arr_deps), arr_deps = std::move(arr_deps),
p = std::move(p)]() mutable { p = std::move(p)]() mutable {
@ -247,7 +241,7 @@ void eval(const std::vector<array>& outputs, bool retain_graph /* = false */) {
} }
scheduler::notify_new_task(stream); scheduler::notify_new_task(stream);
arr.primitive().eval_cpu(arr.inputs(), arr); arr.primitive().eval_cpu(arr.inputs(), arr);
if (!retain_graph) { if (!arr.is_tracer()) {
arr.detach(); arr.detach();
} }
if (p) { if (p) {
@ -269,6 +263,9 @@ std::pair<std::vector<array>, std::vector<array>> vjp(
const std::function<std::vector<array>(const std::vector<array>&)>& fun, const std::function<std::vector<array>(const std::vector<array>&)>& fun,
const std::vector<array>& primals, const std::vector<array>& primals,
const std::vector<array>& cotans) { const std::vector<array>& cotans) {
// Set the global tracing flag.
detail::InTracing in_tracing;
// Make tracers from given primals // Make tracers from given primals
std::vector<array> primals_; std::vector<array> primals_;
for (auto& p : primals) { for (auto& p : primals) {
@ -425,6 +422,9 @@ std::pair<std::vector<array>, std::vector<array>> jvp(
} }
} }
// Set the global tracing flag.
detail::InTracing in_tracing;
std::vector<array> primals_; std::vector<array> primals_;
for (auto& p : primals) { for (auto& p : primals) {
auto s = p.has_primitive() ? p.primitive().stream() auto s = p.has_primitive() ? p.primitive().stream()
@ -578,6 +578,9 @@ std::pair<std::vector<array>, std::vector<array>> vmap_trace(
const std::function<std::vector<array>(const std::vector<array>&)>& fun, const std::function<std::vector<array>(const std::vector<array>&)>& fun,
const std::vector<array>& inputs, const std::vector<array>& inputs,
const std::vector<int>& in_axes) { const std::vector<int>& in_axes) {
// Set the global tracing flag
InTracing in_tracing;
if (in_axes.size() != inputs.size()) { if (in_axes.size() != inputs.size()) {
throw std::invalid_argument( throw std::invalid_argument(
"[vmap] The number of in axes must match the number of inputs."); "[vmap] The number of in axes must match the number of inputs.");

View File

@ -14,11 +14,11 @@ void simplify(Arrays... outputs) {
simplify(std::vector<array>{std::forward<Arrays>(outputs)...}); simplify(std::vector<array>{std::forward<Arrays>(outputs)...});
} }
void eval(const std::vector<array>& outputs, bool retain_graph = false); void eval(const std::vector<array>& outputs);
template <typename... Arrays> template <typename... Arrays>
void eval(Arrays... outputs) { void eval(Arrays... outputs) {
eval(std::vector<array>{std::forward<Arrays>(outputs)...}, false); eval(std::vector<array>{std::forward<Arrays>(outputs)...});
} }
/** /**

View File

@ -14,4 +14,23 @@ std::vector<array> vmap_replace(
const std::vector<int>& in_axes, const std::vector<int>& in_axes,
const std::vector<int>& out_axes); const std::vector<int>& out_axes);
// Create an InTracing object during tracing operations to signify to the rest
// of the codebase that we are during tracing so evals should not throw away
// the graph.
struct InTracing {
InTracing() {
tracing_counter++;
}
~InTracing() {
tracing_counter--;
}
static bool in_tracing() {
return tracing_counter > 0;
}
private:
static int tracing_counter;
};
} // namespace mlx::core::detail } // namespace mlx::core::detail

View File

@ -39,34 +39,33 @@ py::list to_list(array& a, size_t index, int dim) {
} }
auto to_scalar(array& a) { auto to_scalar(array& a) {
bool retain_graph = a.is_tracer();
switch (a.dtype()) { switch (a.dtype()) {
case bool_: case bool_:
return py::cast(a.item<bool>(retain_graph)); return py::cast(a.item<bool>());
case uint8: case uint8:
return py::cast(a.item<uint8_t>(retain_graph)); return py::cast(a.item<uint8_t>());
case uint16: case uint16:
return py::cast(a.item<uint16_t>(retain_graph)); return py::cast(a.item<uint16_t>());
case uint32: case uint32:
return py::cast(a.item<uint32_t>(retain_graph)); return py::cast(a.item<uint32_t>());
case uint64: case uint64:
return py::cast(a.item<uint64_t>(retain_graph)); return py::cast(a.item<uint64_t>());
case int8: case int8:
return py::cast(a.item<int8_t>(retain_graph)); return py::cast(a.item<int8_t>());
case int16: case int16:
return py::cast(a.item<int16_t>(retain_graph)); return py::cast(a.item<int16_t>());
case int32: case int32:
return py::cast(a.item<int32_t>(retain_graph)); return py::cast(a.item<int32_t>());
case int64: case int64:
return py::cast(a.item<int64_t>(retain_graph)); return py::cast(a.item<int64_t>());
case float16: case float16:
return py::cast(static_cast<float>(a.item<float16_t>(retain_graph))); return py::cast(static_cast<float>(a.item<float16_t>()));
case float32: case float32:
return py::cast(a.item<float>(retain_graph)); return py::cast(a.item<float>());
case bfloat16: case bfloat16:
return py::cast(static_cast<float>(a.item<bfloat16_t>(retain_graph))); return py::cast(static_cast<float>(a.item<bfloat16_t>()));
case complex64: case complex64:
return py::cast(a.item<std::complex<float>>(retain_graph)); return py::cast(a.item<std::complex<float>>());
} }
} }
@ -74,7 +73,7 @@ py::object tolist(array& a) {
if (a.ndim() == 0) { if (a.ndim() == 0) {
return to_scalar(a); return to_scalar(a);
} }
a.eval(a.is_tracer()); a.eval();
py::object pl; py::object pl;
switch (a.dtype()) { switch (a.dtype()) {
case bool_: case bool_:
@ -527,7 +526,7 @@ void init_array(py::module_& m) {
.def_buffer([](array& a) { .def_buffer([](array& a) {
// Eval if not already evaled // Eval if not already evaled
if (!a.is_evaled()) { if (!a.is_evaled()) {
eval({a}, a.is_tracer()); a.eval();
} }
return pybind11::buffer_info( return pybind11::buffer_info(
a.data<void>(), a.data<void>(),
@ -751,7 +750,7 @@ void init_array(py::module_& m) {
"__repr__", "__repr__",
[](array& a) { [](array& a) {
if (!a.is_evaled()) { if (!a.is_evaled()) {
a.eval(a.is_tracer()); a.eval();
} }
std::ostringstream os; std::ostringstream os;
os << a; os << a;

View File

@ -345,19 +345,15 @@ class PyFileWriter : public io::Writer {
py::object tell_func_; py::object tell_func_;
}; };
void mlx_save_helper( void mlx_save_helper(py::object file, array a) {
py::object file,
array a,
std::optional<bool> retain_graph_) {
bool retain_graph = retain_graph_.value_or(a.is_tracer());
if (py::isinstance<py::str>(file)) { if (py::isinstance<py::str>(file)) {
save(py::cast<std::string>(file), a, retain_graph); save(py::cast<std::string>(file), a);
return; return;
} else if (is_ostream_object(file)) { } else if (is_ostream_object(file)) {
auto writer = std::make_shared<PyFileWriter>(file); auto writer = std::make_shared<PyFileWriter>(file);
{ {
py::gil_scoped_release gil; py::gil_scoped_release gil;
save(writer, a, retain_graph); save(writer, a);
} }
return; return;
@ -414,26 +410,23 @@ void mlx_savez_helper(
auto writer = std::make_shared<PyFileWriter>(py_ostream); auto writer = std::make_shared<PyFileWriter>(py_ostream);
{ {
py::gil_scoped_release gil; py::gil_scoped_release gil;
save(writer, a, /*retain_graph=*/a.is_tracer()); save(writer, a);
} }
} }
return; return;
} }
void mlx_save_safetensor_helper( void mlx_save_safetensor_helper(py::object file, py::dict d) {
py::object file,
py::dict d,
std::optional<bool> retain_graph) {
auto arrays_map = d.cast<std::unordered_map<std::string, array>>(); auto arrays_map = d.cast<std::unordered_map<std::string, array>>();
if (py::isinstance<py::str>(file)) { if (py::isinstance<py::str>(file)) {
save_safetensors(py::cast<std::string>(file), arrays_map, retain_graph); save_safetensors(py::cast<std::string>(file), arrays_map);
return; return;
} else if (is_ostream_object(file)) { } else if (is_ostream_object(file)) {
auto writer = std::make_shared<PyFileWriter>(file); auto writer = std::make_shared<PyFileWriter>(file);
{ {
py::gil_scoped_release gil; py::gil_scoped_release gil;
save_safetensors(writer, arrays_map, retain_graph); save_safetensors(writer, arrays_map);
} }
return; return;

View File

@ -17,19 +17,13 @@ using DictOrArray = std::variant<array, std::unordered_map<std::string, array>>;
std::unordered_map<std::string, array> mlx_load_safetensor_helper( std::unordered_map<std::string, array> mlx_load_safetensor_helper(
py::object file, py::object file,
StreamOrDevice s); StreamOrDevice s);
void mlx_save_safetensor_helper( void mlx_save_safetensor_helper(py::object file, py::dict d);
py::object file,
py::dict d,
std::optional<bool> retain_graph = std::nullopt);
DictOrArray mlx_load_helper( DictOrArray mlx_load_helper(
py::object file, py::object file,
std::optional<std::string> format, std::optional<std::string> format,
StreamOrDevice s); StreamOrDevice s);
void mlx_save_helper( void mlx_save_helper(py::object file, array a);
py::object file,
array a,
std::optional<bool> retain_graph = std::nullopt);
void mlx_savez_helper( void mlx_savez_helper(
py::object file, py::object file,
py::args args, py::args args,

View File

@ -2902,20 +2902,14 @@ void init_ops(py::module_& m) {
&mlx_save_helper, &mlx_save_helper,
"file"_a, "file"_a,
"arr"_a, "arr"_a,
py::pos_only(),
"retain_graph"_a = std::nullopt,
py::kw_only(),
R"pbdoc( R"pbdoc(
save(file: str, arr: array, / , retain_graph: Optional[bool] = None) save(file: str, arr: array)
Save the array to a binary file in ``.npy`` format. Save the array to a binary file in ``.npy`` format.
Args: Args:
file (str): File to which the array is saved file (str): File to which the array is saved
arr (array): Array to be saved. arr (array): Array to be saved.
retain_graph (bool, optional): Whether or not to retain the graph
during array evaluation. If left unspecified the graph is retained
only if saving is done in a function transformation. Default: ``None``
)pbdoc"); )pbdoc");
m.def( m.def(
"savez", "savez",
@ -2999,11 +2993,8 @@ void init_ops(py::module_& m) {
&mlx_save_safetensor_helper, &mlx_save_safetensor_helper,
"file"_a, "file"_a,
"arrays"_a, "arrays"_a,
py::pos_only(),
"retain_graph"_a = std::nullopt,
py::kw_only(),
R"pbdoc( R"pbdoc(
save_safetensors(file: str, arrays: Dict[str, array], /, retain_graph: Optional[bool] = None) save_safetensors(file: str, arrays: Dict[str, array])
Save array(s) to a binary file in ``.safetensors`` format. Save array(s) to a binary file in ``.safetensors`` format.
@ -3012,9 +3003,6 @@ void init_ops(py::module_& m) {
Args: Args:
file (file, str): File in which the array is saved> file (file, str): File in which the array is saved>
arrays (dict(str, array)): The dictionary of names to arrays to be saved. arrays (dict(str, array)): The dictionary of names to arrays to be saved.
retain_graph (bool, optional): Whether or not to retain the graph
during array evaluation. If left unspecified the graph is retained
only if saving is done in a function transformation. Default: ``None``.
)pbdoc"); )pbdoc");
m.def( m.def(
"where", "where",

View File

@ -440,11 +440,10 @@ auto py_vmap(
void init_transforms(py::module_& m) { void init_transforms(py::module_& m) {
m.def( m.def(
"eval", "eval",
[](const py::args& args, bool retain_graph) { [](const py::args& args) {
std::vector<array> arrays = tree_flatten(args); std::vector<array> arrays = tree_flatten(args);
eval(arrays, retain_graph); eval(arrays);
}, },
"retain_graph"_a = false,
R"pbdoc( R"pbdoc(
Evaluate an :class:`array` or tree of :class:`array`. Evaluate an :class:`array` or tree of :class:`array`.
@ -453,9 +452,6 @@ void init_transforms(py::module_& m) {
or a tree of arrays. If a tree is given the nodes can be a Python or a tree of arrays. If a tree is given the nodes can be a Python
:class:`list`, :class:`tuple` or :class:`dict` but the leafs must all be :class:`list`, :class:`tuple` or :class:`dict` but the leafs must all be
an :class:`array`. an :class:`array`.
retain_graph (bool): Indicate that the graph structure should be
preserved. This option is intended to enable function transforms
which contain control flow based on the value of an array.
)pbdoc"); )pbdoc");
m.def( m.def(
"jvp", "jvp",

View File

@ -259,6 +259,21 @@ class TestAutograd(mlx_tests.MLXTestCase):
self.assertTrue(mx.allclose(vjps[0], mx.zeros(shape_in))) self.assertTrue(mx.allclose(vjps[0], mx.zeros(shape_in)))
def test_update_state(self):
y = mx.array([1.0])
state = mx.zeros((2,))
def fn(y, x):
nonlocal state
x = y * x
state = state + x
return x.sum()
x = mx.ones((2,))
mx.grad(fn)(y, x)
mx.eval(state)
self.assertTrue(mx.allclose(state, mx.ones((2,))))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -15,18 +15,13 @@ class TestEval(mlx_tests.MLXTestCase):
self.assertEqual(x.tolist(), [[1, 1], [1, 1]]) self.assertEqual(x.tolist(), [[1, 1], [1, 1]])
def test_retain_graph(self): def test_retain_graph(self):
def fun(x, retain_graph): def fun(x):
y = 3 * x y = 3 * x
mx.eval(y, retain_graph=retain_graph) mx.eval(y)
return 2 * y return 2 * y
dfun_dx_1 = mx.grad(partial(fun, retain_graph=False)) dfun_dx = mx.grad(fun)
dfun_dx_2 = mx.grad(partial(fun, retain_graph=True)) y = dfun_dx(mx.array(1.0))
with self.assertRaises(ValueError):
dfun_dx_1(mx.array(1.0))
y = dfun_dx_2(mx.array(1.0))
self.assertEqual(y.item(), 6.0) self.assertEqual(y.item(), 6.0)

View File

@ -95,19 +95,14 @@ TEST_CASE("test jvp") {
CHECK_EQ(dout[0].item<float>(), 4.0f); CHECK_EQ(dout[0].item<float>(), 4.0f);
} }
// Evaling in function without graph retention throws // Evaling in function while tracing performs graph retention
{ {
auto fun = [](const array& x) {
auto y = 3 * x;
eval(y);
return 2 * y;
};
CHECK_THROWS(jvp(fun, array(1.0f), array(1.0f)));
// Ok with graph retention
auto fun1 = [](const array& x) { auto fun1 = [](const array& x) {
auto y = 3 * x; auto y = 3 * x;
eval({y}, true); eval(y);
CHECK(y.is_evaled());
CHECK(y.has_primitive());
CHECK(y.is_tracer());
return 2 * y; return 2 * y;
}; };
CHECK_EQ(jvp(fun1, array(1.0f), array(1.0f)).second.item<float>(), 6.0f); CHECK_EQ(jvp(fun1, array(1.0f), array(1.0f)).second.item<float>(), 6.0f);
@ -251,29 +246,27 @@ TEST_CASE("test grad") {
} }
{ {
// Evaluating in the middle of the grad function throws // No graph retention since the output is independent of y
// as it breaks the graph
auto fn = [](array x) {
x = x + 2.0f;
eval(x);
return square(x);
};
CHECK_THROWS(grad(fn)(array(1.0f)));
// Ok since the output is independent of y
auto y = ones({3, 3}); auto y = ones({3, 3});
auto fn1 = [y](array x) { auto fn1 = [y](array x) {
x = x + 2.0f; x = x + 2.0f;
eval(y); eval(y);
CHECK(x.is_tracer());
CHECK(!y.is_tracer());
CHECK(y.is_evaled());
CHECK(!y.has_primitive());
return square(x); return square(x);
}; };
auto dfdx = grad(fn1)(array(1.0f)); auto dfdx = grad(fn1)(array(1.0f));
CHECK_EQ(dfdx.item<float>(), 6.0f); CHECK_EQ(dfdx.item<float>(), 6.0f);
// Retain the graph to avoid breaking it // Graph automatically retained to compute the grad
auto fn2 = [](array x) { auto fn2 = [](array x) {
x = x + 2.0f; x = x + 2.0f;
eval({x}, true); eval(x);
CHECK(x.is_tracer());
CHECK(x.is_evaled());
CHECK(x.has_primitive());
return square(x); return square(x);
}; };
dfdx = grad(fn2)(array(1.0f)); dfdx = grad(fn2)(array(1.0f));
@ -283,7 +276,8 @@ TEST_CASE("test grad") {
// Control flow in grad computation // Control flow in grad computation
{ {
auto fn = [](array x) { auto fn = [](array x) {
if (x.item<float>(true) > 1) { x = x + array(2.0f);
if (x.item<float>() > 3) {
return square(x); return square(x);
} else { } else {
return 4 * x; return 4 * x;
@ -294,7 +288,7 @@ TEST_CASE("test grad") {
CHECK_EQ(dfdx.item<float>(), 4.0f); CHECK_EQ(dfdx.item<float>(), 4.0f);
dfdx = grad(fn)(array(1.5f)); dfdx = grad(fn)(array(1.5f));
CHECK_EQ(dfdx.item<float>(), 3.0f); CHECK_EQ(dfdx.item<float>(), 7.0f);
} }
// Grad with multiple inputs // Grad with multiple inputs
@ -1192,3 +1186,19 @@ TEST_CASE("test scan grads") {
CHECK(array_equal(out, expected).item<bool>()); CHECK(array_equal(out, expected).item<bool>());
} }
} }
TEST_CASE("test update state") {
auto y = array({1.0});
auto x = array({1.0, 1.0});
auto state = array({0.0, 0.0});
auto fn = [&state, &x](array y) {
x = y * x;
state = state + x;
return sum(x);
};
grad(fn)(y);
eval(state);
CHECK(!state.has_primitive());
CHECK(state.is_evaled());
CHECK(array_equal(state, array({1.0, 1.0})).item<bool>());
}

View File

@ -48,36 +48,36 @@ TEST_CASE("test eval multiple") {
CHECK(array_equal(b, full({10, 10}, 0.0f)).item<bool>()); CHECK(array_equal(b, full({10, 10}, 0.0f)).item<bool>());
} }
TEST_CASE("test eval with tracer") { TEST_CASE("test eval with tracer when not tracing") {
// Since we are not tracing it doesn't matter that the array flags are
// tracers they will always be detached.
auto x = array(1); auto x = array(1);
x.set_tracer(true); x.set_tracer(true);
CHECK(!x.is_tracer());
// Ok, x is not a node
eval(x); eval(x);
CHECK(!x.has_primitive());
CHECK(x.is_evaled());
x = ones({2, 3}); x = ones({2, 3});
x.set_tracer(true); x.set_tracer(true);
CHECK_THROWS(eval(x)); eval(x);
CHECK(!x.has_primitive());
// Ok retain_graph=true CHECK(x.is_evaled());
eval({x}, true);
// Make sure all arguments are checked
auto y = ones({2, 3});
CHECK_THROWS(eval(x, y));
} }
TEST_CASE("test eval graph retention") { TEST_CASE("test eval graph retention when not tracing") {
// Since we are not tracing it doesn't matter that the array flags are
// tracers they will always be detached.
auto x = array(1); auto x = array(1);
x.set_tracer(true);
auto y = array(2); auto y = array(2);
auto z = x + y; auto z = x + y;
eval({z}, true); eval(z);
CHECK(z.has_primitive()); CHECK(!z.has_primitive());
CHECK(z.is_evaled());
CHECK_EQ(z.item<int>(true), 3);
CHECK(z.has_primitive());
CHECK(z.is_evaled()); CHECK(z.is_evaled());
CHECK_EQ(z.item<int>(), 3);
z.set_tracer(false);
CHECK_EQ(z.item<int>(), 3); CHECK_EQ(z.item<int>(), 3);
CHECK(!z.has_primitive()); CHECK(!z.has_primitive());
CHECK(z.is_evaled()); CHECK(z.is_evaled());
@ -85,13 +85,7 @@ TEST_CASE("test eval graph retention") {
z = x + y; z = x + y;
auto a = z + x; auto a = z + x;
auto b = a + y; auto b = a + y;
eval({b}, true); eval(b);
CHECK(z.has_primitive());
CHECK(z.is_evaled());
CHECK(a.has_primitive());
CHECK(a.is_evaled());
eval({b}, false);
CHECK(!z.has_primitive()); CHECK(!z.has_primitive());
CHECK(z.is_evaled()); CHECK(z.is_evaled());
CHECK(!a.has_primitive()); CHECK(!a.has_primitive());

View File

@ -183,7 +183,7 @@ TEST_CASE("test vmap with eval") {
auto fun2 = [](std::vector<array> inputs) { auto fun2 = [](std::vector<array> inputs) {
auto x = inputs[0] + 1; auto x = inputs[0] + 1;
auto y = inputs[1] + 2; auto y = inputs[1] + 2;
eval({x}, true); eval(x);
auto out = add(x, y); auto out = add(x, y);
return std::vector<array>{out}; return std::vector<array>{out};
}; };