Dynamic broadcasting for shapeless compile/export (#1722)

* working towards dynamic broadcast

* shapeless broadcast

* fix build + nits

* use broadcast arrays in quantize matmul

* some cleanup / consistency

* mend

* some comments

* add vjp, jvp for broadcast axes
This commit is contained in:
Awni Hannun
2025-01-09 11:04:24 -08:00
committed by GitHub
parent ec36bfa317
commit 1ccaf80575
20 changed files with 471 additions and 163 deletions

View File

@@ -10,20 +10,6 @@
namespace mlx::core {
namespace {
/** Return true if we are currently performing a function transformation in
* order to keep the graph when evaluating tracer arrays. */
bool in_tracing() {
return detail::InTracing::in_tracing();
}
bool retain_graph() {
return detail::RetainGraph::retain_graph();
}
} // namespace
array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
: array_desc_(std::make_shared<ArrayDesc>(Shape{}, dtype)) {
auto cval = static_cast<complex64_t>(val);
@@ -119,7 +105,8 @@ void array::eval() {
}
bool array::is_tracer() const {
return (array_desc_->is_tracer && in_tracing()) || retain_graph();
return (array_desc_->is_tracer && detail::in_tracing()) ||
detail::retain_graph();
}
void array::set_data(allocator::Buffer buffer, Deleter d) {