diff --git a/mlx/backend/cuda/quantized/qmm.cu b/mlx/backend/cuda/quantized/qmm.cu index 671723483..5c6ddad1a 100644 --- a/mlx/backend/cuda/quantized/qmm.cu +++ b/mlx/backend/cuda/quantized/qmm.cu @@ -90,29 +90,29 @@ struct Tile16x16 { } else if constexpr (sizeof(T2) == 8) { asm volatile("st.shared.b64 [%1], %0;\n" : - : "r"(*(uint32_t*)&(values[0])), "r"(a)); + : "r"(*(uint64_t*)&(values[0])), "r"(a)); asm volatile("st.shared.b64 [%1], %0;\n" : - : "r"(*(uint32_t*)&(values[2])), "r"(b)); + : "r"(*(uint64_t*)&(values[2])), "r"(b)); asm volatile("st.shared.b64 [%1], %0;\n" : - : "r"(*(uint32_t*)&(values[1])), "r"(c)); + : "r"(*(uint64_t*)&(values[1])), "r"(c)); asm volatile("st.shared.b64 [%1], %0;\n" : - : "r"(*(uint32_t*)&(values[3])), "r"(d)); + : "r"(*(uint64_t*)&(values[3])), "r"(d)); } else if constexpr (sizeof(T2) == 16) { asm volatile("st.shared.b128 [%1], %0;\n" : - : "r"(*(uint32_t*)&(values[0])), "r"(a)); + : "r"(*(__int128*)&(values[0])), "r"(a)); asm volatile("st.shared.b128 [%1], %0;\n" : - : "r"(*(uint32_t*)&(values[2])), "r"(b)); + : "r"(*(__int128*)&(values[2])), "r"(b)); asm volatile("st.shared.b128 [%1], %0;\n" : - : "r"(*(uint32_t*)&(values[1])), "r"(c)); + : "r"(*(__int128*)&(values[1])), "r"(c)); asm volatile("st.shared.b128 [%1], %0;\n" : - : "r"(*(uint32_t*)&(values[3])), "r"(d)); + : "r"(*(__int128*)&(values[3])), "r"(d)); } } } @@ -143,17 +143,17 @@ struct Tile16x16 { } }; -template +template struct __align__(16) SharedTile { - static constexpr int TILES_R = R / 16; - static constexpr int TILES_C = C / 16; - static constexpr int NUM_ELEMENTS = R * C; + static constexpr int TILES_R = ROWS / 16; + static constexpr int TILES_C = COLS / 16; + static constexpr int NUM_ELEMENTS = ROWS * COLS; static constexpr int swizzle_bytes = (sizeof(T) == 2 ? (TILES_C % 4 == 0 ? 128 : (TILES_C % 2 == 0 ? 64 : 32)) : (sizeof(T) == 4 ? (TILES_C % 2 == 0 ? 128 : 64) : 0)); - T data[R * C]; + T data[ROWS * COLS]; __device__ static inline T* idx(T* ptr, int2 coord) { if constexpr (swizzle_bytes > 0) { @@ -163,12 +163,12 @@ struct __align__(16) SharedTile { const int outer_idx = c / subtile_cols; const uint64_t addr = (uint64_t)(&ptr - [outer_idx * R * subtile_cols + r * subtile_cols + + [outer_idx * ROWS * subtile_cols + r * subtile_cols + c % subtile_cols]); const int swizzle = ((addr % swizzle_repeat) >> 7) << 4; return (T*)(addr ^ swizzle); } else { - return ptr + coord.y * C + coord.x; + return ptr + coord.y * COLS + coord.x; } } @@ -180,33 +180,59 @@ struct __align__(16) SharedTile { const int outer_idx = c / subtile_cols; const uint32_t addr = ptr + sizeof(T) * - (outer_idx * R * subtile_cols + r * subtile_cols + + (outer_idx * ROWS * subtile_cols + r * subtile_cols + c % subtile_cols); const int swizzle = ((addr % swizzle_repeat) >> 7) << 4; return (addr ^ swizzle); } else { - return ptr + sizeof(T) * (coord.y * C + coord.x); + return ptr + sizeof(T) * (coord.y * COLS + coord.x); } } + __device__ inline T& operator[](int2 coord) { + return *idx(&data[0], coord); + } + __device__ inline void store(float4& v, int2 coord) { *(reinterpret_cast(idx(data, coord))) = v; } + __device__ inline void store(float2& v, int2 coord) { + *(reinterpret_cast(idx(data, coord))) = v; + } + + __device__ inline void store(float& v, int2 coord) { + *(reinterpret_cast(idx(data, coord))) = v; + } + + template + __device__ inline void store(T (&v)[N], int2 coord) { + if constexpr (sizeof(T) * N == 4) { + store(*(reinterpret_cast(&v[0])), coord); + } else if constexpr (sizeof(T) * N == 8) { + store(*(reinterpret_cast(&v[0])), coord); + } else if constexpr (sizeof(T) * N == 16) { + store(*(reinterpret_cast(&v[0])), coord); + } else { +#pragma unroll + for (int i = 0; i < N; i++) { + *idx(data, {coord.x, coord.y + i}) = v[i]; + } + } + } + template __device__ inline void load(const T* x, int N) { constexpr int NUM_THREADS = NUM_WARPS * 32; constexpr int ELEMENTS_PER_LOAD = sizeof(float4) / sizeof(T); constexpr int NUM_LOADS = NUM_ELEMENTS / ELEMENTS_PER_LOAD; constexpr int NUM_LOADS_PER_THREAD = NUM_LOADS / NUM_THREADS; - constexpr int NUM_LOADS_PER_ROW = C / ELEMENTS_PER_LOAD; + constexpr int NUM_LOADS_PER_ROW = COLS / ELEMENTS_PER_LOAD; constexpr int STEP_ROWS = NUM_THREADS / NUM_LOADS_PER_ROW; const int row = threadIdx.x / NUM_LOADS_PER_ROW; const int col = threadIdx.x % NUM_LOADS_PER_ROW; - uint32_t data_ptr = - static_cast(__cvta_generic_to_shared(&data[0])); x += row * N + col * ELEMENTS_PER_LOAD; #pragma unroll @@ -216,6 +242,42 @@ struct __align__(16) SharedTile { store(tmp, {row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD}); } } + + template + __device__ inline void + load_quantized(const uint8_t* x, const T* scales, const T* biases, int N) { + constexpr int NUM_THREADS = NUM_WARPS * 32; + constexpr int ELEMENTS_PER_LOAD = + sizeof(uint32_t) * get_pack_factor(); + constexpr int NUM_LOADS = NUM_ELEMENTS / ELEMENTS_PER_LOAD; + constexpr int NUM_LOADS_PER_THREAD = NUM_LOADS / NUM_THREADS; + constexpr int NUM_LOADS_PER_ROW = COLS / ELEMENTS_PER_LOAD; + constexpr int STEP_ROWS = NUM_THREADS / NUM_LOADS_PER_ROW; + constexpr int MASK = (1 << bits) - 1; + + const int row = threadIdx.x / NUM_LOADS_PER_ROW; + const int col = threadIdx.x % NUM_LOADS_PER_ROW; + + const int Nx = N / get_pack_factor(); + const int Ng = N / group_size; + + x += row * Nx + col * (ELEMENTS_PER_LOAD / get_pack_factor()); + scales += row * Ng + col * ELEMENTS_PER_LOAD / group_size; + biases += row * Ng + col * ELEMENTS_PER_LOAD / group_size; + +#pragma unroll + for (int i = 0; i < NUM_LOADS_PER_THREAD; i++) { + T vs[ELEMENTS_PER_LOAD]; + uint32_t w = *reinterpret_cast(x + i * STEP_ROWS * Nx); + T s = scales[i * STEP_ROWS * Ng]; + T b = biases[i * STEP_ROWS * Ng]; +#pragma unroll + for (int j = 0; j < ELEMENTS_PER_LOAD; j++) { + vs[j] = static_cast((w >> (j * bits)) & MASK) * s + b; + } + store(vs, {row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD}); + } + } }; template @@ -312,40 +374,53 @@ __global__ void qmm( Tile16x16 A[WARP_M]; Tile16x16 B[WARP_N]; + x += blockIdx.y * BM * K; + w += blockIdx.x * BN * K / get_pack_factor(); + scales += blockIdx.x * BN * K / group_size; + biases += blockIdx.x * BN * K / group_size; + y += blockIdx.y * BM * N + blockIdx.x * BN; + #pragma unroll for (int i = 0; i < WARP_M * WARP_N; i++) { C[i].clear(); } - xs.load(x, K); - ws.load(x + BM * K, K); - __syncthreads(); - uint32_t base_addr_xs = __cvta_generic_to_shared(&xs.data[0]); uint32_t base_addr_ws = __cvta_generic_to_shared(&ws.data[0]); -#pragma unroll - for (int k = 0; k < WARP_K; k++) { -#pragma unroll - for (int i = 0; i < WARP_M; i++) { - A[i].load(xs.idx( - base_addr_xs, - {offset_m + i * 16 + laneid % 16, k * 16 + laneid / 16 * 8})); - } -#pragma unroll - for (int i = 0; i < WARP_N; i++) { - B[i].load(ws.idx( - base_addr_ws, - {offset_n + i * 16 + laneid % 16, k * 16 + laneid / 16 * 8})); - } + for (int k_block = 0; k_block < K; k_block += BK) { + xs.load(x + k_block, K); + ws.load_quantized( + w + k_block / get_pack_factor(), + scales + k_block / group_size, + biases + k_block / group_size, + K); + __syncthreads(); #pragma unroll - for (int i = 0; i < WARP_M; i++) { + for (int k = 0; k < WARP_K; k++) { #pragma unroll - for (int j = 0; j < WARP_N; j++) { - mma(C[i * WARP_N + j], A[i], B[j]); + for (int i = 0; i < WARP_M; i++) { + A[i].load(xs.idx( + base_addr_xs, + {offset_m + i * 16 + laneid % 16, k * 16 + laneid / 16 * 8})); + } +#pragma unroll + for (int i = 0; i < WARP_N; i++) { + B[i].load(ws.idx( + base_addr_ws, + {offset_n + i * 16 + laneid % 16, k * 16 + laneid / 16 * 8})); + } + +#pragma unroll + for (int i = 0; i < WARP_M; i++) { +#pragma unroll + for (int j = 0; j < WARP_N; j++) { + mma(C[i * WARP_N + j], A[i], B[j]); + } } } + __syncthreads(); } #pragma unroll @@ -378,11 +453,16 @@ void qmm( // dispatch_groups(group_size_, [&](auto group_size) { // dispatch_bits(bits_, [&](auto bits) { using DataType = cuda_type_t; - auto kernel = cu::qmm; + constexpr int BM = 64; + constexpr int BN = 64; + constexpr int BK = 32; + auto kernel = cu::qmm; + + dim3 grid(N / BN, M / BM); enc.add_kernel_node( kernel, - 1, + grid, 128, x.data(), w.data(),