mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Merge branch 'main' into stft
This commit is contained in:
commit
a963a15b8d
@ -251,7 +251,7 @@ jobs:
|
|||||||
name: Install Python package
|
name: Install Python package
|
||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
source env/bin/activate
|
||||||
MACOSX_DEPLOYMENT_TARGET="" DEV_RELEASE=1 \
|
env -u MACOSX_DEPLOYMENT_TARGET DEV_RELEASE=1 \
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
||||||
pip install . -v
|
pip install . -v
|
||||||
- run:
|
- run:
|
||||||
|
74
benchmarks/python/gather_mm_bench.py
Normal file
74
benchmarks/python/gather_mm_bench.py
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
# Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from time_utils import time_fn
|
||||||
|
|
||||||
|
N = 1024
|
||||||
|
D = 1024
|
||||||
|
M = 1024
|
||||||
|
E = 32
|
||||||
|
I = 4
|
||||||
|
|
||||||
|
|
||||||
|
def gather_sort(x, indices):
|
||||||
|
N, M = indices.shape
|
||||||
|
indices = indices.flatten()
|
||||||
|
order = mx.argsort(indices)
|
||||||
|
inv_order = mx.argsort(order)
|
||||||
|
return x.flatten(0, -3)[order // M], indices[order], inv_order
|
||||||
|
|
||||||
|
|
||||||
|
def scatter_unsort(x, inv_order, shape=None):
|
||||||
|
x = x[inv_order]
|
||||||
|
if shape is not None:
|
||||||
|
x = mx.unflatten(x, 0, shape)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def gather_mm_simulate(x, w, indices):
|
||||||
|
x, idx, inv_order = gather_sort(x, indices)
|
||||||
|
for i in range(2):
|
||||||
|
y = mx.concatenate([x[i] @ w[j].T for i, j in enumerate(idx.tolist())], axis=0)
|
||||||
|
x = y[:, None]
|
||||||
|
x = scatter_unsort(x, inv_order, indices.shape)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def time_gather_mm():
|
||||||
|
x = mx.random.normal((N, 1, 1, D)) / 1024**0.5
|
||||||
|
w1 = mx.random.normal((E, M, D)) / 1024**0.5
|
||||||
|
w2 = mx.random.normal((E, D, M)) / 1024**0.5
|
||||||
|
indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)
|
||||||
|
sorted_indices = mx.sort(indices.flatten()).reshape(N, I)
|
||||||
|
mx.eval(x, w1, w2, indices, sorted_indices)
|
||||||
|
|
||||||
|
def gather_mm(x, w1, w2, indices, sort):
|
||||||
|
idx = indices
|
||||||
|
inv_order = None
|
||||||
|
if sort:
|
||||||
|
x, idx, inv_order = gather_sort(x, indices)
|
||||||
|
x = mx.gather_mm(x, w1.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort)
|
||||||
|
x = mx.gather_mm(x, w2.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort)
|
||||||
|
if sort:
|
||||||
|
x = scatter_unsort(x, inv_order, indices.shape)
|
||||||
|
return x
|
||||||
|
|
||||||
|
time_fn(gather_mm, x, w1, w2, indices, False)
|
||||||
|
time_fn(gather_mm, x, w1, w2, sorted_indices, False)
|
||||||
|
time_fn(gather_mm, x, w1, w2, indices, True)
|
||||||
|
|
||||||
|
x = mx.random.normal((N * I, D)) / 1024**0.5
|
||||||
|
w1 = mx.random.normal((M, D)) / 1024**0.5
|
||||||
|
w2 = mx.random.normal((D, M)) / 1024**0.5
|
||||||
|
mx.eval(x, w1, w2)
|
||||||
|
|
||||||
|
def equivalent_matmul(x, w1, w2):
|
||||||
|
x = x @ w1.T
|
||||||
|
x = x @ w2.T
|
||||||
|
return x
|
||||||
|
|
||||||
|
time_fn(equivalent_matmul, x, w1, w2)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
time_gather_mm()
|
84
benchmarks/python/gather_qmm_bench.py
Normal file
84
benchmarks/python/gather_qmm_bench.py
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
# Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from time_utils import time_fn
|
||||||
|
|
||||||
|
N = 1024
|
||||||
|
D = 1024
|
||||||
|
M = 1024
|
||||||
|
E = 32
|
||||||
|
I = 4
|
||||||
|
|
||||||
|
|
||||||
|
def gather_sort(x, indices):
|
||||||
|
N, M = indices.shape
|
||||||
|
indices = indices.flatten()
|
||||||
|
order = mx.argsort(indices)
|
||||||
|
inv_order = mx.argsort(order)
|
||||||
|
return x.flatten(0, -3)[order // M], indices[order], inv_order
|
||||||
|
|
||||||
|
|
||||||
|
def scatter_unsort(x, inv_order, shape=None):
|
||||||
|
x = x[inv_order]
|
||||||
|
if shape is not None:
|
||||||
|
x = mx.unflatten(x, 0, shape)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def gather_mm_simulate(x, w, indices):
|
||||||
|
x, idx, inv_order = gather_sort(x, indices)
|
||||||
|
for i in range(2):
|
||||||
|
y = mx.concatenate(
|
||||||
|
[
|
||||||
|
mx.quantized_matmul(x[i], w[0][j], w[1][j], w[2][j], transpose=True)
|
||||||
|
for i, j in enumerate(idx.tolist())
|
||||||
|
],
|
||||||
|
axis=0,
|
||||||
|
)
|
||||||
|
x = y[:, None]
|
||||||
|
x = scatter_unsort(x, inv_order, indices.shape)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def time_gather_qmm():
|
||||||
|
x = mx.random.normal((N, 1, 1, D)) / 1024**0.5
|
||||||
|
w1 = mx.random.normal((E, M, D)) / 1024**0.5
|
||||||
|
w2 = mx.random.normal((E, D, M)) / 1024**0.5
|
||||||
|
w1 = mx.quantize(w1)
|
||||||
|
w2 = mx.quantize(w2)
|
||||||
|
indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)
|
||||||
|
sorted_indices = mx.sort(indices.flatten()).reshape(N, I)
|
||||||
|
mx.eval(x, w1, w2, indices, sorted_indices)
|
||||||
|
|
||||||
|
def gather_mm(x, w1, w2, indices, sort):
|
||||||
|
idx = indices
|
||||||
|
inv_order = None
|
||||||
|
if sort:
|
||||||
|
x, idx, inv_order = gather_sort(x, indices)
|
||||||
|
x = mx.gather_qmm(x, *w1, transpose=True, rhs_indices=idx, sorted_indices=sort)
|
||||||
|
x = mx.gather_qmm(x, *w2, transpose=True, rhs_indices=idx, sorted_indices=sort)
|
||||||
|
if sort:
|
||||||
|
x = scatter_unsort(x, inv_order, indices.shape)
|
||||||
|
return x
|
||||||
|
|
||||||
|
time_fn(gather_mm, x, w1, w2, indices, False)
|
||||||
|
time_fn(gather_mm, x, w1, w2, sorted_indices, False)
|
||||||
|
time_fn(gather_mm, x, w1, w2, indices, True)
|
||||||
|
|
||||||
|
x = mx.random.normal((N * I, D)) / 1024**0.5
|
||||||
|
w1 = mx.random.normal((M, D)) / 1024**0.5
|
||||||
|
w2 = mx.random.normal((D, M)) / 1024**0.5
|
||||||
|
w1 = mx.quantize(w1)
|
||||||
|
w2 = mx.quantize(w2)
|
||||||
|
mx.eval(x, w1, w2)
|
||||||
|
|
||||||
|
def equivalent_matmul(x, w1, w2):
|
||||||
|
x = mx.quantized_matmul(x, *w1, transpose=True)
|
||||||
|
x = mx.quantized_matmul(x, *w2, transpose=True)
|
||||||
|
return x
|
||||||
|
|
||||||
|
time_fn(equivalent_matmul, x, w1, w2)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
time_gather_qmm()
|
@ -38,6 +38,7 @@ Array
|
|||||||
array.log10
|
array.log10
|
||||||
array.log1p
|
array.log1p
|
||||||
array.log2
|
array.log2
|
||||||
|
array.logcumsumexp
|
||||||
array.logsumexp
|
array.logsumexp
|
||||||
array.max
|
array.max
|
||||||
array.mean
|
array.mean
|
||||||
|
@ -103,6 +103,7 @@ Operations
|
|||||||
log10
|
log10
|
||||||
log1p
|
log1p
|
||||||
logaddexp
|
logaddexp
|
||||||
|
logcumsumexp
|
||||||
logical_not
|
logical_not
|
||||||
logical_and
|
logical_and
|
||||||
logical_or
|
logical_or
|
||||||
|
@ -5,6 +5,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/dtype_utils.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/export.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/export.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
|
||||||
|
@ -339,11 +339,11 @@ class array {
|
|||||||
return allocator::allocator().size(buffer());
|
return allocator::allocator().size(buffer());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return a copy of the shared pointer
|
// Return the shared pointer to the array::Data struct
|
||||||
// to the array::Data struct
|
const std::shared_ptr<Data>& data_shared_ptr() const {
|
||||||
std::shared_ptr<Data> data_shared_ptr() const {
|
|
||||||
return array_desc_->data;
|
return array_desc_->data;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return a raw pointer to the arrays data
|
// Return a raw pointer to the arrays data
|
||||||
template <typename T>
|
template <typename T>
|
||||||
T* data() {
|
T* data() {
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
target_sources(
|
target_sources(
|
||||||
mlx
|
mlx
|
||||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/broadcasting.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||||
|
24
mlx/backend/common/broadcasting.cpp
Normal file
24
mlx/backend/common/broadcasting.cpp
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/common/utils.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
void broadcast(const array& in, array& out) {
|
||||||
|
if (out.size() == 0) {
|
||||||
|
out.set_data(nullptr);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
Strides strides(out.ndim(), 0);
|
||||||
|
int diff = out.ndim() - in.ndim();
|
||||||
|
for (int i = in.ndim() - 1; i >= 0; --i) {
|
||||||
|
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
|
||||||
|
}
|
||||||
|
auto flags = in.flags();
|
||||||
|
if (out.size() > in.size()) {
|
||||||
|
flags.row_contiguous = flags.col_contiguous = false;
|
||||||
|
}
|
||||||
|
out.copy_shared_buffer(in, strides, flags, in.data_size());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
11
mlx/backend/common/broadcasting.h
Normal file
11
mlx/backend/common/broadcasting.h
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/array.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
void broadcast(const array& in, array& out);
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
@ -1,6 +1,7 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
|
|
||||||
|
#include "mlx/backend/common/broadcasting.h"
|
||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
@ -42,23 +43,6 @@ void AsStrided::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
return out.copy_shared_buffer(in, strides_, flags, data_size, offset_);
|
return out.copy_shared_buffer(in, strides_, flags, data_size, offset_);
|
||||||
}
|
}
|
||||||
|
|
||||||
void broadcast(const array& in, array& out) {
|
|
||||||
if (out.size() == 0) {
|
|
||||||
out.set_data(nullptr);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
Strides strides(out.ndim(), 0);
|
|
||||||
int diff = out.ndim() - in.ndim();
|
|
||||||
for (int i = in.ndim() - 1; i >= 0; --i) {
|
|
||||||
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
|
|
||||||
}
|
|
||||||
auto flags = in.flags();
|
|
||||||
if (out.size() > in.size()) {
|
|
||||||
flags.row_contiguous = flags.col_contiguous = false;
|
|
||||||
}
|
|
||||||
out.copy_shared_buffer(in, strides, flags, in.data_size());
|
|
||||||
}
|
|
||||||
|
|
||||||
void Broadcast::eval(const std::vector<array>& inputs, array& out) {
|
void Broadcast::eval(const std::vector<array>& inputs, array& out) {
|
||||||
broadcast(inputs[0], out);
|
broadcast(inputs[0], out);
|
||||||
}
|
}
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
#include <cassert>
|
#include <cassert>
|
||||||
|
|
||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
|
#include "mlx/backend/cpu/binary_ops.h"
|
||||||
#include "mlx/backend/cpu/copy.h"
|
#include "mlx/backend/cpu/copy.h"
|
||||||
#include "mlx/backend/cpu/encoder.h"
|
#include "mlx/backend/cpu/encoder.h"
|
||||||
#include "mlx/backend/cpu/simd/simd.h"
|
#include "mlx/backend/cpu/simd/simd.h"
|
||||||
@ -226,6 +227,16 @@ void scan_dispatch(
|
|||||||
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
|
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case Scan::LogAddExp: {
|
||||||
|
auto op = [](U a, T b) {
|
||||||
|
return detail::LogAddExp{}(a, static_cast<U>(b));
|
||||||
|
};
|
||||||
|
auto init = (issubdtype(in.dtype(), floating))
|
||||||
|
? static_cast<U>(-std::numeric_limits<float>::infinity())
|
||||||
|
: std::numeric_limits<U>::min();
|
||||||
|
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -61,6 +61,7 @@ if(MLX_METAL_JIT)
|
|||||||
kernels/steel/gemm/transforms.h)
|
kernels/steel/gemm/transforms.h)
|
||||||
make_jit_source(steel/gemm/kernels/steel_gemm_fused)
|
make_jit_source(steel/gemm/kernels/steel_gemm_fused)
|
||||||
make_jit_source(steel/gemm/kernels/steel_gemm_masked kernels/steel/defines.h)
|
make_jit_source(steel/gemm/kernels/steel_gemm_masked kernels/steel/defines.h)
|
||||||
|
make_jit_source(steel/gemm/kernels/steel_gemm_gather)
|
||||||
make_jit_source(steel/gemm/kernels/steel_gemm_splitk)
|
make_jit_source(steel/gemm/kernels/steel_gemm_splitk)
|
||||||
make_jit_source(
|
make_jit_source(
|
||||||
steel/conv/conv
|
steel/conv/conv
|
||||||
|
@ -33,6 +33,7 @@ const char* gemm();
|
|||||||
const char* steel_gemm_fused();
|
const char* steel_gemm_fused();
|
||||||
const char* steel_gemm_masked();
|
const char* steel_gemm_masked();
|
||||||
const char* steel_gemm_splitk();
|
const char* steel_gemm_splitk();
|
||||||
|
const char* steel_gemm_gather();
|
||||||
const char* conv();
|
const char* conv();
|
||||||
const char* steel_conv();
|
const char* steel_conv();
|
||||||
const char* steel_conv_general();
|
const char* steel_conv_general();
|
||||||
|
@ -584,6 +584,44 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
|
|||||||
return d.get_kernel(kernel_name, lib);
|
return d.get_kernel(kernel_name, lib);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MTL::ComputePipelineState* get_steel_gemm_gather_kernel(
|
||||||
|
metal::Device& d,
|
||||||
|
const std::string& kernel_name,
|
||||||
|
const std::string& hash_name,
|
||||||
|
const metal::MTLFCList& func_consts,
|
||||||
|
const array& out,
|
||||||
|
bool transpose_a,
|
||||||
|
bool transpose_b,
|
||||||
|
int bm,
|
||||||
|
int bn,
|
||||||
|
int bk,
|
||||||
|
int wm,
|
||||||
|
int wn,
|
||||||
|
bool rhs) {
|
||||||
|
const auto& lib_name = kernel_name;
|
||||||
|
auto lib = d.get_library(lib_name, [&]() {
|
||||||
|
std::string kernel_source;
|
||||||
|
concatenate(
|
||||||
|
kernel_source,
|
||||||
|
metal::utils(),
|
||||||
|
metal::gemm(),
|
||||||
|
metal::steel_gemm_gather(),
|
||||||
|
get_template_definition(
|
||||||
|
lib_name,
|
||||||
|
rhs ? "gather_mm_rhs" : "gather_mm",
|
||||||
|
get_type_string(out.dtype()),
|
||||||
|
bm,
|
||||||
|
bn,
|
||||||
|
bk,
|
||||||
|
wm,
|
||||||
|
wn,
|
||||||
|
transpose_a,
|
||||||
|
transpose_b));
|
||||||
|
return kernel_source;
|
||||||
|
});
|
||||||
|
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
|
||||||
|
}
|
||||||
|
|
||||||
MTL::ComputePipelineState* get_gemv_masked_kernel(
|
MTL::ComputePipelineState* get_gemv_masked_kernel(
|
||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
const std::string& kernel_name,
|
const std::string& kernel_name,
|
||||||
@ -714,4 +752,43 @@ MTL::ComputePipelineState* get_quantized_kernel(
|
|||||||
return d.get_kernel(kernel_name, lib);
|
return d.get_kernel(kernel_name, lib);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MTL::ComputePipelineState* get_gather_qmm_kernel(
|
||||||
|
metal::Device& d,
|
||||||
|
const std::string& kernel_name,
|
||||||
|
const std::string& hash_name,
|
||||||
|
const metal::MTLFCList& func_consts,
|
||||||
|
const array& x,
|
||||||
|
int group_size,
|
||||||
|
int bits,
|
||||||
|
int bm,
|
||||||
|
int bn,
|
||||||
|
int bk,
|
||||||
|
int wm,
|
||||||
|
int wn,
|
||||||
|
bool transpose) {
|
||||||
|
const auto& lib_name = kernel_name;
|
||||||
|
auto lib = d.get_library(lib_name, [&]() {
|
||||||
|
std::string kernel_source;
|
||||||
|
concatenate(
|
||||||
|
kernel_source,
|
||||||
|
metal::utils(),
|
||||||
|
metal::gemm(),
|
||||||
|
metal::quantized(),
|
||||||
|
get_template_definition(
|
||||||
|
lib_name,
|
||||||
|
"gather_qmm_rhs",
|
||||||
|
get_type_string(x.dtype()),
|
||||||
|
group_size,
|
||||||
|
bits,
|
||||||
|
bm,
|
||||||
|
bn,
|
||||||
|
bk,
|
||||||
|
wm,
|
||||||
|
wn,
|
||||||
|
transpose));
|
||||||
|
return kernel_source;
|
||||||
|
});
|
||||||
|
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -160,6 +160,21 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
|
|||||||
bool mn_aligned,
|
bool mn_aligned,
|
||||||
bool k_aligned);
|
bool k_aligned);
|
||||||
|
|
||||||
|
MTL::ComputePipelineState* get_steel_gemm_gather_kernel(
|
||||||
|
metal::Device& d,
|
||||||
|
const std::string& kernel_name,
|
||||||
|
const std::string& hash_name,
|
||||||
|
const metal::MTLFCList& func_consts,
|
||||||
|
const array& out,
|
||||||
|
bool transpose_a,
|
||||||
|
bool transpose_b,
|
||||||
|
int bm,
|
||||||
|
int bn,
|
||||||
|
int bk,
|
||||||
|
int wm,
|
||||||
|
int wn,
|
||||||
|
bool rhs);
|
||||||
|
|
||||||
MTL::ComputePipelineState* get_steel_conv_kernel(
|
MTL::ComputePipelineState* get_steel_conv_kernel(
|
||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
const std::string& kernel_name,
|
const std::string& kernel_name,
|
||||||
@ -209,6 +224,21 @@ MTL::ComputePipelineState* get_quantized_kernel(
|
|||||||
const std::string& kernel_name,
|
const std::string& kernel_name,
|
||||||
const std::string& template_def);
|
const std::string& template_def);
|
||||||
|
|
||||||
|
MTL::ComputePipelineState* get_gather_qmm_kernel(
|
||||||
|
metal::Device& d,
|
||||||
|
const std::string& kernel_name,
|
||||||
|
const std::string& hash_name,
|
||||||
|
const metal::MTLFCList& func_consts,
|
||||||
|
const array& x,
|
||||||
|
int group_size,
|
||||||
|
int bits,
|
||||||
|
int bm,
|
||||||
|
int bn,
|
||||||
|
int bk,
|
||||||
|
int wm,
|
||||||
|
int wn,
|
||||||
|
bool transpose);
|
||||||
|
|
||||||
// Create a GPU kernel template definition for JIT compilation
|
// Create a GPU kernel template definition for JIT compilation
|
||||||
template <typename... Args>
|
template <typename... Args>
|
||||||
std::string
|
std::string
|
||||||
|
@ -69,6 +69,7 @@ set(STEEL_HEADERS
|
|||||||
steel/gemm/loader.h
|
steel/gemm/loader.h
|
||||||
steel/gemm/transforms.h
|
steel/gemm/transforms.h
|
||||||
steel/gemm/kernels/steel_gemm_fused.h
|
steel/gemm/kernels/steel_gemm_fused.h
|
||||||
|
steel/gemm/kernels/steel_gemm_gather.h
|
||||||
steel/gemm/kernels/steel_gemm_masked.h
|
steel/gemm/kernels/steel_gemm_masked.h
|
||||||
steel/gemm/kernels/steel_gemm_splitk.h
|
steel/gemm/kernels/steel_gemm_splitk.h
|
||||||
steel/utils/type_traits.h
|
steel/utils/type_traits.h
|
||||||
@ -116,6 +117,7 @@ if(NOT MLX_METAL_JIT)
|
|||||||
build_kernel(steel/conv/kernels/steel_conv ${STEEL_HEADERS})
|
build_kernel(steel/conv/kernels/steel_conv ${STEEL_HEADERS})
|
||||||
build_kernel(steel/conv/kernels/steel_conv_general ${STEEL_HEADERS})
|
build_kernel(steel/conv/kernels/steel_conv_general ${STEEL_HEADERS})
|
||||||
build_kernel(steel/gemm/kernels/steel_gemm_fused ${STEEL_HEADERS})
|
build_kernel(steel/gemm/kernels/steel_gemm_fused ${STEEL_HEADERS})
|
||||||
|
build_kernel(steel/gemm/kernels/steel_gemm_gather ${STEEL_HEADERS})
|
||||||
build_kernel(steel/gemm/kernels/steel_gemm_masked ${STEEL_HEADERS})
|
build_kernel(steel/gemm/kernels/steel_gemm_masked ${STEEL_HEADERS})
|
||||||
build_kernel(steel/gemm/kernels/steel_gemm_splitk ${STEEL_HEADERS})
|
build_kernel(steel/gemm/kernels/steel_gemm_splitk ${STEEL_HEADERS})
|
||||||
build_kernel(gemv_masked steel/utils.h)
|
build_kernel(gemv_masked steel/utils.h)
|
||||||
|
@ -104,10 +104,22 @@ constexpr bool operator==(complex64_t a, complex64_t b) {
|
|||||||
constexpr complex64_t operator+(complex64_t a, complex64_t b) {
|
constexpr complex64_t operator+(complex64_t a, complex64_t b) {
|
||||||
return {a.real + b.real, a.imag + b.imag};
|
return {a.real + b.real, a.imag + b.imag};
|
||||||
}
|
}
|
||||||
|
constexpr complex64_t operator+(float a, complex64_t b) {
|
||||||
|
return {a + b.real, b.imag};
|
||||||
|
}
|
||||||
|
constexpr complex64_t operator+(complex64_t a, float b) {
|
||||||
|
return {a.real + b, a.imag};
|
||||||
|
}
|
||||||
|
|
||||||
constexpr complex64_t operator-(complex64_t a, complex64_t b) {
|
constexpr complex64_t operator-(complex64_t a, complex64_t b) {
|
||||||
return {a.real - b.real, a.imag - b.imag};
|
return {a.real - b.real, a.imag - b.imag};
|
||||||
}
|
}
|
||||||
|
constexpr complex64_t operator-(float a, complex64_t b) {
|
||||||
|
return {a - b.real, -b.imag};
|
||||||
|
}
|
||||||
|
constexpr complex64_t operator-(complex64_t a, float b) {
|
||||||
|
return {a.real - b, a.imag};
|
||||||
|
}
|
||||||
|
|
||||||
constexpr complex64_t operator*(complex64_t a, complex64_t b) {
|
constexpr complex64_t operator*(complex64_t a, complex64_t b) {
|
||||||
return {a.real * b.real - a.imag * b.imag, a.real * b.imag + a.imag * b.real};
|
return {a.real * b.real - a.imag * b.imag, a.real * b.imag + a.imag * b.real};
|
||||||
@ -120,6 +132,13 @@ constexpr complex64_t operator/(complex64_t a, complex64_t b) {
|
|||||||
return {x / denom, y / denom};
|
return {x / denom, y / denom};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
constexpr complex64_t operator/(float a, complex64_t b) {
|
||||||
|
auto denom = b.real * b.real + b.imag * b.imag;
|
||||||
|
auto x = a * b.real;
|
||||||
|
auto y = -a * b.imag;
|
||||||
|
return {x / denom, y / denom};
|
||||||
|
}
|
||||||
|
|
||||||
constexpr complex64_t operator%(complex64_t a, complex64_t b) {
|
constexpr complex64_t operator%(complex64_t a, complex64_t b) {
|
||||||
auto real = a.real - (b.real * static_cast<int64_t>(a.real / b.real));
|
auto real = a.real - (b.real * static_cast<int64_t>(a.real / b.real));
|
||||||
auto imag = a.imag - (b.imag * static_cast<int64_t>(a.imag / b.imag));
|
auto imag = a.imag - (b.imag * static_cast<int64_t>(a.imag / b.imag));
|
||||||
|
@ -3,6 +3,10 @@
|
|||||||
#include <metal_simdgroup>
|
#include <metal_simdgroup>
|
||||||
#include <metal_stdlib>
|
#include <metal_stdlib>
|
||||||
|
|
||||||
|
constant bool align_M [[function_constant(200)]];
|
||||||
|
constant bool align_N [[function_constant(201)]];
|
||||||
|
constant bool align_K [[function_constant(202)]];
|
||||||
|
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
|
|
||||||
#define MLX_MTL_CONST static constant constexpr const
|
#define MLX_MTL_CONST static constant constexpr const
|
||||||
@ -1686,26 +1690,26 @@ template <
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, int group_size, int bits>
|
template <typename T, int group_size, int bits>
|
||||||
[[kernel]] void bs_qmv_fast(
|
[[kernel]] void gather_qmv_fast(
|
||||||
const device uint32_t* w [[buffer(0)]],
|
const device uint32_t* w [[buffer(0)]],
|
||||||
const device T* scales [[buffer(1)]],
|
const device T* scales [[buffer(1)]],
|
||||||
const device T* biases [[buffer(2)]],
|
const device T* biases [[buffer(2)]],
|
||||||
const device T* x [[buffer(3)]],
|
const device T* x [[buffer(3)]],
|
||||||
device T* y [[buffer(4)]],
|
const device uint32_t* lhs_indices [[buffer(4)]],
|
||||||
const constant int& in_vec_size [[buffer(5)]],
|
const device uint32_t* rhs_indices [[buffer(5)]],
|
||||||
const constant int& out_vec_size [[buffer(6)]],
|
device T* y [[buffer(6)]],
|
||||||
const constant int& x_batch_ndims [[buffer(7)]],
|
const constant int& in_vec_size [[buffer(7)]],
|
||||||
const constant int* x_shape [[buffer(8)]],
|
const constant int& out_vec_size [[buffer(8)]],
|
||||||
const constant int64_t* x_strides [[buffer(9)]],
|
const constant int& x_batch_ndims [[buffer(9)]],
|
||||||
const constant int& w_batch_ndims [[buffer(10)]],
|
const constant int* x_shape [[buffer(10)]],
|
||||||
const constant int* w_shape [[buffer(11)]],
|
const constant int64_t* x_strides [[buffer(11)]],
|
||||||
const constant int64_t* w_strides [[buffer(12)]],
|
const constant int& w_batch_ndims [[buffer(12)]],
|
||||||
const constant int64_t* s_strides [[buffer(13)]],
|
const constant int* w_shape [[buffer(13)]],
|
||||||
const constant int64_t* b_strides [[buffer(14)]],
|
const constant int64_t* w_strides [[buffer(14)]],
|
||||||
const constant int& batch_ndims [[buffer(15)]],
|
const constant int64_t* s_strides [[buffer(15)]],
|
||||||
const constant int* batch_shape [[buffer(16)]],
|
const constant int64_t* b_strides [[buffer(16)]],
|
||||||
const device uint32_t* lhs_indices [[buffer(17)]],
|
const constant int& batch_ndims [[buffer(17)]],
|
||||||
const device uint32_t* rhs_indices [[buffer(18)]],
|
const constant int* batch_shape [[buffer(18)]],
|
||||||
const constant int64_t* lhs_strides [[buffer(19)]],
|
const constant int64_t* lhs_strides [[buffer(19)]],
|
||||||
const constant int64_t* rhs_strides [[buffer(20)]],
|
const constant int64_t* rhs_strides [[buffer(20)]],
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
@ -1748,26 +1752,26 @@ template <typename T, int group_size, int bits>
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, int group_size, int bits>
|
template <typename T, int group_size, int bits>
|
||||||
[[kernel]] void bs_qmv(
|
[[kernel]] void gather_qmv(
|
||||||
const device uint32_t* w [[buffer(0)]],
|
const device uint32_t* w [[buffer(0)]],
|
||||||
const device T* scales [[buffer(1)]],
|
const device T* scales [[buffer(1)]],
|
||||||
const device T* biases [[buffer(2)]],
|
const device T* biases [[buffer(2)]],
|
||||||
const device T* x [[buffer(3)]],
|
const device T* x [[buffer(3)]],
|
||||||
device T* y [[buffer(4)]],
|
const device uint32_t* lhs_indices [[buffer(4)]],
|
||||||
const constant int& in_vec_size [[buffer(5)]],
|
const device uint32_t* rhs_indices [[buffer(5)]],
|
||||||
const constant int& out_vec_size [[buffer(6)]],
|
device T* y [[buffer(6)]],
|
||||||
const constant int& x_batch_ndims [[buffer(7)]],
|
const constant int& in_vec_size [[buffer(7)]],
|
||||||
const constant int* x_shape [[buffer(8)]],
|
const constant int& out_vec_size [[buffer(8)]],
|
||||||
const constant int64_t* x_strides [[buffer(9)]],
|
const constant int& x_batch_ndims [[buffer(9)]],
|
||||||
const constant int& w_batch_ndims [[buffer(10)]],
|
const constant int* x_shape [[buffer(10)]],
|
||||||
const constant int* w_shape [[buffer(11)]],
|
const constant int64_t* x_strides [[buffer(11)]],
|
||||||
const constant int64_t* w_strides [[buffer(12)]],
|
const constant int& w_batch_ndims [[buffer(12)]],
|
||||||
const constant int64_t* s_strides [[buffer(13)]],
|
const constant int* w_shape [[buffer(13)]],
|
||||||
const constant int64_t* b_strides [[buffer(14)]],
|
const constant int64_t* w_strides [[buffer(14)]],
|
||||||
const constant int& batch_ndims [[buffer(15)]],
|
const constant int64_t* s_strides [[buffer(15)]],
|
||||||
const constant int* batch_shape [[buffer(16)]],
|
const constant int64_t* b_strides [[buffer(16)]],
|
||||||
const device uint32_t* lhs_indices [[buffer(17)]],
|
const constant int& batch_ndims [[buffer(17)]],
|
||||||
const device uint32_t* rhs_indices [[buffer(18)]],
|
const constant int* batch_shape [[buffer(18)]],
|
||||||
const constant int64_t* lhs_strides [[buffer(19)]],
|
const constant int64_t* lhs_strides [[buffer(19)]],
|
||||||
const constant int64_t* rhs_strides [[buffer(20)]],
|
const constant int64_t* rhs_strides [[buffer(20)]],
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
@ -1810,26 +1814,26 @@ template <typename T, int group_size, int bits>
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, int group_size, int bits>
|
template <typename T, int group_size, int bits>
|
||||||
[[kernel]] void bs_qvm(
|
[[kernel]] void gather_qvm(
|
||||||
const device uint32_t* w [[buffer(0)]],
|
const device uint32_t* w [[buffer(0)]],
|
||||||
const device T* scales [[buffer(1)]],
|
const device T* scales [[buffer(1)]],
|
||||||
const device T* biases [[buffer(2)]],
|
const device T* biases [[buffer(2)]],
|
||||||
const device T* x [[buffer(3)]],
|
const device T* x [[buffer(3)]],
|
||||||
device T* y [[buffer(4)]],
|
const device uint32_t* lhs_indices [[buffer(4)]],
|
||||||
const constant int& in_vec_size [[buffer(5)]],
|
const device uint32_t* rhs_indices [[buffer(5)]],
|
||||||
const constant int& out_vec_size [[buffer(6)]],
|
device T* y [[buffer(6)]],
|
||||||
const constant int& x_batch_ndims [[buffer(7)]],
|
const constant int& in_vec_size [[buffer(7)]],
|
||||||
const constant int* x_shape [[buffer(8)]],
|
const constant int& out_vec_size [[buffer(8)]],
|
||||||
const constant int64_t* x_strides [[buffer(9)]],
|
const constant int& x_batch_ndims [[buffer(9)]],
|
||||||
const constant int& w_batch_ndims [[buffer(10)]],
|
const constant int* x_shape [[buffer(10)]],
|
||||||
const constant int* w_shape [[buffer(11)]],
|
const constant int64_t* x_strides [[buffer(11)]],
|
||||||
const constant int64_t* w_strides [[buffer(12)]],
|
const constant int& w_batch_ndims [[buffer(12)]],
|
||||||
const constant int64_t* s_strides [[buffer(13)]],
|
const constant int* w_shape [[buffer(13)]],
|
||||||
const constant int64_t* b_strides [[buffer(14)]],
|
const constant int64_t* w_strides [[buffer(14)]],
|
||||||
const constant int& batch_ndims [[buffer(15)]],
|
const constant int64_t* s_strides [[buffer(15)]],
|
||||||
const constant int* batch_shape [[buffer(16)]],
|
const constant int64_t* b_strides [[buffer(16)]],
|
||||||
const device uint32_t* lhs_indices [[buffer(17)]],
|
const constant int& batch_ndims [[buffer(17)]],
|
||||||
const device uint32_t* rhs_indices [[buffer(18)]],
|
const constant int* batch_shape [[buffer(18)]],
|
||||||
const constant int64_t* lhs_strides [[buffer(19)]],
|
const constant int64_t* lhs_strides [[buffer(19)]],
|
||||||
const constant int64_t* rhs_strides [[buffer(20)]],
|
const constant int64_t* rhs_strides [[buffer(20)]],
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
@ -1879,27 +1883,27 @@ template <
|
|||||||
const int BM = 32,
|
const int BM = 32,
|
||||||
const int BK = 32,
|
const int BK = 32,
|
||||||
const int BN = 32>
|
const int BN = 32>
|
||||||
[[kernel]] void bs_qmm_t(
|
[[kernel]] void gather_qmm_t(
|
||||||
const device uint32_t* w [[buffer(0)]],
|
const device uint32_t* w [[buffer(0)]],
|
||||||
const device T* scales [[buffer(1)]],
|
const device T* scales [[buffer(1)]],
|
||||||
const device T* biases [[buffer(2)]],
|
const device T* biases [[buffer(2)]],
|
||||||
const device T* x [[buffer(3)]],
|
const device T* x [[buffer(3)]],
|
||||||
device T* y [[buffer(4)]],
|
const device uint32_t* lhs_indices [[buffer(4)]],
|
||||||
const constant int& K [[buffer(5)]],
|
const device uint32_t* rhs_indices [[buffer(5)]],
|
||||||
const constant int& N [[buffer(6)]],
|
device T* y [[buffer(6)]],
|
||||||
const constant int& M [[buffer(7)]],
|
const constant int& K [[buffer(7)]],
|
||||||
const constant int& x_batch_ndims [[buffer(8)]],
|
const constant int& N [[buffer(8)]],
|
||||||
const constant int* x_shape [[buffer(9)]],
|
const constant int& M [[buffer(9)]],
|
||||||
const constant int64_t* x_strides [[buffer(10)]],
|
const constant int& x_batch_ndims [[buffer(10)]],
|
||||||
const constant int& w_batch_ndims [[buffer(11)]],
|
const constant int* x_shape [[buffer(11)]],
|
||||||
const constant int* w_shape [[buffer(12)]],
|
const constant int64_t* x_strides [[buffer(12)]],
|
||||||
const constant int64_t* w_strides [[buffer(13)]],
|
const constant int& w_batch_ndims [[buffer(13)]],
|
||||||
const constant int64_t* s_strides [[buffer(14)]],
|
const constant int* w_shape [[buffer(14)]],
|
||||||
const constant int64_t* b_strides [[buffer(15)]],
|
const constant int64_t* w_strides [[buffer(15)]],
|
||||||
const constant int& batch_ndims [[buffer(16)]],
|
const constant int64_t* s_strides [[buffer(16)]],
|
||||||
const constant int* batch_shape [[buffer(17)]],
|
const constant int64_t* b_strides [[buffer(17)]],
|
||||||
const device uint32_t* lhs_indices [[buffer(18)]],
|
const constant int& batch_ndims [[buffer(18)]],
|
||||||
const device uint32_t* rhs_indices [[buffer(19)]],
|
const constant int* batch_shape [[buffer(19)]],
|
||||||
const constant int64_t* lhs_strides [[buffer(20)]],
|
const constant int64_t* lhs_strides [[buffer(20)]],
|
||||||
const constant int64_t* rhs_strides [[buffer(21)]],
|
const constant int64_t* rhs_strides [[buffer(21)]],
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
@ -1946,27 +1950,27 @@ template <
|
|||||||
const int BM = 32,
|
const int BM = 32,
|
||||||
const int BK = 32,
|
const int BK = 32,
|
||||||
const int BN = 32>
|
const int BN = 32>
|
||||||
[[kernel]] void bs_qmm_n(
|
[[kernel]] void gather_qmm_n(
|
||||||
const device uint32_t* w [[buffer(0)]],
|
const device uint32_t* w [[buffer(0)]],
|
||||||
const device T* scales [[buffer(1)]],
|
const device T* scales [[buffer(1)]],
|
||||||
const device T* biases [[buffer(2)]],
|
const device T* biases [[buffer(2)]],
|
||||||
const device T* x [[buffer(3)]],
|
const device T* x [[buffer(3)]],
|
||||||
device T* y [[buffer(4)]],
|
const device uint32_t* lhs_indices [[buffer(4)]],
|
||||||
const constant int& K [[buffer(5)]],
|
const device uint32_t* rhs_indices [[buffer(5)]],
|
||||||
const constant int& N [[buffer(6)]],
|
device T* y [[buffer(6)]],
|
||||||
const constant int& M [[buffer(7)]],
|
const constant int& K [[buffer(7)]],
|
||||||
const constant int& x_batch_ndims [[buffer(8)]],
|
const constant int& N [[buffer(8)]],
|
||||||
const constant int* x_shape [[buffer(9)]],
|
const constant int& M [[buffer(9)]],
|
||||||
const constant int64_t* x_strides [[buffer(10)]],
|
const constant int& x_batch_ndims [[buffer(10)]],
|
||||||
const constant int& w_batch_ndims [[buffer(11)]],
|
const constant int* x_shape [[buffer(11)]],
|
||||||
const constant int* w_shape [[buffer(12)]],
|
const constant int64_t* x_strides [[buffer(12)]],
|
||||||
const constant int64_t* w_strides [[buffer(13)]],
|
const constant int& w_batch_ndims [[buffer(13)]],
|
||||||
const constant int64_t* s_strides [[buffer(14)]],
|
const constant int* w_shape [[buffer(14)]],
|
||||||
const constant int64_t* b_strides [[buffer(15)]],
|
const constant int64_t* w_strides [[buffer(15)]],
|
||||||
const constant int& batch_ndims [[buffer(16)]],
|
const constant int64_t* s_strides [[buffer(16)]],
|
||||||
const constant int* batch_shape [[buffer(17)]],
|
const constant int64_t* b_strides [[buffer(17)]],
|
||||||
const device uint32_t* lhs_indices [[buffer(18)]],
|
const constant int& batch_ndims [[buffer(18)]],
|
||||||
const device uint32_t* rhs_indices [[buffer(19)]],
|
const constant int* batch_shape [[buffer(19)]],
|
||||||
const constant int64_t* lhs_strides [[buffer(20)]],
|
const constant int64_t* lhs_strides [[buffer(20)]],
|
||||||
const constant int64_t* rhs_strides [[buffer(21)]],
|
const constant int64_t* rhs_strides [[buffer(21)]],
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
@ -2007,6 +2011,289 @@ template <
|
|||||||
w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
|
w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T, typename mma_t, typename loader_a_t, typename loader_b_t>
|
||||||
|
METAL_FUNC void gemm_loop_aligned(
|
||||||
|
threadgroup T* As,
|
||||||
|
threadgroup T* Bs,
|
||||||
|
thread mma_t& mma_op,
|
||||||
|
thread loader_a_t& loader_a,
|
||||||
|
thread loader_b_t& loader_b,
|
||||||
|
const int k_iterations) {
|
||||||
|
for (int k = 0; k < k_iterations; k++) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Load elements into threadgroup memory
|
||||||
|
loader_a.load_unsafe();
|
||||||
|
loader_b.load_unsafe();
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Multiply and accumulate threadgroup elements
|
||||||
|
mma_op.mma(As, Bs);
|
||||||
|
|
||||||
|
// Prepare for next iteration
|
||||||
|
loader_a.next();
|
||||||
|
loader_b.next();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
|
bool rows_aligned,
|
||||||
|
bool cols_aligned,
|
||||||
|
bool transpose,
|
||||||
|
typename T,
|
||||||
|
typename mma_t,
|
||||||
|
typename loader_a_t,
|
||||||
|
typename loader_b_t>
|
||||||
|
METAL_FUNC void gemm_loop_unaligned(
|
||||||
|
threadgroup T* As,
|
||||||
|
threadgroup T* Bs,
|
||||||
|
thread mma_t& mma_op,
|
||||||
|
thread loader_a_t& loader_a,
|
||||||
|
thread loader_b_t& loader_b,
|
||||||
|
const int k_iterations,
|
||||||
|
const short tgp_bm,
|
||||||
|
const short tgp_bn,
|
||||||
|
const short tgp_bk) {
|
||||||
|
for (int k = 0; k < k_iterations; k++) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Load elements into threadgroup memory
|
||||||
|
if (rows_aligned) {
|
||||||
|
loader_a.load_unsafe();
|
||||||
|
} else {
|
||||||
|
loader_a.load_safe(short2(tgp_bk, tgp_bm));
|
||||||
|
}
|
||||||
|
if (cols_aligned) {
|
||||||
|
loader_b.load_unsafe();
|
||||||
|
} else {
|
||||||
|
loader_b.load_safe(
|
||||||
|
transpose ? short2(tgp_bk, tgp_bn) : short2(tgp_bn, tgp_bk));
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Multiply and accumulate threadgroup elements
|
||||||
|
mma_op.mma(As, Bs);
|
||||||
|
|
||||||
|
// Prepare for next iteration
|
||||||
|
loader_a.next();
|
||||||
|
loader_b.next();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename mma_t, typename loader_a_t, typename loader_b_t>
|
||||||
|
METAL_FUNC void gemm_loop_finalize(
|
||||||
|
threadgroup T* As,
|
||||||
|
threadgroup T* Bs,
|
||||||
|
thread mma_t& mma_op,
|
||||||
|
thread loader_a_t& loader_a,
|
||||||
|
thread loader_b_t& loader_b,
|
||||||
|
const short2 tile_a,
|
||||||
|
const short2 tile_b) {
|
||||||
|
loader_a.load_safe(tile_a);
|
||||||
|
loader_b.load_safe(tile_b);
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
mma_op.mma(As, Bs);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
int group_size,
|
||||||
|
int bits,
|
||||||
|
int BM,
|
||||||
|
int BN,
|
||||||
|
int BK,
|
||||||
|
int WM,
|
||||||
|
int WN,
|
||||||
|
bool transpose>
|
||||||
|
[[kernel]] void gather_qmm_rhs(
|
||||||
|
const device T* x [[buffer(0)]],
|
||||||
|
const device uint32_t* w [[buffer(1)]],
|
||||||
|
const device T* scales [[buffer(2)]],
|
||||||
|
const device T* biases [[buffer(3)]],
|
||||||
|
const device uint32_t* indices [[buffer(4)]],
|
||||||
|
device T* y [[buffer(5)]],
|
||||||
|
const constant int& M [[buffer(6)]],
|
||||||
|
const constant int& N [[buffer(7)]],
|
||||||
|
const constant int& K [[buffer(8)]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint simd_lane_id [[thread_index_in_simdgroup]]) {
|
||||||
|
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
|
||||||
|
constexpr int BK_padded = (BK + 16 / sizeof(T));
|
||||||
|
constexpr int BN_padded = (BN + 16 / sizeof(T));
|
||||||
|
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
||||||
|
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
|
||||||
|
|
||||||
|
using mma_t = mlx::steel::BlockMMA<
|
||||||
|
T,
|
||||||
|
T,
|
||||||
|
BM,
|
||||||
|
BN,
|
||||||
|
BK,
|
||||||
|
WM,
|
||||||
|
WN,
|
||||||
|
false,
|
||||||
|
transpose,
|
||||||
|
BK_padded,
|
||||||
|
transpose ? BK_padded : BN_padded>;
|
||||||
|
using loader_x_t =
|
||||||
|
mlx::steel::BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE>;
|
||||||
|
using loader_w_t = QuantizedBlockLoader<
|
||||||
|
T,
|
||||||
|
transpose ? BN : BK,
|
||||||
|
transpose ? BK : BN,
|
||||||
|
transpose ? BK_padded : BN_padded,
|
||||||
|
transpose,
|
||||||
|
WM * WN * SIMD_SIZE,
|
||||||
|
group_size,
|
||||||
|
bits>;
|
||||||
|
|
||||||
|
threadgroup T Xs[BM * BK_padded];
|
||||||
|
threadgroup T Ws[transpose ? BN * BK_padded : BK * BN_padded];
|
||||||
|
|
||||||
|
// Compute the block
|
||||||
|
const int K_w = K * bytes_per_pack / pack_factor;
|
||||||
|
const int K_g = K / group_size;
|
||||||
|
const int N_w = N * bytes_per_pack / pack_factor;
|
||||||
|
const int N_g = N / group_size;
|
||||||
|
const int K_it = K / BK;
|
||||||
|
const size_t stride_w = transpose ? N * K_w : K * N_w;
|
||||||
|
const size_t stride_s = transpose ? N * K_g : K * N_g;
|
||||||
|
const int y_row = tid.y * BM;
|
||||||
|
const int y_col = tid.x * BN;
|
||||||
|
const size_t y_row_long = size_t(y_row);
|
||||||
|
const size_t y_col_long = size_t(y_col);
|
||||||
|
|
||||||
|
// Prepare threadgroup bounds
|
||||||
|
const short tgp_bm = align_M ? BM : short(min(BM, M - y_row));
|
||||||
|
const short tgp_bn = align_N ? BN : short(min(BN, N - y_col));
|
||||||
|
|
||||||
|
// Calculate the final tiles in the case that K is not aligned
|
||||||
|
const int k_remain = K - K_it * BK;
|
||||||
|
const short2 tile_x = short2(k_remain, tgp_bm);
|
||||||
|
const short2 tile_w =
|
||||||
|
transpose ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
|
||||||
|
|
||||||
|
// Move x and output to the correct block
|
||||||
|
auto wl = (const device uint8_t*)w;
|
||||||
|
x += y_row_long * K;
|
||||||
|
y += y_row_long * N + y_col_long;
|
||||||
|
wl += transpose ? y_col_long * K_w : y_col * bytes_per_pack / pack_factor;
|
||||||
|
scales += transpose ? y_col_long * K_g : y_col / group_size;
|
||||||
|
biases += transpose ? y_col_long * K_g : y_col / group_size;
|
||||||
|
|
||||||
|
// Do as many matmuls as necessary
|
||||||
|
uint32_t index;
|
||||||
|
short offset;
|
||||||
|
uint32_t index_next = indices[y_row];
|
||||||
|
short offset_next = 0;
|
||||||
|
int n = 0;
|
||||||
|
while (n < tgp_bm) {
|
||||||
|
n++;
|
||||||
|
offset = offset_next;
|
||||||
|
index = index_next;
|
||||||
|
offset_next = tgp_bm;
|
||||||
|
for (; n < tgp_bm; n++) {
|
||||||
|
if (indices[y_row + n] != index) {
|
||||||
|
offset_next = n;
|
||||||
|
index_next = indices[y_row + n];
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
|
// Prepare threadgroup mma operation
|
||||||
|
thread mma_t mma_op(simd_group_id, simd_lane_id);
|
||||||
|
|
||||||
|
// Prepare threadgroup loading operations
|
||||||
|
thread loader_x_t loader_x(x, K, Xs, simd_group_id, simd_lane_id);
|
||||||
|
thread loader_w_t loader_w(
|
||||||
|
wl + index * stride_w,
|
||||||
|
scales + index * stride_s,
|
||||||
|
biases + index * stride_s,
|
||||||
|
transpose ? K : N,
|
||||||
|
Ws,
|
||||||
|
simd_group_id,
|
||||||
|
simd_lane_id);
|
||||||
|
|
||||||
|
// Matrices are all aligned check nothing
|
||||||
|
if (align_M && align_N) {
|
||||||
|
gemm_loop_aligned(Xs, Ws, mma_op, loader_x, loader_w, K_it);
|
||||||
|
if (!align_K) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
gemm_loop_finalize(Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store results to device memory
|
||||||
|
if (offset_next - offset == BM) {
|
||||||
|
mma_op.store_result(y, N);
|
||||||
|
} else {
|
||||||
|
mma_op.store_result_slice(
|
||||||
|
y, N, short2(0, offset), short2(BN, offset_next));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Tile aligned so check outside of the hot loop
|
||||||
|
if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {
|
||||||
|
gemm_loop_aligned(Xs, Ws, mma_op, loader_x, loader_w, K_it);
|
||||||
|
if (!align_K) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
gemm_loop_finalize(
|
||||||
|
Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store results to device memory
|
||||||
|
if (offset_next - offset == BM) {
|
||||||
|
mma_op.store_result(y, N);
|
||||||
|
} else {
|
||||||
|
mma_op.store_result_slice(
|
||||||
|
y, N, short2(0, offset), short2(BN, offset_next));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tile partially aligned check rows
|
||||||
|
else if (align_N || tgp_bn == BN) {
|
||||||
|
gemm_loop_unaligned<false, true, transpose>(
|
||||||
|
Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK);
|
||||||
|
if (!align_K) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
gemm_loop_finalize(
|
||||||
|
Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
|
||||||
|
}
|
||||||
|
mma_op.store_result_slice(
|
||||||
|
y, N, short2(0, offset), short2(BN, offset_next));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tile partially aligned check cols
|
||||||
|
else if (align_M || tgp_bm == BM) {
|
||||||
|
gemm_loop_unaligned<true, false, transpose>(
|
||||||
|
Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK);
|
||||||
|
if (!align_K) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
gemm_loop_finalize(
|
||||||
|
Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
|
||||||
|
}
|
||||||
|
mma_op.store_result_slice(
|
||||||
|
y, N, short2(0, offset), short2(tgp_bn, offset_next));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Nothing aligned so check both rows and cols
|
||||||
|
else {
|
||||||
|
gemm_loop_unaligned<false, false, transpose>(
|
||||||
|
Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK);
|
||||||
|
if (!align_K) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
gemm_loop_finalize(
|
||||||
|
Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
|
||||||
|
}
|
||||||
|
mma_op.store_result_slice(
|
||||||
|
y, N, short2(0, offset), short2(tgp_bn, offset_next));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T, const int group_size, const int bits>
|
template <typename T, const int group_size, const int bits>
|
||||||
[[kernel]] void affine_quantize(
|
[[kernel]] void affine_quantize(
|
||||||
const device T* w [[buffer(0)]],
|
const device T* w [[buffer(0)]],
|
||||||
|
@ -60,6 +60,20 @@
|
|||||||
bits, \
|
bits, \
|
||||||
split_k)
|
split_k)
|
||||||
|
|
||||||
|
#define instantiate_gather_qmm_rhs(func, name, type, group_size, bits, bm, bn, bk, wm, wn, transpose) \
|
||||||
|
instantiate_kernel( \
|
||||||
|
#name "_" #type "_gs_" #group_size "_b_" #bits "_bm_" #bm "_bn_" #bn "_bk_" #bk "_wm_" #wm "_wn_" #wn, \
|
||||||
|
func, \
|
||||||
|
type, \
|
||||||
|
group_size, \
|
||||||
|
bits, \
|
||||||
|
bm, \
|
||||||
|
bn, \
|
||||||
|
bk, \
|
||||||
|
wm, \
|
||||||
|
wn, \
|
||||||
|
transpose)
|
||||||
|
|
||||||
#define instantiate_quantized_batched_wrap(name, type, group_size, bits) \
|
#define instantiate_quantized_batched_wrap(name, type, group_size, bits) \
|
||||||
instantiate_quantized_batched(name, type, group_size, bits, 1) \
|
instantiate_quantized_batched(name, type, group_size, bits, 1) \
|
||||||
instantiate_quantized_batched(name, type, group_size, bits, 0)
|
instantiate_quantized_batched(name, type, group_size, bits, 0)
|
||||||
@ -73,14 +87,14 @@
|
|||||||
#define instantiate_quantized_all_single(type, group_size, bits) \
|
#define instantiate_quantized_all_single(type, group_size, bits) \
|
||||||
instantiate_quantized(affine_quantize, type, group_size, bits) \
|
instantiate_quantized(affine_quantize, type, group_size, bits) \
|
||||||
instantiate_quantized(affine_dequantize, type, group_size, bits) \
|
instantiate_quantized(affine_dequantize, type, group_size, bits) \
|
||||||
instantiate_quantized(bs_qmv_fast, type, group_size, bits) \
|
instantiate_quantized(gather_qmv_fast, type, group_size, bits) \
|
||||||
instantiate_quantized(bs_qmv, type, group_size, bits) \
|
instantiate_quantized(gather_qmv, type, group_size, bits) \
|
||||||
instantiate_quantized(bs_qvm, type, group_size, bits) \
|
instantiate_quantized(gather_qvm, type, group_size, bits) \
|
||||||
instantiate_quantized(bs_qmm_n, type, group_size, bits)
|
instantiate_quantized(gather_qmm_n, type, group_size, bits)
|
||||||
|
|
||||||
#define instantiate_quantized_all_aligned(type, group_size, bits) \
|
#define instantiate_quantized_all_aligned(type, group_size, bits) \
|
||||||
instantiate_quantized_aligned(bs_qmm_t, type, group_size, bits, true) \
|
instantiate_quantized_aligned(gather_qmm_t, type, group_size, bits, true) \
|
||||||
instantiate_quantized_aligned(bs_qmm_t, type, group_size, bits, false) \
|
instantiate_quantized_aligned(gather_qmm_t, type, group_size, bits, false) \
|
||||||
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, true, 1) \
|
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, true, 1) \
|
||||||
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, true, 0) \
|
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, true, 0) \
|
||||||
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, false, 1) \
|
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, false, 1) \
|
||||||
@ -96,12 +110,17 @@
|
|||||||
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 8) \
|
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 8) \
|
||||||
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 32)
|
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 32)
|
||||||
|
|
||||||
|
#define instantiate_quantized_all_rhs(type, group_size, bits) \
|
||||||
|
instantiate_gather_qmm_rhs(gather_qmm_rhs, gather_qmm_rhs_nt, type, group_size, bits, 16, 32, 32, 1, 2, true) \
|
||||||
|
instantiate_gather_qmm_rhs(gather_qmm_rhs, gather_qmm_rhs_nn, type, group_size, bits, 16, 32, 32, 1, 2, false)
|
||||||
|
|
||||||
#define instantiate_quantized_funcs(type, group_size, bits) \
|
#define instantiate_quantized_funcs(type, group_size, bits) \
|
||||||
instantiate_quantized_all_single(type, group_size, bits) \
|
instantiate_quantized_all_single(type, group_size, bits) \
|
||||||
instantiate_quantized_all_batched(type, group_size, bits) \
|
instantiate_quantized_all_batched(type, group_size, bits) \
|
||||||
instantiate_quantized_all_aligned(type, group_size, bits) \
|
instantiate_quantized_all_aligned(type, group_size, bits) \
|
||||||
instantiate_quantized_all_quad(type, group_size, bits) \
|
instantiate_quantized_all_quad(type, group_size, bits) \
|
||||||
instantiate_quantized_all_splitk(type, group_size, bits)
|
instantiate_quantized_all_splitk(type, group_size, bits) \
|
||||||
|
instantiate_quantized_all_rhs(type, group_size, bits)
|
||||||
|
|
||||||
#define instantiate_quantized_types(group_size, bits) \
|
#define instantiate_quantized_types(group_size, bits) \
|
||||||
instantiate_quantized_funcs(float, group_size, bits) \
|
instantiate_quantized_funcs(float, group_size, bits) \
|
||||||
|
@ -2,6 +2,8 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/kernels/binary_ops.h"
|
||||||
|
|
||||||
#define DEFINE_SIMD_SCAN() \
|
#define DEFINE_SIMD_SCAN() \
|
||||||
template <typename T, metal::enable_if_t<sizeof(T) < 8, bool> = true> \
|
template <typename T, metal::enable_if_t<sizeof(T) < 8, bool> = true> \
|
||||||
T simd_scan(T val) { \
|
T simd_scan(T val) { \
|
||||||
@ -139,6 +141,29 @@ struct CumMin {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename U>
|
||||||
|
struct CumLogaddexp {
|
||||||
|
static constexpr constant U init = Limits<U>::min;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
U operator()(U a, T b) {
|
||||||
|
return LogAddExp{}(a, static_cast<U>(b));
|
||||||
|
}
|
||||||
|
|
||||||
|
U simd_scan(U x) {
|
||||||
|
for (int i = 1; i <= 16; i *= 2) {
|
||||||
|
U other = simd_shuffle_and_fill_up(x, init, i);
|
||||||
|
x = LogAddExp{}(x, other);
|
||||||
|
}
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
U simd_exclusive_scan(U x) {
|
||||||
|
x = simd_scan(x);
|
||||||
|
return simd_shuffle_and_fill_up(x, init, 1);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template <typename T, typename U, int N_READS, bool reverse>
|
template <typename T, typename U, int N_READS, bool reverse>
|
||||||
inline void load_unsafe(U values[N_READS], const device T* input) {
|
inline void load_unsafe(U values[N_READS], const device T* input) {
|
||||||
if (reverse) {
|
if (reverse) {
|
||||||
|
@ -101,4 +101,7 @@ instantiate_scan_helper(min_int64_int64, int64_t, int64_t, CumMi
|
|||||||
instantiate_scan_helper(min_float16_float16, half, half, CumMin, 4)
|
instantiate_scan_helper(min_float16_float16, half, half, CumMin, 4)
|
||||||
instantiate_scan_helper(min_float32_float32, float, float, CumMin, 4)
|
instantiate_scan_helper(min_float32_float32, float, float, CumMin, 4)
|
||||||
instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMin, 4)
|
instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMin, 4)
|
||||||
instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin, 2) // clang-format on
|
instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin, 2)
|
||||||
|
instantiate_scan_helper(logaddexp_float16_float16, half, half, CumLogaddexp, 4)
|
||||||
|
instantiate_scan_helper(logaddexp_float32_float32, float, float, CumLogaddexp, 4)
|
||||||
|
instantiate_scan_helper(logaddexp_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumLogaddexp, 4) // clang-format on
|
||||||
|
@ -15,10 +15,6 @@ constant bool align_M [[function_constant(200)]];
|
|||||||
constant bool align_N [[function_constant(201)]];
|
constant bool align_N [[function_constant(201)]];
|
||||||
constant bool align_K [[function_constant(202)]];
|
constant bool align_K [[function_constant(202)]];
|
||||||
|
|
||||||
constant bool do_gather [[function_constant(300)]];
|
|
||||||
|
|
||||||
constant bool gather_bias = do_gather && use_out_source;
|
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
template <
|
template <
|
||||||
typename T,
|
typename T,
|
||||||
@ -39,12 +35,6 @@ template <
|
|||||||
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
|
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
|
||||||
const constant int* batch_shape [[buffer(6)]],
|
const constant int* batch_shape [[buffer(6)]],
|
||||||
const constant int64_t* batch_strides [[buffer(7)]],
|
const constant int64_t* batch_strides [[buffer(7)]],
|
||||||
const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]],
|
|
||||||
const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]],
|
|
||||||
const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]],
|
|
||||||
const constant int* operand_shape [[buffer(13), function_constant(do_gather)]],
|
|
||||||
const constant int64_t* operand_strides [[buffer(14), function_constant(do_gather)]],
|
|
||||||
const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]],
|
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
@ -81,84 +71,26 @@ template <
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Adjust for batch
|
// Adjust for batch
|
||||||
|
if (has_batch) {
|
||||||
|
const constant auto* A_bstrides = batch_strides;
|
||||||
|
const constant auto* B_bstrides = batch_strides + params->batch_ndim;
|
||||||
|
|
||||||
// Handle gather
|
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||||
if (do_gather) {
|
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
|
||||||
// Read indices
|
|
||||||
uint32_t indx_A, indx_B, indx_C;
|
|
||||||
|
|
||||||
if (has_batch) {
|
A += batch_offsets.x;
|
||||||
const constant auto* indx_A_bstrides = batch_strides;
|
B += batch_offsets.y;
|
||||||
const constant auto* indx_B_bstrides = batch_strides + params->batch_ndim;
|
|
||||||
|
|
||||||
ulong2 indx_offsets = elem_to_loc_broadcast(
|
|
||||||
tid.z,
|
|
||||||
batch_shape,
|
|
||||||
indx_A_bstrides,
|
|
||||||
indx_B_bstrides,
|
|
||||||
params->batch_ndim);
|
|
||||||
indx_A = lhs_indices[indx_offsets.x];
|
|
||||||
indx_B = rhs_indices[indx_offsets.y];
|
|
||||||
|
|
||||||
if (use_out_source) {
|
|
||||||
const constant auto* indx_C_bstrides =
|
|
||||||
indx_B_bstrides + params->batch_ndim;
|
|
||||||
auto indx_offset_C = elem_to_loc(
|
|
||||||
tid.z, batch_shape, indx_C_bstrides, params->batch_ndim);
|
|
||||||
indx_C = C_indices[indx_offset_C];
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
indx_A = lhs_indices[params->batch_stride_a * tid.z];
|
|
||||||
indx_B = rhs_indices[params->batch_stride_b * tid.z];
|
|
||||||
|
|
||||||
if (use_out_source) {
|
|
||||||
indx_C = C_indices[addmm_params->batch_stride_c * tid.z];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Translate indices to offsets
|
|
||||||
int batch_ndim_A = operand_batch_ndim.x;
|
|
||||||
const constant int* batch_shape_A = operand_shape;
|
|
||||||
const constant auto* batch_strides_A = operand_strides;
|
|
||||||
A += elem_to_loc(indx_A, batch_shape_A, batch_strides_A, batch_ndim_A);
|
|
||||||
|
|
||||||
int batch_ndim_B = operand_batch_ndim.y;
|
|
||||||
const constant int* batch_shape_B = batch_shape_A + batch_ndim_A;
|
|
||||||
const constant auto* batch_strides_B = batch_strides_A + batch_ndim_A;
|
|
||||||
B += elem_to_loc(indx_B, batch_shape_B, batch_strides_B, batch_ndim_B);
|
|
||||||
|
|
||||||
if (use_out_source) {
|
if (use_out_source) {
|
||||||
int batch_ndim_C = operand_batch_ndim.z;
|
const constant auto* C_bstrides = B_bstrides + params->batch_ndim;
|
||||||
const constant int* batch_shape_C = batch_shape_B + batch_ndim_B;
|
C += elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim);
|
||||||
const constant auto* batch_strides_C = batch_strides_B + batch_ndim_B;
|
|
||||||
C += elem_to_loc(indx_C, batch_shape_C, batch_strides_C, batch_ndim_C);
|
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
A += params->batch_stride_a * tid.z;
|
||||||
|
B += params->batch_stride_b * tid.z;
|
||||||
|
|
||||||
}
|
if (use_out_source) {
|
||||||
|
C += addmm_params->batch_stride_c * tid.z;
|
||||||
// Handle regular batch
|
|
||||||
else {
|
|
||||||
if (has_batch) {
|
|
||||||
const constant auto* A_bstrides = batch_strides;
|
|
||||||
const constant auto* B_bstrides = batch_strides + params->batch_ndim;
|
|
||||||
|
|
||||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
|
||||||
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
|
|
||||||
|
|
||||||
A += batch_offsets.x;
|
|
||||||
B += batch_offsets.y;
|
|
||||||
|
|
||||||
if (use_out_source) {
|
|
||||||
const constant auto* C_bstrides = B_bstrides + params->batch_ndim;
|
|
||||||
C += elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
A += params->batch_stride_a * tid.z;
|
|
||||||
B += params->batch_stride_b * tid.z;
|
|
||||||
|
|
||||||
if (use_out_source) {
|
|
||||||
C += addmm_params->batch_stride_c * tid.z;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
459
mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h
Normal file
459
mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h
Normal file
@ -0,0 +1,459 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
using namespace mlx::steel;
|
||||||
|
|
||||||
|
constant bool has_batch [[function_constant(10)]];
|
||||||
|
constant bool align_M [[function_constant(200)]];
|
||||||
|
constant bool align_N [[function_constant(201)]];
|
||||||
|
constant bool align_K [[function_constant(202)]];
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
int BM,
|
||||||
|
int BN,
|
||||||
|
int BK,
|
||||||
|
int WM,
|
||||||
|
int WN,
|
||||||
|
bool transpose_a,
|
||||||
|
bool transpose_b,
|
||||||
|
typename AccumType = float>
|
||||||
|
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gather_mm_rhs(
|
||||||
|
const device T* A [[buffer(0)]],
|
||||||
|
const device T* B [[buffer(1)]],
|
||||||
|
const device uint32_t* rhs_indices [[buffer(2)]],
|
||||||
|
device T* C [[buffer(3)]],
|
||||||
|
const constant GEMMParams* params [[buffer(4)]],
|
||||||
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
|
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]]) {
|
||||||
|
using gemm_kernel = GEMMKernel<
|
||||||
|
T,
|
||||||
|
T,
|
||||||
|
BM,
|
||||||
|
BN,
|
||||||
|
BK,
|
||||||
|
WM,
|
||||||
|
WN,
|
||||||
|
transpose_a,
|
||||||
|
transpose_b,
|
||||||
|
true,
|
||||||
|
true,
|
||||||
|
AccumType>;
|
||||||
|
|
||||||
|
using loader_a_t = typename gemm_kernel::loader_a_t;
|
||||||
|
using loader_b_t = typename gemm_kernel::loader_b_t;
|
||||||
|
using mma_t = typename gemm_kernel::mma_t;
|
||||||
|
|
||||||
|
if (params->tiles_n <= static_cast<int>(tid.x) ||
|
||||||
|
params->tiles_m <= static_cast<int>(tid.y)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare threadgroup memory
|
||||||
|
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
||||||
|
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
||||||
|
|
||||||
|
// Find the block in A, B, C
|
||||||
|
const int c_row = tid.y * BM;
|
||||||
|
const int c_col = tid.x * BN;
|
||||||
|
const size_t c_row_long = size_t(c_row);
|
||||||
|
const size_t c_col_long = size_t(c_col);
|
||||||
|
|
||||||
|
// Prepare threadgroup bounds
|
||||||
|
const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row));
|
||||||
|
const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col));
|
||||||
|
|
||||||
|
A += transpose_a ? c_row_long : c_row_long * params->lda;
|
||||||
|
B += transpose_b ? c_col_long * params->ldb : c_col_long;
|
||||||
|
C += c_row_long * params->ldd + c_col_long;
|
||||||
|
|
||||||
|
// Do as many matmuls as necessary
|
||||||
|
uint32_t index;
|
||||||
|
short offset;
|
||||||
|
uint32_t index_next = rhs_indices[c_row];
|
||||||
|
short offset_next = 0;
|
||||||
|
int n = 0;
|
||||||
|
while (n < tgp_bm) {
|
||||||
|
n++;
|
||||||
|
offset = offset_next;
|
||||||
|
index = index_next;
|
||||||
|
offset_next = tgp_bm;
|
||||||
|
for (; n < tgp_bm; n++) {
|
||||||
|
if (rhs_indices[c_row + n] != index) {
|
||||||
|
offset_next = n;
|
||||||
|
index_next = rhs_indices[c_row + n];
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
|
// Prepare threadgroup mma operation
|
||||||
|
thread mma_t mma_op(simd_group_id, simd_lane_id);
|
||||||
|
|
||||||
|
// Prepare threadgroup loading operations
|
||||||
|
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
||||||
|
thread loader_b_t loader_b(
|
||||||
|
B + index * params->batch_stride_b,
|
||||||
|
params->ldb,
|
||||||
|
Bs,
|
||||||
|
simd_group_id,
|
||||||
|
simd_lane_id);
|
||||||
|
|
||||||
|
// Prepare iterations
|
||||||
|
const int gemm_k_iterations = params->gemm_k_iterations_aligned;
|
||||||
|
|
||||||
|
// Do unaligned K iterations first
|
||||||
|
if (!align_K) {
|
||||||
|
const int k_last = params->gemm_k_iterations_aligned * BK;
|
||||||
|
const int k_remain = params->K - k_last;
|
||||||
|
const size_t k_jump_a =
|
||||||
|
transpose_a ? params->lda * size_t(k_last) : size_t(k_last);
|
||||||
|
const size_t k_jump_b =
|
||||||
|
transpose_b ? size_t(k_last) : params->ldb * size_t(k_last);
|
||||||
|
|
||||||
|
// Move loader source ahead to end
|
||||||
|
loader_a.src += k_jump_a;
|
||||||
|
loader_b.src += k_jump_b;
|
||||||
|
|
||||||
|
// Load tile
|
||||||
|
const short2 tile_dims_A =
|
||||||
|
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
|
||||||
|
const short2 tile_dims_B =
|
||||||
|
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
|
||||||
|
|
||||||
|
loader_a.load_safe(tile_dims_A);
|
||||||
|
loader_b.load_safe(tile_dims_B);
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Do matmul
|
||||||
|
mma_op.mma(As, Bs);
|
||||||
|
|
||||||
|
// Reset source back to start
|
||||||
|
loader_a.src -= k_jump_a;
|
||||||
|
loader_b.src -= k_jump_b;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Matrix level aligned never check
|
||||||
|
if (align_M && align_N) {
|
||||||
|
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Load elements into threadgroup
|
||||||
|
loader_a.load_unsafe();
|
||||||
|
loader_b.load_unsafe();
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Multiply and accumulate threadgroup elements
|
||||||
|
mma_op.mma(As, Bs);
|
||||||
|
|
||||||
|
// Prepare for next iteration
|
||||||
|
loader_a.next();
|
||||||
|
loader_b.next();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store results to device memory
|
||||||
|
if (offset_next - offset == BM) {
|
||||||
|
mma_op.store_result(C, params->ldd);
|
||||||
|
} else {
|
||||||
|
mma_op.store_result_slice(
|
||||||
|
C, params->ldd, short2(0, offset), short2(BN, offset_next));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
const short lbk = 0;
|
||||||
|
|
||||||
|
// Tile aligned don't check
|
||||||
|
if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {
|
||||||
|
gemm_kernel::gemm_loop(
|
||||||
|
As,
|
||||||
|
Bs,
|
||||||
|
gemm_k_iterations,
|
||||||
|
loader_a,
|
||||||
|
loader_b,
|
||||||
|
mma_op,
|
||||||
|
tgp_bm,
|
||||||
|
tgp_bn,
|
||||||
|
lbk,
|
||||||
|
LoopAlignment<true, true, true>{});
|
||||||
|
if (offset_next - offset == BM) {
|
||||||
|
mma_op.store_result(C, params->ldd);
|
||||||
|
} else {
|
||||||
|
mma_op.store_result_slice(
|
||||||
|
C, params->ldd, short2(0, offset), short2(BN, offset_next));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tile partially aligned check rows
|
||||||
|
else if (align_N || tgp_bn == BN) {
|
||||||
|
gemm_kernel::gemm_loop(
|
||||||
|
As,
|
||||||
|
Bs,
|
||||||
|
gemm_k_iterations,
|
||||||
|
loader_a,
|
||||||
|
loader_b,
|
||||||
|
mma_op,
|
||||||
|
tgp_bm,
|
||||||
|
tgp_bn,
|
||||||
|
lbk,
|
||||||
|
LoopAlignment<false, true, true>{});
|
||||||
|
mma_op.store_result_slice(
|
||||||
|
C, params->ldd, short2(0, offset), short2(BN, offset_next));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tile partially aligned check cols
|
||||||
|
else if (align_M || tgp_bm == BM) {
|
||||||
|
gemm_kernel::gemm_loop(
|
||||||
|
As,
|
||||||
|
Bs,
|
||||||
|
gemm_k_iterations,
|
||||||
|
loader_a,
|
||||||
|
loader_b,
|
||||||
|
mma_op,
|
||||||
|
tgp_bm,
|
||||||
|
tgp_bn,
|
||||||
|
lbk,
|
||||||
|
LoopAlignment<true, false, true>{});
|
||||||
|
mma_op.store_result_slice(
|
||||||
|
C, params->ldd, short2(0, offset), short2(tgp_bn, offset_next));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Nothing aligned so check both rows and cols
|
||||||
|
else {
|
||||||
|
gemm_kernel::gemm_loop(
|
||||||
|
As,
|
||||||
|
Bs,
|
||||||
|
gemm_k_iterations,
|
||||||
|
loader_a,
|
||||||
|
loader_b,
|
||||||
|
mma_op,
|
||||||
|
tgp_bm,
|
||||||
|
tgp_bn,
|
||||||
|
lbk,
|
||||||
|
LoopAlignment<false, false, true>{});
|
||||||
|
mma_op.store_result_slice(
|
||||||
|
C, params->ldd, short2(0, offset), short2(tgp_bn, offset_next));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
int BM,
|
||||||
|
int BN,
|
||||||
|
int BK,
|
||||||
|
int WM,
|
||||||
|
int WN,
|
||||||
|
bool transpose_a,
|
||||||
|
bool transpose_b,
|
||||||
|
typename AccumType = float>
|
||||||
|
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gather_mm(
|
||||||
|
const device T* A [[buffer(0)]],
|
||||||
|
const device T* B [[buffer(1)]],
|
||||||
|
const device uint32_t* lhs_indices [[buffer(2)]],
|
||||||
|
const device uint32_t* rhs_indices [[buffer(3)]],
|
||||||
|
device T* C [[buffer(4)]],
|
||||||
|
const constant GEMMParams* params [[buffer(5)]],
|
||||||
|
const constant int* indices_shape [[buffer(6)]],
|
||||||
|
const constant int64_t* lhs_strides [[buffer(7)]],
|
||||||
|
const constant int64_t* rhs_strides [[buffer(8)]],
|
||||||
|
const constant int& batch_ndim_a [[buffer(9)]],
|
||||||
|
const constant int* batch_shape_a [[buffer(10)]],
|
||||||
|
const constant int64_t* batch_strides_a [[buffer(11)]],
|
||||||
|
const constant int& batch_ndim_b [[buffer(12)]],
|
||||||
|
const constant int* batch_shape_b [[buffer(13)]],
|
||||||
|
const constant int64_t* batch_strides_b [[buffer(14)]],
|
||||||
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
|
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]]) {
|
||||||
|
using gemm_kernel = GEMMKernel<
|
||||||
|
T,
|
||||||
|
T,
|
||||||
|
BM,
|
||||||
|
BN,
|
||||||
|
BK,
|
||||||
|
WM,
|
||||||
|
WN,
|
||||||
|
transpose_a,
|
||||||
|
transpose_b,
|
||||||
|
true,
|
||||||
|
true,
|
||||||
|
AccumType>;
|
||||||
|
|
||||||
|
using loader_a_t = typename gemm_kernel::loader_a_t;
|
||||||
|
using loader_b_t = typename gemm_kernel::loader_b_t;
|
||||||
|
using mma_t = typename gemm_kernel::mma_t;
|
||||||
|
|
||||||
|
if (params->tiles_n <= static_cast<int>(tid.x) ||
|
||||||
|
params->tiles_m <= static_cast<int>(tid.y)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Move A and B to the locations pointed by lhs_indices and rhs_indices.
|
||||||
|
uint32_t indx_A, indx_B;
|
||||||
|
if (has_batch) {
|
||||||
|
ulong2 indices_offsets = elem_to_loc_broadcast(
|
||||||
|
tid.z, indices_shape, lhs_strides, rhs_strides, params->batch_ndim);
|
||||||
|
indx_A = lhs_indices[indices_offsets.x];
|
||||||
|
indx_B = rhs_indices[indices_offsets.y];
|
||||||
|
} else {
|
||||||
|
indx_A = lhs_indices[params->batch_stride_a * tid.z];
|
||||||
|
indx_B = rhs_indices[params->batch_stride_b * tid.z];
|
||||||
|
}
|
||||||
|
A += elem_to_loc(indx_A, batch_shape_a, batch_strides_a, batch_ndim_a);
|
||||||
|
B += elem_to_loc(indx_B, batch_shape_b, batch_strides_b, batch_ndim_b);
|
||||||
|
C += params->batch_stride_d * tid.z;
|
||||||
|
|
||||||
|
// Prepare threadgroup memory
|
||||||
|
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
||||||
|
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
||||||
|
|
||||||
|
// Just make sure everybody's finished with the indexing math above.
|
||||||
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
|
// Find block in A, B, C
|
||||||
|
const int c_row = tid.y * BM;
|
||||||
|
const int c_col = tid.x * BN;
|
||||||
|
const size_t c_row_long = size_t(c_row);
|
||||||
|
const size_t c_col_long = size_t(c_col);
|
||||||
|
|
||||||
|
A += transpose_a ? c_row_long : c_row_long * params->lda;
|
||||||
|
B += transpose_b ? c_col_long * params->ldb : c_col_long;
|
||||||
|
C += c_row_long * params->ldd + c_col_long;
|
||||||
|
|
||||||
|
// Prepare threadgroup mma operation
|
||||||
|
thread mma_t mma_op(simd_group_id, simd_lane_id);
|
||||||
|
|
||||||
|
// Prepare threadgroup loading operations
|
||||||
|
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
||||||
|
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
||||||
|
|
||||||
|
// Prepare threadgroup bounds
|
||||||
|
const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row));
|
||||||
|
const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col));
|
||||||
|
|
||||||
|
// Prepare iterations
|
||||||
|
int gemm_k_iterations = params->gemm_k_iterations_aligned;
|
||||||
|
|
||||||
|
// Do unaligned K iterations first
|
||||||
|
if (!align_K) {
|
||||||
|
const int k_last = params->gemm_k_iterations_aligned * BK;
|
||||||
|
const int k_remain = params->K - k_last;
|
||||||
|
const size_t k_jump_a =
|
||||||
|
transpose_a ? params->lda * size_t(k_last) : size_t(k_last);
|
||||||
|
const size_t k_jump_b =
|
||||||
|
transpose_b ? size_t(k_last) : params->ldb * size_t(k_last);
|
||||||
|
|
||||||
|
// Move loader source ahead to end
|
||||||
|
loader_a.src += k_jump_a;
|
||||||
|
loader_b.src += k_jump_b;
|
||||||
|
|
||||||
|
// Load tile
|
||||||
|
const short2 tile_dims_A =
|
||||||
|
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
|
||||||
|
const short2 tile_dims_B =
|
||||||
|
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
|
||||||
|
|
||||||
|
loader_a.load_safe(tile_dims_A);
|
||||||
|
loader_b.load_safe(tile_dims_B);
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Do matmul
|
||||||
|
mma_op.mma(As, Bs);
|
||||||
|
|
||||||
|
// Reset source back to start
|
||||||
|
loader_a.src -= k_jump_a;
|
||||||
|
loader_b.src -= k_jump_b;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Matrix level aligned never check
|
||||||
|
if (align_M && align_N) {
|
||||||
|
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Load elements into threadgroup
|
||||||
|
loader_a.load_unsafe();
|
||||||
|
loader_b.load_unsafe();
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Multiply and accumulate threadgroup elements
|
||||||
|
mma_op.mma(As, Bs);
|
||||||
|
|
||||||
|
// Prepare for next iteration
|
||||||
|
loader_a.next();
|
||||||
|
loader_b.next();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store results to device memory
|
||||||
|
mma_op.store_result(C, params->ldd);
|
||||||
|
} else {
|
||||||
|
const short lbk = 0;
|
||||||
|
|
||||||
|
// Tile aligned don't check
|
||||||
|
if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {
|
||||||
|
gemm_kernel::gemm_loop(
|
||||||
|
As,
|
||||||
|
Bs,
|
||||||
|
gemm_k_iterations,
|
||||||
|
loader_a,
|
||||||
|
loader_b,
|
||||||
|
mma_op,
|
||||||
|
tgp_bm,
|
||||||
|
tgp_bn,
|
||||||
|
lbk,
|
||||||
|
LoopAlignment<true, true, true>{});
|
||||||
|
mma_op.store_result(C, params->ldd);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tile partially aligned check rows
|
||||||
|
else if (align_N || tgp_bn == BN) {
|
||||||
|
gemm_kernel::gemm_loop(
|
||||||
|
As,
|
||||||
|
Bs,
|
||||||
|
gemm_k_iterations,
|
||||||
|
loader_a,
|
||||||
|
loader_b,
|
||||||
|
mma_op,
|
||||||
|
tgp_bm,
|
||||||
|
tgp_bn,
|
||||||
|
lbk,
|
||||||
|
LoopAlignment<false, true, true>{});
|
||||||
|
mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tile partially aligned check cols
|
||||||
|
else if (align_M || tgp_bm == BM) {
|
||||||
|
gemm_kernel::gemm_loop(
|
||||||
|
As,
|
||||||
|
Bs,
|
||||||
|
gemm_k_iterations,
|
||||||
|
loader_a,
|
||||||
|
loader_b,
|
||||||
|
mma_op,
|
||||||
|
tgp_bm,
|
||||||
|
tgp_bn,
|
||||||
|
lbk,
|
||||||
|
LoopAlignment<true, false, true>{});
|
||||||
|
mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Nothing aligned so check both rows and cols
|
||||||
|
else {
|
||||||
|
gemm_kernel::gemm_loop(
|
||||||
|
As,
|
||||||
|
Bs,
|
||||||
|
gemm_k_iterations,
|
||||||
|
loader_a,
|
||||||
|
loader_b,
|
||||||
|
mma_op,
|
||||||
|
tgp_bm,
|
||||||
|
tgp_bn,
|
||||||
|
lbk,
|
||||||
|
LoopAlignment<false, false, true>{});
|
||||||
|
mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,59 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
|
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||||
|
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h"
|
||||||
|
|
||||||
|
#define instantiate_gather_mm_rhs(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
|
instantiate_kernel( \
|
||||||
|
"steel_gather_mm_rhs_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn \
|
||||||
|
"_bk" #bk "_wm" #wm "_wn" #wn, \
|
||||||
|
gather_mm_rhs, \
|
||||||
|
itype, \
|
||||||
|
bm, \
|
||||||
|
bn, \
|
||||||
|
bk, \
|
||||||
|
wm, \
|
||||||
|
wn, \
|
||||||
|
trans_a, \
|
||||||
|
trans_b, \
|
||||||
|
float)
|
||||||
|
|
||||||
|
#define instantiate_gather_mm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
|
instantiate_kernel( \
|
||||||
|
"steel_gather_mm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn \
|
||||||
|
"_bk" #bk "_wm" #wm "_wn" #wn, \
|
||||||
|
gather_mm, \
|
||||||
|
itype, \
|
||||||
|
bm, \
|
||||||
|
bn, \
|
||||||
|
bk, \
|
||||||
|
wm, \
|
||||||
|
wn, \
|
||||||
|
trans_a, \
|
||||||
|
trans_b, \
|
||||||
|
float)
|
||||||
|
|
||||||
|
#define instantiate_gather_mm_rhs_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
|
instantiate_gather_mm_rhs(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
|
instantiate_gather_mm_rhs(nt, false, true, iname, itype, oname, otype, bm, bn, bk, wm, wn)
|
||||||
|
|
||||||
|
#define instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
|
instantiate_gather_mm(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
|
instantiate_gather_mm(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
|
instantiate_gather_mm(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
|
instantiate_gather_mm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
|
||||||
|
|
||||||
|
#define instantiate_gather_mm_shapes_helper(iname, itype, oname, otype) \
|
||||||
|
instantiate_gather_mm_rhs_transpose_helper(iname, itype, oname, otype, 16, 64, 16, 1, 2) \
|
||||||
|
instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
|
||||||
|
instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 1, 2) \
|
||||||
|
instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
|
||||||
|
instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 1, 2) \
|
||||||
|
instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2)
|
||||||
|
// clang-format on
|
||||||
|
|
||||||
|
instantiate_gather_mm_shapes_helper(float16, half, float16, half);
|
||||||
|
instantiate_gather_mm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
|
||||||
|
instantiate_gather_mm_shapes_helper(float32, float, float32, float);
|
@ -142,6 +142,42 @@ struct BaseMMAFrag<T, 8, 8> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename DstPtrType,
|
||||||
|
typename StrX,
|
||||||
|
typename StrY,
|
||||||
|
typename StartX,
|
||||||
|
typename StopX,
|
||||||
|
typename StartY,
|
||||||
|
typename StopY,
|
||||||
|
typename OffX,
|
||||||
|
typename OffY>
|
||||||
|
METAL_FUNC static constexpr void store_slice(
|
||||||
|
const thread frag_type& src,
|
||||||
|
DstPtrType dst,
|
||||||
|
StrX str_x,
|
||||||
|
StrY str_y,
|
||||||
|
StartX start_x,
|
||||||
|
StopX stop_x,
|
||||||
|
StartY start_y,
|
||||||
|
StopY stop_y,
|
||||||
|
OffX off_x = Int<0>{},
|
||||||
|
OffY off_y = Int<0>{}) {
|
||||||
|
using U = pointer_element_t<DstPtrType>;
|
||||||
|
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short i = 0; i < kElemRows; i++) {
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short j = 0; j < kElemCols; j++) {
|
||||||
|
if ((off_x + i) < stop_x && (off_x + i) >= start_x &&
|
||||||
|
(off_y + j) < stop_y && (off_y + j) >= start_y) {
|
||||||
|
dst[(off_x + i) * str_x + (off_y + j) * str_y] =
|
||||||
|
static_cast<U>(src[i * kElemCols + j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
METAL_FUNC static constexpr void mma(
|
METAL_FUNC static constexpr void mma(
|
||||||
thread frag_type& D,
|
thread frag_type& D,
|
||||||
thread frag_type& A,
|
thread frag_type& A,
|
||||||
@ -335,6 +371,31 @@ struct MMATile {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename U, int w_x, int w_y>
|
||||||
|
METAL_FUNC void store_slice(
|
||||||
|
device U* dst,
|
||||||
|
const int ld,
|
||||||
|
const short2 start,
|
||||||
|
const short2 stop) const {
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < kTileRows; ++i) {
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (int j = 0; j < kTileCols; ++j) {
|
||||||
|
MMAFrag_t::store_slice(
|
||||||
|
frag_at(i, j),
|
||||||
|
dst,
|
||||||
|
ld,
|
||||||
|
Int<1>{},
|
||||||
|
start.y,
|
||||||
|
stop.y,
|
||||||
|
start.x,
|
||||||
|
stop.x,
|
||||||
|
(i * kFragRows) * w_x,
|
||||||
|
(j * kFragCols) * w_y);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T, typename U, int M, int N, int K>
|
template <typename T, typename U, int M, int N, int K>
|
||||||
@ -474,6 +535,26 @@ struct BlockMMA {
|
|||||||
Ctile.template store<U, WM, WN>(D, ldd);
|
Ctile.template store<U, WM, WN>(D, ldd);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
METAL_FUNC void
|
||||||
|
store_result_slice(device U* D, const int ldd, short2 start, short2 stop) {
|
||||||
|
// Apply epilogue
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
|
||||||
|
Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
D += sm * ldd + sn;
|
||||||
|
start -= short2(sn, sm);
|
||||||
|
stop -= short2(sn, sm);
|
||||||
|
|
||||||
|
// TODO: Check the start as well
|
||||||
|
if (stop.y <= 0 || stop.x <= 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ctile.template store_slice<U, WM, WN>(D, ldd, start, stop);
|
||||||
|
}
|
||||||
|
|
||||||
METAL_FUNC void
|
METAL_FUNC void
|
||||||
store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) {
|
store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) {
|
||||||
// Apply epilogue
|
// Apply epilogue
|
||||||
|
@ -69,6 +69,9 @@ instantiate_unary_float(Round)
|
|||||||
instantiate_unary_int(BitwiseInvert)
|
instantiate_unary_int(BitwiseInvert)
|
||||||
|
|
||||||
instantiate_unary_all_same(Abs, complex64, complex64_t)
|
instantiate_unary_all_same(Abs, complex64, complex64_t)
|
||||||
|
instantiate_unary_all_same(ArcCos, complex64, complex64_t)
|
||||||
|
instantiate_unary_all_same(ArcSin, complex64, complex64_t)
|
||||||
|
instantiate_unary_all_same(ArcTan, complex64, complex64_t)
|
||||||
instantiate_unary_all_same(Conjugate, complex64, complex64_t)
|
instantiate_unary_all_same(Conjugate, complex64, complex64_t)
|
||||||
instantiate_unary_all_same(Cos, complex64, complex64_t)
|
instantiate_unary_all_same(Cos, complex64, complex64_t)
|
||||||
instantiate_unary_all_same(Cosh, complex64, complex64_t)
|
instantiate_unary_all_same(Cosh, complex64, complex64_t)
|
||||||
@ -80,6 +83,9 @@ instantiate_unary_all_same(Negative, complex64, complex64_t)
|
|||||||
instantiate_unary_all_same(Sign, complex64, complex64_t)
|
instantiate_unary_all_same(Sign, complex64, complex64_t)
|
||||||
instantiate_unary_all_same(Sin, complex64, complex64_t)
|
instantiate_unary_all_same(Sin, complex64, complex64_t)
|
||||||
instantiate_unary_all_same(Sinh, complex64, complex64_t)
|
instantiate_unary_all_same(Sinh, complex64, complex64_t)
|
||||||
|
instantiate_unary_all_same(Square, complex64, complex64_t)
|
||||||
|
instantiate_unary_all_same(Sqrt, complex64, complex64_t)
|
||||||
|
instantiate_unary_all_same(Rsqrt, complex64, complex64_t)
|
||||||
instantiate_unary_all_same(Tan, complex64, complex64_t)
|
instantiate_unary_all_same(Tan, complex64, complex64_t)
|
||||||
instantiate_unary_all_same(Tanh, complex64, complex64_t)
|
instantiate_unary_all_same(Tanh, complex64, complex64_t)
|
||||||
instantiate_unary_all_same(Round, complex64, complex64_t)
|
instantiate_unary_all_same(Round, complex64, complex64_t)
|
||||||
|
@ -17,27 +17,21 @@ struct Abs {
|
|||||||
T operator()(T x) {
|
T operator()(T x) {
|
||||||
return metal::abs(x);
|
return metal::abs(x);
|
||||||
};
|
};
|
||||||
template <>
|
|
||||||
uint8_t operator()(uint8_t x) {
|
uint8_t operator()(uint8_t x) {
|
||||||
return x;
|
return x;
|
||||||
};
|
};
|
||||||
template <>
|
|
||||||
uint16_t operator()(uint16_t x) {
|
uint16_t operator()(uint16_t x) {
|
||||||
return x;
|
return x;
|
||||||
};
|
};
|
||||||
template <>
|
|
||||||
uint32_t operator()(uint32_t x) {
|
uint32_t operator()(uint32_t x) {
|
||||||
return x;
|
return x;
|
||||||
};
|
};
|
||||||
template <>
|
|
||||||
uint64_t operator()(uint64_t x) {
|
uint64_t operator()(uint64_t x) {
|
||||||
return x;
|
return x;
|
||||||
};
|
};
|
||||||
template <>
|
|
||||||
bool operator()(bool x) {
|
bool operator()(bool x) {
|
||||||
return x;
|
return x;
|
||||||
};
|
};
|
||||||
template <>
|
|
||||||
complex64_t operator()(complex64_t x) {
|
complex64_t operator()(complex64_t x) {
|
||||||
return {metal::precise::sqrt(x.real * x.real + x.imag * x.imag), 0};
|
return {metal::precise::sqrt(x.real * x.real + x.imag * x.imag), 0};
|
||||||
};
|
};
|
||||||
@ -48,6 +42,8 @@ struct ArcCos {
|
|||||||
T operator()(T x) {
|
T operator()(T x) {
|
||||||
return metal::precise::acos(x);
|
return metal::precise::acos(x);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
complex64_t operator()(complex64_t x);
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ArcCosh {
|
struct ArcCosh {
|
||||||
@ -62,6 +58,8 @@ struct ArcSin {
|
|||||||
T operator()(T x) {
|
T operator()(T x) {
|
||||||
return metal::precise::asin(x);
|
return metal::precise::asin(x);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
complex64_t operator()(complex64_t x);
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ArcSinh {
|
struct ArcSinh {
|
||||||
@ -76,6 +74,8 @@ struct ArcTan {
|
|||||||
T operator()(T x) {
|
T operator()(T x) {
|
||||||
return metal::precise::atan(x);
|
return metal::precise::atan(x);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
complex64_t operator()(complex64_t x);
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ArcTanh {
|
struct ArcTanh {
|
||||||
@ -97,39 +97,30 @@ struct Ceil {
|
|||||||
T operator()(T x) {
|
T operator()(T x) {
|
||||||
return metal::ceil(x);
|
return metal::ceil(x);
|
||||||
};
|
};
|
||||||
template <>
|
|
||||||
int8_t operator()(int8_t x) {
|
int8_t operator()(int8_t x) {
|
||||||
return x;
|
return x;
|
||||||
};
|
};
|
||||||
template <>
|
|
||||||
int16_t operator()(int16_t x) {
|
int16_t operator()(int16_t x) {
|
||||||
return x;
|
return x;
|
||||||
};
|
};
|
||||||
template <>
|
|
||||||
int32_t operator()(int32_t x) {
|
int32_t operator()(int32_t x) {
|
||||||
return x;
|
return x;
|
||||||
};
|
};
|
||||||
template <>
|
|
||||||
int64_t operator()(int64_t x) {
|
int64_t operator()(int64_t x) {
|
||||||
return x;
|
return x;
|
||||||
};
|
};
|
||||||
template <>
|
|
||||||
uint8_t operator()(uint8_t x) {
|
uint8_t operator()(uint8_t x) {
|
||||||
return x;
|
return x;
|
||||||
};
|
};
|
||||||
template <>
|
|
||||||
uint16_t operator()(uint16_t x) {
|
uint16_t operator()(uint16_t x) {
|
||||||
return x;
|
return x;
|
||||||
};
|
};
|
||||||
template <>
|
|
||||||
uint32_t operator()(uint32_t x) {
|
uint32_t operator()(uint32_t x) {
|
||||||
return x;
|
return x;
|
||||||
};
|
};
|
||||||
template <>
|
|
||||||
uint64_t operator()(uint64_t x) {
|
uint64_t operator()(uint64_t x) {
|
||||||
return x;
|
return x;
|
||||||
};
|
};
|
||||||
template <>
|
|
||||||
bool operator()(bool x) {
|
bool operator()(bool x) {
|
||||||
return x;
|
return x;
|
||||||
};
|
};
|
||||||
@ -141,7 +132,6 @@ struct Cos {
|
|||||||
return metal::precise::cos(x);
|
return metal::precise::cos(x);
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
|
||||||
complex64_t operator()(complex64_t x) {
|
complex64_t operator()(complex64_t x) {
|
||||||
return {
|
return {
|
||||||
metal::precise::cos(x.real) * metal::precise::cosh(x.imag),
|
metal::precise::cos(x.real) * metal::precise::cosh(x.imag),
|
||||||
@ -155,7 +145,6 @@ struct Cosh {
|
|||||||
return metal::precise::cosh(x);
|
return metal::precise::cosh(x);
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
|
||||||
complex64_t operator()(complex64_t x) {
|
complex64_t operator()(complex64_t x) {
|
||||||
return {
|
return {
|
||||||
metal::precise::cosh(x.real) * metal::precise::cos(x.imag),
|
metal::precise::cosh(x.real) * metal::precise::cos(x.imag),
|
||||||
@ -188,7 +177,6 @@ struct Exp {
|
|||||||
T operator()(T x) {
|
T operator()(T x) {
|
||||||
return metal::precise::exp(x);
|
return metal::precise::exp(x);
|
||||||
};
|
};
|
||||||
template <>
|
|
||||||
complex64_t operator()(complex64_t x) {
|
complex64_t operator()(complex64_t x) {
|
||||||
auto m = metal::precise::exp(x.real);
|
auto m = metal::precise::exp(x.real);
|
||||||
return {m * metal::precise::cos(x.imag), m * metal::precise::sin(x.imag)};
|
return {m * metal::precise::cos(x.imag), m * metal::precise::sin(x.imag)};
|
||||||
@ -207,39 +195,30 @@ struct Floor {
|
|||||||
T operator()(T x) {
|
T operator()(T x) {
|
||||||
return metal::floor(x);
|
return metal::floor(x);
|
||||||
};
|
};
|
||||||
template <>
|
|
||||||
int8_t operator()(int8_t x) {
|
int8_t operator()(int8_t x) {
|
||||||
return x;
|
return x;
|
||||||
};
|
};
|
||||||
template <>
|
|
||||||
int16_t operator()(int16_t x) {
|
int16_t operator()(int16_t x) {
|
||||||
return x;
|
return x;
|
||||||
};
|
};
|
||||||
template <>
|
|
||||||
int32_t operator()(int32_t x) {
|
int32_t operator()(int32_t x) {
|
||||||
return x;
|
return x;
|
||||||
};
|
};
|
||||||
template <>
|
|
||||||
int64_t operator()(int64_t x) {
|
int64_t operator()(int64_t x) {
|
||||||
return x;
|
return x;
|
||||||
};
|
};
|
||||||
template <>
|
|
||||||
uint8_t operator()(uint8_t x) {
|
uint8_t operator()(uint8_t x) {
|
||||||
return x;
|
return x;
|
||||||
};
|
};
|
||||||
template <>
|
|
||||||
uint16_t operator()(uint16_t x) {
|
uint16_t operator()(uint16_t x) {
|
||||||
return x;
|
return x;
|
||||||
};
|
};
|
||||||
template <>
|
|
||||||
uint32_t operator()(uint32_t x) {
|
uint32_t operator()(uint32_t x) {
|
||||||
return x;
|
return x;
|
||||||
};
|
};
|
||||||
template <>
|
|
||||||
uint64_t operator()(uint64_t x) {
|
uint64_t operator()(uint64_t x) {
|
||||||
return x;
|
return x;
|
||||||
};
|
};
|
||||||
template <>
|
|
||||||
bool operator()(bool x) {
|
bool operator()(bool x) {
|
||||||
return x;
|
return x;
|
||||||
};
|
};
|
||||||
@ -258,7 +237,6 @@ struct Log {
|
|||||||
return metal::precise::log(x);
|
return metal::precise::log(x);
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
|
||||||
complex64_t operator()(complex64_t x) {
|
complex64_t operator()(complex64_t x) {
|
||||||
auto r = metal::precise::log(Abs{}(x).real);
|
auto r = metal::precise::log(Abs{}(x).real);
|
||||||
auto i = metal::precise::atan2(x.imag, x.real);
|
auto i = metal::precise::atan2(x.imag, x.real);
|
||||||
@ -272,7 +250,6 @@ struct Log2 {
|
|||||||
return metal::precise::log2(x);
|
return metal::precise::log2(x);
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
|
||||||
complex64_t operator()(complex64_t x) {
|
complex64_t operator()(complex64_t x) {
|
||||||
auto y = Log{}(x);
|
auto y = Log{}(x);
|
||||||
return {y.real / M_LN2_F, y.imag / M_LN2_F};
|
return {y.real / M_LN2_F, y.imag / M_LN2_F};
|
||||||
@ -285,7 +262,6 @@ struct Log10 {
|
|||||||
return metal::precise::log10(x);
|
return metal::precise::log10(x);
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
|
||||||
complex64_t operator()(complex64_t x) {
|
complex64_t operator()(complex64_t x) {
|
||||||
auto y = Log{}(x);
|
auto y = Log{}(x);
|
||||||
return {y.real / M_LN10_F, y.imag / M_LN10_F};
|
return {y.real / M_LN10_F, y.imag / M_LN10_F};
|
||||||
@ -325,7 +301,6 @@ struct Round {
|
|||||||
T operator()(T x) {
|
T operator()(T x) {
|
||||||
return metal::rint(x);
|
return metal::rint(x);
|
||||||
};
|
};
|
||||||
template <>
|
|
||||||
complex64_t operator()(complex64_t x) {
|
complex64_t operator()(complex64_t x) {
|
||||||
return {metal::rint(x.real), metal::rint(x.imag)};
|
return {metal::rint(x.real), metal::rint(x.imag)};
|
||||||
};
|
};
|
||||||
@ -344,11 +319,9 @@ struct Sign {
|
|||||||
T operator()(T x) {
|
T operator()(T x) {
|
||||||
return (x > T(0)) - (x < T(0));
|
return (x > T(0)) - (x < T(0));
|
||||||
};
|
};
|
||||||
template <>
|
|
||||||
uint32_t operator()(uint32_t x) {
|
uint32_t operator()(uint32_t x) {
|
||||||
return x != 0;
|
return x != 0;
|
||||||
};
|
};
|
||||||
template <>
|
|
||||||
complex64_t operator()(complex64_t x) {
|
complex64_t operator()(complex64_t x) {
|
||||||
if (x == complex64_t(0)) {
|
if (x == complex64_t(0)) {
|
||||||
return x;
|
return x;
|
||||||
@ -364,7 +337,6 @@ struct Sin {
|
|||||||
return metal::precise::sin(x);
|
return metal::precise::sin(x);
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
|
||||||
complex64_t operator()(complex64_t x) {
|
complex64_t operator()(complex64_t x) {
|
||||||
return {
|
return {
|
||||||
metal::precise::sin(x.real) * metal::precise::cosh(x.imag),
|
metal::precise::sin(x.real) * metal::precise::cosh(x.imag),
|
||||||
@ -378,7 +350,6 @@ struct Sinh {
|
|||||||
return metal::precise::sinh(x);
|
return metal::precise::sinh(x);
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
|
||||||
complex64_t operator()(complex64_t x) {
|
complex64_t operator()(complex64_t x) {
|
||||||
return {
|
return {
|
||||||
metal::precise::sinh(x.real) * metal::precise::cos(x.imag),
|
metal::precise::sinh(x.real) * metal::precise::cos(x.imag),
|
||||||
@ -398,6 +369,17 @@ struct Sqrt {
|
|||||||
T operator()(T x) {
|
T operator()(T x) {
|
||||||
return metal::precise::sqrt(x);
|
return metal::precise::sqrt(x);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
complex64_t operator()(complex64_t x) {
|
||||||
|
if (x.real == 0.0 && x.imag == 0.0) {
|
||||||
|
return {0.0, 0.0};
|
||||||
|
}
|
||||||
|
auto r = Abs{}(x).real;
|
||||||
|
auto a = metal::precise::sqrt((r + x.real) / 2.0);
|
||||||
|
auto b_abs = metal::precise::sqrt((r - x.real) / 2.0);
|
||||||
|
auto b = metal::copysign(b_abs, x.imag);
|
||||||
|
return {a, b};
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Rsqrt {
|
struct Rsqrt {
|
||||||
@ -405,6 +387,10 @@ struct Rsqrt {
|
|||||||
T operator()(T x) {
|
T operator()(T x) {
|
||||||
return metal::precise::rsqrt(x);
|
return metal::precise::rsqrt(x);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
complex64_t operator()(complex64_t x) {
|
||||||
|
return 1.0 / Sqrt{}(x);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Tan {
|
struct Tan {
|
||||||
@ -413,7 +399,6 @@ struct Tan {
|
|||||||
return metal::precise::tan(x);
|
return metal::precise::tan(x);
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
|
||||||
complex64_t operator()(complex64_t x) {
|
complex64_t operator()(complex64_t x) {
|
||||||
float tan_a = metal::precise::tan(x.real);
|
float tan_a = metal::precise::tan(x.real);
|
||||||
float tanh_b = metal::precise::tanh(x.imag);
|
float tanh_b = metal::precise::tanh(x.imag);
|
||||||
@ -429,7 +414,6 @@ struct Tanh {
|
|||||||
return metal::precise::tanh(x);
|
return metal::precise::tanh(x);
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
|
||||||
complex64_t operator()(complex64_t x) {
|
complex64_t operator()(complex64_t x) {
|
||||||
float tanh_a = metal::precise::tanh(x.real);
|
float tanh_a = metal::precise::tanh(x.real);
|
||||||
float tan_b = metal::precise::tan(x.imag);
|
float tan_b = metal::precise::tan(x.imag);
|
||||||
@ -438,3 +422,21 @@ struct Tanh {
|
|||||||
return {(tanh_a + tan_b * t1) / denom, (tan_b - tanh_a * t1) / denom};
|
return {(tanh_a + tan_b * t1) / denom, (tan_b - tanh_a * t1) / denom};
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
complex64_t ArcCos::operator()(complex64_t x) {
|
||||||
|
auto i = complex64_t{0.0, 1.0};
|
||||||
|
auto y = Log{}(x + i * Sqrt{}(1.0 - x * x));
|
||||||
|
return {y.imag, -y.real};
|
||||||
|
};
|
||||||
|
|
||||||
|
complex64_t ArcSin::operator()(complex64_t x) {
|
||||||
|
auto i = complex64_t{0.0, 1.0};
|
||||||
|
auto y = Log{}(i * x + Sqrt{}(1.0 - x * x));
|
||||||
|
return {y.imag, -y.real};
|
||||||
|
};
|
||||||
|
|
||||||
|
complex64_t ArcTan::operator()(complex64_t x) {
|
||||||
|
auto i = complex64_t{0.0, 1.0};
|
||||||
|
auto ix = i * x;
|
||||||
|
return (1.0 / complex64_t{0.0, 2.0}) * Log{}((1.0 + ix) / (1.0 - ix));
|
||||||
|
};
|
||||||
|
@ -5,6 +5,7 @@
|
|||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
|
#include "mlx/backend/common/broadcasting.h"
|
||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
#include "mlx/backend/metal/copy.h"
|
#include "mlx/backend/metal/copy.h"
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
@ -102,6 +103,47 @@ std::tuple<bool, int64_t, array> check_transpose(
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
inline array
|
||||||
|
ensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) {
|
||||||
|
if (!x.flags().row_contiguous) {
|
||||||
|
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||||
|
copy_gpu(x, x_copy, CopyType::General, s);
|
||||||
|
d.add_temporary(x_copy, s.index);
|
||||||
|
return x_copy;
|
||||||
|
} else {
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline std::tuple<bool, int64_t, array>
|
||||||
|
ensure_batch_contiguous(const array& x, metal::Device& d, const Stream& s) {
|
||||||
|
if (x.flags().row_contiguous) {
|
||||||
|
return std::make_tuple(false, x.strides()[x.ndim() - 2], x);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool rc = true;
|
||||||
|
for (int i = 0; i < x.ndim() - 3; i++) {
|
||||||
|
rc &= x.strides()[i + 1] * x.shape(i) == x.strides()[i];
|
||||||
|
}
|
||||||
|
if (rc) {
|
||||||
|
auto stx = x.strides()[x.ndim() - 2];
|
||||||
|
auto sty = x.strides()[x.ndim() - 1];
|
||||||
|
auto K = x.shape(-2);
|
||||||
|
auto N = x.shape(-1);
|
||||||
|
if (sty == 1 && (N != 1 || stx == N)) {
|
||||||
|
return std::make_tuple(false, stx, x);
|
||||||
|
}
|
||||||
|
if (stx == 1 && (N != 1 || sty == K)) {
|
||||||
|
return std::make_tuple(true, sty, x);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||||
|
copy_gpu(x, x_copy, CopyType::General, s);
|
||||||
|
d.add_temporary(x_copy, s.index);
|
||||||
|
return std::make_tuple(false, x_copy.strides()[x_copy.ndim() - 2], x_copy);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
@ -230,7 +272,6 @@ void steel_matmul_regular(
|
|||||||
const bool align_M = (M % bm) == 0;
|
const bool align_M = (M % bm) == 0;
|
||||||
const bool align_N = (N % bn) == 0;
|
const bool align_N = (N % bn) == 0;
|
||||||
const bool align_K = (K % bk) == 0;
|
const bool align_K = (K % bk) == 0;
|
||||||
const bool do_gather = false;
|
|
||||||
|
|
||||||
metal::MTLFCList func_consts = {
|
metal::MTLFCList func_consts = {
|
||||||
{&has_batch, MTL::DataType::DataTypeBool, 10},
|
{&has_batch, MTL::DataType::DataTypeBool, 10},
|
||||||
@ -239,7 +280,6 @@ void steel_matmul_regular(
|
|||||||
{&align_M, MTL::DataType::DataTypeBool, 200},
|
{&align_M, MTL::DataType::DataTypeBool, 200},
|
||||||
{&align_N, MTL::DataType::DataTypeBool, 201},
|
{&align_N, MTL::DataType::DataTypeBool, 201},
|
||||||
{&align_K, MTL::DataType::DataTypeBool, 202},
|
{&align_K, MTL::DataType::DataTypeBool, 202},
|
||||||
{&do_gather, MTL::DataType::DataTypeBool, 300},
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
@ -248,8 +288,7 @@ void steel_matmul_regular(
|
|||||||
<< "_do_axpby_" << (do_axpby ? 't' : 'n')
|
<< "_do_axpby_" << (do_axpby ? 't' : 'n')
|
||||||
<< "_align_M_" << (align_M ? 't' : 'n')
|
<< "_align_M_" << (align_M ? 't' : 'n')
|
||||||
<< "_align_N_" << (align_N ? 't' : 'n')
|
<< "_align_N_" << (align_N ? 't' : 'n')
|
||||||
<< "_align_K_" << (align_K ? 't' : 'n')
|
<< "_align_K_" << (align_K ? 't' : 'n'); // clang-format on
|
||||||
<< "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on
|
|
||||||
|
|
||||||
std::string hash_name = kname.str();
|
std::string hash_name = kname.str();
|
||||||
|
|
||||||
@ -975,7 +1014,6 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
const bool align_M = (M % bm) == 0;
|
const bool align_M = (M % bm) == 0;
|
||||||
const bool align_N = (N % bn) == 0;
|
const bool align_N = (N % bn) == 0;
|
||||||
const bool align_K = (K % bk) == 0;
|
const bool align_K = (K % bk) == 0;
|
||||||
const bool do_gather = false;
|
|
||||||
|
|
||||||
metal::MTLFCList func_consts = {
|
metal::MTLFCList func_consts = {
|
||||||
{&has_batch, MTL::DataType::DataTypeBool, 10},
|
{&has_batch, MTL::DataType::DataTypeBool, 10},
|
||||||
@ -984,7 +1022,6 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
{&align_M, MTL::DataType::DataTypeBool, 200},
|
{&align_M, MTL::DataType::DataTypeBool, 200},
|
||||||
{&align_N, MTL::DataType::DataTypeBool, 201},
|
{&align_N, MTL::DataType::DataTypeBool, 201},
|
||||||
{&align_K, MTL::DataType::DataTypeBool, 202},
|
{&align_K, MTL::DataType::DataTypeBool, 202},
|
||||||
{&do_gather, MTL::DataType::DataTypeBool, 300},
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
@ -993,8 +1030,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
<< "_do_axpby_" << (do_axpby ? 't' : 'n')
|
<< "_do_axpby_" << (do_axpby ? 't' : 'n')
|
||||||
<< "_align_M_" << (align_M ? 't' : 'n')
|
<< "_align_M_" << (align_M ? 't' : 'n')
|
||||||
<< "_align_N_" << (align_N ? 't' : 'n')
|
<< "_align_N_" << (align_N ? 't' : 'n')
|
||||||
<< "_align_K_" << (align_K ? 't' : 'n')
|
<< "_align_K_" << (align_K ? 't' : 'n'); // clang-format on
|
||||||
<< "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on
|
|
||||||
|
|
||||||
std::string hash_name = kname.str();
|
std::string hash_name = kname.str();
|
||||||
|
|
||||||
@ -1464,267 +1500,337 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
d.add_temporaries(std::move(copies), s.index);
|
d.add_temporaries(std::move(copies), s.index);
|
||||||
}
|
}
|
||||||
|
|
||||||
void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void gather_mm_rhs(
|
||||||
using namespace mlx::steel;
|
const array& a_,
|
||||||
// assert(inputs.size() == 2);
|
const array& b_,
|
||||||
if (!issubdtype(out.dtype(), floating)) {
|
const array& indices_,
|
||||||
throw std::runtime_error(
|
array& out,
|
||||||
"[GatherMM] Does not yet support non-floating point types.");
|
metal::Device& d,
|
||||||
}
|
const Stream& s) {
|
||||||
auto& s = stream();
|
array indices = ensure_row_contiguous(indices_, d, s);
|
||||||
auto& d = metal::device(s.device);
|
auto [transpose_b, ldb, b] = ensure_batch_contiguous(b_, d, s);
|
||||||
|
|
||||||
auto& a_pre = inputs[0];
|
// Broadcast a with indices. If we are here that means lhs_indices were not
|
||||||
auto& b_pre = inputs[1];
|
// provided so the lhs_indices are implied to be the shape of a broadcasted
|
||||||
// Return 0s if either input is empty
|
// with rhs_indices. We need only broadcast a and copy it as if applying the
|
||||||
if (a_pre.size() == 0 || b_pre.size() == 0) {
|
// lhs_indices.
|
||||||
array zero = array(0, a_pre.dtype());
|
auto broadcast_with_indices = [&d, &s, &indices](const array& x) {
|
||||||
fill_gpu(zero, out, s);
|
if (x.size() / x.shape(-2) / x.shape(-1) == indices.size()) {
|
||||||
d.add_temporary(std::move(zero), s.index);
|
return ensure_row_contiguous(x, d, s);
|
||||||
return;
|
}
|
||||||
}
|
|
||||||
|
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
auto x_shape = indices.shape();
|
||||||
|
x_shape.push_back(x.shape(-2));
|
||||||
|
x_shape.push_back(x.shape(-1));
|
||||||
|
array new_x(std::move(x_shape), x.dtype(), nullptr, {});
|
||||||
|
broadcast(x, new_x);
|
||||||
|
return ensure_row_contiguous(new_x, d, s);
|
||||||
|
};
|
||||||
|
array a = broadcast_with_indices(a_);
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
// Extract the matmul shapes
|
||||||
// Init checks and prep
|
int K = a.shape(-1);
|
||||||
|
int M = a.size() / K;
|
||||||
|
int N = b.shape(-1);
|
||||||
|
int lda = a.strides()[a.ndim() - 2]; // should be K
|
||||||
|
|
||||||
int M = a_pre.shape(-2);
|
// Define the dispatch blocks
|
||||||
int N = b_pre.shape(-1);
|
int bm = 16, bn = 64, bk = 16;
|
||||||
int K = a_pre.shape(-1);
|
int wm = 1, wn = 2;
|
||||||
|
|
||||||
// Keep a vector with copies to be cleared in the completed buffer to release
|
const bool align_M = (M % bm) == 0;
|
||||||
// the arrays
|
const bool align_N = (N % bn) == 0;
|
||||||
std::vector<array> copies;
|
const bool align_K = (K % bk) == 0;
|
||||||
auto [transpose_a, a_cols, a] = check_transpose(copies, s, a_pre, M == 1);
|
|
||||||
auto [transpose_b, b_cols, b] = check_transpose(copies, s, b_pre, N == 1);
|
|
||||||
|
|
||||||
int lda = a_cols;
|
// Define the kernel name
|
||||||
int ldb = b_cols;
|
std::string base_name;
|
||||||
|
base_name.reserve(64);
|
||||||
|
concatenate(
|
||||||
|
base_name,
|
||||||
|
"steel_gather_mm_rhs_n",
|
||||||
|
transpose_b ? 't' : 'n',
|
||||||
|
'_',
|
||||||
|
type_to_name(a),
|
||||||
|
'_',
|
||||||
|
type_to_name(out),
|
||||||
|
"_bm",
|
||||||
|
bm,
|
||||||
|
"_bn",
|
||||||
|
bn,
|
||||||
|
"_bk",
|
||||||
|
bk,
|
||||||
|
"_wm",
|
||||||
|
wm,
|
||||||
|
"_wn",
|
||||||
|
wn);
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
metal::MTLFCList func_consts = {
|
||||||
// Check and collapse batch dimensions
|
{&align_M, MTL::DataType::DataTypeBool, 200},
|
||||||
|
{&align_N, MTL::DataType::DataTypeBool, 201},
|
||||||
auto get_batch_dims = [](const auto& v) {
|
{&align_K, MTL::DataType::DataTypeBool, 202},
|
||||||
return decltype(v){v.begin(), v.end() - 2};
|
|
||||||
};
|
};
|
||||||
|
|
||||||
auto& lhs_indices = inputs[2];
|
// And the kernel hash that includes the function constants
|
||||||
auto& rhs_indices = inputs[3];
|
std::string hash_name;
|
||||||
|
hash_name.reserve(128);
|
||||||
|
concatenate(
|
||||||
|
hash_name,
|
||||||
|
base_name,
|
||||||
|
"_align_M_",
|
||||||
|
align_M ? 't' : 'n',
|
||||||
|
"_align_N_",
|
||||||
|
align_N ? 't' : 'n',
|
||||||
|
"_align_K_",
|
||||||
|
align_K ? 't' : 'n');
|
||||||
|
|
||||||
Shape batch_shape = get_batch_dims(out.shape());
|
// Get and set the kernel
|
||||||
Strides batch_strides;
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
auto kernel = get_steel_gemm_gather_kernel(
|
||||||
|
d,
|
||||||
|
base_name,
|
||||||
|
hash_name,
|
||||||
|
func_consts,
|
||||||
|
out,
|
||||||
|
false,
|
||||||
|
transpose_b,
|
||||||
|
bm,
|
||||||
|
bn,
|
||||||
|
bk,
|
||||||
|
wm,
|
||||||
|
wn,
|
||||||
|
true);
|
||||||
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
|
|
||||||
batch_strides.insert(
|
// Prepare the matmul params
|
||||||
batch_strides.end(),
|
auto batch_stride_b = b.ndim() > 2 ? b.strides()[b.ndim() - 3] : b.size();
|
||||||
lhs_indices.strides().begin(),
|
steel::GEMMParams params{
|
||||||
lhs_indices.strides().end());
|
/* const int M = */ M,
|
||||||
auto lhs_indices_str = batch_strides.empty() ? 0 : batch_strides.back();
|
/* const int N = */ N,
|
||||||
|
/* const int K = */ K,
|
||||||
|
/* const int lda = */ lda,
|
||||||
|
/* const int ldb = */ static_cast<int>(ldb),
|
||||||
|
/* const int ldd = */ N,
|
||||||
|
/* const int tiles_n = */ (N + bn - 1) / bn,
|
||||||
|
/* const int tiles_m = */ (M + bm - 1) / bm,
|
||||||
|
/* const int64_t batch_stride_a = */ 0,
|
||||||
|
/* const int64_t batch_stride_b = */ static_cast<int64_t>(batch_stride_b),
|
||||||
|
/* const int64_t batch_stride_d = */ 0,
|
||||||
|
/* const int swizzle_log = */ 0,
|
||||||
|
/* const int gemm_k_iterations_aligned = */ (K / bk),
|
||||||
|
/* const int batch_ndim = */ 0};
|
||||||
|
|
||||||
batch_strides.insert(
|
// Prepare the grid
|
||||||
batch_strides.end(),
|
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||||
rhs_indices.strides().begin(),
|
MTL::Size grid_dims = MTL::Size(params.tiles_n, params.tiles_m, 1);
|
||||||
rhs_indices.strides().end());
|
|
||||||
auto rhs_indices_str = batch_strides.empty() ? 0 : batch_strides.back();
|
|
||||||
|
|
||||||
int batch_ndim = batch_shape.size();
|
// Launch kernel
|
||||||
|
compute_encoder.set_input_array(a, 0);
|
||||||
|
compute_encoder.set_input_array(b, 1);
|
||||||
|
compute_encoder.set_input_array(indices, 2);
|
||||||
|
compute_encoder.set_output_array(out, 3);
|
||||||
|
compute_encoder.set_bytes(params, 4);
|
||||||
|
|
||||||
if (batch_ndim == 0) {
|
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||||
batch_shape = {1};
|
}
|
||||||
batch_strides = {0};
|
|
||||||
}
|
|
||||||
|
|
||||||
int batch_ndim_A = a.ndim() - 2;
|
void gather_mv(
|
||||||
int batch_ndim_B = b.ndim() - 2;
|
const array& mat_,
|
||||||
std::vector<int> operand_batch_ndim = {batch_ndim_A, batch_ndim_B};
|
const array& vec_,
|
||||||
|
const array& mat_indices_,
|
||||||
|
const array& vec_indices_,
|
||||||
|
array& out,
|
||||||
|
int N,
|
||||||
|
int K,
|
||||||
|
bool is_mv,
|
||||||
|
metal::Device& d,
|
||||||
|
const Stream& s) {
|
||||||
|
// Copy if needed
|
||||||
|
std::vector<array> copies;
|
||||||
|
auto [transpose_mat, mat_cols, mat] =
|
||||||
|
check_transpose(copies, s, mat_, N == 1);
|
||||||
|
auto [transpose_vec, vec_cols, vec] = check_transpose(copies, s, vec_, true);
|
||||||
|
d.add_temporaries(std::move(copies), s.index);
|
||||||
|
|
||||||
Shape batch_shape_A = get_batch_dims(a.shape());
|
// If we are doing vector matrix instead of matrix vector we need to flip the
|
||||||
Strides batch_strides_A = get_batch_dims(a.strides());
|
// matrix transposition. Basically m @ v = v @ m.T assuming that v is treated
|
||||||
Shape batch_shape_B = get_batch_dims(b.shape());
|
// as a one dimensional array.
|
||||||
Strides batch_strides_B = get_batch_dims(b.strides());
|
transpose_mat = (!is_mv) ^ transpose_mat;
|
||||||
|
|
||||||
if (batch_ndim_A == 0) {
|
// Define some shapes
|
||||||
batch_shape_A = {1};
|
int in_vector_len = K;
|
||||||
batch_strides_A = {0};
|
int out_vector_len = N;
|
||||||
}
|
int mat_ld = mat_cols;
|
||||||
|
|
||||||
if (batch_ndim_B == 0) {
|
int batch_size_out = out.size() / N;
|
||||||
batch_shape_B = {1};
|
int batch_ndim = out.ndim() - 2;
|
||||||
batch_strides_B = {0};
|
int batch_ndim_mat = mat.ndim() - 2;
|
||||||
}
|
int batch_ndim_vec = vec.ndim() - 2;
|
||||||
|
Strides index_strides = vec_indices_.strides();
|
||||||
|
index_strides.insert(
|
||||||
|
index_strides.end(),
|
||||||
|
mat_indices_.strides().begin(),
|
||||||
|
mat_indices_.strides().end());
|
||||||
|
|
||||||
auto matrix_stride_out = static_cast<int64_t>(M) * N;
|
// Determine dispatch kernel
|
||||||
auto batch_size_out = out.size() / matrix_stride_out;
|
int tm = 4, tn = 4;
|
||||||
|
int sm = 1, sn = 32;
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
int bm = 1, bn = 1;
|
||||||
// Gemv specialization
|
int n_out_per_tgp;
|
||||||
|
std::ostringstream kname;
|
||||||
// Route to gemv if needed
|
|
||||||
if (std::min(M, N) == 1) {
|
|
||||||
// Collect problem info
|
|
||||||
bool is_b_matrix = N != 1;
|
|
||||||
|
|
||||||
auto& mat = is_b_matrix ? b : a;
|
|
||||||
auto& vec = is_b_matrix ? a : b;
|
|
||||||
bool transpose_mat = is_b_matrix ? !transpose_b : transpose_a;
|
|
||||||
int in_vector_len = K;
|
|
||||||
int out_vector_len = is_b_matrix ? N : M;
|
|
||||||
|
|
||||||
int mat_cols = transpose_mat ? out_vector_len : in_vector_len;
|
|
||||||
int mat_rows = transpose_mat ? in_vector_len : out_vector_len;
|
|
||||||
int mat_ld = is_b_matrix ? b_cols : a_cols;
|
|
||||||
|
|
||||||
auto batch_strides_mat = is_b_matrix ? batch_strides_B : batch_strides_A;
|
|
||||||
auto batch_strides_vec = is_b_matrix ? batch_strides_A : batch_strides_B;
|
|
||||||
|
|
||||||
auto batch_shape_mat = is_b_matrix ? batch_shape_B : batch_shape_A;
|
|
||||||
auto batch_shape_vec = is_b_matrix ? batch_shape_A : batch_shape_B;
|
|
||||||
|
|
||||||
if (!is_b_matrix) {
|
|
||||||
batch_strides = rhs_indices.strides();
|
|
||||||
batch_strides.insert(
|
|
||||||
batch_strides.end(),
|
|
||||||
lhs_indices.strides().begin(),
|
|
||||||
lhs_indices.strides().end());
|
|
||||||
}
|
|
||||||
|
|
||||||
int batch_ndim = batch_shape.size();
|
|
||||||
|
|
||||||
// Determine dispatch kernel
|
|
||||||
int tm = 4, tn = 4;
|
|
||||||
int sm = 1, sn = 32;
|
|
||||||
int bm = 1, bn = 1;
|
|
||||||
int n_out_per_tgp;
|
|
||||||
std::ostringstream kname;
|
|
||||||
|
|
||||||
if (transpose_mat) {
|
|
||||||
if (in_vector_len >= 8192 && out_vector_len >= 2048) {
|
|
||||||
sm = 4;
|
|
||||||
sn = 8;
|
|
||||||
} else {
|
|
||||||
sm = 8;
|
|
||||||
sn = 4;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (out_vector_len >= 2048) {
|
|
||||||
bn = 16;
|
|
||||||
} else if (out_vector_len >= 512) {
|
|
||||||
bn = 4;
|
|
||||||
} else {
|
|
||||||
bn = 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Specialized kernel for very small outputs
|
|
||||||
tn = out_vector_len < tn ? 1 : tn;
|
|
||||||
|
|
||||||
n_out_per_tgp = bn * sn * tn;
|
|
||||||
kname << "gemv_t_gather_" << type_to_name(out);
|
|
||||||
|
|
||||||
|
if (transpose_mat) {
|
||||||
|
if (in_vector_len >= 8192 && out_vector_len >= 2048) {
|
||||||
|
sm = 4;
|
||||||
|
sn = 8;
|
||||||
} else {
|
} else {
|
||||||
bm = out_vector_len >= 4096 ? 8 : 4;
|
sm = 8;
|
||||||
sn = 32;
|
sn = 4;
|
||||||
|
|
||||||
// Specialized kernel for very small outputs
|
|
||||||
tm = out_vector_len < tm ? 1 : tm;
|
|
||||||
|
|
||||||
n_out_per_tgp = bm * sm * tm;
|
|
||||||
kname << "gemv_gather_" << type_to_name(out);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
kname << "_bm" << bm << "_bn" << bn << "_sm" << sm << "_sn" << sn << "_tm"
|
if (out_vector_len >= 2048) {
|
||||||
<< tm << "_tn" << tn;
|
bn = 16;
|
||||||
|
} else if (out_vector_len >= 512) {
|
||||||
|
bn = 4;
|
||||||
|
} else {
|
||||||
|
bn = 2;
|
||||||
|
}
|
||||||
|
|
||||||
// Encode and dispatch kernel
|
// Specialized kernel for very small outputs
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
tn = out_vector_len < tn ? 1 : tn;
|
||||||
auto kernel = d.get_kernel(kname.str());
|
|
||||||
compute_encoder.set_compute_pipeline_state(kernel);
|
|
||||||
|
|
||||||
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
|
n_out_per_tgp = bn * sn * tn;
|
||||||
MTL::Size group_dims = MTL::Size(32, bn, bm);
|
kname << "gemv_t_gather_" << type_to_name(out);
|
||||||
MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out);
|
|
||||||
|
|
||||||
compute_encoder.set_input_array(mat, 0);
|
} else {
|
||||||
compute_encoder.set_input_array(vec, 1);
|
bm = out_vector_len >= 4096 ? 8 : 4;
|
||||||
compute_encoder.set_output_array(out, 3);
|
sn = 32;
|
||||||
|
|
||||||
compute_encoder.set_bytes(in_vector_len, 4);
|
// Specialized kernel for very small outputs
|
||||||
compute_encoder.set_bytes(out_vector_len, 5);
|
tm = out_vector_len < tm ? 1 : tm;
|
||||||
compute_encoder.set_bytes(mat_ld, 6);
|
|
||||||
|
|
||||||
compute_encoder.set_bytes(batch_ndim, 9);
|
n_out_per_tgp = bm * sm * tm;
|
||||||
compute_encoder.set_vector_bytes(batch_shape, 10);
|
kname << "gemv_gather_" << type_to_name(out);
|
||||||
compute_encoder.set_vector_bytes(batch_strides, 11);
|
|
||||||
|
|
||||||
int batch_ndim_vec = batch_shape_vec.size();
|
|
||||||
compute_encoder.set_bytes(batch_ndim_vec, 12);
|
|
||||||
compute_encoder.set_vector_bytes(batch_shape_vec, 13);
|
|
||||||
compute_encoder.set_vector_bytes(batch_strides_vec, 14);
|
|
||||||
|
|
||||||
int batch_ndim_mat = batch_shape_mat.size();
|
|
||||||
compute_encoder.set_bytes(batch_ndim_mat, 15);
|
|
||||||
compute_encoder.set_vector_bytes(batch_shape_mat, 16);
|
|
||||||
compute_encoder.set_vector_bytes(batch_strides_mat, 17);
|
|
||||||
|
|
||||||
compute_encoder.set_input_array(lhs_indices, 18 + int(!is_b_matrix));
|
|
||||||
compute_encoder.set_input_array(rhs_indices, 18 + int(is_b_matrix));
|
|
||||||
|
|
||||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
|
||||||
|
|
||||||
d.add_temporaries(std::move(copies), s.index);
|
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
kname << "_bm" << bm << "_bn" << bn << "_sm" << sm << "_sn" << sn << "_tm"
|
||||||
// Regular kernel dispatch
|
<< tm << "_tn" << tn;
|
||||||
|
|
||||||
|
// Encode and dispatch kernel
|
||||||
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
auto kernel = d.get_kernel(kname.str());
|
||||||
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
|
|
||||||
|
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
|
||||||
|
MTL::Size group_dims = MTL::Size(32, bn, bm);
|
||||||
|
MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out);
|
||||||
|
|
||||||
|
compute_encoder.set_input_array(mat, 0);
|
||||||
|
compute_encoder.set_input_array(vec, 1);
|
||||||
|
compute_encoder.set_output_array(out, 3);
|
||||||
|
|
||||||
|
compute_encoder.set_bytes(in_vector_len, 4);
|
||||||
|
compute_encoder.set_bytes(out_vector_len, 5);
|
||||||
|
compute_encoder.set_bytes(mat_ld, 6);
|
||||||
|
|
||||||
|
compute_encoder.set_bytes(batch_ndim, 9);
|
||||||
|
compute_encoder.set_vector_bytes(out.shape(), 10);
|
||||||
|
compute_encoder.set_vector_bytes(index_strides, 11);
|
||||||
|
|
||||||
|
compute_encoder.set_bytes(batch_ndim_vec, 12);
|
||||||
|
compute_encoder.set_vector_bytes(vec.shape(), 13);
|
||||||
|
compute_encoder.set_vector_bytes(vec.strides(), 14);
|
||||||
|
|
||||||
|
compute_encoder.set_bytes(batch_ndim_mat, 15);
|
||||||
|
compute_encoder.set_vector_bytes(mat.shape(), 16);
|
||||||
|
compute_encoder.set_vector_bytes(mat.strides(), 17);
|
||||||
|
|
||||||
|
compute_encoder.set_input_array(vec_indices_, 18);
|
||||||
|
compute_encoder.set_input_array(mat_indices_, 19);
|
||||||
|
|
||||||
|
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
void gather_mm(
|
||||||
|
const array& a_,
|
||||||
|
const array& b_,
|
||||||
|
const array& lhs_indices,
|
||||||
|
const array& rhs_indices,
|
||||||
|
array& out,
|
||||||
|
int M,
|
||||||
|
int N,
|
||||||
|
int K,
|
||||||
|
metal::Device& d,
|
||||||
|
const Stream& s) {
|
||||||
|
// Copy if needed
|
||||||
|
std::vector<array> copies;
|
||||||
|
auto [transpose_a, lda, a] = check_transpose(copies, s, a_, false);
|
||||||
|
auto [transpose_b, ldb, b] = check_transpose(copies, s, b_, false);
|
||||||
|
d.add_temporaries(std::move(copies), s.index);
|
||||||
|
|
||||||
// Determine dispatch kernel
|
// Determine dispatch kernel
|
||||||
int bm = 64, bn = 64, bk = 16;
|
int bm = 64, bn = 64, bk = 16;
|
||||||
int wm = 2, wn = 2;
|
int wm = 2, wn = 2;
|
||||||
|
size_t batch_size_out = out.size() / M / N;
|
||||||
|
int batch_ndim = out.ndim() - 2;
|
||||||
|
int batch_ndim_a = a.ndim() - 2;
|
||||||
|
int batch_ndim_b = b.ndim() - 2;
|
||||||
|
|
||||||
char devc = d.get_architecture().back();
|
char devc = d.get_architecture().back();
|
||||||
GEMM_TPARAM_MACRO(devc)
|
GEMM_TPARAM_MACRO(devc)
|
||||||
|
|
||||||
// Prepare kernel name
|
|
||||||
std::ostringstream kname;
|
|
||||||
kname << "steel_gemm_fused_" << (transpose_a ? 't' : 'n')
|
|
||||||
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
|
|
||||||
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
|
|
||||||
<< "_wm" << wm << "_wn" << wn;
|
|
||||||
|
|
||||||
std::string base_name = kname.str();
|
|
||||||
|
|
||||||
const bool has_batch = batch_ndim > 1;
|
const bool has_batch = batch_ndim > 1;
|
||||||
const bool use_out_source = false;
|
|
||||||
const bool do_axpby = false;
|
|
||||||
const bool align_M = (M % bm) == 0;
|
const bool align_M = (M % bm) == 0;
|
||||||
const bool align_N = (N % bn) == 0;
|
const bool align_N = (N % bn) == 0;
|
||||||
const bool align_K = (K % bk) == 0;
|
const bool align_K = (K % bk) == 0;
|
||||||
const bool do_gather = true;
|
|
||||||
|
// Define the kernel name
|
||||||
|
std::string base_name;
|
||||||
|
base_name.reserve(128);
|
||||||
|
concatenate(
|
||||||
|
base_name,
|
||||||
|
"steel_gather_mm_",
|
||||||
|
transpose_a ? 't' : 'n',
|
||||||
|
transpose_b ? 't' : 'n',
|
||||||
|
"_",
|
||||||
|
type_to_name(a),
|
||||||
|
"_",
|
||||||
|
type_to_name(out),
|
||||||
|
"_bm",
|
||||||
|
bm,
|
||||||
|
"_bn",
|
||||||
|
bn,
|
||||||
|
"_bk",
|
||||||
|
bk,
|
||||||
|
"_wm",
|
||||||
|
wm,
|
||||||
|
"_wn",
|
||||||
|
wn);
|
||||||
|
|
||||||
metal::MTLFCList func_consts = {
|
metal::MTLFCList func_consts = {
|
||||||
{&has_batch, MTL::DataType::DataTypeBool, 10},
|
{&has_batch, MTL::DataType::DataTypeBool, 10},
|
||||||
{&use_out_source, MTL::DataType::DataTypeBool, 100},
|
|
||||||
{&do_axpby, MTL::DataType::DataTypeBool, 110},
|
|
||||||
{&align_M, MTL::DataType::DataTypeBool, 200},
|
{&align_M, MTL::DataType::DataTypeBool, 200},
|
||||||
{&align_N, MTL::DataType::DataTypeBool, 201},
|
{&align_N, MTL::DataType::DataTypeBool, 201},
|
||||||
{&align_K, MTL::DataType::DataTypeBool, 202},
|
{&align_K, MTL::DataType::DataTypeBool, 202},
|
||||||
{&do_gather, MTL::DataType::DataTypeBool, 300},
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// clang-format off
|
// And the kernel hash that includes the function constants
|
||||||
kname << "_has_batch_" << (has_batch ? 't' : 'n')
|
std::string hash_name;
|
||||||
<< "_use_out_source_" << (use_out_source ? 't' : 'n')
|
hash_name.reserve(128);
|
||||||
<< "_do_axpby_" << (do_axpby ? 't' : 'n')
|
concatenate(
|
||||||
<< "_align_M_" << (align_M ? 't' : 'n')
|
hash_name,
|
||||||
<< "_align_N_" << (align_N ? 't' : 'n')
|
base_name,
|
||||||
<< "_align_K_" << (align_K ? 't' : 'n')
|
"_has_batch_",
|
||||||
<< "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on
|
has_batch ? 't' : 'n',
|
||||||
|
"_align_M_",
|
||||||
|
align_M ? 't' : 'n',
|
||||||
|
"_align_N_",
|
||||||
|
align_N ? 't' : 'n',
|
||||||
|
"_align_K_",
|
||||||
|
align_K ? 't' : 'n');
|
||||||
|
|
||||||
std::string hash_name = kname.str();
|
// Get and set the kernel
|
||||||
|
|
||||||
// Encode and dispatch kernel
|
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
auto kernel = get_steel_gemm_fused_kernel(
|
auto kernel = get_steel_gemm_gather_kernel(
|
||||||
d,
|
d,
|
||||||
base_name,
|
base_name,
|
||||||
hash_name,
|
hash_name,
|
||||||
@ -1736,72 +1842,96 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
bn,
|
bn,
|
||||||
bk,
|
bk,
|
||||||
wm,
|
wm,
|
||||||
wn);
|
wn,
|
||||||
|
false);
|
||||||
compute_encoder.set_compute_pipeline_state(kernel);
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
|
|
||||||
// Use problem size to determine threadblock swizzle
|
// Prepare the matmul params
|
||||||
int tn = (N + bn - 1) / bn;
|
steel::GEMMParams params{
|
||||||
int tm = (M + bm - 1) / bm;
|
|
||||||
|
|
||||||
// TODO: Explore device-based tuning for swizzle
|
|
||||||
int swizzle_log = 0; // tm >= 6 ? 3 : (tm <= 3 ? 0 : 2);
|
|
||||||
|
|
||||||
// Prepare steel matmul params
|
|
||||||
GEMMParams params{
|
|
||||||
/* const int M = */ M,
|
/* const int M = */ M,
|
||||||
/* const int N = */ N,
|
/* const int N = */ N,
|
||||||
/* const int K = */ K,
|
/* const int K = */ K,
|
||||||
/* const int lda = */ lda,
|
/* const int lda = */ static_cast<int>(lda),
|
||||||
/* const int ldb = */ ldb,
|
/* const int ldb = */ static_cast<int>(ldb),
|
||||||
/* const int ldd = */ N,
|
/* const int ldd = */ N,
|
||||||
/* const int tiles_n = */ tn,
|
/* const int tiles_n = */ (N + bn - 1) / bn,
|
||||||
/* const int tiles_m = */ tm,
|
/* const int tiles_m = */ (M + bm - 1) / bm,
|
||||||
/* const int64_t batch_stride_a = */ lhs_indices_str,
|
/* const int64_t batch_stride_a = */
|
||||||
/* const int64_t batch_stride_b = */ rhs_indices_str,
|
(batch_ndim > 0) ? lhs_indices.strides()[0] : 0,
|
||||||
/* const int64_t batch_stride_d = */ matrix_stride_out,
|
/* const int64_t batch_stride_b = */
|
||||||
/* const int swizzle_log = */ swizzle_log,
|
(batch_ndim > 0) ? rhs_indices.strides()[0] : 0,
|
||||||
|
/* const int64_t batch_stride_d = */ M * N,
|
||||||
|
/* const int swizzle_log = */ 0,
|
||||||
/* const int gemm_k_iterations_aligned = */ (K / bk),
|
/* const int gemm_k_iterations_aligned = */ (K / bk),
|
||||||
/* const int batch_ndim = */ batch_ndim};
|
/* const int batch_ndim = */ batch_ndim};
|
||||||
|
|
||||||
// Prepare launch grid params
|
// Prepare the grid
|
||||||
int tile = 1 << swizzle_log;
|
|
||||||
tm = (tm + tile - 1) / tile;
|
|
||||||
tn = tn * tile;
|
|
||||||
|
|
||||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||||
MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out);
|
MTL::Size grid_dims =
|
||||||
|
MTL::Size(params.tiles_n, params.tiles_m, batch_size_out);
|
||||||
|
|
||||||
// Launch kernel
|
// Launch kernel
|
||||||
compute_encoder.set_input_array(a, 0);
|
compute_encoder.set_input_array(a, 0);
|
||||||
compute_encoder.set_input_array(b, 1);
|
compute_encoder.set_input_array(b, 1);
|
||||||
compute_encoder.set_output_array(out, 3);
|
compute_encoder.set_input_array(lhs_indices, 2);
|
||||||
|
compute_encoder.set_input_array(rhs_indices, 3);
|
||||||
compute_encoder.set_bytes(params, 4);
|
compute_encoder.set_output_array(out, 4);
|
||||||
|
compute_encoder.set_bytes(params, 5);
|
||||||
compute_encoder.set_vector_bytes(batch_shape, 6);
|
compute_encoder.set_vector_bytes(lhs_indices.shape(), 6);
|
||||||
compute_encoder.set_vector_bytes(batch_strides, 7);
|
compute_encoder.set_vector_bytes(lhs_indices.strides(), 7);
|
||||||
|
compute_encoder.set_vector_bytes(rhs_indices.strides(), 8);
|
||||||
compute_encoder.set_input_array(lhs_indices, 10);
|
compute_encoder.set_bytes(batch_ndim_a, 9);
|
||||||
compute_encoder.set_input_array(rhs_indices, 11);
|
compute_encoder.set_vector_bytes(a.shape(), 10);
|
||||||
|
compute_encoder.set_vector_bytes(a.strides(), 11);
|
||||||
std::vector operand_shape = batch_shape_A;
|
compute_encoder.set_bytes(batch_ndim_b, 12);
|
||||||
operand_shape.insert(
|
compute_encoder.set_vector_bytes(b.shape(), 13);
|
||||||
operand_shape.end(), batch_shape_B.begin(), batch_shape_B.end());
|
compute_encoder.set_vector_bytes(b.strides(), 14);
|
||||||
|
|
||||||
std::vector operand_strides = batch_strides_A;
|
|
||||||
operand_strides.insert(
|
|
||||||
operand_strides.end(), batch_strides_B.begin(), batch_strides_B.end());
|
|
||||||
|
|
||||||
operand_batch_ndim.push_back(0);
|
|
||||||
|
|
||||||
compute_encoder.set_vector_bytes(operand_shape, 13);
|
|
||||||
compute_encoder.set_vector_bytes(operand_strides, 14);
|
|
||||||
compute_encoder.set_vector_bytes(operand_batch_ndim, 15);
|
|
||||||
|
|
||||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||||
|
}
|
||||||
|
|
||||||
d.add_temporaries(std::move(copies), s.index);
|
void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
auto& s = stream();
|
||||||
|
auto& d = metal::device(s.device);
|
||||||
|
|
||||||
|
auto& a = inputs[0];
|
||||||
|
auto& b = inputs[1];
|
||||||
|
auto& lhs_indices = inputs[2];
|
||||||
|
auto& rhs_indices = inputs[3];
|
||||||
|
|
||||||
|
// Return 0s if either input is empty
|
||||||
|
if (a.size() == 0 || b.size() == 0) {
|
||||||
|
array zero = array(0, a.dtype());
|
||||||
|
fill_gpu(zero, out, s);
|
||||||
|
d.add_temporary(std::move(zero), s.index);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
|
// Extract shapes from inputs.
|
||||||
|
int M = a.shape(-2);
|
||||||
|
int N = b.shape(-1);
|
||||||
|
int K = a.shape(-1);
|
||||||
|
|
||||||
|
// We are walking a in order and b is also in order so we can batch up the
|
||||||
|
// matmuls and reuse reading a and b.
|
||||||
|
if (M == 1 && right_sorted_ == true) {
|
||||||
|
gather_mm_rhs(a, b, rhs_indices, out, d, s);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Route to gather gemv if any of a or b are vectors
|
||||||
|
if (M == 1) {
|
||||||
|
gather_mv(b, a, rhs_indices, lhs_indices, out, N, K, false, d, s);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (N == 1) {
|
||||||
|
gather_mv(a, b, lhs_indices, rhs_indices, out, M, K, true, d, s);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Route to non specialized gather mm
|
||||||
|
gather_mm(a, b, lhs_indices, rhs_indices, out, M, N, K, d, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -193,6 +193,23 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
|
|||||||
return d.get_kernel(kernel_name);
|
return d.get_kernel(kernel_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MTL::ComputePipelineState* get_steel_gemm_gather_kernel(
|
||||||
|
metal::Device& d,
|
||||||
|
const std::string& kernel_name,
|
||||||
|
const std::string& hash_name,
|
||||||
|
const metal::MTLFCList& func_consts,
|
||||||
|
const array&,
|
||||||
|
bool,
|
||||||
|
bool,
|
||||||
|
int,
|
||||||
|
int,
|
||||||
|
int,
|
||||||
|
int,
|
||||||
|
int,
|
||||||
|
bool) {
|
||||||
|
return d.get_kernel(kernel_name, "mlx", hash_name, func_consts);
|
||||||
|
}
|
||||||
|
|
||||||
MTL::ComputePipelineState* get_gemv_masked_kernel(
|
MTL::ComputePipelineState* get_gemv_masked_kernel(
|
||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
const std::string& kernel_name,
|
const std::string& kernel_name,
|
||||||
@ -252,4 +269,21 @@ MTL::ComputePipelineState* get_quantized_kernel(
|
|||||||
return d.get_kernel(kernel_name);
|
return d.get_kernel(kernel_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MTL::ComputePipelineState* get_gather_qmm_kernel(
|
||||||
|
metal::Device& d,
|
||||||
|
const std::string& kernel_name,
|
||||||
|
const std::string& hash_name,
|
||||||
|
const metal::MTLFCList& func_consts,
|
||||||
|
const array&,
|
||||||
|
int,
|
||||||
|
int,
|
||||||
|
int,
|
||||||
|
int,
|
||||||
|
int,
|
||||||
|
int,
|
||||||
|
int,
|
||||||
|
bool) {
|
||||||
|
return d.get_kernel(kernel_name, "mlx", hash_name, func_consts);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -60,6 +60,9 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
case Scan::Min:
|
case Scan::Min:
|
||||||
reduce_type = "min";
|
reduce_type = "min";
|
||||||
break;
|
break;
|
||||||
|
case Scan::LogAddExp:
|
||||||
|
reduce_type = "logaddexp";
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
kname << reduce_type << "_" << type_to_name(in) << "_" << type_to_name(out);
|
kname << reduce_type << "_" << type_to_name(in) << "_" << type_to_name(out);
|
||||||
auto kernel = get_scan_kernel(
|
auto kernel = get_scan_kernel(
|
||||||
|
@ -2,6 +2,8 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <type_traits>
|
||||||
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
@ -58,14 +60,27 @@ inline void debug_set_primitive_buffer_label(
|
|||||||
|
|
||||||
std::string get_primitive_string(Primitive* primitive);
|
std::string get_primitive_string(Primitive* primitive);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
constexpr bool is_numeric_except_char = std::is_arithmetic_v<T> &&
|
||||||
|
!std::is_same_v<T, char> && !std::is_same_v<T, signed char> &&
|
||||||
|
!std::is_same_v<T, unsigned char> && !std::is_same_v<T, wchar_t>;
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void concatenate(std::string& acc, T first) {
|
void concatenate(std::string& acc, T first) {
|
||||||
acc += first;
|
if constexpr (is_numeric_except_char<T>) {
|
||||||
|
acc += std::to_string(first);
|
||||||
|
} else {
|
||||||
|
acc += first;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename... Args>
|
template <typename T, typename... Args>
|
||||||
void concatenate(std::string& acc, T first, Args... args) {
|
void concatenate(std::string& acc, T first, Args... args) {
|
||||||
acc += first;
|
if constexpr (is_numeric_except_char<T>) {
|
||||||
|
acc += std::to_string(first);
|
||||||
|
} else {
|
||||||
|
acc += first;
|
||||||
|
}
|
||||||
concatenate(acc, args...);
|
concatenate(acc, args...);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
20
mlx/dtype_utils.cpp
Normal file
20
mlx/dtype_utils.cpp
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/dtype_utils.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
const char* dtype_to_string(Dtype arg) {
|
||||||
|
if (arg == bool_) {
|
||||||
|
return "bool";
|
||||||
|
}
|
||||||
|
#define SPECIALIZE_DtypeToString(CPP_TYPE, DTYPE) \
|
||||||
|
if (DTYPE == arg) { \
|
||||||
|
return #DTYPE; \
|
||||||
|
}
|
||||||
|
MLX_FORALL_DTYPES(SPECIALIZE_DtypeToString)
|
||||||
|
#undef SPECIALIZE_DtypeToString
|
||||||
|
return "(unknown)";
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
207
mlx/dtype_utils.h
Normal file
207
mlx/dtype_utils.h
Normal file
@ -0,0 +1,207 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
// Copyright © Meta Platforms, Inc. and affiliates.
|
||||||
|
//
|
||||||
|
// This source code is licensed under the BSD-style license found in
|
||||||
|
// https://github.com/pytorch/executorch/blob/main/LICENSE
|
||||||
|
//
|
||||||
|
// Forked from
|
||||||
|
// https://github.com/pytorch/executorch/blob/main/runtime/core/exec_aten/util/scalar_type_util.h
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/dtype.h"
|
||||||
|
|
||||||
|
#include <fmt/format.h>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
// Return string representation of dtype.
|
||||||
|
const char* dtype_to_string(Dtype arg);
|
||||||
|
|
||||||
|
// Macros that iterate across different subsets of Dtypes.
|
||||||
|
//
|
||||||
|
// For all of these macros, the final `_` parameter is the name of another macro
|
||||||
|
// that takes two parameters: the name of a C type, and the name of the
|
||||||
|
// corresponding Dtype enumerator.
|
||||||
|
//
|
||||||
|
// Note that these macros should use fully-qualified namespaces (starting with
|
||||||
|
// `::`) to ensure that they can be called safely in any arbitrary namespace.
|
||||||
|
#define MLX_FORALL_INT_TYPES(_) \
|
||||||
|
_(uint8_t, uint8) \
|
||||||
|
_(uint16_t, uint16) \
|
||||||
|
_(uint32_t, uint32) \
|
||||||
|
_(uint64_t, uint64) \
|
||||||
|
_(int8_t, int8) \
|
||||||
|
_(int16_t, int16) \
|
||||||
|
_(int32_t, int32) \
|
||||||
|
_(int64_t, int64)
|
||||||
|
|
||||||
|
#define MLX_FORALL_FLOAT_TYPES(_) \
|
||||||
|
_(float16_t, float16) \
|
||||||
|
_(float, float32) \
|
||||||
|
_(double, float64) \
|
||||||
|
_(bfloat16_t, bfloat16)
|
||||||
|
|
||||||
|
// Calls the provided macro on every Dtype, providing the C type and the
|
||||||
|
// Dtype name to each call.
|
||||||
|
//
|
||||||
|
// @param _ A macro that takes two parameters: the name of a C type, and the
|
||||||
|
// name of the corresponding Dtype enumerator.
|
||||||
|
#define MLX_FORALL_DTYPES(_) \
|
||||||
|
MLX_FORALL_INT_TYPES(_) \
|
||||||
|
MLX_FORALL_FLOAT_TYPES(_) \
|
||||||
|
_(bool, bool_) \
|
||||||
|
_(complex64_t, complex64)
|
||||||
|
|
||||||
|
// Maps Dtypes to C++ types.
|
||||||
|
template <Dtype::Val N>
|
||||||
|
struct DtypeToCppType;
|
||||||
|
|
||||||
|
#define SPECIALIZE_DtypeToCppType(CPP_TYPE, DTYPE) \
|
||||||
|
template <> \
|
||||||
|
struct DtypeToCppType<Dtype::Val::DTYPE> { \
|
||||||
|
using type = CPP_TYPE; \
|
||||||
|
};
|
||||||
|
|
||||||
|
MLX_FORALL_DTYPES(SPECIALIZE_DtypeToCppType)
|
||||||
|
|
||||||
|
#undef SPECIALIZE_DtypeToCppType
|
||||||
|
|
||||||
|
// Maps C++ types to Dtypes.
|
||||||
|
template <typename T>
|
||||||
|
struct CppTypeToDtype;
|
||||||
|
|
||||||
|
#define SPECIALIZE_CppTypeToDtype(CPP_TYPE, DTYPE) \
|
||||||
|
template <> \
|
||||||
|
struct CppTypeToDtype<CPP_TYPE> \
|
||||||
|
: std::integral_constant<Dtype::Val, Dtype::Val::DTYPE> {};
|
||||||
|
|
||||||
|
MLX_FORALL_DTYPES(SPECIALIZE_CppTypeToDtype)
|
||||||
|
|
||||||
|
#undef SPECIALIZE_CppTypeToDtype
|
||||||
|
|
||||||
|
// Helper macros for switch case macros (see below)
|
||||||
|
//
|
||||||
|
// These macros are not meant to be used directly. They provide an easy way to
|
||||||
|
// generate a switch statement that can handle subsets of Dtypes supported.
|
||||||
|
|
||||||
|
#define MLX_INTERNAL_SWITCH_CASE(enum_type, CTYPE_ALIAS, ...) \
|
||||||
|
case enum_type: { \
|
||||||
|
using CTYPE_ALIAS = ::mlx::core::DtypeToCppType<enum_type>::type; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
break; \
|
||||||
|
}
|
||||||
|
|
||||||
|
#define MLX_INTERNAL_SWITCH_CHECKED(TYPE, NAME, ...) \
|
||||||
|
switch (TYPE) { \
|
||||||
|
__VA_ARGS__ \
|
||||||
|
default: \
|
||||||
|
throw std::invalid_argument(fmt::format( \
|
||||||
|
"Unhandled dtype %s for %s", dtype_to_string(TYPE), NAME)); \
|
||||||
|
}
|
||||||
|
|
||||||
|
#define MLX_INTERNAL_SWITCH_CASE_INT_TYPES(CTYPE_ALIAS, ...) \
|
||||||
|
MLX_INTERNAL_SWITCH_CASE( \
|
||||||
|
::mlx::core::Dtype::Val::uint8, CTYPE_ALIAS, __VA_ARGS__) \
|
||||||
|
MLX_INTERNAL_SWITCH_CASE( \
|
||||||
|
::mlx::core::Dtype::Val::uint16, CTYPE_ALIAS, __VA_ARGS__) \
|
||||||
|
MLX_INTERNAL_SWITCH_CASE( \
|
||||||
|
::mlx::core::Dtype::Val::uint32, CTYPE_ALIAS, __VA_ARGS__) \
|
||||||
|
MLX_INTERNAL_SWITCH_CASE( \
|
||||||
|
::mlx::core::Dtype::Val::uint64, CTYPE_ALIAS, __VA_ARGS__) \
|
||||||
|
MLX_INTERNAL_SWITCH_CASE( \
|
||||||
|
::mlx::core::Dtype::Val::int8, CTYPE_ALIAS, __VA_ARGS__) \
|
||||||
|
MLX_INTERNAL_SWITCH_CASE( \
|
||||||
|
::mlx::core::Dtype::Val::int16, CTYPE_ALIAS, __VA_ARGS__) \
|
||||||
|
MLX_INTERNAL_SWITCH_CASE( \
|
||||||
|
::mlx::core::Dtype::Val::int32, CTYPE_ALIAS, __VA_ARGS__) \
|
||||||
|
MLX_INTERNAL_SWITCH_CASE( \
|
||||||
|
::mlx::core::Dtype::Val::int64, CTYPE_ALIAS, __VA_ARGS__)
|
||||||
|
|
||||||
|
#define MLX_INTERNAL_SWITCH_CASE_FLOAT_TYPES(CTYPE_ALIAS, ...) \
|
||||||
|
MLX_INTERNAL_SWITCH_CASE( \
|
||||||
|
::mlx::core::Dtype::Val::float16, CTYPE_ALIAS, __VA_ARGS__) \
|
||||||
|
MLX_INTERNAL_SWITCH_CASE( \
|
||||||
|
::mlx::core::Dtype::Val::float32, CTYPE_ALIAS, __VA_ARGS__) \
|
||||||
|
MLX_INTERNAL_SWITCH_CASE( \
|
||||||
|
::mlx::core::Dtype::Val::float64, CTYPE_ALIAS, __VA_ARGS__) \
|
||||||
|
MLX_INTERNAL_SWITCH_CASE( \
|
||||||
|
::mlx::core::Dtype::Val::bfloat16, CTYPE_ALIAS, __VA_ARGS__)
|
||||||
|
|
||||||
|
#define MLX_INTERNAL_SWITCH_CASE_INT_FLOAT_TYPES(CTYPE_ALIAS, ...) \
|
||||||
|
MLX_INTERNAL_SWITCH_CASE_INT_TYPES(CTYPE_ALIAS, __VA_ARGS__) \
|
||||||
|
MLX_INTERNAL_SWITCH_CASE_FLOAT_TYPES(CTYPE_ALIAS, __VA_ARGS__)
|
||||||
|
|
||||||
|
#define MLX_INTERNAL_SWITCH_CASE_REAL_TYPES(CTYPE_ALIAS, ...) \
|
||||||
|
MLX_INTERNAL_SWITCH_CASE_INT_FLOAT_TYPES(CTYPE_ALIAS, __VA_ARGS__) \
|
||||||
|
MLX_INTERNAL_SWITCH_CASE( \
|
||||||
|
::mlx::core::Dtype::Val::bool_, CTYPE_ALIAS, __VA_ARGS__)
|
||||||
|
|
||||||
|
#define MLX_INTERNAL_SWITCH_CASE_COMPLEX_TYPES(CTYPE_ALIAS, ...) \
|
||||||
|
MLX_INTERNAL_SWITCH_CASE( \
|
||||||
|
::mlx::core::Dtype::Val::complex64, CTYPE_ALIAS, __VA_ARGS__)
|
||||||
|
|
||||||
|
#define MLX_INTERNAL_SWITCH_CASE_ALL_TYPES(CTYPE_ALIAS, ...) \
|
||||||
|
MLX_INTERNAL_SWITCH_CASE_REAL_TYPES(CTYPE_ALIAS, __VA_ARGS__) \
|
||||||
|
MLX_INTERNAL_SWITCH_CASE_COMPLEX_TYPES(CTYPE_ALIAS, __VA_ARGS__)
|
||||||
|
|
||||||
|
// Switch case macros
|
||||||
|
//
|
||||||
|
// These macros provide an easy way to generate switch statements that apply a
|
||||||
|
// common lambda function to subsets of Dtypes supported by MLX.
|
||||||
|
// The lambda function can type specialize to the ctype associated with the
|
||||||
|
// Dtype being handled through an alias passed as the CTYPE_ALIAS argument.
|
||||||
|
//
|
||||||
|
// Arguments:
|
||||||
|
// - ADDITIONAL: Additional Dtype case to add
|
||||||
|
// - TYPE: The Dtype to handle through the switch statement
|
||||||
|
// - NAME: A name for this operation which will be used in error messages
|
||||||
|
// - CTYPE_ALIAS: A typedef for the ctype associated with the Dtype.
|
||||||
|
// - ...: A statement to be applied to each Dtype case
|
||||||
|
//
|
||||||
|
// An example usage is:
|
||||||
|
//
|
||||||
|
// MLX_SWITCH_ALL_TYPES(input.dtype(), CTYPE, {
|
||||||
|
// output.data<CTYPE>[0] = input.data<CTYPE>[0];
|
||||||
|
// });
|
||||||
|
//
|
||||||
|
// Note that these can be nested as well:
|
||||||
|
//
|
||||||
|
// MLX_SWITCH_ALL_TYPES(input.dtype(), CTYPE_IN, {
|
||||||
|
// MLX_SWITCH_ALL_TYPES(output.dtype(), CTYPE_OUT, {
|
||||||
|
// output.data<CTYPE_OUT>[0] = input.data<CTYPE_IN>[0];
|
||||||
|
// });
|
||||||
|
// });
|
||||||
|
//
|
||||||
|
// These macros are adapted from Dispatch.h in the ATen library. The primary
|
||||||
|
// difference is that the CTYPE_ALIAS argument is exposed to users, which is
|
||||||
|
// used to alias the ctype associated with the Dtype that is being handled.
|
||||||
|
|
||||||
|
#define MLX_SWITCH_ALL_TYPES(TYPE, CTYPE_ALIAS, ...) \
|
||||||
|
switch (TYPE) { MLX_INTERNAL_SWITCH_CASE_ALL_TYPES(CTYPE_ALIAS, __VA_ARGS__) }
|
||||||
|
|
||||||
|
#define MLX_SWITCH_INT_TYPES_CHECKED(TYPE, NAME, CTYPE_ALIAS, ...) \
|
||||||
|
MLX_INTERNAL_SWITCH_CHECKED( \
|
||||||
|
TYPE, \
|
||||||
|
NAME, \
|
||||||
|
MLX_INTERNAL_SWITCH_CASE_INT_TYPES(CTYPE_ALIAS, __VA_ARGS__))
|
||||||
|
|
||||||
|
#define MLX_SWITCH_FLOAT_TYPES_CHECKED(TYPE, NAME, CTYPE_ALIAS, ...) \
|
||||||
|
MLX_INTERNAL_SWITCH_CHECKED( \
|
||||||
|
TYPE, \
|
||||||
|
NAME, \
|
||||||
|
MLX_INTERNAL_SWITCH_CASE_FLOAT_TYPES(CTYPE_ALIAS, __VA_ARGS__))
|
||||||
|
|
||||||
|
#define MLX_SWITCH_INT_FLOAT_TYPES_CHECKED(TYPE, NAME, CTYPE_ALIAS, ...) \
|
||||||
|
MLX_INTERNAL_SWITCH_CHECKED( \
|
||||||
|
TYPE, \
|
||||||
|
NAME, \
|
||||||
|
MLX_INTERNAL_SWITCH_CASE_INT_FLOAT_TYPES(CTYPE_ALIAS, __VA_ARGS__))
|
||||||
|
|
||||||
|
#define MLX_SWITCH_REAL_TYPES_CHECKED(TYPE, NAME, CTYPE_ALIAS, ...) \
|
||||||
|
MLX_INTERNAL_SWITCH_CHECKED( \
|
||||||
|
TYPE, \
|
||||||
|
NAME, \
|
||||||
|
MLX_INTERNAL_SWITCH_CASE_REAL_TYPES(CTYPE_ALIAS, __VA_ARGS__))
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
@ -1,5 +1,6 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
#include "mlx/export.h"
|
#include "mlx/export.h"
|
||||||
|
#include <map>
|
||||||
#include "mlx/compile_impl.h"
|
#include "mlx/compile_impl.h"
|
||||||
#include "mlx/fast_primitives.h"
|
#include "mlx/fast_primitives.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
@ -298,7 +299,13 @@ struct PrimitiveFactory {
|
|||||||
SERIALIZE_PRIMITIVE(Reshape),
|
SERIALIZE_PRIMITIVE(Reshape),
|
||||||
SERIALIZE_PRIMITIVE(Reduce, "And", "Or", "Sum", "Prod", "Min", "Max"),
|
SERIALIZE_PRIMITIVE(Reduce, "And", "Or", "Sum", "Prod", "Min", "Max"),
|
||||||
SERIALIZE_PRIMITIVE(Round),
|
SERIALIZE_PRIMITIVE(Round),
|
||||||
SERIALIZE_PRIMITIVE(Scan, "CumSum", "CumProd", "CumMin", "CumMax"),
|
SERIALIZE_PRIMITIVE(
|
||||||
|
Scan,
|
||||||
|
"CumSum",
|
||||||
|
"CumProd",
|
||||||
|
"CumMin",
|
||||||
|
"CumMax",
|
||||||
|
"CumLogaddexp"),
|
||||||
SERIALIZE_PRIMITIVE(Scatter),
|
SERIALIZE_PRIMITIVE(Scatter),
|
||||||
SERIALIZE_PRIMITIVE(Select),
|
SERIALIZE_PRIMITIVE(Select),
|
||||||
SERIALIZE_PRIMITIVE(Sigmoid),
|
SERIALIZE_PRIMITIVE(Sigmoid),
|
||||||
@ -475,7 +482,9 @@ bool FunctionTable::match(
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (auto& [_, in] : kwargs) {
|
auto sorted_kwargs =
|
||||||
|
std::map<std::string, array>(kwargs.begin(), kwargs.end());
|
||||||
|
for (auto& [_, in] : sorted_kwargs) {
|
||||||
if (!match_inputs(in, fun.inputs[i++])) {
|
if (!match_inputs(in, fun.inputs[i++])) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -551,7 +560,9 @@ void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) {
|
|||||||
// Flatten the inputs to the function for tracing
|
// Flatten the inputs to the function for tracing
|
||||||
std::vector<std::string> kwarg_keys;
|
std::vector<std::string> kwarg_keys;
|
||||||
auto inputs = args;
|
auto inputs = args;
|
||||||
for (auto& [k, v] : kwargs) {
|
auto sorted_kwargs =
|
||||||
|
std::map<std::string, array>(kwargs.begin(), kwargs.end());
|
||||||
|
for (auto& [k, v] : sorted_kwargs) {
|
||||||
kwarg_keys.push_back(k);
|
kwarg_keys.push_back(k);
|
||||||
inputs.push_back(v);
|
inputs.push_back(v);
|
||||||
}
|
}
|
||||||
|
@ -2,14 +2,14 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <map>
|
|
||||||
#include <set>
|
#include <set>
|
||||||
|
#include <unordered_map>
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
using Args = std::vector<array>;
|
using Args = std::vector<array>;
|
||||||
using Kwargs = std::map<std::string, array>;
|
using Kwargs = std::unordered_map<std::string, array>;
|
||||||
|
|
||||||
struct FunctionExporter;
|
struct FunctionExporter;
|
||||||
|
|
||||||
|
@ -111,7 +111,7 @@ array fft_impl(
|
|||||||
for (auto ax : axes) {
|
for (auto ax : axes) {
|
||||||
n.push_back(a.shape(ax));
|
n.push_back(a.shape(ax));
|
||||||
}
|
}
|
||||||
if (real && inverse) {
|
if (real && inverse && a.ndim() > 0) {
|
||||||
n.back() = (n.back() - 1) * 2;
|
n.back() = (n.back() - 1) * 2;
|
||||||
}
|
}
|
||||||
return fft_impl(a, n, axes, real, inverse, s);
|
return fft_impl(a, n, axes, real, inverse, s);
|
||||||
|
62
mlx/ops.cpp
62
mlx/ops.cpp
@ -3504,6 +3504,28 @@ array cummin(
|
|||||||
{a});
|
{a});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array logcumsumexp(
|
||||||
|
const array& a,
|
||||||
|
int axis,
|
||||||
|
bool reverse /* = false*/,
|
||||||
|
bool inclusive /* = true*/,
|
||||||
|
StreamOrDevice s /* = {}*/) {
|
||||||
|
int ndim = a.ndim();
|
||||||
|
if (axis >= ndim || axis < -ndim) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[logcumsumexp] Axis " << axis << " is out of bounds for array with "
|
||||||
|
<< a.ndim() << " dimensions.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
axis = (axis + a.ndim()) % a.ndim();
|
||||||
|
return array(
|
||||||
|
a.shape(),
|
||||||
|
a.dtype(),
|
||||||
|
std::make_shared<Scan>(
|
||||||
|
to_stream(s), Scan::ReduceType::LogAddExp, axis, reverse, inclusive),
|
||||||
|
{a});
|
||||||
|
}
|
||||||
|
|
||||||
/** Convolution operations */
|
/** Convolution operations */
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
@ -4006,6 +4028,7 @@ array gather_qmm(
|
|||||||
bool transpose /* = true */,
|
bool transpose /* = true */,
|
||||||
int group_size /* = 64 */,
|
int group_size /* = 64 */,
|
||||||
int bits /* = 4 */,
|
int bits /* = 4 */,
|
||||||
|
bool sorted_indices /* = false */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
if (!lhs_indices_ && !rhs_indices_) {
|
if (!lhs_indices_ && !rhs_indices_) {
|
||||||
return quantized_matmul(
|
return quantized_matmul(
|
||||||
@ -4045,13 +4068,19 @@ array gather_qmm(
|
|||||||
return array(
|
return array(
|
||||||
std::move(out_shape),
|
std::move(out_shape),
|
||||||
out_type,
|
out_type,
|
||||||
std::make_shared<GatherQMM>(to_stream(s), group_size, bits, transpose),
|
std::make_shared<GatherQMM>(
|
||||||
|
to_stream(s),
|
||||||
|
group_size,
|
||||||
|
bits,
|
||||||
|
transpose,
|
||||||
|
sorted_indices && !rhs_indices_,
|
||||||
|
sorted_indices && !lhs_indices_),
|
||||||
{astype(x, out_type, s),
|
{astype(x, out_type, s),
|
||||||
w,
|
std::move(w),
|
||||||
astype(scales, out_type, s),
|
astype(scales, out_type, s),
|
||||||
astype(biases, out_type, s),
|
astype(biases, out_type, s),
|
||||||
lhs_indices,
|
std::move(lhs_indices),
|
||||||
rhs_indices});
|
std::move(rhs_indices)});
|
||||||
}
|
}
|
||||||
|
|
||||||
array tensordot(
|
array tensordot(
|
||||||
@ -4477,6 +4506,7 @@ array gather_mm(
|
|||||||
array b,
|
array b,
|
||||||
std::optional<array> lhs_indices_ /* = std::nullopt */,
|
std::optional<array> lhs_indices_ /* = std::nullopt */,
|
||||||
std::optional<array> rhs_indices_ /* = std::nullopt */,
|
std::optional<array> rhs_indices_ /* = std::nullopt */,
|
||||||
|
bool sorted_indices /* = false */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
// If no indices, fall back to full matmul
|
// If no indices, fall back to full matmul
|
||||||
if (!lhs_indices_ && !rhs_indices_) {
|
if (!lhs_indices_ && !rhs_indices_) {
|
||||||
@ -4552,12 +4582,18 @@ array gather_mm(
|
|||||||
out_shape.push_back(M);
|
out_shape.push_back(M);
|
||||||
out_shape.push_back(N);
|
out_shape.push_back(N);
|
||||||
|
|
||||||
// Caculate array
|
// Make the output array
|
||||||
auto out = array(
|
auto out = array(
|
||||||
std::move(out_shape),
|
std::move(out_shape),
|
||||||
out_type,
|
out_type,
|
||||||
std::make_shared<GatherMM>(to_stream(s)),
|
std::make_shared<GatherMM>(
|
||||||
{a, b, lhs_indices, rhs_indices});
|
to_stream(s),
|
||||||
|
sorted_indices && !rhs_indices_,
|
||||||
|
sorted_indices && !lhs_indices_),
|
||||||
|
{std::move(a),
|
||||||
|
std::move(b),
|
||||||
|
std::move(lhs_indices),
|
||||||
|
std::move(rhs_indices)});
|
||||||
|
|
||||||
// Remove the possibly inserted singleton dimensions
|
// Remove the possibly inserted singleton dimensions
|
||||||
std::vector<int> axes;
|
std::vector<int> axes;
|
||||||
@ -4879,8 +4915,10 @@ array operator^(const array& a, const array& b) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
array left_shift(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
array left_shift(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||||
// Bit shift on bool always up-casts to uint8
|
auto t = result_type(a, b);
|
||||||
auto t = promote_types(result_type(a, b), uint8);
|
if (t == bool_) {
|
||||||
|
t = uint8;
|
||||||
|
}
|
||||||
return bitwise_impl(
|
return bitwise_impl(
|
||||||
astype(a, t, s),
|
astype(a, t, s),
|
||||||
astype(b, t, s),
|
astype(b, t, s),
|
||||||
@ -4893,8 +4931,10 @@ array operator<<(const array& a, const array& b) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
array right_shift(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
array right_shift(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||||
// Bit shift on bool always up-casts to uint8
|
auto t = result_type(a, b);
|
||||||
auto t = promote_types(result_type(a, b), uint8);
|
if (t == bool_) {
|
||||||
|
t = uint8;
|
||||||
|
}
|
||||||
return bitwise_impl(
|
return bitwise_impl(
|
||||||
astype(a, t, s),
|
astype(a, t, s),
|
||||||
astype(b, t, s),
|
astype(b, t, s),
|
||||||
|
10
mlx/ops.h
10
mlx/ops.h
@ -715,6 +715,14 @@ array topk(const array& a, int k, StreamOrDevice s = {});
|
|||||||
/** Returns topk elements of the array along a given axis. */
|
/** Returns topk elements of the array along a given axis. */
|
||||||
array topk(const array& a, int k, int axis, StreamOrDevice s = {});
|
array topk(const array& a, int k, int axis, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Cumulative logsumexp of an array. */
|
||||||
|
array logcumsumexp(
|
||||||
|
const array& a,
|
||||||
|
int axis,
|
||||||
|
bool reverse = false,
|
||||||
|
bool inclusive = true,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
/** The logsumexp of all elements of the array. */
|
/** The logsumexp of all elements of the array. */
|
||||||
array logsumexp(const array& a, bool keepdims, StreamOrDevice s = {});
|
array logsumexp(const array& a, bool keepdims, StreamOrDevice s = {});
|
||||||
inline array logsumexp(const array& a, StreamOrDevice s = {}) {
|
inline array logsumexp(const array& a, StreamOrDevice s = {}) {
|
||||||
@ -1344,6 +1352,7 @@ array gather_qmm(
|
|||||||
bool transpose = true,
|
bool transpose = true,
|
||||||
int group_size = 64,
|
int group_size = 64,
|
||||||
int bits = 4,
|
int bits = 4,
|
||||||
|
bool sorted_indices = false,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Returns a contraction of a and b over multiple dimensions. */
|
/** Returns a contraction of a and b over multiple dimensions. */
|
||||||
@ -1391,6 +1400,7 @@ array gather_mm(
|
|||||||
array b,
|
array b,
|
||||||
std::optional<array> lhs_indices = std::nullopt,
|
std::optional<array> lhs_indices = std::nullopt,
|
||||||
std::optional<array> rhs_indices = std::nullopt,
|
std::optional<array> rhs_indices = std::nullopt,
|
||||||
|
bool sorted_indices = false,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Extract a diagonal or construct a diagonal array */
|
/** Extract a diagonal or construct a diagonal array */
|
||||||
|
@ -1275,6 +1275,61 @@ std::vector<array> Convolution::vjp(
|
|||||||
return grads;
|
return grads;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::pair<std::vector<array>, std::vector<int>> Convolution::vmap(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<int>& axes) {
|
||||||
|
auto do_conv = [&](const array& in, const array& w, int groups) {
|
||||||
|
return conv_general(
|
||||||
|
in,
|
||||||
|
w,
|
||||||
|
kernel_strides_,
|
||||||
|
padding_,
|
||||||
|
kernel_dilation_,
|
||||||
|
input_dilation_,
|
||||||
|
groups,
|
||||||
|
flip_,
|
||||||
|
stream());
|
||||||
|
};
|
||||||
|
bool in_vmap = axes[0] >= 0;
|
||||||
|
bool w_vmap = axes[1] >= 0;
|
||||||
|
auto in = inputs[0];
|
||||||
|
auto w = inputs[1];
|
||||||
|
if (in_vmap && !w_vmap) {
|
||||||
|
// flatten / unflatten the batch dimension
|
||||||
|
// of the input / output
|
||||||
|
if (axes[0] > 0) {
|
||||||
|
in = moveaxis(in, axes[0], 0, stream());
|
||||||
|
}
|
||||||
|
auto out = do_conv(flatten(in, 0, 1, stream()), w, groups_);
|
||||||
|
out = unflatten(out, 0, {in.shape(0), in.shape(1)}, stream());
|
||||||
|
return {{out}, {0}};
|
||||||
|
} else if (!in_vmap && w_vmap) {
|
||||||
|
// flatten into the output channels of w
|
||||||
|
// unflatten the channels of the output
|
||||||
|
if (axes[1] > 0) {
|
||||||
|
w = moveaxis(w, axes[1], 0, stream());
|
||||||
|
}
|
||||||
|
auto out = do_conv(in, flatten(w, 0, 1, stream()), groups_);
|
||||||
|
out = unflatten(out, -1, {w.shape(0), w.shape(1)}, stream());
|
||||||
|
return {{out}, {static_cast<int>(out.ndim() - 2)}};
|
||||||
|
} else if (in_vmap && w_vmap) {
|
||||||
|
// use a group convolution when both inputs are vmapped
|
||||||
|
auto b = in.shape(axes[0]);
|
||||||
|
in = moveaxis(in, axes[0], -2, stream());
|
||||||
|
in = flatten(in, -2, -1, stream());
|
||||||
|
if (axes[1] > 0) {
|
||||||
|
w = moveaxis(w, axes[1], 0, stream());
|
||||||
|
}
|
||||||
|
auto c_out = w.shape(1);
|
||||||
|
w = flatten(w, 0, 1, stream());
|
||||||
|
auto out = do_conv(in, w, groups_ * b);
|
||||||
|
out = unflatten(out, -1, {b, c_out}, stream());
|
||||||
|
return {{out}, {static_cast<int>(out.ndim() - 2)}};
|
||||||
|
} else {
|
||||||
|
return {{do_conv(in, w, groups_)}, {-1}};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
bool Convolution::is_equivalent(const Primitive& other) const {
|
bool Convolution::is_equivalent(const Primitive& other) const {
|
||||||
const Convolution& c_other = static_cast<const Convolution&>(other);
|
const Convolution& c_other = static_cast<const Convolution&>(other);
|
||||||
return padding_ == c_other.padding_ &&
|
return padding_ == c_other.padding_ &&
|
||||||
@ -3080,6 +3135,8 @@ std::vector<array> GatherQMM::vjp(
|
|||||||
auto& lhs_indices = primals[4];
|
auto& lhs_indices = primals[4];
|
||||||
auto& rhs_indices = primals[5];
|
auto& rhs_indices = primals[5];
|
||||||
|
|
||||||
|
bool sorted = left_sorted_ || right_sorted_;
|
||||||
|
|
||||||
for (auto arg : argnums) {
|
for (auto arg : argnums) {
|
||||||
// gradient wrt to x
|
// gradient wrt to x
|
||||||
if (arg == 0) {
|
if (arg == 0) {
|
||||||
@ -3098,6 +3155,7 @@ std::vector<array> GatherQMM::vjp(
|
|||||||
!transpose_,
|
!transpose_,
|
||||||
group_size_,
|
group_size_,
|
||||||
bits_,
|
bits_,
|
||||||
|
sorted,
|
||||||
stream()),
|
stream()),
|
||||||
-3,
|
-3,
|
||||||
stream()),
|
stream()),
|
||||||
@ -3478,6 +3536,45 @@ std::vector<array> Scan::vjp(
|
|||||||
|
|
||||||
if (reduce_type_ == Scan::Sum) {
|
if (reduce_type_ == Scan::Sum) {
|
||||||
return {cumsum(cotangents[0], axis_, !reverse_, inclusive_, stream())};
|
return {cumsum(cotangents[0], axis_, !reverse_, inclusive_, stream())};
|
||||||
|
} else if (reduce_type_ == Scan::LogAddExp) {
|
||||||
|
// Ref:
|
||||||
|
// https://github.com/tensorflow/tensorflow/blob/2a5910906a0e0f3dbc186ff9db6386d81a63448c/tensorflow/python/ops/math_grad.py#L1832-L1863
|
||||||
|
|
||||||
|
auto x = primals[0];
|
||||||
|
auto grad = cotangents[0];
|
||||||
|
auto results = outputs[0];
|
||||||
|
|
||||||
|
auto zero = zeros({1}, grad.dtype(), stream());
|
||||||
|
auto grad_min = array(finfo(grad.dtype()).min, grad.dtype());
|
||||||
|
|
||||||
|
// Split the incoming gradient into positive and negative part
|
||||||
|
// in order to take logs. This is required for stable results.
|
||||||
|
auto log_abs_grad = log(abs(grad, stream()), stream());
|
||||||
|
auto log_grad_positive =
|
||||||
|
where(greater(grad, zero, stream()), log_abs_grad, grad_min, stream());
|
||||||
|
auto log_grad_negative =
|
||||||
|
where(less(grad, zero, stream()), log_abs_grad, grad_min, stream());
|
||||||
|
|
||||||
|
auto output_pos = exp(
|
||||||
|
add(logcumsumexp(
|
||||||
|
subtract(log_grad_positive, results, stream()),
|
||||||
|
axis_,
|
||||||
|
!reverse_,
|
||||||
|
inclusive_,
|
||||||
|
stream()),
|
||||||
|
x,
|
||||||
|
stream()));
|
||||||
|
auto output_neg = exp(
|
||||||
|
add(logcumsumexp(
|
||||||
|
subtract(log_grad_negative, results, stream()),
|
||||||
|
axis_,
|
||||||
|
!reverse_,
|
||||||
|
inclusive_,
|
||||||
|
stream()),
|
||||||
|
x,
|
||||||
|
stream()));
|
||||||
|
|
||||||
|
return {subtract(output_pos, output_neg, stream())};
|
||||||
} else if (reduce_type_ == Scan::Prod) {
|
} else if (reduce_type_ == Scan::Prod) {
|
||||||
auto in = primals[0];
|
auto in = primals[0];
|
||||||
// Find the location of the first 0 and set it to 1:
|
// Find the location of the first 0 and set it to 1:
|
||||||
@ -4856,6 +4953,8 @@ std::vector<array> GatherMM::vjp(
|
|||||||
int N = cotan.shape(-1);
|
int N = cotan.shape(-1);
|
||||||
int K = primals[0].shape(-1);
|
int K = primals[0].shape(-1);
|
||||||
|
|
||||||
|
bool sorted = left_sorted_ || right_sorted_;
|
||||||
|
|
||||||
for (auto arg : argnums) {
|
for (auto arg : argnums) {
|
||||||
if (arg == 0) {
|
if (arg == 0) {
|
||||||
// M X N * (K X N).T -> M X K
|
// M X N * (K X N).T -> M X K
|
||||||
@ -4866,7 +4965,8 @@ std::vector<array> GatherMM::vjp(
|
|||||||
base = reshape(base, {-1, M, K}, stream());
|
base = reshape(base, {-1, M, K}, stream());
|
||||||
|
|
||||||
// g : (out_batch_shape) + (M, K)
|
// g : (out_batch_shape) + (M, K)
|
||||||
auto g = gather_mm(cotan, bt, std::nullopt, rhs_indices, stream());
|
auto g =
|
||||||
|
gather_mm(cotan, bt, std::nullopt, rhs_indices, sorted, stream());
|
||||||
g = expand_dims(g, -3, stream());
|
g = expand_dims(g, -3, stream());
|
||||||
auto gacc = scatter_add(base, lhs_indices, g, 0, stream());
|
auto gacc = scatter_add(base, lhs_indices, g, 0, stream());
|
||||||
|
|
||||||
@ -4881,7 +4981,8 @@ std::vector<array> GatherMM::vjp(
|
|||||||
base = reshape(base, {-1, K, N}, stream());
|
base = reshape(base, {-1, K, N}, stream());
|
||||||
|
|
||||||
// g : (out_batch_shape) + (K, N)
|
// g : (out_batch_shape) + (K, N)
|
||||||
auto g = gather_mm(at, cotan, lhs_indices, std::nullopt, stream());
|
auto g =
|
||||||
|
gather_mm(at, cotan, lhs_indices, std::nullopt, sorted, stream());
|
||||||
g = expand_dims(g, -3, stream());
|
g = expand_dims(g, -3, stream());
|
||||||
auto gacc = scatter_add(base, rhs_indices, g, 0, stream());
|
auto gacc = scatter_add(base, rhs_indices, g, 0, stream());
|
||||||
|
|
||||||
@ -4894,6 +4995,12 @@ std::vector<array> GatherMM::vjp(
|
|||||||
return vjps;
|
return vjps;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool GatherMM::is_equivalent(const Primitive& other) const {
|
||||||
|
const GatherMM& g_other = static_cast<const GatherMM&>(other);
|
||||||
|
return left_sorted_ == g_other.left_sorted_ &&
|
||||||
|
right_sorted_ == g_other.right_sorted_;
|
||||||
|
}
|
||||||
|
|
||||||
bool BlockMaskedMM::is_equivalent(const Primitive& other) const {
|
bool BlockMaskedMM::is_equivalent(const Primitive& other) const {
|
||||||
const BlockMaskedMM& a_other = static_cast<const BlockMaskedMM&>(other);
|
const BlockMaskedMM& a_other = static_cast<const BlockMaskedMM&>(other);
|
||||||
return (block_size_ == a_other.block_size_);
|
return (block_size_ == a_other.block_size_);
|
||||||
|
@ -498,7 +498,13 @@ class BlockMaskedMM : public UnaryPrimitive {
|
|||||||
|
|
||||||
class GatherMM : public UnaryPrimitive {
|
class GatherMM : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit GatherMM(Stream stream) : UnaryPrimitive(stream) {}
|
explicit GatherMM(
|
||||||
|
Stream stream,
|
||||||
|
bool left_sorted = false,
|
||||||
|
bool right_sorted = false)
|
||||||
|
: UnaryPrimitive(stream),
|
||||||
|
left_sorted_(left_sorted),
|
||||||
|
right_sorted_(right_sorted) {}
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -510,7 +516,14 @@ class GatherMM : public UnaryPrimitive {
|
|||||||
const std::vector<array>& outputs) override;
|
const std::vector<array>& outputs) override;
|
||||||
|
|
||||||
DEFINE_PRINT(GatherMM)
|
DEFINE_PRINT(GatherMM)
|
||||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
auto state() const {
|
||||||
|
return std::make_pair(left_sorted_, right_sorted_);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool left_sorted_;
|
||||||
|
bool right_sorted_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class BroadcastAxes : public UnaryPrimitive {
|
class BroadcastAxes : public UnaryPrimitive {
|
||||||
@ -698,6 +711,7 @@ class Convolution : public UnaryPrimitive {
|
|||||||
const std::vector<int>& argnums,
|
const std::vector<int>& argnums,
|
||||||
const std::vector<array>& outputs) override;
|
const std::vector<array>& outputs) override;
|
||||||
|
|
||||||
|
DEFINE_VMAP()
|
||||||
DEFINE_PRINT(Convolution)
|
DEFINE_PRINT(Convolution)
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
auto state() const {
|
auto state() const {
|
||||||
@ -1578,11 +1592,19 @@ class QuantizedMatmul : public UnaryPrimitive {
|
|||||||
|
|
||||||
class GatherQMM : public UnaryPrimitive {
|
class GatherQMM : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit GatherQMM(Stream stream, int group_size, int bits, bool transpose)
|
explicit GatherQMM(
|
||||||
|
Stream stream,
|
||||||
|
int group_size,
|
||||||
|
int bits,
|
||||||
|
bool transpose,
|
||||||
|
bool left_sorted = false,
|
||||||
|
bool right_sorted = false)
|
||||||
: UnaryPrimitive(stream),
|
: UnaryPrimitive(stream),
|
||||||
group_size_(group_size),
|
group_size_(group_size),
|
||||||
bits_(bits),
|
bits_(bits),
|
||||||
transpose_(transpose) {}
|
transpose_(transpose),
|
||||||
|
left_sorted_(left_sorted),
|
||||||
|
right_sorted_(right_sorted) {}
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -1592,13 +1614,16 @@ class GatherQMM : public UnaryPrimitive {
|
|||||||
DEFINE_PRINT(GatherQMM)
|
DEFINE_PRINT(GatherQMM)
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
auto state() const {
|
auto state() const {
|
||||||
return std::make_tuple(group_size_, bits_, transpose_);
|
return std::make_tuple(
|
||||||
|
group_size_, bits_, transpose_, left_sorted_, right_sorted_);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int group_size_;
|
int group_size_;
|
||||||
int bits_;
|
int bits_;
|
||||||
bool transpose_;
|
bool transpose_;
|
||||||
|
bool left_sorted_;
|
||||||
|
bool right_sorted_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class RandomBits : public UnaryPrimitive {
|
class RandomBits : public UnaryPrimitive {
|
||||||
@ -1728,7 +1753,7 @@ class Round : public UnaryPrimitive {
|
|||||||
|
|
||||||
class Scan : public UnaryPrimitive {
|
class Scan : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
enum ReduceType { Max, Min, Sum, Prod };
|
enum ReduceType { Max, Min, Sum, Prod, LogAddExp };
|
||||||
|
|
||||||
explicit Scan(
|
explicit Scan(
|
||||||
Stream stream,
|
Stream stream,
|
||||||
@ -1763,6 +1788,9 @@ class Scan : public UnaryPrimitive {
|
|||||||
case Max:
|
case Max:
|
||||||
os << "Max";
|
os << "Max";
|
||||||
break;
|
break;
|
||||||
|
case LogAddExp:
|
||||||
|
os << "Logaddexp";
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
110
mlx/utils.cpp
110
mlx/utils.cpp
@ -5,6 +5,7 @@
|
|||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "mlx/dtype_utils.h"
|
||||||
#include "mlx/types/limits.h"
|
#include "mlx/types/limits.h"
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
@ -224,37 +225,7 @@ void print_array(std::ostream& os, const array& a) {
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::ostream& operator<<(std::ostream& os, const Dtype& dtype) {
|
std::ostream& operator<<(std::ostream& os, const Dtype& dtype) {
|
||||||
switch (dtype) {
|
return os << dtype_to_string(dtype);
|
||||||
case bool_:
|
|
||||||
return os << "bool";
|
|
||||||
case uint8:
|
|
||||||
return os << "uint8";
|
|
||||||
case uint16:
|
|
||||||
return os << "uint16";
|
|
||||||
case uint32:
|
|
||||||
return os << "uint32";
|
|
||||||
case uint64:
|
|
||||||
return os << "uint64";
|
|
||||||
case int8:
|
|
||||||
return os << "int8";
|
|
||||||
case int16:
|
|
||||||
return os << "int16";
|
|
||||||
case int32:
|
|
||||||
return os << "int32";
|
|
||||||
case int64:
|
|
||||||
return os << "int64";
|
|
||||||
case float16:
|
|
||||||
return os << "float16";
|
|
||||||
case float32:
|
|
||||||
return os << "float32";
|
|
||||||
case float64:
|
|
||||||
return os << "float64";
|
|
||||||
case bfloat16:
|
|
||||||
return os << "bfloat16";
|
|
||||||
case complex64:
|
|
||||||
return os << "complex64";
|
|
||||||
}
|
|
||||||
return os;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k) {
|
std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k) {
|
||||||
@ -277,50 +248,7 @@ std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k) {
|
|||||||
|
|
||||||
std::ostream& operator<<(std::ostream& os, array a) {
|
std::ostream& operator<<(std::ostream& os, array a) {
|
||||||
a.eval();
|
a.eval();
|
||||||
switch (a.dtype()) {
|
MLX_SWITCH_ALL_TYPES(a.dtype(), CTYPE, print_array<CTYPE>(os, a));
|
||||||
case bool_:
|
|
||||||
print_array<bool>(os, a);
|
|
||||||
break;
|
|
||||||
case uint8:
|
|
||||||
print_array<uint8_t>(os, a);
|
|
||||||
break;
|
|
||||||
case uint16:
|
|
||||||
print_array<uint16_t>(os, a);
|
|
||||||
break;
|
|
||||||
case uint32:
|
|
||||||
print_array<uint32_t>(os, a);
|
|
||||||
break;
|
|
||||||
case uint64:
|
|
||||||
print_array<uint64_t>(os, a);
|
|
||||||
break;
|
|
||||||
case int8:
|
|
||||||
print_array<int8_t>(os, a);
|
|
||||||
break;
|
|
||||||
case int16:
|
|
||||||
print_array<int16_t>(os, a);
|
|
||||||
break;
|
|
||||||
case int32:
|
|
||||||
print_array<int32_t>(os, a);
|
|
||||||
break;
|
|
||||||
case int64:
|
|
||||||
print_array<int64_t>(os, a);
|
|
||||||
break;
|
|
||||||
case float16:
|
|
||||||
print_array<float16_t>(os, a);
|
|
||||||
break;
|
|
||||||
case bfloat16:
|
|
||||||
print_array<bfloat16_t>(os, a);
|
|
||||||
break;
|
|
||||||
case float32:
|
|
||||||
print_array<float>(os, a);
|
|
||||||
break;
|
|
||||||
case float64:
|
|
||||||
print_array<double>(os, a);
|
|
||||||
break;
|
|
||||||
case complex64:
|
|
||||||
print_array<complex64_t>(os, a);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
return os;
|
return os;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -387,36 +315,8 @@ void set_iinfo_limits(int64_t& min, uint64_t& max) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
iinfo::iinfo(Dtype dtype) : dtype(dtype) {
|
iinfo::iinfo(Dtype dtype) : dtype(dtype) {
|
||||||
switch (dtype) {
|
MLX_SWITCH_INT_TYPES_CHECKED(
|
||||||
case int8:
|
dtype, "[iinfo]", CTYPE, set_iinfo_limits<CTYPE>(min, max));
|
||||||
set_iinfo_limits<int8_t>(min, max);
|
|
||||||
break;
|
|
||||||
case uint8:
|
|
||||||
set_iinfo_limits<uint8_t>(min, max);
|
|
||||||
break;
|
|
||||||
case int16:
|
|
||||||
set_iinfo_limits<int16_t>(min, max);
|
|
||||||
break;
|
|
||||||
case uint16:
|
|
||||||
set_iinfo_limits<uint16_t>(min, max);
|
|
||||||
break;
|
|
||||||
case int32:
|
|
||||||
set_iinfo_limits<int32_t>(min, max);
|
|
||||||
break;
|
|
||||||
case uint32:
|
|
||||||
set_iinfo_limits<uint32_t>(min, max);
|
|
||||||
break;
|
|
||||||
case int64:
|
|
||||||
set_iinfo_limits<int64_t>(min, max);
|
|
||||||
break;
|
|
||||||
case uint64:
|
|
||||||
set_iinfo_limits<uint64_t>(min, max);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[iinfo] dtype " << dtype << " is not integral.";
|
|
||||||
throw std::invalid_argument(msg.str());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -3,8 +3,8 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#define MLX_VERSION_MAJOR 0
|
#define MLX_VERSION_MAJOR 0
|
||||||
#define MLX_VERSION_MINOR 24
|
#define MLX_VERSION_MINOR 25
|
||||||
#define MLX_VERSION_PATCH 2
|
#define MLX_VERSION_PATCH 0
|
||||||
#define MLX_VERSION_NUMERIC \
|
#define MLX_VERSION_NUMERIC \
|
||||||
(100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH)
|
(100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH)
|
||||||
|
|
||||||
|
@ -1202,6 +1202,28 @@ void init_array(nb::module_& m) {
|
|||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
"See :func:`max`.")
|
"See :func:`max`.")
|
||||||
|
.def(
|
||||||
|
"logcumsumexp",
|
||||||
|
[](const mx::array& a,
|
||||||
|
std::optional<int> axis,
|
||||||
|
bool reverse,
|
||||||
|
bool inclusive,
|
||||||
|
mx::StreamOrDevice s) {
|
||||||
|
if (axis) {
|
||||||
|
return mx::logcumsumexp(a, *axis, reverse, inclusive, s);
|
||||||
|
} else {
|
||||||
|
// TODO: Implement that in the C++ API as well. See concatenate
|
||||||
|
// above.
|
||||||
|
return mx::logcumsumexp(
|
||||||
|
mx::reshape(a, {-1}, s), 0, reverse, inclusive, s);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"axis"_a = nb::none(),
|
||||||
|
nb::kw_only(),
|
||||||
|
"reverse"_a = false,
|
||||||
|
"inclusive"_a = true,
|
||||||
|
"stream"_a = nb::none(),
|
||||||
|
"See :func:`logcumsumexp`.")
|
||||||
.def(
|
.def(
|
||||||
"logsumexp",
|
"logsumexp",
|
||||||
[](const mx::array& a,
|
[](const mx::array& a,
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
#include <nanobind/nanobind.h>
|
#include <nanobind/nanobind.h>
|
||||||
#include <nanobind/stl/map.h>
|
|
||||||
#include <nanobind/stl/optional.h>
|
#include <nanobind/stl/optional.h>
|
||||||
#include <nanobind/stl/string.h>
|
#include <nanobind/stl/string.h>
|
||||||
|
#include <nanobind/stl/unordered_map.h>
|
||||||
#include <nanobind/stl/vector.h>
|
#include <nanobind/stl/vector.h>
|
||||||
|
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
@ -16,8 +16,7 @@ namespace mx = mlx::core;
|
|||||||
namespace nb = nanobind;
|
namespace nb = nanobind;
|
||||||
using namespace nb::literals;
|
using namespace nb::literals;
|
||||||
|
|
||||||
std::pair<std::vector<mx::array>, std::map<std::string, mx::array>>
|
std::pair<mx::Args, mx::Kwargs> validate_and_extract_inputs(
|
||||||
validate_and_extract_inputs(
|
|
||||||
const nb::args& args,
|
const nb::args& args,
|
||||||
const nb::kwargs& kwargs,
|
const nb::kwargs& kwargs,
|
||||||
const std::string& prefix) {
|
const std::string& prefix) {
|
||||||
@ -30,8 +29,8 @@ validate_and_extract_inputs(
|
|||||||
"and/or dictionary of arrays.");
|
"and/or dictionary of arrays.");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
std::vector<mx::array> args_;
|
mx::Args args_;
|
||||||
std::map<std::string, mx::array> kwargs_;
|
mx::Kwargs kwargs_;
|
||||||
if (args.size() == 0) {
|
if (args.size() == 0) {
|
||||||
// No args so kwargs must be keyword arrays
|
// No args so kwargs must be keyword arrays
|
||||||
maybe_throw(nb::try_cast(kwargs, kwargs_));
|
maybe_throw(nb::try_cast(kwargs, kwargs_));
|
||||||
@ -81,9 +80,7 @@ class PyFunctionExporter {
|
|||||||
void close() {
|
void close() {
|
||||||
exporter_.close();
|
exporter_.close();
|
||||||
}
|
}
|
||||||
void operator()(
|
void operator()(const mx::Args& args, const mx::Kwargs& kwargs) {
|
||||||
const std::vector<mx::array>& args,
|
|
||||||
const std::map<std::string, mx::array>& kwargs) {
|
|
||||||
exporter_(args, kwargs);
|
exporter_(args, kwargs);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -98,9 +95,12 @@ int py_function_exporter_tp_traverse(
|
|||||||
PyObject* self,
|
PyObject* self,
|
||||||
visitproc visit,
|
visitproc visit,
|
||||||
void* arg) {
|
void* arg) {
|
||||||
|
Py_VISIT(Py_TYPE(self));
|
||||||
|
if (!nb::inst_ready(self)) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
auto* p = nb::inst_ptr<PyFunctionExporter>(self);
|
auto* p = nb::inst_ptr<PyFunctionExporter>(self);
|
||||||
Py_VISIT(p->dep_.ptr());
|
Py_VISIT(p->dep_.ptr());
|
||||||
Py_VISIT(Py_TYPE(self));
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -109,23 +109,22 @@ PyType_Slot py_function_exporter_slots[] = {
|
|||||||
{0, 0}};
|
{0, 0}};
|
||||||
|
|
||||||
auto wrap_export_function(nb::callable fun) {
|
auto wrap_export_function(nb::callable fun) {
|
||||||
return [fun = std::move(fun)](
|
return
|
||||||
const std::vector<mx::array>& args_,
|
[fun = std::move(fun)](const mx::Args& args_, const mx::Kwargs& kwargs_) {
|
||||||
const std::map<std::string, mx::array>& kwargs_) {
|
auto kwargs = nb::dict();
|
||||||
auto kwargs = nb::dict();
|
kwargs.update(nb::cast(kwargs_));
|
||||||
kwargs.update(nb::cast(kwargs_));
|
auto args = nb::tuple(nb::cast(args_));
|
||||||
auto args = nb::tuple(nb::cast(args_));
|
auto outputs = fun(*args, **kwargs);
|
||||||
auto outputs = fun(*args, **kwargs);
|
std::vector<mx::array> outputs_;
|
||||||
std::vector<mx::array> outputs_;
|
if (nb::isinstance<mx::array>(outputs)) {
|
||||||
if (nb::isinstance<mx::array>(outputs)) {
|
outputs_.push_back(nb::cast<mx::array>(outputs));
|
||||||
outputs_.push_back(nb::cast<mx::array>(outputs));
|
} else if (!nb::try_cast(outputs, outputs_)) {
|
||||||
} else if (!nb::try_cast(outputs, outputs_)) {
|
throw std::invalid_argument(
|
||||||
throw std::invalid_argument(
|
"[export_function] Outputs can be either a single array "
|
||||||
"[export_function] Outputs can be either a single array "
|
"a tuple or list of arrays.");
|
||||||
"a tuple or list of arrays.");
|
}
|
||||||
}
|
return outputs_;
|
||||||
return outputs_;
|
};
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void init_export(nb::module_& m) {
|
void init_export(nb::module_& m) {
|
||||||
|
@ -16,12 +16,12 @@ struct gc_func {
|
|||||||
};
|
};
|
||||||
|
|
||||||
int gc_func_tp_traverse(PyObject* self, visitproc visit, void* arg) {
|
int gc_func_tp_traverse(PyObject* self, visitproc visit, void* arg) {
|
||||||
|
Py_VISIT(Py_TYPE(self));
|
||||||
gc_func* w = (gc_func*)self;
|
gc_func* w = (gc_func*)self;
|
||||||
Py_VISIT(w->func);
|
Py_VISIT(w->func);
|
||||||
for (auto d : w->deps) {
|
for (auto d : w->deps) {
|
||||||
Py_VISIT(d);
|
Py_VISIT(d);
|
||||||
}
|
}
|
||||||
Py_VISIT(Py_TYPE(self));
|
|
||||||
return 0;
|
return 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -2382,6 +2382,43 @@ void init_ops(nb::module_& m) {
|
|||||||
Returns:
|
Returns:
|
||||||
array: The output array with the corresponding axes reduced.
|
array: The output array with the corresponding axes reduced.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"logcumsumexp",
|
||||||
|
[](const mx::array& a,
|
||||||
|
std::optional<int> axis,
|
||||||
|
bool reverse,
|
||||||
|
bool inclusive,
|
||||||
|
mx::StreamOrDevice s) {
|
||||||
|
if (axis) {
|
||||||
|
return mx::logcumsumexp(a, *axis, reverse, inclusive, s);
|
||||||
|
} else {
|
||||||
|
return mx::logcumsumexp(
|
||||||
|
mx::reshape(a, {-1}, s), 0, reverse, inclusive, s);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
nb::arg(),
|
||||||
|
"axis"_a = nb::none(),
|
||||||
|
nb::kw_only(),
|
||||||
|
"reverse"_a = false,
|
||||||
|
"inclusive"_a = true,
|
||||||
|
"stream"_a = nb::none(),
|
||||||
|
nb::sig(
|
||||||
|
"def logcumsumexp(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array"),
|
||||||
|
R"pbdoc(
|
||||||
|
Return the cumulative logsumexp of the elements along the given axis.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
a (array): Input array
|
||||||
|
axis (int, optional): Optional axis to compute the cumulative logsumexp
|
||||||
|
over. If unspecified the cumulative logsumexp of the flattened array is
|
||||||
|
returned.
|
||||||
|
reverse (bool): Perform the cumulative logsumexp in reverse.
|
||||||
|
inclusive (bool): The i-th element of the output includes the i-th
|
||||||
|
element of the input.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: The output array.
|
||||||
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"logsumexp",
|
"logsumexp",
|
||||||
[](const mx::array& a,
|
[](const mx::array& a,
|
||||||
@ -4213,9 +4250,10 @@ void init_ops(nb::module_& m) {
|
|||||||
"group_size"_a = 64,
|
"group_size"_a = 64,
|
||||||
"bits"_a = 4,
|
"bits"_a = 4,
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
|
"sorted_indices"_a = false,
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
nb::sig(
|
nb::sig(
|
||||||
"def gather_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
|
"def gather_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, *, sorted_indices: bool = False, stream: Union[None, Stream, Device] = None) -> array"),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Perform quantized matrix multiplication with matrix-level gather.
|
Perform quantized matrix multiplication with matrix-level gather.
|
||||||
|
|
||||||
@ -4228,23 +4266,25 @@ void init_ops(nb::module_& m) {
|
|||||||
as ``w`` since they represent the same quantized matrix.
|
as ``w`` since they represent the same quantized matrix.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x (array): Input array
|
x (array): Input array
|
||||||
w (array): Quantized matrix packed in unsigned integers
|
w (array): Quantized matrix packed in unsigned integers
|
||||||
scales (array): The scales to use per ``group_size`` elements of ``w``
|
scales (array): The scales to use per ``group_size`` elements of ``w``
|
||||||
biases (array): The biases to use per ``group_size`` elements of ``w``
|
biases (array): The biases to use per ``group_size`` elements of ``w``
|
||||||
lhs_indices (array, optional): Integer indices for ``x``. Default: ``None``.
|
lhs_indices (array, optional): Integer indices for ``x``. Default: ``None``.
|
||||||
rhs_indices (array, optional): Integer indices for ``w``. Default: ``None``.
|
rhs_indices (array, optional): Integer indices for ``w``. Default: ``None``.
|
||||||
transpose (bool, optional): Defines whether to multiply with the
|
transpose (bool, optional): Defines whether to multiply with the
|
||||||
transposed ``w`` or not, namely whether we are performing
|
transposed ``w`` or not, namely whether we are performing
|
||||||
``x @ w.T`` or ``x @ w``. Default: ``True``.
|
``x @ w.T`` or ``x @ w``. Default: ``True``.
|
||||||
group_size (int, optional): The size of the group in ``w`` that
|
group_size (int, optional): The size of the group in ``w`` that
|
||||||
shares a scale and bias. Default: ``64``.
|
shares a scale and bias. Default: ``64``.
|
||||||
bits (int, optional): The number of bits occupied by each element in
|
bits (int, optional): The number of bits occupied by each element in
|
||||||
``w``. Default: ``4``.
|
``w``. Default: ``4``.
|
||||||
|
sorted_indices (bool, optional): May allow a faster implementation
|
||||||
|
if the passed indices are sorted. Default: ``False``.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
array: The result of the multiplication of ``x`` with ``w``
|
array: The result of the multiplication of ``x`` with ``w``
|
||||||
after gathering using ``lhs_indices`` and ``rhs_indices``.
|
after gathering using ``lhs_indices`` and ``rhs_indices``.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"tensordot",
|
"tensordot",
|
||||||
@ -4274,16 +4314,16 @@ void init_ops(nb::module_& m) {
|
|||||||
Compute the tensor dot product along the specified axes.
|
Compute the tensor dot product along the specified axes.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
a (array): Input array
|
a (array): Input array
|
||||||
b (array): Input array
|
b (array): Input array
|
||||||
axes (int or list(list(int)), optional): The number of dimensions to
|
axes (int or list(list(int)), optional): The number of dimensions to
|
||||||
sum over. If an integer is provided, then sum over the last
|
sum over. If an integer is provided, then sum over the last
|
||||||
``axes`` dimensions of ``a`` and the first ``axes`` dimensions of
|
``axes`` dimensions of ``a`` and the first ``axes`` dimensions of
|
||||||
``b``. If a list of lists is provided, then sum over the
|
``b``. If a list of lists is provided, then sum over the
|
||||||
corresponding dimensions of ``a`` and ``b``. Default: 2.
|
corresponding dimensions of ``a`` and ``b``. Default: 2.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
array: The tensor dot product.
|
array: The tensor dot product.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"inner",
|
"inner",
|
||||||
@ -4427,9 +4467,10 @@ void init_ops(nb::module_& m) {
|
|||||||
"lhs_indices"_a = nb::none(),
|
"lhs_indices"_a = nb::none(),
|
||||||
"rhs_indices"_a = nb::none(),
|
"rhs_indices"_a = nb::none(),
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
|
"sorted_indices"_a = false,
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
nb::sig(
|
nb::sig(
|
||||||
"def gather_mm(a: array, b: array, /, lhs_indices: array, rhs_indices: array, *, stream: Union[None, Stream, Device] = None) -> array"),
|
"def gather_mm(a: array, b: array, /, lhs_indices: array, rhs_indices: array, *, sorted_indices: bool = False, stream: Union[None, Stream, Device] = None) -> array"),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Matrix multiplication with matrix-level gather.
|
Matrix multiplication with matrix-level gather.
|
||||||
|
|
||||||
@ -4448,11 +4489,16 @@ void init_ops(nb::module_& m) {
|
|||||||
For ``b`` with shape ``(B1, B2, ..., BS, M, K)``, ``rhs_indices``
|
For ``b`` with shape ``(B1, B2, ..., BS, M, K)``, ``rhs_indices``
|
||||||
contains indices from the range ``[0, B1 * B2 * ... * BS)``
|
contains indices from the range ``[0, B1 * B2 * ... * BS)``
|
||||||
|
|
||||||
|
If only one index is passed and it is sorted, the ``sorted_indices``
|
||||||
|
flag can be passed for a possible faster implementation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
a (array): Input array.
|
a (array): Input array.
|
||||||
b (array): Input array.
|
b (array): Input array.
|
||||||
lhs_indices (array, optional): Integer indices for ``a``. Default: ``None``
|
lhs_indices (array, optional): Integer indices for ``a``. Default: ``None``
|
||||||
rhs_indices (array, optional): Integer indices for ``b``. Default: ``None``
|
rhs_indices (array, optional): Integer indices for ``b``. Default: ``None``
|
||||||
|
sorted_indices (bool, optional): May allow a faster implementation
|
||||||
|
if the passed indices are sorted. Default: ``False``.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
array: The output array.
|
array: The output array.
|
||||||
|
@ -960,6 +960,11 @@ class PyCustomFunction {
|
|||||||
};
|
};
|
||||||
|
|
||||||
int py_custom_function_tp_traverse(PyObject* self, visitproc visit, void* arg) {
|
int py_custom_function_tp_traverse(PyObject* self, visitproc visit, void* arg) {
|
||||||
|
Py_VISIT(Py_TYPE(self));
|
||||||
|
if (!nb::inst_ready(self)) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
auto* p = nb::inst_ptr<PyCustomFunction>(self);
|
auto* p = nb::inst_ptr<PyCustomFunction>(self);
|
||||||
nb::handle v = nb::find(p->fun_);
|
nb::handle v = nb::find(p->fun_);
|
||||||
Py_VISIT(v.ptr());
|
Py_VISIT(v.ptr());
|
||||||
@ -975,7 +980,6 @@ int py_custom_function_tp_traverse(PyObject* self, visitproc visit, void* arg) {
|
|||||||
nb::handle v = nb::find(*(p->vmap_fun_));
|
nb::handle v = nb::find(*(p->vmap_fun_));
|
||||||
Py_VISIT(v.ptr());
|
Py_VISIT(v.ptr());
|
||||||
}
|
}
|
||||||
Py_VISIT(Py_TYPE(self));
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
int py_custom_function_tp_clear(PyObject* self) {
|
int py_custom_function_tp_clear(PyObject* self) {
|
||||||
|
@ -1508,6 +1508,7 @@ class TestArray(mlx_tests.MLXTestCase):
|
|||||||
("prod", 1),
|
("prod", 1),
|
||||||
("min", 1),
|
("min", 1),
|
||||||
("max", 1),
|
("max", 1),
|
||||||
|
("logcumsumexp", 1),
|
||||||
("logsumexp", 1),
|
("logsumexp", 1),
|
||||||
("mean", 1),
|
("mean", 1),
|
||||||
("var", 1),
|
("var", 1),
|
||||||
|
@ -1108,7 +1108,7 @@ class TestBlas(mlx_tests.MLXTestCase):
|
|||||||
lhs_indices_ = mx.broadcast_to(lhs_indices, (3, 2))
|
lhs_indices_ = mx.broadcast_to(lhs_indices, (3, 2))
|
||||||
rhs_indices_ = mx.broadcast_to(rhs_indices, (3, 2))
|
rhs_indices_ = mx.broadcast_to(rhs_indices, (3, 2))
|
||||||
M = a.shape[-2]
|
M = a.shape[-2]
|
||||||
N = b.shape[-2]
|
N = b.shape[-1]
|
||||||
K = a.shape[-1]
|
K = a.shape[-1]
|
||||||
|
|
||||||
a = a.reshape((-1, M, K))
|
a = a.reshape((-1, M, K))
|
||||||
|
@ -194,6 +194,11 @@ class TestFFT(mlx_tests.MLXTestCase):
|
|||||||
r_np = np.fft.ifft(segment, n=n_fft)
|
r_np = np.fft.ifft(segment, n=n_fft)
|
||||||
self.assertTrue(np.allclose(r, r_np, atol=1e-5, rtol=1e-5))
|
self.assertTrue(np.allclose(r, r_np, atol=1e-5, rtol=1e-5))
|
||||||
|
|
||||||
|
def test_fft_throws(self):
|
||||||
|
x = mx.array(3.0)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
mx.fft.irfftn(x)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -1857,6 +1857,30 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
y = mx.as_strided(x, (x.size,), (-1,), x.size - 1)
|
y = mx.as_strided(x, (x.size,), (-1,), x.size - 1)
|
||||||
self.assertTrue(mx.array_equal(y, x[::-1]))
|
self.assertTrue(mx.array_equal(y, x[::-1]))
|
||||||
|
|
||||||
|
def test_logcumsumexp(self):
|
||||||
|
npop = np.logaddexp.accumulate
|
||||||
|
mxop = mx.logcumsumexp
|
||||||
|
|
||||||
|
a_npy = np.random.randn(32, 32, 32).astype(np.float32)
|
||||||
|
a_mlx = mx.array(a_npy)
|
||||||
|
|
||||||
|
for axis in (0, 1, 2):
|
||||||
|
c_npy = npop(a_npy, axis=axis)
|
||||||
|
c_mlx = mxop(a_mlx, axis=axis)
|
||||||
|
self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3))
|
||||||
|
|
||||||
|
edge_cases_npy = [
|
||||||
|
np.float32([-float("inf")] * 8),
|
||||||
|
np.float32([-float("inf"), 0, -float("inf")]),
|
||||||
|
np.float32([-float("inf"), float("inf"), -float("inf")]),
|
||||||
|
]
|
||||||
|
edge_cases_mlx = [mx.array(a) for a in edge_cases_npy]
|
||||||
|
|
||||||
|
for a_npy, a_mlx in zip(edge_cases_npy, edge_cases_mlx):
|
||||||
|
c_npy = npop(a_npy, axis=0)
|
||||||
|
c_mlx = mxop(a_mlx, axis=0)
|
||||||
|
self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3))
|
||||||
|
|
||||||
def test_scans(self):
|
def test_scans(self):
|
||||||
a_npy = np.random.randn(32, 32, 32).astype(np.float32)
|
a_npy = np.random.randn(32, 32, 32).astype(np.float32)
|
||||||
a_mlx = mx.array(a_npy)
|
a_mlx = mx.array(a_npy)
|
||||||
@ -2910,6 +2934,35 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
out = a[::-1]
|
out = a[::-1]
|
||||||
self.assertTrue(mx.array_equal(out[-1, :], a[0, :]))
|
self.assertTrue(mx.array_equal(out[-1, :], a[0, :]))
|
||||||
|
|
||||||
|
def test_complex_ops(self):
|
||||||
|
x = mx.array(
|
||||||
|
[
|
||||||
|
3.0 + 4.0j,
|
||||||
|
-5.0 + 12.0j,
|
||||||
|
-8.0 + 0.0j,
|
||||||
|
0.0 + 9.0j,
|
||||||
|
0.0 + 0.0j,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
ops = ["arccos", "arcsin", "arctan", "square", "sqrt"]
|
||||||
|
for op in ops:
|
||||||
|
with self.subTest(op=op):
|
||||||
|
np_op = getattr(np, op)
|
||||||
|
mx_op = getattr(mx, op)
|
||||||
|
self.assertTrue(np.allclose(mx_op(x), np_op(x)))
|
||||||
|
|
||||||
|
x = mx.array(
|
||||||
|
[
|
||||||
|
3.0 + 4.0j,
|
||||||
|
-5.0 + 12.0j,
|
||||||
|
-8.0 + 0.0j,
|
||||||
|
0.0 + 9.0j,
|
||||||
|
9.0 + 1.0j,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.assertTrue(np.allclose(mx.rsqrt(x), 1.0 / np.sqrt(x)))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -174,12 +174,14 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
|||||||
tests = product(
|
tests = product(
|
||||||
[128, 64, 32], # group_size
|
[128, 64, 32], # group_size
|
||||||
[2, 3, 4, 6, 8], # bits
|
[2, 3, 4, 6, 8], # bits
|
||||||
[128, 256], # M
|
[32, 128, 256], # M
|
||||||
[128, 256, 67], # N
|
[128, 256, 67], # N
|
||||||
[0, 1, 3, 8], # B
|
[0, 1, 3, 8], # B
|
||||||
)
|
)
|
||||||
for group_size, bits, M, N, B in tests:
|
for group_size, bits, M, N, B in tests:
|
||||||
with self.subTest(shape=(B, M, N), group_size=group_size, bits=bits):
|
with self.subTest(shape=(B, M, N), group_size=group_size, bits=bits):
|
||||||
|
if M < group_size:
|
||||||
|
continue
|
||||||
x_shape = (1, N) if B == 0 else (B, 1, N)
|
x_shape = (1, N) if B == 0 else (B, 1, N)
|
||||||
w_shape = (N, M) if B == 0 else (B, N, M)
|
w_shape = (N, M) if B == 0 else (B, N, M)
|
||||||
x = mx.random.normal(shape=x_shape, key=k1)
|
x = mx.random.normal(shape=x_shape, key=k1)
|
||||||
@ -448,6 +450,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
for kwargs in inputs:
|
for kwargs in inputs:
|
||||||
|
test_shape(1, 32, 128, **kwargs)
|
||||||
test_shape(32, 32, 256, **kwargs)
|
test_shape(32, 32, 256, **kwargs)
|
||||||
test_shape(1, 32, 256, **kwargs)
|
test_shape(1, 32, 256, **kwargs)
|
||||||
test_shape(32, 256, 32, transpose=False, **kwargs)
|
test_shape(32, 256, 32, transpose=False, **kwargs)
|
||||||
@ -486,6 +489,66 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
|||||||
g2 = mx.grad(f_test)(x, qw, s, b, lhs_indices, rhs_indices)
|
g2 = mx.grad(f_test)(x, qw, s, b, lhs_indices, rhs_indices)
|
||||||
self.assertTrue(mx.allclose(g1, g2, atol=1e-4))
|
self.assertTrue(mx.allclose(g1, g2, atol=1e-4))
|
||||||
|
|
||||||
|
def test_gather_qmm_sorted(self):
|
||||||
|
def quantize(w, transpose=True, group_size=64, bits=4):
|
||||||
|
qw, s, b = mx.quantize(w, group_size=group_size, bits=bits)
|
||||||
|
w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits)
|
||||||
|
if transpose:
|
||||||
|
w_hat = w_hat.swapaxes(-1, -2)
|
||||||
|
return w_hat, qw, s, b
|
||||||
|
|
||||||
|
def gather_sort(x, indices):
|
||||||
|
N, M = indices.shape
|
||||||
|
indices = indices.flatten()
|
||||||
|
order = mx.argsort(indices)
|
||||||
|
inv_order = mx.argsort(order)
|
||||||
|
return x.flatten(0, -3)[order // M], indices[order], inv_order
|
||||||
|
|
||||||
|
def scatter_unsort(x, inv_order, shape=None):
|
||||||
|
x = x[inv_order]
|
||||||
|
if shape is not None:
|
||||||
|
x = mx.unflatten(x, 0, shape)
|
||||||
|
return x
|
||||||
|
|
||||||
|
parameters = [
|
||||||
|
# L, K, D, E, I, transpose
|
||||||
|
(128, 1024, 1024, 32, 4, True),
|
||||||
|
(128, 1024, 544, 32, 4, True),
|
||||||
|
(433, 1024, 1024, 32, 4, True),
|
||||||
|
(433, 1024, 555, 32, 4, True),
|
||||||
|
(433, 2048, 1024, 32, 4, True),
|
||||||
|
(128, 1024, 1024, 32, 4, False),
|
||||||
|
(128, 1024, 544, 32, 4, False),
|
||||||
|
(433, 1024, 1024, 32, 4, False),
|
||||||
|
(433, 1024, 544, 32, 4, False),
|
||||||
|
(433, 1024, 555, 32, 4, False),
|
||||||
|
(433, 2048, 1024, 32, 4, False),
|
||||||
|
]
|
||||||
|
for L, K, D, E, I, transpose in parameters:
|
||||||
|
K, D = (K, D) if transpose else (D, K)
|
||||||
|
ishape = (L, I)
|
||||||
|
xshape = (L, 1, 1, K)
|
||||||
|
wshape = (E, D, K) if transpose else (E, K, D)
|
||||||
|
|
||||||
|
indices = (mx.random.uniform(shape=ishape) * E).astype(mx.uint32)
|
||||||
|
x = mx.random.normal(xshape) / K**0.5
|
||||||
|
w = mx.random.normal(wshape) / K**0.5
|
||||||
|
w, *wq = quantize(w, transpose=transpose)
|
||||||
|
|
||||||
|
y1 = mx.gather_mm(x, w, rhs_indices=indices)
|
||||||
|
y2 = mx.gather_qmm(x, *wq, transpose=transpose, rhs_indices=indices)
|
||||||
|
xs, idx, inv_order = gather_sort(x, indices)
|
||||||
|
y3 = mx.gather_mm(xs, w, rhs_indices=idx, sorted_indices=True)
|
||||||
|
y4 = mx.gather_qmm(
|
||||||
|
xs, *wq, rhs_indices=idx, transpose=transpose, sorted_indices=True
|
||||||
|
)
|
||||||
|
y3 = scatter_unsort(y3, inv_order, indices.shape)
|
||||||
|
y4 = scatter_unsort(y4, inv_order, indices.shape)
|
||||||
|
|
||||||
|
self.assertTrue(mx.allclose(y1, y2, atol=1e-5))
|
||||||
|
self.assertTrue(mx.allclose(y1, y3, atol=1e-5))
|
||||||
|
self.assertTrue(mx.allclose(y1, y4, atol=1e-5))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -669,6 +669,57 @@ class TestVmap(mlx_tests.MLXTestCase):
|
|||||||
self.assertEqual(mx.vmap(fun, in_axes=(1,))(x).shape, (3, 8))
|
self.assertEqual(mx.vmap(fun, in_axes=(1,))(x).shape, (3, 8))
|
||||||
self.assertEqual(mx.vmap(fun, in_axes=(2,))(x).shape, (4, 6))
|
self.assertEqual(mx.vmap(fun, in_axes=(2,))(x).shape, (4, 6))
|
||||||
|
|
||||||
|
def test_vmap_conv(self):
|
||||||
|
# vmap input only
|
||||||
|
x = mx.random.uniform(shape=(2, 2, 5, 4))
|
||||||
|
w = mx.random.uniform(shape=(8, 3, 4))
|
||||||
|
|
||||||
|
expected = mx.stack([mx.conv1d(xi, w) for xi in x])
|
||||||
|
out = mx.vmap(mx.conv1d, in_axes=(0, None))(x, w)
|
||||||
|
self.assertTrue(mx.allclose(expected, out))
|
||||||
|
|
||||||
|
x = mx.moveaxis(x, 0, 2)
|
||||||
|
out = mx.vmap(mx.conv1d, in_axes=(2, None))(x, w)
|
||||||
|
self.assertTrue(mx.allclose(expected, out))
|
||||||
|
|
||||||
|
# vmap weights only
|
||||||
|
x = mx.random.uniform(shape=(2, 5, 4))
|
||||||
|
w = mx.random.uniform(shape=(3, 8, 3, 4))
|
||||||
|
|
||||||
|
expected = mx.stack([mx.conv1d(x, wi) for wi in w])
|
||||||
|
out = mx.vmap(mx.conv1d, in_axes=(None, 0))(x, w)
|
||||||
|
self.assertTrue(mx.allclose(expected, out))
|
||||||
|
|
||||||
|
w = mx.moveaxis(w, 0, 1)
|
||||||
|
out = mx.vmap(mx.conv1d, in_axes=(None, 1))(x, w)
|
||||||
|
self.assertTrue(mx.allclose(expected, out))
|
||||||
|
|
||||||
|
# vmap weights and input
|
||||||
|
x = mx.random.uniform(shape=(3, 2, 5, 4))
|
||||||
|
w = mx.random.uniform(shape=(3, 8, 3, 4))
|
||||||
|
|
||||||
|
expected = mx.stack([mx.conv1d(xi, wi) for xi, wi in zip(x, w)])
|
||||||
|
out = mx.vmap(mx.conv1d, in_axes=(0, 0))(x, w)
|
||||||
|
self.assertTrue(mx.allclose(expected, out))
|
||||||
|
|
||||||
|
x = mx.random.uniform(shape=(2, 3, 5, 4))
|
||||||
|
w = mx.random.uniform(shape=(8, 3, 4, 3))
|
||||||
|
|
||||||
|
expected = mx.stack([mx.conv1d(x[:, i], w[..., i]) for i in range(3)])
|
||||||
|
out = mx.vmap(mx.conv1d, in_axes=(1, 3))(x, w)
|
||||||
|
self.assertTrue(mx.allclose(expected, out))
|
||||||
|
|
||||||
|
# Test with groups
|
||||||
|
x = mx.random.uniform(shape=(3, 2, 5, 8))
|
||||||
|
w = mx.random.uniform(shape=(3, 2, 3, 4))
|
||||||
|
|
||||||
|
def gconv(x, w):
|
||||||
|
return mx.conv1d(x, w, groups=2)
|
||||||
|
|
||||||
|
expected = mx.stack([gconv(xi, wi) for xi, wi in zip(x, w)])
|
||||||
|
out = mx.vmap(gconv, in_axes=(0, 0))(x, w)
|
||||||
|
self.assertTrue(mx.allclose(expected, out))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -97,8 +97,7 @@ TEST_CASE("test export primitives with state") {
|
|||||||
TEST_CASE("test export functions with kwargs") {
|
TEST_CASE("test export functions with kwargs") {
|
||||||
std::string file_path = get_temp_file("model.mlxfn");
|
std::string file_path = get_temp_file("model.mlxfn");
|
||||||
|
|
||||||
auto fun =
|
auto fun = [](const Kwargs& kwargs) -> std::vector<array> {
|
||||||
[](const std::map<std::string, array>& kwargs) -> std::vector<array> {
|
|
||||||
return {kwargs.at("x") + kwargs.at("y")};
|
return {kwargs.at("x") + kwargs.at("y")};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -3874,3 +3874,41 @@ TEST_CASE("test contiguous") {
|
|||||||
CHECK(x.flags().col_contiguous);
|
CHECK(x.flags().col_contiguous);
|
||||||
CHECK_EQ(x.strides(), decltype(x.strides()){1, 2});
|
CHECK_EQ(x.strides(), decltype(x.strides()){1, 2});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test bitwise shift operations") {
|
||||||
|
std::vector<Dtype> dtypes = {
|
||||||
|
int8, int16, int32, int64, uint8, uint16, uint32, uint64};
|
||||||
|
|
||||||
|
for (const auto& dtype : dtypes) {
|
||||||
|
array x = full({4}, 1, dtype);
|
||||||
|
array y = full({4}, 2, dtype);
|
||||||
|
|
||||||
|
auto left_shift_result = left_shift(x, y);
|
||||||
|
CHECK_EQ(left_shift_result.dtype(), dtype);
|
||||||
|
CHECK(array_equal(left_shift_result, array({4, 4, 4, 4}, dtype))
|
||||||
|
.item<bool>());
|
||||||
|
|
||||||
|
auto right_shift_result = right_shift(full({4}, 4, dtype), y);
|
||||||
|
CHECK_EQ(right_shift_result.dtype(), dtype);
|
||||||
|
CHECK(array_equal(right_shift_result, full({4}, 1, dtype)).item<bool>());
|
||||||
|
}
|
||||||
|
|
||||||
|
array x = array({127, -128}, int8);
|
||||||
|
array y = array({1, 1}, int8);
|
||||||
|
auto left_shift_result = left_shift(x, y);
|
||||||
|
auto right_shift_result = right_shift(x, y);
|
||||||
|
|
||||||
|
CHECK(array_equal(left_shift_result, array({-2, 0}, int8)).item<bool>());
|
||||||
|
CHECK(array_equal(right_shift_result, array({63, -64}, int8)).item<bool>());
|
||||||
|
|
||||||
|
array x_bool = full({4}, true, bool_);
|
||||||
|
array y_bool = full({4}, true, bool_);
|
||||||
|
auto left_shift_bool_result = left_shift(x_bool, y_bool);
|
||||||
|
auto right_shift_bool_result = right_shift(x_bool, y_bool);
|
||||||
|
|
||||||
|
CHECK_EQ(left_shift_bool_result.dtype(), uint8);
|
||||||
|
CHECK(array_equal(left_shift_bool_result, full({4}, 2, uint8)).item<bool>());
|
||||||
|
|
||||||
|
CHECK_EQ(right_shift_bool_result.dtype(), uint8);
|
||||||
|
CHECK(array_equal(right_shift_bool_result, full({4}, 0, uint8)).item<bool>());
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user