Compare commits

..

19 Commits

Author SHA1 Message Date
Awni Hannun
ee18e1cbf0 patch bump (#2588) 2025-09-11 17:10:09 -07:00
Awni Hannun
af120c2bc0 set nccl ABI version (#2587) 2025-09-11 16:55:53 -07:00
Cheng
6a3acf2301 [CUDA] Set bias as input when using bias epilogue (#2584) 2025-09-11 15:31:09 +09:00
Awni Hannun
d6977f2a57 Add sdpa with sinks (#2558)
* add sdpa with sinks

* fix 2 pass

* fix matrix sdpa

* fix perf regression

* add to cuda (#2580)
2025-09-10 14:53:00 -07:00
Gökdeniz Gülmez
db5443e831 Adding Relu2 (#2582)
* in. com.

* upd. ackn.

* update __init__

* nits

* nits + format

* used mx.maximum(x, 0) instead of calling the function and moves relu6 under relu2 to make it nicer

* same with _make_activation_module

* Update python/mlx/nn/layers/activations.py

upd

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* update funct.rst

* upd. layers.rst

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2025-09-10 07:24:30 -07:00
Cheng
52b8384d10 Fix flaky addmm tests (#2581) 2025-09-10 14:22:22 +09:00
Cheng
44cc5da4bc [CUDA] Fix alpha not respected when using bias epilogue (#2578) 2025-09-10 09:08:01 +09:00
Cheng
dde3682b69 [CUDA] Use GEMM with epilogue instead of AddMM (#2569) 2025-09-09 13:18:49 +09:00
Awni Hannun
17310d91a6 Add batch offsets for mx.fast.rope (#2564)
* implement batch rope for Metal

* cuda rope (#2576)
2025-09-08 17:35:07 -07:00
Cheng
b194d65a6a Some tweaks in cmake files (#2574)
* Do proper check of Metal lib

* Update doctest to get rid of cmake version hack
2025-09-09 08:27:18 +09:00
Cheng
a44b27f5f8 Fix a few ccache cache miss (#2573)
* Fix ccache cache miss

* Do not define _VERSION_ in python bindings
2025-09-09 07:41:05 +09:00
Awni Hannun
e5a33f2223 faster depthwise 1D conv (#2567) 2025-09-08 11:37:23 -07:00
Cheng
c1e3340b23 Set ccache size before building (#2570) 2025-09-07 09:00:31 +09:00
XXXXRT666
8f163a367d typing: add type hints to mlx.core.array, linalg, distributed, and random (#2565)
* Add type annotations to mlx methods

* Missing list_or_scalar
2025-09-04 09:08:11 -07:00
Manuel Villanueva
89a3df9014 Fixed several type annotations in the MLX stubs which degraded to Unknown/Any (#2560)
* Added scalar to stubs to fix Unkown Type Hint

### Proposed changes

Issue #2478 reports that several type annotations in the MLX stubs degrade to Unknown/Any in editors like VS Code with Pylance, due to missing imports (Union, Optional, Tuple) and an undefined scalar type alias.

This PR updates the stub generation patterns to:
	•	Add missing typing imports in mlx.core.__prefix__ so that Union, Optional, Tuple, etc. are always available.
	•	Define and export scalar: TypeAlias = Union[int, float, bool] in mlx.core.__suffix__ so that functions typed with Union[scalar, array] resolve correctly instead of falling back to Any.
	•	Update submodule stub prefixes (distributed, fast, linalg, metal, random) to import scalar alongside array, Device, and Stream, ensuring type checkers resolve the union consistently across modules.

With these changes, functions like mlx.add now display rich type signatures such as:

```
def add(
    a: scalar | array,
    b: scalar | array,
    stream: Stream | Device | None = None
) -> array
```

instead of degrading to Any.

### Checklist

	•	I have read the CONTRIBUTING document
	•	I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
	•	I have added tests that prove my fix is effective or that my feature works (n/a — stub generation only)
	•	I have updated the necessary documentation (if needed)

* add bool to patterns

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-09-03 12:52:08 -07:00
Krishi Saripalli
c5d2937aa5 chore: Update Docs With Slice Copy Example (#2559)
* chore: updated docs with slice copy example

* nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-09-02 22:07:02 -07:00
Awni Hannun
b61a65e313 fix copies in sdpa (#2563) 2025-09-02 11:00:36 -07:00
wrmsr
04cbb4191c Fix dequantize python sig (#2562) 2025-09-01 11:50:20 -07:00
Artur Antonov
c5460762e7 Fix AdamW weight_decay default value in docstring (#2557) 2025-08-31 21:29:30 -07:00
43 changed files with 1141 additions and 560 deletions

View File

@@ -230,6 +230,9 @@ jobs:
sudo mv ccache-4.11.3-linux-x86_64/ccache /usr/bin/ccache sudo mv ccache-4.11.3-linux-x86_64/ccache /usr/bin/ccache
rm -rf ccache-4.11.3-linux-x86_64 rm -rf ccache-4.11.3-linux-x86_64
curl -LsSf https://astral.sh/uv/install.sh | sh curl -LsSf https://astral.sh/uv/install.sh | sh
- run:
name: Set CCache size
command: ccache --max-size 1G
- run: - run:
name: Install Python package name: Install Python package
command: | command: |
@@ -260,7 +263,6 @@ jobs:
command: | command: |
ccache --show-stats ccache --show-stats
ccache --zero-stats ccache --zero-stats
ccache --max-size 400MB
ccache --cleanup ccache --cleanup
- save_cache: - save_cache:
key: cuda-<< parameters.image_date >>-{{ arch }}-{{ epoch }} key: cuda-<< parameters.image_date >>-{{ arch }}-{{ epoch }}

View File

@@ -19,7 +19,7 @@ MLX was developed with contributions from the following individuals:
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions. - Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
- Paul Paczuski: Improved stability of BCE loss calculation - Paul Paczuski: Improved stability of BCE loss calculation
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops. - Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer. - Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer, and the `ReLU²` activation function.
<a href="https://github.com/ml-explore/mlx/graphs/contributors"> <a href="https://github.com/ml-explore/mlx/graphs/contributors">
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" /> <img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />

View File

@@ -87,22 +87,21 @@ cmake_policy(SET CMP0135 NEW)
add_library(mlx) add_library(mlx)
if(MLX_BUILD_METAL)
set(METAL_LIB "-framework Metal")
set(FOUNDATION_LIB "-framework Foundation")
set(QUARTZ_LIB "-framework QuartzCore")
endif()
if(MLX_BUILD_CUDA) if(MLX_BUILD_CUDA)
enable_language(CUDA) enable_language(CUDA)
endif() endif()
if(MLX_BUILD_METAL AND NOT METAL_LIB) if(MLX_BUILD_METAL)
message(STATUS "Metal not found. Unable to build GPU") find_library(METAL_LIB Metal)
set(MLX_BUILD_METAL OFF) find_library(FOUNDATION_LIB Foundation)
set(MLX_METAL_DEBUG OFF) find_library(QUARTZ_LIB QuartzCore)
elseif(MLX_BUILD_METAL) if(METAL_LIB)
message(STATUS "Building METAL sources") message(STATUS "Metal found ${METAL_LIB}")
else()
message(
FATAL_ERROR
"Metal not found. Set MLX_BUILD_METAL=OFF to build without GPU")
endif()
if(MLX_METAL_DEBUG) if(MLX_METAL_DEBUG)
add_compile_definitions(MLX_METAL_DEBUG) add_compile_definitions(MLX_METAL_DEBUG)
@@ -111,7 +110,8 @@ elseif(MLX_BUILD_METAL)
# Throw an error if xcrun not found # Throw an error if xcrun not found
execute_process( execute_process(
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version" COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
OUTPUT_VARIABLE MACOS_SDK_VERSION COMMAND_ERROR_IS_FATAL ANY) OUTPUT_VARIABLE MACOS_SDK_VERSION
OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
if(${MACOS_SDK_VERSION} LESS 14.0) if(${MACOS_SDK_VERSION} LESS 14.0)
message( message(

View File

@@ -27,6 +27,7 @@ simple functions.
mish mish
prelu prelu
relu relu
relu2
relu6 relu6
selu selu
sigmoid sigmoid

View File

@@ -50,6 +50,7 @@ Layers
QuantizedLinear QuantizedLinear
RMSNorm RMSNorm
ReLU ReLU
ReLU2
ReLU6 ReLU6
RNN RNN
RoPE RoPE

View File

@@ -107,8 +107,20 @@ same array:
>>> a >>> a
array([1, 2, 0], dtype=int32) array([1, 2, 0], dtype=int32)
Note that unlike NumPy, slicing an array creates a copy, not a view. So
mutating it does not mutate the original array:
Note, unlike NumPy, updates to the same location are nondeterministic: .. code-block:: shell
>>> a = mx.array([1, 2, 3])
>>> b = a[:]
>>> b[2] = 0
>>> b
array([1, 2, 0], dtype=int32)
>>> a
array([1, 2, 3], dtype=int32)
Also unlike NumPy, updates to the same location are nondeterministic:
.. code-block:: shell .. code-block:: shell

View File

@@ -85,10 +85,10 @@ cublasLtMatrixLayout_t create_matrix_layout(
int32_t batch_count, int32_t batch_count,
int64_t batch_stride) { int64_t batch_stride) {
cublasLtMatrixLayout_t desc; cublasLtMatrixLayout_t desc;
if (transposed) {
std::swap(rows, cols);
}
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&desc, type, rows, cols, ld)); CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&desc, type, rows, cols, ld));
cublasLtOrder_t order = transposed ? CUBLASLT_ORDER_COL : CUBLASLT_ORDER_ROW;
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(cublasLtOrder_t)));
if (batch_count > 1) { if (batch_count > 1) {
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc, desc,
@@ -138,25 +138,34 @@ CublasGemm::CublasGemm(
CUBLASLT_MATMUL_DESC_POINTER_MODE, CUBLASLT_MATMUL_DESC_POINTER_MODE,
&pointer_mode, &pointer_mode,
sizeof(int32_t))); sizeof(int32_t)));
cublasOperation_t op = CUBLAS_OP_N;
// In cublasLt matrices use column-major layout, while it is possible to use
// the CUBLASLT_ORDER_ROW option to switch to row-major layout, the bias
// epilogue does not work with the option. So instead we swap A and B to make
// cublasLt return the row-major result, which works because:
// - the data of a matrix in row-major layout is identical to its transpose in
// column-major layout
// - C^T = (A @ B)^T = B^T @ A^T
cublasOperation_t a_op = b_transposed ? CUBLAS_OP_T : CUBLAS_OP_N;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_, matmul_desc_,
CUBLASLT_MATMUL_DESC_TRANSA, CUBLASLT_MATMUL_DESC_TRANSA,
&op, &a_op,
sizeof(cublasOperation_t))); sizeof(cublasOperation_t)));
cublasOperation_t b_op = a_transposed ? CUBLAS_OP_T : CUBLAS_OP_N;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_, matmul_desc_,
CUBLASLT_MATMUL_DESC_TRANSB, CUBLASLT_MATMUL_DESC_TRANSB,
&op, &b_op,
sizeof(cublasOperation_t))); sizeof(cublasOperation_t)));
auto type = dtype_to_cublas_type(dtype); auto type = dtype_to_cublas_type(dtype);
a_desc_ = create_matrix_layout( a_desc_ = create_matrix_layout(
type, a_rows, a_cols, a_transposed, lda, batch_count, a_batch_stride); type, b_cols, b_rows, b_transposed, ldb, batch_count, b_batch_stride);
b_desc_ = create_matrix_layout( b_desc_ = create_matrix_layout(
type, b_rows, b_cols, b_transposed, ldb, batch_count, b_batch_stride); type, a_cols, a_rows, a_transposed, lda, batch_count, a_batch_stride);
out_desc_ = create_matrix_layout( out_desc_ = create_matrix_layout(
type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols); type, b_cols, a_rows, false, b_cols, batch_count, a_rows * b_cols);
} }
CublasGemm::CublasGemm( CublasGemm::CublasGemm(
@@ -191,7 +200,7 @@ CublasGemm::CublasGemm(
b_batch_stride) { b_batch_stride) {
auto type = dtype_to_cublas_type(dtype); auto type = dtype_to_cublas_type(dtype);
c_desc_ = create_matrix_layout( c_desc_ = create_matrix_layout(
type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride); type, b_cols, a_rows, false, ldc, batch_count, c_batch_stride);
} }
CublasGemm::~CublasGemm() { CublasGemm::~CublasGemm() {
@@ -213,14 +222,30 @@ void CublasGemm::set_out(
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_)); CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_));
out_desc_ = create_matrix_layout( out_desc_ = create_matrix_layout(
dtype_to_cublas_type(dtype), dtype_to_cublas_type(dtype),
rows,
cols, cols,
rows,
transposed, transposed,
ld, ld,
batch_count, batch_count,
batch_stride); batch_stride);
} }
void CublasGemm::set_bias(cu::CommandEncoder& encoder, const array& bias) {
encoder.set_input_array(bias);
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
CUBLASLT_MATMUL_DESC_EPILOGUE,
&epilogue,
sizeof(epilogue)));
auto* bias_ptr = bias.data<void>();
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
CUBLASLT_MATMUL_DESC_BIAS_POINTER,
&bias_ptr,
sizeof(bias_ptr)));
}
void CublasGemm::run( void CublasGemm::run(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
array& out, array& out,
@@ -228,11 +253,19 @@ void CublasGemm::run(
const array& b, const array& b,
const Shape& batch_shape, const Shape& batch_shape,
const Strides& a_batch_strides, const Strides& a_batch_strides,
const Strides& b_batch_strides) { const Strides& b_batch_strides,
float alpha) {
int batch_count = out.size() / (M_ * N_); int batch_count = out.size() / (M_ * N_);
if (batch_count / batch_shape.back() > 1) { if (batch_count / batch_shape.back() > 1) {
run_batched( run_batched(
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides); encoder,
out,
a,
b,
batch_shape,
a_batch_strides,
b_batch_strides,
alpha);
return; return;
} }
@@ -240,7 +273,13 @@ void CublasGemm::run(
encoder.set_input_array(b); encoder.set_input_array(b);
encoder.set_output_array(out); encoder.set_output_array(out);
execute(encoder, out.data<void>(), a.data<void>(), b.data<void>(), nullptr); execute(
encoder,
out.data<void>(),
a.data<void>(),
b.data<void>(),
nullptr,
alpha);
} }
void CublasGemm::run( void CublasGemm::run(
@@ -330,9 +369,9 @@ void CublasGemm::execute(
handle_, handle_,
matmul_desc_, matmul_desc_,
&alpha, &alpha,
a, b, // a and b are swapped
a_desc_, a_desc_,
b, a,
b_desc_, b_desc_,
&beta, &beta,
c ? c : out, c ? c : out,

View File

@@ -55,6 +55,8 @@ class CublasGemm {
int32_t batch_count, int32_t batch_count,
int64_t batch_stride); int64_t batch_stride);
void set_bias(cu::CommandEncoder& encoder, const array& bias);
void run( void run(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
array& out, array& out,
@@ -62,7 +64,8 @@ class CublasGemm {
const array& b, const array& b,
const Shape& batch_shape, const Shape& batch_shape,
const Strides& a_batch_strides, const Strides& a_batch_strides,
const Strides& b_batch_strides); const Strides& b_batch_strides,
float alpha = 1.0f);
void run( void run(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
@@ -85,7 +88,8 @@ class CublasGemm {
const array& b, const array& b,
const Shape& batch_shape, const Shape& batch_shape,
const Strides& a_batch_strides, const Strides& a_batch_strides,
const Strides& b_batch_strides); const Strides& b_batch_strides,
float alpha);
void run_batched( void run_batched(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,

View File

@@ -13,7 +13,8 @@ void CublasGemm::run_batched(
const array& b, const array& b,
const Shape& batch_shape, const Shape& batch_shape,
const Strides& a_batch_strides, const Strides& a_batch_strides,
const Strides& b_batch_strides) { const Strides& b_batch_strides,
float alpha) {
encoder.set_input_array(a); encoder.set_input_array(a);
encoder.set_input_array(b); encoder.set_input_array(b);
encoder.set_output_array(out); encoder.set_output_array(out);
@@ -27,7 +28,8 @@ void CublasGemm::run_batched(
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_, out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
a.data<int8_t>() + a.itemsize() * a_it.loc, a.data<int8_t>() + a.itemsize() * a_it.loc,
b.data<int8_t>() + b.itemsize() * b_it.loc, b.data<int8_t>() + b.itemsize() * b_it.loc,
nullptr); nullptr,
alpha);
a_it.step(); a_it.step();
b_it.step(); b_it.step();
} }

View File

@@ -154,7 +154,8 @@ void CublasGemm::run_batched(
const array& b, const array& b,
const Shape& batch_shape, const Shape& batch_shape,
const Strides& a_batch_strides, const Strides& a_batch_strides,
const Strides& b_batch_strides) { const Strides& b_batch_strides,
float alpha) {
int batch_count = out.size() / (M_ * N_); int batch_count = out.size() / (M_ * N_);
set_pointer_mode(a_desc_, batch_count); set_pointer_mode(a_desc_, batch_count);
set_pointer_mode(b_desc_, batch_count); set_pointer_mode(b_desc_, batch_count);
@@ -226,7 +227,8 @@ void CublasGemm::run_batched(
reinterpret_cast<void*>(out_pointers), reinterpret_cast<void*>(out_pointers),
reinterpret_cast<void*>(a_pointers), reinterpret_cast<void*>(a_pointers),
reinterpret_cast<void*>(b_pointers), reinterpret_cast<void*>(b_pointers),
nullptr); nullptr,
alpha);
} }
void CublasGemm::run_batched( void CublasGemm::run_batched(

View File

@@ -11,6 +11,7 @@
#include <numeric> #include <numeric>
namespace mlx::core { namespace mlx::core {
namespace { namespace {
std::tuple<bool, int64_t, array> std::tuple<bool, int64_t, array>
@@ -28,6 +29,76 @@ check_transpose(cu::CommandEncoder& enc, const Stream& s, const array& arr) {
} }
} }
void gemm_and_bias(
cu::CommandEncoder& encoder,
int M,
int N,
int K,
bool a_transposed,
int64_t lda,
bool b_transposed,
int64_t ldb,
array& out,
const array& a,
const array& b,
const std::optional<array>& bias = std::nullopt,
float alpha = 1.0f) {
// Check and collapse batch dimensions
auto [batch_shape, a_batch_strides, b_batch_strides] = collapse_batches(a, b);
auto batch_count = out.size() / (M * N);
// Collapse batches into M if needed
if (batch_count > 1 && !a_transposed && batch_shape.size() == 1 &&
a.strides()[a.ndim() - 2] == K && a_batch_strides.back() == M * K &&
b_batch_strides.back() == 0) {
M *= batch_shape.back();
batch_count = 1;
a_batch_strides = {0};
b_batch_strides = {0};
batch_shape = {1};
}
// Use gemmv when possible
if (!bias && cu::can_use_gemv(M, N, K, a_transposed, b_transposed)) {
cu::gemv(
a,
b,
out,
M,
N,
K,
batch_count,
batch_shape,
a_batch_strides,
b_batch_strides,
encoder);
return;
}
// Invoke cublasLt
CublasGemm gemm(
encoder.device(),
a.dtype(),
a_transposed,
M,
K,
lda,
b_transposed,
K,
N,
ldb,
batch_shape.back(),
a_batch_strides.back(),
b_batch_strides.back());
if (bias) {
gemm.set_bias(encoder, *bias);
}
gemm.run(
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides, alpha);
}
} // namespace } // namespace
void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) { void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -48,9 +119,6 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
/////////////////////////////////////////////////////////////////////////////
// Init checks and prep
int M = a_pre.shape(-2); int M = a_pre.shape(-2);
int N = b_pre.shape(-1); int N = b_pre.shape(-1);
int K = a_pre.shape(-1); int K = a_pre.shape(-1);
@@ -60,58 +128,8 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre); auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre);
auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre); auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre);
///////////////////////////////////////////////////////////////////////////// gemm_and_bias(
// Check and collapse batch dimensions encoder, M, N, K, a_transposed, lda, b_transposed, ldb, out, a, b);
auto [batch_shape, a_batch_strides, b_batch_strides] = collapse_batches(a, b);
auto batch_count = out.size() / (M * N);
// Collapse batches into M if needed
if (batch_count > 1 && !a_transposed && batch_shape.size() == 1 &&
a.strides()[a.ndim() - 2] == K && a_batch_strides.back() == M * K &&
b_batch_strides.back() == 0) {
M *= batch_shape.back();
batch_count = 1;
a_batch_strides = {0};
b_batch_strides = {0};
batch_shape = {1};
}
if (cu::can_use_gemv(M, N, K, a_transposed, b_transposed)) {
cu::gemv(
a,
b,
out,
M,
N,
K,
batch_count,
batch_shape,
a_batch_strides,
b_batch_strides,
encoder);
return;
}
/////////////////////////////////////////////////////////////////////////////
// Invoke cublasLt
CublasGemm gemm(
cu::device(s.device),
a.dtype(),
a_transposed,
M,
K,
lda,
b_transposed,
K,
N,
ldb,
batch_shape.back(),
a_batch_strides.back(),
b_batch_strides.back());
gemm.run(encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
} }
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) { void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -136,6 +154,28 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre); auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre);
auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre); auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre);
/////////////////////////////////////////////////////////////////////////////
// Dispatch to GEMM with epilogue or AddMM
if (beta_ == 1 && c.strides(-1) == 1 && c.data_size() == out.shape(-1)) {
out.set_data(allocator::malloc(out.nbytes()));
gemm_and_bias(
encoder,
M,
N,
K,
a_transposed,
lda,
b_transposed,
ldb,
out,
a,
b,
c,
alpha_);
return;
}
int64_t ldc; int64_t ldc;
{ {
auto stx = c.strides()[c.ndim() - 2]; auto stx = c.strides()[c.ndim() - 2];
@@ -177,7 +217,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
} }
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////
// Invoke cublasLt // Invoke cublasLt with AddMM settings
CublasGemm gemm( CublasGemm gemm(
cu::device(s.device), cu::device(s.device),

View File

@@ -103,15 +103,21 @@ template <typename T, bool traditional, bool forward, int N = 4>
__device__ void rope_impl( __device__ void rope_impl(
const T* in, const T* in,
T* out, T* out,
int offset, const int* offset,
float inv_freq, float inv_freq,
float scale, float scale,
const cuda::std::array<int64_t, 3> strides, const cuda::std::array<int64_t, 3> strides,
const cuda::std::array<int64_t, 3> out_strides, const cuda::std::array<int64_t, 3> out_strides,
int64_t n_batch, int64_t offset_stride,
int n_head,
uint3 pos, uint3 pos,
uint3 dims) { uint3 dims) {
float L = scale * static_cast<float>(pos.y + offset); auto n_head_up = N * ((n_head + N - 1) / N);
auto head_idx = static_cast<int>((pos.z * N) % n_head_up);
auto batch_idx = (pos.z * N) / n_head_up;
auto batch_offset = offset[batch_idx * offset_stride];
float L = scale * static_cast<float>(pos.y + batch_offset);
auto mat_idx = batch_idx * n_head + head_idx;
// Compute costheta, sintheta // Compute costheta, sintheta
float theta = L * inv_freq; float theta = L * inv_freq;
@@ -123,20 +129,19 @@ __device__ void rope_impl(
size_t out_index_1, out_index_2; size_t out_index_1, out_index_2;
if (traditional) { if (traditional) {
out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] + out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] +
N * pos.z * out_strides[0]; mat_idx * out_strides[0];
out_index_2 = out_index_1 + 1; out_index_2 = out_index_1 + 1;
in_index_1 = in_index_1 =
2 * pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0]; 2 * pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0];
in_index_2 = in_index_1 + strides[2]; in_index_2 = in_index_1 + strides[2];
} else { } else {
out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] + out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] +
N * pos.z * out_strides[0]; mat_idx * out_strides[0];
out_index_2 = out_index_1 + dims.x * out_strides[2]; out_index_2 = out_index_1 + dims.x * out_strides[2];
in_index_1 = in_index_1 = pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0];
pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
in_index_2 = in_index_1 + dims.x * strides[2]; in_index_2 = in_index_1 + dims.x * strides[2];
} }
for (int i = 0; i < N && pos.z * N + i < n_batch; ++i) { for (int i = 0; i < N && head_idx + i < n_head; ++i) {
// Read and write the output // Read and write the output
float x1 = static_cast<float>(in[in_index_1]); float x1 = static_cast<float>(in[in_index_1]);
float x2 = static_cast<float>(in[in_index_2]); float x2 = static_cast<float>(in[in_index_2]);
@@ -167,7 +172,8 @@ __global__ void rope(
float base, float base,
const __grid_constant__ cuda::std::array<int64_t, 3> strides, const __grid_constant__ cuda::std::array<int64_t, 3> strides,
const __grid_constant__ cuda::std::array<int64_t, 3> out_strides, const __grid_constant__ cuda::std::array<int64_t, 3> out_strides,
int64_t n_batch, int64_t offset_stride,
int n_head,
uint3 dims) { uint3 dims) {
uint3 pos = make_uint3( uint3 pos = make_uint3(
blockIdx.x * blockDim.x + threadIdx.x, blockIdx.x * blockDim.x + threadIdx.x,
@@ -182,12 +188,13 @@ __global__ void rope(
rope_impl<T, traditional, forward>( rope_impl<T, traditional, forward>(
in, in,
out, out,
*offset, offset,
inv_freq, inv_freq,
scale, scale,
strides, strides,
out_strides, out_strides,
n_batch, offset_stride,
n_head,
pos, pos,
dims); dims);
} }
@@ -202,7 +209,8 @@ __global__ void rope_freqs(
float base, float base,
const __grid_constant__ cuda::std::array<int64_t, 3> strides, const __grid_constant__ cuda::std::array<int64_t, 3> strides,
const __grid_constant__ cuda::std::array<int64_t, 3> out_strides, const __grid_constant__ cuda::std::array<int64_t, 3> out_strides,
int64_t n_batch, int64_t offset_stride,
int n_head,
uint3 dims, uint3 dims,
int64_t freq_stride) { int64_t freq_stride) {
uint3 pos = make_uint3( uint3 pos = make_uint3(
@@ -217,12 +225,13 @@ __global__ void rope_freqs(
rope_impl<T, traditional, forward>( rope_impl<T, traditional, forward>(
in, in,
out, out,
*offset, offset,
inv_freq, inv_freq,
scale, scale,
strides, strides,
out_strides, out_strides,
n_batch, offset_stride,
n_head,
pos, pos,
dims); dims);
} }
@@ -245,23 +254,28 @@ void RoPE::eval_gpu(
auto& offset = inputs[1]; auto& offset = inputs[1];
auto& out = outputs[0]; auto& out = outputs[0];
if (in.ndim() < 3) {
throw std::runtime_error("[RoPE] Input must have at least 3 dimensions");
}
cuda::std::array<int64_t, 3> strides; cuda::std::array<int64_t, 3> strides;
cuda::std::array<int64_t, 3> out_strides; cuda::std::array<int64_t, 3> out_strides;
bool donated = false; bool donated = false;
int ndim = in.ndim(); int ndim = in.ndim();
int dispatch_ndim = in.ndim();
int B = in.shape(0);
int T = in.shape(-2);
int D = in.shape(-1);
size_t mat_size = T * D;
int dispatch_ndim = ndim;
while (in.shape(-dispatch_ndim) == 1 && dispatch_ndim > 3) { while (in.shape(-dispatch_ndim) == 1 && dispatch_ndim > 3) {
dispatch_ndim--; dispatch_ndim--;
} }
size_t mat_size = in.shape(-2) * in.shape(-1);
int N = 1;
for (int i = 1; i < (ndim - 2); ++i) {
N *= in.shape(i);
}
// We apply rope to less that the whole vector so copy to output and then // We apply rope to less that the whole vector so copy to output and then
// apply in-place. // apply in-place.
if (dims_ < in.shape(-1)) { if (dims_ < D) {
donated = true; donated = true;
auto ctype = auto ctype =
(in.flags().row_contiguous) ? CopyType::Vector : CopyType::General; (in.flags().row_contiguous) ? CopyType::Vector : CopyType::General;
@@ -302,7 +316,7 @@ void RoPE::eval_gpu(
out_strides[2] = out.strides()[ndim - 1]; out_strides[2] = out.strides()[ndim - 1];
// Some flags to help us dispatch below // Some flags to help us dispatch below
bool single = in.flags().row_contiguous && (mat_size == in.shape(-1)); bool single = in.flags().row_contiguous && B == 1 && T == 1;
bool with_freqs = inputs.size() == 3; bool with_freqs = inputs.size() == 3;
auto& encoder = cu::get_command_encoder(s); auto& encoder = cu::get_command_encoder(s);
@@ -319,7 +333,7 @@ void RoPE::eval_gpu(
if (single && !with_freqs) { if (single && !with_freqs) {
auto kernel = auto kernel =
cu::rope_single<DataType, traditional.value, forward.value>; cu::rope_single<DataType, traditional.value, forward.value>;
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size); uint2 dims = make_uint2(dims_ / 2, N);
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1); auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, kernel,
@@ -336,7 +350,7 @@ void RoPE::eval_gpu(
} else if (single) { } else if (single) {
auto kernel = auto kernel =
cu::rope_single_freqs<DataType, traditional.value, forward.value>; cu::rope_single_freqs<DataType, traditional.value, forward.value>;
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size); uint2 dims = make_uint2(dims_ / 2, N);
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1); auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, kernel,
@@ -354,10 +368,14 @@ void RoPE::eval_gpu(
} else if (with_freqs) { } else if (with_freqs) {
auto kernel = auto kernel =
cu::rope_freqs<DataType, traditional.value, forward.value>; cu::rope_freqs<DataType, traditional.value, forward.value>;
uint3 dims = int n_per_thread = 4;
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size); uint32_t dimz = B * ((N + n_per_thread - 1) / n_per_thread);
dims.z = (dims.z + 3) / 4; uint3 dims = make_uint3(dims_ / 2, T, dimz);
auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z); auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z);
int64_t offset_stride = 0;
if (inputs[1].ndim() > 0) {
offset_stride = inputs[1].strides()[0];
}
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, kernel,
grid, grid,
@@ -371,15 +389,20 @@ void RoPE::eval_gpu(
std::log2(base_), std::log2(base_),
strides, strides,
out_strides, out_strides,
in.size() / mat_size, offset_stride,
N,
dims, dims,
inputs[2].strides(0)); inputs[2].strides(0));
} else { } else {
auto kernel = cu::rope<DataType, traditional.value, forward.value>; auto kernel = cu::rope<DataType, traditional.value, forward.value>;
uint3 dims = int n_per_thread = 4;
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size); uint32_t dimz = B * ((N + n_per_thread - 1) / n_per_thread);
dims.z = (dims.z + 3) / 4; uint3 dims = make_uint3(dims_ / 2, T, dimz);
auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z); auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z);
int64_t offset_stride = 0;
if (inputs[1].ndim() > 0) {
offset_stride = inputs[1].strides()[0];
}
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, kernel,
grid, grid,
@@ -392,7 +415,8 @@ void RoPE::eval_gpu(
std::log2(base_), std::log2(base_),
strides, strides,
out_strides, out_strides,
in.size() / mat_size, offset_stride,
N,
dims); dims);
} }
}); });

View File

@@ -46,6 +46,7 @@ __global__ void kernel_sdpav_1pass(
const T* K, const T* K,
const T* V, const T* V,
T* O, T* O,
const T* sinks,
__grid_constant__ const AttnParams params) { __grid_constant__ const AttnParams params) {
constexpr int BN = 32; constexpr int BN = 32;
constexpr int BD = 32; constexpr int BD = 32;
@@ -65,7 +66,7 @@ __global__ void kernel_sdpav_1pass(
__shared__ U max_scores[BN]; __shared__ U max_scores[BN];
__shared__ U sum_exp_scores[BN]; __shared__ U sum_exp_scores[BN];
const U scale_log2 = params.scale * 1.44269504089f; const U scale_log2 = params.scale * M_LOG2E;
auto block = cg::this_thread_block(); auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<32>(block); auto warp = cg::tiled_partition<32>(block);
@@ -110,6 +111,10 @@ __global__ void kernel_sdpav_1pass(
U max_score = -INFINITY; U max_score = -INFINITY;
U sum_exp_score = 0.f; U sum_exp_score = 0.f;
if (sinks && warp_idx == 0) {
max_score = M_LOG2E * static_cast<U>(sinks[head_idx]);
sum_exp_score = 1.f;
}
// For each key // For each key
for (int i = kv_seq_idx; i < params.kL; i += BN) { for (int i = kv_seq_idx; i < params.kL; i += BN) {
@@ -137,8 +142,9 @@ __global__ void kernel_sdpav_1pass(
// Update the accumulators // Update the accumulators
U new_max = max(max_score, score); U new_max = max(max_score, score);
U factor = exp2f(max_score - new_max); bool is_neg_inf = new_max == -INFINITY;
U exp_score = exp2f(score - new_max); U factor = is_neg_inf ? 1 : exp2f(max_score - new_max);
U exp_score = is_neg_inf ? 0 : exp2f(score - new_max);
max_score = new_max; max_score = new_max;
sum_exp_score = sum_exp_score * factor + exp_score; sum_exp_score = sum_exp_score * factor + exp_score;
@@ -193,6 +199,7 @@ __global__ void kernel_sdpav_2pass_1(
const T* Q, const T* Q,
const T* K, const T* K,
const T* V, const T* V,
const T* sinks,
float* partials, float* partials,
float* sums, float* sums,
float* maxs, float* maxs,
@@ -268,8 +275,12 @@ __global__ void kernel_sdpav_2pass_1(
o[i] = 0.f; o[i] = 0.f;
} }
U max_score = -1e9; U max_score = -INFINITY;
U sum_exp_score = 0.f; U sum_exp_score = 0.f;
if (sinks && warp_idx == 0 && block_idx == 0) {
max_score = M_LOG2E * static_cast<U>(sinks[head_idx]);
sum_exp_score = 1.f;
}
// For each key // For each key
for (int i = kv_seq_idx; i < params.kL; i += blocks * BN) { for (int i = kv_seq_idx; i < params.kL; i += blocks * BN) {
@@ -297,8 +308,9 @@ __global__ void kernel_sdpav_2pass_1(
// Update the accumulators // Update the accumulators
U new_max = max(max_score, score); U new_max = max(max_score, score);
U factor = exp2f(max_score - new_max); bool is_neg_inf = new_max == -INFINITY;
U exp_score = exp2f(score - new_max); U factor = is_neg_inf ? 1 : exp2f(max_score - new_max);
U exp_score = is_neg_inf ? 0 : exp2f(score - new_max);
max_score = new_max; max_score = new_max;
sum_exp_score = sum_exp_score * factor + exp_score; sum_exp_score = sum_exp_score * factor + exp_score;
@@ -463,10 +475,14 @@ void sdpa_vector_1pass_fallback(
const array& v, const array& v,
const float scale, const float scale,
array& o, array& o,
bool do_causal_ = false) { bool do_causal,
const std::optional<array>& sinks) {
encoder.set_input_array(q); encoder.set_input_array(q);
encoder.set_input_array(k); encoder.set_input_array(k);
encoder.set_input_array(v); encoder.set_input_array(v);
if (sinks) {
encoder.set_input_array(*sinks);
}
encoder.set_output_array(o); encoder.set_output_array(o);
cu::AttnParams params{ cu::AttnParams params{
@@ -489,7 +505,7 @@ void sdpa_vector_1pass_fallback(
dim3 block_dim(1024, 1, 1); dim3 block_dim(1024, 1, 1);
dispatch_float_types(o.dtype(), "kernel_sdpav_1pass", [&](auto type_tag) { dispatch_float_types(o.dtype(), "kernel_sdpav_1pass", [&](auto type_tag) {
dispatch_bool(do_causal_, [&](auto do_causal) { dispatch_bool(do_causal, [&](auto do_causal) {
dispatch_headdim(params.D, [&](auto headdim) { dispatch_headdim(params.D, [&](auto headdim) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>; using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
@@ -504,6 +520,7 @@ void sdpa_vector_1pass_fallback(
k.data<DataType>(), k.data<DataType>(),
v.data<DataType>(), v.data<DataType>(),
o.data<DataType>(), o.data<DataType>(),
sinks ? (*sinks).data<DataType>() : nullptr,
params); params);
}); });
}); });
@@ -518,7 +535,8 @@ void sdpa_vector_2pass_fallback(
const array& v, const array& v,
const float scale, const float scale,
array& o, array& o,
bool do_causal_ = false) { bool do_causal,
const std::optional<array>& sinks) {
cu::AttnParams params{ cu::AttnParams params{
/* int B = */ q.shape(0), /* int B = */ q.shape(0),
/* int H = */ q.shape(1), /* int H = */ q.shape(1),
@@ -559,7 +577,7 @@ void sdpa_vector_2pass_fallback(
encoder.add_temporary(maxs); encoder.add_temporary(maxs);
dispatch_float_types(o.dtype(), "kernel_sdpav_2pass", [&](auto type_tag) { dispatch_float_types(o.dtype(), "kernel_sdpav_2pass", [&](auto type_tag) {
dispatch_bool(do_causal_, [&](auto do_causal) { dispatch_bool(do_causal, [&](auto do_causal) {
dispatch_headdim(params.D, [&](auto headdim) { dispatch_headdim(params.D, [&](auto headdim) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>; using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
@@ -570,6 +588,10 @@ void sdpa_vector_2pass_fallback(
encoder.set_input_array(q); encoder.set_input_array(q);
encoder.set_input_array(k); encoder.set_input_array(k);
encoder.set_input_array(v); encoder.set_input_array(v);
if (sinks) {
encoder.set_input_array(*sinks);
}
encoder.set_output_array(intermediate); encoder.set_output_array(intermediate);
encoder.set_output_array(sums); encoder.set_output_array(sums);
encoder.set_output_array(maxs); encoder.set_output_array(maxs);
@@ -585,6 +607,7 @@ void sdpa_vector_2pass_fallback(
q.data<DataType>(), q.data<DataType>(),
k.data<DataType>(), k.data<DataType>(),
v.data<DataType>(), v.data<DataType>(),
sinks ? (*sinks).data<DataType>() : nullptr,
intermediate.data<float>(), intermediate.data<float>(),
sums.data<float>(), sums.data<float>(),
maxs.data<float>(), maxs.data<float>(),
@@ -627,15 +650,16 @@ void sdpa_vector_fallback(
const array& v, const array& v,
const float scale, const float scale,
array& o, array& o,
bool do_causal_ = false) { bool do_causal,
const std::optional<array>& sinks) {
int kL = k.shape(2); int kL = k.shape(2);
if (kL > 1024) { if (kL > 1024) {
return sdpa_vector_2pass_fallback( return sdpa_vector_2pass_fallback(
s, encoder, q, k, v, scale, o, do_causal_); s, encoder, q, k, v, scale, o, do_causal, sinks);
} else { } else {
return sdpa_vector_1pass_fallback( return sdpa_vector_1pass_fallback(
s, encoder, q, k, v, scale, o, do_causal_); s, encoder, q, k, v, scale, o, do_causal, sinks);
} }
} }
@@ -691,7 +715,7 @@ void ScaledDotProductAttention::eval_gpu(
// Define some copy functions to ensure the layout of the inputs is as // Define some copy functions to ensure the layout of the inputs is as
// expected. // expected.
copies.reserve(3); copies.reserve(inputs.size());
auto copy_unless = [&copies, &s]( auto copy_unless = [&copies, &s](
auto predicate, const array& arr) -> const array& { auto predicate, const array& arr) -> const array& {
if (!predicate(arr)) { if (!predicate(arr)) {
@@ -703,6 +727,16 @@ void ScaledDotProductAttention::eval_gpu(
} }
}; };
// Checks that the headdim dimension has stride 1.
auto is_matrix_contiguous = [](const array& arr) {
return arr.strides(-1) == 1;
};
std::optional<array> sinks = std::nullopt;
if (has_sinks_) {
sinks = copy_unless(is_matrix_contiguous, inputs.back());
}
// We are in vector mode ie single query // We are in vector mode ie single query
if (q_pre.shape(2) < 4) { if (q_pre.shape(2) < 4) {
auto q_copy_unless = [](const array& arr) { auto q_copy_unless = [](const array& arr) {
@@ -740,10 +774,6 @@ void ScaledDotProductAttention::eval_gpu(
const auto& k = copy_unless(kv_copy_unless, k_pre); const auto& k = copy_unless(kv_copy_unless, k_pre);
const auto& v = copy_unless(kv_copy_unless, v_pre); const auto& v = copy_unless(kv_copy_unless, v_pre);
for (const auto& cp : copies) {
encoder.add_temporary(cp);
}
// Donate the query if possible // Donate the query if possible
if (q.is_donatable() && q.flags().row_contiguous && q.size() == o.size()) { if (q.is_donatable() && q.flags().row_contiguous && q.size() == o.size()) {
o.copy_shared_buffer(q); o.copy_shared_buffer(q);
@@ -752,22 +782,26 @@ void ScaledDotProductAttention::eval_gpu(
int64_t str_oH = o.shape(3); int64_t str_oH = o.shape(3);
int64_t str_oL = o.shape(1) * str_oH; int64_t str_oL = o.shape(1) * str_oH;
int64_t str_oB = o.shape(2) * str_oL; int64_t str_oB = o.shape(2) * str_oL;
size_t data_size = o.shape(0) * str_oB;
array::Flags flags{ array::Flags flags{
/* bool contiguous = */ 1, /* bool contiguous = */ 1,
/* bool row_contiguous = */ o.shape(2) == 1, /* bool row_contiguous = */ o.shape(2) == 1,
/* bool col_contiguous = */ 0, /* bool col_contiguous = */ o.size() == o.shape(3),
}; };
o.set_data( o.set_data(
allocator::malloc(o.nbytes()), allocator::malloc(o.nbytes()),
data_size, o.size(),
{str_oB, str_oH, str_oL, str_oD}, {str_oB, str_oH, str_oL, str_oD},
flags); flags);
} }
return sdpa_vector_fallback(s, encoder, q, k, v, scale_, o, do_causal_); for (const auto& cp : copies) {
encoder.add_temporary(cp);
}
return sdpa_vector_fallback(
s, encoder, q, k, v, scale_, o, do_causal_, sinks);
} }
// Full attention mode should never reach here // Full attention mode should never reach here

View File

@@ -2,7 +2,6 @@
#include <algorithm> #include <algorithm>
#include <cassert> #include <cassert>
#include <numeric> #include <numeric>
#include <sstream>
#include "mlx/backend/gpu/copy.h" #include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
@@ -39,10 +38,11 @@ void explicit_gemm_conv_ND_gpu(
in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes())); in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes()));
// Prepare unfolding kernel // Prepare unfolding kernel
std::ostringstream kname; std::string kname;
kname << "naive_unfold_nd_" << type_to_name(in_unfolded) << "_" << N; kname.reserve(32);
concatenate(kname, "naive_unfold_nd_", type_to_name(in_unfolded), "_", N);
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str()); auto kernel = d.get_kernel(kname);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in, 0); compute_encoder.set_input_array(in, 0);
@@ -117,11 +117,12 @@ void explicit_gemm_conv_group_ND_gpu(
in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes())); in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes()));
// Prepare unfolding kernel // Prepare unfolding kernel
std::ostringstream kname; std::string kname;
kname << "naive_unfold_transpose_nd_" << type_to_name(in_unfolded) << "_" kname.reserve(32);
<< N; concatenate(
kname, "naive_unfold_transpose_nd_", type_to_name(in_unfolded), "_", N);
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str()); auto kernel = d.get_kernel(kname);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in, 0); compute_encoder.set_input_array(in, 0);
@@ -252,18 +253,32 @@ void implicit_gemm_conv_2D_gpu(
/* const int swizzle_log = */ swizzle_log}; /* const int swizzle_log = */ swizzle_log};
// Determine kernel // Determine kernel
std::ostringstream kname; std::string kname;
kname << "implicit_gemm_conv_2d_" << type_to_name(out) << "_bm" << bm << "_bn" kname.reserve(64);
<< bn << "_bk" << bk << "_wm" << wm << "_wn" << wn << "_channel_" concatenate(
<< (n_channel_specialization ? std::to_string(n_channel_specialization) kname,
: "l") "implicit_gemm_conv_2d_",
<< "_filter_" << (small_filter ? 's' : 'l'); type_to_name(out),
"_bm",
bm,
"_bn",
bn,
"_bk",
bk,
"_wm",
wm,
"_wn",
wn,
"_channel_",
n_channel_specialization ? std::to_string(n_channel_specialization) : "l",
"_filter_",
small_filter ? 's' : 'l');
// Encode and dispatch kernel // Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = get_steel_conv_kernel( auto kernel = get_steel_conv_kernel(
d, d,
kname.str(), kname,
out, out,
bm, bm,
bn, bn,
@@ -559,11 +574,16 @@ void winograd_conv_2D_gpu(
{ {
int bc = 32; int bc = 32;
int bo = 4; int bo = 4;
std::ostringstream kname; std::string kname;
kname << "winograd_conv_2d_weight_transform_" << type_to_name(out) << "_bc" kname.reserve(32);
<< bc; concatenate(
kname,
"winograd_conv_2d_weight_transform_",
type_to_name(out),
"_bc",
bc);
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str()); auto kernel = d.get_kernel(kname);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(wt, 0); compute_encoder.set_input_array(wt, 0);
@@ -587,11 +607,16 @@ void winograd_conv_2D_gpu(
int bc = 32; int bc = 32;
int wm = 2; int wm = 2;
int wn = 2; int wn = 2;
std::ostringstream kname; std::string kname;
kname << "winograd_conv_2d_input_transform_" << type_to_name(out) << "_bc" kname.reserve(32);
<< bc; concatenate(
kname,
"winograd_conv_2d_input_transform_",
type_to_name(out),
"_bc",
bc);
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str()); auto kernel = d.get_kernel(kname);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in_padded, 0); compute_encoder.set_input_array(in_padded, 0);
@@ -634,11 +659,16 @@ void winograd_conv_2D_gpu(
int bc = 32; int bc = 32;
int wm = 2; int wm = 2;
int wn = 2; int wn = 2;
std::ostringstream kname; std::string kname;
kname << "winograd_conv_2d_output_transform_" << type_to_name(out) << "_bo" kname.reserve(32);
<< bc; concatenate(
kname,
"winograd_conv_2d_output_transform_",
type_to_name(out),
"_bo",
bc);
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str()); auto kernel = d.get_kernel(kname);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(out_wg, 0); compute_encoder.set_input_array(out_wg, 0);
@@ -660,9 +690,9 @@ void depthwise_conv_2D_gpu(
const array& wt, const array& wt,
array out, array out,
const MLXConvParams<2>& conv_params) { const MLXConvParams<2>& conv_params) {
std::ostringstream kname; std::string base_name;
kname << "depthwise_conv_2d_" << type_to_name(out); base_name.reserve(32);
std::string base_name = kname.str(); concatenate(base_name, "depthwise_conv_2d_", type_to_name(out));
const int N = conv_params.N; const int N = conv_params.N;
const int ker_h = conv_params.wS[0]; const int ker_h = conv_params.wS[0];
@@ -685,15 +715,18 @@ void depthwise_conv_2D_gpu(
}; };
// clang-format off // clang-format off
kname << "_ker_h_" << ker_h std::string hash_name;
<< "_ker_w_" << ker_w hash_name.reserve(64);
<< "_str_h_" << str_h concatenate(
<< "_str_w_" << str_w hash_name,
<< "_tgp_h_" << th base_name,
<< "_tgp_w_" << tw "_ker_h_", ker_h,
<< "_do_flip_" << (do_flip ? 't' : 'n'); // clang-format on "_ker_w_", ker_w,
"_str_h_", str_h,
std::string hash_name = kname.str(); "_str_w_", str_w,
"_tgp_h_", th,
"_tgp_w_", tw,
"_do_flip_", do_flip ? 't' : 'n'); // clang-format on
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(base_name, hash_name, func_consts); auto kernel = d.get_kernel(base_name, hash_name, func_consts);
@@ -774,6 +807,56 @@ void dispatch_conv_2D_gpu(
} }
} }
void depthwise_conv_1D_gpu(
const Stream& s,
metal::Device& d,
const array& in,
array wt,
array out) {
bool large = in.size() > INT32_MAX || in.data_size() > INT32_MAX;
std::string base_name;
base_name.reserve(32);
concatenate(
base_name,
"depthwise_conv_1d_",
large ? "_large" : "",
type_to_name(out));
if (!wt.flags().row_contiguous) {
wt = contiguous_copy_gpu(wt, s);
d.add_temporary(wt, s.index);
}
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(base_name);
compute_encoder.set_compute_pipeline_state(kernel);
auto B = in.shape(0);
auto Tout = out.shape(1);
auto D = in.shape(2);
auto K = wt.shape(1);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_input_array(wt, 1);
compute_encoder.set_output_array(out, 2);
if (large) {
int64_t strides[3] = {in.strides(0), in.strides(1), in.strides(2)};
compute_encoder.set_bytes(strides, 3, 3);
} else {
int strides[3] = {
static_cast<int>(in.strides(0)),
static_cast<int>(in.strides(1)),
static_cast<int>(in.strides(2))};
compute_encoder.set_bytes(strides, 3, 3);
}
compute_encoder.set_bytes(K, 4);
auto group_dims = get_block_dims(D, Tout, B);
MTL::Size grid_dims = MTL::Size(D, Tout, B);
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
void conv_1D_gpu( void conv_1D_gpu(
const Stream& s, const Stream& s,
metal::Device& d, metal::Device& d,
@@ -790,8 +873,15 @@ void conv_1D_gpu(
bool is_idil_one = in_dilation[0] == 1; bool is_idil_one = in_dilation[0] == 1;
int C = in.shape(2); int C = in.shape(2);
int O = wt.shape(0); int O = wt.shape(0);
const int C_per_group = in.shape(2) / groups; // Fast path for fully separable 1D convolution
const int O_per_group = wt.shape(0) / groups; if (is_idil_one && (groups == C) && groups == O && wt_strides[0] == 1 &&
wt_dilation[0] == 1 && padding[0] == 0 && !flip) {
depthwise_conv_1D_gpu(s, d, in, wt, out);
return;
}
const int C_per_group = C / groups;
const int O_per_group = O / groups;
// Direct to implicit gemm conv // Direct to implicit gemm conv
if (is_idil_one && (C_per_group <= 4 || C_per_group % 16 == 0) && if (is_idil_one && (C_per_group <= 4 || C_per_group % 16 == 0) &&

View File

@@ -288,6 +288,40 @@ instantiate_depthconv2d(float32, float);
instantiate_depthconv2d(float16, half); instantiate_depthconv2d(float16, half);
instantiate_depthconv2d(bfloat16, bfloat16_t); instantiate_depthconv2d(bfloat16, bfloat16_t);
template <typename T, typename IdxT>
[[kernel]] void depthwise_conv_1d(
const device T* in [[buffer(0)]],
const device T* w [[buffer(1)]],
device T* out [[buffer(2)]],
constant const IdxT strides[3],
constant const int& kernel_size,
uint3 tid [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
out += (tid.z * static_cast<IdxT>(grid_dim.y) + tid.y) * grid_dim.x + tid.x;
in += tid.z * strides[0] + tid.y * strides[1] + tid.x * strides[2];
w += tid.x * kernel_size;
float acc = 0.0;
for (int i = 0; i < kernel_size; ++i) {
acc += static_cast<float>(in[0]) * w[i];
in += strides[1];
}
*out = static_cast<T>(acc);
}
#define instantiate_depthconv1d(iname, itype) \
instantiate_kernel( \
"depthwise_conv_1d_" #iname, depthwise_conv_1d, itype, int32_t) \
instantiate_kernel( \
"depthwise_conv_1d_" #iname "_large", \
depthwise_conv_1d, \
itype, \
int64_t)
instantiate_depthconv1d(float32, float);
instantiate_depthconv1d(float16, half);
instantiate_depthconv1d(bfloat16, bfloat16_t);
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
/// Winograd kernels /// Winograd kernels
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////

View File

@@ -10,7 +10,7 @@ void rope_single_impl(
constant const int& offset, constant const int& offset,
const float inv_freq, const float inv_freq,
constant const float& scale, constant const float& scale,
constant const size_t& stride, constant const int64_t& stride,
uint2 pos, uint2 pos,
uint2 grid) { uint2 grid) {
float L = scale * static_cast<float>(offset); float L = scale * static_cast<float>(offset);
@@ -52,7 +52,7 @@ template <typename T, bool traditional, bool forward>
device T* out [[buffer(1)]], device T* out [[buffer(1)]],
constant const int& offset, constant const int& offset,
constant const float& scale, constant const float& scale,
constant const size_t& stride, constant const int64_t& stride,
constant const float& base [[buffer(10)]], constant const float& base [[buffer(10)]],
uint2 pos [[thread_position_in_grid]], uint2 pos [[thread_position_in_grid]],
uint2 grid [[threads_per_grid]]) { uint2 grid [[threads_per_grid]]) {
@@ -68,9 +68,9 @@ template <typename T, bool traditional, bool forward>
device T* out [[buffer(1)]], device T* out [[buffer(1)]],
constant const int& offset, constant const int& offset,
constant const float& scale, constant const float& scale,
constant const size_t& stride, constant const int64_t& stride,
const device float* freqs [[buffer(10)]], const device float* freqs [[buffer(10)]],
constant const size_t& freq_stride [[buffer(11)]], constant const int64_t& freq_stride [[buffer(11)]],
uint2 pos [[thread_position_in_grid]], uint2 pos [[thread_position_in_grid]],
uint2 grid [[threads_per_grid]]) { uint2 grid [[threads_per_grid]]) {
float inv_freq = 1.0 / (freqs[freq_stride * pos.x]); float inv_freq = 1.0 / (freqs[freq_stride * pos.x]);
@@ -82,15 +82,21 @@ template <typename T, bool traditional, bool forward, int N = 4>
void rope_impl( void rope_impl(
const device T* in, const device T* in,
device T* out, device T* out,
constant const int& offset, const device int* offset,
const float inv_freq, const float inv_freq,
constant const float& scale, constant const float& scale,
constant const size_t strides[3], constant const int64_t strides[3],
constant const size_t out_strides[3], constant const int64_t out_strides[3],
constant const size_t& n_batch, constant const int64_t& offset_stride,
constant const int& n_head,
uint3 pos, uint3 pos,
uint3 grid) { uint3 grid) {
float L = scale * static_cast<float>(pos.y + offset); auto n_head_up = N * ((n_head + N - 1) / N);
auto head_idx = static_cast<int>((pos.z * N) % n_head_up);
auto batch_idx = (pos.z * N) / n_head_up;
auto batch_offset = offset[batch_idx * offset_stride];
float L = scale * static_cast<float>(pos.y + batch_offset);
auto mat_idx = batch_idx * n_head + head_idx;
// Compute costheta, sintheta // Compute costheta, sintheta
float theta = L * inv_freq; float theta = L * inv_freq;
@@ -102,20 +108,19 @@ void rope_impl(
size_t out_index_1, out_index_2; size_t out_index_1, out_index_2;
if (traditional) { if (traditional) {
out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] + out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] +
N * pos.z * out_strides[0]; mat_idx * out_strides[0];
out_index_2 = out_index_1 + 1; out_index_2 = out_index_1 + 1;
in_index_1 = in_index_1 =
2 * pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0]; 2 * pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0];
in_index_2 = in_index_1 + strides[2]; in_index_2 = in_index_1 + strides[2];
} else { } else {
out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] + out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] +
N * pos.z * out_strides[0]; mat_idx * out_strides[0];
out_index_2 = out_index_1 + grid.x * out_strides[2]; out_index_2 = out_index_1 + grid.x * out_strides[2];
in_index_1 = in_index_1 = pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0];
pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
in_index_2 = in_index_1 + grid.x * strides[2]; in_index_2 = in_index_1 + grid.x * strides[2];
} }
for (int i = 0; i < N && pos.z * N + i < n_batch; ++i) { for (int i = 0; i < N && head_idx + i < n_head; ++i) {
// Read and write the output // Read and write the output
float x1 = static_cast<float>(in[in_index_1]); float x1 = static_cast<float>(in[in_index_1]);
float x2 = static_cast<float>(in[in_index_2]); float x2 = static_cast<float>(in[in_index_2]);
@@ -141,11 +146,12 @@ template <typename T, bool traditional, bool forward, int N = 4>
[[kernel]] void rope( [[kernel]] void rope(
const device T* in [[buffer(0)]], const device T* in [[buffer(0)]],
device T* out [[buffer(1)]], device T* out [[buffer(1)]],
constant const int& offset, const device int* offset,
constant const float& scale, constant const float& scale,
constant const size_t strides[3], constant const int64_t strides[3],
constant const size_t out_strides[3], constant const int64_t out_strides[3],
constant const size_t& n_batch, constant const int64_t& offset_stride,
constant const int& n_head,
constant const float& base [[buffer(10)]], constant const float& base [[buffer(10)]],
uint3 pos [[thread_position_in_grid]], uint3 pos [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]) { uint3 grid [[threads_per_grid]]) {
@@ -159,7 +165,8 @@ template <typename T, bool traditional, bool forward, int N = 4>
scale, scale,
strides, strides,
out_strides, out_strides,
n_batch, offset_stride,
n_head,
pos, pos,
grid); grid);
} }
@@ -168,13 +175,14 @@ template <typename T, bool traditional, bool forward, int N = 4>
[[kernel]] void rope_freqs( [[kernel]] void rope_freqs(
const device T* in [[buffer(0)]], const device T* in [[buffer(0)]],
device T* out [[buffer(1)]], device T* out [[buffer(1)]],
constant const int& offset, const device int* offset,
constant const float& scale, constant const float& scale,
constant const size_t strides[3], constant const int64_t strides[3],
constant const size_t out_strides[3], constant const int64_t out_strides[3],
constant const size_t& n_batch, constant const int64_t& offset_stride,
constant const int& n_head,
const device float* freqs [[buffer(10)]], const device float* freqs [[buffer(10)]],
constant const size_t& freq_stride [[buffer(11)]], constant const int64_t& freq_stride [[buffer(11)]],
uint3 pos [[thread_position_in_grid]], uint3 pos [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]) { uint3 grid [[threads_per_grid]]) {
float inv_freq = 1.0 / (freqs[freq_stride * pos.x]); float inv_freq = 1.0 / (freqs[freq_stride * pos.x]);
@@ -186,61 +194,20 @@ template <typename T, bool traditional, bool forward, int N = 4>
scale, scale,
strides, strides,
out_strides, out_strides,
n_batch, offset_stride,
n_head,
pos, pos,
grid); grid);
} }
// clang-format off // clang-format off
#define instantiate_rope_g(name, type, traditional, forward) \ #define instantiate_rope_g(name, type, traditional, forward) \
template [[host_name("rope_" #name)]] [[kernel]] void \ instantiate_kernel("rope_" #name, rope, type, traditional, forward) \
rope<type, traditional, forward>( \ instantiate_kernel("rope_freqs_" #name, rope_freqs, type, traditional, forward)
const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \
constant const int& offset, \
constant const float& scale, \
constant const size_t strides[3], \
constant const size_t out_strides[3], \
constant const size_t& n_batch, \
constant const float& base [[buffer(10)]], \
uint3 pos [[thread_position_in_grid]], \
uint3 grid [[threads_per_grid]]); \
template [[host_name("rope_freqs_" #name)]] \
[[kernel]] void rope_freqs<type, traditional, forward>( \
const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \
constant const int& offset, \
constant const float& scale, \
constant const size_t strides[3], \
constant const size_t out_strides[3], \
constant const size_t& n_batch, \
const device float* freqs [[buffer(10)]], \
constant const size_t& freq_stride [[buffer(11)]], \
uint3 pos [[thread_position_in_grid]], \
uint3 grid [[threads_per_grid]]);
#define instantiate_rope_s(name, type, traditional, forward) \ #define instantiate_rope_s(name, type, traditional, forward) \
template [[host_name("rope_single_" #name)]] [[kernel]] void \ instantiate_kernel("rope_single_" #name, rope_single, type, traditional, forward) \
rope_single<type, traditional, forward>( \ instantiate_kernel("rope_single_freqs_" #name, rope_single_freqs, type, traditional, forward)
const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \
constant const int& offset, \
constant const float& scale, \
constant const size_t& stride, \
constant const float& base [[buffer(10)]], \
uint2 pos [[thread_position_in_grid]], \
uint2 grid [[threads_per_grid]]); \
template [[host_name("rope_single_freqs_" #name)]] \
[[kernel]] void rope_single_freqs<type, traditional, forward>( \
const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \
constant const int& offset, \
constant const float& scale, \
constant const size_t& stride, \
const device float* freqs [[buffer(10)]], \
constant const size_t& freq_stride [[buffer(11)]], \
uint2 pos [[thread_position_in_grid]], \
uint2 grid [[threads_per_grid]]);
#define instantiate_rope(name, type, traditional, forward) \ #define instantiate_rope(name, type, traditional, forward) \
instantiate_rope_s(name, type, traditional, forward) \ instantiate_rope_s(name, type, traditional, forward) \

View File

@@ -9,6 +9,7 @@ constant bool query_transposed [[function_constant(21)]];
constant bool do_causal [[function_constant(22)]]; constant bool do_causal [[function_constant(22)]];
constant bool bool_mask [[function_constant(23)]]; constant bool bool_mask [[function_constant(23)]];
constant bool float_mask [[function_constant(24)]]; constant bool float_mask [[function_constant(24)]];
constant bool has_sinks [[function_constant(25)]];
template <typename T, int D, int V = D> template <typename T, int D, int V = D>
[[kernel]] void sdpa_vector( [[kernel]] void sdpa_vector(
@@ -31,6 +32,9 @@ template <typename T, int D, int V = D>
[[buffer(14), function_constant(has_mask)]], [[buffer(14), function_constant(has_mask)]],
const constant int& mask_head_stride const constant int& mask_head_stride
[[buffer(15), function_constant(has_mask)]], [[buffer(15), function_constant(has_mask)]],
const device T* sinks [[buffer(16), function_constant(has_sinks)]],
const constant int& num_q_heads
[[buffer(17), function_constant(has_sinks)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint3 tpg [[threadgroups_per_grid]], uint3 tpg [[threadgroups_per_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
@@ -53,24 +57,24 @@ template <typename T, int D, int V = D>
threadgroup U sum_exp_scores[BN]; threadgroup U sum_exp_scores[BN];
// Adjust positions // Adjust positions
const int head_idx = tid.x; const int q_batch_head_idx = tid.x;
const int q_seq_idx = tid.y; const int q_seq_idx = tid.y;
const int kv_head_idx = head_idx / gqa_factor; const int kv_head_idx = q_batch_head_idx / gqa_factor;
const int o_offset = head_idx * tpg.y + q_seq_idx; const int o_offset = q_batch_head_idx * tpg.y + q_seq_idx;
const int q_offset = const int q_offset =
query_transposed ? tpg.x * q_seq_idx + head_idx : o_offset; query_transposed ? tpg.x * q_seq_idx + q_batch_head_idx : o_offset;
queries += q_offset * D + simd_lid * qk_per_thread; queries += q_offset * D + simd_lid * qk_per_thread;
keys += kv_head_idx * k_head_stride + simd_gid * k_seq_stride + keys += kv_head_idx * k_head_stride + simd_gid * k_seq_stride +
simd_lid * qk_per_thread; simd_lid * qk_per_thread;
values += kv_head_idx * v_head_stride + simd_gid * v_seq_stride + values += kv_head_idx * v_head_stride + simd_gid * v_seq_stride +
simd_lid * v_per_thread; simd_lid * v_per_thread;
if (bool_mask) { if (bool_mask) {
bmask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride + bmask += q_batch_head_idx * mask_head_stride +
q_seq_idx * mask_q_seq_stride; simd_gid * mask_kv_seq_stride + q_seq_idx * mask_q_seq_stride;
} }
if (float_mask) { if (float_mask) {
fmask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride + fmask += q_batch_head_idx * mask_head_stride +
q_seq_idx * mask_q_seq_stride; simd_gid * mask_kv_seq_stride + q_seq_idx * mask_q_seq_stride;
} }
out += o_offset * V + simd_gid * v_per_thread; out += o_offset * V + simd_gid * v_per_thread;
@@ -85,6 +89,10 @@ template <typename T, int D, int V = D>
U max_score = -INFINITY; U max_score = -INFINITY;
U sum_exp_score = 0; U sum_exp_score = 0;
if (has_sinks && simd_gid == 0) {
max_score = static_cast<U>(sinks[q_batch_head_idx % num_q_heads]);
sum_exp_score = 1;
}
// For each key // For each key
for (int i = simd_gid; i < N; i += BN) { for (int i = simd_gid; i < N; i += BN) {
@@ -93,6 +101,8 @@ template <typename T, int D, int V = D>
use_key = i <= (N - int(tpg.y) + int(q_seq_idx)); use_key = i <= (N - int(tpg.y) + int(q_seq_idx));
} else if (bool_mask) { } else if (bool_mask) {
use_key = bmask[0]; use_key = bmask[0];
} else if (float_mask) {
use_key = (fmask[0] >= Limits<T>::finite_min);
} }
if (use_key) { if (use_key) {
// Read the key // Read the key
@@ -107,13 +117,14 @@ template <typename T, int D, int V = D>
} }
score = simd_sum(score); score = simd_sum(score);
if (float_mask) { if (float_mask) {
score += max(Limits<U>::finite_min, static_cast<U>(fmask[0])); score += static_cast<U>(fmask[0]);
} }
// Update the accumulators // Update the accumulators
U new_max = max(max_score, score); U new_max = max(max_score, score);
U factor = fast::exp(max_score - new_max); bool is_neg_inf = new_max == -INFINITY;
U exp_score = fast::exp(score - new_max); U factor = is_neg_inf ? 1.0 : fast::exp(max_score - new_max);
U exp_score = is_neg_inf ? 0.0 : fast::exp(score - new_max);
max_score = new_max; max_score = new_max;
sum_exp_score = sum_exp_score * factor + exp_score; sum_exp_score = sum_exp_score * factor + exp_score;
@@ -187,6 +198,9 @@ template <typename T, int D, int V = D>
[[buffer(16), function_constant(has_mask)]], [[buffer(16), function_constant(has_mask)]],
const constant int& mask_head_stride const constant int& mask_head_stride
[[buffer(17), function_constant(has_mask)]], [[buffer(17), function_constant(has_mask)]],
const device T* sinks [[buffer(18), function_constant(has_sinks)]],
const constant int& num_q_heads
[[buffer(19), function_constant(has_sinks)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint3 tpg [[threadgroups_per_grid]], uint3 tpg [[threadgroups_per_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
@@ -211,12 +225,12 @@ template <typename T, int D, int V = D>
// Adjust positions // Adjust positions
const int block_idx = tid.z; const int block_idx = tid.z;
const int head_idx = tid.x; const int q_batch_head_idx = tid.x;
const int q_seq_idx = tid.y; const int q_seq_idx = tid.y;
const int o_offset = head_idx * tpg.y + q_seq_idx; const int o_offset = q_batch_head_idx * tpg.y + q_seq_idx;
const int q_offset = const int q_offset =
query_transposed ? tpg.x * q_seq_idx + head_idx : o_offset; query_transposed ? tpg.x * q_seq_idx + q_batch_head_idx : o_offset;
const int kv_head_idx = head_idx / gqa_factor; const int kv_head_idx = q_batch_head_idx / gqa_factor;
queries += q_offset * D + simd_lid * qk_per_thread; queries += q_offset * D + simd_lid * qk_per_thread;
keys += kv_head_idx * k_head_stride + keys += kv_head_idx * k_head_stride +
@@ -225,12 +239,12 @@ template <typename T, int D, int V = D>
(block_idx * BN + simd_gid) * v_seq_stride + simd_lid * v_per_thread; (block_idx * BN + simd_gid) * v_seq_stride + simd_lid * v_per_thread;
out += o_offset * blocks * V + block_idx * V + simd_lid * v_per_thread; out += o_offset * blocks * V + block_idx * V + simd_lid * v_per_thread;
if (bool_mask) { if (bool_mask) {
bmask += head_idx * mask_head_stride + bmask += q_batch_head_idx * mask_head_stride +
(block_idx * BN + simd_gid) * mask_kv_seq_stride + (block_idx * BN + simd_gid) * mask_kv_seq_stride +
q_seq_idx * mask_q_seq_stride; q_seq_idx * mask_q_seq_stride;
} }
if (float_mask) { if (float_mask) {
fmask += head_idx * mask_head_stride + fmask += q_batch_head_idx * mask_head_stride +
(block_idx * BN + simd_gid) * mask_kv_seq_stride + (block_idx * BN + simd_gid) * mask_kv_seq_stride +
q_seq_idx * mask_q_seq_stride; q_seq_idx * mask_q_seq_stride;
} }
@@ -245,8 +259,13 @@ template <typename T, int D, int V = D>
o[i] = 0; o[i] = 0;
} }
U max_score = -1e9; U max_score = -INFINITY;
U sum_exp_score = 0; U sum_exp_score = 0;
if (has_sinks && block_idx == 0 && simd_gid == 0) {
int q_head_idx = q_batch_head_idx % num_q_heads;
max_score = static_cast<U>(sinks[q_head_idx]);
sum_exp_score = 1;
}
// For each key // For each key
for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) { for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) {
@@ -255,6 +274,8 @@ template <typename T, int D, int V = D>
use_key = i <= (N - int(tpg.y) + int(q_seq_idx)); use_key = i <= (N - int(tpg.y) + int(q_seq_idx));
} else if (bool_mask) { } else if (bool_mask) {
use_key = bmask[0]; use_key = bmask[0];
} else if (float_mask) {
use_key = (fmask[0] >= Limits<T>::finite_min);
} }
if (use_key) { if (use_key) {
// Read the key // Read the key
@@ -268,6 +289,10 @@ template <typename T, int D, int V = D>
score += q[i] * k[i]; score += q[i] * k[i];
} }
score = simd_sum(score); score = simd_sum(score);
if (score < Limits<T>::finite_min) {
continue;
}
if (float_mask) { if (float_mask) {
score += fmask[0]; score += fmask[0];
} }

View File

@@ -11,6 +11,7 @@ constant bool align_K [[function_constant(201)]];
constant bool has_mask [[function_constant(300)]]; constant bool has_mask [[function_constant(300)]];
constant bool do_causal [[function_constant(301)]]; constant bool do_causal [[function_constant(301)]];
constant bool has_sinks [[function_constant(302)]];
template <typename T> template <typename T>
struct TransformScale { struct TransformScale {
@@ -82,6 +83,7 @@ template <
const constant AttnParams* params [[buffer(4)]], const constant AttnParams* params [[buffer(4)]],
const constant AttnMaskParams* mask_params [[buffer(5), function_constant(has_mask)]], const constant AttnMaskParams* mask_params [[buffer(5), function_constant(has_mask)]],
const device MaskType* mask [[buffer(6), function_constant(has_mask)]], const device MaskType* mask [[buffer(6), function_constant(has_mask)]],
const device T* sinks [[buffer(7), function_constant(has_sinks)]],
uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
@@ -169,7 +171,7 @@ template <
VBlockLoader loader_v( VBlockLoader loader_v(
V, params->V_strides[2], Vs, simd_group_id, simd_lane_id); V, params->V_strides[2], Vs, simd_group_id, simd_lane_id);
TransformScale<T> ts(static_cast<T>(params->scale * 1.44269504089)); TransformScale<T> ts(static_cast<T>(params->scale * M_LOG2E_F));
// Prepare MMA tiles // Prepare MMA tiles
constexpr short kFragSize = 8; // MMAFrag size constexpr short kFragSize = 8; // MMAFrag size
@@ -232,6 +234,14 @@ template <
max_score[i] = Limits<AccumType>::finite_min; max_score[i] = Limits<AccumType>::finite_min;
} }
if (has_sinks) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kRowsPT; ++i) {
max_score[i] = M_LOG2E_F * static_cast<AccumType>(sinks[tidl.y]);
sum_score[i] = 1;
}
}
int kb_lim = params->NK; int kb_lim = params->NK;
if (do_causal) { if (do_causal) {
@@ -350,7 +360,7 @@ template <
Stile.frag_at(i, j)[jj] = Stile.frag_at(i, j)[jj] =
mfrag[jj] ? Stile.frag_at(i, j)[jj] : neg_inf; mfrag[jj] ? Stile.frag_at(i, j)[jj] : neg_inf;
} else { } else {
Stile.frag_at(i, j)[jj] += 1.44269504089 * selem_t(mfrag[jj]); Stile.frag_at(i, j)[jj] += M_LOG2E_F * selem_t(mfrag[jj]);
} }
} }
} }

View File

@@ -18,23 +18,29 @@ void RoPE::eval_gpu(
auto& in = inputs[0]; auto& in = inputs[0];
auto& out = outputs[0]; auto& out = outputs[0];
if (in.ndim() < 3) {
throw std::runtime_error("[RoPE] Input must have at least 3 dimensions");
}
auto& s = out.primitive().stream(); auto& s = out.primitive().stream();
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
size_t strides[3]; int64_t strides[3];
size_t out_strides[3]; int64_t out_strides[3];
bool donated = false; bool donated = false;
int ndim = in.ndim(); int ndim = in.ndim();
int dispatch_ndim = in.ndim(); int B = in.shape(0);
int T = in.shape(-2);
int D = in.shape(-1);
size_t mat_size = T * D;
int dispatch_ndim = ndim;
while (in.shape(-dispatch_ndim) == 1 && dispatch_ndim > 3) { while (in.shape(-dispatch_ndim) == 1 && dispatch_ndim > 3) {
dispatch_ndim--; dispatch_ndim--;
} }
size_t mat_size = in.shape(-2) * in.shape(-1);
if (dims_ < in.shape(-1)) { int N = 1;
for (int i = 1; i < (ndim - 2); ++i) {
N *= in.shape(i);
}
if (dims_ < D) {
donated = true; donated = true;
auto ctype = auto ctype =
(in.flags().row_contiguous) ? CopyType::Vector : CopyType::General; (in.flags().row_contiguous) ? CopyType::Vector : CopyType::General;
@@ -71,8 +77,8 @@ void RoPE::eval_gpu(
out_strides[1] = out.strides()[ndim - 2]; out_strides[1] = out.strides()[ndim - 2];
out_strides[2] = out.strides()[ndim - 1]; out_strides[2] = out.strides()[ndim - 1];
// Special case for inference (single time step and contiguous) // Special case for inference (single batch, single time step, and contiguous)
bool single = in.flags().row_contiguous && (mat_size == in.shape(-1)); bool single = in.flags().row_contiguous && B == 1 && T == 1;
bool with_freqs = inputs.size() == 3; bool with_freqs = inputs.size() == 3;
std::ostringstream kname; std::ostringstream kname;
@@ -86,24 +92,29 @@ void RoPE::eval_gpu(
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(donated ? out : in, 0); compute_encoder.set_input_array(donated ? out : in, 0);
compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(out, 1);
compute_encoder.set_input_array(inputs[1], 2); compute_encoder.set_input_array(inputs[1], 2);
compute_encoder.set_bytes(scale_, 3); compute_encoder.set_bytes(scale_, 3);
size_t n_batch = in.size() / mat_size;
MTL::Size group_dims; MTL::Size group_dims;
MTL::Size grid_dims; MTL::Size grid_dims;
if (single) { if (single) {
compute_encoder.set_bytes(out_strides, 1, 4); compute_encoder.set_bytes(out_strides, 1, 4);
uint32_t dim0 = dims_ / 2; uint32_t dim0 = dims_ / 2;
group_dims = get_block_dims(dim0, n_batch, 1); group_dims = get_block_dims(dim0, N, 1);
grid_dims = MTL::Size(dim0, n_batch, 1); grid_dims = MTL::Size(dim0, N, 1);
} else { } else {
compute_encoder.set_bytes(strides, 3, 4); compute_encoder.set_bytes(strides, 3, 4);
compute_encoder.set_bytes(out_strides, 3, 5); compute_encoder.set_bytes(out_strides, 3, 5);
compute_encoder.set_bytes(n_batch, 6); int64_t offset_stride = 0;
if (inputs[1].ndim() > 0) {
offset_stride = inputs[1].strides()[0];
}
compute_encoder.set_bytes(offset_stride, 6);
compute_encoder.set_bytes(N, 7);
uint32_t dim0 = dims_ / 2; uint32_t dim0 = dims_ / 2;
uint32_t dim1 = in.shape(-2); uint32_t dim1 = T;
uint32_t dim2 = (n_batch + n_per_thread - 1) / n_per_thread; uint32_t dim2 = B * ((N + n_per_thread - 1) / n_per_thread);
group_dims = get_block_dims(dim0, dim1, dim2); group_dims = get_block_dims(dim0, dim1, dim2);
grid_dims = MTL::Size(dim0, dim1, dim2); grid_dims = MTL::Size(dim0, dim1, dim2);
} }

View File

@@ -21,8 +21,9 @@ void sdpa_full_self_attention_metal(
const array& v, const array& v,
const float scale, const float scale,
array& o, array& o,
bool do_causal_ = false, bool do_causal_,
const std::optional<array>& mask = std::nullopt) { const std::optional<array>& mask,
const std::optional<array>& sinks) {
using namespace mlx::steel; using namespace mlx::steel;
int wm = 4; int wm = 4;
@@ -42,35 +43,49 @@ void sdpa_full_self_attention_metal(
const bool align_Q = (qL % bq) == 0; const bool align_Q = (qL % bq) == 0;
const bool align_K = (kL % bk) == 0; const bool align_K = (kL % bk) == 0;
const bool has_mask = !!mask; const bool has_mask = mask.has_value();
const bool do_causal = do_causal_; const bool do_causal = do_causal_;
const bool has_sinks = sinks.has_value();
metal::MTLFCList func_consts = { metal::MTLFCList func_consts = {
{&align_Q, MTL::DataType::DataTypeBool, 200}, {&align_Q, MTL::DataType::DataTypeBool, 200},
{&align_K, MTL::DataType::DataTypeBool, 201}, {&align_K, MTL::DataType::DataTypeBool, 201},
{&has_mask, MTL::DataType::DataTypeBool, 300}, {&has_mask, MTL::DataType::DataTypeBool, 300},
{&do_causal, MTL::DataType::DataTypeBool, 301}}; {&do_causal, MTL::DataType::DataTypeBool, 301},
{&has_sinks, MTL::DataType::DataTypeBool, 302}};
std::ostringstream kname; std::string base_name;
// clang-format off concatenate(
kname << "steel_attention_" base_name,
<< type_to_name(q) "steel_attention_",
<< "_bq" << bq type_to_name(q),
<< "_bk" << bk "_bq",
<< "_bd" << bd bq,
<< "_wm" << wm "_bk",
<< "_wn" << wn bk,
<< "_mask" << (type_to_name(has_mask ? *mask : q)); // clang-format on "_bd",
bd,
"_wm",
wm,
"_wn",
wn,
"_mask",
type_to_name(has_mask ? *mask : q));
std::string base_name = kname.str(); std::string hash_name;
concatenate(
// clang-format off hash_name,
kname << "_align_Q_" << (align_Q ? 't' : 'n') base_name,
<< "_align_K_" << (align_K ? 't' : 'n') "_align_Q_",
<< "_has_mask_" << (has_mask ? 't' : 'n') (align_Q ? 't' : 'n'),
<< "_do_causal_" << (do_causal ? 't' : 'n'); // clang-format on "_align_K_",
(align_K ? 't' : 'n'),
std::string hash_name = kname.str(); "_has_mask_",
(has_mask ? 't' : 'n'),
"_do_causal_",
(do_causal ? 't' : 'n'),
"_has_sinks_",
(has_sinks ? 't' : 'n'));
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(base_name, hash_name, func_consts); auto kernel = d.get_kernel(base_name, hash_name, func_consts);
@@ -114,8 +129,8 @@ void sdpa_full_self_attention_metal(
compute_encoder.set_output_array(o, 3); compute_encoder.set_output_array(o, 3);
compute_encoder.set_bytes(params, 4); compute_encoder.set_bytes(params, 4);
if (mask) { if (has_mask) {
auto m = *mask; auto& m = *mask;
AttnMaskParams mask_params{/* int64_t M_strides[3] = */ { AttnMaskParams mask_params{/* int64_t M_strides[3] = */ {
m.strides(0), m.strides(1), m.strides(2)}}; m.strides(0), m.strides(1), m.strides(2)}};
@@ -123,6 +138,9 @@ void sdpa_full_self_attention_metal(
compute_encoder.set_bytes(mask_params, 5); compute_encoder.set_bytes(mask_params, 5);
compute_encoder.set_input_array(m, 6); compute_encoder.set_input_array(m, 6);
} }
if (has_sinks) {
compute_encoder.set_input_array(*sinks, 7);
}
MTL::Size grid_dims = MTL::Size(NQ, H, B); MTL::Size grid_dims = MTL::Size(NQ, H, B);
MTL::Size group_dims = MTL::Size(32, wm, wn); MTL::Size group_dims = MTL::Size(32, wm, wn);
@@ -139,7 +157,8 @@ void sdpa_vector(
array& out, array& out,
float scale, float scale,
bool do_causal, bool do_causal,
const std::optional<array>& mask) { const std::optional<array>& mask,
const std::optional<array>& sinks) {
// Set the kernel name // Set the kernel name
std::string kname; std::string kname;
kname.reserve(64); kname.reserve(64);
@@ -153,30 +172,32 @@ void sdpa_vector(
// Compute the necessary sizes // Compute the necessary sizes
int gqa_factor = q.shape(1) / k.shape(1); int gqa_factor = q.shape(1) / k.shape(1);
int N = k.shape(2); int N = k.shape(2);
int B = q.shape(0) * q.shape(1);
size_t k_head_stride = k.shape(1) == 1 ? k.strides(0) : k.strides(1); size_t k_head_stride = k.shape(1) == 1 ? k.strides(0) : k.strides(1);
size_t k_seq_stride = k.strides()[2]; size_t k_seq_stride = k.strides()[2];
size_t v_head_stride = v.shape(1) == 1 ? v.strides(0) : v.strides(1); size_t v_head_stride = v.shape(1) == 1 ? v.strides(0) : v.strides(1);
size_t v_seq_stride = v.strides()[2]; size_t v_seq_stride = v.strides()[2];
MTL::Size group_dims(1024, 1, 1); MTL::Size group_dims(1024, 1, 1);
MTL::Size grid_dims(B, q.shape(2), 1); MTL::Size grid_dims(q.shape(0) * q.shape(1), q.shape(2), 1);
bool has_mask = mask.has_value(); bool has_mask = mask.has_value();
bool bool_mask = has_mask && (*mask).dtype() == bool_; bool bool_mask = has_mask && (*mask).dtype() == bool_;
bool float_mask = has_mask && !bool_mask; bool float_mask = has_mask && !bool_mask;
bool query_transposed = !q.flags().row_contiguous; bool query_transposed = !q.flags().row_contiguous;
bool has_sinks = sinks.has_value();
metal::MTLFCList func_consts = { metal::MTLFCList func_consts = {
{&has_mask, MTL::DataType::DataTypeBool, 20}, {&has_mask, MTL::DataType::DataTypeBool, 20},
{&query_transposed, MTL::DataType::DataTypeBool, 21}, {&query_transposed, MTL::DataType::DataTypeBool, 21},
{&do_causal, MTL::DataType::DataTypeBool, 22}, {&do_causal, MTL::DataType::DataTypeBool, 22},
{&bool_mask, MTL::DataType::DataTypeBool, 23}, {&bool_mask, MTL::DataType::DataTypeBool, 23},
{&float_mask, MTL::DataType::DataTypeBool, 24}, {&float_mask, MTL::DataType::DataTypeBool, 24},
{&has_sinks, MTL::DataType::DataTypeBool, 25},
}; };
std::string hash_name = kname; std::string hash_name = kname;
hash_name += has_mask ? (bool_mask ? "_boolmask" : "_floatmask") : "_nomask"; hash_name += has_mask ? (bool_mask ? "_boolmask" : "_floatmask") : "_nomask";
hash_name += query_transposed ? "_qt" : "_qnt"; hash_name += query_transposed ? "_qt" : "_qnt";
hash_name += do_causal ? "_c" : "_nc"; hash_name += do_causal ? "_c" : "_nc";
hash_name += has_sinks ? "_sinks" : "_nosinks";
// Get the kernel // Get the kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
@@ -207,6 +228,10 @@ void sdpa_vector(
compute_encoder.set_bytes(q_seq_stride, 14); compute_encoder.set_bytes(q_seq_stride, 14);
compute_encoder.set_bytes(head_stride, 15); compute_encoder.set_bytes(head_stride, 15);
} }
if (has_sinks) {
compute_encoder.set_input_array(*sinks, 16);
compute_encoder.set_bytes(q.shape(1), 17);
}
// Launch // Launch
compute_encoder.dispatch_threadgroups(grid_dims, group_dims); compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
@@ -221,7 +246,8 @@ void sdpa_vector_2pass(
array& out, array& out,
float scale, float scale,
bool do_causal, bool do_causal,
const std::optional<array>& mask) { const std::optional<array>& mask,
const std::optional<array>& sinks) {
// Set the kernel name // Set the kernel name
std::string kname; std::string kname;
kname.reserve(64); kname.reserve(64);
@@ -267,17 +293,20 @@ void sdpa_vector_2pass(
bool bool_mask = has_mask && (*mask).dtype() == bool_; bool bool_mask = has_mask && (*mask).dtype() == bool_;
bool float_mask = has_mask && !bool_mask; bool float_mask = has_mask && !bool_mask;
bool query_transposed = !q.flags().row_contiguous; bool query_transposed = !q.flags().row_contiguous;
bool has_sinks = sinks.has_value();
metal::MTLFCList func_consts = { metal::MTLFCList func_consts = {
{&has_mask, MTL::DataType::DataTypeBool, 20}, {&has_mask, MTL::DataType::DataTypeBool, 20},
{&query_transposed, MTL::DataType::DataTypeBool, 21}, {&query_transposed, MTL::DataType::DataTypeBool, 21},
{&do_causal, MTL::DataType::DataTypeBool, 22}, {&do_causal, MTL::DataType::DataTypeBool, 22},
{&bool_mask, MTL::DataType::DataTypeBool, 23}, {&bool_mask, MTL::DataType::DataTypeBool, 23},
{&float_mask, MTL::DataType::DataTypeBool, 24}, {&float_mask, MTL::DataType::DataTypeBool, 24},
{&has_sinks, MTL::DataType::DataTypeBool, 25},
}; };
std::string hash_name = kname; std::string hash_name = kname;
hash_name += has_mask ? (bool_mask ? "_boolmask" : "_floatmask") : "_nomask"; hash_name += has_mask ? (bool_mask ? "_boolmask" : "_floatmask") : "_nomask";
hash_name += query_transposed ? "_qt" : "_qnt"; hash_name += query_transposed ? "_qt" : "_qnt";
hash_name += do_causal ? "_c" : "_nc"; hash_name += do_causal ? "_c" : "_nc";
hash_name += has_sinks ? "_sinks" : "_nosinks";
// Get the kernel // Get the kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
@@ -310,6 +339,10 @@ void sdpa_vector_2pass(
compute_encoder.set_bytes(q_seq_stride, 16); compute_encoder.set_bytes(q_seq_stride, 16);
compute_encoder.set_bytes(head_stride, 17); compute_encoder.set_bytes(head_stride, 17);
} }
if (has_sinks) {
compute_encoder.set_input_array(*sinks, 18);
compute_encoder.set_bytes(q.shape(1), 19);
}
// Launch // Launch
compute_encoder.dispatch_threadgroups(grid_dims, group_dims); compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
@@ -394,7 +427,7 @@ void ScaledDotProductAttention::eval_gpu(
// Define some copy functions to ensure the layout of the inputs is as // Define some copy functions to ensure the layout of the inputs is as
// expected. // expected.
copies.reserve(3); copies.reserve(inputs.size());
auto copy_unless = [&copies, &s]( auto copy_unless = [&copies, &s](
auto predicate, const array& arr) -> const array& { auto predicate, const array& arr) -> const array& {
if (!predicate(arr)) { if (!predicate(arr)) {
@@ -411,6 +444,12 @@ void ScaledDotProductAttention::eval_gpu(
return arr.strides(-1) == 1; return arr.strides(-1) == 1;
}; };
std::optional<array> sinks = std::nullopt;
if (has_sinks_) {
sinks = copy_unless(is_matrix_contiguous, inputs.back());
}
bool has_arr_mask = inputs.size() > (3 + has_sinks_);
// We are in vector mode ie single query // We are in vector mode ie single query
if (q_pre.shape(2) <= 8) { if (q_pre.shape(2) <= 8) {
auto q_copy_unless = [](const array& arr) { auto q_copy_unless = [](const array& arr) {
@@ -462,7 +501,7 @@ void ScaledDotProductAttention::eval_gpu(
(strides[0] == strides[1] * shape[1]); (strides[0] == strides[1] * shape[1]);
}; };
auto mask = inputs.size() > 3 auto mask = has_arr_mask
? std::optional<array>{copy_unless(mask_copy_unless, inputs[3])} ? std::optional<array>{copy_unless(mask_copy_unless, inputs[3])}
: std::nullopt; : std::nullopt;
@@ -473,9 +512,9 @@ void ScaledDotProductAttention::eval_gpu(
char devc = d.get_architecture().back(); char devc = d.get_architecture().back();
if ((devc == 'd' && k.shape(2) >= 1024) || if ((devc == 'd' && k.shape(2) >= 1024) ||
(k.shape(1) < q.shape(1) && k.shape(2) >= 4096)) { (k.shape(1) < q.shape(1) && k.shape(2) >= 4096)) {
sdpa_vector_2pass(s, d, q, k, v, o, scale_, do_causal, mask); sdpa_vector_2pass(s, d, q, k, v, o, scale_, do_causal, mask, sinks);
} else { } else {
sdpa_vector(s, d, q, k, v, o, scale_, do_causal, mask); sdpa_vector(s, d, q, k, v, o, scale_, do_causal, mask, sinks);
} }
} }
@@ -503,11 +542,12 @@ void ScaledDotProductAttention::eval_gpu(
{str_oB, str_oH, str_oL, str_oD}, {str_oB, str_oH, str_oL, str_oD},
flags); flags);
auto mask = inputs.size() > 3 auto mask = has_arr_mask
? std::optional<array>{copy_unless(is_matrix_contiguous, inputs[3])} ? std::optional<array>{copy_unless(is_matrix_contiguous, inputs[3])}
: std::nullopt; : std::nullopt;
sdpa_full_self_attention_metal(s, d, q, k, v, scale_, o, do_causal_, mask); sdpa_full_self_attention_metal(
s, d, q, k, v, scale_, o, do_causal_, mask, sinks);
} }
d.add_temporaries(std::move(copies), s.index); d.add_temporaries(std::move(copies), s.index);

View File

@@ -8,6 +8,7 @@ file(
"${CMAKE_CURRENT_BINARY_DIR}/nccl.h") "${CMAKE_CURRENT_BINARY_DIR}/nccl.h")
add_library(nccl SHARED nccl_stubs.cpp) add_library(nccl SHARED nccl_stubs.cpp)
set_target_properties(nccl PROPERTIES SOVERSION 2)
find_package(CUDAToolkit REQUIRED) find_package(CUDAToolkit REQUIRED)
target_include_directories(nccl PRIVATE ${CUDAToolkit_INCLUDE_DIRS}) target_include_directories(nccl PRIVATE ${CUDAToolkit_INCLUDE_DIRS})
target_include_directories(nccl PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) target_include_directories(nccl PRIVATE ${CMAKE_CURRENT_BINARY_DIR})

View File

@@ -366,10 +366,16 @@ array rope(
msg << "[rope] Input must be a floating type but got " << x.dtype() << "."; msg << "[rope] Input must be a floating type but got " << x.dtype() << ".";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
if (offset.size() != 1) { if (offset.ndim() > 1) {
std::ostringstream msg; std::ostringstream msg;
msg << "[rope] offset must be a scalar but has shape " << offset.shape() msg << "[rope] offset must have at most one dimension but has shape "
<< "."; << offset.shape() << ".";
throw std::invalid_argument(msg.str());
}
if (offset.size() != 1 && offset.size() != x.shape(0)) {
std::ostringstream msg;
msg << "[rope] offset must be a scalar or vector with " << x.shape(0)
<< " elements but has shape " << offset.shape() << ".";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
if (!issubdtype(offset.dtype(), integer)) { if (!issubdtype(offset.dtype(), integer)) {
@@ -379,7 +385,7 @@ array rope(
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
if (offset.dtype().size() != 4) { if (offset.dtype().size() != 4) {
inputs[1] = astype(offset, uint32, s); inputs[1] = astype(offset, int32, s);
} }
if (inputs.size() == 3 && if (inputs.size() == 3 &&
(inputs[2].ndim() != 1 || inputs[2].shape(0) != dims / 2)) { (inputs[2].ndim() != 1 || inputs[2].shape(0) != dims / 2)) {
@@ -391,15 +397,26 @@ array rope(
auto fallback = [dims, traditional, base, scale, forward, s]( auto fallback = [dims, traditional, base, scale, forward, s](
std::vector<array> inputs) { std::vector<array> inputs) {
auto& shape = inputs[0].shape(); auto x = inputs[0];
int ndim = shape.size(); auto shape = x.shape();
auto x = flatten(inputs[0], 0, ndim - 3, s); if (x.ndim() == 3) {
x = expand_dims(x, 1, s);
} else if (x.ndim() > 4) {
x = flatten(x, 1, 1 + (x.ndim() - 4), s);
}
auto B = x.shape(0);
auto N = x.shape(1);
auto T = x.shape(2);
auto t = x.dtype(); auto t = x.dtype();
// Compute sines and cosines // Compute sines and cosines
auto half_dims = dims / 2; auto half_dims = dims / 2;
auto& offset = inputs[1]; auto offset = inputs[1];
if (offset.size() > 1) {
offset = expand_dims(offset, {-1, -2}, s);
}
auto positions = auto positions =
multiply(add(arange(x.shape(1), t, s), offset, s), array(scale, t), s); multiply(add(arange(x.shape(2), t, s), offset, s), array(scale, t), s);
auto default_inv_freqs = [&inputs, &s, &t, base, half_dims]() { auto default_inv_freqs = [&inputs, &s, &t, base, half_dims]() {
return exp( return exp(
@@ -412,8 +429,7 @@ array rope(
auto inv_freqs = inputs.size() == 3 ? astype(reciprocal(inputs[2], s), t, s) auto inv_freqs = inputs.size() == 3 ? astype(reciprocal(inputs[2], s), t, s)
: default_inv_freqs(); : default_inv_freqs();
auto theta = auto theta = multiply(expand_dims(positions, -1, s), inv_freqs, s);
multiply(expand_dims(positions, 1, s), expand_dims(inv_freqs, 0, s), s);
auto coss = cos(theta, s); auto coss = cos(theta, s);
auto sins = sin(theta, s); auto sins = sin(theta, s);
@@ -436,32 +452,30 @@ array rope(
}; };
if (traditional) { if (traditional) {
auto x1 = auto x1 = slice(x, {0, 0, 0, 0}, {B, N, T, dims}, {1, 1, 1, 2}, s);
slice(x, {0, 0, 0}, {x.shape(0), x.shape(1), dims}, {1, 1, 2}, s); auto x2 = slice(x, {0, 0, 0, 1}, {B, N, T, dims}, {1, 1, 1, 2}, s);
auto x2 =
slice(x, {0, 0, 1}, {x.shape(0), x.shape(1), dims}, {1, 1, 2}, s);
auto outs = apply_rope(x1, x2, coss, sins); auto outs = apply_rope(x1, x2, coss, sins);
for (auto& o : outs) { for (auto& o : outs) {
o = expand_dims(o, 3, s); o = expand_dims(o, -1, s);
} }
auto out = concatenate(outs, 3, s); auto out = reshape(concatenate(outs, -1, s), {B, N, T, dims}, s);
if (dims < x.shape(-1)) { if (dims < x.shape(-1)) {
out = reshape(out, {x.shape(0), x.shape(1), dims}); out =
out = concatenate({out, slice(x, {0, 0, dims}, x.shape(), s)}, 2, s); concatenate({out, slice(x, {0, 0, 0, dims}, x.shape(), s)}, -1, s);
} }
return std::vector<array>{reshape(out, shape, s)}; return std::vector<array>{reshape(out, shape, s)};
} else { } else {
auto out_s = x.shape(); auto out_s = x.shape();
out_s.back() = half_dims; out_s.back() = half_dims;
auto x1 = slice(x, {0, 0, 0}, out_s, s); auto x1 = slice(x, {0, 0, 0, 0}, out_s, s);
out_s.back() = dims; out_s.back() = dims;
auto x2 = slice(x, {0, 0, half_dims}, out_s, s); auto x2 = slice(x, {0, 0, 0, half_dims}, out_s, s);
auto outs = apply_rope(x1, x2, coss, sins); auto outs = apply_rope(x1, x2, coss, sins);
if (dims < x.shape(-1)) { if (dims < x.shape(-1)) {
outs.push_back(slice(x, {0, 0, dims}, x.shape(), s)); outs.push_back(slice(x, {0, 0, 0, dims}, x.shape(), s));
} }
return std::vector<array>{reshape(concatenate(outs, 2, s), shape, s)}; return std::vector<array>{reshape(concatenate(outs, -1, s), shape, s)};
} }
}; };
auto stream = to_stream(s); auto stream = to_stream(s);
@@ -565,6 +579,7 @@ array scaled_dot_product_attention(
const float scale, const float scale,
const std::string& mask_mode /* = "" */, const std::string& mask_mode /* = "" */,
const std::vector<array>& mask_arrs /* = {} */, const std::vector<array>& mask_arrs /* = {} */,
const std::optional<array>& sinks /* = {} */,
StreamOrDevice s /* = {}*/) { StreamOrDevice s /* = {}*/) {
for (const auto& tensor : {queries, keys, values}) { for (const auto& tensor : {queries, keys, values}) {
if (tensor.ndim() != 4) { if (tensor.ndim() != 4) {
@@ -665,13 +680,20 @@ array scaled_dot_product_attention(
<< final_type << "."; << final_type << ".";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
bool has_sinks = sinks.has_value();
auto q = astype(queries, final_type, s); auto q = astype(queries, final_type, s);
auto k = astype(keys, final_type, s); auto k = astype(keys, final_type, s);
auto v = astype(values, final_type, s); auto v = astype(values, final_type, s);
auto fallback = [scale, final_type, n_q_heads, n_kv_heads, do_causal, s]( auto fallback = [scale,
const std::vector<array>& inputs) { final_type,
n_q_heads,
n_kv_heads,
do_causal,
has_sinks,
has_arr_mask,
s](const std::vector<array>& inputs) {
auto q = multiply(array(scale, inputs[0].dtype()), inputs[0], s); auto q = multiply(array(scale, inputs[0].dtype()), inputs[0], s);
int n_repeats = n_q_heads / n_kv_heads; int n_repeats = n_q_heads / n_kv_heads;
int B = q.shape(0); int B = q.shape(0);
@@ -684,10 +706,9 @@ array scaled_dot_product_attention(
v = expand_dims(v, 2, s); v = expand_dims(v, 2, s);
} }
auto scores = matmul(q, swapaxes(k, -1, -2, s), s); auto scores = matmul(q, swapaxes(k, -1, -2, s), s);
if (inputs.size() > 3 || do_causal) { if (has_arr_mask || do_causal) {
// Mask must be broadcast-compatible with [B, n_q_heads, L_q, L_kv] // Mask must be broadcast-compatible with [B, n_q_heads, L_q, L_kv]
auto mask = inputs.back(); auto make_or_fetch_mask = [&]() {
if (do_causal) { if (do_causal) {
int kL = k.shape(-2); int kL = k.shape(-2);
int qL = q.shape(-2); int qL = q.shape(-2);
@@ -696,8 +717,11 @@ array scaled_dot_product_attention(
auto k_idx = arange(0, kL, s); auto k_idx = arange(0, kL, s);
q_idx = expand_dims(q_idx, 1, s); q_idx = expand_dims(q_idx, 1, s);
k_idx = expand_dims(k_idx, 0, s); k_idx = expand_dims(k_idx, 0, s);
mask = greater_equal(q_idx, k_idx, s); return greater_equal(q_idx, k_idx, s);
} }
return inputs[3];
};
auto mask = make_or_fetch_mask();
if (n_repeats > 1 && mask.ndim() >= 3) { if (n_repeats > 1 && mask.ndim() >= 3) {
if (mask.shape(-3) == 1) { if (mask.shape(-3) == 1) {
@@ -716,7 +740,25 @@ array scaled_dot_product_attention(
scores = add(scores, mask, s); scores = add(scores, mask, s);
} }
} }
if (has_sinks) {
auto sinks = inputs.back();
// scores has shape B N_q N_k L_q L_k
sinks = expand_dims(sinks, {0, 2, 3}, s);
if (scores.ndim() == 5) {
sinks = unflatten(sinks, 1, {n_kv_heads, n_repeats}, s);
}
auto bsx_shape = scores.shape();
bsx_shape.back() = 1;
scores = concatenate({broadcast_to(sinks, bsx_shape, s), scores}, -1, s);
}
scores = softmax(scores, std::vector<int>{-1}, true, s); scores = softmax(scores, std::vector<int>{-1}, true, s);
if (has_sinks) {
// Slice off scores
auto start = Shape(scores.ndim(), 0);
start.back() = 1;
auto stop = scores.shape();
scores = slice(scores, std::move(start), std::move(stop), s);
}
auto out = matmul(scores, v, s); auto out = matmul(scores, v, s);
if (n_repeats > 1) { if (n_repeats > 1) {
out = flatten(out, 1, 2, s); out = flatten(out, 1, 2, s);
@@ -732,7 +774,7 @@ array scaled_dot_product_attention(
has_bool_mask = mask_arr.dtype() == bool_; has_bool_mask = mask_arr.dtype() == bool_;
if (promote_types(mask_arr.dtype(), final_type) != final_type) { if (promote_types(mask_arr.dtype(), final_type) != final_type) {
std::ostringstream msg; std::ostringstream msg;
msg << "[scaled_dot_product_attention] Mask type must promote to output type. " msg << "[scaled_dot_product_attention] Mask type must promote to output type "
<< final_type << "."; << final_type << ".";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} else if (!has_bool_mask) { } else if (!has_bool_mask) {
@@ -743,6 +785,22 @@ array scaled_dot_product_attention(
mask_shape.back() = keys.shape(-2); mask_shape.back() = keys.shape(-2);
inputs.push_back(broadcast_to(mask_arr, mask_shape, stream)); inputs.push_back(broadcast_to(mask_arr, mask_shape, stream));
} }
if (has_sinks) {
if (promote_types(sinks->dtype(), final_type) != final_type) {
std::ostringstream msg;
msg << "[scaled_dot_product_attention] Type of sinks must promote to output type "
<< final_type << ".";
throw std::invalid_argument(msg.str());
}
if (sinks->ndim() != 1 || sinks->shape(0) != n_q_heads) {
std::ostringstream msg;
msg << "[scaled_dot_product_attention] Received invalid shape for sinks "
<< sinks->shape() << ".";
throw std::invalid_argument(msg.str());
}
inputs.push_back(astype(*sinks, final_type, stream));
}
if (!ScaledDotProductAttention::use_fallback( if (!ScaledDotProductAttention::use_fallback(
q, k, v, has_mask, has_arr_mask, do_causal, stream)) { q, k, v, has_mask, has_arr_mask, do_causal, stream)) {
auto out_shape = Shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)}; auto out_shape = Shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)};
@@ -750,7 +808,7 @@ array scaled_dot_product_attention(
std::move(out_shape), std::move(out_shape),
final_type, final_type,
std::make_shared<ScaledDotProductAttention>( std::make_shared<ScaledDotProductAttention>(
stream, fallback, scale, do_causal), stream, fallback, scale, do_causal, has_sinks),
std::move(inputs)); std::move(inputs));
} }
return fallback(std::move(inputs))[0]; return fallback(std::move(inputs))[0];
@@ -759,7 +817,8 @@ array scaled_dot_product_attention(
bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const { bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const {
const ScaledDotProductAttention& a_other = const ScaledDotProductAttention& a_other =
static_cast<const ScaledDotProductAttention&>(other); static_cast<const ScaledDotProductAttention&>(other);
return scale_ == a_other.scale_ && do_causal_ == a_other.do_causal_; return scale_ == a_other.scale_ && do_causal_ == a_other.do_causal_ &&
has_sinks_ == a_other.has_sinks_;
} }
bool Quantize::is_equivalent(const Primitive& other) const { bool Quantize::is_equivalent(const Primitive& other) const {

View File

@@ -50,6 +50,7 @@ array scaled_dot_product_attention(
const float scale, const float scale,
const std::string& mask_mode = "", const std::string& mask_mode = "",
const std::vector<array>& mask_arrs = {}, const std::vector<array>& mask_arrs = {},
const std::optional<array>& sinks = {},
StreamOrDevice s = {}); StreamOrDevice s = {});
using TemplateArg = std::variant<int, bool, Dtype>; using TemplateArg = std::variant<int, bool, Dtype>;

View File

@@ -208,9 +208,13 @@ class ScaledDotProductAttention : public Custom {
explicit ScaledDotProductAttention( explicit ScaledDotProductAttention(
Stream stream, Stream stream,
std::function<std::vector<array>(std::vector<array>)> fallback, std::function<std::vector<array>(std::vector<array>)> fallback,
const float scale, float scale,
const bool do_causal) bool do_causal,
: Custom(stream, fallback), scale_(scale), do_causal_(do_causal) {} bool has_sinks)
: Custom(stream, fallback),
scale_(scale),
do_causal_(do_causal),
has_sinks_(has_sinks) {}
static bool use_fallback( static bool use_fallback(
const array& q, const array& q,
@@ -237,12 +241,13 @@ class ScaledDotProductAttention : public Custom {
DEFINE_NAME(ScaledDotProductAttention); DEFINE_NAME(ScaledDotProductAttention);
DEFINE_INPUT_OUTPUT_SHAPE() DEFINE_INPUT_OUTPUT_SHAPE()
auto state() const { auto state() const {
return std::make_tuple(nullptr, scale_, do_causal_); return std::make_tuple(nullptr, scale_, do_causal_, has_sinks_);
} }
private: private:
float scale_; float scale_;
bool do_causal_; bool do_causal_;
bool has_sinks_;
}; };
class Quantize : public Custom { class Quantize : public Custom {

View File

@@ -1,10 +1,8 @@
// Copyright © 2025 Apple Inc. // Copyright © 2025 Apple Inc.
#include <string>
namespace mlx::core { namespace mlx::core {
std::string version() { const char* version() {
return MLX_VERSION; return MLX_VERSION;
} }

View File

@@ -4,7 +4,7 @@
#define MLX_VERSION_MAJOR 0 #define MLX_VERSION_MAJOR 0
#define MLX_VERSION_MINOR 29 #define MLX_VERSION_MINOR 29
#define MLX_VERSION_PATCH 0 #define MLX_VERSION_PATCH 1
#define MLX_VERSION_NUMERIC \ #define MLX_VERSION_NUMERIC \
(100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH) (100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH)
@@ -15,6 +15,6 @@ namespace mlx::core {
* *
* For dev builds, the version will include the suffix ".devYYYYMMDD+hash" * For dev builds, the version will include the suffix ".devYYYYMMDD+hash"
*/ */
std::string version(); const char* version();
} // namespace mlx::core } // namespace mlx::core

View File

@@ -1,20 +1,34 @@
mlx.core.__prefix__:
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
import sys
if sys.version_info >= (3, 10):
from typing import TypeAlias
else:
from typing_extensions import TypeAlias
mlx.core.__suffix__:
from typing import Union
scalar: TypeAlias = Union[int, float, bool]
list_or_scalar: TypeAlias = Union[scalar, list["list_or_scalar"]]
bool_: Dtype = ...
mlx.core.distributed.__prefix__: mlx.core.distributed.__prefix__:
from mlx.core import array, Dtype, Device, Stream from mlx.core import array, Dtype, Device, Stream, scalar
from mlx.core.distributed import Group from mlx.core.distributed import Group
from typing import Sequence, Optional, Union from typing import Sequence, Optional, Union
mlx.core.fast.__prefix__: mlx.core.fast.__prefix__:
from mlx.core import array, Dtype, Device, Stream from mlx.core import array, Dtype, Device, Stream, scalar
from typing import Sequence, Optional, Union from typing import Sequence, Optional, Union
mlx.core.linalg.__prefix__: mlx.core.linalg.__prefix__:
from mlx.core import array, Dtype, Device, Stream from mlx.core import array, Dtype, Device, Stream, scalar
from typing import Sequence, Optional, Tuple, Union from typing import Sequence, Optional, Tuple, Union
mlx.core.metal.__prefix__: mlx.core.metal.__prefix__:
from mlx.core import array, Dtype, Device, Stream from mlx.core import array, Dtype, Device, Stream, scalar
from typing import Sequence, Optional, Union from typing import Sequence, Optional, Union
mlx.core.random.__prefix__: mlx.core.random.__prefix__:
from mlx.core import array, Dtype, Device, Stream from mlx.core import array, Dtype, Device, Stream, scalar, float32, int32
from typing import Sequence, Optional, Union from typing import Sequence, Optional, Union

View File

@@ -15,6 +15,7 @@ from mlx.nn.layers.activations import (
Mish, Mish,
PReLU, PReLU,
ReLU, ReLU,
ReLU2,
ReLU6, ReLU6,
Sigmoid, Sigmoid,
SiLU, SiLU,
@@ -40,6 +41,7 @@ from mlx.nn.layers.activations import (
mish, mish,
prelu, prelu,
relu, relu,
relu2,
relu6, relu6,
selu, selu,
sigmoid, sigmoid,

View File

@@ -35,6 +35,24 @@ def relu(x):
return mx.maximum(x, 0) return mx.maximum(x, 0)
@partial(mx.compile, shapeless=True)
def relu2(x):
r"""Applies the ReLU² activation function.
Applies :math:`\max(0, x)^2` element wise.
"""
return mx.square(mx.maximum(x, 0))
@partial(mx.compile, shapeless=True)
def relu6(x):
r"""Applies the Rectified Linear Unit 6.
Applies :math:`\min(\max(x, 0), 6)` element wise.
"""
return mx.minimum(mx.maximum(x, 0), 6.0)
@partial(mx.compile, shapeless=True) @partial(mx.compile, shapeless=True)
def leaky_relu(x, negative_slope=0.01): def leaky_relu(x, negative_slope=0.01):
r"""Applies the Leaky Rectified Linear Unit. r"""Applies the Leaky Rectified Linear Unit.
@@ -62,15 +80,6 @@ def elu(x, alpha=1.0):
return mx.where(x > 0, x, alpha * (mx.exp(x) - 1)) return mx.where(x > 0, x, alpha * (mx.exp(x) - 1))
@partial(mx.compile, shapeless=True)
def relu6(x):
r"""Applies the Rectified Linear Unit 6.
Applies :math:`\min(\max(x, 0), 6)` element wise.
"""
return mx.minimum(mx.maximum(x, 0), 6.0)
@partial(mx.compile, shapeless=True) @partial(mx.compile, shapeless=True)
def softmax(x, axis=-1): def softmax(x, axis=-1):
r"""Applies the Softmax function. r"""Applies the Softmax function.
@@ -377,6 +386,22 @@ class ReLU(Module):
""" """
@_make_activation_module(relu2)
class ReLU2(Module):
r"""Applies the ReLU² activation function.
See :func:`relu2` for the functional equivalent.
"""
@_make_activation_module(relu6)
class ReLU6(Module):
r"""Applies the Rectified Linear Unit 6.
See :func:`relu6` for the functional equivalent.
"""
class LeakyReLU(Module): class LeakyReLU(Module):
r"""Applies the Leaky Rectified Linear Unit. r"""Applies the Leaky Rectified Linear Unit.
@@ -412,14 +437,6 @@ class ELU(Module):
return elu(x, self._alpha) return elu(x, self._alpha)
@_make_activation_module(relu6)
class ReLU6(Module):
r"""Applies the Rectified Linear Unit 6.
See :func:`relu6` for the functional equivalent.
"""
@_make_activation_module(softmax) @_make_activation_module(softmax)
class Softmax(Module): class Softmax(Module):
r"""Applies the Softmax function. r"""Applies the Softmax function.

View File

@@ -556,7 +556,7 @@ class AdamW(Adam):
eps (float, optional): The term :math:`\epsilon` added to the eps (float, optional): The term :math:`\epsilon` added to the
denominator to improve numerical stability. Default: ``1e-8`` denominator to improve numerical stability. Default: ``1e-8``
weight_decay (float, optional): The weight decay :math:`\lambda`. weight_decay (float, optional): The weight decay :math:`\lambda`.
Default: ``0``. Default: ``0.01``.
bias_correction (bool, optional): If set to ``True``, bias correction bias_correction (bool, optional): If set to ``True``, bias correction
is applied. Default: ``False`` is applied. Default: ``False``
""" """

View File

@@ -52,7 +52,6 @@ set_target_properties(
${MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY}) ${MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY})
target_link_libraries(core PRIVATE mlx) target_link_libraries(core PRIVATE mlx)
target_compile_definitions(core PRIVATE _VERSION_=${MLX_VERSION})
if(BUILD_SHARED_LIBS) if(BUILD_SHARED_LIBS)
if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin") if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")

View File

@@ -320,6 +320,7 @@ void init_array(nb::module_& m) {
.def_prop_ro( .def_prop_ro(
"shape", "shape",
[](const mx::array& a) { return nb::cast(a.shape()); }, [](const mx::array& a) { return nb::cast(a.shape()); },
nb::sig("def shape(self) -> tuple[int, ...]"),
R"pbdoc( R"pbdoc(
The shape of the array as a Python tuple. The shape of the array as a Python tuple.
@@ -347,6 +348,7 @@ void init_array(nb::module_& m) {
.def( .def(
"item", "item",
&to_scalar, &to_scalar,
nb::sig("def item(self) -> scalar"),
R"pbdoc( R"pbdoc(
Access the value of a scalar array. Access the value of a scalar array.
@@ -356,6 +358,7 @@ void init_array(nb::module_& m) {
.def( .def(
"tolist", "tolist",
&tolist, &tolist,
nb::sig("def tolist(self) -> list_or_scalar"),
R"pbdoc( R"pbdoc(
Convert the array to a Python :class:`list`. Convert the array to a Python :class:`list`.

View File

@@ -164,8 +164,13 @@ void init_fast(nb::module_& parent_module) {
R"pbdoc( R"pbdoc(
Apply rotary positional encoding to the input. Apply rotary positional encoding to the input.
The input is expected to be at least 3D with shape ``(B, *, T, D)`` where:
* ``B`` is the batch size.
* ``T`` is the sequence length.
* ``D`` is the feature dimension.
Args: Args:
a (array): Input array. a (array): The input array.
dims (int): The feature dimensions to be rotated. If the input feature dims (int): The feature dimensions to be rotated. If the input feature
is larger than dims then the rest is left unchanged. is larger than dims then the rest is left unchanged.
traditional (bool): If set to ``True`` choose the traditional traditional (bool): If set to ``True`` choose the traditional
@@ -174,7 +179,9 @@ void init_fast(nb::module_& parent_module) {
each dimension in the positional encodings. Exactly one of ``base`` and each dimension in the positional encodings. Exactly one of ``base`` and
``freqs`` must be ``None``. ``freqs`` must be ``None``.
scale (float): The scale used to scale the positions. scale (float): The scale used to scale the positions.
offset (int or array): The position offset to start at. offset (int or array): The position offset to start at. If an
:obj:`array` is given it can be a scalar or vector of ``B``
offsets for each example in the batch.
freqs (array, optional): Optional frequencies to use with RoPE. freqs (array, optional): Optional frequencies to use with RoPE.
If set, the ``base`` parameter must be ``None``. Default: ``None``. If set, the ``base`` parameter must be ``None``. Default: ``None``.
@@ -189,6 +196,7 @@ void init_fast(nb::module_& parent_module) {
const mx::array& values, const mx::array& values,
const float scale, const float scale,
const std::variant<std::monostate, std::string, mx::array>& mask, const std::variant<std::monostate, std::string, mx::array>& mask,
const std::optional<mx::array>& sinks,
mx::StreamOrDevice s) { mx::StreamOrDevice s) {
bool has_mask = !std::holds_alternative<std::monostate>(mask); bool has_mask = !std::holds_alternative<std::monostate>(mask);
bool has_str_mask = bool has_str_mask =
@@ -205,16 +213,16 @@ void init_fast(nb::module_& parent_module) {
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
return mx::fast::scaled_dot_product_attention( return mx::fast::scaled_dot_product_attention(
queries, keys, values, scale, mask_str, {}, s); queries, keys, values, scale, mask_str, {}, sinks, s);
} else { } else {
auto mask_arr = std::get<mx::array>(mask); auto mask_arr = std::get<mx::array>(mask);
return mx::fast::scaled_dot_product_attention( return mx::fast::scaled_dot_product_attention(
queries, keys, values, scale, "", {mask_arr}, s); queries, keys, values, scale, "", {mask_arr}, sinks, s);
} }
} else { } else {
return mx::fast::scaled_dot_product_attention( return mx::fast::scaled_dot_product_attention(
queries, keys, values, scale, "", {}, s); queries, keys, values, scale, "", {}, sinks, s);
} }
}, },
"q"_a, "q"_a,
@@ -223,9 +231,10 @@ void init_fast(nb::module_& parent_module) {
nb::kw_only(), nb::kw_only(),
"scale"_a, "scale"_a,
"mask"_a = nb::none(), "mask"_a = nb::none(),
"sinks"_a = nb::none(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: Union[None, str, array] = None, stream: Union[None, Stream, Device] = None) -> array"), "def scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: Union[None, str, array] = None, sinks: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
A fast implementation of multi-head attention: ``O = softmax(Q @ K.T, dim=-1) @ V``. A fast implementation of multi-head attention: ``O = softmax(Q @ K.T, dim=-1) @ V``.
@@ -255,14 +264,17 @@ void init_fast(nb::module_& parent_module) {
q (array): Queries with shape ``[B, N_q, T_q, D]``. q (array): Queries with shape ``[B, N_q, T_q, D]``.
k (array): Keys with shape ``[B, N_kv, T_kv, D]``. k (array): Keys with shape ``[B, N_kv, T_kv, D]``.
v (array): Values with shape ``[B, N_kv, T_kv, D]``. v (array): Values with shape ``[B, N_kv, T_kv, D]``.
scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``) scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``).
mask (Union[None, str, array], optional): The mask to apply to the mask (str or array, optional): The mask to apply to the
query-key scores. The mask can be an array or a string indicating query-key scores. The mask can be an array or a string indicating
the mask type. The only supported string type is ``"causal"``. If the mask type. The only supported string type is ``"causal"``. If
the mask is an array it can be a boolean or additive mask. The mask the mask is an array it can be a boolean or additive mask. The mask
can have at most 4 dimensions and must be broadcast-compatible with can have at most 4 dimensions and must be broadcast-compatible with
the shape ``[B, N, T_q, T_kv]``. If an additive mask is given its the shape ``[B, N, T_q, T_kv]``. If an additive mask is given its
type must promote to the promoted type of ``q``, ``k``, and ``v``. type must promote to the promoted type of ``q``, ``k``, and ``v``.
sinks (array, optional): An optional array of attention sinks.
Default: ``None``.
Returns: Returns:
array: The output array. array: The output array.

View File

@@ -447,6 +447,8 @@ void init_linalg(nb::module_& parent_module) {
"a"_a, "a"_a,
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig(
"def eig(a: array, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array]"),
R"pbdoc( R"pbdoc(
Compute the eigenvalues and eigenvectors of a square matrix. Compute the eigenvalues and eigenvectors of a square matrix.
@@ -523,6 +525,8 @@ void init_linalg(nb::module_& parent_module) {
"UPLO"_a = "L", "UPLO"_a = "L",
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig(
"def eigh(a: array, UPLO: str = 'L', *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array]"),
R"pbdoc( R"pbdoc(
Compute the eigenvalues and eigenvectors of a complex Hermitian or Compute the eigenvalues and eigenvectors of a complex Hermitian or
real symmetric matrix. real symmetric matrix.

View File

@@ -2,9 +2,9 @@
#include <nanobind/nanobind.h> #include <nanobind/nanobind.h>
#define STRINGIFY(x) #x #include "mlx/version.h"
#define TOSTRING(x) STRINGIFY(x)
namespace mx = mlx::core;
namespace nb = nanobind; namespace nb = nanobind;
void init_mlx_func(nb::module_&); void init_mlx_func(nb::module_&);
@@ -48,5 +48,5 @@ NB_MODULE(core, m) {
init_distributed(m); init_distributed(m);
init_export(m); init_export(m);
m.attr("__version__") = TOSTRING(_VERSION_); m.attr("__version__") = mx::version();
} }

View File

@@ -4271,7 +4271,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def dequantize(w: array, /, scales: array, biases: Optional[array] = = None, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"), "def dequantize(w: array, /, scales: array, biases: Optional[array] = None, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
Dequantize the matrix ``w`` using quantization parameters. Dequantize the matrix ``w`` using quantization parameters.

View File

@@ -171,7 +171,7 @@ void init_random(nb::module_& parent_module) {
"key"_a = nb::none(), "key"_a = nb::none(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def normal(shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, loc: Optional[scalar, array] = None, scale: Optional[scalar, array] = None, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"), "def normal(shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, loc: Union[scalar, array, None] = None, scale: Union[scalar, array, None] = None, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
Generate normally distributed random numbers. Generate normally distributed random numbers.

View File

@@ -594,8 +594,7 @@ class TestBlas(mlx_tests.MLXTestCase):
np.random.seed(0) np.random.seed(0)
# Batched matmul # Batched matmul
alpha = 0.5 alpha = 0.5
beta = 2.0 for beta in (1.0, 2.0):
# c must broadcast to the output shape # c must broadcast to the output shape
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
mx.addmm(mx.zeros((2, 2, 2)), mx.zeros((2, 2)), mx.zeros((2, 2))) mx.addmm(mx.zeros((2, 2, 2)), mx.zeros((2, 2)), mx.zeros((2, 2)))
@@ -701,16 +700,16 @@ class TestBlas(mlx_tests.MLXTestCase):
# Transposed c # Transposed c
a = mx.ones((10, 5)).T a = mx.ones((10, 5)).T
b = mx.ones((5, 5)) b = mx.ones((5, 5))
out = mx.addmm(a, b, a, beta=1.5, alpha=0.5) out = mx.addmm(a, b, a, beta=beta, alpha=alpha)
expected = 1.5 * a + 0.5 * (b @ a) expected = beta * a + alpha * (b @ a)
self.assertTrue(mx.allclose(expected, out)) self.assertTrue(mx.allclose(expected, out))
# Broadcast c # Broadcast c
a = mx.ones((5, 5)) a = mx.ones((5, 5))
b = mx.ones((5, 5)) b = mx.ones((5, 5))
c = mx.ones((1, 5)) c = mx.ones((1, 5))
out = mx.addmm(c, a, b, beta=1.5, alpha=0.5) out = mx.addmm(c, a, b, beta=beta, alpha=alpha)
expected = 1.5 * c + 0.5 * (a @ b) expected = beta * c + alpha * (a @ b)
self.assertTrue(mx.allclose(expected, out)) self.assertTrue(mx.allclose(expected, out))
def test_addmm_grad(self): def test_addmm_grad(self):
@@ -724,8 +723,7 @@ class TestBlas(mlx_tests.MLXTestCase):
shapes = ((1, 64, 32, 128), (4, 28, 24, 47), (1, 1, 24, 47)) shapes = ((1, 64, 32, 128), (4, 28, 24, 47), (1, 1, 24, 47))
alpha = 2.0 alpha = 2.0
beta = 0.5 for beta in (1.0, 0.5):
f_test = make_addmm(alpha, beta) f_test = make_addmm(alpha, beta)
f_ref = make_ref_addmm(alpha, beta) f_ref = make_ref_addmm(alpha, beta)

View File

@@ -8,18 +8,23 @@ import mlx_tests
def rope_orig(x, dims, traditional, base, scale, offset, freqs=None): def rope_orig(x, dims, traditional, base, scale, offset, freqs=None):
offset = offset.item() if isinstance(offset, mx.array) else offset N = x.shape[-2]
N = x.shape[-2] + offset
dtype = x.dtype dtype = x.dtype
half_D = dims // 2 half_D = dims // 2
positions = mx.arange(offset, N, dtype=dtype) * scale positions = mx.arange(N, dtype=dtype)
if isinstance(offset, mx.array) and offset.size > 1:
expand = tuple(range(1, x.ndim - 1))
positions = mx.expand_dims(offset, expand) + positions
else:
positions = offset + positions
positions = positions * scale
if freqs is None: if freqs is None:
inv_freqs = mx.exp( inv_freqs = mx.exp(
-mx.arange(0.0, half_D, dtype=dtype) * (math.log(base) / half_D) -mx.arange(0.0, half_D, dtype=dtype) * (math.log(base) / half_D)
) )
else: else:
inv_freqs = (1 / freqs).astype(x.dtype) inv_freqs = (1 / freqs).astype(x.dtype)
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(inv_freqs, (1, -1)) theta = mx.expand_dims(positions, -1) * inv_freqs
costheta, sintheta = mx.cos(theta), mx.sin(theta) costheta, sintheta = mx.cos(theta), mx.sin(theta)
if traditional: if traditional:
x1 = x[..., :dims:2] x1 = x[..., :dims:2]
@@ -214,6 +219,7 @@ class TestFast(mlx_tests.MLXTestCase):
) )
self.assertEqual(dtype, rx.dtype) self.assertEqual(dtype, rx.dtype)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype]) self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
return
# Test single vector # Test single vector
x = mx.random.uniform(shape=(1, 1, dims)) x = mx.random.uniform(shape=(1, 1, dims))
@@ -277,6 +283,55 @@ class TestFast(mlx_tests.MLXTestCase):
g2 = mx.grad(f2)(x, y) g2 = mx.grad(f2)(x, y)
self.assertLess(mx.abs(g1 - g2).max(), 1e-5) self.assertLess(mx.abs(g1 - g2).max(), 1e-5)
def test_rope_batch(self):
T = 4
base = 10000.0
scale = 1.0
traditional = True
batch_sizes = [3, 8, 11]
num_heads = [1, 3, 5]
dims = 32
x = mx.random.uniform(shape=(8, 4, T, dims))
offset = mx.array([1, 2, 3])
with self.assertRaises(ValueError):
mx.fast.rope(
x,
dims,
traditional=traditional,
base=base,
scale=scale,
offset=offset,
)
for batch_size in batch_sizes:
for n_head in num_heads:
x = mx.random.uniform(shape=(batch_size, n_head, T, dims))
offset = mx.arange(batch_size)
rx = rope_orig(x, dims, traditional, base, scale, offset)
rx_fast = mx.fast.rope(
x,
dims,
traditional=traditional,
base=base,
scale=scale,
offset=offset,
)
self.assertLess(mx.abs(rx - rx_fast).max(), 1e-5)
x = mx.random.normal(shape=(2, 6, 8, 64)).transpose(0, 2, 1, 3)
dims = 64
offset = 0
rx_fast = mx.fast.rope(
x, dims, traditional=traditional, scale=scale, base=base, offset=offset
)
rx_fast_single = mx.fast.rope(
x[0:1], dims, traditional=traditional, scale=scale, base=base, offset=offset
)
rx = rope_orig(x, dims, traditional, base, scale, offset)
self.assertLess(mx.abs(rx - rx_fast).max(), 1e-5)
def test_rms_norm(self): def test_rms_norm(self):
# Per dtype absolute tolerance # Per dtype absolute tolerance
tolerances = {mx.float32: 1e-6, mx.float16: 1e-3, mx.bfloat16: 1e-2} tolerances = {mx.float32: 1e-6, mx.float16: 1e-3, mx.bfloat16: 1e-2}

View File

@@ -6,7 +6,7 @@ import mlx_tests
import numpy as np import numpy as np
def mlx_ref_attn(q, k, v, scale=1.0, mask=None): def mlx_ref_attn(q, k, v, scale=1.0, mask=None, sinks=None):
q_dtype = q.dtype q_dtype = q.dtype
q = q * mx.array(scale, q_dtype) q = q * mx.array(scale, q_dtype)
n_q_heads = q.shape[-3] n_q_heads = q.shape[-3]
@@ -23,7 +23,6 @@ def mlx_ref_attn(q, k, v, scale=1.0, mask=None):
v = mx.expand_dims(v, 2) v = mx.expand_dims(v, 2)
scores = q @ mx.swapaxes(k, -1, -2) scores = q @ mx.swapaxes(k, -1, -2)
if mask is not None: if mask is not None:
if mask == "causal": if mask == "causal":
@@ -43,7 +42,18 @@ def mlx_ref_attn(q, k, v, scale=1.0, mask=None):
else: else:
scores += mask scores += mask
if sinks is not None:
sinks = mx.expand_dims(sinks, (0, 2, 3))
if n_repeats > 1:
sinks = mx.unflatten(sinks, 1, (n_kv_heads, n_repeats))
score_shape = list(scores.shape)
score_shape[-1] = 1
sinks = mx.broadcast_to(sinks, score_shape)
scores = mx.concatenate([sinks, scores], axis=-1)
scores = mx.softmax(scores, axis=-1, precise=True) scores = mx.softmax(scores, axis=-1, precise=True)
if sinks is not None:
scores = scores[..., 1:]
out = scores @ v out = scores @ v
if n_repeats > 1: if n_repeats > 1:
@@ -158,7 +168,7 @@ class TestFastSelfAttentionSDPA(mlx_tests.MLXTestCase):
Dk = 64 Dk = 64
if self.is_apple_silicon: if self.is_apple_silicon or mx.cuda.is_available():
dtypes.append(np.half) dtypes.append(np.half)
for SEQUENCE_LENGTH in [63, 129, 400]: for SEQUENCE_LENGTH in [63, 129, 400]:
@@ -230,7 +240,7 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
B = 1 B = 1
H = 32 H = 32
dtypes = [np.float32] dtypes = [np.float32]
if self.is_apple_silicon: if self.is_apple_silicon or mx.cuda.is_available():
dtypes.append(np.half) dtypes.append(np.half)
for SEQUENCE_LENGTH in [1, 7, 9, 32, 63, 67, 129, 400, 2000]: for SEQUENCE_LENGTH in [1, 7, 9, 32, 63, 67, 129, 400, 2000]:
@@ -400,16 +410,31 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
def test_fully_masked(self): def test_fully_masked(self):
Lkv = 8 Lkv = 8
mask = mx.array(False) masks = [mx.array(False), mx.array(-float("inf"))]
for mask in masks:
for D in [4, 128]: for D in [4, 128]:
for Lq in [1, 8]: for Lq in [1, 8]:
q = mx.random.normal(shape=(1, 4, Lq, D)) q = mx.random.normal(shape=(1, 4, Lq, D))
k = mx.random.normal(shape=(1, 4, Lkv, D)) k = mx.random.normal(shape=(1, 4, Lkv, D))
v = mx.random.normal(shape=(1, 4, Lkv, D)) v = mx.random.normal(shape=(1, 4, Lkv, D))
out = mx.fast.scaled_dot_product_attention(q, k, v, mask=mask, scale=1) out = mx.fast.scaled_dot_product_attention(
q, k, v, mask=mask, scale=1
)
self.assertTrue(mx.all(mx.isnan(out))) self.assertTrue(mx.all(mx.isnan(out)))
def test_inf_score(self):
Lkv = 8
for D in [4, 128]:
for Lq in [1, 8]:
q = mx.ones(shape=(1, 4, Lq, D))
k = mx.ones(shape=(1, 4, Lkv, D))
v = mx.random.normal(shape=(1, 4, Lkv, D))
k[..., 0, :] = -float("inf")
ref = mlx_primitives_sdpa(q, k, v, scale=1, mask=None)
out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1)
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
def test_fast_sdpa_few_query(self): def test_fast_sdpa_few_query(self):
D = 64 D = 64
L = 43 L = 43
@@ -619,6 +644,17 @@ class TestSDPA(mlx_tests.MLXTestCase):
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask) out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
def test_sdpa_noncontiguous_inputs(self):
mask = mx.ones(shape=(4, 1, 7, 7), dtype=mx.bool_)
mx.random.seed(0)
q = mx.random.normal(shape=(4, 7, 32, 64)).swapaxes(1, 2)
k = mx.random.normal(shape=(4, 7, 8, 64)).swapaxes(1, 2)
v = mx.random.normal(shape=(4, 7, 8, 64)).swapaxes(1, 2)
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask)
ref = mlx_ref_attn(q, k, v, scale=1.0, mask=mask)
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
def test_sdpa_promote_mask(self): def test_sdpa_promote_mask(self):
mask = mx.array(2.0, mx.bfloat16) mask = mx.array(2.0, mx.bfloat16)
D = 64 D = 64
@@ -663,6 +699,51 @@ class TestSDPA(mlx_tests.MLXTestCase):
self.assertFalse(mx.isnan(out).any().item()) self.assertFalse(mx.isnan(out).any().item())
self.assertLessEqual(mx.abs(out - expected).max().item(), 1e-4) self.assertLessEqual(mx.abs(out - expected).max().item(), 1e-4)
def test_sdpa_attention_sinks(self):
B = 2
N_q = N_kv = 8
T_q = T_kv = 128
D = 64
q = mx.random.normal(shape=(B, N_q, T_q, D))
k = mx.random.normal(shape=(B, N_kv, T_kv, D))
v = mx.random.normal(shape=(B, N_kv, T_kv, D))
scale = D**-0.5
# sinks should promote to correct type
sinks = mx.random.normal(shape=(N_q,))
with self.assertRaises(ValueError):
mx.fast.scaled_dot_product_attention(
q.astype(mx.float16),
k.astype(mx.float16),
v.astype(mx.float16),
scale=scale,
sinks=sinks,
)
# Wrong shapes
sinks = mx.random.normal(shape=(N_q + 1,))
with self.assertRaises(ValueError):
mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, sinks=sinks)
sinks = mx.random.normal(shape=())
with self.assertRaises(ValueError):
mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, sinks=sinks)
for T_kv in [128, 4096]:
for T_q in [1, 128]:
for N_kv in [2, 8]:
q = mx.random.normal(shape=(B, N_q, T_q, D))
k = mx.random.normal(shape=(B, N_kv, T_kv, D))
v = mx.random.normal(shape=(B, N_kv, T_kv, D))
sinks = 10 * mx.random.normal(shape=(N_q,))
expected = mlx_ref_attn(q, k, v, scale, sinks=sinks)
out = mx.fast.scaled_dot_product_attention(
q, k, v, scale=scale, sinks=sinks
)
self.assertTrue(mx.allclose(out, expected, atol=1e-5))
if __name__ == "__main__": if __name__ == "__main__":
mlx_tests.MLXTestRunner(failfast=True) mlx_tests.MLXTestRunner(failfast=True)

View File

@@ -121,7 +121,8 @@ class CMakeBuild(build_ext):
build_args += [f"-j{os.cpu_count()}"] build_args += [f"-j{os.cpu_count()}"]
# Avoid cache miss when building from temporary dirs. # Avoid cache miss when building from temporary dirs.
os.environ["CCACHE_BASEDIR"] = os.path.abspath(self.build_temp) os.environ["CCACHE_BASEDIR"] = os.path.realpath(self.build_temp)
os.environ["CCACHE_NOHASHDIR"] = "true"
subprocess.run( subprocess.run(
["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True ["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True
@@ -176,10 +177,6 @@ class GenerateStubs(Command):
# Run again without recursive to specify output file name # Run again without recursive to specify output file name
subprocess.run(["rm", f"{out_path}/mlx.pyi"]) subprocess.run(["rm", f"{out_path}/mlx.pyi"])
subprocess.run(stub_cmd + ["-o", f"{out_path}/__init__.pyi"]) subprocess.run(stub_cmd + ["-o", f"{out_path}/__init__.pyi"])
# mx.bool_ gets filtered by nanobind because of the trailing
# underscore, add it manually:
with open(f"{out_path}/__init__.pyi", "a") as fid:
fid.write("\nbool_: Dtype = ...")
class MLXBdistWheel(bdist_wheel): class MLXBdistWheel(bdist_wheel):

View File

@@ -1,10 +1,7 @@
# Doctest works fine with cmake 3.5
set(CMAKE_POLICY_VERSION_MINIMUM 3.5)
FetchContent_Declare( FetchContent_Declare(
doctest doctest
GIT_REPOSITORY "https://github.com/onqtam/doctest" GIT_REPOSITORY https://github.com/onqtam/doctest.git
GIT_TAG "ae7a13539fb71f270b87eb2e874fbac80bc8dda2") GIT_TAG v2.4.12)
FetchContent_MakeAvailable(doctest) FetchContent_MakeAvailable(doctest)
add_executable(tests ${PROJECT_SOURCE_DIR}/tests/tests.cpp) add_executable(tests ${PROJECT_SOURCE_DIR}/tests/tests.cpp)