From 1523b803f31c5924aa6f25e8a09ca29d2e7b1a91 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 29 Jul 2025 16:47:45 -0700 Subject: [PATCH] Early stages of the steel utils --- mlx/backend/cuda/gemms/gemv.cu | 2 + mlx/backend/cuda/steel/gemm.cuh | 101 ++++++++ mlx/backend/cuda/steel/mma.cuh | 114 +++++++++ mlx/backend/cuda/steel/tiles.cuh | 420 +++++++++++++++++++++++++++++++ mlx/backend/cuda/steel/utils.cuh | 85 +++++++ 5 files changed, 722 insertions(+) create mode 100644 mlx/backend/cuda/steel/gemm.cuh create mode 100644 mlx/backend/cuda/steel/mma.cuh create mode 100644 mlx/backend/cuda/steel/tiles.cuh create mode 100644 mlx/backend/cuda/steel/utils.cuh diff --git a/mlx/backend/cuda/gemms/gemv.cu b/mlx/backend/cuda/gemms/gemv.cu index 55333adea..552ab9cda 100644 --- a/mlx/backend/cuda/gemms/gemv.cu +++ b/mlx/backend/cuda/gemms/gemv.cu @@ -143,6 +143,7 @@ void gemv( kernel, num_blocks_x, block_dims, + 0, mat, vec, out.data(), @@ -154,6 +155,7 @@ void gemv( kernel, dim3{num_blocks_x, batch_count}, block_dims, + 0, mat, vec, out.data(), diff --git a/mlx/backend/cuda/steel/gemm.cuh b/mlx/backend/cuda/steel/gemm.cuh new file mode 100644 index 000000000..99580d2de --- /dev/null +++ b/mlx/backend/cuda/steel/gemm.cuh @@ -0,0 +1,101 @@ + +#include "mlx/backend/cuda/steel/mma.cuh" +#include "mlx/backend/cuda/steel/tiles.cuh" + +namespace mlx::core::cu { + +/** + * An example gemm written with the utils. + * + * Computes A @ B.T when A and B are all aligned with the block sizes. + */ +template +__global__ void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) { + constexpr int WARPS_M = 2; + constexpr int WARPS_N = 2; + constexpr int NUM_WARPS = WARPS_M * WARPS_N; + constexpr int WARP_STEP_M = BM / WARPS_M; + constexpr int WARP_STEP_N = BN / WARPS_N; + + // Precompute some offsets for each thread + const int warpid = threadIdx.x / 32; + const int laneid = threadIdx.x % 32; + const int wm = warpid / WARPS_N; + const int wn = warpid % WARPS_N; + const int offset_m = wm * WARP_STEP_M; + const int offset_n = wn * WARP_STEP_N; + + // Allocate shared memory + extern __shared__ char shmem[]; + SharedTile(&as)[2] = *(SharedTile(*)[2])(&shmem[0]); + SharedTile(&bs)[2] = + *(SharedTile(*)[2])(&shmem[sizeof(T) * 2 * BM * BK]); + + // Allocate registers for the MMA + RegisterTile C; + RegisterTile A; + RegisterTile B; + + // Move the global pointers to the tile + a += blockIdx.y * BM * K; + b += blockIdx.x * BN * K; + y += blockIdx.y * BM * N + blockIdx.x * BN; + + // Zero the accumulators + C.fill(0); + + // Start the SM pipeline + load_async(as[0], as[0].base_addr(), a, K); + load_async(bs[0], bs[0].base_addr(), b, K); + cp_async_commit(); + + int tic = 0; + for (int k_block = BK; k_block < K; k_block += BK) { + load_async(as[tic ^ 1], as[tic ^ 1].base_addr(), a + k_block, K); + load_async(bs[tic ^ 1], bs[tic ^ 1].base_addr(), b + k_block, K); + cp_async_commit(); + cp_async_wait<1>(); + __syncthreads(); + + MLX_UNROLL + for (int k = 0; k < BK / 16; k++) { + A.load( + as[tic], + as[tic].base_addr(), + offset_m + laneid % 16, + k * 16 + laneid / 16 * 8); + B.load( + bs[tic], + bs[tic].base_addr(), + offset_n + laneid % 16, + k * 16 + laneid / 16 * 8); + + mma_t(C, A, B); + } + + tic ^= 1; + } + + // Empty the pipeline + cp_async_wait_all(); + __syncthreads(); + MLX_UNROLL + for (int k = 0; k < BK / 16; k++) { + A.load( + as[tic], + as[tic].base_addr(), + offset_m + laneid % 16, + k * 16 + laneid / 16 * 8); + B.load( + bs[tic], + bs[tic].base_addr(), + offset_n + laneid % 16, + k * 16 + laneid / 16 * 8); + + mma_t(C, A, B); + } + + C.store_global(y, N, offset_m, offset_n); +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/steel/mma.cuh b/mlx/backend/cuda/steel/mma.cuh new file mode 100644 index 000000000..42d3c9040 --- /dev/null +++ b/mlx/backend/cuda/steel/mma.cuh @@ -0,0 +1,114 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/cuda/steel/tiles.cuh" + +namespace mlx::core::cu { + +/** + * Fallback mma. + * + * We should probably a) implement a fallback or complain about it to the + * compiler. + */ +template +__device__ inline void +mma_t(Tile16x16& C, Tile16x16& A, Tile16x16& B) {} + +/** + * Multiply the 16x16 bfloat16 tiles and accumulate the result in one 16x16 + * float tile. + * + * We actually perform C += A @ B.T + */ +__device__ __forceinline__ void mma_t( + Tile16x16& C, + Tile16x16<__nv_bfloat16>& A, + Tile16x16<__nv_bfloat16>& B) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0, %1, %2, %3}, " + "{%4, %5, %6, %7}, " + "{%8, %9}, " + "{%10, %11, %12, %13};" + + // D matrix + : "+f"(C.values[0].x), + "+f"(C.values[0].y), + "+f"(C.values[1].x), + "+f"(C.values[1].y) + + // A matrix + : "r"(*(uint32_t*)(&A.values[0])), + "r"(*(uint32_t*)(&A.values[1])), + "r"(*(uint32_t*)(&A.values[2])), + "r"(*(uint32_t*)(&A.values[3])), + + // B matrix + "r"(*(uint32_t*)(&B.values[0])), + "r"(*(uint32_t*)(&B.values[2])), + + // C matrix + "f"(C.values[0].x), + "f"(C.values[0].y), + "f"(C.values[1].x), + "f"(C.values[1].y)); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0, %1, %2, %3}, " + "{%4, %5, %6, %7}, " + "{%8, %9}, " + "{%10, %11, %12, %13};" + + // D matrix + : "+f"(C.values[2].x), + "+f"(C.values[2].y), + "+f"(C.values[3].x), + "+f"(C.values[3].y) + + // A matrix + : "r"(*(uint32_t*)(&A.values[0])), + "r"(*(uint32_t*)(&A.values[1])), + "r"(*(uint32_t*)(&A.values[2])), + "r"(*(uint32_t*)(&A.values[3])), + + // B matrix + "r"(*(uint32_t*)(&B.values[1])), + "r"(*(uint32_t*)(&B.values[3])), + + // C matrix + "f"(C.values[2].x), + "f"(C.values[2].y), + "f"(C.values[3].x), + "f"(C.values[3].y)); +} + +/** + * Multiply larger register tiles by delegating to mma_t. + */ +template +__device__ __forceinline__ void mma_t( + RegisterTile& C, + RegisterTile& A, + RegisterTile& B) { + constexpr int TILES_M = RegisterTile::TILES_Y; + constexpr int TILES_K = RegisterTile::TILES_X; + constexpr int TILES_N = RegisterTile::TILES_Y; + + MLX_UNROLL + for (int k = 0; k < TILES_K; k++) { + MLX_UNROLL + for (int m = 0; m < TILES_M; m++) { + MLX_UNROLL + for (int n = 0; n < TILES_N; n++) { + mma_t( + C.data[m * TILES_N + n], + A.data[m * TILES_K + k], + B.data[n * TILES_K + k]); + } + } + } +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/steel/tiles.cuh b/mlx/backend/cuda/steel/tiles.cuh new file mode 100644 index 000000000..a44113e6b --- /dev/null +++ b/mlx/backend/cuda/steel/tiles.cuh @@ -0,0 +1,420 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/steel/utils.cuh" + +namespace mlx::core::cu { + +// Map types to their vector of 2 type float -> float2, double -> double2 etc +template +struct Vector2; +template <> +struct Vector2 { + using type = double2; +}; +template <> +struct Vector2 { + using type = float2; +}; +template <> +struct Vector2<__half> { + using type = __half2; +}; +template <> +struct Vector2<__nv_bfloat16> { + using type = __nv_bfloat162; +}; +template +using Vector2_t = typename Vector2::type; + +/** + * The basic building block for Ampere mmas. A 16x16 tile distributed across + * the warp. + * + * Each thread holds 8 values. They are distributed according to + * https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-fragment-mma-16816-float + * + * For use instructions see the individual methods eg load(). + */ +template +struct Tile16x16 { + using T2 = Vector2_t; + + T2 values[4]; + + __device__ inline void fill(T v) { + T2 v2 = {v, v}; + for (int i = 0; i < 4; i++) { + values[i] = v2; + } + } + + /** + * Load a 16x16 tile from shared memory. + * + * The instruction is a bit weird in the sense that the address provided by + * each thread and the elements loaded are not the same. + * + * We load 4 8x8 tiles. The tile rows are stored contiguously in memory. As a + * result the warp provides 4*8 = 32 addresses one per row. + * + * Threads 0-7 provide the addresses for the first tile, 8-15 for the second + * and so on. For instance to load a non swizzled tile we would do + * + * base_addr + (laneid % 16) * BK + (laneid / 2) * 8 + * + * See + * https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-ldmatrix + */ + __device__ __forceinline__ void load(uint32_t row_address) { + if constexpr ( + std::is_same_v || std::is_same_v) { + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.shared::cta.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(*(uint32_t*)&(values[0])), + "=r"(*(uint32_t*)&(values[1])), + "=r"(*(uint32_t*)&(values[2])), + "=r"(*(uint32_t*)&(values[3])) + : "r"(row_address)); + } + } + + /** + * Store the tile to the address pointed to by `x`. + * + * The provided pointer is a generic pointer but this is meant to be used to + * store to global memory. For storing to shared memory we should use + * `stmatrix`. + * + * This also showcases the format of the tile quite nicely. Each register is + * holding to adjacent values. The indices are + * + * row + 0, col + 0 + * row + 8, col + 0 + * row + 0, col + 8 + * row + 8, col + 8 + * + * Given that we are dealing with Vector2_t the column offsets are 4 + * instead of 8. + */ + template + __device__ inline void store_global(U* x, int N) { + using U2 = Vector2_t; + U2* x2 = reinterpret_cast(x); + const int laneid = threadIdx.x % 32; + const int row = laneid / 4; + const int col = laneid % 4; + if constexpr (std::is_same_v) { + x2[(row + 0) * (N / 2) + col + 0] = values[0]; + x2[(row + 0) * (N / 2) + col + 4] = values[2]; + x2[(row + 8) * (N / 2) + col + 0] = values[1]; + x2[(row + 8) * (N / 2) + col + 4] = values[3]; + } else if constexpr ( + std::is_same_v && std::is_same_v) { + x2[(row + 0) * (N / 2) + col + 0] = + __floats2bfloat162_rn(values[0].x, values[0].y); + x2[(row + 0) * (N / 2) + col + 4] = + __floats2bfloat162_rn(values[2].x, values[2].y); + x2[(row + 8) * (N / 2) + col + 0] = + __floats2bfloat162_rn(values[1].x, values[1].y); + x2[(row + 8) * (N / 2) + col + 4] = + __floats2bfloat162_rn(values[3].x, values[3].y); + } + } + + template + __device__ inline void store_global_safe(U* x, int N, int max_rows) { + const int laneid = threadIdx.x % 32; + const int row = laneid / 4; + const int col = laneid % 4; + if (row < max_rows) { + x[(row + 0) * N + 2 * col + 0] = static_cast(values[0].x); + x[(row + 0) * N + 2 * col + 1] = static_cast(values[0].y); + x[(row + 0) * N + 2 * col + 8] = static_cast(values[2].x); + x[(row + 0) * N + 2 * col + 9] = static_cast(values[2].y); + } + if (row + 8 < max_rows) { + x[(row + 8) * N + 2 * col + 0] = static_cast(values[1].x); + x[(row + 8) * N + 2 * col + 1] = static_cast(values[1].y); + x[(row + 8) * N + 2 * col + 8] = static_cast(values[3].x); + x[(row + 8) * N + 2 * col + 9] = static_cast(values[3].y); + } + } +}; + +/** + * A simple container of multiple Tile16x16. + * + * Provides utility functions for loading and manipulating collections of basic + * tiles. + */ +template +struct RegisterTile { + static constexpr int ROWS = ROWS_; + static constexpr int COLS = COLS_; + static constexpr int TILES_X = COLS / 16; + static constexpr int TILES_Y = ROWS / 16; + + Tile16x16 data[TILES_X * TILES_Y]; + + __device__ inline void fill(T v) { + MLX_UNROLL + for (int i = 0; i < TILES_Y; i++) { + MLX_UNROLL + for (int j = 0; j < TILES_X; j++) { + data[i * TILES_X + j].fill(v); + } + } + } + + template + __device__ __forceinline__ void + load(Tile& tile, uint32_t base_address, int row, int col) { + MLX_UNROLL + for (int i = 0; i < TILES_Y; i++) { + MLX_UNROLL + for (int j = 0; j < TILES_X; j++) { + data[i * TILES_X + j].load( + tile.loc(base_address, row + i * 16, col + j * 16)); + } + } + } + + template + __device__ __forceinline__ void + load(Tile& tile, F f, uint32_t base_address, int row, int col) { + MLX_UNROLL + for (int i = 0; i < TILES_Y; i++) { + MLX_UNROLL + for (int j = 0; j < TILES_X; j++) { + f(data[i * TILES_X + j], + tile, + base_address, + row + i * 16, + col + j * 16); + } + } + } + + template + __device__ inline void store_global(U* x, int N, int row, int col) { + MLX_UNROLL + for (int i = 0; i < TILES_Y; i++) { + MLX_UNROLL + for (int j = 0; j < TILES_X; j++) { + data[i * TILES_X + j].store_global( + x + (row + i * 16) * N + col + j * 16, N); + } + } + } + + template + __device__ inline void + store_global_safe(U* x, int N, int row, int col, int max_rows) { + MLX_UNROLL + for (int i = 0; i < TILES_Y; i++) { + MLX_UNROLL + for (int j = 0; j < TILES_X; j++) { + data[i * TILES_X + j].store_global_safe( + x + (row + i * 16) * N + col + j * 16, N, max_rows - row - i * 16); + } + } + } +}; + +template +struct SharedTile { + static constexpr int ROWS = ROWS_; + static constexpr int COLS = COLS_; + static constexpr int TILES_X = COLS / 16; + static constexpr int TILES_Y = ROWS / 16; + static constexpr int NUMEL = ROWS * COLS; + + // Swizzle taken from ThunderKittens. Should be changed when we switch to + // cute Layouts. + // + // See inludes/types/shared/st.cuh + // + // I do feel that it is too math heavy and can be improved. Also the math is + // done every time although the addresses don't change from load to load. I + // guess we are expecting the compiler to figure that out. + static constexpr int swizzle_bytes = + (sizeof(T) == 2 ? (TILES_X % 4 == 0 ? 128 : (TILES_X % 2 == 0 ? 64 : 32)) + : (sizeof(T) == 4 ? (TILES_X % 2 == 0 ? 128 : 64) : 0)); + + T data[ROWS * COLS]; + + __device__ inline uint32_t base_addr() const { + return __cvta_generic_to_shared(&data[0]); + } + + // Return a pointer to the element at (row, col) using the swizzle. + __device__ static inline T* ptr(T* ptr, int row, int col) { + if constexpr (swizzle_bytes > 0) { + static constexpr int swizzle_repeat = swizzle_bytes * 8; + static constexpr int subtile_cols = swizzle_bytes / sizeof(T); + const int outer_idx = col / subtile_cols; + const uint64_t addr = + (uint64_t)(&ptr + [outer_idx * ROWS * subtile_cols + row * subtile_cols + + col % subtile_cols]); + const int swizzle = ((addr % swizzle_repeat) >> 7) << 4; + return (T*)(addr ^ swizzle); + } else { + return ptr + row * COLS + col; + } + } + + // Return the location of the element at (row, col) using the swizzle. + __device__ static inline uint32_t loc(uint32_t ptr, int row, int col) { + if constexpr (swizzle_bytes > 0) { + static constexpr int swizzle_repeat = swizzle_bytes * 8; + static constexpr int subtile_cols = swizzle_bytes / sizeof(T); + const int outer_idx = col / subtile_cols; + const uint32_t addr = ptr + + sizeof(T) * + (outer_idx * ROWS * subtile_cols + row * subtile_cols + + col % subtile_cols); + const int swizzle = ((addr % swizzle_repeat) >> 7) << 4; + return (addr ^ swizzle); + } else { + return ptr + sizeof(T) * (row * COLS + col); + } + } + + // Convenience functions to edit elements going through the swizzle. + __device__ inline T& operator()(int row, int col) { + return *ptr(data, row, col); + } + __device__ inline void store(float4& v, int row, int col) { + *(reinterpret_cast(ptr(data, row, col))) = v; + } + __device__ inline void store(float2& v, int row, int col) { + *(reinterpret_cast(ptr(data, row, col))) = v; + } + __device__ inline void store(float& v, int row, int col) { + *(reinterpret_cast(ptr(data, row, col))) = v; + } + template + __device__ inline void store(T (&v)[N], int row, int col) { + if constexpr (sizeof(T) * N == 4) { + store(*(reinterpret_cast(&v[0])), row, col); + } else if constexpr (sizeof(T) * N == 8) { + store(*(reinterpret_cast(&v[0])), row, col); + } else if constexpr (sizeof(T) * N == 16) { + store(*(reinterpret_cast(&v[0])), row, col); + } else { + MLX_UNROLL + for (int i = 0; i < N; i++) { + *ptr(data, row, col + i) = v[i]; + } + } + } +}; + +/** + * Load the tile from global memory by loading 16 bytes at a time and storing + * them immediately. + * + * Can also be used as a fallback for architectures before sm_80. + */ +template +__device__ inline void load(Tile& tile, const T* x, int N) { + constexpr int NUM_THREADS = NUM_WARPS * 32; + constexpr int ELEMENTS_PER_LOAD = sizeof(float4) / sizeof(T); + constexpr int NUM_LOADS = Tile::NUMEL / ELEMENTS_PER_LOAD; + constexpr int NUM_LOADS_PER_THREAD = NUM_LOADS / NUM_THREADS; + constexpr int NUM_LOADS_PER_ROW = Tile::COLS / ELEMENTS_PER_LOAD; + constexpr int STEP_ROWS = NUM_THREADS / NUM_LOADS_PER_ROW; + + const int row = threadIdx.x / NUM_LOADS_PER_ROW; + const int col = threadIdx.x % NUM_LOADS_PER_ROW; + + x += row * N + col * ELEMENTS_PER_LOAD; + + MLX_UNROLL + for (int i = 0; i < NUM_LOADS_PER_THREAD; i++) { + float4 tmp; + tmp = *(reinterpret_cast(&x[i * STEP_ROWS * N])); + tile.store(tmp, row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD); + } +} + +/** + * The asynchronous equivalent of load. + * + * Loads the tile from global memory by submitting a bunch of async copy + * instructions. The copy won't start until commit is called and we don't have + * a guarantee it will finish until wait is called. + * + * It should be used as follows + * + * load(...) + * load(...) + * cp_async_commit() + * do_other_stuff() + * cp_async_wait_all() + * do_stuff_with_shmem() + */ +template +__device__ inline void +load_async(Tile& tile, uint32_t base_address, const T* x, int N) { + constexpr int NUM_THREADS = NUM_WARPS * 32; + constexpr int ELEMENTS_PER_LOAD = sizeof(float4) / sizeof(T); + constexpr int NUM_LOADS = Tile::NUMEL / ELEMENTS_PER_LOAD; + constexpr int NUM_LOADS_PER_THREAD = NUM_LOADS / NUM_THREADS; + constexpr int NUM_LOADS_PER_ROW = Tile::COLS / ELEMENTS_PER_LOAD; + constexpr int STEP_ROWS = NUM_THREADS / NUM_LOADS_PER_ROW; + + const int row = threadIdx.x / NUM_LOADS_PER_ROW; + const int col = threadIdx.x % NUM_LOADS_PER_ROW; + + x += row * N + col * ELEMENTS_PER_LOAD; + + MLX_UNROLL + for (int i = 0; i < NUM_LOADS_PER_THREAD; i++) { + cp_async<16>( + tile.loc(base_address, row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD), + x + i * STEP_ROWS * N); + } +} + +/** + * Same as load_async but checks if we can load the row. + * + * NOTE: It should be changed to use a predicated cp async instead. + */ +template +__device__ inline void load_async_safe( + Tile& tile, + uint32_t base_address, + const T* x, + int N, + int max_rows) { + constexpr int NUM_THREADS = NUM_WARPS * 32; + constexpr int ELEMENTS_PER_LOAD = sizeof(float4) / sizeof(T); + constexpr int NUM_LOADS = Tile::NUMEL / ELEMENTS_PER_LOAD; + constexpr int NUM_LOADS_PER_THREAD = NUM_LOADS / NUM_THREADS; + constexpr int NUM_LOADS_PER_ROW = Tile::COLS / ELEMENTS_PER_LOAD; + constexpr int STEP_ROWS = NUM_THREADS / NUM_LOADS_PER_ROW; + + const int row = threadIdx.x / NUM_LOADS_PER_ROW; + const int col = threadIdx.x % NUM_LOADS_PER_ROW; + + x += row * N + col * ELEMENTS_PER_LOAD; + + MLX_UNROLL + for (int i = 0; i < NUM_LOADS_PER_THREAD; i++) { + if (row + i * STEP_ROWS < max_rows) { + cp_async<16>( + tile.loc(base_address, row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD), + x + i * STEP_ROWS * N); + } else { + float4 tmp = {0, 0, 0, 0}; + tile.store(tmp, row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD); + } + } +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/steel/utils.cuh b/mlx/backend/cuda/steel/utils.cuh new file mode 100644 index 000000000..cfa8c0ad5 --- /dev/null +++ b/mlx/backend/cuda/steel/utils.cuh @@ -0,0 +1,85 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/cuda/device/utils.cuh" + +#define MLX_UNROLL _Pragma("unroll") + +namespace mlx::core::cu { + +/** + * Copy bytes from the global memory address pointed to by x to the smem + * address pointed to by row_address. + * + * A simple wrapper over the PTX. + */ +template +__device__ inline void cp_async(uint32_t row_address, const T* x) { + static_assert( + N == 16 || N == 8 || N == 4, + "cp.async is only supported for N in {4, 8, 16}."); + + if constexpr (N == 16) { + asm volatile( + "cp.async.ca.shared::cta.global [%0], [%1], 16;\n" ::"r"(row_address), + "l"(reinterpret_cast(x))); + } else if constexpr (N == 8) { + asm volatile( + "cp.async.ca.shared::cta.global [%0], [%1], 8;\n" ::"r"(row_address), + "l"(reinterpret_cast(x))); + } else if constexpr (N == 4) { + asm volatile( + "cp.async.ca.shared::cta.global [%0], [%1], 4;\n" ::"r"(row_address), + "l"(reinterpret_cast(x))); + } +} + +/** + * Submit all the previous async copies to be executed. + */ +__device__ inline void cp_async_commit() { + asm volatile("cp.async.commit_group;\n" ::); +} + +/** + * Wait for all but N of the async copies to finish. + */ +template +__device__ inline void cp_async_wait() { + if constexpr (N == 0) { + asm volatile("cp.async.wait_all;\n" ::); + } else { + asm volatile("cp.async.wait_group %0;\n" ::"n"(N)); + } +} + +/** + * Wait for all the async copies to finish. + */ +__device__ inline void cp_async_wait_all() { + cp_async_wait<0>(); +} + +/** + * Extract ``bits`` bits from the 32 bit value. + * + * Single instruction shift and mask. + */ +template +__device__ inline uint32_t extract_bits(uint32_t value, int start_bit) { + static_assert( + bits == 2 || bits == 4 || bits == 8, + "extract_bits only supports 2, 4, 8 for now."); + uint32_t result; + if constexpr (bits == 2) { + asm("bfe.u32 %0, %1, %2, 2;" : "=r"(result) : "r"(value), "r"(start_bit)); + } else if constexpr (bits == 4) { + asm("bfe.u32 %0, %1, %2, 4;" : "=r"(result) : "r"(value), "r"(start_bit)); + } else if constexpr (bits == 8) { + asm("bfe.u32 %0, %1, %2, 8;" : "=r"(result) : "r"(value), "r"(start_bit)); + } + return result; +} + +} // namespace mlx::core::cu