diff --git a/mlx/backend/metal/CMakeLists.txt b/mlx/backend/metal/CMakeLists.txt index 8f842259c..16225e181 100644 --- a/mlx/backend/metal/CMakeLists.txt +++ b/mlx/backend/metal/CMakeLists.txt @@ -33,10 +33,11 @@ make_jit_source(unary_ops kernels/erf.h kernels/expm1f.h) make_jit_source(binary_ops) make_jit_source(ternary_ops) make_jit_source(reduce_utils kernels/atomic.h kernels/reduction/ops.h) -make_jit_source(scatter kernels/indexing.h) -make_jit_source(gather kernels/indexing.h) -make_jit_source(gather_axis) -make_jit_source(scatter_axis) +make_jit_source(indexing/scatter kernels/indexing/indexing.h) +make_jit_source(indexing/gather kernels/indexing/indexing.h) +make_jit_source(indexing/gather_front kernels/indexing/indexing.h) +make_jit_source(indexing/gather_axis) +make_jit_source(indexing/scatter_axis) make_jit_source(hadamard) if(MLX_METAL_JIT) diff --git a/mlx/backend/metal/indexing.cpp b/mlx/backend/metal/indexing.cpp index 13ce88a62..972b458d7 100644 --- a/mlx/backend/metal/indexing.cpp +++ b/mlx/backend/metal/indexing.cpp @@ -52,8 +52,10 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { auto& s = stream(); auto& d = metal::device(s.device); - int idx_ndim = nidx ? inputs[1].ndim() : 0; - size_t ndim = src.ndim(); + size_t slice_size = 1; + for (auto s : slice_sizes_) { + slice_size *= s; + } bool large_index = nidx && inputs[1].size() > INT32_MAX; bool large_src = src.size() > INT32_MAX; @@ -61,6 +63,55 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { bool large = large_index || large_src || large_out; std::string idx_type_name = nidx ? type_to_name(inputs[1]) : ""; + + if (src.flags().row_contiguous && nidx == 1 && axes_[0] == 0 && + inputs[1].flags().row_contiguous && slice_size == src.strides()[0]) { + int work_per_thread = (slice_size > 8 && src.dtype().size() < 4) ? 2 : 1; + auto& indices = inputs[1]; + std::string kernel_name = fmt::format( + "gather_front{0}_{1}_{2}_{3}", + type_to_name(out), + idx_type_name, + large ? "int64_t" : "int", + work_per_thread); + std::string lib_name = kernel_name; + + auto lib = d.get_library(lib_name, [&]() { + std::string kernel_source = metal::utils(); + kernel_source += metal::gather_front(); + kernel_source += get_template_definition( + kernel_name, + "gather_front", + get_type_string(out.dtype()), + get_type_string(indices.dtype()), + large ? "int64_t" : "int", + work_per_thread); + + return kernel_source; + }); + + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel(kernel_name, lib); + compute_encoder.set_compute_pipeline_state(kernel); + + size_t dim_x = (slice_size + work_per_thread - 1) / work_per_thread; + size_t dim_y = indices.size(); + auto group_dims = get_block_dims(dim_x, dim_y, 1); + MTL::Size grid_dims = MTL::Size(dim_x, dim_y, 1); + + compute_encoder.set_input_array(src, 0); + compute_encoder.set_input_array(indices, 1); + compute_encoder.set_output_array(out, 2); + compute_encoder.set_bytes(slice_size, 3); + compute_encoder.set_bytes(src.shape(0), 4); + compute_encoder.dispatch_threads(grid_dims, group_dims); + + return; + } + + int idx_ndim = nidx ? inputs[1].ndim() : 0; + size_t ndim = src.ndim(); + std::string kernel_name = fmt::format( "gather{0}{1}_{2}_{3}_{4}", type_to_name(out), @@ -96,11 +147,6 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { auto kernel = d.get_kernel(kernel_name, lib); compute_encoder.set_compute_pipeline_state(kernel); - size_t slice_size = 1; - for (auto s : slice_sizes_) { - slice_size *= s; - } - // Launch 3D grid of threads // First two dimensions for the indices, the last one for the slice size_t dim0 = 1; diff --git a/mlx/backend/metal/jit/includes.h b/mlx/backend/metal/jit/includes.h index 068a3ff08..f3b57c7f9 100644 --- a/mlx/backend/metal/jit/includes.h +++ b/mlx/backend/metal/jit/includes.h @@ -19,6 +19,7 @@ const char* binary_two(); const char* copy(); const char* fft(); const char* gather_axis(); +const char* gather_front(); const char* hadamard(); const char* logsumexp(); const char* quantized_utils(); diff --git a/mlx/backend/metal/kernels/gather.h b/mlx/backend/metal/kernels/indexing/gather.h similarity index 96% rename from mlx/backend/metal/kernels/gather.h rename to mlx/backend/metal/kernels/indexing/gather.h index 472e497c0..8b93c0167 100644 --- a/mlx/backend/metal/kernels/gather.h +++ b/mlx/backend/metal/kernels/indexing/gather.h @@ -2,7 +2,7 @@ #pragma once -#include "mlx/backend/metal/kernels/indexing.h" +#include "mlx/backend/metal/kernels/indexing/indexing.h" template METAL_FUNC void gather_impl( diff --git a/mlx/backend/metal/kernels/gather_axis.h b/mlx/backend/metal/kernels/indexing/gather_axis.h similarity index 100% rename from mlx/backend/metal/kernels/gather_axis.h rename to mlx/backend/metal/kernels/indexing/gather_axis.h diff --git a/mlx/backend/metal/kernels/indexing/gather_front.h b/mlx/backend/metal/kernels/indexing/gather_front.h new file mode 100644 index 000000000..1389e4c62 --- /dev/null +++ b/mlx/backend/metal/kernels/indexing/gather_front.h @@ -0,0 +1,24 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/metal/kernels/indexing/indexing.h" + +template +[[kernel]] void gather_front( + const device T* src, + const device IdxT* indices, + device T* out, + const constant int64_t& stride, + const constant int& size, + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + auto idx = offset_neg_idx(indices[index.y], size); + LocT src_idx = static_cast(stride) * idx; + LocT out_idx = static_cast(stride) * index.y; + + int s_idx = N * index.x; + for (int i = 0; i < N && s_idx < stride; ++i, ++s_idx) { + out[out_idx + s_idx] = src[src_idx + s_idx]; + } +} diff --git a/mlx/backend/metal/kernels/indexing.h b/mlx/backend/metal/kernels/indexing/indexing.h similarity index 100% rename from mlx/backend/metal/kernels/indexing.h rename to mlx/backend/metal/kernels/indexing/indexing.h diff --git a/mlx/backend/metal/kernels/scatter.h b/mlx/backend/metal/kernels/indexing/scatter.h similarity index 96% rename from mlx/backend/metal/kernels/scatter.h rename to mlx/backend/metal/kernels/indexing/scatter.h index d96eca3db..f0217b336 100644 --- a/mlx/backend/metal/kernels/scatter.h +++ b/mlx/backend/metal/kernels/indexing/scatter.h @@ -2,7 +2,7 @@ #pragma once -#include "mlx/backend/metal/kernels/indexing.h" +#include "mlx/backend/metal/kernels/indexing/indexing.h" template < typename T, diff --git a/mlx/backend/metal/kernels/scatter_axis.h b/mlx/backend/metal/kernels/indexing/scatter_axis.h similarity index 100% rename from mlx/backend/metal/kernels/scatter_axis.h rename to mlx/backend/metal/kernels/indexing/scatter_axis.h diff --git a/python/mlx/nn/layers/quantized.py b/python/mlx/nn/layers/quantized.py index 669162e68..c308e884b 100644 --- a/python/mlx/nn/layers/quantized.py +++ b/python/mlx/nn/layers/quantized.py @@ -98,11 +98,10 @@ class QuantizedEmbedding(Module): # Initialize the quantized weight scale = math.sqrt(1 / dims) weight = mx.random.normal(shape=(num_embeddings, dims), scale=scale) - self.weight, *scales_biases = mx.quantize(weight, group_size, bits, mode=mode) - if mode == "affine": - self.scales, self.biases = scales_biases - else: - (self.scales,) = scales_biases + self.weight, self.scales, *biases = mx.quantize( + weight, group_size, bits, mode=mode + ) + self.biases = biases[0] if biases else None self.num_embeddings = num_embeddings self.dims = dims @@ -155,16 +154,13 @@ class QuantizedEmbedding(Module): """Create a :obj:`QuantizedEmbedding` layer from an :obj:`Embedding` layer.""" embedding_dims, dims = embedding_layer.weight.shape ql = cls(embedding_dims, dims, group_size, bits, mode=mode) - ql.weight, *scales_biases = mx.quantize( + ql.weight, ql.scales, *biases = mx.quantize( embedding_layer.weight, group_size, bits, mode=mode, ) - if mode == "affine": - ql.scales, ql.biases = scales_biases - else: - (ql.scales,) = scales_biases + ql.biases = biases[0] if biases else None return ql @@ -214,11 +210,10 @@ class QuantizedLinear(Module): high=scale, shape=(output_dims, input_dims), ) - self.weight, *scales_biases = mx.quantize(weight, group_size, bits, mode=mode) - if mode == "affine": - self.scales, self.biases = scales_biases - else: - (self.scales,) = scales_biases + self.weight, self.scales, *biases = mx.quantize( + weight, group_size, bits, mode=mode + ) + self.biases = biases[0] if biases else None # And bias if needed if bias: @@ -261,16 +256,13 @@ class QuantizedLinear(Module): """Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer.""" output_dims, input_dims = linear_layer.weight.shape ql = cls(input_dims, output_dims, False, group_size, bits, mode=mode) - ql.weight, *scales_biases = mx.quantize( + ql.weight, ql.scales, *biases = mx.quantize( linear_layer.weight, group_size, bits, mode=mode, ) - if mode == "affine": - ql.scales, ql.biases = scales_biases - else: - (ql.scales,) = scales_biases + ql.biases = biases[0] if biases else None if "bias" in linear_layer: ql.bias = linear_layer.bias