Faster complex matmul (#2571)

This commit is contained in:
Daniel Yeh
2025-10-03 08:33:15 +02:00
committed by GitHub
parent 287c63a093
commit 22a5da76c8
20 changed files with 623 additions and 73 deletions

View File

@@ -142,9 +142,7 @@ def bench_shape(B, M, N, K, np_dtype, transpose="nn"):
t_b = (0, 1, 2) if transpose[1] == "n" else (0, 2, 1)
c_mlx = a_mx.transpose(t_a) @ b_mx.transpose(t_b)
c_npy = a_np.transpose(t_a).astype(np.float32) @ b_np.transpose(t_b).astype(
np.float32
)
c_npy = a_np.transpose(t_a).astype(np_dtype) @ b_np.transpose(t_b).astype(np_dtype)
atol = 1e-5 if np_dtype == np.float32 else 1e-4
@@ -163,7 +161,7 @@ def get_gflop_count(B, M, N, K):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run gemm benchmarks")
dtypes = ("float32", "float16")
dtypes = ("float32", "float16", "complex64")
transposes = ("nn", "nt", "tn")
shapes = (
(16, 234, 768, 3072),
@@ -187,7 +185,7 @@ if __name__ == "__main__":
diff = gflops_mx / gflops_pt - 1.0
print(
f"{B:3d}, {M:4d}, {N:4d}, {K:4d}, {dtype}, {transpose}, {gflops_pt:05.3f}, {gflops_mx:05.3f}, {100. * diff:+5.2f}%"
f"{B:3d}, {M:4d}, {N:4d}, {K:4d}, {dtype}, {transpose}, {gflops_pt:05.3f}, {gflops_mx:05.3f}, {100.0 * diff:+5.2f}%"
)
if gflops_pt >= 2.0 * gflops_mx:
print("ATTENTION ^^^^^^^")

View File

@@ -196,7 +196,7 @@ def bench_with_out_len(ax, out_vec_len, in_vector_lens, dtype, transpose):
for transpose in (False, True):
for dtype in ("float32", "float16"):
for dtype in ("float32", "float16", "complex64"):
fig, axs = plt.subplots(
len(in_vec_sizes), 2, figsize=(8.5, 11), layout="constrained"
)
@@ -215,7 +215,7 @@ for transpose in (False, True):
fig.suptitle(f"{device_name}: {dtype} {op_name}")
fig.savefig(
os.path.join(
results_dir, f'{device_name.replace(" ", "_")}_{dtype}_{op_name}.pdf'
results_dir, f"{device_name.replace(' ', '_')}_{dtype}_{op_name}.pdf"
)
)
plt.close(fig)

View File

@@ -88,4 +88,47 @@ void matmul<double>(
}
}
template <>
void matmul<complex64_t>(
const complex64_t* a,
const complex64_t* b,
complex64_t* out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
size_t ldc,
float alpha,
float beta,
size_t batch_size,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides) {
auto ndim = a_shape.size();
size_t M = a_shape[ndim - 2];
size_t N = b_shape[ndim - 1];
size_t K = a_shape[ndim - 1];
auto calpha = static_cast<complex64_t>(alpha);
auto cbeta = static_cast<complex64_t>(beta);
for (int i = 0; i < batch_size; ++i) {
cblas_cgemm(
CblasRowMajor,
a_transposed ? CblasTrans : CblasNoTrans, // transA
b_transposed ? CblasTrans : CblasNoTrans, // transB
M,
N,
K,
&calpha,
a + elem_to_loc(M * K * i, a_shape, a_strides),
lda,
b + elem_to_loc(K * N * i, b_shape, b_strides),
ldb,
&cbeta,
out + M * N * i,
ldc);
}
}
} // namespace mlx::core

View File

@@ -108,6 +108,9 @@ void matmul_general(
} else if (out.dtype() == float64) {
matmul_dispatch<double>(
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta, stream);
} else if (out.dtype() == complex64) {
matmul_dispatch<complex64_t>(
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta, stream);
} else {
throw std::runtime_error("[Matmul::eval_cpu] Invalid type.");
}

View File

@@ -50,8 +50,10 @@ cublasComputeType_t dtype_to_compute_type(Dtype dtype) {
return mlx::core::env::enable_tf32() ? CUBLAS_COMPUTE_32F_FAST_TF32
: CUBLAS_COMPUTE_32F;
case float64:
case complex64:
return CUBLAS_COMPUTE_64F;
case complex64:
return mlx::core::env::enable_tf32() ? CUBLAS_COMPUTE_32F_FAST_TF32
: CUBLAS_COMPUTE_32F;
default:
throw std::runtime_error(fmt::format(
"Unsupported dtype in CublasGemm: {}.", dtype_to_string(dtype)));
@@ -126,12 +128,13 @@ CublasGemm::CublasGemm(
N_(b_cols) {
heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED;
auto scale_type = dtype_to_cublas_type(dtype);
scale_type_ = dtype_to_cublas_type(dtype);
if (dtype == bfloat16 || dtype == float16) {
scale_type = CUDA_R_32F;
scale_type_ = CUDA_R_32F;
}
CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(
&matmul_desc_, dtype_to_compute_type(dtype), scale_type));
&matmul_desc_, dtype_to_compute_type(dtype), scale_type_));
int32_t pointer_mode = CUBLASLT_POINTER_MODE_HOST;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
@@ -352,6 +355,16 @@ void CublasGemm::execute(
}
}
const void* alpha_ptr = &alpha;
const void* beta_ptr = &beta;
complex64_t alpha_c, beta_c;
if (scale_type_ == CUDA_C_32F) {
alpha_c = complex64_t{alpha, 0.0f};
beta_c = complex64_t{beta, 0.0f};
alpha_ptr = &alpha_c;
beta_ptr = &beta_c;
}
void* workspace_ptr = nullptr;
if (heuristic_.workspaceSize > 0) {
// Ensure workspace is 256-byte aligned
@@ -368,12 +381,12 @@ void CublasGemm::execute(
CHECK_CUBLAS_ERROR(cublasLtMatmul(
handle_,
matmul_desc_,
&alpha,
alpha_ptr,
b, // a and b are swapped
a_desc_,
a,
b_desc_,
&beta,
beta_ptr,
c ? c : out,
c ? c_desc_ : out_desc_,
out,

View File

@@ -115,6 +115,7 @@ class CublasGemm {
uint64_t M_;
uint64_t N_;
cudaDataType_t scale_type_;
cublasLtMatmulPreference_t pref_{nullptr};
cublasLtHandle_t handle_{nullptr};
cublasLtMatmulDesc_t matmul_desc_{nullptr};

View File

@@ -13,6 +13,37 @@ namespace cg = cooperative_groups;
static constexpr int rows_per_block = 8;
// Accumulator type selection per input element type T.
template <typename T>
struct GemvAccType {
using type = T;
};
template <>
struct GemvAccType<__half> {
using type = float;
};
template <>
struct GemvAccType<__nv_bfloat16> {
using type = float;
};
template <>
struct GemvAccType<float> {
using type = float;
};
template <>
struct GemvAccType<double> {
using type = double;
};
template <>
struct GemvAccType<cu::complex64_t> {
using type = cu::complex64_t;
};
template <typename T, int rows_per_block, int n_per_thread>
__device__ void
gemv_impl(const T* mat, const T* vec, T* out, int rows, int cols) {
@@ -24,7 +55,8 @@ gemv_impl(const T* mat, const T* vec, T* out, int rows, int cols) {
int row = g_idx.x * rows_per_block + t_idx.y;
if (row < rows) {
float sum = 0.0f;
using Acc = typename GemvAccType<T>::type;
Acc sum = Acc(0);
for (int col = n_per_thread * warp.thread_rank(); col < cols;
col += (WARP_SIZE * n_per_thread)) {
auto local_mat =
@@ -32,12 +64,11 @@ gemv_impl(const T* mat, const T* vec, T* out, int rows, int cols) {
auto local_vec = unsafe_load_vector<n_per_thread>(vec + col, 0);
#pragma unroll
for (int j = 0; j < n_per_thread; ++j) {
sum +=
static_cast<float>(local_mat[j]) * static_cast<float>(local_vec[j]);
sum += static_cast<Acc>(local_mat[j]) * static_cast<Acc>(local_vec[j]);
}
}
sum = cg::reduce(warp, sum, cg::plus<float>{});
sum = cg::reduce(warp, sum, cg::plus<Acc>{});
if (warp.thread_rank() == 0) {
out[row] = static_cast<T>(sum);
}
@@ -107,7 +138,7 @@ void gemv(
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
dispatch_float_types(out.dtype(), "gemv", [&](auto type_tag) {
dispatch_inexact_types(out.dtype(), "gemv", [&](auto type_tag) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
dim3 block_dims{WARP_SIZE, rows_per_block};
const DataType* mat;

View File

@@ -93,6 +93,10 @@ void gemm_and_bias(
a_batch_strides.back(),
b_batch_strides.back());
if (bias) {
if (a.dtype() == complex64) {
throw std::runtime_error(
"[gemm_and_bias] complex64 bias epilogue isnt supported in cublasLtMatmul.");
}
gemm.set_bias(encoder, *bias);
}
gemm.run(
@@ -157,7 +161,8 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
/////////////////////////////////////////////////////////////////////////////
// Dispatch to GEMM with epilogue or AddMM
if (beta_ == 1 && c.strides(-1) == 1 && c.data_size() == out.shape(-1)) {
if (beta_ == 1 && a.dtype() != complex64 && c.strides(-1) == 1 &&
c.data_size() == out.shape(-1)) {
out.set_data(allocator::malloc(out.nbytes()));
gemm_and_bias(
encoder,

View File

@@ -104,6 +104,27 @@ constexpr bool operator==(complex64_t a, complex64_t b) {
constexpr complex64_t operator+(complex64_t a, complex64_t b) {
return {a.real + b.real, a.imag + b.imag};
}
constexpr thread complex64_t& operator+=(thread complex64_t& a, complex64_t b) {
a.real += b.real;
a.imag += b.imag;
return a;
}
constexpr threadgroup complex64_t& operator+=(
threadgroup complex64_t& a,
complex64_t b) {
a.real += b.real;
a.imag += b.imag;
return a;
}
constexpr device complex64_t& operator+=(device complex64_t& a, complex64_t b) {
a.real += b.real;
a.imag += b.imag;
return a;
}
constexpr complex64_t operator+(float a, complex64_t b) {
return {a + b.real, b.imag};
}

View File

@@ -15,6 +15,15 @@ using namespace metal;
#define MLX_MTL_CONST static constant constexpr const
template <typename U>
struct DefaultAccT {
using type = float;
};
template <>
struct DefaultAccT<complex64_t> {
using type = complex64_t;
};
template <
typename T,
const int BM, /* Threadgroup rows (in simdgroups) */
@@ -24,8 +33,10 @@ template <
const int TM, /* Thread rows (in elements) */
const int TN, /* Thread cols (in elements) */
const bool kDoAxpby, /* Do out = alpha * out + beta * bias */
typename AccT = float>
typename AccT = typename DefaultAccT<T>::type>
struct GEMVKernel {
using acc_type = AccT;
MLX_MTL_CONST int threadsM = BM * SM;
MLX_MTL_CONST int threadsN = BN * SN;
@@ -246,8 +257,10 @@ template <
const int TM, /* Thread rows (in elements) */
const int TN, /* Thread cols (in elements) */
const bool kDoAxpby, /* Do out = alpha * out + beta * bias */
typename AccT = float>
typename AccT = typename DefaultAccT<T>::type>
struct GEMVTKernel {
using acc_type = AccT;
MLX_MTL_CONST int threadsM = BM * SM;
MLX_MTL_CONST int threadsN = BN * SN;
@@ -453,7 +466,7 @@ template <
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
using gemv_kernel = GEMVKernel<T, BM, BN, SM, SN, TM, TN, kDoAxpby>;
threadgroup float tgp_memory
threadgroup typename gemv_kernel::acc_type tgp_memory
[gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
// Update batch offsets
@@ -530,6 +543,7 @@ template <
instantiate_gemv_blocks(float32, float);
instantiate_gemv_blocks(float16, half);
instantiate_gemv_blocks(bfloat16, bfloat16_t);
instantiate_gemv_blocks(complex64, complex64_t);
template <
typename T,
@@ -565,7 +579,7 @@ template <
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
using gemv_kernel = GEMVKernel<T, BM, BN, SM, SN, TM, TN, false>;
threadgroup float tgp_memory
threadgroup typename gemv_kernel::acc_type tgp_memory
[gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
uint32_t indx_vec;
@@ -636,6 +650,7 @@ template <
instantiate_gemv_bs_blocks(float32, float);
instantiate_gemv_bs_blocks(float16, half);
instantiate_gemv_bs_blocks(bfloat16, bfloat16_t);
instantiate_gemv_bs_blocks(complex64, complex64_t);
///////////////////////////////////////////////////////////////////////////////
/// Vector matrix multiplication
@@ -672,7 +687,7 @@ template <
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
using gemv_kernel = GEMVTKernel<T, BM, BN, SM, SN, TM, TN, kDoAxpby>;
threadgroup float tgp_memory
threadgroup typename gemv_kernel::acc_type tgp_memory
[gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
// Update batch offsets
@@ -738,7 +753,8 @@ template <
// clang-format off
instantiate_gemv_t_blocks(float32, float);
instantiate_gemv_t_blocks(float16, half);
instantiate_gemv_t_blocks(bfloat16, bfloat16_t); // clang-format on
instantiate_gemv_t_blocks(bfloat16, bfloat16_t);
instantiate_gemv_t_blocks(complex64, complex64_t); // clang-format on
template <
typename T,
@@ -773,8 +789,8 @@ template <
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
using gemv_kernel = GEMVTKernel<T, BM, BN, SM, SN, TM, TN, false, float>;
threadgroup float tgp_memory
using gemv_kernel = GEMVTKernel<T, BM, BN, SM, SN, TM, TN, false>;
threadgroup typename gemv_kernel::acc_type tgp_memory
[gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
uint32_t indx_vec;
@@ -848,4 +864,5 @@ template <
// clang-format off
instantiate_gemv_t_bs_blocks(float32, float);
instantiate_gemv_t_bs_blocks(float16, half);
instantiate_gemv_t_bs_blocks(bfloat16, bfloat16_t); // clang-format on
instantiate_gemv_t_bs_blocks(bfloat16, bfloat16_t);
instantiate_gemv_t_bs_blocks(complex64, complex64_t); // clang-format on

View File

@@ -3,9 +3,8 @@
#include <metal_stdlib>
// clang-format off
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
#include "mlx/backend/metal/kernels/steel/conv/conv.h"
#include "mlx/backend/metal/kernels/steel/conv/params.h"
#include "mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h"

View File

@@ -3,9 +3,8 @@
#include <metal_stdlib>
// clang-format off
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
#include "mlx/backend/metal/kernels/steel/conv/conv.h"
#include "mlx/backend/metal/kernels/steel/conv/params.h"
#include "mlx/backend/metal/kernels/steel/utils.h"

View File

@@ -23,10 +23,12 @@
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 1, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 1, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2)
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 8, 4, 1)
instantiate_gemm_shapes_helper(float16, half, float16, half);
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
instantiate_gemm_shapes_helper(float32, float, float32, float);
instantiate_gemm_shapes_helper(complex64, complex64_t, complex64, complex64_t);
// clang-format on

View File

@@ -1,8 +1,8 @@
// Copyright © 2024 Apple Inc.
// clang-format off
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h"
#define instantiate_gemm( \

View File

@@ -60,6 +60,7 @@
instantiate_gemm_shapes_helper(float16, half, float32, float);
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, float32, float);
instantiate_gemm_shapes_helper(float32, float, float32, float);
instantiate_gemm_shapes_helper(complex64, complex64_t, complex64, complex64_t);
#define instantiate_accum(oname, otype, aname, atype) \
instantiate_kernel( \
@@ -71,4 +72,5 @@ instantiate_gemm_shapes_helper(float32, float, float32, float);
instantiate_accum(bfloat16, bfloat16_t, float32, float);
instantiate_accum(float16, half, float32, float);
instantiate_accum(float32, float, float32, float); // clang-format on
instantiate_accum(float32, float, float32, float);
instantiate_accum(complex64, complex64_t, complex64, complex64_t); // clang-format on

View File

@@ -421,6 +421,16 @@ METAL_FUNC void tile_matmad(
}
}
template <typename InT>
struct TransformNone<complex64_t, InT> {
static METAL_FUNC complex64_t apply(complex64_t x) {
return x;
}
static METAL_FUNC complex64_t apply(complex64_t x, complex64_t) {
return x;
}
};
template <
typename T,
typename U,
@@ -731,5 +741,406 @@ struct BlockMMA {
}
};
template <
typename U,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
short lda_tgp,
short ldb_tgp,
typename AccumType,
typename Epilogue>
struct BlockMMA<
complex64_t,
U,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
lda_tgp,
ldb_tgp,
AccumType,
Epilogue> {
static_assert(
metal::is_same_v<AccumType, float>,
"BlockMMA<complex64_t,...> expects float accumulators");
static_assert(
metal::is_same_v<U, complex64_t>,
"For complex BlockMMA, U must be complex64_t; use a different epilogue for projections");
// MMAFrag size
STEEL_CONST short kFragSize = 8;
using MMAFrag_acc_t = BaseMMAFrag<AccumType, kFragSize, kFragSize>;
// Warp tile simdgroup matrix strides along M
STEEL_CONST short TM_stride = kFragSize * WM;
// Warp tile simdgroup matrix strides along M
STEEL_CONST short TN_stride = kFragSize * WN;
// Warp tile size along M
STEEL_CONST short TM = BM / (kFragSize * WM);
// Warp tile size along N
STEEL_CONST short TN = BN / (kFragSize * WN);
// Threadgroup A strides
STEEL_CONST short A_str_m = transpose_a ? 1 : lda_tgp; // M
STEEL_CONST short A_str_k = transpose_a ? lda_tgp : 1; // K
// Threadgroup B strides
STEEL_CONST short B_str_k = transpose_b ? 1 : ldb_tgp; // K
STEEL_CONST short B_str_n = transpose_b ? ldb_tgp : 1; // N
// Threadgroup strides along K
STEEL_CONST short tile_stride_a = kFragSize * A_str_k;
STEEL_CONST short tile_stride_b = kFragSize * B_str_k;
// When indexing complex as float[2]
STEEL_CONST short A_str_m_f = A_str_m * 2;
STEEL_CONST short A_str_k_f = A_str_k * 2;
STEEL_CONST short B_str_k_f = B_str_k * 2;
STEEL_CONST short B_str_n_f = B_str_n * 2;
STEEL_CONST short tile_stride_a_f = tile_stride_a * 2;
STEEL_CONST short tile_stride_b_f = tile_stride_b * 2;
// Accumulators (real/imag)
MMATile<AccumType, TM, TN, MMAFrag_acc_t> Ctile_r;
MMATile<AccumType, TM, TN, MMAFrag_acc_t> Ctile_i;
// Offsets within threadgroup
short sm, sn;
short As_offset, Bs_offset;
/* Constructor */
METAL_FUNC BlockMMA(
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
ushort simd_lane_id [[thread_index_in_simdgroup]]) {
// Determine thread position in simdgroup matrix
short tm = kFragSize * (simd_group_id / WN);
short tn = kFragSize * (simd_group_id % WN);
short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id);
sm = simd_coord.y;
sn = simd_coord.x;
// Determine thread and simdgroup offset
As_offset = (tm + sm) * A_str_m + (sn)*A_str_k; // (M,K)
Bs_offset = (sm)*B_str_k + (tn + sn) * B_str_n; // (K,N)
sm += tm;
sn += tn;
}
/* Karatsuba MMA: 3 real MMAs per K-chunk */
METAL_FUNC void mma(
const threadgroup complex64_t* As,
const threadgroup complex64_t* Bs) {
// Adjust for simdgroup and thread location
As += As_offset;
Bs += Bs_offset;
threadgroup const float* As_f =
reinterpret_cast<threadgroup const float*>(As);
threadgroup const float* Bs_f =
reinterpret_cast<threadgroup const float*>(Bs);
// Iterate over BK in blocks of kFragSize
STEEL_PRAGMA_UNROLL
for (short kk = 0; kk < BK; kk += kFragSize) {
simdgroup_barrier(mem_flags::mem_none);
MMATile<AccumType, TM, 1, MMAFrag_acc_t> Ar, Ai;
Ar.template load<float, WM, 1, A_str_m_f, A_str_k_f>(As_f + 0);
Ai.template load<float, WM, 1, A_str_m_f, A_str_k_f>(As_f + 1);
simdgroup_barrier(mem_flags::mem_none);
MMATile<AccumType, 1, TN, MMAFrag_acc_t> Br, Bi;
Br.template load<float, 1, WN, B_str_k_f, B_str_n_f>(Bs_f + 0);
Bi.template load<float, 1, WN, B_str_k_f, B_str_n_f>(Bs_f + 1);
simdgroup_barrier(mem_flags::mem_none);
// P = Ar*Br ; Q = Ai*Bi ; R = (Ar+Ai)*(Br+Bi)
MMATile<AccumType, TM, TN, MMAFrag_acc_t> P, Q, R;
tile_matmad(P, Ar, Br, P);
tile_matmad(Q, Ai, Bi, Q);
STEEL_PRAGMA_UNROLL
for (short i = 0; i < decltype(Ar)::kElemsPerTile; ++i)
Ar.elems()[i] += Ai.elems()[i];
STEEL_PRAGMA_UNROLL
for (short i = 0; i < decltype(Br)::kElemsPerTile; ++i)
Br.elems()[i] += Bi.elems()[i];
tile_matmad(R, Ar, Br, R);
// C_r += P - Q ; C_i -= Q
STEEL_PRAGMA_UNROLL
for (short i = 0; i < decltype(Ctile_r)::kElemsPerTile; ++i) {
const auto p = P.elems()[i];
const auto q = Q.elems()[i];
const auto r = R.elems()[i];
Ctile_r.elems()[i] += (p - q);
Ctile_i.elems()[i] += (r - p - q);
}
// Progress to next simdgroup tile
As_f += tile_stride_a_f;
Bs_f += tile_stride_b_f;
}
}
/* Store results from simdgroup_matrix results into device memory */
METAL_FUNC void store_result(device U* D, const int ldd) {
// Adjust for simdgroup and thread location
D += sm * ldd + sn;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
thread const auto& r = Ctile_r.frag_at(i, j);
thread const auto& im = Ctile_i.frag_at(i, j);
int off = (i * TM_stride) * ldd + (j * TN_stride);
STEEL_PRAGMA_UNROLL
for (short k = 0; k < decltype(Ctile_r)::kElemsPerFrag; k++) {
D[off + k] = Epilogue::apply(complex64_t(r[k], im[k]));
}
}
}
}
METAL_FUNC void
store_result_slice(device U* D, const int ldd, short2 start, short2 stop) {
D += sm * ldd + sn;
start -= short2(sn, sm);
stop -= short2(sn, sm);
if (stop.y <= 0 || stop.x <= 0)
return;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; ++i) {
const int row = i * TM_stride;
if (row >= start.y && row < stop.y) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; ++j) {
const int off = row * ldd + (j * TN_stride);
thread const auto& r = Ctile_r.frag_at(i, j);
thread const auto& im = Ctile_i.frag_at(i, j);
STEEL_PRAGMA_UNROLL
for (short k = 0; k < decltype(Ctile_r)::kElemsPerFrag; ++k) {
const int col = j * TN_stride + k;
if (col >= start.x && col < stop.x) {
D[off + k] = Epilogue::apply(complex64_t(r[k], im[k]));
}
}
}
}
}
}
METAL_FUNC void
store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) {
D += sm * ldd + sn;
dst_tile_dims -= short2(sn, sm);
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
return;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
if (i * TM_stride < dst_tile_dims.y) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
int off = (i * TM_stride) * ldd + (j * TN_stride);
thread const auto& r = Ctile_r.frag_at(i, j);
thread const auto& im = Ctile_i.frag_at(i, j);
STEEL_PRAGMA_UNROLL
for (short k = 0; k < decltype(Ctile_r)::kElemsPerFrag; k++) {
if ((j * TN_stride + k) < dst_tile_dims.x) {
D[off + k] = Epilogue::apply(complex64_t(r[k], im[k]));
}
}
}
}
}
}
/* Apply epilogue */
template <typename UnaryEpilogue>
METAL_FUNC void apply_epilogue(thread const UnaryEpilogue& epilogue_op) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < decltype(Ctile_r)::kElemsPerTile; i++) {
complex64_t out = epilogue_op.apply(
complex64_t(Ctile_r.elems()[i], Ctile_i.elems()[i]));
Ctile_r.elems()[i] = out.real;
Ctile_i.elems()[i] = out.imag;
}
}
/* Apply epilogue */
template <typename BinaryEpilogue>
METAL_FUNC void apply_epilogue(
const device U* C,
const int ldc,
const int fdc,
thread const BinaryEpilogue& epilogue_op) {
// Adjust for simdgroup and thread location
C += (sm)*ldc + (sn)*fdc;
// Loop over all simdgroup tiles
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
// Get accumulated result and associated offset in Cr, Ci
thread auto& r = Ctile_r.frag_at(i, j);
thread auto& im = Ctile_i.frag_at(i, j);
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
STEEL_PRAGMA_UNROLL
for (short k = 0; k < decltype(Ctile_r)::kElemsPerFrag; k++) {
complex64_t out = epilogue_op.apply(
complex64_t(r[k], im[k]), C[offset_c + k * fdc]);
r[k] = out.real;
im[k] = out.imag;
}
}
}
}
/* Apply epilogue */
template <typename BinaryEpilogue>
METAL_FUNC void apply_epilogue_safe(
const device U* C,
const int ldc,
const int fdc,
short2 dst_tile_dims,
thread const BinaryEpilogue& epilogue_op) {
// Adjust for simdgroup and thread location
C += (sm)*ldc + (sn)*fdc;
dst_tile_dims -= short2(sn, sm);
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
return;
// Loop over all simdgroup tiles
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
// Get accumulated result and associated offset in Cr, Ci
thread auto& r = Ctile_r.frag_at(i, j);
thread auto& im = Ctile_i.frag_at(i, j);
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
constexpr short kelems = decltype(Ctile_r)::kElemsPerFrag;
complex64_t tmp[kelems];
STEEL_PRAGMA_UNROLL
for (short k = 0; k < kelems; k++) {
if ((j * TN_stride + k) < dst_tile_dims.x &&
(i * TM_stride) < dst_tile_dims.y) {
tmp[k] = C[offset_c + k * fdc];
} else {
tmp[k] = complex64_t(0.0f, 0.0f);
}
}
// Apply epilogue
STEEL_PRAGMA_UNROLL
for (short k = 0; k < kelems; k++) {
complex64_t out = epilogue_op.apply(complex64_t(r[k], im[k]), tmp[k]);
r[k] = out.real;
im[k] = out.imag;
}
}
}
}
/* Store results from simdgroup_matrix results into device memory */
METAL_FUNC void store_result(
device U* D,
const int ldd,
const device U* C,
const int ldc,
const int fdc,
thread const Epilogue& epilogue_op) const {
// Adjust for simdgroup and thread location
C += (sm)*ldc + (sn)*fdc;
D += (sm)*ldd + sn;
constexpr short kelems = decltype(Ctile_r)::kElemsPerFrag;
// Loop over all simdgroup tiles
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
// Get accumulated result and associated offset in Cr, Ci
thread const auto& r = Ctile_r.frag_at(i, j);
thread const auto& im = Ctile_i.frag_at(i, j);
int off_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
int off_d = (i * TM_stride) * ldd + (j * TN_stride);
// Apply epilogue
STEEL_PRAGMA_UNROLL
for (short k = 0; k < kelems; k++) {
D[off_d + k] =
epilogue_op.apply(complex64_t(r[k], im[k]), C[off_c + k * fdc]);
}
}
}
}
METAL_FUNC void store_result_safe(
device U* D,
const int ldd,
const device U* C,
const int ldc,
const int fdc,
short2 dst_tile_dims,
thread const Epilogue& epilogue_op) const {
// Adjust for simdgroup and thread location
C += (sm)*ldc + (sn)*fdc;
D += (sm)*ldd + sn;
dst_tile_dims -= short2(sn, sm);
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
return;
constexpr short kelems = decltype(Ctile_r)::kElemsPerFrag;
STEEL_PRAGMA_UNROLL
for (int i = 0; i < TM; i++) {
if (i * TM_stride < dst_tile_dims.y) {
STEEL_PRAGMA_UNROLL
for (int j = 0; j < TN; j++) {
// Get accumulated result and associated offset in Cr, Ci
thread const auto& r = Ctile_r.frag_at(i, j);
thread const auto& im = Ctile_i.frag_at(i, j);
int off_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
int off_d = (i * TM_stride) * ldd + (j * TN_stride);
// Apply epilogue
STEEL_PRAGMA_UNROLL
for (short k = 0; k < kelems; k++) {
if ((j * TN_stride + k) < dst_tile_dims.x) {
D[off_d + k] = epilogue_op.apply(
complex64_t(r[k], im[k]), C[off_c + k * fdc]);
}
}
}
}
}
}
};
} // namespace steel
} // namespace mlx

View File

@@ -48,7 +48,8 @@ struct TransformAxpby {
}
METAL_FUNC OutT apply(InT x, OutT c) const {
return static_cast<OutT>(x * alpha + (beta * c));
return static_cast<OutT>(
x * static_cast<InT>(alpha) + (static_cast<OutT>(beta) * c));
}
};

View File

@@ -86,7 +86,13 @@ ensure_batch_contiguous(const array& x, metal::Device& d, const Stream& s) {
#define GEMM_TPARAM_MACRO(devc) \
if (devc == 'g' || devc == 'p') { /* Small device */ \
if (!transpose_a && transpose_b) { /* nt */ \
if (out.dtype() == complex64) { \
bm = 64; \
bn = 32; \
bk = 8; \
wm = 4; \
wn = 1; \
} else if (!transpose_a && transpose_b) { /* nt */ \
bm = 64; \
bn = 32; \
bk = 32; \
@@ -362,7 +368,11 @@ void steel_gemm_splitk_axpby(
int gemm_k_iterations = (K / bk) / split_k_partitions;
int split_k_partition_size = gemm_k_iterations * bk;
array C_split({split_k_partitions, M, N}, float32, nullptr, {});
array C_split(
{split_k_partitions, M, N},
issubdtype(out.dtype(), complexfloating) ? complex64 : float32,
nullptr,
{});
C_split.set_data(allocator::malloc(C_split.nbytes()));
copies.push_back(C_split);
@@ -807,9 +817,8 @@ inline void gemv(
void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
if (!issubdtype(out.dtype(), floating)) {
throw std::runtime_error(
"[matmul] Does not yet support non-floating point types.");
if (!issubdtype(out.dtype(), inexact)) {
throw std::runtime_error("[matmul] dtype must be inexact.");
}
auto& s = stream();
auto& d = metal::device(s.device);
@@ -1338,7 +1347,8 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
<< "_wm" << wm << "_wn" << wn << "_MN_" << (mn_aligned ? "t" : "n")
<< "aligned" << "_K_" << (k_aligned ? "t" : "n") << "aligned";
<< "aligned"
<< "_K_" << (k_aligned ? "t" : "n") << "aligned";
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);

View File

@@ -76,6 +76,19 @@ void dispatch_float_types(Dtype dt, std::string_view tag, F&& f) {
}
}
template <typename F>
void dispatch_inexact_types(Dtype dt, std::string_view tag, F&& f) {
switch (dt) {
MLX_INTERNAL_DTYPE_SWITCH_FLOATS();
MLX_INTERNAL_DTYPE_SWITCH_CASE(complex64, complex64_t);
default:
std::ostringstream msg;
msg << tag << " Only inexact (float/complex) types supported but " << dt
<< " was provided";
throw std::invalid_argument(msg.str());
}
}
template <typename F>
void dispatch_int_float_types(Dtype dt, std::string_view tag, F&& f) {
switch (dt) {

View File

@@ -1228,14 +1228,14 @@ array pad(
if (low_pad_size[i] < 0) {
std::ostringstream msg;
msg << "Invalid low padding size (" << low_pad_size[i]
<< ") passed to pad" << " for axis " << i
<< ") passed to pad for axis " << i
<< ". Padding sizes must be non-negative";
throw std::invalid_argument(msg.str());
}
if (high_pad_size[i] < 0) {
std::ostringstream msg;
msg << "Invalid high padding size (" << high_pad_size[i]
<< ") passed to pad" << " for axis " << i
<< ") passed to pad for axis " << i
<< ". Padding sizes must be non-negative";
throw std::invalid_argument(msg.str());
}
@@ -2859,26 +2859,6 @@ array matmul(
"have at least one dimension.");
}
// complex matmul using Karatsuba's Algorithm
if (a.dtype() == complex64 || b.dtype() == complex64) {
// Extract real and imaginary parts
auto a_real = real(a, s);
auto a_imag = imag(a, s);
auto b_real = real(b, s);
auto b_imag = imag(b, s);
// Compute real and imaginary components of the result
auto m1 = matmul(a_real, b_real, s);
auto m2 = matmul(a_imag, b_imag, s);
auto m3 = matmul(add(a_real, a_imag, s), add(b_real, b_imag, s), s);
auto c_real = subtract(m1, m2, s);
auto c_imag = subtract(m3, add(m1, m2, s), s);
return add(
c_real, multiply(array(complex64_t{0, 1}, complex64), c_imag, s), s);
}
if (a.ndim() == 1) {
// Insert a singleton dim in the beginning
a = expand_dims(a, 0, s);
@@ -2898,11 +2878,11 @@ array matmul(
// Type promotion
auto out_type = promote_types(a.dtype(), b.dtype());
if (!issubdtype(out_type, floating)) {
if (!issubdtype(out_type, inexact)) {
std::ostringstream msg;
msg << "[matmul] Only real floating point types are supported but "
<< a.dtype() << " and " << b.dtype() << " were provided which results"
<< " in " << out_type << ", which is not a real floating point type.";
msg << "[matmul] Only inexact types are supported but " << a.dtype()
<< " and " << b.dtype() << " were provided which results"
<< " in " << out_type << ", which is not a floating point type.";
throw std::invalid_argument(msg.str());
}
if (a.dtype() != out_type) {
@@ -3146,8 +3126,8 @@ array take_along_axis(
StreamOrDevice s /* = {} */) {
if (axis + a.ndim() < 0 || axis >= static_cast<int>(a.ndim())) {
std::ostringstream msg;
msg << "[take_along_axis] Received invalid axis " << " for array with "
<< a.ndim() << " dimensions.";
msg << "[take_along_axis] Received invalid axis for array with " << a.ndim()
<< " dimensions.";
throw std::invalid_argument(msg.str());
}
@@ -3184,7 +3164,7 @@ array scatter_axis(
(mode == ScatterAxis::None) ? "[put_along_axis]" : "[scatter_add_axis]";
if (axis + a.ndim() < 0 || axis >= static_cast<int>(a.ndim())) {
std::ostringstream msg;
msg << prefix << " Received invalid axis " << " for array with " << a.ndim()
msg << prefix << " Received invalid axis for array with " << a.ndim()
<< " dimensions.";
throw std::invalid_argument(msg.str());
}
@@ -3592,7 +3572,7 @@ run_conv_checks(const array& in, const array& wt, int n_dim, int groups) {
if (in.ndim() != n_dim + 2) {
std::ostringstream msg;
msg << "[conv] Invalid input array with " << in.ndim() << " dimensions for "
<< n_dim << "D convolution." << " Expected an array with " << n_dim + 2
<< n_dim << "D convolution. Expected an array with " << n_dim + 2
<< " dimensions following the format [N, ..., C_in].";
throw std::invalid_argument(msg.str());
}
@@ -4141,7 +4121,8 @@ std::vector<array> quantize(
std::ostringstream msg;
msg << "[quantize] The last dimension of the matrix needs to be divisible by "
<< "the quantization group size " << group_size
<< ". However the provided " << " matrix has shape " << w.shape();
<< ". However the provided "
<< " matrix has shape " << w.shape();
throw std::invalid_argument(msg.str());
}