mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-12 23:34:36 +08:00
Compare commits
19 Commits
Author | SHA1 | Date | |
---|---|---|---|
![]() |
ee18e1cbf0 | ||
![]() |
af120c2bc0 | ||
![]() |
6a3acf2301 | ||
![]() |
d6977f2a57 | ||
![]() |
db5443e831 | ||
![]() |
52b8384d10 | ||
![]() |
44cc5da4bc | ||
![]() |
dde3682b69 | ||
![]() |
17310d91a6 | ||
![]() |
b194d65a6a | ||
![]() |
a44b27f5f8 | ||
![]() |
e5a33f2223 | ||
![]() |
c1e3340b23 | ||
![]() |
8f163a367d | ||
![]() |
89a3df9014 | ||
![]() |
c5d2937aa5 | ||
![]() |
b61a65e313 | ||
![]() |
04cbb4191c | ||
![]() |
c5460762e7 |
@@ -230,6 +230,9 @@ jobs:
|
||||
sudo mv ccache-4.11.3-linux-x86_64/ccache /usr/bin/ccache
|
||||
rm -rf ccache-4.11.3-linux-x86_64
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
- run:
|
||||
name: Set CCache size
|
||||
command: ccache --max-size 1G
|
||||
- run:
|
||||
name: Install Python package
|
||||
command: |
|
||||
@@ -260,7 +263,6 @@ jobs:
|
||||
command: |
|
||||
ccache --show-stats
|
||||
ccache --zero-stats
|
||||
ccache --max-size 400MB
|
||||
ccache --cleanup
|
||||
- save_cache:
|
||||
key: cuda-<< parameters.image_date >>-{{ arch }}-{{ epoch }}
|
||||
|
@@ -19,7 +19,7 @@ MLX was developed with contributions from the following individuals:
|
||||
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
|
||||
- Paul Paczuski: Improved stability of BCE loss calculation
|
||||
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
|
||||
- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer.
|
||||
- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer, and the `ReLU²` activation function.
|
||||
|
||||
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
||||
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
||||
|
@@ -87,22 +87,21 @@ cmake_policy(SET CMP0135 NEW)
|
||||
|
||||
add_library(mlx)
|
||||
|
||||
if(MLX_BUILD_METAL)
|
||||
set(METAL_LIB "-framework Metal")
|
||||
set(FOUNDATION_LIB "-framework Foundation")
|
||||
set(QUARTZ_LIB "-framework QuartzCore")
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_CUDA)
|
||||
enable_language(CUDA)
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_METAL AND NOT METAL_LIB)
|
||||
message(STATUS "Metal not found. Unable to build GPU")
|
||||
set(MLX_BUILD_METAL OFF)
|
||||
set(MLX_METAL_DEBUG OFF)
|
||||
elseif(MLX_BUILD_METAL)
|
||||
message(STATUS "Building METAL sources")
|
||||
if(MLX_BUILD_METAL)
|
||||
find_library(METAL_LIB Metal)
|
||||
find_library(FOUNDATION_LIB Foundation)
|
||||
find_library(QUARTZ_LIB QuartzCore)
|
||||
if(METAL_LIB)
|
||||
message(STATUS "Metal found ${METAL_LIB}")
|
||||
else()
|
||||
message(
|
||||
FATAL_ERROR
|
||||
"Metal not found. Set MLX_BUILD_METAL=OFF to build without GPU")
|
||||
endif()
|
||||
|
||||
if(MLX_METAL_DEBUG)
|
||||
add_compile_definitions(MLX_METAL_DEBUG)
|
||||
@@ -111,7 +110,8 @@ elseif(MLX_BUILD_METAL)
|
||||
# Throw an error if xcrun not found
|
||||
execute_process(
|
||||
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
|
||||
OUTPUT_VARIABLE MACOS_SDK_VERSION COMMAND_ERROR_IS_FATAL ANY)
|
||||
OUTPUT_VARIABLE MACOS_SDK_VERSION
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
|
||||
|
||||
if(${MACOS_SDK_VERSION} LESS 14.0)
|
||||
message(
|
||||
|
@@ -27,6 +27,7 @@ simple functions.
|
||||
mish
|
||||
prelu
|
||||
relu
|
||||
relu2
|
||||
relu6
|
||||
selu
|
||||
sigmoid
|
||||
|
@@ -50,6 +50,7 @@ Layers
|
||||
QuantizedLinear
|
||||
RMSNorm
|
||||
ReLU
|
||||
ReLU2
|
||||
ReLU6
|
||||
RNN
|
||||
RoPE
|
||||
|
@@ -107,8 +107,20 @@ same array:
|
||||
>>> a
|
||||
array([1, 2, 0], dtype=int32)
|
||||
|
||||
Note that unlike NumPy, slicing an array creates a copy, not a view. So
|
||||
mutating it does not mutate the original array:
|
||||
|
||||
Note, unlike NumPy, updates to the same location are nondeterministic:
|
||||
.. code-block:: shell
|
||||
|
||||
>>> a = mx.array([1, 2, 3])
|
||||
>>> b = a[:]
|
||||
>>> b[2] = 0
|
||||
>>> b
|
||||
array([1, 2, 0], dtype=int32)
|
||||
>>> a
|
||||
array([1, 2, 3], dtype=int32)
|
||||
|
||||
Also unlike NumPy, updates to the same location are nondeterministic:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
|
@@ -85,10 +85,10 @@ cublasLtMatrixLayout_t create_matrix_layout(
|
||||
int32_t batch_count,
|
||||
int64_t batch_stride) {
|
||||
cublasLtMatrixLayout_t desc;
|
||||
if (transposed) {
|
||||
std::swap(rows, cols);
|
||||
}
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&desc, type, rows, cols, ld));
|
||||
cublasLtOrder_t order = transposed ? CUBLASLT_ORDER_COL : CUBLASLT_ORDER_ROW;
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
|
||||
desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(cublasLtOrder_t)));
|
||||
if (batch_count > 1) {
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
|
||||
desc,
|
||||
@@ -138,25 +138,34 @@ CublasGemm::CublasGemm(
|
||||
CUBLASLT_MATMUL_DESC_POINTER_MODE,
|
||||
&pointer_mode,
|
||||
sizeof(int32_t)));
|
||||
cublasOperation_t op = CUBLAS_OP_N;
|
||||
|
||||
// In cublasLt matrices use column-major layout, while it is possible to use
|
||||
// the CUBLASLT_ORDER_ROW option to switch to row-major layout, the bias
|
||||
// epilogue does not work with the option. So instead we swap A and B to make
|
||||
// cublasLt return the row-major result, which works because:
|
||||
// - the data of a matrix in row-major layout is identical to its transpose in
|
||||
// column-major layout
|
||||
// - C^T = (A @ B)^T = B^T @ A^T
|
||||
cublasOperation_t a_op = b_transposed ? CUBLAS_OP_T : CUBLAS_OP_N;
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
||||
matmul_desc_,
|
||||
CUBLASLT_MATMUL_DESC_TRANSA,
|
||||
&op,
|
||||
&a_op,
|
||||
sizeof(cublasOperation_t)));
|
||||
cublasOperation_t b_op = a_transposed ? CUBLAS_OP_T : CUBLAS_OP_N;
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
||||
matmul_desc_,
|
||||
CUBLASLT_MATMUL_DESC_TRANSB,
|
||||
&op,
|
||||
&b_op,
|
||||
sizeof(cublasOperation_t)));
|
||||
|
||||
auto type = dtype_to_cublas_type(dtype);
|
||||
a_desc_ = create_matrix_layout(
|
||||
type, a_rows, a_cols, a_transposed, lda, batch_count, a_batch_stride);
|
||||
type, b_cols, b_rows, b_transposed, ldb, batch_count, b_batch_stride);
|
||||
b_desc_ = create_matrix_layout(
|
||||
type, b_rows, b_cols, b_transposed, ldb, batch_count, b_batch_stride);
|
||||
type, a_cols, a_rows, a_transposed, lda, batch_count, a_batch_stride);
|
||||
out_desc_ = create_matrix_layout(
|
||||
type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols);
|
||||
type, b_cols, a_rows, false, b_cols, batch_count, a_rows * b_cols);
|
||||
}
|
||||
|
||||
CublasGemm::CublasGemm(
|
||||
@@ -191,7 +200,7 @@ CublasGemm::CublasGemm(
|
||||
b_batch_stride) {
|
||||
auto type = dtype_to_cublas_type(dtype);
|
||||
c_desc_ = create_matrix_layout(
|
||||
type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride);
|
||||
type, b_cols, a_rows, false, ldc, batch_count, c_batch_stride);
|
||||
}
|
||||
|
||||
CublasGemm::~CublasGemm() {
|
||||
@@ -213,14 +222,30 @@ void CublasGemm::set_out(
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_));
|
||||
out_desc_ = create_matrix_layout(
|
||||
dtype_to_cublas_type(dtype),
|
||||
rows,
|
||||
cols,
|
||||
rows,
|
||||
transposed,
|
||||
ld,
|
||||
batch_count,
|
||||
batch_stride);
|
||||
}
|
||||
|
||||
void CublasGemm::set_bias(cu::CommandEncoder& encoder, const array& bias) {
|
||||
encoder.set_input_array(bias);
|
||||
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS;
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
||||
matmul_desc_,
|
||||
CUBLASLT_MATMUL_DESC_EPILOGUE,
|
||||
&epilogue,
|
||||
sizeof(epilogue)));
|
||||
auto* bias_ptr = bias.data<void>();
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
||||
matmul_desc_,
|
||||
CUBLASLT_MATMUL_DESC_BIAS_POINTER,
|
||||
&bias_ptr,
|
||||
sizeof(bias_ptr)));
|
||||
}
|
||||
|
||||
void CublasGemm::run(
|
||||
cu::CommandEncoder& encoder,
|
||||
array& out,
|
||||
@@ -228,11 +253,19 @@ void CublasGemm::run(
|
||||
const array& b,
|
||||
const Shape& batch_shape,
|
||||
const Strides& a_batch_strides,
|
||||
const Strides& b_batch_strides) {
|
||||
const Strides& b_batch_strides,
|
||||
float alpha) {
|
||||
int batch_count = out.size() / (M_ * N_);
|
||||
if (batch_count / batch_shape.back() > 1) {
|
||||
run_batched(
|
||||
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
|
||||
encoder,
|
||||
out,
|
||||
a,
|
||||
b,
|
||||
batch_shape,
|
||||
a_batch_strides,
|
||||
b_batch_strides,
|
||||
alpha);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -240,7 +273,13 @@ void CublasGemm::run(
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
execute(encoder, out.data<void>(), a.data<void>(), b.data<void>(), nullptr);
|
||||
execute(
|
||||
encoder,
|
||||
out.data<void>(),
|
||||
a.data<void>(),
|
||||
b.data<void>(),
|
||||
nullptr,
|
||||
alpha);
|
||||
}
|
||||
|
||||
void CublasGemm::run(
|
||||
@@ -330,9 +369,9 @@ void CublasGemm::execute(
|
||||
handle_,
|
||||
matmul_desc_,
|
||||
&alpha,
|
||||
a,
|
||||
b, // a and b are swapped
|
||||
a_desc_,
|
||||
b,
|
||||
a,
|
||||
b_desc_,
|
||||
&beta,
|
||||
c ? c : out,
|
||||
|
@@ -55,6 +55,8 @@ class CublasGemm {
|
||||
int32_t batch_count,
|
||||
int64_t batch_stride);
|
||||
|
||||
void set_bias(cu::CommandEncoder& encoder, const array& bias);
|
||||
|
||||
void run(
|
||||
cu::CommandEncoder& encoder,
|
||||
array& out,
|
||||
@@ -62,7 +64,8 @@ class CublasGemm {
|
||||
const array& b,
|
||||
const Shape& batch_shape,
|
||||
const Strides& a_batch_strides,
|
||||
const Strides& b_batch_strides);
|
||||
const Strides& b_batch_strides,
|
||||
float alpha = 1.0f);
|
||||
|
||||
void run(
|
||||
cu::CommandEncoder& encoder,
|
||||
@@ -85,7 +88,8 @@ class CublasGemm {
|
||||
const array& b,
|
||||
const Shape& batch_shape,
|
||||
const Strides& a_batch_strides,
|
||||
const Strides& b_batch_strides);
|
||||
const Strides& b_batch_strides,
|
||||
float alpha);
|
||||
|
||||
void run_batched(
|
||||
cu::CommandEncoder& encoder,
|
||||
|
@@ -13,7 +13,8 @@ void CublasGemm::run_batched(
|
||||
const array& b,
|
||||
const Shape& batch_shape,
|
||||
const Strides& a_batch_strides,
|
||||
const Strides& b_batch_strides) {
|
||||
const Strides& b_batch_strides,
|
||||
float alpha) {
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
@@ -27,7 +28,8 @@ void CublasGemm::run_batched(
|
||||
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
|
||||
a.data<int8_t>() + a.itemsize() * a_it.loc,
|
||||
b.data<int8_t>() + b.itemsize() * b_it.loc,
|
||||
nullptr);
|
||||
nullptr,
|
||||
alpha);
|
||||
a_it.step();
|
||||
b_it.step();
|
||||
}
|
||||
|
@@ -154,7 +154,8 @@ void CublasGemm::run_batched(
|
||||
const array& b,
|
||||
const Shape& batch_shape,
|
||||
const Strides& a_batch_strides,
|
||||
const Strides& b_batch_strides) {
|
||||
const Strides& b_batch_strides,
|
||||
float alpha) {
|
||||
int batch_count = out.size() / (M_ * N_);
|
||||
set_pointer_mode(a_desc_, batch_count);
|
||||
set_pointer_mode(b_desc_, batch_count);
|
||||
@@ -226,7 +227,8 @@ void CublasGemm::run_batched(
|
||||
reinterpret_cast<void*>(out_pointers),
|
||||
reinterpret_cast<void*>(a_pointers),
|
||||
reinterpret_cast<void*>(b_pointers),
|
||||
nullptr);
|
||||
nullptr,
|
||||
alpha);
|
||||
}
|
||||
|
||||
void CublasGemm::run_batched(
|
||||
|
@@ -11,6 +11,7 @@
|
||||
#include <numeric>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
std::tuple<bool, int64_t, array>
|
||||
@@ -28,6 +29,76 @@ check_transpose(cu::CommandEncoder& enc, const Stream& s, const array& arr) {
|
||||
}
|
||||
}
|
||||
|
||||
void gemm_and_bias(
|
||||
cu::CommandEncoder& encoder,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
bool a_transposed,
|
||||
int64_t lda,
|
||||
bool b_transposed,
|
||||
int64_t ldb,
|
||||
array& out,
|
||||
const array& a,
|
||||
const array& b,
|
||||
const std::optional<array>& bias = std::nullopt,
|
||||
float alpha = 1.0f) {
|
||||
// Check and collapse batch dimensions
|
||||
auto [batch_shape, a_batch_strides, b_batch_strides] = collapse_batches(a, b);
|
||||
|
||||
auto batch_count = out.size() / (M * N);
|
||||
|
||||
// Collapse batches into M if needed
|
||||
if (batch_count > 1 && !a_transposed && batch_shape.size() == 1 &&
|
||||
a.strides()[a.ndim() - 2] == K && a_batch_strides.back() == M * K &&
|
||||
b_batch_strides.back() == 0) {
|
||||
M *= batch_shape.back();
|
||||
batch_count = 1;
|
||||
|
||||
a_batch_strides = {0};
|
||||
b_batch_strides = {0};
|
||||
batch_shape = {1};
|
||||
}
|
||||
|
||||
// Use gemmv when possible
|
||||
if (!bias && cu::can_use_gemv(M, N, K, a_transposed, b_transposed)) {
|
||||
cu::gemv(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
batch_count,
|
||||
batch_shape,
|
||||
a_batch_strides,
|
||||
b_batch_strides,
|
||||
encoder);
|
||||
return;
|
||||
}
|
||||
|
||||
// Invoke cublasLt
|
||||
CublasGemm gemm(
|
||||
encoder.device(),
|
||||
a.dtype(),
|
||||
a_transposed,
|
||||
M,
|
||||
K,
|
||||
lda,
|
||||
b_transposed,
|
||||
K,
|
||||
N,
|
||||
ldb,
|
||||
batch_shape.back(),
|
||||
a_batch_strides.back(),
|
||||
b_batch_strides.back());
|
||||
if (bias) {
|
||||
gemm.set_bias(encoder, *bias);
|
||||
}
|
||||
gemm.run(
|
||||
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides, alpha);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -48,9 +119,6 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Init checks and prep
|
||||
|
||||
int M = a_pre.shape(-2);
|
||||
int N = b_pre.shape(-1);
|
||||
int K = a_pre.shape(-1);
|
||||
@@ -60,58 +128,8 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre);
|
||||
auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre);
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Check and collapse batch dimensions
|
||||
|
||||
auto [batch_shape, a_batch_strides, b_batch_strides] = collapse_batches(a, b);
|
||||
|
||||
auto batch_count = out.size() / (M * N);
|
||||
|
||||
// Collapse batches into M if needed
|
||||
if (batch_count > 1 && !a_transposed && batch_shape.size() == 1 &&
|
||||
a.strides()[a.ndim() - 2] == K && a_batch_strides.back() == M * K &&
|
||||
b_batch_strides.back() == 0) {
|
||||
M *= batch_shape.back();
|
||||
batch_count = 1;
|
||||
|
||||
a_batch_strides = {0};
|
||||
b_batch_strides = {0};
|
||||
batch_shape = {1};
|
||||
}
|
||||
|
||||
if (cu::can_use_gemv(M, N, K, a_transposed, b_transposed)) {
|
||||
cu::gemv(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
batch_count,
|
||||
batch_shape,
|
||||
a_batch_strides,
|
||||
b_batch_strides,
|
||||
encoder);
|
||||
return;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Invoke cublasLt
|
||||
CublasGemm gemm(
|
||||
cu::device(s.device),
|
||||
a.dtype(),
|
||||
a_transposed,
|
||||
M,
|
||||
K,
|
||||
lda,
|
||||
b_transposed,
|
||||
K,
|
||||
N,
|
||||
ldb,
|
||||
batch_shape.back(),
|
||||
a_batch_strides.back(),
|
||||
b_batch_strides.back());
|
||||
gemm.run(encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
|
||||
gemm_and_bias(
|
||||
encoder, M, N, K, a_transposed, lda, b_transposed, ldb, out, a, b);
|
||||
}
|
||||
|
||||
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -136,6 +154,28 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre);
|
||||
auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre);
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Dispatch to GEMM with epilogue or AddMM
|
||||
|
||||
if (beta_ == 1 && c.strides(-1) == 1 && c.data_size() == out.shape(-1)) {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
gemm_and_bias(
|
||||
encoder,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
a_transposed,
|
||||
lda,
|
||||
b_transposed,
|
||||
ldb,
|
||||
out,
|
||||
a,
|
||||
b,
|
||||
c,
|
||||
alpha_);
|
||||
return;
|
||||
}
|
||||
|
||||
int64_t ldc;
|
||||
{
|
||||
auto stx = c.strides()[c.ndim() - 2];
|
||||
@@ -177,7 +217,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Invoke cublasLt
|
||||
// Invoke cublasLt with AddMM settings
|
||||
|
||||
CublasGemm gemm(
|
||||
cu::device(s.device),
|
||||
|
@@ -103,15 +103,21 @@ template <typename T, bool traditional, bool forward, int N = 4>
|
||||
__device__ void rope_impl(
|
||||
const T* in,
|
||||
T* out,
|
||||
int offset,
|
||||
const int* offset,
|
||||
float inv_freq,
|
||||
float scale,
|
||||
const cuda::std::array<int64_t, 3> strides,
|
||||
const cuda::std::array<int64_t, 3> out_strides,
|
||||
int64_t n_batch,
|
||||
int64_t offset_stride,
|
||||
int n_head,
|
||||
uint3 pos,
|
||||
uint3 dims) {
|
||||
float L = scale * static_cast<float>(pos.y + offset);
|
||||
auto n_head_up = N * ((n_head + N - 1) / N);
|
||||
auto head_idx = static_cast<int>((pos.z * N) % n_head_up);
|
||||
auto batch_idx = (pos.z * N) / n_head_up;
|
||||
auto batch_offset = offset[batch_idx * offset_stride];
|
||||
float L = scale * static_cast<float>(pos.y + batch_offset);
|
||||
auto mat_idx = batch_idx * n_head + head_idx;
|
||||
|
||||
// Compute costheta, sintheta
|
||||
float theta = L * inv_freq;
|
||||
@@ -123,20 +129,19 @@ __device__ void rope_impl(
|
||||
size_t out_index_1, out_index_2;
|
||||
if (traditional) {
|
||||
out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] +
|
||||
N * pos.z * out_strides[0];
|
||||
mat_idx * out_strides[0];
|
||||
out_index_2 = out_index_1 + 1;
|
||||
in_index_1 =
|
||||
2 * pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
|
||||
2 * pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0];
|
||||
in_index_2 = in_index_1 + strides[2];
|
||||
} else {
|
||||
out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] +
|
||||
N * pos.z * out_strides[0];
|
||||
mat_idx * out_strides[0];
|
||||
out_index_2 = out_index_1 + dims.x * out_strides[2];
|
||||
in_index_1 =
|
||||
pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
|
||||
in_index_1 = pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0];
|
||||
in_index_2 = in_index_1 + dims.x * strides[2];
|
||||
}
|
||||
for (int i = 0; i < N && pos.z * N + i < n_batch; ++i) {
|
||||
for (int i = 0; i < N && head_idx + i < n_head; ++i) {
|
||||
// Read and write the output
|
||||
float x1 = static_cast<float>(in[in_index_1]);
|
||||
float x2 = static_cast<float>(in[in_index_2]);
|
||||
@@ -167,7 +172,8 @@ __global__ void rope(
|
||||
float base,
|
||||
const __grid_constant__ cuda::std::array<int64_t, 3> strides,
|
||||
const __grid_constant__ cuda::std::array<int64_t, 3> out_strides,
|
||||
int64_t n_batch,
|
||||
int64_t offset_stride,
|
||||
int n_head,
|
||||
uint3 dims) {
|
||||
uint3 pos = make_uint3(
|
||||
blockIdx.x * blockDim.x + threadIdx.x,
|
||||
@@ -182,12 +188,13 @@ __global__ void rope(
|
||||
rope_impl<T, traditional, forward>(
|
||||
in,
|
||||
out,
|
||||
*offset,
|
||||
offset,
|
||||
inv_freq,
|
||||
scale,
|
||||
strides,
|
||||
out_strides,
|
||||
n_batch,
|
||||
offset_stride,
|
||||
n_head,
|
||||
pos,
|
||||
dims);
|
||||
}
|
||||
@@ -202,7 +209,8 @@ __global__ void rope_freqs(
|
||||
float base,
|
||||
const __grid_constant__ cuda::std::array<int64_t, 3> strides,
|
||||
const __grid_constant__ cuda::std::array<int64_t, 3> out_strides,
|
||||
int64_t n_batch,
|
||||
int64_t offset_stride,
|
||||
int n_head,
|
||||
uint3 dims,
|
||||
int64_t freq_stride) {
|
||||
uint3 pos = make_uint3(
|
||||
@@ -217,12 +225,13 @@ __global__ void rope_freqs(
|
||||
rope_impl<T, traditional, forward>(
|
||||
in,
|
||||
out,
|
||||
*offset,
|
||||
offset,
|
||||
inv_freq,
|
||||
scale,
|
||||
strides,
|
||||
out_strides,
|
||||
n_batch,
|
||||
offset_stride,
|
||||
n_head,
|
||||
pos,
|
||||
dims);
|
||||
}
|
||||
@@ -245,23 +254,28 @@ void RoPE::eval_gpu(
|
||||
auto& offset = inputs[1];
|
||||
auto& out = outputs[0];
|
||||
|
||||
if (in.ndim() < 3) {
|
||||
throw std::runtime_error("[RoPE] Input must have at least 3 dimensions");
|
||||
}
|
||||
|
||||
cuda::std::array<int64_t, 3> strides;
|
||||
cuda::std::array<int64_t, 3> out_strides;
|
||||
bool donated = false;
|
||||
int ndim = in.ndim();
|
||||
int dispatch_ndim = in.ndim();
|
||||
|
||||
int B = in.shape(0);
|
||||
int T = in.shape(-2);
|
||||
int D = in.shape(-1);
|
||||
size_t mat_size = T * D;
|
||||
int dispatch_ndim = ndim;
|
||||
while (in.shape(-dispatch_ndim) == 1 && dispatch_ndim > 3) {
|
||||
dispatch_ndim--;
|
||||
}
|
||||
size_t mat_size = in.shape(-2) * in.shape(-1);
|
||||
|
||||
int N = 1;
|
||||
for (int i = 1; i < (ndim - 2); ++i) {
|
||||
N *= in.shape(i);
|
||||
}
|
||||
|
||||
// We apply rope to less that the whole vector so copy to output and then
|
||||
// apply in-place.
|
||||
if (dims_ < in.shape(-1)) {
|
||||
if (dims_ < D) {
|
||||
donated = true;
|
||||
auto ctype =
|
||||
(in.flags().row_contiguous) ? CopyType::Vector : CopyType::General;
|
||||
@@ -302,7 +316,7 @@ void RoPE::eval_gpu(
|
||||
out_strides[2] = out.strides()[ndim - 1];
|
||||
|
||||
// Some flags to help us dispatch below
|
||||
bool single = in.flags().row_contiguous && (mat_size == in.shape(-1));
|
||||
bool single = in.flags().row_contiguous && B == 1 && T == 1;
|
||||
bool with_freqs = inputs.size() == 3;
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
@@ -319,7 +333,7 @@ void RoPE::eval_gpu(
|
||||
if (single && !with_freqs) {
|
||||
auto kernel =
|
||||
cu::rope_single<DataType, traditional.value, forward.value>;
|
||||
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size);
|
||||
uint2 dims = make_uint2(dims_ / 2, N);
|
||||
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
@@ -336,7 +350,7 @@ void RoPE::eval_gpu(
|
||||
} else if (single) {
|
||||
auto kernel =
|
||||
cu::rope_single_freqs<DataType, traditional.value, forward.value>;
|
||||
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size);
|
||||
uint2 dims = make_uint2(dims_ / 2, N);
|
||||
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
@@ -354,10 +368,14 @@ void RoPE::eval_gpu(
|
||||
} else if (with_freqs) {
|
||||
auto kernel =
|
||||
cu::rope_freqs<DataType, traditional.value, forward.value>;
|
||||
uint3 dims =
|
||||
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size);
|
||||
dims.z = (dims.z + 3) / 4;
|
||||
int n_per_thread = 4;
|
||||
uint32_t dimz = B * ((N + n_per_thread - 1) / n_per_thread);
|
||||
uint3 dims = make_uint3(dims_ / 2, T, dimz);
|
||||
auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z);
|
||||
int64_t offset_stride = 0;
|
||||
if (inputs[1].ndim() > 0) {
|
||||
offset_stride = inputs[1].strides()[0];
|
||||
}
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
grid,
|
||||
@@ -371,15 +389,20 @@ void RoPE::eval_gpu(
|
||||
std::log2(base_),
|
||||
strides,
|
||||
out_strides,
|
||||
in.size() / mat_size,
|
||||
offset_stride,
|
||||
N,
|
||||
dims,
|
||||
inputs[2].strides(0));
|
||||
} else {
|
||||
auto kernel = cu::rope<DataType, traditional.value, forward.value>;
|
||||
uint3 dims =
|
||||
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size);
|
||||
dims.z = (dims.z + 3) / 4;
|
||||
int n_per_thread = 4;
|
||||
uint32_t dimz = B * ((N + n_per_thread - 1) / n_per_thread);
|
||||
uint3 dims = make_uint3(dims_ / 2, T, dimz);
|
||||
auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z);
|
||||
int64_t offset_stride = 0;
|
||||
if (inputs[1].ndim() > 0) {
|
||||
offset_stride = inputs[1].strides()[0];
|
||||
}
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
grid,
|
||||
@@ -392,7 +415,8 @@ void RoPE::eval_gpu(
|
||||
std::log2(base_),
|
||||
strides,
|
||||
out_strides,
|
||||
in.size() / mat_size,
|
||||
offset_stride,
|
||||
N,
|
||||
dims);
|
||||
}
|
||||
});
|
||||
|
@@ -46,6 +46,7 @@ __global__ void kernel_sdpav_1pass(
|
||||
const T* K,
|
||||
const T* V,
|
||||
T* O,
|
||||
const T* sinks,
|
||||
__grid_constant__ const AttnParams params) {
|
||||
constexpr int BN = 32;
|
||||
constexpr int BD = 32;
|
||||
@@ -65,7 +66,7 @@ __global__ void kernel_sdpav_1pass(
|
||||
__shared__ U max_scores[BN];
|
||||
__shared__ U sum_exp_scores[BN];
|
||||
|
||||
const U scale_log2 = params.scale * 1.44269504089f;
|
||||
const U scale_log2 = params.scale * M_LOG2E;
|
||||
|
||||
auto block = cg::this_thread_block();
|
||||
auto warp = cg::tiled_partition<32>(block);
|
||||
@@ -110,6 +111,10 @@ __global__ void kernel_sdpav_1pass(
|
||||
|
||||
U max_score = -INFINITY;
|
||||
U sum_exp_score = 0.f;
|
||||
if (sinks && warp_idx == 0) {
|
||||
max_score = M_LOG2E * static_cast<U>(sinks[head_idx]);
|
||||
sum_exp_score = 1.f;
|
||||
}
|
||||
|
||||
// For each key
|
||||
for (int i = kv_seq_idx; i < params.kL; i += BN) {
|
||||
@@ -137,8 +142,9 @@ __global__ void kernel_sdpav_1pass(
|
||||
|
||||
// Update the accumulators
|
||||
U new_max = max(max_score, score);
|
||||
U factor = exp2f(max_score - new_max);
|
||||
U exp_score = exp2f(score - new_max);
|
||||
bool is_neg_inf = new_max == -INFINITY;
|
||||
U factor = is_neg_inf ? 1 : exp2f(max_score - new_max);
|
||||
U exp_score = is_neg_inf ? 0 : exp2f(score - new_max);
|
||||
|
||||
max_score = new_max;
|
||||
sum_exp_score = sum_exp_score * factor + exp_score;
|
||||
@@ -193,6 +199,7 @@ __global__ void kernel_sdpav_2pass_1(
|
||||
const T* Q,
|
||||
const T* K,
|
||||
const T* V,
|
||||
const T* sinks,
|
||||
float* partials,
|
||||
float* sums,
|
||||
float* maxs,
|
||||
@@ -268,8 +275,12 @@ __global__ void kernel_sdpav_2pass_1(
|
||||
o[i] = 0.f;
|
||||
}
|
||||
|
||||
U max_score = -1e9;
|
||||
U max_score = -INFINITY;
|
||||
U sum_exp_score = 0.f;
|
||||
if (sinks && warp_idx == 0 && block_idx == 0) {
|
||||
max_score = M_LOG2E * static_cast<U>(sinks[head_idx]);
|
||||
sum_exp_score = 1.f;
|
||||
}
|
||||
|
||||
// For each key
|
||||
for (int i = kv_seq_idx; i < params.kL; i += blocks * BN) {
|
||||
@@ -297,8 +308,9 @@ __global__ void kernel_sdpav_2pass_1(
|
||||
|
||||
// Update the accumulators
|
||||
U new_max = max(max_score, score);
|
||||
U factor = exp2f(max_score - new_max);
|
||||
U exp_score = exp2f(score - new_max);
|
||||
bool is_neg_inf = new_max == -INFINITY;
|
||||
U factor = is_neg_inf ? 1 : exp2f(max_score - new_max);
|
||||
U exp_score = is_neg_inf ? 0 : exp2f(score - new_max);
|
||||
|
||||
max_score = new_max;
|
||||
sum_exp_score = sum_exp_score * factor + exp_score;
|
||||
@@ -463,10 +475,14 @@ void sdpa_vector_1pass_fallback(
|
||||
const array& v,
|
||||
const float scale,
|
||||
array& o,
|
||||
bool do_causal_ = false) {
|
||||
bool do_causal,
|
||||
const std::optional<array>& sinks) {
|
||||
encoder.set_input_array(q);
|
||||
encoder.set_input_array(k);
|
||||
encoder.set_input_array(v);
|
||||
if (sinks) {
|
||||
encoder.set_input_array(*sinks);
|
||||
}
|
||||
encoder.set_output_array(o);
|
||||
|
||||
cu::AttnParams params{
|
||||
@@ -489,7 +505,7 @@ void sdpa_vector_1pass_fallback(
|
||||
dim3 block_dim(1024, 1, 1);
|
||||
|
||||
dispatch_float_types(o.dtype(), "kernel_sdpav_1pass", [&](auto type_tag) {
|
||||
dispatch_bool(do_causal_, [&](auto do_causal) {
|
||||
dispatch_bool(do_causal, [&](auto do_causal) {
|
||||
dispatch_headdim(params.D, [&](auto headdim) {
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
|
||||
@@ -504,6 +520,7 @@ void sdpa_vector_1pass_fallback(
|
||||
k.data<DataType>(),
|
||||
v.data<DataType>(),
|
||||
o.data<DataType>(),
|
||||
sinks ? (*sinks).data<DataType>() : nullptr,
|
||||
params);
|
||||
});
|
||||
});
|
||||
@@ -518,7 +535,8 @@ void sdpa_vector_2pass_fallback(
|
||||
const array& v,
|
||||
const float scale,
|
||||
array& o,
|
||||
bool do_causal_ = false) {
|
||||
bool do_causal,
|
||||
const std::optional<array>& sinks) {
|
||||
cu::AttnParams params{
|
||||
/* int B = */ q.shape(0),
|
||||
/* int H = */ q.shape(1),
|
||||
@@ -559,7 +577,7 @@ void sdpa_vector_2pass_fallback(
|
||||
encoder.add_temporary(maxs);
|
||||
|
||||
dispatch_float_types(o.dtype(), "kernel_sdpav_2pass", [&](auto type_tag) {
|
||||
dispatch_bool(do_causal_, [&](auto do_causal) {
|
||||
dispatch_bool(do_causal, [&](auto do_causal) {
|
||||
dispatch_headdim(params.D, [&](auto headdim) {
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
|
||||
@@ -570,6 +588,10 @@ void sdpa_vector_2pass_fallback(
|
||||
encoder.set_input_array(q);
|
||||
encoder.set_input_array(k);
|
||||
encoder.set_input_array(v);
|
||||
if (sinks) {
|
||||
encoder.set_input_array(*sinks);
|
||||
}
|
||||
|
||||
encoder.set_output_array(intermediate);
|
||||
encoder.set_output_array(sums);
|
||||
encoder.set_output_array(maxs);
|
||||
@@ -585,6 +607,7 @@ void sdpa_vector_2pass_fallback(
|
||||
q.data<DataType>(),
|
||||
k.data<DataType>(),
|
||||
v.data<DataType>(),
|
||||
sinks ? (*sinks).data<DataType>() : nullptr,
|
||||
intermediate.data<float>(),
|
||||
sums.data<float>(),
|
||||
maxs.data<float>(),
|
||||
@@ -627,15 +650,16 @@ void sdpa_vector_fallback(
|
||||
const array& v,
|
||||
const float scale,
|
||||
array& o,
|
||||
bool do_causal_ = false) {
|
||||
bool do_causal,
|
||||
const std::optional<array>& sinks) {
|
||||
int kL = k.shape(2);
|
||||
|
||||
if (kL > 1024) {
|
||||
return sdpa_vector_2pass_fallback(
|
||||
s, encoder, q, k, v, scale, o, do_causal_);
|
||||
s, encoder, q, k, v, scale, o, do_causal, sinks);
|
||||
} else {
|
||||
return sdpa_vector_1pass_fallback(
|
||||
s, encoder, q, k, v, scale, o, do_causal_);
|
||||
s, encoder, q, k, v, scale, o, do_causal, sinks);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -691,7 +715,7 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
|
||||
// Define some copy functions to ensure the layout of the inputs is as
|
||||
// expected.
|
||||
copies.reserve(3);
|
||||
copies.reserve(inputs.size());
|
||||
auto copy_unless = [&copies, &s](
|
||||
auto predicate, const array& arr) -> const array& {
|
||||
if (!predicate(arr)) {
|
||||
@@ -703,6 +727,16 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
}
|
||||
};
|
||||
|
||||
// Checks that the headdim dimension has stride 1.
|
||||
auto is_matrix_contiguous = [](const array& arr) {
|
||||
return arr.strides(-1) == 1;
|
||||
};
|
||||
|
||||
std::optional<array> sinks = std::nullopt;
|
||||
if (has_sinks_) {
|
||||
sinks = copy_unless(is_matrix_contiguous, inputs.back());
|
||||
}
|
||||
|
||||
// We are in vector mode ie single query
|
||||
if (q_pre.shape(2) < 4) {
|
||||
auto q_copy_unless = [](const array& arr) {
|
||||
@@ -740,10 +774,6 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
const auto& k = copy_unless(kv_copy_unless, k_pre);
|
||||
const auto& v = copy_unless(kv_copy_unless, v_pre);
|
||||
|
||||
for (const auto& cp : copies) {
|
||||
encoder.add_temporary(cp);
|
||||
}
|
||||
|
||||
// Donate the query if possible
|
||||
if (q.is_donatable() && q.flags().row_contiguous && q.size() == o.size()) {
|
||||
o.copy_shared_buffer(q);
|
||||
@@ -752,22 +782,26 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
int64_t str_oH = o.shape(3);
|
||||
int64_t str_oL = o.shape(1) * str_oH;
|
||||
int64_t str_oB = o.shape(2) * str_oL;
|
||||
size_t data_size = o.shape(0) * str_oB;
|
||||
|
||||
array::Flags flags{
|
||||
/* bool contiguous = */ 1,
|
||||
/* bool row_contiguous = */ o.shape(2) == 1,
|
||||
/* bool col_contiguous = */ 0,
|
||||
/* bool col_contiguous = */ o.size() == o.shape(3),
|
||||
};
|
||||
|
||||
o.set_data(
|
||||
allocator::malloc(o.nbytes()),
|
||||
data_size,
|
||||
o.size(),
|
||||
{str_oB, str_oH, str_oL, str_oD},
|
||||
flags);
|
||||
}
|
||||
|
||||
return sdpa_vector_fallback(s, encoder, q, k, v, scale_, o, do_causal_);
|
||||
for (const auto& cp : copies) {
|
||||
encoder.add_temporary(cp);
|
||||
}
|
||||
|
||||
return sdpa_vector_fallback(
|
||||
s, encoder, q, k, v, scale_, o, do_causal_, sinks);
|
||||
}
|
||||
|
||||
// Full attention mode should never reach here
|
||||
|
@@ -2,7 +2,6 @@
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
@@ -39,10 +38,11 @@ void explicit_gemm_conv_ND_gpu(
|
||||
in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes()));
|
||||
|
||||
// Prepare unfolding kernel
|
||||
std::ostringstream kname;
|
||||
kname << "naive_unfold_nd_" << type_to_name(in_unfolded) << "_" << N;
|
||||
std::string kname;
|
||||
kname.reserve(32);
|
||||
concatenate(kname, "naive_unfold_nd_", type_to_name(in_unfolded), "_", N);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto kernel = d.get_kernel(kname);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
@@ -117,11 +117,12 @@ void explicit_gemm_conv_group_ND_gpu(
|
||||
in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes()));
|
||||
|
||||
// Prepare unfolding kernel
|
||||
std::ostringstream kname;
|
||||
kname << "naive_unfold_transpose_nd_" << type_to_name(in_unfolded) << "_"
|
||||
<< N;
|
||||
std::string kname;
|
||||
kname.reserve(32);
|
||||
concatenate(
|
||||
kname, "naive_unfold_transpose_nd_", type_to_name(in_unfolded), "_", N);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto kernel = d.get_kernel(kname);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
@@ -252,18 +253,32 @@ void implicit_gemm_conv_2D_gpu(
|
||||
/* const int swizzle_log = */ swizzle_log};
|
||||
|
||||
// Determine kernel
|
||||
std::ostringstream kname;
|
||||
kname << "implicit_gemm_conv_2d_" << type_to_name(out) << "_bm" << bm << "_bn"
|
||||
<< bn << "_bk" << bk << "_wm" << wm << "_wn" << wn << "_channel_"
|
||||
<< (n_channel_specialization ? std::to_string(n_channel_specialization)
|
||||
: "l")
|
||||
<< "_filter_" << (small_filter ? 's' : 'l');
|
||||
std::string kname;
|
||||
kname.reserve(64);
|
||||
concatenate(
|
||||
kname,
|
||||
"implicit_gemm_conv_2d_",
|
||||
type_to_name(out),
|
||||
"_bm",
|
||||
bm,
|
||||
"_bn",
|
||||
bn,
|
||||
"_bk",
|
||||
bk,
|
||||
"_wm",
|
||||
wm,
|
||||
"_wn",
|
||||
wn,
|
||||
"_channel_",
|
||||
n_channel_specialization ? std::to_string(n_channel_specialization) : "l",
|
||||
"_filter_",
|
||||
small_filter ? 's' : 'l');
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = get_steel_conv_kernel(
|
||||
d,
|
||||
kname.str(),
|
||||
kname,
|
||||
out,
|
||||
bm,
|
||||
bn,
|
||||
@@ -559,11 +574,16 @@ void winograd_conv_2D_gpu(
|
||||
{
|
||||
int bc = 32;
|
||||
int bo = 4;
|
||||
std::ostringstream kname;
|
||||
kname << "winograd_conv_2d_weight_transform_" << type_to_name(out) << "_bc"
|
||||
<< bc;
|
||||
std::string kname;
|
||||
kname.reserve(32);
|
||||
concatenate(
|
||||
kname,
|
||||
"winograd_conv_2d_weight_transform_",
|
||||
type_to_name(out),
|
||||
"_bc",
|
||||
bc);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto kernel = d.get_kernel(kname);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(wt, 0);
|
||||
@@ -587,11 +607,16 @@ void winograd_conv_2D_gpu(
|
||||
int bc = 32;
|
||||
int wm = 2;
|
||||
int wn = 2;
|
||||
std::ostringstream kname;
|
||||
kname << "winograd_conv_2d_input_transform_" << type_to_name(out) << "_bc"
|
||||
<< bc;
|
||||
std::string kname;
|
||||
kname.reserve(32);
|
||||
concatenate(
|
||||
kname,
|
||||
"winograd_conv_2d_input_transform_",
|
||||
type_to_name(out),
|
||||
"_bc",
|
||||
bc);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto kernel = d.get_kernel(kname);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(in_padded, 0);
|
||||
@@ -634,11 +659,16 @@ void winograd_conv_2D_gpu(
|
||||
int bc = 32;
|
||||
int wm = 2;
|
||||
int wn = 2;
|
||||
std::ostringstream kname;
|
||||
kname << "winograd_conv_2d_output_transform_" << type_to_name(out) << "_bo"
|
||||
<< bc;
|
||||
std::string kname;
|
||||
kname.reserve(32);
|
||||
concatenate(
|
||||
kname,
|
||||
"winograd_conv_2d_output_transform_",
|
||||
type_to_name(out),
|
||||
"_bo",
|
||||
bc);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto kernel = d.get_kernel(kname);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(out_wg, 0);
|
||||
@@ -660,9 +690,9 @@ void depthwise_conv_2D_gpu(
|
||||
const array& wt,
|
||||
array out,
|
||||
const MLXConvParams<2>& conv_params) {
|
||||
std::ostringstream kname;
|
||||
kname << "depthwise_conv_2d_" << type_to_name(out);
|
||||
std::string base_name = kname.str();
|
||||
std::string base_name;
|
||||
base_name.reserve(32);
|
||||
concatenate(base_name, "depthwise_conv_2d_", type_to_name(out));
|
||||
|
||||
const int N = conv_params.N;
|
||||
const int ker_h = conv_params.wS[0];
|
||||
@@ -685,15 +715,18 @@ void depthwise_conv_2D_gpu(
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
kname << "_ker_h_" << ker_h
|
||||
<< "_ker_w_" << ker_w
|
||||
<< "_str_h_" << str_h
|
||||
<< "_str_w_" << str_w
|
||||
<< "_tgp_h_" << th
|
||||
<< "_tgp_w_" << tw
|
||||
<< "_do_flip_" << (do_flip ? 't' : 'n'); // clang-format on
|
||||
|
||||
std::string hash_name = kname.str();
|
||||
std::string hash_name;
|
||||
hash_name.reserve(64);
|
||||
concatenate(
|
||||
hash_name,
|
||||
base_name,
|
||||
"_ker_h_", ker_h,
|
||||
"_ker_w_", ker_w,
|
||||
"_str_h_", str_h,
|
||||
"_str_w_", str_w,
|
||||
"_tgp_h_", th,
|
||||
"_tgp_w_", tw,
|
||||
"_do_flip_", do_flip ? 't' : 'n'); // clang-format on
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(base_name, hash_name, func_consts);
|
||||
@@ -774,6 +807,56 @@ void dispatch_conv_2D_gpu(
|
||||
}
|
||||
}
|
||||
|
||||
void depthwise_conv_1D_gpu(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& in,
|
||||
array wt,
|
||||
array out) {
|
||||
bool large = in.size() > INT32_MAX || in.data_size() > INT32_MAX;
|
||||
std::string base_name;
|
||||
base_name.reserve(32);
|
||||
concatenate(
|
||||
base_name,
|
||||
"depthwise_conv_1d_",
|
||||
large ? "_large" : "",
|
||||
type_to_name(out));
|
||||
|
||||
if (!wt.flags().row_contiguous) {
|
||||
wt = contiguous_copy_gpu(wt, s);
|
||||
d.add_temporary(wt, s.index);
|
||||
}
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(base_name);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
auto B = in.shape(0);
|
||||
auto Tout = out.shape(1);
|
||||
auto D = in.shape(2);
|
||||
auto K = wt.shape(1);
|
||||
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_input_array(wt, 1);
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
if (large) {
|
||||
int64_t strides[3] = {in.strides(0), in.strides(1), in.strides(2)};
|
||||
compute_encoder.set_bytes(strides, 3, 3);
|
||||
|
||||
} else {
|
||||
int strides[3] = {
|
||||
static_cast<int>(in.strides(0)),
|
||||
static_cast<int>(in.strides(1)),
|
||||
static_cast<int>(in.strides(2))};
|
||||
compute_encoder.set_bytes(strides, 3, 3);
|
||||
}
|
||||
|
||||
compute_encoder.set_bytes(K, 4);
|
||||
auto group_dims = get_block_dims(D, Tout, B);
|
||||
MTL::Size grid_dims = MTL::Size(D, Tout, B);
|
||||
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void conv_1D_gpu(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
@@ -790,8 +873,15 @@ void conv_1D_gpu(
|
||||
bool is_idil_one = in_dilation[0] == 1;
|
||||
int C = in.shape(2);
|
||||
int O = wt.shape(0);
|
||||
const int C_per_group = in.shape(2) / groups;
|
||||
const int O_per_group = wt.shape(0) / groups;
|
||||
// Fast path for fully separable 1D convolution
|
||||
if (is_idil_one && (groups == C) && groups == O && wt_strides[0] == 1 &&
|
||||
wt_dilation[0] == 1 && padding[0] == 0 && !flip) {
|
||||
depthwise_conv_1D_gpu(s, d, in, wt, out);
|
||||
return;
|
||||
}
|
||||
|
||||
const int C_per_group = C / groups;
|
||||
const int O_per_group = O / groups;
|
||||
|
||||
// Direct to implicit gemm conv
|
||||
if (is_idil_one && (C_per_group <= 4 || C_per_group % 16 == 0) &&
|
||||
|
@@ -288,6 +288,40 @@ instantiate_depthconv2d(float32, float);
|
||||
instantiate_depthconv2d(float16, half);
|
||||
instantiate_depthconv2d(bfloat16, bfloat16_t);
|
||||
|
||||
template <typename T, typename IdxT>
|
||||
[[kernel]] void depthwise_conv_1d(
|
||||
const device T* in [[buffer(0)]],
|
||||
const device T* w [[buffer(1)]],
|
||||
device T* out [[buffer(2)]],
|
||||
constant const IdxT strides[3],
|
||||
constant const int& kernel_size,
|
||||
uint3 tid [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
out += (tid.z * static_cast<IdxT>(grid_dim.y) + tid.y) * grid_dim.x + tid.x;
|
||||
in += tid.z * strides[0] + tid.y * strides[1] + tid.x * strides[2];
|
||||
w += tid.x * kernel_size;
|
||||
|
||||
float acc = 0.0;
|
||||
for (int i = 0; i < kernel_size; ++i) {
|
||||
acc += static_cast<float>(in[0]) * w[i];
|
||||
in += strides[1];
|
||||
}
|
||||
*out = static_cast<T>(acc);
|
||||
}
|
||||
|
||||
#define instantiate_depthconv1d(iname, itype) \
|
||||
instantiate_kernel( \
|
||||
"depthwise_conv_1d_" #iname, depthwise_conv_1d, itype, int32_t) \
|
||||
instantiate_kernel( \
|
||||
"depthwise_conv_1d_" #iname "_large", \
|
||||
depthwise_conv_1d, \
|
||||
itype, \
|
||||
int64_t)
|
||||
|
||||
instantiate_depthconv1d(float32, float);
|
||||
instantiate_depthconv1d(float16, half);
|
||||
instantiate_depthconv1d(bfloat16, bfloat16_t);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
/// Winograd kernels
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
@@ -10,7 +10,7 @@ void rope_single_impl(
|
||||
constant const int& offset,
|
||||
const float inv_freq,
|
||||
constant const float& scale,
|
||||
constant const size_t& stride,
|
||||
constant const int64_t& stride,
|
||||
uint2 pos,
|
||||
uint2 grid) {
|
||||
float L = scale * static_cast<float>(offset);
|
||||
@@ -52,7 +52,7 @@ template <typename T, bool traditional, bool forward>
|
||||
device T* out [[buffer(1)]],
|
||||
constant const int& offset,
|
||||
constant const float& scale,
|
||||
constant const size_t& stride,
|
||||
constant const int64_t& stride,
|
||||
constant const float& base [[buffer(10)]],
|
||||
uint2 pos [[thread_position_in_grid]],
|
||||
uint2 grid [[threads_per_grid]]) {
|
||||
@@ -68,9 +68,9 @@ template <typename T, bool traditional, bool forward>
|
||||
device T* out [[buffer(1)]],
|
||||
constant const int& offset,
|
||||
constant const float& scale,
|
||||
constant const size_t& stride,
|
||||
constant const int64_t& stride,
|
||||
const device float* freqs [[buffer(10)]],
|
||||
constant const size_t& freq_stride [[buffer(11)]],
|
||||
constant const int64_t& freq_stride [[buffer(11)]],
|
||||
uint2 pos [[thread_position_in_grid]],
|
||||
uint2 grid [[threads_per_grid]]) {
|
||||
float inv_freq = 1.0 / (freqs[freq_stride * pos.x]);
|
||||
@@ -82,15 +82,21 @@ template <typename T, bool traditional, bool forward, int N = 4>
|
||||
void rope_impl(
|
||||
const device T* in,
|
||||
device T* out,
|
||||
constant const int& offset,
|
||||
const device int* offset,
|
||||
const float inv_freq,
|
||||
constant const float& scale,
|
||||
constant const size_t strides[3],
|
||||
constant const size_t out_strides[3],
|
||||
constant const size_t& n_batch,
|
||||
constant const int64_t strides[3],
|
||||
constant const int64_t out_strides[3],
|
||||
constant const int64_t& offset_stride,
|
||||
constant const int& n_head,
|
||||
uint3 pos,
|
||||
uint3 grid) {
|
||||
float L = scale * static_cast<float>(pos.y + offset);
|
||||
auto n_head_up = N * ((n_head + N - 1) / N);
|
||||
auto head_idx = static_cast<int>((pos.z * N) % n_head_up);
|
||||
auto batch_idx = (pos.z * N) / n_head_up;
|
||||
auto batch_offset = offset[batch_idx * offset_stride];
|
||||
float L = scale * static_cast<float>(pos.y + batch_offset);
|
||||
auto mat_idx = batch_idx * n_head + head_idx;
|
||||
|
||||
// Compute costheta, sintheta
|
||||
float theta = L * inv_freq;
|
||||
@@ -102,20 +108,19 @@ void rope_impl(
|
||||
size_t out_index_1, out_index_2;
|
||||
if (traditional) {
|
||||
out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] +
|
||||
N * pos.z * out_strides[0];
|
||||
mat_idx * out_strides[0];
|
||||
out_index_2 = out_index_1 + 1;
|
||||
in_index_1 =
|
||||
2 * pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
|
||||
2 * pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0];
|
||||
in_index_2 = in_index_1 + strides[2];
|
||||
} else {
|
||||
out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] +
|
||||
N * pos.z * out_strides[0];
|
||||
mat_idx * out_strides[0];
|
||||
out_index_2 = out_index_1 + grid.x * out_strides[2];
|
||||
in_index_1 =
|
||||
pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
|
||||
in_index_1 = pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0];
|
||||
in_index_2 = in_index_1 + grid.x * strides[2];
|
||||
}
|
||||
for (int i = 0; i < N && pos.z * N + i < n_batch; ++i) {
|
||||
for (int i = 0; i < N && head_idx + i < n_head; ++i) {
|
||||
// Read and write the output
|
||||
float x1 = static_cast<float>(in[in_index_1]);
|
||||
float x2 = static_cast<float>(in[in_index_2]);
|
||||
@@ -141,11 +146,12 @@ template <typename T, bool traditional, bool forward, int N = 4>
|
||||
[[kernel]] void rope(
|
||||
const device T* in [[buffer(0)]],
|
||||
device T* out [[buffer(1)]],
|
||||
constant const int& offset,
|
||||
const device int* offset,
|
||||
constant const float& scale,
|
||||
constant const size_t strides[3],
|
||||
constant const size_t out_strides[3],
|
||||
constant const size_t& n_batch,
|
||||
constant const int64_t strides[3],
|
||||
constant const int64_t out_strides[3],
|
||||
constant const int64_t& offset_stride,
|
||||
constant const int& n_head,
|
||||
constant const float& base [[buffer(10)]],
|
||||
uint3 pos [[thread_position_in_grid]],
|
||||
uint3 grid [[threads_per_grid]]) {
|
||||
@@ -159,7 +165,8 @@ template <typename T, bool traditional, bool forward, int N = 4>
|
||||
scale,
|
||||
strides,
|
||||
out_strides,
|
||||
n_batch,
|
||||
offset_stride,
|
||||
n_head,
|
||||
pos,
|
||||
grid);
|
||||
}
|
||||
@@ -168,13 +175,14 @@ template <typename T, bool traditional, bool forward, int N = 4>
|
||||
[[kernel]] void rope_freqs(
|
||||
const device T* in [[buffer(0)]],
|
||||
device T* out [[buffer(1)]],
|
||||
constant const int& offset,
|
||||
const device int* offset,
|
||||
constant const float& scale,
|
||||
constant const size_t strides[3],
|
||||
constant const size_t out_strides[3],
|
||||
constant const size_t& n_batch,
|
||||
constant const int64_t strides[3],
|
||||
constant const int64_t out_strides[3],
|
||||
constant const int64_t& offset_stride,
|
||||
constant const int& n_head,
|
||||
const device float* freqs [[buffer(10)]],
|
||||
constant const size_t& freq_stride [[buffer(11)]],
|
||||
constant const int64_t& freq_stride [[buffer(11)]],
|
||||
uint3 pos [[thread_position_in_grid]],
|
||||
uint3 grid [[threads_per_grid]]) {
|
||||
float inv_freq = 1.0 / (freqs[freq_stride * pos.x]);
|
||||
@@ -186,61 +194,20 @@ template <typename T, bool traditional, bool forward, int N = 4>
|
||||
scale,
|
||||
strides,
|
||||
out_strides,
|
||||
n_batch,
|
||||
offset_stride,
|
||||
n_head,
|
||||
pos,
|
||||
grid);
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_rope_g(name, type, traditional, forward) \
|
||||
template [[host_name("rope_" #name)]] [[kernel]] void \
|
||||
rope<type, traditional, forward>( \
|
||||
const device type* in [[buffer(0)]], \
|
||||
device type* out [[buffer(1)]], \
|
||||
constant const int& offset, \
|
||||
constant const float& scale, \
|
||||
constant const size_t strides[3], \
|
||||
constant const size_t out_strides[3], \
|
||||
constant const size_t& n_batch, \
|
||||
constant const float& base [[buffer(10)]], \
|
||||
uint3 pos [[thread_position_in_grid]], \
|
||||
uint3 grid [[threads_per_grid]]); \
|
||||
template [[host_name("rope_freqs_" #name)]] \
|
||||
[[kernel]] void rope_freqs<type, traditional, forward>( \
|
||||
const device type* in [[buffer(0)]], \
|
||||
device type* out [[buffer(1)]], \
|
||||
constant const int& offset, \
|
||||
constant const float& scale, \
|
||||
constant const size_t strides[3], \
|
||||
constant const size_t out_strides[3], \
|
||||
constant const size_t& n_batch, \
|
||||
const device float* freqs [[buffer(10)]], \
|
||||
constant const size_t& freq_stride [[buffer(11)]], \
|
||||
uint3 pos [[thread_position_in_grid]], \
|
||||
uint3 grid [[threads_per_grid]]);
|
||||
instantiate_kernel("rope_" #name, rope, type, traditional, forward) \
|
||||
instantiate_kernel("rope_freqs_" #name, rope_freqs, type, traditional, forward)
|
||||
|
||||
#define instantiate_rope_s(name, type, traditional, forward) \
|
||||
template [[host_name("rope_single_" #name)]] [[kernel]] void \
|
||||
rope_single<type, traditional, forward>( \
|
||||
const device type* in [[buffer(0)]], \
|
||||
device type* out [[buffer(1)]], \
|
||||
constant const int& offset, \
|
||||
constant const float& scale, \
|
||||
constant const size_t& stride, \
|
||||
constant const float& base [[buffer(10)]], \
|
||||
uint2 pos [[thread_position_in_grid]], \
|
||||
uint2 grid [[threads_per_grid]]); \
|
||||
template [[host_name("rope_single_freqs_" #name)]] \
|
||||
[[kernel]] void rope_single_freqs<type, traditional, forward>( \
|
||||
const device type* in [[buffer(0)]], \
|
||||
device type* out [[buffer(1)]], \
|
||||
constant const int& offset, \
|
||||
constant const float& scale, \
|
||||
constant const size_t& stride, \
|
||||
const device float* freqs [[buffer(10)]], \
|
||||
constant const size_t& freq_stride [[buffer(11)]], \
|
||||
uint2 pos [[thread_position_in_grid]], \
|
||||
uint2 grid [[threads_per_grid]]);
|
||||
#define instantiate_rope_s(name, type, traditional, forward) \
|
||||
instantiate_kernel("rope_single_" #name, rope_single, type, traditional, forward) \
|
||||
instantiate_kernel("rope_single_freqs_" #name, rope_single_freqs, type, traditional, forward)
|
||||
|
||||
#define instantiate_rope(name, type, traditional, forward) \
|
||||
instantiate_rope_s(name, type, traditional, forward) \
|
||||
|
@@ -9,6 +9,7 @@ constant bool query_transposed [[function_constant(21)]];
|
||||
constant bool do_causal [[function_constant(22)]];
|
||||
constant bool bool_mask [[function_constant(23)]];
|
||||
constant bool float_mask [[function_constant(24)]];
|
||||
constant bool has_sinks [[function_constant(25)]];
|
||||
|
||||
template <typename T, int D, int V = D>
|
||||
[[kernel]] void sdpa_vector(
|
||||
@@ -31,6 +32,9 @@ template <typename T, int D, int V = D>
|
||||
[[buffer(14), function_constant(has_mask)]],
|
||||
const constant int& mask_head_stride
|
||||
[[buffer(15), function_constant(has_mask)]],
|
||||
const device T* sinks [[buffer(16), function_constant(has_sinks)]],
|
||||
const constant int& num_q_heads
|
||||
[[buffer(17), function_constant(has_sinks)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 tpg [[threadgroups_per_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
@@ -53,24 +57,24 @@ template <typename T, int D, int V = D>
|
||||
threadgroup U sum_exp_scores[BN];
|
||||
|
||||
// Adjust positions
|
||||
const int head_idx = tid.x;
|
||||
const int q_batch_head_idx = tid.x;
|
||||
const int q_seq_idx = tid.y;
|
||||
const int kv_head_idx = head_idx / gqa_factor;
|
||||
const int o_offset = head_idx * tpg.y + q_seq_idx;
|
||||
const int kv_head_idx = q_batch_head_idx / gqa_factor;
|
||||
const int o_offset = q_batch_head_idx * tpg.y + q_seq_idx;
|
||||
const int q_offset =
|
||||
query_transposed ? tpg.x * q_seq_idx + head_idx : o_offset;
|
||||
query_transposed ? tpg.x * q_seq_idx + q_batch_head_idx : o_offset;
|
||||
queries += q_offset * D + simd_lid * qk_per_thread;
|
||||
keys += kv_head_idx * k_head_stride + simd_gid * k_seq_stride +
|
||||
simd_lid * qk_per_thread;
|
||||
values += kv_head_idx * v_head_stride + simd_gid * v_seq_stride +
|
||||
simd_lid * v_per_thread;
|
||||
if (bool_mask) {
|
||||
bmask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride +
|
||||
q_seq_idx * mask_q_seq_stride;
|
||||
bmask += q_batch_head_idx * mask_head_stride +
|
||||
simd_gid * mask_kv_seq_stride + q_seq_idx * mask_q_seq_stride;
|
||||
}
|
||||
if (float_mask) {
|
||||
fmask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride +
|
||||
q_seq_idx * mask_q_seq_stride;
|
||||
fmask += q_batch_head_idx * mask_head_stride +
|
||||
simd_gid * mask_kv_seq_stride + q_seq_idx * mask_q_seq_stride;
|
||||
}
|
||||
|
||||
out += o_offset * V + simd_gid * v_per_thread;
|
||||
@@ -85,6 +89,10 @@ template <typename T, int D, int V = D>
|
||||
|
||||
U max_score = -INFINITY;
|
||||
U sum_exp_score = 0;
|
||||
if (has_sinks && simd_gid == 0) {
|
||||
max_score = static_cast<U>(sinks[q_batch_head_idx % num_q_heads]);
|
||||
sum_exp_score = 1;
|
||||
}
|
||||
|
||||
// For each key
|
||||
for (int i = simd_gid; i < N; i += BN) {
|
||||
@@ -93,6 +101,8 @@ template <typename T, int D, int V = D>
|
||||
use_key = i <= (N - int(tpg.y) + int(q_seq_idx));
|
||||
} else if (bool_mask) {
|
||||
use_key = bmask[0];
|
||||
} else if (float_mask) {
|
||||
use_key = (fmask[0] >= Limits<T>::finite_min);
|
||||
}
|
||||
if (use_key) {
|
||||
// Read the key
|
||||
@@ -107,13 +117,14 @@ template <typename T, int D, int V = D>
|
||||
}
|
||||
score = simd_sum(score);
|
||||
if (float_mask) {
|
||||
score += max(Limits<U>::finite_min, static_cast<U>(fmask[0]));
|
||||
score += static_cast<U>(fmask[0]);
|
||||
}
|
||||
|
||||
// Update the accumulators
|
||||
U new_max = max(max_score, score);
|
||||
U factor = fast::exp(max_score - new_max);
|
||||
U exp_score = fast::exp(score - new_max);
|
||||
bool is_neg_inf = new_max == -INFINITY;
|
||||
U factor = is_neg_inf ? 1.0 : fast::exp(max_score - new_max);
|
||||
U exp_score = is_neg_inf ? 0.0 : fast::exp(score - new_max);
|
||||
|
||||
max_score = new_max;
|
||||
sum_exp_score = sum_exp_score * factor + exp_score;
|
||||
@@ -187,6 +198,9 @@ template <typename T, int D, int V = D>
|
||||
[[buffer(16), function_constant(has_mask)]],
|
||||
const constant int& mask_head_stride
|
||||
[[buffer(17), function_constant(has_mask)]],
|
||||
const device T* sinks [[buffer(18), function_constant(has_sinks)]],
|
||||
const constant int& num_q_heads
|
||||
[[buffer(19), function_constant(has_sinks)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 tpg [[threadgroups_per_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
@@ -211,12 +225,12 @@ template <typename T, int D, int V = D>
|
||||
|
||||
// Adjust positions
|
||||
const int block_idx = tid.z;
|
||||
const int head_idx = tid.x;
|
||||
const int q_batch_head_idx = tid.x;
|
||||
const int q_seq_idx = tid.y;
|
||||
const int o_offset = head_idx * tpg.y + q_seq_idx;
|
||||
const int o_offset = q_batch_head_idx * tpg.y + q_seq_idx;
|
||||
const int q_offset =
|
||||
query_transposed ? tpg.x * q_seq_idx + head_idx : o_offset;
|
||||
const int kv_head_idx = head_idx / gqa_factor;
|
||||
query_transposed ? tpg.x * q_seq_idx + q_batch_head_idx : o_offset;
|
||||
const int kv_head_idx = q_batch_head_idx / gqa_factor;
|
||||
|
||||
queries += q_offset * D + simd_lid * qk_per_thread;
|
||||
keys += kv_head_idx * k_head_stride +
|
||||
@@ -225,12 +239,12 @@ template <typename T, int D, int V = D>
|
||||
(block_idx * BN + simd_gid) * v_seq_stride + simd_lid * v_per_thread;
|
||||
out += o_offset * blocks * V + block_idx * V + simd_lid * v_per_thread;
|
||||
if (bool_mask) {
|
||||
bmask += head_idx * mask_head_stride +
|
||||
bmask += q_batch_head_idx * mask_head_stride +
|
||||
(block_idx * BN + simd_gid) * mask_kv_seq_stride +
|
||||
q_seq_idx * mask_q_seq_stride;
|
||||
}
|
||||
if (float_mask) {
|
||||
fmask += head_idx * mask_head_stride +
|
||||
fmask += q_batch_head_idx * mask_head_stride +
|
||||
(block_idx * BN + simd_gid) * mask_kv_seq_stride +
|
||||
q_seq_idx * mask_q_seq_stride;
|
||||
}
|
||||
@@ -245,8 +259,13 @@ template <typename T, int D, int V = D>
|
||||
o[i] = 0;
|
||||
}
|
||||
|
||||
U max_score = -1e9;
|
||||
U max_score = -INFINITY;
|
||||
U sum_exp_score = 0;
|
||||
if (has_sinks && block_idx == 0 && simd_gid == 0) {
|
||||
int q_head_idx = q_batch_head_idx % num_q_heads;
|
||||
max_score = static_cast<U>(sinks[q_head_idx]);
|
||||
sum_exp_score = 1;
|
||||
}
|
||||
|
||||
// For each key
|
||||
for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) {
|
||||
@@ -255,6 +274,8 @@ template <typename T, int D, int V = D>
|
||||
use_key = i <= (N - int(tpg.y) + int(q_seq_idx));
|
||||
} else if (bool_mask) {
|
||||
use_key = bmask[0];
|
||||
} else if (float_mask) {
|
||||
use_key = (fmask[0] >= Limits<T>::finite_min);
|
||||
}
|
||||
if (use_key) {
|
||||
// Read the key
|
||||
@@ -268,6 +289,10 @@ template <typename T, int D, int V = D>
|
||||
score += q[i] * k[i];
|
||||
}
|
||||
score = simd_sum(score);
|
||||
if (score < Limits<T>::finite_min) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (float_mask) {
|
||||
score += fmask[0];
|
||||
}
|
||||
|
@@ -11,6 +11,7 @@ constant bool align_K [[function_constant(201)]];
|
||||
|
||||
constant bool has_mask [[function_constant(300)]];
|
||||
constant bool do_causal [[function_constant(301)]];
|
||||
constant bool has_sinks [[function_constant(302)]];
|
||||
|
||||
template <typename T>
|
||||
struct TransformScale {
|
||||
@@ -82,6 +83,7 @@ template <
|
||||
const constant AttnParams* params [[buffer(4)]],
|
||||
const constant AttnMaskParams* mask_params [[buffer(5), function_constant(has_mask)]],
|
||||
const device MaskType* mask [[buffer(6), function_constant(has_mask)]],
|
||||
const device T* sinks [[buffer(7), function_constant(has_sinks)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
@@ -169,7 +171,7 @@ template <
|
||||
VBlockLoader loader_v(
|
||||
V, params->V_strides[2], Vs, simd_group_id, simd_lane_id);
|
||||
|
||||
TransformScale<T> ts(static_cast<T>(params->scale * 1.44269504089));
|
||||
TransformScale<T> ts(static_cast<T>(params->scale * M_LOG2E_F));
|
||||
|
||||
// Prepare MMA tiles
|
||||
constexpr short kFragSize = 8; // MMAFrag size
|
||||
@@ -232,6 +234,14 @@ template <
|
||||
max_score[i] = Limits<AccumType>::finite_min;
|
||||
}
|
||||
|
||||
if (has_sinks) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kRowsPT; ++i) {
|
||||
max_score[i] = M_LOG2E_F * static_cast<AccumType>(sinks[tidl.y]);
|
||||
sum_score[i] = 1;
|
||||
}
|
||||
}
|
||||
|
||||
int kb_lim = params->NK;
|
||||
|
||||
if (do_causal) {
|
||||
@@ -350,7 +360,7 @@ template <
|
||||
Stile.frag_at(i, j)[jj] =
|
||||
mfrag[jj] ? Stile.frag_at(i, j)[jj] : neg_inf;
|
||||
} else {
|
||||
Stile.frag_at(i, j)[jj] += 1.44269504089 * selem_t(mfrag[jj]);
|
||||
Stile.frag_at(i, j)[jj] += M_LOG2E_F * selem_t(mfrag[jj]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -18,23 +18,29 @@ void RoPE::eval_gpu(
|
||||
auto& in = inputs[0];
|
||||
auto& out = outputs[0];
|
||||
|
||||
if (in.ndim() < 3) {
|
||||
throw std::runtime_error("[RoPE] Input must have at least 3 dimensions");
|
||||
}
|
||||
|
||||
auto& s = out.primitive().stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
size_t strides[3];
|
||||
size_t out_strides[3];
|
||||
int64_t strides[3];
|
||||
int64_t out_strides[3];
|
||||
bool donated = false;
|
||||
int ndim = in.ndim();
|
||||
int dispatch_ndim = in.ndim();
|
||||
int B = in.shape(0);
|
||||
int T = in.shape(-2);
|
||||
int D = in.shape(-1);
|
||||
size_t mat_size = T * D;
|
||||
|
||||
int dispatch_ndim = ndim;
|
||||
while (in.shape(-dispatch_ndim) == 1 && dispatch_ndim > 3) {
|
||||
dispatch_ndim--;
|
||||
}
|
||||
size_t mat_size = in.shape(-2) * in.shape(-1);
|
||||
if (dims_ < in.shape(-1)) {
|
||||
|
||||
int N = 1;
|
||||
for (int i = 1; i < (ndim - 2); ++i) {
|
||||
N *= in.shape(i);
|
||||
}
|
||||
|
||||
if (dims_ < D) {
|
||||
donated = true;
|
||||
auto ctype =
|
||||
(in.flags().row_contiguous) ? CopyType::Vector : CopyType::General;
|
||||
@@ -71,8 +77,8 @@ void RoPE::eval_gpu(
|
||||
out_strides[1] = out.strides()[ndim - 2];
|
||||
out_strides[2] = out.strides()[ndim - 1];
|
||||
|
||||
// Special case for inference (single time step and contiguous)
|
||||
bool single = in.flags().row_contiguous && (mat_size == in.shape(-1));
|
||||
// Special case for inference (single batch, single time step, and contiguous)
|
||||
bool single = in.flags().row_contiguous && B == 1 && T == 1;
|
||||
|
||||
bool with_freqs = inputs.size() == 3;
|
||||
std::ostringstream kname;
|
||||
@@ -86,24 +92,29 @@ void RoPE::eval_gpu(
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(donated ? out : in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
|
||||
compute_encoder.set_input_array(inputs[1], 2);
|
||||
compute_encoder.set_bytes(scale_, 3);
|
||||
|
||||
size_t n_batch = in.size() / mat_size;
|
||||
MTL::Size group_dims;
|
||||
MTL::Size grid_dims;
|
||||
if (single) {
|
||||
compute_encoder.set_bytes(out_strides, 1, 4);
|
||||
uint32_t dim0 = dims_ / 2;
|
||||
group_dims = get_block_dims(dim0, n_batch, 1);
|
||||
grid_dims = MTL::Size(dim0, n_batch, 1);
|
||||
group_dims = get_block_dims(dim0, N, 1);
|
||||
grid_dims = MTL::Size(dim0, N, 1);
|
||||
} else {
|
||||
compute_encoder.set_bytes(strides, 3, 4);
|
||||
compute_encoder.set_bytes(out_strides, 3, 5);
|
||||
compute_encoder.set_bytes(n_batch, 6);
|
||||
int64_t offset_stride = 0;
|
||||
if (inputs[1].ndim() > 0) {
|
||||
offset_stride = inputs[1].strides()[0];
|
||||
}
|
||||
compute_encoder.set_bytes(offset_stride, 6);
|
||||
compute_encoder.set_bytes(N, 7);
|
||||
uint32_t dim0 = dims_ / 2;
|
||||
uint32_t dim1 = in.shape(-2);
|
||||
uint32_t dim2 = (n_batch + n_per_thread - 1) / n_per_thread;
|
||||
uint32_t dim1 = T;
|
||||
uint32_t dim2 = B * ((N + n_per_thread - 1) / n_per_thread);
|
||||
group_dims = get_block_dims(dim0, dim1, dim2);
|
||||
grid_dims = MTL::Size(dim0, dim1, dim2);
|
||||
}
|
||||
|
@@ -21,8 +21,9 @@ void sdpa_full_self_attention_metal(
|
||||
const array& v,
|
||||
const float scale,
|
||||
array& o,
|
||||
bool do_causal_ = false,
|
||||
const std::optional<array>& mask = std::nullopt) {
|
||||
bool do_causal_,
|
||||
const std::optional<array>& mask,
|
||||
const std::optional<array>& sinks) {
|
||||
using namespace mlx::steel;
|
||||
|
||||
int wm = 4;
|
||||
@@ -42,35 +43,49 @@ void sdpa_full_self_attention_metal(
|
||||
|
||||
const bool align_Q = (qL % bq) == 0;
|
||||
const bool align_K = (kL % bk) == 0;
|
||||
const bool has_mask = !!mask;
|
||||
const bool has_mask = mask.has_value();
|
||||
const bool do_causal = do_causal_;
|
||||
const bool has_sinks = sinks.has_value();
|
||||
|
||||
metal::MTLFCList func_consts = {
|
||||
{&align_Q, MTL::DataType::DataTypeBool, 200},
|
||||
{&align_K, MTL::DataType::DataTypeBool, 201},
|
||||
{&has_mask, MTL::DataType::DataTypeBool, 300},
|
||||
{&do_causal, MTL::DataType::DataTypeBool, 301}};
|
||||
{&do_causal, MTL::DataType::DataTypeBool, 301},
|
||||
{&has_sinks, MTL::DataType::DataTypeBool, 302}};
|
||||
|
||||
std::ostringstream kname;
|
||||
// clang-format off
|
||||
kname << "steel_attention_"
|
||||
<< type_to_name(q)
|
||||
<< "_bq" << bq
|
||||
<< "_bk" << bk
|
||||
<< "_bd" << bd
|
||||
<< "_wm" << wm
|
||||
<< "_wn" << wn
|
||||
<< "_mask" << (type_to_name(has_mask ? *mask : q)); // clang-format on
|
||||
std::string base_name;
|
||||
concatenate(
|
||||
base_name,
|
||||
"steel_attention_",
|
||||
type_to_name(q),
|
||||
"_bq",
|
||||
bq,
|
||||
"_bk",
|
||||
bk,
|
||||
"_bd",
|
||||
bd,
|
||||
"_wm",
|
||||
wm,
|
||||
"_wn",
|
||||
wn,
|
||||
"_mask",
|
||||
type_to_name(has_mask ? *mask : q));
|
||||
|
||||
std::string base_name = kname.str();
|
||||
|
||||
// clang-format off
|
||||
kname << "_align_Q_" << (align_Q ? 't' : 'n')
|
||||
<< "_align_K_" << (align_K ? 't' : 'n')
|
||||
<< "_has_mask_" << (has_mask ? 't' : 'n')
|
||||
<< "_do_causal_" << (do_causal ? 't' : 'n'); // clang-format on
|
||||
|
||||
std::string hash_name = kname.str();
|
||||
std::string hash_name;
|
||||
concatenate(
|
||||
hash_name,
|
||||
base_name,
|
||||
"_align_Q_",
|
||||
(align_Q ? 't' : 'n'),
|
||||
"_align_K_",
|
||||
(align_K ? 't' : 'n'),
|
||||
"_has_mask_",
|
||||
(has_mask ? 't' : 'n'),
|
||||
"_do_causal_",
|
||||
(do_causal ? 't' : 'n'),
|
||||
"_has_sinks_",
|
||||
(has_sinks ? 't' : 'n'));
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(base_name, hash_name, func_consts);
|
||||
@@ -114,8 +129,8 @@ void sdpa_full_self_attention_metal(
|
||||
compute_encoder.set_output_array(o, 3);
|
||||
compute_encoder.set_bytes(params, 4);
|
||||
|
||||
if (mask) {
|
||||
auto m = *mask;
|
||||
if (has_mask) {
|
||||
auto& m = *mask;
|
||||
|
||||
AttnMaskParams mask_params{/* int64_t M_strides[3] = */ {
|
||||
m.strides(0), m.strides(1), m.strides(2)}};
|
||||
@@ -123,6 +138,9 @@ void sdpa_full_self_attention_metal(
|
||||
compute_encoder.set_bytes(mask_params, 5);
|
||||
compute_encoder.set_input_array(m, 6);
|
||||
}
|
||||
if (has_sinks) {
|
||||
compute_encoder.set_input_array(*sinks, 7);
|
||||
}
|
||||
|
||||
MTL::Size grid_dims = MTL::Size(NQ, H, B);
|
||||
MTL::Size group_dims = MTL::Size(32, wm, wn);
|
||||
@@ -139,7 +157,8 @@ void sdpa_vector(
|
||||
array& out,
|
||||
float scale,
|
||||
bool do_causal,
|
||||
const std::optional<array>& mask) {
|
||||
const std::optional<array>& mask,
|
||||
const std::optional<array>& sinks) {
|
||||
// Set the kernel name
|
||||
std::string kname;
|
||||
kname.reserve(64);
|
||||
@@ -153,30 +172,32 @@ void sdpa_vector(
|
||||
// Compute the necessary sizes
|
||||
int gqa_factor = q.shape(1) / k.shape(1);
|
||||
int N = k.shape(2);
|
||||
int B = q.shape(0) * q.shape(1);
|
||||
size_t k_head_stride = k.shape(1) == 1 ? k.strides(0) : k.strides(1);
|
||||
size_t k_seq_stride = k.strides()[2];
|
||||
size_t v_head_stride = v.shape(1) == 1 ? v.strides(0) : v.strides(1);
|
||||
size_t v_seq_stride = v.strides()[2];
|
||||
|
||||
MTL::Size group_dims(1024, 1, 1);
|
||||
MTL::Size grid_dims(B, q.shape(2), 1);
|
||||
MTL::Size grid_dims(q.shape(0) * q.shape(1), q.shape(2), 1);
|
||||
|
||||
bool has_mask = mask.has_value();
|
||||
bool bool_mask = has_mask && (*mask).dtype() == bool_;
|
||||
bool float_mask = has_mask && !bool_mask;
|
||||
bool query_transposed = !q.flags().row_contiguous;
|
||||
bool has_sinks = sinks.has_value();
|
||||
metal::MTLFCList func_consts = {
|
||||
{&has_mask, MTL::DataType::DataTypeBool, 20},
|
||||
{&query_transposed, MTL::DataType::DataTypeBool, 21},
|
||||
{&do_causal, MTL::DataType::DataTypeBool, 22},
|
||||
{&bool_mask, MTL::DataType::DataTypeBool, 23},
|
||||
{&float_mask, MTL::DataType::DataTypeBool, 24},
|
||||
{&has_sinks, MTL::DataType::DataTypeBool, 25},
|
||||
};
|
||||
std::string hash_name = kname;
|
||||
hash_name += has_mask ? (bool_mask ? "_boolmask" : "_floatmask") : "_nomask";
|
||||
hash_name += query_transposed ? "_qt" : "_qnt";
|
||||
hash_name += do_causal ? "_c" : "_nc";
|
||||
hash_name += has_sinks ? "_sinks" : "_nosinks";
|
||||
|
||||
// Get the kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
@@ -207,6 +228,10 @@ void sdpa_vector(
|
||||
compute_encoder.set_bytes(q_seq_stride, 14);
|
||||
compute_encoder.set_bytes(head_stride, 15);
|
||||
}
|
||||
if (has_sinks) {
|
||||
compute_encoder.set_input_array(*sinks, 16);
|
||||
compute_encoder.set_bytes(q.shape(1), 17);
|
||||
}
|
||||
|
||||
// Launch
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
@@ -221,7 +246,8 @@ void sdpa_vector_2pass(
|
||||
array& out,
|
||||
float scale,
|
||||
bool do_causal,
|
||||
const std::optional<array>& mask) {
|
||||
const std::optional<array>& mask,
|
||||
const std::optional<array>& sinks) {
|
||||
// Set the kernel name
|
||||
std::string kname;
|
||||
kname.reserve(64);
|
||||
@@ -267,17 +293,20 @@ void sdpa_vector_2pass(
|
||||
bool bool_mask = has_mask && (*mask).dtype() == bool_;
|
||||
bool float_mask = has_mask && !bool_mask;
|
||||
bool query_transposed = !q.flags().row_contiguous;
|
||||
bool has_sinks = sinks.has_value();
|
||||
metal::MTLFCList func_consts = {
|
||||
{&has_mask, MTL::DataType::DataTypeBool, 20},
|
||||
{&query_transposed, MTL::DataType::DataTypeBool, 21},
|
||||
{&do_causal, MTL::DataType::DataTypeBool, 22},
|
||||
{&bool_mask, MTL::DataType::DataTypeBool, 23},
|
||||
{&float_mask, MTL::DataType::DataTypeBool, 24},
|
||||
{&has_sinks, MTL::DataType::DataTypeBool, 25},
|
||||
};
|
||||
std::string hash_name = kname;
|
||||
hash_name += has_mask ? (bool_mask ? "_boolmask" : "_floatmask") : "_nomask";
|
||||
hash_name += query_transposed ? "_qt" : "_qnt";
|
||||
hash_name += do_causal ? "_c" : "_nc";
|
||||
hash_name += has_sinks ? "_sinks" : "_nosinks";
|
||||
|
||||
// Get the kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
@@ -310,6 +339,10 @@ void sdpa_vector_2pass(
|
||||
compute_encoder.set_bytes(q_seq_stride, 16);
|
||||
compute_encoder.set_bytes(head_stride, 17);
|
||||
}
|
||||
if (has_sinks) {
|
||||
compute_encoder.set_input_array(*sinks, 18);
|
||||
compute_encoder.set_bytes(q.shape(1), 19);
|
||||
}
|
||||
|
||||
// Launch
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
@@ -394,7 +427,7 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
|
||||
// Define some copy functions to ensure the layout of the inputs is as
|
||||
// expected.
|
||||
copies.reserve(3);
|
||||
copies.reserve(inputs.size());
|
||||
auto copy_unless = [&copies, &s](
|
||||
auto predicate, const array& arr) -> const array& {
|
||||
if (!predicate(arr)) {
|
||||
@@ -411,6 +444,12 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
return arr.strides(-1) == 1;
|
||||
};
|
||||
|
||||
std::optional<array> sinks = std::nullopt;
|
||||
if (has_sinks_) {
|
||||
sinks = copy_unless(is_matrix_contiguous, inputs.back());
|
||||
}
|
||||
bool has_arr_mask = inputs.size() > (3 + has_sinks_);
|
||||
|
||||
// We are in vector mode ie single query
|
||||
if (q_pre.shape(2) <= 8) {
|
||||
auto q_copy_unless = [](const array& arr) {
|
||||
@@ -462,7 +501,7 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
(strides[0] == strides[1] * shape[1]);
|
||||
};
|
||||
|
||||
auto mask = inputs.size() > 3
|
||||
auto mask = has_arr_mask
|
||||
? std::optional<array>{copy_unless(mask_copy_unless, inputs[3])}
|
||||
: std::nullopt;
|
||||
|
||||
@@ -473,9 +512,9 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
char devc = d.get_architecture().back();
|
||||
if ((devc == 'd' && k.shape(2) >= 1024) ||
|
||||
(k.shape(1) < q.shape(1) && k.shape(2) >= 4096)) {
|
||||
sdpa_vector_2pass(s, d, q, k, v, o, scale_, do_causal, mask);
|
||||
sdpa_vector_2pass(s, d, q, k, v, o, scale_, do_causal, mask, sinks);
|
||||
} else {
|
||||
sdpa_vector(s, d, q, k, v, o, scale_, do_causal, mask);
|
||||
sdpa_vector(s, d, q, k, v, o, scale_, do_causal, mask, sinks);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -503,11 +542,12 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
{str_oB, str_oH, str_oL, str_oD},
|
||||
flags);
|
||||
|
||||
auto mask = inputs.size() > 3
|
||||
auto mask = has_arr_mask
|
||||
? std::optional<array>{copy_unless(is_matrix_contiguous, inputs[3])}
|
||||
: std::nullopt;
|
||||
|
||||
sdpa_full_self_attention_metal(s, d, q, k, v, scale_, o, do_causal_, mask);
|
||||
sdpa_full_self_attention_metal(
|
||||
s, d, q, k, v, scale_, o, do_causal_, mask, sinks);
|
||||
}
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
|
@@ -8,6 +8,7 @@ file(
|
||||
"${CMAKE_CURRENT_BINARY_DIR}/nccl.h")
|
||||
|
||||
add_library(nccl SHARED nccl_stubs.cpp)
|
||||
set_target_properties(nccl PROPERTIES SOVERSION 2)
|
||||
find_package(CUDAToolkit REQUIRED)
|
||||
target_include_directories(nccl PRIVATE ${CUDAToolkit_INCLUDE_DIRS})
|
||||
target_include_directories(nccl PRIVATE ${CMAKE_CURRENT_BINARY_DIR})
|
||||
|
141
mlx/fast.cpp
141
mlx/fast.cpp
@@ -366,10 +366,16 @@ array rope(
|
||||
msg << "[rope] Input must be a floating type but got " << x.dtype() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (offset.size() != 1) {
|
||||
if (offset.ndim() > 1) {
|
||||
std::ostringstream msg;
|
||||
msg << "[rope] offset must be a scalar but has shape " << offset.shape()
|
||||
<< ".";
|
||||
msg << "[rope] offset must have at most one dimension but has shape "
|
||||
<< offset.shape() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (offset.size() != 1 && offset.size() != x.shape(0)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[rope] offset must be a scalar or vector with " << x.shape(0)
|
||||
<< " elements but has shape " << offset.shape() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (!issubdtype(offset.dtype(), integer)) {
|
||||
@@ -379,7 +385,7 @@ array rope(
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (offset.dtype().size() != 4) {
|
||||
inputs[1] = astype(offset, uint32, s);
|
||||
inputs[1] = astype(offset, int32, s);
|
||||
}
|
||||
if (inputs.size() == 3 &&
|
||||
(inputs[2].ndim() != 1 || inputs[2].shape(0) != dims / 2)) {
|
||||
@@ -391,15 +397,26 @@ array rope(
|
||||
|
||||
auto fallback = [dims, traditional, base, scale, forward, s](
|
||||
std::vector<array> inputs) {
|
||||
auto& shape = inputs[0].shape();
|
||||
int ndim = shape.size();
|
||||
auto x = flatten(inputs[0], 0, ndim - 3, s);
|
||||
auto x = inputs[0];
|
||||
auto shape = x.shape();
|
||||
if (x.ndim() == 3) {
|
||||
x = expand_dims(x, 1, s);
|
||||
} else if (x.ndim() > 4) {
|
||||
x = flatten(x, 1, 1 + (x.ndim() - 4), s);
|
||||
}
|
||||
|
||||
auto B = x.shape(0);
|
||||
auto N = x.shape(1);
|
||||
auto T = x.shape(2);
|
||||
auto t = x.dtype();
|
||||
// Compute sines and cosines
|
||||
auto half_dims = dims / 2;
|
||||
auto& offset = inputs[1];
|
||||
auto offset = inputs[1];
|
||||
if (offset.size() > 1) {
|
||||
offset = expand_dims(offset, {-1, -2}, s);
|
||||
}
|
||||
auto positions =
|
||||
multiply(add(arange(x.shape(1), t, s), offset, s), array(scale, t), s);
|
||||
multiply(add(arange(x.shape(2), t, s), offset, s), array(scale, t), s);
|
||||
|
||||
auto default_inv_freqs = [&inputs, &s, &t, base, half_dims]() {
|
||||
return exp(
|
||||
@@ -412,8 +429,7 @@ array rope(
|
||||
|
||||
auto inv_freqs = inputs.size() == 3 ? astype(reciprocal(inputs[2], s), t, s)
|
||||
: default_inv_freqs();
|
||||
auto theta =
|
||||
multiply(expand_dims(positions, 1, s), expand_dims(inv_freqs, 0, s), s);
|
||||
auto theta = multiply(expand_dims(positions, -1, s), inv_freqs, s);
|
||||
auto coss = cos(theta, s);
|
||||
auto sins = sin(theta, s);
|
||||
|
||||
@@ -436,32 +452,30 @@ array rope(
|
||||
};
|
||||
|
||||
if (traditional) {
|
||||
auto x1 =
|
||||
slice(x, {0, 0, 0}, {x.shape(0), x.shape(1), dims}, {1, 1, 2}, s);
|
||||
auto x2 =
|
||||
slice(x, {0, 0, 1}, {x.shape(0), x.shape(1), dims}, {1, 1, 2}, s);
|
||||
auto x1 = slice(x, {0, 0, 0, 0}, {B, N, T, dims}, {1, 1, 1, 2}, s);
|
||||
auto x2 = slice(x, {0, 0, 0, 1}, {B, N, T, dims}, {1, 1, 1, 2}, s);
|
||||
auto outs = apply_rope(x1, x2, coss, sins);
|
||||
for (auto& o : outs) {
|
||||
o = expand_dims(o, 3, s);
|
||||
o = expand_dims(o, -1, s);
|
||||
}
|
||||
auto out = concatenate(outs, 3, s);
|
||||
auto out = reshape(concatenate(outs, -1, s), {B, N, T, dims}, s);
|
||||
if (dims < x.shape(-1)) {
|
||||
out = reshape(out, {x.shape(0), x.shape(1), dims});
|
||||
out = concatenate({out, slice(x, {0, 0, dims}, x.shape(), s)}, 2, s);
|
||||
out =
|
||||
concatenate({out, slice(x, {0, 0, 0, dims}, x.shape(), s)}, -1, s);
|
||||
}
|
||||
return std::vector<array>{reshape(out, shape, s)};
|
||||
} else {
|
||||
auto out_s = x.shape();
|
||||
out_s.back() = half_dims;
|
||||
auto x1 = slice(x, {0, 0, 0}, out_s, s);
|
||||
auto x1 = slice(x, {0, 0, 0, 0}, out_s, s);
|
||||
out_s.back() = dims;
|
||||
auto x2 = slice(x, {0, 0, half_dims}, out_s, s);
|
||||
auto x2 = slice(x, {0, 0, 0, half_dims}, out_s, s);
|
||||
|
||||
auto outs = apply_rope(x1, x2, coss, sins);
|
||||
if (dims < x.shape(-1)) {
|
||||
outs.push_back(slice(x, {0, 0, dims}, x.shape(), s));
|
||||
outs.push_back(slice(x, {0, 0, 0, dims}, x.shape(), s));
|
||||
}
|
||||
return std::vector<array>{reshape(concatenate(outs, 2, s), shape, s)};
|
||||
return std::vector<array>{reshape(concatenate(outs, -1, s), shape, s)};
|
||||
}
|
||||
};
|
||||
auto stream = to_stream(s);
|
||||
@@ -565,6 +579,7 @@ array scaled_dot_product_attention(
|
||||
const float scale,
|
||||
const std::string& mask_mode /* = "" */,
|
||||
const std::vector<array>& mask_arrs /* = {} */,
|
||||
const std::optional<array>& sinks /* = {} */,
|
||||
StreamOrDevice s /* = {}*/) {
|
||||
for (const auto& tensor : {queries, keys, values}) {
|
||||
if (tensor.ndim() != 4) {
|
||||
@@ -665,13 +680,20 @@ array scaled_dot_product_attention(
|
||||
<< final_type << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
bool has_sinks = sinks.has_value();
|
||||
|
||||
auto q = astype(queries, final_type, s);
|
||||
auto k = astype(keys, final_type, s);
|
||||
auto v = astype(values, final_type, s);
|
||||
|
||||
auto fallback = [scale, final_type, n_q_heads, n_kv_heads, do_causal, s](
|
||||
const std::vector<array>& inputs) {
|
||||
auto fallback = [scale,
|
||||
final_type,
|
||||
n_q_heads,
|
||||
n_kv_heads,
|
||||
do_causal,
|
||||
has_sinks,
|
||||
has_arr_mask,
|
||||
s](const std::vector<array>& inputs) {
|
||||
auto q = multiply(array(scale, inputs[0].dtype()), inputs[0], s);
|
||||
int n_repeats = n_q_heads / n_kv_heads;
|
||||
int B = q.shape(0);
|
||||
@@ -684,20 +706,22 @@ array scaled_dot_product_attention(
|
||||
v = expand_dims(v, 2, s);
|
||||
}
|
||||
auto scores = matmul(q, swapaxes(k, -1, -2, s), s);
|
||||
if (inputs.size() > 3 || do_causal) {
|
||||
if (has_arr_mask || do_causal) {
|
||||
// Mask must be broadcast-compatible with [B, n_q_heads, L_q, L_kv]
|
||||
auto mask = inputs.back();
|
||||
|
||||
if (do_causal) {
|
||||
int kL = k.shape(-2);
|
||||
int qL = q.shape(-2);
|
||||
int q_off = (kL - qL) < 0 ? 0 : (kL - qL);
|
||||
auto q_idx = arange(q_off, q_off + qL, s);
|
||||
auto k_idx = arange(0, kL, s);
|
||||
q_idx = expand_dims(q_idx, 1, s);
|
||||
k_idx = expand_dims(k_idx, 0, s);
|
||||
mask = greater_equal(q_idx, k_idx, s);
|
||||
}
|
||||
auto make_or_fetch_mask = [&]() {
|
||||
if (do_causal) {
|
||||
int kL = k.shape(-2);
|
||||
int qL = q.shape(-2);
|
||||
int q_off = (kL - qL) < 0 ? 0 : (kL - qL);
|
||||
auto q_idx = arange(q_off, q_off + qL, s);
|
||||
auto k_idx = arange(0, kL, s);
|
||||
q_idx = expand_dims(q_idx, 1, s);
|
||||
k_idx = expand_dims(k_idx, 0, s);
|
||||
return greater_equal(q_idx, k_idx, s);
|
||||
}
|
||||
return inputs[3];
|
||||
};
|
||||
auto mask = make_or_fetch_mask();
|
||||
|
||||
if (n_repeats > 1 && mask.ndim() >= 3) {
|
||||
if (mask.shape(-3) == 1) {
|
||||
@@ -716,7 +740,25 @@ array scaled_dot_product_attention(
|
||||
scores = add(scores, mask, s);
|
||||
}
|
||||
}
|
||||
if (has_sinks) {
|
||||
auto sinks = inputs.back();
|
||||
// scores has shape B N_q N_k L_q L_k
|
||||
sinks = expand_dims(sinks, {0, 2, 3}, s);
|
||||
if (scores.ndim() == 5) {
|
||||
sinks = unflatten(sinks, 1, {n_kv_heads, n_repeats}, s);
|
||||
}
|
||||
auto bsx_shape = scores.shape();
|
||||
bsx_shape.back() = 1;
|
||||
scores = concatenate({broadcast_to(sinks, bsx_shape, s), scores}, -1, s);
|
||||
}
|
||||
scores = softmax(scores, std::vector<int>{-1}, true, s);
|
||||
if (has_sinks) {
|
||||
// Slice off scores
|
||||
auto start = Shape(scores.ndim(), 0);
|
||||
start.back() = 1;
|
||||
auto stop = scores.shape();
|
||||
scores = slice(scores, std::move(start), std::move(stop), s);
|
||||
}
|
||||
auto out = matmul(scores, v, s);
|
||||
if (n_repeats > 1) {
|
||||
out = flatten(out, 1, 2, s);
|
||||
@@ -732,7 +774,7 @@ array scaled_dot_product_attention(
|
||||
has_bool_mask = mask_arr.dtype() == bool_;
|
||||
if (promote_types(mask_arr.dtype(), final_type) != final_type) {
|
||||
std::ostringstream msg;
|
||||
msg << "[scaled_dot_product_attention] Mask type must promote to output type. "
|
||||
msg << "[scaled_dot_product_attention] Mask type must promote to output type "
|
||||
<< final_type << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
} else if (!has_bool_mask) {
|
||||
@@ -743,6 +785,22 @@ array scaled_dot_product_attention(
|
||||
mask_shape.back() = keys.shape(-2);
|
||||
inputs.push_back(broadcast_to(mask_arr, mask_shape, stream));
|
||||
}
|
||||
if (has_sinks) {
|
||||
if (promote_types(sinks->dtype(), final_type) != final_type) {
|
||||
std::ostringstream msg;
|
||||
msg << "[scaled_dot_product_attention] Type of sinks must promote to output type "
|
||||
<< final_type << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (sinks->ndim() != 1 || sinks->shape(0) != n_q_heads) {
|
||||
std::ostringstream msg;
|
||||
msg << "[scaled_dot_product_attention] Received invalid shape for sinks "
|
||||
<< sinks->shape() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
inputs.push_back(astype(*sinks, final_type, stream));
|
||||
}
|
||||
|
||||
if (!ScaledDotProductAttention::use_fallback(
|
||||
q, k, v, has_mask, has_arr_mask, do_causal, stream)) {
|
||||
auto out_shape = Shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)};
|
||||
@@ -750,7 +808,7 @@ array scaled_dot_product_attention(
|
||||
std::move(out_shape),
|
||||
final_type,
|
||||
std::make_shared<ScaledDotProductAttention>(
|
||||
stream, fallback, scale, do_causal),
|
||||
stream, fallback, scale, do_causal, has_sinks),
|
||||
std::move(inputs));
|
||||
}
|
||||
return fallback(std::move(inputs))[0];
|
||||
@@ -759,7 +817,8 @@ array scaled_dot_product_attention(
|
||||
bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const {
|
||||
const ScaledDotProductAttention& a_other =
|
||||
static_cast<const ScaledDotProductAttention&>(other);
|
||||
return scale_ == a_other.scale_ && do_causal_ == a_other.do_causal_;
|
||||
return scale_ == a_other.scale_ && do_causal_ == a_other.do_causal_ &&
|
||||
has_sinks_ == a_other.has_sinks_;
|
||||
}
|
||||
|
||||
bool Quantize::is_equivalent(const Primitive& other) const {
|
||||
|
@@ -50,6 +50,7 @@ array scaled_dot_product_attention(
|
||||
const float scale,
|
||||
const std::string& mask_mode = "",
|
||||
const std::vector<array>& mask_arrs = {},
|
||||
const std::optional<array>& sinks = {},
|
||||
StreamOrDevice s = {});
|
||||
|
||||
using TemplateArg = std::variant<int, bool, Dtype>;
|
||||
|
@@ -208,9 +208,13 @@ class ScaledDotProductAttention : public Custom {
|
||||
explicit ScaledDotProductAttention(
|
||||
Stream stream,
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback,
|
||||
const float scale,
|
||||
const bool do_causal)
|
||||
: Custom(stream, fallback), scale_(scale), do_causal_(do_causal) {}
|
||||
float scale,
|
||||
bool do_causal,
|
||||
bool has_sinks)
|
||||
: Custom(stream, fallback),
|
||||
scale_(scale),
|
||||
do_causal_(do_causal),
|
||||
has_sinks_(has_sinks) {}
|
||||
|
||||
static bool use_fallback(
|
||||
const array& q,
|
||||
@@ -237,12 +241,13 @@ class ScaledDotProductAttention : public Custom {
|
||||
DEFINE_NAME(ScaledDotProductAttention);
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
auto state() const {
|
||||
return std::make_tuple(nullptr, scale_, do_causal_);
|
||||
return std::make_tuple(nullptr, scale_, do_causal_, has_sinks_);
|
||||
}
|
||||
|
||||
private:
|
||||
float scale_;
|
||||
bool do_causal_;
|
||||
bool has_sinks_;
|
||||
};
|
||||
|
||||
class Quantize : public Custom {
|
||||
|
@@ -1,10 +1,8 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
std::string version() {
|
||||
const char* version() {
|
||||
return MLX_VERSION;
|
||||
}
|
||||
|
||||
|
@@ -4,7 +4,7 @@
|
||||
|
||||
#define MLX_VERSION_MAJOR 0
|
||||
#define MLX_VERSION_MINOR 29
|
||||
#define MLX_VERSION_PATCH 0
|
||||
#define MLX_VERSION_PATCH 1
|
||||
#define MLX_VERSION_NUMERIC \
|
||||
(100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH)
|
||||
|
||||
@@ -15,6 +15,6 @@ namespace mlx::core {
|
||||
*
|
||||
* For dev builds, the version will include the suffix ".devYYYYMMDD+hash"
|
||||
*/
|
||||
std::string version();
|
||||
const char* version();
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -1,20 +1,34 @@
|
||||
mlx.core.__prefix__:
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
|
||||
import sys
|
||||
if sys.version_info >= (3, 10):
|
||||
from typing import TypeAlias
|
||||
else:
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
mlx.core.__suffix__:
|
||||
from typing import Union
|
||||
scalar: TypeAlias = Union[int, float, bool]
|
||||
list_or_scalar: TypeAlias = Union[scalar, list["list_or_scalar"]]
|
||||
bool_: Dtype = ...
|
||||
|
||||
mlx.core.distributed.__prefix__:
|
||||
from mlx.core import array, Dtype, Device, Stream
|
||||
from mlx.core import array, Dtype, Device, Stream, scalar
|
||||
from mlx.core.distributed import Group
|
||||
from typing import Sequence, Optional, Union
|
||||
|
||||
mlx.core.fast.__prefix__:
|
||||
from mlx.core import array, Dtype, Device, Stream
|
||||
from mlx.core import array, Dtype, Device, Stream, scalar
|
||||
from typing import Sequence, Optional, Union
|
||||
|
||||
mlx.core.linalg.__prefix__:
|
||||
from mlx.core import array, Dtype, Device, Stream
|
||||
from mlx.core import array, Dtype, Device, Stream, scalar
|
||||
from typing import Sequence, Optional, Tuple, Union
|
||||
|
||||
mlx.core.metal.__prefix__:
|
||||
from mlx.core import array, Dtype, Device, Stream
|
||||
from mlx.core import array, Dtype, Device, Stream, scalar
|
||||
from typing import Sequence, Optional, Union
|
||||
|
||||
mlx.core.random.__prefix__:
|
||||
from mlx.core import array, Dtype, Device, Stream
|
||||
from mlx.core import array, Dtype, Device, Stream, scalar, float32, int32
|
||||
from typing import Sequence, Optional, Union
|
||||
|
@@ -15,6 +15,7 @@ from mlx.nn.layers.activations import (
|
||||
Mish,
|
||||
PReLU,
|
||||
ReLU,
|
||||
ReLU2,
|
||||
ReLU6,
|
||||
Sigmoid,
|
||||
SiLU,
|
||||
@@ -40,6 +41,7 @@ from mlx.nn.layers.activations import (
|
||||
mish,
|
||||
prelu,
|
||||
relu,
|
||||
relu2,
|
||||
relu6,
|
||||
selu,
|
||||
sigmoid,
|
||||
|
@@ -35,6 +35,24 @@ def relu(x):
|
||||
return mx.maximum(x, 0)
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def relu2(x):
|
||||
r"""Applies the ReLU² activation function.
|
||||
|
||||
Applies :math:`\max(0, x)^2` element wise.
|
||||
"""
|
||||
return mx.square(mx.maximum(x, 0))
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def relu6(x):
|
||||
r"""Applies the Rectified Linear Unit 6.
|
||||
|
||||
Applies :math:`\min(\max(x, 0), 6)` element wise.
|
||||
"""
|
||||
return mx.minimum(mx.maximum(x, 0), 6.0)
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def leaky_relu(x, negative_slope=0.01):
|
||||
r"""Applies the Leaky Rectified Linear Unit.
|
||||
@@ -62,15 +80,6 @@ def elu(x, alpha=1.0):
|
||||
return mx.where(x > 0, x, alpha * (mx.exp(x) - 1))
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def relu6(x):
|
||||
r"""Applies the Rectified Linear Unit 6.
|
||||
|
||||
Applies :math:`\min(\max(x, 0), 6)` element wise.
|
||||
"""
|
||||
return mx.minimum(mx.maximum(x, 0), 6.0)
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def softmax(x, axis=-1):
|
||||
r"""Applies the Softmax function.
|
||||
@@ -377,6 +386,22 @@ class ReLU(Module):
|
||||
"""
|
||||
|
||||
|
||||
@_make_activation_module(relu2)
|
||||
class ReLU2(Module):
|
||||
r"""Applies the ReLU² activation function.
|
||||
|
||||
See :func:`relu2` for the functional equivalent.
|
||||
"""
|
||||
|
||||
|
||||
@_make_activation_module(relu6)
|
||||
class ReLU6(Module):
|
||||
r"""Applies the Rectified Linear Unit 6.
|
||||
|
||||
See :func:`relu6` for the functional equivalent.
|
||||
"""
|
||||
|
||||
|
||||
class LeakyReLU(Module):
|
||||
r"""Applies the Leaky Rectified Linear Unit.
|
||||
|
||||
@@ -412,14 +437,6 @@ class ELU(Module):
|
||||
return elu(x, self._alpha)
|
||||
|
||||
|
||||
@_make_activation_module(relu6)
|
||||
class ReLU6(Module):
|
||||
r"""Applies the Rectified Linear Unit 6.
|
||||
|
||||
See :func:`relu6` for the functional equivalent.
|
||||
"""
|
||||
|
||||
|
||||
@_make_activation_module(softmax)
|
||||
class Softmax(Module):
|
||||
r"""Applies the Softmax function.
|
||||
|
@@ -556,7 +556,7 @@ class AdamW(Adam):
|
||||
eps (float, optional): The term :math:`\epsilon` added to the
|
||||
denominator to improve numerical stability. Default: ``1e-8``
|
||||
weight_decay (float, optional): The weight decay :math:`\lambda`.
|
||||
Default: ``0``.
|
||||
Default: ``0.01``.
|
||||
bias_correction (bool, optional): If set to ``True``, bias correction
|
||||
is applied. Default: ``False``
|
||||
"""
|
||||
|
@@ -52,7 +52,6 @@ set_target_properties(
|
||||
${MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY})
|
||||
|
||||
target_link_libraries(core PRIVATE mlx)
|
||||
target_compile_definitions(core PRIVATE _VERSION_=${MLX_VERSION})
|
||||
|
||||
if(BUILD_SHARED_LIBS)
|
||||
if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
||||
|
@@ -320,6 +320,7 @@ void init_array(nb::module_& m) {
|
||||
.def_prop_ro(
|
||||
"shape",
|
||||
[](const mx::array& a) { return nb::cast(a.shape()); },
|
||||
nb::sig("def shape(self) -> tuple[int, ...]"),
|
||||
R"pbdoc(
|
||||
The shape of the array as a Python tuple.
|
||||
|
||||
@@ -347,6 +348,7 @@ void init_array(nb::module_& m) {
|
||||
.def(
|
||||
"item",
|
||||
&to_scalar,
|
||||
nb::sig("def item(self) -> scalar"),
|
||||
R"pbdoc(
|
||||
Access the value of a scalar array.
|
||||
|
||||
@@ -356,6 +358,7 @@ void init_array(nb::module_& m) {
|
||||
.def(
|
||||
"tolist",
|
||||
&tolist,
|
||||
nb::sig("def tolist(self) -> list_or_scalar"),
|
||||
R"pbdoc(
|
||||
Convert the array to a Python :class:`list`.
|
||||
|
||||
|
@@ -164,8 +164,13 @@ void init_fast(nb::module_& parent_module) {
|
||||
R"pbdoc(
|
||||
Apply rotary positional encoding to the input.
|
||||
|
||||
The input is expected to be at least 3D with shape ``(B, *, T, D)`` where:
|
||||
* ``B`` is the batch size.
|
||||
* ``T`` is the sequence length.
|
||||
* ``D`` is the feature dimension.
|
||||
|
||||
Args:
|
||||
a (array): Input array.
|
||||
a (array): The input array.
|
||||
dims (int): The feature dimensions to be rotated. If the input feature
|
||||
is larger than dims then the rest is left unchanged.
|
||||
traditional (bool): If set to ``True`` choose the traditional
|
||||
@@ -174,7 +179,9 @@ void init_fast(nb::module_& parent_module) {
|
||||
each dimension in the positional encodings. Exactly one of ``base`` and
|
||||
``freqs`` must be ``None``.
|
||||
scale (float): The scale used to scale the positions.
|
||||
offset (int or array): The position offset to start at.
|
||||
offset (int or array): The position offset to start at. If an
|
||||
:obj:`array` is given it can be a scalar or vector of ``B``
|
||||
offsets for each example in the batch.
|
||||
freqs (array, optional): Optional frequencies to use with RoPE.
|
||||
If set, the ``base`` parameter must be ``None``. Default: ``None``.
|
||||
|
||||
@@ -189,6 +196,7 @@ void init_fast(nb::module_& parent_module) {
|
||||
const mx::array& values,
|
||||
const float scale,
|
||||
const std::variant<std::monostate, std::string, mx::array>& mask,
|
||||
const std::optional<mx::array>& sinks,
|
||||
mx::StreamOrDevice s) {
|
||||
bool has_mask = !std::holds_alternative<std::monostate>(mask);
|
||||
bool has_str_mask =
|
||||
@@ -205,16 +213,16 @@ void init_fast(nb::module_& parent_module) {
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
return mx::fast::scaled_dot_product_attention(
|
||||
queries, keys, values, scale, mask_str, {}, s);
|
||||
queries, keys, values, scale, mask_str, {}, sinks, s);
|
||||
} else {
|
||||
auto mask_arr = std::get<mx::array>(mask);
|
||||
return mx::fast::scaled_dot_product_attention(
|
||||
queries, keys, values, scale, "", {mask_arr}, s);
|
||||
queries, keys, values, scale, "", {mask_arr}, sinks, s);
|
||||
}
|
||||
|
||||
} else {
|
||||
return mx::fast::scaled_dot_product_attention(
|
||||
queries, keys, values, scale, "", {}, s);
|
||||
queries, keys, values, scale, "", {}, sinks, s);
|
||||
}
|
||||
},
|
||||
"q"_a,
|
||||
@@ -223,9 +231,10 @@ void init_fast(nb::module_& parent_module) {
|
||||
nb::kw_only(),
|
||||
"scale"_a,
|
||||
"mask"_a = nb::none(),
|
||||
"sinks"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: Union[None, str, array] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: Union[None, str, array] = None, sinks: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
A fast implementation of multi-head attention: ``O = softmax(Q @ K.T, dim=-1) @ V``.
|
||||
|
||||
@@ -255,14 +264,17 @@ void init_fast(nb::module_& parent_module) {
|
||||
q (array): Queries with shape ``[B, N_q, T_q, D]``.
|
||||
k (array): Keys with shape ``[B, N_kv, T_kv, D]``.
|
||||
v (array): Values with shape ``[B, N_kv, T_kv, D]``.
|
||||
scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``)
|
||||
mask (Union[None, str, array], optional): The mask to apply to the
|
||||
scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``).
|
||||
mask (str or array, optional): The mask to apply to the
|
||||
query-key scores. The mask can be an array or a string indicating
|
||||
the mask type. The only supported string type is ``"causal"``. If
|
||||
the mask is an array it can be a boolean or additive mask. The mask
|
||||
can have at most 4 dimensions and must be broadcast-compatible with
|
||||
the shape ``[B, N, T_q, T_kv]``. If an additive mask is given its
|
||||
type must promote to the promoted type of ``q``, ``k``, and ``v``.
|
||||
sinks (array, optional): An optional array of attention sinks.
|
||||
Default: ``None``.
|
||||
|
||||
Returns:
|
||||
array: The output array.
|
||||
|
||||
|
@@ -447,6 +447,8 @@ void init_linalg(nb::module_& parent_module) {
|
||||
"a"_a,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def eig(a: array, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array]"),
|
||||
R"pbdoc(
|
||||
Compute the eigenvalues and eigenvectors of a square matrix.
|
||||
|
||||
@@ -523,6 +525,8 @@ void init_linalg(nb::module_& parent_module) {
|
||||
"UPLO"_a = "L",
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def eigh(a: array, UPLO: str = 'L', *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array]"),
|
||||
R"pbdoc(
|
||||
Compute the eigenvalues and eigenvectors of a complex Hermitian or
|
||||
real symmetric matrix.
|
||||
|
@@ -2,9 +2,9 @@
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
|
||||
#define STRINGIFY(x) #x
|
||||
#define TOSTRING(x) STRINGIFY(x)
|
||||
#include "mlx/version.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
namespace nb = nanobind;
|
||||
|
||||
void init_mlx_func(nb::module_&);
|
||||
@@ -48,5 +48,5 @@ NB_MODULE(core, m) {
|
||||
init_distributed(m);
|
||||
init_export(m);
|
||||
|
||||
m.attr("__version__") = TOSTRING(_VERSION_);
|
||||
m.attr("__version__") = mx::version();
|
||||
}
|
||||
|
@@ -4271,7 +4271,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def dequantize(w: array, /, scales: array, biases: Optional[array] = = None, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def dequantize(w: array, /, scales: array, biases: Optional[array] = None, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Dequantize the matrix ``w`` using quantization parameters.
|
||||
|
||||
|
@@ -171,7 +171,7 @@ void init_random(nb::module_& parent_module) {
|
||||
"key"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def normal(shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, loc: Optional[scalar, array] = None, scale: Optional[scalar, array] = None, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def normal(shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, loc: Union[scalar, array, None] = None, scale: Union[scalar, array, None] = None, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Generate normally distributed random numbers.
|
||||
|
||||
|
@@ -91,7 +91,7 @@ mx::array to_array_with_accessor(nb::object obj) {
|
||||
return nb::cast<mx::array>(obj.attr("__mlx_array__")());
|
||||
} else {
|
||||
std::ostringstream msg;
|
||||
msg << "Invalid type " << nb::type_name(obj.type()).c_str()
|
||||
msg << "Invalid type " << nb::type_name(obj.type()).c_str()
|
||||
<< " received in array initialization.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
@@ -594,124 +594,123 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
np.random.seed(0)
|
||||
# Batched matmul
|
||||
alpha = 0.5
|
||||
beta = 2.0
|
||||
for beta in (1.0, 2.0):
|
||||
# c must broadcast to the output shape
|
||||
with self.assertRaises(ValueError):
|
||||
mx.addmm(mx.zeros((2, 2, 2)), mx.zeros((2, 2)), mx.zeros((2, 2)))
|
||||
|
||||
# c must broadcast to the output shape
|
||||
with self.assertRaises(ValueError):
|
||||
mx.addmm(mx.zeros((2, 2, 2)), mx.zeros((2, 2)), mx.zeros((2, 2)))
|
||||
# Regular batched case
|
||||
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
||||
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 16)).astype(np.float32)
|
||||
|
||||
# Regular batched case
|
||||
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
||||
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 16)).astype(np.float32)
|
||||
a_mlx = mx.array(a_npy)
|
||||
b_mlx = mx.array(b_npy)
|
||||
|
||||
a_mlx = mx.array(a_npy)
|
||||
b_mlx = mx.array(b_npy)
|
||||
for c_shape in ((1,), (1, 16), (32, 1, 16), (1, 128, 16)):
|
||||
c_npy = np.ones(c_shape).astype(np.float32)
|
||||
c_mlx = mx.array(c_npy)
|
||||
|
||||
for c_shape in ((1,), (1, 16), (32, 1, 16), (1, 128, 16)):
|
||||
c_npy = np.ones(c_shape).astype(np.float32)
|
||||
c_mlx = mx.array(c_npy)
|
||||
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
|
||||
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
|
||||
|
||||
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
|
||||
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
|
||||
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
||||
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
||||
|
||||
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
||||
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
||||
# Batched and transposed matmul
|
||||
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
||||
b_mlx = mx.array(b_npy)
|
||||
|
||||
# Batched and transposed matmul
|
||||
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
||||
b_mlx = mx.array(b_npy)
|
||||
for c_shape in ((1,), (32, 1, 128), (1, 128)):
|
||||
c_npy = np.ones(c_shape).astype(np.float32)
|
||||
c_mlx = mx.array(c_npy)
|
||||
|
||||
for c_shape in ((1,), (32, 1, 128), (1, 128)):
|
||||
c_npy = np.ones(c_shape).astype(np.float32)
|
||||
c_mlx = mx.array(c_npy)
|
||||
b_np_t = np.transpose(b_npy, (0, 2, 1))
|
||||
b_mx_t = mx.transpose(b_mlx, (0, 2, 1))
|
||||
|
||||
b_np_t = np.transpose(b_npy, (0, 2, 1))
|
||||
b_mx_t = mx.transpose(b_mlx, (0, 2, 1))
|
||||
d_npy = alpha * (a_npy @ b_np_t) + beta * c_npy
|
||||
d_mlx = mx.addmm(c_mlx, a_mlx, b_mx_t, alpha, beta)
|
||||
|
||||
d_npy = alpha * (a_npy @ b_np_t) + beta * c_npy
|
||||
d_mlx = mx.addmm(c_mlx, a_mlx, b_mx_t, alpha, beta)
|
||||
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
||||
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
||||
# Batched matmul with simple broadcast
|
||||
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
||||
b_npy = np.random.normal(0.0, 1.0 / 128, (16, 16)).astype(np.float32)
|
||||
|
||||
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
||||
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
||||
# Batched matmul with simple broadcast
|
||||
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
||||
b_npy = np.random.normal(0.0, 1.0 / 128, (16, 16)).astype(np.float32)
|
||||
a_mlx = mx.array(a_npy)
|
||||
b_mlx = mx.array(b_npy)
|
||||
|
||||
a_mlx = mx.array(a_npy)
|
||||
b_mlx = mx.array(b_npy)
|
||||
for c_shape in ((1,), (1, 16), (32, 1, 16), (1, 128, 16)):
|
||||
c_npy = np.ones(c_shape).astype(np.float32)
|
||||
c_mlx = mx.array(c_npy)
|
||||
|
||||
for c_shape in ((1,), (1, 16), (32, 1, 16), (1, 128, 16)):
|
||||
c_npy = np.ones(c_shape).astype(np.float32)
|
||||
c_mlx = mx.array(c_npy)
|
||||
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
|
||||
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
|
||||
|
||||
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
|
||||
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
|
||||
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
||||
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
||||
# Matmul with vector
|
||||
a_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32)
|
||||
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 128)).astype(np.float32)
|
||||
a_mlx = mx.array(a_npy)
|
||||
b_mlx = mx.array(b_npy)
|
||||
|
||||
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
||||
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
||||
# Matmul with vector
|
||||
a_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32)
|
||||
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 128)).astype(np.float32)
|
||||
a_mlx = mx.array(a_npy)
|
||||
b_mlx = mx.array(b_npy)
|
||||
for c_shape in ((1,), (128,), (32, 128)):
|
||||
c_npy = np.ones(c_shape).astype(np.float32)
|
||||
c_mlx = mx.array(c_npy)
|
||||
|
||||
for c_shape in ((1,), (128,), (32, 128)):
|
||||
c_npy = np.ones(c_shape).astype(np.float32)
|
||||
c_mlx = mx.array(c_npy)
|
||||
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
|
||||
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
|
||||
|
||||
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
|
||||
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
|
||||
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
||||
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
||||
|
||||
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
||||
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
||||
# Matmul with vector
|
||||
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
||||
b_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32)
|
||||
a_mlx = mx.array(a_npy)
|
||||
b_mlx = mx.array(b_npy)
|
||||
|
||||
# Matmul with vector
|
||||
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
||||
b_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32)
|
||||
a_mlx = mx.array(a_npy)
|
||||
b_mlx = mx.array(b_npy)
|
||||
for c_shape in ((1,), (32, 128)):
|
||||
c_npy = np.ones(c_shape).astype(np.float32)
|
||||
c_mlx = mx.array(c_npy)
|
||||
|
||||
for c_shape in ((1,), (32, 128)):
|
||||
c_npy = np.ones(c_shape).astype(np.float32)
|
||||
c_mlx = mx.array(c_npy)
|
||||
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
|
||||
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
|
||||
|
||||
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
|
||||
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
|
||||
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
||||
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
||||
|
||||
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
||||
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
||||
# Split K specializtion
|
||||
a_npy = np.random.normal(0.0, 1.0 / 128, (64, 4096)).astype(np.float32)
|
||||
b_npy = np.random.normal(0.0, 1.0 / 128, (4096, 32)).astype(np.float32)
|
||||
|
||||
# Split K specializtion
|
||||
a_npy = np.random.normal(0.0, 1.0 / 128, (64, 4096)).astype(np.float32)
|
||||
b_npy = np.random.normal(0.0, 1.0 / 128, (4096, 32)).astype(np.float32)
|
||||
a_mlx = mx.array(a_npy)
|
||||
b_mlx = mx.array(b_npy)
|
||||
|
||||
a_mlx = mx.array(a_npy)
|
||||
b_mlx = mx.array(b_npy)
|
||||
for c_shape in ((1,), (1, 32), (64, 1), (64, 32)):
|
||||
c_npy = np.ones(c_shape).astype(np.float32)
|
||||
c_mlx = mx.array(c_npy)
|
||||
|
||||
for c_shape in ((1,), (1, 32), (64, 1), (64, 32)):
|
||||
c_npy = np.ones(c_shape).astype(np.float32)
|
||||
c_mlx = mx.array(c_npy)
|
||||
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
|
||||
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
|
||||
|
||||
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
|
||||
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
|
||||
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
||||
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
||||
|
||||
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
||||
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
||||
# Transposed c
|
||||
a = mx.ones((10, 5)).T
|
||||
b = mx.ones((5, 5))
|
||||
out = mx.addmm(a, b, a, beta=beta, alpha=alpha)
|
||||
expected = beta * a + alpha * (b @ a)
|
||||
self.assertTrue(mx.allclose(expected, out))
|
||||
|
||||
# Transposed c
|
||||
a = mx.ones((10, 5)).T
|
||||
b = mx.ones((5, 5))
|
||||
out = mx.addmm(a, b, a, beta=1.5, alpha=0.5)
|
||||
expected = 1.5 * a + 0.5 * (b @ a)
|
||||
self.assertTrue(mx.allclose(expected, out))
|
||||
|
||||
# Broadcast c
|
||||
a = mx.ones((5, 5))
|
||||
b = mx.ones((5, 5))
|
||||
c = mx.ones((1, 5))
|
||||
out = mx.addmm(c, a, b, beta=1.5, alpha=0.5)
|
||||
expected = 1.5 * c + 0.5 * (a @ b)
|
||||
self.assertTrue(mx.allclose(expected, out))
|
||||
# Broadcast c
|
||||
a = mx.ones((5, 5))
|
||||
b = mx.ones((5, 5))
|
||||
c = mx.ones((1, 5))
|
||||
out = mx.addmm(c, a, b, beta=beta, alpha=alpha)
|
||||
expected = beta * c + alpha * (a @ b)
|
||||
self.assertTrue(mx.allclose(expected, out))
|
||||
|
||||
def test_addmm_grad(self):
|
||||
def make_ref_addmm(alpha, beta):
|
||||
@@ -724,33 +723,32 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
shapes = ((1, 64, 32, 128), (4, 28, 24, 47), (1, 1, 24, 47))
|
||||
|
||||
alpha = 2.0
|
||||
beta = 0.5
|
||||
for beta in (1.0, 0.5):
|
||||
f_test = make_addmm(alpha, beta)
|
||||
f_ref = make_ref_addmm(alpha, beta)
|
||||
|
||||
f_test = make_addmm(alpha, beta)
|
||||
f_ref = make_ref_addmm(alpha, beta)
|
||||
for B, M, N, K in shapes:
|
||||
cotan = mx.ones((B, M, N))
|
||||
c = mx.random.normal((B, M, N))
|
||||
a = mx.random.normal((B, M, K))
|
||||
b = mx.random.normal((B, K, N))
|
||||
|
||||
for B, M, N, K in shapes:
|
||||
cotan = mx.ones((B, M, N))
|
||||
c = mx.random.normal((B, M, N))
|
||||
a = mx.random.normal((B, M, K))
|
||||
b = mx.random.normal((B, K, N))
|
||||
out_ref, dout_ref = mx.vjp(
|
||||
f_ref,
|
||||
[c, a, b],
|
||||
[cotan],
|
||||
)
|
||||
out_test, dout_test = mx.vjp(
|
||||
f_test,
|
||||
[c, a, b],
|
||||
[cotan],
|
||||
)
|
||||
|
||||
out_ref, dout_ref = mx.vjp(
|
||||
f_ref,
|
||||
[c, a, b],
|
||||
[cotan],
|
||||
)
|
||||
out_test, dout_test = mx.vjp(
|
||||
f_test,
|
||||
[c, a, b],
|
||||
[cotan],
|
||||
)
|
||||
self.assertTrue(mx.allclose(out_ref[0], out_test[0], atol=1e-4).item())
|
||||
|
||||
self.assertTrue(mx.allclose(out_ref[0], out_test[0], atol=1e-4).item())
|
||||
|
||||
for r, t in zip(dout_ref, dout_test):
|
||||
self.assertEqual(r.shape, t.shape)
|
||||
self.assertTrue(mx.allclose(r, t, atol=1e-4).item())
|
||||
for r, t in zip(dout_ref, dout_test):
|
||||
self.assertEqual(r.shape, t.shape)
|
||||
self.assertTrue(mx.allclose(r, t, atol=1e-4).item())
|
||||
|
||||
def test_empty_matmul(self):
|
||||
a = mx.array([[], []]).T
|
||||
|
@@ -8,18 +8,23 @@ import mlx_tests
|
||||
|
||||
|
||||
def rope_orig(x, dims, traditional, base, scale, offset, freqs=None):
|
||||
offset = offset.item() if isinstance(offset, mx.array) else offset
|
||||
N = x.shape[-2] + offset
|
||||
N = x.shape[-2]
|
||||
dtype = x.dtype
|
||||
half_D = dims // 2
|
||||
positions = mx.arange(offset, N, dtype=dtype) * scale
|
||||
positions = mx.arange(N, dtype=dtype)
|
||||
if isinstance(offset, mx.array) and offset.size > 1:
|
||||
expand = tuple(range(1, x.ndim - 1))
|
||||
positions = mx.expand_dims(offset, expand) + positions
|
||||
else:
|
||||
positions = offset + positions
|
||||
positions = positions * scale
|
||||
if freqs is None:
|
||||
inv_freqs = mx.exp(
|
||||
-mx.arange(0.0, half_D, dtype=dtype) * (math.log(base) / half_D)
|
||||
)
|
||||
else:
|
||||
inv_freqs = (1 / freqs).astype(x.dtype)
|
||||
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(inv_freqs, (1, -1))
|
||||
theta = mx.expand_dims(positions, -1) * inv_freqs
|
||||
costheta, sintheta = mx.cos(theta), mx.sin(theta)
|
||||
if traditional:
|
||||
x1 = x[..., :dims:2]
|
||||
@@ -214,6 +219,7 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
)
|
||||
self.assertEqual(dtype, rx.dtype)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
return
|
||||
|
||||
# Test single vector
|
||||
x = mx.random.uniform(shape=(1, 1, dims))
|
||||
@@ -277,6 +283,55 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
g2 = mx.grad(f2)(x, y)
|
||||
self.assertLess(mx.abs(g1 - g2).max(), 1e-5)
|
||||
|
||||
def test_rope_batch(self):
|
||||
T = 4
|
||||
base = 10000.0
|
||||
scale = 1.0
|
||||
traditional = True
|
||||
batch_sizes = [3, 8, 11]
|
||||
num_heads = [1, 3, 5]
|
||||
dims = 32
|
||||
|
||||
x = mx.random.uniform(shape=(8, 4, T, dims))
|
||||
|
||||
offset = mx.array([1, 2, 3])
|
||||
with self.assertRaises(ValueError):
|
||||
mx.fast.rope(
|
||||
x,
|
||||
dims,
|
||||
traditional=traditional,
|
||||
base=base,
|
||||
scale=scale,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
for batch_size in batch_sizes:
|
||||
for n_head in num_heads:
|
||||
x = mx.random.uniform(shape=(batch_size, n_head, T, dims))
|
||||
offset = mx.arange(batch_size)
|
||||
rx = rope_orig(x, dims, traditional, base, scale, offset)
|
||||
rx_fast = mx.fast.rope(
|
||||
x,
|
||||
dims,
|
||||
traditional=traditional,
|
||||
base=base,
|
||||
scale=scale,
|
||||
offset=offset,
|
||||
)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), 1e-5)
|
||||
x = mx.random.normal(shape=(2, 6, 8, 64)).transpose(0, 2, 1, 3)
|
||||
dims = 64
|
||||
offset = 0
|
||||
rx_fast = mx.fast.rope(
|
||||
x, dims, traditional=traditional, scale=scale, base=base, offset=offset
|
||||
)
|
||||
rx_fast_single = mx.fast.rope(
|
||||
x[0:1], dims, traditional=traditional, scale=scale, base=base, offset=offset
|
||||
)
|
||||
|
||||
rx = rope_orig(x, dims, traditional, base, scale, offset)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), 1e-5)
|
||||
|
||||
def test_rms_norm(self):
|
||||
# Per dtype absolute tolerance
|
||||
tolerances = {mx.float32: 1e-6, mx.float16: 1e-3, mx.bfloat16: 1e-2}
|
||||
|
@@ -6,7 +6,7 @@ import mlx_tests
|
||||
import numpy as np
|
||||
|
||||
|
||||
def mlx_ref_attn(q, k, v, scale=1.0, mask=None):
|
||||
def mlx_ref_attn(q, k, v, scale=1.0, mask=None, sinks=None):
|
||||
q_dtype = q.dtype
|
||||
q = q * mx.array(scale, q_dtype)
|
||||
n_q_heads = q.shape[-3]
|
||||
@@ -23,7 +23,6 @@ def mlx_ref_attn(q, k, v, scale=1.0, mask=None):
|
||||
v = mx.expand_dims(v, 2)
|
||||
|
||||
scores = q @ mx.swapaxes(k, -1, -2)
|
||||
|
||||
if mask is not None:
|
||||
|
||||
if mask == "causal":
|
||||
@@ -43,7 +42,18 @@ def mlx_ref_attn(q, k, v, scale=1.0, mask=None):
|
||||
else:
|
||||
scores += mask
|
||||
|
||||
if sinks is not None:
|
||||
sinks = mx.expand_dims(sinks, (0, 2, 3))
|
||||
if n_repeats > 1:
|
||||
sinks = mx.unflatten(sinks, 1, (n_kv_heads, n_repeats))
|
||||
score_shape = list(scores.shape)
|
||||
score_shape[-1] = 1
|
||||
sinks = mx.broadcast_to(sinks, score_shape)
|
||||
scores = mx.concatenate([sinks, scores], axis=-1)
|
||||
|
||||
scores = mx.softmax(scores, axis=-1, precise=True)
|
||||
if sinks is not None:
|
||||
scores = scores[..., 1:]
|
||||
|
||||
out = scores @ v
|
||||
if n_repeats > 1:
|
||||
@@ -158,7 +168,7 @@ class TestFastSelfAttentionSDPA(mlx_tests.MLXTestCase):
|
||||
|
||||
Dk = 64
|
||||
|
||||
if self.is_apple_silicon:
|
||||
if self.is_apple_silicon or mx.cuda.is_available():
|
||||
dtypes.append(np.half)
|
||||
|
||||
for SEQUENCE_LENGTH in [63, 129, 400]:
|
||||
@@ -230,7 +240,7 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
|
||||
B = 1
|
||||
H = 32
|
||||
dtypes = [np.float32]
|
||||
if self.is_apple_silicon:
|
||||
if self.is_apple_silicon or mx.cuda.is_available():
|
||||
dtypes.append(np.half)
|
||||
|
||||
for SEQUENCE_LENGTH in [1, 7, 9, 32, 63, 67, 129, 400, 2000]:
|
||||
@@ -400,15 +410,30 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
|
||||
|
||||
def test_fully_masked(self):
|
||||
Lkv = 8
|
||||
mask = mx.array(False)
|
||||
masks = [mx.array(False), mx.array(-float("inf"))]
|
||||
for mask in masks:
|
||||
for D in [4, 128]:
|
||||
for Lq in [1, 8]:
|
||||
q = mx.random.normal(shape=(1, 4, Lq, D))
|
||||
k = mx.random.normal(shape=(1, 4, Lkv, D))
|
||||
v = mx.random.normal(shape=(1, 4, Lkv, D))
|
||||
|
||||
out = mx.fast.scaled_dot_product_attention(
|
||||
q, k, v, mask=mask, scale=1
|
||||
)
|
||||
self.assertTrue(mx.all(mx.isnan(out)))
|
||||
|
||||
def test_inf_score(self):
|
||||
Lkv = 8
|
||||
for D in [4, 128]:
|
||||
for Lq in [1, 8]:
|
||||
q = mx.random.normal(shape=(1, 4, Lq, D))
|
||||
k = mx.random.normal(shape=(1, 4, Lkv, D))
|
||||
q = mx.ones(shape=(1, 4, Lq, D))
|
||||
k = mx.ones(shape=(1, 4, Lkv, D))
|
||||
v = mx.random.normal(shape=(1, 4, Lkv, D))
|
||||
|
||||
out = mx.fast.scaled_dot_product_attention(q, k, v, mask=mask, scale=1)
|
||||
self.assertTrue(mx.all(mx.isnan(out)))
|
||||
k[..., 0, :] = -float("inf")
|
||||
ref = mlx_primitives_sdpa(q, k, v, scale=1, mask=None)
|
||||
out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1)
|
||||
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
||||
|
||||
def test_fast_sdpa_few_query(self):
|
||||
D = 64
|
||||
@@ -619,6 +644,17 @@ class TestSDPA(mlx_tests.MLXTestCase):
|
||||
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
|
||||
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
||||
|
||||
def test_sdpa_noncontiguous_inputs(self):
|
||||
mask = mx.ones(shape=(4, 1, 7, 7), dtype=mx.bool_)
|
||||
mx.random.seed(0)
|
||||
q = mx.random.normal(shape=(4, 7, 32, 64)).swapaxes(1, 2)
|
||||
|
||||
k = mx.random.normal(shape=(4, 7, 8, 64)).swapaxes(1, 2)
|
||||
v = mx.random.normal(shape=(4, 7, 8, 64)).swapaxes(1, 2)
|
||||
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask)
|
||||
ref = mlx_ref_attn(q, k, v, scale=1.0, mask=mask)
|
||||
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
||||
|
||||
def test_sdpa_promote_mask(self):
|
||||
mask = mx.array(2.0, mx.bfloat16)
|
||||
D = 64
|
||||
@@ -663,6 +699,51 @@ class TestSDPA(mlx_tests.MLXTestCase):
|
||||
self.assertFalse(mx.isnan(out).any().item())
|
||||
self.assertLessEqual(mx.abs(out - expected).max().item(), 1e-4)
|
||||
|
||||
def test_sdpa_attention_sinks(self):
|
||||
B = 2
|
||||
N_q = N_kv = 8
|
||||
T_q = T_kv = 128
|
||||
D = 64
|
||||
|
||||
q = mx.random.normal(shape=(B, N_q, T_q, D))
|
||||
k = mx.random.normal(shape=(B, N_kv, T_kv, D))
|
||||
v = mx.random.normal(shape=(B, N_kv, T_kv, D))
|
||||
scale = D**-0.5
|
||||
|
||||
# sinks should promote to correct type
|
||||
sinks = mx.random.normal(shape=(N_q,))
|
||||
with self.assertRaises(ValueError):
|
||||
mx.fast.scaled_dot_product_attention(
|
||||
q.astype(mx.float16),
|
||||
k.astype(mx.float16),
|
||||
v.astype(mx.float16),
|
||||
scale=scale,
|
||||
sinks=sinks,
|
||||
)
|
||||
|
||||
# Wrong shapes
|
||||
sinks = mx.random.normal(shape=(N_q + 1,))
|
||||
with self.assertRaises(ValueError):
|
||||
mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, sinks=sinks)
|
||||
|
||||
sinks = mx.random.normal(shape=())
|
||||
with self.assertRaises(ValueError):
|
||||
mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, sinks=sinks)
|
||||
|
||||
for T_kv in [128, 4096]:
|
||||
for T_q in [1, 128]:
|
||||
for N_kv in [2, 8]:
|
||||
q = mx.random.normal(shape=(B, N_q, T_q, D))
|
||||
k = mx.random.normal(shape=(B, N_kv, T_kv, D))
|
||||
v = mx.random.normal(shape=(B, N_kv, T_kv, D))
|
||||
sinks = 10 * mx.random.normal(shape=(N_q,))
|
||||
|
||||
expected = mlx_ref_attn(q, k, v, scale, sinks=sinks)
|
||||
out = mx.fast.scaled_dot_product_attention(
|
||||
q, k, v, scale=scale, sinks=sinks
|
||||
)
|
||||
self.assertTrue(mx.allclose(out, expected, atol=1e-5))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mlx_tests.MLXTestRunner(failfast=True)
|
||||
|
7
setup.py
7
setup.py
@@ -121,7 +121,8 @@ class CMakeBuild(build_ext):
|
||||
build_args += [f"-j{os.cpu_count()}"]
|
||||
|
||||
# Avoid cache miss when building from temporary dirs.
|
||||
os.environ["CCACHE_BASEDIR"] = os.path.abspath(self.build_temp)
|
||||
os.environ["CCACHE_BASEDIR"] = os.path.realpath(self.build_temp)
|
||||
os.environ["CCACHE_NOHASHDIR"] = "true"
|
||||
|
||||
subprocess.run(
|
||||
["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True
|
||||
@@ -176,10 +177,6 @@ class GenerateStubs(Command):
|
||||
# Run again without recursive to specify output file name
|
||||
subprocess.run(["rm", f"{out_path}/mlx.pyi"])
|
||||
subprocess.run(stub_cmd + ["-o", f"{out_path}/__init__.pyi"])
|
||||
# mx.bool_ gets filtered by nanobind because of the trailing
|
||||
# underscore, add it manually:
|
||||
with open(f"{out_path}/__init__.pyi", "a") as fid:
|
||||
fid.write("\nbool_: Dtype = ...")
|
||||
|
||||
|
||||
class MLXBdistWheel(bdist_wheel):
|
||||
|
@@ -1,10 +1,7 @@
|
||||
# Doctest works fine with cmake 3.5
|
||||
set(CMAKE_POLICY_VERSION_MINIMUM 3.5)
|
||||
|
||||
FetchContent_Declare(
|
||||
doctest
|
||||
GIT_REPOSITORY "https://github.com/onqtam/doctest"
|
||||
GIT_TAG "ae7a13539fb71f270b87eb2e874fbac80bc8dda2")
|
||||
GIT_REPOSITORY https://github.com/onqtam/doctest.git
|
||||
GIT_TAG v2.4.12)
|
||||
FetchContent_MakeAvailable(doctest)
|
||||
|
||||
add_executable(tests ${PROJECT_SOURCE_DIR}/tests/tests.cpp)
|
||||
|
Reference in New Issue
Block a user