mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-15 17:39:05 +08:00
Custom transforms (#1246)
This commit is contained in:
committed by
GitHub
parent
a3c287354f
commit
5c1fa64fb0
@@ -17,6 +17,10 @@ bool in_tracing() {
|
||||
return detail::InTracing::in_tracing();
|
||||
}
|
||||
|
||||
bool retain_graph() {
|
||||
return detail::RetainGraph::retain_graph();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
|
||||
@@ -102,7 +106,7 @@ void array::eval() {
|
||||
}
|
||||
|
||||
bool array::is_tracer() const {
|
||||
return array_desc_->is_tracer && in_tracing();
|
||||
return array_desc_->is_tracer && in_tracing() || retain_graph();
|
||||
}
|
||||
|
||||
void array::set_data(allocator::Buffer buffer, deleter_t d) {
|
||||
|
||||
Reference in New Issue
Block a user