diff --git a/.circleci/config.yml b/.circleci/config.yml index ee2cb8c5b..6dc7ec4df 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -251,7 +251,7 @@ jobs: name: Install Python package command: | source env/bin/activate - MACOSX_DEPLOYMENT_TARGET="" DEV_RELEASE=1 \ + env -u MACOSX_DEPLOYMENT_TARGET DEV_RELEASE=1 \ CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \ pip install . -v - run: diff --git a/benchmarks/python/gather_mm_bench.py b/benchmarks/python/gather_mm_bench.py new file mode 100644 index 000000000..ffeb73487 --- /dev/null +++ b/benchmarks/python/gather_mm_bench.py @@ -0,0 +1,74 @@ +# Copyright © 2025 Apple Inc. + +import mlx.core as mx +from time_utils import time_fn + +N = 1024 +D = 1024 +M = 1024 +E = 32 +I = 4 + + +def gather_sort(x, indices): + N, M = indices.shape + indices = indices.flatten() + order = mx.argsort(indices) + inv_order = mx.argsort(order) + return x.flatten(0, -3)[order // M], indices[order], inv_order + + +def scatter_unsort(x, inv_order, shape=None): + x = x[inv_order] + if shape is not None: + x = mx.unflatten(x, 0, shape) + return x + + +def gather_mm_simulate(x, w, indices): + x, idx, inv_order = gather_sort(x, indices) + for i in range(2): + y = mx.concatenate([x[i] @ w[j].T for i, j in enumerate(idx.tolist())], axis=0) + x = y[:, None] + x = scatter_unsort(x, inv_order, indices.shape) + return x + + +def time_gather_mm(): + x = mx.random.normal((N, 1, 1, D)) / 1024**0.5 + w1 = mx.random.normal((E, M, D)) / 1024**0.5 + w2 = mx.random.normal((E, D, M)) / 1024**0.5 + indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32) + sorted_indices = mx.sort(indices.flatten()).reshape(N, I) + mx.eval(x, w1, w2, indices, sorted_indices) + + def gather_mm(x, w1, w2, indices, sort): + idx = indices + inv_order = None + if sort: + x, idx, inv_order = gather_sort(x, indices) + x = mx.gather_mm(x, w1.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort) + x = mx.gather_mm(x, w2.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort) + if sort: + x = scatter_unsort(x, inv_order, indices.shape) + return x + + time_fn(gather_mm, x, w1, w2, indices, False) + time_fn(gather_mm, x, w1, w2, sorted_indices, False) + time_fn(gather_mm, x, w1, w2, indices, True) + + x = mx.random.normal((N * I, D)) / 1024**0.5 + w1 = mx.random.normal((M, D)) / 1024**0.5 + w2 = mx.random.normal((D, M)) / 1024**0.5 + mx.eval(x, w1, w2) + + def equivalent_matmul(x, w1, w2): + x = x @ w1.T + x = x @ w2.T + return x + + time_fn(equivalent_matmul, x, w1, w2) + + +if __name__ == "__main__": + time_gather_mm() diff --git a/benchmarks/python/gather_qmm_bench.py b/benchmarks/python/gather_qmm_bench.py new file mode 100644 index 000000000..17c06d57d --- /dev/null +++ b/benchmarks/python/gather_qmm_bench.py @@ -0,0 +1,84 @@ +# Copyright © 2025 Apple Inc. + +import mlx.core as mx +from time_utils import time_fn + +N = 1024 +D = 1024 +M = 1024 +E = 32 +I = 4 + + +def gather_sort(x, indices): + N, M = indices.shape + indices = indices.flatten() + order = mx.argsort(indices) + inv_order = mx.argsort(order) + return x.flatten(0, -3)[order // M], indices[order], inv_order + + +def scatter_unsort(x, inv_order, shape=None): + x = x[inv_order] + if shape is not None: + x = mx.unflatten(x, 0, shape) + return x + + +def gather_mm_simulate(x, w, indices): + x, idx, inv_order = gather_sort(x, indices) + for i in range(2): + y = mx.concatenate( + [ + mx.quantized_matmul(x[i], w[0][j], w[1][j], w[2][j], transpose=True) + for i, j in enumerate(idx.tolist()) + ], + axis=0, + ) + x = y[:, None] + x = scatter_unsort(x, inv_order, indices.shape) + return x + + +def time_gather_qmm(): + x = mx.random.normal((N, 1, 1, D)) / 1024**0.5 + w1 = mx.random.normal((E, M, D)) / 1024**0.5 + w2 = mx.random.normal((E, D, M)) / 1024**0.5 + w1 = mx.quantize(w1) + w2 = mx.quantize(w2) + indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32) + sorted_indices = mx.sort(indices.flatten()).reshape(N, I) + mx.eval(x, w1, w2, indices, sorted_indices) + + def gather_mm(x, w1, w2, indices, sort): + idx = indices + inv_order = None + if sort: + x, idx, inv_order = gather_sort(x, indices) + x = mx.gather_qmm(x, *w1, transpose=True, rhs_indices=idx, sorted_indices=sort) + x = mx.gather_qmm(x, *w2, transpose=True, rhs_indices=idx, sorted_indices=sort) + if sort: + x = scatter_unsort(x, inv_order, indices.shape) + return x + + time_fn(gather_mm, x, w1, w2, indices, False) + time_fn(gather_mm, x, w1, w2, sorted_indices, False) + time_fn(gather_mm, x, w1, w2, indices, True) + + x = mx.random.normal((N * I, D)) / 1024**0.5 + w1 = mx.random.normal((M, D)) / 1024**0.5 + w2 = mx.random.normal((D, M)) / 1024**0.5 + w1 = mx.quantize(w1) + w2 = mx.quantize(w2) + mx.eval(x, w1, w2) + + def equivalent_matmul(x, w1, w2): + x = mx.quantized_matmul(x, *w1, transpose=True) + x = mx.quantized_matmul(x, *w2, transpose=True) + return x + + time_fn(equivalent_matmul, x, w1, w2) + + +if __name__ == "__main__": + time_gather_qmm() diff --git a/docs/src/python/array.rst b/docs/src/python/array.rst index 532bb45c9..7e1c3339d 100644 --- a/docs/src/python/array.rst +++ b/docs/src/python/array.rst @@ -38,6 +38,7 @@ Array array.log10 array.log1p array.log2 + array.logcumsumexp array.logsumexp array.max array.mean diff --git a/docs/src/python/ops.rst b/docs/src/python/ops.rst index 66c5764ed..55fc1f534 100644 --- a/docs/src/python/ops.rst +++ b/docs/src/python/ops.rst @@ -103,6 +103,7 @@ Operations log10 log1p logaddexp + logcumsumexp logical_not logical_and logical_or diff --git a/mlx/CMakeLists.txt b/mlx/CMakeLists.txt index 76fe389d4..abf46a7d5 100644 --- a/mlx/CMakeLists.txt +++ b/mlx/CMakeLists.txt @@ -5,6 +5,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/dtype_utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/export.cpp ${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp diff --git a/mlx/array.h b/mlx/array.h index d690dcd97..66a4702a6 100644 --- a/mlx/array.h +++ b/mlx/array.h @@ -339,11 +339,11 @@ class array { return allocator::allocator().size(buffer()); } - // Return a copy of the shared pointer - // to the array::Data struct - std::shared_ptr data_shared_ptr() const { + // Return the shared pointer to the array::Data struct + const std::shared_ptr& data_shared_ptr() const { return array_desc_->data; } + // Return a raw pointer to the arrays data template T* data() { diff --git a/mlx/backend/common/CMakeLists.txt b/mlx/backend/common/CMakeLists.txt index 82e6eef84..6c4e25067 100644 --- a/mlx/backend/common/CMakeLists.txt +++ b/mlx/backend/common/CMakeLists.txt @@ -1,6 +1,7 @@ target_sources( mlx - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/broadcasting.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp ${CMAKE_CURRENT_SOURCE_DIR}/common.cpp ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp diff --git a/mlx/backend/common/broadcasting.cpp b/mlx/backend/common/broadcasting.cpp new file mode 100644 index 000000000..49bc75b8f --- /dev/null +++ b/mlx/backend/common/broadcasting.cpp @@ -0,0 +1,24 @@ +// Copyright © 2024 Apple Inc. + +#include "mlx/backend/common/utils.h" + +namespace mlx::core { + +void broadcast(const array& in, array& out) { + if (out.size() == 0) { + out.set_data(nullptr); + return; + } + Strides strides(out.ndim(), 0); + int diff = out.ndim() - in.ndim(); + for (int i = in.ndim() - 1; i >= 0; --i) { + strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i]; + } + auto flags = in.flags(); + if (out.size() > in.size()) { + flags.row_contiguous = flags.col_contiguous = false; + } + out.copy_shared_buffer(in, strides, flags, in.data_size()); +} + +} // namespace mlx::core diff --git a/mlx/backend/common/broadcasting.h b/mlx/backend/common/broadcasting.h new file mode 100644 index 000000000..29651e909 --- /dev/null +++ b/mlx/backend/common/broadcasting.h @@ -0,0 +1,11 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/array.h" + +namespace mlx::core { + +void broadcast(const array& in, array& out); + +} // namespace mlx::core diff --git a/mlx/backend/common/common.cpp b/mlx/backend/common/common.cpp index 57813e062..2cda88a31 100644 --- a/mlx/backend/common/common.cpp +++ b/mlx/backend/common/common.cpp @@ -1,6 +1,7 @@ // Copyright © 2024 Apple Inc. #include +#include "mlx/backend/common/broadcasting.h" #include "mlx/backend/common/utils.h" #include "mlx/primitives.h" @@ -42,23 +43,6 @@ void AsStrided::eval(const std::vector& inputs, array& out) { return out.copy_shared_buffer(in, strides_, flags, data_size, offset_); } -void broadcast(const array& in, array& out) { - if (out.size() == 0) { - out.set_data(nullptr); - return; - } - Strides strides(out.ndim(), 0); - int diff = out.ndim() - in.ndim(); - for (int i = in.ndim() - 1; i >= 0; --i) { - strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i]; - } - auto flags = in.flags(); - if (out.size() > in.size()) { - flags.row_contiguous = flags.col_contiguous = false; - } - out.copy_shared_buffer(in, strides, flags, in.data_size()); -} - void Broadcast::eval(const std::vector& inputs, array& out) { broadcast(inputs[0], out); } diff --git a/mlx/backend/cpu/scan.cpp b/mlx/backend/cpu/scan.cpp index 1a44ebd39..199dbab35 100644 --- a/mlx/backend/cpu/scan.cpp +++ b/mlx/backend/cpu/scan.cpp @@ -3,6 +3,7 @@ #include #include "mlx/backend/common/utils.h" +#include "mlx/backend/cpu/binary_ops.h" #include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/simd/simd.h" @@ -226,6 +227,16 @@ void scan_dispatch( scan_op(in, out, axis, reverse, inclusive, op, init); break; } + case Scan::LogAddExp: { + auto op = [](U a, T b) { + return detail::LogAddExp{}(a, static_cast(b)); + }; + auto init = (issubdtype(in.dtype(), floating)) + ? static_cast(-std::numeric_limits::infinity()) + : std::numeric_limits::min(); + scan_op(in, out, axis, reverse, inclusive, op, init); + break; + } } } diff --git a/mlx/backend/metal/CMakeLists.txt b/mlx/backend/metal/CMakeLists.txt index 7985396c4..332c560f8 100644 --- a/mlx/backend/metal/CMakeLists.txt +++ b/mlx/backend/metal/CMakeLists.txt @@ -61,6 +61,7 @@ if(MLX_METAL_JIT) kernels/steel/gemm/transforms.h) make_jit_source(steel/gemm/kernels/steel_gemm_fused) make_jit_source(steel/gemm/kernels/steel_gemm_masked kernels/steel/defines.h) + make_jit_source(steel/gemm/kernels/steel_gemm_gather) make_jit_source(steel/gemm/kernels/steel_gemm_splitk) make_jit_source( steel/conv/conv diff --git a/mlx/backend/metal/jit/includes.h b/mlx/backend/metal/jit/includes.h index 921ce50ce..27ae22d05 100644 --- a/mlx/backend/metal/jit/includes.h +++ b/mlx/backend/metal/jit/includes.h @@ -33,6 +33,7 @@ const char* gemm(); const char* steel_gemm_fused(); const char* steel_gemm_masked(); const char* steel_gemm_splitk(); +const char* steel_gemm_gather(); const char* conv(); const char* steel_conv(); const char* steel_conv_general(); diff --git a/mlx/backend/metal/jit_kernels.cpp b/mlx/backend/metal/jit_kernels.cpp index 204bb14e7..5206c9b54 100644 --- a/mlx/backend/metal/jit_kernels.cpp +++ b/mlx/backend/metal/jit_kernels.cpp @@ -584,6 +584,44 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel( return d.get_kernel(kernel_name, lib); } +MTL::ComputePipelineState* get_steel_gemm_gather_kernel( + metal::Device& d, + const std::string& kernel_name, + const std::string& hash_name, + const metal::MTLFCList& func_consts, + const array& out, + bool transpose_a, + bool transpose_b, + int bm, + int bn, + int bk, + int wm, + int wn, + bool rhs) { + const auto& lib_name = kernel_name; + auto lib = d.get_library(lib_name, [&]() { + std::string kernel_source; + concatenate( + kernel_source, + metal::utils(), + metal::gemm(), + metal::steel_gemm_gather(), + get_template_definition( + lib_name, + rhs ? "gather_mm_rhs" : "gather_mm", + get_type_string(out.dtype()), + bm, + bn, + bk, + wm, + wn, + transpose_a, + transpose_b)); + return kernel_source; + }); + return d.get_kernel(kernel_name, lib, hash_name, func_consts); +} + MTL::ComputePipelineState* get_gemv_masked_kernel( metal::Device& d, const std::string& kernel_name, @@ -714,4 +752,43 @@ MTL::ComputePipelineState* get_quantized_kernel( return d.get_kernel(kernel_name, lib); } +MTL::ComputePipelineState* get_gather_qmm_kernel( + metal::Device& d, + const std::string& kernel_name, + const std::string& hash_name, + const metal::MTLFCList& func_consts, + const array& x, + int group_size, + int bits, + int bm, + int bn, + int bk, + int wm, + int wn, + bool transpose) { + const auto& lib_name = kernel_name; + auto lib = d.get_library(lib_name, [&]() { + std::string kernel_source; + concatenate( + kernel_source, + metal::utils(), + metal::gemm(), + metal::quantized(), + get_template_definition( + lib_name, + "gather_qmm_rhs", + get_type_string(x.dtype()), + group_size, + bits, + bm, + bn, + bk, + wm, + wn, + transpose)); + return kernel_source; + }); + return d.get_kernel(kernel_name, lib, hash_name, func_consts); +} + } // namespace mlx::core diff --git a/mlx/backend/metal/kernels.h b/mlx/backend/metal/kernels.h index 1638a4496..6d8864385 100644 --- a/mlx/backend/metal/kernels.h +++ b/mlx/backend/metal/kernels.h @@ -160,6 +160,21 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel( bool mn_aligned, bool k_aligned); +MTL::ComputePipelineState* get_steel_gemm_gather_kernel( + metal::Device& d, + const std::string& kernel_name, + const std::string& hash_name, + const metal::MTLFCList& func_consts, + const array& out, + bool transpose_a, + bool transpose_b, + int bm, + int bn, + int bk, + int wm, + int wn, + bool rhs); + MTL::ComputePipelineState* get_steel_conv_kernel( metal::Device& d, const std::string& kernel_name, @@ -209,6 +224,21 @@ MTL::ComputePipelineState* get_quantized_kernel( const std::string& kernel_name, const std::string& template_def); +MTL::ComputePipelineState* get_gather_qmm_kernel( + metal::Device& d, + const std::string& kernel_name, + const std::string& hash_name, + const metal::MTLFCList& func_consts, + const array& x, + int group_size, + int bits, + int bm, + int bn, + int bk, + int wm, + int wn, + bool transpose); + // Create a GPU kernel template definition for JIT compilation template std::string diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index 309a840f8..3ee88ca46 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -69,6 +69,7 @@ set(STEEL_HEADERS steel/gemm/loader.h steel/gemm/transforms.h steel/gemm/kernels/steel_gemm_fused.h + steel/gemm/kernels/steel_gemm_gather.h steel/gemm/kernels/steel_gemm_masked.h steel/gemm/kernels/steel_gemm_splitk.h steel/utils/type_traits.h @@ -116,6 +117,7 @@ if(NOT MLX_METAL_JIT) build_kernel(steel/conv/kernels/steel_conv ${STEEL_HEADERS}) build_kernel(steel/conv/kernels/steel_conv_general ${STEEL_HEADERS}) build_kernel(steel/gemm/kernels/steel_gemm_fused ${STEEL_HEADERS}) + build_kernel(steel/gemm/kernels/steel_gemm_gather ${STEEL_HEADERS}) build_kernel(steel/gemm/kernels/steel_gemm_masked ${STEEL_HEADERS}) build_kernel(steel/gemm/kernels/steel_gemm_splitk ${STEEL_HEADERS}) build_kernel(gemv_masked steel/utils.h) diff --git a/mlx/backend/metal/kernels/complex.h b/mlx/backend/metal/kernels/complex.h index fe8ec5c0f..c88002cb3 100644 --- a/mlx/backend/metal/kernels/complex.h +++ b/mlx/backend/metal/kernels/complex.h @@ -104,10 +104,22 @@ constexpr bool operator==(complex64_t a, complex64_t b) { constexpr complex64_t operator+(complex64_t a, complex64_t b) { return {a.real + b.real, a.imag + b.imag}; } +constexpr complex64_t operator+(float a, complex64_t b) { + return {a + b.real, b.imag}; +} +constexpr complex64_t operator+(complex64_t a, float b) { + return {a.real + b, a.imag}; +} constexpr complex64_t operator-(complex64_t a, complex64_t b) { return {a.real - b.real, a.imag - b.imag}; } +constexpr complex64_t operator-(float a, complex64_t b) { + return {a - b.real, -b.imag}; +} +constexpr complex64_t operator-(complex64_t a, float b) { + return {a.real - b, a.imag}; +} constexpr complex64_t operator*(complex64_t a, complex64_t b) { return {a.real * b.real - a.imag * b.imag, a.real * b.imag + a.imag * b.real}; @@ -120,6 +132,13 @@ constexpr complex64_t operator/(complex64_t a, complex64_t b) { return {x / denom, y / denom}; } +constexpr complex64_t operator/(float a, complex64_t b) { + auto denom = b.real * b.real + b.imag * b.imag; + auto x = a * b.real; + auto y = -a * b.imag; + return {x / denom, y / denom}; +} + constexpr complex64_t operator%(complex64_t a, complex64_t b) { auto real = a.real - (b.real * static_cast(a.real / b.real)); auto imag = a.imag - (b.imag * static_cast(a.imag / b.imag)); diff --git a/mlx/backend/metal/kernels/quantized.h b/mlx/backend/metal/kernels/quantized.h index af9d7860e..b2b0d8d8f 100644 --- a/mlx/backend/metal/kernels/quantized.h +++ b/mlx/backend/metal/kernels/quantized.h @@ -3,6 +3,10 @@ #include #include +constant bool align_M [[function_constant(200)]]; +constant bool align_N [[function_constant(201)]]; +constant bool align_K [[function_constant(202)]]; + using namespace metal; #define MLX_MTL_CONST static constant constexpr const @@ -1686,26 +1690,26 @@ template < } template -[[kernel]] void bs_qmv_fast( +[[kernel]] void gather_qmv_fast( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], const device T* x [[buffer(3)]], - device T* y [[buffer(4)]], - const constant int& in_vec_size [[buffer(5)]], - const constant int& out_vec_size [[buffer(6)]], - const constant int& x_batch_ndims [[buffer(7)]], - const constant int* x_shape [[buffer(8)]], - const constant int64_t* x_strides [[buffer(9)]], - const constant int& w_batch_ndims [[buffer(10)]], - const constant int* w_shape [[buffer(11)]], - const constant int64_t* w_strides [[buffer(12)]], - const constant int64_t* s_strides [[buffer(13)]], - const constant int64_t* b_strides [[buffer(14)]], - const constant int& batch_ndims [[buffer(15)]], - const constant int* batch_shape [[buffer(16)]], - const device uint32_t* lhs_indices [[buffer(17)]], - const device uint32_t* rhs_indices [[buffer(18)]], + const device uint32_t* lhs_indices [[buffer(4)]], + const device uint32_t* rhs_indices [[buffer(5)]], + device T* y [[buffer(6)]], + const constant int& in_vec_size [[buffer(7)]], + const constant int& out_vec_size [[buffer(8)]], + const constant int& x_batch_ndims [[buffer(9)]], + const constant int* x_shape [[buffer(10)]], + const constant int64_t* x_strides [[buffer(11)]], + const constant int& w_batch_ndims [[buffer(12)]], + const constant int* w_shape [[buffer(13)]], + const constant int64_t* w_strides [[buffer(14)]], + const constant int64_t* s_strides [[buffer(15)]], + const constant int64_t* b_strides [[buffer(16)]], + const constant int& batch_ndims [[buffer(17)]], + const constant int* batch_shape [[buffer(18)]], const constant int64_t* lhs_strides [[buffer(19)]], const constant int64_t* rhs_strides [[buffer(20)]], uint3 tid [[threadgroup_position_in_grid]], @@ -1748,26 +1752,26 @@ template } template -[[kernel]] void bs_qmv( +[[kernel]] void gather_qmv( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], const device T* x [[buffer(3)]], - device T* y [[buffer(4)]], - const constant int& in_vec_size [[buffer(5)]], - const constant int& out_vec_size [[buffer(6)]], - const constant int& x_batch_ndims [[buffer(7)]], - const constant int* x_shape [[buffer(8)]], - const constant int64_t* x_strides [[buffer(9)]], - const constant int& w_batch_ndims [[buffer(10)]], - const constant int* w_shape [[buffer(11)]], - const constant int64_t* w_strides [[buffer(12)]], - const constant int64_t* s_strides [[buffer(13)]], - const constant int64_t* b_strides [[buffer(14)]], - const constant int& batch_ndims [[buffer(15)]], - const constant int* batch_shape [[buffer(16)]], - const device uint32_t* lhs_indices [[buffer(17)]], - const device uint32_t* rhs_indices [[buffer(18)]], + const device uint32_t* lhs_indices [[buffer(4)]], + const device uint32_t* rhs_indices [[buffer(5)]], + device T* y [[buffer(6)]], + const constant int& in_vec_size [[buffer(7)]], + const constant int& out_vec_size [[buffer(8)]], + const constant int& x_batch_ndims [[buffer(9)]], + const constant int* x_shape [[buffer(10)]], + const constant int64_t* x_strides [[buffer(11)]], + const constant int& w_batch_ndims [[buffer(12)]], + const constant int* w_shape [[buffer(13)]], + const constant int64_t* w_strides [[buffer(14)]], + const constant int64_t* s_strides [[buffer(15)]], + const constant int64_t* b_strides [[buffer(16)]], + const constant int& batch_ndims [[buffer(17)]], + const constant int* batch_shape [[buffer(18)]], const constant int64_t* lhs_strides [[buffer(19)]], const constant int64_t* rhs_strides [[buffer(20)]], uint3 tid [[threadgroup_position_in_grid]], @@ -1810,26 +1814,26 @@ template } template -[[kernel]] void bs_qvm( +[[kernel]] void gather_qvm( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], const device T* x [[buffer(3)]], - device T* y [[buffer(4)]], - const constant int& in_vec_size [[buffer(5)]], - const constant int& out_vec_size [[buffer(6)]], - const constant int& x_batch_ndims [[buffer(7)]], - const constant int* x_shape [[buffer(8)]], - const constant int64_t* x_strides [[buffer(9)]], - const constant int& w_batch_ndims [[buffer(10)]], - const constant int* w_shape [[buffer(11)]], - const constant int64_t* w_strides [[buffer(12)]], - const constant int64_t* s_strides [[buffer(13)]], - const constant int64_t* b_strides [[buffer(14)]], - const constant int& batch_ndims [[buffer(15)]], - const constant int* batch_shape [[buffer(16)]], - const device uint32_t* lhs_indices [[buffer(17)]], - const device uint32_t* rhs_indices [[buffer(18)]], + const device uint32_t* lhs_indices [[buffer(4)]], + const device uint32_t* rhs_indices [[buffer(5)]], + device T* y [[buffer(6)]], + const constant int& in_vec_size [[buffer(7)]], + const constant int& out_vec_size [[buffer(8)]], + const constant int& x_batch_ndims [[buffer(9)]], + const constant int* x_shape [[buffer(10)]], + const constant int64_t* x_strides [[buffer(11)]], + const constant int& w_batch_ndims [[buffer(12)]], + const constant int* w_shape [[buffer(13)]], + const constant int64_t* w_strides [[buffer(14)]], + const constant int64_t* s_strides [[buffer(15)]], + const constant int64_t* b_strides [[buffer(16)]], + const constant int& batch_ndims [[buffer(17)]], + const constant int* batch_shape [[buffer(18)]], const constant int64_t* lhs_strides [[buffer(19)]], const constant int64_t* rhs_strides [[buffer(20)]], uint3 tid [[threadgroup_position_in_grid]], @@ -1879,27 +1883,27 @@ template < const int BM = 32, const int BK = 32, const int BN = 32> -[[kernel]] void bs_qmm_t( +[[kernel]] void gather_qmm_t( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], const device T* x [[buffer(3)]], - device T* y [[buffer(4)]], - const constant int& K [[buffer(5)]], - const constant int& N [[buffer(6)]], - const constant int& M [[buffer(7)]], - const constant int& x_batch_ndims [[buffer(8)]], - const constant int* x_shape [[buffer(9)]], - const constant int64_t* x_strides [[buffer(10)]], - const constant int& w_batch_ndims [[buffer(11)]], - const constant int* w_shape [[buffer(12)]], - const constant int64_t* w_strides [[buffer(13)]], - const constant int64_t* s_strides [[buffer(14)]], - const constant int64_t* b_strides [[buffer(15)]], - const constant int& batch_ndims [[buffer(16)]], - const constant int* batch_shape [[buffer(17)]], - const device uint32_t* lhs_indices [[buffer(18)]], - const device uint32_t* rhs_indices [[buffer(19)]], + const device uint32_t* lhs_indices [[buffer(4)]], + const device uint32_t* rhs_indices [[buffer(5)]], + device T* y [[buffer(6)]], + const constant int& K [[buffer(7)]], + const constant int& N [[buffer(8)]], + const constant int& M [[buffer(9)]], + const constant int& x_batch_ndims [[buffer(10)]], + const constant int* x_shape [[buffer(11)]], + const constant int64_t* x_strides [[buffer(12)]], + const constant int& w_batch_ndims [[buffer(13)]], + const constant int* w_shape [[buffer(14)]], + const constant int64_t* w_strides [[buffer(15)]], + const constant int64_t* s_strides [[buffer(16)]], + const constant int64_t* b_strides [[buffer(17)]], + const constant int& batch_ndims [[buffer(18)]], + const constant int* batch_shape [[buffer(19)]], const constant int64_t* lhs_strides [[buffer(20)]], const constant int64_t* rhs_strides [[buffer(21)]], uint3 tid [[threadgroup_position_in_grid]], @@ -1946,27 +1950,27 @@ template < const int BM = 32, const int BK = 32, const int BN = 32> -[[kernel]] void bs_qmm_n( +[[kernel]] void gather_qmm_n( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], const device T* x [[buffer(3)]], - device T* y [[buffer(4)]], - const constant int& K [[buffer(5)]], - const constant int& N [[buffer(6)]], - const constant int& M [[buffer(7)]], - const constant int& x_batch_ndims [[buffer(8)]], - const constant int* x_shape [[buffer(9)]], - const constant int64_t* x_strides [[buffer(10)]], - const constant int& w_batch_ndims [[buffer(11)]], - const constant int* w_shape [[buffer(12)]], - const constant int64_t* w_strides [[buffer(13)]], - const constant int64_t* s_strides [[buffer(14)]], - const constant int64_t* b_strides [[buffer(15)]], - const constant int& batch_ndims [[buffer(16)]], - const constant int* batch_shape [[buffer(17)]], - const device uint32_t* lhs_indices [[buffer(18)]], - const device uint32_t* rhs_indices [[buffer(19)]], + const device uint32_t* lhs_indices [[buffer(4)]], + const device uint32_t* rhs_indices [[buffer(5)]], + device T* y [[buffer(6)]], + const constant int& K [[buffer(7)]], + const constant int& N [[buffer(8)]], + const constant int& M [[buffer(9)]], + const constant int& x_batch_ndims [[buffer(10)]], + const constant int* x_shape [[buffer(11)]], + const constant int64_t* x_strides [[buffer(12)]], + const constant int& w_batch_ndims [[buffer(13)]], + const constant int* w_shape [[buffer(14)]], + const constant int64_t* w_strides [[buffer(15)]], + const constant int64_t* s_strides [[buffer(16)]], + const constant int64_t* b_strides [[buffer(17)]], + const constant int& batch_ndims [[buffer(18)]], + const constant int* batch_shape [[buffer(19)]], const constant int64_t* lhs_strides [[buffer(20)]], const constant int64_t* rhs_strides [[buffer(21)]], uint3 tid [[threadgroup_position_in_grid]], @@ -2007,6 +2011,289 @@ template < w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); } +template +METAL_FUNC void gemm_loop_aligned( + threadgroup T* As, + threadgroup T* Bs, + thread mma_t& mma_op, + thread loader_a_t& loader_a, + thread loader_b_t& loader_b, + const int k_iterations) { + for (int k = 0; k < k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load elements into threadgroup memory + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } +} + +template < + bool rows_aligned, + bool cols_aligned, + bool transpose, + typename T, + typename mma_t, + typename loader_a_t, + typename loader_b_t> +METAL_FUNC void gemm_loop_unaligned( + threadgroup T* As, + threadgroup T* Bs, + thread mma_t& mma_op, + thread loader_a_t& loader_a, + thread loader_b_t& loader_b, + const int k_iterations, + const short tgp_bm, + const short tgp_bn, + const short tgp_bk) { + for (int k = 0; k < k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load elements into threadgroup memory + if (rows_aligned) { + loader_a.load_unsafe(); + } else { + loader_a.load_safe(short2(tgp_bk, tgp_bm)); + } + if (cols_aligned) { + loader_b.load_unsafe(); + } else { + loader_b.load_safe( + transpose ? short2(tgp_bk, tgp_bn) : short2(tgp_bn, tgp_bk)); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } +} + +template +METAL_FUNC void gemm_loop_finalize( + threadgroup T* As, + threadgroup T* Bs, + thread mma_t& mma_op, + thread loader_a_t& loader_a, + thread loader_b_t& loader_b, + const short2 tile_a, + const short2 tile_b) { + loader_a.load_safe(tile_a); + loader_b.load_safe(tile_b); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(As, Bs); +} + +template < + typename T, + int group_size, + int bits, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose> +[[kernel]] void gather_qmm_rhs( + const device T* x [[buffer(0)]], + const device uint32_t* w [[buffer(1)]], + const device T* scales [[buffer(2)]], + const device T* biases [[buffer(3)]], + const device uint32_t* indices [[buffer(4)]], + device T* y [[buffer(5)]], + const constant int& M [[buffer(6)]], + const constant int& N [[buffer(7)]], + const constant int& K [[buffer(8)]], + uint3 tid [[threadgroup_position_in_grid]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]]) { + constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; + constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int BN_padded = (BN + 16 / sizeof(T)); + constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; + constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3; + + using mma_t = mlx::steel::BlockMMA< + T, + T, + BM, + BN, + BK, + WM, + WN, + false, + transpose, + BK_padded, + transpose ? BK_padded : BN_padded>; + using loader_x_t = + mlx::steel::BlockLoader; + using loader_w_t = QuantizedBlockLoader< + T, + transpose ? BN : BK, + transpose ? BK : BN, + transpose ? BK_padded : BN_padded, + transpose, + WM * WN * SIMD_SIZE, + group_size, + bits>; + + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[transpose ? BN * BK_padded : BK * BN_padded]; + + // Compute the block + const int K_w = K * bytes_per_pack / pack_factor; + const int K_g = K / group_size; + const int N_w = N * bytes_per_pack / pack_factor; + const int N_g = N / group_size; + const int K_it = K / BK; + const size_t stride_w = transpose ? N * K_w : K * N_w; + const size_t stride_s = transpose ? N * K_g : K * N_g; + const int y_row = tid.y * BM; + const int y_col = tid.x * BN; + const size_t y_row_long = size_t(y_row); + const size_t y_col_long = size_t(y_col); + + // Prepare threadgroup bounds + const short tgp_bm = align_M ? BM : short(min(BM, M - y_row)); + const short tgp_bn = align_N ? BN : short(min(BN, N - y_col)); + + // Calculate the final tiles in the case that K is not aligned + const int k_remain = K - K_it * BK; + const short2 tile_x = short2(k_remain, tgp_bm); + const short2 tile_w = + transpose ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + + // Move x and output to the correct block + auto wl = (const device uint8_t*)w; + x += y_row_long * K; + y += y_row_long * N + y_col_long; + wl += transpose ? y_col_long * K_w : y_col * bytes_per_pack / pack_factor; + scales += transpose ? y_col_long * K_g : y_col / group_size; + biases += transpose ? y_col_long * K_g : y_col / group_size; + + // Do as many matmuls as necessary + uint32_t index; + short offset; + uint32_t index_next = indices[y_row]; + short offset_next = 0; + int n = 0; + while (n < tgp_bm) { + n++; + offset = offset_next; + index = index_next; + offset_next = tgp_bm; + for (; n < tgp_bm; n++) { + if (indices[y_row + n] != index) { + offset_next = n; + index_next = indices[y_row + n]; + break; + } + } + threadgroup_barrier(mem_flags::mem_none); + + // Prepare threadgroup mma operation + thread mma_t mma_op(simd_group_id, simd_lane_id); + + // Prepare threadgroup loading operations + thread loader_x_t loader_x(x, K, Xs, simd_group_id, simd_lane_id); + thread loader_w_t loader_w( + wl + index * stride_w, + scales + index * stride_s, + biases + index * stride_s, + transpose ? K : N, + Ws, + simd_group_id, + simd_lane_id); + + // Matrices are all aligned check nothing + if (align_M && align_N) { + gemm_loop_aligned(Xs, Ws, mma_op, loader_x, loader_w, K_it); + if (!align_K) { + threadgroup_barrier(mem_flags::mem_threadgroup); + gemm_loop_finalize(Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); + } + + // Store results to device memory + if (offset_next - offset == BM) { + mma_op.store_result(y, N); + } else { + mma_op.store_result_slice( + y, N, short2(0, offset), short2(BN, offset_next)); + } + } else { + // Tile aligned so check outside of the hot loop + if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) { + gemm_loop_aligned(Xs, Ws, mma_op, loader_x, loader_w, K_it); + if (!align_K) { + threadgroup_barrier(mem_flags::mem_threadgroup); + gemm_loop_finalize( + Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); + } + + // Store results to device memory + if (offset_next - offset == BM) { + mma_op.store_result(y, N); + } else { + mma_op.store_result_slice( + y, N, short2(0, offset), short2(BN, offset_next)); + } + } + + // Tile partially aligned check rows + else if (align_N || tgp_bn == BN) { + gemm_loop_unaligned( + Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK); + if (!align_K) { + threadgroup_barrier(mem_flags::mem_threadgroup); + gemm_loop_finalize( + Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); + } + mma_op.store_result_slice( + y, N, short2(0, offset), short2(BN, offset_next)); + } + + // Tile partially aligned check cols + else if (align_M || tgp_bm == BM) { + gemm_loop_unaligned( + Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK); + if (!align_K) { + threadgroup_barrier(mem_flags::mem_threadgroup); + gemm_loop_finalize( + Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); + } + mma_op.store_result_slice( + y, N, short2(0, offset), short2(tgp_bn, offset_next)); + } + + // Nothing aligned so check both rows and cols + else { + gemm_loop_unaligned( + Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK); + if (!align_K) { + threadgroup_barrier(mem_flags::mem_threadgroup); + gemm_loop_finalize( + Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); + } + mma_op.store_result_slice( + y, N, short2(0, offset), short2(tgp_bn, offset_next)); + } + } + } +} + template [[kernel]] void affine_quantize( const device T* w [[buffer(0)]], diff --git a/mlx/backend/metal/kernels/quantized.metal b/mlx/backend/metal/kernels/quantized.metal index 7af554437..11cd8421b 100644 --- a/mlx/backend/metal/kernels/quantized.metal +++ b/mlx/backend/metal/kernels/quantized.metal @@ -60,6 +60,20 @@ bits, \ split_k) +#define instantiate_gather_qmm_rhs(func, name, type, group_size, bits, bm, bn, bk, wm, wn, transpose) \ + instantiate_kernel( \ + #name "_" #type "_gs_" #group_size "_b_" #bits "_bm_" #bm "_bn_" #bn "_bk_" #bk "_wm_" #wm "_wn_" #wn, \ + func, \ + type, \ + group_size, \ + bits, \ + bm, \ + bn, \ + bk, \ + wm, \ + wn, \ + transpose) + #define instantiate_quantized_batched_wrap(name, type, group_size, bits) \ instantiate_quantized_batched(name, type, group_size, bits, 1) \ instantiate_quantized_batched(name, type, group_size, bits, 0) @@ -73,14 +87,14 @@ #define instantiate_quantized_all_single(type, group_size, bits) \ instantiate_quantized(affine_quantize, type, group_size, bits) \ instantiate_quantized(affine_dequantize, type, group_size, bits) \ - instantiate_quantized(bs_qmv_fast, type, group_size, bits) \ - instantiate_quantized(bs_qmv, type, group_size, bits) \ - instantiate_quantized(bs_qvm, type, group_size, bits) \ - instantiate_quantized(bs_qmm_n, type, group_size, bits) + instantiate_quantized(gather_qmv_fast, type, group_size, bits) \ + instantiate_quantized(gather_qmv, type, group_size, bits) \ + instantiate_quantized(gather_qvm, type, group_size, bits) \ + instantiate_quantized(gather_qmm_n, type, group_size, bits) #define instantiate_quantized_all_aligned(type, group_size, bits) \ - instantiate_quantized_aligned(bs_qmm_t, type, group_size, bits, true) \ - instantiate_quantized_aligned(bs_qmm_t, type, group_size, bits, false) \ + instantiate_quantized_aligned(gather_qmm_t, type, group_size, bits, true) \ + instantiate_quantized_aligned(gather_qmm_t, type, group_size, bits, false) \ instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, true, 1) \ instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, true, 0) \ instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, false, 1) \ @@ -96,12 +110,17 @@ instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 8) \ instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 32) +#define instantiate_quantized_all_rhs(type, group_size, bits) \ + instantiate_gather_qmm_rhs(gather_qmm_rhs, gather_qmm_rhs_nt, type, group_size, bits, 16, 32, 32, 1, 2, true) \ + instantiate_gather_qmm_rhs(gather_qmm_rhs, gather_qmm_rhs_nn, type, group_size, bits, 16, 32, 32, 1, 2, false) + #define instantiate_quantized_funcs(type, group_size, bits) \ instantiate_quantized_all_single(type, group_size, bits) \ instantiate_quantized_all_batched(type, group_size, bits) \ instantiate_quantized_all_aligned(type, group_size, bits) \ instantiate_quantized_all_quad(type, group_size, bits) \ - instantiate_quantized_all_splitk(type, group_size, bits) + instantiate_quantized_all_splitk(type, group_size, bits) \ + instantiate_quantized_all_rhs(type, group_size, bits) #define instantiate_quantized_types(group_size, bits) \ instantiate_quantized_funcs(float, group_size, bits) \ diff --git a/mlx/backend/metal/kernels/scan.h b/mlx/backend/metal/kernels/scan.h index cfa84c04c..cb5147558 100644 --- a/mlx/backend/metal/kernels/scan.h +++ b/mlx/backend/metal/kernels/scan.h @@ -2,6 +2,8 @@ #pragma once +#include "mlx/backend/metal/kernels/binary_ops.h" + #define DEFINE_SIMD_SCAN() \ template = true> \ T simd_scan(T val) { \ @@ -139,6 +141,29 @@ struct CumMin { } }; +template +struct CumLogaddexp { + static constexpr constant U init = Limits::min; + + template + U operator()(U a, T b) { + return LogAddExp{}(a, static_cast(b)); + } + + U simd_scan(U x) { + for (int i = 1; i <= 16; i *= 2) { + U other = simd_shuffle_and_fill_up(x, init, i); + x = LogAddExp{}(x, other); + } + return x; + } + + U simd_exclusive_scan(U x) { + x = simd_scan(x); + return simd_shuffle_and_fill_up(x, init, 1); + } +}; + template inline void load_unsafe(U values[N_READS], const device T* input) { if (reverse) { diff --git a/mlx/backend/metal/kernels/scan.metal b/mlx/backend/metal/kernels/scan.metal index 6aa36f5a3..8fcd7f61b 100644 --- a/mlx/backend/metal/kernels/scan.metal +++ b/mlx/backend/metal/kernels/scan.metal @@ -101,4 +101,7 @@ instantiate_scan_helper(min_int64_int64, int64_t, int64_t, CumMi instantiate_scan_helper(min_float16_float16, half, half, CumMin, 4) instantiate_scan_helper(min_float32_float32, float, float, CumMin, 4) instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMin, 4) -instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin, 2) // clang-format on +instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin, 2) +instantiate_scan_helper(logaddexp_float16_float16, half, half, CumLogaddexp, 4) +instantiate_scan_helper(logaddexp_float32_float32, float, float, CumLogaddexp, 4) +instantiate_scan_helper(logaddexp_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumLogaddexp, 4) // clang-format on diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h index bcc585bbe..add495d93 100644 --- a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h @@ -15,10 +15,6 @@ constant bool align_M [[function_constant(200)]]; constant bool align_N [[function_constant(201)]]; constant bool align_K [[function_constant(202)]]; -constant bool do_gather [[function_constant(300)]]; - -constant bool gather_bias = do_gather && use_out_source; - // clang-format off template < typename T, @@ -39,12 +35,6 @@ template < const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]], const constant int* batch_shape [[buffer(6)]], const constant int64_t* batch_strides [[buffer(7)]], - const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]], - const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]], - const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]], - const constant int* operand_shape [[buffer(13), function_constant(do_gather)]], - const constant int64_t* operand_strides [[buffer(14), function_constant(do_gather)]], - const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]], uint3 tid [[threadgroup_position_in_grid]], @@ -81,84 +71,26 @@ template < } // Adjust for batch + if (has_batch) { + const constant auto* A_bstrides = batch_strides; + const constant auto* B_bstrides = batch_strides + params->batch_ndim; - // Handle gather - if (do_gather) { - // Read indices - uint32_t indx_A, indx_B, indx_C; + ulong2 batch_offsets = elem_to_loc_broadcast( + tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim); - if (has_batch) { - const constant auto* indx_A_bstrides = batch_strides; - const constant auto* indx_B_bstrides = batch_strides + params->batch_ndim; - - ulong2 indx_offsets = elem_to_loc_broadcast( - tid.z, - batch_shape, - indx_A_bstrides, - indx_B_bstrides, - params->batch_ndim); - indx_A = lhs_indices[indx_offsets.x]; - indx_B = rhs_indices[indx_offsets.y]; - - if (use_out_source) { - const constant auto* indx_C_bstrides = - indx_B_bstrides + params->batch_ndim; - auto indx_offset_C = elem_to_loc( - tid.z, batch_shape, indx_C_bstrides, params->batch_ndim); - indx_C = C_indices[indx_offset_C]; - } - } else { - indx_A = lhs_indices[params->batch_stride_a * tid.z]; - indx_B = rhs_indices[params->batch_stride_b * tid.z]; - - if (use_out_source) { - indx_C = C_indices[addmm_params->batch_stride_c * tid.z]; - } - } - - // Translate indices to offsets - int batch_ndim_A = operand_batch_ndim.x; - const constant int* batch_shape_A = operand_shape; - const constant auto* batch_strides_A = operand_strides; - A += elem_to_loc(indx_A, batch_shape_A, batch_strides_A, batch_ndim_A); - - int batch_ndim_B = operand_batch_ndim.y; - const constant int* batch_shape_B = batch_shape_A + batch_ndim_A; - const constant auto* batch_strides_B = batch_strides_A + batch_ndim_A; - B += elem_to_loc(indx_B, batch_shape_B, batch_strides_B, batch_ndim_B); + A += batch_offsets.x; + B += batch_offsets.y; if (use_out_source) { - int batch_ndim_C = operand_batch_ndim.z; - const constant int* batch_shape_C = batch_shape_B + batch_ndim_B; - const constant auto* batch_strides_C = batch_strides_B + batch_ndim_B; - C += elem_to_loc(indx_C, batch_shape_C, batch_strides_C, batch_ndim_C); + const constant auto* C_bstrides = B_bstrides + params->batch_ndim; + C += elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim); } + } else { + A += params->batch_stride_a * tid.z; + B += params->batch_stride_b * tid.z; - } - - // Handle regular batch - else { - if (has_batch) { - const constant auto* A_bstrides = batch_strides; - const constant auto* B_bstrides = batch_strides + params->batch_ndim; - - ulong2 batch_offsets = elem_to_loc_broadcast( - tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim); - - A += batch_offsets.x; - B += batch_offsets.y; - - if (use_out_source) { - const constant auto* C_bstrides = B_bstrides + params->batch_ndim; - C += elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim); - } - } else { - A += params->batch_stride_a * tid.z; - B += params->batch_stride_b * tid.z; - - if (use_out_source) { - C += addmm_params->batch_stride_c * tid.z; - } + if (use_out_source) { + C += addmm_params->batch_stride_c * tid.z; } } diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h new file mode 100644 index 000000000..4493375c1 --- /dev/null +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h @@ -0,0 +1,459 @@ +// Copyright © 2024 Apple Inc. + +using namespace mlx::steel; + +constant bool has_batch [[function_constant(10)]]; +constant bool align_M [[function_constant(200)]]; +constant bool align_N [[function_constant(201)]]; +constant bool align_K [[function_constant(202)]]; + +template < + typename T, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + typename AccumType = float> +[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gather_mm_rhs( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + const device uint32_t* rhs_indices [[buffer(2)]], + device T* C [[buffer(3)]], + const constant GEMMParams* params [[buffer(4)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]]) { + using gemm_kernel = GEMMKernel< + T, + T, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + true, + true, + AccumType>; + + using loader_a_t = typename gemm_kernel::loader_a_t; + using loader_b_t = typename gemm_kernel::loader_b_t; + using mma_t = typename gemm_kernel::mma_t; + + if (params->tiles_n <= static_cast(tid.x) || + params->tiles_m <= static_cast(tid.y)) { + return; + } + + // Prepare threadgroup memory + threadgroup T As[gemm_kernel::tgp_mem_size_a]; + threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; + + // Find the block in A, B, C + const int c_row = tid.y * BM; + const int c_col = tid.x * BN; + const size_t c_row_long = size_t(c_row); + const size_t c_col_long = size_t(c_col); + + // Prepare threadgroup bounds + const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row)); + const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col)); + + A += transpose_a ? c_row_long : c_row_long * params->lda; + B += transpose_b ? c_col_long * params->ldb : c_col_long; + C += c_row_long * params->ldd + c_col_long; + + // Do as many matmuls as necessary + uint32_t index; + short offset; + uint32_t index_next = rhs_indices[c_row]; + short offset_next = 0; + int n = 0; + while (n < tgp_bm) { + n++; + offset = offset_next; + index = index_next; + offset_next = tgp_bm; + for (; n < tgp_bm; n++) { + if (rhs_indices[c_row + n] != index) { + offset_next = n; + index_next = rhs_indices[c_row + n]; + break; + } + } + threadgroup_barrier(mem_flags::mem_none); + + // Prepare threadgroup mma operation + thread mma_t mma_op(simd_group_id, simd_lane_id); + + // Prepare threadgroup loading operations + thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); + thread loader_b_t loader_b( + B + index * params->batch_stride_b, + params->ldb, + Bs, + simd_group_id, + simd_lane_id); + + // Prepare iterations + const int gemm_k_iterations = params->gemm_k_iterations_aligned; + + // Do unaligned K iterations first + if (!align_K) { + const int k_last = params->gemm_k_iterations_aligned * BK; + const int k_remain = params->K - k_last; + const size_t k_jump_a = + transpose_a ? params->lda * size_t(k_last) : size_t(k_last); + const size_t k_jump_b = + transpose_b ? size_t(k_last) : params->ldb * size_t(k_last); + + // Move loader source ahead to end + loader_a.src += k_jump_a; + loader_b.src += k_jump_b; + + // Load tile + const short2 tile_dims_A = + transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); + const short2 tile_dims_B = + transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Do matmul + mma_op.mma(As, Bs); + + // Reset source back to start + loader_a.src -= k_jump_a; + loader_b.src -= k_jump_b; + } + + // Matrix level aligned never check + if (align_M && align_N) { + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + // Store results to device memory + if (offset_next - offset == BM) { + mma_op.store_result(C, params->ldd); + } else { + mma_op.store_result_slice( + C, params->ldd, short2(0, offset), short2(BN, offset_next)); + } + } else { + const short lbk = 0; + + // Tile aligned don't check + if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + lbk, + LoopAlignment{}); + if (offset_next - offset == BM) { + mma_op.store_result(C, params->ldd); + } else { + mma_op.store_result_slice( + C, params->ldd, short2(0, offset), short2(BN, offset_next)); + } + } + + // Tile partially aligned check rows + else if (align_N || tgp_bn == BN) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + lbk, + LoopAlignment{}); + mma_op.store_result_slice( + C, params->ldd, short2(0, offset), short2(BN, offset_next)); + } + + // Tile partially aligned check cols + else if (align_M || tgp_bm == BM) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + lbk, + LoopAlignment{}); + mma_op.store_result_slice( + C, params->ldd, short2(0, offset), short2(tgp_bn, offset_next)); + } + + // Nothing aligned so check both rows and cols + else { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + lbk, + LoopAlignment{}); + mma_op.store_result_slice( + C, params->ldd, short2(0, offset), short2(tgp_bn, offset_next)); + } + } + } +} + +template < + typename T, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + typename AccumType = float> +[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gather_mm( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + const device uint32_t* lhs_indices [[buffer(2)]], + const device uint32_t* rhs_indices [[buffer(3)]], + device T* C [[buffer(4)]], + const constant GEMMParams* params [[buffer(5)]], + const constant int* indices_shape [[buffer(6)]], + const constant int64_t* lhs_strides [[buffer(7)]], + const constant int64_t* rhs_strides [[buffer(8)]], + const constant int& batch_ndim_a [[buffer(9)]], + const constant int* batch_shape_a [[buffer(10)]], + const constant int64_t* batch_strides_a [[buffer(11)]], + const constant int& batch_ndim_b [[buffer(12)]], + const constant int* batch_shape_b [[buffer(13)]], + const constant int64_t* batch_strides_b [[buffer(14)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]]) { + using gemm_kernel = GEMMKernel< + T, + T, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + true, + true, + AccumType>; + + using loader_a_t = typename gemm_kernel::loader_a_t; + using loader_b_t = typename gemm_kernel::loader_b_t; + using mma_t = typename gemm_kernel::mma_t; + + if (params->tiles_n <= static_cast(tid.x) || + params->tiles_m <= static_cast(tid.y)) { + return; + } + + // Move A and B to the locations pointed by lhs_indices and rhs_indices. + uint32_t indx_A, indx_B; + if (has_batch) { + ulong2 indices_offsets = elem_to_loc_broadcast( + tid.z, indices_shape, lhs_strides, rhs_strides, params->batch_ndim); + indx_A = lhs_indices[indices_offsets.x]; + indx_B = rhs_indices[indices_offsets.y]; + } else { + indx_A = lhs_indices[params->batch_stride_a * tid.z]; + indx_B = rhs_indices[params->batch_stride_b * tid.z]; + } + A += elem_to_loc(indx_A, batch_shape_a, batch_strides_a, batch_ndim_a); + B += elem_to_loc(indx_B, batch_shape_b, batch_strides_b, batch_ndim_b); + C += params->batch_stride_d * tid.z; + + // Prepare threadgroup memory + threadgroup T As[gemm_kernel::tgp_mem_size_a]; + threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; + + // Just make sure everybody's finished with the indexing math above. + threadgroup_barrier(mem_flags::mem_none); + + // Find block in A, B, C + const int c_row = tid.y * BM; + const int c_col = tid.x * BN; + const size_t c_row_long = size_t(c_row); + const size_t c_col_long = size_t(c_col); + + A += transpose_a ? c_row_long : c_row_long * params->lda; + B += transpose_b ? c_col_long * params->ldb : c_col_long; + C += c_row_long * params->ldd + c_col_long; + + // Prepare threadgroup mma operation + thread mma_t mma_op(simd_group_id, simd_lane_id); + + // Prepare threadgroup loading operations + thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); + thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id); + + // Prepare threadgroup bounds + const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row)); + const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col)); + + // Prepare iterations + int gemm_k_iterations = params->gemm_k_iterations_aligned; + + // Do unaligned K iterations first + if (!align_K) { + const int k_last = params->gemm_k_iterations_aligned * BK; + const int k_remain = params->K - k_last; + const size_t k_jump_a = + transpose_a ? params->lda * size_t(k_last) : size_t(k_last); + const size_t k_jump_b = + transpose_b ? size_t(k_last) : params->ldb * size_t(k_last); + + // Move loader source ahead to end + loader_a.src += k_jump_a; + loader_b.src += k_jump_b; + + // Load tile + const short2 tile_dims_A = + transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); + const short2 tile_dims_B = + transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Do matmul + mma_op.mma(As, Bs); + + // Reset source back to start + loader_a.src -= k_jump_a; + loader_b.src -= k_jump_b; + } + + // Matrix level aligned never check + if (align_M && align_N) { + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + // Store results to device memory + mma_op.store_result(C, params->ldd); + } else { + const short lbk = 0; + + // Tile aligned don't check + if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + lbk, + LoopAlignment{}); + mma_op.store_result(C, params->ldd); + } + + // Tile partially aligned check rows + else if (align_N || tgp_bn == BN) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + lbk, + LoopAlignment{}); + mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm)); + } + + // Tile partially aligned check cols + else if (align_M || tgp_bm == BM) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + lbk, + LoopAlignment{}); + mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm)); + } + + // Nothing aligned so check both rows and cols + else { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + lbk, + LoopAlignment{}); + mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm)); + } + } +} diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal new file mode 100644 index 000000000..f8e5a2a37 --- /dev/null +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal @@ -0,0 +1,59 @@ +// Copyright © 2024 Apple Inc. + +// clang-format off +#include "mlx/backend/metal/kernels/utils.h" +#include "mlx/backend/metal/kernels/steel/gemm/gemm.h" +#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h" + +#define instantiate_gather_mm_rhs(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_kernel( \ + "steel_gather_mm_rhs_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn \ + "_bk" #bk "_wm" #wm "_wn" #wn, \ + gather_mm_rhs, \ + itype, \ + bm, \ + bn, \ + bk, \ + wm, \ + wn, \ + trans_a, \ + trans_b, \ + float) + +#define instantiate_gather_mm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_kernel( \ + "steel_gather_mm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn \ + "_bk" #bk "_wm" #wm "_wn" #wn, \ + gather_mm, \ + itype, \ + bm, \ + bn, \ + bk, \ + wm, \ + wn, \ + trans_a, \ + trans_b, \ + float) + +#define instantiate_gather_mm_rhs_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gather_mm_rhs(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gather_mm_rhs(nt, false, true, iname, itype, oname, otype, bm, bn, bk, wm, wn) + +#define instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gather_mm(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gather_mm(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gather_mm(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gather_mm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) + +#define instantiate_gather_mm_shapes_helper(iname, itype, oname, otype) \ + instantiate_gather_mm_rhs_transpose_helper(iname, itype, oname, otype, 16, 64, 16, 1, 2) \ + instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \ + instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 1, 2) \ + instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \ + instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 1, 2) \ + instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) +// clang-format on + +instantiate_gather_mm_shapes_helper(float16, half, float16, half); +instantiate_gather_mm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t); +instantiate_gather_mm_shapes_helper(float32, float, float32, float); diff --git a/mlx/backend/metal/kernels/steel/gemm/mma.h b/mlx/backend/metal/kernels/steel/gemm/mma.h index aea235abb..64b87655e 100644 --- a/mlx/backend/metal/kernels/steel/gemm/mma.h +++ b/mlx/backend/metal/kernels/steel/gemm/mma.h @@ -142,6 +142,42 @@ struct BaseMMAFrag { } } + template < + typename DstPtrType, + typename StrX, + typename StrY, + typename StartX, + typename StopX, + typename StartY, + typename StopY, + typename OffX, + typename OffY> + METAL_FUNC static constexpr void store_slice( + const thread frag_type& src, + DstPtrType dst, + StrX str_x, + StrY str_y, + StartX start_x, + StopX stop_x, + StartY start_y, + StopY stop_y, + OffX off_x = Int<0>{}, + OffY off_y = Int<0>{}) { + using U = pointer_element_t; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + if ((off_x + i) < stop_x && (off_x + i) >= start_x && + (off_y + j) < stop_y && (off_y + j) >= start_y) { + dst[(off_x + i) * str_x + (off_y + j) * str_y] = + static_cast(src[i * kElemCols + j]); + } + } + } + } + METAL_FUNC static constexpr void mma( thread frag_type& D, thread frag_type& A, @@ -335,6 +371,31 @@ struct MMATile { } } } + + template + METAL_FUNC void store_slice( + device U* dst, + const int ld, + const short2 start, + const short2 stop) const { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kTileCols; ++j) { + MMAFrag_t::store_slice( + frag_at(i, j), + dst, + ld, + Int<1>{}, + start.y, + stop.y, + start.x, + stop.x, + (i * kFragRows) * w_x, + (j * kFragCols) * w_y); + } + } + } }; template @@ -474,6 +535,26 @@ struct BlockMMA { Ctile.template store(D, ldd); } + METAL_FUNC void + store_result_slice(device U* D, const int ldd, short2 start, short2 stop) { + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) { + Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]); + } + + D += sm * ldd + sn; + start -= short2(sn, sm); + stop -= short2(sn, sm); + + // TODO: Check the start as well + if (stop.y <= 0 || stop.x <= 0) { + return; + } + + Ctile.template store_slice(D, ldd, start, stop); + } + METAL_FUNC void store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) { // Apply epilogue diff --git a/mlx/backend/metal/kernels/unary.metal b/mlx/backend/metal/kernels/unary.metal index 2209b0665..d34c5a7ec 100644 --- a/mlx/backend/metal/kernels/unary.metal +++ b/mlx/backend/metal/kernels/unary.metal @@ -69,6 +69,9 @@ instantiate_unary_float(Round) instantiate_unary_int(BitwiseInvert) instantiate_unary_all_same(Abs, complex64, complex64_t) +instantiate_unary_all_same(ArcCos, complex64, complex64_t) +instantiate_unary_all_same(ArcSin, complex64, complex64_t) +instantiate_unary_all_same(ArcTan, complex64, complex64_t) instantiate_unary_all_same(Conjugate, complex64, complex64_t) instantiate_unary_all_same(Cos, complex64, complex64_t) instantiate_unary_all_same(Cosh, complex64, complex64_t) @@ -80,6 +83,9 @@ instantiate_unary_all_same(Negative, complex64, complex64_t) instantiate_unary_all_same(Sign, complex64, complex64_t) instantiate_unary_all_same(Sin, complex64, complex64_t) instantiate_unary_all_same(Sinh, complex64, complex64_t) +instantiate_unary_all_same(Square, complex64, complex64_t) +instantiate_unary_all_same(Sqrt, complex64, complex64_t) +instantiate_unary_all_same(Rsqrt, complex64, complex64_t) instantiate_unary_all_same(Tan, complex64, complex64_t) instantiate_unary_all_same(Tanh, complex64, complex64_t) instantiate_unary_all_same(Round, complex64, complex64_t) diff --git a/mlx/backend/metal/kernels/unary_ops.h b/mlx/backend/metal/kernels/unary_ops.h index 52e126b40..09d9f6605 100644 --- a/mlx/backend/metal/kernels/unary_ops.h +++ b/mlx/backend/metal/kernels/unary_ops.h @@ -17,27 +17,21 @@ struct Abs { T operator()(T x) { return metal::abs(x); }; - template <> uint8_t operator()(uint8_t x) { return x; }; - template <> uint16_t operator()(uint16_t x) { return x; }; - template <> uint32_t operator()(uint32_t x) { return x; }; - template <> uint64_t operator()(uint64_t x) { return x; }; - template <> bool operator()(bool x) { return x; }; - template <> complex64_t operator()(complex64_t x) { return {metal::precise::sqrt(x.real * x.real + x.imag * x.imag), 0}; }; @@ -48,6 +42,8 @@ struct ArcCos { T operator()(T x) { return metal::precise::acos(x); }; + + complex64_t operator()(complex64_t x); }; struct ArcCosh { @@ -62,6 +58,8 @@ struct ArcSin { T operator()(T x) { return metal::precise::asin(x); }; + + complex64_t operator()(complex64_t x); }; struct ArcSinh { @@ -76,6 +74,8 @@ struct ArcTan { T operator()(T x) { return metal::precise::atan(x); }; + + complex64_t operator()(complex64_t x); }; struct ArcTanh { @@ -97,39 +97,30 @@ struct Ceil { T operator()(T x) { return metal::ceil(x); }; - template <> int8_t operator()(int8_t x) { return x; }; - template <> int16_t operator()(int16_t x) { return x; }; - template <> int32_t operator()(int32_t x) { return x; }; - template <> int64_t operator()(int64_t x) { return x; }; - template <> uint8_t operator()(uint8_t x) { return x; }; - template <> uint16_t operator()(uint16_t x) { return x; }; - template <> uint32_t operator()(uint32_t x) { return x; }; - template <> uint64_t operator()(uint64_t x) { return x; }; - template <> bool operator()(bool x) { return x; }; @@ -141,7 +132,6 @@ struct Cos { return metal::precise::cos(x); }; - template <> complex64_t operator()(complex64_t x) { return { metal::precise::cos(x.real) * metal::precise::cosh(x.imag), @@ -155,7 +145,6 @@ struct Cosh { return metal::precise::cosh(x); }; - template <> complex64_t operator()(complex64_t x) { return { metal::precise::cosh(x.real) * metal::precise::cos(x.imag), @@ -188,7 +177,6 @@ struct Exp { T operator()(T x) { return metal::precise::exp(x); }; - template <> complex64_t operator()(complex64_t x) { auto m = metal::precise::exp(x.real); return {m * metal::precise::cos(x.imag), m * metal::precise::sin(x.imag)}; @@ -207,39 +195,30 @@ struct Floor { T operator()(T x) { return metal::floor(x); }; - template <> int8_t operator()(int8_t x) { return x; }; - template <> int16_t operator()(int16_t x) { return x; }; - template <> int32_t operator()(int32_t x) { return x; }; - template <> int64_t operator()(int64_t x) { return x; }; - template <> uint8_t operator()(uint8_t x) { return x; }; - template <> uint16_t operator()(uint16_t x) { return x; }; - template <> uint32_t operator()(uint32_t x) { return x; }; - template <> uint64_t operator()(uint64_t x) { return x; }; - template <> bool operator()(bool x) { return x; }; @@ -258,7 +237,6 @@ struct Log { return metal::precise::log(x); }; - template <> complex64_t operator()(complex64_t x) { auto r = metal::precise::log(Abs{}(x).real); auto i = metal::precise::atan2(x.imag, x.real); @@ -272,7 +250,6 @@ struct Log2 { return metal::precise::log2(x); }; - template <> complex64_t operator()(complex64_t x) { auto y = Log{}(x); return {y.real / M_LN2_F, y.imag / M_LN2_F}; @@ -285,7 +262,6 @@ struct Log10 { return metal::precise::log10(x); }; - template <> complex64_t operator()(complex64_t x) { auto y = Log{}(x); return {y.real / M_LN10_F, y.imag / M_LN10_F}; @@ -325,7 +301,6 @@ struct Round { T operator()(T x) { return metal::rint(x); }; - template <> complex64_t operator()(complex64_t x) { return {metal::rint(x.real), metal::rint(x.imag)}; }; @@ -344,11 +319,9 @@ struct Sign { T operator()(T x) { return (x > T(0)) - (x < T(0)); }; - template <> uint32_t operator()(uint32_t x) { return x != 0; }; - template <> complex64_t operator()(complex64_t x) { if (x == complex64_t(0)) { return x; @@ -364,7 +337,6 @@ struct Sin { return metal::precise::sin(x); }; - template <> complex64_t operator()(complex64_t x) { return { metal::precise::sin(x.real) * metal::precise::cosh(x.imag), @@ -378,7 +350,6 @@ struct Sinh { return metal::precise::sinh(x); }; - template <> complex64_t operator()(complex64_t x) { return { metal::precise::sinh(x.real) * metal::precise::cos(x.imag), @@ -398,6 +369,17 @@ struct Sqrt { T operator()(T x) { return metal::precise::sqrt(x); }; + + complex64_t operator()(complex64_t x) { + if (x.real == 0.0 && x.imag == 0.0) { + return {0.0, 0.0}; + } + auto r = Abs{}(x).real; + auto a = metal::precise::sqrt((r + x.real) / 2.0); + auto b_abs = metal::precise::sqrt((r - x.real) / 2.0); + auto b = metal::copysign(b_abs, x.imag); + return {a, b}; + } }; struct Rsqrt { @@ -405,6 +387,10 @@ struct Rsqrt { T operator()(T x) { return metal::precise::rsqrt(x); }; + + complex64_t operator()(complex64_t x) { + return 1.0 / Sqrt{}(x); + } }; struct Tan { @@ -413,7 +399,6 @@ struct Tan { return metal::precise::tan(x); }; - template <> complex64_t operator()(complex64_t x) { float tan_a = metal::precise::tan(x.real); float tanh_b = metal::precise::tanh(x.imag); @@ -429,7 +414,6 @@ struct Tanh { return metal::precise::tanh(x); }; - template <> complex64_t operator()(complex64_t x) { float tanh_a = metal::precise::tanh(x.real); float tan_b = metal::precise::tan(x.imag); @@ -438,3 +422,21 @@ struct Tanh { return {(tanh_a + tan_b * t1) / denom, (tan_b - tanh_a * t1) / denom}; }; }; + +complex64_t ArcCos::operator()(complex64_t x) { + auto i = complex64_t{0.0, 1.0}; + auto y = Log{}(x + i * Sqrt{}(1.0 - x * x)); + return {y.imag, -y.real}; +}; + +complex64_t ArcSin::operator()(complex64_t x) { + auto i = complex64_t{0.0, 1.0}; + auto y = Log{}(i * x + Sqrt{}(1.0 - x * x)); + return {y.imag, -y.real}; +}; + +complex64_t ArcTan::operator()(complex64_t x) { + auto i = complex64_t{0.0, 1.0}; + auto ix = i * x; + return (1.0 / complex64_t{0.0, 2.0}) * Log{}((1.0 + ix) / (1.0 - ix)); +}; diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index 3f736505f..f55d20c9f 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -5,6 +5,7 @@ #include #include +#include "mlx/backend/common/broadcasting.h" #include "mlx/backend/common/utils.h" #include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/device.h" @@ -102,6 +103,47 @@ std::tuple check_transpose( } }; +inline array +ensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) { + if (!x.flags().row_contiguous) { + array x_copy(x.shape(), x.dtype(), nullptr, {}); + copy_gpu(x, x_copy, CopyType::General, s); + d.add_temporary(x_copy, s.index); + return x_copy; + } else { + return x; + } +} + +inline std::tuple +ensure_batch_contiguous(const array& x, metal::Device& d, const Stream& s) { + if (x.flags().row_contiguous) { + return std::make_tuple(false, x.strides()[x.ndim() - 2], x); + } + + bool rc = true; + for (int i = 0; i < x.ndim() - 3; i++) { + rc &= x.strides()[i + 1] * x.shape(i) == x.strides()[i]; + } + if (rc) { + auto stx = x.strides()[x.ndim() - 2]; + auto sty = x.strides()[x.ndim() - 1]; + auto K = x.shape(-2); + auto N = x.shape(-1); + if (sty == 1 && (N != 1 || stx == N)) { + return std::make_tuple(false, stx, x); + } + if (stx == 1 && (N != 1 || sty == K)) { + return std::make_tuple(true, sty, x); + } + } + + array x_copy(x.shape(), x.dtype(), nullptr, {}); + copy_gpu(x, x_copy, CopyType::General, s); + d.add_temporary(x_copy, s.index); + return std::make_tuple(false, x_copy.strides()[x_copy.ndim() - 2], x_copy); +} + } // namespace /////////////////////////////////////////////////////////////////////////////// @@ -230,7 +272,6 @@ void steel_matmul_regular( const bool align_M = (M % bm) == 0; const bool align_N = (N % bn) == 0; const bool align_K = (K % bk) == 0; - const bool do_gather = false; metal::MTLFCList func_consts = { {&has_batch, MTL::DataType::DataTypeBool, 10}, @@ -239,7 +280,6 @@ void steel_matmul_regular( {&align_M, MTL::DataType::DataTypeBool, 200}, {&align_N, MTL::DataType::DataTypeBool, 201}, {&align_K, MTL::DataType::DataTypeBool, 202}, - {&do_gather, MTL::DataType::DataTypeBool, 300}, }; // clang-format off @@ -248,8 +288,7 @@ void steel_matmul_regular( << "_do_axpby_" << (do_axpby ? 't' : 'n') << "_align_M_" << (align_M ? 't' : 'n') << "_align_N_" << (align_N ? 't' : 'n') - << "_align_K_" << (align_K ? 't' : 'n') - << "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on + << "_align_K_" << (align_K ? 't' : 'n'); // clang-format on std::string hash_name = kname.str(); @@ -975,7 +1014,6 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { const bool align_M = (M % bm) == 0; const bool align_N = (N % bn) == 0; const bool align_K = (K % bk) == 0; - const bool do_gather = false; metal::MTLFCList func_consts = { {&has_batch, MTL::DataType::DataTypeBool, 10}, @@ -984,7 +1022,6 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { {&align_M, MTL::DataType::DataTypeBool, 200}, {&align_N, MTL::DataType::DataTypeBool, 201}, {&align_K, MTL::DataType::DataTypeBool, 202}, - {&do_gather, MTL::DataType::DataTypeBool, 300}, }; // clang-format off @@ -993,8 +1030,7 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { << "_do_axpby_" << (do_axpby ? 't' : 'n') << "_align_M_" << (align_M ? 't' : 'n') << "_align_N_" << (align_N ? 't' : 'n') - << "_align_K_" << (align_K ? 't' : 'n') - << "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on + << "_align_K_" << (align_K ? 't' : 'n'); // clang-format on std::string hash_name = kname.str(); @@ -1464,267 +1500,337 @@ void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { d.add_temporaries(std::move(copies), s.index); } -void GatherMM::eval_gpu(const std::vector& inputs, array& out) { - using namespace mlx::steel; - // assert(inputs.size() == 2); - if (!issubdtype(out.dtype(), floating)) { - throw std::runtime_error( - "[GatherMM] Does not yet support non-floating point types."); - } - auto& s = stream(); - auto& d = metal::device(s.device); +void gather_mm_rhs( + const array& a_, + const array& b_, + const array& indices_, + array& out, + metal::Device& d, + const Stream& s) { + array indices = ensure_row_contiguous(indices_, d, s); + auto [transpose_b, ldb, b] = ensure_batch_contiguous(b_, d, s); - auto& a_pre = inputs[0]; - auto& b_pre = inputs[1]; - // Return 0s if either input is empty - if (a_pre.size() == 0 || b_pre.size() == 0) { - array zero = array(0, a_pre.dtype()); - fill_gpu(zero, out, s); - d.add_temporary(std::move(zero), s.index); - return; - } + // Broadcast a with indices. If we are here that means lhs_indices were not + // provided so the lhs_indices are implied to be the shape of a broadcasted + // with rhs_indices. We need only broadcast a and copy it as if applying the + // lhs_indices. + auto broadcast_with_indices = [&d, &s, &indices](const array& x) { + if (x.size() / x.shape(-2) / x.shape(-1) == indices.size()) { + return ensure_row_contiguous(x, d, s); + } - out.set_data(allocator::malloc(out.nbytes())); + auto x_shape = indices.shape(); + x_shape.push_back(x.shape(-2)); + x_shape.push_back(x.shape(-1)); + array new_x(std::move(x_shape), x.dtype(), nullptr, {}); + broadcast(x, new_x); + return ensure_row_contiguous(new_x, d, s); + }; + array a = broadcast_with_indices(a_); - ///////////////////////////////////////////////////////////////////////////// - // Init checks and prep + // Extract the matmul shapes + int K = a.shape(-1); + int M = a.size() / K; + int N = b.shape(-1); + int lda = a.strides()[a.ndim() - 2]; // should be K - int M = a_pre.shape(-2); - int N = b_pre.shape(-1); - int K = a_pre.shape(-1); + // Define the dispatch blocks + int bm = 16, bn = 64, bk = 16; + int wm = 1, wn = 2; - // Keep a vector with copies to be cleared in the completed buffer to release - // the arrays - std::vector copies; - auto [transpose_a, a_cols, a] = check_transpose(copies, s, a_pre, M == 1); - auto [transpose_b, b_cols, b] = check_transpose(copies, s, b_pre, N == 1); + const bool align_M = (M % bm) == 0; + const bool align_N = (N % bn) == 0; + const bool align_K = (K % bk) == 0; - int lda = a_cols; - int ldb = b_cols; + // Define the kernel name + std::string base_name; + base_name.reserve(64); + concatenate( + base_name, + "steel_gather_mm_rhs_n", + transpose_b ? 't' : 'n', + '_', + type_to_name(a), + '_', + type_to_name(out), + "_bm", + bm, + "_bn", + bn, + "_bk", + bk, + "_wm", + wm, + "_wn", + wn); - ///////////////////////////////////////////////////////////////////////////// - // Check and collapse batch dimensions - - auto get_batch_dims = [](const auto& v) { - return decltype(v){v.begin(), v.end() - 2}; + metal::MTLFCList func_consts = { + {&align_M, MTL::DataType::DataTypeBool, 200}, + {&align_N, MTL::DataType::DataTypeBool, 201}, + {&align_K, MTL::DataType::DataTypeBool, 202}, }; - auto& lhs_indices = inputs[2]; - auto& rhs_indices = inputs[3]; + // And the kernel hash that includes the function constants + std::string hash_name; + hash_name.reserve(128); + concatenate( + hash_name, + base_name, + "_align_M_", + align_M ? 't' : 'n', + "_align_N_", + align_N ? 't' : 'n', + "_align_K_", + align_K ? 't' : 'n'); - Shape batch_shape = get_batch_dims(out.shape()); - Strides batch_strides; + // Get and set the kernel + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = get_steel_gemm_gather_kernel( + d, + base_name, + hash_name, + func_consts, + out, + false, + transpose_b, + bm, + bn, + bk, + wm, + wn, + true); + compute_encoder.set_compute_pipeline_state(kernel); - batch_strides.insert( - batch_strides.end(), - lhs_indices.strides().begin(), - lhs_indices.strides().end()); - auto lhs_indices_str = batch_strides.empty() ? 0 : batch_strides.back(); + // Prepare the matmul params + auto batch_stride_b = b.ndim() > 2 ? b.strides()[b.ndim() - 3] : b.size(); + steel::GEMMParams params{ + /* const int M = */ M, + /* const int N = */ N, + /* const int K = */ K, + /* const int lda = */ lda, + /* const int ldb = */ static_cast(ldb), + /* const int ldd = */ N, + /* const int tiles_n = */ (N + bn - 1) / bn, + /* const int tiles_m = */ (M + bm - 1) / bm, + /* const int64_t batch_stride_a = */ 0, + /* const int64_t batch_stride_b = */ static_cast(batch_stride_b), + /* const int64_t batch_stride_d = */ 0, + /* const int swizzle_log = */ 0, + /* const int gemm_k_iterations_aligned = */ (K / bk), + /* const int batch_ndim = */ 0}; - batch_strides.insert( - batch_strides.end(), - rhs_indices.strides().begin(), - rhs_indices.strides().end()); - auto rhs_indices_str = batch_strides.empty() ? 0 : batch_strides.back(); + // Prepare the grid + MTL::Size group_dims = MTL::Size(32, wn, wm); + MTL::Size grid_dims = MTL::Size(params.tiles_n, params.tiles_m, 1); - int batch_ndim = batch_shape.size(); + // Launch kernel + compute_encoder.set_input_array(a, 0); + compute_encoder.set_input_array(b, 1); + compute_encoder.set_input_array(indices, 2); + compute_encoder.set_output_array(out, 3); + compute_encoder.set_bytes(params, 4); - if (batch_ndim == 0) { - batch_shape = {1}; - batch_strides = {0}; - } + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); +} - int batch_ndim_A = a.ndim() - 2; - int batch_ndim_B = b.ndim() - 2; - std::vector operand_batch_ndim = {batch_ndim_A, batch_ndim_B}; +void gather_mv( + const array& mat_, + const array& vec_, + const array& mat_indices_, + const array& vec_indices_, + array& out, + int N, + int K, + bool is_mv, + metal::Device& d, + const Stream& s) { + // Copy if needed + std::vector copies; + auto [transpose_mat, mat_cols, mat] = + check_transpose(copies, s, mat_, N == 1); + auto [transpose_vec, vec_cols, vec] = check_transpose(copies, s, vec_, true); + d.add_temporaries(std::move(copies), s.index); - Shape batch_shape_A = get_batch_dims(a.shape()); - Strides batch_strides_A = get_batch_dims(a.strides()); - Shape batch_shape_B = get_batch_dims(b.shape()); - Strides batch_strides_B = get_batch_dims(b.strides()); + // If we are doing vector matrix instead of matrix vector we need to flip the + // matrix transposition. Basically m @ v = v @ m.T assuming that v is treated + // as a one dimensional array. + transpose_mat = (!is_mv) ^ transpose_mat; - if (batch_ndim_A == 0) { - batch_shape_A = {1}; - batch_strides_A = {0}; - } + // Define some shapes + int in_vector_len = K; + int out_vector_len = N; + int mat_ld = mat_cols; - if (batch_ndim_B == 0) { - batch_shape_B = {1}; - batch_strides_B = {0}; - } + int batch_size_out = out.size() / N; + int batch_ndim = out.ndim() - 2; + int batch_ndim_mat = mat.ndim() - 2; + int batch_ndim_vec = vec.ndim() - 2; + Strides index_strides = vec_indices_.strides(); + index_strides.insert( + index_strides.end(), + mat_indices_.strides().begin(), + mat_indices_.strides().end()); - auto matrix_stride_out = static_cast(M) * N; - auto batch_size_out = out.size() / matrix_stride_out; - - ///////////////////////////////////////////////////////////////////////////// - // Gemv specialization - - // Route to gemv if needed - if (std::min(M, N) == 1) { - // Collect problem info - bool is_b_matrix = N != 1; - - auto& mat = is_b_matrix ? b : a; - auto& vec = is_b_matrix ? a : b; - bool transpose_mat = is_b_matrix ? !transpose_b : transpose_a; - int in_vector_len = K; - int out_vector_len = is_b_matrix ? N : M; - - int mat_cols = transpose_mat ? out_vector_len : in_vector_len; - int mat_rows = transpose_mat ? in_vector_len : out_vector_len; - int mat_ld = is_b_matrix ? b_cols : a_cols; - - auto batch_strides_mat = is_b_matrix ? batch_strides_B : batch_strides_A; - auto batch_strides_vec = is_b_matrix ? batch_strides_A : batch_strides_B; - - auto batch_shape_mat = is_b_matrix ? batch_shape_B : batch_shape_A; - auto batch_shape_vec = is_b_matrix ? batch_shape_A : batch_shape_B; - - if (!is_b_matrix) { - batch_strides = rhs_indices.strides(); - batch_strides.insert( - batch_strides.end(), - lhs_indices.strides().begin(), - lhs_indices.strides().end()); - } - - int batch_ndim = batch_shape.size(); - - // Determine dispatch kernel - int tm = 4, tn = 4; - int sm = 1, sn = 32; - int bm = 1, bn = 1; - int n_out_per_tgp; - std::ostringstream kname; - - if (transpose_mat) { - if (in_vector_len >= 8192 && out_vector_len >= 2048) { - sm = 4; - sn = 8; - } else { - sm = 8; - sn = 4; - } - - if (out_vector_len >= 2048) { - bn = 16; - } else if (out_vector_len >= 512) { - bn = 4; - } else { - bn = 2; - } - - // Specialized kernel for very small outputs - tn = out_vector_len < tn ? 1 : tn; - - n_out_per_tgp = bn * sn * tn; - kname << "gemv_t_gather_" << type_to_name(out); + // Determine dispatch kernel + int tm = 4, tn = 4; + int sm = 1, sn = 32; + int bm = 1, bn = 1; + int n_out_per_tgp; + std::ostringstream kname; + if (transpose_mat) { + if (in_vector_len >= 8192 && out_vector_len >= 2048) { + sm = 4; + sn = 8; } else { - bm = out_vector_len >= 4096 ? 8 : 4; - sn = 32; - - // Specialized kernel for very small outputs - tm = out_vector_len < tm ? 1 : tm; - - n_out_per_tgp = bm * sm * tm; - kname << "gemv_gather_" << type_to_name(out); + sm = 8; + sn = 4; } - kname << "_bm" << bm << "_bn" << bn << "_sm" << sm << "_sn" << sn << "_tm" - << tm << "_tn" << tn; + if (out_vector_len >= 2048) { + bn = 16; + } else if (out_vector_len >= 512) { + bn = 4; + } else { + bn = 2; + } - // Encode and dispatch kernel - auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(kname.str()); - compute_encoder.set_compute_pipeline_state(kernel); + // Specialized kernel for very small outputs + tn = out_vector_len < tn ? 1 : tn; - int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp; - MTL::Size group_dims = MTL::Size(32, bn, bm); - MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out); + n_out_per_tgp = bn * sn * tn; + kname << "gemv_t_gather_" << type_to_name(out); - compute_encoder.set_input_array(mat, 0); - compute_encoder.set_input_array(vec, 1); - compute_encoder.set_output_array(out, 3); + } else { + bm = out_vector_len >= 4096 ? 8 : 4; + sn = 32; - compute_encoder.set_bytes(in_vector_len, 4); - compute_encoder.set_bytes(out_vector_len, 5); - compute_encoder.set_bytes(mat_ld, 6); + // Specialized kernel for very small outputs + tm = out_vector_len < tm ? 1 : tm; - compute_encoder.set_bytes(batch_ndim, 9); - compute_encoder.set_vector_bytes(batch_shape, 10); - compute_encoder.set_vector_bytes(batch_strides, 11); - - int batch_ndim_vec = batch_shape_vec.size(); - compute_encoder.set_bytes(batch_ndim_vec, 12); - compute_encoder.set_vector_bytes(batch_shape_vec, 13); - compute_encoder.set_vector_bytes(batch_strides_vec, 14); - - int batch_ndim_mat = batch_shape_mat.size(); - compute_encoder.set_bytes(batch_ndim_mat, 15); - compute_encoder.set_vector_bytes(batch_shape_mat, 16); - compute_encoder.set_vector_bytes(batch_strides_mat, 17); - - compute_encoder.set_input_array(lhs_indices, 18 + int(!is_b_matrix)); - compute_encoder.set_input_array(rhs_indices, 18 + int(is_b_matrix)); - - compute_encoder.dispatch_threadgroups(grid_dims, group_dims); - - d.add_temporaries(std::move(copies), s.index); - return; + n_out_per_tgp = bm * sm * tm; + kname << "gemv_gather_" << type_to_name(out); } - ///////////////////////////////////////////////////////////////////////////// - // Regular kernel dispatch + kname << "_bm" << bm << "_bn" << bn << "_sm" << sm << "_sn" << sn << "_tm" + << tm << "_tn" << tn; + + // Encode and dispatch kernel + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel(kname.str()); + compute_encoder.set_compute_pipeline_state(kernel); + + int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp; + MTL::Size group_dims = MTL::Size(32, bn, bm); + MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out); + + compute_encoder.set_input_array(mat, 0); + compute_encoder.set_input_array(vec, 1); + compute_encoder.set_output_array(out, 3); + + compute_encoder.set_bytes(in_vector_len, 4); + compute_encoder.set_bytes(out_vector_len, 5); + compute_encoder.set_bytes(mat_ld, 6); + + compute_encoder.set_bytes(batch_ndim, 9); + compute_encoder.set_vector_bytes(out.shape(), 10); + compute_encoder.set_vector_bytes(index_strides, 11); + + compute_encoder.set_bytes(batch_ndim_vec, 12); + compute_encoder.set_vector_bytes(vec.shape(), 13); + compute_encoder.set_vector_bytes(vec.strides(), 14); + + compute_encoder.set_bytes(batch_ndim_mat, 15); + compute_encoder.set_vector_bytes(mat.shape(), 16); + compute_encoder.set_vector_bytes(mat.strides(), 17); + + compute_encoder.set_input_array(vec_indices_, 18); + compute_encoder.set_input_array(mat_indices_, 19); + + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); +} + +void gather_mm( + const array& a_, + const array& b_, + const array& lhs_indices, + const array& rhs_indices, + array& out, + int M, + int N, + int K, + metal::Device& d, + const Stream& s) { + // Copy if needed + std::vector copies; + auto [transpose_a, lda, a] = check_transpose(copies, s, a_, false); + auto [transpose_b, ldb, b] = check_transpose(copies, s, b_, false); + d.add_temporaries(std::move(copies), s.index); // Determine dispatch kernel int bm = 64, bn = 64, bk = 16; int wm = 2, wn = 2; + size_t batch_size_out = out.size() / M / N; + int batch_ndim = out.ndim() - 2; + int batch_ndim_a = a.ndim() - 2; + int batch_ndim_b = b.ndim() - 2; char devc = d.get_architecture().back(); GEMM_TPARAM_MACRO(devc) - // Prepare kernel name - std::ostringstream kname; - kname << "steel_gemm_fused_" << (transpose_a ? 't' : 'n') - << (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_" - << type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk - << "_wm" << wm << "_wn" << wn; - - std::string base_name = kname.str(); - const bool has_batch = batch_ndim > 1; - const bool use_out_source = false; - const bool do_axpby = false; const bool align_M = (M % bm) == 0; const bool align_N = (N % bn) == 0; const bool align_K = (K % bk) == 0; - const bool do_gather = true; + + // Define the kernel name + std::string base_name; + base_name.reserve(128); + concatenate( + base_name, + "steel_gather_mm_", + transpose_a ? 't' : 'n', + transpose_b ? 't' : 'n', + "_", + type_to_name(a), + "_", + type_to_name(out), + "_bm", + bm, + "_bn", + bn, + "_bk", + bk, + "_wm", + wm, + "_wn", + wn); metal::MTLFCList func_consts = { {&has_batch, MTL::DataType::DataTypeBool, 10}, - {&use_out_source, MTL::DataType::DataTypeBool, 100}, - {&do_axpby, MTL::DataType::DataTypeBool, 110}, {&align_M, MTL::DataType::DataTypeBool, 200}, {&align_N, MTL::DataType::DataTypeBool, 201}, {&align_K, MTL::DataType::DataTypeBool, 202}, - {&do_gather, MTL::DataType::DataTypeBool, 300}, }; - // clang-format off - kname << "_has_batch_" << (has_batch ? 't' : 'n') - << "_use_out_source_" << (use_out_source ? 't' : 'n') - << "_do_axpby_" << (do_axpby ? 't' : 'n') - << "_align_M_" << (align_M ? 't' : 'n') - << "_align_N_" << (align_N ? 't' : 'n') - << "_align_K_" << (align_K ? 't' : 'n') - << "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on + // And the kernel hash that includes the function constants + std::string hash_name; + hash_name.reserve(128); + concatenate( + hash_name, + base_name, + "_has_batch_", + has_batch ? 't' : 'n', + "_align_M_", + align_M ? 't' : 'n', + "_align_N_", + align_N ? 't' : 'n', + "_align_K_", + align_K ? 't' : 'n'); - std::string hash_name = kname.str(); - - // Encode and dispatch kernel + // Get and set the kernel auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = get_steel_gemm_fused_kernel( + auto kernel = get_steel_gemm_gather_kernel( d, base_name, hash_name, @@ -1736,72 +1842,96 @@ void GatherMM::eval_gpu(const std::vector& inputs, array& out) { bn, bk, wm, - wn); - + wn, + false); compute_encoder.set_compute_pipeline_state(kernel); - // Use problem size to determine threadblock swizzle - int tn = (N + bn - 1) / bn; - int tm = (M + bm - 1) / bm; - - // TODO: Explore device-based tuning for swizzle - int swizzle_log = 0; // tm >= 6 ? 3 : (tm <= 3 ? 0 : 2); - - // Prepare steel matmul params - GEMMParams params{ + // Prepare the matmul params + steel::GEMMParams params{ /* const int M = */ M, /* const int N = */ N, /* const int K = */ K, - /* const int lda = */ lda, - /* const int ldb = */ ldb, + /* const int lda = */ static_cast(lda), + /* const int ldb = */ static_cast(ldb), /* const int ldd = */ N, - /* const int tiles_n = */ tn, - /* const int tiles_m = */ tm, - /* const int64_t batch_stride_a = */ lhs_indices_str, - /* const int64_t batch_stride_b = */ rhs_indices_str, - /* const int64_t batch_stride_d = */ matrix_stride_out, - /* const int swizzle_log = */ swizzle_log, + /* const int tiles_n = */ (N + bn - 1) / bn, + /* const int tiles_m = */ (M + bm - 1) / bm, + /* const int64_t batch_stride_a = */ + (batch_ndim > 0) ? lhs_indices.strides()[0] : 0, + /* const int64_t batch_stride_b = */ + (batch_ndim > 0) ? rhs_indices.strides()[0] : 0, + /* const int64_t batch_stride_d = */ M * N, + /* const int swizzle_log = */ 0, /* const int gemm_k_iterations_aligned = */ (K / bk), /* const int batch_ndim = */ batch_ndim}; - // Prepare launch grid params - int tile = 1 << swizzle_log; - tm = (tm + tile - 1) / tile; - tn = tn * tile; - + // Prepare the grid MTL::Size group_dims = MTL::Size(32, wn, wm); - MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out); + MTL::Size grid_dims = + MTL::Size(params.tiles_n, params.tiles_m, batch_size_out); // Launch kernel compute_encoder.set_input_array(a, 0); compute_encoder.set_input_array(b, 1); - compute_encoder.set_output_array(out, 3); - - compute_encoder.set_bytes(params, 4); - - compute_encoder.set_vector_bytes(batch_shape, 6); - compute_encoder.set_vector_bytes(batch_strides, 7); - - compute_encoder.set_input_array(lhs_indices, 10); - compute_encoder.set_input_array(rhs_indices, 11); - - std::vector operand_shape = batch_shape_A; - operand_shape.insert( - operand_shape.end(), batch_shape_B.begin(), batch_shape_B.end()); - - std::vector operand_strides = batch_strides_A; - operand_strides.insert( - operand_strides.end(), batch_strides_B.begin(), batch_strides_B.end()); - - operand_batch_ndim.push_back(0); - - compute_encoder.set_vector_bytes(operand_shape, 13); - compute_encoder.set_vector_bytes(operand_strides, 14); - compute_encoder.set_vector_bytes(operand_batch_ndim, 15); - + compute_encoder.set_input_array(lhs_indices, 2); + compute_encoder.set_input_array(rhs_indices, 3); + compute_encoder.set_output_array(out, 4); + compute_encoder.set_bytes(params, 5); + compute_encoder.set_vector_bytes(lhs_indices.shape(), 6); + compute_encoder.set_vector_bytes(lhs_indices.strides(), 7); + compute_encoder.set_vector_bytes(rhs_indices.strides(), 8); + compute_encoder.set_bytes(batch_ndim_a, 9); + compute_encoder.set_vector_bytes(a.shape(), 10); + compute_encoder.set_vector_bytes(a.strides(), 11); + compute_encoder.set_bytes(batch_ndim_b, 12); + compute_encoder.set_vector_bytes(b.shape(), 13); + compute_encoder.set_vector_bytes(b.strides(), 14); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); +} - d.add_temporaries(std::move(copies), s.index); +void GatherMM::eval_gpu(const std::vector& inputs, array& out) { + auto& s = stream(); + auto& d = metal::device(s.device); + + auto& a = inputs[0]; + auto& b = inputs[1]; + auto& lhs_indices = inputs[2]; + auto& rhs_indices = inputs[3]; + + // Return 0s if either input is empty + if (a.size() == 0 || b.size() == 0) { + array zero = array(0, a.dtype()); + fill_gpu(zero, out, s); + d.add_temporary(std::move(zero), s.index); + return; + } + + out.set_data(allocator::malloc(out.nbytes())); + + // Extract shapes from inputs. + int M = a.shape(-2); + int N = b.shape(-1); + int K = a.shape(-1); + + // We are walking a in order and b is also in order so we can batch up the + // matmuls and reuse reading a and b. + if (M == 1 && right_sorted_ == true) { + gather_mm_rhs(a, b, rhs_indices, out, d, s); + return; + } + + // Route to gather gemv if any of a or b are vectors + if (M == 1) { + gather_mv(b, a, rhs_indices, lhs_indices, out, N, K, false, d, s); + return; + } + if (N == 1) { + gather_mv(a, b, lhs_indices, rhs_indices, out, M, K, true, d, s); + return; + } + + // Route to non specialized gather mm + gather_mm(a, b, lhs_indices, rhs_indices, out, M, N, K, d, s); } } // namespace mlx::core diff --git a/mlx/backend/metal/nojit_kernels.cpp b/mlx/backend/metal/nojit_kernels.cpp index 2d6077ed1..8da147971 100644 --- a/mlx/backend/metal/nojit_kernels.cpp +++ b/mlx/backend/metal/nojit_kernels.cpp @@ -193,6 +193,23 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel( return d.get_kernel(kernel_name); } +MTL::ComputePipelineState* get_steel_gemm_gather_kernel( + metal::Device& d, + const std::string& kernel_name, + const std::string& hash_name, + const metal::MTLFCList& func_consts, + const array&, + bool, + bool, + int, + int, + int, + int, + int, + bool) { + return d.get_kernel(kernel_name, "mlx", hash_name, func_consts); +} + MTL::ComputePipelineState* get_gemv_masked_kernel( metal::Device& d, const std::string& kernel_name, @@ -252,4 +269,21 @@ MTL::ComputePipelineState* get_quantized_kernel( return d.get_kernel(kernel_name); } +MTL::ComputePipelineState* get_gather_qmm_kernel( + metal::Device& d, + const std::string& kernel_name, + const std::string& hash_name, + const metal::MTLFCList& func_consts, + const array&, + int, + int, + int, + int, + int, + int, + int, + bool) { + return d.get_kernel(kernel_name, "mlx", hash_name, func_consts); +} + } // namespace mlx::core diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index 8d1d176c4..6f5807543 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -2,6 +2,7 @@ #include +#include "mlx/backend/common/broadcasting.h" #include "mlx/backend/common/compiled.h" #include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/device.h" @@ -14,93 +15,168 @@ namespace mlx::core { -void launch_qmm( - std::string name, - const std::vector& inputs, +namespace { + +inline array +ensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) { + if (!x.flags().row_contiguous) { + array x_copy(x.shape(), x.dtype(), nullptr, {}); + copy_gpu(x, x_copy, CopyType::General, s); + d.add_temporary(x_copy, s.index); + return x_copy; + } else { + return x; + } +} + +inline array ensure_row_contiguous_matrix( + const array& x, + metal::Device& d, + const Stream& s) { + auto stride_0 = x.strides()[x.ndim() - 2]; + auto stride_1 = x.strides()[x.ndim() - 1]; + if (stride_0 == x.shape(-1) && stride_1 == 1) { + return x; + } else { + array x_copy(x.shape(), x.dtype(), nullptr, {}); + copy_gpu(x, x_copy, CopyType::General, s); + d.add_temporary(x_copy, s.index); + return x_copy; + } +} + +inline int get_qmv_batch_limit(int D, int O, metal::Device& d) { + auto arch = d.get_architecture(); + auto arch_size = arch.back(); + auto arch_gen = arch.substr(arch.size() - 3, 2); + if (arch_gen == "13" || arch_gen == "14") { + switch (arch_size) { + case 'd': + if (D <= 2048 && O <= 2048) { + return 32; + } else if (D <= 4096 && O <= 4096) { + return 18; + } else { + return 12; + } + default: + if (D <= 2048 && O <= 2048) { + return 14; + } else if (D <= 4096 && O <= 4096) { + return 10; + } else { + return 6; + } + } + } else { + switch (arch_size) { + case 'd': + if (D <= 2048 && O <= 2048) { + return 32; + } else if (D <= 4096 && O <= 4096) { + return 18; + } else { + return 12; + } + default: + if (D <= 2048 && O <= 2048) { + return 18; + } else if (D <= 4096 && O <= 4096) { + return 12; + } else { + return 10; + } + } + } +} + +inline int add_strides_and_shapes( + CommandEncoder& compute_encoder, + bool skip, + const array& x, + const array& w, + const array& scales, + const array& biases, + int offset) { + if (skip) { + return 0; + } + + // TODO: Collapse batch dimensions + + int x_batch_ndims = x.ndim() - 2; + int w_batch_ndims = w.ndim() - 2; + compute_encoder.set_bytes(x_batch_ndims, offset); + compute_encoder.set_vector_bytes(x.shape(), offset + 1); + compute_encoder.set_vector_bytes(x.strides(), offset + 2); + compute_encoder.set_bytes(w_batch_ndims, offset + 3); + compute_encoder.set_vector_bytes(w.shape(), offset + 4); + compute_encoder.set_vector_bytes(w.strides(), offset + 5); + compute_encoder.set_vector_bytes(scales.strides(), offset + 6); + compute_encoder.set_vector_bytes(biases.strides(), offset + 7); + + return 8; +} + +inline int add_gather_strides_and_shapes( + CommandEncoder& compute_encoder, + const array& lhs_indices, + const array& rhs_indices, + int offset) { + auto [shape, strides] = collapse_contiguous_dims( + lhs_indices.shape(), {lhs_indices.strides(), rhs_indices.strides()}); + int ndims = shape.size(); + + compute_encoder.set_bytes(ndims, offset); + compute_encoder.set_vector_bytes(shape, offset + 1); + compute_encoder.set_vector_bytes(strides[0], offset + 2); + compute_encoder.set_vector_bytes(strides[1], offset + 3); + + return 4; +} + +} // namespace + +void qmv_quad( + const array& x, + const array& w, + const array& scales, + const array& biases, array& out, int group_size, int bits, - int D, - int O, - int B, + int M, int N, - MTL::Size& group_dims, - MTL::Size& grid_dims, - bool batched, - bool matrix, - bool gather, - bool aligned, - bool quad, + int K, + metal::Device& d, const Stream& s) { - auto& x_pre = inputs[0]; - auto& w_pre = inputs[1]; - auto& scales_pre = inputs[2]; - auto& biases_pre = inputs[3]; + int B = out.size() / M / N; - // Ensure that the last two dims are row contiguous. - // TODO: Check if we really need this for x as well... - std::vector copies; - auto ensure_row_contiguous_last_dims = [&copies, &s](const array& arr) { - auto stride_0 = arr.strides()[arr.ndim() - 2]; - auto stride_1 = arr.strides()[arr.ndim() - 1]; - if (stride_0 == arr.shape(-1) && stride_1 == 1) { - return arr; - } else { - array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); - copy_gpu(arr, arr_copy, CopyType::General, s); - copies.push_back(arr_copy); - return arr_copy; - } - }; - auto x = ensure_row_contiguous_last_dims(x_pre); - auto w = ensure_row_contiguous_last_dims(w_pre); - auto scales = ensure_row_contiguous_last_dims(scales_pre); - auto biases = ensure_row_contiguous_last_dims(biases_pre); + constexpr int quads_per_simd = 8; + constexpr int results_per_quadgroup = 8; + int bn = quads_per_simd * results_per_quadgroup; + int simdgroup_size = 32; + MTL::Size group_dims(simdgroup_size, 1, 1); + MTL::Size grid_dims(M, (N + bn - 1) / bn, B); - int x_batch_ndims = x.ndim() - 2; - auto& x_shape = x.shape(); - auto& x_strides = x.strides(); - int w_batch_ndims = w.ndim() - 2; - auto& w_shape = w.shape(); - auto& w_strides = w.strides(); - auto& s_strides = scales.strides(); - auto& b_strides = biases.strides(); + std::string kname; + kname.reserve(64); + std::string type_string = get_type_string(x.dtype()); + concatenate( + kname, + "qmv_quad_", + type_string, + "_gs_", + group_size, + "_b_", + bits, + "_d_", + K, + B > 1 ? "_batch_1" : "_batch_0"); + auto template_def = get_template_definition( + kname, "qmv_quad", type_string, group_size, bits, K, B > 1); - std::string aligned_n = (O % 32) == 0 ? "true" : "false"; - - std::ostringstream kname; - auto type_string = get_type_string(x.dtype()); - kname << name << "_" << type_string << "_gs_" << group_size << "_b_" << bits; - if (quad) { - kname << "_d_" << D; - } - if (aligned) { - kname << "_alN_" << aligned_n; - } - if (!gather) { - kname << "_batch_" << batched; - } - - // Encode and dispatch kernel - std::string template_def; - if (quad) { - template_def = get_template_definition( - kname.str(), name, type_string, group_size, bits, D, batched); - } else if (aligned && !gather) { - template_def = get_template_definition( - kname.str(), name, type_string, group_size, bits, aligned_n, batched); - } else if (!gather && !aligned) { - template_def = get_template_definition( - kname.str(), name, type_string, group_size, bits, batched); - } else if (aligned && gather) { - template_def = get_template_definition( - kname.str(), name, type_string, group_size, bits, aligned_n); - } else { - template_def = get_template_definition( - kname.str(), name, type_string, group_size, bits); - } - auto& d = metal::device(s.device); - auto kernel = get_quantized_kernel(d, kname.str(), template_def); + auto kernel = get_quantized_kernel(d, kname, template_def); auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); @@ -109,90 +185,87 @@ void launch_qmm( compute_encoder.set_input_array(biases, 2); compute_encoder.set_input_array(x, 3); compute_encoder.set_output_array(out, 4); - compute_encoder.set_bytes(D, 5); - compute_encoder.set_bytes(O, 6); - - int offset = 7; - if (matrix) { - compute_encoder.set_bytes(B, 7); - offset += 1; - } - - if (batched || gather) { - compute_encoder.set_bytes(x_batch_ndims, offset); - compute_encoder.set_vector_bytes(x_shape, offset + 1); - compute_encoder.set_vector_bytes(x_strides, offset + 2); - compute_encoder.set_bytes(w_batch_ndims, offset + 3); - compute_encoder.set_vector_bytes(w_shape, offset + 4); - compute_encoder.set_vector_bytes(w_strides, offset + 5); - compute_encoder.set_vector_bytes(s_strides, offset + 6); - compute_encoder.set_vector_bytes(b_strides, offset + 7); - } - if (gather) { - auto& lhs_indices = inputs[4]; - auto& rhs_indices = inputs[5]; - - // TODO: collapse batch dims - auto& batch_shape = lhs_indices.shape(); - int batch_ndims = batch_shape.size(); - auto& lhs_strides = lhs_indices.strides(); - auto& rhs_strides = rhs_indices.strides(); - - compute_encoder.set_bytes(batch_ndims, offset + 8); - compute_encoder.set_vector_bytes(batch_shape, offset + 9); - compute_encoder.set_input_array(lhs_indices, offset + 10); - compute_encoder.set_input_array(rhs_indices, offset + 11); - compute_encoder.set_vector_bytes(lhs_strides, offset + 12); - compute_encoder.set_vector_bytes(rhs_strides, offset + 13); - } + compute_encoder.set_bytes(K, 5); + compute_encoder.set_bytes(N, 6); + add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, 7); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); - d.add_temporaries(std::move(copies), s.index); } -void qvm_split_k( - const std::vector& inputs, +void qmv( + const array& x, + const array& w, + const array& scales, + const array& biases, array& out, int group_size, int bits, - int D, - int O, - int B, + int M, int N, + int K, + metal::Device& d, const Stream& s) { - int split_k = D > 8192 ? 32 : 8; - int split_D = (D + split_k - 1) / split_k; - N *= split_k; + int B = out.size() / M / N; - int bo = 64; - int bd = 32; - MTL::Size group_dims = MTL::Size(bd, 2, 1); - MTL::Size grid_dims = MTL::Size(B, O / bo, N); + int bn = 8; + int bk = 32; + MTL::Size group_dims(bk, 2, 1); + MTL::Size grid_dims(M, (N + bn - 1) / bn, B); - auto& x_pre = inputs[0]; - auto& w_pre = inputs[1]; - auto& scales_pre = inputs[2]; - auto& biases_pre = inputs[3]; + std::string kname; + kname.reserve(64); + std::string type_string = get_type_string(x.dtype()); + bool fast = N % bn == 0 && K % 512 == 0; + concatenate( + kname, + fast ? "qmv_fast_" : "qmv_", + type_string, + "_gs_", + group_size, + "_b_", + bits, + B > 1 ? "_batch_1" : "_batch_0"); + auto template_def = get_template_definition( + kname, fast ? "qmv_fast" : "qmv", type_string, group_size, bits, B > 1); - // Ensure that the last two dims are row contiguous. - // TODO: Check if we really need this for x as well... - std::vector copies; - auto ensure_row_contiguous_last_dims = [&copies, &s](const array& arr) { - auto stride_0 = arr.strides()[arr.ndim() - 2]; - auto stride_1 = arr.strides()[arr.ndim() - 1]; - if (stride_0 == arr.shape(-1) && stride_1 == 1) { - return arr; - } else { - array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); - copy_gpu(arr, arr_copy, CopyType::General, s); - copies.push_back(arr_copy); - return arr_copy; - } - }; - auto x = ensure_row_contiguous_last_dims(x_pre); - auto w = ensure_row_contiguous_last_dims(w_pre); - auto scales = ensure_row_contiguous_last_dims(scales_pre); - auto biases = ensure_row_contiguous_last_dims(biases_pre); + auto kernel = get_quantized_kernel(d, kname, template_def); + auto& compute_encoder = d.get_command_encoder(s.index); + compute_encoder.set_compute_pipeline_state(kernel); + + compute_encoder.set_input_array(w, 0); + compute_encoder.set_input_array(scales, 1); + compute_encoder.set_input_array(biases, 2); + compute_encoder.set_input_array(x, 3); + compute_encoder.set_output_array(out, 4); + compute_encoder.set_bytes(K, 5); + compute_encoder.set_bytes(N, 6); + add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, 7); + + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); +} + +void qvm_split_k( + const array& x, + const array& w, + const array& scales, + const array& biases, + array& out, + int group_size, + int bits, + int M, + int N, + int K, + metal::Device& d, + const Stream& s) { + int split_k = K > 8192 ? 32 : 8; + int split_D = (K + split_k - 1) / split_k; + int B = out.size() / M / N; + B *= split_k; + + int bn = 64; + int bk = 32; + MTL::Size group_dims = MTL::Size(bk, 2, 1); + MTL::Size grid_dims = MTL::Size(M, N / bn, B); int x_batch_ndims = x.ndim() - 2; auto x_shape = x.shape(); @@ -217,9 +290,7 @@ void qvm_split_k( s_strides.insert(s_strides.end() - 2, split_D * scales.shape(-1)); b_strides.insert(b_strides.end() - 2, split_D * biases.shape(-1)); - int final_block_size = D - (split_k - 1) * split_D; - - auto& d = metal::device(s.device); + int final_block_size = K - (split_k - 1) * split_D; auto temp_shape = out.shape(); temp_shape.insert(temp_shape.end() - 2, split_k); @@ -227,15 +298,24 @@ void qvm_split_k( intermediate.set_data(allocator::malloc(intermediate.nbytes())); d.add_temporary(intermediate, s.index); - std::ostringstream kname; - auto type_string = get_type_string(x.dtype()); - kname << "qvm_split_k" << "_" << type_string << "_gs_" << group_size << "_b_" - << bits << "_spk_" << split_k; + std::string type_string = get_type_string(x.dtype()); + std::string kname; + kname.reserve(64); + concatenate( + kname, + "qvm_split_k_", + type_string, + "_gs_", + group_size, + "_b_", + bits, + "_spk_", + split_k); auto template_def = get_template_definition( - kname.str(), "qvm_split_k", type_string, group_size, bits, split_k); + kname, "qvm_split_k", type_string, group_size, bits, split_k); // Encode and dispatch kernel - auto kernel = get_quantized_kernel(d, kname.str(), template_def); + auto kernel = get_quantized_kernel(d, kname, template_def); auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); @@ -245,7 +325,7 @@ void qvm_split_k( compute_encoder.set_input_array(x, 3); compute_encoder.set_output_array(intermediate, 4); compute_encoder.set_bytes(split_D, 5); - compute_encoder.set_bytes(O, 6); + compute_encoder.set_bytes(N, 6); compute_encoder.set_bytes(x_batch_ndims, 7); compute_encoder.set_vector_bytes(x_shape, 8); @@ -258,7 +338,6 @@ void qvm_split_k( compute_encoder.set_bytes(final_block_size, 15); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); - d.add_temporaries(std::move(copies), s.index); int axis = intermediate.ndim() - 3; ReductionPlan plan( @@ -269,170 +348,589 @@ void qvm_split_k( intermediate, out, "sum", plan, {axis}, compute_encoder, d, s); } -void qmm_op( - const std::vector& inputs, +void qvm( + const array& x, + const array& w, + const array& scales, + const array& biases, + array& out, + int group_size, + int bits, + int M, + int N, + int K, + metal::Device& d, + const Stream& s) { + int B = out.size() / M / N; + + int bn = 64; + int bk = 32; + MTL::Size group_dims(bk, 2, 1); + MTL::Size grid_dims(M, (N + bn - 1) / bn, B); + + std::string kname; + kname.reserve(64); + std::string type_string = get_type_string(x.dtype()); + concatenate( + kname, + "qvm_", + type_string, + "_gs_", + group_size, + "_b_", + bits, + B > 1 ? "_batch_1" : "_batch_0"); + auto template_def = get_template_definition( + kname, "qvm", type_string, group_size, bits, B > 1); + + auto kernel = get_quantized_kernel(d, kname, template_def); + auto& compute_encoder = d.get_command_encoder(s.index); + compute_encoder.set_compute_pipeline_state(kernel); + + compute_encoder.set_input_array(w, 0); + compute_encoder.set_input_array(scales, 1); + compute_encoder.set_input_array(biases, 2); + compute_encoder.set_input_array(x, 3); + compute_encoder.set_output_array(out, 4); + compute_encoder.set_bytes(K, 5); + compute_encoder.set_bytes(N, 6); + add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, 7); + + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); +} + +void qmm( + const array& x, + const array& w, + const array& scales, + const array& biases, array& out, bool transpose, int group_size, int bits, - bool gather, + int M, + int N, + int K, + metal::Device& d, const Stream& s) { - out.set_data(allocator::malloc(out.nbytes())); + int B = out.size() / M / N; - MTL::Size group_dims; - MTL::Size grid_dims; + int wm = 2; + int wn = 2; + int bm = 32; + int bn = 32; + MTL::Size group_dims(32, wn, wm); + MTL::Size grid_dims((N + bn - 1) / bn, (M + bm - 1) / bm, B); - auto& x = inputs[0]; - auto& w = inputs[1]; - bool batched = !gather && (w.ndim() > 2 || !x.flags().row_contiguous); + std::string kname; + kname.reserve(64); + bool aligned = N % 32 == 0; + bool batched = B > 1; + std::string type_string = get_type_string(x.dtype()); + concatenate( + kname, + transpose ? "qmm_t_" : "qmm_n_", + type_string, + "_gs_", + group_size, + "_b_", + bits, + transpose ? (aligned ? "_alN_true" : "_alN_false") : "", + batched ? "_batch_1" : "_batch_0"); + std::string template_def; + if (transpose) { + template_def = get_template_definition( + kname, "qmm_t", type_string, group_size, bits, aligned, batched); + } else { + template_def = get_template_definition( + kname, "qmm_n", type_string, group_size, bits, batched); + } - int D = x.shape(-1); - int O = out.shape(-1); - // For the unbatched W case, avoid `adjust_matrix_offsets` - // for a small performance gain. - int B = (batched || gather) ? x.shape(-2) : x.size() / D; - int N = (batched || gather) ? out.size() / B / O : 1; + auto kernel = get_quantized_kernel(d, kname, template_def); + auto& compute_encoder = d.get_command_encoder(s.index); + compute_encoder.set_compute_pipeline_state(kernel); - std::string name = gather ? "bs_" : ""; - bool matrix = false; - bool aligned = false; - bool quad = false; + compute_encoder.set_input_array(w, 0); + compute_encoder.set_input_array(scales, 1); + compute_encoder.set_input_array(biases, 2); + compute_encoder.set_input_array(x, 3); + compute_encoder.set_output_array(out, 4); + compute_encoder.set_bytes(K, 5); + compute_encoder.set_bytes(N, 6); + compute_encoder.set_bytes(M, 7); + add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, 8); - auto get_qmv_batch_limit = [s](int D, int O) { - auto arch = metal::device(s.device).get_architecture(); - auto arch_size = arch.back(); - auto arch_gen = arch.substr(arch.size() - 3, 2); - if (arch_gen == "13" || arch_gen == "14") { - switch (arch_size) { - case 'd': - if (D <= 2048 && O <= 2048) { - return 32; - } else if (D <= 4096 && O <= 4096) { - return 18; - } else { - return 12; - } - default: - if (D <= 2048 && O <= 2048) { - return 14; - } else if (D <= 4096 && O <= 4096) { - return 10; - } else { - return 6; - } - } - } else { - switch (arch_size) { - case 'd': - if (D <= 2048 && O <= 2048) { - return 32; - } else if (D <= 4096 && O <= 4096) { - return 18; - } else { - return 12; - } - default: - if (D <= 2048 && O <= 2048) { - return 18; - } else if (D <= 4096 && O <= 4096) { - return 12; - } else { - return 10; - } - } + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); +} + +void gather_qmm( + const array& x, + const array& w, + const array& scales, + const array& biases, + const array& lhs_indices, + const array& rhs_indices, + array& out, + bool transpose, + int group_size, + int bits, + int M, + int N, + int K, + metal::Device& d, + const Stream& s) { + int B = out.size() / M / N; + + int wm = 2; + int wn = 2; + int bm = 32; + int bn = 32; + MTL::Size group_dims(32, wn, wm); + MTL::Size grid_dims((N + bn - 1) / bn, (M + bm - 1) / bm, B); + + std::string kname; + kname.reserve(64); + bool aligned = N % 32 == 0; + bool batched = B > 1; + std::string type_string = get_type_string(x.dtype()); + concatenate( + kname, + transpose ? "gather_qmm_t_" : "gather_qmm_n_", + type_string, + "_gs_", + group_size, + "_b_", + bits, + transpose ? (aligned ? "_alN_true" : "_alN_false") : ""); + std::string template_def; + if (transpose) { + template_def = get_template_definition( + kname, "gather_qmm_t", type_string, group_size, bits, aligned); + } else { + template_def = get_template_definition( + kname, "gather_qmm_n", type_string, group_size, bits); + } + + auto kernel = get_quantized_kernel(d, kname, template_def); + auto& compute_encoder = d.get_command_encoder(s.index); + compute_encoder.set_compute_pipeline_state(kernel); + + compute_encoder.set_input_array(w, 0); + compute_encoder.set_input_array(scales, 1); + compute_encoder.set_input_array(biases, 2); + compute_encoder.set_input_array(x, 3); + compute_encoder.set_input_array(lhs_indices, 4); + compute_encoder.set_input_array(rhs_indices, 5); + compute_encoder.set_output_array(out, 6); + compute_encoder.set_bytes(K, 7); + compute_encoder.set_bytes(N, 8); + compute_encoder.set_bytes(M, 9); + int n = + add_strides_and_shapes(compute_encoder, false, x, w, scales, biases, 10); + add_gather_strides_and_shapes( + compute_encoder, lhs_indices, rhs_indices, 10 + n); + + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); +} + +void gather_qmv( + const array& x, + const array& w, + const array& scales, + const array& biases, + const array& lhs_indices, + const array& rhs_indices, + array& out, + int group_size, + int bits, + int M, + int N, + int K, + metal::Device& d, + const Stream& s) { + int B = out.size() / M / N; + + int bn = 8; + int bk = 32; + MTL::Size group_dims(bk, 2, 1); + MTL::Size grid_dims(M, (N + bn - 1) / bn, B); + + std::string kname; + kname.reserve(64); + std::string type_string = get_type_string(x.dtype()); + bool fast = N % bn == 0 && K % 512 == 0; + concatenate( + kname, + fast ? "gather_qmv_fast_" : "gather_qmv_", + type_string, + "_gs_", + group_size, + "_b_", + bits); + auto template_def = get_template_definition( + kname, + fast ? "gather_qmv_fast" : "gather_qmv", + type_string, + group_size, + bits); + + auto kernel = get_quantized_kernel(d, kname, template_def); + auto& compute_encoder = d.get_command_encoder(s.index); + compute_encoder.set_compute_pipeline_state(kernel); + + compute_encoder.set_input_array(w, 0); + compute_encoder.set_input_array(scales, 1); + compute_encoder.set_input_array(biases, 2); + compute_encoder.set_input_array(x, 3); + compute_encoder.set_input_array(lhs_indices, 4); + compute_encoder.set_input_array(rhs_indices, 5); + compute_encoder.set_output_array(out, 6); + compute_encoder.set_bytes(K, 7); + compute_encoder.set_bytes(N, 8); + int n = + add_strides_and_shapes(compute_encoder, false, x, w, scales, biases, 9); + add_gather_strides_and_shapes( + compute_encoder, lhs_indices, rhs_indices, 9 + n); + + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); +} + +void gather_qvm( + const array& x, + const array& w, + const array& scales, + const array& biases, + const array& lhs_indices, + const array& rhs_indices, + array& out, + int group_size, + int bits, + int M, + int N, + int K, + metal::Device& d, + const Stream& s) { + int B = out.size() / M / N; + + int bn = 64; + int bk = 32; + MTL::Size group_dims(bk, 2, 1); + MTL::Size grid_dims(M, (N + bn - 1) / bn, B); + + std::string kname; + kname.reserve(64); + std::string type_string = get_type_string(x.dtype()); + concatenate( + kname, "gather_qvm_", type_string, "_gs_", group_size, "_b_", bits); + auto template_def = get_template_definition( + kname, "gather_qvm", type_string, group_size, bits); + + auto kernel = get_quantized_kernel(d, kname, template_def); + auto& compute_encoder = d.get_command_encoder(s.index); + compute_encoder.set_compute_pipeline_state(kernel); + + compute_encoder.set_input_array(w, 0); + compute_encoder.set_input_array(scales, 1); + compute_encoder.set_input_array(biases, 2); + compute_encoder.set_input_array(x, 3); + compute_encoder.set_input_array(lhs_indices, 4); + compute_encoder.set_input_array(rhs_indices, 5); + compute_encoder.set_output_array(out, 6); + compute_encoder.set_bytes(K, 7); + compute_encoder.set_bytes(N, 8); + int n = + add_strides_and_shapes(compute_encoder, false, x, w, scales, biases, 9); + add_gather_strides_and_shapes( + compute_encoder, lhs_indices, rhs_indices, 9 + n); + + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); +} + +void gather_qmm_rhs( + const array& x_, + const array& w_, + const array& scales_, + const array& biases_, + const array& indices_, + array& out, + bool transpose, + int group_size, + int bits, + int M, + int N, + int K, + metal::Device& d, + const Stream& s) { + // Start by normalizing the indices + array indices = ensure_row_contiguous(indices_, d, s); + + // Broadcast x with indices. If we are here that means lhs_indices were not + // provided so the lhs_indices are implied to be the shape of x broadcasted + // with rhs_indices. We need only broadcast x and copy it as if applying the + // lhs_indices. + auto broadcast_with_indices = [&d, &s, &indices](const array& x) { + if (x.size() / x.shape(-2) / x.shape(-1) == indices.size()) { + return ensure_row_contiguous(x, d, s); } + + auto x_shape = indices.shape(); + x_shape.push_back(x.shape(-2)); + x_shape.push_back(x.shape(-1)); + array new_x(std::move(x_shape), x.dtype(), nullptr, {}); + broadcast(x, new_x); + return ensure_row_contiguous(new_x, d, s); }; - if (transpose) { - auto qmv_batch_limit = get_qmv_batch_limit(D, O); - if (B < qmv_batch_limit && (D == 128 || D == 64) && is_power_of_2(bits)) { - name += "qmv_quad"; - constexpr int quads_per_simd = 8; - constexpr int results_per_quadgroup = 8; - int bo = quads_per_simd * results_per_quadgroup; - int simdgroup_size = 32; - group_dims = MTL::Size(simdgroup_size, 1, 1); - grid_dims = MTL::Size(B, (O + bo - 1) / bo, N); - quad = true; - } else if (B < qmv_batch_limit && O % 8 == 0 && D % 512 == 0 && D >= 512) { - name += "qmv_fast"; - int bo = 8; - int bd = 32; - group_dims = MTL::Size(bd, 2, 1); - grid_dims = MTL::Size(B, O / bo, N); - } else if (B < qmv_batch_limit) { - name += "qmv"; - int bo = 8; - int bd = 32; - group_dims = MTL::Size(bd, 2, 1); - grid_dims = MTL::Size(B, (O + bo - 1) / bo, N); - } else { - int wn = 2; - int wm = 2; - int bm = 32; - int bn = 32; - group_dims = MTL::Size(32, wn, wm); - grid_dims = MTL::Size((O + bn - 1) / bn, (B + bm - 1) / bm, N); - name += "qmm_t"; - matrix = true; - aligned = true; - } - } else { - if (B < 4 && D >= 1024 && !gather) { - return qvm_split_k(inputs, out, group_size, bits, D, O, B, N, s); - } else if (B < 4) { - name += "qvm"; - int bo = 64; - int bd = 32; - group_dims = MTL::Size(bd, 2, 1); - grid_dims = MTL::Size(B, O / bo, N); - } else { - name += "qmm_n"; - int wn = 2; - int wm = 2; - int bm = 32; - int bn = 32; - group_dims = MTL::Size(32, wn, wm); - grid_dims = MTL::Size(O / bn, (B + bm - 1) / bm, N); - matrix = true; - if ((O % bn) != 0) { - std::ostringstream msg; - msg << "[quantized_matmul] The output size should be divisible by " - << bn << " but received " << O << "."; - throw std::runtime_error(msg.str()); - } - } - } - launch_qmm( - name, - inputs, - out, + // Normalize the input arrays + array x = broadcast_with_indices(x_); + array w = ensure_row_contiguous(w_, d, s); + array scales = ensure_row_contiguous(scales_, d, s); + array biases = ensure_row_contiguous(biases_, d, s); + + // TODO: Tune the block sizes + int bm = 16, bn = 32, bk = 32; + int wm = 1, wn = 2; + + const bool align_M = (M % bm) == 0; + const bool align_N = (N % bn) == 0; + const bool align_K = (K % bk) == 0; + + // Make the kernel name + std::string kname; + kname.reserve(64); + std::string type_string = get_type_string(x.dtype()); + concatenate( + kname, + transpose ? "gather_qmm_rhs_nt_" : "gather_qmm_rhs_nn_", + type_string, + "_gs_", + group_size, + "_b_", + bits, + "_bm_", + bm, + "_bn_", + bn, + "_bk_", + bk, + "_wm_", + wm, + "_wn_", + wn); + + metal::MTLFCList func_consts = { + {&align_M, MTL::DataType::DataTypeBool, 200}, + {&align_N, MTL::DataType::DataTypeBool, 201}, + {&align_K, MTL::DataType::DataTypeBool, 202}, + }; + + // And the kernel hash that includes the function constants + std::string hash_name; + hash_name.reserve(128); + concatenate( + hash_name, + kname, + "_align_M_", + align_M ? 't' : 'n', + "_align_N_", + align_N ? 't' : 'n', + "_align_K_", + align_K ? 't' : 'n'); + + // Get and set the kernel + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = get_gather_qmm_kernel( + d, + kname, + hash_name, + func_consts, + x, group_size, bits, - D, - O, - B, - N, - group_dims, - grid_dims, - batched, - matrix, - gather, - aligned, - quad, - s); + bm, + bn, + bk, + wm, + wn, + transpose); + compute_encoder.set_compute_pipeline_state(kernel); + + MTL::Size group_dims(32, wn, wm); + MTL::Size grid_dims((N + bn - 1) / bn, (M + bm - 1) / bm, 1); + + compute_encoder.set_input_array(x, 0); + compute_encoder.set_input_array(w, 1); + compute_encoder.set_input_array(scales, 2); + compute_encoder.set_input_array(biases, 3); + compute_encoder.set_input_array(indices, 4); + compute_encoder.set_output_array(out, 5); + compute_encoder.set_bytes(M, 6); + compute_encoder.set_bytes(N, 7); + compute_encoder.set_bytes(K, 8); + + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 4); - qmm_op( - inputs, out, transpose_, group_size_, bits_, /*gather=*/false, stream()); + auto& s = stream(); + auto& d = metal::device(s.device); + + out.set_data(allocator::malloc(out.nbytes())); + + // Make sure the last two dims of x and w, s, b are contiguous. This should + // be relaxed for x. + array x = ensure_row_contiguous_matrix(inputs[0], d, s); + array w = ensure_row_contiguous_matrix(inputs[1], d, s); + array scales = ensure_row_contiguous_matrix(inputs[2], d, s); + array biases = ensure_row_contiguous_matrix(inputs[3], d, s); + + // Extract the matmul shapes + bool non_batched = w.ndim() == 2 && x.flags().row_contiguous; + int K = x.shape(-1); + int M = non_batched ? x.size() / K : x.shape(-2); + int N = out.shape(-1); + + int vector_limit = transpose_ ? get_qmv_batch_limit(K, N, d) : 4; + + // It is a matrix matrix product. + if (M >= vector_limit) { + qmm(x, + w, + scales, + biases, + out, + transpose_, + group_size_, + bits_, + M, + N, + K, + d, + s); + return; + } + + // It is a qmv with a small inner dimension so route to qmv_quad kernel + if (transpose_ && (K == 128 || K == 64) && is_power_of_2(bits_)) { + qmv_quad(x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s); + return; + } + + // Run of the mill qmv + if (transpose_) { + qmv(x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s); + return; + } + + // Run of the mill qvm + if (K < 1024) { + qvm(x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s); + return; + } + + // Qvm with large dimension so route to a split K kernel for more parallelism + qvm_split_k(x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s); + return; } void GatherQMM::eval_gpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 6); - qmm_op( - inputs, out, transpose_, group_size_, bits_, /*gather=*/true, stream()); + auto& s = stream(); + auto& d = metal::device(s.device); + + out.set_data(allocator::malloc(out.nbytes())); + + array x = ensure_row_contiguous_matrix(inputs[0], d, s); + array w = ensure_row_contiguous_matrix(inputs[1], d, s); + array scales = ensure_row_contiguous_matrix(inputs[2], d, s); + array biases = ensure_row_contiguous_matrix(inputs[3], d, s); + const array& lhs_indices = inputs[4]; + const array& rhs_indices = inputs[5]; + + int K = x.shape(-1); + int M = x.shape(-2); + int N = out.shape(-1); + int B = out.size() / M / N; + int E = w.size() / w.shape(-1) / w.shape(-2); + int vector_limit = transpose_ ? get_qmv_batch_limit(K, N, d) : 4; + + // We are walking x in order and w is also in order so we can batch up the + // matmuls and reuse reading x and w. + // + // TODO: Tune 16 and 8 here a bit better. + if (M == 1 && B >= 16 && right_sorted_ == true && B / E >= 8) { + gather_qmm_rhs( + x, + w, + scales, + biases, + rhs_indices, + out, + transpose_, + group_size_, + bits_, + x.size() / K, + N, + K, + d, + s); + return; + } + + // It is a matrix matrix product + if (M >= vector_limit) { + gather_qmm( + x, + w, + scales, + biases, + lhs_indices, + rhs_indices, + out, + transpose_, + group_size_, + bits_, + M, + N, + K, + d, + s); + return; + } + + if (transpose_) { + gather_qmv( + x, + w, + scales, + biases, + lhs_indices, + rhs_indices, + out, + group_size_, + bits_, + M, + N, + K, + d, + s); + return; + } + + gather_qvm( + x, + w, + scales, + biases, + lhs_indices, + rhs_indices, + out, + group_size_, + bits_, + M, + N, + K, + d, + s); } void fast::AffineQuantize::eval_gpu( @@ -444,27 +942,13 @@ void fast::AffineQuantize::eval_gpu( auto& s = stream(); auto& d = metal::device(s.device); - - std::vector copies; - auto ensure_row_contiguous = [&copies, &s](const array& arr) { - if (arr.flags().row_contiguous) { - return arr; - } else { - array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); - copy_gpu(arr, arr_copy, CopyType::General, s); - copies.push_back(arr_copy); - return arr_copy; - } - }; - auto w = ensure_row_contiguous(w_pre); - auto& compute_encoder = d.get_command_encoder(s.index); + + auto w = ensure_row_contiguous(w_pre, d, s); compute_encoder.set_input_array(w, 0); if (dequantize_) { - auto& scales_pre = inputs[1]; - auto& biases_pre = inputs[2]; - auto scales = ensure_row_contiguous(scales_pre); - auto biases = ensure_row_contiguous(biases_pre); + auto scales = ensure_row_contiguous(inputs[1], d, s); + auto biases = ensure_row_contiguous(inputs[2], d, s); compute_encoder.set_input_array(scales, 1); compute_encoder.set_input_array(biases, 2); compute_encoder.set_output_array(out, 3); @@ -512,8 +996,6 @@ void fast::AffineQuantize::eval_gpu( MTL::Size grid_dims = use_2d ? get_2d_grid_dims(grid_shape, w.strides()) : MTL::Size(nthreads, 1, 1); compute_encoder.dispatch_threads(grid_dims, group_dims); - - d.add_temporaries(std::move(copies), s.index); } } // namespace mlx::core diff --git a/mlx/backend/metal/scan.cpp b/mlx/backend/metal/scan.cpp index c7e0087b7..b1800fea9 100644 --- a/mlx/backend/metal/scan.cpp +++ b/mlx/backend/metal/scan.cpp @@ -60,6 +60,9 @@ void Scan::eval_gpu(const std::vector& inputs, array& out) { case Scan::Min: reduce_type = "min"; break; + case Scan::LogAddExp: + reduce_type = "logaddexp"; + break; } kname << reduce_type << "_" << type_to_name(in) << "_" << type_to_name(out); auto kernel = get_scan_kernel( diff --git a/mlx/backend/metal/utils.h b/mlx/backend/metal/utils.h index cc56bab32..079d15f17 100644 --- a/mlx/backend/metal/utils.h +++ b/mlx/backend/metal/utils.h @@ -2,6 +2,8 @@ #pragma once +#include + #include "mlx/array.h" #include "mlx/backend/metal/device.h" #include "mlx/primitives.h" @@ -58,14 +60,27 @@ inline void debug_set_primitive_buffer_label( std::string get_primitive_string(Primitive* primitive); +template +constexpr bool is_numeric_except_char = std::is_arithmetic_v && + !std::is_same_v && !std::is_same_v && + !std::is_same_v && !std::is_same_v; + template void concatenate(std::string& acc, T first) { - acc += first; + if constexpr (is_numeric_except_char) { + acc += std::to_string(first); + } else { + acc += first; + } } template void concatenate(std::string& acc, T first, Args... args) { - acc += first; + if constexpr (is_numeric_except_char) { + acc += std::to_string(first); + } else { + acc += first; + } concatenate(acc, args...); } diff --git a/mlx/dtype_utils.cpp b/mlx/dtype_utils.cpp new file mode 100644 index 000000000..a4448536d --- /dev/null +++ b/mlx/dtype_utils.cpp @@ -0,0 +1,20 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/dtype_utils.h" + +namespace mlx::core { + +const char* dtype_to_string(Dtype arg) { + if (arg == bool_) { + return "bool"; + } +#define SPECIALIZE_DtypeToString(CPP_TYPE, DTYPE) \ + if (DTYPE == arg) { \ + return #DTYPE; \ + } + MLX_FORALL_DTYPES(SPECIALIZE_DtypeToString) +#undef SPECIALIZE_DtypeToString + return "(unknown)"; +} + +} // namespace mlx::core diff --git a/mlx/dtype_utils.h b/mlx/dtype_utils.h new file mode 100644 index 000000000..55de890f2 --- /dev/null +++ b/mlx/dtype_utils.h @@ -0,0 +1,207 @@ +// Copyright © 2025 Apple Inc. +// Copyright © Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the BSD-style license found in +// https://github.com/pytorch/executorch/blob/main/LICENSE +// +// Forked from +// https://github.com/pytorch/executorch/blob/main/runtime/core/exec_aten/util/scalar_type_util.h + +#pragma once + +#include "mlx/dtype.h" + +#include + +namespace mlx::core { + +// Return string representation of dtype. +const char* dtype_to_string(Dtype arg); + +// Macros that iterate across different subsets of Dtypes. +// +// For all of these macros, the final `_` parameter is the name of another macro +// that takes two parameters: the name of a C type, and the name of the +// corresponding Dtype enumerator. +// +// Note that these macros should use fully-qualified namespaces (starting with +// `::`) to ensure that they can be called safely in any arbitrary namespace. +#define MLX_FORALL_INT_TYPES(_) \ + _(uint8_t, uint8) \ + _(uint16_t, uint16) \ + _(uint32_t, uint32) \ + _(uint64_t, uint64) \ + _(int8_t, int8) \ + _(int16_t, int16) \ + _(int32_t, int32) \ + _(int64_t, int64) + +#define MLX_FORALL_FLOAT_TYPES(_) \ + _(float16_t, float16) \ + _(float, float32) \ + _(double, float64) \ + _(bfloat16_t, bfloat16) + +// Calls the provided macro on every Dtype, providing the C type and the +// Dtype name to each call. +// +// @param _ A macro that takes two parameters: the name of a C type, and the +// name of the corresponding Dtype enumerator. +#define MLX_FORALL_DTYPES(_) \ + MLX_FORALL_INT_TYPES(_) \ + MLX_FORALL_FLOAT_TYPES(_) \ + _(bool, bool_) \ + _(complex64_t, complex64) + +// Maps Dtypes to C++ types. +template +struct DtypeToCppType; + +#define SPECIALIZE_DtypeToCppType(CPP_TYPE, DTYPE) \ + template <> \ + struct DtypeToCppType { \ + using type = CPP_TYPE; \ + }; + +MLX_FORALL_DTYPES(SPECIALIZE_DtypeToCppType) + +#undef SPECIALIZE_DtypeToCppType + +// Maps C++ types to Dtypes. +template +struct CppTypeToDtype; + +#define SPECIALIZE_CppTypeToDtype(CPP_TYPE, DTYPE) \ + template <> \ + struct CppTypeToDtype \ + : std::integral_constant {}; + +MLX_FORALL_DTYPES(SPECIALIZE_CppTypeToDtype) + +#undef SPECIALIZE_CppTypeToDtype + +// Helper macros for switch case macros (see below) +// +// These macros are not meant to be used directly. They provide an easy way to +// generate a switch statement that can handle subsets of Dtypes supported. + +#define MLX_INTERNAL_SWITCH_CASE(enum_type, CTYPE_ALIAS, ...) \ + case enum_type: { \ + using CTYPE_ALIAS = ::mlx::core::DtypeToCppType::type; \ + __VA_ARGS__; \ + break; \ + } + +#define MLX_INTERNAL_SWITCH_CHECKED(TYPE, NAME, ...) \ + switch (TYPE) { \ + __VA_ARGS__ \ + default: \ + throw std::invalid_argument(fmt::format( \ + "Unhandled dtype %s for %s", dtype_to_string(TYPE), NAME)); \ + } + +#define MLX_INTERNAL_SWITCH_CASE_INT_TYPES(CTYPE_ALIAS, ...) \ + MLX_INTERNAL_SWITCH_CASE( \ + ::mlx::core::Dtype::Val::uint8, CTYPE_ALIAS, __VA_ARGS__) \ + MLX_INTERNAL_SWITCH_CASE( \ + ::mlx::core::Dtype::Val::uint16, CTYPE_ALIAS, __VA_ARGS__) \ + MLX_INTERNAL_SWITCH_CASE( \ + ::mlx::core::Dtype::Val::uint32, CTYPE_ALIAS, __VA_ARGS__) \ + MLX_INTERNAL_SWITCH_CASE( \ + ::mlx::core::Dtype::Val::uint64, CTYPE_ALIAS, __VA_ARGS__) \ + MLX_INTERNAL_SWITCH_CASE( \ + ::mlx::core::Dtype::Val::int8, CTYPE_ALIAS, __VA_ARGS__) \ + MLX_INTERNAL_SWITCH_CASE( \ + ::mlx::core::Dtype::Val::int16, CTYPE_ALIAS, __VA_ARGS__) \ + MLX_INTERNAL_SWITCH_CASE( \ + ::mlx::core::Dtype::Val::int32, CTYPE_ALIAS, __VA_ARGS__) \ + MLX_INTERNAL_SWITCH_CASE( \ + ::mlx::core::Dtype::Val::int64, CTYPE_ALIAS, __VA_ARGS__) + +#define MLX_INTERNAL_SWITCH_CASE_FLOAT_TYPES(CTYPE_ALIAS, ...) \ + MLX_INTERNAL_SWITCH_CASE( \ + ::mlx::core::Dtype::Val::float16, CTYPE_ALIAS, __VA_ARGS__) \ + MLX_INTERNAL_SWITCH_CASE( \ + ::mlx::core::Dtype::Val::float32, CTYPE_ALIAS, __VA_ARGS__) \ + MLX_INTERNAL_SWITCH_CASE( \ + ::mlx::core::Dtype::Val::float64, CTYPE_ALIAS, __VA_ARGS__) \ + MLX_INTERNAL_SWITCH_CASE( \ + ::mlx::core::Dtype::Val::bfloat16, CTYPE_ALIAS, __VA_ARGS__) + +#define MLX_INTERNAL_SWITCH_CASE_INT_FLOAT_TYPES(CTYPE_ALIAS, ...) \ + MLX_INTERNAL_SWITCH_CASE_INT_TYPES(CTYPE_ALIAS, __VA_ARGS__) \ + MLX_INTERNAL_SWITCH_CASE_FLOAT_TYPES(CTYPE_ALIAS, __VA_ARGS__) + +#define MLX_INTERNAL_SWITCH_CASE_REAL_TYPES(CTYPE_ALIAS, ...) \ + MLX_INTERNAL_SWITCH_CASE_INT_FLOAT_TYPES(CTYPE_ALIAS, __VA_ARGS__) \ + MLX_INTERNAL_SWITCH_CASE( \ + ::mlx::core::Dtype::Val::bool_, CTYPE_ALIAS, __VA_ARGS__) + +#define MLX_INTERNAL_SWITCH_CASE_COMPLEX_TYPES(CTYPE_ALIAS, ...) \ + MLX_INTERNAL_SWITCH_CASE( \ + ::mlx::core::Dtype::Val::complex64, CTYPE_ALIAS, __VA_ARGS__) + +#define MLX_INTERNAL_SWITCH_CASE_ALL_TYPES(CTYPE_ALIAS, ...) \ + MLX_INTERNAL_SWITCH_CASE_REAL_TYPES(CTYPE_ALIAS, __VA_ARGS__) \ + MLX_INTERNAL_SWITCH_CASE_COMPLEX_TYPES(CTYPE_ALIAS, __VA_ARGS__) + +// Switch case macros +// +// These macros provide an easy way to generate switch statements that apply a +// common lambda function to subsets of Dtypes supported by MLX. +// The lambda function can type specialize to the ctype associated with the +// Dtype being handled through an alias passed as the CTYPE_ALIAS argument. +// +// Arguments: +// - ADDITIONAL: Additional Dtype case to add +// - TYPE: The Dtype to handle through the switch statement +// - NAME: A name for this operation which will be used in error messages +// - CTYPE_ALIAS: A typedef for the ctype associated with the Dtype. +// - ...: A statement to be applied to each Dtype case +// +// An example usage is: +// +// MLX_SWITCH_ALL_TYPES(input.dtype(), CTYPE, { +// output.data[0] = input.data[0]; +// }); +// +// Note that these can be nested as well: +// +// MLX_SWITCH_ALL_TYPES(input.dtype(), CTYPE_IN, { +// MLX_SWITCH_ALL_TYPES(output.dtype(), CTYPE_OUT, { +// output.data[0] = input.data[0]; +// }); +// }); +// +// These macros are adapted from Dispatch.h in the ATen library. The primary +// difference is that the CTYPE_ALIAS argument is exposed to users, which is +// used to alias the ctype associated with the Dtype that is being handled. + +#define MLX_SWITCH_ALL_TYPES(TYPE, CTYPE_ALIAS, ...) \ + switch (TYPE) { MLX_INTERNAL_SWITCH_CASE_ALL_TYPES(CTYPE_ALIAS, __VA_ARGS__) } + +#define MLX_SWITCH_INT_TYPES_CHECKED(TYPE, NAME, CTYPE_ALIAS, ...) \ + MLX_INTERNAL_SWITCH_CHECKED( \ + TYPE, \ + NAME, \ + MLX_INTERNAL_SWITCH_CASE_INT_TYPES(CTYPE_ALIAS, __VA_ARGS__)) + +#define MLX_SWITCH_FLOAT_TYPES_CHECKED(TYPE, NAME, CTYPE_ALIAS, ...) \ + MLX_INTERNAL_SWITCH_CHECKED( \ + TYPE, \ + NAME, \ + MLX_INTERNAL_SWITCH_CASE_FLOAT_TYPES(CTYPE_ALIAS, __VA_ARGS__)) + +#define MLX_SWITCH_INT_FLOAT_TYPES_CHECKED(TYPE, NAME, CTYPE_ALIAS, ...) \ + MLX_INTERNAL_SWITCH_CHECKED( \ + TYPE, \ + NAME, \ + MLX_INTERNAL_SWITCH_CASE_INT_FLOAT_TYPES(CTYPE_ALIAS, __VA_ARGS__)) + +#define MLX_SWITCH_REAL_TYPES_CHECKED(TYPE, NAME, CTYPE_ALIAS, ...) \ + MLX_INTERNAL_SWITCH_CHECKED( \ + TYPE, \ + NAME, \ + MLX_INTERNAL_SWITCH_CASE_REAL_TYPES(CTYPE_ALIAS, __VA_ARGS__)) + +} // namespace mlx::core diff --git a/mlx/export.cpp b/mlx/export.cpp index 8051f786c..effc7a0c1 100644 --- a/mlx/export.cpp +++ b/mlx/export.cpp @@ -1,5 +1,6 @@ // Copyright © 2024 Apple Inc. #include "mlx/export.h" +#include #include "mlx/compile_impl.h" #include "mlx/fast_primitives.h" #include "mlx/primitives.h" @@ -298,7 +299,13 @@ struct PrimitiveFactory { SERIALIZE_PRIMITIVE(Reshape), SERIALIZE_PRIMITIVE(Reduce, "And", "Or", "Sum", "Prod", "Min", "Max"), SERIALIZE_PRIMITIVE(Round), - SERIALIZE_PRIMITIVE(Scan, "CumSum", "CumProd", "CumMin", "CumMax"), + SERIALIZE_PRIMITIVE( + Scan, + "CumSum", + "CumProd", + "CumMin", + "CumMax", + "CumLogaddexp"), SERIALIZE_PRIMITIVE(Scatter), SERIALIZE_PRIMITIVE(Select), SERIALIZE_PRIMITIVE(Sigmoid), @@ -475,7 +482,9 @@ bool FunctionTable::match( return false; } } - for (auto& [_, in] : kwargs) { + auto sorted_kwargs = + std::map(kwargs.begin(), kwargs.end()); + for (auto& [_, in] : sorted_kwargs) { if (!match_inputs(in, fun.inputs[i++])) { return false; } @@ -551,7 +560,9 @@ void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) { // Flatten the inputs to the function for tracing std::vector kwarg_keys; auto inputs = args; - for (auto& [k, v] : kwargs) { + auto sorted_kwargs = + std::map(kwargs.begin(), kwargs.end()); + for (auto& [k, v] : sorted_kwargs) { kwarg_keys.push_back(k); inputs.push_back(v); } diff --git a/mlx/export.h b/mlx/export.h index da090510b..c6859c6d8 100644 --- a/mlx/export.h +++ b/mlx/export.h @@ -2,14 +2,14 @@ #pragma once -#include #include +#include #include "mlx/array.h" namespace mlx::core { using Args = std::vector; -using Kwargs = std::map; +using Kwargs = std::unordered_map; struct FunctionExporter; diff --git a/mlx/fft.cpp b/mlx/fft.cpp index 68b032727..c41b37843 100644 --- a/mlx/fft.cpp +++ b/mlx/fft.cpp @@ -111,7 +111,7 @@ array fft_impl( for (auto ax : axes) { n.push_back(a.shape(ax)); } - if (real && inverse) { + if (real && inverse && a.ndim() > 0) { n.back() = (n.back() - 1) * 2; } return fft_impl(a, n, axes, real, inverse, s); diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 2f3997e7b..54ac62fef 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -3504,6 +3504,28 @@ array cummin( {a}); } +array logcumsumexp( + const array& a, + int axis, + bool reverse /* = false*/, + bool inclusive /* = true*/, + StreamOrDevice s /* = {}*/) { + int ndim = a.ndim(); + if (axis >= ndim || axis < -ndim) { + std::ostringstream msg; + msg << "[logcumsumexp] Axis " << axis << " is out of bounds for array with " + << a.ndim() << " dimensions."; + throw std::invalid_argument(msg.str()); + } + axis = (axis + a.ndim()) % a.ndim(); + return array( + a.shape(), + a.dtype(), + std::make_shared( + to_stream(s), Scan::ReduceType::LogAddExp, axis, reverse, inclusive), + {a}); +} + /** Convolution operations */ namespace { @@ -4006,6 +4028,7 @@ array gather_qmm( bool transpose /* = true */, int group_size /* = 64 */, int bits /* = 4 */, + bool sorted_indices /* = false */, StreamOrDevice s /* = {} */) { if (!lhs_indices_ && !rhs_indices_) { return quantized_matmul( @@ -4045,13 +4068,19 @@ array gather_qmm( return array( std::move(out_shape), out_type, - std::make_shared(to_stream(s), group_size, bits, transpose), + std::make_shared( + to_stream(s), + group_size, + bits, + transpose, + sorted_indices && !rhs_indices_, + sorted_indices && !lhs_indices_), {astype(x, out_type, s), - w, + std::move(w), astype(scales, out_type, s), astype(biases, out_type, s), - lhs_indices, - rhs_indices}); + std::move(lhs_indices), + std::move(rhs_indices)}); } array tensordot( @@ -4477,6 +4506,7 @@ array gather_mm( array b, std::optional lhs_indices_ /* = std::nullopt */, std::optional rhs_indices_ /* = std::nullopt */, + bool sorted_indices /* = false */, StreamOrDevice s /* = {} */) { // If no indices, fall back to full matmul if (!lhs_indices_ && !rhs_indices_) { @@ -4552,12 +4582,18 @@ array gather_mm( out_shape.push_back(M); out_shape.push_back(N); - // Caculate array + // Make the output array auto out = array( std::move(out_shape), out_type, - std::make_shared(to_stream(s)), - {a, b, lhs_indices, rhs_indices}); + std::make_shared( + to_stream(s), + sorted_indices && !rhs_indices_, + sorted_indices && !lhs_indices_), + {std::move(a), + std::move(b), + std::move(lhs_indices), + std::move(rhs_indices)}); // Remove the possibly inserted singleton dimensions std::vector axes; @@ -4879,8 +4915,10 @@ array operator^(const array& a, const array& b) { } array left_shift(const array& a, const array& b, StreamOrDevice s /* = {} */) { - // Bit shift on bool always up-casts to uint8 - auto t = promote_types(result_type(a, b), uint8); + auto t = result_type(a, b); + if (t == bool_) { + t = uint8; + } return bitwise_impl( astype(a, t, s), astype(b, t, s), @@ -4893,8 +4931,10 @@ array operator<<(const array& a, const array& b) { } array right_shift(const array& a, const array& b, StreamOrDevice s /* = {} */) { - // Bit shift on bool always up-casts to uint8 - auto t = promote_types(result_type(a, b), uint8); + auto t = result_type(a, b); + if (t == bool_) { + t = uint8; + } return bitwise_impl( astype(a, t, s), astype(b, t, s), diff --git a/mlx/ops.h b/mlx/ops.h index 02428b974..e79ea235d 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -715,6 +715,14 @@ array topk(const array& a, int k, StreamOrDevice s = {}); /** Returns topk elements of the array along a given axis. */ array topk(const array& a, int k, int axis, StreamOrDevice s = {}); +/** Cumulative logsumexp of an array. */ +array logcumsumexp( + const array& a, + int axis, + bool reverse = false, + bool inclusive = true, + StreamOrDevice s = {}); + /** The logsumexp of all elements of the array. */ array logsumexp(const array& a, bool keepdims, StreamOrDevice s = {}); inline array logsumexp(const array& a, StreamOrDevice s = {}) { @@ -1344,6 +1352,7 @@ array gather_qmm( bool transpose = true, int group_size = 64, int bits = 4, + bool sorted_indices = false, StreamOrDevice s = {}); /** Returns a contraction of a and b over multiple dimensions. */ @@ -1391,6 +1400,7 @@ array gather_mm( array b, std::optional lhs_indices = std::nullopt, std::optional rhs_indices = std::nullopt, + bool sorted_indices = false, StreamOrDevice s = {}); /** Extract a diagonal or construct a diagonal array */ diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 6f9e45313..3d36f0881 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -1275,6 +1275,61 @@ std::vector Convolution::vjp( return grads; } +std::pair, std::vector> Convolution::vmap( + const std::vector& inputs, + const std::vector& axes) { + auto do_conv = [&](const array& in, const array& w, int groups) { + return conv_general( + in, + w, + kernel_strides_, + padding_, + kernel_dilation_, + input_dilation_, + groups, + flip_, + stream()); + }; + bool in_vmap = axes[0] >= 0; + bool w_vmap = axes[1] >= 0; + auto in = inputs[0]; + auto w = inputs[1]; + if (in_vmap && !w_vmap) { + // flatten / unflatten the batch dimension + // of the input / output + if (axes[0] > 0) { + in = moveaxis(in, axes[0], 0, stream()); + } + auto out = do_conv(flatten(in, 0, 1, stream()), w, groups_); + out = unflatten(out, 0, {in.shape(0), in.shape(1)}, stream()); + return {{out}, {0}}; + } else if (!in_vmap && w_vmap) { + // flatten into the output channels of w + // unflatten the channels of the output + if (axes[1] > 0) { + w = moveaxis(w, axes[1], 0, stream()); + } + auto out = do_conv(in, flatten(w, 0, 1, stream()), groups_); + out = unflatten(out, -1, {w.shape(0), w.shape(1)}, stream()); + return {{out}, {static_cast(out.ndim() - 2)}}; + } else if (in_vmap && w_vmap) { + // use a group convolution when both inputs are vmapped + auto b = in.shape(axes[0]); + in = moveaxis(in, axes[0], -2, stream()); + in = flatten(in, -2, -1, stream()); + if (axes[1] > 0) { + w = moveaxis(w, axes[1], 0, stream()); + } + auto c_out = w.shape(1); + w = flatten(w, 0, 1, stream()); + auto out = do_conv(in, w, groups_ * b); + out = unflatten(out, -1, {b, c_out}, stream()); + return {{out}, {static_cast(out.ndim() - 2)}}; + } else { + return {{do_conv(in, w, groups_)}, {-1}}; + } +} + bool Convolution::is_equivalent(const Primitive& other) const { const Convolution& c_other = static_cast(other); return padding_ == c_other.padding_ && @@ -3080,6 +3135,8 @@ std::vector GatherQMM::vjp( auto& lhs_indices = primals[4]; auto& rhs_indices = primals[5]; + bool sorted = left_sorted_ || right_sorted_; + for (auto arg : argnums) { // gradient wrt to x if (arg == 0) { @@ -3098,6 +3155,7 @@ std::vector GatherQMM::vjp( !transpose_, group_size_, bits_, + sorted, stream()), -3, stream()), @@ -3478,6 +3536,45 @@ std::vector Scan::vjp( if (reduce_type_ == Scan::Sum) { return {cumsum(cotangents[0], axis_, !reverse_, inclusive_, stream())}; + } else if (reduce_type_ == Scan::LogAddExp) { + // Ref: + // https://github.com/tensorflow/tensorflow/blob/2a5910906a0e0f3dbc186ff9db6386d81a63448c/tensorflow/python/ops/math_grad.py#L1832-L1863 + + auto x = primals[0]; + auto grad = cotangents[0]; + auto results = outputs[0]; + + auto zero = zeros({1}, grad.dtype(), stream()); + auto grad_min = array(finfo(grad.dtype()).min, grad.dtype()); + + // Split the incoming gradient into positive and negative part + // in order to take logs. This is required for stable results. + auto log_abs_grad = log(abs(grad, stream()), stream()); + auto log_grad_positive = + where(greater(grad, zero, stream()), log_abs_grad, grad_min, stream()); + auto log_grad_negative = + where(less(grad, zero, stream()), log_abs_grad, grad_min, stream()); + + auto output_pos = exp( + add(logcumsumexp( + subtract(log_grad_positive, results, stream()), + axis_, + !reverse_, + inclusive_, + stream()), + x, + stream())); + auto output_neg = exp( + add(logcumsumexp( + subtract(log_grad_negative, results, stream()), + axis_, + !reverse_, + inclusive_, + stream()), + x, + stream())); + + return {subtract(output_pos, output_neg, stream())}; } else if (reduce_type_ == Scan::Prod) { auto in = primals[0]; // Find the location of the first 0 and set it to 1: @@ -4856,6 +4953,8 @@ std::vector GatherMM::vjp( int N = cotan.shape(-1); int K = primals[0].shape(-1); + bool sorted = left_sorted_ || right_sorted_; + for (auto arg : argnums) { if (arg == 0) { // M X N * (K X N).T -> M X K @@ -4866,7 +4965,8 @@ std::vector GatherMM::vjp( base = reshape(base, {-1, M, K}, stream()); // g : (out_batch_shape) + (M, K) - auto g = gather_mm(cotan, bt, std::nullopt, rhs_indices, stream()); + auto g = + gather_mm(cotan, bt, std::nullopt, rhs_indices, sorted, stream()); g = expand_dims(g, -3, stream()); auto gacc = scatter_add(base, lhs_indices, g, 0, stream()); @@ -4881,7 +4981,8 @@ std::vector GatherMM::vjp( base = reshape(base, {-1, K, N}, stream()); // g : (out_batch_shape) + (K, N) - auto g = gather_mm(at, cotan, lhs_indices, std::nullopt, stream()); + auto g = + gather_mm(at, cotan, lhs_indices, std::nullopt, sorted, stream()); g = expand_dims(g, -3, stream()); auto gacc = scatter_add(base, rhs_indices, g, 0, stream()); @@ -4894,6 +4995,12 @@ std::vector GatherMM::vjp( return vjps; } +bool GatherMM::is_equivalent(const Primitive& other) const { + const GatherMM& g_other = static_cast(other); + return left_sorted_ == g_other.left_sorted_ && + right_sorted_ == g_other.right_sorted_; +} + bool BlockMaskedMM::is_equivalent(const Primitive& other) const { const BlockMaskedMM& a_other = static_cast(other); return (block_size_ == a_other.block_size_); diff --git a/mlx/primitives.h b/mlx/primitives.h index c7b2de878..3753e43c5 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -498,7 +498,13 @@ class BlockMaskedMM : public UnaryPrimitive { class GatherMM : public UnaryPrimitive { public: - explicit GatherMM(Stream stream) : UnaryPrimitive(stream) {} + explicit GatherMM( + Stream stream, + bool left_sorted = false, + bool right_sorted = false) + : UnaryPrimitive(stream), + left_sorted_(left_sorted), + right_sorted_(right_sorted) {} void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -510,7 +516,14 @@ class GatherMM : public UnaryPrimitive { const std::vector& outputs) override; DEFINE_PRINT(GatherMM) - DEFINE_DEFAULT_IS_EQUIVALENT() + bool is_equivalent(const Primitive& other) const override; + auto state() const { + return std::make_pair(left_sorted_, right_sorted_); + } + + private: + bool left_sorted_; + bool right_sorted_; }; class BroadcastAxes : public UnaryPrimitive { @@ -698,6 +711,7 @@ class Convolution : public UnaryPrimitive { const std::vector& argnums, const std::vector& outputs) override; + DEFINE_VMAP() DEFINE_PRINT(Convolution) bool is_equivalent(const Primitive& other) const override; auto state() const { @@ -1578,11 +1592,19 @@ class QuantizedMatmul : public UnaryPrimitive { class GatherQMM : public UnaryPrimitive { public: - explicit GatherQMM(Stream stream, int group_size, int bits, bool transpose) + explicit GatherQMM( + Stream stream, + int group_size, + int bits, + bool transpose, + bool left_sorted = false, + bool right_sorted = false) : UnaryPrimitive(stream), group_size_(group_size), bits_(bits), - transpose_(transpose) {} + transpose_(transpose), + left_sorted_(left_sorted), + right_sorted_(right_sorted) {} void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -1592,13 +1614,16 @@ class GatherQMM : public UnaryPrimitive { DEFINE_PRINT(GatherQMM) bool is_equivalent(const Primitive& other) const override; auto state() const { - return std::make_tuple(group_size_, bits_, transpose_); + return std::make_tuple( + group_size_, bits_, transpose_, left_sorted_, right_sorted_); } private: int group_size_; int bits_; bool transpose_; + bool left_sorted_; + bool right_sorted_; }; class RandomBits : public UnaryPrimitive { @@ -1728,7 +1753,7 @@ class Round : public UnaryPrimitive { class Scan : public UnaryPrimitive { public: - enum ReduceType { Max, Min, Sum, Prod }; + enum ReduceType { Max, Min, Sum, Prod, LogAddExp }; explicit Scan( Stream stream, @@ -1763,6 +1788,9 @@ class Scan : public UnaryPrimitive { case Max: os << "Max"; break; + case LogAddExp: + os << "Logaddexp"; + break; } } bool is_equivalent(const Primitive& other) const override; diff --git a/mlx/utils.cpp b/mlx/utils.cpp index 5197e516f..188584174 100644 --- a/mlx/utils.cpp +++ b/mlx/utils.cpp @@ -5,6 +5,7 @@ #include #include +#include "mlx/dtype_utils.h" #include "mlx/types/limits.h" #include "mlx/utils.h" @@ -224,37 +225,7 @@ void print_array(std::ostream& os, const array& a) { } // namespace std::ostream& operator<<(std::ostream& os, const Dtype& dtype) { - switch (dtype) { - case bool_: - return os << "bool"; - case uint8: - return os << "uint8"; - case uint16: - return os << "uint16"; - case uint32: - return os << "uint32"; - case uint64: - return os << "uint64"; - case int8: - return os << "int8"; - case int16: - return os << "int16"; - case int32: - return os << "int32"; - case int64: - return os << "int64"; - case float16: - return os << "float16"; - case float32: - return os << "float32"; - case float64: - return os << "float64"; - case bfloat16: - return os << "bfloat16"; - case complex64: - return os << "complex64"; - } - return os; + return os << dtype_to_string(dtype); } std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k) { @@ -277,50 +248,7 @@ std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k) { std::ostream& operator<<(std::ostream& os, array a) { a.eval(); - switch (a.dtype()) { - case bool_: - print_array(os, a); - break; - case uint8: - print_array(os, a); - break; - case uint16: - print_array(os, a); - break; - case uint32: - print_array(os, a); - break; - case uint64: - print_array(os, a); - break; - case int8: - print_array(os, a); - break; - case int16: - print_array(os, a); - break; - case int32: - print_array(os, a); - break; - case int64: - print_array(os, a); - break; - case float16: - print_array(os, a); - break; - case bfloat16: - print_array(os, a); - break; - case float32: - print_array(os, a); - break; - case float64: - print_array(os, a); - break; - case complex64: - print_array(os, a); - break; - } + MLX_SWITCH_ALL_TYPES(a.dtype(), CTYPE, print_array(os, a)); return os; } @@ -387,36 +315,8 @@ void set_iinfo_limits(int64_t& min, uint64_t& max) { } iinfo::iinfo(Dtype dtype) : dtype(dtype) { - switch (dtype) { - case int8: - set_iinfo_limits(min, max); - break; - case uint8: - set_iinfo_limits(min, max); - break; - case int16: - set_iinfo_limits(min, max); - break; - case uint16: - set_iinfo_limits(min, max); - break; - case int32: - set_iinfo_limits(min, max); - break; - case uint32: - set_iinfo_limits(min, max); - break; - case int64: - set_iinfo_limits(min, max); - break; - case uint64: - set_iinfo_limits(min, max); - break; - default: - std::ostringstream msg; - msg << "[iinfo] dtype " << dtype << " is not integral."; - throw std::invalid_argument(msg.str()); - } + MLX_SWITCH_INT_TYPES_CHECKED( + dtype, "[iinfo]", CTYPE, set_iinfo_limits(min, max)); } } // namespace mlx::core diff --git a/mlx/version.h b/mlx/version.h index 35b026149..fe47d96cc 100644 --- a/mlx/version.h +++ b/mlx/version.h @@ -3,8 +3,8 @@ #pragma once #define MLX_VERSION_MAJOR 0 -#define MLX_VERSION_MINOR 24 -#define MLX_VERSION_PATCH 2 +#define MLX_VERSION_MINOR 25 +#define MLX_VERSION_PATCH 0 #define MLX_VERSION_NUMERIC \ (100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH) diff --git a/python/src/array.cpp b/python/src/array.cpp index e380f2652..467bd0fa5 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -1202,6 +1202,28 @@ void init_array(nb::module_& m) { nb::kw_only(), "stream"_a = nb::none(), "See :func:`max`.") + .def( + "logcumsumexp", + [](const mx::array& a, + std::optional axis, + bool reverse, + bool inclusive, + mx::StreamOrDevice s) { + if (axis) { + return mx::logcumsumexp(a, *axis, reverse, inclusive, s); + } else { + // TODO: Implement that in the C++ API as well. See concatenate + // above. + return mx::logcumsumexp( + mx::reshape(a, {-1}, s), 0, reverse, inclusive, s); + } + }, + "axis"_a = nb::none(), + nb::kw_only(), + "reverse"_a = false, + "inclusive"_a = true, + "stream"_a = nb::none(), + "See :func:`logcumsumexp`.") .def( "logsumexp", [](const mx::array& a, diff --git a/python/src/export.cpp b/python/src/export.cpp index feefeb12c..30062ae37 100644 --- a/python/src/export.cpp +++ b/python/src/export.cpp @@ -1,8 +1,8 @@ // Copyright © 2024 Apple Inc. #include -#include #include #include +#include #include #include @@ -16,8 +16,7 @@ namespace mx = mlx::core; namespace nb = nanobind; using namespace nb::literals; -std::pair, std::map> -validate_and_extract_inputs( +std::pair validate_and_extract_inputs( const nb::args& args, const nb::kwargs& kwargs, const std::string& prefix) { @@ -30,8 +29,8 @@ validate_and_extract_inputs( "and/or dictionary of arrays."); } }; - std::vector args_; - std::map kwargs_; + mx::Args args_; + mx::Kwargs kwargs_; if (args.size() == 0) { // No args so kwargs must be keyword arrays maybe_throw(nb::try_cast(kwargs, kwargs_)); @@ -81,9 +80,7 @@ class PyFunctionExporter { void close() { exporter_.close(); } - void operator()( - const std::vector& args, - const std::map& kwargs) { + void operator()(const mx::Args& args, const mx::Kwargs& kwargs) { exporter_(args, kwargs); } @@ -98,9 +95,12 @@ int py_function_exporter_tp_traverse( PyObject* self, visitproc visit, void* arg) { + Py_VISIT(Py_TYPE(self)); + if (!nb::inst_ready(self)) { + return 0; + } auto* p = nb::inst_ptr(self); Py_VISIT(p->dep_.ptr()); - Py_VISIT(Py_TYPE(self)); return 0; } @@ -109,23 +109,22 @@ PyType_Slot py_function_exporter_slots[] = { {0, 0}}; auto wrap_export_function(nb::callable fun) { - return [fun = std::move(fun)]( - const std::vector& args_, - const std::map& kwargs_) { - auto kwargs = nb::dict(); - kwargs.update(nb::cast(kwargs_)); - auto args = nb::tuple(nb::cast(args_)); - auto outputs = fun(*args, **kwargs); - std::vector outputs_; - if (nb::isinstance(outputs)) { - outputs_.push_back(nb::cast(outputs)); - } else if (!nb::try_cast(outputs, outputs_)) { - throw std::invalid_argument( - "[export_function] Outputs can be either a single array " - "a tuple or list of arrays."); - } - return outputs_; - }; + return + [fun = std::move(fun)](const mx::Args& args_, const mx::Kwargs& kwargs_) { + auto kwargs = nb::dict(); + kwargs.update(nb::cast(kwargs_)); + auto args = nb::tuple(nb::cast(args_)); + auto outputs = fun(*args, **kwargs); + std::vector outputs_; + if (nb::isinstance(outputs)) { + outputs_.push_back(nb::cast(outputs)); + } else if (!nb::try_cast(outputs, outputs_)) { + throw std::invalid_argument( + "[export_function] Outputs can be either a single array " + "a tuple or list of arrays."); + } + return outputs_; + }; } void init_export(nb::module_& m) { diff --git a/python/src/mlx_func.cpp b/python/src/mlx_func.cpp index b2eca5f6f..2f0589bb6 100644 --- a/python/src/mlx_func.cpp +++ b/python/src/mlx_func.cpp @@ -16,12 +16,12 @@ struct gc_func { }; int gc_func_tp_traverse(PyObject* self, visitproc visit, void* arg) { + Py_VISIT(Py_TYPE(self)); gc_func* w = (gc_func*)self; Py_VISIT(w->func); for (auto d : w->deps) { Py_VISIT(d); } - Py_VISIT(Py_TYPE(self)); return 0; }; diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 5f078a08d..f98aa80aa 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -2382,6 +2382,43 @@ void init_ops(nb::module_& m) { Returns: array: The output array with the corresponding axes reduced. )pbdoc"); + m.def( + "logcumsumexp", + [](const mx::array& a, + std::optional axis, + bool reverse, + bool inclusive, + mx::StreamOrDevice s) { + if (axis) { + return mx::logcumsumexp(a, *axis, reverse, inclusive, s); + } else { + return mx::logcumsumexp( + mx::reshape(a, {-1}, s), 0, reverse, inclusive, s); + } + }, + nb::arg(), + "axis"_a = nb::none(), + nb::kw_only(), + "reverse"_a = false, + "inclusive"_a = true, + "stream"_a = nb::none(), + nb::sig( + "def logcumsumexp(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Return the cumulative logsumexp of the elements along the given axis. + + Args: + a (array): Input array + axis (int, optional): Optional axis to compute the cumulative logsumexp + over. If unspecified the cumulative logsumexp of the flattened array is + returned. + reverse (bool): Perform the cumulative logsumexp in reverse. + inclusive (bool): The i-th element of the output includes the i-th + element of the input. + + Returns: + array: The output array. + )pbdoc"); m.def( "logsumexp", [](const mx::array& a, @@ -4213,9 +4250,10 @@ void init_ops(nb::module_& m) { "group_size"_a = 64, "bits"_a = 4, nb::kw_only(), + "sorted_indices"_a = false, "stream"_a = nb::none(), nb::sig( - "def gather_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"), + "def gather_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, *, sorted_indices: bool = False, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Perform quantized matrix multiplication with matrix-level gather. @@ -4228,23 +4266,25 @@ void init_ops(nb::module_& m) { as ``w`` since they represent the same quantized matrix. Args: - x (array): Input array - w (array): Quantized matrix packed in unsigned integers - scales (array): The scales to use per ``group_size`` elements of ``w`` - biases (array): The biases to use per ``group_size`` elements of ``w`` - lhs_indices (array, optional): Integer indices for ``x``. Default: ``None``. - rhs_indices (array, optional): Integer indices for ``w``. Default: ``None``. - transpose (bool, optional): Defines whether to multiply with the - transposed ``w`` or not, namely whether we are performing - ``x @ w.T`` or ``x @ w``. Default: ``True``. - group_size (int, optional): The size of the group in ``w`` that - shares a scale and bias. Default: ``64``. - bits (int, optional): The number of bits occupied by each element in - ``w``. Default: ``4``. + x (array): Input array + w (array): Quantized matrix packed in unsigned integers + scales (array): The scales to use per ``group_size`` elements of ``w`` + biases (array): The biases to use per ``group_size`` elements of ``w`` + lhs_indices (array, optional): Integer indices for ``x``. Default: ``None``. + rhs_indices (array, optional): Integer indices for ``w``. Default: ``None``. + transpose (bool, optional): Defines whether to multiply with the + transposed ``w`` or not, namely whether we are performing + ``x @ w.T`` or ``x @ w``. Default: ``True``. + group_size (int, optional): The size of the group in ``w`` that + shares a scale and bias. Default: ``64``. + bits (int, optional): The number of bits occupied by each element in + ``w``. Default: ``4``. + sorted_indices (bool, optional): May allow a faster implementation + if the passed indices are sorted. Default: ``False``. Returns: - array: The result of the multiplication of ``x`` with ``w`` - after gathering using ``lhs_indices`` and ``rhs_indices``. + array: The result of the multiplication of ``x`` with ``w`` + after gathering using ``lhs_indices`` and ``rhs_indices``. )pbdoc"); m.def( "tensordot", @@ -4274,16 +4314,16 @@ void init_ops(nb::module_& m) { Compute the tensor dot product along the specified axes. Args: - a (array): Input array - b (array): Input array - axes (int or list(list(int)), optional): The number of dimensions to - sum over. If an integer is provided, then sum over the last - ``axes`` dimensions of ``a`` and the first ``axes`` dimensions of - ``b``. If a list of lists is provided, then sum over the - corresponding dimensions of ``a`` and ``b``. Default: 2. + a (array): Input array + b (array): Input array + axes (int or list(list(int)), optional): The number of dimensions to + sum over. If an integer is provided, then sum over the last + ``axes`` dimensions of ``a`` and the first ``axes`` dimensions of + ``b``. If a list of lists is provided, then sum over the + corresponding dimensions of ``a`` and ``b``. Default: 2. Returns: - array: The tensor dot product. + array: The tensor dot product. )pbdoc"); m.def( "inner", @@ -4427,9 +4467,10 @@ void init_ops(nb::module_& m) { "lhs_indices"_a = nb::none(), "rhs_indices"_a = nb::none(), nb::kw_only(), + "sorted_indices"_a = false, "stream"_a = nb::none(), nb::sig( - "def gather_mm(a: array, b: array, /, lhs_indices: array, rhs_indices: array, *, stream: Union[None, Stream, Device] = None) -> array"), + "def gather_mm(a: array, b: array, /, lhs_indices: array, rhs_indices: array, *, sorted_indices: bool = False, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Matrix multiplication with matrix-level gather. @@ -4448,11 +4489,16 @@ void init_ops(nb::module_& m) { For ``b`` with shape ``(B1, B2, ..., BS, M, K)``, ``rhs_indices`` contains indices from the range ``[0, B1 * B2 * ... * BS)`` + If only one index is passed and it is sorted, the ``sorted_indices`` + flag can be passed for a possible faster implementation. + Args: a (array): Input array. b (array): Input array. lhs_indices (array, optional): Integer indices for ``a``. Default: ``None`` rhs_indices (array, optional): Integer indices for ``b``. Default: ``None`` + sorted_indices (bool, optional): May allow a faster implementation + if the passed indices are sorted. Default: ``False``. Returns: array: The output array. diff --git a/python/src/transforms.cpp b/python/src/transforms.cpp index 4a5e2e6ac..c47942b72 100644 --- a/python/src/transforms.cpp +++ b/python/src/transforms.cpp @@ -960,6 +960,11 @@ class PyCustomFunction { }; int py_custom_function_tp_traverse(PyObject* self, visitproc visit, void* arg) { + Py_VISIT(Py_TYPE(self)); + if (!nb::inst_ready(self)) { + return 0; + } + auto* p = nb::inst_ptr(self); nb::handle v = nb::find(p->fun_); Py_VISIT(v.ptr()); @@ -975,7 +980,6 @@ int py_custom_function_tp_traverse(PyObject* self, visitproc visit, void* arg) { nb::handle v = nb::find(*(p->vmap_fun_)); Py_VISIT(v.ptr()); } - Py_VISIT(Py_TYPE(self)); return 0; } int py_custom_function_tp_clear(PyObject* self) { diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 67dc7c84b..fa5784ea9 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -1508,6 +1508,7 @@ class TestArray(mlx_tests.MLXTestCase): ("prod", 1), ("min", 1), ("max", 1), + ("logcumsumexp", 1), ("logsumexp", 1), ("mean", 1), ("var", 1), diff --git a/python/tests/test_blas.py b/python/tests/test_blas.py index 8b7fb462d..6fca4885b 100644 --- a/python/tests/test_blas.py +++ b/python/tests/test_blas.py @@ -1108,7 +1108,7 @@ class TestBlas(mlx_tests.MLXTestCase): lhs_indices_ = mx.broadcast_to(lhs_indices, (3, 2)) rhs_indices_ = mx.broadcast_to(rhs_indices, (3, 2)) M = a.shape[-2] - N = b.shape[-2] + N = b.shape[-1] K = a.shape[-1] a = a.reshape((-1, M, K)) diff --git a/python/tests/test_fft.py b/python/tests/test_fft.py index ec9a48f00..c887cd968 100644 --- a/python/tests/test_fft.py +++ b/python/tests/test_fft.py @@ -194,6 +194,11 @@ class TestFFT(mlx_tests.MLXTestCase): r_np = np.fft.ifft(segment, n=n_fft) self.assertTrue(np.allclose(r, r_np, atol=1e-5, rtol=1e-5)) + def test_fft_throws(self): + x = mx.array(3.0) + with self.assertRaises(ValueError): + mx.fft.irfftn(x) + if __name__ == "__main__": unittest.main() diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index a71d2c253..31ea79345 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1857,6 +1857,30 @@ class TestOps(mlx_tests.MLXTestCase): y = mx.as_strided(x, (x.size,), (-1,), x.size - 1) self.assertTrue(mx.array_equal(y, x[::-1])) + def test_logcumsumexp(self): + npop = np.logaddexp.accumulate + mxop = mx.logcumsumexp + + a_npy = np.random.randn(32, 32, 32).astype(np.float32) + a_mlx = mx.array(a_npy) + + for axis in (0, 1, 2): + c_npy = npop(a_npy, axis=axis) + c_mlx = mxop(a_mlx, axis=axis) + self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3)) + + edge_cases_npy = [ + np.float32([-float("inf")] * 8), + np.float32([-float("inf"), 0, -float("inf")]), + np.float32([-float("inf"), float("inf"), -float("inf")]), + ] + edge_cases_mlx = [mx.array(a) for a in edge_cases_npy] + + for a_npy, a_mlx in zip(edge_cases_npy, edge_cases_mlx): + c_npy = npop(a_npy, axis=0) + c_mlx = mxop(a_mlx, axis=0) + self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3)) + def test_scans(self): a_npy = np.random.randn(32, 32, 32).astype(np.float32) a_mlx = mx.array(a_npy) @@ -2910,6 +2934,35 @@ class TestOps(mlx_tests.MLXTestCase): out = a[::-1] self.assertTrue(mx.array_equal(out[-1, :], a[0, :])) + def test_complex_ops(self): + x = mx.array( + [ + 3.0 + 4.0j, + -5.0 + 12.0j, + -8.0 + 0.0j, + 0.0 + 9.0j, + 0.0 + 0.0j, + ] + ) + + ops = ["arccos", "arcsin", "arctan", "square", "sqrt"] + for op in ops: + with self.subTest(op=op): + np_op = getattr(np, op) + mx_op = getattr(mx, op) + self.assertTrue(np.allclose(mx_op(x), np_op(x))) + + x = mx.array( + [ + 3.0 + 4.0j, + -5.0 + 12.0j, + -8.0 + 0.0j, + 0.0 + 9.0j, + 9.0 + 1.0j, + ] + ) + self.assertTrue(np.allclose(mx.rsqrt(x), 1.0 / np.sqrt(x))) + if __name__ == "__main__": unittest.main() diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index 160eb6400..eeefcd94f 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -174,12 +174,14 @@ class TestQuantized(mlx_tests.MLXTestCase): tests = product( [128, 64, 32], # group_size [2, 3, 4, 6, 8], # bits - [128, 256], # M + [32, 128, 256], # M [128, 256, 67], # N [0, 1, 3, 8], # B ) for group_size, bits, M, N, B in tests: with self.subTest(shape=(B, M, N), group_size=group_size, bits=bits): + if M < group_size: + continue x_shape = (1, N) if B == 0 else (B, 1, N) w_shape = (N, M) if B == 0 else (B, N, M) x = mx.random.normal(shape=x_shape, key=k1) @@ -448,6 +450,7 @@ class TestQuantized(mlx_tests.MLXTestCase): ) for kwargs in inputs: + test_shape(1, 32, 128, **kwargs) test_shape(32, 32, 256, **kwargs) test_shape(1, 32, 256, **kwargs) test_shape(32, 256, 32, transpose=False, **kwargs) @@ -486,6 +489,66 @@ class TestQuantized(mlx_tests.MLXTestCase): g2 = mx.grad(f_test)(x, qw, s, b, lhs_indices, rhs_indices) self.assertTrue(mx.allclose(g1, g2, atol=1e-4)) + def test_gather_qmm_sorted(self): + def quantize(w, transpose=True, group_size=64, bits=4): + qw, s, b = mx.quantize(w, group_size=group_size, bits=bits) + w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits) + if transpose: + w_hat = w_hat.swapaxes(-1, -2) + return w_hat, qw, s, b + + def gather_sort(x, indices): + N, M = indices.shape + indices = indices.flatten() + order = mx.argsort(indices) + inv_order = mx.argsort(order) + return x.flatten(0, -3)[order // M], indices[order], inv_order + + def scatter_unsort(x, inv_order, shape=None): + x = x[inv_order] + if shape is not None: + x = mx.unflatten(x, 0, shape) + return x + + parameters = [ + # L, K, D, E, I, transpose + (128, 1024, 1024, 32, 4, True), + (128, 1024, 544, 32, 4, True), + (433, 1024, 1024, 32, 4, True), + (433, 1024, 555, 32, 4, True), + (433, 2048, 1024, 32, 4, True), + (128, 1024, 1024, 32, 4, False), + (128, 1024, 544, 32, 4, False), + (433, 1024, 1024, 32, 4, False), + (433, 1024, 544, 32, 4, False), + (433, 1024, 555, 32, 4, False), + (433, 2048, 1024, 32, 4, False), + ] + for L, K, D, E, I, transpose in parameters: + K, D = (K, D) if transpose else (D, K) + ishape = (L, I) + xshape = (L, 1, 1, K) + wshape = (E, D, K) if transpose else (E, K, D) + + indices = (mx.random.uniform(shape=ishape) * E).astype(mx.uint32) + x = mx.random.normal(xshape) / K**0.5 + w = mx.random.normal(wshape) / K**0.5 + w, *wq = quantize(w, transpose=transpose) + + y1 = mx.gather_mm(x, w, rhs_indices=indices) + y2 = mx.gather_qmm(x, *wq, transpose=transpose, rhs_indices=indices) + xs, idx, inv_order = gather_sort(x, indices) + y3 = mx.gather_mm(xs, w, rhs_indices=idx, sorted_indices=True) + y4 = mx.gather_qmm( + xs, *wq, rhs_indices=idx, transpose=transpose, sorted_indices=True + ) + y3 = scatter_unsort(y3, inv_order, indices.shape) + y4 = scatter_unsort(y4, inv_order, indices.shape) + + self.assertTrue(mx.allclose(y1, y2, atol=1e-5)) + self.assertTrue(mx.allclose(y1, y3, atol=1e-5)) + self.assertTrue(mx.allclose(y1, y4, atol=1e-5)) + if __name__ == "__main__": unittest.main() diff --git a/python/tests/test_vmap.py b/python/tests/test_vmap.py index 1a1ba23b3..e571678d3 100644 --- a/python/tests/test_vmap.py +++ b/python/tests/test_vmap.py @@ -669,6 +669,57 @@ class TestVmap(mlx_tests.MLXTestCase): self.assertEqual(mx.vmap(fun, in_axes=(1,))(x).shape, (3, 8)) self.assertEqual(mx.vmap(fun, in_axes=(2,))(x).shape, (4, 6)) + def test_vmap_conv(self): + # vmap input only + x = mx.random.uniform(shape=(2, 2, 5, 4)) + w = mx.random.uniform(shape=(8, 3, 4)) + + expected = mx.stack([mx.conv1d(xi, w) for xi in x]) + out = mx.vmap(mx.conv1d, in_axes=(0, None))(x, w) + self.assertTrue(mx.allclose(expected, out)) + + x = mx.moveaxis(x, 0, 2) + out = mx.vmap(mx.conv1d, in_axes=(2, None))(x, w) + self.assertTrue(mx.allclose(expected, out)) + + # vmap weights only + x = mx.random.uniform(shape=(2, 5, 4)) + w = mx.random.uniform(shape=(3, 8, 3, 4)) + + expected = mx.stack([mx.conv1d(x, wi) for wi in w]) + out = mx.vmap(mx.conv1d, in_axes=(None, 0))(x, w) + self.assertTrue(mx.allclose(expected, out)) + + w = mx.moveaxis(w, 0, 1) + out = mx.vmap(mx.conv1d, in_axes=(None, 1))(x, w) + self.assertTrue(mx.allclose(expected, out)) + + # vmap weights and input + x = mx.random.uniform(shape=(3, 2, 5, 4)) + w = mx.random.uniform(shape=(3, 8, 3, 4)) + + expected = mx.stack([mx.conv1d(xi, wi) for xi, wi in zip(x, w)]) + out = mx.vmap(mx.conv1d, in_axes=(0, 0))(x, w) + self.assertTrue(mx.allclose(expected, out)) + + x = mx.random.uniform(shape=(2, 3, 5, 4)) + w = mx.random.uniform(shape=(8, 3, 4, 3)) + + expected = mx.stack([mx.conv1d(x[:, i], w[..., i]) for i in range(3)]) + out = mx.vmap(mx.conv1d, in_axes=(1, 3))(x, w) + self.assertTrue(mx.allclose(expected, out)) + + # Test with groups + x = mx.random.uniform(shape=(3, 2, 5, 8)) + w = mx.random.uniform(shape=(3, 2, 3, 4)) + + def gconv(x, w): + return mx.conv1d(x, w, groups=2) + + expected = mx.stack([gconv(xi, wi) for xi, wi in zip(x, w)]) + out = mx.vmap(gconv, in_axes=(0, 0))(x, w) + self.assertTrue(mx.allclose(expected, out)) + if __name__ == "__main__": unittest.main() diff --git a/tests/export_import_tests.cpp b/tests/export_import_tests.cpp index 83ee1e590..7ad2c640d 100644 --- a/tests/export_import_tests.cpp +++ b/tests/export_import_tests.cpp @@ -97,8 +97,7 @@ TEST_CASE("test export primitives with state") { TEST_CASE("test export functions with kwargs") { std::string file_path = get_temp_file("model.mlxfn"); - auto fun = - [](const std::map& kwargs) -> std::vector { + auto fun = [](const Kwargs& kwargs) -> std::vector { return {kwargs.at("x") + kwargs.at("y")}; }; diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 356515702..de0f3352c 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -3874,3 +3874,41 @@ TEST_CASE("test contiguous") { CHECK(x.flags().col_contiguous); CHECK_EQ(x.strides(), decltype(x.strides()){1, 2}); } + +TEST_CASE("test bitwise shift operations") { + std::vector dtypes = { + int8, int16, int32, int64, uint8, uint16, uint32, uint64}; + + for (const auto& dtype : dtypes) { + array x = full({4}, 1, dtype); + array y = full({4}, 2, dtype); + + auto left_shift_result = left_shift(x, y); + CHECK_EQ(left_shift_result.dtype(), dtype); + CHECK(array_equal(left_shift_result, array({4, 4, 4, 4}, dtype)) + .item()); + + auto right_shift_result = right_shift(full({4}, 4, dtype), y); + CHECK_EQ(right_shift_result.dtype(), dtype); + CHECK(array_equal(right_shift_result, full({4}, 1, dtype)).item()); + } + + array x = array({127, -128}, int8); + array y = array({1, 1}, int8); + auto left_shift_result = left_shift(x, y); + auto right_shift_result = right_shift(x, y); + + CHECK(array_equal(left_shift_result, array({-2, 0}, int8)).item()); + CHECK(array_equal(right_shift_result, array({63, -64}, int8)).item()); + + array x_bool = full({4}, true, bool_); + array y_bool = full({4}, true, bool_); + auto left_shift_bool_result = left_shift(x_bool, y_bool); + auto right_shift_bool_result = right_shift(x_bool, y_bool); + + CHECK_EQ(left_shift_bool_result.dtype(), uint8); + CHECK(array_equal(left_shift_bool_result, full({4}, 2, uint8)).item()); + + CHECK_EQ(right_shift_bool_result.dtype(), uint8); + CHECK(array_equal(right_shift_bool_result, full({4}, 0, uint8)).item()); +} \ No newline at end of file