From 700f7dcf010a0887beb4b096a1cb887bec743832 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Mon, 21 Jul 2025 23:38:21 -0700 Subject: [PATCH] Refactor the matmul a bit --- mlx/backend/cuda/matmul/mma.cuh | 80 ++++++ mlx/backend/cuda/matmul/tiles.cuh | 231 +++++++++++++++++ mlx/backend/cuda/quantized/qmm.cu | 418 +++++------------------------- 3 files changed, 374 insertions(+), 355 deletions(-) create mode 100644 mlx/backend/cuda/matmul/mma.cuh create mode 100644 mlx/backend/cuda/matmul/tiles.cuh diff --git a/mlx/backend/cuda/matmul/mma.cuh b/mlx/backend/cuda/matmul/mma.cuh new file mode 100644 index 000000000..40367b590 --- /dev/null +++ b/mlx/backend/cuda/matmul/mma.cuh @@ -0,0 +1,80 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/cuda/matmul/tiles.cuh" + +namespace mlx::core::cu { + +template +__device__ inline void mma(TileAccum& C, Tile& A, Tile& B) {} + +/** + * Multiply the 16x16 bfloat16 tiles and accumulate the result in one 16x16 + * float tile. + * + * We actually perform C += A @ B.T + */ +__device__ inline void mma( + Tile16x16& C, + Tile16x16<__nv_bfloat16>& A, + Tile16x16<__nv_bfloat16>& B) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0, %1, %2, %3}, " + "{%4, %5, %6, %7}, " + "{%8, %9}, " + "{%10, %11, %12, %13};" + + // D matrix + : "+f"(C.values[0].x), + "+f"(C.values[0].y), + "+f"(C.values[1].x), + "+f"(C.values[1].y) + + // A matrix + : "r"(*(uint32_t*)(&A.values[0])), + "r"(*(uint32_t*)(&A.values[1])), + "r"(*(uint32_t*)(&A.values[2])), + "r"(*(uint32_t*)(&A.values[3])), + + // B matrix + "r"(*(uint32_t*)(&B.values[0])), + "r"(*(uint32_t*)(&B.values[2])), + + // C matrix + "f"(C.values[0].x), + "f"(C.values[0].y), + "f"(C.values[1].x), + "f"(C.values[1].y)); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0, %1, %2, %3}, " + "{%4, %5, %6, %7}, " + "{%8, %9}, " + "{%10, %11, %12, %13};" + + // D matrix + : "+f"(C.values[2].x), + "+f"(C.values[2].y), + "+f"(C.values[3].x), + "+f"(C.values[3].y) + + // A matrix + : "r"(*(uint32_t*)(&A.values[0])), + "r"(*(uint32_t*)(&A.values[1])), + "r"(*(uint32_t*)(&A.values[2])), + "r"(*(uint32_t*)(&A.values[3])), + + // B matrix + "r"(*(uint32_t*)(&B.values[1])), + "r"(*(uint32_t*)(&B.values[3])), + + // C matrix + "f"(C.values[2].x), + "f"(C.values[2].y), + "f"(C.values[3].x), + "f"(C.values[3].y)); +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/matmul/tiles.cuh b/mlx/backend/cuda/matmul/tiles.cuh new file mode 100644 index 000000000..debf21098 --- /dev/null +++ b/mlx/backend/cuda/matmul/tiles.cuh @@ -0,0 +1,231 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +namespace mlx::core::cu { + +// Map types to their vector of 2 type float -> float2, double -> double2 etc +template +struct Vector2; +template <> +struct Vector2 { + using type = double2; +}; +template <> +struct Vector2 { + using type = float2; +}; +template <> +struct Vector2<__half> { + using type = __half2; +}; +template <> +struct Vector2<__nv_bfloat16> { + using type = __nv_bfloat162; +}; +template +using Vector2_t = typename Vector2::type; + +/** + * The basic building block for Ampere mmas. A 16x16 tile distributed across + * the warp. + * + * Each thread holds 8 values. They are distributed according to + * https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-fragment-mma-16816-float + * + * For use instructions see the individual methods eg load(). + */ +template +struct Tile16x16 { + using T2 = Vector2_t; + + T2 values[4]; + + __device__ inline void clear() { + for (int i = 0; i < 4; i++) { + values[i] = static_cast(0); + } + } + + /** + * Load a 16x16 tile from shared memory. + * + * The instruction is a bit weird in the sense that the address provided by + * each thread and the elements loaded are not the same. + * + * We load 4 8x8 tiles. The tile rows are stored contiguously in memory. As a + * result the warp provides 4*8 = 32 addresses one per row. + * + * Threads 0-7 provide the addresses for the first tile, 8-15 for the second + * and so on. For instance to load a non swizzled tile we would do + * + * base_addr + (laneid % 16) * BK + (laneid / 2) * 8 + * + * See + * https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-ldmatrix + */ + __device__ inline void load(uint32_t row_address) { + if constexpr ( + std::is_same_v || std::is_same_v) { + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.shared::cta.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(*(uint32_t*)&(values[0])), + "=r"(*(uint32_t*)&(values[1])), + "=r"(*(uint32_t*)&(values[2])), + "=r"(*(uint32_t*)&(values[3])) + : "r"(row_address)); + } + } + + /** + * Store the tile to the address pointed to by `x`. + * + * The provided pointer is a generic pointer but this is meant to be used to + * store to global memory. For storing to shared memory we should use + * `stmatrix`. + * + * This also showcases the format of the tile quite nicely. Each register is + * holding to adjacent values. The indices are + * + * row + 0, col + 0 + * row + 8, col + 0 + * row + 0, col + 8 + * row + 8, col + 8 + * + * Given that we are dealing with Vector2_t the column offsets are 4 + * instead of 8. + */ + template + __device__ inline void store_global(U* x, int N) { + using U2 = Vector2_t; + U2* x2 = reinterpret_cast(x); + const int laneid = threadIdx.x % 32; + const int row = laneid / 4; + const int col = laneid % 4; + if constexpr (std::is_same_v) { + x2[(row + 0) * (N / 2) + col + 0] = values[0]; + x2[(row + 0) * (N / 2) + col + 4] = values[2]; + x2[(row + 8) * (N / 2) + col + 0] = values[1]; + x2[(row + 8) * (N / 2) + col + 4] = values[3]; + } else if constexpr ( + std::is_same_v && std::is_same_v) { + x2[(row + 0) * (N / 2) + col + 0] = + __floats2bfloat162_rn(values[0].x, values[0].y); + x2[(row + 0) * (N / 2) + col + 4] = + __floats2bfloat162_rn(values[2].x, values[2].y); + x2[(row + 8) * (N / 2) + col + 0] = + __floats2bfloat162_rn(values[1].x, values[1].y); + x2[(row + 8) * (N / 2) + col + 4] = + __floats2bfloat162_rn(values[3].x, values[3].y); + } + } +}; + +template +struct SharedTile { + static constexpr int ROWS = ROWS_; + static constexpr int COLS = COLS_; + static constexpr int TILES_X = COLS / 16; + static constexpr int TILES_Y = ROWS / 16; + static constexpr int NUMEL = ROWS * COLS; + + // Swizzle taken from ThunderKittens. + // + // See inludes/types/shared/st.cuh + // + // I do feel that it is too math heavy and can be improved. Also the math is + // done every time although the addresses don't change from load to load. I + // guess we are expecting the compiler to figure that out. + static constexpr int swizzle_bytes = + (sizeof(T) == 2 ? (TILES_X % 4 == 0 ? 128 : (TILES_X % 2 == 0 ? 64 : 32)) + : (sizeof(T) == 4 ? (TILES_X % 2 == 0 ? 128 : 64) : 0)); + + T data[ROWS * COLS]; + + // Return a pointer to the element at (row, col) using the swizzle. + __device__ static inline T* ptr(T* ptr, int row, int col) { + if constexpr (swizzle_bytes > 0) { + static constexpr int swizzle_repeat = swizzle_bytes * 8; + static constexpr int subtile_cols = swizzle_bytes / sizeof(T); + const int outer_idx = col / subtile_cols; + const uint64_t addr = + (uint64_t)(&ptr + [outer_idx * ROWS * subtile_cols + row * subtile_cols + + col % subtile_cols]); + const int swizzle = ((addr % swizzle_repeat) >> 7) << 4; + return (T*)(addr ^ swizzle); + } else { + return ptr + row * COLS + col; + } + } + + // Return the location of the element at (row, col) using the swizzle. + __device__ static inline uint32_t loc(uint32_t ptr, int row, int col) { + if constexpr (swizzle_bytes > 0) { + static constexpr int swizzle_repeat = swizzle_bytes * 8; + static constexpr int subtile_cols = swizzle_bytes / sizeof(T); + const int outer_idx = col / subtile_cols; + const uint32_t addr = ptr + + sizeof(T) * + (outer_idx * ROWS * subtile_cols + row * subtile_cols + + col % subtile_cols); + const int swizzle = ((addr % swizzle_repeat) >> 7) << 4; + return (addr ^ swizzle); + } else { + return ptr + sizeof(T) * (row * COLS + col); + } + } + + // Convenience functions to edit elements going through the swizzle. + __device__ inline T& operator()(int row, int col) { + return *ptr(data, row, col); + } + __device__ inline void store(float4& v, int row, int col) { + *(reinterpret_cast(ptr(data, row, col))) = v; + } + __device__ inline void store(float2& v, int row, int col) { + *(reinterpret_cast(ptr(data, row, col))) = v; + } + __device__ inline void store(float& v, int row, int col) { + *(reinterpret_cast(ptr(data, row, col))) = v; + } + template + __device__ inline void store(T (&v)[N], int row, int col) { + if constexpr (sizeof(T) * N == 4) { + store(*(reinterpret_cast(&v[0])), row, col); + } else if constexpr (sizeof(T) * N == 8) { + store(*(reinterpret_cast(&v[0])), row, col); + } else if constexpr (sizeof(T) * N == 16) { + store(*(reinterpret_cast(&v[0])), row, col); + } else { +#pragma unroll + for (int i = 0; i < N; i++) { + *ptr(data, row, col + i) = v[i]; + } + } + } +}; + +template +__device__ inline void load(Tile& tile, const T* x, int N) { + constexpr int NUM_THREADS = NUM_WARPS * 32; + constexpr int ELEMENTS_PER_LOAD = sizeof(float4) / sizeof(T); + constexpr int NUM_LOADS = Tile::NUMEL / ELEMENTS_PER_LOAD; + constexpr int NUM_LOADS_PER_THREAD = NUM_LOADS / NUM_THREADS; + constexpr int NUM_LOADS_PER_ROW = Tile::COLS / ELEMENTS_PER_LOAD; + constexpr int STEP_ROWS = NUM_THREADS / NUM_LOADS_PER_ROW; + + const int row = threadIdx.x / NUM_LOADS_PER_ROW; + const int col = threadIdx.x % NUM_LOADS_PER_ROW; + + x += row * N + col * ELEMENTS_PER_LOAD; + +#pragma unroll + for (int i = 0; i < NUM_LOADS_PER_THREAD; i++) { + float4 tmp; + tmp = *(reinterpret_cast(&x[i * STEP_ROWS * N])); + tile.store(tmp, row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD); + } +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/quantized/qmm.cu b/mlx/backend/cuda/quantized/qmm.cu index 5c6ddad1a..b3f27ac26 100644 --- a/mlx/backend/cuda/quantized/qmm.cu +++ b/mlx/backend/cuda/quantized/qmm.cu @@ -2,6 +2,8 @@ #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/backend/cuda/matmul/mma.cuh" +#include "mlx/backend/cuda/matmul/tiles.cuh" #include "mlx/backend/cuda/quantized/quantized_utils.cuh" #include "mlx/dtype_utils.h" @@ -9,340 +11,43 @@ namespace mlx::core { namespace cu { -template -struct Vector2; -template <> -struct Vector2 { - using type = double2; -}; -template <> -struct Vector2 { - using type = float2; -}; -template <> -struct Vector2<__half> { - using type = __half2; -}; -template <> -struct Vector2<__nv_bfloat16> { - using type = __nv_bfloat162; -}; -template -using Vector2_t = typename Vector2::type; +template +__device__ inline void load_quantized( + Tile& tile, + const uint8_t* x, + const T* scales, + const T* biases, + int N) { + constexpr int NUM_THREADS = NUM_WARPS * 32; + constexpr int ELEMENTS_PER_LOAD = sizeof(uint32_t) * get_pack_factor(); + constexpr int NUM_LOADS = Tile::NUMEL / ELEMENTS_PER_LOAD; + constexpr int NUM_LOADS_PER_THREAD = NUM_LOADS / NUM_THREADS; + constexpr int NUM_LOADS_PER_ROW = Tile::COLS / ELEMENTS_PER_LOAD; + constexpr int STEP_ROWS = NUM_THREADS / NUM_LOADS_PER_ROW; + constexpr int MASK = (1 << bits) - 1; -template -struct Tile16x16 { - using T2 = Vector2_t; + const int row = threadIdx.x / NUM_LOADS_PER_ROW; + const int col = threadIdx.x % NUM_LOADS_PER_ROW; - T2 values[4]; + const int Nx = N / get_pack_factor(); + const int Ng = N / group_size; - __device__ inline void clear() { - for (int i = 0; i < 4; i++) { - values[i] = static_cast(0); - } - } - - __device__ inline void load(uint32_t src_address) { - if constexpr ( - std::is_same_v || std::is_same_v) { - asm volatile( - "ldmatrix.sync.aligned.m8n8.x4.shared::cta.b16 {%0, %1, %2, %3}, [%4];\n" - : "=r"(*(uint32_t*)&(values[0])), - "=r"(*(uint32_t*)&(values[1])), - "=r"(*(uint32_t*)&(values[2])), - "=r"(*(uint32_t*)&(values[3])) - : "r"(src_address)); - } - } - - __device__ inline void store(uint32_t dst_address) { - if constexpr ( - std::is_same_v || std::is_same_v) { - asm volatile( - "stmatrix.sync.aligned.m8n8.x4.shared::cta.b16 {%0, %1, %2, %3}, [%4];\n" - : "=r"(*(uint32_t*)&(values[0])), - "=r"(*(uint32_t*)&(values[1])), - "=r"(*(uint32_t*)&(values[2])), - "=r"(*(uint32_t*)&(values[3])) - : "r"(dst_address)); - } else { - const int laneid = threadIdx.x % 32; - const int row = laneid / 4; - const int col = laneid % 4; - - const uint32_t a = dst_address + ((row + 0) * 8 + col + 0) * sizeof(T2); - const uint32_t b = dst_address + ((row + 0) * 8 + col + 4) * sizeof(T2); - const uint32_t c = dst_address + ((row + 8) * 8 + col + 0) * sizeof(T2); - const uint32_t d = dst_address + ((row + 8) * 8 + col + 4) * sizeof(T2); - if constexpr (sizeof(T2) == 4) { - asm volatile("st.shared.b32 [%1], %0;\n" - : - : "r"(*(uint32_t*)&(values[0])), "r"(a)); - asm volatile("st.shared.b32 [%1], %0;\n" - : - : "r"(*(uint32_t*)&(values[2])), "r"(b)); - asm volatile("st.shared.b32 [%1], %0;\n" - : - : "r"(*(uint32_t*)&(values[1])), "r"(c)); - asm volatile("st.shared.b32 [%1], %0;\n" - : - : "r"(*(uint32_t*)&(values[3])), "r"(d)); - } else if constexpr (sizeof(T2) == 8) { - asm volatile("st.shared.b64 [%1], %0;\n" - : - : "r"(*(uint64_t*)&(values[0])), "r"(a)); - asm volatile("st.shared.b64 [%1], %0;\n" - : - : "r"(*(uint64_t*)&(values[2])), "r"(b)); - asm volatile("st.shared.b64 [%1], %0;\n" - : - : "r"(*(uint64_t*)&(values[1])), "r"(c)); - asm volatile("st.shared.b64 [%1], %0;\n" - : - : "r"(*(uint64_t*)&(values[3])), "r"(d)); - } else if constexpr (sizeof(T2) == 16) { - asm volatile("st.shared.b128 [%1], %0;\n" - : - : "r"(*(__int128*)&(values[0])), "r"(a)); - asm volatile("st.shared.b128 [%1], %0;\n" - : - : "r"(*(__int128*)&(values[2])), "r"(b)); - asm volatile("st.shared.b128 [%1], %0;\n" - : - : "r"(*(__int128*)&(values[1])), "r"(c)); - asm volatile("st.shared.b128 [%1], %0;\n" - : - : "r"(*(__int128*)&(values[3])), "r"(d)); - } - } - } - - template - __device__ inline void store_global(U* x, int N) { - using U2 = Vector2_t; - U2* x2 = reinterpret_cast(x); - const int laneid = threadIdx.x % 32; - const int row = laneid / 4; - const int col = laneid % 4; - if constexpr (std::is_same_v) { - x2[(row + 0) * (N / 2) + col + 0] = values[0]; - x2[(row + 0) * (N / 2) + col + 4] = values[2]; - x2[(row + 8) * (N / 2) + col + 0] = values[1]; - x2[(row + 8) * (N / 2) + col + 4] = values[3]; - } else if constexpr ( - std::is_same_v && std::is_same_v) { - x2[(row + 0) * (N / 2) + col + 0] = - __floats2bfloat162_rn(values[0].x, values[0].y); - x2[(row + 0) * (N / 2) + col + 4] = - __floats2bfloat162_rn(values[2].x, values[2].y); - x2[(row + 8) * (N / 2) + col + 0] = - __floats2bfloat162_rn(values[1].x, values[1].y); - x2[(row + 8) * (N / 2) + col + 4] = - __floats2bfloat162_rn(values[3].x, values[3].y); - } - } -}; - -template -struct __align__(16) SharedTile { - static constexpr int TILES_R = ROWS / 16; - static constexpr int TILES_C = COLS / 16; - static constexpr int NUM_ELEMENTS = ROWS * COLS; - - static constexpr int swizzle_bytes = - (sizeof(T) == 2 ? (TILES_C % 4 == 0 ? 128 : (TILES_C % 2 == 0 ? 64 : 32)) - : (sizeof(T) == 4 ? (TILES_C % 2 == 0 ? 128 : 64) : 0)); - - T data[ROWS * COLS]; - - __device__ static inline T* idx(T* ptr, int2 coord) { - if constexpr (swizzle_bytes > 0) { - int r = coord.x, c = coord.y; - static constexpr int swizzle_repeat = swizzle_bytes * 8; - static constexpr int subtile_cols = swizzle_bytes / sizeof(T); - const int outer_idx = c / subtile_cols; - const uint64_t addr = - (uint64_t)(&ptr - [outer_idx * ROWS * subtile_cols + r * subtile_cols + - c % subtile_cols]); - const int swizzle = ((addr % swizzle_repeat) >> 7) << 4; - return (T*)(addr ^ swizzle); - } else { - return ptr + coord.y * COLS + coord.x; - } - } - - __device__ static inline uint32_t idx(uint32_t ptr, int2 coord) { - if constexpr (swizzle_bytes > 0) { - int r = coord.x, c = coord.y; - static constexpr int swizzle_repeat = swizzle_bytes * 8; - static constexpr int subtile_cols = swizzle_bytes / sizeof(T); - const int outer_idx = c / subtile_cols; - const uint32_t addr = ptr + - sizeof(T) * - (outer_idx * ROWS * subtile_cols + r * subtile_cols + - c % subtile_cols); - const int swizzle = ((addr % swizzle_repeat) >> 7) << 4; - return (addr ^ swizzle); - } else { - return ptr + sizeof(T) * (coord.y * COLS + coord.x); - } - } - - __device__ inline T& operator[](int2 coord) { - return *idx(&data[0], coord); - } - - __device__ inline void store(float4& v, int2 coord) { - *(reinterpret_cast(idx(data, coord))) = v; - } - - __device__ inline void store(float2& v, int2 coord) { - *(reinterpret_cast(idx(data, coord))) = v; - } - - __device__ inline void store(float& v, int2 coord) { - *(reinterpret_cast(idx(data, coord))) = v; - } - - template - __device__ inline void store(T (&v)[N], int2 coord) { - if constexpr (sizeof(T) * N == 4) { - store(*(reinterpret_cast(&v[0])), coord); - } else if constexpr (sizeof(T) * N == 8) { - store(*(reinterpret_cast(&v[0])), coord); - } else if constexpr (sizeof(T) * N == 16) { - store(*(reinterpret_cast(&v[0])), coord); - } else { -#pragma unroll - for (int i = 0; i < N; i++) { - *idx(data, {coord.x, coord.y + i}) = v[i]; - } - } - } - - template - __device__ inline void load(const T* x, int N) { - constexpr int NUM_THREADS = NUM_WARPS * 32; - constexpr int ELEMENTS_PER_LOAD = sizeof(float4) / sizeof(T); - constexpr int NUM_LOADS = NUM_ELEMENTS / ELEMENTS_PER_LOAD; - constexpr int NUM_LOADS_PER_THREAD = NUM_LOADS / NUM_THREADS; - constexpr int NUM_LOADS_PER_ROW = COLS / ELEMENTS_PER_LOAD; - constexpr int STEP_ROWS = NUM_THREADS / NUM_LOADS_PER_ROW; - - const int row = threadIdx.x / NUM_LOADS_PER_ROW; - const int col = threadIdx.x % NUM_LOADS_PER_ROW; - - x += row * N + col * ELEMENTS_PER_LOAD; + x += row * Nx + col * (ELEMENTS_PER_LOAD / get_pack_factor()); + scales += row * Ng + col * ELEMENTS_PER_LOAD / group_size; + biases += row * Ng + col * ELEMENTS_PER_LOAD / group_size; #pragma unroll - for (int i = 0; i < NUM_LOADS_PER_THREAD; i++) { - float4 tmp; - tmp = *(reinterpret_cast(&x[i * STEP_ROWS * N])); - store(tmp, {row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD}); - } - } - - template - __device__ inline void - load_quantized(const uint8_t* x, const T* scales, const T* biases, int N) { - constexpr int NUM_THREADS = NUM_WARPS * 32; - constexpr int ELEMENTS_PER_LOAD = - sizeof(uint32_t) * get_pack_factor(); - constexpr int NUM_LOADS = NUM_ELEMENTS / ELEMENTS_PER_LOAD; - constexpr int NUM_LOADS_PER_THREAD = NUM_LOADS / NUM_THREADS; - constexpr int NUM_LOADS_PER_ROW = COLS / ELEMENTS_PER_LOAD; - constexpr int STEP_ROWS = NUM_THREADS / NUM_LOADS_PER_ROW; - constexpr int MASK = (1 << bits) - 1; - - const int row = threadIdx.x / NUM_LOADS_PER_ROW; - const int col = threadIdx.x % NUM_LOADS_PER_ROW; - - const int Nx = N / get_pack_factor(); - const int Ng = N / group_size; - - x += row * Nx + col * (ELEMENTS_PER_LOAD / get_pack_factor()); - scales += row * Ng + col * ELEMENTS_PER_LOAD / group_size; - biases += row * Ng + col * ELEMENTS_PER_LOAD / group_size; - + for (int i = 0; i < NUM_LOADS_PER_THREAD; i++) { + T vs[ELEMENTS_PER_LOAD]; + uint32_t w = *reinterpret_cast(x + i * STEP_ROWS * Nx); + T s = scales[i * STEP_ROWS * Ng]; + T b = biases[i * STEP_ROWS * Ng]; #pragma unroll - for (int i = 0; i < NUM_LOADS_PER_THREAD; i++) { - T vs[ELEMENTS_PER_LOAD]; - uint32_t w = *reinterpret_cast(x + i * STEP_ROWS * Nx); - T s = scales[i * STEP_ROWS * Ng]; - T b = biases[i * STEP_ROWS * Ng]; -#pragma unroll - for (int j = 0; j < ELEMENTS_PER_LOAD; j++) { - vs[j] = static_cast((w >> (j * bits)) & MASK) * s + b; - } - store(vs, {row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD}); + for (int j = 0; j < ELEMENTS_PER_LOAD; j++) { + vs[j] = static_cast((w >> (j * bits)) & MASK) * s + b; } + tile.store(vs, row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD); } -}; - -template -__device__ inline void mma(TileAccum& C, Tile& A, Tile& B) {} - -__device__ inline void mma( - Tile16x16& C, - Tile16x16<__nv_bfloat16>& A, - Tile16x16<__nv_bfloat16>& B) { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " - "{%0, %1, %2, %3}, " - "{%4, %5, %6, %7}, " - "{%8, %9}, " - "{%10, %11, %12, %13};" - - // D matrix - : "+f"(C.values[0].x), - "+f"(C.values[0].y), - "+f"(C.values[1].x), - "+f"(C.values[1].y) - - // A matrix - : "r"(*(uint32_t*)(&A.values[0])), - "r"(*(uint32_t*)(&A.values[1])), - "r"(*(uint32_t*)(&A.values[2])), - "r"(*(uint32_t*)(&A.values[3])), - - // B matrix - "r"(*(uint32_t*)(&B.values[0])), - "r"(*(uint32_t*)(&B.values[2])), - - // C matrix - "f"(C.values[0].x), - "f"(C.values[0].y), - "f"(C.values[1].x), - "f"(C.values[1].y)); - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " - "{%0, %1, %2, %3}, " - "{%4, %5, %6, %7}, " - "{%8, %9}, " - "{%10, %11, %12, %13};" - - // D matrix - : "+f"(C.values[2].x), - "+f"(C.values[2].y), - "+f"(C.values[3].x), - "+f"(C.values[3].y) - - // A matrix - : "r"(*(uint32_t*)(&A.values[0])), - "r"(*(uint32_t*)(&A.values[1])), - "r"(*(uint32_t*)(&A.values[2])), - "r"(*(uint32_t*)(&A.values[3])), - - // B matrix - "r"(*(uint32_t*)(&B.values[1])), - "r"(*(uint32_t*)(&B.values[3])), - - // C matrix - "f"(C.values[2].x), - "f"(C.values[2].y), - "f"(C.values[3].x), - "f"(C.values[3].y)); } template @@ -389,8 +94,9 @@ __global__ void qmm( uint32_t base_addr_ws = __cvta_generic_to_shared(&ws.data[0]); for (int k_block = 0; k_block < K; k_block += BK) { - xs.load(x + k_block, K); - ws.load_quantized( + load(xs, x + k_block, K); + load_quantized( + ws, w + k_block / get_pack_factor(), scales + k_block / group_size, biases + k_block / group_size, @@ -401,15 +107,17 @@ __global__ void qmm( for (int k = 0; k < WARP_K; k++) { #pragma unroll for (int i = 0; i < WARP_M; i++) { - A[i].load(xs.idx( + A[i].load(xs.loc( base_addr_xs, - {offset_m + i * 16 + laneid % 16, k * 16 + laneid / 16 * 8})); + offset_m + i * 16 + laneid % 16, + k * 16 + laneid / 16 * 8)); } #pragma unroll for (int i = 0; i < WARP_N; i++) { - B[i].load(ws.idx( + B[i].load(ws.loc( base_addr_ws, - {offset_n + i * 16 + laneid % 16, k * 16 + laneid / 16 * 8})); + offset_n + i * 16 + laneid % 16, + k * 16 + laneid / 16 * 8)); } #pragma unroll @@ -420,7 +128,6 @@ __global__ void qmm( } } } - __syncthreads(); } #pragma unroll @@ -450,30 +157,31 @@ void qmm( cu::CommandEncoder& enc, const Stream& s) { dispatch_float_types(x.dtype(), "qmm", [&](auto type_tag) { - // dispatch_groups(group_size_, [&](auto group_size) { - // dispatch_bits(bits_, [&](auto bits) { - using DataType = cuda_type_t; - constexpr int BM = 64; - constexpr int BN = 64; - constexpr int BK = 32; - auto kernel = cu::qmm; + dispatch_groups(group_size_, [&](auto group_size) { + dispatch_bits(bits_, [&](auto bits) { + using DataType = cuda_type_t; + constexpr int BM = 64; + constexpr int BN = 64; + constexpr int BK = 32; + auto kernel = + cu::qmm; - dim3 grid(N / BN, M / BM); + dim3 grid(N / BN, M / BM); - enc.add_kernel_node( - kernel, - grid, - 128, - x.data(), - w.data(), - scales.data(), - biases.data(), - out.data(), - M, - N, - K); - //}); - //}); + enc.add_kernel_node( + kernel, + grid, + 128, + x.data(), + w.data(), + scales.data(), + biases.data(), + out.data(), + M, + N, + K); + }); + }); }); }