mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Compare commits
8 Commits
441bd764e6
...
14d22ddedb
Author | SHA1 | Date | |
---|---|---|---|
![]() |
14d22ddedb | ||
![]() |
bc53f8293f | ||
![]() |
c552ff2451 | ||
![]() |
4d68bd3250 | ||
![]() |
5fbce6c49e | ||
![]() |
0b5c5680f4 | ||
![]() |
221edc4a65 | ||
![]() |
190c72739b |
@ -107,6 +107,16 @@ same array:
|
||||
>>> a
|
||||
array([1, 2, 0], dtype=int32)
|
||||
|
||||
|
||||
Note, unlike NumPy, updates to the same location are nondeterministic:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
>>> a = mx.array([1, 2, 3])
|
||||
>>> a[[0, 0]] = mx.array([4, 5])
|
||||
|
||||
The first element of ``a`` could be ``4`` or ``5``.
|
||||
|
||||
Transformations of functions which use in-place updates are allowed and work as
|
||||
expected. For example:
|
||||
|
||||
|
@ -12,6 +12,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/paged_attention.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/random.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cpp
|
||||
|
@ -101,10 +101,12 @@ constexpr bool supports_binary_op() {
|
||||
return std::is_same_v<Out, bool> && std::is_same_v<In, bool>;
|
||||
}
|
||||
if (std::is_same_v<Op, NaNEqual>) {
|
||||
return std::is_same_v<Out, bool> &&
|
||||
(is_floating_v<In> || std::is_same_v<In, complex64_t>);
|
||||
return std::is_same_v<Out, bool> && is_inexact_v<In>;
|
||||
}
|
||||
if (std::is_same_v<Op, LogAddExp> || std::is_same_v<Op, ArcTan2>) {
|
||||
if (std::is_same_v<Op, LogAddExp>) {
|
||||
return std::is_same_v<In, Out> && is_inexact_v<In>;
|
||||
}
|
||||
if (std::is_same_v<Op, ArcTan2>) {
|
||||
return std::is_same_v<In, Out> && is_floating_v<In>;
|
||||
}
|
||||
if (std::is_same_v<Op, BitwiseAnd> || std::is_same_v<Op, BitwiseOr> ||
|
||||
@ -150,10 +152,10 @@ void binary_op_gpu_inplace(
|
||||
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
|
||||
auto& a_strides = strides[0];
|
||||
auto& b_strides = strides[1];
|
||||
bool large = a.data_size() > UINT32_MAX ||
|
||||
b.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX;
|
||||
bool large = a.data_size() > INT32_MAX ||
|
||||
b.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
|
||||
MLX_SWITCH_BOOL(large, LARGE, {
|
||||
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
||||
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
|
||||
int ndim = shape.size();
|
||||
if (ndim <= 3) {
|
||||
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
||||
@ -165,7 +167,7 @@ void binary_op_gpu_inplace(
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out.data<OutType>(),
|
||||
out.data_size(),
|
||||
out.size(),
|
||||
const_param<NDIM>(shape),
|
||||
const_param<NDIM>(a_strides),
|
||||
const_param<NDIM>(b_strides));
|
||||
@ -178,7 +180,7 @@ void binary_op_gpu_inplace(
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out.data<OutType>(),
|
||||
out.data_size(),
|
||||
out.size(),
|
||||
const_param(shape),
|
||||
const_param(a_strides),
|
||||
const_param(b_strides),
|
||||
@ -196,8 +198,8 @@ void binary_op_gpu_inplace(
|
||||
} else if (bopt == BinaryOpType::VectorVector) {
|
||||
kernel = cu::binary_vv<Op, InType, OutType, IdxT>;
|
||||
}
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, out, LARGE);
|
||||
auto [num_blocks, block_dims] = get_launch_args(
|
||||
kernel, out.data_size(), out.shape(), out.strides(), LARGE);
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
@ -264,7 +266,6 @@ BINARY_GPU(Add)
|
||||
BINARY_GPU(ArcTan2)
|
||||
BINARY_GPU(Divide)
|
||||
BINARY_GPU(Remainder)
|
||||
BINARY_GPU(Equal)
|
||||
BINARY_GPU(Greater)
|
||||
BINARY_GPU(GreaterEqual)
|
||||
BINARY_GPU(Less)
|
||||
@ -279,6 +280,17 @@ BINARY_GPU(NotEqual)
|
||||
BINARY_GPU(Power)
|
||||
BINARY_GPU(Subtract)
|
||||
|
||||
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("Equal::eval_gpu");
|
||||
auto& s = out.primitive().stream();
|
||||
auto op = get_primitive_string(this);
|
||||
if (equal_nan_) {
|
||||
binary_op_gpu<cu::NaNEqual>(inputs, out, op, s);
|
||||
} else {
|
||||
binary_op_gpu<cu::Equal>(inputs, out, op, s);
|
||||
}
|
||||
}
|
||||
|
||||
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("BitwiseBinary::eval_gpu");
|
||||
auto& s = out.primitive().stream();
|
||||
|
@ -130,11 +130,13 @@ struct FusedKernelBuilder {
|
||||
|
||||
constexpr const char* g_jit_includes = R"(
|
||||
#include "mlx/backend/cuda/device/binary_ops.cuh"
|
||||
#include "mlx/backend/cuda/device/ternary_ops.cuh"
|
||||
#include "mlx/backend/cuda/device/unary_ops.cuh"
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
|
||||
#define inf cuda::std::numeric_limits<float>::infinity()
|
||||
)";
|
||||
|
||||
void Compiled::eval_gpu(
|
||||
|
@ -6,7 +6,7 @@
|
||||
namespace mlx::core {
|
||||
|
||||
void copy_gpu_inplace(
|
||||
const array& in_,
|
||||
const array& in,
|
||||
array& out,
|
||||
const Shape& shape,
|
||||
const Strides& strides_in,
|
||||
@ -20,7 +20,6 @@ void copy_gpu_inplace(
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
const array& in = in_.data_shared_ptr() ? in_ : out;
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(in);
|
||||
|
@ -10,20 +10,13 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
#define MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, ...) \
|
||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE_IN, { \
|
||||
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, { \
|
||||
using InType = cuda_type_t<CTYPE_IN>; \
|
||||
using OutType = cuda_type_t<CTYPE_OUT>; \
|
||||
if constexpr (cu::CastOp<InType, OutType>::is_castable) { \
|
||||
__VA_ARGS__; \
|
||||
} else { \
|
||||
throw std::runtime_error(fmt::format( \
|
||||
"Can not copy data from dtype {} to {}.", \
|
||||
dtype_to_string(out.dtype()), \
|
||||
dtype_to_string(in.dtype()))); \
|
||||
} \
|
||||
}); \
|
||||
#define MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, ...) \
|
||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE_IN, { \
|
||||
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, { \
|
||||
using InType = cuda_type_t<CTYPE_IN>; \
|
||||
using OutType = cuda_type_t<CTYPE_OUT>; \
|
||||
__VA_ARGS__; \
|
||||
}); \
|
||||
})
|
||||
|
||||
void copy_contiguous(
|
||||
|
@ -43,7 +43,8 @@ void copy_contiguous(
|
||||
if (ctype == CopyType::Vector) {
|
||||
kernel = cu::copy_v<InType, OutType, IdxT>;
|
||||
}
|
||||
auto [num_blocks, block_dims] = get_launch_args(kernel, out, LARGE);
|
||||
auto [num_blocks, block_dims] = get_launch_args(
|
||||
kernel, out.data_size(), out.shape(), out.strides(), LARGE);
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
in.data<InType>() + in_offset,
|
||||
out.data<OutType>() + out_offset,
|
||||
|
@ -59,9 +59,9 @@ void copy_general(
|
||||
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
|
||||
const InType* in_ptr = in.data<InType>() + offset_in;
|
||||
OutType* out_ptr = out.data<OutType>() + offset_out;
|
||||
bool large = in.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX;
|
||||
bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
|
||||
MLX_SWITCH_BOOL(large, LARGE, {
|
||||
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
||||
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
|
||||
int ndim = shape.size();
|
||||
if (ndim <= 3) {
|
||||
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
||||
@ -70,7 +70,7 @@ void copy_general(
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
out.data_size(),
|
||||
out.size(),
|
||||
const_param<NDIM>(shape),
|
||||
const_param<NDIM>(strides_in),
|
||||
const_param<NDIM>(strides_out));
|
||||
@ -81,7 +81,7 @@ void copy_general(
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
out.data_size(),
|
||||
out.size(),
|
||||
const_param(shape),
|
||||
const_param(strides_in),
|
||||
const_param(strides_out),
|
||||
|
@ -65,9 +65,9 @@ void copy_general_dynamic(
|
||||
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
|
||||
const InType* in_ptr = in.data<InType>() + offset_in;
|
||||
OutType* out_ptr = out.data<OutType>() + offset_out;
|
||||
bool large = in.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX;
|
||||
bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
|
||||
MLX_SWITCH_BOOL(large, LARGE, {
|
||||
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
||||
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
|
||||
int ndim = shape.size();
|
||||
if (ndim <= 3) {
|
||||
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
||||
@ -76,7 +76,7 @@ void copy_general_dynamic(
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
out.data_size(),
|
||||
out.size(),
|
||||
const_param<NDIM>(shape),
|
||||
const_param<NDIM>(strides_in),
|
||||
const_param<NDIM>(strides_out),
|
||||
@ -89,7 +89,7 @@ void copy_general_dynamic(
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
out.data_size(),
|
||||
out.size(),
|
||||
const_param(shape),
|
||||
const_param(strides_in),
|
||||
const_param(strides_out),
|
||||
|
@ -54,9 +54,9 @@ void copy_general_input(
|
||||
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
|
||||
const InType* in_ptr = in.data<InType>() + offset_in;
|
||||
OutType* out_ptr = out.data<OutType>() + offset_out;
|
||||
bool large = in.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX;
|
||||
bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
|
||||
MLX_SWITCH_BOOL(large, LARGE, {
|
||||
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
||||
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
|
||||
int ndim = shape.size();
|
||||
if (ndim <= 3) {
|
||||
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
||||
@ -65,7 +65,7 @@ void copy_general_input(
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
out.data_size(),
|
||||
out.size(),
|
||||
const_param<NDIM>(shape),
|
||||
const_param<NDIM>(strides_in));
|
||||
});
|
||||
@ -75,7 +75,7 @@ void copy_general_input(
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
out.data_size(),
|
||||
out.size(),
|
||||
const_param(shape),
|
||||
const_param(strides_in),
|
||||
ndim);
|
||||
|
@ -1,6 +1,8 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
|
||||
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
|
||||
#include <cuComplex.h>
|
||||
#include <cuda/std/array>
|
||||
@ -122,6 +124,26 @@ struct LogAddExp {
|
||||
? maxval
|
||||
: T(float(maxval) + log1p(expf(minval - maxval)));
|
||||
};
|
||||
|
||||
__device__ cuComplex operator()(cuComplex x, cuComplex y) {
|
||||
if (isnan(cuCrealf(x)) || isnan(cuCimagf(x)) || isnan(cuCrealf(y)) ||
|
||||
isnan(cuCimagf(y))) {
|
||||
return {
|
||||
cuda::std::numeric_limits<float>::quiet_NaN(),
|
||||
cuda::std::numeric_limits<float>::quiet_NaN()};
|
||||
}
|
||||
constexpr float inf = cuda::std::numeric_limits<float>::infinity();
|
||||
auto maxval = x > y ? x : y;
|
||||
auto minval = x < y ? x : y;
|
||||
if (cuCrealf(minval) == -inf || cuCrealf(maxval) == inf)
|
||||
return maxval;
|
||||
float m = exp(cuCrealf(minval) - cuCrealf(maxval));
|
||||
cuComplex dexp{
|
||||
m * cos(cuCimagf(minval) - cuCimagf(maxval)),
|
||||
m * sin(cuCimagf(minval) - cuCimagf(maxval)),
|
||||
};
|
||||
return maxval + log1p(dexp);
|
||||
}
|
||||
};
|
||||
|
||||
struct Maximum {
|
||||
|
@ -45,6 +45,18 @@ struct CastOp<
|
||||
}
|
||||
};
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
struct CastOp<
|
||||
SrcT,
|
||||
DstT,
|
||||
cuda::std::enable_if_t<cuda::std::is_same_v<SrcT, DstT>>> {
|
||||
static constexpr bool is_castable = true;
|
||||
|
||||
__device__ SrcT operator()(SrcT x) {
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
// Return an iterator that cast the value to DstT using CastOp.
|
||||
template <typename DstT, typename Iterator>
|
||||
__host__ __device__ auto make_cast_iterator(Iterator it) {
|
||||
|
@ -1,4 +1,5 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
#pragma once
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
|
@ -5,6 +5,8 @@
|
||||
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
|
||||
#include <math_constants.h>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
struct Abs {
|
||||
@ -183,21 +185,38 @@ struct Imag {
|
||||
struct Log {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return log(x);
|
||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
auto r = log(cuCrealf(Abs{}(x)));
|
||||
auto i = atan2f(cuCimagf(x), cuCrealf(x));
|
||||
return {r, i};
|
||||
} else {
|
||||
return log(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Log2 {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return log2(x);
|
||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
auto y = Log{}(x);
|
||||
return {cuCrealf(y) / CUDART_LN2_F, cuCimagf(y) / CUDART_LN2_F};
|
||||
} else {
|
||||
return log2(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Log10 {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return log10(x);
|
||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
auto y = Log{}(x);
|
||||
return {cuCrealf(y) / CUDART_LNT_F, cuCimagf(y) / CUDART_LNT_F};
|
||||
return y;
|
||||
} else {
|
||||
return log10(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -187,8 +187,8 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_nd(
|
||||
template <typename IdxT = int64_t>
|
||||
inline __host__ __device__ IdxT
|
||||
elem_to_loc_4d(IdxT elem, const int* shape, const int64_t* strides, int ndim) {
|
||||
IdxT loc = elem_to_loc_nd<3>(elem, shape, strides);
|
||||
for (int i = ndim - 1; i >= 3; --i) {
|
||||
IdxT loc = 0;
|
||||
for (int i = ndim - 1; i >= 0; --i) {
|
||||
loc += (elem % shape[i]) * IdxT(strides[i]);
|
||||
elem /= shape[i];
|
||||
}
|
||||
@ -202,8 +202,9 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_4d(
|
||||
const int64_t* a_strides,
|
||||
const int64_t* b_strides,
|
||||
int ndim) {
|
||||
auto [a_loc, b_loc] = elem_to_loc_nd<3>(elem, shape, a_strides, b_strides);
|
||||
for (int i = ndim - 1; i >= 3; --i) {
|
||||
IdxT a_loc = 0;
|
||||
IdxT b_loc = 0;
|
||||
for (int i = ndim - 1; i >= 0; --i) {
|
||||
int dim_idx = elem % shape[i];
|
||||
a_loc += dim_idx * a_strides[i];
|
||||
b_loc += dim_idx * b_strides[i];
|
||||
@ -220,9 +221,10 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_4d(
|
||||
const int64_t* b_strides,
|
||||
const int64_t* c_strides,
|
||||
int ndim) {
|
||||
auto [a_loc, b_loc, c_loc] =
|
||||
elem_to_loc_nd<3>(elem, shape, a_strides, b_strides, c_strides);
|
||||
for (int i = ndim - 1; i >= 3; --i) {
|
||||
IdxT a_loc = 0;
|
||||
IdxT b_loc = 0;
|
||||
IdxT c_loc = 0;
|
||||
for (int i = ndim - 1; i >= 0; --i) {
|
||||
int dim_idx = elem % shape[i];
|
||||
a_loc += dim_idx * a_strides[i];
|
||||
b_loc += dim_idx * b_strides[i];
|
||||
@ -336,4 +338,21 @@ struct LoopedElemToLoc<1, false, OffsetT> {
|
||||
}
|
||||
};
|
||||
|
||||
inline __device__ cuComplex log1p(cuComplex in) {
|
||||
float x = cuCrealf(in);
|
||||
float y = cuCimagf(in);
|
||||
float zabs = sqrt(x * x + y * y);
|
||||
float theta = atan2f(y, x + 1);
|
||||
if (zabs < 0.5f) {
|
||||
float r = x * (2 + x) + y * y;
|
||||
if (r == 0) { // handle underflow
|
||||
return {x, theta};
|
||||
}
|
||||
return {0.5f * log1pf(r), theta};
|
||||
} else {
|
||||
auto z0 = sqrt((x + 1) * (x + 1) + y * y);
|
||||
return {log(z0), theta};
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
||||
|
@ -65,8 +65,8 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
Dtype idx_dtype = nidx > 0 ? inputs[1].dtype() : int32;
|
||||
int32_t idx_ndim = nidx > 0 ? inputs[1].ndim() : 0;
|
||||
|
||||
bool large = (nidx > 0 && inputs[1].size() > UINT32_MAX) ||
|
||||
(src.size() > UINT32_MAX) || (out.size() > UINT32_MAX);
|
||||
bool large = (nidx > 0 && inputs[1].size() > INT32_MAX) ||
|
||||
(src.size() > INT32_MAX) || (out.size() > INT32_MAX);
|
||||
|
||||
uint32_t slice_size = std::accumulate(
|
||||
slice_sizes_.begin(), slice_sizes_.end(), 1, std::multiplies<uint32_t>());
|
||||
@ -88,7 +88,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
dtype_to_cuda_type(idx_dtype),
|
||||
nidx,
|
||||
ndim,
|
||||
large ? "int64_t" : "uint32_t"));
|
||||
large ? "int64_t" : "int32_t"));
|
||||
}
|
||||
}
|
||||
return std::make_pair(jit_source_gather, std::move(kernel_names));
|
||||
@ -99,7 +99,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
if (large) {
|
||||
mod.append_arg<int64_t>(out.size());
|
||||
} else {
|
||||
mod.append_arg<uint32_t>(out.size());
|
||||
mod.append_arg<int32_t>(out.size());
|
||||
}
|
||||
mod.append_ndim_arg(src.shape());
|
||||
mod.append_ndim_arg(src.strides());
|
||||
@ -115,7 +115,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
dtype_to_cuda_type(idx_dtype),
|
||||
nidx,
|
||||
idx_ndim,
|
||||
large ? "int64_t" : "uint32_t");
|
||||
large ? "int64_t" : "int32_t");
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
for (const auto& in : inputs) {
|
||||
@ -152,14 +152,14 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
Dtype idx_dtype = nidx > 0 ? inputs[1].dtype() : int32;
|
||||
int32_t idx_ndim = nidx > 0 ? inputs[1].ndim() : 0;
|
||||
|
||||
bool large = (nidx > 0 && inputs[1].size() > UINT32_MAX) ||
|
||||
(upd.size() > UINT32_MAX) || (out.size() > UINT32_MAX);
|
||||
bool large = (nidx > 0 && inputs[1].size() > INT32_MAX) ||
|
||||
(upd.size() > INT32_MAX) || (out.size() > INT32_MAX);
|
||||
|
||||
uint32_t upd_post_idx_size = std::accumulate(
|
||||
int32_t upd_post_idx_size = std::accumulate(
|
||||
upd.shape().begin() + idx_ndim,
|
||||
upd.shape().end(),
|
||||
1,
|
||||
std::multiplies<uint32_t>());
|
||||
std::multiplies<int32_t>());
|
||||
|
||||
const char* op = g_scatter_ops[reduce_type_];
|
||||
std::string module_name = fmt::format(
|
||||
@ -181,7 +181,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
op,
|
||||
nidx,
|
||||
ndim,
|
||||
large ? "int64_t" : "uint32_t"));
|
||||
large ? "int64_t" : "int32_t"));
|
||||
}
|
||||
}
|
||||
return std::make_pair(jit_source_scatter, std::move(kernel_names));
|
||||
@ -192,7 +192,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
if (large) {
|
||||
mod.append_arg<int64_t>(upd.size());
|
||||
} else {
|
||||
mod.append_arg<uint32_t>(upd.size());
|
||||
mod.append_arg<int32_t>(upd.size());
|
||||
}
|
||||
mod.append_ndim_arg(upd.shape());
|
||||
mod.append_ndim_arg(upd.strides());
|
||||
@ -200,7 +200,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
if (large) {
|
||||
mod.append_arg<int64_t>(upd_post_idx_size);
|
||||
} else {
|
||||
mod.append_arg<uint32_t>(upd_post_idx_size);
|
||||
mod.append_arg<int32_t>(upd_post_idx_size);
|
||||
}
|
||||
mod.append_ndim_arg(out.shape());
|
||||
mod.append_ndim_arg(out.strides());
|
||||
@ -215,7 +215,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
op,
|
||||
nidx,
|
||||
idx_ndim,
|
||||
large ? "int64_t" : "uint32_t");
|
||||
large ? "int64_t" : "int32_t");
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
for (const auto& in : inputs) {
|
||||
@ -238,7 +238,7 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
return;
|
||||
}
|
||||
|
||||
bool large = idx.size() > UINT32_MAX || src.size() > UINT32_MAX;
|
||||
bool large = idx.size() > INT32_MAX || src.size() > INT32_MAX;
|
||||
|
||||
std::string module_name = fmt::format(
|
||||
"gather_axis_{}_{}",
|
||||
@ -258,7 +258,7 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
ndim,
|
||||
contiguous & 1 ? true : false,
|
||||
contiguous & 2 ? true : false,
|
||||
large ? "int64_t" : "uint32_t"));
|
||||
large ? "int64_t" : "int32_t"));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -283,9 +283,9 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
mod.append_arg<int64_t>(idx_size_axis);
|
||||
mod.append_arg<int64_t>(idx_size_post);
|
||||
} else {
|
||||
mod.append_arg<uint32_t>(idx_size_pre);
|
||||
mod.append_arg<uint32_t>(idx_size_axis);
|
||||
mod.append_arg<uint32_t>(idx_size_post);
|
||||
mod.append_arg<int32_t>(idx_size_pre);
|
||||
mod.append_arg<int32_t>(idx_size_axis);
|
||||
mod.append_arg<int32_t>(idx_size_post);
|
||||
}
|
||||
mod.append_arg(remove_index(idx.shape(), axis_));
|
||||
mod.append_arg(remove_index(src.strides(), axis_));
|
||||
@ -302,7 +302,7 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
src.ndim() - 1,
|
||||
src.flags().row_contiguous,
|
||||
idx.flags().row_contiguous,
|
||||
large ? "int64_t" : "uint32_t");
|
||||
large ? "int64_t" : "int32_t");
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
for (const auto& in : inputs) {
|
||||
@ -337,7 +337,7 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
return;
|
||||
}
|
||||
|
||||
bool large = idx.size() > UINT32_MAX || src.size() > UINT32_MAX;
|
||||
bool large = idx.size() > INT32_MAX || src.size() > INT32_MAX;
|
||||
|
||||
const char* op = reduce_type_ == ScatterAxis::Sum ? "Sum" : "Assign";
|
||||
std::string module_name = fmt::format(
|
||||
@ -360,7 +360,7 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
ndim,
|
||||
contiguous & 1 ? true : false,
|
||||
contiguous & 2 ? true : false,
|
||||
large ? "int64_t" : "uint32_t"));
|
||||
large ? "int64_t" : "int32_t"));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -385,9 +385,9 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
mod.append_arg<int64_t>(idx_size_axis);
|
||||
mod.append_arg<int64_t>(idx_size_post);
|
||||
} else {
|
||||
mod.append_arg<uint32_t>(idx_size_pre);
|
||||
mod.append_arg<uint32_t>(idx_size_axis);
|
||||
mod.append_arg<uint32_t>(idx_size_post);
|
||||
mod.append_arg<int32_t>(idx_size_pre);
|
||||
mod.append_arg<int32_t>(idx_size_axis);
|
||||
mod.append_arg<int32_t>(idx_size_post);
|
||||
}
|
||||
mod.append_arg(remove_index(idx.shape(), axis_));
|
||||
mod.append_arg(remove_index(upd.strides(), axis_));
|
||||
@ -405,7 +405,7 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
idx.ndim() - 1,
|
||||
upd.flags().row_contiguous,
|
||||
idx.flags().row_contiguous,
|
||||
large ? "int64_t" : "uint32_t");
|
||||
large ? "int64_t" : "int32_t");
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
for (const auto& in : inputs) {
|
||||
|
@ -102,6 +102,11 @@ inline constexpr bool is_floating_v =
|
||||
cuda::std::is_same_v<T, float> || cuda::std::is_same_v<T, double> ||
|
||||
cuda::std::is_same_v<T, float16_t> || cuda::std::is_same_v<T, bfloat16_t>;
|
||||
|
||||
// Type traits for detecting complex or real floating point numbers.
|
||||
template <typename T>
|
||||
inline constexpr bool is_inexact_v =
|
||||
is_floating_v<T> || cuda::std::is_same_v<T, complex64_t>;
|
||||
|
||||
// Utility to copy data from vector to array in host.
|
||||
template <int NDIM = MAX_NDIM, typename T = int32_t>
|
||||
inline cuda::std::array<T, NDIM> const_param(const std::vector<T>& vec) {
|
||||
@ -136,17 +141,19 @@ inline uint max_occupancy_block_dim(T kernel) {
|
||||
template <typename T>
|
||||
inline std::tuple<dim3, uint> get_launch_args(
|
||||
T kernel,
|
||||
const array& arr,
|
||||
size_t size,
|
||||
const Shape& shape,
|
||||
const Strides& strides,
|
||||
bool large,
|
||||
int work_per_thread = 1) {
|
||||
size_t nthreads = cuda::ceil_div(arr.size(), work_per_thread);
|
||||
size_t nthreads = cuda::ceil_div(size, work_per_thread);
|
||||
uint block_dim = max_occupancy_block_dim(kernel);
|
||||
if (block_dim > nthreads) {
|
||||
block_dim = nthreads;
|
||||
}
|
||||
dim3 num_blocks;
|
||||
if (large) {
|
||||
num_blocks = get_2d_grid_dims(arr.shape(), arr.strides(), work_per_thread);
|
||||
num_blocks = get_2d_grid_dims(shape, strides, work_per_thread);
|
||||
num_blocks.x = cuda::ceil_div(num_blocks.x, block_dim);
|
||||
} else {
|
||||
num_blocks.x = cuda::ceil_div(nthreads, block_dim);
|
||||
@ -154,4 +161,14 @@ inline std::tuple<dim3, uint> get_launch_args(
|
||||
return std::make_tuple(num_blocks, block_dim);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline std::tuple<dim3, uint> get_launch_args(
|
||||
T kernel,
|
||||
const array& arr,
|
||||
bool large,
|
||||
int work_per_thread = 1) {
|
||||
return get_launch_args(
|
||||
kernel, arr.size(), arr.shape(), arr.strides(), large, work_per_thread);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -101,10 +101,10 @@ void ternary_op_gpu_inplace(
|
||||
auto& a_strides = strides[0];
|
||||
auto& b_strides = strides[1];
|
||||
auto& c_strides = strides[2];
|
||||
bool large = a.data_size() > UINT32_MAX || b.data_size() > UINT32_MAX ||
|
||||
c.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX;
|
||||
bool large = a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||
|
||||
c.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
|
||||
MLX_SWITCH_BOOL(large, LARGE, {
|
||||
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
||||
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
|
||||
int ndim = shape.size();
|
||||
if (ndim <= 3) {
|
||||
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
||||
@ -116,7 +116,7 @@ void ternary_op_gpu_inplace(
|
||||
b.data<DType>(),
|
||||
c.data<DType>(),
|
||||
out.data<DType>(),
|
||||
out.data_size(),
|
||||
out.size(),
|
||||
const_param<NDIM>(shape),
|
||||
const_param<NDIM>(a_strides),
|
||||
const_param<NDIM>(b_strides),
|
||||
@ -142,7 +142,8 @@ void ternary_op_gpu_inplace(
|
||||
MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, {
|
||||
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
||||
auto kernel = cu::ternary_v<Op, DType, IdxT>;
|
||||
auto [num_blocks, block_dims] = get_launch_args(kernel, out, LARGE);
|
||||
auto [num_blocks, block_dims] = get_launch_args(
|
||||
kernel, out.data_size(), out.shape(), out.strides(), LARGE);
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
a.data<bool>(),
|
||||
b.data<DType>(),
|
||||
|
@ -27,12 +27,14 @@ constexpr bool supports_unary_op() {
|
||||
std::is_same_v<Op, ArcSin> || std::is_same_v<Op, ArcSinh> ||
|
||||
std::is_same_v<Op, ArcTan> || std::is_same_v<Op, ArcTanh> ||
|
||||
std::is_same_v<Op, Erf> || std::is_same_v<Op, ErfInv> ||
|
||||
std::is_same_v<Op, Expm1> || std::is_same_v<Op, Log1p> ||
|
||||
std::is_same_v<Op, Log> || std::is_same_v<Op, Log2> ||
|
||||
std::is_same_v<Op, Log10> || std::is_same_v<Op, Sigmoid> ||
|
||||
std::is_same_v<Op, Expm1> || std::is_same_v<Op, Sigmoid> ||
|
||||
std::is_same_v<Op, Sqrt> || std::is_same_v<Op, Rsqrt>) {
|
||||
return std::is_same_v<In, Out> && is_floating_v<In>;
|
||||
}
|
||||
if (std::is_same_v<Op, Log> || std::is_same_v<Op, Log2> ||
|
||||
std::is_same_v<Op, Log10> || std::is_same_v<Op, Log1p>) {
|
||||
return std::is_same_v<In, Out> && is_inexact_v<In>;
|
||||
}
|
||||
if (std::is_same_v<Op, BitwiseInvert>) {
|
||||
return std::is_same_v<In, Out> && std::is_integral_v<In> &&
|
||||
!std::is_same_v<In, bool>;
|
||||
@ -91,7 +93,7 @@ void unary_op_gpu_inplace(
|
||||
} else {
|
||||
auto [shape, strides] = collapse_contiguous_dims(in);
|
||||
auto [in_begin, in_end] = cu::make_general_iterators<int64_t>(
|
||||
in_ptr, in.data_size(), shape, strides);
|
||||
in_ptr, in.size(), shape, strides);
|
||||
thrust::transform(policy, in_begin, in_end, out_ptr, Op());
|
||||
}
|
||||
} else {
|
||||
|
@ -31,6 +31,9 @@ const char* dtype_to_cuda_type(const Dtype& dtype) {
|
||||
if (dtype == bfloat16) {
|
||||
return "__nv_bfloat16";
|
||||
}
|
||||
if (dtype == complex64) {
|
||||
return "cuComplex";
|
||||
}
|
||||
#define SPECIALIZE_DtypeToString(CPP_TYPE, DTYPE) \
|
||||
if (dtype == DTYPE) { \
|
||||
return #CPP_TYPE; \
|
||||
|
@ -102,6 +102,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/paged_attention.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/normalization.cpp
|
||||
|
@ -241,6 +241,13 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
|
||||
int wn,
|
||||
bool transpose);
|
||||
|
||||
MTL::ComputePipelineState* get_paged_attention_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& hash_name,
|
||||
const metal::MTLFCList& func_consts,
|
||||
const std::string&);
|
||||
|
||||
// Create a GPU kernel template definition for JIT compilation
|
||||
template <typename... Args>
|
||||
std::string
|
||||
|
@ -109,6 +109,7 @@ if(NOT MLX_METAL_JIT)
|
||||
reduction/reduce_row.h)
|
||||
build_kernel(quantized quantized.h ${STEEL_HEADERS})
|
||||
build_kernel(scan scan.h)
|
||||
build_kernel(paged_attention paged_attention.h)
|
||||
build_kernel(softmax softmax.h)
|
||||
build_kernel(logsumexp logsumexp.h)
|
||||
build_kernel(sort sort.h)
|
||||
|
1196
mlx/backend/metal/kernels/paged_attention.h
Normal file
1196
mlx/backend/metal/kernels/paged_attention.h
Normal file
File diff suppressed because it is too large
Load Diff
131
mlx/backend/metal/kernels/paged_attention.metal
Normal file
131
mlx/backend/metal/kernels/paged_attention.metal
Normal file
@ -0,0 +1,131 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/kernels/paged_attention.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
#define instantiate_paged_attention_inner( \
|
||||
type, head_size, block_size, num_threads, num_simd_lanes, partition_size) \
|
||||
template \
|
||||
[[host_name("paged_attention_" #type "_hs" #head_size "_bs" #block_size \
|
||||
"_nt" #num_threads "_nsl" #num_simd_lanes \
|
||||
"_ps" #partition_size)]] [[kernel]] void \
|
||||
paged_attention< \
|
||||
type, \
|
||||
head_size, \
|
||||
block_size, \
|
||||
num_threads, \
|
||||
num_simd_lanes, \
|
||||
partition_size>( \
|
||||
device float* exp_sums \
|
||||
[[buffer(0), function_constant(use_partitioning)]], \
|
||||
device float* max_logits \
|
||||
[[buffer(1), function_constant(use_partitioning)]], \
|
||||
device type* out [[buffer(2)]], \
|
||||
device const type* q [[buffer(3)]], \
|
||||
device const type* k_cache [[buffer(4)]], \
|
||||
device const type* v_cache [[buffer(5)]], \
|
||||
const constant int& num_kv_heads [[buffer(6)]], \
|
||||
const constant float& scale [[buffer(7)]], \
|
||||
const constant float& softcapping [[buffer(8)]], \
|
||||
device const uint32_t* block_tables [[buffer(9)]], \
|
||||
device const uint32_t* context_lens [[buffer(10)]], \
|
||||
const constant int& max_num_blocks_per_seq [[buffer(11)]], \
|
||||
device const float* alibi_slopes \
|
||||
[[buffer(12), function_constant(use_alibi)]], \
|
||||
const constant int& q_stride [[buffer(13)]], \
|
||||
const constant int& kv_block_stride [[buffer(14)]], \
|
||||
const constant int& kv_head_stride [[buffer(15)]], \
|
||||
threadgroup char* shared_mem [[threadgroup(0)]], \
|
||||
uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]], \
|
||||
uint3 threadgroups_per_grid [[threadgroups_per_grid]], \
|
||||
uint3 thread_position_in_threadgroup \
|
||||
[[thread_position_in_threadgroup]], \
|
||||
uint simd_tid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_paged_attention_v2_reduce_inner( \
|
||||
type, head_size, num_threads, num_simd_lanes, partition_size) \
|
||||
template [[host_name("paged_attention_v2_reduce_" #type "_hs" #head_size \
|
||||
"_nt" #num_threads "_nsl" #num_simd_lanes \
|
||||
"_ps" #partition_size)]] [[kernel]] void \
|
||||
paged_attention_v2_reduce< \
|
||||
type, \
|
||||
head_size, \
|
||||
num_threads, \
|
||||
num_simd_lanes, \
|
||||
partition_size>( \
|
||||
device type * out [[buffer(0)]], \
|
||||
const device float* exp_sums [[buffer(1)]], \
|
||||
const device float* max_logits [[buffer(2)]], \
|
||||
const device type* tmp_out [[buffer(3)]], \
|
||||
device uint32_t* context_lens [[buffer(4)]], \
|
||||
const constant int& max_num_partitions [[buffer(5)]], \
|
||||
threadgroup char* shared_mem [[threadgroup(0)]], \
|
||||
uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]], \
|
||||
uint3 threadgroups_per_grid [[threadgroups_per_grid]], \
|
||||
uint3 thread_position_in_threadgroup [[thread_position_in_threadgroup]], \
|
||||
uint3 threads_per_threadgroup [[threads_per_threadgroup]], \
|
||||
uint simd_tid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_paged_attention_heads( \
|
||||
type, block_size, num_threads, num_simd_lanes, partition_size) \
|
||||
instantiate_paged_attention_inner( \
|
||||
type, 64, block_size, num_threads, num_simd_lanes, partition_size); \
|
||||
instantiate_paged_attention_inner( \
|
||||
type, 80, block_size, num_threads, num_simd_lanes, partition_size); \
|
||||
instantiate_paged_attention_inner( \
|
||||
type, 96, block_size, num_threads, num_simd_lanes, partition_size); \
|
||||
instantiate_paged_attention_inner( \
|
||||
type, 112, block_size, num_threads, num_simd_lanes, partition_size); \
|
||||
instantiate_paged_attention_inner( \
|
||||
type, 128, block_size, num_threads, num_simd_lanes, partition_size); \
|
||||
instantiate_paged_attention_inner( \
|
||||
type, 192, block_size, num_threads, num_simd_lanes, partition_size); \
|
||||
instantiate_paged_attention_inner( \
|
||||
type, 256, block_size, num_threads, num_simd_lanes, partition_size);
|
||||
|
||||
#define instantiate_paged_attention_v2_reduce_heads( \
|
||||
type, num_threads, num_simd_lanes, partition_size) \
|
||||
instantiate_paged_attention_v2_reduce_inner( \
|
||||
type, 64, num_threads, num_simd_lanes, partition_size); \
|
||||
instantiate_paged_attention_v2_reduce_inner( \
|
||||
type, 80, num_threads, num_simd_lanes, partition_size); \
|
||||
instantiate_paged_attention_v2_reduce_inner( \
|
||||
type, 96, num_threads, num_simd_lanes, partition_size); \
|
||||
instantiate_paged_attention_v2_reduce_inner( \
|
||||
type, 112, num_threads, num_simd_lanes, partition_size); \
|
||||
instantiate_paged_attention_v2_reduce_inner( \
|
||||
type, 128, num_threads, num_simd_lanes, partition_size); \
|
||||
instantiate_paged_attention_v2_reduce_inner( \
|
||||
type, 192, num_threads, num_simd_lanes, partition_size); \
|
||||
instantiate_paged_attention_v2_reduce_inner( \
|
||||
type, 256, num_threads, num_simd_lanes, partition_size);
|
||||
|
||||
#define instantiate_paged_attention_block_size( \
|
||||
type, num_threads, num_simd_lanes, partition_size) \
|
||||
instantiate_paged_attention_heads( \
|
||||
type, 8, num_threads, num_simd_lanes, partition_size); \
|
||||
instantiate_paged_attention_heads( \
|
||||
type, 16, num_threads, num_simd_lanes, partition_size); \
|
||||
instantiate_paged_attention_heads( \
|
||||
type, 32, num_threads, num_simd_lanes, partition_size);
|
||||
|
||||
// TODO: tune num_threads = 256
|
||||
// NOTE: partition_size = 0
|
||||
#define instantiate_paged_attention_v1(type, num_simd_lanes) \
|
||||
instantiate_paged_attention_block_size(type, 256, num_simd_lanes, 0);
|
||||
|
||||
// TODO: tune num_threads = 256
|
||||
// NOTE: partition_size = 512
|
||||
#define instantiate_paged_attention_v2(type, num_simd_lanes) \
|
||||
instantiate_paged_attention_block_size(type, 256, num_simd_lanes, 512); \
|
||||
instantiate_paged_attention_v2_reduce_heads(type, 256, num_simd_lanes, 512);
|
||||
|
||||
instantiate_paged_attention_v1(float, 32);
|
||||
instantiate_paged_attention_v1(bfloat16_t, 32);
|
||||
instantiate_paged_attention_v1(half, 32);
|
||||
|
||||
instantiate_paged_attention_v2(float, 32);
|
||||
instantiate_paged_attention_v2(bfloat16_t, 32);
|
||||
instantiate_paged_attention_v2(half, 32);
|
@ -288,4 +288,13 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
|
||||
return d.get_kernel(kernel_name, hash_name, func_consts);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_paged_attention_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& hash_name,
|
||||
const metal::MTLFCList& func_consts,
|
||||
const std::string&) {
|
||||
return d.get_kernel(kernel_name, "mlx", hash_name, func_consts);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
324
mlx/backend/metal/paged_attention.cpp
Normal file
324
mlx/backend/metal/paged_attention.cpp
Normal file
@ -0,0 +1,324 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/paged_attention_primitives.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core::paged_attention {
|
||||
|
||||
static void run_paged_attention(
|
||||
const array& q,
|
||||
const array& k_cache,
|
||||
const array& v_cache,
|
||||
const array& block_tables,
|
||||
const array& context_lens,
|
||||
const int head_size,
|
||||
const int block_size,
|
||||
const int num_kv_heads,
|
||||
const float scale,
|
||||
const float softcapping,
|
||||
const int max_context_len,
|
||||
const int max_num_blocks_per_seq,
|
||||
const bool use_partitioning,
|
||||
const std::optional<array> alibi,
|
||||
const int q_stride,
|
||||
const int kv_block_stride,
|
||||
const int kv_head_stride,
|
||||
const int num_heads,
|
||||
const int num_seqs,
|
||||
array& out,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
const int partition_size = use_partitioning ? 512 : 0;
|
||||
const int num_threads = 256;
|
||||
const int num_simd_lanes = 32;
|
||||
const bool use_alibi = alibi.has_value();
|
||||
|
||||
std::string type_string = get_type_string(q.dtype());
|
||||
std::string kname;
|
||||
kname.reserve(64);
|
||||
concatenate(
|
||||
kname,
|
||||
"paged_attention_",
|
||||
type_string,
|
||||
"_hs",
|
||||
head_size,
|
||||
"_bs",
|
||||
block_size,
|
||||
"_nt",
|
||||
num_threads,
|
||||
"_nsl",
|
||||
num_simd_lanes,
|
||||
"_ps",
|
||||
partition_size);
|
||||
|
||||
auto template_def = get_template_definition(
|
||||
kname,
|
||||
"paged_attention",
|
||||
type_string,
|
||||
head_size,
|
||||
block_size,
|
||||
num_threads,
|
||||
num_simd_lanes,
|
||||
partition_size);
|
||||
|
||||
// Encode and dispatch kernel
|
||||
metal::MTLFCList func_consts = {
|
||||
{use_partitioning, MTL::DataType::DataTypeBool, 10},
|
||||
{use_alibi, MTL::DataType::DataTypeBool, 20},
|
||||
};
|
||||
|
||||
std::string hash_name = kname;
|
||||
auto kernel = get_paged_attention_kernel(
|
||||
d, kname, hash_name, func_consts, template_def);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
int local_max_num_partitions = 1;
|
||||
if (use_partitioning) {
|
||||
local_max_num_partitions =
|
||||
(max_context_len + partition_size - 1) / partition_size;
|
||||
}
|
||||
|
||||
int logits_size = use_partitioning ? partition_size * size_of(float32) : 0;
|
||||
int outputs_size = use_partitioning
|
||||
? ((num_threads / num_simd_lanes) / 2) * head_size * size_of(float32)
|
||||
: 0;
|
||||
int shared_mem_size =
|
||||
use_partitioning ? std::max(logits_size, outputs_size) : 0;
|
||||
if (use_partitioning) {
|
||||
compute_encoder.set_threadgroup_memory_length(shared_mem_size, 0);
|
||||
}
|
||||
|
||||
if (use_partitioning) {
|
||||
auto tmp_out = array(
|
||||
{num_seqs, num_heads, local_max_num_partitions, head_size}, float32);
|
||||
tmp_out.set_data(allocator::malloc(tmp_out.nbytes()));
|
||||
auto exp_sums =
|
||||
array({num_seqs, num_heads, local_max_num_partitions}, float32);
|
||||
exp_sums.set_data(allocator::malloc(exp_sums.nbytes()));
|
||||
|
||||
std::vector<array> temporaries = {tmp_out, exp_sums};
|
||||
|
||||
compute_encoder.set_output_array(tmp_out, 0);
|
||||
compute_encoder.set_output_array(exp_sums, 1);
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
compute_encoder.set_input_array(q, 3);
|
||||
compute_encoder.set_input_array(k_cache, 4);
|
||||
compute_encoder.set_input_array(v_cache, 5);
|
||||
|
||||
compute_encoder.set_bytes(num_kv_heads, 6);
|
||||
compute_encoder.set_bytes(scale, 7);
|
||||
compute_encoder.set_bytes(softcapping, 8);
|
||||
|
||||
compute_encoder.set_input_array(block_tables, 9);
|
||||
compute_encoder.set_input_array(context_lens, 10);
|
||||
|
||||
compute_encoder.set_bytes(max_num_blocks_per_seq, 11);
|
||||
|
||||
if (use_alibi) {
|
||||
compute_encoder.set_input_array(alibi.value(), 12);
|
||||
}
|
||||
|
||||
compute_encoder.set_bytes(q_stride, 13);
|
||||
compute_encoder.set_bytes(kv_block_stride, 14);
|
||||
compute_encoder.set_bytes(kv_head_stride, 15);
|
||||
|
||||
MTL::Size grid_dims(num_heads, num_seqs, local_max_num_partitions);
|
||||
MTL::Size group_dims(num_threads, 1, 1);
|
||||
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
d.add_temporaries(std::move(temporaries), s.index);
|
||||
} else {
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
compute_encoder.set_input_array(q, 3);
|
||||
compute_encoder.set_input_array(k_cache, 4);
|
||||
compute_encoder.set_input_array(v_cache, 5);
|
||||
|
||||
compute_encoder.set_bytes(num_kv_heads, 6);
|
||||
compute_encoder.set_bytes(scale, 7);
|
||||
compute_encoder.set_bytes(softcapping, 8);
|
||||
|
||||
compute_encoder.set_input_array(block_tables, 9);
|
||||
compute_encoder.set_input_array(context_lens, 10);
|
||||
|
||||
compute_encoder.set_bytes(max_num_blocks_per_seq, 11);
|
||||
|
||||
if (use_alibi) {
|
||||
compute_encoder.set_input_array(alibi.value(), 12);
|
||||
}
|
||||
|
||||
compute_encoder.set_bytes(q_stride, 13);
|
||||
compute_encoder.set_bytes(kv_block_stride, 14);
|
||||
compute_encoder.set_bytes(kv_head_stride, 15);
|
||||
|
||||
MTL::Size grid_dims(num_heads, num_seqs, 1);
|
||||
MTL::Size group_dims(num_threads, 1, 1);
|
||||
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
void paged_attention_v1(
|
||||
const array& q,
|
||||
const array& k_cache,
|
||||
const array& v_cache,
|
||||
const array& block_tables,
|
||||
const array& context_lens,
|
||||
const int head_size,
|
||||
const int block_size,
|
||||
const int num_kv_heads,
|
||||
const float scale,
|
||||
const float softcapping,
|
||||
const int max_context_len,
|
||||
const int max_num_blocks_per_seq,
|
||||
const std::optional<array> alibi,
|
||||
const int q_stride,
|
||||
const int kv_block_stride,
|
||||
const int kv_head_stride,
|
||||
const int num_heads,
|
||||
const int num_seqs,
|
||||
array& out,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
run_paged_attention(
|
||||
q,
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_tables,
|
||||
context_lens,
|
||||
head_size,
|
||||
block_size,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
softcapping,
|
||||
max_context_len,
|
||||
max_num_blocks_per_seq,
|
||||
/*use_partitioning=*/false,
|
||||
alibi,
|
||||
q_stride,
|
||||
kv_block_stride,
|
||||
kv_head_stride,
|
||||
num_heads,
|
||||
num_seqs,
|
||||
out,
|
||||
d,
|
||||
s);
|
||||
}
|
||||
|
||||
void paged_attention_v2(
|
||||
const array& q,
|
||||
const array& k_cache,
|
||||
const array& v_cache,
|
||||
const array& block_tables,
|
||||
const array& context_lens,
|
||||
const int head_size,
|
||||
const int block_size,
|
||||
const int num_kv_heads,
|
||||
const float scale,
|
||||
const float softcapping,
|
||||
const int max_context_len,
|
||||
const int max_num_blocks_per_seq,
|
||||
const int /* max_num_partitions */,
|
||||
const std::optional<array> alibi,
|
||||
const int q_stride,
|
||||
const int kv_block_stride,
|
||||
const int kv_head_stride,
|
||||
const int num_heads,
|
||||
const int num_seqs,
|
||||
array& out,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
run_paged_attention(
|
||||
q,
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_tables,
|
||||
context_lens,
|
||||
head_size,
|
||||
block_size,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
softcapping,
|
||||
max_context_len,
|
||||
max_num_blocks_per_seq,
|
||||
/*use_partitioning=*/true,
|
||||
alibi,
|
||||
q_stride,
|
||||
kv_block_stride,
|
||||
kv_head_stride,
|
||||
num_heads,
|
||||
num_seqs,
|
||||
out,
|
||||
d,
|
||||
s);
|
||||
}
|
||||
|
||||
void PagedAttention::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& q = inputs[0];
|
||||
auto& k_cache = inputs[1];
|
||||
auto& v_cache = inputs[2];
|
||||
auto& block_tables = inputs[3];
|
||||
auto& context_lens = inputs[4];
|
||||
const auto alibi_slopes =
|
||||
inputs.size() == 6 ? std::optional{inputs[5]} : std::nullopt;
|
||||
|
||||
if (use_v1_) {
|
||||
paged_attention_v1(
|
||||
q,
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_tables,
|
||||
context_lens,
|
||||
head_size_,
|
||||
block_size_,
|
||||
num_kv_heads_,
|
||||
softmax_scale_,
|
||||
softcapping_.value_or(1.),
|
||||
max_context_len_,
|
||||
max_num_blocks_per_seq_,
|
||||
alibi_slopes,
|
||||
q_stride_,
|
||||
kv_block_stride_,
|
||||
kv_head_stride_,
|
||||
num_heads_,
|
||||
num_seqs_,
|
||||
out,
|
||||
d,
|
||||
s);
|
||||
} else {
|
||||
paged_attention_v2(
|
||||
q,
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_tables,
|
||||
context_lens,
|
||||
head_size_,
|
||||
block_size_,
|
||||
num_kv_heads_,
|
||||
softmax_scale_,
|
||||
softcapping_.value_or(1.),
|
||||
max_context_len_,
|
||||
max_num_blocks_per_seq_,
|
||||
max_num_partitions_,
|
||||
alibi_slopes,
|
||||
q_stride_,
|
||||
kv_block_stride_,
|
||||
kv_head_stride_,
|
||||
num_heads_,
|
||||
num_seqs_,
|
||||
out,
|
||||
d,
|
||||
s);
|
||||
}
|
||||
}
|
||||
} // namespace mlx::core::paged_attention
|
@ -17,6 +17,7 @@
|
||||
#include "mlx/linalg.h"
|
||||
#include "mlx/memory.h"
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/paged_attention.h"
|
||||
#include "mlx/random.h"
|
||||
#include "mlx/stream.h"
|
||||
#include "mlx/transforms.h"
|
||||
|
170
mlx/paged_attention.cpp
Normal file
170
mlx/paged_attention.cpp
Normal file
@ -0,0 +1,170 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
// Required for using M_PI in MSVC.
|
||||
#define _USE_MATH_DEFINES
|
||||
|
||||
#include <algorithm>
|
||||
#include <climits>
|
||||
#include <cmath>
|
||||
#include <numeric>
|
||||
#include <set>
|
||||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
|
||||
#include "mlx/paged_attention_primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core::paged_attention {
|
||||
|
||||
array paged_attention(
|
||||
const array& q,
|
||||
const array& k_cache,
|
||||
const array& v_cache,
|
||||
const array& block_tables,
|
||||
const array& context_lens,
|
||||
int max_context_len,
|
||||
float softmax_scale,
|
||||
std::optional<array> alibi_slopes = std::nullopt,
|
||||
std::optional<float> softcapping = std::nullopt,
|
||||
StreamOrDevice s_ = {}) {
|
||||
auto s = to_stream(s_);
|
||||
|
||||
// supported dtypes
|
||||
if (!issubdtype(q.dtype(), floating)) {
|
||||
throw std::invalid_argument(
|
||||
"[paged_attention] Only real floating types are supported");
|
||||
}
|
||||
if (!(q.dtype() == k_cache.dtype() && k_cache.dtype() == v_cache.dtype())) {
|
||||
throw std::invalid_argument(
|
||||
"[paged_attention] q/k_cache/v_cache dtype must match");
|
||||
}
|
||||
if (!(block_tables.dtype() == uint32 && context_lens.dtype() == uint32)) {
|
||||
throw std::invalid_argument(
|
||||
"[paged_attention] block_tables/context_lens dtype must be uint32");
|
||||
}
|
||||
|
||||
// rank checks
|
||||
if (q.ndim() != 3)
|
||||
throw std::invalid_argument("[paged_attention] `q` must be rank-3");
|
||||
if (k_cache.ndim() != 5)
|
||||
throw std::invalid_argument("[paged_attention] `k_cache` must be rank-5");
|
||||
if (v_cache.ndim() != 4)
|
||||
throw std::invalid_argument("[paged_attention] `v_cache` must be rank-4");
|
||||
if (block_tables.ndim() != 2)
|
||||
throw std::invalid_argument(
|
||||
"[paged_attention] `block_tables` must be rank-2");
|
||||
if (context_lens.ndim() != 1)
|
||||
throw std::invalid_argument(
|
||||
"[paged_attention] `context_lens` must be rank-1");
|
||||
|
||||
// 4. Shape consistency
|
||||
const auto& q_shape = q.shape(); // [num_seqs, num_heads, head_size]
|
||||
const auto& kc_shape = k_cache.shape();
|
||||
const auto& vc_shape = v_cache.shape();
|
||||
const auto& bt_shape = block_tables.shape();
|
||||
const auto& cl_shape = context_lens.shape();
|
||||
|
||||
int num_seqs = q_shape[0];
|
||||
int num_heads = q_shape[1];
|
||||
int head_size = q_shape[2];
|
||||
|
||||
// Allowed head sizes
|
||||
switch (head_size) {
|
||||
case 64:
|
||||
case 80:
|
||||
case 96:
|
||||
case 112:
|
||||
case 128:
|
||||
case 192:
|
||||
case 256:
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument(
|
||||
"[paged_attention] `head_size` must be one of "
|
||||
"{64, 80, 96, 112, 128, 192, 256}");
|
||||
}
|
||||
|
||||
int max_num_blocks_per_seq = bt_shape[1];
|
||||
|
||||
// block_tables first dimension must match num_seqs
|
||||
if (bt_shape[0] != num_seqs) {
|
||||
std::stringstream ss;
|
||||
ss << "[paged_attention] block_tables.shape[0] (" << bt_shape[0]
|
||||
<< ") must equal q.shape[0] (" << num_seqs << ")";
|
||||
throw std::invalid_argument(ss.str());
|
||||
}
|
||||
|
||||
// Extract k_cache dimensions
|
||||
int num_blocks = kc_shape[0];
|
||||
int num_kv_heads = kc_shape[1];
|
||||
int head_size_kc = kc_shape[2];
|
||||
int block_size = kc_shape[3];
|
||||
int x = kc_shape[4];
|
||||
|
||||
if (head_size_kc * x != head_size) {
|
||||
std::stringstream ss;
|
||||
ss << "[paged_attention] k_cache head_size (" << head_size_kc << " * " << x
|
||||
<< ") must equal q head_size (" << head_size << ")";
|
||||
throw std::invalid_argument(ss.str());
|
||||
}
|
||||
|
||||
// v_cache must match the derived dimensions
|
||||
if (!(vc_shape[0] == num_blocks && vc_shape[1] == num_kv_heads &&
|
||||
vc_shape[2] == head_size && vc_shape[3] == block_size)) {
|
||||
throw std::invalid_argument(
|
||||
"[paged_attention] `v_cache` shape mismatch with `k_cache`/`q`");
|
||||
}
|
||||
|
||||
// context_lens length must match num_seqs
|
||||
if (cl_shape[0] != num_seqs) {
|
||||
std::stringstream ss;
|
||||
ss << "paged_attention: context_lens length (" << cl_shape[0]
|
||||
<< ") must equal q.shape[0] (" << num_seqs << ")";
|
||||
throw std::invalid_argument(ss.str());
|
||||
}
|
||||
|
||||
constexpr int partition_size = 512;
|
||||
int max_num_partitions =
|
||||
(max_context_len + partition_size - 1) / partition_size; // ceil‑div
|
||||
bool use_v1 = ((max_num_partitions == 1) || (num_seqs * num_heads > 512)) &&
|
||||
(partition_size % block_size == 0);
|
||||
|
||||
auto out_shape = q.shape();
|
||||
|
||||
auto inputs = std::vector{
|
||||
std::move(q),
|
||||
std::move(k_cache),
|
||||
std::move(v_cache),
|
||||
std::move(block_tables),
|
||||
std::move(context_lens)};
|
||||
if (alibi_slopes.has_value()) {
|
||||
inputs.push_back(std::move(alibi_slopes.value()));
|
||||
}
|
||||
|
||||
int q_stride = q.strides()[0];
|
||||
int kv_block_stride = k_cache.strides()[0];
|
||||
int kv_head_stride = k_cache.strides()[1];
|
||||
|
||||
return array(
|
||||
std::move(out_shape),
|
||||
q.dtype(),
|
||||
std::make_shared<PagedAttention>(
|
||||
to_stream(s),
|
||||
use_v1,
|
||||
max_context_len,
|
||||
head_size,
|
||||
block_size,
|
||||
num_kv_heads,
|
||||
softmax_scale,
|
||||
max_num_blocks_per_seq,
|
||||
max_num_partitions,
|
||||
q_stride,
|
||||
kv_block_stride,
|
||||
kv_head_stride,
|
||||
num_heads,
|
||||
num_seqs,
|
||||
softcapping),
|
||||
inputs);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::paged_attention
|
34
mlx/paged_attention.h
Normal file
34
mlx/paged_attention.h
Normal file
@ -0,0 +1,34 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <optional>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/device.h"
|
||||
#include "mlx/stream.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core::paged_attention {
|
||||
|
||||
/**
|
||||
* \defgroup ops Paged attention operations
|
||||
* @{
|
||||
*/
|
||||
|
||||
/** PagedAttention operation. */
|
||||
array paged_attention(
|
||||
const array& q,
|
||||
const array& k_cache,
|
||||
const array& v_cache,
|
||||
const array& block_tables,
|
||||
const array& context_lens,
|
||||
int max_context_len,
|
||||
float softmax_scale,
|
||||
std::optional<array> alibi_slopes = std::nullopt,
|
||||
std::optional<float> softcapping = std::nullopt,
|
||||
StreamOrDevice s_ = {});
|
||||
|
||||
/** @} */
|
||||
|
||||
} // namespace mlx::core::paged_attention
|
82
mlx/paged_attention_primitives.h
Normal file
82
mlx/paged_attention_primitives.h
Normal file
@ -0,0 +1,82 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
// Required for using M_PI in MSVC.
|
||||
#define _USE_MATH_DEFINES
|
||||
|
||||
#include <optional>
|
||||
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core::paged_attention {
|
||||
|
||||
class PagedAttention : public UnaryPrimitive {
|
||||
public:
|
||||
explicit PagedAttention(
|
||||
Stream stream,
|
||||
bool use_v1,
|
||||
int max_context_len,
|
||||
int head_size,
|
||||
int block_size,
|
||||
int num_kv_heads,
|
||||
int max_num_blocks_per_seq,
|
||||
int max_num_partitions,
|
||||
int q_stride,
|
||||
int kv_block_stride,
|
||||
int kv_head_stride,
|
||||
int num_heads,
|
||||
int num_seqs,
|
||||
float softmax_scale,
|
||||
std::optional<float> softcapping = std::nullopt)
|
||||
: UnaryPrimitive(stream),
|
||||
use_v1_(use_v1),
|
||||
max_context_len_(max_context_len),
|
||||
head_size_(head_size),
|
||||
block_size_(block_size),
|
||||
num_kv_heads_(num_kv_heads),
|
||||
max_num_blocks_per_seq_(max_num_blocks_per_seq),
|
||||
max_num_partitions_(max_num_partitions),
|
||||
q_stride_(q_stride),
|
||||
kv_block_stride_(kv_block_stride),
|
||||
kv_head_stride_(kv_head_stride),
|
||||
num_heads_(num_heads),
|
||||
num_seqs_(num_seqs),
|
||||
softmax_scale_(softmax_scale),
|
||||
softcapping_(softcapping) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& outputs) override {
|
||||
throw std::runtime_error("NYI");
|
||||
}
|
||||
|
||||
void eval_gpu(const std::vector<array>& inputs, array& outputs) override;
|
||||
|
||||
DEFINE_PRINT(PagedAttention);
|
||||
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||
auto state() const {
|
||||
return std::make_tuple(
|
||||
max_context_len_,
|
||||
head_size_,
|
||||
block_size_,
|
||||
softmax_scale_,
|
||||
softcapping_);
|
||||
}
|
||||
|
||||
private:
|
||||
bool use_v1_;
|
||||
int max_context_len_;
|
||||
int head_size_;
|
||||
int block_size_;
|
||||
int num_kv_heads_;
|
||||
int max_num_blocks_per_seq_;
|
||||
int max_num_partitions_;
|
||||
int q_stride_;
|
||||
int kv_block_stride_;
|
||||
int kv_head_stride_;
|
||||
int num_heads_;
|
||||
int num_seqs_;
|
||||
float softmax_scale_;
|
||||
std::optional<float> softcapping_ = std::nullopt;
|
||||
};
|
||||
|
||||
} // namespace mlx::core::paged_attention
|
@ -1,25 +1,50 @@
|
||||
cuda_skip = {
|
||||
"TestArray.test_api",
|
||||
"TestArray.test_setitem",
|
||||
"TestAutograd.test_cumprod_grad",
|
||||
"TestAutograd.test_slice_grads",
|
||||
"TestAutograd.test_split_against_slice",
|
||||
"TestAutograd.test_stop_gradient",
|
||||
"TestAutograd.test_topk_grad",
|
||||
"TestAutograd.test_update_state",
|
||||
"TestAutograd.test_vjp",
|
||||
"TestBF16.test_arg_reduction_ops",
|
||||
"TestBF16.test_binary_ops",
|
||||
"TestBF16.test_reduction_ops",
|
||||
"TestBlas.test_block_masked_matmul",
|
||||
"TestBlas.test_complex_gemm",
|
||||
"TestCompile.test_compile_dynamic_dims",
|
||||
"TestEinsum.test_ellipses",
|
||||
"TestEinsum.test_opt_einsum_test_cases",
|
||||
"TestLoad.test_load_f8_e4m3",
|
||||
"TestMemory.test_memory_info",
|
||||
"TestLayers.test_group_norm",
|
||||
"TestLayers.test_pooling",
|
||||
"TestLayers.test_quantized_embedding",
|
||||
"TestLayers.test_sin_pe",
|
||||
"TestLayers.test_upsample",
|
||||
"TestOps.test_array_equal",
|
||||
"TestOps.test_complex_ops",
|
||||
"TestOps.test_dynamic_slicing",
|
||||
"TestOps.test_softmax",
|
||||
"TestOps.test_sort",
|
||||
"TestOps.test_tile",
|
||||
"TestReduce.test_axis_permutation_sums",
|
||||
"TestReduce.test_dtypes",
|
||||
"TestReduce.test_expand_sums",
|
||||
"TestReduce.test_many_reduction_axes",
|
||||
"TestUpsample.test_torch_upsample",
|
||||
# DivMod NYI
|
||||
"TestOps.test_divmod",
|
||||
"TestEval.test_multi_output_eval_during_transform",
|
||||
# Partition NYI
|
||||
"TestAutograd.test_topk_grad",
|
||||
"TestOps.test_argpartition",
|
||||
"TestOps.test_partition",
|
||||
# Block masked matmul NYI
|
||||
"TestBlas.test_block_masked_matmul",
|
||||
# Gather matmul NYI
|
||||
"TestBlas.test_gather_matmul",
|
||||
"TestBlas.test_gather_matmul_grad",
|
||||
"TestBlas.test_matmul_batched",
|
||||
"TestBlas.test_matrix_vector_attn",
|
||||
"TestCompile.test_compile_dynamic_dims",
|
||||
"TestCompile.test_compile_inf",
|
||||
"TestCompile.test_inf_constant",
|
||||
# Scan NYI
|
||||
"TestAutograd.test_cumprod_grad",
|
||||
"TestOps.test_scans",
|
||||
"TestOps.test_logcumsumexp",
|
||||
# Hadamard NYI
|
||||
"TestOps.test_hadamard",
|
||||
"TestOps.test_hadamard_grad_vmap",
|
||||
# Convolutions NYI
|
||||
"TestConv.test_1d_conv_with_2d",
|
||||
"TestConv.test_asymmetric_padding",
|
||||
"TestConv.test_basic_grad_shapes",
|
||||
@ -46,12 +71,11 @@ cuda_skip = {
|
||||
"TestConvTranspose.test_torch_conv_transpose_3D",
|
||||
"TestConvTranspose.test_torch_conv_transpose_3D_grad",
|
||||
"TestConvTranspose.test_torch_conv_transpose_3d_output_padding",
|
||||
"TestEinsum.test_attention",
|
||||
"TestEinsum.test_ellipses",
|
||||
"TestEinsum.test_opt_einsum_test_cases",
|
||||
"TestEval.test_multi_output_eval_during_transform",
|
||||
"TestExportImport.test_export_conv",
|
||||
"TestFast.test_rope_grad",
|
||||
"TestLayers.test_conv1d",
|
||||
"TestLayers.test_conv2d",
|
||||
"TestVmap.test_vmap_conv",
|
||||
# FFTs NYI
|
||||
"TestFFT.test_fft",
|
||||
"TestFFT.test_fft_big_powers_of_two",
|
||||
"TestFFT.test_fft_contiguity",
|
||||
@ -61,61 +85,22 @@ cuda_skip = {
|
||||
"TestFFT.test_fft_large_numbers",
|
||||
"TestFFT.test_fft_shared_mem",
|
||||
"TestFFT.test_fftn",
|
||||
"TestInit.test_orthogonal",
|
||||
# Lapack ops NYI
|
||||
"TestLinalg.test_cholesky",
|
||||
"TestLinalg.test_cholesky_inv",
|
||||
"TestLinalg.test_eig",
|
||||
"TestLinalg.test_eigh",
|
||||
"TestLinalg.test_inverse",
|
||||
"TestVmap.test_vmap_inverse",
|
||||
"TestLinalg.test_lu",
|
||||
"TestLinalg.test_lu_factor",
|
||||
"TestLinalg.test_pseudo_inverse",
|
||||
"TestLinalg.test_qr_factorization",
|
||||
"TestInit.test_orthogonal",
|
||||
"TestLinalg.test_svd_decomposition",
|
||||
"TestVmap.test_vmap_svd",
|
||||
"TestLinalg.test_tri_inverse",
|
||||
"TestLoad.test_load_f8_e4m3",
|
||||
"TestLosses.test_binary_cross_entropy",
|
||||
"TestMemory.test_memory_info",
|
||||
"TestLayers.test_conv1d",
|
||||
"TestLayers.test_conv2d",
|
||||
"TestLayers.test_elu",
|
||||
"TestLayers.test_group_norm",
|
||||
"TestLayers.test_hard_shrink",
|
||||
"TestLayers.test_pooling",
|
||||
"TestLayers.test_quantized_embedding",
|
||||
"TestLayers.test_sin_pe",
|
||||
"TestLayers.test_softshrink",
|
||||
"TestLayers.test_upsample",
|
||||
"TestOps.test_argpartition",
|
||||
"TestOps.test_array_equal",
|
||||
"TestOps.test_as_strided",
|
||||
"TestOps.test_atleast_1d",
|
||||
"TestOps.test_atleast_2d",
|
||||
"TestOps.test_atleast_3d",
|
||||
"TestOps.test_binary_ops",
|
||||
"TestOps.test_bitwise_grad",
|
||||
"TestOps.test_complex_ops",
|
||||
"TestOps.test_divmod",
|
||||
"TestOps.test_dynamic_slicing",
|
||||
"TestOps.test_hadamard",
|
||||
"TestOps.test_hadamard_grad_vmap",
|
||||
"TestOps.test_irregular_binary_ops",
|
||||
"TestOps.test_isfinite",
|
||||
"TestOps.test_kron",
|
||||
"TestOps.test_log",
|
||||
"TestOps.test_log10",
|
||||
"TestOps.test_log1p",
|
||||
"TestOps.test_log2",
|
||||
"TestOps.test_logaddexp",
|
||||
"TestOps.test_logcumsumexp",
|
||||
"TestOps.test_partition",
|
||||
"TestOps.test_scans",
|
||||
"TestOps.test_slice_update_reversed",
|
||||
"TestOps.test_softmax",
|
||||
"TestOps.test_sort",
|
||||
"TestOps.test_tensordot",
|
||||
"TestOps.test_tile",
|
||||
"TestOps.test_view",
|
||||
# Quantization NYI
|
||||
"TestQuantized.test_gather_matmul_grad",
|
||||
"TestQuantized.test_gather_qmm",
|
||||
"TestQuantized.test_gather_qmm_sorted",
|
||||
@ -131,13 +116,4 @@ cuda_skip = {
|
||||
"TestQuantized.test_small_matrix",
|
||||
"TestQuantized.test_throw",
|
||||
"TestQuantized.test_vjp_scales_biases",
|
||||
"TestReduce.test_axis_permutation_sums",
|
||||
"TestReduce.test_dtypes",
|
||||
"TestReduce.test_expand_sums",
|
||||
"TestReduce.test_many_reduction_axes",
|
||||
"TestUpsample.test_torch_upsample",
|
||||
"TestVmap.test_unary",
|
||||
"TestVmap.test_vmap_conv",
|
||||
"TestVmap.test_vmap_inverse",
|
||||
"TestVmap.test_vmap_svd",
|
||||
}
|
||||
|
@ -1187,7 +1187,7 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
check_slices(np.zeros((3, 2)), np.array([[3, 3], [4, 4]]), np.array([0, 1]))
|
||||
check_slices(np.zeros((3, 2)), np.array([[3, 3], [4, 4]]), np.array([0, 1]))
|
||||
check_slices(
|
||||
np.zeros((3, 2)), np.array([[3, 3], [4, 4], [5, 5]]), np.array([0, 0, 1])
|
||||
np.zeros((3, 2)), np.array([[3, 3], [4, 4], [5, 5]]), np.array([0, 2, 1])
|
||||
)
|
||||
|
||||
# Multiple slices
|
||||
|
@ -83,14 +83,14 @@ class TestLosses(mlx_tests.MLXTestCase):
|
||||
logits, targets, reduction="mean"
|
||||
)
|
||||
expected_mean = mx.mean(expected_none)
|
||||
self.assertEqual(losses_mean, expected_mean)
|
||||
self.assertTrue(mx.allclose(losses_mean, expected_mean))
|
||||
|
||||
# Test with reduction 'sum'
|
||||
losses_sum = nn.losses.binary_cross_entropy(
|
||||
logits, targets, reduction="sum"
|
||||
)
|
||||
expected_sum = mx.sum(expected_none)
|
||||
self.assertEqual(losses_sum, expected_sum)
|
||||
self.assertTrue(mx.allclose(losses_sum, expected_sum))
|
||||
|
||||
# With weights, no label smoothing
|
||||
weights = mx.array([1.0, 2.0, 1.0, 2.0])
|
||||
|
@ -2586,17 +2586,6 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertEqualArray(result, mx.array(expected))
|
||||
|
||||
def test_atleast_1d(self):
|
||||
def compare_nested_lists(x, y):
|
||||
if isinstance(x, list) and isinstance(y, list):
|
||||
if len(x) != len(y):
|
||||
return False
|
||||
for i in range(len(x)):
|
||||
if not compare_nested_lists(x[i], y[i]):
|
||||
return False
|
||||
return True
|
||||
else:
|
||||
return x == y
|
||||
|
||||
# Test 1D input
|
||||
arrays = [
|
||||
[1],
|
||||
@ -2614,23 +2603,11 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
for i, array in enumerate(arrays):
|
||||
mx_res = mx.atleast_1d(mx.array(array))
|
||||
np_res = np.atleast_1d(np.array(array))
|
||||
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
|
||||
self.assertEqual(mx_res.shape, np_res.shape)
|
||||
self.assertEqual(mx_res.ndim, np_res.ndim)
|
||||
self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i])))
|
||||
self.assertTrue(mx.array_equal(mx_res, atleast_arrays[i]))
|
||||
|
||||
def test_atleast_2d(self):
|
||||
def compare_nested_lists(x, y):
|
||||
if isinstance(x, list) and isinstance(y, list):
|
||||
if len(x) != len(y):
|
||||
return False
|
||||
for i in range(len(x)):
|
||||
if not compare_nested_lists(x[i], y[i]):
|
||||
return False
|
||||
return True
|
||||
else:
|
||||
return x == y
|
||||
|
||||
# Test 1D input
|
||||
arrays = [
|
||||
[1],
|
||||
@ -2648,23 +2625,11 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
for i, array in enumerate(arrays):
|
||||
mx_res = mx.atleast_2d(mx.array(array))
|
||||
np_res = np.atleast_2d(np.array(array))
|
||||
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
|
||||
self.assertEqual(mx_res.shape, np_res.shape)
|
||||
self.assertEqual(mx_res.ndim, np_res.ndim)
|
||||
self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i])))
|
||||
self.assertTrue(mx.array_equal(mx_res, atleast_arrays[i]))
|
||||
|
||||
def test_atleast_3d(self):
|
||||
def compare_nested_lists(x, y):
|
||||
if isinstance(x, list) and isinstance(y, list):
|
||||
if len(x) != len(y):
|
||||
return False
|
||||
for i in range(len(x)):
|
||||
if not compare_nested_lists(x[i], y[i]):
|
||||
return False
|
||||
return True
|
||||
else:
|
||||
return x == y
|
||||
|
||||
# Test 1D input
|
||||
arrays = [
|
||||
[1],
|
||||
@ -2682,10 +2647,9 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
for i, array in enumerate(arrays):
|
||||
mx_res = mx.atleast_3d(mx.array(array))
|
||||
np_res = np.atleast_3d(np.array(array))
|
||||
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
|
||||
self.assertEqual(mx_res.shape, np_res.shape)
|
||||
self.assertEqual(mx_res.ndim, np_res.ndim)
|
||||
self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i])))
|
||||
self.assertTrue(mx.array_equal(mx_res, atleast_arrays[i]))
|
||||
|
||||
def test_issubdtype(self):
|
||||
self.assertTrue(mx.issubdtype(mx.bfloat16, mx.inexact))
|
||||
|
Loading…
Reference in New Issue
Block a user