From 979abf462bbc5ac3ce6f0e52161de9a12f8782ff Mon Sep 17 00:00:00 2001 From: Ronan Collobert Date: Fri, 31 Oct 2025 09:43:29 -0700 Subject: [PATCH] WIP (metal) --- mlx/backend/metal/compiled.cpp | 14 +++++------ mlx/backend/metal/conv.cpp | 2 +- mlx/backend/metal/custom_kernel.cpp | 14 +++++------ mlx/backend/metal/device.cpp | 5 +++- mlx/backend/metal/device.h | 2 +- mlx/backend/metal/fence.cpp | 8 +++--- mlx/backend/metal/fft.cpp | 22 ++++++++--------- mlx/backend/metal/hadamard.cpp | 2 +- mlx/backend/metal/indexing.cpp | 38 ++++++++++++++--------------- mlx/backend/metal/matmul.cpp | 2 +- mlx/backend/metal/nojit_kernels.cpp | 4 +-- mlx/backend/metal/normalization.cpp | 6 ++--- mlx/backend/metal/primitives.cpp | 30 +++++++++++++---------- mlx/backend/metal/reduce.cpp | 14 +++++------ mlx/backend/metal/resident.cpp | 2 +- mlx/backend/metal/slicing.cpp | 2 +- mlx/backend/metal/utils.h | 7 ++++++ 17 files changed, 94 insertions(+), 80 deletions(-) diff --git a/mlx/backend/metal/compiled.cpp b/mlx/backend/metal/compiled.cpp index eb51ab750..e2173dc87 100644 --- a/mlx/backend/metal/compiled.cpp +++ b/mlx/backend/metal/compiled.cpp @@ -109,7 +109,7 @@ inline void build_kernel( // Read constant / contiguous inputs in tmps std::vector nc_inputs; - for (int i = 0; i < inputs.size(); ++i) { + for (int i = 0; i < std::ssize(inputs); ++i) { auto& x = inputs[i]; auto& xname = namer.get_name(x); @@ -134,7 +134,7 @@ inline void build_kernel( } // Initialize the indices for non-contiguous inputs - for (int i = 0; i < nc_inputs.size(); ++i) { + for (int i = 0; i < std::ssize(nc_inputs); ++i) { auto& xname = namer.get_name(nc_inputs[i]); os += fmt::format(" {0} index_{1} = ", idx_type, xname); if (ndim == 1) { @@ -174,7 +174,7 @@ inline void build_kernel( os += fmt::format(" for (int d = {0}; d >= 0; --d) {{\n", ndim - 3); } os += " uint l = zpos % output_shape[d];\n"; - for (int i = 0; i < nc_inputs.size(); ++i) { + for (int i = 0; i < std::ssize(nc_inputs); ++i) { auto& xname = namer.get_name(nc_inputs[i]); os += fmt::format(" index_{0} += ", xname); if (dynamic_dims) { @@ -195,7 +195,7 @@ inline void build_kernel( } // Read non-contiguous inputs into tmps - for (int i = 0; i < nc_inputs.size(); ++i) { + for (int i = 0; i < std::ssize(nc_inputs); ++i) { auto& x = nc_inputs[i]; auto& xname = namer.get_name(x); os += fmt::format( @@ -214,7 +214,7 @@ inline void build_kernel( } else { os += x.primitive().name(); os += "()("; - for (int i = 0; i < x.inputs().size() - 1; i++) { + for (int i = 0; i < std::ssize(x.inputs()) - 1; i++) { os += fmt::format("tmp_{0}, ", namer.get_name(x.inputs()[i])); } os += fmt::format("tmp_{0});\n", namer.get_name(x.inputs().back())); @@ -227,7 +227,7 @@ inline void build_kernel( } // Increment indices and close per thread loop if (work_per_thread > 1) { - for (int i = 0; i < nc_inputs.size(); ++i) { + for (int i = 0; i < std::ssize(nc_inputs); ++i) { auto& x = nc_inputs[i]; auto& xname = namer.get_name(x); if (!dynamic_dims) { @@ -396,7 +396,7 @@ void Compiled::eval_gpu( int cnt = 0; int stride_idx = 1; // idx 0 is the output strides Strides in_strides; - for (int i = 0; i < inputs.size(); i++) { + for (int i = 0; i < std::ssize(inputs); i++) { if (is_constant_(i)) { continue; } diff --git a/mlx/backend/metal/conv.cpp b/mlx/backend/metal/conv.cpp index b4a674ff0..e09a3175c 100644 --- a/mlx/backend/metal/conv.cpp +++ b/mlx/backend/metal/conv.cpp @@ -990,7 +990,7 @@ void conv_3D_gpu( const std::vector& wt_dilation, const std::vector& in_dilation, bool flip, - std::vector& copies) { + std::vector& /* copies */) { // Make conv params MLXConvParams<3> conv_params{ /* const int N = */ static_cast(in.shape(0)), diff --git a/mlx/backend/metal/custom_kernel.cpp b/mlx/backend/metal/custom_kernel.cpp index c48b93c91..deaf1f0f6 100644 --- a/mlx/backend/metal/custom_kernel.cpp +++ b/mlx/backend/metal/custom_kernel.cpp @@ -68,7 +68,7 @@ std::string write_signature( int index = 0; constexpr int max_constant_array_size = 8; // Add inputs - for (int i = 0; i < inputs.size(); ++i) { + for (int i = 0; i < std::ssize(inputs); ++i) { const auto& name = input_names[i]; const auto& arr = inputs[i]; auto dtype = get_type_string(arr.dtype()); @@ -109,7 +109,7 @@ std::string write_signature( } } // Add outputs - for (int i = 0; i < output_names.size(); ++i) { + for (int i = 0; i < std::ssize(output_names); ++i) { const auto& name = output_names[i]; const auto& dtype = output_dtypes[i]; kernel_source += " device "; @@ -126,8 +126,8 @@ std::string write_signature( kernel_source += " [[buffer("; kernel_source += std::to_string(index); kernel_source += ")]]"; - if (index < inputs.size() + output_names.size() - 1 || - attributes.size() > 0) { + if (index < std::ssize(inputs) + std::ssize(output_names) - 1 || + std::ssize(attributes) > 0) { kernel_source += ",\n"; } else { kernel_source += ") {\n"; @@ -138,7 +138,7 @@ std::string write_signature( index = 0; for (const auto& attr : attributes) { kernel_source += attr; - if (index < attributes.size() - 1) { + if (index < std::ssize(attributes) - 1) { kernel_source += ",\n"; } else { kernel_source += ") {\n"; @@ -381,7 +381,7 @@ void CustomKernel::eval_gpu( auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); int index = 0; - for (int i = 0; i < checked_inputs.size(); i++) { + for (int i = 0; i < std::ssize(checked_inputs); i++) { const array& in = checked_inputs[i]; auto& shape_info = shape_infos_[i]; compute_encoder.set_input_array(in, index); @@ -408,7 +408,7 @@ void CustomKernel::eval_gpu( } const auto [tx, ty, tz] = threadgroup_; - auto tg_size = tx * ty * tz; + unsigned long tg_size = tx * ty * tz; auto max_tg_size = kernel->maxTotalThreadsPerThreadgroup(); if (tg_size > max_tg_size) { std::ostringstream msg; diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index e82d734a2..5465603df 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -127,6 +127,9 @@ std::pair load_swiftpm_library( } } } +#else + (void)device; + (void)lib_name; #endif return {nullptr, nullptr}; } @@ -713,7 +716,7 @@ MTL::LinkedFunctions* Device::get_linked_functions_( auto lfuncs = MTL::LinkedFunctions::linkedFunctions(); std::vector objs(funcs.size()); - for (int i = 0; i < funcs.size(); i++) { + for (int i = 0; i < std::ssize(funcs); i++) { objs[i] = funcs[i]; } diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index fefb7cdc0..663e04cd8 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -137,7 +137,7 @@ struct DeviceStream { // Data updated between command buffers MTL::CommandBuffer* buffer{nullptr}; int buffer_ops{0}; - size_t buffer_sizes{0}; + int64_t buffer_sizes{0}; // The command encoder, fence, and temporaries are updated between command // encoders diff --git a/mlx/backend/metal/fence.cpp b/mlx/backend/metal/fence.cpp index 5abdf7309..8068e84a4 100644 --- a/mlx/backend/metal/fence.cpp +++ b/mlx/backend/metal/fence.cpp @@ -76,7 +76,7 @@ void Fence::wait(Stream stream, const array& x) { auto command_buffer = d.get_command_buffer(idx); command_buffer->encodeWait(static_cast(f.fence), f.count); command_buffer->addCompletedHandler( - [fence_ = fence_](MTL::CommandBuffer* cbuf) {}); + [fence_ = fence_](MTL::CommandBuffer* /* cbuf */) {}); return; } @@ -96,7 +96,7 @@ void Fence::wait(Stream stream, const array& x) { compute_encoder.dispatch_threads(kernel_dims, kernel_dims); d.get_command_buffer(idx)->addCompletedHandler( - [fence_ = fence_](MTL::CommandBuffer* cbuf) {}); + [fence_ = fence_](MTL::CommandBuffer* /* cbuf */) {}); } void Fence::update(Stream stream, const array& x) { @@ -124,7 +124,7 @@ void Fence::update(Stream stream, const array& x) { command_buffer->encodeSignalEvent( static_cast(f.fence), f.count); command_buffer->addCompletedHandler( - [fence_ = fence_](MTL::CommandBuffer* cbuf) {}); + [fence_ = fence_](MTL::CommandBuffer* /* cbuf */) {}); return; } @@ -154,7 +154,7 @@ void Fence::update(Stream stream, const array& x) { compute_encoder.dispatch_threads(kernel_dims, kernel_dims); d.get_command_buffer(idx)->addCompletedHandler( - [fence_ = fence_](MTL::CommandBuffer* cbuf) {}); + [fence_ = fence_](MTL::CommandBuffer* /* cbuf */) {}); } } // namespace mlx::core diff --git a/mlx/backend/metal/fft.cpp b/mlx/backend/metal/fft.cpp index d329a4685..74165910c 100644 --- a/mlx/backend/metal/fft.cpp +++ b/mlx/backend/metal/fft.cpp @@ -60,7 +60,7 @@ struct FourStepParams { void fft_op( const array& in, array& out, - size_t axis, + int64_t axis, bool inverse, bool real, const FourStepParams four_step_params, @@ -93,7 +93,7 @@ std::vector plan_stockham_fft(int n) { if (n == 1) { return plan; } - for (int i = 0; i < radices.size(); i++) { + for (int i = 0; i < std::ssize(radices); i++) { int radix = radices[i]; // Manually tuned radices for powers of 2 if (is_power_of_2(orig_n) && orig_n < 512 && radix > 4) { @@ -181,7 +181,7 @@ int compute_elems_per_thread(FFTPlan plan) { steps.insert(steps.end(), plan.stockham.begin(), plan.stockham.end()); steps.insert(steps.end(), plan.rader.begin(), plan.rader.end()); std::set used_radices; - for (int i = 0; i < steps.size(); i++) { + for (int i = 0; i < std::ssize(steps); i++) { int radix = radices[i % radices.size()]; if (steps[i] > 0) { used_radices.insert(radix); @@ -260,7 +260,7 @@ int primitive_root(int n) { std::tuple compute_raders_constants( int rader_n, - const Stream& s) { + const Stream& /* s */) { int proot = primitive_root(rader_n); // Fermat's little theorem int inv = mod_exp(proot, rader_n - 2, rader_n); @@ -508,7 +508,7 @@ void four_step_fft( void fft_op( const array& in, array& out, - size_t axis, + int64_t axis, bool inverse, bool real, const FourStepParams four_step_params, @@ -612,11 +612,11 @@ void fft_op( // Start of radix/rader step constants int index = 4; - for (int i = 0; i < plan.stockham.size(); i++) { + for (int i = 0; i < std::ssize(plan.stockham); i++) { func_consts.push_back(make_int(&plan.stockham[i], index)); index += 1; } - for (int i = 0; i < plan.rader.size(); i++) { + for (int i = 0; i < std::ssize(plan.rader); i++) { func_consts.push_back(make_int(&plan.rader[i], index)); index += 1; } @@ -771,8 +771,8 @@ void nd_fft_op( array temp1(temp_shape, complex64, nullptr, {}); array temp2(temp_shape, complex64, nullptr, {}); std::vector temp_arrs = {temp1, temp2}; - for (int i = axes.size() - 1; i >= 0; i--) { - int reverse_index = axes.size() - i - 1; + for (int i = std::ssize(axes) - 1; i >= 0; i--) { + int reverse_index = std::ssize(axes) - i - 1; // For 5D and above, we don't want to reallocate our two temporary arrays bool inplace = reverse_index >= 3 && i != 0; // Opposite order for fft vs ifft @@ -780,8 +780,8 @@ void nd_fft_op( size_t axis = axes[index]; // Mirror np.fft.(i)rfftn and perform a real transform // only on the final axis. - bool step_real = (real && index == axes.size() - 1); - const array& in_arr = i == axes.size() - 1 ? in : temp_arrs[1 - i % 2]; + bool step_real = (real && index == std::ssize(axes) - 1); + const array& in_arr = i == std::ssize(axes) - 1 ? in : temp_arrs[1 - i % 2]; array& out_arr = i == 0 ? out : temp_arrs[i % 2]; fft_op(in_arr, out_arr, axis, inverse, step_real, inplace, s); } diff --git a/mlx/backend/metal/hadamard.cpp b/mlx/backend/metal/hadamard.cpp index 65a877151..bf115c630 100644 --- a/mlx/backend/metal/hadamard.cpp +++ b/mlx/backend/metal/hadamard.cpp @@ -43,7 +43,7 @@ std::string gen_hadamard_codelet(int m) { while (end != std::string_view::npos) { source << " tmp[" << index << "] = "; auto row = matrix.substr(start, end - start); - for (int i = 0; i < row.length(); i++) { + for (int i = 0; i < std::ssize(row); i++) { source << " " << row[i] << " x[" << i << "]"; } source << ";" << std::endl; diff --git a/mlx/backend/metal/indexing.cpp b/mlx/backend/metal/indexing.cpp index 025098757..8a215267b 100644 --- a/mlx/backend/metal/indexing.cpp +++ b/mlx/backend/metal/indexing.cpp @@ -52,7 +52,7 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { auto& s = stream(); auto& d = metal::device(s.device); - size_t slice_size = 1; + int64_t slice_size = 1; for (auto s : slice_sizes_) { slice_size *= s; } @@ -94,8 +94,8 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { auto kernel = d.get_kernel(kernel_name, lib); compute_encoder.set_compute_pipeline_state(kernel); - size_t dim_x = (slice_size + work_per_thread - 1) / work_per_thread; - size_t dim_y = indices.size(); + int64_t dim_x = (slice_size + work_per_thread - 1) / work_per_thread; + int64_t dim_y = indices.size(); auto group_dims = get_block_dims(dim_x, dim_y, 1); MTL::Size grid_dims = MTL::Size(dim_x, dim_y, 1); @@ -110,7 +110,7 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { } int idx_ndim = nidx ? inputs[1].ndim() : 0; - size_t ndim = src.ndim(); + int64_t ndim = src.ndim(); std::string kernel_name = fmt::format( "gather{0}{1}_{2}_{3}_{4}", @@ -149,8 +149,8 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { // Launch 3D grid of threads // First two dimensions for the indices, the last one for the slice - size_t dim0 = 1; - size_t dim1 = 1; + int64_t dim0 = 1; + int64_t dim1 = 1; if (nidx) { if (inputs[1].ndim() >= 1) { dim0 = inputs[1].shape(0); @@ -159,13 +159,13 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { dim1 = inputs[1].size() / dim0; } } - size_t dim2 = slice_size; + int64_t dim2 = slice_size; auto group_dims = get_block_dims(dim0, dim1, dim2); MTL::Size grid_dims = MTL::Size(dim0, dim1, dim2); // Collect all idx shapes and strides into one place std::vector idx_shapes; - std::vector idx_strides; + std::vector idx_strides; std::vector idx_contigs; for (int i = 0; i < nidx; ++i) { idx_shapes.insert( @@ -246,7 +246,7 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { auto& d = metal::device(s.device); int idx_ndim = nidx ? inputs[1].ndim() : 0; - size_t idx_size = nidx ? inputs[1].size() : 1; + int64_t idx_size = nidx ? inputs[1].size() : 1; auto idx_to_out = idx_size / out.size(); int nwork; @@ -345,7 +345,7 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kernel_name, lib); - size_t nthreads = upd.size(); + int64_t nthreads = upd.size(); compute_encoder.set_compute_pipeline_state(kernel); @@ -354,8 +354,8 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.set_output_array(out, 2); // Set update info - size_t upd_ndim = upd.ndim(); - size_t upd_size = 1; + int64_t upd_ndim = upd.ndim(); + int64_t upd_size = 1; for (int i = idx_ndim; i < upd.ndim(); ++i) { upd_size *= upd.shape(i); } @@ -391,7 +391,7 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.set_bytes(upd_size, 6); // Set output info - size_t out_ndim = out.ndim(); + int64_t out_ndim = out.ndim(); if (out_ndim == 0) { // Need placeholders so Metal doesn't complain int shape_ = 0; @@ -448,7 +448,7 @@ void GatherAxis::eval_gpu(const std::vector& inputs, array& out) { auto& s = stream(); auto& d = metal::device(s.device); - size_t ndim = src.ndim(); + int64_t ndim = src.ndim(); bool large = idx.size() > INT32_MAX || src.size() > INT32_MAX; @@ -486,8 +486,8 @@ void GatherAxis::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.set_compute_pipeline_state(kernel); // Grid [size post, index size, size pre] - size_t size_pre = 1; - size_t size_post = 1; + int64_t size_pre = 1; + int64_t size_post = 1; for (int i = 0; i < axis_; ++i) { size_pre *= idx.shape(i); } @@ -541,7 +541,7 @@ void ScatterAxis::eval_gpu(const std::vector& inputs, array& out) { auto& s = stream(); auto& d = metal::device(s.device); - size_t ndim = src.ndim(); + int64_t ndim = src.ndim(); bool large = idx.size() > INT32_MAX || src.size() > INT32_MAX; @@ -602,8 +602,8 @@ void ScatterAxis::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.set_compute_pipeline_state(kernel); // Grid [size post, index size, size pre] - size_t size_pre = 1; - size_t size_post = 1; + int64_t size_pre = 1; + int64_t size_post = 1; for (int i = 0; i < axis_; ++i) { size_pre *= idx.shape(i); } diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index d6bee651d..4a8f3d77d 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -344,7 +344,7 @@ void steel_gemm_splitk_axpby( int M, int N, int K, - int batch_size_out, + int /* batch_size_out */, int lda, int ldb, bool transpose_a, diff --git a/mlx/backend/metal/nojit_kernels.cpp b/mlx/backend/metal/nojit_kernels.cpp index 109dd8df7..46d2cb5e0 100644 --- a/mlx/backend/metal/nojit_kernels.cpp +++ b/mlx/backend/metal/nojit_kernels.cpp @@ -179,8 +179,8 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel( metal::Device& d, const std::string& kernel_name, const array&, - const std::optional& mask_out, - const std::optional& mask_op, + const std::optional& /* mask_out */, + const std::optional& /* mask_op */, bool, bool, int, diff --git a/mlx/backend/metal/normalization.cpp b/mlx/backend/metal/normalization.cpp index da0160b24..4c277c52e 100644 --- a/mlx/backend/metal/normalization.cpp +++ b/mlx/backend/metal/normalization.cpp @@ -134,7 +134,7 @@ void RMSNormVJP::eval_gpu( d.add_temporary(g, s.index); } - auto axis_size = static_cast(x.shape().back()); + auto axis_size = x.shape().back(); int n_rows = x.data_size() / axis_size; // Allocate the gradient accumulator gw and a temporary to store the @@ -246,7 +246,7 @@ void LayerNorm::eval_gpu( const array& w = inputs[1]; const array& b = inputs[2]; - auto axis_size = static_cast(x.shape().back()); + auto axis_size = x.shape().back(); int n_rows = x.data_size() / axis_size; int simd_size = 32; @@ -344,7 +344,7 @@ void LayerNormVJP::eval_gpu( d.add_temporary(g, s.index); } - auto axis_size = static_cast(x.shape().back()); + auto axis_size = x.shape().back(); int n_rows = x.data_size() / axis_size; // Allocate a temporary to store the gradients for w and allocate the output diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 5f6376c5e..930114ecc 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -152,7 +152,7 @@ void ArgReduce::eval_gpu(const std::vector& inputs, array& out) { } } -void Load::eval_gpu(const std::vector& inputs, array& out) { +void Load::eval_gpu(const std::vector& /* inputs */, array& /* out */) { throw std::runtime_error("[Load::eval_gpu] Not implemented."); } @@ -201,41 +201,45 @@ void RandomBits::eval_gpu(const std::vector& inputs, array& out) { } void QRF::eval_gpu( - const std::vector& inputs, - std::vector& outputs) { + const std::vector& /* inputs */, + std::vector& /* outputs */) { throw std::runtime_error("[QRF::eval_gpu] Metal QR factorization NYI."); } void SVD::eval_gpu( - const std::vector& inputs, - std::vector& outputs) { + const std::vector& /* inputs */, + std::vector& /* outputs */) { throw std::runtime_error("[SVD::eval_gpu] Metal SVD NYI."); } -void Inverse::eval_gpu(const std::vector& inputs, array& output) { +void Inverse::eval_gpu( + const std::vector& /* inputs */, + array& /* output */) { throw std::runtime_error("[Inverse::eval_gpu] Metal inversion NYI."); } -void Cholesky::eval_gpu(const std::vector& inputs, array& out) { +void Cholesky::eval_gpu( + const std::vector& /* inputs */, + array& /* out */) { throw std::runtime_error( "[Cholesky::eval_gpu] Metal Cholesky decomposition NYI."); } void Eig::eval_gpu( - const std::vector& inputs, - std::vector& outputs) { + const std::vector& /* inputs */, + std::vector& /* outputs */) { throw std::runtime_error("[Eig::eval_gpu] Metal Eig NYI."); } void Eigh::eval_gpu( - const std::vector& inputs, - std::vector& outputs) { + const std::vector& /* inputs */, + std::vector& /* outputs */) { throw std::runtime_error("[Eigh::eval_gpu] Metal Eigh NYI."); } void LUF::eval_gpu( - const std::vector& inputs, - std::vector& outputs) { + const std::vector& /* inputs */, + std::vector& /* outputs */) { throw std::runtime_error("[LUF::eval_gpu] Metal LU factorization NYI."); } diff --git a/mlx/backend/metal/reduce.cpp b/mlx/backend/metal/reduce.cpp index 504943d82..2443bf96c 100644 --- a/mlx/backend/metal/reduce.cpp +++ b/mlx/backend/metal/reduce.cpp @@ -291,7 +291,7 @@ void init_reduce( const std::string& op_name, CommandEncoder& compute_encoder, metal::Device& d, - const Stream& s) { + const Stream& /* s */) { auto [_, out_type] = remap_reduce_types(out, op_name); const std::string func_name = "init_reduce"; std::string kname = func_name; @@ -397,7 +397,7 @@ void row_reduce_small( RowReduceArgs& args, CommandEncoder& compute_encoder, metal::Device& d, - const Stream& s) { + const Stream& /* s */) { // Set the kernel int n = get_kernel_reduce_ndim(args.reduce_ndim); auto [in_type, out_type] = remap_reduce_types(in, op_name); @@ -453,7 +453,7 @@ void row_reduce_simple( RowReduceArgs& args, CommandEncoder& compute_encoder, metal::Device& d, - const Stream& s) { + const Stream& /* s */) { // Set the kernel auto [in_type, out_type] = remap_reduce_types(in, op_name); const std::string func_name = "row_reduce_simple"; @@ -493,7 +493,7 @@ void row_reduce_looped( RowReduceArgs& args, CommandEncoder& compute_encoder, metal::Device& d, - const Stream& s) { + const Stream& /* s */) { auto [in_type, out_type] = remap_reduce_types(in, op_name); // Set the kernel @@ -570,7 +570,7 @@ void strided_reduce_small( ColReduceArgs& args, CommandEncoder& compute_encoder, metal::Device& d, - const Stream& s) { + const Stream& /* s */) { auto [in_type, out_type] = remap_reduce_types(in, op_name); // Figure out the grid dims @@ -747,7 +747,7 @@ void strided_reduce_looped( ColReduceArgs& args, CommandEncoder& compute_encoder, metal::Device& d, - const Stream& s) { + const Stream& /* s */) { auto [in_type, out_type] = remap_reduce_types(in, op_name); // Prepare the arguments for the kernel @@ -959,7 +959,7 @@ void Reduce::eval_gpu(const std::vector& inputs, array& out) { // Continue with reduction operation // Minimum of 4 bytes since we use size 4 structs for all reduce // and metal will complain o/w - size_t min_bytes = std::max(out.nbytes(), 4ul); + size_t min_bytes = std::max(out.nbytes(), 4); out.set_data(allocator::malloc(min_bytes)); std::string op_name; switch (reduce_type_) { diff --git a/mlx/backend/metal/resident.cpp b/mlx/backend/metal/resident.cpp index 798824c2f..cc7e6af08 100644 --- a/mlx/backend/metal/resident.cpp +++ b/mlx/backend/metal/resident.cpp @@ -80,7 +80,7 @@ void ResidencySet::resize(size_t size) { // Remove wired allocations until under capacity auto allocations = wired_set_->allAllocations(); auto num_allocations = wired_set_->allocationCount(); - for (int i = 0; i < num_allocations && current_size > size; ++i) { + for (size_t i = 0; i < num_allocations && current_size > size; ++i) { auto buf = static_cast(allocations->object(i)); wired_set_->removeAllocation(buf); current_size -= buf->allocatedSize(); diff --git a/mlx/backend/metal/slicing.cpp b/mlx/backend/metal/slicing.cpp index 1e14c35c8..97087c256 100644 --- a/mlx/backend/metal/slicing.cpp +++ b/mlx/backend/metal/slicing.cpp @@ -33,7 +33,7 @@ void concatenate_gpu( auto& d = metal::device(s.device); auto& compute_encoder = d.get_command_encoder(s.index); auto concurrent_ctx = compute_encoder.start_concurrent(); - for (int i = 0; i < inputs.size(); i++) { + for (int i = 0; i < std::ssize(inputs); i++) { array out_slice(inputs[i].shape(), out.dtype(), nullptr, {}); size_t data_offset = strides[axis] * sizes[i]; out_slice.copy_shared_buffer( diff --git a/mlx/backend/metal/utils.h b/mlx/backend/metal/utils.h index e7784e599..74d2fc244 100644 --- a/mlx/backend/metal/utils.h +++ b/mlx/backend/metal/utils.h @@ -29,6 +29,10 @@ inline void debug_set_stream_queue_label(MTL::CommandQueue* queue, int index) { std::ostringstream label; label << "Stream " << index; queue->setLabel(make_string(label)); +#else + // appease warnings + (void)queue; + (void)index; #endif } @@ -42,6 +46,9 @@ inline void debug_set_primitive_buffer_label( } label << primitive.name(); command_buffer->setLabel(make_string(label)); +#else + (void)command_buffer; + (void)primitive; #endif }