mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-17 09:18:12 +08:00
Split broadcast so it is always fused in compile (#2318)
This commit is contained in:

committed by
GitHub

parent
656ed7f780
commit
2c11d10f8d
@@ -245,6 +245,30 @@ void merge(array& dst, array& src, ParentsMap& parents_map) {
|
||||
}
|
||||
}
|
||||
|
||||
// Any parent in the divider will continue to refer to `x` but any parent not
|
||||
// in the divider will refer to a copy of the operation.
|
||||
array split_one(
|
||||
const array& x,
|
||||
ParentsMap& parents_map,
|
||||
const std::unordered_set<uintptr_t>& divider) {
|
||||
array y(x.shape(), x.dtype(), x.primitive_ptr(), x.inputs());
|
||||
|
||||
auto& x_parents = parents_map[x.id()];
|
||||
auto& y_parents = parents_map[y.id()];
|
||||
|
||||
for (auto it = x_parents.begin(); it != x_parents.end();) {
|
||||
if (divider.find(it->first.id()) != divider.end()) {
|
||||
it->first.inputs()[it->second] = y;
|
||||
y_parents.emplace_back(std::move(*it));
|
||||
it = x_parents.erase(it);
|
||||
} else {
|
||||
it++;
|
||||
}
|
||||
}
|
||||
|
||||
return std::move(y);
|
||||
}
|
||||
|
||||
template <typename T, typename... U>
|
||||
std::uintptr_t get_function_address(const std::function<T(U...)>& fun) {
|
||||
using FunType = T (*)(U...);
|
||||
@@ -669,10 +693,16 @@ void compile_fuse(
|
||||
}
|
||||
|
||||
// Arrays with a mix of parents outside the compilable section
|
||||
// are not fusable
|
||||
// are not fusable except for broadcast which we can split to avoid
|
||||
// stopping fusion
|
||||
if (!all_parents_in) {
|
||||
// Possible input
|
||||
input_set.insert(a.id());
|
||||
if (a.has_primitive() && is_broadcast(a.primitive())) {
|
||||
array b = split_one(a, parents_map, cache);
|
||||
recurse(b, depth, s, shape);
|
||||
} else {
|
||||
// Possible input
|
||||
input_set.insert(a.id());
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user