diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index c5d36484f..b73e03a49 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -92,7 +92,7 @@ target_compile_options( # Compute capability 7 is required for synchronization between CPU/GPU with # managed memory. TODO: Add more architectures for potential performance gain. set(MLX_CUDA_ARCHITECTURES - "70;80" + "80" CACHE STRING "CUDA architectures") message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}") set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES diff --git a/mlx/backend/cuda/quantized/qmm.cu b/mlx/backend/cuda/quantized/qmm.cu index 9a58a59c4..671723483 100644 --- a/mlx/backend/cuda/quantized/qmm.cu +++ b/mlx/backend/cuda/quantized/qmm.cu @@ -7,7 +7,358 @@ namespace mlx::core { -namespace cu {} // namespace cu +namespace cu { + +template +struct Vector2; +template <> +struct Vector2 { + using type = double2; +}; +template <> +struct Vector2 { + using type = float2; +}; +template <> +struct Vector2<__half> { + using type = __half2; +}; +template <> +struct Vector2<__nv_bfloat16> { + using type = __nv_bfloat162; +}; +template +using Vector2_t = typename Vector2::type; + +template +struct Tile16x16 { + using T2 = Vector2_t; + + T2 values[4]; + + __device__ inline void clear() { + for (int i = 0; i < 4; i++) { + values[i] = static_cast(0); + } + } + + __device__ inline void load(uint32_t src_address) { + if constexpr ( + std::is_same_v || std::is_same_v) { + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.shared::cta.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(*(uint32_t*)&(values[0])), + "=r"(*(uint32_t*)&(values[1])), + "=r"(*(uint32_t*)&(values[2])), + "=r"(*(uint32_t*)&(values[3])) + : "r"(src_address)); + } + } + + __device__ inline void store(uint32_t dst_address) { + if constexpr ( + std::is_same_v || std::is_same_v) { + asm volatile( + "stmatrix.sync.aligned.m8n8.x4.shared::cta.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(*(uint32_t*)&(values[0])), + "=r"(*(uint32_t*)&(values[1])), + "=r"(*(uint32_t*)&(values[2])), + "=r"(*(uint32_t*)&(values[3])) + : "r"(dst_address)); + } else { + const int laneid = threadIdx.x % 32; + const int row = laneid / 4; + const int col = laneid % 4; + + const uint32_t a = dst_address + ((row + 0) * 8 + col + 0) * sizeof(T2); + const uint32_t b = dst_address + ((row + 0) * 8 + col + 4) * sizeof(T2); + const uint32_t c = dst_address + ((row + 8) * 8 + col + 0) * sizeof(T2); + const uint32_t d = dst_address + ((row + 8) * 8 + col + 4) * sizeof(T2); + if constexpr (sizeof(T2) == 4) { + asm volatile("st.shared.b32 [%1], %0;\n" + : + : "r"(*(uint32_t*)&(values[0])), "r"(a)); + asm volatile("st.shared.b32 [%1], %0;\n" + : + : "r"(*(uint32_t*)&(values[2])), "r"(b)); + asm volatile("st.shared.b32 [%1], %0;\n" + : + : "r"(*(uint32_t*)&(values[1])), "r"(c)); + asm volatile("st.shared.b32 [%1], %0;\n" + : + : "r"(*(uint32_t*)&(values[3])), "r"(d)); + } else if constexpr (sizeof(T2) == 8) { + asm volatile("st.shared.b64 [%1], %0;\n" + : + : "r"(*(uint32_t*)&(values[0])), "r"(a)); + asm volatile("st.shared.b64 [%1], %0;\n" + : + : "r"(*(uint32_t*)&(values[2])), "r"(b)); + asm volatile("st.shared.b64 [%1], %0;\n" + : + : "r"(*(uint32_t*)&(values[1])), "r"(c)); + asm volatile("st.shared.b64 [%1], %0;\n" + : + : "r"(*(uint32_t*)&(values[3])), "r"(d)); + } else if constexpr (sizeof(T2) == 16) { + asm volatile("st.shared.b128 [%1], %0;\n" + : + : "r"(*(uint32_t*)&(values[0])), "r"(a)); + asm volatile("st.shared.b128 [%1], %0;\n" + : + : "r"(*(uint32_t*)&(values[2])), "r"(b)); + asm volatile("st.shared.b128 [%1], %0;\n" + : + : "r"(*(uint32_t*)&(values[1])), "r"(c)); + asm volatile("st.shared.b128 [%1], %0;\n" + : + : "r"(*(uint32_t*)&(values[3])), "r"(d)); + } + } + } + + template + __device__ inline void store_global(U* x, int N) { + using U2 = Vector2_t; + U2* x2 = reinterpret_cast(x); + const int laneid = threadIdx.x % 32; + const int row = laneid / 4; + const int col = laneid % 4; + if constexpr (std::is_same_v) { + x2[(row + 0) * (N / 2) + col + 0] = values[0]; + x2[(row + 0) * (N / 2) + col + 4] = values[2]; + x2[(row + 8) * (N / 2) + col + 0] = values[1]; + x2[(row + 8) * (N / 2) + col + 4] = values[3]; + } else if constexpr ( + std::is_same_v && std::is_same_v) { + x2[(row + 0) * (N / 2) + col + 0] = + __floats2bfloat162_rn(values[0].x, values[0].y); + x2[(row + 0) * (N / 2) + col + 4] = + __floats2bfloat162_rn(values[2].x, values[2].y); + x2[(row + 8) * (N / 2) + col + 0] = + __floats2bfloat162_rn(values[1].x, values[1].y); + x2[(row + 8) * (N / 2) + col + 4] = + __floats2bfloat162_rn(values[3].x, values[3].y); + } + } +}; + +template +struct __align__(16) SharedTile { + static constexpr int TILES_R = R / 16; + static constexpr int TILES_C = C / 16; + static constexpr int NUM_ELEMENTS = R * C; + + static constexpr int swizzle_bytes = + (sizeof(T) == 2 ? (TILES_C % 4 == 0 ? 128 : (TILES_C % 2 == 0 ? 64 : 32)) + : (sizeof(T) == 4 ? (TILES_C % 2 == 0 ? 128 : 64) : 0)); + + T data[R * C]; + + __device__ static inline T* idx(T* ptr, int2 coord) { + if constexpr (swizzle_bytes > 0) { + int r = coord.x, c = coord.y; + static constexpr int swizzle_repeat = swizzle_bytes * 8; + static constexpr int subtile_cols = swizzle_bytes / sizeof(T); + const int outer_idx = c / subtile_cols; + const uint64_t addr = + (uint64_t)(&ptr + [outer_idx * R * subtile_cols + r * subtile_cols + + c % subtile_cols]); + const int swizzle = ((addr % swizzle_repeat) >> 7) << 4; + return (T*)(addr ^ swizzle); + } else { + return ptr + coord.y * C + coord.x; + } + } + + __device__ static inline uint32_t idx(uint32_t ptr, int2 coord) { + if constexpr (swizzle_bytes > 0) { + int r = coord.x, c = coord.y; + static constexpr int swizzle_repeat = swizzle_bytes * 8; + static constexpr int subtile_cols = swizzle_bytes / sizeof(T); + const int outer_idx = c / subtile_cols; + const uint32_t addr = ptr + + sizeof(T) * + (outer_idx * R * subtile_cols + r * subtile_cols + + c % subtile_cols); + const int swizzle = ((addr % swizzle_repeat) >> 7) << 4; + return (addr ^ swizzle); + } else { + return ptr + sizeof(T) * (coord.y * C + coord.x); + } + } + + __device__ inline void store(float4& v, int2 coord) { + *(reinterpret_cast(idx(data, coord))) = v; + } + + template + __device__ inline void load(const T* x, int N) { + constexpr int NUM_THREADS = NUM_WARPS * 32; + constexpr int ELEMENTS_PER_LOAD = sizeof(float4) / sizeof(T); + constexpr int NUM_LOADS = NUM_ELEMENTS / ELEMENTS_PER_LOAD; + constexpr int NUM_LOADS_PER_THREAD = NUM_LOADS / NUM_THREADS; + constexpr int NUM_LOADS_PER_ROW = C / ELEMENTS_PER_LOAD; + constexpr int STEP_ROWS = NUM_THREADS / NUM_LOADS_PER_ROW; + + const int row = threadIdx.x / NUM_LOADS_PER_ROW; + const int col = threadIdx.x % NUM_LOADS_PER_ROW; + + uint32_t data_ptr = + static_cast(__cvta_generic_to_shared(&data[0])); + x += row * N + col * ELEMENTS_PER_LOAD; + +#pragma unroll + for (int i = 0; i < NUM_LOADS_PER_THREAD; i++) { + float4 tmp; + tmp = *(reinterpret_cast(&x[i * STEP_ROWS * N])); + store(tmp, {row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD}); + } + } +}; + +template +__device__ inline void mma(TileAccum& C, Tile& A, Tile& B) {} + +__device__ inline void mma( + Tile16x16& C, + Tile16x16<__nv_bfloat16>& A, + Tile16x16<__nv_bfloat16>& B) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0, %1, %2, %3}, " + "{%4, %5, %6, %7}, " + "{%8, %9}, " + "{%10, %11, %12, %13};" + + // D matrix + : "+f"(C.values[0].x), + "+f"(C.values[0].y), + "+f"(C.values[1].x), + "+f"(C.values[1].y) + + // A matrix + : "r"(*(uint32_t*)(&A.values[0])), + "r"(*(uint32_t*)(&A.values[1])), + "r"(*(uint32_t*)(&A.values[2])), + "r"(*(uint32_t*)(&A.values[3])), + + // B matrix + "r"(*(uint32_t*)(&B.values[0])), + "r"(*(uint32_t*)(&B.values[2])), + + // C matrix + "f"(C.values[0].x), + "f"(C.values[0].y), + "f"(C.values[1].x), + "f"(C.values[1].y)); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0, %1, %2, %3}, " + "{%4, %5, %6, %7}, " + "{%8, %9}, " + "{%10, %11, %12, %13};" + + // D matrix + : "+f"(C.values[2].x), + "+f"(C.values[2].y), + "+f"(C.values[3].x), + "+f"(C.values[3].y) + + // A matrix + : "r"(*(uint32_t*)(&A.values[0])), + "r"(*(uint32_t*)(&A.values[1])), + "r"(*(uint32_t*)(&A.values[2])), + "r"(*(uint32_t*)(&A.values[3])), + + // B matrix + "r"(*(uint32_t*)(&B.values[1])), + "r"(*(uint32_t*)(&B.values[3])), + + // C matrix + "f"(C.values[2].x), + "f"(C.values[2].y), + "f"(C.values[3].x), + "f"(C.values[3].y)); +} + +template +__global__ void qmm( + const T* x, + const uint8_t* w, + const T* scales, + const T* biases, + T* y, + int M, + int N, + int K) { + constexpr int NUM_WARPS = 4; + constexpr int WARP_M = (BM / 16) / (NUM_WARPS / 2); + constexpr int WARP_N = (BN / 16) / (NUM_WARPS / 2); + constexpr int WARP_K = BK / 16; + constexpr int WARP_STEP_M = WARP_M * 16; + constexpr int WARP_STEP_N = WARP_N * 16; + + const int warpid = threadIdx.x / 32; + const int laneid = threadIdx.x % 32; + const int offset_m = (warpid / 2) * WARP_STEP_M; + const int offset_n = (warpid % 2) * WARP_STEP_N; + + __shared__ SharedTile xs; + __shared__ SharedTile ws; + + Tile16x16 C[WARP_M * WARP_N]; + Tile16x16 A[WARP_M]; + Tile16x16 B[WARP_N]; + +#pragma unroll + for (int i = 0; i < WARP_M * WARP_N; i++) { + C[i].clear(); + } + + xs.load(x, K); + ws.load(x + BM * K, K); + __syncthreads(); + + uint32_t base_addr_xs = __cvta_generic_to_shared(&xs.data[0]); + uint32_t base_addr_ws = __cvta_generic_to_shared(&ws.data[0]); + +#pragma unroll + for (int k = 0; k < WARP_K; k++) { +#pragma unroll + for (int i = 0; i < WARP_M; i++) { + A[i].load(xs.idx( + base_addr_xs, + {offset_m + i * 16 + laneid % 16, k * 16 + laneid / 16 * 8})); + } +#pragma unroll + for (int i = 0; i < WARP_N; i++) { + B[i].load(ws.idx( + base_addr_ws, + {offset_n + i * 16 + laneid % 16, k * 16 + laneid / 16 * 8})); + } + +#pragma unroll + for (int i = 0; i < WARP_M; i++) { +#pragma unroll + for (int j = 0; j < WARP_N; j++) { + mma(C[i * WARP_N + j], A[i], B[j]); + } + } + } + +#pragma unroll + for (int i = 0; i < WARP_M; i++) { +#pragma unroll + for (int j = 0; j < WARP_N; j++) { + C[i * WARP_N + j].store_global( + y + (offset_m + i * 16) * N + offset_n + j * 16, N); + } + } +} + +} // namespace cu void qmm( const array& x, @@ -24,13 +375,25 @@ void qmm( cu::CommandEncoder& enc, const Stream& s) { dispatch_float_types(x.dtype(), "qmm", [&](auto type_tag) { - dispatch_groups(group_size_, [&](auto group_size) { - dispatch_bits(bits_, [&](auto bits) { - dispatch_bool(transpose_, [&](auto transpose) { - using T = cuda_type_t; - }); - }); - }); + // dispatch_groups(group_size_, [&](auto group_size) { + // dispatch_bits(bits_, [&](auto bits) { + using DataType = cuda_type_t; + auto kernel = cu::qmm; + + enc.add_kernel_node( + kernel, + 1, + 128, + x.data(), + w.data(), + scales.data(), + biases.data(), + out.data(), + M, + N, + K); + //}); + //}); }); }