Merge branch 'main' into stft

This commit is contained in:
Param Thakkar 2025-04-23 00:02:52 +05:30 committed by GitHub
commit a963a15b8d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
57 changed files with 3477 additions and 1014 deletions

View File

@ -251,7 +251,7 @@ jobs:
name: Install Python package name: Install Python package
command: | command: |
source env/bin/activate source env/bin/activate
MACOSX_DEPLOYMENT_TARGET="" DEV_RELEASE=1 \ env -u MACOSX_DEPLOYMENT_TARGET DEV_RELEASE=1 \
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \ CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
pip install . -v pip install . -v
- run: - run:

View File

@ -0,0 +1,74 @@
# Copyright © 2025 Apple Inc.
import mlx.core as mx
from time_utils import time_fn
N = 1024
D = 1024
M = 1024
E = 32
I = 4
def gather_sort(x, indices):
N, M = indices.shape
indices = indices.flatten()
order = mx.argsort(indices)
inv_order = mx.argsort(order)
return x.flatten(0, -3)[order // M], indices[order], inv_order
def scatter_unsort(x, inv_order, shape=None):
x = x[inv_order]
if shape is not None:
x = mx.unflatten(x, 0, shape)
return x
def gather_mm_simulate(x, w, indices):
x, idx, inv_order = gather_sort(x, indices)
for i in range(2):
y = mx.concatenate([x[i] @ w[j].T for i, j in enumerate(idx.tolist())], axis=0)
x = y[:, None]
x = scatter_unsort(x, inv_order, indices.shape)
return x
def time_gather_mm():
x = mx.random.normal((N, 1, 1, D)) / 1024**0.5
w1 = mx.random.normal((E, M, D)) / 1024**0.5
w2 = mx.random.normal((E, D, M)) / 1024**0.5
indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)
sorted_indices = mx.sort(indices.flatten()).reshape(N, I)
mx.eval(x, w1, w2, indices, sorted_indices)
def gather_mm(x, w1, w2, indices, sort):
idx = indices
inv_order = None
if sort:
x, idx, inv_order = gather_sort(x, indices)
x = mx.gather_mm(x, w1.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort)
x = mx.gather_mm(x, w2.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort)
if sort:
x = scatter_unsort(x, inv_order, indices.shape)
return x
time_fn(gather_mm, x, w1, w2, indices, False)
time_fn(gather_mm, x, w1, w2, sorted_indices, False)
time_fn(gather_mm, x, w1, w2, indices, True)
x = mx.random.normal((N * I, D)) / 1024**0.5
w1 = mx.random.normal((M, D)) / 1024**0.5
w2 = mx.random.normal((D, M)) / 1024**0.5
mx.eval(x, w1, w2)
def equivalent_matmul(x, w1, w2):
x = x @ w1.T
x = x @ w2.T
return x
time_fn(equivalent_matmul, x, w1, w2)
if __name__ == "__main__":
time_gather_mm()

View File

@ -0,0 +1,84 @@
# Copyright © 2025 Apple Inc.
import mlx.core as mx
from time_utils import time_fn
N = 1024
D = 1024
M = 1024
E = 32
I = 4
def gather_sort(x, indices):
N, M = indices.shape
indices = indices.flatten()
order = mx.argsort(indices)
inv_order = mx.argsort(order)
return x.flatten(0, -3)[order // M], indices[order], inv_order
def scatter_unsort(x, inv_order, shape=None):
x = x[inv_order]
if shape is not None:
x = mx.unflatten(x, 0, shape)
return x
def gather_mm_simulate(x, w, indices):
x, idx, inv_order = gather_sort(x, indices)
for i in range(2):
y = mx.concatenate(
[
mx.quantized_matmul(x[i], w[0][j], w[1][j], w[2][j], transpose=True)
for i, j in enumerate(idx.tolist())
],
axis=0,
)
x = y[:, None]
x = scatter_unsort(x, inv_order, indices.shape)
return x
def time_gather_qmm():
x = mx.random.normal((N, 1, 1, D)) / 1024**0.5
w1 = mx.random.normal((E, M, D)) / 1024**0.5
w2 = mx.random.normal((E, D, M)) / 1024**0.5
w1 = mx.quantize(w1)
w2 = mx.quantize(w2)
indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)
sorted_indices = mx.sort(indices.flatten()).reshape(N, I)
mx.eval(x, w1, w2, indices, sorted_indices)
def gather_mm(x, w1, w2, indices, sort):
idx = indices
inv_order = None
if sort:
x, idx, inv_order = gather_sort(x, indices)
x = mx.gather_qmm(x, *w1, transpose=True, rhs_indices=idx, sorted_indices=sort)
x = mx.gather_qmm(x, *w2, transpose=True, rhs_indices=idx, sorted_indices=sort)
if sort:
x = scatter_unsort(x, inv_order, indices.shape)
return x
time_fn(gather_mm, x, w1, w2, indices, False)
time_fn(gather_mm, x, w1, w2, sorted_indices, False)
time_fn(gather_mm, x, w1, w2, indices, True)
x = mx.random.normal((N * I, D)) / 1024**0.5
w1 = mx.random.normal((M, D)) / 1024**0.5
w2 = mx.random.normal((D, M)) / 1024**0.5
w1 = mx.quantize(w1)
w2 = mx.quantize(w2)
mx.eval(x, w1, w2)
def equivalent_matmul(x, w1, w2):
x = mx.quantized_matmul(x, *w1, transpose=True)
x = mx.quantized_matmul(x, *w2, transpose=True)
return x
time_fn(equivalent_matmul, x, w1, w2)
if __name__ == "__main__":
time_gather_qmm()

View File

@ -38,6 +38,7 @@ Array
array.log10 array.log10
array.log1p array.log1p
array.log2 array.log2
array.logcumsumexp
array.logsumexp array.logsumexp
array.max array.max
array.mean array.mean

View File

@ -103,6 +103,7 @@ Operations
log10 log10
log1p log1p
logaddexp logaddexp
logcumsumexp
logical_not logical_not
logical_and logical_and
logical_or logical_or

View File

@ -5,6 +5,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp ${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp ${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
${CMAKE_CURRENT_SOURCE_DIR}/dtype_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/export.cpp ${CMAKE_CURRENT_SOURCE_DIR}/export.cpp
${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp ${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp

View File

@ -339,11 +339,11 @@ class array {
return allocator::allocator().size(buffer()); return allocator::allocator().size(buffer());
} }
// Return a copy of the shared pointer // Return the shared pointer to the array::Data struct
// to the array::Data struct const std::shared_ptr<Data>& data_shared_ptr() const {
std::shared_ptr<Data> data_shared_ptr() const {
return array_desc_->data; return array_desc_->data;
} }
// Return a raw pointer to the arrays data // Return a raw pointer to the arrays data
template <typename T> template <typename T>
T* data() { T* data() {

View File

@ -1,6 +1,7 @@
target_sources( target_sources(
mlx mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/broadcasting.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp ${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp

View File

@ -0,0 +1,24 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/common/utils.h"
namespace mlx::core {
void broadcast(const array& in, array& out) {
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
Strides strides(out.ndim(), 0);
int diff = out.ndim() - in.ndim();
for (int i = in.ndim() - 1; i >= 0; --i) {
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
}
auto flags = in.flags();
if (out.size() > in.size()) {
flags.row_contiguous = flags.col_contiguous = false;
}
out.copy_shared_buffer(in, strides, flags, in.data_size());
}
} // namespace mlx::core

View File

@ -0,0 +1,11 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/array.h"
namespace mlx::core {
void broadcast(const array& in, array& out);
} // namespace mlx::core

View File

@ -1,6 +1,7 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
#include <cassert> #include <cassert>
#include "mlx/backend/common/broadcasting.h"
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
@ -42,23 +43,6 @@ void AsStrided::eval(const std::vector<array>& inputs, array& out) {
return out.copy_shared_buffer(in, strides_, flags, data_size, offset_); return out.copy_shared_buffer(in, strides_, flags, data_size, offset_);
} }
void broadcast(const array& in, array& out) {
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
Strides strides(out.ndim(), 0);
int diff = out.ndim() - in.ndim();
for (int i = in.ndim() - 1; i >= 0; --i) {
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
}
auto flags = in.flags();
if (out.size() > in.size()) {
flags.row_contiguous = flags.col_contiguous = false;
}
out.copy_shared_buffer(in, strides, flags, in.data_size());
}
void Broadcast::eval(const std::vector<array>& inputs, array& out) { void Broadcast::eval(const std::vector<array>& inputs, array& out) {
broadcast(inputs[0], out); broadcast(inputs[0], out);
} }

View File

@ -3,6 +3,7 @@
#include <cassert> #include <cassert>
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/binary_ops.h"
#include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h" #include "mlx/backend/cpu/simd/simd.h"
@ -226,6 +227,16 @@ void scan_dispatch(
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init); scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
break; break;
} }
case Scan::LogAddExp: {
auto op = [](U a, T b) {
return detail::LogAddExp{}(a, static_cast<U>(b));
};
auto init = (issubdtype(in.dtype(), floating))
? static_cast<U>(-std::numeric_limits<float>::infinity())
: std::numeric_limits<U>::min();
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
break;
}
} }
} }

View File

@ -61,6 +61,7 @@ if(MLX_METAL_JIT)
kernels/steel/gemm/transforms.h) kernels/steel/gemm/transforms.h)
make_jit_source(steel/gemm/kernels/steel_gemm_fused) make_jit_source(steel/gemm/kernels/steel_gemm_fused)
make_jit_source(steel/gemm/kernels/steel_gemm_masked kernels/steel/defines.h) make_jit_source(steel/gemm/kernels/steel_gemm_masked kernels/steel/defines.h)
make_jit_source(steel/gemm/kernels/steel_gemm_gather)
make_jit_source(steel/gemm/kernels/steel_gemm_splitk) make_jit_source(steel/gemm/kernels/steel_gemm_splitk)
make_jit_source( make_jit_source(
steel/conv/conv steel/conv/conv

View File

@ -33,6 +33,7 @@ const char* gemm();
const char* steel_gemm_fused(); const char* steel_gemm_fused();
const char* steel_gemm_masked(); const char* steel_gemm_masked();
const char* steel_gemm_splitk(); const char* steel_gemm_splitk();
const char* steel_gemm_gather();
const char* conv(); const char* conv();
const char* steel_conv(); const char* steel_conv();
const char* steel_conv_general(); const char* steel_conv_general();

View File

@ -584,6 +584,44 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
return d.get_kernel(kernel_name, lib); return d.get_kernel(kernel_name, lib);
} }
MTL::ComputePipelineState* get_steel_gemm_gather_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const array& out,
bool transpose_a,
bool transpose_b,
int bm,
int bn,
int bk,
int wm,
int wn,
bool rhs) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name, [&]() {
std::string kernel_source;
concatenate(
kernel_source,
metal::utils(),
metal::gemm(),
metal::steel_gemm_gather(),
get_template_definition(
lib_name,
rhs ? "gather_mm_rhs" : "gather_mm",
get_type_string(out.dtype()),
bm,
bn,
bk,
wm,
wn,
transpose_a,
transpose_b));
return kernel_source;
});
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
}
MTL::ComputePipelineState* get_gemv_masked_kernel( MTL::ComputePipelineState* get_gemv_masked_kernel(
metal::Device& d, metal::Device& d,
const std::string& kernel_name, const std::string& kernel_name,
@ -714,4 +752,43 @@ MTL::ComputePipelineState* get_quantized_kernel(
return d.get_kernel(kernel_name, lib); return d.get_kernel(kernel_name, lib);
} }
MTL::ComputePipelineState* get_gather_qmm_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const array& x,
int group_size,
int bits,
int bm,
int bn,
int bk,
int wm,
int wn,
bool transpose) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name, [&]() {
std::string kernel_source;
concatenate(
kernel_source,
metal::utils(),
metal::gemm(),
metal::quantized(),
get_template_definition(
lib_name,
"gather_qmm_rhs",
get_type_string(x.dtype()),
group_size,
bits,
bm,
bn,
bk,
wm,
wn,
transpose));
return kernel_source;
});
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
}
} // namespace mlx::core } // namespace mlx::core

View File

@ -160,6 +160,21 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
bool mn_aligned, bool mn_aligned,
bool k_aligned); bool k_aligned);
MTL::ComputePipelineState* get_steel_gemm_gather_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const array& out,
bool transpose_a,
bool transpose_b,
int bm,
int bn,
int bk,
int wm,
int wn,
bool rhs);
MTL::ComputePipelineState* get_steel_conv_kernel( MTL::ComputePipelineState* get_steel_conv_kernel(
metal::Device& d, metal::Device& d,
const std::string& kernel_name, const std::string& kernel_name,
@ -209,6 +224,21 @@ MTL::ComputePipelineState* get_quantized_kernel(
const std::string& kernel_name, const std::string& kernel_name,
const std::string& template_def); const std::string& template_def);
MTL::ComputePipelineState* get_gather_qmm_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const array& x,
int group_size,
int bits,
int bm,
int bn,
int bk,
int wm,
int wn,
bool transpose);
// Create a GPU kernel template definition for JIT compilation // Create a GPU kernel template definition for JIT compilation
template <typename... Args> template <typename... Args>
std::string std::string

View File

@ -69,6 +69,7 @@ set(STEEL_HEADERS
steel/gemm/loader.h steel/gemm/loader.h
steel/gemm/transforms.h steel/gemm/transforms.h
steel/gemm/kernels/steel_gemm_fused.h steel/gemm/kernels/steel_gemm_fused.h
steel/gemm/kernels/steel_gemm_gather.h
steel/gemm/kernels/steel_gemm_masked.h steel/gemm/kernels/steel_gemm_masked.h
steel/gemm/kernels/steel_gemm_splitk.h steel/gemm/kernels/steel_gemm_splitk.h
steel/utils/type_traits.h steel/utils/type_traits.h
@ -116,6 +117,7 @@ if(NOT MLX_METAL_JIT)
build_kernel(steel/conv/kernels/steel_conv ${STEEL_HEADERS}) build_kernel(steel/conv/kernels/steel_conv ${STEEL_HEADERS})
build_kernel(steel/conv/kernels/steel_conv_general ${STEEL_HEADERS}) build_kernel(steel/conv/kernels/steel_conv_general ${STEEL_HEADERS})
build_kernel(steel/gemm/kernels/steel_gemm_fused ${STEEL_HEADERS}) build_kernel(steel/gemm/kernels/steel_gemm_fused ${STEEL_HEADERS})
build_kernel(steel/gemm/kernels/steel_gemm_gather ${STEEL_HEADERS})
build_kernel(steel/gemm/kernels/steel_gemm_masked ${STEEL_HEADERS}) build_kernel(steel/gemm/kernels/steel_gemm_masked ${STEEL_HEADERS})
build_kernel(steel/gemm/kernels/steel_gemm_splitk ${STEEL_HEADERS}) build_kernel(steel/gemm/kernels/steel_gemm_splitk ${STEEL_HEADERS})
build_kernel(gemv_masked steel/utils.h) build_kernel(gemv_masked steel/utils.h)

View File

@ -104,10 +104,22 @@ constexpr bool operator==(complex64_t a, complex64_t b) {
constexpr complex64_t operator+(complex64_t a, complex64_t b) { constexpr complex64_t operator+(complex64_t a, complex64_t b) {
return {a.real + b.real, a.imag + b.imag}; return {a.real + b.real, a.imag + b.imag};
} }
constexpr complex64_t operator+(float a, complex64_t b) {
return {a + b.real, b.imag};
}
constexpr complex64_t operator+(complex64_t a, float b) {
return {a.real + b, a.imag};
}
constexpr complex64_t operator-(complex64_t a, complex64_t b) { constexpr complex64_t operator-(complex64_t a, complex64_t b) {
return {a.real - b.real, a.imag - b.imag}; return {a.real - b.real, a.imag - b.imag};
} }
constexpr complex64_t operator-(float a, complex64_t b) {
return {a - b.real, -b.imag};
}
constexpr complex64_t operator-(complex64_t a, float b) {
return {a.real - b, a.imag};
}
constexpr complex64_t operator*(complex64_t a, complex64_t b) { constexpr complex64_t operator*(complex64_t a, complex64_t b) {
return {a.real * b.real - a.imag * b.imag, a.real * b.imag + a.imag * b.real}; return {a.real * b.real - a.imag * b.imag, a.real * b.imag + a.imag * b.real};
@ -120,6 +132,13 @@ constexpr complex64_t operator/(complex64_t a, complex64_t b) {
return {x / denom, y / denom}; return {x / denom, y / denom};
} }
constexpr complex64_t operator/(float a, complex64_t b) {
auto denom = b.real * b.real + b.imag * b.imag;
auto x = a * b.real;
auto y = -a * b.imag;
return {x / denom, y / denom};
}
constexpr complex64_t operator%(complex64_t a, complex64_t b) { constexpr complex64_t operator%(complex64_t a, complex64_t b) {
auto real = a.real - (b.real * static_cast<int64_t>(a.real / b.real)); auto real = a.real - (b.real * static_cast<int64_t>(a.real / b.real));
auto imag = a.imag - (b.imag * static_cast<int64_t>(a.imag / b.imag)); auto imag = a.imag - (b.imag * static_cast<int64_t>(a.imag / b.imag));

View File

@ -3,6 +3,10 @@
#include <metal_simdgroup> #include <metal_simdgroup>
#include <metal_stdlib> #include <metal_stdlib>
constant bool align_M [[function_constant(200)]];
constant bool align_N [[function_constant(201)]];
constant bool align_K [[function_constant(202)]];
using namespace metal; using namespace metal;
#define MLX_MTL_CONST static constant constexpr const #define MLX_MTL_CONST static constant constexpr const
@ -1686,26 +1690,26 @@ template <
} }
template <typename T, int group_size, int bits> template <typename T, int group_size, int bits>
[[kernel]] void bs_qmv_fast( [[kernel]] void gather_qmv_fast(
const device uint32_t* w [[buffer(0)]], const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]], const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]], const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]], const device T* x [[buffer(3)]],
device T* y [[buffer(4)]], const device uint32_t* lhs_indices [[buffer(4)]],
const constant int& in_vec_size [[buffer(5)]], const device uint32_t* rhs_indices [[buffer(5)]],
const constant int& out_vec_size [[buffer(6)]], device T* y [[buffer(6)]],
const constant int& x_batch_ndims [[buffer(7)]], const constant int& in_vec_size [[buffer(7)]],
const constant int* x_shape [[buffer(8)]], const constant int& out_vec_size [[buffer(8)]],
const constant int64_t* x_strides [[buffer(9)]], const constant int& x_batch_ndims [[buffer(9)]],
const constant int& w_batch_ndims [[buffer(10)]], const constant int* x_shape [[buffer(10)]],
const constant int* w_shape [[buffer(11)]], const constant int64_t* x_strides [[buffer(11)]],
const constant int64_t* w_strides [[buffer(12)]], const constant int& w_batch_ndims [[buffer(12)]],
const constant int64_t* s_strides [[buffer(13)]], const constant int* w_shape [[buffer(13)]],
const constant int64_t* b_strides [[buffer(14)]], const constant int64_t* w_strides [[buffer(14)]],
const constant int& batch_ndims [[buffer(15)]], const constant int64_t* s_strides [[buffer(15)]],
const constant int* batch_shape [[buffer(16)]], const constant int64_t* b_strides [[buffer(16)]],
const device uint32_t* lhs_indices [[buffer(17)]], const constant int& batch_ndims [[buffer(17)]],
const device uint32_t* rhs_indices [[buffer(18)]], const constant int* batch_shape [[buffer(18)]],
const constant int64_t* lhs_strides [[buffer(19)]], const constant int64_t* lhs_strides [[buffer(19)]],
const constant int64_t* rhs_strides [[buffer(20)]], const constant int64_t* rhs_strides [[buffer(20)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
@ -1748,26 +1752,26 @@ template <typename T, int group_size, int bits>
} }
template <typename T, int group_size, int bits> template <typename T, int group_size, int bits>
[[kernel]] void bs_qmv( [[kernel]] void gather_qmv(
const device uint32_t* w [[buffer(0)]], const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]], const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]], const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]], const device T* x [[buffer(3)]],
device T* y [[buffer(4)]], const device uint32_t* lhs_indices [[buffer(4)]],
const constant int& in_vec_size [[buffer(5)]], const device uint32_t* rhs_indices [[buffer(5)]],
const constant int& out_vec_size [[buffer(6)]], device T* y [[buffer(6)]],
const constant int& x_batch_ndims [[buffer(7)]], const constant int& in_vec_size [[buffer(7)]],
const constant int* x_shape [[buffer(8)]], const constant int& out_vec_size [[buffer(8)]],
const constant int64_t* x_strides [[buffer(9)]], const constant int& x_batch_ndims [[buffer(9)]],
const constant int& w_batch_ndims [[buffer(10)]], const constant int* x_shape [[buffer(10)]],
const constant int* w_shape [[buffer(11)]], const constant int64_t* x_strides [[buffer(11)]],
const constant int64_t* w_strides [[buffer(12)]], const constant int& w_batch_ndims [[buffer(12)]],
const constant int64_t* s_strides [[buffer(13)]], const constant int* w_shape [[buffer(13)]],
const constant int64_t* b_strides [[buffer(14)]], const constant int64_t* w_strides [[buffer(14)]],
const constant int& batch_ndims [[buffer(15)]], const constant int64_t* s_strides [[buffer(15)]],
const constant int* batch_shape [[buffer(16)]], const constant int64_t* b_strides [[buffer(16)]],
const device uint32_t* lhs_indices [[buffer(17)]], const constant int& batch_ndims [[buffer(17)]],
const device uint32_t* rhs_indices [[buffer(18)]], const constant int* batch_shape [[buffer(18)]],
const constant int64_t* lhs_strides [[buffer(19)]], const constant int64_t* lhs_strides [[buffer(19)]],
const constant int64_t* rhs_strides [[buffer(20)]], const constant int64_t* rhs_strides [[buffer(20)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
@ -1810,26 +1814,26 @@ template <typename T, int group_size, int bits>
} }
template <typename T, int group_size, int bits> template <typename T, int group_size, int bits>
[[kernel]] void bs_qvm( [[kernel]] void gather_qvm(
const device uint32_t* w [[buffer(0)]], const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]], const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]], const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]], const device T* x [[buffer(3)]],
device T* y [[buffer(4)]], const device uint32_t* lhs_indices [[buffer(4)]],
const constant int& in_vec_size [[buffer(5)]], const device uint32_t* rhs_indices [[buffer(5)]],
const constant int& out_vec_size [[buffer(6)]], device T* y [[buffer(6)]],
const constant int& x_batch_ndims [[buffer(7)]], const constant int& in_vec_size [[buffer(7)]],
const constant int* x_shape [[buffer(8)]], const constant int& out_vec_size [[buffer(8)]],
const constant int64_t* x_strides [[buffer(9)]], const constant int& x_batch_ndims [[buffer(9)]],
const constant int& w_batch_ndims [[buffer(10)]], const constant int* x_shape [[buffer(10)]],
const constant int* w_shape [[buffer(11)]], const constant int64_t* x_strides [[buffer(11)]],
const constant int64_t* w_strides [[buffer(12)]], const constant int& w_batch_ndims [[buffer(12)]],
const constant int64_t* s_strides [[buffer(13)]], const constant int* w_shape [[buffer(13)]],
const constant int64_t* b_strides [[buffer(14)]], const constant int64_t* w_strides [[buffer(14)]],
const constant int& batch_ndims [[buffer(15)]], const constant int64_t* s_strides [[buffer(15)]],
const constant int* batch_shape [[buffer(16)]], const constant int64_t* b_strides [[buffer(16)]],
const device uint32_t* lhs_indices [[buffer(17)]], const constant int& batch_ndims [[buffer(17)]],
const device uint32_t* rhs_indices [[buffer(18)]], const constant int* batch_shape [[buffer(18)]],
const constant int64_t* lhs_strides [[buffer(19)]], const constant int64_t* lhs_strides [[buffer(19)]],
const constant int64_t* rhs_strides [[buffer(20)]], const constant int64_t* rhs_strides [[buffer(20)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
@ -1879,27 +1883,27 @@ template <
const int BM = 32, const int BM = 32,
const int BK = 32, const int BK = 32,
const int BN = 32> const int BN = 32>
[[kernel]] void bs_qmm_t( [[kernel]] void gather_qmm_t(
const device uint32_t* w [[buffer(0)]], const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]], const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]], const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]], const device T* x [[buffer(3)]],
device T* y [[buffer(4)]], const device uint32_t* lhs_indices [[buffer(4)]],
const constant int& K [[buffer(5)]], const device uint32_t* rhs_indices [[buffer(5)]],
const constant int& N [[buffer(6)]], device T* y [[buffer(6)]],
const constant int& M [[buffer(7)]], const constant int& K [[buffer(7)]],
const constant int& x_batch_ndims [[buffer(8)]], const constant int& N [[buffer(8)]],
const constant int* x_shape [[buffer(9)]], const constant int& M [[buffer(9)]],
const constant int64_t* x_strides [[buffer(10)]], const constant int& x_batch_ndims [[buffer(10)]],
const constant int& w_batch_ndims [[buffer(11)]], const constant int* x_shape [[buffer(11)]],
const constant int* w_shape [[buffer(12)]], const constant int64_t* x_strides [[buffer(12)]],
const constant int64_t* w_strides [[buffer(13)]], const constant int& w_batch_ndims [[buffer(13)]],
const constant int64_t* s_strides [[buffer(14)]], const constant int* w_shape [[buffer(14)]],
const constant int64_t* b_strides [[buffer(15)]], const constant int64_t* w_strides [[buffer(15)]],
const constant int& batch_ndims [[buffer(16)]], const constant int64_t* s_strides [[buffer(16)]],
const constant int* batch_shape [[buffer(17)]], const constant int64_t* b_strides [[buffer(17)]],
const device uint32_t* lhs_indices [[buffer(18)]], const constant int& batch_ndims [[buffer(18)]],
const device uint32_t* rhs_indices [[buffer(19)]], const constant int* batch_shape [[buffer(19)]],
const constant int64_t* lhs_strides [[buffer(20)]], const constant int64_t* lhs_strides [[buffer(20)]],
const constant int64_t* rhs_strides [[buffer(21)]], const constant int64_t* rhs_strides [[buffer(21)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
@ -1946,27 +1950,27 @@ template <
const int BM = 32, const int BM = 32,
const int BK = 32, const int BK = 32,
const int BN = 32> const int BN = 32>
[[kernel]] void bs_qmm_n( [[kernel]] void gather_qmm_n(
const device uint32_t* w [[buffer(0)]], const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]], const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]], const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]], const device T* x [[buffer(3)]],
device T* y [[buffer(4)]], const device uint32_t* lhs_indices [[buffer(4)]],
const constant int& K [[buffer(5)]], const device uint32_t* rhs_indices [[buffer(5)]],
const constant int& N [[buffer(6)]], device T* y [[buffer(6)]],
const constant int& M [[buffer(7)]], const constant int& K [[buffer(7)]],
const constant int& x_batch_ndims [[buffer(8)]], const constant int& N [[buffer(8)]],
const constant int* x_shape [[buffer(9)]], const constant int& M [[buffer(9)]],
const constant int64_t* x_strides [[buffer(10)]], const constant int& x_batch_ndims [[buffer(10)]],
const constant int& w_batch_ndims [[buffer(11)]], const constant int* x_shape [[buffer(11)]],
const constant int* w_shape [[buffer(12)]], const constant int64_t* x_strides [[buffer(12)]],
const constant int64_t* w_strides [[buffer(13)]], const constant int& w_batch_ndims [[buffer(13)]],
const constant int64_t* s_strides [[buffer(14)]], const constant int* w_shape [[buffer(14)]],
const constant int64_t* b_strides [[buffer(15)]], const constant int64_t* w_strides [[buffer(15)]],
const constant int& batch_ndims [[buffer(16)]], const constant int64_t* s_strides [[buffer(16)]],
const constant int* batch_shape [[buffer(17)]], const constant int64_t* b_strides [[buffer(17)]],
const device uint32_t* lhs_indices [[buffer(18)]], const constant int& batch_ndims [[buffer(18)]],
const device uint32_t* rhs_indices [[buffer(19)]], const constant int* batch_shape [[buffer(19)]],
const constant int64_t* lhs_strides [[buffer(20)]], const constant int64_t* lhs_strides [[buffer(20)]],
const constant int64_t* rhs_strides [[buffer(21)]], const constant int64_t* rhs_strides [[buffer(21)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
@ -2007,6 +2011,289 @@ template <
w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
} }
template <typename T, typename mma_t, typename loader_a_t, typename loader_b_t>
METAL_FUNC void gemm_loop_aligned(
threadgroup T* As,
threadgroup T* Bs,
thread mma_t& mma_op,
thread loader_a_t& loader_a,
thread loader_b_t& loader_b,
const int k_iterations) {
for (int k = 0; k < k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup memory
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
}
template <
bool rows_aligned,
bool cols_aligned,
bool transpose,
typename T,
typename mma_t,
typename loader_a_t,
typename loader_b_t>
METAL_FUNC void gemm_loop_unaligned(
threadgroup T* As,
threadgroup T* Bs,
thread mma_t& mma_op,
thread loader_a_t& loader_a,
thread loader_b_t& loader_b,
const int k_iterations,
const short tgp_bm,
const short tgp_bn,
const short tgp_bk) {
for (int k = 0; k < k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup memory
if (rows_aligned) {
loader_a.load_unsafe();
} else {
loader_a.load_safe(short2(tgp_bk, tgp_bm));
}
if (cols_aligned) {
loader_b.load_unsafe();
} else {
loader_b.load_safe(
transpose ? short2(tgp_bk, tgp_bn) : short2(tgp_bn, tgp_bk));
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
}
template <typename T, typename mma_t, typename loader_a_t, typename loader_b_t>
METAL_FUNC void gemm_loop_finalize(
threadgroup T* As,
threadgroup T* Bs,
thread mma_t& mma_op,
thread loader_a_t& loader_a,
thread loader_b_t& loader_b,
const short2 tile_a,
const short2 tile_b) {
loader_a.load_safe(tile_a);
loader_b.load_safe(tile_b);
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
template <
typename T,
int group_size,
int bits,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose>
[[kernel]] void gather_qmm_rhs(
const device T* x [[buffer(0)]],
const device uint32_t* w [[buffer(1)]],
const device T* scales [[buffer(2)]],
const device T* biases [[buffer(3)]],
const device uint32_t* indices [[buffer(4)]],
device T* y [[buffer(5)]],
const constant int& M [[buffer(6)]],
const constant int& N [[buffer(7)]],
const constant int& K [[buffer(8)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]]) {
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
constexpr int BK_padded = (BK + 16 / sizeof(T));
constexpr int BN_padded = (BN + 16 / sizeof(T));
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
using mma_t = mlx::steel::BlockMMA<
T,
T,
BM,
BN,
BK,
WM,
WN,
false,
transpose,
BK_padded,
transpose ? BK_padded : BN_padded>;
using loader_x_t =
mlx::steel::BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE>;
using loader_w_t = QuantizedBlockLoader<
T,
transpose ? BN : BK,
transpose ? BK : BN,
transpose ? BK_padded : BN_padded,
transpose,
WM * WN * SIMD_SIZE,
group_size,
bits>;
threadgroup T Xs[BM * BK_padded];
threadgroup T Ws[transpose ? BN * BK_padded : BK * BN_padded];
// Compute the block
const int K_w = K * bytes_per_pack / pack_factor;
const int K_g = K / group_size;
const int N_w = N * bytes_per_pack / pack_factor;
const int N_g = N / group_size;
const int K_it = K / BK;
const size_t stride_w = transpose ? N * K_w : K * N_w;
const size_t stride_s = transpose ? N * K_g : K * N_g;
const int y_row = tid.y * BM;
const int y_col = tid.x * BN;
const size_t y_row_long = size_t(y_row);
const size_t y_col_long = size_t(y_col);
// Prepare threadgroup bounds
const short tgp_bm = align_M ? BM : short(min(BM, M - y_row));
const short tgp_bn = align_N ? BN : short(min(BN, N - y_col));
// Calculate the final tiles in the case that K is not aligned
const int k_remain = K - K_it * BK;
const short2 tile_x = short2(k_remain, tgp_bm);
const short2 tile_w =
transpose ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
// Move x and output to the correct block
auto wl = (const device uint8_t*)w;
x += y_row_long * K;
y += y_row_long * N + y_col_long;
wl += transpose ? y_col_long * K_w : y_col * bytes_per_pack / pack_factor;
scales += transpose ? y_col_long * K_g : y_col / group_size;
biases += transpose ? y_col_long * K_g : y_col / group_size;
// Do as many matmuls as necessary
uint32_t index;
short offset;
uint32_t index_next = indices[y_row];
short offset_next = 0;
int n = 0;
while (n < tgp_bm) {
n++;
offset = offset_next;
index = index_next;
offset_next = tgp_bm;
for (; n < tgp_bm; n++) {
if (indices[y_row + n] != index) {
offset_next = n;
index_next = indices[y_row + n];
break;
}
}
threadgroup_barrier(mem_flags::mem_none);
// Prepare threadgroup mma operation
thread mma_t mma_op(simd_group_id, simd_lane_id);
// Prepare threadgroup loading operations
thread loader_x_t loader_x(x, K, Xs, simd_group_id, simd_lane_id);
thread loader_w_t loader_w(
wl + index * stride_w,
scales + index * stride_s,
biases + index * stride_s,
transpose ? K : N,
Ws,
simd_group_id,
simd_lane_id);
// Matrices are all aligned check nothing
if (align_M && align_N) {
gemm_loop_aligned(Xs, Ws, mma_op, loader_x, loader_w, K_it);
if (!align_K) {
threadgroup_barrier(mem_flags::mem_threadgroup);
gemm_loop_finalize(Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
}
// Store results to device memory
if (offset_next - offset == BM) {
mma_op.store_result(y, N);
} else {
mma_op.store_result_slice(
y, N, short2(0, offset), short2(BN, offset_next));
}
} else {
// Tile aligned so check outside of the hot loop
if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {
gemm_loop_aligned(Xs, Ws, mma_op, loader_x, loader_w, K_it);
if (!align_K) {
threadgroup_barrier(mem_flags::mem_threadgroup);
gemm_loop_finalize(
Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
}
// Store results to device memory
if (offset_next - offset == BM) {
mma_op.store_result(y, N);
} else {
mma_op.store_result_slice(
y, N, short2(0, offset), short2(BN, offset_next));
}
}
// Tile partially aligned check rows
else if (align_N || tgp_bn == BN) {
gemm_loop_unaligned<false, true, transpose>(
Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK);
if (!align_K) {
threadgroup_barrier(mem_flags::mem_threadgroup);
gemm_loop_finalize(
Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
}
mma_op.store_result_slice(
y, N, short2(0, offset), short2(BN, offset_next));
}
// Tile partially aligned check cols
else if (align_M || tgp_bm == BM) {
gemm_loop_unaligned<true, false, transpose>(
Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK);
if (!align_K) {
threadgroup_barrier(mem_flags::mem_threadgroup);
gemm_loop_finalize(
Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
}
mma_op.store_result_slice(
y, N, short2(0, offset), short2(tgp_bn, offset_next));
}
// Nothing aligned so check both rows and cols
else {
gemm_loop_unaligned<false, false, transpose>(
Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK);
if (!align_K) {
threadgroup_barrier(mem_flags::mem_threadgroup);
gemm_loop_finalize(
Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
}
mma_op.store_result_slice(
y, N, short2(0, offset), short2(tgp_bn, offset_next));
}
}
}
}
template <typename T, const int group_size, const int bits> template <typename T, const int group_size, const int bits>
[[kernel]] void affine_quantize( [[kernel]] void affine_quantize(
const device T* w [[buffer(0)]], const device T* w [[buffer(0)]],

View File

@ -60,6 +60,20 @@
bits, \ bits, \
split_k) split_k)
#define instantiate_gather_qmm_rhs(func, name, type, group_size, bits, bm, bn, bk, wm, wn, transpose) \
instantiate_kernel( \
#name "_" #type "_gs_" #group_size "_b_" #bits "_bm_" #bm "_bn_" #bn "_bk_" #bk "_wm_" #wm "_wn_" #wn, \
func, \
type, \
group_size, \
bits, \
bm, \
bn, \
bk, \
wm, \
wn, \
transpose)
#define instantiate_quantized_batched_wrap(name, type, group_size, bits) \ #define instantiate_quantized_batched_wrap(name, type, group_size, bits) \
instantiate_quantized_batched(name, type, group_size, bits, 1) \ instantiate_quantized_batched(name, type, group_size, bits, 1) \
instantiate_quantized_batched(name, type, group_size, bits, 0) instantiate_quantized_batched(name, type, group_size, bits, 0)
@ -73,14 +87,14 @@
#define instantiate_quantized_all_single(type, group_size, bits) \ #define instantiate_quantized_all_single(type, group_size, bits) \
instantiate_quantized(affine_quantize, type, group_size, bits) \ instantiate_quantized(affine_quantize, type, group_size, bits) \
instantiate_quantized(affine_dequantize, type, group_size, bits) \ instantiate_quantized(affine_dequantize, type, group_size, bits) \
instantiate_quantized(bs_qmv_fast, type, group_size, bits) \ instantiate_quantized(gather_qmv_fast, type, group_size, bits) \
instantiate_quantized(bs_qmv, type, group_size, bits) \ instantiate_quantized(gather_qmv, type, group_size, bits) \
instantiate_quantized(bs_qvm, type, group_size, bits) \ instantiate_quantized(gather_qvm, type, group_size, bits) \
instantiate_quantized(bs_qmm_n, type, group_size, bits) instantiate_quantized(gather_qmm_n, type, group_size, bits)
#define instantiate_quantized_all_aligned(type, group_size, bits) \ #define instantiate_quantized_all_aligned(type, group_size, bits) \
instantiate_quantized_aligned(bs_qmm_t, type, group_size, bits, true) \ instantiate_quantized_aligned(gather_qmm_t, type, group_size, bits, true) \
instantiate_quantized_aligned(bs_qmm_t, type, group_size, bits, false) \ instantiate_quantized_aligned(gather_qmm_t, type, group_size, bits, false) \
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, true, 1) \ instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, true, 1) \
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, true, 0) \ instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, true, 0) \
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, false, 1) \ instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, false, 1) \
@ -96,12 +110,17 @@
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 8) \ instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 8) \
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 32) instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 32)
#define instantiate_quantized_all_rhs(type, group_size, bits) \
instantiate_gather_qmm_rhs(gather_qmm_rhs, gather_qmm_rhs_nt, type, group_size, bits, 16, 32, 32, 1, 2, true) \
instantiate_gather_qmm_rhs(gather_qmm_rhs, gather_qmm_rhs_nn, type, group_size, bits, 16, 32, 32, 1, 2, false)
#define instantiate_quantized_funcs(type, group_size, bits) \ #define instantiate_quantized_funcs(type, group_size, bits) \
instantiate_quantized_all_single(type, group_size, bits) \ instantiate_quantized_all_single(type, group_size, bits) \
instantiate_quantized_all_batched(type, group_size, bits) \ instantiate_quantized_all_batched(type, group_size, bits) \
instantiate_quantized_all_aligned(type, group_size, bits) \ instantiate_quantized_all_aligned(type, group_size, bits) \
instantiate_quantized_all_quad(type, group_size, bits) \ instantiate_quantized_all_quad(type, group_size, bits) \
instantiate_quantized_all_splitk(type, group_size, bits) instantiate_quantized_all_splitk(type, group_size, bits) \
instantiate_quantized_all_rhs(type, group_size, bits)
#define instantiate_quantized_types(group_size, bits) \ #define instantiate_quantized_types(group_size, bits) \
instantiate_quantized_funcs(float, group_size, bits) \ instantiate_quantized_funcs(float, group_size, bits) \

View File

@ -2,6 +2,8 @@
#pragma once #pragma once
#include "mlx/backend/metal/kernels/binary_ops.h"
#define DEFINE_SIMD_SCAN() \ #define DEFINE_SIMD_SCAN() \
template <typename T, metal::enable_if_t<sizeof(T) < 8, bool> = true> \ template <typename T, metal::enable_if_t<sizeof(T) < 8, bool> = true> \
T simd_scan(T val) { \ T simd_scan(T val) { \
@ -139,6 +141,29 @@ struct CumMin {
} }
}; };
template <typename U>
struct CumLogaddexp {
static constexpr constant U init = Limits<U>::min;
template <typename T>
U operator()(U a, T b) {
return LogAddExp{}(a, static_cast<U>(b));
}
U simd_scan(U x) {
for (int i = 1; i <= 16; i *= 2) {
U other = simd_shuffle_and_fill_up(x, init, i);
x = LogAddExp{}(x, other);
}
return x;
}
U simd_exclusive_scan(U x) {
x = simd_scan(x);
return simd_shuffle_and_fill_up(x, init, 1);
}
};
template <typename T, typename U, int N_READS, bool reverse> template <typename T, typename U, int N_READS, bool reverse>
inline void load_unsafe(U values[N_READS], const device T* input) { inline void load_unsafe(U values[N_READS], const device T* input) {
if (reverse) { if (reverse) {

View File

@ -101,4 +101,7 @@ instantiate_scan_helper(min_int64_int64, int64_t, int64_t, CumMi
instantiate_scan_helper(min_float16_float16, half, half, CumMin, 4) instantiate_scan_helper(min_float16_float16, half, half, CumMin, 4)
instantiate_scan_helper(min_float32_float32, float, float, CumMin, 4) instantiate_scan_helper(min_float32_float32, float, float, CumMin, 4)
instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMin, 4) instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMin, 4)
instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin, 2) // clang-format on instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin, 2)
instantiate_scan_helper(logaddexp_float16_float16, half, half, CumLogaddexp, 4)
instantiate_scan_helper(logaddexp_float32_float32, float, float, CumLogaddexp, 4)
instantiate_scan_helper(logaddexp_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumLogaddexp, 4) // clang-format on

View File

@ -15,10 +15,6 @@ constant bool align_M [[function_constant(200)]];
constant bool align_N [[function_constant(201)]]; constant bool align_N [[function_constant(201)]];
constant bool align_K [[function_constant(202)]]; constant bool align_K [[function_constant(202)]];
constant bool do_gather [[function_constant(300)]];
constant bool gather_bias = do_gather && use_out_source;
// clang-format off // clang-format off
template < template <
typename T, typename T,
@ -39,12 +35,6 @@ template <
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]], const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
const constant int* batch_shape [[buffer(6)]], const constant int* batch_shape [[buffer(6)]],
const constant int64_t* batch_strides [[buffer(7)]], const constant int64_t* batch_strides [[buffer(7)]],
const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]],
const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]],
const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]],
const constant int* operand_shape [[buffer(13), function_constant(do_gather)]],
const constant int64_t* operand_strides [[buffer(14), function_constant(do_gather)]],
const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]],
uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
@ -81,84 +71,26 @@ template <
} }
// Adjust for batch // Adjust for batch
if (has_batch) {
const constant auto* A_bstrides = batch_strides;
const constant auto* B_bstrides = batch_strides + params->batch_ndim;
// Handle gather ulong2 batch_offsets = elem_to_loc_broadcast(
if (do_gather) { tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
// Read indices
uint32_t indx_A, indx_B, indx_C;
if (has_batch) { A += batch_offsets.x;
const constant auto* indx_A_bstrides = batch_strides; B += batch_offsets.y;
const constant auto* indx_B_bstrides = batch_strides + params->batch_ndim;
ulong2 indx_offsets = elem_to_loc_broadcast(
tid.z,
batch_shape,
indx_A_bstrides,
indx_B_bstrides,
params->batch_ndim);
indx_A = lhs_indices[indx_offsets.x];
indx_B = rhs_indices[indx_offsets.y];
if (use_out_source) {
const constant auto* indx_C_bstrides =
indx_B_bstrides + params->batch_ndim;
auto indx_offset_C = elem_to_loc(
tid.z, batch_shape, indx_C_bstrides, params->batch_ndim);
indx_C = C_indices[indx_offset_C];
}
} else {
indx_A = lhs_indices[params->batch_stride_a * tid.z];
indx_B = rhs_indices[params->batch_stride_b * tid.z];
if (use_out_source) {
indx_C = C_indices[addmm_params->batch_stride_c * tid.z];
}
}
// Translate indices to offsets
int batch_ndim_A = operand_batch_ndim.x;
const constant int* batch_shape_A = operand_shape;
const constant auto* batch_strides_A = operand_strides;
A += elem_to_loc(indx_A, batch_shape_A, batch_strides_A, batch_ndim_A);
int batch_ndim_B = operand_batch_ndim.y;
const constant int* batch_shape_B = batch_shape_A + batch_ndim_A;
const constant auto* batch_strides_B = batch_strides_A + batch_ndim_A;
B += elem_to_loc(indx_B, batch_shape_B, batch_strides_B, batch_ndim_B);
if (use_out_source) { if (use_out_source) {
int batch_ndim_C = operand_batch_ndim.z; const constant auto* C_bstrides = B_bstrides + params->batch_ndim;
const constant int* batch_shape_C = batch_shape_B + batch_ndim_B; C += elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim);
const constant auto* batch_strides_C = batch_strides_B + batch_ndim_B;
C += elem_to_loc(indx_C, batch_shape_C, batch_strides_C, batch_ndim_C);
} }
} else {
A += params->batch_stride_a * tid.z;
B += params->batch_stride_b * tid.z;
} if (use_out_source) {
C += addmm_params->batch_stride_c * tid.z;
// Handle regular batch
else {
if (has_batch) {
const constant auto* A_bstrides = batch_strides;
const constant auto* B_bstrides = batch_strides + params->batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
A += batch_offsets.x;
B += batch_offsets.y;
if (use_out_source) {
const constant auto* C_bstrides = B_bstrides + params->batch_ndim;
C += elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim);
}
} else {
A += params->batch_stride_a * tid.z;
B += params->batch_stride_b * tid.z;
if (use_out_source) {
C += addmm_params->batch_stride_c * tid.z;
}
} }
} }

View File

@ -0,0 +1,459 @@
// Copyright © 2024 Apple Inc.
using namespace mlx::steel;
constant bool has_batch [[function_constant(10)]];
constant bool align_M [[function_constant(200)]];
constant bool align_N [[function_constant(201)]];
constant bool align_K [[function_constant(202)]];
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
typename AccumType = float>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gather_mm_rhs(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
const device uint32_t* rhs_indices [[buffer(2)]],
device T* C [[buffer(3)]],
const constant GEMMParams* params [[buffer(4)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]]) {
using gemm_kernel = GEMMKernel<
T,
T,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
true,
true,
AccumType>;
using loader_a_t = typename gemm_kernel::loader_a_t;
using loader_b_t = typename gemm_kernel::loader_b_t;
using mma_t = typename gemm_kernel::mma_t;
if (params->tiles_n <= static_cast<int>(tid.x) ||
params->tiles_m <= static_cast<int>(tid.y)) {
return;
}
// Prepare threadgroup memory
threadgroup T As[gemm_kernel::tgp_mem_size_a];
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
// Find the block in A, B, C
const int c_row = tid.y * BM;
const int c_col = tid.x * BN;
const size_t c_row_long = size_t(c_row);
const size_t c_col_long = size_t(c_col);
// Prepare threadgroup bounds
const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row));
const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col));
A += transpose_a ? c_row_long : c_row_long * params->lda;
B += transpose_b ? c_col_long * params->ldb : c_col_long;
C += c_row_long * params->ldd + c_col_long;
// Do as many matmuls as necessary
uint32_t index;
short offset;
uint32_t index_next = rhs_indices[c_row];
short offset_next = 0;
int n = 0;
while (n < tgp_bm) {
n++;
offset = offset_next;
index = index_next;
offset_next = tgp_bm;
for (; n < tgp_bm; n++) {
if (rhs_indices[c_row + n] != index) {
offset_next = n;
index_next = rhs_indices[c_row + n];
break;
}
}
threadgroup_barrier(mem_flags::mem_none);
// Prepare threadgroup mma operation
thread mma_t mma_op(simd_group_id, simd_lane_id);
// Prepare threadgroup loading operations
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
thread loader_b_t loader_b(
B + index * params->batch_stride_b,
params->ldb,
Bs,
simd_group_id,
simd_lane_id);
// Prepare iterations
const int gemm_k_iterations = params->gemm_k_iterations_aligned;
// Do unaligned K iterations first
if (!align_K) {
const int k_last = params->gemm_k_iterations_aligned * BK;
const int k_remain = params->K - k_last;
const size_t k_jump_a =
transpose_a ? params->lda * size_t(k_last) : size_t(k_last);
const size_t k_jump_b =
transpose_b ? size_t(k_last) : params->ldb * size_t(k_last);
// Move loader source ahead to end
loader_a.src += k_jump_a;
loader_b.src += k_jump_b;
// Load tile
const short2 tile_dims_A =
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
const short2 tile_dims_B =
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
loader_a.load_safe(tile_dims_A);
loader_b.load_safe(tile_dims_B);
threadgroup_barrier(mem_flags::mem_threadgroup);
// Do matmul
mma_op.mma(As, Bs);
// Reset source back to start
loader_a.src -= k_jump_a;
loader_b.src -= k_jump_b;
}
// Matrix level aligned never check
if (align_M && align_N) {
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
// Store results to device memory
if (offset_next - offset == BM) {
mma_op.store_result(C, params->ldd);
} else {
mma_op.store_result_slice(
C, params->ldd, short2(0, offset), short2(BN, offset_next));
}
} else {
const short lbk = 0;
// Tile aligned don't check
if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
lbk,
LoopAlignment<true, true, true>{});
if (offset_next - offset == BM) {
mma_op.store_result(C, params->ldd);
} else {
mma_op.store_result_slice(
C, params->ldd, short2(0, offset), short2(BN, offset_next));
}
}
// Tile partially aligned check rows
else if (align_N || tgp_bn == BN) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
lbk,
LoopAlignment<false, true, true>{});
mma_op.store_result_slice(
C, params->ldd, short2(0, offset), short2(BN, offset_next));
}
// Tile partially aligned check cols
else if (align_M || tgp_bm == BM) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
lbk,
LoopAlignment<true, false, true>{});
mma_op.store_result_slice(
C, params->ldd, short2(0, offset), short2(tgp_bn, offset_next));
}
// Nothing aligned so check both rows and cols
else {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
lbk,
LoopAlignment<false, false, true>{});
mma_op.store_result_slice(
C, params->ldd, short2(0, offset), short2(tgp_bn, offset_next));
}
}
}
}
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
typename AccumType = float>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gather_mm(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
const device uint32_t* lhs_indices [[buffer(2)]],
const device uint32_t* rhs_indices [[buffer(3)]],
device T* C [[buffer(4)]],
const constant GEMMParams* params [[buffer(5)]],
const constant int* indices_shape [[buffer(6)]],
const constant int64_t* lhs_strides [[buffer(7)]],
const constant int64_t* rhs_strides [[buffer(8)]],
const constant int& batch_ndim_a [[buffer(9)]],
const constant int* batch_shape_a [[buffer(10)]],
const constant int64_t* batch_strides_a [[buffer(11)]],
const constant int& batch_ndim_b [[buffer(12)]],
const constant int* batch_shape_b [[buffer(13)]],
const constant int64_t* batch_strides_b [[buffer(14)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]]) {
using gemm_kernel = GEMMKernel<
T,
T,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
true,
true,
AccumType>;
using loader_a_t = typename gemm_kernel::loader_a_t;
using loader_b_t = typename gemm_kernel::loader_b_t;
using mma_t = typename gemm_kernel::mma_t;
if (params->tiles_n <= static_cast<int>(tid.x) ||
params->tiles_m <= static_cast<int>(tid.y)) {
return;
}
// Move A and B to the locations pointed by lhs_indices and rhs_indices.
uint32_t indx_A, indx_B;
if (has_batch) {
ulong2 indices_offsets = elem_to_loc_broadcast(
tid.z, indices_shape, lhs_strides, rhs_strides, params->batch_ndim);
indx_A = lhs_indices[indices_offsets.x];
indx_B = rhs_indices[indices_offsets.y];
} else {
indx_A = lhs_indices[params->batch_stride_a * tid.z];
indx_B = rhs_indices[params->batch_stride_b * tid.z];
}
A += elem_to_loc(indx_A, batch_shape_a, batch_strides_a, batch_ndim_a);
B += elem_to_loc(indx_B, batch_shape_b, batch_strides_b, batch_ndim_b);
C += params->batch_stride_d * tid.z;
// Prepare threadgroup memory
threadgroup T As[gemm_kernel::tgp_mem_size_a];
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
// Just make sure everybody's finished with the indexing math above.
threadgroup_barrier(mem_flags::mem_none);
// Find block in A, B, C
const int c_row = tid.y * BM;
const int c_col = tid.x * BN;
const size_t c_row_long = size_t(c_row);
const size_t c_col_long = size_t(c_col);
A += transpose_a ? c_row_long : c_row_long * params->lda;
B += transpose_b ? c_col_long * params->ldb : c_col_long;
C += c_row_long * params->ldd + c_col_long;
// Prepare threadgroup mma operation
thread mma_t mma_op(simd_group_id, simd_lane_id);
// Prepare threadgroup loading operations
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
// Prepare threadgroup bounds
const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row));
const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col));
// Prepare iterations
int gemm_k_iterations = params->gemm_k_iterations_aligned;
// Do unaligned K iterations first
if (!align_K) {
const int k_last = params->gemm_k_iterations_aligned * BK;
const int k_remain = params->K - k_last;
const size_t k_jump_a =
transpose_a ? params->lda * size_t(k_last) : size_t(k_last);
const size_t k_jump_b =
transpose_b ? size_t(k_last) : params->ldb * size_t(k_last);
// Move loader source ahead to end
loader_a.src += k_jump_a;
loader_b.src += k_jump_b;
// Load tile
const short2 tile_dims_A =
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
const short2 tile_dims_B =
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
loader_a.load_safe(tile_dims_A);
loader_b.load_safe(tile_dims_B);
threadgroup_barrier(mem_flags::mem_threadgroup);
// Do matmul
mma_op.mma(As, Bs);
// Reset source back to start
loader_a.src -= k_jump_a;
loader_b.src -= k_jump_b;
}
// Matrix level aligned never check
if (align_M && align_N) {
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
// Store results to device memory
mma_op.store_result(C, params->ldd);
} else {
const short lbk = 0;
// Tile aligned don't check
if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
lbk,
LoopAlignment<true, true, true>{});
mma_op.store_result(C, params->ldd);
}
// Tile partially aligned check rows
else if (align_N || tgp_bn == BN) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
lbk,
LoopAlignment<false, true, true>{});
mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm));
}
// Tile partially aligned check cols
else if (align_M || tgp_bm == BM) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
lbk,
LoopAlignment<true, false, true>{});
mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm));
}
// Nothing aligned so check both rows and cols
else {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
lbk,
LoopAlignment<false, false, true>{});
mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm));
}
}
}

View File

@ -0,0 +1,59 @@
// Copyright © 2024 Apple Inc.
// clang-format off
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h"
#define instantiate_gather_mm_rhs(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_kernel( \
"steel_gather_mm_rhs_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn \
"_bk" #bk "_wm" #wm "_wn" #wn, \
gather_mm_rhs, \
itype, \
bm, \
bn, \
bk, \
wm, \
wn, \
trans_a, \
trans_b, \
float)
#define instantiate_gather_mm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_kernel( \
"steel_gather_mm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn \
"_bk" #bk "_wm" #wm "_wn" #wn, \
gather_mm, \
itype, \
bm, \
bn, \
bk, \
wm, \
wn, \
trans_a, \
trans_b, \
float)
#define instantiate_gather_mm_rhs_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gather_mm_rhs(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gather_mm_rhs(nt, false, true, iname, itype, oname, otype, bm, bn, bk, wm, wn)
#define instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gather_mm(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gather_mm(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gather_mm(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gather_mm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
#define instantiate_gather_mm_shapes_helper(iname, itype, oname, otype) \
instantiate_gather_mm_rhs_transpose_helper(iname, itype, oname, otype, 16, 64, 16, 1, 2) \
instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 1, 2) \
instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 1, 2) \
instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2)
// clang-format on
instantiate_gather_mm_shapes_helper(float16, half, float16, half);
instantiate_gather_mm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
instantiate_gather_mm_shapes_helper(float32, float, float32, float);

View File

@ -142,6 +142,42 @@ struct BaseMMAFrag<T, 8, 8> {
} }
} }
template <
typename DstPtrType,
typename StrX,
typename StrY,
typename StartX,
typename StopX,
typename StartY,
typename StopY,
typename OffX,
typename OffY>
METAL_FUNC static constexpr void store_slice(
const thread frag_type& src,
DstPtrType dst,
StrX str_x,
StrY str_y,
StartX start_x,
StopX stop_x,
StartY start_y,
StopY stop_y,
OffX off_x = Int<0>{},
OffY off_y = Int<0>{}) {
using U = pointer_element_t<DstPtrType>;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kElemRows; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kElemCols; j++) {
if ((off_x + i) < stop_x && (off_x + i) >= start_x &&
(off_y + j) < stop_y && (off_y + j) >= start_y) {
dst[(off_x + i) * str_x + (off_y + j) * str_y] =
static_cast<U>(src[i * kElemCols + j]);
}
}
}
}
METAL_FUNC static constexpr void mma( METAL_FUNC static constexpr void mma(
thread frag_type& D, thread frag_type& D,
thread frag_type& A, thread frag_type& A,
@ -335,6 +371,31 @@ struct MMATile {
} }
} }
} }
template <typename U, int w_x, int w_y>
METAL_FUNC void store_slice(
device U* dst,
const int ld,
const short2 start,
const short2 stop) const {
STEEL_PRAGMA_UNROLL
for (int i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (int j = 0; j < kTileCols; ++j) {
MMAFrag_t::store_slice(
frag_at(i, j),
dst,
ld,
Int<1>{},
start.y,
stop.y,
start.x,
stop.x,
(i * kFragRows) * w_x,
(j * kFragCols) * w_y);
}
}
}
}; };
template <typename T, typename U, int M, int N, int K> template <typename T, typename U, int M, int N, int K>
@ -474,6 +535,26 @@ struct BlockMMA {
Ctile.template store<U, WM, WN>(D, ldd); Ctile.template store<U, WM, WN>(D, ldd);
} }
METAL_FUNC void
store_result_slice(device U* D, const int ldd, short2 start, short2 stop) {
// Apply epilogue
STEEL_PRAGMA_UNROLL
for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]);
}
D += sm * ldd + sn;
start -= short2(sn, sm);
stop -= short2(sn, sm);
// TODO: Check the start as well
if (stop.y <= 0 || stop.x <= 0) {
return;
}
Ctile.template store_slice<U, WM, WN>(D, ldd, start, stop);
}
METAL_FUNC void METAL_FUNC void
store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) { store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) {
// Apply epilogue // Apply epilogue

View File

@ -69,6 +69,9 @@ instantiate_unary_float(Round)
instantiate_unary_int(BitwiseInvert) instantiate_unary_int(BitwiseInvert)
instantiate_unary_all_same(Abs, complex64, complex64_t) instantiate_unary_all_same(Abs, complex64, complex64_t)
instantiate_unary_all_same(ArcCos, complex64, complex64_t)
instantiate_unary_all_same(ArcSin, complex64, complex64_t)
instantiate_unary_all_same(ArcTan, complex64, complex64_t)
instantiate_unary_all_same(Conjugate, complex64, complex64_t) instantiate_unary_all_same(Conjugate, complex64, complex64_t)
instantiate_unary_all_same(Cos, complex64, complex64_t) instantiate_unary_all_same(Cos, complex64, complex64_t)
instantiate_unary_all_same(Cosh, complex64, complex64_t) instantiate_unary_all_same(Cosh, complex64, complex64_t)
@ -80,6 +83,9 @@ instantiate_unary_all_same(Negative, complex64, complex64_t)
instantiate_unary_all_same(Sign, complex64, complex64_t) instantiate_unary_all_same(Sign, complex64, complex64_t)
instantiate_unary_all_same(Sin, complex64, complex64_t) instantiate_unary_all_same(Sin, complex64, complex64_t)
instantiate_unary_all_same(Sinh, complex64, complex64_t) instantiate_unary_all_same(Sinh, complex64, complex64_t)
instantiate_unary_all_same(Square, complex64, complex64_t)
instantiate_unary_all_same(Sqrt, complex64, complex64_t)
instantiate_unary_all_same(Rsqrt, complex64, complex64_t)
instantiate_unary_all_same(Tan, complex64, complex64_t) instantiate_unary_all_same(Tan, complex64, complex64_t)
instantiate_unary_all_same(Tanh, complex64, complex64_t) instantiate_unary_all_same(Tanh, complex64, complex64_t)
instantiate_unary_all_same(Round, complex64, complex64_t) instantiate_unary_all_same(Round, complex64, complex64_t)

View File

@ -17,27 +17,21 @@ struct Abs {
T operator()(T x) { T operator()(T x) {
return metal::abs(x); return metal::abs(x);
}; };
template <>
uint8_t operator()(uint8_t x) { uint8_t operator()(uint8_t x) {
return x; return x;
}; };
template <>
uint16_t operator()(uint16_t x) { uint16_t operator()(uint16_t x) {
return x; return x;
}; };
template <>
uint32_t operator()(uint32_t x) { uint32_t operator()(uint32_t x) {
return x; return x;
}; };
template <>
uint64_t operator()(uint64_t x) { uint64_t operator()(uint64_t x) {
return x; return x;
}; };
template <>
bool operator()(bool x) { bool operator()(bool x) {
return x; return x;
}; };
template <>
complex64_t operator()(complex64_t x) { complex64_t operator()(complex64_t x) {
return {metal::precise::sqrt(x.real * x.real + x.imag * x.imag), 0}; return {metal::precise::sqrt(x.real * x.real + x.imag * x.imag), 0};
}; };
@ -48,6 +42,8 @@ struct ArcCos {
T operator()(T x) { T operator()(T x) {
return metal::precise::acos(x); return metal::precise::acos(x);
}; };
complex64_t operator()(complex64_t x);
}; };
struct ArcCosh { struct ArcCosh {
@ -62,6 +58,8 @@ struct ArcSin {
T operator()(T x) { T operator()(T x) {
return metal::precise::asin(x); return metal::precise::asin(x);
}; };
complex64_t operator()(complex64_t x);
}; };
struct ArcSinh { struct ArcSinh {
@ -76,6 +74,8 @@ struct ArcTan {
T operator()(T x) { T operator()(T x) {
return metal::precise::atan(x); return metal::precise::atan(x);
}; };
complex64_t operator()(complex64_t x);
}; };
struct ArcTanh { struct ArcTanh {
@ -97,39 +97,30 @@ struct Ceil {
T operator()(T x) { T operator()(T x) {
return metal::ceil(x); return metal::ceil(x);
}; };
template <>
int8_t operator()(int8_t x) { int8_t operator()(int8_t x) {
return x; return x;
}; };
template <>
int16_t operator()(int16_t x) { int16_t operator()(int16_t x) {
return x; return x;
}; };
template <>
int32_t operator()(int32_t x) { int32_t operator()(int32_t x) {
return x; return x;
}; };
template <>
int64_t operator()(int64_t x) { int64_t operator()(int64_t x) {
return x; return x;
}; };
template <>
uint8_t operator()(uint8_t x) { uint8_t operator()(uint8_t x) {
return x; return x;
}; };
template <>
uint16_t operator()(uint16_t x) { uint16_t operator()(uint16_t x) {
return x; return x;
}; };
template <>
uint32_t operator()(uint32_t x) { uint32_t operator()(uint32_t x) {
return x; return x;
}; };
template <>
uint64_t operator()(uint64_t x) { uint64_t operator()(uint64_t x) {
return x; return x;
}; };
template <>
bool operator()(bool x) { bool operator()(bool x) {
return x; return x;
}; };
@ -141,7 +132,6 @@ struct Cos {
return metal::precise::cos(x); return metal::precise::cos(x);
}; };
template <>
complex64_t operator()(complex64_t x) { complex64_t operator()(complex64_t x) {
return { return {
metal::precise::cos(x.real) * metal::precise::cosh(x.imag), metal::precise::cos(x.real) * metal::precise::cosh(x.imag),
@ -155,7 +145,6 @@ struct Cosh {
return metal::precise::cosh(x); return metal::precise::cosh(x);
}; };
template <>
complex64_t operator()(complex64_t x) { complex64_t operator()(complex64_t x) {
return { return {
metal::precise::cosh(x.real) * metal::precise::cos(x.imag), metal::precise::cosh(x.real) * metal::precise::cos(x.imag),
@ -188,7 +177,6 @@ struct Exp {
T operator()(T x) { T operator()(T x) {
return metal::precise::exp(x); return metal::precise::exp(x);
}; };
template <>
complex64_t operator()(complex64_t x) { complex64_t operator()(complex64_t x) {
auto m = metal::precise::exp(x.real); auto m = metal::precise::exp(x.real);
return {m * metal::precise::cos(x.imag), m * metal::precise::sin(x.imag)}; return {m * metal::precise::cos(x.imag), m * metal::precise::sin(x.imag)};
@ -207,39 +195,30 @@ struct Floor {
T operator()(T x) { T operator()(T x) {
return metal::floor(x); return metal::floor(x);
}; };
template <>
int8_t operator()(int8_t x) { int8_t operator()(int8_t x) {
return x; return x;
}; };
template <>
int16_t operator()(int16_t x) { int16_t operator()(int16_t x) {
return x; return x;
}; };
template <>
int32_t operator()(int32_t x) { int32_t operator()(int32_t x) {
return x; return x;
}; };
template <>
int64_t operator()(int64_t x) { int64_t operator()(int64_t x) {
return x; return x;
}; };
template <>
uint8_t operator()(uint8_t x) { uint8_t operator()(uint8_t x) {
return x; return x;
}; };
template <>
uint16_t operator()(uint16_t x) { uint16_t operator()(uint16_t x) {
return x; return x;
}; };
template <>
uint32_t operator()(uint32_t x) { uint32_t operator()(uint32_t x) {
return x; return x;
}; };
template <>
uint64_t operator()(uint64_t x) { uint64_t operator()(uint64_t x) {
return x; return x;
}; };
template <>
bool operator()(bool x) { bool operator()(bool x) {
return x; return x;
}; };
@ -258,7 +237,6 @@ struct Log {
return metal::precise::log(x); return metal::precise::log(x);
}; };
template <>
complex64_t operator()(complex64_t x) { complex64_t operator()(complex64_t x) {
auto r = metal::precise::log(Abs{}(x).real); auto r = metal::precise::log(Abs{}(x).real);
auto i = metal::precise::atan2(x.imag, x.real); auto i = metal::precise::atan2(x.imag, x.real);
@ -272,7 +250,6 @@ struct Log2 {
return metal::precise::log2(x); return metal::precise::log2(x);
}; };
template <>
complex64_t operator()(complex64_t x) { complex64_t operator()(complex64_t x) {
auto y = Log{}(x); auto y = Log{}(x);
return {y.real / M_LN2_F, y.imag / M_LN2_F}; return {y.real / M_LN2_F, y.imag / M_LN2_F};
@ -285,7 +262,6 @@ struct Log10 {
return metal::precise::log10(x); return metal::precise::log10(x);
}; };
template <>
complex64_t operator()(complex64_t x) { complex64_t operator()(complex64_t x) {
auto y = Log{}(x); auto y = Log{}(x);
return {y.real / M_LN10_F, y.imag / M_LN10_F}; return {y.real / M_LN10_F, y.imag / M_LN10_F};
@ -325,7 +301,6 @@ struct Round {
T operator()(T x) { T operator()(T x) {
return metal::rint(x); return metal::rint(x);
}; };
template <>
complex64_t operator()(complex64_t x) { complex64_t operator()(complex64_t x) {
return {metal::rint(x.real), metal::rint(x.imag)}; return {metal::rint(x.real), metal::rint(x.imag)};
}; };
@ -344,11 +319,9 @@ struct Sign {
T operator()(T x) { T operator()(T x) {
return (x > T(0)) - (x < T(0)); return (x > T(0)) - (x < T(0));
}; };
template <>
uint32_t operator()(uint32_t x) { uint32_t operator()(uint32_t x) {
return x != 0; return x != 0;
}; };
template <>
complex64_t operator()(complex64_t x) { complex64_t operator()(complex64_t x) {
if (x == complex64_t(0)) { if (x == complex64_t(0)) {
return x; return x;
@ -364,7 +337,6 @@ struct Sin {
return metal::precise::sin(x); return metal::precise::sin(x);
}; };
template <>
complex64_t operator()(complex64_t x) { complex64_t operator()(complex64_t x) {
return { return {
metal::precise::sin(x.real) * metal::precise::cosh(x.imag), metal::precise::sin(x.real) * metal::precise::cosh(x.imag),
@ -378,7 +350,6 @@ struct Sinh {
return metal::precise::sinh(x); return metal::precise::sinh(x);
}; };
template <>
complex64_t operator()(complex64_t x) { complex64_t operator()(complex64_t x) {
return { return {
metal::precise::sinh(x.real) * metal::precise::cos(x.imag), metal::precise::sinh(x.real) * metal::precise::cos(x.imag),
@ -398,6 +369,17 @@ struct Sqrt {
T operator()(T x) { T operator()(T x) {
return metal::precise::sqrt(x); return metal::precise::sqrt(x);
}; };
complex64_t operator()(complex64_t x) {
if (x.real == 0.0 && x.imag == 0.0) {
return {0.0, 0.0};
}
auto r = Abs{}(x).real;
auto a = metal::precise::sqrt((r + x.real) / 2.0);
auto b_abs = metal::precise::sqrt((r - x.real) / 2.0);
auto b = metal::copysign(b_abs, x.imag);
return {a, b};
}
}; };
struct Rsqrt { struct Rsqrt {
@ -405,6 +387,10 @@ struct Rsqrt {
T operator()(T x) { T operator()(T x) {
return metal::precise::rsqrt(x); return metal::precise::rsqrt(x);
}; };
complex64_t operator()(complex64_t x) {
return 1.0 / Sqrt{}(x);
}
}; };
struct Tan { struct Tan {
@ -413,7 +399,6 @@ struct Tan {
return metal::precise::tan(x); return metal::precise::tan(x);
}; };
template <>
complex64_t operator()(complex64_t x) { complex64_t operator()(complex64_t x) {
float tan_a = metal::precise::tan(x.real); float tan_a = metal::precise::tan(x.real);
float tanh_b = metal::precise::tanh(x.imag); float tanh_b = metal::precise::tanh(x.imag);
@ -429,7 +414,6 @@ struct Tanh {
return metal::precise::tanh(x); return metal::precise::tanh(x);
}; };
template <>
complex64_t operator()(complex64_t x) { complex64_t operator()(complex64_t x) {
float tanh_a = metal::precise::tanh(x.real); float tanh_a = metal::precise::tanh(x.real);
float tan_b = metal::precise::tan(x.imag); float tan_b = metal::precise::tan(x.imag);
@ -438,3 +422,21 @@ struct Tanh {
return {(tanh_a + tan_b * t1) / denom, (tan_b - tanh_a * t1) / denom}; return {(tanh_a + tan_b * t1) / denom, (tan_b - tanh_a * t1) / denom};
}; };
}; };
complex64_t ArcCos::operator()(complex64_t x) {
auto i = complex64_t{0.0, 1.0};
auto y = Log{}(x + i * Sqrt{}(1.0 - x * x));
return {y.imag, -y.real};
};
complex64_t ArcSin::operator()(complex64_t x) {
auto i = complex64_t{0.0, 1.0};
auto y = Log{}(i * x + Sqrt{}(1.0 - x * x));
return {y.imag, -y.real};
};
complex64_t ArcTan::operator()(complex64_t x) {
auto i = complex64_t{0.0, 1.0};
auto ix = i * x;
return (1.0 / complex64_t{0.0, 2.0}) * Log{}((1.0 + ix) / (1.0 - ix));
};

View File

@ -5,6 +5,7 @@
#include <numeric> #include <numeric>
#include <sstream> #include <sstream>
#include "mlx/backend/common/broadcasting.h"
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
@ -102,6 +103,47 @@ std::tuple<bool, int64_t, array> check_transpose(
} }
}; };
inline array
ensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) {
if (!x.flags().row_contiguous) {
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
d.add_temporary(x_copy, s.index);
return x_copy;
} else {
return x;
}
}
inline std::tuple<bool, int64_t, array>
ensure_batch_contiguous(const array& x, metal::Device& d, const Stream& s) {
if (x.flags().row_contiguous) {
return std::make_tuple(false, x.strides()[x.ndim() - 2], x);
}
bool rc = true;
for (int i = 0; i < x.ndim() - 3; i++) {
rc &= x.strides()[i + 1] * x.shape(i) == x.strides()[i];
}
if (rc) {
auto stx = x.strides()[x.ndim() - 2];
auto sty = x.strides()[x.ndim() - 1];
auto K = x.shape(-2);
auto N = x.shape(-1);
if (sty == 1 && (N != 1 || stx == N)) {
return std::make_tuple(false, stx, x);
}
if (stx == 1 && (N != 1 || sty == K)) {
return std::make_tuple(true, sty, x);
}
}
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
d.add_temporary(x_copy, s.index);
return std::make_tuple(false, x_copy.strides()[x_copy.ndim() - 2], x_copy);
}
} // namespace } // namespace
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
@ -230,7 +272,6 @@ void steel_matmul_regular(
const bool align_M = (M % bm) == 0; const bool align_M = (M % bm) == 0;
const bool align_N = (N % bn) == 0; const bool align_N = (N % bn) == 0;
const bool align_K = (K % bk) == 0; const bool align_K = (K % bk) == 0;
const bool do_gather = false;
metal::MTLFCList func_consts = { metal::MTLFCList func_consts = {
{&has_batch, MTL::DataType::DataTypeBool, 10}, {&has_batch, MTL::DataType::DataTypeBool, 10},
@ -239,7 +280,6 @@ void steel_matmul_regular(
{&align_M, MTL::DataType::DataTypeBool, 200}, {&align_M, MTL::DataType::DataTypeBool, 200},
{&align_N, MTL::DataType::DataTypeBool, 201}, {&align_N, MTL::DataType::DataTypeBool, 201},
{&align_K, MTL::DataType::DataTypeBool, 202}, {&align_K, MTL::DataType::DataTypeBool, 202},
{&do_gather, MTL::DataType::DataTypeBool, 300},
}; };
// clang-format off // clang-format off
@ -248,8 +288,7 @@ void steel_matmul_regular(
<< "_do_axpby_" << (do_axpby ? 't' : 'n') << "_do_axpby_" << (do_axpby ? 't' : 'n')
<< "_align_M_" << (align_M ? 't' : 'n') << "_align_M_" << (align_M ? 't' : 'n')
<< "_align_N_" << (align_N ? 't' : 'n') << "_align_N_" << (align_N ? 't' : 'n')
<< "_align_K_" << (align_K ? 't' : 'n') << "_align_K_" << (align_K ? 't' : 'n'); // clang-format on
<< "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on
std::string hash_name = kname.str(); std::string hash_name = kname.str();
@ -975,7 +1014,6 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
const bool align_M = (M % bm) == 0; const bool align_M = (M % bm) == 0;
const bool align_N = (N % bn) == 0; const bool align_N = (N % bn) == 0;
const bool align_K = (K % bk) == 0; const bool align_K = (K % bk) == 0;
const bool do_gather = false;
metal::MTLFCList func_consts = { metal::MTLFCList func_consts = {
{&has_batch, MTL::DataType::DataTypeBool, 10}, {&has_batch, MTL::DataType::DataTypeBool, 10},
@ -984,7 +1022,6 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
{&align_M, MTL::DataType::DataTypeBool, 200}, {&align_M, MTL::DataType::DataTypeBool, 200},
{&align_N, MTL::DataType::DataTypeBool, 201}, {&align_N, MTL::DataType::DataTypeBool, 201},
{&align_K, MTL::DataType::DataTypeBool, 202}, {&align_K, MTL::DataType::DataTypeBool, 202},
{&do_gather, MTL::DataType::DataTypeBool, 300},
}; };
// clang-format off // clang-format off
@ -993,8 +1030,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
<< "_do_axpby_" << (do_axpby ? 't' : 'n') << "_do_axpby_" << (do_axpby ? 't' : 'n')
<< "_align_M_" << (align_M ? 't' : 'n') << "_align_M_" << (align_M ? 't' : 'n')
<< "_align_N_" << (align_N ? 't' : 'n') << "_align_N_" << (align_N ? 't' : 'n')
<< "_align_K_" << (align_K ? 't' : 'n') << "_align_K_" << (align_K ? 't' : 'n'); // clang-format on
<< "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on
std::string hash_name = kname.str(); std::string hash_name = kname.str();
@ -1464,267 +1500,337 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
d.add_temporaries(std::move(copies), s.index); d.add_temporaries(std::move(copies), s.index);
} }
void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) { void gather_mm_rhs(
using namespace mlx::steel; const array& a_,
// assert(inputs.size() == 2); const array& b_,
if (!issubdtype(out.dtype(), floating)) { const array& indices_,
throw std::runtime_error( array& out,
"[GatherMM] Does not yet support non-floating point types."); metal::Device& d,
} const Stream& s) {
auto& s = stream(); array indices = ensure_row_contiguous(indices_, d, s);
auto& d = metal::device(s.device); auto [transpose_b, ldb, b] = ensure_batch_contiguous(b_, d, s);
auto& a_pre = inputs[0]; // Broadcast a with indices. If we are here that means lhs_indices were not
auto& b_pre = inputs[1]; // provided so the lhs_indices are implied to be the shape of a broadcasted
// Return 0s if either input is empty // with rhs_indices. We need only broadcast a and copy it as if applying the
if (a_pre.size() == 0 || b_pre.size() == 0) { // lhs_indices.
array zero = array(0, a_pre.dtype()); auto broadcast_with_indices = [&d, &s, &indices](const array& x) {
fill_gpu(zero, out, s); if (x.size() / x.shape(-2) / x.shape(-1) == indices.size()) {
d.add_temporary(std::move(zero), s.index); return ensure_row_contiguous(x, d, s);
return; }
}
out.set_data(allocator::malloc(out.nbytes())); auto x_shape = indices.shape();
x_shape.push_back(x.shape(-2));
x_shape.push_back(x.shape(-1));
array new_x(std::move(x_shape), x.dtype(), nullptr, {});
broadcast(x, new_x);
return ensure_row_contiguous(new_x, d, s);
};
array a = broadcast_with_indices(a_);
///////////////////////////////////////////////////////////////////////////// // Extract the matmul shapes
// Init checks and prep int K = a.shape(-1);
int M = a.size() / K;
int N = b.shape(-1);
int lda = a.strides()[a.ndim() - 2]; // should be K
int M = a_pre.shape(-2); // Define the dispatch blocks
int N = b_pre.shape(-1); int bm = 16, bn = 64, bk = 16;
int K = a_pre.shape(-1); int wm = 1, wn = 2;
// Keep a vector with copies to be cleared in the completed buffer to release const bool align_M = (M % bm) == 0;
// the arrays const bool align_N = (N % bn) == 0;
std::vector<array> copies; const bool align_K = (K % bk) == 0;
auto [transpose_a, a_cols, a] = check_transpose(copies, s, a_pre, M == 1);
auto [transpose_b, b_cols, b] = check_transpose(copies, s, b_pre, N == 1);
int lda = a_cols; // Define the kernel name
int ldb = b_cols; std::string base_name;
base_name.reserve(64);
concatenate(
base_name,
"steel_gather_mm_rhs_n",
transpose_b ? 't' : 'n',
'_',
type_to_name(a),
'_',
type_to_name(out),
"_bm",
bm,
"_bn",
bn,
"_bk",
bk,
"_wm",
wm,
"_wn",
wn);
///////////////////////////////////////////////////////////////////////////// metal::MTLFCList func_consts = {
// Check and collapse batch dimensions {&align_M, MTL::DataType::DataTypeBool, 200},
{&align_N, MTL::DataType::DataTypeBool, 201},
auto get_batch_dims = [](const auto& v) { {&align_K, MTL::DataType::DataTypeBool, 202},
return decltype(v){v.begin(), v.end() - 2};
}; };
auto& lhs_indices = inputs[2]; // And the kernel hash that includes the function constants
auto& rhs_indices = inputs[3]; std::string hash_name;
hash_name.reserve(128);
concatenate(
hash_name,
base_name,
"_align_M_",
align_M ? 't' : 'n',
"_align_N_",
align_N ? 't' : 'n',
"_align_K_",
align_K ? 't' : 'n');
Shape batch_shape = get_batch_dims(out.shape()); // Get and set the kernel
Strides batch_strides; auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = get_steel_gemm_gather_kernel(
d,
base_name,
hash_name,
func_consts,
out,
false,
transpose_b,
bm,
bn,
bk,
wm,
wn,
true);
compute_encoder.set_compute_pipeline_state(kernel);
batch_strides.insert( // Prepare the matmul params
batch_strides.end(), auto batch_stride_b = b.ndim() > 2 ? b.strides()[b.ndim() - 3] : b.size();
lhs_indices.strides().begin(), steel::GEMMParams params{
lhs_indices.strides().end()); /* const int M = */ M,
auto lhs_indices_str = batch_strides.empty() ? 0 : batch_strides.back(); /* const int N = */ N,
/* const int K = */ K,
/* const int lda = */ lda,
/* const int ldb = */ static_cast<int>(ldb),
/* const int ldd = */ N,
/* const int tiles_n = */ (N + bn - 1) / bn,
/* const int tiles_m = */ (M + bm - 1) / bm,
/* const int64_t batch_stride_a = */ 0,
/* const int64_t batch_stride_b = */ static_cast<int64_t>(batch_stride_b),
/* const int64_t batch_stride_d = */ 0,
/* const int swizzle_log = */ 0,
/* const int gemm_k_iterations_aligned = */ (K / bk),
/* const int batch_ndim = */ 0};
batch_strides.insert( // Prepare the grid
batch_strides.end(), MTL::Size group_dims = MTL::Size(32, wn, wm);
rhs_indices.strides().begin(), MTL::Size grid_dims = MTL::Size(params.tiles_n, params.tiles_m, 1);
rhs_indices.strides().end());
auto rhs_indices_str = batch_strides.empty() ? 0 : batch_strides.back();
int batch_ndim = batch_shape.size(); // Launch kernel
compute_encoder.set_input_array(a, 0);
compute_encoder.set_input_array(b, 1);
compute_encoder.set_input_array(indices, 2);
compute_encoder.set_output_array(out, 3);
compute_encoder.set_bytes(params, 4);
if (batch_ndim == 0) { compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
batch_shape = {1}; }
batch_strides = {0};
}
int batch_ndim_A = a.ndim() - 2; void gather_mv(
int batch_ndim_B = b.ndim() - 2; const array& mat_,
std::vector<int> operand_batch_ndim = {batch_ndim_A, batch_ndim_B}; const array& vec_,
const array& mat_indices_,
const array& vec_indices_,
array& out,
int N,
int K,
bool is_mv,
metal::Device& d,
const Stream& s) {
// Copy if needed
std::vector<array> copies;
auto [transpose_mat, mat_cols, mat] =
check_transpose(copies, s, mat_, N == 1);
auto [transpose_vec, vec_cols, vec] = check_transpose(copies, s, vec_, true);
d.add_temporaries(std::move(copies), s.index);
Shape batch_shape_A = get_batch_dims(a.shape()); // If we are doing vector matrix instead of matrix vector we need to flip the
Strides batch_strides_A = get_batch_dims(a.strides()); // matrix transposition. Basically m @ v = v @ m.T assuming that v is treated
Shape batch_shape_B = get_batch_dims(b.shape()); // as a one dimensional array.
Strides batch_strides_B = get_batch_dims(b.strides()); transpose_mat = (!is_mv) ^ transpose_mat;
if (batch_ndim_A == 0) { // Define some shapes
batch_shape_A = {1}; int in_vector_len = K;
batch_strides_A = {0}; int out_vector_len = N;
} int mat_ld = mat_cols;
if (batch_ndim_B == 0) { int batch_size_out = out.size() / N;
batch_shape_B = {1}; int batch_ndim = out.ndim() - 2;
batch_strides_B = {0}; int batch_ndim_mat = mat.ndim() - 2;
} int batch_ndim_vec = vec.ndim() - 2;
Strides index_strides = vec_indices_.strides();
index_strides.insert(
index_strides.end(),
mat_indices_.strides().begin(),
mat_indices_.strides().end());
auto matrix_stride_out = static_cast<int64_t>(M) * N; // Determine dispatch kernel
auto batch_size_out = out.size() / matrix_stride_out; int tm = 4, tn = 4;
int sm = 1, sn = 32;
///////////////////////////////////////////////////////////////////////////// int bm = 1, bn = 1;
// Gemv specialization int n_out_per_tgp;
std::ostringstream kname;
// Route to gemv if needed
if (std::min(M, N) == 1) {
// Collect problem info
bool is_b_matrix = N != 1;
auto& mat = is_b_matrix ? b : a;
auto& vec = is_b_matrix ? a : b;
bool transpose_mat = is_b_matrix ? !transpose_b : transpose_a;
int in_vector_len = K;
int out_vector_len = is_b_matrix ? N : M;
int mat_cols = transpose_mat ? out_vector_len : in_vector_len;
int mat_rows = transpose_mat ? in_vector_len : out_vector_len;
int mat_ld = is_b_matrix ? b_cols : a_cols;
auto batch_strides_mat = is_b_matrix ? batch_strides_B : batch_strides_A;
auto batch_strides_vec = is_b_matrix ? batch_strides_A : batch_strides_B;
auto batch_shape_mat = is_b_matrix ? batch_shape_B : batch_shape_A;
auto batch_shape_vec = is_b_matrix ? batch_shape_A : batch_shape_B;
if (!is_b_matrix) {
batch_strides = rhs_indices.strides();
batch_strides.insert(
batch_strides.end(),
lhs_indices.strides().begin(),
lhs_indices.strides().end());
}
int batch_ndim = batch_shape.size();
// Determine dispatch kernel
int tm = 4, tn = 4;
int sm = 1, sn = 32;
int bm = 1, bn = 1;
int n_out_per_tgp;
std::ostringstream kname;
if (transpose_mat) {
if (in_vector_len >= 8192 && out_vector_len >= 2048) {
sm = 4;
sn = 8;
} else {
sm = 8;
sn = 4;
}
if (out_vector_len >= 2048) {
bn = 16;
} else if (out_vector_len >= 512) {
bn = 4;
} else {
bn = 2;
}
// Specialized kernel for very small outputs
tn = out_vector_len < tn ? 1 : tn;
n_out_per_tgp = bn * sn * tn;
kname << "gemv_t_gather_" << type_to_name(out);
if (transpose_mat) {
if (in_vector_len >= 8192 && out_vector_len >= 2048) {
sm = 4;
sn = 8;
} else { } else {
bm = out_vector_len >= 4096 ? 8 : 4; sm = 8;
sn = 32; sn = 4;
// Specialized kernel for very small outputs
tm = out_vector_len < tm ? 1 : tm;
n_out_per_tgp = bm * sm * tm;
kname << "gemv_gather_" << type_to_name(out);
} }
kname << "_bm" << bm << "_bn" << bn << "_sm" << sm << "_sn" << sn << "_tm" if (out_vector_len >= 2048) {
<< tm << "_tn" << tn; bn = 16;
} else if (out_vector_len >= 512) {
bn = 4;
} else {
bn = 2;
}
// Encode and dispatch kernel // Specialized kernel for very small outputs
auto& compute_encoder = d.get_command_encoder(s.index); tn = out_vector_len < tn ? 1 : tn;
auto kernel = d.get_kernel(kname.str());
compute_encoder.set_compute_pipeline_state(kernel);
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp; n_out_per_tgp = bn * sn * tn;
MTL::Size group_dims = MTL::Size(32, bn, bm); kname << "gemv_t_gather_" << type_to_name(out);
MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out);
compute_encoder.set_input_array(mat, 0); } else {
compute_encoder.set_input_array(vec, 1); bm = out_vector_len >= 4096 ? 8 : 4;
compute_encoder.set_output_array(out, 3); sn = 32;
compute_encoder.set_bytes(in_vector_len, 4); // Specialized kernel for very small outputs
compute_encoder.set_bytes(out_vector_len, 5); tm = out_vector_len < tm ? 1 : tm;
compute_encoder.set_bytes(mat_ld, 6);
compute_encoder.set_bytes(batch_ndim, 9); n_out_per_tgp = bm * sm * tm;
compute_encoder.set_vector_bytes(batch_shape, 10); kname << "gemv_gather_" << type_to_name(out);
compute_encoder.set_vector_bytes(batch_strides, 11);
int batch_ndim_vec = batch_shape_vec.size();
compute_encoder.set_bytes(batch_ndim_vec, 12);
compute_encoder.set_vector_bytes(batch_shape_vec, 13);
compute_encoder.set_vector_bytes(batch_strides_vec, 14);
int batch_ndim_mat = batch_shape_mat.size();
compute_encoder.set_bytes(batch_ndim_mat, 15);
compute_encoder.set_vector_bytes(batch_shape_mat, 16);
compute_encoder.set_vector_bytes(batch_strides_mat, 17);
compute_encoder.set_input_array(lhs_indices, 18 + int(!is_b_matrix));
compute_encoder.set_input_array(rhs_indices, 18 + int(is_b_matrix));
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
d.add_temporaries(std::move(copies), s.index);
return;
} }
///////////////////////////////////////////////////////////////////////////// kname << "_bm" << bm << "_bn" << bn << "_sm" << sm << "_sn" << sn << "_tm"
// Regular kernel dispatch << tm << "_tn" << tn;
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder.set_compute_pipeline_state(kernel);
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
MTL::Size group_dims = MTL::Size(32, bn, bm);
MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out);
compute_encoder.set_input_array(mat, 0);
compute_encoder.set_input_array(vec, 1);
compute_encoder.set_output_array(out, 3);
compute_encoder.set_bytes(in_vector_len, 4);
compute_encoder.set_bytes(out_vector_len, 5);
compute_encoder.set_bytes(mat_ld, 6);
compute_encoder.set_bytes(batch_ndim, 9);
compute_encoder.set_vector_bytes(out.shape(), 10);
compute_encoder.set_vector_bytes(index_strides, 11);
compute_encoder.set_bytes(batch_ndim_vec, 12);
compute_encoder.set_vector_bytes(vec.shape(), 13);
compute_encoder.set_vector_bytes(vec.strides(), 14);
compute_encoder.set_bytes(batch_ndim_mat, 15);
compute_encoder.set_vector_bytes(mat.shape(), 16);
compute_encoder.set_vector_bytes(mat.strides(), 17);
compute_encoder.set_input_array(vec_indices_, 18);
compute_encoder.set_input_array(mat_indices_, 19);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
void gather_mm(
const array& a_,
const array& b_,
const array& lhs_indices,
const array& rhs_indices,
array& out,
int M,
int N,
int K,
metal::Device& d,
const Stream& s) {
// Copy if needed
std::vector<array> copies;
auto [transpose_a, lda, a] = check_transpose(copies, s, a_, false);
auto [transpose_b, ldb, b] = check_transpose(copies, s, b_, false);
d.add_temporaries(std::move(copies), s.index);
// Determine dispatch kernel // Determine dispatch kernel
int bm = 64, bn = 64, bk = 16; int bm = 64, bn = 64, bk = 16;
int wm = 2, wn = 2; int wm = 2, wn = 2;
size_t batch_size_out = out.size() / M / N;
int batch_ndim = out.ndim() - 2;
int batch_ndim_a = a.ndim() - 2;
int batch_ndim_b = b.ndim() - 2;
char devc = d.get_architecture().back(); char devc = d.get_architecture().back();
GEMM_TPARAM_MACRO(devc) GEMM_TPARAM_MACRO(devc)
// Prepare kernel name
std::ostringstream kname;
kname << "steel_gemm_fused_" << (transpose_a ? 't' : 'n')
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
<< "_wm" << wm << "_wn" << wn;
std::string base_name = kname.str();
const bool has_batch = batch_ndim > 1; const bool has_batch = batch_ndim > 1;
const bool use_out_source = false;
const bool do_axpby = false;
const bool align_M = (M % bm) == 0; const bool align_M = (M % bm) == 0;
const bool align_N = (N % bn) == 0; const bool align_N = (N % bn) == 0;
const bool align_K = (K % bk) == 0; const bool align_K = (K % bk) == 0;
const bool do_gather = true;
// Define the kernel name
std::string base_name;
base_name.reserve(128);
concatenate(
base_name,
"steel_gather_mm_",
transpose_a ? 't' : 'n',
transpose_b ? 't' : 'n',
"_",
type_to_name(a),
"_",
type_to_name(out),
"_bm",
bm,
"_bn",
bn,
"_bk",
bk,
"_wm",
wm,
"_wn",
wn);
metal::MTLFCList func_consts = { metal::MTLFCList func_consts = {
{&has_batch, MTL::DataType::DataTypeBool, 10}, {&has_batch, MTL::DataType::DataTypeBool, 10},
{&use_out_source, MTL::DataType::DataTypeBool, 100},
{&do_axpby, MTL::DataType::DataTypeBool, 110},
{&align_M, MTL::DataType::DataTypeBool, 200}, {&align_M, MTL::DataType::DataTypeBool, 200},
{&align_N, MTL::DataType::DataTypeBool, 201}, {&align_N, MTL::DataType::DataTypeBool, 201},
{&align_K, MTL::DataType::DataTypeBool, 202}, {&align_K, MTL::DataType::DataTypeBool, 202},
{&do_gather, MTL::DataType::DataTypeBool, 300},
}; };
// clang-format off // And the kernel hash that includes the function constants
kname << "_has_batch_" << (has_batch ? 't' : 'n') std::string hash_name;
<< "_use_out_source_" << (use_out_source ? 't' : 'n') hash_name.reserve(128);
<< "_do_axpby_" << (do_axpby ? 't' : 'n') concatenate(
<< "_align_M_" << (align_M ? 't' : 'n') hash_name,
<< "_align_N_" << (align_N ? 't' : 'n') base_name,
<< "_align_K_" << (align_K ? 't' : 'n') "_has_batch_",
<< "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on has_batch ? 't' : 'n',
"_align_M_",
align_M ? 't' : 'n',
"_align_N_",
align_N ? 't' : 'n',
"_align_K_",
align_K ? 't' : 'n');
std::string hash_name = kname.str(); // Get and set the kernel
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = get_steel_gemm_fused_kernel( auto kernel = get_steel_gemm_gather_kernel(
d, d,
base_name, base_name,
hash_name, hash_name,
@ -1736,72 +1842,96 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
bn, bn,
bk, bk,
wm, wm,
wn); wn,
false);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_compute_pipeline_state(kernel);
// Use problem size to determine threadblock swizzle // Prepare the matmul params
int tn = (N + bn - 1) / bn; steel::GEMMParams params{
int tm = (M + bm - 1) / bm;
// TODO: Explore device-based tuning for swizzle
int swizzle_log = 0; // tm >= 6 ? 3 : (tm <= 3 ? 0 : 2);
// Prepare steel matmul params
GEMMParams params{
/* const int M = */ M, /* const int M = */ M,
/* const int N = */ N, /* const int N = */ N,
/* const int K = */ K, /* const int K = */ K,
/* const int lda = */ lda, /* const int lda = */ static_cast<int>(lda),
/* const int ldb = */ ldb, /* const int ldb = */ static_cast<int>(ldb),
/* const int ldd = */ N, /* const int ldd = */ N,
/* const int tiles_n = */ tn, /* const int tiles_n = */ (N + bn - 1) / bn,
/* const int tiles_m = */ tm, /* const int tiles_m = */ (M + bm - 1) / bm,
/* const int64_t batch_stride_a = */ lhs_indices_str, /* const int64_t batch_stride_a = */
/* const int64_t batch_stride_b = */ rhs_indices_str, (batch_ndim > 0) ? lhs_indices.strides()[0] : 0,
/* const int64_t batch_stride_d = */ matrix_stride_out, /* const int64_t batch_stride_b = */
/* const int swizzle_log = */ swizzle_log, (batch_ndim > 0) ? rhs_indices.strides()[0] : 0,
/* const int64_t batch_stride_d = */ M * N,
/* const int swizzle_log = */ 0,
/* const int gemm_k_iterations_aligned = */ (K / bk), /* const int gemm_k_iterations_aligned = */ (K / bk),
/* const int batch_ndim = */ batch_ndim}; /* const int batch_ndim = */ batch_ndim};
// Prepare launch grid params // Prepare the grid
int tile = 1 << swizzle_log;
tm = (tm + tile - 1) / tile;
tn = tn * tile;
MTL::Size group_dims = MTL::Size(32, wn, wm); MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out); MTL::Size grid_dims =
MTL::Size(params.tiles_n, params.tiles_m, batch_size_out);
// Launch kernel // Launch kernel
compute_encoder.set_input_array(a, 0); compute_encoder.set_input_array(a, 0);
compute_encoder.set_input_array(b, 1); compute_encoder.set_input_array(b, 1);
compute_encoder.set_output_array(out, 3); compute_encoder.set_input_array(lhs_indices, 2);
compute_encoder.set_input_array(rhs_indices, 3);
compute_encoder.set_bytes(params, 4); compute_encoder.set_output_array(out, 4);
compute_encoder.set_bytes(params, 5);
compute_encoder.set_vector_bytes(batch_shape, 6); compute_encoder.set_vector_bytes(lhs_indices.shape(), 6);
compute_encoder.set_vector_bytes(batch_strides, 7); compute_encoder.set_vector_bytes(lhs_indices.strides(), 7);
compute_encoder.set_vector_bytes(rhs_indices.strides(), 8);
compute_encoder.set_input_array(lhs_indices, 10); compute_encoder.set_bytes(batch_ndim_a, 9);
compute_encoder.set_input_array(rhs_indices, 11); compute_encoder.set_vector_bytes(a.shape(), 10);
compute_encoder.set_vector_bytes(a.strides(), 11);
std::vector operand_shape = batch_shape_A; compute_encoder.set_bytes(batch_ndim_b, 12);
operand_shape.insert( compute_encoder.set_vector_bytes(b.shape(), 13);
operand_shape.end(), batch_shape_B.begin(), batch_shape_B.end()); compute_encoder.set_vector_bytes(b.strides(), 14);
std::vector operand_strides = batch_strides_A;
operand_strides.insert(
operand_strides.end(), batch_strides_B.begin(), batch_strides_B.end());
operand_batch_ndim.push_back(0);
compute_encoder.set_vector_bytes(operand_shape, 13);
compute_encoder.set_vector_bytes(operand_strides, 14);
compute_encoder.set_vector_bytes(operand_batch_ndim, 15);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims); compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
d.add_temporaries(std::move(copies), s.index); void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = stream();
auto& d = metal::device(s.device);
auto& a = inputs[0];
auto& b = inputs[1];
auto& lhs_indices = inputs[2];
auto& rhs_indices = inputs[3];
// Return 0s if either input is empty
if (a.size() == 0 || b.size() == 0) {
array zero = array(0, a.dtype());
fill_gpu(zero, out, s);
d.add_temporary(std::move(zero), s.index);
return;
}
out.set_data(allocator::malloc(out.nbytes()));
// Extract shapes from inputs.
int M = a.shape(-2);
int N = b.shape(-1);
int K = a.shape(-1);
// We are walking a in order and b is also in order so we can batch up the
// matmuls and reuse reading a and b.
if (M == 1 && right_sorted_ == true) {
gather_mm_rhs(a, b, rhs_indices, out, d, s);
return;
}
// Route to gather gemv if any of a or b are vectors
if (M == 1) {
gather_mv(b, a, rhs_indices, lhs_indices, out, N, K, false, d, s);
return;
}
if (N == 1) {
gather_mv(a, b, lhs_indices, rhs_indices, out, M, K, true, d, s);
return;
}
// Route to non specialized gather mm
gather_mm(a, b, lhs_indices, rhs_indices, out, M, N, K, d, s);
} }
} // namespace mlx::core } // namespace mlx::core

View File

@ -193,6 +193,23 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
return d.get_kernel(kernel_name); return d.get_kernel(kernel_name);
} }
MTL::ComputePipelineState* get_steel_gemm_gather_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const array&,
bool,
bool,
int,
int,
int,
int,
int,
bool) {
return d.get_kernel(kernel_name, "mlx", hash_name, func_consts);
}
MTL::ComputePipelineState* get_gemv_masked_kernel( MTL::ComputePipelineState* get_gemv_masked_kernel(
metal::Device& d, metal::Device& d,
const std::string& kernel_name, const std::string& kernel_name,
@ -252,4 +269,21 @@ MTL::ComputePipelineState* get_quantized_kernel(
return d.get_kernel(kernel_name); return d.get_kernel(kernel_name);
} }
MTL::ComputePipelineState* get_gather_qmm_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const array&,
int,
int,
int,
int,
int,
int,
int,
bool) {
return d.get_kernel(kernel_name, "mlx", hash_name, func_consts);
}
} // namespace mlx::core } // namespace mlx::core

File diff suppressed because it is too large Load Diff

View File

@ -60,6 +60,9 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
case Scan::Min: case Scan::Min:
reduce_type = "min"; reduce_type = "min";
break; break;
case Scan::LogAddExp:
reduce_type = "logaddexp";
break;
} }
kname << reduce_type << "_" << type_to_name(in) << "_" << type_to_name(out); kname << reduce_type << "_" << type_to_name(in) << "_" << type_to_name(out);
auto kernel = get_scan_kernel( auto kernel = get_scan_kernel(

View File

@ -2,6 +2,8 @@
#pragma once #pragma once
#include <type_traits>
#include "mlx/array.h" #include "mlx/array.h"
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
@ -58,14 +60,27 @@ inline void debug_set_primitive_buffer_label(
std::string get_primitive_string(Primitive* primitive); std::string get_primitive_string(Primitive* primitive);
template <typename T>
constexpr bool is_numeric_except_char = std::is_arithmetic_v<T> &&
!std::is_same_v<T, char> && !std::is_same_v<T, signed char> &&
!std::is_same_v<T, unsigned char> && !std::is_same_v<T, wchar_t>;
template <typename T> template <typename T>
void concatenate(std::string& acc, T first) { void concatenate(std::string& acc, T first) {
acc += first; if constexpr (is_numeric_except_char<T>) {
acc += std::to_string(first);
} else {
acc += first;
}
} }
template <typename T, typename... Args> template <typename T, typename... Args>
void concatenate(std::string& acc, T first, Args... args) { void concatenate(std::string& acc, T first, Args... args) {
acc += first; if constexpr (is_numeric_except_char<T>) {
acc += std::to_string(first);
} else {
acc += first;
}
concatenate(acc, args...); concatenate(acc, args...);
} }

20
mlx/dtype_utils.cpp Normal file
View File

@ -0,0 +1,20 @@
// Copyright © 2025 Apple Inc.
#include "mlx/dtype_utils.h"
namespace mlx::core {
const char* dtype_to_string(Dtype arg) {
if (arg == bool_) {
return "bool";
}
#define SPECIALIZE_DtypeToString(CPP_TYPE, DTYPE) \
if (DTYPE == arg) { \
return #DTYPE; \
}
MLX_FORALL_DTYPES(SPECIALIZE_DtypeToString)
#undef SPECIALIZE_DtypeToString
return "(unknown)";
}
} // namespace mlx::core

207
mlx/dtype_utils.h Normal file
View File

@ -0,0 +1,207 @@
// Copyright © 2025 Apple Inc.
// Copyright © Meta Platforms, Inc. and affiliates.
//
// This source code is licensed under the BSD-style license found in
// https://github.com/pytorch/executorch/blob/main/LICENSE
//
// Forked from
// https://github.com/pytorch/executorch/blob/main/runtime/core/exec_aten/util/scalar_type_util.h
#pragma once
#include "mlx/dtype.h"
#include <fmt/format.h>
namespace mlx::core {
// Return string representation of dtype.
const char* dtype_to_string(Dtype arg);
// Macros that iterate across different subsets of Dtypes.
//
// For all of these macros, the final `_` parameter is the name of another macro
// that takes two parameters: the name of a C type, and the name of the
// corresponding Dtype enumerator.
//
// Note that these macros should use fully-qualified namespaces (starting with
// `::`) to ensure that they can be called safely in any arbitrary namespace.
#define MLX_FORALL_INT_TYPES(_) \
_(uint8_t, uint8) \
_(uint16_t, uint16) \
_(uint32_t, uint32) \
_(uint64_t, uint64) \
_(int8_t, int8) \
_(int16_t, int16) \
_(int32_t, int32) \
_(int64_t, int64)
#define MLX_FORALL_FLOAT_TYPES(_) \
_(float16_t, float16) \
_(float, float32) \
_(double, float64) \
_(bfloat16_t, bfloat16)
// Calls the provided macro on every Dtype, providing the C type and the
// Dtype name to each call.
//
// @param _ A macro that takes two parameters: the name of a C type, and the
// name of the corresponding Dtype enumerator.
#define MLX_FORALL_DTYPES(_) \
MLX_FORALL_INT_TYPES(_) \
MLX_FORALL_FLOAT_TYPES(_) \
_(bool, bool_) \
_(complex64_t, complex64)
// Maps Dtypes to C++ types.
template <Dtype::Val N>
struct DtypeToCppType;
#define SPECIALIZE_DtypeToCppType(CPP_TYPE, DTYPE) \
template <> \
struct DtypeToCppType<Dtype::Val::DTYPE> { \
using type = CPP_TYPE; \
};
MLX_FORALL_DTYPES(SPECIALIZE_DtypeToCppType)
#undef SPECIALIZE_DtypeToCppType
// Maps C++ types to Dtypes.
template <typename T>
struct CppTypeToDtype;
#define SPECIALIZE_CppTypeToDtype(CPP_TYPE, DTYPE) \
template <> \
struct CppTypeToDtype<CPP_TYPE> \
: std::integral_constant<Dtype::Val, Dtype::Val::DTYPE> {};
MLX_FORALL_DTYPES(SPECIALIZE_CppTypeToDtype)
#undef SPECIALIZE_CppTypeToDtype
// Helper macros for switch case macros (see below)
//
// These macros are not meant to be used directly. They provide an easy way to
// generate a switch statement that can handle subsets of Dtypes supported.
#define MLX_INTERNAL_SWITCH_CASE(enum_type, CTYPE_ALIAS, ...) \
case enum_type: { \
using CTYPE_ALIAS = ::mlx::core::DtypeToCppType<enum_type>::type; \
__VA_ARGS__; \
break; \
}
#define MLX_INTERNAL_SWITCH_CHECKED(TYPE, NAME, ...) \
switch (TYPE) { \
__VA_ARGS__ \
default: \
throw std::invalid_argument(fmt::format( \
"Unhandled dtype %s for %s", dtype_to_string(TYPE), NAME)); \
}
#define MLX_INTERNAL_SWITCH_CASE_INT_TYPES(CTYPE_ALIAS, ...) \
MLX_INTERNAL_SWITCH_CASE( \
::mlx::core::Dtype::Val::uint8, CTYPE_ALIAS, __VA_ARGS__) \
MLX_INTERNAL_SWITCH_CASE( \
::mlx::core::Dtype::Val::uint16, CTYPE_ALIAS, __VA_ARGS__) \
MLX_INTERNAL_SWITCH_CASE( \
::mlx::core::Dtype::Val::uint32, CTYPE_ALIAS, __VA_ARGS__) \
MLX_INTERNAL_SWITCH_CASE( \
::mlx::core::Dtype::Val::uint64, CTYPE_ALIAS, __VA_ARGS__) \
MLX_INTERNAL_SWITCH_CASE( \
::mlx::core::Dtype::Val::int8, CTYPE_ALIAS, __VA_ARGS__) \
MLX_INTERNAL_SWITCH_CASE( \
::mlx::core::Dtype::Val::int16, CTYPE_ALIAS, __VA_ARGS__) \
MLX_INTERNAL_SWITCH_CASE( \
::mlx::core::Dtype::Val::int32, CTYPE_ALIAS, __VA_ARGS__) \
MLX_INTERNAL_SWITCH_CASE( \
::mlx::core::Dtype::Val::int64, CTYPE_ALIAS, __VA_ARGS__)
#define MLX_INTERNAL_SWITCH_CASE_FLOAT_TYPES(CTYPE_ALIAS, ...) \
MLX_INTERNAL_SWITCH_CASE( \
::mlx::core::Dtype::Val::float16, CTYPE_ALIAS, __VA_ARGS__) \
MLX_INTERNAL_SWITCH_CASE( \
::mlx::core::Dtype::Val::float32, CTYPE_ALIAS, __VA_ARGS__) \
MLX_INTERNAL_SWITCH_CASE( \
::mlx::core::Dtype::Val::float64, CTYPE_ALIAS, __VA_ARGS__) \
MLX_INTERNAL_SWITCH_CASE( \
::mlx::core::Dtype::Val::bfloat16, CTYPE_ALIAS, __VA_ARGS__)
#define MLX_INTERNAL_SWITCH_CASE_INT_FLOAT_TYPES(CTYPE_ALIAS, ...) \
MLX_INTERNAL_SWITCH_CASE_INT_TYPES(CTYPE_ALIAS, __VA_ARGS__) \
MLX_INTERNAL_SWITCH_CASE_FLOAT_TYPES(CTYPE_ALIAS, __VA_ARGS__)
#define MLX_INTERNAL_SWITCH_CASE_REAL_TYPES(CTYPE_ALIAS, ...) \
MLX_INTERNAL_SWITCH_CASE_INT_FLOAT_TYPES(CTYPE_ALIAS, __VA_ARGS__) \
MLX_INTERNAL_SWITCH_CASE( \
::mlx::core::Dtype::Val::bool_, CTYPE_ALIAS, __VA_ARGS__)
#define MLX_INTERNAL_SWITCH_CASE_COMPLEX_TYPES(CTYPE_ALIAS, ...) \
MLX_INTERNAL_SWITCH_CASE( \
::mlx::core::Dtype::Val::complex64, CTYPE_ALIAS, __VA_ARGS__)
#define MLX_INTERNAL_SWITCH_CASE_ALL_TYPES(CTYPE_ALIAS, ...) \
MLX_INTERNAL_SWITCH_CASE_REAL_TYPES(CTYPE_ALIAS, __VA_ARGS__) \
MLX_INTERNAL_SWITCH_CASE_COMPLEX_TYPES(CTYPE_ALIAS, __VA_ARGS__)
// Switch case macros
//
// These macros provide an easy way to generate switch statements that apply a
// common lambda function to subsets of Dtypes supported by MLX.
// The lambda function can type specialize to the ctype associated with the
// Dtype being handled through an alias passed as the CTYPE_ALIAS argument.
//
// Arguments:
// - ADDITIONAL: Additional Dtype case to add
// - TYPE: The Dtype to handle through the switch statement
// - NAME: A name for this operation which will be used in error messages
// - CTYPE_ALIAS: A typedef for the ctype associated with the Dtype.
// - ...: A statement to be applied to each Dtype case
//
// An example usage is:
//
// MLX_SWITCH_ALL_TYPES(input.dtype(), CTYPE, {
// output.data<CTYPE>[0] = input.data<CTYPE>[0];
// });
//
// Note that these can be nested as well:
//
// MLX_SWITCH_ALL_TYPES(input.dtype(), CTYPE_IN, {
// MLX_SWITCH_ALL_TYPES(output.dtype(), CTYPE_OUT, {
// output.data<CTYPE_OUT>[0] = input.data<CTYPE_IN>[0];
// });
// });
//
// These macros are adapted from Dispatch.h in the ATen library. The primary
// difference is that the CTYPE_ALIAS argument is exposed to users, which is
// used to alias the ctype associated with the Dtype that is being handled.
#define MLX_SWITCH_ALL_TYPES(TYPE, CTYPE_ALIAS, ...) \
switch (TYPE) { MLX_INTERNAL_SWITCH_CASE_ALL_TYPES(CTYPE_ALIAS, __VA_ARGS__) }
#define MLX_SWITCH_INT_TYPES_CHECKED(TYPE, NAME, CTYPE_ALIAS, ...) \
MLX_INTERNAL_SWITCH_CHECKED( \
TYPE, \
NAME, \
MLX_INTERNAL_SWITCH_CASE_INT_TYPES(CTYPE_ALIAS, __VA_ARGS__))
#define MLX_SWITCH_FLOAT_TYPES_CHECKED(TYPE, NAME, CTYPE_ALIAS, ...) \
MLX_INTERNAL_SWITCH_CHECKED( \
TYPE, \
NAME, \
MLX_INTERNAL_SWITCH_CASE_FLOAT_TYPES(CTYPE_ALIAS, __VA_ARGS__))
#define MLX_SWITCH_INT_FLOAT_TYPES_CHECKED(TYPE, NAME, CTYPE_ALIAS, ...) \
MLX_INTERNAL_SWITCH_CHECKED( \
TYPE, \
NAME, \
MLX_INTERNAL_SWITCH_CASE_INT_FLOAT_TYPES(CTYPE_ALIAS, __VA_ARGS__))
#define MLX_SWITCH_REAL_TYPES_CHECKED(TYPE, NAME, CTYPE_ALIAS, ...) \
MLX_INTERNAL_SWITCH_CHECKED( \
TYPE, \
NAME, \
MLX_INTERNAL_SWITCH_CASE_REAL_TYPES(CTYPE_ALIAS, __VA_ARGS__))
} // namespace mlx::core

View File

@ -1,5 +1,6 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
#include "mlx/export.h" #include "mlx/export.h"
#include <map>
#include "mlx/compile_impl.h" #include "mlx/compile_impl.h"
#include "mlx/fast_primitives.h" #include "mlx/fast_primitives.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
@ -298,7 +299,13 @@ struct PrimitiveFactory {
SERIALIZE_PRIMITIVE(Reshape), SERIALIZE_PRIMITIVE(Reshape),
SERIALIZE_PRIMITIVE(Reduce, "And", "Or", "Sum", "Prod", "Min", "Max"), SERIALIZE_PRIMITIVE(Reduce, "And", "Or", "Sum", "Prod", "Min", "Max"),
SERIALIZE_PRIMITIVE(Round), SERIALIZE_PRIMITIVE(Round),
SERIALIZE_PRIMITIVE(Scan, "CumSum", "CumProd", "CumMin", "CumMax"), SERIALIZE_PRIMITIVE(
Scan,
"CumSum",
"CumProd",
"CumMin",
"CumMax",
"CumLogaddexp"),
SERIALIZE_PRIMITIVE(Scatter), SERIALIZE_PRIMITIVE(Scatter),
SERIALIZE_PRIMITIVE(Select), SERIALIZE_PRIMITIVE(Select),
SERIALIZE_PRIMITIVE(Sigmoid), SERIALIZE_PRIMITIVE(Sigmoid),
@ -475,7 +482,9 @@ bool FunctionTable::match(
return false; return false;
} }
} }
for (auto& [_, in] : kwargs) { auto sorted_kwargs =
std::map<std::string, array>(kwargs.begin(), kwargs.end());
for (auto& [_, in] : sorted_kwargs) {
if (!match_inputs(in, fun.inputs[i++])) { if (!match_inputs(in, fun.inputs[i++])) {
return false; return false;
} }
@ -551,7 +560,9 @@ void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) {
// Flatten the inputs to the function for tracing // Flatten the inputs to the function for tracing
std::vector<std::string> kwarg_keys; std::vector<std::string> kwarg_keys;
auto inputs = args; auto inputs = args;
for (auto& [k, v] : kwargs) { auto sorted_kwargs =
std::map<std::string, array>(kwargs.begin(), kwargs.end());
for (auto& [k, v] : sorted_kwargs) {
kwarg_keys.push_back(k); kwarg_keys.push_back(k);
inputs.push_back(v); inputs.push_back(v);
} }

View File

@ -2,14 +2,14 @@
#pragma once #pragma once
#include <map>
#include <set> #include <set>
#include <unordered_map>
#include "mlx/array.h" #include "mlx/array.h"
namespace mlx::core { namespace mlx::core {
using Args = std::vector<array>; using Args = std::vector<array>;
using Kwargs = std::map<std::string, array>; using Kwargs = std::unordered_map<std::string, array>;
struct FunctionExporter; struct FunctionExporter;

View File

@ -111,7 +111,7 @@ array fft_impl(
for (auto ax : axes) { for (auto ax : axes) {
n.push_back(a.shape(ax)); n.push_back(a.shape(ax));
} }
if (real && inverse) { if (real && inverse && a.ndim() > 0) {
n.back() = (n.back() - 1) * 2; n.back() = (n.back() - 1) * 2;
} }
return fft_impl(a, n, axes, real, inverse, s); return fft_impl(a, n, axes, real, inverse, s);

View File

@ -3504,6 +3504,28 @@ array cummin(
{a}); {a});
} }
array logcumsumexp(
const array& a,
int axis,
bool reverse /* = false*/,
bool inclusive /* = true*/,
StreamOrDevice s /* = {}*/) {
int ndim = a.ndim();
if (axis >= ndim || axis < -ndim) {
std::ostringstream msg;
msg << "[logcumsumexp] Axis " << axis << " is out of bounds for array with "
<< a.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
axis = (axis + a.ndim()) % a.ndim();
return array(
a.shape(),
a.dtype(),
std::make_shared<Scan>(
to_stream(s), Scan::ReduceType::LogAddExp, axis, reverse, inclusive),
{a});
}
/** Convolution operations */ /** Convolution operations */
namespace { namespace {
@ -4006,6 +4028,7 @@ array gather_qmm(
bool transpose /* = true */, bool transpose /* = true */,
int group_size /* = 64 */, int group_size /* = 64 */,
int bits /* = 4 */, int bits /* = 4 */,
bool sorted_indices /* = false */,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
if (!lhs_indices_ && !rhs_indices_) { if (!lhs_indices_ && !rhs_indices_) {
return quantized_matmul( return quantized_matmul(
@ -4045,13 +4068,19 @@ array gather_qmm(
return array( return array(
std::move(out_shape), std::move(out_shape),
out_type, out_type,
std::make_shared<GatherQMM>(to_stream(s), group_size, bits, transpose), std::make_shared<GatherQMM>(
to_stream(s),
group_size,
bits,
transpose,
sorted_indices && !rhs_indices_,
sorted_indices && !lhs_indices_),
{astype(x, out_type, s), {astype(x, out_type, s),
w, std::move(w),
astype(scales, out_type, s), astype(scales, out_type, s),
astype(biases, out_type, s), astype(biases, out_type, s),
lhs_indices, std::move(lhs_indices),
rhs_indices}); std::move(rhs_indices)});
} }
array tensordot( array tensordot(
@ -4477,6 +4506,7 @@ array gather_mm(
array b, array b,
std::optional<array> lhs_indices_ /* = std::nullopt */, std::optional<array> lhs_indices_ /* = std::nullopt */,
std::optional<array> rhs_indices_ /* = std::nullopt */, std::optional<array> rhs_indices_ /* = std::nullopt */,
bool sorted_indices /* = false */,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
// If no indices, fall back to full matmul // If no indices, fall back to full matmul
if (!lhs_indices_ && !rhs_indices_) { if (!lhs_indices_ && !rhs_indices_) {
@ -4552,12 +4582,18 @@ array gather_mm(
out_shape.push_back(M); out_shape.push_back(M);
out_shape.push_back(N); out_shape.push_back(N);
// Caculate array // Make the output array
auto out = array( auto out = array(
std::move(out_shape), std::move(out_shape),
out_type, out_type,
std::make_shared<GatherMM>(to_stream(s)), std::make_shared<GatherMM>(
{a, b, lhs_indices, rhs_indices}); to_stream(s),
sorted_indices && !rhs_indices_,
sorted_indices && !lhs_indices_),
{std::move(a),
std::move(b),
std::move(lhs_indices),
std::move(rhs_indices)});
// Remove the possibly inserted singleton dimensions // Remove the possibly inserted singleton dimensions
std::vector<int> axes; std::vector<int> axes;
@ -4879,8 +4915,10 @@ array operator^(const array& a, const array& b) {
} }
array left_shift(const array& a, const array& b, StreamOrDevice s /* = {} */) { array left_shift(const array& a, const array& b, StreamOrDevice s /* = {} */) {
// Bit shift on bool always up-casts to uint8 auto t = result_type(a, b);
auto t = promote_types(result_type(a, b), uint8); if (t == bool_) {
t = uint8;
}
return bitwise_impl( return bitwise_impl(
astype(a, t, s), astype(a, t, s),
astype(b, t, s), astype(b, t, s),
@ -4893,8 +4931,10 @@ array operator<<(const array& a, const array& b) {
} }
array right_shift(const array& a, const array& b, StreamOrDevice s /* = {} */) { array right_shift(const array& a, const array& b, StreamOrDevice s /* = {} */) {
// Bit shift on bool always up-casts to uint8 auto t = result_type(a, b);
auto t = promote_types(result_type(a, b), uint8); if (t == bool_) {
t = uint8;
}
return bitwise_impl( return bitwise_impl(
astype(a, t, s), astype(a, t, s),
astype(b, t, s), astype(b, t, s),

View File

@ -715,6 +715,14 @@ array topk(const array& a, int k, StreamOrDevice s = {});
/** Returns topk elements of the array along a given axis. */ /** Returns topk elements of the array along a given axis. */
array topk(const array& a, int k, int axis, StreamOrDevice s = {}); array topk(const array& a, int k, int axis, StreamOrDevice s = {});
/** Cumulative logsumexp of an array. */
array logcumsumexp(
const array& a,
int axis,
bool reverse = false,
bool inclusive = true,
StreamOrDevice s = {});
/** The logsumexp of all elements of the array. */ /** The logsumexp of all elements of the array. */
array logsumexp(const array& a, bool keepdims, StreamOrDevice s = {}); array logsumexp(const array& a, bool keepdims, StreamOrDevice s = {});
inline array logsumexp(const array& a, StreamOrDevice s = {}) { inline array logsumexp(const array& a, StreamOrDevice s = {}) {
@ -1344,6 +1352,7 @@ array gather_qmm(
bool transpose = true, bool transpose = true,
int group_size = 64, int group_size = 64,
int bits = 4, int bits = 4,
bool sorted_indices = false,
StreamOrDevice s = {}); StreamOrDevice s = {});
/** Returns a contraction of a and b over multiple dimensions. */ /** Returns a contraction of a and b over multiple dimensions. */
@ -1391,6 +1400,7 @@ array gather_mm(
array b, array b,
std::optional<array> lhs_indices = std::nullopt, std::optional<array> lhs_indices = std::nullopt,
std::optional<array> rhs_indices = std::nullopt, std::optional<array> rhs_indices = std::nullopt,
bool sorted_indices = false,
StreamOrDevice s = {}); StreamOrDevice s = {});
/** Extract a diagonal or construct a diagonal array */ /** Extract a diagonal or construct a diagonal array */

View File

@ -1275,6 +1275,61 @@ std::vector<array> Convolution::vjp(
return grads; return grads;
} }
std::pair<std::vector<array>, std::vector<int>> Convolution::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
auto do_conv = [&](const array& in, const array& w, int groups) {
return conv_general(
in,
w,
kernel_strides_,
padding_,
kernel_dilation_,
input_dilation_,
groups,
flip_,
stream());
};
bool in_vmap = axes[0] >= 0;
bool w_vmap = axes[1] >= 0;
auto in = inputs[0];
auto w = inputs[1];
if (in_vmap && !w_vmap) {
// flatten / unflatten the batch dimension
// of the input / output
if (axes[0] > 0) {
in = moveaxis(in, axes[0], 0, stream());
}
auto out = do_conv(flatten(in, 0, 1, stream()), w, groups_);
out = unflatten(out, 0, {in.shape(0), in.shape(1)}, stream());
return {{out}, {0}};
} else if (!in_vmap && w_vmap) {
// flatten into the output channels of w
// unflatten the channels of the output
if (axes[1] > 0) {
w = moveaxis(w, axes[1], 0, stream());
}
auto out = do_conv(in, flatten(w, 0, 1, stream()), groups_);
out = unflatten(out, -1, {w.shape(0), w.shape(1)}, stream());
return {{out}, {static_cast<int>(out.ndim() - 2)}};
} else if (in_vmap && w_vmap) {
// use a group convolution when both inputs are vmapped
auto b = in.shape(axes[0]);
in = moveaxis(in, axes[0], -2, stream());
in = flatten(in, -2, -1, stream());
if (axes[1] > 0) {
w = moveaxis(w, axes[1], 0, stream());
}
auto c_out = w.shape(1);
w = flatten(w, 0, 1, stream());
auto out = do_conv(in, w, groups_ * b);
out = unflatten(out, -1, {b, c_out}, stream());
return {{out}, {static_cast<int>(out.ndim() - 2)}};
} else {
return {{do_conv(in, w, groups_)}, {-1}};
}
}
bool Convolution::is_equivalent(const Primitive& other) const { bool Convolution::is_equivalent(const Primitive& other) const {
const Convolution& c_other = static_cast<const Convolution&>(other); const Convolution& c_other = static_cast<const Convolution&>(other);
return padding_ == c_other.padding_ && return padding_ == c_other.padding_ &&
@ -3080,6 +3135,8 @@ std::vector<array> GatherQMM::vjp(
auto& lhs_indices = primals[4]; auto& lhs_indices = primals[4];
auto& rhs_indices = primals[5]; auto& rhs_indices = primals[5];
bool sorted = left_sorted_ || right_sorted_;
for (auto arg : argnums) { for (auto arg : argnums) {
// gradient wrt to x // gradient wrt to x
if (arg == 0) { if (arg == 0) {
@ -3098,6 +3155,7 @@ std::vector<array> GatherQMM::vjp(
!transpose_, !transpose_,
group_size_, group_size_,
bits_, bits_,
sorted,
stream()), stream()),
-3, -3,
stream()), stream()),
@ -3478,6 +3536,45 @@ std::vector<array> Scan::vjp(
if (reduce_type_ == Scan::Sum) { if (reduce_type_ == Scan::Sum) {
return {cumsum(cotangents[0], axis_, !reverse_, inclusive_, stream())}; return {cumsum(cotangents[0], axis_, !reverse_, inclusive_, stream())};
} else if (reduce_type_ == Scan::LogAddExp) {
// Ref:
// https://github.com/tensorflow/tensorflow/blob/2a5910906a0e0f3dbc186ff9db6386d81a63448c/tensorflow/python/ops/math_grad.py#L1832-L1863
auto x = primals[0];
auto grad = cotangents[0];
auto results = outputs[0];
auto zero = zeros({1}, grad.dtype(), stream());
auto grad_min = array(finfo(grad.dtype()).min, grad.dtype());
// Split the incoming gradient into positive and negative part
// in order to take logs. This is required for stable results.
auto log_abs_grad = log(abs(grad, stream()), stream());
auto log_grad_positive =
where(greater(grad, zero, stream()), log_abs_grad, grad_min, stream());
auto log_grad_negative =
where(less(grad, zero, stream()), log_abs_grad, grad_min, stream());
auto output_pos = exp(
add(logcumsumexp(
subtract(log_grad_positive, results, stream()),
axis_,
!reverse_,
inclusive_,
stream()),
x,
stream()));
auto output_neg = exp(
add(logcumsumexp(
subtract(log_grad_negative, results, stream()),
axis_,
!reverse_,
inclusive_,
stream()),
x,
stream()));
return {subtract(output_pos, output_neg, stream())};
} else if (reduce_type_ == Scan::Prod) { } else if (reduce_type_ == Scan::Prod) {
auto in = primals[0]; auto in = primals[0];
// Find the location of the first 0 and set it to 1: // Find the location of the first 0 and set it to 1:
@ -4856,6 +4953,8 @@ std::vector<array> GatherMM::vjp(
int N = cotan.shape(-1); int N = cotan.shape(-1);
int K = primals[0].shape(-1); int K = primals[0].shape(-1);
bool sorted = left_sorted_ || right_sorted_;
for (auto arg : argnums) { for (auto arg : argnums) {
if (arg == 0) { if (arg == 0) {
// M X N * (K X N).T -> M X K // M X N * (K X N).T -> M X K
@ -4866,7 +4965,8 @@ std::vector<array> GatherMM::vjp(
base = reshape(base, {-1, M, K}, stream()); base = reshape(base, {-1, M, K}, stream());
// g : (out_batch_shape) + (M, K) // g : (out_batch_shape) + (M, K)
auto g = gather_mm(cotan, bt, std::nullopt, rhs_indices, stream()); auto g =
gather_mm(cotan, bt, std::nullopt, rhs_indices, sorted, stream());
g = expand_dims(g, -3, stream()); g = expand_dims(g, -3, stream());
auto gacc = scatter_add(base, lhs_indices, g, 0, stream()); auto gacc = scatter_add(base, lhs_indices, g, 0, stream());
@ -4881,7 +4981,8 @@ std::vector<array> GatherMM::vjp(
base = reshape(base, {-1, K, N}, stream()); base = reshape(base, {-1, K, N}, stream());
// g : (out_batch_shape) + (K, N) // g : (out_batch_shape) + (K, N)
auto g = gather_mm(at, cotan, lhs_indices, std::nullopt, stream()); auto g =
gather_mm(at, cotan, lhs_indices, std::nullopt, sorted, stream());
g = expand_dims(g, -3, stream()); g = expand_dims(g, -3, stream());
auto gacc = scatter_add(base, rhs_indices, g, 0, stream()); auto gacc = scatter_add(base, rhs_indices, g, 0, stream());
@ -4894,6 +4995,12 @@ std::vector<array> GatherMM::vjp(
return vjps; return vjps;
} }
bool GatherMM::is_equivalent(const Primitive& other) const {
const GatherMM& g_other = static_cast<const GatherMM&>(other);
return left_sorted_ == g_other.left_sorted_ &&
right_sorted_ == g_other.right_sorted_;
}
bool BlockMaskedMM::is_equivalent(const Primitive& other) const { bool BlockMaskedMM::is_equivalent(const Primitive& other) const {
const BlockMaskedMM& a_other = static_cast<const BlockMaskedMM&>(other); const BlockMaskedMM& a_other = static_cast<const BlockMaskedMM&>(other);
return (block_size_ == a_other.block_size_); return (block_size_ == a_other.block_size_);

View File

@ -498,7 +498,13 @@ class BlockMaskedMM : public UnaryPrimitive {
class GatherMM : public UnaryPrimitive { class GatherMM : public UnaryPrimitive {
public: public:
explicit GatherMM(Stream stream) : UnaryPrimitive(stream) {} explicit GatherMM(
Stream stream,
bool left_sorted = false,
bool right_sorted = false)
: UnaryPrimitive(stream),
left_sorted_(left_sorted),
right_sorted_(right_sorted) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -510,7 +516,14 @@ class GatherMM : public UnaryPrimitive {
const std::vector<array>& outputs) override; const std::vector<array>& outputs) override;
DEFINE_PRINT(GatherMM) DEFINE_PRINT(GatherMM)
DEFINE_DEFAULT_IS_EQUIVALENT() bool is_equivalent(const Primitive& other) const override;
auto state() const {
return std::make_pair(left_sorted_, right_sorted_);
}
private:
bool left_sorted_;
bool right_sorted_;
}; };
class BroadcastAxes : public UnaryPrimitive { class BroadcastAxes : public UnaryPrimitive {
@ -698,6 +711,7 @@ class Convolution : public UnaryPrimitive {
const std::vector<int>& argnums, const std::vector<int>& argnums,
const std::vector<array>& outputs) override; const std::vector<array>& outputs) override;
DEFINE_VMAP()
DEFINE_PRINT(Convolution) DEFINE_PRINT(Convolution)
bool is_equivalent(const Primitive& other) const override; bool is_equivalent(const Primitive& other) const override;
auto state() const { auto state() const {
@ -1578,11 +1592,19 @@ class QuantizedMatmul : public UnaryPrimitive {
class GatherQMM : public UnaryPrimitive { class GatherQMM : public UnaryPrimitive {
public: public:
explicit GatherQMM(Stream stream, int group_size, int bits, bool transpose) explicit GatherQMM(
Stream stream,
int group_size,
int bits,
bool transpose,
bool left_sorted = false,
bool right_sorted = false)
: UnaryPrimitive(stream), : UnaryPrimitive(stream),
group_size_(group_size), group_size_(group_size),
bits_(bits), bits_(bits),
transpose_(transpose) {} transpose_(transpose),
left_sorted_(left_sorted),
right_sorted_(right_sorted) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -1592,13 +1614,16 @@ class GatherQMM : public UnaryPrimitive {
DEFINE_PRINT(GatherQMM) DEFINE_PRINT(GatherQMM)
bool is_equivalent(const Primitive& other) const override; bool is_equivalent(const Primitive& other) const override;
auto state() const { auto state() const {
return std::make_tuple(group_size_, bits_, transpose_); return std::make_tuple(
group_size_, bits_, transpose_, left_sorted_, right_sorted_);
} }
private: private:
int group_size_; int group_size_;
int bits_; int bits_;
bool transpose_; bool transpose_;
bool left_sorted_;
bool right_sorted_;
}; };
class RandomBits : public UnaryPrimitive { class RandomBits : public UnaryPrimitive {
@ -1728,7 +1753,7 @@ class Round : public UnaryPrimitive {
class Scan : public UnaryPrimitive { class Scan : public UnaryPrimitive {
public: public:
enum ReduceType { Max, Min, Sum, Prod }; enum ReduceType { Max, Min, Sum, Prod, LogAddExp };
explicit Scan( explicit Scan(
Stream stream, Stream stream,
@ -1763,6 +1788,9 @@ class Scan : public UnaryPrimitive {
case Max: case Max:
os << "Max"; os << "Max";
break; break;
case LogAddExp:
os << "Logaddexp";
break;
} }
} }
bool is_equivalent(const Primitive& other) const override; bool is_equivalent(const Primitive& other) const override;

View File

@ -5,6 +5,7 @@
#include <sstream> #include <sstream>
#include <vector> #include <vector>
#include "mlx/dtype_utils.h"
#include "mlx/types/limits.h" #include "mlx/types/limits.h"
#include "mlx/utils.h" #include "mlx/utils.h"
@ -224,37 +225,7 @@ void print_array(std::ostream& os, const array& a) {
} // namespace } // namespace
std::ostream& operator<<(std::ostream& os, const Dtype& dtype) { std::ostream& operator<<(std::ostream& os, const Dtype& dtype) {
switch (dtype) { return os << dtype_to_string(dtype);
case bool_:
return os << "bool";
case uint8:
return os << "uint8";
case uint16:
return os << "uint16";
case uint32:
return os << "uint32";
case uint64:
return os << "uint64";
case int8:
return os << "int8";
case int16:
return os << "int16";
case int32:
return os << "int32";
case int64:
return os << "int64";
case float16:
return os << "float16";
case float32:
return os << "float32";
case float64:
return os << "float64";
case bfloat16:
return os << "bfloat16";
case complex64:
return os << "complex64";
}
return os;
} }
std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k) { std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k) {
@ -277,50 +248,7 @@ std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k) {
std::ostream& operator<<(std::ostream& os, array a) { std::ostream& operator<<(std::ostream& os, array a) {
a.eval(); a.eval();
switch (a.dtype()) { MLX_SWITCH_ALL_TYPES(a.dtype(), CTYPE, print_array<CTYPE>(os, a));
case bool_:
print_array<bool>(os, a);
break;
case uint8:
print_array<uint8_t>(os, a);
break;
case uint16:
print_array<uint16_t>(os, a);
break;
case uint32:
print_array<uint32_t>(os, a);
break;
case uint64:
print_array<uint64_t>(os, a);
break;
case int8:
print_array<int8_t>(os, a);
break;
case int16:
print_array<int16_t>(os, a);
break;
case int32:
print_array<int32_t>(os, a);
break;
case int64:
print_array<int64_t>(os, a);
break;
case float16:
print_array<float16_t>(os, a);
break;
case bfloat16:
print_array<bfloat16_t>(os, a);
break;
case float32:
print_array<float>(os, a);
break;
case float64:
print_array<double>(os, a);
break;
case complex64:
print_array<complex64_t>(os, a);
break;
}
return os; return os;
} }
@ -387,36 +315,8 @@ void set_iinfo_limits(int64_t& min, uint64_t& max) {
} }
iinfo::iinfo(Dtype dtype) : dtype(dtype) { iinfo::iinfo(Dtype dtype) : dtype(dtype) {
switch (dtype) { MLX_SWITCH_INT_TYPES_CHECKED(
case int8: dtype, "[iinfo]", CTYPE, set_iinfo_limits<CTYPE>(min, max));
set_iinfo_limits<int8_t>(min, max);
break;
case uint8:
set_iinfo_limits<uint8_t>(min, max);
break;
case int16:
set_iinfo_limits<int16_t>(min, max);
break;
case uint16:
set_iinfo_limits<uint16_t>(min, max);
break;
case int32:
set_iinfo_limits<int32_t>(min, max);
break;
case uint32:
set_iinfo_limits<uint32_t>(min, max);
break;
case int64:
set_iinfo_limits<int64_t>(min, max);
break;
case uint64:
set_iinfo_limits<uint64_t>(min, max);
break;
default:
std::ostringstream msg;
msg << "[iinfo] dtype " << dtype << " is not integral.";
throw std::invalid_argument(msg.str());
}
} }
} // namespace mlx::core } // namespace mlx::core

View File

@ -3,8 +3,8 @@
#pragma once #pragma once
#define MLX_VERSION_MAJOR 0 #define MLX_VERSION_MAJOR 0
#define MLX_VERSION_MINOR 24 #define MLX_VERSION_MINOR 25
#define MLX_VERSION_PATCH 2 #define MLX_VERSION_PATCH 0
#define MLX_VERSION_NUMERIC \ #define MLX_VERSION_NUMERIC \
(100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH) (100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH)

View File

@ -1202,6 +1202,28 @@ void init_array(nb::module_& m) {
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
"See :func:`max`.") "See :func:`max`.")
.def(
"logcumsumexp",
[](const mx::array& a,
std::optional<int> axis,
bool reverse,
bool inclusive,
mx::StreamOrDevice s) {
if (axis) {
return mx::logcumsumexp(a, *axis, reverse, inclusive, s);
} else {
// TODO: Implement that in the C++ API as well. See concatenate
// above.
return mx::logcumsumexp(
mx::reshape(a, {-1}, s), 0, reverse, inclusive, s);
}
},
"axis"_a = nb::none(),
nb::kw_only(),
"reverse"_a = false,
"inclusive"_a = true,
"stream"_a = nb::none(),
"See :func:`logcumsumexp`.")
.def( .def(
"logsumexp", "logsumexp",
[](const mx::array& a, [](const mx::array& a,

View File

@ -1,8 +1,8 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
#include <nanobind/nanobind.h> #include <nanobind/nanobind.h>
#include <nanobind/stl/map.h>
#include <nanobind/stl/optional.h> #include <nanobind/stl/optional.h>
#include <nanobind/stl/string.h> #include <nanobind/stl/string.h>
#include <nanobind/stl/unordered_map.h>
#include <nanobind/stl/vector.h> #include <nanobind/stl/vector.h>
#include <fstream> #include <fstream>
@ -16,8 +16,7 @@ namespace mx = mlx::core;
namespace nb = nanobind; namespace nb = nanobind;
using namespace nb::literals; using namespace nb::literals;
std::pair<std::vector<mx::array>, std::map<std::string, mx::array>> std::pair<mx::Args, mx::Kwargs> validate_and_extract_inputs(
validate_and_extract_inputs(
const nb::args& args, const nb::args& args,
const nb::kwargs& kwargs, const nb::kwargs& kwargs,
const std::string& prefix) { const std::string& prefix) {
@ -30,8 +29,8 @@ validate_and_extract_inputs(
"and/or dictionary of arrays."); "and/or dictionary of arrays.");
} }
}; };
std::vector<mx::array> args_; mx::Args args_;
std::map<std::string, mx::array> kwargs_; mx::Kwargs kwargs_;
if (args.size() == 0) { if (args.size() == 0) {
// No args so kwargs must be keyword arrays // No args so kwargs must be keyword arrays
maybe_throw(nb::try_cast(kwargs, kwargs_)); maybe_throw(nb::try_cast(kwargs, kwargs_));
@ -81,9 +80,7 @@ class PyFunctionExporter {
void close() { void close() {
exporter_.close(); exporter_.close();
} }
void operator()( void operator()(const mx::Args& args, const mx::Kwargs& kwargs) {
const std::vector<mx::array>& args,
const std::map<std::string, mx::array>& kwargs) {
exporter_(args, kwargs); exporter_(args, kwargs);
} }
@ -98,9 +95,12 @@ int py_function_exporter_tp_traverse(
PyObject* self, PyObject* self,
visitproc visit, visitproc visit,
void* arg) { void* arg) {
Py_VISIT(Py_TYPE(self));
if (!nb::inst_ready(self)) {
return 0;
}
auto* p = nb::inst_ptr<PyFunctionExporter>(self); auto* p = nb::inst_ptr<PyFunctionExporter>(self);
Py_VISIT(p->dep_.ptr()); Py_VISIT(p->dep_.ptr());
Py_VISIT(Py_TYPE(self));
return 0; return 0;
} }
@ -109,23 +109,22 @@ PyType_Slot py_function_exporter_slots[] = {
{0, 0}}; {0, 0}};
auto wrap_export_function(nb::callable fun) { auto wrap_export_function(nb::callable fun) {
return [fun = std::move(fun)]( return
const std::vector<mx::array>& args_, [fun = std::move(fun)](const mx::Args& args_, const mx::Kwargs& kwargs_) {
const std::map<std::string, mx::array>& kwargs_) { auto kwargs = nb::dict();
auto kwargs = nb::dict(); kwargs.update(nb::cast(kwargs_));
kwargs.update(nb::cast(kwargs_)); auto args = nb::tuple(nb::cast(args_));
auto args = nb::tuple(nb::cast(args_)); auto outputs = fun(*args, **kwargs);
auto outputs = fun(*args, **kwargs); std::vector<mx::array> outputs_;
std::vector<mx::array> outputs_; if (nb::isinstance<mx::array>(outputs)) {
if (nb::isinstance<mx::array>(outputs)) { outputs_.push_back(nb::cast<mx::array>(outputs));
outputs_.push_back(nb::cast<mx::array>(outputs)); } else if (!nb::try_cast(outputs, outputs_)) {
} else if (!nb::try_cast(outputs, outputs_)) { throw std::invalid_argument(
throw std::invalid_argument( "[export_function] Outputs can be either a single array "
"[export_function] Outputs can be either a single array " "a tuple or list of arrays.");
"a tuple or list of arrays."); }
} return outputs_;
return outputs_; };
};
} }
void init_export(nb::module_& m) { void init_export(nb::module_& m) {

View File

@ -16,12 +16,12 @@ struct gc_func {
}; };
int gc_func_tp_traverse(PyObject* self, visitproc visit, void* arg) { int gc_func_tp_traverse(PyObject* self, visitproc visit, void* arg) {
Py_VISIT(Py_TYPE(self));
gc_func* w = (gc_func*)self; gc_func* w = (gc_func*)self;
Py_VISIT(w->func); Py_VISIT(w->func);
for (auto d : w->deps) { for (auto d : w->deps) {
Py_VISIT(d); Py_VISIT(d);
} }
Py_VISIT(Py_TYPE(self));
return 0; return 0;
}; };

View File

@ -2382,6 +2382,43 @@ void init_ops(nb::module_& m) {
Returns: Returns:
array: The output array with the corresponding axes reduced. array: The output array with the corresponding axes reduced.
)pbdoc"); )pbdoc");
m.def(
"logcumsumexp",
[](const mx::array& a,
std::optional<int> axis,
bool reverse,
bool inclusive,
mx::StreamOrDevice s) {
if (axis) {
return mx::logcumsumexp(a, *axis, reverse, inclusive, s);
} else {
return mx::logcumsumexp(
mx::reshape(a, {-1}, s), 0, reverse, inclusive, s);
}
},
nb::arg(),
"axis"_a = nb::none(),
nb::kw_only(),
"reverse"_a = false,
"inclusive"_a = true,
"stream"_a = nb::none(),
nb::sig(
"def logcumsumexp(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Return the cumulative logsumexp of the elements along the given axis.
Args:
a (array): Input array
axis (int, optional): Optional axis to compute the cumulative logsumexp
over. If unspecified the cumulative logsumexp of the flattened array is
returned.
reverse (bool): Perform the cumulative logsumexp in reverse.
inclusive (bool): The i-th element of the output includes the i-th
element of the input.
Returns:
array: The output array.
)pbdoc");
m.def( m.def(
"logsumexp", "logsumexp",
[](const mx::array& a, [](const mx::array& a,
@ -4213,9 +4250,10 @@ void init_ops(nb::module_& m) {
"group_size"_a = 64, "group_size"_a = 64,
"bits"_a = 4, "bits"_a = 4,
nb::kw_only(), nb::kw_only(),
"sorted_indices"_a = false,
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def gather_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"), "def gather_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, *, sorted_indices: bool = False, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
Perform quantized matrix multiplication with matrix-level gather. Perform quantized matrix multiplication with matrix-level gather.
@ -4228,23 +4266,25 @@ void init_ops(nb::module_& m) {
as ``w`` since they represent the same quantized matrix. as ``w`` since they represent the same quantized matrix.
Args: Args:
x (array): Input array x (array): Input array
w (array): Quantized matrix packed in unsigned integers w (array): Quantized matrix packed in unsigned integers
scales (array): The scales to use per ``group_size`` elements of ``w`` scales (array): The scales to use per ``group_size`` elements of ``w``
biases (array): The biases to use per ``group_size`` elements of ``w`` biases (array): The biases to use per ``group_size`` elements of ``w``
lhs_indices (array, optional): Integer indices for ``x``. Default: ``None``. lhs_indices (array, optional): Integer indices for ``x``. Default: ``None``.
rhs_indices (array, optional): Integer indices for ``w``. Default: ``None``. rhs_indices (array, optional): Integer indices for ``w``. Default: ``None``.
transpose (bool, optional): Defines whether to multiply with the transpose (bool, optional): Defines whether to multiply with the
transposed ``w`` or not, namely whether we are performing transposed ``w`` or not, namely whether we are performing
``x @ w.T`` or ``x @ w``. Default: ``True``. ``x @ w.T`` or ``x @ w``. Default: ``True``.
group_size (int, optional): The size of the group in ``w`` that group_size (int, optional): The size of the group in ``w`` that
shares a scale and bias. Default: ``64``. shares a scale and bias. Default: ``64``.
bits (int, optional): The number of bits occupied by each element in bits (int, optional): The number of bits occupied by each element in
``w``. Default: ``4``. ``w``. Default: ``4``.
sorted_indices (bool, optional): May allow a faster implementation
if the passed indices are sorted. Default: ``False``.
Returns: Returns:
array: The result of the multiplication of ``x`` with ``w`` array: The result of the multiplication of ``x`` with ``w``
after gathering using ``lhs_indices`` and ``rhs_indices``. after gathering using ``lhs_indices`` and ``rhs_indices``.
)pbdoc"); )pbdoc");
m.def( m.def(
"tensordot", "tensordot",
@ -4274,16 +4314,16 @@ void init_ops(nb::module_& m) {
Compute the tensor dot product along the specified axes. Compute the tensor dot product along the specified axes.
Args: Args:
a (array): Input array a (array): Input array
b (array): Input array b (array): Input array
axes (int or list(list(int)), optional): The number of dimensions to axes (int or list(list(int)), optional): The number of dimensions to
sum over. If an integer is provided, then sum over the last sum over. If an integer is provided, then sum over the last
``axes`` dimensions of ``a`` and the first ``axes`` dimensions of ``axes`` dimensions of ``a`` and the first ``axes`` dimensions of
``b``. If a list of lists is provided, then sum over the ``b``. If a list of lists is provided, then sum over the
corresponding dimensions of ``a`` and ``b``. Default: 2. corresponding dimensions of ``a`` and ``b``. Default: 2.
Returns: Returns:
array: The tensor dot product. array: The tensor dot product.
)pbdoc"); )pbdoc");
m.def( m.def(
"inner", "inner",
@ -4427,9 +4467,10 @@ void init_ops(nb::module_& m) {
"lhs_indices"_a = nb::none(), "lhs_indices"_a = nb::none(),
"rhs_indices"_a = nb::none(), "rhs_indices"_a = nb::none(),
nb::kw_only(), nb::kw_only(),
"sorted_indices"_a = false,
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def gather_mm(a: array, b: array, /, lhs_indices: array, rhs_indices: array, *, stream: Union[None, Stream, Device] = None) -> array"), "def gather_mm(a: array, b: array, /, lhs_indices: array, rhs_indices: array, *, sorted_indices: bool = False, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
Matrix multiplication with matrix-level gather. Matrix multiplication with matrix-level gather.
@ -4448,11 +4489,16 @@ void init_ops(nb::module_& m) {
For ``b`` with shape ``(B1, B2, ..., BS, M, K)``, ``rhs_indices`` For ``b`` with shape ``(B1, B2, ..., BS, M, K)``, ``rhs_indices``
contains indices from the range ``[0, B1 * B2 * ... * BS)`` contains indices from the range ``[0, B1 * B2 * ... * BS)``
If only one index is passed and it is sorted, the ``sorted_indices``
flag can be passed for a possible faster implementation.
Args: Args:
a (array): Input array. a (array): Input array.
b (array): Input array. b (array): Input array.
lhs_indices (array, optional): Integer indices for ``a``. Default: ``None`` lhs_indices (array, optional): Integer indices for ``a``. Default: ``None``
rhs_indices (array, optional): Integer indices for ``b``. Default: ``None`` rhs_indices (array, optional): Integer indices for ``b``. Default: ``None``
sorted_indices (bool, optional): May allow a faster implementation
if the passed indices are sorted. Default: ``False``.
Returns: Returns:
array: The output array. array: The output array.

View File

@ -960,6 +960,11 @@ class PyCustomFunction {
}; };
int py_custom_function_tp_traverse(PyObject* self, visitproc visit, void* arg) { int py_custom_function_tp_traverse(PyObject* self, visitproc visit, void* arg) {
Py_VISIT(Py_TYPE(self));
if (!nb::inst_ready(self)) {
return 0;
}
auto* p = nb::inst_ptr<PyCustomFunction>(self); auto* p = nb::inst_ptr<PyCustomFunction>(self);
nb::handle v = nb::find(p->fun_); nb::handle v = nb::find(p->fun_);
Py_VISIT(v.ptr()); Py_VISIT(v.ptr());
@ -975,7 +980,6 @@ int py_custom_function_tp_traverse(PyObject* self, visitproc visit, void* arg) {
nb::handle v = nb::find(*(p->vmap_fun_)); nb::handle v = nb::find(*(p->vmap_fun_));
Py_VISIT(v.ptr()); Py_VISIT(v.ptr());
} }
Py_VISIT(Py_TYPE(self));
return 0; return 0;
} }
int py_custom_function_tp_clear(PyObject* self) { int py_custom_function_tp_clear(PyObject* self) {

View File

@ -1508,6 +1508,7 @@ class TestArray(mlx_tests.MLXTestCase):
("prod", 1), ("prod", 1),
("min", 1), ("min", 1),
("max", 1), ("max", 1),
("logcumsumexp", 1),
("logsumexp", 1), ("logsumexp", 1),
("mean", 1), ("mean", 1),
("var", 1), ("var", 1),

View File

@ -1108,7 +1108,7 @@ class TestBlas(mlx_tests.MLXTestCase):
lhs_indices_ = mx.broadcast_to(lhs_indices, (3, 2)) lhs_indices_ = mx.broadcast_to(lhs_indices, (3, 2))
rhs_indices_ = mx.broadcast_to(rhs_indices, (3, 2)) rhs_indices_ = mx.broadcast_to(rhs_indices, (3, 2))
M = a.shape[-2] M = a.shape[-2]
N = b.shape[-2] N = b.shape[-1]
K = a.shape[-1] K = a.shape[-1]
a = a.reshape((-1, M, K)) a = a.reshape((-1, M, K))

View File

@ -194,6 +194,11 @@ class TestFFT(mlx_tests.MLXTestCase):
r_np = np.fft.ifft(segment, n=n_fft) r_np = np.fft.ifft(segment, n=n_fft)
self.assertTrue(np.allclose(r, r_np, atol=1e-5, rtol=1e-5)) self.assertTrue(np.allclose(r, r_np, atol=1e-5, rtol=1e-5))
def test_fft_throws(self):
x = mx.array(3.0)
with self.assertRaises(ValueError):
mx.fft.irfftn(x)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -1857,6 +1857,30 @@ class TestOps(mlx_tests.MLXTestCase):
y = mx.as_strided(x, (x.size,), (-1,), x.size - 1) y = mx.as_strided(x, (x.size,), (-1,), x.size - 1)
self.assertTrue(mx.array_equal(y, x[::-1])) self.assertTrue(mx.array_equal(y, x[::-1]))
def test_logcumsumexp(self):
npop = np.logaddexp.accumulate
mxop = mx.logcumsumexp
a_npy = np.random.randn(32, 32, 32).astype(np.float32)
a_mlx = mx.array(a_npy)
for axis in (0, 1, 2):
c_npy = npop(a_npy, axis=axis)
c_mlx = mxop(a_mlx, axis=axis)
self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3))
edge_cases_npy = [
np.float32([-float("inf")] * 8),
np.float32([-float("inf"), 0, -float("inf")]),
np.float32([-float("inf"), float("inf"), -float("inf")]),
]
edge_cases_mlx = [mx.array(a) for a in edge_cases_npy]
for a_npy, a_mlx in zip(edge_cases_npy, edge_cases_mlx):
c_npy = npop(a_npy, axis=0)
c_mlx = mxop(a_mlx, axis=0)
self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3))
def test_scans(self): def test_scans(self):
a_npy = np.random.randn(32, 32, 32).astype(np.float32) a_npy = np.random.randn(32, 32, 32).astype(np.float32)
a_mlx = mx.array(a_npy) a_mlx = mx.array(a_npy)
@ -2910,6 +2934,35 @@ class TestOps(mlx_tests.MLXTestCase):
out = a[::-1] out = a[::-1]
self.assertTrue(mx.array_equal(out[-1, :], a[0, :])) self.assertTrue(mx.array_equal(out[-1, :], a[0, :]))
def test_complex_ops(self):
x = mx.array(
[
3.0 + 4.0j,
-5.0 + 12.0j,
-8.0 + 0.0j,
0.0 + 9.0j,
0.0 + 0.0j,
]
)
ops = ["arccos", "arcsin", "arctan", "square", "sqrt"]
for op in ops:
with self.subTest(op=op):
np_op = getattr(np, op)
mx_op = getattr(mx, op)
self.assertTrue(np.allclose(mx_op(x), np_op(x)))
x = mx.array(
[
3.0 + 4.0j,
-5.0 + 12.0j,
-8.0 + 0.0j,
0.0 + 9.0j,
9.0 + 1.0j,
]
)
self.assertTrue(np.allclose(mx.rsqrt(x), 1.0 / np.sqrt(x)))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -174,12 +174,14 @@ class TestQuantized(mlx_tests.MLXTestCase):
tests = product( tests = product(
[128, 64, 32], # group_size [128, 64, 32], # group_size
[2, 3, 4, 6, 8], # bits [2, 3, 4, 6, 8], # bits
[128, 256], # M [32, 128, 256], # M
[128, 256, 67], # N [128, 256, 67], # N
[0, 1, 3, 8], # B [0, 1, 3, 8], # B
) )
for group_size, bits, M, N, B in tests: for group_size, bits, M, N, B in tests:
with self.subTest(shape=(B, M, N), group_size=group_size, bits=bits): with self.subTest(shape=(B, M, N), group_size=group_size, bits=bits):
if M < group_size:
continue
x_shape = (1, N) if B == 0 else (B, 1, N) x_shape = (1, N) if B == 0 else (B, 1, N)
w_shape = (N, M) if B == 0 else (B, N, M) w_shape = (N, M) if B == 0 else (B, N, M)
x = mx.random.normal(shape=x_shape, key=k1) x = mx.random.normal(shape=x_shape, key=k1)
@ -448,6 +450,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
) )
for kwargs in inputs: for kwargs in inputs:
test_shape(1, 32, 128, **kwargs)
test_shape(32, 32, 256, **kwargs) test_shape(32, 32, 256, **kwargs)
test_shape(1, 32, 256, **kwargs) test_shape(1, 32, 256, **kwargs)
test_shape(32, 256, 32, transpose=False, **kwargs) test_shape(32, 256, 32, transpose=False, **kwargs)
@ -486,6 +489,66 @@ class TestQuantized(mlx_tests.MLXTestCase):
g2 = mx.grad(f_test)(x, qw, s, b, lhs_indices, rhs_indices) g2 = mx.grad(f_test)(x, qw, s, b, lhs_indices, rhs_indices)
self.assertTrue(mx.allclose(g1, g2, atol=1e-4)) self.assertTrue(mx.allclose(g1, g2, atol=1e-4))
def test_gather_qmm_sorted(self):
def quantize(w, transpose=True, group_size=64, bits=4):
qw, s, b = mx.quantize(w, group_size=group_size, bits=bits)
w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits)
if transpose:
w_hat = w_hat.swapaxes(-1, -2)
return w_hat, qw, s, b
def gather_sort(x, indices):
N, M = indices.shape
indices = indices.flatten()
order = mx.argsort(indices)
inv_order = mx.argsort(order)
return x.flatten(0, -3)[order // M], indices[order], inv_order
def scatter_unsort(x, inv_order, shape=None):
x = x[inv_order]
if shape is not None:
x = mx.unflatten(x, 0, shape)
return x
parameters = [
# L, K, D, E, I, transpose
(128, 1024, 1024, 32, 4, True),
(128, 1024, 544, 32, 4, True),
(433, 1024, 1024, 32, 4, True),
(433, 1024, 555, 32, 4, True),
(433, 2048, 1024, 32, 4, True),
(128, 1024, 1024, 32, 4, False),
(128, 1024, 544, 32, 4, False),
(433, 1024, 1024, 32, 4, False),
(433, 1024, 544, 32, 4, False),
(433, 1024, 555, 32, 4, False),
(433, 2048, 1024, 32, 4, False),
]
for L, K, D, E, I, transpose in parameters:
K, D = (K, D) if transpose else (D, K)
ishape = (L, I)
xshape = (L, 1, 1, K)
wshape = (E, D, K) if transpose else (E, K, D)
indices = (mx.random.uniform(shape=ishape) * E).astype(mx.uint32)
x = mx.random.normal(xshape) / K**0.5
w = mx.random.normal(wshape) / K**0.5
w, *wq = quantize(w, transpose=transpose)
y1 = mx.gather_mm(x, w, rhs_indices=indices)
y2 = mx.gather_qmm(x, *wq, transpose=transpose, rhs_indices=indices)
xs, idx, inv_order = gather_sort(x, indices)
y3 = mx.gather_mm(xs, w, rhs_indices=idx, sorted_indices=True)
y4 = mx.gather_qmm(
xs, *wq, rhs_indices=idx, transpose=transpose, sorted_indices=True
)
y3 = scatter_unsort(y3, inv_order, indices.shape)
y4 = scatter_unsort(y4, inv_order, indices.shape)
self.assertTrue(mx.allclose(y1, y2, atol=1e-5))
self.assertTrue(mx.allclose(y1, y3, atol=1e-5))
self.assertTrue(mx.allclose(y1, y4, atol=1e-5))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -669,6 +669,57 @@ class TestVmap(mlx_tests.MLXTestCase):
self.assertEqual(mx.vmap(fun, in_axes=(1,))(x).shape, (3, 8)) self.assertEqual(mx.vmap(fun, in_axes=(1,))(x).shape, (3, 8))
self.assertEqual(mx.vmap(fun, in_axes=(2,))(x).shape, (4, 6)) self.assertEqual(mx.vmap(fun, in_axes=(2,))(x).shape, (4, 6))
def test_vmap_conv(self):
# vmap input only
x = mx.random.uniform(shape=(2, 2, 5, 4))
w = mx.random.uniform(shape=(8, 3, 4))
expected = mx.stack([mx.conv1d(xi, w) for xi in x])
out = mx.vmap(mx.conv1d, in_axes=(0, None))(x, w)
self.assertTrue(mx.allclose(expected, out))
x = mx.moveaxis(x, 0, 2)
out = mx.vmap(mx.conv1d, in_axes=(2, None))(x, w)
self.assertTrue(mx.allclose(expected, out))
# vmap weights only
x = mx.random.uniform(shape=(2, 5, 4))
w = mx.random.uniform(shape=(3, 8, 3, 4))
expected = mx.stack([mx.conv1d(x, wi) for wi in w])
out = mx.vmap(mx.conv1d, in_axes=(None, 0))(x, w)
self.assertTrue(mx.allclose(expected, out))
w = mx.moveaxis(w, 0, 1)
out = mx.vmap(mx.conv1d, in_axes=(None, 1))(x, w)
self.assertTrue(mx.allclose(expected, out))
# vmap weights and input
x = mx.random.uniform(shape=(3, 2, 5, 4))
w = mx.random.uniform(shape=(3, 8, 3, 4))
expected = mx.stack([mx.conv1d(xi, wi) for xi, wi in zip(x, w)])
out = mx.vmap(mx.conv1d, in_axes=(0, 0))(x, w)
self.assertTrue(mx.allclose(expected, out))
x = mx.random.uniform(shape=(2, 3, 5, 4))
w = mx.random.uniform(shape=(8, 3, 4, 3))
expected = mx.stack([mx.conv1d(x[:, i], w[..., i]) for i in range(3)])
out = mx.vmap(mx.conv1d, in_axes=(1, 3))(x, w)
self.assertTrue(mx.allclose(expected, out))
# Test with groups
x = mx.random.uniform(shape=(3, 2, 5, 8))
w = mx.random.uniform(shape=(3, 2, 3, 4))
def gconv(x, w):
return mx.conv1d(x, w, groups=2)
expected = mx.stack([gconv(xi, wi) for xi, wi in zip(x, w)])
out = mx.vmap(gconv, in_axes=(0, 0))(x, w)
self.assertTrue(mx.allclose(expected, out))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -97,8 +97,7 @@ TEST_CASE("test export primitives with state") {
TEST_CASE("test export functions with kwargs") { TEST_CASE("test export functions with kwargs") {
std::string file_path = get_temp_file("model.mlxfn"); std::string file_path = get_temp_file("model.mlxfn");
auto fun = auto fun = [](const Kwargs& kwargs) -> std::vector<array> {
[](const std::map<std::string, array>& kwargs) -> std::vector<array> {
return {kwargs.at("x") + kwargs.at("y")}; return {kwargs.at("x") + kwargs.at("y")};
}; };

View File

@ -3874,3 +3874,41 @@ TEST_CASE("test contiguous") {
CHECK(x.flags().col_contiguous); CHECK(x.flags().col_contiguous);
CHECK_EQ(x.strides(), decltype(x.strides()){1, 2}); CHECK_EQ(x.strides(), decltype(x.strides()){1, 2});
} }
TEST_CASE("test bitwise shift operations") {
std::vector<Dtype> dtypes = {
int8, int16, int32, int64, uint8, uint16, uint32, uint64};
for (const auto& dtype : dtypes) {
array x = full({4}, 1, dtype);
array y = full({4}, 2, dtype);
auto left_shift_result = left_shift(x, y);
CHECK_EQ(left_shift_result.dtype(), dtype);
CHECK(array_equal(left_shift_result, array({4, 4, 4, 4}, dtype))
.item<bool>());
auto right_shift_result = right_shift(full({4}, 4, dtype), y);
CHECK_EQ(right_shift_result.dtype(), dtype);
CHECK(array_equal(right_shift_result, full({4}, 1, dtype)).item<bool>());
}
array x = array({127, -128}, int8);
array y = array({1, 1}, int8);
auto left_shift_result = left_shift(x, y);
auto right_shift_result = right_shift(x, y);
CHECK(array_equal(left_shift_result, array({-2, 0}, int8)).item<bool>());
CHECK(array_equal(right_shift_result, array({63, -64}, int8)).item<bool>());
array x_bool = full({4}, true, bool_);
array y_bool = full({4}, true, bool_);
auto left_shift_bool_result = left_shift(x_bool, y_bool);
auto right_shift_bool_result = right_shift(x_bool, y_bool);
CHECK_EQ(left_shift_bool_result.dtype(), uint8);
CHECK(array_equal(left_shift_bool_result, full({4}, 2, uint8)).item<bool>());
CHECK_EQ(right_shift_bool_result.dtype(), uint8);
CHECK(array_equal(right_shift_bool_result, full({4}, 0, uint8)).item<bool>());
}