mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Limit compile buffers (#1887)
* limit compile buffers * maybe not flaky test
This commit is contained in:
parent
78ba24c37d
commit
c707b2b0a6
@ -15,6 +15,7 @@
|
|||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
constexpr int max_compile_depth = 11;
|
constexpr int max_compile_depth = 11;
|
||||||
|
constexpr int max_compile_arrays = 24;
|
||||||
|
|
||||||
bool is_unary(const Primitive& p) {
|
bool is_unary(const Primitive& p) {
|
||||||
return (
|
return (
|
||||||
@ -570,6 +571,7 @@ void compile_fuse(
|
|||||||
|
|
||||||
std::function<void(const array&, int, const Stream&, const Shape&)> recurse;
|
std::function<void(const array&, int, const Stream&, const Shape&)> recurse;
|
||||||
std::unordered_set<uintptr_t> cache;
|
std::unordered_set<uintptr_t> cache;
|
||||||
|
std::unordered_set<uintptr_t> input_set;
|
||||||
recurse = [&](const array& a,
|
recurse = [&](const array& a,
|
||||||
int depth,
|
int depth,
|
||||||
const Stream& s,
|
const Stream& s,
|
||||||
@ -587,6 +589,8 @@ void compile_fuse(
|
|||||||
if (depth >= max_compile_depth || !a.has_primitive() ||
|
if (depth >= max_compile_depth || !a.has_primitive() ||
|
||||||
a.primitive().stream() != s || !is_fusable(a.primitive()) ||
|
a.primitive().stream() != s || !is_fusable(a.primitive()) ||
|
||||||
(output_map.find(a.id()) != output_map.end() && a.shape() != shape)) {
|
(output_map.find(a.id()) != output_map.end() && a.shape() != shape)) {
|
||||||
|
// Possible input
|
||||||
|
input_set.insert(a.id());
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -607,9 +611,20 @@ void compile_fuse(
|
|||||||
// Arrays with a mix of parents outside the compilable section
|
// Arrays with a mix of parents outside the compilable section
|
||||||
// are not fusable
|
// are not fusable
|
||||||
if (!all_parents_in) {
|
if (!all_parents_in) {
|
||||||
|
// Possible input
|
||||||
|
input_set.insert(a.id());
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (output_map.find(a.id()) != output_map.end()) {
|
||||||
|
input_set.insert(a.id());
|
||||||
|
} else {
|
||||||
|
// Not an input anymore since fusing it
|
||||||
|
input_set.erase(a.id());
|
||||||
|
}
|
||||||
|
if (input_set.size() >= max_compile_arrays) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
cache.insert({a.id()});
|
cache.insert({a.id()});
|
||||||
|
|
||||||
for (auto& in : a.inputs()) {
|
for (auto& in : a.inputs()) {
|
||||||
@ -630,7 +645,7 @@ void compile_fuse(
|
|||||||
|
|
||||||
// Recurse a second time to build the tape in the right
|
// Recurse a second time to build the tape in the right
|
||||||
// order and collect the inputs
|
// order and collect the inputs
|
||||||
std::unordered_set<uintptr_t> input_set;
|
input_set.clear();
|
||||||
std::vector<array> inputs;
|
std::vector<array> inputs;
|
||||||
std::vector<array> fused_tape;
|
std::vector<array> fused_tape;
|
||||||
std::unordered_set<uintptr_t> tape_set;
|
std::unordered_set<uintptr_t> tape_set;
|
||||||
|
@ -177,6 +177,7 @@ class TestDistributed(mlx_tests.MLXTestCase):
|
|||||||
def test_donation(self):
|
def test_donation(self):
|
||||||
x = mx.random.normal((1024,))
|
x = mx.random.normal((1024,))
|
||||||
mx.eval(x)
|
mx.eval(x)
|
||||||
|
mx.synchronize(mx.default_stream(mx.default_device()))
|
||||||
|
|
||||||
mx.metal.reset_peak_memory()
|
mx.metal.reset_peak_memory()
|
||||||
scale = mx.array(2.0)
|
scale = mx.array(2.0)
|
||||||
|
@ -815,6 +815,31 @@ class TestCompile(mlx_tests.MLXTestCase):
|
|||||||
out = fun(*inputs)
|
out = fun(*inputs)
|
||||||
self.assertTrue(mx.allclose(out, mx.full((2, 2), 20)))
|
self.assertTrue(mx.allclose(out, mx.full((2, 2), 20)))
|
||||||
|
|
||||||
|
@mx.compile
|
||||||
|
def fun(arrs):
|
||||||
|
for _ in range(6):
|
||||||
|
arrs = [x + y for x, y in zip(arrs[::2], arrs[1::2])]
|
||||||
|
return arrs[0]
|
||||||
|
|
||||||
|
arrs = [mx.array([1.0, 2.0]) for _ in range(64)]
|
||||||
|
out = fun(arrs)
|
||||||
|
self.assertTrue(mx.allclose(out, mx.array([64.0, 128.0])))
|
||||||
|
|
||||||
|
def test_compile_many_outputs(self):
|
||||||
|
|
||||||
|
@mx.compile
|
||||||
|
def fun(arr):
|
||||||
|
arrs = [arr] * 64
|
||||||
|
first_arrs = None
|
||||||
|
for _ in range(6):
|
||||||
|
arrs = [x + y for x, y in zip(arrs[::2], arrs[1::2])]
|
||||||
|
if first_arrs is None:
|
||||||
|
first_arrs = arrs
|
||||||
|
return arrs[0], first_arrs
|
||||||
|
|
||||||
|
out = fun(mx.array([1.0, 2.0]))
|
||||||
|
self.assertTrue(mx.allclose(out[0], mx.array([64.0, 128.0])))
|
||||||
|
|
||||||
def test_shapeless_compile_matmul(self):
|
def test_shapeless_compile_matmul(self):
|
||||||
a = mx.array([0.0, 1.0, 2.0])
|
a = mx.array([0.0, 1.0, 2.0])
|
||||||
b = mx.array([0.0, 1.0, 2.0])
|
b = mx.array([0.0, 1.0, 2.0])
|
||||||
|
@ -385,6 +385,7 @@ class TestLoad(mlx_tests.MLXTestCase):
|
|||||||
mx.eval(x)
|
mx.eval(x)
|
||||||
save_file = os.path.join(self.test_dir, "donation.npy")
|
save_file = os.path.join(self.test_dir, "donation.npy")
|
||||||
mx.save(save_file, x)
|
mx.save(save_file, x)
|
||||||
|
mx.synchronize(mx.default_stream(mx.default_device()))
|
||||||
|
|
||||||
mx.metal.reset_peak_memory()
|
mx.metal.reset_peak_memory()
|
||||||
scale = mx.array(2.0)
|
scale = mx.array(2.0)
|
||||||
|
Loading…
Reference in New Issue
Block a user