Compare commits

..

12 Commits

Author SHA1 Message Date
Awni Hannun
ee18e1cbf0 patch bump (#2588) 2025-09-11 17:10:09 -07:00
Awni Hannun
af120c2bc0 set nccl ABI version (#2587) 2025-09-11 16:55:53 -07:00
Cheng
6a3acf2301 [CUDA] Set bias as input when using bias epilogue (#2584) 2025-09-11 15:31:09 +09:00
Awni Hannun
d6977f2a57 Add sdpa with sinks (#2558)
* add sdpa with sinks

* fix 2 pass

* fix matrix sdpa

* fix perf regression

* add to cuda (#2580)
2025-09-10 14:53:00 -07:00
Gökdeniz Gülmez
db5443e831 Adding Relu2 (#2582)
* in. com.

* upd. ackn.

* update __init__

* nits

* nits + format

* used mx.maximum(x, 0) instead of calling the function and moves relu6 under relu2 to make it nicer

* same with _make_activation_module

* Update python/mlx/nn/layers/activations.py

upd

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* update funct.rst

* upd. layers.rst

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2025-09-10 07:24:30 -07:00
Cheng
52b8384d10 Fix flaky addmm tests (#2581) 2025-09-10 14:22:22 +09:00
Cheng
44cc5da4bc [CUDA] Fix alpha not respected when using bias epilogue (#2578) 2025-09-10 09:08:01 +09:00
Cheng
dde3682b69 [CUDA] Use GEMM with epilogue instead of AddMM (#2569) 2025-09-09 13:18:49 +09:00
Awni Hannun
17310d91a6 Add batch offsets for mx.fast.rope (#2564)
* implement batch rope for Metal

* cuda rope (#2576)
2025-09-08 17:35:07 -07:00
Cheng
b194d65a6a Some tweaks in cmake files (#2574)
* Do proper check of Metal lib

* Update doctest to get rid of cmake version hack
2025-09-09 08:27:18 +09:00
Cheng
a44b27f5f8 Fix a few ccache cache miss (#2573)
* Fix ccache cache miss

* Do not define _VERSION_ in python bindings
2025-09-09 07:41:05 +09:00
Awni Hannun
e5a33f2223 faster depthwise 1D conv (#2567) 2025-09-08 11:37:23 -07:00
35 changed files with 1084 additions and 545 deletions

View File

@@ -19,7 +19,7 @@ MLX was developed with contributions from the following individuals:
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
- Paul Paczuski: Improved stability of BCE loss calculation
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer.
- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer, and the `ReLU²` activation function.
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />

View File

@@ -87,22 +87,21 @@ cmake_policy(SET CMP0135 NEW)
add_library(mlx)
if(MLX_BUILD_METAL)
set(METAL_LIB "-framework Metal")
set(FOUNDATION_LIB "-framework Foundation")
set(QUARTZ_LIB "-framework QuartzCore")
endif()
if(MLX_BUILD_CUDA)
enable_language(CUDA)
endif()
if(MLX_BUILD_METAL AND NOT METAL_LIB)
message(STATUS "Metal not found. Unable to build GPU")
set(MLX_BUILD_METAL OFF)
set(MLX_METAL_DEBUG OFF)
elseif(MLX_BUILD_METAL)
message(STATUS "Building METAL sources")
if(MLX_BUILD_METAL)
find_library(METAL_LIB Metal)
find_library(FOUNDATION_LIB Foundation)
find_library(QUARTZ_LIB QuartzCore)
if(METAL_LIB)
message(STATUS "Metal found ${METAL_LIB}")
else()
message(
FATAL_ERROR
"Metal not found. Set MLX_BUILD_METAL=OFF to build without GPU")
endif()
if(MLX_METAL_DEBUG)
add_compile_definitions(MLX_METAL_DEBUG)
@@ -111,7 +110,8 @@ elseif(MLX_BUILD_METAL)
# Throw an error if xcrun not found
execute_process(
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
OUTPUT_VARIABLE MACOS_SDK_VERSION COMMAND_ERROR_IS_FATAL ANY)
OUTPUT_VARIABLE MACOS_SDK_VERSION
OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
if(${MACOS_SDK_VERSION} LESS 14.0)
message(

View File

@@ -27,6 +27,7 @@ simple functions.
mish
prelu
relu
relu2
relu6
selu
sigmoid

View File

@@ -50,6 +50,7 @@ Layers
QuantizedLinear
RMSNorm
ReLU
ReLU2
ReLU6
RNN
RoPE

View File

@@ -85,10 +85,10 @@ cublasLtMatrixLayout_t create_matrix_layout(
int32_t batch_count,
int64_t batch_stride) {
cublasLtMatrixLayout_t desc;
if (transposed) {
std::swap(rows, cols);
}
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&desc, type, rows, cols, ld));
cublasLtOrder_t order = transposed ? CUBLASLT_ORDER_COL : CUBLASLT_ORDER_ROW;
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(cublasLtOrder_t)));
if (batch_count > 1) {
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc,
@@ -138,25 +138,34 @@ CublasGemm::CublasGemm(
CUBLASLT_MATMUL_DESC_POINTER_MODE,
&pointer_mode,
sizeof(int32_t)));
cublasOperation_t op = CUBLAS_OP_N;
// In cublasLt matrices use column-major layout, while it is possible to use
// the CUBLASLT_ORDER_ROW option to switch to row-major layout, the bias
// epilogue does not work with the option. So instead we swap A and B to make
// cublasLt return the row-major result, which works because:
// - the data of a matrix in row-major layout is identical to its transpose in
// column-major layout
// - C^T = (A @ B)^T = B^T @ A^T
cublasOperation_t a_op = b_transposed ? CUBLAS_OP_T : CUBLAS_OP_N;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
CUBLASLT_MATMUL_DESC_TRANSA,
&op,
&a_op,
sizeof(cublasOperation_t)));
cublasOperation_t b_op = a_transposed ? CUBLAS_OP_T : CUBLAS_OP_N;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
CUBLASLT_MATMUL_DESC_TRANSB,
&op,
&b_op,
sizeof(cublasOperation_t)));
auto type = dtype_to_cublas_type(dtype);
a_desc_ = create_matrix_layout(
type, a_rows, a_cols, a_transposed, lda, batch_count, a_batch_stride);
type, b_cols, b_rows, b_transposed, ldb, batch_count, b_batch_stride);
b_desc_ = create_matrix_layout(
type, b_rows, b_cols, b_transposed, ldb, batch_count, b_batch_stride);
type, a_cols, a_rows, a_transposed, lda, batch_count, a_batch_stride);
out_desc_ = create_matrix_layout(
type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols);
type, b_cols, a_rows, false, b_cols, batch_count, a_rows * b_cols);
}
CublasGemm::CublasGemm(
@@ -191,7 +200,7 @@ CublasGemm::CublasGemm(
b_batch_stride) {
auto type = dtype_to_cublas_type(dtype);
c_desc_ = create_matrix_layout(
type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride);
type, b_cols, a_rows, false, ldc, batch_count, c_batch_stride);
}
CublasGemm::~CublasGemm() {
@@ -213,14 +222,30 @@ void CublasGemm::set_out(
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_));
out_desc_ = create_matrix_layout(
dtype_to_cublas_type(dtype),
rows,
cols,
rows,
transposed,
ld,
batch_count,
batch_stride);
}
void CublasGemm::set_bias(cu::CommandEncoder& encoder, const array& bias) {
encoder.set_input_array(bias);
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
CUBLASLT_MATMUL_DESC_EPILOGUE,
&epilogue,
sizeof(epilogue)));
auto* bias_ptr = bias.data<void>();
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
CUBLASLT_MATMUL_DESC_BIAS_POINTER,
&bias_ptr,
sizeof(bias_ptr)));
}
void CublasGemm::run(
cu::CommandEncoder& encoder,
array& out,
@@ -228,11 +253,19 @@ void CublasGemm::run(
const array& b,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides) {
const Strides& b_batch_strides,
float alpha) {
int batch_count = out.size() / (M_ * N_);
if (batch_count / batch_shape.back() > 1) {
run_batched(
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
encoder,
out,
a,
b,
batch_shape,
a_batch_strides,
b_batch_strides,
alpha);
return;
}
@@ -240,7 +273,13 @@ void CublasGemm::run(
encoder.set_input_array(b);
encoder.set_output_array(out);
execute(encoder, out.data<void>(), a.data<void>(), b.data<void>(), nullptr);
execute(
encoder,
out.data<void>(),
a.data<void>(),
b.data<void>(),
nullptr,
alpha);
}
void CublasGemm::run(
@@ -330,9 +369,9 @@ void CublasGemm::execute(
handle_,
matmul_desc_,
&alpha,
a,
b, // a and b are swapped
a_desc_,
b,
a,
b_desc_,
&beta,
c ? c : out,

View File

@@ -55,6 +55,8 @@ class CublasGemm {
int32_t batch_count,
int64_t batch_stride);
void set_bias(cu::CommandEncoder& encoder, const array& bias);
void run(
cu::CommandEncoder& encoder,
array& out,
@@ -62,7 +64,8 @@ class CublasGemm {
const array& b,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides);
const Strides& b_batch_strides,
float alpha = 1.0f);
void run(
cu::CommandEncoder& encoder,
@@ -85,7 +88,8 @@ class CublasGemm {
const array& b,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides);
const Strides& b_batch_strides,
float alpha);
void run_batched(
cu::CommandEncoder& encoder,

View File

@@ -13,7 +13,8 @@ void CublasGemm::run_batched(
const array& b,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides) {
const Strides& b_batch_strides,
float alpha) {
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
@@ -27,7 +28,8 @@ void CublasGemm::run_batched(
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
a.data<int8_t>() + a.itemsize() * a_it.loc,
b.data<int8_t>() + b.itemsize() * b_it.loc,
nullptr);
nullptr,
alpha);
a_it.step();
b_it.step();
}

View File

@@ -154,7 +154,8 @@ void CublasGemm::run_batched(
const array& b,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides) {
const Strides& b_batch_strides,
float alpha) {
int batch_count = out.size() / (M_ * N_);
set_pointer_mode(a_desc_, batch_count);
set_pointer_mode(b_desc_, batch_count);
@@ -226,7 +227,8 @@ void CublasGemm::run_batched(
reinterpret_cast<void*>(out_pointers),
reinterpret_cast<void*>(a_pointers),
reinterpret_cast<void*>(b_pointers),
nullptr);
nullptr,
alpha);
}
void CublasGemm::run_batched(

View File

@@ -11,6 +11,7 @@
#include <numeric>
namespace mlx::core {
namespace {
std::tuple<bool, int64_t, array>
@@ -28,6 +29,76 @@ check_transpose(cu::CommandEncoder& enc, const Stream& s, const array& arr) {
}
}
void gemm_and_bias(
cu::CommandEncoder& encoder,
int M,
int N,
int K,
bool a_transposed,
int64_t lda,
bool b_transposed,
int64_t ldb,
array& out,
const array& a,
const array& b,
const std::optional<array>& bias = std::nullopt,
float alpha = 1.0f) {
// Check and collapse batch dimensions
auto [batch_shape, a_batch_strides, b_batch_strides] = collapse_batches(a, b);
auto batch_count = out.size() / (M * N);
// Collapse batches into M if needed
if (batch_count > 1 && !a_transposed && batch_shape.size() == 1 &&
a.strides()[a.ndim() - 2] == K && a_batch_strides.back() == M * K &&
b_batch_strides.back() == 0) {
M *= batch_shape.back();
batch_count = 1;
a_batch_strides = {0};
b_batch_strides = {0};
batch_shape = {1};
}
// Use gemmv when possible
if (!bias && cu::can_use_gemv(M, N, K, a_transposed, b_transposed)) {
cu::gemv(
a,
b,
out,
M,
N,
K,
batch_count,
batch_shape,
a_batch_strides,
b_batch_strides,
encoder);
return;
}
// Invoke cublasLt
CublasGemm gemm(
encoder.device(),
a.dtype(),
a_transposed,
M,
K,
lda,
b_transposed,
K,
N,
ldb,
batch_shape.back(),
a_batch_strides.back(),
b_batch_strides.back());
if (bias) {
gemm.set_bias(encoder, *bias);
}
gemm.run(
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides, alpha);
}
} // namespace
void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -48,9 +119,6 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc(out.nbytes()));
/////////////////////////////////////////////////////////////////////////////
// Init checks and prep
int M = a_pre.shape(-2);
int N = b_pre.shape(-1);
int K = a_pre.shape(-1);
@@ -60,58 +128,8 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre);
auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre);
/////////////////////////////////////////////////////////////////////////////
// Check and collapse batch dimensions
auto [batch_shape, a_batch_strides, b_batch_strides] = collapse_batches(a, b);
auto batch_count = out.size() / (M * N);
// Collapse batches into M if needed
if (batch_count > 1 && !a_transposed && batch_shape.size() == 1 &&
a.strides()[a.ndim() - 2] == K && a_batch_strides.back() == M * K &&
b_batch_strides.back() == 0) {
M *= batch_shape.back();
batch_count = 1;
a_batch_strides = {0};
b_batch_strides = {0};
batch_shape = {1};
}
if (cu::can_use_gemv(M, N, K, a_transposed, b_transposed)) {
cu::gemv(
a,
b,
out,
M,
N,
K,
batch_count,
batch_shape,
a_batch_strides,
b_batch_strides,
encoder);
return;
}
/////////////////////////////////////////////////////////////////////////////
// Invoke cublasLt
CublasGemm gemm(
cu::device(s.device),
a.dtype(),
a_transposed,
M,
K,
lda,
b_transposed,
K,
N,
ldb,
batch_shape.back(),
a_batch_strides.back(),
b_batch_strides.back());
gemm.run(encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
gemm_and_bias(
encoder, M, N, K, a_transposed, lda, b_transposed, ldb, out, a, b);
}
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -136,6 +154,28 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre);
auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre);
/////////////////////////////////////////////////////////////////////////////
// Dispatch to GEMM with epilogue or AddMM
if (beta_ == 1 && c.strides(-1) == 1 && c.data_size() == out.shape(-1)) {
out.set_data(allocator::malloc(out.nbytes()));
gemm_and_bias(
encoder,
M,
N,
K,
a_transposed,
lda,
b_transposed,
ldb,
out,
a,
b,
c,
alpha_);
return;
}
int64_t ldc;
{
auto stx = c.strides()[c.ndim() - 2];
@@ -177,7 +217,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
}
/////////////////////////////////////////////////////////////////////////////
// Invoke cublasLt
// Invoke cublasLt with AddMM settings
CublasGemm gemm(
cu::device(s.device),

View File

@@ -103,15 +103,21 @@ template <typename T, bool traditional, bool forward, int N = 4>
__device__ void rope_impl(
const T* in,
T* out,
int offset,
const int* offset,
float inv_freq,
float scale,
const cuda::std::array<int64_t, 3> strides,
const cuda::std::array<int64_t, 3> out_strides,
int64_t n_batch,
int64_t offset_stride,
int n_head,
uint3 pos,
uint3 dims) {
float L = scale * static_cast<float>(pos.y + offset);
auto n_head_up = N * ((n_head + N - 1) / N);
auto head_idx = static_cast<int>((pos.z * N) % n_head_up);
auto batch_idx = (pos.z * N) / n_head_up;
auto batch_offset = offset[batch_idx * offset_stride];
float L = scale * static_cast<float>(pos.y + batch_offset);
auto mat_idx = batch_idx * n_head + head_idx;
// Compute costheta, sintheta
float theta = L * inv_freq;
@@ -123,20 +129,19 @@ __device__ void rope_impl(
size_t out_index_1, out_index_2;
if (traditional) {
out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] +
N * pos.z * out_strides[0];
mat_idx * out_strides[0];
out_index_2 = out_index_1 + 1;
in_index_1 =
2 * pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
2 * pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0];
in_index_2 = in_index_1 + strides[2];
} else {
out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] +
N * pos.z * out_strides[0];
mat_idx * out_strides[0];
out_index_2 = out_index_1 + dims.x * out_strides[2];
in_index_1 =
pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
in_index_1 = pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0];
in_index_2 = in_index_1 + dims.x * strides[2];
}
for (int i = 0; i < N && pos.z * N + i < n_batch; ++i) {
for (int i = 0; i < N && head_idx + i < n_head; ++i) {
// Read and write the output
float x1 = static_cast<float>(in[in_index_1]);
float x2 = static_cast<float>(in[in_index_2]);
@@ -167,7 +172,8 @@ __global__ void rope(
float base,
const __grid_constant__ cuda::std::array<int64_t, 3> strides,
const __grid_constant__ cuda::std::array<int64_t, 3> out_strides,
int64_t n_batch,
int64_t offset_stride,
int n_head,
uint3 dims) {
uint3 pos = make_uint3(
blockIdx.x * blockDim.x + threadIdx.x,
@@ -182,12 +188,13 @@ __global__ void rope(
rope_impl<T, traditional, forward>(
in,
out,
*offset,
offset,
inv_freq,
scale,
strides,
out_strides,
n_batch,
offset_stride,
n_head,
pos,
dims);
}
@@ -202,7 +209,8 @@ __global__ void rope_freqs(
float base,
const __grid_constant__ cuda::std::array<int64_t, 3> strides,
const __grid_constant__ cuda::std::array<int64_t, 3> out_strides,
int64_t n_batch,
int64_t offset_stride,
int n_head,
uint3 dims,
int64_t freq_stride) {
uint3 pos = make_uint3(
@@ -217,12 +225,13 @@ __global__ void rope_freqs(
rope_impl<T, traditional, forward>(
in,
out,
*offset,
offset,
inv_freq,
scale,
strides,
out_strides,
n_batch,
offset_stride,
n_head,
pos,
dims);
}
@@ -245,23 +254,28 @@ void RoPE::eval_gpu(
auto& offset = inputs[1];
auto& out = outputs[0];
if (in.ndim() < 3) {
throw std::runtime_error("[RoPE] Input must have at least 3 dimensions");
}
cuda::std::array<int64_t, 3> strides;
cuda::std::array<int64_t, 3> out_strides;
bool donated = false;
int ndim = in.ndim();
int dispatch_ndim = in.ndim();
int B = in.shape(0);
int T = in.shape(-2);
int D = in.shape(-1);
size_t mat_size = T * D;
int dispatch_ndim = ndim;
while (in.shape(-dispatch_ndim) == 1 && dispatch_ndim > 3) {
dispatch_ndim--;
}
size_t mat_size = in.shape(-2) * in.shape(-1);
int N = 1;
for (int i = 1; i < (ndim - 2); ++i) {
N *= in.shape(i);
}
// We apply rope to less that the whole vector so copy to output and then
// apply in-place.
if (dims_ < in.shape(-1)) {
if (dims_ < D) {
donated = true;
auto ctype =
(in.flags().row_contiguous) ? CopyType::Vector : CopyType::General;
@@ -302,7 +316,7 @@ void RoPE::eval_gpu(
out_strides[2] = out.strides()[ndim - 1];
// Some flags to help us dispatch below
bool single = in.flags().row_contiguous && (mat_size == in.shape(-1));
bool single = in.flags().row_contiguous && B == 1 && T == 1;
bool with_freqs = inputs.size() == 3;
auto& encoder = cu::get_command_encoder(s);
@@ -319,7 +333,7 @@ void RoPE::eval_gpu(
if (single && !with_freqs) {
auto kernel =
cu::rope_single<DataType, traditional.value, forward.value>;
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size);
uint2 dims = make_uint2(dims_ / 2, N);
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
encoder.add_kernel_node(
kernel,
@@ -336,7 +350,7 @@ void RoPE::eval_gpu(
} else if (single) {
auto kernel =
cu::rope_single_freqs<DataType, traditional.value, forward.value>;
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size);
uint2 dims = make_uint2(dims_ / 2, N);
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
encoder.add_kernel_node(
kernel,
@@ -354,10 +368,14 @@ void RoPE::eval_gpu(
} else if (with_freqs) {
auto kernel =
cu::rope_freqs<DataType, traditional.value, forward.value>;
uint3 dims =
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size);
dims.z = (dims.z + 3) / 4;
int n_per_thread = 4;
uint32_t dimz = B * ((N + n_per_thread - 1) / n_per_thread);
uint3 dims = make_uint3(dims_ / 2, T, dimz);
auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z);
int64_t offset_stride = 0;
if (inputs[1].ndim() > 0) {
offset_stride = inputs[1].strides()[0];
}
encoder.add_kernel_node(
kernel,
grid,
@@ -371,15 +389,20 @@ void RoPE::eval_gpu(
std::log2(base_),
strides,
out_strides,
in.size() / mat_size,
offset_stride,
N,
dims,
inputs[2].strides(0));
} else {
auto kernel = cu::rope<DataType, traditional.value, forward.value>;
uint3 dims =
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size);
dims.z = (dims.z + 3) / 4;
int n_per_thread = 4;
uint32_t dimz = B * ((N + n_per_thread - 1) / n_per_thread);
uint3 dims = make_uint3(dims_ / 2, T, dimz);
auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z);
int64_t offset_stride = 0;
if (inputs[1].ndim() > 0) {
offset_stride = inputs[1].strides()[0];
}
encoder.add_kernel_node(
kernel,
grid,
@@ -392,7 +415,8 @@ void RoPE::eval_gpu(
std::log2(base_),
strides,
out_strides,
in.size() / mat_size,
offset_stride,
N,
dims);
}
});

View File

@@ -46,6 +46,7 @@ __global__ void kernel_sdpav_1pass(
const T* K,
const T* V,
T* O,
const T* sinks,
__grid_constant__ const AttnParams params) {
constexpr int BN = 32;
constexpr int BD = 32;
@@ -65,7 +66,7 @@ __global__ void kernel_sdpav_1pass(
__shared__ U max_scores[BN];
__shared__ U sum_exp_scores[BN];
const U scale_log2 = params.scale * 1.44269504089f;
const U scale_log2 = params.scale * M_LOG2E;
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<32>(block);
@@ -110,6 +111,10 @@ __global__ void kernel_sdpav_1pass(
U max_score = -INFINITY;
U sum_exp_score = 0.f;
if (sinks && warp_idx == 0) {
max_score = M_LOG2E * static_cast<U>(sinks[head_idx]);
sum_exp_score = 1.f;
}
// For each key
for (int i = kv_seq_idx; i < params.kL; i += BN) {
@@ -137,8 +142,9 @@ __global__ void kernel_sdpav_1pass(
// Update the accumulators
U new_max = max(max_score, score);
U factor = exp2f(max_score - new_max);
U exp_score = exp2f(score - new_max);
bool is_neg_inf = new_max == -INFINITY;
U factor = is_neg_inf ? 1 : exp2f(max_score - new_max);
U exp_score = is_neg_inf ? 0 : exp2f(score - new_max);
max_score = new_max;
sum_exp_score = sum_exp_score * factor + exp_score;
@@ -193,6 +199,7 @@ __global__ void kernel_sdpav_2pass_1(
const T* Q,
const T* K,
const T* V,
const T* sinks,
float* partials,
float* sums,
float* maxs,
@@ -268,8 +275,12 @@ __global__ void kernel_sdpav_2pass_1(
o[i] = 0.f;
}
U max_score = -1e9;
U max_score = -INFINITY;
U sum_exp_score = 0.f;
if (sinks && warp_idx == 0 && block_idx == 0) {
max_score = M_LOG2E * static_cast<U>(sinks[head_idx]);
sum_exp_score = 1.f;
}
// For each key
for (int i = kv_seq_idx; i < params.kL; i += blocks * BN) {
@@ -297,8 +308,9 @@ __global__ void kernel_sdpav_2pass_1(
// Update the accumulators
U new_max = max(max_score, score);
U factor = exp2f(max_score - new_max);
U exp_score = exp2f(score - new_max);
bool is_neg_inf = new_max == -INFINITY;
U factor = is_neg_inf ? 1 : exp2f(max_score - new_max);
U exp_score = is_neg_inf ? 0 : exp2f(score - new_max);
max_score = new_max;
sum_exp_score = sum_exp_score * factor + exp_score;
@@ -463,10 +475,14 @@ void sdpa_vector_1pass_fallback(
const array& v,
const float scale,
array& o,
bool do_causal_ = false) {
bool do_causal,
const std::optional<array>& sinks) {
encoder.set_input_array(q);
encoder.set_input_array(k);
encoder.set_input_array(v);
if (sinks) {
encoder.set_input_array(*sinks);
}
encoder.set_output_array(o);
cu::AttnParams params{
@@ -489,7 +505,7 @@ void sdpa_vector_1pass_fallback(
dim3 block_dim(1024, 1, 1);
dispatch_float_types(o.dtype(), "kernel_sdpav_1pass", [&](auto type_tag) {
dispatch_bool(do_causal_, [&](auto do_causal) {
dispatch_bool(do_causal, [&](auto do_causal) {
dispatch_headdim(params.D, [&](auto headdim) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
@@ -504,6 +520,7 @@ void sdpa_vector_1pass_fallback(
k.data<DataType>(),
v.data<DataType>(),
o.data<DataType>(),
sinks ? (*sinks).data<DataType>() : nullptr,
params);
});
});
@@ -518,7 +535,8 @@ void sdpa_vector_2pass_fallback(
const array& v,
const float scale,
array& o,
bool do_causal_ = false) {
bool do_causal,
const std::optional<array>& sinks) {
cu::AttnParams params{
/* int B = */ q.shape(0),
/* int H = */ q.shape(1),
@@ -559,7 +577,7 @@ void sdpa_vector_2pass_fallback(
encoder.add_temporary(maxs);
dispatch_float_types(o.dtype(), "kernel_sdpav_2pass", [&](auto type_tag) {
dispatch_bool(do_causal_, [&](auto do_causal) {
dispatch_bool(do_causal, [&](auto do_causal) {
dispatch_headdim(params.D, [&](auto headdim) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
@@ -570,6 +588,10 @@ void sdpa_vector_2pass_fallback(
encoder.set_input_array(q);
encoder.set_input_array(k);
encoder.set_input_array(v);
if (sinks) {
encoder.set_input_array(*sinks);
}
encoder.set_output_array(intermediate);
encoder.set_output_array(sums);
encoder.set_output_array(maxs);
@@ -585,6 +607,7 @@ void sdpa_vector_2pass_fallback(
q.data<DataType>(),
k.data<DataType>(),
v.data<DataType>(),
sinks ? (*sinks).data<DataType>() : nullptr,
intermediate.data<float>(),
sums.data<float>(),
maxs.data<float>(),
@@ -627,15 +650,16 @@ void sdpa_vector_fallback(
const array& v,
const float scale,
array& o,
bool do_causal_ = false) {
bool do_causal,
const std::optional<array>& sinks) {
int kL = k.shape(2);
if (kL > 1024) {
return sdpa_vector_2pass_fallback(
s, encoder, q, k, v, scale, o, do_causal_);
s, encoder, q, k, v, scale, o, do_causal, sinks);
} else {
return sdpa_vector_1pass_fallback(
s, encoder, q, k, v, scale, o, do_causal_);
s, encoder, q, k, v, scale, o, do_causal, sinks);
}
}
@@ -691,7 +715,7 @@ void ScaledDotProductAttention::eval_gpu(
// Define some copy functions to ensure the layout of the inputs is as
// expected.
copies.reserve(3);
copies.reserve(inputs.size());
auto copy_unless = [&copies, &s](
auto predicate, const array& arr) -> const array& {
if (!predicate(arr)) {
@@ -703,6 +727,16 @@ void ScaledDotProductAttention::eval_gpu(
}
};
// Checks that the headdim dimension has stride 1.
auto is_matrix_contiguous = [](const array& arr) {
return arr.strides(-1) == 1;
};
std::optional<array> sinks = std::nullopt;
if (has_sinks_) {
sinks = copy_unless(is_matrix_contiguous, inputs.back());
}
// We are in vector mode ie single query
if (q_pre.shape(2) < 4) {
auto q_copy_unless = [](const array& arr) {
@@ -740,10 +774,6 @@ void ScaledDotProductAttention::eval_gpu(
const auto& k = copy_unless(kv_copy_unless, k_pre);
const auto& v = copy_unless(kv_copy_unless, v_pre);
for (const auto& cp : copies) {
encoder.add_temporary(cp);
}
// Donate the query if possible
if (q.is_donatable() && q.flags().row_contiguous && q.size() == o.size()) {
o.copy_shared_buffer(q);
@@ -752,22 +782,26 @@ void ScaledDotProductAttention::eval_gpu(
int64_t str_oH = o.shape(3);
int64_t str_oL = o.shape(1) * str_oH;
int64_t str_oB = o.shape(2) * str_oL;
size_t data_size = o.shape(0) * str_oB;
array::Flags flags{
/* bool contiguous = */ 1,
/* bool row_contiguous = */ o.shape(2) == 1,
/* bool col_contiguous = */ 0,
/* bool col_contiguous = */ o.size() == o.shape(3),
};
o.set_data(
allocator::malloc(o.nbytes()),
data_size,
o.size(),
{str_oB, str_oH, str_oL, str_oD},
flags);
}
return sdpa_vector_fallback(s, encoder, q, k, v, scale_, o, do_causal_);
for (const auto& cp : copies) {
encoder.add_temporary(cp);
}
return sdpa_vector_fallback(
s, encoder, q, k, v, scale_, o, do_causal_, sinks);
}
// Full attention mode should never reach here

View File

@@ -2,7 +2,6 @@
#include <algorithm>
#include <cassert>
#include <numeric>
#include <sstream>
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/device.h"
@@ -39,10 +38,11 @@ void explicit_gemm_conv_ND_gpu(
in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes()));
// Prepare unfolding kernel
std::ostringstream kname;
kname << "naive_unfold_nd_" << type_to_name(in_unfolded) << "_" << N;
std::string kname;
kname.reserve(32);
concatenate(kname, "naive_unfold_nd_", type_to_name(in_unfolded), "_", N);
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
auto kernel = d.get_kernel(kname);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in, 0);
@@ -117,11 +117,12 @@ void explicit_gemm_conv_group_ND_gpu(
in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes()));
// Prepare unfolding kernel
std::ostringstream kname;
kname << "naive_unfold_transpose_nd_" << type_to_name(in_unfolded) << "_"
<< N;
std::string kname;
kname.reserve(32);
concatenate(
kname, "naive_unfold_transpose_nd_", type_to_name(in_unfolded), "_", N);
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
auto kernel = d.get_kernel(kname);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in, 0);
@@ -252,18 +253,32 @@ void implicit_gemm_conv_2D_gpu(
/* const int swizzle_log = */ swizzle_log};
// Determine kernel
std::ostringstream kname;
kname << "implicit_gemm_conv_2d_" << type_to_name(out) << "_bm" << bm << "_bn"
<< bn << "_bk" << bk << "_wm" << wm << "_wn" << wn << "_channel_"
<< (n_channel_specialization ? std::to_string(n_channel_specialization)
: "l")
<< "_filter_" << (small_filter ? 's' : 'l');
std::string kname;
kname.reserve(64);
concatenate(
kname,
"implicit_gemm_conv_2d_",
type_to_name(out),
"_bm",
bm,
"_bn",
bn,
"_bk",
bk,
"_wm",
wm,
"_wn",
wn,
"_channel_",
n_channel_specialization ? std::to_string(n_channel_specialization) : "l",
"_filter_",
small_filter ? 's' : 'l');
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = get_steel_conv_kernel(
d,
kname.str(),
kname,
out,
bm,
bn,
@@ -559,11 +574,16 @@ void winograd_conv_2D_gpu(
{
int bc = 32;
int bo = 4;
std::ostringstream kname;
kname << "winograd_conv_2d_weight_transform_" << type_to_name(out) << "_bc"
<< bc;
std::string kname;
kname.reserve(32);
concatenate(
kname,
"winograd_conv_2d_weight_transform_",
type_to_name(out),
"_bc",
bc);
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
auto kernel = d.get_kernel(kname);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(wt, 0);
@@ -587,11 +607,16 @@ void winograd_conv_2D_gpu(
int bc = 32;
int wm = 2;
int wn = 2;
std::ostringstream kname;
kname << "winograd_conv_2d_input_transform_" << type_to_name(out) << "_bc"
<< bc;
std::string kname;
kname.reserve(32);
concatenate(
kname,
"winograd_conv_2d_input_transform_",
type_to_name(out),
"_bc",
bc);
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
auto kernel = d.get_kernel(kname);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in_padded, 0);
@@ -634,11 +659,16 @@ void winograd_conv_2D_gpu(
int bc = 32;
int wm = 2;
int wn = 2;
std::ostringstream kname;
kname << "winograd_conv_2d_output_transform_" << type_to_name(out) << "_bo"
<< bc;
std::string kname;
kname.reserve(32);
concatenate(
kname,
"winograd_conv_2d_output_transform_",
type_to_name(out),
"_bo",
bc);
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
auto kernel = d.get_kernel(kname);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(out_wg, 0);
@@ -660,9 +690,9 @@ void depthwise_conv_2D_gpu(
const array& wt,
array out,
const MLXConvParams<2>& conv_params) {
std::ostringstream kname;
kname << "depthwise_conv_2d_" << type_to_name(out);
std::string base_name = kname.str();
std::string base_name;
base_name.reserve(32);
concatenate(base_name, "depthwise_conv_2d_", type_to_name(out));
const int N = conv_params.N;
const int ker_h = conv_params.wS[0];
@@ -685,15 +715,18 @@ void depthwise_conv_2D_gpu(
};
// clang-format off
kname << "_ker_h_" << ker_h
<< "_ker_w_" << ker_w
<< "_str_h_" << str_h
<< "_str_w_" << str_w
<< "_tgp_h_" << th
<< "_tgp_w_" << tw
<< "_do_flip_" << (do_flip ? 't' : 'n'); // clang-format on
std::string hash_name = kname.str();
std::string hash_name;
hash_name.reserve(64);
concatenate(
hash_name,
base_name,
"_ker_h_", ker_h,
"_ker_w_", ker_w,
"_str_h_", str_h,
"_str_w_", str_w,
"_tgp_h_", th,
"_tgp_w_", tw,
"_do_flip_", do_flip ? 't' : 'n'); // clang-format on
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(base_name, hash_name, func_consts);
@@ -774,6 +807,56 @@ void dispatch_conv_2D_gpu(
}
}
void depthwise_conv_1D_gpu(
const Stream& s,
metal::Device& d,
const array& in,
array wt,
array out) {
bool large = in.size() > INT32_MAX || in.data_size() > INT32_MAX;
std::string base_name;
base_name.reserve(32);
concatenate(
base_name,
"depthwise_conv_1d_",
large ? "_large" : "",
type_to_name(out));
if (!wt.flags().row_contiguous) {
wt = contiguous_copy_gpu(wt, s);
d.add_temporary(wt, s.index);
}
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(base_name);
compute_encoder.set_compute_pipeline_state(kernel);
auto B = in.shape(0);
auto Tout = out.shape(1);
auto D = in.shape(2);
auto K = wt.shape(1);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_input_array(wt, 1);
compute_encoder.set_output_array(out, 2);
if (large) {
int64_t strides[3] = {in.strides(0), in.strides(1), in.strides(2)};
compute_encoder.set_bytes(strides, 3, 3);
} else {
int strides[3] = {
static_cast<int>(in.strides(0)),
static_cast<int>(in.strides(1)),
static_cast<int>(in.strides(2))};
compute_encoder.set_bytes(strides, 3, 3);
}
compute_encoder.set_bytes(K, 4);
auto group_dims = get_block_dims(D, Tout, B);
MTL::Size grid_dims = MTL::Size(D, Tout, B);
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
void conv_1D_gpu(
const Stream& s,
metal::Device& d,
@@ -790,8 +873,15 @@ void conv_1D_gpu(
bool is_idil_one = in_dilation[0] == 1;
int C = in.shape(2);
int O = wt.shape(0);
const int C_per_group = in.shape(2) / groups;
const int O_per_group = wt.shape(0) / groups;
// Fast path for fully separable 1D convolution
if (is_idil_one && (groups == C) && groups == O && wt_strides[0] == 1 &&
wt_dilation[0] == 1 && padding[0] == 0 && !flip) {
depthwise_conv_1D_gpu(s, d, in, wt, out);
return;
}
const int C_per_group = C / groups;
const int O_per_group = O / groups;
// Direct to implicit gemm conv
if (is_idil_one && (C_per_group <= 4 || C_per_group % 16 == 0) &&

View File

@@ -288,6 +288,40 @@ instantiate_depthconv2d(float32, float);
instantiate_depthconv2d(float16, half);
instantiate_depthconv2d(bfloat16, bfloat16_t);
template <typename T, typename IdxT>
[[kernel]] void depthwise_conv_1d(
const device T* in [[buffer(0)]],
const device T* w [[buffer(1)]],
device T* out [[buffer(2)]],
constant const IdxT strides[3],
constant const int& kernel_size,
uint3 tid [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
out += (tid.z * static_cast<IdxT>(grid_dim.y) + tid.y) * grid_dim.x + tid.x;
in += tid.z * strides[0] + tid.y * strides[1] + tid.x * strides[2];
w += tid.x * kernel_size;
float acc = 0.0;
for (int i = 0; i < kernel_size; ++i) {
acc += static_cast<float>(in[0]) * w[i];
in += strides[1];
}
*out = static_cast<T>(acc);
}
#define instantiate_depthconv1d(iname, itype) \
instantiate_kernel( \
"depthwise_conv_1d_" #iname, depthwise_conv_1d, itype, int32_t) \
instantiate_kernel( \
"depthwise_conv_1d_" #iname "_large", \
depthwise_conv_1d, \
itype, \
int64_t)
instantiate_depthconv1d(float32, float);
instantiate_depthconv1d(float16, half);
instantiate_depthconv1d(bfloat16, bfloat16_t);
///////////////////////////////////////////////////////////////////////////////
/// Winograd kernels
///////////////////////////////////////////////////////////////////////////////

View File

@@ -10,7 +10,7 @@ void rope_single_impl(
constant const int& offset,
const float inv_freq,
constant const float& scale,
constant const size_t& stride,
constant const int64_t& stride,
uint2 pos,
uint2 grid) {
float L = scale * static_cast<float>(offset);
@@ -52,7 +52,7 @@ template <typename T, bool traditional, bool forward>
device T* out [[buffer(1)]],
constant const int& offset,
constant const float& scale,
constant const size_t& stride,
constant const int64_t& stride,
constant const float& base [[buffer(10)]],
uint2 pos [[thread_position_in_grid]],
uint2 grid [[threads_per_grid]]) {
@@ -68,9 +68,9 @@ template <typename T, bool traditional, bool forward>
device T* out [[buffer(1)]],
constant const int& offset,
constant const float& scale,
constant const size_t& stride,
constant const int64_t& stride,
const device float* freqs [[buffer(10)]],
constant const size_t& freq_stride [[buffer(11)]],
constant const int64_t& freq_stride [[buffer(11)]],
uint2 pos [[thread_position_in_grid]],
uint2 grid [[threads_per_grid]]) {
float inv_freq = 1.0 / (freqs[freq_stride * pos.x]);
@@ -82,15 +82,21 @@ template <typename T, bool traditional, bool forward, int N = 4>
void rope_impl(
const device T* in,
device T* out,
constant const int& offset,
const device int* offset,
const float inv_freq,
constant const float& scale,
constant const size_t strides[3],
constant const size_t out_strides[3],
constant const size_t& n_batch,
constant const int64_t strides[3],
constant const int64_t out_strides[3],
constant const int64_t& offset_stride,
constant const int& n_head,
uint3 pos,
uint3 grid) {
float L = scale * static_cast<float>(pos.y + offset);
auto n_head_up = N * ((n_head + N - 1) / N);
auto head_idx = static_cast<int>((pos.z * N) % n_head_up);
auto batch_idx = (pos.z * N) / n_head_up;
auto batch_offset = offset[batch_idx * offset_stride];
float L = scale * static_cast<float>(pos.y + batch_offset);
auto mat_idx = batch_idx * n_head + head_idx;
// Compute costheta, sintheta
float theta = L * inv_freq;
@@ -102,20 +108,19 @@ void rope_impl(
size_t out_index_1, out_index_2;
if (traditional) {
out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] +
N * pos.z * out_strides[0];
mat_idx * out_strides[0];
out_index_2 = out_index_1 + 1;
in_index_1 =
2 * pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
2 * pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0];
in_index_2 = in_index_1 + strides[2];
} else {
out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] +
N * pos.z * out_strides[0];
mat_idx * out_strides[0];
out_index_2 = out_index_1 + grid.x * out_strides[2];
in_index_1 =
pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
in_index_1 = pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0];
in_index_2 = in_index_1 + grid.x * strides[2];
}
for (int i = 0; i < N && pos.z * N + i < n_batch; ++i) {
for (int i = 0; i < N && head_idx + i < n_head; ++i) {
// Read and write the output
float x1 = static_cast<float>(in[in_index_1]);
float x2 = static_cast<float>(in[in_index_2]);
@@ -141,11 +146,12 @@ template <typename T, bool traditional, bool forward, int N = 4>
[[kernel]] void rope(
const device T* in [[buffer(0)]],
device T* out [[buffer(1)]],
constant const int& offset,
const device int* offset,
constant const float& scale,
constant const size_t strides[3],
constant const size_t out_strides[3],
constant const size_t& n_batch,
constant const int64_t strides[3],
constant const int64_t out_strides[3],
constant const int64_t& offset_stride,
constant const int& n_head,
constant const float& base [[buffer(10)]],
uint3 pos [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]) {
@@ -159,7 +165,8 @@ template <typename T, bool traditional, bool forward, int N = 4>
scale,
strides,
out_strides,
n_batch,
offset_stride,
n_head,
pos,
grid);
}
@@ -168,13 +175,14 @@ template <typename T, bool traditional, bool forward, int N = 4>
[[kernel]] void rope_freqs(
const device T* in [[buffer(0)]],
device T* out [[buffer(1)]],
constant const int& offset,
const device int* offset,
constant const float& scale,
constant const size_t strides[3],
constant const size_t out_strides[3],
constant const size_t& n_batch,
constant const int64_t strides[3],
constant const int64_t out_strides[3],
constant const int64_t& offset_stride,
constant const int& n_head,
const device float* freqs [[buffer(10)]],
constant const size_t& freq_stride [[buffer(11)]],
constant const int64_t& freq_stride [[buffer(11)]],
uint3 pos [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]) {
float inv_freq = 1.0 / (freqs[freq_stride * pos.x]);
@@ -186,61 +194,20 @@ template <typename T, bool traditional, bool forward, int N = 4>
scale,
strides,
out_strides,
n_batch,
offset_stride,
n_head,
pos,
grid);
}
// clang-format off
#define instantiate_rope_g(name, type, traditional, forward) \
template [[host_name("rope_" #name)]] [[kernel]] void \
rope<type, traditional, forward>( \
const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \
constant const int& offset, \
constant const float& scale, \
constant const size_t strides[3], \
constant const size_t out_strides[3], \
constant const size_t& n_batch, \
constant const float& base [[buffer(10)]], \
uint3 pos [[thread_position_in_grid]], \
uint3 grid [[threads_per_grid]]); \
template [[host_name("rope_freqs_" #name)]] \
[[kernel]] void rope_freqs<type, traditional, forward>( \
const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \
constant const int& offset, \
constant const float& scale, \
constant const size_t strides[3], \
constant const size_t out_strides[3], \
constant const size_t& n_batch, \
const device float* freqs [[buffer(10)]], \
constant const size_t& freq_stride [[buffer(11)]], \
uint3 pos [[thread_position_in_grid]], \
uint3 grid [[threads_per_grid]]);
instantiate_kernel("rope_" #name, rope, type, traditional, forward) \
instantiate_kernel("rope_freqs_" #name, rope_freqs, type, traditional, forward)
#define instantiate_rope_s(name, type, traditional, forward) \
template [[host_name("rope_single_" #name)]] [[kernel]] void \
rope_single<type, traditional, forward>( \
const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \
constant const int& offset, \
constant const float& scale, \
constant const size_t& stride, \
constant const float& base [[buffer(10)]], \
uint2 pos [[thread_position_in_grid]], \
uint2 grid [[threads_per_grid]]); \
template [[host_name("rope_single_freqs_" #name)]] \
[[kernel]] void rope_single_freqs<type, traditional, forward>( \
const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \
constant const int& offset, \
constant const float& scale, \
constant const size_t& stride, \
const device float* freqs [[buffer(10)]], \
constant const size_t& freq_stride [[buffer(11)]], \
uint2 pos [[thread_position_in_grid]], \
uint2 grid [[threads_per_grid]]);
#define instantiate_rope_s(name, type, traditional, forward) \
instantiate_kernel("rope_single_" #name, rope_single, type, traditional, forward) \
instantiate_kernel("rope_single_freqs_" #name, rope_single_freqs, type, traditional, forward)
#define instantiate_rope(name, type, traditional, forward) \
instantiate_rope_s(name, type, traditional, forward) \

View File

@@ -9,6 +9,7 @@ constant bool query_transposed [[function_constant(21)]];
constant bool do_causal [[function_constant(22)]];
constant bool bool_mask [[function_constant(23)]];
constant bool float_mask [[function_constant(24)]];
constant bool has_sinks [[function_constant(25)]];
template <typename T, int D, int V = D>
[[kernel]] void sdpa_vector(
@@ -31,6 +32,9 @@ template <typename T, int D, int V = D>
[[buffer(14), function_constant(has_mask)]],
const constant int& mask_head_stride
[[buffer(15), function_constant(has_mask)]],
const device T* sinks [[buffer(16), function_constant(has_sinks)]],
const constant int& num_q_heads
[[buffer(17), function_constant(has_sinks)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 tpg [[threadgroups_per_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
@@ -53,24 +57,24 @@ template <typename T, int D, int V = D>
threadgroup U sum_exp_scores[BN];
// Adjust positions
const int head_idx = tid.x;
const int q_batch_head_idx = tid.x;
const int q_seq_idx = tid.y;
const int kv_head_idx = head_idx / gqa_factor;
const int o_offset = head_idx * tpg.y + q_seq_idx;
const int kv_head_idx = q_batch_head_idx / gqa_factor;
const int o_offset = q_batch_head_idx * tpg.y + q_seq_idx;
const int q_offset =
query_transposed ? tpg.x * q_seq_idx + head_idx : o_offset;
query_transposed ? tpg.x * q_seq_idx + q_batch_head_idx : o_offset;
queries += q_offset * D + simd_lid * qk_per_thread;
keys += kv_head_idx * k_head_stride + simd_gid * k_seq_stride +
simd_lid * qk_per_thread;
values += kv_head_idx * v_head_stride + simd_gid * v_seq_stride +
simd_lid * v_per_thread;
if (bool_mask) {
bmask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride +
q_seq_idx * mask_q_seq_stride;
bmask += q_batch_head_idx * mask_head_stride +
simd_gid * mask_kv_seq_stride + q_seq_idx * mask_q_seq_stride;
}
if (float_mask) {
fmask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride +
q_seq_idx * mask_q_seq_stride;
fmask += q_batch_head_idx * mask_head_stride +
simd_gid * mask_kv_seq_stride + q_seq_idx * mask_q_seq_stride;
}
out += o_offset * V + simd_gid * v_per_thread;
@@ -85,6 +89,10 @@ template <typename T, int D, int V = D>
U max_score = -INFINITY;
U sum_exp_score = 0;
if (has_sinks && simd_gid == 0) {
max_score = static_cast<U>(sinks[q_batch_head_idx % num_q_heads]);
sum_exp_score = 1;
}
// For each key
for (int i = simd_gid; i < N; i += BN) {
@@ -93,6 +101,8 @@ template <typename T, int D, int V = D>
use_key = i <= (N - int(tpg.y) + int(q_seq_idx));
} else if (bool_mask) {
use_key = bmask[0];
} else if (float_mask) {
use_key = (fmask[0] >= Limits<T>::finite_min);
}
if (use_key) {
// Read the key
@@ -107,13 +117,14 @@ template <typename T, int D, int V = D>
}
score = simd_sum(score);
if (float_mask) {
score += max(Limits<U>::finite_min, static_cast<U>(fmask[0]));
score += static_cast<U>(fmask[0]);
}
// Update the accumulators
U new_max = max(max_score, score);
U factor = fast::exp(max_score - new_max);
U exp_score = fast::exp(score - new_max);
bool is_neg_inf = new_max == -INFINITY;
U factor = is_neg_inf ? 1.0 : fast::exp(max_score - new_max);
U exp_score = is_neg_inf ? 0.0 : fast::exp(score - new_max);
max_score = new_max;
sum_exp_score = sum_exp_score * factor + exp_score;
@@ -187,6 +198,9 @@ template <typename T, int D, int V = D>
[[buffer(16), function_constant(has_mask)]],
const constant int& mask_head_stride
[[buffer(17), function_constant(has_mask)]],
const device T* sinks [[buffer(18), function_constant(has_sinks)]],
const constant int& num_q_heads
[[buffer(19), function_constant(has_sinks)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 tpg [[threadgroups_per_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
@@ -211,12 +225,12 @@ template <typename T, int D, int V = D>
// Adjust positions
const int block_idx = tid.z;
const int head_idx = tid.x;
const int q_batch_head_idx = tid.x;
const int q_seq_idx = tid.y;
const int o_offset = head_idx * tpg.y + q_seq_idx;
const int o_offset = q_batch_head_idx * tpg.y + q_seq_idx;
const int q_offset =
query_transposed ? tpg.x * q_seq_idx + head_idx : o_offset;
const int kv_head_idx = head_idx / gqa_factor;
query_transposed ? tpg.x * q_seq_idx + q_batch_head_idx : o_offset;
const int kv_head_idx = q_batch_head_idx / gqa_factor;
queries += q_offset * D + simd_lid * qk_per_thread;
keys += kv_head_idx * k_head_stride +
@@ -225,12 +239,12 @@ template <typename T, int D, int V = D>
(block_idx * BN + simd_gid) * v_seq_stride + simd_lid * v_per_thread;
out += o_offset * blocks * V + block_idx * V + simd_lid * v_per_thread;
if (bool_mask) {
bmask += head_idx * mask_head_stride +
bmask += q_batch_head_idx * mask_head_stride +
(block_idx * BN + simd_gid) * mask_kv_seq_stride +
q_seq_idx * mask_q_seq_stride;
}
if (float_mask) {
fmask += head_idx * mask_head_stride +
fmask += q_batch_head_idx * mask_head_stride +
(block_idx * BN + simd_gid) * mask_kv_seq_stride +
q_seq_idx * mask_q_seq_stride;
}
@@ -245,8 +259,13 @@ template <typename T, int D, int V = D>
o[i] = 0;
}
U max_score = -1e9;
U max_score = -INFINITY;
U sum_exp_score = 0;
if (has_sinks && block_idx == 0 && simd_gid == 0) {
int q_head_idx = q_batch_head_idx % num_q_heads;
max_score = static_cast<U>(sinks[q_head_idx]);
sum_exp_score = 1;
}
// For each key
for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) {
@@ -255,6 +274,8 @@ template <typename T, int D, int V = D>
use_key = i <= (N - int(tpg.y) + int(q_seq_idx));
} else if (bool_mask) {
use_key = bmask[0];
} else if (float_mask) {
use_key = (fmask[0] >= Limits<T>::finite_min);
}
if (use_key) {
// Read the key
@@ -268,6 +289,10 @@ template <typename T, int D, int V = D>
score += q[i] * k[i];
}
score = simd_sum(score);
if (score < Limits<T>::finite_min) {
continue;
}
if (float_mask) {
score += fmask[0];
}

View File

@@ -11,6 +11,7 @@ constant bool align_K [[function_constant(201)]];
constant bool has_mask [[function_constant(300)]];
constant bool do_causal [[function_constant(301)]];
constant bool has_sinks [[function_constant(302)]];
template <typename T>
struct TransformScale {
@@ -82,6 +83,7 @@ template <
const constant AttnParams* params [[buffer(4)]],
const constant AttnMaskParams* mask_params [[buffer(5), function_constant(has_mask)]],
const device MaskType* mask [[buffer(6), function_constant(has_mask)]],
const device T* sinks [[buffer(7), function_constant(has_sinks)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
@@ -169,7 +171,7 @@ template <
VBlockLoader loader_v(
V, params->V_strides[2], Vs, simd_group_id, simd_lane_id);
TransformScale<T> ts(static_cast<T>(params->scale * 1.44269504089));
TransformScale<T> ts(static_cast<T>(params->scale * M_LOG2E_F));
// Prepare MMA tiles
constexpr short kFragSize = 8; // MMAFrag size
@@ -232,6 +234,14 @@ template <
max_score[i] = Limits<AccumType>::finite_min;
}
if (has_sinks) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kRowsPT; ++i) {
max_score[i] = M_LOG2E_F * static_cast<AccumType>(sinks[tidl.y]);
sum_score[i] = 1;
}
}
int kb_lim = params->NK;
if (do_causal) {
@@ -350,7 +360,7 @@ template <
Stile.frag_at(i, j)[jj] =
mfrag[jj] ? Stile.frag_at(i, j)[jj] : neg_inf;
} else {
Stile.frag_at(i, j)[jj] += 1.44269504089 * selem_t(mfrag[jj]);
Stile.frag_at(i, j)[jj] += M_LOG2E_F * selem_t(mfrag[jj]);
}
}
}

View File

@@ -18,23 +18,29 @@ void RoPE::eval_gpu(
auto& in = inputs[0];
auto& out = outputs[0];
if (in.ndim() < 3) {
throw std::runtime_error("[RoPE] Input must have at least 3 dimensions");
}
auto& s = out.primitive().stream();
auto& d = metal::device(s.device);
size_t strides[3];
size_t out_strides[3];
int64_t strides[3];
int64_t out_strides[3];
bool donated = false;
int ndim = in.ndim();
int dispatch_ndim = in.ndim();
int B = in.shape(0);
int T = in.shape(-2);
int D = in.shape(-1);
size_t mat_size = T * D;
int dispatch_ndim = ndim;
while (in.shape(-dispatch_ndim) == 1 && dispatch_ndim > 3) {
dispatch_ndim--;
}
size_t mat_size = in.shape(-2) * in.shape(-1);
if (dims_ < in.shape(-1)) {
int N = 1;
for (int i = 1; i < (ndim - 2); ++i) {
N *= in.shape(i);
}
if (dims_ < D) {
donated = true;
auto ctype =
(in.flags().row_contiguous) ? CopyType::Vector : CopyType::General;
@@ -71,8 +77,8 @@ void RoPE::eval_gpu(
out_strides[1] = out.strides()[ndim - 2];
out_strides[2] = out.strides()[ndim - 1];
// Special case for inference (single time step and contiguous)
bool single = in.flags().row_contiguous && (mat_size == in.shape(-1));
// Special case for inference (single batch, single time step, and contiguous)
bool single = in.flags().row_contiguous && B == 1 && T == 1;
bool with_freqs = inputs.size() == 3;
std::ostringstream kname;
@@ -86,24 +92,29 @@ void RoPE::eval_gpu(
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(donated ? out : in, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder.set_input_array(inputs[1], 2);
compute_encoder.set_bytes(scale_, 3);
size_t n_batch = in.size() / mat_size;
MTL::Size group_dims;
MTL::Size grid_dims;
if (single) {
compute_encoder.set_bytes(out_strides, 1, 4);
uint32_t dim0 = dims_ / 2;
group_dims = get_block_dims(dim0, n_batch, 1);
grid_dims = MTL::Size(dim0, n_batch, 1);
group_dims = get_block_dims(dim0, N, 1);
grid_dims = MTL::Size(dim0, N, 1);
} else {
compute_encoder.set_bytes(strides, 3, 4);
compute_encoder.set_bytes(out_strides, 3, 5);
compute_encoder.set_bytes(n_batch, 6);
int64_t offset_stride = 0;
if (inputs[1].ndim() > 0) {
offset_stride = inputs[1].strides()[0];
}
compute_encoder.set_bytes(offset_stride, 6);
compute_encoder.set_bytes(N, 7);
uint32_t dim0 = dims_ / 2;
uint32_t dim1 = in.shape(-2);
uint32_t dim2 = (n_batch + n_per_thread - 1) / n_per_thread;
uint32_t dim1 = T;
uint32_t dim2 = B * ((N + n_per_thread - 1) / n_per_thread);
group_dims = get_block_dims(dim0, dim1, dim2);
grid_dims = MTL::Size(dim0, dim1, dim2);
}

View File

@@ -21,8 +21,9 @@ void sdpa_full_self_attention_metal(
const array& v,
const float scale,
array& o,
bool do_causal_ = false,
const std::optional<array>& mask = std::nullopt) {
bool do_causal_,
const std::optional<array>& mask,
const std::optional<array>& sinks) {
using namespace mlx::steel;
int wm = 4;
@@ -42,35 +43,49 @@ void sdpa_full_self_attention_metal(
const bool align_Q = (qL % bq) == 0;
const bool align_K = (kL % bk) == 0;
const bool has_mask = !!mask;
const bool has_mask = mask.has_value();
const bool do_causal = do_causal_;
const bool has_sinks = sinks.has_value();
metal::MTLFCList func_consts = {
{&align_Q, MTL::DataType::DataTypeBool, 200},
{&align_K, MTL::DataType::DataTypeBool, 201},
{&has_mask, MTL::DataType::DataTypeBool, 300},
{&do_causal, MTL::DataType::DataTypeBool, 301}};
{&do_causal, MTL::DataType::DataTypeBool, 301},
{&has_sinks, MTL::DataType::DataTypeBool, 302}};
std::ostringstream kname;
// clang-format off
kname << "steel_attention_"
<< type_to_name(q)
<< "_bq" << bq
<< "_bk" << bk
<< "_bd" << bd
<< "_wm" << wm
<< "_wn" << wn
<< "_mask" << (type_to_name(has_mask ? *mask : q)); // clang-format on
std::string base_name;
concatenate(
base_name,
"steel_attention_",
type_to_name(q),
"_bq",
bq,
"_bk",
bk,
"_bd",
bd,
"_wm",
wm,
"_wn",
wn,
"_mask",
type_to_name(has_mask ? *mask : q));
std::string base_name = kname.str();
// clang-format off
kname << "_align_Q_" << (align_Q ? 't' : 'n')
<< "_align_K_" << (align_K ? 't' : 'n')
<< "_has_mask_" << (has_mask ? 't' : 'n')
<< "_do_causal_" << (do_causal ? 't' : 'n'); // clang-format on
std::string hash_name = kname.str();
std::string hash_name;
concatenate(
hash_name,
base_name,
"_align_Q_",
(align_Q ? 't' : 'n'),
"_align_K_",
(align_K ? 't' : 'n'),
"_has_mask_",
(has_mask ? 't' : 'n'),
"_do_causal_",
(do_causal ? 't' : 'n'),
"_has_sinks_",
(has_sinks ? 't' : 'n'));
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(base_name, hash_name, func_consts);
@@ -114,8 +129,8 @@ void sdpa_full_self_attention_metal(
compute_encoder.set_output_array(o, 3);
compute_encoder.set_bytes(params, 4);
if (mask) {
auto m = *mask;
if (has_mask) {
auto& m = *mask;
AttnMaskParams mask_params{/* int64_t M_strides[3] = */ {
m.strides(0), m.strides(1), m.strides(2)}};
@@ -123,6 +138,9 @@ void sdpa_full_self_attention_metal(
compute_encoder.set_bytes(mask_params, 5);
compute_encoder.set_input_array(m, 6);
}
if (has_sinks) {
compute_encoder.set_input_array(*sinks, 7);
}
MTL::Size grid_dims = MTL::Size(NQ, H, B);
MTL::Size group_dims = MTL::Size(32, wm, wn);
@@ -139,7 +157,8 @@ void sdpa_vector(
array& out,
float scale,
bool do_causal,
const std::optional<array>& mask) {
const std::optional<array>& mask,
const std::optional<array>& sinks) {
// Set the kernel name
std::string kname;
kname.reserve(64);
@@ -153,30 +172,32 @@ void sdpa_vector(
// Compute the necessary sizes
int gqa_factor = q.shape(1) / k.shape(1);
int N = k.shape(2);
int B = q.shape(0) * q.shape(1);
size_t k_head_stride = k.shape(1) == 1 ? k.strides(0) : k.strides(1);
size_t k_seq_stride = k.strides()[2];
size_t v_head_stride = v.shape(1) == 1 ? v.strides(0) : v.strides(1);
size_t v_seq_stride = v.strides()[2];
MTL::Size group_dims(1024, 1, 1);
MTL::Size grid_dims(B, q.shape(2), 1);
MTL::Size grid_dims(q.shape(0) * q.shape(1), q.shape(2), 1);
bool has_mask = mask.has_value();
bool bool_mask = has_mask && (*mask).dtype() == bool_;
bool float_mask = has_mask && !bool_mask;
bool query_transposed = !q.flags().row_contiguous;
bool has_sinks = sinks.has_value();
metal::MTLFCList func_consts = {
{&has_mask, MTL::DataType::DataTypeBool, 20},
{&query_transposed, MTL::DataType::DataTypeBool, 21},
{&do_causal, MTL::DataType::DataTypeBool, 22},
{&bool_mask, MTL::DataType::DataTypeBool, 23},
{&float_mask, MTL::DataType::DataTypeBool, 24},
{&has_sinks, MTL::DataType::DataTypeBool, 25},
};
std::string hash_name = kname;
hash_name += has_mask ? (bool_mask ? "_boolmask" : "_floatmask") : "_nomask";
hash_name += query_transposed ? "_qt" : "_qnt";
hash_name += do_causal ? "_c" : "_nc";
hash_name += has_sinks ? "_sinks" : "_nosinks";
// Get the kernel
auto& compute_encoder = d.get_command_encoder(s.index);
@@ -207,6 +228,10 @@ void sdpa_vector(
compute_encoder.set_bytes(q_seq_stride, 14);
compute_encoder.set_bytes(head_stride, 15);
}
if (has_sinks) {
compute_encoder.set_input_array(*sinks, 16);
compute_encoder.set_bytes(q.shape(1), 17);
}
// Launch
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
@@ -221,7 +246,8 @@ void sdpa_vector_2pass(
array& out,
float scale,
bool do_causal,
const std::optional<array>& mask) {
const std::optional<array>& mask,
const std::optional<array>& sinks) {
// Set the kernel name
std::string kname;
kname.reserve(64);
@@ -267,17 +293,20 @@ void sdpa_vector_2pass(
bool bool_mask = has_mask && (*mask).dtype() == bool_;
bool float_mask = has_mask && !bool_mask;
bool query_transposed = !q.flags().row_contiguous;
bool has_sinks = sinks.has_value();
metal::MTLFCList func_consts = {
{&has_mask, MTL::DataType::DataTypeBool, 20},
{&query_transposed, MTL::DataType::DataTypeBool, 21},
{&do_causal, MTL::DataType::DataTypeBool, 22},
{&bool_mask, MTL::DataType::DataTypeBool, 23},
{&float_mask, MTL::DataType::DataTypeBool, 24},
{&has_sinks, MTL::DataType::DataTypeBool, 25},
};
std::string hash_name = kname;
hash_name += has_mask ? (bool_mask ? "_boolmask" : "_floatmask") : "_nomask";
hash_name += query_transposed ? "_qt" : "_qnt";
hash_name += do_causal ? "_c" : "_nc";
hash_name += has_sinks ? "_sinks" : "_nosinks";
// Get the kernel
auto& compute_encoder = d.get_command_encoder(s.index);
@@ -310,6 +339,10 @@ void sdpa_vector_2pass(
compute_encoder.set_bytes(q_seq_stride, 16);
compute_encoder.set_bytes(head_stride, 17);
}
if (has_sinks) {
compute_encoder.set_input_array(*sinks, 18);
compute_encoder.set_bytes(q.shape(1), 19);
}
// Launch
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
@@ -411,6 +444,12 @@ void ScaledDotProductAttention::eval_gpu(
return arr.strides(-1) == 1;
};
std::optional<array> sinks = std::nullopt;
if (has_sinks_) {
sinks = copy_unless(is_matrix_contiguous, inputs.back());
}
bool has_arr_mask = inputs.size() > (3 + has_sinks_);
// We are in vector mode ie single query
if (q_pre.shape(2) <= 8) {
auto q_copy_unless = [](const array& arr) {
@@ -462,7 +501,7 @@ void ScaledDotProductAttention::eval_gpu(
(strides[0] == strides[1] * shape[1]);
};
auto mask = inputs.size() > 3
auto mask = has_arr_mask
? std::optional<array>{copy_unless(mask_copy_unless, inputs[3])}
: std::nullopt;
@@ -473,9 +512,9 @@ void ScaledDotProductAttention::eval_gpu(
char devc = d.get_architecture().back();
if ((devc == 'd' && k.shape(2) >= 1024) ||
(k.shape(1) < q.shape(1) && k.shape(2) >= 4096)) {
sdpa_vector_2pass(s, d, q, k, v, o, scale_, do_causal, mask);
sdpa_vector_2pass(s, d, q, k, v, o, scale_, do_causal, mask, sinks);
} else {
sdpa_vector(s, d, q, k, v, o, scale_, do_causal, mask);
sdpa_vector(s, d, q, k, v, o, scale_, do_causal, mask, sinks);
}
}
@@ -503,11 +542,12 @@ void ScaledDotProductAttention::eval_gpu(
{str_oB, str_oH, str_oL, str_oD},
flags);
auto mask = inputs.size() > 3
auto mask = has_arr_mask
? std::optional<array>{copy_unless(is_matrix_contiguous, inputs[3])}
: std::nullopt;
sdpa_full_self_attention_metal(s, d, q, k, v, scale_, o, do_causal_, mask);
sdpa_full_self_attention_metal(
s, d, q, k, v, scale_, o, do_causal_, mask, sinks);
}
d.add_temporaries(std::move(copies), s.index);

View File

@@ -8,6 +8,7 @@ file(
"${CMAKE_CURRENT_BINARY_DIR}/nccl.h")
add_library(nccl SHARED nccl_stubs.cpp)
set_target_properties(nccl PROPERTIES SOVERSION 2)
find_package(CUDAToolkit REQUIRED)
target_include_directories(nccl PRIVATE ${CUDAToolkit_INCLUDE_DIRS})
target_include_directories(nccl PRIVATE ${CMAKE_CURRENT_BINARY_DIR})

View File

@@ -366,10 +366,16 @@ array rope(
msg << "[rope] Input must be a floating type but got " << x.dtype() << ".";
throw std::invalid_argument(msg.str());
}
if (offset.size() != 1) {
if (offset.ndim() > 1) {
std::ostringstream msg;
msg << "[rope] offset must be a scalar but has shape " << offset.shape()
<< ".";
msg << "[rope] offset must have at most one dimension but has shape "
<< offset.shape() << ".";
throw std::invalid_argument(msg.str());
}
if (offset.size() != 1 && offset.size() != x.shape(0)) {
std::ostringstream msg;
msg << "[rope] offset must be a scalar or vector with " << x.shape(0)
<< " elements but has shape " << offset.shape() << ".";
throw std::invalid_argument(msg.str());
}
if (!issubdtype(offset.dtype(), integer)) {
@@ -379,7 +385,7 @@ array rope(
throw std::invalid_argument(msg.str());
}
if (offset.dtype().size() != 4) {
inputs[1] = astype(offset, uint32, s);
inputs[1] = astype(offset, int32, s);
}
if (inputs.size() == 3 &&
(inputs[2].ndim() != 1 || inputs[2].shape(0) != dims / 2)) {
@@ -391,15 +397,26 @@ array rope(
auto fallback = [dims, traditional, base, scale, forward, s](
std::vector<array> inputs) {
auto& shape = inputs[0].shape();
int ndim = shape.size();
auto x = flatten(inputs[0], 0, ndim - 3, s);
auto x = inputs[0];
auto shape = x.shape();
if (x.ndim() == 3) {
x = expand_dims(x, 1, s);
} else if (x.ndim() > 4) {
x = flatten(x, 1, 1 + (x.ndim() - 4), s);
}
auto B = x.shape(0);
auto N = x.shape(1);
auto T = x.shape(2);
auto t = x.dtype();
// Compute sines and cosines
auto half_dims = dims / 2;
auto& offset = inputs[1];
auto offset = inputs[1];
if (offset.size() > 1) {
offset = expand_dims(offset, {-1, -2}, s);
}
auto positions =
multiply(add(arange(x.shape(1), t, s), offset, s), array(scale, t), s);
multiply(add(arange(x.shape(2), t, s), offset, s), array(scale, t), s);
auto default_inv_freqs = [&inputs, &s, &t, base, half_dims]() {
return exp(
@@ -412,8 +429,7 @@ array rope(
auto inv_freqs = inputs.size() == 3 ? astype(reciprocal(inputs[2], s), t, s)
: default_inv_freqs();
auto theta =
multiply(expand_dims(positions, 1, s), expand_dims(inv_freqs, 0, s), s);
auto theta = multiply(expand_dims(positions, -1, s), inv_freqs, s);
auto coss = cos(theta, s);
auto sins = sin(theta, s);
@@ -436,32 +452,30 @@ array rope(
};
if (traditional) {
auto x1 =
slice(x, {0, 0, 0}, {x.shape(0), x.shape(1), dims}, {1, 1, 2}, s);
auto x2 =
slice(x, {0, 0, 1}, {x.shape(0), x.shape(1), dims}, {1, 1, 2}, s);
auto x1 = slice(x, {0, 0, 0, 0}, {B, N, T, dims}, {1, 1, 1, 2}, s);
auto x2 = slice(x, {0, 0, 0, 1}, {B, N, T, dims}, {1, 1, 1, 2}, s);
auto outs = apply_rope(x1, x2, coss, sins);
for (auto& o : outs) {
o = expand_dims(o, 3, s);
o = expand_dims(o, -1, s);
}
auto out = concatenate(outs, 3, s);
auto out = reshape(concatenate(outs, -1, s), {B, N, T, dims}, s);
if (dims < x.shape(-1)) {
out = reshape(out, {x.shape(0), x.shape(1), dims});
out = concatenate({out, slice(x, {0, 0, dims}, x.shape(), s)}, 2, s);
out =
concatenate({out, slice(x, {0, 0, 0, dims}, x.shape(), s)}, -1, s);
}
return std::vector<array>{reshape(out, shape, s)};
} else {
auto out_s = x.shape();
out_s.back() = half_dims;
auto x1 = slice(x, {0, 0, 0}, out_s, s);
auto x1 = slice(x, {0, 0, 0, 0}, out_s, s);
out_s.back() = dims;
auto x2 = slice(x, {0, 0, half_dims}, out_s, s);
auto x2 = slice(x, {0, 0, 0, half_dims}, out_s, s);
auto outs = apply_rope(x1, x2, coss, sins);
if (dims < x.shape(-1)) {
outs.push_back(slice(x, {0, 0, dims}, x.shape(), s));
outs.push_back(slice(x, {0, 0, 0, dims}, x.shape(), s));
}
return std::vector<array>{reshape(concatenate(outs, 2, s), shape, s)};
return std::vector<array>{reshape(concatenate(outs, -1, s), shape, s)};
}
};
auto stream = to_stream(s);
@@ -565,6 +579,7 @@ array scaled_dot_product_attention(
const float scale,
const std::string& mask_mode /* = "" */,
const std::vector<array>& mask_arrs /* = {} */,
const std::optional<array>& sinks /* = {} */,
StreamOrDevice s /* = {}*/) {
for (const auto& tensor : {queries, keys, values}) {
if (tensor.ndim() != 4) {
@@ -665,13 +680,20 @@ array scaled_dot_product_attention(
<< final_type << ".";
throw std::invalid_argument(msg.str());
}
bool has_sinks = sinks.has_value();
auto q = astype(queries, final_type, s);
auto k = astype(keys, final_type, s);
auto v = astype(values, final_type, s);
auto fallback = [scale, final_type, n_q_heads, n_kv_heads, do_causal, s](
const std::vector<array>& inputs) {
auto fallback = [scale,
final_type,
n_q_heads,
n_kv_heads,
do_causal,
has_sinks,
has_arr_mask,
s](const std::vector<array>& inputs) {
auto q = multiply(array(scale, inputs[0].dtype()), inputs[0], s);
int n_repeats = n_q_heads / n_kv_heads;
int B = q.shape(0);
@@ -684,20 +706,22 @@ array scaled_dot_product_attention(
v = expand_dims(v, 2, s);
}
auto scores = matmul(q, swapaxes(k, -1, -2, s), s);
if (inputs.size() > 3 || do_causal) {
if (has_arr_mask || do_causal) {
// Mask must be broadcast-compatible with [B, n_q_heads, L_q, L_kv]
auto mask = inputs.back();
if (do_causal) {
int kL = k.shape(-2);
int qL = q.shape(-2);
int q_off = (kL - qL) < 0 ? 0 : (kL - qL);
auto q_idx = arange(q_off, q_off + qL, s);
auto k_idx = arange(0, kL, s);
q_idx = expand_dims(q_idx, 1, s);
k_idx = expand_dims(k_idx, 0, s);
mask = greater_equal(q_idx, k_idx, s);
}
auto make_or_fetch_mask = [&]() {
if (do_causal) {
int kL = k.shape(-2);
int qL = q.shape(-2);
int q_off = (kL - qL) < 0 ? 0 : (kL - qL);
auto q_idx = arange(q_off, q_off + qL, s);
auto k_idx = arange(0, kL, s);
q_idx = expand_dims(q_idx, 1, s);
k_idx = expand_dims(k_idx, 0, s);
return greater_equal(q_idx, k_idx, s);
}
return inputs[3];
};
auto mask = make_or_fetch_mask();
if (n_repeats > 1 && mask.ndim() >= 3) {
if (mask.shape(-3) == 1) {
@@ -716,7 +740,25 @@ array scaled_dot_product_attention(
scores = add(scores, mask, s);
}
}
if (has_sinks) {
auto sinks = inputs.back();
// scores has shape B N_q N_k L_q L_k
sinks = expand_dims(sinks, {0, 2, 3}, s);
if (scores.ndim() == 5) {
sinks = unflatten(sinks, 1, {n_kv_heads, n_repeats}, s);
}
auto bsx_shape = scores.shape();
bsx_shape.back() = 1;
scores = concatenate({broadcast_to(sinks, bsx_shape, s), scores}, -1, s);
}
scores = softmax(scores, std::vector<int>{-1}, true, s);
if (has_sinks) {
// Slice off scores
auto start = Shape(scores.ndim(), 0);
start.back() = 1;
auto stop = scores.shape();
scores = slice(scores, std::move(start), std::move(stop), s);
}
auto out = matmul(scores, v, s);
if (n_repeats > 1) {
out = flatten(out, 1, 2, s);
@@ -732,7 +774,7 @@ array scaled_dot_product_attention(
has_bool_mask = mask_arr.dtype() == bool_;
if (promote_types(mask_arr.dtype(), final_type) != final_type) {
std::ostringstream msg;
msg << "[scaled_dot_product_attention] Mask type must promote to output type. "
msg << "[scaled_dot_product_attention] Mask type must promote to output type "
<< final_type << ".";
throw std::invalid_argument(msg.str());
} else if (!has_bool_mask) {
@@ -743,6 +785,22 @@ array scaled_dot_product_attention(
mask_shape.back() = keys.shape(-2);
inputs.push_back(broadcast_to(mask_arr, mask_shape, stream));
}
if (has_sinks) {
if (promote_types(sinks->dtype(), final_type) != final_type) {
std::ostringstream msg;
msg << "[scaled_dot_product_attention] Type of sinks must promote to output type "
<< final_type << ".";
throw std::invalid_argument(msg.str());
}
if (sinks->ndim() != 1 || sinks->shape(0) != n_q_heads) {
std::ostringstream msg;
msg << "[scaled_dot_product_attention] Received invalid shape for sinks "
<< sinks->shape() << ".";
throw std::invalid_argument(msg.str());
}
inputs.push_back(astype(*sinks, final_type, stream));
}
if (!ScaledDotProductAttention::use_fallback(
q, k, v, has_mask, has_arr_mask, do_causal, stream)) {
auto out_shape = Shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)};
@@ -750,7 +808,7 @@ array scaled_dot_product_attention(
std::move(out_shape),
final_type,
std::make_shared<ScaledDotProductAttention>(
stream, fallback, scale, do_causal),
stream, fallback, scale, do_causal, has_sinks),
std::move(inputs));
}
return fallback(std::move(inputs))[0];
@@ -759,7 +817,8 @@ array scaled_dot_product_attention(
bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const {
const ScaledDotProductAttention& a_other =
static_cast<const ScaledDotProductAttention&>(other);
return scale_ == a_other.scale_ && do_causal_ == a_other.do_causal_;
return scale_ == a_other.scale_ && do_causal_ == a_other.do_causal_ &&
has_sinks_ == a_other.has_sinks_;
}
bool Quantize::is_equivalent(const Primitive& other) const {

View File

@@ -50,6 +50,7 @@ array scaled_dot_product_attention(
const float scale,
const std::string& mask_mode = "",
const std::vector<array>& mask_arrs = {},
const std::optional<array>& sinks = {},
StreamOrDevice s = {});
using TemplateArg = std::variant<int, bool, Dtype>;

View File

@@ -208,9 +208,13 @@ class ScaledDotProductAttention : public Custom {
explicit ScaledDotProductAttention(
Stream stream,
std::function<std::vector<array>(std::vector<array>)> fallback,
const float scale,
const bool do_causal)
: Custom(stream, fallback), scale_(scale), do_causal_(do_causal) {}
float scale,
bool do_causal,
bool has_sinks)
: Custom(stream, fallback),
scale_(scale),
do_causal_(do_causal),
has_sinks_(has_sinks) {}
static bool use_fallback(
const array& q,
@@ -237,12 +241,13 @@ class ScaledDotProductAttention : public Custom {
DEFINE_NAME(ScaledDotProductAttention);
DEFINE_INPUT_OUTPUT_SHAPE()
auto state() const {
return std::make_tuple(nullptr, scale_, do_causal_);
return std::make_tuple(nullptr, scale_, do_causal_, has_sinks_);
}
private:
float scale_;
bool do_causal_;
bool has_sinks_;
};
class Quantize : public Custom {

View File

@@ -1,10 +1,8 @@
// Copyright © 2025 Apple Inc.
#include <string>
namespace mlx::core {
std::string version() {
const char* version() {
return MLX_VERSION;
}

View File

@@ -4,7 +4,7 @@
#define MLX_VERSION_MAJOR 0
#define MLX_VERSION_MINOR 29
#define MLX_VERSION_PATCH 0
#define MLX_VERSION_PATCH 1
#define MLX_VERSION_NUMERIC \
(100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH)
@@ -15,6 +15,6 @@ namespace mlx::core {
*
* For dev builds, the version will include the suffix ".devYYYYMMDD+hash"
*/
std::string version();
const char* version();
} // namespace mlx::core

View File

@@ -15,6 +15,7 @@ from mlx.nn.layers.activations import (
Mish,
PReLU,
ReLU,
ReLU2,
ReLU6,
Sigmoid,
SiLU,
@@ -40,6 +41,7 @@ from mlx.nn.layers.activations import (
mish,
prelu,
relu,
relu2,
relu6,
selu,
sigmoid,

View File

@@ -35,6 +35,24 @@ def relu(x):
return mx.maximum(x, 0)
@partial(mx.compile, shapeless=True)
def relu2(x):
r"""Applies the ReLU² activation function.
Applies :math:`\max(0, x)^2` element wise.
"""
return mx.square(mx.maximum(x, 0))
@partial(mx.compile, shapeless=True)
def relu6(x):
r"""Applies the Rectified Linear Unit 6.
Applies :math:`\min(\max(x, 0), 6)` element wise.
"""
return mx.minimum(mx.maximum(x, 0), 6.0)
@partial(mx.compile, shapeless=True)
def leaky_relu(x, negative_slope=0.01):
r"""Applies the Leaky Rectified Linear Unit.
@@ -62,15 +80,6 @@ def elu(x, alpha=1.0):
return mx.where(x > 0, x, alpha * (mx.exp(x) - 1))
@partial(mx.compile, shapeless=True)
def relu6(x):
r"""Applies the Rectified Linear Unit 6.
Applies :math:`\min(\max(x, 0), 6)` element wise.
"""
return mx.minimum(mx.maximum(x, 0), 6.0)
@partial(mx.compile, shapeless=True)
def softmax(x, axis=-1):
r"""Applies the Softmax function.
@@ -377,6 +386,22 @@ class ReLU(Module):
"""
@_make_activation_module(relu2)
class ReLU2(Module):
r"""Applies the ReLU² activation function.
See :func:`relu2` for the functional equivalent.
"""
@_make_activation_module(relu6)
class ReLU6(Module):
r"""Applies the Rectified Linear Unit 6.
See :func:`relu6` for the functional equivalent.
"""
class LeakyReLU(Module):
r"""Applies the Leaky Rectified Linear Unit.
@@ -412,14 +437,6 @@ class ELU(Module):
return elu(x, self._alpha)
@_make_activation_module(relu6)
class ReLU6(Module):
r"""Applies the Rectified Linear Unit 6.
See :func:`relu6` for the functional equivalent.
"""
@_make_activation_module(softmax)
class Softmax(Module):
r"""Applies the Softmax function.

View File

@@ -52,7 +52,6 @@ set_target_properties(
${MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY})
target_link_libraries(core PRIVATE mlx)
target_compile_definitions(core PRIVATE _VERSION_=${MLX_VERSION})
if(BUILD_SHARED_LIBS)
if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")

View File

@@ -164,8 +164,13 @@ void init_fast(nb::module_& parent_module) {
R"pbdoc(
Apply rotary positional encoding to the input.
The input is expected to be at least 3D with shape ``(B, *, T, D)`` where:
* ``B`` is the batch size.
* ``T`` is the sequence length.
* ``D`` is the feature dimension.
Args:
a (array): Input array.
a (array): The input array.
dims (int): The feature dimensions to be rotated. If the input feature
is larger than dims then the rest is left unchanged.
traditional (bool): If set to ``True`` choose the traditional
@@ -174,7 +179,9 @@ void init_fast(nb::module_& parent_module) {
each dimension in the positional encodings. Exactly one of ``base`` and
``freqs`` must be ``None``.
scale (float): The scale used to scale the positions.
offset (int or array): The position offset to start at.
offset (int or array): The position offset to start at. If an
:obj:`array` is given it can be a scalar or vector of ``B``
offsets for each example in the batch.
freqs (array, optional): Optional frequencies to use with RoPE.
If set, the ``base`` parameter must be ``None``. Default: ``None``.
@@ -189,6 +196,7 @@ void init_fast(nb::module_& parent_module) {
const mx::array& values,
const float scale,
const std::variant<std::monostate, std::string, mx::array>& mask,
const std::optional<mx::array>& sinks,
mx::StreamOrDevice s) {
bool has_mask = !std::holds_alternative<std::monostate>(mask);
bool has_str_mask =
@@ -205,16 +213,16 @@ void init_fast(nb::module_& parent_module) {
throw std::invalid_argument(msg.str());
}
return mx::fast::scaled_dot_product_attention(
queries, keys, values, scale, mask_str, {}, s);
queries, keys, values, scale, mask_str, {}, sinks, s);
} else {
auto mask_arr = std::get<mx::array>(mask);
return mx::fast::scaled_dot_product_attention(
queries, keys, values, scale, "", {mask_arr}, s);
queries, keys, values, scale, "", {mask_arr}, sinks, s);
}
} else {
return mx::fast::scaled_dot_product_attention(
queries, keys, values, scale, "", {}, s);
queries, keys, values, scale, "", {}, sinks, s);
}
},
"q"_a,
@@ -223,9 +231,10 @@ void init_fast(nb::module_& parent_module) {
nb::kw_only(),
"scale"_a,
"mask"_a = nb::none(),
"sinks"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
"def scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: Union[None, str, array] = None, stream: Union[None, Stream, Device] = None) -> array"),
"def scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: Union[None, str, array] = None, sinks: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
A fast implementation of multi-head attention: ``O = softmax(Q @ K.T, dim=-1) @ V``.
@@ -255,14 +264,17 @@ void init_fast(nb::module_& parent_module) {
q (array): Queries with shape ``[B, N_q, T_q, D]``.
k (array): Keys with shape ``[B, N_kv, T_kv, D]``.
v (array): Values with shape ``[B, N_kv, T_kv, D]``.
scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``)
mask (Union[None, str, array], optional): The mask to apply to the
scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``).
mask (str or array, optional): The mask to apply to the
query-key scores. The mask can be an array or a string indicating
the mask type. The only supported string type is ``"causal"``. If
the mask is an array it can be a boolean or additive mask. The mask
can have at most 4 dimensions and must be broadcast-compatible with
the shape ``[B, N, T_q, T_kv]``. If an additive mask is given its
type must promote to the promoted type of ``q``, ``k``, and ``v``.
sinks (array, optional): An optional array of attention sinks.
Default: ``None``.
Returns:
array: The output array.

View File

@@ -2,9 +2,9 @@
#include <nanobind/nanobind.h>
#define STRINGIFY(x) #x
#define TOSTRING(x) STRINGIFY(x)
#include "mlx/version.h"
namespace mx = mlx::core;
namespace nb = nanobind;
void init_mlx_func(nb::module_&);
@@ -48,5 +48,5 @@ NB_MODULE(core, m) {
init_distributed(m);
init_export(m);
m.attr("__version__") = TOSTRING(_VERSION_);
m.attr("__version__") = mx::version();
}

View File

@@ -91,7 +91,7 @@ mx::array to_array_with_accessor(nb::object obj) {
return nb::cast<mx::array>(obj.attr("__mlx_array__")());
} else {
std::ostringstream msg;
msg << "Invalid type " << nb::type_name(obj.type()).c_str()
msg << "Invalid type " << nb::type_name(obj.type()).c_str()
<< " received in array initialization.";
throw std::invalid_argument(msg.str());
}

View File

@@ -594,124 +594,123 @@ class TestBlas(mlx_tests.MLXTestCase):
np.random.seed(0)
# Batched matmul
alpha = 0.5
beta = 2.0
for beta in (1.0, 2.0):
# c must broadcast to the output shape
with self.assertRaises(ValueError):
mx.addmm(mx.zeros((2, 2, 2)), mx.zeros((2, 2)), mx.zeros((2, 2)))
# c must broadcast to the output shape
with self.assertRaises(ValueError):
mx.addmm(mx.zeros((2, 2, 2)), mx.zeros((2, 2)), mx.zeros((2, 2)))
# Regular batched case
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 16)).astype(np.float32)
# Regular batched case
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 16)).astype(np.float32)
a_mlx = mx.array(a_npy)
b_mlx = mx.array(b_npy)
a_mlx = mx.array(a_npy)
b_mlx = mx.array(b_npy)
for c_shape in ((1,), (1, 16), (32, 1, 16), (1, 128, 16)):
c_npy = np.ones(c_shape).astype(np.float32)
c_mlx = mx.array(c_npy)
for c_shape in ((1,), (1, 16), (32, 1, 16), (1, 128, 16)):
c_npy = np.ones(c_shape).astype(np.float32)
c_mlx = mx.array(c_npy)
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
# Batched and transposed matmul
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
b_mlx = mx.array(b_npy)
# Batched and transposed matmul
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
b_mlx = mx.array(b_npy)
for c_shape in ((1,), (32, 1, 128), (1, 128)):
c_npy = np.ones(c_shape).astype(np.float32)
c_mlx = mx.array(c_npy)
for c_shape in ((1,), (32, 1, 128), (1, 128)):
c_npy = np.ones(c_shape).astype(np.float32)
c_mlx = mx.array(c_npy)
b_np_t = np.transpose(b_npy, (0, 2, 1))
b_mx_t = mx.transpose(b_mlx, (0, 2, 1))
b_np_t = np.transpose(b_npy, (0, 2, 1))
b_mx_t = mx.transpose(b_mlx, (0, 2, 1))
d_npy = alpha * (a_npy @ b_np_t) + beta * c_npy
d_mlx = mx.addmm(c_mlx, a_mlx, b_mx_t, alpha, beta)
d_npy = alpha * (a_npy @ b_np_t) + beta * c_npy
d_mlx = mx.addmm(c_mlx, a_mlx, b_mx_t, alpha, beta)
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
# Batched matmul with simple broadcast
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (16, 16)).astype(np.float32)
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
# Batched matmul with simple broadcast
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (16, 16)).astype(np.float32)
a_mlx = mx.array(a_npy)
b_mlx = mx.array(b_npy)
a_mlx = mx.array(a_npy)
b_mlx = mx.array(b_npy)
for c_shape in ((1,), (1, 16), (32, 1, 16), (1, 128, 16)):
c_npy = np.ones(c_shape).astype(np.float32)
c_mlx = mx.array(c_npy)
for c_shape in ((1,), (1, 16), (32, 1, 16), (1, 128, 16)):
c_npy = np.ones(c_shape).astype(np.float32)
c_mlx = mx.array(c_npy)
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
# Matmul with vector
a_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 128)).astype(np.float32)
a_mlx = mx.array(a_npy)
b_mlx = mx.array(b_npy)
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
# Matmul with vector
a_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 128)).astype(np.float32)
a_mlx = mx.array(a_npy)
b_mlx = mx.array(b_npy)
for c_shape in ((1,), (128,), (32, 128)):
c_npy = np.ones(c_shape).astype(np.float32)
c_mlx = mx.array(c_npy)
for c_shape in ((1,), (128,), (32, 128)):
c_npy = np.ones(c_shape).astype(np.float32)
c_mlx = mx.array(c_npy)
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
# Matmul with vector
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32)
a_mlx = mx.array(a_npy)
b_mlx = mx.array(b_npy)
# Matmul with vector
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32)
a_mlx = mx.array(a_npy)
b_mlx = mx.array(b_npy)
for c_shape in ((1,), (32, 128)):
c_npy = np.ones(c_shape).astype(np.float32)
c_mlx = mx.array(c_npy)
for c_shape in ((1,), (32, 128)):
c_npy = np.ones(c_shape).astype(np.float32)
c_mlx = mx.array(c_npy)
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
# Split K specializtion
a_npy = np.random.normal(0.0, 1.0 / 128, (64, 4096)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (4096, 32)).astype(np.float32)
# Split K specializtion
a_npy = np.random.normal(0.0, 1.0 / 128, (64, 4096)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (4096, 32)).astype(np.float32)
a_mlx = mx.array(a_npy)
b_mlx = mx.array(b_npy)
a_mlx = mx.array(a_npy)
b_mlx = mx.array(b_npy)
for c_shape in ((1,), (1, 32), (64, 1), (64, 32)):
c_npy = np.ones(c_shape).astype(np.float32)
c_mlx = mx.array(c_npy)
for c_shape in ((1,), (1, 32), (64, 1), (64, 32)):
c_npy = np.ones(c_shape).astype(np.float32)
c_mlx = mx.array(c_npy)
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
# Transposed c
a = mx.ones((10, 5)).T
b = mx.ones((5, 5))
out = mx.addmm(a, b, a, beta=beta, alpha=alpha)
expected = beta * a + alpha * (b @ a)
self.assertTrue(mx.allclose(expected, out))
# Transposed c
a = mx.ones((10, 5)).T
b = mx.ones((5, 5))
out = mx.addmm(a, b, a, beta=1.5, alpha=0.5)
expected = 1.5 * a + 0.5 * (b @ a)
self.assertTrue(mx.allclose(expected, out))
# Broadcast c
a = mx.ones((5, 5))
b = mx.ones((5, 5))
c = mx.ones((1, 5))
out = mx.addmm(c, a, b, beta=1.5, alpha=0.5)
expected = 1.5 * c + 0.5 * (a @ b)
self.assertTrue(mx.allclose(expected, out))
# Broadcast c
a = mx.ones((5, 5))
b = mx.ones((5, 5))
c = mx.ones((1, 5))
out = mx.addmm(c, a, b, beta=beta, alpha=alpha)
expected = beta * c + alpha * (a @ b)
self.assertTrue(mx.allclose(expected, out))
def test_addmm_grad(self):
def make_ref_addmm(alpha, beta):
@@ -724,33 +723,32 @@ class TestBlas(mlx_tests.MLXTestCase):
shapes = ((1, 64, 32, 128), (4, 28, 24, 47), (1, 1, 24, 47))
alpha = 2.0
beta = 0.5
for beta in (1.0, 0.5):
f_test = make_addmm(alpha, beta)
f_ref = make_ref_addmm(alpha, beta)
f_test = make_addmm(alpha, beta)
f_ref = make_ref_addmm(alpha, beta)
for B, M, N, K in shapes:
cotan = mx.ones((B, M, N))
c = mx.random.normal((B, M, N))
a = mx.random.normal((B, M, K))
b = mx.random.normal((B, K, N))
for B, M, N, K in shapes:
cotan = mx.ones((B, M, N))
c = mx.random.normal((B, M, N))
a = mx.random.normal((B, M, K))
b = mx.random.normal((B, K, N))
out_ref, dout_ref = mx.vjp(
f_ref,
[c, a, b],
[cotan],
)
out_test, dout_test = mx.vjp(
f_test,
[c, a, b],
[cotan],
)
out_ref, dout_ref = mx.vjp(
f_ref,
[c, a, b],
[cotan],
)
out_test, dout_test = mx.vjp(
f_test,
[c, a, b],
[cotan],
)
self.assertTrue(mx.allclose(out_ref[0], out_test[0], atol=1e-4).item())
self.assertTrue(mx.allclose(out_ref[0], out_test[0], atol=1e-4).item())
for r, t in zip(dout_ref, dout_test):
self.assertEqual(r.shape, t.shape)
self.assertTrue(mx.allclose(r, t, atol=1e-4).item())
for r, t in zip(dout_ref, dout_test):
self.assertEqual(r.shape, t.shape)
self.assertTrue(mx.allclose(r, t, atol=1e-4).item())
def test_empty_matmul(self):
a = mx.array([[], []]).T

View File

@@ -8,18 +8,23 @@ import mlx_tests
def rope_orig(x, dims, traditional, base, scale, offset, freqs=None):
offset = offset.item() if isinstance(offset, mx.array) else offset
N = x.shape[-2] + offset
N = x.shape[-2]
dtype = x.dtype
half_D = dims // 2
positions = mx.arange(offset, N, dtype=dtype) * scale
positions = mx.arange(N, dtype=dtype)
if isinstance(offset, mx.array) and offset.size > 1:
expand = tuple(range(1, x.ndim - 1))
positions = mx.expand_dims(offset, expand) + positions
else:
positions = offset + positions
positions = positions * scale
if freqs is None:
inv_freqs = mx.exp(
-mx.arange(0.0, half_D, dtype=dtype) * (math.log(base) / half_D)
)
else:
inv_freqs = (1 / freqs).astype(x.dtype)
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(inv_freqs, (1, -1))
theta = mx.expand_dims(positions, -1) * inv_freqs
costheta, sintheta = mx.cos(theta), mx.sin(theta)
if traditional:
x1 = x[..., :dims:2]
@@ -214,6 +219,7 @@ class TestFast(mlx_tests.MLXTestCase):
)
self.assertEqual(dtype, rx.dtype)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
return
# Test single vector
x = mx.random.uniform(shape=(1, 1, dims))
@@ -277,6 +283,55 @@ class TestFast(mlx_tests.MLXTestCase):
g2 = mx.grad(f2)(x, y)
self.assertLess(mx.abs(g1 - g2).max(), 1e-5)
def test_rope_batch(self):
T = 4
base = 10000.0
scale = 1.0
traditional = True
batch_sizes = [3, 8, 11]
num_heads = [1, 3, 5]
dims = 32
x = mx.random.uniform(shape=(8, 4, T, dims))
offset = mx.array([1, 2, 3])
with self.assertRaises(ValueError):
mx.fast.rope(
x,
dims,
traditional=traditional,
base=base,
scale=scale,
offset=offset,
)
for batch_size in batch_sizes:
for n_head in num_heads:
x = mx.random.uniform(shape=(batch_size, n_head, T, dims))
offset = mx.arange(batch_size)
rx = rope_orig(x, dims, traditional, base, scale, offset)
rx_fast = mx.fast.rope(
x,
dims,
traditional=traditional,
base=base,
scale=scale,
offset=offset,
)
self.assertLess(mx.abs(rx - rx_fast).max(), 1e-5)
x = mx.random.normal(shape=(2, 6, 8, 64)).transpose(0, 2, 1, 3)
dims = 64
offset = 0
rx_fast = mx.fast.rope(
x, dims, traditional=traditional, scale=scale, base=base, offset=offset
)
rx_fast_single = mx.fast.rope(
x[0:1], dims, traditional=traditional, scale=scale, base=base, offset=offset
)
rx = rope_orig(x, dims, traditional, base, scale, offset)
self.assertLess(mx.abs(rx - rx_fast).max(), 1e-5)
def test_rms_norm(self):
# Per dtype absolute tolerance
tolerances = {mx.float32: 1e-6, mx.float16: 1e-3, mx.bfloat16: 1e-2}

View File

@@ -6,7 +6,7 @@ import mlx_tests
import numpy as np
def mlx_ref_attn(q, k, v, scale=1.0, mask=None):
def mlx_ref_attn(q, k, v, scale=1.0, mask=None, sinks=None):
q_dtype = q.dtype
q = q * mx.array(scale, q_dtype)
n_q_heads = q.shape[-3]
@@ -23,7 +23,6 @@ def mlx_ref_attn(q, k, v, scale=1.0, mask=None):
v = mx.expand_dims(v, 2)
scores = q @ mx.swapaxes(k, -1, -2)
if mask is not None:
if mask == "causal":
@@ -43,7 +42,18 @@ def mlx_ref_attn(q, k, v, scale=1.0, mask=None):
else:
scores += mask
if sinks is not None:
sinks = mx.expand_dims(sinks, (0, 2, 3))
if n_repeats > 1:
sinks = mx.unflatten(sinks, 1, (n_kv_heads, n_repeats))
score_shape = list(scores.shape)
score_shape[-1] = 1
sinks = mx.broadcast_to(sinks, score_shape)
scores = mx.concatenate([sinks, scores], axis=-1)
scores = mx.softmax(scores, axis=-1, precise=True)
if sinks is not None:
scores = scores[..., 1:]
out = scores @ v
if n_repeats > 1:
@@ -158,7 +168,7 @@ class TestFastSelfAttentionSDPA(mlx_tests.MLXTestCase):
Dk = 64
if self.is_apple_silicon:
if self.is_apple_silicon or mx.cuda.is_available():
dtypes.append(np.half)
for SEQUENCE_LENGTH in [63, 129, 400]:
@@ -230,7 +240,7 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
B = 1
H = 32
dtypes = [np.float32]
if self.is_apple_silicon:
if self.is_apple_silicon or mx.cuda.is_available():
dtypes.append(np.half)
for SEQUENCE_LENGTH in [1, 7, 9, 32, 63, 67, 129, 400, 2000]:
@@ -400,15 +410,30 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
def test_fully_masked(self):
Lkv = 8
mask = mx.array(False)
masks = [mx.array(False), mx.array(-float("inf"))]
for mask in masks:
for D in [4, 128]:
for Lq in [1, 8]:
q = mx.random.normal(shape=(1, 4, Lq, D))
k = mx.random.normal(shape=(1, 4, Lkv, D))
v = mx.random.normal(shape=(1, 4, Lkv, D))
out = mx.fast.scaled_dot_product_attention(
q, k, v, mask=mask, scale=1
)
self.assertTrue(mx.all(mx.isnan(out)))
def test_inf_score(self):
Lkv = 8
for D in [4, 128]:
for Lq in [1, 8]:
q = mx.random.normal(shape=(1, 4, Lq, D))
k = mx.random.normal(shape=(1, 4, Lkv, D))
q = mx.ones(shape=(1, 4, Lq, D))
k = mx.ones(shape=(1, 4, Lkv, D))
v = mx.random.normal(shape=(1, 4, Lkv, D))
out = mx.fast.scaled_dot_product_attention(q, k, v, mask=mask, scale=1)
self.assertTrue(mx.all(mx.isnan(out)))
k[..., 0, :] = -float("inf")
ref = mlx_primitives_sdpa(q, k, v, scale=1, mask=None)
out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1)
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
def test_fast_sdpa_few_query(self):
D = 64
@@ -674,6 +699,51 @@ class TestSDPA(mlx_tests.MLXTestCase):
self.assertFalse(mx.isnan(out).any().item())
self.assertLessEqual(mx.abs(out - expected).max().item(), 1e-4)
def test_sdpa_attention_sinks(self):
B = 2
N_q = N_kv = 8
T_q = T_kv = 128
D = 64
q = mx.random.normal(shape=(B, N_q, T_q, D))
k = mx.random.normal(shape=(B, N_kv, T_kv, D))
v = mx.random.normal(shape=(B, N_kv, T_kv, D))
scale = D**-0.5
# sinks should promote to correct type
sinks = mx.random.normal(shape=(N_q,))
with self.assertRaises(ValueError):
mx.fast.scaled_dot_product_attention(
q.astype(mx.float16),
k.astype(mx.float16),
v.astype(mx.float16),
scale=scale,
sinks=sinks,
)
# Wrong shapes
sinks = mx.random.normal(shape=(N_q + 1,))
with self.assertRaises(ValueError):
mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, sinks=sinks)
sinks = mx.random.normal(shape=())
with self.assertRaises(ValueError):
mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, sinks=sinks)
for T_kv in [128, 4096]:
for T_q in [1, 128]:
for N_kv in [2, 8]:
q = mx.random.normal(shape=(B, N_q, T_q, D))
k = mx.random.normal(shape=(B, N_kv, T_kv, D))
v = mx.random.normal(shape=(B, N_kv, T_kv, D))
sinks = 10 * mx.random.normal(shape=(N_q,))
expected = mlx_ref_attn(q, k, v, scale, sinks=sinks)
out = mx.fast.scaled_dot_product_attention(
q, k, v, scale=scale, sinks=sinks
)
self.assertTrue(mx.allclose(out, expected, atol=1e-5))
if __name__ == "__main__":
mlx_tests.MLXTestRunner(failfast=True)

View File

@@ -121,7 +121,8 @@ class CMakeBuild(build_ext):
build_args += [f"-j{os.cpu_count()}"]
# Avoid cache miss when building from temporary dirs.
os.environ["CCACHE_BASEDIR"] = os.path.abspath(self.build_temp)
os.environ["CCACHE_BASEDIR"] = os.path.realpath(self.build_temp)
os.environ["CCACHE_NOHASHDIR"] = "true"
subprocess.run(
["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True

View File

@@ -1,10 +1,7 @@
# Doctest works fine with cmake 3.5
set(CMAKE_POLICY_VERSION_MINIMUM 3.5)
FetchContent_Declare(
doctest
GIT_REPOSITORY "https://github.com/onqtam/doctest"
GIT_TAG "ae7a13539fb71f270b87eb2e874fbac80bc8dda2")
GIT_REPOSITORY https://github.com/onqtam/doctest.git
GIT_TAG v2.4.12)
FetchContent_MakeAvailable(doctest)
add_executable(tests ${PROJECT_SOURCE_DIR}/tests/tests.cpp)