Removes the retain_graph flag (#385)

* Adds global tracing flag
* Removes retain_graph in favor of is_tracer
This commit is contained in:
Angelos Katharopoulos
2024-01-07 15:16:51 -08:00
committed by GitHub
parent 449b43762e
commit a611b0bc82
22 changed files with 209 additions and 207 deletions

View File

@@ -6,6 +6,7 @@
#include "mlx/ops.h"
#include "mlx/primitives.h"
#include "mlx/transforms.h"
#include "mlx/transforms_impl.h"
namespace mlx::core {
@@ -21,6 +22,12 @@ std::pair<size_t, std::vector<size_t>> cum_prod(const std::vector<int>& shape) {
return {cum_prod, strides};
}
/** Return true if we are currently performing a function transformation in
* order to keep the graph when evaluating tracer arrays. */
bool in_tracing() {
return detail::InTracing::in_tracing();
}
} // namespace
array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
@@ -62,8 +69,12 @@ void array::detach() {
array_desc_->primitive = nullptr;
}
void array::eval(bool retain_graph /* = false */) {
mlx::core::eval({*this}, retain_graph);
void array::eval() {
mlx::core::eval({*this});
}
bool array::is_tracer() const {
return array_desc_->is_tracer && in_tracing();
}
void array::set_data(allocator::Buffer buffer, deleter_t d) {

View File

@@ -116,11 +116,11 @@ class array {
};
/** Evaluate the array. */
void eval(bool retain_graph = false);
void eval();
/** Get the value from a scalar array. */
template <typename T>
T item(bool retain_graph = false);
T item();
struct ArrayIterator {
using iterator_category = std::random_access_iterator_tag;
@@ -265,9 +265,7 @@ class array {
array_desc_->is_tracer = is_tracer;
}
// Check if the array is a tracer array
bool is_tracer() const {
return array_desc_->is_tracer;
}
bool is_tracer() const;
void set_data(allocator::Buffer buffer, deleter_t d = allocator::free);
@@ -381,11 +379,11 @@ array::array(
}
template <typename T>
T array::item(bool retain_graph /* = false */) {
T array::item() {
if (size() != 1) {
throw std::invalid_argument("item can only be called on arrays of size 1.");
}
eval(retain_graph);
eval();
return *data<T>();
}

View File

@@ -46,39 +46,36 @@ MTL::CommandBuffer* increment_command_buffer(Stream s) {
std::function<void()> make_task(
array& arr,
std::vector<std::shared_future<void>> deps,
std::shared_ptr<std::promise<void>> p,
bool retain_graph) {
auto task =
[retain_graph, arr, deps = std::move(deps), p = std::move(p)]() mutable {
auto pool = new_scoped_memory_pool();
for (auto& d : deps) {
d.wait();
}
auto s = arr.primitive().stream();
auto command_buffer = increment_command_buffer(s);
arr.primitive().eval_gpu(arr.inputs(), arr);
if (p) {
metal::device(s.device).end_encoding(s.index);
scheduler::notify_new_task(s);
command_buffer->addCompletedHandler(
[retain_graph, s, arr, p = std::move(p)](
MTL::CommandBuffer*) mutable {
if (!retain_graph) {
arr.detach();
}
p->set_value();
scheduler::notify_task_completion(s);
});
metal::device(s.device).commit_command_buffer(s.index);
} else {
command_buffer->addCompletedHandler(
[retain_graph, s, arr](MTL::CommandBuffer*) mutable {
if (!retain_graph) {
arr.detach();
}
});
}
};
std::shared_ptr<std::promise<void>> p) {
auto task = [arr, deps = std::move(deps), p = std::move(p)]() mutable {
auto pool = new_scoped_memory_pool();
for (auto& d : deps) {
d.wait();
}
auto s = arr.primitive().stream();
auto command_buffer = increment_command_buffer(s);
arr.primitive().eval_gpu(arr.inputs(), arr);
if (p) {
metal::device(s.device).end_encoding(s.index);
scheduler::notify_new_task(s);
command_buffer->addCompletedHandler(
[s, arr, p = std::move(p)](MTL::CommandBuffer*) mutable {
if (!arr.is_tracer()) {
arr.detach();
}
p->set_value();
scheduler::notify_task_completion(s);
});
metal::device(s.device).commit_command_buffer(s.index);
} else {
command_buffer->addCompletedHandler(
[s, arr](MTL::CommandBuffer*) mutable {
if (!arr.is_tracer()) {
arr.detach();
}
});
}
};
return task;
}

View File

@@ -25,7 +25,6 @@ std::shared_ptr<void> new_scoped_memory_pool();
std::function<void()> make_task(
array& arr,
std::vector<std::shared_future<void>> deps,
std::shared_ptr<std::promise<void>> p,
bool retain_graph);
std::shared_ptr<std::promise<void>> p);
} // namespace mlx::core::metal

View File

@@ -14,8 +14,7 @@ std::shared_ptr<void> new_scoped_memory_pool() {
std::function<void()> make_task(
array& arr,
std::vector<std::shared_future<void>> deps,
std::shared_ptr<std::promise<void>> p,
bool retain_graph) {
std::shared_ptr<std::promise<void>> p) {
throw std::runtime_error(
"[metal::make_task] Cannot make GPU task without metal backend");
}

View File

@@ -40,11 +40,11 @@ inline bool is_big_endian_() {
} // namespace
/** Save array to out stream in .npy format */
void save(std::shared_ptr<io::Writer> out_stream, array a, bool retain_graph) {
void save(std::shared_ptr<io::Writer> out_stream, array a) {
////////////////////////////////////////////////////////
// Check array
a.eval(retain_graph);
a.eval();
if (a.nbytes() == 0) {
throw std::invalid_argument("[save] cannot serialize an empty array");
@@ -52,7 +52,7 @@ void save(std::shared_ptr<io::Writer> out_stream, array a, bool retain_graph) {
if (!(a.flags().row_contiguous || a.flags().col_contiguous)) {
a = reshape(flatten(a), a.shape());
a.eval(retain_graph);
a.eval();
}
// Check once more in-case the above ops change
if (!(a.flags().row_contiguous || a.flags().col_contiguous)) {
@@ -127,7 +127,7 @@ void save(std::shared_ptr<io::Writer> out_stream, array a, bool retain_graph) {
}
/** Save array to file in .npy format */
void save(const std::string& file_, array a, bool retain_graph) {
void save(const std::string& file_, array a) {
// Open and check file
std::string file = file_;
@@ -136,7 +136,7 @@ void save(const std::string& file_, array a, bool retain_graph) {
file += ".npy";
// Serialize array
save(std::make_shared<io::FileWriter>(file), a, retain_graph);
save(std::make_shared<io::FileWriter>(file), a);
}
/** Load array from reader in .npy format */

View File

@@ -111,4 +111,4 @@ class FileWriter : public Writer {
};
} // namespace io
} // namespace mlx::core
} // namespace mlx::core

View File

@@ -125,8 +125,7 @@ std::unordered_map<std::string, array> load_safetensors(
/** Save array to out stream in .npy format */
void save_safetensors(
std::shared_ptr<io::Writer> out_stream,
std::unordered_map<std::string, array> a,
std::optional<bool> retain_graph_) {
std::unordered_map<std::string, array> a) {
////////////////////////////////////////////////////////
// Check file
if (!out_stream->good() || !out_stream->is_open()) {
@@ -142,8 +141,7 @@ void save_safetensors(
});
size_t offset = 0;
for (auto& [key, arr] : a) {
auto retain = retain_graph_.value_or(arr.is_tracer());
arr.eval(retain);
arr.eval();
if (arr.nbytes() == 0) {
throw std::invalid_argument(
"[save_safetensors] cannot serialize an empty array key: " + key);
@@ -152,7 +150,7 @@ void save_safetensors(
// Try to make it row contiguous
if (!arr.flags().row_contiguous) {
arr = reshape(flatten(arr), arr.shape());
arr.eval(retain);
arr.eval();
}
// Has to be row-major now but, check one more time in case
@@ -181,8 +179,7 @@ void save_safetensors(
void save_safetensors(
const std::string& file_,
std::unordered_map<std::string, array> a,
std::optional<bool> retain_graph) {
std::unordered_map<std::string, array> a) {
// Open and check file
std::string file = file_;
@@ -192,7 +189,7 @@ void save_safetensors(
file += ".safetensors";
// Serialize array
save_safetensors(std::make_shared<io::FileWriter>(file), a, retain_graph);
save_safetensors(std::make_shared<io::FileWriter>(file), a);
}
} // namespace mlx::core

View File

@@ -1021,13 +1021,10 @@ array conv2d(
/** Serialization operations */
/** Save array to out stream in .npy format */
void save(
std::shared_ptr<io::Writer> out_stream,
array a,
bool retain_graph = true);
void save(std::shared_ptr<io::Writer> out_stream, array a);
/** Save array to file in .npy format */
void save(const std::string& file, array a, bool retain_graph = true);
void save(const std::string& file, array a);
/** Load array from reader in .npy format */
array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s = {});
@@ -1091,10 +1088,8 @@ std::unordered_map<std::string, array> load_safetensors(
void save_safetensors(
std::shared_ptr<io::Writer> in_stream,
std::unordered_map<std::string, array>,
std::optional<bool> retain_graph = std::nullopt);
std::unordered_map<std::string, array>);
void save_safetensors(
const std::string& file,
std::unordered_map<std::string, array>,
std::optional<bool> retain_graph = std::nullopt);
std::unordered_map<std::string, array>);
} // namespace mlx::core

View File

@@ -19,6 +19,12 @@
namespace mlx::core {
// Initialize the static tracing counter from transforms_impl.h .
//
// This is used to implement the in_tracing() function the returns true if we
// are currently under a function transformation.
int detail::InTracing::tracing_counter{0};
void simplify(const std::vector<array>& outputs) {
std::function<void(const array&)> recurse;
std::queue<array> tape;
@@ -154,16 +160,7 @@ void simplify(const std::vector<array>& outputs) {
}
}
void eval(const std::vector<array>& outputs, bool retain_graph /* = false */) {
if (!retain_graph) {
for (auto& out : outputs) {
if (out.has_primitive() && out.is_tracer()) {
throw std::invalid_argument(
"[eval] Illegal to eval an array during "
"function transform without graph retention.");
}
}
}
void eval(const std::vector<array>& outputs) {
std::function<void(const array&)> recurse;
std::queue<array> tape;
std::unordered_set<std::uintptr_t> cache;
@@ -185,7 +182,7 @@ void eval(const std::vector<array>& outputs, bool retain_graph /* = false */) {
}
}
cache.insert(id);
if (!a.is_evaled() || (!retain_graph && a.has_primitive())) {
if (!a.is_evaled() || (!a.is_tracer() && a.has_primitive())) {
if (!a.has_primitive()) {
throw std::invalid_argument(
"[eval] Attempting to eval an array without a primitive.");
@@ -195,7 +192,7 @@ void eval(const std::vector<array>& outputs, bool retain_graph /* = false */) {
};
for (auto& arr : outputs) {
if (!arr.is_evaled() || (!retain_graph && arr.has_primitive())) {
if (!arr.is_evaled() || (!arr.is_tracer() && arr.has_primitive())) {
recurse(arr);
// Insert a dependency for every output to synchronize
// with at the end.
@@ -209,7 +206,7 @@ void eval(const std::vector<array>& outputs, bool retain_graph /* = false */) {
auto arr = std::move(tape.front());
tape.pop();
if (arr.is_evaled()) {
if (!retain_graph && arr.has_primitive()) {
if (!arr.is_tracer() && arr.has_primitive()) {
arr.detach();
}
continue;
@@ -233,12 +230,9 @@ void eval(const std::vector<array>& outputs, bool retain_graph /* = false */) {
throw std::runtime_error("Metal GPU is not available.");
}
scheduler::enqueue(
stream,
metal::make_task(
arr, std::move(arr_deps), std::move(p), retain_graph));
stream, metal::make_task(arr, std::move(arr_deps), std::move(p)));
} else {
auto task = [retain_graph,
arr,
auto task = [arr,
stream,
arr_deps = std::move(arr_deps),
p = std::move(p)]() mutable {
@@ -247,7 +241,7 @@ void eval(const std::vector<array>& outputs, bool retain_graph /* = false */) {
}
scheduler::notify_new_task(stream);
arr.primitive().eval_cpu(arr.inputs(), arr);
if (!retain_graph) {
if (!arr.is_tracer()) {
arr.detach();
}
if (p) {
@@ -269,6 +263,9 @@ std::pair<std::vector<array>, std::vector<array>> vjp(
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
const std::vector<array>& primals,
const std::vector<array>& cotans) {
// Set the global tracing flag.
detail::InTracing in_tracing;
// Make tracers from given primals
std::vector<array> primals_;
for (auto& p : primals) {
@@ -425,6 +422,9 @@ std::pair<std::vector<array>, std::vector<array>> jvp(
}
}
// Set the global tracing flag.
detail::InTracing in_tracing;
std::vector<array> primals_;
for (auto& p : primals) {
auto s = p.has_primitive() ? p.primitive().stream()
@@ -578,6 +578,9 @@ std::pair<std::vector<array>, std::vector<array>> vmap_trace(
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
const std::vector<array>& inputs,
const std::vector<int>& in_axes) {
// Set the global tracing flag
InTracing in_tracing;
if (in_axes.size() != inputs.size()) {
throw std::invalid_argument(
"[vmap] The number of in axes must match the number of inputs.");

View File

@@ -14,11 +14,11 @@ void simplify(Arrays... outputs) {
simplify(std::vector<array>{std::forward<Arrays>(outputs)...});
}
void eval(const std::vector<array>& outputs, bool retain_graph = false);
void eval(const std::vector<array>& outputs);
template <typename... Arrays>
void eval(Arrays... outputs) {
eval(std::vector<array>{std::forward<Arrays>(outputs)...}, false);
eval(std::vector<array>{std::forward<Arrays>(outputs)...});
}
/**

View File

@@ -14,4 +14,23 @@ std::vector<array> vmap_replace(
const std::vector<int>& in_axes,
const std::vector<int>& out_axes);
// Create an InTracing object during tracing operations to signify to the rest
// of the codebase that we are during tracing so evals should not throw away
// the graph.
struct InTracing {
InTracing() {
tracing_counter++;
}
~InTracing() {
tracing_counter--;
}
static bool in_tracing() {
return tracing_counter > 0;
}
private:
static int tracing_counter;
};
} // namespace mlx::core::detail