Compare commits

...

8 Commits

Author SHA1 Message Date
Eric Buehler
441bd764e6
Merge 4d68bd3250 into 4fda5fbdf9 2025-06-15 21:37:05 +02:00
Awni Hannun
4fda5fbdf9
add python testing for cuda with ability to skip list of tests (#2295) 2025-06-15 10:56:48 -07:00
Angelos Katharopoulos
580776559b
RoPE for CUDA (#2293)
* First working CUDA rope

* Fix random
2025-06-15 06:08:07 -07:00
Eric Buehler
4d68bd3250 Refactor v1/v2 caller code 2025-05-31 09:48:24 -04:00
Eric Buehler
5fbce6c49e Add v2 call 2025-05-31 09:30:51 -04:00
Eric Buehler
0b5c5680f4 Add v1 call 2025-05-31 09:20:22 -04:00
Eric Buehler
221edc4a65 Add the attention kernel 2025-05-31 08:25:25 -04:00
Eric Buehler
190c72739b Add pagedattn primitive 2025-05-31 08:10:33 -04:00
56 changed files with 2620 additions and 64 deletions

View File

@ -234,6 +234,7 @@ jobs:
command: |
source env/bin/activate
LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v
LOW_MEMORY=1 DEVICE=gpu python -m tests discover python/tests -v
build_release:
parameters:

View File

@ -12,6 +12,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/paged_attention.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/random.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cpp

View File

@ -209,4 +209,14 @@ Dims get_2d_grid_dims_common(
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
}
std::pair<Dims, Dims> get_grid_and_block_common(int dim0, int dim1, int dim2) {
auto [bx, by, bz] = get_block_dims_common(dim0, dim1, dim2);
auto gx = (dim0 + bx - 1) / bx;
auto gy = (dim1 + by - 1) / by;
auto gz = (dim2 + bz - 1) / bz;
return std::make_pair(
std::make_tuple(gx, gy, gz), std::make_tuple(bx, by, bz));
}
} // namespace mlx::core

View File

@ -95,6 +95,9 @@ Dims get_2d_grid_dims_common(
const Strides& strides,
size_t divisor);
// Get both the block and a grid of blocks that covers dim0, dim1 and dim2.
std::pair<Dims, Dims> get_grid_and_block_common(int dim0, int dim1, int dim2);
struct ContiguousIterator {
inline void step() {
int dims = shape_.size();

View File

@ -32,6 +32,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce/segmented_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu
${CMAKE_CURRENT_SOURCE_DIR}/sort.cu

View File

@ -23,4 +23,11 @@ dim3 get_2d_grid_dims(
return dim3(std::get<0>(dims), std::get<1>(dims), std::get<2>(dims));
}
std::pair<dim3, dim3> get_grid_and_block(int dim0, int dim1, int dim2) {
auto [grid, block] = get_grid_and_block_common(dim0, dim1, dim2);
auto [gx, gy, gz] = grid;
auto [bx, by, bz] = block;
return std::make_pair(dim3(gx, gy, gz), dim3(bx, by, bz));
}
} // namespace mlx::core

View File

@ -121,6 +121,7 @@ dim3 get_2d_grid_dims(
const Shape& shape,
const Strides& strides,
size_t divisor);
std::pair<dim3, dim3> get_grid_and_block(int dim0, int dim1, int dim2);
// Return a block size that achieves maximum potential occupancy for kernel.
template <typename T>

View File

@ -94,7 +94,6 @@ NO_GPU_MULTI(Eig)
NO_GPU_MULTI(Eigh)
namespace fast {
NO_GPU_USE_FALLBACK(RoPE)
NO_GPU(ScaledDotProductAttention)
NO_GPU_MULTI(AffineQuantize)
NO_GPU_MULTI(CustomKernel)

View File

@ -4,6 +4,7 @@
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/primitives.h"
#include <cooperative_groups.h>
#include <nvtx3/nvtx3.hpp>
#include <cassert>
@ -12,6 +13,8 @@ namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
__constant__ constexpr uint32_t rotations[2][4] = {
{13, 15, 26, 6},
{17, 29, 16, 24}};
@ -47,27 +50,28 @@ __global__ void rbitsc(
dim3 grid_dims,
bool odd,
uint32_t bytes_per_key) {
uint2 index{
blockIdx.x * blockDim.x + threadIdx.x,
blockIdx.y * blockDim.y + threadIdx.y};
if (index.x >= grid_dims.x || index.y >= grid_dims.y) {
auto grid = cg::this_grid();
uint thread_index = grid.thread_rank();
uint index_x = thread_index % grid_dims.x;
uint index_y = thread_index / grid_dims.x;
if (index_x >= grid_dims.x || index_y >= grid_dims.y) {
return;
}
auto kidx = 2 * index.x;
auto kidx = 2 * index_x;
auto key = uint2{keys[kidx], keys[kidx + 1]};
auto half_size = grid_dims.y - odd;
out += index.x * bytes_per_key;
bool drop_last = odd && (index.y == half_size);
out += index_x * bytes_per_key;
bool drop_last = odd && (index_y == half_size);
auto bits = threefry2x32_hash(
key, uint2{index.y, drop_last ? 0 : index.y + grid_dims.y});
size_t idx = size_t(index.y) << 2;
key, uint2{index_y, drop_last ? 0 : index_y + grid_dims.y});
size_t idx = size_t(index_y) << 2;
for (int i = 0; i < 4; ++i) {
out[idx + i] = bits.bytes[0][i];
}
if (!drop_last) {
idx = (drop_last ? 0 : size_t(index.y) + grid_dims.y) << 2;
if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) {
idx = (drop_last ? 0 : size_t(index_y) + grid_dims.y) << 2;
if ((index_y + 1) == half_size && (bytes_per_key % 4) > 0) {
int edge_bytes = (bytes_per_key % 4);
for (int i = 0; i < edge_bytes; ++i) {
out[idx + i] = bits.bytes[1][i];
@ -89,30 +93,31 @@ __global__ void rbits(
int32_t ndim,
const __grid_constant__ Shape key_shape,
const __grid_constant__ Strides key_strides) {
uint2 index{
blockIdx.x * blockDim.x + threadIdx.x,
blockIdx.y * blockDim.y + threadIdx.y};
if (index.x >= grid_dims.x || index.y >= grid_dims.y) {
auto grid = cg::this_grid();
uint thread_index = grid.thread_rank();
uint index_x = thread_index % grid_dims.x;
uint index_y = thread_index / grid_dims.x;
if (index_x >= grid_dims.x || index_y >= grid_dims.y) {
return;
}
auto kidx = 2 * index.x;
auto kidx = 2 * index_x;
auto k1_elem = elem_to_loc(kidx, key_shape.data(), key_strides.data(), ndim);
auto k2_elem =
elem_to_loc(kidx + 1, key_shape.data(), key_strides.data(), ndim);
auto key = uint2{keys[k1_elem], keys[k2_elem]};
auto half_size = grid_dims.y - odd;
out += size_t(index.x) * bytes_per_key;
bool drop_last = odd && (index.y == half_size);
out += size_t(index_x) * bytes_per_key;
bool drop_last = odd && (index_y == half_size);
auto bits = threefry2x32_hash(
key, uint2{index.y, drop_last ? 0 : index.y + grid_dims.y});
size_t idx = size_t(index.y) << 2;
key, uint2{index_y, drop_last ? 0 : index_y + grid_dims.y});
size_t idx = size_t(index_y) << 2;
for (int i = 0; i < 4; ++i) {
out[idx + i] = bits.bytes[0][i];
}
if (!drop_last) {
idx = (drop_last ? 0 : size_t(index.y) + grid_dims.y) << 2;
if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) {
idx = (drop_last ? 0 : size_t(index_y) + grid_dims.y) << 2;
if ((index_y + 1) == half_size && (bytes_per_key % 4) > 0) {
int edge_bytes = (bytes_per_key % 4);
for (int i = 0; i < edge_bytes; ++i) {
out[idx + i] = bits.bytes[1][i];
@ -153,19 +158,22 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
dim3 grid_dims{num_keys, half_size + odd};
dim3 block_dims = get_block_dims(grid_dims.x, grid_dims.y, 1);
dim3 num_blocks{
cuda::ceil_div(grid_dims.x, block_dims.x),
cuda::ceil_div(grid_dims.y, block_dims.y)};
int64_t total = grid_dims.x * grid_dims.y;
int32_t threads_y = 1;
while ((total / threads_y) >= (1U << 31)) {
threads_y *= 2;
}
int32_t threads_x = cuda::ceil_div(total, threads_y);
auto [grid, block] = get_grid_and_block(threads_x, threads_y, 1);
if (keys.flags().row_contiguous) {
cu::rbitsc<<<num_blocks, block_dims, 0, stream>>>(
cu::rbitsc<<<grid, block, 0, stream>>>(
keys.data<uint32_t>(),
out.data<uint8_t>(),
grid_dims,
odd,
bytes_per_key);
} else {
cu::rbits<<<num_blocks, block_dims, 0, stream>>>(
cu::rbits<<<grid, block, 0, stream>>>(
keys.data<uint32_t>(),
out.data<uint8_t>(),
grid_dims,

385
mlx/backend/cuda/rope.cu Normal file
View File

@ -0,0 +1,385 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h"
#include "mlx/fast_primitives.h"
#include <nvtx3/nvtx3.hpp>
namespace mlx::core {
namespace cu {
template <typename T, bool traditional, bool forward>
__device__ void rope_single_impl(
const T* in,
T* out,
int32_t offset,
float inv_freq,
float scale,
int64_t stride,
uint2 pos,
uint2 dims) {
float L = scale * static_cast<float>(offset);
// Compute costheta, sintheta
float theta = L * inv_freq;
float costheta = cos(theta);
float sintheta = sin(theta);
// Compute the input and output indices
uint index_1, index_2;
if (traditional) {
index_1 = 2 * pos.x + pos.y * stride;
index_2 = index_1 + 1;
} else {
index_1 = pos.x + pos.y * stride;
index_2 = index_1 + dims.x;
}
// Read and write the output
float x1 = static_cast<float>(in[index_1]);
float x2 = static_cast<float>(in[index_2]);
float rx1;
float rx2;
if (forward) {
rx1 = x1 * costheta - x2 * sintheta;
rx2 = x1 * sintheta + x2 * costheta;
} else {
rx1 = x2 * sintheta + x1 * costheta;
rx2 = x2 * costheta - x1 * sintheta;
}
out[index_1] = static_cast<T>(rx1);
out[index_2] = static_cast<T>(rx2);
}
template <typename T, bool traditional, bool forward>
__global__ void rope_single(
const T* in,
T* out,
const int32_t* offset,
float scale,
float base,
int64_t stride,
uint2 dims) {
uint2 pos = make_uint2(
blockIdx.x * blockDim.x + threadIdx.x,
blockIdx.y * blockDim.y + threadIdx.y);
if (pos.x >= dims.x || pos.y >= dims.y) {
return;
}
float d = static_cast<float>(pos.x) / static_cast<float>(dims.x);
float inv_freq = exp2(-d * base);
rope_single_impl<T, traditional, forward>(
in, out, *offset, inv_freq, scale, stride, pos, dims);
}
template <typename T, bool traditional, bool forward>
__global__ void rope_single_freqs(
const T* in,
T* out,
const int32_t* offset,
const float* freqs,
float scale,
int64_t stride,
uint2 dims,
int64_t freq_stride) {
uint2 pos = make_uint2(
blockIdx.x * blockDim.x + threadIdx.x,
blockIdx.y * blockDim.y + threadIdx.y);
if (pos.x >= dims.x || pos.y >= dims.y) {
return;
}
float inv_freq = 1.0 / freqs[freq_stride * pos.x];
rope_single_impl<T, traditional, forward>(
in, out, *offset, inv_freq, scale, stride, pos, dims);
}
template <typename T, bool traditional, bool forward, int N = 4>
__device__ void rope_impl(
const T* in,
T* out,
int offset,
float inv_freq,
float scale,
const cuda::std::array<int64_t, 3> strides,
const cuda::std::array<int64_t, 3> out_strides,
int64_t n_batch,
uint3 pos,
uint3 dims) {
float L = scale * static_cast<float>(pos.y + offset);
// Compute costheta, sintheta
float theta = L * inv_freq;
float costheta = cos(theta);
float sintheta = sin(theta);
// Compute the input and output indices
size_t in_index_1, in_index_2;
size_t out_index_1, out_index_2;
if (traditional) {
out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] +
N * pos.z * out_strides[0];
out_index_2 = out_index_1 + 1;
in_index_1 =
2 * pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
in_index_2 = in_index_1 + strides[2];
} else {
out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] +
N * pos.z * out_strides[0];
out_index_2 = out_index_1 + dims.x * out_strides[2];
in_index_1 =
pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
in_index_2 = in_index_1 + dims.x * strides[2];
}
for (int i = 0; i < N && pos.z * N + i < n_batch; ++i) {
// Read and write the output
float x1 = static_cast<float>(in[in_index_1]);
float x2 = static_cast<float>(in[in_index_2]);
float rx1;
float rx2;
if (forward) {
rx1 = x1 * costheta - x2 * sintheta;
rx2 = x1 * sintheta + x2 * costheta;
} else {
rx1 = x2 * sintheta + x1 * costheta;
rx2 = x2 * costheta - x1 * sintheta;
}
out[out_index_1] = static_cast<T>(rx1);
out[out_index_2] = static_cast<T>(rx2);
in_index_1 += strides[0];
in_index_2 += strides[0];
out_index_1 += out_strides[0];
out_index_2 += out_strides[0];
}
}
template <typename T, bool traditional, bool forward>
__global__ void rope(
const T* in,
T* out,
const int32_t* offset,
float scale,
float base,
const __grid_constant__ cuda::std::array<int64_t, 3> strides,
const __grid_constant__ cuda::std::array<int64_t, 3> out_strides,
int64_t n_batch,
uint3 dims) {
uint3 pos = make_uint3(
blockIdx.x * blockDim.x + threadIdx.x,
blockIdx.y * blockDim.y + threadIdx.y,
blockIdx.z * blockDim.z + threadIdx.z);
if (pos.x >= dims.x || pos.y >= dims.y || pos.z >= dims.z) {
return;
}
float d = static_cast<float>(pos.x) / static_cast<float>(dims.x);
float inv_freq = exp2(-d * base);
rope_impl<T, traditional, forward>(
in,
out,
*offset,
inv_freq,
scale,
strides,
out_strides,
n_batch,
pos,
dims);
}
template <typename T, bool traditional, bool forward>
__global__ void rope_freqs(
const T* in,
T* out,
const int32_t* offset,
const float* freqs,
float scale,
float base,
const __grid_constant__ cuda::std::array<int64_t, 3> strides,
const __grid_constant__ cuda::std::array<int64_t, 3> out_strides,
int64_t n_batch,
uint3 dims,
int64_t freq_stride) {
uint3 pos = make_uint3(
blockIdx.x * blockDim.x + threadIdx.x,
blockIdx.y * blockDim.y + threadIdx.y,
blockIdx.z * blockDim.z + threadIdx.z);
if (pos.x >= dims.x || pos.y >= dims.y || pos.z >= dims.z) {
return;
}
float inv_freq = 1.0 / freqs[freq_stride * pos.x];
rope_impl<T, traditional, forward>(
in,
out,
*offset,
inv_freq,
scale,
strides,
out_strides,
n_batch,
pos,
dims);
}
} // namespace cu
namespace fast {
bool RoPE::use_fallback(Stream s) {
return s.device == Device::cpu;
}
void RoPE::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
nvtx3::scoped_range r("RoPE::eval_gpu");
auto& s = stream();
auto& in = inputs[0];
auto& offset = inputs[1];
auto& out = outputs[0];
if (in.ndim() < 3) {
throw std::runtime_error("[RoPE] Input must have at least 3 dimensions");
}
cuda::std::array<int64_t, 3> strides;
cuda::std::array<int64_t, 3> out_strides;
bool donated = false;
int ndim = in.ndim();
int dispatch_ndim = in.ndim();
while (in.shape(-dispatch_ndim) == 1 && dispatch_ndim > 3) {
dispatch_ndim--;
}
size_t mat_size = in.shape(-2) * in.shape(-1);
// We apply rope to less that the whole vector so copy to output and then
// apply in-place.
if (dims_ < in.shape(-1)) {
donated = true;
auto ctype =
(in.flags().row_contiguous) ? CopyType::Vector : CopyType::General;
copy_gpu(in, out, ctype, s);
strides[0] = mat_size;
strides[1] = out.strides()[ndim - 2];
strides[2] = out.strides()[ndim - 1];
}
// Either copy or apply in-place
else if (in.flags().row_contiguous) {
if (in.is_donatable()) {
donated = true;
out.copy_shared_buffer(in);
} else {
out.set_data(allocator::malloc(out.nbytes()));
}
strides[0] = mat_size;
strides[1] = in.strides()[ndim - 2];
strides[2] = in.strides()[ndim - 1];
} else if (dispatch_ndim == 3) {
// Handle non-contiguous 3D inputs
out.set_data(allocator::malloc(out.nbytes()));
strides[0] = in.strides()[ndim - 3];
strides[1] = in.strides()[ndim - 2];
strides[2] = in.strides()[ndim - 1];
} else {
// Copy non-contiguous > 3D inputs into the output and treat
// input as donated
donated = true;
copy_gpu(in, out, CopyType::General, s);
strides[0] = mat_size;
strides[1] = out.strides()[ndim - 2];
strides[2] = out.strides()[ndim - 1];
}
out_strides[0] = mat_size;
out_strides[1] = out.strides()[ndim - 2];
out_strides[2] = out.strides()[ndim - 1];
// Some flags to help us dispatch below
bool single = in.flags().row_contiguous && (mat_size == in.shape(-1));
bool with_freqs = inputs.size() == 3;
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(donated ? out : in);
encoder.set_input_array(offset);
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_FLOAT_TYPES_CHECKED(in.dtype(), "rope", CTYPE, {
using DataType = cuda_type_t<CTYPE>;
MLX_SWITCH_BOOL(traditional_, TRADITIONAL, {
MLX_SWITCH_BOOL(forward_, FORWARD, {
if (single && !with_freqs) {
auto kernel = cu::rope_single<DataType, TRADITIONAL, FORWARD>;
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size);
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
kernel<<<grid, block, 0, stream>>>(
(donated ? out : in).data<DataType>(),
out.data<DataType>(),
offset.data<int32_t>(),
scale_,
std::log2(base_),
mat_size,
dims);
} else if (single) {
auto kernel = cu::rope_single_freqs<DataType, TRADITIONAL, FORWARD>;
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size);
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
kernel<<<grid, block, 0, stream>>>(
(donated ? out : in).data<DataType>(),
out.data<DataType>(),
offset.data<int32_t>(),
inputs[2].data<float>(),
scale_,
mat_size,
dims,
inputs[2].strides(0));
} else if (with_freqs) {
auto kernel = cu::rope_freqs<DataType, TRADITIONAL, FORWARD>;
uint3 dims =
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size);
dims.z = (dims.z + 3) / 4;
auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z);
kernel<<<grid, block, 0, stream>>>(
(donated ? out : in).data<DataType>(),
out.data<DataType>(),
offset.data<int32_t>(),
inputs[2].data<float>(),
scale_,
std::log2(base_),
strides,
out_strides,
in.size() / mat_size,
dims,
inputs[2].strides(0));
} else {
auto kernel = cu::rope<DataType, TRADITIONAL, FORWARD>;
uint3 dims =
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size);
dims.z = (dims.z + 3) / 4;
auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z);
kernel<<<grid, block, 0, stream>>>(
(donated ? out : in).data<DataType>(),
out.data<DataType>(),
offset.data<int32_t>(),
scale_,
std::log2(base_),
strides,
out_strides,
in.size() / mat_size,
dims);
}
});
});
});
});
}
} // namespace fast
} // namespace mlx::core

View File

@ -102,6 +102,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
${CMAKE_CURRENT_SOURCE_DIR}/paged_attention.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/normalization.cpp

View File

@ -241,6 +241,13 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
int wn,
bool transpose);
MTL::ComputePipelineState* get_paged_attention_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const std::string&);
// Create a GPU kernel template definition for JIT compilation
template <typename... Args>
std::string

View File

@ -109,6 +109,7 @@ if(NOT MLX_METAL_JIT)
reduction/reduce_row.h)
build_kernel(quantized quantized.h ${STEEL_HEADERS})
build_kernel(scan scan.h)
build_kernel(paged_attention paged_attention.h)
build_kernel(softmax softmax.h)
build_kernel(logsumexp logsumexp.h)
build_kernel(sort sort.h)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,131 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/kernels/paged_attention.h"
#include "mlx/backend/metal/kernels/utils.h"
#define instantiate_paged_attention_inner( \
type, head_size, block_size, num_threads, num_simd_lanes, partition_size) \
template \
[[host_name("paged_attention_" #type "_hs" #head_size "_bs" #block_size \
"_nt" #num_threads "_nsl" #num_simd_lanes \
"_ps" #partition_size)]] [[kernel]] void \
paged_attention< \
type, \
head_size, \
block_size, \
num_threads, \
num_simd_lanes, \
partition_size>( \
device float* exp_sums \
[[buffer(0), function_constant(use_partitioning)]], \
device float* max_logits \
[[buffer(1), function_constant(use_partitioning)]], \
device type* out [[buffer(2)]], \
device const type* q [[buffer(3)]], \
device const type* k_cache [[buffer(4)]], \
device const type* v_cache [[buffer(5)]], \
const constant int& num_kv_heads [[buffer(6)]], \
const constant float& scale [[buffer(7)]], \
const constant float& softcapping [[buffer(8)]], \
device const uint32_t* block_tables [[buffer(9)]], \
device const uint32_t* context_lens [[buffer(10)]], \
const constant int& max_num_blocks_per_seq [[buffer(11)]], \
device const float* alibi_slopes \
[[buffer(12), function_constant(use_alibi)]], \
const constant int& q_stride [[buffer(13)]], \
const constant int& kv_block_stride [[buffer(14)]], \
const constant int& kv_head_stride [[buffer(15)]], \
threadgroup char* shared_mem [[threadgroup(0)]], \
uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]], \
uint3 threadgroups_per_grid [[threadgroups_per_grid]], \
uint3 thread_position_in_threadgroup \
[[thread_position_in_threadgroup]], \
uint simd_tid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_paged_attention_v2_reduce_inner( \
type, head_size, num_threads, num_simd_lanes, partition_size) \
template [[host_name("paged_attention_v2_reduce_" #type "_hs" #head_size \
"_nt" #num_threads "_nsl" #num_simd_lanes \
"_ps" #partition_size)]] [[kernel]] void \
paged_attention_v2_reduce< \
type, \
head_size, \
num_threads, \
num_simd_lanes, \
partition_size>( \
device type * out [[buffer(0)]], \
const device float* exp_sums [[buffer(1)]], \
const device float* max_logits [[buffer(2)]], \
const device type* tmp_out [[buffer(3)]], \
device uint32_t* context_lens [[buffer(4)]], \
const constant int& max_num_partitions [[buffer(5)]], \
threadgroup char* shared_mem [[threadgroup(0)]], \
uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]], \
uint3 threadgroups_per_grid [[threadgroups_per_grid]], \
uint3 thread_position_in_threadgroup [[thread_position_in_threadgroup]], \
uint3 threads_per_threadgroup [[threads_per_threadgroup]], \
uint simd_tid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_paged_attention_heads( \
type, block_size, num_threads, num_simd_lanes, partition_size) \
instantiate_paged_attention_inner( \
type, 64, block_size, num_threads, num_simd_lanes, partition_size); \
instantiate_paged_attention_inner( \
type, 80, block_size, num_threads, num_simd_lanes, partition_size); \
instantiate_paged_attention_inner( \
type, 96, block_size, num_threads, num_simd_lanes, partition_size); \
instantiate_paged_attention_inner( \
type, 112, block_size, num_threads, num_simd_lanes, partition_size); \
instantiate_paged_attention_inner( \
type, 128, block_size, num_threads, num_simd_lanes, partition_size); \
instantiate_paged_attention_inner( \
type, 192, block_size, num_threads, num_simd_lanes, partition_size); \
instantiate_paged_attention_inner( \
type, 256, block_size, num_threads, num_simd_lanes, partition_size);
#define instantiate_paged_attention_v2_reduce_heads( \
type, num_threads, num_simd_lanes, partition_size) \
instantiate_paged_attention_v2_reduce_inner( \
type, 64, num_threads, num_simd_lanes, partition_size); \
instantiate_paged_attention_v2_reduce_inner( \
type, 80, num_threads, num_simd_lanes, partition_size); \
instantiate_paged_attention_v2_reduce_inner( \
type, 96, num_threads, num_simd_lanes, partition_size); \
instantiate_paged_attention_v2_reduce_inner( \
type, 112, num_threads, num_simd_lanes, partition_size); \
instantiate_paged_attention_v2_reduce_inner( \
type, 128, num_threads, num_simd_lanes, partition_size); \
instantiate_paged_attention_v2_reduce_inner( \
type, 192, num_threads, num_simd_lanes, partition_size); \
instantiate_paged_attention_v2_reduce_inner( \
type, 256, num_threads, num_simd_lanes, partition_size);
#define instantiate_paged_attention_block_size( \
type, num_threads, num_simd_lanes, partition_size) \
instantiate_paged_attention_heads( \
type, 8, num_threads, num_simd_lanes, partition_size); \
instantiate_paged_attention_heads( \
type, 16, num_threads, num_simd_lanes, partition_size); \
instantiate_paged_attention_heads( \
type, 32, num_threads, num_simd_lanes, partition_size);
// TODO: tune num_threads = 256
// NOTE: partition_size = 0
#define instantiate_paged_attention_v1(type, num_simd_lanes) \
instantiate_paged_attention_block_size(type, 256, num_simd_lanes, 0);
// TODO: tune num_threads = 256
// NOTE: partition_size = 512
#define instantiate_paged_attention_v2(type, num_simd_lanes) \
instantiate_paged_attention_block_size(type, 256, num_simd_lanes, 512); \
instantiate_paged_attention_v2_reduce_heads(type, 256, num_simd_lanes, 512);
instantiate_paged_attention_v1(float, 32);
instantiate_paged_attention_v1(bfloat16_t, 32);
instantiate_paged_attention_v1(half, 32);
instantiate_paged_attention_v2(float, 32);
instantiate_paged_attention_v2(bfloat16_t, 32);
instantiate_paged_attention_v2(half, 32);

View File

@ -288,4 +288,13 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
return d.get_kernel(kernel_name, hash_name, func_consts);
}
MTL::ComputePipelineState* get_paged_attention_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const std::string&) {
return d.get_kernel(kernel_name, "mlx", hash_name, func_consts);
}
} // namespace mlx::core

View File

@ -0,0 +1,324 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/paged_attention_primitives.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
namespace mlx::core::paged_attention {
static void run_paged_attention(
const array& q,
const array& k_cache,
const array& v_cache,
const array& block_tables,
const array& context_lens,
const int head_size,
const int block_size,
const int num_kv_heads,
const float scale,
const float softcapping,
const int max_context_len,
const int max_num_blocks_per_seq,
const bool use_partitioning,
const std::optional<array> alibi,
const int q_stride,
const int kv_block_stride,
const int kv_head_stride,
const int num_heads,
const int num_seqs,
array& out,
metal::Device& d,
const Stream& s) {
const int partition_size = use_partitioning ? 512 : 0;
const int num_threads = 256;
const int num_simd_lanes = 32;
const bool use_alibi = alibi.has_value();
std::string type_string = get_type_string(q.dtype());
std::string kname;
kname.reserve(64);
concatenate(
kname,
"paged_attention_",
type_string,
"_hs",
head_size,
"_bs",
block_size,
"_nt",
num_threads,
"_nsl",
num_simd_lanes,
"_ps",
partition_size);
auto template_def = get_template_definition(
kname,
"paged_attention",
type_string,
head_size,
block_size,
num_threads,
num_simd_lanes,
partition_size);
// Encode and dispatch kernel
metal::MTLFCList func_consts = {
{use_partitioning, MTL::DataType::DataTypeBool, 10},
{use_alibi, MTL::DataType::DataTypeBool, 20},
};
std::string hash_name = kname;
auto kernel = get_paged_attention_kernel(
d, kname, hash_name, func_consts, template_def);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
int local_max_num_partitions = 1;
if (use_partitioning) {
local_max_num_partitions =
(max_context_len + partition_size - 1) / partition_size;
}
int logits_size = use_partitioning ? partition_size * size_of(float32) : 0;
int outputs_size = use_partitioning
? ((num_threads / num_simd_lanes) / 2) * head_size * size_of(float32)
: 0;
int shared_mem_size =
use_partitioning ? std::max(logits_size, outputs_size) : 0;
if (use_partitioning) {
compute_encoder.set_threadgroup_memory_length(shared_mem_size, 0);
}
if (use_partitioning) {
auto tmp_out = array(
{num_seqs, num_heads, local_max_num_partitions, head_size}, float32);
tmp_out.set_data(allocator::malloc(tmp_out.nbytes()));
auto exp_sums =
array({num_seqs, num_heads, local_max_num_partitions}, float32);
exp_sums.set_data(allocator::malloc(exp_sums.nbytes()));
std::vector<array> temporaries = {tmp_out, exp_sums};
compute_encoder.set_output_array(tmp_out, 0);
compute_encoder.set_output_array(exp_sums, 1);
compute_encoder.set_output_array(out, 2);
compute_encoder.set_input_array(q, 3);
compute_encoder.set_input_array(k_cache, 4);
compute_encoder.set_input_array(v_cache, 5);
compute_encoder.set_bytes(num_kv_heads, 6);
compute_encoder.set_bytes(scale, 7);
compute_encoder.set_bytes(softcapping, 8);
compute_encoder.set_input_array(block_tables, 9);
compute_encoder.set_input_array(context_lens, 10);
compute_encoder.set_bytes(max_num_blocks_per_seq, 11);
if (use_alibi) {
compute_encoder.set_input_array(alibi.value(), 12);
}
compute_encoder.set_bytes(q_stride, 13);
compute_encoder.set_bytes(kv_block_stride, 14);
compute_encoder.set_bytes(kv_head_stride, 15);
MTL::Size grid_dims(num_heads, num_seqs, local_max_num_partitions);
MTL::Size group_dims(num_threads, 1, 1);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
d.add_temporaries(std::move(temporaries), s.index);
} else {
compute_encoder.set_output_array(out, 2);
compute_encoder.set_input_array(q, 3);
compute_encoder.set_input_array(k_cache, 4);
compute_encoder.set_input_array(v_cache, 5);
compute_encoder.set_bytes(num_kv_heads, 6);
compute_encoder.set_bytes(scale, 7);
compute_encoder.set_bytes(softcapping, 8);
compute_encoder.set_input_array(block_tables, 9);
compute_encoder.set_input_array(context_lens, 10);
compute_encoder.set_bytes(max_num_blocks_per_seq, 11);
if (use_alibi) {
compute_encoder.set_input_array(alibi.value(), 12);
}
compute_encoder.set_bytes(q_stride, 13);
compute_encoder.set_bytes(kv_block_stride, 14);
compute_encoder.set_bytes(kv_head_stride, 15);
MTL::Size grid_dims(num_heads, num_seqs, 1);
MTL::Size group_dims(num_threads, 1, 1);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
}
void paged_attention_v1(
const array& q,
const array& k_cache,
const array& v_cache,
const array& block_tables,
const array& context_lens,
const int head_size,
const int block_size,
const int num_kv_heads,
const float scale,
const float softcapping,
const int max_context_len,
const int max_num_blocks_per_seq,
const std::optional<array> alibi,
const int q_stride,
const int kv_block_stride,
const int kv_head_stride,
const int num_heads,
const int num_seqs,
array& out,
metal::Device& d,
const Stream& s) {
run_paged_attention(
q,
k_cache,
v_cache,
block_tables,
context_lens,
head_size,
block_size,
num_kv_heads,
scale,
softcapping,
max_context_len,
max_num_blocks_per_seq,
/*use_partitioning=*/false,
alibi,
q_stride,
kv_block_stride,
kv_head_stride,
num_heads,
num_seqs,
out,
d,
s);
}
void paged_attention_v2(
const array& q,
const array& k_cache,
const array& v_cache,
const array& block_tables,
const array& context_lens,
const int head_size,
const int block_size,
const int num_kv_heads,
const float scale,
const float softcapping,
const int max_context_len,
const int max_num_blocks_per_seq,
const int /* max_num_partitions */,
const std::optional<array> alibi,
const int q_stride,
const int kv_block_stride,
const int kv_head_stride,
const int num_heads,
const int num_seqs,
array& out,
metal::Device& d,
const Stream& s) {
run_paged_attention(
q,
k_cache,
v_cache,
block_tables,
context_lens,
head_size,
block_size,
num_kv_heads,
scale,
softcapping,
max_context_len,
max_num_blocks_per_seq,
/*use_partitioning=*/true,
alibi,
q_stride,
kv_block_stride,
kv_head_stride,
num_heads,
num_seqs,
out,
d,
s);
}
void PagedAttention::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = stream();
auto& d = metal::device(s.device);
out.set_data(allocator::malloc(out.nbytes()));
auto& q = inputs[0];
auto& k_cache = inputs[1];
auto& v_cache = inputs[2];
auto& block_tables = inputs[3];
auto& context_lens = inputs[4];
const auto alibi_slopes =
inputs.size() == 6 ? std::optional{inputs[5]} : std::nullopt;
if (use_v1_) {
paged_attention_v1(
q,
k_cache,
v_cache,
block_tables,
context_lens,
head_size_,
block_size_,
num_kv_heads_,
softmax_scale_,
softcapping_.value_or(1.),
max_context_len_,
max_num_blocks_per_seq_,
alibi_slopes,
q_stride_,
kv_block_stride_,
kv_head_stride_,
num_heads_,
num_seqs_,
out,
d,
s);
} else {
paged_attention_v2(
q,
k_cache,
v_cache,
block_tables,
context_lens,
head_size_,
block_size_,
num_kv_heads_,
softmax_scale_,
softcapping_.value_or(1.),
max_context_len_,
max_num_blocks_per_seq_,
max_num_partitions_,
alibi_slopes,
q_stride_,
kv_block_stride_,
kv_head_stride_,
num_heads_,
num_seqs_,
out,
d,
s);
}
}
} // namespace mlx::core::paged_attention

View File

@ -17,6 +17,7 @@
#include "mlx/linalg.h"
#include "mlx/memory.h"
#include "mlx/ops.h"
#include "mlx/paged_attention.h"
#include "mlx/random.h"
#include "mlx/stream.h"
#include "mlx/transforms.h"

170
mlx/paged_attention.cpp Normal file
View File

@ -0,0 +1,170 @@
// Copyright © 2023-2024 Apple Inc.
// Required for using M_PI in MSVC.
#define _USE_MATH_DEFINES
#include <algorithm>
#include <climits>
#include <cmath>
#include <numeric>
#include <set>
#include <sstream>
#include <stdexcept>
#include "mlx/paged_attention_primitives.h"
#include "mlx/utils.h"
namespace mlx::core::paged_attention {
array paged_attention(
const array& q,
const array& k_cache,
const array& v_cache,
const array& block_tables,
const array& context_lens,
int max_context_len,
float softmax_scale,
std::optional<array> alibi_slopes = std::nullopt,
std::optional<float> softcapping = std::nullopt,
StreamOrDevice s_ = {}) {
auto s = to_stream(s_);
// supported dtypes
if (!issubdtype(q.dtype(), floating)) {
throw std::invalid_argument(
"[paged_attention] Only real floating types are supported");
}
if (!(q.dtype() == k_cache.dtype() && k_cache.dtype() == v_cache.dtype())) {
throw std::invalid_argument(
"[paged_attention] q/k_cache/v_cache dtype must match");
}
if (!(block_tables.dtype() == uint32 && context_lens.dtype() == uint32)) {
throw std::invalid_argument(
"[paged_attention] block_tables/context_lens dtype must be uint32");
}
// rank checks
if (q.ndim() != 3)
throw std::invalid_argument("[paged_attention] `q` must be rank-3");
if (k_cache.ndim() != 5)
throw std::invalid_argument("[paged_attention] `k_cache` must be rank-5");
if (v_cache.ndim() != 4)
throw std::invalid_argument("[paged_attention] `v_cache` must be rank-4");
if (block_tables.ndim() != 2)
throw std::invalid_argument(
"[paged_attention] `block_tables` must be rank-2");
if (context_lens.ndim() != 1)
throw std::invalid_argument(
"[paged_attention] `context_lens` must be rank-1");
// 4. Shape consistency
const auto& q_shape = q.shape(); // [num_seqs, num_heads, head_size]
const auto& kc_shape = k_cache.shape();
const auto& vc_shape = v_cache.shape();
const auto& bt_shape = block_tables.shape();
const auto& cl_shape = context_lens.shape();
int num_seqs = q_shape[0];
int num_heads = q_shape[1];
int head_size = q_shape[2];
// Allowed head sizes
switch (head_size) {
case 64:
case 80:
case 96:
case 112:
case 128:
case 192:
case 256:
break;
default:
throw std::invalid_argument(
"[paged_attention] `head_size` must be one of "
"{64, 80, 96, 112, 128, 192, 256}");
}
int max_num_blocks_per_seq = bt_shape[1];
// block_tables first dimension must match num_seqs
if (bt_shape[0] != num_seqs) {
std::stringstream ss;
ss << "[paged_attention] block_tables.shape[0] (" << bt_shape[0]
<< ") must equal q.shape[0] (" << num_seqs << ")";
throw std::invalid_argument(ss.str());
}
// Extract k_cache dimensions
int num_blocks = kc_shape[0];
int num_kv_heads = kc_shape[1];
int head_size_kc = kc_shape[2];
int block_size = kc_shape[3];
int x = kc_shape[4];
if (head_size_kc * x != head_size) {
std::stringstream ss;
ss << "[paged_attention] k_cache head_size (" << head_size_kc << " * " << x
<< ") must equal q head_size (" << head_size << ")";
throw std::invalid_argument(ss.str());
}
// v_cache must match the derived dimensions
if (!(vc_shape[0] == num_blocks && vc_shape[1] == num_kv_heads &&
vc_shape[2] == head_size && vc_shape[3] == block_size)) {
throw std::invalid_argument(
"[paged_attention] `v_cache` shape mismatch with `k_cache`/`q`");
}
// context_lens length must match num_seqs
if (cl_shape[0] != num_seqs) {
std::stringstream ss;
ss << "paged_attention: context_lens length (" << cl_shape[0]
<< ") must equal q.shape[0] (" << num_seqs << ")";
throw std::invalid_argument(ss.str());
}
constexpr int partition_size = 512;
int max_num_partitions =
(max_context_len + partition_size - 1) / partition_size; // ceildiv
bool use_v1 = ((max_num_partitions == 1) || (num_seqs * num_heads > 512)) &&
(partition_size % block_size == 0);
auto out_shape = q.shape();
auto inputs = std::vector{
std::move(q),
std::move(k_cache),
std::move(v_cache),
std::move(block_tables),
std::move(context_lens)};
if (alibi_slopes.has_value()) {
inputs.push_back(std::move(alibi_slopes.value()));
}
int q_stride = q.strides()[0];
int kv_block_stride = k_cache.strides()[0];
int kv_head_stride = k_cache.strides()[1];
return array(
std::move(out_shape),
q.dtype(),
std::make_shared<PagedAttention>(
to_stream(s),
use_v1,
max_context_len,
head_size,
block_size,
num_kv_heads,
softmax_scale,
max_num_blocks_per_seq,
max_num_partitions,
q_stride,
kv_block_stride,
kv_head_stride,
num_heads,
num_seqs,
softcapping),
inputs);
}
} // namespace mlx::core::paged_attention

34
mlx/paged_attention.h Normal file
View File

@ -0,0 +1,34 @@
// Copyright © 2023-2024 Apple Inc.
#pragma once
#include <optional>
#include "mlx/array.h"
#include "mlx/device.h"
#include "mlx/stream.h"
#include "mlx/utils.h"
namespace mlx::core::paged_attention {
/**
* \defgroup ops Paged attention operations
* @{
*/
/** PagedAttention operation. */
array paged_attention(
const array& q,
const array& k_cache,
const array& v_cache,
const array& block_tables,
const array& context_lens,
int max_context_len,
float softmax_scale,
std::optional<array> alibi_slopes = std::nullopt,
std::optional<float> softcapping = std::nullopt,
StreamOrDevice s_ = {});
/** @} */
} // namespace mlx::core::paged_attention

View File

@ -0,0 +1,82 @@
// Copyright © 2023-2024 Apple Inc.
// Required for using M_PI in MSVC.
#define _USE_MATH_DEFINES
#include <optional>
#include "mlx/primitives.h"
namespace mlx::core::paged_attention {
class PagedAttention : public UnaryPrimitive {
public:
explicit PagedAttention(
Stream stream,
bool use_v1,
int max_context_len,
int head_size,
int block_size,
int num_kv_heads,
int max_num_blocks_per_seq,
int max_num_partitions,
int q_stride,
int kv_block_stride,
int kv_head_stride,
int num_heads,
int num_seqs,
float softmax_scale,
std::optional<float> softcapping = std::nullopt)
: UnaryPrimitive(stream),
use_v1_(use_v1),
max_context_len_(max_context_len),
head_size_(head_size),
block_size_(block_size),
num_kv_heads_(num_kv_heads),
max_num_blocks_per_seq_(max_num_blocks_per_seq),
max_num_partitions_(max_num_partitions),
q_stride_(q_stride),
kv_block_stride_(kv_block_stride),
kv_head_stride_(kv_head_stride),
num_heads_(num_heads),
num_seqs_(num_seqs),
softmax_scale_(softmax_scale),
softcapping_(softcapping) {}
void eval_cpu(const std::vector<array>& inputs, array& outputs) override {
throw std::runtime_error("NYI");
}
void eval_gpu(const std::vector<array>& inputs, array& outputs) override;
DEFINE_PRINT(PagedAttention);
bool is_equivalent(const Primitive& other) const override;
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
auto state() const {
return std::make_tuple(
max_context_len_,
head_size_,
block_size_,
softmax_scale_,
softcapping_);
}
private:
bool use_v1_;
int max_context_len_;
int head_size_;
int block_size_;
int num_kv_heads_;
int max_num_blocks_per_seq_;
int max_num_partitions_;
int q_stride_;
int kv_block_stride_;
int kv_head_stride_;
int num_heads_;
int num_seqs_;
float softmax_scale_;
std::optional<float> softcapping_ = std::nullopt;
};
} // namespace mlx::core::paged_attention

5
python/tests/__main__.py Normal file
View File

@ -0,0 +1,5 @@
from . import mlx_tests
__unittest = True
mlx_tests.MLXTestRunner(module=None)

143
python/tests/cuda_skip.py Normal file
View File

@ -0,0 +1,143 @@
cuda_skip = {
"TestArray.test_api",
"TestArray.test_setitem",
"TestAutograd.test_cumprod_grad",
"TestAutograd.test_slice_grads",
"TestAutograd.test_split_against_slice",
"TestAutograd.test_stop_gradient",
"TestAutograd.test_topk_grad",
"TestAutograd.test_update_state",
"TestAutograd.test_vjp",
"TestBF16.test_arg_reduction_ops",
"TestBF16.test_binary_ops",
"TestBF16.test_reduction_ops",
"TestBlas.test_block_masked_matmul",
"TestBlas.test_complex_gemm",
"TestBlas.test_gather_matmul",
"TestBlas.test_gather_matmul_grad",
"TestBlas.test_matmul_batched",
"TestBlas.test_matrix_vector_attn",
"TestCompile.test_compile_dynamic_dims",
"TestCompile.test_compile_inf",
"TestCompile.test_inf_constant",
"TestConv.test_1d_conv_with_2d",
"TestConv.test_asymmetric_padding",
"TestConv.test_basic_grad_shapes",
"TestConv.test_conv2d_unaligned_channels",
"TestConv.test_conv_1d_groups_flipped",
"TestConv.test_conv_general_flip_grad",
"TestConv.test_conv_groups_grad",
"TestConv.test_numpy_conv",
"TestConv.test_repeated_conv",
"TestConv.test_torch_conv_1D",
"TestConv.test_torch_conv_1D_grad",
"TestConv.test_torch_conv_2D",
"TestConv.test_torch_conv_2D_grad",
"TestConv.test_torch_conv_3D",
"TestConv.test_torch_conv_3D_grad",
"TestConv.test_torch_conv_depthwise",
"TestConv.test_torch_conv_general",
"TestConvTranspose.test_torch_conv_tranpose_1d_output_padding",
"TestConvTranspose.test_torch_conv_transpose_1D",
"TestConvTranspose.test_torch_conv_transpose_1D_grad",
"TestConvTranspose.test_torch_conv_transpose_2D",
"TestConvTranspose.test_torch_conv_transpose_2D_grad",
"TestConvTranspose.test_torch_conv_transpose_2d_output_padding",
"TestConvTranspose.test_torch_conv_transpose_3D",
"TestConvTranspose.test_torch_conv_transpose_3D_grad",
"TestConvTranspose.test_torch_conv_transpose_3d_output_padding",
"TestEinsum.test_attention",
"TestEinsum.test_ellipses",
"TestEinsum.test_opt_einsum_test_cases",
"TestEval.test_multi_output_eval_during_transform",
"TestExportImport.test_export_conv",
"TestFast.test_rope_grad",
"TestFFT.test_fft",
"TestFFT.test_fft_big_powers_of_two",
"TestFFT.test_fft_contiguity",
"TestFFT.test_fft_exhaustive",
"TestFFT.test_fft_grads",
"TestFFT.test_fft_into_ifft",
"TestFFT.test_fft_large_numbers",
"TestFFT.test_fft_shared_mem",
"TestFFT.test_fftn",
"TestInit.test_orthogonal",
"TestLinalg.test_cholesky",
"TestLinalg.test_cholesky_inv",
"TestLinalg.test_eig",
"TestLinalg.test_eigh",
"TestLinalg.test_inverse",
"TestLinalg.test_lu",
"TestLinalg.test_lu_factor",
"TestLinalg.test_pseudo_inverse",
"TestLinalg.test_qr_factorization",
"TestLinalg.test_svd_decomposition",
"TestLinalg.test_tri_inverse",
"TestLoad.test_load_f8_e4m3",
"TestLosses.test_binary_cross_entropy",
"TestMemory.test_memory_info",
"TestLayers.test_conv1d",
"TestLayers.test_conv2d",
"TestLayers.test_elu",
"TestLayers.test_group_norm",
"TestLayers.test_hard_shrink",
"TestLayers.test_pooling",
"TestLayers.test_quantized_embedding",
"TestLayers.test_sin_pe",
"TestLayers.test_softshrink",
"TestLayers.test_upsample",
"TestOps.test_argpartition",
"TestOps.test_array_equal",
"TestOps.test_as_strided",
"TestOps.test_atleast_1d",
"TestOps.test_atleast_2d",
"TestOps.test_atleast_3d",
"TestOps.test_binary_ops",
"TestOps.test_bitwise_grad",
"TestOps.test_complex_ops",
"TestOps.test_divmod",
"TestOps.test_dynamic_slicing",
"TestOps.test_hadamard",
"TestOps.test_hadamard_grad_vmap",
"TestOps.test_irregular_binary_ops",
"TestOps.test_isfinite",
"TestOps.test_kron",
"TestOps.test_log",
"TestOps.test_log10",
"TestOps.test_log1p",
"TestOps.test_log2",
"TestOps.test_logaddexp",
"TestOps.test_logcumsumexp",
"TestOps.test_partition",
"TestOps.test_scans",
"TestOps.test_slice_update_reversed",
"TestOps.test_softmax",
"TestOps.test_sort",
"TestOps.test_tensordot",
"TestOps.test_tile",
"TestOps.test_view",
"TestQuantized.test_gather_matmul_grad",
"TestQuantized.test_gather_qmm",
"TestQuantized.test_gather_qmm_sorted",
"TestQuantized.test_non_multiples",
"TestQuantized.test_qmm",
"TestQuantized.test_qmm_jvp",
"TestQuantized.test_qmm_shapes",
"TestQuantized.test_qmm_vjp",
"TestQuantized.test_qmv",
"TestQuantized.test_quantize_dequantize",
"TestQuantized.test_qvm",
"TestQuantized.test_qvm_splitk",
"TestQuantized.test_small_matrix",
"TestQuantized.test_throw",
"TestQuantized.test_vjp_scales_biases",
"TestReduce.test_axis_permutation_sums",
"TestReduce.test_dtypes",
"TestReduce.test_expand_sums",
"TestReduce.test_many_reduction_axes",
"TestUpsample.test_torch_upsample",
"TestVmap.test_unary",
"TestVmap.test_vmap_conv",
"TestVmap.test_vmap_inverse",
"TestVmap.test_vmap_svd",
}

View File

@ -9,6 +9,42 @@ import mlx.core as mx
import numpy as np
class MLXTestRunner(unittest.TestProgram):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def createTests(self, *args, **kwargs):
super().createTests(*args, **kwargs)
# Asume CUDA backend in this case
device = os.getenv("DEVICE", None)
if device is not None:
device = getattr(mx, device)
else:
device = mx.default_device()
if not (device == mx.gpu and not mx.metal.is_available()):
return
from cuda_skip import cuda_skip
filtered_suite = unittest.TestSuite()
def filter_and_add(t):
if isinstance(t, unittest.TestSuite):
for sub_t in t:
filter_and_add(sub_t)
else:
t_id = ".".join(t.id().split(".")[-2:])
if t_id in cuda_skip:
print(f"Skipping {t_id}")
else:
filtered_suite.addTest(t)
filter_and_add(self.test)
self.test = filtered_suite
class MLXTestCase(unittest.TestCase):
@property
def is_apple_silicon(self):

View File

@ -130,4 +130,4 @@ class TestRingDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -2033,4 +2033,4 @@ class TestArray(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -799,4 +799,4 @@ class TestAutograd(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -193,4 +193,4 @@ class TestBF16(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -1236,4 +1236,4 @@ class TestBlas(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -981,4 +981,4 @@ class TestCompile(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -38,4 +38,4 @@ class TestConstants(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -1188,4 +1188,4 @@ class TestConv(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -807,4 +807,4 @@ class TestConvTranspose(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -38,7 +38,7 @@ class TestDevice(mlx_tests.MLXTestCase):
# Restore device
mx.set_default_device(device)
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
@unittest.skipIf(not mx.is_available(mx.gpu), "GPU is not available")
def test_device_context(self):
default = mx.default_device()
diff = mx.cpu if default == mx.gpu else mx.gpu
@ -114,4 +114,4 @@ class TestStream(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -294,4 +294,4 @@ class TestDouble(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -360,4 +360,4 @@ class TestEinsum(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -172,7 +172,7 @@ class TestEval(mlx_tests.MLXTestCase):
post = mx.get_peak_memory()
self.assertEqual(pre, post)
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
@unittest.skipIf(not mx.is_available(mx.gpu), "GPU is not available")
def test_multistream_deadlock(self):
s1 = mx.default_stream(mx.gpu)
s2 = mx.new_stream(mx.gpu)
@ -197,4 +197,4 @@ class TestEval(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -348,4 +348,4 @@ class TestExportImport(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -772,4 +772,4 @@ class TestFast(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -607,7 +607,7 @@ class TestSDPA(mlx_tests.MLXTestCase):
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
def test_sdpa_prommote_mask(self):
def test_sdpa_promote_mask(self):
mask = mx.array(2.0, mx.bfloat16)
D = 64
Nq = 4
@ -653,4 +653,4 @@ class TestSDPA(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main(failfast=True)
mlx_tests.MLXTestRunner(failfast=True)

View File

@ -320,4 +320,4 @@ class TestFFT(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -34,4 +34,4 @@ class TestGraph(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -136,4 +136,4 @@ class TestInit(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -545,4 +545,4 @@ class TestLinalg(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -400,4 +400,4 @@ class TestLoad(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -414,4 +414,4 @@ class TestLosses(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -60,4 +60,4 @@ class TestMemory(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -1907,4 +1907,4 @@ class TestLayers(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -3127,4 +3127,4 @@ class TestBroadcast(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -527,4 +527,4 @@ class TestSchedulers(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -576,4 +576,4 @@ class TestQuantized(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -389,4 +389,4 @@ class TestRandom(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -155,4 +155,4 @@ class TestReduce(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main(failfast=True)
mlx_tests.MLXTestRunner(failfast=True)

View File

@ -48,4 +48,4 @@ class TestTreeUtils(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -97,4 +97,4 @@ class TestUpsample(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -725,4 +725,4 @@ class TestVmap(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()