Faster complex matmul (#2571)

This commit is contained in:
Daniel Yeh
2025-10-03 08:33:15 +02:00
committed by GitHub
parent 287c63a093
commit 22a5da76c8
20 changed files with 623 additions and 73 deletions

View File

@@ -104,6 +104,27 @@ constexpr bool operator==(complex64_t a, complex64_t b) {
constexpr complex64_t operator+(complex64_t a, complex64_t b) {
return {a.real + b.real, a.imag + b.imag};
}
constexpr thread complex64_t& operator+=(thread complex64_t& a, complex64_t b) {
a.real += b.real;
a.imag += b.imag;
return a;
}
constexpr threadgroup complex64_t& operator+=(
threadgroup complex64_t& a,
complex64_t b) {
a.real += b.real;
a.imag += b.imag;
return a;
}
constexpr device complex64_t& operator+=(device complex64_t& a, complex64_t b) {
a.real += b.real;
a.imag += b.imag;
return a;
}
constexpr complex64_t operator+(float a, complex64_t b) {
return {a + b.real, b.imag};
}

View File

@@ -15,6 +15,15 @@ using namespace metal;
#define MLX_MTL_CONST static constant constexpr const
template <typename U>
struct DefaultAccT {
using type = float;
};
template <>
struct DefaultAccT<complex64_t> {
using type = complex64_t;
};
template <
typename T,
const int BM, /* Threadgroup rows (in simdgroups) */
@@ -24,8 +33,10 @@ template <
const int TM, /* Thread rows (in elements) */
const int TN, /* Thread cols (in elements) */
const bool kDoAxpby, /* Do out = alpha * out + beta * bias */
typename AccT = float>
typename AccT = typename DefaultAccT<T>::type>
struct GEMVKernel {
using acc_type = AccT;
MLX_MTL_CONST int threadsM = BM * SM;
MLX_MTL_CONST int threadsN = BN * SN;
@@ -246,8 +257,10 @@ template <
const int TM, /* Thread rows (in elements) */
const int TN, /* Thread cols (in elements) */
const bool kDoAxpby, /* Do out = alpha * out + beta * bias */
typename AccT = float>
typename AccT = typename DefaultAccT<T>::type>
struct GEMVTKernel {
using acc_type = AccT;
MLX_MTL_CONST int threadsM = BM * SM;
MLX_MTL_CONST int threadsN = BN * SN;
@@ -453,7 +466,7 @@ template <
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
using gemv_kernel = GEMVKernel<T, BM, BN, SM, SN, TM, TN, kDoAxpby>;
threadgroup float tgp_memory
threadgroup typename gemv_kernel::acc_type tgp_memory
[gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
// Update batch offsets
@@ -530,6 +543,7 @@ template <
instantiate_gemv_blocks(float32, float);
instantiate_gemv_blocks(float16, half);
instantiate_gemv_blocks(bfloat16, bfloat16_t);
instantiate_gemv_blocks(complex64, complex64_t);
template <
typename T,
@@ -565,7 +579,7 @@ template <
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
using gemv_kernel = GEMVKernel<T, BM, BN, SM, SN, TM, TN, false>;
threadgroup float tgp_memory
threadgroup typename gemv_kernel::acc_type tgp_memory
[gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
uint32_t indx_vec;
@@ -636,6 +650,7 @@ template <
instantiate_gemv_bs_blocks(float32, float);
instantiate_gemv_bs_blocks(float16, half);
instantiate_gemv_bs_blocks(bfloat16, bfloat16_t);
instantiate_gemv_bs_blocks(complex64, complex64_t);
///////////////////////////////////////////////////////////////////////////////
/// Vector matrix multiplication
@@ -672,7 +687,7 @@ template <
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
using gemv_kernel = GEMVTKernel<T, BM, BN, SM, SN, TM, TN, kDoAxpby>;
threadgroup float tgp_memory
threadgroup typename gemv_kernel::acc_type tgp_memory
[gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
// Update batch offsets
@@ -738,7 +753,8 @@ template <
// clang-format off
instantiate_gemv_t_blocks(float32, float);
instantiate_gemv_t_blocks(float16, half);
instantiate_gemv_t_blocks(bfloat16, bfloat16_t); // clang-format on
instantiate_gemv_t_blocks(bfloat16, bfloat16_t);
instantiate_gemv_t_blocks(complex64, complex64_t); // clang-format on
template <
typename T,
@@ -773,8 +789,8 @@ template <
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
using gemv_kernel = GEMVTKernel<T, BM, BN, SM, SN, TM, TN, false, float>;
threadgroup float tgp_memory
using gemv_kernel = GEMVTKernel<T, BM, BN, SM, SN, TM, TN, false>;
threadgroup typename gemv_kernel::acc_type tgp_memory
[gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
uint32_t indx_vec;
@@ -848,4 +864,5 @@ template <
// clang-format off
instantiate_gemv_t_bs_blocks(float32, float);
instantiate_gemv_t_bs_blocks(float16, half);
instantiate_gemv_t_bs_blocks(bfloat16, bfloat16_t); // clang-format on
instantiate_gemv_t_bs_blocks(bfloat16, bfloat16_t);
instantiate_gemv_t_bs_blocks(complex64, complex64_t); // clang-format on

View File

@@ -3,9 +3,8 @@
#include <metal_stdlib>
// clang-format off
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
#include "mlx/backend/metal/kernels/steel/conv/conv.h"
#include "mlx/backend/metal/kernels/steel/conv/params.h"
#include "mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h"

View File

@@ -3,9 +3,8 @@
#include <metal_stdlib>
// clang-format off
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
#include "mlx/backend/metal/kernels/steel/conv/conv.h"
#include "mlx/backend/metal/kernels/steel/conv/params.h"
#include "mlx/backend/metal/kernels/steel/utils.h"

View File

@@ -23,10 +23,12 @@
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 1, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 1, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2)
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 8, 4, 1)
instantiate_gemm_shapes_helper(float16, half, float16, half);
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
instantiate_gemm_shapes_helper(float32, float, float32, float);
instantiate_gemm_shapes_helper(complex64, complex64_t, complex64, complex64_t);
// clang-format on

View File

@@ -1,8 +1,8 @@
// Copyright © 2024 Apple Inc.
// clang-format off
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h"
#define instantiate_gemm( \

View File

@@ -60,6 +60,7 @@
instantiate_gemm_shapes_helper(float16, half, float32, float);
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, float32, float);
instantiate_gemm_shapes_helper(float32, float, float32, float);
instantiate_gemm_shapes_helper(complex64, complex64_t, complex64, complex64_t);
#define instantiate_accum(oname, otype, aname, atype) \
instantiate_kernel( \
@@ -71,4 +72,5 @@ instantiate_gemm_shapes_helper(float32, float, float32, float);
instantiate_accum(bfloat16, bfloat16_t, float32, float);
instantiate_accum(float16, half, float32, float);
instantiate_accum(float32, float, float32, float); // clang-format on
instantiate_accum(float32, float, float32, float);
instantiate_accum(complex64, complex64_t, complex64, complex64_t); // clang-format on

View File

@@ -421,6 +421,16 @@ METAL_FUNC void tile_matmad(
}
}
template <typename InT>
struct TransformNone<complex64_t, InT> {
static METAL_FUNC complex64_t apply(complex64_t x) {
return x;
}
static METAL_FUNC complex64_t apply(complex64_t x, complex64_t) {
return x;
}
};
template <
typename T,
typename U,
@@ -731,5 +741,406 @@ struct BlockMMA {
}
};
template <
typename U,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
short lda_tgp,
short ldb_tgp,
typename AccumType,
typename Epilogue>
struct BlockMMA<
complex64_t,
U,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
lda_tgp,
ldb_tgp,
AccumType,
Epilogue> {
static_assert(
metal::is_same_v<AccumType, float>,
"BlockMMA<complex64_t,...> expects float accumulators");
static_assert(
metal::is_same_v<U, complex64_t>,
"For complex BlockMMA, U must be complex64_t; use a different epilogue for projections");
// MMAFrag size
STEEL_CONST short kFragSize = 8;
using MMAFrag_acc_t = BaseMMAFrag<AccumType, kFragSize, kFragSize>;
// Warp tile simdgroup matrix strides along M
STEEL_CONST short TM_stride = kFragSize * WM;
// Warp tile simdgroup matrix strides along M
STEEL_CONST short TN_stride = kFragSize * WN;
// Warp tile size along M
STEEL_CONST short TM = BM / (kFragSize * WM);
// Warp tile size along N
STEEL_CONST short TN = BN / (kFragSize * WN);
// Threadgroup A strides
STEEL_CONST short A_str_m = transpose_a ? 1 : lda_tgp; // M
STEEL_CONST short A_str_k = transpose_a ? lda_tgp : 1; // K
// Threadgroup B strides
STEEL_CONST short B_str_k = transpose_b ? 1 : ldb_tgp; // K
STEEL_CONST short B_str_n = transpose_b ? ldb_tgp : 1; // N
// Threadgroup strides along K
STEEL_CONST short tile_stride_a = kFragSize * A_str_k;
STEEL_CONST short tile_stride_b = kFragSize * B_str_k;
// When indexing complex as float[2]
STEEL_CONST short A_str_m_f = A_str_m * 2;
STEEL_CONST short A_str_k_f = A_str_k * 2;
STEEL_CONST short B_str_k_f = B_str_k * 2;
STEEL_CONST short B_str_n_f = B_str_n * 2;
STEEL_CONST short tile_stride_a_f = tile_stride_a * 2;
STEEL_CONST short tile_stride_b_f = tile_stride_b * 2;
// Accumulators (real/imag)
MMATile<AccumType, TM, TN, MMAFrag_acc_t> Ctile_r;
MMATile<AccumType, TM, TN, MMAFrag_acc_t> Ctile_i;
// Offsets within threadgroup
short sm, sn;
short As_offset, Bs_offset;
/* Constructor */
METAL_FUNC BlockMMA(
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
ushort simd_lane_id [[thread_index_in_simdgroup]]) {
// Determine thread position in simdgroup matrix
short tm = kFragSize * (simd_group_id / WN);
short tn = kFragSize * (simd_group_id % WN);
short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id);
sm = simd_coord.y;
sn = simd_coord.x;
// Determine thread and simdgroup offset
As_offset = (tm + sm) * A_str_m + (sn)*A_str_k; // (M,K)
Bs_offset = (sm)*B_str_k + (tn + sn) * B_str_n; // (K,N)
sm += tm;
sn += tn;
}
/* Karatsuba MMA: 3 real MMAs per K-chunk */
METAL_FUNC void mma(
const threadgroup complex64_t* As,
const threadgroup complex64_t* Bs) {
// Adjust for simdgroup and thread location
As += As_offset;
Bs += Bs_offset;
threadgroup const float* As_f =
reinterpret_cast<threadgroup const float*>(As);
threadgroup const float* Bs_f =
reinterpret_cast<threadgroup const float*>(Bs);
// Iterate over BK in blocks of kFragSize
STEEL_PRAGMA_UNROLL
for (short kk = 0; kk < BK; kk += kFragSize) {
simdgroup_barrier(mem_flags::mem_none);
MMATile<AccumType, TM, 1, MMAFrag_acc_t> Ar, Ai;
Ar.template load<float, WM, 1, A_str_m_f, A_str_k_f>(As_f + 0);
Ai.template load<float, WM, 1, A_str_m_f, A_str_k_f>(As_f + 1);
simdgroup_barrier(mem_flags::mem_none);
MMATile<AccumType, 1, TN, MMAFrag_acc_t> Br, Bi;
Br.template load<float, 1, WN, B_str_k_f, B_str_n_f>(Bs_f + 0);
Bi.template load<float, 1, WN, B_str_k_f, B_str_n_f>(Bs_f + 1);
simdgroup_barrier(mem_flags::mem_none);
// P = Ar*Br ; Q = Ai*Bi ; R = (Ar+Ai)*(Br+Bi)
MMATile<AccumType, TM, TN, MMAFrag_acc_t> P, Q, R;
tile_matmad(P, Ar, Br, P);
tile_matmad(Q, Ai, Bi, Q);
STEEL_PRAGMA_UNROLL
for (short i = 0; i < decltype(Ar)::kElemsPerTile; ++i)
Ar.elems()[i] += Ai.elems()[i];
STEEL_PRAGMA_UNROLL
for (short i = 0; i < decltype(Br)::kElemsPerTile; ++i)
Br.elems()[i] += Bi.elems()[i];
tile_matmad(R, Ar, Br, R);
// C_r += P - Q ; C_i -= Q
STEEL_PRAGMA_UNROLL
for (short i = 0; i < decltype(Ctile_r)::kElemsPerTile; ++i) {
const auto p = P.elems()[i];
const auto q = Q.elems()[i];
const auto r = R.elems()[i];
Ctile_r.elems()[i] += (p - q);
Ctile_i.elems()[i] += (r - p - q);
}
// Progress to next simdgroup tile
As_f += tile_stride_a_f;
Bs_f += tile_stride_b_f;
}
}
/* Store results from simdgroup_matrix results into device memory */
METAL_FUNC void store_result(device U* D, const int ldd) {
// Adjust for simdgroup and thread location
D += sm * ldd + sn;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
thread const auto& r = Ctile_r.frag_at(i, j);
thread const auto& im = Ctile_i.frag_at(i, j);
int off = (i * TM_stride) * ldd + (j * TN_stride);
STEEL_PRAGMA_UNROLL
for (short k = 0; k < decltype(Ctile_r)::kElemsPerFrag; k++) {
D[off + k] = Epilogue::apply(complex64_t(r[k], im[k]));
}
}
}
}
METAL_FUNC void
store_result_slice(device U* D, const int ldd, short2 start, short2 stop) {
D += sm * ldd + sn;
start -= short2(sn, sm);
stop -= short2(sn, sm);
if (stop.y <= 0 || stop.x <= 0)
return;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; ++i) {
const int row = i * TM_stride;
if (row >= start.y && row < stop.y) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; ++j) {
const int off = row * ldd + (j * TN_stride);
thread const auto& r = Ctile_r.frag_at(i, j);
thread const auto& im = Ctile_i.frag_at(i, j);
STEEL_PRAGMA_UNROLL
for (short k = 0; k < decltype(Ctile_r)::kElemsPerFrag; ++k) {
const int col = j * TN_stride + k;
if (col >= start.x && col < stop.x) {
D[off + k] = Epilogue::apply(complex64_t(r[k], im[k]));
}
}
}
}
}
}
METAL_FUNC void
store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) {
D += sm * ldd + sn;
dst_tile_dims -= short2(sn, sm);
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
return;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
if (i * TM_stride < dst_tile_dims.y) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
int off = (i * TM_stride) * ldd + (j * TN_stride);
thread const auto& r = Ctile_r.frag_at(i, j);
thread const auto& im = Ctile_i.frag_at(i, j);
STEEL_PRAGMA_UNROLL
for (short k = 0; k < decltype(Ctile_r)::kElemsPerFrag; k++) {
if ((j * TN_stride + k) < dst_tile_dims.x) {
D[off + k] = Epilogue::apply(complex64_t(r[k], im[k]));
}
}
}
}
}
}
/* Apply epilogue */
template <typename UnaryEpilogue>
METAL_FUNC void apply_epilogue(thread const UnaryEpilogue& epilogue_op) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < decltype(Ctile_r)::kElemsPerTile; i++) {
complex64_t out = epilogue_op.apply(
complex64_t(Ctile_r.elems()[i], Ctile_i.elems()[i]));
Ctile_r.elems()[i] = out.real;
Ctile_i.elems()[i] = out.imag;
}
}
/* Apply epilogue */
template <typename BinaryEpilogue>
METAL_FUNC void apply_epilogue(
const device U* C,
const int ldc,
const int fdc,
thread const BinaryEpilogue& epilogue_op) {
// Adjust for simdgroup and thread location
C += (sm)*ldc + (sn)*fdc;
// Loop over all simdgroup tiles
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
// Get accumulated result and associated offset in Cr, Ci
thread auto& r = Ctile_r.frag_at(i, j);
thread auto& im = Ctile_i.frag_at(i, j);
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
STEEL_PRAGMA_UNROLL
for (short k = 0; k < decltype(Ctile_r)::kElemsPerFrag; k++) {
complex64_t out = epilogue_op.apply(
complex64_t(r[k], im[k]), C[offset_c + k * fdc]);
r[k] = out.real;
im[k] = out.imag;
}
}
}
}
/* Apply epilogue */
template <typename BinaryEpilogue>
METAL_FUNC void apply_epilogue_safe(
const device U* C,
const int ldc,
const int fdc,
short2 dst_tile_dims,
thread const BinaryEpilogue& epilogue_op) {
// Adjust for simdgroup and thread location
C += (sm)*ldc + (sn)*fdc;
dst_tile_dims -= short2(sn, sm);
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
return;
// Loop over all simdgroup tiles
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
// Get accumulated result and associated offset in Cr, Ci
thread auto& r = Ctile_r.frag_at(i, j);
thread auto& im = Ctile_i.frag_at(i, j);
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
constexpr short kelems = decltype(Ctile_r)::kElemsPerFrag;
complex64_t tmp[kelems];
STEEL_PRAGMA_UNROLL
for (short k = 0; k < kelems; k++) {
if ((j * TN_stride + k) < dst_tile_dims.x &&
(i * TM_stride) < dst_tile_dims.y) {
tmp[k] = C[offset_c + k * fdc];
} else {
tmp[k] = complex64_t(0.0f, 0.0f);
}
}
// Apply epilogue
STEEL_PRAGMA_UNROLL
for (short k = 0; k < kelems; k++) {
complex64_t out = epilogue_op.apply(complex64_t(r[k], im[k]), tmp[k]);
r[k] = out.real;
im[k] = out.imag;
}
}
}
}
/* Store results from simdgroup_matrix results into device memory */
METAL_FUNC void store_result(
device U* D,
const int ldd,
const device U* C,
const int ldc,
const int fdc,
thread const Epilogue& epilogue_op) const {
// Adjust for simdgroup and thread location
C += (sm)*ldc + (sn)*fdc;
D += (sm)*ldd + sn;
constexpr short kelems = decltype(Ctile_r)::kElemsPerFrag;
// Loop over all simdgroup tiles
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
// Get accumulated result and associated offset in Cr, Ci
thread const auto& r = Ctile_r.frag_at(i, j);
thread const auto& im = Ctile_i.frag_at(i, j);
int off_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
int off_d = (i * TM_stride) * ldd + (j * TN_stride);
// Apply epilogue
STEEL_PRAGMA_UNROLL
for (short k = 0; k < kelems; k++) {
D[off_d + k] =
epilogue_op.apply(complex64_t(r[k], im[k]), C[off_c + k * fdc]);
}
}
}
}
METAL_FUNC void store_result_safe(
device U* D,
const int ldd,
const device U* C,
const int ldc,
const int fdc,
short2 dst_tile_dims,
thread const Epilogue& epilogue_op) const {
// Adjust for simdgroup and thread location
C += (sm)*ldc + (sn)*fdc;
D += (sm)*ldd + sn;
dst_tile_dims -= short2(sn, sm);
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
return;
constexpr short kelems = decltype(Ctile_r)::kElemsPerFrag;
STEEL_PRAGMA_UNROLL
for (int i = 0; i < TM; i++) {
if (i * TM_stride < dst_tile_dims.y) {
STEEL_PRAGMA_UNROLL
for (int j = 0; j < TN; j++) {
// Get accumulated result and associated offset in Cr, Ci
thread const auto& r = Ctile_r.frag_at(i, j);
thread const auto& im = Ctile_i.frag_at(i, j);
int off_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
int off_d = (i * TM_stride) * ldd + (j * TN_stride);
// Apply epilogue
STEEL_PRAGMA_UNROLL
for (short k = 0; k < kelems; k++) {
if ((j * TN_stride + k) < dst_tile_dims.x) {
D[off_d + k] = epilogue_op.apply(
complex64_t(r[k], im[k]), C[off_c + k * fdc]);
}
}
}
}
}
}
};
} // namespace steel
} // namespace mlx

View File

@@ -48,7 +48,8 @@ struct TransformAxpby {
}
METAL_FUNC OutT apply(InT x, OutT c) const {
return static_cast<OutT>(x * alpha + (beta * c));
return static_cast<OutT>(
x * static_cast<InT>(alpha) + (static_cast<OutT>(beta) * c));
}
};

View File

@@ -86,7 +86,13 @@ ensure_batch_contiguous(const array& x, metal::Device& d, const Stream& s) {
#define GEMM_TPARAM_MACRO(devc) \
if (devc == 'g' || devc == 'p') { /* Small device */ \
if (!transpose_a && transpose_b) { /* nt */ \
if (out.dtype() == complex64) { \
bm = 64; \
bn = 32; \
bk = 8; \
wm = 4; \
wn = 1; \
} else if (!transpose_a && transpose_b) { /* nt */ \
bm = 64; \
bn = 32; \
bk = 32; \
@@ -362,7 +368,11 @@ void steel_gemm_splitk_axpby(
int gemm_k_iterations = (K / bk) / split_k_partitions;
int split_k_partition_size = gemm_k_iterations * bk;
array C_split({split_k_partitions, M, N}, float32, nullptr, {});
array C_split(
{split_k_partitions, M, N},
issubdtype(out.dtype(), complexfloating) ? complex64 : float32,
nullptr,
{});
C_split.set_data(allocator::malloc(C_split.nbytes()));
copies.push_back(C_split);
@@ -807,9 +817,8 @@ inline void gemv(
void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
if (!issubdtype(out.dtype(), floating)) {
throw std::runtime_error(
"[matmul] Does not yet support non-floating point types.");
if (!issubdtype(out.dtype(), inexact)) {
throw std::runtime_error("[matmul] dtype must be inexact.");
}
auto& s = stream();
auto& d = metal::device(s.device);
@@ -1338,7 +1347,8 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
<< "_wm" << wm << "_wn" << wn << "_MN_" << (mn_aligned ? "t" : "n")
<< "aligned" << "_K_" << (k_aligned ? "t" : "n") << "aligned";
<< "aligned"
<< "_K_" << (k_aligned ? "t" : "n") << "aligned";
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);