mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-04 15:04:40 +08:00
Compare commits
8 Commits
spda_sinks
...
simple-gem
Author | SHA1 | Date | |
---|---|---|---|
![]() |
4987e7615a | ||
![]() |
e1303f6160 | ||
![]() |
cf5eef095d | ||
![]() |
395d582719 | ||
![]() |
05583bcd10 | ||
![]() |
6fce01593a | ||
![]() |
97afe40b7b | ||
![]() |
f70c62d69c |
@@ -26,6 +26,8 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/gemms/gemv.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/gemms/cutlass_gemm.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/gemms/simple_gemm.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||
@@ -88,6 +90,9 @@ target_include_directories(mlx PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/gen")
|
||||
target_compile_options(mlx
|
||||
PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--extended-lambda>")
|
||||
|
||||
# Keep ptx around for inspection
|
||||
target_compile_options(mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--keep>")
|
||||
|
||||
# Enable calling host constexpr functions from device. This is needed because
|
||||
# the constexpr version of isnan is host only.
|
||||
target_compile_options(
|
||||
@@ -173,3 +178,12 @@ target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
|
||||
# Install CCCL headers for JIT.
|
||||
install(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda
|
||||
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl)
|
||||
|
||||
# Fetch and make available cutlass
|
||||
FetchContent_Declare(
|
||||
cutlass
|
||||
GIT_REPOSITORY https://github.com/NVIDIA/cutlass.git
|
||||
GIT_TAG v4.1.0)
|
||||
FetchContent_Populate(cutlass)
|
||||
target_include_directories(
|
||||
mlx PRIVATE $<BUILD_INTERFACE:${cutlass_SOURCE_DIR}/include>)
|
||||
|
396
mlx/backend/cuda/gemms/cutlass_gemm.cu
Normal file
396
mlx/backend/cuda/gemms/cutlass_gemm.cu
Normal file
@@ -0,0 +1,396 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/dtype_utils.h"
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
#include <cutlass/arch/arch.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/gemm/device/gemm.h>
|
||||
#include <cutlass/layout/matrix.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
#include <iostream>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
namespace {
|
||||
|
||||
using namespace cute;
|
||||
using bf16 = cute::bfloat16_t;
|
||||
|
||||
template <typename Kernel>
|
||||
void configure_matmul(Kernel kernel, int smem_size) {
|
||||
static bool initialized = false;
|
||||
if (!initialized) {
|
||||
initialized = true;
|
||||
cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool transpose, typename Tiler>
|
||||
constexpr int get_feature_size(Tiler smem) {
|
||||
int feature_size = (transpose) ? size<0>(smem) : size<1>(smem);
|
||||
return (feature_size >= 64) ? 64 : feature_size;
|
||||
}
|
||||
|
||||
constexpr int constexpr_log2(int x) {
|
||||
return (x > 0) ? 1 + constexpr_log2(x >> 1) : -1;
|
||||
}
|
||||
|
||||
template <int feature_size, int itemsize, int copy_bits>
|
||||
constexpr int get_swizzle_bits() {
|
||||
constexpr int swizzle_bits =
|
||||
constexpr_log2(feature_size * itemsize / copy_bits);
|
||||
return (swizzle_bits > 3) ? 3 : swizzle_bits;
|
||||
}
|
||||
|
||||
template <int itemsize, bool transpose, int copy_bits, typename Tiler>
|
||||
constexpr auto make_smem_layout(Tiler smem) {
|
||||
constexpr int feature_size = get_feature_size<transpose>(smem);
|
||||
constexpr int swizzle_bits =
|
||||
get_swizzle_bits<feature_size, itemsize, copy_bits>();
|
||||
|
||||
using F = Int<feature_size>;
|
||||
using BaseLayout = std::conditional_t<
|
||||
transpose,
|
||||
Layout<cute::Shape<F, _8>, cute::Stride<_1, F>>,
|
||||
Layout<cute::Shape<_8, F>, cute::Stride<F, _1>>>;
|
||||
|
||||
auto swizzled =
|
||||
make_composed_layout(Swizzle<swizzle_bits, 3, 3>{}, 0, BaseLayout{});
|
||||
|
||||
return tile_to_shape(swizzled, smem);
|
||||
}
|
||||
|
||||
template <int itemsize, bool transpose, int copy_bits, typename Tiler>
|
||||
constexpr auto make_result_smem_layout(Tiler smem) {
|
||||
constexpr int feature_size = get_feature_size<transpose>(smem);
|
||||
constexpr int swizzle_bits =
|
||||
get_swizzle_bits<feature_size, itemsize, copy_bits>();
|
||||
|
||||
using F = Int<feature_size>;
|
||||
using BaseLayout = std::conditional_t<
|
||||
transpose,
|
||||
Layout<cute::Shape<F, _8>, cute::Stride<_1, F>>,
|
||||
Layout<cute::Shape<_8, F>, cute::Stride<F, _1>>>;
|
||||
|
||||
auto swizzled = make_composed_layout(
|
||||
Swizzle<transpose ? 0 : swizzle_bits, 3, 4>{}, 0, BaseLayout{});
|
||||
|
||||
return tile_to_shape(swizzled, smem);
|
||||
}
|
||||
|
||||
template <
|
||||
int num_threads,
|
||||
int itemsize,
|
||||
bool transpose,
|
||||
int copy_bits,
|
||||
typename Copier,
|
||||
typename Tiler>
|
||||
constexpr auto make_tiled_copy(Copier copy_op, Tiler smem) {
|
||||
constexpr int num_elements = copy_bits / itemsize;
|
||||
constexpr int feature_size = transpose ? size<0>(smem) : size<1>(smem);
|
||||
constexpr int copies_per_feature = feature_size / num_elements;
|
||||
|
||||
using E = Int<num_elements>;
|
||||
using C = Int<copies_per_feature>;
|
||||
using R = Int<num_threads / copies_per_feature>;
|
||||
|
||||
using ThreadLayout = std::conditional_t<
|
||||
transpose,
|
||||
Layout<cute::Shape<C, R>, cute::Stride<_1, C>>,
|
||||
Layout<cute::Shape<R, C>, cute::Stride<C, _1>>>;
|
||||
using ValueLayout = std::conditional_t<
|
||||
transpose,
|
||||
Layout<cute::Shape<E, _1>>,
|
||||
Layout<cute::Shape<_1, E>>>;
|
||||
|
||||
return make_tiled_copy(copy_op, ThreadLayout{}, ValueLayout{});
|
||||
}
|
||||
|
||||
template <int rasterization_factor>
|
||||
__device__ inline int2 raster_tile(int x, int y) {
|
||||
return {
|
||||
x / rasterization_factor,
|
||||
(x % rasterization_factor) + y * rasterization_factor};
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename SLayoutA,
|
||||
typename SLayoutB,
|
||||
typename SLayoutC,
|
||||
typename CopyA,
|
||||
typename CopyB,
|
||||
typename CopyC,
|
||||
typename MMA,
|
||||
int rasterization_factor>
|
||||
__global__ static __launch_bounds__(decltype(size(MMA{}))::value) void matmul_kernel(
|
||||
const T* __restrict__ A,
|
||||
const T* __restrict__ B,
|
||||
T* __restrict__ C,
|
||||
SLayoutA SA,
|
||||
SLayoutB SB,
|
||||
SLayoutC SC,
|
||||
CopyA copy_a,
|
||||
CopyB copy_b,
|
||||
CopyC copy_c,
|
||||
MMA mma,
|
||||
int M,
|
||||
int N,
|
||||
int K) {
|
||||
constexpr auto BM = size<0>(SA);
|
||||
constexpr auto BN = size<0>(SB);
|
||||
constexpr auto BK = size<1>(SA);
|
||||
constexpr auto PIPE = size<2>(SA);
|
||||
|
||||
const int2 tile = raster_tile<rasterization_factor>(blockIdx.x, blockIdx.y);
|
||||
const int blocks_m = ceil_div(M, BM);
|
||||
const int blocks_n = ceil_div(N, BN);
|
||||
|
||||
// Exit early if the tile is OOB
|
||||
if (tile.x >= blocks_m || tile.y >= blocks_n) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Make the full tensors
|
||||
Tensor full_A =
|
||||
make_tensor(make_gmem_ptr(A), make_shape(M, K), make_stride(K, _1{}));
|
||||
Tensor full_B =
|
||||
make_tensor(make_gmem_ptr(B), make_shape(N, K), make_stride(K, _1{}));
|
||||
Tensor full_C =
|
||||
make_tensor(make_gmem_ptr(C), make_shape(M, N), make_stride(N, _1{}));
|
||||
|
||||
// Partition the tensors into tiles and select the ones for this threadblock
|
||||
Tensor local_A =
|
||||
local_tile(full_A, make_shape(BM, BK), make_coord(tile.x, _));
|
||||
Tensor local_B =
|
||||
local_tile(full_B, make_shape(BN, BK), make_coord(tile.y, _));
|
||||
Tensor local_C =
|
||||
local_tile(full_C, make_shape(BM, BN), make_coord(tile.x, tile.y));
|
||||
|
||||
// Make shared memory tensors
|
||||
extern __shared__ char shared_memory[];
|
||||
T* shared_A_ptr = reinterpret_cast<T*>(shared_memory);
|
||||
T* shared_B_ptr =
|
||||
reinterpret_cast<T*>(shared_memory + cosize(SA) * sizeof(T));
|
||||
T* shared_C_ptr = reinterpret_cast<T*>(shared_memory);
|
||||
Tensor shared_A = make_tensor(make_smem_ptr(shared_A_ptr), SA);
|
||||
Tensor shared_B = make_tensor(make_smem_ptr(shared_B_ptr), SB);
|
||||
Tensor shared_C = make_tensor(make_smem_ptr(shared_C_ptr), SC);
|
||||
|
||||
// Get the copies that correspond to this thread
|
||||
auto thread_copy_a = copy_a.get_slice(threadIdx.x);
|
||||
Tensor local_A_src = thread_copy_a.partition_S(local_A);
|
||||
Tensor local_A_dst = thread_copy_a.partition_D(shared_A);
|
||||
auto thread_copy_b = copy_b.get_slice(threadIdx.x);
|
||||
Tensor local_B_src = thread_copy_a.partition_S(local_B);
|
||||
Tensor local_B_dst = thread_copy_a.partition_D(shared_B);
|
||||
auto thread_copy_c = copy_c.get_slice(threadIdx.x);
|
||||
Tensor local_C_src = thread_copy_c.partition_S(shared_C);
|
||||
Tensor local_C_dst = thread_copy_c.partition_D(local_C);
|
||||
|
||||
// Start fetches
|
||||
int k_tile_count = size<2>(local_A);
|
||||
int k_tile_next = 0;
|
||||
CUTE_UNROLL
|
||||
for (int k = 0; k < PIPE - 1; k++) {
|
||||
copy(copy_a, local_A_src(_, _, _, k_tile_next), local_A_dst(_, _, _, k));
|
||||
copy(copy_b, local_B_src(_, _, _, k_tile_next), local_B_dst(_, _, _, k));
|
||||
cp_async_fence();
|
||||
k_tile_count--;
|
||||
k_tile_next += (k_tile_count > 0);
|
||||
}
|
||||
|
||||
// Get the MMA that corresponds to this thread and allocate registers
|
||||
auto thread_mma = mma.get_slice(threadIdx.x);
|
||||
Tensor mma_shared_A = thread_mma.partition_A(shared_A);
|
||||
Tensor mma_shared_B = thread_mma.partition_B(shared_B);
|
||||
Tensor mma_shared_C = thread_mma.partition_C(shared_C);
|
||||
Tensor mma_global_C = thread_mma.partition_C(local_C);
|
||||
Tensor mma_frag_A = mma.make_fragment_A(mma_shared_A(_, _, _, 0));
|
||||
Tensor mma_frag_B = mma.make_fragment_B(mma_shared_B(_, _, _, 0));
|
||||
Tensor mma_frag_C = mma.make_fragment_C(mma_global_C);
|
||||
clear(mma_frag_C);
|
||||
|
||||
// Make shared to register copies
|
||||
Copy_Atom<SM75_U32x4_LDSM_N, bf16> s2r_atom_a;
|
||||
Copy_Atom<SM75_U32x4_LDSM_N, bf16> s2r_atom_b;
|
||||
auto s2r_copy_a = make_tiled_copy_A(s2r_atom_a, mma);
|
||||
auto s2r_thread_copy_a = s2r_copy_a.get_slice(threadIdx.x);
|
||||
auto s2r_copy_b = make_tiled_copy_B(s2r_atom_b, mma);
|
||||
auto s2r_thread_copy_b = s2r_copy_b.get_slice(threadIdx.x);
|
||||
Tensor mma_A_src = s2r_thread_copy_a.partition_S(shared_A);
|
||||
Tensor mma_A_dst = s2r_thread_copy_a.retile_D(mma_frag_A);
|
||||
Tensor mma_B_src = s2r_thread_copy_b.partition_S(shared_B);
|
||||
Tensor mma_B_dst = s2r_thread_copy_b.retile_D(mma_frag_B);
|
||||
|
||||
constexpr auto RPIPE = size<2>(mma_shared_A);
|
||||
int smem_read = 0;
|
||||
int smem_write = PIPE - 1;
|
||||
Tensor mma_A_src_p = mma_A_src(_, _, _, smem_read);
|
||||
Tensor mma_B_src_p = mma_B_src(_, _, _, smem_read);
|
||||
|
||||
// Start the register pipeline
|
||||
if constexpr (RPIPE > 1) {
|
||||
cp_async_wait<PIPE - 2>();
|
||||
__syncthreads();
|
||||
copy(s2r_copy_a, mma_A_src_p(_, _, Int<0>{}), mma_A_dst(_, _, Int<0>{}));
|
||||
copy(s2r_copy_b, mma_B_src_p(_, _, Int<0>{}), mma_B_dst(_, _, Int<0>{}));
|
||||
}
|
||||
|
||||
CUTE_NO_UNROLL
|
||||
while (k_tile_count > -(PIPE - 1)) {
|
||||
CUTE_UNROLL
|
||||
for (int k_block = 0; k_block < RPIPE; k_block++) {
|
||||
if (k_block == RPIPE - 1) {
|
||||
mma_A_src_p = mma_A_src(_, _, _, smem_read);
|
||||
mma_B_src_p = mma_B_src(_, _, _, smem_read);
|
||||
cp_async_wait<PIPE - 2>();
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Load the next register tile
|
||||
auto k_block_next = (k_block + 1) % RPIPE;
|
||||
copy(
|
||||
s2r_copy_a,
|
||||
mma_A_src_p(_, _, k_block_next),
|
||||
mma_A_dst(_, _, k_block_next));
|
||||
copy(
|
||||
s2r_copy_b,
|
||||
mma_B_src_p(_, _, k_block_next),
|
||||
mma_B_dst(_, _, k_block_next));
|
||||
|
||||
if (k_block == 0) {
|
||||
copy(
|
||||
copy_a,
|
||||
local_A_src(_, _, _, k_tile_next),
|
||||
local_A_dst(_, _, _, smem_write));
|
||||
copy(
|
||||
copy_b,
|
||||
local_B_src(_, _, _, k_tile_next),
|
||||
local_B_dst(_, _, _, smem_write));
|
||||
cp_async_fence();
|
||||
k_tile_count--;
|
||||
k_tile_next += (k_tile_count > 0);
|
||||
smem_write = smem_read;
|
||||
smem_read = (smem_read == PIPE - 1) ? 0 : (smem_read + 1);
|
||||
}
|
||||
|
||||
gemm(
|
||||
mma,
|
||||
mma_frag_A(_, _, k_block),
|
||||
mma_frag_B(_, _, k_block),
|
||||
mma_frag_C);
|
||||
}
|
||||
}
|
||||
|
||||
copy(mma_frag_C, mma_shared_C);
|
||||
__syncthreads();
|
||||
copy(copy_c, local_C_src, local_C_dst);
|
||||
|
||||
// if (threadIdx.x == 0) {
|
||||
// print("fC: "); print(mma_frag_C); print("\n");
|
||||
// print("sC: "); print(mma_shared_C); print("\n");
|
||||
// print("dC: "); print(local_C_dst); print("\n");
|
||||
//
|
||||
// print(s2r_atom_a); print("\n");
|
||||
// }
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void cutlass_gemm(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
cu::CommandEncoder& enc) {
|
||||
enc.set_input_array(a);
|
||||
enc.set_input_array(b);
|
||||
enc.set_output_array(out);
|
||||
dispatch_float_types(a.dtype(), "simple_gemm", [&](auto type_tag) {
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
if constexpr (std::is_same_v<DataType, __nv_bfloat16>) {
|
||||
using namespace cute;
|
||||
|
||||
// Tile definitions
|
||||
auto BM = Int<128>{};
|
||||
auto BN = Int<128>{};
|
||||
auto BK = Int<64>{};
|
||||
auto BP = Int<3>{};
|
||||
auto GM = Int<8>{};
|
||||
|
||||
// Thread definitions
|
||||
using TM = Int<2>;
|
||||
using TN = Int<2>;
|
||||
using TK = Int<1>;
|
||||
constexpr int num_threads = TM::value * TN::value * 32;
|
||||
|
||||
auto SA = make_smem_layout<16, false, 128>(make_shape(BM, BK, BP));
|
||||
auto SB = make_smem_layout<16, false, 128>(make_shape(BN, BK, BP));
|
||||
auto SC = make_result_smem_layout<16, false, 128>(make_shape(BM, BN));
|
||||
|
||||
constexpr auto smem_size = (cosize(SA) + cosize(SB)) * sizeof(bf16);
|
||||
|
||||
auto async_copy_op =
|
||||
Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, bf16>{};
|
||||
auto tiled_copy_a = make_tiled_copy<num_threads, 16, false, 128>(
|
||||
async_copy_op, make_shape(BM, BK));
|
||||
auto tiled_copy_b = make_tiled_copy<num_threads, 16, false, 128>(
|
||||
async_copy_op, make_shape(BN, BK));
|
||||
|
||||
auto sync_copy_op = Copy_Atom<UniversalCopy<uint128_t>, bf16>{};
|
||||
auto tiled_copy_c = make_tiled_copy<num_threads, 16, false, 128>(
|
||||
sync_copy_op, make_shape(BM, BN));
|
||||
|
||||
auto mma_op = SM80_16x8x16_F32BF16BF16F32_TN{};
|
||||
auto tiled_mma = make_tiled_mma(
|
||||
mma_op, Layout<cute::Shape<TM, TN, TK>>{}, Tile<_32, _32, _16>{});
|
||||
|
||||
auto kernel = matmul_kernel<
|
||||
bf16,
|
||||
decltype(SA),
|
||||
decltype(SB),
|
||||
decltype(SC),
|
||||
decltype(tiled_copy_a),
|
||||
decltype(tiled_copy_b),
|
||||
decltype(tiled_copy_c),
|
||||
decltype(tiled_mma),
|
||||
GM.value>;
|
||||
configure_matmul(kernel, smem_size);
|
||||
|
||||
dim3 block(size(tiled_mma));
|
||||
dim3 grid(
|
||||
size(ceil_div(M, BM) * GM), size(ceil_div(ceil_div(N, BN), GM)));
|
||||
|
||||
enc.add_kernel_node(
|
||||
kernel,
|
||||
grid,
|
||||
block,
|
||||
smem_size,
|
||||
a.data<bf16>(),
|
||||
b.data<bf16>(),
|
||||
out.data<bf16>(),
|
||||
SA,
|
||||
SB,
|
||||
SC,
|
||||
tiled_copy_a,
|
||||
tiled_copy_b,
|
||||
tiled_copy_c,
|
||||
tiled_mma,
|
||||
M,
|
||||
N,
|
||||
K);
|
||||
} else {
|
||||
throw std::runtime_error("Only bfloat16 supported");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
18
mlx/backend/cuda/gemms/cutlass_gemm.h
Normal file
18
mlx/backend/cuda/gemms/cutlass_gemm.h
Normal file
@@ -0,0 +1,18 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
void cutlass_gemm(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
cu::CommandEncoder& enc);
|
||||
|
||||
}
|
69
mlx/backend/cuda/gemms/simple_gemm.cu
Normal file
69
mlx/backend/cuda/gemms/simple_gemm.cu
Normal file
@@ -0,0 +1,69 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/cuda/steel/gemm.cuh"
|
||||
#include "mlx/dtype_utils.h"
|
||||
|
||||
#include <iostream>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename Kernel>
|
||||
static void configure_smem(Kernel kernel, int SM) {
|
||||
static bool done = false;
|
||||
if (done) {
|
||||
return;
|
||||
}
|
||||
std::cout << "configuring" << std::endl;
|
||||
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, SM);
|
||||
cudaFuncSetAttribute(
|
||||
kernel,
|
||||
cudaFuncAttributePreferredSharedMemoryCarveout,
|
||||
cudaSharedmemCarveoutMaxShared);
|
||||
done = true;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void simple_gemm(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
cu::CommandEncoder& enc) {
|
||||
enc.set_input_array(a);
|
||||
enc.set_input_array(b);
|
||||
enc.set_output_array(out);
|
||||
dispatch_float_types(a.dtype(), "simple_gemm", [&](auto type_tag) {
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
constexpr int BM = 128;
|
||||
constexpr int BN = 128;
|
||||
constexpr int BK = 32;
|
||||
constexpr int PIPE = 3;
|
||||
constexpr int SM = PIPE * sizeof(DataType) * (BM * BK + BN * BK);
|
||||
constexpr int WM = 2;
|
||||
constexpr int WN = 4;
|
||||
|
||||
auto kernel = ab_t_aligned<DataType, BM, BN, BK, WM, WN, PIPE>;
|
||||
configure_smem(kernel, SM);
|
||||
|
||||
dim3 grid(N / BN, M / BM);
|
||||
enc.add_kernel_node(
|
||||
kernel,
|
||||
grid,
|
||||
WM * WN * WARP_SIZE,
|
||||
SM,
|
||||
a.data<DataType>(),
|
||||
b.data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
N,
|
||||
K);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
18
mlx/backend/cuda/gemms/simple_gemm.h
Normal file
18
mlx/backend/cuda/gemms/simple_gemm.h
Normal file
@@ -0,0 +1,18 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
void simple_gemm(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
cu::CommandEncoder& enc);
|
||||
|
||||
}
|
@@ -3,7 +3,9 @@
|
||||
#include "mlx/backend/common/matmul.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
|
||||
#include "mlx/backend/cuda/gemms/cutlass_gemm.h"
|
||||
#include "mlx/backend/cuda/gemms/gemv.h"
|
||||
#include "mlx/backend/cuda/gemms/simple_gemm.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
@@ -11,8 +13,14 @@
|
||||
#include <numeric>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
int get_test_gemm() {
|
||||
static int t = env::get_var("MLX_ENABLE_TEST_GEMM", 0);
|
||||
return t;
|
||||
}
|
||||
|
||||
std::tuple<bool, int64_t, array>
|
||||
check_transpose(cu::CommandEncoder& enc, const Stream& s, const array& arr) {
|
||||
auto stx = arr.strides()[arr.ndim() - 2];
|
||||
@@ -95,6 +103,18 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (M % 512 == 0 && N % 512 == 0 && K % 512 == 0 && !a_transposed &&
|
||||
b_transposed && batch_count == 1 && get_test_gemm() == 1) {
|
||||
cu::simple_gemm(a, b, out, M, N, K, encoder);
|
||||
return;
|
||||
}
|
||||
|
||||
if (M % 512 == 0 && N % 512 == 0 && K % 512 == 0 && !a_transposed &&
|
||||
b_transposed && batch_count == 1 && get_test_gemm() == 2) {
|
||||
cu::cutlass_gemm(a, b, out, M, N, K, encoder);
|
||||
return;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Invoke cublasLt
|
||||
CublasGemm gemm(
|
||||
|
@@ -4,95 +4,189 @@
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
template <typename T, int BM, int BN, int BK, int WM, int WN>
|
||||
__device__ inline void gemm_ab_t(
|
||||
RegisterTile<float, BM / WM, BN / WN>& C,
|
||||
SharedTile<T, BM, BK>& As,
|
||||
SharedTile<T, BN, BK>& Bs,
|
||||
RegisterTileLoader<SharedTile<T, BM, BK>>& rloader_a,
|
||||
RegisterTileLoader<SharedTile<T, BN, BK>>& rloader_b) {
|
||||
RegisterTile<T, BM / WM, 16> A[2];
|
||||
RegisterTile<T, BN / WN, 16> B[2];
|
||||
|
||||
rloader_a.load(A[0], As.base_addr(), 0);
|
||||
rloader_b.load(B[0], Bs.base_addr(), 0);
|
||||
|
||||
MLX_UNROLL
|
||||
for (int k = 1; k < BK / 16; k++) {
|
||||
rloader_a.load(A[k & 1], As.base_addr(), k);
|
||||
rloader_b.load(B[k & 1], Bs.base_addr(), k);
|
||||
|
||||
mma_t(C, A[(k - 1) & 1], B[(k - 1) & 1]);
|
||||
}
|
||||
mma_t(C, A[(BK / 16 - 1) & 1], B[(BK / 16 - 1) & 1]);
|
||||
}
|
||||
|
||||
/**
|
||||
* An example gemm written with the utils.
|
||||
*
|
||||
* Computes A @ B.T when A and B are all aligned with the block sizes.
|
||||
*/
|
||||
template <typename T, int BM, int BN, int BK>
|
||||
__global__ void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) {
|
||||
constexpr int WARPS_M = 2;
|
||||
constexpr int WARPS_N = 2;
|
||||
constexpr int NUM_WARPS = WARPS_M * WARPS_N;
|
||||
constexpr int WARP_STEP_M = BM / WARPS_M;
|
||||
constexpr int WARP_STEP_N = BN / WARPS_N;
|
||||
// template <typename T, int BM, int BN, int BK, int WM, int WN, int PIPE>
|
||||
//__global__ __launch_bounds__(WM * WN * WARP_SIZE, 1)
|
||||
// void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) {
|
||||
// constexpr int NUM_WARPS = WM * WN;
|
||||
// constexpr int WARP_STEP_M = BM / WM;
|
||||
// constexpr int WARP_STEP_N = BN / WN;
|
||||
//
|
||||
// // Precompute some offsets for each thread
|
||||
// const int warpid = threadIdx.x / 32;
|
||||
// const int laneid = threadIdx.x % 32;
|
||||
// const int wm = warpid / WN;
|
||||
// const int wn = warpid % WN;
|
||||
// const int offset_m = wm * WARP_STEP_M;
|
||||
// const int offset_n = wn * WARP_STEP_N;
|
||||
//
|
||||
// // Allocate shared memory
|
||||
// extern __shared__ char shmem[];
|
||||
// SharedTile<T, BM, BK>(&as)[PIPE] =
|
||||
// *(SharedTile<T, BM, BK>(*)[PIPE])(&shmem[0]);
|
||||
// SharedTile<T, BN, BK>(&bs)[PIPE] =
|
||||
// *(SharedTile<T, BN, BK>(*)[PIPE])(&shmem[sizeof(T) * PIPE * BM * BK]);
|
||||
//
|
||||
// // Move the global pointers to the tile
|
||||
// a += blockIdx.y * BM * K;
|
||||
// b += blockIdx.x * BN * K;
|
||||
// y += blockIdx.y * BM * N + blockIdx.x * BN;
|
||||
//
|
||||
// // Make the loaders to/from SMEM
|
||||
// SharedTileLoader<NUM_WARPS, SharedTile<T, BM, BK>> sloader_a(a, K);
|
||||
// SharedTileLoader<NUM_WARPS, SharedTile<T, BN, BK>> sloader_b(b, K);
|
||||
// RegisterTileLoader<SharedTile<T, BM, BK>> rloader_a(offset_m, laneid);
|
||||
// RegisterTileLoader<SharedTile<T, BN, BK>> rloader_b(offset_n, laneid);
|
||||
//
|
||||
// // Start the SM pipeline
|
||||
// MLX_UNROLL
|
||||
// for (int i = 0; i < PIPE - 1; i++) {
|
||||
// sloader_a.load_async(as[i].base_addr());
|
||||
// sloader_b.load_async(bs[i].base_addr());
|
||||
// cp_async_commit();
|
||||
// sloader_a.next();
|
||||
// sloader_b.next();
|
||||
// }
|
||||
//
|
||||
// // Allocate and zero the MMA accumulator
|
||||
// RegisterTile<float, BM / WM, BN / WN> C;
|
||||
// C.fill(0);
|
||||
//
|
||||
// // Matmul loop
|
||||
// int num_blocks = K / BK;
|
||||
// int sread = 0;
|
||||
// int swrite = PIPE - 1;
|
||||
// for (int i = 0; i < num_blocks; i++) {
|
||||
// cp_async_wait<PIPE - 1>();
|
||||
//
|
||||
// gemm_ab_t<T, BM, BN, BK, WM, WN>(
|
||||
// C, as[sread], bs[sread], rloader_a, rloader_b);
|
||||
//
|
||||
// sloader_a.load_async(as[swrite].base_addr());
|
||||
// sloader_b.load_async(bs[swrite].base_addr());
|
||||
// cp_async_commit();
|
||||
// sloader_a.next(i + PIPE < num_blocks);
|
||||
// sloader_b.next(i + PIPE < num_blocks);
|
||||
//
|
||||
// swrite = sread;
|
||||
// sread = (sread + 1) % PIPE;
|
||||
// }
|
||||
//
|
||||
// C.store_global(y, N, offset_m, offset_n);
|
||||
// }
|
||||
|
||||
template <typename T, int BM, int BN, int BK, int WM, int WN, int PIPE>
|
||||
__global__ __launch_bounds__(
|
||||
WM* WN* WARP_SIZE,
|
||||
1) void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) {
|
||||
constexpr int NUM_WARPS = WM * WN;
|
||||
constexpr int WARP_STEP_M = BM / WM;
|
||||
constexpr int WARP_STEP_N = BN / WN;
|
||||
|
||||
// Precompute some offsets for each thread
|
||||
const int warpid = threadIdx.x / 32;
|
||||
const int laneid = threadIdx.x % 32;
|
||||
const int wm = warpid / WARPS_N;
|
||||
const int wn = warpid % WARPS_N;
|
||||
const int wm = warpid / WN;
|
||||
const int wn = warpid % WN;
|
||||
const int offset_m = wm * WARP_STEP_M;
|
||||
const int offset_n = wn * WARP_STEP_N;
|
||||
|
||||
// Allocate shared memory
|
||||
extern __shared__ char shmem[];
|
||||
SharedTile<T, BM, BK>(&as)[2] = *(SharedTile<T, BM, BK>(*)[2])(&shmem[0]);
|
||||
SharedTile<T, BN, BK>(&bs)[2] =
|
||||
*(SharedTile<T, BN, BK>(*)[2])(&shmem[sizeof(T) * 2 * BM * BK]);
|
||||
|
||||
// Allocate registers for the MMA
|
||||
RegisterTile<float, BM / WARPS_M, BN / WARPS_N> C;
|
||||
RegisterTile<T, BM / WARPS_M, 16> A;
|
||||
RegisterTile<T, BN / WARPS_N, 16> B;
|
||||
SharedTile<T, BM, BK>(&as)[PIPE] =
|
||||
*(SharedTile<T, BM, BK>(*)[PIPE])(&shmem[0]);
|
||||
SharedTile<T, BN, BK>(&bs)[PIPE] =
|
||||
*(SharedTile<T, BN, BK>(*)[PIPE])(&shmem[sizeof(T) * PIPE * BM * BK]);
|
||||
|
||||
// Move the global pointers to the tile
|
||||
a += blockIdx.y * BM * K;
|
||||
b += blockIdx.x * BN * K;
|
||||
y += blockIdx.y * BM * N + blockIdx.x * BN;
|
||||
|
||||
// Zero the accumulators
|
||||
C.fill(0);
|
||||
// Make the loaders to/from SMEM
|
||||
using sloader = SharedTileLoader<NUM_WARPS, SharedTile<T, BM, BK>>;
|
||||
constexpr int SSTEP = sloader::STEP_ROWS * sizeof(T) * BK;
|
||||
const int srow = threadIdx.x / sloader::NUM_LOADS_PER_ROW;
|
||||
const int scol =
|
||||
(threadIdx.x % sloader::NUM_LOADS_PER_ROW) * sloader::ELEMENTS_PER_LOAD;
|
||||
a += srow * K + scol;
|
||||
b += srow * K + scol;
|
||||
uint32_t sm_offsets[PIPE][2];
|
||||
MLX_UNROLL
|
||||
for (int s = 0; s < PIPE; s++) {
|
||||
sm_offsets[s][0] = as[s].loc(as[s].base_addr(), srow, scol);
|
||||
sm_offsets[s][1] = bs[s].loc(bs[s].base_addr(), srow, scol);
|
||||
}
|
||||
RegisterTileLoader<SharedTile<T, BM, BK>> rloader_a(offset_m, laneid);
|
||||
RegisterTileLoader<SharedTile<T, BN, BK>> rloader_b(offset_n, laneid);
|
||||
|
||||
// Start the SM pipeline
|
||||
load_async<NUM_WARPS>(as[0], as[0].base_addr(), a, K);
|
||||
load_async<NUM_WARPS>(bs[0], bs[0].base_addr(), b, K);
|
||||
cp_async_commit();
|
||||
|
||||
int tic = 0;
|
||||
for (int k_block = BK; k_block < K; k_block += BK) {
|
||||
load_async<NUM_WARPS>(as[tic ^ 1], as[tic ^ 1].base_addr(), a + k_block, K);
|
||||
load_async<NUM_WARPS>(bs[tic ^ 1], bs[tic ^ 1].base_addr(), b + k_block, K);
|
||||
cp_async_commit();
|
||||
cp_async_wait<1>();
|
||||
__syncthreads();
|
||||
|
||||
MLX_UNROLL
|
||||
for (int s = 0; s < PIPE - 1; s++) {
|
||||
MLX_UNROLL
|
||||
for (int k = 0; k < BK / 16; k++) {
|
||||
A.load(
|
||||
as[tic],
|
||||
as[tic].base_addr(),
|
||||
offset_m + laneid % 16,
|
||||
k * 16 + laneid / 16 * 8);
|
||||
B.load(
|
||||
bs[tic],
|
||||
bs[tic].base_addr(),
|
||||
offset_n + laneid % 16,
|
||||
k * 16 + laneid / 16 * 8);
|
||||
|
||||
mma_t(C, A, B);
|
||||
for (int l = 0; l < sloader::NUM_LOADS_PER_THREAD; l++) {
|
||||
cp_async<16>(sm_offsets[s][0] + l * SSTEP, a);
|
||||
cp_async<16>(sm_offsets[s][1] + l * SSTEP, b);
|
||||
a += sloader::STEP_ROWS * K;
|
||||
b += sloader::STEP_ROWS * K;
|
||||
}
|
||||
|
||||
tic ^= 1;
|
||||
cp_async_commit();
|
||||
}
|
||||
|
||||
// Empty the pipeline
|
||||
cp_async_wait_all();
|
||||
__syncthreads();
|
||||
MLX_UNROLL
|
||||
for (int k = 0; k < BK / 16; k++) {
|
||||
A.load(
|
||||
as[tic],
|
||||
as[tic].base_addr(),
|
||||
offset_m + laneid % 16,
|
||||
k * 16 + laneid / 16 * 8);
|
||||
B.load(
|
||||
bs[tic],
|
||||
bs[tic].base_addr(),
|
||||
offset_n + laneid % 16,
|
||||
k * 16 + laneid / 16 * 8);
|
||||
// Allocate and zero the MMA accumulator
|
||||
RegisterTile<float, BM / WM, BN / WN> C;
|
||||
C.fill(0);
|
||||
|
||||
mma_t(C, A, B);
|
||||
// Matmul loop
|
||||
int num_blocks = K / BK;
|
||||
int sread = 0;
|
||||
int swrite = PIPE - 1;
|
||||
for (int i = 0; i < num_blocks; i++) {
|
||||
cp_async_wait<PIPE - 1>();
|
||||
|
||||
gemm_ab_t<T, BM, BN, BK, WM, WN>(
|
||||
C, as[sread], bs[sread], rloader_a, rloader_b);
|
||||
|
||||
if (false) {
|
||||
MLX_UNROLL
|
||||
for (int l = 0; l < sloader::NUM_LOADS_PER_THREAD; l++) {
|
||||
cp_async<16>(sm_offsets[swrite][0] + l * SSTEP, a);
|
||||
cp_async<16>(sm_offsets[swrite][1] + l * SSTEP, b);
|
||||
a += sloader::STEP_ROWS * K;
|
||||
b += sloader::STEP_ROWS * K;
|
||||
}
|
||||
}
|
||||
cp_async_commit();
|
||||
|
||||
swrite = sread;
|
||||
sread = (sread + 1) % PIPE;
|
||||
}
|
||||
|
||||
C.store_global(y, N, offset_m, offset_n);
|
||||
|
@@ -223,59 +223,10 @@ struct RegisterTile {
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* A simple container of multiple Tile16x16.
|
||||
*
|
||||
* Provides utility functions for loading and manipulating collections of basic
|
||||
* tiles.
|
||||
*/
|
||||
template <typename T, int ROWS_, int COLS_>
|
||||
struct RegisterTile {
|
||||
static constexpr int ROWS = ROWS_;
|
||||
static constexpr int COLS = COLS_;
|
||||
static constexpr int TILES_X = COLS / 16;
|
||||
static constexpr int TILES_Y = ROWS / 16;
|
||||
|
||||
Tile16x16<T> data[TILES_X * TILES_Y];
|
||||
|
||||
__device__ inline void fill(T v) {
|
||||
MLX_UNROLL
|
||||
for (int i = 0; i < TILES_Y; i++) {
|
||||
MLX_UNROLL
|
||||
for (int j = 0; j < TILES_X; j++) {
|
||||
data[i * TILES_X + j].fill(v);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Tile>
|
||||
__device__ inline void
|
||||
load(Tile& tile, uint32_t base_address, int row, int col) {
|
||||
MLX_UNROLL
|
||||
for (int i = 0; i < TILES_Y; i++) {
|
||||
MLX_UNROLL
|
||||
for (int j = 0; j < TILES_X; j++) {
|
||||
data[i * TILES_X + j].load(
|
||||
tile.loc(base_address, row + i * 16, col + j * 16));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
__device__ inline void store_global(U* x, int N, int row, int col) {
|
||||
MLX_UNROLL
|
||||
for (int i = 0; i < TILES_Y; i++) {
|
||||
MLX_UNROLL
|
||||
for (int j = 0; j < TILES_X; j++) {
|
||||
data[i * TILES_X + j].store_global(
|
||||
x + (row + i * 16) * N + col + j * 16, N);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int ROWS_, int COLS_>
|
||||
struct SharedTile {
|
||||
using value_type = T;
|
||||
|
||||
static constexpr int ROWS = ROWS_;
|
||||
static constexpr int COLS = COLS_;
|
||||
static constexpr int TILES_X = COLS / 16;
|
||||
@@ -317,23 +268,26 @@ struct SharedTile {
|
||||
}
|
||||
}
|
||||
|
||||
// Return the location of the element at (row, col) using the swizzle.
|
||||
__device__ static inline uint32_t loc(uint32_t ptr, int row, int col) {
|
||||
__device__ static inline uint32_t offset(int row, int col) {
|
||||
if constexpr (swizzle_bytes > 0) {
|
||||
static constexpr int swizzle_repeat = swizzle_bytes * 8;
|
||||
static constexpr int subtile_cols = swizzle_bytes / sizeof(T);
|
||||
const int outer_idx = col / subtile_cols;
|
||||
const uint32_t addr = ptr +
|
||||
sizeof(T) *
|
||||
(outer_idx * ROWS * subtile_cols + row * subtile_cols +
|
||||
col % subtile_cols);
|
||||
const uint32_t addr = sizeof(T) *
|
||||
(outer_idx * ROWS * subtile_cols + row * subtile_cols +
|
||||
col % subtile_cols);
|
||||
const int swizzle = ((addr % swizzle_repeat) >> 7) << 4;
|
||||
return (addr ^ swizzle);
|
||||
} else {
|
||||
return ptr + sizeof(T) * (row * COLS + col);
|
||||
return sizeof(T) * (row * COLS + col);
|
||||
}
|
||||
}
|
||||
|
||||
// Return the location of the element at (row, col) using the swizzle.
|
||||
__device__ static inline uint32_t loc(uint32_t ptr, int row, int col) {
|
||||
return ptr + offset(row, col);
|
||||
}
|
||||
|
||||
// Convenience functions to edit elements going through the swizzle.
|
||||
__device__ inline T& operator()(int row, int col) {
|
||||
return *ptr(data, row, col);
|
||||
@@ -364,6 +318,76 @@ struct SharedTile {
|
||||
}
|
||||
};
|
||||
|
||||
template <int NUM_WARPS, typename Tile>
|
||||
struct SharedTileLoader {
|
||||
using T = typename Tile::value_type;
|
||||
|
||||
static constexpr int NUM_THREADS = NUM_WARPS * 32;
|
||||
static constexpr int ELEMENTS_PER_LOAD = sizeof(float4) / sizeof(T);
|
||||
static constexpr int NUM_LOADS = Tile::NUMEL / ELEMENTS_PER_LOAD;
|
||||
static constexpr int NUM_LOADS_PER_THREAD = NUM_LOADS / NUM_THREADS;
|
||||
static constexpr int NUM_LOADS_PER_ROW = Tile::COLS / ELEMENTS_PER_LOAD;
|
||||
static constexpr int STEP_ROWS = NUM_THREADS / NUM_LOADS_PER_ROW;
|
||||
|
||||
const T* x_;
|
||||
int N_;
|
||||
uint32_t offset_;
|
||||
|
||||
__device__ SharedTileLoader(const T* x, int N) : x_(x), N_(N) {
|
||||
const int row = threadIdx.x / NUM_LOADS_PER_ROW;
|
||||
const int col = threadIdx.x % NUM_LOADS_PER_ROW;
|
||||
|
||||
x_ += row * N + col * ELEMENTS_PER_LOAD;
|
||||
offset_ = Tile::offset(row, col * ELEMENTS_PER_LOAD);
|
||||
}
|
||||
|
||||
__device__ inline void load_async(uint32_t base_address) {
|
||||
MLX_UNROLL
|
||||
for (int i = 0; i < NUM_LOADS_PER_THREAD; i++) {
|
||||
cp_async<16>(
|
||||
base_address + offset_ + i * STEP_ROWS * sizeof(T) * Tile::COLS,
|
||||
x_ + i * STEP_ROWS * N_);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ inline void next() {
|
||||
x_ += Tile::COLS;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Tile>
|
||||
struct RegisterTileLoader {
|
||||
using T = typename Tile::value_type;
|
||||
|
||||
uint32_t offset_[Tile::COLS / 16];
|
||||
|
||||
__device__ RegisterTileLoader(int offset_row, int laneid) {
|
||||
const int row = offset_row + laneid & 15;
|
||||
const int col = (laneid >> 4) << 3;
|
||||
|
||||
MLX_UNROLL
|
||||
for (int i = 0; i < Tile::COLS / 16; i++) {
|
||||
offset_[i] = Tile::offset(row, col + i * 16);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int ROWS, int COLS>
|
||||
__device__ inline void
|
||||
load(RegisterTile<T, ROWS, COLS>& x, uint32_t base_address, int col) {
|
||||
constexpr int TILES_Y = RegisterTile<T, ROWS, COLS>::TILES_Y;
|
||||
constexpr int TILES_X = RegisterTile<T, ROWS, COLS>::TILES_X;
|
||||
|
||||
MLX_UNROLL
|
||||
for (int i = 0; i < TILES_Y; i++) {
|
||||
MLX_UNROLL
|
||||
for (int j = 0; j < TILES_X; j++) {
|
||||
x.data[i * TILES_X + j].load(
|
||||
base_address + offset_[j + col] + i * 16 * Tile::COLS * sizeof(T));
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Load the tile from global memory by loading 16 bytes at a time and storing
|
||||
* them immediately.
|
||||
|
@@ -21,15 +21,15 @@ __device__ inline void cp_async(uint32_t row_address, const T* x) {
|
||||
#if defined(MLX_CUDA_SM_80_ENABLED)
|
||||
if constexpr (N == 16) {
|
||||
asm volatile(
|
||||
"cp.async.ca.shared::cta.global [%0], [%1], 16;\n" ::"r"(row_address),
|
||||
"cp.async.cg.shared::cta.global [%0], [%1], 16;\n" ::"r"(row_address),
|
||||
"l"(reinterpret_cast<const int4*>(x)));
|
||||
} else if constexpr (N == 8) {
|
||||
asm volatile(
|
||||
"cp.async.ca.shared::cta.global [%0], [%1], 8;\n" ::"r"(row_address),
|
||||
"cp.async.cg.shared::cta.global [%0], [%1], 8;\n" ::"r"(row_address),
|
||||
"l"(reinterpret_cast<const int2*>(x)));
|
||||
} else if constexpr (N == 4) {
|
||||
asm volatile(
|
||||
"cp.async.ca.shared::cta.global [%0], [%1], 4;\n" ::"r"(row_address),
|
||||
"cp.async.cg.shared::cta.global [%0], [%1], 4;\n" ::"r"(row_address),
|
||||
"l"(reinterpret_cast<const int*>(x)));
|
||||
}
|
||||
#endif
|
||||
|
Reference in New Issue
Block a user