mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
6 Commits
c1e3340b23
...
44cc5da4bc
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
44cc5da4bc | ||
|
|
dde3682b69 | ||
|
|
17310d91a6 | ||
|
|
b194d65a6a | ||
|
|
a44b27f5f8 | ||
|
|
e5a33f2223 |
@@ -87,22 +87,21 @@ cmake_policy(SET CMP0135 NEW)
|
|||||||
|
|
||||||
add_library(mlx)
|
add_library(mlx)
|
||||||
|
|
||||||
if(MLX_BUILD_METAL)
|
|
||||||
set(METAL_LIB "-framework Metal")
|
|
||||||
set(FOUNDATION_LIB "-framework Foundation")
|
|
||||||
set(QUARTZ_LIB "-framework QuartzCore")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if(MLX_BUILD_CUDA)
|
if(MLX_BUILD_CUDA)
|
||||||
enable_language(CUDA)
|
enable_language(CUDA)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(MLX_BUILD_METAL AND NOT METAL_LIB)
|
if(MLX_BUILD_METAL)
|
||||||
message(STATUS "Metal not found. Unable to build GPU")
|
find_library(METAL_LIB Metal)
|
||||||
set(MLX_BUILD_METAL OFF)
|
find_library(FOUNDATION_LIB Foundation)
|
||||||
set(MLX_METAL_DEBUG OFF)
|
find_library(QUARTZ_LIB QuartzCore)
|
||||||
elseif(MLX_BUILD_METAL)
|
if(METAL_LIB)
|
||||||
message(STATUS "Building METAL sources")
|
message(STATUS "Metal found ${METAL_LIB}")
|
||||||
|
else()
|
||||||
|
message(
|
||||||
|
FATAL_ERROR
|
||||||
|
"Metal not found. Set MLX_BUILD_METAL=OFF to build without GPU")
|
||||||
|
endif()
|
||||||
|
|
||||||
if(MLX_METAL_DEBUG)
|
if(MLX_METAL_DEBUG)
|
||||||
add_compile_definitions(MLX_METAL_DEBUG)
|
add_compile_definitions(MLX_METAL_DEBUG)
|
||||||
@@ -111,7 +110,8 @@ elseif(MLX_BUILD_METAL)
|
|||||||
# Throw an error if xcrun not found
|
# Throw an error if xcrun not found
|
||||||
execute_process(
|
execute_process(
|
||||||
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
|
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
|
||||||
OUTPUT_VARIABLE MACOS_SDK_VERSION COMMAND_ERROR_IS_FATAL ANY)
|
OUTPUT_VARIABLE MACOS_SDK_VERSION
|
||||||
|
OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
|
||||||
|
|
||||||
if(${MACOS_SDK_VERSION} LESS 14.0)
|
if(${MACOS_SDK_VERSION} LESS 14.0)
|
||||||
message(
|
message(
|
||||||
|
|||||||
@@ -85,10 +85,10 @@ cublasLtMatrixLayout_t create_matrix_layout(
|
|||||||
int32_t batch_count,
|
int32_t batch_count,
|
||||||
int64_t batch_stride) {
|
int64_t batch_stride) {
|
||||||
cublasLtMatrixLayout_t desc;
|
cublasLtMatrixLayout_t desc;
|
||||||
|
if (transposed) {
|
||||||
|
std::swap(rows, cols);
|
||||||
|
}
|
||||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&desc, type, rows, cols, ld));
|
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&desc, type, rows, cols, ld));
|
||||||
cublasLtOrder_t order = transposed ? CUBLASLT_ORDER_COL : CUBLASLT_ORDER_ROW;
|
|
||||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
|
|
||||||
desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(cublasLtOrder_t)));
|
|
||||||
if (batch_count > 1) {
|
if (batch_count > 1) {
|
||||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
|
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
|
||||||
desc,
|
desc,
|
||||||
@@ -138,25 +138,34 @@ CublasGemm::CublasGemm(
|
|||||||
CUBLASLT_MATMUL_DESC_POINTER_MODE,
|
CUBLASLT_MATMUL_DESC_POINTER_MODE,
|
||||||
&pointer_mode,
|
&pointer_mode,
|
||||||
sizeof(int32_t)));
|
sizeof(int32_t)));
|
||||||
cublasOperation_t op = CUBLAS_OP_N;
|
|
||||||
|
// In cublasLt matrices use column-major layout, while it is possible to use
|
||||||
|
// the CUBLASLT_ORDER_ROW option to switch to row-major layout, the bias
|
||||||
|
// epilogue does not work with the option. So instead we swap A and B to make
|
||||||
|
// cublasLt return the row-major result, which works because:
|
||||||
|
// - the data of a matrix in row-major layout is identical to its transpose in
|
||||||
|
// column-major layout
|
||||||
|
// - C^T = (A @ B)^T = B^T @ A^T
|
||||||
|
cublasOperation_t a_op = b_transposed ? CUBLAS_OP_T : CUBLAS_OP_N;
|
||||||
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
||||||
matmul_desc_,
|
matmul_desc_,
|
||||||
CUBLASLT_MATMUL_DESC_TRANSA,
|
CUBLASLT_MATMUL_DESC_TRANSA,
|
||||||
&op,
|
&a_op,
|
||||||
sizeof(cublasOperation_t)));
|
sizeof(cublasOperation_t)));
|
||||||
|
cublasOperation_t b_op = a_transposed ? CUBLAS_OP_T : CUBLAS_OP_N;
|
||||||
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
||||||
matmul_desc_,
|
matmul_desc_,
|
||||||
CUBLASLT_MATMUL_DESC_TRANSB,
|
CUBLASLT_MATMUL_DESC_TRANSB,
|
||||||
&op,
|
&b_op,
|
||||||
sizeof(cublasOperation_t)));
|
sizeof(cublasOperation_t)));
|
||||||
|
|
||||||
auto type = dtype_to_cublas_type(dtype);
|
auto type = dtype_to_cublas_type(dtype);
|
||||||
a_desc_ = create_matrix_layout(
|
a_desc_ = create_matrix_layout(
|
||||||
type, a_rows, a_cols, a_transposed, lda, batch_count, a_batch_stride);
|
type, b_cols, b_rows, b_transposed, ldb, batch_count, b_batch_stride);
|
||||||
b_desc_ = create_matrix_layout(
|
b_desc_ = create_matrix_layout(
|
||||||
type, b_rows, b_cols, b_transposed, ldb, batch_count, b_batch_stride);
|
type, a_cols, a_rows, a_transposed, lda, batch_count, a_batch_stride);
|
||||||
out_desc_ = create_matrix_layout(
|
out_desc_ = create_matrix_layout(
|
||||||
type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols);
|
type, b_cols, a_rows, false, b_cols, batch_count, a_rows * b_cols);
|
||||||
}
|
}
|
||||||
|
|
||||||
CublasGemm::CublasGemm(
|
CublasGemm::CublasGemm(
|
||||||
@@ -191,7 +200,7 @@ CublasGemm::CublasGemm(
|
|||||||
b_batch_stride) {
|
b_batch_stride) {
|
||||||
auto type = dtype_to_cublas_type(dtype);
|
auto type = dtype_to_cublas_type(dtype);
|
||||||
c_desc_ = create_matrix_layout(
|
c_desc_ = create_matrix_layout(
|
||||||
type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride);
|
type, b_cols, a_rows, false, ldc, batch_count, c_batch_stride);
|
||||||
}
|
}
|
||||||
|
|
||||||
CublasGemm::~CublasGemm() {
|
CublasGemm::~CublasGemm() {
|
||||||
@@ -213,14 +222,25 @@ void CublasGemm::set_out(
|
|||||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_));
|
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_));
|
||||||
out_desc_ = create_matrix_layout(
|
out_desc_ = create_matrix_layout(
|
||||||
dtype_to_cublas_type(dtype),
|
dtype_to_cublas_type(dtype),
|
||||||
rows,
|
|
||||||
cols,
|
cols,
|
||||||
|
rows,
|
||||||
transposed,
|
transposed,
|
||||||
ld,
|
ld,
|
||||||
batch_count,
|
batch_count,
|
||||||
batch_stride);
|
batch_stride);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void CublasGemm::set_bias(void* bias) {
|
||||||
|
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS;
|
||||||
|
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
||||||
|
matmul_desc_,
|
||||||
|
CUBLASLT_MATMUL_DESC_EPILOGUE,
|
||||||
|
&epilogue,
|
||||||
|
sizeof(epilogue)));
|
||||||
|
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
||||||
|
matmul_desc_, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias)));
|
||||||
|
}
|
||||||
|
|
||||||
void CublasGemm::run(
|
void CublasGemm::run(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
array& out,
|
array& out,
|
||||||
@@ -228,11 +248,19 @@ void CublasGemm::run(
|
|||||||
const array& b,
|
const array& b,
|
||||||
const Shape& batch_shape,
|
const Shape& batch_shape,
|
||||||
const Strides& a_batch_strides,
|
const Strides& a_batch_strides,
|
||||||
const Strides& b_batch_strides) {
|
const Strides& b_batch_strides,
|
||||||
|
float alpha) {
|
||||||
int batch_count = out.size() / (M_ * N_);
|
int batch_count = out.size() / (M_ * N_);
|
||||||
if (batch_count / batch_shape.back() > 1) {
|
if (batch_count / batch_shape.back() > 1) {
|
||||||
run_batched(
|
run_batched(
|
||||||
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
|
encoder,
|
||||||
|
out,
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
batch_shape,
|
||||||
|
a_batch_strides,
|
||||||
|
b_batch_strides,
|
||||||
|
alpha);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -240,7 +268,13 @@ void CublasGemm::run(
|
|||||||
encoder.set_input_array(b);
|
encoder.set_input_array(b);
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
|
|
||||||
execute(encoder, out.data<void>(), a.data<void>(), b.data<void>(), nullptr);
|
execute(
|
||||||
|
encoder,
|
||||||
|
out.data<void>(),
|
||||||
|
a.data<void>(),
|
||||||
|
b.data<void>(),
|
||||||
|
nullptr,
|
||||||
|
alpha);
|
||||||
}
|
}
|
||||||
|
|
||||||
void CublasGemm::run(
|
void CublasGemm::run(
|
||||||
@@ -330,9 +364,9 @@ void CublasGemm::execute(
|
|||||||
handle_,
|
handle_,
|
||||||
matmul_desc_,
|
matmul_desc_,
|
||||||
&alpha,
|
&alpha,
|
||||||
a,
|
b, // a and b are swapped
|
||||||
a_desc_,
|
a_desc_,
|
||||||
b,
|
a,
|
||||||
b_desc_,
|
b_desc_,
|
||||||
&beta,
|
&beta,
|
||||||
c ? c : out,
|
c ? c : out,
|
||||||
|
|||||||
@@ -55,6 +55,8 @@ class CublasGemm {
|
|||||||
int32_t batch_count,
|
int32_t batch_count,
|
||||||
int64_t batch_stride);
|
int64_t batch_stride);
|
||||||
|
|
||||||
|
void set_bias(void* bias);
|
||||||
|
|
||||||
void run(
|
void run(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
array& out,
|
array& out,
|
||||||
@@ -62,7 +64,8 @@ class CublasGemm {
|
|||||||
const array& b,
|
const array& b,
|
||||||
const Shape& batch_shape,
|
const Shape& batch_shape,
|
||||||
const Strides& a_batch_strides,
|
const Strides& a_batch_strides,
|
||||||
const Strides& b_batch_strides);
|
const Strides& b_batch_strides,
|
||||||
|
float alpha = 1.0f);
|
||||||
|
|
||||||
void run(
|
void run(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
@@ -85,7 +88,8 @@ class CublasGemm {
|
|||||||
const array& b,
|
const array& b,
|
||||||
const Shape& batch_shape,
|
const Shape& batch_shape,
|
||||||
const Strides& a_batch_strides,
|
const Strides& a_batch_strides,
|
||||||
const Strides& b_batch_strides);
|
const Strides& b_batch_strides,
|
||||||
|
float alpha);
|
||||||
|
|
||||||
void run_batched(
|
void run_batched(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
|
|||||||
@@ -13,7 +13,8 @@ void CublasGemm::run_batched(
|
|||||||
const array& b,
|
const array& b,
|
||||||
const Shape& batch_shape,
|
const Shape& batch_shape,
|
||||||
const Strides& a_batch_strides,
|
const Strides& a_batch_strides,
|
||||||
const Strides& b_batch_strides) {
|
const Strides& b_batch_strides,
|
||||||
|
float alpha) {
|
||||||
encoder.set_input_array(a);
|
encoder.set_input_array(a);
|
||||||
encoder.set_input_array(b);
|
encoder.set_input_array(b);
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
@@ -27,7 +28,8 @@ void CublasGemm::run_batched(
|
|||||||
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
|
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
|
||||||
a.data<int8_t>() + a.itemsize() * a_it.loc,
|
a.data<int8_t>() + a.itemsize() * a_it.loc,
|
||||||
b.data<int8_t>() + b.itemsize() * b_it.loc,
|
b.data<int8_t>() + b.itemsize() * b_it.loc,
|
||||||
nullptr);
|
nullptr,
|
||||||
|
alpha);
|
||||||
a_it.step();
|
a_it.step();
|
||||||
b_it.step();
|
b_it.step();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -154,7 +154,8 @@ void CublasGemm::run_batched(
|
|||||||
const array& b,
|
const array& b,
|
||||||
const Shape& batch_shape,
|
const Shape& batch_shape,
|
||||||
const Strides& a_batch_strides,
|
const Strides& a_batch_strides,
|
||||||
const Strides& b_batch_strides) {
|
const Strides& b_batch_strides,
|
||||||
|
float alpha) {
|
||||||
int batch_count = out.size() / (M_ * N_);
|
int batch_count = out.size() / (M_ * N_);
|
||||||
set_pointer_mode(a_desc_, batch_count);
|
set_pointer_mode(a_desc_, batch_count);
|
||||||
set_pointer_mode(b_desc_, batch_count);
|
set_pointer_mode(b_desc_, batch_count);
|
||||||
@@ -226,7 +227,8 @@ void CublasGemm::run_batched(
|
|||||||
reinterpret_cast<void*>(out_pointers),
|
reinterpret_cast<void*>(out_pointers),
|
||||||
reinterpret_cast<void*>(a_pointers),
|
reinterpret_cast<void*>(a_pointers),
|
||||||
reinterpret_cast<void*>(b_pointers),
|
reinterpret_cast<void*>(b_pointers),
|
||||||
nullptr);
|
nullptr,
|
||||||
|
alpha);
|
||||||
}
|
}
|
||||||
|
|
||||||
void CublasGemm::run_batched(
|
void CublasGemm::run_batched(
|
||||||
|
|||||||
@@ -11,6 +11,7 @@
|
|||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
std::tuple<bool, int64_t, array>
|
std::tuple<bool, int64_t, array>
|
||||||
@@ -28,6 +29,76 @@ check_transpose(cu::CommandEncoder& enc, const Stream& s, const array& arr) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void gemm_and_bias(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
int M,
|
||||||
|
int N,
|
||||||
|
int K,
|
||||||
|
bool a_transposed,
|
||||||
|
int64_t lda,
|
||||||
|
bool b_transposed,
|
||||||
|
int64_t ldb,
|
||||||
|
array& out,
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
void* bias = nullptr,
|
||||||
|
float alpha = 1.0f) {
|
||||||
|
// Check and collapse batch dimensions
|
||||||
|
auto [batch_shape, a_batch_strides, b_batch_strides] = collapse_batches(a, b);
|
||||||
|
|
||||||
|
auto batch_count = out.size() / (M * N);
|
||||||
|
|
||||||
|
// Collapse batches into M if needed
|
||||||
|
if (batch_count > 1 && !a_transposed && batch_shape.size() == 1 &&
|
||||||
|
a.strides()[a.ndim() - 2] == K && a_batch_strides.back() == M * K &&
|
||||||
|
b_batch_strides.back() == 0) {
|
||||||
|
M *= batch_shape.back();
|
||||||
|
batch_count = 1;
|
||||||
|
|
||||||
|
a_batch_strides = {0};
|
||||||
|
b_batch_strides = {0};
|
||||||
|
batch_shape = {1};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use gemmv when possible
|
||||||
|
if (!bias && cu::can_use_gemv(M, N, K, a_transposed, b_transposed)) {
|
||||||
|
cu::gemv(
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
out,
|
||||||
|
M,
|
||||||
|
N,
|
||||||
|
K,
|
||||||
|
batch_count,
|
||||||
|
batch_shape,
|
||||||
|
a_batch_strides,
|
||||||
|
b_batch_strides,
|
||||||
|
encoder);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Invoke cublasLt
|
||||||
|
CublasGemm gemm(
|
||||||
|
encoder.device(),
|
||||||
|
a.dtype(),
|
||||||
|
a_transposed,
|
||||||
|
M,
|
||||||
|
K,
|
||||||
|
lda,
|
||||||
|
b_transposed,
|
||||||
|
K,
|
||||||
|
N,
|
||||||
|
ldb,
|
||||||
|
batch_shape.back(),
|
||||||
|
a_batch_strides.back(),
|
||||||
|
b_batch_strides.back());
|
||||||
|
if (bias) {
|
||||||
|
gemm.set_bias(bias);
|
||||||
|
}
|
||||||
|
gemm.run(
|
||||||
|
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides, alpha);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
@@ -48,9 +119,6 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Init checks and prep
|
|
||||||
|
|
||||||
int M = a_pre.shape(-2);
|
int M = a_pre.shape(-2);
|
||||||
int N = b_pre.shape(-1);
|
int N = b_pre.shape(-1);
|
||||||
int K = a_pre.shape(-1);
|
int K = a_pre.shape(-1);
|
||||||
@@ -60,58 +128,8 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre);
|
auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre);
|
||||||
auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre);
|
auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre);
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
gemm_and_bias(
|
||||||
// Check and collapse batch dimensions
|
encoder, M, N, K, a_transposed, lda, b_transposed, ldb, out, a, b);
|
||||||
|
|
||||||
auto [batch_shape, a_batch_strides, b_batch_strides] = collapse_batches(a, b);
|
|
||||||
|
|
||||||
auto batch_count = out.size() / (M * N);
|
|
||||||
|
|
||||||
// Collapse batches into M if needed
|
|
||||||
if (batch_count > 1 && !a_transposed && batch_shape.size() == 1 &&
|
|
||||||
a.strides()[a.ndim() - 2] == K && a_batch_strides.back() == M * K &&
|
|
||||||
b_batch_strides.back() == 0) {
|
|
||||||
M *= batch_shape.back();
|
|
||||||
batch_count = 1;
|
|
||||||
|
|
||||||
a_batch_strides = {0};
|
|
||||||
b_batch_strides = {0};
|
|
||||||
batch_shape = {1};
|
|
||||||
}
|
|
||||||
|
|
||||||
if (cu::can_use_gemv(M, N, K, a_transposed, b_transposed)) {
|
|
||||||
cu::gemv(
|
|
||||||
a,
|
|
||||||
b,
|
|
||||||
out,
|
|
||||||
M,
|
|
||||||
N,
|
|
||||||
K,
|
|
||||||
batch_count,
|
|
||||||
batch_shape,
|
|
||||||
a_batch_strides,
|
|
||||||
b_batch_strides,
|
|
||||||
encoder);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Invoke cublasLt
|
|
||||||
CublasGemm gemm(
|
|
||||||
cu::device(s.device),
|
|
||||||
a.dtype(),
|
|
||||||
a_transposed,
|
|
||||||
M,
|
|
||||||
K,
|
|
||||||
lda,
|
|
||||||
b_transposed,
|
|
||||||
K,
|
|
||||||
N,
|
|
||||||
ldb,
|
|
||||||
batch_shape.back(),
|
|
||||||
a_batch_strides.back(),
|
|
||||||
b_batch_strides.back());
|
|
||||||
gemm.run(encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
@@ -136,6 +154,28 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre);
|
auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre);
|
||||||
auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre);
|
auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre);
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Dispatch to GEMM with epilogue or AddMM
|
||||||
|
|
||||||
|
if (beta_ == 1 && c.strides(-1) == 1 && c.data_size() == out.shape(-1)) {
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
gemm_and_bias(
|
||||||
|
encoder,
|
||||||
|
M,
|
||||||
|
N,
|
||||||
|
K,
|
||||||
|
a_transposed,
|
||||||
|
lda,
|
||||||
|
b_transposed,
|
||||||
|
ldb,
|
||||||
|
out,
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
c.data<void>(),
|
||||||
|
alpha_);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
int64_t ldc;
|
int64_t ldc;
|
||||||
{
|
{
|
||||||
auto stx = c.strides()[c.ndim() - 2];
|
auto stx = c.strides()[c.ndim() - 2];
|
||||||
@@ -177,7 +217,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
// Invoke cublasLt
|
// Invoke cublasLt with AddMM settings
|
||||||
|
|
||||||
CublasGemm gemm(
|
CublasGemm gemm(
|
||||||
cu::device(s.device),
|
cu::device(s.device),
|
||||||
|
|||||||
@@ -103,15 +103,21 @@ template <typename T, bool traditional, bool forward, int N = 4>
|
|||||||
__device__ void rope_impl(
|
__device__ void rope_impl(
|
||||||
const T* in,
|
const T* in,
|
||||||
T* out,
|
T* out,
|
||||||
int offset,
|
const int* offset,
|
||||||
float inv_freq,
|
float inv_freq,
|
||||||
float scale,
|
float scale,
|
||||||
const cuda::std::array<int64_t, 3> strides,
|
const cuda::std::array<int64_t, 3> strides,
|
||||||
const cuda::std::array<int64_t, 3> out_strides,
|
const cuda::std::array<int64_t, 3> out_strides,
|
||||||
int64_t n_batch,
|
int64_t offset_stride,
|
||||||
|
int n_head,
|
||||||
uint3 pos,
|
uint3 pos,
|
||||||
uint3 dims) {
|
uint3 dims) {
|
||||||
float L = scale * static_cast<float>(pos.y + offset);
|
auto n_head_up = N * ((n_head + N - 1) / N);
|
||||||
|
auto head_idx = static_cast<int>((pos.z * N) % n_head_up);
|
||||||
|
auto batch_idx = (pos.z * N) / n_head_up;
|
||||||
|
auto batch_offset = offset[batch_idx * offset_stride];
|
||||||
|
float L = scale * static_cast<float>(pos.y + batch_offset);
|
||||||
|
auto mat_idx = batch_idx * n_head + head_idx;
|
||||||
|
|
||||||
// Compute costheta, sintheta
|
// Compute costheta, sintheta
|
||||||
float theta = L * inv_freq;
|
float theta = L * inv_freq;
|
||||||
@@ -123,20 +129,19 @@ __device__ void rope_impl(
|
|||||||
size_t out_index_1, out_index_2;
|
size_t out_index_1, out_index_2;
|
||||||
if (traditional) {
|
if (traditional) {
|
||||||
out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] +
|
out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] +
|
||||||
N * pos.z * out_strides[0];
|
mat_idx * out_strides[0];
|
||||||
out_index_2 = out_index_1 + 1;
|
out_index_2 = out_index_1 + 1;
|
||||||
in_index_1 =
|
in_index_1 =
|
||||||
2 * pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
|
2 * pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0];
|
||||||
in_index_2 = in_index_1 + strides[2];
|
in_index_2 = in_index_1 + strides[2];
|
||||||
} else {
|
} else {
|
||||||
out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] +
|
out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] +
|
||||||
N * pos.z * out_strides[0];
|
mat_idx * out_strides[0];
|
||||||
out_index_2 = out_index_1 + dims.x * out_strides[2];
|
out_index_2 = out_index_1 + dims.x * out_strides[2];
|
||||||
in_index_1 =
|
in_index_1 = pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0];
|
||||||
pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
|
|
||||||
in_index_2 = in_index_1 + dims.x * strides[2];
|
in_index_2 = in_index_1 + dims.x * strides[2];
|
||||||
}
|
}
|
||||||
for (int i = 0; i < N && pos.z * N + i < n_batch; ++i) {
|
for (int i = 0; i < N && head_idx + i < n_head; ++i) {
|
||||||
// Read and write the output
|
// Read and write the output
|
||||||
float x1 = static_cast<float>(in[in_index_1]);
|
float x1 = static_cast<float>(in[in_index_1]);
|
||||||
float x2 = static_cast<float>(in[in_index_2]);
|
float x2 = static_cast<float>(in[in_index_2]);
|
||||||
@@ -167,7 +172,8 @@ __global__ void rope(
|
|||||||
float base,
|
float base,
|
||||||
const __grid_constant__ cuda::std::array<int64_t, 3> strides,
|
const __grid_constant__ cuda::std::array<int64_t, 3> strides,
|
||||||
const __grid_constant__ cuda::std::array<int64_t, 3> out_strides,
|
const __grid_constant__ cuda::std::array<int64_t, 3> out_strides,
|
||||||
int64_t n_batch,
|
int64_t offset_stride,
|
||||||
|
int n_head,
|
||||||
uint3 dims) {
|
uint3 dims) {
|
||||||
uint3 pos = make_uint3(
|
uint3 pos = make_uint3(
|
||||||
blockIdx.x * blockDim.x + threadIdx.x,
|
blockIdx.x * blockDim.x + threadIdx.x,
|
||||||
@@ -182,12 +188,13 @@ __global__ void rope(
|
|||||||
rope_impl<T, traditional, forward>(
|
rope_impl<T, traditional, forward>(
|
||||||
in,
|
in,
|
||||||
out,
|
out,
|
||||||
*offset,
|
offset,
|
||||||
inv_freq,
|
inv_freq,
|
||||||
scale,
|
scale,
|
||||||
strides,
|
strides,
|
||||||
out_strides,
|
out_strides,
|
||||||
n_batch,
|
offset_stride,
|
||||||
|
n_head,
|
||||||
pos,
|
pos,
|
||||||
dims);
|
dims);
|
||||||
}
|
}
|
||||||
@@ -202,7 +209,8 @@ __global__ void rope_freqs(
|
|||||||
float base,
|
float base,
|
||||||
const __grid_constant__ cuda::std::array<int64_t, 3> strides,
|
const __grid_constant__ cuda::std::array<int64_t, 3> strides,
|
||||||
const __grid_constant__ cuda::std::array<int64_t, 3> out_strides,
|
const __grid_constant__ cuda::std::array<int64_t, 3> out_strides,
|
||||||
int64_t n_batch,
|
int64_t offset_stride,
|
||||||
|
int n_head,
|
||||||
uint3 dims,
|
uint3 dims,
|
||||||
int64_t freq_stride) {
|
int64_t freq_stride) {
|
||||||
uint3 pos = make_uint3(
|
uint3 pos = make_uint3(
|
||||||
@@ -217,12 +225,13 @@ __global__ void rope_freqs(
|
|||||||
rope_impl<T, traditional, forward>(
|
rope_impl<T, traditional, forward>(
|
||||||
in,
|
in,
|
||||||
out,
|
out,
|
||||||
*offset,
|
offset,
|
||||||
inv_freq,
|
inv_freq,
|
||||||
scale,
|
scale,
|
||||||
strides,
|
strides,
|
||||||
out_strides,
|
out_strides,
|
||||||
n_batch,
|
offset_stride,
|
||||||
|
n_head,
|
||||||
pos,
|
pos,
|
||||||
dims);
|
dims);
|
||||||
}
|
}
|
||||||
@@ -245,23 +254,28 @@ void RoPE::eval_gpu(
|
|||||||
auto& offset = inputs[1];
|
auto& offset = inputs[1];
|
||||||
auto& out = outputs[0];
|
auto& out = outputs[0];
|
||||||
|
|
||||||
if (in.ndim() < 3) {
|
|
||||||
throw std::runtime_error("[RoPE] Input must have at least 3 dimensions");
|
|
||||||
}
|
|
||||||
|
|
||||||
cuda::std::array<int64_t, 3> strides;
|
cuda::std::array<int64_t, 3> strides;
|
||||||
cuda::std::array<int64_t, 3> out_strides;
|
cuda::std::array<int64_t, 3> out_strides;
|
||||||
bool donated = false;
|
bool donated = false;
|
||||||
int ndim = in.ndim();
|
int ndim = in.ndim();
|
||||||
int dispatch_ndim = in.ndim();
|
|
||||||
|
int B = in.shape(0);
|
||||||
|
int T = in.shape(-2);
|
||||||
|
int D = in.shape(-1);
|
||||||
|
size_t mat_size = T * D;
|
||||||
|
int dispatch_ndim = ndim;
|
||||||
while (in.shape(-dispatch_ndim) == 1 && dispatch_ndim > 3) {
|
while (in.shape(-dispatch_ndim) == 1 && dispatch_ndim > 3) {
|
||||||
dispatch_ndim--;
|
dispatch_ndim--;
|
||||||
}
|
}
|
||||||
size_t mat_size = in.shape(-2) * in.shape(-1);
|
|
||||||
|
int N = 1;
|
||||||
|
for (int i = 1; i < (ndim - 2); ++i) {
|
||||||
|
N *= in.shape(i);
|
||||||
|
}
|
||||||
|
|
||||||
// We apply rope to less that the whole vector so copy to output and then
|
// We apply rope to less that the whole vector so copy to output and then
|
||||||
// apply in-place.
|
// apply in-place.
|
||||||
if (dims_ < in.shape(-1)) {
|
if (dims_ < D) {
|
||||||
donated = true;
|
donated = true;
|
||||||
auto ctype =
|
auto ctype =
|
||||||
(in.flags().row_contiguous) ? CopyType::Vector : CopyType::General;
|
(in.flags().row_contiguous) ? CopyType::Vector : CopyType::General;
|
||||||
@@ -302,7 +316,7 @@ void RoPE::eval_gpu(
|
|||||||
out_strides[2] = out.strides()[ndim - 1];
|
out_strides[2] = out.strides()[ndim - 1];
|
||||||
|
|
||||||
// Some flags to help us dispatch below
|
// Some flags to help us dispatch below
|
||||||
bool single = in.flags().row_contiguous && (mat_size == in.shape(-1));
|
bool single = in.flags().row_contiguous && B == 1 && T == 1;
|
||||||
bool with_freqs = inputs.size() == 3;
|
bool with_freqs = inputs.size() == 3;
|
||||||
|
|
||||||
auto& encoder = cu::get_command_encoder(s);
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
@@ -319,7 +333,7 @@ void RoPE::eval_gpu(
|
|||||||
if (single && !with_freqs) {
|
if (single && !with_freqs) {
|
||||||
auto kernel =
|
auto kernel =
|
||||||
cu::rope_single<DataType, traditional.value, forward.value>;
|
cu::rope_single<DataType, traditional.value, forward.value>;
|
||||||
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size);
|
uint2 dims = make_uint2(dims_ / 2, N);
|
||||||
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
|
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
|
||||||
encoder.add_kernel_node(
|
encoder.add_kernel_node(
|
||||||
kernel,
|
kernel,
|
||||||
@@ -336,7 +350,7 @@ void RoPE::eval_gpu(
|
|||||||
} else if (single) {
|
} else if (single) {
|
||||||
auto kernel =
|
auto kernel =
|
||||||
cu::rope_single_freqs<DataType, traditional.value, forward.value>;
|
cu::rope_single_freqs<DataType, traditional.value, forward.value>;
|
||||||
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size);
|
uint2 dims = make_uint2(dims_ / 2, N);
|
||||||
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
|
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
|
||||||
encoder.add_kernel_node(
|
encoder.add_kernel_node(
|
||||||
kernel,
|
kernel,
|
||||||
@@ -354,10 +368,14 @@ void RoPE::eval_gpu(
|
|||||||
} else if (with_freqs) {
|
} else if (with_freqs) {
|
||||||
auto kernel =
|
auto kernel =
|
||||||
cu::rope_freqs<DataType, traditional.value, forward.value>;
|
cu::rope_freqs<DataType, traditional.value, forward.value>;
|
||||||
uint3 dims =
|
int n_per_thread = 4;
|
||||||
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size);
|
uint32_t dimz = B * ((N + n_per_thread - 1) / n_per_thread);
|
||||||
dims.z = (dims.z + 3) / 4;
|
uint3 dims = make_uint3(dims_ / 2, T, dimz);
|
||||||
auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z);
|
auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z);
|
||||||
|
int64_t offset_stride = 0;
|
||||||
|
if (inputs[1].ndim() > 0) {
|
||||||
|
offset_stride = inputs[1].strides()[0];
|
||||||
|
}
|
||||||
encoder.add_kernel_node(
|
encoder.add_kernel_node(
|
||||||
kernel,
|
kernel,
|
||||||
grid,
|
grid,
|
||||||
@@ -371,15 +389,20 @@ void RoPE::eval_gpu(
|
|||||||
std::log2(base_),
|
std::log2(base_),
|
||||||
strides,
|
strides,
|
||||||
out_strides,
|
out_strides,
|
||||||
in.size() / mat_size,
|
offset_stride,
|
||||||
|
N,
|
||||||
dims,
|
dims,
|
||||||
inputs[2].strides(0));
|
inputs[2].strides(0));
|
||||||
} else {
|
} else {
|
||||||
auto kernel = cu::rope<DataType, traditional.value, forward.value>;
|
auto kernel = cu::rope<DataType, traditional.value, forward.value>;
|
||||||
uint3 dims =
|
int n_per_thread = 4;
|
||||||
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size);
|
uint32_t dimz = B * ((N + n_per_thread - 1) / n_per_thread);
|
||||||
dims.z = (dims.z + 3) / 4;
|
uint3 dims = make_uint3(dims_ / 2, T, dimz);
|
||||||
auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z);
|
auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z);
|
||||||
|
int64_t offset_stride = 0;
|
||||||
|
if (inputs[1].ndim() > 0) {
|
||||||
|
offset_stride = inputs[1].strides()[0];
|
||||||
|
}
|
||||||
encoder.add_kernel_node(
|
encoder.add_kernel_node(
|
||||||
kernel,
|
kernel,
|
||||||
grid,
|
grid,
|
||||||
@@ -392,7 +415,8 @@ void RoPE::eval_gpu(
|
|||||||
std::log2(base_),
|
std::log2(base_),
|
||||||
strides,
|
strides,
|
||||||
out_strides,
|
out_strides,
|
||||||
in.size() / mat_size,
|
offset_stride,
|
||||||
|
N,
|
||||||
dims);
|
dims);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include <sstream>
|
|
||||||
|
|
||||||
#include "mlx/backend/gpu/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
@@ -39,10 +38,11 @@ void explicit_gemm_conv_ND_gpu(
|
|||||||
in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes()));
|
in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes()));
|
||||||
|
|
||||||
// Prepare unfolding kernel
|
// Prepare unfolding kernel
|
||||||
std::ostringstream kname;
|
std::string kname;
|
||||||
kname << "naive_unfold_nd_" << type_to_name(in_unfolded) << "_" << N;
|
kname.reserve(32);
|
||||||
|
concatenate(kname, "naive_unfold_nd_", type_to_name(in_unfolded), "_", N);
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
auto kernel = d.get_kernel(kname.str());
|
auto kernel = d.get_kernel(kname);
|
||||||
compute_encoder.set_compute_pipeline_state(kernel);
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
|
|
||||||
compute_encoder.set_input_array(in, 0);
|
compute_encoder.set_input_array(in, 0);
|
||||||
@@ -117,11 +117,12 @@ void explicit_gemm_conv_group_ND_gpu(
|
|||||||
in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes()));
|
in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes()));
|
||||||
|
|
||||||
// Prepare unfolding kernel
|
// Prepare unfolding kernel
|
||||||
std::ostringstream kname;
|
std::string kname;
|
||||||
kname << "naive_unfold_transpose_nd_" << type_to_name(in_unfolded) << "_"
|
kname.reserve(32);
|
||||||
<< N;
|
concatenate(
|
||||||
|
kname, "naive_unfold_transpose_nd_", type_to_name(in_unfolded), "_", N);
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
auto kernel = d.get_kernel(kname.str());
|
auto kernel = d.get_kernel(kname);
|
||||||
compute_encoder.set_compute_pipeline_state(kernel);
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
|
|
||||||
compute_encoder.set_input_array(in, 0);
|
compute_encoder.set_input_array(in, 0);
|
||||||
@@ -252,18 +253,32 @@ void implicit_gemm_conv_2D_gpu(
|
|||||||
/* const int swizzle_log = */ swizzle_log};
|
/* const int swizzle_log = */ swizzle_log};
|
||||||
|
|
||||||
// Determine kernel
|
// Determine kernel
|
||||||
std::ostringstream kname;
|
std::string kname;
|
||||||
kname << "implicit_gemm_conv_2d_" << type_to_name(out) << "_bm" << bm << "_bn"
|
kname.reserve(64);
|
||||||
<< bn << "_bk" << bk << "_wm" << wm << "_wn" << wn << "_channel_"
|
concatenate(
|
||||||
<< (n_channel_specialization ? std::to_string(n_channel_specialization)
|
kname,
|
||||||
: "l")
|
"implicit_gemm_conv_2d_",
|
||||||
<< "_filter_" << (small_filter ? 's' : 'l');
|
type_to_name(out),
|
||||||
|
"_bm",
|
||||||
|
bm,
|
||||||
|
"_bn",
|
||||||
|
bn,
|
||||||
|
"_bk",
|
||||||
|
bk,
|
||||||
|
"_wm",
|
||||||
|
wm,
|
||||||
|
"_wn",
|
||||||
|
wn,
|
||||||
|
"_channel_",
|
||||||
|
n_channel_specialization ? std::to_string(n_channel_specialization) : "l",
|
||||||
|
"_filter_",
|
||||||
|
small_filter ? 's' : 'l');
|
||||||
|
|
||||||
// Encode and dispatch kernel
|
// Encode and dispatch kernel
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
auto kernel = get_steel_conv_kernel(
|
auto kernel = get_steel_conv_kernel(
|
||||||
d,
|
d,
|
||||||
kname.str(),
|
kname,
|
||||||
out,
|
out,
|
||||||
bm,
|
bm,
|
||||||
bn,
|
bn,
|
||||||
@@ -559,11 +574,16 @@ void winograd_conv_2D_gpu(
|
|||||||
{
|
{
|
||||||
int bc = 32;
|
int bc = 32;
|
||||||
int bo = 4;
|
int bo = 4;
|
||||||
std::ostringstream kname;
|
std::string kname;
|
||||||
kname << "winograd_conv_2d_weight_transform_" << type_to_name(out) << "_bc"
|
kname.reserve(32);
|
||||||
<< bc;
|
concatenate(
|
||||||
|
kname,
|
||||||
|
"winograd_conv_2d_weight_transform_",
|
||||||
|
type_to_name(out),
|
||||||
|
"_bc",
|
||||||
|
bc);
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
auto kernel = d.get_kernel(kname.str());
|
auto kernel = d.get_kernel(kname);
|
||||||
compute_encoder.set_compute_pipeline_state(kernel);
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
|
|
||||||
compute_encoder.set_input_array(wt, 0);
|
compute_encoder.set_input_array(wt, 0);
|
||||||
@@ -587,11 +607,16 @@ void winograd_conv_2D_gpu(
|
|||||||
int bc = 32;
|
int bc = 32;
|
||||||
int wm = 2;
|
int wm = 2;
|
||||||
int wn = 2;
|
int wn = 2;
|
||||||
std::ostringstream kname;
|
std::string kname;
|
||||||
kname << "winograd_conv_2d_input_transform_" << type_to_name(out) << "_bc"
|
kname.reserve(32);
|
||||||
<< bc;
|
concatenate(
|
||||||
|
kname,
|
||||||
|
"winograd_conv_2d_input_transform_",
|
||||||
|
type_to_name(out),
|
||||||
|
"_bc",
|
||||||
|
bc);
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
auto kernel = d.get_kernel(kname.str());
|
auto kernel = d.get_kernel(kname);
|
||||||
compute_encoder.set_compute_pipeline_state(kernel);
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
|
|
||||||
compute_encoder.set_input_array(in_padded, 0);
|
compute_encoder.set_input_array(in_padded, 0);
|
||||||
@@ -634,11 +659,16 @@ void winograd_conv_2D_gpu(
|
|||||||
int bc = 32;
|
int bc = 32;
|
||||||
int wm = 2;
|
int wm = 2;
|
||||||
int wn = 2;
|
int wn = 2;
|
||||||
std::ostringstream kname;
|
std::string kname;
|
||||||
kname << "winograd_conv_2d_output_transform_" << type_to_name(out) << "_bo"
|
kname.reserve(32);
|
||||||
<< bc;
|
concatenate(
|
||||||
|
kname,
|
||||||
|
"winograd_conv_2d_output_transform_",
|
||||||
|
type_to_name(out),
|
||||||
|
"_bo",
|
||||||
|
bc);
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
auto kernel = d.get_kernel(kname.str());
|
auto kernel = d.get_kernel(kname);
|
||||||
compute_encoder.set_compute_pipeline_state(kernel);
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
|
|
||||||
compute_encoder.set_input_array(out_wg, 0);
|
compute_encoder.set_input_array(out_wg, 0);
|
||||||
@@ -660,9 +690,9 @@ void depthwise_conv_2D_gpu(
|
|||||||
const array& wt,
|
const array& wt,
|
||||||
array out,
|
array out,
|
||||||
const MLXConvParams<2>& conv_params) {
|
const MLXConvParams<2>& conv_params) {
|
||||||
std::ostringstream kname;
|
std::string base_name;
|
||||||
kname << "depthwise_conv_2d_" << type_to_name(out);
|
base_name.reserve(32);
|
||||||
std::string base_name = kname.str();
|
concatenate(base_name, "depthwise_conv_2d_", type_to_name(out));
|
||||||
|
|
||||||
const int N = conv_params.N;
|
const int N = conv_params.N;
|
||||||
const int ker_h = conv_params.wS[0];
|
const int ker_h = conv_params.wS[0];
|
||||||
@@ -685,15 +715,18 @@ void depthwise_conv_2D_gpu(
|
|||||||
};
|
};
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
kname << "_ker_h_" << ker_h
|
std::string hash_name;
|
||||||
<< "_ker_w_" << ker_w
|
hash_name.reserve(64);
|
||||||
<< "_str_h_" << str_h
|
concatenate(
|
||||||
<< "_str_w_" << str_w
|
hash_name,
|
||||||
<< "_tgp_h_" << th
|
base_name,
|
||||||
<< "_tgp_w_" << tw
|
"_ker_h_", ker_h,
|
||||||
<< "_do_flip_" << (do_flip ? 't' : 'n'); // clang-format on
|
"_ker_w_", ker_w,
|
||||||
|
"_str_h_", str_h,
|
||||||
std::string hash_name = kname.str();
|
"_str_w_", str_w,
|
||||||
|
"_tgp_h_", th,
|
||||||
|
"_tgp_w_", tw,
|
||||||
|
"_do_flip_", do_flip ? 't' : 'n'); // clang-format on
|
||||||
|
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
auto kernel = d.get_kernel(base_name, hash_name, func_consts);
|
auto kernel = d.get_kernel(base_name, hash_name, func_consts);
|
||||||
@@ -774,6 +807,56 @@ void dispatch_conv_2D_gpu(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void depthwise_conv_1D_gpu(
|
||||||
|
const Stream& s,
|
||||||
|
metal::Device& d,
|
||||||
|
const array& in,
|
||||||
|
array wt,
|
||||||
|
array out) {
|
||||||
|
bool large = in.size() > INT32_MAX || in.data_size() > INT32_MAX;
|
||||||
|
std::string base_name;
|
||||||
|
base_name.reserve(32);
|
||||||
|
concatenate(
|
||||||
|
base_name,
|
||||||
|
"depthwise_conv_1d_",
|
||||||
|
large ? "_large" : "",
|
||||||
|
type_to_name(out));
|
||||||
|
|
||||||
|
if (!wt.flags().row_contiguous) {
|
||||||
|
wt = contiguous_copy_gpu(wt, s);
|
||||||
|
d.add_temporary(wt, s.index);
|
||||||
|
}
|
||||||
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
auto kernel = d.get_kernel(base_name);
|
||||||
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
|
|
||||||
|
auto B = in.shape(0);
|
||||||
|
auto Tout = out.shape(1);
|
||||||
|
auto D = in.shape(2);
|
||||||
|
auto K = wt.shape(1);
|
||||||
|
|
||||||
|
compute_encoder.set_input_array(in, 0);
|
||||||
|
compute_encoder.set_input_array(wt, 1);
|
||||||
|
compute_encoder.set_output_array(out, 2);
|
||||||
|
if (large) {
|
||||||
|
int64_t strides[3] = {in.strides(0), in.strides(1), in.strides(2)};
|
||||||
|
compute_encoder.set_bytes(strides, 3, 3);
|
||||||
|
|
||||||
|
} else {
|
||||||
|
int strides[3] = {
|
||||||
|
static_cast<int>(in.strides(0)),
|
||||||
|
static_cast<int>(in.strides(1)),
|
||||||
|
static_cast<int>(in.strides(2))};
|
||||||
|
compute_encoder.set_bytes(strides, 3, 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
compute_encoder.set_bytes(K, 4);
|
||||||
|
auto group_dims = get_block_dims(D, Tout, B);
|
||||||
|
MTL::Size grid_dims = MTL::Size(D, Tout, B);
|
||||||
|
|
||||||
|
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||||
|
}
|
||||||
|
|
||||||
void conv_1D_gpu(
|
void conv_1D_gpu(
|
||||||
const Stream& s,
|
const Stream& s,
|
||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
@@ -790,8 +873,15 @@ void conv_1D_gpu(
|
|||||||
bool is_idil_one = in_dilation[0] == 1;
|
bool is_idil_one = in_dilation[0] == 1;
|
||||||
int C = in.shape(2);
|
int C = in.shape(2);
|
||||||
int O = wt.shape(0);
|
int O = wt.shape(0);
|
||||||
const int C_per_group = in.shape(2) / groups;
|
// Fast path for fully separable 1D convolution
|
||||||
const int O_per_group = wt.shape(0) / groups;
|
if (is_idil_one && (groups == C) && groups == O && wt_strides[0] == 1 &&
|
||||||
|
wt_dilation[0] == 1 && padding[0] == 0 && !flip) {
|
||||||
|
depthwise_conv_1D_gpu(s, d, in, wt, out);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int C_per_group = C / groups;
|
||||||
|
const int O_per_group = O / groups;
|
||||||
|
|
||||||
// Direct to implicit gemm conv
|
// Direct to implicit gemm conv
|
||||||
if (is_idil_one && (C_per_group <= 4 || C_per_group % 16 == 0) &&
|
if (is_idil_one && (C_per_group <= 4 || C_per_group % 16 == 0) &&
|
||||||
|
|||||||
@@ -288,6 +288,40 @@ instantiate_depthconv2d(float32, float);
|
|||||||
instantiate_depthconv2d(float16, half);
|
instantiate_depthconv2d(float16, half);
|
||||||
instantiate_depthconv2d(bfloat16, bfloat16_t);
|
instantiate_depthconv2d(bfloat16, bfloat16_t);
|
||||||
|
|
||||||
|
template <typename T, typename IdxT>
|
||||||
|
[[kernel]] void depthwise_conv_1d(
|
||||||
|
const device T* in [[buffer(0)]],
|
||||||
|
const device T* w [[buffer(1)]],
|
||||||
|
device T* out [[buffer(2)]],
|
||||||
|
constant const IdxT strides[3],
|
||||||
|
constant const int& kernel_size,
|
||||||
|
uint3 tid [[thread_position_in_grid]],
|
||||||
|
uint3 grid_dim [[threads_per_grid]]) {
|
||||||
|
out += (tid.z * static_cast<IdxT>(grid_dim.y) + tid.y) * grid_dim.x + tid.x;
|
||||||
|
in += tid.z * strides[0] + tid.y * strides[1] + tid.x * strides[2];
|
||||||
|
w += tid.x * kernel_size;
|
||||||
|
|
||||||
|
float acc = 0.0;
|
||||||
|
for (int i = 0; i < kernel_size; ++i) {
|
||||||
|
acc += static_cast<float>(in[0]) * w[i];
|
||||||
|
in += strides[1];
|
||||||
|
}
|
||||||
|
*out = static_cast<T>(acc);
|
||||||
|
}
|
||||||
|
|
||||||
|
#define instantiate_depthconv1d(iname, itype) \
|
||||||
|
instantiate_kernel( \
|
||||||
|
"depthwise_conv_1d_" #iname, depthwise_conv_1d, itype, int32_t) \
|
||||||
|
instantiate_kernel( \
|
||||||
|
"depthwise_conv_1d_" #iname "_large", \
|
||||||
|
depthwise_conv_1d, \
|
||||||
|
itype, \
|
||||||
|
int64_t)
|
||||||
|
|
||||||
|
instantiate_depthconv1d(float32, float);
|
||||||
|
instantiate_depthconv1d(float16, half);
|
||||||
|
instantiate_depthconv1d(bfloat16, bfloat16_t);
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
/// Winograd kernels
|
/// Winograd kernels
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ void rope_single_impl(
|
|||||||
constant const int& offset,
|
constant const int& offset,
|
||||||
const float inv_freq,
|
const float inv_freq,
|
||||||
constant const float& scale,
|
constant const float& scale,
|
||||||
constant const size_t& stride,
|
constant const int64_t& stride,
|
||||||
uint2 pos,
|
uint2 pos,
|
||||||
uint2 grid) {
|
uint2 grid) {
|
||||||
float L = scale * static_cast<float>(offset);
|
float L = scale * static_cast<float>(offset);
|
||||||
@@ -52,7 +52,7 @@ template <typename T, bool traditional, bool forward>
|
|||||||
device T* out [[buffer(1)]],
|
device T* out [[buffer(1)]],
|
||||||
constant const int& offset,
|
constant const int& offset,
|
||||||
constant const float& scale,
|
constant const float& scale,
|
||||||
constant const size_t& stride,
|
constant const int64_t& stride,
|
||||||
constant const float& base [[buffer(10)]],
|
constant const float& base [[buffer(10)]],
|
||||||
uint2 pos [[thread_position_in_grid]],
|
uint2 pos [[thread_position_in_grid]],
|
||||||
uint2 grid [[threads_per_grid]]) {
|
uint2 grid [[threads_per_grid]]) {
|
||||||
@@ -68,9 +68,9 @@ template <typename T, bool traditional, bool forward>
|
|||||||
device T* out [[buffer(1)]],
|
device T* out [[buffer(1)]],
|
||||||
constant const int& offset,
|
constant const int& offset,
|
||||||
constant const float& scale,
|
constant const float& scale,
|
||||||
constant const size_t& stride,
|
constant const int64_t& stride,
|
||||||
const device float* freqs [[buffer(10)]],
|
const device float* freqs [[buffer(10)]],
|
||||||
constant const size_t& freq_stride [[buffer(11)]],
|
constant const int64_t& freq_stride [[buffer(11)]],
|
||||||
uint2 pos [[thread_position_in_grid]],
|
uint2 pos [[thread_position_in_grid]],
|
||||||
uint2 grid [[threads_per_grid]]) {
|
uint2 grid [[threads_per_grid]]) {
|
||||||
float inv_freq = 1.0 / (freqs[freq_stride * pos.x]);
|
float inv_freq = 1.0 / (freqs[freq_stride * pos.x]);
|
||||||
@@ -82,15 +82,21 @@ template <typename T, bool traditional, bool forward, int N = 4>
|
|||||||
void rope_impl(
|
void rope_impl(
|
||||||
const device T* in,
|
const device T* in,
|
||||||
device T* out,
|
device T* out,
|
||||||
constant const int& offset,
|
const device int* offset,
|
||||||
const float inv_freq,
|
const float inv_freq,
|
||||||
constant const float& scale,
|
constant const float& scale,
|
||||||
constant const size_t strides[3],
|
constant const int64_t strides[3],
|
||||||
constant const size_t out_strides[3],
|
constant const int64_t out_strides[3],
|
||||||
constant const size_t& n_batch,
|
constant const int64_t& offset_stride,
|
||||||
|
constant const int& n_head,
|
||||||
uint3 pos,
|
uint3 pos,
|
||||||
uint3 grid) {
|
uint3 grid) {
|
||||||
float L = scale * static_cast<float>(pos.y + offset);
|
auto n_head_up = N * ((n_head + N - 1) / N);
|
||||||
|
auto head_idx = static_cast<int>((pos.z * N) % n_head_up);
|
||||||
|
auto batch_idx = (pos.z * N) / n_head_up;
|
||||||
|
auto batch_offset = offset[batch_idx * offset_stride];
|
||||||
|
float L = scale * static_cast<float>(pos.y + batch_offset);
|
||||||
|
auto mat_idx = batch_idx * n_head + head_idx;
|
||||||
|
|
||||||
// Compute costheta, sintheta
|
// Compute costheta, sintheta
|
||||||
float theta = L * inv_freq;
|
float theta = L * inv_freq;
|
||||||
@@ -102,20 +108,19 @@ void rope_impl(
|
|||||||
size_t out_index_1, out_index_2;
|
size_t out_index_1, out_index_2;
|
||||||
if (traditional) {
|
if (traditional) {
|
||||||
out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] +
|
out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] +
|
||||||
N * pos.z * out_strides[0];
|
mat_idx * out_strides[0];
|
||||||
out_index_2 = out_index_1 + 1;
|
out_index_2 = out_index_1 + 1;
|
||||||
in_index_1 =
|
in_index_1 =
|
||||||
2 * pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
|
2 * pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0];
|
||||||
in_index_2 = in_index_1 + strides[2];
|
in_index_2 = in_index_1 + strides[2];
|
||||||
} else {
|
} else {
|
||||||
out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] +
|
out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] +
|
||||||
N * pos.z * out_strides[0];
|
mat_idx * out_strides[0];
|
||||||
out_index_2 = out_index_1 + grid.x * out_strides[2];
|
out_index_2 = out_index_1 + grid.x * out_strides[2];
|
||||||
in_index_1 =
|
in_index_1 = pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0];
|
||||||
pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
|
|
||||||
in_index_2 = in_index_1 + grid.x * strides[2];
|
in_index_2 = in_index_1 + grid.x * strides[2];
|
||||||
}
|
}
|
||||||
for (int i = 0; i < N && pos.z * N + i < n_batch; ++i) {
|
for (int i = 0; i < N && head_idx + i < n_head; ++i) {
|
||||||
// Read and write the output
|
// Read and write the output
|
||||||
float x1 = static_cast<float>(in[in_index_1]);
|
float x1 = static_cast<float>(in[in_index_1]);
|
||||||
float x2 = static_cast<float>(in[in_index_2]);
|
float x2 = static_cast<float>(in[in_index_2]);
|
||||||
@@ -141,11 +146,12 @@ template <typename T, bool traditional, bool forward, int N = 4>
|
|||||||
[[kernel]] void rope(
|
[[kernel]] void rope(
|
||||||
const device T* in [[buffer(0)]],
|
const device T* in [[buffer(0)]],
|
||||||
device T* out [[buffer(1)]],
|
device T* out [[buffer(1)]],
|
||||||
constant const int& offset,
|
const device int* offset,
|
||||||
constant const float& scale,
|
constant const float& scale,
|
||||||
constant const size_t strides[3],
|
constant const int64_t strides[3],
|
||||||
constant const size_t out_strides[3],
|
constant const int64_t out_strides[3],
|
||||||
constant const size_t& n_batch,
|
constant const int64_t& offset_stride,
|
||||||
|
constant const int& n_head,
|
||||||
constant const float& base [[buffer(10)]],
|
constant const float& base [[buffer(10)]],
|
||||||
uint3 pos [[thread_position_in_grid]],
|
uint3 pos [[thread_position_in_grid]],
|
||||||
uint3 grid [[threads_per_grid]]) {
|
uint3 grid [[threads_per_grid]]) {
|
||||||
@@ -159,7 +165,8 @@ template <typename T, bool traditional, bool forward, int N = 4>
|
|||||||
scale,
|
scale,
|
||||||
strides,
|
strides,
|
||||||
out_strides,
|
out_strides,
|
||||||
n_batch,
|
offset_stride,
|
||||||
|
n_head,
|
||||||
pos,
|
pos,
|
||||||
grid);
|
grid);
|
||||||
}
|
}
|
||||||
@@ -168,13 +175,14 @@ template <typename T, bool traditional, bool forward, int N = 4>
|
|||||||
[[kernel]] void rope_freqs(
|
[[kernel]] void rope_freqs(
|
||||||
const device T* in [[buffer(0)]],
|
const device T* in [[buffer(0)]],
|
||||||
device T* out [[buffer(1)]],
|
device T* out [[buffer(1)]],
|
||||||
constant const int& offset,
|
const device int* offset,
|
||||||
constant const float& scale,
|
constant const float& scale,
|
||||||
constant const size_t strides[3],
|
constant const int64_t strides[3],
|
||||||
constant const size_t out_strides[3],
|
constant const int64_t out_strides[3],
|
||||||
constant const size_t& n_batch,
|
constant const int64_t& offset_stride,
|
||||||
|
constant const int& n_head,
|
||||||
const device float* freqs [[buffer(10)]],
|
const device float* freqs [[buffer(10)]],
|
||||||
constant const size_t& freq_stride [[buffer(11)]],
|
constant const int64_t& freq_stride [[buffer(11)]],
|
||||||
uint3 pos [[thread_position_in_grid]],
|
uint3 pos [[thread_position_in_grid]],
|
||||||
uint3 grid [[threads_per_grid]]) {
|
uint3 grid [[threads_per_grid]]) {
|
||||||
float inv_freq = 1.0 / (freqs[freq_stride * pos.x]);
|
float inv_freq = 1.0 / (freqs[freq_stride * pos.x]);
|
||||||
@@ -186,61 +194,20 @@ template <typename T, bool traditional, bool forward, int N = 4>
|
|||||||
scale,
|
scale,
|
||||||
strides,
|
strides,
|
||||||
out_strides,
|
out_strides,
|
||||||
n_batch,
|
offset_stride,
|
||||||
|
n_head,
|
||||||
pos,
|
pos,
|
||||||
grid);
|
grid);
|
||||||
}
|
}
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
#define instantiate_rope_g(name, type, traditional, forward) \
|
#define instantiate_rope_g(name, type, traditional, forward) \
|
||||||
template [[host_name("rope_" #name)]] [[kernel]] void \
|
instantiate_kernel("rope_" #name, rope, type, traditional, forward) \
|
||||||
rope<type, traditional, forward>( \
|
instantiate_kernel("rope_freqs_" #name, rope_freqs, type, traditional, forward)
|
||||||
const device type* in [[buffer(0)]], \
|
|
||||||
device type* out [[buffer(1)]], \
|
|
||||||
constant const int& offset, \
|
|
||||||
constant const float& scale, \
|
|
||||||
constant const size_t strides[3], \
|
|
||||||
constant const size_t out_strides[3], \
|
|
||||||
constant const size_t& n_batch, \
|
|
||||||
constant const float& base [[buffer(10)]], \
|
|
||||||
uint3 pos [[thread_position_in_grid]], \
|
|
||||||
uint3 grid [[threads_per_grid]]); \
|
|
||||||
template [[host_name("rope_freqs_" #name)]] \
|
|
||||||
[[kernel]] void rope_freqs<type, traditional, forward>( \
|
|
||||||
const device type* in [[buffer(0)]], \
|
|
||||||
device type* out [[buffer(1)]], \
|
|
||||||
constant const int& offset, \
|
|
||||||
constant const float& scale, \
|
|
||||||
constant const size_t strides[3], \
|
|
||||||
constant const size_t out_strides[3], \
|
|
||||||
constant const size_t& n_batch, \
|
|
||||||
const device float* freqs [[buffer(10)]], \
|
|
||||||
constant const size_t& freq_stride [[buffer(11)]], \
|
|
||||||
uint3 pos [[thread_position_in_grid]], \
|
|
||||||
uint3 grid [[threads_per_grid]]);
|
|
||||||
|
|
||||||
#define instantiate_rope_s(name, type, traditional, forward) \
|
#define instantiate_rope_s(name, type, traditional, forward) \
|
||||||
template [[host_name("rope_single_" #name)]] [[kernel]] void \
|
instantiate_kernel("rope_single_" #name, rope_single, type, traditional, forward) \
|
||||||
rope_single<type, traditional, forward>( \
|
instantiate_kernel("rope_single_freqs_" #name, rope_single_freqs, type, traditional, forward)
|
||||||
const device type* in [[buffer(0)]], \
|
|
||||||
device type* out [[buffer(1)]], \
|
|
||||||
constant const int& offset, \
|
|
||||||
constant const float& scale, \
|
|
||||||
constant const size_t& stride, \
|
|
||||||
constant const float& base [[buffer(10)]], \
|
|
||||||
uint2 pos [[thread_position_in_grid]], \
|
|
||||||
uint2 grid [[threads_per_grid]]); \
|
|
||||||
template [[host_name("rope_single_freqs_" #name)]] \
|
|
||||||
[[kernel]] void rope_single_freqs<type, traditional, forward>( \
|
|
||||||
const device type* in [[buffer(0)]], \
|
|
||||||
device type* out [[buffer(1)]], \
|
|
||||||
constant const int& offset, \
|
|
||||||
constant const float& scale, \
|
|
||||||
constant const size_t& stride, \
|
|
||||||
const device float* freqs [[buffer(10)]], \
|
|
||||||
constant const size_t& freq_stride [[buffer(11)]], \
|
|
||||||
uint2 pos [[thread_position_in_grid]], \
|
|
||||||
uint2 grid [[threads_per_grid]]);
|
|
||||||
|
|
||||||
#define instantiate_rope(name, type, traditional, forward) \
|
#define instantiate_rope(name, type, traditional, forward) \
|
||||||
instantiate_rope_s(name, type, traditional, forward) \
|
instantiate_rope_s(name, type, traditional, forward) \
|
||||||
|
|||||||
@@ -18,23 +18,29 @@ void RoPE::eval_gpu(
|
|||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
auto& out = outputs[0];
|
auto& out = outputs[0];
|
||||||
|
|
||||||
if (in.ndim() < 3) {
|
|
||||||
throw std::runtime_error("[RoPE] Input must have at least 3 dimensions");
|
|
||||||
}
|
|
||||||
|
|
||||||
auto& s = out.primitive().stream();
|
auto& s = out.primitive().stream();
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
|
|
||||||
size_t strides[3];
|
int64_t strides[3];
|
||||||
size_t out_strides[3];
|
int64_t out_strides[3];
|
||||||
bool donated = false;
|
bool donated = false;
|
||||||
int ndim = in.ndim();
|
int ndim = in.ndim();
|
||||||
int dispatch_ndim = in.ndim();
|
int B = in.shape(0);
|
||||||
|
int T = in.shape(-2);
|
||||||
|
int D = in.shape(-1);
|
||||||
|
size_t mat_size = T * D;
|
||||||
|
|
||||||
|
int dispatch_ndim = ndim;
|
||||||
while (in.shape(-dispatch_ndim) == 1 && dispatch_ndim > 3) {
|
while (in.shape(-dispatch_ndim) == 1 && dispatch_ndim > 3) {
|
||||||
dispatch_ndim--;
|
dispatch_ndim--;
|
||||||
}
|
}
|
||||||
size_t mat_size = in.shape(-2) * in.shape(-1);
|
|
||||||
if (dims_ < in.shape(-1)) {
|
int N = 1;
|
||||||
|
for (int i = 1; i < (ndim - 2); ++i) {
|
||||||
|
N *= in.shape(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (dims_ < D) {
|
||||||
donated = true;
|
donated = true;
|
||||||
auto ctype =
|
auto ctype =
|
||||||
(in.flags().row_contiguous) ? CopyType::Vector : CopyType::General;
|
(in.flags().row_contiguous) ? CopyType::Vector : CopyType::General;
|
||||||
@@ -71,8 +77,8 @@ void RoPE::eval_gpu(
|
|||||||
out_strides[1] = out.strides()[ndim - 2];
|
out_strides[1] = out.strides()[ndim - 2];
|
||||||
out_strides[2] = out.strides()[ndim - 1];
|
out_strides[2] = out.strides()[ndim - 1];
|
||||||
|
|
||||||
// Special case for inference (single time step and contiguous)
|
// Special case for inference (single batch, single time step, and contiguous)
|
||||||
bool single = in.flags().row_contiguous && (mat_size == in.shape(-1));
|
bool single = in.flags().row_contiguous && B == 1 && T == 1;
|
||||||
|
|
||||||
bool with_freqs = inputs.size() == 3;
|
bool with_freqs = inputs.size() == 3;
|
||||||
std::ostringstream kname;
|
std::ostringstream kname;
|
||||||
@@ -86,24 +92,29 @@ void RoPE::eval_gpu(
|
|||||||
compute_encoder.set_compute_pipeline_state(kernel);
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
compute_encoder.set_input_array(donated ? out : in, 0);
|
compute_encoder.set_input_array(donated ? out : in, 0);
|
||||||
compute_encoder.set_output_array(out, 1);
|
compute_encoder.set_output_array(out, 1);
|
||||||
|
|
||||||
compute_encoder.set_input_array(inputs[1], 2);
|
compute_encoder.set_input_array(inputs[1], 2);
|
||||||
compute_encoder.set_bytes(scale_, 3);
|
compute_encoder.set_bytes(scale_, 3);
|
||||||
|
|
||||||
size_t n_batch = in.size() / mat_size;
|
|
||||||
MTL::Size group_dims;
|
MTL::Size group_dims;
|
||||||
MTL::Size grid_dims;
|
MTL::Size grid_dims;
|
||||||
if (single) {
|
if (single) {
|
||||||
compute_encoder.set_bytes(out_strides, 1, 4);
|
compute_encoder.set_bytes(out_strides, 1, 4);
|
||||||
uint32_t dim0 = dims_ / 2;
|
uint32_t dim0 = dims_ / 2;
|
||||||
group_dims = get_block_dims(dim0, n_batch, 1);
|
group_dims = get_block_dims(dim0, N, 1);
|
||||||
grid_dims = MTL::Size(dim0, n_batch, 1);
|
grid_dims = MTL::Size(dim0, N, 1);
|
||||||
} else {
|
} else {
|
||||||
compute_encoder.set_bytes(strides, 3, 4);
|
compute_encoder.set_bytes(strides, 3, 4);
|
||||||
compute_encoder.set_bytes(out_strides, 3, 5);
|
compute_encoder.set_bytes(out_strides, 3, 5);
|
||||||
compute_encoder.set_bytes(n_batch, 6);
|
int64_t offset_stride = 0;
|
||||||
|
if (inputs[1].ndim() > 0) {
|
||||||
|
offset_stride = inputs[1].strides()[0];
|
||||||
|
}
|
||||||
|
compute_encoder.set_bytes(offset_stride, 6);
|
||||||
|
compute_encoder.set_bytes(N, 7);
|
||||||
uint32_t dim0 = dims_ / 2;
|
uint32_t dim0 = dims_ / 2;
|
||||||
uint32_t dim1 = in.shape(-2);
|
uint32_t dim1 = T;
|
||||||
uint32_t dim2 = (n_batch + n_per_thread - 1) / n_per_thread;
|
uint32_t dim2 = B * ((N + n_per_thread - 1) / n_per_thread);
|
||||||
group_dims = get_block_dims(dim0, dim1, dim2);
|
group_dims = get_block_dims(dim0, dim1, dim2);
|
||||||
grid_dims = MTL::Size(dim0, dim1, dim2);
|
grid_dims = MTL::Size(dim0, dim1, dim2);
|
||||||
}
|
}
|
||||||
|
|||||||
60
mlx/fast.cpp
60
mlx/fast.cpp
@@ -366,10 +366,16 @@ array rope(
|
|||||||
msg << "[rope] Input must be a floating type but got " << x.dtype() << ".";
|
msg << "[rope] Input must be a floating type but got " << x.dtype() << ".";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
if (offset.size() != 1) {
|
if (offset.ndim() > 1) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[rope] offset must be a scalar but has shape " << offset.shape()
|
msg << "[rope] offset must have at most one dimension but has shape "
|
||||||
<< ".";
|
<< offset.shape() << ".";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
if (offset.size() != 1 && offset.size() != x.shape(0)) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[rope] offset must be a scalar or vector with " << x.shape(0)
|
||||||
|
<< " elements but has shape " << offset.shape() << ".";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
if (!issubdtype(offset.dtype(), integer)) {
|
if (!issubdtype(offset.dtype(), integer)) {
|
||||||
@@ -379,7 +385,7 @@ array rope(
|
|||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
if (offset.dtype().size() != 4) {
|
if (offset.dtype().size() != 4) {
|
||||||
inputs[1] = astype(offset, uint32, s);
|
inputs[1] = astype(offset, int32, s);
|
||||||
}
|
}
|
||||||
if (inputs.size() == 3 &&
|
if (inputs.size() == 3 &&
|
||||||
(inputs[2].ndim() != 1 || inputs[2].shape(0) != dims / 2)) {
|
(inputs[2].ndim() != 1 || inputs[2].shape(0) != dims / 2)) {
|
||||||
@@ -391,15 +397,26 @@ array rope(
|
|||||||
|
|
||||||
auto fallback = [dims, traditional, base, scale, forward, s](
|
auto fallback = [dims, traditional, base, scale, forward, s](
|
||||||
std::vector<array> inputs) {
|
std::vector<array> inputs) {
|
||||||
auto& shape = inputs[0].shape();
|
auto x = inputs[0];
|
||||||
int ndim = shape.size();
|
auto shape = x.shape();
|
||||||
auto x = flatten(inputs[0], 0, ndim - 3, s);
|
if (x.ndim() == 3) {
|
||||||
|
x = expand_dims(x, 1, s);
|
||||||
|
} else if (x.ndim() > 4) {
|
||||||
|
x = flatten(x, 1, 1 + (x.ndim() - 4), s);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto B = x.shape(0);
|
||||||
|
auto N = x.shape(1);
|
||||||
|
auto T = x.shape(2);
|
||||||
auto t = x.dtype();
|
auto t = x.dtype();
|
||||||
// Compute sines and cosines
|
// Compute sines and cosines
|
||||||
auto half_dims = dims / 2;
|
auto half_dims = dims / 2;
|
||||||
auto& offset = inputs[1];
|
auto offset = inputs[1];
|
||||||
|
if (offset.size() > 1) {
|
||||||
|
offset = expand_dims(offset, {-1, -2}, s);
|
||||||
|
}
|
||||||
auto positions =
|
auto positions =
|
||||||
multiply(add(arange(x.shape(1), t, s), offset, s), array(scale, t), s);
|
multiply(add(arange(x.shape(2), t, s), offset, s), array(scale, t), s);
|
||||||
|
|
||||||
auto default_inv_freqs = [&inputs, &s, &t, base, half_dims]() {
|
auto default_inv_freqs = [&inputs, &s, &t, base, half_dims]() {
|
||||||
return exp(
|
return exp(
|
||||||
@@ -412,8 +429,7 @@ array rope(
|
|||||||
|
|
||||||
auto inv_freqs = inputs.size() == 3 ? astype(reciprocal(inputs[2], s), t, s)
|
auto inv_freqs = inputs.size() == 3 ? astype(reciprocal(inputs[2], s), t, s)
|
||||||
: default_inv_freqs();
|
: default_inv_freqs();
|
||||||
auto theta =
|
auto theta = multiply(expand_dims(positions, -1, s), inv_freqs, s);
|
||||||
multiply(expand_dims(positions, 1, s), expand_dims(inv_freqs, 0, s), s);
|
|
||||||
auto coss = cos(theta, s);
|
auto coss = cos(theta, s);
|
||||||
auto sins = sin(theta, s);
|
auto sins = sin(theta, s);
|
||||||
|
|
||||||
@@ -436,32 +452,30 @@ array rope(
|
|||||||
};
|
};
|
||||||
|
|
||||||
if (traditional) {
|
if (traditional) {
|
||||||
auto x1 =
|
auto x1 = slice(x, {0, 0, 0, 0}, {B, N, T, dims}, {1, 1, 1, 2}, s);
|
||||||
slice(x, {0, 0, 0}, {x.shape(0), x.shape(1), dims}, {1, 1, 2}, s);
|
auto x2 = slice(x, {0, 0, 0, 1}, {B, N, T, dims}, {1, 1, 1, 2}, s);
|
||||||
auto x2 =
|
|
||||||
slice(x, {0, 0, 1}, {x.shape(0), x.shape(1), dims}, {1, 1, 2}, s);
|
|
||||||
auto outs = apply_rope(x1, x2, coss, sins);
|
auto outs = apply_rope(x1, x2, coss, sins);
|
||||||
for (auto& o : outs) {
|
for (auto& o : outs) {
|
||||||
o = expand_dims(o, 3, s);
|
o = expand_dims(o, -1, s);
|
||||||
}
|
}
|
||||||
auto out = concatenate(outs, 3, s);
|
auto out = reshape(concatenate(outs, -1, s), {B, N, T, dims}, s);
|
||||||
if (dims < x.shape(-1)) {
|
if (dims < x.shape(-1)) {
|
||||||
out = reshape(out, {x.shape(0), x.shape(1), dims});
|
out =
|
||||||
out = concatenate({out, slice(x, {0, 0, dims}, x.shape(), s)}, 2, s);
|
concatenate({out, slice(x, {0, 0, 0, dims}, x.shape(), s)}, -1, s);
|
||||||
}
|
}
|
||||||
return std::vector<array>{reshape(out, shape, s)};
|
return std::vector<array>{reshape(out, shape, s)};
|
||||||
} else {
|
} else {
|
||||||
auto out_s = x.shape();
|
auto out_s = x.shape();
|
||||||
out_s.back() = half_dims;
|
out_s.back() = half_dims;
|
||||||
auto x1 = slice(x, {0, 0, 0}, out_s, s);
|
auto x1 = slice(x, {0, 0, 0, 0}, out_s, s);
|
||||||
out_s.back() = dims;
|
out_s.back() = dims;
|
||||||
auto x2 = slice(x, {0, 0, half_dims}, out_s, s);
|
auto x2 = slice(x, {0, 0, 0, half_dims}, out_s, s);
|
||||||
|
|
||||||
auto outs = apply_rope(x1, x2, coss, sins);
|
auto outs = apply_rope(x1, x2, coss, sins);
|
||||||
if (dims < x.shape(-1)) {
|
if (dims < x.shape(-1)) {
|
||||||
outs.push_back(slice(x, {0, 0, dims}, x.shape(), s));
|
outs.push_back(slice(x, {0, 0, 0, dims}, x.shape(), s));
|
||||||
}
|
}
|
||||||
return std::vector<array>{reshape(concatenate(outs, 2, s), shape, s)};
|
return std::vector<array>{reshape(concatenate(outs, -1, s), shape, s)};
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
auto stream = to_stream(s);
|
auto stream = to_stream(s);
|
||||||
|
|||||||
@@ -1,10 +1,8 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
#include <string>
|
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
std::string version() {
|
const char* version() {
|
||||||
return MLX_VERSION;
|
return MLX_VERSION;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -15,6 +15,6 @@ namespace mlx::core {
|
|||||||
*
|
*
|
||||||
* For dev builds, the version will include the suffix ".devYYYYMMDD+hash"
|
* For dev builds, the version will include the suffix ".devYYYYMMDD+hash"
|
||||||
*/
|
*/
|
||||||
std::string version();
|
const char* version();
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -52,7 +52,6 @@ set_target_properties(
|
|||||||
${MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY})
|
${MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY})
|
||||||
|
|
||||||
target_link_libraries(core PRIVATE mlx)
|
target_link_libraries(core PRIVATE mlx)
|
||||||
target_compile_definitions(core PRIVATE _VERSION_=${MLX_VERSION})
|
|
||||||
|
|
||||||
if(BUILD_SHARED_LIBS)
|
if(BUILD_SHARED_LIBS)
|
||||||
if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
||||||
|
|||||||
@@ -164,8 +164,13 @@ void init_fast(nb::module_& parent_module) {
|
|||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Apply rotary positional encoding to the input.
|
Apply rotary positional encoding to the input.
|
||||||
|
|
||||||
|
The input is expected to be at least 3D with shape ``(B, *, T, D)`` where:
|
||||||
|
* ``B`` is the batch size.
|
||||||
|
* ``T`` is the sequence length.
|
||||||
|
* ``D`` is the feature dimension.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
a (array): Input array.
|
a (array): The input array.
|
||||||
dims (int): The feature dimensions to be rotated. If the input feature
|
dims (int): The feature dimensions to be rotated. If the input feature
|
||||||
is larger than dims then the rest is left unchanged.
|
is larger than dims then the rest is left unchanged.
|
||||||
traditional (bool): If set to ``True`` choose the traditional
|
traditional (bool): If set to ``True`` choose the traditional
|
||||||
@@ -174,7 +179,9 @@ void init_fast(nb::module_& parent_module) {
|
|||||||
each dimension in the positional encodings. Exactly one of ``base`` and
|
each dimension in the positional encodings. Exactly one of ``base`` and
|
||||||
``freqs`` must be ``None``.
|
``freqs`` must be ``None``.
|
||||||
scale (float): The scale used to scale the positions.
|
scale (float): The scale used to scale the positions.
|
||||||
offset (int or array): The position offset to start at.
|
offset (int or array): The position offset to start at. If an
|
||||||
|
:obj:`array` is given it can be a scalar or vector of ``B``
|
||||||
|
offsets for each example in the batch.
|
||||||
freqs (array, optional): Optional frequencies to use with RoPE.
|
freqs (array, optional): Optional frequencies to use with RoPE.
|
||||||
If set, the ``base`` parameter must be ``None``. Default: ``None``.
|
If set, the ``base`` parameter must be ``None``. Default: ``None``.
|
||||||
|
|
||||||
|
|||||||
@@ -2,9 +2,9 @@
|
|||||||
|
|
||||||
#include <nanobind/nanobind.h>
|
#include <nanobind/nanobind.h>
|
||||||
|
|
||||||
#define STRINGIFY(x) #x
|
#include "mlx/version.h"
|
||||||
#define TOSTRING(x) STRINGIFY(x)
|
|
||||||
|
|
||||||
|
namespace mx = mlx::core;
|
||||||
namespace nb = nanobind;
|
namespace nb = nanobind;
|
||||||
|
|
||||||
void init_mlx_func(nb::module_&);
|
void init_mlx_func(nb::module_&);
|
||||||
@@ -48,5 +48,5 @@ NB_MODULE(core, m) {
|
|||||||
init_distributed(m);
|
init_distributed(m);
|
||||||
init_export(m);
|
init_export(m);
|
||||||
|
|
||||||
m.attr("__version__") = TOSTRING(_VERSION_);
|
m.attr("__version__") = mx::version();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -91,7 +91,7 @@ mx::array to_array_with_accessor(nb::object obj) {
|
|||||||
return nb::cast<mx::array>(obj.attr("__mlx_array__")());
|
return nb::cast<mx::array>(obj.attr("__mlx_array__")());
|
||||||
} else {
|
} else {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "Invalid type " << nb::type_name(obj.type()).c_str()
|
msg << "Invalid type " << nb::type_name(obj.type()).c_str()
|
||||||
<< " received in array initialization.";
|
<< " received in array initialization.";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -594,124 +594,123 @@ class TestBlas(mlx_tests.MLXTestCase):
|
|||||||
np.random.seed(0)
|
np.random.seed(0)
|
||||||
# Batched matmul
|
# Batched matmul
|
||||||
alpha = 0.5
|
alpha = 0.5
|
||||||
beta = 2.0
|
for beta in (1.0, 2.0):
|
||||||
|
# c must broadcast to the output shape
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
mx.addmm(mx.zeros((2, 2, 2)), mx.zeros((2, 2)), mx.zeros((2, 2)))
|
||||||
|
|
||||||
# c must broadcast to the output shape
|
# Regular batched case
|
||||||
with self.assertRaises(ValueError):
|
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
||||||
mx.addmm(mx.zeros((2, 2, 2)), mx.zeros((2, 2)), mx.zeros((2, 2)))
|
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 16)).astype(np.float32)
|
||||||
|
|
||||||
# Regular batched case
|
a_mlx = mx.array(a_npy)
|
||||||
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
b_mlx = mx.array(b_npy)
|
||||||
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 16)).astype(np.float32)
|
|
||||||
|
|
||||||
a_mlx = mx.array(a_npy)
|
for c_shape in ((1,), (1, 16), (32, 1, 16), (1, 128, 16)):
|
||||||
b_mlx = mx.array(b_npy)
|
c_npy = np.ones(c_shape).astype(np.float32)
|
||||||
|
c_mlx = mx.array(c_npy)
|
||||||
|
|
||||||
for c_shape in ((1,), (1, 16), (32, 1, 16), (1, 128, 16)):
|
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
|
||||||
c_npy = np.ones(c_shape).astype(np.float32)
|
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
|
||||||
c_mlx = mx.array(c_npy)
|
|
||||||
|
|
||||||
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
|
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
||||||
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
|
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
||||||
|
|
||||||
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
# Batched and transposed matmul
|
||||||
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
||||||
|
b_mlx = mx.array(b_npy)
|
||||||
|
|
||||||
# Batched and transposed matmul
|
for c_shape in ((1,), (32, 1, 128), (1, 128)):
|
||||||
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
c_npy = np.ones(c_shape).astype(np.float32)
|
||||||
b_mlx = mx.array(b_npy)
|
c_mlx = mx.array(c_npy)
|
||||||
|
|
||||||
for c_shape in ((1,), (32, 1, 128), (1, 128)):
|
b_np_t = np.transpose(b_npy, (0, 2, 1))
|
||||||
c_npy = np.ones(c_shape).astype(np.float32)
|
b_mx_t = mx.transpose(b_mlx, (0, 2, 1))
|
||||||
c_mlx = mx.array(c_npy)
|
|
||||||
|
|
||||||
b_np_t = np.transpose(b_npy, (0, 2, 1))
|
d_npy = alpha * (a_npy @ b_np_t) + beta * c_npy
|
||||||
b_mx_t = mx.transpose(b_mlx, (0, 2, 1))
|
d_mlx = mx.addmm(c_mlx, a_mlx, b_mx_t, alpha, beta)
|
||||||
|
|
||||||
d_npy = alpha * (a_npy @ b_np_t) + beta * c_npy
|
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
||||||
d_mlx = mx.addmm(c_mlx, a_mlx, b_mx_t, alpha, beta)
|
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
||||||
|
# Batched matmul with simple broadcast
|
||||||
|
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
||||||
|
b_npy = np.random.normal(0.0, 1.0 / 128, (16, 16)).astype(np.float32)
|
||||||
|
|
||||||
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
a_mlx = mx.array(a_npy)
|
||||||
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
b_mlx = mx.array(b_npy)
|
||||||
# Batched matmul with simple broadcast
|
|
||||||
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
|
||||||
b_npy = np.random.normal(0.0, 1.0 / 128, (16, 16)).astype(np.float32)
|
|
||||||
|
|
||||||
a_mlx = mx.array(a_npy)
|
for c_shape in ((1,), (1, 16), (32, 1, 16), (1, 128, 16)):
|
||||||
b_mlx = mx.array(b_npy)
|
c_npy = np.ones(c_shape).astype(np.float32)
|
||||||
|
c_mlx = mx.array(c_npy)
|
||||||
|
|
||||||
for c_shape in ((1,), (1, 16), (32, 1, 16), (1, 128, 16)):
|
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
|
||||||
c_npy = np.ones(c_shape).astype(np.float32)
|
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
|
||||||
c_mlx = mx.array(c_npy)
|
|
||||||
|
|
||||||
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
|
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
||||||
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
|
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
||||||
|
# Matmul with vector
|
||||||
|
a_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32)
|
||||||
|
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 128)).astype(np.float32)
|
||||||
|
a_mlx = mx.array(a_npy)
|
||||||
|
b_mlx = mx.array(b_npy)
|
||||||
|
|
||||||
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
for c_shape in ((1,), (128,), (32, 128)):
|
||||||
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
c_npy = np.ones(c_shape).astype(np.float32)
|
||||||
# Matmul with vector
|
c_mlx = mx.array(c_npy)
|
||||||
a_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32)
|
|
||||||
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 128)).astype(np.float32)
|
|
||||||
a_mlx = mx.array(a_npy)
|
|
||||||
b_mlx = mx.array(b_npy)
|
|
||||||
|
|
||||||
for c_shape in ((1,), (128,), (32, 128)):
|
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
|
||||||
c_npy = np.ones(c_shape).astype(np.float32)
|
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
|
||||||
c_mlx = mx.array(c_npy)
|
|
||||||
|
|
||||||
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
|
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
||||||
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
|
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
||||||
|
|
||||||
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
# Matmul with vector
|
||||||
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
||||||
|
b_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32)
|
||||||
|
a_mlx = mx.array(a_npy)
|
||||||
|
b_mlx = mx.array(b_npy)
|
||||||
|
|
||||||
# Matmul with vector
|
for c_shape in ((1,), (32, 128)):
|
||||||
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
c_npy = np.ones(c_shape).astype(np.float32)
|
||||||
b_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32)
|
c_mlx = mx.array(c_npy)
|
||||||
a_mlx = mx.array(a_npy)
|
|
||||||
b_mlx = mx.array(b_npy)
|
|
||||||
|
|
||||||
for c_shape in ((1,), (32, 128)):
|
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
|
||||||
c_npy = np.ones(c_shape).astype(np.float32)
|
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
|
||||||
c_mlx = mx.array(c_npy)
|
|
||||||
|
|
||||||
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
|
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
||||||
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
|
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
||||||
|
|
||||||
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
# Split K specializtion
|
||||||
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
a_npy = np.random.normal(0.0, 1.0 / 128, (64, 4096)).astype(np.float32)
|
||||||
|
b_npy = np.random.normal(0.0, 1.0 / 128, (4096, 32)).astype(np.float32)
|
||||||
|
|
||||||
# Split K specializtion
|
a_mlx = mx.array(a_npy)
|
||||||
a_npy = np.random.normal(0.0, 1.0 / 128, (64, 4096)).astype(np.float32)
|
b_mlx = mx.array(b_npy)
|
||||||
b_npy = np.random.normal(0.0, 1.0 / 128, (4096, 32)).astype(np.float32)
|
|
||||||
|
|
||||||
a_mlx = mx.array(a_npy)
|
for c_shape in ((1,), (1, 32), (64, 1), (64, 32)):
|
||||||
b_mlx = mx.array(b_npy)
|
c_npy = np.ones(c_shape).astype(np.float32)
|
||||||
|
c_mlx = mx.array(c_npy)
|
||||||
|
|
||||||
for c_shape in ((1,), (1, 32), (64, 1), (64, 32)):
|
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
|
||||||
c_npy = np.ones(c_shape).astype(np.float32)
|
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
|
||||||
c_mlx = mx.array(c_npy)
|
|
||||||
|
|
||||||
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
|
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
||||||
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
|
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
||||||
|
|
||||||
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
# Transposed c
|
||||||
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
a = mx.ones((10, 5)).T
|
||||||
|
b = mx.ones((5, 5))
|
||||||
|
out = mx.addmm(a, b, a, beta=beta, alpha=alpha)
|
||||||
|
expected = beta * a + alpha * (b @ a)
|
||||||
|
self.assertTrue(mx.allclose(expected, out))
|
||||||
|
|
||||||
# Transposed c
|
# Broadcast c
|
||||||
a = mx.ones((10, 5)).T
|
a = mx.ones((5, 5))
|
||||||
b = mx.ones((5, 5))
|
b = mx.ones((5, 5))
|
||||||
out = mx.addmm(a, b, a, beta=1.5, alpha=0.5)
|
c = mx.ones((1, 5))
|
||||||
expected = 1.5 * a + 0.5 * (b @ a)
|
out = mx.addmm(c, a, b, beta=beta, alpha=alpha)
|
||||||
self.assertTrue(mx.allclose(expected, out))
|
expected = beta * c + alpha * (a @ b)
|
||||||
|
self.assertTrue(mx.allclose(expected, out))
|
||||||
# Broadcast c
|
|
||||||
a = mx.ones((5, 5))
|
|
||||||
b = mx.ones((5, 5))
|
|
||||||
c = mx.ones((1, 5))
|
|
||||||
out = mx.addmm(c, a, b, beta=1.5, alpha=0.5)
|
|
||||||
expected = 1.5 * c + 0.5 * (a @ b)
|
|
||||||
self.assertTrue(mx.allclose(expected, out))
|
|
||||||
|
|
||||||
def test_addmm_grad(self):
|
def test_addmm_grad(self):
|
||||||
def make_ref_addmm(alpha, beta):
|
def make_ref_addmm(alpha, beta):
|
||||||
@@ -724,33 +723,32 @@ class TestBlas(mlx_tests.MLXTestCase):
|
|||||||
shapes = ((1, 64, 32, 128), (4, 28, 24, 47), (1, 1, 24, 47))
|
shapes = ((1, 64, 32, 128), (4, 28, 24, 47), (1, 1, 24, 47))
|
||||||
|
|
||||||
alpha = 2.0
|
alpha = 2.0
|
||||||
beta = 0.5
|
for beta in (1.0, 0.5):
|
||||||
|
f_test = make_addmm(alpha, beta)
|
||||||
|
f_ref = make_ref_addmm(alpha, beta)
|
||||||
|
|
||||||
f_test = make_addmm(alpha, beta)
|
for B, M, N, K in shapes:
|
||||||
f_ref = make_ref_addmm(alpha, beta)
|
cotan = mx.ones((B, M, N))
|
||||||
|
c = mx.random.normal((B, M, N))
|
||||||
|
a = mx.random.normal((B, M, K))
|
||||||
|
b = mx.random.normal((B, K, N))
|
||||||
|
|
||||||
for B, M, N, K in shapes:
|
out_ref, dout_ref = mx.vjp(
|
||||||
cotan = mx.ones((B, M, N))
|
f_ref,
|
||||||
c = mx.random.normal((B, M, N))
|
[c, a, b],
|
||||||
a = mx.random.normal((B, M, K))
|
[cotan],
|
||||||
b = mx.random.normal((B, K, N))
|
)
|
||||||
|
out_test, dout_test = mx.vjp(
|
||||||
|
f_test,
|
||||||
|
[c, a, b],
|
||||||
|
[cotan],
|
||||||
|
)
|
||||||
|
|
||||||
out_ref, dout_ref = mx.vjp(
|
self.assertTrue(mx.allclose(out_ref[0], out_test[0], atol=1e-4).item())
|
||||||
f_ref,
|
|
||||||
[c, a, b],
|
|
||||||
[cotan],
|
|
||||||
)
|
|
||||||
out_test, dout_test = mx.vjp(
|
|
||||||
f_test,
|
|
||||||
[c, a, b],
|
|
||||||
[cotan],
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertTrue(mx.allclose(out_ref[0], out_test[0], atol=1e-4).item())
|
for r, t in zip(dout_ref, dout_test):
|
||||||
|
self.assertEqual(r.shape, t.shape)
|
||||||
for r, t in zip(dout_ref, dout_test):
|
self.assertTrue(mx.allclose(r, t, atol=1e-4).item())
|
||||||
self.assertEqual(r.shape, t.shape)
|
|
||||||
self.assertTrue(mx.allclose(r, t, atol=1e-4).item())
|
|
||||||
|
|
||||||
def test_empty_matmul(self):
|
def test_empty_matmul(self):
|
||||||
a = mx.array([[], []]).T
|
a = mx.array([[], []]).T
|
||||||
|
|||||||
@@ -8,18 +8,23 @@ import mlx_tests
|
|||||||
|
|
||||||
|
|
||||||
def rope_orig(x, dims, traditional, base, scale, offset, freqs=None):
|
def rope_orig(x, dims, traditional, base, scale, offset, freqs=None):
|
||||||
offset = offset.item() if isinstance(offset, mx.array) else offset
|
N = x.shape[-2]
|
||||||
N = x.shape[-2] + offset
|
|
||||||
dtype = x.dtype
|
dtype = x.dtype
|
||||||
half_D = dims // 2
|
half_D = dims // 2
|
||||||
positions = mx.arange(offset, N, dtype=dtype) * scale
|
positions = mx.arange(N, dtype=dtype)
|
||||||
|
if isinstance(offset, mx.array) and offset.size > 1:
|
||||||
|
expand = tuple(range(1, x.ndim - 1))
|
||||||
|
positions = mx.expand_dims(offset, expand) + positions
|
||||||
|
else:
|
||||||
|
positions = offset + positions
|
||||||
|
positions = positions * scale
|
||||||
if freqs is None:
|
if freqs is None:
|
||||||
inv_freqs = mx.exp(
|
inv_freqs = mx.exp(
|
||||||
-mx.arange(0.0, half_D, dtype=dtype) * (math.log(base) / half_D)
|
-mx.arange(0.0, half_D, dtype=dtype) * (math.log(base) / half_D)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
inv_freqs = (1 / freqs).astype(x.dtype)
|
inv_freqs = (1 / freqs).astype(x.dtype)
|
||||||
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(inv_freqs, (1, -1))
|
theta = mx.expand_dims(positions, -1) * inv_freqs
|
||||||
costheta, sintheta = mx.cos(theta), mx.sin(theta)
|
costheta, sintheta = mx.cos(theta), mx.sin(theta)
|
||||||
if traditional:
|
if traditional:
|
||||||
x1 = x[..., :dims:2]
|
x1 = x[..., :dims:2]
|
||||||
@@ -214,6 +219,7 @@ class TestFast(mlx_tests.MLXTestCase):
|
|||||||
)
|
)
|
||||||
self.assertEqual(dtype, rx.dtype)
|
self.assertEqual(dtype, rx.dtype)
|
||||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||||
|
return
|
||||||
|
|
||||||
# Test single vector
|
# Test single vector
|
||||||
x = mx.random.uniform(shape=(1, 1, dims))
|
x = mx.random.uniform(shape=(1, 1, dims))
|
||||||
@@ -277,6 +283,55 @@ class TestFast(mlx_tests.MLXTestCase):
|
|||||||
g2 = mx.grad(f2)(x, y)
|
g2 = mx.grad(f2)(x, y)
|
||||||
self.assertLess(mx.abs(g1 - g2).max(), 1e-5)
|
self.assertLess(mx.abs(g1 - g2).max(), 1e-5)
|
||||||
|
|
||||||
|
def test_rope_batch(self):
|
||||||
|
T = 4
|
||||||
|
base = 10000.0
|
||||||
|
scale = 1.0
|
||||||
|
traditional = True
|
||||||
|
batch_sizes = [3, 8, 11]
|
||||||
|
num_heads = [1, 3, 5]
|
||||||
|
dims = 32
|
||||||
|
|
||||||
|
x = mx.random.uniform(shape=(8, 4, T, dims))
|
||||||
|
|
||||||
|
offset = mx.array([1, 2, 3])
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
mx.fast.rope(
|
||||||
|
x,
|
||||||
|
dims,
|
||||||
|
traditional=traditional,
|
||||||
|
base=base,
|
||||||
|
scale=scale,
|
||||||
|
offset=offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
for batch_size in batch_sizes:
|
||||||
|
for n_head in num_heads:
|
||||||
|
x = mx.random.uniform(shape=(batch_size, n_head, T, dims))
|
||||||
|
offset = mx.arange(batch_size)
|
||||||
|
rx = rope_orig(x, dims, traditional, base, scale, offset)
|
||||||
|
rx_fast = mx.fast.rope(
|
||||||
|
x,
|
||||||
|
dims,
|
||||||
|
traditional=traditional,
|
||||||
|
base=base,
|
||||||
|
scale=scale,
|
||||||
|
offset=offset,
|
||||||
|
)
|
||||||
|
self.assertLess(mx.abs(rx - rx_fast).max(), 1e-5)
|
||||||
|
x = mx.random.normal(shape=(2, 6, 8, 64)).transpose(0, 2, 1, 3)
|
||||||
|
dims = 64
|
||||||
|
offset = 0
|
||||||
|
rx_fast = mx.fast.rope(
|
||||||
|
x, dims, traditional=traditional, scale=scale, base=base, offset=offset
|
||||||
|
)
|
||||||
|
rx_fast_single = mx.fast.rope(
|
||||||
|
x[0:1], dims, traditional=traditional, scale=scale, base=base, offset=offset
|
||||||
|
)
|
||||||
|
|
||||||
|
rx = rope_orig(x, dims, traditional, base, scale, offset)
|
||||||
|
self.assertLess(mx.abs(rx - rx_fast).max(), 1e-5)
|
||||||
|
|
||||||
def test_rms_norm(self):
|
def test_rms_norm(self):
|
||||||
# Per dtype absolute tolerance
|
# Per dtype absolute tolerance
|
||||||
tolerances = {mx.float32: 1e-6, mx.float16: 1e-3, mx.bfloat16: 1e-2}
|
tolerances = {mx.float32: 1e-6, mx.float16: 1e-3, mx.bfloat16: 1e-2}
|
||||||
|
|||||||
3
setup.py
3
setup.py
@@ -121,7 +121,8 @@ class CMakeBuild(build_ext):
|
|||||||
build_args += [f"-j{os.cpu_count()}"]
|
build_args += [f"-j{os.cpu_count()}"]
|
||||||
|
|
||||||
# Avoid cache miss when building from temporary dirs.
|
# Avoid cache miss when building from temporary dirs.
|
||||||
os.environ["CCACHE_BASEDIR"] = os.path.abspath(self.build_temp)
|
os.environ["CCACHE_BASEDIR"] = os.path.realpath(self.build_temp)
|
||||||
|
os.environ["CCACHE_NOHASHDIR"] = "true"
|
||||||
|
|
||||||
subprocess.run(
|
subprocess.run(
|
||||||
["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True
|
["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True
|
||||||
|
|||||||
@@ -1,10 +1,7 @@
|
|||||||
# Doctest works fine with cmake 3.5
|
|
||||||
set(CMAKE_POLICY_VERSION_MINIMUM 3.5)
|
|
||||||
|
|
||||||
FetchContent_Declare(
|
FetchContent_Declare(
|
||||||
doctest
|
doctest
|
||||||
GIT_REPOSITORY "https://github.com/onqtam/doctest"
|
GIT_REPOSITORY https://github.com/onqtam/doctest.git
|
||||||
GIT_TAG "ae7a13539fb71f270b87eb2e874fbac80bc8dda2")
|
GIT_TAG v2.4.12)
|
||||||
FetchContent_MakeAvailable(doctest)
|
FetchContent_MakeAvailable(doctest)
|
||||||
|
|
||||||
add_executable(tests ${PROJECT_SOURCE_DIR}/tests/tests.cpp)
|
add_executable(tests ${PROJECT_SOURCE_DIR}/tests/tests.cpp)
|
||||||
|
|||||||
Reference in New Issue
Block a user