Dynamic broadcasting for shapeless compile/export (#1722)

* working towards dynamic broadcast

* shapeless broadcast

* fix build + nits

* use broadcast arrays in quantize matmul

* some cleanup / consistency

* mend

* some comments

* add vjp, jvp for broadcast axes
This commit is contained in:
Awni Hannun
2025-01-09 11:04:24 -08:00
committed by GitHub
parent ec36bfa317
commit 1ccaf80575
20 changed files with 471 additions and 163 deletions

View File

@@ -243,6 +243,7 @@ struct PrimitiveFactory {
"RightShift"),
SERIALIZE_PRIMITIVE(BlockMaskedMM),
SERIALIZE_PRIMITIVE(Broadcast),
SERIALIZE_PRIMITIVE(BroadcastAxes),
SERIALIZE_PRIMITIVE(Ceil),
SERIALIZE_PRIMITIVE(Concatenate),
SERIALIZE_PRIMITIVE(Conjugate),
@@ -568,7 +569,8 @@ void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) {
};
// Trace to build the graph
auto [trace_inputs, trace_outputs] = detail::compile_trace(flat_fun, inputs);
auto [trace_inputs, trace_outputs] =
detail::compile_trace(flat_fun, inputs, ftable->shapeless);
// DFS the graph and get the tape
auto [tape, parents_map] =