mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-21 18:28:11 +08:00
Faster complex matmul (#2571)
This commit is contained in:
@@ -104,6 +104,27 @@ constexpr bool operator==(complex64_t a, complex64_t b) {
|
||||
constexpr complex64_t operator+(complex64_t a, complex64_t b) {
|
||||
return {a.real + b.real, a.imag + b.imag};
|
||||
}
|
||||
|
||||
constexpr thread complex64_t& operator+=(thread complex64_t& a, complex64_t b) {
|
||||
a.real += b.real;
|
||||
a.imag += b.imag;
|
||||
return a;
|
||||
}
|
||||
|
||||
constexpr threadgroup complex64_t& operator+=(
|
||||
threadgroup complex64_t& a,
|
||||
complex64_t b) {
|
||||
a.real += b.real;
|
||||
a.imag += b.imag;
|
||||
return a;
|
||||
}
|
||||
|
||||
constexpr device complex64_t& operator+=(device complex64_t& a, complex64_t b) {
|
||||
a.real += b.real;
|
||||
a.imag += b.imag;
|
||||
return a;
|
||||
}
|
||||
|
||||
constexpr complex64_t operator+(float a, complex64_t b) {
|
||||
return {a + b.real, b.imag};
|
||||
}
|
||||
|
@@ -15,6 +15,15 @@ using namespace metal;
|
||||
|
||||
#define MLX_MTL_CONST static constant constexpr const
|
||||
|
||||
template <typename U>
|
||||
struct DefaultAccT {
|
||||
using type = float;
|
||||
};
|
||||
template <>
|
||||
struct DefaultAccT<complex64_t> {
|
||||
using type = complex64_t;
|
||||
};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
const int BM, /* Threadgroup rows (in simdgroups) */
|
||||
@@ -24,8 +33,10 @@ template <
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN, /* Thread cols (in elements) */
|
||||
const bool kDoAxpby, /* Do out = alpha * out + beta * bias */
|
||||
typename AccT = float>
|
||||
typename AccT = typename DefaultAccT<T>::type>
|
||||
struct GEMVKernel {
|
||||
using acc_type = AccT;
|
||||
|
||||
MLX_MTL_CONST int threadsM = BM * SM;
|
||||
MLX_MTL_CONST int threadsN = BN * SN;
|
||||
|
||||
@@ -246,8 +257,10 @@ template <
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN, /* Thread cols (in elements) */
|
||||
const bool kDoAxpby, /* Do out = alpha * out + beta * bias */
|
||||
typename AccT = float>
|
||||
typename AccT = typename DefaultAccT<T>::type>
|
||||
struct GEMVTKernel {
|
||||
using acc_type = AccT;
|
||||
|
||||
MLX_MTL_CONST int threadsM = BM * SM;
|
||||
MLX_MTL_CONST int threadsN = BN * SN;
|
||||
|
||||
@@ -453,7 +466,7 @@ template <
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
using gemv_kernel = GEMVKernel<T, BM, BN, SM, SN, TM, TN, kDoAxpby>;
|
||||
threadgroup float tgp_memory
|
||||
threadgroup typename gemv_kernel::acc_type tgp_memory
|
||||
[gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
|
||||
|
||||
// Update batch offsets
|
||||
@@ -530,6 +543,7 @@ template <
|
||||
instantiate_gemv_blocks(float32, float);
|
||||
instantiate_gemv_blocks(float16, half);
|
||||
instantiate_gemv_blocks(bfloat16, bfloat16_t);
|
||||
instantiate_gemv_blocks(complex64, complex64_t);
|
||||
|
||||
template <
|
||||
typename T,
|
||||
@@ -565,7 +579,7 @@ template <
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
using gemv_kernel = GEMVKernel<T, BM, BN, SM, SN, TM, TN, false>;
|
||||
threadgroup float tgp_memory
|
||||
threadgroup typename gemv_kernel::acc_type tgp_memory
|
||||
[gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
|
||||
|
||||
uint32_t indx_vec;
|
||||
@@ -636,6 +650,7 @@ template <
|
||||
instantiate_gemv_bs_blocks(float32, float);
|
||||
instantiate_gemv_bs_blocks(float16, half);
|
||||
instantiate_gemv_bs_blocks(bfloat16, bfloat16_t);
|
||||
instantiate_gemv_bs_blocks(complex64, complex64_t);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
/// Vector matrix multiplication
|
||||
@@ -672,7 +687,7 @@ template <
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
using gemv_kernel = GEMVTKernel<T, BM, BN, SM, SN, TM, TN, kDoAxpby>;
|
||||
threadgroup float tgp_memory
|
||||
threadgroup typename gemv_kernel::acc_type tgp_memory
|
||||
[gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
|
||||
|
||||
// Update batch offsets
|
||||
@@ -738,7 +753,8 @@ template <
|
||||
// clang-format off
|
||||
instantiate_gemv_t_blocks(float32, float);
|
||||
instantiate_gemv_t_blocks(float16, half);
|
||||
instantiate_gemv_t_blocks(bfloat16, bfloat16_t); // clang-format on
|
||||
instantiate_gemv_t_blocks(bfloat16, bfloat16_t);
|
||||
instantiate_gemv_t_blocks(complex64, complex64_t); // clang-format on
|
||||
|
||||
template <
|
||||
typename T,
|
||||
@@ -773,8 +789,8 @@ template <
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
using gemv_kernel = GEMVTKernel<T, BM, BN, SM, SN, TM, TN, false, float>;
|
||||
threadgroup float tgp_memory
|
||||
using gemv_kernel = GEMVTKernel<T, BM, BN, SM, SN, TM, TN, false>;
|
||||
threadgroup typename gemv_kernel::acc_type tgp_memory
|
||||
[gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
|
||||
|
||||
uint32_t indx_vec;
|
||||
@@ -848,4 +864,5 @@ template <
|
||||
// clang-format off
|
||||
instantiate_gemv_t_bs_blocks(float32, float);
|
||||
instantiate_gemv_t_bs_blocks(float16, half);
|
||||
instantiate_gemv_t_bs_blocks(bfloat16, bfloat16_t); // clang-format on
|
||||
instantiate_gemv_t_bs_blocks(bfloat16, bfloat16_t);
|
||||
instantiate_gemv_t_bs_blocks(complex64, complex64_t); // clang-format on
|
||||
|
@@ -3,9 +3,8 @@
|
||||
#include <metal_stdlib>
|
||||
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
|
||||
#include "mlx/backend/metal/kernels/steel/conv/conv.h"
|
||||
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
||||
#include "mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h"
|
||||
|
@@ -3,9 +3,8 @@
|
||||
#include <metal_stdlib>
|
||||
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
|
||||
#include "mlx/backend/metal/kernels/steel/conv/conv.h"
|
||||
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
||||
#include "mlx/backend/metal/kernels/steel/utils.h"
|
||||
|
@@ -23,10 +23,12 @@
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 1, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 1, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2)
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 8, 4, 1)
|
||||
|
||||
instantiate_gemm_shapes_helper(float16, half, float16, half);
|
||||
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
|
||||
|
||||
instantiate_gemm_shapes_helper(float32, float, float32, float);
|
||||
instantiate_gemm_shapes_helper(complex64, complex64_t, complex64, complex64_t);
|
||||
// clang-format on
|
||||
|
@@ -1,8 +1,8 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h"
|
||||
|
||||
#define instantiate_gemm( \
|
||||
|
@@ -60,6 +60,7 @@
|
||||
instantiate_gemm_shapes_helper(float16, half, float32, float);
|
||||
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, float32, float);
|
||||
instantiate_gemm_shapes_helper(float32, float, float32, float);
|
||||
instantiate_gemm_shapes_helper(complex64, complex64_t, complex64, complex64_t);
|
||||
|
||||
#define instantiate_accum(oname, otype, aname, atype) \
|
||||
instantiate_kernel( \
|
||||
@@ -71,4 +72,5 @@ instantiate_gemm_shapes_helper(float32, float, float32, float);
|
||||
|
||||
instantiate_accum(bfloat16, bfloat16_t, float32, float);
|
||||
instantiate_accum(float16, half, float32, float);
|
||||
instantiate_accum(float32, float, float32, float); // clang-format on
|
||||
instantiate_accum(float32, float, float32, float);
|
||||
instantiate_accum(complex64, complex64_t, complex64, complex64_t); // clang-format on
|
||||
|
@@ -421,6 +421,16 @@ METAL_FUNC void tile_matmad(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename InT>
|
||||
struct TransformNone<complex64_t, InT> {
|
||||
static METAL_FUNC complex64_t apply(complex64_t x) {
|
||||
return x;
|
||||
}
|
||||
static METAL_FUNC complex64_t apply(complex64_t x, complex64_t) {
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
@@ -731,5 +741,406 @@ struct BlockMMA {
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
typename U,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
short lda_tgp,
|
||||
short ldb_tgp,
|
||||
typename AccumType,
|
||||
typename Epilogue>
|
||||
struct BlockMMA<
|
||||
complex64_t,
|
||||
U,
|
||||
BM,
|
||||
BN,
|
||||
BK,
|
||||
WM,
|
||||
WN,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
lda_tgp,
|
||||
ldb_tgp,
|
||||
AccumType,
|
||||
Epilogue> {
|
||||
static_assert(
|
||||
metal::is_same_v<AccumType, float>,
|
||||
"BlockMMA<complex64_t,...> expects float accumulators");
|
||||
static_assert(
|
||||
metal::is_same_v<U, complex64_t>,
|
||||
"For complex BlockMMA, U must be complex64_t; use a different epilogue for projections");
|
||||
// MMAFrag size
|
||||
STEEL_CONST short kFragSize = 8;
|
||||
using MMAFrag_acc_t = BaseMMAFrag<AccumType, kFragSize, kFragSize>;
|
||||
|
||||
// Warp tile simdgroup matrix strides along M
|
||||
STEEL_CONST short TM_stride = kFragSize * WM;
|
||||
// Warp tile simdgroup matrix strides along M
|
||||
STEEL_CONST short TN_stride = kFragSize * WN;
|
||||
|
||||
// Warp tile size along M
|
||||
STEEL_CONST short TM = BM / (kFragSize * WM);
|
||||
// Warp tile size along N
|
||||
STEEL_CONST short TN = BN / (kFragSize * WN);
|
||||
|
||||
// Threadgroup A strides
|
||||
STEEL_CONST short A_str_m = transpose_a ? 1 : lda_tgp; // M
|
||||
STEEL_CONST short A_str_k = transpose_a ? lda_tgp : 1; // K
|
||||
|
||||
// Threadgroup B strides
|
||||
STEEL_CONST short B_str_k = transpose_b ? 1 : ldb_tgp; // K
|
||||
STEEL_CONST short B_str_n = transpose_b ? ldb_tgp : 1; // N
|
||||
|
||||
// Threadgroup strides along K
|
||||
STEEL_CONST short tile_stride_a = kFragSize * A_str_k;
|
||||
STEEL_CONST short tile_stride_b = kFragSize * B_str_k;
|
||||
|
||||
// When indexing complex as float[2]
|
||||
STEEL_CONST short A_str_m_f = A_str_m * 2;
|
||||
STEEL_CONST short A_str_k_f = A_str_k * 2;
|
||||
STEEL_CONST short B_str_k_f = B_str_k * 2;
|
||||
STEEL_CONST short B_str_n_f = B_str_n * 2;
|
||||
STEEL_CONST short tile_stride_a_f = tile_stride_a * 2;
|
||||
STEEL_CONST short tile_stride_b_f = tile_stride_b * 2;
|
||||
|
||||
// Accumulators (real/imag)
|
||||
MMATile<AccumType, TM, TN, MMAFrag_acc_t> Ctile_r;
|
||||
MMATile<AccumType, TM, TN, MMAFrag_acc_t> Ctile_i;
|
||||
|
||||
// Offsets within threadgroup
|
||||
short sm, sn;
|
||||
short As_offset, Bs_offset;
|
||||
|
||||
/* Constructor */
|
||||
METAL_FUNC BlockMMA(
|
||||
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
ushort simd_lane_id [[thread_index_in_simdgroup]]) {
|
||||
// Determine thread position in simdgroup matrix
|
||||
short tm = kFragSize * (simd_group_id / WN);
|
||||
short tn = kFragSize * (simd_group_id % WN);
|
||||
|
||||
short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id);
|
||||
sm = simd_coord.y;
|
||||
sn = simd_coord.x;
|
||||
|
||||
// Determine thread and simdgroup offset
|
||||
As_offset = (tm + sm) * A_str_m + (sn)*A_str_k; // (M,K)
|
||||
Bs_offset = (sm)*B_str_k + (tn + sn) * B_str_n; // (K,N)
|
||||
|
||||
sm += tm;
|
||||
sn += tn;
|
||||
}
|
||||
|
||||
/* Karatsuba MMA: 3 real MMAs per K-chunk */
|
||||
METAL_FUNC void mma(
|
||||
const threadgroup complex64_t* As,
|
||||
const threadgroup complex64_t* Bs) {
|
||||
// Adjust for simdgroup and thread location
|
||||
As += As_offset;
|
||||
Bs += Bs_offset;
|
||||
threadgroup const float* As_f =
|
||||
reinterpret_cast<threadgroup const float*>(As);
|
||||
threadgroup const float* Bs_f =
|
||||
reinterpret_cast<threadgroup const float*>(Bs);
|
||||
|
||||
// Iterate over BK in blocks of kFragSize
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short kk = 0; kk < BK; kk += kFragSize) {
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
MMATile<AccumType, TM, 1, MMAFrag_acc_t> Ar, Ai;
|
||||
Ar.template load<float, WM, 1, A_str_m_f, A_str_k_f>(As_f + 0);
|
||||
Ai.template load<float, WM, 1, A_str_m_f, A_str_k_f>(As_f + 1);
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
MMATile<AccumType, 1, TN, MMAFrag_acc_t> Br, Bi;
|
||||
Br.template load<float, 1, WN, B_str_k_f, B_str_n_f>(Bs_f + 0);
|
||||
Bi.template load<float, 1, WN, B_str_k_f, B_str_n_f>(Bs_f + 1);
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// P = Ar*Br ; Q = Ai*Bi ; R = (Ar+Ai)*(Br+Bi)
|
||||
MMATile<AccumType, TM, TN, MMAFrag_acc_t> P, Q, R;
|
||||
|
||||
tile_matmad(P, Ar, Br, P);
|
||||
tile_matmad(Q, Ai, Bi, Q);
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < decltype(Ar)::kElemsPerTile; ++i)
|
||||
Ar.elems()[i] += Ai.elems()[i];
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < decltype(Br)::kElemsPerTile; ++i)
|
||||
Br.elems()[i] += Bi.elems()[i];
|
||||
|
||||
tile_matmad(R, Ar, Br, R);
|
||||
|
||||
// C_r += P - Q ; C_i -= Q
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < decltype(Ctile_r)::kElemsPerTile; ++i) {
|
||||
const auto p = P.elems()[i];
|
||||
const auto q = Q.elems()[i];
|
||||
const auto r = R.elems()[i];
|
||||
Ctile_r.elems()[i] += (p - q);
|
||||
Ctile_i.elems()[i] += (r - p - q);
|
||||
}
|
||||
|
||||
// Progress to next simdgroup tile
|
||||
As_f += tile_stride_a_f;
|
||||
Bs_f += tile_stride_b_f;
|
||||
}
|
||||
}
|
||||
|
||||
/* Store results from simdgroup_matrix results into device memory */
|
||||
METAL_FUNC void store_result(device U* D, const int ldd) {
|
||||
// Adjust for simdgroup and thread location
|
||||
D += sm * ldd + sn;
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < TM; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < TN; j++) {
|
||||
thread const auto& r = Ctile_r.frag_at(i, j);
|
||||
thread const auto& im = Ctile_i.frag_at(i, j);
|
||||
int off = (i * TM_stride) * ldd + (j * TN_stride);
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short k = 0; k < decltype(Ctile_r)::kElemsPerFrag; k++) {
|
||||
D[off + k] = Epilogue::apply(complex64_t(r[k], im[k]));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void
|
||||
store_result_slice(device U* D, const int ldd, short2 start, short2 stop) {
|
||||
D += sm * ldd + sn;
|
||||
start -= short2(sn, sm);
|
||||
stop -= short2(sn, sm);
|
||||
|
||||
if (stop.y <= 0 || stop.x <= 0)
|
||||
return;
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < TM; ++i) {
|
||||
const int row = i * TM_stride;
|
||||
if (row >= start.y && row < stop.y) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < TN; ++j) {
|
||||
const int off = row * ldd + (j * TN_stride);
|
||||
thread const auto& r = Ctile_r.frag_at(i, j);
|
||||
thread const auto& im = Ctile_i.frag_at(i, j);
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short k = 0; k < decltype(Ctile_r)::kElemsPerFrag; ++k) {
|
||||
const int col = j * TN_stride + k;
|
||||
if (col >= start.x && col < stop.x) {
|
||||
D[off + k] = Epilogue::apply(complex64_t(r[k], im[k]));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void
|
||||
store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) {
|
||||
D += sm * ldd + sn;
|
||||
dst_tile_dims -= short2(sn, sm);
|
||||
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
|
||||
return;
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < TM; i++) {
|
||||
if (i * TM_stride < dst_tile_dims.y) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < TN; j++) {
|
||||
int off = (i * TM_stride) * ldd + (j * TN_stride);
|
||||
thread const auto& r = Ctile_r.frag_at(i, j);
|
||||
thread const auto& im = Ctile_i.frag_at(i, j);
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short k = 0; k < decltype(Ctile_r)::kElemsPerFrag; k++) {
|
||||
if ((j * TN_stride + k) < dst_tile_dims.x) {
|
||||
D[off + k] = Epilogue::apply(complex64_t(r[k], im[k]));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Apply epilogue */
|
||||
template <typename UnaryEpilogue>
|
||||
METAL_FUNC void apply_epilogue(thread const UnaryEpilogue& epilogue_op) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < decltype(Ctile_r)::kElemsPerTile; i++) {
|
||||
complex64_t out = epilogue_op.apply(
|
||||
complex64_t(Ctile_r.elems()[i], Ctile_i.elems()[i]));
|
||||
Ctile_r.elems()[i] = out.real;
|
||||
Ctile_i.elems()[i] = out.imag;
|
||||
}
|
||||
}
|
||||
|
||||
/* Apply epilogue */
|
||||
template <typename BinaryEpilogue>
|
||||
METAL_FUNC void apply_epilogue(
|
||||
const device U* C,
|
||||
const int ldc,
|
||||
const int fdc,
|
||||
thread const BinaryEpilogue& epilogue_op) {
|
||||
// Adjust for simdgroup and thread location
|
||||
C += (sm)*ldc + (sn)*fdc;
|
||||
|
||||
// Loop over all simdgroup tiles
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < TM; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < TN; j++) {
|
||||
// Get accumulated result and associated offset in Cr, Ci
|
||||
thread auto& r = Ctile_r.frag_at(i, j);
|
||||
thread auto& im = Ctile_i.frag_at(i, j);
|
||||
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short k = 0; k < decltype(Ctile_r)::kElemsPerFrag; k++) {
|
||||
complex64_t out = epilogue_op.apply(
|
||||
complex64_t(r[k], im[k]), C[offset_c + k * fdc]);
|
||||
r[k] = out.real;
|
||||
im[k] = out.imag;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Apply epilogue */
|
||||
template <typename BinaryEpilogue>
|
||||
METAL_FUNC void apply_epilogue_safe(
|
||||
const device U* C,
|
||||
const int ldc,
|
||||
const int fdc,
|
||||
short2 dst_tile_dims,
|
||||
thread const BinaryEpilogue& epilogue_op) {
|
||||
// Adjust for simdgroup and thread location
|
||||
C += (sm)*ldc + (sn)*fdc;
|
||||
dst_tile_dims -= short2(sn, sm);
|
||||
|
||||
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
|
||||
return;
|
||||
|
||||
// Loop over all simdgroup tiles
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < TM; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < TN; j++) {
|
||||
// Get accumulated result and associated offset in Cr, Ci
|
||||
thread auto& r = Ctile_r.frag_at(i, j);
|
||||
thread auto& im = Ctile_i.frag_at(i, j);
|
||||
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
|
||||
|
||||
constexpr short kelems = decltype(Ctile_r)::kElemsPerFrag;
|
||||
complex64_t tmp[kelems];
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short k = 0; k < kelems; k++) {
|
||||
if ((j * TN_stride + k) < dst_tile_dims.x &&
|
||||
(i * TM_stride) < dst_tile_dims.y) {
|
||||
tmp[k] = C[offset_c + k * fdc];
|
||||
} else {
|
||||
tmp[k] = complex64_t(0.0f, 0.0f);
|
||||
}
|
||||
}
|
||||
|
||||
// Apply epilogue
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short k = 0; k < kelems; k++) {
|
||||
complex64_t out = epilogue_op.apply(complex64_t(r[k], im[k]), tmp[k]);
|
||||
r[k] = out.real;
|
||||
im[k] = out.imag;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Store results from simdgroup_matrix results into device memory */
|
||||
METAL_FUNC void store_result(
|
||||
device U* D,
|
||||
const int ldd,
|
||||
const device U* C,
|
||||
const int ldc,
|
||||
const int fdc,
|
||||
thread const Epilogue& epilogue_op) const {
|
||||
// Adjust for simdgroup and thread location
|
||||
C += (sm)*ldc + (sn)*fdc;
|
||||
D += (sm)*ldd + sn;
|
||||
|
||||
constexpr short kelems = decltype(Ctile_r)::kElemsPerFrag;
|
||||
|
||||
// Loop over all simdgroup tiles
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < TM; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < TN; j++) {
|
||||
// Get accumulated result and associated offset in Cr, Ci
|
||||
thread const auto& r = Ctile_r.frag_at(i, j);
|
||||
thread const auto& im = Ctile_i.frag_at(i, j);
|
||||
int off_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
|
||||
int off_d = (i * TM_stride) * ldd + (j * TN_stride);
|
||||
|
||||
// Apply epilogue
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short k = 0; k < kelems; k++) {
|
||||
D[off_d + k] =
|
||||
epilogue_op.apply(complex64_t(r[k], im[k]), C[off_c + k * fdc]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void store_result_safe(
|
||||
device U* D,
|
||||
const int ldd,
|
||||
const device U* C,
|
||||
const int ldc,
|
||||
const int fdc,
|
||||
short2 dst_tile_dims,
|
||||
thread const Epilogue& epilogue_op) const {
|
||||
// Adjust for simdgroup and thread location
|
||||
C += (sm)*ldc + (sn)*fdc;
|
||||
D += (sm)*ldd + sn;
|
||||
dst_tile_dims -= short2(sn, sm);
|
||||
|
||||
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
|
||||
return;
|
||||
|
||||
constexpr short kelems = decltype(Ctile_r)::kElemsPerFrag;
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int i = 0; i < TM; i++) {
|
||||
if (i * TM_stride < dst_tile_dims.y) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int j = 0; j < TN; j++) {
|
||||
// Get accumulated result and associated offset in Cr, Ci
|
||||
thread const auto& r = Ctile_r.frag_at(i, j);
|
||||
thread const auto& im = Ctile_i.frag_at(i, j);
|
||||
int off_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
|
||||
int off_d = (i * TM_stride) * ldd + (j * TN_stride);
|
||||
|
||||
// Apply epilogue
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short k = 0; k < kelems; k++) {
|
||||
if ((j * TN_stride + k) < dst_tile_dims.x) {
|
||||
D[off_d + k] = epilogue_op.apply(
|
||||
complex64_t(r[k], im[k]), C[off_c + k * fdc]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace steel
|
||||
} // namespace mlx
|
||||
|
@@ -48,7 +48,8 @@ struct TransformAxpby {
|
||||
}
|
||||
|
||||
METAL_FUNC OutT apply(InT x, OutT c) const {
|
||||
return static_cast<OutT>(x * alpha + (beta * c));
|
||||
return static_cast<OutT>(
|
||||
x * static_cast<InT>(alpha) + (static_cast<OutT>(beta) * c));
|
||||
}
|
||||
};
|
||||
|
||||
|
@@ -86,7 +86,13 @@ ensure_batch_contiguous(const array& x, metal::Device& d, const Stream& s) {
|
||||
|
||||
#define GEMM_TPARAM_MACRO(devc) \
|
||||
if (devc == 'g' || devc == 'p') { /* Small device */ \
|
||||
if (!transpose_a && transpose_b) { /* nt */ \
|
||||
if (out.dtype() == complex64) { \
|
||||
bm = 64; \
|
||||
bn = 32; \
|
||||
bk = 8; \
|
||||
wm = 4; \
|
||||
wn = 1; \
|
||||
} else if (!transpose_a && transpose_b) { /* nt */ \
|
||||
bm = 64; \
|
||||
bn = 32; \
|
||||
bk = 32; \
|
||||
@@ -362,7 +368,11 @@ void steel_gemm_splitk_axpby(
|
||||
int gemm_k_iterations = (K / bk) / split_k_partitions;
|
||||
int split_k_partition_size = gemm_k_iterations * bk;
|
||||
|
||||
array C_split({split_k_partitions, M, N}, float32, nullptr, {});
|
||||
array C_split(
|
||||
{split_k_partitions, M, N},
|
||||
issubdtype(out.dtype(), complexfloating) ? complex64 : float32,
|
||||
nullptr,
|
||||
{});
|
||||
C_split.set_data(allocator::malloc(C_split.nbytes()));
|
||||
copies.push_back(C_split);
|
||||
|
||||
@@ -807,9 +817,8 @@ inline void gemv(
|
||||
|
||||
void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
if (!issubdtype(out.dtype(), floating)) {
|
||||
throw std::runtime_error(
|
||||
"[matmul] Does not yet support non-floating point types.");
|
||||
if (!issubdtype(out.dtype(), inexact)) {
|
||||
throw std::runtime_error("[matmul] dtype must be inexact.");
|
||||
}
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
@@ -1338,7 +1347,8 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
|
||||
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
|
||||
<< "_wm" << wm << "_wn" << wn << "_MN_" << (mn_aligned ? "t" : "n")
|
||||
<< "aligned" << "_K_" << (k_aligned ? "t" : "n") << "aligned";
|
||||
<< "aligned"
|
||||
<< "_K_" << (k_aligned ? "t" : "n") << "aligned";
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
|
Reference in New Issue
Block a user