mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-11 06:24:35 +08:00
Compare commits
46 Commits
trellis-qu
...
fft
Author | SHA1 | Date | |
---|---|---|---|
![]() |
83762691ba | ||
![]() |
2a41caa00e | ||
![]() |
6593281d25 | ||
![]() |
da98e8bce8 | ||
![]() |
be57a16a80 | ||
![]() |
1704809f29 | ||
![]() |
0cae0bdac8 | ||
![]() |
5a1a5d5ed1 | ||
![]() |
1683975acf | ||
![]() |
af705590ac | ||
![]() |
825124af8f | ||
![]() |
9c5e7da507 | ||
![]() |
481349495b | ||
![]() |
9daa6b003f | ||
![]() |
a3a632d567 | ||
![]() |
e496c5a4b4 | ||
![]() |
ea890d8710 | ||
![]() |
aa5d84f102 | ||
![]() |
f1606486d2 | ||
![]() |
87720a8908 | ||
![]() |
bb6565ef14 | ||
![]() |
7bb063bcb3 | ||
![]() |
b36dd472bb | ||
![]() |
167b759a38 | ||
![]() |
99b9868859 | ||
![]() |
6b2d5448f2 | ||
![]() |
eaf709b83e | ||
![]() |
f0e70afff0 | ||
![]() |
86984cad68 | ||
![]() |
fbc89e3ced | ||
![]() |
38c1e720c2 | ||
![]() |
600e87e03c | ||
![]() |
3836445241 | ||
![]() |
1d2c9d6a07 | ||
![]() |
e8ac6bd2f5 | ||
![]() |
fdadc4f22c | ||
![]() |
79b527f45f | ||
![]() |
dc4eada7f0 | ||
![]() |
70ebc3b598 | ||
![]() |
b13f2aed16 | ||
![]() |
5f04c0f818 | ||
![]() |
55935ccae7 | ||
![]() |
b529515eb1 | ||
![]() |
3cde719eb7 | ||
![]() |
5de6d94a90 | ||
![]() |
99eefd2ec0 |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -36,6 +36,7 @@ share/python-wheels/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
uv.lock
|
||||
|
||||
# vim
|
||||
*.swp
|
||||
|
@@ -34,6 +34,7 @@ option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
|
||||
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
|
||||
option(MLX_BUILD_METAL "Build metal backend" ON)
|
||||
option(MLX_BUILD_CPU "Build cpu backend" ON)
|
||||
option(MLX_BUILD_CUDA "Build cuda backend" OFF)
|
||||
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
|
||||
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
|
||||
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
|
||||
@@ -83,6 +84,10 @@ if(MLX_BUILD_METAL)
|
||||
set(QUARTZ_LIB "-framework QuartzCore")
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_CUDA)
|
||||
enable_language(CUDA)
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_METAL AND NOT METAL_LIB)
|
||||
message(STATUS "Metal not found. Unable to build GPU")
|
||||
set(MLX_BUILD_METAL OFF)
|
||||
|
@@ -1,4 +1,6 @@
|
||||
include CMakeLists.txt
|
||||
include mlx.pc.in
|
||||
recursive-include mlx/ *
|
||||
include cmake/*
|
||||
include python/src/*
|
||||
include python/mlx/py.typed # support type hinting as in PEP-561
|
||||
|
74
benchmarks/python/gather_mm_bench.py
Normal file
74
benchmarks/python/gather_mm_bench.py
Normal file
@@ -0,0 +1,74 @@
|
||||
# Copyright © 2025 Apple Inc.
|
||||
|
||||
import mlx.core as mx
|
||||
from time_utils import time_fn
|
||||
|
||||
N = 1024
|
||||
D = 1024
|
||||
M = 1024
|
||||
E = 32
|
||||
I = 4
|
||||
|
||||
|
||||
def gather_sort(x, indices):
|
||||
N, M = indices.shape
|
||||
indices = indices.flatten()
|
||||
order = mx.argsort(indices)
|
||||
inv_order = mx.argsort(order)
|
||||
return x.flatten(0, -3)[order // M], indices[order], inv_order
|
||||
|
||||
|
||||
def scatter_unsort(x, inv_order, shape=None):
|
||||
x = x[inv_order]
|
||||
if shape is not None:
|
||||
x = mx.unflatten(x, 0, shape)
|
||||
return x
|
||||
|
||||
|
||||
def gather_mm_simulate(x, w, indices):
|
||||
x, idx, inv_order = gather_sort(x, indices)
|
||||
for i in range(2):
|
||||
y = mx.concatenate([x[i] @ w[j].T for i, j in enumerate(idx.tolist())], axis=0)
|
||||
x = y[:, None]
|
||||
x = scatter_unsort(x, inv_order, indices.shape)
|
||||
return x
|
||||
|
||||
|
||||
def time_gather_mm():
|
||||
x = mx.random.normal((N, 1, 1, D)) / 1024**0.5
|
||||
w1 = mx.random.normal((E, M, D)) / 1024**0.5
|
||||
w2 = mx.random.normal((E, D, M)) / 1024**0.5
|
||||
indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)
|
||||
sorted_indices = mx.sort(indices.flatten()).reshape(N, I)
|
||||
mx.eval(x, w1, w2, indices, sorted_indices)
|
||||
|
||||
def gather_mm(x, w1, w2, indices, sort):
|
||||
idx = indices
|
||||
inv_order = None
|
||||
if sort:
|
||||
x, idx, inv_order = gather_sort(x, indices)
|
||||
x = mx.gather_mm(x, w1.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort)
|
||||
x = mx.gather_mm(x, w2.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort)
|
||||
if sort:
|
||||
x = scatter_unsort(x, inv_order, indices.shape)
|
||||
return x
|
||||
|
||||
time_fn(gather_mm, x, w1, w2, indices, False)
|
||||
time_fn(gather_mm, x, w1, w2, sorted_indices, False)
|
||||
time_fn(gather_mm, x, w1, w2, indices, True)
|
||||
|
||||
x = mx.random.normal((N * I, D)) / 1024**0.5
|
||||
w1 = mx.random.normal((M, D)) / 1024**0.5
|
||||
w2 = mx.random.normal((D, M)) / 1024**0.5
|
||||
mx.eval(x, w1, w2)
|
||||
|
||||
def equivalent_matmul(x, w1, w2):
|
||||
x = x @ w1.T
|
||||
x = x @ w2.T
|
||||
return x
|
||||
|
||||
time_fn(equivalent_matmul, x, w1, w2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
time_gather_mm()
|
84
benchmarks/python/gather_qmm_bench.py
Normal file
84
benchmarks/python/gather_qmm_bench.py
Normal file
@@ -0,0 +1,84 @@
|
||||
# Copyright © 2025 Apple Inc.
|
||||
|
||||
import mlx.core as mx
|
||||
from time_utils import time_fn
|
||||
|
||||
N = 1024
|
||||
D = 1024
|
||||
M = 1024
|
||||
E = 32
|
||||
I = 4
|
||||
|
||||
|
||||
def gather_sort(x, indices):
|
||||
N, M = indices.shape
|
||||
indices = indices.flatten()
|
||||
order = mx.argsort(indices)
|
||||
inv_order = mx.argsort(order)
|
||||
return x.flatten(0, -3)[order // M], indices[order], inv_order
|
||||
|
||||
|
||||
def scatter_unsort(x, inv_order, shape=None):
|
||||
x = x[inv_order]
|
||||
if shape is not None:
|
||||
x = mx.unflatten(x, 0, shape)
|
||||
return x
|
||||
|
||||
|
||||
def gather_mm_simulate(x, w, indices):
|
||||
x, idx, inv_order = gather_sort(x, indices)
|
||||
for i in range(2):
|
||||
y = mx.concatenate(
|
||||
[
|
||||
mx.quantized_matmul(x[i], w[0][j], w[1][j], w[2][j], transpose=True)
|
||||
for i, j in enumerate(idx.tolist())
|
||||
],
|
||||
axis=0,
|
||||
)
|
||||
x = y[:, None]
|
||||
x = scatter_unsort(x, inv_order, indices.shape)
|
||||
return x
|
||||
|
||||
|
||||
def time_gather_qmm():
|
||||
x = mx.random.normal((N, 1, 1, D)) / 1024**0.5
|
||||
w1 = mx.random.normal((E, M, D)) / 1024**0.5
|
||||
w2 = mx.random.normal((E, D, M)) / 1024**0.5
|
||||
w1 = mx.quantize(w1)
|
||||
w2 = mx.quantize(w2)
|
||||
indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)
|
||||
sorted_indices = mx.sort(indices.flatten()).reshape(N, I)
|
||||
mx.eval(x, w1, w2, indices, sorted_indices)
|
||||
|
||||
def gather_mm(x, w1, w2, indices, sort):
|
||||
idx = indices
|
||||
inv_order = None
|
||||
if sort:
|
||||
x, idx, inv_order = gather_sort(x, indices)
|
||||
x = mx.gather_qmm(x, *w1, transpose=True, rhs_indices=idx, sorted_indices=sort)
|
||||
x = mx.gather_qmm(x, *w2, transpose=True, rhs_indices=idx, sorted_indices=sort)
|
||||
if sort:
|
||||
x = scatter_unsort(x, inv_order, indices.shape)
|
||||
return x
|
||||
|
||||
time_fn(gather_mm, x, w1, w2, indices, False)
|
||||
time_fn(gather_mm, x, w1, w2, sorted_indices, False)
|
||||
time_fn(gather_mm, x, w1, w2, indices, True)
|
||||
|
||||
x = mx.random.normal((N * I, D)) / 1024**0.5
|
||||
w1 = mx.random.normal((M, D)) / 1024**0.5
|
||||
w2 = mx.random.normal((D, M)) / 1024**0.5
|
||||
w1 = mx.quantize(w1)
|
||||
w2 = mx.quantize(w2)
|
||||
mx.eval(x, w1, w2)
|
||||
|
||||
def equivalent_matmul(x, w1, w2):
|
||||
x = mx.quantized_matmul(x, *w1, transpose=True)
|
||||
x = mx.quantized_matmul(x, *w2, transpose=True)
|
||||
return x
|
||||
|
||||
time_fn(equivalent_matmul, x, w1, w2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
time_gather_qmm()
|
@@ -20,3 +20,5 @@ FFT
|
||||
irfft2
|
||||
rfftn
|
||||
irfftn
|
||||
fftshift
|
||||
ifftshift
|
||||
|
@@ -5,6 +5,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/dtype_utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/export.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
|
||||
@@ -48,5 +49,16 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
|
||||
if(MLX_BUILD_METAL)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal)
|
||||
else()
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_metal)
|
||||
target_sources(mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/no_metal.cpp)
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_CUDA)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda)
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_METAL OR MLX_BUILD_CUDA)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/gpu)
|
||||
else()
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_gpu)
|
||||
endif()
|
||||
|
@@ -339,11 +339,11 @@ class array {
|
||||
return allocator::allocator().size(buffer());
|
||||
}
|
||||
|
||||
// Return a copy of the shared pointer
|
||||
// to the array::Data struct
|
||||
std::shared_ptr<Data> data_shared_ptr() const {
|
||||
// Return the shared pointer to the array::Data struct
|
||||
const std::shared_ptr<Data>& data_shared_ptr() const {
|
||||
return array_desc_->data;
|
||||
}
|
||||
|
||||
// Return a raw pointer to the arrays data
|
||||
template <typename T>
|
||||
T* data() {
|
||||
@@ -356,7 +356,7 @@ class array {
|
||||
}
|
||||
|
||||
enum Status {
|
||||
// The ouptut of a computation which has not been scheduled.
|
||||
// The output of a computation which has not been scheduled.
|
||||
// For example, the status of `x` in `auto x = a + b`.
|
||||
unscheduled,
|
||||
|
||||
|
@@ -1,8 +1,10 @@
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/broadcasting.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/transpose.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp)
|
||||
|
24
mlx/backend/common/broadcasting.cpp
Normal file
24
mlx/backend/common/broadcasting.cpp
Normal file
@@ -0,0 +1,24 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void broadcast(const array& in, array& out) {
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
return;
|
||||
}
|
||||
Strides strides(out.ndim(), 0);
|
||||
int diff = out.ndim() - in.ndim();
|
||||
for (int i = in.ndim() - 1; i >= 0; --i) {
|
||||
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
|
||||
}
|
||||
auto flags = in.flags();
|
||||
if (out.size() > in.size()) {
|
||||
flags.row_contiguous = flags.col_contiguous = false;
|
||||
}
|
||||
out.copy_shared_buffer(in, strides, flags, in.data_size());
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
11
mlx/backend/common/broadcasting.h
Normal file
11
mlx/backend/common/broadcasting.h
Normal file
@@ -0,0 +1,11 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void broadcast(const array& in, array& out);
|
||||
|
||||
} // namespace mlx::core
|
@@ -1,6 +1,8 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/backend/common/broadcasting.h"
|
||||
#include "mlx/backend/common/transpose.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
@@ -18,47 +20,23 @@ void AsStrided::eval(const std::vector<array>& inputs, array& out) {
|
||||
"AsStrided must be used with row contiguous arrays only.");
|
||||
}
|
||||
|
||||
// Compute the flags given the shape and strides
|
||||
bool row_contiguous = true, col_contiguous = true;
|
||||
size_t r = 1, c = 1;
|
||||
for (int i = strides_.size() - 1, j = 0; i >= 0; i--, j++) {
|
||||
row_contiguous &= (r == strides_[i]) || (shape_[i] == 1);
|
||||
col_contiguous &= (c == strides_[j]) || (shape_[j] == 1);
|
||||
r *= shape_[i];
|
||||
c *= shape_[j];
|
||||
}
|
||||
// Calculate the contiguity based on the given shape and strides
|
||||
auto [ds, rc, cc] = check_contiguity(shape_, strides_);
|
||||
auto flags = in.flags();
|
||||
|
||||
// TODO: Compute the contiguous flag in a better way cause now we are
|
||||
// unnecessarily strict.
|
||||
flags.contiguous = row_contiguous || col_contiguous;
|
||||
flags.row_contiguous = row_contiguous;
|
||||
flags.col_contiguous = col_contiguous;
|
||||
flags.contiguous = rc || cc;
|
||||
flags.row_contiguous = rc;
|
||||
flags.col_contiguous = cc;
|
||||
|
||||
// There is no easy way to compute the actual data size so we use out.size().
|
||||
// The contiguous flag will almost certainly not be set so no code should
|
||||
// rely on data_size anyway.
|
||||
size_t data_size = out.size();
|
||||
// There is no easy way to compute the actual data size so we use out.size()
|
||||
// when the array is not contiguous.
|
||||
size_t data_size = flags.contiguous ? ds : out.size();
|
||||
|
||||
return out.copy_shared_buffer(in, strides_, flags, data_size, offset_);
|
||||
}
|
||||
|
||||
void broadcast(const array& in, array& out) {
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
return;
|
||||
}
|
||||
Strides strides(out.ndim(), 0);
|
||||
int diff = out.ndim() - in.ndim();
|
||||
for (int i = in.ndim() - 1; i >= 0; --i) {
|
||||
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
|
||||
}
|
||||
auto flags = in.flags();
|
||||
if (out.size() > in.size()) {
|
||||
flags.row_contiguous = flags.col_contiguous = false;
|
||||
}
|
||||
out.copy_shared_buffer(in, strides, flags, in.data_size());
|
||||
}
|
||||
|
||||
void Broadcast::eval(const std::vector<array>& inputs, array& out) {
|
||||
broadcast(inputs[0], out);
|
||||
}
|
||||
@@ -286,36 +264,7 @@ void StopGradient::eval(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
void Transpose::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
Strides out_strides(out.ndim());
|
||||
auto& in = inputs[0];
|
||||
for (int ax = 0; ax < axes_.size(); ++ax) {
|
||||
out_strides[ax] = in.strides()[axes_[ax]];
|
||||
}
|
||||
|
||||
// Conditions for {row/col}_contiguous
|
||||
// - array must be contiguous (no gaps)
|
||||
// - underlying buffer size should have the same size as the array
|
||||
// - cumulative product of shapes is equal to the strides (we can ignore axes
|
||||
// with size == 1)
|
||||
// - in the forward direction (column contiguous)
|
||||
// - in the reverse direction (row contiguous)
|
||||
// - vectors are both row and col contiguous (hence if both row/col are
|
||||
// true, they stay true)
|
||||
auto flags = in.flags();
|
||||
if (flags.contiguous && in.data_size() == in.size()) {
|
||||
int64_t f_stride = 1;
|
||||
int64_t b_stride = 1;
|
||||
flags.col_contiguous = true;
|
||||
flags.row_contiguous = true;
|
||||
for (int i = 0, ri = out.ndim() - 1; i < out.ndim(); ++i, --ri) {
|
||||
flags.col_contiguous &= (out_strides[i] == f_stride || out.shape(i) == 1);
|
||||
f_stride *= out.shape(i);
|
||||
flags.row_contiguous &=
|
||||
(out_strides[ri] == b_stride || out.shape(ri) == 1);
|
||||
b_stride *= out.shape(ri);
|
||||
}
|
||||
}
|
||||
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
|
||||
transpose(inputs[0], out, axes_);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -99,7 +99,11 @@ inline std::pair<int, int> decompose_hadamard(int n) {
|
||||
"[hadamard] Only supports n = m*2^k where m in (1, 12, 20, 28).");
|
||||
}
|
||||
}
|
||||
if (n > (1 << 26)) {
|
||||
throw std::invalid_argument(
|
||||
"[hadamard] Only supports n = m*2^k where k <= 26");
|
||||
}
|
||||
return {n, m};
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
} // namespace mlx::core
|
||||
|
57
mlx/backend/common/transpose.cpp
Normal file
57
mlx/backend/common/transpose.cpp
Normal file
@@ -0,0 +1,57 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void transpose(const array& in, array& out, const std::vector<int>& axes) {
|
||||
Strides out_strides(out.ndim());
|
||||
for (int ax = 0; ax < axes.size(); ++ax) {
|
||||
out_strides[ax] = in.strides()[axes[ax]];
|
||||
}
|
||||
|
||||
// Conditions for {row/col}_contiguous
|
||||
// - array must be contiguous (no gaps)
|
||||
// - underlying buffer size should have the same size as the array
|
||||
// - cumulative product of shapes is equal to the strides (we can ignore axes
|
||||
// with size == 1)
|
||||
// - in the forward direction (column contiguous)
|
||||
// - in the reverse direction (row contiguous)
|
||||
// - vectors are both row and col contiguous (hence if both row/col are
|
||||
// true, they stay true)
|
||||
auto flags = in.flags();
|
||||
if (flags.contiguous && in.data_size() == in.size()) {
|
||||
auto [_, rc, cc] = check_contiguity(out.shape(), out_strides);
|
||||
flags.row_contiguous = rc;
|
||||
flags.col_contiguous = cc;
|
||||
}
|
||||
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
|
||||
}
|
||||
|
||||
void as_transposed(array& out, const std::vector<int>& axes) {
|
||||
assert(out.data_size() == out.size() && out.flags().contiguous);
|
||||
|
||||
// Calculate the contiguous strides.
|
||||
Strides strides(out.ndim(), 1);
|
||||
for (int i = out.ndim() - 2; i >= 0; i--) {
|
||||
strides[i] = strides[i + 1] * out.shape(i);
|
||||
}
|
||||
|
||||
// Calculate the new strides for transposing.
|
||||
Strides new_strides;
|
||||
new_strides.reserve(out.ndim());
|
||||
for (auto ax : axes) {
|
||||
new_strides.push_back(strides[ax]);
|
||||
}
|
||||
|
||||
auto [ds, rc, cc] = check_contiguity(out.shape(), new_strides);
|
||||
auto flags = out.flags();
|
||||
flags.row_contiguous = rc;
|
||||
flags.col_contiguous = cc;
|
||||
|
||||
out.copy_shared_buffer(out, new_strides, flags, ds);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
12
mlx/backend/common/transpose.h
Normal file
12
mlx/backend/common/transpose.h
Normal file
@@ -0,0 +1,12 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void transpose(const array& in, array& out, const std::vector<int>& axes);
|
||||
void as_transposed(array& out, const std::vector<int>& axes);
|
||||
|
||||
} // namespace mlx::core
|
@@ -132,6 +132,11 @@ struct ContiguousIterator {
|
||||
};
|
||||
|
||||
inline auto check_contiguity(const Shape& shape, const Strides& strides) {
|
||||
// Conditions for {row/col}_contiguous
|
||||
// - cumulative product of shapes is equal to the strides (we can ignore axes
|
||||
// with size == 1)
|
||||
// - in the forward direction (column contiguous)
|
||||
// - in the reverse direction (row contiguous)
|
||||
size_t no_broadcast_data_size = 1;
|
||||
int64_t f_stride = 1;
|
||||
int64_t b_stride = 1;
|
||||
|
@@ -40,7 +40,8 @@ add_dependencies(mlx cpu_compiled_preamble)
|
||||
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/available.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||
|
11
mlx/backend/cpu/available.cpp
Normal file
11
mlx/backend/cpu/available.cpp
Normal file
@@ -0,0 +1,11 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cpu/available.h"
|
||||
|
||||
namespace mlx::core::cpu {
|
||||
|
||||
bool is_available() {
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cpu
|
9
mlx/backend/cpu/available.h
Normal file
9
mlx/backend/cpu/available.h
Normal file
@@ -0,0 +1,9 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace mlx::core::cpu {
|
||||
|
||||
bool is_available();
|
||||
|
||||
} // namespace mlx::core::cpu
|
@@ -172,9 +172,12 @@ void binary_float(
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case complex64:
|
||||
binary_op<complex64_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[binary_float] Only supports non-complex floating point types.");
|
||||
"[binary_float] Only supports floating point types.");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
@@ -40,7 +40,10 @@ struct CompilerCache {
|
||||
std::shared_mutex mtx;
|
||||
};
|
||||
|
||||
static CompilerCache cache{};
|
||||
static CompilerCache& cache() {
|
||||
static CompilerCache cache_;
|
||||
return cache_;
|
||||
};
|
||||
|
||||
// GPU compile is always available if the GPU is available and since we are in
|
||||
// this file CPU compile is also available.
|
||||
@@ -56,14 +59,16 @@ void* compile(
|
||||
const std::string& kernel_name,
|
||||
const std::function<std::string(void)>& source_builder) {
|
||||
{
|
||||
std::shared_lock lock(cache.mtx);
|
||||
if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) {
|
||||
std::shared_lock lock(cache().mtx);
|
||||
if (auto it = cache().kernels.find(kernel_name);
|
||||
it != cache().kernels.end()) {
|
||||
return it->second;
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_lock lock(cache.mtx);
|
||||
if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) {
|
||||
std::unique_lock lock(cache().mtx);
|
||||
if (auto it = cache().kernels.find(kernel_name);
|
||||
it != cache().kernels.end()) {
|
||||
return it->second;
|
||||
}
|
||||
std::string source_code = source_builder();
|
||||
@@ -120,10 +125,10 @@ void* compile(
|
||||
}
|
||||
|
||||
// load library
|
||||
cache.libs.emplace_back(shared_lib_path);
|
||||
cache().libs.emplace_back(shared_lib_path);
|
||||
|
||||
// Load function
|
||||
void* fun = dlsym(cache.libs.back().lib, kernel_name.c_str());
|
||||
void* fun = dlsym(cache().libs.back().lib, kernel_name.c_str());
|
||||
if (!fun) {
|
||||
std::ostringstream msg;
|
||||
msg << "[Compile::eval_cpu] Failed to load compiled function "
|
||||
@@ -131,7 +136,7 @@ void* compile(
|
||||
<< dlerror();
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
cache.kernels.insert({kernel_name, fun});
|
||||
cache().kernels.insert({kernel_name, fun});
|
||||
return fun;
|
||||
}
|
||||
|
||||
|
@@ -330,7 +330,8 @@ void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
case complex64:
|
||||
throw std::runtime_error("Scan ops do not support complex types yet");
|
||||
scan_dispatch<complex64_t, complex64_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
}
|
||||
});
|
||||
|
@@ -88,12 +88,33 @@ DEFAULT_UNARY(expm1, std::expm1)
|
||||
DEFAULT_UNARY(floor, std::floor)
|
||||
DEFAULT_UNARY(log, std::log)
|
||||
DEFAULT_UNARY(log10, std::log10)
|
||||
DEFAULT_UNARY(log1p, std::log1p)
|
||||
DEFAULT_UNARY(sinh, std::sinh)
|
||||
DEFAULT_UNARY(sqrt, std::sqrt)
|
||||
DEFAULT_UNARY(tan, std::tan)
|
||||
DEFAULT_UNARY(tanh, std::tanh)
|
||||
|
||||
template <typename T>
|
||||
Simd<T, 1> log1p(Simd<T, 1> in) {
|
||||
if constexpr (is_complex<T>) {
|
||||
auto x = in.value.real();
|
||||
auto y = in.value.imag();
|
||||
auto zabs = std::abs(in.value);
|
||||
auto theta = std::atan2(y, x + 1);
|
||||
if (zabs < 0.5) {
|
||||
auto r = x * (2 + x) + y * y;
|
||||
if (r == 0) { // handle underflow
|
||||
return Simd<T, 1>{T{x, theta}};
|
||||
}
|
||||
return Simd<T, 1>{T{((typeof(x))(0.5)) * std::log1p(r), theta}};
|
||||
} else {
|
||||
auto z0 = std::hypot(x + 1, y);
|
||||
return Simd<T, 1>{T{std::log(z0), theta}};
|
||||
}
|
||||
} else {
|
||||
return Simd<T, 1>{std::log1p(in.value)};
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Simd<T, 1> log2(Simd<T, 1> in) {
|
||||
if constexpr (is_complex<T>) {
|
||||
|
57
mlx/backend/cuda/CMakeLists.txt
Normal file
57
mlx/backend/cuda/CMakeLists.txt
Normal file
@@ -0,0 +1,57 @@
|
||||
# Filename rules in cuda backend:
|
||||
#
|
||||
# * Use .cu/.cuh if code contains device code, and .cpp/.h if not.
|
||||
# * Device-only kernel code should be put in kernels/ subdir.
|
||||
# * Files in kernels/ subdir should not include files outside.
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fence.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
|
||||
|
||||
target_compile_definitions(mlx PUBLIC MLX_USE_CUDA)
|
||||
|
||||
# Enable defining device lambda functions.
|
||||
target_compile_options(mlx
|
||||
PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--extended-lambda>")
|
||||
|
||||
# Compute capability 7 is required for synchronization between CPU/GPU with
|
||||
# managed memory. TODO: Add more architectures for potential performance gain.
|
||||
set(MLX_CUDA_ARCHITECTURES
|
||||
"75;80"
|
||||
CACHE STRING "CUDA architectures")
|
||||
message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}")
|
||||
set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES
|
||||
"${MLX_CUDA_ARCHITECTURES}")
|
||||
|
||||
# Use fixed version of CCCL.
|
||||
FetchContent_Declare(
|
||||
cccl
|
||||
URL "https://github.com/NVIDIA/cccl/releases/download/v2.8.1/cccl-v2.8.1.zip")
|
||||
FetchContent_MakeAvailable(cccl)
|
||||
target_include_directories(mlx PRIVATE BEFORE "${cccl_SOURCE_DIR}/include")
|
||||
|
||||
# Use fixed version of NVTX.
|
||||
FetchContent_Declare(
|
||||
nvtx3
|
||||
GIT_REPOSITORY https://github.com/NVIDIA/NVTX.git
|
||||
GIT_TAG v3.1.1
|
||||
GIT_SHALLOW TRUE
|
||||
SOURCE_SUBDIR c EXCLUDE_FROM_ALL)
|
||||
FetchContent_MakeAvailable(nvtx3)
|
||||
target_link_libraries(mlx PUBLIC $<BUILD_INTERFACE:nvtx3-cpp>)
|
||||
|
||||
# Make cuda runtime APIs available in non-cuda files.
|
||||
find_package(CUDAToolkit REQUIRED)
|
||||
target_include_directories(mlx PRIVATE ${CUDAToolkit_INCLUDE_DIRS})
|
||||
|
||||
# Suppress nvcc warnings on MLX headers.
|
||||
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
|
||||
--diag_suppress=997>)
|
154
mlx/backend/cuda/allocator.cpp
Normal file
154
mlx/backend/cuda/allocator.cpp
Normal file
@@ -0,0 +1,154 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/allocator.h"
|
||||
#include "mlx/backend/cuda/utils.h"
|
||||
#include "mlx/backend/cuda/worker.h"
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <fmt/format.h>
|
||||
|
||||
#include <cassert>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
CudaAllocator::CudaAllocator() {
|
||||
// TODO: Set memory limit for multi-device.
|
||||
size_t free, total;
|
||||
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
|
||||
memory_limit_ = total * 0.8;
|
||||
}
|
||||
|
||||
Buffer CudaAllocator::malloc(size_t size) {
|
||||
// TODO: Check memory limit.
|
||||
auto* buf = new CudaBuffer{nullptr, size};
|
||||
cudaError_t err = cudaMallocManaged(&buf->data, size);
|
||||
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
|
||||
throw std::runtime_error(
|
||||
fmt::format("cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
|
||||
}
|
||||
std::lock_guard lock(mutex_);
|
||||
active_memory_ += size;
|
||||
peak_memory_ = std::max(active_memory_, peak_memory_);
|
||||
return Buffer{buf};
|
||||
}
|
||||
|
||||
void CudaAllocator::free(Buffer buffer) {
|
||||
auto* buf = static_cast<CudaBuffer*>(buffer.ptr());
|
||||
if (!buf) {
|
||||
return;
|
||||
}
|
||||
|
||||
// If free() is called from a unregistered thread, reschedule the call to
|
||||
// worker.
|
||||
{
|
||||
std::lock_guard lock(worker_mutex_);
|
||||
if (allowed_threads_.count(std::this_thread::get_id()) == 0) {
|
||||
if (!worker_) {
|
||||
worker_.reset(new Worker);
|
||||
}
|
||||
worker_->add_task([buffer]() { allocator().free(buffer); });
|
||||
worker_->end_batch();
|
||||
worker_->commit();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
size_t size = buf->size;
|
||||
cudaFree(buf->data);
|
||||
delete buf;
|
||||
std::lock_guard lock(mutex_);
|
||||
active_memory_ -= size;
|
||||
}
|
||||
|
||||
size_t CudaAllocator::size(Buffer buffer) const {
|
||||
auto* buf = static_cast<CudaBuffer*>(buffer.ptr());
|
||||
if (!buf) {
|
||||
return 0;
|
||||
}
|
||||
return buf->size;
|
||||
}
|
||||
|
||||
void CudaAllocator::register_this_thread() {
|
||||
std::lock_guard lock(worker_mutex_);
|
||||
allowed_threads_.insert(std::this_thread::get_id());
|
||||
}
|
||||
|
||||
size_t CudaAllocator::get_active_memory() const {
|
||||
return active_memory_;
|
||||
}
|
||||
|
||||
size_t CudaAllocator::get_peak_memory() const {
|
||||
return peak_memory_;
|
||||
}
|
||||
|
||||
void CudaAllocator::reset_peak_memory() {
|
||||
std::lock_guard lock(mutex_);
|
||||
peak_memory_ = 0;
|
||||
}
|
||||
|
||||
size_t CudaAllocator::get_memory_limit() {
|
||||
return memory_limit_;
|
||||
}
|
||||
|
||||
size_t CudaAllocator::set_memory_limit(size_t limit) {
|
||||
std::lock_guard lock(mutex_);
|
||||
std::swap(limit, memory_limit_);
|
||||
return limit;
|
||||
}
|
||||
|
||||
CudaAllocator& allocator() {
|
||||
// By creating the |allocator_| on heap, the destructor of CudaAllocator
|
||||
// will not be called on exit and buffers in the cache will be leaked. This
|
||||
// can save some time at program exit.
|
||||
static CudaAllocator* allocator_ = new CudaAllocator;
|
||||
return *allocator_;
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
namespace allocator {
|
||||
|
||||
Allocator& allocator() {
|
||||
return cu::allocator();
|
||||
}
|
||||
|
||||
void* Buffer::raw_ptr() {
|
||||
if (!ptr_) {
|
||||
return nullptr;
|
||||
}
|
||||
return static_cast<cu::CudaBuffer*>(ptr_)->data;
|
||||
}
|
||||
|
||||
} // namespace allocator
|
||||
|
||||
size_t get_active_memory() {
|
||||
return cu::allocator().get_active_memory();
|
||||
}
|
||||
size_t get_peak_memory() {
|
||||
return cu::allocator().get_peak_memory();
|
||||
}
|
||||
void reset_peak_memory() {
|
||||
return cu::allocator().reset_peak_memory();
|
||||
}
|
||||
size_t set_memory_limit(size_t limit) {
|
||||
return cu::allocator().set_memory_limit(limit);
|
||||
}
|
||||
size_t get_memory_limit() {
|
||||
return cu::allocator().get_memory_limit();
|
||||
}
|
||||
|
||||
// TODO: Implement buffer cache.
|
||||
size_t get_cache_memory() {
|
||||
return 0;
|
||||
}
|
||||
size_t set_cache_limit(size_t) {
|
||||
return 0;
|
||||
}
|
||||
size_t set_wired_limit(size_t) {
|
||||
return 0;
|
||||
}
|
||||
void clear_cache() {}
|
||||
|
||||
} // namespace mlx::core
|
58
mlx/backend/cuda/allocator.h
Normal file
58
mlx/backend/cuda/allocator.h
Normal file
@@ -0,0 +1,58 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
|
||||
#include <mutex>
|
||||
#include <set>
|
||||
#include <thread>
|
||||
#include <utility>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
class Worker;
|
||||
|
||||
using allocator::Buffer;
|
||||
|
||||
// Stores cuda-managed unified memory.
|
||||
struct CudaBuffer {
|
||||
void* data;
|
||||
size_t size;
|
||||
};
|
||||
|
||||
class CudaAllocator : public allocator::Allocator {
|
||||
public:
|
||||
Buffer malloc(size_t size) override;
|
||||
void free(Buffer buffer) override;
|
||||
size_t size(Buffer buffer) const override;
|
||||
|
||||
// Register current thread as safe to free buffers.
|
||||
// In cuda freeing a buffer implicitly synchronizes stream, and for threads
|
||||
// that may be waited by gpu stream (for example cpu stream threads), freeing
|
||||
// buffers there would result in dead lock.
|
||||
void register_this_thread();
|
||||
|
||||
size_t get_active_memory() const;
|
||||
size_t get_peak_memory() const;
|
||||
void reset_peak_memory();
|
||||
size_t get_memory_limit();
|
||||
size_t set_memory_limit(size_t limit);
|
||||
|
||||
private:
|
||||
CudaAllocator();
|
||||
friend CudaAllocator& allocator();
|
||||
|
||||
std::mutex worker_mutex_;
|
||||
std::unique_ptr<Worker> worker_;
|
||||
std::set<std::thread::id> allowed_threads_;
|
||||
|
||||
std::mutex mutex_;
|
||||
size_t memory_limit_;
|
||||
size_t active_memory_{0};
|
||||
size_t peak_memory_{0};
|
||||
};
|
||||
|
||||
CudaAllocator& allocator();
|
||||
|
||||
} // namespace mlx::core::cu
|
26
mlx/backend/cuda/copy.cpp
Normal file
26
mlx/backend/cuda/copy.cpp
Normal file
@@ -0,0 +1,26 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void copy_gpu_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
const Shape& data_shape,
|
||||
const Strides& strides_in_pre,
|
||||
const Strides& strides_out_pre,
|
||||
int64_t inp_offset,
|
||||
int64_t out_offset,
|
||||
CopyType ctype,
|
||||
const Stream& s,
|
||||
const std::optional<array>& dynamic_i_offset /* = std::nullopt */,
|
||||
const std::optional<array>& dynamic_o_offset /* = std::nullopt */) {
|
||||
throw std::runtime_error("copy_gpu_inplace not implemented in CUDA backend.");
|
||||
}
|
||||
|
||||
void fill_gpu(const array& val, array& out, const Stream& s) {
|
||||
throw std::runtime_error("fill_gpu not implemented in CUDA backend.");
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
117
mlx/backend/cuda/device.cpp
Normal file
117
mlx/backend/cuda/device.cpp
Normal file
@@ -0,0 +1,117 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/worker.h"
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
|
||||
#include <fmt/format.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
DeviceStream::DeviceStream(Device& device) : device_(device), stream_(device) {}
|
||||
|
||||
void DeviceStream::synchronize() {
|
||||
cudaStreamSynchronize(stream_);
|
||||
}
|
||||
|
||||
cudaStream_t DeviceStream::schedule_cuda_stream() {
|
||||
// TODO: Return a stream that maximizes parallelism.
|
||||
return stream_;
|
||||
}
|
||||
|
||||
cudaStream_t DeviceStream::last_cuda_stream() {
|
||||
return stream_;
|
||||
}
|
||||
|
||||
CommandEncoder& DeviceStream::get_encoder() {
|
||||
if (!encoder_) {
|
||||
encoder_ = std::make_unique<CommandEncoder>(*this);
|
||||
}
|
||||
return *encoder_;
|
||||
}
|
||||
|
||||
Device::Device(int device) : device_(device) {
|
||||
// Validate the requirements of device.
|
||||
int attr = 0;
|
||||
cudaDeviceGetAttribute(&attr, cudaDevAttrConcurrentManagedAccess, device_);
|
||||
if (attr != 1) {
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Device {} does not support synchronization in managed memory.",
|
||||
device_));
|
||||
}
|
||||
}
|
||||
|
||||
void Device::make_current() {
|
||||
// We need to set/get current CUDA device very frequently, cache it to reduce
|
||||
// actual calls of CUDA APIs. This function assumes single-thread in host.
|
||||
static int current = 0;
|
||||
if (current != device_) {
|
||||
CHECK_CUDA_ERROR(cudaSetDevice(device_));
|
||||
current = device_;
|
||||
}
|
||||
}
|
||||
|
||||
DeviceStream& Device::get_stream(Stream s) {
|
||||
auto it = streams_.find(s.index);
|
||||
if (it == streams_.end()) {
|
||||
it = streams_.try_emplace(s.index, *this).first;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
CommandEncoder::CommandEncoder(DeviceStream& s)
|
||||
: device_(s.device()), stream_(s) {}
|
||||
|
||||
void CommandEncoder::add_completed_handler(std::function<void()> task) {
|
||||
worker_.add_task(std::move(task));
|
||||
}
|
||||
|
||||
void CommandEncoder::end_encoding() {
|
||||
if (!temporaries_.empty()) {
|
||||
add_completed_handler([temporaries = std::move(temporaries_)]() {});
|
||||
}
|
||||
|
||||
// There is no kernel running, run completion handlers immediately.
|
||||
if (!has_gpu_work_) {
|
||||
worker_.consume_in_this_thread();
|
||||
return;
|
||||
}
|
||||
has_gpu_work_ = false;
|
||||
|
||||
// Put completion handlers in a batch.
|
||||
worker_.end_batch();
|
||||
|
||||
// Signaling kernel completion is expensive, delay until enough batches.
|
||||
// TODO: This number is arbitrarily picked, profile for a better stragety.
|
||||
if (worker_.uncommited_batches() > 8) {
|
||||
commit();
|
||||
}
|
||||
}
|
||||
|
||||
void CommandEncoder::commit() {
|
||||
worker_.commit(stream_.last_cuda_stream());
|
||||
}
|
||||
|
||||
Device& device(mlx::core::Device device) {
|
||||
static std::unordered_map<int, Device> devices;
|
||||
auto it = devices.find(device.index);
|
||||
if (it == devices.end()) {
|
||||
it = devices.try_emplace(device.index, device.index).first;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
DeviceStream& get_stream(Stream s) {
|
||||
return device(s.device).get_stream(s);
|
||||
}
|
||||
|
||||
CommandEncoder& get_command_encoder(Stream s) {
|
||||
return get_stream(s).get_encoder();
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
} // namespace mlx::core
|
131
mlx/backend/cuda/device.h
Normal file
131
mlx/backend/cuda/device.h
Normal file
@@ -0,0 +1,131 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/cuda/worker.h"
|
||||
#include "mlx/stream.h"
|
||||
|
||||
#include <thrust/execution_policy.h>
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
class Device;
|
||||
class CommandEncoder;
|
||||
|
||||
class DeviceStream {
|
||||
public:
|
||||
explicit DeviceStream(Device& device);
|
||||
|
||||
DeviceStream(const DeviceStream&) = delete;
|
||||
DeviceStream& operator=(const DeviceStream&) = delete;
|
||||
|
||||
// Wait until kernels in the stream complete.
|
||||
void synchronize();
|
||||
|
||||
// Return a cuda stream for launching kernels.
|
||||
cudaStream_t schedule_cuda_stream();
|
||||
|
||||
// Return the last cuda stream used.
|
||||
cudaStream_t last_cuda_stream();
|
||||
|
||||
CommandEncoder& get_encoder();
|
||||
|
||||
Device& device() {
|
||||
return device_;
|
||||
}
|
||||
|
||||
private:
|
||||
Device& device_;
|
||||
CudaStream stream_;
|
||||
std::unique_ptr<CommandEncoder> encoder_;
|
||||
};
|
||||
|
||||
class Device {
|
||||
public:
|
||||
explicit Device(int device);
|
||||
|
||||
Device(const Device&) = delete;
|
||||
Device& operator=(const Device&) = delete;
|
||||
|
||||
// Make this device the current cuda device, required by some cuda calls.
|
||||
void make_current();
|
||||
|
||||
DeviceStream& get_stream(Stream s);
|
||||
|
||||
int cuda_device() const {
|
||||
return device_;
|
||||
}
|
||||
|
||||
private:
|
||||
int device_;
|
||||
std::unordered_map<int, DeviceStream> streams_;
|
||||
};
|
||||
|
||||
class CommandEncoder {
|
||||
public:
|
||||
explicit CommandEncoder(DeviceStream& stream);
|
||||
|
||||
CommandEncoder(const CommandEncoder&) = delete;
|
||||
CommandEncoder& operator=(const CommandEncoder&) = delete;
|
||||
|
||||
void set_input_array(const array& arr) {}
|
||||
void set_output_array(const array& arr) {}
|
||||
|
||||
void add_temporary(const array& arr) {
|
||||
temporaries_.push_back(arr.data_shared_ptr());
|
||||
}
|
||||
|
||||
void add_completed_handler(std::function<void()> task);
|
||||
void end_encoding();
|
||||
void commit();
|
||||
|
||||
// Schedule a cuda stream for |fun| to launch kernels, and check error
|
||||
// afterwards.
|
||||
template <typename F>
|
||||
void launch_kernel(F&& fun) {
|
||||
launch_kernel(stream_.schedule_cuda_stream(), std::forward<F>(fun));
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
void launch_kernel(cudaStream_t stream, F&& fun) {
|
||||
device_.make_current();
|
||||
fun(stream);
|
||||
check_cuda_error("kernel launch", cudaGetLastError());
|
||||
has_gpu_work_ = true;
|
||||
}
|
||||
|
||||
Device& device() {
|
||||
return device_;
|
||||
}
|
||||
|
||||
DeviceStream& stream() {
|
||||
return stream_;
|
||||
}
|
||||
|
||||
bool has_gpu_work() const {
|
||||
return has_gpu_work_;
|
||||
}
|
||||
|
||||
private:
|
||||
Device& device_;
|
||||
DeviceStream& stream_;
|
||||
Worker worker_;
|
||||
bool has_gpu_work_{false};
|
||||
std::vector<std::shared_ptr<array::Data>> temporaries_;
|
||||
};
|
||||
|
||||
Device& device(mlx::core::Device device);
|
||||
DeviceStream& get_stream(Stream s);
|
||||
CommandEncoder& get_command_encoder(Stream s);
|
||||
|
||||
// Return an execution policy that does not sync for result.
|
||||
// Note that not all thrust APIs support async policy, confirm before using.
|
||||
inline auto thrust_policy(cudaStream_t stream) {
|
||||
// TODO: Connect thrust's custom allocator with mlx's allocator.
|
||||
return thrust::cuda::par_nosync.on(stream);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
35
mlx/backend/cuda/dtype_utils.cuh
Normal file
35
mlx/backend/cuda/dtype_utils.cuh
Normal file
@@ -0,0 +1,35 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuComplex.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
// Maps CPU types to CUDA types.
|
||||
template <typename T>
|
||||
struct CTypeToCudaType {
|
||||
using type = T;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct CTypeToCudaType<float16_t> {
|
||||
using type = __half;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct CTypeToCudaType<bfloat16_t> {
|
||||
using type = __nv_bfloat16;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct CTypeToCudaType<complex64_t> {
|
||||
using type = cuComplex;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
using cuda_type_t = typename CTypeToCudaType<T>::type;
|
||||
|
||||
} // namespace mlx::core
|
68
mlx/backend/cuda/eval.cpp
Normal file
68
mlx/backend/cuda/eval.cpp
Normal file
@@ -0,0 +1,68 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/gpu/eval.h"
|
||||
#include "mlx/backend/cuda/allocator.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/gpu/available.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
namespace mlx::core::gpu {
|
||||
|
||||
bool is_available() {
|
||||
return true;
|
||||
}
|
||||
|
||||
void new_stream(Stream s) {
|
||||
// Force initalization of cuda, so cuda runtime get destroyed at last.
|
||||
cudaFree(nullptr);
|
||||
// Ensure the static stream objects get created.
|
||||
cu::get_command_encoder(s);
|
||||
// The main thread is safe to free buffers.
|
||||
cu::allocator().register_this_thread();
|
||||
}
|
||||
|
||||
void eval(array& arr) {
|
||||
nvtx3::scoped_range r("gpu::eval");
|
||||
auto outputs = arr.outputs();
|
||||
{
|
||||
// If the array is a tracer hold a reference
|
||||
// to its inputs so they don't get donated
|
||||
std::vector<array> inputs;
|
||||
if (arr.is_tracer()) {
|
||||
inputs = arr.inputs();
|
||||
}
|
||||
arr.primitive().eval_gpu(arr.inputs(), outputs);
|
||||
}
|
||||
|
||||
auto& encoder = cu::get_command_encoder(arr.primitive().stream());
|
||||
if (encoder.has_gpu_work()) {
|
||||
// Keep used buffers alive until kernel finishes running.
|
||||
std::unordered_set<std::shared_ptr<array::Data>> buffers;
|
||||
for (auto& in : arr.inputs()) {
|
||||
buffers.insert(in.data_shared_ptr());
|
||||
}
|
||||
for (auto& s : arr.siblings()) {
|
||||
buffers.insert(s.data_shared_ptr());
|
||||
}
|
||||
// Remove the output if it was donated to by an input.
|
||||
if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) {
|
||||
buffers.erase(it);
|
||||
}
|
||||
encoder.add_completed_handler([buffers = std::move(buffers)]() {});
|
||||
}
|
||||
encoder.end_encoding();
|
||||
}
|
||||
|
||||
void finalize(Stream s) {
|
||||
nvtx3::scoped_range r("gpu::finalize");
|
||||
cu::get_command_encoder(s).commit();
|
||||
}
|
||||
|
||||
void synchronize(Stream s) {
|
||||
nvtx3::scoped_range r("gpu::synchronize");
|
||||
cu::get_stream(s).synchronize();
|
||||
}
|
||||
|
||||
} // namespace mlx::core::gpu
|
265
mlx/backend/cuda/event.cu
Normal file
265
mlx/backend/cuda/event.cu
Normal file
@@ -0,0 +1,265 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/event.h"
|
||||
#include "mlx/backend/cuda/utils.h"
|
||||
#include "mlx/event.h"
|
||||
#include "mlx/scheduler.h"
|
||||
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// CudaEvent implementations
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Cuda event managed with RAII.
|
||||
class CudaEventHandle {
|
||||
public:
|
||||
CudaEventHandle() {
|
||||
CHECK_CUDA_ERROR(cudaEventCreateWithFlags(
|
||||
&event_, cudaEventDisableTiming | cudaEventBlockingSync));
|
||||
}
|
||||
|
||||
~CudaEventHandle() {
|
||||
CHECK_CUDA_ERROR(cudaEventDestroy(event_));
|
||||
}
|
||||
|
||||
CudaEventHandle(const CudaEventHandle&) = delete;
|
||||
CudaEventHandle& operator=(const CudaEventHandle&) = delete;
|
||||
|
||||
operator cudaEvent_t() const {
|
||||
return event_;
|
||||
}
|
||||
|
||||
private:
|
||||
cudaEvent_t event_;
|
||||
};
|
||||
|
||||
CudaEvent::CudaEvent() : event_(std::make_shared<CudaEventHandle>()) {}
|
||||
|
||||
void CudaEvent::wait() {
|
||||
nvtx3::scoped_range r("cu::CudaEvent::wait");
|
||||
if (!recorded_) {
|
||||
throw std::runtime_error("Should not wait on a CudaEvent before record.");
|
||||
}
|
||||
cudaEventSynchronize(*event_);
|
||||
}
|
||||
|
||||
void CudaEvent::wait(cudaStream_t stream) {
|
||||
if (!recorded_) {
|
||||
throw std::runtime_error("Should not wait on a CudaEvent before record.");
|
||||
}
|
||||
cudaStreamWaitEvent(stream, *event_);
|
||||
}
|
||||
|
||||
void CudaEvent::wait(Stream s) {
|
||||
if (s.device == mlx::core::Device::cpu) {
|
||||
scheduler::enqueue(s, [*this]() mutable { wait(); });
|
||||
} else {
|
||||
wait(cu::get_stream(s).last_cuda_stream());
|
||||
}
|
||||
}
|
||||
|
||||
void CudaEvent::record(cudaStream_t stream) {
|
||||
cudaEventRecord(*event_, stream);
|
||||
recorded_ = true;
|
||||
}
|
||||
|
||||
void CudaEvent::record(Stream s) {
|
||||
if (s.device == mlx::core::Device::cpu) {
|
||||
throw std::runtime_error("CudaEvent can not wait on cpu stream.");
|
||||
} else {
|
||||
record(cu::get_stream(s).last_cuda_stream());
|
||||
}
|
||||
}
|
||||
|
||||
bool CudaEvent::completed() const {
|
||||
return cudaEventQuery(*event_) == cudaSuccess;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// SharedEvent implementations
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace {
|
||||
|
||||
__host__ __device__ void event_wait(SharedEvent::Atomic* ac, uint64_t value) {
|
||||
uint64_t current;
|
||||
while ((current = ac->load()) < value) {
|
||||
ac->wait(current);
|
||||
}
|
||||
}
|
||||
|
||||
__host__ __device__ void event_signal(SharedEvent::Atomic* ac, uint64_t value) {
|
||||
ac->store(value);
|
||||
ac->notify_all();
|
||||
}
|
||||
|
||||
__global__ void event_wait_kernel(SharedEvent::Atomic* ac, uint64_t value) {
|
||||
event_wait(ac, value);
|
||||
}
|
||||
|
||||
__global__ void event_signal_kernel(SharedEvent::Atomic* ac, uint64_t value) {
|
||||
event_signal(ac, value);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
SharedEvent::SharedEvent() {
|
||||
// Allocate cuda::atomic on managed memory.
|
||||
allocator::Buffer buffer = allocator::malloc(sizeof(Atomic));
|
||||
Atomic* ac = static_cast<Atomic*>(buffer.raw_ptr());
|
||||
new (ac) Atomic(0);
|
||||
ac_ = std::shared_ptr<Atomic>(ac, [buffer](Atomic* ptr) {
|
||||
ptr->~Atomic();
|
||||
allocator::free(buffer);
|
||||
});
|
||||
}
|
||||
|
||||
void SharedEvent::wait(uint64_t value) {
|
||||
nvtx3::scoped_range r("cu::SharedEvent::wait");
|
||||
event_wait(ac_.get(), value);
|
||||
}
|
||||
|
||||
void SharedEvent::wait(cudaStream_t stream, uint64_t value) {
|
||||
event_wait_kernel<<<1, 1, 0, stream>>>(ac_.get(), value);
|
||||
}
|
||||
|
||||
void SharedEvent::wait(Stream s, uint64_t value) {
|
||||
nvtx3::scoped_range r("cu::SharedEvent::wait(s)");
|
||||
if (s.device == mlx::core::Device::cpu) {
|
||||
scheduler::enqueue(s, [*this, value]() mutable { wait(value); });
|
||||
} else {
|
||||
auto& encoder = get_command_encoder(s);
|
||||
encoder.launch_kernel(
|
||||
encoder.stream().last_cuda_stream(),
|
||||
[this, value](cudaStream_t stream) { wait(stream, value); });
|
||||
encoder.add_completed_handler([ac = ac_]() {});
|
||||
encoder.end_encoding();
|
||||
}
|
||||
}
|
||||
|
||||
void SharedEvent::signal(uint64_t value) {
|
||||
nvtx3::scoped_range r("cu::SharedEvent::signal");
|
||||
event_signal(ac_.get(), value);
|
||||
}
|
||||
|
||||
void SharedEvent::signal(cudaStream_t stream, uint64_t value) {
|
||||
event_signal_kernel<<<1, 1, 0, stream>>>(ac_.get(), value);
|
||||
}
|
||||
|
||||
void SharedEvent::signal(Stream s, uint64_t value) {
|
||||
nvtx3::scoped_range r("cu::SharedEvent::signal(s)");
|
||||
if (s.device == mlx::core::Device::cpu) {
|
||||
scheduler::enqueue(s, [*this, value]() mutable { signal(value); });
|
||||
} else {
|
||||
auto& encoder = get_command_encoder(s);
|
||||
encoder.launch_kernel(
|
||||
encoder.stream().last_cuda_stream(),
|
||||
[this, value](cudaStream_t stream) { signal(stream, value); });
|
||||
encoder.add_completed_handler([ac = ac_]() {});
|
||||
encoder.end_encoding();
|
||||
}
|
||||
}
|
||||
|
||||
bool SharedEvent::is_signaled(uint64_t value) const {
|
||||
nvtx3::scoped_range r("cu::SharedEvent::is_signaled");
|
||||
return ac_->load() >= value;
|
||||
}
|
||||
|
||||
uint64_t SharedEvent::value() const {
|
||||
nvtx3::scoped_range r("cu::SharedEvent::value");
|
||||
return ac_->load();
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Event implementations
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace {
|
||||
|
||||
struct EventImpl {
|
||||
// CudaEvent is preferred when possible because it is fast, however we have
|
||||
// to fallback to SharedEvent in following cases:
|
||||
// 1. the event is used to wait/signal a cpu stream;
|
||||
// 2. signal value other than 1 has been specified.
|
||||
std::unique_ptr<cu::CudaEvent> cuda;
|
||||
std::unique_ptr<cu::SharedEvent> shared;
|
||||
|
||||
bool is_created() const {
|
||||
return cuda || shared;
|
||||
}
|
||||
|
||||
void ensure_created(Stream s, uint64_t signal_value) {
|
||||
if (is_created()) {
|
||||
return;
|
||||
}
|
||||
if (s.device == mlx::core::Device::cpu || signal_value > 1) {
|
||||
nvtx3::mark("Using slow SharedEvent");
|
||||
shared = std::make_unique<cu::SharedEvent>();
|
||||
} else {
|
||||
cuda = std::make_unique<cu::CudaEvent>();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
Event::Event(Stream s) : stream_(s) {
|
||||
event_ = std::shared_ptr<void>(
|
||||
new EventImpl(), [](void* ptr) { delete static_cast<EventImpl*>(ptr); });
|
||||
}
|
||||
|
||||
void Event::wait() {
|
||||
auto* event = static_cast<EventImpl*>(event_.get());
|
||||
assert(event->is_created());
|
||||
if (event->cuda) {
|
||||
assert(value() == 1);
|
||||
event->cuda->wait();
|
||||
} else {
|
||||
event->shared->wait(value());
|
||||
}
|
||||
}
|
||||
|
||||
void Event::wait(Stream s) {
|
||||
auto* event = static_cast<EventImpl*>(event_.get());
|
||||
assert(event->is_created());
|
||||
if (event->cuda) {
|
||||
assert(value() == 1);
|
||||
event->cuda->wait(s);
|
||||
} else {
|
||||
event->shared->wait(s, value());
|
||||
}
|
||||
}
|
||||
|
||||
void Event::signal(Stream s) {
|
||||
auto* event = static_cast<EventImpl*>(event_.get());
|
||||
event->ensure_created(s, value());
|
||||
if (event->cuda) {
|
||||
assert(value() == 1);
|
||||
event->cuda->record(s);
|
||||
} else {
|
||||
event->shared->signal(s, value());
|
||||
}
|
||||
}
|
||||
|
||||
bool Event::is_signaled() const {
|
||||
auto* event = static_cast<EventImpl*>(event_.get());
|
||||
if (!event->is_created()) {
|
||||
return false;
|
||||
}
|
||||
if (event->cuda) {
|
||||
assert(value() == 1);
|
||||
return event->cuda->recorded() && event->cuda->completed();
|
||||
} else {
|
||||
return event->shared->is_signaled(value());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
66
mlx/backend/cuda/event.h
Normal file
66
mlx/backend/cuda/event.h
Normal file
@@ -0,0 +1,66 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/stream.h"
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda/atomic>
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
class CudaEventHandle;
|
||||
|
||||
// Wrapper of native cuda event. It can synchronize between GPU streams, or wait
|
||||
// on GPU stream in CPU stream, but can not wait on CPU stream.
|
||||
class CudaEvent {
|
||||
public:
|
||||
CudaEvent();
|
||||
|
||||
void wait();
|
||||
void wait(cudaStream_t stream);
|
||||
void wait(Stream s);
|
||||
void record(cudaStream_t stream);
|
||||
void record(Stream s);
|
||||
|
||||
// Return whether the recorded kernels have completed. Note that this method
|
||||
// returns true if record() has not been called.
|
||||
bool completed() const;
|
||||
|
||||
bool recorded() const {
|
||||
return recorded_;
|
||||
}
|
||||
|
||||
private:
|
||||
bool recorded_{false};
|
||||
std::shared_ptr<CudaEventHandle> event_;
|
||||
};
|
||||
|
||||
// Event that can synchronize between CPU and GPU. It is much slower than
|
||||
// CudaEvent so the latter should always be preferred when possible.
|
||||
class SharedEvent {
|
||||
public:
|
||||
using Atomic = cuda::atomic<uint64_t>;
|
||||
|
||||
SharedEvent();
|
||||
|
||||
void wait(uint64_t value);
|
||||
void wait(cudaStream_t stream, uint64_t value);
|
||||
void wait(Stream s, uint64_t value);
|
||||
void signal(uint64_t value);
|
||||
void signal(cudaStream_t stream, uint64_t value);
|
||||
void signal(Stream s, uint64_t value);
|
||||
bool is_signaled(uint64_t value) const;
|
||||
uint64_t value() const;
|
||||
|
||||
const std::shared_ptr<Atomic>& atomic() const {
|
||||
return ac_;
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<Atomic> ac_;
|
||||
};
|
||||
|
||||
} // namespace mlx::core::cu
|
70
mlx/backend/cuda/fence.cu
Normal file
70
mlx/backend/cuda/fence.cu
Normal file
@@ -0,0 +1,70 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/event.h"
|
||||
#include "mlx/fence.h"
|
||||
#include "mlx/scheduler.h"
|
||||
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
__host__ __device__ void busy_wait(cuda::atomic<uint64_t>* ac, uint64_t value) {
|
||||
while (true) {
|
||||
// In theory the atomic_thread_fence is not needed, but for CUDA 11 without
|
||||
// it the load() may never return new value.
|
||||
cuda::atomic_thread_fence(cuda::memory_order_seq_cst);
|
||||
uint64_t current = ac->load();
|
||||
if (current >= value) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void busy_wait_kernel(cuda::atomic<uint64_t>* ac, uint64_t value) {
|
||||
busy_wait(ac, value);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
struct FenceImpl {
|
||||
uint32_t count;
|
||||
cu::SharedEvent event;
|
||||
};
|
||||
|
||||
Fence::Fence(Stream s) {
|
||||
fence_ = std::shared_ptr<void>(
|
||||
new FenceImpl{0}, [](void* ptr) { delete static_cast<FenceImpl*>(ptr); });
|
||||
}
|
||||
|
||||
void Fence::wait(Stream s, const array&) {
|
||||
auto* fence = static_cast<FenceImpl*>(fence_.get());
|
||||
// We can't use SharedEvent::wait because it could hang in CUDA 11, see also:
|
||||
// https://github.com/ml-explore/mlx/issues/2137
|
||||
const auto& ac = fence->event.atomic();
|
||||
if (s.device == mlx::core::Device::cpu) {
|
||||
scheduler::enqueue(s, [ac, count = fence->count]() {
|
||||
nvtx3::scoped_range r("Fence::wait()");
|
||||
busy_wait(ac.get(), count);
|
||||
});
|
||||
} else {
|
||||
nvtx3::scoped_range r("Fence::wait(s)");
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.launch_kernel(
|
||||
encoder.stream().last_cuda_stream(), [&](cudaStream_t stream) {
|
||||
busy_wait_kernel<<<1, 1, 0>>>(ac.get(), fence->count);
|
||||
});
|
||||
encoder.add_completed_handler([ac]() {});
|
||||
encoder.end_encoding();
|
||||
}
|
||||
}
|
||||
|
||||
void Fence::update(Stream s, const array&) {
|
||||
auto* fence = static_cast<FenceImpl*>(fence_.get());
|
||||
fence->count++;
|
||||
fence->event.signal(s, fence->count);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
15
mlx/backend/cuda/kernels/arange.cuh
Normal file
15
mlx/backend/cuda/kernels/arange.cuh
Normal file
@@ -0,0 +1,15 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
template <typename T>
|
||||
struct Arange {
|
||||
const T start;
|
||||
const T step;
|
||||
|
||||
__device__ T operator()(uint32_t i) const {
|
||||
return start + i * step;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mlx::core::cu
|
107
mlx/backend/cuda/kernels/fp16_math.cuh
Normal file
107
mlx/backend/cuda/kernels/fp16_math.cuh
Normal file
@@ -0,0 +1,107 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda/std/limits>
|
||||
#include <cuda/std/type_traits>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Missing C++ operator overrides for CUDA 7.
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800
|
||||
|
||||
#define MLX_DEFINE_BF16_OP(OP) \
|
||||
__forceinline__ __device__ __nv_bfloat16 operator OP( \
|
||||
__nv_bfloat16 x, __nv_bfloat16 y) { \
|
||||
return __float2bfloat16(__bfloat162float(x) OP __bfloat162float(y)); \
|
||||
}
|
||||
|
||||
#define MLX_DEFINE_BF16_CMP(OP) \
|
||||
__forceinline__ __device__ bool operator OP( \
|
||||
__nv_bfloat16 x, __nv_bfloat16 y) { \
|
||||
return __float2bfloat16(__bfloat162float(x) OP __bfloat162float(y)); \
|
||||
}
|
||||
|
||||
MLX_DEFINE_BF16_OP(+)
|
||||
MLX_DEFINE_BF16_OP(-)
|
||||
MLX_DEFINE_BF16_OP(*)
|
||||
MLX_DEFINE_BF16_OP(/)
|
||||
MLX_DEFINE_BF16_CMP(>)
|
||||
MLX_DEFINE_BF16_CMP(<)
|
||||
MLX_DEFINE_BF16_CMP(>=)
|
||||
MLX_DEFINE_BF16_CMP(<=)
|
||||
|
||||
#undef MLX_DEFINE_BF16_OP
|
||||
#undef MLX_DEFINE_BF16_CMP
|
||||
|
||||
#endif // CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Additional C++ operator overrides between half types and native types.
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename U>
|
||||
constexpr bool is_integral_except =
|
||||
cuda::std::is_integral_v<T> && !cuda::std::is_same_v<T, U>;
|
||||
|
||||
template <typename T, typename U>
|
||||
constexpr bool is_arithmetic_except =
|
||||
cuda::std::is_arithmetic_v<T> && !cuda::std::is_same_v<T, U>;
|
||||
|
||||
#define MLX_DEFINE_HALF_OP(HALF, HALF2FLOAT, FLOAT2HALF, OP) \
|
||||
template < \
|
||||
typename T, \
|
||||
typename = cuda::std::enable_if_t<is_integral_except<T, HALF>>> \
|
||||
__forceinline__ __device__ HALF operator OP(HALF x, T y) { \
|
||||
return FLOAT2HALF(HALF2FLOAT(x) OP static_cast<float>(y)); \
|
||||
} \
|
||||
template < \
|
||||
typename T, \
|
||||
typename = cuda::std::enable_if_t<is_integral_except<T, HALF>>> \
|
||||
__forceinline__ __device__ HALF operator OP(T x, HALF y) { \
|
||||
return FLOAT2HALF(static_cast<float>(x) OP HALF2FLOAT(y)); \
|
||||
}
|
||||
|
||||
#define MLX_DEFINE_HALF_CMP(HALF, HALF2FLOAT, OP) \
|
||||
template < \
|
||||
typename T, \
|
||||
typename = cuda::std::enable_if_t<is_arithmetic_except<T, HALF>>> \
|
||||
__forceinline__ __device__ bool operator OP(HALF x, T y) { \
|
||||
return HALF2FLOAT(x) OP static_cast<float>(y); \
|
||||
} \
|
||||
template < \
|
||||
typename T, \
|
||||
typename = cuda::std::enable_if_t<is_arithmetic_except<T, HALF>>> \
|
||||
__forceinline__ __device__ bool operator OP(T x, HALF y) { \
|
||||
return static_cast<float>(y) OP HALF2FLOAT(x); \
|
||||
}
|
||||
|
||||
MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, +)
|
||||
MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, -)
|
||||
MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, *)
|
||||
MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, /)
|
||||
MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, +)
|
||||
MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, -)
|
||||
MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, *)
|
||||
MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, /)
|
||||
MLX_DEFINE_HALF_CMP(__half, __half2float, <)
|
||||
MLX_DEFINE_HALF_CMP(__half, __half2float, >)
|
||||
MLX_DEFINE_HALF_CMP(__half, __half2float, <=)
|
||||
MLX_DEFINE_HALF_CMP(__half, __half2float, >=)
|
||||
MLX_DEFINE_HALF_CMP(__half, __half2float, ==)
|
||||
MLX_DEFINE_HALF_CMP(__half, __half2float, !=)
|
||||
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, <)
|
||||
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, >)
|
||||
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, <=)
|
||||
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, >=)
|
||||
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, ==)
|
||||
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, !=)
|
||||
|
||||
#undef MLX_DEFINE_HALF_OP
|
||||
#undef MLX_DEFINE_HALF_CMP
|
||||
|
||||
} // namespace mlx::core::cu
|
163
mlx/backend/cuda/primitives.cu
Normal file
163
mlx/backend/cuda/primitives.cu
Normal file
@@ -0,0 +1,163 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/dtype_utils.cuh"
|
||||
#include "mlx/backend/cuda/kernels/arange.cuh"
|
||||
#include "mlx/backend/cuda/kernels/fp16_math.cuh"
|
||||
#include "mlx/distributed/primitives.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
#include <thrust/device_ptr.h>
|
||||
#include <thrust/transform.h>
|
||||
|
||||
#include <cassert>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("Arange::eval_gpu");
|
||||
assert(inputs.size() == 0);
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
auto& s = stream();
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&, this](cudaStream_t stream) {
|
||||
MLX_SWITCH_INT_FLOAT_TYPES_CHECKED(out.dtype(), "Arange", CTYPE, {
|
||||
using OutType = cuda_type_t<CTYPE>;
|
||||
CTYPE step =
|
||||
static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_);
|
||||
thrust::transform(
|
||||
cu::thrust_policy(stream),
|
||||
thrust::counting_iterator<uint32_t>(0),
|
||||
thrust::counting_iterator<uint32_t>(out.data_size()),
|
||||
thrust::device_pointer_cast(out.data<OutType>()),
|
||||
cu::Arange<OutType>{
|
||||
static_cast<OutType>(start_), static_cast<OutType>(step)});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
#define NO_GPU_MULTI(func) \
|
||||
void func::eval_gpu( \
|
||||
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
||||
throw std::runtime_error(#func " has no CUDA implementation."); \
|
||||
}
|
||||
|
||||
#define NO_GPU(func) \
|
||||
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
|
||||
throw std::runtime_error(#func " has no CUDA implementation."); \
|
||||
}
|
||||
|
||||
NO_GPU(Abs)
|
||||
NO_GPU(Add)
|
||||
NO_GPU(AddMM)
|
||||
NO_GPU(ArcCos)
|
||||
NO_GPU(ArcCosh)
|
||||
NO_GPU(ArcSin)
|
||||
NO_GPU(ArcSinh)
|
||||
NO_GPU(ArcTan)
|
||||
NO_GPU(ArcTan2)
|
||||
NO_GPU(ArcTanh)
|
||||
NO_GPU(ArgPartition)
|
||||
NO_GPU(ArgReduce)
|
||||
NO_GPU(ArgSort)
|
||||
NO_GPU(BitwiseBinary)
|
||||
NO_GPU(BitwiseInvert)
|
||||
NO_GPU(BlockMaskedMM)
|
||||
NO_GPU(Ceil)
|
||||
NO_GPU_MULTI(Compiled)
|
||||
NO_GPU(Conjugate)
|
||||
NO_GPU(Convolution)
|
||||
NO_GPU(Cos)
|
||||
NO_GPU(Cosh)
|
||||
NO_GPU(Divide)
|
||||
NO_GPU_MULTI(DivMod)
|
||||
NO_GPU(DynamicSlice)
|
||||
NO_GPU(DynamicSliceUpdate)
|
||||
NO_GPU(Remainder)
|
||||
NO_GPU(Equal)
|
||||
NO_GPU(Erf)
|
||||
NO_GPU(ErfInv)
|
||||
NO_GPU(Exp)
|
||||
NO_GPU(Expm1)
|
||||
NO_GPU(FFT)
|
||||
NO_GPU(Floor)
|
||||
NO_GPU(Gather)
|
||||
NO_GPU(GatherAxis)
|
||||
NO_GPU(GatherMM)
|
||||
NO_GPU(GatherQMM)
|
||||
NO_GPU(Greater)
|
||||
NO_GPU(GreaterEqual)
|
||||
NO_GPU(Hadamard)
|
||||
NO_GPU(Imag)
|
||||
NO_GPU(Less)
|
||||
NO_GPU(LessEqual)
|
||||
NO_GPU(Load)
|
||||
NO_GPU(Log)
|
||||
NO_GPU(Log1p)
|
||||
NO_GPU(LogicalNot)
|
||||
NO_GPU(LogicalAnd)
|
||||
NO_GPU(LogicalOr)
|
||||
NO_GPU(LogAddExp)
|
||||
NO_GPU(LogSumExp)
|
||||
NO_GPU_MULTI(LUF)
|
||||
NO_GPU(Matmul)
|
||||
NO_GPU(Maximum)
|
||||
NO_GPU(Minimum)
|
||||
NO_GPU(Multiply)
|
||||
NO_GPU(Negative)
|
||||
NO_GPU(NotEqual)
|
||||
NO_GPU(Partition)
|
||||
NO_GPU(Power)
|
||||
NO_GPU_MULTI(QRF)
|
||||
NO_GPU(QuantizedMatmul)
|
||||
NO_GPU(RandomBits)
|
||||
NO_GPU(Real)
|
||||
NO_GPU(Reduce)
|
||||
NO_GPU(Round)
|
||||
NO_GPU(Scan)
|
||||
NO_GPU(Scatter)
|
||||
NO_GPU(ScatterAxis)
|
||||
NO_GPU(Select)
|
||||
NO_GPU(Sigmoid)
|
||||
NO_GPU(Sign)
|
||||
NO_GPU(Sin)
|
||||
NO_GPU(Sinh)
|
||||
NO_GPU(SliceUpdate)
|
||||
NO_GPU(Softmax)
|
||||
NO_GPU(Sort)
|
||||
NO_GPU(Square)
|
||||
NO_GPU(Sqrt)
|
||||
NO_GPU(Subtract)
|
||||
NO_GPU_MULTI(SVD)
|
||||
NO_GPU(Tan)
|
||||
NO_GPU(Tanh)
|
||||
NO_GPU(Inverse)
|
||||
NO_GPU(Cholesky)
|
||||
NO_GPU_MULTI(Eigh)
|
||||
|
||||
namespace fast {
|
||||
NO_GPU_MULTI(LayerNorm)
|
||||
NO_GPU_MULTI(LayerNormVJP)
|
||||
NO_GPU_MULTI(RMSNorm)
|
||||
NO_GPU_MULTI(RMSNormVJP)
|
||||
NO_GPU_MULTI(RoPE)
|
||||
NO_GPU(ScaledDotProductAttention)
|
||||
NO_GPU_MULTI(AffineQuantize)
|
||||
NO_GPU_MULTI(CustomKernel)
|
||||
} // namespace fast
|
||||
|
||||
namespace distributed {
|
||||
NO_GPU_MULTI(AllReduce)
|
||||
NO_GPU_MULTI(AllGather)
|
||||
NO_GPU_MULTI(Send)
|
||||
NO_GPU_MULTI(Recv)
|
||||
} // namespace distributed
|
||||
|
||||
} // namespace mlx::core
|
15
mlx/backend/cuda/slicing.cpp
Normal file
15
mlx/backend/cuda/slicing.cpp
Normal file
@@ -0,0 +1,15 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/gpu/slicing.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void concatenate_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
int axis,
|
||||
const Stream& s) {
|
||||
throw std::runtime_error("concatenate_gpu not implemented in CUDA backend.");
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
26
mlx/backend/cuda/utils.cpp
Normal file
26
mlx/backend/cuda/utils.cpp
Normal file
@@ -0,0 +1,26 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/utils.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
|
||||
#include <fmt/format.h>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
CudaStream::CudaStream(cu::Device& device) {
|
||||
device.make_current();
|
||||
CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking));
|
||||
}
|
||||
|
||||
CudaStream::~CudaStream() {
|
||||
CHECK_CUDA_ERROR(cudaStreamDestroy(stream_));
|
||||
}
|
||||
|
||||
void check_cuda_error(const char* name, cudaError_t err) {
|
||||
if (err != cudaSuccess) {
|
||||
throw std::runtime_error(
|
||||
fmt::format("{} failed: {}", name, cudaGetErrorString(err)));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
36
mlx/backend/cuda/utils.h
Normal file
36
mlx/backend/cuda/utils.h
Normal file
@@ -0,0 +1,36 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
class Device;
|
||||
}
|
||||
|
||||
// Cuda stream managed with RAII.
|
||||
class CudaStream {
|
||||
public:
|
||||
explicit CudaStream(cu::Device& device);
|
||||
~CudaStream();
|
||||
|
||||
CudaStream(const CudaStream&) = delete;
|
||||
CudaStream& operator=(const CudaStream&) = delete;
|
||||
|
||||
operator cudaStream_t() const {
|
||||
return stream_;
|
||||
}
|
||||
|
||||
private:
|
||||
cudaStream_t stream_;
|
||||
};
|
||||
|
||||
// Throw exception if the cuda API does not succeed.
|
||||
void check_cuda_error(const char* name, cudaError_t err);
|
||||
|
||||
// The macro version that prints the command that failed.
|
||||
#define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd))
|
||||
|
||||
} // namespace mlx::core
|
90
mlx/backend/cuda/worker.cpp
Normal file
90
mlx/backend/cuda/worker.cpp
Normal file
@@ -0,0 +1,90 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/worker.h"
|
||||
#include "mlx/backend/cuda/allocator.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
Worker::Worker()
|
||||
: signal_stream_(device(mlx::core::Device::gpu)),
|
||||
worker_(&Worker::thread_fn, this) {}
|
||||
|
||||
Worker::~Worker() {
|
||||
{
|
||||
std::lock_guard lock(worker_mutex_);
|
||||
stop_ = true;
|
||||
}
|
||||
worker_event_.signal(batch_ + 1);
|
||||
worker_.join();
|
||||
}
|
||||
|
||||
void Worker::add_task(std::function<void()> task) {
|
||||
pending_tasks_.push_back(std::move(task));
|
||||
}
|
||||
|
||||
void Worker::consume_in_this_thread() {
|
||||
for (auto& task : pending_tasks_) {
|
||||
task();
|
||||
}
|
||||
pending_tasks_.clear();
|
||||
}
|
||||
|
||||
void Worker::end_batch() {
|
||||
batch_++;
|
||||
{
|
||||
std::lock_guard lock(worker_mutex_);
|
||||
worker_tasks_[batch_] = std::move(pending_tasks_);
|
||||
}
|
||||
uncommited_batches_++;
|
||||
}
|
||||
|
||||
void Worker::commit() {
|
||||
if (uncommited_batches_ == 0) {
|
||||
return;
|
||||
}
|
||||
uncommited_batches_ = 0;
|
||||
worker_event_.signal(batch_);
|
||||
}
|
||||
|
||||
void Worker::commit(cudaStream_t stream) {
|
||||
if (uncommited_batches_ == 0) {
|
||||
return;
|
||||
}
|
||||
uncommited_batches_ = 0;
|
||||
// Signal the |worker_event_| in |signal_stream_| after the kernels in
|
||||
// |stream_| finish running.
|
||||
signal_event_.record(stream);
|
||||
signal_event_.wait(signal_stream_);
|
||||
worker_event_.signal(signal_stream_, batch_);
|
||||
}
|
||||
|
||||
void Worker::thread_fn() {
|
||||
// The worker thread is safe to free buffers.
|
||||
allocator().register_this_thread();
|
||||
|
||||
while (!stop_) {
|
||||
uint64_t batch = worker_event_.value();
|
||||
Tasks tasks;
|
||||
{
|
||||
std::lock_guard lock(worker_mutex_);
|
||||
// Move tasks in signaled batches.
|
||||
auto end = worker_tasks_.upper_bound(batch);
|
||||
for (auto it = worker_tasks_.begin(); it != end; ++it) {
|
||||
if (tasks.empty()) {
|
||||
tasks = std::move(it->second);
|
||||
} else {
|
||||
std::move(
|
||||
it->second.begin(), it->second.end(), std::back_inserter(tasks));
|
||||
}
|
||||
}
|
||||
worker_tasks_.erase(worker_tasks_.begin(), end);
|
||||
}
|
||||
for (auto& task : tasks) {
|
||||
task();
|
||||
}
|
||||
worker_event_.wait(batch + 1);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
68
mlx/backend/cuda/worker.h
Normal file
68
mlx/backend/cuda/worker.h
Normal file
@@ -0,0 +1,68 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cuda/event.h"
|
||||
#include "mlx/backend/cuda/utils.h"
|
||||
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include <mutex>
|
||||
#include <thread>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
// Run tasks in worker thread, synchronized with cuda stream.
|
||||
class Worker {
|
||||
public:
|
||||
Worker();
|
||||
~Worker();
|
||||
|
||||
Worker(const Worker&) = delete;
|
||||
Worker& operator=(const Worker&) = delete;
|
||||
|
||||
// Add a pending |task| that will run when consumed or commited.
|
||||
void add_task(std::function<void()> task);
|
||||
|
||||
// Run pending tasks immediately in current thread.
|
||||
void consume_in_this_thread();
|
||||
|
||||
// Put pending tasks in a batch.
|
||||
void end_batch();
|
||||
|
||||
// Inform worker thread to run current batches now.
|
||||
void commit();
|
||||
|
||||
// Inform worker thread to run current batches after kernels in |stream|
|
||||
// finish running.
|
||||
void commit(cudaStream_t stream);
|
||||
|
||||
// Return how many batches have been added but not committed yet.
|
||||
size_t uncommited_batches() const {
|
||||
return uncommited_batches_;
|
||||
}
|
||||
|
||||
private:
|
||||
void thread_fn();
|
||||
|
||||
uint64_t batch_{0};
|
||||
size_t uncommited_batches_{0};
|
||||
|
||||
// Cuda stream and event for signaling kernel completion.
|
||||
CudaStream signal_stream_;
|
||||
CudaEvent signal_event_;
|
||||
|
||||
// Worker thread.
|
||||
SharedEvent worker_event_;
|
||||
std::thread worker_;
|
||||
std::mutex worker_mutex_;
|
||||
bool stop_{false};
|
||||
|
||||
// Tasks are put in |pending_tasks_| first, and then moved to
|
||||
// |worker_tasks_| when end_batch() is called.
|
||||
using Tasks = std::vector<std::function<void()>>;
|
||||
Tasks pending_tasks_;
|
||||
std::map<uint64_t, Tasks> worker_tasks_;
|
||||
};
|
||||
|
||||
} // namespace mlx::core::cu
|
5
mlx/backend/gpu/CMakeLists.txt
Normal file
5
mlx/backend/gpu/CMakeLists.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp)
|
9
mlx/backend/gpu/available.h
Normal file
9
mlx/backend/gpu/available.h
Normal file
@@ -0,0 +1,9 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace mlx::core::gpu {
|
||||
|
||||
bool is_available();
|
||||
|
||||
} // namespace mlx::core::gpu
|
49
mlx/backend/gpu/copy.cpp
Normal file
49
mlx/backend/gpu/copy.cpp
Normal file
@@ -0,0 +1,49 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <cassert>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
|
||||
bool donated = set_copy_output_data(in, out, ctype);
|
||||
if (donated && in.dtype() == out.dtype()) {
|
||||
// If the output has the same type as the input then there is nothing to
|
||||
// copy, just use the buffer.
|
||||
return;
|
||||
}
|
||||
if (ctype == CopyType::GeneralGeneral) {
|
||||
ctype = CopyType::General;
|
||||
}
|
||||
copy_gpu_inplace(in, out, ctype, s);
|
||||
}
|
||||
|
||||
void copy_gpu(const array& in, array& out, CopyType ctype) {
|
||||
copy_gpu(in, out, ctype, out.primitive().stream());
|
||||
}
|
||||
|
||||
void copy_gpu_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
CopyType ctype,
|
||||
const Stream& s) {
|
||||
assert(in.shape() == out.shape());
|
||||
return copy_gpu_inplace(
|
||||
in, out, in.shape(), in.strides(), out.strides(), 0, 0, ctype, s);
|
||||
}
|
||||
|
||||
void copy_gpu_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
const Strides& i_strides,
|
||||
int64_t i_offset,
|
||||
CopyType ctype,
|
||||
const Stream& s) {
|
||||
assert(in.shape() == out.shape());
|
||||
return copy_gpu_inplace(
|
||||
in, out, in.shape(), i_strides, out.strides(), i_offset, 0, ctype, s);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -5,6 +5,8 @@
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/stream.h"
|
||||
|
||||
#include <optional>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
// Generic copy inplace
|
@@ -8,14 +8,11 @@
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/stream.h"
|
||||
|
||||
namespace mlx::core::metal {
|
||||
namespace mlx::core::gpu {
|
||||
|
||||
void new_stream(Stream stream);
|
||||
|
||||
std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool();
|
||||
|
||||
void eval(array& arr);
|
||||
void finalize(Stream s);
|
||||
void synchronize(Stream s);
|
||||
|
||||
} // namespace mlx::core::metal
|
||||
} // namespace mlx::core::gpu
|
222
mlx/backend/gpu/primitives.cpp
Normal file
222
mlx/backend/gpu/primitives.cpp
Normal file
@@ -0,0 +1,222 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/gpu/slicing.h"
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#define MLX_PROFILER_RANGE(message)
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
void reshape(const array& in, array& out, Stream s) {
|
||||
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
||||
if (copy_necessary) {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
copy_gpu_inplace(
|
||||
in,
|
||||
out,
|
||||
in.shape(),
|
||||
in.strides(),
|
||||
make_contiguous_strides(in.shape()),
|
||||
0,
|
||||
0,
|
||||
CopyType::General,
|
||||
s);
|
||||
} else {
|
||||
shared_buffer_reshape(in, out_strides, out);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void AsStrided::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("AsStrided::eval_gpu");
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void AsType::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("AsType::eval_gpu");
|
||||
CopyType ctype =
|
||||
inputs[0].flags().contiguous ? CopyType::Vector : CopyType::General;
|
||||
copy_gpu(inputs[0], out, ctype);
|
||||
}
|
||||
|
||||
void Broadcast::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("Broadcast::eval_gpu");
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void BroadcastAxes::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("BroadcastAxes::eval_gpu");
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void Concatenate::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("Concatenate::eval_gpu");
|
||||
concatenate_gpu(inputs, out, axis_, stream());
|
||||
}
|
||||
|
||||
void Contiguous::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("Contiguous::eval_gpu");
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
constexpr size_t extra_bytes = 16384;
|
||||
if (in.buffer_size() <= out.nbytes() + extra_bytes &&
|
||||
(in.flags().row_contiguous ||
|
||||
(allow_col_major_ && in.flags().col_contiguous))) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
copy_gpu_inplace(
|
||||
in,
|
||||
out,
|
||||
in.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||
stream());
|
||||
}
|
||||
}
|
||||
|
||||
void Copy::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("Copy::eval_gpu");
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void CustomTransforms::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
MLX_PROFILER_RANGE("CustomTransforms::eval_gpu");
|
||||
eval(inputs, outputs);
|
||||
}
|
||||
|
||||
void Depends::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
MLX_PROFILER_RANGE("Depends::eval_gpu");
|
||||
eval(inputs, outputs);
|
||||
}
|
||||
|
||||
void ExpandDims::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("ExpandDims::eval_gpu");
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void Full::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("Full::eval_gpu");
|
||||
auto in = inputs[0];
|
||||
CopyType ctype;
|
||||
if (in.data_size() == 1) {
|
||||
ctype = CopyType::Scalar;
|
||||
} else if (in.flags().contiguous) {
|
||||
ctype = CopyType::Vector;
|
||||
} else {
|
||||
ctype = CopyType::General;
|
||||
}
|
||||
copy_gpu(in, out, ctype);
|
||||
}
|
||||
|
||||
void Flatten::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("Flatten::eval_gpu");
|
||||
reshape(inputs[0], out, stream());
|
||||
}
|
||||
|
||||
void NumberOfElements::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("NumberOfElements::eval_gpu");
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void Pad::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Inputs must be base input array and scalar val array
|
||||
assert(inputs.size() == 2);
|
||||
auto& in = inputs[0];
|
||||
auto& val = inputs[1];
|
||||
|
||||
// Padding value must be a scalar
|
||||
assert(val.size() == 1);
|
||||
|
||||
// Padding value, input and output must be of the same type
|
||||
assert(val.dtype() == in.dtype() && in.dtype() == out.dtype());
|
||||
|
||||
pad_gpu(in, val, out, axes_, low_pad_size_, stream());
|
||||
}
|
||||
|
||||
void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("Reshape::eval_gpu");
|
||||
reshape(inputs[0], out, stream());
|
||||
}
|
||||
|
||||
void Split::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
MLX_PROFILER_RANGE("Split::eval_gpu");
|
||||
eval(inputs, outputs);
|
||||
}
|
||||
|
||||
void Slice::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("Slice::eval_gpu");
|
||||
assert(inputs.size() == 1);
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
return;
|
||||
}
|
||||
|
||||
auto& in = inputs[0];
|
||||
slice_gpu(in, out, start_indices_, strides_, stream());
|
||||
}
|
||||
|
||||
void Squeeze::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("Squeeze::eval_gpu");
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void StopGradient::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("StopGradient::eval_gpu");
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void Transpose::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("Transpose::eval_gpu");
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void Unflatten::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("Unflatten::eval_gpu");
|
||||
reshape(inputs[0], out, stream());
|
||||
}
|
||||
|
||||
void View::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("View::eval_gpu");
|
||||
auto& in = inputs[0];
|
||||
auto ibytes = size_of(in.dtype());
|
||||
auto obytes = size_of(out.dtype());
|
||||
// Conditions for buffer copying (disjunction):
|
||||
// - type size is the same
|
||||
// - type size is smaller and the last axis is contiguous
|
||||
// - the entire array is row contiguous
|
||||
if (ibytes == obytes || (obytes < ibytes && in.strides().back() == 1) ||
|
||||
in.flags().row_contiguous) {
|
||||
auto strides = in.strides();
|
||||
for (int i = 0; i < static_cast<int>(strides.size()) - 1; ++i) {
|
||||
strides[i] *= ibytes;
|
||||
strides[i] /= obytes;
|
||||
}
|
||||
out.copy_shared_buffer(
|
||||
in, strides, in.flags(), in.data_size() * ibytes / obytes);
|
||||
} else {
|
||||
auto tmp = array(in.shape(), in.dtype(), nullptr, {});
|
||||
tmp.set_data(allocator::malloc(tmp.nbytes()));
|
||||
copy_gpu_inplace(in, tmp, CopyType::General, stream());
|
||||
|
||||
auto flags = out.flags();
|
||||
flags.contiguous = true;
|
||||
flags.row_contiguous = true;
|
||||
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
|
||||
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
|
||||
out.copy_shared_buffer(tmp, out.strides(), flags, out.size());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
44
mlx/backend/gpu/slicing.cpp
Normal file
44
mlx/backend/gpu/slicing.cpp
Normal file
@@ -0,0 +1,44 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/slicing.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/gpu/slicing.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void slice_gpu(
|
||||
const array& in,
|
||||
array& out,
|
||||
const Shape& start_indices,
|
||||
const Shape& strides,
|
||||
const Stream& s) {
|
||||
slice(in, out, start_indices, strides);
|
||||
}
|
||||
|
||||
void pad_gpu(
|
||||
const array& in,
|
||||
const array& val,
|
||||
array& out,
|
||||
const std::vector<int>& axes,
|
||||
const Shape& low_pad_size,
|
||||
const Stream& s) {
|
||||
// Fill output with val
|
||||
fill_gpu(val, out, s);
|
||||
|
||||
// Find offset for start of input values
|
||||
size_t data_offset = 0;
|
||||
for (int i = 0; i < axes.size(); i++) {
|
||||
auto ax = axes[i] < 0 ? out.ndim() + axes[i] : axes[i];
|
||||
data_offset += out.strides()[ax] * low_pad_size[i];
|
||||
}
|
||||
|
||||
// Extract slice from output where input will be pasted
|
||||
array out_slice(in.shape(), out.dtype(), nullptr, {});
|
||||
out_slice.copy_shared_buffer(
|
||||
out, out.strides(), out.flags(), out_slice.size(), data_offset);
|
||||
|
||||
// Copy input values into the slice
|
||||
copy_gpu_inplace(in, out_slice, CopyType::GeneralGeneral, s);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -61,6 +61,7 @@ if(MLX_METAL_JIT)
|
||||
kernels/steel/gemm/transforms.h)
|
||||
make_jit_source(steel/gemm/kernels/steel_gemm_fused)
|
||||
make_jit_source(steel/gemm/kernels/steel_gemm_masked kernels/steel/defines.h)
|
||||
make_jit_source(steel/gemm/kernels/steel_gemm_gather)
|
||||
make_jit_source(steel/gemm/kernels/steel_gemm_splitk)
|
||||
make_jit_source(
|
||||
steel/conv/conv
|
||||
@@ -92,6 +93,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
|
||||
|
@@ -1,7 +1,6 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#include "mlx/backend/metal/allocator.h"
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
#include "mlx/backend/metal/metal_impl.h"
|
||||
#include "mlx/backend/metal/resident.h"
|
||||
#include "mlx/memory.h"
|
||||
|
||||
|
@@ -90,7 +90,7 @@ void binary_op_gpu_inplace(
|
||||
work_per_thread = large ? 4 : 2;
|
||||
} else {
|
||||
large = out.data_size() > UINT32_MAX;
|
||||
work_per_thread = 1;
|
||||
work_per_thread = get_work_per_thread(a.dtype());
|
||||
}
|
||||
std::string kernel_name =
|
||||
get_kernel_name(bopt, op, a, large, shape.size(), work_per_thread);
|
||||
@@ -137,13 +137,20 @@ void binary_op_gpu_inplace(
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
} else {
|
||||
// Launch a 1D or 2D grid of threads
|
||||
size_t nthreads = out.data_size();
|
||||
size_t nthreads = ceildiv(out.data_size(), work_per_thread);
|
||||
if (thread_group_size > nthreads) {
|
||||
thread_group_size = nthreads;
|
||||
}
|
||||
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
MTL::Size grid_dims;
|
||||
if (large) {
|
||||
compute_encoder.set_bytes<int64_t>(out.data_size(), arg_idx++);
|
||||
grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread);
|
||||
} else {
|
||||
compute_encoder.set_bytes<int>(out.data_size(), arg_idx++);
|
||||
grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
}
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
@@ -64,6 +64,7 @@ inline void build_kernel(
|
||||
cnt++);
|
||||
}
|
||||
|
||||
std::string idx_type = use_big_index ? "int64_t" : "uint";
|
||||
if (add_indices) {
|
||||
os += fmt::format(
|
||||
" constant const int64_t* in_strides [[buffer({0})]],\n", cnt++);
|
||||
@@ -83,6 +84,9 @@ inline void build_kernel(
|
||||
" constant const int64_t* output_strides [[buffer({0})]],\n", cnt++);
|
||||
os += fmt::format(
|
||||
" constant const int* output_shape [[buffer({0})]],\n", cnt++);
|
||||
} else {
|
||||
os += fmt::format(
|
||||
" constant const {0}& size [[buffer({1})]],\n", idx_type, cnt++);
|
||||
}
|
||||
if (dynamic_dims) {
|
||||
os += fmt::format(" constant const int& ndim [[buffer({0})]],\n", cnt++);
|
||||
@@ -92,13 +96,14 @@ inline void build_kernel(
|
||||
os += " uint3 pos [[thread_position_in_grid]],\n";
|
||||
os += " uint3 grid [[threads_per_grid]]) {\n";
|
||||
|
||||
std::string idx_type = use_big_index ? "int64_t" : "uint";
|
||||
os += fmt::format(" constexpr int N_ = {0};\n", work_per_thread);
|
||||
if (contiguous && use_big_index) {
|
||||
// This is only used for contiguous kernels which don't have
|
||||
// a third grid dimension
|
||||
os += " int64_t index = pos.x + grid.x * int64_t(pos.y);\n";
|
||||
os += " int64_t index = N_ * (pos.x + grid.x * int64_t(pos.y));\n";
|
||||
} else if (contiguous) {
|
||||
os += " uint index = N_ * pos.x;\n";
|
||||
} else if (work_per_thread > 1) {
|
||||
os += fmt::format(" constexpr int N_ = {0};\n", work_per_thread);
|
||||
os += fmt::format(
|
||||
" int xshape = output_shape[{0}];\n",
|
||||
dynamic_dims ? "ndim - 1" : std::to_string(ndim - 1));
|
||||
@@ -110,6 +115,9 @@ inline void build_kernel(
|
||||
" {0} index = pos.x + grid.x * (pos.y + {0}(grid.y) * pos.z);\n",
|
||||
idx_type);
|
||||
}
|
||||
if (work_per_thread > 1 && contiguous) {
|
||||
os += " for (int i = 0; i < N_ && index < size; ++i) {\n";
|
||||
}
|
||||
|
||||
// Read constant / contiguous inputs in tmps
|
||||
std::vector<array> nc_inputs;
|
||||
@@ -193,7 +201,7 @@ inline void build_kernel(
|
||||
}
|
||||
|
||||
// Open per-thread loop
|
||||
if (work_per_thread > 1) {
|
||||
if (work_per_thread > 1 && !contiguous) {
|
||||
os +=
|
||||
" for (int i = 0; i < N_ && (int(N_ * pos.x) + i) < xshape; ++i) {\n";
|
||||
}
|
||||
@@ -272,6 +280,7 @@ void Compiled::eval_gpu(
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto lib = d.get_library(kernel_lib_, [&]() {
|
||||
int work_per_thread = get_work_per_thread(outputs_[0].dtype());
|
||||
std::string kernel = metal::utils();
|
||||
concatenate(
|
||||
kernel, metal::unary_ops(), metal::binary_ops(), metal::ternary_ops());
|
||||
@@ -284,7 +293,9 @@ void Compiled::eval_gpu(
|
||||
constant_ids_,
|
||||
/* contiguous = */ true,
|
||||
/* ndim = */ 0,
|
||||
/* dynamic_dims = */ false);
|
||||
/* dynamic_dims = */ false,
|
||||
/* use_big_index = */ false,
|
||||
/* work_per_thread = */ work_per_thread);
|
||||
build_kernel(
|
||||
kernel,
|
||||
kernel_lib_ + "_contiguous_large",
|
||||
@@ -295,7 +306,8 @@ void Compiled::eval_gpu(
|
||||
/* contiguous = */ true,
|
||||
/* ndim = */ 0,
|
||||
/* dynamic_dims = */ false,
|
||||
/* use_big_index = */ true);
|
||||
/* use_big_index = */ true,
|
||||
/* work_per_thread = */ work_per_thread);
|
||||
for (int i = 1; i < 8; i++) {
|
||||
build_kernel(
|
||||
kernel,
|
||||
@@ -468,6 +480,13 @@ void Compiled::eval_gpu(
|
||||
if (!contiguous) {
|
||||
compute_encoder.set_vector_bytes(strides[0], cnt++);
|
||||
compute_encoder.set_vector_bytes(shape, cnt++);
|
||||
} else {
|
||||
auto size = outputs[0].data_size();
|
||||
if (large) {
|
||||
compute_encoder.set_bytes<int64_t>(size, cnt++);
|
||||
} else {
|
||||
compute_encoder.set_bytes<int>(size, cnt++);
|
||||
}
|
||||
}
|
||||
|
||||
// Put the number of dims in if it is dynamic
|
||||
@@ -477,12 +496,13 @@ void Compiled::eval_gpu(
|
||||
|
||||
// Launch the kernel
|
||||
if (contiguous) {
|
||||
size_t nthreads = outputs[0].data_size();
|
||||
int work_per_thread = get_work_per_thread(outputs[0].dtype());
|
||||
size_t nthreads = ceildiv(outputs[0].data_size(), work_per_thread);
|
||||
MTL::Size group_dims(
|
||||
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
|
||||
|
||||
MTL::Size grid_dims = large
|
||||
? get_2d_grid_dims(outputs[0].shape(), outputs[0].strides())
|
||||
? get_2d_grid_dims(
|
||||
outputs[0].shape(), outputs[0].strides(), work_per_thread)
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
} else {
|
||||
|
@@ -5,7 +5,7 @@
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
|
@@ -1,35 +1,15 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
constexpr int MAX_COPY_SPECIALIZED_DIMS = 3;
|
||||
|
||||
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
|
||||
bool donated = set_copy_output_data(in, out, ctype);
|
||||
if (donated && in.dtype() == out.dtype()) {
|
||||
// If the output has the same type as the input then there is nothing to
|
||||
// copy, just use the buffer.
|
||||
return;
|
||||
}
|
||||
if (ctype == CopyType::GeneralGeneral) {
|
||||
ctype = CopyType::General;
|
||||
}
|
||||
copy_gpu_inplace(in, out, ctype, s);
|
||||
}
|
||||
|
||||
void copy_gpu(const array& in, array& out, CopyType ctype) {
|
||||
copy_gpu(in, out, ctype, out.primitive().stream());
|
||||
}
|
||||
|
||||
void copy_gpu_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
@@ -104,6 +84,8 @@ void copy_gpu_inplace(
|
||||
"[Copy::eval_gpu] Dynamic output offset requires GeneralGeneral copy");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
work_per_thread = get_work_per_thread(in.dtype());
|
||||
}
|
||||
concatenate(kernel_name, "_copy", type_to_name(in), type_to_name(out));
|
||||
auto kernel = dynamic ? get_dynamic_copy_kernel(d, kernel_name, in, out)
|
||||
@@ -165,39 +147,23 @@ void copy_gpu_inplace(
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
} else {
|
||||
size_t nthreads = out.data_size();
|
||||
size_t nthreads = ceildiv(out.data_size(), work_per_thread);
|
||||
if (thread_group_size > nthreads) {
|
||||
thread_group_size = nthreads;
|
||||
}
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
MTL::Size grid_dims;
|
||||
if (large) {
|
||||
compute_encoder.set_bytes<int64_t>(out.data_size(), 2);
|
||||
grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread);
|
||||
} else {
|
||||
compute_encoder.set_bytes<int>(out.data_size(), 2);
|
||||
grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
}
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
void copy_gpu_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
CopyType ctype,
|
||||
const Stream& s) {
|
||||
assert(in.shape() == out.shape());
|
||||
return copy_gpu_inplace(
|
||||
in, out, in.shape(), in.strides(), out.strides(), 0, 0, ctype, s);
|
||||
}
|
||||
|
||||
void copy_gpu_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
const Strides& i_strides,
|
||||
int64_t i_offset,
|
||||
CopyType ctype,
|
||||
const Stream& s) {
|
||||
assert(in.shape() == out.shape());
|
||||
return copy_gpu_inplace(
|
||||
in, out, in.shape(), i_strides, out.strides(), i_offset, 0, ctype, s);
|
||||
}
|
||||
|
||||
void fill_gpu(const array& val, array& out, const Stream& s) {
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
@@ -214,14 +180,21 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
|
||||
compute_encoder.set_input_array(val, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
|
||||
int work_per_thread = get_work_per_thread(val.dtype());
|
||||
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
size_t nthreads = out.data_size();
|
||||
size_t nthreads = ceildiv(out.data_size(), work_per_thread);
|
||||
if (thread_group_size > nthreads) {
|
||||
thread_group_size = nthreads;
|
||||
}
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
MTL::Size grid_dims;
|
||||
if (large) {
|
||||
compute_encoder.set_bytes<int64_t>(out.data_size(), 2);
|
||||
grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread);
|
||||
} else {
|
||||
compute_encoder.set_bytes<int>(out.data_size(), 2);
|
||||
grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
}
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
|
@@ -1,6 +1,6 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/metal/jit/includes.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
|
@@ -1,20 +1,20 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <filesystem>
|
||||
#include <sstream>
|
||||
|
||||
#include <sys/sysctl.h>
|
||||
|
||||
#define NS_PRIVATE_IMPLEMENTATION
|
||||
#define CA_PRIVATE_IMPLEMENTATION
|
||||
#define MTL_PRIVATE_IMPLEMENTATION
|
||||
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
#include "mlx/backend/metal/metal_impl.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace fs = std::filesystem;
|
||||
|
||||
namespace mlx::core::metal {
|
||||
|
||||
namespace {
|
||||
@@ -66,8 +66,8 @@ MTL::Library* try_load_bundle(
|
||||
if (bundle != nullptr) {
|
||||
std::string resource_path =
|
||||
std::string(bundle->resourceURL()->fileSystemRepresentation()) + "/" +
|
||||
lib_name + ".metallib" auto [lib, error] =
|
||||
load_library_from_path(device, resource_path.c_str());
|
||||
lib_name + ".metallib";
|
||||
auto [lib, error] = load_library_from_path(device, resource_path.c_str());
|
||||
if (lib) {
|
||||
return lib;
|
||||
}
|
||||
@@ -79,12 +79,18 @@ MTL::Library* try_load_bundle(
|
||||
// Firstly, search for the metallib in the same path as this binary
|
||||
std::pair<MTL::Library*, NS::Error*> load_colocated_library(
|
||||
MTL::Device* device,
|
||||
const std::string& lib_name) {
|
||||
std::string lib_path = get_colocated_mtllib_path(lib_name);
|
||||
if (lib_path.size() != 0) {
|
||||
return load_library_from_path(device, lib_path.c_str());
|
||||
const std::string& relative_path) {
|
||||
std::string binary_dir = get_binary_directory();
|
||||
if (binary_dir.size() == 0) {
|
||||
return {nullptr, nullptr};
|
||||
}
|
||||
return {nullptr, nullptr};
|
||||
|
||||
auto path = fs::path(binary_dir) / relative_path;
|
||||
if (!path.has_extension()) {
|
||||
path.replace_extension(".metallib");
|
||||
}
|
||||
|
||||
return load_library_from_path(device, path.c_str());
|
||||
}
|
||||
|
||||
std::pair<MTL::Library*, NS::Error*> load_swiftpm_library(
|
||||
@@ -99,7 +105,7 @@ std::pair<MTL::Library*, NS::Error*> load_swiftpm_library(
|
||||
auto bundles = NS::Bundle::allBundles();
|
||||
for (int i = 0, c = (int)bundles->count(); i < c; i++) {
|
||||
auto bundle = reinterpret_cast<NS::Bundle*>(bundles->object(i));
|
||||
library = try_load_bundle(device, bundle->resourceURL());
|
||||
library = try_load_bundle(device, bundle->resourceURL(), lib_name);
|
||||
if (library != nullptr) {
|
||||
return {library, nullptr};
|
||||
}
|
||||
@@ -109,33 +115,34 @@ std::pair<MTL::Library*, NS::Error*> load_swiftpm_library(
|
||||
}
|
||||
|
||||
MTL::Library* load_default_library(MTL::Device* device) {
|
||||
NS::Error *error1, *error2, *error3;
|
||||
NS::Error* error[4];
|
||||
MTL::Library* lib;
|
||||
// First try the colocated mlx.metallib
|
||||
std::tie(lib, error1) = load_colocated_library(device, "mlx");
|
||||
std::tie(lib, error[0]) = load_colocated_library(device, "mlx");
|
||||
if (lib) {
|
||||
return lib;
|
||||
}
|
||||
|
||||
std::tie(lib, error[1]) = load_colocated_library(device, "Resources/mlx");
|
||||
if (lib) {
|
||||
return lib;
|
||||
}
|
||||
|
||||
// Then try default.metallib in a SwiftPM bundle if we have one
|
||||
std::tie(lib, error2) = load_swiftpm_library(device, "default");
|
||||
std::tie(lib, error[2]) = load_swiftpm_library(device, "default");
|
||||
if (lib) {
|
||||
return lib;
|
||||
}
|
||||
|
||||
// Finally try default_mtllib_path
|
||||
std::tie(lib, error3) = load_library_from_path(device, default_mtllib_path);
|
||||
std::tie(lib, error[3]) = load_library_from_path(device, default_mtllib_path);
|
||||
if (!lib) {
|
||||
std::ostringstream msg;
|
||||
msg << "Failed to load the default metallib. ";
|
||||
if (error1 != nullptr) {
|
||||
msg << error1->localizedDescription()->utf8String() << " ";
|
||||
}
|
||||
if (error2 != nullptr) {
|
||||
msg << error2->localizedDescription()->utf8String() << " ";
|
||||
}
|
||||
if (error3 != nullptr) {
|
||||
msg << error3->localizedDescription()->utf8String() << " ";
|
||||
for (int i = 0; i < 4; i++) {
|
||||
if (error[i] != nullptr) {
|
||||
msg << error[i]->localizedDescription()->utf8String() << " ";
|
||||
}
|
||||
}
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
@@ -156,6 +163,7 @@ MTL::Library* load_library(
|
||||
<< error->localizedDescription()->utf8String();
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
return lib;
|
||||
}
|
||||
|
||||
// We have been given a path so try to load from lib_path / lib_name.metallib
|
||||
@@ -168,6 +176,7 @@ MTL::Library* load_library(
|
||||
<< "> with error " << error->localizedDescription()->utf8String();
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
return lib;
|
||||
}
|
||||
|
||||
// Try to load the colocated library
|
||||
@@ -188,8 +197,8 @@ MTL::Library* load_library(
|
||||
|
||||
std::ostringstream msg;
|
||||
msg << "Failed to load the metallib " << lib_name << ".metallib. "
|
||||
<< "We attempted to load it from <" << get_colocated_mtllib_path(lib_name)
|
||||
<< ">";
|
||||
<< "We attempted to load it from <" << get_binary_directory() << "/"
|
||||
<< lib_name << ".metallib" << ">";
|
||||
#ifdef SWIFTPM_BUNDLE
|
||||
msg << " and from the Swift PM bundle.";
|
||||
#endif
|
||||
@@ -760,42 +769,4 @@ std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool() {
|
||||
NS::AutoreleasePool::alloc()->init(), dtor);
|
||||
}
|
||||
|
||||
void new_stream(Stream stream) {
|
||||
if (stream.device == mlx::core::Device::gpu) {
|
||||
device(stream.device).new_queue(stream.index);
|
||||
}
|
||||
}
|
||||
|
||||
const std::unordered_map<std::string, std::variant<std::string, size_t>>&
|
||||
device_info() {
|
||||
auto init_device_info = []()
|
||||
-> std::unordered_map<std::string, std::variant<std::string, size_t>> {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
auto raw_device = device(default_device()).mtl_device();
|
||||
auto name = std::string(raw_device->name()->utf8String());
|
||||
auto arch = std::string(raw_device->architecture()->name()->utf8String());
|
||||
|
||||
size_t memsize = 0;
|
||||
size_t length = sizeof(memsize);
|
||||
sysctlbyname("hw.memsize", &memsize, &length, NULL, 0);
|
||||
|
||||
size_t rsrc_limit = 0;
|
||||
sysctlbyname("iogpu.rsrc_limit", &rsrc_limit, &length, NULL, 0);
|
||||
if (rsrc_limit == 0) {
|
||||
rsrc_limit = 499000;
|
||||
}
|
||||
|
||||
return {
|
||||
{"device_name", name},
|
||||
{"architecture", arch},
|
||||
{"max_buffer_length", raw_device->maxBufferLength()},
|
||||
{"max_recommended_working_set_size",
|
||||
raw_device->recommendedMaxWorkingSetSize()},
|
||||
{"memory_size", memsize},
|
||||
{"resource_limit", rsrc_limit}};
|
||||
};
|
||||
static auto device_info_ = init_device_info();
|
||||
return device_info_;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::metal
|
||||
|
@@ -21,18 +21,14 @@ namespace mlx::core::metal {
|
||||
|
||||
// Note, this function must be left inline in a header so that it is not
|
||||
// dynamically linked.
|
||||
inline std::string get_colocated_mtllib_path(const std::string& lib_name) {
|
||||
inline std::string get_binary_directory() {
|
||||
Dl_info info;
|
||||
std::string mtllib_path;
|
||||
std::string lib_ext = lib_name + ".metallib";
|
||||
|
||||
int success = dladdr((void*)get_colocated_mtllib_path, &info);
|
||||
std::string directory;
|
||||
int success = dladdr((void*)get_binary_directory, &info);
|
||||
if (success) {
|
||||
auto mtllib = fs::path(info.dli_fname).remove_filename() / lib_ext;
|
||||
mtllib_path = mtllib.c_str();
|
||||
directory = fs::path(info.dli_fname).remove_filename().c_str();
|
||||
}
|
||||
|
||||
return mtllib_path;
|
||||
return directory;
|
||||
}
|
||||
|
||||
using MTLFCList =
|
||||
@@ -270,4 +266,6 @@ class Device {
|
||||
|
||||
Device& device(mlx::core::Device);
|
||||
|
||||
std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool();
|
||||
|
||||
} // namespace mlx::core::metal
|
||||
|
@@ -4,7 +4,7 @@
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/distributed/ops.h"
|
||||
|
102
mlx/backend/metal/eval.cpp
Normal file
102
mlx/backend/metal/eval.cpp
Normal file
@@ -0,0 +1,102 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#include <memory>
|
||||
|
||||
#include "mlx/backend/gpu/available.h"
|
||||
#include "mlx/backend/gpu/eval.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/scheduler.h"
|
||||
|
||||
namespace mlx::core::gpu {
|
||||
|
||||
bool is_available() {
|
||||
return true;
|
||||
}
|
||||
|
||||
void new_stream(Stream stream) {
|
||||
if (stream.device == mlx::core::Device::gpu) {
|
||||
metal::device(stream.device).new_queue(stream.index);
|
||||
}
|
||||
}
|
||||
|
||||
inline void check_error(MTL::CommandBuffer* cbuf) {
|
||||
if (cbuf->status() == MTL::CommandBufferStatusError) {
|
||||
std::ostringstream msg;
|
||||
msg << "[METAL] Command buffer execution failed: "
|
||||
<< cbuf->error()->localizedDescription()->utf8String();
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
void eval(array& arr) {
|
||||
auto pool = metal::new_scoped_memory_pool();
|
||||
auto s = arr.primitive().stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto command_buffer = d.get_command_buffer(s.index);
|
||||
|
||||
auto outputs = arr.outputs();
|
||||
{
|
||||
// If the array is a tracer hold a reference
|
||||
// to its inputs so they don't get donated
|
||||
std::vector<array> inputs;
|
||||
if (arr.is_tracer()) {
|
||||
inputs = arr.inputs();
|
||||
}
|
||||
|
||||
debug_set_primitive_buffer_label(command_buffer, arr.primitive());
|
||||
arr.primitive().eval_gpu(arr.inputs(), outputs);
|
||||
}
|
||||
std::unordered_set<std::shared_ptr<array::Data>> buffers;
|
||||
for (auto& in : arr.inputs()) {
|
||||
buffers.insert(in.data_shared_ptr());
|
||||
}
|
||||
for (auto& s : arr.siblings()) {
|
||||
buffers.insert(s.data_shared_ptr());
|
||||
}
|
||||
// Remove the output if it was donated to by an input
|
||||
if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) {
|
||||
buffers.erase(it);
|
||||
}
|
||||
|
||||
if (d.command_buffer_needs_commit(s.index)) {
|
||||
d.end_encoding(s.index);
|
||||
scheduler::notify_new_task(s);
|
||||
command_buffer->addCompletedHandler(
|
||||
[s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
|
||||
scheduler::notify_task_completion(s);
|
||||
check_error(cbuf);
|
||||
});
|
||||
d.commit_command_buffer(s.index);
|
||||
d.get_command_buffer(s.index);
|
||||
} else {
|
||||
command_buffer->addCompletedHandler(
|
||||
[s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
|
||||
check_error(cbuf);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void finalize(Stream s) {
|
||||
auto pool = metal::new_scoped_memory_pool();
|
||||
auto& d = metal::device(s.device);
|
||||
auto cb = d.get_command_buffer(s.index);
|
||||
d.end_encoding(s.index);
|
||||
cb->addCompletedHandler([s](MTL::CommandBuffer* cbuf) { check_error(cbuf); });
|
||||
d.commit_command_buffer(s.index);
|
||||
d.get_command_buffer(s.index);
|
||||
}
|
||||
|
||||
void synchronize(Stream s) {
|
||||
auto pool = metal::new_scoped_memory_pool();
|
||||
auto& d = metal::device(s.device);
|
||||
auto cb = d.get_command_buffer(s.index);
|
||||
cb->retain();
|
||||
d.end_encoding(s.index);
|
||||
d.commit_command_buffer(s.index);
|
||||
cb->waitUntilCompleted();
|
||||
check_error(cb);
|
||||
cb->release();
|
||||
}
|
||||
|
||||
} // namespace mlx::core::gpu
|
@@ -2,7 +2,6 @@
|
||||
|
||||
#include "mlx/event.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/metal_impl.h"
|
||||
#include "mlx/scheduler.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
@@ -1,7 +1,6 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
#include "mlx/fence.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/metal_impl.h"
|
||||
#include "mlx/scheduler.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
@@ -139,7 +138,7 @@ void Fence::update(Stream stream, const array& x) {
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(x, 0);
|
||||
compute_encoder.set_bytes(nthreads, 1);
|
||||
compute_encoder.dispatch_threadgroups(group_dims, grid_dims);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
|
||||
// Barrier on previous kernels
|
||||
compute_encoder.barrier();
|
||||
|
@@ -1,16 +1,18 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
#include <cassert>
|
||||
#include <complex>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <numeric>
|
||||
#include <set>
|
||||
|
||||
#include "mlx/3rdparty/pocketfft.h"
|
||||
#include "mlx/backend/common/transpose.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/gpu/slicing.h"
|
||||
#include "mlx/backend/metal/binary.h"
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/slicing.h"
|
||||
#include "mlx/backend/metal/unary.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/utils.h"
|
||||
@@ -27,7 +29,7 @@ using MTLFC = std::tuple<const void*, MTL::DataType, NS::UInteger>;
|
||||
// For strided reads/writes, coalesce at least this many complex64s
|
||||
#define MIN_COALESCE_WIDTH 4
|
||||
|
||||
inline const std::vector<int> supported_radices() {
|
||||
inline constexpr std::array<int, 9> supported_radices() {
|
||||
// Ordered by preference in decomposition.
|
||||
return {13, 11, 8, 7, 6, 5, 4, 3, 2};
|
||||
}
|
||||
@@ -49,6 +51,35 @@ std::vector<int> prime_factors(int n) {
|
||||
return factors;
|
||||
}
|
||||
|
||||
int next_fast_n(int n) {
|
||||
return next_power_of_2(n);
|
||||
}
|
||||
|
||||
std::vector<int> stockham_decompose(int n) {
|
||||
auto radices = supported_radices();
|
||||
std::vector<int> steps(radices.size(), 0);
|
||||
int orig_n = n;
|
||||
|
||||
for (int i = 0; i < radices.size(); i++) {
|
||||
int radix = radices[i];
|
||||
|
||||
// Manually tuned radices for powers of 2
|
||||
if (is_power_of_2(orig_n) && orig_n < 512 && radix > 4) {
|
||||
continue;
|
||||
}
|
||||
|
||||
while (n % radix == 0) {
|
||||
steps[i] += 1;
|
||||
n /= radix;
|
||||
if (n == 1) {
|
||||
return steps;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
struct FourStepParams {
|
||||
bool required = false;
|
||||
bool first_step = true;
|
||||
@@ -65,9 +96,10 @@ void fft_op(
|
||||
bool real,
|
||||
const FourStepParams four_step_params,
|
||||
bool inplace,
|
||||
metal::Device& d,
|
||||
const Stream& s);
|
||||
|
||||
struct FFTPlan {
|
||||
struct OldFFTPlan {
|
||||
int n = 0;
|
||||
// Number of steps for each radix in the Stockham decomposition
|
||||
std::vector<int> stockham;
|
||||
@@ -82,9 +114,104 @@ struct FFTPlan {
|
||||
int n2 = 0;
|
||||
};
|
||||
|
||||
int next_fast_n(int n) {
|
||||
return next_power_of_2(n);
|
||||
}
|
||||
class FFTPlan {
|
||||
public:
|
||||
enum FFTType {
|
||||
UNSUPPORTED,
|
||||
NOOP,
|
||||
STOCKHAM,
|
||||
RADER,
|
||||
BLUESTEIN,
|
||||
MULTIUPLOAD_BLUESTEIN,
|
||||
SMALL_FOUR_STEP,
|
||||
LARGE_FOUR_STEP
|
||||
};
|
||||
|
||||
FFTPlan(int n) : n_(n) {
|
||||
// NOOP
|
||||
if (n == 1) {
|
||||
type_ = NOOP;
|
||||
}
|
||||
|
||||
// Too large for Stockham so do four step fft for powers of 2
|
||||
else if (n > MAX_STOCKHAM_FFT_SIZE && is_power_of_2(n)) {
|
||||
if (n <= 1 << 20) {
|
||||
type_ = SMALL_FOUR_STEP;
|
||||
n2_ = n > 65536 ? 1024 : 64;
|
||||
n1_ = n / n2_;
|
||||
steps1_ = stockham_decompose(n1_);
|
||||
steps2_ = stockham_decompose(n2_);
|
||||
} else {
|
||||
type_ = LARGE_FOUR_STEP;
|
||||
}
|
||||
}
|
||||
|
||||
// Too large and not power of 2 so do multi-upload Bluestein fft
|
||||
else if (n > MAX_STOCKHAM_FFT_SIZE) {
|
||||
type_ = MULTIUPLOAD_BLUESTEIN;
|
||||
bluestein_n_ = next_fast_n(2 * n - 1);
|
||||
}
|
||||
|
||||
// Stockham fft
|
||||
else if (auto steps = stockham_decompose(n); steps.size() > 0) {
|
||||
type_ = STOCKHAM;
|
||||
steps_ = steps;
|
||||
}
|
||||
|
||||
// Add rader but for now simply fall back to bluestein when stockham not
|
||||
// posssible
|
||||
else if (n > MAX_BLUESTEIN_FFT_SIZE) {
|
||||
type_ = MULTIUPLOAD_BLUESTEIN;
|
||||
bluestein_n_ = next_fast_n(2 * n - 1);
|
||||
} else {
|
||||
type_ = BLUESTEIN;
|
||||
bluestein_n_ = next_fast_n(2 * n - 1);
|
||||
steps_ = stockham_decompose(bluestein_n_);
|
||||
}
|
||||
}
|
||||
|
||||
FFTType type() const {
|
||||
return type_;
|
||||
}
|
||||
|
||||
int size() const {
|
||||
return n_;
|
||||
}
|
||||
|
||||
const std::vector<int>& steps() const {
|
||||
return steps_;
|
||||
}
|
||||
|
||||
int first_size() const {
|
||||
return n1_;
|
||||
}
|
||||
|
||||
const std::vector<int>& first_steps() const {
|
||||
return steps1_;
|
||||
}
|
||||
|
||||
int second_size() const {
|
||||
return n2_;
|
||||
}
|
||||
|
||||
const std::vector<int>& second_steps() const {
|
||||
return steps2_;
|
||||
}
|
||||
|
||||
int bluestein_size() const {
|
||||
return bluestein_n_;
|
||||
}
|
||||
|
||||
private:
|
||||
int n_;
|
||||
FFTType type_;
|
||||
std::vector<int> steps_;
|
||||
int n1_;
|
||||
std::vector<int> steps1_;
|
||||
int n2_;
|
||||
std::vector<int> steps2_;
|
||||
int bluestein_n_;
|
||||
};
|
||||
|
||||
std::vector<int> plan_stockham_fft(int n) {
|
||||
auto radices = supported_radices();
|
||||
@@ -110,15 +237,12 @@ std::vector<int> plan_stockham_fft(int n) {
|
||||
throw std::runtime_error("Unplannable");
|
||||
}
|
||||
|
||||
FFTPlan plan_fft(int n) {
|
||||
OldFFTPlan plan_fft(int n) {
|
||||
auto radices = supported_radices();
|
||||
std::set<int> radices_set(radices.begin(), radices.end());
|
||||
|
||||
FFTPlan plan;
|
||||
OldFFTPlan plan;
|
||||
plan.n = n;
|
||||
plan.rader = std::vector<int>(radices.size(), 0);
|
||||
auto factors = prime_factors(n);
|
||||
int remaining_n = n;
|
||||
|
||||
// Four Step FFT when N is too large for shared mem.
|
||||
if (n > MAX_STOCKHAM_FFT_SIZE && is_power_of_2(n)) {
|
||||
@@ -128,16 +252,20 @@ FFTPlan plan_fft(int n) {
|
||||
plan.n2 = n > 65536 ? 1024 : 64;
|
||||
plan.n1 = n / plan.n2;
|
||||
return plan;
|
||||
} else if (n > MAX_STOCKHAM_FFT_SIZE) {
|
||||
}
|
||||
|
||||
if (n > MAX_STOCKHAM_FFT_SIZE) {
|
||||
// Otherwise we use a multi-upload Bluestein's
|
||||
plan.four_step = true;
|
||||
plan.bluestein_n = next_fast_n(2 * n - 1);
|
||||
return plan;
|
||||
}
|
||||
|
||||
int remaining_n = n;
|
||||
auto factors = prime_factors(n);
|
||||
for (int factor : factors) {
|
||||
// Make sure the factor is a supported radix
|
||||
if (radices_set.find(factor) == radices_set.end()) {
|
||||
if (std::find(radices.begin(), radices.end(), factor) == radices.end()) {
|
||||
// We only support a single Rader factor currently
|
||||
// TODO(alexbarron) investigate weirdness with large
|
||||
// Rader sizes -- possibly a compiler issue?
|
||||
@@ -154,7 +282,7 @@ FFTPlan plan_fft(int n) {
|
||||
for (int rf : rader_factors) {
|
||||
// We don't nest Rader's algorithm so if `factor - 1`
|
||||
// isn't Stockham decomposable we give up and do Bluestein's.
|
||||
if (radices_set.find(rf) == radices_set.end()) {
|
||||
if (std::find(radices.begin(), radices.end(), rf) == radices.end()) {
|
||||
plan.four_step = n > MAX_BLUESTEIN_FFT_SIZE;
|
||||
plan.bluestein_n = next_fast_n(2 * n - 1);
|
||||
plan.stockham = plan_stockham_fft(plan.bluestein_n);
|
||||
@@ -172,7 +300,7 @@ FFTPlan plan_fft(int n) {
|
||||
return plan;
|
||||
}
|
||||
|
||||
int compute_elems_per_thread(FFTPlan plan) {
|
||||
int compute_elems_per_thread(OldFFTPlan plan) {
|
||||
// Heuristics for selecting an efficient number
|
||||
// of threads to use for a particular mixed-radix FFT.
|
||||
auto n = plan.n;
|
||||
@@ -355,9 +483,11 @@ void multi_upload_bluestein_fft(
|
||||
size_t axis,
|
||||
bool inverse,
|
||||
bool real,
|
||||
FFTPlan& plan,
|
||||
OldFFTPlan& plan,
|
||||
std::vector<array>& copies,
|
||||
const Stream& s) {
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
// TODO(alexbarron) Implement fused kernels for mutli upload bluestein's
|
||||
// algorithm
|
||||
int n = inverse ? out.shape(axis) : in.shape(axis);
|
||||
@@ -420,6 +550,7 @@ void multi_upload_bluestein_fft(
|
||||
/*real=*/false,
|
||||
FourStepParams(),
|
||||
/*inplace=*/false,
|
||||
d,
|
||||
s);
|
||||
copies.push_back(pad_temp1);
|
||||
|
||||
@@ -435,6 +566,7 @@ void multi_upload_bluestein_fft(
|
||||
/* real= */ false,
|
||||
FourStepParams(),
|
||||
/*inplace=*/true,
|
||||
d,
|
||||
s);
|
||||
|
||||
int offset = plan.bluestein_n - (2 * n - 1);
|
||||
@@ -480,7 +612,7 @@ void four_step_fft(
|
||||
size_t axis,
|
||||
bool inverse,
|
||||
bool real,
|
||||
FFTPlan& plan,
|
||||
OldFFTPlan& plan,
|
||||
std::vector<array>& copies,
|
||||
const Stream& s,
|
||||
bool in_place) {
|
||||
@@ -493,7 +625,15 @@ void four_step_fft(
|
||||
auto temp_shape = (real && inverse) ? out.shape() : in.shape();
|
||||
array temp(temp_shape, complex64, nullptr, {});
|
||||
fft_op(
|
||||
in, temp, axis, inverse, real, four_step_params, /*inplace=*/false, s);
|
||||
in,
|
||||
temp,
|
||||
axis,
|
||||
inverse,
|
||||
real,
|
||||
four_step_params,
|
||||
/*inplace=*/false,
|
||||
d,
|
||||
s);
|
||||
four_step_params.first_step = false;
|
||||
fft_op(
|
||||
temp,
|
||||
@@ -503,6 +643,7 @@ void four_step_fft(
|
||||
real,
|
||||
four_step_params,
|
||||
/*inplace=*/in_place,
|
||||
d,
|
||||
s);
|
||||
copies.push_back(temp);
|
||||
} else {
|
||||
@@ -518,9 +659,8 @@ void fft_op(
|
||||
bool real,
|
||||
const FourStepParams four_step_params,
|
||||
bool inplace,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
size_t n = out.dtype() == float32 ? out.shape(axis) : in.shape(axis);
|
||||
if (n == 1) {
|
||||
out.copy_shared_buffer(in);
|
||||
@@ -755,57 +895,517 @@ void fft_op(
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
}
|
||||
|
||||
void fft_op(
|
||||
inline int compute_elems_per_thread(int n, const std::vector<int>& steps) {
|
||||
auto radices = supported_radices();
|
||||
std::set<int> used_radices;
|
||||
for (int i = 0; i < steps.size(); i++) {
|
||||
if (steps[i] > 0) {
|
||||
used_radices.insert(radices[i % radices.size()]);
|
||||
}
|
||||
}
|
||||
|
||||
// Manual tuning for 7/11/13
|
||||
if (used_radices.find(7) != used_radices.end() &&
|
||||
(used_radices.find(11) != used_radices.end() ||
|
||||
used_radices.find(13) != used_radices.end())) {
|
||||
return 7;
|
||||
} else if (
|
||||
used_radices.find(11) != used_radices.end() &&
|
||||
used_radices.find(13) != used_radices.end()) {
|
||||
return 11;
|
||||
}
|
||||
|
||||
// TODO(alexbarron) Some really weird stuff is going on
|
||||
// for certain `elems_per_thread` on large composite n.
|
||||
// Possibly a compiler issue?
|
||||
if (n == 3159)
|
||||
return 13;
|
||||
if (n == 3645)
|
||||
return 5;
|
||||
if (n == 3969)
|
||||
return 7;
|
||||
if (n == 1982)
|
||||
return 5;
|
||||
|
||||
if (used_radices.size() == 1) {
|
||||
return *(used_radices.begin());
|
||||
}
|
||||
if (used_radices.size() == 2 &&
|
||||
(used_radices.find(11) != used_radices.end() ||
|
||||
used_radices.find(13) != used_radices.end())) {
|
||||
return std::accumulate(used_radices.begin(), used_radices.end(), 0) / 2;
|
||||
}
|
||||
|
||||
// In all other cases use the second smallest radix.
|
||||
return *(++used_radices.begin());
|
||||
}
|
||||
|
||||
inline array ensure_fastest_moving_axis(
|
||||
const array& x,
|
||||
int axis,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
// The axis is already with a stride of 1 so check that we have no overlaps
|
||||
// and broadcasting and avoid the copy.
|
||||
if (x.strides(axis) == 1) {
|
||||
// This is a fairly strict test perhaps consider relaxing it in the future.
|
||||
if (x.flags().row_contiguous || x.flags().col_contiguous) {
|
||||
return x;
|
||||
}
|
||||
}
|
||||
|
||||
// To make it the fastest moving axis simply transpose it, then copy it and
|
||||
// then transpose it back.
|
||||
|
||||
// Transpose it
|
||||
std::vector<int> axes(x.ndim(), 0);
|
||||
for (int ax = 0; ax < axes.size(); ax++) {
|
||||
axes[ax] = (ax < axis) ? ax : ax + 1;
|
||||
}
|
||||
axes.back() = axis;
|
||||
Shape xtshape;
|
||||
xtshape.reserve(axes.size());
|
||||
for (auto ax : axes) {
|
||||
xtshape.push_back(x.shape(ax));
|
||||
}
|
||||
array xt(xtshape, x.dtype(), nullptr, {});
|
||||
transpose(x, xt, axes);
|
||||
|
||||
// Copy it
|
||||
array xtc(xt.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(
|
||||
xt,
|
||||
xtc,
|
||||
xt.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||
s);
|
||||
d.add_temporary(xtc, s.index);
|
||||
|
||||
// Transpose it
|
||||
for (int ax = 0; ax < axes.size(); ax++) {
|
||||
axes[ax] = (ax < axis) ? ax : ((ax == axis) ? axes.size() - 1 : ax - 1);
|
||||
}
|
||||
array y(x.shape(), x.dtype(), nullptr, {});
|
||||
transpose(xtc, y, axes);
|
||||
|
||||
return y;
|
||||
}
|
||||
|
||||
inline void prepare_output_array(const array& in, array& out, int axis) {
|
||||
// Prepare the output array such that it matches the input in terms of
|
||||
// stride ordering. Namely we might have moved `axis` around in the `in`
|
||||
// array. We must do the same in `out`. The difference is that we don't have
|
||||
// to copy anything because `out` contains garbage at the moment.
|
||||
|
||||
if (in.flags().row_contiguous && out.flags().row_contiguous) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<int> axes(out.ndim(), 0);
|
||||
for (int ax = 0; ax < axes.size(); ax++) {
|
||||
axes[ax] = (ax < axis) ? ax : ax + 1;
|
||||
}
|
||||
axes.back() = axis;
|
||||
as_transposed(out, axes);
|
||||
}
|
||||
|
||||
void fft_stockham_inplace(
|
||||
const FFTPlan& plan,
|
||||
const array& in_,
|
||||
array& out,
|
||||
size_t axis,
|
||||
bool inverse,
|
||||
bool real,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
// Prepare the input and output arrays such that `axis` has stride 1.
|
||||
// Possibly copy the input but never the output as it doesn't have anything
|
||||
// useful in it yet.
|
||||
array in = ensure_fastest_moving_axis(in_, axis, d, s);
|
||||
prepare_output_array(in, out, axis);
|
||||
|
||||
// Prepare the arguments for stockham fft
|
||||
int n = plan.size();
|
||||
bool power_of_2 = is_power_of_2(n);
|
||||
int total_batch_size =
|
||||
out.dtype() == float32 ? out.size() / n : in.size() / n;
|
||||
auto& steps = plan.steps();
|
||||
int elems_per_thread = compute_elems_per_thread(n, steps);
|
||||
int threads_per_fft = ceildiv(n, elems_per_thread);
|
||||
int tg_batch_size = std::max(MIN_THREADGROUP_MEM_SIZE / n, 1);
|
||||
int tg_mem_size = next_power_of_2(tg_batch_size * n);
|
||||
int batch_size = ceildiv(total_batch_size, tg_batch_size);
|
||||
batch_size = real ? ceildiv(batch_size, 2) : batch_size; // 2 RFFTs at once
|
||||
std::vector<MTLFC> func_consts = {
|
||||
{&inverse, MTL::DataType::DataTypeBool, 0},
|
||||
{&power_of_2, MTL::DataType::DataTypeBool, 1},
|
||||
{&elems_per_thread, MTL::DataType::DataTypeInt, 2}};
|
||||
for (int i = 0; i < steps.size(); i++) {
|
||||
func_consts.emplace_back(&steps[i], MTL::DataType::DataTypeInt, 4 + i);
|
||||
}
|
||||
|
||||
// Get the kernel
|
||||
auto in_type = in.dtype() == float32 ? "float" : "float2";
|
||||
auto out_type = out.dtype() == float32 ? "float" : "float2";
|
||||
std::string hash_name;
|
||||
std::string kname;
|
||||
kname.reserve(64);
|
||||
hash_name.reserve(64);
|
||||
concatenate(kname, "fft_mem_", tg_mem_size, "_", in_type, "_", out_type);
|
||||
concatenate(hash_name, kname, "_n", n, "_inv_", inverse);
|
||||
auto template_def =
|
||||
get_template_definition(kname, "fft", tg_mem_size, in_type, out_type);
|
||||
auto kernel = get_fft_kernel(d, kname, hash_name, func_consts, template_def);
|
||||
|
||||
// Launch it
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder.set_bytes(n, 2);
|
||||
compute_encoder.set_bytes(total_batch_size, 3);
|
||||
|
||||
MTL::Size group_dims(1, tg_batch_size, threads_per_fft);
|
||||
MTL::Size grid_dims(batch_size, tg_batch_size, threads_per_fft);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void fft_four_step_inplace(
|
||||
const FFTPlan& plan,
|
||||
const array& in_,
|
||||
array& out,
|
||||
size_t axis,
|
||||
bool inverse,
|
||||
bool real,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
// Prepare the input and output arrays such that `axis` has stride 1.
|
||||
// Possibly copy the input but never the output as it doesn't have anything
|
||||
// useful in it yet.
|
||||
array in = ensure_fastest_moving_axis(in_, axis, d, s);
|
||||
prepare_output_array(in, out, axis);
|
||||
|
||||
// Also prepare the intermediate array for the four-step fft which is
|
||||
// implemented with 2 kernel calls.
|
||||
array intermediate(
|
||||
(real && inverse) ? out.shape() : in.shape(), complex64, nullptr, {});
|
||||
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
|
||||
prepare_output_array(in, intermediate, axis);
|
||||
d.add_temporary(intermediate, s.index);
|
||||
|
||||
// Make the two calls
|
||||
for (int step = 0; step < 2; step++) {
|
||||
// Create the parameters
|
||||
int n1 = plan.first_size();
|
||||
int n2 = plan.second_size();
|
||||
int n = (step == 0) ? n1 : n2;
|
||||
bool power_of_2 = true;
|
||||
int total_batch_size =
|
||||
out.dtype() == float32 ? out.size() / n : in.size() / n;
|
||||
auto& steps = (step == 0) ? plan.first_steps() : plan.second_steps();
|
||||
int elems_per_thread = compute_elems_per_thread(n, steps);
|
||||
int threads_per_fft = ceildiv(n, elems_per_thread);
|
||||
int tg_batch_size =
|
||||
std::max(MIN_THREADGROUP_MEM_SIZE / n, MIN_COALESCE_WIDTH);
|
||||
int tg_mem_size = next_power_of_2(tg_batch_size * n);
|
||||
int batch_size = ceildiv(total_batch_size, tg_batch_size);
|
||||
std::vector<MTLFC> func_consts = {
|
||||
{&inverse, MTL::DataType::DataTypeBool, 0},
|
||||
{&power_of_2, MTL::DataType::DataTypeBool, 1},
|
||||
{&elems_per_thread, MTL::DataType::DataTypeInt, 2}};
|
||||
for (int i = 0; i < steps.size(); i++) {
|
||||
func_consts.emplace_back(&steps[i], MTL::DataType::DataTypeInt, 4 + i);
|
||||
}
|
||||
|
||||
// Get the kernel
|
||||
auto to_type = [](const array& x) {
|
||||
return x.dtype() == float32 ? "float" : "float2";
|
||||
};
|
||||
auto in_type = step == 0 ? to_type(in) : to_type(intermediate);
|
||||
auto out_type = step == 0 ? to_type(intermediate) : to_type(out);
|
||||
std::string hash_name;
|
||||
std::string kname;
|
||||
kname.reserve(64);
|
||||
hash_name.reserve(64);
|
||||
concatenate(
|
||||
kname,
|
||||
"four_step_mem_",
|
||||
tg_mem_size,
|
||||
"_",
|
||||
in_type,
|
||||
"_",
|
||||
out_type,
|
||||
"_",
|
||||
step,
|
||||
(real ? "_true" : "_false"));
|
||||
concatenate(hash_name, kname, "_n", n, "_inv_", inverse);
|
||||
auto template_def = get_template_definition(
|
||||
kname, "four_step_fft", tg_mem_size, in_type, out_type, step, real);
|
||||
auto kernel =
|
||||
get_fft_kernel(d, kname, hash_name, func_consts, template_def);
|
||||
|
||||
// Launch it
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array((step == 0) ? in : intermediate, 0);
|
||||
compute_encoder.set_output_array((step == 0) ? intermediate : out, 1);
|
||||
compute_encoder.set_bytes(n1, 2);
|
||||
compute_encoder.set_bytes(n2, 3);
|
||||
compute_encoder.set_bytes(total_batch_size, 4);
|
||||
|
||||
MTL::Size group_dims(1, tg_batch_size, threads_per_fft);
|
||||
MTL::Size grid_dims(batch_size, tg_batch_size, threads_per_fft);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
void fft_bluestein(
|
||||
const FFTPlan& plan,
|
||||
const array& in_,
|
||||
array& out,
|
||||
size_t axis,
|
||||
bool inverse,
|
||||
bool real,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
// Prepare the input and output arrays such that `axis` has stride 1.
|
||||
// Possibly copy the input but never the output as it doesn't have anything
|
||||
// useful in it yet.
|
||||
array in = ensure_fastest_moving_axis(in_, axis, d, s);
|
||||
prepare_output_array(in, out, axis);
|
||||
|
||||
// Prepare the arguments for bluestein fft
|
||||
int n = plan.bluestein_size();
|
||||
bool power_of_2 = true;
|
||||
int total_batch_size = out.dtype() == float32 ? out.size() / plan.size()
|
||||
: in.size() / plan.size();
|
||||
auto& steps = plan.steps();
|
||||
int elems_per_thread = compute_elems_per_thread(n, steps);
|
||||
int threads_per_fft = ceildiv(n, elems_per_thread);
|
||||
int tg_batch_size = std::max(MIN_THREADGROUP_MEM_SIZE / n, 1);
|
||||
int tg_mem_size = next_power_of_2(tg_batch_size * n);
|
||||
int batch_size = ceildiv(total_batch_size, tg_batch_size);
|
||||
batch_size = real ? ceildiv(batch_size, 2) : batch_size; // 2 RFFTs at once
|
||||
std::vector<MTLFC> func_consts = {
|
||||
{&inverse, MTL::DataType::DataTypeBool, 0},
|
||||
{&power_of_2, MTL::DataType::DataTypeBool, 1},
|
||||
{&elems_per_thread, MTL::DataType::DataTypeInt, 2}};
|
||||
for (int i = 0; i < steps.size(); i++) {
|
||||
func_consts.emplace_back(&steps[i], MTL::DataType::DataTypeInt, 4 + i);
|
||||
}
|
||||
|
||||
// Get the kernel
|
||||
auto in_type = in.dtype() == float32 ? "float" : "float2";
|
||||
auto out_type = out.dtype() == float32 ? "float" : "float2";
|
||||
std::string hash_name;
|
||||
std::string kname;
|
||||
kname.reserve(64);
|
||||
hash_name.reserve(64);
|
||||
concatenate(
|
||||
kname, "bluestein_fft_mem_", tg_mem_size, "_", in_type, "_", out_type);
|
||||
concatenate(hash_name, kname, "_n", n, "_inv_", inverse);
|
||||
auto template_def = get_template_definition(
|
||||
kname, "bluestein_fft", tg_mem_size, in_type, out_type);
|
||||
auto kernel = get_fft_kernel(d, kname, hash_name, func_consts, template_def);
|
||||
|
||||
// Get the bluestein constants
|
||||
auto [w_k, w_q] =
|
||||
compute_bluestein_constants(plan.size(), plan.bluestein_size());
|
||||
d.add_temporary(w_k, s.index);
|
||||
d.add_temporary(w_q, s.index);
|
||||
|
||||
// Launch it
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder.set_input_array(w_q, 2);
|
||||
compute_encoder.set_input_array(w_k, 3);
|
||||
compute_encoder.set_bytes(plan.size(), 4);
|
||||
compute_encoder.set_bytes(n, 5);
|
||||
compute_encoder.set_bytes(total_batch_size, 6);
|
||||
|
||||
MTL::Size group_dims(1, tg_batch_size, threads_per_fft);
|
||||
MTL::Size grid_dims(batch_size, tg_batch_size, threads_per_fft);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void fft_multi_upload_bluestein(
|
||||
const FFTPlan& plan,
|
||||
const array& in_,
|
||||
array& out,
|
||||
size_t axis,
|
||||
bool inverse,
|
||||
bool real,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
// Get Bluestein's constants using the CPU (this is done in the submission
|
||||
// thread which is pretty bad).
|
||||
auto [w_k, w_q] =
|
||||
compute_bluestein_constants(plan.size(), plan.bluestein_size());
|
||||
d.add_temporary(w_k, s.index);
|
||||
d.add_temporary(w_q, s.index);
|
||||
|
||||
// Prepare the input
|
||||
auto in_shape = inverse ? out.shape() : in_.shape();
|
||||
array in(std::move(in_shape), complex64, nullptr, {});
|
||||
if (real && !inverse) {
|
||||
copy_gpu(
|
||||
in_,
|
||||
in,
|
||||
in_.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||
s);
|
||||
d.add_temporary(in, s.index);
|
||||
} else if (real && inverse) {
|
||||
int back_offset = plan.size() % 2 == 0 ? 2 : 1;
|
||||
auto slice_shape = in.shape();
|
||||
slice_shape[axis] -= back_offset;
|
||||
array slice_temp(slice_shape, complex64, nullptr, {});
|
||||
array conj_temp(in.shape(), complex64, nullptr, {});
|
||||
Shape rstarts(in.ndim(), 0);
|
||||
Shape rstrides(in.ndim(), 1);
|
||||
rstarts[axis] = in.shape(axis) - back_offset;
|
||||
rstrides[axis] = -1;
|
||||
unary_op_gpu({in_}, conj_temp, "Conjugate", s);
|
||||
slice_gpu(in_, slice_temp, rstarts, rstrides, s);
|
||||
concatenate_gpu({conj_temp, slice_temp}, in, (int)axis, s);
|
||||
d.add_temporary(conj_temp, s.index);
|
||||
} else if (inverse) {
|
||||
unary_op_gpu({in_}, in, "Conjugate", s);
|
||||
d.add_temporary(in, s.index);
|
||||
} else {
|
||||
in.copy_shared_buffer(in_);
|
||||
}
|
||||
|
||||
// Multiply with
|
||||
Strides b_strides(in.ndim(), 0);
|
||||
b_strides[axis] = 1;
|
||||
array w_k_broadcast(in.shape(), complex64, nullptr, {});
|
||||
w_k_broadcast.copy_shared_buffer(w_k, b_strides, {}, w_k.data_size());
|
||||
array x(in.shape(), complex64, nullptr, {});
|
||||
binary_op_gpu({in, w_k_broadcast}, x, "Multiply", s);
|
||||
d.add_temporary(x, s.index);
|
||||
|
||||
// Pad
|
||||
auto padded_shape = out.shape();
|
||||
padded_shape[axis] = plan.bluestein_size();
|
||||
array padded_x(padded_shape, complex64, nullptr, {});
|
||||
auto zero = array(complex64_t{0.0f, 0.0f});
|
||||
pad_gpu(x, zero, padded_x, {(int)axis}, {0}, s);
|
||||
d.add_temporary(zero, s.index);
|
||||
d.add_temporary(padded_x, s.index);
|
||||
|
||||
// First fft
|
||||
}
|
||||
|
||||
void fft_op_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
size_t axis,
|
||||
bool inverse,
|
||||
bool real,
|
||||
bool inplace,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
fft_op(in, out, axis, inverse, real, FourStepParams(), inplace, s);
|
||||
// Get the FFT size and plan it
|
||||
auto plan =
|
||||
FFTPlan(out.dtype() == float32 ? out.shape(axis) : in.shape(axis));
|
||||
|
||||
switch (plan.type()) {
|
||||
case FFTPlan::NOOP:
|
||||
std::cout << "--------------> 1-size FFT <-----------------" << std::endl;
|
||||
break;
|
||||
case FFTPlan::STOCKHAM:
|
||||
return fft_stockham_inplace(plan, in, out, axis, inverse, real, d, s);
|
||||
case FFTPlan::SMALL_FOUR_STEP:
|
||||
return fft_four_step_inplace(plan, in, out, axis, inverse, real, d, s);
|
||||
case FFTPlan::BLUESTEIN:
|
||||
return fft_bluestein(plan, in, out, axis, inverse, real, d, s);
|
||||
case FFTPlan::UNSUPPORTED: {
|
||||
std::string msg;
|
||||
concatenate(msg, "FFT of size ", plan.size(), " not supported");
|
||||
throw std::runtime_error(msg);
|
||||
}
|
||||
default:
|
||||
std::cout << "----- NYI ----" << std::endl;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void nd_fft_op(
|
||||
void nd_fft_op_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
const std::vector<size_t>& axes,
|
||||
bool inverse,
|
||||
bool real,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
// Perform ND FFT on GPU as a series of 1D FFTs
|
||||
auto temp_shape = inverse ? in.shape() : out.shape();
|
||||
array temp1(temp_shape, complex64, nullptr, {});
|
||||
array temp2(temp_shape, complex64, nullptr, {});
|
||||
std::vector<array> temp_arrs = {temp1, temp2};
|
||||
for (int i = axes.size() - 1; i >= 0; i--) {
|
||||
int reverse_index = axes.size() - i - 1;
|
||||
// For 5D and above, we don't want to reallocate our two temporary arrays
|
||||
bool inplace = reverse_index >= 3 && i != 0;
|
||||
// Opposite order for fft vs ifft
|
||||
int index = inverse ? reverse_index : i;
|
||||
size_t axis = axes[index];
|
||||
// Mirror np.fft.(i)rfftn and perform a real transform
|
||||
// only on the final axis.
|
||||
bool step_real = (real && index == axes.size() - 1);
|
||||
auto step_shape = inverse ? out.shape(axis) : in.shape(axis);
|
||||
const array& in_arr = i == axes.size() - 1 ? in : temp_arrs[1 - i % 2];
|
||||
array& out_arr = i == 0 ? out : temp_arrs[i % 2];
|
||||
fft_op(in_arr, out_arr, axis, inverse, step_real, inplace, s);
|
||||
}
|
||||
// We are going to make and possibly reuse some intermediate arrays that will
|
||||
// hold the intermediate fft results.
|
||||
auto shape = inverse ? in.shape() : out.shape();
|
||||
std::vector<array> intermediates;
|
||||
intermediates.reserve(2);
|
||||
|
||||
auto& d = metal::device(s.device);
|
||||
d.add_temporaries(std::move(temp_arrs), s.index);
|
||||
// Utility to return either in or one of the intermediates.
|
||||
auto get_input_array = [&](int step) -> const array& {
|
||||
// The first step so use the input array
|
||||
if (step == 0) {
|
||||
return in;
|
||||
}
|
||||
|
||||
return intermediates[(step - 1) % 2];
|
||||
};
|
||||
|
||||
// Utility to return either out or one of the intermediates. It also informs
|
||||
// us if we should allocate memory for that output or there is already some
|
||||
// allocated.
|
||||
auto get_output_array = [&](int step) -> array& {
|
||||
// It is the final step so return the output array
|
||||
if (step == axes.size() - 1) {
|
||||
return out;
|
||||
}
|
||||
|
||||
// We already have made an array that we can use so go ahead and use it and
|
||||
// don't reallocate the memory.
|
||||
if (step % 2 < intermediates.size()) {
|
||||
return intermediates[step % 2];
|
||||
}
|
||||
|
||||
array x(shape, complex64, nullptr, {});
|
||||
x.set_data(allocator::malloc(x.nbytes()));
|
||||
intermediates.emplace_back(std::move(x));
|
||||
d.add_temporary(intermediates.back(), s.index);
|
||||
|
||||
return intermediates.back();
|
||||
};
|
||||
|
||||
// Perform ND FFT on GPU as a series of 1D FFTs
|
||||
for (int step = 0; step < axes.size(); step++) {
|
||||
auto x = get_input_array(step);
|
||||
auto y = get_output_array(step);
|
||||
auto step_axis = axes[inverse ? step : axes.size() - step - 1];
|
||||
auto step_real = real && (inverse ? step == axes.size() - 1 : step == 0);
|
||||
fft_op_inplace(x, y, step_axis, inverse, step_real, d, s);
|
||||
}
|
||||
}
|
||||
|
||||
void FFT::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto& in = inputs[0];
|
||||
|
||||
// The FFT ops above have the *_inplace suffix. This means that the memory
|
||||
// needs to be already allocated in the output array. Similar to
|
||||
// copy_gpu_inplace and so on.
|
||||
//
|
||||
// Even though we allocate the memory, we do not necessarily want the
|
||||
// contiguous strides so the *_inplace ops may change the strides and flags
|
||||
// of the array but will not reallocate the memory.
|
||||
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
if (axes_.size() > 1) {
|
||||
nd_fft_op(in, out, axes_, inverse_, real_, s);
|
||||
nd_fft_op_inplace(in, out, axes_, inverse_, real_, d, s);
|
||||
} else {
|
||||
fft_op(in, out, axes_[0], inverse_, real_, /*inplace=*/false, s);
|
||||
fft_op_inplace(in, out, axes_[0], inverse_, real_, d, s);
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -1,11 +1,9 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <map>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/common/hadamard.h"
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/jit/includes.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
@@ -15,7 +13,6 @@
|
||||
namespace mlx::core {
|
||||
|
||||
constexpr int MAX_HADAMARD_THREADS_PER_GROUP = 256;
|
||||
constexpr int MAX_HADAMARD_BYTES = 32768; // 32KB
|
||||
|
||||
std::string gen_hadamard_codelet(int m) {
|
||||
// Generate a O(m^2) hadamard codelet for a given M
|
||||
@@ -60,121 +57,142 @@ std::string gen_hadamard_codelet(int m) {
|
||||
return source.str();
|
||||
}
|
||||
|
||||
void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& s = stream();
|
||||
void hadamard_mn_contiguous(
|
||||
const array& x,
|
||||
array& y,
|
||||
int m,
|
||||
int n1,
|
||||
int n2,
|
||||
float scale,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
int n = n1 * n2;
|
||||
int read_width_n1 = n1 == 2 ? 2 : 4;
|
||||
int read_width_n2 = n2 == 2 ? 2 : 4;
|
||||
int read_width_m = (n == 2 || m == 28) ? 2 : 4;
|
||||
int max_radix_1 = std::min(n1, 16);
|
||||
int max_radix_2 = std::min(n2, 16);
|
||||
float scale_n1 = 1.0;
|
||||
float scale_n2 = (m == 1) ? scale : 1.0;
|
||||
float scale_m = scale;
|
||||
|
||||
auto& in = inputs[0];
|
||||
// n2 is a row contiguous power of 2 hadamard transform
|
||||
MTL::Size group_dims_n2(n2 / max_radix_2, 1, 1);
|
||||
MTL::Size grid_dims_n2(n2 / max_radix_2, x.size() / n2, 1);
|
||||
|
||||
std::vector<array> copies;
|
||||
// Only support the last axis for now
|
||||
int axis = in.ndim() - 1;
|
||||
auto check_input = [&copies, &s](const array& x) {
|
||||
// TODO(alexbarron) pass strides to kernel to relax this constraint
|
||||
bool no_copy = x.flags().row_contiguous;
|
||||
if (no_copy) {
|
||||
return x;
|
||||
} else {
|
||||
copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
|
||||
copy_gpu(x, copies.back(), CopyType::General, s);
|
||||
return copies.back();
|
||||
// n1 is a strided power of 2 hadamard transform with stride n2
|
||||
MTL::Size group_dims_n1(n1 / max_radix_1, 1, 1);
|
||||
MTL::Size grid_dims_n1(n1 / max_radix_1, x.size() / n, n2);
|
||||
|
||||
// m is a strided hadamard transform with stride n = n1 * n2
|
||||
MTL::Size group_dims_m(
|
||||
std::min(n / read_width_m, MAX_HADAMARD_THREADS_PER_GROUP), 1, 1);
|
||||
MTL::Size grid_dims_m(
|
||||
group_dims_m.width, x.size() / m / read_width_m / group_dims_m.width, 1);
|
||||
|
||||
// Make the kernel
|
||||
std::string kname;
|
||||
kname.reserve(32);
|
||||
concatenate(kname, "hadamard_", n * m, "_", type_to_name(x));
|
||||
auto lib = d.get_library(kname, [&]() {
|
||||
std::string kernel;
|
||||
concatenate(
|
||||
kernel,
|
||||
metal::utils(),
|
||||
gen_hadamard_codelet(m),
|
||||
metal::hadamard(),
|
||||
get_template_definition(
|
||||
"n2" + kname,
|
||||
"hadamard_n",
|
||||
get_type_string(x.dtype()),
|
||||
n2,
|
||||
max_radix_2,
|
||||
read_width_n2));
|
||||
if (n1 > 1) {
|
||||
kernel += get_template_definition(
|
||||
"n1" + kname,
|
||||
"hadamard_n",
|
||||
get_type_string(x.dtype()),
|
||||
n1,
|
||||
max_radix_1,
|
||||
read_width_n1,
|
||||
n2);
|
||||
}
|
||||
};
|
||||
const array& in_contiguous = check_input(in);
|
||||
|
||||
if (in_contiguous.is_donatable()) {
|
||||
out.copy_shared_buffer(in_contiguous);
|
||||
} else {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
}
|
||||
|
||||
int n, m;
|
||||
std::tie(n, m) = decompose_hadamard(in.shape(axis));
|
||||
|
||||
if (n * (int)size_of(in.dtype()) > MAX_HADAMARD_BYTES) {
|
||||
throw std::invalid_argument(
|
||||
"[hadamard] For n = m*2^k, 2^k > 8192 for FP32 or 2^k > 16384 for FP16/BF16 NYI");
|
||||
}
|
||||
|
||||
int max_radix = std::min(n, 16);
|
||||
// Use read_width 2 for m = 28 to avoid register spilling
|
||||
int read_width = (n == 2 || m == 28) ? 2 : 4;
|
||||
|
||||
std::ostringstream kname;
|
||||
kname << "hadamard_" << n * m << "_" << type_to_name(out);
|
||||
auto kernel_name = kname.str();
|
||||
auto& d = metal::device(s.device);
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
std::ostringstream kernel_source;
|
||||
auto codelet = gen_hadamard_codelet(m);
|
||||
kernel_source << metal::utils() << codelet << metal::hadamard();
|
||||
kernel_source << get_template_definition(
|
||||
"n" + kernel_name,
|
||||
"hadamard_n",
|
||||
get_type_string(in.dtype()),
|
||||
n,
|
||||
max_radix,
|
||||
read_width);
|
||||
kernel_source << get_template_definition(
|
||||
"m" + kernel_name,
|
||||
"hadamard_m",
|
||||
get_type_string(in.dtype()),
|
||||
n,
|
||||
m,
|
||||
read_width);
|
||||
return kernel_source.str();
|
||||
if (m > 1) {
|
||||
kernel += get_template_definition(
|
||||
"m" + kname,
|
||||
"hadamard_m",
|
||||
get_type_string(x.dtype()),
|
||||
n,
|
||||
m,
|
||||
read_width_m);
|
||||
}
|
||||
return kernel;
|
||||
});
|
||||
|
||||
int batch_size = in.size() / n;
|
||||
int threads_per = n / max_radix;
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
|
||||
auto launch_hadamard = [&](const array& in,
|
||||
array& out,
|
||||
const std::string& kernel_name,
|
||||
float scale) {
|
||||
auto kernel = d.get_kernel(kernel_name, lib);
|
||||
assert(threads_per <= kernel->maxTotalThreadsPerThreadgroup());
|
||||
|
||||
// Launch the strided transform for n1
|
||||
if (n1 > 1) {
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel("n1" + kname, lib);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder.set_bytes(scale, 2);
|
||||
|
||||
MTL::Size group_dims = MTL::Size(1, threads_per, 1);
|
||||
MTL::Size grid_dims = MTL::Size(batch_size, threads_per, 1);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
};
|
||||
|
||||
if (m > 1) {
|
||||
// When m is greater than 1, we decompose the
|
||||
// computation into two uploads to the GPU:
|
||||
//
|
||||
// e.g. len(x) = 12*4 = 48, m = 12, n = 4
|
||||
//
|
||||
// y = h48 @ x
|
||||
//
|
||||
// Upload 1:
|
||||
// tmp = a.reshape(12, 4) @ h4
|
||||
//
|
||||
// Upload 2:
|
||||
// y = h12 @ tmp
|
||||
array temp(in.shape(), in.dtype(), nullptr, {});
|
||||
temp.set_data(allocator::malloc(temp.nbytes()));
|
||||
copies.push_back(temp);
|
||||
|
||||
launch_hadamard(in_contiguous, temp, "n" + kernel_name, 1.0);
|
||||
|
||||
// Metal sometimes reports 256 max threads per group for hadamard_m kernel
|
||||
threads_per = std::min(n / read_width, MAX_HADAMARD_THREADS_PER_GROUP);
|
||||
batch_size = in.size() / m / read_width / threads_per;
|
||||
launch_hadamard(temp, out, "m" + kernel_name, scale_);
|
||||
} else {
|
||||
launch_hadamard(in_contiguous, out, "n" + kernel_name, scale_);
|
||||
compute_encoder.set_input_array(x, 0);
|
||||
compute_encoder.set_output_array(y, 1);
|
||||
compute_encoder.set_bytes(scale_n1, 2);
|
||||
compute_encoder.dispatch_threads(grid_dims_n1, group_dims_n1);
|
||||
}
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
// Launch the transform for n2
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel("n2" + kname, lib);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(n1 > 1 ? y : x, 0);
|
||||
compute_encoder.set_output_array(y, 1);
|
||||
compute_encoder.set_bytes(scale_n2, 2);
|
||||
compute_encoder.dispatch_threads(grid_dims_n2, group_dims_n2);
|
||||
|
||||
// Launch the strided transform for m
|
||||
if (m > 1) {
|
||||
auto kernel = d.get_kernel("m" + kname, lib);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(y, 0);
|
||||
compute_encoder.set_output_array(y, 1);
|
||||
compute_encoder.set_bytes(scale_m, 2);
|
||||
compute_encoder.dispatch_threads(grid_dims_m, group_dims_m);
|
||||
}
|
||||
}
|
||||
|
||||
void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto& in = inputs[0];
|
||||
|
||||
// Split the hadamard transform so that all of them work on vectors smaller
|
||||
// than 8192 elements.
|
||||
//
|
||||
// We decompose it in the following way:
|
||||
//
|
||||
// n = m * n1 * n2 = m * 2^k1 * 2^k2
|
||||
//
|
||||
// where m is in (1, 12, 20, 28) and n1 and n2 <= 8192
|
||||
auto [n, m] = decompose_hadamard(in.shape().back());
|
||||
int n1 = 1, n2 = n;
|
||||
if (n > 8192) {
|
||||
for (n2 = 2; n2 * n2 < n; n2 *= 2) {
|
||||
}
|
||||
n1 = n / n2;
|
||||
}
|
||||
|
||||
if (in.flags().row_contiguous) {
|
||||
if (in.is_donatable()) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
}
|
||||
hadamard_mn_contiguous(in, out, m, n1, n2, scale_, d, s);
|
||||
} else {
|
||||
copy_gpu(in, out, CopyType::General, s);
|
||||
hadamard_mn_contiguous(out, out, m, n1, n2, scale_, d, s);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -2,7 +2,7 @@
|
||||
#include <fmt/format.h>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/jit/includes.h"
|
||||
#include "mlx/backend/metal/jit/indexing.h"
|
||||
|
@@ -33,6 +33,7 @@ const char* gemm();
|
||||
const char* steel_gemm_fused();
|
||||
const char* steel_gemm_masked();
|
||||
const char* steel_gemm_splitk();
|
||||
const char* steel_gemm_gather();
|
||||
const char* conv();
|
||||
const char* steel_conv();
|
||||
const char* steel_conv_general();
|
||||
|
@@ -584,6 +584,44 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_steel_gemm_gather_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& hash_name,
|
||||
const metal::MTLFCList& func_consts,
|
||||
const array& out,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
int bm,
|
||||
int bn,
|
||||
int bk,
|
||||
int wm,
|
||||
int wn,
|
||||
bool rhs) {
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
std::string kernel_source;
|
||||
concatenate(
|
||||
kernel_source,
|
||||
metal::utils(),
|
||||
metal::gemm(),
|
||||
metal::steel_gemm_gather(),
|
||||
get_template_definition(
|
||||
lib_name,
|
||||
rhs ? "gather_mm_rhs" : "gather_mm",
|
||||
get_type_string(out.dtype()),
|
||||
bm,
|
||||
bn,
|
||||
bk,
|
||||
wm,
|
||||
wn,
|
||||
transpose_a,
|
||||
transpose_b));
|
||||
return kernel_source;
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_gemv_masked_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
@@ -714,4 +752,43 @@ MTL::ComputePipelineState* get_quantized_kernel(
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_gather_qmm_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& hash_name,
|
||||
const metal::MTLFCList& func_consts,
|
||||
const array& x,
|
||||
int group_size,
|
||||
int bits,
|
||||
int bm,
|
||||
int bn,
|
||||
int bk,
|
||||
int wm,
|
||||
int wn,
|
||||
bool transpose) {
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
std::string kernel_source;
|
||||
concatenate(
|
||||
kernel_source,
|
||||
metal::utils(),
|
||||
metal::gemm(),
|
||||
metal::quantized(),
|
||||
get_template_definition(
|
||||
lib_name,
|
||||
"gather_qmm_rhs",
|
||||
get_type_string(x.dtype()),
|
||||
group_size,
|
||||
bits,
|
||||
bm,
|
||||
bn,
|
||||
bk,
|
||||
wm,
|
||||
wn,
|
||||
transpose));
|
||||
return kernel_source;
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -160,6 +160,21 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
|
||||
bool mn_aligned,
|
||||
bool k_aligned);
|
||||
|
||||
MTL::ComputePipelineState* get_steel_gemm_gather_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& hash_name,
|
||||
const metal::MTLFCList& func_consts,
|
||||
const array& out,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
int bm,
|
||||
int bn,
|
||||
int bk,
|
||||
int wm,
|
||||
int wn,
|
||||
bool rhs);
|
||||
|
||||
MTL::ComputePipelineState* get_steel_conv_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
@@ -209,6 +224,21 @@ MTL::ComputePipelineState* get_quantized_kernel(
|
||||
const std::string& kernel_name,
|
||||
const std::string& template_def);
|
||||
|
||||
MTL::ComputePipelineState* get_gather_qmm_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& hash_name,
|
||||
const metal::MTLFCList& func_consts,
|
||||
const array& x,
|
||||
int group_size,
|
||||
int bits,
|
||||
int bm,
|
||||
int bn,
|
||||
int bk,
|
||||
int wm,
|
||||
int wn,
|
||||
bool transpose);
|
||||
|
||||
// Create a GPU kernel template definition for JIT compilation
|
||||
template <typename... Args>
|
||||
std::string
|
||||
|
@@ -69,6 +69,7 @@ set(STEEL_HEADERS
|
||||
steel/gemm/loader.h
|
||||
steel/gemm/transforms.h
|
||||
steel/gemm/kernels/steel_gemm_fused.h
|
||||
steel/gemm/kernels/steel_gemm_gather.h
|
||||
steel/gemm/kernels/steel_gemm_masked.h
|
||||
steel/gemm/kernels/steel_gemm_splitk.h
|
||||
steel/utils/type_traits.h
|
||||
@@ -116,6 +117,7 @@ if(NOT MLX_METAL_JIT)
|
||||
build_kernel(steel/conv/kernels/steel_conv ${STEEL_HEADERS})
|
||||
build_kernel(steel/conv/kernels/steel_conv_general ${STEEL_HEADERS})
|
||||
build_kernel(steel/gemm/kernels/steel_gemm_fused ${STEEL_HEADERS})
|
||||
build_kernel(steel/gemm/kernels/steel_gemm_gather ${STEEL_HEADERS})
|
||||
build_kernel(steel/gemm/kernels/steel_gemm_masked ${STEEL_HEADERS})
|
||||
build_kernel(steel/gemm/kernels/steel_gemm_splitk ${STEEL_HEADERS})
|
||||
build_kernel(gemv_masked steel/utils.h)
|
||||
|
@@ -9,64 +9,85 @@ template <typename T, typename U, typename Op>
|
||||
c[index] = Op()(a[0], b[0]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
[[kernel]] void binary_sv(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
constant uint& size,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
c[index] = Op()(a[0], b[index]);
|
||||
index *= N;
|
||||
for (int i = 0; i < N && (index + i) < size; ++i) {
|
||||
c[index + i] = Op()(a[0], b[index + i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
[[kernel]] void binary_vs(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
constant uint& size,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
c[index] = Op()(a[index], b[0]);
|
||||
index *= N;
|
||||
for (int i = 0; i < N && (index + i) < size; ++i) {
|
||||
c[index + i] = Op()(a[index + i], b[0]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
[[kernel]] void binary_vv(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
constant uint& size,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
c[index] = Op()(a[index], b[index]);
|
||||
index *= N;
|
||||
for (int i = 0; i < N && (index + i) < size; ++i) {
|
||||
c[index + i] = Op()(a[index + i], b[index + i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
[[kernel]] void binary_sv2(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
constant int64_t& size,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
int64_t offset = index.x + grid_dim.x * int64_t(index.y);
|
||||
c[offset] = Op()(a[0], b[offset]);
|
||||
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||
for (int i = 0; i < N && (offset + i) < size; ++i) {
|
||||
c[offset + i] = Op()(a[0], b[offset + i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
[[kernel]] void binary_vs2(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
constant int64_t& size,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
int64_t offset = index.x + grid_dim.x * int64_t(index.y);
|
||||
c[offset] = Op()(a[offset], b[0]);
|
||||
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||
for (int i = 0; i < N && (offset + i) < size; ++i) {
|
||||
c[offset + i] = Op()(a[offset + i], b[0]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
[[kernel]] void binary_vv2(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
constant int64_t& size,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
int64_t offset = index.x + grid_dim.x * int64_t(index.y);
|
||||
c[offset] = Op()(a[offset], b[offset]);
|
||||
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||
for (int i = 0; i < N && (offset + i) < size; ++i) {
|
||||
c[offset + i] = Op()(a[offset + i], b[offset + i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, typename IdxT = int64_t>
|
||||
|
@@ -71,6 +71,7 @@ instantiate_binary_types_bool(Less)
|
||||
instantiate_binary_types_bool(LessEqual)
|
||||
instantiate_binary_types_bool(NotEqual)
|
||||
instantiate_binary_float(LogAddExp)
|
||||
instantiate_binary_all(LogAddExp, complex64, complex64_t, complex64_t)
|
||||
instantiate_binary_types(Maximum)
|
||||
instantiate_binary_types(Minimum)
|
||||
instantiate_binary_types(Multiply)
|
||||
|
@@ -130,6 +130,24 @@ struct LogAddExp {
|
||||
? maxval
|
||||
: (maxval + log1p(metal::exp(minval - maxval)));
|
||||
};
|
||||
|
||||
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||
if (metal::isnan(x.real) || metal::isnan(x.imag) || metal::isnan(y.real) ||
|
||||
metal::isnan(y.imag)) {
|
||||
return metal::numeric_limits<float>::quiet_NaN();
|
||||
}
|
||||
constexpr float inf = metal::numeric_limits<float>::infinity();
|
||||
complex64_t maxval = x > y ? x : y;
|
||||
complex64_t minval = x < y ? x : y;
|
||||
if (minval.real == -inf || maxval.real == inf)
|
||||
return maxval;
|
||||
float m = metal::exp(minval.real - maxval.real);
|
||||
complex64_t dexp{
|
||||
m * metal::cos(minval.imag - maxval.imag),
|
||||
m * metal::sin(minval.imag - maxval.imag),
|
||||
};
|
||||
return maxval + log1p(dexp);
|
||||
}
|
||||
};
|
||||
|
||||
struct Maximum {
|
||||
|
@@ -12,82 +12,103 @@ template <typename T, typename U, typename Op>
|
||||
d[index] = out[1];
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
[[kernel]] void binary_sv(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
constant uint& size,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto out = Op()(a[0], b[index]);
|
||||
c[index] = out[0];
|
||||
d[index] = out[1];
|
||||
index *= N;
|
||||
for (int i = 0; i < N && (index + i) < size; ++i) {
|
||||
auto out = Op()(a[0], b[index + i]);
|
||||
c[index + i] = out[0];
|
||||
d[index + i] = out[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
[[kernel]] void binary_vs(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
constant uint& size,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto out = Op()(a[index], b[0]);
|
||||
c[index] = out[0];
|
||||
d[index] = out[1];
|
||||
index *= N;
|
||||
for (int i = 0; i < N && (index + i) < size; ++i) {
|
||||
auto out = Op()(a[index + i], b[0]);
|
||||
c[index + i] = out[0];
|
||||
d[index + i] = out[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
[[kernel]] void binary_vv(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
constant uint& size,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto out = Op()(a[index], b[index]);
|
||||
c[index] = out[0];
|
||||
d[index] = out[1];
|
||||
index *= N;
|
||||
for (int i = 0; i < N && (index + i) < size; ++i) {
|
||||
auto out = Op()(a[index + i], b[index + i]);
|
||||
c[index + i] = out[0];
|
||||
d[index + i] = out[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
[[kernel]] void binary_sv2(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
constant int64_t& size,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto offset = index.x + grid_dim.x * int64_t(index.y);
|
||||
auto out = Op()(a[0], b[offset]);
|
||||
c[offset] = out[0];
|
||||
d[offset] = out[1];
|
||||
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||
for (int i = 0; i < N && (offset + i) < size; ++i) {
|
||||
auto out = Op()(a[0], b[offset + i]);
|
||||
c[offset + i] = out[0];
|
||||
d[offset + i] = out[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
[[kernel]] void binary_vs2(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
constant int64_t& size,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto offset = index.x + grid_dim.x * int64_t(index.y);
|
||||
auto out = Op()(a[offset], b[0]);
|
||||
c[offset] = out[0];
|
||||
d[offset] = out[1];
|
||||
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||
for (int i = 0; i < N && (offset + i) < size; ++i) {
|
||||
auto out = Op()(a[offset + i], b[0]);
|
||||
c[offset + i] = out[0];
|
||||
d[offset + i] = out[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
[[kernel]] void binary_vv2(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
constant int64_t& size,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto offset = index.x + grid_dim.x * int64_t(index.y);
|
||||
auto out = Op()(a[offset], b[offset]);
|
||||
c[offset] = out[0];
|
||||
d[offset] = out[1];
|
||||
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||
for (int i = 0; i < N && (offset + i) < size; ++i) {
|
||||
auto out = Op()(a[offset + i], b[offset + i]);
|
||||
c[offset + i] = out[0];
|
||||
d[offset + i] = out[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, typename IdxT = int64_t>
|
||||
|
@@ -104,10 +104,22 @@ constexpr bool operator==(complex64_t a, complex64_t b) {
|
||||
constexpr complex64_t operator+(complex64_t a, complex64_t b) {
|
||||
return {a.real + b.real, a.imag + b.imag};
|
||||
}
|
||||
constexpr complex64_t operator+(float a, complex64_t b) {
|
||||
return {a + b.real, b.imag};
|
||||
}
|
||||
constexpr complex64_t operator+(complex64_t a, float b) {
|
||||
return {a.real + b, a.imag};
|
||||
}
|
||||
|
||||
constexpr complex64_t operator-(complex64_t a, complex64_t b) {
|
||||
return {a.real - b.real, a.imag - b.imag};
|
||||
}
|
||||
constexpr complex64_t operator-(float a, complex64_t b) {
|
||||
return {a - b.real, -b.imag};
|
||||
}
|
||||
constexpr complex64_t operator-(complex64_t a, float b) {
|
||||
return {a.real - b, a.imag};
|
||||
}
|
||||
|
||||
constexpr complex64_t operator*(complex64_t a, complex64_t b) {
|
||||
return {a.real * b.real - a.imag * b.imag, a.real * b.imag + a.imag * b.real};
|
||||
@@ -120,6 +132,13 @@ constexpr complex64_t operator/(complex64_t a, complex64_t b) {
|
||||
return {x / denom, y / denom};
|
||||
}
|
||||
|
||||
constexpr complex64_t operator/(float a, complex64_t b) {
|
||||
auto denom = b.real * b.real + b.imag * b.imag;
|
||||
auto x = a * b.real;
|
||||
auto y = -a * b.imag;
|
||||
return {x / denom, y / denom};
|
||||
}
|
||||
|
||||
constexpr complex64_t operator%(complex64_t a, complex64_t b) {
|
||||
auto real = a.real - (b.real * static_cast<int64_t>(a.real / b.real));
|
||||
auto imag = a.imag - (b.imag * static_cast<int64_t>(a.imag / b.imag));
|
||||
|
@@ -1,39 +1,53 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
template <typename T, typename U>
|
||||
template <typename T, typename U, int N = WorkPerThread<T>::n>
|
||||
[[kernel]] void copy_s(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant uint& size,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
dst[index] = static_cast<U>(src[0]);
|
||||
index *= N;
|
||||
for (int i = 0; i < N && (index + i) < size; ++i) {
|
||||
dst[index + i] = static_cast<U>(src[0]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
template <typename T, typename U, int N = WorkPerThread<T>::n>
|
||||
[[kernel]] void copy_v(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant uint& size,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
dst[index] = static_cast<U>(src[index]);
|
||||
index *= N;
|
||||
for (int i = 0; i < N && (index + i) < size; ++i) {
|
||||
dst[index + i] = static_cast<U>(src[index + i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
template <typename T, typename U, int N = WorkPerThread<T>::n>
|
||||
[[kernel]] void copy_s2(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant int64_t& size,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto offset = index.x + grid_dim.x * int64_t(index.y);
|
||||
dst[offset] = static_cast<U>(src[0]);
|
||||
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||
for (int i = 0; i < N && (offset + i) < size; ++i) {
|
||||
dst[offset + i] = static_cast<U>(src[0]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
template <typename T, typename U, int N = WorkPerThread<T>::n>
|
||||
[[kernel]] void copy_v2(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant int64_t& size,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto offset = index.x + grid_dim.x * int64_t(index.y);
|
||||
dst[offset] = static_cast<U>(src[offset]);
|
||||
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||
for (int i = 0; i < N && (offset + i) < size; ++i) {
|
||||
dst[offset + i] = static_cast<U>(src[offset + i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename IdxT = int64_t>
|
||||
|
@@ -10,7 +10,7 @@ For many sizes, GPU FFTs are memory bandwidth bound so
|
||||
read/write performance is important.
|
||||
|
||||
Where possible, we read 128 bits sequentially in each thread,
|
||||
coalesced with accesses from adajcent threads for optimal performance.
|
||||
coalesced with accesses from adjacent threads for optimal performance.
|
||||
|
||||
We implement specialized reading/writing for:
|
||||
- FFT
|
||||
|
@@ -26,7 +26,7 @@ METAL_FUNC void radix_func(thread float* x) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int N, int max_radix, int read_width>
|
||||
template <typename T, int N, int max_radix, int read_width, int stride = 1>
|
||||
[[kernel]] void hadamard_n(
|
||||
const device T* in [[buffer(0)]],
|
||||
device T* out [[buffer(1)]],
|
||||
@@ -46,18 +46,25 @@ template <typename T, int N, int max_radix, int read_width>
|
||||
constexpr short logFinal = logN % logR;
|
||||
constexpr short final_radix = 1 << (logFinal);
|
||||
|
||||
int batch_idx = elem.x * N;
|
||||
short i = elem.y;
|
||||
int batch_idx = elem.y * N * stride + elem.z;
|
||||
short i = elem.x;
|
||||
|
||||
threadgroup T buf[N];
|
||||
|
||||
// Read values from device
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < max_radix / read_width; j++) {
|
||||
short index = j * read_width * num_threads + i * read_width;
|
||||
if (stride == 1) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short r = 0; r < read_width; r++) {
|
||||
buf[index + r] = in[batch_idx + index + r];
|
||||
for (short j = 0; j < max_radix / read_width; j++) {
|
||||
short index = j * read_width * num_threads + i * read_width;
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short r = 0; r < read_width; r++) {
|
||||
buf[index + r] = in[batch_idx + index + r];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < max_radix; j++) {
|
||||
buf[j * num_threads + i] = in[batch_idx + (j * num_threads + i) * stride];
|
||||
}
|
||||
}
|
||||
|
||||
@@ -113,12 +120,20 @@ template <typename T, int N, int max_radix, int read_width>
|
||||
}
|
||||
|
||||
// Write values to device
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < max_radix / read_width; j++) {
|
||||
short index = j * read_width * num_threads + i * read_width;
|
||||
if (stride == 1) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short r = 0; r < read_width; r++) {
|
||||
out[batch_idx + index + r] = T(buf[index + r] * scale);
|
||||
for (short j = 0; j < max_radix / read_width; j++) {
|
||||
short index = j * read_width * num_threads + i * read_width;
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short r = 0; r < read_width; r++) {
|
||||
out[batch_idx + index + r] = T(buf[index + r] * scale);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < max_radix; j++) {
|
||||
out[batch_idx + (j * num_threads + i) * stride] =
|
||||
buf[j * num_threads + i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -3,6 +3,10 @@
|
||||
#include <metal_simdgroup>
|
||||
#include <metal_stdlib>
|
||||
|
||||
constant bool align_M [[function_constant(200)]];
|
||||
constant bool align_N [[function_constant(201)]];
|
||||
constant bool align_K [[function_constant(202)]];
|
||||
|
||||
using namespace metal;
|
||||
|
||||
#define MLX_MTL_CONST static constant constexpr const
|
||||
@@ -1004,11 +1008,11 @@ METAL_FUNC void qmm_t_impl(
|
||||
|
||||
auto wl = (const device uint8_t*)w;
|
||||
|
||||
x += y_row * K;
|
||||
x += y_row * static_cast<int64_t>(K);
|
||||
wl += y_col * K_w;
|
||||
scales += y_col * K_g;
|
||||
biases += y_col * K_g;
|
||||
y += y_row * N + y_col;
|
||||
y += y_row * static_cast<int64_t>(N) + y_col;
|
||||
|
||||
// Make the x loader and mma operation
|
||||
const short num_els = min(BM, M - y_row);
|
||||
@@ -1128,11 +1132,11 @@ METAL_FUNC void qmm_n_impl(
|
||||
// Set the block
|
||||
const int y_row = tid.y * BM;
|
||||
const int y_col = tid.x * BN;
|
||||
x += y_row * K;
|
||||
x += y_row * static_cast<int64_t>(K);
|
||||
wl += y_col * bytes_per_pack / pack_factor;
|
||||
scales += y_col / group_size;
|
||||
biases += y_col / group_size;
|
||||
y += y_row * N + y_col;
|
||||
y += y_row * static_cast<int64_t>(N) + y_col;
|
||||
|
||||
// Make the x loader and mma operation
|
||||
const short num_els = min(BM, M - y_row);
|
||||
@@ -1686,26 +1690,26 @@ template <
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int bits>
|
||||
[[kernel]] void bs_qmv_fast(
|
||||
[[kernel]] void gather_qmv_fast(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
const device T* biases [[buffer(2)]],
|
||||
const device T* x [[buffer(3)]],
|
||||
device T* y [[buffer(4)]],
|
||||
const constant int& in_vec_size [[buffer(5)]],
|
||||
const constant int& out_vec_size [[buffer(6)]],
|
||||
const constant int& x_batch_ndims [[buffer(7)]],
|
||||
const constant int* x_shape [[buffer(8)]],
|
||||
const constant int64_t* x_strides [[buffer(9)]],
|
||||
const constant int& w_batch_ndims [[buffer(10)]],
|
||||
const constant int* w_shape [[buffer(11)]],
|
||||
const constant int64_t* w_strides [[buffer(12)]],
|
||||
const constant int64_t* s_strides [[buffer(13)]],
|
||||
const constant int64_t* b_strides [[buffer(14)]],
|
||||
const constant int& batch_ndims [[buffer(15)]],
|
||||
const constant int* batch_shape [[buffer(16)]],
|
||||
const device uint32_t* lhs_indices [[buffer(17)]],
|
||||
const device uint32_t* rhs_indices [[buffer(18)]],
|
||||
const device uint32_t* lhs_indices [[buffer(4)]],
|
||||
const device uint32_t* rhs_indices [[buffer(5)]],
|
||||
device T* y [[buffer(6)]],
|
||||
const constant int& in_vec_size [[buffer(7)]],
|
||||
const constant int& out_vec_size [[buffer(8)]],
|
||||
const constant int& x_batch_ndims [[buffer(9)]],
|
||||
const constant int* x_shape [[buffer(10)]],
|
||||
const constant int64_t* x_strides [[buffer(11)]],
|
||||
const constant int& w_batch_ndims [[buffer(12)]],
|
||||
const constant int* w_shape [[buffer(13)]],
|
||||
const constant int64_t* w_strides [[buffer(14)]],
|
||||
const constant int64_t* s_strides [[buffer(15)]],
|
||||
const constant int64_t* b_strides [[buffer(16)]],
|
||||
const constant int& batch_ndims [[buffer(17)]],
|
||||
const constant int* batch_shape [[buffer(18)]],
|
||||
const constant int64_t* lhs_strides [[buffer(19)]],
|
||||
const constant int64_t* rhs_strides [[buffer(20)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
@@ -1748,26 +1752,26 @@ template <typename T, int group_size, int bits>
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int bits>
|
||||
[[kernel]] void bs_qmv(
|
||||
[[kernel]] void gather_qmv(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
const device T* biases [[buffer(2)]],
|
||||
const device T* x [[buffer(3)]],
|
||||
device T* y [[buffer(4)]],
|
||||
const constant int& in_vec_size [[buffer(5)]],
|
||||
const constant int& out_vec_size [[buffer(6)]],
|
||||
const constant int& x_batch_ndims [[buffer(7)]],
|
||||
const constant int* x_shape [[buffer(8)]],
|
||||
const constant int64_t* x_strides [[buffer(9)]],
|
||||
const constant int& w_batch_ndims [[buffer(10)]],
|
||||
const constant int* w_shape [[buffer(11)]],
|
||||
const constant int64_t* w_strides [[buffer(12)]],
|
||||
const constant int64_t* s_strides [[buffer(13)]],
|
||||
const constant int64_t* b_strides [[buffer(14)]],
|
||||
const constant int& batch_ndims [[buffer(15)]],
|
||||
const constant int* batch_shape [[buffer(16)]],
|
||||
const device uint32_t* lhs_indices [[buffer(17)]],
|
||||
const device uint32_t* rhs_indices [[buffer(18)]],
|
||||
const device uint32_t* lhs_indices [[buffer(4)]],
|
||||
const device uint32_t* rhs_indices [[buffer(5)]],
|
||||
device T* y [[buffer(6)]],
|
||||
const constant int& in_vec_size [[buffer(7)]],
|
||||
const constant int& out_vec_size [[buffer(8)]],
|
||||
const constant int& x_batch_ndims [[buffer(9)]],
|
||||
const constant int* x_shape [[buffer(10)]],
|
||||
const constant int64_t* x_strides [[buffer(11)]],
|
||||
const constant int& w_batch_ndims [[buffer(12)]],
|
||||
const constant int* w_shape [[buffer(13)]],
|
||||
const constant int64_t* w_strides [[buffer(14)]],
|
||||
const constant int64_t* s_strides [[buffer(15)]],
|
||||
const constant int64_t* b_strides [[buffer(16)]],
|
||||
const constant int& batch_ndims [[buffer(17)]],
|
||||
const constant int* batch_shape [[buffer(18)]],
|
||||
const constant int64_t* lhs_strides [[buffer(19)]],
|
||||
const constant int64_t* rhs_strides [[buffer(20)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
@@ -1810,26 +1814,26 @@ template <typename T, int group_size, int bits>
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int bits>
|
||||
[[kernel]] void bs_qvm(
|
||||
[[kernel]] void gather_qvm(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
const device T* biases [[buffer(2)]],
|
||||
const device T* x [[buffer(3)]],
|
||||
device T* y [[buffer(4)]],
|
||||
const constant int& in_vec_size [[buffer(5)]],
|
||||
const constant int& out_vec_size [[buffer(6)]],
|
||||
const constant int& x_batch_ndims [[buffer(7)]],
|
||||
const constant int* x_shape [[buffer(8)]],
|
||||
const constant int64_t* x_strides [[buffer(9)]],
|
||||
const constant int& w_batch_ndims [[buffer(10)]],
|
||||
const constant int* w_shape [[buffer(11)]],
|
||||
const constant int64_t* w_strides [[buffer(12)]],
|
||||
const constant int64_t* s_strides [[buffer(13)]],
|
||||
const constant int64_t* b_strides [[buffer(14)]],
|
||||
const constant int& batch_ndims [[buffer(15)]],
|
||||
const constant int* batch_shape [[buffer(16)]],
|
||||
const device uint32_t* lhs_indices [[buffer(17)]],
|
||||
const device uint32_t* rhs_indices [[buffer(18)]],
|
||||
const device uint32_t* lhs_indices [[buffer(4)]],
|
||||
const device uint32_t* rhs_indices [[buffer(5)]],
|
||||
device T* y [[buffer(6)]],
|
||||
const constant int& in_vec_size [[buffer(7)]],
|
||||
const constant int& out_vec_size [[buffer(8)]],
|
||||
const constant int& x_batch_ndims [[buffer(9)]],
|
||||
const constant int* x_shape [[buffer(10)]],
|
||||
const constant int64_t* x_strides [[buffer(11)]],
|
||||
const constant int& w_batch_ndims [[buffer(12)]],
|
||||
const constant int* w_shape [[buffer(13)]],
|
||||
const constant int64_t* w_strides [[buffer(14)]],
|
||||
const constant int64_t* s_strides [[buffer(15)]],
|
||||
const constant int64_t* b_strides [[buffer(16)]],
|
||||
const constant int& batch_ndims [[buffer(17)]],
|
||||
const constant int* batch_shape [[buffer(18)]],
|
||||
const constant int64_t* lhs_strides [[buffer(19)]],
|
||||
const constant int64_t* rhs_strides [[buffer(20)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
@@ -1879,27 +1883,27 @@ template <
|
||||
const int BM = 32,
|
||||
const int BK = 32,
|
||||
const int BN = 32>
|
||||
[[kernel]] void bs_qmm_t(
|
||||
[[kernel]] void gather_qmm_t(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
const device T* biases [[buffer(2)]],
|
||||
const device T* x [[buffer(3)]],
|
||||
device T* y [[buffer(4)]],
|
||||
const constant int& K [[buffer(5)]],
|
||||
const constant int& N [[buffer(6)]],
|
||||
const constant int& M [[buffer(7)]],
|
||||
const constant int& x_batch_ndims [[buffer(8)]],
|
||||
const constant int* x_shape [[buffer(9)]],
|
||||
const constant int64_t* x_strides [[buffer(10)]],
|
||||
const constant int& w_batch_ndims [[buffer(11)]],
|
||||
const constant int* w_shape [[buffer(12)]],
|
||||
const constant int64_t* w_strides [[buffer(13)]],
|
||||
const constant int64_t* s_strides [[buffer(14)]],
|
||||
const constant int64_t* b_strides [[buffer(15)]],
|
||||
const constant int& batch_ndims [[buffer(16)]],
|
||||
const constant int* batch_shape [[buffer(17)]],
|
||||
const device uint32_t* lhs_indices [[buffer(18)]],
|
||||
const device uint32_t* rhs_indices [[buffer(19)]],
|
||||
const device uint32_t* lhs_indices [[buffer(4)]],
|
||||
const device uint32_t* rhs_indices [[buffer(5)]],
|
||||
device T* y [[buffer(6)]],
|
||||
const constant int& K [[buffer(7)]],
|
||||
const constant int& N [[buffer(8)]],
|
||||
const constant int& M [[buffer(9)]],
|
||||
const constant int& x_batch_ndims [[buffer(10)]],
|
||||
const constant int* x_shape [[buffer(11)]],
|
||||
const constant int64_t* x_strides [[buffer(12)]],
|
||||
const constant int& w_batch_ndims [[buffer(13)]],
|
||||
const constant int* w_shape [[buffer(14)]],
|
||||
const constant int64_t* w_strides [[buffer(15)]],
|
||||
const constant int64_t* s_strides [[buffer(16)]],
|
||||
const constant int64_t* b_strides [[buffer(17)]],
|
||||
const constant int& batch_ndims [[buffer(18)]],
|
||||
const constant int* batch_shape [[buffer(19)]],
|
||||
const constant int64_t* lhs_strides [[buffer(20)]],
|
||||
const constant int64_t* rhs_strides [[buffer(21)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
@@ -1946,27 +1950,27 @@ template <
|
||||
const int BM = 32,
|
||||
const int BK = 32,
|
||||
const int BN = 32>
|
||||
[[kernel]] void bs_qmm_n(
|
||||
[[kernel]] void gather_qmm_n(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
const device T* biases [[buffer(2)]],
|
||||
const device T* x [[buffer(3)]],
|
||||
device T* y [[buffer(4)]],
|
||||
const constant int& K [[buffer(5)]],
|
||||
const constant int& N [[buffer(6)]],
|
||||
const constant int& M [[buffer(7)]],
|
||||
const constant int& x_batch_ndims [[buffer(8)]],
|
||||
const constant int* x_shape [[buffer(9)]],
|
||||
const constant int64_t* x_strides [[buffer(10)]],
|
||||
const constant int& w_batch_ndims [[buffer(11)]],
|
||||
const constant int* w_shape [[buffer(12)]],
|
||||
const constant int64_t* w_strides [[buffer(13)]],
|
||||
const constant int64_t* s_strides [[buffer(14)]],
|
||||
const constant int64_t* b_strides [[buffer(15)]],
|
||||
const constant int& batch_ndims [[buffer(16)]],
|
||||
const constant int* batch_shape [[buffer(17)]],
|
||||
const device uint32_t* lhs_indices [[buffer(18)]],
|
||||
const device uint32_t* rhs_indices [[buffer(19)]],
|
||||
const device uint32_t* lhs_indices [[buffer(4)]],
|
||||
const device uint32_t* rhs_indices [[buffer(5)]],
|
||||
device T* y [[buffer(6)]],
|
||||
const constant int& K [[buffer(7)]],
|
||||
const constant int& N [[buffer(8)]],
|
||||
const constant int& M [[buffer(9)]],
|
||||
const constant int& x_batch_ndims [[buffer(10)]],
|
||||
const constant int* x_shape [[buffer(11)]],
|
||||
const constant int64_t* x_strides [[buffer(12)]],
|
||||
const constant int& w_batch_ndims [[buffer(13)]],
|
||||
const constant int* w_shape [[buffer(14)]],
|
||||
const constant int64_t* w_strides [[buffer(15)]],
|
||||
const constant int64_t* s_strides [[buffer(16)]],
|
||||
const constant int64_t* b_strides [[buffer(17)]],
|
||||
const constant int& batch_ndims [[buffer(18)]],
|
||||
const constant int* batch_shape [[buffer(19)]],
|
||||
const constant int64_t* lhs_strides [[buffer(20)]],
|
||||
const constant int64_t* rhs_strides [[buffer(21)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
@@ -2007,6 +2011,289 @@ template <
|
||||
w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
|
||||
}
|
||||
|
||||
template <typename T, typename mma_t, typename loader_a_t, typename loader_b_t>
|
||||
METAL_FUNC void gemm_loop_aligned(
|
||||
threadgroup T* As,
|
||||
threadgroup T* Bs,
|
||||
thread mma_t& mma_op,
|
||||
thread loader_a_t& loader_a,
|
||||
thread loader_b_t& loader_b,
|
||||
const int k_iterations) {
|
||||
for (int k = 0; k < k_iterations; k++) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Load elements into threadgroup memory
|
||||
loader_a.load_unsafe();
|
||||
loader_b.load_unsafe();
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
bool rows_aligned,
|
||||
bool cols_aligned,
|
||||
bool transpose,
|
||||
typename T,
|
||||
typename mma_t,
|
||||
typename loader_a_t,
|
||||
typename loader_b_t>
|
||||
METAL_FUNC void gemm_loop_unaligned(
|
||||
threadgroup T* As,
|
||||
threadgroup T* Bs,
|
||||
thread mma_t& mma_op,
|
||||
thread loader_a_t& loader_a,
|
||||
thread loader_b_t& loader_b,
|
||||
const int k_iterations,
|
||||
const short tgp_bm,
|
||||
const short tgp_bn,
|
||||
const short tgp_bk) {
|
||||
for (int k = 0; k < k_iterations; k++) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Load elements into threadgroup memory
|
||||
if (rows_aligned) {
|
||||
loader_a.load_unsafe();
|
||||
} else {
|
||||
loader_a.load_safe(short2(tgp_bk, tgp_bm));
|
||||
}
|
||||
if (cols_aligned) {
|
||||
loader_b.load_unsafe();
|
||||
} else {
|
||||
loader_b.load_safe(
|
||||
transpose ? short2(tgp_bk, tgp_bn) : short2(tgp_bn, tgp_bk));
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename mma_t, typename loader_a_t, typename loader_b_t>
|
||||
METAL_FUNC void gemm_loop_finalize(
|
||||
threadgroup T* As,
|
||||
threadgroup T* Bs,
|
||||
thread mma_t& mma_op,
|
||||
thread loader_a_t& loader_a,
|
||||
thread loader_b_t& loader_b,
|
||||
const short2 tile_a,
|
||||
const short2 tile_b) {
|
||||
loader_a.load_safe(tile_a);
|
||||
loader_b.load_safe(tile_b);
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
mma_op.mma(As, Bs);
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int group_size,
|
||||
int bits,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose>
|
||||
[[kernel]] void gather_qmm_rhs(
|
||||
const device T* x [[buffer(0)]],
|
||||
const device uint32_t* w [[buffer(1)]],
|
||||
const device T* scales [[buffer(2)]],
|
||||
const device T* biases [[buffer(3)]],
|
||||
const device uint32_t* indices [[buffer(4)]],
|
||||
device T* y [[buffer(5)]],
|
||||
const constant int& M [[buffer(6)]],
|
||||
const constant int& N [[buffer(7)]],
|
||||
const constant int& K [[buffer(8)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]]) {
|
||||
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
|
||||
constexpr int BK_padded = (BK + 16 / sizeof(T));
|
||||
constexpr int BN_padded = (BN + 16 / sizeof(T));
|
||||
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
||||
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
|
||||
|
||||
using mma_t = mlx::steel::BlockMMA<
|
||||
T,
|
||||
T,
|
||||
BM,
|
||||
BN,
|
||||
BK,
|
||||
WM,
|
||||
WN,
|
||||
false,
|
||||
transpose,
|
||||
BK_padded,
|
||||
transpose ? BK_padded : BN_padded>;
|
||||
using loader_x_t =
|
||||
mlx::steel::BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE>;
|
||||
using loader_w_t = QuantizedBlockLoader<
|
||||
T,
|
||||
transpose ? BN : BK,
|
||||
transpose ? BK : BN,
|
||||
transpose ? BK_padded : BN_padded,
|
||||
transpose,
|
||||
WM * WN * SIMD_SIZE,
|
||||
group_size,
|
||||
bits>;
|
||||
|
||||
threadgroup T Xs[BM * BK_padded];
|
||||
threadgroup T Ws[transpose ? BN * BK_padded : BK * BN_padded];
|
||||
|
||||
// Compute the block
|
||||
const int K_w = K * bytes_per_pack / pack_factor;
|
||||
const int K_g = K / group_size;
|
||||
const int N_w = N * bytes_per_pack / pack_factor;
|
||||
const int N_g = N / group_size;
|
||||
const int K_it = K / BK;
|
||||
const size_t stride_w = transpose ? N * K_w : K * N_w;
|
||||
const size_t stride_s = transpose ? N * K_g : K * N_g;
|
||||
const int y_row = tid.y * BM;
|
||||
const int y_col = tid.x * BN;
|
||||
const size_t y_row_long = size_t(y_row);
|
||||
const size_t y_col_long = size_t(y_col);
|
||||
|
||||
// Prepare threadgroup bounds
|
||||
const short tgp_bm = align_M ? BM : short(min(BM, M - y_row));
|
||||
const short tgp_bn = align_N ? BN : short(min(BN, N - y_col));
|
||||
|
||||
// Calculate the final tiles in the case that K is not aligned
|
||||
const int k_remain = K - K_it * BK;
|
||||
const short2 tile_x = short2(k_remain, tgp_bm);
|
||||
const short2 tile_w =
|
||||
transpose ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
|
||||
|
||||
// Move x and output to the correct block
|
||||
auto wl = (const device uint8_t*)w;
|
||||
x += y_row_long * K;
|
||||
y += y_row_long * N + y_col_long;
|
||||
wl += transpose ? y_col_long * K_w : y_col * bytes_per_pack / pack_factor;
|
||||
scales += transpose ? y_col_long * K_g : y_col / group_size;
|
||||
biases += transpose ? y_col_long * K_g : y_col / group_size;
|
||||
|
||||
// Do as many matmuls as necessary
|
||||
uint32_t index;
|
||||
short offset;
|
||||
uint32_t index_next = indices[y_row];
|
||||
short offset_next = 0;
|
||||
int n = 0;
|
||||
while (n < tgp_bm) {
|
||||
n++;
|
||||
offset = offset_next;
|
||||
index = index_next;
|
||||
offset_next = tgp_bm;
|
||||
for (; n < tgp_bm; n++) {
|
||||
if (indices[y_row + n] != index) {
|
||||
offset_next = n;
|
||||
index_next = indices[y_row + n];
|
||||
break;
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Prepare threadgroup mma operation
|
||||
thread mma_t mma_op(simd_group_id, simd_lane_id);
|
||||
|
||||
// Prepare threadgroup loading operations
|
||||
thread loader_x_t loader_x(x, K, Xs, simd_group_id, simd_lane_id);
|
||||
thread loader_w_t loader_w(
|
||||
wl + index * stride_w,
|
||||
scales + index * stride_s,
|
||||
biases + index * stride_s,
|
||||
transpose ? K : N,
|
||||
Ws,
|
||||
simd_group_id,
|
||||
simd_lane_id);
|
||||
|
||||
// Matrices are all aligned check nothing
|
||||
if (align_M && align_N) {
|
||||
gemm_loop_aligned(Xs, Ws, mma_op, loader_x, loader_w, K_it);
|
||||
if (!align_K) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
gemm_loop_finalize(Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
|
||||
}
|
||||
|
||||
// Store results to device memory
|
||||
if (offset_next - offset == BM) {
|
||||
mma_op.store_result(y, N);
|
||||
} else {
|
||||
mma_op.store_result_slice(
|
||||
y, N, short2(0, offset), short2(BN, offset_next));
|
||||
}
|
||||
} else {
|
||||
// Tile aligned so check outside of the hot loop
|
||||
if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {
|
||||
gemm_loop_aligned(Xs, Ws, mma_op, loader_x, loader_w, K_it);
|
||||
if (!align_K) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
gemm_loop_finalize(
|
||||
Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
|
||||
}
|
||||
|
||||
// Store results to device memory
|
||||
if (offset_next - offset == BM) {
|
||||
mma_op.store_result(y, N);
|
||||
} else {
|
||||
mma_op.store_result_slice(
|
||||
y, N, short2(0, offset), short2(BN, offset_next));
|
||||
}
|
||||
}
|
||||
|
||||
// Tile partially aligned check rows
|
||||
else if (align_N || tgp_bn == BN) {
|
||||
gemm_loop_unaligned<false, true, transpose>(
|
||||
Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK);
|
||||
if (!align_K) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
gemm_loop_finalize(
|
||||
Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
|
||||
}
|
||||
mma_op.store_result_slice(
|
||||
y, N, short2(0, offset), short2(BN, offset_next));
|
||||
}
|
||||
|
||||
// Tile partially aligned check cols
|
||||
else if (align_M || tgp_bm == BM) {
|
||||
gemm_loop_unaligned<true, false, transpose>(
|
||||
Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK);
|
||||
if (!align_K) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
gemm_loop_finalize(
|
||||
Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
|
||||
}
|
||||
mma_op.store_result_slice(
|
||||
y, N, short2(0, offset), short2(tgp_bn, offset_next));
|
||||
}
|
||||
|
||||
// Nothing aligned so check both rows and cols
|
||||
else {
|
||||
gemm_loop_unaligned<false, false, transpose>(
|
||||
Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK);
|
||||
if (!align_K) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
gemm_loop_finalize(
|
||||
Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
|
||||
}
|
||||
mma_op.store_result_slice(
|
||||
y, N, short2(0, offset), short2(tgp_bn, offset_next));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, const int group_size, const int bits>
|
||||
[[kernel]] void affine_quantize(
|
||||
const device T* w [[buffer(0)]],
|
||||
|
@@ -60,6 +60,20 @@
|
||||
bits, \
|
||||
split_k)
|
||||
|
||||
#define instantiate_gather_qmm_rhs(func, name, type, group_size, bits, bm, bn, bk, wm, wn, transpose) \
|
||||
instantiate_kernel( \
|
||||
#name "_" #type "_gs_" #group_size "_b_" #bits "_bm_" #bm "_bn_" #bn "_bk_" #bk "_wm_" #wm "_wn_" #wn, \
|
||||
func, \
|
||||
type, \
|
||||
group_size, \
|
||||
bits, \
|
||||
bm, \
|
||||
bn, \
|
||||
bk, \
|
||||
wm, \
|
||||
wn, \
|
||||
transpose)
|
||||
|
||||
#define instantiate_quantized_batched_wrap(name, type, group_size, bits) \
|
||||
instantiate_quantized_batched(name, type, group_size, bits, 1) \
|
||||
instantiate_quantized_batched(name, type, group_size, bits, 0)
|
||||
@@ -73,14 +87,14 @@
|
||||
#define instantiate_quantized_all_single(type, group_size, bits) \
|
||||
instantiate_quantized(affine_quantize, type, group_size, bits) \
|
||||
instantiate_quantized(affine_dequantize, type, group_size, bits) \
|
||||
instantiate_quantized(bs_qmv_fast, type, group_size, bits) \
|
||||
instantiate_quantized(bs_qmv, type, group_size, bits) \
|
||||
instantiate_quantized(bs_qvm, type, group_size, bits) \
|
||||
instantiate_quantized(bs_qmm_n, type, group_size, bits)
|
||||
instantiate_quantized(gather_qmv_fast, type, group_size, bits) \
|
||||
instantiate_quantized(gather_qmv, type, group_size, bits) \
|
||||
instantiate_quantized(gather_qvm, type, group_size, bits) \
|
||||
instantiate_quantized(gather_qmm_n, type, group_size, bits)
|
||||
|
||||
#define instantiate_quantized_all_aligned(type, group_size, bits) \
|
||||
instantiate_quantized_aligned(bs_qmm_t, type, group_size, bits, true) \
|
||||
instantiate_quantized_aligned(bs_qmm_t, type, group_size, bits, false) \
|
||||
instantiate_quantized_aligned(gather_qmm_t, type, group_size, bits, true) \
|
||||
instantiate_quantized_aligned(gather_qmm_t, type, group_size, bits, false) \
|
||||
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, true, 1) \
|
||||
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, true, 0) \
|
||||
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, false, 1) \
|
||||
@@ -96,12 +110,17 @@
|
||||
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 8) \
|
||||
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 32)
|
||||
|
||||
#define instantiate_quantized_all_rhs(type, group_size, bits) \
|
||||
instantiate_gather_qmm_rhs(gather_qmm_rhs, gather_qmm_rhs_nt, type, group_size, bits, 16, 32, 32, 1, 2, true) \
|
||||
instantiate_gather_qmm_rhs(gather_qmm_rhs, gather_qmm_rhs_nn, type, group_size, bits, 16, 32, 32, 1, 2, false)
|
||||
|
||||
#define instantiate_quantized_funcs(type, group_size, bits) \
|
||||
instantiate_quantized_all_single(type, group_size, bits) \
|
||||
instantiate_quantized_all_batched(type, group_size, bits) \
|
||||
instantiate_quantized_all_aligned(type, group_size, bits) \
|
||||
instantiate_quantized_all_quad(type, group_size, bits) \
|
||||
instantiate_quantized_all_splitk(type, group_size, bits)
|
||||
instantiate_quantized_all_splitk(type, group_size, bits) \
|
||||
instantiate_quantized_all_rhs(type, group_size, bits)
|
||||
|
||||
#define instantiate_quantized_types(group_size, bits) \
|
||||
instantiate_quantized_funcs(float, group_size, bits) \
|
||||
|
@@ -104,4 +104,5 @@ instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMi
|
||||
instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin, 2)
|
||||
instantiate_scan_helper(logaddexp_float16_float16, half, half, CumLogaddexp, 4)
|
||||
instantiate_scan_helper(logaddexp_float32_float32, float, float, CumLogaddexp, 4)
|
||||
instantiate_scan_helper(logaddexp_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumLogaddexp, 4) // clang-format on
|
||||
instantiate_scan_helper(logaddexp_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumLogaddexp, 4)
|
||||
instantiate_scan_helper(logaddexp_complex64_complex64, complex64_t, complex64_t, CumLogaddexp, 2) // clang-format on
|
||||
|
@@ -56,9 +56,9 @@ template <typename T, int D, int V = D>
|
||||
const int head_idx = tid.x;
|
||||
const int q_seq_idx = tid.y;
|
||||
const int kv_head_idx = head_idx / gqa_factor;
|
||||
const int o_offset = tpg.x * q_seq_idx + head_idx;
|
||||
const int o_offset = head_idx * tpg.y + q_seq_idx;
|
||||
const int q_offset =
|
||||
query_transposed ? o_offset : head_idx * tpg.y + q_seq_idx;
|
||||
query_transposed ? tpg.x * q_seq_idx + head_idx : o_offset;
|
||||
queries += q_offset * D + simd_lid * qk_per_thread;
|
||||
keys += kv_head_idx * k_head_stride + simd_gid * k_seq_stride +
|
||||
simd_lid * qk_per_thread;
|
||||
@@ -213,9 +213,9 @@ template <typename T, int D, int V = D>
|
||||
const int block_idx = tid.z;
|
||||
const int head_idx = tid.x;
|
||||
const int q_seq_idx = tid.y;
|
||||
const int o_offset = tpg.x * q_seq_idx + head_idx;
|
||||
const int o_offset = head_idx * tpg.y + q_seq_idx;
|
||||
const int q_offset =
|
||||
query_transposed ? o_offset : head_idx * tpg.y + q_seq_idx;
|
||||
query_transposed ? tpg.x * q_seq_idx + head_idx : o_offset;
|
||||
const int kv_head_idx = head_idx / gqa_factor;
|
||||
|
||||
queries += q_offset * D + simd_lid * qk_per_thread;
|
||||
@@ -358,8 +358,8 @@ template <typename T, int D>
|
||||
// Adjust positions
|
||||
const int head_idx = tid.x;
|
||||
const int q_seq_idx = tid.y;
|
||||
const int n_heads = tpg.x;
|
||||
const int q_offset = n_heads * q_seq_idx + head_idx;
|
||||
const int q_offset = head_idx * tpg.y + q_seq_idx;
|
||||
;
|
||||
partials += q_offset * blocks * D + simd_gid * D + simd_lid * elem_per_thread;
|
||||
sums += q_offset * blocks;
|
||||
maxs += q_offset * blocks;
|
||||
|
@@ -95,7 +95,7 @@ template <
|
||||
|
||||
Q += tidl.z * params->Q_strides[0] + // Batch
|
||||
tidl.y * params->Q_strides[1] + // Head
|
||||
tidl.x * BQ * params->Q_strides[2]; // Seqeunce
|
||||
tidl.x * BQ * params->Q_strides[2]; // Sequence
|
||||
|
||||
ulong kv_head_idx = int(tid.y) / params->gqa_factor;
|
||||
K += tidl.z * params->K_strides[0] + // Batch
|
||||
@@ -106,7 +106,7 @@ template <
|
||||
|
||||
O += tidl.z * params->O_strides[0] + // Batch
|
||||
tidl.y * params->O_strides[1] + // Head
|
||||
tidl.x * BQ * params->O_strides[2]; // Seqeunce
|
||||
tidl.x * BQ * params->O_strides[2]; // Sequence
|
||||
|
||||
if (has_mask) {
|
||||
mask += tidl.z * mask_params->M_strides[0] + // Batch
|
||||
|
@@ -113,7 +113,7 @@ struct BlockLoader {
|
||||
tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
|
||||
}
|
||||
|
||||
// Zero out uneeded values
|
||||
// Zero out unneeded values
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);
|
||||
@@ -240,7 +240,7 @@ struct BlockLoaderT {
|
||||
tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
|
||||
}
|
||||
|
||||
// Zero out uneeded values
|
||||
// Zero out unneeded values
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);
|
||||
|
@@ -141,7 +141,7 @@ implicit_gemm_conv_2d_general(
|
||||
|
||||
// Store results to device memory
|
||||
{
|
||||
// Adjust for simdgroup and thread locatio
|
||||
// Adjust for simdgroup and thread location
|
||||
int offset_m = c_row + mma_op.sm;
|
||||
int offset_n = c_col + mma_op.sn;
|
||||
C += offset_n;
|
||||
|
@@ -15,10 +15,6 @@ constant bool align_M [[function_constant(200)]];
|
||||
constant bool align_N [[function_constant(201)]];
|
||||
constant bool align_K [[function_constant(202)]];
|
||||
|
||||
constant bool do_gather [[function_constant(300)]];
|
||||
|
||||
constant bool gather_bias = do_gather && use_out_source;
|
||||
|
||||
// clang-format off
|
||||
template <
|
||||
typename T,
|
||||
@@ -39,12 +35,6 @@ template <
|
||||
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
|
||||
const constant int* batch_shape [[buffer(6)]],
|
||||
const constant int64_t* batch_strides [[buffer(7)]],
|
||||
const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]],
|
||||
const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]],
|
||||
const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]],
|
||||
const constant int* operand_shape [[buffer(13), function_constant(do_gather)]],
|
||||
const constant int64_t* operand_strides [[buffer(14), function_constant(do_gather)]],
|
||||
const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
@@ -81,84 +71,26 @@ template <
|
||||
}
|
||||
|
||||
// Adjust for batch
|
||||
if (has_batch) {
|
||||
const constant auto* A_bstrides = batch_strides;
|
||||
const constant auto* B_bstrides = batch_strides + params->batch_ndim;
|
||||
|
||||
// Handle gather
|
||||
if (do_gather) {
|
||||
// Read indices
|
||||
uint32_t indx_A, indx_B, indx_C;
|
||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
|
||||
|
||||
if (has_batch) {
|
||||
const constant auto* indx_A_bstrides = batch_strides;
|
||||
const constant auto* indx_B_bstrides = batch_strides + params->batch_ndim;
|
||||
|
||||
ulong2 indx_offsets = elem_to_loc_broadcast(
|
||||
tid.z,
|
||||
batch_shape,
|
||||
indx_A_bstrides,
|
||||
indx_B_bstrides,
|
||||
params->batch_ndim);
|
||||
indx_A = lhs_indices[indx_offsets.x];
|
||||
indx_B = rhs_indices[indx_offsets.y];
|
||||
|
||||
if (use_out_source) {
|
||||
const constant auto* indx_C_bstrides =
|
||||
indx_B_bstrides + params->batch_ndim;
|
||||
auto indx_offset_C = elem_to_loc(
|
||||
tid.z, batch_shape, indx_C_bstrides, params->batch_ndim);
|
||||
indx_C = C_indices[indx_offset_C];
|
||||
}
|
||||
} else {
|
||||
indx_A = lhs_indices[params->batch_stride_a * tid.z];
|
||||
indx_B = rhs_indices[params->batch_stride_b * tid.z];
|
||||
|
||||
if (use_out_source) {
|
||||
indx_C = C_indices[addmm_params->batch_stride_c * tid.z];
|
||||
}
|
||||
}
|
||||
|
||||
// Translate indices to offsets
|
||||
int batch_ndim_A = operand_batch_ndim.x;
|
||||
const constant int* batch_shape_A = operand_shape;
|
||||
const constant auto* batch_strides_A = operand_strides;
|
||||
A += elem_to_loc(indx_A, batch_shape_A, batch_strides_A, batch_ndim_A);
|
||||
|
||||
int batch_ndim_B = operand_batch_ndim.y;
|
||||
const constant int* batch_shape_B = batch_shape_A + batch_ndim_A;
|
||||
const constant auto* batch_strides_B = batch_strides_A + batch_ndim_A;
|
||||
B += elem_to_loc(indx_B, batch_shape_B, batch_strides_B, batch_ndim_B);
|
||||
A += batch_offsets.x;
|
||||
B += batch_offsets.y;
|
||||
|
||||
if (use_out_source) {
|
||||
int batch_ndim_C = operand_batch_ndim.z;
|
||||
const constant int* batch_shape_C = batch_shape_B + batch_ndim_B;
|
||||
const constant auto* batch_strides_C = batch_strides_B + batch_ndim_B;
|
||||
C += elem_to_loc(indx_C, batch_shape_C, batch_strides_C, batch_ndim_C);
|
||||
const constant auto* C_bstrides = B_bstrides + params->batch_ndim;
|
||||
C += elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim);
|
||||
}
|
||||
} else {
|
||||
A += params->batch_stride_a * tid.z;
|
||||
B += params->batch_stride_b * tid.z;
|
||||
|
||||
}
|
||||
|
||||
// Handle regular batch
|
||||
else {
|
||||
if (has_batch) {
|
||||
const constant auto* A_bstrides = batch_strides;
|
||||
const constant auto* B_bstrides = batch_strides + params->batch_ndim;
|
||||
|
||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
|
||||
|
||||
A += batch_offsets.x;
|
||||
B += batch_offsets.y;
|
||||
|
||||
if (use_out_source) {
|
||||
const constant auto* C_bstrides = B_bstrides + params->batch_ndim;
|
||||
C += elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim);
|
||||
}
|
||||
} else {
|
||||
A += params->batch_stride_a * tid.z;
|
||||
B += params->batch_stride_b * tid.z;
|
||||
|
||||
if (use_out_source) {
|
||||
C += addmm_params->batch_stride_c * tid.z;
|
||||
}
|
||||
if (use_out_source) {
|
||||
C += addmm_params->batch_stride_c * tid.z;
|
||||
}
|
||||
}
|
||||
|
||||
|
459
mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h
Normal file
459
mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h
Normal file
@@ -0,0 +1,459 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
using namespace mlx::steel;
|
||||
|
||||
constant bool has_batch [[function_constant(10)]];
|
||||
constant bool align_M [[function_constant(200)]];
|
||||
constant bool align_N [[function_constant(201)]];
|
||||
constant bool align_K [[function_constant(202)]];
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
typename AccumType = float>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gather_mm_rhs(
|
||||
const device T* A [[buffer(0)]],
|
||||
const device T* B [[buffer(1)]],
|
||||
const device uint32_t* rhs_indices [[buffer(2)]],
|
||||
device T* C [[buffer(3)]],
|
||||
const constant GEMMParams* params [[buffer(4)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]]) {
|
||||
using gemm_kernel = GEMMKernel<
|
||||
T,
|
||||
T,
|
||||
BM,
|
||||
BN,
|
||||
BK,
|
||||
WM,
|
||||
WN,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
true,
|
||||
true,
|
||||
AccumType>;
|
||||
|
||||
using loader_a_t = typename gemm_kernel::loader_a_t;
|
||||
using loader_b_t = typename gemm_kernel::loader_b_t;
|
||||
using mma_t = typename gemm_kernel::mma_t;
|
||||
|
||||
if (params->tiles_n <= static_cast<int>(tid.x) ||
|
||||
params->tiles_m <= static_cast<int>(tid.y)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Prepare threadgroup memory
|
||||
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
||||
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
||||
|
||||
// Find the block in A, B, C
|
||||
const int c_row = tid.y * BM;
|
||||
const int c_col = tid.x * BN;
|
||||
const size_t c_row_long = size_t(c_row);
|
||||
const size_t c_col_long = size_t(c_col);
|
||||
|
||||
// Prepare threadgroup bounds
|
||||
const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row));
|
||||
const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col));
|
||||
|
||||
A += transpose_a ? c_row_long : c_row_long * params->lda;
|
||||
B += transpose_b ? c_col_long * params->ldb : c_col_long;
|
||||
C += c_row_long * params->ldd + c_col_long;
|
||||
|
||||
// Do as many matmuls as necessary
|
||||
uint32_t index;
|
||||
short offset;
|
||||
uint32_t index_next = rhs_indices[c_row];
|
||||
short offset_next = 0;
|
||||
int n = 0;
|
||||
while (n < tgp_bm) {
|
||||
n++;
|
||||
offset = offset_next;
|
||||
index = index_next;
|
||||
offset_next = tgp_bm;
|
||||
for (; n < tgp_bm; n++) {
|
||||
if (rhs_indices[c_row + n] != index) {
|
||||
offset_next = n;
|
||||
index_next = rhs_indices[c_row + n];
|
||||
break;
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Prepare threadgroup mma operation
|
||||
thread mma_t mma_op(simd_group_id, simd_lane_id);
|
||||
|
||||
// Prepare threadgroup loading operations
|
||||
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
||||
thread loader_b_t loader_b(
|
||||
B + index * params->batch_stride_b,
|
||||
params->ldb,
|
||||
Bs,
|
||||
simd_group_id,
|
||||
simd_lane_id);
|
||||
|
||||
// Prepare iterations
|
||||
const int gemm_k_iterations = params->gemm_k_iterations_aligned;
|
||||
|
||||
// Do unaligned K iterations first
|
||||
if (!align_K) {
|
||||
const int k_last = params->gemm_k_iterations_aligned * BK;
|
||||
const int k_remain = params->K - k_last;
|
||||
const size_t k_jump_a =
|
||||
transpose_a ? params->lda * size_t(k_last) : size_t(k_last);
|
||||
const size_t k_jump_b =
|
||||
transpose_b ? size_t(k_last) : params->ldb * size_t(k_last);
|
||||
|
||||
// Move loader source ahead to end
|
||||
loader_a.src += k_jump_a;
|
||||
loader_b.src += k_jump_b;
|
||||
|
||||
// Load tile
|
||||
const short2 tile_dims_A =
|
||||
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
|
||||
const short2 tile_dims_B =
|
||||
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
|
||||
|
||||
loader_a.load_safe(tile_dims_A);
|
||||
loader_b.load_safe(tile_dims_B);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Do matmul
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Reset source back to start
|
||||
loader_a.src -= k_jump_a;
|
||||
loader_b.src -= k_jump_b;
|
||||
}
|
||||
|
||||
// Matrix level aligned never check
|
||||
if (align_M && align_N) {
|
||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_unsafe();
|
||||
loader_b.load_unsafe();
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
|
||||
// Store results to device memory
|
||||
if (offset_next - offset == BM) {
|
||||
mma_op.store_result(C, params->ldd);
|
||||
} else {
|
||||
mma_op.store_result_slice(
|
||||
C, params->ldd, short2(0, offset), short2(BN, offset_next));
|
||||
}
|
||||
} else {
|
||||
const short lbk = 0;
|
||||
|
||||
// Tile aligned don't check
|
||||
if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
lbk,
|
||||
LoopAlignment<true, true, true>{});
|
||||
if (offset_next - offset == BM) {
|
||||
mma_op.store_result(C, params->ldd);
|
||||
} else {
|
||||
mma_op.store_result_slice(
|
||||
C, params->ldd, short2(0, offset), short2(BN, offset_next));
|
||||
}
|
||||
}
|
||||
|
||||
// Tile partially aligned check rows
|
||||
else if (align_N || tgp_bn == BN) {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
lbk,
|
||||
LoopAlignment<false, true, true>{});
|
||||
mma_op.store_result_slice(
|
||||
C, params->ldd, short2(0, offset), short2(BN, offset_next));
|
||||
}
|
||||
|
||||
// Tile partially aligned check cols
|
||||
else if (align_M || tgp_bm == BM) {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
lbk,
|
||||
LoopAlignment<true, false, true>{});
|
||||
mma_op.store_result_slice(
|
||||
C, params->ldd, short2(0, offset), short2(tgp_bn, offset_next));
|
||||
}
|
||||
|
||||
// Nothing aligned so check both rows and cols
|
||||
else {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
lbk,
|
||||
LoopAlignment<false, false, true>{});
|
||||
mma_op.store_result_slice(
|
||||
C, params->ldd, short2(0, offset), short2(tgp_bn, offset_next));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
typename AccumType = float>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gather_mm(
|
||||
const device T* A [[buffer(0)]],
|
||||
const device T* B [[buffer(1)]],
|
||||
const device uint32_t* lhs_indices [[buffer(2)]],
|
||||
const device uint32_t* rhs_indices [[buffer(3)]],
|
||||
device T* C [[buffer(4)]],
|
||||
const constant GEMMParams* params [[buffer(5)]],
|
||||
const constant int* indices_shape [[buffer(6)]],
|
||||
const constant int64_t* lhs_strides [[buffer(7)]],
|
||||
const constant int64_t* rhs_strides [[buffer(8)]],
|
||||
const constant int& batch_ndim_a [[buffer(9)]],
|
||||
const constant int* batch_shape_a [[buffer(10)]],
|
||||
const constant int64_t* batch_strides_a [[buffer(11)]],
|
||||
const constant int& batch_ndim_b [[buffer(12)]],
|
||||
const constant int* batch_shape_b [[buffer(13)]],
|
||||
const constant int64_t* batch_strides_b [[buffer(14)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]]) {
|
||||
using gemm_kernel = GEMMKernel<
|
||||
T,
|
||||
T,
|
||||
BM,
|
||||
BN,
|
||||
BK,
|
||||
WM,
|
||||
WN,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
true,
|
||||
true,
|
||||
AccumType>;
|
||||
|
||||
using loader_a_t = typename gemm_kernel::loader_a_t;
|
||||
using loader_b_t = typename gemm_kernel::loader_b_t;
|
||||
using mma_t = typename gemm_kernel::mma_t;
|
||||
|
||||
if (params->tiles_n <= static_cast<int>(tid.x) ||
|
||||
params->tiles_m <= static_cast<int>(tid.y)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Move A and B to the locations pointed by lhs_indices and rhs_indices.
|
||||
uint32_t indx_A, indx_B;
|
||||
if (has_batch) {
|
||||
ulong2 indices_offsets = elem_to_loc_broadcast(
|
||||
tid.z, indices_shape, lhs_strides, rhs_strides, params->batch_ndim);
|
||||
indx_A = lhs_indices[indices_offsets.x];
|
||||
indx_B = rhs_indices[indices_offsets.y];
|
||||
} else {
|
||||
indx_A = lhs_indices[params->batch_stride_a * tid.z];
|
||||
indx_B = rhs_indices[params->batch_stride_b * tid.z];
|
||||
}
|
||||
A += elem_to_loc(indx_A, batch_shape_a, batch_strides_a, batch_ndim_a);
|
||||
B += elem_to_loc(indx_B, batch_shape_b, batch_strides_b, batch_ndim_b);
|
||||
C += params->batch_stride_d * tid.z;
|
||||
|
||||
// Prepare threadgroup memory
|
||||
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
||||
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
||||
|
||||
// Just make sure everybody's finished with the indexing math above.
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Find block in A, B, C
|
||||
const int c_row = tid.y * BM;
|
||||
const int c_col = tid.x * BN;
|
||||
const size_t c_row_long = size_t(c_row);
|
||||
const size_t c_col_long = size_t(c_col);
|
||||
|
||||
A += transpose_a ? c_row_long : c_row_long * params->lda;
|
||||
B += transpose_b ? c_col_long * params->ldb : c_col_long;
|
||||
C += c_row_long * params->ldd + c_col_long;
|
||||
|
||||
// Prepare threadgroup mma operation
|
||||
thread mma_t mma_op(simd_group_id, simd_lane_id);
|
||||
|
||||
// Prepare threadgroup loading operations
|
||||
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
||||
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
||||
|
||||
// Prepare threadgroup bounds
|
||||
const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row));
|
||||
const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col));
|
||||
|
||||
// Prepare iterations
|
||||
int gemm_k_iterations = params->gemm_k_iterations_aligned;
|
||||
|
||||
// Do unaligned K iterations first
|
||||
if (!align_K) {
|
||||
const int k_last = params->gemm_k_iterations_aligned * BK;
|
||||
const int k_remain = params->K - k_last;
|
||||
const size_t k_jump_a =
|
||||
transpose_a ? params->lda * size_t(k_last) : size_t(k_last);
|
||||
const size_t k_jump_b =
|
||||
transpose_b ? size_t(k_last) : params->ldb * size_t(k_last);
|
||||
|
||||
// Move loader source ahead to end
|
||||
loader_a.src += k_jump_a;
|
||||
loader_b.src += k_jump_b;
|
||||
|
||||
// Load tile
|
||||
const short2 tile_dims_A =
|
||||
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
|
||||
const short2 tile_dims_B =
|
||||
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
|
||||
|
||||
loader_a.load_safe(tile_dims_A);
|
||||
loader_b.load_safe(tile_dims_B);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Do matmul
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Reset source back to start
|
||||
loader_a.src -= k_jump_a;
|
||||
loader_b.src -= k_jump_b;
|
||||
}
|
||||
|
||||
// Matrix level aligned never check
|
||||
if (align_M && align_N) {
|
||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_unsafe();
|
||||
loader_b.load_unsafe();
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
|
||||
// Store results to device memory
|
||||
mma_op.store_result(C, params->ldd);
|
||||
} else {
|
||||
const short lbk = 0;
|
||||
|
||||
// Tile aligned don't check
|
||||
if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
lbk,
|
||||
LoopAlignment<true, true, true>{});
|
||||
mma_op.store_result(C, params->ldd);
|
||||
}
|
||||
|
||||
// Tile partially aligned check rows
|
||||
else if (align_N || tgp_bn == BN) {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
lbk,
|
||||
LoopAlignment<false, true, true>{});
|
||||
mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm));
|
||||
}
|
||||
|
||||
// Tile partially aligned check cols
|
||||
else if (align_M || tgp_bm == BM) {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
lbk,
|
||||
LoopAlignment<true, false, true>{});
|
||||
mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm));
|
||||
}
|
||||
|
||||
// Nothing aligned so check both rows and cols
|
||||
else {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
lbk,
|
||||
LoopAlignment<false, false, true>{});
|
||||
mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm));
|
||||
}
|
||||
}
|
||||
}
|
@@ -0,0 +1,59 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h"
|
||||
|
||||
#define instantiate_gather_mm_rhs(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_kernel( \
|
||||
"steel_gather_mm_rhs_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn \
|
||||
"_bk" #bk "_wm" #wm "_wn" #wn, \
|
||||
gather_mm_rhs, \
|
||||
itype, \
|
||||
bm, \
|
||||
bn, \
|
||||
bk, \
|
||||
wm, \
|
||||
wn, \
|
||||
trans_a, \
|
||||
trans_b, \
|
||||
float)
|
||||
|
||||
#define instantiate_gather_mm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_kernel( \
|
||||
"steel_gather_mm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn \
|
||||
"_bk" #bk "_wm" #wm "_wn" #wn, \
|
||||
gather_mm, \
|
||||
itype, \
|
||||
bm, \
|
||||
bn, \
|
||||
bk, \
|
||||
wm, \
|
||||
wn, \
|
||||
trans_a, \
|
||||
trans_b, \
|
||||
float)
|
||||
|
||||
#define instantiate_gather_mm_rhs_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gather_mm_rhs(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gather_mm_rhs(nt, false, true, iname, itype, oname, otype, bm, bn, bk, wm, wn)
|
||||
|
||||
#define instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gather_mm(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gather_mm(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gather_mm(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gather_mm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
|
||||
|
||||
#define instantiate_gather_mm_shapes_helper(iname, itype, oname, otype) \
|
||||
instantiate_gather_mm_rhs_transpose_helper(iname, itype, oname, otype, 16, 64, 16, 1, 2) \
|
||||
instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
|
||||
instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 1, 2) \
|
||||
instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
|
||||
instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 1, 2) \
|
||||
instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2)
|
||||
// clang-format on
|
||||
|
||||
instantiate_gather_mm_shapes_helper(float16, half, float16, half);
|
||||
instantiate_gather_mm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
|
||||
instantiate_gather_mm_shapes_helper(float32, float, float32, float);
|
@@ -113,7 +113,7 @@ struct BlockLoader {
|
||||
tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
|
||||
}
|
||||
|
||||
// Zero out uneeded values
|
||||
// Zero out unneeded values
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);
|
||||
|
@@ -142,6 +142,42 @@ struct BaseMMAFrag<T, 8, 8> {
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename DstPtrType,
|
||||
typename StrX,
|
||||
typename StrY,
|
||||
typename StartX,
|
||||
typename StopX,
|
||||
typename StartY,
|
||||
typename StopY,
|
||||
typename OffX,
|
||||
typename OffY>
|
||||
METAL_FUNC static constexpr void store_slice(
|
||||
const thread frag_type& src,
|
||||
DstPtrType dst,
|
||||
StrX str_x,
|
||||
StrY str_y,
|
||||
StartX start_x,
|
||||
StopX stop_x,
|
||||
StartY start_y,
|
||||
StopY stop_y,
|
||||
OffX off_x = Int<0>{},
|
||||
OffY off_y = Int<0>{}) {
|
||||
using U = pointer_element_t<DstPtrType>;
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kElemRows; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < kElemCols; j++) {
|
||||
if ((off_x + i) < stop_x && (off_x + i) >= start_x &&
|
||||
(off_y + j) < stop_y && (off_y + j) >= start_y) {
|
||||
dst[(off_x + i) * str_x + (off_y + j) * str_y] =
|
||||
static_cast<U>(src[i * kElemCols + j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC static constexpr void mma(
|
||||
thread frag_type& D,
|
||||
thread frag_type& A,
|
||||
@@ -335,6 +371,31 @@ struct MMATile {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U, int w_x, int w_y>
|
||||
METAL_FUNC void store_slice(
|
||||
device U* dst,
|
||||
const int ld,
|
||||
const short2 start,
|
||||
const short2 stop) const {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kTileRows; ++i) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int j = 0; j < kTileCols; ++j) {
|
||||
MMAFrag_t::store_slice(
|
||||
frag_at(i, j),
|
||||
dst,
|
||||
ld,
|
||||
Int<1>{},
|
||||
start.y,
|
||||
stop.y,
|
||||
start.x,
|
||||
stop.x,
|
||||
(i * kFragRows) * w_x,
|
||||
(j * kFragCols) * w_y);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, int M, int N, int K>
|
||||
@@ -474,6 +535,26 @@ struct BlockMMA {
|
||||
Ctile.template store<U, WM, WN>(D, ldd);
|
||||
}
|
||||
|
||||
METAL_FUNC void
|
||||
store_result_slice(device U* D, const int ldd, short2 start, short2 stop) {
|
||||
// Apply epilogue
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
|
||||
Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]);
|
||||
}
|
||||
|
||||
D += sm * ldd + sn;
|
||||
start -= short2(sn, sm);
|
||||
stop -= short2(sn, sm);
|
||||
|
||||
// TODO: Check the start as well
|
||||
if (stop.y <= 0 || stop.x <= 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
Ctile.template store_slice<U, WM, WN>(D, ldd, start, stop);
|
||||
}
|
||||
|
||||
METAL_FUNC void
|
||||
store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) {
|
||||
// Apply epilogue
|
||||
|
@@ -1,25 +1,32 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
template <typename T, typename Op>
|
||||
template <typename T, typename Op, int N = WorkPerThread<T>::n>
|
||||
[[kernel]] void ternary_v(
|
||||
device const bool* a,
|
||||
device const T* b,
|
||||
device const T* c,
|
||||
device T* d,
|
||||
constant uint& size,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
d[index] = Op()(a[index], b[index], c[index]);
|
||||
index *= N;
|
||||
for (int i = 0; i < N && (index + i) < size; ++i) {
|
||||
d[index + i] = Op()(a[index + i], b[index + i], c[index + i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename Op>
|
||||
template <typename T, typename Op, int N = WorkPerThread<T>::n>
|
||||
[[kernel]] void ternary_v2(
|
||||
device const bool* a,
|
||||
device const T* b,
|
||||
device const T* c,
|
||||
device T* d,
|
||||
constant int64_t& size,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto offset = index.x + grid_dim.x * int64_t(index.y);
|
||||
d[offset] = Op()(a[offset], b[offset], c[offset]);
|
||||
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||
for (int i = 0; i < N && (offset + i) < size; ++i) {
|
||||
d[offset + i] = Op()(a[offset + i], b[offset + i], c[offset + i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename Op, typename IdxT = int64_t>
|
||||
|
@@ -1,21 +1,28 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
[[kernel]] void unary_v(
|
||||
device const T* in,
|
||||
device U* out,
|
||||
constant uint& size,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
out[index] = Op()(in[index]);
|
||||
index *= N;
|
||||
for (int i = 0; i < N && (index + i) < size; ++i) {
|
||||
out[index + i] = Op()(in[index + i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
[[kernel]] void unary_v2(
|
||||
device const T* in,
|
||||
device U* out,
|
||||
constant int64_t& size,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto offset = index.x + grid_dim.x * int64_t(index.y);
|
||||
out[offset] = Op()(in[offset]);
|
||||
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||
for (int i = 0; i < N && (offset + i) < size; ++i) {
|
||||
out[offset + i] = Op()(in[offset + i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
|
@@ -69,17 +69,24 @@ instantiate_unary_float(Round)
|
||||
instantiate_unary_int(BitwiseInvert)
|
||||
|
||||
instantiate_unary_all_same(Abs, complex64, complex64_t)
|
||||
instantiate_unary_all_same(ArcCos, complex64, complex64_t)
|
||||
instantiate_unary_all_same(ArcSin, complex64, complex64_t)
|
||||
instantiate_unary_all_same(ArcTan, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Conjugate, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Cos, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Cosh, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Exp, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Log, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Log1p, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Log2, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Log10, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Negative, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Sign, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Sin, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Sinh, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Square, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Sqrt, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Rsqrt, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Tan, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Tanh, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Round, complex64, complex64_t)
|
||||
|
@@ -17,27 +17,21 @@ struct Abs {
|
||||
T operator()(T x) {
|
||||
return metal::abs(x);
|
||||
};
|
||||
template <>
|
||||
uint8_t operator()(uint8_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
uint16_t operator()(uint16_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
uint32_t operator()(uint32_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
uint64_t operator()(uint64_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
bool operator()(bool x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return {metal::precise::sqrt(x.real * x.real + x.imag * x.imag), 0};
|
||||
};
|
||||
@@ -48,6 +42,8 @@ struct ArcCos {
|
||||
T operator()(T x) {
|
||||
return metal::precise::acos(x);
|
||||
};
|
||||
|
||||
complex64_t operator()(complex64_t x);
|
||||
};
|
||||
|
||||
struct ArcCosh {
|
||||
@@ -62,6 +58,8 @@ struct ArcSin {
|
||||
T operator()(T x) {
|
||||
return metal::precise::asin(x);
|
||||
};
|
||||
|
||||
complex64_t operator()(complex64_t x);
|
||||
};
|
||||
|
||||
struct ArcSinh {
|
||||
@@ -76,6 +74,8 @@ struct ArcTan {
|
||||
T operator()(T x) {
|
||||
return metal::precise::atan(x);
|
||||
};
|
||||
|
||||
complex64_t operator()(complex64_t x);
|
||||
};
|
||||
|
||||
struct ArcTanh {
|
||||
@@ -97,39 +97,30 @@ struct Ceil {
|
||||
T operator()(T x) {
|
||||
return metal::ceil(x);
|
||||
};
|
||||
template <>
|
||||
int8_t operator()(int8_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
int16_t operator()(int16_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
int32_t operator()(int32_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
int64_t operator()(int64_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
uint8_t operator()(uint8_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
uint16_t operator()(uint16_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
uint32_t operator()(uint32_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
uint64_t operator()(uint64_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
bool operator()(bool x) {
|
||||
return x;
|
||||
};
|
||||
@@ -141,7 +132,6 @@ struct Cos {
|
||||
return metal::precise::cos(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return {
|
||||
metal::precise::cos(x.real) * metal::precise::cosh(x.imag),
|
||||
@@ -155,7 +145,6 @@ struct Cosh {
|
||||
return metal::precise::cosh(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return {
|
||||
metal::precise::cosh(x.real) * metal::precise::cos(x.imag),
|
||||
@@ -188,7 +177,6 @@ struct Exp {
|
||||
T operator()(T x) {
|
||||
return metal::precise::exp(x);
|
||||
};
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
auto m = metal::precise::exp(x.real);
|
||||
return {m * metal::precise::cos(x.imag), m * metal::precise::sin(x.imag)};
|
||||
@@ -207,39 +195,30 @@ struct Floor {
|
||||
T operator()(T x) {
|
||||
return metal::floor(x);
|
||||
};
|
||||
template <>
|
||||
int8_t operator()(int8_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
int16_t operator()(int16_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
int32_t operator()(int32_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
int64_t operator()(int64_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
uint8_t operator()(uint8_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
uint16_t operator()(uint16_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
uint32_t operator()(uint32_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
uint64_t operator()(uint64_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
bool operator()(bool x) {
|
||||
return x;
|
||||
};
|
||||
@@ -258,7 +237,6 @@ struct Log {
|
||||
return metal::precise::log(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
auto r = metal::precise::log(Abs{}(x).real);
|
||||
auto i = metal::precise::atan2(x.imag, x.real);
|
||||
@@ -272,7 +250,6 @@ struct Log2 {
|
||||
return metal::precise::log2(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
auto y = Log{}(x);
|
||||
return {y.real / M_LN2_F, y.imag / M_LN2_F};
|
||||
@@ -285,7 +262,6 @@ struct Log10 {
|
||||
return metal::precise::log10(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
auto y = Log{}(x);
|
||||
return {y.real / M_LN10_F, y.imag / M_LN10_F};
|
||||
@@ -325,7 +301,6 @@ struct Round {
|
||||
T operator()(T x) {
|
||||
return metal::rint(x);
|
||||
};
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return {metal::rint(x.real), metal::rint(x.imag)};
|
||||
};
|
||||
@@ -344,11 +319,9 @@ struct Sign {
|
||||
T operator()(T x) {
|
||||
return (x > T(0)) - (x < T(0));
|
||||
};
|
||||
template <>
|
||||
uint32_t operator()(uint32_t x) {
|
||||
return x != 0;
|
||||
};
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
if (x == complex64_t(0)) {
|
||||
return x;
|
||||
@@ -364,7 +337,6 @@ struct Sin {
|
||||
return metal::precise::sin(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return {
|
||||
metal::precise::sin(x.real) * metal::precise::cosh(x.imag),
|
||||
@@ -378,7 +350,6 @@ struct Sinh {
|
||||
return metal::precise::sinh(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return {
|
||||
metal::precise::sinh(x.real) * metal::precise::cos(x.imag),
|
||||
@@ -398,6 +369,17 @@ struct Sqrt {
|
||||
T operator()(T x) {
|
||||
return metal::precise::sqrt(x);
|
||||
};
|
||||
|
||||
complex64_t operator()(complex64_t x) {
|
||||
if (x.real == 0.0 && x.imag == 0.0) {
|
||||
return {0.0, 0.0};
|
||||
}
|
||||
auto r = Abs{}(x).real;
|
||||
auto a = metal::precise::sqrt((r + x.real) / 2.0);
|
||||
auto b_abs = metal::precise::sqrt((r - x.real) / 2.0);
|
||||
auto b = metal::copysign(b_abs, x.imag);
|
||||
return {a, b};
|
||||
}
|
||||
};
|
||||
|
||||
struct Rsqrt {
|
||||
@@ -405,6 +387,10 @@ struct Rsqrt {
|
||||
T operator()(T x) {
|
||||
return metal::precise::rsqrt(x);
|
||||
};
|
||||
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return 1.0 / Sqrt{}(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Tan {
|
||||
@@ -413,7 +399,6 @@ struct Tan {
|
||||
return metal::precise::tan(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
float tan_a = metal::precise::tan(x.real);
|
||||
float tanh_b = metal::precise::tanh(x.imag);
|
||||
@@ -429,7 +414,6 @@ struct Tanh {
|
||||
return metal::precise::tanh(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
float tanh_a = metal::precise::tanh(x.real);
|
||||
float tan_b = metal::precise::tan(x.imag);
|
||||
@@ -438,3 +422,21 @@ struct Tanh {
|
||||
return {(tanh_a + tan_b * t1) / denom, (tan_b - tanh_a * t1) / denom};
|
||||
};
|
||||
};
|
||||
|
||||
complex64_t ArcCos::operator()(complex64_t x) {
|
||||
auto i = complex64_t{0.0, 1.0};
|
||||
auto y = Log{}(x + i * Sqrt{}(1.0 - x * x));
|
||||
return {y.imag, -y.real};
|
||||
};
|
||||
|
||||
complex64_t ArcSin::operator()(complex64_t x) {
|
||||
auto i = complex64_t{0.0, 1.0};
|
||||
auto y = Log{}(i * x + Sqrt{}(1.0 - x * x));
|
||||
return {y.imag, -y.real};
|
||||
};
|
||||
|
||||
complex64_t ArcTan::operator()(complex64_t x) {
|
||||
auto i = complex64_t{0.0, 1.0};
|
||||
auto ix = i * x;
|
||||
return (1.0 / complex64_t{0.0, 2.0}) * Log{}((1.0 + ix) / (1.0 - ix));
|
||||
};
|
||||
|
@@ -15,6 +15,14 @@
|
||||
|
||||
typedef half float16_t;
|
||||
|
||||
// Work per thread values for different types. The values here are expected to
|
||||
// match get_work_per_thread in mlx/backend/metal/utils.h
|
||||
template <typename U>
|
||||
struct WorkPerThread {
|
||||
static_assert(sizeof(U) <= 8, "Type too large");
|
||||
static constexpr int constant n = 8 / sizeof(U);
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Type limits utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@@ -328,6 +336,23 @@ inline bfloat16_t log1p(bfloat16_t x) {
|
||||
return bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f)));
|
||||
}
|
||||
|
||||
inline complex64_t log1p(complex64_t in) {
|
||||
float x = in.real;
|
||||
float y = in.imag;
|
||||
float zabs = metal::precise::sqrt(x * x + y * y);
|
||||
float theta = metal::atan2(y, x + 1);
|
||||
if (zabs < 0.5f) {
|
||||
float r = x * (2 + x) + y * y;
|
||||
if (r == 0) { // handle underflow
|
||||
return {x, theta};
|
||||
}
|
||||
return {0.5f * log1p(r), theta};
|
||||
} else {
|
||||
auto z0 = metal::sqrt((x + 1) * (x + 1) + y * y);
|
||||
return {metal::log(z0), theta};
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// SIMD shuffle ops
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
@@ -1,7 +1,7 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#include <algorithm>
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
|
@@ -5,8 +5,9 @@
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/common/broadcasting.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
@@ -102,6 +103,47 @@ std::tuple<bool, int64_t, array> check_transpose(
|
||||
}
|
||||
};
|
||||
|
||||
inline array
|
||||
ensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) {
|
||||
if (!x.flags().row_contiguous) {
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
d.add_temporary(x_copy, s.index);
|
||||
return x_copy;
|
||||
} else {
|
||||
return x;
|
||||
}
|
||||
}
|
||||
|
||||
inline std::tuple<bool, int64_t, array>
|
||||
ensure_batch_contiguous(const array& x, metal::Device& d, const Stream& s) {
|
||||
if (x.flags().row_contiguous) {
|
||||
return std::make_tuple(false, x.strides()[x.ndim() - 2], x);
|
||||
}
|
||||
|
||||
bool rc = true;
|
||||
for (int i = 0; i < x.ndim() - 3; i++) {
|
||||
rc &= x.strides()[i + 1] * x.shape(i) == x.strides()[i];
|
||||
}
|
||||
if (rc) {
|
||||
auto stx = x.strides()[x.ndim() - 2];
|
||||
auto sty = x.strides()[x.ndim() - 1];
|
||||
auto K = x.shape(-2);
|
||||
auto N = x.shape(-1);
|
||||
if (sty == 1 && (N != 1 || stx == N)) {
|
||||
return std::make_tuple(false, stx, x);
|
||||
}
|
||||
if (stx == 1 && (N != 1 || sty == K)) {
|
||||
return std::make_tuple(true, sty, x);
|
||||
}
|
||||
}
|
||||
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
d.add_temporary(x_copy, s.index);
|
||||
return std::make_tuple(false, x_copy.strides()[x_copy.ndim() - 2], x_copy);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@@ -230,7 +272,6 @@ void steel_matmul_regular(
|
||||
const bool align_M = (M % bm) == 0;
|
||||
const bool align_N = (N % bn) == 0;
|
||||
const bool align_K = (K % bk) == 0;
|
||||
const bool do_gather = false;
|
||||
|
||||
metal::MTLFCList func_consts = {
|
||||
{&has_batch, MTL::DataType::DataTypeBool, 10},
|
||||
@@ -239,7 +280,6 @@ void steel_matmul_regular(
|
||||
{&align_M, MTL::DataType::DataTypeBool, 200},
|
||||
{&align_N, MTL::DataType::DataTypeBool, 201},
|
||||
{&align_K, MTL::DataType::DataTypeBool, 202},
|
||||
{&do_gather, MTL::DataType::DataTypeBool, 300},
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
@@ -248,8 +288,7 @@ void steel_matmul_regular(
|
||||
<< "_do_axpby_" << (do_axpby ? 't' : 'n')
|
||||
<< "_align_M_" << (align_M ? 't' : 'n')
|
||||
<< "_align_N_" << (align_N ? 't' : 'n')
|
||||
<< "_align_K_" << (align_K ? 't' : 'n')
|
||||
<< "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on
|
||||
<< "_align_K_" << (align_K ? 't' : 'n'); // clang-format on
|
||||
|
||||
std::string hash_name = kname.str();
|
||||
|
||||
@@ -975,7 +1014,6 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
const bool align_M = (M % bm) == 0;
|
||||
const bool align_N = (N % bn) == 0;
|
||||
const bool align_K = (K % bk) == 0;
|
||||
const bool do_gather = false;
|
||||
|
||||
metal::MTLFCList func_consts = {
|
||||
{&has_batch, MTL::DataType::DataTypeBool, 10},
|
||||
@@ -984,7 +1022,6 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
{&align_M, MTL::DataType::DataTypeBool, 200},
|
||||
{&align_N, MTL::DataType::DataTypeBool, 201},
|
||||
{&align_K, MTL::DataType::DataTypeBool, 202},
|
||||
{&do_gather, MTL::DataType::DataTypeBool, 300},
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
@@ -993,8 +1030,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
<< "_do_axpby_" << (do_axpby ? 't' : 'n')
|
||||
<< "_align_M_" << (align_M ? 't' : 'n')
|
||||
<< "_align_N_" << (align_N ? 't' : 'n')
|
||||
<< "_align_K_" << (align_K ? 't' : 'n')
|
||||
<< "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on
|
||||
<< "_align_K_" << (align_K ? 't' : 'n'); // clang-format on
|
||||
|
||||
std::string hash_name = kname.str();
|
||||
|
||||
@@ -1464,267 +1500,337 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
}
|
||||
|
||||
void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
using namespace mlx::steel;
|
||||
// assert(inputs.size() == 2);
|
||||
if (!issubdtype(out.dtype(), floating)) {
|
||||
throw std::runtime_error(
|
||||
"[GatherMM] Does not yet support non-floating point types.");
|
||||
}
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
void gather_mm_rhs(
|
||||
const array& a_,
|
||||
const array& b_,
|
||||
const array& indices_,
|
||||
array& out,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
array indices = ensure_row_contiguous(indices_, d, s);
|
||||
auto [transpose_b, ldb, b] = ensure_batch_contiguous(b_, d, s);
|
||||
|
||||
auto& a_pre = inputs[0];
|
||||
auto& b_pre = inputs[1];
|
||||
// Return 0s if either input is empty
|
||||
if (a_pre.size() == 0 || b_pre.size() == 0) {
|
||||
array zero = array(0, a_pre.dtype());
|
||||
fill_gpu(zero, out, s);
|
||||
d.add_temporary(std::move(zero), s.index);
|
||||
return;
|
||||
}
|
||||
// Broadcast a with indices. If we are here that means lhs_indices were not
|
||||
// provided so the lhs_indices are implied to be the shape of a broadcasted
|
||||
// with rhs_indices. We need only broadcast a and copy it as if applying the
|
||||
// lhs_indices.
|
||||
auto broadcast_with_indices = [&d, &s, &indices](const array& x) {
|
||||
if (x.size() / x.shape(-2) / x.shape(-1) == indices.size()) {
|
||||
return ensure_row_contiguous(x, d, s);
|
||||
}
|
||||
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
auto x_shape = indices.shape();
|
||||
x_shape.push_back(x.shape(-2));
|
||||
x_shape.push_back(x.shape(-1));
|
||||
array new_x(std::move(x_shape), x.dtype(), nullptr, {});
|
||||
broadcast(x, new_x);
|
||||
return ensure_row_contiguous(new_x, d, s);
|
||||
};
|
||||
array a = broadcast_with_indices(a_);
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Init checks and prep
|
||||
// Extract the matmul shapes
|
||||
int K = a.shape(-1);
|
||||
int M = a.size() / K;
|
||||
int N = b.shape(-1);
|
||||
int lda = a.strides()[a.ndim() - 2]; // should be K
|
||||
|
||||
int M = a_pre.shape(-2);
|
||||
int N = b_pre.shape(-1);
|
||||
int K = a_pre.shape(-1);
|
||||
// Define the dispatch blocks
|
||||
int bm = 16, bn = 64, bk = 16;
|
||||
int wm = 1, wn = 2;
|
||||
|
||||
// Keep a vector with copies to be cleared in the completed buffer to release
|
||||
// the arrays
|
||||
std::vector<array> copies;
|
||||
auto [transpose_a, a_cols, a] = check_transpose(copies, s, a_pre, M == 1);
|
||||
auto [transpose_b, b_cols, b] = check_transpose(copies, s, b_pre, N == 1);
|
||||
const bool align_M = (M % bm) == 0;
|
||||
const bool align_N = (N % bn) == 0;
|
||||
const bool align_K = (K % bk) == 0;
|
||||
|
||||
int lda = a_cols;
|
||||
int ldb = b_cols;
|
||||
// Define the kernel name
|
||||
std::string base_name;
|
||||
base_name.reserve(64);
|
||||
concatenate(
|
||||
base_name,
|
||||
"steel_gather_mm_rhs_n",
|
||||
transpose_b ? 't' : 'n',
|
||||
'_',
|
||||
type_to_name(a),
|
||||
'_',
|
||||
type_to_name(out),
|
||||
"_bm",
|
||||
bm,
|
||||
"_bn",
|
||||
bn,
|
||||
"_bk",
|
||||
bk,
|
||||
"_wm",
|
||||
wm,
|
||||
"_wn",
|
||||
wn);
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Check and collapse batch dimensions
|
||||
|
||||
auto get_batch_dims = [](const auto& v) {
|
||||
return decltype(v){v.begin(), v.end() - 2};
|
||||
metal::MTLFCList func_consts = {
|
||||
{&align_M, MTL::DataType::DataTypeBool, 200},
|
||||
{&align_N, MTL::DataType::DataTypeBool, 201},
|
||||
{&align_K, MTL::DataType::DataTypeBool, 202},
|
||||
};
|
||||
|
||||
auto& lhs_indices = inputs[2];
|
||||
auto& rhs_indices = inputs[3];
|
||||
// And the kernel hash that includes the function constants
|
||||
std::string hash_name;
|
||||
hash_name.reserve(128);
|
||||
concatenate(
|
||||
hash_name,
|
||||
base_name,
|
||||
"_align_M_",
|
||||
align_M ? 't' : 'n',
|
||||
"_align_N_",
|
||||
align_N ? 't' : 'n',
|
||||
"_align_K_",
|
||||
align_K ? 't' : 'n');
|
||||
|
||||
Shape batch_shape = get_batch_dims(out.shape());
|
||||
Strides batch_strides;
|
||||
// Get and set the kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = get_steel_gemm_gather_kernel(
|
||||
d,
|
||||
base_name,
|
||||
hash_name,
|
||||
func_consts,
|
||||
out,
|
||||
false,
|
||||
transpose_b,
|
||||
bm,
|
||||
bn,
|
||||
bk,
|
||||
wm,
|
||||
wn,
|
||||
true);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
batch_strides.insert(
|
||||
batch_strides.end(),
|
||||
lhs_indices.strides().begin(),
|
||||
lhs_indices.strides().end());
|
||||
auto lhs_indices_str = batch_strides.empty() ? 0 : batch_strides.back();
|
||||
// Prepare the matmul params
|
||||
auto batch_stride_b = b.ndim() > 2 ? b.strides()[b.ndim() - 3] : b.size();
|
||||
steel::GEMMParams params{
|
||||
/* const int M = */ M,
|
||||
/* const int N = */ N,
|
||||
/* const int K = */ K,
|
||||
/* const int lda = */ lda,
|
||||
/* const int ldb = */ static_cast<int>(ldb),
|
||||
/* const int ldd = */ N,
|
||||
/* const int tiles_n = */ (N + bn - 1) / bn,
|
||||
/* const int tiles_m = */ (M + bm - 1) / bm,
|
||||
/* const int64_t batch_stride_a = */ 0,
|
||||
/* const int64_t batch_stride_b = */ static_cast<int64_t>(batch_stride_b),
|
||||
/* const int64_t batch_stride_d = */ 0,
|
||||
/* const int swizzle_log = */ 0,
|
||||
/* const int gemm_k_iterations_aligned = */ (K / bk),
|
||||
/* const int batch_ndim = */ 0};
|
||||
|
||||
batch_strides.insert(
|
||||
batch_strides.end(),
|
||||
rhs_indices.strides().begin(),
|
||||
rhs_indices.strides().end());
|
||||
auto rhs_indices_str = batch_strides.empty() ? 0 : batch_strides.back();
|
||||
// Prepare the grid
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(params.tiles_n, params.tiles_m, 1);
|
||||
|
||||
int batch_ndim = batch_shape.size();
|
||||
// Launch kernel
|
||||
compute_encoder.set_input_array(a, 0);
|
||||
compute_encoder.set_input_array(b, 1);
|
||||
compute_encoder.set_input_array(indices, 2);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
compute_encoder.set_bytes(params, 4);
|
||||
|
||||
if (batch_ndim == 0) {
|
||||
batch_shape = {1};
|
||||
batch_strides = {0};
|
||||
}
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
int batch_ndim_A = a.ndim() - 2;
|
||||
int batch_ndim_B = b.ndim() - 2;
|
||||
std::vector<int> operand_batch_ndim = {batch_ndim_A, batch_ndim_B};
|
||||
void gather_mv(
|
||||
const array& mat_,
|
||||
const array& vec_,
|
||||
const array& mat_indices_,
|
||||
const array& vec_indices_,
|
||||
array& out,
|
||||
int N,
|
||||
int K,
|
||||
bool is_mv,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
// Copy if needed
|
||||
std::vector<array> copies;
|
||||
auto [transpose_mat, mat_cols, mat] =
|
||||
check_transpose(copies, s, mat_, N == 1);
|
||||
auto [transpose_vec, vec_cols, vec] = check_transpose(copies, s, vec_, true);
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
|
||||
Shape batch_shape_A = get_batch_dims(a.shape());
|
||||
Strides batch_strides_A = get_batch_dims(a.strides());
|
||||
Shape batch_shape_B = get_batch_dims(b.shape());
|
||||
Strides batch_strides_B = get_batch_dims(b.strides());
|
||||
// If we are doing vector matrix instead of matrix vector we need to flip the
|
||||
// matrix transposition. Basically m @ v = v @ m.T assuming that v is treated
|
||||
// as a one dimensional array.
|
||||
transpose_mat = (!is_mv) ^ transpose_mat;
|
||||
|
||||
if (batch_ndim_A == 0) {
|
||||
batch_shape_A = {1};
|
||||
batch_strides_A = {0};
|
||||
}
|
||||
// Define some shapes
|
||||
int in_vector_len = K;
|
||||
int out_vector_len = N;
|
||||
int mat_ld = mat_cols;
|
||||
|
||||
if (batch_ndim_B == 0) {
|
||||
batch_shape_B = {1};
|
||||
batch_strides_B = {0};
|
||||
}
|
||||
int batch_size_out = out.size() / N;
|
||||
int batch_ndim = out.ndim() - 2;
|
||||
int batch_ndim_mat = mat.ndim() - 2;
|
||||
int batch_ndim_vec = vec.ndim() - 2;
|
||||
Strides index_strides = vec_indices_.strides();
|
||||
index_strides.insert(
|
||||
index_strides.end(),
|
||||
mat_indices_.strides().begin(),
|
||||
mat_indices_.strides().end());
|
||||
|
||||
auto matrix_stride_out = static_cast<int64_t>(M) * N;
|
||||
auto batch_size_out = out.size() / matrix_stride_out;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Gemv specialization
|
||||
|
||||
// Route to gemv if needed
|
||||
if (std::min(M, N) == 1) {
|
||||
// Collect problem info
|
||||
bool is_b_matrix = N != 1;
|
||||
|
||||
auto& mat = is_b_matrix ? b : a;
|
||||
auto& vec = is_b_matrix ? a : b;
|
||||
bool transpose_mat = is_b_matrix ? !transpose_b : transpose_a;
|
||||
int in_vector_len = K;
|
||||
int out_vector_len = is_b_matrix ? N : M;
|
||||
|
||||
int mat_cols = transpose_mat ? out_vector_len : in_vector_len;
|
||||
int mat_rows = transpose_mat ? in_vector_len : out_vector_len;
|
||||
int mat_ld = is_b_matrix ? b_cols : a_cols;
|
||||
|
||||
auto batch_strides_mat = is_b_matrix ? batch_strides_B : batch_strides_A;
|
||||
auto batch_strides_vec = is_b_matrix ? batch_strides_A : batch_strides_B;
|
||||
|
||||
auto batch_shape_mat = is_b_matrix ? batch_shape_B : batch_shape_A;
|
||||
auto batch_shape_vec = is_b_matrix ? batch_shape_A : batch_shape_B;
|
||||
|
||||
if (!is_b_matrix) {
|
||||
batch_strides = rhs_indices.strides();
|
||||
batch_strides.insert(
|
||||
batch_strides.end(),
|
||||
lhs_indices.strides().begin(),
|
||||
lhs_indices.strides().end());
|
||||
}
|
||||
|
||||
int batch_ndim = batch_shape.size();
|
||||
|
||||
// Determine dispatch kernel
|
||||
int tm = 4, tn = 4;
|
||||
int sm = 1, sn = 32;
|
||||
int bm = 1, bn = 1;
|
||||
int n_out_per_tgp;
|
||||
std::ostringstream kname;
|
||||
|
||||
if (transpose_mat) {
|
||||
if (in_vector_len >= 8192 && out_vector_len >= 2048) {
|
||||
sm = 4;
|
||||
sn = 8;
|
||||
} else {
|
||||
sm = 8;
|
||||
sn = 4;
|
||||
}
|
||||
|
||||
if (out_vector_len >= 2048) {
|
||||
bn = 16;
|
||||
} else if (out_vector_len >= 512) {
|
||||
bn = 4;
|
||||
} else {
|
||||
bn = 2;
|
||||
}
|
||||
|
||||
// Specialized kernel for very small outputs
|
||||
tn = out_vector_len < tn ? 1 : tn;
|
||||
|
||||
n_out_per_tgp = bn * sn * tn;
|
||||
kname << "gemv_t_gather_" << type_to_name(out);
|
||||
// Determine dispatch kernel
|
||||
int tm = 4, tn = 4;
|
||||
int sm = 1, sn = 32;
|
||||
int bm = 1, bn = 1;
|
||||
int n_out_per_tgp;
|
||||
std::ostringstream kname;
|
||||
|
||||
if (transpose_mat) {
|
||||
if (in_vector_len >= 8192 && out_vector_len >= 2048) {
|
||||
sm = 4;
|
||||
sn = 8;
|
||||
} else {
|
||||
bm = out_vector_len >= 4096 ? 8 : 4;
|
||||
sn = 32;
|
||||
|
||||
// Specialized kernel for very small outputs
|
||||
tm = out_vector_len < tm ? 1 : tm;
|
||||
|
||||
n_out_per_tgp = bm * sm * tm;
|
||||
kname << "gemv_gather_" << type_to_name(out);
|
||||
sm = 8;
|
||||
sn = 4;
|
||||
}
|
||||
|
||||
kname << "_bm" << bm << "_bn" << bn << "_sm" << sm << "_sn" << sn << "_tm"
|
||||
<< tm << "_tn" << tn;
|
||||
if (out_vector_len >= 2048) {
|
||||
bn = 16;
|
||||
} else if (out_vector_len >= 512) {
|
||||
bn = 4;
|
||||
} else {
|
||||
bn = 2;
|
||||
}
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
// Specialized kernel for very small outputs
|
||||
tn = out_vector_len < tn ? 1 : tn;
|
||||
|
||||
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
|
||||
MTL::Size group_dims = MTL::Size(32, bn, bm);
|
||||
MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out);
|
||||
n_out_per_tgp = bn * sn * tn;
|
||||
kname << "gemv_t_gather_" << type_to_name(out);
|
||||
|
||||
compute_encoder.set_input_array(mat, 0);
|
||||
compute_encoder.set_input_array(vec, 1);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
} else {
|
||||
bm = out_vector_len >= 4096 ? 8 : 4;
|
||||
sn = 32;
|
||||
|
||||
compute_encoder.set_bytes(in_vector_len, 4);
|
||||
compute_encoder.set_bytes(out_vector_len, 5);
|
||||
compute_encoder.set_bytes(mat_ld, 6);
|
||||
// Specialized kernel for very small outputs
|
||||
tm = out_vector_len < tm ? 1 : tm;
|
||||
|
||||
compute_encoder.set_bytes(batch_ndim, 9);
|
||||
compute_encoder.set_vector_bytes(batch_shape, 10);
|
||||
compute_encoder.set_vector_bytes(batch_strides, 11);
|
||||
|
||||
int batch_ndim_vec = batch_shape_vec.size();
|
||||
compute_encoder.set_bytes(batch_ndim_vec, 12);
|
||||
compute_encoder.set_vector_bytes(batch_shape_vec, 13);
|
||||
compute_encoder.set_vector_bytes(batch_strides_vec, 14);
|
||||
|
||||
int batch_ndim_mat = batch_shape_mat.size();
|
||||
compute_encoder.set_bytes(batch_ndim_mat, 15);
|
||||
compute_encoder.set_vector_bytes(batch_shape_mat, 16);
|
||||
compute_encoder.set_vector_bytes(batch_strides_mat, 17);
|
||||
|
||||
compute_encoder.set_input_array(lhs_indices, 18 + int(!is_b_matrix));
|
||||
compute_encoder.set_input_array(rhs_indices, 18 + int(is_b_matrix));
|
||||
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
return;
|
||||
n_out_per_tgp = bm * sm * tm;
|
||||
kname << "gemv_gather_" << type_to_name(out);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Regular kernel dispatch
|
||||
kname << "_bm" << bm << "_bn" << bn << "_sm" << sm << "_sn" << sn << "_tm"
|
||||
<< tm << "_tn" << tn;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
|
||||
MTL::Size group_dims = MTL::Size(32, bn, bm);
|
||||
MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out);
|
||||
|
||||
compute_encoder.set_input_array(mat, 0);
|
||||
compute_encoder.set_input_array(vec, 1);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
|
||||
compute_encoder.set_bytes(in_vector_len, 4);
|
||||
compute_encoder.set_bytes(out_vector_len, 5);
|
||||
compute_encoder.set_bytes(mat_ld, 6);
|
||||
|
||||
compute_encoder.set_bytes(batch_ndim, 9);
|
||||
compute_encoder.set_vector_bytes(out.shape(), 10);
|
||||
compute_encoder.set_vector_bytes(index_strides, 11);
|
||||
|
||||
compute_encoder.set_bytes(batch_ndim_vec, 12);
|
||||
compute_encoder.set_vector_bytes(vec.shape(), 13);
|
||||
compute_encoder.set_vector_bytes(vec.strides(), 14);
|
||||
|
||||
compute_encoder.set_bytes(batch_ndim_mat, 15);
|
||||
compute_encoder.set_vector_bytes(mat.shape(), 16);
|
||||
compute_encoder.set_vector_bytes(mat.strides(), 17);
|
||||
|
||||
compute_encoder.set_input_array(vec_indices_, 18);
|
||||
compute_encoder.set_input_array(mat_indices_, 19);
|
||||
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void gather_mm(
|
||||
const array& a_,
|
||||
const array& b_,
|
||||
const array& lhs_indices,
|
||||
const array& rhs_indices,
|
||||
array& out,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
// Copy if needed
|
||||
std::vector<array> copies;
|
||||
auto [transpose_a, lda, a] = check_transpose(copies, s, a_, false);
|
||||
auto [transpose_b, ldb, b] = check_transpose(copies, s, b_, false);
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
|
||||
// Determine dispatch kernel
|
||||
int bm = 64, bn = 64, bk = 16;
|
||||
int wm = 2, wn = 2;
|
||||
size_t batch_size_out = out.size() / M / N;
|
||||
int batch_ndim = out.ndim() - 2;
|
||||
int batch_ndim_a = a.ndim() - 2;
|
||||
int batch_ndim_b = b.ndim() - 2;
|
||||
|
||||
char devc = d.get_architecture().back();
|
||||
GEMM_TPARAM_MACRO(devc)
|
||||
|
||||
// Prepare kernel name
|
||||
std::ostringstream kname;
|
||||
kname << "steel_gemm_fused_" << (transpose_a ? 't' : 'n')
|
||||
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
|
||||
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
|
||||
<< "_wm" << wm << "_wn" << wn;
|
||||
|
||||
std::string base_name = kname.str();
|
||||
|
||||
const bool has_batch = batch_ndim > 1;
|
||||
const bool use_out_source = false;
|
||||
const bool do_axpby = false;
|
||||
const bool align_M = (M % bm) == 0;
|
||||
const bool align_N = (N % bn) == 0;
|
||||
const bool align_K = (K % bk) == 0;
|
||||
const bool do_gather = true;
|
||||
|
||||
// Define the kernel name
|
||||
std::string base_name;
|
||||
base_name.reserve(128);
|
||||
concatenate(
|
||||
base_name,
|
||||
"steel_gather_mm_",
|
||||
transpose_a ? 't' : 'n',
|
||||
transpose_b ? 't' : 'n',
|
||||
"_",
|
||||
type_to_name(a),
|
||||
"_",
|
||||
type_to_name(out),
|
||||
"_bm",
|
||||
bm,
|
||||
"_bn",
|
||||
bn,
|
||||
"_bk",
|
||||
bk,
|
||||
"_wm",
|
||||
wm,
|
||||
"_wn",
|
||||
wn);
|
||||
|
||||
metal::MTLFCList func_consts = {
|
||||
{&has_batch, MTL::DataType::DataTypeBool, 10},
|
||||
{&use_out_source, MTL::DataType::DataTypeBool, 100},
|
||||
{&do_axpby, MTL::DataType::DataTypeBool, 110},
|
||||
{&align_M, MTL::DataType::DataTypeBool, 200},
|
||||
{&align_N, MTL::DataType::DataTypeBool, 201},
|
||||
{&align_K, MTL::DataType::DataTypeBool, 202},
|
||||
{&do_gather, MTL::DataType::DataTypeBool, 300},
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
kname << "_has_batch_" << (has_batch ? 't' : 'n')
|
||||
<< "_use_out_source_" << (use_out_source ? 't' : 'n')
|
||||
<< "_do_axpby_" << (do_axpby ? 't' : 'n')
|
||||
<< "_align_M_" << (align_M ? 't' : 'n')
|
||||
<< "_align_N_" << (align_N ? 't' : 'n')
|
||||
<< "_align_K_" << (align_K ? 't' : 'n')
|
||||
<< "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on
|
||||
// And the kernel hash that includes the function constants
|
||||
std::string hash_name;
|
||||
hash_name.reserve(128);
|
||||
concatenate(
|
||||
hash_name,
|
||||
base_name,
|
||||
"_has_batch_",
|
||||
has_batch ? 't' : 'n',
|
||||
"_align_M_",
|
||||
align_M ? 't' : 'n',
|
||||
"_align_N_",
|
||||
align_N ? 't' : 'n',
|
||||
"_align_K_",
|
||||
align_K ? 't' : 'n');
|
||||
|
||||
std::string hash_name = kname.str();
|
||||
|
||||
// Encode and dispatch kernel
|
||||
// Get and set the kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = get_steel_gemm_fused_kernel(
|
||||
auto kernel = get_steel_gemm_gather_kernel(
|
||||
d,
|
||||
base_name,
|
||||
hash_name,
|
||||
@@ -1736,72 +1842,96 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
bn,
|
||||
bk,
|
||||
wm,
|
||||
wn);
|
||||
|
||||
wn,
|
||||
false);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Use problem size to determine threadblock swizzle
|
||||
int tn = (N + bn - 1) / bn;
|
||||
int tm = (M + bm - 1) / bm;
|
||||
|
||||
// TODO: Explore device-based tuning for swizzle
|
||||
int swizzle_log = 0; // tm >= 6 ? 3 : (tm <= 3 ? 0 : 2);
|
||||
|
||||
// Prepare steel matmul params
|
||||
GEMMParams params{
|
||||
// Prepare the matmul params
|
||||
steel::GEMMParams params{
|
||||
/* const int M = */ M,
|
||||
/* const int N = */ N,
|
||||
/* const int K = */ K,
|
||||
/* const int lda = */ lda,
|
||||
/* const int ldb = */ ldb,
|
||||
/* const int lda = */ static_cast<int>(lda),
|
||||
/* const int ldb = */ static_cast<int>(ldb),
|
||||
/* const int ldd = */ N,
|
||||
/* const int tiles_n = */ tn,
|
||||
/* const int tiles_m = */ tm,
|
||||
/* const int64_t batch_stride_a = */ lhs_indices_str,
|
||||
/* const int64_t batch_stride_b = */ rhs_indices_str,
|
||||
/* const int64_t batch_stride_d = */ matrix_stride_out,
|
||||
/* const int swizzle_log = */ swizzle_log,
|
||||
/* const int tiles_n = */ (N + bn - 1) / bn,
|
||||
/* const int tiles_m = */ (M + bm - 1) / bm,
|
||||
/* const int64_t batch_stride_a = */
|
||||
(batch_ndim > 0) ? lhs_indices.strides()[0] : 0,
|
||||
/* const int64_t batch_stride_b = */
|
||||
(batch_ndim > 0) ? rhs_indices.strides()[0] : 0,
|
||||
/* const int64_t batch_stride_d = */ M * N,
|
||||
/* const int swizzle_log = */ 0,
|
||||
/* const int gemm_k_iterations_aligned = */ (K / bk),
|
||||
/* const int batch_ndim = */ batch_ndim};
|
||||
|
||||
// Prepare launch grid params
|
||||
int tile = 1 << swizzle_log;
|
||||
tm = (tm + tile - 1) / tile;
|
||||
tn = tn * tile;
|
||||
|
||||
// Prepare the grid
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out);
|
||||
MTL::Size grid_dims =
|
||||
MTL::Size(params.tiles_n, params.tiles_m, batch_size_out);
|
||||
|
||||
// Launch kernel
|
||||
compute_encoder.set_input_array(a, 0);
|
||||
compute_encoder.set_input_array(b, 1);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
|
||||
compute_encoder.set_bytes(params, 4);
|
||||
|
||||
compute_encoder.set_vector_bytes(batch_shape, 6);
|
||||
compute_encoder.set_vector_bytes(batch_strides, 7);
|
||||
|
||||
compute_encoder.set_input_array(lhs_indices, 10);
|
||||
compute_encoder.set_input_array(rhs_indices, 11);
|
||||
|
||||
std::vector operand_shape = batch_shape_A;
|
||||
operand_shape.insert(
|
||||
operand_shape.end(), batch_shape_B.begin(), batch_shape_B.end());
|
||||
|
||||
std::vector operand_strides = batch_strides_A;
|
||||
operand_strides.insert(
|
||||
operand_strides.end(), batch_strides_B.begin(), batch_strides_B.end());
|
||||
|
||||
operand_batch_ndim.push_back(0);
|
||||
|
||||
compute_encoder.set_vector_bytes(operand_shape, 13);
|
||||
compute_encoder.set_vector_bytes(operand_strides, 14);
|
||||
compute_encoder.set_vector_bytes(operand_batch_ndim, 15);
|
||||
|
||||
compute_encoder.set_input_array(lhs_indices, 2);
|
||||
compute_encoder.set_input_array(rhs_indices, 3);
|
||||
compute_encoder.set_output_array(out, 4);
|
||||
compute_encoder.set_bytes(params, 5);
|
||||
compute_encoder.set_vector_bytes(lhs_indices.shape(), 6);
|
||||
compute_encoder.set_vector_bytes(lhs_indices.strides(), 7);
|
||||
compute_encoder.set_vector_bytes(rhs_indices.strides(), 8);
|
||||
compute_encoder.set_bytes(batch_ndim_a, 9);
|
||||
compute_encoder.set_vector_bytes(a.shape(), 10);
|
||||
compute_encoder.set_vector_bytes(a.strides(), 11);
|
||||
compute_encoder.set_bytes(batch_ndim_b, 12);
|
||||
compute_encoder.set_vector_bytes(b.shape(), 13);
|
||||
compute_encoder.set_vector_bytes(b.strides(), 14);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto& lhs_indices = inputs[2];
|
||||
auto& rhs_indices = inputs[3];
|
||||
|
||||
// Return 0s if either input is empty
|
||||
if (a.size() == 0 || b.size() == 0) {
|
||||
array zero = array(0, a.dtype());
|
||||
fill_gpu(zero, out, s);
|
||||
d.add_temporary(std::move(zero), s.index);
|
||||
return;
|
||||
}
|
||||
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
// Extract shapes from inputs.
|
||||
int M = a.shape(-2);
|
||||
int N = b.shape(-1);
|
||||
int K = a.shape(-1);
|
||||
|
||||
// We are walking a in order and b is also in order so we can batch up the
|
||||
// matmuls and reuse reading a and b.
|
||||
if (M == 1 && right_sorted_ == true) {
|
||||
gather_mm_rhs(a, b, rhs_indices, out, d, s);
|
||||
return;
|
||||
}
|
||||
|
||||
// Route to gather gemv if any of a or b are vectors
|
||||
if (M == 1) {
|
||||
gather_mv(b, a, rhs_indices, lhs_indices, out, N, K, false, d, s);
|
||||
return;
|
||||
}
|
||||
if (N == 1) {
|
||||
gather_mv(a, b, lhs_indices, rhs_indices, out, M, K, true, d, s);
|
||||
return;
|
||||
}
|
||||
|
||||
// Route to non specialized gather mm
|
||||
gather_mm(a, b, lhs_indices, rhs_indices, out, M, N, K, d, s);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -1,11 +1,11 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#include <memory>
|
||||
|
||||
#include <sys/sysctl.h>
|
||||
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/scheduler.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core::metal {
|
||||
|
||||
@@ -13,85 +13,6 @@ bool is_available() {
|
||||
return true;
|
||||
}
|
||||
|
||||
inline void check_error(MTL::CommandBuffer* cbuf) {
|
||||
if (cbuf->status() == MTL::CommandBufferStatusError) {
|
||||
std::ostringstream msg;
|
||||
msg << "[METAL] Command buffer execution failed: "
|
||||
<< cbuf->error()->localizedDescription()->utf8String();
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
void eval(array& arr) {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
auto s = arr.primitive().stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto command_buffer = d.get_command_buffer(s.index);
|
||||
|
||||
auto outputs = arr.outputs();
|
||||
{
|
||||
// If the array is a tracer hold a reference
|
||||
// to its inputs so they don't get donated
|
||||
std::vector<array> inputs;
|
||||
if (arr.is_tracer()) {
|
||||
inputs = arr.inputs();
|
||||
}
|
||||
|
||||
debug_set_primitive_buffer_label(command_buffer, arr.primitive());
|
||||
arr.primitive().eval_gpu(arr.inputs(), outputs);
|
||||
}
|
||||
std::unordered_set<std::shared_ptr<array::Data>> buffers;
|
||||
for (auto& in : arr.inputs()) {
|
||||
buffers.insert(in.data_shared_ptr());
|
||||
}
|
||||
for (auto& s : arr.siblings()) {
|
||||
buffers.insert(s.data_shared_ptr());
|
||||
}
|
||||
// Remove the output if it was donated to by an input
|
||||
if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) {
|
||||
buffers.erase(it);
|
||||
}
|
||||
|
||||
if (d.command_buffer_needs_commit(s.index)) {
|
||||
d.end_encoding(s.index);
|
||||
scheduler::notify_new_task(s);
|
||||
command_buffer->addCompletedHandler(
|
||||
[s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
|
||||
scheduler::notify_task_completion(s);
|
||||
check_error(cbuf);
|
||||
});
|
||||
d.commit_command_buffer(s.index);
|
||||
d.get_command_buffer(s.index);
|
||||
} else {
|
||||
command_buffer->addCompletedHandler(
|
||||
[s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
|
||||
check_error(cbuf);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void finalize(Stream s) {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
auto& d = metal::device(s.device);
|
||||
auto cb = d.get_command_buffer(s.index);
|
||||
d.end_encoding(s.index);
|
||||
cb->addCompletedHandler([s](MTL::CommandBuffer* cbuf) { check_error(cbuf); });
|
||||
d.commit_command_buffer(s.index);
|
||||
d.get_command_buffer(s.index);
|
||||
}
|
||||
|
||||
void synchronize(Stream s) {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
auto& d = metal::device(s.device);
|
||||
auto cb = d.get_command_buffer(s.index);
|
||||
cb->retain();
|
||||
d.end_encoding(s.index);
|
||||
d.commit_command_buffer(s.index);
|
||||
cb->waitUntilCompleted();
|
||||
check_error(cb);
|
||||
cb->release();
|
||||
}
|
||||
|
||||
void start_capture(std::string path, id object) {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
|
||||
@@ -128,4 +49,36 @@ void stop_capture() {
|
||||
manager->stopCapture();
|
||||
}
|
||||
|
||||
const std::unordered_map<std::string, std::variant<std::string, size_t>>&
|
||||
device_info() {
|
||||
auto init_device_info = []()
|
||||
-> std::unordered_map<std::string, std::variant<std::string, size_t>> {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
auto raw_device = device(default_device()).mtl_device();
|
||||
auto name = std::string(raw_device->name()->utf8String());
|
||||
auto arch = std::string(raw_device->architecture()->name()->utf8String());
|
||||
|
||||
size_t memsize = 0;
|
||||
size_t length = sizeof(memsize);
|
||||
sysctlbyname("hw.memsize", &memsize, &length, NULL, 0);
|
||||
|
||||
size_t rsrc_limit = 0;
|
||||
sysctlbyname("iogpu.rsrc_limit", &rsrc_limit, &length, NULL, 0);
|
||||
if (rsrc_limit == 0) {
|
||||
rsrc_limit = 499000;
|
||||
}
|
||||
|
||||
return {
|
||||
{"device_name", name},
|
||||
{"architecture", arch},
|
||||
{"max_buffer_length", raw_device->maxBufferLength()},
|
||||
{"max_recommended_working_set_size",
|
||||
raw_device->recommendedMaxWorkingSetSize()},
|
||||
{"memory_size", memsize},
|
||||
{"resource_limit", rsrc_limit}};
|
||||
};
|
||||
static auto device_info_ = init_device_info();
|
||||
return device_info_;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::metal
|
||||
|
@@ -2,11 +2,10 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <variant>
|
||||
|
||||
#include "mlx/array.h"
|
||||
|
||||
namespace mlx::core::metal {
|
||||
|
||||
/* Check if the Metal backend is available. */
|
||||
|
22
mlx/backend/metal/no_metal.cpp
Normal file
22
mlx/backend/metal/no_metal.cpp
Normal file
@@ -0,0 +1,22 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include <stdexcept>
|
||||
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
|
||||
namespace mlx::core::metal {
|
||||
|
||||
bool is_available() {
|
||||
return false;
|
||||
}
|
||||
|
||||
void start_capture(std::string) {}
|
||||
void stop_capture() {}
|
||||
|
||||
const std::unordered_map<std::string, std::variant<std::string, size_t>>&
|
||||
device_info() {
|
||||
throw std::runtime_error(
|
||||
"[metal::device_info] Cannot get device info without metal backend");
|
||||
};
|
||||
|
||||
} // namespace mlx::core::metal
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user