Dynamic broadcasting for shapeless compile/export (#1722)

* working towards dynamic broadcast

* shapeless broadcast

* fix build + nits

* use broadcast arrays in quantize matmul

* some cleanup / consistency

* mend

* some comments

* add vjp, jvp for broadcast axes
This commit is contained in:
Awni Hannun
2025-01-09 11:04:24 -08:00
committed by GitHub
parent ec36bfa317
commit 1ccaf80575
20 changed files with 471 additions and 163 deletions

View File

@@ -284,9 +284,10 @@ CompilerCache& compiler_cache() {
std::pair<std::vector<array>, std::vector<array>> compile_trace(
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
const std::vector<array>& inputs) {
const std::vector<array>& inputs,
bool shapeless) {
// Set the global tracing flag.
detail::InTracing in_tracing;
detail::InTracing in_tracing{shapeless};
// Run the function on placeholder inputs
// to get compute graph
@@ -824,7 +825,8 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
// Set the constants
entry.constants = std::move(constants);
// Trace to build the graph
std::tie(entry.inputs, entry.outputs) = compile_trace(fun, inputs);
std::tie(entry.inputs, entry.outputs) =
compile_trace(fun, inputs, shapeless);
// DFS the graph and get a tape, and a map of array id to (parent,
// position in parent inputs)