Early stages of the steel utils

This commit is contained in:
Angelos Katharopoulos
2025-07-29 16:47:45 -07:00
parent 8b25ce62d5
commit 1523b803f3
5 changed files with 722 additions and 0 deletions

View File

@@ -143,6 +143,7 @@ void gemv(
kernel,
num_blocks_x,
block_dims,
0,
mat,
vec,
out.data<DataType>(),
@@ -154,6 +155,7 @@ void gemv(
kernel,
dim3{num_blocks_x, batch_count},
block_dims,
0,
mat,
vec,
out.data<DataType>(),

View File

@@ -0,0 +1,101 @@
#include "mlx/backend/cuda/steel/mma.cuh"
#include "mlx/backend/cuda/steel/tiles.cuh"
namespace mlx::core::cu {
/**
* An example gemm written with the utils.
*
* Computes A @ B.T when A and B are all aligned with the block sizes.
*/
template <typename T, int BM, int BN, int BK>
__global__ void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) {
constexpr int WARPS_M = 2;
constexpr int WARPS_N = 2;
constexpr int NUM_WARPS = WARPS_M * WARPS_N;
constexpr int WARP_STEP_M = BM / WARPS_M;
constexpr int WARP_STEP_N = BN / WARPS_N;
// Precompute some offsets for each thread
const int warpid = threadIdx.x / 32;
const int laneid = threadIdx.x % 32;
const int wm = warpid / WARPS_N;
const int wn = warpid % WARPS_N;
const int offset_m = wm * WARP_STEP_M;
const int offset_n = wn * WARP_STEP_N;
// Allocate shared memory
extern __shared__ char shmem[];
SharedTile<T, BM, BK>(&as)[2] = *(SharedTile<T, BM, BK>(*)[2])(&shmem[0]);
SharedTile<T, BN, BK>(&bs)[2] =
*(SharedTile<T, BN, BK>(*)[2])(&shmem[sizeof(T) * 2 * BM * BK]);
// Allocate registers for the MMA
RegisterTile<float, BM / WARPS_M, BN / WARPS_N> C;
RegisterTile<T, BM / WARPS_M, 16> A;
RegisterTile<T, BN / WARPS_N, 16> B;
// Move the global pointers to the tile
a += blockIdx.y * BM * K;
b += blockIdx.x * BN * K;
y += blockIdx.y * BM * N + blockIdx.x * BN;
// Zero the accumulators
C.fill(0);
// Start the SM pipeline
load_async<NUM_WARPS>(as[0], as[0].base_addr(), a, K);
load_async<NUM_WARPS>(bs[0], bs[0].base_addr(), b, K);
cp_async_commit();
int tic = 0;
for (int k_block = BK; k_block < K; k_block += BK) {
load_async<NUM_WARPS>(as[tic ^ 1], as[tic ^ 1].base_addr(), a + k_block, K);
load_async<NUM_WARPS>(bs[tic ^ 1], bs[tic ^ 1].base_addr(), b + k_block, K);
cp_async_commit();
cp_async_wait<1>();
__syncthreads();
MLX_UNROLL
for (int k = 0; k < BK / 16; k++) {
A.load(
as[tic],
as[tic].base_addr(),
offset_m + laneid % 16,
k * 16 + laneid / 16 * 8);
B.load(
bs[tic],
bs[tic].base_addr(),
offset_n + laneid % 16,
k * 16 + laneid / 16 * 8);
mma_t(C, A, B);
}
tic ^= 1;
}
// Empty the pipeline
cp_async_wait_all();
__syncthreads();
MLX_UNROLL
for (int k = 0; k < BK / 16; k++) {
A.load(
as[tic],
as[tic].base_addr(),
offset_m + laneid % 16,
k * 16 + laneid / 16 * 8);
B.load(
bs[tic],
bs[tic].base_addr(),
offset_n + laneid % 16,
k * 16 + laneid / 16 * 8);
mma_t(C, A, B);
}
C.store_global(y, N, offset_m, offset_n);
}
} // namespace mlx::core::cu

View File

@@ -0,0 +1,114 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/backend/cuda/steel/tiles.cuh"
namespace mlx::core::cu {
/**
* Fallback mma.
*
* We should probably a) implement a fallback or complain about it to the
* compiler.
*/
template <typename U, typename T>
__device__ inline void
mma_t(Tile16x16<U>& C, Tile16x16<T>& A, Tile16x16<T>& B) {}
/**
* Multiply the 16x16 bfloat16 tiles and accumulate the result in one 16x16
* float tile.
*
* We actually perform C += A @ B.T
*/
__device__ __forceinline__ void mma_t(
Tile16x16<float>& C,
Tile16x16<__nv_bfloat16>& A,
Tile16x16<__nv_bfloat16>& B) {
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0, %1, %2, %3}, "
"{%4, %5, %6, %7}, "
"{%8, %9}, "
"{%10, %11, %12, %13};"
// D matrix
: "+f"(C.values[0].x),
"+f"(C.values[0].y),
"+f"(C.values[1].x),
"+f"(C.values[1].y)
// A matrix
: "r"(*(uint32_t*)(&A.values[0])),
"r"(*(uint32_t*)(&A.values[1])),
"r"(*(uint32_t*)(&A.values[2])),
"r"(*(uint32_t*)(&A.values[3])),
// B matrix
"r"(*(uint32_t*)(&B.values[0])),
"r"(*(uint32_t*)(&B.values[2])),
// C matrix
"f"(C.values[0].x),
"f"(C.values[0].y),
"f"(C.values[1].x),
"f"(C.values[1].y));
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0, %1, %2, %3}, "
"{%4, %5, %6, %7}, "
"{%8, %9}, "
"{%10, %11, %12, %13};"
// D matrix
: "+f"(C.values[2].x),
"+f"(C.values[2].y),
"+f"(C.values[3].x),
"+f"(C.values[3].y)
// A matrix
: "r"(*(uint32_t*)(&A.values[0])),
"r"(*(uint32_t*)(&A.values[1])),
"r"(*(uint32_t*)(&A.values[2])),
"r"(*(uint32_t*)(&A.values[3])),
// B matrix
"r"(*(uint32_t*)(&B.values[1])),
"r"(*(uint32_t*)(&B.values[3])),
// C matrix
"f"(C.values[2].x),
"f"(C.values[2].y),
"f"(C.values[3].x),
"f"(C.values[3].y));
}
/**
* Multiply larger register tiles by delegating to mma_t.
*/
template <typename U, typename T, int M, int N, int K>
__device__ __forceinline__ void mma_t(
RegisterTile<U, M, N>& C,
RegisterTile<T, M, K>& A,
RegisterTile<T, N, K>& B) {
constexpr int TILES_M = RegisterTile<T, M, K>::TILES_Y;
constexpr int TILES_K = RegisterTile<T, M, K>::TILES_X;
constexpr int TILES_N = RegisterTile<T, N, K>::TILES_Y;
MLX_UNROLL
for (int k = 0; k < TILES_K; k++) {
MLX_UNROLL
for (int m = 0; m < TILES_M; m++) {
MLX_UNROLL
for (int n = 0; n < TILES_N; n++) {
mma_t(
C.data[m * TILES_N + n],
A.data[m * TILES_K + k],
B.data[n * TILES_K + k]);
}
}
}
}
} // namespace mlx::core::cu

View File

@@ -0,0 +1,420 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/backend/steel/utils.cuh"
namespace mlx::core::cu {
// Map types to their vector of 2 type float -> float2, double -> double2 etc
template <typename T>
struct Vector2;
template <>
struct Vector2<double> {
using type = double2;
};
template <>
struct Vector2<float> {
using type = float2;
};
template <>
struct Vector2<__half> {
using type = __half2;
};
template <>
struct Vector2<__nv_bfloat16> {
using type = __nv_bfloat162;
};
template <typename T>
using Vector2_t = typename Vector2<T>::type;
/**
* The basic building block for Ampere mmas. A 16x16 tile distributed across
* the warp.
*
* Each thread holds 8 values. They are distributed according to
* https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-fragment-mma-16816-float
*
* For use instructions see the individual methods eg load().
*/
template <typename T>
struct Tile16x16 {
using T2 = Vector2_t<T>;
T2 values[4];
__device__ inline void fill(T v) {
T2 v2 = {v, v};
for (int i = 0; i < 4; i++) {
values[i] = v2;
}
}
/**
* Load a 16x16 tile from shared memory.
*
* The instruction is a bit weird in the sense that the address provided by
* each thread and the elements loaded are not the same.
*
* We load 4 8x8 tiles. The tile rows are stored contiguously in memory. As a
* result the warp provides 4*8 = 32 addresses one per row.
*
* Threads 0-7 provide the addresses for the first tile, 8-15 for the second
* and so on. For instance to load a non swizzled tile we would do
*
* base_addr + (laneid % 16) * BK + (laneid / 2) * 8
*
* See
* https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-ldmatrix
*/
__device__ __forceinline__ void load(uint32_t row_address) {
if constexpr (
std::is_same_v<T2, __nv_bfloat162> || std::is_same_v<T2, __half2>) {
asm volatile(
"ldmatrix.sync.aligned.m8n8.x4.shared::cta.b16 {%0, %1, %2, %3}, [%4];\n"
: "=r"(*(uint32_t*)&(values[0])),
"=r"(*(uint32_t*)&(values[1])),
"=r"(*(uint32_t*)&(values[2])),
"=r"(*(uint32_t*)&(values[3]))
: "r"(row_address));
}
}
/**
* Store the tile to the address pointed to by `x`.
*
* The provided pointer is a generic pointer but this is meant to be used to
* store to global memory. For storing to shared memory we should use
* `stmatrix`.
*
* This also showcases the format of the tile quite nicely. Each register is
* holding to adjacent values. The indices are
*
* row + 0, col + 0
* row + 8, col + 0
* row + 0, col + 8
* row + 8, col + 8
*
* Given that we are dealing with Vector2_t<U> the column offsets are 4
* instead of 8.
*/
template <typename U>
__device__ inline void store_global(U* x, int N) {
using U2 = Vector2_t<U>;
U2* x2 = reinterpret_cast<U2*>(x);
const int laneid = threadIdx.x % 32;
const int row = laneid / 4;
const int col = laneid % 4;
if constexpr (std::is_same_v<U2, T2>) {
x2[(row + 0) * (N / 2) + col + 0] = values[0];
x2[(row + 0) * (N / 2) + col + 4] = values[2];
x2[(row + 8) * (N / 2) + col + 0] = values[1];
x2[(row + 8) * (N / 2) + col + 4] = values[3];
} else if constexpr (
std::is_same_v<T2, float2> && std::is_same_v<U, __nv_bfloat16>) {
x2[(row + 0) * (N / 2) + col + 0] =
__floats2bfloat162_rn(values[0].x, values[0].y);
x2[(row + 0) * (N / 2) + col + 4] =
__floats2bfloat162_rn(values[2].x, values[2].y);
x2[(row + 8) * (N / 2) + col + 0] =
__floats2bfloat162_rn(values[1].x, values[1].y);
x2[(row + 8) * (N / 2) + col + 4] =
__floats2bfloat162_rn(values[3].x, values[3].y);
}
}
template <typename U>
__device__ inline void store_global_safe(U* x, int N, int max_rows) {
const int laneid = threadIdx.x % 32;
const int row = laneid / 4;
const int col = laneid % 4;
if (row < max_rows) {
x[(row + 0) * N + 2 * col + 0] = static_cast<U>(values[0].x);
x[(row + 0) * N + 2 * col + 1] = static_cast<U>(values[0].y);
x[(row + 0) * N + 2 * col + 8] = static_cast<U>(values[2].x);
x[(row + 0) * N + 2 * col + 9] = static_cast<U>(values[2].y);
}
if (row + 8 < max_rows) {
x[(row + 8) * N + 2 * col + 0] = static_cast<U>(values[1].x);
x[(row + 8) * N + 2 * col + 1] = static_cast<U>(values[1].y);
x[(row + 8) * N + 2 * col + 8] = static_cast<U>(values[3].x);
x[(row + 8) * N + 2 * col + 9] = static_cast<U>(values[3].y);
}
}
};
/**
* A simple container of multiple Tile16x16.
*
* Provides utility functions for loading and manipulating collections of basic
* tiles.
*/
template <typename T, int ROWS_, int COLS_>
struct RegisterTile {
static constexpr int ROWS = ROWS_;
static constexpr int COLS = COLS_;
static constexpr int TILES_X = COLS / 16;
static constexpr int TILES_Y = ROWS / 16;
Tile16x16<T> data[TILES_X * TILES_Y];
__device__ inline void fill(T v) {
MLX_UNROLL
for (int i = 0; i < TILES_Y; i++) {
MLX_UNROLL
for (int j = 0; j < TILES_X; j++) {
data[i * TILES_X + j].fill(v);
}
}
}
template <typename Tile>
__device__ __forceinline__ void
load(Tile& tile, uint32_t base_address, int row, int col) {
MLX_UNROLL
for (int i = 0; i < TILES_Y; i++) {
MLX_UNROLL
for (int j = 0; j < TILES_X; j++) {
data[i * TILES_X + j].load(
tile.loc(base_address, row + i * 16, col + j * 16));
}
}
}
template <typename Tile, typename F>
__device__ __forceinline__ void
load(Tile& tile, F f, uint32_t base_address, int row, int col) {
MLX_UNROLL
for (int i = 0; i < TILES_Y; i++) {
MLX_UNROLL
for (int j = 0; j < TILES_X; j++) {
f(data[i * TILES_X + j],
tile,
base_address,
row + i * 16,
col + j * 16);
}
}
}
template <typename U>
__device__ inline void store_global(U* x, int N, int row, int col) {
MLX_UNROLL
for (int i = 0; i < TILES_Y; i++) {
MLX_UNROLL
for (int j = 0; j < TILES_X; j++) {
data[i * TILES_X + j].store_global(
x + (row + i * 16) * N + col + j * 16, N);
}
}
}
template <typename U>
__device__ inline void
store_global_safe(U* x, int N, int row, int col, int max_rows) {
MLX_UNROLL
for (int i = 0; i < TILES_Y; i++) {
MLX_UNROLL
for (int j = 0; j < TILES_X; j++) {
data[i * TILES_X + j].store_global_safe(
x + (row + i * 16) * N + col + j * 16, N, max_rows - row - i * 16);
}
}
}
};
template <typename T, int ROWS_, int COLS_>
struct SharedTile {
static constexpr int ROWS = ROWS_;
static constexpr int COLS = COLS_;
static constexpr int TILES_X = COLS / 16;
static constexpr int TILES_Y = ROWS / 16;
static constexpr int NUMEL = ROWS * COLS;
// Swizzle taken from ThunderKittens. Should be changed when we switch to
// cute Layouts.
//
// See inludes/types/shared/st.cuh
//
// I do feel that it is too math heavy and can be improved. Also the math is
// done every time although the addresses don't change from load to load. I
// guess we are expecting the compiler to figure that out.
static constexpr int swizzle_bytes =
(sizeof(T) == 2 ? (TILES_X % 4 == 0 ? 128 : (TILES_X % 2 == 0 ? 64 : 32))
: (sizeof(T) == 4 ? (TILES_X % 2 == 0 ? 128 : 64) : 0));
T data[ROWS * COLS];
__device__ inline uint32_t base_addr() const {
return __cvta_generic_to_shared(&data[0]);
}
// Return a pointer to the element at (row, col) using the swizzle.
__device__ static inline T* ptr(T* ptr, int row, int col) {
if constexpr (swizzle_bytes > 0) {
static constexpr int swizzle_repeat = swizzle_bytes * 8;
static constexpr int subtile_cols = swizzle_bytes / sizeof(T);
const int outer_idx = col / subtile_cols;
const uint64_t addr =
(uint64_t)(&ptr
[outer_idx * ROWS * subtile_cols + row * subtile_cols +
col % subtile_cols]);
const int swizzle = ((addr % swizzle_repeat) >> 7) << 4;
return (T*)(addr ^ swizzle);
} else {
return ptr + row * COLS + col;
}
}
// Return the location of the element at (row, col) using the swizzle.
__device__ static inline uint32_t loc(uint32_t ptr, int row, int col) {
if constexpr (swizzle_bytes > 0) {
static constexpr int swizzle_repeat = swizzle_bytes * 8;
static constexpr int subtile_cols = swizzle_bytes / sizeof(T);
const int outer_idx = col / subtile_cols;
const uint32_t addr = ptr +
sizeof(T) *
(outer_idx * ROWS * subtile_cols + row * subtile_cols +
col % subtile_cols);
const int swizzle = ((addr % swizzle_repeat) >> 7) << 4;
return (addr ^ swizzle);
} else {
return ptr + sizeof(T) * (row * COLS + col);
}
}
// Convenience functions to edit elements going through the swizzle.
__device__ inline T& operator()(int row, int col) {
return *ptr(data, row, col);
}
__device__ inline void store(float4& v, int row, int col) {
*(reinterpret_cast<float4*>(ptr(data, row, col))) = v;
}
__device__ inline void store(float2& v, int row, int col) {
*(reinterpret_cast<float2*>(ptr(data, row, col))) = v;
}
__device__ inline void store(float& v, int row, int col) {
*(reinterpret_cast<float*>(ptr(data, row, col))) = v;
}
template <int N>
__device__ inline void store(T (&v)[N], int row, int col) {
if constexpr (sizeof(T) * N == 4) {
store(*(reinterpret_cast<float*>(&v[0])), row, col);
} else if constexpr (sizeof(T) * N == 8) {
store(*(reinterpret_cast<float2*>(&v[0])), row, col);
} else if constexpr (sizeof(T) * N == 16) {
store(*(reinterpret_cast<float4*>(&v[0])), row, col);
} else {
MLX_UNROLL
for (int i = 0; i < N; i++) {
*ptr(data, row, col + i) = v[i];
}
}
}
};
/**
* Load the tile from global memory by loading 16 bytes at a time and storing
* them immediately.
*
* Can also be used as a fallback for architectures before sm_80.
*/
template <int NUM_WARPS, typename T, typename Tile>
__device__ inline void load(Tile& tile, const T* x, int N) {
constexpr int NUM_THREADS = NUM_WARPS * 32;
constexpr int ELEMENTS_PER_LOAD = sizeof(float4) / sizeof(T);
constexpr int NUM_LOADS = Tile::NUMEL / ELEMENTS_PER_LOAD;
constexpr int NUM_LOADS_PER_THREAD = NUM_LOADS / NUM_THREADS;
constexpr int NUM_LOADS_PER_ROW = Tile::COLS / ELEMENTS_PER_LOAD;
constexpr int STEP_ROWS = NUM_THREADS / NUM_LOADS_PER_ROW;
const int row = threadIdx.x / NUM_LOADS_PER_ROW;
const int col = threadIdx.x % NUM_LOADS_PER_ROW;
x += row * N + col * ELEMENTS_PER_LOAD;
MLX_UNROLL
for (int i = 0; i < NUM_LOADS_PER_THREAD; i++) {
float4 tmp;
tmp = *(reinterpret_cast<const float4*>(&x[i * STEP_ROWS * N]));
tile.store(tmp, row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD);
}
}
/**
* The asynchronous equivalent of load.
*
* Loads the tile from global memory by submitting a bunch of async copy
* instructions. The copy won't start until commit is called and we don't have
* a guarantee it will finish until wait is called.
*
* It should be used as follows
*
* load(...)
* load(...)
* cp_async_commit()
* do_other_stuff()
* cp_async_wait_all()
* do_stuff_with_shmem()
*/
template <int NUM_WARPS, typename T, typename Tile>
__device__ inline void
load_async(Tile& tile, uint32_t base_address, const T* x, int N) {
constexpr int NUM_THREADS = NUM_WARPS * 32;
constexpr int ELEMENTS_PER_LOAD = sizeof(float4) / sizeof(T);
constexpr int NUM_LOADS = Tile::NUMEL / ELEMENTS_PER_LOAD;
constexpr int NUM_LOADS_PER_THREAD = NUM_LOADS / NUM_THREADS;
constexpr int NUM_LOADS_PER_ROW = Tile::COLS / ELEMENTS_PER_LOAD;
constexpr int STEP_ROWS = NUM_THREADS / NUM_LOADS_PER_ROW;
const int row = threadIdx.x / NUM_LOADS_PER_ROW;
const int col = threadIdx.x % NUM_LOADS_PER_ROW;
x += row * N + col * ELEMENTS_PER_LOAD;
MLX_UNROLL
for (int i = 0; i < NUM_LOADS_PER_THREAD; i++) {
cp_async<16>(
tile.loc(base_address, row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD),
x + i * STEP_ROWS * N);
}
}
/**
* Same as load_async but checks if we can load the row.
*
* NOTE: It should be changed to use a predicated cp async instead.
*/
template <int NUM_WARPS, typename T, typename Tile>
__device__ inline void load_async_safe(
Tile& tile,
uint32_t base_address,
const T* x,
int N,
int max_rows) {
constexpr int NUM_THREADS = NUM_WARPS * 32;
constexpr int ELEMENTS_PER_LOAD = sizeof(float4) / sizeof(T);
constexpr int NUM_LOADS = Tile::NUMEL / ELEMENTS_PER_LOAD;
constexpr int NUM_LOADS_PER_THREAD = NUM_LOADS / NUM_THREADS;
constexpr int NUM_LOADS_PER_ROW = Tile::COLS / ELEMENTS_PER_LOAD;
constexpr int STEP_ROWS = NUM_THREADS / NUM_LOADS_PER_ROW;
const int row = threadIdx.x / NUM_LOADS_PER_ROW;
const int col = threadIdx.x % NUM_LOADS_PER_ROW;
x += row * N + col * ELEMENTS_PER_LOAD;
MLX_UNROLL
for (int i = 0; i < NUM_LOADS_PER_THREAD; i++) {
if (row + i * STEP_ROWS < max_rows) {
cp_async<16>(
tile.loc(base_address, row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD),
x + i * STEP_ROWS * N);
} else {
float4 tmp = {0, 0, 0, 0};
tile.store(tmp, row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD);
}
}
}
} // namespace mlx::core::cu

View File

@@ -0,0 +1,85 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/backend/cuda/device/utils.cuh"
#define MLX_UNROLL _Pragma("unroll")
namespace mlx::core::cu {
/**
* Copy bytes from the global memory address pointed to by x to the smem
* address pointed to by row_address.
*
* A simple wrapper over the PTX.
*/
template <int N, typename T>
__device__ inline void cp_async(uint32_t row_address, const T* x) {
static_assert(
N == 16 || N == 8 || N == 4,
"cp.async is only supported for N in {4, 8, 16}.");
if constexpr (N == 16) {
asm volatile(
"cp.async.ca.shared::cta.global [%0], [%1], 16;\n" ::"r"(row_address),
"l"(reinterpret_cast<const int4*>(x)));
} else if constexpr (N == 8) {
asm volatile(
"cp.async.ca.shared::cta.global [%0], [%1], 8;\n" ::"r"(row_address),
"l"(reinterpret_cast<const int2*>(x)));
} else if constexpr (N == 4) {
asm volatile(
"cp.async.ca.shared::cta.global [%0], [%1], 4;\n" ::"r"(row_address),
"l"(reinterpret_cast<const int*>(x)));
}
}
/**
* Submit all the previous async copies to be executed.
*/
__device__ inline void cp_async_commit() {
asm volatile("cp.async.commit_group;\n" ::);
}
/**
* Wait for all but N of the async copies to finish.
*/
template <int N>
__device__ inline void cp_async_wait() {
if constexpr (N == 0) {
asm volatile("cp.async.wait_all;\n" ::);
} else {
asm volatile("cp.async.wait_group %0;\n" ::"n"(N));
}
}
/**
* Wait for all the async copies to finish.
*/
__device__ inline void cp_async_wait_all() {
cp_async_wait<0>();
}
/**
* Extract ``bits`` bits from the 32 bit value.
*
* Single instruction shift and mask.
*/
template <int bits>
__device__ inline uint32_t extract_bits(uint32_t value, int start_bit) {
static_assert(
bits == 2 || bits == 4 || bits == 8,
"extract_bits only supports 2, 4, 8 for now.");
uint32_t result;
if constexpr (bits == 2) {
asm("bfe.u32 %0, %1, %2, 2;" : "=r"(result) : "r"(value), "r"(start_bit));
} else if constexpr (bits == 4) {
asm("bfe.u32 %0, %1, %2, 4;" : "=r"(result) : "r"(value), "r"(start_bit));
} else if constexpr (bits == 8) {
asm("bfe.u32 %0, %1, %2, 8;" : "=r"(result) : "r"(value), "r"(start_bit));
}
return result;
}
} // namespace mlx::core::cu