diff --git a/benchmarks/python/blas/bench_gemm.py b/benchmarks/python/blas/bench_gemm.py index 4914c40ba..ee358a95d 100644 --- a/benchmarks/python/blas/bench_gemm.py +++ b/benchmarks/python/blas/bench_gemm.py @@ -142,9 +142,7 @@ def bench_shape(B, M, N, K, np_dtype, transpose="nn"): t_b = (0, 1, 2) if transpose[1] == "n" else (0, 2, 1) c_mlx = a_mx.transpose(t_a) @ b_mx.transpose(t_b) - c_npy = a_np.transpose(t_a).astype(np.float32) @ b_np.transpose(t_b).astype( - np.float32 - ) + c_npy = a_np.transpose(t_a).astype(np_dtype) @ b_np.transpose(t_b).astype(np_dtype) atol = 1e-5 if np_dtype == np.float32 else 1e-4 @@ -163,7 +161,7 @@ def get_gflop_count(B, M, N, K): if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run gemm benchmarks") - dtypes = ("float32", "float16") + dtypes = ("float32", "float16", "complex64") transposes = ("nn", "nt", "tn") shapes = ( (16, 234, 768, 3072), @@ -187,7 +185,7 @@ if __name__ == "__main__": diff = gflops_mx / gflops_pt - 1.0 print( - f"{B:3d}, {M:4d}, {N:4d}, {K:4d}, {dtype}, {transpose}, {gflops_pt:05.3f}, {gflops_mx:05.3f}, {100. * diff:+5.2f}%" + f"{B:3d}, {M:4d}, {N:4d}, {K:4d}, {dtype}, {transpose}, {gflops_pt:05.3f}, {gflops_mx:05.3f}, {100.0 * diff:+5.2f}%" ) if gflops_pt >= 2.0 * gflops_mx: print("ATTENTION ^^^^^^^") diff --git a/benchmarks/python/blas/bench_gemv.py b/benchmarks/python/blas/bench_gemv.py index 2b564a78a..01a4339c7 100644 --- a/benchmarks/python/blas/bench_gemv.py +++ b/benchmarks/python/blas/bench_gemv.py @@ -196,7 +196,7 @@ def bench_with_out_len(ax, out_vec_len, in_vector_lens, dtype, transpose): for transpose in (False, True): - for dtype in ("float32", "float16"): + for dtype in ("float32", "float16", "complex64"): fig, axs = plt.subplots( len(in_vec_sizes), 2, figsize=(8.5, 11), layout="constrained" ) @@ -215,7 +215,7 @@ for transpose in (False, True): fig.suptitle(f"{device_name}: {dtype} {op_name}") fig.savefig( os.path.join( - results_dir, f'{device_name.replace(" ", "_")}_{dtype}_{op_name}.pdf' + results_dir, f"{device_name.replace(' ', '_')}_{dtype}_{op_name}.pdf" ) ) plt.close(fig) diff --git a/mlx/backend/cpu/gemms/cblas.cpp b/mlx/backend/cpu/gemms/cblas.cpp index 13eaeca91..765e9f539 100644 --- a/mlx/backend/cpu/gemms/cblas.cpp +++ b/mlx/backend/cpu/gemms/cblas.cpp @@ -88,4 +88,47 @@ void matmul( } } +template <> +void matmul( + const complex64_t* a, + const complex64_t* b, + complex64_t* out, + bool a_transposed, + bool b_transposed, + size_t lda, + size_t ldb, + size_t ldc, + float alpha, + float beta, + size_t batch_size, + const Shape& a_shape, + const Strides& a_strides, + const Shape& b_shape, + const Strides& b_strides) { + auto ndim = a_shape.size(); + size_t M = a_shape[ndim - 2]; + size_t N = b_shape[ndim - 1]; + size_t K = a_shape[ndim - 1]; + auto calpha = static_cast(alpha); + auto cbeta = static_cast(beta); + + for (int i = 0; i < batch_size; ++i) { + cblas_cgemm( + CblasRowMajor, + a_transposed ? CblasTrans : CblasNoTrans, // transA + b_transposed ? CblasTrans : CblasNoTrans, // transB + M, + N, + K, + &calpha, + a + elem_to_loc(M * K * i, a_shape, a_strides), + lda, + b + elem_to_loc(K * N * i, b_shape, b_strides), + ldb, + &cbeta, + out + M * N * i, + ldc); + } +} + } // namespace mlx::core diff --git a/mlx/backend/cpu/matmul.cpp b/mlx/backend/cpu/matmul.cpp index 7997c75ed..029e94aab 100644 --- a/mlx/backend/cpu/matmul.cpp +++ b/mlx/backend/cpu/matmul.cpp @@ -108,6 +108,9 @@ void matmul_general( } else if (out.dtype() == float64) { matmul_dispatch( a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta, stream); + } else if (out.dtype() == complex64) { + matmul_dispatch( + a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta, stream); } else { throw std::runtime_error("[Matmul::eval_cpu] Invalid type."); } diff --git a/mlx/backend/cuda/gemms/cublas_gemm.cpp b/mlx/backend/cuda/gemms/cublas_gemm.cpp index e58653178..fc2380fc3 100644 --- a/mlx/backend/cuda/gemms/cublas_gemm.cpp +++ b/mlx/backend/cuda/gemms/cublas_gemm.cpp @@ -50,8 +50,10 @@ cublasComputeType_t dtype_to_compute_type(Dtype dtype) { return mlx::core::env::enable_tf32() ? CUBLAS_COMPUTE_32F_FAST_TF32 : CUBLAS_COMPUTE_32F; case float64: - case complex64: return CUBLAS_COMPUTE_64F; + case complex64: + return mlx::core::env::enable_tf32() ? CUBLAS_COMPUTE_32F_FAST_TF32 + : CUBLAS_COMPUTE_32F; default: throw std::runtime_error(fmt::format( "Unsupported dtype in CublasGemm: {}.", dtype_to_string(dtype))); @@ -126,12 +128,13 @@ CublasGemm::CublasGemm( N_(b_cols) { heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED; - auto scale_type = dtype_to_cublas_type(dtype); + scale_type_ = dtype_to_cublas_type(dtype); if (dtype == bfloat16 || dtype == float16) { - scale_type = CUDA_R_32F; + scale_type_ = CUDA_R_32F; } + CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate( - &matmul_desc_, dtype_to_compute_type(dtype), scale_type)); + &matmul_desc_, dtype_to_compute_type(dtype), scale_type_)); int32_t pointer_mode = CUBLASLT_POINTER_MODE_HOST; CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( matmul_desc_, @@ -352,6 +355,16 @@ void CublasGemm::execute( } } + const void* alpha_ptr = α + const void* beta_ptr = β + complex64_t alpha_c, beta_c; + if (scale_type_ == CUDA_C_32F) { + alpha_c = complex64_t{alpha, 0.0f}; + beta_c = complex64_t{beta, 0.0f}; + alpha_ptr = &alpha_c; + beta_ptr = &beta_c; + } + void* workspace_ptr = nullptr; if (heuristic_.workspaceSize > 0) { // Ensure workspace is 256-byte aligned @@ -368,12 +381,12 @@ void CublasGemm::execute( CHECK_CUBLAS_ERROR(cublasLtMatmul( handle_, matmul_desc_, - &alpha, + alpha_ptr, b, // a and b are swapped a_desc_, a, b_desc_, - &beta, + beta_ptr, c ? c : out, c ? c_desc_ : out_desc_, out, diff --git a/mlx/backend/cuda/gemms/cublas_gemm.h b/mlx/backend/cuda/gemms/cublas_gemm.h index b202818a1..d6d2189b9 100644 --- a/mlx/backend/cuda/gemms/cublas_gemm.h +++ b/mlx/backend/cuda/gemms/cublas_gemm.h @@ -115,6 +115,7 @@ class CublasGemm { uint64_t M_; uint64_t N_; + cudaDataType_t scale_type_; cublasLtMatmulPreference_t pref_{nullptr}; cublasLtHandle_t handle_{nullptr}; cublasLtMatmulDesc_t matmul_desc_{nullptr}; diff --git a/mlx/backend/cuda/gemms/gemv.cu b/mlx/backend/cuda/gemms/gemv.cu index 552ab9cda..e1c755039 100644 --- a/mlx/backend/cuda/gemms/gemv.cu +++ b/mlx/backend/cuda/gemms/gemv.cu @@ -13,6 +13,37 @@ namespace cg = cooperative_groups; static constexpr int rows_per_block = 8; +// Accumulator type selection per input element type T. +template +struct GemvAccType { + using type = T; +}; + +template <> +struct GemvAccType<__half> { + using type = float; +}; + +template <> +struct GemvAccType<__nv_bfloat16> { + using type = float; +}; + +template <> +struct GemvAccType { + using type = float; +}; + +template <> +struct GemvAccType { + using type = double; +}; + +template <> +struct GemvAccType { + using type = cu::complex64_t; +}; + template __device__ void gemv_impl(const T* mat, const T* vec, T* out, int rows, int cols) { @@ -24,7 +55,8 @@ gemv_impl(const T* mat, const T* vec, T* out, int rows, int cols) { int row = g_idx.x * rows_per_block + t_idx.y; if (row < rows) { - float sum = 0.0f; + using Acc = typename GemvAccType::type; + Acc sum = Acc(0); for (int col = n_per_thread * warp.thread_rank(); col < cols; col += (WARP_SIZE * n_per_thread)) { auto local_mat = @@ -32,12 +64,11 @@ gemv_impl(const T* mat, const T* vec, T* out, int rows, int cols) { auto local_vec = unsafe_load_vector(vec + col, 0); #pragma unroll for (int j = 0; j < n_per_thread; ++j) { - sum += - static_cast(local_mat[j]) * static_cast(local_vec[j]); + sum += static_cast(local_mat[j]) * static_cast(local_vec[j]); } } - sum = cg::reduce(warp, sum, cg::plus{}); + sum = cg::reduce(warp, sum, cg::plus{}); if (warp.thread_rank() == 0) { out[row] = static_cast(sum); } @@ -107,7 +138,7 @@ void gemv( encoder.set_input_array(a); encoder.set_input_array(b); encoder.set_output_array(out); - dispatch_float_types(out.dtype(), "gemv", [&](auto type_tag) { + dispatch_inexact_types(out.dtype(), "gemv", [&](auto type_tag) { using DataType = cuda_type_t; dim3 block_dims{WARP_SIZE, rows_per_block}; const DataType* mat; diff --git a/mlx/backend/cuda/matmul.cpp b/mlx/backend/cuda/matmul.cpp index 50c8ee629..8a79b172b 100644 --- a/mlx/backend/cuda/matmul.cpp +++ b/mlx/backend/cuda/matmul.cpp @@ -93,6 +93,10 @@ void gemm_and_bias( a_batch_strides.back(), b_batch_strides.back()); if (bias) { + if (a.dtype() == complex64) { + throw std::runtime_error( + "[gemm_and_bias] complex64 bias epilogue isn’t supported in cublasLtMatmul."); + } gemm.set_bias(encoder, *bias); } gemm.run( @@ -157,7 +161,8 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { ///////////////////////////////////////////////////////////////////////////// // Dispatch to GEMM with epilogue or AddMM - if (beta_ == 1 && c.strides(-1) == 1 && c.data_size() == out.shape(-1)) { + if (beta_ == 1 && a.dtype() != complex64 && c.strides(-1) == 1 && + c.data_size() == out.shape(-1)) { out.set_data(allocator::malloc(out.nbytes())); gemm_and_bias( encoder, diff --git a/mlx/backend/metal/kernels/complex.h b/mlx/backend/metal/kernels/complex.h index c88002cb3..6e391483d 100644 --- a/mlx/backend/metal/kernels/complex.h +++ b/mlx/backend/metal/kernels/complex.h @@ -104,6 +104,27 @@ constexpr bool operator==(complex64_t a, complex64_t b) { constexpr complex64_t operator+(complex64_t a, complex64_t b) { return {a.real + b.real, a.imag + b.imag}; } + +constexpr thread complex64_t& operator+=(thread complex64_t& a, complex64_t b) { + a.real += b.real; + a.imag += b.imag; + return a; +} + +constexpr threadgroup complex64_t& operator+=( + threadgroup complex64_t& a, + complex64_t b) { + a.real += b.real; + a.imag += b.imag; + return a; +} + +constexpr device complex64_t& operator+=(device complex64_t& a, complex64_t b) { + a.real += b.real; + a.imag += b.imag; + return a; +} + constexpr complex64_t operator+(float a, complex64_t b) { return {a + b.real, b.imag}; } diff --git a/mlx/backend/metal/kernels/gemv.metal b/mlx/backend/metal/kernels/gemv.metal index b04545152..0914d8024 100644 --- a/mlx/backend/metal/kernels/gemv.metal +++ b/mlx/backend/metal/kernels/gemv.metal @@ -15,6 +15,15 @@ using namespace metal; #define MLX_MTL_CONST static constant constexpr const +template +struct DefaultAccT { + using type = float; +}; +template <> +struct DefaultAccT { + using type = complex64_t; +}; + template < typename T, const int BM, /* Threadgroup rows (in simdgroups) */ @@ -24,8 +33,10 @@ template < const int TM, /* Thread rows (in elements) */ const int TN, /* Thread cols (in elements) */ const bool kDoAxpby, /* Do out = alpha * out + beta * bias */ - typename AccT = float> + typename AccT = typename DefaultAccT::type> struct GEMVKernel { + using acc_type = AccT; + MLX_MTL_CONST int threadsM = BM * SM; MLX_MTL_CONST int threadsN = BN * SN; @@ -246,8 +257,10 @@ template < const int TM, /* Thread rows (in elements) */ const int TN, /* Thread cols (in elements) */ const bool kDoAxpby, /* Do out = alpha * out + beta * bias */ - typename AccT = float> + typename AccT = typename DefaultAccT::type> struct GEMVTKernel { + using acc_type = AccT; + MLX_MTL_CONST int threadsM = BM * SM; MLX_MTL_CONST int threadsN = BN * SN; @@ -453,7 +466,7 @@ template < uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { using gemv_kernel = GEMVKernel; - threadgroup float tgp_memory + threadgroup typename gemv_kernel::acc_type tgp_memory [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size]; // Update batch offsets @@ -530,6 +543,7 @@ template < instantiate_gemv_blocks(float32, float); instantiate_gemv_blocks(float16, half); instantiate_gemv_blocks(bfloat16, bfloat16_t); +instantiate_gemv_blocks(complex64, complex64_t); template < typename T, @@ -565,7 +579,7 @@ template < uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { using gemv_kernel = GEMVKernel; - threadgroup float tgp_memory + threadgroup typename gemv_kernel::acc_type tgp_memory [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size]; uint32_t indx_vec; @@ -636,6 +650,7 @@ template < instantiate_gemv_bs_blocks(float32, float); instantiate_gemv_bs_blocks(float16, half); instantiate_gemv_bs_blocks(bfloat16, bfloat16_t); +instantiate_gemv_bs_blocks(complex64, complex64_t); /////////////////////////////////////////////////////////////////////////////// /// Vector matrix multiplication @@ -672,7 +687,7 @@ template < uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { using gemv_kernel = GEMVTKernel; - threadgroup float tgp_memory + threadgroup typename gemv_kernel::acc_type tgp_memory [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size]; // Update batch offsets @@ -738,7 +753,8 @@ template < // clang-format off instantiate_gemv_t_blocks(float32, float); instantiate_gemv_t_blocks(float16, half); -instantiate_gemv_t_blocks(bfloat16, bfloat16_t); // clang-format on +instantiate_gemv_t_blocks(bfloat16, bfloat16_t); +instantiate_gemv_t_blocks(complex64, complex64_t); // clang-format on template < typename T, @@ -773,8 +789,8 @@ template < uint3 lid [[thread_position_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { - using gemv_kernel = GEMVTKernel; - threadgroup float tgp_memory + using gemv_kernel = GEMVTKernel; + threadgroup typename gemv_kernel::acc_type tgp_memory [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size]; uint32_t indx_vec; @@ -848,4 +864,5 @@ template < // clang-format off instantiate_gemv_t_bs_blocks(float32, float); instantiate_gemv_t_bs_blocks(float16, half); -instantiate_gemv_t_bs_blocks(bfloat16, bfloat16_t); // clang-format on +instantiate_gemv_t_bs_blocks(bfloat16, bfloat16_t); +instantiate_gemv_t_bs_blocks(complex64, complex64_t); // clang-format on diff --git a/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal b/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal index 960321440..7851e5ef1 100644 --- a/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +++ b/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal @@ -3,9 +3,8 @@ #include // clang-format off -#include "mlx/backend/metal/kernels/steel/gemm/mma.h" - #include "mlx/backend/metal/kernels/utils.h" +#include "mlx/backend/metal/kernels/steel/gemm/mma.h" #include "mlx/backend/metal/kernels/steel/conv/conv.h" #include "mlx/backend/metal/kernels/steel/conv/params.h" #include "mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h" diff --git a/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal b/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal index 8f8b8e0e0..23e8575ce 100644 --- a/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +++ b/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal @@ -3,9 +3,8 @@ #include // clang-format off -#include "mlx/backend/metal/kernels/steel/gemm/mma.h" - #include "mlx/backend/metal/kernels/utils.h" +#include "mlx/backend/metal/kernels/steel/gemm/mma.h" #include "mlx/backend/metal/kernels/steel/conv/conv.h" #include "mlx/backend/metal/kernels/steel/conv/params.h" #include "mlx/backend/metal/kernels/steel/utils.h" diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal index e03d4beb2..5549580f9 100644 --- a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal @@ -23,10 +23,12 @@ instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 1, 2) \ instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \ instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 1, 2) \ - instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) + instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \ + instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 8, 4, 1) instantiate_gemm_shapes_helper(float16, half, float16, half); instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t); instantiate_gemm_shapes_helper(float32, float, float32, float); +instantiate_gemm_shapes_helper(complex64, complex64_t, complex64, complex64_t); // clang-format on diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal index af34a6870..aa1274c94 100644 --- a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal @@ -1,8 +1,8 @@ // Copyright © 2024 Apple Inc. // clang-format off -#include "mlx/backend/metal/kernels/steel/gemm/gemm.h" #include "mlx/backend/metal/kernels/utils.h" +#include "mlx/backend/metal/kernels/steel/gemm/gemm.h" #include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h" #define instantiate_gemm( \ diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal index a046515b0..8c60b426c 100644 --- a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal @@ -60,6 +60,7 @@ instantiate_gemm_shapes_helper(float16, half, float32, float); instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, float32, float); instantiate_gemm_shapes_helper(float32, float, float32, float); +instantiate_gemm_shapes_helper(complex64, complex64_t, complex64, complex64_t); #define instantiate_accum(oname, otype, aname, atype) \ instantiate_kernel( \ @@ -71,4 +72,5 @@ instantiate_gemm_shapes_helper(float32, float, float32, float); instantiate_accum(bfloat16, bfloat16_t, float32, float); instantiate_accum(float16, half, float32, float); -instantiate_accum(float32, float, float32, float); // clang-format on +instantiate_accum(float32, float, float32, float); +instantiate_accum(complex64, complex64_t, complex64, complex64_t); // clang-format on diff --git a/mlx/backend/metal/kernels/steel/gemm/mma.h b/mlx/backend/metal/kernels/steel/gemm/mma.h index 64b87655e..74151a954 100644 --- a/mlx/backend/metal/kernels/steel/gemm/mma.h +++ b/mlx/backend/metal/kernels/steel/gemm/mma.h @@ -421,6 +421,16 @@ METAL_FUNC void tile_matmad( } } +template +struct TransformNone { + static METAL_FUNC complex64_t apply(complex64_t x) { + return x; + } + static METAL_FUNC complex64_t apply(complex64_t x, complex64_t) { + return x; + } +}; + template < typename T, typename U, @@ -731,5 +741,406 @@ struct BlockMMA { } }; +template < + typename U, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + short lda_tgp, + short ldb_tgp, + typename AccumType, + typename Epilogue> +struct BlockMMA< + complex64_t, + U, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + lda_tgp, + ldb_tgp, + AccumType, + Epilogue> { + static_assert( + metal::is_same_v, + "BlockMMA expects float accumulators"); + static_assert( + metal::is_same_v, + "For complex BlockMMA, U must be complex64_t; use a different epilogue for projections"); + // MMAFrag size + STEEL_CONST short kFragSize = 8; + using MMAFrag_acc_t = BaseMMAFrag; + + // Warp tile simdgroup matrix strides along M + STEEL_CONST short TM_stride = kFragSize * WM; + // Warp tile simdgroup matrix strides along M + STEEL_CONST short TN_stride = kFragSize * WN; + + // Warp tile size along M + STEEL_CONST short TM = BM / (kFragSize * WM); + // Warp tile size along N + STEEL_CONST short TN = BN / (kFragSize * WN); + + // Threadgroup A strides + STEEL_CONST short A_str_m = transpose_a ? 1 : lda_tgp; // M + STEEL_CONST short A_str_k = transpose_a ? lda_tgp : 1; // K + + // Threadgroup B strides + STEEL_CONST short B_str_k = transpose_b ? 1 : ldb_tgp; // K + STEEL_CONST short B_str_n = transpose_b ? ldb_tgp : 1; // N + + // Threadgroup strides along K + STEEL_CONST short tile_stride_a = kFragSize * A_str_k; + STEEL_CONST short tile_stride_b = kFragSize * B_str_k; + + // When indexing complex as float[2] + STEEL_CONST short A_str_m_f = A_str_m * 2; + STEEL_CONST short A_str_k_f = A_str_k * 2; + STEEL_CONST short B_str_k_f = B_str_k * 2; + STEEL_CONST short B_str_n_f = B_str_n * 2; + STEEL_CONST short tile_stride_a_f = tile_stride_a * 2; + STEEL_CONST short tile_stride_b_f = tile_stride_b * 2; + + // Accumulators (real/imag) + MMATile Ctile_r; + MMATile Ctile_i; + + // Offsets within threadgroup + short sm, sn; + short As_offset, Bs_offset; + + /* Constructor */ + METAL_FUNC BlockMMA( + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) { + // Determine thread position in simdgroup matrix + short tm = kFragSize * (simd_group_id / WN); + short tn = kFragSize * (simd_group_id % WN); + + short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id); + sm = simd_coord.y; + sn = simd_coord.x; + + // Determine thread and simdgroup offset + As_offset = (tm + sm) * A_str_m + (sn)*A_str_k; // (M,K) + Bs_offset = (sm)*B_str_k + (tn + sn) * B_str_n; // (K,N) + + sm += tm; + sn += tn; + } + + /* Karatsuba MMA: 3 real MMAs per K-chunk */ + METAL_FUNC void mma( + const threadgroup complex64_t* As, + const threadgroup complex64_t* Bs) { + // Adjust for simdgroup and thread location + As += As_offset; + Bs += Bs_offset; + threadgroup const float* As_f = + reinterpret_cast(As); + threadgroup const float* Bs_f = + reinterpret_cast(Bs); + + // Iterate over BK in blocks of kFragSize + STEEL_PRAGMA_UNROLL + for (short kk = 0; kk < BK; kk += kFragSize) { + simdgroup_barrier(mem_flags::mem_none); + + MMATile Ar, Ai; + Ar.template load(As_f + 0); + Ai.template load(As_f + 1); + + simdgroup_barrier(mem_flags::mem_none); + + MMATile Br, Bi; + Br.template load(Bs_f + 0); + Bi.template load(Bs_f + 1); + + simdgroup_barrier(mem_flags::mem_none); + + // P = Ar*Br ; Q = Ai*Bi ; R = (Ar+Ai)*(Br+Bi) + MMATile P, Q, R; + + tile_matmad(P, Ar, Br, P); + tile_matmad(Q, Ai, Bi, Q); + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < decltype(Ar)::kElemsPerTile; ++i) + Ar.elems()[i] += Ai.elems()[i]; + STEEL_PRAGMA_UNROLL + for (short i = 0; i < decltype(Br)::kElemsPerTile; ++i) + Br.elems()[i] += Bi.elems()[i]; + + tile_matmad(R, Ar, Br, R); + + // C_r += P - Q ; C_i -= Q + STEEL_PRAGMA_UNROLL + for (short i = 0; i < decltype(Ctile_r)::kElemsPerTile; ++i) { + const auto p = P.elems()[i]; + const auto q = Q.elems()[i]; + const auto r = R.elems()[i]; + Ctile_r.elems()[i] += (p - q); + Ctile_i.elems()[i] += (r - p - q); + } + + // Progress to next simdgroup tile + As_f += tile_stride_a_f; + Bs_f += tile_stride_b_f; + } + } + + /* Store results from simdgroup_matrix results into device memory */ + METAL_FUNC void store_result(device U* D, const int ldd) { + // Adjust for simdgroup and thread location + D += sm * ldd + sn; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + thread const auto& r = Ctile_r.frag_at(i, j); + thread const auto& im = Ctile_i.frag_at(i, j); + int off = (i * TM_stride) * ldd + (j * TN_stride); + STEEL_PRAGMA_UNROLL + for (short k = 0; k < decltype(Ctile_r)::kElemsPerFrag; k++) { + D[off + k] = Epilogue::apply(complex64_t(r[k], im[k])); + } + } + } + } + + METAL_FUNC void + store_result_slice(device U* D, const int ldd, short2 start, short2 stop) { + D += sm * ldd + sn; + start -= short2(sn, sm); + stop -= short2(sn, sm); + + if (stop.y <= 0 || stop.x <= 0) + return; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; ++i) { + const int row = i * TM_stride; + if (row >= start.y && row < stop.y) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; ++j) { + const int off = row * ldd + (j * TN_stride); + thread const auto& r = Ctile_r.frag_at(i, j); + thread const auto& im = Ctile_i.frag_at(i, j); + + STEEL_PRAGMA_UNROLL + for (short k = 0; k < decltype(Ctile_r)::kElemsPerFrag; ++k) { + const int col = j * TN_stride + k; + if (col >= start.x && col < stop.x) { + D[off + k] = Epilogue::apply(complex64_t(r[k], im[k])); + } + } + } + } + } + } + + METAL_FUNC void + store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) { + D += sm * ldd + sn; + dst_tile_dims -= short2(sn, sm); + if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) + return; + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + if (i * TM_stride < dst_tile_dims.y) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + int off = (i * TM_stride) * ldd + (j * TN_stride); + thread const auto& r = Ctile_r.frag_at(i, j); + thread const auto& im = Ctile_i.frag_at(i, j); + STEEL_PRAGMA_UNROLL + for (short k = 0; k < decltype(Ctile_r)::kElemsPerFrag; k++) { + if ((j * TN_stride + k) < dst_tile_dims.x) { + D[off + k] = Epilogue::apply(complex64_t(r[k], im[k])); + } + } + } + } + } + } + + /* Apply epilogue */ + template + METAL_FUNC void apply_epilogue(thread const UnaryEpilogue& epilogue_op) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < decltype(Ctile_r)::kElemsPerTile; i++) { + complex64_t out = epilogue_op.apply( + complex64_t(Ctile_r.elems()[i], Ctile_i.elems()[i])); + Ctile_r.elems()[i] = out.real; + Ctile_i.elems()[i] = out.imag; + } + } + + /* Apply epilogue */ + template + METAL_FUNC void apply_epilogue( + const device U* C, + const int ldc, + const int fdc, + thread const BinaryEpilogue& epilogue_op) { + // Adjust for simdgroup and thread location + C += (sm)*ldc + (sn)*fdc; + + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in Cr, Ci + thread auto& r = Ctile_r.frag_at(i, j); + thread auto& im = Ctile_i.frag_at(i, j); + int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + + STEEL_PRAGMA_UNROLL + for (short k = 0; k < decltype(Ctile_r)::kElemsPerFrag; k++) { + complex64_t out = epilogue_op.apply( + complex64_t(r[k], im[k]), C[offset_c + k * fdc]); + r[k] = out.real; + im[k] = out.imag; + } + } + } + } + + /* Apply epilogue */ + template + METAL_FUNC void apply_epilogue_safe( + const device U* C, + const int ldc, + const int fdc, + short2 dst_tile_dims, + thread const BinaryEpilogue& epilogue_op) { + // Adjust for simdgroup and thread location + C += (sm)*ldc + (sn)*fdc; + dst_tile_dims -= short2(sn, sm); + + if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) + return; + + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in Cr, Ci + thread auto& r = Ctile_r.frag_at(i, j); + thread auto& im = Ctile_i.frag_at(i, j); + int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + + constexpr short kelems = decltype(Ctile_r)::kElemsPerFrag; + complex64_t tmp[kelems]; + + STEEL_PRAGMA_UNROLL + for (short k = 0; k < kelems; k++) { + if ((j * TN_stride + k) < dst_tile_dims.x && + (i * TM_stride) < dst_tile_dims.y) { + tmp[k] = C[offset_c + k * fdc]; + } else { + tmp[k] = complex64_t(0.0f, 0.0f); + } + } + + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short k = 0; k < kelems; k++) { + complex64_t out = epilogue_op.apply(complex64_t(r[k], im[k]), tmp[k]); + r[k] = out.real; + im[k] = out.imag; + } + } + } + } + + /* Store results from simdgroup_matrix results into device memory */ + METAL_FUNC void store_result( + device U* D, + const int ldd, + const device U* C, + const int ldc, + const int fdc, + thread const Epilogue& epilogue_op) const { + // Adjust for simdgroup and thread location + C += (sm)*ldc + (sn)*fdc; + D += (sm)*ldd + sn; + + constexpr short kelems = decltype(Ctile_r)::kElemsPerFrag; + + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in Cr, Ci + thread const auto& r = Ctile_r.frag_at(i, j); + thread const auto& im = Ctile_i.frag_at(i, j); + int off_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + int off_d = (i * TM_stride) * ldd + (j * TN_stride); + + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short k = 0; k < kelems; k++) { + D[off_d + k] = + epilogue_op.apply(complex64_t(r[k], im[k]), C[off_c + k * fdc]); + } + } + } + } + + METAL_FUNC void store_result_safe( + device U* D, + const int ldd, + const device U* C, + const int ldc, + const int fdc, + short2 dst_tile_dims, + thread const Epilogue& epilogue_op) const { + // Adjust for simdgroup and thread location + C += (sm)*ldc + (sn)*fdc; + D += (sm)*ldd + sn; + dst_tile_dims -= short2(sn, sm); + + if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) + return; + + constexpr short kelems = decltype(Ctile_r)::kElemsPerFrag; + + STEEL_PRAGMA_UNROLL + for (int i = 0; i < TM; i++) { + if (i * TM_stride < dst_tile_dims.y) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < TN; j++) { + // Get accumulated result and associated offset in Cr, Ci + thread const auto& r = Ctile_r.frag_at(i, j); + thread const auto& im = Ctile_i.frag_at(i, j); + int off_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + int off_d = (i * TM_stride) * ldd + (j * TN_stride); + + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short k = 0; k < kelems; k++) { + if ((j * TN_stride + k) < dst_tile_dims.x) { + D[off_d + k] = epilogue_op.apply( + complex64_t(r[k], im[k]), C[off_c + k * fdc]); + } + } + } + } + } + } +}; + } // namespace steel } // namespace mlx diff --git a/mlx/backend/metal/kernels/steel/gemm/transforms.h b/mlx/backend/metal/kernels/steel/gemm/transforms.h index c0624d21b..0282a1223 100644 --- a/mlx/backend/metal/kernels/steel/gemm/transforms.h +++ b/mlx/backend/metal/kernels/steel/gemm/transforms.h @@ -48,7 +48,8 @@ struct TransformAxpby { } METAL_FUNC OutT apply(InT x, OutT c) const { - return static_cast(x * alpha + (beta * c)); + return static_cast( + x * static_cast(alpha) + (static_cast(beta) * c)); } }; diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index 1ceae3f33..d9497a665 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -86,7 +86,13 @@ ensure_batch_contiguous(const array& x, metal::Device& d, const Stream& s) { #define GEMM_TPARAM_MACRO(devc) \ if (devc == 'g' || devc == 'p') { /* Small device */ \ - if (!transpose_a && transpose_b) { /* nt */ \ + if (out.dtype() == complex64) { \ + bm = 64; \ + bn = 32; \ + bk = 8; \ + wm = 4; \ + wn = 1; \ + } else if (!transpose_a && transpose_b) { /* nt */ \ bm = 64; \ bn = 32; \ bk = 32; \ @@ -362,7 +368,11 @@ void steel_gemm_splitk_axpby( int gemm_k_iterations = (K / bk) / split_k_partitions; int split_k_partition_size = gemm_k_iterations * bk; - array C_split({split_k_partitions, M, N}, float32, nullptr, {}); + array C_split( + {split_k_partitions, M, N}, + issubdtype(out.dtype(), complexfloating) ? complex64 : float32, + nullptr, + {}); C_split.set_data(allocator::malloc(C_split.nbytes())); copies.push_back(C_split); @@ -807,9 +817,8 @@ inline void gemv( void Matmul::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); - if (!issubdtype(out.dtype(), floating)) { - throw std::runtime_error( - "[matmul] Does not yet support non-floating point types."); + if (!issubdtype(out.dtype(), inexact)) { + throw std::runtime_error("[matmul] dtype must be inexact."); } auto& s = stream(); auto& d = metal::device(s.device); @@ -1338,7 +1347,8 @@ void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { << (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_" << type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk << "_wm" << wm << "_wn" << wn << "_MN_" << (mn_aligned ? "t" : "n") - << "aligned" << "_K_" << (k_aligned ? "t" : "n") << "aligned"; + << "aligned" + << "_K_" << (k_aligned ? "t" : "n") << "aligned"; // Encode and dispatch kernel auto& compute_encoder = d.get_command_encoder(s.index); diff --git a/mlx/dtype_utils.h b/mlx/dtype_utils.h index 27fe432f6..47c6ed661 100644 --- a/mlx/dtype_utils.h +++ b/mlx/dtype_utils.h @@ -76,6 +76,19 @@ void dispatch_float_types(Dtype dt, std::string_view tag, F&& f) { } } +template +void dispatch_inexact_types(Dtype dt, std::string_view tag, F&& f) { + switch (dt) { + MLX_INTERNAL_DTYPE_SWITCH_FLOATS(); + MLX_INTERNAL_DTYPE_SWITCH_CASE(complex64, complex64_t); + default: + std::ostringstream msg; + msg << tag << " Only inexact (float/complex) types supported but " << dt + << " was provided"; + throw std::invalid_argument(msg.str()); + } +} + template void dispatch_int_float_types(Dtype dt, std::string_view tag, F&& f) { switch (dt) { diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 73d3a1f23..ffc7e4bb4 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -1228,14 +1228,14 @@ array pad( if (low_pad_size[i] < 0) { std::ostringstream msg; msg << "Invalid low padding size (" << low_pad_size[i] - << ") passed to pad" << " for axis " << i + << ") passed to pad for axis " << i << ". Padding sizes must be non-negative"; throw std::invalid_argument(msg.str()); } if (high_pad_size[i] < 0) { std::ostringstream msg; msg << "Invalid high padding size (" << high_pad_size[i] - << ") passed to pad" << " for axis " << i + << ") passed to pad for axis " << i << ". Padding sizes must be non-negative"; throw std::invalid_argument(msg.str()); } @@ -2859,26 +2859,6 @@ array matmul( "have at least one dimension."); } - // complex matmul using Karatsuba's Algorithm - if (a.dtype() == complex64 || b.dtype() == complex64) { - // Extract real and imaginary parts - auto a_real = real(a, s); - auto a_imag = imag(a, s); - auto b_real = real(b, s); - auto b_imag = imag(b, s); - - // Compute real and imaginary components of the result - auto m1 = matmul(a_real, b_real, s); - auto m2 = matmul(a_imag, b_imag, s); - auto m3 = matmul(add(a_real, a_imag, s), add(b_real, b_imag, s), s); - - auto c_real = subtract(m1, m2, s); - auto c_imag = subtract(m3, add(m1, m2, s), s); - - return add( - c_real, multiply(array(complex64_t{0, 1}, complex64), c_imag, s), s); - } - if (a.ndim() == 1) { // Insert a singleton dim in the beginning a = expand_dims(a, 0, s); @@ -2898,11 +2878,11 @@ array matmul( // Type promotion auto out_type = promote_types(a.dtype(), b.dtype()); - if (!issubdtype(out_type, floating)) { + if (!issubdtype(out_type, inexact)) { std::ostringstream msg; - msg << "[matmul] Only real floating point types are supported but " - << a.dtype() << " and " << b.dtype() << " were provided which results" - << " in " << out_type << ", which is not a real floating point type."; + msg << "[matmul] Only inexact types are supported but " << a.dtype() + << " and " << b.dtype() << " were provided which results" + << " in " << out_type << ", which is not a floating point type."; throw std::invalid_argument(msg.str()); } if (a.dtype() != out_type) { @@ -3146,8 +3126,8 @@ array take_along_axis( StreamOrDevice s /* = {} */) { if (axis + a.ndim() < 0 || axis >= static_cast(a.ndim())) { std::ostringstream msg; - msg << "[take_along_axis] Received invalid axis " << " for array with " - << a.ndim() << " dimensions."; + msg << "[take_along_axis] Received invalid axis for array with " << a.ndim() + << " dimensions."; throw std::invalid_argument(msg.str()); } @@ -3184,7 +3164,7 @@ array scatter_axis( (mode == ScatterAxis::None) ? "[put_along_axis]" : "[scatter_add_axis]"; if (axis + a.ndim() < 0 || axis >= static_cast(a.ndim())) { std::ostringstream msg; - msg << prefix << " Received invalid axis " << " for array with " << a.ndim() + msg << prefix << " Received invalid axis for array with " << a.ndim() << " dimensions."; throw std::invalid_argument(msg.str()); } @@ -3592,7 +3572,7 @@ run_conv_checks(const array& in, const array& wt, int n_dim, int groups) { if (in.ndim() != n_dim + 2) { std::ostringstream msg; msg << "[conv] Invalid input array with " << in.ndim() << " dimensions for " - << n_dim << "D convolution." << " Expected an array with " << n_dim + 2 + << n_dim << "D convolution. Expected an array with " << n_dim + 2 << " dimensions following the format [N, ..., C_in]."; throw std::invalid_argument(msg.str()); } @@ -4141,7 +4121,8 @@ std::vector quantize( std::ostringstream msg; msg << "[quantize] The last dimension of the matrix needs to be divisible by " << "the quantization group size " << group_size - << ". However the provided " << " matrix has shape " << w.shape(); + << ". However the provided " + << " matrix has shape " << w.shape(); throw std::invalid_argument(msg.str()); }