mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Dynamic broadcasting for shapeless compile/export (#1722)
* working towards dynamic broadcast * shapeless broadcast * fix build + nits * use broadcast arrays in quantize matmul * some cleanup / consistency * mend * some comments * add vjp, jvp for broadcast axes
This commit is contained in:
@@ -31,13 +31,13 @@ class Synchronizer : public Primitive {
|
||||
DEFINE_PRINT(Synchronize);
|
||||
};
|
||||
|
||||
// Initialize the static tracing counter from transforms_impl.h .
|
||||
// Initialize the static tracing members from transforms_impl.h
|
||||
//
|
||||
// This is used to implement the in_tracing() function the returns true if we
|
||||
// These are used to implement the in_tracing() function the returns true if we
|
||||
// are currently under a function transformation and the retain_graph()
|
||||
// function which returns true if we are forced to retain the graph during
|
||||
// evaluation.
|
||||
int detail::InTracing::tracing_counter{0};
|
||||
std::vector<bool> detail::InTracing::trace_stack{};
|
||||
int detail::RetainGraph::tracing_counter{0};
|
||||
|
||||
array eval_impl(std::vector<array> outputs, bool async) {
|
||||
@@ -434,7 +434,6 @@ std::pair<std::vector<array>, std::vector<array>> vjp(
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<array> vjps;
|
||||
for (auto& primal : primals_) {
|
||||
if (auto cotan_it = cotan_map.find(primal.id());
|
||||
@@ -629,7 +628,7 @@ ValueAndGradFn value_and_grad(
|
||||
for (auto arg : args) {
|
||||
ginputs.push_back(inputs[arg]);
|
||||
}
|
||||
// Set the incoming gradient to int32, vjp will cast it to the output type
|
||||
// Set the incoming gradient to float32, vjp will cast it to the output type
|
||||
auto [outputs, grads] = vjp(gfun, ginputs, {array(1.0f)});
|
||||
return std::make_pair(outputs, grads);
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user