From e7a2a3dcd143e0d128b8300803375c6d463a4222 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 28 May 2024 15:18:18 -0700 Subject: [PATCH] Fix a couple bugs (#1161) * fix jit reduce for RMS norm * make strides a single buffer * better eval error message * fix compiling with inf and bf16 * fix cpu compile with bf16 --- mlx/backend/common/make_compiled_preamble.sh | 1 + mlx/backend/metal/compiled.cpp | 43 +++++++++++++------- mlx/backend/metal/jit_kernels.cpp | 5 ++- mlx/backend/metal/kernels.h | 1 + mlx/backend/metal/kernels/reduce.metal | 4 +- mlx/backend/metal/nojit_kernels.cpp | 1 + mlx/backend/metal/reduce.cpp | 14 +++---- mlx/transforms.cpp | 10 +++-- python/tests/test_compile.py | 7 ++++ 9 files changed, 59 insertions(+), 27 deletions(-) diff --git a/mlx/backend/common/make_compiled_preamble.sh b/mlx/backend/common/make_compiled_preamble.sh index 050fce25e..93eca9ca1 100644 --- a/mlx/backend/common/make_compiled_preamble.sh +++ b/mlx/backend/common/make_compiled_preamble.sh @@ -28,6 +28,7 @@ const char* get_kernel_preamble() { return R"preamble( $INCLUDES $CONTENT +using namespace mlx::core; using namespace mlx::core::detail; )preamble"; } diff --git a/mlx/backend/metal/compiled.cpp b/mlx/backend/metal/compiled.cpp index 0bfb177a2..b6bb055f2 100644 --- a/mlx/backend/metal/compiled.cpp +++ b/mlx/backend/metal/compiled.cpp @@ -56,12 +56,15 @@ inline void build_kernel( } else { add_indices = true; os << " device const " << get_type_string(x.dtype()) << "* " << xname - << " [[buffer(" << cnt++ << ")]]," << std::endl - << " constant const size_t* " << xname << "_strides [[buffer(" - << cnt++ << ")]]," << std::endl; + << " [[buffer(" << cnt++ << ")]]," << std::endl; } } + if (add_indices) { + os << " constant const size_t* in_strides [[buffer(" << cnt++ + << ")]],\n"; + } + // Add the output arguments for (auto& x : outputs) { os << " device " << get_type_string(x.dtype()) << "* " @@ -110,13 +113,17 @@ inline void build_kernel( } // Read the inputs in tmps - for (auto& x : inputs) { + int nc_in_count = 0; + for (int i = 0; i < inputs.size(); ++i) { + auto& x = inputs[i]; auto& xname = namer.get_name(x); if (is_constant(x)) { - os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "; + auto type_str = get_type_string(x.dtype()); + os << " auto tmp_" << xname << " = static_cast<" + << get_type_string(x.dtype()) << ">("; print_constant(os, x); - os << ";" << std::endl; + os << ");" << std::endl; } else if (is_scalar(x)) { os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = " << xname << "[0];" << std::endl; @@ -124,17 +131,20 @@ inline void build_kernel( os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = " << xname << "[index];" << std::endl; } else if (!dynamic_dims) { + int offset = nc_in_count * ndim; os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = " << xname << "["; - os << "index_0 * " << xname << "_strides[0]"; + os << "index_0 * " << "in_strides[" << offset << "]"; for (int i = 1; i < ndim; i++) { - os << " + index_" << i << " * " << xname << "_strides[" << i << "]"; + os << " + index_" << i << " * " << "in_strides[" << offset + i << "]"; } os << "];" << std::endl; + nc_in_count++; } else { os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = " - << xname << "[elem_to_loc(index, output_shape, " << xname - << "_strides, ndim)];" << std::endl; + << xname << "[elem_to_loc(index, output_shape, in_strides + " + << nc_in_count * ndim << ", ndim)];" << std::endl; + nc_in_count++; } } @@ -296,6 +306,7 @@ void Compiled::eval_gpu( // Put the inputs in int cnt = 0; int stride_idx = 1; // idx 0 is the output strides + std::vector in_strides; for (int i = 0; i < inputs.size(); i++) { if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) { continue; @@ -303,13 +314,17 @@ void Compiled::eval_gpu( auto& x = inputs[i]; compute_encoder.set_input_array(x, cnt++); if (!contiguous && !is_scalar(x)) { - compute_encoder->setBytes( - strides[stride_idx].data(), - strides[stride_idx].size() * sizeof(size_t), - cnt++); + in_strides.insert( + in_strides.end(), + strides[stride_idx].begin(), + strides[stride_idx].end()); stride_idx++; } } + if (!in_strides.empty()) { + compute_encoder->setBytes( + in_strides.data(), in_strides.size() * sizeof(size_t), cnt++); + } compiled_allocate_outputs( inputs, outputs, inputs_, constant_ids_, contiguous, true); diff --git a/mlx/backend/metal/jit_kernels.cpp b/mlx/backend/metal/jit_kernels.cpp index 813b5c392..4616f0cd6 100644 --- a/mlx/backend/metal/jit_kernels.cpp +++ b/mlx/backend/metal/jit_kernels.cpp @@ -259,11 +259,14 @@ MTL::ComputePipelineState* get_reduce_init_kernel( MTL::ComputePipelineState* get_reduce_kernel( metal::Device& d, const std::string& kernel_name, + const std::string& op_name, const array& in, const array& out) { std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); auto lib = d.get_library(lib_name); if (lib == nullptr) { + std::string op_type = op_name; + op_type[0] = std::toupper(op_name[0]); bool non_atomic = out.dtype() == int64 || out.dtype() == uint64; std::ostringstream kernel_source; kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce() @@ -273,7 +276,7 @@ MTL::ComputePipelineState* get_reduce_kernel( lib_name, get_type_string(in.dtype()), get_type_string(out.dtype()), - op_name(out)); + op_type); lib = d.get_library(lib_name, kernel_source.str()); } return d.get_kernel(kernel_name, lib); diff --git a/mlx/backend/metal/kernels.h b/mlx/backend/metal/kernels.h index e1df0521b..ed1181e4a 100644 --- a/mlx/backend/metal/kernels.h +++ b/mlx/backend/metal/kernels.h @@ -76,6 +76,7 @@ MTL::ComputePipelineState* get_reduce_init_kernel( MTL::ComputePipelineState* get_reduce_kernel( metal::Device& d, const std::string& kernel_name, + const std::string& op_name, const array& in, const array& out); diff --git a/mlx/backend/metal/kernels/reduce.metal b/mlx/backend/metal/kernels/reduce.metal index b8d94df06..0e4650015 100644 --- a/mlx/backend/metal/kernels/reduce.metal +++ b/mlx/backend/metal/kernels/reduce.metal @@ -38,8 +38,8 @@ #define instantiate_reduce_ops(inst_f, type_f) \ type_f(inst_f, sum, Sum) \ type_f(inst_f, prod, Prod) \ - type_f(inst_f, min_, Min) \ - type_f(inst_f, max_, Max) + type_f(inst_f, min, Min) \ + type_f(inst_f, max, Max) // Special case for bool reductions #define instantiate_reduce_from_types_helper( \ diff --git a/mlx/backend/metal/nojit_kernels.cpp b/mlx/backend/metal/nojit_kernels.cpp index e14d099d3..f91db5a53 100644 --- a/mlx/backend/metal/nojit_kernels.cpp +++ b/mlx/backend/metal/nojit_kernels.cpp @@ -98,6 +98,7 @@ MTL::ComputePipelineState* get_reduce_init_kernel( MTL::ComputePipelineState* get_reduce_kernel( metal::Device& d, const std::string& kernel_name, + const std::string&, const array&, const array&) { return d.get_kernel(kernel_name); diff --git a/mlx/backend/metal/reduce.cpp b/mlx/backend/metal/reduce.cpp index 73b166658..2d110a5d0 100644 --- a/mlx/backend/metal/reduce.cpp +++ b/mlx/backend/metal/reduce.cpp @@ -46,7 +46,7 @@ void all_reduce_dispatch( kernel_name += "NoAtomics"; } kernel_name += "_reduce_" + op_name + type_to_name(in); - auto kernel = get_reduce_kernel(d, kernel_name, in, out); + auto kernel = get_reduce_kernel(d, kernel_name, op_name, in, out); compute_encoder->setComputePipelineState(kernel); @@ -175,7 +175,7 @@ void row_reduce_general_dispatch( kname << "rowGeneral" << small_desc << "_reduce_" << op_name << type_to_name(in); - auto kernel = get_reduce_kernel(d, kname.str(), in, out); + auto kernel = get_reduce_kernel(d, kname.str(), op_name, in, out); compute_encoder->setComputePipelineState(kernel); // Get dispatch grid dims @@ -342,7 +342,7 @@ void strided_reduce_general_dispatch( if (reduction_size * non_col_reductions < 16) { // Select kernel auto kernel = get_reduce_kernel( - d, "colSmall_reduce_" + op_name + type_to_name(in), in, out); + d, "colSmall_reduce_" + op_name + type_to_name(in), op_name, in, out); compute_encoder->setComputePipelineState(kernel); // Select block dims @@ -384,7 +384,7 @@ void strided_reduce_general_dispatch( kernel_name += "NoAtomics"; } kernel_name += "_reduce_" + op_name + type_to_name(in); - auto kernel = get_reduce_kernel(d, kernel_name, in, out); + auto kernel = get_reduce_kernel(d, kernel_name, op_name, in, out); compute_encoder->setComputePipelineState(kernel); @@ -501,7 +501,7 @@ void strided_reduce_general_dispatch( std::string kernel_name = "rowGeneralNoAtomics_reduce_" + op_name + type_to_name(intermediate); auto row_reduce_kernel = - get_reduce_kernel(d, kernel_name, intermediate, out); + get_reduce_kernel(d, kernel_name, op_name, intermediate, out); compute_encoder->setComputePipelineState(row_reduce_kernel); compute_encoder.set_input_array(intermediate, 0); @@ -573,10 +573,10 @@ void Reduce::eval_gpu(const std::vector& inputs, array& out) { op_name = out.dtype() == bool_ ? "and" : "prod"; break; case Reduce::Min: - op_name = out.dtype() == bool_ ? "and" : "min_"; + op_name = out.dtype() == bool_ ? "and" : "min"; break; case Reduce::Max: - op_name = out.dtype() == bool_ ? "or" : "max_"; + op_name = out.dtype() == bool_ ? "or" : "max"; break; } diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index 5798dfffa..3951abce0 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -88,9 +88,13 @@ array eval_impl(std::vector outputs, bool async) { " transformations like compile or vmap is not allowed."); } throw std::runtime_error( - "[eval] Attempting to eval an array without a primitive. " - "This may be a bug, please file an issue here: " - " https://github.com/ml-explore/mlx/issues."); + "[eval] Attempting to eval an array without a primitive.\n" + "If you are compiling a function, make sure all the inputs " + "and outputs are captured:\n" + "https://ml-explore.github.io/mlx/build/html/usage/compile.html#pure-functions.\n" + "If you are not using compile, this may be a bug. " + "Please file an issue here:\n" + "https://github.com/ml-explore/mlx/issues."); } if (a.primitive().stream() != in.primitive().stream()) { needs_signal.insert(in.id()); diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index 20ca62101..5018e47f4 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -704,6 +704,13 @@ class TestCompile(mlx_tests.MLXTestCase): self.assertEqual(y1.item(), y2.item()) self.assertEqual(y1.item(), 6) + def test_inf_constant(self): + def fn(x): + return mx.where(mx.isinf(x), 0, 1) + + x = mx.array([0, float("inf"), 1], dtype=mx.bfloat16) + self.assertTrue(mx.array_equal(mx.compile(fn)(x), fn(x))) + if __name__ == "__main__": unittest.main()