mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
[WIP] Init NAX matmuls
This commit is contained in:
@@ -120,6 +120,14 @@ if(NOT MLX_METAL_PATH)
|
|||||||
set(MLX_METAL_PATH ${CMAKE_CURRENT_BINARY_DIR}/kernels/)
|
set(MLX_METAL_PATH ${CMAKE_CURRENT_BINARY_DIR}/kernels/)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
if((MLX_METAL_VERSION GREATER_EQUAL 400) AND (MACOS_SDK_VERSION GREATER_EQUAL
|
||||||
|
26.2))
|
||||||
|
set(MLX_ENABLE_NAX TRUE)
|
||||||
|
target_compile_definitions(mlx PRIVATE MLX_ENABLE_NAX)
|
||||||
|
else()
|
||||||
|
set(MLX_ENABLE_NAX FALSE)
|
||||||
|
endif()
|
||||||
|
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/kernels)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/kernels)
|
||||||
|
|
||||||
target_compile_definitions(mlx
|
target_compile_definitions(mlx
|
||||||
|
|||||||
@@ -265,4 +265,15 @@ Device& device(mlx::core::Device);
|
|||||||
|
|
||||||
std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool();
|
std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool();
|
||||||
|
|
||||||
|
#ifdef MLX_ENABLE_NAX
|
||||||
|
|
||||||
|
inline bool is_nax_available() {
|
||||||
|
static bool is_nax_available_ =
|
||||||
|
/* __builtin_available(macOS 26.2, *) && */
|
||||||
|
metal::device(mlx::core::Device::gpu).get_architecture_gen() >= 17;
|
||||||
|
return is_nax_available_;
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // MLX_ENABLE_NAX
|
||||||
|
|
||||||
} // namespace mlx::core::metal
|
} // namespace mlx::core::metal
|
||||||
|
|||||||
@@ -9,10 +9,13 @@ set(BASE_HEADERS
|
|||||||
utils.h)
|
utils.h)
|
||||||
|
|
||||||
function(build_kernel_base TARGET SRCFILE DEPS)
|
function(build_kernel_base TARGET SRCFILE DEPS)
|
||||||
set(METAL_FLAGS -Wall -Wextra -fno-fast-math -Wno-c++17-extensions)
|
set(METAL_FLAGS -x metal -Wall -Wextra -fno-fast-math -Wno-c++17-extensions)
|
||||||
if(MLX_METAL_DEBUG)
|
if(MLX_METAL_DEBUG)
|
||||||
set(METAL_FLAGS ${METAL_FLAGS} -gline-tables-only -frecord-sources)
|
set(METAL_FLAGS ${METAL_FLAGS} -gline-tables-only -frecord-sources)
|
||||||
endif()
|
endif()
|
||||||
|
if(MLX_ENABLE_NAX)
|
||||||
|
set(METAL_FLAGS ${METAL_FLAGS} -Wno-c++20-extensions -std=metal4.0)
|
||||||
|
endif()
|
||||||
if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "")
|
if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "")
|
||||||
set(METAL_FLAGS ${METAL_FLAGS}
|
set(METAL_FLAGS ${METAL_FLAGS}
|
||||||
"-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
|
"-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
|
||||||
@@ -120,6 +123,22 @@ if(NOT MLX_METAL_JIT)
|
|||||||
build_kernel(gemv_masked steel/utils.h)
|
build_kernel(gemv_masked steel/utils.h)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
if(MLX_ENABLE_NAX)
|
||||||
|
|
||||||
|
set(STEEL_NAX_HEADERS
|
||||||
|
steel/defines.h
|
||||||
|
steel/utils.h
|
||||||
|
steel/gemm/transforms.h
|
||||||
|
steel/gemm/nax.h
|
||||||
|
steel/gemm/gemm_nax.h
|
||||||
|
steel/utils/type_traits.h
|
||||||
|
steel/utils/integral_constant.h)
|
||||||
|
|
||||||
|
build_kernel(steel/gemm/kernels/steel_gemm_fused_nax ${STEEL_NAX_HEADERS})
|
||||||
|
build_kernel(steel/gemm/kernels/steel_gemm_gather_nax ${STEEL_NAX_HEADERS})
|
||||||
|
|
||||||
|
endif()
|
||||||
|
|
||||||
add_custom_command(
|
add_custom_command(
|
||||||
OUTPUT ${MLX_METAL_PATH}/mlx.metallib
|
OUTPUT ${MLX_METAL_PATH}/mlx.metallib
|
||||||
COMMAND xcrun -sdk macosx metallib ${KERNEL_AIR} -o
|
COMMAND xcrun -sdk macosx metallib ${KERNEL_AIR} -o
|
||||||
|
|||||||
@@ -1,4 +1,7 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
#define STEEL_CONST static constant constexpr const
|
#define STEEL_CONST static constant constexpr const
|
||||||
#define STEEL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
|
#define STEEL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
|
||||||
|
#define STEEL_PRAGMA_NO_UNROLL _Pragma("clang loop unroll(disable)")
|
||||||
|
|||||||
154
mlx/backend/metal/kernels/steel/gemm/gemm_nax.h
Normal file
154
mlx/backend/metal/kernels/steel/gemm/gemm_nax.h
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/kernels/steel/gemm/nax.h"
|
||||||
|
#include "mlx/backend/metal/kernels/steel/gemm/params.h"
|
||||||
|
|
||||||
|
using namespace metal;
|
||||||
|
|
||||||
|
namespace mlx::steel {
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
short SM,
|
||||||
|
short SN,
|
||||||
|
short SK,
|
||||||
|
short BK,
|
||||||
|
bool transpose_a,
|
||||||
|
bool transpose_b,
|
||||||
|
bool kAlignedM,
|
||||||
|
bool kAlignedN,
|
||||||
|
bool kAlignedK,
|
||||||
|
short UM,
|
||||||
|
short UN,
|
||||||
|
short UK,
|
||||||
|
typename AccumType = float>
|
||||||
|
auto gemm_loop(
|
||||||
|
const device T* A,
|
||||||
|
const device T* B,
|
||||||
|
const constant GEMMParams* params [[buffer(4)]],
|
||||||
|
const short sgp_sm,
|
||||||
|
const short sgp_sn) {
|
||||||
|
constexpr short TM = SM / UM;
|
||||||
|
constexpr short TN = SN / UN;
|
||||||
|
constexpr short TK = SK / UK;
|
||||||
|
|
||||||
|
constexpr int RA = transpose_a ? TK : TM;
|
||||||
|
constexpr int CA = transpose_a ? TM : TK;
|
||||||
|
|
||||||
|
constexpr int RB = transpose_b ? TN : TK;
|
||||||
|
constexpr int CB = transpose_b ? TK : TN;
|
||||||
|
|
||||||
|
using DSubTile = NAXSubTile<AccumType, UM, UN>;
|
||||||
|
using ASubTile =
|
||||||
|
NAXSubTile<T, (transpose_a ? UK : UM), (transpose_a ? UM : UK)>;
|
||||||
|
using BSubTile =
|
||||||
|
NAXSubTile<T, (transpose_b ? UN : UK), (transpose_b ? UK : UN)>;
|
||||||
|
|
||||||
|
NAXTile<AccumType, TM, TN, DSubTile> Dtile;
|
||||||
|
Dtile.clear();
|
||||||
|
|
||||||
|
int gemm_k_iterations_ = params->gemm_k_iterations_aligned;
|
||||||
|
|
||||||
|
STEEL_PRAGMA_NO_UNROLL
|
||||||
|
for (int kk0 = 0; kk0 < gemm_k_iterations_; kk0++) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
|
STEEL_PRAGMA_NO_UNROLL
|
||||||
|
for (int kk1 = 0; kk1 < BK; kk1 += SK) {
|
||||||
|
NAXTile<T, RA, CA, ASubTile> Atile;
|
||||||
|
NAXTile<T, RB, CB, BSubTile> Btile;
|
||||||
|
const int k = kk1;
|
||||||
|
|
||||||
|
volatile int compiler_barrier;
|
||||||
|
|
||||||
|
const int A_offset = transpose_a ? k * params->lda : k;
|
||||||
|
const int B_offset = transpose_b ? k : k * params->ldb;
|
||||||
|
|
||||||
|
if constexpr (kAlignedM) {
|
||||||
|
Atile.load(A + A_offset, params->lda);
|
||||||
|
} else {
|
||||||
|
const short rmax = transpose_a ? UK : sgp_sm;
|
||||||
|
const short cmax = transpose_a ? sgp_sm : UK;
|
||||||
|
Atile.load_safe(A + A_offset, params->lda, short2(cmax, rmax));
|
||||||
|
}
|
||||||
|
|
||||||
|
if constexpr (kAlignedN) {
|
||||||
|
Btile.load(B + B_offset, params->ldb);
|
||||||
|
} else {
|
||||||
|
const short rmax = transpose_b ? sgp_sn : UK;
|
||||||
|
const short cmax = transpose_b ? UK : sgp_sn;
|
||||||
|
Btile.load_safe(B + B_offset, params->ldb, short2(cmax, rmax));
|
||||||
|
}
|
||||||
|
|
||||||
|
tile_matmad_nax(
|
||||||
|
Dtile,
|
||||||
|
Atile,
|
||||||
|
metal::bool_constant<transpose_a>{},
|
||||||
|
Btile,
|
||||||
|
metal::bool_constant<transpose_b>{});
|
||||||
|
|
||||||
|
(void)compiler_barrier;
|
||||||
|
}
|
||||||
|
|
||||||
|
A += transpose_a ? (BK * params->lda) : BK;
|
||||||
|
B += transpose_b ? BK : (BK * params->ldb);
|
||||||
|
}
|
||||||
|
|
||||||
|
if constexpr (!kAlignedK) {
|
||||||
|
simdgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
|
const short rem_bk = params->K - gemm_k_iterations_ * BK;
|
||||||
|
|
||||||
|
STEEL_PRAGMA_NO_UNROLL
|
||||||
|
for (int kk1 = 0; kk1 < rem_bk; kk1 += SK) {
|
||||||
|
NAXTile<T, 1, 1, ASubTile> Atile;
|
||||||
|
NAXTile<T, 1, 1, BSubTile> Btile;
|
||||||
|
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (int mm = 0; mm < TM; mm++) {
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (int nn = 0; nn < TN; nn++) {
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (int kk = 0; kk < TK; kk++) {
|
||||||
|
const int m = mm * UM;
|
||||||
|
const int n = nn * UN;
|
||||||
|
const int k = kk1 + kk * UK;
|
||||||
|
const short psk = max(0, rem_bk - k);
|
||||||
|
|
||||||
|
const int A_offset =
|
||||||
|
transpose_a ? (m + k * params->lda) : (m * params->lda + k);
|
||||||
|
const int B_offset =
|
||||||
|
transpose_b ? (k + n * params->ldb) : (k * params->ldb + n);
|
||||||
|
|
||||||
|
{
|
||||||
|
const short psm = kAlignedM ? SM : max(0, sgp_sm - m);
|
||||||
|
const short rmax = transpose_a ? psk : psm;
|
||||||
|
const short cmax = transpose_a ? psm : psk;
|
||||||
|
Atile.load_safe(A + A_offset, params->lda, short2(cmax, rmax));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
const short psn = kAlignedN ? SN : max(0, sgp_sn - n);
|
||||||
|
const short rmax = transpose_b ? psn : psk;
|
||||||
|
const short cmax = transpose_b ? psk : psn;
|
||||||
|
Btile.load_safe(B + B_offset, params->ldb, short2(cmax, rmax));
|
||||||
|
}
|
||||||
|
|
||||||
|
subtile_matmad_nax(
|
||||||
|
Dtile.subtile_at(mm, nn),
|
||||||
|
Atile.subtile_at(0, 0),
|
||||||
|
metal::bool_constant<transpose_a>{},
|
||||||
|
Btile.subtile_at(0, 0),
|
||||||
|
metal::bool_constant<transpose_b>{});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return Dtile;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::steel
|
||||||
@@ -0,0 +1,207 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
using namespace mlx::steel;
|
||||||
|
|
||||||
|
constant bool has_batch [[function_constant(10)]];
|
||||||
|
|
||||||
|
constant bool use_out_source [[function_constant(100)]];
|
||||||
|
constant bool do_axpby [[function_constant(110)]];
|
||||||
|
|
||||||
|
constant bool align_M [[function_constant(200)]];
|
||||||
|
constant bool align_N [[function_constant(201)]];
|
||||||
|
constant bool align_K [[function_constant(202)]];
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
template <
|
||||||
|
bool kAlignedM,
|
||||||
|
bool kAlignedN,
|
||||||
|
typename NAXTile_t,
|
||||||
|
typename T>
|
||||||
|
void gemm_epilogue(
|
||||||
|
thread NAXTile_t& Dtile,
|
||||||
|
const device T* C,
|
||||||
|
const constant GEMMParams* params,
|
||||||
|
const constant GEMMAddMMParams* addmm_params,
|
||||||
|
const short sgp_sm,
|
||||||
|
const short sgp_sn) { // clang-format on
|
||||||
|
|
||||||
|
(void)params;
|
||||||
|
|
||||||
|
constexpr short UM = NAXTile_t::kSubTileRows;
|
||||||
|
constexpr short UN = NAXTile_t::kSubTileCols;
|
||||||
|
using CSubTile = NAXSubTile<T, UM, UN>;
|
||||||
|
|
||||||
|
using V = typename NAXTile_t::elem_type;
|
||||||
|
|
||||||
|
constexpr short TM = NAXTile_t::kTileRows;
|
||||||
|
constexpr short TN = NAXTile_t::kTileCols;
|
||||||
|
constexpr short kElemsPerSubTile = NAXTile_t::kElemsPerSubTile;
|
||||||
|
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short mm = 0; mm < TM; mm++) {
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short nn = 0; nn < TN; nn++) {
|
||||||
|
const short m = mm * UM;
|
||||||
|
const short n = nn * UN;
|
||||||
|
|
||||||
|
CSubTile CTile;
|
||||||
|
|
||||||
|
if constexpr (kAlignedM && kAlignedN) {
|
||||||
|
CTile.load(C, addmm_params->ldc, addmm_params->fdc, m, n);
|
||||||
|
} else {
|
||||||
|
CTile.load_safe(
|
||||||
|
C, addmm_params->ldc, addmm_params->fdc, sgp_sm, sgp_sn, m, n);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto delems = Dtile.subtile_at(mm, nn).elems();
|
||||||
|
auto celems = CTile.elems();
|
||||||
|
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short i = 0; i < kElemsPerSubTile; i++) {
|
||||||
|
if (do_axpby) {
|
||||||
|
delems[i] = addmm_params->alpha * delems[i] +
|
||||||
|
addmm_params->beta * static_cast<V>(celems[i]);
|
||||||
|
} else {
|
||||||
|
delems[i] += static_cast<V>(celems[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
int BM,
|
||||||
|
int BN,
|
||||||
|
int BK,
|
||||||
|
int WM,
|
||||||
|
int WN,
|
||||||
|
bool transpose_a,
|
||||||
|
bool transpose_b,
|
||||||
|
typename AccumType = float>
|
||||||
|
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm(
|
||||||
|
const device T* A [[buffer(0)]],
|
||||||
|
const device T* B [[buffer(1)]],
|
||||||
|
const device T* C [[buffer(2), function_constant(use_out_source)]],
|
||||||
|
device T* D [[buffer(3)]],
|
||||||
|
const constant GEMMParams* params [[buffer(4)]],
|
||||||
|
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
|
||||||
|
const constant int* batch_shape [[buffer(6), function_constant(has_batch)]],
|
||||||
|
const constant int64_t* batch_strides [[buffer(7), function_constant(has_batch)]],
|
||||||
|
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]]) { // clang-format on
|
||||||
|
// Find block
|
||||||
|
const int tid_y = ((tid.y) << params->swizzle_log) +
|
||||||
|
((tid.x) & ((1 << params->swizzle_log) - 1));
|
||||||
|
const int tid_x = (tid.x) >> params->swizzle_log;
|
||||||
|
|
||||||
|
// Exit early if out of bounds
|
||||||
|
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Adjust for batch
|
||||||
|
if (has_batch) {
|
||||||
|
const constant auto* A_bstrides = batch_strides;
|
||||||
|
const constant auto* B_bstrides = batch_strides + params->batch_ndim;
|
||||||
|
|
||||||
|
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||||
|
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
|
||||||
|
|
||||||
|
A += batch_offsets.x;
|
||||||
|
B += batch_offsets.y;
|
||||||
|
|
||||||
|
if (use_out_source) {
|
||||||
|
const constant auto* C_bstrides = B_bstrides + params->batch_ndim;
|
||||||
|
C += elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
A += params->batch_stride_a * tid.z;
|
||||||
|
B += params->batch_stride_b * tid.z;
|
||||||
|
|
||||||
|
if (use_out_source) {
|
||||||
|
C += addmm_params->batch_stride_c * tid.z;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
D += params->batch_stride_d * tid.z;
|
||||||
|
|
||||||
|
// Prepare threadgroup memory
|
||||||
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
|
// Find block in A, B, C
|
||||||
|
const int c_row = tid_y * BM;
|
||||||
|
const int c_col = tid_x * BN;
|
||||||
|
const size_t c_row_long = size_t(c_row);
|
||||||
|
const size_t c_col_long = size_t(c_col);
|
||||||
|
|
||||||
|
A += transpose_a ? c_row_long : c_row_long * params->lda;
|
||||||
|
B += transpose_b ? c_col_long * params->ldb : c_col_long;
|
||||||
|
D += c_row_long * params->ldd + c_col_long;
|
||||||
|
|
||||||
|
if (use_out_source) {
|
||||||
|
C += c_row_long * addmm_params->ldc + c_col_long * addmm_params->fdc;
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr short UM = 16;
|
||||||
|
constexpr short UN = 32;
|
||||||
|
constexpr short UK = 16;
|
||||||
|
constexpr short SM = BM / WM;
|
||||||
|
constexpr short SN = BN / WN;
|
||||||
|
constexpr short SK = 32;
|
||||||
|
|
||||||
|
constexpr short TM = SM / UM;
|
||||||
|
constexpr short TN = SN / UN;
|
||||||
|
|
||||||
|
const short tm = SM * (simd_group_id / WN);
|
||||||
|
const short tn = SN * (simd_group_id % WN);
|
||||||
|
|
||||||
|
const short sgp_sm = align_M ? SM : min(SM, short(params->M - (c_row + tm)));
|
||||||
|
const bool is_unaligned_sm = align_M ? false : (sgp_sm != SM);
|
||||||
|
|
||||||
|
const short sgp_sn = align_N ? SN : min(SN, short(params->N - (c_col + tn)));
|
||||||
|
const bool is_unaligned_sn = align_N ? false : (sgp_sn != SN);
|
||||||
|
|
||||||
|
A += transpose_a ? tm : (tm * params->lda);
|
||||||
|
B += transpose_b ? (tn * params->ldb) : tn;
|
||||||
|
D += tm * params->ldd + tn;
|
||||||
|
|
||||||
|
if (use_out_source) {
|
||||||
|
C += tm * addmm_params->ldc + tn * addmm_params->fdc;
|
||||||
|
}
|
||||||
|
|
||||||
|
using DSubTile = NAXSubTile<AccumType, UM, UN>;
|
||||||
|
NAXTile<AccumType, TM, TN, DSubTile> Dtile;
|
||||||
|
|
||||||
|
dispatch_bool(align_K, [&](auto kAlignedK) {
|
||||||
|
dispatch_bool(align_M || !is_unaligned_sm, [&](auto kAlignedM) {
|
||||||
|
dispatch_bool(align_N || !is_unaligned_sn, [&](auto kAlignedN) {
|
||||||
|
Dtile = gemm_loop<
|
||||||
|
T,
|
||||||
|
SM,
|
||||||
|
SN,
|
||||||
|
SK,
|
||||||
|
BK,
|
||||||
|
transpose_a,
|
||||||
|
transpose_b,
|
||||||
|
kAlignedM.value,
|
||||||
|
kAlignedN.value,
|
||||||
|
kAlignedK.value,
|
||||||
|
UM,
|
||||||
|
UN,
|
||||||
|
UK,
|
||||||
|
AccumType>(A, B, params, sgp_sm, sgp_sn);
|
||||||
|
if (use_out_source) {
|
||||||
|
gemm_epilogue<kAlignedM.value, kAlignedN.value>(
|
||||||
|
Dtile, C, params, addmm_params, sgp_sm, sgp_sn);
|
||||||
|
}
|
||||||
|
if constexpr (kAlignedM && kAlignedN) {
|
||||||
|
Dtile.store(D, int(params->ldd));
|
||||||
|
} else {
|
||||||
|
Dtile.store_safe(D, int(params->ldd), short2(sgp_sn, sgp_sm));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
@@ -0,0 +1,35 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include <metal_stdlib>
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/kernels/steel/gemm/gemm_nax.h"
|
||||||
|
#include "mlx/backend/metal/kernels/steel/gemm/nax.h"
|
||||||
|
#include "mlx/backend/metal/kernels/steel/gemm/params.h"
|
||||||
|
#include "mlx/backend/metal/kernels/steel/gemm/transforms.h"
|
||||||
|
#include "mlx/backend/metal/kernels/steel/utils.h"
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h"
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
|
instantiate_kernel( \
|
||||||
|
"steel_gemm_fused_nax_" #tname "_" #iname "_" #oname \
|
||||||
|
"_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn, \
|
||||||
|
gemm, itype, bm, bn, bk, wm, wn, trans_a, trans_b, float)
|
||||||
|
|
||||||
|
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
|
instantiate_gemm(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
|
instantiate_gemm(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
|
instantiate_gemm(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
|
instantiate_gemm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
|
||||||
|
|
||||||
|
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
||||||
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 256, 2, 2) \
|
||||||
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 128, 128, 512, 4, 4)
|
||||||
|
|
||||||
|
instantiate_gemm_shapes_helper(float16, half, float16, half);
|
||||||
|
instantiate_gemm_shapes_helper(bfloat16, bfloat, bfloat16, bfloat);
|
||||||
|
instantiate_gemm_shapes_helper(float32, float, float32, float);
|
||||||
|
// clang-format on
|
||||||
@@ -0,0 +1,132 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
using namespace mlx::steel;
|
||||||
|
|
||||||
|
constant bool align_M [[function_constant(200)]];
|
||||||
|
constant bool align_N [[function_constant(201)]];
|
||||||
|
constant bool align_K [[function_constant(202)]];
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
int BM,
|
||||||
|
int BN,
|
||||||
|
int BK,
|
||||||
|
int WM,
|
||||||
|
int WN,
|
||||||
|
bool transpose_a,
|
||||||
|
bool transpose_b,
|
||||||
|
typename AccumType = float>
|
||||||
|
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
|
||||||
|
gather_mm_rhs_nax(
|
||||||
|
const device T* A [[buffer(0)]],
|
||||||
|
const device T* B [[buffer(1)]],
|
||||||
|
const device uint32_t* rhs_indices [[buffer(2)]],
|
||||||
|
device T* C [[buffer(3)]],
|
||||||
|
const constant GEMMParams* params [[buffer(4)]],
|
||||||
|
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]]) {
|
||||||
|
constexpr short UM = 16;
|
||||||
|
constexpr short UN = 32;
|
||||||
|
constexpr short UK = 16;
|
||||||
|
constexpr short SM = BM / WM;
|
||||||
|
constexpr short SN = BN / WN;
|
||||||
|
constexpr short SK = 32;
|
||||||
|
constexpr short TM = SM / UM;
|
||||||
|
constexpr short TN = SN / UN;
|
||||||
|
|
||||||
|
if (params->tiles_n <= static_cast<int>(tid.x) ||
|
||||||
|
params->tiles_m <= static_cast<int>(tid.y)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the block in A, B, C
|
||||||
|
const int c_row = tid.y * BM;
|
||||||
|
const int c_col = tid.x * BN;
|
||||||
|
const size_t c_row_long = size_t(c_row);
|
||||||
|
const size_t c_col_long = size_t(c_col);
|
||||||
|
|
||||||
|
A += transpose_a ? c_row_long : c_row_long * params->lda;
|
||||||
|
B += transpose_b ? c_col_long * params->ldb : c_col_long;
|
||||||
|
C += c_row_long * params->ldd + c_col_long;
|
||||||
|
rhs_indices += c_row;
|
||||||
|
|
||||||
|
const short tm = SM * (simd_group_id / WN);
|
||||||
|
const short tn = SN * (simd_group_id % WN);
|
||||||
|
|
||||||
|
const short sgp_sm = align_M ? SM : min(SM, short(params->M - (c_row + tm)));
|
||||||
|
const bool is_unaligned_sm = align_M ? false : (sgp_sm != SM);
|
||||||
|
|
||||||
|
const short sgp_sn = align_N ? SN : min(SN, short(params->N - (c_col + tn)));
|
||||||
|
const bool is_unaligned_sn = align_N ? false : (sgp_sn != SN);
|
||||||
|
|
||||||
|
A += transpose_a ? tm : (tm * params->lda);
|
||||||
|
B += transpose_b ? (tn * params->ldb) : tn;
|
||||||
|
C += tm * params->ldd + tn;
|
||||||
|
rhs_indices += tm;
|
||||||
|
|
||||||
|
// Do as many matmuls as necessary
|
||||||
|
uint32_t index;
|
||||||
|
short offset;
|
||||||
|
uint32_t index_next = rhs_indices[0];
|
||||||
|
short offset_next = 0;
|
||||||
|
int n = 0;
|
||||||
|
while (n < sgp_sm) {
|
||||||
|
n++;
|
||||||
|
offset = offset_next;
|
||||||
|
index = index_next;
|
||||||
|
offset_next = sgp_sm;
|
||||||
|
for (; n < sgp_sm; n++) {
|
||||||
|
if (rhs_indices[n] != index) {
|
||||||
|
offset_next = n;
|
||||||
|
index_next = rhs_indices[n];
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
|
using DSubTile = NAXSubTile<AccumType, UM, UN>;
|
||||||
|
NAXTile<AccumType, TM, TN, DSubTile> Ctile;
|
||||||
|
|
||||||
|
dispatch_bool(align_K, [&](auto kAlignedK) {
|
||||||
|
dispatch_bool(align_M || !is_unaligned_sm, [&](auto kAlignedM) {
|
||||||
|
dispatch_bool(align_N || !is_unaligned_sn, [&](auto kAlignedN) {
|
||||||
|
auto do_gemm = gemm_loop<
|
||||||
|
T,
|
||||||
|
SM,
|
||||||
|
SN,
|
||||||
|
SK,
|
||||||
|
BK,
|
||||||
|
transpose_a,
|
||||||
|
transpose_b,
|
||||||
|
kAlignedM.value,
|
||||||
|
kAlignedN.value,
|
||||||
|
kAlignedK.value,
|
||||||
|
UM,
|
||||||
|
UN,
|
||||||
|
UK,
|
||||||
|
AccumType>;
|
||||||
|
Ctile = do_gemm(
|
||||||
|
A, B + index * params->batch_stride_b, params, sgp_sm, sgp_sn);
|
||||||
|
|
||||||
|
if constexpr (kAlignedN.value) {
|
||||||
|
if (offset_next - offset == SM) {
|
||||||
|
Ctile.store(C, int(params->ldd));
|
||||||
|
} else {
|
||||||
|
Ctile.store_slice(
|
||||||
|
C,
|
||||||
|
int(params->ldd),
|
||||||
|
short2(0, offset),
|
||||||
|
short2(SN, offset_next));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Ctile.store_slice(
|
||||||
|
C,
|
||||||
|
int(params->ldd),
|
||||||
|
short2(0, offset),
|
||||||
|
short2(sgp_sn, offset_next));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,39 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include <metal_stdlib>
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/kernels/steel/gemm/gemm_nax.h"
|
||||||
|
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h"
|
||||||
|
#include "mlx/backend/metal/kernels/steel/gemm/nax.h"
|
||||||
|
#include "mlx/backend/metal/kernels/steel/gemm/params.h"
|
||||||
|
#include "mlx/backend/metal/kernels/steel/utils.h"
|
||||||
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
#define instantiate_gather_mm_rhs(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
|
instantiate_kernel( \
|
||||||
|
"steel_gather_mm_rhs_nax_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn \
|
||||||
|
"_bk" #bk "_wm" #wm "_wn" #wn, \
|
||||||
|
gather_mm_rhs_nax, \
|
||||||
|
itype, \
|
||||||
|
bm, \
|
||||||
|
bn, \
|
||||||
|
bk, \
|
||||||
|
wm, \
|
||||||
|
wn, \
|
||||||
|
trans_a, \
|
||||||
|
trans_b, \
|
||||||
|
float)
|
||||||
|
|
||||||
|
#define instantiate_gather_mm_rhs_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
|
instantiate_gather_mm_rhs(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
|
instantiate_gather_mm_rhs(nt, false, true, iname, itype, oname, otype, bm, bn, bk, wm, wn)
|
||||||
|
|
||||||
|
#define instantiate_gather_mm_shapes_helper(iname, itype, oname, otype) \
|
||||||
|
instantiate_gather_mm_rhs_transpose_helper(iname, itype, oname, otype, 16, 128, 128, 1, 4) \
|
||||||
|
instantiate_gather_mm_rhs_transpose_helper(iname, itype, oname, otype, 32, 128, 128, 1, 4) \
|
||||||
|
instantiate_gather_mm_rhs_transpose_helper(iname, itype, oname, otype, 64, 128, 128, 2, 4)
|
||||||
|
// clang-format on
|
||||||
|
|
||||||
|
instantiate_gather_mm_shapes_helper(float16, half, float16, half);
|
||||||
|
instantiate_gather_mm_shapes_helper(bfloat16, bfloat, bfloat16, bfloat);
|
||||||
1087
mlx/backend/metal/kernels/steel/gemm/nax.h
Normal file
1087
mlx/backend/metal/kernels/steel/gemm/nax.h
Normal file
File diff suppressed because it is too large
Load Diff
@@ -74,6 +74,44 @@ integral_const_binop(>=, operator>=);
|
|||||||
integral_const_binop(&&, operator&&);
|
integral_const_binop(&&, operator&&);
|
||||||
integral_const_binop(||, operator||);
|
integral_const_binop(||, operator||);
|
||||||
|
|
||||||
|
template <typename T, typename = metal::enable_if_t<!is_integral_v<T>>>
|
||||||
|
METAL_FUNC constexpr auto operator||(true_type, T) {
|
||||||
|
return true_type{};
|
||||||
|
}
|
||||||
|
template <typename T, typename = metal::enable_if_t<!is_integral_v<T>>>
|
||||||
|
METAL_FUNC constexpr auto operator||(T, true_type) {
|
||||||
|
return true_type{};
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename = metal::enable_if_t<!is_integral_v<T>>>
|
||||||
|
METAL_FUNC constexpr auto operator&&(false_type, T) {
|
||||||
|
return false_type{};
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename = metal::enable_if_t<!is_integral_v<T>>>
|
||||||
|
METAL_FUNC constexpr auto operator&&(T, false_type) {
|
||||||
|
return false_type{};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dispatch utilities
|
||||||
|
template <typename F>
|
||||||
|
void dispatch_bool(bool v, F f) {
|
||||||
|
if (v) {
|
||||||
|
f(true_type{});
|
||||||
|
} else {
|
||||||
|
f(false_type{});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int start, int stop, int step, typename F>
|
||||||
|
constexpr void const_for_loop(F f) {
|
||||||
|
if constexpr (start < stop) {
|
||||||
|
constexpr auto idx = Int<start>{};
|
||||||
|
f(idx);
|
||||||
|
const_for_loop<start + step, stop, step, F>(f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#undef integral_const_binop
|
#undef integral_const_binop
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|||||||
@@ -172,6 +172,165 @@ ensure_batch_contiguous(const array& x, metal::Device& d, const Stream& s) {
|
|||||||
// Regular steel matmul dispatch
|
// Regular steel matmul dispatch
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
#ifdef MLX_ENABLE_NAX
|
||||||
|
|
||||||
|
template <bool CHECK_AB>
|
||||||
|
void steel_matmul_regular_axpby_nax(
|
||||||
|
const Stream& s,
|
||||||
|
metal::Device& d,
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
const array& c,
|
||||||
|
array& out,
|
||||||
|
int M,
|
||||||
|
int N,
|
||||||
|
int K,
|
||||||
|
int batch_size_out,
|
||||||
|
int lda,
|
||||||
|
int ldb,
|
||||||
|
int ldd,
|
||||||
|
bool transpose_a,
|
||||||
|
bool transpose_b,
|
||||||
|
std::vector<array>& copies,
|
||||||
|
Shape batch_shape,
|
||||||
|
Strides batch_strides,
|
||||||
|
int64_t A_batch_stride,
|
||||||
|
int64_t B_batch_stride,
|
||||||
|
int64_t matrix_stride_out,
|
||||||
|
int64_t C_batch_stride /* = 0*/,
|
||||||
|
float alpha /* = 1.0f */,
|
||||||
|
float beta /* = 0.0f */) {
|
||||||
|
using namespace mlx::steel;
|
||||||
|
|
||||||
|
// Determine dispatch kernel
|
||||||
|
int bm = 128, bn = 128, bk = 512;
|
||||||
|
int wm = 4, wn = 4;
|
||||||
|
|
||||||
|
// Prepare kernel name
|
||||||
|
std::ostringstream kname;
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
kname << "steel_gemm_fused_nax_"
|
||||||
|
<< (transpose_a ? 't' : 'n')
|
||||||
|
<< (transpose_b ? 't' : 'n')
|
||||||
|
<< "_" << type_to_name(a)
|
||||||
|
<< "_" << type_to_name(out)
|
||||||
|
<< "_bm" << bm << "_bn" << bn << "_bk" << bk
|
||||||
|
<< "_wm" << wm << "_wn" << wn; // clang-format on
|
||||||
|
|
||||||
|
std::string base_name = kname.str();
|
||||||
|
|
||||||
|
const bool has_batch = (batch_shape.size() > 1);
|
||||||
|
const bool use_out_source = CHECK_AB && (alpha != 0.0f || beta != 1.0f);
|
||||||
|
const bool do_axpby = use_out_source && (alpha != 1.0f || beta != 1.0f);
|
||||||
|
const bool align_M = (M % bm) == 0;
|
||||||
|
const bool align_N = (N % bn) == 0;
|
||||||
|
const bool align_K = (K % bk) == 0;
|
||||||
|
|
||||||
|
metal::MTLFCList func_consts = {
|
||||||
|
{&has_batch, MTL::DataType::DataTypeBool, 10},
|
||||||
|
{&use_out_source, MTL::DataType::DataTypeBool, 100},
|
||||||
|
{&do_axpby, MTL::DataType::DataTypeBool, 110},
|
||||||
|
{&align_M, MTL::DataType::DataTypeBool, 200},
|
||||||
|
{&align_N, MTL::DataType::DataTypeBool, 201},
|
||||||
|
{&align_K, MTL::DataType::DataTypeBool, 202},
|
||||||
|
};
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
kname << "_has_batch_" << (has_batch ? 't' : 'n')
|
||||||
|
<< "_use_out_source_" << (use_out_source ? 't' : 'n')
|
||||||
|
<< "_do_axpby_" << (do_axpby ? 't' : 'n')
|
||||||
|
<< "_align_M_" << (align_M ? 't' : 'n')
|
||||||
|
<< "_align_N_" << (align_N ? 't' : 'n')
|
||||||
|
<< "_align_K_" << (align_K ? 't' : 'n'); // clang-format on
|
||||||
|
|
||||||
|
std::string hash_name = kname.str();
|
||||||
|
|
||||||
|
// Encode and dispatch kernel
|
||||||
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
auto kernel = get_steel_gemm_fused_kernel(
|
||||||
|
/* metal::Device& d = */ d,
|
||||||
|
/* const std::string& kernel_name = */ base_name,
|
||||||
|
/* const std::string& hash_name = */ hash_name,
|
||||||
|
/* const metal::MTLFCList& func_consts = */ func_consts,
|
||||||
|
/* const array& out = */ out,
|
||||||
|
/* bool transpose_a = */ transpose_a,
|
||||||
|
/* bool transpose_b = */ transpose_b,
|
||||||
|
/* int bm = */ bm,
|
||||||
|
/* int bn = */ bn,
|
||||||
|
/* int bk = */ bk,
|
||||||
|
/* int wm = */ wm,
|
||||||
|
/* int wn = */ wn);
|
||||||
|
|
||||||
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
|
|
||||||
|
// Use problem size to determine threadblock swizzle
|
||||||
|
int tn = (N + bn - 1) / bn;
|
||||||
|
int tm = (M + bm - 1) / bm;
|
||||||
|
|
||||||
|
// TODO: Explore device-based tuning for swizzle
|
||||||
|
int swizzle_log = tm <= 3 ? 0 : 1;
|
||||||
|
|
||||||
|
// Prepare steel matmul params
|
||||||
|
GEMMParams params{
|
||||||
|
/* const int M = */ M,
|
||||||
|
/* const int N = */ N,
|
||||||
|
/* const int K = */ K,
|
||||||
|
/* const int lda = */ lda,
|
||||||
|
/* const int ldb = */ ldb,
|
||||||
|
/* const int ldd = */ ldd,
|
||||||
|
/* const int tiles_n = */ tn,
|
||||||
|
/* const int tiles_m = */ tm,
|
||||||
|
/* const int64_t batch_stride_a = */ A_batch_stride,
|
||||||
|
/* const int64_t batch_stride_b = */ B_batch_stride,
|
||||||
|
/* const int64_t batch_stride_d = */ matrix_stride_out,
|
||||||
|
/* const int swizzle_log = */ swizzle_log,
|
||||||
|
/* const int gemm_k_iterations_aligned = */ (K / bk),
|
||||||
|
/* const int batch_ndim = */ int(batch_shape.size())};
|
||||||
|
|
||||||
|
// Prepare launch grid params
|
||||||
|
int tile = 1 << swizzle_log;
|
||||||
|
tm = (tm + tile - 1) / tile;
|
||||||
|
tn = tn * tile;
|
||||||
|
|
||||||
|
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||||
|
MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out);
|
||||||
|
|
||||||
|
// Launch kernel
|
||||||
|
compute_encoder.set_input_array(a, 0);
|
||||||
|
compute_encoder.set_input_array(b, 1);
|
||||||
|
compute_encoder.set_output_array(out, 3);
|
||||||
|
|
||||||
|
compute_encoder.set_bytes(params, 4);
|
||||||
|
|
||||||
|
if (has_batch) {
|
||||||
|
compute_encoder.set_vector_bytes(batch_shape, 6);
|
||||||
|
compute_encoder.set_vector_bytes(batch_strides, 7);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (use_out_source) {
|
||||||
|
int ldc = c.strides()[c.ndim() - 2];
|
||||||
|
int fdc = c.strides()[c.ndim() - 1];
|
||||||
|
|
||||||
|
GEMMAddMMParams params{
|
||||||
|
/* const int ldc = */ ldc,
|
||||||
|
/* const int fdc = */ fdc,
|
||||||
|
/* const int64_t batch_stride_c = */ C_batch_stride,
|
||||||
|
/* const float alpha = */ alpha,
|
||||||
|
/* const float beta = */ beta};
|
||||||
|
|
||||||
|
compute_encoder.set_input_array(c, 2);
|
||||||
|
compute_encoder.set_bytes(params, 5);
|
||||||
|
}
|
||||||
|
|
||||||
|
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||||
|
|
||||||
|
// Record copies
|
||||||
|
d.add_temporaries(std::move(copies), s.index);
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // MLX_ENABLE_NAX
|
||||||
|
|
||||||
template <bool CHECK_AB>
|
template <bool CHECK_AB>
|
||||||
void steel_matmul_regular_axpby(
|
void steel_matmul_regular_axpby(
|
||||||
const Stream& s,
|
const Stream& s,
|
||||||
@@ -198,6 +357,39 @@ void steel_matmul_regular_axpby(
|
|||||||
int64_t C_batch_stride /* = 0*/,
|
int64_t C_batch_stride /* = 0*/,
|
||||||
float alpha /* = 1.0f */,
|
float alpha /* = 1.0f */,
|
||||||
float beta /* = 0.0f */) {
|
float beta /* = 0.0f */) {
|
||||||
|
#ifdef MLX_ENABLE_NAX
|
||||||
|
|
||||||
|
if (metal::is_nax_available() &&
|
||||||
|
(a.dtype() != float32 || env::enable_tf32())) {
|
||||||
|
return steel_matmul_regular_axpby_nax<CHECK_AB>(
|
||||||
|
/* const Stream& s = */ s,
|
||||||
|
/* metal::Device& d = */ d,
|
||||||
|
/* const array& a = */ a,
|
||||||
|
/* const array& b = */ b,
|
||||||
|
/* const array& c = */ c,
|
||||||
|
/* array& out = */ out,
|
||||||
|
/* int M = */ M,
|
||||||
|
/* int N = */ N,
|
||||||
|
/* int K = */ K,
|
||||||
|
/* int batch_size_out = */ batch_size_out,
|
||||||
|
/* int lda = */ lda,
|
||||||
|
/* int ldb = */ ldb,
|
||||||
|
/* int ldd = */ ldd,
|
||||||
|
/* bool transpose_a = */ transpose_a,
|
||||||
|
/* bool transpose_b = */ transpose_b,
|
||||||
|
/* std::vector<array>& copies = */ copies,
|
||||||
|
/* Shape batch_shape = */ batch_shape,
|
||||||
|
/* Strides batch_strides = */ batch_strides,
|
||||||
|
/* int64_t A_batch_stride = */ A_batch_stride,
|
||||||
|
/* int64_t B_batch_stride = */ B_batch_stride,
|
||||||
|
/* int64_t matrix_stride_out = */ matrix_stride_out,
|
||||||
|
/* int64_t C_batch_stride = */ C_batch_stride,
|
||||||
|
/* float alpha = */ alpha,
|
||||||
|
/* float beta = */ beta);
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // MLX_ENABLE_NAX
|
||||||
|
|
||||||
using namespace mlx::steel;
|
using namespace mlx::steel;
|
||||||
|
|
||||||
// Determine dispatch kernel
|
// Determine dispatch kernel
|
||||||
@@ -1572,6 +1764,153 @@ void gather_mm_rhs(
|
|||||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef MLX_ENABLE_NAX
|
||||||
|
|
||||||
|
void gather_mm_rhs_nax(
|
||||||
|
const array& a_,
|
||||||
|
const array& b_,
|
||||||
|
const array& indices_,
|
||||||
|
array& out,
|
||||||
|
metal::Device& d,
|
||||||
|
const Stream& s) {
|
||||||
|
array indices = ensure_row_contiguous(indices_, d, s);
|
||||||
|
auto [transpose_b, ldb, b] = ensure_batch_contiguous(b_, d, s);
|
||||||
|
|
||||||
|
// Broadcast a with indices. If we are here that means lhs_indices were not
|
||||||
|
// provided so the lhs_indices are implied to be the shape of a broadcasted
|
||||||
|
// with rhs_indices. We need only broadcast a and copy it as if applying the
|
||||||
|
// lhs_indices.
|
||||||
|
auto broadcast_with_indices = [&d, &s, &indices](const array& x) {
|
||||||
|
if (x.size() / x.shape(-2) / x.shape(-1) == indices.size()) {
|
||||||
|
return ensure_row_contiguous(x, d, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto x_shape = indices.shape();
|
||||||
|
x_shape.push_back(x.shape(-2));
|
||||||
|
x_shape.push_back(x.shape(-1));
|
||||||
|
array new_x(std::move(x_shape), x.dtype(), nullptr, {});
|
||||||
|
broadcast(x, new_x);
|
||||||
|
return ensure_row_contiguous(new_x, d, s);
|
||||||
|
};
|
||||||
|
array a = broadcast_with_indices(a_);
|
||||||
|
|
||||||
|
// Extract the matmul shapes
|
||||||
|
int K = a.shape(-1);
|
||||||
|
int M = a.size() / K;
|
||||||
|
int N = b.shape(-1);
|
||||||
|
int lda = a.strides()[a.ndim() - 2]; // should be K
|
||||||
|
int E = b.shape(0);
|
||||||
|
|
||||||
|
// Define the dispatch blocks
|
||||||
|
int bm, bn = 128, bk = 128, wm, wn = 4;
|
||||||
|
if (M / E > 48) {
|
||||||
|
bm = 64;
|
||||||
|
wm = 2;
|
||||||
|
} else if (M / E > 24) {
|
||||||
|
bm = 32l;
|
||||||
|
wm = 1;
|
||||||
|
} else {
|
||||||
|
bm = 16;
|
||||||
|
wm = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
const bool align_M = (M % bm) == 0;
|
||||||
|
const bool align_N = (N % bn) == 0;
|
||||||
|
const bool align_K = (K % bk) == 0;
|
||||||
|
|
||||||
|
// Define the kernel name
|
||||||
|
std::string base_name;
|
||||||
|
base_name.reserve(64);
|
||||||
|
concatenate(
|
||||||
|
base_name,
|
||||||
|
"steel_gather_mm_rhs_mxu_n",
|
||||||
|
transpose_b ? 't' : 'n',
|
||||||
|
'_',
|
||||||
|
type_to_name(a),
|
||||||
|
'_',
|
||||||
|
type_to_name(out),
|
||||||
|
"_bm",
|
||||||
|
bm,
|
||||||
|
"_bn",
|
||||||
|
bn,
|
||||||
|
"_bk",
|
||||||
|
bk,
|
||||||
|
"_wm",
|
||||||
|
wm,
|
||||||
|
"_wn",
|
||||||
|
wn);
|
||||||
|
|
||||||
|
metal::MTLFCList func_consts = {
|
||||||
|
{&align_M, MTL::DataType::DataTypeBool, 200},
|
||||||
|
{&align_N, MTL::DataType::DataTypeBool, 201},
|
||||||
|
{&align_K, MTL::DataType::DataTypeBool, 202},
|
||||||
|
};
|
||||||
|
|
||||||
|
// And the kernel hash that includes the function constants
|
||||||
|
std::string hash_name;
|
||||||
|
hash_name.reserve(128);
|
||||||
|
concatenate(
|
||||||
|
hash_name,
|
||||||
|
base_name,
|
||||||
|
"_align_M_",
|
||||||
|
align_M ? 't' : 'n',
|
||||||
|
"_align_N_",
|
||||||
|
align_N ? 't' : 'n',
|
||||||
|
"_align_K_",
|
||||||
|
align_K ? 't' : 'n');
|
||||||
|
|
||||||
|
// Get and set the kernel
|
||||||
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
auto kernel = get_steel_gemm_gather_kernel(
|
||||||
|
d,
|
||||||
|
base_name,
|
||||||
|
hash_name,
|
||||||
|
func_consts,
|
||||||
|
out,
|
||||||
|
false,
|
||||||
|
transpose_b,
|
||||||
|
bm,
|
||||||
|
bn,
|
||||||
|
bk,
|
||||||
|
wm,
|
||||||
|
wn,
|
||||||
|
true);
|
||||||
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
|
|
||||||
|
// Prepare the matmul params
|
||||||
|
auto batch_stride_b = b.ndim() > 2 ? b.strides()[b.ndim() - 3] : b.size();
|
||||||
|
steel::GEMMParams params{
|
||||||
|
/* const int M = */ M,
|
||||||
|
/* const int N = */ N,
|
||||||
|
/* const int K = */ K,
|
||||||
|
/* const int lda = */ lda,
|
||||||
|
/* const int ldb = */ static_cast<int>(ldb),
|
||||||
|
/* const int ldd = */ N,
|
||||||
|
/* const int tiles_n = */ (N + bn - 1) / bn,
|
||||||
|
/* const int tiles_m = */ (M + bm - 1) / bm,
|
||||||
|
/* const int64_t batch_stride_a = */ 0,
|
||||||
|
/* const int64_t batch_stride_b = */ static_cast<int64_t>(batch_stride_b),
|
||||||
|
/* const int64_t batch_stride_d = */ 0,
|
||||||
|
/* const int swizzle_log = */ 0,
|
||||||
|
/* const int gemm_k_iterations_aligned = */ (K / bk),
|
||||||
|
/* const int batch_ndim = */ 0};
|
||||||
|
|
||||||
|
// Prepare the grid
|
||||||
|
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||||
|
MTL::Size grid_dims = MTL::Size(params.tiles_n, params.tiles_m, 1);
|
||||||
|
|
||||||
|
// Launch kernel
|
||||||
|
compute_encoder.set_input_array(a, 0);
|
||||||
|
compute_encoder.set_input_array(b, 1);
|
||||||
|
compute_encoder.set_input_array(indices, 2);
|
||||||
|
compute_encoder.set_output_array(out, 3);
|
||||||
|
compute_encoder.set_bytes(params, 4);
|
||||||
|
|
||||||
|
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // MLX_ENABLE_NAX
|
||||||
|
|
||||||
void gather_mv(
|
void gather_mv(
|
||||||
const array& mat_,
|
const array& mat_,
|
||||||
const array& vec_,
|
const array& vec_,
|
||||||
@@ -1855,6 +2194,14 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
// We are walking a in order and b is also in order so we can batch up the
|
// We are walking a in order and b is also in order so we can batch up the
|
||||||
// matmuls and reuse reading a and b.
|
// matmuls and reuse reading a and b.
|
||||||
if (M == 1 && right_sorted_ == true) {
|
if (M == 1 && right_sorted_ == true) {
|
||||||
|
#ifdef MLX_ENABLE_NAX
|
||||||
|
|
||||||
|
if (metal::is_nax_available() && a.dtype() != float32) {
|
||||||
|
return gather_mm_rhs_nax(a, b, rhs_indices, out, d, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // MLX_ENABLE_NAX
|
||||||
|
|
||||||
gather_mm_rhs(a, b, rhs_indices, out, d, s);
|
gather_mm_rhs(a, b, rhs_indices, out, d, s);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user