diff --git a/mlx/backend/common/compiled.cpp b/mlx/backend/common/compiled.cpp index 8c95eb9ca..762939b1e 100644 --- a/mlx/backend/common/compiled.cpp +++ b/mlx/backend/common/compiled.cpp @@ -126,4 +126,102 @@ std::string build_lib_name( return os.str(); } +bool compiled_check_contiguity( + const std::vector& inputs, + const std::vector& shape) { + bool contiguous = true; + bool all_contig = true; + bool all_row_contig = true; + bool all_col_contig = true; + int non_scalar_inputs = 0; + for (const auto& x : inputs) { + if (is_scalar(x)) { + continue; + } + non_scalar_inputs++; + bool shape_eq = x.shape() == shape; + all_contig &= (x.flags().contiguous && shape_eq); + all_row_contig &= (x.flags().row_contiguous && shape_eq); + all_col_contig &= (x.flags().col_contiguous && shape_eq); + } + if (non_scalar_inputs > 1 && !all_row_contig && !all_col_contig) { + contiguous = false; + } else if (non_scalar_inputs == 1 && !all_contig) { + contiguous = false; + } else if (non_scalar_inputs == 0 && !shape.empty()) { + contiguous = false; + } + return contiguous; +} + +void compiled_allocate_outputs( + const std::vector& inputs, + std::vector& outputs, + const std::vector& inputs_, + const std::unordered_set& constant_ids_, + bool contiguous, + bool move_buffers /* = false */) { + if (contiguous) { + int o = 0; + std::vector strides; + size_t data_size; + array::Flags flags; + for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) { + auto& in = inputs[i]; + // Conditions for donation + // - Correct size + // - Not a scalar + // - Donatable + // - Not a constant + if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) && + in.is_donatable() && + constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) { + if (move_buffers) { + outputs[o++].move_shared_buffer(in); + } else { + outputs[o++].copy_shared_buffer(in); + } + } + // Get representative input flags to properly set non-donated outputs + if (strides.empty() && in.size() == outputs[0].size()) { + strides = in.strides(); + flags = in.flags(); + data_size = in.data_size(); + } + } + for (; o < outputs.size(); ++o) { + outputs[o].set_data( + allocator::malloc_or_wait(data_size * outputs[o].itemsize()), + data_size, + strides, + flags); + } + } else { + int o = 0; + for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) { + auto& in = inputs[i]; + // Conditions for donation + // - Row contiguous + // - Donatable + // - Correct size + // - Not a constant + if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() && + in.is_donatable() && + constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) { + if (move_buffers) { + outputs[o].move_shared_buffer( + in, outputs[o].strides(), in.flags(), in.data_size()); + } else { + outputs[o].copy_shared_buffer( + in, outputs[o].strides(), in.flags(), in.data_size()); + } + o++; + } + } + for (; o < outputs.size(); ++o) { + outputs[o].set_data(allocator::malloc_or_wait(outputs[o].nbytes())); + } + } +} + } // namespace mlx::core diff --git a/mlx/backend/common/compiled.h b/mlx/backend/common/compiled.h index d01fe4fdc..a08a53e68 100644 --- a/mlx/backend/common/compiled.h +++ b/mlx/backend/common/compiled.h @@ -53,4 +53,18 @@ inline bool is_scalar(const array& x) { return x.ndim() == 0; } +// Check if we can use a contiguous operation given inputs and the output shape +bool compiled_check_contiguity( + const std::vector& inputs, + const std::vector& shape); + +// Allocate space for the outputs possibly with input donation +void compiled_allocate_outputs( + const std::vector& inputs, + std::vector& outputs, + const std::vector& inputs_, + const std::unordered_set& constant_ids_, + bool contiguous, + bool move_buffers = false); + } // namespace mlx::core diff --git a/mlx/backend/common/compiled_cpu.cpp b/mlx/backend/common/compiled_cpu.cpp index f3b136bfa..40acb74b5 100644 --- a/mlx/backend/common/compiled_cpu.cpp +++ b/mlx/backend/common/compiled_cpu.cpp @@ -52,8 +52,25 @@ void* compile( return nullptr; } + std::string kernel_file_name; + + // Deal with long kernel names. Maximum length for files on macOS is 255 + // characters. Clip file name with a little extra room and append a 16 + // character hash. + constexpr int max_file_name_length = 245; + if (kernel_name.size() > max_file_name_length) { + std::ostringstream file_name; + file_name + << std::string_view(kernel_name).substr(0, max_file_name_length - 16); + auto file_id = std::hash{}(kernel_name); + file_name << "_" << std::hex << std::setw(16) << file_id << std::dec; + kernel_file_name = file_name.str(); + } else { + kernel_file_name = kernel_name; + } + std::ostringstream shared_lib_name; - shared_lib_name << "lib" << kernel_name << ".so"; + shared_lib_name << "lib" << kernel_file_name << ".so"; auto shared_lib_path = get_temp_file(shared_lib_name.str()); bool lib_exists = false; { @@ -64,7 +81,7 @@ void* compile( if (!lib_exists) { // Open source file and write source code to it std::ostringstream source_file_name; - source_file_name << kernel_name << ".cpp"; + source_file_name << kernel_file_name << ".cpp"; auto source_file_path = get_temp_file(source_file_name.str()); std::ofstream source_file(source_file_path); @@ -248,28 +265,7 @@ void Compiled::eval_cpu( // Figure out which kernel we are using auto& shape = outputs[0].shape(); - bool contiguous = true; - { - bool all_contig = true; - bool all_row_contig = true; - bool all_col_contig = true; - int non_scalar_inputs = 0; - for (auto& x : inputs) { - if (is_scalar(x)) { - continue; - } - non_scalar_inputs++; - bool shape_eq = x.shape() == shape; - all_contig &= (x.flags().contiguous && shape_eq); - all_row_contig &= (x.flags().row_contiguous && shape_eq); - all_col_contig &= (x.flags().col_contiguous && shape_eq); - } - if (non_scalar_inputs > 1 && !all_row_contig && !all_col_contig) { - contiguous = false; - } else if (non_scalar_inputs == 1 && !all_contig) { - contiguous = false; - } - } + bool contiguous = compiled_check_contiguity(inputs, shape); // Handle all broadcasting and collect function input arguments std::vector args; @@ -342,58 +338,8 @@ void Compiled::eval_cpu( fn_ptr = compile(kernel_name, kernel.str()); } - // Allocate space for the outputs possibly with input donation - if (contiguous) { - int o = 0; - std::vector strides; - size_t data_size; - array::Flags flags; - for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) { - auto& in = inputs[i]; - // Conditions for donation - // - Contiguous - // - Donatable - // - Correct size - // - Not a constant - if (in.flags().contiguous && !is_scalar(in) && in.is_donatable() && - constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) { - outputs[o++].copy_shared_buffer(in); - } - // Get representative input flags to properly set non-donated outputs - if (strides.empty() && in.size() == outputs[0].size()) { - strides = in.strides(); - flags = in.flags(); - data_size = in.data_size(); - } - } - for (; o < outputs.size(); ++o) { - outputs[o].set_data( - allocator::malloc_or_wait(data_size * outputs[o].itemsize()), - data_size, - strides, - flags); - } - } else { - int o = 0; - for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) { - auto& in = inputs[i]; - // Conditions for donation - // - Row contiguous - // - Donatable - // - Correct size - // - Not a constant - if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() && - in.is_donatable() && - constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) { - outputs[o].copy_shared_buffer( - in, outputs[o].strides(), in.flags(), in.data_size()); - o++; - } - } - for (; o < outputs.size(); ++o) { - outputs[o].set_data(allocator::malloc_or_wait(outputs[o].nbytes())); - } - } + compiled_allocate_outputs( + inputs, outputs, inputs_, constant_ids_, contiguous, false); for (auto& x : outputs) { args.push_back(x.data()); diff --git a/mlx/backend/metal/compiled.cpp b/mlx/backend/metal/compiled.cpp index f59d0582a..b7327348a 100644 --- a/mlx/backend/metal/compiled.cpp +++ b/mlx/backend/metal/compiled.cpp @@ -229,14 +229,7 @@ void Compiled::eval_gpu( // Figure out which kernel we are using auto& output_shape = outputs[0].shape(); - bool contiguous = true; - for (auto& x : inputs) { - if ((!x.flags().row_contiguous || x.shape() != output_shape) && - !is_scalar(x)) { - contiguous = false; - break; - } - } + bool contiguous = compiled_check_contiguity(inputs, output_shape); // Collapse contiguous dims to route to a faster kernel if possible. Also // handle all broadcasting. @@ -317,28 +310,8 @@ void Compiled::eval_gpu( } } - // Allocate space for the outputs possibly with input donation - { - int o = 0; - for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) { - auto& in = inputs[i]; - // Conditions for donation - // - Row contiguous - // - Donatable - // - Correct size - // - Not a constant - if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() && - in.is_donatable() && - constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) { - outputs[o].move_shared_buffer( - in, outputs[o].strides(), in.flags(), in.data_size()); - o++; - } - } - for (; o < outputs.size(); ++o) { - outputs[o].set_data(allocator::malloc_or_wait(outputs[o].nbytes())); - } - } + compiled_allocate_outputs( + inputs, outputs, inputs_, constant_ids_, contiguous, true); // Put the outputs in for (auto& x : outputs) { diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index cfdd334cb..d7cce31f5 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -671,6 +671,26 @@ class TestCompile(mlx_tests.MLXTestCase): out = cmean(x) self.assertTrue(mx.allclose(out, mean(x))) + def test_compile_broadcast_only(self): + def fn(a): + a = mx.broadcast_to(a, (1,)) + return a + a + + out = mx.compile(fn)(mx.array(2.0)) + # Make sure repr can be called + self.assertTrue(repr(out) is not None) + self.assertTrue(mx.array_equal(out, mx.array([4.0]))) + + def test_compile_with_long_name(self): + def fn(a, b): + for _ in range(10): + a = a - 1.0 + b = b - 1.0 + return a + b + + out = mx.compile(fn)(mx.array(10.0), mx.array(20.0)) + self.assertEqual(out.item(), 10.0) + if __name__ == "__main__": unittest.main()