diff --git a/docs/src/dev/extensions.rst b/docs/src/dev/extensions.rst index 9a2be90cd..ecb418468 100644 --- a/docs/src/dev/extensions.rst +++ b/docs/src/dev/extensions.rst @@ -486,9 +486,8 @@ below. std::ostringstream kname; kname << "axpby_" << "general_" << type_to_name(out); - // Make sure the metal library is available and look for it - // in the same folder as this executable if needed - d.register_library("mlx_ext", metal::get_colocated_mtllib_path); + // Make sure the metal library is available + d.register_library("mlx_ext"); // Make a kernel from this metal library auto kernel = d.get_kernel(kname.str(), "mlx_ext"); diff --git a/examples/extensions/axpby/axpby.cpp b/examples/extensions/axpby/axpby.cpp index 57c7d7900..c8ba7c239 100644 --- a/examples/extensions/axpby/axpby.cpp +++ b/examples/extensions/axpby/axpby.cpp @@ -249,9 +249,8 @@ void Axpby::eval_gpu( kname << (contiguous_kernel ? "contiguous_" : "general_"); kname << type_to_name(out); - // Make sure the metal library is available and look for it - // in the same folder as this executable if needed - d.register_library("mlx_ext", metal::get_colocated_mtllib_path); + // Make sure the metal library is available + d.register_library("mlx_ext"); // Make a kernel from this metal library auto kernel = d.get_kernel(kname.str(), "mlx_ext"); diff --git a/mlx/backend/metal/CMakeLists.txt b/mlx/backend/metal/CMakeLists.txt index 2a58df9cc..38bb1589d 100644 --- a/mlx/backend/metal/CMakeLists.txt +++ b/mlx/backend/metal/CMakeLists.txt @@ -114,6 +114,7 @@ if (MLX_METAL_JIT) kernels/steel/conv/loaders/loader_general.h ) make_jit_source(quantized) + make_jit_source(gemv_masked) else() target_sources( mlx @@ -149,6 +150,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ternary.cpp ${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp ) if (NOT MLX_METAL_PATH) diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index 9ef15497b..2ab0fa960 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -14,7 +14,6 @@ #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/metal.h" #include "mlx/backend/metal/metal_impl.h" -#include "mlx/backend/metal/mps/gemm.h" #include "mlx/backend/metal/utils.h" namespace fs = std::filesystem; @@ -39,6 +38,20 @@ constexpr auto get_metal_version() { #endif } +std::string get_colocated_mtllib_path(const std::string& lib_name) { + Dl_info info; + std::string mtllib_path; + std::string lib_ext = lib_name + ".metallib"; + + int success = dladdr((void*)get_colocated_mtllib_path, &info); + if (success) { + auto mtllib = fs::path(info.dli_fname).remove_filename() / lib_ext; + mtllib_path = mtllib.c_str(); + } + + return mtllib_path; +} + auto load_device() { auto devices = MTL::CopyAllDevices(); auto device = static_cast(devices->object(0)) @@ -126,6 +139,49 @@ MTL::Library* load_library( } // namespace +CommandEncoder::CommandEncoder(MTL::CommandBuffer* cbuf) : cbuf(cbuf) { + enc = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent); + enc->retain(); +} + +CommandEncoder::~CommandEncoder() { + enc->endEncoding(); + enc->release(); +} + +void CommandEncoder::set_input_array( + const array& a, + int idx, + int64_t offset /* = 0 */) { + auto r_buf = static_cast(const_cast(a.buffer().ptr())); + if (auto it = outputs.find(r_buf); it != outputs.end()) { + // Insert a barrier + enc->memoryBarrier(&r_buf, 1); + + // Remove the output + outputs.erase(it); + } + auto a_buf = static_cast(a.buffer().ptr()); + auto base_offset = a.data() - + static_cast(const_cast(a_buf)->contents()); + base_offset += offset; + enc->setBuffer(a_buf, base_offset, idx); +} + +void CommandEncoder::set_output_array( + array& a, + int idx, + int64_t offset /* = 0 */) { + // Add barriers before adding the output to the output set + set_input_array(a, idx, offset); + auto buf = static_cast(a.buffer().ptr()); + if (concurrent) { + concurrent_outputs.insert(buf); + } else { + outputs.insert(buf); + } +} + void CommandEncoder::dispatchThreadgroups( MTL::Size grid_dims, MTL::Size group_dims) { @@ -255,13 +311,9 @@ void Device::register_library( } } -void Device::register_library( - const std::string& lib_name, - const std::function& lib_path_func) { +void Device::register_library(const std::string& lib_name) { if (auto it = library_map_.find(lib_name); it == library_map_.end()) { - std::string new_lib_path = lib_path_func(lib_name); - auto new_lib = load_library(device_, lib_name, new_lib_path.c_str()); - library_map_.insert({lib_name, new_lib}); + register_library(lib_name, get_colocated_mtllib_path(lib_name)); } } @@ -271,7 +323,7 @@ MTL::Library* Device::get_library_cache_(const std::string& lib_name) { if (auto it = library_map_.find(lib_name); it != library_map_.end()) { mtl_lib = it->second; } else { // Look for metallib alongside library - register_library(lib_name); + register_library(lib_name, get_colocated_mtllib_path(lib_name)); mtl_lib = library_map_[lib_name]; } diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index 047b2735a..a28dd832e 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -9,38 +9,16 @@ #include #include -#include -#include - #include "mlx/array.h" #include "mlx/device.h" -namespace fs = std::filesystem; - namespace mlx::core::metal { -inline std::string get_colocated_mtllib_path(const std::string& lib_name) { - Dl_info info; - std::string mtllib_path; - std::string lib_ext = lib_name + ".metallib"; - - int success = dladdr((void*)get_colocated_mtllib_path, &info); - if (success) { - auto mtllib = fs::path(info.dli_fname).remove_filename() / lib_ext; - mtllib_path = mtllib.c_str(); - } - - return mtllib_path; -} - using MTLFCList = std::vector>; struct CommandEncoder { - CommandEncoder(MTL::CommandBuffer* cbuf) : cbuf(cbuf) { - enc = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent); - enc->retain(); - }; + CommandEncoder(MTL::CommandBuffer* cbuf); CommandEncoder(const CommandEncoder&) = delete; CommandEncoder& operator=(const CommandEncoder&) = delete; @@ -63,34 +41,8 @@ struct CommandEncoder { return enc; } - void set_input_array(const array& a, int idx, int64_t offset = 0) { - auto r_buf = - static_cast(const_cast(a.buffer().ptr())); - if (auto it = outputs.find(r_buf); it != outputs.end()) { - // Insert a barrier - enc->memoryBarrier(&r_buf, 1); - - // Remove the output - outputs.erase(it); - } - auto a_buf = static_cast(a.buffer().ptr()); - auto base_offset = a.data() - - static_cast(const_cast(a_buf)->contents()); - base_offset += offset; - enc->setBuffer(a_buf, base_offset, idx); - } - - void set_output_array(array& a, int idx, int64_t offset = 0) { - // Add barriers before adding the output to the output set - set_input_array(a, idx, offset); - auto buf = static_cast(a.buffer().ptr()); - if (concurrent) { - concurrent_outputs.insert(buf); - } else { - outputs.insert(buf); - } - } - + void set_input_array(const array& a, int idx, int64_t offset = 0); + void set_output_array(array& a, int idx, int64_t offset = 0); void dispatchThreadgroups(MTL::Size grid_dims, MTL::Size group_dims); void dispatchThreads(MTL::Size grid_dims, MTL::Size group_dims); @@ -98,10 +50,7 @@ struct CommandEncoder { return ConcurrentContext(*this); } - ~CommandEncoder() { - enc->endEncoding(); - enc->release(); - } + ~CommandEncoder(); private: void maybe_split(); @@ -136,10 +85,8 @@ class Device { void register_library( const std::string& lib_name, const std::string& lib_path); - void register_library( - const std::string& lib_name, - const std::function& lib_path_func = - get_colocated_mtllib_path); + + void register_library(const std::string& lib_name); MTL::Library* get_library(const std::string& name); diff --git a/mlx/backend/metal/fft.cpp b/mlx/backend/metal/fft.cpp index de0edce51..8fb2c9377 100644 --- a/mlx/backend/metal/fft.cpp +++ b/mlx/backend/metal/fft.cpp @@ -1,4 +1,4 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2024 Apple Inc. #include #include #include @@ -12,8 +12,6 @@ #include "mlx/backend/metal/slicing.h" #include "mlx/backend/metal/unary.h" #include "mlx/backend/metal/utils.h" -#include "mlx/mlx.h" -#include "mlx/primitives.h" #include "mlx/utils.h" namespace mlx::core { @@ -786,10 +784,9 @@ void nd_fft_op( fft_op(in_arr, out_arr, axis, inverse, step_real, inplace, s); } - std::vector copies = {temp1, temp2}; auto& d = metal::device(s.device); d.get_command_buffer(s.index)->addCompletedHandler( - [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); + [temp_arrs](MTL::CommandBuffer*) mutable { temp_arrs.clear(); }); } void FFT::eval_gpu(const std::vector& inputs, array& out) { diff --git a/mlx/backend/metal/jit/gemv_masked.h b/mlx/backend/metal/jit/gemv_masked.h new file mode 100644 index 000000000..deae78865 --- /dev/null +++ b/mlx/backend/metal/jit/gemv_masked.h @@ -0,0 +1,25 @@ +// Copyright © 2024 Apple Inc. + +constexpr std::string_view gemv_masked_kernel = R"( +template [[host_name("{name}")]] [[kernel]] void +gemv_{trans}masked<{itype}, {outm_t}, {opm_t}, {bm}, {bn}, {sm}, {sn}, {tm}, {tn}, {nc}>( + const device {itype}* mat [[buffer(0)]], + const device {itype}* in_vec [[buffer(1)]], + device {itype}* out_vec [[buffer(3)]], + const constant int& in_vec_size [[buffer(4)]], + const constant int& out_vec_size [[buffer(5)]], + const constant int& marix_ld [[buffer(6)]], + const constant int& batch_ndim [[buffer(9)]], + const constant int* batch_shape [[buffer(10)]], + const constant size_t* vector_batch_stride [[buffer(11)]], + const constant size_t* matrix_batch_stride [[buffer(12)]], + const device {outm_t}* out_mask [[buffer(20)]], + const device {opm_t}* mat_mask [[buffer(21)]], + const device {opm_t}* vec_mask [[buffer(22)]], + const constant int* mask_strides [[buffer(23)]], + const constant size_t* mask_batch_strides [[buffer(24)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]); +)"; diff --git a/mlx/backend/metal/jit/includes.h b/mlx/backend/metal/jit/includes.h index aca2c683e..89b025c66 100644 --- a/mlx/backend/metal/jit/includes.h +++ b/mlx/backend/metal/jit/includes.h @@ -33,5 +33,6 @@ const char* steel_gemm_splitk(); const char* conv(); const char* steel_conv(); const char* steel_conv_general(); +const char* gemv_masked(); } // namespace mlx::core::metal diff --git a/mlx/backend/metal/jit_kernels.cpp b/mlx/backend/metal/jit_kernels.cpp index 12e843618..895ad143d 100644 --- a/mlx/backend/metal/jit_kernels.cpp +++ b/mlx/backend/metal/jit_kernels.cpp @@ -4,6 +4,7 @@ #include "mlx/backend/common/compiled.h" #include "mlx/backend/metal/jit/arange.h" #include "mlx/backend/metal/jit/copy.h" +#include "mlx/backend/metal/jit/gemv_masked.h" #include "mlx/backend/metal/jit/includes.h" #include "mlx/backend/metal/jit/reduce.h" #include "mlx/backend/metal/jit/scan.h" @@ -50,10 +51,12 @@ MTL::ComputePipelineState* get_unary_kernel( std::ostringstream kernel_source; auto u_def = get_template_definition( "v" + lib_name, "unary_v", get_type_string(out_type), op); + auto u2_def = get_template_definition( + "v2" + lib_name, "unary_v2", get_type_string(out_type), op); auto g_def = get_template_definition( "g" + lib_name, "unary_g", get_type_string(out_type), op); kernel_source << metal::utils() << metal::unary_ops() << metal::unary() - << u_def << g_def; + << u_def << u2_def << g_def; lib = d.get_library(lib_name, kernel_source.str()); } return d.get_kernel(kernel_name, lib); @@ -70,6 +73,9 @@ void add_binary_kernels( {"vs", "binary_vs"}, {"sv", "binary_sv"}, {"vv", "binary_vv"}, + {"vs2", "binary_vs2"}, + {"sv2", "binary_sv2"}, + {"vv2", "binary_vv2"}, {"g1", "binary_g_nd1"}, {"g2", "binary_g_nd2"}, {"g3", "binary_g_nd3"}, @@ -146,6 +152,7 @@ MTL::ComputePipelineState* get_ternary_kernel( std::ostringstream kernel_source; const std::map kernel_types = { {"v", "ternary_v"}, + {"v2", "ternary_v2"}, {"g", "ternary_g"}, {"g1", "ternary_g_nd1"}, {"g2", "ternary_g_nd2"}, @@ -496,6 +503,49 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel( return d.get_kernel(kernel_name, lib); } +MTL::ComputePipelineState* get_gemv_masked_kernel( + metal::Device& d, + const std::string& kernel_name, + const array& out, + const std::optional& mask_out, + const std::optional& mask_op, + bool transpose_mat, + int bm, + int bn, + int sm, + int sn, + int tm, + int tn, + bool contiguous) { + const auto& lib_name = kernel_name; + auto lib = d.get_library(lib_name); + if (lib == nullptr) { + std::ostringstream kernel_source; + auto out_mask_type = mask_out.has_value() + ? get_type_string((*mask_out).dtype()) + : "nomask_t"; + auto op_mask_type = + mask_op.has_value() ? get_type_string((*mask_op).dtype()) : "nomask_t"; + kernel_source << metal::utils() << metal::gemv_masked() + << fmt::format( + gemv_masked_kernel, + "name"_a = lib_name, + "itype"_a = get_type_string(out.dtype()), + "outm_t"_a = out_mask_type, + "opm_t"_a = op_mask_type, + "bm"_a = bm, + "bn"_a = bn, + "sm"_a = sm, + "sn"_a = sn, + "tm"_a = tm, + "tn"_a = tn, + "trans"_a = transpose_mat ? "t_" : "", + "nc"_a = contiguous ? "0" : "1"); + lib = d.get_library(lib_name, kernel_source.str()); + } + return d.get_kernel(kernel_name, lib); +} + MTL::ComputePipelineState* get_steel_conv_kernel( metal::Device& d, const std::string& kernel_name, diff --git a/mlx/backend/metal/kernels.h b/mlx/backend/metal/kernels.h index 40f548868..44f260777 100644 --- a/mlx/backend/metal/kernels.h +++ b/mlx/backend/metal/kernels.h @@ -151,6 +151,21 @@ MTL::ComputePipelineState* get_steel_conv_kernel( int n_channel_specialization, bool small_filter); +MTL::ComputePipelineState* get_gemv_masked_kernel( + metal::Device& d, + const std::string& kernel_name, + const array& out, + const std::optional& mask_out, + const std::optional& mask_op, + bool transpose_mat, + int bm, + int bn, + int sm, + int sn, + int tm, + int tn, + bool contiguous); + MTL::ComputePipelineState* get_steel_conv_general_kernel( metal::Device& d, const std::string& kernel_name, diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index 4834c1f78..8699565a5 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -38,7 +38,6 @@ endfunction(build_kernel) build_kernel(arg_reduce) build_kernel(conv steel/conv/params.h) build_kernel(gemv steel/utils.h) -build_kernel(gemv_masked steel/utils.h) build_kernel(layer_norm) build_kernel(random) build_kernel(rms_norm) @@ -121,6 +120,7 @@ build_kernel( steel/gemm/kernels/steel_gemm_splitk ${STEEL_HEADERS} ) +build_kernel(gemv_masked steel/utils.h) endif() diff --git a/mlx/backend/metal/kernels/gemv_masked.h b/mlx/backend/metal/kernels/gemv_masked.h new file mode 100644 index 000000000..1dd436e32 --- /dev/null +++ b/mlx/backend/metal/kernels/gemv_masked.h @@ -0,0 +1,819 @@ +// Copyright © 2023-2024 Apple Inc. + +#include "mlx/backend/metal/kernels/steel/utils.h" + +using namespace metal; + +#define MLX_MTL_CONST static constant constexpr const +#define MLX_MTL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)") + +struct _NoMask { + char x; + + constexpr METAL_FUNC operator bool() { + return true; + } + constexpr METAL_FUNC operator bool() const threadgroup { + return true; + } + constexpr METAL_FUNC operator bool() const device { + return true; + } + constexpr METAL_FUNC operator bool() const constant { + return true; + } +}; + +typedef struct _NoMask nomask_t; + +template +struct ScaleOp { + OutT scale; + + METAL_FUNC OutT apply(InT x) const { + return static_cast(x) * scale; + } +}; + +template < + typename T, + typename out_mask_t, + typename op_mask_t, + const int BM, /* Threadgroup rows (in simdgroups) */ + const int BN, /* Threadgroup cols (in simdgroups) */ + const int SM, /* Simdgroup rows (in threads) */ + const int SN, /* Simdgroup cols (in threads) */ + const int TM, /* Thread rows (in elements) */ + const int TN> /* Thread cols (in elements) */ +struct GEMVKernel { + MLX_MTL_CONST int threadsM = BM * SM; + MLX_MTL_CONST int threadsN = BN * SN; + + MLX_MTL_CONST int blockM = threadsM * TM; + MLX_MTL_CONST int blockN = threadsN * TN; + + static_assert(SM * SN == 32, "simdgroup can only have 32 threads"); + + static_assert( + SN == 8 || SN == 16 || SN == 32, + "gemv block must have a width of 8, 16, or 32"); + + static_assert(blockN >= blockM, "Masked gemv must have blockN >= blockM"); + + MLX_MTL_CONST bool has_operand_mask = !metal::is_same_v; + MLX_MTL_CONST bool has_output_mask = !metal::is_same_v; + + MLX_MTL_CONST bool has_mul_operand_mask = + has_operand_mask && !metal::is_same_v; + MLX_MTL_CONST bool has_mul_output_mask = + has_output_mask && !metal::is_same_v; + + // - The matrix of size (M = out_vec_size, K = in_vec_size) is divided up + // into blocks of (blockM, blockN) divided among threadgroups + // - Every thread works on a block of (TM, TN) + // - We assume each threadgroup has (threadsN, threadsM, 1) threads + // + // 1. A thread loads TN elements each from mat along TM rows + // and the corresponding scalar from the vector + // 2. The thread then multiplies and adds to accumulate its local result for + // the block + // 3. At the end, each thread has accumulated results over all blocks across + // the rows. These are then summed up across the threadgroup + // 4. Each threadgroup writes its accumulated blockM outputs + // + // Edge case handling: + // - The threadgroup with the largest tid has blocks that exceed the matrix + // * The blocks that start outside the matrix are never read (thread results + // remain zero) + // * The last thread that partially overlaps with the matrix is shifted + // inwards such that the thread block fits exactly in the matrix + + MLX_MTL_CONST short tgp_mem_size = BN > 1 ? BN*(blockM + TM) : 0; + MLX_MTL_CONST bool needs_tgp_reduction = BN > 1; + + static METAL_FUNC void + load_unsafe(const device T* src, thread T dst[TN], const int src_offset = 0) { + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + dst[tn] = src[src_offset + tn]; + } + } + + static METAL_FUNC void load_safe( + const device T* src, + thread T dst[TN], + const int src_offset = 0, + const int src_size = TN) { + if (src_offset + TN <= src_size) { + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + dst[tn] = src[src_offset + tn]; + } + } else { // Edgecase + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + dst[tn] = src_offset + tn < src_size ? src[src_offset + tn] : 0; + } + } + } + + static METAL_FUNC void run( + const device T* mat [[buffer(0)]], + const device T* in_vec [[buffer(1)]], + device T* out_vec [[buffer(3)]], + const constant int& in_vec_size [[buffer(4)]], + const constant int& out_vec_size [[buffer(5)]], + const constant int& matrix_ld [[buffer(6)]], + const device out_mask_t* out_mask [[buffer(20)]], + const device op_mask_t* mat_mask [[buffer(21)]], + const device op_mask_t* vec_mask [[buffer(22)]], + const constant int* mask_strides [[buffer(23)]], + threadgroup T* tgp_memory [[threadgroup(0)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + // Appease compiler + (void)lid; + + // Thread local accumulation results + thread T result[TM] = {0}; + thread T inter[TN]; + thread T v_coeff[TN]; + + const int thrM = SN != 32 ? simd_lid / SN : 0; + const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid); + + const int sgN = BN != 1 ? (simd_gid % BN) : 0; + + const int simdM = BN != 1 ? SM * (simd_gid / BN) : int(SM * simd_gid); + const int simdN = BN != 1 ? SN * (simd_gid % BN) : 0; + + int bm = (simdM + thrM) * TM; + int bn = (simdN + thrN) * TN; + + // Block position + int out_row = tid.x * blockM + bm; + + // Exit simdgroup if rows out of bound + if (out_row >= out_vec_size) + return; + + // Adjust tail simdgroup to ensure in bound reads + out_row = out_row + TM <= out_vec_size ? out_row : out_vec_size - TM; + + // Prepare mask offsets + const constant int* out_mask_strides = mask_strides; + const constant int* mat_mask_strides = + mask_strides + (has_output_mask ? 2 : 0); + const constant int* vec_mask_strides = + mat_mask_strides + (has_operand_mask ? 2 : 0); + + const int m_block_idx = blockN > blockM ? out_row / blockN : int(tid.x); + + const int out_mask_offset = + !has_output_mask ? 0 : m_block_idx * out_mask_strides[1]; + + int mat_mask_offset = + !has_operand_mask ? 0 : m_block_idx * mat_mask_strides[1]; + int vec_mask_offset = 0; + const int mat_mask_step = !has_operand_mask ? 0 : mat_mask_strides[0]; + const int vec_mask_step = !has_operand_mask ? 0 : vec_mask_strides[1]; + + T out_scale{1}; + + // Check output mask + if (has_output_mask) { + auto mask_out = out_mask[out_mask_offset]; + + // Write zeros and return if mask is 0 + if (!mask_out) { + if (simdN == 0 && thrN == 0) { + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + out_vec[out_row + tm] = T(0.); + } + } + + return; + } + + // Store scalar if multiplicative mask + if (has_mul_output_mask) { + out_scale = T(mask_out); + } + } + + // Advance matrix + mat += out_row * matrix_ld; + + // Prepare for loop + constexpr const uniform loop_stride = make_uniform(blockN); + const uniform in_size = make_uniform(in_vec_size); + const uniform n_iter = in_size / loop_stride; + const uniform last_iter = loop_stride * n_iter; + const uniform leftover = in_size - last_iter; + + // Loop over in_vec in blocks of blockN + for (int i = 0; i < n_iter; ++i) { + if (!has_operand_mask || + (bool(mat_mask[mat_mask_offset]) && + bool(vec_mask[vec_mask_offset]))) { + T block_scale{1}; + if (has_mul_operand_mask) { + block_scale = + T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]); + } + + load_unsafe(in_vec, v_coeff, bn); + + // Apply scale + if (has_mul_operand_mask) { + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + v_coeff[tn] *= block_scale; + } + } + + // Per thread work loop + int mat_offset = 0; + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + // Load for the row + load_unsafe(mat, inter, mat_offset + bn); + + // Accumulate results + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + result[tm] += inter[tn] * v_coeff[tn]; + } + + mat_offset += matrix_ld; + } + } + + bn += blockN; + mat_mask_offset += mat_mask_step; + vec_mask_offset += vec_mask_step; + } + + if (leftover > 0 && + (!has_operand_mask || + (bool(mat_mask[mat_mask_offset]) && + bool(vec_mask[vec_mask_offset])))) { + T block_scale{1}; + if (has_mul_operand_mask) { + block_scale = + T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]); + } + + load_safe(in_vec, v_coeff, bn, in_size); + + // Apply scale + if (has_mul_operand_mask) { + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + v_coeff[tn] *= block_scale; + } + } + + // Per thread work loop + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + // Load for the row + load_safe(&mat[tm * matrix_ld], inter, bn, in_size); + + // Accumulate results + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + result[tm] += inter[tn] * v_coeff[tn]; + } + } + } + + // Apply out scale + if (has_mul_output_mask) { + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + result[tm] *= out_scale; + } + } + + // Simdgroup accumulations + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + MLX_MTL_PRAGMA_UNROLL + for (ushort sn = (SN / 2); sn >= 1; sn >>= 1) { + result[tm] += simd_shuffle_down(result[tm], sn); + } + } + + // Threadgroup accumulation results + if (needs_tgp_reduction) { + threadgroup T* tgp_results = tgp_memory + sgN * (blockM + TM) + bm; + if (thrN == 0) { + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + tgp_results[tm] = result[tm]; + } + + threadgroup_barrier(mem_flags::mem_none); + + if (sgN == 0) { + MLX_MTL_PRAGMA_UNROLL + for (int sgn = 1; sgn < BN; sgn++) { + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + result[tm] += tgp_results[sgn * (blockM + TM) + tm]; + } + } + } + } + } + + // Write outputs + if (simdN == 0 && thrN == 0) { + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + out_vec[out_row + tm] = result[tm]; + } + } + } +}; + +/////////////////////////////////////////////////////////////////////////////// +/// Vector matrix multiplication +/////////////////////////////////////////////////////////////////////////////// + +template < + typename T, + typename out_mask_t, + typename op_mask_t, + const int BM, /* Threadgroup rows (in simdgroups) */ + const int BN, /* Threadgroup cols (in simdgroups) */ + const int SM, /* Simdgroup rows (in threads) */ + const int SN, /* Simdgroup cols (in threads) */ + const int TM, /* Thread rows (in elements) */ + const int TN> /* Thread cols (in elements) */ +struct GEMVTKernel { + MLX_MTL_CONST int threadsM = BM * SM; + MLX_MTL_CONST int threadsN = BN * SN; + + MLX_MTL_CONST int blockM = threadsM * TM; + MLX_MTL_CONST int blockN = threadsN * TN; + + static_assert(SM * SN == 32, "simdgroup can only have 32 threads"); + + MLX_MTL_CONST bool has_operand_mask = !metal::is_same_v; + MLX_MTL_CONST bool has_output_mask = !metal::is_same_v; + + MLX_MTL_CONST bool has_mul_operand_mask = + has_operand_mask && !metal::is_same_v; + MLX_MTL_CONST bool has_mul_output_mask = + has_output_mask && !metal::is_same_v; + + // - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up + // into blocks of (blockM, blockN) divided among threadgroups + // - Every thread works on a block of (TM, TN) + // - We assume each threadgroup has (threadsN, threadsM, 1) threads + // + // 1. A thread loads TN elements each from mat along TM contiguous rows + // and the corresponding scalar from the vector + // 2. The thread then accumulates its local result for the block + // 3. At the end, each thread has accumulated results over all blocks across + // the rows. These are then summed up across the threadgroup + // 4. Each threadgroup writes its accumulated BN * TN outputs + // + // Edge case handling: + // - The threadgroup with the largest tid has blocks that exceed the matrix + // * The blocks that start outside the matrix are never read (thread results + // remain zero) + // * The last thread that partially overlaps with the matrix is shifted + // inwards such that the thread block fits exactly in the matrix + + MLX_MTL_CONST short tgp_mem_size = BM > 1 ? BM*(blockN + TN) : 0; + MLX_MTL_CONST bool needs_tgp_reduction = BM > 1; + + static METAL_FUNC void run( + const device T* mat [[buffer(0)]], + const device T* in_vec [[buffer(1)]], + device T* out_vec [[buffer(3)]], + const constant int& in_vec_size [[buffer(4)]], + const constant int& out_vec_size [[buffer(5)]], + const constant int& marix_ld [[buffer(6)]], + const device out_mask_t* out_mask [[buffer(20)]], + const device op_mask_t* mat_mask [[buffer(21)]], + const device op_mask_t* vec_mask [[buffer(22)]], + const constant int* mask_strides [[buffer(23)]], + threadgroup T* tgp_memory [[threadgroup(0)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + // Appease compiler + (void)lid; + + // Thread local accumulation results + T result[TN] = {0}; + T inter[TN]; + T v_coeff[TM]; + + const int thrM = SN != 32 ? simd_lid / SN : 0; + const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid); + + const int sgM = BN != 1 ? (simd_gid / BN) : int(simd_gid); + const int sgN = BN != 1 ? (simd_gid % BN) : 0; + + const int simdM = SM * sgM; + const int simdN = SN * sgN; + + int cm = (simdM + thrM); + int cn = (simdN + thrN); + + int bm = cm * TM; + int bn = cn * TN; + + int out_col = tid.x * blockN + bn; + + // Prepare mask offsets + const constant int* out_mask_strides = mask_strides; + const constant int* mat_mask_strides = + out_mask_strides + (has_output_mask ? 2 : 0); + const constant int* vec_mask_strides = + mat_mask_strides + (has_operand_mask ? 2 : 0); + + const int n_block_idx = blockM > blockN ? out_col / blockM : int(tid.x); + + const int out_mask_offset = + !has_output_mask ? 0 : n_block_idx; // * out_mask_strides[0]; + + int mat_mask_offset = + !has_operand_mask ? 0 : n_block_idx * mat_mask_strides[0]; + int vec_mask_offset = 0; + const int mat_mask_step = !has_operand_mask ? 0 : mat_mask_strides[1]; + const int vec_mask_step = !has_operand_mask ? 0 : vec_mask_strides[0]; + + T out_scale{1}; + + // Check output mask + if (has_output_mask) { + auto mask_out = out_mask[out_mask_offset]; + + // Write zeros and return if mask is 0 + if (!mask_out) { + if (cm == 0 && out_col < out_vec_size) { + if (out_col + TN <= out_vec_size) { + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + out_vec[out_col + tn] = T(0.); + } + } else { + for (int tn = 0; tn < TN && (out_col + tn) < out_vec_size; tn++) { + out_vec[out_col + tn] = T(0.); + } + } + } + + return; + } + + // Store scalar if multiplicative mask + if (has_mul_output_mask) { + out_scale = T(mask_out); + } + } + + // Prepare for loop + constexpr const uniform loop_stride = make_uniform(blockM); + const uniform in_size = make_uniform(in_vec_size); + const uniform n_iter = in_size / loop_stride; + const uniform last_iter = loop_stride * n_iter; + const uniform leftover = in_size - last_iter; + + // Edgecase handling + if (out_col < out_vec_size) { + out_col = (out_col + TN) <= out_vec_size ? out_col : out_vec_size - TN; + + // Per thread accumulation main loop + for (int i = 0; i < n_iter; ++i) { + // Adding a threadgroup_barrier improves performance slightly + // This is possibly it may help exploit cache better + threadgroup_barrier(mem_flags::mem_none); + + if (!has_operand_mask || + (bool(mat_mask[mat_mask_offset]) && + bool(vec_mask[vec_mask_offset]))) { + T block_scale{1}; + if (has_mul_operand_mask) { + block_scale = + T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]); + } + + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + v_coeff[tm] = in_vec[bm + tm]; + } + + // Apply scale + if (has_mul_operand_mask) { + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + v_coeff[tm] *= block_scale; + } + } + + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + for (int tn = 0; tn < TN; tn++) { + inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn]; + } + for (int tn = 0; tn < TN; tn++) { + result[tn] += v_coeff[tm] * inter[tn]; + } + } + } + + bm += blockM; + mat_mask_offset += mat_mask_step; + vec_mask_offset += vec_mask_step; + } + + if (leftover > 0 && + (!has_operand_mask || + (bool(mat_mask[mat_mask_offset]) && + bool(vec_mask[vec_mask_offset])))) { + T block_scale{1}; + if (has_mul_operand_mask) { + block_scale = + T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]); + } + + for (int tm = 0; tm < TM && bm + tm < in_vec_size; tm++) { + v_coeff[tm] = in_vec[bm + tm]; + + if (has_mul_operand_mask) { + v_coeff[tm] *= block_scale; + } + + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn]; + } + + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + result[tn] += v_coeff[tm] * inter[tn]; + } + } + } + } + + // Apply out scale + if (has_mul_output_mask) { + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + result[tn] *= out_scale; + } + } + + // Simdgroup accumulations + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + MLX_MTL_PRAGMA_UNROLL + for (ushort sm = (SM / 2); sm >= 1; sm >>= 1) { + result[tn] += simd_shuffle_down(result[tn], SN * sm); + } + } + + // Threadgroup accumulation results + if (needs_tgp_reduction) { + threadgroup T* tgp_results = tgp_memory + sgM * (blockN + TN) + bn; + if (thrM == 0) { + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + tgp_results[tn] = result[tn]; + } + + threadgroup_barrier(mem_flags::mem_none); + + if (sgM == 0) { + MLX_MTL_PRAGMA_UNROLL + for (int sgm = 1; sgm < BM; sgm++) { + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + result[tn] += tgp_results[sgm * (blockN + TN) + tn]; + } + } + } + } + } + + // Threadgroup accumulation and writing out results + if (cm == 0 && out_col < out_vec_size) { + MLX_MTL_PRAGMA_UNROLL + for (int j = 0; j < TN; j++) { + out_vec[out_col + j] = result[j]; + } + } + } +}; + +/////////////////////////////////////////////////////////////////////////////// +/// Matrix vector multiplication +/////////////////////////////////////////////////////////////////////////////// + +template < + typename T, + typename out_mask_t, + typename op_mask_t, + const int BM, /* Threadgroup rows (in simdgroups) */ + const int BN, /* Threadgroup cols (in simdgroups) */ + const int SM, /* Simdgroup rows (in threads) */ + const int SN, /* Simdgroup cols (in threads) */ + const int TM, /* Thread rows (in elements) */ + const int TN, /* Thread cols (in elements) */ + const bool kDoNCBatch> /* Batch ndim > 1 */ +[[kernel, max_total_threads_per_threadgroup(BM* BN * 32)]] void gemv_masked( + const device T* mat [[buffer(0)]], + const device T* in_vec [[buffer(1)]], + device T* out_vec [[buffer(3)]], + const constant int& in_vec_size [[buffer(4)]], + const constant int& out_vec_size [[buffer(5)]], + const constant int& marix_ld [[buffer(6)]], + const constant int& batch_ndim [[buffer(9)]], + const constant int* batch_shape [[buffer(10)]], + const constant size_t* vector_batch_stride [[buffer(11)]], + const constant size_t* matrix_batch_stride [[buffer(12)]], + const device out_mask_t* out_mask [[buffer(20)]], + const device op_mask_t* mat_mask [[buffer(21)]], + const device op_mask_t* vec_mask [[buffer(22)]], + const constant int* mask_strides [[buffer(23)]], + const constant size_t* mask_batch_strides [[buffer(24)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + using gemv_kernel = + GEMVKernel; + threadgroup T tgp_memory + [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size]; + + constexpr bool has_operand_mask = !metal::is_same_v; + constexpr bool has_output_mask = !metal::is_same_v; + + // Update batch offsets + if (kDoNCBatch) { + in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim); + mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim); + + if (has_output_mask) { + out_mask += + elem_to_loc(tid.z, batch_shape, mask_batch_strides, batch_ndim); + mask_batch_strides += batch_ndim; + } + + if (has_operand_mask) { + const constant size_t* mask_strides_mat = mask_batch_strides; + const constant size_t* mask_strides_vec = mask_strides_mat + batch_ndim; + + ulong2 batch_offsets = elem_to_loc_broadcast( + tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim); + + mat_mask += batch_offsets.x; + vec_mask += batch_offsets.y; + } + + } else { + in_vec += tid.z * vector_batch_stride[0]; + mat += tid.z * matrix_batch_stride[0]; + + if (has_output_mask) { + out_mask += tid.z * mask_batch_strides[0]; + mask_batch_strides += batch_ndim; + } + + if (has_operand_mask) { + mat_mask += tid.z * mask_batch_strides[0]; + vec_mask += tid.z * mask_batch_strides[batch_ndim]; + } + } + + out_vec += tid.z * out_vec_size; + + gemv_kernel::run( + mat, + in_vec, + out_vec, + in_vec_size, + out_vec_size, + marix_ld, + out_mask, + mat_mask, + vec_mask, + mask_strides, + gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory, + tid, + lid, + simd_gid, + simd_lid); +} + +/////////////////////////////////////////////////////////////////////////////// +/// Vector matrix multiplication +/////////////////////////////////////////////////////////////////////////////// + +template < + typename T, + typename out_mask_t, + typename op_mask_t, + const int BM, /* Threadgroup rows (in simdgroups) */ + const int BN, /* Threadgroup cols (in simdgroups) */ + const int SM, /* Simdgroup rows (in threads) */ + const int SN, /* Simdgroup cols (in threads) */ + const int TM, /* Thread rows (in elements) */ + const int TN, /* Thread cols (in elements) */ + const bool kDoNCBatch> /* Batch ndim > 1 */ +[[kernel, max_total_threads_per_threadgroup(BM* BN * 32)]] void gemv_t_masked( + const device T* mat [[buffer(0)]], + const device T* in_vec [[buffer(1)]], + device T* out_vec [[buffer(3)]], + const constant int& in_vec_size [[buffer(4)]], + const constant int& out_vec_size [[buffer(5)]], + const constant int& marix_ld [[buffer(6)]], + const constant int& batch_ndim [[buffer(9)]], + const constant int* batch_shape [[buffer(10)]], + const constant size_t* vector_batch_stride [[buffer(11)]], + const constant size_t* matrix_batch_stride [[buffer(12)]], + const device out_mask_t* out_mask [[buffer(20)]], + const device op_mask_t* mat_mask [[buffer(21)]], + const device op_mask_t* vec_mask [[buffer(22)]], + const constant int* mask_strides [[buffer(23)]], + const constant size_t* mask_batch_strides [[buffer(24)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + using gemv_kernel = + GEMVTKernel; + threadgroup T tgp_memory + [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size]; + + constexpr bool has_operand_mask = !metal::is_same_v; + constexpr bool has_output_mask = !metal::is_same_v; + + // Update batch offsets + if (kDoNCBatch) { + in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim); + mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim); + + if (has_output_mask) { + out_mask += + elem_to_loc(tid.z, batch_shape, mask_batch_strides, batch_ndim); + mask_batch_strides += batch_ndim; + } + + if (has_operand_mask) { + const constant size_t* mask_strides_mat = mask_batch_strides; + const constant size_t* mask_strides_vec = mask_strides_mat + batch_ndim; + + ulong2 batch_offsets = elem_to_loc_broadcast( + tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim); + + mat_mask += batch_offsets.x; + vec_mask += batch_offsets.y; + } + + } else { + in_vec += tid.z * vector_batch_stride[0]; + mat += tid.z * matrix_batch_stride[0]; + + if (has_output_mask) { + out_mask += tid.z * mask_batch_strides[0]; + mask_batch_strides += batch_ndim; + } + + if (has_operand_mask) { + mat_mask += tid.z * mask_batch_strides[0]; + vec_mask += tid.z * mask_batch_strides[batch_ndim]; + } + } + + out_vec += tid.z * out_vec_size; + + gemv_kernel::run( + mat, + in_vec, + out_vec, + in_vec_size, + out_vec_size, + marix_ld, + out_mask, + mat_mask, + vec_mask, + mask_strides, + gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory, + tid, + lid, + simd_gid, + simd_lid); +} diff --git a/mlx/backend/metal/kernels/gemv_masked.metal b/mlx/backend/metal/kernels/gemv_masked.metal index 2d63a6e40..7df97bac3 100644 --- a/mlx/backend/metal/kernels/gemv_masked.metal +++ b/mlx/backend/metal/kernels/gemv_masked.metal @@ -1,5 +1,6 @@ // Copyright © 2023-2024 Apple Inc. +// clang-format off #include #include @@ -7,726 +8,7 @@ #include "mlx/backend/metal/kernels/defines.h" #include "mlx/backend/metal/kernels/utils.h" -#include "mlx/backend/metal/kernels/steel/utils.h" - -using namespace metal; - -/////////////////////////////////////////////////////////////////////////////// -/// Matrix vector multiplication -/////////////////////////////////////////////////////////////////////////////// - -#define MLX_MTL_CONST static constant constexpr const - -struct _NoMask { - char x; - - constexpr METAL_FUNC operator bool() { - return true; - } - constexpr METAL_FUNC operator bool() const threadgroup { - return true; - } - constexpr METAL_FUNC operator bool() const device { - return true; - } - constexpr METAL_FUNC operator bool() const constant { - return true; - } -}; - -typedef struct _NoMask nomask_t; - -template -struct ScaleOp { - OutT scale; - - METAL_FUNC OutT apply(InT x) const { - return static_cast(x) * scale; - } -}; - -template < - typename T, - typename out_mask_t, - typename op_mask_t, - const int BM, /* Threadgroup rows (in simdgroups) */ - const int BN, /* Threadgroup cols (in simdgroups) */ - const int SM, /* Simdgroup rows (in threads) */ - const int SN, /* Simdgroup cols (in threads) */ - const int TM, /* Thread rows (in elements) */ - const int TN> /* Thread cols (in elements) */ -struct GEMVKernel { - MLX_MTL_CONST int threadsM = BM * SM; - MLX_MTL_CONST int threadsN = BN * SN; - - MLX_MTL_CONST int blockM = threadsM * TM; - MLX_MTL_CONST int blockN = threadsN * TN; - - static_assert(SM * SN == 32, "simdgroup can only have 32 threads"); - - static_assert( - SN == 8 || SN == 16 || SN == 32, - "gemv block must have a width of 8, 16, or 32"); - - static_assert(blockN >= blockM, "Masked gemv must have blockN >= blockM"); - - MLX_MTL_CONST bool has_operand_mask = !metal::is_same_v; - MLX_MTL_CONST bool has_output_mask = !metal::is_same_v; - - MLX_MTL_CONST bool has_mul_operand_mask = - has_operand_mask && !metal::is_same_v; - MLX_MTL_CONST bool has_mul_output_mask = - has_output_mask && !metal::is_same_v; - - // - The matrix of size (M = out_vec_size, K = in_vec_size) is divided up - // into blocks of (blockM, blockN) divided among threadgroups - // - Every thread works on a block of (TM, TN) - // - We assume each threadgroup has (threadsN, threadsM, 1) threads - // - // 1. A thread loads TN elements each from mat along TM rows - // and the corresponding scalar from the vector - // 2. The thread then multiplies and adds to accumulate its local result for - // the block - // 3. At the end, each thread has accumulated results over all blocks across - // the rows. These are then summed up across the threadgroup - // 4. Each threadgroup writes its accumulated blockM outputs - // - // Edge case handling: - // - The threadgroup with the largest tid has blocks that exceed the matrix - // * The blocks that start outside the matrix are never read (thread results - // remain zero) - // * The last thread that partially overlaps with the matrix is shifted - // inwards such that the thread block fits exactly in the matrix - - MLX_MTL_CONST short tgp_mem_size = BN > 1 ? BN*(blockM + TM) : 0; - MLX_MTL_CONST bool needs_tgp_reduction = BN > 1; - - static METAL_FUNC void - load_unsafe(const device T* src, thread T dst[TN], const int src_offset = 0) { - MLX_MTL_PRAGMA_UNROLL - for (int tn = 0; tn < TN; tn++) { - dst[tn] = src[src_offset + tn]; - } - } - - static METAL_FUNC void load_safe( - const device T* src, - thread T dst[TN], - const int src_offset = 0, - const int src_size = TN) { - if (src_offset + TN <= src_size) { - MLX_MTL_PRAGMA_UNROLL - for (int tn = 0; tn < TN; tn++) { - dst[tn] = src[src_offset + tn]; - } - } else { // Edgecase - MLX_MTL_PRAGMA_UNROLL - for (int tn = 0; tn < TN; tn++) { - dst[tn] = src_offset + tn < src_size ? src[src_offset + tn] : 0; - } - } - } - - static METAL_FUNC void run( - const device T* mat [[buffer(0)]], - const device T* in_vec [[buffer(1)]], - device T* out_vec [[buffer(3)]], - const constant int& in_vec_size [[buffer(4)]], - const constant int& out_vec_size [[buffer(5)]], - const constant int& matrix_ld [[buffer(6)]], - const device out_mask_t* out_mask [[buffer(20)]], - const device op_mask_t* mat_mask [[buffer(21)]], - const device op_mask_t* vec_mask [[buffer(22)]], - const constant int* mask_strides [[buffer(23)]], - threadgroup T* tgp_memory [[threadgroup(0)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - // Appease compiler - (void)lid; - - // Thread local accumulation results - thread T result[TM] = {0}; - thread T inter[TN]; - thread T v_coeff[TN]; - - const int thrM = SN != 32 ? simd_lid / SN : 0; - const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid); - - const int sgN = BN != 1 ? (simd_gid % BN) : 0; - - const int simdM = BN != 1 ? SM * (simd_gid / BN) : int(SM * simd_gid); - const int simdN = BN != 1 ? SN * (simd_gid % BN) : 0; - - int bm = (simdM + thrM) * TM; - int bn = (simdN + thrN) * TN; - - // Block position - int out_row = tid.x * blockM + bm; - - // Exit simdgroup if rows out of bound - if (out_row >= out_vec_size) - return; - - // Adjust tail simdgroup to ensure in bound reads - out_row = out_row + TM <= out_vec_size ? out_row : out_vec_size - TM; - - // Prepare mask offsets - const constant int* out_mask_strides = mask_strides; - const constant int* mat_mask_strides = - mask_strides + (has_output_mask ? 2 : 0); - const constant int* vec_mask_strides = - mat_mask_strides + (has_operand_mask ? 2 : 0); - - const int m_block_idx = blockN > blockM ? out_row / blockN : int(tid.x); - - const int out_mask_offset = - !has_output_mask ? 0 : m_block_idx * out_mask_strides[1]; - - int mat_mask_offset = - !has_operand_mask ? 0 : m_block_idx * mat_mask_strides[1]; - int vec_mask_offset = 0; - const int mat_mask_step = !has_operand_mask ? 0 : mat_mask_strides[0]; - const int vec_mask_step = !has_operand_mask ? 0 : vec_mask_strides[1]; - - T out_scale{1}; - - // Check output mask - if (has_output_mask) { - auto mask_out = out_mask[out_mask_offset]; - - // Write zeros and return if mask is 0 - if (!mask_out) { - if (simdN == 0 && thrN == 0) { - MLX_MTL_PRAGMA_UNROLL - for (int tm = 0; tm < TM; tm++) { - out_vec[out_row + tm] = T(0.); - } - } - - return; - } - - // Store scalar if multiplicative mask - if (has_mul_output_mask) { - out_scale = T(mask_out); - } - } - - // Advance matrix - mat += out_row * matrix_ld; - - // Prepare for loop - constexpr const uniform loop_stride = make_uniform(blockN); - const uniform in_size = make_uniform(in_vec_size); - const uniform n_iter = in_size / loop_stride; - const uniform last_iter = loop_stride * n_iter; - const uniform leftover = in_size - last_iter; - - // Loop over in_vec in blocks of blockN - for (int i = 0; i < n_iter; ++i) { - if (!has_operand_mask || - (bool(mat_mask[mat_mask_offset]) && - bool(vec_mask[vec_mask_offset]))) { - T block_scale{1}; - if (has_mul_operand_mask) { - block_scale = - T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]); - } - - load_unsafe(in_vec, v_coeff, bn); - - // Apply scale - if (has_mul_operand_mask) { - MLX_MTL_PRAGMA_UNROLL - for (int tn = 0; tn < TN; tn++) { - v_coeff[tn] *= block_scale; - } - } - - // Per thread work loop - int mat_offset = 0; - MLX_MTL_PRAGMA_UNROLL - for (int tm = 0; tm < TM; tm++) { - // Load for the row - load_unsafe(mat, inter, mat_offset + bn); - - // Accumulate results - MLX_MTL_PRAGMA_UNROLL - for (int tn = 0; tn < TN; tn++) { - result[tm] += inter[tn] * v_coeff[tn]; - } - - mat_offset += matrix_ld; - } - } - - bn += blockN; - mat_mask_offset += mat_mask_step; - vec_mask_offset += vec_mask_step; - } - - if (leftover > 0 && - (!has_operand_mask || - (bool(mat_mask[mat_mask_offset]) && - bool(vec_mask[vec_mask_offset])))) { - T block_scale{1}; - if (has_mul_operand_mask) { - block_scale = - T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]); - } - - load_safe(in_vec, v_coeff, bn, in_size); - - // Apply scale - if (has_mul_operand_mask) { - MLX_MTL_PRAGMA_UNROLL - for (int tn = 0; tn < TN; tn++) { - v_coeff[tn] *= block_scale; - } - } - - // Per thread work loop - MLX_MTL_PRAGMA_UNROLL - for (int tm = 0; tm < TM; tm++) { - // Load for the row - load_safe(&mat[tm * matrix_ld], inter, bn, in_size); - - // Accumulate results - MLX_MTL_PRAGMA_UNROLL - for (int tn = 0; tn < TN; tn++) { - result[tm] += inter[tn] * v_coeff[tn]; - } - } - } - - // Apply out scale - if (has_mul_output_mask) { - MLX_MTL_PRAGMA_UNROLL - for (int tm = 0; tm < TM; tm++) { - result[tm] *= out_scale; - } - } - - // Simdgroup accumulations - MLX_MTL_PRAGMA_UNROLL - for (int tm = 0; tm < TM; tm++) { - MLX_MTL_PRAGMA_UNROLL - for (ushort sn = (SN / 2); sn >= 1; sn >>= 1) { - result[tm] += simd_shuffle_down(result[tm], sn); - } - } - - // Threadgroup accumulation results - if (needs_tgp_reduction) { - threadgroup T* tgp_results = tgp_memory + sgN * (blockM + TM) + bm; - if (thrN == 0) { - MLX_MTL_PRAGMA_UNROLL - for (int tm = 0; tm < TM; tm++) { - tgp_results[tm] = result[tm]; - } - - threadgroup_barrier(mem_flags::mem_none); - - if (sgN == 0) { - MLX_MTL_PRAGMA_UNROLL - for (int sgn = 1; sgn < BN; sgn++) { - MLX_MTL_PRAGMA_UNROLL - for (int tm = 0; tm < TM; tm++) { - result[tm] += tgp_results[sgn * (blockM + TM) + tm]; - } - } - } - } - } - - // Write outputs - if (simdN == 0 && thrN == 0) { - MLX_MTL_PRAGMA_UNROLL - for (int tm = 0; tm < TM; tm++) { - out_vec[out_row + tm] = result[tm]; - } - } - } -}; - -/////////////////////////////////////////////////////////////////////////////// -/// Vector matrix multiplication -/////////////////////////////////////////////////////////////////////////////// - -template < - typename T, - typename out_mask_t, - typename op_mask_t, - const int BM, /* Threadgroup rows (in simdgroups) */ - const int BN, /* Threadgroup cols (in simdgroups) */ - const int SM, /* Simdgroup rows (in threads) */ - const int SN, /* Simdgroup cols (in threads) */ - const int TM, /* Thread rows (in elements) */ - const int TN> /* Thread cols (in elements) */ -struct GEMVTKernel { - MLX_MTL_CONST int threadsM = BM * SM; - MLX_MTL_CONST int threadsN = BN * SN; - - MLX_MTL_CONST int blockM = threadsM * TM; - MLX_MTL_CONST int blockN = threadsN * TN; - - static_assert(SM * SN == 32, "simdgroup can only have 32 threads"); - - MLX_MTL_CONST bool has_operand_mask = !metal::is_same_v; - MLX_MTL_CONST bool has_output_mask = !metal::is_same_v; - - MLX_MTL_CONST bool has_mul_operand_mask = - has_operand_mask && !metal::is_same_v; - MLX_MTL_CONST bool has_mul_output_mask = - has_output_mask && !metal::is_same_v; - - // - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up - // into blocks of (blockM, blockN) divided among threadgroups - // - Every thread works on a block of (TM, TN) - // - We assume each threadgroup has (threadsN, threadsM, 1) threads - // - // 1. A thread loads TN elements each from mat along TM contiguous rows - // and the corresponding scalar from the vector - // 2. The thread then accumulates its local result for the block - // 3. At the end, each thread has accumulated results over all blocks across - // the rows. These are then summed up across the threadgroup - // 4. Each threadgroup writes its accumulated BN * TN outputs - // - // Edge case handling: - // - The threadgroup with the largest tid has blocks that exceed the matrix - // * The blocks that start outside the matrix are never read (thread results - // remain zero) - // * The last thread that partially overlaps with the matrix is shifted - // inwards such that the thread block fits exactly in the matrix - - MLX_MTL_CONST short tgp_mem_size = BM > 1 ? BM*(blockN + TN) : 0; - MLX_MTL_CONST bool needs_tgp_reduction = BM > 1; - - static METAL_FUNC void run( - const device T* mat [[buffer(0)]], - const device T* in_vec [[buffer(1)]], - device T* out_vec [[buffer(3)]], - const constant int& in_vec_size [[buffer(4)]], - const constant int& out_vec_size [[buffer(5)]], - const constant int& marix_ld [[buffer(6)]], - const device out_mask_t* out_mask [[buffer(20)]], - const device op_mask_t* mat_mask [[buffer(21)]], - const device op_mask_t* vec_mask [[buffer(22)]], - const constant int* mask_strides [[buffer(23)]], - threadgroup T* tgp_memory [[threadgroup(0)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - // Appease compiler - (void)lid; - - // Thread local accumulation results - T result[TN] = {0}; - T inter[TN]; - T v_coeff[TM]; - - const int thrM = SN != 32 ? simd_lid / SN : 0; - const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid); - - const int sgM = BN != 1 ? (simd_gid / BN) : int(simd_gid); - const int sgN = BN != 1 ? (simd_gid % BN) : 0; - - const int simdM = SM * sgM; - const int simdN = SN * sgN; - - int cm = (simdM + thrM); - int cn = (simdN + thrN); - - int bm = cm * TM; - int bn = cn * TN; - - int out_col = tid.x * blockN + bn; - - // Prepare mask offsets - const constant int* out_mask_strides = mask_strides; - const constant int* mat_mask_strides = - out_mask_strides + (has_output_mask ? 2 : 0); - const constant int* vec_mask_strides = - mat_mask_strides + (has_operand_mask ? 2 : 0); - - const int n_block_idx = blockM > blockN ? out_col / blockM : int(tid.x); - - const int out_mask_offset = - !has_output_mask ? 0 : n_block_idx; // * out_mask_strides[0]; - - int mat_mask_offset = - !has_operand_mask ? 0 : n_block_idx * mat_mask_strides[0]; - int vec_mask_offset = 0; - const int mat_mask_step = !has_operand_mask ? 0 : mat_mask_strides[1]; - const int vec_mask_step = !has_operand_mask ? 0 : vec_mask_strides[0]; - - T out_scale{1}; - - // Check output mask - if (has_output_mask) { - auto mask_out = out_mask[out_mask_offset]; - - // Write zeros and return if mask is 0 - if (!mask_out) { - if (cm == 0 && out_col < out_vec_size) { - if (out_col + TN <= out_vec_size) { - MLX_MTL_PRAGMA_UNROLL - for (int tn = 0; tn < TN; tn++) { - out_vec[out_col + tn] = T(0.); - } - } else { - for (int tn = 0; tn < TN && (out_col + tn) < out_vec_size; tn++) { - out_vec[out_col + tn] = T(0.); - } - } - } - - return; - } - - // Store scalar if multiplicative mask - if (has_mul_output_mask) { - out_scale = T(mask_out); - } - } - - // Prepare for loop - constexpr const uniform loop_stride = make_uniform(blockM); - const uniform in_size = make_uniform(in_vec_size); - const uniform n_iter = in_size / loop_stride; - const uniform last_iter = loop_stride * n_iter; - const uniform leftover = in_size - last_iter; - - // Edgecase handling - if (out_col < out_vec_size) { - out_col = (out_col + TN) <= out_vec_size ? out_col : out_vec_size - TN; - - // Per thread accumulation main loop - for (int i = 0; i < n_iter; ++i) { - // Adding a threadgroup_barrier improves performance slightly - // This is possibly it may help exploit cache better - threadgroup_barrier(mem_flags::mem_none); - - if (!has_operand_mask || - (bool(mat_mask[mat_mask_offset]) && - bool(vec_mask[vec_mask_offset]))) { - T block_scale{1}; - if (has_mul_operand_mask) { - block_scale = - T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]); - } - - MLX_MTL_PRAGMA_UNROLL - for (int tm = 0; tm < TM; tm++) { - v_coeff[tm] = in_vec[bm + tm]; - } - - // Apply scale - if (has_mul_operand_mask) { - MLX_MTL_PRAGMA_UNROLL - for (int tm = 0; tm < TM; tm++) { - v_coeff[tm] *= block_scale; - } - } - - MLX_MTL_PRAGMA_UNROLL - for (int tm = 0; tm < TM; tm++) { - for (int tn = 0; tn < TN; tn++) { - inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn]; - } - for (int tn = 0; tn < TN; tn++) { - result[tn] += v_coeff[tm] * inter[tn]; - } - } - } - - bm += blockM; - mat_mask_offset += mat_mask_step; - vec_mask_offset += vec_mask_step; - } - - if (leftover > 0 && - (!has_operand_mask || - (bool(mat_mask[mat_mask_offset]) && - bool(vec_mask[vec_mask_offset])))) { - T block_scale{1}; - if (has_mul_operand_mask) { - block_scale = - T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]); - } - - for (int tm = 0; tm < TM && bm + tm < in_vec_size; tm++) { - v_coeff[tm] = in_vec[bm + tm]; - - if (has_mul_operand_mask) { - v_coeff[tm] *= block_scale; - } - - MLX_MTL_PRAGMA_UNROLL - for (int tn = 0; tn < TN; tn++) { - inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn]; - } - - MLX_MTL_PRAGMA_UNROLL - for (int tn = 0; tn < TN; tn++) { - result[tn] += v_coeff[tm] * inter[tn]; - } - } - } - } - - // Apply out scale - if (has_mul_output_mask) { - MLX_MTL_PRAGMA_UNROLL - for (int tn = 0; tn < TN; tn++) { - result[tn] *= out_scale; - } - } - - // Simdgroup accumulations - MLX_MTL_PRAGMA_UNROLL - for (int tn = 0; tn < TN; tn++) { - MLX_MTL_PRAGMA_UNROLL - for (ushort sm = (SM / 2); sm >= 1; sm >>= 1) { - result[tn] += simd_shuffle_down(result[tn], SN * sm); - } - } - - // Threadgroup accumulation results - if (needs_tgp_reduction) { - threadgroup T* tgp_results = tgp_memory + sgM * (blockN + TN) + bn; - if (thrM == 0) { - MLX_MTL_PRAGMA_UNROLL - for (int tn = 0; tn < TN; tn++) { - tgp_results[tn] = result[tn]; - } - - threadgroup_barrier(mem_flags::mem_none); - - if (sgM == 0) { - MLX_MTL_PRAGMA_UNROLL - for (int sgm = 1; sgm < BM; sgm++) { - MLX_MTL_PRAGMA_UNROLL - for (int tn = 0; tn < TN; tn++) { - result[tn] += tgp_results[sgm * (blockN + TN) + tn]; - } - } - } - } - } - - // Threadgroup accumulation and writing out results - if (cm == 0 && out_col < out_vec_size) { - MLX_MTL_PRAGMA_UNROLL - for (int j = 0; j < TN; j++) { - out_vec[out_col + j] = result[j]; - } - } - } -}; - -/////////////////////////////////////////////////////////////////////////////// -/// Matrix vector multiplication -/////////////////////////////////////////////////////////////////////////////// - -template < - typename T, - typename out_mask_t, - typename op_mask_t, - const int BM, /* Threadgroup rows (in simdgroups) */ - const int BN, /* Threadgroup cols (in simdgroups) */ - const int SM, /* Simdgroup rows (in threads) */ - const int SN, /* Simdgroup cols (in threads) */ - const int TM, /* Thread rows (in elements) */ - const int TN, /* Thread cols (in elements) */ - const bool kDoNCBatch> /* Batch ndim > 1 */ -[[kernel, max_total_threads_per_threadgroup(BM* BN * 32)]] void gemv_masked( - const device T* mat [[buffer(0)]], - const device T* in_vec [[buffer(1)]], - device T* out_vec [[buffer(3)]], - const constant int& in_vec_size [[buffer(4)]], - const constant int& out_vec_size [[buffer(5)]], - const constant int& marix_ld [[buffer(6)]], - const constant int& batch_ndim [[buffer(9)]], - const constant int* batch_shape [[buffer(10)]], - const constant size_t* vector_batch_stride [[buffer(11)]], - const constant size_t* matrix_batch_stride [[buffer(12)]], - const device out_mask_t* out_mask [[buffer(20)]], - const device op_mask_t* mat_mask [[buffer(21)]], - const device op_mask_t* vec_mask [[buffer(22)]], - const constant int* mask_strides [[buffer(23)]], - const constant size_t* mask_batch_strides [[buffer(24)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - using gemv_kernel = - GEMVKernel; - threadgroup T tgp_memory - [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size]; - - constexpr bool has_operand_mask = !metal::is_same_v; - constexpr bool has_output_mask = !metal::is_same_v; - - // Update batch offsets - if (kDoNCBatch) { - in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim); - mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim); - - if (has_output_mask) { - out_mask += - elem_to_loc(tid.z, batch_shape, mask_batch_strides, batch_ndim); - mask_batch_strides += batch_ndim; - } - - if (has_operand_mask) { - const constant size_t* mask_strides_mat = mask_batch_strides; - const constant size_t* mask_strides_vec = mask_strides_mat + batch_ndim; - - ulong2 batch_offsets = elem_to_loc_broadcast( - tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim); - - mat_mask += batch_offsets.x; - vec_mask += batch_offsets.y; - } - - } else { - in_vec += tid.z * vector_batch_stride[0]; - mat += tid.z * matrix_batch_stride[0]; - - if (has_output_mask) { - out_mask += tid.z * mask_batch_strides[0]; - mask_batch_strides += batch_ndim; - } - - if (has_operand_mask) { - mat_mask += tid.z * mask_batch_strides[0]; - vec_mask += tid.z * mask_batch_strides[batch_ndim]; - } - } - - out_vec += tid.z * out_vec_size; - - gemv_kernel::run( - mat, - in_vec, - out_vec, - in_vec_size, - out_vec_size, - marix_ld, - out_mask, - mat_mask, - vec_mask, - mask_strides, - gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory, - tid, - lid, - simd_gid, - simd_lid); -} +#include "mlx/backend/metal/kernels/gemv_masked.h" #define instantiate_gemv_helper( \ outm_n, outm_t, opm_n, opm_t, name, itype, bm, bn, sm, sn, tm, tn, nc) \ @@ -754,7 +36,6 @@ template < uint simd_gid [[simdgroup_index_in_threadgroup]], \ uint simd_lid [[thread_index_in_simdgroup]]); -// clang-format off #define instantiate_gemv_base(name, itype, bm, bn, sm, sn, tm, tn, nc) \ instantiate_gemv_helper(bool_, bool, bool_, bool, name, itype, bm, bn, sm, sn, tm, tn, nc) \ instantiate_gemv_helper(name, itype, name, itype, name, itype, bm, bn, sm, sn, tm, tn, nc) \ @@ -763,125 +44,23 @@ template < instantiate_gemv_helper(nomask, nomask_t, name, itype, name, itype, bm, bn, sm, sn, tm, tn, nc) \ instantiate_gemv_helper(nomask, nomask_t, bool_, bool, name, itype, bm, bn, sm, sn, tm, tn, nc) \ instantiate_gemv_helper(bool_, bool, nomask, nomask_t, name, itype, bm, bn, sm, sn, tm, tn, nc) \ - instantiate_gemv_helper(name, itype, nomask, nomask_t, name, itype, bm, bn, sm, sn, tm, tn, nc) // clang-format on + instantiate_gemv_helper(name, itype, nomask, nomask_t, name, itype, bm, bn, sm, sn, tm, tn, nc) -// clang-format off -#define instantiate_gemv(name, itype, bm, bn, sm, sn, tm, tn) \ +#define instantiate_gemv(name, itype, bm, bn, sm, sn, tm, tn) \ instantiate_gemv_base(name, itype, bm, bn, sm, sn, tm, tn, 0) \ - instantiate_gemv_base(name, itype, bm, bn, sm, sn, tm, tn, 1) // clang-format on + instantiate_gemv_base(name, itype, bm, bn, sm, sn, tm, tn, 1) -// clang-format off #define instantiate_gemv_blocks(name, itype) \ instantiate_gemv(name, itype, 2, 1, 4, 8, 1, 4) \ instantiate_gemv(name, itype, 2, 1, 4, 8, 4, 4) \ instantiate_gemv(name, itype, 2, 1, 2, 16, 1, 4) \ instantiate_gemv(name, itype, 2, 1, 2, 16, 4, 4) \ - instantiate_gemv(name, itype, 4, 1, 2, 16, 4, 4) // clang-format on + instantiate_gemv(name, itype, 4, 1, 2, 16, 4, 4) instantiate_gemv_blocks(float32, float); instantiate_gemv_blocks(float16, half); instantiate_gemv_blocks(bfloat16, bfloat16_t); -/////////////////////////////////////////////////////////////////////////////// -/// Vector matrix multiplication -/////////////////////////////////////////////////////////////////////////////// - -template < - typename T, - typename out_mask_t, - typename op_mask_t, - const int BM, /* Threadgroup rows (in simdgroups) */ - const int BN, /* Threadgroup cols (in simdgroups) */ - const int SM, /* Simdgroup rows (in threads) */ - const int SN, /* Simdgroup cols (in threads) */ - const int TM, /* Thread rows (in elements) */ - const int TN, /* Thread cols (in elements) */ - const bool kDoNCBatch> /* Batch ndim > 1 */ -[[kernel, max_total_threads_per_threadgroup(BM* BN * 32)]] void gemv_t_masked( - const device T* mat [[buffer(0)]], - const device T* in_vec [[buffer(1)]], - device T* out_vec [[buffer(3)]], - const constant int& in_vec_size [[buffer(4)]], - const constant int& out_vec_size [[buffer(5)]], - const constant int& marix_ld [[buffer(6)]], - const constant int& batch_ndim [[buffer(9)]], - const constant int* batch_shape [[buffer(10)]], - const constant size_t* vector_batch_stride [[buffer(11)]], - const constant size_t* matrix_batch_stride [[buffer(12)]], - const device out_mask_t* out_mask [[buffer(20)]], - const device op_mask_t* mat_mask [[buffer(21)]], - const device op_mask_t* vec_mask [[buffer(22)]], - const constant int* mask_strides [[buffer(23)]], - const constant size_t* mask_batch_strides [[buffer(24)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - using gemv_kernel = - GEMVTKernel; - threadgroup T tgp_memory - [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size]; - - constexpr bool has_operand_mask = !metal::is_same_v; - constexpr bool has_output_mask = !metal::is_same_v; - - // Update batch offsets - if (kDoNCBatch) { - in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim); - mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim); - - if (has_output_mask) { - out_mask += - elem_to_loc(tid.z, batch_shape, mask_batch_strides, batch_ndim); - mask_batch_strides += batch_ndim; - } - - if (has_operand_mask) { - const constant size_t* mask_strides_mat = mask_batch_strides; - const constant size_t* mask_strides_vec = mask_strides_mat + batch_ndim; - - ulong2 batch_offsets = elem_to_loc_broadcast( - tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim); - - mat_mask += batch_offsets.x; - vec_mask += batch_offsets.y; - } - - } else { - in_vec += tid.z * vector_batch_stride[0]; - mat += tid.z * matrix_batch_stride[0]; - - if (has_output_mask) { - out_mask += tid.z * mask_batch_strides[0]; - mask_batch_strides += batch_ndim; - } - - if (has_operand_mask) { - mat_mask += tid.z * mask_batch_strides[0]; - vec_mask += tid.z * mask_batch_strides[batch_ndim]; - } - } - - out_vec += tid.z * out_vec_size; - - gemv_kernel::run( - mat, - in_vec, - out_vec, - in_vec_size, - out_vec_size, - marix_ld, - out_mask, - mat_mask, - vec_mask, - mask_strides, - gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory, - tid, - lid, - simd_gid, - simd_lid); -} - #define instantiate_gemv_t_helper( \ outm_n, outm_t, opm_n, opm_t, name, itype, bm, bn, sm, sn, tm, tn, nc) \ template [[host_name("gemv_t_outmask_" #outm_n "_opmask_" #opm_n "_" #name \ @@ -908,7 +87,6 @@ template < uint simd_gid [[simdgroup_index_in_threadgroup]], \ uint simd_lid [[thread_index_in_simdgroup]]); -// clang-format off #define instantiate_gemv_t_base(name, itype, bm, bn, sm, sn, tm, tn, nc) \ instantiate_gemv_t_helper(bool_, bool, bool_, bool, name, itype, bm, bn, sm, sn, tm, tn, nc) \ instantiate_gemv_t_helper(name, itype, name, itype, name, itype, bm, bn, sm, sn, tm, tn, nc) \ @@ -917,23 +95,20 @@ template < instantiate_gemv_t_helper(nomask, nomask_t, name, itype, name, itype, bm, bn, sm, sn, tm, tn, nc) \ instantiate_gemv_t_helper(nomask, nomask_t, bool_, bool, name, itype, bm, bn, sm, sn, tm, tn, nc) \ instantiate_gemv_t_helper(bool_, bool, nomask, nomask_t, name, itype, bm, bn, sm, sn, tm, tn, nc) \ - instantiate_gemv_t_helper(name, itype, nomask, nomask_t, name, itype, bm, bn, sm, sn, tm, tn, nc) // clang-format on + instantiate_gemv_t_helper(name, itype, nomask, nomask_t, name, itype, bm, bn, sm, sn, tm, tn, nc) -// clang-format off #define instantiate_gemv_t(name, itype, bm, bn, sm, sn, tm, tn) \ instantiate_gemv_t_base(name, itype, bm, bn, sm, sn, tm, tn, 0) \ - instantiate_gemv_t_base(name, itype, bm, bn, sm, sn, tm, tn, 1) // clang-format on + instantiate_gemv_t_base(name, itype, bm, bn, sm, sn, tm, tn, 1) -// clang-format off #define instantiate_gemv_t_blocks(name, itype) \ instantiate_gemv_t(name, itype, 1, 1, 8, 4, 4, 1) \ instantiate_gemv_t(name, itype, 1, 2, 8, 4, 4, 4) \ instantiate_gemv_t(name, itype, 1, 1, 8, 4, 8, 1) \ instantiate_gemv_t(name, itype, 1, 1, 8, 4, 8, 4) \ instantiate_gemv_t(name, itype, 1, 2, 8, 4, 8, 4) \ - instantiate_gemv_t(name, itype, 1, 4, 8, 4, 8, 4) // clang-format on + instantiate_gemv_t(name, itype, 1, 4, 8, 4, 8, 4) -// clang-format off instantiate_gemv_t_blocks(float32, float); instantiate_gemv_t_blocks(float16, half); -instantiate_gemv_t_blocks(bfloat16, bfloat16_t); // clang-format on \ No newline at end of file +instantiate_gemv_t_blocks(bfloat16, bfloat16_t); // clang-format on diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index 4dd6cd715..5edb66197 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -11,187 +11,14 @@ #include "mlx/backend/metal/kernels/defines.h" #include "mlx/backend/metal/kernels/steel/gemm/params.h" #include "mlx/backend/metal/matmul.h" -#include "mlx/backend/metal/mps/gemm.h" #include "mlx/backend/metal/utils.h" #include "mlx/primitives.h" #include "mlx/utils.h" namespace mlx::core { -/////////////////////////////////////////////////////////////////////////////// -// MPS Matmul fallback -/////////////////////////////////////////////////////////////////////////////// - namespace { -bool use_mps() { - auto get_val = []() { - if (const char* buff_str = std::getenv("MLX_USE_MPS")) { - return std::string(buff_str) != "OFF"; - } else { - return false; - } - }; - static bool use_mps_ = get_val(); - return use_mps_; -} - -#define MAX_OPS_PER_BUFFER max_ops_per_buffer() - -inline void mps_matmul( - const Stream& s, - metal::Device& d, - const array& a, - const array& b, - array& out, - int M, - int N, - int K, - int batch_size_out, - int lda, - int ldb, - bool transpose_a, - bool transpose_b, - std::vector& copies, - float alpha = 1.0f, - float beta = 0.0f) { - MPS::DataType mps_dtype = MPS::DataTypeFloat32; - - if (out.dtype() == float16) { - mps_dtype = MPS::DataTypeFloat16; - } else if (out.dtype() == bfloat16) { - mps_dtype = MPS::DataTypeBFloat16; - } - - // Used batched MPSMatrixMultiplication if batch_size_out > 1 - // We only accept the following cases: - // 1. Both a, b have batch_size_out matrices worth of data - // 2. Only one of a or b has batch_size_out matrices worth of data and - // the other has matrix worth of data - - // The matrix dimensions of a and b are sure to be regularly strided - if (batch_size_out > 1) { - // No broadcasting defaults - auto batch_size_a = a.data_size() / (M * K); - auto batch_size_b = b.data_size() / (K * N); - - auto matrix_stride_a = M * K; - auto matrix_stride_b = K * N; - auto matrix_stride_out = M * N; - - // At this point, batch_size_a, batch_size_b show the number of matrices - // in data, no broadcasted strides considered - if (batch_size_out == std::max(batch_size_a, batch_size_b)) { - // Handle simple broadcasting - if (std::min(batch_size_a, batch_size_b) == 1) { - matrix_stride_a = (batch_size_a == 1) ? 0 : matrix_stride_a; - matrix_stride_b = (batch_size_b == 1) ? 0 : matrix_stride_b; - - batch_size_a = batch_size_out; - batch_size_b = batch_size_out; - } - - // Only proceed if broadcasting between a and b is simple - // At this point, batch_size_a, batch_size_b show the number of matrices - // after broadcasting - if (batch_size_a == batch_size_b) { - auto a_desc = MPS::MatrixDescriptor::matrixDescriptor( - (M * K) / lda, - lda, - batch_size_a, - lda * a.itemsize(), - (matrix_stride_a * a.itemsize()), - mps_dtype); - - auto b_desc = MPS::MatrixDescriptor::matrixDescriptor( - (K * N) / ldb, - ldb, - batch_size_b, - ldb * b.itemsize(), - (matrix_stride_b * b.itemsize()), - mps_dtype); - - auto out_desc = MPS::MatrixDescriptor::matrixDescriptor( - M, - N, - batch_size_out, - N * out.itemsize(), - matrix_stride_out * out.itemsize(), - mps_dtype); - - auto a_buf = static_cast(a.buffer().ptr()); - auto a_mat = MPS::Matrix::alloc()->init(a_buf, a_desc); - - auto b_buf = static_cast(b.buffer().ptr()); - auto b_mat = MPS::Matrix::alloc()->init(b_buf, b_desc); - - auto out_buf = static_cast(out.buffer().ptr()); - auto out_mat = MPS::Matrix::alloc()->init(out_buf, out_desc); - - auto kernel = MPS::MatrixMultiplication::alloc()->init( - d.mtl_device(), transpose_a, transpose_b, M, N, K, alpha, beta); - - auto command_buffer = d.get_command_buffer(s.index); - kernel->setBatchSize(batch_size_out); - kernel->setBatchStart(0); - kernel->encodeToCommandBuffer(command_buffer, a_mat, b_mat, out_mat); - command_buffer->addCompletedHandler( - [a_mat, b_mat, out_mat, kernel, copies]( - MTL::CommandBuffer*) mutable { - a_mat->release(); - b_mat->release(); - out_mat->release(); - kernel->release(); - copies.clear(); - }); - - return; - } - } - } - - // Schedule as many calls to MPSMatrixMultiplication as needed otherwise - auto a_desc = MPS::MatrixDescriptor::matrixDescriptor( - a.data_size() / lda, lda, lda * a.itemsize(), mps_dtype); - - auto b_desc = MPS::MatrixDescriptor::matrixDescriptor( - b.data_size() / ldb, ldb, ldb * b.itemsize(), mps_dtype); - - auto out_desc = MPS::MatrixDescriptor::matrixDescriptor( - batch_size_out * M, N, N * out.itemsize(), mps_dtype); - - auto a_buf = static_cast(a.buffer().ptr()); - auto a_mat = MPS::Matrix::alloc()->init(a_buf, a_desc); - - auto b_buf = static_cast(b.buffer().ptr()); - auto b_mat = MPS::Matrix::alloc()->init(b_buf, b_desc); - - auto out_buf = static_cast(out.buffer().ptr()); - auto out_mat = MPS::Matrix::alloc()->init(out_buf, out_desc); - - auto kernel = MPS::MatrixMultiplication::alloc()->init( - d.mtl_device(), transpose_a, transpose_b, M, N, K, alpha, beta); - - auto command_buffer = d.get_command_buffer(s.index); - for (int i = 0; i < batch_size_out; ++i) { - auto a_row = elem_to_loc(M * K * i, a.shape(), a.strides()) / lda; - auto b_row = elem_to_loc(K * N * i, b.shape(), b.strides()) / ldb; - kernel->setLeftMatrixOrigin({a_row, 0, 0}); - kernel->setRightMatrixOrigin({b_row, 0, 0}); - kernel->setResultMatrixOrigin({i * static_cast(M), 0, 0}); - kernel->encodeToCommandBuffer(command_buffer, a_mat, b_mat, out_mat); - } - - command_buffer->addCompletedHandler( - [a_mat, b_mat, out_mat, kernel, copies](MTL::CommandBuffer*) mutable { - a_mat->release(); - b_mat->release(); - out_mat->release(); - kernel->release(); - copies.clear(); - }); -} - inline auto collapse_batches(const array& a, const array& b) { // Get and check the shape for the batched dims std::vector A_bshape{a.shape().begin(), a.shape().end() - 2}; @@ -860,26 +687,6 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { ///////////////////////////////////////////////////////////////////////////// // Gemm specialization - if (use_mps()) { - d.end_encoding(s.index); - - return mps_matmul( - s, - d, - a, - b, - out, - M, - N, - K, - batch_size_out, - a_cols, - b_cols, - a_transposed, - b_transposed, - copies); - } - return steel_matmul( /* const Stream& s = */ s, /* metal::Device& d = */ d, @@ -1529,8 +1336,22 @@ void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { kname << "_nc" << !contiguous_kernel; // Encode and dispatch kernel + auto kernel = get_gemv_masked_kernel( + d, + kname.str(), + out, + has_out_mask ? std::optional{inputs[2]} : std::nullopt, + has_op_mask ? std::optional{inputs.back()} : std::nullopt, + transpose_mat, + bm, + bn, + sm, + sn, + tm, + tn, + contiguous_kernel); + auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(kname.str()); compute_encoder->setComputePipelineState(kernel); int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp; diff --git a/mlx/backend/metal/matmul.h b/mlx/backend/metal/matmul.h index bbe41ea8d..3edf34f66 100644 --- a/mlx/backend/metal/matmul.h +++ b/mlx/backend/metal/matmul.h @@ -1,14 +1,6 @@ // Copyright © 2023 Apple Inc. -#include -#include -#include - -#include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/device.h" -#include "mlx/backend/metal/mps/gemm.h" -#include "mlx/backend/metal/utils.h" -#include "mlx/utils.h" namespace mlx::core { @@ -48,4 +40,4 @@ void steel_matmul( std::vector A_batch_stride = {}, std::vector B_batch_stride = {}); -} // namespace mlx::core \ No newline at end of file +} // namespace mlx::core diff --git a/mlx/backend/metal/mps/gemm.h b/mlx/backend/metal/mps/gemm.h deleted file mode 100644 index c93df68d8..000000000 --- a/mlx/backend/metal/mps/gemm.h +++ /dev/null @@ -1,370 +0,0 @@ -// Copyright © 2023 Apple Inc. - -#pragma once - -#include - -#define _MPS_PRIVATE_CLS(symbol) (MTL::Private::Class::s_k##symbol) -#define _MPS_PRIVATE_SEL(accessor) (MTL::Private::Selector::s_k##accessor) - -namespace MTL::Private::Class { -_MTL_PRIVATE_DEF_CLS(MPSMatrixDescriptor); -_MTL_PRIVATE_DEF_CLS(MPSMatrix); -_MTL_PRIVATE_DEF_CLS(MPSVectorDescriptor); -_MTL_PRIVATE_DEF_CLS(MPSVector); -_MTL_PRIVATE_DEF_CLS(MPSKernel); -_MTL_PRIVATE_DEF_CLS(MPSMatrixMultiplication); -_MTL_PRIVATE_DEF_CLS(MPSMatrixVectorMultiplication); -} // namespace MTL::Private::Class - -namespace MTL::Private::Selector { -_MTL_PRIVATE_DEF_SEL( - matrixDescriptorWithRows_columns_rowBytes_dataType, - "matrixDescriptorWithRows:columns:rowBytes:dataType:"); -_MTL_PRIVATE_DEF_SEL( - matrixDescriptorWithRows_columns_matrices_rowBytes_matrixBytes_dataType, - "matrixDescriptorWithRows:columns:matrices:rowBytes:matrixBytes:dataType:"); -_MTL_PRIVATE_DEF_SEL(rows, "rows"); -_MTL_PRIVATE_DEF_SEL(initWithBuffer_descriptor, "initWithBuffer:descriptor:"); -_MTL_PRIVATE_DEF_SEL( - initWithDevice_, - "initWithDevice:transposeLeft:transposeRight:" - "resultRows:resultColumns:interiorColumns:alpha:beta:"); -_MTL_PRIVATE_DEF_SEL( - encodeToCommandBuffer_leftMatrix_rightMatrix_resultMatrix, - "encodeToCommandBuffer:leftMatrix:rightMatrix:resultMatrix:"); -_MTL_PRIVATE_DEF_SEL(setLeftMatrixOrigin_, "setLeftMatrixOrigin:"); -_MTL_PRIVATE_DEF_SEL(setRightMatrixOrigin_, "setRightMatrixOrigin:"); -_MTL_PRIVATE_DEF_SEL(setResultMatrixOrigin_, "setResultMatrixOrigin:"); -_MTL_PRIVATE_DEF_SEL(setBatchStart_, "setBatchStart:"); -_MTL_PRIVATE_DEF_SEL(setBatchSize_, "setBatchSize:"); -_MTL_PRIVATE_DEF_SEL( - vectorDescriptorWithLength_dataType, - "vectorDescriptorWithLength:dataType:"); -_MTL_PRIVATE_DEF_SEL( - vectorDescriptorWithLength_vectors_vectorBytes_dataType, - "vectorDescriptorWithLength:vectors:vectorBytes:dataType:"); -_MTL_PRIVATE_DEF_SEL( - initWithDevice_transpose_rows_columns_alpha_beta, - "initWithDevice:transpose:rows:columns:alpha:beta:"); -_MTL_PRIVATE_DEF_SEL( - encodeToCommandBuffer_inputMatrix_inputVector_resultVector, - "encodeToCommandBuffer:inputMatrix:inputVector:resultVector:"); -} // namespace MTL::Private::Selector - -namespace MPS { - -typedef enum DataType : uint32_t { - DataTypeFloatBit = 0x10000000, - DataTypeAlternateEncodingBit = 0x80000000, - DataTypeFloat16 = DataTypeFloatBit | 16, - DataTypeFloat32 = DataTypeFloatBit | 32, - DataTypeBFloat16 = DataTypeAlternateEncodingBit | DataTypeFloat16 -} DataType; - -class MatrixDescriptor : public NS::Copying { - public: - static class MatrixDescriptor* matrixDescriptor( - NS::UInteger rows, - NS::UInteger columns, - NS::UInteger rowBytes, - NS::UInteger dataType); - static class MatrixDescriptor* matrixDescriptor( - NS::UInteger rows, - NS::UInteger columns, - NS::UInteger matrices, - NS::UInteger rowBytes, - NS::UInteger matrixBytes, - NS::UInteger dataType); - NS::UInteger rows() const; -}; - -class Matrix : public NS::Referencing { - public: - static class Matrix* alloc(); - Matrix* init(MTL::Buffer* buffer, MatrixDescriptor* descriptor); - Matrix* init(const MTL::Buffer* buffer, MatrixDescriptor* descriptor); -}; - -class Kernel : public NS::Referencing { - public: - NS::String* label() const; - MTL::Device* device() const; -}; - -class MatrixMultiplication - : public NS::Referencing { - public: - static class MatrixMultiplication* alloc(); - - MatrixMultiplication* init( - MTL::Device* device, - bool transposeLeft, - bool transposeRight, - NS::UInteger resultRows, - NS::UInteger resultColumns, - NS::UInteger interiorColumns, - double alpha, - double beta); - - void encodeToCommandBuffer( - MTL::CommandBuffer* commandBuffer, - Matrix* leftMatrix, - Matrix* rightMatrix, - Matrix* resultMatrix); - - void setLeftMatrixOrigin(MTL::Origin origin); - void setRightMatrixOrigin(MTL::Origin origin); - void setResultMatrixOrigin(MTL::Origin origin); - void setBatchStart(NS::UInteger batchStart); - void setBatchSize(NS::UInteger batchSize); -}; - -class VectorDescriptor : public NS::Copying { - public: - static class VectorDescriptor* vectorDescriptor( - NS::UInteger length, - NS::UInteger dataType); - static class VectorDescriptor* vectorDescriptor( - NS::UInteger length, - NS::UInteger vectors, - NS::UInteger vectorBytes, - NS::UInteger dataType); -}; - -class Vector : public NS::Referencing { - public: - static class Vector* alloc(); - Vector* init(MTL::Buffer* buffer, VectorDescriptor* descriptor); - Vector* init(const MTL::Buffer* buffer, VectorDescriptor* descriptor); -}; - -class MatrixVectorMultiplication - : public NS::Referencing { - public: - static class MatrixVectorMultiplication* alloc(); - - MatrixVectorMultiplication* init( - MTL::Device* device, - bool transpose, - NS::UInteger rows, - NS::UInteger columns, - double alpha, - double beta); - - void encodeToCommandBuffer( - MTL::CommandBuffer* commandBuffer, - Matrix* inputMatrix, - Vector* inputVector, - Vector* resultVector); -}; - -_MTL_INLINE MatrixDescriptor* MatrixDescriptor::matrixDescriptor( - NS::UInteger rows, - NS::UInteger columns, - NS::UInteger rowBytes, - NS::UInteger dataType) { - return Object::sendMessage( - _MPS_PRIVATE_CLS(MPSMatrixDescriptor), - _MPS_PRIVATE_SEL(matrixDescriptorWithRows_columns_rowBytes_dataType), - rows, - columns, - rowBytes, - dataType); -} - -_MTL_INLINE MatrixDescriptor* MatrixDescriptor::matrixDescriptor( - NS::UInteger rows, - NS::UInteger columns, - NS::UInteger matrices, - NS::UInteger rowBytes, - NS::UInteger matrixBytes, - NS::UInteger dataType) { - return Object::sendMessage( - _MPS_PRIVATE_CLS(MPSMatrixDescriptor), - _MPS_PRIVATE_SEL( - matrixDescriptorWithRows_columns_matrices_rowBytes_matrixBytes_dataType), - rows, - columns, - matrices, - rowBytes, - matrixBytes, - dataType); -} - -_MTL_INLINE NS::UInteger MatrixDescriptor::rows() const { - return Object::sendMessage(this, _MPS_PRIVATE_SEL(rows)); -} - -_MTL_INLINE Matrix* Matrix::alloc() { - return NS::Object::alloc(_MPS_PRIVATE_CLS(MPSMatrix)); -} - -_MTL_INLINE Matrix* Matrix::init( - MTL::Buffer* buffer, - MatrixDescriptor* descriptor) { - return Object::sendMessage( - this, _MPS_PRIVATE_SEL(initWithBuffer_descriptor), buffer, descriptor); -} - -_MTL_INLINE Matrix* Matrix::init( - const MTL::Buffer* buffer, - MatrixDescriptor* descriptor) { - return init(const_cast(buffer), descriptor); -} - -_MTL_INLINE NS::String* Kernel::label() const { - return Object::sendMessage(this, _MPS_PRIVATE_SEL(label)); -} - -_MTL_INLINE MTL::Device* Kernel::device() const { - return Object::sendMessage(this, _MPS_PRIVATE_SEL(device)); -} - -_MTL_INLINE MatrixMultiplication* MatrixMultiplication::alloc() { - return NS::Object::alloc( - _MPS_PRIVATE_CLS(MPSMatrixMultiplication)); -} - -_MTL_INLINE MatrixMultiplication* MatrixMultiplication::init( - MTL::Device* device, - bool transposeLeft, - bool transposeRight, - NS::UInteger resultRows, - NS::UInteger resultColumns, - NS::UInteger interiorColumns, - double alpha, - double beta) { - return Object::sendMessage( - this, - _MPS_PRIVATE_SEL(initWithDevice_), - device, - transposeLeft, - transposeRight, - resultRows, - resultColumns, - interiorColumns, - alpha, - beta); -} - -_MTL_INLINE void MatrixMultiplication::encodeToCommandBuffer( - MTL::CommandBuffer* commandBuffer, - Matrix* leftMatrix, - Matrix* rightMatrix, - Matrix* resultMatrix) { - return Object::sendMessage( - this, - _MPS_PRIVATE_SEL( - encodeToCommandBuffer_leftMatrix_rightMatrix_resultMatrix), - commandBuffer, - leftMatrix, - rightMatrix, - resultMatrix); -} - -_MTL_INLINE void MatrixMultiplication::setLeftMatrixOrigin(MTL::Origin origin) { - Object::sendMessage( - this, _MPS_PRIVATE_SEL(setLeftMatrixOrigin_), origin); -} - -_MTL_INLINE void MatrixMultiplication::setRightMatrixOrigin( - MTL::Origin origin) { - Object::sendMessage( - this, _MPS_PRIVATE_SEL(setRightMatrixOrigin_), origin); -} - -_MTL_INLINE void MatrixMultiplication::setResultMatrixOrigin( - MTL::Origin origin) { - Object::sendMessage( - this, _MPS_PRIVATE_SEL(setResultMatrixOrigin_), origin); -} - -_MTL_INLINE void MatrixMultiplication::setBatchStart(NS::UInteger batchStart) { - Object::sendMessage(this, _MPS_PRIVATE_SEL(setBatchStart_), batchStart); -} - -_MTL_INLINE void MatrixMultiplication::setBatchSize(NS::UInteger batchSize) { - Object::sendMessage(this, _MPS_PRIVATE_SEL(setBatchSize_), batchSize); -} - -_MTL_INLINE VectorDescriptor* VectorDescriptor::vectorDescriptor( - NS::UInteger length, - NS::UInteger dataType) { - return Object::sendMessage( - _MPS_PRIVATE_CLS(MPSVectorDescriptor), - _MPS_PRIVATE_SEL(vectorDescriptorWithLength_dataType), - length, - dataType); -} - -_MTL_INLINE VectorDescriptor* VectorDescriptor::vectorDescriptor( - NS::UInteger length, - NS::UInteger vectors, - NS::UInteger vectorBytes, - NS::UInteger dataType) { - return Object::sendMessage( - _MPS_PRIVATE_CLS(MPSVectorDescriptor), - _MPS_PRIVATE_SEL(vectorDescriptorWithLength_vectors_vectorBytes_dataType), - length, - vectors, - vectorBytes, - dataType); -} - -_MTL_INLINE Vector* Vector::alloc() { - return NS::Object::alloc(_MPS_PRIVATE_CLS(MPSVector)); -} - -_MTL_INLINE Vector* Vector::init( - MTL::Buffer* buffer, - VectorDescriptor* descriptor) { - return Object::sendMessage( - this, _MPS_PRIVATE_SEL(initWithBuffer_descriptor), buffer, descriptor); -} - -_MTL_INLINE Vector* Vector::init( - const MTL::Buffer* buffer, - VectorDescriptor* descriptor) { - return init(const_cast(buffer), descriptor); -} - -_MTL_INLINE MatrixVectorMultiplication* MatrixVectorMultiplication::alloc() { - return NS::Object::alloc( - _MPS_PRIVATE_CLS(MPSMatrixVectorMultiplication)); -} - -_MTL_INLINE MatrixVectorMultiplication* MatrixVectorMultiplication::init( - MTL::Device* device, - bool transpose, - NS::UInteger rows, - NS::UInteger columns, - double alpha, - double beta) { - return Object::sendMessage( - this, - _MPS_PRIVATE_SEL(initWithDevice_transpose_rows_columns_alpha_beta), - device, - transpose, - rows, - columns, - alpha, - beta); -} - -_MTL_INLINE void MatrixVectorMultiplication::encodeToCommandBuffer( - MTL::CommandBuffer* commandBuffer, - Matrix* inputMatrix, - Vector* inputVector, - Vector* resultVector) { - return Object::sendMessage( - this, - _MPS_PRIVATE_SEL( - encodeToCommandBuffer_inputMatrix_inputVector_resultVector), - commandBuffer, - inputMatrix, - inputVector, - resultVector); -} - -} // namespace MPS \ No newline at end of file diff --git a/mlx/backend/metal/nojit_kernels.cpp b/mlx/backend/metal/nojit_kernels.cpp index ce8a3f582..361e28916 100644 --- a/mlx/backend/metal/nojit_kernels.cpp +++ b/mlx/backend/metal/nojit_kernels.cpp @@ -169,6 +169,23 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel( return d.get_kernel(kernel_name); } +MTL::ComputePipelineState* get_gemv_masked_kernel( + metal::Device& d, + const std::string& kernel_name, + const array&, + const std::optional&, + const std::optional&, + bool, + int, + int, + int, + int, + int, + int, + bool) { + return d.get_kernel(kernel_name); +} + MTL::ComputePipelineState* get_steel_conv_kernel( metal::Device& d, const std::string& kernel_name, diff --git a/mlx/backend/metal/ternary.cpp b/mlx/backend/metal/ternary.cpp index 66364e7de..8f4371939 100644 --- a/mlx/backend/metal/ternary.cpp +++ b/mlx/backend/metal/ternary.cpp @@ -32,7 +32,7 @@ void ternary_op_gpu_inplace( auto& strides_c = strides[2]; auto& strides_out = strides[3]; - bool use_2d = out.data_size(); + bool use_2d = out.data_size() > UINT_MAX; std::string kernel_name; { std::ostringstream kname; diff --git a/mlx/backend/metal/utils.cpp b/mlx/backend/metal/utils.cpp new file mode 100644 index 000000000..bbb80acbd --- /dev/null +++ b/mlx/backend/metal/utils.cpp @@ -0,0 +1,116 @@ +// Copyright © 2023-2024 Apple Inc. + +#include "mlx/backend/metal/utils.h" + +using namespace mlx; + +namespace mlx::core { + +std::string type_to_name(const array& a) { + std::string tname; + switch (a.dtype()) { + case bool_: + tname = "bool_"; + break; + case uint8: + tname = "uint8"; + break; + case uint16: + tname = "uint16"; + break; + case uint32: + tname = "uint32"; + break; + case uint64: + tname = "uint64"; + break; + case int8: + tname = "int8"; + break; + case int16: + tname = "int16"; + break; + case int32: + tname = "int32"; + break; + case int64: + tname = "int64"; + break; + case float16: + tname = "float16"; + break; + case float32: + tname = "float32"; + break; + case bfloat16: + tname = "bfloat16"; + break; + case complex64: + tname = "complex64"; + break; + } + return tname; +} + +MTL::Size get_block_dims(int dim0, int dim1, int dim2) { + int pows[3] = {0, 0, 0}; + int sum = 0; + while (true) { + int presum = sum; + // Check all the pows + if (dim0 >= (1 << (pows[0] + 1))) { + pows[0]++; + sum++; + } + if (sum == 10) { + break; + } + if (dim1 >= (1 << (pows[1] + 1))) { + pows[1]++; + sum++; + } + if (sum == 10) { + break; + } + if (dim2 >= (1 << (pows[2] + 1))) { + pows[2]++; + sum++; + } + if (sum == presum || sum == 10) { + break; + } + } + return MTL::Size{1ul << pows[0], 1ul << pows[1], 1ul << pows[2]}; +} + +MTL::Size get_2d_grid_dims( + const std::vector& shape, + const std::vector& strides) { + // Dims with strides of 0 are ignored as they + // correspond to broadcasted dimensions + size_t grid_x = 1; + size_t grid_y = 1; + for (int i = 0; i < shape.size(); ++i) { + if (strides[i] == 0) { + continue; + } + if (grid_x * shape[i] < UINT32_MAX) { + grid_x *= shape[i]; + } else { + grid_y *= shape[i]; + } + } + if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) { + throw std::runtime_error("Unable to safely factor shape."); + } + return MTL::Size( + static_cast(grid_x), static_cast(grid_y), 1); +} + +std::string get_primitive_string(Primitive* primitive) { + std::ostringstream op_t; + primitive->print(op_t); + return op_t.str(); +} + +} // namespace mlx::core diff --git a/mlx/backend/metal/utils.h b/mlx/backend/metal/utils.h index 0a24e431a..af85281a9 100644 --- a/mlx/backend/metal/utils.h +++ b/mlx/backend/metal/utils.h @@ -8,8 +8,6 @@ namespace mlx::core { -namespace { - using metal::CommandEncoder; template @@ -27,82 +25,13 @@ set_vector_bytes(CommandEncoder& enc, const std::vector& vec, int idx) { return set_vector_bytes(enc, vec, vec.size(), idx); } -std::string type_to_name(const array& a) { - std::string tname; - switch (a.dtype()) { - case bool_: - tname = "bool_"; - break; - case uint8: - tname = "uint8"; - break; - case uint16: - tname = "uint16"; - break; - case uint32: - tname = "uint32"; - break; - case uint64: - tname = "uint64"; - break; - case int8: - tname = "int8"; - break; - case int16: - tname = "int16"; - break; - case int32: - tname = "int32"; - break; - case int64: - tname = "int64"; - break; - case float16: - tname = "float16"; - break; - case float32: - tname = "float32"; - break; - case bfloat16: - tname = "bfloat16"; - break; - case complex64: - tname = "complex64"; - break; - } - return tname; -} +std::string type_to_name(const array& a); -MTL::Size get_block_dims(int dim0, int dim1, int dim2) { - int pows[3] = {0, 0, 0}; - int sum = 0; - while (true) { - int presum = sum; - // Check all the pows - if (dim0 >= (1 << (pows[0] + 1))) { - pows[0]++; - sum++; - } - if (sum == 10) { - break; - } - if (dim1 >= (1 << (pows[1] + 1))) { - pows[1]++; - sum++; - } - if (sum == 10) { - break; - } - if (dim2 >= (1 << (pows[2] + 1))) { - pows[2]++; - sum++; - } - if (sum == presum || sum == 10) { - break; - } - } - return MTL::Size{1ul << pows[0], 1ul << pows[1], 1ul << pows[2]}; -} +// Compute the thread block dimensions which fit the given +// input dimensions. +// - The thread block dimensions will be powers of two +// - The thread block size will be less than 1024 +MTL::Size get_block_dims(int dim0, int dim1, int dim2); // Computes a 2D grid where each element is < UINT_MAX // Assumes: @@ -111,27 +40,7 @@ MTL::Size get_block_dims(int dim0, int dim1, int dim2) { // possibly broadcasted array MTL::Size get_2d_grid_dims( const std::vector& shape, - const std::vector& strides) { - // Dims with strides of 0 are ignored as they - // correspond to broadcasted dimensions - size_t grid_x = 1; - size_t grid_y = 1; - for (int i = 0; i < shape.size(); ++i) { - if (strides[i] == 0) { - continue; - } - if (grid_x * shape[i] < UINT32_MAX) { - grid_x *= shape[i]; - } else { - grid_y *= shape[i]; - } - } - if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) { - throw std::runtime_error("Unable to safely factor shape."); - } - return MTL::Size( - static_cast(grid_x), static_cast(grid_y), 1); -} + const std::vector& strides); inline NS::String* make_string(std::ostringstream& os) { std::string string = os.str(); @@ -159,12 +68,6 @@ inline void debug_set_primitive_buffer_label( #endif } -std::string get_primitive_string(Primitive* primitive) { - std::ostringstream op_t; - primitive->print(op_t); - return op_t.str(); -} - -} // namespace +std::string get_primitive_string(Primitive* primitive); } // namespace mlx::core diff --git a/mlx/dtype.cpp b/mlx/dtype.cpp index 465e51bd5..cf78a7805 100644 --- a/mlx/dtype.cpp +++ b/mlx/dtype.cpp @@ -1,11 +1,8 @@ // Copyright © 2023-2024 Apple Inc. #include -#include -#include #include "mlx/dtype.h" -#include "mlx/utils.h" namespace mlx::core { @@ -178,67 +175,4 @@ bool issubdtype(const Dtype::Category& a, const Dtype::Category& b) { [static_cast(b)]; } -// Array protocol typestring for Dtype -std::string dtype_to_array_protocol(const Dtype& t) { - std::ostringstream r; - if (size_of(t) > 1) - r << (is_big_endian() ? ">" : "<"); - else - r << "|"; - r << kindof(t) << (int)size_of(t); - return r.str(); -} - -// Dtype from array protocol type string -Dtype dtype_from_array_protocol(std::string_view t) { - if (t.length() == 2 || t.length() == 3) { - std::string_view r = t.length() == 3 ? t.substr(1, 2) : t; - - if (r == "V2") { - return bfloat16; - } - - uint8_t size = r[1] - '0'; - - switch (r[0]) { - case 'b': { - if (size == 1) - return bool_; - } - case 'i': { - if (size == 1) - return int8; - else if (size == 2) - return int16; - else if (size == 4) - return int32; - else if (size == 8) - return int64; - } - case 'u': { - if (size == 1) - return uint8; - else if (size == 2) - return uint16; - else if (size == 4) - return uint32; - else if (size == 8) - return uint64; - } - case 'f': { - if (size == 2) - return float16; - else if (size == 4) - return float32; - } - case 'c': { - return complex64; - } - } - } - - throw std::invalid_argument( - "[from_str] Invalid array protocol type-string: " + std::string(t)); -} - } // namespace mlx::core diff --git a/mlx/dtype.h b/mlx/dtype.h index bba650278..5f9ee27cc 100644 --- a/mlx/dtype.h +++ b/mlx/dtype.h @@ -4,8 +4,6 @@ #include #include -#include -#include #include "mlx/types/complex.h" #include "mlx/types/half_types.h" @@ -103,9 +101,4 @@ struct TypeToDtype { operator Dtype(); }; -// Array protocol typestring for Dtype -std::string dtype_to_array_protocol(const Dtype& t); -// Dtype from array protocol type string -Dtype dtype_from_array_protocol(std::string_view t); - } // namespace mlx::core diff --git a/mlx/io/load.cpp b/mlx/io/load.cpp index 294a1229f..d9403e4b2 100644 --- a/mlx/io/load.cpp +++ b/mlx/io/load.cpp @@ -26,6 +26,80 @@ constexpr uint8_t MAGIC[] = { 0x59, }; +inline bool is_big_endian() { + union ByteOrder { + int32_t i; + uint8_t c[4]; + }; + ByteOrder b = {0x01234567}; + + return b.c[0] == 0x01; +} + +// Array protocol typestring for Dtype +std::string dtype_to_array_protocol(const Dtype& t) { + std::ostringstream r; + if (size_of(t) > 1) { + r << (is_big_endian() ? ">" : "<"); + } else { + r << "|"; + } + r << kindof(t) << (int)size_of(t); + return r.str(); +} + +// Dtype from array protocol type string +Dtype dtype_from_array_protocol(std::string_view t) { + if (t.length() == 2 || t.length() == 3) { + std::string_view r = t.length() == 3 ? t.substr(1, 2) : t; + + if (r == "V2") { + return bfloat16; + } + + uint8_t size = r[1] - '0'; + + switch (r[0]) { + case 'b': { + if (size == 1) + return bool_; + } + case 'i': { + if (size == 1) + return int8; + else if (size == 2) + return int16; + else if (size == 4) + return int32; + else if (size == 8) + return int64; + } + case 'u': { + if (size == 1) + return uint8; + else if (size == 2) + return uint16; + else if (size == 4) + return uint32; + else if (size == 8) + return uint64; + } + case 'f': { + if (size == 2) + return float16; + else if (size == 4) + return float32; + } + case 'c': { + return complex64; + } + } + } + + throw std::invalid_argument( + "[from_str] Invalid array protocol type-string: " + std::string(t)); +} + } // namespace /** Save array to out stream in .npy format */ diff --git a/mlx/utils.h b/mlx/utils.h index ee2512554..9f39f5092 100644 --- a/mlx/utils.h +++ b/mlx/utils.h @@ -84,16 +84,6 @@ int check_shape_dim(const T dim) { return static_cast(dim); } -inline bool is_big_endian() { - union ByteOrder { - int32_t i; - uint8_t c[4]; - }; - ByteOrder b = {0x01234567}; - - return b.c[0] == 0x01; -} - /** * Returns the axis normalized to be in the range [0, ndim). * Based on numpy's normalize_axis_index. See diff --git a/tests/array_tests.cpp b/tests/array_tests.cpp index 62341c5c7..5144243e7 100644 --- a/tests/array_tests.cpp +++ b/tests/array_tests.cpp @@ -161,6 +161,8 @@ TEST_CASE("test array types") { // bfloat16 { basic_dtype_test(bfloat16_t, bfloat16); } +#undef basic_dtype_test + // uint32 { uint32_t val = UINT_MAX; @@ -233,31 +235,6 @@ TEST_CASE("test array types") { CHECK_EQ(x.dtype(), complex64); CHECK_EQ(x.item(), v); } - -#undef basic_dtype_test - -#define basic_dtype_str_test(s, dtype) \ - CHECK_EQ(s, dtype_to_array_protocol(dtype)); \ - CHECK_EQ(dtype_from_array_protocol(s), dtype); - - // To and from str - { - basic_dtype_str_test("|b1", bool_); - basic_dtype_str_test("|u1", uint8); - basic_dtype_str_test("