Dynamic broadcasting for shapeless compile/export (#1722)

* working towards dynamic broadcast

* shapeless broadcast

* fix build + nits

* use broadcast arrays in quantize matmul

* some cleanup / consistency

* mend

* some comments

* add vjp, jvp for broadcast axes
This commit is contained in:
Awni Hannun
2025-01-09 11:04:24 -08:00
committed by GitHub
parent ec36bfa317
commit 1ccaf80575
20 changed files with 471 additions and 163 deletions

View File

@@ -31,13 +31,13 @@ class Synchronizer : public Primitive {
DEFINE_PRINT(Synchronize);
};
// Initialize the static tracing counter from transforms_impl.h .
// Initialize the static tracing members from transforms_impl.h
//
// This is used to implement the in_tracing() function the returns true if we
// These are used to implement the in_tracing() function the returns true if we
// are currently under a function transformation and the retain_graph()
// function which returns true if we are forced to retain the graph during
// evaluation.
int detail::InTracing::tracing_counter{0};
std::vector<bool> detail::InTracing::trace_stack{};
int detail::RetainGraph::tracing_counter{0};
array eval_impl(std::vector<array> outputs, bool async) {
@@ -434,7 +434,6 @@ std::pair<std::vector<array>, std::vector<array>> vjp(
}
}
}
std::vector<array> vjps;
for (auto& primal : primals_) {
if (auto cotan_it = cotan_map.find(primal.id());
@@ -629,7 +628,7 @@ ValueAndGradFn value_and_grad(
for (auto arg : args) {
ginputs.push_back(inputs[arg]);
}
// Set the incoming gradient to int32, vjp will cast it to the output type
// Set the incoming gradient to float32, vjp will cast it to the output type
auto [outputs, grads] = vjp(gfun, ginputs, {array(1.0f)});
return std::make_pair(outputs, grads);
};