diff --git a/mlx/backend/cuda/device/binary_ops.cuh b/mlx/backend/cuda/device/binary_ops.cuh index 4779a6f33..b96a7f9cc 100644 --- a/mlx/backend/cuda/device/binary_ops.cuh +++ b/mlx/backend/cuda/device/binary_ops.cuh @@ -194,6 +194,13 @@ struct Power { } return res; } else if constexpr (cuda::std::is_same_v) { + if (base.y == 0 && base.x == 0) { + if (isnan(exp.x) || isnan(exp.y)) { + auto nan = cuda::std::numeric_limits::quiet_NaN(); + return make_cuFloatComplex(nan, nan); + } + return make_cuFloatComplex(0.0, 0.0); + } auto x_theta = atan2f(base.y, base.x); auto x_ln_r = 0.5 * logf(base.x * base.x + base.y * base.y); auto mag = expf(exp.x * x_ln_r - exp.y * x_theta); diff --git a/mlx/backend/cuda/jit_module.cpp b/mlx/backend/cuda/jit_module.cpp index b8be103cc..8a033523c 100644 --- a/mlx/backend/cuda/jit_module.cpp +++ b/mlx/backend/cuda/jit_module.cpp @@ -145,7 +145,7 @@ bool compiler_supports_device_sass(Device& device) { } } -#define INCLUDE_PREFIX "mlx/backend/cuda/kernels/" +#define INCLUDE_PREFIX "mlx/backend/cuda/device/" constexpr const char* g_include_names[] = { INCLUDE_PREFIX "atomic_ops.cuh", diff --git a/mlx/backend/cuda/matmul.cpp b/mlx/backend/cuda/matmul.cpp index 89247fd3e..9930c75b8 100644 --- a/mlx/backend/cuda/matmul.cpp +++ b/mlx/backend/cuda/matmul.cpp @@ -44,9 +44,12 @@ class MatMul { int64_t b_batch_stride) { heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED; - auto type = dtype_to_cuda_type(dtype); + auto scale_type = dtype_to_cuda_type(dtype); + if (dtype == bfloat16) { + scale_type = CUDA_R_32F; + } CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate( - &matmul_desc_, dtype_to_compute_type(dtype), type)); + &matmul_desc_, dtype_to_compute_type(dtype), scale_type)); int32_t pointer_mode = CUBLASLT_POINTER_MODE_HOST; CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( matmul_desc_, @@ -65,6 +68,7 @@ class MatMul { &op, sizeof(cublasOperation_t))); + auto type = dtype_to_cuda_type(dtype); a_desc_ = create_matrix_layout( type, a_rows, a_cols, a_transposed, lda, batch_count, a_batch_stride); b_desc_ = create_matrix_layout( @@ -187,15 +191,10 @@ class MatMul { private: cublasComputeType_t dtype_to_compute_type(Dtype dtype) { switch (dtype) { - case uint8: - case uint16: - case int8: - case int16: - case int32: - return CUBLAS_COMPUTE_32I; case float16: - case bfloat16: return CUBLAS_COMPUTE_16F; + case bfloat16: + return CUBLAS_COMPUTE_32F; case float32: return CUBLAS_COMPUTE_32F; case float64: @@ -209,16 +208,6 @@ class MatMul { cudaDataType_t dtype_to_cuda_type(Dtype dtype) { switch (dtype) { - case uint8: - return CUDA_R_8U; - case uint16: - return CUDA_R_16U; - case int8: - return CUDA_R_8I; - case int16: - return CUDA_R_16I; - case int32: - return CUDA_R_32I; case float16: return CUDA_R_16F; case bfloat16: diff --git a/mlx/backend/metal/conv.cpp b/mlx/backend/metal/conv.cpp index 697afa6a1..9eb6a6385 100644 --- a/mlx/backend/metal/conv.cpp +++ b/mlx/backend/metal/conv.cpp @@ -155,26 +155,26 @@ void explicit_gemm_conv_group_ND_gpu( // Perform gemm std::vector copies = {in_unfolded, wt_transpose}; return steel_matmul_regular( - s, - d, - /* a = */ in_unfolded, - /* b = */ wt_transpose, - /* c = */ out, - /* M = */ implicit_M, - /* N = */ implicit_N, - /* K = */ implicit_K, - /* batch_size_out = */ groups, - /* a_cols = */ implicit_K * groups, - /* b_cols = */ implicit_K, - /* out_cols = */ implicit_N * groups, - /* a_transposed = */ false, - /* b_transposed = */ true, - /* batch_shape = */ {1}, - /* batch_strides = */ {0}, - /* A_batch_strides = */ size_t(implicit_K), - /* B_batch_strides = */ size_t(implicit_N) * implicit_K, - /* matrix_stride_out = */ size_t(implicit_N), - /*copies = */ copies); + /* const Stream& s = */ s, + /* Device& d = */ d, + /* const array& a = */ in_unfolded, + /* const array& b = */ wt_transpose, + /* array& c = */ out, + /* int M = */ implicit_M, + /* int N = */ implicit_N, + /* int K = */ implicit_K, + /* int batch_size_out = */ groups, + /* int lda = */ implicit_K * groups, + /* int ldb = */ implicit_K, + /* int ldd = */ implicit_N * groups, + /* bool transpose_a = */ false, + /* bool transpose_b = */ true, + /* std::vector& copies = */ copies, + /* Shape batch_shape = */ {1}, + /* Strides batch_strides = */ {0}, + /* int64_t A_batch_strides = */ int64_t(implicit_K), + /* int64_t B_batch_strides = */ int64_t(implicit_N) * implicit_K, + /* int64_t matrix_stride_out = */ int64_t(implicit_N)); } void implicit_gemm_conv_2D_gpu( diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index 425274361..88835eb75 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -297,6 +297,9 @@ Device::Device() { device_ = load_device(); default_library_ = load_default_library(device_); arch_ = std::string(device_->architecture()->name()->utf8String()); + int ag_tens = arch_[arch_.size() - 3] - '0'; + int ag_ones = arch_[arch_.size() - 2] - '0'; + arch_gen_ = ag_tens * 10 + ag_ones; auto arch = arch_.back(); switch (arch) { case 'p': // phone diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index 5bfcc6649..f87a8c48b 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -177,6 +177,10 @@ class Device { return arch_; } + int get_architecture_gen() const { + return arch_gen_; + } + void new_queue(int index); MTL::CommandQueue* get_queue(Stream stream); @@ -268,6 +272,7 @@ class Device { library_kernels_; const MTL::ResidencySet* residency_set_{nullptr}; std::string arch_; + int arch_gen_; int max_ops_per_buffer_; int max_mb_per_buffer_; }; diff --git a/mlx/backend/metal/kernels/binary_ops.h b/mlx/backend/metal/kernels/binary_ops.h index 4aaf2b4da..f4deb860e 100644 --- a/mlx/backend/metal/kernels/binary_ops.h +++ b/mlx/backend/metal/kernels/binary_ops.h @@ -235,6 +235,13 @@ struct Power { template <> complex64_t operator()(complex64_t x, complex64_t y) { + if (x.real == 0 && x.imag == 0) { + if (metal::isnan(y.real) || metal::isnan(y.imag)) { + auto nan = metal::numeric_limits::quiet_NaN(); + return {nan, nan}; + } + return {0.0, 0.0}; + } auto x_theta = metal::atan2(x.imag, x.real); auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag); auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta); diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h index add495d93..85830872d 100644 --- a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h @@ -33,8 +33,8 @@ template < device T* D [[buffer(3)]], const constant GEMMParams* params [[buffer(4)]], const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]], - const constant int* batch_shape [[buffer(6)]], - const constant int64_t* batch_strides [[buffer(7)]], + const constant int* batch_shape [[buffer(6), function_constant(has_batch)]], + const constant int64_t* batch_strides [[buffer(7), function_constant(has_batch)]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]], uint3 tid [[threadgroup_position_in_grid]], diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index ed96d37ea..be7f3e2f8 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -164,11 +164,17 @@ ensure_batch_contiguous(const array& x, metal::Device& d, const Stream& s) { wn = 2; \ } -void steel_matmul_regular( +/////////////////////////////////////////////////////////////////////////////// +// Regular steel matmul dispatch +/////////////////////////////////////////////////////////////////////////////// + +template +void steel_matmul_regular_axpby( const Stream& s, metal::Device& d, const array& a, const array& b, + const array& c, array& out, int M, int N, @@ -179,12 +185,15 @@ void steel_matmul_regular( int ldd, bool transpose_a, bool transpose_b, + std::vector& copies, Shape batch_shape, Strides batch_strides, int64_t A_batch_stride, int64_t B_batch_stride, int64_t matrix_stride_out, - std::vector& copies) { + int64_t C_batch_stride /* = 0*/, + float alpha /* = 1.0f */, + float beta /* = 0.0f */) { using namespace mlx::steel; // Determine dispatch kernel @@ -196,16 +205,21 @@ void steel_matmul_regular( // Prepare kernel name std::ostringstream kname; - kname << "steel_gemm_fused_" << (transpose_a ? 't' : 'n') - << (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_" - << type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk - << "_wm" << wm << "_wn" << wn; + + // clang-format off + kname << "steel_gemm_fused_" + << (transpose_a ? 't' : 'n') + << (transpose_b ? 't' : 'n') + << "_" << type_to_name(a) + << "_" << type_to_name(out) + << "_bm" << bm << "_bn" << bn << "_bk" << bk + << "_wm" << wm << "_wn" << wn; // clang-format on std::string base_name = kname.str(); const bool has_batch = (batch_shape.size() > 1); - const bool use_out_source = false; - const bool do_axpby = false; + const bool use_out_source = CHECK_AB && (alpha != 0.0f || beta != 1.0f); + const bool do_axpby = use_out_source && (alpha != 1.0f || beta != 1.0f); const bool align_M = (M % bm) == 0; const bool align_N = (N % bn) == 0; const bool align_K = (K % bk) == 0; @@ -232,18 +246,18 @@ void steel_matmul_regular( // Encode and dispatch kernel auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = get_steel_gemm_fused_kernel( - d, - base_name, - hash_name, - func_consts, - out, - transpose_a, - transpose_b, - bm, - bn, - bk, - wm, - wn); + /* metal::Device& d = */ d, + /* const std::string& kernel_name = */ base_name, + /* const std::string& hash_name = */ hash_name, + /* const metal::MTLFCList& func_consts = */ func_consts, + /* const array& out = */ out, + /* bool transpose_a = */ transpose_a, + /* bool transpose_b = */ transpose_b, + /* int bm = */ bm, + /* int bn = */ bn, + /* int bk = */ bk, + /* int wm = */ wm, + /* int wn = */ wn); compute_encoder.set_compute_pipeline_state(kernel); @@ -286,8 +300,25 @@ void steel_matmul_regular( compute_encoder.set_bytes(params, 4); - compute_encoder.set_vector_bytes(batch_shape, 6); - compute_encoder.set_vector_bytes(batch_strides, 7); + if (has_batch) { + compute_encoder.set_vector_bytes(batch_shape, 6); + compute_encoder.set_vector_bytes(batch_strides, 7); + } + + if (use_out_source) { + int ldc = c.strides()[c.ndim() - 2]; + int fdc = c.strides()[c.ndim() - 1]; + + GEMMAddMMParams params{ + /* const int ldc = */ ldc, + /* const int fdc = */ fdc, + /* const int64_t batch_stride_c = */ C_batch_stride, + /* const float alpha = */ alpha, + /* const float beta = */ beta}; + + compute_encoder.set_input_array(c, 2); + compute_encoder.set_bytes(params, 5); + } compute_encoder.dispatch_threadgroups(grid_dims, group_dims); @@ -295,7 +326,437 @@ void steel_matmul_regular( d.add_temporaries(std::move(copies), s.index); } -void steel_matmul( +/////////////////////////////////////////////////////////////////////////////// +// Split k steel matmul +/////////////////////////////////////////////////////////////////////////////// + +template +void steel_gemm_splitk_axpby( + const Stream& s, + metal::Device& d, + const array& a, + const array& b, + const array& c, + array& out, + int M, + int N, + int K, + int batch_size_out, + int lda, + int ldb, + bool transpose_a, + bool transpose_b, + std::vector& copies, + float alpha = 1.0f, + float beta = 0.0f) { + using namespace mlx::steel; + + int _tm = M / 16; + int _tn = N / 16; + int _tk = K / 16; + + int bm = M < 40 ? 16 : 32; + int bn = N < 40 ? 16 : 32; + int bk = 16; + int wm = 2, wn = 2; + + int split_k_partitions = _tk < 16 ? 2 : (_tk < 32 ? 4 : (_tk < 64 ? 8 : 16)); + int split_k_partition_stride = M * N; + int gemm_k_iterations = (K / bk) / split_k_partitions; + int split_k_partition_size = gemm_k_iterations * bk; + + array C_split({split_k_partitions, M, N}, float32, nullptr, {}); + C_split.set_data(allocator::malloc(C_split.nbytes())); + copies.push_back(C_split); + + bool mn_aligned = M % bm == 0 && N % bn == 0; + bool k_aligned = K % bk == 0; + std::ostringstream kname; + + // clang-format off + kname << "steel_gemm_splitk_" + << (transpose_a ? 't' : 'n') + << (transpose_b ? 't' : 'n') + << "_" << type_to_name(a) + << "_" << type_to_name(C_split) + << "_bm" << bm << "_bn" << bn << "_bk" << bk + << "_wm" << wm << "_wn" << wn + << "_MN_" << (mn_aligned ? "t" : "n") << "aligned" + << "_K_" << (k_aligned ? "t" : "n") << "aligned"; // clang-format on + + // Encode and dispatch gemm kernel + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = get_steel_gemm_splitk_kernel( + /* metal::Device& d = */ d, + /* const std::string& kernel_name = */ kname.str(), + /* const array& in = */ a, + /* const array& out = */ C_split, + /* bool transpose_a = */ transpose_a, + /* bool transpose_b = */ transpose_b, + /* int bm = */ bm, + /* int bn = */ bn, + /* int bk = */ bk, + /* int wm = */ wm, + /* int wn = */ wn, + /* bool mn_aligned = */ mn_aligned, + /* bool k_aligned = */ k_aligned); + + compute_encoder.set_compute_pipeline_state(kernel); + + int tn = (N + bn - 1) / bn; + int tm = (M + bm - 1) / bm; + + GEMMSpiltKParams params{ + /* const int M = */ M, + /* const int N = */ N, + /* const int K = */ K, + /* const int lda = */ lda, + /* const int ldb = */ ldb, + /* const int ldc = */ N, + /* const int tiles_n = */ tn, + /* const int tiles_m = */ tm, + /* const int split_k_partitions = */ split_k_partitions, + /* const int split_k_partition_stride = */ split_k_partition_stride, + /* const int split_k_partition_size = */ split_k_partition_size, + /* const int gemm_k_iterations_aligned = */ gemm_k_iterations}; + + MTL::Size group_dims = MTL::Size(32, wn, wm); + MTL::Size grid_dims = MTL::Size(tn, tm, split_k_partitions); + + compute_encoder.set_input_array(a, 0); + compute_encoder.set_input_array(b, 1); + compute_encoder.set_output_array(C_split, 2); + + compute_encoder.set_bytes(params, 3); + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); + + // Do accum kernel + { + const bool do_axpby = CHECK_AB && (alpha != 1.0f || beta != 0.0f); + + auto kernel_name = "steel_gemm_splitk_accum_" + type_to_name(out) + "_" + + type_to_name(C_split); + + if (do_axpby) { + kernel_name = kernel_name + "_axbpy"; + } + + auto kernel = get_steel_gemm_splitk_accum_kernel( + /* metal::Device& d = */ d, + /* const std::string& kernel_name = */ kernel_name, + /* const array& in = */ C_split, + /* const array& out = */ out, + /* bool axbpy = */ do_axpby); + compute_encoder.set_compute_pipeline_state(kernel); + + // Set the arguments for the kernel + compute_encoder.set_input_array(C_split, 0); + compute_encoder.set_output_array(out, 1); + compute_encoder.set_bytes(split_k_partitions, 2); + compute_encoder.set_bytes(split_k_partition_stride, 3); + compute_encoder.set_bytes(N, 4); + + if (do_axpby) { + int ldc = c.strides()[c.ndim() - 2]; + int fdc = c.strides()[c.ndim() - 1]; + + compute_encoder.set_input_array(c, 5); + compute_encoder.set_bytes(ldc, 6); + compute_encoder.set_bytes(fdc, 7); + compute_encoder.set_bytes(alpha, 8); + compute_encoder.set_bytes(beta, 9); + } + + // Launch enough thread groups for each output + MTL::Size grid_dims = MTL::Size(N, M, 1); + auto group_dims = get_block_dims(N, M, 1); + compute_encoder.dispatch_threads(grid_dims, group_dims); + } + + d.add_temporaries(std::move(copies), s.index); +} + +/////////////////////////////////////////////////////////////////////////////// +// Split matmul routing +/////////////////////////////////////////////////////////////////////////////// + +template +void steel_matmul_axpby( + const Stream& s, + metal::Device& d, + const array& a, + const array& b, + const array& c, + array& out, + int M, + int N, + int K, + int batch_size_out, + int lda, + int ldb, + bool transpose_a, + bool transpose_b, + std::vector& copies, + Shape batch_shape /* = {} */, + Strides A_batch_stride /* = {} */, + Strides B_batch_stride /* = {} */, + Strides C_batch_stride /* = {} */, + float alpha /* = 1.0f */, + float beta /* = 0.0f */) { + if (batch_shape.empty()) { + ///////////////////////////////////////////////////////////////////////////// + // Check and collapse batch dimensions + if constexpr (CHECK_AB) { + auto [batch_shape_, A_bstride_, B_bstride_, C_bstride_] = + collapse_batches(a, b, c); + + batch_shape = batch_shape_; + A_batch_stride = A_bstride_; + B_batch_stride = B_bstride_; + C_batch_stride = C_bstride_; + // Collapse batches into M if needed + if (batch_size_out > 1 && !transpose_a && batch_shape.size() == 1 && + a.strides()[a.ndim() - 2] == K && A_batch_stride.back() == M * K && + C_batch_stride.back() == M * c.strides()[c.ndim() - 2] && + B_batch_stride.back() == 0) { + M *= batch_shape.back(); + batch_size_out = 1; + + A_batch_stride = {0}; + B_batch_stride = {0}; + C_batch_stride = {0}; + batch_shape = {1}; + } + } else { + auto [batch_shape_, A_bstride_, B_bstride_] = collapse_batches(a, b); + + batch_shape = batch_shape_; + A_batch_stride = A_bstride_; + B_batch_stride = B_bstride_; + // Collapse batches into M if needed + if (batch_size_out > 1 && !transpose_a && batch_shape.size() == 1 && + a.strides()[a.ndim() - 2] == K && A_batch_stride.back() == M * K && + B_batch_stride.back() == 0) { + M *= batch_shape.back(); + batch_size_out = 1; + + A_batch_stride = {0}; + B_batch_stride = {0}; + batch_shape = {1}; + } + } + } + + ///////////////////////////////////////////////////////////////////////////// + // Split K specialization + + int _tm = M / 16; + int _tn = N / 16; + int _tk = K / 16; + + if (batch_size_out == 1 && (_tm * _tn) <= 32 && _tk >= 8) { + return steel_gemm_splitk_axpby( + /* const Stream& s = */ s, + /* metal::Device& d = */ d, + /* const array& a = */ a, + /* const array& b = */ b, + /* const array& c = */ c, + /* array& out = */ out, + /* int M = */ M, + /* int N = */ N, + /* int K = */ K, + /* int batch_size_out = */ batch_size_out, + /* int lda = */ lda, + /* int ldb = */ ldb, + /* bool transpose_a = */ transpose_a, + /* bool transpose_b = */ transpose_b, + /* std::vector& copies = */ copies, + /* float alpha = */ alpha, + /* float beta = */ beta); + } + + ///////////////////////////////////////////////////////////////////////////// + // Regular kernel dispatch + auto batch_strides = A_batch_stride; + batch_strides.insert( + batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end()); + if (CHECK_AB && !C_batch_stride.empty()) { + batch_strides.insert( + batch_strides.end(), C_batch_stride.begin(), C_batch_stride.end()); + } + + int64_t A_batch_stride_ = A_batch_stride.empty() ? 0 : A_batch_stride.back(); + int64_t B_batch_stride_ = B_batch_stride.empty() ? 0 : B_batch_stride.back(); + int64_t C_batch_stride_ = C_batch_stride.empty() ? 0 : C_batch_stride.back(); + + return steel_matmul_regular_axpby( + /* const Stream& s = */ s, + /* metal::Device& d = */ d, + /* const array& a = */ a, + /* const array& b = */ b, + /* const array& c = */ c, + /* array& out = */ out, + /* int M = */ M, + /* int N = */ N, + /* int K = */ K, + /* int batch_size_out = */ batch_size_out, + /* int lda = */ lda, + /* int ldb = */ ldb, + /* int ldd = */ N, + /* bool transpose_a = */ transpose_a, + /* bool transpose_b = */ transpose_b, + /* std::vector& copies = */ copies, + /* Shape batch_shape = */ std::move(batch_shape), + /* Strides batch_strides = */ std::move(batch_strides), + /* int64_t A_batch_stride = */ A_batch_stride_, + /* int64_t B_batch_stride = */ B_batch_stride_, + /* int64_t matrix_stride_out = */ int64_t(M) * N, + /* int64_t C_batch_stride = */ C_batch_stride_, + /* float alpha = */ alpha, + /* float beta = */ beta); +} + +/////////////////////////////////////////////////////////////////////////////// +// GEMV dispatch +/////////////////////////////////////////////////////////////////////////////// + +template +void gemv_axbpy( + const Stream& s, + metal::Device& d, + const array& a, + const array& b, + const array& c, + array& out, + int M, + int N, + int K, + int batch_size_out, + int lda, + int ldb, + bool transpose_a, + bool transpose_b, + std::vector& copies, + Shape batch_shape = {}, + Strides A_batch_stride = {}, + Strides B_batch_stride = {}, + Strides C_batch_stride = {}, + float alpha = 1.0f, + float beta = 0.0f) { + // Collect problem info + bool is_b_matrix = N != 1; + + auto& mat = is_b_matrix ? b : a; + auto& vec = is_b_matrix ? a : b; + bool transpose_mat = is_b_matrix ? !transpose_b : transpose_a; + int in_vector_len = K; + int out_vector_len = is_b_matrix ? N : M; + + int mat_cols = transpose_mat ? out_vector_len : in_vector_len; + int mat_rows = transpose_mat ? in_vector_len : out_vector_len; + int mat_ld = is_b_matrix ? ldb : lda; + + auto batch_strides_mat = is_b_matrix ? B_batch_stride : A_batch_stride; + auto batch_strides_vec = is_b_matrix ? A_batch_stride : B_batch_stride; + + int stride_mat = batch_strides_mat.back(); + int stride_vec = batch_strides_vec.back(); + + // Determine if inputs have simple batching / broadcasting + bool contiguous_kernel = (batch_shape.size() == 1); + + int batch_ndim = batch_shape.size(); + + // Determine dispatch kernel + int tm = 4, tn = 4; + int sm = 1, sn = 32; + int bm = 1, bn = 1; + int n_out_per_tgp; + std::ostringstream kname; + + if (transpose_mat) { + if (in_vector_len >= 8192 && out_vector_len >= 2048) { + sm = 4; + sn = 8; + } else { + sm = 8; + sn = 4; + } + + if (out_vector_len >= 2048) { + bn = 16; + } else if (out_vector_len >= 512) { + bn = 4; + } else { + bn = 2; + } + + // Specialized kernel for very small outputs + tn = out_vector_len < tn ? 1 : tn; + + n_out_per_tgp = bn * sn * tn; + kname << "gemv_t_" << type_to_name(out); + + } else { + bm = out_vector_len >= 4096 ? 8 : 4; + sn = 32; + + // Specialized kernel for very small outputs + tm = out_vector_len < tm ? 1 : tm; + + n_out_per_tgp = bm * sm * tm; + kname << "gemv_" << type_to_name(out); + } + + const bool do_axpby = CHECK_AB && (alpha != 1.0f || beta != 0.0f); + + // clang-format off + kname << "_bm" << bm << "_bn" << bn + << "_sm" << sm << "_sn" << sn + << "_tm" << tm << "_tn" << tn + << "_nc" << !contiguous_kernel + << "_axpby" << do_axpby; // clang-format on + + // Encode and dispatch kernel + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel(kname.str()); + compute_encoder.set_compute_pipeline_state(kernel); + + int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp; + MTL::Size group_dims = MTL::Size(32, bn, bm); + MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out); + + compute_encoder.set_input_array(mat, 0); + compute_encoder.set_input_array(vec, 1); + compute_encoder.set_output_array(out, 3); + + compute_encoder.set_bytes(in_vector_len, 4); + compute_encoder.set_bytes(out_vector_len, 5); + compute_encoder.set_bytes(mat_ld, 6); + + compute_encoder.set_bytes(batch_ndim, 9); + compute_encoder.set_vector_bytes(batch_shape, 10); + compute_encoder.set_vector_bytes(batch_strides_vec, 11); + compute_encoder.set_vector_bytes(batch_strides_mat, 12); + + if (do_axpby) { + compute_encoder.set_input_array(c, 2); + + compute_encoder.set_bytes(alpha, 7); + compute_encoder.set_bytes(beta, 8); + + compute_encoder.set_vector_bytes(C_batch_stride, 13); + + int bias_stride = c.strides()[c.ndim() - 1]; + compute_encoder.set_bytes(bias_stride, 14); + } + + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); + + d.add_temporaries(std::move(copies), s.index); +} + +inline void gemv( const Stream& s, metal::Device& d, const array& a, @@ -310,166 +771,34 @@ void steel_matmul( bool transpose_a, bool transpose_b, std::vector& copies, - Shape batch_shape /* = {} */, - Strides A_batch_stride /* = {} */, - Strides B_batch_stride /* = {} */) { - using namespace mlx::steel; - - if (batch_shape.empty()) { - ///////////////////////////////////////////////////////////////////////////// - // Check and collapse batch dimensions - auto [batch_shape_, A_bstride_, B_bstride_] = collapse_batches(a, b); - - batch_shape = batch_shape_; - A_batch_stride = A_bstride_; - B_batch_stride = B_bstride_; - // Collapse batches into M if needed - if (batch_size_out > 1 && !transpose_a && batch_shape.size() == 1 && - a.strides()[a.ndim() - 2] == K && A_batch_stride.back() == M * K && - B_batch_stride.back() == 0) { - M *= batch_shape.back(); - batch_size_out = 1; - - A_batch_stride = {0}; - B_batch_stride = {0}; - batch_shape = {1}; - } - } - - size_t matrix_stride_out = size_t(M) * N; - - ///////////////////////////////////////////////////////////////////////////// - // Split K specialization - - int _tm = M / 16; - int _tn = N / 16; - int _tk = K / 16; - - if (batch_size_out == 1 && (_tm * _tn) <= 32 && _tk >= 8) { - int bm = M < 40 ? 16 : 32; - int bn = N < 40 ? 16 : 32; - int bk = 16; - int wm = 2, wn = 2; - - int split_k_partitions = - _tk < 16 ? 2 : (_tk < 32 ? 4 : (_tk < 64 ? 8 : 16)); - int split_k_partition_stride = M * N; - int gemm_k_iterations = (K / bk) / split_k_partitions; - int split_k_partition_size = gemm_k_iterations * bk; - - array C_split({split_k_partitions, M, N}, float32, nullptr, {}); - C_split.set_data(allocator::malloc(C_split.nbytes())); - copies.push_back(C_split); - - bool mn_aligned = M % bm == 0 && N % bn == 0; - bool k_aligned = K % bk == 0; - std::ostringstream kname; - kname << "steel_gemm_splitk_" << (transpose_a ? 't' : 'n') - << (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_" - << type_to_name(C_split) << "_bm" << bm << "_bn" << bn << "_bk" << bk - << "_wm" << wm << "_wn" << wn << "_MN_" << (mn_aligned ? "t" : "n") - << "aligned" << "_K_" << (k_aligned ? "t" : "n") << "aligned"; - - // Encode and dispatch gemm kernel - auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = get_steel_gemm_splitk_kernel( - d, - kname.str(), - a, - C_split, - transpose_a, - transpose_b, - bm, - bn, - bk, - wm, - wn, - mn_aligned, - k_aligned); - compute_encoder.set_compute_pipeline_state(kernel); - - int tn = (N + bn - 1) / bn; - int tm = (M + bm - 1) / bm; - - GEMMSpiltKParams params{ - /* const int M = */ M, - /* const int N = */ N, - /* const int K = */ K, - /* const int lda = */ lda, - /* const int ldb = */ ldb, - /* const int ldc = */ N, - /* const int tiles_n = */ tn, - /* const int tiles_m = */ tm, - /* const int split_k_partitions = */ split_k_partitions, - /* const int split_k_partition_stride = */ split_k_partition_stride, - /* const int split_k_partition_size = */ split_k_partition_size, - /* const int gemm_k_iterations_aligned = */ gemm_k_iterations}; - - MTL::Size group_dims = MTL::Size(32, wn, wm); - MTL::Size grid_dims = MTL::Size(tn, tm, split_k_partitions); - - compute_encoder.set_input_array(a, 0); - compute_encoder.set_input_array(b, 1); - compute_encoder.set_output_array(C_split, 2); - - compute_encoder.set_bytes(params, 3); - compute_encoder.dispatch_threadgroups(grid_dims, group_dims); - - // Do accum kernel - { - auto kernel_name = "steel_gemm_splitk_accum_" + type_to_name(out) + "_" + - type_to_name(C_split); - - auto kernel = get_steel_gemm_splitk_accum_kernel( - d, kernel_name, C_split, out, false); - compute_encoder.set_compute_pipeline_state(kernel); - - // Set the arguments for the kernel - compute_encoder.set_input_array(C_split, 0); - compute_encoder.set_output_array(out, 1); - compute_encoder.set_bytes(split_k_partitions, 2); - compute_encoder.set_bytes(split_k_partition_stride, 3); - compute_encoder.set_bytes(N, 4); - - // Launch enough thread groups for each output - MTL::Size grid_dims = MTL::Size(N, M, 1); - auto group_dims = get_block_dims(N, M, 1); - compute_encoder.dispatch_threads(grid_dims, group_dims); - } - - d.add_temporaries(std::move(copies), s.index); - return; - } - - ///////////////////////////////////////////////////////////////////////////// - // Regular kernel dispatch - auto batch_strides = A_batch_stride; - batch_strides.insert( - batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end()); - - steel_matmul_regular( - s, - d, - a, - b, - out, - M, - N, - K, - batch_size_out, - lda, - ldb, - N, - transpose_a, - transpose_b, - std::move(batch_shape), - std::move(batch_strides), - A_batch_stride.back(), - B_batch_stride.back(), - matrix_stride_out, - copies); + Shape batch_shape = {}, + Strides A_batch_stride = {}, + Strides B_batch_stride = {}) { + return gemv_axbpy( + /* const Stream& s = */ s, + /* metal::Device& d = */ d, + /* const array& a = */ a, + /* const array& b = */ b, + /* const array& c = */ b, + /* array& out = */ out, + /* int M = */ M, + /* int N = */ N, + /* int K = */ K, + /* int batch_size_out = */ batch_size_out, + /* int lda = */ lda, + /* int ldb = */ ldb, + /* bool transpose_a = */ transpose_a, + /* bool transpose_b = */ transpose_b, + /* std::vector& copies = */ copies, + /* Shape batch_shape = */ batch_shape, + /* Strides A_batch_stride = */ A_batch_stride, + /* Strides B_batch_stride = */ B_batch_stride); } +/////////////////////////////////////////////////////////////////////////////// +// Matmul implementation +/////////////////////////////////////////////////////////////////////////////// + void Matmul::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); if (!issubdtype(out.dtype(), floating)) { @@ -528,102 +857,26 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { // Route to gemv if needed if (std::min(M, N) == 1) { - // Collect problem info - bool is_b_matrix = N != 1; - - auto& mat = is_b_matrix ? b : a; - auto& vec = is_b_matrix ? a : b; - bool transpose_mat = is_b_matrix ? !b_transposed : a_transposed; - int in_vector_len = K; - int out_vector_len = is_b_matrix ? N : M; - - int mat_cols = transpose_mat ? out_vector_len : in_vector_len; - int mat_rows = transpose_mat ? in_vector_len : out_vector_len; - int mat_ld = is_b_matrix ? b_cols : a_cols; - - auto batch_strides_mat = is_b_matrix ? B_batch_stride : A_batch_stride; - auto batch_strides_vec = is_b_matrix ? A_batch_stride : B_batch_stride; - - int stride_mat = batch_strides_mat.back(); - int stride_vec = batch_strides_vec.back(); - - // Determine if inputs have simple batching / broadcasting - bool contiguous_kernel = (batch_shape.size() == 1); - - int batch_ndim = batch_shape.size(); - - // Determine dispatch kernel - int tm = 4, tn = 4; - int sm = 1, sn = 32; - int bm = 1, bn = 1; - int n_out_per_tgp; - std::ostringstream kname; - - if (transpose_mat) { - if (in_vector_len >= 8192 && out_vector_len >= 2048) { - sm = 4; - sn = 8; - } else { - sm = 8; - sn = 4; - } - - if (out_vector_len >= 2048) { - bn = 16; - } else if (out_vector_len >= 512) { - bn = 4; - } else { - bn = 2; - } - - // Specialized kernel for very small outputs - tn = out_vector_len < tn ? 1 : tn; - - n_out_per_tgp = bn * sn * tn; - kname << "gemv_t_" << type_to_name(out); - - } else { - bm = out_vector_len >= 4096 ? 8 : 4; - sn = 32; - - // Specialized kernel for very small outputs - tm = out_vector_len < tm ? 1 : tm; - - n_out_per_tgp = bm * sm * tm; - kname << "gemv_" << type_to_name(out); - } - - kname << "_bm" << bm << "_bn" << bn << "_sm" << sm << "_sn" << sn << "_tm" - << tm << "_tn" << tn; - kname << "_nc" << !contiguous_kernel << "_axpby0"; - - // Encode and dispatch kernel - auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(kname.str()); - compute_encoder.set_compute_pipeline_state(kernel); - - int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp; - MTL::Size group_dims = MTL::Size(32, bn, bm); - MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out); - - compute_encoder.set_input_array(mat, 0); - compute_encoder.set_input_array(vec, 1); - compute_encoder.set_output_array(out, 3); - - compute_encoder.set_bytes(in_vector_len, 4); - compute_encoder.set_bytes(out_vector_len, 5); - compute_encoder.set_bytes(mat_ld, 6); - - compute_encoder.set_bytes(batch_ndim, 9); - compute_encoder.set_vector_bytes(batch_shape, 10); - compute_encoder.set_vector_bytes(batch_strides_vec, 11); - compute_encoder.set_vector_bytes(batch_strides_mat, 12); - - compute_encoder.dispatch_threadgroups(grid_dims, group_dims); - - d.add_temporaries(std::move(copies), s.index); - return; + return gemv( + /* const Stream& s = */ s, + /* metal::Device& d = */ d, + /* const array& a = */ a, + /* const array& b = */ b, + /* array& out = */ out, + /* int M = */ M, + /* int N = */ N, + /* int K = */ K, + /* int batch_size_out = */ batch_size_out, + /* int lda = */ a_cols, + /* int ldb = */ b_cols, + /* bool transpose_a = */ a_transposed, + /* bool transpose_b = */ b_transposed, + /* std::vector& copies = */ copies, + /* Shape batch_shape = */ std::move(batch_shape), + /* Strides A_batch_stride = */ std::move(A_batch_stride), + /* Strides B_batch_stride = */ std::move(B_batch_stride)); } + ///////////////////////////////////////////////////////////////////////////// // Gemm specialization @@ -641,12 +894,16 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { /* int ldb = */ b_cols, /* bool transpose_a = */ a_transposed, /* bool transpose_b = */ b_transposed, - /* std::vector& = */ copies, - /* Shape batch_shape = */ batch_shape, - /* Strides A_batch_stride = */ A_batch_stride, - /* Strides B_batch_stride = */ B_batch_stride); + /* std::vector& copies = */ copies, + /* Shape batch_shape = */ std::move(batch_shape), + /* Strides A_batch_stride = */ std::move(A_batch_stride), + /* Strides B_batch_stride = */ std::move(B_batch_stride)); } +/////////////////////////////////////////////////////////////////////////////// +// AddMM implementation +/////////////////////////////////////////////////////////////////////////////// + void AddMM::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 3); if (!issubdtype(out.dtype(), floating)) { @@ -726,346 +983,61 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { // Route to gemv if needed if (std::min(M, N) == 1) { - // Collect problem info - bool is_b_matrix = N != 1; - - auto& mat = is_b_matrix ? b : a; - auto& vec = is_b_matrix ? a : b; - bool transpose_mat = is_b_matrix ? !transpose_b : transpose_a; - int in_vector_len = K; - int out_vector_len = is_b_matrix ? N : M; - - int mat_cols = transpose_mat ? out_vector_len : in_vector_len; - int mat_rows = transpose_mat ? in_vector_len : out_vector_len; - int mat_ld = is_b_matrix ? b_cols : a_cols; - - auto batch_strides_mat = is_b_matrix ? B_batch_stride : A_batch_stride; - auto batch_strides_vec = is_b_matrix ? A_batch_stride : B_batch_stride; - - int stride_mat = batch_strides_mat.back(); - int stride_vec = batch_strides_vec.back(); - - // Determine if inputs have simple batching / broadcasting - bool contiguous_kernel = (batch_shape.size() == 1); - - int batch_ndim = batch_shape.size(); - - // Determine dispatch kernel - int tm = 4, tn = 4; - int sm = 1, sn = 32; - int bm = 1, bn = 1; - int n_out_per_tgp; - std::ostringstream kname; - - if (transpose_mat) { - if (in_vector_len >= 8192 && out_vector_len >= 2048) { - sm = 4; - sn = 8; - } else { - sm = 8; - sn = 4; - } - - if (out_vector_len >= 2048) { - bn = 16; - } else if (out_vector_len >= 512) { - bn = 4; - } else { - bn = 2; - } - - // Specialized kernel for very small outputs - tn = out_vector_len < tn ? 1 : tn; - - n_out_per_tgp = bn * sn * tn; - kname << "gemv_t_" << type_to_name(out); - - } else { - bm = out_vector_len >= 4096 ? 8 : 4; - sn = 32; - - // Specialized kernel for very small outputs - tm = out_vector_len < tm ? 1 : tm; - - n_out_per_tgp = bm * sm * tm; - kname << "gemv_" << type_to_name(out); - } - - kname << "_bm" << bm << "_bn" << bn << "_sm" << sm << "_sn" << sn << "_tm" - << tm << "_tn" << tn; - kname << "_nc" << !contiguous_kernel << "_axpby1"; - - // Encode and dispatch kernel - auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(kname.str()); - compute_encoder.set_compute_pipeline_state(kernel); - - int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp; - MTL::Size group_dims = MTL::Size(32, bn, bm); - MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out); - - compute_encoder.set_input_array(mat, 0); - compute_encoder.set_input_array(vec, 1); - compute_encoder.set_input_array(c, 2); - compute_encoder.set_output_array(out, 3); - - compute_encoder.set_bytes(in_vector_len, 4); - compute_encoder.set_bytes(out_vector_len, 5); - compute_encoder.set_bytes(mat_ld, 6); - - compute_encoder.set_bytes(alpha_, 7); - compute_encoder.set_bytes(beta_, 8); - - compute_encoder.set_bytes(batch_ndim, 9); - compute_encoder.set_vector_bytes(batch_shape, 10); - compute_encoder.set_vector_bytes(batch_strides_vec, 11); - compute_encoder.set_vector_bytes(batch_strides_mat, 12); - compute_encoder.set_vector_bytes(C_batch_stride, 13); - - int bias_stride = c.strides()[c.ndim() - 1]; - compute_encoder.set_bytes(bias_stride, 14); - - compute_encoder.dispatch_threadgroups(grid_dims, group_dims); - - d.add_temporaries(std::move(copies), s.index); - return; - } - - using namespace mlx::steel; - - ///////////////////////////////////////////////////////////////////////////// - // Split K specialization - - int _tm = M / 16; - int _tn = N / 16; - int _tk = K / 16; - - if (batch_size_out == 1 && (_tm * _tn) <= 32 && _tk >= 8) { - int bm = M < 40 ? 16 : 32; - int bn = N < 40 ? 16 : 32; - int bk = 16; - int wm = 2, wn = 2; - - int split_k_partitions = - _tk < 16 ? 2 : (_tk < 32 ? 4 : (_tk < 64 ? 8 : 16)); - int split_k_partition_stride = M * N; - int gemm_k_iterations = (K / bk) / split_k_partitions; - int split_k_partition_size = gemm_k_iterations * bk; - - array C_split({split_k_partitions, M, N}, float32, nullptr, {}); - C_split.set_data(allocator::malloc(C_split.nbytes())); - copies.push_back(C_split); - - bool mn_aligned = M % bm == 0 && N % bn == 0; - bool k_aligned = K % bk == 0; - - std::ostringstream kname; - kname << "steel_gemm_splitk_" << (transpose_a ? 't' : 'n') - << (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_" - << type_to_name(C_split) << "_bm" << bm << "_bn" << bn << "_bk" << bk - << "_wm" << wm << "_wn" << wn << "_MN_" << (mn_aligned ? "t" : "n") - << "aligned" << "_K_" << (k_aligned ? "t" : "n") << "aligned"; - - // Encode and dispatch gemm kernel - auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = get_steel_gemm_splitk_kernel( - d, - kname.str(), - a, - C_split, - transpose_a, - transpose_b, - bm, - bn, - bk, - wm, - wn, - mn_aligned, - k_aligned); - - compute_encoder.set_compute_pipeline_state(kernel); - - int tn = (N + bn - 1) / bn; - int tm = (M + bm - 1) / bm; - - GEMMSpiltKParams params{ - M, - N, - K, - lda, - ldb, - N, - tn, - tm, - split_k_partitions, - split_k_partition_stride, - split_k_partition_size, - gemm_k_iterations}; - - MTL::Size group_dims = MTL::Size(32, wn, wm); - MTL::Size grid_dims = MTL::Size(tn, tm, split_k_partitions); - - compute_encoder.set_input_array(a, 0); - compute_encoder.set_input_array(b, 1); - compute_encoder.set_output_array(C_split, 2); - - compute_encoder.set_bytes(params, 3); - compute_encoder.dispatch_threadgroups(grid_dims, group_dims); - - // Do accum kernel - { - auto kernel_name = "steel_gemm_splitk_accum_" + type_to_name(out) + "_" + - type_to_name(C_split) + "_axbpy"; - auto kernel = get_steel_gemm_splitk_accum_kernel( - d, kernel_name, C_split, out, true); - - compute_encoder.set_compute_pipeline_state(kernel); - - // Set the arguments for the kernel - compute_encoder.set_input_array(C_split, 0); - compute_encoder.set_output_array(out, 1); - compute_encoder.set_bytes(split_k_partitions, 2); - compute_encoder.set_bytes(split_k_partition_stride, 3); - compute_encoder.set_bytes(N, 4); - compute_encoder.set_input_array(c, 5); - compute_encoder.set_bytes(ldc, 6); - compute_encoder.set_bytes(fdc, 7); - compute_encoder.set_bytes(alpha_, 8); - compute_encoder.set_bytes(beta_, 9); - - // Launch enough thread groups for each output - MTL::Size grid_dims = MTL::Size(N, M, 1); - auto group_dims = get_block_dims(N, M, 1); - compute_encoder.dispatch_threads(grid_dims, group_dims); - } - - d.add_temporaries(std::move(copies), s.index); - return; + return gemv_axbpy( + /* const Stream& s = */ s, + /* metal::Device& d = */ d, + /* const array& a = */ a, + /* const array& b = */ b, + /* const array& c = */ c, + /* array& out = */ out, + /* int M = */ M, + /* int N = */ N, + /* int K = */ K, + /* int batch_size_out = */ batch_size_out, + /* int lda = */ lda, + /* int ldb = */ ldb, + /* bool transpose_a = */ transpose_a, + /* bool transpose_b = */ transpose_b, + /* std::vector& copies = */ copies, + /* Shape batch_shape = */ batch_shape, + /* Strides A_batch_stride = */ A_batch_stride, + /* Strides B_batch_stride = */ B_batch_stride, + /* Strides C_batch_stride = */ C_batch_stride, + /* float alpha = */ alpha_, + /* float beta = */ beta_); } ///////////////////////////////////////////////////////////////////////////// // Regular addmm dispatch - // Determine dispatch kernel - int bm = 64, bn = 64, bk = 16; - int wm = 2, wn = 2; - - char devc = d.get_architecture().back(); - GEMM_TPARAM_MACRO(devc) - - // Prepare kernel name - std::ostringstream kname; - kname << "steel_gemm_fused_" << (transpose_a ? 't' : 'n') - << (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_" - << type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk - << "_wm" << wm << "_wn" << wn; - - std::string base_name = kname.str(); - - const bool has_batch = (batch_shape.size() > 1); - const bool use_out_source = true; - const bool do_axpby = !(alpha_ == 1. && beta_ == 1.); - const bool align_M = (M % bm) == 0; - const bool align_N = (N % bn) == 0; - const bool align_K = (K % bk) == 0; - - metal::MTLFCList func_consts = { - {&has_batch, MTL::DataType::DataTypeBool, 10}, - {&use_out_source, MTL::DataType::DataTypeBool, 100}, - {&do_axpby, MTL::DataType::DataTypeBool, 110}, - {&align_M, MTL::DataType::DataTypeBool, 200}, - {&align_N, MTL::DataType::DataTypeBool, 201}, - {&align_K, MTL::DataType::DataTypeBool, 202}, - }; - - // clang-format off - kname << "_has_batch_" << (has_batch ? 't' : 'n') - << "_use_out_source_" << (use_out_source ? 't' : 'n') - << "_do_axpby_" << (do_axpby ? 't' : 'n') - << "_align_M_" << (align_M ? 't' : 'n') - << "_align_N_" << (align_N ? 't' : 'n') - << "_align_K_" << (align_K ? 't' : 'n'); // clang-format on - - std::string hash_name = kname.str(); - - // Encode and dispatch kernel - auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = get_steel_gemm_fused_kernel( - d, - base_name, - hash_name, - func_consts, - out, - transpose_a, - transpose_b, - bm, - bn, - bk, - wm, - wn); - - compute_encoder.set_compute_pipeline_state(kernel); - - int tn = (N + bn - 1) / bn; - int tm = (M + bm - 1) / bm; - - // TODO: Explore device-based tuning for swizzle - int swizzle_log = 0; // tm >= 6 ? 3 : (tm <= 3 ? 0 : 2); - - // Prepare steel matmul params - GEMMParams gemm_params{ - /* const int M = */ M, - /* const int N = */ N, - /* const int K = */ K, - /* const int lda = */ lda, - /* const int ldb = */ ldb, - /* const int ldd = */ N, - /* const int tiles_n = */ tn, - /* const int tiles_m = */ tm, - /* const int64_t batch_stride_a = */ A_batch_stride.back(), - /* const int64_t batch_stride_b = */ B_batch_stride.back(), - /* const int64_t batch_stride_d = */ matrix_stride_out, - /* const int swizzle_log = */ swizzle_log, - /* const int gemm_k_iterations_aligned = */ (K / bk), - /* const int batch_ndim = */ int(batch_shape.size())}; - - GEMMAddMMParams params{ - /* const int ldc = */ ldc, - /* const int fdc = */ fdc, - /* const int64_t batch_stride_c = */ C_batch_stride.back(), - /* const float alpha = */ alpha_, - /* const float beta = */ beta_}; - - int tile = 1 << swizzle_log; - tm = (tm + tile - 1) / tile; - tn = tn * tile; - - MTL::Size group_dims = MTL::Size(32, wn, wm); - MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out); - - Strides batch_strides = A_batch_stride; - batch_strides.insert( - batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end()); - batch_strides.insert( - batch_strides.end(), C_batch_stride.begin(), C_batch_stride.end()); - - // Launch kernel - compute_encoder.set_input_array(a, 0); - compute_encoder.set_input_array(b, 1); - compute_encoder.set_input_array(c, 2); - compute_encoder.set_output_array(out, 3); - - compute_encoder.set_bytes(gemm_params, 4); - compute_encoder.set_bytes(params, 5); - - compute_encoder.set_vector_bytes(batch_shape, 6); - compute_encoder.set_vector_bytes(batch_strides, 7); - - compute_encoder.dispatch_threadgroups(grid_dims, group_dims); - - d.add_temporaries(std::move(copies), s.index); + return steel_matmul_axpby( + /* const Stream& s = */ s, + /* metal::Device& d = */ d, + /* const array& a = */ a, + /* const array& b = */ b, + /* const array& c = */ c, + /* array& out = */ out, + /* int M = */ M, + /* int N = */ N, + /* int K = */ K, + /* int batch_size_out = */ batch_size_out, + /* int lda = */ lda, + /* int ldb = */ ldb, + /* bool transpose_a = */ transpose_a, + /* bool transpose_b = */ transpose_b, + /* std::vector& copies = */ copies, + /* Shape batch_shape = */ batch_shape, + /* Strides A_batch_stride = */ A_batch_stride, + /* Strides B_batch_stride = */ B_batch_stride, + /* Strides B_batch_stride = */ C_batch_stride, + /* float alpha = */ alpha_, + /* float beta = */ beta_); } +/////////////////////////////////////////////////////////////////////////////// +// BlockMaskedMM implementation +/////////////////////////////////////////////////////////////////////////////// + void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { using namespace mlx::steel; // assert(inputs.size() == 2); @@ -1454,6 +1426,10 @@ void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { d.add_temporaries(std::move(copies), s.index); } +/////////////////////////////////////////////////////////////////////////////// +// GatherMM implementation +/////////////////////////////////////////////////////////////////////////////// + void gather_mm_rhs( const array& a_, const array& b_, diff --git a/mlx/backend/metal/matmul.h b/mlx/backend/metal/matmul.h index 09ffe05a8..218664b1f 100644 --- a/mlx/backend/metal/matmul.h +++ b/mlx/backend/metal/matmul.h @@ -6,7 +6,34 @@ namespace mlx::core { -void steel_matmul_regular( +template +void steel_matmul_regular_axpby( + const Stream& s, + metal::Device& d, + const array& a, + const array& b, + const array& c, + array& out, + int M, + int N, + int K, + int batch_size_out, + int lda, + int ldb, + int ldd, + bool transpose_a, + bool transpose_b, + std::vector& copies, + Shape batch_shape, + Strides batch_strides, + int64_t A_batch_stride, + int64_t B_batch_stride, + int64_t matrix_stride_out, + int64_t C_batch_stride = 0, + float alpha = 1.0f, + float beta = 0.0f); + +inline void steel_matmul_regular( const Stream& s, metal::Device& d, const array& a, @@ -21,14 +48,61 @@ void steel_matmul_regular( int ldd, bool transpose_a, bool transpose_b, + std::vector& copies, Shape batch_shape, Strides batch_strides, int64_t A_batch_stride, int64_t B_batch_stride, - int64_t matrix_stride_out, - std::vector& copies); + int64_t matrix_stride_out) { + return steel_matmul_regular_axpby( + /* const Stream& s = */ s, + /* metal::Device& d = */ d, + /* const array& a = */ a, + /* const array& b = */ b, + /* const array& c = */ b, + /* array& out = */ out, + /* int M = */ M, + /* int N = */ N, + /* int K = */ K, + /* int batch_size_out = */ batch_size_out, + /* int lda = */ lda, + /* int ldb = */ ldb, + /* int ldd = */ ldd, + /* bool transpose_a = */ transpose_a, + /* bool transpose_b = */ transpose_b, + /* std::vector& copies = */ copies, + /* Shape batch_shape = */ batch_shape, + /* Strides batch_strides = */ batch_strides, + /* int64_t A_batch_stride = */ A_batch_stride, + /* int64_t B_batch_stride = */ B_batch_stride, + /* int64_t matrix_stride_out = */ matrix_stride_out); +} -void steel_matmul( +template +void steel_matmul_axpby( + const Stream& s, + metal::Device& d, + const array& a, + const array& b, + const array& c, + array& out, + int M, + int N, + int K, + int batch_size_out, + int lda, + int ldb, + bool transpose_a, + bool transpose_b, + std::vector& copies, + Shape batch_shape = {}, + Strides A_batch_stride = {}, + Strides B_batch_stride = {}, + Strides C_batch_stride = {}, + float alpha = 1.0f, + float beta = 0.0f); + +inline void steel_matmul( const Stream& s, metal::Device& d, const array& a, @@ -45,6 +119,26 @@ void steel_matmul( std::vector& copies, Shape batch_shape = {}, Strides A_batch_stride = {}, - Strides B_batch_stride = {}); + Strides B_batch_stride = {}) { + return steel_matmul_axpby( + /* const Stream& s = */ s, + /* metal::Device& d = */ d, + /* const array& a = */ a, + /* const array& b = */ b, + /* const array& c = */ b, + /* array& out = */ out, + /* int M = */ M, + /* int N = */ N, + /* int K = */ K, + /* int batch_size_out = */ batch_size_out, + /* int lda = */ lda, + /* int ldb = */ ldb, + /* bool transpose_a = */ transpose_a, + /* bool transpose_b = */ transpose_b, + /* std::vector& copies = */ copies, + /* Shape batch_shape = */ batch_shape, + /* Strides A_batch_stride = */ A_batch_stride, + /* Strides B_batch_stride = */ B_batch_stride); +} } // namespace mlx::core diff --git a/mlx/backend/metal/normalization.cpp b/mlx/backend/metal/normalization.cpp index d570bf3c0..8674eff72 100644 --- a/mlx/backend/metal/normalization.cpp +++ b/mlx/backend/metal/normalization.cpp @@ -26,7 +26,7 @@ void RMSNorm::eval_gpu( bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1; if (no_copy && x.ndim() > 1) { auto s = x.strides()[x.ndim() - 2]; - no_copy &= (s == 0 || s == x.shape().back()); + no_copy &= (s == 0 || s == x.shape().back() || x.shape(-2) == 1); } if (no_copy) { if (x.is_donatable()) { @@ -227,7 +227,7 @@ void LayerNorm::eval_gpu( bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1; if (no_copy && x.ndim() > 1) { auto s = x.strides()[x.ndim() - 2]; - no_copy &= (s == 0 || s == x.shape().back()); + no_copy &= (s == 0 || s == x.shape().back() || x.shape(-2) == 1); } if (no_copy) { if (x.is_donatable()) { diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 9602f667a..2b861428f 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -2847,21 +2847,6 @@ array matmul( "[matmul] Got 0 dimension input. Inputs must " "have at least one dimension."); } - if (a.ndim() == 1) { - // Insert a singleton dim in the beginning - a = expand_dims(a, 0, s); - } - if (b.ndim() == 1) { - // Insert a singleton dim at the end - b = expand_dims(b, 1, s); - } - if (a.shape(-1) != b.shape(-2)) { - std::ostringstream msg; - msg << "[matmul] Last dimension of first input with shape " << a.shape() - << " must match second to last dimension of" - << " second input with shape " << b.shape() << "."; - throw std::invalid_argument(msg.str()); - } // complex matmul using Karatsuba's Algorithm if (a.dtype() == complex64 || b.dtype() == complex64) { @@ -2883,6 +2868,22 @@ array matmul( c_real, multiply(array(complex64_t{0, 1}, complex64), c_imag, s), s); } + if (a.ndim() == 1) { + // Insert a singleton dim in the beginning + a = expand_dims(a, 0, s); + } + if (b.ndim() == 1) { + // Insert a singleton dim at the end + b = expand_dims(b, 1, s); + } + if (a.shape(-1) != b.shape(-2)) { + std::ostringstream msg; + msg << "[matmul] Last dimension of first input with shape " << a.shape() + << " must match second to last dimension of" + << " second input with shape " << b.shape() << "."; + throw std::invalid_argument(msg.str()); + } + // Type promotion auto out_type = promote_types(a.dtype(), b.dtype()); @@ -4240,6 +4241,16 @@ array addmm( "have at least one dimension."); } + // Type promotion + auto out_type = result_type(a, b, c); + + if (out_type == complex64) { + return add( + multiply(matmul(a, b, s), array(alpha), s), + multiply(array(beta), c, s), + s); + } + if (a.ndim() == 1) { // Insert a singleton dim in the beginning a = expand_dims(a, 0, s); @@ -4257,16 +4268,6 @@ array addmm( throw std::invalid_argument(msg.str()); } - // Type promotion - auto out_type = result_type(a, b, c); - - if (out_type == complex64) { - return add( - multiply(matmul(a, b, s), array(alpha), s), - multiply(array(beta), c, s), - s); - } - if (!issubdtype(out_type, floating)) { std::ostringstream msg; msg << "[addmm] Only real floating point types are supported but " diff --git a/mlx/utils.cpp b/mlx/utils.cpp index 0b2e66352..61b9da3a2 100644 --- a/mlx/utils.cpp +++ b/mlx/utils.cpp @@ -69,7 +69,12 @@ inline void PrintFormatter::print(std::ostream& os, double val) { os << val; } inline void PrintFormatter::print(std::ostream& os, complex64_t val) { - os << val; + os << val.real(); + if (val.imag() >= 0 || std::isnan(val.imag())) { + os << "+" << val.imag() << "j"; + } else { + os << "-" << -val.imag() << "j"; + } } PrintFormatter& get_global_formatter() { diff --git a/python/tests/test_blas.py b/python/tests/test_blas.py index 8c7a97ba8..2762df8f8 100644 --- a/python/tests/test_blas.py +++ b/python/tests/test_blas.py @@ -1195,6 +1195,16 @@ class TestBlas(mlx_tests.MLXTestCase): c_np = np.matmul(np.array(a).T, b) self.assertTrue(np.allclose(c, c_np)) + # Check shapes + a = mx.random.normal((2, 3)).astype(mx.complex64) + b = mx.random.normal((3,)) + self.assertEqual((a @ b).shape, (2,)) + + a = mx.random.normal((2, 3)).astype(mx.complex64) + b = mx.random.normal((3,)) + c = mx.random.normal((2,)) + self.assertEqual(mx.addmm(c, a, b).shape, (2,)) + def test_complex_gemm(self): M = 16 K = 50 diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index f3d48dda3..7c4f3f8e3 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -3078,6 +3078,13 @@ class TestOps(mlx_tests.MLXTestCase): ) self.assertTrue(np.allclose(mx.rsqrt(x), 1.0 / np.sqrt(x))) + def test_complex_power(self): + out = mx.power(mx.array(0j), 2) + self.assertEqual(out.item(), 0j) + + out = mx.power(mx.array(0j), float("nan")) + self.assertTrue(mx.isnan(out)) + class TestBroadcast(mlx_tests.MLXTestCase): def test_broadcast_shapes(self):