mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 18:48:15 +08:00 
			
		
		
		
	tmp
This commit is contained in:
		@@ -90,6 +90,9 @@ target_include_directories(mlx PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/gen")
 | 
			
		||||
target_compile_options(mlx
 | 
			
		||||
                       PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--extended-lambda>")
 | 
			
		||||
 | 
			
		||||
# Keep ptx around for inspection
 | 
			
		||||
target_compile_options(mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--keep>")
 | 
			
		||||
 | 
			
		||||
# Enable calling host constexpr functions from device. This is needed because
 | 
			
		||||
# the constexpr version of isnan is host only.
 | 
			
		||||
target_compile_options(
 | 
			
		||||
 
 | 
			
		||||
@@ -5,8 +5,29 @@
 | 
			
		||||
#include "mlx/backend/cuda/steel/gemm.cuh"
 | 
			
		||||
#include "mlx/dtype_utils.h"
 | 
			
		||||
 | 
			
		||||
#include <iostream>
 | 
			
		||||
 | 
			
		||||
namespace mlx::core::cu {
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
template <typename Kernel>
 | 
			
		||||
static void configure_smem(Kernel kernel, int SM) {
 | 
			
		||||
  static bool done = false;
 | 
			
		||||
  if (done) {
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
  std::cout << "configuring" << std::endl;
 | 
			
		||||
  cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, SM);
 | 
			
		||||
  cudaFuncSetAttribute(
 | 
			
		||||
      kernel,
 | 
			
		||||
      cudaFuncAttributePreferredSharedMemoryCarveout,
 | 
			
		||||
      cudaSharedmemCarveoutMaxShared);
 | 
			
		||||
  done = true;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace
 | 
			
		||||
 | 
			
		||||
void simple_gemm(
 | 
			
		||||
    const array& a,
 | 
			
		||||
    const array& b,
 | 
			
		||||
@@ -23,17 +44,20 @@ void simple_gemm(
 | 
			
		||||
    constexpr int BM = 128;
 | 
			
		||||
    constexpr int BN = 128;
 | 
			
		||||
    constexpr int BK = 32;
 | 
			
		||||
    constexpr int PIPE = 3;
 | 
			
		||||
    constexpr int SM = PIPE * sizeof(DataType) * (BM * BK + BN * BK);
 | 
			
		||||
    constexpr int WM = 2;
 | 
			
		||||
    constexpr int WN = 4;
 | 
			
		||||
 | 
			
		||||
    auto kernel = ab_t_aligned<DataType, BM, BN, BK>;
 | 
			
		||||
    cudaFuncSetAttribute(
 | 
			
		||||
        kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
 | 
			
		||||
    auto kernel = ab_t_aligned<DataType, BM, BN, BK, WM, WN, PIPE>;
 | 
			
		||||
    configure_smem(kernel, SM);
 | 
			
		||||
 | 
			
		||||
    dim3 grid(N / BN, M / BM);
 | 
			
		||||
    enc.add_kernel_node(
 | 
			
		||||
        kernel,
 | 
			
		||||
        grid,
 | 
			
		||||
        8 * WARP_SIZE,
 | 
			
		||||
        4 * sizeof(DataType) * (BM * BK + BN * BK),
 | 
			
		||||
        WM * WN * WARP_SIZE,
 | 
			
		||||
        SM,
 | 
			
		||||
        a.data<DataType>(),
 | 
			
		||||
        b.data<DataType>(),
 | 
			
		||||
        out.data<DataType>(),
 | 
			
		||||
 
 | 
			
		||||
@@ -16,6 +16,11 @@ namespace mlx::core {
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
int get_test_gemm() {
 | 
			
		||||
  static int t = env::get_var("MLX_ENABLE_TEST_GEMM", 0);
 | 
			
		||||
  return t;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::tuple<bool, int64_t, array>
 | 
			
		||||
check_transpose(cu::CommandEncoder& enc, const Stream& s, const array& arr) {
 | 
			
		||||
  auto stx = arr.strides()[arr.ndim() - 2];
 | 
			
		||||
@@ -99,15 +104,13 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  if (M % 512 == 0 && N % 512 == 0 && K % 512 == 0 && !a_transposed &&
 | 
			
		||||
      b_transposed && batch_count == 1 &&
 | 
			
		||||
      env::get_var("MLX_ENABLE_TEST_GEMM", 0) == 1) {
 | 
			
		||||
      b_transposed && batch_count == 1 && get_test_gemm() == 1) {
 | 
			
		||||
    cu::simple_gemm(a, b, out, M, N, K, encoder);
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  if (M % 512 == 0 && N % 512 == 0 && K % 512 == 0 && !a_transposed &&
 | 
			
		||||
      b_transposed && batch_count == 1 &&
 | 
			
		||||
      env::get_var("MLX_ENABLE_TEST_GEMM", 0) == 2) {
 | 
			
		||||
      b_transposed && batch_count == 1 && get_test_gemm() == 2) {
 | 
			
		||||
    cu::cutlass_gemm(a, b, out, M, N, K, encoder);
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
 
 | 
			
		||||
@@ -8,20 +8,19 @@ template <typename T, int BM, int BN, int BK, int WM, int WN>
 | 
			
		||||
__device__ inline void gemm_ab_t(
 | 
			
		||||
    RegisterTile<float, BM / WM, BN / WN>& C,
 | 
			
		||||
    SharedTile<T, BM, BK>& As,
 | 
			
		||||
    SharedTile<T, BM, BK>& Bs,
 | 
			
		||||
    int lane_row_a,
 | 
			
		||||
    int lane_row_b,
 | 
			
		||||
    int lane_col) {
 | 
			
		||||
    SharedTile<T, BN, BK>& Bs,
 | 
			
		||||
    RegisterTileLoader<SharedTile<T, BM, BK>>& rloader_a,
 | 
			
		||||
    RegisterTileLoader<SharedTile<T, BN, BK>>& rloader_b) {
 | 
			
		||||
  RegisterTile<T, BM / WM, 16> A[2];
 | 
			
		||||
  RegisterTile<T, BN / WN, 16> B[2];
 | 
			
		||||
 | 
			
		||||
  A[0].load(As, As.base_addr(), lane_row_a, lane_col);
 | 
			
		||||
  B[0].load(Bs, Bs.base_addr(), lane_row_b, lane_col);
 | 
			
		||||
  rloader_a.load(A[0], As.base_addr(), 0);
 | 
			
		||||
  rloader_b.load(B[0], Bs.base_addr(), 0);
 | 
			
		||||
 | 
			
		||||
  MLX_UNROLL
 | 
			
		||||
  for (int k = 1; k < BK / 16; k++) {
 | 
			
		||||
    A[k & 1].load(As, As.base_addr(), lane_row_a, lane_col + k * 16);
 | 
			
		||||
    B[k & 1].load(Bs, Bs.base_addr(), lane_row_b, lane_col + k * 16);
 | 
			
		||||
    rloader_a.load(A[k & 1], As.base_addr(), k);
 | 
			
		||||
    rloader_b.load(B[k & 1], Bs.base_addr(), k);
 | 
			
		||||
 | 
			
		||||
    mma_t(C, A[(k - 1) & 1], B[(k - 1) & 1]);
 | 
			
		||||
  }
 | 
			
		||||
@@ -33,25 +32,91 @@ __device__ inline void gemm_ab_t(
 | 
			
		||||
 *
 | 
			
		||||
 * Computes A @ B.T when A and B are all aligned with the block sizes.
 | 
			
		||||
 */
 | 
			
		||||
template <typename T, int BM, int BN, int BK>
 | 
			
		||||
__global__ void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) {
 | 
			
		||||
  constexpr int WARPS_M = 4;
 | 
			
		||||
  constexpr int WARPS_N = 2;
 | 
			
		||||
  constexpr int NUM_WARPS = WARPS_M * WARPS_N;
 | 
			
		||||
  constexpr int WARP_STEP_M = BM / WARPS_M;
 | 
			
		||||
  constexpr int WARP_STEP_N = BN / WARPS_N;
 | 
			
		||||
  constexpr int PIPE = 4;
 | 
			
		||||
// template <typename T, int BM, int BN, int BK, int WM, int WN, int PIPE>
 | 
			
		||||
//__global__ __launch_bounds__(WM * WN * WARP_SIZE, 1)
 | 
			
		||||
// void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) {
 | 
			
		||||
//   constexpr int NUM_WARPS = WM * WN;
 | 
			
		||||
//   constexpr int WARP_STEP_M = BM / WM;
 | 
			
		||||
//   constexpr int WARP_STEP_N = BN / WN;
 | 
			
		||||
//
 | 
			
		||||
//   // Precompute some offsets for each thread
 | 
			
		||||
//   const int warpid = threadIdx.x / 32;
 | 
			
		||||
//   const int laneid = threadIdx.x % 32;
 | 
			
		||||
//   const int wm = warpid / WN;
 | 
			
		||||
//   const int wn = warpid % WN;
 | 
			
		||||
//   const int offset_m = wm * WARP_STEP_M;
 | 
			
		||||
//   const int offset_n = wn * WARP_STEP_N;
 | 
			
		||||
//
 | 
			
		||||
//   // Allocate shared memory
 | 
			
		||||
//   extern __shared__ char shmem[];
 | 
			
		||||
//   SharedTile<T, BM, BK>(&as)[PIPE] =
 | 
			
		||||
//       *(SharedTile<T, BM, BK>(*)[PIPE])(&shmem[0]);
 | 
			
		||||
//   SharedTile<T, BN, BK>(&bs)[PIPE] =
 | 
			
		||||
//       *(SharedTile<T, BN, BK>(*)[PIPE])(&shmem[sizeof(T) * PIPE * BM * BK]);
 | 
			
		||||
//
 | 
			
		||||
//   // Move the global pointers to the tile
 | 
			
		||||
//   a += blockIdx.y * BM * K;
 | 
			
		||||
//   b += blockIdx.x * BN * K;
 | 
			
		||||
//   y += blockIdx.y * BM * N + blockIdx.x * BN;
 | 
			
		||||
//
 | 
			
		||||
//   // Make the loaders to/from SMEM
 | 
			
		||||
//   SharedTileLoader<NUM_WARPS, SharedTile<T, BM, BK>> sloader_a(a, K);
 | 
			
		||||
//   SharedTileLoader<NUM_WARPS, SharedTile<T, BN, BK>> sloader_b(b, K);
 | 
			
		||||
//   RegisterTileLoader<SharedTile<T, BM, BK>> rloader_a(offset_m, laneid);
 | 
			
		||||
//   RegisterTileLoader<SharedTile<T, BN, BK>> rloader_b(offset_n, laneid);
 | 
			
		||||
//
 | 
			
		||||
//   // Start the SM pipeline
 | 
			
		||||
//   MLX_UNROLL
 | 
			
		||||
//   for (int i = 0; i < PIPE - 1; i++) {
 | 
			
		||||
//     sloader_a.load_async(as[i].base_addr());
 | 
			
		||||
//     sloader_b.load_async(bs[i].base_addr());
 | 
			
		||||
//     cp_async_commit();
 | 
			
		||||
//     sloader_a.next();
 | 
			
		||||
//     sloader_b.next();
 | 
			
		||||
//   }
 | 
			
		||||
//
 | 
			
		||||
//   // Allocate and zero the MMA accumulator
 | 
			
		||||
//   RegisterTile<float, BM / WM, BN / WN> C;
 | 
			
		||||
//   C.fill(0);
 | 
			
		||||
//
 | 
			
		||||
//   // Matmul loop
 | 
			
		||||
//   int num_blocks = K / BK;
 | 
			
		||||
//   int sread = 0;
 | 
			
		||||
//   int swrite = PIPE - 1;
 | 
			
		||||
//   for (int i = 0; i < num_blocks; i++) {
 | 
			
		||||
//     cp_async_wait<PIPE - 1>();
 | 
			
		||||
//
 | 
			
		||||
//     gemm_ab_t<T, BM, BN, BK, WM, WN>(
 | 
			
		||||
//         C, as[sread], bs[sread], rloader_a, rloader_b);
 | 
			
		||||
//
 | 
			
		||||
//     sloader_a.load_async(as[swrite].base_addr());
 | 
			
		||||
//     sloader_b.load_async(bs[swrite].base_addr());
 | 
			
		||||
//     cp_async_commit();
 | 
			
		||||
//     sloader_a.next(i + PIPE < num_blocks);
 | 
			
		||||
//     sloader_b.next(i + PIPE < num_blocks);
 | 
			
		||||
//
 | 
			
		||||
//     swrite = sread;
 | 
			
		||||
//     sread = (sread + 1) % PIPE;
 | 
			
		||||
//   }
 | 
			
		||||
//
 | 
			
		||||
//   C.store_global(y, N, offset_m, offset_n);
 | 
			
		||||
// }
 | 
			
		||||
 | 
			
		||||
template <typename T, int BM, int BN, int BK, int WM, int WN, int PIPE>
 | 
			
		||||
__global__ __launch_bounds__(
 | 
			
		||||
    WM* WN* WARP_SIZE,
 | 
			
		||||
    1) void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) {
 | 
			
		||||
  constexpr int NUM_WARPS = WM * WN;
 | 
			
		||||
  constexpr int WARP_STEP_M = BM / WM;
 | 
			
		||||
  constexpr int WARP_STEP_N = BN / WN;
 | 
			
		||||
 | 
			
		||||
  // Precompute some offsets for each thread
 | 
			
		||||
  const int warpid = threadIdx.x / 32;
 | 
			
		||||
  const int laneid = threadIdx.x % 32;
 | 
			
		||||
  const int wm = warpid / WARPS_N;
 | 
			
		||||
  const int wn = warpid % WARPS_N;
 | 
			
		||||
  const int wm = warpid / WN;
 | 
			
		||||
  const int wn = warpid % WN;
 | 
			
		||||
  const int offset_m = wm * WARP_STEP_M;
 | 
			
		||||
  const int offset_n = wn * WARP_STEP_N;
 | 
			
		||||
  const int lane_row_a = offset_m + (laneid & 15);
 | 
			
		||||
  const int lane_row_b = offset_n + (laneid & 15);
 | 
			
		||||
  const int lane_col = (laneid >> 4) << 3;
 | 
			
		||||
 | 
			
		||||
  // Allocate shared memory
 | 
			
		||||
  extern __shared__ char shmem[];
 | 
			
		||||
@@ -65,34 +130,59 @@ __global__ void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) {
 | 
			
		||||
  b += blockIdx.x * BN * K;
 | 
			
		||||
  y += blockIdx.y * BM * N + blockIdx.x * BN;
 | 
			
		||||
 | 
			
		||||
  // Make the loaders to/from SMEM
 | 
			
		||||
  using sloader = SharedTileLoader<NUM_WARPS, SharedTile<T, BM, BK>>;
 | 
			
		||||
  constexpr int SSTEP = sloader::STEP_ROWS * sizeof(T) * BK;
 | 
			
		||||
  const int srow = threadIdx.x / sloader::NUM_LOADS_PER_ROW;
 | 
			
		||||
  const int scol =
 | 
			
		||||
      (threadIdx.x % sloader::NUM_LOADS_PER_ROW) * sloader::ELEMENTS_PER_LOAD;
 | 
			
		||||
  a += srow * K + scol;
 | 
			
		||||
  b += srow * K + scol;
 | 
			
		||||
  uint32_t sm_offsets[PIPE][2];
 | 
			
		||||
  MLX_UNROLL
 | 
			
		||||
  for (int s = 0; s < PIPE; s++) {
 | 
			
		||||
    sm_offsets[s][0] = as[s].loc(as[s].base_addr(), srow, scol);
 | 
			
		||||
    sm_offsets[s][1] = bs[s].loc(bs[s].base_addr(), srow, scol);
 | 
			
		||||
  }
 | 
			
		||||
  RegisterTileLoader<SharedTile<T, BM, BK>> rloader_a(offset_m, laneid);
 | 
			
		||||
  RegisterTileLoader<SharedTile<T, BN, BK>> rloader_b(offset_n, laneid);
 | 
			
		||||
 | 
			
		||||
  // Start the SM pipeline
 | 
			
		||||
  MLX_UNROLL
 | 
			
		||||
  for (int i = 0; i < PIPE - 1; i++) {
 | 
			
		||||
    load_async<NUM_WARPS>(as[i], as[i].base_addr(), a + i * BK, K);
 | 
			
		||||
    load_async<NUM_WARPS>(bs[i], bs[i].base_addr(), b + i * BK, K);
 | 
			
		||||
  for (int s = 0; s < PIPE - 1; s++) {
 | 
			
		||||
    MLX_UNROLL
 | 
			
		||||
    for (int l = 0; l < sloader::NUM_LOADS_PER_THREAD; l++) {
 | 
			
		||||
      cp_async<16>(sm_offsets[s][0] + l * SSTEP, a);
 | 
			
		||||
      cp_async<16>(sm_offsets[s][1] + l * SSTEP, b);
 | 
			
		||||
      a += sloader::STEP_ROWS * K;
 | 
			
		||||
      b += sloader::STEP_ROWS * K;
 | 
			
		||||
    }
 | 
			
		||||
    cp_async_commit();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Allocate and zero the MMA accumulator
 | 
			
		||||
  RegisterTile<float, BM / WARPS_M, BN / WARPS_N> C;
 | 
			
		||||
  RegisterTile<float, BM / WM, BN / WN> C;
 | 
			
		||||
  C.fill(0);
 | 
			
		||||
 | 
			
		||||
  // Matmul loop
 | 
			
		||||
  int num_blocks = K / BK;
 | 
			
		||||
  int k_block = (PIPE - 1) * BK;
 | 
			
		||||
  int sread = 0;
 | 
			
		||||
  int swrite = PIPE - 1;
 | 
			
		||||
  for (int i = 0; i < num_blocks; i++) {
 | 
			
		||||
    cp_async_wait<PIPE - 2>();
 | 
			
		||||
    cp_async_wait<PIPE - 1>();
 | 
			
		||||
 | 
			
		||||
    if (k_block < K) {
 | 
			
		||||
      load_async<NUM_WARPS>(as[swrite], as[swrite].base_addr(), a + k_block, K);
 | 
			
		||||
      load_async<NUM_WARPS>(bs[swrite], bs[swrite].base_addr(), b + k_block, K);
 | 
			
		||||
    gemm_ab_t<T, BM, BN, BK, WM, WN>(
 | 
			
		||||
        C, as[sread], bs[sread], rloader_a, rloader_b);
 | 
			
		||||
 | 
			
		||||
    if (false) {
 | 
			
		||||
      MLX_UNROLL
 | 
			
		||||
      for (int l = 0; l < sloader::NUM_LOADS_PER_THREAD; l++) {
 | 
			
		||||
        cp_async<16>(sm_offsets[swrite][0] + l * SSTEP, a);
 | 
			
		||||
        cp_async<16>(sm_offsets[swrite][1] + l * SSTEP, b);
 | 
			
		||||
        a += sloader::STEP_ROWS * K;
 | 
			
		||||
        b += sloader::STEP_ROWS * K;
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    gemm_ab_t<T, BM, BN, BK, WARPS_M, WARPS_N>(
 | 
			
		||||
        C, as[sread], bs[sread], lane_row_a, lane_row_b, lane_col);
 | 
			
		||||
 | 
			
		||||
    cp_async_commit();
 | 
			
		||||
 | 
			
		||||
    swrite = sread;
 | 
			
		||||
 
 | 
			
		||||
@@ -225,6 +225,8 @@ struct RegisterTile {
 | 
			
		||||
 | 
			
		||||
template <typename T, int ROWS_, int COLS_>
 | 
			
		||||
struct SharedTile {
 | 
			
		||||
  using value_type = T;
 | 
			
		||||
 | 
			
		||||
  static constexpr int ROWS = ROWS_;
 | 
			
		||||
  static constexpr int COLS = COLS_;
 | 
			
		||||
  static constexpr int TILES_X = COLS / 16;
 | 
			
		||||
@@ -266,23 +268,26 @@ struct SharedTile {
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Return the location of the element at (row, col) using the swizzle.
 | 
			
		||||
  __device__ static inline uint32_t loc(uint32_t ptr, int row, int col) {
 | 
			
		||||
  __device__ static inline uint32_t offset(int row, int col) {
 | 
			
		||||
    if constexpr (swizzle_bytes > 0) {
 | 
			
		||||
      static constexpr int swizzle_repeat = swizzle_bytes * 8;
 | 
			
		||||
      static constexpr int subtile_cols = swizzle_bytes / sizeof(T);
 | 
			
		||||
      const int outer_idx = col / subtile_cols;
 | 
			
		||||
      const uint32_t addr = ptr +
 | 
			
		||||
          sizeof(T) *
 | 
			
		||||
              (outer_idx * ROWS * subtile_cols + row * subtile_cols +
 | 
			
		||||
               col % subtile_cols);
 | 
			
		||||
      const uint32_t addr = sizeof(T) *
 | 
			
		||||
          (outer_idx * ROWS * subtile_cols + row * subtile_cols +
 | 
			
		||||
           col % subtile_cols);
 | 
			
		||||
      const int swizzle = ((addr % swizzle_repeat) >> 7) << 4;
 | 
			
		||||
      return (addr ^ swizzle);
 | 
			
		||||
    } else {
 | 
			
		||||
      return ptr + sizeof(T) * (row * COLS + col);
 | 
			
		||||
      return sizeof(T) * (row * COLS + col);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Return the location of the element at (row, col) using the swizzle.
 | 
			
		||||
  __device__ static inline uint32_t loc(uint32_t ptr, int row, int col) {
 | 
			
		||||
    return ptr + offset(row, col);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Convenience functions to edit elements going through the swizzle.
 | 
			
		||||
  __device__ inline T& operator()(int row, int col) {
 | 
			
		||||
    return *ptr(data, row, col);
 | 
			
		||||
@@ -313,6 +318,76 @@ struct SharedTile {
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
template <int NUM_WARPS, typename Tile>
 | 
			
		||||
struct SharedTileLoader {
 | 
			
		||||
  using T = typename Tile::value_type;
 | 
			
		||||
 | 
			
		||||
  static constexpr int NUM_THREADS = NUM_WARPS * 32;
 | 
			
		||||
  static constexpr int ELEMENTS_PER_LOAD = sizeof(float4) / sizeof(T);
 | 
			
		||||
  static constexpr int NUM_LOADS = Tile::NUMEL / ELEMENTS_PER_LOAD;
 | 
			
		||||
  static constexpr int NUM_LOADS_PER_THREAD = NUM_LOADS / NUM_THREADS;
 | 
			
		||||
  static constexpr int NUM_LOADS_PER_ROW = Tile::COLS / ELEMENTS_PER_LOAD;
 | 
			
		||||
  static constexpr int STEP_ROWS = NUM_THREADS / NUM_LOADS_PER_ROW;
 | 
			
		||||
 | 
			
		||||
  const T* x_;
 | 
			
		||||
  int N_;
 | 
			
		||||
  uint32_t offset_;
 | 
			
		||||
 | 
			
		||||
  __device__ SharedTileLoader(const T* x, int N) : x_(x), N_(N) {
 | 
			
		||||
    const int row = threadIdx.x / NUM_LOADS_PER_ROW;
 | 
			
		||||
    const int col = threadIdx.x % NUM_LOADS_PER_ROW;
 | 
			
		||||
 | 
			
		||||
    x_ += row * N + col * ELEMENTS_PER_LOAD;
 | 
			
		||||
    offset_ = Tile::offset(row, col * ELEMENTS_PER_LOAD);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  __device__ inline void load_async(uint32_t base_address) {
 | 
			
		||||
    MLX_UNROLL
 | 
			
		||||
    for (int i = 0; i < NUM_LOADS_PER_THREAD; i++) {
 | 
			
		||||
      cp_async<16>(
 | 
			
		||||
          base_address + offset_ + i * STEP_ROWS * sizeof(T) * Tile::COLS,
 | 
			
		||||
          x_ + i * STEP_ROWS * N_);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  __device__ inline void next() {
 | 
			
		||||
    x_ += Tile::COLS;
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
template <typename Tile>
 | 
			
		||||
struct RegisterTileLoader {
 | 
			
		||||
  using T = typename Tile::value_type;
 | 
			
		||||
 | 
			
		||||
  uint32_t offset_[Tile::COLS / 16];
 | 
			
		||||
 | 
			
		||||
  __device__ RegisterTileLoader(int offset_row, int laneid) {
 | 
			
		||||
    const int row = offset_row + laneid & 15;
 | 
			
		||||
    const int col = (laneid >> 4) << 3;
 | 
			
		||||
 | 
			
		||||
    MLX_UNROLL
 | 
			
		||||
    for (int i = 0; i < Tile::COLS / 16; i++) {
 | 
			
		||||
      offset_[i] = Tile::offset(row, col + i * 16);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  template <typename T, int ROWS, int COLS>
 | 
			
		||||
  __device__ inline void
 | 
			
		||||
  load(RegisterTile<T, ROWS, COLS>& x, uint32_t base_address, int col) {
 | 
			
		||||
    constexpr int TILES_Y = RegisterTile<T, ROWS, COLS>::TILES_Y;
 | 
			
		||||
    constexpr int TILES_X = RegisterTile<T, ROWS, COLS>::TILES_X;
 | 
			
		||||
 | 
			
		||||
    MLX_UNROLL
 | 
			
		||||
    for (int i = 0; i < TILES_Y; i++) {
 | 
			
		||||
      MLX_UNROLL
 | 
			
		||||
      for (int j = 0; j < TILES_X; j++) {
 | 
			
		||||
        x.data[i * TILES_X + j].load(
 | 
			
		||||
            base_address + offset_[j + col] + i * 16 * Tile::COLS * sizeof(T));
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Load the tile from global memory by loading 16 bytes at a time and storing
 | 
			
		||||
 * them immediately.
 | 
			
		||||
 
 | 
			
		||||
@@ -21,15 +21,15 @@ __device__ inline void cp_async(uint32_t row_address, const T* x) {
 | 
			
		||||
#if defined(MLX_CUDA_SM_80_ENABLED)
 | 
			
		||||
  if constexpr (N == 16) {
 | 
			
		||||
    asm volatile(
 | 
			
		||||
        "cp.async.ca.shared::cta.global [%0], [%1], 16;\n" ::"r"(row_address),
 | 
			
		||||
        "cp.async.cg.shared::cta.global [%0], [%1], 16;\n" ::"r"(row_address),
 | 
			
		||||
        "l"(reinterpret_cast<const int4*>(x)));
 | 
			
		||||
  } else if constexpr (N == 8) {
 | 
			
		||||
    asm volatile(
 | 
			
		||||
        "cp.async.ca.shared::cta.global [%0], [%1], 8;\n" ::"r"(row_address),
 | 
			
		||||
        "cp.async.cg.shared::cta.global [%0], [%1], 8;\n" ::"r"(row_address),
 | 
			
		||||
        "l"(reinterpret_cast<const int2*>(x)));
 | 
			
		||||
  } else if constexpr (N == 4) {
 | 
			
		||||
    asm volatile(
 | 
			
		||||
        "cp.async.ca.shared::cta.global [%0], [%1], 4;\n" ::"r"(row_address),
 | 
			
		||||
        "cp.async.cg.shared::cta.global [%0], [%1], 4;\n" ::"r"(row_address),
 | 
			
		||||
        "l"(reinterpret_cast<const int*>(x)));
 | 
			
		||||
  }
 | 
			
		||||
#endif
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user