From 7c441600feee3311e25b24fdd1719be71f4ef53b Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 11 Mar 2024 06:31:31 -0700 Subject: [PATCH] Compile stride bug (#812) * fix compile stride bug * revert sdpa fix * fix cpu * fix bug with simplifying outputs --- mlx/array.cpp | 21 ++++++++++++++----- mlx/array.h | 7 +++++++ mlx/backend/common/compiled_cpu.cpp | 4 +++- mlx/backend/metal/compiled.cpp | 4 +++- .../scaled_dot_product_attention.metal | 5 ++--- .../metal/scaled_dot_product_attention.cpp | 2 ++ mlx/compile.cpp | 4 ++-- python/tests/test_compile.py | 8 +++++++ tests/compile_tests.cpp | 15 +++++++++++++ 9 files changed, 58 insertions(+), 12 deletions(-) diff --git a/mlx/array.cpp b/mlx/array.cpp index add6be279..864c6973d 100644 --- a/mlx/array.cpp +++ b/mlx/array.cpp @@ -162,12 +162,23 @@ void array::copy_shared_buffer(const array& other) { copy_shared_buffer(other, other.strides(), other.flags(), other.data_size()); } -void array::move_shared_buffer(array other) { +void array::move_shared_buffer( + array other, + const std::vector& strides, + Flags flags, + size_t data_size, + size_t offset /* = 0 */) { array_desc_->data = std::move(other.array_desc_->data); - array_desc_->strides = other.strides(); - array_desc_->flags = other.flags(); - array_desc_->data_size = other.data_size(); - array_desc_->data_ptr = other.array_desc_->data_ptr; + array_desc_->strides = strides; + array_desc_->flags = flags; + array_desc_->data_size = data_size; + auto char_offset = sizeof(char) * itemsize() * offset; + array_desc_->data_ptr = static_cast( + static_cast(other.array_desc_->data_ptr) + char_offset); +} + +void array::move_shared_buffer(array other) { + move_shared_buffer(other, other.strides(), other.flags(), other.data_size()); } array::ArrayDesc::ArrayDesc(const std::vector& shape, Dtype dtype) diff --git a/mlx/array.h b/mlx/array.h index 740e69886..7325a1e76 100644 --- a/mlx/array.h +++ b/mlx/array.h @@ -339,6 +339,13 @@ class array { void copy_shared_buffer(const array& other); + void move_shared_buffer( + array other, + const std::vector& strides, + Flags flags, + size_t data_size, + size_t offset = 0); + void move_shared_buffer(array other); void overwrite_descriptor(const array& other) { diff --git a/mlx/backend/common/compiled_cpu.cpp b/mlx/backend/common/compiled_cpu.cpp index 8ab409886..f3b136bfa 100644 --- a/mlx/backend/common/compiled_cpu.cpp +++ b/mlx/backend/common/compiled_cpu.cpp @@ -385,7 +385,9 @@ void Compiled::eval_cpu( if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() && in.is_donatable() && constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) { - outputs[o++].copy_shared_buffer(in); + outputs[o].copy_shared_buffer( + in, outputs[o].strides(), in.flags(), in.data_size()); + o++; } } for (; o < outputs.size(); ++o) { diff --git a/mlx/backend/metal/compiled.cpp b/mlx/backend/metal/compiled.cpp index 3b1ee116a..0879e623d 100644 --- a/mlx/backend/metal/compiled.cpp +++ b/mlx/backend/metal/compiled.cpp @@ -329,7 +329,9 @@ void Compiled::eval_gpu( if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() && in.is_donatable() && constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) { - outputs[o++].move_shared_buffer(in); + outputs[o].move_shared_buffer( + in, outputs[o].strides(), in.flags(), in.data_size()); + o++; } } for (; o < outputs.size(); ++o) { diff --git a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal index 16ab5975b..fb9f0a111 100644 --- a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +++ b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal @@ -13,12 +13,10 @@ templatesetThreadgroupMemoryLength(tgroupMemorySize, 0); compute_encoder->dispatchThreadgroups(grid_dims, group_dims); { diff --git a/mlx/compile.cpp b/mlx/compile.cpp index 678f6e8ea..0ede320e5 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -439,7 +439,8 @@ void compile_simplify( } auto& src = parents->second[j].first; auto& dst = parents->second[i].first; - if (src.id() != dst.id() && array_equivalent(src, dst)) { + if (src.id() != dst.id() && array_equivalent(src, dst) && + output_set.find(src.id()) == output_set.end()) { merge(dst, src, parents_map); mask[j] = true; } @@ -456,7 +457,6 @@ void compile_simplify( return output_set.find(a.id()) == output_set.end(); } }; - bool discard = maybe_merge_parents(arr); for (auto& s : arr.siblings()) { discard &= maybe_merge_parents(s); diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index 18f523211..98d7f7276 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -605,6 +605,14 @@ class TestCompile(mlx_tests.MLXTestCase): with self.assertRaises(ValueError): out = fun(mx.array(0.0), y=MyClass()) + def test_compile_create_list(self): + @mx.compile + def fun(): + return [0.1 * mx.zeros((2,)), 0.1 * mx.zeros((2,))] + + out = fun() + mx.eval(out) + if __name__ == "__main__": unittest.main() diff --git a/tests/compile_tests.cpp b/tests/compile_tests.cpp index 569ab0913..c8c308c99 100644 --- a/tests/compile_tests.cpp +++ b/tests/compile_tests.cpp @@ -703,3 +703,18 @@ TEST_CASE("test shapeless compile") { CHECK_NE(out.inputs()[1].id(), out2.inputs()[1].id()); } } + +auto compile_broadcast_add(const std::vector& inputs) { + auto b = zeros({8, 8}); + return std::vector{inputs[0] + b}; +} + +TEST_CASE("test compile strides") { + { + auto cfun = compile(compile_broadcast_add); + auto a = zeros({1, 8, 8}); + auto out = cfun({a})[0]; + eval(out); + CHECK_EQ(out.strides().size(), 3); + } +}