mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
6 Commits
c1e3340b23
...
44cc5da4bc
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
44cc5da4bc | ||
|
|
dde3682b69 | ||
|
|
17310d91a6 | ||
|
|
b194d65a6a | ||
|
|
a44b27f5f8 | ||
|
|
e5a33f2223 |
@@ -87,22 +87,21 @@ cmake_policy(SET CMP0135 NEW)
|
||||
|
||||
add_library(mlx)
|
||||
|
||||
if(MLX_BUILD_METAL)
|
||||
set(METAL_LIB "-framework Metal")
|
||||
set(FOUNDATION_LIB "-framework Foundation")
|
||||
set(QUARTZ_LIB "-framework QuartzCore")
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_CUDA)
|
||||
enable_language(CUDA)
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_METAL AND NOT METAL_LIB)
|
||||
message(STATUS "Metal not found. Unable to build GPU")
|
||||
set(MLX_BUILD_METAL OFF)
|
||||
set(MLX_METAL_DEBUG OFF)
|
||||
elseif(MLX_BUILD_METAL)
|
||||
message(STATUS "Building METAL sources")
|
||||
if(MLX_BUILD_METAL)
|
||||
find_library(METAL_LIB Metal)
|
||||
find_library(FOUNDATION_LIB Foundation)
|
||||
find_library(QUARTZ_LIB QuartzCore)
|
||||
if(METAL_LIB)
|
||||
message(STATUS "Metal found ${METAL_LIB}")
|
||||
else()
|
||||
message(
|
||||
FATAL_ERROR
|
||||
"Metal not found. Set MLX_BUILD_METAL=OFF to build without GPU")
|
||||
endif()
|
||||
|
||||
if(MLX_METAL_DEBUG)
|
||||
add_compile_definitions(MLX_METAL_DEBUG)
|
||||
@@ -111,7 +110,8 @@ elseif(MLX_BUILD_METAL)
|
||||
# Throw an error if xcrun not found
|
||||
execute_process(
|
||||
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
|
||||
OUTPUT_VARIABLE MACOS_SDK_VERSION COMMAND_ERROR_IS_FATAL ANY)
|
||||
OUTPUT_VARIABLE MACOS_SDK_VERSION
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
|
||||
|
||||
if(${MACOS_SDK_VERSION} LESS 14.0)
|
||||
message(
|
||||
|
||||
@@ -85,10 +85,10 @@ cublasLtMatrixLayout_t create_matrix_layout(
|
||||
int32_t batch_count,
|
||||
int64_t batch_stride) {
|
||||
cublasLtMatrixLayout_t desc;
|
||||
if (transposed) {
|
||||
std::swap(rows, cols);
|
||||
}
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&desc, type, rows, cols, ld));
|
||||
cublasLtOrder_t order = transposed ? CUBLASLT_ORDER_COL : CUBLASLT_ORDER_ROW;
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
|
||||
desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(cublasLtOrder_t)));
|
||||
if (batch_count > 1) {
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
|
||||
desc,
|
||||
@@ -138,25 +138,34 @@ CublasGemm::CublasGemm(
|
||||
CUBLASLT_MATMUL_DESC_POINTER_MODE,
|
||||
&pointer_mode,
|
||||
sizeof(int32_t)));
|
||||
cublasOperation_t op = CUBLAS_OP_N;
|
||||
|
||||
// In cublasLt matrices use column-major layout, while it is possible to use
|
||||
// the CUBLASLT_ORDER_ROW option to switch to row-major layout, the bias
|
||||
// epilogue does not work with the option. So instead we swap A and B to make
|
||||
// cublasLt return the row-major result, which works because:
|
||||
// - the data of a matrix in row-major layout is identical to its transpose in
|
||||
// column-major layout
|
||||
// - C^T = (A @ B)^T = B^T @ A^T
|
||||
cublasOperation_t a_op = b_transposed ? CUBLAS_OP_T : CUBLAS_OP_N;
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
||||
matmul_desc_,
|
||||
CUBLASLT_MATMUL_DESC_TRANSA,
|
||||
&op,
|
||||
&a_op,
|
||||
sizeof(cublasOperation_t)));
|
||||
cublasOperation_t b_op = a_transposed ? CUBLAS_OP_T : CUBLAS_OP_N;
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
||||
matmul_desc_,
|
||||
CUBLASLT_MATMUL_DESC_TRANSB,
|
||||
&op,
|
||||
&b_op,
|
||||
sizeof(cublasOperation_t)));
|
||||
|
||||
auto type = dtype_to_cublas_type(dtype);
|
||||
a_desc_ = create_matrix_layout(
|
||||
type, a_rows, a_cols, a_transposed, lda, batch_count, a_batch_stride);
|
||||
type, b_cols, b_rows, b_transposed, ldb, batch_count, b_batch_stride);
|
||||
b_desc_ = create_matrix_layout(
|
||||
type, b_rows, b_cols, b_transposed, ldb, batch_count, b_batch_stride);
|
||||
type, a_cols, a_rows, a_transposed, lda, batch_count, a_batch_stride);
|
||||
out_desc_ = create_matrix_layout(
|
||||
type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols);
|
||||
type, b_cols, a_rows, false, b_cols, batch_count, a_rows * b_cols);
|
||||
}
|
||||
|
||||
CublasGemm::CublasGemm(
|
||||
@@ -191,7 +200,7 @@ CublasGemm::CublasGemm(
|
||||
b_batch_stride) {
|
||||
auto type = dtype_to_cublas_type(dtype);
|
||||
c_desc_ = create_matrix_layout(
|
||||
type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride);
|
||||
type, b_cols, a_rows, false, ldc, batch_count, c_batch_stride);
|
||||
}
|
||||
|
||||
CublasGemm::~CublasGemm() {
|
||||
@@ -213,14 +222,25 @@ void CublasGemm::set_out(
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_));
|
||||
out_desc_ = create_matrix_layout(
|
||||
dtype_to_cublas_type(dtype),
|
||||
rows,
|
||||
cols,
|
||||
rows,
|
||||
transposed,
|
||||
ld,
|
||||
batch_count,
|
||||
batch_stride);
|
||||
}
|
||||
|
||||
void CublasGemm::set_bias(void* bias) {
|
||||
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS;
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
||||
matmul_desc_,
|
||||
CUBLASLT_MATMUL_DESC_EPILOGUE,
|
||||
&epilogue,
|
||||
sizeof(epilogue)));
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
||||
matmul_desc_, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias)));
|
||||
}
|
||||
|
||||
void CublasGemm::run(
|
||||
cu::CommandEncoder& encoder,
|
||||
array& out,
|
||||
@@ -228,11 +248,19 @@ void CublasGemm::run(
|
||||
const array& b,
|
||||
const Shape& batch_shape,
|
||||
const Strides& a_batch_strides,
|
||||
const Strides& b_batch_strides) {
|
||||
const Strides& b_batch_strides,
|
||||
float alpha) {
|
||||
int batch_count = out.size() / (M_ * N_);
|
||||
if (batch_count / batch_shape.back() > 1) {
|
||||
run_batched(
|
||||
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
|
||||
encoder,
|
||||
out,
|
||||
a,
|
||||
b,
|
||||
batch_shape,
|
||||
a_batch_strides,
|
||||
b_batch_strides,
|
||||
alpha);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -240,7 +268,13 @@ void CublasGemm::run(
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
execute(encoder, out.data<void>(), a.data<void>(), b.data<void>(), nullptr);
|
||||
execute(
|
||||
encoder,
|
||||
out.data<void>(),
|
||||
a.data<void>(),
|
||||
b.data<void>(),
|
||||
nullptr,
|
||||
alpha);
|
||||
}
|
||||
|
||||
void CublasGemm::run(
|
||||
@@ -330,9 +364,9 @@ void CublasGemm::execute(
|
||||
handle_,
|
||||
matmul_desc_,
|
||||
&alpha,
|
||||
a,
|
||||
b, // a and b are swapped
|
||||
a_desc_,
|
||||
b,
|
||||
a,
|
||||
b_desc_,
|
||||
&beta,
|
||||
c ? c : out,
|
||||
|
||||
@@ -55,6 +55,8 @@ class CublasGemm {
|
||||
int32_t batch_count,
|
||||
int64_t batch_stride);
|
||||
|
||||
void set_bias(void* bias);
|
||||
|
||||
void run(
|
||||
cu::CommandEncoder& encoder,
|
||||
array& out,
|
||||
@@ -62,7 +64,8 @@ class CublasGemm {
|
||||
const array& b,
|
||||
const Shape& batch_shape,
|
||||
const Strides& a_batch_strides,
|
||||
const Strides& b_batch_strides);
|
||||
const Strides& b_batch_strides,
|
||||
float alpha = 1.0f);
|
||||
|
||||
void run(
|
||||
cu::CommandEncoder& encoder,
|
||||
@@ -85,7 +88,8 @@ class CublasGemm {
|
||||
const array& b,
|
||||
const Shape& batch_shape,
|
||||
const Strides& a_batch_strides,
|
||||
const Strides& b_batch_strides);
|
||||
const Strides& b_batch_strides,
|
||||
float alpha);
|
||||
|
||||
void run_batched(
|
||||
cu::CommandEncoder& encoder,
|
||||
|
||||
@@ -13,7 +13,8 @@ void CublasGemm::run_batched(
|
||||
const array& b,
|
||||
const Shape& batch_shape,
|
||||
const Strides& a_batch_strides,
|
||||
const Strides& b_batch_strides) {
|
||||
const Strides& b_batch_strides,
|
||||
float alpha) {
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
@@ -27,7 +28,8 @@ void CublasGemm::run_batched(
|
||||
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
|
||||
a.data<int8_t>() + a.itemsize() * a_it.loc,
|
||||
b.data<int8_t>() + b.itemsize() * b_it.loc,
|
||||
nullptr);
|
||||
nullptr,
|
||||
alpha);
|
||||
a_it.step();
|
||||
b_it.step();
|
||||
}
|
||||
|
||||
@@ -154,7 +154,8 @@ void CublasGemm::run_batched(
|
||||
const array& b,
|
||||
const Shape& batch_shape,
|
||||
const Strides& a_batch_strides,
|
||||
const Strides& b_batch_strides) {
|
||||
const Strides& b_batch_strides,
|
||||
float alpha) {
|
||||
int batch_count = out.size() / (M_ * N_);
|
||||
set_pointer_mode(a_desc_, batch_count);
|
||||
set_pointer_mode(b_desc_, batch_count);
|
||||
@@ -226,7 +227,8 @@ void CublasGemm::run_batched(
|
||||
reinterpret_cast<void*>(out_pointers),
|
||||
reinterpret_cast<void*>(a_pointers),
|
||||
reinterpret_cast<void*>(b_pointers),
|
||||
nullptr);
|
||||
nullptr,
|
||||
alpha);
|
||||
}
|
||||
|
||||
void CublasGemm::run_batched(
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include <numeric>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
std::tuple<bool, int64_t, array>
|
||||
@@ -28,6 +29,76 @@ check_transpose(cu::CommandEncoder& enc, const Stream& s, const array& arr) {
|
||||
}
|
||||
}
|
||||
|
||||
void gemm_and_bias(
|
||||
cu::CommandEncoder& encoder,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
bool a_transposed,
|
||||
int64_t lda,
|
||||
bool b_transposed,
|
||||
int64_t ldb,
|
||||
array& out,
|
||||
const array& a,
|
||||
const array& b,
|
||||
void* bias = nullptr,
|
||||
float alpha = 1.0f) {
|
||||
// Check and collapse batch dimensions
|
||||
auto [batch_shape, a_batch_strides, b_batch_strides] = collapse_batches(a, b);
|
||||
|
||||
auto batch_count = out.size() / (M * N);
|
||||
|
||||
// Collapse batches into M if needed
|
||||
if (batch_count > 1 && !a_transposed && batch_shape.size() == 1 &&
|
||||
a.strides()[a.ndim() - 2] == K && a_batch_strides.back() == M * K &&
|
||||
b_batch_strides.back() == 0) {
|
||||
M *= batch_shape.back();
|
||||
batch_count = 1;
|
||||
|
||||
a_batch_strides = {0};
|
||||
b_batch_strides = {0};
|
||||
batch_shape = {1};
|
||||
}
|
||||
|
||||
// Use gemmv when possible
|
||||
if (!bias && cu::can_use_gemv(M, N, K, a_transposed, b_transposed)) {
|
||||
cu::gemv(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
batch_count,
|
||||
batch_shape,
|
||||
a_batch_strides,
|
||||
b_batch_strides,
|
||||
encoder);
|
||||
return;
|
||||
}
|
||||
|
||||
// Invoke cublasLt
|
||||
CublasGemm gemm(
|
||||
encoder.device(),
|
||||
a.dtype(),
|
||||
a_transposed,
|
||||
M,
|
||||
K,
|
||||
lda,
|
||||
b_transposed,
|
||||
K,
|
||||
N,
|
||||
ldb,
|
||||
batch_shape.back(),
|
||||
a_batch_strides.back(),
|
||||
b_batch_strides.back());
|
||||
if (bias) {
|
||||
gemm.set_bias(bias);
|
||||
}
|
||||
gemm.run(
|
||||
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides, alpha);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -48,9 +119,6 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Init checks and prep
|
||||
|
||||
int M = a_pre.shape(-2);
|
||||
int N = b_pre.shape(-1);
|
||||
int K = a_pre.shape(-1);
|
||||
@@ -60,58 +128,8 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre);
|
||||
auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre);
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Check and collapse batch dimensions
|
||||
|
||||
auto [batch_shape, a_batch_strides, b_batch_strides] = collapse_batches(a, b);
|
||||
|
||||
auto batch_count = out.size() / (M * N);
|
||||
|
||||
// Collapse batches into M if needed
|
||||
if (batch_count > 1 && !a_transposed && batch_shape.size() == 1 &&
|
||||
a.strides()[a.ndim() - 2] == K && a_batch_strides.back() == M * K &&
|
||||
b_batch_strides.back() == 0) {
|
||||
M *= batch_shape.back();
|
||||
batch_count = 1;
|
||||
|
||||
a_batch_strides = {0};
|
||||
b_batch_strides = {0};
|
||||
batch_shape = {1};
|
||||
}
|
||||
|
||||
if (cu::can_use_gemv(M, N, K, a_transposed, b_transposed)) {
|
||||
cu::gemv(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
batch_count,
|
||||
batch_shape,
|
||||
a_batch_strides,
|
||||
b_batch_strides,
|
||||
encoder);
|
||||
return;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Invoke cublasLt
|
||||
CublasGemm gemm(
|
||||
cu::device(s.device),
|
||||
a.dtype(),
|
||||
a_transposed,
|
||||
M,
|
||||
K,
|
||||
lda,
|
||||
b_transposed,
|
||||
K,
|
||||
N,
|
||||
ldb,
|
||||
batch_shape.back(),
|
||||
a_batch_strides.back(),
|
||||
b_batch_strides.back());
|
||||
gemm.run(encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
|
||||
gemm_and_bias(
|
||||
encoder, M, N, K, a_transposed, lda, b_transposed, ldb, out, a, b);
|
||||
}
|
||||
|
||||
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -136,6 +154,28 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre);
|
||||
auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre);
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Dispatch to GEMM with epilogue or AddMM
|
||||
|
||||
if (beta_ == 1 && c.strides(-1) == 1 && c.data_size() == out.shape(-1)) {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
gemm_and_bias(
|
||||
encoder,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
a_transposed,
|
||||
lda,
|
||||
b_transposed,
|
||||
ldb,
|
||||
out,
|
||||
a,
|
||||
b,
|
||||
c.data<void>(),
|
||||
alpha_);
|
||||
return;
|
||||
}
|
||||
|
||||
int64_t ldc;
|
||||
{
|
||||
auto stx = c.strides()[c.ndim() - 2];
|
||||
@@ -177,7 +217,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Invoke cublasLt
|
||||
// Invoke cublasLt with AddMM settings
|
||||
|
||||
CublasGemm gemm(
|
||||
cu::device(s.device),
|
||||
|
||||
@@ -103,15 +103,21 @@ template <typename T, bool traditional, bool forward, int N = 4>
|
||||
__device__ void rope_impl(
|
||||
const T* in,
|
||||
T* out,
|
||||
int offset,
|
||||
const int* offset,
|
||||
float inv_freq,
|
||||
float scale,
|
||||
const cuda::std::array<int64_t, 3> strides,
|
||||
const cuda::std::array<int64_t, 3> out_strides,
|
||||
int64_t n_batch,
|
||||
int64_t offset_stride,
|
||||
int n_head,
|
||||
uint3 pos,
|
||||
uint3 dims) {
|
||||
float L = scale * static_cast<float>(pos.y + offset);
|
||||
auto n_head_up = N * ((n_head + N - 1) / N);
|
||||
auto head_idx = static_cast<int>((pos.z * N) % n_head_up);
|
||||
auto batch_idx = (pos.z * N) / n_head_up;
|
||||
auto batch_offset = offset[batch_idx * offset_stride];
|
||||
float L = scale * static_cast<float>(pos.y + batch_offset);
|
||||
auto mat_idx = batch_idx * n_head + head_idx;
|
||||
|
||||
// Compute costheta, sintheta
|
||||
float theta = L * inv_freq;
|
||||
@@ -123,20 +129,19 @@ __device__ void rope_impl(
|
||||
size_t out_index_1, out_index_2;
|
||||
if (traditional) {
|
||||
out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] +
|
||||
N * pos.z * out_strides[0];
|
||||
mat_idx * out_strides[0];
|
||||
out_index_2 = out_index_1 + 1;
|
||||
in_index_1 =
|
||||
2 * pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
|
||||
2 * pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0];
|
||||
in_index_2 = in_index_1 + strides[2];
|
||||
} else {
|
||||
out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] +
|
||||
N * pos.z * out_strides[0];
|
||||
mat_idx * out_strides[0];
|
||||
out_index_2 = out_index_1 + dims.x * out_strides[2];
|
||||
in_index_1 =
|
||||
pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
|
||||
in_index_1 = pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0];
|
||||
in_index_2 = in_index_1 + dims.x * strides[2];
|
||||
}
|
||||
for (int i = 0; i < N && pos.z * N + i < n_batch; ++i) {
|
||||
for (int i = 0; i < N && head_idx + i < n_head; ++i) {
|
||||
// Read and write the output
|
||||
float x1 = static_cast<float>(in[in_index_1]);
|
||||
float x2 = static_cast<float>(in[in_index_2]);
|
||||
@@ -167,7 +172,8 @@ __global__ void rope(
|
||||
float base,
|
||||
const __grid_constant__ cuda::std::array<int64_t, 3> strides,
|
||||
const __grid_constant__ cuda::std::array<int64_t, 3> out_strides,
|
||||
int64_t n_batch,
|
||||
int64_t offset_stride,
|
||||
int n_head,
|
||||
uint3 dims) {
|
||||
uint3 pos = make_uint3(
|
||||
blockIdx.x * blockDim.x + threadIdx.x,
|
||||
@@ -182,12 +188,13 @@ __global__ void rope(
|
||||
rope_impl<T, traditional, forward>(
|
||||
in,
|
||||
out,
|
||||
*offset,
|
||||
offset,
|
||||
inv_freq,
|
||||
scale,
|
||||
strides,
|
||||
out_strides,
|
||||
n_batch,
|
||||
offset_stride,
|
||||
n_head,
|
||||
pos,
|
||||
dims);
|
||||
}
|
||||
@@ -202,7 +209,8 @@ __global__ void rope_freqs(
|
||||
float base,
|
||||
const __grid_constant__ cuda::std::array<int64_t, 3> strides,
|
||||
const __grid_constant__ cuda::std::array<int64_t, 3> out_strides,
|
||||
int64_t n_batch,
|
||||
int64_t offset_stride,
|
||||
int n_head,
|
||||
uint3 dims,
|
||||
int64_t freq_stride) {
|
||||
uint3 pos = make_uint3(
|
||||
@@ -217,12 +225,13 @@ __global__ void rope_freqs(
|
||||
rope_impl<T, traditional, forward>(
|
||||
in,
|
||||
out,
|
||||
*offset,
|
||||
offset,
|
||||
inv_freq,
|
||||
scale,
|
||||
strides,
|
||||
out_strides,
|
||||
n_batch,
|
||||
offset_stride,
|
||||
n_head,
|
||||
pos,
|
||||
dims);
|
||||
}
|
||||
@@ -245,23 +254,28 @@ void RoPE::eval_gpu(
|
||||
auto& offset = inputs[1];
|
||||
auto& out = outputs[0];
|
||||
|
||||
if (in.ndim() < 3) {
|
||||
throw std::runtime_error("[RoPE] Input must have at least 3 dimensions");
|
||||
}
|
||||
|
||||
cuda::std::array<int64_t, 3> strides;
|
||||
cuda::std::array<int64_t, 3> out_strides;
|
||||
bool donated = false;
|
||||
int ndim = in.ndim();
|
||||
int dispatch_ndim = in.ndim();
|
||||
|
||||
int B = in.shape(0);
|
||||
int T = in.shape(-2);
|
||||
int D = in.shape(-1);
|
||||
size_t mat_size = T * D;
|
||||
int dispatch_ndim = ndim;
|
||||
while (in.shape(-dispatch_ndim) == 1 && dispatch_ndim > 3) {
|
||||
dispatch_ndim--;
|
||||
}
|
||||
size_t mat_size = in.shape(-2) * in.shape(-1);
|
||||
|
||||
int N = 1;
|
||||
for (int i = 1; i < (ndim - 2); ++i) {
|
||||
N *= in.shape(i);
|
||||
}
|
||||
|
||||
// We apply rope to less that the whole vector so copy to output and then
|
||||
// apply in-place.
|
||||
if (dims_ < in.shape(-1)) {
|
||||
if (dims_ < D) {
|
||||
donated = true;
|
||||
auto ctype =
|
||||
(in.flags().row_contiguous) ? CopyType::Vector : CopyType::General;
|
||||
@@ -302,7 +316,7 @@ void RoPE::eval_gpu(
|
||||
out_strides[2] = out.strides()[ndim - 1];
|
||||
|
||||
// Some flags to help us dispatch below
|
||||
bool single = in.flags().row_contiguous && (mat_size == in.shape(-1));
|
||||
bool single = in.flags().row_contiguous && B == 1 && T == 1;
|
||||
bool with_freqs = inputs.size() == 3;
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
@@ -319,7 +333,7 @@ void RoPE::eval_gpu(
|
||||
if (single && !with_freqs) {
|
||||
auto kernel =
|
||||
cu::rope_single<DataType, traditional.value, forward.value>;
|
||||
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size);
|
||||
uint2 dims = make_uint2(dims_ / 2, N);
|
||||
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
@@ -336,7 +350,7 @@ void RoPE::eval_gpu(
|
||||
} else if (single) {
|
||||
auto kernel =
|
||||
cu::rope_single_freqs<DataType, traditional.value, forward.value>;
|
||||
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size);
|
||||
uint2 dims = make_uint2(dims_ / 2, N);
|
||||
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
@@ -354,10 +368,14 @@ void RoPE::eval_gpu(
|
||||
} else if (with_freqs) {
|
||||
auto kernel =
|
||||
cu::rope_freqs<DataType, traditional.value, forward.value>;
|
||||
uint3 dims =
|
||||
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size);
|
||||
dims.z = (dims.z + 3) / 4;
|
||||
int n_per_thread = 4;
|
||||
uint32_t dimz = B * ((N + n_per_thread - 1) / n_per_thread);
|
||||
uint3 dims = make_uint3(dims_ / 2, T, dimz);
|
||||
auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z);
|
||||
int64_t offset_stride = 0;
|
||||
if (inputs[1].ndim() > 0) {
|
||||
offset_stride = inputs[1].strides()[0];
|
||||
}
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
grid,
|
||||
@@ -371,15 +389,20 @@ void RoPE::eval_gpu(
|
||||
std::log2(base_),
|
||||
strides,
|
||||
out_strides,
|
||||
in.size() / mat_size,
|
||||
offset_stride,
|
||||
N,
|
||||
dims,
|
||||
inputs[2].strides(0));
|
||||
} else {
|
||||
auto kernel = cu::rope<DataType, traditional.value, forward.value>;
|
||||
uint3 dims =
|
||||
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size);
|
||||
dims.z = (dims.z + 3) / 4;
|
||||
int n_per_thread = 4;
|
||||
uint32_t dimz = B * ((N + n_per_thread - 1) / n_per_thread);
|
||||
uint3 dims = make_uint3(dims_ / 2, T, dimz);
|
||||
auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z);
|
||||
int64_t offset_stride = 0;
|
||||
if (inputs[1].ndim() > 0) {
|
||||
offset_stride = inputs[1].strides()[0];
|
||||
}
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
grid,
|
||||
@@ -392,7 +415,8 @@ void RoPE::eval_gpu(
|
||||
std::log2(base_),
|
||||
strides,
|
||||
out_strides,
|
||||
in.size() / mat_size,
|
||||
offset_stride,
|
||||
N,
|
||||
dims);
|
||||
}
|
||||
});
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
@@ -39,10 +38,11 @@ void explicit_gemm_conv_ND_gpu(
|
||||
in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes()));
|
||||
|
||||
// Prepare unfolding kernel
|
||||
std::ostringstream kname;
|
||||
kname << "naive_unfold_nd_" << type_to_name(in_unfolded) << "_" << N;
|
||||
std::string kname;
|
||||
kname.reserve(32);
|
||||
concatenate(kname, "naive_unfold_nd_", type_to_name(in_unfolded), "_", N);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto kernel = d.get_kernel(kname);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
@@ -117,11 +117,12 @@ void explicit_gemm_conv_group_ND_gpu(
|
||||
in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes()));
|
||||
|
||||
// Prepare unfolding kernel
|
||||
std::ostringstream kname;
|
||||
kname << "naive_unfold_transpose_nd_" << type_to_name(in_unfolded) << "_"
|
||||
<< N;
|
||||
std::string kname;
|
||||
kname.reserve(32);
|
||||
concatenate(
|
||||
kname, "naive_unfold_transpose_nd_", type_to_name(in_unfolded), "_", N);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto kernel = d.get_kernel(kname);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
@@ -252,18 +253,32 @@ void implicit_gemm_conv_2D_gpu(
|
||||
/* const int swizzle_log = */ swizzle_log};
|
||||
|
||||
// Determine kernel
|
||||
std::ostringstream kname;
|
||||
kname << "implicit_gemm_conv_2d_" << type_to_name(out) << "_bm" << bm << "_bn"
|
||||
<< bn << "_bk" << bk << "_wm" << wm << "_wn" << wn << "_channel_"
|
||||
<< (n_channel_specialization ? std::to_string(n_channel_specialization)
|
||||
: "l")
|
||||
<< "_filter_" << (small_filter ? 's' : 'l');
|
||||
std::string kname;
|
||||
kname.reserve(64);
|
||||
concatenate(
|
||||
kname,
|
||||
"implicit_gemm_conv_2d_",
|
||||
type_to_name(out),
|
||||
"_bm",
|
||||
bm,
|
||||
"_bn",
|
||||
bn,
|
||||
"_bk",
|
||||
bk,
|
||||
"_wm",
|
||||
wm,
|
||||
"_wn",
|
||||
wn,
|
||||
"_channel_",
|
||||
n_channel_specialization ? std::to_string(n_channel_specialization) : "l",
|
||||
"_filter_",
|
||||
small_filter ? 's' : 'l');
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = get_steel_conv_kernel(
|
||||
d,
|
||||
kname.str(),
|
||||
kname,
|
||||
out,
|
||||
bm,
|
||||
bn,
|
||||
@@ -559,11 +574,16 @@ void winograd_conv_2D_gpu(
|
||||
{
|
||||
int bc = 32;
|
||||
int bo = 4;
|
||||
std::ostringstream kname;
|
||||
kname << "winograd_conv_2d_weight_transform_" << type_to_name(out) << "_bc"
|
||||
<< bc;
|
||||
std::string kname;
|
||||
kname.reserve(32);
|
||||
concatenate(
|
||||
kname,
|
||||
"winograd_conv_2d_weight_transform_",
|
||||
type_to_name(out),
|
||||
"_bc",
|
||||
bc);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto kernel = d.get_kernel(kname);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(wt, 0);
|
||||
@@ -587,11 +607,16 @@ void winograd_conv_2D_gpu(
|
||||
int bc = 32;
|
||||
int wm = 2;
|
||||
int wn = 2;
|
||||
std::ostringstream kname;
|
||||
kname << "winograd_conv_2d_input_transform_" << type_to_name(out) << "_bc"
|
||||
<< bc;
|
||||
std::string kname;
|
||||
kname.reserve(32);
|
||||
concatenate(
|
||||
kname,
|
||||
"winograd_conv_2d_input_transform_",
|
||||
type_to_name(out),
|
||||
"_bc",
|
||||
bc);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto kernel = d.get_kernel(kname);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(in_padded, 0);
|
||||
@@ -634,11 +659,16 @@ void winograd_conv_2D_gpu(
|
||||
int bc = 32;
|
||||
int wm = 2;
|
||||
int wn = 2;
|
||||
std::ostringstream kname;
|
||||
kname << "winograd_conv_2d_output_transform_" << type_to_name(out) << "_bo"
|
||||
<< bc;
|
||||
std::string kname;
|
||||
kname.reserve(32);
|
||||
concatenate(
|
||||
kname,
|
||||
"winograd_conv_2d_output_transform_",
|
||||
type_to_name(out),
|
||||
"_bo",
|
||||
bc);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto kernel = d.get_kernel(kname);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(out_wg, 0);
|
||||
@@ -660,9 +690,9 @@ void depthwise_conv_2D_gpu(
|
||||
const array& wt,
|
||||
array out,
|
||||
const MLXConvParams<2>& conv_params) {
|
||||
std::ostringstream kname;
|
||||
kname << "depthwise_conv_2d_" << type_to_name(out);
|
||||
std::string base_name = kname.str();
|
||||
std::string base_name;
|
||||
base_name.reserve(32);
|
||||
concatenate(base_name, "depthwise_conv_2d_", type_to_name(out));
|
||||
|
||||
const int N = conv_params.N;
|
||||
const int ker_h = conv_params.wS[0];
|
||||
@@ -685,15 +715,18 @@ void depthwise_conv_2D_gpu(
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
kname << "_ker_h_" << ker_h
|
||||
<< "_ker_w_" << ker_w
|
||||
<< "_str_h_" << str_h
|
||||
<< "_str_w_" << str_w
|
||||
<< "_tgp_h_" << th
|
||||
<< "_tgp_w_" << tw
|
||||
<< "_do_flip_" << (do_flip ? 't' : 'n'); // clang-format on
|
||||
|
||||
std::string hash_name = kname.str();
|
||||
std::string hash_name;
|
||||
hash_name.reserve(64);
|
||||
concatenate(
|
||||
hash_name,
|
||||
base_name,
|
||||
"_ker_h_", ker_h,
|
||||
"_ker_w_", ker_w,
|
||||
"_str_h_", str_h,
|
||||
"_str_w_", str_w,
|
||||
"_tgp_h_", th,
|
||||
"_tgp_w_", tw,
|
||||
"_do_flip_", do_flip ? 't' : 'n'); // clang-format on
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(base_name, hash_name, func_consts);
|
||||
@@ -774,6 +807,56 @@ void dispatch_conv_2D_gpu(
|
||||
}
|
||||
}
|
||||
|
||||
void depthwise_conv_1D_gpu(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& in,
|
||||
array wt,
|
||||
array out) {
|
||||
bool large = in.size() > INT32_MAX || in.data_size() > INT32_MAX;
|
||||
std::string base_name;
|
||||
base_name.reserve(32);
|
||||
concatenate(
|
||||
base_name,
|
||||
"depthwise_conv_1d_",
|
||||
large ? "_large" : "",
|
||||
type_to_name(out));
|
||||
|
||||
if (!wt.flags().row_contiguous) {
|
||||
wt = contiguous_copy_gpu(wt, s);
|
||||
d.add_temporary(wt, s.index);
|
||||
}
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(base_name);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
auto B = in.shape(0);
|
||||
auto Tout = out.shape(1);
|
||||
auto D = in.shape(2);
|
||||
auto K = wt.shape(1);
|
||||
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_input_array(wt, 1);
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
if (large) {
|
||||
int64_t strides[3] = {in.strides(0), in.strides(1), in.strides(2)};
|
||||
compute_encoder.set_bytes(strides, 3, 3);
|
||||
|
||||
} else {
|
||||
int strides[3] = {
|
||||
static_cast<int>(in.strides(0)),
|
||||
static_cast<int>(in.strides(1)),
|
||||
static_cast<int>(in.strides(2))};
|
||||
compute_encoder.set_bytes(strides, 3, 3);
|
||||
}
|
||||
|
||||
compute_encoder.set_bytes(K, 4);
|
||||
auto group_dims = get_block_dims(D, Tout, B);
|
||||
MTL::Size grid_dims = MTL::Size(D, Tout, B);
|
||||
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void conv_1D_gpu(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
@@ -790,8 +873,15 @@ void conv_1D_gpu(
|
||||
bool is_idil_one = in_dilation[0] == 1;
|
||||
int C = in.shape(2);
|
||||
int O = wt.shape(0);
|
||||
const int C_per_group = in.shape(2) / groups;
|
||||
const int O_per_group = wt.shape(0) / groups;
|
||||
// Fast path for fully separable 1D convolution
|
||||
if (is_idil_one && (groups == C) && groups == O && wt_strides[0] == 1 &&
|
||||
wt_dilation[0] == 1 && padding[0] == 0 && !flip) {
|
||||
depthwise_conv_1D_gpu(s, d, in, wt, out);
|
||||
return;
|
||||
}
|
||||
|
||||
const int C_per_group = C / groups;
|
||||
const int O_per_group = O / groups;
|
||||
|
||||
// Direct to implicit gemm conv
|
||||
if (is_idil_one && (C_per_group <= 4 || C_per_group % 16 == 0) &&
|
||||
|
||||
@@ -288,6 +288,40 @@ instantiate_depthconv2d(float32, float);
|
||||
instantiate_depthconv2d(float16, half);
|
||||
instantiate_depthconv2d(bfloat16, bfloat16_t);
|
||||
|
||||
template <typename T, typename IdxT>
|
||||
[[kernel]] void depthwise_conv_1d(
|
||||
const device T* in [[buffer(0)]],
|
||||
const device T* w [[buffer(1)]],
|
||||
device T* out [[buffer(2)]],
|
||||
constant const IdxT strides[3],
|
||||
constant const int& kernel_size,
|
||||
uint3 tid [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
out += (tid.z * static_cast<IdxT>(grid_dim.y) + tid.y) * grid_dim.x + tid.x;
|
||||
in += tid.z * strides[0] + tid.y * strides[1] + tid.x * strides[2];
|
||||
w += tid.x * kernel_size;
|
||||
|
||||
float acc = 0.0;
|
||||
for (int i = 0; i < kernel_size; ++i) {
|
||||
acc += static_cast<float>(in[0]) * w[i];
|
||||
in += strides[1];
|
||||
}
|
||||
*out = static_cast<T>(acc);
|
||||
}
|
||||
|
||||
#define instantiate_depthconv1d(iname, itype) \
|
||||
instantiate_kernel( \
|
||||
"depthwise_conv_1d_" #iname, depthwise_conv_1d, itype, int32_t) \
|
||||
instantiate_kernel( \
|
||||
"depthwise_conv_1d_" #iname "_large", \
|
||||
depthwise_conv_1d, \
|
||||
itype, \
|
||||
int64_t)
|
||||
|
||||
instantiate_depthconv1d(float32, float);
|
||||
instantiate_depthconv1d(float16, half);
|
||||
instantiate_depthconv1d(bfloat16, bfloat16_t);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
/// Winograd kernels
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@@ -10,7 +10,7 @@ void rope_single_impl(
|
||||
constant const int& offset,
|
||||
const float inv_freq,
|
||||
constant const float& scale,
|
||||
constant const size_t& stride,
|
||||
constant const int64_t& stride,
|
||||
uint2 pos,
|
||||
uint2 grid) {
|
||||
float L = scale * static_cast<float>(offset);
|
||||
@@ -52,7 +52,7 @@ template <typename T, bool traditional, bool forward>
|
||||
device T* out [[buffer(1)]],
|
||||
constant const int& offset,
|
||||
constant const float& scale,
|
||||
constant const size_t& stride,
|
||||
constant const int64_t& stride,
|
||||
constant const float& base [[buffer(10)]],
|
||||
uint2 pos [[thread_position_in_grid]],
|
||||
uint2 grid [[threads_per_grid]]) {
|
||||
@@ -68,9 +68,9 @@ template <typename T, bool traditional, bool forward>
|
||||
device T* out [[buffer(1)]],
|
||||
constant const int& offset,
|
||||
constant const float& scale,
|
||||
constant const size_t& stride,
|
||||
constant const int64_t& stride,
|
||||
const device float* freqs [[buffer(10)]],
|
||||
constant const size_t& freq_stride [[buffer(11)]],
|
||||
constant const int64_t& freq_stride [[buffer(11)]],
|
||||
uint2 pos [[thread_position_in_grid]],
|
||||
uint2 grid [[threads_per_grid]]) {
|
||||
float inv_freq = 1.0 / (freqs[freq_stride * pos.x]);
|
||||
@@ -82,15 +82,21 @@ template <typename T, bool traditional, bool forward, int N = 4>
|
||||
void rope_impl(
|
||||
const device T* in,
|
||||
device T* out,
|
||||
constant const int& offset,
|
||||
const device int* offset,
|
||||
const float inv_freq,
|
||||
constant const float& scale,
|
||||
constant const size_t strides[3],
|
||||
constant const size_t out_strides[3],
|
||||
constant const size_t& n_batch,
|
||||
constant const int64_t strides[3],
|
||||
constant const int64_t out_strides[3],
|
||||
constant const int64_t& offset_stride,
|
||||
constant const int& n_head,
|
||||
uint3 pos,
|
||||
uint3 grid) {
|
||||
float L = scale * static_cast<float>(pos.y + offset);
|
||||
auto n_head_up = N * ((n_head + N - 1) / N);
|
||||
auto head_idx = static_cast<int>((pos.z * N) % n_head_up);
|
||||
auto batch_idx = (pos.z * N) / n_head_up;
|
||||
auto batch_offset = offset[batch_idx * offset_stride];
|
||||
float L = scale * static_cast<float>(pos.y + batch_offset);
|
||||
auto mat_idx = batch_idx * n_head + head_idx;
|
||||
|
||||
// Compute costheta, sintheta
|
||||
float theta = L * inv_freq;
|
||||
@@ -102,20 +108,19 @@ void rope_impl(
|
||||
size_t out_index_1, out_index_2;
|
||||
if (traditional) {
|
||||
out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] +
|
||||
N * pos.z * out_strides[0];
|
||||
mat_idx * out_strides[0];
|
||||
out_index_2 = out_index_1 + 1;
|
||||
in_index_1 =
|
||||
2 * pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
|
||||
2 * pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0];
|
||||
in_index_2 = in_index_1 + strides[2];
|
||||
} else {
|
||||
out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] +
|
||||
N * pos.z * out_strides[0];
|
||||
mat_idx * out_strides[0];
|
||||
out_index_2 = out_index_1 + grid.x * out_strides[2];
|
||||
in_index_1 =
|
||||
pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
|
||||
in_index_1 = pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0];
|
||||
in_index_2 = in_index_1 + grid.x * strides[2];
|
||||
}
|
||||
for (int i = 0; i < N && pos.z * N + i < n_batch; ++i) {
|
||||
for (int i = 0; i < N && head_idx + i < n_head; ++i) {
|
||||
// Read and write the output
|
||||
float x1 = static_cast<float>(in[in_index_1]);
|
||||
float x2 = static_cast<float>(in[in_index_2]);
|
||||
@@ -141,11 +146,12 @@ template <typename T, bool traditional, bool forward, int N = 4>
|
||||
[[kernel]] void rope(
|
||||
const device T* in [[buffer(0)]],
|
||||
device T* out [[buffer(1)]],
|
||||
constant const int& offset,
|
||||
const device int* offset,
|
||||
constant const float& scale,
|
||||
constant const size_t strides[3],
|
||||
constant const size_t out_strides[3],
|
||||
constant const size_t& n_batch,
|
||||
constant const int64_t strides[3],
|
||||
constant const int64_t out_strides[3],
|
||||
constant const int64_t& offset_stride,
|
||||
constant const int& n_head,
|
||||
constant const float& base [[buffer(10)]],
|
||||
uint3 pos [[thread_position_in_grid]],
|
||||
uint3 grid [[threads_per_grid]]) {
|
||||
@@ -159,7 +165,8 @@ template <typename T, bool traditional, bool forward, int N = 4>
|
||||
scale,
|
||||
strides,
|
||||
out_strides,
|
||||
n_batch,
|
||||
offset_stride,
|
||||
n_head,
|
||||
pos,
|
||||
grid);
|
||||
}
|
||||
@@ -168,13 +175,14 @@ template <typename T, bool traditional, bool forward, int N = 4>
|
||||
[[kernel]] void rope_freqs(
|
||||
const device T* in [[buffer(0)]],
|
||||
device T* out [[buffer(1)]],
|
||||
constant const int& offset,
|
||||
const device int* offset,
|
||||
constant const float& scale,
|
||||
constant const size_t strides[3],
|
||||
constant const size_t out_strides[3],
|
||||
constant const size_t& n_batch,
|
||||
constant const int64_t strides[3],
|
||||
constant const int64_t out_strides[3],
|
||||
constant const int64_t& offset_stride,
|
||||
constant const int& n_head,
|
||||
const device float* freqs [[buffer(10)]],
|
||||
constant const size_t& freq_stride [[buffer(11)]],
|
||||
constant const int64_t& freq_stride [[buffer(11)]],
|
||||
uint3 pos [[thread_position_in_grid]],
|
||||
uint3 grid [[threads_per_grid]]) {
|
||||
float inv_freq = 1.0 / (freqs[freq_stride * pos.x]);
|
||||
@@ -186,61 +194,20 @@ template <typename T, bool traditional, bool forward, int N = 4>
|
||||
scale,
|
||||
strides,
|
||||
out_strides,
|
||||
n_batch,
|
||||
offset_stride,
|
||||
n_head,
|
||||
pos,
|
||||
grid);
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_rope_g(name, type, traditional, forward) \
|
||||
template [[host_name("rope_" #name)]] [[kernel]] void \
|
||||
rope<type, traditional, forward>( \
|
||||
const device type* in [[buffer(0)]], \
|
||||
device type* out [[buffer(1)]], \
|
||||
constant const int& offset, \
|
||||
constant const float& scale, \
|
||||
constant const size_t strides[3], \
|
||||
constant const size_t out_strides[3], \
|
||||
constant const size_t& n_batch, \
|
||||
constant const float& base [[buffer(10)]], \
|
||||
uint3 pos [[thread_position_in_grid]], \
|
||||
uint3 grid [[threads_per_grid]]); \
|
||||
template [[host_name("rope_freqs_" #name)]] \
|
||||
[[kernel]] void rope_freqs<type, traditional, forward>( \
|
||||
const device type* in [[buffer(0)]], \
|
||||
device type* out [[buffer(1)]], \
|
||||
constant const int& offset, \
|
||||
constant const float& scale, \
|
||||
constant const size_t strides[3], \
|
||||
constant const size_t out_strides[3], \
|
||||
constant const size_t& n_batch, \
|
||||
const device float* freqs [[buffer(10)]], \
|
||||
constant const size_t& freq_stride [[buffer(11)]], \
|
||||
uint3 pos [[thread_position_in_grid]], \
|
||||
uint3 grid [[threads_per_grid]]);
|
||||
instantiate_kernel("rope_" #name, rope, type, traditional, forward) \
|
||||
instantiate_kernel("rope_freqs_" #name, rope_freqs, type, traditional, forward)
|
||||
|
||||
#define instantiate_rope_s(name, type, traditional, forward) \
|
||||
template [[host_name("rope_single_" #name)]] [[kernel]] void \
|
||||
rope_single<type, traditional, forward>( \
|
||||
const device type* in [[buffer(0)]], \
|
||||
device type* out [[buffer(1)]], \
|
||||
constant const int& offset, \
|
||||
constant const float& scale, \
|
||||
constant const size_t& stride, \
|
||||
constant const float& base [[buffer(10)]], \
|
||||
uint2 pos [[thread_position_in_grid]], \
|
||||
uint2 grid [[threads_per_grid]]); \
|
||||
template [[host_name("rope_single_freqs_" #name)]] \
|
||||
[[kernel]] void rope_single_freqs<type, traditional, forward>( \
|
||||
const device type* in [[buffer(0)]], \
|
||||
device type* out [[buffer(1)]], \
|
||||
constant const int& offset, \
|
||||
constant const float& scale, \
|
||||
constant const size_t& stride, \
|
||||
const device float* freqs [[buffer(10)]], \
|
||||
constant const size_t& freq_stride [[buffer(11)]], \
|
||||
uint2 pos [[thread_position_in_grid]], \
|
||||
uint2 grid [[threads_per_grid]]);
|
||||
#define instantiate_rope_s(name, type, traditional, forward) \
|
||||
instantiate_kernel("rope_single_" #name, rope_single, type, traditional, forward) \
|
||||
instantiate_kernel("rope_single_freqs_" #name, rope_single_freqs, type, traditional, forward)
|
||||
|
||||
#define instantiate_rope(name, type, traditional, forward) \
|
||||
instantiate_rope_s(name, type, traditional, forward) \
|
||||
|
||||
@@ -18,23 +18,29 @@ void RoPE::eval_gpu(
|
||||
auto& in = inputs[0];
|
||||
auto& out = outputs[0];
|
||||
|
||||
if (in.ndim() < 3) {
|
||||
throw std::runtime_error("[RoPE] Input must have at least 3 dimensions");
|
||||
}
|
||||
|
||||
auto& s = out.primitive().stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
size_t strides[3];
|
||||
size_t out_strides[3];
|
||||
int64_t strides[3];
|
||||
int64_t out_strides[3];
|
||||
bool donated = false;
|
||||
int ndim = in.ndim();
|
||||
int dispatch_ndim = in.ndim();
|
||||
int B = in.shape(0);
|
||||
int T = in.shape(-2);
|
||||
int D = in.shape(-1);
|
||||
size_t mat_size = T * D;
|
||||
|
||||
int dispatch_ndim = ndim;
|
||||
while (in.shape(-dispatch_ndim) == 1 && dispatch_ndim > 3) {
|
||||
dispatch_ndim--;
|
||||
}
|
||||
size_t mat_size = in.shape(-2) * in.shape(-1);
|
||||
if (dims_ < in.shape(-1)) {
|
||||
|
||||
int N = 1;
|
||||
for (int i = 1; i < (ndim - 2); ++i) {
|
||||
N *= in.shape(i);
|
||||
}
|
||||
|
||||
if (dims_ < D) {
|
||||
donated = true;
|
||||
auto ctype =
|
||||
(in.flags().row_contiguous) ? CopyType::Vector : CopyType::General;
|
||||
@@ -71,8 +77,8 @@ void RoPE::eval_gpu(
|
||||
out_strides[1] = out.strides()[ndim - 2];
|
||||
out_strides[2] = out.strides()[ndim - 1];
|
||||
|
||||
// Special case for inference (single time step and contiguous)
|
||||
bool single = in.flags().row_contiguous && (mat_size == in.shape(-1));
|
||||
// Special case for inference (single batch, single time step, and contiguous)
|
||||
bool single = in.flags().row_contiguous && B == 1 && T == 1;
|
||||
|
||||
bool with_freqs = inputs.size() == 3;
|
||||
std::ostringstream kname;
|
||||
@@ -86,24 +92,29 @@ void RoPE::eval_gpu(
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(donated ? out : in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
|
||||
compute_encoder.set_input_array(inputs[1], 2);
|
||||
compute_encoder.set_bytes(scale_, 3);
|
||||
|
||||
size_t n_batch = in.size() / mat_size;
|
||||
MTL::Size group_dims;
|
||||
MTL::Size grid_dims;
|
||||
if (single) {
|
||||
compute_encoder.set_bytes(out_strides, 1, 4);
|
||||
uint32_t dim0 = dims_ / 2;
|
||||
group_dims = get_block_dims(dim0, n_batch, 1);
|
||||
grid_dims = MTL::Size(dim0, n_batch, 1);
|
||||
group_dims = get_block_dims(dim0, N, 1);
|
||||
grid_dims = MTL::Size(dim0, N, 1);
|
||||
} else {
|
||||
compute_encoder.set_bytes(strides, 3, 4);
|
||||
compute_encoder.set_bytes(out_strides, 3, 5);
|
||||
compute_encoder.set_bytes(n_batch, 6);
|
||||
int64_t offset_stride = 0;
|
||||
if (inputs[1].ndim() > 0) {
|
||||
offset_stride = inputs[1].strides()[0];
|
||||
}
|
||||
compute_encoder.set_bytes(offset_stride, 6);
|
||||
compute_encoder.set_bytes(N, 7);
|
||||
uint32_t dim0 = dims_ / 2;
|
||||
uint32_t dim1 = in.shape(-2);
|
||||
uint32_t dim2 = (n_batch + n_per_thread - 1) / n_per_thread;
|
||||
uint32_t dim1 = T;
|
||||
uint32_t dim2 = B * ((N + n_per_thread - 1) / n_per_thread);
|
||||
group_dims = get_block_dims(dim0, dim1, dim2);
|
||||
grid_dims = MTL::Size(dim0, dim1, dim2);
|
||||
}
|
||||
|
||||
60
mlx/fast.cpp
60
mlx/fast.cpp
@@ -366,10 +366,16 @@ array rope(
|
||||
msg << "[rope] Input must be a floating type but got " << x.dtype() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (offset.size() != 1) {
|
||||
if (offset.ndim() > 1) {
|
||||
std::ostringstream msg;
|
||||
msg << "[rope] offset must be a scalar but has shape " << offset.shape()
|
||||
<< ".";
|
||||
msg << "[rope] offset must have at most one dimension but has shape "
|
||||
<< offset.shape() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (offset.size() != 1 && offset.size() != x.shape(0)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[rope] offset must be a scalar or vector with " << x.shape(0)
|
||||
<< " elements but has shape " << offset.shape() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (!issubdtype(offset.dtype(), integer)) {
|
||||
@@ -379,7 +385,7 @@ array rope(
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (offset.dtype().size() != 4) {
|
||||
inputs[1] = astype(offset, uint32, s);
|
||||
inputs[1] = astype(offset, int32, s);
|
||||
}
|
||||
if (inputs.size() == 3 &&
|
||||
(inputs[2].ndim() != 1 || inputs[2].shape(0) != dims / 2)) {
|
||||
@@ -391,15 +397,26 @@ array rope(
|
||||
|
||||
auto fallback = [dims, traditional, base, scale, forward, s](
|
||||
std::vector<array> inputs) {
|
||||
auto& shape = inputs[0].shape();
|
||||
int ndim = shape.size();
|
||||
auto x = flatten(inputs[0], 0, ndim - 3, s);
|
||||
auto x = inputs[0];
|
||||
auto shape = x.shape();
|
||||
if (x.ndim() == 3) {
|
||||
x = expand_dims(x, 1, s);
|
||||
} else if (x.ndim() > 4) {
|
||||
x = flatten(x, 1, 1 + (x.ndim() - 4), s);
|
||||
}
|
||||
|
||||
auto B = x.shape(0);
|
||||
auto N = x.shape(1);
|
||||
auto T = x.shape(2);
|
||||
auto t = x.dtype();
|
||||
// Compute sines and cosines
|
||||
auto half_dims = dims / 2;
|
||||
auto& offset = inputs[1];
|
||||
auto offset = inputs[1];
|
||||
if (offset.size() > 1) {
|
||||
offset = expand_dims(offset, {-1, -2}, s);
|
||||
}
|
||||
auto positions =
|
||||
multiply(add(arange(x.shape(1), t, s), offset, s), array(scale, t), s);
|
||||
multiply(add(arange(x.shape(2), t, s), offset, s), array(scale, t), s);
|
||||
|
||||
auto default_inv_freqs = [&inputs, &s, &t, base, half_dims]() {
|
||||
return exp(
|
||||
@@ -412,8 +429,7 @@ array rope(
|
||||
|
||||
auto inv_freqs = inputs.size() == 3 ? astype(reciprocal(inputs[2], s), t, s)
|
||||
: default_inv_freqs();
|
||||
auto theta =
|
||||
multiply(expand_dims(positions, 1, s), expand_dims(inv_freqs, 0, s), s);
|
||||
auto theta = multiply(expand_dims(positions, -1, s), inv_freqs, s);
|
||||
auto coss = cos(theta, s);
|
||||
auto sins = sin(theta, s);
|
||||
|
||||
@@ -436,32 +452,30 @@ array rope(
|
||||
};
|
||||
|
||||
if (traditional) {
|
||||
auto x1 =
|
||||
slice(x, {0, 0, 0}, {x.shape(0), x.shape(1), dims}, {1, 1, 2}, s);
|
||||
auto x2 =
|
||||
slice(x, {0, 0, 1}, {x.shape(0), x.shape(1), dims}, {1, 1, 2}, s);
|
||||
auto x1 = slice(x, {0, 0, 0, 0}, {B, N, T, dims}, {1, 1, 1, 2}, s);
|
||||
auto x2 = slice(x, {0, 0, 0, 1}, {B, N, T, dims}, {1, 1, 1, 2}, s);
|
||||
auto outs = apply_rope(x1, x2, coss, sins);
|
||||
for (auto& o : outs) {
|
||||
o = expand_dims(o, 3, s);
|
||||
o = expand_dims(o, -1, s);
|
||||
}
|
||||
auto out = concatenate(outs, 3, s);
|
||||
auto out = reshape(concatenate(outs, -1, s), {B, N, T, dims}, s);
|
||||
if (dims < x.shape(-1)) {
|
||||
out = reshape(out, {x.shape(0), x.shape(1), dims});
|
||||
out = concatenate({out, slice(x, {0, 0, dims}, x.shape(), s)}, 2, s);
|
||||
out =
|
||||
concatenate({out, slice(x, {0, 0, 0, dims}, x.shape(), s)}, -1, s);
|
||||
}
|
||||
return std::vector<array>{reshape(out, shape, s)};
|
||||
} else {
|
||||
auto out_s = x.shape();
|
||||
out_s.back() = half_dims;
|
||||
auto x1 = slice(x, {0, 0, 0}, out_s, s);
|
||||
auto x1 = slice(x, {0, 0, 0, 0}, out_s, s);
|
||||
out_s.back() = dims;
|
||||
auto x2 = slice(x, {0, 0, half_dims}, out_s, s);
|
||||
auto x2 = slice(x, {0, 0, 0, half_dims}, out_s, s);
|
||||
|
||||
auto outs = apply_rope(x1, x2, coss, sins);
|
||||
if (dims < x.shape(-1)) {
|
||||
outs.push_back(slice(x, {0, 0, dims}, x.shape(), s));
|
||||
outs.push_back(slice(x, {0, 0, 0, dims}, x.shape(), s));
|
||||
}
|
||||
return std::vector<array>{reshape(concatenate(outs, 2, s), shape, s)};
|
||||
return std::vector<array>{reshape(concatenate(outs, -1, s), shape, s)};
|
||||
}
|
||||
};
|
||||
auto stream = to_stream(s);
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
std::string version() {
|
||||
const char* version() {
|
||||
return MLX_VERSION;
|
||||
}
|
||||
|
||||
|
||||
@@ -15,6 +15,6 @@ namespace mlx::core {
|
||||
*
|
||||
* For dev builds, the version will include the suffix ".devYYYYMMDD+hash"
|
||||
*/
|
||||
std::string version();
|
||||
const char* version();
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -52,7 +52,6 @@ set_target_properties(
|
||||
${MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY})
|
||||
|
||||
target_link_libraries(core PRIVATE mlx)
|
||||
target_compile_definitions(core PRIVATE _VERSION_=${MLX_VERSION})
|
||||
|
||||
if(BUILD_SHARED_LIBS)
|
||||
if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
||||
|
||||
@@ -164,8 +164,13 @@ void init_fast(nb::module_& parent_module) {
|
||||
R"pbdoc(
|
||||
Apply rotary positional encoding to the input.
|
||||
|
||||
The input is expected to be at least 3D with shape ``(B, *, T, D)`` where:
|
||||
* ``B`` is the batch size.
|
||||
* ``T`` is the sequence length.
|
||||
* ``D`` is the feature dimension.
|
||||
|
||||
Args:
|
||||
a (array): Input array.
|
||||
a (array): The input array.
|
||||
dims (int): The feature dimensions to be rotated. If the input feature
|
||||
is larger than dims then the rest is left unchanged.
|
||||
traditional (bool): If set to ``True`` choose the traditional
|
||||
@@ -174,7 +179,9 @@ void init_fast(nb::module_& parent_module) {
|
||||
each dimension in the positional encodings. Exactly one of ``base`` and
|
||||
``freqs`` must be ``None``.
|
||||
scale (float): The scale used to scale the positions.
|
||||
offset (int or array): The position offset to start at.
|
||||
offset (int or array): The position offset to start at. If an
|
||||
:obj:`array` is given it can be a scalar or vector of ``B``
|
||||
offsets for each example in the batch.
|
||||
freqs (array, optional): Optional frequencies to use with RoPE.
|
||||
If set, the ``base`` parameter must be ``None``. Default: ``None``.
|
||||
|
||||
|
||||
@@ -2,9 +2,9 @@
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
|
||||
#define STRINGIFY(x) #x
|
||||
#define TOSTRING(x) STRINGIFY(x)
|
||||
#include "mlx/version.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
namespace nb = nanobind;
|
||||
|
||||
void init_mlx_func(nb::module_&);
|
||||
@@ -48,5 +48,5 @@ NB_MODULE(core, m) {
|
||||
init_distributed(m);
|
||||
init_export(m);
|
||||
|
||||
m.attr("__version__") = TOSTRING(_VERSION_);
|
||||
m.attr("__version__") = mx::version();
|
||||
}
|
||||
|
||||
@@ -91,7 +91,7 @@ mx::array to_array_with_accessor(nb::object obj) {
|
||||
return nb::cast<mx::array>(obj.attr("__mlx_array__")());
|
||||
} else {
|
||||
std::ostringstream msg;
|
||||
msg << "Invalid type " << nb::type_name(obj.type()).c_str()
|
||||
msg << "Invalid type " << nb::type_name(obj.type()).c_str()
|
||||
<< " received in array initialization.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
@@ -594,124 +594,123 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
np.random.seed(0)
|
||||
# Batched matmul
|
||||
alpha = 0.5
|
||||
beta = 2.0
|
||||
for beta in (1.0, 2.0):
|
||||
# c must broadcast to the output shape
|
||||
with self.assertRaises(ValueError):
|
||||
mx.addmm(mx.zeros((2, 2, 2)), mx.zeros((2, 2)), mx.zeros((2, 2)))
|
||||
|
||||
# c must broadcast to the output shape
|
||||
with self.assertRaises(ValueError):
|
||||
mx.addmm(mx.zeros((2, 2, 2)), mx.zeros((2, 2)), mx.zeros((2, 2)))
|
||||
# Regular batched case
|
||||
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
||||
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 16)).astype(np.float32)
|
||||
|
||||
# Regular batched case
|
||||
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
||||
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 16)).astype(np.float32)
|
||||
a_mlx = mx.array(a_npy)
|
||||
b_mlx = mx.array(b_npy)
|
||||
|
||||
a_mlx = mx.array(a_npy)
|
||||
b_mlx = mx.array(b_npy)
|
||||
for c_shape in ((1,), (1, 16), (32, 1, 16), (1, 128, 16)):
|
||||
c_npy = np.ones(c_shape).astype(np.float32)
|
||||
c_mlx = mx.array(c_npy)
|
||||
|
||||
for c_shape in ((1,), (1, 16), (32, 1, 16), (1, 128, 16)):
|
||||
c_npy = np.ones(c_shape).astype(np.float32)
|
||||
c_mlx = mx.array(c_npy)
|
||||
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
|
||||
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
|
||||
|
||||
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
|
||||
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
|
||||
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
||||
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
||||
|
||||
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
||||
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
||||
# Batched and transposed matmul
|
||||
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
||||
b_mlx = mx.array(b_npy)
|
||||
|
||||
# Batched and transposed matmul
|
||||
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
||||
b_mlx = mx.array(b_npy)
|
||||
for c_shape in ((1,), (32, 1, 128), (1, 128)):
|
||||
c_npy = np.ones(c_shape).astype(np.float32)
|
||||
c_mlx = mx.array(c_npy)
|
||||
|
||||
for c_shape in ((1,), (32, 1, 128), (1, 128)):
|
||||
c_npy = np.ones(c_shape).astype(np.float32)
|
||||
c_mlx = mx.array(c_npy)
|
||||
b_np_t = np.transpose(b_npy, (0, 2, 1))
|
||||
b_mx_t = mx.transpose(b_mlx, (0, 2, 1))
|
||||
|
||||
b_np_t = np.transpose(b_npy, (0, 2, 1))
|
||||
b_mx_t = mx.transpose(b_mlx, (0, 2, 1))
|
||||
d_npy = alpha * (a_npy @ b_np_t) + beta * c_npy
|
||||
d_mlx = mx.addmm(c_mlx, a_mlx, b_mx_t, alpha, beta)
|
||||
|
||||
d_npy = alpha * (a_npy @ b_np_t) + beta * c_npy
|
||||
d_mlx = mx.addmm(c_mlx, a_mlx, b_mx_t, alpha, beta)
|
||||
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
||||
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
||||
# Batched matmul with simple broadcast
|
||||
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
||||
b_npy = np.random.normal(0.0, 1.0 / 128, (16, 16)).astype(np.float32)
|
||||
|
||||
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
||||
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
||||
# Batched matmul with simple broadcast
|
||||
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
||||
b_npy = np.random.normal(0.0, 1.0 / 128, (16, 16)).astype(np.float32)
|
||||
a_mlx = mx.array(a_npy)
|
||||
b_mlx = mx.array(b_npy)
|
||||
|
||||
a_mlx = mx.array(a_npy)
|
||||
b_mlx = mx.array(b_npy)
|
||||
for c_shape in ((1,), (1, 16), (32, 1, 16), (1, 128, 16)):
|
||||
c_npy = np.ones(c_shape).astype(np.float32)
|
||||
c_mlx = mx.array(c_npy)
|
||||
|
||||
for c_shape in ((1,), (1, 16), (32, 1, 16), (1, 128, 16)):
|
||||
c_npy = np.ones(c_shape).astype(np.float32)
|
||||
c_mlx = mx.array(c_npy)
|
||||
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
|
||||
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
|
||||
|
||||
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
|
||||
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
|
||||
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
||||
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
||||
# Matmul with vector
|
||||
a_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32)
|
||||
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 128)).astype(np.float32)
|
||||
a_mlx = mx.array(a_npy)
|
||||
b_mlx = mx.array(b_npy)
|
||||
|
||||
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
||||
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
||||
# Matmul with vector
|
||||
a_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32)
|
||||
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 128)).astype(np.float32)
|
||||
a_mlx = mx.array(a_npy)
|
||||
b_mlx = mx.array(b_npy)
|
||||
for c_shape in ((1,), (128,), (32, 128)):
|
||||
c_npy = np.ones(c_shape).astype(np.float32)
|
||||
c_mlx = mx.array(c_npy)
|
||||
|
||||
for c_shape in ((1,), (128,), (32, 128)):
|
||||
c_npy = np.ones(c_shape).astype(np.float32)
|
||||
c_mlx = mx.array(c_npy)
|
||||
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
|
||||
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
|
||||
|
||||
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
|
||||
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
|
||||
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
||||
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
||||
|
||||
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
||||
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
||||
# Matmul with vector
|
||||
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
||||
b_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32)
|
||||
a_mlx = mx.array(a_npy)
|
||||
b_mlx = mx.array(b_npy)
|
||||
|
||||
# Matmul with vector
|
||||
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
||||
b_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32)
|
||||
a_mlx = mx.array(a_npy)
|
||||
b_mlx = mx.array(b_npy)
|
||||
for c_shape in ((1,), (32, 128)):
|
||||
c_npy = np.ones(c_shape).astype(np.float32)
|
||||
c_mlx = mx.array(c_npy)
|
||||
|
||||
for c_shape in ((1,), (32, 128)):
|
||||
c_npy = np.ones(c_shape).astype(np.float32)
|
||||
c_mlx = mx.array(c_npy)
|
||||
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
|
||||
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
|
||||
|
||||
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
|
||||
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
|
||||
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
||||
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
||||
|
||||
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
||||
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
||||
# Split K specializtion
|
||||
a_npy = np.random.normal(0.0, 1.0 / 128, (64, 4096)).astype(np.float32)
|
||||
b_npy = np.random.normal(0.0, 1.0 / 128, (4096, 32)).astype(np.float32)
|
||||
|
||||
# Split K specializtion
|
||||
a_npy = np.random.normal(0.0, 1.0 / 128, (64, 4096)).astype(np.float32)
|
||||
b_npy = np.random.normal(0.0, 1.0 / 128, (4096, 32)).astype(np.float32)
|
||||
a_mlx = mx.array(a_npy)
|
||||
b_mlx = mx.array(b_npy)
|
||||
|
||||
a_mlx = mx.array(a_npy)
|
||||
b_mlx = mx.array(b_npy)
|
||||
for c_shape in ((1,), (1, 32), (64, 1), (64, 32)):
|
||||
c_npy = np.ones(c_shape).astype(np.float32)
|
||||
c_mlx = mx.array(c_npy)
|
||||
|
||||
for c_shape in ((1,), (1, 32), (64, 1), (64, 32)):
|
||||
c_npy = np.ones(c_shape).astype(np.float32)
|
||||
c_mlx = mx.array(c_npy)
|
||||
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
|
||||
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
|
||||
|
||||
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
|
||||
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
|
||||
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
||||
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
||||
|
||||
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
||||
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
||||
# Transposed c
|
||||
a = mx.ones((10, 5)).T
|
||||
b = mx.ones((5, 5))
|
||||
out = mx.addmm(a, b, a, beta=beta, alpha=alpha)
|
||||
expected = beta * a + alpha * (b @ a)
|
||||
self.assertTrue(mx.allclose(expected, out))
|
||||
|
||||
# Transposed c
|
||||
a = mx.ones((10, 5)).T
|
||||
b = mx.ones((5, 5))
|
||||
out = mx.addmm(a, b, a, beta=1.5, alpha=0.5)
|
||||
expected = 1.5 * a + 0.5 * (b @ a)
|
||||
self.assertTrue(mx.allclose(expected, out))
|
||||
|
||||
# Broadcast c
|
||||
a = mx.ones((5, 5))
|
||||
b = mx.ones((5, 5))
|
||||
c = mx.ones((1, 5))
|
||||
out = mx.addmm(c, a, b, beta=1.5, alpha=0.5)
|
||||
expected = 1.5 * c + 0.5 * (a @ b)
|
||||
self.assertTrue(mx.allclose(expected, out))
|
||||
# Broadcast c
|
||||
a = mx.ones((5, 5))
|
||||
b = mx.ones((5, 5))
|
||||
c = mx.ones((1, 5))
|
||||
out = mx.addmm(c, a, b, beta=beta, alpha=alpha)
|
||||
expected = beta * c + alpha * (a @ b)
|
||||
self.assertTrue(mx.allclose(expected, out))
|
||||
|
||||
def test_addmm_grad(self):
|
||||
def make_ref_addmm(alpha, beta):
|
||||
@@ -724,33 +723,32 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
shapes = ((1, 64, 32, 128), (4, 28, 24, 47), (1, 1, 24, 47))
|
||||
|
||||
alpha = 2.0
|
||||
beta = 0.5
|
||||
for beta in (1.0, 0.5):
|
||||
f_test = make_addmm(alpha, beta)
|
||||
f_ref = make_ref_addmm(alpha, beta)
|
||||
|
||||
f_test = make_addmm(alpha, beta)
|
||||
f_ref = make_ref_addmm(alpha, beta)
|
||||
for B, M, N, K in shapes:
|
||||
cotan = mx.ones((B, M, N))
|
||||
c = mx.random.normal((B, M, N))
|
||||
a = mx.random.normal((B, M, K))
|
||||
b = mx.random.normal((B, K, N))
|
||||
|
||||
for B, M, N, K in shapes:
|
||||
cotan = mx.ones((B, M, N))
|
||||
c = mx.random.normal((B, M, N))
|
||||
a = mx.random.normal((B, M, K))
|
||||
b = mx.random.normal((B, K, N))
|
||||
out_ref, dout_ref = mx.vjp(
|
||||
f_ref,
|
||||
[c, a, b],
|
||||
[cotan],
|
||||
)
|
||||
out_test, dout_test = mx.vjp(
|
||||
f_test,
|
||||
[c, a, b],
|
||||
[cotan],
|
||||
)
|
||||
|
||||
out_ref, dout_ref = mx.vjp(
|
||||
f_ref,
|
||||
[c, a, b],
|
||||
[cotan],
|
||||
)
|
||||
out_test, dout_test = mx.vjp(
|
||||
f_test,
|
||||
[c, a, b],
|
||||
[cotan],
|
||||
)
|
||||
self.assertTrue(mx.allclose(out_ref[0], out_test[0], atol=1e-4).item())
|
||||
|
||||
self.assertTrue(mx.allclose(out_ref[0], out_test[0], atol=1e-4).item())
|
||||
|
||||
for r, t in zip(dout_ref, dout_test):
|
||||
self.assertEqual(r.shape, t.shape)
|
||||
self.assertTrue(mx.allclose(r, t, atol=1e-4).item())
|
||||
for r, t in zip(dout_ref, dout_test):
|
||||
self.assertEqual(r.shape, t.shape)
|
||||
self.assertTrue(mx.allclose(r, t, atol=1e-4).item())
|
||||
|
||||
def test_empty_matmul(self):
|
||||
a = mx.array([[], []]).T
|
||||
|
||||
@@ -8,18 +8,23 @@ import mlx_tests
|
||||
|
||||
|
||||
def rope_orig(x, dims, traditional, base, scale, offset, freqs=None):
|
||||
offset = offset.item() if isinstance(offset, mx.array) else offset
|
||||
N = x.shape[-2] + offset
|
||||
N = x.shape[-2]
|
||||
dtype = x.dtype
|
||||
half_D = dims // 2
|
||||
positions = mx.arange(offset, N, dtype=dtype) * scale
|
||||
positions = mx.arange(N, dtype=dtype)
|
||||
if isinstance(offset, mx.array) and offset.size > 1:
|
||||
expand = tuple(range(1, x.ndim - 1))
|
||||
positions = mx.expand_dims(offset, expand) + positions
|
||||
else:
|
||||
positions = offset + positions
|
||||
positions = positions * scale
|
||||
if freqs is None:
|
||||
inv_freqs = mx.exp(
|
||||
-mx.arange(0.0, half_D, dtype=dtype) * (math.log(base) / half_D)
|
||||
)
|
||||
else:
|
||||
inv_freqs = (1 / freqs).astype(x.dtype)
|
||||
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(inv_freqs, (1, -1))
|
||||
theta = mx.expand_dims(positions, -1) * inv_freqs
|
||||
costheta, sintheta = mx.cos(theta), mx.sin(theta)
|
||||
if traditional:
|
||||
x1 = x[..., :dims:2]
|
||||
@@ -214,6 +219,7 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
)
|
||||
self.assertEqual(dtype, rx.dtype)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
return
|
||||
|
||||
# Test single vector
|
||||
x = mx.random.uniform(shape=(1, 1, dims))
|
||||
@@ -277,6 +283,55 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
g2 = mx.grad(f2)(x, y)
|
||||
self.assertLess(mx.abs(g1 - g2).max(), 1e-5)
|
||||
|
||||
def test_rope_batch(self):
|
||||
T = 4
|
||||
base = 10000.0
|
||||
scale = 1.0
|
||||
traditional = True
|
||||
batch_sizes = [3, 8, 11]
|
||||
num_heads = [1, 3, 5]
|
||||
dims = 32
|
||||
|
||||
x = mx.random.uniform(shape=(8, 4, T, dims))
|
||||
|
||||
offset = mx.array([1, 2, 3])
|
||||
with self.assertRaises(ValueError):
|
||||
mx.fast.rope(
|
||||
x,
|
||||
dims,
|
||||
traditional=traditional,
|
||||
base=base,
|
||||
scale=scale,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
for batch_size in batch_sizes:
|
||||
for n_head in num_heads:
|
||||
x = mx.random.uniform(shape=(batch_size, n_head, T, dims))
|
||||
offset = mx.arange(batch_size)
|
||||
rx = rope_orig(x, dims, traditional, base, scale, offset)
|
||||
rx_fast = mx.fast.rope(
|
||||
x,
|
||||
dims,
|
||||
traditional=traditional,
|
||||
base=base,
|
||||
scale=scale,
|
||||
offset=offset,
|
||||
)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), 1e-5)
|
||||
x = mx.random.normal(shape=(2, 6, 8, 64)).transpose(0, 2, 1, 3)
|
||||
dims = 64
|
||||
offset = 0
|
||||
rx_fast = mx.fast.rope(
|
||||
x, dims, traditional=traditional, scale=scale, base=base, offset=offset
|
||||
)
|
||||
rx_fast_single = mx.fast.rope(
|
||||
x[0:1], dims, traditional=traditional, scale=scale, base=base, offset=offset
|
||||
)
|
||||
|
||||
rx = rope_orig(x, dims, traditional, base, scale, offset)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), 1e-5)
|
||||
|
||||
def test_rms_norm(self):
|
||||
# Per dtype absolute tolerance
|
||||
tolerances = {mx.float32: 1e-6, mx.float16: 1e-3, mx.bfloat16: 1e-2}
|
||||
|
||||
3
setup.py
3
setup.py
@@ -121,7 +121,8 @@ class CMakeBuild(build_ext):
|
||||
build_args += [f"-j{os.cpu_count()}"]
|
||||
|
||||
# Avoid cache miss when building from temporary dirs.
|
||||
os.environ["CCACHE_BASEDIR"] = os.path.abspath(self.build_temp)
|
||||
os.environ["CCACHE_BASEDIR"] = os.path.realpath(self.build_temp)
|
||||
os.environ["CCACHE_NOHASHDIR"] = "true"
|
||||
|
||||
subprocess.run(
|
||||
["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
# Doctest works fine with cmake 3.5
|
||||
set(CMAKE_POLICY_VERSION_MINIMUM 3.5)
|
||||
|
||||
FetchContent_Declare(
|
||||
doctest
|
||||
GIT_REPOSITORY "https://github.com/onqtam/doctest"
|
||||
GIT_TAG "ae7a13539fb71f270b87eb2e874fbac80bc8dda2")
|
||||
GIT_REPOSITORY https://github.com/onqtam/doctest.git
|
||||
GIT_TAG v2.4.12)
|
||||
FetchContent_MakeAvailable(doctest)
|
||||
|
||||
add_executable(tests ${PROJECT_SOURCE_DIR}/tests/tests.cpp)
|
||||
|
||||
Reference in New Issue
Block a user