Compare commits

..

1 Commits

Author SHA1 Message Date
Jagrit Digani
4c46e17a5d Update benchmark output 2025-04-15 10:50:06 -07:00
168 changed files with 1817 additions and 6309 deletions

1
.gitignore vendored
View File

@@ -36,7 +36,6 @@ share/python-wheels/
.installed.cfg
*.egg
MANIFEST
uv.lock
# vim
*.swp

View File

@@ -34,7 +34,6 @@ option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
option(MLX_BUILD_METAL "Build metal backend" ON)
option(MLX_BUILD_CPU "Build cpu backend" ON)
option(MLX_BUILD_CUDA "Build cuda backend" OFF)
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
@@ -84,10 +83,6 @@ if(MLX_BUILD_METAL)
set(QUARTZ_LIB "-framework QuartzCore")
endif()
if(MLX_BUILD_CUDA)
enable_language(CUDA)
endif()
if(MLX_BUILD_METAL AND NOT METAL_LIB)
message(STATUS "Metal not found. Unable to build GPU")
set(MLX_BUILD_METAL OFF)

View File

@@ -1,6 +1,4 @@
include CMakeLists.txt
include mlx.pc.in
recursive-include mlx/ *
include cmake/*
include python/src/*
include python/mlx/py.typed # support type hinting as in PEP-561

View File

@@ -157,7 +157,7 @@ def bench_shape(B, M, N, K, np_dtype, transpose="nn"):
def get_gflop_count(B, M, N, K):
return float(2.0 * N_iter_bench * N_iter_func * B * M * N * K) / float(1024.0**3)
return float(2.0 * N_iter_bench * N_iter_func * B * M * N * K) / float(1000.0**3)
if __name__ == "__main__":
@@ -175,6 +175,8 @@ if __name__ == "__main__":
(1, 4096, 4096, 4096),
)
print(f" B, M, N, K, dtype, t, gflops_pt, gflops_mx, diff%")
for dtype in dtypes:
for transpose in transposes:
for B, M, N, K in shapes:
@@ -187,7 +189,7 @@ if __name__ == "__main__":
diff = gflops_mx / gflops_pt - 1.0
print(
f"{B:3d}, {M:4d}, {N:4d}, {K:4d}, {dtype}, {transpose}, {gflops_pt:05.3f}, {gflops_mx:05.3f}, {100. * diff:+5.2f}%"
f"{B:3d}, {M:4d}, {N:4d}, {K:5d}, {dtype}, {transpose}, {gflops_pt:8.2f}, {gflops_mx:8.2f}, {100. * diff:+5.2f}%"
)
if gflops_pt >= 2.0 * gflops_mx:
print("ATTENTION ^^^^^^^")

View File

@@ -1,4 +1,4 @@
# Copyright © 2025 Apple Inc.
# Copyright © 2023-2024 Apple Inc.
import mlx.core as mx
from time_utils import time_fn

View File

@@ -1,84 +0,0 @@
# Copyright © 2025 Apple Inc.
import mlx.core as mx
from time_utils import time_fn
N = 1024
D = 1024
M = 1024
E = 32
I = 4
def gather_sort(x, indices):
N, M = indices.shape
indices = indices.flatten()
order = mx.argsort(indices)
inv_order = mx.argsort(order)
return x.flatten(0, -3)[order // M], indices[order], inv_order
def scatter_unsort(x, inv_order, shape=None):
x = x[inv_order]
if shape is not None:
x = mx.unflatten(x, 0, shape)
return x
def gather_mm_simulate(x, w, indices):
x, idx, inv_order = gather_sort(x, indices)
for i in range(2):
y = mx.concatenate(
[
mx.quantized_matmul(x[i], w[0][j], w[1][j], w[2][j], transpose=True)
for i, j in enumerate(idx.tolist())
],
axis=0,
)
x = y[:, None]
x = scatter_unsort(x, inv_order, indices.shape)
return x
def time_gather_qmm():
x = mx.random.normal((N, 1, 1, D)) / 1024**0.5
w1 = mx.random.normal((E, M, D)) / 1024**0.5
w2 = mx.random.normal((E, D, M)) / 1024**0.5
w1 = mx.quantize(w1)
w2 = mx.quantize(w2)
indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)
sorted_indices = mx.sort(indices.flatten()).reshape(N, I)
mx.eval(x, w1, w2, indices, sorted_indices)
def gather_mm(x, w1, w2, indices, sort):
idx = indices
inv_order = None
if sort:
x, idx, inv_order = gather_sort(x, indices)
x = mx.gather_qmm(x, *w1, transpose=True, rhs_indices=idx, sorted_indices=sort)
x = mx.gather_qmm(x, *w2, transpose=True, rhs_indices=idx, sorted_indices=sort)
if sort:
x = scatter_unsort(x, inv_order, indices.shape)
return x
time_fn(gather_mm, x, w1, w2, indices, False)
time_fn(gather_mm, x, w1, w2, sorted_indices, False)
time_fn(gather_mm, x, w1, w2, indices, True)
x = mx.random.normal((N * I, D)) / 1024**0.5
w1 = mx.random.normal((M, D)) / 1024**0.5
w2 = mx.random.normal((D, M)) / 1024**0.5
w1 = mx.quantize(w1)
w2 = mx.quantize(w2)
mx.eval(x, w1, w2)
def equivalent_matmul(x, w1, w2):
x = mx.quantized_matmul(x, *w1, transpose=True)
x = mx.quantized_matmul(x, *w2, transpose=True)
return x
time_fn(equivalent_matmul, x, w1, w2)
if __name__ == "__main__":
time_gather_qmm()

View File

@@ -20,5 +20,3 @@ FFT
irfft2
rfftn
irfftn
fftshift
ifftshift

View File

@@ -5,7 +5,6 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
${CMAKE_CURRENT_SOURCE_DIR}/dtype_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/export.cpp
${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
@@ -49,16 +48,5 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
if(MLX_BUILD_METAL)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal)
else()
target_sources(mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/no_metal.cpp)
endif()
if(MLX_BUILD_CUDA)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda)
endif()
if(MLX_BUILD_METAL OR MLX_BUILD_CUDA)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/gpu)
else()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_gpu)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_metal)
endif()

View File

@@ -339,11 +339,11 @@ class array {
return allocator::allocator().size(buffer());
}
// Return the shared pointer to the array::Data struct
const std::shared_ptr<Data>& data_shared_ptr() const {
// Return a copy of the shared pointer
// to the array::Data struct
std::shared_ptr<Data> data_shared_ptr() const {
return array_desc_->data;
}
// Return a raw pointer to the arrays data
template <typename T>
T* data() {
@@ -356,7 +356,7 @@ class array {
}
enum Status {
// The output of a computation which has not been scheduled.
// The ouptut of a computation which has not been scheduled.
// For example, the status of `x` in `auto x = a + b`.
unscheduled,

View File

@@ -99,11 +99,7 @@ inline std::pair<int, int> decompose_hadamard(int n) {
"[hadamard] Only supports n = m*2^k where m in (1, 12, 20, 28).");
}
}
if (n > (1 << 26)) {
throw std::invalid_argument(
"[hadamard] Only supports n = m*2^k where k <= 26");
}
return {n, m};
}
} // namespace mlx::core
} // namespace mlx::core

View File

@@ -40,8 +40,7 @@ add_dependencies(mlx cpu_compiled_preamble)
target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/available.cpp
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp

View File

@@ -1,11 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cpu/available.h"
namespace mlx::core::cpu {
bool is_available() {
return true;
}
} // namespace mlx::core::cpu

View File

@@ -1,9 +0,0 @@
// Copyright © 2025 Apple Inc.
#pragma once
namespace mlx::core::cpu {
bool is_available();
} // namespace mlx::core::cpu

View File

@@ -172,12 +172,9 @@ void binary_float(
case bfloat16:
binary_op<bfloat16_t, Op>(a, b, out, bopt);
break;
case complex64:
binary_op<complex64_t, Op>(a, b, out, bopt);
break;
default:
throw std::runtime_error(
"[binary_float] Only supports floating point types.");
"[binary_float] Only supports non-complex floating point types.");
}
});
}

View File

@@ -40,10 +40,7 @@ struct CompilerCache {
std::shared_mutex mtx;
};
static CompilerCache& cache() {
static CompilerCache cache_;
return cache_;
};
static CompilerCache cache{};
// GPU compile is always available if the GPU is available and since we are in
// this file CPU compile is also available.
@@ -59,16 +56,14 @@ void* compile(
const std::string& kernel_name,
const std::function<std::string(void)>& source_builder) {
{
std::shared_lock lock(cache().mtx);
if (auto it = cache().kernels.find(kernel_name);
it != cache().kernels.end()) {
std::shared_lock lock(cache.mtx);
if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) {
return it->second;
}
}
std::unique_lock lock(cache().mtx);
if (auto it = cache().kernels.find(kernel_name);
it != cache().kernels.end()) {
std::unique_lock lock(cache.mtx);
if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) {
return it->second;
}
std::string source_code = source_builder();
@@ -125,10 +120,10 @@ void* compile(
}
// load library
cache().libs.emplace_back(shared_lib_path);
cache.libs.emplace_back(shared_lib_path);
// Load function
void* fun = dlsym(cache().libs.back().lib, kernel_name.c_str());
void* fun = dlsym(cache.libs.back().lib, kernel_name.c_str());
if (!fun) {
std::ostringstream msg;
msg << "[Compile::eval_cpu] Failed to load compiled function "
@@ -136,7 +131,7 @@ void* compile(
<< dlerror();
throw std::runtime_error(msg.str());
}
cache().kernels.insert({kernel_name, fun});
cache.kernels.insert({kernel_name, fun});
return fun;
}

View File

@@ -22,8 +22,7 @@ void slow_conv_1D(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
@@ -61,8 +60,7 @@ void slow_conv_1D(
out_stride_O = out.strides()[2],
flip,
padding_lo = padding_lo[0],
padding_hi = padding_hi[0],
padding = padding[0],
wt_stride = wt_strides[0],
wt_dilation = wt_dilation[0],
in_dilation = in_dilation[0]]() mutable {
@@ -79,7 +77,7 @@ void slow_conv_1D(
const T* wt_ptr = filter_wt_ptr + wh * wt_stride_H;
int wh_flip = flip ? (wH - wh - 1) : wh;
int ih = oh * wt_stride - padding_lo + wh_flip * wt_dilation;
int ih = oh * wt_stride - padding + wh_flip * wt_dilation;
auto ih_div = std::div(ih, in_dilation);
@@ -111,8 +109,7 @@ void slow_conv_2D(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
@@ -123,235 +120,230 @@ void slow_conv_2D(
encoder.set_input_array(wt);
encoder.set_output_array(out);
encoder.dispatch(
[st_wt_ptr = wt.data<T>(),
st_in_ptr = in.data<T>(),
st_out_ptr = out.data<T>(),
encoder.dispatch([st_wt_ptr = wt.data<T>(),
st_in_ptr = in.data<T>(),
st_out_ptr = out.data<T>(),
N = in.shape(0), // Batch size, should be the same as out.shape(0)
iH = 1 + in_dilation[0] * (in.shape(1) - 1), // Input spatial dim
iW = 1 + in_dilation[1] * (in.shape(2) - 1), // Input spatial dim
C = in.shape(3), // In channels
oH = out.shape(1), // Output spatial dim
oW = out.shape(2), // Output spatial dim
O = wt.shape(0), // Out channels
wH = wt.shape(1), // Weight spatial dim
wW = wt.shape(2), // Weight spatial dim
N = in.shape(
0), // Batch size, should be the same as out.shape(0)
iH = 1 +
in_dilation[0] * (in.shape(1) - 1), // Input spatial dim
iW = 1 +
in_dilation[1] * (in.shape(2) - 1), // Input spatial dim
C = in.shape(3), // In channels
oH = out.shape(1), // Output spatial dim
oW = out.shape(2), // Output spatial dim
O = wt.shape(0), // Out channels
wH = wt.shape(1), // Weight spatial dim
wW = wt.shape(2), // Weight spatial dim
groups = in.shape(3) / wt.shape(3),
C_per_group = wt.shape(3),
groups = in.shape(3) / wt.shape(3),
C_per_group = wt.shape(3),
in_stride_N = in.strides()[0],
in_stride_H = in.strides()[1],
in_stride_W = in.strides()[2],
in_stride_C = in.strides()[3],
in_stride_N = in.strides()[0],
in_stride_H = in.strides()[1],
in_stride_W = in.strides()[2],
in_stride_C = in.strides()[3],
wt_stride_O = wt.strides()[0],
wt_stride_H = wt.strides()[1],
wt_stride_W = wt.strides()[2],
wt_stride_C = wt.strides()[3],
wt_stride_O = wt.strides()[0],
wt_stride_H = wt.strides()[1],
wt_stride_W = wt.strides()[2],
wt_stride_C = wt.strides()[3],
out_stride_N = out.strides()[0],
out_stride_H = out.strides()[1],
out_stride_W = out.strides()[2],
out_stride_O = out.strides()[3],
out_stride_N = out.strides()[0],
out_stride_H = out.strides()[1],
out_stride_W = out.strides()[2],
out_stride_O = out.strides()[3],
padding_lo,
padding_hi,
wt_strides,
wt_dilation,
in_dilation,
flip]() mutable {
bool is_idil_one = in_dilation[0] == 1 && in_dilation[1] == 1;
padding,
wt_strides,
wt_dilation,
in_dilation,
flip]() mutable {
bool is_idil_one = in_dilation[0] == 1 && in_dilation[1] == 1;
const int O_per_group = O / groups;
auto pt_conv_no_checks =
[&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) {
out_ptr += oh * out_stride_H + ow * out_stride_W;
int ih_base = oh * wt_strides[0] - padding_lo[0];
int iw_base = ow * wt_strides[1] - padding_lo[1];
const int O_per_group = O / groups;
auto pt_conv_no_checks = [&](const T* in_ptr,
const T* wt_ptr,
T* out_ptr,
int oh,
int ow) {
out_ptr += oh * out_stride_H + ow * out_stride_W;
int ih_base = oh * wt_strides[0] - padding[0];
int iw_base = ow * wt_strides[1] - padding[1];
for (int g = 0; g < groups; ++g) {
for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
float r = 0.;
for (int g = 0; g < groups; ++g) {
for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
float r = 0.;
for (int wh = 0; wh < wH; ++wh) {
for (int ww = 0; ww < wW; ++ww) {
int wh_flip = flip ? wH - wh - 1 : wh;
int ww_flip = flip ? wW - ww - 1 : ww;
int ih = ih_base + wh_flip * wt_dilation[0];
int iw = iw_base + ww_flip * wt_dilation[1];
for (int wh = 0; wh < wH; ++wh) {
for (int ww = 0; ww < wW; ++ww) {
int wh_flip = flip ? wH - wh - 1 : wh;
int ww_flip = flip ? wW - ww - 1 : ww;
int ih = ih_base + wh_flip * wt_dilation[0];
int iw = iw_base + ww_flip * wt_dilation[1];
const T* wt_ptr_pt =
wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
const T* in_ptr_pt =
in_ptr + ih * in_stride_H + iw * in_stride_W;
const T* wt_ptr_pt = wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
const T* in_ptr_pt = in_ptr + ih * in_stride_H + iw * in_stride_W;
for (int c = g * C_per_group; c < (g + 1) * C_per_group;
++c) {
r += static_cast<float>(in_ptr_pt[c * in_stride_C]) *
static_cast<float>(
wt_ptr_pt[(c % C_per_group) * wt_stride_C]);
} // c
} // ww
} // wh
for (int c = g * C_per_group; c < (g + 1) * C_per_group; ++c) {
r += static_cast<float>(in_ptr_pt[c * in_stride_C]) *
static_cast<float>(
wt_ptr_pt[(c % C_per_group) * wt_stride_C]);
} // c
} // ww
} // wh
out_ptr[0] = static_cast<T>(r);
out_ptr += out_stride_O;
wt_ptr += wt_stride_O;
} // o
} // g
};
out_ptr[0] = static_cast<T>(r);
out_ptr += out_stride_O;
wt_ptr += wt_stride_O;
} // o
} // g
};
int jump_h = flip ? -wt_dilation[0] : wt_dilation[0];
int jump_w = flip ? -wt_dilation[1] : wt_dilation[1];
int jump_h = flip ? -wt_dilation[0] : wt_dilation[0];
int jump_w = flip ? -wt_dilation[1] : wt_dilation[1];
int init_h = (flip ? (wH - 1) * wt_dilation[0] : 0);
int init_w = (flip ? (wW - 1) * wt_dilation[1] : 0);
int init_h = (flip ? (wH - 1) * wt_dilation[0] : 0);
int init_w = (flip ? (wW - 1) * wt_dilation[1] : 0);
int f_wgt_jump_h =
std::lcm(in_dilation[0], wt_dilation[0]) / wt_dilation[0];
int f_wgt_jump_w =
std::lcm(in_dilation[1], wt_dilation[1]) / wt_dilation[1];
int f_wgt_jump_h =
std::lcm(in_dilation[0], wt_dilation[0]) / wt_dilation[0];
int f_wgt_jump_w =
std::lcm(in_dilation[1], wt_dilation[1]) / wt_dilation[1];
int f_out_jump_h =
std::lcm(in_dilation[0], wt_strides[0]) / wt_strides[0];
int f_out_jump_w =
std::lcm(in_dilation[1], wt_strides[1]) / wt_strides[1];
int f_out_jump_h = std::lcm(in_dilation[0], wt_strides[0]) / wt_strides[0];
int f_out_jump_w = std::lcm(in_dilation[1], wt_strides[1]) / wt_strides[1];
std::vector<int> base_h(f_out_jump_h);
std::vector<int> base_w(f_out_jump_w);
std::vector<int> base_h(f_out_jump_h);
std::vector<int> base_w(f_out_jump_w);
for (int i = 0; i < f_out_jump_h; ++i) {
int ih_loop = i * wt_strides[0] - padding_lo[0] + init_h;
for (int i = 0; i < f_out_jump_h; ++i) {
int ih_loop = i * wt_strides[0] - padding[0] + init_h;
int wh_base = 0;
while (wh_base < wH && ih_loop % in_dilation[0] != 0) {
wh_base++;
ih_loop += jump_h;
}
int wh_base = 0;
while (wh_base < wH && ih_loop % in_dilation[0] != 0) {
wh_base++;
ih_loop += jump_h;
}
base_h[i] = wh_base;
}
base_h[i] = wh_base;
}
for (int j = 0; j < f_out_jump_w; ++j) {
int iw_loop = j * wt_strides[1] - padding_lo[1] + init_w;
for (int j = 0; j < f_out_jump_w; ++j) {
int iw_loop = j * wt_strides[1] - padding[1] + init_w;
int ww_base = 0;
while (ww_base < wW && iw_loop % in_dilation[1] != 0) {
ww_base++;
iw_loop += jump_w;
}
int ww_base = 0;
while (ww_base < wW && iw_loop % in_dilation[1] != 0) {
ww_base++;
iw_loop += jump_w;
}
base_w[j] = ww_base;
}
base_w[j] = ww_base;
}
auto pt_conv_all_checks =
[&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) {
out_ptr += oh * out_stride_H + ow * out_stride_W;
auto pt_conv_all_checks =
[&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) {
out_ptr += oh * out_stride_H + ow * out_stride_W;
int ih_base = oh * wt_strides[0] - padding_lo[0];
int iw_base = ow * wt_strides[1] - padding_lo[1];
int ih_base = oh * wt_strides[0] - padding[0];
int iw_base = ow * wt_strides[1] - padding[1];
int wh_base = base_h[oh % f_out_jump_h];
int ww_base = base_w[ow % f_out_jump_w];
int wh_base = base_h[oh % f_out_jump_h];
int ww_base = base_w[ow % f_out_jump_w];
for (int g = 0; g < groups; ++g) {
for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
float r = 0.;
for (int g = 0; g < groups; ++g) {
for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
float r = 0.;
for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) {
for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) {
int wh_flip = flip ? wH - wh - 1 : wh;
int ww_flip = flip ? wW - ww - 1 : ww;
int ih = ih_base + wh_flip * wt_dilation[0];
int iw = iw_base + ww_flip * wt_dilation[1];
for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) {
for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) {
int wh_flip = flip ? wH - wh - 1 : wh;
int ww_flip = flip ? wW - ww - 1 : ww;
int ih = ih_base + wh_flip * wt_dilation[0];
int iw = iw_base + ww_flip * wt_dilation[1];
if (ih >= 0 && ih < iH && iw >= 0 && iw < iW) {
const T* wt_ptr_pt =
wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
if (ih >= 0 && ih < iH && iw >= 0 && iw < iW) {
const T* wt_ptr_pt =
wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
int ih_dil = !is_idil_one ? (ih / in_dilation[0]) : ih;
int iw_dil = !is_idil_one ? (iw / in_dilation[1]) : iw;
int ih_dil = !is_idil_one ? (ih / in_dilation[0]) : ih;
int iw_dil = !is_idil_one ? (iw / in_dilation[1]) : iw;
const T* in_ptr_pt = in_ptr + ih_dil * in_stride_H +
iw_dil * in_stride_W;
const T* in_ptr_pt =
in_ptr + ih_dil * in_stride_H + iw_dil * in_stride_W;
for (int c = g * C_per_group; c < (g + 1) * C_per_group;
++c) {
r += static_cast<float>(in_ptr_pt[c * in_stride_C]) *
static_cast<float>(
wt_ptr_pt[(c % C_per_group) * wt_stride_C]);
} // c
for (int c = g * C_per_group; c < (g + 1) * C_per_group;
++c) {
r += static_cast<float>(in_ptr_pt[c * in_stride_C]) *
static_cast<float>(
wt_ptr_pt[(c % C_per_group) * wt_stride_C]);
} // c
} // ih, iw check
} // ww
} // wh
} // ih, iw check
} // ww
} // wh
out_ptr[0] = static_cast<T>(r);
out_ptr += out_stride_O;
wt_ptr += wt_stride_O;
} // o
} // g
};
out_ptr[0] = static_cast<T>(r);
out_ptr += out_stride_O;
wt_ptr += wt_stride_O;
} // o
} // g
};
int oH_border_0 = 0;
int oH_border_1 = is_idil_one
? ((padding_lo[0] + wt_strides[0] - 1) / wt_strides[0])
: oH;
int oH_border_2 = std::max(
oH_border_1,
(iH + padding_lo[0] - wH * wt_dilation[0]) / wt_strides[0]);
int oH_border_3 = oH;
int oH_border_0 = 0;
int oH_border_1 =
is_idil_one ? ((padding[0] + wt_strides[0] - 1) / wt_strides[0]) : oH;
int oH_border_2 = std::max(
oH_border_1, (iH + padding[0] - wH * wt_dilation[0]) / wt_strides[0]);
int oH_border_3 = oH;
int oW_border_0 = 0;
int oW_border_1 = is_idil_one
? ((padding_lo[1] + wt_strides[1] - 1) / wt_strides[1])
: oW;
int oW_border_2 = std::max(
oW_border_1,
(iW + padding_lo[1] - wW * wt_dilation[1]) / wt_strides[1]);
int oW_border_3 = oW;
int oW_border_0 = 0;
int oW_border_1 =
is_idil_one ? ((padding[1] + wt_strides[1] - 1) / wt_strides[1]) : oW;
int oW_border_2 = std::max(
oW_border_1, (iW + padding[1] - wW * wt_dilation[1]) / wt_strides[1]);
int oW_border_3 = oW;
for (int n = 0; n < N; ++n) {
// Case 1: oh might put us out of bounds
for (int oh = oH_border_0; oh < oH_border_1; ++oh) {
for (int ow = 0; ow < oW; ++ow) {
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
} // ow
} // oh
for (int n = 0; n < N; ++n) {
// Case 1: oh might put us out of bounds
for (int oh = oH_border_0; oh < oH_border_1; ++oh) {
for (int ow = 0; ow < oW; ++ow) {
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
} // ow
} // oh
// Case 2: oh in bounds
for (int oh = oH_border_1; oh < oH_border_2; ++oh) {
// Case a: ow might put us out of bounds
for (int ow = oW_border_0; ow < oW_border_1; ++ow) {
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
} // ow
// Case 2: oh in bounds
for (int oh = oH_border_1; oh < oH_border_2; ++oh) {
// Case a: ow might put us out of bounds
for (int ow = oW_border_0; ow < oW_border_1; ++ow) {
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
} // ow
// Case b: ow in bounds
for (int ow = oW_border_1; ow < oW_border_2; ++ow) {
pt_conv_no_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
} // ow
// Case b: ow in bounds
for (int ow = oW_border_1; ow < oW_border_2; ++ow) {
pt_conv_no_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
} // ow
// Case c: ow might put us out of bounds
for (int ow = oW_border_2; ow < oW_border_3; ++ow) {
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
} // ow
// Case c: ow might put us out of bounds
for (int ow = oW_border_2; ow < oW_border_3; ++ow) {
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
} // ow
} // oh
} // oh
// Case 3: oh might put us out of bounds
for (int oh = oH_border_2; oh < oH_border_3; ++oh) {
for (int ow = 0; ow < oW; ++ow) {
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
} // ow
} // oh
// Case 3: oh might put us out of bounds
for (int oh = oH_border_2; oh < oH_border_3; ++oh) {
for (int ow = 0; ow < oW; ++ow) {
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
} // ow
} // oh
st_in_ptr += in_stride_N;
st_out_ptr += out_stride_N;
st_in_ptr += in_stride_N;
st_out_ptr += out_stride_N;
} // n
});
} // n
});
}
template <typename T>
@@ -359,8 +351,7 @@ void slow_conv_3D(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
@@ -409,8 +400,7 @@ void slow_conv_3D(
out_stride_H = out.strides()[2],
out_stride_W = out.strides()[3],
out_stride_O = out.strides()[4],
padding_lo,
padding_hi,
padding,
wt_strides,
wt_dilation,
in_dilation,
@@ -425,9 +415,9 @@ void slow_conv_3D(
int oh,
int ow) {
out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W;
int id_base = od * wt_strides[0] - padding_lo[0];
int ih_base = oh * wt_strides[1] - padding_lo[1];
int iw_base = ow * wt_strides[2] - padding_lo[2];
int id_base = od * wt_strides[0] - padding[0];
int ih_base = oh * wt_strides[1] - padding[1];
int iw_base = ow * wt_strides[2] - padding[2];
for (int o = 0; o < O; ++o) {
float r = 0.;
@@ -488,7 +478,7 @@ void slow_conv_3D(
std::vector<int> base_w(f_out_jump_w);
for (int i = 0; i < f_out_jump_d; ++i) {
int id_loop = i * wt_strides[0] - padding_lo[0] + init_d;
int id_loop = i * wt_strides[0] - padding[0] + init_d;
int wd_base = 0;
while (wd_base < wD && id_loop % in_dilation[0] != 0) {
@@ -500,7 +490,7 @@ void slow_conv_3D(
}
for (int i = 0; i < f_out_jump_h; ++i) {
int ih_loop = i * wt_strides[1] - padding_lo[1] + init_h;
int ih_loop = i * wt_strides[1] - padding[1] + init_h;
int wh_base = 0;
while (wh_base < wH && ih_loop % in_dilation[1] != 0) {
@@ -512,7 +502,7 @@ void slow_conv_3D(
}
for (int j = 0; j < f_out_jump_w; ++j) {
int iw_loop = j * wt_strides[2] - padding_lo[2] + init_w;
int iw_loop = j * wt_strides[2] - padding[2] + init_w;
int ww_base = 0;
while (ww_base < wW && iw_loop % in_dilation[2] != 0) {
@@ -531,9 +521,9 @@ void slow_conv_3D(
int ow) {
out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W;
int id_base = od * wt_strides[0] - padding_lo[0];
int ih_base = oh * wt_strides[1] - padding_lo[1];
int iw_base = ow * wt_strides[2] - padding_lo[2];
int id_base = od * wt_strides[0] - padding[0];
int ih_base = oh * wt_strides[1] - padding[1];
int iw_base = ow * wt_strides[2] - padding[2];
int wd_base = base_d[od % f_out_jump_d];
int wh_base = base_h[oh % f_out_jump_h];
@@ -583,30 +573,24 @@ void slow_conv_3D(
};
int oD_border_0 = 0;
int oD_border_1 = is_idil_one
? ((padding_lo[0] + wt_strides[0] - 1) / wt_strides[0])
: oD;
int oD_border_1 =
is_idil_one ? ((padding[0] + wt_strides[0] - 1) / wt_strides[0]) : oD;
int oD_border_2 = std::max(
oD_border_1,
(iD + padding_lo[0] - wD * wt_dilation[0]) / wt_strides[0]);
oD_border_1, (iD + padding[0] - wD * wt_dilation[0]) / wt_strides[0]);
int oD_border_3 = oD;
int oH_border_0 = 0;
int oH_border_1 = is_idil_one
? ((padding_lo[1] + wt_strides[1] - 1) / wt_strides[1])
: oH;
int oH_border_1 =
is_idil_one ? ((padding[1] + wt_strides[1] - 1) / wt_strides[1]) : oH;
int oH_border_2 = std::max(
oH_border_1,
(iH + padding_lo[1] - wH * wt_dilation[1]) / wt_strides[1]);
oH_border_1, (iH + padding[1] - wH * wt_dilation[1]) / wt_strides[1]);
int oH_border_3 = oH;
int oW_border_0 = 0;
int oW_border_1 = is_idil_one
? ((padding_lo[2] + wt_strides[2] - 1) / wt_strides[2])
: oW;
int oW_border_1 =
is_idil_one ? ((padding[2] + wt_strides[2] - 1) / wt_strides[2]) : oW;
int oW_border_2 = std::max(
oW_border_1,
(iW + padding_lo[2] - wW * wt_dilation[2]) / wt_strides[2]);
oW_border_1, (iW + padding[2] - wW * wt_dilation[2]) / wt_strides[2]);
int oW_border_3 = oW;
for (int n = 0; n < N; ++n) {
@@ -674,8 +658,7 @@ void dispatch_slow_conv_1D(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
@@ -686,8 +669,7 @@ void dispatch_slow_conv_1D(
in,
wt,
out,
padding_lo,
padding_hi,
padding,
wt_strides,
wt_dilation,
in_dilation,
@@ -698,8 +680,7 @@ void dispatch_slow_conv_1D(
in,
wt,
out,
padding_lo,
padding_hi,
padding,
wt_strides,
wt_dilation,
in_dilation,
@@ -710,8 +691,7 @@ void dispatch_slow_conv_1D(
in,
wt,
out,
padding_lo,
padding_hi,
padding,
wt_strides,
wt_dilation,
in_dilation,
@@ -727,8 +707,7 @@ void dispatch_slow_conv_2D(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
@@ -739,8 +718,7 @@ void dispatch_slow_conv_2D(
in,
wt,
out,
padding_lo,
padding_hi,
padding,
wt_strides,
wt_dilation,
in_dilation,
@@ -751,8 +729,7 @@ void dispatch_slow_conv_2D(
in,
wt,
out,
padding_lo,
padding_hi,
padding,
wt_strides,
wt_dilation,
in_dilation,
@@ -763,8 +740,7 @@ void dispatch_slow_conv_2D(
in,
wt,
out,
padding_lo,
padding_hi,
padding,
wt_strides,
wt_dilation,
in_dilation,
@@ -780,8 +756,7 @@ void dispatch_slow_conv_3D(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
@@ -792,8 +767,7 @@ void dispatch_slow_conv_3D(
in,
wt,
out,
padding_lo,
padding_hi,
padding,
wt_strides,
wt_dilation,
in_dilation,
@@ -804,8 +778,7 @@ void dispatch_slow_conv_3D(
in,
wt,
out,
padding_lo,
padding_hi,
padding,
wt_strides,
wt_dilation,
in_dilation,
@@ -816,8 +789,7 @@ void dispatch_slow_conv_3D(
in,
wt,
out,
padding_lo,
padding_hi,
padding,
wt_strides,
wt_dilation,
in_dilation,
@@ -857,8 +829,7 @@ void explicit_gemm_conv_1D_cpu(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
Stream stream) {
@@ -877,7 +848,7 @@ void explicit_gemm_conv_1D_cpu(
auto& encoder = cpu::get_command_encoder(stream);
// Pad input
Shape padded_shape = {N, iH + padding_lo[0] + padding_hi[0], C};
Shape padded_shape = {N, iH + 2 * padding[0], C};
array in_padded(padded_shape, conv_dtype, nullptr, {});
// Fill with zeros
@@ -886,7 +857,7 @@ void explicit_gemm_conv_1D_cpu(
copy(temps.back(), in_padded, CopyType::Scalar, stream);
// Pick input slice from padded
size_t data_offset = padding_lo[0] * in_padded.strides()[1];
size_t data_offset = padding[0] * in_padded.strides()[1];
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
in_padded_slice.copy_shared_buffer(
in_padded,
@@ -1000,8 +971,7 @@ void explicit_gemm_conv_2D_cpu(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
Stream stream) {
@@ -1019,11 +989,7 @@ void explicit_gemm_conv_2D_cpu(
auto& encoder = cpu::get_command_encoder(stream);
// Pad input
Shape padded_shape = {
N,
iH + padding_lo[0] + padding_hi[0],
iW + padding_lo[1] + padding_hi[1],
C};
Shape padded_shape = {N, iH + 2 * padding[0], iW + 2 * padding[1], C};
array in_padded(padded_shape, conv_dtype, nullptr, {});
// Fill with zeros
@@ -1032,8 +998,8 @@ void explicit_gemm_conv_2D_cpu(
copy(temps.back(), in_padded, CopyType::Scalar, stream);
// Pick input slice from padded
size_t data_offset = padding_lo[0] * in_padded.strides()[1] +
padding_lo[1] * in_padded.strides()[2];
size_t data_offset =
padding[0] * in_padded.strides()[1] + padding[1] * in_padded.strides()[2];
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
in_padded_slice.copy_shared_buffer(
in_padded,
@@ -1125,8 +1091,7 @@ void explicit_gemm_conv_ND_cpu(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const bool flip,
@@ -1149,7 +1114,7 @@ void explicit_gemm_conv_ND_cpu(
Shape padded_shape(in.shape().size());
padded_shape.front() = N;
for (size_t i = 0; i < iDim.size(); i++) {
padded_shape[i + 1] = iDim[i] + padding_lo[i] + padding_hi[i];
padded_shape[i + 1] = iDim[i] + 2 * padding[i];
}
padded_shape.back() = C;
array in_padded(padded_shape, conv_dtype, nullptr, {});
@@ -1160,10 +1125,9 @@ void explicit_gemm_conv_ND_cpu(
// Pick input slice from padded
size_t data_offset = 0;
for (size_t i = 0; i < padding_lo.size(); i++) {
data_offset += padding_lo[i] * in_padded.strides()[i + 1];
for (size_t i = 0; i < padding.size(); i++) {
data_offset += padding[i] * in_padded.strides()[i + 1];
}
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
in_padded_slice.copy_shared_buffer(
in_padded,
@@ -1297,8 +1261,7 @@ void conv_1D_cpu(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
@@ -1307,40 +1270,22 @@ void conv_1D_cpu(
const int groups = in.shape().back() / wt.shape().back();
if (wt_dilation[0] == 1 && in_dilation[0] == 1 && !flip) {
return explicit_gemm_conv_1D_cpu(
in, wt, out, padding_lo, padding_hi, wt_strides, wt_dilation, stream);
in, wt, out, padding, wt_strides, wt_dilation, stream);
}
if (wt_dilation[0] == 1 && in_dilation[0] == 1 && groups == 1) {
return explicit_gemm_conv_ND_cpu(
in,
wt,
out,
padding_lo,
padding_hi,
wt_strides,
wt_dilation,
flip,
stream);
in, wt, out, padding, wt_strides, wt_dilation, flip, stream);
}
return dispatch_slow_conv_1D(
in,
wt,
out,
padding_lo,
padding_hi,
wt_strides,
wt_dilation,
in_dilation,
flip,
stream);
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip, stream);
}
void conv_2D_cpu(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
@@ -1350,35 +1295,18 @@ void conv_2D_cpu(
if (wt_dilation[0] == 1 && wt_dilation[1] == 1 && in_dilation[0] == 1 &&
in_dilation[1] == 1 && groups == 1) {
return explicit_gemm_conv_ND_cpu(
in,
wt,
out,
padding_lo,
padding_hi,
wt_strides,
wt_dilation,
flip,
stream);
in, wt, out, padding, wt_strides, wt_dilation, flip, stream);
}
return dispatch_slow_conv_2D(
in,
wt,
out,
padding_lo,
padding_hi,
wt_strides,
wt_dilation,
in_dilation,
flip,
stream);
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip, stream);
}
void conv_3D_cpu(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
@@ -1389,28 +1317,11 @@ void conv_3D_cpu(
in_dilation[0] == 1 && in_dilation[1] == 1 && in_dilation[2] == 1 &&
groups == 1) {
return explicit_gemm_conv_ND_cpu(
in,
wt,
out,
padding_lo,
padding_hi,
wt_strides,
wt_dilation,
flip,
stream);
in, wt, out, padding, wt_strides, wt_dilation, flip, stream);
}
return dispatch_slow_conv_3D(
in,
wt,
out,
padding_lo,
padding_hi,
wt_strides,
wt_dilation,
in_dilation,
flip,
stream);
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip, stream);
}
} // namespace
@@ -1427,8 +1338,7 @@ void Convolution::eval_cpu(const std::vector<array>& inputs, array& out) {
in,
wt,
out,
padding_lo_,
padding_hi_,
padding_,
kernel_strides_,
kernel_dilation_,
input_dilation_,
@@ -1441,8 +1351,7 @@ void Convolution::eval_cpu(const std::vector<array>& inputs, array& out) {
in,
wt,
out,
padding_lo_,
padding_hi_,
padding_,
kernel_strides_,
kernel_dilation_,
input_dilation_,
@@ -1455,8 +1364,7 @@ void Convolution::eval_cpu(const std::vector<array>& inputs, array& out) {
in,
wt,
out,
padding_lo_,
padding_hi_,
padding_,
kernel_strides_,
kernel_dilation_,
input_dilation_,

View File

@@ -330,8 +330,7 @@ void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case complex64:
scan_dispatch<complex64_t, complex64_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
throw std::runtime_error("Scan ops do not support complex types yet");
break;
}
});

View File

@@ -88,33 +88,12 @@ DEFAULT_UNARY(expm1, std::expm1)
DEFAULT_UNARY(floor, std::floor)
DEFAULT_UNARY(log, std::log)
DEFAULT_UNARY(log10, std::log10)
DEFAULT_UNARY(log1p, std::log1p)
DEFAULT_UNARY(sinh, std::sinh)
DEFAULT_UNARY(sqrt, std::sqrt)
DEFAULT_UNARY(tan, std::tan)
DEFAULT_UNARY(tanh, std::tanh)
template <typename T>
Simd<T, 1> log1p(Simd<T, 1> in) {
if constexpr (is_complex<T>) {
auto x = in.value.real();
auto y = in.value.imag();
auto zabs = std::abs(in.value);
auto theta = std::atan2(y, x + 1);
if (zabs < 0.5) {
auto r = x * (2 + x) + y * y;
if (r == 0) { // handle underflow
return Simd<T, 1>{T{x, theta}};
}
return Simd<T, 1>{T{((typeof(x))(0.5)) * std::log1p(r), theta}};
} else {
auto z0 = std::hypot(x + 1, y);
return Simd<T, 1>{T{std::log(z0), theta}};
}
} else {
return Simd<T, 1>{std::log1p(in.value)};
}
}
template <typename T>
Simd<T, 1> log2(Simd<T, 1> in) {
if constexpr (is_complex<T>) {

View File

@@ -1,57 +0,0 @@
# Filename rules in cuda backend:
#
# * Use .cu/.cuh if code contains device code, and .cpp/.h if not.
# * Device-only kernel code should be put in kernels/ subdir.
# * Files in kernels/ subdir should not include files outside.
target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
${CMAKE_CURRENT_SOURCE_DIR}/fence.cu
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
target_compile_definitions(mlx PUBLIC MLX_USE_CUDA)
# Enable defining device lambda functions.
target_compile_options(mlx
PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--extended-lambda>")
# Compute capability 7 is required for synchronization between CPU/GPU with
# managed memory. TODO: Add more architectures for potential performance gain.
set(MLX_CUDA_ARCHITECTURES
"75;80"
CACHE STRING "CUDA architectures")
message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}")
set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES
"${MLX_CUDA_ARCHITECTURES}")
# Use fixed version of CCCL.
FetchContent_Declare(
cccl
URL "https://github.com/NVIDIA/cccl/releases/download/v2.8.1/cccl-v2.8.1.zip")
FetchContent_MakeAvailable(cccl)
target_include_directories(mlx PRIVATE BEFORE "${cccl_SOURCE_DIR}/include")
# Use fixed version of NVTX.
FetchContent_Declare(
nvtx3
GIT_REPOSITORY https://github.com/NVIDIA/NVTX.git
GIT_TAG v3.1.1
GIT_SHALLOW TRUE
SOURCE_SUBDIR c EXCLUDE_FROM_ALL)
FetchContent_MakeAvailable(nvtx3)
target_link_libraries(mlx PUBLIC $<BUILD_INTERFACE:nvtx3-cpp>)
# Make cuda runtime APIs available in non-cuda files.
find_package(CUDAToolkit REQUIRED)
target_include_directories(mlx PRIVATE ${CUDAToolkit_INCLUDE_DIRS})
# Suppress nvcc warnings on MLX headers.
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
--diag_suppress=997>)

View File

@@ -1,154 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/utils.h"
#include "mlx/backend/cuda/worker.h"
#include <cuda_runtime.h>
#include <fmt/format.h>
#include <cassert>
namespace mlx::core {
namespace cu {
CudaAllocator::CudaAllocator() {
// TODO: Set memory limit for multi-device.
size_t free, total;
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
memory_limit_ = total * 0.8;
}
Buffer CudaAllocator::malloc(size_t size) {
// TODO: Check memory limit.
auto* buf = new CudaBuffer{nullptr, size};
cudaError_t err = cudaMallocManaged(&buf->data, size);
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
throw std::runtime_error(
fmt::format("cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
}
std::lock_guard lock(mutex_);
active_memory_ += size;
peak_memory_ = std::max(active_memory_, peak_memory_);
return Buffer{buf};
}
void CudaAllocator::free(Buffer buffer) {
auto* buf = static_cast<CudaBuffer*>(buffer.ptr());
if (!buf) {
return;
}
// If free() is called from a unregistered thread, reschedule the call to
// worker.
{
std::lock_guard lock(worker_mutex_);
if (allowed_threads_.count(std::this_thread::get_id()) == 0) {
if (!worker_) {
worker_.reset(new Worker);
}
worker_->add_task([buffer]() { allocator().free(buffer); });
worker_->end_batch();
worker_->commit();
return;
}
}
size_t size = buf->size;
cudaFree(buf->data);
delete buf;
std::lock_guard lock(mutex_);
active_memory_ -= size;
}
size_t CudaAllocator::size(Buffer buffer) const {
auto* buf = static_cast<CudaBuffer*>(buffer.ptr());
if (!buf) {
return 0;
}
return buf->size;
}
void CudaAllocator::register_this_thread() {
std::lock_guard lock(worker_mutex_);
allowed_threads_.insert(std::this_thread::get_id());
}
size_t CudaAllocator::get_active_memory() const {
return active_memory_;
}
size_t CudaAllocator::get_peak_memory() const {
return peak_memory_;
}
void CudaAllocator::reset_peak_memory() {
std::lock_guard lock(mutex_);
peak_memory_ = 0;
}
size_t CudaAllocator::get_memory_limit() {
return memory_limit_;
}
size_t CudaAllocator::set_memory_limit(size_t limit) {
std::lock_guard lock(mutex_);
std::swap(limit, memory_limit_);
return limit;
}
CudaAllocator& allocator() {
// By creating the |allocator_| on heap, the destructor of CudaAllocator
// will not be called on exit and buffers in the cache will be leaked. This
// can save some time at program exit.
static CudaAllocator* allocator_ = new CudaAllocator;
return *allocator_;
}
} // namespace cu
namespace allocator {
Allocator& allocator() {
return cu::allocator();
}
void* Buffer::raw_ptr() {
if (!ptr_) {
return nullptr;
}
return static_cast<cu::CudaBuffer*>(ptr_)->data;
}
} // namespace allocator
size_t get_active_memory() {
return cu::allocator().get_active_memory();
}
size_t get_peak_memory() {
return cu::allocator().get_peak_memory();
}
void reset_peak_memory() {
return cu::allocator().reset_peak_memory();
}
size_t set_memory_limit(size_t limit) {
return cu::allocator().set_memory_limit(limit);
}
size_t get_memory_limit() {
return cu::allocator().get_memory_limit();
}
// TODO: Implement buffer cache.
size_t get_cache_memory() {
return 0;
}
size_t set_cache_limit(size_t) {
return 0;
}
size_t set_wired_limit(size_t) {
return 0;
}
void clear_cache() {}
} // namespace mlx::core

View File

@@ -1,58 +0,0 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/allocator.h"
#include <mutex>
#include <set>
#include <thread>
#include <utility>
namespace mlx::core::cu {
class Worker;
using allocator::Buffer;
// Stores cuda-managed unified memory.
struct CudaBuffer {
void* data;
size_t size;
};
class CudaAllocator : public allocator::Allocator {
public:
Buffer malloc(size_t size) override;
void free(Buffer buffer) override;
size_t size(Buffer buffer) const override;
// Register current thread as safe to free buffers.
// In cuda freeing a buffer implicitly synchronizes stream, and for threads
// that may be waited by gpu stream (for example cpu stream threads), freeing
// buffers there would result in dead lock.
void register_this_thread();
size_t get_active_memory() const;
size_t get_peak_memory() const;
void reset_peak_memory();
size_t get_memory_limit();
size_t set_memory_limit(size_t limit);
private:
CudaAllocator();
friend CudaAllocator& allocator();
std::mutex worker_mutex_;
std::unique_ptr<Worker> worker_;
std::set<std::thread::id> allowed_threads_;
std::mutex mutex_;
size_t memory_limit_;
size_t active_memory_{0};
size_t peak_memory_{0};
};
CudaAllocator& allocator();
} // namespace mlx::core::cu

View File

@@ -1,26 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/gpu/copy.h"
namespace mlx::core {
void copy_gpu_inplace(
const array& in,
array& out,
const Shape& data_shape,
const Strides& strides_in_pre,
const Strides& strides_out_pre,
int64_t inp_offset,
int64_t out_offset,
CopyType ctype,
const Stream& s,
const std::optional<array>& dynamic_i_offset /* = std::nullopt */,
const std::optional<array>& dynamic_o_offset /* = std::nullopt */) {
throw std::runtime_error("copy_gpu_inplace not implemented in CUDA backend.");
}
void fill_gpu(const array& val, array& out, const Stream& s) {
throw std::runtime_error("fill_gpu not implemented in CUDA backend.");
}
} // namespace mlx::core

View File

@@ -1,117 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/worker.h"
#include "mlx/backend/metal/metal.h"
#include <fmt/format.h>
#include <nvtx3/nvtx3.hpp>
namespace mlx::core {
namespace cu {
DeviceStream::DeviceStream(Device& device) : device_(device), stream_(device) {}
void DeviceStream::synchronize() {
cudaStreamSynchronize(stream_);
}
cudaStream_t DeviceStream::schedule_cuda_stream() {
// TODO: Return a stream that maximizes parallelism.
return stream_;
}
cudaStream_t DeviceStream::last_cuda_stream() {
return stream_;
}
CommandEncoder& DeviceStream::get_encoder() {
if (!encoder_) {
encoder_ = std::make_unique<CommandEncoder>(*this);
}
return *encoder_;
}
Device::Device(int device) : device_(device) {
// Validate the requirements of device.
int attr = 0;
cudaDeviceGetAttribute(&attr, cudaDevAttrConcurrentManagedAccess, device_);
if (attr != 1) {
throw std::runtime_error(fmt::format(
"Device {} does not support synchronization in managed memory.",
device_));
}
}
void Device::make_current() {
// We need to set/get current CUDA device very frequently, cache it to reduce
// actual calls of CUDA APIs. This function assumes single-thread in host.
static int current = 0;
if (current != device_) {
CHECK_CUDA_ERROR(cudaSetDevice(device_));
current = device_;
}
}
DeviceStream& Device::get_stream(Stream s) {
auto it = streams_.find(s.index);
if (it == streams_.end()) {
it = streams_.try_emplace(s.index, *this).first;
}
return it->second;
}
CommandEncoder::CommandEncoder(DeviceStream& s)
: device_(s.device()), stream_(s) {}
void CommandEncoder::add_completed_handler(std::function<void()> task) {
worker_.add_task(std::move(task));
}
void CommandEncoder::end_encoding() {
if (!temporaries_.empty()) {
add_completed_handler([temporaries = std::move(temporaries_)]() {});
}
// There is no kernel running, run completion handlers immediately.
if (!has_gpu_work_) {
worker_.consume_in_this_thread();
return;
}
has_gpu_work_ = false;
// Put completion handlers in a batch.
worker_.end_batch();
// Signaling kernel completion is expensive, delay until enough batches.
// TODO: This number is arbitrarily picked, profile for a better stragety.
if (worker_.uncommited_batches() > 8) {
commit();
}
}
void CommandEncoder::commit() {
worker_.commit(stream_.last_cuda_stream());
}
Device& device(mlx::core::Device device) {
static std::unordered_map<int, Device> devices;
auto it = devices.find(device.index);
if (it == devices.end()) {
it = devices.try_emplace(device.index, device.index).first;
}
return it->second;
}
DeviceStream& get_stream(Stream s) {
return device(s.device).get_stream(s);
}
CommandEncoder& get_command_encoder(Stream s) {
return get_stream(s).get_encoder();
}
} // namespace cu
} // namespace mlx::core

View File

@@ -1,131 +0,0 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/array.h"
#include "mlx/backend/cuda/worker.h"
#include "mlx/stream.h"
#include <thrust/execution_policy.h>
#include <unordered_map>
namespace mlx::core::cu {
class Device;
class CommandEncoder;
class DeviceStream {
public:
explicit DeviceStream(Device& device);
DeviceStream(const DeviceStream&) = delete;
DeviceStream& operator=(const DeviceStream&) = delete;
// Wait until kernels in the stream complete.
void synchronize();
// Return a cuda stream for launching kernels.
cudaStream_t schedule_cuda_stream();
// Return the last cuda stream used.
cudaStream_t last_cuda_stream();
CommandEncoder& get_encoder();
Device& device() {
return device_;
}
private:
Device& device_;
CudaStream stream_;
std::unique_ptr<CommandEncoder> encoder_;
};
class Device {
public:
explicit Device(int device);
Device(const Device&) = delete;
Device& operator=(const Device&) = delete;
// Make this device the current cuda device, required by some cuda calls.
void make_current();
DeviceStream& get_stream(Stream s);
int cuda_device() const {
return device_;
}
private:
int device_;
std::unordered_map<int, DeviceStream> streams_;
};
class CommandEncoder {
public:
explicit CommandEncoder(DeviceStream& stream);
CommandEncoder(const CommandEncoder&) = delete;
CommandEncoder& operator=(const CommandEncoder&) = delete;
void set_input_array(const array& arr) {}
void set_output_array(const array& arr) {}
void add_temporary(const array& arr) {
temporaries_.push_back(arr.data_shared_ptr());
}
void add_completed_handler(std::function<void()> task);
void end_encoding();
void commit();
// Schedule a cuda stream for |fun| to launch kernels, and check error
// afterwards.
template <typename F>
void launch_kernel(F&& fun) {
launch_kernel(stream_.schedule_cuda_stream(), std::forward<F>(fun));
}
template <typename F>
void launch_kernel(cudaStream_t stream, F&& fun) {
device_.make_current();
fun(stream);
check_cuda_error("kernel launch", cudaGetLastError());
has_gpu_work_ = true;
}
Device& device() {
return device_;
}
DeviceStream& stream() {
return stream_;
}
bool has_gpu_work() const {
return has_gpu_work_;
}
private:
Device& device_;
DeviceStream& stream_;
Worker worker_;
bool has_gpu_work_{false};
std::vector<std::shared_ptr<array::Data>> temporaries_;
};
Device& device(mlx::core::Device device);
DeviceStream& get_stream(Stream s);
CommandEncoder& get_command_encoder(Stream s);
// Return an execution policy that does not sync for result.
// Note that not all thrust APIs support async policy, confirm before using.
inline auto thrust_policy(cudaStream_t stream) {
// TODO: Connect thrust's custom allocator with mlx's allocator.
return thrust::cuda::par_nosync.on(stream);
}
} // namespace mlx::core::cu

View File

@@ -1,35 +0,0 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <cuComplex.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
namespace mlx::core {
// Maps CPU types to CUDA types.
template <typename T>
struct CTypeToCudaType {
using type = T;
};
template <>
struct CTypeToCudaType<float16_t> {
using type = __half;
};
template <>
struct CTypeToCudaType<bfloat16_t> {
using type = __nv_bfloat16;
};
template <>
struct CTypeToCudaType<complex64_t> {
using type = cuComplex;
};
template <typename T>
using cuda_type_t = typename CTypeToCudaType<T>::type;
} // namespace mlx::core

View File

@@ -1,68 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/gpu/eval.h"
#include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/gpu/available.h"
#include "mlx/primitives.h"
#include <nvtx3/nvtx3.hpp>
namespace mlx::core::gpu {
bool is_available() {
return true;
}
void new_stream(Stream s) {
// Force initalization of cuda, so cuda runtime get destroyed at last.
cudaFree(nullptr);
// Ensure the static stream objects get created.
cu::get_command_encoder(s);
// The main thread is safe to free buffers.
cu::allocator().register_this_thread();
}
void eval(array& arr) {
nvtx3::scoped_range r("gpu::eval");
auto outputs = arr.outputs();
{
// If the array is a tracer hold a reference
// to its inputs so they don't get donated
std::vector<array> inputs;
if (arr.is_tracer()) {
inputs = arr.inputs();
}
arr.primitive().eval_gpu(arr.inputs(), outputs);
}
auto& encoder = cu::get_command_encoder(arr.primitive().stream());
if (encoder.has_gpu_work()) {
// Keep used buffers alive until kernel finishes running.
std::unordered_set<std::shared_ptr<array::Data>> buffers;
for (auto& in : arr.inputs()) {
buffers.insert(in.data_shared_ptr());
}
for (auto& s : arr.siblings()) {
buffers.insert(s.data_shared_ptr());
}
// Remove the output if it was donated to by an input.
if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) {
buffers.erase(it);
}
encoder.add_completed_handler([buffers = std::move(buffers)]() {});
}
encoder.end_encoding();
}
void finalize(Stream s) {
nvtx3::scoped_range r("gpu::finalize");
cu::get_command_encoder(s).commit();
}
void synchronize(Stream s) {
nvtx3::scoped_range r("gpu::synchronize");
cu::get_stream(s).synchronize();
}
} // namespace mlx::core::gpu

View File

@@ -1,265 +0,0 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/event.h"
#include "mlx/backend/cuda/utils.h"
#include "mlx/event.h"
#include "mlx/scheduler.h"
#include <nvtx3/nvtx3.hpp>
namespace mlx::core {
namespace cu {
///////////////////////////////////////////////////////////////////////////////
// CudaEvent implementations
///////////////////////////////////////////////////////////////////////////////
// Cuda event managed with RAII.
class CudaEventHandle {
public:
CudaEventHandle() {
CHECK_CUDA_ERROR(cudaEventCreateWithFlags(
&event_, cudaEventDisableTiming | cudaEventBlockingSync));
}
~CudaEventHandle() {
CHECK_CUDA_ERROR(cudaEventDestroy(event_));
}
CudaEventHandle(const CudaEventHandle&) = delete;
CudaEventHandle& operator=(const CudaEventHandle&) = delete;
operator cudaEvent_t() const {
return event_;
}
private:
cudaEvent_t event_;
};
CudaEvent::CudaEvent() : event_(std::make_shared<CudaEventHandle>()) {}
void CudaEvent::wait() {
nvtx3::scoped_range r("cu::CudaEvent::wait");
if (!recorded_) {
throw std::runtime_error("Should not wait on a CudaEvent before record.");
}
cudaEventSynchronize(*event_);
}
void CudaEvent::wait(cudaStream_t stream) {
if (!recorded_) {
throw std::runtime_error("Should not wait on a CudaEvent before record.");
}
cudaStreamWaitEvent(stream, *event_);
}
void CudaEvent::wait(Stream s) {
if (s.device == mlx::core::Device::cpu) {
scheduler::enqueue(s, [*this]() mutable { wait(); });
} else {
wait(cu::get_stream(s).last_cuda_stream());
}
}
void CudaEvent::record(cudaStream_t stream) {
cudaEventRecord(*event_, stream);
recorded_ = true;
}
void CudaEvent::record(Stream s) {
if (s.device == mlx::core::Device::cpu) {
throw std::runtime_error("CudaEvent can not wait on cpu stream.");
} else {
record(cu::get_stream(s).last_cuda_stream());
}
}
bool CudaEvent::completed() const {
return cudaEventQuery(*event_) == cudaSuccess;
}
///////////////////////////////////////////////////////////////////////////////
// SharedEvent implementations
///////////////////////////////////////////////////////////////////////////////
namespace {
__host__ __device__ void event_wait(SharedEvent::Atomic* ac, uint64_t value) {
uint64_t current;
while ((current = ac->load()) < value) {
ac->wait(current);
}
}
__host__ __device__ void event_signal(SharedEvent::Atomic* ac, uint64_t value) {
ac->store(value);
ac->notify_all();
}
__global__ void event_wait_kernel(SharedEvent::Atomic* ac, uint64_t value) {
event_wait(ac, value);
}
__global__ void event_signal_kernel(SharedEvent::Atomic* ac, uint64_t value) {
event_signal(ac, value);
}
} // namespace
SharedEvent::SharedEvent() {
// Allocate cuda::atomic on managed memory.
allocator::Buffer buffer = allocator::malloc(sizeof(Atomic));
Atomic* ac = static_cast<Atomic*>(buffer.raw_ptr());
new (ac) Atomic(0);
ac_ = std::shared_ptr<Atomic>(ac, [buffer](Atomic* ptr) {
ptr->~Atomic();
allocator::free(buffer);
});
}
void SharedEvent::wait(uint64_t value) {
nvtx3::scoped_range r("cu::SharedEvent::wait");
event_wait(ac_.get(), value);
}
void SharedEvent::wait(cudaStream_t stream, uint64_t value) {
event_wait_kernel<<<1, 1, 0, stream>>>(ac_.get(), value);
}
void SharedEvent::wait(Stream s, uint64_t value) {
nvtx3::scoped_range r("cu::SharedEvent::wait(s)");
if (s.device == mlx::core::Device::cpu) {
scheduler::enqueue(s, [*this, value]() mutable { wait(value); });
} else {
auto& encoder = get_command_encoder(s);
encoder.launch_kernel(
encoder.stream().last_cuda_stream(),
[this, value](cudaStream_t stream) { wait(stream, value); });
encoder.add_completed_handler([ac = ac_]() {});
encoder.end_encoding();
}
}
void SharedEvent::signal(uint64_t value) {
nvtx3::scoped_range r("cu::SharedEvent::signal");
event_signal(ac_.get(), value);
}
void SharedEvent::signal(cudaStream_t stream, uint64_t value) {
event_signal_kernel<<<1, 1, 0, stream>>>(ac_.get(), value);
}
void SharedEvent::signal(Stream s, uint64_t value) {
nvtx3::scoped_range r("cu::SharedEvent::signal(s)");
if (s.device == mlx::core::Device::cpu) {
scheduler::enqueue(s, [*this, value]() mutable { signal(value); });
} else {
auto& encoder = get_command_encoder(s);
encoder.launch_kernel(
encoder.stream().last_cuda_stream(),
[this, value](cudaStream_t stream) { signal(stream, value); });
encoder.add_completed_handler([ac = ac_]() {});
encoder.end_encoding();
}
}
bool SharedEvent::is_signaled(uint64_t value) const {
nvtx3::scoped_range r("cu::SharedEvent::is_signaled");
return ac_->load() >= value;
}
uint64_t SharedEvent::value() const {
nvtx3::scoped_range r("cu::SharedEvent::value");
return ac_->load();
}
} // namespace cu
///////////////////////////////////////////////////////////////////////////////
// Event implementations
///////////////////////////////////////////////////////////////////////////////
namespace {
struct EventImpl {
// CudaEvent is preferred when possible because it is fast, however we have
// to fallback to SharedEvent in following cases:
// 1. the event is used to wait/signal a cpu stream;
// 2. signal value other than 1 has been specified.
std::unique_ptr<cu::CudaEvent> cuda;
std::unique_ptr<cu::SharedEvent> shared;
bool is_created() const {
return cuda || shared;
}
void ensure_created(Stream s, uint64_t signal_value) {
if (is_created()) {
return;
}
if (s.device == mlx::core::Device::cpu || signal_value > 1) {
nvtx3::mark("Using slow SharedEvent");
shared = std::make_unique<cu::SharedEvent>();
} else {
cuda = std::make_unique<cu::CudaEvent>();
}
}
};
} // namespace
Event::Event(Stream s) : stream_(s) {
event_ = std::shared_ptr<void>(
new EventImpl(), [](void* ptr) { delete static_cast<EventImpl*>(ptr); });
}
void Event::wait() {
auto* event = static_cast<EventImpl*>(event_.get());
assert(event->is_created());
if (event->cuda) {
assert(value() == 1);
event->cuda->wait();
} else {
event->shared->wait(value());
}
}
void Event::wait(Stream s) {
auto* event = static_cast<EventImpl*>(event_.get());
assert(event->is_created());
if (event->cuda) {
assert(value() == 1);
event->cuda->wait(s);
} else {
event->shared->wait(s, value());
}
}
void Event::signal(Stream s) {
auto* event = static_cast<EventImpl*>(event_.get());
event->ensure_created(s, value());
if (event->cuda) {
assert(value() == 1);
event->cuda->record(s);
} else {
event->shared->signal(s, value());
}
}
bool Event::is_signaled() const {
auto* event = static_cast<EventImpl*>(event_.get());
if (!event->is_created()) {
return false;
}
if (event->cuda) {
assert(value() == 1);
return event->cuda->recorded() && event->cuda->completed();
} else {
return event->shared->is_signaled(value());
}
}
} // namespace mlx::core

View File

@@ -1,66 +0,0 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/stream.h"
#include <cuda_runtime.h>
#include <cuda/atomic>
#include <memory>
namespace mlx::core::cu {
class CudaEventHandle;
// Wrapper of native cuda event. It can synchronize between GPU streams, or wait
// on GPU stream in CPU stream, but can not wait on CPU stream.
class CudaEvent {
public:
CudaEvent();
void wait();
void wait(cudaStream_t stream);
void wait(Stream s);
void record(cudaStream_t stream);
void record(Stream s);
// Return whether the recorded kernels have completed. Note that this method
// returns true if record() has not been called.
bool completed() const;
bool recorded() const {
return recorded_;
}
private:
bool recorded_{false};
std::shared_ptr<CudaEventHandle> event_;
};
// Event that can synchronize between CPU and GPU. It is much slower than
// CudaEvent so the latter should always be preferred when possible.
class SharedEvent {
public:
using Atomic = cuda::atomic<uint64_t>;
SharedEvent();
void wait(uint64_t value);
void wait(cudaStream_t stream, uint64_t value);
void wait(Stream s, uint64_t value);
void signal(uint64_t value);
void signal(cudaStream_t stream, uint64_t value);
void signal(Stream s, uint64_t value);
bool is_signaled(uint64_t value) const;
uint64_t value() const;
const std::shared_ptr<Atomic>& atomic() const {
return ac_;
}
private:
std::shared_ptr<Atomic> ac_;
};
} // namespace mlx::core::cu

View File

@@ -1,70 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/event.h"
#include "mlx/fence.h"
#include "mlx/scheduler.h"
#include <nvtx3/nvtx3.hpp>
namespace mlx::core {
namespace {
__host__ __device__ void busy_wait(cuda::atomic<uint64_t>* ac, uint64_t value) {
while (true) {
// In theory the atomic_thread_fence is not needed, but for CUDA 11 without
// it the load() may never return new value.
cuda::atomic_thread_fence(cuda::memory_order_seq_cst);
uint64_t current = ac->load();
if (current >= value) {
break;
}
}
}
__global__ void busy_wait_kernel(cuda::atomic<uint64_t>* ac, uint64_t value) {
busy_wait(ac, value);
}
} // namespace
struct FenceImpl {
uint32_t count;
cu::SharedEvent event;
};
Fence::Fence(Stream s) {
fence_ = std::shared_ptr<void>(
new FenceImpl{0}, [](void* ptr) { delete static_cast<FenceImpl*>(ptr); });
}
void Fence::wait(Stream s, const array&) {
auto* fence = static_cast<FenceImpl*>(fence_.get());
// We can't use SharedEvent::wait because it could hang in CUDA 11, see also:
// https://github.com/ml-explore/mlx/issues/2137
const auto& ac = fence->event.atomic();
if (s.device == mlx::core::Device::cpu) {
scheduler::enqueue(s, [ac, count = fence->count]() {
nvtx3::scoped_range r("Fence::wait()");
busy_wait(ac.get(), count);
});
} else {
nvtx3::scoped_range r("Fence::wait(s)");
auto& encoder = cu::get_command_encoder(s);
encoder.launch_kernel(
encoder.stream().last_cuda_stream(), [&](cudaStream_t stream) {
busy_wait_kernel<<<1, 1, 0>>>(ac.get(), fence->count);
});
encoder.add_completed_handler([ac]() {});
encoder.end_encoding();
}
}
void Fence::update(Stream s, const array&) {
auto* fence = static_cast<FenceImpl*>(fence_.get());
fence->count++;
fence->event.signal(s, fence->count);
}
} // namespace mlx::core

View File

@@ -1,15 +0,0 @@
// Copyright © 2025 Apple Inc.
namespace mlx::core::cu {
template <typename T>
struct Arange {
const T start;
const T step;
__device__ T operator()(uint32_t i) const {
return start + i * step;
}
};
} // namespace mlx::core::cu

View File

@@ -1,107 +0,0 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <cuda_fp16.h>
#include <cuda/std/limits>
#include <cuda/std/type_traits>
namespace mlx::core::cu {
///////////////////////////////////////////////////////////////////////////////
// Missing C++ operator overrides for CUDA 7.
///////////////////////////////////////////////////////////////////////////////
#if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800
#define MLX_DEFINE_BF16_OP(OP) \
__forceinline__ __device__ __nv_bfloat16 operator OP( \
__nv_bfloat16 x, __nv_bfloat16 y) { \
return __float2bfloat16(__bfloat162float(x) OP __bfloat162float(y)); \
}
#define MLX_DEFINE_BF16_CMP(OP) \
__forceinline__ __device__ bool operator OP( \
__nv_bfloat16 x, __nv_bfloat16 y) { \
return __float2bfloat16(__bfloat162float(x) OP __bfloat162float(y)); \
}
MLX_DEFINE_BF16_OP(+)
MLX_DEFINE_BF16_OP(-)
MLX_DEFINE_BF16_OP(*)
MLX_DEFINE_BF16_OP(/)
MLX_DEFINE_BF16_CMP(>)
MLX_DEFINE_BF16_CMP(<)
MLX_DEFINE_BF16_CMP(>=)
MLX_DEFINE_BF16_CMP(<=)
#undef MLX_DEFINE_BF16_OP
#undef MLX_DEFINE_BF16_CMP
#endif // CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800
///////////////////////////////////////////////////////////////////////////////
// Additional C++ operator overrides between half types and native types.
///////////////////////////////////////////////////////////////////////////////
template <typename T, typename U>
constexpr bool is_integral_except =
cuda::std::is_integral_v<T> && !cuda::std::is_same_v<T, U>;
template <typename T, typename U>
constexpr bool is_arithmetic_except =
cuda::std::is_arithmetic_v<T> && !cuda::std::is_same_v<T, U>;
#define MLX_DEFINE_HALF_OP(HALF, HALF2FLOAT, FLOAT2HALF, OP) \
template < \
typename T, \
typename = cuda::std::enable_if_t<is_integral_except<T, HALF>>> \
__forceinline__ __device__ HALF operator OP(HALF x, T y) { \
return FLOAT2HALF(HALF2FLOAT(x) OP static_cast<float>(y)); \
} \
template < \
typename T, \
typename = cuda::std::enable_if_t<is_integral_except<T, HALF>>> \
__forceinline__ __device__ HALF operator OP(T x, HALF y) { \
return FLOAT2HALF(static_cast<float>(x) OP HALF2FLOAT(y)); \
}
#define MLX_DEFINE_HALF_CMP(HALF, HALF2FLOAT, OP) \
template < \
typename T, \
typename = cuda::std::enable_if_t<is_arithmetic_except<T, HALF>>> \
__forceinline__ __device__ bool operator OP(HALF x, T y) { \
return HALF2FLOAT(x) OP static_cast<float>(y); \
} \
template < \
typename T, \
typename = cuda::std::enable_if_t<is_arithmetic_except<T, HALF>>> \
__forceinline__ __device__ bool operator OP(T x, HALF y) { \
return static_cast<float>(y) OP HALF2FLOAT(x); \
}
MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, +)
MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, -)
MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, *)
MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, /)
MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, +)
MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, -)
MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, *)
MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, /)
MLX_DEFINE_HALF_CMP(__half, __half2float, <)
MLX_DEFINE_HALF_CMP(__half, __half2float, >)
MLX_DEFINE_HALF_CMP(__half, __half2float, <=)
MLX_DEFINE_HALF_CMP(__half, __half2float, >=)
MLX_DEFINE_HALF_CMP(__half, __half2float, ==)
MLX_DEFINE_HALF_CMP(__half, __half2float, !=)
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, <)
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, >)
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, <=)
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, >=)
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, ==)
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, !=)
#undef MLX_DEFINE_HALF_OP
#undef MLX_DEFINE_HALF_CMP
} // namespace mlx::core::cu

View File

@@ -1,163 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/dtype_utils.cuh"
#include "mlx/backend/cuda/kernels/arange.cuh"
#include "mlx/backend/cuda/kernels/fp16_math.cuh"
#include "mlx/distributed/primitives.h"
#include "mlx/dtype_utils.h"
#include "mlx/fast_primitives.h"
#include "mlx/primitives.h"
#include <nvtx3/nvtx3.hpp>
#include <thrust/device_ptr.h>
#include <thrust/transform.h>
#include <cassert>
namespace mlx::core {
void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Arange::eval_gpu");
assert(inputs.size() == 0);
out.set_data(allocator::malloc(out.nbytes()));
if (out.size() == 0) {
return;
}
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
encoder.set_output_array(out);
encoder.launch_kernel([&, this](cudaStream_t stream) {
MLX_SWITCH_INT_FLOAT_TYPES_CHECKED(out.dtype(), "Arange", CTYPE, {
using OutType = cuda_type_t<CTYPE>;
CTYPE step =
static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_);
thrust::transform(
cu::thrust_policy(stream),
thrust::counting_iterator<uint32_t>(0),
thrust::counting_iterator<uint32_t>(out.data_size()),
thrust::device_pointer_cast(out.data<OutType>()),
cu::Arange<OutType>{
static_cast<OutType>(start_), static_cast<OutType>(step)});
});
});
}
#define NO_GPU_MULTI(func) \
void func::eval_gpu( \
const std::vector<array>& inputs, std::vector<array>& outputs) { \
throw std::runtime_error(#func " has no CUDA implementation."); \
}
#define NO_GPU(func) \
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
throw std::runtime_error(#func " has no CUDA implementation."); \
}
NO_GPU(Abs)
NO_GPU(Add)
NO_GPU(AddMM)
NO_GPU(ArcCos)
NO_GPU(ArcCosh)
NO_GPU(ArcSin)
NO_GPU(ArcSinh)
NO_GPU(ArcTan)
NO_GPU(ArcTan2)
NO_GPU(ArcTanh)
NO_GPU(ArgPartition)
NO_GPU(ArgReduce)
NO_GPU(ArgSort)
NO_GPU(BitwiseBinary)
NO_GPU(BitwiseInvert)
NO_GPU(BlockMaskedMM)
NO_GPU(Ceil)
NO_GPU_MULTI(Compiled)
NO_GPU(Conjugate)
NO_GPU(Convolution)
NO_GPU(Cos)
NO_GPU(Cosh)
NO_GPU(Divide)
NO_GPU_MULTI(DivMod)
NO_GPU(DynamicSlice)
NO_GPU(DynamicSliceUpdate)
NO_GPU(Remainder)
NO_GPU(Equal)
NO_GPU(Erf)
NO_GPU(ErfInv)
NO_GPU(Exp)
NO_GPU(Expm1)
NO_GPU(FFT)
NO_GPU(Floor)
NO_GPU(Gather)
NO_GPU(GatherAxis)
NO_GPU(GatherMM)
NO_GPU(GatherQMM)
NO_GPU(Greater)
NO_GPU(GreaterEqual)
NO_GPU(Hadamard)
NO_GPU(Imag)
NO_GPU(Less)
NO_GPU(LessEqual)
NO_GPU(Load)
NO_GPU(Log)
NO_GPU(Log1p)
NO_GPU(LogicalNot)
NO_GPU(LogicalAnd)
NO_GPU(LogicalOr)
NO_GPU(LogAddExp)
NO_GPU(LogSumExp)
NO_GPU_MULTI(LUF)
NO_GPU(Matmul)
NO_GPU(Maximum)
NO_GPU(Minimum)
NO_GPU(Multiply)
NO_GPU(Negative)
NO_GPU(NotEqual)
NO_GPU(Partition)
NO_GPU(Power)
NO_GPU_MULTI(QRF)
NO_GPU(QuantizedMatmul)
NO_GPU(RandomBits)
NO_GPU(Real)
NO_GPU(Reduce)
NO_GPU(Round)
NO_GPU(Scan)
NO_GPU(Scatter)
NO_GPU(ScatterAxis)
NO_GPU(Select)
NO_GPU(Sigmoid)
NO_GPU(Sign)
NO_GPU(Sin)
NO_GPU(Sinh)
NO_GPU(SliceUpdate)
NO_GPU(Softmax)
NO_GPU(Sort)
NO_GPU(Square)
NO_GPU(Sqrt)
NO_GPU(Subtract)
NO_GPU_MULTI(SVD)
NO_GPU(Tan)
NO_GPU(Tanh)
NO_GPU(Inverse)
NO_GPU(Cholesky)
NO_GPU_MULTI(Eigh)
namespace fast {
NO_GPU_MULTI(LayerNorm)
NO_GPU_MULTI(LayerNormVJP)
NO_GPU_MULTI(RMSNorm)
NO_GPU_MULTI(RMSNormVJP)
NO_GPU_MULTI(RoPE)
NO_GPU(ScaledDotProductAttention)
NO_GPU_MULTI(AffineQuantize)
NO_GPU_MULTI(CustomKernel)
} // namespace fast
namespace distributed {
NO_GPU_MULTI(AllReduce)
NO_GPU_MULTI(AllGather)
NO_GPU_MULTI(Send)
NO_GPU_MULTI(Recv)
} // namespace distributed
} // namespace mlx::core

View File

@@ -1,15 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/gpu/slicing.h"
namespace mlx::core {
void concatenate_gpu(
const std::vector<array>& inputs,
array& out,
int axis,
const Stream& s) {
throw std::runtime_error("concatenate_gpu not implemented in CUDA backend.");
}
} // namespace mlx::core

View File

@@ -1,26 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/utils.h"
#include "mlx/backend/cuda/device.h"
#include <fmt/format.h>
namespace mlx::core {
CudaStream::CudaStream(cu::Device& device) {
device.make_current();
CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking));
}
CudaStream::~CudaStream() {
CHECK_CUDA_ERROR(cudaStreamDestroy(stream_));
}
void check_cuda_error(const char* name, cudaError_t err) {
if (err != cudaSuccess) {
throw std::runtime_error(
fmt::format("{} failed: {}", name, cudaGetErrorString(err)));
}
}
} // namespace mlx::core

View File

@@ -1,36 +0,0 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <cuda_runtime.h>
namespace mlx::core {
namespace cu {
class Device;
}
// Cuda stream managed with RAII.
class CudaStream {
public:
explicit CudaStream(cu::Device& device);
~CudaStream();
CudaStream(const CudaStream&) = delete;
CudaStream& operator=(const CudaStream&) = delete;
operator cudaStream_t() const {
return stream_;
}
private:
cudaStream_t stream_;
};
// Throw exception if the cuda API does not succeed.
void check_cuda_error(const char* name, cudaError_t err);
// The macro version that prints the command that failed.
#define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd))
} // namespace mlx::core

View File

@@ -1,90 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/worker.h"
#include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/device.h"
namespace mlx::core::cu {
Worker::Worker()
: signal_stream_(device(mlx::core::Device::gpu)),
worker_(&Worker::thread_fn, this) {}
Worker::~Worker() {
{
std::lock_guard lock(worker_mutex_);
stop_ = true;
}
worker_event_.signal(batch_ + 1);
worker_.join();
}
void Worker::add_task(std::function<void()> task) {
pending_tasks_.push_back(std::move(task));
}
void Worker::consume_in_this_thread() {
for (auto& task : pending_tasks_) {
task();
}
pending_tasks_.clear();
}
void Worker::end_batch() {
batch_++;
{
std::lock_guard lock(worker_mutex_);
worker_tasks_[batch_] = std::move(pending_tasks_);
}
uncommited_batches_++;
}
void Worker::commit() {
if (uncommited_batches_ == 0) {
return;
}
uncommited_batches_ = 0;
worker_event_.signal(batch_);
}
void Worker::commit(cudaStream_t stream) {
if (uncommited_batches_ == 0) {
return;
}
uncommited_batches_ = 0;
// Signal the |worker_event_| in |signal_stream_| after the kernels in
// |stream_| finish running.
signal_event_.record(stream);
signal_event_.wait(signal_stream_);
worker_event_.signal(signal_stream_, batch_);
}
void Worker::thread_fn() {
// The worker thread is safe to free buffers.
allocator().register_this_thread();
while (!stop_) {
uint64_t batch = worker_event_.value();
Tasks tasks;
{
std::lock_guard lock(worker_mutex_);
// Move tasks in signaled batches.
auto end = worker_tasks_.upper_bound(batch);
for (auto it = worker_tasks_.begin(); it != end; ++it) {
if (tasks.empty()) {
tasks = std::move(it->second);
} else {
std::move(
it->second.begin(), it->second.end(), std::back_inserter(tasks));
}
}
worker_tasks_.erase(worker_tasks_.begin(), end);
}
for (auto& task : tasks) {
task();
}
worker_event_.wait(batch + 1);
}
}
} // namespace mlx::core::cu

View File

@@ -1,68 +0,0 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/backend/cuda/event.h"
#include "mlx/backend/cuda/utils.h"
#include <functional>
#include <map>
#include <mutex>
#include <thread>
namespace mlx::core::cu {
// Run tasks in worker thread, synchronized with cuda stream.
class Worker {
public:
Worker();
~Worker();
Worker(const Worker&) = delete;
Worker& operator=(const Worker&) = delete;
// Add a pending |task| that will run when consumed or commited.
void add_task(std::function<void()> task);
// Run pending tasks immediately in current thread.
void consume_in_this_thread();
// Put pending tasks in a batch.
void end_batch();
// Inform worker thread to run current batches now.
void commit();
// Inform worker thread to run current batches after kernels in |stream|
// finish running.
void commit(cudaStream_t stream);
// Return how many batches have been added but not committed yet.
size_t uncommited_batches() const {
return uncommited_batches_;
}
private:
void thread_fn();
uint64_t batch_{0};
size_t uncommited_batches_{0};
// Cuda stream and event for signaling kernel completion.
CudaStream signal_stream_;
CudaEvent signal_event_;
// Worker thread.
SharedEvent worker_event_;
std::thread worker_;
std::mutex worker_mutex_;
bool stop_{false};
// Tasks are put in |pending_tasks_| first, and then moved to
// |worker_tasks_| when end_batch() is called.
using Tasks = std::vector<std::function<void()>>;
Tasks pending_tasks_;
std::map<uint64_t, Tasks> worker_tasks_;
};
} // namespace mlx::core::cu

View File

@@ -1,5 +0,0 @@
target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp)

View File

@@ -1,9 +0,0 @@
// Copyright © 2025 Apple Inc.
#pragma once
namespace mlx::core::gpu {
bool is_available();
} // namespace mlx::core::gpu

View File

@@ -1,49 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/gpu/copy.h"
#include "mlx/primitives.h"
#include <cassert>
namespace mlx::core {
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
bool donated = set_copy_output_data(in, out, ctype);
if (donated && in.dtype() == out.dtype()) {
// If the output has the same type as the input then there is nothing to
// copy, just use the buffer.
return;
}
if (ctype == CopyType::GeneralGeneral) {
ctype = CopyType::General;
}
copy_gpu_inplace(in, out, ctype, s);
}
void copy_gpu(const array& in, array& out, CopyType ctype) {
copy_gpu(in, out, ctype, out.primitive().stream());
}
void copy_gpu_inplace(
const array& in,
array& out,
CopyType ctype,
const Stream& s) {
assert(in.shape() == out.shape());
return copy_gpu_inplace(
in, out, in.shape(), in.strides(), out.strides(), 0, 0, ctype, s);
}
void copy_gpu_inplace(
const array& in,
array& out,
const Strides& i_strides,
int64_t i_offset,
CopyType ctype,
const Stream& s) {
assert(in.shape() == out.shape());
return copy_gpu_inplace(
in, out, in.shape(), i_strides, out.strides(), i_offset, 0, ctype, s);
}
} // namespace mlx::core

View File

@@ -1,217 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/primitives.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/gpu/slicing.h"
#include <cassert>
#define MLX_PROFILER_RANGE(message)
namespace mlx::core {
namespace {
void reshape(const array& in, array& out, Stream s) {
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) {
out.set_data(allocator::malloc(out.nbytes()));
copy_gpu_inplace(
in,
out,
in.shape(),
in.strides(),
make_contiguous_strides(in.shape()),
0,
0,
CopyType::General,
s);
} else {
shared_buffer_reshape(in, out_strides, out);
}
}
} // namespace
void AsStrided::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("AsStrided::eval_gpu");
eval(inputs, out);
}
void AsType::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("AsType::eval_gpu");
CopyType ctype =
inputs[0].flags().contiguous ? CopyType::Vector : CopyType::General;
copy_gpu(inputs[0], out, ctype);
}
void Broadcast::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Broadcast::eval_gpu");
eval(inputs, out);
}
void BroadcastAxes::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("BroadcastAxes::eval_gpu");
eval(inputs, out);
}
void Concatenate::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Concatenate::eval_gpu");
concatenate_gpu(inputs, out, axis_, stream());
}
void Contiguous::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Contiguous::eval_gpu");
assert(inputs.size() == 1);
auto& in = inputs[0];
constexpr size_t extra_bytes = 16384;
if (in.buffer_size() <= out.nbytes() + extra_bytes &&
(in.flags().row_contiguous ||
(allow_col_major_ && in.flags().col_contiguous))) {
out.copy_shared_buffer(in);
} else {
copy_gpu(in, out, CopyType::General);
}
}
void Copy::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Copy::eval_gpu");
eval(inputs, out);
}
void CustomTransforms::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
MLX_PROFILER_RANGE("CustomTransforms::eval_gpu");
eval(inputs, outputs);
}
void Depends::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
MLX_PROFILER_RANGE("Depends::eval_gpu");
eval(inputs, outputs);
}
void ExpandDims::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("ExpandDims::eval_gpu");
eval(inputs, out);
}
void Full::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Full::eval_gpu");
auto in = inputs[0];
CopyType ctype;
if (in.data_size() == 1) {
ctype = CopyType::Scalar;
} else if (in.flags().contiguous) {
ctype = CopyType::Vector;
} else {
ctype = CopyType::General;
}
copy_gpu(in, out, ctype);
}
void Flatten::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Flatten::eval_gpu");
reshape(inputs[0], out, stream());
}
void NumberOfElements::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("NumberOfElements::eval_gpu");
eval(inputs, out);
}
void Pad::eval_gpu(const std::vector<array>& inputs, array& out) {
// Inputs must be base input array and scalar val array
assert(inputs.size() == 2);
auto& in = inputs[0];
auto& val = inputs[1];
// Padding value must be a scalar
assert(val.size() == 1);
// Padding value, input and output must be of the same type
assert(val.dtype() == in.dtype() && in.dtype() == out.dtype());
pad_gpu(in, val, out, axes_, low_pad_size_, stream());
}
void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Reshape::eval_gpu");
reshape(inputs[0], out, stream());
}
void Split::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
MLX_PROFILER_RANGE("Split::eval_gpu");
eval(inputs, outputs);
}
void Slice::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Slice::eval_gpu");
assert(inputs.size() == 1);
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
auto& in = inputs[0];
slice_gpu(in, out, start_indices_, strides_, stream());
}
void Squeeze::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Squeeze::eval_gpu");
eval(inputs, out);
}
void StopGradient::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("StopGradient::eval_gpu");
eval(inputs, out);
}
void Transpose::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Transpose::eval_gpu");
eval(inputs, out);
}
void Unflatten::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Unflatten::eval_gpu");
reshape(inputs[0], out, stream());
}
void View::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("View::eval_gpu");
auto& in = inputs[0];
auto ibytes = size_of(in.dtype());
auto obytes = size_of(out.dtype());
// Conditions for buffer copying (disjunction):
// - type size is the same
// - type size is smaller and the last axis is contiguous
// - the entire array is row contiguous
if (ibytes == obytes || (obytes < ibytes && in.strides().back() == 1) ||
in.flags().row_contiguous) {
auto strides = in.strides();
for (int i = 0; i < static_cast<int>(strides.size()) - 1; ++i) {
strides[i] *= ibytes;
strides[i] /= obytes;
}
out.copy_shared_buffer(
in, strides, in.flags(), in.data_size() * ibytes / obytes);
} else {
auto tmp = array(in.shape(), in.dtype(), nullptr, {});
tmp.set_data(allocator::malloc(tmp.nbytes()));
copy_gpu_inplace(in, tmp, CopyType::General, stream());
auto flags = out.flags();
flags.contiguous = true;
flags.row_contiguous = true;
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
out.copy_shared_buffer(tmp, out.strides(), flags, out.size());
}
}
} // namespace mlx::core

View File

@@ -1,44 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/slicing.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/gpu/slicing.h"
namespace mlx::core {
void slice_gpu(
const array& in,
array& out,
const Shape& start_indices,
const Shape& strides,
const Stream& s) {
slice(in, out, start_indices, strides);
}
void pad_gpu(
const array& in,
const array& val,
array& out,
const std::vector<int>& axes,
const Shape& low_pad_size,
const Stream& s) {
// Fill output with val
fill_gpu(val, out, s);
// Find offset for start of input values
size_t data_offset = 0;
for (int i = 0; i < axes.size(); i++) {
auto ax = axes[i] < 0 ? out.ndim() + axes[i] : axes[i];
data_offset += out.strides()[ax] * low_pad_size[i];
}
// Extract slice from output where input will be pasted
array out_slice(in.shape(), out.dtype(), nullptr, {});
out_slice.copy_shared_buffer(
out, out.strides(), out.flags(), out_slice.size(), data_offset);
// Copy input values into the slice
copy_gpu_inplace(in, out_slice, CopyType::GeneralGeneral, s);
}
} // namespace mlx::core

View File

@@ -93,7 +93,6 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/event.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp

View File

@@ -1,6 +1,7 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/allocator.h"
#include "mlx/backend/metal/metal.h"
#include "mlx/backend/metal/metal_impl.h"
#include "mlx/backend/metal/resident.h"
#include "mlx/memory.h"

View File

@@ -90,7 +90,7 @@ void binary_op_gpu_inplace(
work_per_thread = large ? 4 : 2;
} else {
large = out.data_size() > UINT32_MAX;
work_per_thread = get_work_per_thread(a.dtype());
work_per_thread = 1;
}
std::string kernel_name =
get_kernel_name(bopt, op, a, large, shape.size(), work_per_thread);
@@ -137,20 +137,13 @@ void binary_op_gpu_inplace(
compute_encoder.dispatch_threads(grid_dims, group_dims);
} else {
// Launch a 1D or 2D grid of threads
size_t nthreads = ceildiv(out.data_size(), work_per_thread);
size_t nthreads = out.data_size();
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
MTL::Size grid_dims;
if (large) {
compute_encoder.set_bytes<int64_t>(out.data_size(), arg_idx++);
grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread);
} else {
compute_encoder.set_bytes<int>(out.data_size(), arg_idx++);
grid_dims = MTL::Size(nthreads, 1, 1);
}
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
}

View File

@@ -64,7 +64,6 @@ inline void build_kernel(
cnt++);
}
std::string idx_type = use_big_index ? "int64_t" : "uint";
if (add_indices) {
os += fmt::format(
" constant const int64_t* in_strides [[buffer({0})]],\n", cnt++);
@@ -84,9 +83,6 @@ inline void build_kernel(
" constant const int64_t* output_strides [[buffer({0})]],\n", cnt++);
os += fmt::format(
" constant const int* output_shape [[buffer({0})]],\n", cnt++);
} else {
os += fmt::format(
" constant const {0}& size [[buffer({1})]],\n", idx_type, cnt++);
}
if (dynamic_dims) {
os += fmt::format(" constant const int& ndim [[buffer({0})]],\n", cnt++);
@@ -96,14 +92,13 @@ inline void build_kernel(
os += " uint3 pos [[thread_position_in_grid]],\n";
os += " uint3 grid [[threads_per_grid]]) {\n";
os += fmt::format(" constexpr int N_ = {0};\n", work_per_thread);
std::string idx_type = use_big_index ? "int64_t" : "uint";
if (contiguous && use_big_index) {
// This is only used for contiguous kernels which don't have
// a third grid dimension
os += " int64_t index = N_ * (pos.x + grid.x * int64_t(pos.y));\n";
} else if (contiguous) {
os += " uint index = N_ * pos.x;\n";
os += " int64_t index = pos.x + grid.x * int64_t(pos.y);\n";
} else if (work_per_thread > 1) {
os += fmt::format(" constexpr int N_ = {0};\n", work_per_thread);
os += fmt::format(
" int xshape = output_shape[{0}];\n",
dynamic_dims ? "ndim - 1" : std::to_string(ndim - 1));
@@ -115,9 +110,6 @@ inline void build_kernel(
" {0} index = pos.x + grid.x * (pos.y + {0}(grid.y) * pos.z);\n",
idx_type);
}
if (work_per_thread > 1 && contiguous) {
os += " for (int i = 0; i < N_ && index < size; ++i) {\n";
}
// Read constant / contiguous inputs in tmps
std::vector<array> nc_inputs;
@@ -201,7 +193,7 @@ inline void build_kernel(
}
// Open per-thread loop
if (work_per_thread > 1 && !contiguous) {
if (work_per_thread > 1) {
os +=
" for (int i = 0; i < N_ && (int(N_ * pos.x) + i) < xshape; ++i) {\n";
}
@@ -280,7 +272,6 @@ void Compiled::eval_gpu(
auto& s = stream();
auto& d = metal::device(s.device);
auto lib = d.get_library(kernel_lib_, [&]() {
int work_per_thread = get_work_per_thread(outputs_[0].dtype());
std::string kernel = metal::utils();
concatenate(
kernel, metal::unary_ops(), metal::binary_ops(), metal::ternary_ops());
@@ -293,9 +284,7 @@ void Compiled::eval_gpu(
constant_ids_,
/* contiguous = */ true,
/* ndim = */ 0,
/* dynamic_dims = */ false,
/* use_big_index = */ false,
/* work_per_thread = */ work_per_thread);
/* dynamic_dims = */ false);
build_kernel(
kernel,
kernel_lib_ + "_contiguous_large",
@@ -306,8 +295,7 @@ void Compiled::eval_gpu(
/* contiguous = */ true,
/* ndim = */ 0,
/* dynamic_dims = */ false,
/* use_big_index = */ true,
/* work_per_thread = */ work_per_thread);
/* use_big_index = */ true);
for (int i = 1; i < 8; i++) {
build_kernel(
kernel,
@@ -480,13 +468,6 @@ void Compiled::eval_gpu(
if (!contiguous) {
compute_encoder.set_vector_bytes(strides[0], cnt++);
compute_encoder.set_vector_bytes(shape, cnt++);
} else {
auto size = outputs[0].data_size();
if (large) {
compute_encoder.set_bytes<int64_t>(size, cnt++);
} else {
compute_encoder.set_bytes<int>(size, cnt++);
}
}
// Put the number of dims in if it is dynamic
@@ -496,13 +477,12 @@ void Compiled::eval_gpu(
// Launch the kernel
if (contiguous) {
int work_per_thread = get_work_per_thread(outputs[0].dtype());
size_t nthreads = ceildiv(outputs[0].data_size(), work_per_thread);
size_t nthreads = outputs[0].data_size();
MTL::Size group_dims(
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
MTL::Size grid_dims = large
? get_2d_grid_dims(
outputs[0].shape(), outputs[0].strides(), work_per_thread)
? get_2d_grid_dims(outputs[0].shape(), outputs[0].strides())
: MTL::Size(nthreads, 1, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims);
} else {

View File

@@ -5,7 +5,7 @@
#include <numeric>
#include <sstream>
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/kernels/defines.h"
@@ -952,7 +952,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
in,
wt,
out,
padding_lo_,
padding_,
kernel_strides_,
kernel_dilation_,
input_dilation_,
@@ -967,7 +967,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
in,
wt,
out,
padding_lo_,
padding_,
kernel_strides_,
kernel_dilation_,
input_dilation_,
@@ -983,7 +983,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
in,
wt,
out,
padding_lo_,
padding_,
kernel_strides_,
kernel_dilation_,
input_dilation_,

View File

@@ -1,15 +1,35 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/gpu/copy.h"
#include <sstream>
#include "mlx/backend/common/utils.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
namespace mlx::core {
constexpr int MAX_COPY_SPECIALIZED_DIMS = 3;
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
bool donated = set_copy_output_data(in, out, ctype);
if (donated && in.dtype() == out.dtype()) {
// If the output has the same type as the input then there is nothing to
// copy, just use the buffer.
return;
}
if (ctype == CopyType::GeneralGeneral) {
ctype = CopyType::General;
}
copy_gpu_inplace(in, out, ctype, s);
}
void copy_gpu(const array& in, array& out, CopyType ctype) {
copy_gpu(in, out, ctype, out.primitive().stream());
}
void copy_gpu_inplace(
const array& in,
array& out,
@@ -84,8 +104,6 @@ void copy_gpu_inplace(
"[Copy::eval_gpu] Dynamic output offset requires GeneralGeneral copy");
}
}
} else {
work_per_thread = get_work_per_thread(in.dtype());
}
concatenate(kernel_name, "_copy", type_to_name(in), type_to_name(out));
auto kernel = dynamic ? get_dynamic_copy_kernel(d, kernel_name, in, out)
@@ -147,23 +165,39 @@ void copy_gpu_inplace(
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder.dispatch_threads(grid_dims, group_dims);
} else {
size_t nthreads = ceildiv(out.data_size(), work_per_thread);
size_t nthreads = out.data_size();
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
MTL::Size grid_dims;
if (large) {
compute_encoder.set_bytes<int64_t>(out.data_size(), 2);
grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread);
} else {
compute_encoder.set_bytes<int>(out.data_size(), 2);
grid_dims = MTL::Size(nthreads, 1, 1);
}
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
}
void copy_gpu_inplace(
const array& in,
array& out,
CopyType ctype,
const Stream& s) {
assert(in.shape() == out.shape());
return copy_gpu_inplace(
in, out, in.shape(), in.strides(), out.strides(), 0, 0, ctype, s);
}
void copy_gpu_inplace(
const array& in,
array& out,
const Strides& i_strides,
int64_t i_offset,
CopyType ctype,
const Stream& s) {
assert(in.shape() == out.shape());
return copy_gpu_inplace(
in, out, in.shape(), i_strides, out.strides(), i_offset, 0, ctype, s);
}
void fill_gpu(const array& val, array& out, const Stream& s) {
if (out.size() == 0) {
return;
@@ -180,21 +214,14 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
compute_encoder.set_input_array(val, 0);
compute_encoder.set_output_array(out, 1);
int work_per_thread = get_work_per_thread(val.dtype());
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
size_t nthreads = ceildiv(out.data_size(), work_per_thread);
size_t nthreads = out.data_size();
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
MTL::Size grid_dims;
if (large) {
compute_encoder.set_bytes<int64_t>(out.data_size(), 2);
grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread);
} else {
compute_encoder.set_bytes<int>(out.data_size(), 2);
grid_dims = MTL::Size(nthreads, 1, 1);
}
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims);
}

View File

@@ -5,8 +5,6 @@
#include "mlx/backend/common/copy.h"
#include "mlx/stream.h"
#include <optional>
namespace mlx::core {
// Generic copy inplace

View File

@@ -1,6 +1,6 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/fast_primitives.h"

View File

@@ -1,20 +1,20 @@
// Copyright © 2023-2024 Apple Inc.
#include <cstdlib>
#include <filesystem>
#include <sstream>
#include <sys/sysctl.h>
#define NS_PRIVATE_IMPLEMENTATION
#define CA_PRIVATE_IMPLEMENTATION
#define MTL_PRIVATE_IMPLEMENTATION
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/metal.h"
#include "mlx/backend/metal/metal_impl.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/utils.h"
namespace fs = std::filesystem;
namespace mlx::core::metal {
namespace {
@@ -66,8 +66,8 @@ MTL::Library* try_load_bundle(
if (bundle != nullptr) {
std::string resource_path =
std::string(bundle->resourceURL()->fileSystemRepresentation()) + "/" +
lib_name + ".metallib";
auto [lib, error] = load_library_from_path(device, resource_path.c_str());
lib_name + ".metallib" auto [lib, error] =
load_library_from_path(device, resource_path.c_str());
if (lib) {
return lib;
}
@@ -79,18 +79,12 @@ MTL::Library* try_load_bundle(
// Firstly, search for the metallib in the same path as this binary
std::pair<MTL::Library*, NS::Error*> load_colocated_library(
MTL::Device* device,
const std::string& relative_path) {
std::string binary_dir = get_binary_directory();
if (binary_dir.size() == 0) {
return {nullptr, nullptr};
const std::string& lib_name) {
std::string lib_path = get_colocated_mtllib_path(lib_name);
if (lib_path.size() != 0) {
return load_library_from_path(device, lib_path.c_str());
}
auto path = fs::path(binary_dir) / relative_path;
if (!path.has_extension()) {
path.replace_extension(".metallib");
}
return load_library_from_path(device, path.c_str());
return {nullptr, nullptr};
}
std::pair<MTL::Library*, NS::Error*> load_swiftpm_library(
@@ -105,7 +99,7 @@ std::pair<MTL::Library*, NS::Error*> load_swiftpm_library(
auto bundles = NS::Bundle::allBundles();
for (int i = 0, c = (int)bundles->count(); i < c; i++) {
auto bundle = reinterpret_cast<NS::Bundle*>(bundles->object(i));
library = try_load_bundle(device, bundle->resourceURL(), lib_name);
library = try_load_bundle(device, bundle->resourceURL());
if (library != nullptr) {
return {library, nullptr};
}
@@ -115,34 +109,33 @@ std::pair<MTL::Library*, NS::Error*> load_swiftpm_library(
}
MTL::Library* load_default_library(MTL::Device* device) {
NS::Error* error[4];
NS::Error *error1, *error2, *error3;
MTL::Library* lib;
// First try the colocated mlx.metallib
std::tie(lib, error[0]) = load_colocated_library(device, "mlx");
if (lib) {
return lib;
}
std::tie(lib, error[1]) = load_colocated_library(device, "Resources/mlx");
std::tie(lib, error1) = load_colocated_library(device, "mlx");
if (lib) {
return lib;
}
// Then try default.metallib in a SwiftPM bundle if we have one
std::tie(lib, error[2]) = load_swiftpm_library(device, "default");
std::tie(lib, error2) = load_swiftpm_library(device, "default");
if (lib) {
return lib;
}
// Finally try default_mtllib_path
std::tie(lib, error[3]) = load_library_from_path(device, default_mtllib_path);
std::tie(lib, error3) = load_library_from_path(device, default_mtllib_path);
if (!lib) {
std::ostringstream msg;
msg << "Failed to load the default metallib. ";
for (int i = 0; i < 4; i++) {
if (error[i] != nullptr) {
msg << error[i]->localizedDescription()->utf8String() << " ";
}
if (error1 != nullptr) {
msg << error1->localizedDescription()->utf8String() << " ";
}
if (error2 != nullptr) {
msg << error2->localizedDescription()->utf8String() << " ";
}
if (error3 != nullptr) {
msg << error3->localizedDescription()->utf8String() << " ";
}
throw std::runtime_error(msg.str());
}
@@ -163,7 +156,6 @@ MTL::Library* load_library(
<< error->localizedDescription()->utf8String();
throw std::runtime_error(msg.str());
}
return lib;
}
// We have been given a path so try to load from lib_path / lib_name.metallib
@@ -176,7 +168,6 @@ MTL::Library* load_library(
<< "> with error " << error->localizedDescription()->utf8String();
throw std::runtime_error(msg.str());
}
return lib;
}
// Try to load the colocated library
@@ -197,8 +188,8 @@ MTL::Library* load_library(
std::ostringstream msg;
msg << "Failed to load the metallib " << lib_name << ".metallib. "
<< "We attempted to load it from <" << get_binary_directory() << "/"
<< lib_name << ".metallib" << ">";
<< "We attempted to load it from <" << get_colocated_mtllib_path(lib_name)
<< ">";
#ifdef SWIFTPM_BUNDLE
msg << " and from the Swift PM bundle.";
#endif
@@ -769,4 +760,42 @@ std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool() {
NS::AutoreleasePool::alloc()->init(), dtor);
}
void new_stream(Stream stream) {
if (stream.device == mlx::core::Device::gpu) {
device(stream.device).new_queue(stream.index);
}
}
const std::unordered_map<std::string, std::variant<std::string, size_t>>&
device_info() {
auto init_device_info = []()
-> std::unordered_map<std::string, std::variant<std::string, size_t>> {
auto pool = new_scoped_memory_pool();
auto raw_device = device(default_device()).mtl_device();
auto name = std::string(raw_device->name()->utf8String());
auto arch = std::string(raw_device->architecture()->name()->utf8String());
size_t memsize = 0;
size_t length = sizeof(memsize);
sysctlbyname("hw.memsize", &memsize, &length, NULL, 0);
size_t rsrc_limit = 0;
sysctlbyname("iogpu.rsrc_limit", &rsrc_limit, &length, NULL, 0);
if (rsrc_limit == 0) {
rsrc_limit = 499000;
}
return {
{"device_name", name},
{"architecture", arch},
{"max_buffer_length", raw_device->maxBufferLength()},
{"max_recommended_working_set_size",
raw_device->recommendedMaxWorkingSetSize()},
{"memory_size", memsize},
{"resource_limit", rsrc_limit}};
};
static auto device_info_ = init_device_info();
return device_info_;
}
} // namespace mlx::core::metal

View File

@@ -21,14 +21,18 @@ namespace mlx::core::metal {
// Note, this function must be left inline in a header so that it is not
// dynamically linked.
inline std::string get_binary_directory() {
inline std::string get_colocated_mtllib_path(const std::string& lib_name) {
Dl_info info;
std::string directory;
int success = dladdr((void*)get_binary_directory, &info);
std::string mtllib_path;
std::string lib_ext = lib_name + ".metallib";
int success = dladdr((void*)get_colocated_mtllib_path, &info);
if (success) {
directory = fs::path(info.dli_fname).remove_filename().c_str();
auto mtllib = fs::path(info.dli_fname).remove_filename() / lib_ext;
mtllib_path = mtllib.c_str();
}
return directory;
return mtllib_path;
}
using MTLFCList =
@@ -266,6 +270,4 @@ class Device {
Device& device(mlx::core::Device);
std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool();
} // namespace mlx::core::metal

View File

@@ -4,7 +4,7 @@
#include "mlx/allocator.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/distributed/ops.h"

View File

@@ -1,102 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#include <memory>
#include "mlx/backend/gpu/available.h"
#include "mlx/backend/gpu/eval.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
#include "mlx/scheduler.h"
namespace mlx::core::gpu {
bool is_available() {
return true;
}
void new_stream(Stream stream) {
if (stream.device == mlx::core::Device::gpu) {
metal::device(stream.device).new_queue(stream.index);
}
}
inline void check_error(MTL::CommandBuffer* cbuf) {
if (cbuf->status() == MTL::CommandBufferStatusError) {
std::ostringstream msg;
msg << "[METAL] Command buffer execution failed: "
<< cbuf->error()->localizedDescription()->utf8String();
throw std::runtime_error(msg.str());
}
}
void eval(array& arr) {
auto pool = metal::new_scoped_memory_pool();
auto s = arr.primitive().stream();
auto& d = metal::device(s.device);
auto command_buffer = d.get_command_buffer(s.index);
auto outputs = arr.outputs();
{
// If the array is a tracer hold a reference
// to its inputs so they don't get donated
std::vector<array> inputs;
if (arr.is_tracer()) {
inputs = arr.inputs();
}
debug_set_primitive_buffer_label(command_buffer, arr.primitive());
arr.primitive().eval_gpu(arr.inputs(), outputs);
}
std::unordered_set<std::shared_ptr<array::Data>> buffers;
for (auto& in : arr.inputs()) {
buffers.insert(in.data_shared_ptr());
}
for (auto& s : arr.siblings()) {
buffers.insert(s.data_shared_ptr());
}
// Remove the output if it was donated to by an input
if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) {
buffers.erase(it);
}
if (d.command_buffer_needs_commit(s.index)) {
d.end_encoding(s.index);
scheduler::notify_new_task(s);
command_buffer->addCompletedHandler(
[s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
scheduler::notify_task_completion(s);
check_error(cbuf);
});
d.commit_command_buffer(s.index);
d.get_command_buffer(s.index);
} else {
command_buffer->addCompletedHandler(
[s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
check_error(cbuf);
});
}
}
void finalize(Stream s) {
auto pool = metal::new_scoped_memory_pool();
auto& d = metal::device(s.device);
auto cb = d.get_command_buffer(s.index);
d.end_encoding(s.index);
cb->addCompletedHandler([s](MTL::CommandBuffer* cbuf) { check_error(cbuf); });
d.commit_command_buffer(s.index);
d.get_command_buffer(s.index);
}
void synchronize(Stream s) {
auto pool = metal::new_scoped_memory_pool();
auto& d = metal::device(s.device);
auto cb = d.get_command_buffer(s.index);
cb->retain();
d.end_encoding(s.index);
d.commit_command_buffer(s.index);
cb->waitUntilCompleted();
check_error(cb);
cb->release();
}
} // namespace mlx::core::gpu

View File

@@ -2,6 +2,7 @@
#include "mlx/event.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/metal_impl.h"
#include "mlx/scheduler.h"
namespace mlx::core {

View File

@@ -1,6 +1,7 @@
// Copyright © 2024 Apple Inc.
#include "mlx/fence.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/metal_impl.h"
#include "mlx/scheduler.h"
#include "mlx/utils.h"
@@ -138,7 +139,7 @@ void Fence::update(Stream stream, const array& x) {
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(x, 0);
compute_encoder.set_bytes(nthreads, 1);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
compute_encoder.dispatch_threadgroups(group_dims, grid_dims);
// Barrier on previous kernels
compute_encoder.barrier();

View File

@@ -7,10 +7,10 @@
#include "mlx/3rdparty/pocketfft.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/gpu/slicing.h"
#include "mlx/backend/metal/binary.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/slicing.h"
#include "mlx/backend/metal/unary.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/utils.h"
@@ -632,7 +632,7 @@ void fft_op(
func_consts.push_back(make_int(&rader_m, 3));
// The overall number of FFTs we're going to compute for this input
size_t size = out.dtype() == float32 ? out.size() : in.size();
int size = out.dtype() == float32 ? out.size() : in.size();
if (real && inverse && four_step_params.required) {
size = out.size();
}
@@ -659,6 +659,8 @@ void fft_op(
// We can perform 2 RFFTs at once so the batch size is halved.
batch_size = (batch_size + 2 - 1) / 2;
}
int out_buffer_size = out.size();
auto& compute_encoder = d.get_command_encoder(s.index);
auto in_type_str = in.dtype() == float32 ? "float" : "float2";
auto out_type_str = out.dtype() == float32 ? "float" : "float2";

View File

@@ -1,9 +1,11 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/common/hadamard.h"
#include <map>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/common/hadamard.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/kernels.h"
@@ -13,6 +15,7 @@
namespace mlx::core {
constexpr int MAX_HADAMARD_THREADS_PER_GROUP = 256;
constexpr int MAX_HADAMARD_BYTES = 32768; // 32KB
std::string gen_hadamard_codelet(int m) {
// Generate a O(m^2) hadamard codelet for a given M
@@ -57,142 +60,121 @@ std::string gen_hadamard_codelet(int m) {
return source.str();
}
void hadamard_mn_contiguous(
const array& x,
array& y,
int m,
int n1,
int n2,
float scale,
metal::Device& d,
const Stream& s) {
int n = n1 * n2;
int read_width_n1 = n1 == 2 ? 2 : 4;
int read_width_n2 = n2 == 2 ? 2 : 4;
int read_width_m = (n == 2 || m == 28) ? 2 : 4;
int max_radix_1 = std::min(n1, 16);
int max_radix_2 = std::min(n2, 16);
float scale_n1 = 1.0;
float scale_n2 = (m == 1) ? scale : 1.0;
float scale_m = scale;
// n2 is a row contiguous power of 2 hadamard transform
MTL::Size group_dims_n2(n2 / max_radix_2, 1, 1);
MTL::Size grid_dims_n2(n2 / max_radix_2, x.size() / n2, 1);
// n1 is a strided power of 2 hadamard transform with stride n2
MTL::Size group_dims_n1(n1 / max_radix_1, 1, 1);
MTL::Size grid_dims_n1(n1 / max_radix_1, x.size() / n, n2);
// m is a strided hadamard transform with stride n = n1 * n2
MTL::Size group_dims_m(
std::min(n / read_width_m, MAX_HADAMARD_THREADS_PER_GROUP), 1, 1);
MTL::Size grid_dims_m(
group_dims_m.width, x.size() / m / read_width_m / group_dims_m.width, 1);
// Make the kernel
std::string kname;
kname.reserve(32);
concatenate(kname, "hadamard_", n * m, "_", type_to_name(x));
auto lib = d.get_library(kname, [&]() {
std::string kernel;
concatenate(
kernel,
metal::utils(),
gen_hadamard_codelet(m),
metal::hadamard(),
get_template_definition(
"n2" + kname,
"hadamard_n",
get_type_string(x.dtype()),
n2,
max_radix_2,
read_width_n2));
if (n1 > 1) {
kernel += get_template_definition(
"n1" + kname,
"hadamard_n",
get_type_string(x.dtype()),
n1,
max_radix_1,
read_width_n1,
n2);
}
if (m > 1) {
kernel += get_template_definition(
"m" + kname,
"hadamard_m",
get_type_string(x.dtype()),
n,
m,
read_width_m);
}
return kernel;
});
// Launch the strided transform for n1
if (n1 > 1) {
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel("n1" + kname, lib);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(x, 0);
compute_encoder.set_output_array(y, 1);
compute_encoder.set_bytes(scale_n1, 2);
compute_encoder.dispatch_threads(grid_dims_n1, group_dims_n1);
}
// Launch the transform for n2
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel("n2" + kname, lib);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(n1 > 1 ? y : x, 0);
compute_encoder.set_output_array(y, 1);
compute_encoder.set_bytes(scale_n2, 2);
compute_encoder.dispatch_threads(grid_dims_n2, group_dims_n2);
// Launch the strided transform for m
if (m > 1) {
auto kernel = d.get_kernel("m" + kname, lib);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(y, 0);
compute_encoder.set_output_array(y, 1);
compute_encoder.set_bytes(scale_m, 2);
compute_encoder.dispatch_threads(grid_dims_m, group_dims_m);
}
}
void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = stream();
auto& d = metal::device(s.device);
auto& in = inputs[0];
// Split the hadamard transform so that all of them work on vectors smaller
// than 8192 elements.
//
// We decompose it in the following way:
//
// n = m * n1 * n2 = m * 2^k1 * 2^k2
//
// where m is in (1, 12, 20, 28) and n1 and n2 <= 8192
auto [n, m] = decompose_hadamard(in.shape().back());
int n1 = 1, n2 = n;
if (n > 8192) {
for (n2 = 2; n2 * n2 < n; n2 *= 2) {
std::vector<array> copies;
// Only support the last axis for now
int axis = in.ndim() - 1;
auto check_input = [&copies, &s](const array& x) {
// TODO(alexbarron) pass strides to kernel to relax this constraint
bool no_copy = x.flags().row_contiguous;
if (no_copy) {
return x;
} else {
copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
copy_gpu(x, copies.back(), CopyType::General, s);
return copies.back();
}
n1 = n / n2;
};
const array& in_contiguous = check_input(in);
if (in_contiguous.is_donatable()) {
out.copy_shared_buffer(in_contiguous);
} else {
out.set_data(allocator::malloc(out.nbytes()));
}
if (in.flags().row_contiguous) {
if (in.is_donatable()) {
out.copy_shared_buffer(in);
} else {
out.set_data(allocator::malloc(out.nbytes()));
}
hadamard_mn_contiguous(in, out, m, n1, n2, scale_, d, s);
} else {
copy_gpu(in, out, CopyType::General, s);
hadamard_mn_contiguous(out, out, m, n1, n2, scale_, d, s);
int n, m;
std::tie(n, m) = decompose_hadamard(in.shape(axis));
if (n * (int)size_of(in.dtype()) > MAX_HADAMARD_BYTES) {
throw std::invalid_argument(
"[hadamard] For n = m*2^k, 2^k > 8192 for FP32 or 2^k > 16384 for FP16/BF16 NYI");
}
int max_radix = std::min(n, 16);
// Use read_width 2 for m = 28 to avoid register spilling
int read_width = (n == 2 || m == 28) ? 2 : 4;
std::ostringstream kname;
kname << "hadamard_" << n * m << "_" << type_to_name(out);
auto kernel_name = kname.str();
auto& d = metal::device(s.device);
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
auto codelet = gen_hadamard_codelet(m);
kernel_source << metal::utils() << codelet << metal::hadamard();
kernel_source << get_template_definition(
"n" + kernel_name,
"hadamard_n",
get_type_string(in.dtype()),
n,
max_radix,
read_width);
kernel_source << get_template_definition(
"m" + kernel_name,
"hadamard_m",
get_type_string(in.dtype()),
n,
m,
read_width);
return kernel_source.str();
});
int batch_size = in.size() / n;
int threads_per = n / max_radix;
auto& compute_encoder = d.get_command_encoder(s.index);
auto launch_hadamard = [&](const array& in,
array& out,
const std::string& kernel_name,
float scale) {
auto kernel = d.get_kernel(kernel_name, lib);
assert(threads_per <= kernel->maxTotalThreadsPerThreadgroup());
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder.set_bytes(scale, 2);
MTL::Size group_dims = MTL::Size(1, threads_per, 1);
MTL::Size grid_dims = MTL::Size(batch_size, threads_per, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims);
};
if (m > 1) {
// When m is greater than 1, we decompose the
// computation into two uploads to the GPU:
//
// e.g. len(x) = 12*4 = 48, m = 12, n = 4
//
// y = h48 @ x
//
// Upload 1:
// tmp = a.reshape(12, 4) @ h4
//
// Upload 2:
// y = h12 @ tmp
array temp(in.shape(), in.dtype(), nullptr, {});
temp.set_data(allocator::malloc(temp.nbytes()));
copies.push_back(temp);
launch_hadamard(in_contiguous, temp, "n" + kernel_name, 1.0);
// Metal sometimes reports 256 max threads per group for hadamard_m kernel
threads_per = std::min(n / read_width, MAX_HADAMARD_THREADS_PER_GROUP);
batch_size = in.size() / m / read_width / threads_per;
launch_hadamard(temp, out, "m" + kernel_name, scale_);
} else {
launch_hadamard(in_contiguous, out, "n" + kernel_name, scale_);
}
d.add_temporaries(std::move(copies), s.index);
}
} // namespace mlx::core

View File

@@ -2,7 +2,7 @@
#include <fmt/format.h>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/jit/indexing.h"

View File

@@ -752,43 +752,4 @@ MTL::ComputePipelineState* get_quantized_kernel(
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_gather_qmm_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const array& x,
int group_size,
int bits,
int bm,
int bn,
int bk,
int wm,
int wn,
bool transpose) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name, [&]() {
std::string kernel_source;
concatenate(
kernel_source,
metal::utils(),
metal::gemm(),
metal::quantized(),
get_template_definition(
lib_name,
"gather_qmm_rhs",
get_type_string(x.dtype()),
group_size,
bits,
bm,
bn,
bk,
wm,
wn,
transpose));
return kernel_source;
});
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
}
} // namespace mlx::core

View File

@@ -224,21 +224,6 @@ MTL::ComputePipelineState* get_quantized_kernel(
const std::string& kernel_name,
const std::string& template_def);
MTL::ComputePipelineState* get_gather_qmm_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const array& x,
int group_size,
int bits,
int bm,
int bn,
int bk,
int wm,
int wn,
bool transpose);
// Create a GPU kernel template definition for JIT compilation
template <typename... Args>
std::string

View File

@@ -9,85 +9,64 @@ template <typename T, typename U, typename Op>
c[index] = Op()(a[0], b[0]);
}
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
template <typename T, typename U, typename Op>
[[kernel]] void binary_sv(
device const T* a,
device const T* b,
device U* c,
constant uint& size,
uint index [[thread_position_in_grid]]) {
index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) {
c[index + i] = Op()(a[0], b[index + i]);
}
c[index] = Op()(a[0], b[index]);
}
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
template <typename T, typename U, typename Op>
[[kernel]] void binary_vs(
device const T* a,
device const T* b,
device U* c,
constant uint& size,
uint index [[thread_position_in_grid]]) {
index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) {
c[index + i] = Op()(a[index + i], b[0]);
}
c[index] = Op()(a[index], b[0]);
}
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
template <typename T, typename U, typename Op>
[[kernel]] void binary_vv(
device const T* a,
device const T* b,
device U* c,
constant uint& size,
uint index [[thread_position_in_grid]]) {
index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) {
c[index + i] = Op()(a[index + i], b[index + i]);
}
c[index] = Op()(a[index], b[index]);
}
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
template <typename T, typename U, typename Op>
[[kernel]] void binary_sv2(
device const T* a,
device const T* b,
device U* c,
constant int64_t& size,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) {
c[offset + i] = Op()(a[0], b[offset + i]);
}
int64_t offset = index.x + grid_dim.x * int64_t(index.y);
c[offset] = Op()(a[0], b[offset]);
}
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
template <typename T, typename U, typename Op>
[[kernel]] void binary_vs2(
device const T* a,
device const T* b,
device U* c,
constant int64_t& size,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) {
c[offset + i] = Op()(a[offset + i], b[0]);
}
int64_t offset = index.x + grid_dim.x * int64_t(index.y);
c[offset] = Op()(a[offset], b[0]);
}
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
template <typename T, typename U, typename Op>
[[kernel]] void binary_vv2(
device const T* a,
device const T* b,
device U* c,
constant int64_t& size,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) {
c[offset + i] = Op()(a[offset + i], b[offset + i]);
}
int64_t offset = index.x + grid_dim.x * int64_t(index.y);
c[offset] = Op()(a[offset], b[offset]);
}
template <typename T, typename U, typename Op, typename IdxT = int64_t>

View File

@@ -71,7 +71,6 @@ instantiate_binary_types_bool(Less)
instantiate_binary_types_bool(LessEqual)
instantiate_binary_types_bool(NotEqual)
instantiate_binary_float(LogAddExp)
instantiate_binary_all(LogAddExp, complex64, complex64_t, complex64_t)
instantiate_binary_types(Maximum)
instantiate_binary_types(Minimum)
instantiate_binary_types(Multiply)

View File

@@ -130,24 +130,6 @@ struct LogAddExp {
? maxval
: (maxval + log1p(metal::exp(minval - maxval)));
};
complex64_t operator()(complex64_t x, complex64_t y) {
if (metal::isnan(x.real) || metal::isnan(x.imag) || metal::isnan(y.real) ||
metal::isnan(y.imag)) {
return metal::numeric_limits<float>::quiet_NaN();
}
constexpr float inf = metal::numeric_limits<float>::infinity();
complex64_t maxval = x > y ? x : y;
complex64_t minval = x < y ? x : y;
if (minval.real == -inf || maxval.real == inf)
return maxval;
float m = metal::exp(minval.real - maxval.real);
complex64_t dexp{
m * metal::cos(minval.imag - maxval.imag),
m * metal::sin(minval.imag - maxval.imag),
};
return maxval + log1p(dexp);
}
};
struct Maximum {

View File

@@ -12,103 +12,82 @@ template <typename T, typename U, typename Op>
d[index] = out[1];
}
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
template <typename T, typename U, typename Op>
[[kernel]] void binary_sv(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant uint& size,
uint index [[thread_position_in_grid]]) {
index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) {
auto out = Op()(a[0], b[index + i]);
c[index + i] = out[0];
d[index + i] = out[1];
}
auto out = Op()(a[0], b[index]);
c[index] = out[0];
d[index] = out[1];
}
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
template <typename T, typename U, typename Op>
[[kernel]] void binary_vs(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant uint& size,
uint index [[thread_position_in_grid]]) {
index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) {
auto out = Op()(a[index + i], b[0]);
c[index + i] = out[0];
d[index + i] = out[1];
}
auto out = Op()(a[index], b[0]);
c[index] = out[0];
d[index] = out[1];
}
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
template <typename T, typename U, typename Op>
[[kernel]] void binary_vv(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant uint& size,
uint index [[thread_position_in_grid]]) {
index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) {
auto out = Op()(a[index + i], b[index + i]);
c[index + i] = out[0];
d[index + i] = out[1];
}
auto out = Op()(a[index], b[index]);
c[index] = out[0];
d[index] = out[1];
}
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
template <typename T, typename U, typename Op>
[[kernel]] void binary_sv2(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant int64_t& size,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) {
auto out = Op()(a[0], b[offset + i]);
c[offset + i] = out[0];
d[offset + i] = out[1];
}
auto offset = index.x + grid_dim.x * int64_t(index.y);
auto out = Op()(a[0], b[offset]);
c[offset] = out[0];
d[offset] = out[1];
}
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
template <typename T, typename U, typename Op>
[[kernel]] void binary_vs2(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant int64_t& size,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) {
auto out = Op()(a[offset + i], b[0]);
c[offset + i] = out[0];
d[offset + i] = out[1];
}
auto offset = index.x + grid_dim.x * int64_t(index.y);
auto out = Op()(a[offset], b[0]);
c[offset] = out[0];
d[offset] = out[1];
}
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
template <typename T, typename U, typename Op>
[[kernel]] void binary_vv2(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant int64_t& size,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) {
auto out = Op()(a[offset + i], b[offset + i]);
c[offset + i] = out[0];
d[offset + i] = out[1];
}
auto offset = index.x + grid_dim.x * int64_t(index.y);
auto out = Op()(a[offset], b[offset]);
c[offset] = out[0];
d[offset] = out[1];
}
template <typename T, typename U, typename Op, typename IdxT = int64_t>

View File

@@ -104,22 +104,10 @@ constexpr bool operator==(complex64_t a, complex64_t b) {
constexpr complex64_t operator+(complex64_t a, complex64_t b) {
return {a.real + b.real, a.imag + b.imag};
}
constexpr complex64_t operator+(float a, complex64_t b) {
return {a + b.real, b.imag};
}
constexpr complex64_t operator+(complex64_t a, float b) {
return {a.real + b, a.imag};
}
constexpr complex64_t operator-(complex64_t a, complex64_t b) {
return {a.real - b.real, a.imag - b.imag};
}
constexpr complex64_t operator-(float a, complex64_t b) {
return {a - b.real, -b.imag};
}
constexpr complex64_t operator-(complex64_t a, float b) {
return {a.real - b, a.imag};
}
constexpr complex64_t operator*(complex64_t a, complex64_t b) {
return {a.real * b.real - a.imag * b.imag, a.real * b.imag + a.imag * b.real};
@@ -132,13 +120,6 @@ constexpr complex64_t operator/(complex64_t a, complex64_t b) {
return {x / denom, y / denom};
}
constexpr complex64_t operator/(float a, complex64_t b) {
auto denom = b.real * b.real + b.imag * b.imag;
auto x = a * b.real;
auto y = -a * b.imag;
return {x / denom, y / denom};
}
constexpr complex64_t operator%(complex64_t a, complex64_t b) {
auto real = a.real - (b.real * static_cast<int64_t>(a.real / b.real));
auto imag = a.imag - (b.imag * static_cast<int64_t>(a.imag / b.imag));

View File

@@ -1,53 +1,39 @@
// Copyright © 2024 Apple Inc.
template <typename T, typename U, int N = WorkPerThread<T>::n>
template <typename T, typename U>
[[kernel]] void copy_s(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant uint& size,
uint index [[thread_position_in_grid]]) {
index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) {
dst[index + i] = static_cast<U>(src[0]);
}
dst[index] = static_cast<U>(src[0]);
}
template <typename T, typename U, int N = WorkPerThread<T>::n>
template <typename T, typename U>
[[kernel]] void copy_v(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant uint& size,
uint index [[thread_position_in_grid]]) {
index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) {
dst[index + i] = static_cast<U>(src[index + i]);
}
dst[index] = static_cast<U>(src[index]);
}
template <typename T, typename U, int N = WorkPerThread<T>::n>
template <typename T, typename U>
[[kernel]] void copy_s2(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant int64_t& size,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) {
dst[offset + i] = static_cast<U>(src[0]);
}
auto offset = index.x + grid_dim.x * int64_t(index.y);
dst[offset] = static_cast<U>(src[0]);
}
template <typename T, typename U, int N = WorkPerThread<T>::n>
template <typename T, typename U>
[[kernel]] void copy_v2(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant int64_t& size,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) {
dst[offset + i] = static_cast<U>(src[offset + i]);
}
auto offset = index.x + grid_dim.x * int64_t(index.y);
dst[offset] = static_cast<U>(src[offset]);
}
template <typename T, typename U, typename IdxT = int64_t>

View File

@@ -10,7 +10,7 @@ For many sizes, GPU FFTs are memory bandwidth bound so
read/write performance is important.
Where possible, we read 128 bits sequentially in each thread,
coalesced with accesses from adjacent threads for optimal performance.
coalesced with accesses from adajcent threads for optimal performance.
We implement specialized reading/writing for:
- FFT
@@ -98,7 +98,7 @@ struct ReadWriter {
}
METAL_FUNC void load() const {
size_t batch_idx = size_t(elem.x * grid.y) * n;
int batch_idx = elem.x * grid.y * n;
short tg_idx = elem.y * grid.z + elem.z;
short max_index = grid.y * n - 2;
@@ -121,7 +121,7 @@ struct ReadWriter {
}
METAL_FUNC void write() const {
size_t batch_idx = size_t(elem.x * grid.y) * n;
int batch_idx = elem.x * grid.y * n;
short tg_idx = elem.y * grid.z + elem.z;
short max_index = grid.y * n - 2;
@@ -144,7 +144,7 @@ struct ReadWriter {
// Padded IO for Bluestein's algorithm
METAL_FUNC void load_padded(int length, const device float2* w_k) const {
size_t batch_idx = size_t(elem.x * grid.y) * length + elem.y * length;
int batch_idx = elem.x * grid.y * length + elem.y * length;
int fft_idx = elem.z;
int m = grid.z;
@@ -161,7 +161,7 @@ struct ReadWriter {
}
METAL_FUNC void write_padded(int length, const device float2* w_k) const {
size_t batch_idx = size_t(elem.x * grid.y) * length + elem.y * length;
int batch_idx = elem.x * grid.y * length + elem.y * length;
int fft_idx = elem.z;
int m = grid.z;
float2 inv_factor = {1.0f / n, -1.0f / n};
@@ -261,7 +261,7 @@ METAL_FUNC bool ReadWriter<float, float2>::out_of_bounds() const {
template <>
METAL_FUNC void ReadWriter<float, float2>::load() const {
size_t batch_idx = size_t(elem.x * grid.y) * n * 2 + elem.y * n * 2;
int batch_idx = elem.x * grid.y * n * 2 + elem.y * n * 2;
threadgroup float2* seq_buf = buf + elem.y * n;
// No out of bounds accesses on odd batch sizes
@@ -283,8 +283,7 @@ template <>
METAL_FUNC void ReadWriter<float, float2>::write() const {
short n_over_2 = (n / 2) + 1;
size_t batch_idx =
size_t(elem.x * grid.y) * n_over_2 * 2 + elem.y * n_over_2 * 2;
int batch_idx = elem.x * grid.y * n_over_2 * 2 + elem.y * n_over_2 * 2;
threadgroup float2* seq_buf = buf + elem.y * n;
int grid_index = elem.x * grid.y + elem.y;
@@ -318,7 +317,7 @@ template <>
METAL_FUNC void ReadWriter<float, float2>::load_padded(
int length,
const device float2* w_k) const {
size_t batch_idx = size_t(elem.x * grid.y) * length * 2 + elem.y * length * 2;
int batch_idx = elem.x * grid.y * length * 2 + elem.y * length * 2;
threadgroup float2* seq_buf = buf + elem.y * n;
// No out of bounds accesses on odd batch sizes
@@ -346,8 +345,8 @@ METAL_FUNC void ReadWriter<float, float2>::write_padded(
int length,
const device float2* w_k) const {
int length_over_2 = (length / 2) + 1;
size_t batch_idx =
size_t(elem.x * grid.y) * length_over_2 * 2 + elem.y * length_over_2 * 2;
int batch_idx =
elem.x * grid.y * length_over_2 * 2 + elem.y * length_over_2 * 2;
threadgroup float2* seq_buf = buf + elem.y * n + length - 1;
int grid_index = elem.x * grid.y + elem.y;
@@ -398,8 +397,7 @@ METAL_FUNC bool ReadWriter<float2, float>::out_of_bounds() const {
template <>
METAL_FUNC void ReadWriter<float2, float>::load() const {
short n_over_2 = (n / 2) + 1;
size_t batch_idx =
size_t(elem.x * grid.y) * n_over_2 * 2 + elem.y * n_over_2 * 2;
int batch_idx = elem.x * grid.y * n_over_2 * 2 + elem.y * n_over_2 * 2;
threadgroup float2* seq_buf = buf + elem.y * n;
// No out of bounds accesses on odd batch sizes
@@ -460,8 +458,8 @@ METAL_FUNC void ReadWriter<float2, float>::load_padded(
int n_over_2 = (n / 2) + 1;
int length_over_2 = (length / 2) + 1;
size_t batch_idx =
size_t(elem.x * grid.y) * length_over_2 * 2 + elem.y * length_over_2 * 2;
int batch_idx =
elem.x * grid.y * length_over_2 * 2 + elem.y * length_over_2 * 2;
threadgroup float2* seq_buf = buf + elem.y * n;
// No out of bounds accesses on odd batch sizes
@@ -505,7 +503,7 @@ template <>
METAL_FUNC void ReadWriter<float2, float>::write_padded(
int length,
const device float2* w_k) const {
size_t batch_idx = size_t(elem.x * grid.y) * length * 2 + elem.y * length * 2;
int batch_idx = elem.x * grid.y * length * 2 + elem.y * length * 2;
threadgroup float2* seq_buf = buf + elem.y * n + length - 1;
int grid_index = elem.x * grid.y + elem.y;

View File

@@ -26,7 +26,7 @@ METAL_FUNC void radix_func(thread float* x) {
}
}
template <typename T, int N, int max_radix, int read_width, int stride = 1>
template <typename T, int N, int max_radix, int read_width>
[[kernel]] void hadamard_n(
const device T* in [[buffer(0)]],
device T* out [[buffer(1)]],
@@ -46,25 +46,18 @@ template <typename T, int N, int max_radix, int read_width, int stride = 1>
constexpr short logFinal = logN % logR;
constexpr short final_radix = 1 << (logFinal);
int batch_idx = elem.y * N * stride + elem.z;
short i = elem.x;
int batch_idx = elem.x * N;
short i = elem.y;
threadgroup T buf[N];
// Read values from device
if (stride == 1) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < max_radix / read_width; j++) {
short index = j * read_width * num_threads + i * read_width;
STEEL_PRAGMA_UNROLL
for (short j = 0; j < max_radix / read_width; j++) {
short index = j * read_width * num_threads + i * read_width;
STEEL_PRAGMA_UNROLL
for (short r = 0; r < read_width; r++) {
buf[index + r] = in[batch_idx + index + r];
}
}
} else {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < max_radix; j++) {
buf[j * num_threads + i] = in[batch_idx + (j * num_threads + i) * stride];
for (short r = 0; r < read_width; r++) {
buf[index + r] = in[batch_idx + index + r];
}
}
@@ -120,20 +113,12 @@ template <typename T, int N, int max_radix, int read_width, int stride = 1>
}
// Write values to device
if (stride == 1) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < max_radix / read_width; j++) {
short index = j * read_width * num_threads + i * read_width;
STEEL_PRAGMA_UNROLL
for (short j = 0; j < max_radix / read_width; j++) {
short index = j * read_width * num_threads + i * read_width;
STEEL_PRAGMA_UNROLL
for (short r = 0; r < read_width; r++) {
out[batch_idx + index + r] = T(buf[index + r] * scale);
}
}
} else {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < max_radix; j++) {
out[batch_idx + (j * num_threads + i) * stride] =
buf[j * num_threads + i];
for (short r = 0; r < read_width; r++) {
out[batch_idx + index + r] = T(buf[index + r] * scale);
}
}
}

View File

@@ -3,10 +3,6 @@
#include <metal_simdgroup>
#include <metal_stdlib>
constant bool align_M [[function_constant(200)]];
constant bool align_N [[function_constant(201)]];
constant bool align_K [[function_constant(202)]];
using namespace metal;
#define MLX_MTL_CONST static constant constexpr const
@@ -1008,11 +1004,11 @@ METAL_FUNC void qmm_t_impl(
auto wl = (const device uint8_t*)w;
x += y_row * static_cast<int64_t>(K);
x += y_row * K;
wl += y_col * K_w;
scales += y_col * K_g;
biases += y_col * K_g;
y += y_row * static_cast<int64_t>(N) + y_col;
y += y_row * N + y_col;
// Make the x loader and mma operation
const short num_els = min(BM, M - y_row);
@@ -1132,11 +1128,11 @@ METAL_FUNC void qmm_n_impl(
// Set the block
const int y_row = tid.y * BM;
const int y_col = tid.x * BN;
x += y_row * static_cast<int64_t>(K);
x += y_row * K;
wl += y_col * bytes_per_pack / pack_factor;
scales += y_col / group_size;
biases += y_col / group_size;
y += y_row * static_cast<int64_t>(N) + y_col;
y += y_row * N + y_col;
// Make the x loader and mma operation
const short num_els = min(BM, M - y_row);
@@ -1690,26 +1686,26 @@ template <
}
template <typename T, int group_size, int bits>
[[kernel]] void gather_qmv_fast(
[[kernel]] void bs_qmv_fast(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]],
const device uint32_t* lhs_indices [[buffer(4)]],
const device uint32_t* rhs_indices [[buffer(5)]],
device T* y [[buffer(6)]],
const constant int& in_vec_size [[buffer(7)]],
const constant int& out_vec_size [[buffer(8)]],
const constant int& x_batch_ndims [[buffer(9)]],
const constant int* x_shape [[buffer(10)]],
const constant int64_t* x_strides [[buffer(11)]],
const constant int& w_batch_ndims [[buffer(12)]],
const constant int* w_shape [[buffer(13)]],
const constant int64_t* w_strides [[buffer(14)]],
const constant int64_t* s_strides [[buffer(15)]],
const constant int64_t* b_strides [[buffer(16)]],
const constant int& batch_ndims [[buffer(17)]],
const constant int* batch_shape [[buffer(18)]],
device T* y [[buffer(4)]],
const constant int& in_vec_size [[buffer(5)]],
const constant int& out_vec_size [[buffer(6)]],
const constant int& x_batch_ndims [[buffer(7)]],
const constant int* x_shape [[buffer(8)]],
const constant int64_t* x_strides [[buffer(9)]],
const constant int& w_batch_ndims [[buffer(10)]],
const constant int* w_shape [[buffer(11)]],
const constant int64_t* w_strides [[buffer(12)]],
const constant int64_t* s_strides [[buffer(13)]],
const constant int64_t* b_strides [[buffer(14)]],
const constant int& batch_ndims [[buffer(15)]],
const constant int* batch_shape [[buffer(16)]],
const device uint32_t* lhs_indices [[buffer(17)]],
const device uint32_t* rhs_indices [[buffer(18)]],
const constant int64_t* lhs_strides [[buffer(19)]],
const constant int64_t* rhs_strides [[buffer(20)]],
uint3 tid [[threadgroup_position_in_grid]],
@@ -1752,26 +1748,26 @@ template <typename T, int group_size, int bits>
}
template <typename T, int group_size, int bits>
[[kernel]] void gather_qmv(
[[kernel]] void bs_qmv(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]],
const device uint32_t* lhs_indices [[buffer(4)]],
const device uint32_t* rhs_indices [[buffer(5)]],
device T* y [[buffer(6)]],
const constant int& in_vec_size [[buffer(7)]],
const constant int& out_vec_size [[buffer(8)]],
const constant int& x_batch_ndims [[buffer(9)]],
const constant int* x_shape [[buffer(10)]],
const constant int64_t* x_strides [[buffer(11)]],
const constant int& w_batch_ndims [[buffer(12)]],
const constant int* w_shape [[buffer(13)]],
const constant int64_t* w_strides [[buffer(14)]],
const constant int64_t* s_strides [[buffer(15)]],
const constant int64_t* b_strides [[buffer(16)]],
const constant int& batch_ndims [[buffer(17)]],
const constant int* batch_shape [[buffer(18)]],
device T* y [[buffer(4)]],
const constant int& in_vec_size [[buffer(5)]],
const constant int& out_vec_size [[buffer(6)]],
const constant int& x_batch_ndims [[buffer(7)]],
const constant int* x_shape [[buffer(8)]],
const constant int64_t* x_strides [[buffer(9)]],
const constant int& w_batch_ndims [[buffer(10)]],
const constant int* w_shape [[buffer(11)]],
const constant int64_t* w_strides [[buffer(12)]],
const constant int64_t* s_strides [[buffer(13)]],
const constant int64_t* b_strides [[buffer(14)]],
const constant int& batch_ndims [[buffer(15)]],
const constant int* batch_shape [[buffer(16)]],
const device uint32_t* lhs_indices [[buffer(17)]],
const device uint32_t* rhs_indices [[buffer(18)]],
const constant int64_t* lhs_strides [[buffer(19)]],
const constant int64_t* rhs_strides [[buffer(20)]],
uint3 tid [[threadgroup_position_in_grid]],
@@ -1814,26 +1810,26 @@ template <typename T, int group_size, int bits>
}
template <typename T, int group_size, int bits>
[[kernel]] void gather_qvm(
[[kernel]] void bs_qvm(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]],
const device uint32_t* lhs_indices [[buffer(4)]],
const device uint32_t* rhs_indices [[buffer(5)]],
device T* y [[buffer(6)]],
const constant int& in_vec_size [[buffer(7)]],
const constant int& out_vec_size [[buffer(8)]],
const constant int& x_batch_ndims [[buffer(9)]],
const constant int* x_shape [[buffer(10)]],
const constant int64_t* x_strides [[buffer(11)]],
const constant int& w_batch_ndims [[buffer(12)]],
const constant int* w_shape [[buffer(13)]],
const constant int64_t* w_strides [[buffer(14)]],
const constant int64_t* s_strides [[buffer(15)]],
const constant int64_t* b_strides [[buffer(16)]],
const constant int& batch_ndims [[buffer(17)]],
const constant int* batch_shape [[buffer(18)]],
device T* y [[buffer(4)]],
const constant int& in_vec_size [[buffer(5)]],
const constant int& out_vec_size [[buffer(6)]],
const constant int& x_batch_ndims [[buffer(7)]],
const constant int* x_shape [[buffer(8)]],
const constant int64_t* x_strides [[buffer(9)]],
const constant int& w_batch_ndims [[buffer(10)]],
const constant int* w_shape [[buffer(11)]],
const constant int64_t* w_strides [[buffer(12)]],
const constant int64_t* s_strides [[buffer(13)]],
const constant int64_t* b_strides [[buffer(14)]],
const constant int& batch_ndims [[buffer(15)]],
const constant int* batch_shape [[buffer(16)]],
const device uint32_t* lhs_indices [[buffer(17)]],
const device uint32_t* rhs_indices [[buffer(18)]],
const constant int64_t* lhs_strides [[buffer(19)]],
const constant int64_t* rhs_strides [[buffer(20)]],
uint3 tid [[threadgroup_position_in_grid]],
@@ -1883,27 +1879,27 @@ template <
const int BM = 32,
const int BK = 32,
const int BN = 32>
[[kernel]] void gather_qmm_t(
[[kernel]] void bs_qmm_t(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]],
const device uint32_t* lhs_indices [[buffer(4)]],
const device uint32_t* rhs_indices [[buffer(5)]],
device T* y [[buffer(6)]],
const constant int& K [[buffer(7)]],
const constant int& N [[buffer(8)]],
const constant int& M [[buffer(9)]],
const constant int& x_batch_ndims [[buffer(10)]],
const constant int* x_shape [[buffer(11)]],
const constant int64_t* x_strides [[buffer(12)]],
const constant int& w_batch_ndims [[buffer(13)]],
const constant int* w_shape [[buffer(14)]],
const constant int64_t* w_strides [[buffer(15)]],
const constant int64_t* s_strides [[buffer(16)]],
const constant int64_t* b_strides [[buffer(17)]],
const constant int& batch_ndims [[buffer(18)]],
const constant int* batch_shape [[buffer(19)]],
device T* y [[buffer(4)]],
const constant int& K [[buffer(5)]],
const constant int& N [[buffer(6)]],
const constant int& M [[buffer(7)]],
const constant int& x_batch_ndims [[buffer(8)]],
const constant int* x_shape [[buffer(9)]],
const constant int64_t* x_strides [[buffer(10)]],
const constant int& w_batch_ndims [[buffer(11)]],
const constant int* w_shape [[buffer(12)]],
const constant int64_t* w_strides [[buffer(13)]],
const constant int64_t* s_strides [[buffer(14)]],
const constant int64_t* b_strides [[buffer(15)]],
const constant int& batch_ndims [[buffer(16)]],
const constant int* batch_shape [[buffer(17)]],
const device uint32_t* lhs_indices [[buffer(18)]],
const device uint32_t* rhs_indices [[buffer(19)]],
const constant int64_t* lhs_strides [[buffer(20)]],
const constant int64_t* rhs_strides [[buffer(21)]],
uint3 tid [[threadgroup_position_in_grid]],
@@ -1950,27 +1946,27 @@ template <
const int BM = 32,
const int BK = 32,
const int BN = 32>
[[kernel]] void gather_qmm_n(
[[kernel]] void bs_qmm_n(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]],
const device uint32_t* lhs_indices [[buffer(4)]],
const device uint32_t* rhs_indices [[buffer(5)]],
device T* y [[buffer(6)]],
const constant int& K [[buffer(7)]],
const constant int& N [[buffer(8)]],
const constant int& M [[buffer(9)]],
const constant int& x_batch_ndims [[buffer(10)]],
const constant int* x_shape [[buffer(11)]],
const constant int64_t* x_strides [[buffer(12)]],
const constant int& w_batch_ndims [[buffer(13)]],
const constant int* w_shape [[buffer(14)]],
const constant int64_t* w_strides [[buffer(15)]],
const constant int64_t* s_strides [[buffer(16)]],
const constant int64_t* b_strides [[buffer(17)]],
const constant int& batch_ndims [[buffer(18)]],
const constant int* batch_shape [[buffer(19)]],
device T* y [[buffer(4)]],
const constant int& K [[buffer(5)]],
const constant int& N [[buffer(6)]],
const constant int& M [[buffer(7)]],
const constant int& x_batch_ndims [[buffer(8)]],
const constant int* x_shape [[buffer(9)]],
const constant int64_t* x_strides [[buffer(10)]],
const constant int& w_batch_ndims [[buffer(11)]],
const constant int* w_shape [[buffer(12)]],
const constant int64_t* w_strides [[buffer(13)]],
const constant int64_t* s_strides [[buffer(14)]],
const constant int64_t* b_strides [[buffer(15)]],
const constant int& batch_ndims [[buffer(16)]],
const constant int* batch_shape [[buffer(17)]],
const device uint32_t* lhs_indices [[buffer(18)]],
const device uint32_t* rhs_indices [[buffer(19)]],
const constant int64_t* lhs_strides [[buffer(20)]],
const constant int64_t* rhs_strides [[buffer(21)]],
uint3 tid [[threadgroup_position_in_grid]],
@@ -2011,289 +2007,6 @@ template <
w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
}
template <typename T, typename mma_t, typename loader_a_t, typename loader_b_t>
METAL_FUNC void gemm_loop_aligned(
threadgroup T* As,
threadgroup T* Bs,
thread mma_t& mma_op,
thread loader_a_t& loader_a,
thread loader_b_t& loader_b,
const int k_iterations) {
for (int k = 0; k < k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup memory
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
}
template <
bool rows_aligned,
bool cols_aligned,
bool transpose,
typename T,
typename mma_t,
typename loader_a_t,
typename loader_b_t>
METAL_FUNC void gemm_loop_unaligned(
threadgroup T* As,
threadgroup T* Bs,
thread mma_t& mma_op,
thread loader_a_t& loader_a,
thread loader_b_t& loader_b,
const int k_iterations,
const short tgp_bm,
const short tgp_bn,
const short tgp_bk) {
for (int k = 0; k < k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup memory
if (rows_aligned) {
loader_a.load_unsafe();
} else {
loader_a.load_safe(short2(tgp_bk, tgp_bm));
}
if (cols_aligned) {
loader_b.load_unsafe();
} else {
loader_b.load_safe(
transpose ? short2(tgp_bk, tgp_bn) : short2(tgp_bn, tgp_bk));
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
}
template <typename T, typename mma_t, typename loader_a_t, typename loader_b_t>
METAL_FUNC void gemm_loop_finalize(
threadgroup T* As,
threadgroup T* Bs,
thread mma_t& mma_op,
thread loader_a_t& loader_a,
thread loader_b_t& loader_b,
const short2 tile_a,
const short2 tile_b) {
loader_a.load_safe(tile_a);
loader_b.load_safe(tile_b);
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
template <
typename T,
int group_size,
int bits,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose>
[[kernel]] void gather_qmm_rhs(
const device T* x [[buffer(0)]],
const device uint32_t* w [[buffer(1)]],
const device T* scales [[buffer(2)]],
const device T* biases [[buffer(3)]],
const device uint32_t* indices [[buffer(4)]],
device T* y [[buffer(5)]],
const constant int& M [[buffer(6)]],
const constant int& N [[buffer(7)]],
const constant int& K [[buffer(8)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]]) {
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
constexpr int BK_padded = (BK + 16 / sizeof(T));
constexpr int BN_padded = (BN + 16 / sizeof(T));
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
using mma_t = mlx::steel::BlockMMA<
T,
T,
BM,
BN,
BK,
WM,
WN,
false,
transpose,
BK_padded,
transpose ? BK_padded : BN_padded>;
using loader_x_t =
mlx::steel::BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE>;
using loader_w_t = QuantizedBlockLoader<
T,
transpose ? BN : BK,
transpose ? BK : BN,
transpose ? BK_padded : BN_padded,
transpose,
WM * WN * SIMD_SIZE,
group_size,
bits>;
threadgroup T Xs[BM * BK_padded];
threadgroup T Ws[transpose ? BN * BK_padded : BK * BN_padded];
// Compute the block
const int K_w = K * bytes_per_pack / pack_factor;
const int K_g = K / group_size;
const int N_w = N * bytes_per_pack / pack_factor;
const int N_g = N / group_size;
const int K_it = K / BK;
const size_t stride_w = transpose ? N * K_w : K * N_w;
const size_t stride_s = transpose ? N * K_g : K * N_g;
const int y_row = tid.y * BM;
const int y_col = tid.x * BN;
const size_t y_row_long = size_t(y_row);
const size_t y_col_long = size_t(y_col);
// Prepare threadgroup bounds
const short tgp_bm = align_M ? BM : short(min(BM, M - y_row));
const short tgp_bn = align_N ? BN : short(min(BN, N - y_col));
// Calculate the final tiles in the case that K is not aligned
const int k_remain = K - K_it * BK;
const short2 tile_x = short2(k_remain, tgp_bm);
const short2 tile_w =
transpose ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
// Move x and output to the correct block
auto wl = (const device uint8_t*)w;
x += y_row_long * K;
y += y_row_long * N + y_col_long;
wl += transpose ? y_col_long * K_w : y_col * bytes_per_pack / pack_factor;
scales += transpose ? y_col_long * K_g : y_col / group_size;
biases += transpose ? y_col_long * K_g : y_col / group_size;
// Do as many matmuls as necessary
uint32_t index;
short offset;
uint32_t index_next = indices[y_row];
short offset_next = 0;
int n = 0;
while (n < tgp_bm) {
n++;
offset = offset_next;
index = index_next;
offset_next = tgp_bm;
for (; n < tgp_bm; n++) {
if (indices[y_row + n] != index) {
offset_next = n;
index_next = indices[y_row + n];
break;
}
}
threadgroup_barrier(mem_flags::mem_none);
// Prepare threadgroup mma operation
thread mma_t mma_op(simd_group_id, simd_lane_id);
// Prepare threadgroup loading operations
thread loader_x_t loader_x(x, K, Xs, simd_group_id, simd_lane_id);
thread loader_w_t loader_w(
wl + index * stride_w,
scales + index * stride_s,
biases + index * stride_s,
transpose ? K : N,
Ws,
simd_group_id,
simd_lane_id);
// Matrices are all aligned check nothing
if (align_M && align_N) {
gemm_loop_aligned(Xs, Ws, mma_op, loader_x, loader_w, K_it);
if (!align_K) {
threadgroup_barrier(mem_flags::mem_threadgroup);
gemm_loop_finalize(Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
}
// Store results to device memory
if (offset_next - offset == BM) {
mma_op.store_result(y, N);
} else {
mma_op.store_result_slice(
y, N, short2(0, offset), short2(BN, offset_next));
}
} else {
// Tile aligned so check outside of the hot loop
if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {
gemm_loop_aligned(Xs, Ws, mma_op, loader_x, loader_w, K_it);
if (!align_K) {
threadgroup_barrier(mem_flags::mem_threadgroup);
gemm_loop_finalize(
Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
}
// Store results to device memory
if (offset_next - offset == BM) {
mma_op.store_result(y, N);
} else {
mma_op.store_result_slice(
y, N, short2(0, offset), short2(BN, offset_next));
}
}
// Tile partially aligned check rows
else if (align_N || tgp_bn == BN) {
gemm_loop_unaligned<false, true, transpose>(
Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK);
if (!align_K) {
threadgroup_barrier(mem_flags::mem_threadgroup);
gemm_loop_finalize(
Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
}
mma_op.store_result_slice(
y, N, short2(0, offset), short2(BN, offset_next));
}
// Tile partially aligned check cols
else if (align_M || tgp_bm == BM) {
gemm_loop_unaligned<true, false, transpose>(
Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK);
if (!align_K) {
threadgroup_barrier(mem_flags::mem_threadgroup);
gemm_loop_finalize(
Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
}
mma_op.store_result_slice(
y, N, short2(0, offset), short2(tgp_bn, offset_next));
}
// Nothing aligned so check both rows and cols
else {
gemm_loop_unaligned<false, false, transpose>(
Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK);
if (!align_K) {
threadgroup_barrier(mem_flags::mem_threadgroup);
gemm_loop_finalize(
Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
}
mma_op.store_result_slice(
y, N, short2(0, offset), short2(tgp_bn, offset_next));
}
}
}
}
template <typename T, const int group_size, const int bits>
[[kernel]] void affine_quantize(
const device T* w [[buffer(0)]],

View File

@@ -60,20 +60,6 @@
bits, \
split_k)
#define instantiate_gather_qmm_rhs(func, name, type, group_size, bits, bm, bn, bk, wm, wn, transpose) \
instantiate_kernel( \
#name "_" #type "_gs_" #group_size "_b_" #bits "_bm_" #bm "_bn_" #bn "_bk_" #bk "_wm_" #wm "_wn_" #wn, \
func, \
type, \
group_size, \
bits, \
bm, \
bn, \
bk, \
wm, \
wn, \
transpose)
#define instantiate_quantized_batched_wrap(name, type, group_size, bits) \
instantiate_quantized_batched(name, type, group_size, bits, 1) \
instantiate_quantized_batched(name, type, group_size, bits, 0)
@@ -87,14 +73,14 @@
#define instantiate_quantized_all_single(type, group_size, bits) \
instantiate_quantized(affine_quantize, type, group_size, bits) \
instantiate_quantized(affine_dequantize, type, group_size, bits) \
instantiate_quantized(gather_qmv_fast, type, group_size, bits) \
instantiate_quantized(gather_qmv, type, group_size, bits) \
instantiate_quantized(gather_qvm, type, group_size, bits) \
instantiate_quantized(gather_qmm_n, type, group_size, bits)
instantiate_quantized(bs_qmv_fast, type, group_size, bits) \
instantiate_quantized(bs_qmv, type, group_size, bits) \
instantiate_quantized(bs_qvm, type, group_size, bits) \
instantiate_quantized(bs_qmm_n, type, group_size, bits)
#define instantiate_quantized_all_aligned(type, group_size, bits) \
instantiate_quantized_aligned(gather_qmm_t, type, group_size, bits, true) \
instantiate_quantized_aligned(gather_qmm_t, type, group_size, bits, false) \
instantiate_quantized_aligned(bs_qmm_t, type, group_size, bits, true) \
instantiate_quantized_aligned(bs_qmm_t, type, group_size, bits, false) \
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, true, 1) \
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, true, 0) \
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, false, 1) \
@@ -110,17 +96,12 @@
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 8) \
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 32)
#define instantiate_quantized_all_rhs(type, group_size, bits) \
instantiate_gather_qmm_rhs(gather_qmm_rhs, gather_qmm_rhs_nt, type, group_size, bits, 16, 32, 32, 1, 2, true) \
instantiate_gather_qmm_rhs(gather_qmm_rhs, gather_qmm_rhs_nn, type, group_size, bits, 16, 32, 32, 1, 2, false)
#define instantiate_quantized_funcs(type, group_size, bits) \
instantiate_quantized_all_single(type, group_size, bits) \
instantiate_quantized_all_batched(type, group_size, bits) \
instantiate_quantized_all_aligned(type, group_size, bits) \
instantiate_quantized_all_quad(type, group_size, bits) \
instantiate_quantized_all_splitk(type, group_size, bits) \
instantiate_quantized_all_rhs(type, group_size, bits)
instantiate_quantized_all_splitk(type, group_size, bits)
#define instantiate_quantized_types(group_size, bits) \
instantiate_quantized_funcs(float, group_size, bits) \

View File

@@ -104,5 +104,4 @@ instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMi
instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin, 2)
instantiate_scan_helper(logaddexp_float16_float16, half, half, CumLogaddexp, 4)
instantiate_scan_helper(logaddexp_float32_float32, float, float, CumLogaddexp, 4)
instantiate_scan_helper(logaddexp_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumLogaddexp, 4)
instantiate_scan_helper(logaddexp_complex64_complex64, complex64_t, complex64_t, CumLogaddexp, 2) // clang-format on
instantiate_scan_helper(logaddexp_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumLogaddexp, 4) // clang-format on

View File

@@ -56,9 +56,9 @@ template <typename T, int D, int V = D>
const int head_idx = tid.x;
const int q_seq_idx = tid.y;
const int kv_head_idx = head_idx / gqa_factor;
const int o_offset = head_idx * tpg.y + q_seq_idx;
const int o_offset = tpg.x * q_seq_idx + head_idx;
const int q_offset =
query_transposed ? tpg.x * q_seq_idx + head_idx : o_offset;
query_transposed ? o_offset : head_idx * tpg.y + q_seq_idx;
queries += q_offset * D + simd_lid * qk_per_thread;
keys += kv_head_idx * k_head_stride + simd_gid * k_seq_stride +
simd_lid * qk_per_thread;
@@ -213,9 +213,9 @@ template <typename T, int D, int V = D>
const int block_idx = tid.z;
const int head_idx = tid.x;
const int q_seq_idx = tid.y;
const int o_offset = head_idx * tpg.y + q_seq_idx;
const int o_offset = tpg.x * q_seq_idx + head_idx;
const int q_offset =
query_transposed ? tpg.x * q_seq_idx + head_idx : o_offset;
query_transposed ? o_offset : head_idx * tpg.y + q_seq_idx;
const int kv_head_idx = head_idx / gqa_factor;
queries += q_offset * D + simd_lid * qk_per_thread;
@@ -358,8 +358,8 @@ template <typename T, int D>
// Adjust positions
const int head_idx = tid.x;
const int q_seq_idx = tid.y;
const int q_offset = head_idx * tpg.y + q_seq_idx;
;
const int n_heads = tpg.x;
const int q_offset = n_heads * q_seq_idx + head_idx;
partials += q_offset * blocks * D + simd_gid * D + simd_lid * elem_per_thread;
sums += q_offset * blocks;
maxs += q_offset * blocks;

View File

@@ -95,7 +95,7 @@ template <
Q += tidl.z * params->Q_strides[0] + // Batch
tidl.y * params->Q_strides[1] + // Head
tidl.x * BQ * params->Q_strides[2]; // Sequence
tidl.x * BQ * params->Q_strides[2]; // Seqeunce
ulong kv_head_idx = int(tid.y) / params->gqa_factor;
K += tidl.z * params->K_strides[0] + // Batch
@@ -106,7 +106,7 @@ template <
O += tidl.z * params->O_strides[0] + // Batch
tidl.y * params->O_strides[1] + // Head
tidl.x * BQ * params->O_strides[2]; // Sequence
tidl.x * BQ * params->O_strides[2]; // Seqeunce
if (has_mask) {
mask += tidl.z * mask_params->M_strides[0] + // Batch

View File

@@ -113,7 +113,7 @@ struct BlockLoader {
tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
}
// Zero out unneeded values
// Zero out uneeded values
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);
@@ -240,7 +240,7 @@ struct BlockLoaderT {
tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
}
// Zero out unneeded values
// Zero out uneeded values
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);

View File

@@ -141,7 +141,7 @@ implicit_gemm_conv_2d_general(
// Store results to device memory
{
// Adjust for simdgroup and thread location
// Adjust for simdgroup and thread locatio
int offset_m = c_row + mma_op.sm;
int offset_n = c_col + mma_op.sn;
C += offset_n;

View File

@@ -113,7 +113,7 @@ struct BlockLoader {
tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
}
// Zero out unneeded values
// Zero out uneeded values
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);

View File

@@ -1,32 +1,25 @@
// Copyright © 2024 Apple Inc.
template <typename T, typename Op, int N = WorkPerThread<T>::n>
template <typename T, typename Op>
[[kernel]] void ternary_v(
device const bool* a,
device const T* b,
device const T* c,
device T* d,
constant uint& size,
uint index [[thread_position_in_grid]]) {
index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) {
d[index + i] = Op()(a[index + i], b[index + i], c[index + i]);
}
d[index] = Op()(a[index], b[index], c[index]);
}
template <typename T, typename Op, int N = WorkPerThread<T>::n>
template <typename T, typename Op>
[[kernel]] void ternary_v2(
device const bool* a,
device const T* b,
device const T* c,
device T* d,
constant int64_t& size,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) {
d[offset + i] = Op()(a[offset + i], b[offset + i], c[offset + i]);
}
auto offset = index.x + grid_dim.x * int64_t(index.y);
d[offset] = Op()(a[offset], b[offset], c[offset]);
}
template <typename T, typename Op, typename IdxT = int64_t>

View File

@@ -1,28 +1,21 @@
// Copyright © 2024 Apple Inc.
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
template <typename T, typename U, typename Op>
[[kernel]] void unary_v(
device const T* in,
device U* out,
constant uint& size,
uint index [[thread_position_in_grid]]) {
index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) {
out[index + i] = Op()(in[index + i]);
}
out[index] = Op()(in[index]);
}
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
template <typename T, typename U, typename Op>
[[kernel]] void unary_v2(
device const T* in,
device U* out,
constant int64_t& size,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) {
out[offset + i] = Op()(in[offset + i]);
}
auto offset = index.x + grid_dim.x * int64_t(index.y);
out[offset] = Op()(in[offset]);
}
template <

View File

@@ -69,24 +69,17 @@ instantiate_unary_float(Round)
instantiate_unary_int(BitwiseInvert)
instantiate_unary_all_same(Abs, complex64, complex64_t)
instantiate_unary_all_same(ArcCos, complex64, complex64_t)
instantiate_unary_all_same(ArcSin, complex64, complex64_t)
instantiate_unary_all_same(ArcTan, complex64, complex64_t)
instantiate_unary_all_same(Conjugate, complex64, complex64_t)
instantiate_unary_all_same(Cos, complex64, complex64_t)
instantiate_unary_all_same(Cosh, complex64, complex64_t)
instantiate_unary_all_same(Exp, complex64, complex64_t)
instantiate_unary_all_same(Log, complex64, complex64_t)
instantiate_unary_all_same(Log1p, complex64, complex64_t)
instantiate_unary_all_same(Log2, complex64, complex64_t)
instantiate_unary_all_same(Log10, complex64, complex64_t)
instantiate_unary_all_same(Negative, complex64, complex64_t)
instantiate_unary_all_same(Sign, complex64, complex64_t)
instantiate_unary_all_same(Sin, complex64, complex64_t)
instantiate_unary_all_same(Sinh, complex64, complex64_t)
instantiate_unary_all_same(Square, complex64, complex64_t)
instantiate_unary_all_same(Sqrt, complex64, complex64_t)
instantiate_unary_all_same(Rsqrt, complex64, complex64_t)
instantiate_unary_all_same(Tan, complex64, complex64_t)
instantiate_unary_all_same(Tanh, complex64, complex64_t)
instantiate_unary_all_same(Round, complex64, complex64_t)

View File

@@ -17,21 +17,27 @@ struct Abs {
T operator()(T x) {
return metal::abs(x);
};
template <>
uint8_t operator()(uint8_t x) {
return x;
};
template <>
uint16_t operator()(uint16_t x) {
return x;
};
template <>
uint32_t operator()(uint32_t x) {
return x;
};
template <>
uint64_t operator()(uint64_t x) {
return x;
};
template <>
bool operator()(bool x) {
return x;
};
template <>
complex64_t operator()(complex64_t x) {
return {metal::precise::sqrt(x.real * x.real + x.imag * x.imag), 0};
};
@@ -42,8 +48,6 @@ struct ArcCos {
T operator()(T x) {
return metal::precise::acos(x);
};
complex64_t operator()(complex64_t x);
};
struct ArcCosh {
@@ -58,8 +62,6 @@ struct ArcSin {
T operator()(T x) {
return metal::precise::asin(x);
};
complex64_t operator()(complex64_t x);
};
struct ArcSinh {
@@ -74,8 +76,6 @@ struct ArcTan {
T operator()(T x) {
return metal::precise::atan(x);
};
complex64_t operator()(complex64_t x);
};
struct ArcTanh {
@@ -97,30 +97,39 @@ struct Ceil {
T operator()(T x) {
return metal::ceil(x);
};
template <>
int8_t operator()(int8_t x) {
return x;
};
template <>
int16_t operator()(int16_t x) {
return x;
};
template <>
int32_t operator()(int32_t x) {
return x;
};
template <>
int64_t operator()(int64_t x) {
return x;
};
template <>
uint8_t operator()(uint8_t x) {
return x;
};
template <>
uint16_t operator()(uint16_t x) {
return x;
};
template <>
uint32_t operator()(uint32_t x) {
return x;
};
template <>
uint64_t operator()(uint64_t x) {
return x;
};
template <>
bool operator()(bool x) {
return x;
};
@@ -132,6 +141,7 @@ struct Cos {
return metal::precise::cos(x);
};
template <>
complex64_t operator()(complex64_t x) {
return {
metal::precise::cos(x.real) * metal::precise::cosh(x.imag),
@@ -145,6 +155,7 @@ struct Cosh {
return metal::precise::cosh(x);
};
template <>
complex64_t operator()(complex64_t x) {
return {
metal::precise::cosh(x.real) * metal::precise::cos(x.imag),
@@ -177,6 +188,7 @@ struct Exp {
T operator()(T x) {
return metal::precise::exp(x);
};
template <>
complex64_t operator()(complex64_t x) {
auto m = metal::precise::exp(x.real);
return {m * metal::precise::cos(x.imag), m * metal::precise::sin(x.imag)};
@@ -195,30 +207,39 @@ struct Floor {
T operator()(T x) {
return metal::floor(x);
};
template <>
int8_t operator()(int8_t x) {
return x;
};
template <>
int16_t operator()(int16_t x) {
return x;
};
template <>
int32_t operator()(int32_t x) {
return x;
};
template <>
int64_t operator()(int64_t x) {
return x;
};
template <>
uint8_t operator()(uint8_t x) {
return x;
};
template <>
uint16_t operator()(uint16_t x) {
return x;
};
template <>
uint32_t operator()(uint32_t x) {
return x;
};
template <>
uint64_t operator()(uint64_t x) {
return x;
};
template <>
bool operator()(bool x) {
return x;
};
@@ -237,6 +258,7 @@ struct Log {
return metal::precise::log(x);
};
template <>
complex64_t operator()(complex64_t x) {
auto r = metal::precise::log(Abs{}(x).real);
auto i = metal::precise::atan2(x.imag, x.real);
@@ -250,6 +272,7 @@ struct Log2 {
return metal::precise::log2(x);
};
template <>
complex64_t operator()(complex64_t x) {
auto y = Log{}(x);
return {y.real / M_LN2_F, y.imag / M_LN2_F};
@@ -262,6 +285,7 @@ struct Log10 {
return metal::precise::log10(x);
};
template <>
complex64_t operator()(complex64_t x) {
auto y = Log{}(x);
return {y.real / M_LN10_F, y.imag / M_LN10_F};
@@ -301,6 +325,7 @@ struct Round {
T operator()(T x) {
return metal::rint(x);
};
template <>
complex64_t operator()(complex64_t x) {
return {metal::rint(x.real), metal::rint(x.imag)};
};
@@ -319,9 +344,11 @@ struct Sign {
T operator()(T x) {
return (x > T(0)) - (x < T(0));
};
template <>
uint32_t operator()(uint32_t x) {
return x != 0;
};
template <>
complex64_t operator()(complex64_t x) {
if (x == complex64_t(0)) {
return x;
@@ -337,6 +364,7 @@ struct Sin {
return metal::precise::sin(x);
};
template <>
complex64_t operator()(complex64_t x) {
return {
metal::precise::sin(x.real) * metal::precise::cosh(x.imag),
@@ -350,6 +378,7 @@ struct Sinh {
return metal::precise::sinh(x);
};
template <>
complex64_t operator()(complex64_t x) {
return {
metal::precise::sinh(x.real) * metal::precise::cos(x.imag),
@@ -369,17 +398,6 @@ struct Sqrt {
T operator()(T x) {
return metal::precise::sqrt(x);
};
complex64_t operator()(complex64_t x) {
if (x.real == 0.0 && x.imag == 0.0) {
return {0.0, 0.0};
}
auto r = Abs{}(x).real;
auto a = metal::precise::sqrt((r + x.real) / 2.0);
auto b_abs = metal::precise::sqrt((r - x.real) / 2.0);
auto b = metal::copysign(b_abs, x.imag);
return {a, b};
}
};
struct Rsqrt {
@@ -387,10 +405,6 @@ struct Rsqrt {
T operator()(T x) {
return metal::precise::rsqrt(x);
};
complex64_t operator()(complex64_t x) {
return 1.0 / Sqrt{}(x);
}
};
struct Tan {
@@ -399,6 +413,7 @@ struct Tan {
return metal::precise::tan(x);
};
template <>
complex64_t operator()(complex64_t x) {
float tan_a = metal::precise::tan(x.real);
float tanh_b = metal::precise::tanh(x.imag);
@@ -414,6 +429,7 @@ struct Tanh {
return metal::precise::tanh(x);
};
template <>
complex64_t operator()(complex64_t x) {
float tanh_a = metal::precise::tanh(x.real);
float tan_b = metal::precise::tan(x.imag);
@@ -422,21 +438,3 @@ struct Tanh {
return {(tanh_a + tan_b * t1) / denom, (tan_b - tanh_a * t1) / denom};
};
};
complex64_t ArcCos::operator()(complex64_t x) {
auto i = complex64_t{0.0, 1.0};
auto y = Log{}(x + i * Sqrt{}(1.0 - x * x));
return {y.imag, -y.real};
};
complex64_t ArcSin::operator()(complex64_t x) {
auto i = complex64_t{0.0, 1.0};
auto y = Log{}(i * x + Sqrt{}(1.0 - x * x));
return {y.imag, -y.real};
};
complex64_t ArcTan::operator()(complex64_t x) {
auto i = complex64_t{0.0, 1.0};
auto ix = i * x;
return (1.0 / complex64_t{0.0, 2.0}) * Log{}((1.0 + ix) / (1.0 - ix));
};

View File

@@ -15,14 +15,6 @@
typedef half float16_t;
// Work per thread values for different types. The values here are expected to
// match get_work_per_thread in mlx/backend/metal/utils.h
template <typename U>
struct WorkPerThread {
static_assert(sizeof(U) <= 8, "Type too large");
static constexpr int constant n = 8 / sizeof(U);
};
///////////////////////////////////////////////////////////////////////////////
// Type limits utils
///////////////////////////////////////////////////////////////////////////////
@@ -336,23 +328,6 @@ inline bfloat16_t log1p(bfloat16_t x) {
return bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f)));
}
inline complex64_t log1p(complex64_t in) {
float x = in.real;
float y = in.imag;
float zabs = metal::precise::sqrt(x * x + y * y);
float theta = metal::atan2(y, x + 1);
if (zabs < 0.5f) {
float r = x * (2 + x) + y * y;
if (r == 0) { // handle underflow
return {x, theta};
}
return {0.5f * log1p(r), theta};
} else {
auto z0 = metal::sqrt((x + 1) * (x + 1) + y * y);
return {metal::log(z0), theta};
}
}
///////////////////////////////////////////////////////////////////////////////
// SIMD shuffle ops
///////////////////////////////////////////////////////////////////////////////

View File

@@ -1,7 +1,7 @@
// Copyright © 2023-2024 Apple Inc.
#include <algorithm>
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h"

View File

@@ -7,7 +7,7 @@
#include "mlx/backend/common/broadcasting.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/kernels/defines.h"
@@ -1908,7 +1908,8 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc(out.nbytes()));
// Extract shapes from inputs.
// Extract shapes strides from inputs and copy in case of non-contiguous
// vectors.
int M = a.shape(-2);
int N = b.shape(-1);
int K = a.shape(-1);

View File

@@ -1,11 +1,11 @@
// Copyright © 2023-2024 Apple Inc.
#include <memory>
#include <sys/sysctl.h>
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/metal.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
#include "mlx/scheduler.h"
#include "mlx/utils.h"
namespace mlx::core::metal {
@@ -13,6 +13,85 @@ bool is_available() {
return true;
}
inline void check_error(MTL::CommandBuffer* cbuf) {
if (cbuf->status() == MTL::CommandBufferStatusError) {
std::ostringstream msg;
msg << "[METAL] Command buffer execution failed: "
<< cbuf->error()->localizedDescription()->utf8String();
throw std::runtime_error(msg.str());
}
}
void eval(array& arr) {
auto pool = new_scoped_memory_pool();
auto s = arr.primitive().stream();
auto& d = metal::device(s.device);
auto command_buffer = d.get_command_buffer(s.index);
auto outputs = arr.outputs();
{
// If the array is a tracer hold a reference
// to its inputs so they don't get donated
std::vector<array> inputs;
if (arr.is_tracer()) {
inputs = arr.inputs();
}
debug_set_primitive_buffer_label(command_buffer, arr.primitive());
arr.primitive().eval_gpu(arr.inputs(), outputs);
}
std::unordered_set<std::shared_ptr<array::Data>> buffers;
for (auto& in : arr.inputs()) {
buffers.insert(in.data_shared_ptr());
}
for (auto& s : arr.siblings()) {
buffers.insert(s.data_shared_ptr());
}
// Remove the output if it was donated to by an input
if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) {
buffers.erase(it);
}
if (d.command_buffer_needs_commit(s.index)) {
d.end_encoding(s.index);
scheduler::notify_new_task(s);
command_buffer->addCompletedHandler(
[s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
scheduler::notify_task_completion(s);
check_error(cbuf);
});
d.commit_command_buffer(s.index);
d.get_command_buffer(s.index);
} else {
command_buffer->addCompletedHandler(
[s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
check_error(cbuf);
});
}
}
void finalize(Stream s) {
auto pool = new_scoped_memory_pool();
auto& d = metal::device(s.device);
auto cb = d.get_command_buffer(s.index);
d.end_encoding(s.index);
cb->addCompletedHandler([s](MTL::CommandBuffer* cbuf) { check_error(cbuf); });
d.commit_command_buffer(s.index);
d.get_command_buffer(s.index);
}
void synchronize(Stream s) {
auto pool = new_scoped_memory_pool();
auto& d = metal::device(s.device);
auto cb = d.get_command_buffer(s.index);
cb->retain();
d.end_encoding(s.index);
d.commit_command_buffer(s.index);
cb->waitUntilCompleted();
check_error(cb);
cb->release();
}
void start_capture(std::string path, id object) {
auto pool = new_scoped_memory_pool();
@@ -49,36 +128,4 @@ void stop_capture() {
manager->stopCapture();
}
const std::unordered_map<std::string, std::variant<std::string, size_t>>&
device_info() {
auto init_device_info = []()
-> std::unordered_map<std::string, std::variant<std::string, size_t>> {
auto pool = new_scoped_memory_pool();
auto raw_device = device(default_device()).mtl_device();
auto name = std::string(raw_device->name()->utf8String());
auto arch = std::string(raw_device->architecture()->name()->utf8String());
size_t memsize = 0;
size_t length = sizeof(memsize);
sysctlbyname("hw.memsize", &memsize, &length, NULL, 0);
size_t rsrc_limit = 0;
sysctlbyname("iogpu.rsrc_limit", &rsrc_limit, &length, NULL, 0);
if (rsrc_limit == 0) {
rsrc_limit = 499000;
}
return {
{"device_name", name},
{"architecture", arch},
{"max_buffer_length", raw_device->maxBufferLength()},
{"max_recommended_working_set_size",
raw_device->recommendedMaxWorkingSetSize()},
{"memory_size", memsize},
{"resource_limit", rsrc_limit}};
};
static auto device_info_ = init_device_info();
return device_info_;
}
} // namespace mlx::core::metal

View File

@@ -2,10 +2,11 @@
#pragma once
#include <string>
#include <unordered_map>
#include <variant>
#include "mlx/array.h"
namespace mlx::core::metal {
/* Check if the Metal backend is available. */

View File

@@ -8,11 +8,14 @@
#include "mlx/array.h"
#include "mlx/stream.h"
namespace mlx::core::gpu {
namespace mlx::core::metal {
void new_stream(Stream stream);
std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool();
void eval(array& arr);
void finalize(Stream s);
void synchronize(Stream s);
} // namespace mlx::core::gpu
} // namespace mlx::core::metal

View File

@@ -1,22 +0,0 @@
// Copyright © 2025 Apple Inc.
#include <stdexcept>
#include "mlx/backend/metal/metal.h"
namespace mlx::core::metal {
bool is_available() {
return false;
}
void start_capture(std::string) {}
void stop_capture() {}
const std::unordered_map<std::string, std::variant<std::string, size_t>>&
device_info() {
throw std::runtime_error(
"[metal::device_info] Cannot get device info without metal backend");
};
} // namespace mlx::core::metal

View File

@@ -269,21 +269,4 @@ MTL::ComputePipelineState* get_quantized_kernel(
return d.get_kernel(kernel_name);
}
MTL::ComputePipelineState* get_gather_qmm_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const array&,
int,
int,
int,
int,
int,
int,
int,
bool) {
return d.get_kernel(kernel_name, "mlx", hash_name, func_consts);
}
} // namespace mlx::core

View File

@@ -1,7 +1,7 @@
// Copyright © 2024 Apple Inc.
#include <algorithm>
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/reduce.h"

View File

@@ -7,10 +7,10 @@
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/common/slicing.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/gpu/slicing.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/slicing.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
#include "mlx/scheduler.h"
@@ -25,6 +25,25 @@ void arange_set_scalars(T start, T next, metal::CommandEncoder& enc) {
enc.set_bytes(step, 1);
}
void reshape(const array& in, array& out, Stream s) {
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) {
out.set_data(allocator::malloc(out.nbytes()));
copy_gpu_inplace(
in,
out,
in.shape(),
in.strides(),
make_contiguous_strides(in.shape()),
0,
0,
CopyType::General,
s);
} else {
shared_buffer_reshape(in, out_strides, out);
}
}
static array compute_dynamic_offset(
const array& indices,
const Strides& strides,
@@ -207,10 +226,105 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
}
}
void AsType::eval_gpu(const std::vector<array>& inputs, array& out) {
CopyType ctype =
inputs[0].flags().contiguous ? CopyType::Vector : CopyType::General;
copy_gpu(inputs[0], out, ctype);
}
void AsStrided::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void Broadcast::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void BroadcastAxes::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void Concatenate::eval_gpu(const std::vector<array>& inputs, array& out) {
concatenate_gpu(inputs, out, axis_, stream());
}
void Contiguous::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
constexpr size_t extra_bytes = 16384;
if (in.buffer_size() <= out.nbytes() + extra_bytes &&
(in.flags().row_contiguous ||
(allow_col_major_ && in.flags().col_contiguous))) {
out.copy_shared_buffer(in);
} else {
copy_gpu(in, out, CopyType::General);
}
}
void Copy::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void CustomTransforms::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
eval(inputs, outputs);
}
void Depends::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
eval(inputs, outputs);
}
void Full::eval_gpu(const std::vector<array>& inputs, array& out) {
auto in = inputs[0];
CopyType ctype;
if (in.data_size() == 1) {
ctype = CopyType::Scalar;
} else if (in.flags().contiguous) {
ctype = CopyType::Vector;
} else {
ctype = CopyType::General;
}
copy_gpu(in, out, ctype);
}
void ExpandDims::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void Flatten::eval_gpu(const std::vector<array>& inputs, array& out) {
reshape(inputs[0], out, stream());
}
void Unflatten::eval_gpu(const std::vector<array>& inputs, array& out) {
reshape(inputs[0], out, stream());
}
void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
throw std::runtime_error("[Load::eval_gpu] Not implemented.");
}
void NumberOfElements::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void Pad::eval_gpu(const std::vector<array>& inputs, array& out) {
// Inputs must be base input array and scalar val array
assert(inputs.size() == 2);
auto& in = inputs[0];
auto& val = inputs[1];
// Padding value must be a scalar
assert(val.size() == 1);
// Padding value, input and output must be of the same type
assert(val.dtype() == in.dtype() && in.dtype() == out.dtype());
pad_gpu(in, val, out, axes_, low_pad_size_, stream());
}
void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
@@ -256,6 +370,27 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) {
reshape(inputs[0], out, stream());
}
void Split::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
eval(inputs, outputs);
}
void Slice::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
auto& in = inputs[0];
slice_gpu(in, out, start_indices_, strides_, stream());
}
void DynamicSlice::eval_gpu(const std::vector<array>& inputs, array& out) {
if (out.size() == 0) {
out.set_data(nullptr);
@@ -357,6 +492,18 @@ void SliceUpdate::eval_gpu(const std::vector<array>& inputs, array& out) {
/* const Stream& s = */ stream());
}
void Squeeze::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void StopGradient::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void Transpose::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void QRF::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
@@ -390,4 +537,35 @@ void LUF::eval_gpu(
throw std::runtime_error("[LUF::eval_gpu] Metal LU factorization NYI.");
}
void View::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0];
auto ibytes = size_of(in.dtype());
auto obytes = size_of(out.dtype());
// Conditions for buffer copying (disjunction):
// - type size is the same
// - type size is smaller and the last axis is contiguous
// - the entire array is row contiguous
if (ibytes == obytes || (obytes < ibytes && in.strides().back() == 1) ||
in.flags().row_contiguous) {
auto strides = in.strides();
for (int i = 0; i < static_cast<int>(strides.size()) - 1; ++i) {
strides[i] *= ibytes;
strides[i] /= obytes;
}
out.copy_shared_buffer(
in, strides, in.flags(), in.data_size() * ibytes / obytes);
} else {
auto tmp = array(in.shape(), in.dtype(), nullptr, {});
tmp.set_data(allocator::malloc(tmp.nbytes()));
copy_gpu_inplace(in, tmp, CopyType::General, stream());
auto flags = out.flags();
flags.contiguous = true;
flags.row_contiguous = true;
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
out.copy_shared_buffer(tmp, out.strides(), flags, out.size());
}
}
} // namespace mlx::core

File diff suppressed because it is too large Load Diff

View File

@@ -3,7 +3,7 @@
#include <algorithm>
#include <cassert>
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/kernels/defines.h"

View File

@@ -1,6 +1,7 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/metal/resident.h"
#include "mlx/backend/metal/metal_impl.h"
namespace mlx::core::metal {

View File

@@ -1,5 +1,5 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/fast_primitives.h"

View File

@@ -2,7 +2,7 @@
#include <sstream>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels/steel/attn/params.h"
@@ -154,9 +154,9 @@ void sdpa_vector(
int gqa_factor = q.shape(1) / k.shape(1);
int N = k.shape(2);
int B = q.shape(0) * q.shape(1);
size_t k_head_stride = k.shape(1) == 1 ? k.strides(0) : k.strides(1);
size_t k_head_stride = k.strides()[1];
size_t k_seq_stride = k.strides()[2];
size_t v_head_stride = v.shape(1) == 1 ? v.strides(0) : v.strides(1);
size_t v_head_stride = v.strides()[1];
size_t v_seq_stride = v.strides()[2];
MTL::Size group_dims(1024, 1, 1);
@@ -199,10 +199,11 @@ void sdpa_vector(
if (has_mask) {
auto& m = *mask;
compute_encoder.set_input_array(m, 11 + float_mask);
int32_t kv_seq_stride = m.shape(3) > 1 ? m.strides(3) : 0;
int32_t q_seq_stride = m.shape(2) > 1 ? m.strides(2) : 0;
int32_t head_stride =
m.shape(1) > 1 ? m.strides(1) : (m.shape(0) > 1 ? m.strides(0) : 0);
auto nd = m.ndim();
int32_t kv_seq_stride =
nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0;
int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0;
int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0;
compute_encoder.set_bytes(kv_seq_stride, 13);
compute_encoder.set_bytes(q_seq_stride, 14);
compute_encoder.set_bytes(head_stride, 15);
@@ -237,10 +238,9 @@ void sdpa_vector_2pass(
int N = k.shape(2);
int blocks = 32;
int B = q.shape(0) * q.shape(1);
size_t k_head_stride = k.shape(1) == 1 ? k.strides(0) : k.strides(1);
size_t k_head_stride = k.strides()[1];
size_t k_seq_stride = k.strides()[2];
size_t v_head_stride = v.shape(1) == 1 ? v.strides(0) : v.strides(1);
size_t v_head_stride = v.strides()[1];
size_t v_seq_stride = v.strides()[2];
MTL::Size group_dims(8 * 32, 1, 1);
MTL::Size grid_dims(B, q.shape(2), blocks);
@@ -302,10 +302,11 @@ void sdpa_vector_2pass(
if (has_mask) {
auto& m = *mask;
compute_encoder.set_input_array(m, 13 + float_mask);
int32_t kv_seq_stride = m.shape(3) > 1 ? m.strides(3) : 0;
int32_t q_seq_stride = m.shape(2) > 1 ? m.strides(2) : 0;
int32_t head_stride =
m.shape(1) > 1 ? m.strides(1) : (m.shape(0) > 1 ? m.strides(0) : 0);
auto nd = m.ndim();
int32_t kv_seq_stride =
nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0;
int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0;
int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0;
compute_encoder.set_bytes(kv_seq_stride, 15);
compute_encoder.set_bytes(q_seq_stride, 16);
compute_encoder.set_bytes(head_stride, 17);
@@ -367,6 +368,18 @@ void ScaledDotProductAttention::eval_gpu(
}
};
// Checks if arr is row contiguous or the sequence and head dimension are
// transposed
auto is_contiguous_or_head_seq_transposed = [](const array& arr) {
if (arr.flags().row_contiguous) {
return true;
}
auto& strides = arr.strides();
auto& shape = arr.shape();
return (strides[3] == 1) && (strides[2] == shape[3] * shape[1]) &&
(strides[1] == shape[3]) && (strides[0] == strides[2] * shape[2]);
};
// Checks that the headdim dimension has stride 1.
auto is_matrix_contiguous = [](const array& arr) {
return arr.strides(-1) == 1;
@@ -374,58 +387,30 @@ void ScaledDotProductAttention::eval_gpu(
// We are in vector mode ie single query
if (q_pre.shape(2) <= 8) {
auto q_copy_unless = [](const array& arr) {
if (arr.flags().row_contiguous) {
return true;
}
auto& strides = arr.strides();
auto& shape = arr.shape();
if (shape[0] == 1 || shape[1] == 1) {
// If either the batch or head dimension is a singleton, the other can
// be transposed with the sequence dimension
auto bidx = shape[0] == 1 ? 1 : 0;
return (strides[3] == 1) && (strides[2] == shape[3] * shape[bidx]) &&
(strides[bidx] == shape[3]);
}
return false;
};
auto kv_copy_unless = [](const array& arr) {
// keys and values should be copied if:
// - the last dimension is not contiguous
// - the batch and head dim are not contiguous
auto& strides = arr.strides();
auto& shape = arr.shape();
if (strides.back() != 1) {
return false;
}
if (shape[0] == 1 || shape[1] == 1) {
return true;
}
return (strides[0] == strides[1] * shape[1]);
};
const auto& q = copy_unless(q_copy_unless, q_pre);
const auto& k = copy_unless(kv_copy_unless, k_pre);
const auto& v = copy_unless(kv_copy_unless, v_pre);
const auto& q = copy_unless(is_contiguous_or_head_seq_transposed, q_pre);
const auto& k = copy_unless(is_matrix_contiguous, k_pre);
const auto& v = copy_unless(is_matrix_contiguous, v_pre);
// Donate the query if possible
if (q.is_donatable() && q.flags().row_contiguous && q.size() == o.size()) {
if (q.is_donatable() && (q.shape(2) == 1 || !q.flags().row_contiguous) &&
q.size() == o.size()) {
o.copy_shared_buffer(q);
} else {
o.set_data(allocator::malloc(o.nbytes()));
if (o.shape(2) == 1) {
o.set_data(allocator::malloc(o.nbytes()));
} else {
auto strides = o.strides();
strides[2] = o.shape(1) * o.shape(3);
strides[1] = o.shape(3);
auto flags = q.flags();
flags.row_contiguous = q.shape(1) == 1;
o.set_data(
allocator::malloc(o.nbytes()), o.size(), std::move(strides), flags);
}
}
auto mask_copy_unless = [&q](const array& arr) {
auto& strides = arr.strides();
auto& shape = arr.shape();
return arr.flags().row_contiguous || q.shape(0) == 1 || q.shape(1) == 1 ||
(strides[0] == strides[1] * shape[1]);
};
auto mask = inputs.size() > 3
? std::optional<array>{copy_unless(mask_copy_unless, inputs[3])}
: std::nullopt;
auto mask =
inputs.size() > 3 ? std::optional<array>{inputs[3]} : std::nullopt;
// We route to the 2 pass fused attention if
// - The device is large and the sequence length long

View File

@@ -3,7 +3,7 @@
#include <cassert>
#include <sstream>
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h"

View File

@@ -2,12 +2,21 @@
#include <numeric>
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/gpu/slicing.h"
#include "mlx/backend/common/slicing.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
namespace mlx::core {
void slice_gpu(
const array& in,
array& out,
const Shape& start_indices,
const Shape& strides,
const Stream& s) {
slice(in, out, start_indices, strides);
}
void concatenate_gpu(
const std::vector<array>& inputs,
array& out,
@@ -39,4 +48,30 @@ void concatenate_gpu(
}
}
void pad_gpu(
const array& in,
const array& val,
array& out,
const std::vector<int>& axes,
const Shape& low_pad_size,
const Stream& s) {
// Fill output with val
fill_gpu(val, out, s);
// Find offset for start of input values
size_t data_offset = 0;
for (int i = 0; i < axes.size(); i++) {
auto ax = axes[i] < 0 ? out.ndim() + axes[i] : axes[i];
data_offset += out.strides()[ax] * low_pad_size[i];
}
// Extract slice from output where input will be pasted
array out_slice(in.shape(), out.dtype(), nullptr, {});
out_slice.copy_shared_buffer(
out, out.strides(), out.flags(), out_slice.size(), data_offset);
// Copy input values into the slice
copy_gpu_inplace(in, out_slice, CopyType::GeneralGeneral, s);
}
} // namespace mlx::core

View File

@@ -1,7 +1,7 @@
// Copyright © 2023-2024 Apple Inc.
#include <algorithm>
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/kernels/defines.h"

Some files were not shown because too many files have changed in this diff Show More