mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Dynamic broadcasting for shapeless compile/export (#1722)
* working towards dynamic broadcast * shapeless broadcast * fix build + nits * use broadcast arrays in quantize matmul * some cleanup / consistency * mend * some comments * add vjp, jvp for broadcast axes
This commit is contained in:
@@ -20,19 +20,23 @@ std::vector<array> vmap_replace(
|
||||
// of the codebase that we are during tracing so evals should not throw away
|
||||
// the graph.
|
||||
struct InTracing {
|
||||
InTracing() {
|
||||
tracing_counter++;
|
||||
explicit InTracing(bool dynamic = false) {
|
||||
trace_stack.push_back(dynamic);
|
||||
}
|
||||
~InTracing() {
|
||||
tracing_counter--;
|
||||
trace_stack.pop_back();
|
||||
}
|
||||
|
||||
static bool in_tracing() {
|
||||
return tracing_counter > 0;
|
||||
return !trace_stack.empty();
|
||||
}
|
||||
static bool in_dynamic_tracing() {
|
||||
// compile is always and only the outer-most transform
|
||||
return in_tracing() && trace_stack.front();
|
||||
}
|
||||
|
||||
private:
|
||||
static int tracing_counter;
|
||||
static std::vector<bool> trace_stack;
|
||||
};
|
||||
|
||||
struct RetainGraph {
|
||||
@@ -51,4 +55,20 @@ struct RetainGraph {
|
||||
static int tracing_counter;
|
||||
};
|
||||
|
||||
/** Return true if we are currently performing a function transformation in
|
||||
* order to keep the graph when evaluating tracer arrays. */
|
||||
inline bool in_tracing() {
|
||||
return detail::InTracing::in_tracing();
|
||||
}
|
||||
|
||||
/** Return true if we are in a dynamic (shapeless) trace used for compiling or
|
||||
* exporting graphs with dynamic shapes. */
|
||||
inline bool in_dynamic_tracing() {
|
||||
return detail::InTracing::in_dynamic_tracing();
|
||||
}
|
||||
|
||||
inline bool retain_graph() {
|
||||
return detail::RetainGraph::retain_graph();
|
||||
}
|
||||
|
||||
} // namespace mlx::core::detail
|
||||
|
||||
Reference in New Issue
Block a user