mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Dynamic broadcasting for shapeless compile/export (#1722)
* working towards dynamic broadcast * shapeless broadcast * fix build + nits * use broadcast arrays in quantize matmul * some cleanup / consistency * mend * some comments * add vjp, jvp for broadcast axes
This commit is contained in:
@@ -10,20 +10,6 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
/** Return true if we are currently performing a function transformation in
|
||||
* order to keep the graph when evaluating tracer arrays. */
|
||||
bool in_tracing() {
|
||||
return detail::InTracing::in_tracing();
|
||||
}
|
||||
|
||||
bool retain_graph() {
|
||||
return detail::RetainGraph::retain_graph();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(Shape{}, dtype)) {
|
||||
auto cval = static_cast<complex64_t>(val);
|
||||
@@ -119,7 +105,8 @@ void array::eval() {
|
||||
}
|
||||
|
||||
bool array::is_tracer() const {
|
||||
return (array_desc_->is_tracer && in_tracing()) || retain_graph();
|
||||
return (array_desc_->is_tracer && detail::in_tracing()) ||
|
||||
detail::retain_graph();
|
||||
}
|
||||
|
||||
void array::set_data(allocator::Buffer buffer, Deleter d) {
|
||||
|
||||
Reference in New Issue
Block a user