Dynamic broadcasting for shapeless compile/export (#1722)

* working towards dynamic broadcast

* shapeless broadcast

* fix build + nits

* use broadcast arrays in quantize matmul

* some cleanup / consistency

* mend

* some comments

* add vjp, jvp for broadcast axes
This commit is contained in:
Awni Hannun
2025-01-09 11:04:24 -08:00
committed by GitHub
parent ec36bfa317
commit 1ccaf80575
20 changed files with 471 additions and 163 deletions

View File

@@ -20,19 +20,23 @@ std::vector<array> vmap_replace(
// of the codebase that we are during tracing so evals should not throw away
// the graph.
struct InTracing {
InTracing() {
tracing_counter++;
explicit InTracing(bool dynamic = false) {
trace_stack.push_back(dynamic);
}
~InTracing() {
tracing_counter--;
trace_stack.pop_back();
}
static bool in_tracing() {
return tracing_counter > 0;
return !trace_stack.empty();
}
static bool in_dynamic_tracing() {
// compile is always and only the outer-most transform
return in_tracing() && trace_stack.front();
}
private:
static int tracing_counter;
static std::vector<bool> trace_stack;
};
struct RetainGraph {
@@ -51,4 +55,20 @@ struct RetainGraph {
static int tracing_counter;
};
/** Return true if we are currently performing a function transformation in
* order to keep the graph when evaluating tracer arrays. */
inline bool in_tracing() {
return detail::InTracing::in_tracing();
}
/** Return true if we are in a dynamic (shapeless) trace used for compiling or
* exporting graphs with dynamic shapes. */
inline bool in_dynamic_tracing() {
return detail::InTracing::in_dynamic_tracing();
}
inline bool retain_graph() {
return detail::RetainGraph::retain_graph();
}
} // namespace mlx::core::detail