mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-05 08:41:13 +08:00
Removes the retain_graph
flag (#385)
* Adds global tracing flag * Removes retain_graph in favor of is_tracer
This commit is contained in:
parent
449b43762e
commit
a611b0bc82
@ -6,6 +6,7 @@
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/transforms.h"
|
||||
#include "mlx/transforms_impl.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@ -21,6 +22,12 @@ std::pair<size_t, std::vector<size_t>> cum_prod(const std::vector<int>& shape) {
|
||||
return {cum_prod, strides};
|
||||
}
|
||||
|
||||
/** Return true if we are currently performing a function transformation in
|
||||
* order to keep the graph when evaluating tracer arrays. */
|
||||
bool in_tracing() {
|
||||
return detail::InTracing::in_tracing();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
|
||||
@ -62,8 +69,12 @@ void array::detach() {
|
||||
array_desc_->primitive = nullptr;
|
||||
}
|
||||
|
||||
void array::eval(bool retain_graph /* = false */) {
|
||||
mlx::core::eval({*this}, retain_graph);
|
||||
void array::eval() {
|
||||
mlx::core::eval({*this});
|
||||
}
|
||||
|
||||
bool array::is_tracer() const {
|
||||
return array_desc_->is_tracer && in_tracing();
|
||||
}
|
||||
|
||||
void array::set_data(allocator::Buffer buffer, deleter_t d) {
|
||||
|
12
mlx/array.h
12
mlx/array.h
@ -116,11 +116,11 @@ class array {
|
||||
};
|
||||
|
||||
/** Evaluate the array. */
|
||||
void eval(bool retain_graph = false);
|
||||
void eval();
|
||||
|
||||
/** Get the value from a scalar array. */
|
||||
template <typename T>
|
||||
T item(bool retain_graph = false);
|
||||
T item();
|
||||
|
||||
struct ArrayIterator {
|
||||
using iterator_category = std::random_access_iterator_tag;
|
||||
@ -265,9 +265,7 @@ class array {
|
||||
array_desc_->is_tracer = is_tracer;
|
||||
}
|
||||
// Check if the array is a tracer array
|
||||
bool is_tracer() const {
|
||||
return array_desc_->is_tracer;
|
||||
}
|
||||
bool is_tracer() const;
|
||||
|
||||
void set_data(allocator::Buffer buffer, deleter_t d = allocator::free);
|
||||
|
||||
@ -381,11 +379,11 @@ array::array(
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T array::item(bool retain_graph /* = false */) {
|
||||
T array::item() {
|
||||
if (size() != 1) {
|
||||
throw std::invalid_argument("item can only be called on arrays of size 1.");
|
||||
}
|
||||
eval(retain_graph);
|
||||
eval();
|
||||
return *data<T>();
|
||||
}
|
||||
|
||||
|
@ -46,10 +46,8 @@ MTL::CommandBuffer* increment_command_buffer(Stream s) {
|
||||
std::function<void()> make_task(
|
||||
array& arr,
|
||||
std::vector<std::shared_future<void>> deps,
|
||||
std::shared_ptr<std::promise<void>> p,
|
||||
bool retain_graph) {
|
||||
auto task =
|
||||
[retain_graph, arr, deps = std::move(deps), p = std::move(p)]() mutable {
|
||||
std::shared_ptr<std::promise<void>> p) {
|
||||
auto task = [arr, deps = std::move(deps), p = std::move(p)]() mutable {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
for (auto& d : deps) {
|
||||
d.wait();
|
||||
@ -61,9 +59,8 @@ std::function<void()> make_task(
|
||||
metal::device(s.device).end_encoding(s.index);
|
||||
scheduler::notify_new_task(s);
|
||||
command_buffer->addCompletedHandler(
|
||||
[retain_graph, s, arr, p = std::move(p)](
|
||||
MTL::CommandBuffer*) mutable {
|
||||
if (!retain_graph) {
|
||||
[s, arr, p = std::move(p)](MTL::CommandBuffer*) mutable {
|
||||
if (!arr.is_tracer()) {
|
||||
arr.detach();
|
||||
}
|
||||
p->set_value();
|
||||
@ -72,8 +69,8 @@ std::function<void()> make_task(
|
||||
metal::device(s.device).commit_command_buffer(s.index);
|
||||
} else {
|
||||
command_buffer->addCompletedHandler(
|
||||
[retain_graph, s, arr](MTL::CommandBuffer*) mutable {
|
||||
if (!retain_graph) {
|
||||
[s, arr](MTL::CommandBuffer*) mutable {
|
||||
if (!arr.is_tracer()) {
|
||||
arr.detach();
|
||||
}
|
||||
});
|
||||
|
@ -25,7 +25,6 @@ std::shared_ptr<void> new_scoped_memory_pool();
|
||||
std::function<void()> make_task(
|
||||
array& arr,
|
||||
std::vector<std::shared_future<void>> deps,
|
||||
std::shared_ptr<std::promise<void>> p,
|
||||
bool retain_graph);
|
||||
std::shared_ptr<std::promise<void>> p);
|
||||
|
||||
} // namespace mlx::core::metal
|
||||
|
@ -14,8 +14,7 @@ std::shared_ptr<void> new_scoped_memory_pool() {
|
||||
std::function<void()> make_task(
|
||||
array& arr,
|
||||
std::vector<std::shared_future<void>> deps,
|
||||
std::shared_ptr<std::promise<void>> p,
|
||||
bool retain_graph) {
|
||||
std::shared_ptr<std::promise<void>> p) {
|
||||
throw std::runtime_error(
|
||||
"[metal::make_task] Cannot make GPU task without metal backend");
|
||||
}
|
||||
|
@ -40,11 +40,11 @@ inline bool is_big_endian_() {
|
||||
} // namespace
|
||||
|
||||
/** Save array to out stream in .npy format */
|
||||
void save(std::shared_ptr<io::Writer> out_stream, array a, bool retain_graph) {
|
||||
void save(std::shared_ptr<io::Writer> out_stream, array a) {
|
||||
////////////////////////////////////////////////////////
|
||||
// Check array
|
||||
|
||||
a.eval(retain_graph);
|
||||
a.eval();
|
||||
|
||||
if (a.nbytes() == 0) {
|
||||
throw std::invalid_argument("[save] cannot serialize an empty array");
|
||||
@ -52,7 +52,7 @@ void save(std::shared_ptr<io::Writer> out_stream, array a, bool retain_graph) {
|
||||
|
||||
if (!(a.flags().row_contiguous || a.flags().col_contiguous)) {
|
||||
a = reshape(flatten(a), a.shape());
|
||||
a.eval(retain_graph);
|
||||
a.eval();
|
||||
}
|
||||
// Check once more in-case the above ops change
|
||||
if (!(a.flags().row_contiguous || a.flags().col_contiguous)) {
|
||||
@ -127,7 +127,7 @@ void save(std::shared_ptr<io::Writer> out_stream, array a, bool retain_graph) {
|
||||
}
|
||||
|
||||
/** Save array to file in .npy format */
|
||||
void save(const std::string& file_, array a, bool retain_graph) {
|
||||
void save(const std::string& file_, array a) {
|
||||
// Open and check file
|
||||
std::string file = file_;
|
||||
|
||||
@ -136,7 +136,7 @@ void save(const std::string& file_, array a, bool retain_graph) {
|
||||
file += ".npy";
|
||||
|
||||
// Serialize array
|
||||
save(std::make_shared<io::FileWriter>(file), a, retain_graph);
|
||||
save(std::make_shared<io::FileWriter>(file), a);
|
||||
}
|
||||
|
||||
/** Load array from reader in .npy format */
|
||||
|
@ -125,8 +125,7 @@ std::unordered_map<std::string, array> load_safetensors(
|
||||
/** Save array to out stream in .npy format */
|
||||
void save_safetensors(
|
||||
std::shared_ptr<io::Writer> out_stream,
|
||||
std::unordered_map<std::string, array> a,
|
||||
std::optional<bool> retain_graph_) {
|
||||
std::unordered_map<std::string, array> a) {
|
||||
////////////////////////////////////////////////////////
|
||||
// Check file
|
||||
if (!out_stream->good() || !out_stream->is_open()) {
|
||||
@ -142,8 +141,7 @@ void save_safetensors(
|
||||
});
|
||||
size_t offset = 0;
|
||||
for (auto& [key, arr] : a) {
|
||||
auto retain = retain_graph_.value_or(arr.is_tracer());
|
||||
arr.eval(retain);
|
||||
arr.eval();
|
||||
if (arr.nbytes() == 0) {
|
||||
throw std::invalid_argument(
|
||||
"[save_safetensors] cannot serialize an empty array key: " + key);
|
||||
@ -152,7 +150,7 @@ void save_safetensors(
|
||||
// Try to make it row contiguous
|
||||
if (!arr.flags().row_contiguous) {
|
||||
arr = reshape(flatten(arr), arr.shape());
|
||||
arr.eval(retain);
|
||||
arr.eval();
|
||||
}
|
||||
|
||||
// Has to be row-major now but, check one more time in case
|
||||
@ -181,8 +179,7 @@ void save_safetensors(
|
||||
|
||||
void save_safetensors(
|
||||
const std::string& file_,
|
||||
std::unordered_map<std::string, array> a,
|
||||
std::optional<bool> retain_graph) {
|
||||
std::unordered_map<std::string, array> a) {
|
||||
// Open and check file
|
||||
std::string file = file_;
|
||||
|
||||
@ -192,7 +189,7 @@ void save_safetensors(
|
||||
file += ".safetensors";
|
||||
|
||||
// Serialize array
|
||||
save_safetensors(std::make_shared<io::FileWriter>(file), a, retain_graph);
|
||||
save_safetensors(std::make_shared<io::FileWriter>(file), a);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
13
mlx/ops.h
13
mlx/ops.h
@ -1021,13 +1021,10 @@ array conv2d(
|
||||
/** Serialization operations */
|
||||
|
||||
/** Save array to out stream in .npy format */
|
||||
void save(
|
||||
std::shared_ptr<io::Writer> out_stream,
|
||||
array a,
|
||||
bool retain_graph = true);
|
||||
void save(std::shared_ptr<io::Writer> out_stream, array a);
|
||||
|
||||
/** Save array to file in .npy format */
|
||||
void save(const std::string& file, array a, bool retain_graph = true);
|
||||
void save(const std::string& file, array a);
|
||||
|
||||
/** Load array from reader in .npy format */
|
||||
array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s = {});
|
||||
@ -1091,10 +1088,8 @@ std::unordered_map<std::string, array> load_safetensors(
|
||||
|
||||
void save_safetensors(
|
||||
std::shared_ptr<io::Writer> in_stream,
|
||||
std::unordered_map<std::string, array>,
|
||||
std::optional<bool> retain_graph = std::nullopt);
|
||||
std::unordered_map<std::string, array>);
|
||||
void save_safetensors(
|
||||
const std::string& file,
|
||||
std::unordered_map<std::string, array>,
|
||||
std::optional<bool> retain_graph = std::nullopt);
|
||||
std::unordered_map<std::string, array>);
|
||||
} // namespace mlx::core
|
||||
|
@ -19,6 +19,12 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
// Initialize the static tracing counter from transforms_impl.h .
|
||||
//
|
||||
// This is used to implement the in_tracing() function the returns true if we
|
||||
// are currently under a function transformation.
|
||||
int detail::InTracing::tracing_counter{0};
|
||||
|
||||
void simplify(const std::vector<array>& outputs) {
|
||||
std::function<void(const array&)> recurse;
|
||||
std::queue<array> tape;
|
||||
@ -154,16 +160,7 @@ void simplify(const std::vector<array>& outputs) {
|
||||
}
|
||||
}
|
||||
|
||||
void eval(const std::vector<array>& outputs, bool retain_graph /* = false */) {
|
||||
if (!retain_graph) {
|
||||
for (auto& out : outputs) {
|
||||
if (out.has_primitive() && out.is_tracer()) {
|
||||
throw std::invalid_argument(
|
||||
"[eval] Illegal to eval an array during "
|
||||
"function transform without graph retention.");
|
||||
}
|
||||
}
|
||||
}
|
||||
void eval(const std::vector<array>& outputs) {
|
||||
std::function<void(const array&)> recurse;
|
||||
std::queue<array> tape;
|
||||
std::unordered_set<std::uintptr_t> cache;
|
||||
@ -185,7 +182,7 @@ void eval(const std::vector<array>& outputs, bool retain_graph /* = false */) {
|
||||
}
|
||||
}
|
||||
cache.insert(id);
|
||||
if (!a.is_evaled() || (!retain_graph && a.has_primitive())) {
|
||||
if (!a.is_evaled() || (!a.is_tracer() && a.has_primitive())) {
|
||||
if (!a.has_primitive()) {
|
||||
throw std::invalid_argument(
|
||||
"[eval] Attempting to eval an array without a primitive.");
|
||||
@ -195,7 +192,7 @@ void eval(const std::vector<array>& outputs, bool retain_graph /* = false */) {
|
||||
};
|
||||
|
||||
for (auto& arr : outputs) {
|
||||
if (!arr.is_evaled() || (!retain_graph && arr.has_primitive())) {
|
||||
if (!arr.is_evaled() || (!arr.is_tracer() && arr.has_primitive())) {
|
||||
recurse(arr);
|
||||
// Insert a dependency for every output to synchronize
|
||||
// with at the end.
|
||||
@ -209,7 +206,7 @@ void eval(const std::vector<array>& outputs, bool retain_graph /* = false */) {
|
||||
auto arr = std::move(tape.front());
|
||||
tape.pop();
|
||||
if (arr.is_evaled()) {
|
||||
if (!retain_graph && arr.has_primitive()) {
|
||||
if (!arr.is_tracer() && arr.has_primitive()) {
|
||||
arr.detach();
|
||||
}
|
||||
continue;
|
||||
@ -233,12 +230,9 @@ void eval(const std::vector<array>& outputs, bool retain_graph /* = false */) {
|
||||
throw std::runtime_error("Metal GPU is not available.");
|
||||
}
|
||||
scheduler::enqueue(
|
||||
stream,
|
||||
metal::make_task(
|
||||
arr, std::move(arr_deps), std::move(p), retain_graph));
|
||||
stream, metal::make_task(arr, std::move(arr_deps), std::move(p)));
|
||||
} else {
|
||||
auto task = [retain_graph,
|
||||
arr,
|
||||
auto task = [arr,
|
||||
stream,
|
||||
arr_deps = std::move(arr_deps),
|
||||
p = std::move(p)]() mutable {
|
||||
@ -247,7 +241,7 @@ void eval(const std::vector<array>& outputs, bool retain_graph /* = false */) {
|
||||
}
|
||||
scheduler::notify_new_task(stream);
|
||||
arr.primitive().eval_cpu(arr.inputs(), arr);
|
||||
if (!retain_graph) {
|
||||
if (!arr.is_tracer()) {
|
||||
arr.detach();
|
||||
}
|
||||
if (p) {
|
||||
@ -269,6 +263,9 @@ std::pair<std::vector<array>, std::vector<array>> vjp(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotans) {
|
||||
// Set the global tracing flag.
|
||||
detail::InTracing in_tracing;
|
||||
|
||||
// Make tracers from given primals
|
||||
std::vector<array> primals_;
|
||||
for (auto& p : primals) {
|
||||
@ -425,6 +422,9 @@ std::pair<std::vector<array>, std::vector<array>> jvp(
|
||||
}
|
||||
}
|
||||
|
||||
// Set the global tracing flag.
|
||||
detail::InTracing in_tracing;
|
||||
|
||||
std::vector<array> primals_;
|
||||
for (auto& p : primals) {
|
||||
auto s = p.has_primitive() ? p.primitive().stream()
|
||||
@ -578,6 +578,9 @@ std::pair<std::vector<array>, std::vector<array>> vmap_trace(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& in_axes) {
|
||||
// Set the global tracing flag
|
||||
InTracing in_tracing;
|
||||
|
||||
if (in_axes.size() != inputs.size()) {
|
||||
throw std::invalid_argument(
|
||||
"[vmap] The number of in axes must match the number of inputs.");
|
||||
|
@ -14,11 +14,11 @@ void simplify(Arrays... outputs) {
|
||||
simplify(std::vector<array>{std::forward<Arrays>(outputs)...});
|
||||
}
|
||||
|
||||
void eval(const std::vector<array>& outputs, bool retain_graph = false);
|
||||
void eval(const std::vector<array>& outputs);
|
||||
|
||||
template <typename... Arrays>
|
||||
void eval(Arrays... outputs) {
|
||||
eval(std::vector<array>{std::forward<Arrays>(outputs)...}, false);
|
||||
eval(std::vector<array>{std::forward<Arrays>(outputs)...});
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -14,4 +14,23 @@ std::vector<array> vmap_replace(
|
||||
const std::vector<int>& in_axes,
|
||||
const std::vector<int>& out_axes);
|
||||
|
||||
// Create an InTracing object during tracing operations to signify to the rest
|
||||
// of the codebase that we are during tracing so evals should not throw away
|
||||
// the graph.
|
||||
struct InTracing {
|
||||
InTracing() {
|
||||
tracing_counter++;
|
||||
}
|
||||
~InTracing() {
|
||||
tracing_counter--;
|
||||
}
|
||||
|
||||
static bool in_tracing() {
|
||||
return tracing_counter > 0;
|
||||
}
|
||||
|
||||
private:
|
||||
static int tracing_counter;
|
||||
};
|
||||
|
||||
} // namespace mlx::core::detail
|
||||
|
@ -39,34 +39,33 @@ py::list to_list(array& a, size_t index, int dim) {
|
||||
}
|
||||
|
||||
auto to_scalar(array& a) {
|
||||
bool retain_graph = a.is_tracer();
|
||||
switch (a.dtype()) {
|
||||
case bool_:
|
||||
return py::cast(a.item<bool>(retain_graph));
|
||||
return py::cast(a.item<bool>());
|
||||
case uint8:
|
||||
return py::cast(a.item<uint8_t>(retain_graph));
|
||||
return py::cast(a.item<uint8_t>());
|
||||
case uint16:
|
||||
return py::cast(a.item<uint16_t>(retain_graph));
|
||||
return py::cast(a.item<uint16_t>());
|
||||
case uint32:
|
||||
return py::cast(a.item<uint32_t>(retain_graph));
|
||||
return py::cast(a.item<uint32_t>());
|
||||
case uint64:
|
||||
return py::cast(a.item<uint64_t>(retain_graph));
|
||||
return py::cast(a.item<uint64_t>());
|
||||
case int8:
|
||||
return py::cast(a.item<int8_t>(retain_graph));
|
||||
return py::cast(a.item<int8_t>());
|
||||
case int16:
|
||||
return py::cast(a.item<int16_t>(retain_graph));
|
||||
return py::cast(a.item<int16_t>());
|
||||
case int32:
|
||||
return py::cast(a.item<int32_t>(retain_graph));
|
||||
return py::cast(a.item<int32_t>());
|
||||
case int64:
|
||||
return py::cast(a.item<int64_t>(retain_graph));
|
||||
return py::cast(a.item<int64_t>());
|
||||
case float16:
|
||||
return py::cast(static_cast<float>(a.item<float16_t>(retain_graph)));
|
||||
return py::cast(static_cast<float>(a.item<float16_t>()));
|
||||
case float32:
|
||||
return py::cast(a.item<float>(retain_graph));
|
||||
return py::cast(a.item<float>());
|
||||
case bfloat16:
|
||||
return py::cast(static_cast<float>(a.item<bfloat16_t>(retain_graph)));
|
||||
return py::cast(static_cast<float>(a.item<bfloat16_t>()));
|
||||
case complex64:
|
||||
return py::cast(a.item<std::complex<float>>(retain_graph));
|
||||
return py::cast(a.item<std::complex<float>>());
|
||||
}
|
||||
}
|
||||
|
||||
@ -74,7 +73,7 @@ py::object tolist(array& a) {
|
||||
if (a.ndim() == 0) {
|
||||
return to_scalar(a);
|
||||
}
|
||||
a.eval(a.is_tracer());
|
||||
a.eval();
|
||||
py::object pl;
|
||||
switch (a.dtype()) {
|
||||
case bool_:
|
||||
@ -527,7 +526,7 @@ void init_array(py::module_& m) {
|
||||
.def_buffer([](array& a) {
|
||||
// Eval if not already evaled
|
||||
if (!a.is_evaled()) {
|
||||
eval({a}, a.is_tracer());
|
||||
a.eval();
|
||||
}
|
||||
return pybind11::buffer_info(
|
||||
a.data<void>(),
|
||||
@ -751,7 +750,7 @@ void init_array(py::module_& m) {
|
||||
"__repr__",
|
||||
[](array& a) {
|
||||
if (!a.is_evaled()) {
|
||||
a.eval(a.is_tracer());
|
||||
a.eval();
|
||||
}
|
||||
std::ostringstream os;
|
||||
os << a;
|
||||
|
@ -345,19 +345,15 @@ class PyFileWriter : public io::Writer {
|
||||
py::object tell_func_;
|
||||
};
|
||||
|
||||
void mlx_save_helper(
|
||||
py::object file,
|
||||
array a,
|
||||
std::optional<bool> retain_graph_) {
|
||||
bool retain_graph = retain_graph_.value_or(a.is_tracer());
|
||||
void mlx_save_helper(py::object file, array a) {
|
||||
if (py::isinstance<py::str>(file)) {
|
||||
save(py::cast<std::string>(file), a, retain_graph);
|
||||
save(py::cast<std::string>(file), a);
|
||||
return;
|
||||
} else if (is_ostream_object(file)) {
|
||||
auto writer = std::make_shared<PyFileWriter>(file);
|
||||
{
|
||||
py::gil_scoped_release gil;
|
||||
save(writer, a, retain_graph);
|
||||
save(writer, a);
|
||||
}
|
||||
|
||||
return;
|
||||
@ -414,26 +410,23 @@ void mlx_savez_helper(
|
||||
auto writer = std::make_shared<PyFileWriter>(py_ostream);
|
||||
{
|
||||
py::gil_scoped_release gil;
|
||||
save(writer, a, /*retain_graph=*/a.is_tracer());
|
||||
save(writer, a);
|
||||
}
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
void mlx_save_safetensor_helper(
|
||||
py::object file,
|
||||
py::dict d,
|
||||
std::optional<bool> retain_graph) {
|
||||
void mlx_save_safetensor_helper(py::object file, py::dict d) {
|
||||
auto arrays_map = d.cast<std::unordered_map<std::string, array>>();
|
||||
if (py::isinstance<py::str>(file)) {
|
||||
save_safetensors(py::cast<std::string>(file), arrays_map, retain_graph);
|
||||
save_safetensors(py::cast<std::string>(file), arrays_map);
|
||||
return;
|
||||
} else if (is_ostream_object(file)) {
|
||||
auto writer = std::make_shared<PyFileWriter>(file);
|
||||
{
|
||||
py::gil_scoped_release gil;
|
||||
save_safetensors(writer, arrays_map, retain_graph);
|
||||
save_safetensors(writer, arrays_map);
|
||||
}
|
||||
|
||||
return;
|
||||
|
@ -17,19 +17,13 @@ using DictOrArray = std::variant<array, std::unordered_map<std::string, array>>;
|
||||
std::unordered_map<std::string, array> mlx_load_safetensor_helper(
|
||||
py::object file,
|
||||
StreamOrDevice s);
|
||||
void mlx_save_safetensor_helper(
|
||||
py::object file,
|
||||
py::dict d,
|
||||
std::optional<bool> retain_graph = std::nullopt);
|
||||
void mlx_save_safetensor_helper(py::object file, py::dict d);
|
||||
|
||||
DictOrArray mlx_load_helper(
|
||||
py::object file,
|
||||
std::optional<std::string> format,
|
||||
StreamOrDevice s);
|
||||
void mlx_save_helper(
|
||||
py::object file,
|
||||
array a,
|
||||
std::optional<bool> retain_graph = std::nullopt);
|
||||
void mlx_save_helper(py::object file, array a);
|
||||
void mlx_savez_helper(
|
||||
py::object file,
|
||||
py::args args,
|
||||
|
@ -2902,20 +2902,14 @@ void init_ops(py::module_& m) {
|
||||
&mlx_save_helper,
|
||||
"file"_a,
|
||||
"arr"_a,
|
||||
py::pos_only(),
|
||||
"retain_graph"_a = std::nullopt,
|
||||
py::kw_only(),
|
||||
R"pbdoc(
|
||||
save(file: str, arr: array, / , retain_graph: Optional[bool] = None)
|
||||
save(file: str, arr: array)
|
||||
|
||||
Save the array to a binary file in ``.npy`` format.
|
||||
|
||||
Args:
|
||||
file (str): File to which the array is saved
|
||||
arr (array): Array to be saved.
|
||||
retain_graph (bool, optional): Whether or not to retain the graph
|
||||
during array evaluation. If left unspecified the graph is retained
|
||||
only if saving is done in a function transformation. Default: ``None``
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"savez",
|
||||
@ -2999,11 +2993,8 @@ void init_ops(py::module_& m) {
|
||||
&mlx_save_safetensor_helper,
|
||||
"file"_a,
|
||||
"arrays"_a,
|
||||
py::pos_only(),
|
||||
"retain_graph"_a = std::nullopt,
|
||||
py::kw_only(),
|
||||
R"pbdoc(
|
||||
save_safetensors(file: str, arrays: Dict[str, array], /, retain_graph: Optional[bool] = None)
|
||||
save_safetensors(file: str, arrays: Dict[str, array])
|
||||
|
||||
Save array(s) to a binary file in ``.safetensors`` format.
|
||||
|
||||
@ -3012,9 +3003,6 @@ void init_ops(py::module_& m) {
|
||||
Args:
|
||||
file (file, str): File in which the array is saved>
|
||||
arrays (dict(str, array)): The dictionary of names to arrays to be saved.
|
||||
retain_graph (bool, optional): Whether or not to retain the graph
|
||||
during array evaluation. If left unspecified the graph is retained
|
||||
only if saving is done in a function transformation. Default: ``None``.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"where",
|
||||
|
@ -440,11 +440,10 @@ auto py_vmap(
|
||||
void init_transforms(py::module_& m) {
|
||||
m.def(
|
||||
"eval",
|
||||
[](const py::args& args, bool retain_graph) {
|
||||
[](const py::args& args) {
|
||||
std::vector<array> arrays = tree_flatten(args);
|
||||
eval(arrays, retain_graph);
|
||||
eval(arrays);
|
||||
},
|
||||
"retain_graph"_a = false,
|
||||
R"pbdoc(
|
||||
Evaluate an :class:`array` or tree of :class:`array`.
|
||||
|
||||
@ -453,9 +452,6 @@ void init_transforms(py::module_& m) {
|
||||
or a tree of arrays. If a tree is given the nodes can be a Python
|
||||
:class:`list`, :class:`tuple` or :class:`dict` but the leafs must all be
|
||||
an :class:`array`.
|
||||
retain_graph (bool): Indicate that the graph structure should be
|
||||
preserved. This option is intended to enable function transforms
|
||||
which contain control flow based on the value of an array.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"jvp",
|
||||
|
@ -259,6 +259,21 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
||||
|
||||
self.assertTrue(mx.allclose(vjps[0], mx.zeros(shape_in)))
|
||||
|
||||
def test_update_state(self):
|
||||
y = mx.array([1.0])
|
||||
state = mx.zeros((2,))
|
||||
|
||||
def fn(y, x):
|
||||
nonlocal state
|
||||
x = y * x
|
||||
state = state + x
|
||||
return x.sum()
|
||||
|
||||
x = mx.ones((2,))
|
||||
mx.grad(fn)(y, x)
|
||||
mx.eval(state)
|
||||
self.assertTrue(mx.allclose(state, mx.ones((2,))))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -15,18 +15,13 @@ class TestEval(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(x.tolist(), [[1, 1], [1, 1]])
|
||||
|
||||
def test_retain_graph(self):
|
||||
def fun(x, retain_graph):
|
||||
def fun(x):
|
||||
y = 3 * x
|
||||
mx.eval(y, retain_graph=retain_graph)
|
||||
mx.eval(y)
|
||||
return 2 * y
|
||||
|
||||
dfun_dx_1 = mx.grad(partial(fun, retain_graph=False))
|
||||
dfun_dx_2 = mx.grad(partial(fun, retain_graph=True))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
dfun_dx_1(mx.array(1.0))
|
||||
|
||||
y = dfun_dx_2(mx.array(1.0))
|
||||
dfun_dx = mx.grad(fun)
|
||||
y = dfun_dx(mx.array(1.0))
|
||||
self.assertEqual(y.item(), 6.0)
|
||||
|
||||
|
||||
|
@ -95,19 +95,14 @@ TEST_CASE("test jvp") {
|
||||
CHECK_EQ(dout[0].item<float>(), 4.0f);
|
||||
}
|
||||
|
||||
// Evaling in function without graph retention throws
|
||||
// Evaling in function while tracing performs graph retention
|
||||
{
|
||||
auto fun = [](const array& x) {
|
||||
auto y = 3 * x;
|
||||
eval(y);
|
||||
return 2 * y;
|
||||
};
|
||||
CHECK_THROWS(jvp(fun, array(1.0f), array(1.0f)));
|
||||
|
||||
// Ok with graph retention
|
||||
auto fun1 = [](const array& x) {
|
||||
auto y = 3 * x;
|
||||
eval({y}, true);
|
||||
eval(y);
|
||||
CHECK(y.is_evaled());
|
||||
CHECK(y.has_primitive());
|
||||
CHECK(y.is_tracer());
|
||||
return 2 * y;
|
||||
};
|
||||
CHECK_EQ(jvp(fun1, array(1.0f), array(1.0f)).second.item<float>(), 6.0f);
|
||||
@ -251,29 +246,27 @@ TEST_CASE("test grad") {
|
||||
}
|
||||
|
||||
{
|
||||
// Evaluating in the middle of the grad function throws
|
||||
// as it breaks the graph
|
||||
auto fn = [](array x) {
|
||||
x = x + 2.0f;
|
||||
eval(x);
|
||||
return square(x);
|
||||
};
|
||||
CHECK_THROWS(grad(fn)(array(1.0f)));
|
||||
|
||||
// Ok since the output is independent of y
|
||||
// No graph retention since the output is independent of y
|
||||
auto y = ones({3, 3});
|
||||
auto fn1 = [y](array x) {
|
||||
x = x + 2.0f;
|
||||
eval(y);
|
||||
CHECK(x.is_tracer());
|
||||
CHECK(!y.is_tracer());
|
||||
CHECK(y.is_evaled());
|
||||
CHECK(!y.has_primitive());
|
||||
return square(x);
|
||||
};
|
||||
auto dfdx = grad(fn1)(array(1.0f));
|
||||
CHECK_EQ(dfdx.item<float>(), 6.0f);
|
||||
|
||||
// Retain the graph to avoid breaking it
|
||||
// Graph automatically retained to compute the grad
|
||||
auto fn2 = [](array x) {
|
||||
x = x + 2.0f;
|
||||
eval({x}, true);
|
||||
eval(x);
|
||||
CHECK(x.is_tracer());
|
||||
CHECK(x.is_evaled());
|
||||
CHECK(x.has_primitive());
|
||||
return square(x);
|
||||
};
|
||||
dfdx = grad(fn2)(array(1.0f));
|
||||
@ -283,7 +276,8 @@ TEST_CASE("test grad") {
|
||||
// Control flow in grad computation
|
||||
{
|
||||
auto fn = [](array x) {
|
||||
if (x.item<float>(true) > 1) {
|
||||
x = x + array(2.0f);
|
||||
if (x.item<float>() > 3) {
|
||||
return square(x);
|
||||
} else {
|
||||
return 4 * x;
|
||||
@ -294,7 +288,7 @@ TEST_CASE("test grad") {
|
||||
CHECK_EQ(dfdx.item<float>(), 4.0f);
|
||||
|
||||
dfdx = grad(fn)(array(1.5f));
|
||||
CHECK_EQ(dfdx.item<float>(), 3.0f);
|
||||
CHECK_EQ(dfdx.item<float>(), 7.0f);
|
||||
}
|
||||
|
||||
// Grad with multiple inputs
|
||||
@ -1192,3 +1186,19 @@ TEST_CASE("test scan grads") {
|
||||
CHECK(array_equal(out, expected).item<bool>());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test update state") {
|
||||
auto y = array({1.0});
|
||||
auto x = array({1.0, 1.0});
|
||||
auto state = array({0.0, 0.0});
|
||||
auto fn = [&state, &x](array y) {
|
||||
x = y * x;
|
||||
state = state + x;
|
||||
return sum(x);
|
||||
};
|
||||
grad(fn)(y);
|
||||
eval(state);
|
||||
CHECK(!state.has_primitive());
|
||||
CHECK(state.is_evaled());
|
||||
CHECK(array_equal(state, array({1.0, 1.0})).item<bool>());
|
||||
}
|
||||
|
@ -48,36 +48,36 @@ TEST_CASE("test eval multiple") {
|
||||
CHECK(array_equal(b, full({10, 10}, 0.0f)).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test eval with tracer") {
|
||||
TEST_CASE("test eval with tracer when not tracing") {
|
||||
// Since we are not tracing it doesn't matter that the array flags are
|
||||
// tracers they will always be detached.
|
||||
auto x = array(1);
|
||||
x.set_tracer(true);
|
||||
|
||||
// Ok, x is not a node
|
||||
CHECK(!x.is_tracer());
|
||||
eval(x);
|
||||
CHECK(!x.has_primitive());
|
||||
CHECK(x.is_evaled());
|
||||
|
||||
x = ones({2, 3});
|
||||
x.set_tracer(true);
|
||||
CHECK_THROWS(eval(x));
|
||||
|
||||
// Ok retain_graph=true
|
||||
eval({x}, true);
|
||||
|
||||
// Make sure all arguments are checked
|
||||
auto y = ones({2, 3});
|
||||
CHECK_THROWS(eval(x, y));
|
||||
eval(x);
|
||||
CHECK(!x.has_primitive());
|
||||
CHECK(x.is_evaled());
|
||||
}
|
||||
|
||||
TEST_CASE("test eval graph retention") {
|
||||
TEST_CASE("test eval graph retention when not tracing") {
|
||||
// Since we are not tracing it doesn't matter that the array flags are
|
||||
// tracers they will always be detached.
|
||||
auto x = array(1);
|
||||
x.set_tracer(true);
|
||||
auto y = array(2);
|
||||
auto z = x + y;
|
||||
eval({z}, true);
|
||||
CHECK(z.has_primitive());
|
||||
CHECK(z.is_evaled());
|
||||
CHECK_EQ(z.item<int>(true), 3);
|
||||
CHECK(z.has_primitive());
|
||||
eval(z);
|
||||
CHECK(!z.has_primitive());
|
||||
CHECK(z.is_evaled());
|
||||
CHECK_EQ(z.item<int>(), 3);
|
||||
|
||||
z.set_tracer(false);
|
||||
CHECK_EQ(z.item<int>(), 3);
|
||||
CHECK(!z.has_primitive());
|
||||
CHECK(z.is_evaled());
|
||||
@ -85,13 +85,7 @@ TEST_CASE("test eval graph retention") {
|
||||
z = x + y;
|
||||
auto a = z + x;
|
||||
auto b = a + y;
|
||||
eval({b}, true);
|
||||
CHECK(z.has_primitive());
|
||||
CHECK(z.is_evaled());
|
||||
CHECK(a.has_primitive());
|
||||
CHECK(a.is_evaled());
|
||||
|
||||
eval({b}, false);
|
||||
eval(b);
|
||||
CHECK(!z.has_primitive());
|
||||
CHECK(z.is_evaled());
|
||||
CHECK(!a.has_primitive());
|
||||
|
@ -183,7 +183,7 @@ TEST_CASE("test vmap with eval") {
|
||||
auto fun2 = [](std::vector<array> inputs) {
|
||||
auto x = inputs[0] + 1;
|
||||
auto y = inputs[1] + 2;
|
||||
eval({x}, true);
|
||||
eval(x);
|
||||
auto out = add(x, y);
|
||||
return std::vector<array>{out};
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user