mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Dynamic broadcasting for shapeless compile/export (#1722)
* working towards dynamic broadcast * shapeless broadcast * fix build + nits * use broadcast arrays in quantize matmul * some cleanup / consistency * mend * some comments * add vjp, jvp for broadcast axes
This commit is contained in:
@@ -284,9 +284,10 @@ CompilerCache& compiler_cache() {
|
||||
|
||||
std::pair<std::vector<array>, std::vector<array>> compile_trace(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||
const std::vector<array>& inputs) {
|
||||
const std::vector<array>& inputs,
|
||||
bool shapeless) {
|
||||
// Set the global tracing flag.
|
||||
detail::InTracing in_tracing;
|
||||
detail::InTracing in_tracing{shapeless};
|
||||
|
||||
// Run the function on placeholder inputs
|
||||
// to get compute graph
|
||||
@@ -824,7 +825,8 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
// Set the constants
|
||||
entry.constants = std::move(constants);
|
||||
// Trace to build the graph
|
||||
std::tie(entry.inputs, entry.outputs) = compile_trace(fun, inputs);
|
||||
std::tie(entry.inputs, entry.outputs) =
|
||||
compile_trace(fun, inputs, shapeless);
|
||||
|
||||
// DFS the graph and get a tape, and a map of array id to (parent,
|
||||
// position in parent inputs)
|
||||
|
||||
Reference in New Issue
Block a user