Merge branch 'main' into stft

This commit is contained in:
Param Thakkar 2025-04-23 00:02:52 +05:30 committed by GitHub
commit a963a15b8d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
57 changed files with 3477 additions and 1014 deletions

View File

@ -251,7 +251,7 @@ jobs:
name: Install Python package
command: |
source env/bin/activate
MACOSX_DEPLOYMENT_TARGET="" DEV_RELEASE=1 \
env -u MACOSX_DEPLOYMENT_TARGET DEV_RELEASE=1 \
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
pip install . -v
- run:

View File

@ -0,0 +1,74 @@
# Copyright © 2025 Apple Inc.
import mlx.core as mx
from time_utils import time_fn
N = 1024
D = 1024
M = 1024
E = 32
I = 4
def gather_sort(x, indices):
N, M = indices.shape
indices = indices.flatten()
order = mx.argsort(indices)
inv_order = mx.argsort(order)
return x.flatten(0, -3)[order // M], indices[order], inv_order
def scatter_unsort(x, inv_order, shape=None):
x = x[inv_order]
if shape is not None:
x = mx.unflatten(x, 0, shape)
return x
def gather_mm_simulate(x, w, indices):
x, idx, inv_order = gather_sort(x, indices)
for i in range(2):
y = mx.concatenate([x[i] @ w[j].T for i, j in enumerate(idx.tolist())], axis=0)
x = y[:, None]
x = scatter_unsort(x, inv_order, indices.shape)
return x
def time_gather_mm():
x = mx.random.normal((N, 1, 1, D)) / 1024**0.5
w1 = mx.random.normal((E, M, D)) / 1024**0.5
w2 = mx.random.normal((E, D, M)) / 1024**0.5
indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)
sorted_indices = mx.sort(indices.flatten()).reshape(N, I)
mx.eval(x, w1, w2, indices, sorted_indices)
def gather_mm(x, w1, w2, indices, sort):
idx = indices
inv_order = None
if sort:
x, idx, inv_order = gather_sort(x, indices)
x = mx.gather_mm(x, w1.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort)
x = mx.gather_mm(x, w2.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort)
if sort:
x = scatter_unsort(x, inv_order, indices.shape)
return x
time_fn(gather_mm, x, w1, w2, indices, False)
time_fn(gather_mm, x, w1, w2, sorted_indices, False)
time_fn(gather_mm, x, w1, w2, indices, True)
x = mx.random.normal((N * I, D)) / 1024**0.5
w1 = mx.random.normal((M, D)) / 1024**0.5
w2 = mx.random.normal((D, M)) / 1024**0.5
mx.eval(x, w1, w2)
def equivalent_matmul(x, w1, w2):
x = x @ w1.T
x = x @ w2.T
return x
time_fn(equivalent_matmul, x, w1, w2)
if __name__ == "__main__":
time_gather_mm()

View File

@ -0,0 +1,84 @@
# Copyright © 2025 Apple Inc.
import mlx.core as mx
from time_utils import time_fn
N = 1024
D = 1024
M = 1024
E = 32
I = 4
def gather_sort(x, indices):
N, M = indices.shape
indices = indices.flatten()
order = mx.argsort(indices)
inv_order = mx.argsort(order)
return x.flatten(0, -3)[order // M], indices[order], inv_order
def scatter_unsort(x, inv_order, shape=None):
x = x[inv_order]
if shape is not None:
x = mx.unflatten(x, 0, shape)
return x
def gather_mm_simulate(x, w, indices):
x, idx, inv_order = gather_sort(x, indices)
for i in range(2):
y = mx.concatenate(
[
mx.quantized_matmul(x[i], w[0][j], w[1][j], w[2][j], transpose=True)
for i, j in enumerate(idx.tolist())
],
axis=0,
)
x = y[:, None]
x = scatter_unsort(x, inv_order, indices.shape)
return x
def time_gather_qmm():
x = mx.random.normal((N, 1, 1, D)) / 1024**0.5
w1 = mx.random.normal((E, M, D)) / 1024**0.5
w2 = mx.random.normal((E, D, M)) / 1024**0.5
w1 = mx.quantize(w1)
w2 = mx.quantize(w2)
indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)
sorted_indices = mx.sort(indices.flatten()).reshape(N, I)
mx.eval(x, w1, w2, indices, sorted_indices)
def gather_mm(x, w1, w2, indices, sort):
idx = indices
inv_order = None
if sort:
x, idx, inv_order = gather_sort(x, indices)
x = mx.gather_qmm(x, *w1, transpose=True, rhs_indices=idx, sorted_indices=sort)
x = mx.gather_qmm(x, *w2, transpose=True, rhs_indices=idx, sorted_indices=sort)
if sort:
x = scatter_unsort(x, inv_order, indices.shape)
return x
time_fn(gather_mm, x, w1, w2, indices, False)
time_fn(gather_mm, x, w1, w2, sorted_indices, False)
time_fn(gather_mm, x, w1, w2, indices, True)
x = mx.random.normal((N * I, D)) / 1024**0.5
w1 = mx.random.normal((M, D)) / 1024**0.5
w2 = mx.random.normal((D, M)) / 1024**0.5
w1 = mx.quantize(w1)
w2 = mx.quantize(w2)
mx.eval(x, w1, w2)
def equivalent_matmul(x, w1, w2):
x = mx.quantized_matmul(x, *w1, transpose=True)
x = mx.quantized_matmul(x, *w2, transpose=True)
return x
time_fn(equivalent_matmul, x, w1, w2)
if __name__ == "__main__":
time_gather_qmm()

View File

@ -38,6 +38,7 @@ Array
array.log10
array.log1p
array.log2
array.logcumsumexp
array.logsumexp
array.max
array.mean

View File

@ -103,6 +103,7 @@ Operations
log10
log1p
logaddexp
logcumsumexp
logical_not
logical_and
logical_or

View File

@ -5,6 +5,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
${CMAKE_CURRENT_SOURCE_DIR}/dtype_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/export.cpp
${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp

View File

@ -339,11 +339,11 @@ class array {
return allocator::allocator().size(buffer());
}
// Return a copy of the shared pointer
// to the array::Data struct
std::shared_ptr<Data> data_shared_ptr() const {
// Return the shared pointer to the array::Data struct
const std::shared_ptr<Data>& data_shared_ptr() const {
return array_desc_->data;
}
// Return a raw pointer to the arrays data
template <typename T>
T* data() {

View File

@ -1,6 +1,7 @@
target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/broadcasting.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp

View File

@ -0,0 +1,24 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/common/utils.h"
namespace mlx::core {
void broadcast(const array& in, array& out) {
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
Strides strides(out.ndim(), 0);
int diff = out.ndim() - in.ndim();
for (int i = in.ndim() - 1; i >= 0; --i) {
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
}
auto flags = in.flags();
if (out.size() > in.size()) {
flags.row_contiguous = flags.col_contiguous = false;
}
out.copy_shared_buffer(in, strides, flags, in.data_size());
}
} // namespace mlx::core

View File

@ -0,0 +1,11 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/array.h"
namespace mlx::core {
void broadcast(const array& in, array& out);
} // namespace mlx::core

View File

@ -1,6 +1,7 @@
// Copyright © 2024 Apple Inc.
#include <cassert>
#include "mlx/backend/common/broadcasting.h"
#include "mlx/backend/common/utils.h"
#include "mlx/primitives.h"
@ -42,23 +43,6 @@ void AsStrided::eval(const std::vector<array>& inputs, array& out) {
return out.copy_shared_buffer(in, strides_, flags, data_size, offset_);
}
void broadcast(const array& in, array& out) {
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
Strides strides(out.ndim(), 0);
int diff = out.ndim() - in.ndim();
for (int i = in.ndim() - 1; i >= 0; --i) {
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
}
auto flags = in.flags();
if (out.size() > in.size()) {
flags.row_contiguous = flags.col_contiguous = false;
}
out.copy_shared_buffer(in, strides, flags, in.data_size());
}
void Broadcast::eval(const std::vector<array>& inputs, array& out) {
broadcast(inputs[0], out);
}

View File

@ -3,6 +3,7 @@
#include <cassert>
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/binary_ops.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h"
@ -226,6 +227,16 @@ void scan_dispatch(
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
break;
}
case Scan::LogAddExp: {
auto op = [](U a, T b) {
return detail::LogAddExp{}(a, static_cast<U>(b));
};
auto init = (issubdtype(in.dtype(), floating))
? static_cast<U>(-std::numeric_limits<float>::infinity())
: std::numeric_limits<U>::min();
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
break;
}
}
}

View File

@ -61,6 +61,7 @@ if(MLX_METAL_JIT)
kernels/steel/gemm/transforms.h)
make_jit_source(steel/gemm/kernels/steel_gemm_fused)
make_jit_source(steel/gemm/kernels/steel_gemm_masked kernels/steel/defines.h)
make_jit_source(steel/gemm/kernels/steel_gemm_gather)
make_jit_source(steel/gemm/kernels/steel_gemm_splitk)
make_jit_source(
steel/conv/conv

View File

@ -33,6 +33,7 @@ const char* gemm();
const char* steel_gemm_fused();
const char* steel_gemm_masked();
const char* steel_gemm_splitk();
const char* steel_gemm_gather();
const char* conv();
const char* steel_conv();
const char* steel_conv_general();

View File

@ -584,6 +584,44 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_steel_gemm_gather_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const array& out,
bool transpose_a,
bool transpose_b,
int bm,
int bn,
int bk,
int wm,
int wn,
bool rhs) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name, [&]() {
std::string kernel_source;
concatenate(
kernel_source,
metal::utils(),
metal::gemm(),
metal::steel_gemm_gather(),
get_template_definition(
lib_name,
rhs ? "gather_mm_rhs" : "gather_mm",
get_type_string(out.dtype()),
bm,
bn,
bk,
wm,
wn,
transpose_a,
transpose_b));
return kernel_source;
});
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
}
MTL::ComputePipelineState* get_gemv_masked_kernel(
metal::Device& d,
const std::string& kernel_name,
@ -714,4 +752,43 @@ MTL::ComputePipelineState* get_quantized_kernel(
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_gather_qmm_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const array& x,
int group_size,
int bits,
int bm,
int bn,
int bk,
int wm,
int wn,
bool transpose) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name, [&]() {
std::string kernel_source;
concatenate(
kernel_source,
metal::utils(),
metal::gemm(),
metal::quantized(),
get_template_definition(
lib_name,
"gather_qmm_rhs",
get_type_string(x.dtype()),
group_size,
bits,
bm,
bn,
bk,
wm,
wn,
transpose));
return kernel_source;
});
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
}
} // namespace mlx::core

View File

@ -160,6 +160,21 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
bool mn_aligned,
bool k_aligned);
MTL::ComputePipelineState* get_steel_gemm_gather_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const array& out,
bool transpose_a,
bool transpose_b,
int bm,
int bn,
int bk,
int wm,
int wn,
bool rhs);
MTL::ComputePipelineState* get_steel_conv_kernel(
metal::Device& d,
const std::string& kernel_name,
@ -209,6 +224,21 @@ MTL::ComputePipelineState* get_quantized_kernel(
const std::string& kernel_name,
const std::string& template_def);
MTL::ComputePipelineState* get_gather_qmm_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const array& x,
int group_size,
int bits,
int bm,
int bn,
int bk,
int wm,
int wn,
bool transpose);
// Create a GPU kernel template definition for JIT compilation
template <typename... Args>
std::string

View File

@ -69,6 +69,7 @@ set(STEEL_HEADERS
steel/gemm/loader.h
steel/gemm/transforms.h
steel/gemm/kernels/steel_gemm_fused.h
steel/gemm/kernels/steel_gemm_gather.h
steel/gemm/kernels/steel_gemm_masked.h
steel/gemm/kernels/steel_gemm_splitk.h
steel/utils/type_traits.h
@ -116,6 +117,7 @@ if(NOT MLX_METAL_JIT)
build_kernel(steel/conv/kernels/steel_conv ${STEEL_HEADERS})
build_kernel(steel/conv/kernels/steel_conv_general ${STEEL_HEADERS})
build_kernel(steel/gemm/kernels/steel_gemm_fused ${STEEL_HEADERS})
build_kernel(steel/gemm/kernels/steel_gemm_gather ${STEEL_HEADERS})
build_kernel(steel/gemm/kernels/steel_gemm_masked ${STEEL_HEADERS})
build_kernel(steel/gemm/kernels/steel_gemm_splitk ${STEEL_HEADERS})
build_kernel(gemv_masked steel/utils.h)

View File

@ -104,10 +104,22 @@ constexpr bool operator==(complex64_t a, complex64_t b) {
constexpr complex64_t operator+(complex64_t a, complex64_t b) {
return {a.real + b.real, a.imag + b.imag};
}
constexpr complex64_t operator+(float a, complex64_t b) {
return {a + b.real, b.imag};
}
constexpr complex64_t operator+(complex64_t a, float b) {
return {a.real + b, a.imag};
}
constexpr complex64_t operator-(complex64_t a, complex64_t b) {
return {a.real - b.real, a.imag - b.imag};
}
constexpr complex64_t operator-(float a, complex64_t b) {
return {a - b.real, -b.imag};
}
constexpr complex64_t operator-(complex64_t a, float b) {
return {a.real - b, a.imag};
}
constexpr complex64_t operator*(complex64_t a, complex64_t b) {
return {a.real * b.real - a.imag * b.imag, a.real * b.imag + a.imag * b.real};
@ -120,6 +132,13 @@ constexpr complex64_t operator/(complex64_t a, complex64_t b) {
return {x / denom, y / denom};
}
constexpr complex64_t operator/(float a, complex64_t b) {
auto denom = b.real * b.real + b.imag * b.imag;
auto x = a * b.real;
auto y = -a * b.imag;
return {x / denom, y / denom};
}
constexpr complex64_t operator%(complex64_t a, complex64_t b) {
auto real = a.real - (b.real * static_cast<int64_t>(a.real / b.real));
auto imag = a.imag - (b.imag * static_cast<int64_t>(a.imag / b.imag));

View File

@ -3,6 +3,10 @@
#include <metal_simdgroup>
#include <metal_stdlib>
constant bool align_M [[function_constant(200)]];
constant bool align_N [[function_constant(201)]];
constant bool align_K [[function_constant(202)]];
using namespace metal;
#define MLX_MTL_CONST static constant constexpr const
@ -1686,26 +1690,26 @@ template <
}
template <typename T, int group_size, int bits>
[[kernel]] void bs_qmv_fast(
[[kernel]] void gather_qmv_fast(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]],
device T* y [[buffer(4)]],
const constant int& in_vec_size [[buffer(5)]],
const constant int& out_vec_size [[buffer(6)]],
const constant int& x_batch_ndims [[buffer(7)]],
const constant int* x_shape [[buffer(8)]],
const constant int64_t* x_strides [[buffer(9)]],
const constant int& w_batch_ndims [[buffer(10)]],
const constant int* w_shape [[buffer(11)]],
const constant int64_t* w_strides [[buffer(12)]],
const constant int64_t* s_strides [[buffer(13)]],
const constant int64_t* b_strides [[buffer(14)]],
const constant int& batch_ndims [[buffer(15)]],
const constant int* batch_shape [[buffer(16)]],
const device uint32_t* lhs_indices [[buffer(17)]],
const device uint32_t* rhs_indices [[buffer(18)]],
const device uint32_t* lhs_indices [[buffer(4)]],
const device uint32_t* rhs_indices [[buffer(5)]],
device T* y [[buffer(6)]],
const constant int& in_vec_size [[buffer(7)]],
const constant int& out_vec_size [[buffer(8)]],
const constant int& x_batch_ndims [[buffer(9)]],
const constant int* x_shape [[buffer(10)]],
const constant int64_t* x_strides [[buffer(11)]],
const constant int& w_batch_ndims [[buffer(12)]],
const constant int* w_shape [[buffer(13)]],
const constant int64_t* w_strides [[buffer(14)]],
const constant int64_t* s_strides [[buffer(15)]],
const constant int64_t* b_strides [[buffer(16)]],
const constant int& batch_ndims [[buffer(17)]],
const constant int* batch_shape [[buffer(18)]],
const constant int64_t* lhs_strides [[buffer(19)]],
const constant int64_t* rhs_strides [[buffer(20)]],
uint3 tid [[threadgroup_position_in_grid]],
@ -1748,26 +1752,26 @@ template <typename T, int group_size, int bits>
}
template <typename T, int group_size, int bits>
[[kernel]] void bs_qmv(
[[kernel]] void gather_qmv(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]],
device T* y [[buffer(4)]],
const constant int& in_vec_size [[buffer(5)]],
const constant int& out_vec_size [[buffer(6)]],
const constant int& x_batch_ndims [[buffer(7)]],
const constant int* x_shape [[buffer(8)]],
const constant int64_t* x_strides [[buffer(9)]],
const constant int& w_batch_ndims [[buffer(10)]],
const constant int* w_shape [[buffer(11)]],
const constant int64_t* w_strides [[buffer(12)]],
const constant int64_t* s_strides [[buffer(13)]],
const constant int64_t* b_strides [[buffer(14)]],
const constant int& batch_ndims [[buffer(15)]],
const constant int* batch_shape [[buffer(16)]],
const device uint32_t* lhs_indices [[buffer(17)]],
const device uint32_t* rhs_indices [[buffer(18)]],
const device uint32_t* lhs_indices [[buffer(4)]],
const device uint32_t* rhs_indices [[buffer(5)]],
device T* y [[buffer(6)]],
const constant int& in_vec_size [[buffer(7)]],
const constant int& out_vec_size [[buffer(8)]],
const constant int& x_batch_ndims [[buffer(9)]],
const constant int* x_shape [[buffer(10)]],
const constant int64_t* x_strides [[buffer(11)]],
const constant int& w_batch_ndims [[buffer(12)]],
const constant int* w_shape [[buffer(13)]],
const constant int64_t* w_strides [[buffer(14)]],
const constant int64_t* s_strides [[buffer(15)]],
const constant int64_t* b_strides [[buffer(16)]],
const constant int& batch_ndims [[buffer(17)]],
const constant int* batch_shape [[buffer(18)]],
const constant int64_t* lhs_strides [[buffer(19)]],
const constant int64_t* rhs_strides [[buffer(20)]],
uint3 tid [[threadgroup_position_in_grid]],
@ -1810,26 +1814,26 @@ template <typename T, int group_size, int bits>
}
template <typename T, int group_size, int bits>
[[kernel]] void bs_qvm(
[[kernel]] void gather_qvm(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]],
device T* y [[buffer(4)]],
const constant int& in_vec_size [[buffer(5)]],
const constant int& out_vec_size [[buffer(6)]],
const constant int& x_batch_ndims [[buffer(7)]],
const constant int* x_shape [[buffer(8)]],
const constant int64_t* x_strides [[buffer(9)]],
const constant int& w_batch_ndims [[buffer(10)]],
const constant int* w_shape [[buffer(11)]],
const constant int64_t* w_strides [[buffer(12)]],
const constant int64_t* s_strides [[buffer(13)]],
const constant int64_t* b_strides [[buffer(14)]],
const constant int& batch_ndims [[buffer(15)]],
const constant int* batch_shape [[buffer(16)]],
const device uint32_t* lhs_indices [[buffer(17)]],
const device uint32_t* rhs_indices [[buffer(18)]],
const device uint32_t* lhs_indices [[buffer(4)]],
const device uint32_t* rhs_indices [[buffer(5)]],
device T* y [[buffer(6)]],
const constant int& in_vec_size [[buffer(7)]],
const constant int& out_vec_size [[buffer(8)]],
const constant int& x_batch_ndims [[buffer(9)]],
const constant int* x_shape [[buffer(10)]],
const constant int64_t* x_strides [[buffer(11)]],
const constant int& w_batch_ndims [[buffer(12)]],
const constant int* w_shape [[buffer(13)]],
const constant int64_t* w_strides [[buffer(14)]],
const constant int64_t* s_strides [[buffer(15)]],
const constant int64_t* b_strides [[buffer(16)]],
const constant int& batch_ndims [[buffer(17)]],
const constant int* batch_shape [[buffer(18)]],
const constant int64_t* lhs_strides [[buffer(19)]],
const constant int64_t* rhs_strides [[buffer(20)]],
uint3 tid [[threadgroup_position_in_grid]],
@ -1879,27 +1883,27 @@ template <
const int BM = 32,
const int BK = 32,
const int BN = 32>
[[kernel]] void bs_qmm_t(
[[kernel]] void gather_qmm_t(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]],
device T* y [[buffer(4)]],
const constant int& K [[buffer(5)]],
const constant int& N [[buffer(6)]],
const constant int& M [[buffer(7)]],
const constant int& x_batch_ndims [[buffer(8)]],
const constant int* x_shape [[buffer(9)]],
const constant int64_t* x_strides [[buffer(10)]],
const constant int& w_batch_ndims [[buffer(11)]],
const constant int* w_shape [[buffer(12)]],
const constant int64_t* w_strides [[buffer(13)]],
const constant int64_t* s_strides [[buffer(14)]],
const constant int64_t* b_strides [[buffer(15)]],
const constant int& batch_ndims [[buffer(16)]],
const constant int* batch_shape [[buffer(17)]],
const device uint32_t* lhs_indices [[buffer(18)]],
const device uint32_t* rhs_indices [[buffer(19)]],
const device uint32_t* lhs_indices [[buffer(4)]],
const device uint32_t* rhs_indices [[buffer(5)]],
device T* y [[buffer(6)]],
const constant int& K [[buffer(7)]],
const constant int& N [[buffer(8)]],
const constant int& M [[buffer(9)]],
const constant int& x_batch_ndims [[buffer(10)]],
const constant int* x_shape [[buffer(11)]],
const constant int64_t* x_strides [[buffer(12)]],
const constant int& w_batch_ndims [[buffer(13)]],
const constant int* w_shape [[buffer(14)]],
const constant int64_t* w_strides [[buffer(15)]],
const constant int64_t* s_strides [[buffer(16)]],
const constant int64_t* b_strides [[buffer(17)]],
const constant int& batch_ndims [[buffer(18)]],
const constant int* batch_shape [[buffer(19)]],
const constant int64_t* lhs_strides [[buffer(20)]],
const constant int64_t* rhs_strides [[buffer(21)]],
uint3 tid [[threadgroup_position_in_grid]],
@ -1946,27 +1950,27 @@ template <
const int BM = 32,
const int BK = 32,
const int BN = 32>
[[kernel]] void bs_qmm_n(
[[kernel]] void gather_qmm_n(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]],
device T* y [[buffer(4)]],
const constant int& K [[buffer(5)]],
const constant int& N [[buffer(6)]],
const constant int& M [[buffer(7)]],
const constant int& x_batch_ndims [[buffer(8)]],
const constant int* x_shape [[buffer(9)]],
const constant int64_t* x_strides [[buffer(10)]],
const constant int& w_batch_ndims [[buffer(11)]],
const constant int* w_shape [[buffer(12)]],
const constant int64_t* w_strides [[buffer(13)]],
const constant int64_t* s_strides [[buffer(14)]],
const constant int64_t* b_strides [[buffer(15)]],
const constant int& batch_ndims [[buffer(16)]],
const constant int* batch_shape [[buffer(17)]],
const device uint32_t* lhs_indices [[buffer(18)]],
const device uint32_t* rhs_indices [[buffer(19)]],
const device uint32_t* lhs_indices [[buffer(4)]],
const device uint32_t* rhs_indices [[buffer(5)]],
device T* y [[buffer(6)]],
const constant int& K [[buffer(7)]],
const constant int& N [[buffer(8)]],
const constant int& M [[buffer(9)]],
const constant int& x_batch_ndims [[buffer(10)]],
const constant int* x_shape [[buffer(11)]],
const constant int64_t* x_strides [[buffer(12)]],
const constant int& w_batch_ndims [[buffer(13)]],
const constant int* w_shape [[buffer(14)]],
const constant int64_t* w_strides [[buffer(15)]],
const constant int64_t* s_strides [[buffer(16)]],
const constant int64_t* b_strides [[buffer(17)]],
const constant int& batch_ndims [[buffer(18)]],
const constant int* batch_shape [[buffer(19)]],
const constant int64_t* lhs_strides [[buffer(20)]],
const constant int64_t* rhs_strides [[buffer(21)]],
uint3 tid [[threadgroup_position_in_grid]],
@ -2007,6 +2011,289 @@ template <
w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
}
template <typename T, typename mma_t, typename loader_a_t, typename loader_b_t>
METAL_FUNC void gemm_loop_aligned(
threadgroup T* As,
threadgroup T* Bs,
thread mma_t& mma_op,
thread loader_a_t& loader_a,
thread loader_b_t& loader_b,
const int k_iterations) {
for (int k = 0; k < k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup memory
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
}
template <
bool rows_aligned,
bool cols_aligned,
bool transpose,
typename T,
typename mma_t,
typename loader_a_t,
typename loader_b_t>
METAL_FUNC void gemm_loop_unaligned(
threadgroup T* As,
threadgroup T* Bs,
thread mma_t& mma_op,
thread loader_a_t& loader_a,
thread loader_b_t& loader_b,
const int k_iterations,
const short tgp_bm,
const short tgp_bn,
const short tgp_bk) {
for (int k = 0; k < k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup memory
if (rows_aligned) {
loader_a.load_unsafe();
} else {
loader_a.load_safe(short2(tgp_bk, tgp_bm));
}
if (cols_aligned) {
loader_b.load_unsafe();
} else {
loader_b.load_safe(
transpose ? short2(tgp_bk, tgp_bn) : short2(tgp_bn, tgp_bk));
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
}
template <typename T, typename mma_t, typename loader_a_t, typename loader_b_t>
METAL_FUNC void gemm_loop_finalize(
threadgroup T* As,
threadgroup T* Bs,
thread mma_t& mma_op,
thread loader_a_t& loader_a,
thread loader_b_t& loader_b,
const short2 tile_a,
const short2 tile_b) {
loader_a.load_safe(tile_a);
loader_b.load_safe(tile_b);
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
template <
typename T,
int group_size,
int bits,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose>
[[kernel]] void gather_qmm_rhs(
const device T* x [[buffer(0)]],
const device uint32_t* w [[buffer(1)]],
const device T* scales [[buffer(2)]],
const device T* biases [[buffer(3)]],
const device uint32_t* indices [[buffer(4)]],
device T* y [[buffer(5)]],
const constant int& M [[buffer(6)]],
const constant int& N [[buffer(7)]],
const constant int& K [[buffer(8)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]]) {
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
constexpr int BK_padded = (BK + 16 / sizeof(T));
constexpr int BN_padded = (BN + 16 / sizeof(T));
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
using mma_t = mlx::steel::BlockMMA<
T,
T,
BM,
BN,
BK,
WM,
WN,
false,
transpose,
BK_padded,
transpose ? BK_padded : BN_padded>;
using loader_x_t =
mlx::steel::BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE>;
using loader_w_t = QuantizedBlockLoader<
T,
transpose ? BN : BK,
transpose ? BK : BN,
transpose ? BK_padded : BN_padded,
transpose,
WM * WN * SIMD_SIZE,
group_size,
bits>;
threadgroup T Xs[BM * BK_padded];
threadgroup T Ws[transpose ? BN * BK_padded : BK * BN_padded];
// Compute the block
const int K_w = K * bytes_per_pack / pack_factor;
const int K_g = K / group_size;
const int N_w = N * bytes_per_pack / pack_factor;
const int N_g = N / group_size;
const int K_it = K / BK;
const size_t stride_w = transpose ? N * K_w : K * N_w;
const size_t stride_s = transpose ? N * K_g : K * N_g;
const int y_row = tid.y * BM;
const int y_col = tid.x * BN;
const size_t y_row_long = size_t(y_row);
const size_t y_col_long = size_t(y_col);
// Prepare threadgroup bounds
const short tgp_bm = align_M ? BM : short(min(BM, M - y_row));
const short tgp_bn = align_N ? BN : short(min(BN, N - y_col));
// Calculate the final tiles in the case that K is not aligned
const int k_remain = K - K_it * BK;
const short2 tile_x = short2(k_remain, tgp_bm);
const short2 tile_w =
transpose ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
// Move x and output to the correct block
auto wl = (const device uint8_t*)w;
x += y_row_long * K;
y += y_row_long * N + y_col_long;
wl += transpose ? y_col_long * K_w : y_col * bytes_per_pack / pack_factor;
scales += transpose ? y_col_long * K_g : y_col / group_size;
biases += transpose ? y_col_long * K_g : y_col / group_size;
// Do as many matmuls as necessary
uint32_t index;
short offset;
uint32_t index_next = indices[y_row];
short offset_next = 0;
int n = 0;
while (n < tgp_bm) {
n++;
offset = offset_next;
index = index_next;
offset_next = tgp_bm;
for (; n < tgp_bm; n++) {
if (indices[y_row + n] != index) {
offset_next = n;
index_next = indices[y_row + n];
break;
}
}
threadgroup_barrier(mem_flags::mem_none);
// Prepare threadgroup mma operation
thread mma_t mma_op(simd_group_id, simd_lane_id);
// Prepare threadgroup loading operations
thread loader_x_t loader_x(x, K, Xs, simd_group_id, simd_lane_id);
thread loader_w_t loader_w(
wl + index * stride_w,
scales + index * stride_s,
biases + index * stride_s,
transpose ? K : N,
Ws,
simd_group_id,
simd_lane_id);
// Matrices are all aligned check nothing
if (align_M && align_N) {
gemm_loop_aligned(Xs, Ws, mma_op, loader_x, loader_w, K_it);
if (!align_K) {
threadgroup_barrier(mem_flags::mem_threadgroup);
gemm_loop_finalize(Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
}
// Store results to device memory
if (offset_next - offset == BM) {
mma_op.store_result(y, N);
} else {
mma_op.store_result_slice(
y, N, short2(0, offset), short2(BN, offset_next));
}
} else {
// Tile aligned so check outside of the hot loop
if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {
gemm_loop_aligned(Xs, Ws, mma_op, loader_x, loader_w, K_it);
if (!align_K) {
threadgroup_barrier(mem_flags::mem_threadgroup);
gemm_loop_finalize(
Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
}
// Store results to device memory
if (offset_next - offset == BM) {
mma_op.store_result(y, N);
} else {
mma_op.store_result_slice(
y, N, short2(0, offset), short2(BN, offset_next));
}
}
// Tile partially aligned check rows
else if (align_N || tgp_bn == BN) {
gemm_loop_unaligned<false, true, transpose>(
Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK);
if (!align_K) {
threadgroup_barrier(mem_flags::mem_threadgroup);
gemm_loop_finalize(
Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
}
mma_op.store_result_slice(
y, N, short2(0, offset), short2(BN, offset_next));
}
// Tile partially aligned check cols
else if (align_M || tgp_bm == BM) {
gemm_loop_unaligned<true, false, transpose>(
Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK);
if (!align_K) {
threadgroup_barrier(mem_flags::mem_threadgroup);
gemm_loop_finalize(
Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
}
mma_op.store_result_slice(
y, N, short2(0, offset), short2(tgp_bn, offset_next));
}
// Nothing aligned so check both rows and cols
else {
gemm_loop_unaligned<false, false, transpose>(
Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK);
if (!align_K) {
threadgroup_barrier(mem_flags::mem_threadgroup);
gemm_loop_finalize(
Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
}
mma_op.store_result_slice(
y, N, short2(0, offset), short2(tgp_bn, offset_next));
}
}
}
}
template <typename T, const int group_size, const int bits>
[[kernel]] void affine_quantize(
const device T* w [[buffer(0)]],

View File

@ -60,6 +60,20 @@
bits, \
split_k)
#define instantiate_gather_qmm_rhs(func, name, type, group_size, bits, bm, bn, bk, wm, wn, transpose) \
instantiate_kernel( \
#name "_" #type "_gs_" #group_size "_b_" #bits "_bm_" #bm "_bn_" #bn "_bk_" #bk "_wm_" #wm "_wn_" #wn, \
func, \
type, \
group_size, \
bits, \
bm, \
bn, \
bk, \
wm, \
wn, \
transpose)
#define instantiate_quantized_batched_wrap(name, type, group_size, bits) \
instantiate_quantized_batched(name, type, group_size, bits, 1) \
instantiate_quantized_batched(name, type, group_size, bits, 0)
@ -73,14 +87,14 @@
#define instantiate_quantized_all_single(type, group_size, bits) \
instantiate_quantized(affine_quantize, type, group_size, bits) \
instantiate_quantized(affine_dequantize, type, group_size, bits) \
instantiate_quantized(bs_qmv_fast, type, group_size, bits) \
instantiate_quantized(bs_qmv, type, group_size, bits) \
instantiate_quantized(bs_qvm, type, group_size, bits) \
instantiate_quantized(bs_qmm_n, type, group_size, bits)
instantiate_quantized(gather_qmv_fast, type, group_size, bits) \
instantiate_quantized(gather_qmv, type, group_size, bits) \
instantiate_quantized(gather_qvm, type, group_size, bits) \
instantiate_quantized(gather_qmm_n, type, group_size, bits)
#define instantiate_quantized_all_aligned(type, group_size, bits) \
instantiate_quantized_aligned(bs_qmm_t, type, group_size, bits, true) \
instantiate_quantized_aligned(bs_qmm_t, type, group_size, bits, false) \
instantiate_quantized_aligned(gather_qmm_t, type, group_size, bits, true) \
instantiate_quantized_aligned(gather_qmm_t, type, group_size, bits, false) \
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, true, 1) \
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, true, 0) \
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, false, 1) \
@ -96,12 +110,17 @@
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 8) \
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 32)
#define instantiate_quantized_all_rhs(type, group_size, bits) \
instantiate_gather_qmm_rhs(gather_qmm_rhs, gather_qmm_rhs_nt, type, group_size, bits, 16, 32, 32, 1, 2, true) \
instantiate_gather_qmm_rhs(gather_qmm_rhs, gather_qmm_rhs_nn, type, group_size, bits, 16, 32, 32, 1, 2, false)
#define instantiate_quantized_funcs(type, group_size, bits) \
instantiate_quantized_all_single(type, group_size, bits) \
instantiate_quantized_all_batched(type, group_size, bits) \
instantiate_quantized_all_aligned(type, group_size, bits) \
instantiate_quantized_all_quad(type, group_size, bits) \
instantiate_quantized_all_splitk(type, group_size, bits)
instantiate_quantized_all_splitk(type, group_size, bits) \
instantiate_quantized_all_rhs(type, group_size, bits)
#define instantiate_quantized_types(group_size, bits) \
instantiate_quantized_funcs(float, group_size, bits) \

View File

@ -2,6 +2,8 @@
#pragma once
#include "mlx/backend/metal/kernels/binary_ops.h"
#define DEFINE_SIMD_SCAN() \
template <typename T, metal::enable_if_t<sizeof(T) < 8, bool> = true> \
T simd_scan(T val) { \
@ -139,6 +141,29 @@ struct CumMin {
}
};
template <typename U>
struct CumLogaddexp {
static constexpr constant U init = Limits<U>::min;
template <typename T>
U operator()(U a, T b) {
return LogAddExp{}(a, static_cast<U>(b));
}
U simd_scan(U x) {
for (int i = 1; i <= 16; i *= 2) {
U other = simd_shuffle_and_fill_up(x, init, i);
x = LogAddExp{}(x, other);
}
return x;
}
U simd_exclusive_scan(U x) {
x = simd_scan(x);
return simd_shuffle_and_fill_up(x, init, 1);
}
};
template <typename T, typename U, int N_READS, bool reverse>
inline void load_unsafe(U values[N_READS], const device T* input) {
if (reverse) {

View File

@ -101,4 +101,7 @@ instantiate_scan_helper(min_int64_int64, int64_t, int64_t, CumMi
instantiate_scan_helper(min_float16_float16, half, half, CumMin, 4)
instantiate_scan_helper(min_float32_float32, float, float, CumMin, 4)
instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMin, 4)
instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin, 2) // clang-format on
instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin, 2)
instantiate_scan_helper(logaddexp_float16_float16, half, half, CumLogaddexp, 4)
instantiate_scan_helper(logaddexp_float32_float32, float, float, CumLogaddexp, 4)
instantiate_scan_helper(logaddexp_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumLogaddexp, 4) // clang-format on

View File

@ -15,10 +15,6 @@ constant bool align_M [[function_constant(200)]];
constant bool align_N [[function_constant(201)]];
constant bool align_K [[function_constant(202)]];
constant bool do_gather [[function_constant(300)]];
constant bool gather_bias = do_gather && use_out_source;
// clang-format off
template <
typename T,
@ -39,12 +35,6 @@ template <
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
const constant int* batch_shape [[buffer(6)]],
const constant int64_t* batch_strides [[buffer(7)]],
const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]],
const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]],
const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]],
const constant int* operand_shape [[buffer(13), function_constant(do_gather)]],
const constant int64_t* operand_strides [[buffer(14), function_constant(do_gather)]],
const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
@ -81,63 +71,6 @@ template <
}
// Adjust for batch
// Handle gather
if (do_gather) {
// Read indices
uint32_t indx_A, indx_B, indx_C;
if (has_batch) {
const constant auto* indx_A_bstrides = batch_strides;
const constant auto* indx_B_bstrides = batch_strides + params->batch_ndim;
ulong2 indx_offsets = elem_to_loc_broadcast(
tid.z,
batch_shape,
indx_A_bstrides,
indx_B_bstrides,
params->batch_ndim);
indx_A = lhs_indices[indx_offsets.x];
indx_B = rhs_indices[indx_offsets.y];
if (use_out_source) {
const constant auto* indx_C_bstrides =
indx_B_bstrides + params->batch_ndim;
auto indx_offset_C = elem_to_loc(
tid.z, batch_shape, indx_C_bstrides, params->batch_ndim);
indx_C = C_indices[indx_offset_C];
}
} else {
indx_A = lhs_indices[params->batch_stride_a * tid.z];
indx_B = rhs_indices[params->batch_stride_b * tid.z];
if (use_out_source) {
indx_C = C_indices[addmm_params->batch_stride_c * tid.z];
}
}
// Translate indices to offsets
int batch_ndim_A = operand_batch_ndim.x;
const constant int* batch_shape_A = operand_shape;
const constant auto* batch_strides_A = operand_strides;
A += elem_to_loc(indx_A, batch_shape_A, batch_strides_A, batch_ndim_A);
int batch_ndim_B = operand_batch_ndim.y;
const constant int* batch_shape_B = batch_shape_A + batch_ndim_A;
const constant auto* batch_strides_B = batch_strides_A + batch_ndim_A;
B += elem_to_loc(indx_B, batch_shape_B, batch_strides_B, batch_ndim_B);
if (use_out_source) {
int batch_ndim_C = operand_batch_ndim.z;
const constant int* batch_shape_C = batch_shape_B + batch_ndim_B;
const constant auto* batch_strides_C = batch_strides_B + batch_ndim_B;
C += elem_to_loc(indx_C, batch_shape_C, batch_strides_C, batch_ndim_C);
}
}
// Handle regular batch
else {
if (has_batch) {
const constant auto* A_bstrides = batch_strides;
const constant auto* B_bstrides = batch_strides + params->batch_ndim;
@ -160,7 +93,6 @@ template <
C += addmm_params->batch_stride_c * tid.z;
}
}
}
D += params->batch_stride_d * tid.z;

View File

@ -0,0 +1,459 @@
// Copyright © 2024 Apple Inc.
using namespace mlx::steel;
constant bool has_batch [[function_constant(10)]];
constant bool align_M [[function_constant(200)]];
constant bool align_N [[function_constant(201)]];
constant bool align_K [[function_constant(202)]];
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
typename AccumType = float>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gather_mm_rhs(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
const device uint32_t* rhs_indices [[buffer(2)]],
device T* C [[buffer(3)]],
const constant GEMMParams* params [[buffer(4)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]]) {
using gemm_kernel = GEMMKernel<
T,
T,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
true,
true,
AccumType>;
using loader_a_t = typename gemm_kernel::loader_a_t;
using loader_b_t = typename gemm_kernel::loader_b_t;
using mma_t = typename gemm_kernel::mma_t;
if (params->tiles_n <= static_cast<int>(tid.x) ||
params->tiles_m <= static_cast<int>(tid.y)) {
return;
}
// Prepare threadgroup memory
threadgroup T As[gemm_kernel::tgp_mem_size_a];
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
// Find the block in A, B, C
const int c_row = tid.y * BM;
const int c_col = tid.x * BN;
const size_t c_row_long = size_t(c_row);
const size_t c_col_long = size_t(c_col);
// Prepare threadgroup bounds
const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row));
const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col));
A += transpose_a ? c_row_long : c_row_long * params->lda;
B += transpose_b ? c_col_long * params->ldb : c_col_long;
C += c_row_long * params->ldd + c_col_long;
// Do as many matmuls as necessary
uint32_t index;
short offset;
uint32_t index_next = rhs_indices[c_row];
short offset_next = 0;
int n = 0;
while (n < tgp_bm) {
n++;
offset = offset_next;
index = index_next;
offset_next = tgp_bm;
for (; n < tgp_bm; n++) {
if (rhs_indices[c_row + n] != index) {
offset_next = n;
index_next = rhs_indices[c_row + n];
break;
}
}
threadgroup_barrier(mem_flags::mem_none);
// Prepare threadgroup mma operation
thread mma_t mma_op(simd_group_id, simd_lane_id);
// Prepare threadgroup loading operations
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
thread loader_b_t loader_b(
B + index * params->batch_stride_b,
params->ldb,
Bs,
simd_group_id,
simd_lane_id);
// Prepare iterations
const int gemm_k_iterations = params->gemm_k_iterations_aligned;
// Do unaligned K iterations first
if (!align_K) {
const int k_last = params->gemm_k_iterations_aligned * BK;
const int k_remain = params->K - k_last;
const size_t k_jump_a =
transpose_a ? params->lda * size_t(k_last) : size_t(k_last);
const size_t k_jump_b =
transpose_b ? size_t(k_last) : params->ldb * size_t(k_last);
// Move loader source ahead to end
loader_a.src += k_jump_a;
loader_b.src += k_jump_b;
// Load tile
const short2 tile_dims_A =
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
const short2 tile_dims_B =
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
loader_a.load_safe(tile_dims_A);
loader_b.load_safe(tile_dims_B);
threadgroup_barrier(mem_flags::mem_threadgroup);
// Do matmul
mma_op.mma(As, Bs);
// Reset source back to start
loader_a.src -= k_jump_a;
loader_b.src -= k_jump_b;
}
// Matrix level aligned never check
if (align_M && align_N) {
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
// Store results to device memory
if (offset_next - offset == BM) {
mma_op.store_result(C, params->ldd);
} else {
mma_op.store_result_slice(
C, params->ldd, short2(0, offset), short2(BN, offset_next));
}
} else {
const short lbk = 0;
// Tile aligned don't check
if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
lbk,
LoopAlignment<true, true, true>{});
if (offset_next - offset == BM) {
mma_op.store_result(C, params->ldd);
} else {
mma_op.store_result_slice(
C, params->ldd, short2(0, offset), short2(BN, offset_next));
}
}
// Tile partially aligned check rows
else if (align_N || tgp_bn == BN) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
lbk,
LoopAlignment<false, true, true>{});
mma_op.store_result_slice(
C, params->ldd, short2(0, offset), short2(BN, offset_next));
}
// Tile partially aligned check cols
else if (align_M || tgp_bm == BM) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
lbk,
LoopAlignment<true, false, true>{});
mma_op.store_result_slice(
C, params->ldd, short2(0, offset), short2(tgp_bn, offset_next));
}
// Nothing aligned so check both rows and cols
else {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
lbk,
LoopAlignment<false, false, true>{});
mma_op.store_result_slice(
C, params->ldd, short2(0, offset), short2(tgp_bn, offset_next));
}
}
}
}
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
typename AccumType = float>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gather_mm(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
const device uint32_t* lhs_indices [[buffer(2)]],
const device uint32_t* rhs_indices [[buffer(3)]],
device T* C [[buffer(4)]],
const constant GEMMParams* params [[buffer(5)]],
const constant int* indices_shape [[buffer(6)]],
const constant int64_t* lhs_strides [[buffer(7)]],
const constant int64_t* rhs_strides [[buffer(8)]],
const constant int& batch_ndim_a [[buffer(9)]],
const constant int* batch_shape_a [[buffer(10)]],
const constant int64_t* batch_strides_a [[buffer(11)]],
const constant int& batch_ndim_b [[buffer(12)]],
const constant int* batch_shape_b [[buffer(13)]],
const constant int64_t* batch_strides_b [[buffer(14)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]]) {
using gemm_kernel = GEMMKernel<
T,
T,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
true,
true,
AccumType>;
using loader_a_t = typename gemm_kernel::loader_a_t;
using loader_b_t = typename gemm_kernel::loader_b_t;
using mma_t = typename gemm_kernel::mma_t;
if (params->tiles_n <= static_cast<int>(tid.x) ||
params->tiles_m <= static_cast<int>(tid.y)) {
return;
}
// Move A and B to the locations pointed by lhs_indices and rhs_indices.
uint32_t indx_A, indx_B;
if (has_batch) {
ulong2 indices_offsets = elem_to_loc_broadcast(
tid.z, indices_shape, lhs_strides, rhs_strides, params->batch_ndim);
indx_A = lhs_indices[indices_offsets.x];
indx_B = rhs_indices[indices_offsets.y];
} else {
indx_A = lhs_indices[params->batch_stride_a * tid.z];
indx_B = rhs_indices[params->batch_stride_b * tid.z];
}
A += elem_to_loc(indx_A, batch_shape_a, batch_strides_a, batch_ndim_a);
B += elem_to_loc(indx_B, batch_shape_b, batch_strides_b, batch_ndim_b);
C += params->batch_stride_d * tid.z;
// Prepare threadgroup memory
threadgroup T As[gemm_kernel::tgp_mem_size_a];
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
// Just make sure everybody's finished with the indexing math above.
threadgroup_barrier(mem_flags::mem_none);
// Find block in A, B, C
const int c_row = tid.y * BM;
const int c_col = tid.x * BN;
const size_t c_row_long = size_t(c_row);
const size_t c_col_long = size_t(c_col);
A += transpose_a ? c_row_long : c_row_long * params->lda;
B += transpose_b ? c_col_long * params->ldb : c_col_long;
C += c_row_long * params->ldd + c_col_long;
// Prepare threadgroup mma operation
thread mma_t mma_op(simd_group_id, simd_lane_id);
// Prepare threadgroup loading operations
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
// Prepare threadgroup bounds
const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row));
const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col));
// Prepare iterations
int gemm_k_iterations = params->gemm_k_iterations_aligned;
// Do unaligned K iterations first
if (!align_K) {
const int k_last = params->gemm_k_iterations_aligned * BK;
const int k_remain = params->K - k_last;
const size_t k_jump_a =
transpose_a ? params->lda * size_t(k_last) : size_t(k_last);
const size_t k_jump_b =
transpose_b ? size_t(k_last) : params->ldb * size_t(k_last);
// Move loader source ahead to end
loader_a.src += k_jump_a;
loader_b.src += k_jump_b;
// Load tile
const short2 tile_dims_A =
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
const short2 tile_dims_B =
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
loader_a.load_safe(tile_dims_A);
loader_b.load_safe(tile_dims_B);
threadgroup_barrier(mem_flags::mem_threadgroup);
// Do matmul
mma_op.mma(As, Bs);
// Reset source back to start
loader_a.src -= k_jump_a;
loader_b.src -= k_jump_b;
}
// Matrix level aligned never check
if (align_M && align_N) {
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
// Store results to device memory
mma_op.store_result(C, params->ldd);
} else {
const short lbk = 0;
// Tile aligned don't check
if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
lbk,
LoopAlignment<true, true, true>{});
mma_op.store_result(C, params->ldd);
}
// Tile partially aligned check rows
else if (align_N || tgp_bn == BN) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
lbk,
LoopAlignment<false, true, true>{});
mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm));
}
// Tile partially aligned check cols
else if (align_M || tgp_bm == BM) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
lbk,
LoopAlignment<true, false, true>{});
mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm));
}
// Nothing aligned so check both rows and cols
else {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
lbk,
LoopAlignment<false, false, true>{});
mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm));
}
}
}

View File

@ -0,0 +1,59 @@
// Copyright © 2024 Apple Inc.
// clang-format off
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h"
#define instantiate_gather_mm_rhs(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_kernel( \
"steel_gather_mm_rhs_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn \
"_bk" #bk "_wm" #wm "_wn" #wn, \
gather_mm_rhs, \
itype, \
bm, \
bn, \
bk, \
wm, \
wn, \
trans_a, \
trans_b, \
float)
#define instantiate_gather_mm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_kernel( \
"steel_gather_mm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn \
"_bk" #bk "_wm" #wm "_wn" #wn, \
gather_mm, \
itype, \
bm, \
bn, \
bk, \
wm, \
wn, \
trans_a, \
trans_b, \
float)
#define instantiate_gather_mm_rhs_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gather_mm_rhs(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gather_mm_rhs(nt, false, true, iname, itype, oname, otype, bm, bn, bk, wm, wn)
#define instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gather_mm(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gather_mm(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gather_mm(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gather_mm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
#define instantiate_gather_mm_shapes_helper(iname, itype, oname, otype) \
instantiate_gather_mm_rhs_transpose_helper(iname, itype, oname, otype, 16, 64, 16, 1, 2) \
instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 1, 2) \
instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 1, 2) \
instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2)
// clang-format on
instantiate_gather_mm_shapes_helper(float16, half, float16, half);
instantiate_gather_mm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
instantiate_gather_mm_shapes_helper(float32, float, float32, float);

View File

@ -142,6 +142,42 @@ struct BaseMMAFrag<T, 8, 8> {
}
}
template <
typename DstPtrType,
typename StrX,
typename StrY,
typename StartX,
typename StopX,
typename StartY,
typename StopY,
typename OffX,
typename OffY>
METAL_FUNC static constexpr void store_slice(
const thread frag_type& src,
DstPtrType dst,
StrX str_x,
StrY str_y,
StartX start_x,
StopX stop_x,
StartY start_y,
StopY stop_y,
OffX off_x = Int<0>{},
OffY off_y = Int<0>{}) {
using U = pointer_element_t<DstPtrType>;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kElemRows; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kElemCols; j++) {
if ((off_x + i) < stop_x && (off_x + i) >= start_x &&
(off_y + j) < stop_y && (off_y + j) >= start_y) {
dst[(off_x + i) * str_x + (off_y + j) * str_y] =
static_cast<U>(src[i * kElemCols + j]);
}
}
}
}
METAL_FUNC static constexpr void mma(
thread frag_type& D,
thread frag_type& A,
@ -335,6 +371,31 @@ struct MMATile {
}
}
}
template <typename U, int w_x, int w_y>
METAL_FUNC void store_slice(
device U* dst,
const int ld,
const short2 start,
const short2 stop) const {
STEEL_PRAGMA_UNROLL
for (int i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (int j = 0; j < kTileCols; ++j) {
MMAFrag_t::store_slice(
frag_at(i, j),
dst,
ld,
Int<1>{},
start.y,
stop.y,
start.x,
stop.x,
(i * kFragRows) * w_x,
(j * kFragCols) * w_y);
}
}
}
};
template <typename T, typename U, int M, int N, int K>
@ -474,6 +535,26 @@ struct BlockMMA {
Ctile.template store<U, WM, WN>(D, ldd);
}
METAL_FUNC void
store_result_slice(device U* D, const int ldd, short2 start, short2 stop) {
// Apply epilogue
STEEL_PRAGMA_UNROLL
for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]);
}
D += sm * ldd + sn;
start -= short2(sn, sm);
stop -= short2(sn, sm);
// TODO: Check the start as well
if (stop.y <= 0 || stop.x <= 0) {
return;
}
Ctile.template store_slice<U, WM, WN>(D, ldd, start, stop);
}
METAL_FUNC void
store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) {
// Apply epilogue

View File

@ -69,6 +69,9 @@ instantiate_unary_float(Round)
instantiate_unary_int(BitwiseInvert)
instantiate_unary_all_same(Abs, complex64, complex64_t)
instantiate_unary_all_same(ArcCos, complex64, complex64_t)
instantiate_unary_all_same(ArcSin, complex64, complex64_t)
instantiate_unary_all_same(ArcTan, complex64, complex64_t)
instantiate_unary_all_same(Conjugate, complex64, complex64_t)
instantiate_unary_all_same(Cos, complex64, complex64_t)
instantiate_unary_all_same(Cosh, complex64, complex64_t)
@ -80,6 +83,9 @@ instantiate_unary_all_same(Negative, complex64, complex64_t)
instantiate_unary_all_same(Sign, complex64, complex64_t)
instantiate_unary_all_same(Sin, complex64, complex64_t)
instantiate_unary_all_same(Sinh, complex64, complex64_t)
instantiate_unary_all_same(Square, complex64, complex64_t)
instantiate_unary_all_same(Sqrt, complex64, complex64_t)
instantiate_unary_all_same(Rsqrt, complex64, complex64_t)
instantiate_unary_all_same(Tan, complex64, complex64_t)
instantiate_unary_all_same(Tanh, complex64, complex64_t)
instantiate_unary_all_same(Round, complex64, complex64_t)

View File

@ -17,27 +17,21 @@ struct Abs {
T operator()(T x) {
return metal::abs(x);
};
template <>
uint8_t operator()(uint8_t x) {
return x;
};
template <>
uint16_t operator()(uint16_t x) {
return x;
};
template <>
uint32_t operator()(uint32_t x) {
return x;
};
template <>
uint64_t operator()(uint64_t x) {
return x;
};
template <>
bool operator()(bool x) {
return x;
};
template <>
complex64_t operator()(complex64_t x) {
return {metal::precise::sqrt(x.real * x.real + x.imag * x.imag), 0};
};
@ -48,6 +42,8 @@ struct ArcCos {
T operator()(T x) {
return metal::precise::acos(x);
};
complex64_t operator()(complex64_t x);
};
struct ArcCosh {
@ -62,6 +58,8 @@ struct ArcSin {
T operator()(T x) {
return metal::precise::asin(x);
};
complex64_t operator()(complex64_t x);
};
struct ArcSinh {
@ -76,6 +74,8 @@ struct ArcTan {
T operator()(T x) {
return metal::precise::atan(x);
};
complex64_t operator()(complex64_t x);
};
struct ArcTanh {
@ -97,39 +97,30 @@ struct Ceil {
T operator()(T x) {
return metal::ceil(x);
};
template <>
int8_t operator()(int8_t x) {
return x;
};
template <>
int16_t operator()(int16_t x) {
return x;
};
template <>
int32_t operator()(int32_t x) {
return x;
};
template <>
int64_t operator()(int64_t x) {
return x;
};
template <>
uint8_t operator()(uint8_t x) {
return x;
};
template <>
uint16_t operator()(uint16_t x) {
return x;
};
template <>
uint32_t operator()(uint32_t x) {
return x;
};
template <>
uint64_t operator()(uint64_t x) {
return x;
};
template <>
bool operator()(bool x) {
return x;
};
@ -141,7 +132,6 @@ struct Cos {
return metal::precise::cos(x);
};
template <>
complex64_t operator()(complex64_t x) {
return {
metal::precise::cos(x.real) * metal::precise::cosh(x.imag),
@ -155,7 +145,6 @@ struct Cosh {
return metal::precise::cosh(x);
};
template <>
complex64_t operator()(complex64_t x) {
return {
metal::precise::cosh(x.real) * metal::precise::cos(x.imag),
@ -188,7 +177,6 @@ struct Exp {
T operator()(T x) {
return metal::precise::exp(x);
};
template <>
complex64_t operator()(complex64_t x) {
auto m = metal::precise::exp(x.real);
return {m * metal::precise::cos(x.imag), m * metal::precise::sin(x.imag)};
@ -207,39 +195,30 @@ struct Floor {
T operator()(T x) {
return metal::floor(x);
};
template <>
int8_t operator()(int8_t x) {
return x;
};
template <>
int16_t operator()(int16_t x) {
return x;
};
template <>
int32_t operator()(int32_t x) {
return x;
};
template <>
int64_t operator()(int64_t x) {
return x;
};
template <>
uint8_t operator()(uint8_t x) {
return x;
};
template <>
uint16_t operator()(uint16_t x) {
return x;
};
template <>
uint32_t operator()(uint32_t x) {
return x;
};
template <>
uint64_t operator()(uint64_t x) {
return x;
};
template <>
bool operator()(bool x) {
return x;
};
@ -258,7 +237,6 @@ struct Log {
return metal::precise::log(x);
};
template <>
complex64_t operator()(complex64_t x) {
auto r = metal::precise::log(Abs{}(x).real);
auto i = metal::precise::atan2(x.imag, x.real);
@ -272,7 +250,6 @@ struct Log2 {
return metal::precise::log2(x);
};
template <>
complex64_t operator()(complex64_t x) {
auto y = Log{}(x);
return {y.real / M_LN2_F, y.imag / M_LN2_F};
@ -285,7 +262,6 @@ struct Log10 {
return metal::precise::log10(x);
};
template <>
complex64_t operator()(complex64_t x) {
auto y = Log{}(x);
return {y.real / M_LN10_F, y.imag / M_LN10_F};
@ -325,7 +301,6 @@ struct Round {
T operator()(T x) {
return metal::rint(x);
};
template <>
complex64_t operator()(complex64_t x) {
return {metal::rint(x.real), metal::rint(x.imag)};
};
@ -344,11 +319,9 @@ struct Sign {
T operator()(T x) {
return (x > T(0)) - (x < T(0));
};
template <>
uint32_t operator()(uint32_t x) {
return x != 0;
};
template <>
complex64_t operator()(complex64_t x) {
if (x == complex64_t(0)) {
return x;
@ -364,7 +337,6 @@ struct Sin {
return metal::precise::sin(x);
};
template <>
complex64_t operator()(complex64_t x) {
return {
metal::precise::sin(x.real) * metal::precise::cosh(x.imag),
@ -378,7 +350,6 @@ struct Sinh {
return metal::precise::sinh(x);
};
template <>
complex64_t operator()(complex64_t x) {
return {
metal::precise::sinh(x.real) * metal::precise::cos(x.imag),
@ -398,6 +369,17 @@ struct Sqrt {
T operator()(T x) {
return metal::precise::sqrt(x);
};
complex64_t operator()(complex64_t x) {
if (x.real == 0.0 && x.imag == 0.0) {
return {0.0, 0.0};
}
auto r = Abs{}(x).real;
auto a = metal::precise::sqrt((r + x.real) / 2.0);
auto b_abs = metal::precise::sqrt((r - x.real) / 2.0);
auto b = metal::copysign(b_abs, x.imag);
return {a, b};
}
};
struct Rsqrt {
@ -405,6 +387,10 @@ struct Rsqrt {
T operator()(T x) {
return metal::precise::rsqrt(x);
};
complex64_t operator()(complex64_t x) {
return 1.0 / Sqrt{}(x);
}
};
struct Tan {
@ -413,7 +399,6 @@ struct Tan {
return metal::precise::tan(x);
};
template <>
complex64_t operator()(complex64_t x) {
float tan_a = metal::precise::tan(x.real);
float tanh_b = metal::precise::tanh(x.imag);
@ -429,7 +414,6 @@ struct Tanh {
return metal::precise::tanh(x);
};
template <>
complex64_t operator()(complex64_t x) {
float tanh_a = metal::precise::tanh(x.real);
float tan_b = metal::precise::tan(x.imag);
@ -438,3 +422,21 @@ struct Tanh {
return {(tanh_a + tan_b * t1) / denom, (tan_b - tanh_a * t1) / denom};
};
};
complex64_t ArcCos::operator()(complex64_t x) {
auto i = complex64_t{0.0, 1.0};
auto y = Log{}(x + i * Sqrt{}(1.0 - x * x));
return {y.imag, -y.real};
};
complex64_t ArcSin::operator()(complex64_t x) {
auto i = complex64_t{0.0, 1.0};
auto y = Log{}(i * x + Sqrt{}(1.0 - x * x));
return {y.imag, -y.real};
};
complex64_t ArcTan::operator()(complex64_t x) {
auto i = complex64_t{0.0, 1.0};
auto ix = i * x;
return (1.0 / complex64_t{0.0, 2.0}) * Log{}((1.0 + ix) / (1.0 - ix));
};

View File

@ -5,6 +5,7 @@
#include <numeric>
#include <sstream>
#include "mlx/backend/common/broadcasting.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
@ -102,6 +103,47 @@ std::tuple<bool, int64_t, array> check_transpose(
}
};
inline array
ensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) {
if (!x.flags().row_contiguous) {
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
d.add_temporary(x_copy, s.index);
return x_copy;
} else {
return x;
}
}
inline std::tuple<bool, int64_t, array>
ensure_batch_contiguous(const array& x, metal::Device& d, const Stream& s) {
if (x.flags().row_contiguous) {
return std::make_tuple(false, x.strides()[x.ndim() - 2], x);
}
bool rc = true;
for (int i = 0; i < x.ndim() - 3; i++) {
rc &= x.strides()[i + 1] * x.shape(i) == x.strides()[i];
}
if (rc) {
auto stx = x.strides()[x.ndim() - 2];
auto sty = x.strides()[x.ndim() - 1];
auto K = x.shape(-2);
auto N = x.shape(-1);
if (sty == 1 && (N != 1 || stx == N)) {
return std::make_tuple(false, stx, x);
}
if (stx == 1 && (N != 1 || sty == K)) {
return std::make_tuple(true, sty, x);
}
}
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
d.add_temporary(x_copy, s.index);
return std::make_tuple(false, x_copy.strides()[x_copy.ndim() - 2], x_copy);
}
} // namespace
///////////////////////////////////////////////////////////////////////////////
@ -230,7 +272,6 @@ void steel_matmul_regular(
const bool align_M = (M % bm) == 0;
const bool align_N = (N % bn) == 0;
const bool align_K = (K % bk) == 0;
const bool do_gather = false;
metal::MTLFCList func_consts = {
{&has_batch, MTL::DataType::DataTypeBool, 10},
@ -239,7 +280,6 @@ void steel_matmul_regular(
{&align_M, MTL::DataType::DataTypeBool, 200},
{&align_N, MTL::DataType::DataTypeBool, 201},
{&align_K, MTL::DataType::DataTypeBool, 202},
{&do_gather, MTL::DataType::DataTypeBool, 300},
};
// clang-format off
@ -248,8 +288,7 @@ void steel_matmul_regular(
<< "_do_axpby_" << (do_axpby ? 't' : 'n')
<< "_align_M_" << (align_M ? 't' : 'n')
<< "_align_N_" << (align_N ? 't' : 'n')
<< "_align_K_" << (align_K ? 't' : 'n')
<< "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on
<< "_align_K_" << (align_K ? 't' : 'n'); // clang-format on
std::string hash_name = kname.str();
@ -975,7 +1014,6 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
const bool align_M = (M % bm) == 0;
const bool align_N = (N % bn) == 0;
const bool align_K = (K % bk) == 0;
const bool do_gather = false;
metal::MTLFCList func_consts = {
{&has_batch, MTL::DataType::DataTypeBool, 10},
@ -984,7 +1022,6 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
{&align_M, MTL::DataType::DataTypeBool, 200},
{&align_N, MTL::DataType::DataTypeBool, 201},
{&align_K, MTL::DataType::DataTypeBool, 202},
{&do_gather, MTL::DataType::DataTypeBool, 300},
};
// clang-format off
@ -993,8 +1030,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
<< "_do_axpby_" << (do_axpby ? 't' : 'n')
<< "_align_M_" << (align_M ? 't' : 'n')
<< "_align_N_" << (align_N ? 't' : 'n')
<< "_align_K_" << (align_K ? 't' : 'n')
<< "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on
<< "_align_K_" << (align_K ? 't' : 'n'); // clang-format on
std::string hash_name = kname.str();
@ -1464,131 +1500,176 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
d.add_temporaries(std::move(copies), s.index);
}
void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
using namespace mlx::steel;
// assert(inputs.size() == 2);
if (!issubdtype(out.dtype(), floating)) {
throw std::runtime_error(
"[GatherMM] Does not yet support non-floating point types.");
}
auto& s = stream();
auto& d = metal::device(s.device);
void gather_mm_rhs(
const array& a_,
const array& b_,
const array& indices_,
array& out,
metal::Device& d,
const Stream& s) {
array indices = ensure_row_contiguous(indices_, d, s);
auto [transpose_b, ldb, b] = ensure_batch_contiguous(b_, d, s);
auto& a_pre = inputs[0];
auto& b_pre = inputs[1];
// Return 0s if either input is empty
if (a_pre.size() == 0 || b_pre.size() == 0) {
array zero = array(0, a_pre.dtype());
fill_gpu(zero, out, s);
d.add_temporary(std::move(zero), s.index);
return;
// Broadcast a with indices. If we are here that means lhs_indices were not
// provided so the lhs_indices are implied to be the shape of a broadcasted
// with rhs_indices. We need only broadcast a and copy it as if applying the
// lhs_indices.
auto broadcast_with_indices = [&d, &s, &indices](const array& x) {
if (x.size() / x.shape(-2) / x.shape(-1) == indices.size()) {
return ensure_row_contiguous(x, d, s);
}
out.set_data(allocator::malloc(out.nbytes()));
auto x_shape = indices.shape();
x_shape.push_back(x.shape(-2));
x_shape.push_back(x.shape(-1));
array new_x(std::move(x_shape), x.dtype(), nullptr, {});
broadcast(x, new_x);
return ensure_row_contiguous(new_x, d, s);
};
array a = broadcast_with_indices(a_);
/////////////////////////////////////////////////////////////////////////////
// Init checks and prep
// Extract the matmul shapes
int K = a.shape(-1);
int M = a.size() / K;
int N = b.shape(-1);
int lda = a.strides()[a.ndim() - 2]; // should be K
int M = a_pre.shape(-2);
int N = b_pre.shape(-1);
int K = a_pre.shape(-1);
// Define the dispatch blocks
int bm = 16, bn = 64, bk = 16;
int wm = 1, wn = 2;
// Keep a vector with copies to be cleared in the completed buffer to release
// the arrays
std::vector<array> copies;
auto [transpose_a, a_cols, a] = check_transpose(copies, s, a_pre, M == 1);
auto [transpose_b, b_cols, b] = check_transpose(copies, s, b_pre, N == 1);
const bool align_M = (M % bm) == 0;
const bool align_N = (N % bn) == 0;
const bool align_K = (K % bk) == 0;
int lda = a_cols;
int ldb = b_cols;
// Define the kernel name
std::string base_name;
base_name.reserve(64);
concatenate(
base_name,
"steel_gather_mm_rhs_n",
transpose_b ? 't' : 'n',
'_',
type_to_name(a),
'_',
type_to_name(out),
"_bm",
bm,
"_bn",
bn,
"_bk",
bk,
"_wm",
wm,
"_wn",
wn);
/////////////////////////////////////////////////////////////////////////////
// Check and collapse batch dimensions
auto get_batch_dims = [](const auto& v) {
return decltype(v){v.begin(), v.end() - 2};
metal::MTLFCList func_consts = {
{&align_M, MTL::DataType::DataTypeBool, 200},
{&align_N, MTL::DataType::DataTypeBool, 201},
{&align_K, MTL::DataType::DataTypeBool, 202},
};
auto& lhs_indices = inputs[2];
auto& rhs_indices = inputs[3];
// And the kernel hash that includes the function constants
std::string hash_name;
hash_name.reserve(128);
concatenate(
hash_name,
base_name,
"_align_M_",
align_M ? 't' : 'n',
"_align_N_",
align_N ? 't' : 'n',
"_align_K_",
align_K ? 't' : 'n');
Shape batch_shape = get_batch_dims(out.shape());
Strides batch_strides;
// Get and set the kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = get_steel_gemm_gather_kernel(
d,
base_name,
hash_name,
func_consts,
out,
false,
transpose_b,
bm,
bn,
bk,
wm,
wn,
true);
compute_encoder.set_compute_pipeline_state(kernel);
batch_strides.insert(
batch_strides.end(),
lhs_indices.strides().begin(),
lhs_indices.strides().end());
auto lhs_indices_str = batch_strides.empty() ? 0 : batch_strides.back();
// Prepare the matmul params
auto batch_stride_b = b.ndim() > 2 ? b.strides()[b.ndim() - 3] : b.size();
steel::GEMMParams params{
/* const int M = */ M,
/* const int N = */ N,
/* const int K = */ K,
/* const int lda = */ lda,
/* const int ldb = */ static_cast<int>(ldb),
/* const int ldd = */ N,
/* const int tiles_n = */ (N + bn - 1) / bn,
/* const int tiles_m = */ (M + bm - 1) / bm,
/* const int64_t batch_stride_a = */ 0,
/* const int64_t batch_stride_b = */ static_cast<int64_t>(batch_stride_b),
/* const int64_t batch_stride_d = */ 0,
/* const int swizzle_log = */ 0,
/* const int gemm_k_iterations_aligned = */ (K / bk),
/* const int batch_ndim = */ 0};
batch_strides.insert(
batch_strides.end(),
rhs_indices.strides().begin(),
rhs_indices.strides().end());
auto rhs_indices_str = batch_strides.empty() ? 0 : batch_strides.back();
// Prepare the grid
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(params.tiles_n, params.tiles_m, 1);
int batch_ndim = batch_shape.size();
// Launch kernel
compute_encoder.set_input_array(a, 0);
compute_encoder.set_input_array(b, 1);
compute_encoder.set_input_array(indices, 2);
compute_encoder.set_output_array(out, 3);
compute_encoder.set_bytes(params, 4);
if (batch_ndim == 0) {
batch_shape = {1};
batch_strides = {0};
}
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
int batch_ndim_A = a.ndim() - 2;
int batch_ndim_B = b.ndim() - 2;
std::vector<int> operand_batch_ndim = {batch_ndim_A, batch_ndim_B};
void gather_mv(
const array& mat_,
const array& vec_,
const array& mat_indices_,
const array& vec_indices_,
array& out,
int N,
int K,
bool is_mv,
metal::Device& d,
const Stream& s) {
// Copy if needed
std::vector<array> copies;
auto [transpose_mat, mat_cols, mat] =
check_transpose(copies, s, mat_, N == 1);
auto [transpose_vec, vec_cols, vec] = check_transpose(copies, s, vec_, true);
d.add_temporaries(std::move(copies), s.index);
Shape batch_shape_A = get_batch_dims(a.shape());
Strides batch_strides_A = get_batch_dims(a.strides());
Shape batch_shape_B = get_batch_dims(b.shape());
Strides batch_strides_B = get_batch_dims(b.strides());
// If we are doing vector matrix instead of matrix vector we need to flip the
// matrix transposition. Basically m @ v = v @ m.T assuming that v is treated
// as a one dimensional array.
transpose_mat = (!is_mv) ^ transpose_mat;
if (batch_ndim_A == 0) {
batch_shape_A = {1};
batch_strides_A = {0};
}
if (batch_ndim_B == 0) {
batch_shape_B = {1};
batch_strides_B = {0};
}
auto matrix_stride_out = static_cast<int64_t>(M) * N;
auto batch_size_out = out.size() / matrix_stride_out;
/////////////////////////////////////////////////////////////////////////////
// Gemv specialization
// Route to gemv if needed
if (std::min(M, N) == 1) {
// Collect problem info
bool is_b_matrix = N != 1;
auto& mat = is_b_matrix ? b : a;
auto& vec = is_b_matrix ? a : b;
bool transpose_mat = is_b_matrix ? !transpose_b : transpose_a;
// Define some shapes
int in_vector_len = K;
int out_vector_len = is_b_matrix ? N : M;
int out_vector_len = N;
int mat_ld = mat_cols;
int mat_cols = transpose_mat ? out_vector_len : in_vector_len;
int mat_rows = transpose_mat ? in_vector_len : out_vector_len;
int mat_ld = is_b_matrix ? b_cols : a_cols;
auto batch_strides_mat = is_b_matrix ? batch_strides_B : batch_strides_A;
auto batch_strides_vec = is_b_matrix ? batch_strides_A : batch_strides_B;
auto batch_shape_mat = is_b_matrix ? batch_shape_B : batch_shape_A;
auto batch_shape_vec = is_b_matrix ? batch_shape_A : batch_shape_B;
if (!is_b_matrix) {
batch_strides = rhs_indices.strides();
batch_strides.insert(
batch_strides.end(),
lhs_indices.strides().begin(),
lhs_indices.strides().end());
}
int batch_ndim = batch_shape.size();
int batch_size_out = out.size() / N;
int batch_ndim = out.ndim() - 2;
int batch_ndim_mat = mat.ndim() - 2;
int batch_ndim_vec = vec.ndim() - 2;
Strides index_strides = vec_indices_.strides();
index_strides.insert(
index_strides.end(),
mat_indices_.strides().begin(),
mat_indices_.strides().end());
// Determine dispatch kernel
int tm = 4, tn = 4;
@ -1652,79 +1733,104 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_bytes(mat_ld, 6);
compute_encoder.set_bytes(batch_ndim, 9);
compute_encoder.set_vector_bytes(batch_shape, 10);
compute_encoder.set_vector_bytes(batch_strides, 11);
compute_encoder.set_vector_bytes(out.shape(), 10);
compute_encoder.set_vector_bytes(index_strides, 11);
int batch_ndim_vec = batch_shape_vec.size();
compute_encoder.set_bytes(batch_ndim_vec, 12);
compute_encoder.set_vector_bytes(batch_shape_vec, 13);
compute_encoder.set_vector_bytes(batch_strides_vec, 14);
compute_encoder.set_vector_bytes(vec.shape(), 13);
compute_encoder.set_vector_bytes(vec.strides(), 14);
int batch_ndim_mat = batch_shape_mat.size();
compute_encoder.set_bytes(batch_ndim_mat, 15);
compute_encoder.set_vector_bytes(batch_shape_mat, 16);
compute_encoder.set_vector_bytes(batch_strides_mat, 17);
compute_encoder.set_vector_bytes(mat.shape(), 16);
compute_encoder.set_vector_bytes(mat.strides(), 17);
compute_encoder.set_input_array(lhs_indices, 18 + int(!is_b_matrix));
compute_encoder.set_input_array(rhs_indices, 18 + int(is_b_matrix));
compute_encoder.set_input_array(vec_indices_, 18);
compute_encoder.set_input_array(mat_indices_, 19);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
void gather_mm(
const array& a_,
const array& b_,
const array& lhs_indices,
const array& rhs_indices,
array& out,
int M,
int N,
int K,
metal::Device& d,
const Stream& s) {
// Copy if needed
std::vector<array> copies;
auto [transpose_a, lda, a] = check_transpose(copies, s, a_, false);
auto [transpose_b, ldb, b] = check_transpose(copies, s, b_, false);
d.add_temporaries(std::move(copies), s.index);
return;
}
/////////////////////////////////////////////////////////////////////////////
// Regular kernel dispatch
// Determine dispatch kernel
int bm = 64, bn = 64, bk = 16;
int wm = 2, wn = 2;
size_t batch_size_out = out.size() / M / N;
int batch_ndim = out.ndim() - 2;
int batch_ndim_a = a.ndim() - 2;
int batch_ndim_b = b.ndim() - 2;
char devc = d.get_architecture().back();
GEMM_TPARAM_MACRO(devc)
// Prepare kernel name
std::ostringstream kname;
kname << "steel_gemm_fused_" << (transpose_a ? 't' : 'n')
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
<< "_wm" << wm << "_wn" << wn;
std::string base_name = kname.str();
const bool has_batch = batch_ndim > 1;
const bool use_out_source = false;
const bool do_axpby = false;
const bool align_M = (M % bm) == 0;
const bool align_N = (N % bn) == 0;
const bool align_K = (K % bk) == 0;
const bool do_gather = true;
// Define the kernel name
std::string base_name;
base_name.reserve(128);
concatenate(
base_name,
"steel_gather_mm_",
transpose_a ? 't' : 'n',
transpose_b ? 't' : 'n',
"_",
type_to_name(a),
"_",
type_to_name(out),
"_bm",
bm,
"_bn",
bn,
"_bk",
bk,
"_wm",
wm,
"_wn",
wn);
metal::MTLFCList func_consts = {
{&has_batch, MTL::DataType::DataTypeBool, 10},
{&use_out_source, MTL::DataType::DataTypeBool, 100},
{&do_axpby, MTL::DataType::DataTypeBool, 110},
{&align_M, MTL::DataType::DataTypeBool, 200},
{&align_N, MTL::DataType::DataTypeBool, 201},
{&align_K, MTL::DataType::DataTypeBool, 202},
{&do_gather, MTL::DataType::DataTypeBool, 300},
};
// clang-format off
kname << "_has_batch_" << (has_batch ? 't' : 'n')
<< "_use_out_source_" << (use_out_source ? 't' : 'n')
<< "_do_axpby_" << (do_axpby ? 't' : 'n')
<< "_align_M_" << (align_M ? 't' : 'n')
<< "_align_N_" << (align_N ? 't' : 'n')
<< "_align_K_" << (align_K ? 't' : 'n')
<< "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on
// And the kernel hash that includes the function constants
std::string hash_name;
hash_name.reserve(128);
concatenate(
hash_name,
base_name,
"_has_batch_",
has_batch ? 't' : 'n',
"_align_M_",
align_M ? 't' : 'n',
"_align_N_",
align_N ? 't' : 'n',
"_align_K_",
align_K ? 't' : 'n');
std::string hash_name = kname.str();
// Encode and dispatch kernel
// Get and set the kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = get_steel_gemm_fused_kernel(
auto kernel = get_steel_gemm_gather_kernel(
d,
base_name,
hash_name,
@ -1736,72 +1842,96 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
bn,
bk,
wm,
wn);
wn,
false);
compute_encoder.set_compute_pipeline_state(kernel);
// Use problem size to determine threadblock swizzle
int tn = (N + bn - 1) / bn;
int tm = (M + bm - 1) / bm;
// TODO: Explore device-based tuning for swizzle
int swizzle_log = 0; // tm >= 6 ? 3 : (tm <= 3 ? 0 : 2);
// Prepare steel matmul params
GEMMParams params{
// Prepare the matmul params
steel::GEMMParams params{
/* const int M = */ M,
/* const int N = */ N,
/* const int K = */ K,
/* const int lda = */ lda,
/* const int ldb = */ ldb,
/* const int lda = */ static_cast<int>(lda),
/* const int ldb = */ static_cast<int>(ldb),
/* const int ldd = */ N,
/* const int tiles_n = */ tn,
/* const int tiles_m = */ tm,
/* const int64_t batch_stride_a = */ lhs_indices_str,
/* const int64_t batch_stride_b = */ rhs_indices_str,
/* const int64_t batch_stride_d = */ matrix_stride_out,
/* const int swizzle_log = */ swizzle_log,
/* const int tiles_n = */ (N + bn - 1) / bn,
/* const int tiles_m = */ (M + bm - 1) / bm,
/* const int64_t batch_stride_a = */
(batch_ndim > 0) ? lhs_indices.strides()[0] : 0,
/* const int64_t batch_stride_b = */
(batch_ndim > 0) ? rhs_indices.strides()[0] : 0,
/* const int64_t batch_stride_d = */ M * N,
/* const int swizzle_log = */ 0,
/* const int gemm_k_iterations_aligned = */ (K / bk),
/* const int batch_ndim = */ batch_ndim};
// Prepare launch grid params
int tile = 1 << swizzle_log;
tm = (tm + tile - 1) / tile;
tn = tn * tile;
// Prepare the grid
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out);
MTL::Size grid_dims =
MTL::Size(params.tiles_n, params.tiles_m, batch_size_out);
// Launch kernel
compute_encoder.set_input_array(a, 0);
compute_encoder.set_input_array(b, 1);
compute_encoder.set_output_array(out, 3);
compute_encoder.set_bytes(params, 4);
compute_encoder.set_vector_bytes(batch_shape, 6);
compute_encoder.set_vector_bytes(batch_strides, 7);
compute_encoder.set_input_array(lhs_indices, 10);
compute_encoder.set_input_array(rhs_indices, 11);
std::vector operand_shape = batch_shape_A;
operand_shape.insert(
operand_shape.end(), batch_shape_B.begin(), batch_shape_B.end());
std::vector operand_strides = batch_strides_A;
operand_strides.insert(
operand_strides.end(), batch_strides_B.begin(), batch_strides_B.end());
operand_batch_ndim.push_back(0);
compute_encoder.set_vector_bytes(operand_shape, 13);
compute_encoder.set_vector_bytes(operand_strides, 14);
compute_encoder.set_vector_bytes(operand_batch_ndim, 15);
compute_encoder.set_input_array(lhs_indices, 2);
compute_encoder.set_input_array(rhs_indices, 3);
compute_encoder.set_output_array(out, 4);
compute_encoder.set_bytes(params, 5);
compute_encoder.set_vector_bytes(lhs_indices.shape(), 6);
compute_encoder.set_vector_bytes(lhs_indices.strides(), 7);
compute_encoder.set_vector_bytes(rhs_indices.strides(), 8);
compute_encoder.set_bytes(batch_ndim_a, 9);
compute_encoder.set_vector_bytes(a.shape(), 10);
compute_encoder.set_vector_bytes(a.strides(), 11);
compute_encoder.set_bytes(batch_ndim_b, 12);
compute_encoder.set_vector_bytes(b.shape(), 13);
compute_encoder.set_vector_bytes(b.strides(), 14);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
d.add_temporaries(std::move(copies), s.index);
void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = stream();
auto& d = metal::device(s.device);
auto& a = inputs[0];
auto& b = inputs[1];
auto& lhs_indices = inputs[2];
auto& rhs_indices = inputs[3];
// Return 0s if either input is empty
if (a.size() == 0 || b.size() == 0) {
array zero = array(0, a.dtype());
fill_gpu(zero, out, s);
d.add_temporary(std::move(zero), s.index);
return;
}
out.set_data(allocator::malloc(out.nbytes()));
// Extract shapes from inputs.
int M = a.shape(-2);
int N = b.shape(-1);
int K = a.shape(-1);
// We are walking a in order and b is also in order so we can batch up the
// matmuls and reuse reading a and b.
if (M == 1 && right_sorted_ == true) {
gather_mm_rhs(a, b, rhs_indices, out, d, s);
return;
}
// Route to gather gemv if any of a or b are vectors
if (M == 1) {
gather_mv(b, a, rhs_indices, lhs_indices, out, N, K, false, d, s);
return;
}
if (N == 1) {
gather_mv(a, b, lhs_indices, rhs_indices, out, M, K, true, d, s);
return;
}
// Route to non specialized gather mm
gather_mm(a, b, lhs_indices, rhs_indices, out, M, N, K, d, s);
}
} // namespace mlx::core

View File

@ -193,6 +193,23 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
return d.get_kernel(kernel_name);
}
MTL::ComputePipelineState* get_steel_gemm_gather_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const array&,
bool,
bool,
int,
int,
int,
int,
int,
bool) {
return d.get_kernel(kernel_name, "mlx", hash_name, func_consts);
}
MTL::ComputePipelineState* get_gemv_masked_kernel(
metal::Device& d,
const std::string& kernel_name,
@ -252,4 +269,21 @@ MTL::ComputePipelineState* get_quantized_kernel(
return d.get_kernel(kernel_name);
}
MTL::ComputePipelineState* get_gather_qmm_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const array&,
int,
int,
int,
int,
int,
int,
int,
bool) {
return d.get_kernel(kernel_name, "mlx", hash_name, func_consts);
}
} // namespace mlx::core

File diff suppressed because it is too large Load Diff

View File

@ -60,6 +60,9 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
case Scan::Min:
reduce_type = "min";
break;
case Scan::LogAddExp:
reduce_type = "logaddexp";
break;
}
kname << reduce_type << "_" << type_to_name(in) << "_" << type_to_name(out);
auto kernel = get_scan_kernel(

View File

@ -2,6 +2,8 @@
#pragma once
#include <type_traits>
#include "mlx/array.h"
#include "mlx/backend/metal/device.h"
#include "mlx/primitives.h"
@ -58,14 +60,27 @@ inline void debug_set_primitive_buffer_label(
std::string get_primitive_string(Primitive* primitive);
template <typename T>
constexpr bool is_numeric_except_char = std::is_arithmetic_v<T> &&
!std::is_same_v<T, char> && !std::is_same_v<T, signed char> &&
!std::is_same_v<T, unsigned char> && !std::is_same_v<T, wchar_t>;
template <typename T>
void concatenate(std::string& acc, T first) {
if constexpr (is_numeric_except_char<T>) {
acc += std::to_string(first);
} else {
acc += first;
}
}
template <typename T, typename... Args>
void concatenate(std::string& acc, T first, Args... args) {
if constexpr (is_numeric_except_char<T>) {
acc += std::to_string(first);
} else {
acc += first;
}
concatenate(acc, args...);
}

20
mlx/dtype_utils.cpp Normal file
View File

@ -0,0 +1,20 @@
// Copyright © 2025 Apple Inc.
#include "mlx/dtype_utils.h"
namespace mlx::core {
const char* dtype_to_string(Dtype arg) {
if (arg == bool_) {
return "bool";
}
#define SPECIALIZE_DtypeToString(CPP_TYPE, DTYPE) \
if (DTYPE == arg) { \
return #DTYPE; \
}
MLX_FORALL_DTYPES(SPECIALIZE_DtypeToString)
#undef SPECIALIZE_DtypeToString
return "(unknown)";
}
} // namespace mlx::core

207
mlx/dtype_utils.h Normal file
View File

@ -0,0 +1,207 @@
// Copyright © 2025 Apple Inc.
// Copyright © Meta Platforms, Inc. and affiliates.
//
// This source code is licensed under the BSD-style license found in
// https://github.com/pytorch/executorch/blob/main/LICENSE
//
// Forked from
// https://github.com/pytorch/executorch/blob/main/runtime/core/exec_aten/util/scalar_type_util.h
#pragma once
#include "mlx/dtype.h"
#include <fmt/format.h>
namespace mlx::core {
// Return string representation of dtype.
const char* dtype_to_string(Dtype arg);
// Macros that iterate across different subsets of Dtypes.
//
// For all of these macros, the final `_` parameter is the name of another macro
// that takes two parameters: the name of a C type, and the name of the
// corresponding Dtype enumerator.
//
// Note that these macros should use fully-qualified namespaces (starting with
// `::`) to ensure that they can be called safely in any arbitrary namespace.
#define MLX_FORALL_INT_TYPES(_) \
_(uint8_t, uint8) \
_(uint16_t, uint16) \
_(uint32_t, uint32) \
_(uint64_t, uint64) \
_(int8_t, int8) \
_(int16_t, int16) \
_(int32_t, int32) \
_(int64_t, int64)
#define MLX_FORALL_FLOAT_TYPES(_) \
_(float16_t, float16) \
_(float, float32) \
_(double, float64) \
_(bfloat16_t, bfloat16)
// Calls the provided macro on every Dtype, providing the C type and the
// Dtype name to each call.
//
// @param _ A macro that takes two parameters: the name of a C type, and the
// name of the corresponding Dtype enumerator.
#define MLX_FORALL_DTYPES(_) \
MLX_FORALL_INT_TYPES(_) \
MLX_FORALL_FLOAT_TYPES(_) \
_(bool, bool_) \
_(complex64_t, complex64)
// Maps Dtypes to C++ types.
template <Dtype::Val N>
struct DtypeToCppType;
#define SPECIALIZE_DtypeToCppType(CPP_TYPE, DTYPE) \
template <> \
struct DtypeToCppType<Dtype::Val::DTYPE> { \
using type = CPP_TYPE; \
};
MLX_FORALL_DTYPES(SPECIALIZE_DtypeToCppType)
#undef SPECIALIZE_DtypeToCppType
// Maps C++ types to Dtypes.
template <typename T>
struct CppTypeToDtype;
#define SPECIALIZE_CppTypeToDtype(CPP_TYPE, DTYPE) \
template <> \
struct CppTypeToDtype<CPP_TYPE> \
: std::integral_constant<Dtype::Val, Dtype::Val::DTYPE> {};
MLX_FORALL_DTYPES(SPECIALIZE_CppTypeToDtype)
#undef SPECIALIZE_CppTypeToDtype
// Helper macros for switch case macros (see below)
//
// These macros are not meant to be used directly. They provide an easy way to
// generate a switch statement that can handle subsets of Dtypes supported.
#define MLX_INTERNAL_SWITCH_CASE(enum_type, CTYPE_ALIAS, ...) \
case enum_type: { \
using CTYPE_ALIAS = ::mlx::core::DtypeToCppType<enum_type>::type; \
__VA_ARGS__; \
break; \
}
#define MLX_INTERNAL_SWITCH_CHECKED(TYPE, NAME, ...) \
switch (TYPE) { \
__VA_ARGS__ \
default: \
throw std::invalid_argument(fmt::format( \
"Unhandled dtype %s for %s", dtype_to_string(TYPE), NAME)); \
}
#define MLX_INTERNAL_SWITCH_CASE_INT_TYPES(CTYPE_ALIAS, ...) \
MLX_INTERNAL_SWITCH_CASE( \
::mlx::core::Dtype::Val::uint8, CTYPE_ALIAS, __VA_ARGS__) \
MLX_INTERNAL_SWITCH_CASE( \
::mlx::core::Dtype::Val::uint16, CTYPE_ALIAS, __VA_ARGS__) \
MLX_INTERNAL_SWITCH_CASE( \
::mlx::core::Dtype::Val::uint32, CTYPE_ALIAS, __VA_ARGS__) \
MLX_INTERNAL_SWITCH_CASE( \
::mlx::core::Dtype::Val::uint64, CTYPE_ALIAS, __VA_ARGS__) \
MLX_INTERNAL_SWITCH_CASE( \
::mlx::core::Dtype::Val::int8, CTYPE_ALIAS, __VA_ARGS__) \
MLX_INTERNAL_SWITCH_CASE( \
::mlx::core::Dtype::Val::int16, CTYPE_ALIAS, __VA_ARGS__) \
MLX_INTERNAL_SWITCH_CASE( \
::mlx::core::Dtype::Val::int32, CTYPE_ALIAS, __VA_ARGS__) \
MLX_INTERNAL_SWITCH_CASE( \
::mlx::core::Dtype::Val::int64, CTYPE_ALIAS, __VA_ARGS__)
#define MLX_INTERNAL_SWITCH_CASE_FLOAT_TYPES(CTYPE_ALIAS, ...) \
MLX_INTERNAL_SWITCH_CASE( \
::mlx::core::Dtype::Val::float16, CTYPE_ALIAS, __VA_ARGS__) \
MLX_INTERNAL_SWITCH_CASE( \
::mlx::core::Dtype::Val::float32, CTYPE_ALIAS, __VA_ARGS__) \
MLX_INTERNAL_SWITCH_CASE( \
::mlx::core::Dtype::Val::float64, CTYPE_ALIAS, __VA_ARGS__) \
MLX_INTERNAL_SWITCH_CASE( \
::mlx::core::Dtype::Val::bfloat16, CTYPE_ALIAS, __VA_ARGS__)
#define MLX_INTERNAL_SWITCH_CASE_INT_FLOAT_TYPES(CTYPE_ALIAS, ...) \
MLX_INTERNAL_SWITCH_CASE_INT_TYPES(CTYPE_ALIAS, __VA_ARGS__) \
MLX_INTERNAL_SWITCH_CASE_FLOAT_TYPES(CTYPE_ALIAS, __VA_ARGS__)
#define MLX_INTERNAL_SWITCH_CASE_REAL_TYPES(CTYPE_ALIAS, ...) \
MLX_INTERNAL_SWITCH_CASE_INT_FLOAT_TYPES(CTYPE_ALIAS, __VA_ARGS__) \
MLX_INTERNAL_SWITCH_CASE( \
::mlx::core::Dtype::Val::bool_, CTYPE_ALIAS, __VA_ARGS__)
#define MLX_INTERNAL_SWITCH_CASE_COMPLEX_TYPES(CTYPE_ALIAS, ...) \
MLX_INTERNAL_SWITCH_CASE( \
::mlx::core::Dtype::Val::complex64, CTYPE_ALIAS, __VA_ARGS__)
#define MLX_INTERNAL_SWITCH_CASE_ALL_TYPES(CTYPE_ALIAS, ...) \
MLX_INTERNAL_SWITCH_CASE_REAL_TYPES(CTYPE_ALIAS, __VA_ARGS__) \
MLX_INTERNAL_SWITCH_CASE_COMPLEX_TYPES(CTYPE_ALIAS, __VA_ARGS__)
// Switch case macros
//
// These macros provide an easy way to generate switch statements that apply a
// common lambda function to subsets of Dtypes supported by MLX.
// The lambda function can type specialize to the ctype associated with the
// Dtype being handled through an alias passed as the CTYPE_ALIAS argument.
//
// Arguments:
// - ADDITIONAL: Additional Dtype case to add
// - TYPE: The Dtype to handle through the switch statement
// - NAME: A name for this operation which will be used in error messages
// - CTYPE_ALIAS: A typedef for the ctype associated with the Dtype.
// - ...: A statement to be applied to each Dtype case
//
// An example usage is:
//
// MLX_SWITCH_ALL_TYPES(input.dtype(), CTYPE, {
// output.data<CTYPE>[0] = input.data<CTYPE>[0];
// });
//
// Note that these can be nested as well:
//
// MLX_SWITCH_ALL_TYPES(input.dtype(), CTYPE_IN, {
// MLX_SWITCH_ALL_TYPES(output.dtype(), CTYPE_OUT, {
// output.data<CTYPE_OUT>[0] = input.data<CTYPE_IN>[0];
// });
// });
//
// These macros are adapted from Dispatch.h in the ATen library. The primary
// difference is that the CTYPE_ALIAS argument is exposed to users, which is
// used to alias the ctype associated with the Dtype that is being handled.
#define MLX_SWITCH_ALL_TYPES(TYPE, CTYPE_ALIAS, ...) \
switch (TYPE) { MLX_INTERNAL_SWITCH_CASE_ALL_TYPES(CTYPE_ALIAS, __VA_ARGS__) }
#define MLX_SWITCH_INT_TYPES_CHECKED(TYPE, NAME, CTYPE_ALIAS, ...) \
MLX_INTERNAL_SWITCH_CHECKED( \
TYPE, \
NAME, \
MLX_INTERNAL_SWITCH_CASE_INT_TYPES(CTYPE_ALIAS, __VA_ARGS__))
#define MLX_SWITCH_FLOAT_TYPES_CHECKED(TYPE, NAME, CTYPE_ALIAS, ...) \
MLX_INTERNAL_SWITCH_CHECKED( \
TYPE, \
NAME, \
MLX_INTERNAL_SWITCH_CASE_FLOAT_TYPES(CTYPE_ALIAS, __VA_ARGS__))
#define MLX_SWITCH_INT_FLOAT_TYPES_CHECKED(TYPE, NAME, CTYPE_ALIAS, ...) \
MLX_INTERNAL_SWITCH_CHECKED( \
TYPE, \
NAME, \
MLX_INTERNAL_SWITCH_CASE_INT_FLOAT_TYPES(CTYPE_ALIAS, __VA_ARGS__))
#define MLX_SWITCH_REAL_TYPES_CHECKED(TYPE, NAME, CTYPE_ALIAS, ...) \
MLX_INTERNAL_SWITCH_CHECKED( \
TYPE, \
NAME, \
MLX_INTERNAL_SWITCH_CASE_REAL_TYPES(CTYPE_ALIAS, __VA_ARGS__))
} // namespace mlx::core

View File

@ -1,5 +1,6 @@
// Copyright © 2024 Apple Inc.
#include "mlx/export.h"
#include <map>
#include "mlx/compile_impl.h"
#include "mlx/fast_primitives.h"
#include "mlx/primitives.h"
@ -298,7 +299,13 @@ struct PrimitiveFactory {
SERIALIZE_PRIMITIVE(Reshape),
SERIALIZE_PRIMITIVE(Reduce, "And", "Or", "Sum", "Prod", "Min", "Max"),
SERIALIZE_PRIMITIVE(Round),
SERIALIZE_PRIMITIVE(Scan, "CumSum", "CumProd", "CumMin", "CumMax"),
SERIALIZE_PRIMITIVE(
Scan,
"CumSum",
"CumProd",
"CumMin",
"CumMax",
"CumLogaddexp"),
SERIALIZE_PRIMITIVE(Scatter),
SERIALIZE_PRIMITIVE(Select),
SERIALIZE_PRIMITIVE(Sigmoid),
@ -475,7 +482,9 @@ bool FunctionTable::match(
return false;
}
}
for (auto& [_, in] : kwargs) {
auto sorted_kwargs =
std::map<std::string, array>(kwargs.begin(), kwargs.end());
for (auto& [_, in] : sorted_kwargs) {
if (!match_inputs(in, fun.inputs[i++])) {
return false;
}
@ -551,7 +560,9 @@ void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) {
// Flatten the inputs to the function for tracing
std::vector<std::string> kwarg_keys;
auto inputs = args;
for (auto& [k, v] : kwargs) {
auto sorted_kwargs =
std::map<std::string, array>(kwargs.begin(), kwargs.end());
for (auto& [k, v] : sorted_kwargs) {
kwarg_keys.push_back(k);
inputs.push_back(v);
}

View File

@ -2,14 +2,14 @@
#pragma once
#include <map>
#include <set>
#include <unordered_map>
#include "mlx/array.h"
namespace mlx::core {
using Args = std::vector<array>;
using Kwargs = std::map<std::string, array>;
using Kwargs = std::unordered_map<std::string, array>;
struct FunctionExporter;

View File

@ -111,7 +111,7 @@ array fft_impl(
for (auto ax : axes) {
n.push_back(a.shape(ax));
}
if (real && inverse) {
if (real && inverse && a.ndim() > 0) {
n.back() = (n.back() - 1) * 2;
}
return fft_impl(a, n, axes, real, inverse, s);

View File

@ -3504,6 +3504,28 @@ array cummin(
{a});
}
array logcumsumexp(
const array& a,
int axis,
bool reverse /* = false*/,
bool inclusive /* = true*/,
StreamOrDevice s /* = {}*/) {
int ndim = a.ndim();
if (axis >= ndim || axis < -ndim) {
std::ostringstream msg;
msg << "[logcumsumexp] Axis " << axis << " is out of bounds for array with "
<< a.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
axis = (axis + a.ndim()) % a.ndim();
return array(
a.shape(),
a.dtype(),
std::make_shared<Scan>(
to_stream(s), Scan::ReduceType::LogAddExp, axis, reverse, inclusive),
{a});
}
/** Convolution operations */
namespace {
@ -4006,6 +4028,7 @@ array gather_qmm(
bool transpose /* = true */,
int group_size /* = 64 */,
int bits /* = 4 */,
bool sorted_indices /* = false */,
StreamOrDevice s /* = {} */) {
if (!lhs_indices_ && !rhs_indices_) {
return quantized_matmul(
@ -4045,13 +4068,19 @@ array gather_qmm(
return array(
std::move(out_shape),
out_type,
std::make_shared<GatherQMM>(to_stream(s), group_size, bits, transpose),
std::make_shared<GatherQMM>(
to_stream(s),
group_size,
bits,
transpose,
sorted_indices && !rhs_indices_,
sorted_indices && !lhs_indices_),
{astype(x, out_type, s),
w,
std::move(w),
astype(scales, out_type, s),
astype(biases, out_type, s),
lhs_indices,
rhs_indices});
std::move(lhs_indices),
std::move(rhs_indices)});
}
array tensordot(
@ -4477,6 +4506,7 @@ array gather_mm(
array b,
std::optional<array> lhs_indices_ /* = std::nullopt */,
std::optional<array> rhs_indices_ /* = std::nullopt */,
bool sorted_indices /* = false */,
StreamOrDevice s /* = {} */) {
// If no indices, fall back to full matmul
if (!lhs_indices_ && !rhs_indices_) {
@ -4552,12 +4582,18 @@ array gather_mm(
out_shape.push_back(M);
out_shape.push_back(N);
// Caculate array
// Make the output array
auto out = array(
std::move(out_shape),
out_type,
std::make_shared<GatherMM>(to_stream(s)),
{a, b, lhs_indices, rhs_indices});
std::make_shared<GatherMM>(
to_stream(s),
sorted_indices && !rhs_indices_,
sorted_indices && !lhs_indices_),
{std::move(a),
std::move(b),
std::move(lhs_indices),
std::move(rhs_indices)});
// Remove the possibly inserted singleton dimensions
std::vector<int> axes;
@ -4879,8 +4915,10 @@ array operator^(const array& a, const array& b) {
}
array left_shift(const array& a, const array& b, StreamOrDevice s /* = {} */) {
// Bit shift on bool always up-casts to uint8
auto t = promote_types(result_type(a, b), uint8);
auto t = result_type(a, b);
if (t == bool_) {
t = uint8;
}
return bitwise_impl(
astype(a, t, s),
astype(b, t, s),
@ -4893,8 +4931,10 @@ array operator<<(const array& a, const array& b) {
}
array right_shift(const array& a, const array& b, StreamOrDevice s /* = {} */) {
// Bit shift on bool always up-casts to uint8
auto t = promote_types(result_type(a, b), uint8);
auto t = result_type(a, b);
if (t == bool_) {
t = uint8;
}
return bitwise_impl(
astype(a, t, s),
astype(b, t, s),

View File

@ -715,6 +715,14 @@ array topk(const array& a, int k, StreamOrDevice s = {});
/** Returns topk elements of the array along a given axis. */
array topk(const array& a, int k, int axis, StreamOrDevice s = {});
/** Cumulative logsumexp of an array. */
array logcumsumexp(
const array& a,
int axis,
bool reverse = false,
bool inclusive = true,
StreamOrDevice s = {});
/** The logsumexp of all elements of the array. */
array logsumexp(const array& a, bool keepdims, StreamOrDevice s = {});
inline array logsumexp(const array& a, StreamOrDevice s = {}) {
@ -1344,6 +1352,7 @@ array gather_qmm(
bool transpose = true,
int group_size = 64,
int bits = 4,
bool sorted_indices = false,
StreamOrDevice s = {});
/** Returns a contraction of a and b over multiple dimensions. */
@ -1391,6 +1400,7 @@ array gather_mm(
array b,
std::optional<array> lhs_indices = std::nullopt,
std::optional<array> rhs_indices = std::nullopt,
bool sorted_indices = false,
StreamOrDevice s = {});
/** Extract a diagonal or construct a diagonal array */

View File

@ -1275,6 +1275,61 @@ std::vector<array> Convolution::vjp(
return grads;
}
std::pair<std::vector<array>, std::vector<int>> Convolution::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
auto do_conv = [&](const array& in, const array& w, int groups) {
return conv_general(
in,
w,
kernel_strides_,
padding_,
kernel_dilation_,
input_dilation_,
groups,
flip_,
stream());
};
bool in_vmap = axes[0] >= 0;
bool w_vmap = axes[1] >= 0;
auto in = inputs[0];
auto w = inputs[1];
if (in_vmap && !w_vmap) {
// flatten / unflatten the batch dimension
// of the input / output
if (axes[0] > 0) {
in = moveaxis(in, axes[0], 0, stream());
}
auto out = do_conv(flatten(in, 0, 1, stream()), w, groups_);
out = unflatten(out, 0, {in.shape(0), in.shape(1)}, stream());
return {{out}, {0}};
} else if (!in_vmap && w_vmap) {
// flatten into the output channels of w
// unflatten the channels of the output
if (axes[1] > 0) {
w = moveaxis(w, axes[1], 0, stream());
}
auto out = do_conv(in, flatten(w, 0, 1, stream()), groups_);
out = unflatten(out, -1, {w.shape(0), w.shape(1)}, stream());
return {{out}, {static_cast<int>(out.ndim() - 2)}};
} else if (in_vmap && w_vmap) {
// use a group convolution when both inputs are vmapped
auto b = in.shape(axes[0]);
in = moveaxis(in, axes[0], -2, stream());
in = flatten(in, -2, -1, stream());
if (axes[1] > 0) {
w = moveaxis(w, axes[1], 0, stream());
}
auto c_out = w.shape(1);
w = flatten(w, 0, 1, stream());
auto out = do_conv(in, w, groups_ * b);
out = unflatten(out, -1, {b, c_out}, stream());
return {{out}, {static_cast<int>(out.ndim() - 2)}};
} else {
return {{do_conv(in, w, groups_)}, {-1}};
}
}
bool Convolution::is_equivalent(const Primitive& other) const {
const Convolution& c_other = static_cast<const Convolution&>(other);
return padding_ == c_other.padding_ &&
@ -3080,6 +3135,8 @@ std::vector<array> GatherQMM::vjp(
auto& lhs_indices = primals[4];
auto& rhs_indices = primals[5];
bool sorted = left_sorted_ || right_sorted_;
for (auto arg : argnums) {
// gradient wrt to x
if (arg == 0) {
@ -3098,6 +3155,7 @@ std::vector<array> GatherQMM::vjp(
!transpose_,
group_size_,
bits_,
sorted,
stream()),
-3,
stream()),
@ -3478,6 +3536,45 @@ std::vector<array> Scan::vjp(
if (reduce_type_ == Scan::Sum) {
return {cumsum(cotangents[0], axis_, !reverse_, inclusive_, stream())};
} else if (reduce_type_ == Scan::LogAddExp) {
// Ref:
// https://github.com/tensorflow/tensorflow/blob/2a5910906a0e0f3dbc186ff9db6386d81a63448c/tensorflow/python/ops/math_grad.py#L1832-L1863
auto x = primals[0];
auto grad = cotangents[0];
auto results = outputs[0];
auto zero = zeros({1}, grad.dtype(), stream());
auto grad_min = array(finfo(grad.dtype()).min, grad.dtype());
// Split the incoming gradient into positive and negative part
// in order to take logs. This is required for stable results.
auto log_abs_grad = log(abs(grad, stream()), stream());
auto log_grad_positive =
where(greater(grad, zero, stream()), log_abs_grad, grad_min, stream());
auto log_grad_negative =
where(less(grad, zero, stream()), log_abs_grad, grad_min, stream());
auto output_pos = exp(
add(logcumsumexp(
subtract(log_grad_positive, results, stream()),
axis_,
!reverse_,
inclusive_,
stream()),
x,
stream()));
auto output_neg = exp(
add(logcumsumexp(
subtract(log_grad_negative, results, stream()),
axis_,
!reverse_,
inclusive_,
stream()),
x,
stream()));
return {subtract(output_pos, output_neg, stream())};
} else if (reduce_type_ == Scan::Prod) {
auto in = primals[0];
// Find the location of the first 0 and set it to 1:
@ -4856,6 +4953,8 @@ std::vector<array> GatherMM::vjp(
int N = cotan.shape(-1);
int K = primals[0].shape(-1);
bool sorted = left_sorted_ || right_sorted_;
for (auto arg : argnums) {
if (arg == 0) {
// M X N * (K X N).T -> M X K
@ -4866,7 +4965,8 @@ std::vector<array> GatherMM::vjp(
base = reshape(base, {-1, M, K}, stream());
// g : (out_batch_shape) + (M, K)
auto g = gather_mm(cotan, bt, std::nullopt, rhs_indices, stream());
auto g =
gather_mm(cotan, bt, std::nullopt, rhs_indices, sorted, stream());
g = expand_dims(g, -3, stream());
auto gacc = scatter_add(base, lhs_indices, g, 0, stream());
@ -4881,7 +4981,8 @@ std::vector<array> GatherMM::vjp(
base = reshape(base, {-1, K, N}, stream());
// g : (out_batch_shape) + (K, N)
auto g = gather_mm(at, cotan, lhs_indices, std::nullopt, stream());
auto g =
gather_mm(at, cotan, lhs_indices, std::nullopt, sorted, stream());
g = expand_dims(g, -3, stream());
auto gacc = scatter_add(base, rhs_indices, g, 0, stream());
@ -4894,6 +4995,12 @@ std::vector<array> GatherMM::vjp(
return vjps;
}
bool GatherMM::is_equivalent(const Primitive& other) const {
const GatherMM& g_other = static_cast<const GatherMM&>(other);
return left_sorted_ == g_other.left_sorted_ &&
right_sorted_ == g_other.right_sorted_;
}
bool BlockMaskedMM::is_equivalent(const Primitive& other) const {
const BlockMaskedMM& a_other = static_cast<const BlockMaskedMM&>(other);
return (block_size_ == a_other.block_size_);

View File

@ -498,7 +498,13 @@ class BlockMaskedMM : public UnaryPrimitive {
class GatherMM : public UnaryPrimitive {
public:
explicit GatherMM(Stream stream) : UnaryPrimitive(stream) {}
explicit GatherMM(
Stream stream,
bool left_sorted = false,
bool right_sorted = false)
: UnaryPrimitive(stream),
left_sorted_(left_sorted),
right_sorted_(right_sorted) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -510,7 +516,14 @@ class GatherMM : public UnaryPrimitive {
const std::vector<array>& outputs) override;
DEFINE_PRINT(GatherMM)
DEFINE_DEFAULT_IS_EQUIVALENT()
bool is_equivalent(const Primitive& other) const override;
auto state() const {
return std::make_pair(left_sorted_, right_sorted_);
}
private:
bool left_sorted_;
bool right_sorted_;
};
class BroadcastAxes : public UnaryPrimitive {
@ -698,6 +711,7 @@ class Convolution : public UnaryPrimitive {
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
DEFINE_VMAP()
DEFINE_PRINT(Convolution)
bool is_equivalent(const Primitive& other) const override;
auto state() const {
@ -1578,11 +1592,19 @@ class QuantizedMatmul : public UnaryPrimitive {
class GatherQMM : public UnaryPrimitive {
public:
explicit GatherQMM(Stream stream, int group_size, int bits, bool transpose)
explicit GatherQMM(
Stream stream,
int group_size,
int bits,
bool transpose,
bool left_sorted = false,
bool right_sorted = false)
: UnaryPrimitive(stream),
group_size_(group_size),
bits_(bits),
transpose_(transpose) {}
transpose_(transpose),
left_sorted_(left_sorted),
right_sorted_(right_sorted) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -1592,13 +1614,16 @@ class GatherQMM : public UnaryPrimitive {
DEFINE_PRINT(GatherQMM)
bool is_equivalent(const Primitive& other) const override;
auto state() const {
return std::make_tuple(group_size_, bits_, transpose_);
return std::make_tuple(
group_size_, bits_, transpose_, left_sorted_, right_sorted_);
}
private:
int group_size_;
int bits_;
bool transpose_;
bool left_sorted_;
bool right_sorted_;
};
class RandomBits : public UnaryPrimitive {
@ -1728,7 +1753,7 @@ class Round : public UnaryPrimitive {
class Scan : public UnaryPrimitive {
public:
enum ReduceType { Max, Min, Sum, Prod };
enum ReduceType { Max, Min, Sum, Prod, LogAddExp };
explicit Scan(
Stream stream,
@ -1763,6 +1788,9 @@ class Scan : public UnaryPrimitive {
case Max:
os << "Max";
break;
case LogAddExp:
os << "Logaddexp";
break;
}
}
bool is_equivalent(const Primitive& other) const override;

View File

@ -5,6 +5,7 @@
#include <sstream>
#include <vector>
#include "mlx/dtype_utils.h"
#include "mlx/types/limits.h"
#include "mlx/utils.h"
@ -224,37 +225,7 @@ void print_array(std::ostream& os, const array& a) {
} // namespace
std::ostream& operator<<(std::ostream& os, const Dtype& dtype) {
switch (dtype) {
case bool_:
return os << "bool";
case uint8:
return os << "uint8";
case uint16:
return os << "uint16";
case uint32:
return os << "uint32";
case uint64:
return os << "uint64";
case int8:
return os << "int8";
case int16:
return os << "int16";
case int32:
return os << "int32";
case int64:
return os << "int64";
case float16:
return os << "float16";
case float32:
return os << "float32";
case float64:
return os << "float64";
case bfloat16:
return os << "bfloat16";
case complex64:
return os << "complex64";
}
return os;
return os << dtype_to_string(dtype);
}
std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k) {
@ -277,50 +248,7 @@ std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k) {
std::ostream& operator<<(std::ostream& os, array a) {
a.eval();
switch (a.dtype()) {
case bool_:
print_array<bool>(os, a);
break;
case uint8:
print_array<uint8_t>(os, a);
break;
case uint16:
print_array<uint16_t>(os, a);
break;
case uint32:
print_array<uint32_t>(os, a);
break;
case uint64:
print_array<uint64_t>(os, a);
break;
case int8:
print_array<int8_t>(os, a);
break;
case int16:
print_array<int16_t>(os, a);
break;
case int32:
print_array<int32_t>(os, a);
break;
case int64:
print_array<int64_t>(os, a);
break;
case float16:
print_array<float16_t>(os, a);
break;
case bfloat16:
print_array<bfloat16_t>(os, a);
break;
case float32:
print_array<float>(os, a);
break;
case float64:
print_array<double>(os, a);
break;
case complex64:
print_array<complex64_t>(os, a);
break;
}
MLX_SWITCH_ALL_TYPES(a.dtype(), CTYPE, print_array<CTYPE>(os, a));
return os;
}
@ -387,36 +315,8 @@ void set_iinfo_limits(int64_t& min, uint64_t& max) {
}
iinfo::iinfo(Dtype dtype) : dtype(dtype) {
switch (dtype) {
case int8:
set_iinfo_limits<int8_t>(min, max);
break;
case uint8:
set_iinfo_limits<uint8_t>(min, max);
break;
case int16:
set_iinfo_limits<int16_t>(min, max);
break;
case uint16:
set_iinfo_limits<uint16_t>(min, max);
break;
case int32:
set_iinfo_limits<int32_t>(min, max);
break;
case uint32:
set_iinfo_limits<uint32_t>(min, max);
break;
case int64:
set_iinfo_limits<int64_t>(min, max);
break;
case uint64:
set_iinfo_limits<uint64_t>(min, max);
break;
default:
std::ostringstream msg;
msg << "[iinfo] dtype " << dtype << " is not integral.";
throw std::invalid_argument(msg.str());
}
MLX_SWITCH_INT_TYPES_CHECKED(
dtype, "[iinfo]", CTYPE, set_iinfo_limits<CTYPE>(min, max));
}
} // namespace mlx::core

View File

@ -3,8 +3,8 @@
#pragma once
#define MLX_VERSION_MAJOR 0
#define MLX_VERSION_MINOR 24
#define MLX_VERSION_PATCH 2
#define MLX_VERSION_MINOR 25
#define MLX_VERSION_PATCH 0
#define MLX_VERSION_NUMERIC \
(100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH)

View File

@ -1202,6 +1202,28 @@ void init_array(nb::module_& m) {
nb::kw_only(),
"stream"_a = nb::none(),
"See :func:`max`.")
.def(
"logcumsumexp",
[](const mx::array& a,
std::optional<int> axis,
bool reverse,
bool inclusive,
mx::StreamOrDevice s) {
if (axis) {
return mx::logcumsumexp(a, *axis, reverse, inclusive, s);
} else {
// TODO: Implement that in the C++ API as well. See concatenate
// above.
return mx::logcumsumexp(
mx::reshape(a, {-1}, s), 0, reverse, inclusive, s);
}
},
"axis"_a = nb::none(),
nb::kw_only(),
"reverse"_a = false,
"inclusive"_a = true,
"stream"_a = nb::none(),
"See :func:`logcumsumexp`.")
.def(
"logsumexp",
[](const mx::array& a,

View File

@ -1,8 +1,8 @@
// Copyright © 2024 Apple Inc.
#include <nanobind/nanobind.h>
#include <nanobind/stl/map.h>
#include <nanobind/stl/optional.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/unordered_map.h>
#include <nanobind/stl/vector.h>
#include <fstream>
@ -16,8 +16,7 @@ namespace mx = mlx::core;
namespace nb = nanobind;
using namespace nb::literals;
std::pair<std::vector<mx::array>, std::map<std::string, mx::array>>
validate_and_extract_inputs(
std::pair<mx::Args, mx::Kwargs> validate_and_extract_inputs(
const nb::args& args,
const nb::kwargs& kwargs,
const std::string& prefix) {
@ -30,8 +29,8 @@ validate_and_extract_inputs(
"and/or dictionary of arrays.");
}
};
std::vector<mx::array> args_;
std::map<std::string, mx::array> kwargs_;
mx::Args args_;
mx::Kwargs kwargs_;
if (args.size() == 0) {
// No args so kwargs must be keyword arrays
maybe_throw(nb::try_cast(kwargs, kwargs_));
@ -81,9 +80,7 @@ class PyFunctionExporter {
void close() {
exporter_.close();
}
void operator()(
const std::vector<mx::array>& args,
const std::map<std::string, mx::array>& kwargs) {
void operator()(const mx::Args& args, const mx::Kwargs& kwargs) {
exporter_(args, kwargs);
}
@ -98,9 +95,12 @@ int py_function_exporter_tp_traverse(
PyObject* self,
visitproc visit,
void* arg) {
Py_VISIT(Py_TYPE(self));
if (!nb::inst_ready(self)) {
return 0;
}
auto* p = nb::inst_ptr<PyFunctionExporter>(self);
Py_VISIT(p->dep_.ptr());
Py_VISIT(Py_TYPE(self));
return 0;
}
@ -109,9 +109,8 @@ PyType_Slot py_function_exporter_slots[] = {
{0, 0}};
auto wrap_export_function(nb::callable fun) {
return [fun = std::move(fun)](
const std::vector<mx::array>& args_,
const std::map<std::string, mx::array>& kwargs_) {
return
[fun = std::move(fun)](const mx::Args& args_, const mx::Kwargs& kwargs_) {
auto kwargs = nb::dict();
kwargs.update(nb::cast(kwargs_));
auto args = nb::tuple(nb::cast(args_));

View File

@ -16,12 +16,12 @@ struct gc_func {
};
int gc_func_tp_traverse(PyObject* self, visitproc visit, void* arg) {
Py_VISIT(Py_TYPE(self));
gc_func* w = (gc_func*)self;
Py_VISIT(w->func);
for (auto d : w->deps) {
Py_VISIT(d);
}
Py_VISIT(Py_TYPE(self));
return 0;
};

View File

@ -2382,6 +2382,43 @@ void init_ops(nb::module_& m) {
Returns:
array: The output array with the corresponding axes reduced.
)pbdoc");
m.def(
"logcumsumexp",
[](const mx::array& a,
std::optional<int> axis,
bool reverse,
bool inclusive,
mx::StreamOrDevice s) {
if (axis) {
return mx::logcumsumexp(a, *axis, reverse, inclusive, s);
} else {
return mx::logcumsumexp(
mx::reshape(a, {-1}, s), 0, reverse, inclusive, s);
}
},
nb::arg(),
"axis"_a = nb::none(),
nb::kw_only(),
"reverse"_a = false,
"inclusive"_a = true,
"stream"_a = nb::none(),
nb::sig(
"def logcumsumexp(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Return the cumulative logsumexp of the elements along the given axis.
Args:
a (array): Input array
axis (int, optional): Optional axis to compute the cumulative logsumexp
over. If unspecified the cumulative logsumexp of the flattened array is
returned.
reverse (bool): Perform the cumulative logsumexp in reverse.
inclusive (bool): The i-th element of the output includes the i-th
element of the input.
Returns:
array: The output array.
)pbdoc");
m.def(
"logsumexp",
[](const mx::array& a,
@ -4213,9 +4250,10 @@ void init_ops(nb::module_& m) {
"group_size"_a = 64,
"bits"_a = 4,
nb::kw_only(),
"sorted_indices"_a = false,
"stream"_a = nb::none(),
nb::sig(
"def gather_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
"def gather_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, *, sorted_indices: bool = False, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Perform quantized matrix multiplication with matrix-level gather.
@ -4241,6 +4279,8 @@ void init_ops(nb::module_& m) {
shares a scale and bias. Default: ``64``.
bits (int, optional): The number of bits occupied by each element in
``w``. Default: ``4``.
sorted_indices (bool, optional): May allow a faster implementation
if the passed indices are sorted. Default: ``False``.
Returns:
array: The result of the multiplication of ``x`` with ``w``
@ -4427,9 +4467,10 @@ void init_ops(nb::module_& m) {
"lhs_indices"_a = nb::none(),
"rhs_indices"_a = nb::none(),
nb::kw_only(),
"sorted_indices"_a = false,
"stream"_a = nb::none(),
nb::sig(
"def gather_mm(a: array, b: array, /, lhs_indices: array, rhs_indices: array, *, stream: Union[None, Stream, Device] = None) -> array"),
"def gather_mm(a: array, b: array, /, lhs_indices: array, rhs_indices: array, *, sorted_indices: bool = False, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Matrix multiplication with matrix-level gather.
@ -4448,11 +4489,16 @@ void init_ops(nb::module_& m) {
For ``b`` with shape ``(B1, B2, ..., BS, M, K)``, ``rhs_indices``
contains indices from the range ``[0, B1 * B2 * ... * BS)``
If only one index is passed and it is sorted, the ``sorted_indices``
flag can be passed for a possible faster implementation.
Args:
a (array): Input array.
b (array): Input array.
lhs_indices (array, optional): Integer indices for ``a``. Default: ``None``
rhs_indices (array, optional): Integer indices for ``b``. Default: ``None``
sorted_indices (bool, optional): May allow a faster implementation
if the passed indices are sorted. Default: ``False``.
Returns:
array: The output array.

View File

@ -960,6 +960,11 @@ class PyCustomFunction {
};
int py_custom_function_tp_traverse(PyObject* self, visitproc visit, void* arg) {
Py_VISIT(Py_TYPE(self));
if (!nb::inst_ready(self)) {
return 0;
}
auto* p = nb::inst_ptr<PyCustomFunction>(self);
nb::handle v = nb::find(p->fun_);
Py_VISIT(v.ptr());
@ -975,7 +980,6 @@ int py_custom_function_tp_traverse(PyObject* self, visitproc visit, void* arg) {
nb::handle v = nb::find(*(p->vmap_fun_));
Py_VISIT(v.ptr());
}
Py_VISIT(Py_TYPE(self));
return 0;
}
int py_custom_function_tp_clear(PyObject* self) {

View File

@ -1508,6 +1508,7 @@ class TestArray(mlx_tests.MLXTestCase):
("prod", 1),
("min", 1),
("max", 1),
("logcumsumexp", 1),
("logsumexp", 1),
("mean", 1),
("var", 1),

View File

@ -1108,7 +1108,7 @@ class TestBlas(mlx_tests.MLXTestCase):
lhs_indices_ = mx.broadcast_to(lhs_indices, (3, 2))
rhs_indices_ = mx.broadcast_to(rhs_indices, (3, 2))
M = a.shape[-2]
N = b.shape[-2]
N = b.shape[-1]
K = a.shape[-1]
a = a.reshape((-1, M, K))

View File

@ -194,6 +194,11 @@ class TestFFT(mlx_tests.MLXTestCase):
r_np = np.fft.ifft(segment, n=n_fft)
self.assertTrue(np.allclose(r, r_np, atol=1e-5, rtol=1e-5))
def test_fft_throws(self):
x = mx.array(3.0)
with self.assertRaises(ValueError):
mx.fft.irfftn(x)
if __name__ == "__main__":
unittest.main()

View File

@ -1857,6 +1857,30 @@ class TestOps(mlx_tests.MLXTestCase):
y = mx.as_strided(x, (x.size,), (-1,), x.size - 1)
self.assertTrue(mx.array_equal(y, x[::-1]))
def test_logcumsumexp(self):
npop = np.logaddexp.accumulate
mxop = mx.logcumsumexp
a_npy = np.random.randn(32, 32, 32).astype(np.float32)
a_mlx = mx.array(a_npy)
for axis in (0, 1, 2):
c_npy = npop(a_npy, axis=axis)
c_mlx = mxop(a_mlx, axis=axis)
self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3))
edge_cases_npy = [
np.float32([-float("inf")] * 8),
np.float32([-float("inf"), 0, -float("inf")]),
np.float32([-float("inf"), float("inf"), -float("inf")]),
]
edge_cases_mlx = [mx.array(a) for a in edge_cases_npy]
for a_npy, a_mlx in zip(edge_cases_npy, edge_cases_mlx):
c_npy = npop(a_npy, axis=0)
c_mlx = mxop(a_mlx, axis=0)
self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3))
def test_scans(self):
a_npy = np.random.randn(32, 32, 32).astype(np.float32)
a_mlx = mx.array(a_npy)
@ -2910,6 +2934,35 @@ class TestOps(mlx_tests.MLXTestCase):
out = a[::-1]
self.assertTrue(mx.array_equal(out[-1, :], a[0, :]))
def test_complex_ops(self):
x = mx.array(
[
3.0 + 4.0j,
-5.0 + 12.0j,
-8.0 + 0.0j,
0.0 + 9.0j,
0.0 + 0.0j,
]
)
ops = ["arccos", "arcsin", "arctan", "square", "sqrt"]
for op in ops:
with self.subTest(op=op):
np_op = getattr(np, op)
mx_op = getattr(mx, op)
self.assertTrue(np.allclose(mx_op(x), np_op(x)))
x = mx.array(
[
3.0 + 4.0j,
-5.0 + 12.0j,
-8.0 + 0.0j,
0.0 + 9.0j,
9.0 + 1.0j,
]
)
self.assertTrue(np.allclose(mx.rsqrt(x), 1.0 / np.sqrt(x)))
if __name__ == "__main__":
unittest.main()

View File

@ -174,12 +174,14 @@ class TestQuantized(mlx_tests.MLXTestCase):
tests = product(
[128, 64, 32], # group_size
[2, 3, 4, 6, 8], # bits
[128, 256], # M
[32, 128, 256], # M
[128, 256, 67], # N
[0, 1, 3, 8], # B
)
for group_size, bits, M, N, B in tests:
with self.subTest(shape=(B, M, N), group_size=group_size, bits=bits):
if M < group_size:
continue
x_shape = (1, N) if B == 0 else (B, 1, N)
w_shape = (N, M) if B == 0 else (B, N, M)
x = mx.random.normal(shape=x_shape, key=k1)
@ -448,6 +450,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
)
for kwargs in inputs:
test_shape(1, 32, 128, **kwargs)
test_shape(32, 32, 256, **kwargs)
test_shape(1, 32, 256, **kwargs)
test_shape(32, 256, 32, transpose=False, **kwargs)
@ -486,6 +489,66 @@ class TestQuantized(mlx_tests.MLXTestCase):
g2 = mx.grad(f_test)(x, qw, s, b, lhs_indices, rhs_indices)
self.assertTrue(mx.allclose(g1, g2, atol=1e-4))
def test_gather_qmm_sorted(self):
def quantize(w, transpose=True, group_size=64, bits=4):
qw, s, b = mx.quantize(w, group_size=group_size, bits=bits)
w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits)
if transpose:
w_hat = w_hat.swapaxes(-1, -2)
return w_hat, qw, s, b
def gather_sort(x, indices):
N, M = indices.shape
indices = indices.flatten()
order = mx.argsort(indices)
inv_order = mx.argsort(order)
return x.flatten(0, -3)[order // M], indices[order], inv_order
def scatter_unsort(x, inv_order, shape=None):
x = x[inv_order]
if shape is not None:
x = mx.unflatten(x, 0, shape)
return x
parameters = [
# L, K, D, E, I, transpose
(128, 1024, 1024, 32, 4, True),
(128, 1024, 544, 32, 4, True),
(433, 1024, 1024, 32, 4, True),
(433, 1024, 555, 32, 4, True),
(433, 2048, 1024, 32, 4, True),
(128, 1024, 1024, 32, 4, False),
(128, 1024, 544, 32, 4, False),
(433, 1024, 1024, 32, 4, False),
(433, 1024, 544, 32, 4, False),
(433, 1024, 555, 32, 4, False),
(433, 2048, 1024, 32, 4, False),
]
for L, K, D, E, I, transpose in parameters:
K, D = (K, D) if transpose else (D, K)
ishape = (L, I)
xshape = (L, 1, 1, K)
wshape = (E, D, K) if transpose else (E, K, D)
indices = (mx.random.uniform(shape=ishape) * E).astype(mx.uint32)
x = mx.random.normal(xshape) / K**0.5
w = mx.random.normal(wshape) / K**0.5
w, *wq = quantize(w, transpose=transpose)
y1 = mx.gather_mm(x, w, rhs_indices=indices)
y2 = mx.gather_qmm(x, *wq, transpose=transpose, rhs_indices=indices)
xs, idx, inv_order = gather_sort(x, indices)
y3 = mx.gather_mm(xs, w, rhs_indices=idx, sorted_indices=True)
y4 = mx.gather_qmm(
xs, *wq, rhs_indices=idx, transpose=transpose, sorted_indices=True
)
y3 = scatter_unsort(y3, inv_order, indices.shape)
y4 = scatter_unsort(y4, inv_order, indices.shape)
self.assertTrue(mx.allclose(y1, y2, atol=1e-5))
self.assertTrue(mx.allclose(y1, y3, atol=1e-5))
self.assertTrue(mx.allclose(y1, y4, atol=1e-5))
if __name__ == "__main__":
unittest.main()

View File

@ -669,6 +669,57 @@ class TestVmap(mlx_tests.MLXTestCase):
self.assertEqual(mx.vmap(fun, in_axes=(1,))(x).shape, (3, 8))
self.assertEqual(mx.vmap(fun, in_axes=(2,))(x).shape, (4, 6))
def test_vmap_conv(self):
# vmap input only
x = mx.random.uniform(shape=(2, 2, 5, 4))
w = mx.random.uniform(shape=(8, 3, 4))
expected = mx.stack([mx.conv1d(xi, w) for xi in x])
out = mx.vmap(mx.conv1d, in_axes=(0, None))(x, w)
self.assertTrue(mx.allclose(expected, out))
x = mx.moveaxis(x, 0, 2)
out = mx.vmap(mx.conv1d, in_axes=(2, None))(x, w)
self.assertTrue(mx.allclose(expected, out))
# vmap weights only
x = mx.random.uniform(shape=(2, 5, 4))
w = mx.random.uniform(shape=(3, 8, 3, 4))
expected = mx.stack([mx.conv1d(x, wi) for wi in w])
out = mx.vmap(mx.conv1d, in_axes=(None, 0))(x, w)
self.assertTrue(mx.allclose(expected, out))
w = mx.moveaxis(w, 0, 1)
out = mx.vmap(mx.conv1d, in_axes=(None, 1))(x, w)
self.assertTrue(mx.allclose(expected, out))
# vmap weights and input
x = mx.random.uniform(shape=(3, 2, 5, 4))
w = mx.random.uniform(shape=(3, 8, 3, 4))
expected = mx.stack([mx.conv1d(xi, wi) for xi, wi in zip(x, w)])
out = mx.vmap(mx.conv1d, in_axes=(0, 0))(x, w)
self.assertTrue(mx.allclose(expected, out))
x = mx.random.uniform(shape=(2, 3, 5, 4))
w = mx.random.uniform(shape=(8, 3, 4, 3))
expected = mx.stack([mx.conv1d(x[:, i], w[..., i]) for i in range(3)])
out = mx.vmap(mx.conv1d, in_axes=(1, 3))(x, w)
self.assertTrue(mx.allclose(expected, out))
# Test with groups
x = mx.random.uniform(shape=(3, 2, 5, 8))
w = mx.random.uniform(shape=(3, 2, 3, 4))
def gconv(x, w):
return mx.conv1d(x, w, groups=2)
expected = mx.stack([gconv(xi, wi) for xi, wi in zip(x, w)])
out = mx.vmap(gconv, in_axes=(0, 0))(x, w)
self.assertTrue(mx.allclose(expected, out))
if __name__ == "__main__":
unittest.main()

View File

@ -97,8 +97,7 @@ TEST_CASE("test export primitives with state") {
TEST_CASE("test export functions with kwargs") {
std::string file_path = get_temp_file("model.mlxfn");
auto fun =
[](const std::map<std::string, array>& kwargs) -> std::vector<array> {
auto fun = [](const Kwargs& kwargs) -> std::vector<array> {
return {kwargs.at("x") + kwargs.at("y")};
};

View File

@ -3874,3 +3874,41 @@ TEST_CASE("test contiguous") {
CHECK(x.flags().col_contiguous);
CHECK_EQ(x.strides(), decltype(x.strides()){1, 2});
}
TEST_CASE("test bitwise shift operations") {
std::vector<Dtype> dtypes = {
int8, int16, int32, int64, uint8, uint16, uint32, uint64};
for (const auto& dtype : dtypes) {
array x = full({4}, 1, dtype);
array y = full({4}, 2, dtype);
auto left_shift_result = left_shift(x, y);
CHECK_EQ(left_shift_result.dtype(), dtype);
CHECK(array_equal(left_shift_result, array({4, 4, 4, 4}, dtype))
.item<bool>());
auto right_shift_result = right_shift(full({4}, 4, dtype), y);
CHECK_EQ(right_shift_result.dtype(), dtype);
CHECK(array_equal(right_shift_result, full({4}, 1, dtype)).item<bool>());
}
array x = array({127, -128}, int8);
array y = array({1, 1}, int8);
auto left_shift_result = left_shift(x, y);
auto right_shift_result = right_shift(x, y);
CHECK(array_equal(left_shift_result, array({-2, 0}, int8)).item<bool>());
CHECK(array_equal(right_shift_result, array({63, -64}, int8)).item<bool>());
array x_bool = full({4}, true, bool_);
array y_bool = full({4}, true, bool_);
auto left_shift_bool_result = left_shift(x_bool, y_bool);
auto right_shift_bool_result = right_shift(x_bool, y_bool);
CHECK_EQ(left_shift_bool_result.dtype(), uint8);
CHECK(array_equal(left_shift_bool_result, full({4}, 2, uint8)).item<bool>());
CHECK_EQ(right_shift_bool_result.dtype(), uint8);
CHECK(array_equal(right_shift_bool_result, full({4}, 0, uint8)).item<bool>());
}