mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-04 16:21:14 +08:00
Removes the retain_graph
flag (#385)
* Adds global tracing flag * Removes retain_graph in favor of is_tracer
This commit is contained in:
parent
449b43762e
commit
a611b0bc82
@ -6,6 +6,7 @@
|
|||||||
#include "mlx/ops.h"
|
#include "mlx/ops.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
#include "mlx/transforms.h"
|
#include "mlx/transforms.h"
|
||||||
|
#include "mlx/transforms_impl.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
@ -21,6 +22,12 @@ std::pair<size_t, std::vector<size_t>> cum_prod(const std::vector<int>& shape) {
|
|||||||
return {cum_prod, strides};
|
return {cum_prod, strides};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Return true if we are currently performing a function transformation in
|
||||||
|
* order to keep the graph when evaluating tracer arrays. */
|
||||||
|
bool in_tracing() {
|
||||||
|
return detail::InTracing::in_tracing();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
|
array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
|
||||||
@ -62,8 +69,12 @@ void array::detach() {
|
|||||||
array_desc_->primitive = nullptr;
|
array_desc_->primitive = nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
void array::eval(bool retain_graph /* = false */) {
|
void array::eval() {
|
||||||
mlx::core::eval({*this}, retain_graph);
|
mlx::core::eval({*this});
|
||||||
|
}
|
||||||
|
|
||||||
|
bool array::is_tracer() const {
|
||||||
|
return array_desc_->is_tracer && in_tracing();
|
||||||
}
|
}
|
||||||
|
|
||||||
void array::set_data(allocator::Buffer buffer, deleter_t d) {
|
void array::set_data(allocator::Buffer buffer, deleter_t d) {
|
||||||
|
12
mlx/array.h
12
mlx/array.h
@ -116,11 +116,11 @@ class array {
|
|||||||
};
|
};
|
||||||
|
|
||||||
/** Evaluate the array. */
|
/** Evaluate the array. */
|
||||||
void eval(bool retain_graph = false);
|
void eval();
|
||||||
|
|
||||||
/** Get the value from a scalar array. */
|
/** Get the value from a scalar array. */
|
||||||
template <typename T>
|
template <typename T>
|
||||||
T item(bool retain_graph = false);
|
T item();
|
||||||
|
|
||||||
struct ArrayIterator {
|
struct ArrayIterator {
|
||||||
using iterator_category = std::random_access_iterator_tag;
|
using iterator_category = std::random_access_iterator_tag;
|
||||||
@ -265,9 +265,7 @@ class array {
|
|||||||
array_desc_->is_tracer = is_tracer;
|
array_desc_->is_tracer = is_tracer;
|
||||||
}
|
}
|
||||||
// Check if the array is a tracer array
|
// Check if the array is a tracer array
|
||||||
bool is_tracer() const {
|
bool is_tracer() const;
|
||||||
return array_desc_->is_tracer;
|
|
||||||
}
|
|
||||||
|
|
||||||
void set_data(allocator::Buffer buffer, deleter_t d = allocator::free);
|
void set_data(allocator::Buffer buffer, deleter_t d = allocator::free);
|
||||||
|
|
||||||
@ -381,11 +379,11 @@ array::array(
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
T array::item(bool retain_graph /* = false */) {
|
T array::item() {
|
||||||
if (size() != 1) {
|
if (size() != 1) {
|
||||||
throw std::invalid_argument("item can only be called on arrays of size 1.");
|
throw std::invalid_argument("item can only be called on arrays of size 1.");
|
||||||
}
|
}
|
||||||
eval(retain_graph);
|
eval();
|
||||||
return *data<T>();
|
return *data<T>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -46,10 +46,8 @@ MTL::CommandBuffer* increment_command_buffer(Stream s) {
|
|||||||
std::function<void()> make_task(
|
std::function<void()> make_task(
|
||||||
array& arr,
|
array& arr,
|
||||||
std::vector<std::shared_future<void>> deps,
|
std::vector<std::shared_future<void>> deps,
|
||||||
std::shared_ptr<std::promise<void>> p,
|
std::shared_ptr<std::promise<void>> p) {
|
||||||
bool retain_graph) {
|
auto task = [arr, deps = std::move(deps), p = std::move(p)]() mutable {
|
||||||
auto task =
|
|
||||||
[retain_graph, arr, deps = std::move(deps), p = std::move(p)]() mutable {
|
|
||||||
auto pool = new_scoped_memory_pool();
|
auto pool = new_scoped_memory_pool();
|
||||||
for (auto& d : deps) {
|
for (auto& d : deps) {
|
||||||
d.wait();
|
d.wait();
|
||||||
@ -61,9 +59,8 @@ std::function<void()> make_task(
|
|||||||
metal::device(s.device).end_encoding(s.index);
|
metal::device(s.device).end_encoding(s.index);
|
||||||
scheduler::notify_new_task(s);
|
scheduler::notify_new_task(s);
|
||||||
command_buffer->addCompletedHandler(
|
command_buffer->addCompletedHandler(
|
||||||
[retain_graph, s, arr, p = std::move(p)](
|
[s, arr, p = std::move(p)](MTL::CommandBuffer*) mutable {
|
||||||
MTL::CommandBuffer*) mutable {
|
if (!arr.is_tracer()) {
|
||||||
if (!retain_graph) {
|
|
||||||
arr.detach();
|
arr.detach();
|
||||||
}
|
}
|
||||||
p->set_value();
|
p->set_value();
|
||||||
@ -72,8 +69,8 @@ std::function<void()> make_task(
|
|||||||
metal::device(s.device).commit_command_buffer(s.index);
|
metal::device(s.device).commit_command_buffer(s.index);
|
||||||
} else {
|
} else {
|
||||||
command_buffer->addCompletedHandler(
|
command_buffer->addCompletedHandler(
|
||||||
[retain_graph, s, arr](MTL::CommandBuffer*) mutable {
|
[s, arr](MTL::CommandBuffer*) mutable {
|
||||||
if (!retain_graph) {
|
if (!arr.is_tracer()) {
|
||||||
arr.detach();
|
arr.detach();
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
@ -25,7 +25,6 @@ std::shared_ptr<void> new_scoped_memory_pool();
|
|||||||
std::function<void()> make_task(
|
std::function<void()> make_task(
|
||||||
array& arr,
|
array& arr,
|
||||||
std::vector<std::shared_future<void>> deps,
|
std::vector<std::shared_future<void>> deps,
|
||||||
std::shared_ptr<std::promise<void>> p,
|
std::shared_ptr<std::promise<void>> p);
|
||||||
bool retain_graph);
|
|
||||||
|
|
||||||
} // namespace mlx::core::metal
|
} // namespace mlx::core::metal
|
||||||
|
@ -14,8 +14,7 @@ std::shared_ptr<void> new_scoped_memory_pool() {
|
|||||||
std::function<void()> make_task(
|
std::function<void()> make_task(
|
||||||
array& arr,
|
array& arr,
|
||||||
std::vector<std::shared_future<void>> deps,
|
std::vector<std::shared_future<void>> deps,
|
||||||
std::shared_ptr<std::promise<void>> p,
|
std::shared_ptr<std::promise<void>> p) {
|
||||||
bool retain_graph) {
|
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"[metal::make_task] Cannot make GPU task without metal backend");
|
"[metal::make_task] Cannot make GPU task without metal backend");
|
||||||
}
|
}
|
||||||
|
@ -40,11 +40,11 @@ inline bool is_big_endian_() {
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
/** Save array to out stream in .npy format */
|
/** Save array to out stream in .npy format */
|
||||||
void save(std::shared_ptr<io::Writer> out_stream, array a, bool retain_graph) {
|
void save(std::shared_ptr<io::Writer> out_stream, array a) {
|
||||||
////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////
|
||||||
// Check array
|
// Check array
|
||||||
|
|
||||||
a.eval(retain_graph);
|
a.eval();
|
||||||
|
|
||||||
if (a.nbytes() == 0) {
|
if (a.nbytes() == 0) {
|
||||||
throw std::invalid_argument("[save] cannot serialize an empty array");
|
throw std::invalid_argument("[save] cannot serialize an empty array");
|
||||||
@ -52,7 +52,7 @@ void save(std::shared_ptr<io::Writer> out_stream, array a, bool retain_graph) {
|
|||||||
|
|
||||||
if (!(a.flags().row_contiguous || a.flags().col_contiguous)) {
|
if (!(a.flags().row_contiguous || a.flags().col_contiguous)) {
|
||||||
a = reshape(flatten(a), a.shape());
|
a = reshape(flatten(a), a.shape());
|
||||||
a.eval(retain_graph);
|
a.eval();
|
||||||
}
|
}
|
||||||
// Check once more in-case the above ops change
|
// Check once more in-case the above ops change
|
||||||
if (!(a.flags().row_contiguous || a.flags().col_contiguous)) {
|
if (!(a.flags().row_contiguous || a.flags().col_contiguous)) {
|
||||||
@ -127,7 +127,7 @@ void save(std::shared_ptr<io::Writer> out_stream, array a, bool retain_graph) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/** Save array to file in .npy format */
|
/** Save array to file in .npy format */
|
||||||
void save(const std::string& file_, array a, bool retain_graph) {
|
void save(const std::string& file_, array a) {
|
||||||
// Open and check file
|
// Open and check file
|
||||||
std::string file = file_;
|
std::string file = file_;
|
||||||
|
|
||||||
@ -136,7 +136,7 @@ void save(const std::string& file_, array a, bool retain_graph) {
|
|||||||
file += ".npy";
|
file += ".npy";
|
||||||
|
|
||||||
// Serialize array
|
// Serialize array
|
||||||
save(std::make_shared<io::FileWriter>(file), a, retain_graph);
|
save(std::make_shared<io::FileWriter>(file), a);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Load array from reader in .npy format */
|
/** Load array from reader in .npy format */
|
||||||
|
@ -125,8 +125,7 @@ std::unordered_map<std::string, array> load_safetensors(
|
|||||||
/** Save array to out stream in .npy format */
|
/** Save array to out stream in .npy format */
|
||||||
void save_safetensors(
|
void save_safetensors(
|
||||||
std::shared_ptr<io::Writer> out_stream,
|
std::shared_ptr<io::Writer> out_stream,
|
||||||
std::unordered_map<std::string, array> a,
|
std::unordered_map<std::string, array> a) {
|
||||||
std::optional<bool> retain_graph_) {
|
|
||||||
////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////
|
||||||
// Check file
|
// Check file
|
||||||
if (!out_stream->good() || !out_stream->is_open()) {
|
if (!out_stream->good() || !out_stream->is_open()) {
|
||||||
@ -142,8 +141,7 @@ void save_safetensors(
|
|||||||
});
|
});
|
||||||
size_t offset = 0;
|
size_t offset = 0;
|
||||||
for (auto& [key, arr] : a) {
|
for (auto& [key, arr] : a) {
|
||||||
auto retain = retain_graph_.value_or(arr.is_tracer());
|
arr.eval();
|
||||||
arr.eval(retain);
|
|
||||||
if (arr.nbytes() == 0) {
|
if (arr.nbytes() == 0) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[save_safetensors] cannot serialize an empty array key: " + key);
|
"[save_safetensors] cannot serialize an empty array key: " + key);
|
||||||
@ -152,7 +150,7 @@ void save_safetensors(
|
|||||||
// Try to make it row contiguous
|
// Try to make it row contiguous
|
||||||
if (!arr.flags().row_contiguous) {
|
if (!arr.flags().row_contiguous) {
|
||||||
arr = reshape(flatten(arr), arr.shape());
|
arr = reshape(flatten(arr), arr.shape());
|
||||||
arr.eval(retain);
|
arr.eval();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Has to be row-major now but, check one more time in case
|
// Has to be row-major now but, check one more time in case
|
||||||
@ -181,8 +179,7 @@ void save_safetensors(
|
|||||||
|
|
||||||
void save_safetensors(
|
void save_safetensors(
|
||||||
const std::string& file_,
|
const std::string& file_,
|
||||||
std::unordered_map<std::string, array> a,
|
std::unordered_map<std::string, array> a) {
|
||||||
std::optional<bool> retain_graph) {
|
|
||||||
// Open and check file
|
// Open and check file
|
||||||
std::string file = file_;
|
std::string file = file_;
|
||||||
|
|
||||||
@ -192,7 +189,7 @@ void save_safetensors(
|
|||||||
file += ".safetensors";
|
file += ".safetensors";
|
||||||
|
|
||||||
// Serialize array
|
// Serialize array
|
||||||
save_safetensors(std::make_shared<io::FileWriter>(file), a, retain_graph);
|
save_safetensors(std::make_shared<io::FileWriter>(file), a);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
13
mlx/ops.h
13
mlx/ops.h
@ -1021,13 +1021,10 @@ array conv2d(
|
|||||||
/** Serialization operations */
|
/** Serialization operations */
|
||||||
|
|
||||||
/** Save array to out stream in .npy format */
|
/** Save array to out stream in .npy format */
|
||||||
void save(
|
void save(std::shared_ptr<io::Writer> out_stream, array a);
|
||||||
std::shared_ptr<io::Writer> out_stream,
|
|
||||||
array a,
|
|
||||||
bool retain_graph = true);
|
|
||||||
|
|
||||||
/** Save array to file in .npy format */
|
/** Save array to file in .npy format */
|
||||||
void save(const std::string& file, array a, bool retain_graph = true);
|
void save(const std::string& file, array a);
|
||||||
|
|
||||||
/** Load array from reader in .npy format */
|
/** Load array from reader in .npy format */
|
||||||
array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s = {});
|
array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s = {});
|
||||||
@ -1091,10 +1088,8 @@ std::unordered_map<std::string, array> load_safetensors(
|
|||||||
|
|
||||||
void save_safetensors(
|
void save_safetensors(
|
||||||
std::shared_ptr<io::Writer> in_stream,
|
std::shared_ptr<io::Writer> in_stream,
|
||||||
std::unordered_map<std::string, array>,
|
std::unordered_map<std::string, array>);
|
||||||
std::optional<bool> retain_graph = std::nullopt);
|
|
||||||
void save_safetensors(
|
void save_safetensors(
|
||||||
const std::string& file,
|
const std::string& file,
|
||||||
std::unordered_map<std::string, array>,
|
std::unordered_map<std::string, array>);
|
||||||
std::optional<bool> retain_graph = std::nullopt);
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -19,6 +19,12 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
|
// Initialize the static tracing counter from transforms_impl.h .
|
||||||
|
//
|
||||||
|
// This is used to implement the in_tracing() function the returns true if we
|
||||||
|
// are currently under a function transformation.
|
||||||
|
int detail::InTracing::tracing_counter{0};
|
||||||
|
|
||||||
void simplify(const std::vector<array>& outputs) {
|
void simplify(const std::vector<array>& outputs) {
|
||||||
std::function<void(const array&)> recurse;
|
std::function<void(const array&)> recurse;
|
||||||
std::queue<array> tape;
|
std::queue<array> tape;
|
||||||
@ -154,16 +160,7 @@ void simplify(const std::vector<array>& outputs) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void eval(const std::vector<array>& outputs, bool retain_graph /* = false */) {
|
void eval(const std::vector<array>& outputs) {
|
||||||
if (!retain_graph) {
|
|
||||||
for (auto& out : outputs) {
|
|
||||||
if (out.has_primitive() && out.is_tracer()) {
|
|
||||||
throw std::invalid_argument(
|
|
||||||
"[eval] Illegal to eval an array during "
|
|
||||||
"function transform without graph retention.");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
std::function<void(const array&)> recurse;
|
std::function<void(const array&)> recurse;
|
||||||
std::queue<array> tape;
|
std::queue<array> tape;
|
||||||
std::unordered_set<std::uintptr_t> cache;
|
std::unordered_set<std::uintptr_t> cache;
|
||||||
@ -185,7 +182,7 @@ void eval(const std::vector<array>& outputs, bool retain_graph /* = false */) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
cache.insert(id);
|
cache.insert(id);
|
||||||
if (!a.is_evaled() || (!retain_graph && a.has_primitive())) {
|
if (!a.is_evaled() || (!a.is_tracer() && a.has_primitive())) {
|
||||||
if (!a.has_primitive()) {
|
if (!a.has_primitive()) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[eval] Attempting to eval an array without a primitive.");
|
"[eval] Attempting to eval an array without a primitive.");
|
||||||
@ -195,7 +192,7 @@ void eval(const std::vector<array>& outputs, bool retain_graph /* = false */) {
|
|||||||
};
|
};
|
||||||
|
|
||||||
for (auto& arr : outputs) {
|
for (auto& arr : outputs) {
|
||||||
if (!arr.is_evaled() || (!retain_graph && arr.has_primitive())) {
|
if (!arr.is_evaled() || (!arr.is_tracer() && arr.has_primitive())) {
|
||||||
recurse(arr);
|
recurse(arr);
|
||||||
// Insert a dependency for every output to synchronize
|
// Insert a dependency for every output to synchronize
|
||||||
// with at the end.
|
// with at the end.
|
||||||
@ -209,7 +206,7 @@ void eval(const std::vector<array>& outputs, bool retain_graph /* = false */) {
|
|||||||
auto arr = std::move(tape.front());
|
auto arr = std::move(tape.front());
|
||||||
tape.pop();
|
tape.pop();
|
||||||
if (arr.is_evaled()) {
|
if (arr.is_evaled()) {
|
||||||
if (!retain_graph && arr.has_primitive()) {
|
if (!arr.is_tracer() && arr.has_primitive()) {
|
||||||
arr.detach();
|
arr.detach();
|
||||||
}
|
}
|
||||||
continue;
|
continue;
|
||||||
@ -233,12 +230,9 @@ void eval(const std::vector<array>& outputs, bool retain_graph /* = false */) {
|
|||||||
throw std::runtime_error("Metal GPU is not available.");
|
throw std::runtime_error("Metal GPU is not available.");
|
||||||
}
|
}
|
||||||
scheduler::enqueue(
|
scheduler::enqueue(
|
||||||
stream,
|
stream, metal::make_task(arr, std::move(arr_deps), std::move(p)));
|
||||||
metal::make_task(
|
|
||||||
arr, std::move(arr_deps), std::move(p), retain_graph));
|
|
||||||
} else {
|
} else {
|
||||||
auto task = [retain_graph,
|
auto task = [arr,
|
||||||
arr,
|
|
||||||
stream,
|
stream,
|
||||||
arr_deps = std::move(arr_deps),
|
arr_deps = std::move(arr_deps),
|
||||||
p = std::move(p)]() mutable {
|
p = std::move(p)]() mutable {
|
||||||
@ -247,7 +241,7 @@ void eval(const std::vector<array>& outputs, bool retain_graph /* = false */) {
|
|||||||
}
|
}
|
||||||
scheduler::notify_new_task(stream);
|
scheduler::notify_new_task(stream);
|
||||||
arr.primitive().eval_cpu(arr.inputs(), arr);
|
arr.primitive().eval_cpu(arr.inputs(), arr);
|
||||||
if (!retain_graph) {
|
if (!arr.is_tracer()) {
|
||||||
arr.detach();
|
arr.detach();
|
||||||
}
|
}
|
||||||
if (p) {
|
if (p) {
|
||||||
@ -269,6 +263,9 @@ std::pair<std::vector<array>, std::vector<array>> vjp(
|
|||||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||||
const std::vector<array>& primals,
|
const std::vector<array>& primals,
|
||||||
const std::vector<array>& cotans) {
|
const std::vector<array>& cotans) {
|
||||||
|
// Set the global tracing flag.
|
||||||
|
detail::InTracing in_tracing;
|
||||||
|
|
||||||
// Make tracers from given primals
|
// Make tracers from given primals
|
||||||
std::vector<array> primals_;
|
std::vector<array> primals_;
|
||||||
for (auto& p : primals) {
|
for (auto& p : primals) {
|
||||||
@ -425,6 +422,9 @@ std::pair<std::vector<array>, std::vector<array>> jvp(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set the global tracing flag.
|
||||||
|
detail::InTracing in_tracing;
|
||||||
|
|
||||||
std::vector<array> primals_;
|
std::vector<array> primals_;
|
||||||
for (auto& p : primals) {
|
for (auto& p : primals) {
|
||||||
auto s = p.has_primitive() ? p.primitive().stream()
|
auto s = p.has_primitive() ? p.primitive().stream()
|
||||||
@ -578,6 +578,9 @@ std::pair<std::vector<array>, std::vector<array>> vmap_trace(
|
|||||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const std::vector<int>& in_axes) {
|
const std::vector<int>& in_axes) {
|
||||||
|
// Set the global tracing flag
|
||||||
|
InTracing in_tracing;
|
||||||
|
|
||||||
if (in_axes.size() != inputs.size()) {
|
if (in_axes.size() != inputs.size()) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[vmap] The number of in axes must match the number of inputs.");
|
"[vmap] The number of in axes must match the number of inputs.");
|
||||||
|
@ -14,11 +14,11 @@ void simplify(Arrays... outputs) {
|
|||||||
simplify(std::vector<array>{std::forward<Arrays>(outputs)...});
|
simplify(std::vector<array>{std::forward<Arrays>(outputs)...});
|
||||||
}
|
}
|
||||||
|
|
||||||
void eval(const std::vector<array>& outputs, bool retain_graph = false);
|
void eval(const std::vector<array>& outputs);
|
||||||
|
|
||||||
template <typename... Arrays>
|
template <typename... Arrays>
|
||||||
void eval(Arrays... outputs) {
|
void eval(Arrays... outputs) {
|
||||||
eval(std::vector<array>{std::forward<Arrays>(outputs)...}, false);
|
eval(std::vector<array>{std::forward<Arrays>(outputs)...});
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -14,4 +14,23 @@ std::vector<array> vmap_replace(
|
|||||||
const std::vector<int>& in_axes,
|
const std::vector<int>& in_axes,
|
||||||
const std::vector<int>& out_axes);
|
const std::vector<int>& out_axes);
|
||||||
|
|
||||||
|
// Create an InTracing object during tracing operations to signify to the rest
|
||||||
|
// of the codebase that we are during tracing so evals should not throw away
|
||||||
|
// the graph.
|
||||||
|
struct InTracing {
|
||||||
|
InTracing() {
|
||||||
|
tracing_counter++;
|
||||||
|
}
|
||||||
|
~InTracing() {
|
||||||
|
tracing_counter--;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool in_tracing() {
|
||||||
|
return tracing_counter > 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
static int tracing_counter;
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace mlx::core::detail
|
} // namespace mlx::core::detail
|
||||||
|
@ -39,34 +39,33 @@ py::list to_list(array& a, size_t index, int dim) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto to_scalar(array& a) {
|
auto to_scalar(array& a) {
|
||||||
bool retain_graph = a.is_tracer();
|
|
||||||
switch (a.dtype()) {
|
switch (a.dtype()) {
|
||||||
case bool_:
|
case bool_:
|
||||||
return py::cast(a.item<bool>(retain_graph));
|
return py::cast(a.item<bool>());
|
||||||
case uint8:
|
case uint8:
|
||||||
return py::cast(a.item<uint8_t>(retain_graph));
|
return py::cast(a.item<uint8_t>());
|
||||||
case uint16:
|
case uint16:
|
||||||
return py::cast(a.item<uint16_t>(retain_graph));
|
return py::cast(a.item<uint16_t>());
|
||||||
case uint32:
|
case uint32:
|
||||||
return py::cast(a.item<uint32_t>(retain_graph));
|
return py::cast(a.item<uint32_t>());
|
||||||
case uint64:
|
case uint64:
|
||||||
return py::cast(a.item<uint64_t>(retain_graph));
|
return py::cast(a.item<uint64_t>());
|
||||||
case int8:
|
case int8:
|
||||||
return py::cast(a.item<int8_t>(retain_graph));
|
return py::cast(a.item<int8_t>());
|
||||||
case int16:
|
case int16:
|
||||||
return py::cast(a.item<int16_t>(retain_graph));
|
return py::cast(a.item<int16_t>());
|
||||||
case int32:
|
case int32:
|
||||||
return py::cast(a.item<int32_t>(retain_graph));
|
return py::cast(a.item<int32_t>());
|
||||||
case int64:
|
case int64:
|
||||||
return py::cast(a.item<int64_t>(retain_graph));
|
return py::cast(a.item<int64_t>());
|
||||||
case float16:
|
case float16:
|
||||||
return py::cast(static_cast<float>(a.item<float16_t>(retain_graph)));
|
return py::cast(static_cast<float>(a.item<float16_t>()));
|
||||||
case float32:
|
case float32:
|
||||||
return py::cast(a.item<float>(retain_graph));
|
return py::cast(a.item<float>());
|
||||||
case bfloat16:
|
case bfloat16:
|
||||||
return py::cast(static_cast<float>(a.item<bfloat16_t>(retain_graph)));
|
return py::cast(static_cast<float>(a.item<bfloat16_t>()));
|
||||||
case complex64:
|
case complex64:
|
||||||
return py::cast(a.item<std::complex<float>>(retain_graph));
|
return py::cast(a.item<std::complex<float>>());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -74,7 +73,7 @@ py::object tolist(array& a) {
|
|||||||
if (a.ndim() == 0) {
|
if (a.ndim() == 0) {
|
||||||
return to_scalar(a);
|
return to_scalar(a);
|
||||||
}
|
}
|
||||||
a.eval(a.is_tracer());
|
a.eval();
|
||||||
py::object pl;
|
py::object pl;
|
||||||
switch (a.dtype()) {
|
switch (a.dtype()) {
|
||||||
case bool_:
|
case bool_:
|
||||||
@ -527,7 +526,7 @@ void init_array(py::module_& m) {
|
|||||||
.def_buffer([](array& a) {
|
.def_buffer([](array& a) {
|
||||||
// Eval if not already evaled
|
// Eval if not already evaled
|
||||||
if (!a.is_evaled()) {
|
if (!a.is_evaled()) {
|
||||||
eval({a}, a.is_tracer());
|
a.eval();
|
||||||
}
|
}
|
||||||
return pybind11::buffer_info(
|
return pybind11::buffer_info(
|
||||||
a.data<void>(),
|
a.data<void>(),
|
||||||
@ -751,7 +750,7 @@ void init_array(py::module_& m) {
|
|||||||
"__repr__",
|
"__repr__",
|
||||||
[](array& a) {
|
[](array& a) {
|
||||||
if (!a.is_evaled()) {
|
if (!a.is_evaled()) {
|
||||||
a.eval(a.is_tracer());
|
a.eval();
|
||||||
}
|
}
|
||||||
std::ostringstream os;
|
std::ostringstream os;
|
||||||
os << a;
|
os << a;
|
||||||
|
@ -345,19 +345,15 @@ class PyFileWriter : public io::Writer {
|
|||||||
py::object tell_func_;
|
py::object tell_func_;
|
||||||
};
|
};
|
||||||
|
|
||||||
void mlx_save_helper(
|
void mlx_save_helper(py::object file, array a) {
|
||||||
py::object file,
|
|
||||||
array a,
|
|
||||||
std::optional<bool> retain_graph_) {
|
|
||||||
bool retain_graph = retain_graph_.value_or(a.is_tracer());
|
|
||||||
if (py::isinstance<py::str>(file)) {
|
if (py::isinstance<py::str>(file)) {
|
||||||
save(py::cast<std::string>(file), a, retain_graph);
|
save(py::cast<std::string>(file), a);
|
||||||
return;
|
return;
|
||||||
} else if (is_ostream_object(file)) {
|
} else if (is_ostream_object(file)) {
|
||||||
auto writer = std::make_shared<PyFileWriter>(file);
|
auto writer = std::make_shared<PyFileWriter>(file);
|
||||||
{
|
{
|
||||||
py::gil_scoped_release gil;
|
py::gil_scoped_release gil;
|
||||||
save(writer, a, retain_graph);
|
save(writer, a);
|
||||||
}
|
}
|
||||||
|
|
||||||
return;
|
return;
|
||||||
@ -414,26 +410,23 @@ void mlx_savez_helper(
|
|||||||
auto writer = std::make_shared<PyFileWriter>(py_ostream);
|
auto writer = std::make_shared<PyFileWriter>(py_ostream);
|
||||||
{
|
{
|
||||||
py::gil_scoped_release gil;
|
py::gil_scoped_release gil;
|
||||||
save(writer, a, /*retain_graph=*/a.is_tracer());
|
save(writer, a);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
void mlx_save_safetensor_helper(
|
void mlx_save_safetensor_helper(py::object file, py::dict d) {
|
||||||
py::object file,
|
|
||||||
py::dict d,
|
|
||||||
std::optional<bool> retain_graph) {
|
|
||||||
auto arrays_map = d.cast<std::unordered_map<std::string, array>>();
|
auto arrays_map = d.cast<std::unordered_map<std::string, array>>();
|
||||||
if (py::isinstance<py::str>(file)) {
|
if (py::isinstance<py::str>(file)) {
|
||||||
save_safetensors(py::cast<std::string>(file), arrays_map, retain_graph);
|
save_safetensors(py::cast<std::string>(file), arrays_map);
|
||||||
return;
|
return;
|
||||||
} else if (is_ostream_object(file)) {
|
} else if (is_ostream_object(file)) {
|
||||||
auto writer = std::make_shared<PyFileWriter>(file);
|
auto writer = std::make_shared<PyFileWriter>(file);
|
||||||
{
|
{
|
||||||
py::gil_scoped_release gil;
|
py::gil_scoped_release gil;
|
||||||
save_safetensors(writer, arrays_map, retain_graph);
|
save_safetensors(writer, arrays_map);
|
||||||
}
|
}
|
||||||
|
|
||||||
return;
|
return;
|
||||||
|
@ -17,19 +17,13 @@ using DictOrArray = std::variant<array, std::unordered_map<std::string, array>>;
|
|||||||
std::unordered_map<std::string, array> mlx_load_safetensor_helper(
|
std::unordered_map<std::string, array> mlx_load_safetensor_helper(
|
||||||
py::object file,
|
py::object file,
|
||||||
StreamOrDevice s);
|
StreamOrDevice s);
|
||||||
void mlx_save_safetensor_helper(
|
void mlx_save_safetensor_helper(py::object file, py::dict d);
|
||||||
py::object file,
|
|
||||||
py::dict d,
|
|
||||||
std::optional<bool> retain_graph = std::nullopt);
|
|
||||||
|
|
||||||
DictOrArray mlx_load_helper(
|
DictOrArray mlx_load_helper(
|
||||||
py::object file,
|
py::object file,
|
||||||
std::optional<std::string> format,
|
std::optional<std::string> format,
|
||||||
StreamOrDevice s);
|
StreamOrDevice s);
|
||||||
void mlx_save_helper(
|
void mlx_save_helper(py::object file, array a);
|
||||||
py::object file,
|
|
||||||
array a,
|
|
||||||
std::optional<bool> retain_graph = std::nullopt);
|
|
||||||
void mlx_savez_helper(
|
void mlx_savez_helper(
|
||||||
py::object file,
|
py::object file,
|
||||||
py::args args,
|
py::args args,
|
||||||
|
@ -2902,20 +2902,14 @@ void init_ops(py::module_& m) {
|
|||||||
&mlx_save_helper,
|
&mlx_save_helper,
|
||||||
"file"_a,
|
"file"_a,
|
||||||
"arr"_a,
|
"arr"_a,
|
||||||
py::pos_only(),
|
|
||||||
"retain_graph"_a = std::nullopt,
|
|
||||||
py::kw_only(),
|
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
save(file: str, arr: array, / , retain_graph: Optional[bool] = None)
|
save(file: str, arr: array)
|
||||||
|
|
||||||
Save the array to a binary file in ``.npy`` format.
|
Save the array to a binary file in ``.npy`` format.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file (str): File to which the array is saved
|
file (str): File to which the array is saved
|
||||||
arr (array): Array to be saved.
|
arr (array): Array to be saved.
|
||||||
retain_graph (bool, optional): Whether or not to retain the graph
|
|
||||||
during array evaluation. If left unspecified the graph is retained
|
|
||||||
only if saving is done in a function transformation. Default: ``None``
|
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"savez",
|
"savez",
|
||||||
@ -2999,11 +2993,8 @@ void init_ops(py::module_& m) {
|
|||||||
&mlx_save_safetensor_helper,
|
&mlx_save_safetensor_helper,
|
||||||
"file"_a,
|
"file"_a,
|
||||||
"arrays"_a,
|
"arrays"_a,
|
||||||
py::pos_only(),
|
|
||||||
"retain_graph"_a = std::nullopt,
|
|
||||||
py::kw_only(),
|
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
save_safetensors(file: str, arrays: Dict[str, array], /, retain_graph: Optional[bool] = None)
|
save_safetensors(file: str, arrays: Dict[str, array])
|
||||||
|
|
||||||
Save array(s) to a binary file in ``.safetensors`` format.
|
Save array(s) to a binary file in ``.safetensors`` format.
|
||||||
|
|
||||||
@ -3012,9 +3003,6 @@ void init_ops(py::module_& m) {
|
|||||||
Args:
|
Args:
|
||||||
file (file, str): File in which the array is saved>
|
file (file, str): File in which the array is saved>
|
||||||
arrays (dict(str, array)): The dictionary of names to arrays to be saved.
|
arrays (dict(str, array)): The dictionary of names to arrays to be saved.
|
||||||
retain_graph (bool, optional): Whether or not to retain the graph
|
|
||||||
during array evaluation. If left unspecified the graph is retained
|
|
||||||
only if saving is done in a function transformation. Default: ``None``.
|
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"where",
|
"where",
|
||||||
|
@ -440,11 +440,10 @@ auto py_vmap(
|
|||||||
void init_transforms(py::module_& m) {
|
void init_transforms(py::module_& m) {
|
||||||
m.def(
|
m.def(
|
||||||
"eval",
|
"eval",
|
||||||
[](const py::args& args, bool retain_graph) {
|
[](const py::args& args) {
|
||||||
std::vector<array> arrays = tree_flatten(args);
|
std::vector<array> arrays = tree_flatten(args);
|
||||||
eval(arrays, retain_graph);
|
eval(arrays);
|
||||||
},
|
},
|
||||||
"retain_graph"_a = false,
|
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Evaluate an :class:`array` or tree of :class:`array`.
|
Evaluate an :class:`array` or tree of :class:`array`.
|
||||||
|
|
||||||
@ -453,9 +452,6 @@ void init_transforms(py::module_& m) {
|
|||||||
or a tree of arrays. If a tree is given the nodes can be a Python
|
or a tree of arrays. If a tree is given the nodes can be a Python
|
||||||
:class:`list`, :class:`tuple` or :class:`dict` but the leafs must all be
|
:class:`list`, :class:`tuple` or :class:`dict` but the leafs must all be
|
||||||
an :class:`array`.
|
an :class:`array`.
|
||||||
retain_graph (bool): Indicate that the graph structure should be
|
|
||||||
preserved. This option is intended to enable function transforms
|
|
||||||
which contain control flow based on the value of an array.
|
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"jvp",
|
"jvp",
|
||||||
|
@ -259,6 +259,21 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
|||||||
|
|
||||||
self.assertTrue(mx.allclose(vjps[0], mx.zeros(shape_in)))
|
self.assertTrue(mx.allclose(vjps[0], mx.zeros(shape_in)))
|
||||||
|
|
||||||
|
def test_update_state(self):
|
||||||
|
y = mx.array([1.0])
|
||||||
|
state = mx.zeros((2,))
|
||||||
|
|
||||||
|
def fn(y, x):
|
||||||
|
nonlocal state
|
||||||
|
x = y * x
|
||||||
|
state = state + x
|
||||||
|
return x.sum()
|
||||||
|
|
||||||
|
x = mx.ones((2,))
|
||||||
|
mx.grad(fn)(y, x)
|
||||||
|
mx.eval(state)
|
||||||
|
self.assertTrue(mx.allclose(state, mx.ones((2,))))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -15,18 +15,13 @@ class TestEval(mlx_tests.MLXTestCase):
|
|||||||
self.assertEqual(x.tolist(), [[1, 1], [1, 1]])
|
self.assertEqual(x.tolist(), [[1, 1], [1, 1]])
|
||||||
|
|
||||||
def test_retain_graph(self):
|
def test_retain_graph(self):
|
||||||
def fun(x, retain_graph):
|
def fun(x):
|
||||||
y = 3 * x
|
y = 3 * x
|
||||||
mx.eval(y, retain_graph=retain_graph)
|
mx.eval(y)
|
||||||
return 2 * y
|
return 2 * y
|
||||||
|
|
||||||
dfun_dx_1 = mx.grad(partial(fun, retain_graph=False))
|
dfun_dx = mx.grad(fun)
|
||||||
dfun_dx_2 = mx.grad(partial(fun, retain_graph=True))
|
y = dfun_dx(mx.array(1.0))
|
||||||
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
dfun_dx_1(mx.array(1.0))
|
|
||||||
|
|
||||||
y = dfun_dx_2(mx.array(1.0))
|
|
||||||
self.assertEqual(y.item(), 6.0)
|
self.assertEqual(y.item(), 6.0)
|
||||||
|
|
||||||
|
|
||||||
|
@ -95,19 +95,14 @@ TEST_CASE("test jvp") {
|
|||||||
CHECK_EQ(dout[0].item<float>(), 4.0f);
|
CHECK_EQ(dout[0].item<float>(), 4.0f);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Evaling in function without graph retention throws
|
// Evaling in function while tracing performs graph retention
|
||||||
{
|
{
|
||||||
auto fun = [](const array& x) {
|
|
||||||
auto y = 3 * x;
|
|
||||||
eval(y);
|
|
||||||
return 2 * y;
|
|
||||||
};
|
|
||||||
CHECK_THROWS(jvp(fun, array(1.0f), array(1.0f)));
|
|
||||||
|
|
||||||
// Ok with graph retention
|
|
||||||
auto fun1 = [](const array& x) {
|
auto fun1 = [](const array& x) {
|
||||||
auto y = 3 * x;
|
auto y = 3 * x;
|
||||||
eval({y}, true);
|
eval(y);
|
||||||
|
CHECK(y.is_evaled());
|
||||||
|
CHECK(y.has_primitive());
|
||||||
|
CHECK(y.is_tracer());
|
||||||
return 2 * y;
|
return 2 * y;
|
||||||
};
|
};
|
||||||
CHECK_EQ(jvp(fun1, array(1.0f), array(1.0f)).second.item<float>(), 6.0f);
|
CHECK_EQ(jvp(fun1, array(1.0f), array(1.0f)).second.item<float>(), 6.0f);
|
||||||
@ -251,29 +246,27 @@ TEST_CASE("test grad") {
|
|||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
// Evaluating in the middle of the grad function throws
|
// No graph retention since the output is independent of y
|
||||||
// as it breaks the graph
|
|
||||||
auto fn = [](array x) {
|
|
||||||
x = x + 2.0f;
|
|
||||||
eval(x);
|
|
||||||
return square(x);
|
|
||||||
};
|
|
||||||
CHECK_THROWS(grad(fn)(array(1.0f)));
|
|
||||||
|
|
||||||
// Ok since the output is independent of y
|
|
||||||
auto y = ones({3, 3});
|
auto y = ones({3, 3});
|
||||||
auto fn1 = [y](array x) {
|
auto fn1 = [y](array x) {
|
||||||
x = x + 2.0f;
|
x = x + 2.0f;
|
||||||
eval(y);
|
eval(y);
|
||||||
|
CHECK(x.is_tracer());
|
||||||
|
CHECK(!y.is_tracer());
|
||||||
|
CHECK(y.is_evaled());
|
||||||
|
CHECK(!y.has_primitive());
|
||||||
return square(x);
|
return square(x);
|
||||||
};
|
};
|
||||||
auto dfdx = grad(fn1)(array(1.0f));
|
auto dfdx = grad(fn1)(array(1.0f));
|
||||||
CHECK_EQ(dfdx.item<float>(), 6.0f);
|
CHECK_EQ(dfdx.item<float>(), 6.0f);
|
||||||
|
|
||||||
// Retain the graph to avoid breaking it
|
// Graph automatically retained to compute the grad
|
||||||
auto fn2 = [](array x) {
|
auto fn2 = [](array x) {
|
||||||
x = x + 2.0f;
|
x = x + 2.0f;
|
||||||
eval({x}, true);
|
eval(x);
|
||||||
|
CHECK(x.is_tracer());
|
||||||
|
CHECK(x.is_evaled());
|
||||||
|
CHECK(x.has_primitive());
|
||||||
return square(x);
|
return square(x);
|
||||||
};
|
};
|
||||||
dfdx = grad(fn2)(array(1.0f));
|
dfdx = grad(fn2)(array(1.0f));
|
||||||
@ -283,7 +276,8 @@ TEST_CASE("test grad") {
|
|||||||
// Control flow in grad computation
|
// Control flow in grad computation
|
||||||
{
|
{
|
||||||
auto fn = [](array x) {
|
auto fn = [](array x) {
|
||||||
if (x.item<float>(true) > 1) {
|
x = x + array(2.0f);
|
||||||
|
if (x.item<float>() > 3) {
|
||||||
return square(x);
|
return square(x);
|
||||||
} else {
|
} else {
|
||||||
return 4 * x;
|
return 4 * x;
|
||||||
@ -294,7 +288,7 @@ TEST_CASE("test grad") {
|
|||||||
CHECK_EQ(dfdx.item<float>(), 4.0f);
|
CHECK_EQ(dfdx.item<float>(), 4.0f);
|
||||||
|
|
||||||
dfdx = grad(fn)(array(1.5f));
|
dfdx = grad(fn)(array(1.5f));
|
||||||
CHECK_EQ(dfdx.item<float>(), 3.0f);
|
CHECK_EQ(dfdx.item<float>(), 7.0f);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Grad with multiple inputs
|
// Grad with multiple inputs
|
||||||
@ -1192,3 +1186,19 @@ TEST_CASE("test scan grads") {
|
|||||||
CHECK(array_equal(out, expected).item<bool>());
|
CHECK(array_equal(out, expected).item<bool>());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test update state") {
|
||||||
|
auto y = array({1.0});
|
||||||
|
auto x = array({1.0, 1.0});
|
||||||
|
auto state = array({0.0, 0.0});
|
||||||
|
auto fn = [&state, &x](array y) {
|
||||||
|
x = y * x;
|
||||||
|
state = state + x;
|
||||||
|
return sum(x);
|
||||||
|
};
|
||||||
|
grad(fn)(y);
|
||||||
|
eval(state);
|
||||||
|
CHECK(!state.has_primitive());
|
||||||
|
CHECK(state.is_evaled());
|
||||||
|
CHECK(array_equal(state, array({1.0, 1.0})).item<bool>());
|
||||||
|
}
|
||||||
|
@ -48,36 +48,36 @@ TEST_CASE("test eval multiple") {
|
|||||||
CHECK(array_equal(b, full({10, 10}, 0.0f)).item<bool>());
|
CHECK(array_equal(b, full({10, 10}, 0.0f)).item<bool>());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_CASE("test eval with tracer") {
|
TEST_CASE("test eval with tracer when not tracing") {
|
||||||
|
// Since we are not tracing it doesn't matter that the array flags are
|
||||||
|
// tracers they will always be detached.
|
||||||
auto x = array(1);
|
auto x = array(1);
|
||||||
x.set_tracer(true);
|
x.set_tracer(true);
|
||||||
|
CHECK(!x.is_tracer());
|
||||||
// Ok, x is not a node
|
|
||||||
eval(x);
|
eval(x);
|
||||||
|
CHECK(!x.has_primitive());
|
||||||
|
CHECK(x.is_evaled());
|
||||||
|
|
||||||
x = ones({2, 3});
|
x = ones({2, 3});
|
||||||
x.set_tracer(true);
|
x.set_tracer(true);
|
||||||
CHECK_THROWS(eval(x));
|
eval(x);
|
||||||
|
CHECK(!x.has_primitive());
|
||||||
// Ok retain_graph=true
|
CHECK(x.is_evaled());
|
||||||
eval({x}, true);
|
|
||||||
|
|
||||||
// Make sure all arguments are checked
|
|
||||||
auto y = ones({2, 3});
|
|
||||||
CHECK_THROWS(eval(x, y));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_CASE("test eval graph retention") {
|
TEST_CASE("test eval graph retention when not tracing") {
|
||||||
|
// Since we are not tracing it doesn't matter that the array flags are
|
||||||
|
// tracers they will always be detached.
|
||||||
auto x = array(1);
|
auto x = array(1);
|
||||||
|
x.set_tracer(true);
|
||||||
auto y = array(2);
|
auto y = array(2);
|
||||||
auto z = x + y;
|
auto z = x + y;
|
||||||
eval({z}, true);
|
eval(z);
|
||||||
CHECK(z.has_primitive());
|
CHECK(!z.has_primitive());
|
||||||
CHECK(z.is_evaled());
|
|
||||||
CHECK_EQ(z.item<int>(true), 3);
|
|
||||||
CHECK(z.has_primitive());
|
|
||||||
CHECK(z.is_evaled());
|
CHECK(z.is_evaled());
|
||||||
|
CHECK_EQ(z.item<int>(), 3);
|
||||||
|
|
||||||
|
z.set_tracer(false);
|
||||||
CHECK_EQ(z.item<int>(), 3);
|
CHECK_EQ(z.item<int>(), 3);
|
||||||
CHECK(!z.has_primitive());
|
CHECK(!z.has_primitive());
|
||||||
CHECK(z.is_evaled());
|
CHECK(z.is_evaled());
|
||||||
@ -85,13 +85,7 @@ TEST_CASE("test eval graph retention") {
|
|||||||
z = x + y;
|
z = x + y;
|
||||||
auto a = z + x;
|
auto a = z + x;
|
||||||
auto b = a + y;
|
auto b = a + y;
|
||||||
eval({b}, true);
|
eval(b);
|
||||||
CHECK(z.has_primitive());
|
|
||||||
CHECK(z.is_evaled());
|
|
||||||
CHECK(a.has_primitive());
|
|
||||||
CHECK(a.is_evaled());
|
|
||||||
|
|
||||||
eval({b}, false);
|
|
||||||
CHECK(!z.has_primitive());
|
CHECK(!z.has_primitive());
|
||||||
CHECK(z.is_evaled());
|
CHECK(z.is_evaled());
|
||||||
CHECK(!a.has_primitive());
|
CHECK(!a.has_primitive());
|
||||||
|
@ -183,7 +183,7 @@ TEST_CASE("test vmap with eval") {
|
|||||||
auto fun2 = [](std::vector<array> inputs) {
|
auto fun2 = [](std::vector<array> inputs) {
|
||||||
auto x = inputs[0] + 1;
|
auto x = inputs[0] + 1;
|
||||||
auto y = inputs[1] + 2;
|
auto y = inputs[1] + 2;
|
||||||
eval({x}, true);
|
eval(x);
|
||||||
auto out = add(x, y);
|
auto out = add(x, y);
|
||||||
return std::vector<array>{out};
|
return std::vector<array>{out};
|
||||||
};
|
};
|
||||||
|
Loading…
Reference in New Issue
Block a user