Compare commits

..

13 Commits

Author SHA1 Message Date
Cheng
44cc5da4bc [CUDA] Fix alpha not respected when using bias epilogue (#2578) 2025-09-10 09:08:01 +09:00
Cheng
dde3682b69 [CUDA] Use GEMM with epilogue instead of AddMM (#2569) 2025-09-09 13:18:49 +09:00
Awni Hannun
17310d91a6 Add batch offsets for mx.fast.rope (#2564)
* implement batch rope for Metal

* cuda rope (#2576)
2025-09-08 17:35:07 -07:00
Cheng
b194d65a6a Some tweaks in cmake files (#2574)
* Do proper check of Metal lib

* Update doctest to get rid of cmake version hack
2025-09-09 08:27:18 +09:00
Cheng
a44b27f5f8 Fix a few ccache cache miss (#2573)
* Fix ccache cache miss

* Do not define _VERSION_ in python bindings
2025-09-09 07:41:05 +09:00
Awni Hannun
e5a33f2223 faster depthwise 1D conv (#2567) 2025-09-08 11:37:23 -07:00
Cheng
c1e3340b23 Set ccache size before building (#2570) 2025-09-07 09:00:31 +09:00
XXXXRT666
8f163a367d typing: add type hints to mlx.core.array, linalg, distributed, and random (#2565)
* Add type annotations to mlx methods

* Missing list_or_scalar
2025-09-04 09:08:11 -07:00
Manuel Villanueva
89a3df9014 Fixed several type annotations in the MLX stubs which degraded to Unknown/Any (#2560)
* Added scalar to stubs to fix Unkown Type Hint

### Proposed changes

Issue #2478 reports that several type annotations in the MLX stubs degrade to Unknown/Any in editors like VS Code with Pylance, due to missing imports (Union, Optional, Tuple) and an undefined scalar type alias.

This PR updates the stub generation patterns to:
	•	Add missing typing imports in mlx.core.__prefix__ so that Union, Optional, Tuple, etc. are always available.
	•	Define and export scalar: TypeAlias = Union[int, float, bool] in mlx.core.__suffix__ so that functions typed with Union[scalar, array] resolve correctly instead of falling back to Any.
	•	Update submodule stub prefixes (distributed, fast, linalg, metal, random) to import scalar alongside array, Device, and Stream, ensuring type checkers resolve the union consistently across modules.

With these changes, functions like mlx.add now display rich type signatures such as:

```
def add(
    a: scalar | array,
    b: scalar | array,
    stream: Stream | Device | None = None
) -> array
```

instead of degrading to Any.

### Checklist

	•	I have read the CONTRIBUTING document
	•	I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
	•	I have added tests that prove my fix is effective or that my feature works (n/a — stub generation only)
	•	I have updated the necessary documentation (if needed)

* add bool to patterns

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-09-03 12:52:08 -07:00
Krishi Saripalli
c5d2937aa5 chore: Update Docs With Slice Copy Example (#2559)
* chore: updated docs with slice copy example

* nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-09-02 22:07:02 -07:00
Awni Hannun
b61a65e313 fix copies in sdpa (#2563) 2025-09-02 11:00:36 -07:00
wrmsr
04cbb4191c Fix dequantize python sig (#2562) 2025-09-01 11:50:20 -07:00
Artur Antonov
c5460762e7 Fix AdamW weight_decay default value in docstring (#2557) 2025-08-31 21:29:30 -07:00
32 changed files with 746 additions and 427 deletions

View File

@@ -230,6 +230,9 @@ jobs:
sudo mv ccache-4.11.3-linux-x86_64/ccache /usr/bin/ccache
rm -rf ccache-4.11.3-linux-x86_64
curl -LsSf https://astral.sh/uv/install.sh | sh
- run:
name: Set CCache size
command: ccache --max-size 1G
- run:
name: Install Python package
command: |
@@ -260,7 +263,6 @@ jobs:
command: |
ccache --show-stats
ccache --zero-stats
ccache --max-size 400MB
ccache --cleanup
- save_cache:
key: cuda-<< parameters.image_date >>-{{ arch }}-{{ epoch }}

View File

@@ -87,22 +87,21 @@ cmake_policy(SET CMP0135 NEW)
add_library(mlx)
if(MLX_BUILD_METAL)
set(METAL_LIB "-framework Metal")
set(FOUNDATION_LIB "-framework Foundation")
set(QUARTZ_LIB "-framework QuartzCore")
endif()
if(MLX_BUILD_CUDA)
enable_language(CUDA)
endif()
if(MLX_BUILD_METAL AND NOT METAL_LIB)
message(STATUS "Metal not found. Unable to build GPU")
set(MLX_BUILD_METAL OFF)
set(MLX_METAL_DEBUG OFF)
elseif(MLX_BUILD_METAL)
message(STATUS "Building METAL sources")
if(MLX_BUILD_METAL)
find_library(METAL_LIB Metal)
find_library(FOUNDATION_LIB Foundation)
find_library(QUARTZ_LIB QuartzCore)
if(METAL_LIB)
message(STATUS "Metal found ${METAL_LIB}")
else()
message(
FATAL_ERROR
"Metal not found. Set MLX_BUILD_METAL=OFF to build without GPU")
endif()
if(MLX_METAL_DEBUG)
add_compile_definitions(MLX_METAL_DEBUG)
@@ -111,7 +110,8 @@ elseif(MLX_BUILD_METAL)
# Throw an error if xcrun not found
execute_process(
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
OUTPUT_VARIABLE MACOS_SDK_VERSION COMMAND_ERROR_IS_FATAL ANY)
OUTPUT_VARIABLE MACOS_SDK_VERSION
OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
if(${MACOS_SDK_VERSION} LESS 14.0)
message(

View File

@@ -107,8 +107,20 @@ same array:
>>> a
array([1, 2, 0], dtype=int32)
Note that unlike NumPy, slicing an array creates a copy, not a view. So
mutating it does not mutate the original array:
Note, unlike NumPy, updates to the same location are nondeterministic:
.. code-block:: shell
>>> a = mx.array([1, 2, 3])
>>> b = a[:]
>>> b[2] = 0
>>> b
array([1, 2, 0], dtype=int32)
>>> a
array([1, 2, 3], dtype=int32)
Also unlike NumPy, updates to the same location are nondeterministic:
.. code-block:: shell

View File

@@ -85,10 +85,10 @@ cublasLtMatrixLayout_t create_matrix_layout(
int32_t batch_count,
int64_t batch_stride) {
cublasLtMatrixLayout_t desc;
if (transposed) {
std::swap(rows, cols);
}
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&desc, type, rows, cols, ld));
cublasLtOrder_t order = transposed ? CUBLASLT_ORDER_COL : CUBLASLT_ORDER_ROW;
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(cublasLtOrder_t)));
if (batch_count > 1) {
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc,
@@ -138,25 +138,34 @@ CublasGemm::CublasGemm(
CUBLASLT_MATMUL_DESC_POINTER_MODE,
&pointer_mode,
sizeof(int32_t)));
cublasOperation_t op = CUBLAS_OP_N;
// In cublasLt matrices use column-major layout, while it is possible to use
// the CUBLASLT_ORDER_ROW option to switch to row-major layout, the bias
// epilogue does not work with the option. So instead we swap A and B to make
// cublasLt return the row-major result, which works because:
// - the data of a matrix in row-major layout is identical to its transpose in
// column-major layout
// - C^T = (A @ B)^T = B^T @ A^T
cublasOperation_t a_op = b_transposed ? CUBLAS_OP_T : CUBLAS_OP_N;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
CUBLASLT_MATMUL_DESC_TRANSA,
&op,
&a_op,
sizeof(cublasOperation_t)));
cublasOperation_t b_op = a_transposed ? CUBLAS_OP_T : CUBLAS_OP_N;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
CUBLASLT_MATMUL_DESC_TRANSB,
&op,
&b_op,
sizeof(cublasOperation_t)));
auto type = dtype_to_cublas_type(dtype);
a_desc_ = create_matrix_layout(
type, a_rows, a_cols, a_transposed, lda, batch_count, a_batch_stride);
type, b_cols, b_rows, b_transposed, ldb, batch_count, b_batch_stride);
b_desc_ = create_matrix_layout(
type, b_rows, b_cols, b_transposed, ldb, batch_count, b_batch_stride);
type, a_cols, a_rows, a_transposed, lda, batch_count, a_batch_stride);
out_desc_ = create_matrix_layout(
type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols);
type, b_cols, a_rows, false, b_cols, batch_count, a_rows * b_cols);
}
CublasGemm::CublasGemm(
@@ -191,7 +200,7 @@ CublasGemm::CublasGemm(
b_batch_stride) {
auto type = dtype_to_cublas_type(dtype);
c_desc_ = create_matrix_layout(
type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride);
type, b_cols, a_rows, false, ldc, batch_count, c_batch_stride);
}
CublasGemm::~CublasGemm() {
@@ -213,14 +222,25 @@ void CublasGemm::set_out(
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_));
out_desc_ = create_matrix_layout(
dtype_to_cublas_type(dtype),
rows,
cols,
rows,
transposed,
ld,
batch_count,
batch_stride);
}
void CublasGemm::set_bias(void* bias) {
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
CUBLASLT_MATMUL_DESC_EPILOGUE,
&epilogue,
sizeof(epilogue)));
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias)));
}
void CublasGemm::run(
cu::CommandEncoder& encoder,
array& out,
@@ -228,11 +248,19 @@ void CublasGemm::run(
const array& b,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides) {
const Strides& b_batch_strides,
float alpha) {
int batch_count = out.size() / (M_ * N_);
if (batch_count / batch_shape.back() > 1) {
run_batched(
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
encoder,
out,
a,
b,
batch_shape,
a_batch_strides,
b_batch_strides,
alpha);
return;
}
@@ -240,7 +268,13 @@ void CublasGemm::run(
encoder.set_input_array(b);
encoder.set_output_array(out);
execute(encoder, out.data<void>(), a.data<void>(), b.data<void>(), nullptr);
execute(
encoder,
out.data<void>(),
a.data<void>(),
b.data<void>(),
nullptr,
alpha);
}
void CublasGemm::run(
@@ -330,9 +364,9 @@ void CublasGemm::execute(
handle_,
matmul_desc_,
&alpha,
a,
b, // a and b are swapped
a_desc_,
b,
a,
b_desc_,
&beta,
c ? c : out,

View File

@@ -55,6 +55,8 @@ class CublasGemm {
int32_t batch_count,
int64_t batch_stride);
void set_bias(void* bias);
void run(
cu::CommandEncoder& encoder,
array& out,
@@ -62,7 +64,8 @@ class CublasGemm {
const array& b,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides);
const Strides& b_batch_strides,
float alpha = 1.0f);
void run(
cu::CommandEncoder& encoder,
@@ -85,7 +88,8 @@ class CublasGemm {
const array& b,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides);
const Strides& b_batch_strides,
float alpha);
void run_batched(
cu::CommandEncoder& encoder,

View File

@@ -13,7 +13,8 @@ void CublasGemm::run_batched(
const array& b,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides) {
const Strides& b_batch_strides,
float alpha) {
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
@@ -27,7 +28,8 @@ void CublasGemm::run_batched(
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
a.data<int8_t>() + a.itemsize() * a_it.loc,
b.data<int8_t>() + b.itemsize() * b_it.loc,
nullptr);
nullptr,
alpha);
a_it.step();
b_it.step();
}

View File

@@ -154,7 +154,8 @@ void CublasGemm::run_batched(
const array& b,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides) {
const Strides& b_batch_strides,
float alpha) {
int batch_count = out.size() / (M_ * N_);
set_pointer_mode(a_desc_, batch_count);
set_pointer_mode(b_desc_, batch_count);
@@ -226,7 +227,8 @@ void CublasGemm::run_batched(
reinterpret_cast<void*>(out_pointers),
reinterpret_cast<void*>(a_pointers),
reinterpret_cast<void*>(b_pointers),
nullptr);
nullptr,
alpha);
}
void CublasGemm::run_batched(

View File

@@ -11,6 +11,7 @@
#include <numeric>
namespace mlx::core {
namespace {
std::tuple<bool, int64_t, array>
@@ -28,6 +29,76 @@ check_transpose(cu::CommandEncoder& enc, const Stream& s, const array& arr) {
}
}
void gemm_and_bias(
cu::CommandEncoder& encoder,
int M,
int N,
int K,
bool a_transposed,
int64_t lda,
bool b_transposed,
int64_t ldb,
array& out,
const array& a,
const array& b,
void* bias = nullptr,
float alpha = 1.0f) {
// Check and collapse batch dimensions
auto [batch_shape, a_batch_strides, b_batch_strides] = collapse_batches(a, b);
auto batch_count = out.size() / (M * N);
// Collapse batches into M if needed
if (batch_count > 1 && !a_transposed && batch_shape.size() == 1 &&
a.strides()[a.ndim() - 2] == K && a_batch_strides.back() == M * K &&
b_batch_strides.back() == 0) {
M *= batch_shape.back();
batch_count = 1;
a_batch_strides = {0};
b_batch_strides = {0};
batch_shape = {1};
}
// Use gemmv when possible
if (!bias && cu::can_use_gemv(M, N, K, a_transposed, b_transposed)) {
cu::gemv(
a,
b,
out,
M,
N,
K,
batch_count,
batch_shape,
a_batch_strides,
b_batch_strides,
encoder);
return;
}
// Invoke cublasLt
CublasGemm gemm(
encoder.device(),
a.dtype(),
a_transposed,
M,
K,
lda,
b_transposed,
K,
N,
ldb,
batch_shape.back(),
a_batch_strides.back(),
b_batch_strides.back());
if (bias) {
gemm.set_bias(bias);
}
gemm.run(
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides, alpha);
}
} // namespace
void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -48,9 +119,6 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc(out.nbytes()));
/////////////////////////////////////////////////////////////////////////////
// Init checks and prep
int M = a_pre.shape(-2);
int N = b_pre.shape(-1);
int K = a_pre.shape(-1);
@@ -60,58 +128,8 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre);
auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre);
/////////////////////////////////////////////////////////////////////////////
// Check and collapse batch dimensions
auto [batch_shape, a_batch_strides, b_batch_strides] = collapse_batches(a, b);
auto batch_count = out.size() / (M * N);
// Collapse batches into M if needed
if (batch_count > 1 && !a_transposed && batch_shape.size() == 1 &&
a.strides()[a.ndim() - 2] == K && a_batch_strides.back() == M * K &&
b_batch_strides.back() == 0) {
M *= batch_shape.back();
batch_count = 1;
a_batch_strides = {0};
b_batch_strides = {0};
batch_shape = {1};
}
if (cu::can_use_gemv(M, N, K, a_transposed, b_transposed)) {
cu::gemv(
a,
b,
out,
M,
N,
K,
batch_count,
batch_shape,
a_batch_strides,
b_batch_strides,
encoder);
return;
}
/////////////////////////////////////////////////////////////////////////////
// Invoke cublasLt
CublasGemm gemm(
cu::device(s.device),
a.dtype(),
a_transposed,
M,
K,
lda,
b_transposed,
K,
N,
ldb,
batch_shape.back(),
a_batch_strides.back(),
b_batch_strides.back());
gemm.run(encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
gemm_and_bias(
encoder, M, N, K, a_transposed, lda, b_transposed, ldb, out, a, b);
}
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -136,6 +154,28 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre);
auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre);
/////////////////////////////////////////////////////////////////////////////
// Dispatch to GEMM with epilogue or AddMM
if (beta_ == 1 && c.strides(-1) == 1 && c.data_size() == out.shape(-1)) {
out.set_data(allocator::malloc(out.nbytes()));
gemm_and_bias(
encoder,
M,
N,
K,
a_transposed,
lda,
b_transposed,
ldb,
out,
a,
b,
c.data<void>(),
alpha_);
return;
}
int64_t ldc;
{
auto stx = c.strides()[c.ndim() - 2];
@@ -177,7 +217,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
}
/////////////////////////////////////////////////////////////////////////////
// Invoke cublasLt
// Invoke cublasLt with AddMM settings
CublasGemm gemm(
cu::device(s.device),

View File

@@ -103,15 +103,21 @@ template <typename T, bool traditional, bool forward, int N = 4>
__device__ void rope_impl(
const T* in,
T* out,
int offset,
const int* offset,
float inv_freq,
float scale,
const cuda::std::array<int64_t, 3> strides,
const cuda::std::array<int64_t, 3> out_strides,
int64_t n_batch,
int64_t offset_stride,
int n_head,
uint3 pos,
uint3 dims) {
float L = scale * static_cast<float>(pos.y + offset);
auto n_head_up = N * ((n_head + N - 1) / N);
auto head_idx = static_cast<int>((pos.z * N) % n_head_up);
auto batch_idx = (pos.z * N) / n_head_up;
auto batch_offset = offset[batch_idx * offset_stride];
float L = scale * static_cast<float>(pos.y + batch_offset);
auto mat_idx = batch_idx * n_head + head_idx;
// Compute costheta, sintheta
float theta = L * inv_freq;
@@ -123,20 +129,19 @@ __device__ void rope_impl(
size_t out_index_1, out_index_2;
if (traditional) {
out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] +
N * pos.z * out_strides[0];
mat_idx * out_strides[0];
out_index_2 = out_index_1 + 1;
in_index_1 =
2 * pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
2 * pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0];
in_index_2 = in_index_1 + strides[2];
} else {
out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] +
N * pos.z * out_strides[0];
mat_idx * out_strides[0];
out_index_2 = out_index_1 + dims.x * out_strides[2];
in_index_1 =
pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
in_index_1 = pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0];
in_index_2 = in_index_1 + dims.x * strides[2];
}
for (int i = 0; i < N && pos.z * N + i < n_batch; ++i) {
for (int i = 0; i < N && head_idx + i < n_head; ++i) {
// Read and write the output
float x1 = static_cast<float>(in[in_index_1]);
float x2 = static_cast<float>(in[in_index_2]);
@@ -167,7 +172,8 @@ __global__ void rope(
float base,
const __grid_constant__ cuda::std::array<int64_t, 3> strides,
const __grid_constant__ cuda::std::array<int64_t, 3> out_strides,
int64_t n_batch,
int64_t offset_stride,
int n_head,
uint3 dims) {
uint3 pos = make_uint3(
blockIdx.x * blockDim.x + threadIdx.x,
@@ -182,12 +188,13 @@ __global__ void rope(
rope_impl<T, traditional, forward>(
in,
out,
*offset,
offset,
inv_freq,
scale,
strides,
out_strides,
n_batch,
offset_stride,
n_head,
pos,
dims);
}
@@ -202,7 +209,8 @@ __global__ void rope_freqs(
float base,
const __grid_constant__ cuda::std::array<int64_t, 3> strides,
const __grid_constant__ cuda::std::array<int64_t, 3> out_strides,
int64_t n_batch,
int64_t offset_stride,
int n_head,
uint3 dims,
int64_t freq_stride) {
uint3 pos = make_uint3(
@@ -217,12 +225,13 @@ __global__ void rope_freqs(
rope_impl<T, traditional, forward>(
in,
out,
*offset,
offset,
inv_freq,
scale,
strides,
out_strides,
n_batch,
offset_stride,
n_head,
pos,
dims);
}
@@ -245,23 +254,28 @@ void RoPE::eval_gpu(
auto& offset = inputs[1];
auto& out = outputs[0];
if (in.ndim() < 3) {
throw std::runtime_error("[RoPE] Input must have at least 3 dimensions");
}
cuda::std::array<int64_t, 3> strides;
cuda::std::array<int64_t, 3> out_strides;
bool donated = false;
int ndim = in.ndim();
int dispatch_ndim = in.ndim();
int B = in.shape(0);
int T = in.shape(-2);
int D = in.shape(-1);
size_t mat_size = T * D;
int dispatch_ndim = ndim;
while (in.shape(-dispatch_ndim) == 1 && dispatch_ndim > 3) {
dispatch_ndim--;
}
size_t mat_size = in.shape(-2) * in.shape(-1);
int N = 1;
for (int i = 1; i < (ndim - 2); ++i) {
N *= in.shape(i);
}
// We apply rope to less that the whole vector so copy to output and then
// apply in-place.
if (dims_ < in.shape(-1)) {
if (dims_ < D) {
donated = true;
auto ctype =
(in.flags().row_contiguous) ? CopyType::Vector : CopyType::General;
@@ -302,7 +316,7 @@ void RoPE::eval_gpu(
out_strides[2] = out.strides()[ndim - 1];
// Some flags to help us dispatch below
bool single = in.flags().row_contiguous && (mat_size == in.shape(-1));
bool single = in.flags().row_contiguous && B == 1 && T == 1;
bool with_freqs = inputs.size() == 3;
auto& encoder = cu::get_command_encoder(s);
@@ -319,7 +333,7 @@ void RoPE::eval_gpu(
if (single && !with_freqs) {
auto kernel =
cu::rope_single<DataType, traditional.value, forward.value>;
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size);
uint2 dims = make_uint2(dims_ / 2, N);
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
encoder.add_kernel_node(
kernel,
@@ -336,7 +350,7 @@ void RoPE::eval_gpu(
} else if (single) {
auto kernel =
cu::rope_single_freqs<DataType, traditional.value, forward.value>;
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size);
uint2 dims = make_uint2(dims_ / 2, N);
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
encoder.add_kernel_node(
kernel,
@@ -354,10 +368,14 @@ void RoPE::eval_gpu(
} else if (with_freqs) {
auto kernel =
cu::rope_freqs<DataType, traditional.value, forward.value>;
uint3 dims =
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size);
dims.z = (dims.z + 3) / 4;
int n_per_thread = 4;
uint32_t dimz = B * ((N + n_per_thread - 1) / n_per_thread);
uint3 dims = make_uint3(dims_ / 2, T, dimz);
auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z);
int64_t offset_stride = 0;
if (inputs[1].ndim() > 0) {
offset_stride = inputs[1].strides()[0];
}
encoder.add_kernel_node(
kernel,
grid,
@@ -371,15 +389,20 @@ void RoPE::eval_gpu(
std::log2(base_),
strides,
out_strides,
in.size() / mat_size,
offset_stride,
N,
dims,
inputs[2].strides(0));
} else {
auto kernel = cu::rope<DataType, traditional.value, forward.value>;
uint3 dims =
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size);
dims.z = (dims.z + 3) / 4;
int n_per_thread = 4;
uint32_t dimz = B * ((N + n_per_thread - 1) / n_per_thread);
uint3 dims = make_uint3(dims_ / 2, T, dimz);
auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z);
int64_t offset_stride = 0;
if (inputs[1].ndim() > 0) {
offset_stride = inputs[1].strides()[0];
}
encoder.add_kernel_node(
kernel,
grid,
@@ -392,7 +415,8 @@ void RoPE::eval_gpu(
std::log2(base_),
strides,
out_strides,
in.size() / mat_size,
offset_stride,
N,
dims);
}
});

View File

@@ -2,7 +2,6 @@
#include <algorithm>
#include <cassert>
#include <numeric>
#include <sstream>
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/device.h"
@@ -39,10 +38,11 @@ void explicit_gemm_conv_ND_gpu(
in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes()));
// Prepare unfolding kernel
std::ostringstream kname;
kname << "naive_unfold_nd_" << type_to_name(in_unfolded) << "_" << N;
std::string kname;
kname.reserve(32);
concatenate(kname, "naive_unfold_nd_", type_to_name(in_unfolded), "_", N);
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
auto kernel = d.get_kernel(kname);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in, 0);
@@ -117,11 +117,12 @@ void explicit_gemm_conv_group_ND_gpu(
in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes()));
// Prepare unfolding kernel
std::ostringstream kname;
kname << "naive_unfold_transpose_nd_" << type_to_name(in_unfolded) << "_"
<< N;
std::string kname;
kname.reserve(32);
concatenate(
kname, "naive_unfold_transpose_nd_", type_to_name(in_unfolded), "_", N);
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
auto kernel = d.get_kernel(kname);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in, 0);
@@ -252,18 +253,32 @@ void implicit_gemm_conv_2D_gpu(
/* const int swizzle_log = */ swizzle_log};
// Determine kernel
std::ostringstream kname;
kname << "implicit_gemm_conv_2d_" << type_to_name(out) << "_bm" << bm << "_bn"
<< bn << "_bk" << bk << "_wm" << wm << "_wn" << wn << "_channel_"
<< (n_channel_specialization ? std::to_string(n_channel_specialization)
: "l")
<< "_filter_" << (small_filter ? 's' : 'l');
std::string kname;
kname.reserve(64);
concatenate(
kname,
"implicit_gemm_conv_2d_",
type_to_name(out),
"_bm",
bm,
"_bn",
bn,
"_bk",
bk,
"_wm",
wm,
"_wn",
wn,
"_channel_",
n_channel_specialization ? std::to_string(n_channel_specialization) : "l",
"_filter_",
small_filter ? 's' : 'l');
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = get_steel_conv_kernel(
d,
kname.str(),
kname,
out,
bm,
bn,
@@ -559,11 +574,16 @@ void winograd_conv_2D_gpu(
{
int bc = 32;
int bo = 4;
std::ostringstream kname;
kname << "winograd_conv_2d_weight_transform_" << type_to_name(out) << "_bc"
<< bc;
std::string kname;
kname.reserve(32);
concatenate(
kname,
"winograd_conv_2d_weight_transform_",
type_to_name(out),
"_bc",
bc);
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
auto kernel = d.get_kernel(kname);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(wt, 0);
@@ -587,11 +607,16 @@ void winograd_conv_2D_gpu(
int bc = 32;
int wm = 2;
int wn = 2;
std::ostringstream kname;
kname << "winograd_conv_2d_input_transform_" << type_to_name(out) << "_bc"
<< bc;
std::string kname;
kname.reserve(32);
concatenate(
kname,
"winograd_conv_2d_input_transform_",
type_to_name(out),
"_bc",
bc);
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
auto kernel = d.get_kernel(kname);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in_padded, 0);
@@ -634,11 +659,16 @@ void winograd_conv_2D_gpu(
int bc = 32;
int wm = 2;
int wn = 2;
std::ostringstream kname;
kname << "winograd_conv_2d_output_transform_" << type_to_name(out) << "_bo"
<< bc;
std::string kname;
kname.reserve(32);
concatenate(
kname,
"winograd_conv_2d_output_transform_",
type_to_name(out),
"_bo",
bc);
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
auto kernel = d.get_kernel(kname);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(out_wg, 0);
@@ -660,9 +690,9 @@ void depthwise_conv_2D_gpu(
const array& wt,
array out,
const MLXConvParams<2>& conv_params) {
std::ostringstream kname;
kname << "depthwise_conv_2d_" << type_to_name(out);
std::string base_name = kname.str();
std::string base_name;
base_name.reserve(32);
concatenate(base_name, "depthwise_conv_2d_", type_to_name(out));
const int N = conv_params.N;
const int ker_h = conv_params.wS[0];
@@ -685,15 +715,18 @@ void depthwise_conv_2D_gpu(
};
// clang-format off
kname << "_ker_h_" << ker_h
<< "_ker_w_" << ker_w
<< "_str_h_" << str_h
<< "_str_w_" << str_w
<< "_tgp_h_" << th
<< "_tgp_w_" << tw
<< "_do_flip_" << (do_flip ? 't' : 'n'); // clang-format on
std::string hash_name = kname.str();
std::string hash_name;
hash_name.reserve(64);
concatenate(
hash_name,
base_name,
"_ker_h_", ker_h,
"_ker_w_", ker_w,
"_str_h_", str_h,
"_str_w_", str_w,
"_tgp_h_", th,
"_tgp_w_", tw,
"_do_flip_", do_flip ? 't' : 'n'); // clang-format on
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(base_name, hash_name, func_consts);
@@ -774,6 +807,56 @@ void dispatch_conv_2D_gpu(
}
}
void depthwise_conv_1D_gpu(
const Stream& s,
metal::Device& d,
const array& in,
array wt,
array out) {
bool large = in.size() > INT32_MAX || in.data_size() > INT32_MAX;
std::string base_name;
base_name.reserve(32);
concatenate(
base_name,
"depthwise_conv_1d_",
large ? "_large" : "",
type_to_name(out));
if (!wt.flags().row_contiguous) {
wt = contiguous_copy_gpu(wt, s);
d.add_temporary(wt, s.index);
}
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(base_name);
compute_encoder.set_compute_pipeline_state(kernel);
auto B = in.shape(0);
auto Tout = out.shape(1);
auto D = in.shape(2);
auto K = wt.shape(1);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_input_array(wt, 1);
compute_encoder.set_output_array(out, 2);
if (large) {
int64_t strides[3] = {in.strides(0), in.strides(1), in.strides(2)};
compute_encoder.set_bytes(strides, 3, 3);
} else {
int strides[3] = {
static_cast<int>(in.strides(0)),
static_cast<int>(in.strides(1)),
static_cast<int>(in.strides(2))};
compute_encoder.set_bytes(strides, 3, 3);
}
compute_encoder.set_bytes(K, 4);
auto group_dims = get_block_dims(D, Tout, B);
MTL::Size grid_dims = MTL::Size(D, Tout, B);
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
void conv_1D_gpu(
const Stream& s,
metal::Device& d,
@@ -790,8 +873,15 @@ void conv_1D_gpu(
bool is_idil_one = in_dilation[0] == 1;
int C = in.shape(2);
int O = wt.shape(0);
const int C_per_group = in.shape(2) / groups;
const int O_per_group = wt.shape(0) / groups;
// Fast path for fully separable 1D convolution
if (is_idil_one && (groups == C) && groups == O && wt_strides[0] == 1 &&
wt_dilation[0] == 1 && padding[0] == 0 && !flip) {
depthwise_conv_1D_gpu(s, d, in, wt, out);
return;
}
const int C_per_group = C / groups;
const int O_per_group = O / groups;
// Direct to implicit gemm conv
if (is_idil_one && (C_per_group <= 4 || C_per_group % 16 == 0) &&

View File

@@ -288,6 +288,40 @@ instantiate_depthconv2d(float32, float);
instantiate_depthconv2d(float16, half);
instantiate_depthconv2d(bfloat16, bfloat16_t);
template <typename T, typename IdxT>
[[kernel]] void depthwise_conv_1d(
const device T* in [[buffer(0)]],
const device T* w [[buffer(1)]],
device T* out [[buffer(2)]],
constant const IdxT strides[3],
constant const int& kernel_size,
uint3 tid [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
out += (tid.z * static_cast<IdxT>(grid_dim.y) + tid.y) * grid_dim.x + tid.x;
in += tid.z * strides[0] + tid.y * strides[1] + tid.x * strides[2];
w += tid.x * kernel_size;
float acc = 0.0;
for (int i = 0; i < kernel_size; ++i) {
acc += static_cast<float>(in[0]) * w[i];
in += strides[1];
}
*out = static_cast<T>(acc);
}
#define instantiate_depthconv1d(iname, itype) \
instantiate_kernel( \
"depthwise_conv_1d_" #iname, depthwise_conv_1d, itype, int32_t) \
instantiate_kernel( \
"depthwise_conv_1d_" #iname "_large", \
depthwise_conv_1d, \
itype, \
int64_t)
instantiate_depthconv1d(float32, float);
instantiate_depthconv1d(float16, half);
instantiate_depthconv1d(bfloat16, bfloat16_t);
///////////////////////////////////////////////////////////////////////////////
/// Winograd kernels
///////////////////////////////////////////////////////////////////////////////

View File

@@ -10,7 +10,7 @@ void rope_single_impl(
constant const int& offset,
const float inv_freq,
constant const float& scale,
constant const size_t& stride,
constant const int64_t& stride,
uint2 pos,
uint2 grid) {
float L = scale * static_cast<float>(offset);
@@ -52,7 +52,7 @@ template <typename T, bool traditional, bool forward>
device T* out [[buffer(1)]],
constant const int& offset,
constant const float& scale,
constant const size_t& stride,
constant const int64_t& stride,
constant const float& base [[buffer(10)]],
uint2 pos [[thread_position_in_grid]],
uint2 grid [[threads_per_grid]]) {
@@ -68,9 +68,9 @@ template <typename T, bool traditional, bool forward>
device T* out [[buffer(1)]],
constant const int& offset,
constant const float& scale,
constant const size_t& stride,
constant const int64_t& stride,
const device float* freqs [[buffer(10)]],
constant const size_t& freq_stride [[buffer(11)]],
constant const int64_t& freq_stride [[buffer(11)]],
uint2 pos [[thread_position_in_grid]],
uint2 grid [[threads_per_grid]]) {
float inv_freq = 1.0 / (freqs[freq_stride * pos.x]);
@@ -82,15 +82,21 @@ template <typename T, bool traditional, bool forward, int N = 4>
void rope_impl(
const device T* in,
device T* out,
constant const int& offset,
const device int* offset,
const float inv_freq,
constant const float& scale,
constant const size_t strides[3],
constant const size_t out_strides[3],
constant const size_t& n_batch,
constant const int64_t strides[3],
constant const int64_t out_strides[3],
constant const int64_t& offset_stride,
constant const int& n_head,
uint3 pos,
uint3 grid) {
float L = scale * static_cast<float>(pos.y + offset);
auto n_head_up = N * ((n_head + N - 1) / N);
auto head_idx = static_cast<int>((pos.z * N) % n_head_up);
auto batch_idx = (pos.z * N) / n_head_up;
auto batch_offset = offset[batch_idx * offset_stride];
float L = scale * static_cast<float>(pos.y + batch_offset);
auto mat_idx = batch_idx * n_head + head_idx;
// Compute costheta, sintheta
float theta = L * inv_freq;
@@ -102,20 +108,19 @@ void rope_impl(
size_t out_index_1, out_index_2;
if (traditional) {
out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] +
N * pos.z * out_strides[0];
mat_idx * out_strides[0];
out_index_2 = out_index_1 + 1;
in_index_1 =
2 * pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
2 * pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0];
in_index_2 = in_index_1 + strides[2];
} else {
out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] +
N * pos.z * out_strides[0];
mat_idx * out_strides[0];
out_index_2 = out_index_1 + grid.x * out_strides[2];
in_index_1 =
pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
in_index_1 = pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0];
in_index_2 = in_index_1 + grid.x * strides[2];
}
for (int i = 0; i < N && pos.z * N + i < n_batch; ++i) {
for (int i = 0; i < N && head_idx + i < n_head; ++i) {
// Read and write the output
float x1 = static_cast<float>(in[in_index_1]);
float x2 = static_cast<float>(in[in_index_2]);
@@ -141,11 +146,12 @@ template <typename T, bool traditional, bool forward, int N = 4>
[[kernel]] void rope(
const device T* in [[buffer(0)]],
device T* out [[buffer(1)]],
constant const int& offset,
const device int* offset,
constant const float& scale,
constant const size_t strides[3],
constant const size_t out_strides[3],
constant const size_t& n_batch,
constant const int64_t strides[3],
constant const int64_t out_strides[3],
constant const int64_t& offset_stride,
constant const int& n_head,
constant const float& base [[buffer(10)]],
uint3 pos [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]) {
@@ -159,7 +165,8 @@ template <typename T, bool traditional, bool forward, int N = 4>
scale,
strides,
out_strides,
n_batch,
offset_stride,
n_head,
pos,
grid);
}
@@ -168,13 +175,14 @@ template <typename T, bool traditional, bool forward, int N = 4>
[[kernel]] void rope_freqs(
const device T* in [[buffer(0)]],
device T* out [[buffer(1)]],
constant const int& offset,
const device int* offset,
constant const float& scale,
constant const size_t strides[3],
constant const size_t out_strides[3],
constant const size_t& n_batch,
constant const int64_t strides[3],
constant const int64_t out_strides[3],
constant const int64_t& offset_stride,
constant const int& n_head,
const device float* freqs [[buffer(10)]],
constant const size_t& freq_stride [[buffer(11)]],
constant const int64_t& freq_stride [[buffer(11)]],
uint3 pos [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]) {
float inv_freq = 1.0 / (freqs[freq_stride * pos.x]);
@@ -186,61 +194,20 @@ template <typename T, bool traditional, bool forward, int N = 4>
scale,
strides,
out_strides,
n_batch,
offset_stride,
n_head,
pos,
grid);
}
// clang-format off
#define instantiate_rope_g(name, type, traditional, forward) \
template [[host_name("rope_" #name)]] [[kernel]] void \
rope<type, traditional, forward>( \
const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \
constant const int& offset, \
constant const float& scale, \
constant const size_t strides[3], \
constant const size_t out_strides[3], \
constant const size_t& n_batch, \
constant const float& base [[buffer(10)]], \
uint3 pos [[thread_position_in_grid]], \
uint3 grid [[threads_per_grid]]); \
template [[host_name("rope_freqs_" #name)]] \
[[kernel]] void rope_freqs<type, traditional, forward>( \
const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \
constant const int& offset, \
constant const float& scale, \
constant const size_t strides[3], \
constant const size_t out_strides[3], \
constant const size_t& n_batch, \
const device float* freqs [[buffer(10)]], \
constant const size_t& freq_stride [[buffer(11)]], \
uint3 pos [[thread_position_in_grid]], \
uint3 grid [[threads_per_grid]]);
instantiate_kernel("rope_" #name, rope, type, traditional, forward) \
instantiate_kernel("rope_freqs_" #name, rope_freqs, type, traditional, forward)
#define instantiate_rope_s(name, type, traditional, forward) \
template [[host_name("rope_single_" #name)]] [[kernel]] void \
rope_single<type, traditional, forward>( \
const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \
constant const int& offset, \
constant const float& scale, \
constant const size_t& stride, \
constant const float& base [[buffer(10)]], \
uint2 pos [[thread_position_in_grid]], \
uint2 grid [[threads_per_grid]]); \
template [[host_name("rope_single_freqs_" #name)]] \
[[kernel]] void rope_single_freqs<type, traditional, forward>( \
const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \
constant const int& offset, \
constant const float& scale, \
constant const size_t& stride, \
const device float* freqs [[buffer(10)]], \
constant const size_t& freq_stride [[buffer(11)]], \
uint2 pos [[thread_position_in_grid]], \
uint2 grid [[threads_per_grid]]);
#define instantiate_rope_s(name, type, traditional, forward) \
instantiate_kernel("rope_single_" #name, rope_single, type, traditional, forward) \
instantiate_kernel("rope_single_freqs_" #name, rope_single_freqs, type, traditional, forward)
#define instantiate_rope(name, type, traditional, forward) \
instantiate_rope_s(name, type, traditional, forward) \

View File

@@ -18,23 +18,29 @@ void RoPE::eval_gpu(
auto& in = inputs[0];
auto& out = outputs[0];
if (in.ndim() < 3) {
throw std::runtime_error("[RoPE] Input must have at least 3 dimensions");
}
auto& s = out.primitive().stream();
auto& d = metal::device(s.device);
size_t strides[3];
size_t out_strides[3];
int64_t strides[3];
int64_t out_strides[3];
bool donated = false;
int ndim = in.ndim();
int dispatch_ndim = in.ndim();
int B = in.shape(0);
int T = in.shape(-2);
int D = in.shape(-1);
size_t mat_size = T * D;
int dispatch_ndim = ndim;
while (in.shape(-dispatch_ndim) == 1 && dispatch_ndim > 3) {
dispatch_ndim--;
}
size_t mat_size = in.shape(-2) * in.shape(-1);
if (dims_ < in.shape(-1)) {
int N = 1;
for (int i = 1; i < (ndim - 2); ++i) {
N *= in.shape(i);
}
if (dims_ < D) {
donated = true;
auto ctype =
(in.flags().row_contiguous) ? CopyType::Vector : CopyType::General;
@@ -71,8 +77,8 @@ void RoPE::eval_gpu(
out_strides[1] = out.strides()[ndim - 2];
out_strides[2] = out.strides()[ndim - 1];
// Special case for inference (single time step and contiguous)
bool single = in.flags().row_contiguous && (mat_size == in.shape(-1));
// Special case for inference (single batch, single time step, and contiguous)
bool single = in.flags().row_contiguous && B == 1 && T == 1;
bool with_freqs = inputs.size() == 3;
std::ostringstream kname;
@@ -86,24 +92,29 @@ void RoPE::eval_gpu(
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(donated ? out : in, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder.set_input_array(inputs[1], 2);
compute_encoder.set_bytes(scale_, 3);
size_t n_batch = in.size() / mat_size;
MTL::Size group_dims;
MTL::Size grid_dims;
if (single) {
compute_encoder.set_bytes(out_strides, 1, 4);
uint32_t dim0 = dims_ / 2;
group_dims = get_block_dims(dim0, n_batch, 1);
grid_dims = MTL::Size(dim0, n_batch, 1);
group_dims = get_block_dims(dim0, N, 1);
grid_dims = MTL::Size(dim0, N, 1);
} else {
compute_encoder.set_bytes(strides, 3, 4);
compute_encoder.set_bytes(out_strides, 3, 5);
compute_encoder.set_bytes(n_batch, 6);
int64_t offset_stride = 0;
if (inputs[1].ndim() > 0) {
offset_stride = inputs[1].strides()[0];
}
compute_encoder.set_bytes(offset_stride, 6);
compute_encoder.set_bytes(N, 7);
uint32_t dim0 = dims_ / 2;
uint32_t dim1 = in.shape(-2);
uint32_t dim2 = (n_batch + n_per_thread - 1) / n_per_thread;
uint32_t dim1 = T;
uint32_t dim2 = B * ((N + n_per_thread - 1) / n_per_thread);
group_dims = get_block_dims(dim0, dim1, dim2);
grid_dims = MTL::Size(dim0, dim1, dim2);
}

View File

@@ -394,7 +394,7 @@ void ScaledDotProductAttention::eval_gpu(
// Define some copy functions to ensure the layout of the inputs is as
// expected.
copies.reserve(3);
copies.reserve(inputs.size());
auto copy_unless = [&copies, &s](
auto predicate, const array& arr) -> const array& {
if (!predicate(arr)) {

View File

@@ -366,10 +366,16 @@ array rope(
msg << "[rope] Input must be a floating type but got " << x.dtype() << ".";
throw std::invalid_argument(msg.str());
}
if (offset.size() != 1) {
if (offset.ndim() > 1) {
std::ostringstream msg;
msg << "[rope] offset must be a scalar but has shape " << offset.shape()
<< ".";
msg << "[rope] offset must have at most one dimension but has shape "
<< offset.shape() << ".";
throw std::invalid_argument(msg.str());
}
if (offset.size() != 1 && offset.size() != x.shape(0)) {
std::ostringstream msg;
msg << "[rope] offset must be a scalar or vector with " << x.shape(0)
<< " elements but has shape " << offset.shape() << ".";
throw std::invalid_argument(msg.str());
}
if (!issubdtype(offset.dtype(), integer)) {
@@ -379,7 +385,7 @@ array rope(
throw std::invalid_argument(msg.str());
}
if (offset.dtype().size() != 4) {
inputs[1] = astype(offset, uint32, s);
inputs[1] = astype(offset, int32, s);
}
if (inputs.size() == 3 &&
(inputs[2].ndim() != 1 || inputs[2].shape(0) != dims / 2)) {
@@ -391,15 +397,26 @@ array rope(
auto fallback = [dims, traditional, base, scale, forward, s](
std::vector<array> inputs) {
auto& shape = inputs[0].shape();
int ndim = shape.size();
auto x = flatten(inputs[0], 0, ndim - 3, s);
auto x = inputs[0];
auto shape = x.shape();
if (x.ndim() == 3) {
x = expand_dims(x, 1, s);
} else if (x.ndim() > 4) {
x = flatten(x, 1, 1 + (x.ndim() - 4), s);
}
auto B = x.shape(0);
auto N = x.shape(1);
auto T = x.shape(2);
auto t = x.dtype();
// Compute sines and cosines
auto half_dims = dims / 2;
auto& offset = inputs[1];
auto offset = inputs[1];
if (offset.size() > 1) {
offset = expand_dims(offset, {-1, -2}, s);
}
auto positions =
multiply(add(arange(x.shape(1), t, s), offset, s), array(scale, t), s);
multiply(add(arange(x.shape(2), t, s), offset, s), array(scale, t), s);
auto default_inv_freqs = [&inputs, &s, &t, base, half_dims]() {
return exp(
@@ -412,8 +429,7 @@ array rope(
auto inv_freqs = inputs.size() == 3 ? astype(reciprocal(inputs[2], s), t, s)
: default_inv_freqs();
auto theta =
multiply(expand_dims(positions, 1, s), expand_dims(inv_freqs, 0, s), s);
auto theta = multiply(expand_dims(positions, -1, s), inv_freqs, s);
auto coss = cos(theta, s);
auto sins = sin(theta, s);
@@ -436,32 +452,30 @@ array rope(
};
if (traditional) {
auto x1 =
slice(x, {0, 0, 0}, {x.shape(0), x.shape(1), dims}, {1, 1, 2}, s);
auto x2 =
slice(x, {0, 0, 1}, {x.shape(0), x.shape(1), dims}, {1, 1, 2}, s);
auto x1 = slice(x, {0, 0, 0, 0}, {B, N, T, dims}, {1, 1, 1, 2}, s);
auto x2 = slice(x, {0, 0, 0, 1}, {B, N, T, dims}, {1, 1, 1, 2}, s);
auto outs = apply_rope(x1, x2, coss, sins);
for (auto& o : outs) {
o = expand_dims(o, 3, s);
o = expand_dims(o, -1, s);
}
auto out = concatenate(outs, 3, s);
auto out = reshape(concatenate(outs, -1, s), {B, N, T, dims}, s);
if (dims < x.shape(-1)) {
out = reshape(out, {x.shape(0), x.shape(1), dims});
out = concatenate({out, slice(x, {0, 0, dims}, x.shape(), s)}, 2, s);
out =
concatenate({out, slice(x, {0, 0, 0, dims}, x.shape(), s)}, -1, s);
}
return std::vector<array>{reshape(out, shape, s)};
} else {
auto out_s = x.shape();
out_s.back() = half_dims;
auto x1 = slice(x, {0, 0, 0}, out_s, s);
auto x1 = slice(x, {0, 0, 0, 0}, out_s, s);
out_s.back() = dims;
auto x2 = slice(x, {0, 0, half_dims}, out_s, s);
auto x2 = slice(x, {0, 0, 0, half_dims}, out_s, s);
auto outs = apply_rope(x1, x2, coss, sins);
if (dims < x.shape(-1)) {
outs.push_back(slice(x, {0, 0, dims}, x.shape(), s));
outs.push_back(slice(x, {0, 0, 0, dims}, x.shape(), s));
}
return std::vector<array>{reshape(concatenate(outs, 2, s), shape, s)};
return std::vector<array>{reshape(concatenate(outs, -1, s), shape, s)};
}
};
auto stream = to_stream(s);

View File

@@ -1,10 +1,8 @@
// Copyright © 2025 Apple Inc.
#include <string>
namespace mlx::core {
std::string version() {
const char* version() {
return MLX_VERSION;
}

View File

@@ -15,6 +15,6 @@ namespace mlx::core {
*
* For dev builds, the version will include the suffix ".devYYYYMMDD+hash"
*/
std::string version();
const char* version();
} // namespace mlx::core

View File

@@ -1,20 +1,34 @@
mlx.core.__prefix__:
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
import sys
if sys.version_info >= (3, 10):
from typing import TypeAlias
else:
from typing_extensions import TypeAlias
mlx.core.__suffix__:
from typing import Union
scalar: TypeAlias = Union[int, float, bool]
list_or_scalar: TypeAlias = Union[scalar, list["list_or_scalar"]]
bool_: Dtype = ...
mlx.core.distributed.__prefix__:
from mlx.core import array, Dtype, Device, Stream
from mlx.core import array, Dtype, Device, Stream, scalar
from mlx.core.distributed import Group
from typing import Sequence, Optional, Union
mlx.core.fast.__prefix__:
from mlx.core import array, Dtype, Device, Stream
from mlx.core import array, Dtype, Device, Stream, scalar
from typing import Sequence, Optional, Union
mlx.core.linalg.__prefix__:
from mlx.core import array, Dtype, Device, Stream
from mlx.core import array, Dtype, Device, Stream, scalar
from typing import Sequence, Optional, Tuple, Union
mlx.core.metal.__prefix__:
from mlx.core import array, Dtype, Device, Stream
from mlx.core import array, Dtype, Device, Stream, scalar
from typing import Sequence, Optional, Union
mlx.core.random.__prefix__:
from mlx.core import array, Dtype, Device, Stream
from mlx.core import array, Dtype, Device, Stream, scalar, float32, int32
from typing import Sequence, Optional, Union

View File

@@ -556,7 +556,7 @@ class AdamW(Adam):
eps (float, optional): The term :math:`\epsilon` added to the
denominator to improve numerical stability. Default: ``1e-8``
weight_decay (float, optional): The weight decay :math:`\lambda`.
Default: ``0``.
Default: ``0.01``.
bias_correction (bool, optional): If set to ``True``, bias correction
is applied. Default: ``False``
"""

View File

@@ -52,7 +52,6 @@ set_target_properties(
${MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY})
target_link_libraries(core PRIVATE mlx)
target_compile_definitions(core PRIVATE _VERSION_=${MLX_VERSION})
if(BUILD_SHARED_LIBS)
if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")

View File

@@ -320,6 +320,7 @@ void init_array(nb::module_& m) {
.def_prop_ro(
"shape",
[](const mx::array& a) { return nb::cast(a.shape()); },
nb::sig("def shape(self) -> tuple[int, ...]"),
R"pbdoc(
The shape of the array as a Python tuple.
@@ -347,6 +348,7 @@ void init_array(nb::module_& m) {
.def(
"item",
&to_scalar,
nb::sig("def item(self) -> scalar"),
R"pbdoc(
Access the value of a scalar array.
@@ -356,6 +358,7 @@ void init_array(nb::module_& m) {
.def(
"tolist",
&tolist,
nb::sig("def tolist(self) -> list_or_scalar"),
R"pbdoc(
Convert the array to a Python :class:`list`.

View File

@@ -164,8 +164,13 @@ void init_fast(nb::module_& parent_module) {
R"pbdoc(
Apply rotary positional encoding to the input.
The input is expected to be at least 3D with shape ``(B, *, T, D)`` where:
* ``B`` is the batch size.
* ``T`` is the sequence length.
* ``D`` is the feature dimension.
Args:
a (array): Input array.
a (array): The input array.
dims (int): The feature dimensions to be rotated. If the input feature
is larger than dims then the rest is left unchanged.
traditional (bool): If set to ``True`` choose the traditional
@@ -174,7 +179,9 @@ void init_fast(nb::module_& parent_module) {
each dimension in the positional encodings. Exactly one of ``base`` and
``freqs`` must be ``None``.
scale (float): The scale used to scale the positions.
offset (int or array): The position offset to start at.
offset (int or array): The position offset to start at. If an
:obj:`array` is given it can be a scalar or vector of ``B``
offsets for each example in the batch.
freqs (array, optional): Optional frequencies to use with RoPE.
If set, the ``base`` parameter must be ``None``. Default: ``None``.

View File

@@ -447,6 +447,8 @@ void init_linalg(nb::module_& parent_module) {
"a"_a,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def eig(a: array, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array]"),
R"pbdoc(
Compute the eigenvalues and eigenvectors of a square matrix.
@@ -523,6 +525,8 @@ void init_linalg(nb::module_& parent_module) {
"UPLO"_a = "L",
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def eigh(a: array, UPLO: str = 'L', *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array]"),
R"pbdoc(
Compute the eigenvalues and eigenvectors of a complex Hermitian or
real symmetric matrix.

View File

@@ -2,9 +2,9 @@
#include <nanobind/nanobind.h>
#define STRINGIFY(x) #x
#define TOSTRING(x) STRINGIFY(x)
#include "mlx/version.h"
namespace mx = mlx::core;
namespace nb = nanobind;
void init_mlx_func(nb::module_&);
@@ -48,5 +48,5 @@ NB_MODULE(core, m) {
init_distributed(m);
init_export(m);
m.attr("__version__") = TOSTRING(_VERSION_);
m.attr("__version__") = mx::version();
}

View File

@@ -4271,7 +4271,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def dequantize(w: array, /, scales: array, biases: Optional[array] = = None, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"),
"def dequantize(w: array, /, scales: array, biases: Optional[array] = None, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Dequantize the matrix ``w`` using quantization parameters.

View File

@@ -171,7 +171,7 @@ void init_random(nb::module_& parent_module) {
"key"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
"def normal(shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, loc: Optional[scalar, array] = None, scale: Optional[scalar, array] = None, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
"def normal(shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, loc: Union[scalar, array, None] = None, scale: Union[scalar, array, None] = None, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Generate normally distributed random numbers.

View File

@@ -91,7 +91,7 @@ mx::array to_array_with_accessor(nb::object obj) {
return nb::cast<mx::array>(obj.attr("__mlx_array__")());
} else {
std::ostringstream msg;
msg << "Invalid type " << nb::type_name(obj.type()).c_str()
msg << "Invalid type " << nb::type_name(obj.type()).c_str()
<< " received in array initialization.";
throw std::invalid_argument(msg.str());
}

View File

@@ -594,124 +594,123 @@ class TestBlas(mlx_tests.MLXTestCase):
np.random.seed(0)
# Batched matmul
alpha = 0.5
beta = 2.0
for beta in (1.0, 2.0):
# c must broadcast to the output shape
with self.assertRaises(ValueError):
mx.addmm(mx.zeros((2, 2, 2)), mx.zeros((2, 2)), mx.zeros((2, 2)))
# c must broadcast to the output shape
with self.assertRaises(ValueError):
mx.addmm(mx.zeros((2, 2, 2)), mx.zeros((2, 2)), mx.zeros((2, 2)))
# Regular batched case
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 16)).astype(np.float32)
# Regular batched case
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 16)).astype(np.float32)
a_mlx = mx.array(a_npy)
b_mlx = mx.array(b_npy)
a_mlx = mx.array(a_npy)
b_mlx = mx.array(b_npy)
for c_shape in ((1,), (1, 16), (32, 1, 16), (1, 128, 16)):
c_npy = np.ones(c_shape).astype(np.float32)
c_mlx = mx.array(c_npy)
for c_shape in ((1,), (1, 16), (32, 1, 16), (1, 128, 16)):
c_npy = np.ones(c_shape).astype(np.float32)
c_mlx = mx.array(c_npy)
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
# Batched and transposed matmul
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
b_mlx = mx.array(b_npy)
# Batched and transposed matmul
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
b_mlx = mx.array(b_npy)
for c_shape in ((1,), (32, 1, 128), (1, 128)):
c_npy = np.ones(c_shape).astype(np.float32)
c_mlx = mx.array(c_npy)
for c_shape in ((1,), (32, 1, 128), (1, 128)):
c_npy = np.ones(c_shape).astype(np.float32)
c_mlx = mx.array(c_npy)
b_np_t = np.transpose(b_npy, (0, 2, 1))
b_mx_t = mx.transpose(b_mlx, (0, 2, 1))
b_np_t = np.transpose(b_npy, (0, 2, 1))
b_mx_t = mx.transpose(b_mlx, (0, 2, 1))
d_npy = alpha * (a_npy @ b_np_t) + beta * c_npy
d_mlx = mx.addmm(c_mlx, a_mlx, b_mx_t, alpha, beta)
d_npy = alpha * (a_npy @ b_np_t) + beta * c_npy
d_mlx = mx.addmm(c_mlx, a_mlx, b_mx_t, alpha, beta)
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
# Batched matmul with simple broadcast
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (16, 16)).astype(np.float32)
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
# Batched matmul with simple broadcast
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (16, 16)).astype(np.float32)
a_mlx = mx.array(a_npy)
b_mlx = mx.array(b_npy)
a_mlx = mx.array(a_npy)
b_mlx = mx.array(b_npy)
for c_shape in ((1,), (1, 16), (32, 1, 16), (1, 128, 16)):
c_npy = np.ones(c_shape).astype(np.float32)
c_mlx = mx.array(c_npy)
for c_shape in ((1,), (1, 16), (32, 1, 16), (1, 128, 16)):
c_npy = np.ones(c_shape).astype(np.float32)
c_mlx = mx.array(c_npy)
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
# Matmul with vector
a_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 128)).astype(np.float32)
a_mlx = mx.array(a_npy)
b_mlx = mx.array(b_npy)
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
# Matmul with vector
a_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 128)).astype(np.float32)
a_mlx = mx.array(a_npy)
b_mlx = mx.array(b_npy)
for c_shape in ((1,), (128,), (32, 128)):
c_npy = np.ones(c_shape).astype(np.float32)
c_mlx = mx.array(c_npy)
for c_shape in ((1,), (128,), (32, 128)):
c_npy = np.ones(c_shape).astype(np.float32)
c_mlx = mx.array(c_npy)
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
# Matmul with vector
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32)
a_mlx = mx.array(a_npy)
b_mlx = mx.array(b_npy)
# Matmul with vector
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32)
a_mlx = mx.array(a_npy)
b_mlx = mx.array(b_npy)
for c_shape in ((1,), (32, 128)):
c_npy = np.ones(c_shape).astype(np.float32)
c_mlx = mx.array(c_npy)
for c_shape in ((1,), (32, 128)):
c_npy = np.ones(c_shape).astype(np.float32)
c_mlx = mx.array(c_npy)
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
# Split K specializtion
a_npy = np.random.normal(0.0, 1.0 / 128, (64, 4096)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (4096, 32)).astype(np.float32)
# Split K specializtion
a_npy = np.random.normal(0.0, 1.0 / 128, (64, 4096)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (4096, 32)).astype(np.float32)
a_mlx = mx.array(a_npy)
b_mlx = mx.array(b_npy)
a_mlx = mx.array(a_npy)
b_mlx = mx.array(b_npy)
for c_shape in ((1,), (1, 32), (64, 1), (64, 32)):
c_npy = np.ones(c_shape).astype(np.float32)
c_mlx = mx.array(c_npy)
for c_shape in ((1,), (1, 32), (64, 1), (64, 32)):
c_npy = np.ones(c_shape).astype(np.float32)
c_mlx = mx.array(c_npy)
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
# Transposed c
a = mx.ones((10, 5)).T
b = mx.ones((5, 5))
out = mx.addmm(a, b, a, beta=beta, alpha=alpha)
expected = beta * a + alpha * (b @ a)
self.assertTrue(mx.allclose(expected, out))
# Transposed c
a = mx.ones((10, 5)).T
b = mx.ones((5, 5))
out = mx.addmm(a, b, a, beta=1.5, alpha=0.5)
expected = 1.5 * a + 0.5 * (b @ a)
self.assertTrue(mx.allclose(expected, out))
# Broadcast c
a = mx.ones((5, 5))
b = mx.ones((5, 5))
c = mx.ones((1, 5))
out = mx.addmm(c, a, b, beta=1.5, alpha=0.5)
expected = 1.5 * c + 0.5 * (a @ b)
self.assertTrue(mx.allclose(expected, out))
# Broadcast c
a = mx.ones((5, 5))
b = mx.ones((5, 5))
c = mx.ones((1, 5))
out = mx.addmm(c, a, b, beta=beta, alpha=alpha)
expected = beta * c + alpha * (a @ b)
self.assertTrue(mx.allclose(expected, out))
def test_addmm_grad(self):
def make_ref_addmm(alpha, beta):
@@ -724,33 +723,32 @@ class TestBlas(mlx_tests.MLXTestCase):
shapes = ((1, 64, 32, 128), (4, 28, 24, 47), (1, 1, 24, 47))
alpha = 2.0
beta = 0.5
for beta in (1.0, 0.5):
f_test = make_addmm(alpha, beta)
f_ref = make_ref_addmm(alpha, beta)
f_test = make_addmm(alpha, beta)
f_ref = make_ref_addmm(alpha, beta)
for B, M, N, K in shapes:
cotan = mx.ones((B, M, N))
c = mx.random.normal((B, M, N))
a = mx.random.normal((B, M, K))
b = mx.random.normal((B, K, N))
for B, M, N, K in shapes:
cotan = mx.ones((B, M, N))
c = mx.random.normal((B, M, N))
a = mx.random.normal((B, M, K))
b = mx.random.normal((B, K, N))
out_ref, dout_ref = mx.vjp(
f_ref,
[c, a, b],
[cotan],
)
out_test, dout_test = mx.vjp(
f_test,
[c, a, b],
[cotan],
)
out_ref, dout_ref = mx.vjp(
f_ref,
[c, a, b],
[cotan],
)
out_test, dout_test = mx.vjp(
f_test,
[c, a, b],
[cotan],
)
self.assertTrue(mx.allclose(out_ref[0], out_test[0], atol=1e-4).item())
self.assertTrue(mx.allclose(out_ref[0], out_test[0], atol=1e-4).item())
for r, t in zip(dout_ref, dout_test):
self.assertEqual(r.shape, t.shape)
self.assertTrue(mx.allclose(r, t, atol=1e-4).item())
for r, t in zip(dout_ref, dout_test):
self.assertEqual(r.shape, t.shape)
self.assertTrue(mx.allclose(r, t, atol=1e-4).item())
def test_empty_matmul(self):
a = mx.array([[], []]).T

View File

@@ -8,18 +8,23 @@ import mlx_tests
def rope_orig(x, dims, traditional, base, scale, offset, freqs=None):
offset = offset.item() if isinstance(offset, mx.array) else offset
N = x.shape[-2] + offset
N = x.shape[-2]
dtype = x.dtype
half_D = dims // 2
positions = mx.arange(offset, N, dtype=dtype) * scale
positions = mx.arange(N, dtype=dtype)
if isinstance(offset, mx.array) and offset.size > 1:
expand = tuple(range(1, x.ndim - 1))
positions = mx.expand_dims(offset, expand) + positions
else:
positions = offset + positions
positions = positions * scale
if freqs is None:
inv_freqs = mx.exp(
-mx.arange(0.0, half_D, dtype=dtype) * (math.log(base) / half_D)
)
else:
inv_freqs = (1 / freqs).astype(x.dtype)
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(inv_freqs, (1, -1))
theta = mx.expand_dims(positions, -1) * inv_freqs
costheta, sintheta = mx.cos(theta), mx.sin(theta)
if traditional:
x1 = x[..., :dims:2]
@@ -214,6 +219,7 @@ class TestFast(mlx_tests.MLXTestCase):
)
self.assertEqual(dtype, rx.dtype)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
return
# Test single vector
x = mx.random.uniform(shape=(1, 1, dims))
@@ -277,6 +283,55 @@ class TestFast(mlx_tests.MLXTestCase):
g2 = mx.grad(f2)(x, y)
self.assertLess(mx.abs(g1 - g2).max(), 1e-5)
def test_rope_batch(self):
T = 4
base = 10000.0
scale = 1.0
traditional = True
batch_sizes = [3, 8, 11]
num_heads = [1, 3, 5]
dims = 32
x = mx.random.uniform(shape=(8, 4, T, dims))
offset = mx.array([1, 2, 3])
with self.assertRaises(ValueError):
mx.fast.rope(
x,
dims,
traditional=traditional,
base=base,
scale=scale,
offset=offset,
)
for batch_size in batch_sizes:
for n_head in num_heads:
x = mx.random.uniform(shape=(batch_size, n_head, T, dims))
offset = mx.arange(batch_size)
rx = rope_orig(x, dims, traditional, base, scale, offset)
rx_fast = mx.fast.rope(
x,
dims,
traditional=traditional,
base=base,
scale=scale,
offset=offset,
)
self.assertLess(mx.abs(rx - rx_fast).max(), 1e-5)
x = mx.random.normal(shape=(2, 6, 8, 64)).transpose(0, 2, 1, 3)
dims = 64
offset = 0
rx_fast = mx.fast.rope(
x, dims, traditional=traditional, scale=scale, base=base, offset=offset
)
rx_fast_single = mx.fast.rope(
x[0:1], dims, traditional=traditional, scale=scale, base=base, offset=offset
)
rx = rope_orig(x, dims, traditional, base, scale, offset)
self.assertLess(mx.abs(rx - rx_fast).max(), 1e-5)
def test_rms_norm(self):
# Per dtype absolute tolerance
tolerances = {mx.float32: 1e-6, mx.float16: 1e-3, mx.bfloat16: 1e-2}

View File

@@ -619,6 +619,17 @@ class TestSDPA(mlx_tests.MLXTestCase):
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
def test_sdpa_noncontiguous_inputs(self):
mask = mx.ones(shape=(4, 1, 7, 7), dtype=mx.bool_)
mx.random.seed(0)
q = mx.random.normal(shape=(4, 7, 32, 64)).swapaxes(1, 2)
k = mx.random.normal(shape=(4, 7, 8, 64)).swapaxes(1, 2)
v = mx.random.normal(shape=(4, 7, 8, 64)).swapaxes(1, 2)
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask)
ref = mlx_ref_attn(q, k, v, scale=1.0, mask=mask)
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
def test_sdpa_promote_mask(self):
mask = mx.array(2.0, mx.bfloat16)
D = 64

View File

@@ -121,7 +121,8 @@ class CMakeBuild(build_ext):
build_args += [f"-j{os.cpu_count()}"]
# Avoid cache miss when building from temporary dirs.
os.environ["CCACHE_BASEDIR"] = os.path.abspath(self.build_temp)
os.environ["CCACHE_BASEDIR"] = os.path.realpath(self.build_temp)
os.environ["CCACHE_NOHASHDIR"] = "true"
subprocess.run(
["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True
@@ -176,10 +177,6 @@ class GenerateStubs(Command):
# Run again without recursive to specify output file name
subprocess.run(["rm", f"{out_path}/mlx.pyi"])
subprocess.run(stub_cmd + ["-o", f"{out_path}/__init__.pyi"])
# mx.bool_ gets filtered by nanobind because of the trailing
# underscore, add it manually:
with open(f"{out_path}/__init__.pyi", "a") as fid:
fid.write("\nbool_: Dtype = ...")
class MLXBdistWheel(bdist_wheel):

View File

@@ -1,10 +1,7 @@
# Doctest works fine with cmake 3.5
set(CMAKE_POLICY_VERSION_MINIMUM 3.5)
FetchContent_Declare(
doctest
GIT_REPOSITORY "https://github.com/onqtam/doctest"
GIT_TAG "ae7a13539fb71f270b87eb2e874fbac80bc8dda2")
GIT_REPOSITORY https://github.com/onqtam/doctest.git
GIT_TAG v2.4.12)
FetchContent_MakeAvailable(doctest)
add_executable(tests ${PROJECT_SOURCE_DIR}/tests/tests.cpp)