mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-30 21:51:25 +08:00
Split broadcast so it is always fused in compile (#2318)
This commit is contained in:
parent
656ed7f780
commit
2c11d10f8d
@ -245,6 +245,30 @@ void merge(array& dst, array& src, ParentsMap& parents_map) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Any parent in the divider will continue to refer to `x` but any parent not
|
||||||
|
// in the divider will refer to a copy of the operation.
|
||||||
|
array split_one(
|
||||||
|
const array& x,
|
||||||
|
ParentsMap& parents_map,
|
||||||
|
const std::unordered_set<uintptr_t>& divider) {
|
||||||
|
array y(x.shape(), x.dtype(), x.primitive_ptr(), x.inputs());
|
||||||
|
|
||||||
|
auto& x_parents = parents_map[x.id()];
|
||||||
|
auto& y_parents = parents_map[y.id()];
|
||||||
|
|
||||||
|
for (auto it = x_parents.begin(); it != x_parents.end();) {
|
||||||
|
if (divider.find(it->first.id()) != divider.end()) {
|
||||||
|
it->first.inputs()[it->second] = y;
|
||||||
|
y_parents.emplace_back(std::move(*it));
|
||||||
|
it = x_parents.erase(it);
|
||||||
|
} else {
|
||||||
|
it++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::move(y);
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T, typename... U>
|
template <typename T, typename... U>
|
||||||
std::uintptr_t get_function_address(const std::function<T(U...)>& fun) {
|
std::uintptr_t get_function_address(const std::function<T(U...)>& fun) {
|
||||||
using FunType = T (*)(U...);
|
using FunType = T (*)(U...);
|
||||||
@ -669,10 +693,16 @@ void compile_fuse(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Arrays with a mix of parents outside the compilable section
|
// Arrays with a mix of parents outside the compilable section
|
||||||
// are not fusable
|
// are not fusable except for broadcast which we can split to avoid
|
||||||
|
// stopping fusion
|
||||||
if (!all_parents_in) {
|
if (!all_parents_in) {
|
||||||
// Possible input
|
if (a.has_primitive() && is_broadcast(a.primitive())) {
|
||||||
input_set.insert(a.id());
|
array b = split_one(a, parents_map, cache);
|
||||||
|
recurse(b, depth, s, shape);
|
||||||
|
} else {
|
||||||
|
// Possible input
|
||||||
|
input_set.insert(a.id());
|
||||||
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5,6 +5,7 @@ import io
|
|||||||
import math
|
import math
|
||||||
import unittest
|
import unittest
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from io import StringIO
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx_tests
|
import mlx_tests
|
||||||
@ -991,6 +992,28 @@ class TestCompile(mlx_tests.MLXTestCase):
|
|||||||
y_compiled = mx.compile(fun)(x).item()
|
y_compiled = mx.compile(fun)(x).item()
|
||||||
self.assertEqual(y, y_compiled)
|
self.assertEqual(y, y_compiled)
|
||||||
|
|
||||||
|
def test_shared_broadcast(self):
|
||||||
|
def fun(x, y, z):
|
||||||
|
yy = mx.broadcast_to(y, z.shape)
|
||||||
|
return (x + yy * z), yy.sum()
|
||||||
|
|
||||||
|
a = mx.random.normal((10, 10))
|
||||||
|
b = mx.array(0.1)
|
||||||
|
c = mx.random.normal((10, 10))
|
||||||
|
mx.eval(a, b, c)
|
||||||
|
fc = mx.compile(fun)
|
||||||
|
d = fc(a, b, c)
|
||||||
|
|
||||||
|
s = StringIO()
|
||||||
|
mx.export_to_dot(s, a=a, b=b, c=c, d1=d[0], d2=d[1])
|
||||||
|
s.seek(0)
|
||||||
|
s = s.read()
|
||||||
|
|
||||||
|
self.assertTrue("CompiledBroadcastMultiplyAdd" in s)
|
||||||
|
d_hat = fun(a, b, c)
|
||||||
|
self.assertTrue(mx.allclose(d[0], d_hat[0]))
|
||||||
|
self.assertTrue(mx.allclose(d[1], d_hat[1]))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
mlx_tests.MLXTestRunner()
|
mlx_tests.MLXTestRunner()
|
||||||
|
Loading…
Reference in New Issue
Block a user