Compare commits

..

6 Commits

Author SHA1 Message Date
Cheng
44cc5da4bc [CUDA] Fix alpha not respected when using bias epilogue (#2578) 2025-09-10 09:08:01 +09:00
Cheng
dde3682b69 [CUDA] Use GEMM with epilogue instead of AddMM (#2569) 2025-09-09 13:18:49 +09:00
Awni Hannun
17310d91a6 Add batch offsets for mx.fast.rope (#2564)
* implement batch rope for Metal

* cuda rope (#2576)
2025-09-08 17:35:07 -07:00
Cheng
b194d65a6a Some tweaks in cmake files (#2574)
* Do proper check of Metal lib

* Update doctest to get rid of cmake version hack
2025-09-09 08:27:18 +09:00
Cheng
a44b27f5f8 Fix a few ccache cache miss (#2573)
* Fix ccache cache miss

* Do not define _VERSION_ in python bindings
2025-09-09 07:41:05 +09:00
Awni Hannun
e5a33f2223 faster depthwise 1D conv (#2567) 2025-09-08 11:37:23 -07:00
22 changed files with 689 additions and 412 deletions

View File

@@ -87,22 +87,21 @@ cmake_policy(SET CMP0135 NEW)
add_library(mlx) add_library(mlx)
if(MLX_BUILD_METAL)
set(METAL_LIB "-framework Metal")
set(FOUNDATION_LIB "-framework Foundation")
set(QUARTZ_LIB "-framework QuartzCore")
endif()
if(MLX_BUILD_CUDA) if(MLX_BUILD_CUDA)
enable_language(CUDA) enable_language(CUDA)
endif() endif()
if(MLX_BUILD_METAL AND NOT METAL_LIB) if(MLX_BUILD_METAL)
message(STATUS "Metal not found. Unable to build GPU") find_library(METAL_LIB Metal)
set(MLX_BUILD_METAL OFF) find_library(FOUNDATION_LIB Foundation)
set(MLX_METAL_DEBUG OFF) find_library(QUARTZ_LIB QuartzCore)
elseif(MLX_BUILD_METAL) if(METAL_LIB)
message(STATUS "Building METAL sources") message(STATUS "Metal found ${METAL_LIB}")
else()
message(
FATAL_ERROR
"Metal not found. Set MLX_BUILD_METAL=OFF to build without GPU")
endif()
if(MLX_METAL_DEBUG) if(MLX_METAL_DEBUG)
add_compile_definitions(MLX_METAL_DEBUG) add_compile_definitions(MLX_METAL_DEBUG)
@@ -111,7 +110,8 @@ elseif(MLX_BUILD_METAL)
# Throw an error if xcrun not found # Throw an error if xcrun not found
execute_process( execute_process(
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version" COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
OUTPUT_VARIABLE MACOS_SDK_VERSION COMMAND_ERROR_IS_FATAL ANY) OUTPUT_VARIABLE MACOS_SDK_VERSION
OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
if(${MACOS_SDK_VERSION} LESS 14.0) if(${MACOS_SDK_VERSION} LESS 14.0)
message( message(

View File

@@ -85,10 +85,10 @@ cublasLtMatrixLayout_t create_matrix_layout(
int32_t batch_count, int32_t batch_count,
int64_t batch_stride) { int64_t batch_stride) {
cublasLtMatrixLayout_t desc; cublasLtMatrixLayout_t desc;
if (transposed) {
std::swap(rows, cols);
}
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&desc, type, rows, cols, ld)); CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&desc, type, rows, cols, ld));
cublasLtOrder_t order = transposed ? CUBLASLT_ORDER_COL : CUBLASLT_ORDER_ROW;
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(cublasLtOrder_t)));
if (batch_count > 1) { if (batch_count > 1) {
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc, desc,
@@ -138,25 +138,34 @@ CublasGemm::CublasGemm(
CUBLASLT_MATMUL_DESC_POINTER_MODE, CUBLASLT_MATMUL_DESC_POINTER_MODE,
&pointer_mode, &pointer_mode,
sizeof(int32_t))); sizeof(int32_t)));
cublasOperation_t op = CUBLAS_OP_N;
// In cublasLt matrices use column-major layout, while it is possible to use
// the CUBLASLT_ORDER_ROW option to switch to row-major layout, the bias
// epilogue does not work with the option. So instead we swap A and B to make
// cublasLt return the row-major result, which works because:
// - the data of a matrix in row-major layout is identical to its transpose in
// column-major layout
// - C^T = (A @ B)^T = B^T @ A^T
cublasOperation_t a_op = b_transposed ? CUBLAS_OP_T : CUBLAS_OP_N;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_, matmul_desc_,
CUBLASLT_MATMUL_DESC_TRANSA, CUBLASLT_MATMUL_DESC_TRANSA,
&op, &a_op,
sizeof(cublasOperation_t))); sizeof(cublasOperation_t)));
cublasOperation_t b_op = a_transposed ? CUBLAS_OP_T : CUBLAS_OP_N;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_, matmul_desc_,
CUBLASLT_MATMUL_DESC_TRANSB, CUBLASLT_MATMUL_DESC_TRANSB,
&op, &b_op,
sizeof(cublasOperation_t))); sizeof(cublasOperation_t)));
auto type = dtype_to_cublas_type(dtype); auto type = dtype_to_cublas_type(dtype);
a_desc_ = create_matrix_layout( a_desc_ = create_matrix_layout(
type, a_rows, a_cols, a_transposed, lda, batch_count, a_batch_stride); type, b_cols, b_rows, b_transposed, ldb, batch_count, b_batch_stride);
b_desc_ = create_matrix_layout( b_desc_ = create_matrix_layout(
type, b_rows, b_cols, b_transposed, ldb, batch_count, b_batch_stride); type, a_cols, a_rows, a_transposed, lda, batch_count, a_batch_stride);
out_desc_ = create_matrix_layout( out_desc_ = create_matrix_layout(
type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols); type, b_cols, a_rows, false, b_cols, batch_count, a_rows * b_cols);
} }
CublasGemm::CublasGemm( CublasGemm::CublasGemm(
@@ -191,7 +200,7 @@ CublasGemm::CublasGemm(
b_batch_stride) { b_batch_stride) {
auto type = dtype_to_cublas_type(dtype); auto type = dtype_to_cublas_type(dtype);
c_desc_ = create_matrix_layout( c_desc_ = create_matrix_layout(
type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride); type, b_cols, a_rows, false, ldc, batch_count, c_batch_stride);
} }
CublasGemm::~CublasGemm() { CublasGemm::~CublasGemm() {
@@ -213,14 +222,25 @@ void CublasGemm::set_out(
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_)); CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_));
out_desc_ = create_matrix_layout( out_desc_ = create_matrix_layout(
dtype_to_cublas_type(dtype), dtype_to_cublas_type(dtype),
rows,
cols, cols,
rows,
transposed, transposed,
ld, ld,
batch_count, batch_count,
batch_stride); batch_stride);
} }
void CublasGemm::set_bias(void* bias) {
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
CUBLASLT_MATMUL_DESC_EPILOGUE,
&epilogue,
sizeof(epilogue)));
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias)));
}
void CublasGemm::run( void CublasGemm::run(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
array& out, array& out,
@@ -228,11 +248,19 @@ void CublasGemm::run(
const array& b, const array& b,
const Shape& batch_shape, const Shape& batch_shape,
const Strides& a_batch_strides, const Strides& a_batch_strides,
const Strides& b_batch_strides) { const Strides& b_batch_strides,
float alpha) {
int batch_count = out.size() / (M_ * N_); int batch_count = out.size() / (M_ * N_);
if (batch_count / batch_shape.back() > 1) { if (batch_count / batch_shape.back() > 1) {
run_batched( run_batched(
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides); encoder,
out,
a,
b,
batch_shape,
a_batch_strides,
b_batch_strides,
alpha);
return; return;
} }
@@ -240,7 +268,13 @@ void CublasGemm::run(
encoder.set_input_array(b); encoder.set_input_array(b);
encoder.set_output_array(out); encoder.set_output_array(out);
execute(encoder, out.data<void>(), a.data<void>(), b.data<void>(), nullptr); execute(
encoder,
out.data<void>(),
a.data<void>(),
b.data<void>(),
nullptr,
alpha);
} }
void CublasGemm::run( void CublasGemm::run(
@@ -330,9 +364,9 @@ void CublasGemm::execute(
handle_, handle_,
matmul_desc_, matmul_desc_,
&alpha, &alpha,
a, b, // a and b are swapped
a_desc_, a_desc_,
b, a,
b_desc_, b_desc_,
&beta, &beta,
c ? c : out, c ? c : out,

View File

@@ -55,6 +55,8 @@ class CublasGemm {
int32_t batch_count, int32_t batch_count,
int64_t batch_stride); int64_t batch_stride);
void set_bias(void* bias);
void run( void run(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
array& out, array& out,
@@ -62,7 +64,8 @@ class CublasGemm {
const array& b, const array& b,
const Shape& batch_shape, const Shape& batch_shape,
const Strides& a_batch_strides, const Strides& a_batch_strides,
const Strides& b_batch_strides); const Strides& b_batch_strides,
float alpha = 1.0f);
void run( void run(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
@@ -85,7 +88,8 @@ class CublasGemm {
const array& b, const array& b,
const Shape& batch_shape, const Shape& batch_shape,
const Strides& a_batch_strides, const Strides& a_batch_strides,
const Strides& b_batch_strides); const Strides& b_batch_strides,
float alpha);
void run_batched( void run_batched(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,

View File

@@ -13,7 +13,8 @@ void CublasGemm::run_batched(
const array& b, const array& b,
const Shape& batch_shape, const Shape& batch_shape,
const Strides& a_batch_strides, const Strides& a_batch_strides,
const Strides& b_batch_strides) { const Strides& b_batch_strides,
float alpha) {
encoder.set_input_array(a); encoder.set_input_array(a);
encoder.set_input_array(b); encoder.set_input_array(b);
encoder.set_output_array(out); encoder.set_output_array(out);
@@ -27,7 +28,8 @@ void CublasGemm::run_batched(
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_, out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
a.data<int8_t>() + a.itemsize() * a_it.loc, a.data<int8_t>() + a.itemsize() * a_it.loc,
b.data<int8_t>() + b.itemsize() * b_it.loc, b.data<int8_t>() + b.itemsize() * b_it.loc,
nullptr); nullptr,
alpha);
a_it.step(); a_it.step();
b_it.step(); b_it.step();
} }

View File

@@ -154,7 +154,8 @@ void CublasGemm::run_batched(
const array& b, const array& b,
const Shape& batch_shape, const Shape& batch_shape,
const Strides& a_batch_strides, const Strides& a_batch_strides,
const Strides& b_batch_strides) { const Strides& b_batch_strides,
float alpha) {
int batch_count = out.size() / (M_ * N_); int batch_count = out.size() / (M_ * N_);
set_pointer_mode(a_desc_, batch_count); set_pointer_mode(a_desc_, batch_count);
set_pointer_mode(b_desc_, batch_count); set_pointer_mode(b_desc_, batch_count);
@@ -226,7 +227,8 @@ void CublasGemm::run_batched(
reinterpret_cast<void*>(out_pointers), reinterpret_cast<void*>(out_pointers),
reinterpret_cast<void*>(a_pointers), reinterpret_cast<void*>(a_pointers),
reinterpret_cast<void*>(b_pointers), reinterpret_cast<void*>(b_pointers),
nullptr); nullptr,
alpha);
} }
void CublasGemm::run_batched( void CublasGemm::run_batched(

View File

@@ -11,6 +11,7 @@
#include <numeric> #include <numeric>
namespace mlx::core { namespace mlx::core {
namespace { namespace {
std::tuple<bool, int64_t, array> std::tuple<bool, int64_t, array>
@@ -28,6 +29,76 @@ check_transpose(cu::CommandEncoder& enc, const Stream& s, const array& arr) {
} }
} }
void gemm_and_bias(
cu::CommandEncoder& encoder,
int M,
int N,
int K,
bool a_transposed,
int64_t lda,
bool b_transposed,
int64_t ldb,
array& out,
const array& a,
const array& b,
void* bias = nullptr,
float alpha = 1.0f) {
// Check and collapse batch dimensions
auto [batch_shape, a_batch_strides, b_batch_strides] = collapse_batches(a, b);
auto batch_count = out.size() / (M * N);
// Collapse batches into M if needed
if (batch_count > 1 && !a_transposed && batch_shape.size() == 1 &&
a.strides()[a.ndim() - 2] == K && a_batch_strides.back() == M * K &&
b_batch_strides.back() == 0) {
M *= batch_shape.back();
batch_count = 1;
a_batch_strides = {0};
b_batch_strides = {0};
batch_shape = {1};
}
// Use gemmv when possible
if (!bias && cu::can_use_gemv(M, N, K, a_transposed, b_transposed)) {
cu::gemv(
a,
b,
out,
M,
N,
K,
batch_count,
batch_shape,
a_batch_strides,
b_batch_strides,
encoder);
return;
}
// Invoke cublasLt
CublasGemm gemm(
encoder.device(),
a.dtype(),
a_transposed,
M,
K,
lda,
b_transposed,
K,
N,
ldb,
batch_shape.back(),
a_batch_strides.back(),
b_batch_strides.back());
if (bias) {
gemm.set_bias(bias);
}
gemm.run(
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides, alpha);
}
} // namespace } // namespace
void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) { void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -48,9 +119,6 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
/////////////////////////////////////////////////////////////////////////////
// Init checks and prep
int M = a_pre.shape(-2); int M = a_pre.shape(-2);
int N = b_pre.shape(-1); int N = b_pre.shape(-1);
int K = a_pre.shape(-1); int K = a_pre.shape(-1);
@@ -60,58 +128,8 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre); auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre);
auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre); auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre);
///////////////////////////////////////////////////////////////////////////// gemm_and_bias(
// Check and collapse batch dimensions encoder, M, N, K, a_transposed, lda, b_transposed, ldb, out, a, b);
auto [batch_shape, a_batch_strides, b_batch_strides] = collapse_batches(a, b);
auto batch_count = out.size() / (M * N);
// Collapse batches into M if needed
if (batch_count > 1 && !a_transposed && batch_shape.size() == 1 &&
a.strides()[a.ndim() - 2] == K && a_batch_strides.back() == M * K &&
b_batch_strides.back() == 0) {
M *= batch_shape.back();
batch_count = 1;
a_batch_strides = {0};
b_batch_strides = {0};
batch_shape = {1};
}
if (cu::can_use_gemv(M, N, K, a_transposed, b_transposed)) {
cu::gemv(
a,
b,
out,
M,
N,
K,
batch_count,
batch_shape,
a_batch_strides,
b_batch_strides,
encoder);
return;
}
/////////////////////////////////////////////////////////////////////////////
// Invoke cublasLt
CublasGemm gemm(
cu::device(s.device),
a.dtype(),
a_transposed,
M,
K,
lda,
b_transposed,
K,
N,
ldb,
batch_shape.back(),
a_batch_strides.back(),
b_batch_strides.back());
gemm.run(encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
} }
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) { void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -136,6 +154,28 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre); auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre);
auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre); auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre);
/////////////////////////////////////////////////////////////////////////////
// Dispatch to GEMM with epilogue or AddMM
if (beta_ == 1 && c.strides(-1) == 1 && c.data_size() == out.shape(-1)) {
out.set_data(allocator::malloc(out.nbytes()));
gemm_and_bias(
encoder,
M,
N,
K,
a_transposed,
lda,
b_transposed,
ldb,
out,
a,
b,
c.data<void>(),
alpha_);
return;
}
int64_t ldc; int64_t ldc;
{ {
auto stx = c.strides()[c.ndim() - 2]; auto stx = c.strides()[c.ndim() - 2];
@@ -177,7 +217,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
} }
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////
// Invoke cublasLt // Invoke cublasLt with AddMM settings
CublasGemm gemm( CublasGemm gemm(
cu::device(s.device), cu::device(s.device),

View File

@@ -103,15 +103,21 @@ template <typename T, bool traditional, bool forward, int N = 4>
__device__ void rope_impl( __device__ void rope_impl(
const T* in, const T* in,
T* out, T* out,
int offset, const int* offset,
float inv_freq, float inv_freq,
float scale, float scale,
const cuda::std::array<int64_t, 3> strides, const cuda::std::array<int64_t, 3> strides,
const cuda::std::array<int64_t, 3> out_strides, const cuda::std::array<int64_t, 3> out_strides,
int64_t n_batch, int64_t offset_stride,
int n_head,
uint3 pos, uint3 pos,
uint3 dims) { uint3 dims) {
float L = scale * static_cast<float>(pos.y + offset); auto n_head_up = N * ((n_head + N - 1) / N);
auto head_idx = static_cast<int>((pos.z * N) % n_head_up);
auto batch_idx = (pos.z * N) / n_head_up;
auto batch_offset = offset[batch_idx * offset_stride];
float L = scale * static_cast<float>(pos.y + batch_offset);
auto mat_idx = batch_idx * n_head + head_idx;
// Compute costheta, sintheta // Compute costheta, sintheta
float theta = L * inv_freq; float theta = L * inv_freq;
@@ -123,20 +129,19 @@ __device__ void rope_impl(
size_t out_index_1, out_index_2; size_t out_index_1, out_index_2;
if (traditional) { if (traditional) {
out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] + out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] +
N * pos.z * out_strides[0]; mat_idx * out_strides[0];
out_index_2 = out_index_1 + 1; out_index_2 = out_index_1 + 1;
in_index_1 = in_index_1 =
2 * pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0]; 2 * pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0];
in_index_2 = in_index_1 + strides[2]; in_index_2 = in_index_1 + strides[2];
} else { } else {
out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] + out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] +
N * pos.z * out_strides[0]; mat_idx * out_strides[0];
out_index_2 = out_index_1 + dims.x * out_strides[2]; out_index_2 = out_index_1 + dims.x * out_strides[2];
in_index_1 = in_index_1 = pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0];
pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
in_index_2 = in_index_1 + dims.x * strides[2]; in_index_2 = in_index_1 + dims.x * strides[2];
} }
for (int i = 0; i < N && pos.z * N + i < n_batch; ++i) { for (int i = 0; i < N && head_idx + i < n_head; ++i) {
// Read and write the output // Read and write the output
float x1 = static_cast<float>(in[in_index_1]); float x1 = static_cast<float>(in[in_index_1]);
float x2 = static_cast<float>(in[in_index_2]); float x2 = static_cast<float>(in[in_index_2]);
@@ -167,7 +172,8 @@ __global__ void rope(
float base, float base,
const __grid_constant__ cuda::std::array<int64_t, 3> strides, const __grid_constant__ cuda::std::array<int64_t, 3> strides,
const __grid_constant__ cuda::std::array<int64_t, 3> out_strides, const __grid_constant__ cuda::std::array<int64_t, 3> out_strides,
int64_t n_batch, int64_t offset_stride,
int n_head,
uint3 dims) { uint3 dims) {
uint3 pos = make_uint3( uint3 pos = make_uint3(
blockIdx.x * blockDim.x + threadIdx.x, blockIdx.x * blockDim.x + threadIdx.x,
@@ -182,12 +188,13 @@ __global__ void rope(
rope_impl<T, traditional, forward>( rope_impl<T, traditional, forward>(
in, in,
out, out,
*offset, offset,
inv_freq, inv_freq,
scale, scale,
strides, strides,
out_strides, out_strides,
n_batch, offset_stride,
n_head,
pos, pos,
dims); dims);
} }
@@ -202,7 +209,8 @@ __global__ void rope_freqs(
float base, float base,
const __grid_constant__ cuda::std::array<int64_t, 3> strides, const __grid_constant__ cuda::std::array<int64_t, 3> strides,
const __grid_constant__ cuda::std::array<int64_t, 3> out_strides, const __grid_constant__ cuda::std::array<int64_t, 3> out_strides,
int64_t n_batch, int64_t offset_stride,
int n_head,
uint3 dims, uint3 dims,
int64_t freq_stride) { int64_t freq_stride) {
uint3 pos = make_uint3( uint3 pos = make_uint3(
@@ -217,12 +225,13 @@ __global__ void rope_freqs(
rope_impl<T, traditional, forward>( rope_impl<T, traditional, forward>(
in, in,
out, out,
*offset, offset,
inv_freq, inv_freq,
scale, scale,
strides, strides,
out_strides, out_strides,
n_batch, offset_stride,
n_head,
pos, pos,
dims); dims);
} }
@@ -245,23 +254,28 @@ void RoPE::eval_gpu(
auto& offset = inputs[1]; auto& offset = inputs[1];
auto& out = outputs[0]; auto& out = outputs[0];
if (in.ndim() < 3) {
throw std::runtime_error("[RoPE] Input must have at least 3 dimensions");
}
cuda::std::array<int64_t, 3> strides; cuda::std::array<int64_t, 3> strides;
cuda::std::array<int64_t, 3> out_strides; cuda::std::array<int64_t, 3> out_strides;
bool donated = false; bool donated = false;
int ndim = in.ndim(); int ndim = in.ndim();
int dispatch_ndim = in.ndim();
int B = in.shape(0);
int T = in.shape(-2);
int D = in.shape(-1);
size_t mat_size = T * D;
int dispatch_ndim = ndim;
while (in.shape(-dispatch_ndim) == 1 && dispatch_ndim > 3) { while (in.shape(-dispatch_ndim) == 1 && dispatch_ndim > 3) {
dispatch_ndim--; dispatch_ndim--;
} }
size_t mat_size = in.shape(-2) * in.shape(-1);
int N = 1;
for (int i = 1; i < (ndim - 2); ++i) {
N *= in.shape(i);
}
// We apply rope to less that the whole vector so copy to output and then // We apply rope to less that the whole vector so copy to output and then
// apply in-place. // apply in-place.
if (dims_ < in.shape(-1)) { if (dims_ < D) {
donated = true; donated = true;
auto ctype = auto ctype =
(in.flags().row_contiguous) ? CopyType::Vector : CopyType::General; (in.flags().row_contiguous) ? CopyType::Vector : CopyType::General;
@@ -302,7 +316,7 @@ void RoPE::eval_gpu(
out_strides[2] = out.strides()[ndim - 1]; out_strides[2] = out.strides()[ndim - 1];
// Some flags to help us dispatch below // Some flags to help us dispatch below
bool single = in.flags().row_contiguous && (mat_size == in.shape(-1)); bool single = in.flags().row_contiguous && B == 1 && T == 1;
bool with_freqs = inputs.size() == 3; bool with_freqs = inputs.size() == 3;
auto& encoder = cu::get_command_encoder(s); auto& encoder = cu::get_command_encoder(s);
@@ -319,7 +333,7 @@ void RoPE::eval_gpu(
if (single && !with_freqs) { if (single && !with_freqs) {
auto kernel = auto kernel =
cu::rope_single<DataType, traditional.value, forward.value>; cu::rope_single<DataType, traditional.value, forward.value>;
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size); uint2 dims = make_uint2(dims_ / 2, N);
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1); auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, kernel,
@@ -336,7 +350,7 @@ void RoPE::eval_gpu(
} else if (single) { } else if (single) {
auto kernel = auto kernel =
cu::rope_single_freqs<DataType, traditional.value, forward.value>; cu::rope_single_freqs<DataType, traditional.value, forward.value>;
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size); uint2 dims = make_uint2(dims_ / 2, N);
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1); auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, kernel,
@@ -354,10 +368,14 @@ void RoPE::eval_gpu(
} else if (with_freqs) { } else if (with_freqs) {
auto kernel = auto kernel =
cu::rope_freqs<DataType, traditional.value, forward.value>; cu::rope_freqs<DataType, traditional.value, forward.value>;
uint3 dims = int n_per_thread = 4;
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size); uint32_t dimz = B * ((N + n_per_thread - 1) / n_per_thread);
dims.z = (dims.z + 3) / 4; uint3 dims = make_uint3(dims_ / 2, T, dimz);
auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z); auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z);
int64_t offset_stride = 0;
if (inputs[1].ndim() > 0) {
offset_stride = inputs[1].strides()[0];
}
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, kernel,
grid, grid,
@@ -371,15 +389,20 @@ void RoPE::eval_gpu(
std::log2(base_), std::log2(base_),
strides, strides,
out_strides, out_strides,
in.size() / mat_size, offset_stride,
N,
dims, dims,
inputs[2].strides(0)); inputs[2].strides(0));
} else { } else {
auto kernel = cu::rope<DataType, traditional.value, forward.value>; auto kernel = cu::rope<DataType, traditional.value, forward.value>;
uint3 dims = int n_per_thread = 4;
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size); uint32_t dimz = B * ((N + n_per_thread - 1) / n_per_thread);
dims.z = (dims.z + 3) / 4; uint3 dims = make_uint3(dims_ / 2, T, dimz);
auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z); auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z);
int64_t offset_stride = 0;
if (inputs[1].ndim() > 0) {
offset_stride = inputs[1].strides()[0];
}
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, kernel,
grid, grid,
@@ -392,7 +415,8 @@ void RoPE::eval_gpu(
std::log2(base_), std::log2(base_),
strides, strides,
out_strides, out_strides,
in.size() / mat_size, offset_stride,
N,
dims); dims);
} }
}); });

View File

@@ -2,7 +2,6 @@
#include <algorithm> #include <algorithm>
#include <cassert> #include <cassert>
#include <numeric> #include <numeric>
#include <sstream>
#include "mlx/backend/gpu/copy.h" #include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
@@ -39,10 +38,11 @@ void explicit_gemm_conv_ND_gpu(
in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes())); in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes()));
// Prepare unfolding kernel // Prepare unfolding kernel
std::ostringstream kname; std::string kname;
kname << "naive_unfold_nd_" << type_to_name(in_unfolded) << "_" << N; kname.reserve(32);
concatenate(kname, "naive_unfold_nd_", type_to_name(in_unfolded), "_", N);
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str()); auto kernel = d.get_kernel(kname);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in, 0); compute_encoder.set_input_array(in, 0);
@@ -117,11 +117,12 @@ void explicit_gemm_conv_group_ND_gpu(
in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes())); in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes()));
// Prepare unfolding kernel // Prepare unfolding kernel
std::ostringstream kname; std::string kname;
kname << "naive_unfold_transpose_nd_" << type_to_name(in_unfolded) << "_" kname.reserve(32);
<< N; concatenate(
kname, "naive_unfold_transpose_nd_", type_to_name(in_unfolded), "_", N);
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str()); auto kernel = d.get_kernel(kname);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in, 0); compute_encoder.set_input_array(in, 0);
@@ -252,18 +253,32 @@ void implicit_gemm_conv_2D_gpu(
/* const int swizzle_log = */ swizzle_log}; /* const int swizzle_log = */ swizzle_log};
// Determine kernel // Determine kernel
std::ostringstream kname; std::string kname;
kname << "implicit_gemm_conv_2d_" << type_to_name(out) << "_bm" << bm << "_bn" kname.reserve(64);
<< bn << "_bk" << bk << "_wm" << wm << "_wn" << wn << "_channel_" concatenate(
<< (n_channel_specialization ? std::to_string(n_channel_specialization) kname,
: "l") "implicit_gemm_conv_2d_",
<< "_filter_" << (small_filter ? 's' : 'l'); type_to_name(out),
"_bm",
bm,
"_bn",
bn,
"_bk",
bk,
"_wm",
wm,
"_wn",
wn,
"_channel_",
n_channel_specialization ? std::to_string(n_channel_specialization) : "l",
"_filter_",
small_filter ? 's' : 'l');
// Encode and dispatch kernel // Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = get_steel_conv_kernel( auto kernel = get_steel_conv_kernel(
d, d,
kname.str(), kname,
out, out,
bm, bm,
bn, bn,
@@ -559,11 +574,16 @@ void winograd_conv_2D_gpu(
{ {
int bc = 32; int bc = 32;
int bo = 4; int bo = 4;
std::ostringstream kname; std::string kname;
kname << "winograd_conv_2d_weight_transform_" << type_to_name(out) << "_bc" kname.reserve(32);
<< bc; concatenate(
kname,
"winograd_conv_2d_weight_transform_",
type_to_name(out),
"_bc",
bc);
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str()); auto kernel = d.get_kernel(kname);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(wt, 0); compute_encoder.set_input_array(wt, 0);
@@ -587,11 +607,16 @@ void winograd_conv_2D_gpu(
int bc = 32; int bc = 32;
int wm = 2; int wm = 2;
int wn = 2; int wn = 2;
std::ostringstream kname; std::string kname;
kname << "winograd_conv_2d_input_transform_" << type_to_name(out) << "_bc" kname.reserve(32);
<< bc; concatenate(
kname,
"winograd_conv_2d_input_transform_",
type_to_name(out),
"_bc",
bc);
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str()); auto kernel = d.get_kernel(kname);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in_padded, 0); compute_encoder.set_input_array(in_padded, 0);
@@ -634,11 +659,16 @@ void winograd_conv_2D_gpu(
int bc = 32; int bc = 32;
int wm = 2; int wm = 2;
int wn = 2; int wn = 2;
std::ostringstream kname; std::string kname;
kname << "winograd_conv_2d_output_transform_" << type_to_name(out) << "_bo" kname.reserve(32);
<< bc; concatenate(
kname,
"winograd_conv_2d_output_transform_",
type_to_name(out),
"_bo",
bc);
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str()); auto kernel = d.get_kernel(kname);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(out_wg, 0); compute_encoder.set_input_array(out_wg, 0);
@@ -660,9 +690,9 @@ void depthwise_conv_2D_gpu(
const array& wt, const array& wt,
array out, array out,
const MLXConvParams<2>& conv_params) { const MLXConvParams<2>& conv_params) {
std::ostringstream kname; std::string base_name;
kname << "depthwise_conv_2d_" << type_to_name(out); base_name.reserve(32);
std::string base_name = kname.str(); concatenate(base_name, "depthwise_conv_2d_", type_to_name(out));
const int N = conv_params.N; const int N = conv_params.N;
const int ker_h = conv_params.wS[0]; const int ker_h = conv_params.wS[0];
@@ -685,15 +715,18 @@ void depthwise_conv_2D_gpu(
}; };
// clang-format off // clang-format off
kname << "_ker_h_" << ker_h std::string hash_name;
<< "_ker_w_" << ker_w hash_name.reserve(64);
<< "_str_h_" << str_h concatenate(
<< "_str_w_" << str_w hash_name,
<< "_tgp_h_" << th base_name,
<< "_tgp_w_" << tw "_ker_h_", ker_h,
<< "_do_flip_" << (do_flip ? 't' : 'n'); // clang-format on "_ker_w_", ker_w,
"_str_h_", str_h,
std::string hash_name = kname.str(); "_str_w_", str_w,
"_tgp_h_", th,
"_tgp_w_", tw,
"_do_flip_", do_flip ? 't' : 'n'); // clang-format on
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(base_name, hash_name, func_consts); auto kernel = d.get_kernel(base_name, hash_name, func_consts);
@@ -774,6 +807,56 @@ void dispatch_conv_2D_gpu(
} }
} }
void depthwise_conv_1D_gpu(
const Stream& s,
metal::Device& d,
const array& in,
array wt,
array out) {
bool large = in.size() > INT32_MAX || in.data_size() > INT32_MAX;
std::string base_name;
base_name.reserve(32);
concatenate(
base_name,
"depthwise_conv_1d_",
large ? "_large" : "",
type_to_name(out));
if (!wt.flags().row_contiguous) {
wt = contiguous_copy_gpu(wt, s);
d.add_temporary(wt, s.index);
}
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(base_name);
compute_encoder.set_compute_pipeline_state(kernel);
auto B = in.shape(0);
auto Tout = out.shape(1);
auto D = in.shape(2);
auto K = wt.shape(1);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_input_array(wt, 1);
compute_encoder.set_output_array(out, 2);
if (large) {
int64_t strides[3] = {in.strides(0), in.strides(1), in.strides(2)};
compute_encoder.set_bytes(strides, 3, 3);
} else {
int strides[3] = {
static_cast<int>(in.strides(0)),
static_cast<int>(in.strides(1)),
static_cast<int>(in.strides(2))};
compute_encoder.set_bytes(strides, 3, 3);
}
compute_encoder.set_bytes(K, 4);
auto group_dims = get_block_dims(D, Tout, B);
MTL::Size grid_dims = MTL::Size(D, Tout, B);
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
void conv_1D_gpu( void conv_1D_gpu(
const Stream& s, const Stream& s,
metal::Device& d, metal::Device& d,
@@ -790,8 +873,15 @@ void conv_1D_gpu(
bool is_idil_one = in_dilation[0] == 1; bool is_idil_one = in_dilation[0] == 1;
int C = in.shape(2); int C = in.shape(2);
int O = wt.shape(0); int O = wt.shape(0);
const int C_per_group = in.shape(2) / groups; // Fast path for fully separable 1D convolution
const int O_per_group = wt.shape(0) / groups; if (is_idil_one && (groups == C) && groups == O && wt_strides[0] == 1 &&
wt_dilation[0] == 1 && padding[0] == 0 && !flip) {
depthwise_conv_1D_gpu(s, d, in, wt, out);
return;
}
const int C_per_group = C / groups;
const int O_per_group = O / groups;
// Direct to implicit gemm conv // Direct to implicit gemm conv
if (is_idil_one && (C_per_group <= 4 || C_per_group % 16 == 0) && if (is_idil_one && (C_per_group <= 4 || C_per_group % 16 == 0) &&

View File

@@ -288,6 +288,40 @@ instantiate_depthconv2d(float32, float);
instantiate_depthconv2d(float16, half); instantiate_depthconv2d(float16, half);
instantiate_depthconv2d(bfloat16, bfloat16_t); instantiate_depthconv2d(bfloat16, bfloat16_t);
template <typename T, typename IdxT>
[[kernel]] void depthwise_conv_1d(
const device T* in [[buffer(0)]],
const device T* w [[buffer(1)]],
device T* out [[buffer(2)]],
constant const IdxT strides[3],
constant const int& kernel_size,
uint3 tid [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
out += (tid.z * static_cast<IdxT>(grid_dim.y) + tid.y) * grid_dim.x + tid.x;
in += tid.z * strides[0] + tid.y * strides[1] + tid.x * strides[2];
w += tid.x * kernel_size;
float acc = 0.0;
for (int i = 0; i < kernel_size; ++i) {
acc += static_cast<float>(in[0]) * w[i];
in += strides[1];
}
*out = static_cast<T>(acc);
}
#define instantiate_depthconv1d(iname, itype) \
instantiate_kernel( \
"depthwise_conv_1d_" #iname, depthwise_conv_1d, itype, int32_t) \
instantiate_kernel( \
"depthwise_conv_1d_" #iname "_large", \
depthwise_conv_1d, \
itype, \
int64_t)
instantiate_depthconv1d(float32, float);
instantiate_depthconv1d(float16, half);
instantiate_depthconv1d(bfloat16, bfloat16_t);
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
/// Winograd kernels /// Winograd kernels
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////

View File

@@ -10,7 +10,7 @@ void rope_single_impl(
constant const int& offset, constant const int& offset,
const float inv_freq, const float inv_freq,
constant const float& scale, constant const float& scale,
constant const size_t& stride, constant const int64_t& stride,
uint2 pos, uint2 pos,
uint2 grid) { uint2 grid) {
float L = scale * static_cast<float>(offset); float L = scale * static_cast<float>(offset);
@@ -52,7 +52,7 @@ template <typename T, bool traditional, bool forward>
device T* out [[buffer(1)]], device T* out [[buffer(1)]],
constant const int& offset, constant const int& offset,
constant const float& scale, constant const float& scale,
constant const size_t& stride, constant const int64_t& stride,
constant const float& base [[buffer(10)]], constant const float& base [[buffer(10)]],
uint2 pos [[thread_position_in_grid]], uint2 pos [[thread_position_in_grid]],
uint2 grid [[threads_per_grid]]) { uint2 grid [[threads_per_grid]]) {
@@ -68,9 +68,9 @@ template <typename T, bool traditional, bool forward>
device T* out [[buffer(1)]], device T* out [[buffer(1)]],
constant const int& offset, constant const int& offset,
constant const float& scale, constant const float& scale,
constant const size_t& stride, constant const int64_t& stride,
const device float* freqs [[buffer(10)]], const device float* freqs [[buffer(10)]],
constant const size_t& freq_stride [[buffer(11)]], constant const int64_t& freq_stride [[buffer(11)]],
uint2 pos [[thread_position_in_grid]], uint2 pos [[thread_position_in_grid]],
uint2 grid [[threads_per_grid]]) { uint2 grid [[threads_per_grid]]) {
float inv_freq = 1.0 / (freqs[freq_stride * pos.x]); float inv_freq = 1.0 / (freqs[freq_stride * pos.x]);
@@ -82,15 +82,21 @@ template <typename T, bool traditional, bool forward, int N = 4>
void rope_impl( void rope_impl(
const device T* in, const device T* in,
device T* out, device T* out,
constant const int& offset, const device int* offset,
const float inv_freq, const float inv_freq,
constant const float& scale, constant const float& scale,
constant const size_t strides[3], constant const int64_t strides[3],
constant const size_t out_strides[3], constant const int64_t out_strides[3],
constant const size_t& n_batch, constant const int64_t& offset_stride,
constant const int& n_head,
uint3 pos, uint3 pos,
uint3 grid) { uint3 grid) {
float L = scale * static_cast<float>(pos.y + offset); auto n_head_up = N * ((n_head + N - 1) / N);
auto head_idx = static_cast<int>((pos.z * N) % n_head_up);
auto batch_idx = (pos.z * N) / n_head_up;
auto batch_offset = offset[batch_idx * offset_stride];
float L = scale * static_cast<float>(pos.y + batch_offset);
auto mat_idx = batch_idx * n_head + head_idx;
// Compute costheta, sintheta // Compute costheta, sintheta
float theta = L * inv_freq; float theta = L * inv_freq;
@@ -102,20 +108,19 @@ void rope_impl(
size_t out_index_1, out_index_2; size_t out_index_1, out_index_2;
if (traditional) { if (traditional) {
out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] + out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] +
N * pos.z * out_strides[0]; mat_idx * out_strides[0];
out_index_2 = out_index_1 + 1; out_index_2 = out_index_1 + 1;
in_index_1 = in_index_1 =
2 * pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0]; 2 * pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0];
in_index_2 = in_index_1 + strides[2]; in_index_2 = in_index_1 + strides[2];
} else { } else {
out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] + out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] +
N * pos.z * out_strides[0]; mat_idx * out_strides[0];
out_index_2 = out_index_1 + grid.x * out_strides[2]; out_index_2 = out_index_1 + grid.x * out_strides[2];
in_index_1 = in_index_1 = pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0];
pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
in_index_2 = in_index_1 + grid.x * strides[2]; in_index_2 = in_index_1 + grid.x * strides[2];
} }
for (int i = 0; i < N && pos.z * N + i < n_batch; ++i) { for (int i = 0; i < N && head_idx + i < n_head; ++i) {
// Read and write the output // Read and write the output
float x1 = static_cast<float>(in[in_index_1]); float x1 = static_cast<float>(in[in_index_1]);
float x2 = static_cast<float>(in[in_index_2]); float x2 = static_cast<float>(in[in_index_2]);
@@ -141,11 +146,12 @@ template <typename T, bool traditional, bool forward, int N = 4>
[[kernel]] void rope( [[kernel]] void rope(
const device T* in [[buffer(0)]], const device T* in [[buffer(0)]],
device T* out [[buffer(1)]], device T* out [[buffer(1)]],
constant const int& offset, const device int* offset,
constant const float& scale, constant const float& scale,
constant const size_t strides[3], constant const int64_t strides[3],
constant const size_t out_strides[3], constant const int64_t out_strides[3],
constant const size_t& n_batch, constant const int64_t& offset_stride,
constant const int& n_head,
constant const float& base [[buffer(10)]], constant const float& base [[buffer(10)]],
uint3 pos [[thread_position_in_grid]], uint3 pos [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]) { uint3 grid [[threads_per_grid]]) {
@@ -159,7 +165,8 @@ template <typename T, bool traditional, bool forward, int N = 4>
scale, scale,
strides, strides,
out_strides, out_strides,
n_batch, offset_stride,
n_head,
pos, pos,
grid); grid);
} }
@@ -168,13 +175,14 @@ template <typename T, bool traditional, bool forward, int N = 4>
[[kernel]] void rope_freqs( [[kernel]] void rope_freqs(
const device T* in [[buffer(0)]], const device T* in [[buffer(0)]],
device T* out [[buffer(1)]], device T* out [[buffer(1)]],
constant const int& offset, const device int* offset,
constant const float& scale, constant const float& scale,
constant const size_t strides[3], constant const int64_t strides[3],
constant const size_t out_strides[3], constant const int64_t out_strides[3],
constant const size_t& n_batch, constant const int64_t& offset_stride,
constant const int& n_head,
const device float* freqs [[buffer(10)]], const device float* freqs [[buffer(10)]],
constant const size_t& freq_stride [[buffer(11)]], constant const int64_t& freq_stride [[buffer(11)]],
uint3 pos [[thread_position_in_grid]], uint3 pos [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]) { uint3 grid [[threads_per_grid]]) {
float inv_freq = 1.0 / (freqs[freq_stride * pos.x]); float inv_freq = 1.0 / (freqs[freq_stride * pos.x]);
@@ -186,61 +194,20 @@ template <typename T, bool traditional, bool forward, int N = 4>
scale, scale,
strides, strides,
out_strides, out_strides,
n_batch, offset_stride,
n_head,
pos, pos,
grid); grid);
} }
// clang-format off // clang-format off
#define instantiate_rope_g(name, type, traditional, forward) \ #define instantiate_rope_g(name, type, traditional, forward) \
template [[host_name("rope_" #name)]] [[kernel]] void \ instantiate_kernel("rope_" #name, rope, type, traditional, forward) \
rope<type, traditional, forward>( \ instantiate_kernel("rope_freqs_" #name, rope_freqs, type, traditional, forward)
const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \
constant const int& offset, \
constant const float& scale, \
constant const size_t strides[3], \
constant const size_t out_strides[3], \
constant const size_t& n_batch, \
constant const float& base [[buffer(10)]], \
uint3 pos [[thread_position_in_grid]], \
uint3 grid [[threads_per_grid]]); \
template [[host_name("rope_freqs_" #name)]] \
[[kernel]] void rope_freqs<type, traditional, forward>( \
const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \
constant const int& offset, \
constant const float& scale, \
constant const size_t strides[3], \
constant const size_t out_strides[3], \
constant const size_t& n_batch, \
const device float* freqs [[buffer(10)]], \
constant const size_t& freq_stride [[buffer(11)]], \
uint3 pos [[thread_position_in_grid]], \
uint3 grid [[threads_per_grid]]);
#define instantiate_rope_s(name, type, traditional, forward) \ #define instantiate_rope_s(name, type, traditional, forward) \
template [[host_name("rope_single_" #name)]] [[kernel]] void \ instantiate_kernel("rope_single_" #name, rope_single, type, traditional, forward) \
rope_single<type, traditional, forward>( \ instantiate_kernel("rope_single_freqs_" #name, rope_single_freqs, type, traditional, forward)
const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \
constant const int& offset, \
constant const float& scale, \
constant const size_t& stride, \
constant const float& base [[buffer(10)]], \
uint2 pos [[thread_position_in_grid]], \
uint2 grid [[threads_per_grid]]); \
template [[host_name("rope_single_freqs_" #name)]] \
[[kernel]] void rope_single_freqs<type, traditional, forward>( \
const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \
constant const int& offset, \
constant const float& scale, \
constant const size_t& stride, \
const device float* freqs [[buffer(10)]], \
constant const size_t& freq_stride [[buffer(11)]], \
uint2 pos [[thread_position_in_grid]], \
uint2 grid [[threads_per_grid]]);
#define instantiate_rope(name, type, traditional, forward) \ #define instantiate_rope(name, type, traditional, forward) \
instantiate_rope_s(name, type, traditional, forward) \ instantiate_rope_s(name, type, traditional, forward) \

View File

@@ -18,23 +18,29 @@ void RoPE::eval_gpu(
auto& in = inputs[0]; auto& in = inputs[0];
auto& out = outputs[0]; auto& out = outputs[0];
if (in.ndim() < 3) {
throw std::runtime_error("[RoPE] Input must have at least 3 dimensions");
}
auto& s = out.primitive().stream(); auto& s = out.primitive().stream();
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
size_t strides[3]; int64_t strides[3];
size_t out_strides[3]; int64_t out_strides[3];
bool donated = false; bool donated = false;
int ndim = in.ndim(); int ndim = in.ndim();
int dispatch_ndim = in.ndim(); int B = in.shape(0);
int T = in.shape(-2);
int D = in.shape(-1);
size_t mat_size = T * D;
int dispatch_ndim = ndim;
while (in.shape(-dispatch_ndim) == 1 && dispatch_ndim > 3) { while (in.shape(-dispatch_ndim) == 1 && dispatch_ndim > 3) {
dispatch_ndim--; dispatch_ndim--;
} }
size_t mat_size = in.shape(-2) * in.shape(-1);
if (dims_ < in.shape(-1)) { int N = 1;
for (int i = 1; i < (ndim - 2); ++i) {
N *= in.shape(i);
}
if (dims_ < D) {
donated = true; donated = true;
auto ctype = auto ctype =
(in.flags().row_contiguous) ? CopyType::Vector : CopyType::General; (in.flags().row_contiguous) ? CopyType::Vector : CopyType::General;
@@ -71,8 +77,8 @@ void RoPE::eval_gpu(
out_strides[1] = out.strides()[ndim - 2]; out_strides[1] = out.strides()[ndim - 2];
out_strides[2] = out.strides()[ndim - 1]; out_strides[2] = out.strides()[ndim - 1];
// Special case for inference (single time step and contiguous) // Special case for inference (single batch, single time step, and contiguous)
bool single = in.flags().row_contiguous && (mat_size == in.shape(-1)); bool single = in.flags().row_contiguous && B == 1 && T == 1;
bool with_freqs = inputs.size() == 3; bool with_freqs = inputs.size() == 3;
std::ostringstream kname; std::ostringstream kname;
@@ -86,24 +92,29 @@ void RoPE::eval_gpu(
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(donated ? out : in, 0); compute_encoder.set_input_array(donated ? out : in, 0);
compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(out, 1);
compute_encoder.set_input_array(inputs[1], 2); compute_encoder.set_input_array(inputs[1], 2);
compute_encoder.set_bytes(scale_, 3); compute_encoder.set_bytes(scale_, 3);
size_t n_batch = in.size() / mat_size;
MTL::Size group_dims; MTL::Size group_dims;
MTL::Size grid_dims; MTL::Size grid_dims;
if (single) { if (single) {
compute_encoder.set_bytes(out_strides, 1, 4); compute_encoder.set_bytes(out_strides, 1, 4);
uint32_t dim0 = dims_ / 2; uint32_t dim0 = dims_ / 2;
group_dims = get_block_dims(dim0, n_batch, 1); group_dims = get_block_dims(dim0, N, 1);
grid_dims = MTL::Size(dim0, n_batch, 1); grid_dims = MTL::Size(dim0, N, 1);
} else { } else {
compute_encoder.set_bytes(strides, 3, 4); compute_encoder.set_bytes(strides, 3, 4);
compute_encoder.set_bytes(out_strides, 3, 5); compute_encoder.set_bytes(out_strides, 3, 5);
compute_encoder.set_bytes(n_batch, 6); int64_t offset_stride = 0;
if (inputs[1].ndim() > 0) {
offset_stride = inputs[1].strides()[0];
}
compute_encoder.set_bytes(offset_stride, 6);
compute_encoder.set_bytes(N, 7);
uint32_t dim0 = dims_ / 2; uint32_t dim0 = dims_ / 2;
uint32_t dim1 = in.shape(-2); uint32_t dim1 = T;
uint32_t dim2 = (n_batch + n_per_thread - 1) / n_per_thread; uint32_t dim2 = B * ((N + n_per_thread - 1) / n_per_thread);
group_dims = get_block_dims(dim0, dim1, dim2); group_dims = get_block_dims(dim0, dim1, dim2);
grid_dims = MTL::Size(dim0, dim1, dim2); grid_dims = MTL::Size(dim0, dim1, dim2);
} }

View File

@@ -366,10 +366,16 @@ array rope(
msg << "[rope] Input must be a floating type but got " << x.dtype() << "."; msg << "[rope] Input must be a floating type but got " << x.dtype() << ".";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
if (offset.size() != 1) { if (offset.ndim() > 1) {
std::ostringstream msg; std::ostringstream msg;
msg << "[rope] offset must be a scalar but has shape " << offset.shape() msg << "[rope] offset must have at most one dimension but has shape "
<< "."; << offset.shape() << ".";
throw std::invalid_argument(msg.str());
}
if (offset.size() != 1 && offset.size() != x.shape(0)) {
std::ostringstream msg;
msg << "[rope] offset must be a scalar or vector with " << x.shape(0)
<< " elements but has shape " << offset.shape() << ".";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
if (!issubdtype(offset.dtype(), integer)) { if (!issubdtype(offset.dtype(), integer)) {
@@ -379,7 +385,7 @@ array rope(
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
if (offset.dtype().size() != 4) { if (offset.dtype().size() != 4) {
inputs[1] = astype(offset, uint32, s); inputs[1] = astype(offset, int32, s);
} }
if (inputs.size() == 3 && if (inputs.size() == 3 &&
(inputs[2].ndim() != 1 || inputs[2].shape(0) != dims / 2)) { (inputs[2].ndim() != 1 || inputs[2].shape(0) != dims / 2)) {
@@ -391,15 +397,26 @@ array rope(
auto fallback = [dims, traditional, base, scale, forward, s]( auto fallback = [dims, traditional, base, scale, forward, s](
std::vector<array> inputs) { std::vector<array> inputs) {
auto& shape = inputs[0].shape(); auto x = inputs[0];
int ndim = shape.size(); auto shape = x.shape();
auto x = flatten(inputs[0], 0, ndim - 3, s); if (x.ndim() == 3) {
x = expand_dims(x, 1, s);
} else if (x.ndim() > 4) {
x = flatten(x, 1, 1 + (x.ndim() - 4), s);
}
auto B = x.shape(0);
auto N = x.shape(1);
auto T = x.shape(2);
auto t = x.dtype(); auto t = x.dtype();
// Compute sines and cosines // Compute sines and cosines
auto half_dims = dims / 2; auto half_dims = dims / 2;
auto& offset = inputs[1]; auto offset = inputs[1];
if (offset.size() > 1) {
offset = expand_dims(offset, {-1, -2}, s);
}
auto positions = auto positions =
multiply(add(arange(x.shape(1), t, s), offset, s), array(scale, t), s); multiply(add(arange(x.shape(2), t, s), offset, s), array(scale, t), s);
auto default_inv_freqs = [&inputs, &s, &t, base, half_dims]() { auto default_inv_freqs = [&inputs, &s, &t, base, half_dims]() {
return exp( return exp(
@@ -412,8 +429,7 @@ array rope(
auto inv_freqs = inputs.size() == 3 ? astype(reciprocal(inputs[2], s), t, s) auto inv_freqs = inputs.size() == 3 ? astype(reciprocal(inputs[2], s), t, s)
: default_inv_freqs(); : default_inv_freqs();
auto theta = auto theta = multiply(expand_dims(positions, -1, s), inv_freqs, s);
multiply(expand_dims(positions, 1, s), expand_dims(inv_freqs, 0, s), s);
auto coss = cos(theta, s); auto coss = cos(theta, s);
auto sins = sin(theta, s); auto sins = sin(theta, s);
@@ -436,32 +452,30 @@ array rope(
}; };
if (traditional) { if (traditional) {
auto x1 = auto x1 = slice(x, {0, 0, 0, 0}, {B, N, T, dims}, {1, 1, 1, 2}, s);
slice(x, {0, 0, 0}, {x.shape(0), x.shape(1), dims}, {1, 1, 2}, s); auto x2 = slice(x, {0, 0, 0, 1}, {B, N, T, dims}, {1, 1, 1, 2}, s);
auto x2 =
slice(x, {0, 0, 1}, {x.shape(0), x.shape(1), dims}, {1, 1, 2}, s);
auto outs = apply_rope(x1, x2, coss, sins); auto outs = apply_rope(x1, x2, coss, sins);
for (auto& o : outs) { for (auto& o : outs) {
o = expand_dims(o, 3, s); o = expand_dims(o, -1, s);
} }
auto out = concatenate(outs, 3, s); auto out = reshape(concatenate(outs, -1, s), {B, N, T, dims}, s);
if (dims < x.shape(-1)) { if (dims < x.shape(-1)) {
out = reshape(out, {x.shape(0), x.shape(1), dims}); out =
out = concatenate({out, slice(x, {0, 0, dims}, x.shape(), s)}, 2, s); concatenate({out, slice(x, {0, 0, 0, dims}, x.shape(), s)}, -1, s);
} }
return std::vector<array>{reshape(out, shape, s)}; return std::vector<array>{reshape(out, shape, s)};
} else { } else {
auto out_s = x.shape(); auto out_s = x.shape();
out_s.back() = half_dims; out_s.back() = half_dims;
auto x1 = slice(x, {0, 0, 0}, out_s, s); auto x1 = slice(x, {0, 0, 0, 0}, out_s, s);
out_s.back() = dims; out_s.back() = dims;
auto x2 = slice(x, {0, 0, half_dims}, out_s, s); auto x2 = slice(x, {0, 0, 0, half_dims}, out_s, s);
auto outs = apply_rope(x1, x2, coss, sins); auto outs = apply_rope(x1, x2, coss, sins);
if (dims < x.shape(-1)) { if (dims < x.shape(-1)) {
outs.push_back(slice(x, {0, 0, dims}, x.shape(), s)); outs.push_back(slice(x, {0, 0, 0, dims}, x.shape(), s));
} }
return std::vector<array>{reshape(concatenate(outs, 2, s), shape, s)}; return std::vector<array>{reshape(concatenate(outs, -1, s), shape, s)};
} }
}; };
auto stream = to_stream(s); auto stream = to_stream(s);

View File

@@ -1,10 +1,8 @@
// Copyright © 2025 Apple Inc. // Copyright © 2025 Apple Inc.
#include <string>
namespace mlx::core { namespace mlx::core {
std::string version() { const char* version() {
return MLX_VERSION; return MLX_VERSION;
} }

View File

@@ -15,6 +15,6 @@ namespace mlx::core {
* *
* For dev builds, the version will include the suffix ".devYYYYMMDD+hash" * For dev builds, the version will include the suffix ".devYYYYMMDD+hash"
*/ */
std::string version(); const char* version();
} // namespace mlx::core } // namespace mlx::core

View File

@@ -52,7 +52,6 @@ set_target_properties(
${MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY}) ${MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY})
target_link_libraries(core PRIVATE mlx) target_link_libraries(core PRIVATE mlx)
target_compile_definitions(core PRIVATE _VERSION_=${MLX_VERSION})
if(BUILD_SHARED_LIBS) if(BUILD_SHARED_LIBS)
if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin") if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")

View File

@@ -164,8 +164,13 @@ void init_fast(nb::module_& parent_module) {
R"pbdoc( R"pbdoc(
Apply rotary positional encoding to the input. Apply rotary positional encoding to the input.
The input is expected to be at least 3D with shape ``(B, *, T, D)`` where:
* ``B`` is the batch size.
* ``T`` is the sequence length.
* ``D`` is the feature dimension.
Args: Args:
a (array): Input array. a (array): The input array.
dims (int): The feature dimensions to be rotated. If the input feature dims (int): The feature dimensions to be rotated. If the input feature
is larger than dims then the rest is left unchanged. is larger than dims then the rest is left unchanged.
traditional (bool): If set to ``True`` choose the traditional traditional (bool): If set to ``True`` choose the traditional
@@ -174,7 +179,9 @@ void init_fast(nb::module_& parent_module) {
each dimension in the positional encodings. Exactly one of ``base`` and each dimension in the positional encodings. Exactly one of ``base`` and
``freqs`` must be ``None``. ``freqs`` must be ``None``.
scale (float): The scale used to scale the positions. scale (float): The scale used to scale the positions.
offset (int or array): The position offset to start at. offset (int or array): The position offset to start at. If an
:obj:`array` is given it can be a scalar or vector of ``B``
offsets for each example in the batch.
freqs (array, optional): Optional frequencies to use with RoPE. freqs (array, optional): Optional frequencies to use with RoPE.
If set, the ``base`` parameter must be ``None``. Default: ``None``. If set, the ``base`` parameter must be ``None``. Default: ``None``.

View File

@@ -2,9 +2,9 @@
#include <nanobind/nanobind.h> #include <nanobind/nanobind.h>
#define STRINGIFY(x) #x #include "mlx/version.h"
#define TOSTRING(x) STRINGIFY(x)
namespace mx = mlx::core;
namespace nb = nanobind; namespace nb = nanobind;
void init_mlx_func(nb::module_&); void init_mlx_func(nb::module_&);
@@ -48,5 +48,5 @@ NB_MODULE(core, m) {
init_distributed(m); init_distributed(m);
init_export(m); init_export(m);
m.attr("__version__") = TOSTRING(_VERSION_); m.attr("__version__") = mx::version();
} }

View File

@@ -91,7 +91,7 @@ mx::array to_array_with_accessor(nb::object obj) {
return nb::cast<mx::array>(obj.attr("__mlx_array__")()); return nb::cast<mx::array>(obj.attr("__mlx_array__")());
} else { } else {
std::ostringstream msg; std::ostringstream msg;
msg << "Invalid type " << nb::type_name(obj.type()).c_str() msg << "Invalid type " << nb::type_name(obj.type()).c_str()
<< " received in array initialization."; << " received in array initialization.";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }

View File

@@ -594,124 +594,123 @@ class TestBlas(mlx_tests.MLXTestCase):
np.random.seed(0) np.random.seed(0)
# Batched matmul # Batched matmul
alpha = 0.5 alpha = 0.5
beta = 2.0 for beta in (1.0, 2.0):
# c must broadcast to the output shape
with self.assertRaises(ValueError):
mx.addmm(mx.zeros((2, 2, 2)), mx.zeros((2, 2)), mx.zeros((2, 2)))
# c must broadcast to the output shape # Regular batched case
with self.assertRaises(ValueError): a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
mx.addmm(mx.zeros((2, 2, 2)), mx.zeros((2, 2)), mx.zeros((2, 2))) b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 16)).astype(np.float32)
# Regular batched case a_mlx = mx.array(a_npy)
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32) b_mlx = mx.array(b_npy)
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 16)).astype(np.float32)
a_mlx = mx.array(a_npy) for c_shape in ((1,), (1, 16), (32, 1, 16), (1, 128, 16)):
b_mlx = mx.array(b_npy) c_npy = np.ones(c_shape).astype(np.float32)
c_mlx = mx.array(c_npy)
for c_shape in ((1,), (1, 16), (32, 1, 16), (1, 128, 16)): d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
c_npy = np.ones(c_shape).astype(np.float32) d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
c_mlx = mx.array(c_npy)
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta) self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape)) # Batched and transposed matmul
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5)) b_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
b_mlx = mx.array(b_npy)
# Batched and transposed matmul for c_shape in ((1,), (32, 1, 128), (1, 128)):
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32) c_npy = np.ones(c_shape).astype(np.float32)
b_mlx = mx.array(b_npy) c_mlx = mx.array(c_npy)
for c_shape in ((1,), (32, 1, 128), (1, 128)): b_np_t = np.transpose(b_npy, (0, 2, 1))
c_npy = np.ones(c_shape).astype(np.float32) b_mx_t = mx.transpose(b_mlx, (0, 2, 1))
c_mlx = mx.array(c_npy)
b_np_t = np.transpose(b_npy, (0, 2, 1)) d_npy = alpha * (a_npy @ b_np_t) + beta * c_npy
b_mx_t = mx.transpose(b_mlx, (0, 2, 1)) d_mlx = mx.addmm(c_mlx, a_mlx, b_mx_t, alpha, beta)
d_npy = alpha * (a_npy @ b_np_t) + beta * c_npy self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
d_mlx = mx.addmm(c_mlx, a_mlx, b_mx_t, alpha, beta) self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
# Batched matmul with simple broadcast
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (16, 16)).astype(np.float32)
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape)) a_mlx = mx.array(a_npy)
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5)) b_mlx = mx.array(b_npy)
# Batched matmul with simple broadcast
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (16, 16)).astype(np.float32)
a_mlx = mx.array(a_npy) for c_shape in ((1,), (1, 16), (32, 1, 16), (1, 128, 16)):
b_mlx = mx.array(b_npy) c_npy = np.ones(c_shape).astype(np.float32)
c_mlx = mx.array(c_npy)
for c_shape in ((1,), (1, 16), (32, 1, 16), (1, 128, 16)): d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
c_npy = np.ones(c_shape).astype(np.float32) d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
c_mlx = mx.array(c_npy)
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta) self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
# Matmul with vector
a_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 128)).astype(np.float32)
a_mlx = mx.array(a_npy)
b_mlx = mx.array(b_npy)
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape)) for c_shape in ((1,), (128,), (32, 128)):
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5)) c_npy = np.ones(c_shape).astype(np.float32)
# Matmul with vector c_mlx = mx.array(c_npy)
a_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 128)).astype(np.float32)
a_mlx = mx.array(a_npy)
b_mlx = mx.array(b_npy)
for c_shape in ((1,), (128,), (32, 128)): d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
c_npy = np.ones(c_shape).astype(np.float32) d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
c_mlx = mx.array(c_npy)
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta) self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape)) # Matmul with vector
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5)) a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32)
a_mlx = mx.array(a_npy)
b_mlx = mx.array(b_npy)
# Matmul with vector for c_shape in ((1,), (32, 128)):
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32) c_npy = np.ones(c_shape).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32) c_mlx = mx.array(c_npy)
a_mlx = mx.array(a_npy)
b_mlx = mx.array(b_npy)
for c_shape in ((1,), (32, 128)): d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
c_npy = np.ones(c_shape).astype(np.float32) d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
c_mlx = mx.array(c_npy)
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta) self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape)) # Split K specializtion
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5)) a_npy = np.random.normal(0.0, 1.0 / 128, (64, 4096)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (4096, 32)).astype(np.float32)
# Split K specializtion a_mlx = mx.array(a_npy)
a_npy = np.random.normal(0.0, 1.0 / 128, (64, 4096)).astype(np.float32) b_mlx = mx.array(b_npy)
b_npy = np.random.normal(0.0, 1.0 / 128, (4096, 32)).astype(np.float32)
a_mlx = mx.array(a_npy) for c_shape in ((1,), (1, 32), (64, 1), (64, 32)):
b_mlx = mx.array(b_npy) c_npy = np.ones(c_shape).astype(np.float32)
c_mlx = mx.array(c_npy)
for c_shape in ((1,), (1, 32), (64, 1), (64, 32)): d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
c_npy = np.ones(c_shape).astype(np.float32) d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
c_mlx = mx.array(c_npy)
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta) self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape)) # Transposed c
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5)) a = mx.ones((10, 5)).T
b = mx.ones((5, 5))
out = mx.addmm(a, b, a, beta=beta, alpha=alpha)
expected = beta * a + alpha * (b @ a)
self.assertTrue(mx.allclose(expected, out))
# Transposed c # Broadcast c
a = mx.ones((10, 5)).T a = mx.ones((5, 5))
b = mx.ones((5, 5)) b = mx.ones((5, 5))
out = mx.addmm(a, b, a, beta=1.5, alpha=0.5) c = mx.ones((1, 5))
expected = 1.5 * a + 0.5 * (b @ a) out = mx.addmm(c, a, b, beta=beta, alpha=alpha)
self.assertTrue(mx.allclose(expected, out)) expected = beta * c + alpha * (a @ b)
self.assertTrue(mx.allclose(expected, out))
# Broadcast c
a = mx.ones((5, 5))
b = mx.ones((5, 5))
c = mx.ones((1, 5))
out = mx.addmm(c, a, b, beta=1.5, alpha=0.5)
expected = 1.5 * c + 0.5 * (a @ b)
self.assertTrue(mx.allclose(expected, out))
def test_addmm_grad(self): def test_addmm_grad(self):
def make_ref_addmm(alpha, beta): def make_ref_addmm(alpha, beta):
@@ -724,33 +723,32 @@ class TestBlas(mlx_tests.MLXTestCase):
shapes = ((1, 64, 32, 128), (4, 28, 24, 47), (1, 1, 24, 47)) shapes = ((1, 64, 32, 128), (4, 28, 24, 47), (1, 1, 24, 47))
alpha = 2.0 alpha = 2.0
beta = 0.5 for beta in (1.0, 0.5):
f_test = make_addmm(alpha, beta)
f_ref = make_ref_addmm(alpha, beta)
f_test = make_addmm(alpha, beta) for B, M, N, K in shapes:
f_ref = make_ref_addmm(alpha, beta) cotan = mx.ones((B, M, N))
c = mx.random.normal((B, M, N))
a = mx.random.normal((B, M, K))
b = mx.random.normal((B, K, N))
for B, M, N, K in shapes: out_ref, dout_ref = mx.vjp(
cotan = mx.ones((B, M, N)) f_ref,
c = mx.random.normal((B, M, N)) [c, a, b],
a = mx.random.normal((B, M, K)) [cotan],
b = mx.random.normal((B, K, N)) )
out_test, dout_test = mx.vjp(
f_test,
[c, a, b],
[cotan],
)
out_ref, dout_ref = mx.vjp( self.assertTrue(mx.allclose(out_ref[0], out_test[0], atol=1e-4).item())
f_ref,
[c, a, b],
[cotan],
)
out_test, dout_test = mx.vjp(
f_test,
[c, a, b],
[cotan],
)
self.assertTrue(mx.allclose(out_ref[0], out_test[0], atol=1e-4).item()) for r, t in zip(dout_ref, dout_test):
self.assertEqual(r.shape, t.shape)
for r, t in zip(dout_ref, dout_test): self.assertTrue(mx.allclose(r, t, atol=1e-4).item())
self.assertEqual(r.shape, t.shape)
self.assertTrue(mx.allclose(r, t, atol=1e-4).item())
def test_empty_matmul(self): def test_empty_matmul(self):
a = mx.array([[], []]).T a = mx.array([[], []]).T

View File

@@ -8,18 +8,23 @@ import mlx_tests
def rope_orig(x, dims, traditional, base, scale, offset, freqs=None): def rope_orig(x, dims, traditional, base, scale, offset, freqs=None):
offset = offset.item() if isinstance(offset, mx.array) else offset N = x.shape[-2]
N = x.shape[-2] + offset
dtype = x.dtype dtype = x.dtype
half_D = dims // 2 half_D = dims // 2
positions = mx.arange(offset, N, dtype=dtype) * scale positions = mx.arange(N, dtype=dtype)
if isinstance(offset, mx.array) and offset.size > 1:
expand = tuple(range(1, x.ndim - 1))
positions = mx.expand_dims(offset, expand) + positions
else:
positions = offset + positions
positions = positions * scale
if freqs is None: if freqs is None:
inv_freqs = mx.exp( inv_freqs = mx.exp(
-mx.arange(0.0, half_D, dtype=dtype) * (math.log(base) / half_D) -mx.arange(0.0, half_D, dtype=dtype) * (math.log(base) / half_D)
) )
else: else:
inv_freqs = (1 / freqs).astype(x.dtype) inv_freqs = (1 / freqs).astype(x.dtype)
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(inv_freqs, (1, -1)) theta = mx.expand_dims(positions, -1) * inv_freqs
costheta, sintheta = mx.cos(theta), mx.sin(theta) costheta, sintheta = mx.cos(theta), mx.sin(theta)
if traditional: if traditional:
x1 = x[..., :dims:2] x1 = x[..., :dims:2]
@@ -214,6 +219,7 @@ class TestFast(mlx_tests.MLXTestCase):
) )
self.assertEqual(dtype, rx.dtype) self.assertEqual(dtype, rx.dtype)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype]) self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
return
# Test single vector # Test single vector
x = mx.random.uniform(shape=(1, 1, dims)) x = mx.random.uniform(shape=(1, 1, dims))
@@ -277,6 +283,55 @@ class TestFast(mlx_tests.MLXTestCase):
g2 = mx.grad(f2)(x, y) g2 = mx.grad(f2)(x, y)
self.assertLess(mx.abs(g1 - g2).max(), 1e-5) self.assertLess(mx.abs(g1 - g2).max(), 1e-5)
def test_rope_batch(self):
T = 4
base = 10000.0
scale = 1.0
traditional = True
batch_sizes = [3, 8, 11]
num_heads = [1, 3, 5]
dims = 32
x = mx.random.uniform(shape=(8, 4, T, dims))
offset = mx.array([1, 2, 3])
with self.assertRaises(ValueError):
mx.fast.rope(
x,
dims,
traditional=traditional,
base=base,
scale=scale,
offset=offset,
)
for batch_size in batch_sizes:
for n_head in num_heads:
x = mx.random.uniform(shape=(batch_size, n_head, T, dims))
offset = mx.arange(batch_size)
rx = rope_orig(x, dims, traditional, base, scale, offset)
rx_fast = mx.fast.rope(
x,
dims,
traditional=traditional,
base=base,
scale=scale,
offset=offset,
)
self.assertLess(mx.abs(rx - rx_fast).max(), 1e-5)
x = mx.random.normal(shape=(2, 6, 8, 64)).transpose(0, 2, 1, 3)
dims = 64
offset = 0
rx_fast = mx.fast.rope(
x, dims, traditional=traditional, scale=scale, base=base, offset=offset
)
rx_fast_single = mx.fast.rope(
x[0:1], dims, traditional=traditional, scale=scale, base=base, offset=offset
)
rx = rope_orig(x, dims, traditional, base, scale, offset)
self.assertLess(mx.abs(rx - rx_fast).max(), 1e-5)
def test_rms_norm(self): def test_rms_norm(self):
# Per dtype absolute tolerance # Per dtype absolute tolerance
tolerances = {mx.float32: 1e-6, mx.float16: 1e-3, mx.bfloat16: 1e-2} tolerances = {mx.float32: 1e-6, mx.float16: 1e-3, mx.bfloat16: 1e-2}

View File

@@ -121,7 +121,8 @@ class CMakeBuild(build_ext):
build_args += [f"-j{os.cpu_count()}"] build_args += [f"-j{os.cpu_count()}"]
# Avoid cache miss when building from temporary dirs. # Avoid cache miss when building from temporary dirs.
os.environ["CCACHE_BASEDIR"] = os.path.abspath(self.build_temp) os.environ["CCACHE_BASEDIR"] = os.path.realpath(self.build_temp)
os.environ["CCACHE_NOHASHDIR"] = "true"
subprocess.run( subprocess.run(
["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True ["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True

View File

@@ -1,10 +1,7 @@
# Doctest works fine with cmake 3.5
set(CMAKE_POLICY_VERSION_MINIMUM 3.5)
FetchContent_Declare( FetchContent_Declare(
doctest doctest
GIT_REPOSITORY "https://github.com/onqtam/doctest" GIT_REPOSITORY https://github.com/onqtam/doctest.git
GIT_TAG "ae7a13539fb71f270b87eb2e874fbac80bc8dda2") GIT_TAG v2.4.12)
FetchContent_MakeAvailable(doctest) FetchContent_MakeAvailable(doctest)
add_executable(tests ${PROJECT_SOURCE_DIR}/tests/tests.cpp) add_executable(tests ${PROJECT_SOURCE_DIR}/tests/tests.cpp)