mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-15 09:29:26 +08:00
Removes the retain_graph flag (#385)
* Adds global tracing flag * Removes retain_graph in favor of is_tracer
This commit is contained in:
committed by
GitHub
parent
449b43762e
commit
a611b0bc82
@@ -6,6 +6,7 @@
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/transforms.h"
|
||||
#include "mlx/transforms_impl.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@@ -21,6 +22,12 @@ std::pair<size_t, std::vector<size_t>> cum_prod(const std::vector<int>& shape) {
|
||||
return {cum_prod, strides};
|
||||
}
|
||||
|
||||
/** Return true if we are currently performing a function transformation in
|
||||
* order to keep the graph when evaluating tracer arrays. */
|
||||
bool in_tracing() {
|
||||
return detail::InTracing::in_tracing();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
|
||||
@@ -62,8 +69,12 @@ void array::detach() {
|
||||
array_desc_->primitive = nullptr;
|
||||
}
|
||||
|
||||
void array::eval(bool retain_graph /* = false */) {
|
||||
mlx::core::eval({*this}, retain_graph);
|
||||
void array::eval() {
|
||||
mlx::core::eval({*this});
|
||||
}
|
||||
|
||||
bool array::is_tracer() const {
|
||||
return array_desc_->is_tracer && in_tracing();
|
||||
}
|
||||
|
||||
void array::set_data(allocator::Buffer buffer, deleter_t d) {
|
||||
|
||||
Reference in New Issue
Block a user