mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Dynamic broadcasting for shapeless compile/export (#1722)
* working towards dynamic broadcast * shapeless broadcast * fix build + nits * use broadcast arrays in quantize matmul * some cleanup / consistency * mend * some comments * add vjp, jvp for broadcast axes
This commit is contained in:
@@ -243,6 +243,7 @@ struct PrimitiveFactory {
|
||||
"RightShift"),
|
||||
SERIALIZE_PRIMITIVE(BlockMaskedMM),
|
||||
SERIALIZE_PRIMITIVE(Broadcast),
|
||||
SERIALIZE_PRIMITIVE(BroadcastAxes),
|
||||
SERIALIZE_PRIMITIVE(Ceil),
|
||||
SERIALIZE_PRIMITIVE(Concatenate),
|
||||
SERIALIZE_PRIMITIVE(Conjugate),
|
||||
@@ -568,7 +569,8 @@ void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) {
|
||||
};
|
||||
|
||||
// Trace to build the graph
|
||||
auto [trace_inputs, trace_outputs] = detail::compile_trace(flat_fun, inputs);
|
||||
auto [trace_inputs, trace_outputs] =
|
||||
detail::compile_trace(flat_fun, inputs, ftable->shapeless);
|
||||
|
||||
// DFS the graph and get the tape
|
||||
auto [tape, parents_map] =
|
||||
|
||||
Reference in New Issue
Block a user