mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-19 19:38:16 +08:00
Removes the retain_graph
flag (#385)
* Adds global tracing flag * Removes retain_graph in favor of is_tracer
This commit is contained in:

committed by
GitHub

parent
449b43762e
commit
a611b0bc82
@@ -6,6 +6,7 @@
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/transforms.h"
|
||||
#include "mlx/transforms_impl.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@@ -21,6 +22,12 @@ std::pair<size_t, std::vector<size_t>> cum_prod(const std::vector<int>& shape) {
|
||||
return {cum_prod, strides};
|
||||
}
|
||||
|
||||
/** Return true if we are currently performing a function transformation in
|
||||
* order to keep the graph when evaluating tracer arrays. */
|
||||
bool in_tracing() {
|
||||
return detail::InTracing::in_tracing();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
|
||||
@@ -62,8 +69,12 @@ void array::detach() {
|
||||
array_desc_->primitive = nullptr;
|
||||
}
|
||||
|
||||
void array::eval(bool retain_graph /* = false */) {
|
||||
mlx::core::eval({*this}, retain_graph);
|
||||
void array::eval() {
|
||||
mlx::core::eval({*this});
|
||||
}
|
||||
|
||||
bool array::is_tracer() const {
|
||||
return array_desc_->is_tracer && in_tracing();
|
||||
}
|
||||
|
||||
void array::set_data(allocator::Buffer buffer, deleter_t d) {
|
||||
|
12
mlx/array.h
12
mlx/array.h
@@ -116,11 +116,11 @@ class array {
|
||||
};
|
||||
|
||||
/** Evaluate the array. */
|
||||
void eval(bool retain_graph = false);
|
||||
void eval();
|
||||
|
||||
/** Get the value from a scalar array. */
|
||||
template <typename T>
|
||||
T item(bool retain_graph = false);
|
||||
T item();
|
||||
|
||||
struct ArrayIterator {
|
||||
using iterator_category = std::random_access_iterator_tag;
|
||||
@@ -265,9 +265,7 @@ class array {
|
||||
array_desc_->is_tracer = is_tracer;
|
||||
}
|
||||
// Check if the array is a tracer array
|
||||
bool is_tracer() const {
|
||||
return array_desc_->is_tracer;
|
||||
}
|
||||
bool is_tracer() const;
|
||||
|
||||
void set_data(allocator::Buffer buffer, deleter_t d = allocator::free);
|
||||
|
||||
@@ -381,11 +379,11 @@ array::array(
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T array::item(bool retain_graph /* = false */) {
|
||||
T array::item() {
|
||||
if (size() != 1) {
|
||||
throw std::invalid_argument("item can only be called on arrays of size 1.");
|
||||
}
|
||||
eval(retain_graph);
|
||||
eval();
|
||||
return *data<T>();
|
||||
}
|
||||
|
||||
|
@@ -46,39 +46,36 @@ MTL::CommandBuffer* increment_command_buffer(Stream s) {
|
||||
std::function<void()> make_task(
|
||||
array& arr,
|
||||
std::vector<std::shared_future<void>> deps,
|
||||
std::shared_ptr<std::promise<void>> p,
|
||||
bool retain_graph) {
|
||||
auto task =
|
||||
[retain_graph, arr, deps = std::move(deps), p = std::move(p)]() mutable {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
for (auto& d : deps) {
|
||||
d.wait();
|
||||
}
|
||||
auto s = arr.primitive().stream();
|
||||
auto command_buffer = increment_command_buffer(s);
|
||||
arr.primitive().eval_gpu(arr.inputs(), arr);
|
||||
if (p) {
|
||||
metal::device(s.device).end_encoding(s.index);
|
||||
scheduler::notify_new_task(s);
|
||||
command_buffer->addCompletedHandler(
|
||||
[retain_graph, s, arr, p = std::move(p)](
|
||||
MTL::CommandBuffer*) mutable {
|
||||
if (!retain_graph) {
|
||||
arr.detach();
|
||||
}
|
||||
p->set_value();
|
||||
scheduler::notify_task_completion(s);
|
||||
});
|
||||
metal::device(s.device).commit_command_buffer(s.index);
|
||||
} else {
|
||||
command_buffer->addCompletedHandler(
|
||||
[retain_graph, s, arr](MTL::CommandBuffer*) mutable {
|
||||
if (!retain_graph) {
|
||||
arr.detach();
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
std::shared_ptr<std::promise<void>> p) {
|
||||
auto task = [arr, deps = std::move(deps), p = std::move(p)]() mutable {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
for (auto& d : deps) {
|
||||
d.wait();
|
||||
}
|
||||
auto s = arr.primitive().stream();
|
||||
auto command_buffer = increment_command_buffer(s);
|
||||
arr.primitive().eval_gpu(arr.inputs(), arr);
|
||||
if (p) {
|
||||
metal::device(s.device).end_encoding(s.index);
|
||||
scheduler::notify_new_task(s);
|
||||
command_buffer->addCompletedHandler(
|
||||
[s, arr, p = std::move(p)](MTL::CommandBuffer*) mutable {
|
||||
if (!arr.is_tracer()) {
|
||||
arr.detach();
|
||||
}
|
||||
p->set_value();
|
||||
scheduler::notify_task_completion(s);
|
||||
});
|
||||
metal::device(s.device).commit_command_buffer(s.index);
|
||||
} else {
|
||||
command_buffer->addCompletedHandler(
|
||||
[s, arr](MTL::CommandBuffer*) mutable {
|
||||
if (!arr.is_tracer()) {
|
||||
arr.detach();
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
return task;
|
||||
}
|
||||
|
||||
|
@@ -25,7 +25,6 @@ std::shared_ptr<void> new_scoped_memory_pool();
|
||||
std::function<void()> make_task(
|
||||
array& arr,
|
||||
std::vector<std::shared_future<void>> deps,
|
||||
std::shared_ptr<std::promise<void>> p,
|
||||
bool retain_graph);
|
||||
std::shared_ptr<std::promise<void>> p);
|
||||
|
||||
} // namespace mlx::core::metal
|
||||
|
@@ -14,8 +14,7 @@ std::shared_ptr<void> new_scoped_memory_pool() {
|
||||
std::function<void()> make_task(
|
||||
array& arr,
|
||||
std::vector<std::shared_future<void>> deps,
|
||||
std::shared_ptr<std::promise<void>> p,
|
||||
bool retain_graph) {
|
||||
std::shared_ptr<std::promise<void>> p) {
|
||||
throw std::runtime_error(
|
||||
"[metal::make_task] Cannot make GPU task without metal backend");
|
||||
}
|
||||
|
@@ -40,11 +40,11 @@ inline bool is_big_endian_() {
|
||||
} // namespace
|
||||
|
||||
/** Save array to out stream in .npy format */
|
||||
void save(std::shared_ptr<io::Writer> out_stream, array a, bool retain_graph) {
|
||||
void save(std::shared_ptr<io::Writer> out_stream, array a) {
|
||||
////////////////////////////////////////////////////////
|
||||
// Check array
|
||||
|
||||
a.eval(retain_graph);
|
||||
a.eval();
|
||||
|
||||
if (a.nbytes() == 0) {
|
||||
throw std::invalid_argument("[save] cannot serialize an empty array");
|
||||
@@ -52,7 +52,7 @@ void save(std::shared_ptr<io::Writer> out_stream, array a, bool retain_graph) {
|
||||
|
||||
if (!(a.flags().row_contiguous || a.flags().col_contiguous)) {
|
||||
a = reshape(flatten(a), a.shape());
|
||||
a.eval(retain_graph);
|
||||
a.eval();
|
||||
}
|
||||
// Check once more in-case the above ops change
|
||||
if (!(a.flags().row_contiguous || a.flags().col_contiguous)) {
|
||||
@@ -127,7 +127,7 @@ void save(std::shared_ptr<io::Writer> out_stream, array a, bool retain_graph) {
|
||||
}
|
||||
|
||||
/** Save array to file in .npy format */
|
||||
void save(const std::string& file_, array a, bool retain_graph) {
|
||||
void save(const std::string& file_, array a) {
|
||||
// Open and check file
|
||||
std::string file = file_;
|
||||
|
||||
@@ -136,7 +136,7 @@ void save(const std::string& file_, array a, bool retain_graph) {
|
||||
file += ".npy";
|
||||
|
||||
// Serialize array
|
||||
save(std::make_shared<io::FileWriter>(file), a, retain_graph);
|
||||
save(std::make_shared<io::FileWriter>(file), a);
|
||||
}
|
||||
|
||||
/** Load array from reader in .npy format */
|
||||
|
@@ -111,4 +111,4 @@ class FileWriter : public Writer {
|
||||
};
|
||||
|
||||
} // namespace io
|
||||
} // namespace mlx::core
|
||||
} // namespace mlx::core
|
||||
|
@@ -125,8 +125,7 @@ std::unordered_map<std::string, array> load_safetensors(
|
||||
/** Save array to out stream in .npy format */
|
||||
void save_safetensors(
|
||||
std::shared_ptr<io::Writer> out_stream,
|
||||
std::unordered_map<std::string, array> a,
|
||||
std::optional<bool> retain_graph_) {
|
||||
std::unordered_map<std::string, array> a) {
|
||||
////////////////////////////////////////////////////////
|
||||
// Check file
|
||||
if (!out_stream->good() || !out_stream->is_open()) {
|
||||
@@ -142,8 +141,7 @@ void save_safetensors(
|
||||
});
|
||||
size_t offset = 0;
|
||||
for (auto& [key, arr] : a) {
|
||||
auto retain = retain_graph_.value_or(arr.is_tracer());
|
||||
arr.eval(retain);
|
||||
arr.eval();
|
||||
if (arr.nbytes() == 0) {
|
||||
throw std::invalid_argument(
|
||||
"[save_safetensors] cannot serialize an empty array key: " + key);
|
||||
@@ -152,7 +150,7 @@ void save_safetensors(
|
||||
// Try to make it row contiguous
|
||||
if (!arr.flags().row_contiguous) {
|
||||
arr = reshape(flatten(arr), arr.shape());
|
||||
arr.eval(retain);
|
||||
arr.eval();
|
||||
}
|
||||
|
||||
// Has to be row-major now but, check one more time in case
|
||||
@@ -181,8 +179,7 @@ void save_safetensors(
|
||||
|
||||
void save_safetensors(
|
||||
const std::string& file_,
|
||||
std::unordered_map<std::string, array> a,
|
||||
std::optional<bool> retain_graph) {
|
||||
std::unordered_map<std::string, array> a) {
|
||||
// Open and check file
|
||||
std::string file = file_;
|
||||
|
||||
@@ -192,7 +189,7 @@ void save_safetensors(
|
||||
file += ".safetensors";
|
||||
|
||||
// Serialize array
|
||||
save_safetensors(std::make_shared<io::FileWriter>(file), a, retain_graph);
|
||||
save_safetensors(std::make_shared<io::FileWriter>(file), a);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
13
mlx/ops.h
13
mlx/ops.h
@@ -1021,13 +1021,10 @@ array conv2d(
|
||||
/** Serialization operations */
|
||||
|
||||
/** Save array to out stream in .npy format */
|
||||
void save(
|
||||
std::shared_ptr<io::Writer> out_stream,
|
||||
array a,
|
||||
bool retain_graph = true);
|
||||
void save(std::shared_ptr<io::Writer> out_stream, array a);
|
||||
|
||||
/** Save array to file in .npy format */
|
||||
void save(const std::string& file, array a, bool retain_graph = true);
|
||||
void save(const std::string& file, array a);
|
||||
|
||||
/** Load array from reader in .npy format */
|
||||
array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s = {});
|
||||
@@ -1091,10 +1088,8 @@ std::unordered_map<std::string, array> load_safetensors(
|
||||
|
||||
void save_safetensors(
|
||||
std::shared_ptr<io::Writer> in_stream,
|
||||
std::unordered_map<std::string, array>,
|
||||
std::optional<bool> retain_graph = std::nullopt);
|
||||
std::unordered_map<std::string, array>);
|
||||
void save_safetensors(
|
||||
const std::string& file,
|
||||
std::unordered_map<std::string, array>,
|
||||
std::optional<bool> retain_graph = std::nullopt);
|
||||
std::unordered_map<std::string, array>);
|
||||
} // namespace mlx::core
|
||||
|
@@ -19,6 +19,12 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
// Initialize the static tracing counter from transforms_impl.h .
|
||||
//
|
||||
// This is used to implement the in_tracing() function the returns true if we
|
||||
// are currently under a function transformation.
|
||||
int detail::InTracing::tracing_counter{0};
|
||||
|
||||
void simplify(const std::vector<array>& outputs) {
|
||||
std::function<void(const array&)> recurse;
|
||||
std::queue<array> tape;
|
||||
@@ -154,16 +160,7 @@ void simplify(const std::vector<array>& outputs) {
|
||||
}
|
||||
}
|
||||
|
||||
void eval(const std::vector<array>& outputs, bool retain_graph /* = false */) {
|
||||
if (!retain_graph) {
|
||||
for (auto& out : outputs) {
|
||||
if (out.has_primitive() && out.is_tracer()) {
|
||||
throw std::invalid_argument(
|
||||
"[eval] Illegal to eval an array during "
|
||||
"function transform without graph retention.");
|
||||
}
|
||||
}
|
||||
}
|
||||
void eval(const std::vector<array>& outputs) {
|
||||
std::function<void(const array&)> recurse;
|
||||
std::queue<array> tape;
|
||||
std::unordered_set<std::uintptr_t> cache;
|
||||
@@ -185,7 +182,7 @@ void eval(const std::vector<array>& outputs, bool retain_graph /* = false */) {
|
||||
}
|
||||
}
|
||||
cache.insert(id);
|
||||
if (!a.is_evaled() || (!retain_graph && a.has_primitive())) {
|
||||
if (!a.is_evaled() || (!a.is_tracer() && a.has_primitive())) {
|
||||
if (!a.has_primitive()) {
|
||||
throw std::invalid_argument(
|
||||
"[eval] Attempting to eval an array without a primitive.");
|
||||
@@ -195,7 +192,7 @@ void eval(const std::vector<array>& outputs, bool retain_graph /* = false */) {
|
||||
};
|
||||
|
||||
for (auto& arr : outputs) {
|
||||
if (!arr.is_evaled() || (!retain_graph && arr.has_primitive())) {
|
||||
if (!arr.is_evaled() || (!arr.is_tracer() && arr.has_primitive())) {
|
||||
recurse(arr);
|
||||
// Insert a dependency for every output to synchronize
|
||||
// with at the end.
|
||||
@@ -209,7 +206,7 @@ void eval(const std::vector<array>& outputs, bool retain_graph /* = false */) {
|
||||
auto arr = std::move(tape.front());
|
||||
tape.pop();
|
||||
if (arr.is_evaled()) {
|
||||
if (!retain_graph && arr.has_primitive()) {
|
||||
if (!arr.is_tracer() && arr.has_primitive()) {
|
||||
arr.detach();
|
||||
}
|
||||
continue;
|
||||
@@ -233,12 +230,9 @@ void eval(const std::vector<array>& outputs, bool retain_graph /* = false */) {
|
||||
throw std::runtime_error("Metal GPU is not available.");
|
||||
}
|
||||
scheduler::enqueue(
|
||||
stream,
|
||||
metal::make_task(
|
||||
arr, std::move(arr_deps), std::move(p), retain_graph));
|
||||
stream, metal::make_task(arr, std::move(arr_deps), std::move(p)));
|
||||
} else {
|
||||
auto task = [retain_graph,
|
||||
arr,
|
||||
auto task = [arr,
|
||||
stream,
|
||||
arr_deps = std::move(arr_deps),
|
||||
p = std::move(p)]() mutable {
|
||||
@@ -247,7 +241,7 @@ void eval(const std::vector<array>& outputs, bool retain_graph /* = false */) {
|
||||
}
|
||||
scheduler::notify_new_task(stream);
|
||||
arr.primitive().eval_cpu(arr.inputs(), arr);
|
||||
if (!retain_graph) {
|
||||
if (!arr.is_tracer()) {
|
||||
arr.detach();
|
||||
}
|
||||
if (p) {
|
||||
@@ -269,6 +263,9 @@ std::pair<std::vector<array>, std::vector<array>> vjp(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotans) {
|
||||
// Set the global tracing flag.
|
||||
detail::InTracing in_tracing;
|
||||
|
||||
// Make tracers from given primals
|
||||
std::vector<array> primals_;
|
||||
for (auto& p : primals) {
|
||||
@@ -425,6 +422,9 @@ std::pair<std::vector<array>, std::vector<array>> jvp(
|
||||
}
|
||||
}
|
||||
|
||||
// Set the global tracing flag.
|
||||
detail::InTracing in_tracing;
|
||||
|
||||
std::vector<array> primals_;
|
||||
for (auto& p : primals) {
|
||||
auto s = p.has_primitive() ? p.primitive().stream()
|
||||
@@ -578,6 +578,9 @@ std::pair<std::vector<array>, std::vector<array>> vmap_trace(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& in_axes) {
|
||||
// Set the global tracing flag
|
||||
InTracing in_tracing;
|
||||
|
||||
if (in_axes.size() != inputs.size()) {
|
||||
throw std::invalid_argument(
|
||||
"[vmap] The number of in axes must match the number of inputs.");
|
||||
|
@@ -14,11 +14,11 @@ void simplify(Arrays... outputs) {
|
||||
simplify(std::vector<array>{std::forward<Arrays>(outputs)...});
|
||||
}
|
||||
|
||||
void eval(const std::vector<array>& outputs, bool retain_graph = false);
|
||||
void eval(const std::vector<array>& outputs);
|
||||
|
||||
template <typename... Arrays>
|
||||
void eval(Arrays... outputs) {
|
||||
eval(std::vector<array>{std::forward<Arrays>(outputs)...}, false);
|
||||
eval(std::vector<array>{std::forward<Arrays>(outputs)...});
|
||||
}
|
||||
|
||||
/**
|
||||
|
@@ -14,4 +14,23 @@ std::vector<array> vmap_replace(
|
||||
const std::vector<int>& in_axes,
|
||||
const std::vector<int>& out_axes);
|
||||
|
||||
// Create an InTracing object during tracing operations to signify to the rest
|
||||
// of the codebase that we are during tracing so evals should not throw away
|
||||
// the graph.
|
||||
struct InTracing {
|
||||
InTracing() {
|
||||
tracing_counter++;
|
||||
}
|
||||
~InTracing() {
|
||||
tracing_counter--;
|
||||
}
|
||||
|
||||
static bool in_tracing() {
|
||||
return tracing_counter > 0;
|
||||
}
|
||||
|
||||
private:
|
||||
static int tracing_counter;
|
||||
};
|
||||
|
||||
} // namespace mlx::core::detail
|
||||
|
Reference in New Issue
Block a user