mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Compare commits
2 Commits
a14aaa7c9d
...
4fda5fbdf9
Author | SHA1 | Date | |
---|---|---|---|
![]() |
4fda5fbdf9 | ||
![]() |
580776559b |
@ -234,6 +234,7 @@ jobs:
|
||||
command: |
|
||||
source env/bin/activate
|
||||
LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v
|
||||
LOW_MEMORY=1 DEVICE=gpu python -m tests discover python/tests -v
|
||||
|
||||
build_release:
|
||||
parameters:
|
||||
|
@ -209,4 +209,14 @@ Dims get_2d_grid_dims_common(
|
||||
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
|
||||
}
|
||||
|
||||
std::pair<Dims, Dims> get_grid_and_block_common(int dim0, int dim1, int dim2) {
|
||||
auto [bx, by, bz] = get_block_dims_common(dim0, dim1, dim2);
|
||||
auto gx = (dim0 + bx - 1) / bx;
|
||||
auto gy = (dim1 + by - 1) / by;
|
||||
auto gz = (dim2 + bz - 1) / bz;
|
||||
|
||||
return std::make_pair(
|
||||
std::make_tuple(gx, gy, gz), std::make_tuple(bx, by, bz));
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -95,6 +95,9 @@ Dims get_2d_grid_dims_common(
|
||||
const Strides& strides,
|
||||
size_t divisor);
|
||||
|
||||
// Get both the block and a grid of blocks that covers dim0, dim1 and dim2.
|
||||
std::pair<Dims, Dims> get_grid_and_block_common(int dim0, int dim1, int dim2);
|
||||
|
||||
struct ContiguousIterator {
|
||||
inline void step() {
|
||||
int dims = shape_.size();
|
||||
|
@ -32,6 +32,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/segmented_reduce.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cu
|
||||
|
@ -23,4 +23,11 @@ dim3 get_2d_grid_dims(
|
||||
return dim3(std::get<0>(dims), std::get<1>(dims), std::get<2>(dims));
|
||||
}
|
||||
|
||||
std::pair<dim3, dim3> get_grid_and_block(int dim0, int dim1, int dim2) {
|
||||
auto [grid, block] = get_grid_and_block_common(dim0, dim1, dim2);
|
||||
auto [gx, gy, gz] = grid;
|
||||
auto [bx, by, bz] = block;
|
||||
return std::make_pair(dim3(gx, gy, gz), dim3(bx, by, bz));
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -121,6 +121,7 @@ dim3 get_2d_grid_dims(
|
||||
const Shape& shape,
|
||||
const Strides& strides,
|
||||
size_t divisor);
|
||||
std::pair<dim3, dim3> get_grid_and_block(int dim0, int dim1, int dim2);
|
||||
|
||||
// Return a block size that achieves maximum potential occupancy for kernel.
|
||||
template <typename T>
|
||||
|
@ -94,7 +94,6 @@ NO_GPU_MULTI(Eig)
|
||||
NO_GPU_MULTI(Eigh)
|
||||
|
||||
namespace fast {
|
||||
NO_GPU_USE_FALLBACK(RoPE)
|
||||
NO_GPU(ScaledDotProductAttention)
|
||||
NO_GPU_MULTI(AffineQuantize)
|
||||
NO_GPU_MULTI(CustomKernel)
|
||||
|
@ -4,6 +4,7 @@
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
#include <cassert>
|
||||
@ -12,6 +13,8 @@ namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
__constant__ constexpr uint32_t rotations[2][4] = {
|
||||
{13, 15, 26, 6},
|
||||
{17, 29, 16, 24}};
|
||||
@ -47,27 +50,28 @@ __global__ void rbitsc(
|
||||
dim3 grid_dims,
|
||||
bool odd,
|
||||
uint32_t bytes_per_key) {
|
||||
uint2 index{
|
||||
blockIdx.x * blockDim.x + threadIdx.x,
|
||||
blockIdx.y * blockDim.y + threadIdx.y};
|
||||
if (index.x >= grid_dims.x || index.y >= grid_dims.y) {
|
||||
auto grid = cg::this_grid();
|
||||
uint thread_index = grid.thread_rank();
|
||||
uint index_x = thread_index % grid_dims.x;
|
||||
uint index_y = thread_index / grid_dims.x;
|
||||
if (index_x >= grid_dims.x || index_y >= grid_dims.y) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto kidx = 2 * index.x;
|
||||
auto kidx = 2 * index_x;
|
||||
auto key = uint2{keys[kidx], keys[kidx + 1]};
|
||||
auto half_size = grid_dims.y - odd;
|
||||
out += index.x * bytes_per_key;
|
||||
bool drop_last = odd && (index.y == half_size);
|
||||
out += index_x * bytes_per_key;
|
||||
bool drop_last = odd && (index_y == half_size);
|
||||
auto bits = threefry2x32_hash(
|
||||
key, uint2{index.y, drop_last ? 0 : index.y + grid_dims.y});
|
||||
size_t idx = size_t(index.y) << 2;
|
||||
key, uint2{index_y, drop_last ? 0 : index_y + grid_dims.y});
|
||||
size_t idx = size_t(index_y) << 2;
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
out[idx + i] = bits.bytes[0][i];
|
||||
}
|
||||
if (!drop_last) {
|
||||
idx = (drop_last ? 0 : size_t(index.y) + grid_dims.y) << 2;
|
||||
if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) {
|
||||
idx = (drop_last ? 0 : size_t(index_y) + grid_dims.y) << 2;
|
||||
if ((index_y + 1) == half_size && (bytes_per_key % 4) > 0) {
|
||||
int edge_bytes = (bytes_per_key % 4);
|
||||
for (int i = 0; i < edge_bytes; ++i) {
|
||||
out[idx + i] = bits.bytes[1][i];
|
||||
@ -89,30 +93,31 @@ __global__ void rbits(
|
||||
int32_t ndim,
|
||||
const __grid_constant__ Shape key_shape,
|
||||
const __grid_constant__ Strides key_strides) {
|
||||
uint2 index{
|
||||
blockIdx.x * blockDim.x + threadIdx.x,
|
||||
blockIdx.y * blockDim.y + threadIdx.y};
|
||||
if (index.x >= grid_dims.x || index.y >= grid_dims.y) {
|
||||
auto grid = cg::this_grid();
|
||||
uint thread_index = grid.thread_rank();
|
||||
uint index_x = thread_index % grid_dims.x;
|
||||
uint index_y = thread_index / grid_dims.x;
|
||||
if (index_x >= grid_dims.x || index_y >= grid_dims.y) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto kidx = 2 * index.x;
|
||||
auto kidx = 2 * index_x;
|
||||
auto k1_elem = elem_to_loc(kidx, key_shape.data(), key_strides.data(), ndim);
|
||||
auto k2_elem =
|
||||
elem_to_loc(kidx + 1, key_shape.data(), key_strides.data(), ndim);
|
||||
auto key = uint2{keys[k1_elem], keys[k2_elem]};
|
||||
auto half_size = grid_dims.y - odd;
|
||||
out += size_t(index.x) * bytes_per_key;
|
||||
bool drop_last = odd && (index.y == half_size);
|
||||
out += size_t(index_x) * bytes_per_key;
|
||||
bool drop_last = odd && (index_y == half_size);
|
||||
auto bits = threefry2x32_hash(
|
||||
key, uint2{index.y, drop_last ? 0 : index.y + grid_dims.y});
|
||||
size_t idx = size_t(index.y) << 2;
|
||||
key, uint2{index_y, drop_last ? 0 : index_y + grid_dims.y});
|
||||
size_t idx = size_t(index_y) << 2;
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
out[idx + i] = bits.bytes[0][i];
|
||||
}
|
||||
if (!drop_last) {
|
||||
idx = (drop_last ? 0 : size_t(index.y) + grid_dims.y) << 2;
|
||||
if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) {
|
||||
idx = (drop_last ? 0 : size_t(index_y) + grid_dims.y) << 2;
|
||||
if ((index_y + 1) == half_size && (bytes_per_key % 4) > 0) {
|
||||
int edge_bytes = (bytes_per_key % 4);
|
||||
for (int i = 0; i < edge_bytes; ++i) {
|
||||
out[idx + i] = bits.bytes[1][i];
|
||||
@ -153,19 +158,22 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
dim3 grid_dims{num_keys, half_size + odd};
|
||||
dim3 block_dims = get_block_dims(grid_dims.x, grid_dims.y, 1);
|
||||
dim3 num_blocks{
|
||||
cuda::ceil_div(grid_dims.x, block_dims.x),
|
||||
cuda::ceil_div(grid_dims.y, block_dims.y)};
|
||||
int64_t total = grid_dims.x * grid_dims.y;
|
||||
int32_t threads_y = 1;
|
||||
while ((total / threads_y) >= (1U << 31)) {
|
||||
threads_y *= 2;
|
||||
}
|
||||
int32_t threads_x = cuda::ceil_div(total, threads_y);
|
||||
auto [grid, block] = get_grid_and_block(threads_x, threads_y, 1);
|
||||
if (keys.flags().row_contiguous) {
|
||||
cu::rbitsc<<<num_blocks, block_dims, 0, stream>>>(
|
||||
cu::rbitsc<<<grid, block, 0, stream>>>(
|
||||
keys.data<uint32_t>(),
|
||||
out.data<uint8_t>(),
|
||||
grid_dims,
|
||||
odd,
|
||||
bytes_per_key);
|
||||
} else {
|
||||
cu::rbits<<<num_blocks, block_dims, 0, stream>>>(
|
||||
cu::rbits<<<grid, block, 0, stream>>>(
|
||||
keys.data<uint32_t>(),
|
||||
out.data<uint8_t>(),
|
||||
grid_dims,
|
||||
|
385
mlx/backend/cuda/rope.cu
Normal file
385
mlx/backend/cuda/rope.cu
Normal file
@ -0,0 +1,385 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
template <typename T, bool traditional, bool forward>
|
||||
__device__ void rope_single_impl(
|
||||
const T* in,
|
||||
T* out,
|
||||
int32_t offset,
|
||||
float inv_freq,
|
||||
float scale,
|
||||
int64_t stride,
|
||||
uint2 pos,
|
||||
uint2 dims) {
|
||||
float L = scale * static_cast<float>(offset);
|
||||
|
||||
// Compute costheta, sintheta
|
||||
float theta = L * inv_freq;
|
||||
float costheta = cos(theta);
|
||||
float sintheta = sin(theta);
|
||||
|
||||
// Compute the input and output indices
|
||||
uint index_1, index_2;
|
||||
if (traditional) {
|
||||
index_1 = 2 * pos.x + pos.y * stride;
|
||||
index_2 = index_1 + 1;
|
||||
} else {
|
||||
index_1 = pos.x + pos.y * stride;
|
||||
index_2 = index_1 + dims.x;
|
||||
}
|
||||
|
||||
// Read and write the output
|
||||
float x1 = static_cast<float>(in[index_1]);
|
||||
float x2 = static_cast<float>(in[index_2]);
|
||||
float rx1;
|
||||
float rx2;
|
||||
if (forward) {
|
||||
rx1 = x1 * costheta - x2 * sintheta;
|
||||
rx2 = x1 * sintheta + x2 * costheta;
|
||||
} else {
|
||||
rx1 = x2 * sintheta + x1 * costheta;
|
||||
rx2 = x2 * costheta - x1 * sintheta;
|
||||
}
|
||||
out[index_1] = static_cast<T>(rx1);
|
||||
out[index_2] = static_cast<T>(rx2);
|
||||
}
|
||||
|
||||
template <typename T, bool traditional, bool forward>
|
||||
__global__ void rope_single(
|
||||
const T* in,
|
||||
T* out,
|
||||
const int32_t* offset,
|
||||
float scale,
|
||||
float base,
|
||||
int64_t stride,
|
||||
uint2 dims) {
|
||||
uint2 pos = make_uint2(
|
||||
blockIdx.x * blockDim.x + threadIdx.x,
|
||||
blockIdx.y * blockDim.y + threadIdx.y);
|
||||
if (pos.x >= dims.x || pos.y >= dims.y) {
|
||||
return;
|
||||
}
|
||||
|
||||
float d = static_cast<float>(pos.x) / static_cast<float>(dims.x);
|
||||
float inv_freq = exp2(-d * base);
|
||||
rope_single_impl<T, traditional, forward>(
|
||||
in, out, *offset, inv_freq, scale, stride, pos, dims);
|
||||
}
|
||||
|
||||
template <typename T, bool traditional, bool forward>
|
||||
__global__ void rope_single_freqs(
|
||||
const T* in,
|
||||
T* out,
|
||||
const int32_t* offset,
|
||||
const float* freqs,
|
||||
float scale,
|
||||
int64_t stride,
|
||||
uint2 dims,
|
||||
int64_t freq_stride) {
|
||||
uint2 pos = make_uint2(
|
||||
blockIdx.x * blockDim.x + threadIdx.x,
|
||||
blockIdx.y * blockDim.y + threadIdx.y);
|
||||
if (pos.x >= dims.x || pos.y >= dims.y) {
|
||||
return;
|
||||
}
|
||||
|
||||
float inv_freq = 1.0 / freqs[freq_stride * pos.x];
|
||||
rope_single_impl<T, traditional, forward>(
|
||||
in, out, *offset, inv_freq, scale, stride, pos, dims);
|
||||
}
|
||||
|
||||
template <typename T, bool traditional, bool forward, int N = 4>
|
||||
__device__ void rope_impl(
|
||||
const T* in,
|
||||
T* out,
|
||||
int offset,
|
||||
float inv_freq,
|
||||
float scale,
|
||||
const cuda::std::array<int64_t, 3> strides,
|
||||
const cuda::std::array<int64_t, 3> out_strides,
|
||||
int64_t n_batch,
|
||||
uint3 pos,
|
||||
uint3 dims) {
|
||||
float L = scale * static_cast<float>(pos.y + offset);
|
||||
|
||||
// Compute costheta, sintheta
|
||||
float theta = L * inv_freq;
|
||||
float costheta = cos(theta);
|
||||
float sintheta = sin(theta);
|
||||
|
||||
// Compute the input and output indices
|
||||
size_t in_index_1, in_index_2;
|
||||
size_t out_index_1, out_index_2;
|
||||
if (traditional) {
|
||||
out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] +
|
||||
N * pos.z * out_strides[0];
|
||||
out_index_2 = out_index_1 + 1;
|
||||
in_index_1 =
|
||||
2 * pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
|
||||
in_index_2 = in_index_1 + strides[2];
|
||||
} else {
|
||||
out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] +
|
||||
N * pos.z * out_strides[0];
|
||||
out_index_2 = out_index_1 + dims.x * out_strides[2];
|
||||
in_index_1 =
|
||||
pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
|
||||
in_index_2 = in_index_1 + dims.x * strides[2];
|
||||
}
|
||||
for (int i = 0; i < N && pos.z * N + i < n_batch; ++i) {
|
||||
// Read and write the output
|
||||
float x1 = static_cast<float>(in[in_index_1]);
|
||||
float x2 = static_cast<float>(in[in_index_2]);
|
||||
float rx1;
|
||||
float rx2;
|
||||
if (forward) {
|
||||
rx1 = x1 * costheta - x2 * sintheta;
|
||||
rx2 = x1 * sintheta + x2 * costheta;
|
||||
} else {
|
||||
rx1 = x2 * sintheta + x1 * costheta;
|
||||
rx2 = x2 * costheta - x1 * sintheta;
|
||||
}
|
||||
out[out_index_1] = static_cast<T>(rx1);
|
||||
out[out_index_2] = static_cast<T>(rx2);
|
||||
in_index_1 += strides[0];
|
||||
in_index_2 += strides[0];
|
||||
out_index_1 += out_strides[0];
|
||||
out_index_2 += out_strides[0];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, bool traditional, bool forward>
|
||||
__global__ void rope(
|
||||
const T* in,
|
||||
T* out,
|
||||
const int32_t* offset,
|
||||
float scale,
|
||||
float base,
|
||||
const __grid_constant__ cuda::std::array<int64_t, 3> strides,
|
||||
const __grid_constant__ cuda::std::array<int64_t, 3> out_strides,
|
||||
int64_t n_batch,
|
||||
uint3 dims) {
|
||||
uint3 pos = make_uint3(
|
||||
blockIdx.x * blockDim.x + threadIdx.x,
|
||||
blockIdx.y * blockDim.y + threadIdx.y,
|
||||
blockIdx.z * blockDim.z + threadIdx.z);
|
||||
if (pos.x >= dims.x || pos.y >= dims.y || pos.z >= dims.z) {
|
||||
return;
|
||||
}
|
||||
|
||||
float d = static_cast<float>(pos.x) / static_cast<float>(dims.x);
|
||||
float inv_freq = exp2(-d * base);
|
||||
rope_impl<T, traditional, forward>(
|
||||
in,
|
||||
out,
|
||||
*offset,
|
||||
inv_freq,
|
||||
scale,
|
||||
strides,
|
||||
out_strides,
|
||||
n_batch,
|
||||
pos,
|
||||
dims);
|
||||
}
|
||||
|
||||
template <typename T, bool traditional, bool forward>
|
||||
__global__ void rope_freqs(
|
||||
const T* in,
|
||||
T* out,
|
||||
const int32_t* offset,
|
||||
const float* freqs,
|
||||
float scale,
|
||||
float base,
|
||||
const __grid_constant__ cuda::std::array<int64_t, 3> strides,
|
||||
const __grid_constant__ cuda::std::array<int64_t, 3> out_strides,
|
||||
int64_t n_batch,
|
||||
uint3 dims,
|
||||
int64_t freq_stride) {
|
||||
uint3 pos = make_uint3(
|
||||
blockIdx.x * blockDim.x + threadIdx.x,
|
||||
blockIdx.y * blockDim.y + threadIdx.y,
|
||||
blockIdx.z * blockDim.z + threadIdx.z);
|
||||
if (pos.x >= dims.x || pos.y >= dims.y || pos.z >= dims.z) {
|
||||
return;
|
||||
}
|
||||
|
||||
float inv_freq = 1.0 / freqs[freq_stride * pos.x];
|
||||
rope_impl<T, traditional, forward>(
|
||||
in,
|
||||
out,
|
||||
*offset,
|
||||
inv_freq,
|
||||
scale,
|
||||
strides,
|
||||
out_strides,
|
||||
n_batch,
|
||||
pos,
|
||||
dims);
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
namespace fast {
|
||||
|
||||
bool RoPE::use_fallback(Stream s) {
|
||||
return s.device == Device::cpu;
|
||||
}
|
||||
|
||||
void RoPE::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
nvtx3::scoped_range r("RoPE::eval_gpu");
|
||||
|
||||
auto& s = stream();
|
||||
auto& in = inputs[0];
|
||||
auto& offset = inputs[1];
|
||||
auto& out = outputs[0];
|
||||
|
||||
if (in.ndim() < 3) {
|
||||
throw std::runtime_error("[RoPE] Input must have at least 3 dimensions");
|
||||
}
|
||||
|
||||
cuda::std::array<int64_t, 3> strides;
|
||||
cuda::std::array<int64_t, 3> out_strides;
|
||||
bool donated = false;
|
||||
int ndim = in.ndim();
|
||||
int dispatch_ndim = in.ndim();
|
||||
while (in.shape(-dispatch_ndim) == 1 && dispatch_ndim > 3) {
|
||||
dispatch_ndim--;
|
||||
}
|
||||
size_t mat_size = in.shape(-2) * in.shape(-1);
|
||||
|
||||
// We apply rope to less that the whole vector so copy to output and then
|
||||
// apply in-place.
|
||||
if (dims_ < in.shape(-1)) {
|
||||
donated = true;
|
||||
auto ctype =
|
||||
(in.flags().row_contiguous) ? CopyType::Vector : CopyType::General;
|
||||
copy_gpu(in, out, ctype, s);
|
||||
strides[0] = mat_size;
|
||||
strides[1] = out.strides()[ndim - 2];
|
||||
strides[2] = out.strides()[ndim - 1];
|
||||
}
|
||||
|
||||
// Either copy or apply in-place
|
||||
else if (in.flags().row_contiguous) {
|
||||
if (in.is_donatable()) {
|
||||
donated = true;
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
}
|
||||
strides[0] = mat_size;
|
||||
strides[1] = in.strides()[ndim - 2];
|
||||
strides[2] = in.strides()[ndim - 1];
|
||||
} else if (dispatch_ndim == 3) {
|
||||
// Handle non-contiguous 3D inputs
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
strides[0] = in.strides()[ndim - 3];
|
||||
strides[1] = in.strides()[ndim - 2];
|
||||
strides[2] = in.strides()[ndim - 1];
|
||||
} else {
|
||||
// Copy non-contiguous > 3D inputs into the output and treat
|
||||
// input as donated
|
||||
donated = true;
|
||||
copy_gpu(in, out, CopyType::General, s);
|
||||
strides[0] = mat_size;
|
||||
strides[1] = out.strides()[ndim - 2];
|
||||
strides[2] = out.strides()[ndim - 1];
|
||||
}
|
||||
out_strides[0] = mat_size;
|
||||
out_strides[1] = out.strides()[ndim - 2];
|
||||
out_strides[2] = out.strides()[ndim - 1];
|
||||
|
||||
// Some flags to help us dispatch below
|
||||
bool single = in.flags().row_contiguous && (mat_size == in.shape(-1));
|
||||
bool with_freqs = inputs.size() == 3;
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(donated ? out : in);
|
||||
encoder.set_input_array(offset);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_FLOAT_TYPES_CHECKED(in.dtype(), "rope", CTYPE, {
|
||||
using DataType = cuda_type_t<CTYPE>;
|
||||
MLX_SWITCH_BOOL(traditional_, TRADITIONAL, {
|
||||
MLX_SWITCH_BOOL(forward_, FORWARD, {
|
||||
if (single && !with_freqs) {
|
||||
auto kernel = cu::rope_single<DataType, TRADITIONAL, FORWARD>;
|
||||
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size);
|
||||
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
|
||||
kernel<<<grid, block, 0, stream>>>(
|
||||
(donated ? out : in).data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
offset.data<int32_t>(),
|
||||
scale_,
|
||||
std::log2(base_),
|
||||
mat_size,
|
||||
dims);
|
||||
} else if (single) {
|
||||
auto kernel = cu::rope_single_freqs<DataType, TRADITIONAL, FORWARD>;
|
||||
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size);
|
||||
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
|
||||
kernel<<<grid, block, 0, stream>>>(
|
||||
(donated ? out : in).data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
offset.data<int32_t>(),
|
||||
inputs[2].data<float>(),
|
||||
scale_,
|
||||
mat_size,
|
||||
dims,
|
||||
inputs[2].strides(0));
|
||||
} else if (with_freqs) {
|
||||
auto kernel = cu::rope_freqs<DataType, TRADITIONAL, FORWARD>;
|
||||
uint3 dims =
|
||||
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size);
|
||||
dims.z = (dims.z + 3) / 4;
|
||||
auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z);
|
||||
kernel<<<grid, block, 0, stream>>>(
|
||||
(donated ? out : in).data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
offset.data<int32_t>(),
|
||||
inputs[2].data<float>(),
|
||||
scale_,
|
||||
std::log2(base_),
|
||||
strides,
|
||||
out_strides,
|
||||
in.size() / mat_size,
|
||||
dims,
|
||||
inputs[2].strides(0));
|
||||
} else {
|
||||
auto kernel = cu::rope<DataType, TRADITIONAL, FORWARD>;
|
||||
uint3 dims =
|
||||
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size);
|
||||
dims.z = (dims.z + 3) / 4;
|
||||
auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z);
|
||||
kernel<<<grid, block, 0, stream>>>(
|
||||
(donated ? out : in).data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
offset.data<int32_t>(),
|
||||
scale_,
|
||||
std::log2(base_),
|
||||
strides,
|
||||
out_strides,
|
||||
in.size() / mat_size,
|
||||
dims);
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace fast
|
||||
|
||||
} // namespace mlx::core
|
5
python/tests/__main__.py
Normal file
5
python/tests/__main__.py
Normal file
@ -0,0 +1,5 @@
|
||||
from . import mlx_tests
|
||||
|
||||
__unittest = True
|
||||
|
||||
mlx_tests.MLXTestRunner(module=None)
|
143
python/tests/cuda_skip.py
Normal file
143
python/tests/cuda_skip.py
Normal file
@ -0,0 +1,143 @@
|
||||
cuda_skip = {
|
||||
"TestArray.test_api",
|
||||
"TestArray.test_setitem",
|
||||
"TestAutograd.test_cumprod_grad",
|
||||
"TestAutograd.test_slice_grads",
|
||||
"TestAutograd.test_split_against_slice",
|
||||
"TestAutograd.test_stop_gradient",
|
||||
"TestAutograd.test_topk_grad",
|
||||
"TestAutograd.test_update_state",
|
||||
"TestAutograd.test_vjp",
|
||||
"TestBF16.test_arg_reduction_ops",
|
||||
"TestBF16.test_binary_ops",
|
||||
"TestBF16.test_reduction_ops",
|
||||
"TestBlas.test_block_masked_matmul",
|
||||
"TestBlas.test_complex_gemm",
|
||||
"TestBlas.test_gather_matmul",
|
||||
"TestBlas.test_gather_matmul_grad",
|
||||
"TestBlas.test_matmul_batched",
|
||||
"TestBlas.test_matrix_vector_attn",
|
||||
"TestCompile.test_compile_dynamic_dims",
|
||||
"TestCompile.test_compile_inf",
|
||||
"TestCompile.test_inf_constant",
|
||||
"TestConv.test_1d_conv_with_2d",
|
||||
"TestConv.test_asymmetric_padding",
|
||||
"TestConv.test_basic_grad_shapes",
|
||||
"TestConv.test_conv2d_unaligned_channels",
|
||||
"TestConv.test_conv_1d_groups_flipped",
|
||||
"TestConv.test_conv_general_flip_grad",
|
||||
"TestConv.test_conv_groups_grad",
|
||||
"TestConv.test_numpy_conv",
|
||||
"TestConv.test_repeated_conv",
|
||||
"TestConv.test_torch_conv_1D",
|
||||
"TestConv.test_torch_conv_1D_grad",
|
||||
"TestConv.test_torch_conv_2D",
|
||||
"TestConv.test_torch_conv_2D_grad",
|
||||
"TestConv.test_torch_conv_3D",
|
||||
"TestConv.test_torch_conv_3D_grad",
|
||||
"TestConv.test_torch_conv_depthwise",
|
||||
"TestConv.test_torch_conv_general",
|
||||
"TestConvTranspose.test_torch_conv_tranpose_1d_output_padding",
|
||||
"TestConvTranspose.test_torch_conv_transpose_1D",
|
||||
"TestConvTranspose.test_torch_conv_transpose_1D_grad",
|
||||
"TestConvTranspose.test_torch_conv_transpose_2D",
|
||||
"TestConvTranspose.test_torch_conv_transpose_2D_grad",
|
||||
"TestConvTranspose.test_torch_conv_transpose_2d_output_padding",
|
||||
"TestConvTranspose.test_torch_conv_transpose_3D",
|
||||
"TestConvTranspose.test_torch_conv_transpose_3D_grad",
|
||||
"TestConvTranspose.test_torch_conv_transpose_3d_output_padding",
|
||||
"TestEinsum.test_attention",
|
||||
"TestEinsum.test_ellipses",
|
||||
"TestEinsum.test_opt_einsum_test_cases",
|
||||
"TestEval.test_multi_output_eval_during_transform",
|
||||
"TestExportImport.test_export_conv",
|
||||
"TestFast.test_rope_grad",
|
||||
"TestFFT.test_fft",
|
||||
"TestFFT.test_fft_big_powers_of_two",
|
||||
"TestFFT.test_fft_contiguity",
|
||||
"TestFFT.test_fft_exhaustive",
|
||||
"TestFFT.test_fft_grads",
|
||||
"TestFFT.test_fft_into_ifft",
|
||||
"TestFFT.test_fft_large_numbers",
|
||||
"TestFFT.test_fft_shared_mem",
|
||||
"TestFFT.test_fftn",
|
||||
"TestInit.test_orthogonal",
|
||||
"TestLinalg.test_cholesky",
|
||||
"TestLinalg.test_cholesky_inv",
|
||||
"TestLinalg.test_eig",
|
||||
"TestLinalg.test_eigh",
|
||||
"TestLinalg.test_inverse",
|
||||
"TestLinalg.test_lu",
|
||||
"TestLinalg.test_lu_factor",
|
||||
"TestLinalg.test_pseudo_inverse",
|
||||
"TestLinalg.test_qr_factorization",
|
||||
"TestLinalg.test_svd_decomposition",
|
||||
"TestLinalg.test_tri_inverse",
|
||||
"TestLoad.test_load_f8_e4m3",
|
||||
"TestLosses.test_binary_cross_entropy",
|
||||
"TestMemory.test_memory_info",
|
||||
"TestLayers.test_conv1d",
|
||||
"TestLayers.test_conv2d",
|
||||
"TestLayers.test_elu",
|
||||
"TestLayers.test_group_norm",
|
||||
"TestLayers.test_hard_shrink",
|
||||
"TestLayers.test_pooling",
|
||||
"TestLayers.test_quantized_embedding",
|
||||
"TestLayers.test_sin_pe",
|
||||
"TestLayers.test_softshrink",
|
||||
"TestLayers.test_upsample",
|
||||
"TestOps.test_argpartition",
|
||||
"TestOps.test_array_equal",
|
||||
"TestOps.test_as_strided",
|
||||
"TestOps.test_atleast_1d",
|
||||
"TestOps.test_atleast_2d",
|
||||
"TestOps.test_atleast_3d",
|
||||
"TestOps.test_binary_ops",
|
||||
"TestOps.test_bitwise_grad",
|
||||
"TestOps.test_complex_ops",
|
||||
"TestOps.test_divmod",
|
||||
"TestOps.test_dynamic_slicing",
|
||||
"TestOps.test_hadamard",
|
||||
"TestOps.test_hadamard_grad_vmap",
|
||||
"TestOps.test_irregular_binary_ops",
|
||||
"TestOps.test_isfinite",
|
||||
"TestOps.test_kron",
|
||||
"TestOps.test_log",
|
||||
"TestOps.test_log10",
|
||||
"TestOps.test_log1p",
|
||||
"TestOps.test_log2",
|
||||
"TestOps.test_logaddexp",
|
||||
"TestOps.test_logcumsumexp",
|
||||
"TestOps.test_partition",
|
||||
"TestOps.test_scans",
|
||||
"TestOps.test_slice_update_reversed",
|
||||
"TestOps.test_softmax",
|
||||
"TestOps.test_sort",
|
||||
"TestOps.test_tensordot",
|
||||
"TestOps.test_tile",
|
||||
"TestOps.test_view",
|
||||
"TestQuantized.test_gather_matmul_grad",
|
||||
"TestQuantized.test_gather_qmm",
|
||||
"TestQuantized.test_gather_qmm_sorted",
|
||||
"TestQuantized.test_non_multiples",
|
||||
"TestQuantized.test_qmm",
|
||||
"TestQuantized.test_qmm_jvp",
|
||||
"TestQuantized.test_qmm_shapes",
|
||||
"TestQuantized.test_qmm_vjp",
|
||||
"TestQuantized.test_qmv",
|
||||
"TestQuantized.test_quantize_dequantize",
|
||||
"TestQuantized.test_qvm",
|
||||
"TestQuantized.test_qvm_splitk",
|
||||
"TestQuantized.test_small_matrix",
|
||||
"TestQuantized.test_throw",
|
||||
"TestQuantized.test_vjp_scales_biases",
|
||||
"TestReduce.test_axis_permutation_sums",
|
||||
"TestReduce.test_dtypes",
|
||||
"TestReduce.test_expand_sums",
|
||||
"TestReduce.test_many_reduction_axes",
|
||||
"TestUpsample.test_torch_upsample",
|
||||
"TestVmap.test_unary",
|
||||
"TestVmap.test_vmap_conv",
|
||||
"TestVmap.test_vmap_inverse",
|
||||
"TestVmap.test_vmap_svd",
|
||||
}
|
@ -9,6 +9,42 @@ import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MLXTestRunner(unittest.TestProgram):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def createTests(self, *args, **kwargs):
|
||||
super().createTests(*args, **kwargs)
|
||||
|
||||
# Asume CUDA backend in this case
|
||||
device = os.getenv("DEVICE", None)
|
||||
if device is not None:
|
||||
device = getattr(mx, device)
|
||||
else:
|
||||
device = mx.default_device()
|
||||
|
||||
if not (device == mx.gpu and not mx.metal.is_available()):
|
||||
return
|
||||
|
||||
from cuda_skip import cuda_skip
|
||||
|
||||
filtered_suite = unittest.TestSuite()
|
||||
|
||||
def filter_and_add(t):
|
||||
if isinstance(t, unittest.TestSuite):
|
||||
for sub_t in t:
|
||||
filter_and_add(sub_t)
|
||||
else:
|
||||
t_id = ".".join(t.id().split(".")[-2:])
|
||||
if t_id in cuda_skip:
|
||||
print(f"Skipping {t_id}")
|
||||
else:
|
||||
filtered_suite.addTest(t)
|
||||
|
||||
filter_and_add(self.test)
|
||||
self.test = filtered_suite
|
||||
|
||||
|
||||
class MLXTestCase(unittest.TestCase):
|
||||
@property
|
||||
def is_apple_silicon(self):
|
||||
|
@ -130,4 +130,4 @@ class TestRingDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -2033,4 +2033,4 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -799,4 +799,4 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -193,4 +193,4 @@ class TestBF16(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -1236,4 +1236,4 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -981,4 +981,4 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -38,4 +38,4 @@ class TestConstants(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -1188,4 +1188,4 @@ class TestConv(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -807,4 +807,4 @@ class TestConvTranspose(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -38,7 +38,7 @@ class TestDevice(mlx_tests.MLXTestCase):
|
||||
# Restore device
|
||||
mx.set_default_device(device)
|
||||
|
||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||
@unittest.skipIf(not mx.is_available(mx.gpu), "GPU is not available")
|
||||
def test_device_context(self):
|
||||
default = mx.default_device()
|
||||
diff = mx.cpu if default == mx.gpu else mx.gpu
|
||||
@ -114,4 +114,4 @@ class TestStream(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -294,4 +294,4 @@ class TestDouble(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -360,4 +360,4 @@ class TestEinsum(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -172,7 +172,7 @@ class TestEval(mlx_tests.MLXTestCase):
|
||||
post = mx.get_peak_memory()
|
||||
self.assertEqual(pre, post)
|
||||
|
||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||
@unittest.skipIf(not mx.is_available(mx.gpu), "GPU is not available")
|
||||
def test_multistream_deadlock(self):
|
||||
s1 = mx.default_stream(mx.gpu)
|
||||
s2 = mx.new_stream(mx.gpu)
|
||||
@ -197,4 +197,4 @@ class TestEval(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -348,4 +348,4 @@ class TestExportImport(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -772,4 +772,4 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -607,7 +607,7 @@ class TestSDPA(mlx_tests.MLXTestCase):
|
||||
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
|
||||
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
||||
|
||||
def test_sdpa_prommote_mask(self):
|
||||
def test_sdpa_promote_mask(self):
|
||||
mask = mx.array(2.0, mx.bfloat16)
|
||||
D = 64
|
||||
Nq = 4
|
||||
@ -653,4 +653,4 @@ class TestSDPA(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(failfast=True)
|
||||
mlx_tests.MLXTestRunner(failfast=True)
|
||||
|
@ -320,4 +320,4 @@ class TestFFT(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -34,4 +34,4 @@ class TestGraph(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -136,4 +136,4 @@ class TestInit(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -545,4 +545,4 @@ class TestLinalg(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -400,4 +400,4 @@ class TestLoad(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -414,4 +414,4 @@ class TestLosses(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -60,4 +60,4 @@ class TestMemory(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -1907,4 +1907,4 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -3127,4 +3127,4 @@ class TestBroadcast(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -527,4 +527,4 @@ class TestSchedulers(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -576,4 +576,4 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -389,4 +389,4 @@ class TestRandom(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -155,4 +155,4 @@ class TestReduce(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(failfast=True)
|
||||
mlx_tests.MLXTestRunner(failfast=True)
|
||||
|
@ -48,4 +48,4 @@ class TestTreeUtils(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -97,4 +97,4 @@ class TestUpsample(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -725,4 +725,4 @@ class TestVmap(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
Loading…
Reference in New Issue
Block a user