Compare commits

..

7 Commits

Author SHA1 Message Date
Angelos Katharopoulos
6fc00d2c10 Add rudimentary barrier 2024-11-05 11:34:55 -08:00
Angelos Katharopoulos
44f0de2854 Fix run without distributed 2024-11-05 11:27:41 -08:00
Angelos Katharopoulos
29ec3539ed TCP socket distributed 2024-11-05 11:27:41 -08:00
Angelos Katharopoulos
e94f0028c3 Change the send message size 2024-11-05 11:27:41 -08:00
Angelos Katharopoulos
e5354fcddb Make it work even for donated inputs 2024-11-05 11:27:41 -08:00
Angelos Katharopoulos
34dd079a64 Start a sockets based distributed backend 2024-11-05 11:27:41 -08:00
Angelos Katharopoulos
c3ccd4919f Add MPI barrier 2024-11-05 11:26:53 -08:00
159 changed files with 4212 additions and 6498 deletions

View File

@@ -1,14 +1,13 @@
repos:
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v19.1.4
rev: v18.1.8
hooks:
- id: clang-format
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 24.10.0
rev: 24.8.0
hooks:
- id: black
- repo: https://github.com/pycqa/isort
rev: 5.13.2
hooks:

View File

@@ -24,7 +24,7 @@ option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
if(NOT MLX_VERSION)
set(MLX_VERSION 0.21.1)
set(MLX_VERSION 0.19.3)
endif()
# --------------------- Processor tests -------------------------
@@ -34,6 +34,8 @@ message(
"Building MLX for ${CMAKE_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}"
)
set(MLX_BUILD_ARM OFF)
if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86_64")
if(NOT MLX_ENABLE_X64_MAC)
@@ -55,6 +57,10 @@ else()
message(WARNING "MLX is prioritised for Apple silicon systems using macOS.")
endif()
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm64")
set(MLX_BUILD_ARM ON)
endif()
# ----------------------------- Lib -----------------------------
include(FetchContent)
@@ -83,27 +89,25 @@ elseif(MLX_BUILD_METAL)
# Throw an error if xcrun not found
execute_process(
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
OUTPUT_VARIABLE MACOS_SDK_VERSION COMMAND_ERROR_IS_FATAL ANY)
OUTPUT_VARIABLE MACOS_VERSION COMMAND_ERROR_IS_FATAL ANY)
if(${MACOS_SDK_VERSION} LESS 14.0)
if(${MACOS_VERSION} LESS 14.0)
message(
FATAL_ERROR
"MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON")
endif()
message(STATUS "Building with macOS SDK version ${MACOS_SDK_VERSION}")
message(STATUS "Building with SDK for macOS version ${MACOS_VERSION}")
set(METAL_CPP_URL
https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18-beta.zip
)
if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "")
set(XCRUN_FLAGS "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
endif()
# Get the metal version
execute_process(
COMMAND
zsh "-c"
"echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'"
"echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal -E -x metal -P - | tail -1 | tr -d '\n'"
OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL})
FetchContent_MakeAvailable(metal_cpp)
@@ -111,11 +115,13 @@ elseif(MLX_BUILD_METAL)
mlx PUBLIC $<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}>
$<INSTALL_INTERFACE:include/metal_cpp>)
target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
add_compile_definitions("MLX_METAL_VERSION=${MLX_METAL_VERSION}")
endif()
if(MLX_BUILD_CPU)
find_library(ACCELERATE_LIBRARY Accelerate)
if(ACCELERATE_LIBRARY)
if(MLX_BUILD_ARM AND ACCELERATE_LIBRARY)
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
set(MLX_BUILD_ACCELERATE ON)
target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY})
@@ -154,13 +160,6 @@ if(MLX_BUILD_CPU)
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
target_link_libraries(mlx PUBLIC ${BLAS_LIBRARIES})
if(WIN32)
find_package(dlfcn-win32 REQUIRED)
message(STATUS "dlfcn-win32 lib " ${dlfcn-win32_LIBRARIES})
message(STATUS "dlfcn-win32 include " ${dlfcn-win32_INCLUDE_DIRS})
target_link_libraries(mlx PUBLIC ${dlfcn-win32_LIBRARIES})
endif()
endif()
else()
set(MLX_BUILD_ACCELERATE OFF)

View File

@@ -1,189 +1,62 @@
# Copyright © 2024 Apple Inc.
import argparse
import math
import os
import subprocess
import time
import mlx.core as mx
import numpy as np
from time_utils import time_fn
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
device_name = device_name.decode("utf-8").strip("\n")
N_warmup = 5
N_iter_bench = 40
N_iter_func = 8
MAX_SEQ = 300
START_SEQ = 100
SEQ_INCREMENT = 50
def bench(f, *args):
for i in range(N_warmup):
f(*args)
def time_self_attention_primitives():
mx.random.seed(3)
B = 2
H = 38
D = 64
for R in range(START_SEQ, MAX_SEQ, SEQ_INCREMENT):
q = mx.random.uniform(shape=(B, H, R, D))
k = mx.random.uniform(shape=(B, H, R, D))
v = mx.random.uniform(shape=(B, H, R, D))
scale = 1.0 / math.sqrt(float(D))
mx.eval(q, k, v)
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(*args)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def sdpa_primitives(qs, ks, vs, alpha):
s = (alpha * qs) @ ks.transpose(0, 1, 3, 2)
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
o = p @ vs
return o
time_fn(sdpa_primitives, q, k, v, scale)
def mlx_sdpa_fused_inner(q, k, v, scale):
return mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=None)
def time_self_attention_sdpa():
mx.random.seed(3)
B = 2
H = 38
D = 64
for R in range(START_SEQ, MAX_SEQ, SEQ_INCREMENT):
q = mx.random.uniform(shape=(B, H, R, D))
k = mx.random.uniform(shape=(B, H, R, D))
v = mx.random.uniform(shape=(B, H, R, D))
scale = 1.0 / math.sqrt(float(D))
mx.eval(q, k, v)
def sdpa_fused(qs, ks, vs, alpha):
o = mx.fast.scaled_dot_product_attention(qs, ks, vs, scale=alpha)
return o
def mlx_sdpa_unfused_inner(q, k, v, scale, f32softmax=False):
q_dtype = q.dtype
q = q * mx.array(scale, q_dtype)
n_q_heads = q.shape[-3]
n_kv_heads = k.shape[-3]
n_repeats = n_q_heads // n_kv_heads
B = q.shape[0]
L = q.shape[2]
if n_repeats > 1:
q = mx.reshape(q, [B, n_kv_heads, n_repeats, L, -1])
k = mx.expand_dims(k, 2)
v = mx.expand_dims(v, 2)
scores = q @ mx.swapaxes(k, -1, -2)
if f32softmax:
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(q_dtype)
else:
scores = mx.softmax(scores, axis=-1)
out = scores @ v
if n_repeats > 1:
out = mx.reshape(out, [B, n_q_heads, L, -1])
return out
def mlx_spda_unfused(q, k, v, scale, transpose):
q_out = q
if transpose:
k = mx.transpose(k, (0, 2, 1, 3))
v = mx.transpose(v, (0, 2, 1, 3))
for i in range(N_iter_func):
if transpose:
q_out = mx.transpose(q_out, (0, 2, 1, 3))
q_out = mlx_sdpa_unfused_inner(q_out, k, v, scale)
if transpose:
q_out = mx.transpose(q_out, (0, 2, 1, 3))
mx.eval(q_out)
return q_out
def mlx_spda_fused(q, k, v, scale, transpose):
q_out = q
if transpose:
k = mx.transpose(k, (0, 2, 1, 3))
v = mx.transpose(v, (0, 2, 1, 3))
for i in range(N_iter_func):
if transpose:
q_out = mx.transpose(q_out, (0, 2, 1, 3))
q_out = mlx_sdpa_fused_inner(q_out, k, v, scale)
if transpose:
q_out = mx.transpose(q_out, (0, 2, 1, 3))
mx.eval(q_out)
return q_out
def bench_shape(B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, np_dtype, transpose=True):
shape_q = (
(B, qsl, n_q_heads, head_dim) if transpose else (B, n_q_heads, qsl, head_dim)
)
shape_kv = (
(B, ksl, n_kv_heads, head_dim) if transpose else (B, n_kv_heads, ksl, head_dim)
)
q_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_q).astype(np_dtype)
k_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_kv).astype(np_dtype)
v_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_kv).astype(np_dtype)
scale = math.sqrt(1.0 / head_dim)
q_mx = mx.array(q_np)
k_mx = mx.array(k_np)
v_mx = mx.array(v_np)
time_mlx_unfused = bench(mlx_spda_unfused, q_mx, k_mx, v_mx, scale, transpose)
time_mlx_fused = bench(mlx_spda_fused, q_mx, k_mx, v_mx, scale, transpose)
if transpose:
q_mx = mx.transpose(q_mx, (0, 2, 1, 3))
k_mx = mx.transpose(k_mx, (0, 2, 1, 3))
v_mx = mx.transpose(v_mx, (0, 2, 1, 3))
o_mlx_fused = mlx_sdpa_fused_inner(q_mx, k_mx, v_mx, scale)
o_mlx_unfused = mlx_sdpa_unfused_inner(q_mx, k_mx, v_mx, scale, f32softmax=True)
atol = 1e-5 if np_dtype == np.float32 else 1e-4
if not mx.allclose(o_mlx_fused, o_mlx_unfused, atol=atol):
print(
f"Failed at (B: {B}, qsl: {qsl}, ksl: {ksl}, head_dim: {head_dim}, n_qh: {n_q_heads}, n_kvh: {n_kv_heads}) [tpose = {transpose}] with max(|a - b|) = {mx.max(mx.abs(o_mlx_unfused - o_mlx_fused)):3.2e}"
)
return time_mlx_fused, time_mlx_unfused
def get_gflop_count(B, M, N, K):
return float(2.0 * N_iter_bench * N_iter_func * B * M * N * K) / float(1024.0**3)
time_fn(sdpa_fused, q, k, v, scale)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run gemm benchmarks")
parser = argparse.ArgumentParser("MLX benchmarks.")
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
args = parser.parse_args()
if args.gpu:
mx.set_default_device(mx.gpu)
else:
mx.set_default_device(mx.cpu)
dtypes = ("float16", "float32")[:1]
transposes = (False,)
# fmt: off
shapes_64 = (
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
( 1, 32, 32, 64, 32, 32),
( 1, 64, 64, 64, 32, 32),
( 1, 128, 128, 64, 32, 32),
( 1, 256, 256, 64, 32, 32),
( 1, 512, 512, 64, 32, 32),
( 1, 1024, 1024, 64, 32, 32),
( 1, 2048, 2048, 64, 32, 32),
( 1, 4096, 4096, 64, 32, 32),
)
shapes_80 = (
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
( 1, 1024, 1024, 80, 32, 32),
( 1, 2048, 2048, 80, 32, 32),
( 1, 4096, 4096, 80, 32, 32),
)
shapes_128 = (
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
( 1, 1024, 1024, 128, 32, 32),
( 1, 2048, 2048, 128, 32, 32),
( 1, 4096, 4096, 128, 32, 32),
)
# fmt: on
shapes = shapes_64 + shapes_80 + shapes_128
print(" B, qsl, ksl, hdim, n_qh, n_kvh, tpose, dtype, t_unfs, t_fuse, diff%")
for dtype in dtypes:
for transpose in transposes:
for B, qsl, ksl, head_dim, n_q_heads, n_kv_heads in shapes:
np_dtype = getattr(np, dtype)
time_mlx_fused, time_mlx_unfused = bench_shape(
B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, np_dtype, transpose
)
diff = time_mlx_unfused / time_mlx_fused - 1.0
t_str = 1 if transpose else 0
print(
f"{B:3d}, {qsl:5d}, {ksl:5d}, {head_dim:4d}, {n_q_heads:4d}, {n_kv_heads:5d}, {t_str:5d}, {dtype}, {time_mlx_unfused: 2.3f}, {time_mlx_fused: 2.3f}, {100. * diff:+5.2f}%"
)
time_self_attention_sdpa()
time_self_attention_primitives()

View File

@@ -4,51 +4,42 @@ import math
import mlx.core as mx
from time_utils import time_fn
L = 16384
L = 1024
H = 32
H_k = H // 4
H_k = 32 // 4
D = 128
dtype = mx.float16
loops = 10
def attention(q, k, v):
def _sdpa(q, k, v):
B, Hq, L, D = q.shape
_, Hk, S, _ = k.shape
q = q.reshape(B, Hk, Hq // Hk, L, D)
k = k[:, :, None, :, :]
v = v[:, :, None, :, :]
s = q @ k.transpose(0, 1, 2, 4, 3)
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
o = p @ v
return o.reshape(B, Hq, L, D)
for i in range(loops):
q = _sdpa(q, k, v)
return q
B, Hq, L, D = q.shape
_, Hk, S, _ = k.shape
q = q.reshape(B, Hk, Hq // Hk, L, D)
k = k[:, :, None, :, :]
v = v[:, :, None, :, :]
s = q @ k.transpose(0, 1, 2, 4, 3)
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
o = p @ v
return o.reshape(B, Hq, L, D)
def sdpa(q, k, v):
for i in range(loops):
q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0)
return q
return mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0)
def time_self_attention_primitives():
mx.random.seed(3)
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
v = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
q = mx.random.uniform(shape=(1, H, 1, D))
k = mx.random.uniform(shape=(1, H_k, L, D))
v = mx.random.uniform(shape=(1, H_k, L, D))
mx.eval(q, k, v)
time_fn(attention, q, k, v)
def time_self_attention_sdpa():
mx.random.seed(3)
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
v = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
q = mx.random.uniform(shape=(1, H, 1, D))
k = mx.random.uniform(shape=(1, H_k, L, D))
v = mx.random.uniform(shape=(1, H_k, L, D))
mx.eval(q, k, v)
time_fn(sdpa, q, k, v)

View File

@@ -494,7 +494,7 @@ below.
// Prepare to encode kernel
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder->setComputePipelineState(kernel);
// Kernel parameters are registered with buffer indices corresponding to
// those in the kernel declaration at axpby.metal
@@ -509,14 +509,14 @@ below.
compute_encoder.set_output_array(out, 2);
// Encode alpha and beta
compute_encoder.set_bytes(alpha_, 3);
compute_encoder.set_bytes(beta_, 4);
compute_encoder->setBytes(&alpha_, sizeof(float), 3);
compute_encoder->setBytes(&beta_, sizeof(float), 4);
// Encode shape, strides and ndim
compute_encoder.set_vector_bytes(x.shape(), 5);
compute_encoder.set_vector_bytes(x.strides(), 6);
compute_encoder.set_bytes(y.strides(), 7);
compute_encoder.set_bytes(ndim, 8);
compute_encoder->setBytes(x.shape().data(), ndim * sizeof(int), 5);
compute_encoder->setBytes(x.strides().data(), ndim * sizeof(size_t), 6);
compute_encoder->setBytes(y.strides().data(), ndim * sizeof(size_t), 7);
compute_encoder->setBytes(&ndim, sizeof(int), 8);
// We launch 1 thread for each input and make sure that the number of
// threads in any given threadgroup is not higher than the max allowed
@@ -530,7 +530,7 @@ below.
// Launch the grid with the given number of threads divided among
// the given threadgroups
compute_encoder.dispatch_threads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
We can now call the :meth:`axpby` operation on both the CPU and the GPU!

View File

@@ -209,7 +209,7 @@ Metal library by run-time compiling kernels the first time they are used in MLX
on a given machine. Note run-time compilation incurs a cold-start cost which can
be anwywhere from a few hundred millisecond to a few seconds depending on the
application. Once a kernel is compiled, it will be cached by the system. The
Metal kernel cache persists across reboots.
Metal kernel cache persists accross reboots.
Troubleshooting
^^^^^^^^^^^^^^^

View File

@@ -12,4 +12,5 @@ Fast
layer_norm
rope
scaled_dot_product_attention
affine_quantize
metal_kernel

View File

@@ -12,7 +12,6 @@ Layers
ALiBi
AvgPool1d
AvgPool2d
AvgPool3d
BatchNorm
CELU
Conv1d
@@ -42,7 +41,6 @@ Layers
LSTM
MaxPool1d
MaxPool2d
MaxPool3d
Mish
MultiHeadAttention
PReLU

View File

@@ -184,8 +184,8 @@ Let's time these two different versions:
print(timeit.timeit(lambda: mx.eval(naive_add(xs, ys)), number=100))
print(timeit.timeit(lambda: mx.eval(vmap_add(xs, ys)), number=100))
On an M1 Max the naive version takes in total ``5.639`` seconds whereas the
vectorized version takes only ``0.024`` seconds, more than 200 times faster.
On an M1 Max the naive version takes in total ``0.390`` seconds whereas the
vectorized version takes only ``0.025`` seconds, more than ten times faster.
Of course, this operation is quite contrived. A better approach is to simply do
``xs + ys.T``, but for more complex functions :func:`vmap` can be quite handy.

View File

@@ -257,7 +257,7 @@ void Axpby::eval_gpu(
// Prepare to encode kernel
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder->setComputePipelineState(kernel);
// Kernel parameters are registered with buffer indices corresponding to
// those in the kernel declaration at axpby.metal
@@ -272,15 +272,15 @@ void Axpby::eval_gpu(
compute_encoder.set_output_array(out, 2);
// Encode alpha and beta
compute_encoder.set_bytes(alpha_, 3);
compute_encoder.set_bytes(beta_, 4);
compute_encoder->setBytes(&alpha_, sizeof(float), 3);
compute_encoder->setBytes(&beta_, sizeof(float), 4);
// Encode shape, strides and ndim if needed
if (!contiguous_kernel) {
compute_encoder.set_vector_bytes(x.shape(), 5);
compute_encoder.set_vector_bytes(x.strides(), 6);
compute_encoder.set_bytes(y.strides(), 7);
compute_encoder.set_bytes(ndim, 8);
compute_encoder->setBytes(x.shape().data(), ndim * sizeof(int), 5);
compute_encoder->setBytes(x.strides().data(), ndim * sizeof(size_t), 6);
compute_encoder->setBytes(y.strides().data(), ndim * sizeof(size_t), 7);
compute_encoder->setBytes(&ndim, sizeof(int), 8);
}
// We launch 1 thread for each input and make sure that the number of
@@ -295,7 +295,7 @@ void Axpby::eval_gpu(
// Launch the grid with the given number of threads divided among
// the given threadgroups
compute_encoder.dispatch_threads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
#else // Metal is not available

View File

@@ -2,6 +2,7 @@
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h"
template <typename T>
@@ -59,4 +60,4 @@ template <typename T>
instantiate_axpby(float32, float);
instantiate_axpby(float16, half);
instantiate_axpby(bfloat16, bfloat16_t);
instantiate_axpby(complex64, complex64_t);
instantiate_axpby(complex64, complex64_t);

View File

@@ -1,4 +1,4 @@
setuptools>=42
cmake>=3.24
mlx>=0.21.0
mlx>=0.18.1
nanobind==2.2.0

View File

@@ -28,19 +28,10 @@ endif()
if (@MLX_BUILD_METAL@)
set(MLX_BUILD_METAL @MLX_BUILD_METAL@)
set(MLX_CXX_FLAGS ${MLX_CXX_FLAGS} -D_METAL_)
set(MLX_INCLUDE_DIRS
"${MLX_INCLUDE_DIRS};"
set_and_check(MLX_INCLUDE_DIRS
${MLX_INCLUDE_DIRS}
@PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/metal_cpp
)
if(@MLX_METAL_VERSION@ GREATER_EQUAL 310)
set(MLX_INCLUDE_DIRS
"${MLX_INCLUDE_DIRS};"
@PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/mlx/backend/metal/kernels/metal_3_1)
else()
set(MLX_INCLUDE_DIRS
"${MLX_INCLUDE_DIRS};"
@PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/mlx/backend/metal/kernels/metal_3_0)
endif()
endif()
set_target_properties(mlx PROPERTIES
@@ -49,4 +40,4 @@ set_target_properties(mlx PROPERTIES
)
include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(MLX DEFAULT_MSG MLX_LIBRARY MLX_INCLUDE_DIRS)
find_package_handle_standard_args(MLX DEFAULT_MSG MLX_LIBRARY MLX_INCLUDE_DIRS)

View File

@@ -19,7 +19,7 @@ Buffer malloc(size_t size) {
}
void free(Buffer buffer) {
allocator().free(buffer);
return allocator().free(buffer);
}
Buffer CommonAllocator::malloc(size_t size, bool) {

View File

@@ -1,6 +1,5 @@
// Copyright © 2023-2024 Apple Inc.
#include <functional>
#include <unordered_map>
#include "mlx/array.h"
#include "mlx/ops.h"
@@ -31,7 +30,7 @@ array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
}
array::array(
Shape shape,
std::vector<int> shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
std::vector<array> inputs)
@@ -42,7 +41,7 @@ array::array(
std::move(inputs))) {}
std::vector<array> array::make_arrays(
std::vector<Shape> shapes,
std::vector<std::vector<int>> shapes,
const std::vector<Dtype>& dtypes,
const std::shared_ptr<Primitive>& primitive,
const std::vector<array>& inputs) {
@@ -74,7 +73,11 @@ array::array(std::initializer_list<int> data, Dtype dtype)
}
/* Build an array from a shared buffer */
array::array(allocator::Buffer data, Shape shape, Dtype dtype, Deleter deleter)
array::array(
allocator::Buffer data,
std::vector<int> shape,
Dtype dtype,
deleter_t deleter)
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
set_data(data, deleter);
}
@@ -122,7 +125,7 @@ bool array::is_tracer() const {
return array_desc_->is_tracer && in_tracing() || retain_graph();
}
void array::set_data(allocator::Buffer buffer, Deleter d) {
void array::set_data(allocator::Buffer buffer, deleter_t d) {
array_desc_->data = std::make_shared<Data>(buffer, d);
array_desc_->data_ptr = buffer.raw_ptr();
array_desc_->data_size = size();
@@ -135,9 +138,9 @@ void array::set_data(allocator::Buffer buffer, Deleter d) {
void array::set_data(
allocator::Buffer buffer,
size_t data_size,
Strides strides,
std::vector<size_t> strides,
Flags flags,
Deleter d) {
deleter_t d) {
array_desc_->data = std::make_shared<Data>(buffer, d);
array_desc_->data_ptr = buffer.raw_ptr();
array_desc_->data_size = data_size;
@@ -147,7 +150,7 @@ void array::set_data(
void array::copy_shared_buffer(
const array& other,
const Strides& strides,
const std::vector<size_t>& strides,
Flags flags,
size_t data_size,
size_t offset /* = 0 */) {
@@ -166,7 +169,7 @@ void array::copy_shared_buffer(const array& other) {
void array::move_shared_buffer(
array other,
const Strides& strides,
const std::vector<size_t>& strides,
Flags flags,
size_t data_size,
size_t offset /* = 0 */) {
@@ -211,8 +214,6 @@ array::~array() {
if (do_detach) {
for (auto& s : siblings()) {
for (auto& ss : s.siblings()) {
// Set to null here to avoid descending into array destructor
// for siblings
ss.array_desc_ = nullptr;
}
s.array_desc_->siblings.clear();
@@ -233,13 +234,13 @@ void array::ArrayDesc::init() {
}
}
array::ArrayDesc::ArrayDesc(Shape shape, Dtype dtype)
array::ArrayDesc::ArrayDesc(std::vector<int> shape, Dtype dtype)
: shape(std::move(shape)), dtype(dtype), status(Status::available) {
init();
}
array::ArrayDesc::ArrayDesc(
Shape shape,
std::vector<int> shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
std::vector<array> inputs)
@@ -291,14 +292,6 @@ array::ArrayDesc::~ArrayDesc() {
auto top = std::move(for_deletion.back());
for_deletion.pop_back();
append_deletable_inputs(*top);
// Clear out possible siblings to break circular references
for (auto& s : top->siblings) {
// Set to null here to avoid descending into top-level
// array destructor for siblings
s.array_desc_ = nullptr;
}
top->siblings.clear();
}
}

View File

@@ -15,10 +15,7 @@ namespace mlx::core {
// Forward declaration
class Primitive;
using Deleter = std::function<void(allocator::Buffer)>;
using Shape = std::vector<int32_t>;
using Strides = std::vector<size_t>;
using deleter_t = std::function<void(allocator::Buffer)>;
class array {
/* An array is really a node in a graph. It contains a shared ArrayDesc
@@ -36,7 +33,7 @@ class array {
template <typename It>
array(
It data,
Shape shape,
std::vector<int> shape,
Dtype dtype =
TypeToDtype<typename std::iterator_traits<It>::value_type>());
@@ -52,15 +49,15 @@ class array {
template <typename T>
array(
std::initializer_list<T> data,
Shape shape,
std::vector<int> shape,
Dtype dtype = TypeToDtype<T>());
/* Build an array from a buffer */
array(
allocator::Buffer data,
Shape shape,
std::vector<int> shape,
Dtype dtype,
Deleter deleter = allocator::free);
deleter_t deleter = allocator::free);
/** Assignment to rvalue does not compile. */
array& operator=(const array& other) && = delete;
@@ -99,7 +96,7 @@ class array {
}
/** The shape of the array as a vector of integers. */
const Shape& shape() const {
const std::vector<int>& shape() const {
return array_desc_->shape;
}
@@ -108,12 +105,12 @@ class array {
*
* This function supports negative indexing and provides
* bounds checking. */
auto shape(int dim) const {
int shape(int dim) const {
return shape().at(dim < 0 ? dim + ndim() : dim);
}
/** The strides of the array. */
const Strides& strides() const {
const std::vector<size_t>& strides() const {
return array_desc_->strides;
}
@@ -122,7 +119,7 @@ class array {
*
* This function supports negative indexing and provides
* bounds checking. */
auto strides(int dim) const {
size_t strides(int dim) const {
return strides().at(dim < 0 ? dim + ndim() : dim);
}
@@ -187,13 +184,13 @@ class array {
*/
array(
Shape shape,
std::vector<int> shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
std::vector<array> inputs);
static std::vector<array> make_arrays(
std::vector<Shape> shapes,
std::vector<std::vector<int>> shapes,
const std::vector<Dtype>& dtypes,
const std::shared_ptr<Primitive>& primitive,
const std::vector<array>& inputs);
@@ -210,8 +207,8 @@ class array {
struct Data {
allocator::Buffer buffer;
Deleter d;
Data(allocator::Buffer buffer, Deleter d = allocator::free)
deleter_t d;
Data(allocator::Buffer buffer, deleter_t d = allocator::free)
: buffer(buffer), d(d) {}
// Not copyable
Data(const Data& d) = delete;
@@ -400,18 +397,18 @@ class array {
// Check if the array is a tracer array
bool is_tracer() const;
void set_data(allocator::Buffer buffer, Deleter d = allocator::free);
void set_data(allocator::Buffer buffer, deleter_t d = allocator::free);
void set_data(
allocator::Buffer buffer,
size_t data_size,
Strides strides,
std::vector<size_t> strides,
Flags flags,
Deleter d = allocator::free);
deleter_t d = allocator::free);
void copy_shared_buffer(
const array& other,
const Strides& strides,
const std::vector<size_t>& strides,
Flags flags,
size_t data_size,
size_t offset = 0);
@@ -420,7 +417,7 @@ class array {
void move_shared_buffer(
array other,
const Strides& strides,
const std::vector<size_t>& strides,
Flags flags,
size_t data_size,
size_t offset = 0);
@@ -439,8 +436,8 @@ class array {
void init(const It src);
struct ArrayDesc {
Shape shape;
Strides strides;
std::vector<int> shape;
std::vector<size_t> strides;
size_t size;
Dtype dtype;
std::shared_ptr<Primitive> primitive;
@@ -474,10 +471,10 @@ class array {
// The arrays position in the output list
uint32_t position{0};
explicit ArrayDesc(Shape shape, Dtype dtype);
explicit ArrayDesc(std::vector<int> shape, Dtype dtype);
explicit ArrayDesc(
Shape shape,
std::vector<int> shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
std::vector<array> inputs);
@@ -505,7 +502,7 @@ array::array(T val, Dtype dtype /* = TypeToDtype<T>() */)
template <typename It>
array::array(
It data,
Shape shape,
std::vector<int> shape,
Dtype dtype /* = TypeToDtype<typename std::iterator_traits<It>::value_type>() */) :
array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
init(data);
@@ -524,7 +521,7 @@ array::array(
template <typename T>
array::array(
std::initializer_list<T> data,
Shape shape,
std::vector<int> shape,
Dtype dtype /* = TypeToDtype<T>() */)
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
if (data.size() != size()) {

View File

@@ -39,7 +39,7 @@ void AsStrided::eval(const std::vector<array>& inputs, array& out) {
// rely on data_size anyway.
size_t data_size = out.size();
return move_or_copy(in, out, strides_, flags, data_size, offset_);
return out.copy_shared_buffer(in, strides_, flags, data_size, offset_);
}
void Broadcast::eval(const std::vector<array>& inputs, array& out) {
@@ -58,12 +58,12 @@ void Broadcast::eval(const std::vector<array>& inputs, array& out) {
if (out.size() > in.size()) {
flags.row_contiguous = flags.col_contiguous = false;
}
move_or_copy(in, out, strides, flags, in.data_size());
out.copy_shared_buffer(in, strides, flags, in.data_size());
}
void Copy::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
move_or_copy(inputs[0], out);
out.copy_shared_buffer(inputs[0]);
}
void CustomTransforms::eval(
@@ -72,7 +72,7 @@ void CustomTransforms::eval(
assert(inputs.size() > outputs.size());
for (int i = 0, j = inputs.size() - outputs.size(); i < outputs.size();
i++, j++) {
move_or_copy(inputs[j], outputs[i]);
outputs[i].copy_shared_buffer(inputs[j]);
}
}
@@ -81,7 +81,7 @@ void Depends::eval(
std::vector<array>& outputs) {
assert(inputs.size() > outputs.size());
for (int i = 0; i < outputs.size(); i++) {
move_or_copy(inputs[i], outputs[i]);
outputs[i].copy_shared_buffer(inputs[i]);
}
}
@@ -194,7 +194,7 @@ void Reshape::shared_buffer_reshape(
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
}
move_or_copy(in, out, out_strides, flags, in.data_size());
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
}
void Split::eval(
@@ -263,7 +263,7 @@ std::tuple<int64_t, std::vector<int64_t>> SliceUpdate::prepare_slice(
void StopGradient::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
move_or_copy(inputs[0], out);
out.copy_shared_buffer(inputs[0]);
}
void Transpose::eval(const std::vector<array>& inputs, array& out) {
@@ -297,7 +297,7 @@ void Transpose::eval(const std::vector<array>& inputs, array& out) {
b_stride *= out.shape(ri);
}
}
move_or_copy(in, out, out_strides, flags, in.data_size());
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
}
} // namespace mlx::core

View File

@@ -47,7 +47,7 @@ bool compile_available_for_device(const Device& device) {
} // namespace detail
std::string get_temp_file(const std::string& name) {
return std::filesystem::temp_directory_path().append(name).string();
return std::filesystem::temp_directory_path().append(name);
}
// Return a pointer to a compiled function
@@ -279,7 +279,7 @@ void Compiled::eval_cpu(
// Figure out which kernel we are using
auto& shape = outputs[0].shape();
auto contiguous = compiled_check_contiguity(inputs, shape);
bool contiguous = compiled_check_contiguity(inputs, shape);
// Handle all broadcasting and collect function input arguments
std::vector<void*> args;

View File

@@ -159,17 +159,6 @@ void Conjugate::eval(const std::vector<array>& inputs, array& out) {
}
}
void Contiguous::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.flags().row_contiguous ||
(allow_col_major_ && in.flags().col_contiguous)) {
out.copy_shared_buffer(in);
} else {
copy(in, out, CopyType::General);
}
}
void Cos::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
@@ -617,7 +606,7 @@ void View::eval_cpu(const std::vector<array>& inputs, array& out) {
if (ibytes == obytes || obytes < ibytes && in.strides().back() == 1 ||
in.flags().row_contiguous) {
auto strides = in.strides();
for (int i = 0; i < static_cast<int>(strides.size()) - 1; ++i) {
for (int i = 0; i < strides.size() - 1; ++i) {
strides[i] *= ibytes;
strides[i] /= obytes;
}

View File

@@ -2,38 +2,13 @@
#include <cassert>
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/ops.h"
#include "mlx/fast_primitives.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
namespace mlx::core {
namespace {
template <typename T, int bits>
void extract_bits(const uint8_t* w_in, T* w_out) {
assert(bits == 3 || bits == 6);
if (bits == 3) {
w_out[0] = static_cast<T>(w_in[0] & 0x7);
w_out[1] = static_cast<T>((w_in[0] & 0x38) >> 3);
w_out[2] = static_cast<T>(((w_in[0] & 0xc0) >> 6) + ((w_in[1] & 0x1) << 2));
w_out[3] = static_cast<T>((w_in[1] & 0xe) >> 1);
w_out[4] = static_cast<T>((w_in[1] & 0x70) >> 4);
w_out[5] = static_cast<T>(((w_in[1] & 0x80) >> 7) + ((w_in[2] & 0x3) << 1));
w_out[6] = static_cast<T>((w_in[2] & 0x1c) >> 2);
w_out[7] = static_cast<T>((w_in[2] & 0xe0) >> 5);
} else if (bits == 6) {
w_out[0] = static_cast<T>(w_in[0] & 0x3f);
w_out[1] =
static_cast<T>(((w_in[0] >> 6) & 0x03) + ((w_in[1] & 0x0f) << 2));
w_out[2] =
static_cast<T>(((w_in[1] >> 4) & 0x0f) + ((w_in[2] & 0x03) << 4));
w_out[3] = static_cast<T>((w_in[2] >> 2) & 0x3f);
}
}
template <typename T, int bits, int group_size>
void _qmm(
T* result,
@@ -45,12 +20,13 @@ void _qmm(
int N,
int K) {
constexpr int bitmask = (1 << bits) - 1;
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
constexpr int pack_factor = 32 / bits;
constexpr int packs_in_group = group_size / pack_factor;
const int Ng = N / group_size;
const int Nw = N / pack_factor;
for (int m = 0; m < M; m++) {
const uint8_t* w_local = (const uint8_t*)w;
const uint32_t* w_local = w;
const T* scales_local = scales;
const T* biases_local = biases;
@@ -64,25 +40,13 @@ void _qmm(
T scale = *scales_local++;
T bias = *biases_local++;
for (int ng = 0; ng < packs_in_group; ng++) {
if (bits == 3 || bits == 6) {
T wl[pack_factor];
extract_bits<T, bits>(w_local, wl);
#pragma clang loop unroll(full)
for (int p = 0; p < pack_factor; p++) {
(*result_local++) += xi * (scale * wl[p] + bias);
}
w_local += bytes_per_pack;
uint32_t wi = *w_local++;
} else {
uint8_t wi = *w_local++;
#pragma clang loop unroll(full)
for (int p = 0; p < pack_factor; p++) {
(*result_local++) +=
xi * (scale * static_cast<T>(wi & bitmask) + bias);
if (bits != 8) {
wi >>= bits;
}
}
for (int p = 0; p < pack_factor; p++) {
(*result_local++) +=
xi * (scale * static_cast<T>(wi & bitmask) + bias);
wi >>= bits;
}
}
}
@@ -103,12 +67,13 @@ void _qmm_t(
int N,
int K) {
constexpr int bitmask = (1 << bits) - 1;
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
constexpr int pack_factor = 32 / bits;
constexpr int packs_in_group = group_size / pack_factor;
const int Kg = K / group_size;
const int Kw = K / pack_factor;
for (int m = 0; m < M; m++) {
const uint8_t* w_local = (const uint8_t*)w;
const uint32_t* w_local = w;
const T* scales_local = scales;
const T* biases_local = biases;
@@ -120,26 +85,12 @@ void _qmm_t(
T bias = *biases_local++;
for (int kw = 0; kw < packs_in_group; kw++) {
if (bits == 3 || bits == 6) {
T wl[pack_factor];
extract_bits<T, bits>(w_local, wl);
#pragma clang loop unroll(full)
for (int p = 0; p < pack_factor; p++) {
sum += x_local[p] * (scale * wl[p] + bias);
}
w_local += bytes_per_pack;
x_local += pack_factor;
uint32_t wi = *w_local++;
} else {
uint8_t wi = *w_local++;
#pragma clang loop unroll(full)
for (int p = 0; p < pack_factor; p++) {
sum +=
(*x_local++) * (scale * static_cast<T>(wi & bitmask) + bias);
if (bits != 8) {
wi >>= bits;
}
}
for (int p = 0; p < pack_factor; p++) {
sum += (*x_local++) * (scale * static_cast<T>(wi & bitmask) + bias);
wi >>= bits;
}
}
}
@@ -151,55 +102,6 @@ void _qmm_t(
}
}
template <typename T, int bits, int group_size>
void _qmm_dispatch_transpose(
T* result,
const T* x,
const uint32_t* w,
const T* scales,
const T* biases,
int M,
int N,
int K,
bool transposed_w) {
if (transposed_w) {
return _qmm_t<T, bits, group_size>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, bits, group_size>(result, x, w, scales, biases, M, N, K);
}
}
template <typename T, int bits>
void _qmm_dispatch_group(
T* result,
const T* x,
const uint32_t* w,
const T* scales,
const T* biases,
int M,
int N,
int K,
int group_size,
bool transposed_w) {
switch (group_size) {
case 32:
_qmm_dispatch_transpose<T, bits, 32>(
result, x, w, scales, biases, M, N, K, transposed_w);
break;
case 64:
_qmm_dispatch_transpose<T, bits, 64>(
result, x, w, scales, biases, M, N, K, transposed_w);
break;
case 128:
_qmm_dispatch_transpose<T, bits, 128>(
result, x, w, scales, biases, M, N, K, transposed_w);
break;
default:
throw std::invalid_argument(
"Quantization group size must be 32, 64 or 128.");
}
}
template <typename T>
void _qmm_dispatch_typed(
T* result,
@@ -214,29 +116,79 @@ void _qmm_dispatch_typed(
int bits,
bool transposed_w) {
switch (bits) {
case 2:
_qmm_dispatch_group<T, 2>(
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
break;
case 3:
_qmm_dispatch_group<T, 3>(
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
break;
case 4:
_qmm_dispatch_group<T, 4>(
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
break;
case 6:
_qmm_dispatch_group<T, 6>(
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
break;
case 8:
_qmm_dispatch_group<T, 8>(
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
break;
default:
throw std::invalid_argument("Quantization bits must be 2, 3, 4, 6 or 8.");
case 2: {
switch (group_size) {
case 32:
if (transposed_w) {
return _qmm_t<T, 2, 32>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 2, 32>(result, x, w, scales, biases, M, N, K);
}
case 64:
if (transposed_w) {
return _qmm_t<T, 2, 64>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 2, 64>(result, x, w, scales, biases, M, N, K);
}
case 128:
if (transposed_w) {
return _qmm_t<T, 2, 128>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 2, 128>(result, x, w, scales, biases, M, N, K);
}
}
}
case 4: {
switch (group_size) {
case 32:
if (transposed_w) {
return _qmm_t<T, 4, 32>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 4, 32>(result, x, w, scales, biases, M, N, K);
}
case 64:
if (transposed_w) {
return _qmm_t<T, 4, 64>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 4, 64>(result, x, w, scales, biases, M, N, K);
}
case 128:
if (transposed_w) {
return _qmm_t<T, 4, 128>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 4, 128>(result, x, w, scales, biases, M, N, K);
}
}
}
case 8: {
switch (group_size) {
case 32:
if (transposed_w) {
return _qmm_t<T, 8, 32>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 8, 32>(result, x, w, scales, biases, M, N, K);
}
case 64:
if (transposed_w) {
return _qmm_t<T, 8, 64>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 8, 64>(result, x, w, scales, biases, M, N, K);
}
case 128:
if (transposed_w) {
return _qmm_t<T, 8, 128>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 8, 128>(result, x, w, scales, biases, M, N, K);
}
}
}
}
std::ostringstream msg;
msg << "Quantization type not supported. Provided bits=" << bits
<< " and group_size=" << group_size
<< ". The supported options are bits in "
<< "{2, 4, 8} and group_size in {64, 128}.";
throw std::invalid_argument(msg.str());
}
void _qmm_dispatch(
@@ -452,114 +404,4 @@ void GatherQMM::eval(const std::vector<array>& inputs, array& out) {
transpose_);
}
template <typename T, typename U>
void quantize(
const array& w_,
array& out_,
array& scales_,
array& biases_,
int bits,
int group_size) {
const T* w = w_.data<T>();
auto out = out_.data<U>();
T* scales = scales_.data<T>();
T* biases = biases_.data<T>();
T n_bins = (1 << bits) - 1;
T eps = 1e-7;
bool power_of_2_bits = is_power_of_2(bits);
int el_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
// For 3/6 bits we read 3 uint8s at a time instead of 1 uint32
int bytes_per_pack = power_of_2_bits ? 1 : 3;
int int_per_group = group_size * bytes_per_pack / el_per_int;
size_t n_groups = w_.size() / group_size;
for (size_t i = 0; i < n_groups; ++i) {
size_t w_idx = i * group_size;
T w_min = std::numeric_limits<float>::infinity();
T w_max = -w_min;
for (int j = 0; j < group_size; ++j) {
w_max = std::max(w_max, w[w_idx + j]);
w_min = std::min(w_min, w[w_idx + j]);
}
bool mask = std::abs(w_min) > std::abs(w_max);
T scale = std::max(T((w_max - w_min) / n_bins), eps);
scale = mask ? scale : -scale;
auto edge = mask ? w_min : w_max;
auto q0 = std::rint(edge / scale);
if (q0 == 0) {
scales[i] = scale;
biases[i] = 0;
} else {
scales[i] = edge / q0;
biases[i] = edge;
}
size_t out_idx = i * int_per_group;
for (int j = 0; j < int_per_group / bytes_per_pack; ++j) {
uint32_t out_el = 0;
for (int k = 0; k < el_per_int; ++k) {
T w_el = w[w_idx + j * el_per_int + k];
w_el = std::rint((w_el - biases[i]) / scales[i]);
w_el = std::min(std::max(w_el, T(0)), n_bins);
out_el |= static_cast<uint32_t>(w_el) << (k * bits);
}
if (power_of_2_bits) {
out[out_idx + j] = out_el;
} else {
out[out_idx + bytes_per_pack * j] = out_el & 0xff;
out[out_idx + bytes_per_pack * j + 1] = (out_el & 0xff00) >> 8;
out[out_idx + bytes_per_pack * j + 2] = (out_el & 0xff0000) >> 16;
}
}
}
}
void fast::AffineQuantize::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
auto ensure_row_contiguous = [](const array& arr) {
if (arr.flags().row_contiguous) {
return arr;
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
return arr_copy;
}
};
auto w = ensure_row_contiguous(inputs[0]);
auto& out = outputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& scales = outputs[1];
auto& biases = outputs[2];
scales.set_data(allocator::malloc_or_wait(scales.nbytes()));
biases.set_data(allocator::malloc_or_wait(biases.nbytes()));
if (w.dtype() == float16) {
if (is_power_of_2(bits_)) {
quantize<float16_t, uint32_t>(w, out, scales, biases, bits_, group_size_);
} else {
quantize<float16_t, uint8_t>(w, out, scales, biases, bits_, group_size_);
}
} else if (w.dtype() == bfloat16) {
if (is_power_of_2(bits_)) {
quantize<bfloat16_t, uint32_t>(
w, out, scales, biases, bits_, group_size_);
} else {
quantize<bfloat16_t, uint8_t>(w, out, scales, biases, bits_, group_size_);
}
} else if (w.dtype() == float32) {
if (is_power_of_2(bits_)) {
quantize<float, uint32_t>(w, out, scales, biases, bits_, group_size_);
} else {
quantize<float, uint8_t>(w, out, scales, biases, bits_, group_size_);
}
} else {
throw std::runtime_error(
"[fast::AffineQuantize::eval_cpu] Only supports floating point inputs");
}
}
} // namespace mlx::core

View File

@@ -120,53 +120,45 @@ struct MinReduce {
};
template <typename InT>
void reduce_dispatch_and_or(
void reduce_dispatch_out(
const array& in,
array& out,
Reduce::ReduceType rtype,
const std::vector<int>& axes) {
if (rtype == Reduce::And) {
reduction_op<InT, bool>(in, out, axes, true, AndReduce());
} else {
reduction_op<InT, bool>(in, out, axes, false, OrReduce());
}
}
template <typename InT>
void reduce_dispatch_sum_prod(
const array& in,
array& out,
Reduce::ReduceType rtype,
const std::vector<int>& axes) {
if (rtype == Reduce::Sum) {
auto op = [](auto y, auto x) { (*y) = (*y) + x; };
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
reduction_op<InT, int32_t>(in, out, axes, 0, op);
} else {
reduction_op<InT, InT>(in, out, axes, 0, op);
switch (rtype) {
case Reduce::And: {
reduction_op<InT, bool>(in, out, axes, true, AndReduce());
break;
}
} else {
auto op = [](auto y, auto x) { (*y) *= x; };
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
reduction_op<InT, int32_t>(in, out, axes, 1, op);
} else {
case Reduce::Or: {
reduction_op<InT, bool>(in, out, axes, false, OrReduce());
break;
}
case Reduce::Sum: {
auto op = [](auto y, auto x) { (*y) = (*y) + x; };
if (out.dtype() == int32) {
// special case since the input type can be bool
reduction_op<InT, int32_t>(in, out, axes, 0, op);
} else {
reduction_op<InT, InT>(in, out, axes, 0, op);
}
break;
}
case Reduce::Prod: {
auto op = [](auto y, auto x) { (*y) *= x; };
reduction_op<InT, InT>(in, out, axes, 1, op);
break;
}
case Reduce::Max: {
auto init = Limits<InT>::min;
reduction_op<InT, InT>(in, out, axes, init, MaxReduce());
break;
}
case Reduce::Min: {
auto init = Limits<InT>::max;
reduction_op<InT, InT>(in, out, axes, init, MinReduce());
break;
}
}
}
template <typename InT>
void reduce_dispatch_min_max(
const array& in,
array& out,
Reduce::ReduceType rtype,
const std::vector<int>& axes) {
if (rtype == Reduce::Max) {
auto init = Limits<InT>::min;
reduction_op<InT, InT>(in, out, axes, init, MaxReduce());
} else {
auto init = Limits<InT>::max;
reduction_op<InT, InT>(in, out, axes, init, MinReduce());
}
}
@@ -198,114 +190,46 @@ void nd_loop(
void Reduce::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
switch (reduce_type_) {
case Reduce::And:
case Reduce::Or: {
switch (in.dtype()) {
case bool_:
case uint8:
case int8:
reduce_dispatch_and_or<int8_t>(in, out, reduce_type_, axes_);
break;
case int16:
case uint16:
case float16:
case bfloat16:
reduce_dispatch_and_or<int16_t>(in, out, reduce_type_, axes_);
break;
case uint32:
case int32:
case float32:
reduce_dispatch_and_or<int32_t>(in, out, reduce_type_, axes_);
break;
case uint64:
case int64:
case complex64:
reduce_dispatch_and_or<int64_t>(in, out, reduce_type_, axes_);
break;
}
switch (in.dtype()) {
case bool_:
reduce_dispatch_out<bool>(in, out, reduce_type_, axes_);
break;
}
case Reduce::Sum:
case Reduce::Prod: {
switch (in.dtype()) {
case bool_:
case uint8:
case int8:
reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_);
break;
case int16:
case uint16:
reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_);
break;
case int32:
case uint32:
reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_);
break;
case int64:
case uint64:
reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_);
break;
case float16:
reduce_dispatch_sum_prod<float16_t>(in, out, reduce_type_, axes_);
break;
case bfloat16:
reduce_dispatch_sum_prod<bfloat16_t>(in, out, reduce_type_, axes_);
break;
case float32:
reduce_dispatch_sum_prod<float>(in, out, reduce_type_, axes_);
break;
case complex64:
reduce_dispatch_sum_prod<complex64_t>(in, out, reduce_type_, axes_);
break;
}
case uint8:
reduce_dispatch_out<uint8_t>(in, out, reduce_type_, axes_);
break;
}
case Reduce::Max:
case Reduce::Min: {
switch (in.dtype()) {
case bool_:
reduce_dispatch_min_max<bool>(in, out, reduce_type_, axes_);
break;
case uint8:
reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
break;
case uint16:
reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
break;
case uint32:
reduce_dispatch_min_max<uint32_t>(in, out, reduce_type_, axes_);
break;
case uint64:
reduce_dispatch_min_max<uint64_t>(in, out, reduce_type_, axes_);
break;
case int8:
reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
break;
case int16:
reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
break;
case int32:
reduce_dispatch_min_max<int32_t>(in, out, reduce_type_, axes_);
break;
case int64:
reduce_dispatch_min_max<int64_t>(in, out, reduce_type_, axes_);
break;
case float16:
reduce_dispatch_min_max<float16_t>(in, out, reduce_type_, axes_);
break;
case float32:
reduce_dispatch_min_max<float>(in, out, reduce_type_, axes_);
break;
case bfloat16:
reduce_dispatch_min_max<bfloat16_t>(in, out, reduce_type_, axes_);
break;
case complex64:
reduce_dispatch_min_max<complex64_t>(in, out, reduce_type_, axes_);
break;
}
case uint16:
reduce_dispatch_out<uint16_t>(in, out, reduce_type_, axes_);
break;
case uint32:
reduce_dispatch_out<uint32_t>(in, out, reduce_type_, axes_);
break;
case uint64:
reduce_dispatch_out<uint64_t>(in, out, reduce_type_, axes_);
break;
case int8:
reduce_dispatch_out<uint8_t>(in, out, reduce_type_, axes_);
break;
case int16:
reduce_dispatch_out<uint16_t>(in, out, reduce_type_, axes_);
break;
case int32:
reduce_dispatch_out<int32_t>(in, out, reduce_type_, axes_);
break;
case int64:
reduce_dispatch_out<int64_t>(in, out, reduce_type_, axes_);
break;
case float16:
reduce_dispatch_out<float16_t>(in, out, reduce_type_, axes_);
break;
case float32:
reduce_dispatch_out<float>(in, out, reduce_type_, axes_);
break;
case bfloat16:
reduce_dispatch_out<bfloat16_t>(in, out, reduce_type_, axes_);
break;
case complex64:
reduce_dispatch_out<complex64_t>(in, out, reduce_type_, axes_);
break;
}
}
}

View File

@@ -34,7 +34,7 @@ void shared_buffer_slice(
flags.col_contiguous = is_col_contiguous;
flags.contiguous = (no_bsx_size == data_size);
move_or_copy(in, out, out_strides, flags, data_size, data_offset);
out.copy_shared_buffer(in, out_strides, flags, data_size, data_offset);
}
} // namespace mlx::core

View File

@@ -4,28 +4,6 @@
namespace mlx::core {
void move_or_copy(const array& in, array& out) {
if (in.is_donatable()) {
out.move_shared_buffer(in);
} else {
out.copy_shared_buffer(in);
}
}
void move_or_copy(
const array& in,
array& out,
const std::vector<size_t>& strides,
array::Flags flags,
size_t data_size,
size_t offset /* = 0 */) {
if (in.is_donatable()) {
out.move_shared_buffer(in, strides, flags, data_size, offset);
} else {
out.copy_shared_buffer(in, strides, flags, data_size, offset);
}
}
template <typename StrideT>
std::tuple<std::vector<int>, std::vector<std::vector<StrideT>>>
collapse_contiguous_dims_impl(

View File

@@ -178,13 +178,4 @@ inline bool is_donatable(const array& in, const array& out) {
in.buffer_size() <= out.nbytes() + donation_extra;
}
void move_or_copy(const array& in, array& out);
void move_or_copy(
const array& in,
array& out,
const std::vector<size_t>& strides,
array::Flags flags,
size_t data_size,
size_t offset = 0);
} // namespace mlx::core

View File

@@ -14,21 +14,14 @@ function(make_jit_source SRC_FILE)
COMMAND
/bin/bash ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
${CMAKE_CURRENT_BINARY_DIR}/jit ${CMAKE_C_COMPILER} ${PROJECT_SOURCE_DIR}
${SRC_FILE}
${SRC_FILE} "-DMLX_METAL_VERSION=${MLX_METAL_VERSION}"
DEPENDS make_compiled_preamble.sh kernels/${SRC_FILE}.h ${ARGN})
add_custom_target(${SRC_NAME} DEPENDS jit/${SRC_NAME}.cpp)
add_dependencies(mlx ${SRC_NAME})
target_sources(mlx PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/jit/${SRC_NAME}.cpp)
endfunction(make_jit_source)
make_jit_source(
utils
kernels/jit/bf16.h
kernels/metal_3_0/bf16.h
kernels/metal_3_1/bf16.h
kernels/bf16_math.h
kernels/complex.h
kernels/defines.h)
make_jit_source(utils kernels/bf16.h kernels/complex.h kernels/defines.h)
make_jit_source(unary_ops kernels/erf.h kernels/expm1f.h)
make_jit_source(binary_ops)
make_jit_source(ternary_ops)

View File

@@ -30,7 +30,7 @@ BufferCache::BufferCache(MTL::Device* device)
: device_(device), head_(nullptr), tail_(nullptr), pool_size_(0) {}
BufferCache::~BufferCache() {
auto pool = metal::new_scoped_memory_pool();
auto thread_pool = metal::new_scoped_memory_pool();
clear();
}
@@ -155,13 +155,11 @@ MetalAllocator::MetalAllocator()
}
size_t MetalAllocator::set_cache_limit(size_t limit) {
std::unique_lock lk(mutex_);
std::swap(limit, max_pool_size_);
return limit;
};
size_t MetalAllocator::set_memory_limit(size_t limit, bool relaxed) {
std::unique_lock lk(mutex_);
std::swap(limit, block_limit_);
relaxed_ = relaxed;
gc_limit_ = std::min(
@@ -171,7 +169,6 @@ size_t MetalAllocator::set_memory_limit(size_t limit, bool relaxed) {
};
size_t MetalAllocator::set_wired_limit(size_t limit) {
std::unique_lock lk(mutex_);
std::swap(limit, wired_limit_);
residency_set_.resize(wired_limit_);
return limit;
@@ -208,7 +205,7 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
return Buffer{nullptr};
}
auto pool = metal::new_scoped_memory_pool();
auto thread_pool = metal::new_scoped_memory_pool();
// If we have a lot of memory pressure or are over the maximum cache size,
// try to reclaim memory from the cache
@@ -229,7 +226,7 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
// Maintain the cache below the requested limit
if (get_cache_memory() >= max_pool_size_) {
auto pool = metal::new_scoped_memory_pool();
auto thread_pool = metal::new_scoped_memory_pool();
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
}
@@ -240,15 +237,11 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
void MetalAllocator::clear_cache() {
std::unique_lock lk(mutex_);
auto pool = metal::new_scoped_memory_pool();
buffer_cache_.clear();
}
void MetalAllocator::free(Buffer buffer) {
auto buf = static_cast<MTL::Buffer*>(buffer.ptr());
if (buf == nullptr) {
return;
}
std::unique_lock lk(mutex_);
residency_set_.erase(buf);
active_memory_ -= buf->length();
@@ -256,7 +249,7 @@ void MetalAllocator::free(Buffer buffer) {
buffer_cache_.recycle_to_cache(buf);
} else {
lk.unlock();
auto pool = metal::new_scoped_memory_pool();
auto thread_pool = metal::new_scoped_memory_pool();
buf->release();
}
}

View File

@@ -22,37 +22,37 @@ std::string get_kernel_name(
BinaryOpType bopt,
const std::string& op,
const array& a,
bool large,
bool use_2d,
int ndim,
int work_per_thread) {
std::string kname;
std::ostringstream kname;
switch (bopt) {
case BinaryOpType::ScalarScalar:
kname = "ss";
kname << "ss";
break;
case BinaryOpType::ScalarVector:
kname = (large ? "sv2" : "sv");
kname << (use_2d ? "sv2" : "sv");
break;
case BinaryOpType::VectorScalar:
kname = (large ? "vs2" : "vs");
kname << (use_2d ? "vs2" : "vs");
break;
case BinaryOpType::VectorVector:
kname = (large ? "vv2" : "vv");
kname << (use_2d ? "vv2" : "vv");
break;
case BinaryOpType::General:
kname = "g";
kname << "g";
if (ndim <= 3) {
kname += std::to_string(ndim);
kname << ndim;
} else {
concatenate(kname, "n", std::to_string(work_per_thread));
}
if (large) {
kname += "large";
kname << "n";
if (work_per_thread > 1) {
kname << work_per_thread;
}
}
break;
}
concatenate(kname, "_", op, type_to_name(a));
return kname;
kname << "_" << op << type_to_name(a);
return kname.str();
}
void binary_op_gpu_inplace(
@@ -81,24 +81,18 @@ void binary_op_gpu_inplace(
};
auto [shape, strides_a, strides_b, strides_out] = maybe_collapse();
bool large = out.data_size() > UINT32_MAX;
bool use_2d = out.data_size() > UINT32_MAX;
auto ndim = shape.size();
int work_per_thread;
if (bopt == BinaryOpType::General) {
large |= (a.data_size() > UINT32_MAX || b.data_size() > UINT32_MAX);
work_per_thread = large ? 4 : 2;
} else {
work_per_thread = 1;
}
int work_per_thread = (bopt == BinaryOpType::General) ? 4 : 1;
std::string kernel_name =
get_kernel_name(bopt, op, a, large, shape.size(), work_per_thread);
get_kernel_name(bopt, op, a, use_2d, shape.size(), work_per_thread);
auto& d = metal::device(s.device);
auto kernel = outputs.size() == 2
? get_binary_two_kernel(d, kernel_name, a.dtype(), out.dtype(), op)
: get_binary_kernel(d, kernel_name, a.dtype(), out.dtype(), op);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder->setComputePipelineState(kernel);
// - If a is donated it goes to the first output
// - If b is donated it goes to the first output if a was not donated
@@ -123,15 +117,19 @@ void binary_op_gpu_inplace(
size_t rest = out.size() / (dim0 * dim1);
if (ndim > 3) {
compute_encoder.set_vector_bytes(shape, arg_idx++);
compute_encoder.set_vector_bytes(strides_a, arg_idx++);
compute_encoder.set_vector_bytes(strides_b, arg_idx++);
compute_encoder.set_bytes<int>(ndim, arg_idx++);
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), arg_idx++);
compute_encoder->setBytes(
strides_a.data(), ndim * sizeof(size_t), arg_idx++);
compute_encoder->setBytes(
strides_b.data(), ndim * sizeof(size_t), arg_idx++);
compute_encoder->setBytes(&ndim, sizeof(int), arg_idx++);
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
} else {
// The shape is implicit in the grid for <= 3D
compute_encoder.set_vector_bytes(strides_a, arg_idx++);
compute_encoder.set_vector_bytes(strides_b, arg_idx++);
compute_encoder->setBytes(
strides_a.data(), ndim * sizeof(size_t), arg_idx++);
compute_encoder->setBytes(
strides_b.data(), ndim * sizeof(size_t), arg_idx++);
}
if (thread_group_size != 1024) {
@@ -139,7 +137,7 @@ void binary_op_gpu_inplace(
}
auto group_dims = get_block_dims(dim0, dim1, rest);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder.dispatch_threads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} else {
// Launch a 1D or 2D grid of threads
size_t nthreads = out.data_size();
@@ -147,9 +145,9 @@ void binary_op_gpu_inplace(
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims);
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
}

View File

@@ -1,6 +1,5 @@
// Copyright © 2023-2024 Apple Inc.
#include <fmt/format.h>
#include <iostream> //TODO
#include <sstream>
#include "mlx/backend/common/compiled.h"
@@ -12,12 +11,12 @@
#include "mlx/primitives.h"
#include "mlx/utils.h"
using namespace fmt::literals;
namespace mlx::core {
constexpr int WORK_PER_THREAD = 4;
inline void build_kernel(
std::string& os,
std::ostream& os,
const std::string& kernel_name,
const std::vector<array>& inputs,
const std::vector<array>& outputs,
@@ -42,8 +41,8 @@ inline void build_kernel(
int cnt = 0;
// Start the kernel
os += fmt::format(
"[[host_name(\"{0}\")]]\n[[kernel]] void {0}(\n", kernel_name);
os << "[[host_name(\"" << kernel_name << "\")]]\n"
<< "[[kernel]] void " << kernel_name << "(\n";
// Add the input arguments
for (auto& x : inputs) {
@@ -55,61 +54,51 @@ inline void build_kernel(
}
// Scalars and contiguous need no strides
if (!is_scalar(x) && !contiguous) {
if (is_scalar(x) || contiguous) {
os << " device const " << get_type_string(x.dtype()) << "* " << xname
<< " [[buffer(" << cnt++ << ")]],\n";
} else {
add_indices = true;
os << " device const " << get_type_string(x.dtype()) << "* " << xname
<< " [[buffer(" << cnt++ << ")]],\n";
}
os += fmt::format(
" device const {0}* {1} [[buffer({2})]],\n",
get_type_string(x.dtype()),
xname,
cnt++);
}
if (add_indices) {
os += fmt::format(
" constant const size_t* in_strides [[buffer({0})]],\n", cnt++);
os << " constant const size_t* in_strides [[buffer(" << cnt++
<< ")]],\n";
}
// Add the output arguments
for (auto& x : outputs) {
os += fmt::format(
" device {0}* {1} [[buffer({2})]],\n",
get_type_string(x.dtype()),
namer.get_name(x),
cnt++);
os << " device " << get_type_string(x.dtype()) << "* "
<< namer.get_name(x) << " [[buffer(" << cnt++ << ")]],\n";
}
// Add output strides and shape to extract the indices.
if (!contiguous) {
os += fmt::format(
" constant const size_t* output_strides [[buffer({0})]],\n", cnt++);
os += fmt::format(
" constant const int* output_shape [[buffer({0})]],\n", cnt++);
os << " constant const size_t* output_strides [[buffer(" << cnt++
<< ")]],\n"
<< " constant const int* output_shape [[buffer(" << cnt++ << ")]],\n";
}
if (dynamic_dims) {
os += fmt::format(" constant const int& ndim [[buffer({0})]],\n", cnt++);
os << " constant const int& ndim [[buffer(" << cnt++ << ")]],\n";
}
// The thread index in the whole grid
os += " uint3 pos [[thread_position_in_grid]],\n";
os += " uint3 grid [[threads_per_grid]]) {\n";
os << " uint3 pos [[thread_position_in_grid]],\n"
<< " uint3 grid [[threads_per_grid]]) {\n";
std::string idx_type = use_big_index ? "size_t" : "uint";
if (contiguous && use_big_index) {
if (use_big_index) {
// This is only used for contiguous kernels which don't have
// a third grid dimension
os += " size_t index = pos.x + grid.x * size_t(pos.y);\n";
os << " size_t index = pos.x + grid.x * size_t(pos.y);\n";
} else if (work_per_thread > 1) {
os += fmt::format(" constexpr int N_ = {0};\n", work_per_thread);
os += fmt::format(
" int xshape = output_shape[{0}];\n",
dynamic_dims ? "ndim - 1" : std::to_string(ndim - 1));
os += fmt::format(
" {0} index = N_ * pos.x + xshape * (pos.y + {0}(grid.y) * pos.z);\n",
idx_type);
os << " constexpr int N_ = " << std::to_string(work_per_thread) << ";\n"
<< " int xshape = output_shape["
<< (dynamic_dims ? "ndim - 1" : std::to_string(ndim - 1)) << "];\n"
<< " size_t index = N_ * pos.x + xshape * (pos.y + size_t(grid.y) * pos.z);\n";
} else {
os += fmt::format(
" {0} index = pos.x + grid.x * (pos.y + {0}(grid.y) * pos.z);\n",
idx_type);
os << " size_t index = pos.x + grid.x * (pos.y + size_t(grid.y) * pos.z);\n";
}
// Read constant / contiguous inputs in tmps
@@ -120,19 +109,16 @@ inline void build_kernel(
if (is_constant(x)) {
auto type_str = get_type_string(x.dtype());
std::ostringstream ss;
print_constant(ss, x);
os += fmt::format(
" auto tmp_{0} = static_cast<{1}>({2});\n",
xname,
get_type_string(x.dtype()),
ss.str());
os << " auto tmp_" << xname << " = static_cast<"
<< get_type_string(x.dtype()) << ">(";
print_constant(os, x);
os << ");\n";
} else if (is_scalar(x)) {
os += fmt::format(
" {0} tmp_{1} = {1}[0];\n", get_type_string(x.dtype()), xname);
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[0];\n";
} else if (contiguous) {
os += fmt::format(
" {0} tmp_{1} = {1}[index];\n", get_type_string(x.dtype()), xname);
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[index];\n";
} else {
nc_inputs.push_back(x);
}
@@ -141,98 +127,83 @@ inline void build_kernel(
// Initialize the indices for non-contiguous inputs
for (int i = 0; i < nc_inputs.size(); ++i) {
auto& xname = namer.get_name(nc_inputs[i]);
os += fmt::format(" {0} index_{1} = ", idx_type, xname);
if (ndim == 1) {
int offset = i * ndim;
os += fmt::format(
"elem_to_loc_1<size_t, uint>(pos.x, in_strides[{0}]);\n", offset);
os << " size_t index_" << xname << " = elem_to_loc_1(pos.x, "
<< "in_strides[" << offset << "]);\n";
} else if (ndim == 2) {
int offset = i * ndim;
os += fmt::format(
"elem_to_loc_2<size_t, {0}>({{pos.x, pos.y}}, in_strides + {1});\n",
idx_type,
offset);
os << " size_t index_" << xname << " = elem_to_loc_2({pos.x, pos.y}, "
<< "in_strides + " << offset << ");\n";
} else if (ndim == 3) {
int offset = i * ndim;
os += fmt::format(
"elem_to_loc_3<size_t, {0}>(pos, in_strides + {1});\n",
idx_type,
offset);
os << " size_t index_" << xname << " = elem_to_loc_3(pos, "
<< "in_strides + " << offset << ");\n";
} else if (!dynamic_dims) {
int offset = (i + 1) * ndim;
os += fmt::format(
"N_ * pos.x * {0}(in_strides[{1}]) + pos.y * {0}(in_strides[{2}]);\n",
idx_type,
offset - 1,
offset - 2);
int offset = i * ndim;
os << " size_t index_" << xname << " = N_ * pos.x * in_strides["
<< offset + ndim - 1 << "]"
<< " + pos.y * in_strides[" << offset + ndim - 2 << "];\n";
} else {
os += fmt::format(
"N_ * pos.x * {0}(in_strides[ndim * {1} + ndim - 1]) + pos.y * {0}(in_strides[ndim * {1} + ndim - 2]);\n",
idx_type,
i);
os << " size_t index_" << xname << " = N_ * pos.x * in_strides[ndim * "
<< i << " + ndim - 1]"
<< " + pos.y * in_strides[ndim * " << i << " + ndim - 2];\n";
}
}
if (!nc_inputs.empty() && (ndim > 3 || dynamic_dims)) {
os += " uint zpos = pos.z;\n";
os << " uint zpos = pos.z;\n";
if (dynamic_dims) {
os += " for (int d = ndim - 3; d >= 0; --d) {\n";
os << " for (int d = ndim - 3; d >= 0; --d) {\n";
} else {
os += fmt::format(" for (int d = {0}; d >= 0; --d) {{\n", ndim - 3);
os << " for (int d = " << ndim - 3 << "; d >= 0; --d) {\n";
}
os += " uint l = zpos % output_shape[d];\n";
os << " uint l = zpos % output_shape[d];\n";
for (int i = 0; i < nc_inputs.size(); ++i) {
auto& xname = namer.get_name(nc_inputs[i]);
os += fmt::format(" index_{0} += ", xname);
os << " index_" << xname << " += ";
if (dynamic_dims) {
os +=
fmt::format("l * {0}(in_strides[{1} * ndim + d]);\n", idx_type, i);
os << "l * in_strides[" << i << " * ndim + d];\n";
} else {
os +=
fmt::format("l * {0}(in_strides[{1} + d]);\n", idx_type, i * ndim);
os << "l * in_strides[" << i * ndim << " + d];\n";
}
}
os += " zpos /= output_shape[d];\n }\n";
os << " zpos /= output_shape[d];\n }\n";
}
// Open per-thread loop
if (work_per_thread > 1) {
os +=
" for (int i = 0; i < N_ && (int(N_ * pos.x) + i) < xshape; ++i) {\n";
os << " for (int i = 0; i < N_ && (int(N_ * pos.x) + i) < xshape; ++i) {\n";
}
// Read non-contiguous inputs into tmps
for (int i = 0; i < nc_inputs.size(); ++i) {
auto& x = nc_inputs[i];
auto& xname = namer.get_name(x);
os += fmt::format(
" {0} tmp_{1} = {1}[index_{1}];\n", get_type_string(x.dtype()), xname);
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[index_" << xname << "];\n";
}
// Actually write the computation
for (auto& x : tape) {
os += fmt::format(
" {0} tmp_{1} = ", get_type_string(x.dtype()), namer.get_name(x));
os << " " << get_type_string(x.dtype()) << " tmp_" << namer.get_name(x)
<< " = ";
if (is_static_cast(x.primitive())) {
os += fmt::format(
"static_cast<{0}>(tmp_{1});\n",
get_type_string(x.dtype()),
namer.get_name(x.inputs()[0]));
os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_"
<< namer.get_name(x.inputs()[0]) << ");\n";
} else {
std::ostringstream ss;
x.primitive().print(ss);
os += ss.str();
os += "()(";
x.primitive().print(os);
os << "()(";
for (int i = 0; i < x.inputs().size() - 1; i++) {
os += fmt::format("tmp_{0}, ", namer.get_name(x.inputs()[i]));
os << "tmp_" << namer.get_name(x.inputs()[i]) << ", ";
}
os += fmt::format("tmp_{0});\n", namer.get_name(x.inputs().back()));
os << "tmp_" << namer.get_name(x.inputs().back()) << ");\n";
}
}
// Write the outputs from tmps
for (auto& x : outputs) {
os += fmt::format(" {0}[index] = tmp_{0};\n", namer.get_name(x));
os << " " << namer.get_name(x) << "[index] = tmp_" << namer.get_name(x)
<< ";\n";
}
// Increment indices and close per thread loop
if (work_per_thread > 1) {
@@ -240,18 +211,18 @@ inline void build_kernel(
auto& x = nc_inputs[i];
auto& xname = namer.get_name(x);
if (!dynamic_dims) {
os += fmt::format(
" index_{0} += in_strides[{1}];\n", xname, i * ndim + ndim - 1);
os << " index_" << xname << " += "
<< "in_strides[" << i * ndim + ndim - 1 << "];\n";
} else {
os += fmt::format(
" index_{0} += in_strides[{1} * ndim + ndim - 1];\n", xname, i);
os << " index_" << xname << " += "
<< "in_strides[" << i << " * ndim + ndim - 1];\n";
}
}
os += " index++;\n }\n";
os << " index++;\n }\n";
}
// Finish the kernel
os += "}\n";
os << "}\n";
if (cnt > 31) {
std::ostringstream msg;
@@ -275,9 +246,9 @@ void Compiled::eval_gpu(
auto& s = stream();
auto& d = metal::device(s.device);
auto lib = d.get_library(kernel_lib_, [&]() {
std::string kernel = metal::utils();
concatenate(
kernel, metal::unary_ops(), metal::binary_ops(), metal::ternary_ops());
std::ostringstream kernel;
kernel << metal::utils() << metal::unary_ops() << metal::binary_ops()
<< metal::ternary_ops();
build_kernel(
kernel,
kernel_lib_ + "_contiguous",
@@ -290,7 +261,7 @@ void Compiled::eval_gpu(
/* dynamic_dims = */ false);
build_kernel(
kernel,
kernel_lib_ + "_contiguous_large",
kernel_lib_ + "_contiguous_big",
inputs_,
outputs_,
tape_,
@@ -311,21 +282,7 @@ void Compiled::eval_gpu(
/* ndim = */ i,
/* dynamic_dims = */ false,
/* use_big_index = */ false,
/* work_per_thread = */ i > 3 ? 2 : 1);
if (i > 1) {
build_kernel(
kernel,
kernel_lib_ + "_strided_" + std::to_string(i) + "_large",
inputs_,
outputs_,
tape_,
constant_ids_,
/* contiguous = */ false,
/* ndim = */ i,
/* dynamic_dims = */ false,
/* use_big_index = */ true,
/* work_per_thread = */ i > 3 ? 4 : 1);
}
/* work_per_thread = */ i > 3 ? WORK_PER_THREAD : 1);
}
build_kernel(
kernel,
@@ -338,25 +295,13 @@ void Compiled::eval_gpu(
/* ndim = */ 0,
/* dynamic_dims = */ true,
/* use_big_index = */ false,
/* work_per_thread = */ 2);
build_kernel(
kernel,
kernel_lib_ + "_strided_dynamic_large",
inputs_,
outputs_,
tape_,
constant_ids_,
/* contiguous = */ false,
/* ndim = */ 0,
/* dynamic_dims = */ true,
/* use_big_index = */ true,
/* work_per_thread = */ 4);
return kernel;
/* work_per_thread = */ WORK_PER_THREAD);
return kernel.str();
});
// Figure out which kernel we are using
auto& output_shape = outputs[0].shape();
auto contiguous = compiled_check_contiguity(inputs, output_shape);
bool contiguous = compiled_check_contiguity(inputs, output_shape);
// Collapse contiguous dims to route to a faster kernel if possible. Also
// handle all broadcasting.
@@ -404,19 +349,13 @@ void Compiled::eval_gpu(
collapse_contiguous_dims(output_shape, initial_strides, INT32_MAX);
}
bool large;
bool use_2d = false;
if (contiguous) {
size_t max_size = 0;
for (auto& in : inputs) {
max_size = std::max(max_size, in.data_size());
}
large = (max_size > UINT32_MAX);
} else {
size_t max_size = 0;
for (auto& o : outputs) {
max_size = std::max(max_size, o.size());
}
large = (max_size > UINT32_MAX);
use_2d = (max_size > UINT32_MAX);
}
// Get the kernel from the lib
@@ -429,13 +368,12 @@ void Compiled::eval_gpu(
} else {
kernel_name += std::to_string(shape.size());
}
}
if (large) {
kernel_name += "_large";
} else if (use_2d) {
kernel_name += "_big";
}
auto kernel = d.get_kernel(kernel_name, lib);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder->setComputePipelineState(kernel);
// Put the inputs in
int cnt = 0;
@@ -456,7 +394,8 @@ void Compiled::eval_gpu(
}
}
if (!in_strides.empty()) {
compute_encoder.set_vector_bytes(in_strides, cnt++);
compute_encoder->setBytes(
in_strides.data(), in_strides.size() * sizeof(size_t), cnt++);
}
compiled_allocate_outputs(
@@ -469,13 +408,14 @@ void Compiled::eval_gpu(
// Put the output shape and strides in
if (!contiguous) {
compute_encoder.set_vector_bytes(strides[0], cnt++);
compute_encoder.set_vector_bytes(shape, cnt++);
compute_encoder->setBytes(
strides[0].data(), strides[0].size() * sizeof(size_t), cnt++);
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), cnt++);
}
// Put the number of dims in if it is dynamic
if (dynamic) {
compute_encoder.set_bytes(ndim, cnt++);
compute_encoder->setBytes(&ndim, sizeof(int), cnt++);
}
// Launch the kernel
@@ -484,15 +424,15 @@ void Compiled::eval_gpu(
MTL::Size group_dims(
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
MTL::Size grid_dims = large
MTL::Size grid_dims = use_2d
? get_2d_grid_dims(outputs[0].shape(), outputs[0].strides())
: MTL::Size(nthreads, 1, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} else {
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
size_t rest = outputs[0].size() / (dim0 * dim1);
int work_per_thread = ndim > 3 ? (large ? 4 : 2) : 1;
int work_per_thread = ndim > 3 ? WORK_PER_THREAD : 1;
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
int pow2;
@@ -505,7 +445,7 @@ void Compiled::eval_gpu(
}
auto group_dims = get_block_dims(dim0, dim1, rest, pow2);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder.dispatch_threads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
}

View File

@@ -44,24 +44,23 @@ void explicit_gemm_conv_ND_gpu(
kname << "naive_unfold_nd_" << type_to_name(in_unfolded) << "_" << N;
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(in_unfolded, 1);
compute_encoder.set_bytes(conv_params, 2);
compute_encoder->setBytes(&conv_params, sizeof(conv_params), 2);
// Launch unfolding kernel
size_t tgp_x = std::min(conv_params.C, 64);
int tgp_x = std::min(conv_params.C, 64);
tgp_x = 32 * ((tgp_x + 32 - 1) / 32);
size_t tgp_y = 256 / tgp_x;
int tgp_y = 256 / tgp_x;
MTL::Size group_dims = MTL::Size(tgp_x, tgp_y, 1);
MTL::Size grid_dims = MTL::Size(
conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]);
MTL::Size group_dims = MTL::Size(
std::min(tgp_x, grid_dims.width), std::min(tgp_y, grid_dims.height), 1);
compute_encoder.dispatch_threads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
// Reshape weight
std::vector<int> wt_reshape{implicit_K, implicit_N};
@@ -123,24 +122,23 @@ void explicit_gemm_conv_group_ND_gpu(
<< N;
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(in_unfolded, 1);
compute_encoder.set_bytes(conv_params, 2);
compute_encoder->setBytes(&conv_params, sizeof(conv_params), 2);
// Launch unfolding kernel
size_t tgp_x = std::min(conv_params.C, 64);
int tgp_x = std::min(conv_params.C, 64);
tgp_x = 32 * ((tgp_x + 32 - 1) / 32);
size_t tgp_y = 256 / tgp_x;
int tgp_y = 256 / tgp_x;
MTL::Size group_dims = MTL::Size(tgp_x, tgp_y, 1);
MTL::Size grid_dims = MTL::Size(
conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]);
MTL::Size group_dims = MTL::Size(
std::min(tgp_x, grid_dims.width), std::min(tgp_y, grid_dims.height), 1);
compute_encoder.dispatch_threads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
// Transpose kernel weights so that we can slice them by contiguous chunks
// of channel groups.
@@ -239,7 +237,7 @@ void slow_conv_2D_gpu(
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder->setComputePipelineState(kernel);
size_t n_pixels = conv_params.oS[0] * conv_params.oS[1];
@@ -254,8 +252,8 @@ void slow_conv_2D_gpu(
compute_encoder.set_input_array(wt, 1);
compute_encoder.set_output_array(out, 2);
compute_encoder.set_bytes(conv_params, 3);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
void implicit_gemm_conv_2D_gpu(
@@ -354,7 +352,7 @@ void implicit_gemm_conv_2D_gpu(
wn,
n_channel_specialization,
small_filter);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder->setComputePipelineState(kernel);
// Deduce grid launch dimensions
int tile = 1 << swizzle_log;
@@ -370,11 +368,11 @@ void implicit_gemm_conv_2D_gpu(
compute_encoder.set_output_array(out, 2);
// Encode params
compute_encoder.set_bytes(conv_params, 3);
compute_encoder.set_bytes(gemm_params, 4);
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
compute_encoder->setBytes(&gemm_params, sizeof(ImplicitGemmConv2DParams), 4);
// Launch kernel
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
void implicit_gemm_conv_2D_general_gpu(
@@ -508,7 +506,7 @@ void implicit_gemm_conv_2D_general_gpu(
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel =
get_steel_conv_general_kernel(d, kname.str(), out, bm, bn, bk, wm, wn);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder->setComputePipelineState(kernel);
// Deduce grid launch dimensions
int tile = 1 << swizzle_log;
@@ -525,15 +523,17 @@ void implicit_gemm_conv_2D_general_gpu(
compute_encoder.set_output_array(out, 2);
// Encode params
compute_encoder.set_bytes(conv_params, 3);
compute_encoder.set_bytes(gemm_params, 4);
compute_encoder.set_bytes(jump_params, 5);
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
compute_encoder->setBytes(&gemm_params, sizeof(ImplicitGemmConv2DParams), 4);
compute_encoder->setBytes(&jump_params, sizeof(Conv2DGeneralJumpParams), 5);
compute_encoder.set_vector_bytes(base_h, 6);
compute_encoder.set_vector_bytes(base_w, 7);
compute_encoder->setBytes(
base_h.data(), sizeof(Conv2DGeneralBaseInfo) * base_h.size(), 6);
compute_encoder->setBytes(
base_w.data(), sizeof(Conv2DGeneralBaseInfo) * base_w.size(), 7);
// Launch kernel
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
void winograd_conv_2D_gpu(
@@ -622,18 +622,18 @@ void winograd_conv_2D_gpu(
<< bc;
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(wt, 0);
compute_encoder.set_output_array(filt_wg, 1);
compute_encoder.set_bytes(C_c, 2);
compute_encoder.set_bytes(O_c, 3);
compute_encoder->setBytes(&C_c, sizeof(int), 2);
compute_encoder->setBytes(&O_c, sizeof(int), 3);
MTL::Size group_dims = MTL::Size(32, bo, 1);
MTL::Size grid_dims = MTL::Size(O_c / bo, 1, 1);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
// Do input transform
@@ -650,17 +650,18 @@ void winograd_conv_2D_gpu(
<< bc;
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(in_padded, 0);
compute_encoder.set_output_array(inp_wg, 1);
compute_encoder.set_bytes(conv_params_updated, 2);
compute_encoder->setBytes(
&conv_params_updated, sizeof(MLXConvParams<2>), 2);
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
// Do batched gemm
@@ -697,17 +698,18 @@ void winograd_conv_2D_gpu(
<< bc;
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(out_wg, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder.set_bytes(conv_params_updated, 2);
compute_encoder->setBytes(
&conv_params_updated, sizeof(MLXConvParams<2>), 2);
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
}

View File

@@ -74,46 +74,44 @@ void copy_gpu_inplace(
};
auto [shape, strides_in_, strides_out_] = maybe_collapse();
int ndim = shape.size();
bool large;
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
// Allow for negative strides
large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
} else {
large = out.data_size() > UINT32_MAX;
}
bool use_2d = out.data_size() > UINT32_MAX;
auto& d = metal::device(s.device);
int work_per_thread = 1;
std::string kernel_name;
switch (ctype) {
case CopyType::Scalar:
kernel_name = (large ? "s2" : "s");
break;
case CopyType::Vector:
kernel_name = (large ? "v2" : "v");
break;
case CopyType::General:
kernel_name = "g";
break;
case CopyType::GeneralGeneral:
kernel_name = "gg";
break;
}
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
if (shape.size() <= MAX_COPY_SPECIALIZED_DIMS) {
kernel_name += std::to_string(shape.size());
} else {
work_per_thread = large ? 4 : 2;
concatenate(kernel_name, "n", std::to_string(work_per_thread));
{
std::ostringstream kname;
switch (ctype) {
case CopyType::Scalar:
kname << (use_2d ? "s2" : "s");
break;
case CopyType::Vector:
kname << (use_2d ? "v2" : "v");
break;
case CopyType::General:
kname << "g";
break;
case CopyType::GeneralGeneral:
kname << "gg";
break;
}
if (large) {
kernel_name += "large";
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
if (shape.size() <= MAX_COPY_SPECIALIZED_DIMS) {
kname << shape.size();
} else {
work_per_thread = 4;
kname << "n4";
}
}
kname << "_copy";
kname << type_to_name(in) << type_to_name(out);
kernel_name = kname.str();
}
concatenate(kernel_name, "_copy", type_to_name(in), type_to_name(out));
auto kernel = get_copy_kernel(d, kernel_name, in, out);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder->setComputePipelineState(kernel);
bool donate_in = in.data_shared_ptr() == nullptr;
inp_offset *= size_of(in.dtype());
@@ -127,23 +125,23 @@ void copy_gpu_inplace(
std::vector<int64_t> strides_in{strides_in_.begin(), strides_in_.end()};
std::vector<int64_t> strides_out{strides_out_.begin(), strides_out_.end()};
if (ndim > 3) {
compute_encoder.set_vector_bytes(shape, ndim, 2);
set_vector_bytes(compute_encoder, shape, ndim, 2);
}
compute_encoder.set_vector_bytes(strides_in, ndim, 3);
set_vector_bytes(compute_encoder, strides_in, ndim, 3);
if (ctype == CopyType::GeneralGeneral) {
compute_encoder.set_vector_bytes(strides_out, ndim, 4);
set_vector_bytes(compute_encoder, strides_out, ndim, 4);
}
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
int dim0 = ndim > 0 ? shape[ndim - 1] : 1;
int dim1 = ndim > 1 ? shape[ndim - 2] : 1;
size_t data_size = 1;
for (auto& s : shape)
data_size *= s;
size_t rest = data_size / (dim0 * dim1);
int rest = data_size / (dim0 * dim1);
if (ndim > MAX_COPY_SPECIALIZED_DIMS) {
compute_encoder.set_bytes(ndim, 5);
compute_encoder->setBytes(&ndim, sizeof(int), 5);
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
}
@@ -154,16 +152,16 @@ void copy_gpu_inplace(
auto group_dims = get_block_dims(dim0, dim1, rest);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder.dispatch_threads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} else {
size_t nthreads = out.data_size();
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims);
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
}
@@ -195,13 +193,13 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
return;
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
bool large = out.data_size() > UINT32_MAX;
bool use_2d = out.data_size() > UINT32_MAX;
auto& d = metal::device(s.device);
std::string kernel_name = std::string(large ? "s2" : "s") + "_copy" +
std::string kernel_name = std::string(use_2d ? "s2" : "s") + "_copy" +
type_to_name(val) + type_to_name(out);
auto kernel = get_copy_kernel(d, kernel_name, val, out);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(val, 0);
compute_encoder.set_output_array(out, 1);
@@ -212,9 +210,9 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims);
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
} // namespace mlx::core

View File

@@ -43,7 +43,7 @@ void CustomKernel::eval_gpu(
d.get_library(lib_name, [this] { return metal::utils() + source_; });
auto kernel = d.get_kernel(name_, lib);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder->setComputePipelineState(kernel);
int index = 0;
for (int i = 0; i < checked_inputs.size(); i++) {
const array& in = checked_inputs[i];
@@ -53,15 +53,15 @@ void CustomKernel::eval_gpu(
if (in.ndim() > 0) {
int ndim = in.ndim();
if (shape_info.shape) {
compute_encoder.set_vector_bytes(in.shape(), ndim, index);
set_vector_bytes(compute_encoder, in.shape(), ndim, index);
index++;
}
if (shape_info.strides) {
compute_encoder.set_vector_bytes(in.strides(), ndim, index);
set_vector_bytes(compute_encoder, in.strides(), ndim, index);
index++;
}
if (shape_info.ndim) {
compute_encoder.set_bytes(ndim, index);
compute_encoder->setBytes(&ndim, sizeof(int), index);
index++;
}
}
@@ -72,11 +72,10 @@ void CustomKernel::eval_gpu(
}
const auto [tx, ty, tz] = threadgroup_;
MTL::Size group_dims = MTL::Size(tx, ty, tz);
const auto [gx, gy, gz] = grid_;
MTL::Size group_dims =
MTL::Size(std::min(tx, gx), std::min(ty, gy), std::min(tz, gz));
MTL::Size grid_dims = MTL::Size(gx, gy, gz);
compute_encoder.dispatch_threads(grid_dims, group_dims);
compute_encoder->dispatchThreads(grid_dims, group_dims);
d.add_temporaries(std::move(copies), s.index);
}

View File

@@ -23,18 +23,14 @@ constexpr int MAX_BUFFERS_PER_QUEUE = 12;
constexpr const char* default_mtllib_path = METAL_PATH;
auto get_metal_version() {
auto get_metal_version_ = []() {
if (__builtin_available(macOS 15, iOS 18, tvOS 18, visionOS 2, *)) {
return MTL::LanguageVersion3_2;
} else if (__builtin_available(macOS 14, iOS 17, tvOS 17, visionOS 1, *)) {
return MTL::LanguageVersion3_1;
} else {
return MTL::LanguageVersion3_0;
}
};
static auto metal_version_ = get_metal_version_();
return metal_version_;
constexpr auto get_metal_version() {
#if (MLX_METAL_VERSION >= 320)
return MTL::LanguageVersion3_2;
#elif (MLX_METAL_VERSION >= 310)
return MTL::LanguageVersion3_1;
#else
return MTL::LanguageVersion3_0;
#endif
}
auto load_device() {
@@ -175,14 +171,14 @@ void CommandEncoder::maybeInsertBarrier() {
next_outputs_.clear();
}
void CommandEncoder::dispatch_threadgroups(
void CommandEncoder::dispatchThreadgroups(
MTL::Size grid_dims,
MTL::Size group_dims) {
maybeInsertBarrier();
enc_->dispatchThreadgroups(grid_dims, group_dims);
}
void CommandEncoder::dispatch_threads(
void CommandEncoder::dispatchThreads(
MTL::Size grid_dims,
MTL::Size group_dims) {
maybeInsertBarrier();
@@ -280,7 +276,7 @@ void Device::end_encoding(int index) {
// - Update the map of outputs to include this command encoder's outputs.
// - Always signal this command encoders fence.
// - Add a completion handler for this command encoder that removes outputs
// from the map to limit the growth of the map and avoid unnecessary waits
// from the map to limit the growth of the map and avoid unecessary waits
// - Temporaries are a special case as they do not cross command encoder
// boundaries. These can be removed early from the encoders inputs and
// outputs since they don't need synchronization.
@@ -302,7 +298,7 @@ void Device::end_encoding(int index) {
if (auto it = stream.outputs.find(in); it != stream.outputs.end()) {
// If we've already waited on a fence, don't wait on it again.
if (waiting_on.find(it->second) == waiting_on.end()) {
enc.wait_for_fence(it->second->fence);
enc->waitForFence(it->second->fence);
waiting_on.insert(it->second);
}
}
@@ -311,7 +307,7 @@ void Device::end_encoding(int index) {
stream.outputs[out] = stream.fence;
}
}
enc.update_fence(stream.fence->fence);
enc->updateFence(stream.fence->fence);
stream.buffer->addCompletedHandler(
[&stream,
waiting_on = std::move(waiting_on),
@@ -645,27 +641,21 @@ void new_stream(Stream stream) {
std::unordered_map<std::string, std::variant<std::string, size_t>>
device_info() {
auto init_device_info = []()
-> std::unordered_map<std::string, std::variant<std::string, size_t>> {
auto pool = new_scoped_memory_pool();
auto raw_device = device(default_device()).mtl_device();
auto arch = std::string(raw_device->architecture()->name()->utf8String());
auto raw_device = device(default_device()).mtl_device();
auto arch = std::string(raw_device->architecture()->name()->utf8String());
int mib[] = {CTL_HW, HW_MEMSIZE};
size_t memsize = 0;
size_t length = sizeof(memsize);
int mib[] = {CTL_HW, HW_MEMSIZE};
size_t memsize = 0;
size_t length = sizeof(memsize);
sysctl(mib, 2, &memsize, &length, NULL, 0);
sysctl(mib, 2, &memsize, &length, NULL, 0);
return {
{"architecture", arch},
{"max_buffer_length", raw_device->maxBufferLength()},
{"max_recommended_working_set_size",
raw_device->recommendedMaxWorkingSetSize()},
{"memory_size", memsize}};
};
static auto device_info_ = init_device_info();
return device_info_;
return {
{"architecture", arch},
{"max_buffer_length", raw_device->maxBufferLength()},
{"max_recommended_working_set_size",
raw_device->recommendedMaxWorkingSetSize()},
{"memory_size", memsize}};
}
} // namespace mlx::core::metal

View File

@@ -58,43 +58,16 @@ struct CommandEncoder {
CommandEncoder& enc;
};
MTL::ComputeCommandEncoder* operator->() {
return enc_;
}
void set_input_array(const array& a, int idx, int64_t offset = 0);
void set_output_array(array& a, int idx, int64_t offset = 0);
void dispatch_threadgroups(MTL::Size grid_dims, MTL::Size group_dims);
void dispatch_threads(MTL::Size grid_dims, MTL::Size group_dims);
void dispatchThreadgroups(MTL::Size grid_dims, MTL::Size group_dims);
void dispatchThreads(MTL::Size grid_dims, MTL::Size group_dims);
void maybeInsertBarrier();
void set_compute_pipeline_state(MTL::ComputePipelineState* kernel) {
enc_->setComputePipelineState(kernel);
}
void wait_for_fence(MTL::Fence* fence) {
enc_->waitForFence(fence);
}
void update_fence(MTL::Fence* fence) {
enc_->updateFence(fence);
}
template <typename T>
void set_vector_bytes(const std::vector<T>& vec, size_t nelems, int idx) {
enc_->setBytes(vec.data(), nelems * sizeof(T), idx);
}
template <typename T>
void set_vector_bytes(const std::vector<T>& vec, int idx) {
return set_vector_bytes(vec, vec.size(), idx);
}
template <typename T>
void set_bytes(const T* v, int n, int idx) {
return enc_->setBytes(v, n * sizeof(T), idx);
}
template <typename T>
void set_bytes(const T& v, int idx) {
return enc_->setBytes(&v, sizeof(T), idx);
}
ConcurrentContext start_concurrent() {
return ConcurrentContext(*this);
}

View File

@@ -699,7 +699,7 @@ void fft_op(
auto kernel =
get_fft_kernel(d, base_name, hash_name, func_consts, template_def);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(in_contiguous, 0);
compute_encoder.set_output_array(out, 1);
@@ -711,9 +711,9 @@ void fft_op(
compute_encoder.set_input_array(w_q, 2); // w_q
compute_encoder.set_input_array(w_k, 3); // w_k
compute_encoder.set_bytes(n, 4);
compute_encoder.set_bytes(plan.bluestein_n, 5);
compute_encoder.set_bytes(total_batch_size, 6);
compute_encoder->setBytes(&n, sizeof(int), 4);
compute_encoder->setBytes(&plan.bluestein_n, sizeof(int), 5);
compute_encoder->setBytes(&total_batch_size, sizeof(int), 6);
} else if (plan.rader_n > 1) {
auto [b_q, g_q, g_minus_q] = compute_raders_constants(plan.rader_n, s);
copies.push_back(b_q);
@@ -723,22 +723,22 @@ void fft_op(
compute_encoder.set_input_array(b_q, 2);
compute_encoder.set_input_array(g_q, 3);
compute_encoder.set_input_array(g_minus_q, 4);
compute_encoder.set_bytes(n, 5);
compute_encoder.set_bytes(total_batch_size, 6);
compute_encoder.set_bytes(plan.rader_n, 7);
compute_encoder->setBytes(&n, sizeof(int), 5);
compute_encoder->setBytes(&total_batch_size, sizeof(int), 6);
compute_encoder->setBytes(&plan.rader_n, sizeof(int), 7);
} else if (four_step_params.required) {
compute_encoder.set_bytes(four_step_params.n1, 2);
compute_encoder.set_bytes(four_step_params.n2, 3);
compute_encoder.set_bytes(total_batch_size, 4);
compute_encoder->setBytes(&four_step_params.n1, sizeof(int), 2);
compute_encoder->setBytes(&four_step_params.n2, sizeof(int), 3);
compute_encoder->setBytes(&total_batch_size, sizeof(int), 4);
} else {
compute_encoder.set_bytes(n, 2);
compute_encoder.set_bytes(total_batch_size, 3);
compute_encoder->setBytes(&n, sizeof(int), 2);
compute_encoder->setBytes(&total_batch_size, sizeof(int), 3);
}
auto group_dims = MTL::Size(1, threadgroup_batch_size, threads_per_fft);
auto grid_dims =
MTL::Size(batch_size, threadgroup_batch_size, threads_per_fft);
compute_encoder.dispatch_threads(grid_dims, group_dims);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
d.add_temporaries(std::move(copies), s.index);

View File

@@ -137,14 +137,14 @@ void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
auto kernel = d.get_kernel(kernel_name, lib);
assert(threads_per <= kernel->maxTotalThreadsPerThreadgroup());
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder.set_bytes(scale, 2);
compute_encoder->setBytes(&scale, sizeof(float), 2);
MTL::Size group_dims = MTL::Size(1, threads_per, 1);
MTL::Size grid_dims = MTL::Size(batch_size, threads_per, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims);
compute_encoder->dispatchThreads(grid_dims, group_dims);
};
if (m > 1) {

View File

@@ -53,31 +53,27 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
int idx_ndim = nidx ? inputs[1].ndim() : 0;
size_t ndim = src.ndim();
bool large_index = nidx && inputs[1].size() > UINT32_MAX;
bool large_src = src.size() > UINT32_MAX;
bool large_out = out.size() > UINT32_MAX;
bool large = large_index || large_src || large_out;
std::string lib_name;
std::string kernel_name;
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
std::string kernel_name = fmt::format(
"gather{0}{1}_{2}_{3}_{4}",
type_to_name(out),
idx_type_name,
nidx,
idx_ndim,
large ? "size_t" : "uint");
std::string lib_name = kernel_name;
{
std::ostringstream kname;
kname << "gather" << type_to_name(out) << idx_type_name << "_" << nidx
<< "_" << idx_ndim;
lib_name = kname.str();
kernel_name = lib_name;
}
auto lib = d.get_library(lib_name, [&]() {
std::string kernel_source = metal::utils();
kernel_source += metal::gather();
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::gather();
std::string out_type_str = get_type_string(out.dtype());
std::string idx_type_str =
nidx ? get_type_string(inputs[1].dtype()) : "bool";
auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx);
// Index dimension specializations
kernel_source += fmt::format(
kernel_source << fmt::format(
gather_kernels,
type_to_name(out) + idx_type_name,
out_type_str,
@@ -85,14 +81,13 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
nidx,
idx_args,
idx_arr,
idx_ndim,
large ? "size_t" : "uint");
return kernel_source;
idx_ndim);
return kernel_source.str();
});
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kernel_name, lib);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder->setComputePipelineState(kernel);
size_t slice_size = 1;
for (auto s : slice_sizes_) {
@@ -136,20 +131,20 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_output_array(out, 1);
// Set source info
compute_encoder.set_vector_bytes(src.shape(), 2);
compute_encoder.set_vector_bytes(src.strides(), 3);
compute_encoder.set_bytes(ndim, 4);
compute_encoder.set_vector_bytes(slice_sizes_, 5);
compute_encoder.set_vector_bytes(axes_, 6);
set_vector_bytes(compute_encoder, src.shape(), 2);
set_vector_bytes(compute_encoder, src.strides(), 3);
compute_encoder->setBytes(&ndim, sizeof(size_t), 4);
set_vector_bytes(compute_encoder, slice_sizes_, 5);
set_vector_bytes(compute_encoder, axes_, 6);
// Set index info
//
// We don't need to check for empty idx_shapes because gather has a
// idx_ndim == 0 specialization
compute_encoder.set_vector_bytes(idx_shapes, 7);
compute_encoder.set_vector_bytes(idx_strides, 8);
compute_encoder.set_vector_bytes(idx_contigs, 9);
compute_encoder.set_bytes(idx_ndim, 10);
set_vector_bytes(compute_encoder, idx_shapes, 7);
set_vector_bytes(compute_encoder, idx_strides, 8);
set_vector_bytes(compute_encoder, idx_contigs, 9);
compute_encoder->setBytes(&idx_ndim, sizeof(int), 10);
// Set index buffers
for (int i = 0; i < nidx; ++i) {
@@ -157,7 +152,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
}
// Launch grid
compute_encoder.dispatch_threads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -214,6 +209,8 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
nwork = 32;
}
std::string lib_name;
std::string kernel_name;
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
std::string op_name;
switch (reduce_type_) {
@@ -234,24 +231,18 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
break;
}
auto upd_contig = upd.flags().row_contiguous;
bool large_out = out.size() > UINT32_MAX;
bool large_idx = nidx && (inputs[1].size() > UINT32_MAX);
bool large_upd = upd.size() > UINT32_MAX;
bool large = large_out || large_idx || large_upd;
std::string kernel_name = fmt::format(
"scatter{0}{1}_{2}_{3}_{4}_nwork{5}_{6}",
type_to_name(out),
idx_type_name,
op_name,
nidx,
upd_contig ? "updc_true" : "updc_false",
nwork,
large ? "size_t" : "uint");
std::string lib_name = kernel_name;
{
std::ostringstream kname;
kname << "scatter" << type_to_name(out) << idx_type_name;
kname << "_" << op_name << "_" << nidx << "_"
<< (upd_contig ? "updc_true" : "updc_false") << "_nwork" << nwork;
lib_name = kname.str();
kernel_name = kname.str();
}
auto lib = d.get_library(lib_name, [&]() {
std::string kernel_source = metal::utils();
concatenate(kernel_source, metal::reduce_utils(), metal::scatter());
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::reduce_utils()
<< metal::scatter();
std::string out_type_str = get_type_string(out.dtype());
std::string idx_type_str =
@@ -279,7 +270,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
}
auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx);
kernel_source += fmt::format(
kernel_source << fmt::format(
scatter_kernels,
type_to_name(out) + idx_type_name + "_" + op_name,
out_type_str,
@@ -289,9 +280,8 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
idx_args,
idx_arr,
upd_contig,
nwork,
large ? "size_t" : "uint");
return kernel_source;
nwork);
return kernel_source.str();
});
auto& compute_encoder = d.get_command_encoder(s.index);
@@ -299,7 +289,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
size_t nthreads = upd.size();
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder->setComputePipelineState(kernel);
// Set all the buffers
compute_encoder.set_input_array(upd, 1);
@@ -333,14 +323,14 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
// Need placeholders so Metal doesn't compalain
int shape_ = 0;
size_t stride_ = 0;
compute_encoder.set_bytes(shape_, 3);
compute_encoder.set_bytes(stride_, 4);
compute_encoder->setBytes(&shape_, sizeof(int), 3);
compute_encoder->setBytes(&stride_, sizeof(size_t), 4);
} else {
compute_encoder.set_vector_bytes(upd.shape(), 3);
compute_encoder.set_vector_bytes(upd.strides(), 4);
set_vector_bytes(compute_encoder, upd.shape(), 3);
set_vector_bytes(compute_encoder, upd.strides(), 4);
}
compute_encoder.set_bytes(upd_ndim, 5);
compute_encoder.set_bytes(upd_size, 6);
compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 5);
compute_encoder->setBytes(&upd_size, sizeof(size_t), 6);
// Set output info
size_t out_ndim = out.ndim();
@@ -348,14 +338,14 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
// Need placeholders so Metal doesn't compalain
int shape_ = 0;
size_t stride_ = 0;
compute_encoder.set_bytes(shape_, 7);
compute_encoder.set_bytes(stride_, 8);
compute_encoder->setBytes(&shape_, sizeof(int), 7);
compute_encoder->setBytes(&stride_, sizeof(size_t), 8);
} else {
compute_encoder.set_vector_bytes(out.shape(), 7);
compute_encoder.set_vector_bytes(out.strides(), 8);
set_vector_bytes(compute_encoder, out.shape(), 7);
set_vector_bytes(compute_encoder, out.strides(), 8);
}
compute_encoder.set_bytes(out_ndim, 9);
compute_encoder.set_vector_bytes(axes_, 10);
compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9);
compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10);
// Set index info
if (idx_ndim == 0) {
@@ -365,11 +355,11 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
idx_strides.push_back(0);
idx_contigs.push_back(false);
}
compute_encoder.set_vector_bytes(idx_shapes, 11);
compute_encoder.set_vector_bytes(idx_strides, 12);
compute_encoder.set_vector_bytes(idx_contigs, 13);
compute_encoder.set_bytes(idx_ndim, 14);
compute_encoder.set_bytes(idx_size, 15);
set_vector_bytes(compute_encoder, idx_shapes, 11);
set_vector_bytes(compute_encoder, idx_strides, 12);
set_vector_bytes(compute_encoder, idx_contigs, 13);
compute_encoder->setBytes(&idx_ndim, sizeof(int), 14);
compute_encoder->setBytes(&idx_size, sizeof(size_t), 15);
// Set index buffers
for (int i = 0; i < nidx; ++i) {
@@ -385,7 +375,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
throw std::runtime_error("[Scatter::eval_gpu] Invalid number of threads");
}
MTL::Size group_dims = get_block_dims(upd_size, grid_y, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
} // namespace mlx::core

View File

@@ -1,7 +1,7 @@
// Copyright © 2023-2024 Apple Inc.
constexpr std::string_view gather_kernels = R"(
[[kernel]] void gather{0}_{3}_{6}_{7}(
[[kernel]] void gather{0}_{3}_{6}(
const device {1}* src [[buffer(0)]],
device {1}* out [[buffer(1)]],
const constant int* src_shape [[buffer(2)]],
@@ -19,7 +19,7 @@ constexpr std::string_view gather_kernels = R"(
Indices<{2}, {3}> idxs{{
{{ {5} }}, idx_shapes, idx_strides, idx_contigs, idx_ndim}};
return gather_impl<{1}, {2}, {3}, {6}, {7}>(
return gather_impl<{1}, {2}, {3}, {6}>(
src,
out,
src_shape,
@@ -34,7 +34,7 @@ constexpr std::string_view gather_kernels = R"(
)";
constexpr std::string_view scatter_kernels = R"(
[[kernel]] void scatter{0}_{4}_updc_{7}_nwork{8}_{9}(
[[kernel]] void scatter{0}_{4}_updc_{7}_nwork{8}(
const device {1}* updates [[buffer(1)]],
device mlx_atomic<{1}>* out [[buffer(2)]],
const constant int* upd_shape [[buffer(3)]],
@@ -54,7 +54,7 @@ constexpr std::string_view scatter_kernels = R"(
uint2 gid [[thread_position_in_grid]]) {{
Indices<{2}, {4}> idxs{{ {{ {6} }}, idx_shapes, idx_strides, idx_contigs, idx_ndim}};
return scatter_impl<{1}, {2}, {3}, {4}, {7}, {8}, {9}>(
return scatter_impl<{1}, {2}, {3}, {4}, {7}, {8}>(
updates,
out,
upd_shape,

View File

@@ -1,4 +1,5 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/metal/jit/arange.h"
#include "mlx/backend/metal/jit/gemv_masked.h"
@@ -45,27 +46,25 @@ MTL::ComputePipelineState* get_unary_kernel(
auto lib = d.get_library(lib_name, [&]() {
auto in_t = get_type_string(in_type);
auto out_t = get_type_string(out_type);
std::string kernel_source = metal::utils();
concatenate(kernel_source, metal::unary_ops(), metal::unary());
kernel_source +=
get_template_definition("v_" + lib_name, "unary_v", in_t, out_t, op);
kernel_source +=
get_template_definition("v2_" + lib_name, "unary_v2", in_t, out_t, op);
kernel_source += get_template_definition(
"gn1_" + lib_name, "unary_g", in_t, out_t, op, 1, "uint");
kernel_source += get_template_definition(
"gn4large_" + lib_name, "unary_g", in_t, out_t, op, 4);
return kernel_source;
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::unary_ops() << metal::unary();
kernel_source << get_template_definition(
"v_" + lib_name, "unary_v", in_t, out_t, op);
kernel_source << get_template_definition(
"v2_" + lib_name, "unary_v2", in_t, out_t, op);
kernel_source << get_template_definition(
"gn4_" + lib_name, "unary_g", in_t, out_t, op, 4);
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
}
void append_binary_kernels(
void add_binary_kernels(
const std::string lib_name,
Dtype in_type,
Dtype out_type,
const std::string op,
std::string& kernel_source) {
std::ostringstream& kernel_source) {
const std::array<std::pair<std::string, std::string>, 10> kernel_types = {{
{"ss", "binary_ss"},
{"vs", "binary_vs"},
@@ -75,24 +74,26 @@ void append_binary_kernels(
{"sv2", "binary_sv2"},
{"vv2", "binary_vv2"},
{"g1", "binary_g_nd1"},
{"g2large", "binary_g_nd2"},
{"g3large", "binary_g_nd3"},
{"g2", "binary_g_nd2"},
{"g3", "binary_g_nd3"},
}};
auto in_t = get_type_string(in_type);
auto out_t = get_type_string(out_type);
for (auto& [name, func] : kernel_types) {
kernel_source +=
get_template_definition(name + "_" + lib_name, func, in_t, out_t, op);
std::string template_def;
template_def = get_template_definition(
name + "_" + lib_name,
func,
get_type_string(in_type),
get_type_string(out_type),
op);
kernel_source << template_def;
}
kernel_source += get_template_definition(
"g2_" + lib_name, "binary_g_nd2", in_t, out_t, op, "uint");
kernel_source += get_template_definition(
"g3_" + lib_name, "binary_g_nd3", in_t, out_t, op, "uint");
kernel_source += get_template_definition(
"gn2_" + lib_name, "binary_g", in_t, out_t, op, 2, "uint");
kernel_source += get_template_definition(
"gn4large_" + lib_name, "binary_g", in_t, out_t, op, 4);
kernel_source << get_template_definition(
"gn4_" + lib_name,
"binary_g",
get_type_string(in_type),
get_type_string(out_type),
op,
4);
}
MTL::ComputePipelineState* get_binary_kernel(
@@ -103,11 +104,10 @@ MTL::ComputePipelineState* get_binary_kernel(
const std::string op) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name, [&]() {
std::string kernel_source;
kernel_source = metal::utils();
concatenate(kernel_source, metal::binary_ops(), metal::binary());
append_binary_kernels(lib_name, in_type, out_type, op, kernel_source);
return kernel_source;
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::binary_ops() << metal::binary();
add_binary_kernels(lib_name, in_type, out_type, op, kernel_source);
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
}
@@ -120,10 +120,11 @@ MTL::ComputePipelineState* get_binary_two_kernel(
const std::string op) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name, [&]() {
std::string kernel_source = metal::utils();
concatenate(kernel_source, metal::binary_ops(), metal::binary_two());
append_binary_kernels(lib_name, in_type, out_type, op, kernel_source);
return kernel_source;
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::binary_ops()
<< metal::binary_two();
add_binary_kernels(lib_name, in_type, out_type, op, kernel_source);
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
}
@@ -135,29 +136,24 @@ MTL::ComputePipelineState* get_ternary_kernel(
const std::string op) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name, [&]() {
auto t_str = get_type_string(type);
std::string kernel_source = metal::utils();
concatenate(kernel_source, metal::ternary_ops(), metal::ternary());
std::ostringstream kernel_source;
const std::array<std::pair<std::string, std::string>, 5> kernel_types = {{
{"v", "ternary_v"},
{"v2", "ternary_v2"},
{"g1", "ternary_g_nd1"},
{"g2large", "ternary_g_nd2"},
{"g3large", "ternary_g_nd3"},
{"g2", "ternary_g_nd2"},
{"g3", "ternary_g_nd3"},
}};
kernel_source << metal::utils() << metal::ternary_ops() << metal::ternary();
for (auto& [name, func] : kernel_types) {
kernel_source +=
get_template_definition(name + "_" + lib_name, func, t_str, op);
std::string template_def;
template_def = get_template_definition(
name + "_" + lib_name, func, get_type_string(type), op);
kernel_source << template_def;
}
kernel_source += get_template_definition(
"g2_" + lib_name, "ternary_g_nd2", t_str, op, "uint");
kernel_source += get_template_definition(
"g3_" + lib_name, "ternary_g_nd3", t_str, op, "uint");
kernel_source += get_template_definition(
"gn2_" + lib_name, "ternary_g", t_str, op, 2, "uint");
kernel_source += get_template_definition(
"gn4large_" + lib_name, "ternary_g", t_str, op, 4);
return kernel_source;
kernel_source << get_template_definition(
"gn4_" + lib_name, "ternary_g", get_type_string(type), op, 4);
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
}
@@ -169,43 +165,31 @@ MTL::ComputePipelineState* get_copy_kernel(
const array& out) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name, [&]() {
std::string kernel_source = metal::utils();
kernel_source += metal::copy();
std::ostringstream kernel_source;
auto in_type = get_type_string(in.dtype());
auto out_type = get_type_string(out.dtype());
kernel_source +=
get_template_definition("s_" + lib_name, "copy_s", in_type, out_type);
kernel_source +=
get_template_definition("v_" + lib_name, "copy_v", in_type, out_type);
kernel_source += get_template_definition(
"g1_" + lib_name, "copy_g_nd1", in_type, out_type);
kernel_source += get_template_definition(
"g2_" + lib_name, "copy_g_nd2", in_type, out_type, "int");
kernel_source += get_template_definition(
"g3_" + lib_name, "copy_g_nd3", in_type, out_type, "int");
kernel_source += get_template_definition(
"gn2_" + lib_name, "copy_g", in_type, out_type, 2, "int");
kernel_source += get_template_definition(
"gg1_" + lib_name, "copy_gg_nd1", in_type, out_type);
kernel_source += get_template_definition(
"gg2_" + lib_name, "copy_gg_nd2", in_type, out_type, "int");
kernel_source += get_template_definition(
"gg3_" + lib_name, "copy_gg_nd3", in_type, out_type, "int");
kernel_source += get_template_definition(
"ggn2_" + lib_name, "copy_gg", in_type, out_type, 2, "int");
kernel_source += get_template_definition(
"g2large_" + lib_name, "copy_g_nd2", in_type, out_type);
kernel_source += get_template_definition(
"g3large_" + lib_name, "copy_g_nd3", in_type, out_type);
kernel_source += get_template_definition(
"gn4large_" + lib_name, "copy_g", in_type, out_type, 4);
kernel_source += get_template_definition(
"gg2large_" + lib_name, "copy_gg_nd2", in_type, out_type);
kernel_source += get_template_definition(
"gg3large_" + lib_name, "copy_gg_nd3", in_type, out_type);
kernel_source += get_template_definition(
"ggn4large_" + lib_name, "copy_gg", in_type, out_type, 4);
return kernel_source;
kernel_source << metal::utils() << metal::copy()
<< get_template_definition(
"s_" + lib_name, "copy_s", in_type, out_type)
<< get_template_definition(
"v_" + lib_name, "copy_v", in_type, out_type)
<< get_template_definition(
"g1_" + lib_name, "copy_g_nd1", in_type, out_type)
<< get_template_definition(
"g2_" + lib_name, "copy_g_nd2", in_type, out_type)
<< get_template_definition(
"g3_" + lib_name, "copy_g_nd3", in_type, out_type)
<< get_template_definition(
"gn4_" + lib_name, "copy_g", in_type, out_type, 4)
<< get_template_definition(
"gg1_" + lib_name, "copy_gg_nd1", in_type, out_type)
<< get_template_definition(
"gg2_" + lib_name, "copy_gg_nd2", in_type, out_type)
<< get_template_definition(
"gg3_" + lib_name, "copy_gg_nd3", in_type, out_type)
<< get_template_definition(
"ggn4_" + lib_name, "copy_gg", in_type, out_type, 4);
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
}
@@ -337,17 +321,17 @@ MTL::ComputePipelineState* get_reduce_init_kernel(
const std::string& kernel_name,
const std::string& func_name,
const std::string& op_name,
const Dtype& out_type) {
const array& out) {
auto lib = d.get_library(kernel_name, [&]() {
std::ostringstream kernel_source;
std::string op_type = op_name;
op_type[0] = std::toupper(op_name[0]);
auto out_t = get_type_string(out_type);
std::string op = op_type + "<" + out_t + ">";
std::string kernel_source = metal::utils();
kernel_source += metal::reduce_utils();
kernel_source += metal::reduce();
kernel_source += get_template_definition(kernel_name, func_name, out_t, op);
return kernel_source;
auto out_type = get_type_string(out.dtype());
std::string op = op_type + "<" + out_type + ">";
kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce();
kernel_source << get_template_definition(
kernel_name, func_name, out_type, op);
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
}
@@ -357,31 +341,30 @@ MTL::ComputePipelineState* get_reduce_kernel(
const std::string& kernel_name,
const std::string& func_name,
const std::string& op_name,
const Dtype& in_type,
const Dtype& out_type,
const std::string& idx_t,
const array& in,
const array& out,
int ndim /* = -1 */,
int bm /* = -1 */,
int bn /* = -1 */) {
auto lib = d.get_library(kernel_name, [&]() {
std::string op_type = op_name;
op_type[0] = std::toupper(op_name[0]);
auto in_t = get_type_string(in_type);
auto out_t = get_type_string(out_type);
std::string op = op_type + "<" + out_t + ">";
std::string kernel_source = metal::utils();
concatenate(kernel_source, metal::reduce_utils(), metal::reduce());
std::ostringstream kernel_source;
auto in_type = get_type_string(in.dtype());
auto out_type = get_type_string(out.dtype());
std::string op = op_type + "<" + out_type + ">";
kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce();
if (bm >= 0) {
kernel_source += get_template_definition(
kernel_name, func_name, in_t, out_t, op, idx_t, ndim, bm, bn);
kernel_source << get_template_definition(
kernel_name, func_name, in_type, out_type, op, ndim, bm, bn);
} else if (ndim >= 0) {
kernel_source += get_template_definition(
kernel_name, func_name, in_t, out_t, op, idx_t, ndim);
kernel_source << get_template_definition(
kernel_name, func_name, in_type, out_type, op, ndim);
} else {
kernel_source += get_template_definition(
kernel_name, func_name, in_t, out_t, op, idx_t);
kernel_source << get_template_definition(
kernel_name, func_name, in_type, out_type, op);
}
return kernel_source;
return kernel_source.str();
});
auto st = d.get_kernel(kernel_name, lib);
return st;

View File

@@ -81,16 +81,15 @@ MTL::ComputePipelineState* get_reduce_init_kernel(
const std::string& kernel_name,
const std::string& func_name,
const std::string& op_name,
const Dtype& out_type);
const array& out);
MTL::ComputePipelineState* get_reduce_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& func_name,
const std::string& op_name,
const Dtype& in_type,
const Dtype& out_type,
const std::string& idx_t,
const array& in,
const array& out,
int ndim = -1,
int bm = -1,
int bn = -1);

View File

@@ -1,27 +1,13 @@
set(BASE_HEADERS
metal_3_1/bf16.h
metal_3_0/bf16.h
bf16_math.h
complex.h
defines.h
expm1f.h
utils.h)
set(BASE_HEADERS bf16.h bf16_math.h complex.h defines.h expm1f.h utils.h)
function(build_kernel_base TARGET SRCFILE DEPS)
set(METAL_FLAGS -Wall -Wextra -fno-fast-math)
if(MLX_METAL_DEBUG)
set(METAL_FLAGS ${METAL_FLAGS} -gline-tables-only -frecord-sources)
endif()
if(MLX_METAL_VERSION GREATER_EQUAL 310)
set(VERSION_INCLUDES
${PROJECT_SOURCE_DIR}/mlx/backend/metal/kernels/metal_3_1)
else()
set(VERSION_INCLUDES
${PROJECT_SOURCE_DIR}/mlx/backend/metal/kernels/metal_3_0)
endif()
add_custom_command(
COMMAND xcrun -sdk macosx metal ${METAL_FLAGS} -c ${SRCFILE}
-I${PROJECT_SOURCE_DIR} -I${VERSION_INCLUDES} -o ${TARGET}.air
-I${PROJECT_SOURCE_DIR} -o ${TARGET}.air
DEPENDS ${SRCFILE} ${DEPS} ${BASE_HEADERS}
OUTPUT ${TARGET}.air
COMMENT "Building ${TARGET}.air"
@@ -44,7 +30,9 @@ build_kernel(layer_norm)
build_kernel(random)
build_kernel(rms_norm)
build_kernel(rope)
build_kernel(scaled_dot_product_attention sdpa_vector.h)
build_kernel(
scaled_dot_product_attention scaled_dot_product_attention_params.h
sdpa_vector.h steel/defines.h steel/gemm/transforms.h steel/utils.h)
set(STEEL_HEADERS
steel/defines.h
@@ -66,24 +54,6 @@ set(STEEL_HEADERS
steel/utils/type_traits.h
steel/utils/integral_constant.h)
set(STEEL_ATTN_HEADERS
steel/defines.h
steel/utils.h
steel/gemm/gemm.h
steel/gemm/mma.h
steel/gemm/loader.h
steel/gemm/transforms.h
steel/utils/type_traits.h
steel/utils/integral_constant.h
steel/attn/attn.h
steel/attn/loader.h
steel/attn/mma.h
steel/attn/params.h
steel/attn/transforms.h
steel/attn/kernels/steel_attention.h)
build_kernel(steel/attn/kernels/steel_attention ${STEEL_ATTN_HEADERS})
if(NOT MLX_METAL_JIT)
build_kernel(arange arange.h)
build_kernel(binary binary.h binary_ops.h)

View File

@@ -1,7 +1,7 @@
// Copyright © 2023-2024 Apple Inc.
// clang-format off
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/arange.h"
#define instantiate_arange(tname, type) \

View File

@@ -6,6 +6,12 @@
using namespace metal;
#if (MLX_METAL_VERSION >= 310) || (__METAL_VERSION__ >= 310)
typedef bfloat bfloat16_t;
#else
/////////////////////////////////////////////////////////////////////////////
// Helpers
/////////////////////////////////////////////////////////////////////////////
@@ -305,10 +311,7 @@ METAL_FUNC bool isnan(_MLX_BFloat16 x) {
} // namespace metal
#pragma METAL internals : disable
inline uint16_t bfloat16_to_uint16(const bfloat16_t x) {
return x.bits_;
}
inline bfloat16_t uint16_to_bfloat16(const uint16_t x) {
return _MLX_BFloat16(x, _MLX_BFloat16::bits_to_bfloat());
}
#endif
#include "mlx/backend/metal/kernels/bf16_math.h"

View File

@@ -2,6 +2,8 @@
#pragma once
#include "mlx/backend/metal/kernels/bf16.h"
///////////////////////////////////////////////////////////////////////////////
// Metal math for bfloat16
///////////////////////////////////////////////////////////////////////////////
@@ -367,6 +369,18 @@ instantiate_metal_math_funcs(
return static_cast<otype>(__metal_simd_xor(static_cast<ctype>(data))); \
}
#if (MLX_METAL_VERSION >= 310) || (__METAL_VERSION__ >= 310)
#define bfloat16_to_uint16(x) as_type<uint16_t>(x)
#define uint16_to_bfloat16(x) as_type<bfloat16_t>(x)
#else
#define bfloat16_to_uint16(x) x.bits_
#define uint16_to_bfloat16(x) _MLX_BFloat16(x, _MLX_BFloat16::bits_to_bfloat())
#endif
namespace metal {
instantiate_metal_simd_comm_funcs(

View File

@@ -69,7 +69,7 @@ template <typename T, typename U, typename Op>
c[offset] = Op()(a[offset], b[offset]);
}
template <typename T, typename U, typename Op, typename IdxT = size_t>
template <typename T, typename U, typename Op>
[[kernel]] void binary_g_nd1(
device const T* a,
device const T* b,
@@ -77,12 +77,12 @@ template <typename T, typename U, typename Op, typename IdxT = size_t>
constant const size_t& a_stride,
constant const size_t& b_stride,
uint index [[thread_position_in_grid]]) {
auto a_idx = elem_to_loc_1<size_t, IdxT>(index, a_stride);
auto b_idx = elem_to_loc_1<size_t, IdxT>(index, b_stride);
auto a_idx = elem_to_loc_1(index, a_stride);
auto b_idx = elem_to_loc_1(index, b_stride);
c[index] = Op()(a[a_idx], b[b_idx]);
}
template <typename T, typename U, typename Op, typename IdxT = size_t>
template <typename T, typename U, typename Op>
[[kernel]] void binary_g_nd2(
device const T* a,
device const T* b,
@@ -91,13 +91,13 @@ template <typename T, typename U, typename Op, typename IdxT = size_t>
constant const size_t b_strides[2],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_2<size_t, IdxT>(index, a_strides);
auto b_idx = elem_to_loc_2<size_t, IdxT>(index, b_strides);
IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y;
auto a_idx = elem_to_loc_2(index, a_strides);
auto b_idx = elem_to_loc_2(index, b_strides);
size_t out_idx = index.x + size_t(grid_dim.x) * index.y;
c[out_idx] = Op()(a[a_idx], b[b_idx]);
}
template <typename T, typename U, typename Op, typename IdxT = size_t>
template <typename T, typename U, typename Op>
[[kernel]] void binary_g_nd3(
device const T* a,
device const T* b,
@@ -106,18 +106,14 @@ template <typename T, typename U, typename Op, typename IdxT = size_t>
constant const size_t b_strides[3],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_3<size_t, IdxT>(index, a_strides);
auto b_idx = elem_to_loc_3<size_t, IdxT>(index, b_strides);
IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z);
auto a_idx = elem_to_loc_3(index, a_strides);
auto b_idx = elem_to_loc_3(index, b_strides);
size_t out_idx =
index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
c[out_idx] = Op()(a[a_idx], b[b_idx]);
}
template <
typename T,
typename U,
typename Op,
int N = 1,
typename IdxT = size_t>
template <typename T, typename U, typename Op, int N = 1>
[[kernel]] void binary_g(
device const T* a,
device const T* b,
@@ -128,12 +124,13 @@ template <
constant const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd<size_t, IdxT>(
auto idx = elem_to_loc_2_nd(
{N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim);
auto xshape = shape[ndim - 1];
IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
IdxT a_xstride = a_strides[ndim - 1];
IdxT b_xstride = b_strides[ndim - 1];
size_t out_idx =
N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z);
auto a_xstride = a_strides[ndim - 1];
auto b_xstride = b_strides[ndim - 1];
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
c[out_idx++] = Op()(a[idx.x], b[idx.y]);
idx.x += a_xstride;

View File

@@ -9,22 +9,18 @@
#include "mlx/backend/metal/kernels/binary_ops.h"
#include "mlx/backend/metal/kernels/binary.h"
#define instantiate_binary_all(op, tname, itype, otype) \
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
instantiate_kernel("gn2_" #op #tname, binary_g, itype, otype, op, 2, uint) \
instantiate_kernel("gn4large_" #op #tname, binary_g, itype, otype, op, 4) \
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op, uint) \
instantiate_kernel("g1large_" #op #tname, binary_g_nd1, itype, otype, op) \
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op, uint) \
instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, uint) \
instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op)
#define instantiate_binary_all(op, tname, itype, otype) \
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
instantiate_kernel("gn4_" #op #tname, binary_g, itype, otype, op, 4) \
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op) \
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op) \
#define instantiate_binary_integer(op) \
instantiate_binary_all(op, uint8, uint8_t, uint8_t) \

View File

@@ -90,7 +90,7 @@ template <typename T, typename U, typename Op>
d[offset] = out[1];
}
template <typename T, typename U, typename Op, typename IdxT = size_t>
template <typename T, typename U, typename Op>
[[kernel]] void binary_g_nd1(
device const T* a,
device const T* b,
@@ -99,14 +99,14 @@ template <typename T, typename U, typename Op, typename IdxT = size_t>
constant const size_t& a_stride,
constant const size_t& b_stride,
uint index [[thread_position_in_grid]]) {
auto a_idx = elem_to_loc_1<size_t, IdxT>(index, a_stride);
auto b_idx = elem_to_loc_1<size_t, IdxT>(index, b_stride);
auto a_idx = elem_to_loc_1(index, a_stride);
auto b_idx = elem_to_loc_1(index, b_stride);
auto out = Op()(a[a_idx], b[b_idx]);
c[index] = out[0];
d[index] = out[1];
}
template <typename T, typename U, typename Op, typename IdxT = size_t>
template <typename T, typename U, typename Op>
[[kernel]] void binary_g_nd2(
device const T* a,
device const T* b,
@@ -116,15 +116,15 @@ template <typename T, typename U, typename Op, typename IdxT = size_t>
constant const size_t b_strides[2],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_2<size_t, IdxT>(index, a_strides);
auto b_idx = elem_to_loc_2<size_t, IdxT>(index, b_strides);
IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y;
auto a_idx = elem_to_loc_2(index, a_strides);
auto b_idx = elem_to_loc_2(index, b_strides);
size_t out_idx = index.x + size_t(grid_dim.x) * index.y;
auto out = Op()(a[a_idx], b[b_idx]);
c[out_idx] = out[0];
d[out_idx] = out[1];
}
template <typename T, typename U, typename Op, typename IdxT = size_t>
template <typename T, typename U, typename Op>
[[kernel]] void binary_g_nd3(
device const T* a,
device const T* b,
@@ -134,20 +134,16 @@ template <typename T, typename U, typename Op, typename IdxT = size_t>
constant const size_t b_strides[3],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_3<size_t, IdxT>(index, a_strides);
auto b_idx = elem_to_loc_3<size_t, IdxT>(index, b_strides);
IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z);
auto a_idx = elem_to_loc_3(index, a_strides);
auto b_idx = elem_to_loc_3(index, b_strides);
size_t out_idx =
index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
auto out = Op()(a[a_idx], b[b_idx]);
c[out_idx] = out[0];
d[out_idx] = out[1];
}
template <
typename T,
typename U,
typename Op,
int N = 1,
typename IdxT = size_t>
template <typename T, typename U, typename Op, int N = 1>
[[kernel]] void binary_g(
device const T* a,
device const T* b,
@@ -159,12 +155,13 @@ template <
constant const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd<size_t, IdxT>(
auto idx = elem_to_loc_2_nd(
{N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim);
auto xshape = shape[ndim - 1];
IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
IdxT a_xstride = a_strides[ndim - 1];
IdxT b_xstride = b_strides[ndim - 1];
size_t out_idx =
N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z);
auto a_xstride = a_strides[ndim - 1];
auto b_xstride = b_strides[ndim - 1];
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
auto out = Op()(a[idx.x], b[idx.y]);
c[out_idx] = out[0];

View File

@@ -7,22 +7,18 @@
#include "mlx/backend/metal/kernels/binary_ops.h"
#include "mlx/backend/metal/kernels/binary_two.h"
#define instantiate_binary_all(op, tname, itype, otype) \
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
instantiate_kernel("gn2_" #op #tname, binary_g, itype, otype, op, 2, uint) \
instantiate_kernel("gn4large_" #op #tname, binary_g, itype, otype, op, 4) \
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op, uint) \
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op, uint) \
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, uint) \
instantiate_kernel("g1large_" #op #tname, binary_g_nd1, itype, otype, op) \
instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \
instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op)
#define instantiate_binary_all(op, tname, itype, otype) \
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
instantiate_kernel("gn4_" #op #tname, binary_g, itype, otype, op, 4) \
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op) \
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op) \
#define instantiate_binary_float(op) \
instantiate_binary_all(op, float16, half, half) \

View File

@@ -4,8 +4,8 @@
#include <metal_simdgroup_matrix>
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/steel/conv/params.h"
#include "mlx/backend/metal/kernels/utils.h"
#define MLX_MTL_CONST static constant constexpr const

View File

@@ -36,42 +36,42 @@ template <typename T, typename U>
dst[offset] = static_cast<U>(src[offset]);
}
template <typename T, typename U, typename IdxT = int64_t>
template <typename T, typename U>
[[kernel]] void copy_g_nd1(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t& src_stride [[buffer(3)]],
uint index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_1<int64_t, IdxT>(index, src_stride);
auto src_idx = elem_to_loc_1(index, src_stride);
dst[index] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U, typename IdxT = int64_t>
template <typename T, typename U>
[[kernel]] void copy_g_nd2(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc_2<int64_t, IdxT>(index, src_strides);
IdxT dst_idx = index.x + IdxT(grid_dim.x) * index.y;
auto src_idx = elem_to_loc_2(index, src_strides);
int64_t dst_idx = index.x + (int64_t)grid_dim.x * index.y;
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U, typename IdxT = int64_t>
template <typename T, typename U>
[[kernel]] void copy_g_nd3(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc_3<int64_t, IdxT>(index, src_strides);
IdxT dst_idx =
index.x + IdxT(grid_dim.x) * (index.y + IdxT(grid_dim.y) * index.z);
auto src_idx = elem_to_loc_3(index, src_strides);
int64_t dst_idx =
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U, int N = 1, typename IdxT = int64_t>
template <typename T, typename U, int N = 1>
[[kernel]] void copy_g(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
@@ -80,16 +80,17 @@ template <typename T, typename U, int N = 1, typename IdxT = int64_t>
constant const int& ndim [[buffer(5)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc<int64_t, IdxT>(
auto src_idx = elem_to_loc(
{N * index.x, index.y, index.z}, src_shape, src_strides, ndim);
if (N == 1) {
IdxT dst_idx =
index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z);
int64_t dst_idx =
index.x + grid_dim.x * (index.y + int64_t(grid_dim.y) * index.z);
dst[dst_idx] = static_cast<U>(src[src_idx]);
return;
}
auto xshape = src_shape[ndim - 1];
IdxT dst_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
int64_t dst_idx =
N * index.x + xshape * (index.y + int64_t(grid_dim.y) * index.z);
auto src_xstride = src_strides[ndim - 1];
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
dst[dst_idx + i] = static_cast<U>(src[src_idx]);
@@ -97,43 +98,43 @@ template <typename T, typename U, int N = 1, typename IdxT = int64_t>
}
}
template <typename T, typename U, typename IdxT = int64_t>
template <typename T, typename U>
[[kernel]] void copy_gg_nd1(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t& src_stride [[buffer(3)]],
constant const int64_t& dst_stride [[buffer(4)]],
uint index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_1<int64_t, IdxT>(index, src_stride);
auto dst_idx = elem_to_loc_1<int64_t, IdxT>(index, dst_stride);
auto src_idx = elem_to_loc_1(index, src_stride);
auto dst_idx = elem_to_loc_1(index, dst_stride);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U, typename IdxT = int64_t>
template <typename T, typename U>
[[kernel]] void copy_gg_nd2(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint2 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_2<int64_t, IdxT>(index, src_strides);
auto dst_idx = elem_to_loc_2<int64_t, IdxT>(index, dst_strides);
auto src_idx = elem_to_loc_2(index, src_strides);
auto dst_idx = elem_to_loc_2(index, dst_strides);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U, typename IdxT = int64_t>
template <typename T, typename U>
[[kernel]] void copy_gg_nd3(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint3 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_3<int64_t, IdxT>(index, src_strides);
auto dst_idx = elem_to_loc_3<int64_t, IdxT>(index, dst_strides);
auto src_idx = elem_to_loc_3(index, src_strides);
auto dst_idx = elem_to_loc_3(index, dst_strides);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U, int N = 1, typename IdxT = int64_t>
template <typename T, typename U, int N = 1>
[[kernel]] void copy_gg(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
@@ -142,7 +143,7 @@ template <typename T, typename U, int N = 1, typename IdxT = int64_t>
constant const int64_t* dst_strides [[buffer(4)]],
constant const int& ndim [[buffer(5)]],
uint3 index [[thread_position_in_grid]]) {
auto idx = elem_to_loc_2_nd<int64_t, IdxT>(
auto idx = elem_to_loc_2_nd(
{N * index.x, index.y, index.z},
src_shape,
src_strides,
@@ -152,8 +153,8 @@ template <typename T, typename U, int N = 1, typename IdxT = int64_t>
dst[idx.y] = static_cast<U>(src[idx.x]);
return;
}
IdxT src_xstride = src_strides[ndim - 1];
IdxT dst_xstride = dst_strides[ndim - 1];
auto src_xstride = src_strides[ndim - 1];
auto dst_xstride = dst_strides[ndim - 1];
auto xshape = src_shape[ndim - 1];
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
dst[idx.y] = static_cast<U>(src[idx.x]);

View File

@@ -2,29 +2,22 @@
// clang-format off
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/copy.h"
#define instantiate_copy_all(tname, itype, otype) \
instantiate_kernel("s_copy" #tname, copy_s, itype, otype) \
instantiate_kernel("v_copy" #tname, copy_v, itype, otype) \
instantiate_kernel("s2_copy" #tname, copy_s2, itype, otype) \
instantiate_kernel("v2_copy" #tname, copy_v2, itype, otype) \
instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype, int) \
instantiate_kernel("g2_copy" #tname, copy_g_nd2, itype, otype, int) \
instantiate_kernel("g3_copy" #tname, copy_g_nd3, itype, otype, int) \
instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, itype, otype, int) \
instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, itype, otype, int) \
instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, itype, otype, int) \
instantiate_kernel("gn2_copy" #tname, copy_g, itype, otype, 2, int) \
instantiate_kernel("ggn2_copy" #tname, copy_gg, itype, otype, 2, int) \
instantiate_kernel("g1large_copy" #tname, copy_g_nd1, itype, otype) \
instantiate_kernel("g2large_copy" #tname, copy_g_nd2, itype, otype) \
instantiate_kernel("g3large_copy" #tname, copy_g_nd3, itype, otype) \
instantiate_kernel("gg1large_copy" #tname, copy_gg_nd1, itype, otype) \
instantiate_kernel("gg2large_copy" #tname, copy_gg_nd2, itype, otype) \
instantiate_kernel("gg3large_copy" #tname, copy_gg_nd3, itype, otype) \
instantiate_kernel("gn4large_copy" #tname, copy_g, itype, otype, 4) \
instantiate_kernel("ggn4large_copy" #tname, copy_gg, itype, otype, 4)
#define instantiate_copy_all(tname, itype, otype) \
instantiate_kernel("s_copy" #tname, copy_s, itype, otype) \
instantiate_kernel("v_copy" #tname, copy_v, itype, otype) \
instantiate_kernel("s2_copy" #tname, copy_s2, itype, otype) \
instantiate_kernel("v2_copy" #tname, copy_v2, itype, otype) \
instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype) \
instantiate_kernel("g2_copy" #tname, copy_g_nd2, itype, otype) \
instantiate_kernel("g3_copy" #tname, copy_g_nd3, itype, otype) \
instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, itype, otype) \
instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, itype, otype) \
instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, itype, otype) \
instantiate_kernel("gn4_copy" #tname, copy_g, itype, otype, 4) \
instantiate_kernel("ggn4_copy" #tname, copy_gg, itype, otype, 4)
#define instantiate_copy_itype(itname, itype) \
instantiate_copy_all(itname ##bool_, itype, bool) \

View File

@@ -4,7 +4,7 @@
#include "mlx/backend/metal/kernels/indexing.h"
template <typename T, typename IdxT, int NIDX, int IDX_NDIM, typename LocT>
template <typename T, typename IdxT, int NIDX, int IDX_NDIM>
METAL_FUNC void gather_impl(
const device T* src [[buffer(0)]],
device T* out [[buffer(1)]],
@@ -16,18 +16,18 @@ METAL_FUNC void gather_impl(
const thread Indices<IdxT, NIDX>& indices,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
LocT src_idx = 0;
size_t src_idx = 0;
for (int i = 0; i < NIDX; ++i) {
LocT idx_loc;
size_t idx_loc;
if (IDX_NDIM == 0) {
idx_loc = 0;
} else if (IDX_NDIM == 1) {
idx_loc = index.x * static_cast<LocT>(indices.strides[indices.ndim * i]);
idx_loc = index.x * indices.strides[indices.ndim * i];
} else {
idx_loc = index.x * static_cast<LocT>(indices.strides[indices.ndim * i]);
idx_loc = index.x * indices.strides[indices.ndim * i];
idx_loc += indices.row_contiguous[i]
? index.y
: elem_to_loc<size_t, LocT>(
: elem_to_loc(
index.y,
&indices.shapes[indices.ndim * i + 1],
&indices.strides[indices.ndim * i + 1],
@@ -35,17 +35,17 @@ METAL_FUNC void gather_impl(
}
auto ax = axes[i];
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], src_shape[ax]);
src_idx += static_cast<LocT>(idx_val) * static_cast<LocT>(src_strides[ax]);
src_idx += idx_val * src_strides[ax];
}
auto src_offset =
elem_to_loc<size_t, LocT>(index.z, slice_sizes, src_strides, src_ndim);
auto src_offset = elem_to_loc(index.z, slice_sizes, src_strides, src_ndim);
LocT out_idx = index.z;
size_t out_idx = index.z;
if (IDX_NDIM == 1) {
out_idx += static_cast<LocT>(grid_dim.z) * index.x;
out_idx += static_cast<size_t>(grid_dim.z) * index.x;
} else if (IDX_NDIM >= 2) {
out_idx += grid_dim.z * (index.x * static_cast<LocT>(grid_dim.y) + index.y);
out_idx +=
grid_dim.z * (index.x * static_cast<size_t>(grid_dim.y) + index.y);
}
out[out_idx] = src[src_offset + src_idx];
}

View File

@@ -3,6 +3,8 @@
#include <metal_simdgroup>
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/utils.h"
@@ -910,4 +912,4 @@ template <
// clang-format off
instantiate_gemv_t_bs_blocks(float32, float);
instantiate_gemv_t_bs_blocks(float16, half);
instantiate_gemv_t_bs_blocks(bfloat16, bfloat16_t); // clang-format on
instantiate_gemv_t_bs_blocks(bfloat16, bfloat16_t); // clang-format on

View File

@@ -4,6 +4,8 @@
#include <metal_simdgroup>
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/gemv_masked.h"

View File

@@ -14,7 +14,7 @@ struct Indices {
};
template <typename IdxT>
METAL_FUNC size_t offset_neg_idx(IdxT idx, int size) {
METAL_FUNC size_t offset_neg_idx(IdxT idx, size_t size) {
if (is_unsigned_v<IdxT>) {
return idx;
} else {

View File

@@ -1,16 +0,0 @@
// Copyright © 2024 Apple Inc.
// clang-format off
#define jit_if #if
#define jit_else #else
#define jit_endif #endif
jit_if (__METAL_VERSION__ >= 310)
#include "mlx/backend/metal/kernels/metal_3_1/bf16.h"
jit_else
#include "mlx/backend/metal/kernels/metal_3_0/bf16.h"
jit_endif // clang-format on

View File

@@ -3,6 +3,8 @@
#include <metal_common>
#include <metal_simdgroup>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/utils.h"
using namespace metal;

View File

@@ -1,16 +0,0 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include <metal_stdlib>
using namespace metal;
typedef bfloat bfloat16_t;
inline uint16_t bfloat16_to_uint16(const bfloat16_t x) {
return as_type<uint16_t>(x);
}
inline bfloat16_t uint16_to_bfloat16(const uint16_t x) {
return as_type<bfloat16_t>(x);
}

View File

@@ -13,8 +13,8 @@ MLX_MTL_CONST int QUAD_SIZE = 4;
template <typename T, typename U, int values_per_thread, int bits>
inline U load_vector(const device T* x, thread U* x_thread) {
static_assert(
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
"Template undefined for bits not in {2, 3, 4, 6, 8}");
bits == 2 || bits == 4 || bits == 8,
"Template undefined for bits not in {2, 4, 8}");
U sum = 0;
@@ -28,21 +28,6 @@ inline U load_vector(const device T* x, thread U* x_thread) {
}
}
else if (bits == 3) {
for (int i = 0; i < values_per_thread; i += 8) {
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] +
x[i + 6] + x[i + 7];
x_thread[i] = x[i];
x_thread[i + 1] = x[i + 1] / 8.0f;
x_thread[i + 2] = x[i + 2] / 64.0f;
x_thread[i + 3] = x[i + 3] / 2.0f;
x_thread[i + 4] = x[i + 4] / 16.0f;
x_thread[i + 5] = x[i + 5] / 128.0f;
x_thread[i + 6] = x[i + 6] / 4.0f;
x_thread[i + 7] = x[i + 7] / 32.0f;
}
}
else if (bits == 4) {
for (int i = 0; i < values_per_thread; i += 4) {
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
@@ -53,16 +38,6 @@ inline U load_vector(const device T* x, thread U* x_thread) {
}
}
else if (bits == 6) {
for (int i = 0; i < values_per_thread; i += 4) {
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
x_thread[i] = x[i];
x_thread[i + 1] = x[i + 1] / 64.0f;
x_thread[i + 2] = x[i + 2] / 16.0f;
x_thread[i + 3] = x[i + 3] / 4.0f;
}
}
else if (bits == 8) {
for (int i = 0; i < values_per_thread; i++) {
sum += x[i];
@@ -76,8 +51,8 @@ inline U load_vector(const device T* x, thread U* x_thread) {
template <typename T, typename U, int values_per_thread, int bits>
inline U load_vector_safe(const device T* x, thread U* x_thread, int N) {
static_assert(
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
"Template undefined for bits not in {2, 3, 4, 6, 8}");
bits == 2 || bits == 4 || bits == 8,
"Template undefined for bits not in {2, 4, 8}");
U sum = 0;
@@ -89,21 +64,8 @@ inline U load_vector_safe(const device T* x, thread U* x_thread, int N) {
x_thread[i + 2] = x[i + 2] / 16.0f;
x_thread[i + 3] = x[i + 3] / 64.0f;
}
}
else if (bits == 3) {
for (int i = 0; i < N; i += 8) {
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] +
x[i + 6] + x[i + 7];
x_thread[i] = x[i];
x_thread[i + 1] = x[i + 1] / 8.0f;
x_thread[i + 2] = x[i + 2] / 64.0f;
x_thread[i + 3] = x[i + 3] / 2.0f;
x_thread[i + 4] = x[i + 4] / 16.0f;
x_thread[i + 5] = x[i + 5] / 128.0f;
x_thread[i + 6] = x[i + 6] / 4.0f;
x_thread[i + 7] = x[i + 7] / 32.0f;
for (int i = N; i < values_per_thread; i++) {
x_thread[i] = 0;
}
}
@@ -115,15 +77,8 @@ inline U load_vector_safe(const device T* x, thread U* x_thread, int N) {
x_thread[i + 2] = x[i + 2] / 256.0f;
x_thread[i + 3] = x[i + 3] / 4096.0f;
}
}
else if (bits == 6) {
for (int i = 0; i < N; i += 4) {
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
x_thread[i] = x[i];
x_thread[i + 1] = x[i + 1] / 64.0f;
x_thread[i + 2] = x[i + 2] / 16.0f;
x_thread[i + 3] = x[i + 3] / 4.0f;
for (int i = N; i < values_per_thread; i++) {
x_thread[i] = 0;
}
}
@@ -132,10 +87,9 @@ inline U load_vector_safe(const device T* x, thread U* x_thread, int N) {
sum += x[i];
x_thread[i] = x[i];
}
}
for (int i = N; i < values_per_thread; i++) {
x_thread[i] = 0;
for (int i = N; i < values_per_thread; i++) {
x_thread[i] = 0;
}
}
return sum;
@@ -149,8 +103,8 @@ inline U qdot(
U bias,
U sum) {
static_assert(
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
"Template undefined for bits not in {2, 3, 4, 6, 8}");
bits == 2 || bits == 4 || bits == 8,
"Template undefined for bits not in {2, 4, 8}");
U accum = 0;
@@ -164,26 +118,6 @@ inline U qdot(
}
}
else if (bits == 3) {
for (int i = 0; i < (values_per_thread / 8); i++) {
x_thread += 8 * i;
w += 3 * i;
accum += (w[0] & 0x07) * x_thread[0];
accum += (w[0] & 0x38) * x_thread[1];
accum += (w[0] & 0xc0) * x_thread[2];
accum += (w[1] & 0x01) * (x_thread[2] * 256.0f);
accum += (w[1] & 0x0e) * x_thread[3];
accum += (w[1] & 0x70) * x_thread[4];
accum += (w[1] & 0x80) * x_thread[5];
accum += (w[2] & 0x03) * (x_thread[5] * 256.0f);
accum += (w[2] & 0x1c) * x_thread[6];
accum += (w[2] & 0xe0) * x_thread[7];
}
}
else if (bits == 4) {
const device uint16_t* ws = (const device uint16_t*)w;
for (int i = 0; i < (values_per_thread / 4); i++) {
@@ -195,23 +129,6 @@ inline U qdot(
}
}
else if (bits == 6) {
for (int i = 0; i < (values_per_thread / 4); i++) {
x_thread += 4 * i;
w += 3 * i;
accum += (w[0] & 0x3f) * x_thread[0];
accum += (w[0] & 0xc0) * x_thread[1];
accum += (w[1] & 0x0f) * (x_thread[1] * 256.0f);
accum += (w[1] & 0xf0) * x_thread[2];
accum += (w[2] & 0x03) * (x_thread[2] * 256.0f);
accum += (w[2] & 0xfc) * x_thread[3];
}
}
else if (bits == 8) {
for (int i = 0; i < values_per_thread; i++) {
accum += x_thread[i] * w[i];
@@ -230,8 +147,8 @@ inline U qdot_safe(
U sum,
int N) {
static_assert(
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
"Template undefined for bits not in {2, 3, 4, 6, 8}");
bits == 2 || bits == 4 || bits == 8,
"Template undefined for bits not in {2, 4, 8}");
U accum = 0;
@@ -245,26 +162,6 @@ inline U qdot_safe(
}
}
else if (bits == 3) {
for (int i = 0; i < (N / 8); i++) {
x_thread += 8 * i;
w += 3 * i;
accum += (w[0] & 0x07) * x_thread[0];
accum += (w[0] & 0x38) * x_thread[1];
accum += (w[0] & 0xc0) * x_thread[2];
accum += (w[1] & 0x01) * (x_thread[2] * 256.0f);
accum += (w[1] & 0x0e) * x_thread[3];
accum += (w[1] & 0x70) * x_thread[4];
accum += (w[1] & 0x80) * x_thread[5];
accum += (w[2] & 0x03) * (x_thread[5] * 256.0f);
accum += (w[2] & 0x1c) * x_thread[6];
accum += (w[2] & 0xe0) * x_thread[7];
}
}
else if (bits == 4) {
const device uint16_t* ws = (const device uint16_t*)w;
for (int i = 0; i < (N / 4); i++) {
@@ -276,23 +173,6 @@ inline U qdot_safe(
}
}
else if (bits == 6) {
for (int i = 0; i < (N / 4); i++) {
x_thread += 4 * i;
w += 3 * i;
accum += (w[0] & 0x3f) * x_thread[0];
accum += (w[0] & 0xc0) * x_thread[1];
accum += (w[1] & 0x0f) * (x_thread[1] * 256.0f);
accum += (w[1] & 0xf0) * x_thread[2];
accum += (w[2] & 0x03) * (x_thread[2] * 256.0f);
accum += (w[2] & 0xfc) * x_thread[3];
}
}
else if (bits == 8) {
for (int i = 0; i < N; i++) {
accum += x_thread[i] * w[i];
@@ -306,8 +186,8 @@ template <typename U, int values_per_thread, int bits>
inline void
qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) {
static_assert(
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
"Template undefined for bits not in {2, 3, 4, 6, 8}");
bits == 2 || bits == 4 || bits == 8,
"Template undefined for bits not in {2, 4, 8}");
if (bits == 2) {
U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f};
@@ -319,45 +199,12 @@ qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) {
}
}
else if (bits == 3) {
for (int i = 0; i < (values_per_thread / 8); i++) {
uint8_t w0 = w[3 * i];
uint8_t w1 = w[3 * i + 1];
uint8_t w2 = w[3 * i + 2];
result[8 * i] += x * ((w0 & 0x7) * scale + bias);
result[8 * i + 1] += x * (((w0 & 0x38) >> 3) * scale + bias);
result[8 * i + 2] +=
x * ((((w0 & 0xc0) >> 6) + ((w1 & 0x1) << 2)) * scale + bias);
result[8 * i + 3] += x * (((w1 & 0xe) >> 1) * scale + bias);
result[8 * i + 4] += x * (((w1 & 0x70) >> 4) * scale + bias);
result[8 * i + 5] +=
x * ((((w1 & 0x80) >> 7) + ((w2 & 0x3) << 1)) * scale + bias);
result[8 * i + 6] += x * (((w2 & 0x1c) >> 2) * scale + bias);
result[8 * i + 7] += x * (((w2 & 0xe0) >> 5) * scale + bias);
}
}
else if (bits == 4) {
U s[2] = {scale, scale / 16.0f};
for (int i = 0; i < (values_per_thread / 2); i++) {
result[2 * i] += x * (s[0] * (w[i] & 0x0f) + bias);
result[2 * i + 1] += x * (s[1] * (w[i] & 0xf0) + bias);
}
} else if (bits == 6) {
for (int i = 0; i < (values_per_thread / 4); i++) {
uint8_t w0 = w[3 * i];
uint8_t w1 = w[3 * i + 1];
uint8_t w2 = w[3 * i + 2];
result[4 * i] += x * ((w0 & 0x3f) * scale + bias);
result[4 * i + 1] +=
x * ((((w0 >> 6) & 0x03) + ((w1 & 0x0f) << 2)) * scale + bias);
result[4 * i + 2] +=
x * ((((w1 >> 4) & 0x0f) + ((w2 & 0x03) << 4)) * scale + bias);
result[4 * i + 3] += x * (((w2 >> 2) & 0x3f) * scale + bias);
}
}
else if (bits == 8) {
@@ -371,8 +218,8 @@ template <typename U, int N, int bits>
inline void
dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) {
static_assert(
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
"Template undefined for bits not in {2, 3, 4, 6, 8}");
bits == 2 || bits == 4 || bits == 8,
"Template undefined for bits not in {2, 4, 8}");
if (bits == 2) {
U s[4] = {
@@ -388,22 +235,6 @@ dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) {
}
}
else if (bits == 3) {
for (int i = 0; i < (N / 8); i++) {
w_local += 8 * i;
w += 3 * i;
w_local[0] = (w[0] & 0x7) * scale + bias;
w_local[1] = ((w[0] & 0x38) >> 3) * scale + bias;
w_local[2] = (((w[0] & 0xc0) >> 6) + ((w[1] & 0x1) << 2)) * scale + bias;
w_local[3] = ((w[1] & 0xe) >> 1) * scale + bias;
w_local[4] = ((w[1] & 0x70) >> 4) * scale + bias;
w_local[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias;
w_local[6] = ((w[2] & 0x1c) >> 2) * scale + bias;
w_local[7] = ((w[2] & 0xe0) >> 5) * scale + bias;
}
}
else if (bits == 4) {
U s[2] = {scale, scale / static_cast<U>(16.0f)};
for (int i = 0; i < (N / 2); i++) {
@@ -412,18 +243,6 @@ dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) {
}
}
else if (bits == 6) {
for (int i = 0; i < (N / 4); i++) {
w_local += 4 * i;
w += 3 * i;
w_local[0] = (w[0] & 0x3f) * scale + bias;
w_local[1] = (((w[0] >> 6) & 0x03) + ((w[1] & 0x0f) << 2)) * scale + bias;
w_local[2] = (((w[1] >> 4) & 0x0f) + ((w[2] & 0x03) << 4)) * scale + bias;
w_local[3] = ((w[2] >> 2) & 0x3f) * scale + bias;
}
}
else if (bits == 8) {
for (int i = 0; i < N; i++) {
w_local[i] = scale * w[i] + bias;
@@ -448,11 +267,10 @@ struct QuantizedBlockLoader {
group_size % BCOLS == 0,
"The group size should be divisible by the columns");
static_assert(
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
"Template undefined for bits not in {2, 3, 4, 6, 8}");
bits == 2 || bits == 4 || bits == 8,
"Template undefined for bits not in {2, 4, 8}");
MLX_MTL_CONST short pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
MLX_MTL_CONST short bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
MLX_MTL_CONST short pack_factor = 32 / bits;
MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor;
MLX_MTL_CONST short n_reads =
(BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size;
@@ -468,12 +286,12 @@ struct QuantizedBlockLoader {
const short bj;
threadgroup T* dst;
const device uint8_t* src;
const device uint32_t* src;
const device T* scales;
const device T* biases;
QuantizedBlockLoader(
const device uint8_t* src_,
const device uint32_t* src_,
const device T* scales_,
const device T* biases_,
const int src_ld_,
@@ -482,16 +300,14 @@ struct QuantizedBlockLoader {
ushort simd_lane_id [[thread_index_in_simdgroup]])
: src_ld(src_ld_),
tile_stride(
reduction_dim ? BCOLS_PACKED * bytes_per_pack
: BROWS * src_ld * bytes_per_pack / pack_factor),
reduction_dim ? BCOLS_PACKED : BROWS * src_ld / pack_factor),
group_step_cnt(0),
group_stride(BROWS * src_ld / group_size),
thread_idx(simd_group_id * 32 + simd_lane_id),
bi(n_reads * thread_idx / BCOLS_PACKED),
bj((n_reads * thread_idx) % BCOLS_PACKED),
dst(dst_ + bi * dst_ld + bj * pack_factor),
src(src_ + bi * src_ld * bytes_per_pack / pack_factor +
bj * bytes_per_pack),
src(src_ + bi * src_ld / pack_factor + bj),
scales(scales_ + bi * src_ld / group_size),
biases(biases_ + bi * src_ld / group_size) {}
@@ -504,7 +320,7 @@ struct QuantizedBlockLoader {
T bias = *biases;
for (int i = 0; i < n_reads; i++) {
dequantize<T, pack_factor, bits>(
src + i * bytes_per_pack, scale, bias, dst + i * pack_factor);
(device uint8_t*)(src + i), scale, bias, dst + i * pack_factor);
}
}
@@ -531,10 +347,7 @@ struct QuantizedBlockLoader {
T bias = *biases;
for (int i = 0; i < n_reads; i++) {
dequantize<T, pack_factor, bits>(
(device uint8_t*)(src + i * bytes_per_pack),
scale,
bias,
dst + i * pack_factor);
(device uint8_t*)(src + i), scale, bias, dst + i * pack_factor);
}
}
@@ -597,7 +410,8 @@ METAL_FUNC void qmv_quad_impl(
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
for (int row = 0; row < results_per_quadgroup; row++) {
auto wl = (const device uint8_t*)(w + row * in_vec_size_w * quads_per_simd);
const device uint8_t* wl =
(const device uint8_t*)(w + row * in_vec_size_w * quads_per_simd);
const device T* sl = scales + row * in_vec_size_g * quads_per_simd;
const device T* bl = biases + row * in_vec_size_g * quads_per_simd;
@@ -628,30 +442,25 @@ METAL_FUNC void qmv_fast_impl(
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
constexpr int packs_per_thread = bits == 2 ? 1 : 2;
constexpr int packs_per_thread = bits > 2 ? 2 : 1;
constexpr int num_simdgroups = 2;
constexpr int results_per_simdgroup = 4;
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3;
constexpr int pack_factor = 32 / bits;
constexpr int values_per_thread = pack_factor * packs_per_thread;
constexpr int block_size = values_per_thread * SIMD_SIZE;
constexpr int scale_step_per_thread = group_size / values_per_thread;
const device uint8_t* ws = (const device uint8_t*)w;
typedef float U;
thread U x_thread[values_per_thread];
thread U result[results_per_simdgroup] = {0};
// Adjust positions
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
const int in_vec_size_w = in_vec_size / pack_factor;
const int in_vec_size_g = in_vec_size / group_size;
const int out_row = tid.x * (num_simdgroups * results_per_simdgroup) +
simd_gid * results_per_simdgroup;
ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack;
w += out_row * in_vec_size_w + simd_lid * packs_per_thread;
scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
x += tid.y * in_vec_size + simd_lid * values_per_thread;
@@ -661,7 +470,8 @@ METAL_FUNC void qmv_fast_impl(
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
for (int row = 0; row < results_per_simdgroup; row++) {
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
const device uint8_t* wl =
(const device uint8_t*)(w + row * in_vec_size_w);
const device T* sl = scales + row * in_vec_size_g;
const device T* bl = biases + row * in_vec_size_g;
@@ -670,7 +480,7 @@ METAL_FUNC void qmv_fast_impl(
result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
}
ws += block_size * bytes_per_pack / pack_factor;
w += block_size / pack_factor;
scales += block_size / group_size;
biases += block_size / group_size;
x += block_size;
@@ -696,25 +506,21 @@ METAL_FUNC void qmv_impl(
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
constexpr int num_simdgroups = 2;
constexpr int results_per_simdgroup = 4;
constexpr int packs_per_thread = 1;
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3;
constexpr int pack_factor = 32 / bits;
constexpr int values_per_thread = pack_factor * packs_per_thread;
constexpr int block_size = values_per_thread * SIMD_SIZE;
constexpr int scale_step_per_thread = group_size / values_per_thread;
const device uint8_t* ws = (const device uint8_t*)w;
typedef float U;
thread U x_thread[values_per_thread];
thread U result[results_per_simdgroup] = {0};
// Adjust positions
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
const int in_vec_size_w = in_vec_size / pack_factor;
const int in_vec_size_g = in_vec_size / group_size;
const int out_row = tid.x * (num_simdgroups * results_per_simdgroup) +
simd_gid * results_per_simdgroup;
@@ -727,8 +533,7 @@ METAL_FUNC void qmv_impl(
// In this case we need to properly guard all our reads because there isn't
// even 1 tile in the matrix
if (out_vec_size < (num_simdgroups * results_per_simdgroup)) {
ws +=
out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack;
w += out_row * in_vec_size_w + simd_lid * packs_per_thread;
scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
x += tid.y * in_vec_size + simd_lid * values_per_thread;
@@ -739,7 +544,8 @@ METAL_FUNC void qmv_impl(
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
for (int row = 0; out_row + row < out_vec_size; row++) {
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
const device uint8_t* wl =
(const device uint8_t*)(w + row * in_vec_size_w);
const device T* sl = scales + row * in_vec_size_g;
const device T* bl = biases + row * in_vec_size_g;
@@ -749,7 +555,7 @@ METAL_FUNC void qmv_impl(
qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
}
ws += block_size * bytes_per_pack / pack_factor;
w += block_size / pack_factor;
scales += block_size / group_size;
biases += block_size / group_size;
x += block_size;
@@ -758,20 +564,18 @@ METAL_FUNC void qmv_impl(
static_cast<int>(in_vec_size - k - simd_lid * values_per_thread),
0,
values_per_thread);
if (remaining > 0) {
U sum = load_vector_safe<T, U, values_per_thread, bits>(
x, x_thread, remaining);
U sum =
load_vector_safe<T, U, values_per_thread, bits>(x, x_thread, remaining);
for (int row = 0; out_row + row < out_vec_size; row++) {
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
const device T* sl = scales + row * in_vec_size_g;
const device T* bl = biases + row * in_vec_size_g;
for (int row = 0; out_row + row < out_vec_size; row++) {
const device uint8_t* wl =
(const device uint8_t*)(w + row * in_vec_size_w);
const device T* sl = scales + row * in_vec_size_g;
const device T* bl = biases + row * in_vec_size_g;
U s = sl[0];
U b = bl[0];
result[row] +=
qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
}
U s = sl[0];
U b = bl[0];
result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
}
for (int row = 0; out_row + row < out_vec_size; row++) {
@@ -784,8 +588,7 @@ METAL_FUNC void qmv_impl(
// In this case the last tile is moved back to redo some output values
else {
ws += used_out_row * in_vec_size_w +
simd_lid * packs_per_thread * bytes_per_pack;
w += used_out_row * in_vec_size_w + simd_lid * packs_per_thread;
scales += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
biases += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
x += tid.y * in_vec_size + simd_lid * values_per_thread;
@@ -796,7 +599,8 @@ METAL_FUNC void qmv_impl(
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
for (int row = 0; row < results_per_simdgroup; row++) {
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
const device uint8_t* wl =
(const device uint8_t*)(w + row * in_vec_size_w);
const device T* sl = scales + row * in_vec_size_g;
const device T* bl = biases + row * in_vec_size_g;
@@ -806,7 +610,7 @@ METAL_FUNC void qmv_impl(
qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
}
ws += block_size * bytes_per_pack / pack_factor;
w += block_size / pack_factor;
scales += block_size / group_size;
biases += block_size / group_size;
x += block_size;
@@ -815,21 +619,21 @@ METAL_FUNC void qmv_impl(
static_cast<int>(in_vec_size - k - simd_lid * values_per_thread),
0,
values_per_thread);
if (remaining > 0) {
U sum = load_vector_safe<T, U, values_per_thread, bits>(
x, x_thread, remaining);
U sum =
load_vector_safe<T, U, values_per_thread, bits>(x, x_thread, remaining);
for (int row = 0; row < results_per_simdgroup; row++) {
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
const device T* sl = scales + row * in_vec_size_g;
const device T* bl = biases + row * in_vec_size_g;
for (int row = 0; row < results_per_simdgroup; row++) {
const device uint8_t* wl =
(const device uint8_t*)(w + row * in_vec_size_w);
const device T* sl = scales + row * in_vec_size_g;
const device T* bl = biases + row * in_vec_size_g;
U s = sl[0];
U b = bl[0];
result[row] += qdot_safe<U, values_per_thread, bits>(
wl, x_thread, s, b, sum, remaining);
}
U s = sl[0];
U b = bl[0];
result[row] += qdot_safe<U, values_per_thread, bits>(
wl, x_thread, s, b, sum, remaining);
}
for (int row = 0; row < results_per_simdgroup; row++) {
result[row] = simd_sum(result[row]);
if (simd_lid == 0) {
@@ -851,20 +655,14 @@ METAL_FUNC void qvm_impl(
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
constexpr int num_simdgroups = 2;
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
constexpr int pack_factor = 32 / bits;
constexpr int tn = 32 / pack_factor;
constexpr int block_size = SIMD_SIZE;
using W_T =
typename ConditionalType<power_of_2_bits, uint32_t, uint8_t>::type;
const device W_T* ws = (const device W_T*)w;
constexpr int blocksize = SIMD_SIZE;
typedef float U;
typedef struct {
W_T wi[tn * bytes_per_pack];
uint32_t wi[tn];
} vec_w;
thread vec_w w_local;
@@ -874,10 +672,11 @@ METAL_FUNC void qvm_impl(
thread U x_local = 0;
// Adjust positions
const int out_vec_size_w = out_vec_size * bytes_per_pack / pack_factor;
const int out_vec_size_w = out_vec_size / pack_factor;
const int out_vec_size_g = out_vec_size / group_size;
int out_col = pack_factor * tn * (tid.x * num_simdgroups + simd_gid);
ws += out_col * bytes_per_pack / pack_factor + simd_lid * out_vec_size_w;
int out_col =
tid.x * (num_simdgroups * pack_factor * tn) + simd_gid * pack_factor * tn;
w += out_col / pack_factor + simd_lid * out_vec_size_w;
scales += out_col / group_size + simd_lid * out_vec_size_g;
biases += out_col / group_size + simd_lid * out_vec_size_g;
x += tid.y * in_vec_size + simd_lid;
@@ -887,42 +686,43 @@ METAL_FUNC void qvm_impl(
return;
}
// Loop over in_vec in blocks of block_size
int remaining = in_vec_size % block_size;
// Loop over in_vec in blocks of blocksize
int remaining = in_vec_size % blocksize;
if (remaining == 0) {
for (int i = 0; i < in_vec_size; i += block_size) {
for (int i = 0; i < in_vec_size; i += blocksize) {
x_local = *x;
scale = *scales;
bias = *biases;
w_local = *((device vec_w*)ws);
w_local = *((device vec_w*)w);
qouter<U, tn * pack_factor, bits>(
(thread uint8_t*)&w_local, x_local, scale, bias, result);
x += block_size;
scales += block_size * out_vec_size_g;
biases += block_size * out_vec_size_g;
ws += block_size * out_vec_size_w;
x += blocksize;
scales += blocksize * out_vec_size_g;
biases += blocksize * out_vec_size_g;
w += blocksize * out_vec_size_w;
}
} else {
for (int i = block_size; i < in_vec_size; i += block_size) {
for (int i = blocksize; i < in_vec_size; i += blocksize) {
x_local = *x;
scale = *scales;
bias = *biases;
w_local = *((device vec_w*)ws);
w_local = *((device vec_w*)w);
qouter<U, tn * pack_factor, bits>(
(thread uint8_t*)&w_local, x_local, scale, bias, result);
x += block_size;
scales += block_size * out_vec_size_g;
biases += block_size * out_vec_size_g;
ws += block_size * out_vec_size_w;
x += blocksize;
scales += blocksize * out_vec_size_g;
biases += blocksize * out_vec_size_g;
w += blocksize * out_vec_size_w;
}
if (static_cast<int>(simd_lid) < remaining) {
x_local = *x;
scale = *scales;
bias = *biases;
w_local = *((device vec_w*)ws);
w_local = *((device vec_w*)w);
} else {
x_local = 0;
scale = 0;
@@ -977,9 +777,8 @@ METAL_FUNC void qmm_t_impl(
constexpr int WM = 2;
constexpr int WN = 2;
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
constexpr int pack_factor = 32 / bits;
constexpr int BK_padded = (BK + 16 / sizeof(T));
constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
// Instantiate the appropriate BlockMMA and Loader
using mma_t = mlx::steel::
@@ -997,15 +796,13 @@ METAL_FUNC void qmm_t_impl(
bits>;
// Set the block
const int K_w = K * bytes_per_pack / pack_factor;
const int K_w = K / pack_factor;
const int K_g = K / group_size;
const int y_row = tid.y * BM;
const int y_col = tid.x * BN;
auto wl = (const device uint8_t*)w;
x += y_row * K;
wl += y_col * K_w;
w += y_col * K_w;
scales += y_col * K_g;
biases += y_col * K_g;
y += y_row * N + y_col;
@@ -1014,7 +811,7 @@ METAL_FUNC void qmm_t_impl(
const short num_els = min(BM, M - y_row);
const short num_outs = min(BN, N - y_col);
loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
loader_w_t loader_w(wl, scales, biases, K, Ws, simd_gid, simd_lid);
loader_w_t loader_w(w, scales, biases, K, Ws, simd_gid, simd_lid);
mma_t mma_op(simd_gid, simd_lid);
if (num_els < BM) {
@@ -1056,7 +853,6 @@ METAL_FUNC void qmm_t_impl(
loader_x.load_unsafe();
loader_w.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(Xs, Ws);
loader_x.next();
loader_w.next();
@@ -1102,11 +898,9 @@ METAL_FUNC void qmm_n_impl(
constexpr int WM = 2;
constexpr int WN = 2;
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
constexpr int pack_factor = 32 / bits;
constexpr int BK_padded = (BK + 16 / sizeof(T));
constexpr int BN_padded = (BN + 16 / sizeof(T));
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
// Instantiate the appropriate BlockMMA and Loader
using mma_t = mlx::steel::
@@ -1123,13 +917,11 @@ METAL_FUNC void qmm_n_impl(
group_size,
bits>;
auto wl = (const device uint8_t*)w;
// Set the block
const int y_row = tid.y * BM;
const int y_col = tid.x * BN;
x += y_row * K;
wl += y_col * bytes_per_pack / pack_factor;
w += y_col / pack_factor;
scales += y_col / group_size;
biases += y_col / group_size;
y += y_row * N + y_col;
@@ -1137,7 +929,7 @@ METAL_FUNC void qmm_n_impl(
// Make the x loader and mma operation
const short num_els = min(BM, M - y_row);
loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
loader_w_t loader_w(wl, scales, biases, N, Ws, simd_gid, simd_lid);
loader_w_t loader_w(w, scales, biases, N, Ws, simd_gid, simd_lid);
mma_t mma_op(simd_gid, simd_lid);
if (num_els < BM) {
@@ -2009,14 +1801,13 @@ template <typename T, const int group_size, const int bits>
uint2 grid_dim [[threads_per_grid]]) {
constexpr T eps = T(1e-7);
constexpr int simd_size = 32;
constexpr int uint8_bits = 8;
constexpr T n_bins = (1 << bits) - 1;
constexpr int packs_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
constexpr int packs_per_int = uint8_bits / bits;
constexpr int values_per_reduce = group_size / simd_size;
constexpr int writes_per_reduce = packs_per_int / values_per_reduce;
constexpr int writes_per_pack =
writes_per_reduce > 1 ? 1 : values_per_reduce / packs_per_int;
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
static_assert(
group_size % simd_size == 0,
@@ -2024,9 +1815,7 @@ template <typename T, const int group_size, const int bits>
size_t offset = index.x + grid_dim.x * size_t(index.y);
size_t in_index = offset * values_per_reduce;
size_t out_index = power_of_2_bits
? offset * writes_per_pack
: offset * bytes_per_pack / writes_per_reduce;
size_t out_index = offset * writes_per_pack;
T w_thread[values_per_reduce];
T w_min = Limits<T>::max;
@@ -2059,9 +1848,7 @@ template <typename T, const int group_size, const int bits>
biases[gindex] = bias;
}
// We accumulate 3 bytes worth for 3/6 bit so we need a uint32_t
uint32_t output = 0;
uint8_t output = 0;
#pragma clang loop unroll(full)
for (int i = 0; i < values_per_reduce; i++) {
uint8_t val = min(round((w_thread[i] - bias) / scale), n_bins);
@@ -2077,23 +1864,47 @@ template <typename T, const int group_size, const int bits>
output = 0;
} else {
#pragma clang loop unroll(full)
for (int j = 1; j < writes_per_reduce; j++) {
uint8_t sval = simd_shuffle_down(val, j);
output += sval << (bits * (j * values_per_reduce + i));
for (int j = 0; j < writes_per_reduce - 1; j++) {
uint8_t sval = simd_shuffle_down(val, j + 1);
output += sval << (bits * (values_per_reduce + j + i));
}
}
}
if (bits == 3 || bits == 6) {
if (in_index % packs_per_int == 0 && out_index % bytes_per_pack == 0) {
out[out_index] = output & 0xff;
out[out_index + 1] = (output & 0xff00) >> 8;
out[out_index + 2] = (output & 0xff0000) >> 16;
}
} else {
if (writes_per_reduce > 0 && out_index % writes_per_reduce == 0) {
out[out_index / writes_per_reduce] = output;
if (writes_per_reduce > 0 && out_index % writes_per_reduce == 0) {
out[out_index / writes_per_reduce] = output;
}
}
template <typename T, const int group_size, const int bits>
[[kernel]] void affine_quantize_scales_biases(
const device T* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
device uint8_t* out [[buffer(3)]],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
constexpr int uint8_bits = 8;
constexpr int packs_per_int = uint8_bits / bits;
constexpr T n_bins = (1 << bits) - 1;
size_t offset = index.x + grid_dim.x * size_t(index.y);
size_t in_index = offset * packs_per_int;
size_t gindex = in_index / group_size;
T scale = scales[gindex];
T bias = biases[gindex];
uint8_t output = 0;
#pragma clang loop unroll(full)
for (int i = 0; i < packs_per_int; i++) {
uint8_t val = min(round((w[in_index + i] - bias) / scale), n_bins);
if (bits == 8) {
output = val;
} else {
output += val << (bits * i);
}
}
out[offset] = output;
}
template <typename T, const int group_size, const int bits>
@@ -2104,48 +1915,26 @@ template <typename T, const int group_size, const int bits>
device T* out [[buffer(3)]],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
constexpr int packs_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
constexpr int uint8_bits = 8;
constexpr int packs_per_int = uint8_bits / bits;
size_t offset = index.x + grid_dim.x * size_t(index.y);
size_t oindex = offset * packs_per_int;
size_t gindex = oindex / group_size;
T scale = scales[gindex];
T bias = biases[gindex];
uint val = w[offset];
out += oindex;
if (bits == 3) {
w += offset * bytes_per_pack;
out[0] = (w[0] & 0x7) * scale + bias;
out[1] = ((w[0] & 0x38) >> 3) * scale + bias;
out[2] = (((w[0] & 0xc0) >> 6) + ((w[1] & 0x1) << 2)) * scale + bias;
out[3] = ((w[1] & 0xe) >> 1) * scale + bias;
out[4] = ((w[1] & 0x70) >> 4) * scale + bias;
out[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias;
out[6] = ((w[2] & 0x1c) >> 2) * scale + bias;
out[7] = ((w[2] & 0xe0) >> 5) * scale + bias;
} else if (bits == 6) {
w += offset * bytes_per_pack;
out[0] = (w[0] & 0x3f) * scale + bias;
out[1] = (((w[0] >> 6) & 0x03) + ((w[1] & 0x0f) << 2)) * scale + bias;
out[2] = (((w[1] >> 4) & 0x0f) + ((w[2] & 0x03) << 4)) * scale + bias;
out[3] = ((w[2] >> 2) & 0x3f) * scale + bias;
} else {
uint val = w[offset];
#pragma clang loop unroll(full)
for (int i = 0; i < packs_per_int; i++) {
uint8_t d;
if (bits == 2) {
d = (val >> (bits * i)) & 0x03;
} else if (bits == 4) {
d = (val >> (bits * i)) & 0x0f;
} else if (bits == 8) {
d = val;
}
out[i] = scale * d + bias;
for (int i = 0; i < packs_per_int; i++) {
uint8_t d;
if (bits == 2) {
d = (val >> (bits * i)) & 0x03;
} else if (bits == 4) {
d = (val >> (bits * i)) & 0x0f;
} else if (bits == 8) {
d = val;
}
out[oindex + i] = scale * d + bias;
}
}

View File

@@ -72,6 +72,7 @@
#define instantiate_quantized_all_single(type, group_size, bits) \
instantiate_quantized(affine_quantize, type, group_size, bits) \
instantiate_quantized(affine_quantize_scales_biases, type, group_size, bits) \
instantiate_quantized(affine_dequantize, type, group_size, bits) \
instantiate_quantized(bs_qmv_fast, type, group_size, bits) \
instantiate_quantized(bs_qmv, type, group_size, bits) \
@@ -115,9 +116,7 @@
#define instantiate_quantized_all() \
instantiate_quantized_groups(2) \
instantiate_quantized_groups(3) \
instantiate_quantized_groups(4) \
instantiate_quantized_groups(6) \
instantiate_quantized_groups(8)
instantiate_quantized_all() // clang-format on

View File

@@ -10,156 +10,186 @@
#include "mlx/backend/metal/kernels/reduction/ops.h"
#include "mlx/backend/metal/kernels/reduce.h"
#define instantiate_init_reduce(name, tname, type, op) \
instantiate_kernel("init_reduce_" #name #tname, init_reduce, type, op<type>)
#define instantiate_reduce_helper_floats(inst_f, name, op) \
inst_f(name, float16, half, op) \
inst_f(name, float32, float, op) \
inst_f(name, bfloat16, bfloat16_t, op)
instantiate_init_reduce(and, bool_, bool, And)
instantiate_init_reduce(or, bool_, bool, Or)
#define instantiate_reduce_helper_uints(inst_f, name, op) \
inst_f(name, uint8, uint8_t, op) \
inst_f(name, uint16, uint16_t, op) \
inst_f(name, uint32, uint32_t, op)
#define instantiate_init_sum_prod(name, op) \
instantiate_init_reduce(name, int32, int32_t, op) \
instantiate_init_reduce(name, int64, int64_t, op) \
instantiate_init_reduce(name, float16, float16_t, op) \
instantiate_init_reduce(name, bfloat16, bfloat16_t, op) \
instantiate_init_reduce(name, float32, float, op) \
instantiate_init_reduce(name, complex64, complex64_t, op)
#define instantiate_reduce_helper_ints(inst_f, name, op) \
inst_f(name, int8, int8_t, op) \
inst_f(name, int16, int16_t, op) \
inst_f(name, int32, int32_t, op)
instantiate_init_sum_prod(sum, Sum)
instantiate_init_sum_prod(prod, Prod)
#define instantiate_reduce_helper_64b(inst_f, name, op) \
inst_f(name, int64, int64_t, op) \
inst_f(name, uint64, uint64_t, op) \
inst_f(name, complex64, complex64_t, op)
#define instantiate_init_min_max(name, op) \
instantiate_init_reduce(name, bool_, bool, op) \
instantiate_init_reduce(name, int8, int8_t, op) \
instantiate_init_reduce(name, int16, int16_t, op) \
instantiate_init_reduce(name, int32, int32_t, op) \
instantiate_init_reduce(name, int64, int64_t, op) \
instantiate_init_reduce(name, uint8, uint8_t, op) \
instantiate_init_reduce(name, uint16, uint16_t, op) \
instantiate_init_reduce(name, uint32, uint32_t, op) \
instantiate_init_reduce(name, uint64, uint64_t, op) \
instantiate_init_reduce(name, float16, float16_t, op) \
instantiate_init_reduce(name, bfloat16, bfloat16_t, op) \
instantiate_init_reduce(name, float32, float, op) \
instantiate_init_reduce(name, complex64, complex64_t, op)
#define instantiate_reduce_helper_types(inst_f, name, op) \
instantiate_reduce_helper_floats(inst_f, name, op) \
instantiate_reduce_helper_uints(inst_f, name, op) \
instantiate_reduce_helper_ints(inst_f, name, op)
instantiate_init_min_max(min, Min)
instantiate_init_min_max(max, Max)
#define instantiate_reduce_ops(inst_f, type_f) \
type_f(inst_f, sum, Sum) \
type_f(inst_f, prod, Prod) \
type_f(inst_f, min, Min) \
type_f(inst_f, max, Max)
// Special case for bool reductions
#define instantiate_reduce_from_types_helper( \
inst_f, name, tname, itype, otype, op) \
inst_f(name##tname, itype, otype, op)
#define instantiate_reduce_from_types(inst_f, name, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, bool_, bool, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, uint8, uint8_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, uint16, uint16_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, uint32, uint32_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, uint64, uint64_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, int8, int8_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, int16, int16_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, int32, int32_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, int64, int64_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, float16, half, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, \
name, \
float32, \
float, \
otype, \
op) \
instantiate_reduce_from_types_helper( \
inst_f, \
name, \
bfloat16, \
bfloat16_t, \
otype, \
op)
#define instantiate_init_reduce(name, otype, op) \
instantiate_kernel("init_reduce_" #name, \
init_reduce, \
otype, op)
#define instantiate_init_reduce_helper(name, tname, type, op) \
instantiate_init_reduce(name##tname, type, op<type>)
instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_types)
instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_64b)
instantiate_init_reduce(andbool_, bool, And<bool>)
instantiate_init_reduce(orbool_, bool, Or<bool>)
#define instantiate_all_reduce(name, itype, otype, op) \
instantiate_kernel("all_reduce_" #name, \
all_reduce, \
itype, otype, op)
#define instantiate_col_reduce_small(name, itype, otype, op, dim) \
instantiate_kernel("col_reduce_small_" #dim "_reduce_" #name, \
col_reduce_small, \
itype, otype, op, uint, dim) \
instantiate_kernel("col_reduce_longcolumn_" #dim "_reduce_" #name, \
col_reduce_longcolumn, \
itype, otype, op, uint, dim) \
instantiate_kernel("col_reduce_small_large_" #dim "_reduce_" #name, \
col_reduce_small, \
itype, otype, op, size_t, dim) \
instantiate_kernel("col_reduce_longcolumn_large_" #dim "_reduce_" #name, \
col_reduce_longcolumn, \
itype, otype, op, size_t, dim)
#define instantiate_same_all_reduce_helper(name, tname, type, op) \
instantiate_all_reduce(name##tname, type, type, op<type>)
#define instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, bm, bn) \
instantiate_kernel("col_reduce_looped_" #dim "_" #bm "_" #bn "_reduce_" #name, \
col_reduce_looped, \
itype, otype, op, uint, dim, bm, bn) \
instantiate_kernel("col_reduce_looped_large_" #dim "_" #bm "_" #bn "_reduce_" #name, \
col_reduce_looped, \
itype, otype, op, size_t, dim, bm, bn)
instantiate_reduce_ops(instantiate_same_all_reduce_helper, instantiate_reduce_helper_types)
instantiate_reduce_ops(instantiate_same_all_reduce_helper, instantiate_reduce_helper_64b)
#define instantiate_col_reduce_2pass_tile(name, itype, otype, op, dim, bm, bn) \
instantiate_kernel("col_reduce_2pass_" #dim "_" #bm "_" #bn "_reduce_" #name, \
col_reduce_2pass, \
itype, otype, op, uint, dim, bm, bn) \
instantiate_kernel("col_reduce_2pass_large_" #dim "_" #bm "_" #bn "_reduce_" #name, \
col_reduce_2pass, \
itype, otype, op, size_t, dim, bm, bn)
instantiate_reduce_from_types(instantiate_all_reduce, and, bool, And<bool>)
instantiate_reduce_from_types(instantiate_all_reduce, or, bool, Or<bool>)
// special case bool with larger output type
instantiate_all_reduce(sumbool_, bool, uint32_t, Sum<uint32_t>)
#define instantiate_col_reduce_small(name, itype, otype, op, dim) \
instantiate_kernel("col_reduce_small_" #dim "_reduce_" #name, \
col_reduce_small, \
itype, otype, op, dim) \
instantiate_kernel("col_reduce_longcolumn_" #dim "_reduce_" #name, \
col_reduce_longcolumn, \
itype, otype, op, dim)
#define instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, bm, bn) \
instantiate_kernel("col_reduce_looped_" #dim "_" #bm "_" #bn "_reduce_" #name, \
col_reduce_looped, \
itype, otype, op, dim, bm, bn)
#define instantiate_col_reduce_2pass_tile(name, itype, otype, op, dim, bm, bn) \
instantiate_kernel("col_reduce_2pass_" #dim "_" #bm "_" #bn "_reduce_" #name, \
col_reduce_2pass, \
itype, otype, op, dim, bm, bn)
#define instantiate_col_reduce_looped(name, itype, otype, op, dim) \
instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 32, 32) \
instantiate_col_reduce_2pass_tile(name, itype, otype, op, dim, 32, 32)
#define instantiate_col_reduce_general(name, itype, otype, op) \
instantiate_col_reduce_small(name, itype, otype, op, 0) \
instantiate_col_reduce_small(name, itype, otype, op, 1) \
instantiate_col_reduce_small(name, itype, otype, op, 2) \
instantiate_col_reduce_small(name, itype, otype, op, 5) \
instantiate_col_reduce_small(name, itype, otype, op, 3) \
instantiate_col_reduce_small(name, itype, otype, op, 4) \
instantiate_col_reduce_looped(name, itype, otype, op, 0) \
instantiate_col_reduce_looped(name, itype, otype, op, 1) \
instantiate_col_reduce_looped(name, itype, otype, op, 2) \
instantiate_col_reduce_looped(name, itype, otype, op, 5)
instantiate_col_reduce_looped(name, itype, otype, op, 3) \
instantiate_col_reduce_looped(name, itype, otype, op, 4)
#define instantiate_row_reduce_small(name, itype, otype, op, dim) \
instantiate_kernel("row_reduce_small_" #dim "_reduce_" #name, \
row_reduce_small, \
itype, otype, op, uint, dim) \
instantiate_kernel("row_reduce_small_large_" #dim "_reduce_" #name, \
row_reduce_small, \
itype, otype, op, size_t, dim)
#define instantiate_same_col_reduce_helper(name, tname, type, op) \
instantiate_col_reduce_general(name##tname, type, type, op<type>)
#define instantiate_row_reduce_looped(name, itype, otype, op, dim) \
instantiate_kernel("row_reduce_looped_" #dim "_reduce_" #name, \
row_reduce_looped, \
itype, otype, op, uint, dim) \
instantiate_kernel("row_reduce_looped_large_" #dim "_reduce_" #name, \
row_reduce_looped, \
itype, otype, op, size_t, dim)
instantiate_reduce_ops(instantiate_same_col_reduce_helper, instantiate_reduce_helper_types)
instantiate_reduce_ops(instantiate_same_col_reduce_helper, instantiate_reduce_helper_64b)
instantiate_col_reduce_general(sumbool_, bool, uint32_t, Sum<uint32_t>)
instantiate_reduce_from_types(instantiate_col_reduce_general, and, bool, And<bool>)
instantiate_reduce_from_types(instantiate_col_reduce_general, or, bool, Or<bool>)
#define instantiate_row_reduce_small(name, itype, otype, op, dim) \
instantiate_kernel("row_reduce_small_" #dim "_reduce_" #name, \
row_reduce_small, \
itype, otype, op, dim)
#define instantiate_row_reduce_looped(name, itype, otype, op, dim) \
instantiate_kernel("row_reduce_looped_" #dim "_reduce_" #name, \
row_reduce_looped, \
itype, otype, op, dim)
#define instantiate_row_reduce_general(name, itype, otype, op) \
instantiate_row_reduce_small(name, itype, otype, op, 0) \
instantiate_row_reduce_small(name, itype, otype, op, 1) \
instantiate_row_reduce_small(name, itype, otype, op, 2) \
instantiate_row_reduce_small(name, itype, otype, op, 5) \
instantiate_row_reduce_small(name, itype, otype, op, 3) \
instantiate_row_reduce_small(name, itype, otype, op, 4) \
instantiate_row_reduce_looped(name, itype, otype, op, 0) \
instantiate_row_reduce_looped(name, itype, otype, op, 1) \
instantiate_row_reduce_looped(name, itype, otype, op, 2) \
instantiate_row_reduce_looped(name, itype, otype, op, 5) \
instantiate_row_reduce_looped(name, itype, otype, op, 3) \
instantiate_row_reduce_looped(name, itype, otype, op, 4) \
instantiate_kernel("row_reduce_simple_" #name, \
row_reduce_simple, \
itype, otype, op)
#define instantiate_reduce_functions(name, tname, itype, otype, op) \
instantiate_all_reduce(name##tname, itype, otype, op<otype>) \
instantiate_row_reduce_general(name##tname, itype, otype, op<otype>) \
instantiate_col_reduce_general(name##tname, itype, otype, op<otype>)
#define instantiate_same_row_reduce_helper(name, tname, type, op) \
instantiate_row_reduce_general(name##tname, type, type, op<type>)
#define instantiate_and_or(name, op) \
instantiate_reduce_functions(name, bool_, bool, bool, op) \
instantiate_reduce_functions(name, int16, int16_t, bool, op) \
instantiate_reduce_functions(name, int32, int32_t, bool, op) \
instantiate_reduce_functions(name, int64, int64_t, bool, op)
instantiate_reduce_ops(instantiate_same_row_reduce_helper, instantiate_reduce_helper_types)
instantiate_reduce_ops(instantiate_same_row_reduce_helper, instantiate_reduce_helper_64b)
instantiate_and_or(and, And)
instantiate_and_or(or, Or)
instantiate_reduce_from_types(instantiate_row_reduce_general, and, bool, And<bool>)
instantiate_reduce_from_types(instantiate_row_reduce_general, or, bool, Or<bool>)
#define instantiate_sum_prod(name, op) \
instantiate_reduce_functions(name, int8, int8_t, int32_t, op) \
instantiate_reduce_functions(name, int16, int16_t, int32_t, op) \
instantiate_reduce_functions(name, int32, int32_t, int32_t, op) \
instantiate_reduce_functions(name, int64, int64_t, int64_t, op) \
instantiate_reduce_functions(name, float16, float16_t, float16_t, op) \
instantiate_reduce_functions(name, bfloat16, bfloat16_t, bfloat16_t, op) \
instantiate_reduce_functions(name, float32, float, float, op) \
instantiate_reduce_functions(name, complex64, complex64_t, complex64_t, op)
instantiate_sum_prod(sum, Sum)
instantiate_sum_prod(prod, Prod)
#define instantiate_min_max(name, op) \
instantiate_reduce_functions(name, int8, int8_t, int8_t, op) \
instantiate_reduce_functions(name, int16, int16_t, int16_t, op) \
instantiate_reduce_functions(name, int32, int32_t, int32_t, op) \
instantiate_reduce_functions(name, int64, int64_t, int64_t, op) \
instantiate_reduce_functions(name, uint8, uint8_t, uint8_t, op) \
instantiate_reduce_functions(name, uint16, uint16_t, uint16_t, op) \
instantiate_reduce_functions(name, uint32, uint32_t, uint32_t, op) \
instantiate_reduce_functions(name, uint64, uint64_t, uint64_t, op) \
instantiate_reduce_functions(name, float16, float16_t, float16_t, op) \
instantiate_reduce_functions(name, bfloat16, bfloat16_t, bfloat16_t, op) \
instantiate_reduce_functions(name, float32, float, float, op) \
instantiate_reduce_functions(name, complex64, complex64_t, complex64_t, op)
instantiate_min_max(min, Min)
instantiate_min_max(max, Max)
instantiate_row_reduce_general(sumbool_, bool, uint32_t, Sum<uint32_t>)
// clang-format on

View File

@@ -1,11 +1,6 @@
// Copyright © 2023-2024 Apple Inc.
template <
typename T,
typename U,
typename Op,
typename IdxT = int64_t,
int N_READS = REDUCE_N_READS>
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
[[kernel]] void all_reduce(
const device T* in [[buffer(0)]],
device U* out [[buffer(1)]],
@@ -21,10 +16,10 @@ template <
threadgroup U shared_vals[simd_size];
U total = Op::init;
IdxT start_idx = gid.y * IdxT(row_size);
IdxT actual_row =
int64_t start_idx = gid.y * row_size;
int64_t actual_row =
(start_idx + row_size <= in_size) ? row_size : in_size - start_idx;
IdxT blocks = actual_row / (lsize.x * N_READS);
int64_t blocks = actual_row / (lsize.x * N_READS);
int extra = actual_row - blocks * (lsize.x * N_READS);
extra -= lid.x * N_READS;
start_idx += lid.x * N_READS;
@@ -35,7 +30,7 @@ template <
extra = 0;
}
for (IdxT b = 0; b < blocks; b++) {
for (int64_t b = 0; b < blocks; b++) {
for (int i = 0; i < N_READS; i++) {
total = op(static_cast<U>(in[i]), total);
}

View File

@@ -1,6 +1,6 @@
// Copyright © 2023-2024 Apple Inc.
template <typename T, typename U, typename Op, typename IdxT, int NDIMS>
template <typename T, typename U, typename Op, int NDIMS>
[[kernel]] void col_reduce_small(
const device T* in [[buffer(0)]],
device U* out [[buffer(1)]],
@@ -19,7 +19,7 @@ template <typename T, typename U, typename Op, typename IdxT, int NDIMS>
uint3 lsize [[threads_per_threadgroup]]) {
constexpr int n_reads = 4;
Op op;
LoopedElemToLoc<NDIMS, IdxT, (NDIMS > 2)> loop(reduce_ndim);
looped_elem_to_loc<NDIMS> loop;
const device T* row;
U totals[n_reads];
@@ -27,20 +27,20 @@ template <typename T, typename U, typename Op, typename IdxT, int NDIMS>
totals[i] = Op::init;
}
IdxT column = IdxT(gid.x) * lsize.x * n_reads + lid.x * n_reads;
size_t column = size_t(gid.x) * lsize.x * n_reads + lid.x * n_reads;
if (column >= reduction_stride) {
return;
}
bool safe = column + n_reads <= reduction_stride;
IdxT out_idx = gid.y + gsize.y * IdxT(gid.z);
IdxT in_idx = elem_to_loc<size_t, IdxT>(out_idx, shape, strides, ndim);
size_t out_idx = gid.y + gsize.y * size_t(gid.z);
size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim);
in += in_idx + column;
IdxT total_rows = IdxT(non_col_reductions) * IdxT(reduction_size);
size_t total_rows = non_col_reductions * reduction_size;
loop.next(lid.y, reduce_shape, reduce_strides);
for (IdxT r = lid.y; r < total_rows; r += lsize.y) {
row = in + loop.location();
for (size_t r = lid.y; r < total_rows; r += lsize.y) {
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
if (safe) {
for (int i = 0; i < n_reads; i++) {
totals[i] = op(static_cast<U>(row[i]), totals[i]);
@@ -80,7 +80,7 @@ template <typename T, typename U, typename Op, typename IdxT, int NDIMS>
}
if (lid.y == 0) {
out += out_idx * IdxT(reduction_stride) + column;
out += out_idx * reduction_stride + column;
if (safe) {
for (int i = 0; i < n_reads; i++) {
out[i] = totals[i];
@@ -93,7 +93,7 @@ template <typename T, typename U, typename Op, typename IdxT, int NDIMS>
}
}
template <typename T, typename U, typename Op, typename IdxT, int NDIMS>
template <typename T, typename U, typename Op, int NDIMS>
[[kernel]] void col_reduce_longcolumn(
const device T* in [[buffer(0)]],
device U* out [[buffer(1)]],
@@ -112,19 +112,19 @@ template <typename T, typename U, typename Op, typename IdxT, int NDIMS>
uint3 lid [[thread_position_in_threadgroup]],
uint3 lsize [[threads_per_threadgroup]]) {
Op op;
LoopedElemToLoc<NDIMS, IdxT, (NDIMS > 2)> loop(reduce_ndim);
looped_elem_to_loc<NDIMS> loop;
const device T* row;
IdxT out_idx = gid.x + gsize.x * IdxT(gid.y);
IdxT in_idx = elem_to_loc<size_t, IdxT>(out_idx, shape, strides, ndim);
size_t out_idx = gid.x + gsize.x * size_t(gid.y);
size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim);
in += in_idx + lid.x;
U total = Op::init;
IdxT total_rows = IdxT(non_col_reductions) * IdxT(reduction_size);
size_t total_rows = non_col_reductions * reduction_size;
loop.next(gid.z * lsize.y + lid.y, reduce_shape, reduce_strides);
for (IdxT r = gid.z * lsize.y + lid.y; r < total_rows;
for (size_t r = gid.z * lsize.y + lid.y; r < total_rows;
r += lsize.y * gsize.z) {
row = in + loop.location();
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
total = op(static_cast<U>(*row), total);
loop.next(lsize.y * gsize.z, reduce_shape, reduce_strides);
}
@@ -136,8 +136,7 @@ template <typename T, typename U, typename Op, typename IdxT, int NDIMS>
for (uint i = 1; i < lsize.y; i++) {
total = op(total, shared_vals[i * lsize.x + lid.x]);
}
out[gid.z * IdxT(out_size) + out_idx * IdxT(reduction_stride) + lid.x] =
total;
out[gid.z * out_size + out_idx * reduction_stride + lid.x] = total;
}
}
@@ -152,14 +151,7 @@ template <typename T, typename U, typename Op, typename IdxT, int NDIMS>
* totals with a loop.
* 7. Write them to the output
*/
template <
typename T,
typename U,
typename Op,
typename IdxT,
int NDIMS,
int BM,
int BN>
template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
[[kernel]] void col_reduce_looped(
const device T* in [[buffer(0)]],
device U* out [[buffer(1)]],
@@ -184,7 +176,7 @@ template <
threadgroup U shared_vals[BN * BM];
U totals[n_reads];
LoopedElemToLoc<NDIMS, IdxT, (NDIMS > 2)> loop(reduce_ndim);
looped_elem_to_loc<NDIMS> loop;
const device T* row;
for (int i = 0; i < n_reads; i++) {
@@ -193,17 +185,17 @@ template <
short lid = simd_group_id * simd_size + simd_lane_id;
short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks);
IdxT column = BN * gid.x + offset.x;
size_t column = BN * gid.x + offset.x;
bool safe = column + n_reads <= reduction_stride;
IdxT out_idx = gid.y + gsize.y * IdxT(gid.z);
IdxT in_idx = elem_to_loc<size_t, IdxT>(out_idx, shape, strides, ndim);
size_t out_idx = gid.y + gsize.y * size_t(gid.z);
size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim);
in += in_idx + column;
IdxT total = IdxT(non_col_reductions) * IdxT(reduction_size);
size_t total = non_col_reductions * reduction_size;
loop.next(offset.y, reduce_shape, reduce_strides);
for (IdxT r = offset.y; r < total; r += BM) {
row = in + loop.location();
for (size_t r = offset.y; r < total; r += BM) {
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
if (safe) {
for (int i = 0; i < n_reads; i++) {
@@ -243,8 +235,8 @@ template <
// Write the output.
if (simd_lane_id == 0) {
IdxT out_column = BN * gid.x + out_offset.x;
out += out_idx * IdxT(reduction_stride) + out_column;
size_t out_column = BN * gid.x + out_offset.x;
out += out_idx * reduction_stride + out_column;
if (out_column + n_outputs <= reduction_stride) {
for (int i = 0; i < n_outputs; i++) {
out[i] = totals[i];
@@ -277,7 +269,7 @@ template <
// Write the output.
if (offset.y == 0) {
out += out_idx * IdxT(reduction_stride) + column;
out += out_idx * reduction_stride + column;
if (safe) {
for (int i = 0; i < n_reads; i++) {
out[i] = totals[i];
@@ -291,14 +283,7 @@ template <
}
}
template <
typename T,
typename U,
typename Op,
typename IdxT,
int NDIMS,
int BM,
int BN>
template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
[[kernel]] void col_reduce_2pass(
const device T* in [[buffer(0)]],
device U* out [[buffer(1)]],
@@ -327,7 +312,7 @@ template <
threadgroup U shared_vals[BN * BM];
U totals[n_reads];
LoopedElemToLoc<NDIMS, IdxT, (NDIMS > 2)> loop(reduce_ndim);
looped_elem_to_loc<NDIMS> loop;
const device T* row;
for (int i = 0; i < n_reads; i++) {
@@ -336,19 +321,20 @@ template <
short lid = simd_group_id * simd_size + simd_lane_id;
short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks);
IdxT column = BN * gid.x + offset.x;
size_t column = BN * gid.x + offset.x;
bool safe = column + n_reads <= reduction_stride;
IdxT full_idx = gid.y + gsize.y * IdxT(gid.z);
IdxT block_idx = full_idx / IdxT(out_size);
IdxT out_idx = full_idx % IdxT(out_size);
IdxT in_idx = elem_to_loc<size_t, IdxT>(out_idx, shape, strides, ndim);
size_t full_idx = gid.y + gsize.y * size_t(gid.z);
size_t block_idx = full_idx / out_size;
size_t out_idx = full_idx % out_size;
size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim);
in += in_idx + column;
IdxT total = IdxT(non_col_reductions) * IdxT(reduction_size);
size_t total = non_col_reductions * reduction_size;
loop.next(offset.y + block_idx * BM, reduce_shape, reduce_strides);
for (IdxT r = offset.y + block_idx * BM; r < total; r += outer_blocks * BM) {
row = in + loop.location();
for (size_t r = offset.y + block_idx * BM; r < total;
r += outer_blocks * BM) {
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
if (safe) {
for (int i = 0; i < n_reads; i++) {
@@ -383,8 +369,8 @@ template <
// Write the output.
if (simd_lane_id == 0) {
IdxT out_column = BN * gid.x + out_offset.x;
out += full_idx * IdxT(reduction_stride) + out_column;
size_t out_column = BN * gid.x + out_offset.x;
out += full_idx * reduction_stride + out_column;
if (out_column + n_outputs <= reduction_stride) {
for (int i = 0; i < n_outputs; i++) {
out[i] = totals[i];

View File

@@ -193,7 +193,6 @@ template <
typename T,
typename U,
typename Op,
typename IdxT,
int NDIMS,
int N_READS = REDUCE_N_READS>
[[kernel]] void row_reduce_small(
@@ -215,20 +214,20 @@ template <
Op op;
U total_val = Op::init;
LoopedElemToLoc<NDIMS, IdxT, (NDIMS > 2)> loop(reduce_ndim);
looped_elem_to_loc<NDIMS> loop;
// Precompute some row reduction numbers
const device T* row;
int blocks = IdxT(row_size) / N_READS;
int extra = IdxT(row_size) % N_READS;
int blocks = row_size / N_READS;
int extra = row_size % N_READS;
if ((non_row_reductions < 32 && row_size <= 8) || non_row_reductions <= 8) {
// Simple loop over non_row_reductions and reduce the row in the thread.
IdxT out_idx = tid.x + tsize.y * IdxT(tid.y);
in += elem_to_loc<size_t, IdxT>(out_idx, shape, strides, ndim);
size_t out_idx = tid.x + tsize.y * size_t(tid.y);
in += elem_to_loc(out_idx, shape, strides, ndim);
for (uint r = 0; r < non_row_reductions; r++) {
row = in + loop.location();
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
thread_reduce<T, U, Op, N_READS>(total_val, row, blocks, extra);
loop.next(reduce_shape, reduce_strides);
}
@@ -237,13 +236,13 @@ template <
} else {
// Collaboratively reduce over non_row_reductions in the simdgroup. Each
// thread reduces every 32nd row and then a simple simd reduce.
IdxT out_idx = gid.y + gsize.y * IdxT(gid.z);
in += elem_to_loc<size_t, IdxT>(out_idx, shape, strides, ndim);
size_t out_idx = gid.y + gsize.y * size_t(gid.z);
in += elem_to_loc(out_idx, shape, strides, ndim);
loop.next(simd_lane_id, reduce_shape, reduce_strides);
for (uint r = simd_lane_id; r < non_row_reductions; r += simd_size) {
row = in + loop.location();
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
thread_reduce<T, U, Op, N_READS>(total_val, row, blocks, extra);
loop.next(simd_size, reduce_shape, reduce_strides);
}
@@ -260,7 +259,6 @@ template <
typename T,
typename U,
typename Op,
typename IdxT = size_t,
int N_READS = REDUCE_N_READS,
int N_WRITES = REDUCE_N_WRITES>
[[kernel]] void row_reduce_simple(
@@ -279,15 +277,15 @@ template <
U totals[N_WRITES];
// Move to the row
IdxT out_idx = N_WRITES * (gid.y + gsize.y * IdxT(gid.z));
size_t out_idx = N_WRITES * (gid.y + gsize.y * size_t(gid.z));
if (out_idx + N_WRITES > out_size) {
out_idx = out_size - N_WRITES;
}
in += out_idx * IdxT(reduction_size);
in += out_idx * reduction_size;
out += out_idx;
// Each thread reduces across the row
int blocks = IdxT(reduction_size) / (lsize.x * N_READS);
int blocks = reduction_size / (lsize.x * N_READS);
int extra = reduction_size - blocks * (lsize.x * N_READS);
per_thread_row_reduce<T, U, Op, N_READS, N_WRITES>(
totals, in, reduction_size, blocks, extra, lsize.x, lid.x);
@@ -308,7 +306,6 @@ template <
typename T,
typename U,
typename Op,
typename IdxT,
int NDIMS,
int N_READS = REDUCE_N_READS>
[[kernel]] void row_reduce_looped(
@@ -333,20 +330,19 @@ template <
threadgroup U shared_vals[simd_size];
U total = Op::init;
IdxT out_idx = gid.y + gsize.y * IdxT(gid.z);
size_t out_idx = gid.y + gsize.y * size_t(gid.z);
// lid.x * N_READS breaks the per_thread_row_reduce interface a bit. Maybe it
// needs a small refactor.
in += elem_to_loc<size_t, IdxT>(out_idx, shape, strides, ndim) +
lid.x * N_READS;
in += elem_to_loc(out_idx, shape, strides, ndim) + lid.x * N_READS;
LoopedElemToLoc<NDIMS, IdxT, (NDIMS > 2)> loop(reduce_ndim);
looped_elem_to_loc<NDIMS> loop;
const device T* row;
int blocks = IdxT(row_size) / (lsize.x * N_READS);
int blocks = row_size / (lsize.x * N_READS);
int extra = row_size - blocks * (lsize.x * N_READS);
for (IdxT i = 0; i < non_row_reductions; i++) {
row = in + loop.location();
for (size_t i = 0; i < non_row_reductions; i++) {
row = in + loop.location(i, reduce_shape, reduce_strides, reduce_ndim);
// Each thread reduces across the row
U row_total;

View File

@@ -3,6 +3,8 @@
#include <metal_common>
#include <metal_simdgroup>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/utils.h"
using namespace metal;
@@ -15,15 +17,12 @@ template <typename T, int N_READS = RMS_N_READS>
constant float& eps,
constant uint& axis_size,
constant uint& w_stride,
threadgroup float* local_inv_mean [[threadgroup(0)]],
threadgroup float* local_sums [[threadgroup(1)]],
uint gid [[threadgroup_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
constexpr int SIMD_SIZE = 32;
threadgroup float local_inv_mean[1];
threadgroup float local_sums[SIMD_SIZE];
float acc = 0;
x += gid * size_t(axis_size) + lid * N_READS;
w += w_stride * lid * N_READS;
@@ -85,15 +84,13 @@ template <typename T, int N_READS = RMS_N_READS>
constant float& eps,
constant uint& axis_size,
constant uint& w_stride,
threadgroup float* local_inv_mean [[threadgroup(0)]],
threadgroup float* local_sums [[threadgroup(1)]],
uint gid [[threadgroup_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
constexpr int SIMD_SIZE = 32;
threadgroup float local_inv_mean[1];
threadgroup float local_sums[SIMD_SIZE];
float acc = 0;
x += gid * size_t(axis_size) + lid * N_READS;
w += w_stride * lid * N_READS;
@@ -379,6 +376,8 @@ template <typename T, int N_READS = RMS_N_READS>
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
threadgroup float* local_inv_mean [[threadgroup(0)]], \
threadgroup float* local_sums [[threadgroup(1)]], \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
@@ -408,6 +407,8 @@ template <typename T, int N_READS = RMS_N_READS>
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
threadgroup float* local_inv_mean [[threadgroup(0)]], \
threadgroup float* local_sums [[threadgroup(1)]], \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \

View File

@@ -2,6 +2,7 @@
#include <metal_math>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h"
template <typename T, bool traditional, bool forward>
void rope_single_impl(

View File

@@ -1,16 +1,946 @@
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/scaled_dot_product_attention_params.h"
#include "mlx/backend/metal/kernels/sdpa_vector.h"
#include "mlx/backend/metal/kernels/steel/defines.h"
#include "mlx/backend/metal/kernels/steel/gemm/transforms.h"
#include "mlx/backend/metal/kernels/utils.h"
using namespace metal;
using namespace mlx::steel;
template <
typename T,
short BROWS,
short BCOLS,
short dst_ld,
short reduction_dim,
short tgp_size,
short alignment = 1,
short n_reads = (BCOLS * BROWS) / (tgp_size),
short TCOLS = BCOLS / n_reads,
short TROWS = tgp_size / TCOLS>
struct BlockLoaderFA {
STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS;
STEEL_CONST short vec_size = n_reads;
// Leading dimension for src
const int src_ld;
const int tile_stride;
// Thread location indices
const short thread_idx;
const short bi;
const short bj;
// threadgroup and device memory
threadgroup T* dst;
const device T* src;
struct alignas(alignment * sizeof(T)) ReadVector {
uint8_t v[sizeof(T) * vec_size];
};
/* Constructor */
METAL_FUNC BlockLoaderFA(
const device T* src_,
const int src_ld_,
threadgroup T* dst_,
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
ushort simd_lane_id [[thread_index_in_simdgroup]])
: src_ld(src_ld_),
tile_stride(reduction_dim ? BCOLS : BROWS * src_ld),
thread_idx(simd_group_id * 32 + simd_lane_id),
bi(thread_idx / TCOLS),
bj(vec_size * (thread_idx % TCOLS)),
dst(dst_ + bi * dst_ld + bj),
src(src_ + bi * src_ld + bj) {}
/* Load from device memory into threadgroup memory - without bound checking */
METAL_FUNC void load_unsafe() const {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
*((threadgroup ReadVector*)(&dst[i * dst_ld])) =
*((const device ReadVector*)(&src[i * src_ld]));
}
}
/* Load from device memory into threadgroup memory - with bound checking */
METAL_FUNC void load_safe(short2 src_tile_dim) const {
src_tile_dim = src_tile_dim - short2(bj, bi);
// Skip loading if thread has no valid reads
if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = T(0);
}
}
return;
}
// Use fast thread memory for bound checks
bool tmp_idx[vec_size];
T tmp_val[vec_size];
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
// Make sure tmp_idx only contains valid indices
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x);
}
// Read valid indices into tmp_val
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
}
// Zero out uneeded values
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);
}
// Copy values to threadgroup memory
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = tmp_val[j];
}
}
}
/* Iteration helper */
METAL_FUNC void next() {
src += tile_stride;
}
METAL_FUNC void next(short n) {
src += n * tile_stride;
}
};
template <bool M_aligned, bool N_aligned, bool K_aligned>
struct LoopAlignment {};
template <
typename T,
typename U,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
short lda_tgp,
short ldb_tgp,
typename AccumType = float,
typename Epilogue = TransformNone<U, AccumType>>
struct BlockMMAFA {
// Warp tile simdgroup matrix strides along M
STEEL_CONST short TM_stride = 8 * WM;
// Warp tile simdgroup matrix strides along M
STEEL_CONST short TN_stride = 8 * WN;
// Warp tile size along M
STEEL_CONST short TM = BM / TM_stride;
// Warp tile size along N
STEEL_CONST short TN = BN / TN_stride;
// Strides of A, B along reduction axis
STEEL_CONST short simd_stride_a = {
transpose_a ? TM_stride : TM_stride * lda_tgp};
STEEL_CONST short simd_stride_b = {
transpose_b ? TN_stride * ldb_tgp : TN_stride};
// Jump between elements
STEEL_CONST short jump_a = {transpose_a ? lda_tgp : 1};
STEEL_CONST short jump_b = {transpose_b ? ldb_tgp : 1};
STEEL_CONST short tile_stride_a = {transpose_a ? 8 * lda_tgp : 8};
STEEL_CONST short tile_stride_b = {transpose_b ? 8 : 8 * ldb_tgp};
// Simdgroup matrices
simdgroup_matrix<AccumType, 8, 8> Asimd[TM];
simdgroup_matrix<AccumType, 8, 8> Bsimd[TN];
simdgroup_matrix<AccumType, 8, 8> results[TM * TN] = {
simdgroup_matrix<AccumType, 8, 8>(0)};
// Offsets within threadgroup
const short tm;
const short tn;
short sm;
short sn;
ushort sid;
ushort slid;
short As_offset;
short Bs_offset;
/* Constructor */
METAL_FUNC BlockMMAFA(
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
ushort simd_lane_id [[thread_index_in_simdgroup]])
: tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) {
// Determine thread position in simdgroup matrix
short qid = simd_lane_id / 4;
slid = simd_lane_id;
sid = simd_group_id;
sm = (qid & 4) + (simd_lane_id / 2) % 4;
sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
// Determine thread and simdgroup offset
As_offset =
transpose_a ? ((sn)*lda_tgp + (tm + sm)) : ((sn) + (tm + sm) * lda_tgp);
Bs_offset =
transpose_b ? ((tn + sn) * ldb_tgp + (sm)) : ((sm)*ldb_tgp + (tn + sn));
}
/* (BM, BK) X (BK, BN) multiply accumulate function */
METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) {
// Adjust for simdgroup and thread location
As += As_offset;
Bs += Bs_offset;
// Iterate over BK in blocks of 8
STEEL_PRAGMA_UNROLL
for (short kk = 0; kk < BK; kk += 8) {
simdgroup_barrier(mem_flags::mem_none);
// Load elements from threadgroup A as simdgroup matrices
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
Asimd[i].thread_elements()[0] =
static_cast<AccumType>(As[i * simd_stride_a + 0]);
Asimd[i].thread_elements()[1] =
static_cast<AccumType>(As[i * simd_stride_a + jump_a]);
}
simdgroup_barrier(mem_flags::mem_none);
// Load elements from threadgroup B as simdgroup matrices
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
Bsimd[j].thread_elements()[0] =
static_cast<AccumType>(Bs[j * simd_stride_b + 0]);
Bsimd[j].thread_elements()[1] =
static_cast<AccumType>(Bs[j * simd_stride_b + jump_b]);
}
simdgroup_barrier(mem_flags::mem_none);
// Multiply and accumulate into result simdgroup matrices
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
short j_serp = (i % 2) ? (TN - 1 - j) : j;
simdgroup_multiply_accumulate(
results[i * TN + j_serp],
Asimd[i],
Bsimd[j_serp],
results[i * TN + j_serp]);
}
}
// Progress to next simdgroup tile
As += tile_stride_a;
Bs += tile_stride_b;
}
}
METAL_FUNC void rescale_output(const threadgroup float* Corrections) {
// Loop over all simdgroup tiles
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
short row = sm + tm + i * TM_stride;
float scale_value = Corrections[row];
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread auto& accum = results[i * TN + j].thread_elements();
// int offset = (i * TM_stride) * ldc + (j * TN_stride);
accum[0] *= scale_value;
accum[1] *= scale_value;
}
}
}
/* Store results from simdgroup_matrix results into device memory */
METAL_FUNC void store_result(device U* C, const int ldc) const {
// Adjust for simdgroup and thread location
C += (sm + tm) * ldc + tn + sn;
// Loop over all simdgroup tiles
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread const auto& accum = results[i * TN + j].thread_elements();
int offset = (i * TM_stride) * ldc + (j * TN_stride);
// Apply epilogue
U outs[2] = {Epilogue::apply(accum[0]), Epilogue::apply(accum[1])};
// Write out C
C[offset] = outs[0];
C[offset + 1] = outs[1];
}
}
}
METAL_FUNC void store_result_to_tgp_memory(
threadgroup U* C,
const int ldc,
short2 dst_tile_dims) const {
// Adjust for simdgroup and thread location
C += (sm + tm) * ldc + (tn + sn);
dst_tile_dims -= short2(tn + sn, sm + tm);
STEEL_PRAGMA_UNROLL
for (int i = 0; i < TM; i++) {
if (i * TM_stride < dst_tile_dims.y) {
STEEL_PRAGMA_UNROLL
for (int j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread const auto& accum = results[i * TN + j].thread_elements();
int offset = (i * TM_stride) * ldc + (j * TN_stride);
// Apply epilogue and output C
if (j * TN_stride < dst_tile_dims.x) {
C[offset] = Epilogue::apply(accum[0]);
}
if (j * TN_stride + 1 < dst_tile_dims.x) {
C[offset + 1] = Epilogue::apply(accum[1]);
}
}
}
}
}
METAL_FUNC void
store_result_safe(device U* C, const int ldc, short2 dst_tile_dims) const {
// Adjust for simdgroup and thread location
C += (sm + tm) * ldc + (tn + sn);
dst_tile_dims -= short2(tn + sn, sm + tm);
STEEL_PRAGMA_UNROLL
for (int i = 0; i < TM; i++) {
if (i * TM_stride < dst_tile_dims.y) {
STEEL_PRAGMA_UNROLL
for (int j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread const auto& accum = results[i * TN + j].thread_elements();
int offset = (i * TM_stride) * ldc + (j * TN_stride);
// Apply epilogue and output C
if (j * TN_stride < dst_tile_dims.x) {
C[offset] = Epilogue::apply(accum[0]);
}
if (j * TN_stride + 1 < dst_tile_dims.x) {
C[offset + 1] = Epilogue::apply(accum[1]);
}
}
}
}
}
/* Store results from simdgroup_matrix results into device memory */
METAL_FUNC void store_result(
device U* D,
const int ldd,
const device U* C,
const int ldc,
const int fdc,
thread const Epilogue& epilogue_op) const {
// Adjust for simdgroup and thread location
C += (sm + tm) * ldc + (tn + sn) * fdc;
D += (sm + tm) * ldd + tn + sn;
// Loop over all simdgroup tiles
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread const auto& accum = results[i * TN + j].thread_elements();
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
// Apply epilogue
U outs[2] = {
epilogue_op.apply(accum[0], C[offset_c]),
epilogue_op.apply(accum[1], C[offset_c + fdc])};
// Write out D
D[offset_d] = outs[0];
D[offset_d + 1] = outs[1];
}
}
}
METAL_FUNC void store_result_safe(
device U* D,
const int ldd,
const device U* C,
const int ldc,
const int fdc,
short2 dst_tile_dims,
thread const Epilogue& epilogue_op) const {
// Adjust for simdgroup and thread location
C += (sm + tm) * ldc + (tn + sn) * fdc;
D += (sm + tm) * ldd + tn + sn;
dst_tile_dims -= short2(tn + sn, sm + tm);
STEEL_PRAGMA_UNROLL
for (int i = 0; i < TM; i++) {
if (i * TM_stride < dst_tile_dims.y) {
STEEL_PRAGMA_UNROLL
for (int j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread const auto& accum = results[i * TN + j].thread_elements();
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
// Apply epilogue and output C
if (j * TN_stride < dst_tile_dims.x) {
D[offset_d] = epilogue_op.apply(accum[0], C[offset_c]);
}
if (j * TN_stride + 1 < dst_tile_dims.x) {
D[offset_d + 1] = epilogue_op.apply(accum[1], C[offset_c + fdc]);
}
}
}
}
}
METAL_FUNC void clear_results() {
STEEL_PRAGMA_UNROLL
for (int i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (int j = 0; j < TN; j++) {
results[i * TN + j] = simdgroup_matrix<AccumType, 8, 8>(0);
}
}
}
};
template <
typename T,
typename U,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_q,
bool transpose_k,
bool transpose_v,
bool MN_aligned,
bool K_aligned,
typename AccumType = typename AccumHelper<T>::accum_type,
typename Epilogue = TransformNone<U, AccumType>>
struct FastAttentionKernel {
STEEL_CONST short tgp_padding = 16 / sizeof(T);
STEEL_CONST short float_padding = 16 / sizeof(float);
STEEL_CONST short tgp_mem_size_q =
transpose_q ? BK * (BM + tgp_padding) : BM * (BK + tgp_padding);
STEEL_CONST short tgp_mem_size_k =
transpose_k ? BK * (BN + tgp_padding) : BN * (BK + tgp_padding);
STEEL_CONST short tgp_mem_size_v =
transpose_v ? BK * (BN + tgp_padding) : BN * (BK + tgp_padding);
STEEL_CONST short tgp_mem_size_s = BM * (BN + tgp_padding);
// maxes, rowsums, rescale
STEEL_CONST short tgp_mem_size_corrections =
4 * (BM * sizeof(float) + float_padding);
STEEL_CONST bool share_kv_smem = transpose_k != transpose_v;
STEEL_CONST short tgp_mem_size = share_kv_smem
? tgp_mem_size_q + tgp_mem_size_k + tgp_mem_size_s +
tgp_mem_size_corrections
: tgp_mem_size_q + tgp_mem_size_k + tgp_mem_size_s +
tgp_mem_size_corrections + tgp_mem_size_v;
STEEL_CONST short tgp_size = WM * WN * 32;
static_assert(transpose_q == false, "Expected Q not transposed.");
static_assert(transpose_k == true, "Expected K transposed.");
static_assert(transpose_v == false, "Expected V not transposed.");
static_assert(tgp_mem_size <= 32768, "Excessive tgp memory requested.");
using loader_q_t = BlockLoaderFA<
T,
transpose_q ? BK : BM,
transpose_q ? BM : BK,
transpose_q ? BM + tgp_padding : BK + tgp_padding,
!transpose_q,
tgp_size>;
using loader_k_t = BlockLoaderFA<
T,
transpose_k ? BN : BK,
transpose_k ? BK : BN,
transpose_k ? BK + tgp_padding : BN + tgp_padding,
transpose_k,
tgp_size>;
using loader_v_t = BlockLoaderFA<
T,
transpose_v ? BK : BN,
transpose_v ? BN : BK,
transpose_v ? BN + tgp_padding : BK + tgp_padding,
transpose_v,
tgp_size>;
using mma_qk_t = BlockMMAFA<
T,
U,
BM,
BN,
BK,
WM,
WN,
transpose_q,
transpose_k,
transpose_q ? BM + tgp_padding : BK + tgp_padding,
transpose_k ? BK + tgp_padding : BN + tgp_padding,
AccumType,
Epilogue>;
using mma_sv_t = BlockMMAFA<
T,
U,
BM,
BK,
BN,
WM,
WN,
false,
transpose_v,
BN + tgp_padding,
BK + tgp_padding,
AccumType,
Epilogue>;
/* Main kernel function */
template <bool M_aligned, bool N_aligned, bool K_aligned_>
static METAL_FUNC void gemm_loop(
threadgroup T* As [[threadgroup(0)]],
threadgroup T* Bs [[threadgroup(1)]],
const int gemm_k_iterations,
thread loader_k_t& loader_b,
thread mma_qk_t& mma_op,
thread const short& tgp_bm,
thread const short& tgp_bn,
LoopAlignment<M_aligned, N_aligned, K_aligned_> l = {}) {
// Appease the compiler
(void)l;
(void)tgp_bm;
short2 tile_dims_B = transpose_k ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
// not valid for gemm_k_iterations > 1 (so, BK == d_k)
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
if (N_aligned) {
loader_b.load_unsafe();
} else {
loader_b.load_safe(tile_dims_B);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
}
}
static METAL_FUNC void initialize_corrections(
threadgroup float* C,
uint simd_lane_id,
uint simd_group_id) {
if (simd_group_id == 0) {
threadgroup float* maxes = C;
threadgroup float* sums = C + (BM + float_padding);
threadgroup float* o_rescale = sums + (BM + float_padding);
threadgroup float* output_rescale = o_rescale + (BM + float_padding);
if (simd_lane_id < BM) {
maxes[simd_lane_id] = -INFINITY; // m_i
sums[simd_lane_id] = 0.f; // l_i
o_rescale[simd_lane_id] = 1.f; // li * exp(mi - mi_new)
output_rescale[simd_lane_id] = 1.f; // 1.0 / l_i
}
}
}
static METAL_FUNC void rescale_ss(
threadgroup T* Ss,
threadgroup float* Corrections,
uint simd_group_id,
uint simd_lane_id,
short2 local_blocks,
float alpha) {
if (simd_group_id == 0) {
short row_offset = BM + float_padding;
threadgroup float* maxes = Corrections;
threadgroup float* sums = Corrections + row_offset;
threadgroup float* o_rescale = sums + row_offset;
threadgroup float* output_scales = o_rescale + row_offset;
if (simd_lane_id < uint(local_blocks.y)) {
float m_i_old = maxes[simd_lane_id];
float l_i_old = sums[simd_lane_id];
float m_i_new = m_i_old;
float l_i_new = l_i_old;
short offset = simd_lane_id * (BN + tgp_padding);
float m_ij = -INFINITY;
for (short j = 0; j < local_blocks.x; j++) {
float val = alpha * float(Ss[offset + j]);
m_ij = max(m_ij, val);
}
m_i_new = max(m_ij, m_i_new);
float rowsum = 0.f; // lij
for (short j = 0; j < local_blocks.x; j++) {
float val = alpha * float(Ss[offset + j]);
float P_i_j = exp(val - m_ij);
rowsum += P_i_j;
P_i_j = P_i_j * exp(m_ij - m_i_new);
Ss[offset + j] = T(P_i_j);
}
l_i_new =
exp(m_i_old - m_i_new) * l_i_old + exp(m_ij - m_i_new) * rowsum;
maxes[simd_lane_id] = m_i_new;
sums[simd_lane_id] = l_i_new;
float rescale = l_i_old * exp(m_i_old - m_i_new);
o_rescale[simd_lane_id] = rescale;
output_scales[simd_lane_id] = 1.0 / l_i_new;
}
}
}
/* Main kernel function */
static METAL_FUNC void run(
const device T* Q [[buffer(0)]],
const device T* K [[buffer(1)]],
const device T* V [[buffer(2)]],
device U* O [[buffer(3)]],
const constant MLXFastAttentionParams* params [[buffer(4)]],
threadgroup T* Qs [[threadgroup(0)]],
threadgroup T* Ks [[threadgroup(1)]],
threadgroup T* Ss [[threadgroup(2)]],
threadgroup T* Vs [[threadgroup(3)]],
threadgroup float* Corrections [[threadgroup(4)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
// Pacifying compiler
(void)lid;
const int tid_y = ((tid.y) << params->swizzle_log) +
((tid.x) & ((1 << params->swizzle_log) - 1));
const int tid_x = (tid.x) >> params->swizzle_log;
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
return;
}
threadgroup_barrier(mem_flags::mem_none);
// Find block in Q, O; and head in K, V.
const int c_row = tid_y * BM;
Q += transpose_q ? c_row : c_row * params->ldq;
thread loader_q_t loader_q(Q, params->ldq, Qs, simd_group_id, simd_lane_id);
short tgp_bm = min(BM, params->M - c_row);
short2 tile_dims_Q = transpose_q ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
loader_q.load_safe(tile_dims_Q);
initialize_corrections(Corrections, simd_lane_id, simd_group_id);
O += c_row * params->ldo;
// Prepare threadgroup mma operation
thread mma_qk_t mma_qk_op(simd_group_id, simd_lane_id);
thread mma_sv_t mma_softmax_sv_op(simd_group_id, simd_lane_id);
thread loader_k_t loader_k(K, params->ldk, Ks, simd_group_id, simd_lane_id);
thread loader_v_t loader_v(V, params->ldv, Vs, simd_group_id, simd_lane_id);
for (short n_block = 0; n_block < params->gemm_n_iterations_aligned;
n_block++) {
short c_col = BN;
// Prepare threadgroup loading operations
short gemm_k_iterations = params->gemm_k_iterations_aligned;
short tgp_bn_qk = min(BN, params->N - c_col * n_block);
threadgroup_barrier(mem_flags::mem_none);
///////////////////////////////////////////////////////////////////////////////
{ // Loop over K - unaligned case
if (tgp_bm == BM && tgp_bn_qk == BN) {
gemm_loop<true, true, K_aligned>(
Qs,
Ks,
gemm_k_iterations,
loader_k,
mma_qk_op,
tgp_bm,
tgp_bn_qk);
} else if (tgp_bn_qk == BN) {
gemm_loop<false, true, K_aligned>(
Qs,
Ks,
gemm_k_iterations,
loader_k,
mma_qk_op,
tgp_bm,
tgp_bn_qk);
} else if (tgp_bm == BM) {
gemm_loop<true, false, K_aligned>(
Qs,
Ks,
gemm_k_iterations,
loader_k,
mma_qk_op,
tgp_bm,
tgp_bn_qk);
} else {
gemm_loop<false, false, K_aligned>(
Qs,
Ks,
gemm_k_iterations,
loader_k,
mma_qk_op,
tgp_bm,
tgp_bn_qk);
}
}
mma_qk_op.store_result_to_tgp_memory(
Ss, BN + tgp_padding, short2(BN, BM));
threadgroup_barrier(mem_flags::mem_threadgroup);
rescale_ss(
Ss,
Corrections,
simd_group_id,
simd_lane_id,
short2(tgp_bn_qk, tgp_bm),
params->alpha);
loader_v.load_safe(short2(BK, tgp_bn_qk));
threadgroup_barrier(mem_flags::mem_threadgroup);
threadgroup float* o_scales = Corrections + 2 * (BM + float_padding);
mma_softmax_sv_op.rescale_output(o_scales);
mma_softmax_sv_op.mma(Ss, Vs);
threadgroup float* final_output_scales =
Corrections + 3 * (BM + float_padding);
mma_softmax_sv_op.rescale_output(final_output_scales);
loader_v.next();
loader_k.next(BN);
mma_qk_op.clear_results();
}
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_softmax_sv_op.store_result_safe(O, params->ldo, short2(BK, tgp_bm));
}
};
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_q,
bool transpose_k,
bool transpose_v,
bool MN_aligned,
bool K_aligned>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void attention(
const device T* Q [[buffer(0)]],
const device T* K [[buffer(1)]],
const device T* V [[buffer(2)]],
device T* O [[buffer(3)]],
const constant MLXFastAttentionParams* params [[buffer(4)]],
const constant int* batch_shape [[buffer(6)]],
const constant size_t* batch_strides [[buffer(7)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
using attention_kernel = FastAttentionKernel<
T,
T,
BM,
BN,
BK,
WM,
WN,
transpose_q,
transpose_k,
transpose_v,
MN_aligned,
K_aligned>;
// Adjust for batch
if (params->batch_ndim > 1) {
const constant size_t* Q_bstrides = batch_strides;
const constant size_t* KV_bstrides = batch_strides + params->batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, Q_bstrides, KV_bstrides, params->batch_ndim);
Q += batch_offsets.x;
K += batch_offsets.y;
V += batch_offsets.y;
} else {
Q += params->batch_stride_q * tid.z;
K += params->batch_stride_k * tid.z;
V += params->batch_stride_v * tid.z;
}
// same shape as input
O += params->batch_stride_o * tid.z;
threadgroup T Qs[attention_kernel::tgp_mem_size_q];
threadgroup T Ss[attention_kernel::tgp_mem_size_s];
threadgroup float Corrections[attention_kernel::tgp_mem_size_corrections];
if (attention_kernel::share_kv_smem) {
threadgroup T Ks[attention_kernel::tgp_mem_size_k];
threadgroup T* Vs = Ks; //[attention_kernel::tgp_mem_size_v];
attention_kernel::run(
Q,
K,
V,
O,
params,
Qs,
Ks,
Ss,
Vs,
Corrections,
simd_lane_id,
simd_group_id,
tid,
lid);
} else {
threadgroup T Ks[attention_kernel::tgp_mem_size_k];
threadgroup T Vs[attention_kernel::tgp_mem_size_v];
attention_kernel::run(
Q,
K,
V,
O,
params,
Qs,
Ks,
Ss,
Vs,
Corrections,
simd_lane_id,
simd_group_id,
tid,
lid);
}
}
// clang-format off
// SDPA full instantiations
#define instantiate_fast_inference_self_attention_kernel( \
itype, otype, bm, bn, bk, wm, wn) \
template [[host_name("steel_gemm_attention_bm_" #bm "_bn_" #bn "_bk_" #bk \
"_itype_" #itype)]] [[kernel]] void \
attention<itype, bm, bn, bk, wm, wn, false, true, false, false, true>( \
const device itype* Q [[buffer(0)]], \
const device itype* K [[buffer(1)]], \
const device itype* V [[buffer(2)]], \
device otype* O [[buffer(3)]], \
const constant MLXFastAttentionParams* params [[buffer(4)]], \
const constant int* batch_shape [[buffer(6)]], \
const constant size_t* batch_strides [[buffer(7)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);
instantiate_fast_inference_self_attention_kernel(
float,
float,
16,
16,
64,
2,
2);
instantiate_fast_inference_self_attention_kernel(
float,
float,
16,
16,
128,
2,
2);
instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 64, 2, 2);
instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 128, 2, 2);
// SDPA vector instantiations
#define instantiate_sdpa_vector(type, head_dim) \
instantiate_kernel("sdpa_vector_" #type "_" #head_dim, sdpa_vector, type, head_dim) \
instantiate_kernel("sdpa_vector_2pass_1_" #type "_" #head_dim, sdpa_vector_2pass_1, type, head_dim) \
instantiate_kernel("sdpa_vector_2pass_2_" #type "_" #head_dim, sdpa_vector_2pass_2, type, head_dim)
#define instantiate_sdpa_vector(type, head_dim) \
template [[host_name("sdpa_vector_" #type "_" #head_dim)]] \
[[kernel]] void sdpa_vector<type, head_dim>( \
const device type* queries [[buffer(0)]], \
const device type* keys [[buffer(1)]], \
const device type* values [[buffer(2)]], \
device type* out [[buffer(3)]], \
const constant int& gqa_factor, \
const constant int& N, \
const constant size_t& k_stride, \
const constant size_t& v_stride, \
const constant float& scale, \
uint3 tid [[threadgroup_position_in_grid]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_sdpa_vector_heads(type) \
instantiate_sdpa_vector(type, 64) \

View File

@@ -0,0 +1,42 @@
//
// scaled_dot_product_attention_params.h
// mlx
#pragma once
struct MLXFastAttentionParams {
const int M;
const int N;
const int K;
const int ldq; // ldq == ldo
const int ldk;
const int ldv;
const int lds;
const int ldo;
const int tiles_n;
const int tiles_m;
const int batch_stride_q;
const int batch_stride_k;
const int batch_stride_v;
const int batch_stride_o;
const int swizzle_log;
const int gemm_n_iterations_aligned;
const int gemm_k_iterations_aligned;
const int gemm_sv_m_block_iterations;
const int batch_ndim;
const float alpha;
};
struct MLXScaledDotProductAttentionParams {
// Associated dimensions & transposition information
const uint QUERY_SEQUENCE_LENGTH = 1;
const uint N_Q_HEADS = 32;
const uint N_KV_HEADS = 32;
const uint KV_TILES = 1;
const float INV_ALPHA = 0.08838834764831843f;
};

View File

@@ -10,8 +10,7 @@ template <
typename Op,
int NIDX,
bool UPD_ROW_CONTIG,
int NWORK,
typename LocT>
int NWORK>
METAL_FUNC void scatter_impl(
const device T* updates,
device mlx_atomic<T>* out,
@@ -29,31 +28,29 @@ METAL_FUNC void scatter_impl(
Op op;
auto ind_idx = gid.y * NWORK;
LocT out_offset = 0;
size_t out_offset = 0;
if (upd_size > 1) {
out_offset = elem_to_loc<size_t, LocT>(
gid.x, upd_shape + indices.ndim, out_strides, out_ndim);
out_offset =
elem_to_loc(gid.x, upd_shape + indices.ndim, out_strides, out_ndim);
}
for (int j = 0; j < NWORK && ind_idx < idx_size; ++j, ind_idx++) {
LocT out_idx = out_offset;
size_t out_idx = out_offset;
for (int i = 0; i < NIDX; ++i) {
auto idx_loc = indices.row_contiguous[i]
? ind_idx
: elem_to_loc<size_t, LocT>(
: elem_to_loc(
ind_idx,
&indices.shapes[indices.ndim * i],
&indices.strides[indices.ndim * i],
indices.ndim);
auto ax = axes[i];
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], out_shape[ax]);
out_idx +=
static_cast<LocT>(idx_val) * static_cast<LocT>(out_strides[ax]);
out_idx += idx_val * out_strides[ax];
}
auto upd_idx = ind_idx * static_cast<LocT>(upd_size) + gid.x;
auto upd_idx = ind_idx * upd_size + gid.x;
if constexpr (!UPD_ROW_CONTIG) {
upd_idx =
elem_to_loc<size_t, LocT>(upd_idx, upd_shape, upd_strides, upd_ndim);
upd_idx = elem_to_loc(upd_idx, upd_shape, upd_strides, upd_ndim);
}
op.atomic_update(out, updates[upd_idx], out_idx);
}

View File

@@ -21,7 +21,8 @@ template <typename T, int D>
constexpr int BN = 32;
constexpr int BD = 32;
constexpr int elem_per_thread = D / BD;
constexpr int stride = BN * D;
const int stride = BN * D;
typedef float U;
@@ -83,6 +84,7 @@ template <typename T, int D>
keys += stride;
values += stride;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Each thread has a partial part of the output so we need to combine them.
@@ -112,181 +114,3 @@ template <typename T, int D>
}
}
}
template <typename T, int D>
[[kernel]] void sdpa_vector_2pass_1(
const device T* queries [[buffer(0)]],
const device T* keys [[buffer(1)]],
const device T* values [[buffer(2)]],
device float* out [[buffer(3)]],
device float* sums [[buffer(4)]],
device float* maxs [[buffer(5)]],
const constant int& gqa_factor,
const constant int& N,
const constant size_t& k_stride,
const constant size_t& v_stride,
const constant float& scale,
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int BN = 8;
constexpr int BD = 32;
constexpr int elem_per_thread = D / BD;
constexpr int stride = BN * D;
constexpr int blocks = 32;
typedef float U;
thread U q[elem_per_thread];
thread U k[elem_per_thread];
thread U o[elem_per_thread];
threadgroup U outputs[BN * BD];
threadgroup U max_scores[BN];
threadgroup U sum_exp_scores[BN];
// Adjust positions
const int block_idx = tid.z;
const int head_idx = tid.y;
const int kv_head_idx = head_idx / gqa_factor;
queries += head_idx * D + simd_lid * elem_per_thread;
keys += kv_head_idx * k_stride + (block_idx * BN + simd_gid) * D +
simd_lid * elem_per_thread;
values += kv_head_idx * v_stride + (block_idx * BN + simd_gid) * D +
simd_lid * elem_per_thread;
out += head_idx * blocks * D + block_idx * D + simd_lid * elem_per_thread;
sums += head_idx * blocks + block_idx;
maxs += head_idx * blocks + block_idx;
// Read the query and 0 the output accumulator
for (int i = 0; i < elem_per_thread; i++) {
q[i] = static_cast<U>(scale) * queries[i];
}
for (int i = 0; i < elem_per_thread; i++) {
o[i] = 0;
}
U max_score = -1e9;
U sum_exp_score = 0;
// For each key
for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) {
// Read the key
for (int i = 0; i < elem_per_thread; i++) {
k[i] = keys[i];
}
// Compute the i-th score
U score = 0;
for (int i = 0; i < elem_per_thread; i++) {
score += q[i] * k[i];
}
score = simd_sum(score);
// Update the accumulators
U new_max = max(max_score, score);
U factor = fast::exp(max_score - new_max);
U exp_score = fast::exp(score - new_max);
max_score = new_max;
sum_exp_score = sum_exp_score * factor + exp_score;
// Update the output accumulator
for (int i = 0; i < elem_per_thread; i++) {
o[i] = o[i] * factor + exp_score * values[i];
}
// Move the pointers to the next kv
keys += blocks * stride;
values += blocks * stride;
}
// Each thread has a partial part of the output so we need to combine them.
// First let's communicate the max and sum_exp
if (simd_lid == 0) {
max_scores[simd_gid] = max_score;
sum_exp_scores[simd_gid] = sum_exp_score;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
max_score = (simd_lid < BN) ? max_scores[simd_lid] : -1e9;
U new_max = simd_max(max_score);
U factor = fast::exp(max_score - new_max);
sum_exp_score = (simd_lid < BN) ? sum_exp_scores[simd_lid] : 0;
sum_exp_score = simd_sum(sum_exp_score * factor);
// Write the sum and new max
if (simd_gid == 0) {
sums[0] = sum_exp_score;
maxs[0] = new_max;
}
// Now we need to aggregate all the outputs
for (int i = 0; i < elem_per_thread; i++) {
outputs[simd_lid * BN + simd_gid] =
o[i] * fast::exp(max_scores[simd_gid] - new_max);
threadgroup_barrier(mem_flags::mem_threadgroup);
// And write the output
if (simd_gid == 0) {
U output = outputs[simd_lid * BN];
for (int j = 1; j < BN; j++) {
output += outputs[simd_lid * BN + j];
}
out[i] = static_cast<T>(output);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
}
template <typename T, int D>
[[kernel]] void sdpa_vector_2pass_2(
const device float* partials [[buffer(0)]],
const device float* sums [[buffer(1)]],
const device float* maxs [[buffer(2)]],
device T* out [[buffer(3)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int BN = 32;
constexpr int BD = 32;
constexpr int elem_per_thread = D / BD;
constexpr int blocks = 32;
typedef float U;
thread U o[elem_per_thread];
threadgroup U outputs[BN * BD];
// Adjust positions
const int head_idx = tid.y;
partials += head_idx * blocks * D + simd_gid * D + simd_lid * elem_per_thread;
sums += head_idx * blocks;
maxs += head_idx * blocks;
out += head_idx * D + simd_gid * elem_per_thread;
// First everybody reads the max and sum_exp
U max_score = maxs[simd_lid];
U new_max = simd_max(max_score);
U factor = fast::exp(max_score - new_max);
U sum_exp_score = simd_sum(sums[simd_lid] * factor);
// Now read the block into registers and then use shared memory to transpose
// it
for (int i = 0; i < elem_per_thread; i++) {
o[i] = partials[i];
}
for (int i = 0; i < elem_per_thread; i++) {
outputs[simd_lid * BD + simd_gid] = o[i];
threadgroup_barrier(mem_flags::mem_threadgroup);
o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score;
threadgroup_barrier(mem_flags::mem_threadgroup);
}
// And write the output
if (simd_lid == 0) {
for (int i = 0; i < elem_per_thread; i++) {
out[i] = static_cast<T>(o[i]);
}
}
}

View File

@@ -6,6 +6,8 @@
using namespace metal;
// clang-format off
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/softmax.h"

View File

@@ -3,6 +3,8 @@
#include <metal_stdlib>
// clang-format off
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/sort.h"

View File

@@ -1,296 +0,0 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/backend/metal/kernels/steel/attn/loader.h"
#include "mlx/backend/metal/kernels/steel/attn/mma.h"
#include "mlx/backend/metal/kernels/steel/attn/params.h"
#include "mlx/backend/metal/kernels/steel/attn/transforms.h"
#include "mlx/backend/metal/kernels/steel/gemm/params.h"
#include "mlx/backend/metal/kernels/steel/utils.h"
using namespace metal;
///////////////////////////////////////////////////////////////////////////////
// GEMM kernel class
///////////////////////////////////////////////////////////////////////////////
namespace mlx {
namespace steel {
template <bool M_aligned, bool N_aligned, bool K_aligned>
struct LoopAlignment {};
template <
typename T,
typename U,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
bool MN_aligned,
bool K_aligned,
typename AccumType = typename AccumHelper<T>::accum_type,
typename Epilogue = TransformNone<U, AccumType>>
struct GEMMKernel {
STEEL_CONST short tgp_padding_a = 16 / sizeof(T);
STEEL_CONST short tgp_padding_b = 16 / sizeof(T);
STEEL_CONST short tgp_mem_size_a =
transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a);
STEEL_CONST short tgp_mem_size_b =
transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b);
STEEL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b;
STEEL_CONST short tgp_size = WM * WN * 32;
using loader_a_t = BlockLoader<
T,
transpose_a ? BK : BM,
transpose_a ? BM : BK,
transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a,
!transpose_a,
tgp_size>;
using loader_b_t = BlockLoader<
T,
transpose_b ? BN : BK,
transpose_b ? BK : BN,
transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b,
transpose_b,
tgp_size>;
using mma_t = BlockMMA<
T,
U,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a,
transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b,
AccumType,
Epilogue>;
/* Main kernel function */
template <bool M_aligned, bool N_aligned, bool K_aligned_>
static METAL_FUNC void gemm_loop(
threadgroup T* As [[threadgroup(0)]],
threadgroup T* Bs [[threadgroup(1)]],
const int gemm_k_iterations,
thread loader_a_t& loader_a,
thread loader_b_t& loader_b,
thread mma_t& mma_op,
thread const short& tgp_bm,
thread const short& tgp_bn,
thread const short& lbk,
LoopAlignment<M_aligned, N_aligned, K_aligned_> l = {}) {
// Appease the compiler
(void)l;
short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
if (M_aligned) {
loader_a.load_unsafe();
} else {
loader_a.load_safe(tile_dims_A);
}
if (N_aligned) {
loader_b.load_unsafe();
} else {
loader_b.load_safe(tile_dims_B);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
if (!K_aligned_) {
threadgroup_barrier(mem_flags::mem_threadgroup);
short2 tile_dims_A_last =
transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm);
short2 tile_dims_B_last =
transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk);
loader_a.load_safe(tile_dims_A_last);
loader_b.load_safe(tile_dims_B_last);
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
}
/* Main kernel function */
static METAL_FUNC void run(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
device U* D [[buffer(2)]],
const constant GEMMParams* params [[buffer(3)]],
threadgroup T* As [[threadgroup(0)]],
threadgroup T* Bs [[threadgroup(1)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
// Pacifying compiler
(void)lid;
const int tid_y = ((tid.y) << params->swizzle_log) +
((tid.x) & ((1 << params->swizzle_log) - 1));
const int tid_x = (tid.x) >> params->swizzle_log;
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
return;
}
threadgroup_barrier(mem_flags::mem_none);
// Find block in A, B, C
const int c_row = tid_y * BM;
const int c_col = tid_x * BN;
const size_t c_row_long = size_t(c_row);
const size_t c_col_long = size_t(c_col);
A += transpose_a ? c_row_long : c_row_long * params->lda;
B += transpose_b ? c_col_long * params->ldb : c_col_long;
D += c_row_long * params->ldd + c_col_long;
// Prepare threadgroup loading operations
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
// Prepare threadgroup mma operation
thread mma_t mma_op(simd_group_id, simd_lane_id);
int gemm_k_iterations = params->gemm_k_iterations_aligned;
///////////////////////////////////////////////////////////////////////////////
// MNK aligned loop
if (MN_aligned) {
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
threadgroup_barrier(mem_flags::mem_none);
// Loop tail
if (!K_aligned) {
int lbk = params->K - params->gemm_k_iterations_aligned * BK;
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
loader_a.load_safe(tile_dims_A);
loader_b.load_safe(tile_dims_B);
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
// Store results to device memory
mma_op.store_result(D, params->ldd);
return;
}
///////////////////////////////////////////////////////////////////////////////
// MN unaligned loop
else { // Loop over K - unaligned case
short tgp_bm = min(BM, params->M - c_row);
short tgp_bn = min(BN, params->N - c_col);
short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK;
if (tgp_bm == BM && tgp_bn == BN) {
gemm_loop<true, true, K_aligned>(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk);
mma_op.store_result(D, params->ldd);
return;
} else if (tgp_bn == BN) {
gemm_loop<false, true, K_aligned>(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk);
mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
return;
} else if (tgp_bm == BM) {
gemm_loop<true, false, K_aligned>(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk);
mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
return;
} else {
gemm_loop<false, false, K_aligned>(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk);
mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
return;
}
}
}
};
} // namespace steel
} // namespace mlx

View File

@@ -1,349 +0,0 @@
// Copyright © 2024 Apple Inc.
using namespace mlx::steel;
///////////////////////////////////////////////////////////////////////////////
// GEMM kernels
///////////////////////////////////////////////////////////////////////////////
constant bool align_Q [[function_constant(200)]];
constant bool align_K [[function_constant(201)]];
template <typename T>
struct TransformScale {
T scale;
METAL_FUNC TransformScale(T scale_) : scale(scale_) {}
METAL_FUNC T apply(T x) const {
return scale * x;
}
};
struct MaxOp {
template <typename T>
METAL_FUNC static constexpr T apply(T x, T y) {
return metal::max(x, y);
}
};
struct SumOp {
template <typename T>
METAL_FUNC static constexpr T apply(T x, T y) {
return x + y;
}
};
struct MulOp {
template <typename T>
METAL_FUNC static constexpr T apply(T x, T y) {
return x * y;
}
};
struct SubOp {
template <typename T>
METAL_FUNC static constexpr T apply(T x, T y) {
return x - y;
}
};
struct ExpSubOp {
template <typename T>
METAL_FUNC static constexpr T apply(T x, T y) {
return fast::exp(x - y);
}
};
struct DivOp {
template <typename T>
METAL_FUNC static constexpr T apply(T x, T y) {
return x / y;
}
};
// clang-format off
template <
typename T,
int BQ,
int BK,
int BD,
int WM,
int WN,
typename AccumType = float>
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void attention(
const device T* Q [[buffer(0)]],
const device T* K [[buffer(1)]],
const device T* V [[buffer(2)]],
device T* O [[buffer(3)]],
const constant AttnParams* params [[buffer(4)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on
// Pacifying compiler
(void)lid;
// Move to correct block
ulong3 tidl{tid.x, tid.y, tid.z};
Q += tidl.z * params->Q_strides[0] + // Batch
tidl.y * params->Q_strides[1] + // Head
tidl.x * BQ * params->Q_strides[2]; // Seqeunce
ulong kv_head_idx = int(tid.y) / params->gqa_factor;
K += tidl.z * params->K_strides[0] + // Batch
kv_head_idx * params->K_strides[1]; // Head
V += tidl.z * params->V_strides[0] + // Batch
kv_head_idx * params->V_strides[1]; // Head
O += tidl.z * params->O_strides[0] + // Batch
tidl.y * params->O_strides[1] + // Head
tidl.x * BQ * params->O_strides[2]; // Seqeunce
// Prepare threadgroup memory
constexpr short padQ = 0; // 16 / sizeof(T);
constexpr short padK = 0; // 16 / sizeof(T);
constexpr short padV = 0; // 16 / sizeof(T);
constexpr short LDQ_tgp = BD + padQ;
constexpr short LDK_tgp = BK + padK;
constexpr short LDV_tgp = BD + padV;
threadgroup T Qs[BQ * (BD + padQ)];
threadgroup T Ks[(BK + padK) * BD];
threadgroup T Vs[BK * (BD + padV)];
// Prepare block loaders
using QBlockLoader = BlockLoaderT<
/* typename T = */ T,
/* short BROWS = */ BQ,
/* short BCOLS = */ BD,
/* short kDstStrRow = */ LDQ_tgp,
/* short kDstStrCol = */ 1,
/* short reduction_dim = */ 1,
/* short tgp_size = */ WM * WN * 32>;
// K is loaded in transposed
using KBlockLoader = BlockLoaderT<
/* typename T = */ T,
/* short BROWS = */ BK,
/* short BCOLS = */ BD,
/* short kDstStrRow = */ 1,
/* short kDstStrCol = */ LDK_tgp,
/* short reduction_dim = */ 0,
/* short tgp_size = */ WM * WN * 32>;
using VBlockLoader = BlockLoaderT<
/* typename T = */ T,
/* short BROWS = */ BK,
/* short BCOLS = */ BD,
/* short kDstStrRow = */ LDV_tgp,
/* short kDstStrCol = */ 1,
/* short reduction_dim = */ 0,
/* short tgp_size = */ WM * WN * 32>;
QBlockLoader loader_q(
Q, params->Q_strides[2], Qs, simd_group_id, simd_lane_id);
KBlockLoader loader_k(
K, params->K_strides[2], Ks, simd_group_id, simd_lane_id);
VBlockLoader loader_v(
V, params->V_strides[2], Vs, simd_group_id, simd_lane_id);
TransformScale<T> ts(static_cast<T>(params->scale));
// Prepare MMA tiles
constexpr short kFragSize = 8; // MMAFrag size
using MMAFrag_acc_t = BaseMMAFrag<AccumType, kFragSize, kFragSize>;
constexpr int kNWarps = WM * WN;
static_assert(
BQ >= (kNWarps * kFragSize) && BQ % (kNWarps * kFragSize) == 0,
"Each simdgroup must host atleast 1 simdgroup matrix along Q sequence.");
// Q seq frags per warp
constexpr int TQ = BQ / (kNWarps * kFragSize);
// KV sequence frags (all warps load the same frags)
constexpr int TK = BK / kFragSize;
// HeadDim frags (all warps load the same frags)
constexpr int TD = BD / kFragSize;
static_assert(TQ == 1, "Check TQ");
MMATile<AccumType, TQ, 1, MMAFrag_acc_t> Qtile;
MMATile<AccumType, 1, TK, MMAFrag_acc_t> Ktile;
MMATile<AccumType, TQ, TK, MMAFrag_acc_t> Stile;
MMATile<AccumType, TK, TD, MMAFrag_acc_t> Vtile;
MMATile<AccumType, TQ, TD, MMAFrag_acc_t> Otile;
Otile.clear();
// Prepare mma tile offsets
const short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id);
const short sm = simd_coord.y;
const short sn = simd_coord.x;
const short tm = kFragSize * TQ * simd_group_id;
const short Qs_offset = (tm + sm) * LDQ_tgp + sn;
const short Ks_offset = sm * LDK_tgp + sn;
const short Vs_offset = sm * LDV_tgp + sn;
constexpr short Qs_tile_stride = kFragSize;
constexpr short Ks_tile_stride = kFragSize * LDK_tgp;
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load Q blocks apply scale
if (!align_Q && int(tid.x) == (params->NQ_aligned)) {
loader_q.load_safe(short2(BD, params->qL - params->NQ_aligned * BQ));
} else {
loader_q.load_unsafe();
}
loader_q.apply_inplace_op(ts);
// Init row reduction variables
constexpr short kRowsPT = decltype(Stile)::kRowsPerThread;
AccumType max_score[kRowsPT];
AccumType sum_score[kRowsPT] = {0};
// Init to -Inf
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kRowsPT; ++i) {
max_score[i] = Limits<AccumType>::min;
}
// Loop over KV seq length
for (int kb = 0; kb < params->NK; kb++) {
// Load K block and apply scale
threadgroup_barrier(mem_flags::mem_threadgroup);
if (!align_K && kb == (params->NK_aligned)) {
loader_k.load_safe(short2(BD, params->kL - params->NK_aligned * BK));
} else {
loader_k.load_unsafe();
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Do S = Q @ K.T
Stile.clear();
for (short dd = 0; dd < TD; dd++) {
simdgroup_barrier(mem_flags::mem_none);
Qtile.template load<T, 1, 1, LDQ_tgp, 1>(
&Qs[Qs_offset + dd * Qs_tile_stride]);
Ktile.template load<T, 1, 1, LDK_tgp, 1>(
&Ks[Ks_offset + dd * Ks_tile_stride]);
simdgroup_barrier(mem_flags::mem_none);
tile_matmad(Stile, Qtile, Ktile, Stile);
}
// Mask out of length sequence
if (!align_K && kb == (params->NK_aligned)) {
using stile_t = decltype(Stile);
using selem_t = typename stile_t::elem_type;
constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity();
const short lim = params->kL - params->NK_aligned * BK;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < stile_t::kTileRows; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < stile_t::kTileCols; j++) {
short col_pos = sn + (j * stile_t::kFragCols);
STEEL_PRAGMA_UNROLL
for (short jj = 0; jj < stile_t::MMAFrag_t::kElemCols; jj++) {
if ((col_pos + jj) >= lim) {
Stile.frag_at(i, j)[jj] = neg_inf;
}
}
}
}
}
simdgroup_barrier(mem_flags::mem_none);
// Load V blocks
if (!align_K && kb == (params->NK_aligned)) {
loader_v.load_safe(short2(BD, params->kL - params->NK_aligned * BK));
} else {
loader_v.load_unsafe();
}
// Do softmax
// Temp variables
AccumType new_max[kRowsPT];
AccumType factor[kRowsPT];
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kRowsPT; ++i) {
new_max[i] = max_score[i];
}
// Row max
Stile.template row_reduce<MaxOp>(new_max);
// exp(Si - rowmax(Si))
Stile.template row_bin_op<ExpSubOp>(new_max);
// Factor exp(rowmax(Si) - rowmax(Si-1))
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kRowsPT; ++i) {
factor[i] = fast::exp(max_score[i] - new_max[i]);
}
// Save max for next iteration
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kRowsPT; ++i) {
max_score[i] = new_max[i];
}
// Row Sum
AccumType sum_score_tmp[kRowsPT] = {0};
Stile.template row_reduce<SumOp>(sum_score_tmp);
// Update norm
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kRowsPT; ++i) {
sum_score[i] = sum_score[i] * factor[i] + sum_score_tmp[i];
}
// Update O
Otile.template row_bin_op<MulOp>(factor);
// Load V into registers
threadgroup_barrier(mem_flags::mem_threadgroup);
Vtile.template load<T, 1, 1, LDV_tgp, 1>(&Vs[Vs_offset]);
simdgroup_barrier(mem_flags::mem_none);
// Do O = S @ V
tile_matmad(Otile, Stile, Vtile, Otile);
// Prepare for next iteration
loader_k.next();
loader_v.next();
}
// Normalize output
Otile.template row_bin_op<DivOp>(sum_score);
threadgroup_barrier(mem_flags::mem_none);
// Store results
O += (tm + sm) * params->O_strides[2] + sn;
if (!align_Q && int(tid.x) == (params->NQ_aligned)) {
auto dst_tile_dims =
short2(BD - sn, params->qL - BQ * params->NQ_aligned - (tm + sm));
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
return;
Otile.template store_safe<T, 1, 1>(O, params->O_strides[2], dst_tile_dims);
} else {
Otile.template store<T, 1, 1>(O, params->O_strides[2]);
}
}

View File

@@ -1,31 +0,0 @@
// Copyright © 2024 Apple Inc.
// clang-format off
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/attn/attn.h"
#include "mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h"
#define instantiate_attn(tname, dtype, bq, bk, bd, wm, wn) \
template [[host_name("steel_attention_" #tname "_bq" #bq "_bk" #bk "_bd" #bd "_wm" #wm "_wn" #wn)]] \
[[kernel]] void attention<dtype, bq, bk, bd, wm, wn, float>( \
const device dtype* Q [[buffer(0)]], \
const device dtype* K [[buffer(1)]], \
const device dtype* V [[buffer(2)]], \
device dtype* O [[buffer(3)]],\
const constant AttnParams* params [[buffer(4)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);
#define instantiate_attn_shapes_helper(iname, itype) \
instantiate_attn(iname, itype, 32, 16, 128, 4, 1) \
instantiate_attn(iname, itype, 32, 32, 80, 4, 1) \
instantiate_attn(iname, itype, 32, 32, 64, 4, 1)
instantiate_attn_shapes_helper(float16, half);
instantiate_attn_shapes_helper(bfloat16, bfloat16_t);
instantiate_attn_shapes_helper(float32, float);
// clang-format on

View File

@@ -1,264 +0,0 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/backend/metal/kernels/steel/defines.h"
///////////////////////////////////////////////////////////////////////////////
// Loading helper
///////////////////////////////////////////////////////////////////////////////
namespace mlx {
namespace steel {
template <
typename T,
short BROWS,
short BCOLS,
short dst_ld,
short reduction_dim,
short tgp_size,
short alignment = 1,
short n_reads = (BCOLS * BROWS) / (tgp_size),
short TCOLS = BCOLS / n_reads,
short TROWS = tgp_size / TCOLS>
struct BlockLoader {
STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS;
STEEL_CONST short vec_size = n_reads;
// Leading dimension for src
const int src_ld;
const int tile_stride;
// Thread location indices
const short thread_idx;
const short bi;
const short bj;
// threadgroup and device memory
threadgroup T* dst;
const device T* src;
struct alignas(alignment * sizeof(T)) ReadVector {
uint8_t v[sizeof(T) * vec_size];
};
/* Constructor */
METAL_FUNC BlockLoader(
const device T* src_,
const int src_ld_,
threadgroup T* dst_,
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
ushort simd_lane_id [[thread_index_in_simdgroup]])
: src_ld(src_ld_),
tile_stride(reduction_dim ? BCOLS : BROWS * src_ld),
thread_idx(simd_group_id * 32 + simd_lane_id),
bi(thread_idx / TCOLS),
bj(vec_size * (thread_idx % TCOLS)),
dst(dst_ + bi * dst_ld + bj),
src(src_ + bi * src_ld + bj) {}
/* Apply operation to threadgroup without bound checking */
template <typename UnaryOp>
METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = op.apply(dst[i * dst_ld + j]);
}
}
}
/* Load from device memory into threadgroup memory - without bound checking */
METAL_FUNC void load_unsafe() const {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
*((threadgroup ReadVector*)(&dst[i * dst_ld])) =
*((const device ReadVector*)(&src[i * src_ld]));
}
}
/* Load from device memory into threadgroup memory - with bound checking */
METAL_FUNC void load_safe(short2 src_tile_dim) const {
src_tile_dim = src_tile_dim - short2(bj, bi);
// Skip loading if thread has no valid reads
if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = T(0);
}
}
return;
}
// Use fast thread memory for bound checks
bool tmp_idx[vec_size];
T tmp_val[vec_size];
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
// Make sure tmp_idx only contains valid indices
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x);
}
// Read valid indices into tmp_val
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
}
// Zero out uneeded values
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);
}
// Copy values to threadgroup memory
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = tmp_val[j];
}
}
}
/* Iteration helper */
METAL_FUNC void next() {
src += tile_stride;
}
};
template <int R, int C>
struct CShape {
STEEL_CONST int kRows = R;
STEEL_CONST int kCols = C;
};
template <
typename T,
short BROWS,
short BCOLS,
short kDstStrRow,
short kDstStrCol,
short reduction_dim,
short tgp_size,
short n_reads = (BCOLS * BROWS) / (tgp_size),
short TCOLS = BCOLS / n_reads,
short TROWS = tgp_size / TCOLS>
struct BlockLoaderT {
STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS;
STEEL_CONST short vec_size = n_reads;
// Leading dimension for src
const int src_ld;
const int tile_stride;
// Thread location indices
const short thread_idx;
const short bi;
const short bj;
// threadgroup and device memory
threadgroup T* dst;
const device T* src;
/* Constructor */
METAL_FUNC BlockLoaderT(
const device T* src_,
const int src_ld_,
threadgroup T* dst_,
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
ushort simd_lane_id [[thread_index_in_simdgroup]])
: src_ld(src_ld_),
tile_stride(reduction_dim ? BCOLS : BROWS * src_ld),
thread_idx(simd_group_id * 32 + simd_lane_id),
bi(thread_idx / TCOLS),
bj(vec_size * (thread_idx % TCOLS)),
dst(dst_ + bi * kDstStrRow + bj * kDstStrCol),
src(src_ + bi * src_ld + bj) {}
/* Apply operation to threadgroup without bound checking */
template <typename UnaryOp>
METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * kDstStrRow + j * kDstStrCol] =
op.apply(dst[i * kDstStrRow + j * kDstStrCol]);
}
}
}
/* Load from device memory into threadgroup memory - without bound checking */
METAL_FUNC void load_unsafe() const {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * kDstStrRow + j * kDstStrCol] = src[i * src_ld + j];
}
}
}
/* Load from device memory into threadgroup memory - with bound checking */
METAL_FUNC void load_safe(short2 src_tile_dim) const {
src_tile_dim = src_tile_dim - short2(bj, bi);
// Skip loading if thread has no valid reads
if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * kDstStrRow + j * kDstStrCol] = T(0);
}
}
return;
}
// Use fast thread memory for bound checks
bool tmp_idx[vec_size];
T tmp_val[vec_size];
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
// Make sure tmp_idx only contains valid indices
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x);
}
// Read valid indices into tmp_val
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
}
// Zero out uneeded values
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);
}
// Copy values to threadgroup memory
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * kDstStrRow + j * kDstStrCol] = tmp_val[j];
}
}
}
/* Iteration helper */
METAL_FUNC void next() {
src += tile_stride;
}
};
} // namespace steel
} // namespace mlx

View File

@@ -1,726 +0,0 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include <metal_simdgroup>
#include <metal_simdgroup_matrix>
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/steel/attn/transforms.h"
#include "mlx/backend/metal/kernels/steel/defines.h"
#include "mlx/backend/metal/kernels/steel/utils/integral_constant.h"
using namespace metal;
///////////////////////////////////////////////////////////////////////////////
// MMA helper
///////////////////////////////////////////////////////////////////////////////
namespace mlx {
namespace steel {
template <typename RInt, typename CInt>
struct Shape2D {
RInt r;
CInt c;
Shape2D(RInt r_, CInt c_) : r(r_), c(c_) {}
};
template <typename Shape, typename Layout>
struct Layout2D {
Shape shape;
Layout layout;
};
template <typename T, int kFragRows_, int kFragCols_>
struct BaseMMAFrag {
static_assert(
kFragRows_ == 8,
"Only 8 x 8 fragment matrices are currently supported");
static_assert(
kFragCols_ == 8,
"Only 8 x 8 fragment matrices are currently supported");
};
template <typename T>
struct BaseMMAFrag<T, 8, 8> {
STEEL_CONST int kFragRows = 8;
STEEL_CONST int kFragCols = 8;
STEEL_CONST int kElemsPerFrag = (kFragRows * kFragCols) / 32;
STEEL_CONST int kElemRows = 1;
STEEL_CONST int kElemCols = 2;
static_assert(
kElemRows * kElemCols == kElemsPerFrag,
"MMAFrag shape is not consistent with MMAFrag size");
typedef metal::simdgroup_matrix<T, kFragRows, kFragCols> mat_type;
typedef metal::vec<T, kElemsPerFrag> frag_type;
typedef metal::vec<T, kElemRows> row_frag_type;
typedef metal::vec<T, kElemCols> col_frag_type;
METAL_FUNC static constexpr short2 get_coord(ushort simd_lane_id
[[thread_index_in_simdgroup]]) {
const short qid = simd_lane_id / 4;
const short fm = (qid & 4) + ((simd_lane_id / 2) % 4);
const short fn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
return short2{fn, fm};
}
template <typename SrcPtrType, typename StrX, typename StrY>
METAL_FUNC static constexpr void
load(thread frag_type& dst, SrcPtrType src, StrX str_x, StrY str_y) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kElemRows; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kElemCols; j++) {
dst[i * kElemCols + j] = static_cast<T>(src[i * str_x + j * str_y]);
}
}
}
template <
typename SrcPtrType,
typename StrX,
typename StrY,
typename LimX,
typename LimY,
typename OffX,
typename OffY>
METAL_FUNC static constexpr void load_safe(
thread frag_type& dst,
SrcPtrType src,
StrX str_x,
StrY str_y,
LimX lim_x,
LimY lim_y,
OffX off_x = Int<0>{},
OffY off_y = Int<0>{}) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kElemRows; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kElemCols; j++) {
if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
dst[i * kElemCols + j] =
static_cast<T>(src[(off_x + i) * str_x + (off_x + j) * str_y]);
} else {
dst[i * kElemCols + j] = T(0);
}
}
}
}
template <typename DstPtrType, typename StrX, typename StrY>
METAL_FUNC static constexpr void
store(const thread frag_type& src, DstPtrType dst, StrX str_x, StrY str_y) {
using U = pointer_element_t<DstPtrType>;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kElemRows; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kElemCols; j++) {
dst[i * str_x + j * str_y] = static_cast<U>(src[i * kElemCols + j]);
}
}
}
template <
typename DstPtrType,
typename StrX,
typename StrY,
typename LimX,
typename LimY,
typename OffX,
typename OffY>
METAL_FUNC static constexpr void store_safe(
const thread frag_type& src,
DstPtrType dst,
StrX str_x,
StrY str_y,
LimX lim_x,
LimY lim_y,
OffX off_x = Int<0>{},
OffY off_y = Int<0>{}) {
using U = pointer_element_t<DstPtrType>;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kElemRows; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kElemCols; j++) {
if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
dst[(off_x + i) * str_x + (off_y + j) * str_y] =
static_cast<U>(src[i * kElemCols + j]);
}
}
}
}
METAL_FUNC static constexpr void mma(
thread frag_type& D,
thread frag_type& A,
thread frag_type& B,
thread frag_type& C) {
mat_type D_mat;
mat_type A_mat;
mat_type B_mat;
mat_type C_mat;
reinterpret_cast<thread frag_type&>(A_mat.thread_elements()) = A;
reinterpret_cast<thread frag_type&>(B_mat.thread_elements()) = B;
reinterpret_cast<thread frag_type&>(C_mat.thread_elements()) = C;
mma(D_mat, A_mat, B_mat, C_mat);
D = reinterpret_cast<thread frag_type&>(D_mat.thread_elements());
}
METAL_FUNC static constexpr void mma(
thread mat_type& D,
thread mat_type& A,
thread mat_type& B,
thread mat_type& C) {
simdgroup_multiply_accumulate(D, A, B, C);
}
template <typename Op>
METAL_FUNC static constexpr void row_reduce(
thread const frag_type& inp_vals,
thread T* reduced_vals) {
T thr_reduce = Op::apply(inp_vals.x, inp_vals.y);
T qgr_reduce = simd_shuffle_xor(thr_reduce, ushort(1));
qgr_reduce = Op::apply(thr_reduce, qgr_reduce);
T sgr_reduce = simd_shuffle_xor(qgr_reduce, ushort(8));
sgr_reduce = Op::apply(qgr_reduce, sgr_reduce);
reduced_vals[0] = Op::apply(reduced_vals[0], sgr_reduce);
}
template <typename Op>
METAL_FUNC static constexpr void row_bin_op(
thread frag_type& inp_vals,
thread T* row_vals) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kElemRows; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kElemCols; j++) {
inp_vals[i * kElemCols + j] =
Op::apply(inp_vals[i * kElemCols + j], row_vals[i]);
}
}
}
};
template <
typename T,
int kTileRows_,
int kTileCols_,
class MMAFrag_ = BaseMMAFrag<T, 8, 8>>
struct MMATile {
using MMAFrag_t = MMAFrag_;
using elem_type = T;
STEEL_CONST int kFragRows = MMAFrag_t::kFragRows;
STEEL_CONST int kFragCols = MMAFrag_t::kFragCols;
STEEL_CONST int kElemsPerFrag = MMAFrag_t::kElemsPerFrag;
STEEL_CONST int kTileRows = kTileRows_;
STEEL_CONST int kTileCols = kTileCols_;
STEEL_CONST int kRows = kTileRows * kFragRows;
STEEL_CONST int kCols = kTileCols * kFragCols;
STEEL_CONST int kNumFrags = kTileRows * kTileCols;
STEEL_CONST int kElemsPerTile = kNumFrags * kElemsPerFrag;
STEEL_CONST int kRowsPerThread = kTileRows * MMAFrag_t::kElemRows;
STEEL_CONST int kColsPerThread = kTileCols * MMAFrag_t::kElemCols;
typedef typename MMAFrag_t::mat_type mat_type;
typedef typename MMAFrag_t::frag_type frag_type;
frag_type val_frags[kNumFrags] = {frag_type(0)};
METAL_FUNC MMATile() thread {}
METAL_FUNC constexpr void clear() {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kNumFrags; ++i) {
val_frags[i] = frag_type(0);
}
}
METAL_FUNC constexpr thread frag_type& frag_at(const short i, const short j) {
return val_frags[i * kTileCols + j];
}
METAL_FUNC constexpr const thread frag_type& frag_at(
const short i,
const short j) const {
return val_frags[i * kTileCols + j];
}
METAL_FUNC mat_type mat_at(const short i, const short j) {
mat_type val_mat;
STEEL_PRAGMA_UNROLL
for (short ii = 0; ii < kElemsPerFrag; ++ii) {
val_mat.thread_elements()[ii] = frag_at(i, j)[ii];
}
return val_mat;
}
METAL_FUNC thread elem_type* elems() {
return reinterpret_cast<thread elem_type*>(val_frags);
}
METAL_FUNC const thread elem_type* elems() const {
return reinterpret_cast<const thread elem_type*>(val_frags);
}
template <typename Op>
METAL_FUNC void row_reduce(thread T vals[kRowsPerThread]) const {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kTileCols; ++j) {
MMAFrag_t::template row_reduce<Op>(
frag_at(i, j), &vals[i * MMAFrag_t::kElemRows]);
}
}
}
template <typename Op>
METAL_FUNC void row_bin_op(thread T vals[kRowsPerThread]) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kTileCols; ++j) {
MMAFrag_t::template row_bin_op<Op>(
frag_at(i, j), &vals[i * MMAFrag_t::kElemRows]);
}
}
}
template <typename U, int w_x, int w_y, int str_x, int str_y>
METAL_FUNC void load(const threadgroup U* src) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kTileCols; ++j) {
MMAFrag_t::load(
frag_at(i, j),
&(
src[(i * kFragRows) * w_x * str_x +
(j * kFragCols) * w_y * str_y]),
Int<str_x>{},
Int<str_y>{});
}
}
}
template <typename U, int w_x, int w_y, int str_x, int str_y>
METAL_FUNC void store(threadgroup U* dst) const {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kTileCols; ++j) {
MMAFrag_t::store(
frag_at(i, j),
&(
dst[(i * kFragRows) * w_x * str_x +
(j * kFragCols) * w_y * str_y]),
Int<str_x>{},
Int<str_y>{});
}
}
}
template <typename U, int w_x, int w_y>
METAL_FUNC void load(const device U* src, const int ld) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kTileCols; ++j) {
MMAFrag_t::load(
frag_at(i, j),
&(src[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]),
ld,
Int<1>{});
}
}
}
template <typename U, int w_x, int w_y>
METAL_FUNC void store(device U* dst, const int ld) const {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kTileCols; ++j) {
MMAFrag_t::store(
frag_at(i, j),
&(dst[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]),
ld,
Int<1>{});
}
}
}
template <typename U, int w_x, int w_y>
METAL_FUNC void
load_safe(const device U* src, const int ld, const short2 src_tile_dims) {
STEEL_PRAGMA_UNROLL
for (int i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (int j = 0; j < kTileCols; ++j) {
MMAFrag_t::load_safe(
frag_at(i, j),
src,
ld,
Int<1>{},
src_tile_dims.y,
src_tile_dims.x,
(i * kFragRows) * w_x,
(j * kFragCols) * w_y);
}
}
}
template <typename U, int w_x, int w_y>
METAL_FUNC void
store_safe(device U* dst, const int ld, const short2 dst_tile_dims) const {
STEEL_PRAGMA_UNROLL
for (int i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (int j = 0; j < kTileCols; ++j) {
MMAFrag_t::store_safe(
frag_at(i, j),
dst,
ld,
Int<1>{},
dst_tile_dims.y,
dst_tile_dims.x,
(i * kFragRows) * w_x,
(j * kFragCols) * w_y);
}
}
}
};
template <typename T, typename U, int M, int N, int K>
METAL_FUNC void tile_matmad(
thread MMATile<T, M, N>& D,
thread MMATile<U, M, K>& A,
thread MMATile<U, K, N>& B,
thread MMATile<T, M, N>& C) {
STEEL_PRAGMA_UNROLL
for (short k = 0; k < K; ++k) {
STEEL_PRAGMA_UNROLL
for (short m = 0; m < M; ++m) {
STEEL_PRAGMA_UNROLL
for (short n = 0; n < N; ++n) {
short n_serp = (m % 2) ? (N - 1 - n) : n;
MMATile<T, M, N>::MMAFrag_t::mma(
D.frag_at(m, n_serp),
A.frag_at(m, k),
B.frag_at(k, n_serp),
C.frag_at(m, n_serp));
}
}
}
}
template <
typename T,
typename U,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
short lda_tgp,
short ldb_tgp,
typename AccumType = float,
typename Epilogue = TransformNone<U, AccumType>>
struct BlockMMA {
// MMAFrag size
STEEL_CONST short kFragSize = 8;
using MMAFrag_acc_t = BaseMMAFrag<AccumType, kFragSize, kFragSize>;
// Warp tile simdgroup matrix strides along M
STEEL_CONST short TM_stride = kFragSize * WM;
// Warp tile simdgroup matrix strides along M
STEEL_CONST short TN_stride = kFragSize * WN;
// Warp tile size along M
STEEL_CONST short TM = BM / TM_stride;
// Warp tile size along N
STEEL_CONST short TN = BN / TN_stride;
// Threadgroup A strides
STEEL_CONST short A_str_m = transpose_a ? 1 : lda_tgp; // M
STEEL_CONST short A_str_k = transpose_a ? lda_tgp : 1; // K
// Threadgroup B strides
STEEL_CONST short B_str_k = transpose_b ? 1 : ldb_tgp; // K
STEEL_CONST short B_str_n = transpose_b ? ldb_tgp : 1; // N
// Threadgroup strides along K
STEEL_CONST short tile_stride_a = kFragSize * A_str_k;
STEEL_CONST short tile_stride_b = kFragSize * B_str_k;
// Simdgroup matrices
MMATile<AccumType, TM, 1, MMAFrag_acc_t> Atile;
MMATile<AccumType, 1, TN, MMAFrag_acc_t> Btile;
MMATile<AccumType, TM, TN, MMAFrag_acc_t> Ctile;
// Offsets within threadgroup
short sm;
short sn;
short As_offset;
short Bs_offset;
/* Constructor */
METAL_FUNC BlockMMA(
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
ushort simd_lane_id [[thread_index_in_simdgroup]]) {
// Determine thread position in simdgroup matrix
short tm = kFragSize * (simd_group_id / WN);
short tn = kFragSize * (simd_group_id % WN);
short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id);
sm = simd_coord.y;
sn = simd_coord.x;
// Determine thread and simdgroup offset
As_offset = (tm + sm) * A_str_m + (sn)*A_str_k; // M, K
Bs_offset = (sm)*B_str_k + (tn + sn) * B_str_n; // K, N
sm += tm;
sn += tn;
}
/* (BM, BK) X (BK, BN) multiply accumulate function */
METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) {
// Adjust for simdgroup and thread location
As += As_offset;
Bs += Bs_offset;
// Iterate over BK in blocks of kFragSize
STEEL_PRAGMA_UNROLL
for (short kk = 0; kk < BK; kk += kFragSize) {
simdgroup_barrier(mem_flags::mem_none);
Atile.template load<T, WM, 1, A_str_m, A_str_k>(As);
simdgroup_barrier(mem_flags::mem_none);
Btile.template load<T, 1, WN, B_str_k, B_str_n>(Bs);
simdgroup_barrier(mem_flags::mem_none);
tile_matmad(Ctile, Atile, Btile, Ctile);
// Progress to next simdgroup tile
As += tile_stride_a;
Bs += tile_stride_b;
}
}
/* Store results from simdgroup_matrix results into device memory */
METAL_FUNC void store_result(device U* D, const int ldd) {
// Apply epilogue
STEEL_PRAGMA_UNROLL
for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]);
}
// Adjust for simdgroup and thread location
D += sm * ldd + sn;
Ctile.template store<U, WM, WN>(D, ldd);
}
METAL_FUNC void
store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) {
// Apply epilogue
STEEL_PRAGMA_UNROLL
for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]);
}
// Adjust for simdgroup and thread location
D += sm * ldd + sn;
dst_tile_dims -= short2(sn, sm);
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
return;
Ctile.template store_safe<U, WM, WN>(D, ldd, dst_tile_dims);
}
/* Apply epilogue */
template <typename UnaryEpilogue>
METAL_FUNC void apply_epilogue(thread const UnaryEpilogue& epilogue_op) {
// Loop over all simdgroup tiles
STEEL_PRAGMA_UNROLL
for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
Ctile.elems()[i] = epilogue_op.apply(Ctile.elems()[i]);
}
}
/* Apply epilogue */
template <typename BinaryEpilogue>
METAL_FUNC void apply_epilogue(
const device U* C,
const int ldc,
const int fdc,
thread const BinaryEpilogue& epilogue_op) {
// Adjust for simdgroup and thread location
C += (sm)*ldc + (sn)*fdc;
// Loop over all simdgroup tiles
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread auto& accum = Ctile.frag_at(i, j);
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
// Apply epilogue
STEEL_PRAGMA_UNROLL
for (short k = 0; k < decltype(Ctile)::kElemsPerFrag; k++) {
accum[k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
}
}
}
}
/* Apply epilogue */
template <typename BinaryEpilogue>
METAL_FUNC void apply_epilogue_safe(
const device U* C,
const int ldc,
const int fdc,
short2 dst_tile_dims,
thread const BinaryEpilogue& epilogue_op) {
// Adjust for simdgroup and thread location
C += (sm)*ldc + (sn)*fdc;
dst_tile_dims -= short2(sn, sm);
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
return;
// Loop over all simdgroup tiles
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread auto& accum = Ctile.frag_at(i, j);
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
constexpr short kelems = decltype(Ctile)::kElemsPerFrag;
// Read C
U c_elems[kelems] = {0};
STEEL_PRAGMA_UNROLL
for (short k = 0; k < kelems; k++) {
if ((j * TN_stride + k) < dst_tile_dims.x) {
c_elems[k] = C[offset_c + k * fdc];
}
}
// Apply epilogue
STEEL_PRAGMA_UNROLL
for (short k = 0; k < kelems; k++) {
accum[k] = epilogue_op.apply(accum[k], c_elems[k]);
}
}
}
}
/* Store results from simdgroup_matrix results into device memory */
METAL_FUNC void store_result(
device U* D,
const int ldd,
const device U* C,
const int ldc,
const int fdc,
thread const Epilogue& epilogue_op) const {
// Adjust for simdgroup and thread location
C += (sm)*ldc + (sn)*fdc;
D += (sm)*ldd + sn;
constexpr short kelems = decltype(Ctile)::kElemsPerFrag;
// Loop over all simdgroup tiles
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread const auto& accum = Ctile.frag_at(i, j);
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
// Apply epilogue
STEEL_PRAGMA_UNROLL
for (short k = 0; k < kelems; k++) {
D[offset_d + k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
}
}
}
}
METAL_FUNC void store_result_safe(
device U* D,
const int ldd,
const device U* C,
const int ldc,
const int fdc,
short2 dst_tile_dims,
thread const Epilogue& epilogue_op) const {
// Adjust for simdgroup and thread location
C += (sm)*ldc + (sn)*fdc;
D += (sm)*ldd + sn;
dst_tile_dims -= short2(sn, sm);
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
return;
constexpr short kelems = decltype(Ctile)::kElemsPerFrag;
STEEL_PRAGMA_UNROLL
for (int i = 0; i < TM; i++) {
if (i * TM_stride < dst_tile_dims.y) {
STEEL_PRAGMA_UNROLL
for (int j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread const auto& accum = Ctile.frag_at(i, j);
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
// Apply epilogue
STEEL_PRAGMA_UNROLL
for (short k = 0; k < kelems; k++) {
if ((j * TN_stride + k) < dst_tile_dims.x) {
D[offset_d + k] =
epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
}
}
}
}
}
}
};
} // namespace steel
} // namespace mlx

View File

@@ -1,36 +0,0 @@
// Copyright © 2024 Apple Inc.
#pragma once
///////////////////////////////////////////////////////////////////////////////
// Attn param classes
///////////////////////////////////////////////////////////////////////////////
namespace mlx {
namespace steel {
struct AttnParams {
int B; ///< Batch Size
int H; ///< Heads
int D; ///< Head Dim
int qL; ///< Query Sequence Length
int kL; ///< Key Sequence Length
int gqa_factor; ///< Group Query factor
float scale; ///< Attention scale
int NQ; ///< Number of query blocks
int NK; ///< Number of key/value blocks
int NQ_aligned; ///< Number of full query blocks
int NK_aligned; ///< Number of full key/value blocks
size_t Q_strides[3]; ///< Query strides (B, H, L, D = 1)
size_t K_strides[3]; ///< Key strides (B, H, L, D = 1)
size_t V_strides[3]; ///< Value strides (B, H, L, D = 1)
size_t O_strides[3]; ///< Output strides (B, H, L, D = 1)
};
} // namespace steel
} // namespace mlx

View File

@@ -1,71 +0,0 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/backend/metal/kernels/steel/utils.h"
///////////////////////////////////////////////////////////////////////////////
// Transforms and Epilogues
///////////////////////////////////////////////////////////////////////////////
namespace mlx {
namespace steel {
template <typename OutT, typename InT>
struct TransformNone {
static METAL_FUNC OutT apply(InT x) {
return static_cast<OutT>(x);
}
static METAL_FUNC OutT apply(InT x, OutT) {
return static_cast<OutT>(x);
}
};
template <typename OutT, typename InT>
struct TransformAdd {
TransformAdd(const float, const float) {}
static METAL_FUNC OutT apply(InT x) {
return static_cast<OutT>(x);
}
static METAL_FUNC OutT apply(InT x, OutT c) {
return static_cast<OutT>(x) + c;
}
};
template <typename OutT, typename InT>
struct TransformAxpby {
const float alpha;
const float beta;
TransformAxpby(const float alpha_, const float beta_)
: alpha(alpha_), beta(beta_) {}
static METAL_FUNC OutT apply(InT x) {
return static_cast<OutT>(x);
}
METAL_FUNC OutT apply(InT x, OutT c) const {
return static_cast<OutT>(x * alpha + (beta * c));
}
};
template <typename T>
struct AccumHelper {
typedef float accum_type;
};
struct BlockSwizzle {
static METAL_FUNC int2
swizzle(uint3 tid [[threadgroup_position_in_grid]], const int swizzle_log) {
const int tid_x = (tid.x) >> swizzle_log;
const int tid_y =
((tid.y) << swizzle_log) + ((tid.x) & ((1 << swizzle_log) - 1));
return int2(tid_x, tid_y);
}
};
} // namespace steel
} // namespace mlx

View File

@@ -5,7 +5,7 @@
// clang-format off
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/steel/conv/conv.h"
#include "mlx/backend/metal/kernels/steel/conv/params.h"
#include "mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h"

View File

@@ -5,7 +5,7 @@
// clang-format off
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/steel/conv/conv.h"
#include "mlx/backend/metal/kernels/steel/conv/params.h"
#include "mlx/backend/metal/kernels/steel/utils.h"

View File

@@ -1,6 +1,7 @@
// Copyright © 2024 Apple Inc.
// clang-format off
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"

View File

@@ -1,6 +1,7 @@
// Copyright © 2024 Apple Inc.
// clang-format off
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h"

View File

@@ -1,7 +1,7 @@
// Copyright © 2024 Apple Inc.
// clang-format off
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h"

View File

@@ -385,9 +385,9 @@ struct BlockMMA {
STEEL_CONST short TN_stride = kFragSize * WN;
// Warp tile size along M
STEEL_CONST short TM = BM / (kFragSize * WM);
STEEL_CONST short TM = BM / TM_stride;
// Warp tile size along N
STEEL_CONST short TN = BN / (kFragSize * WN);
STEEL_CONST short TN = BN / TN_stride;
// Threadgroup A strides
STEEL_CONST short A_str_m = transpose_a ? 1 : lda_tgp; // M

View File

@@ -22,7 +22,7 @@ template <typename T, typename Op>
d[offset] = Op()(a[offset], b[offset], c[offset]);
}
template <typename T, typename Op, typename IdxT = size_t>
template <typename T, typename Op>
[[kernel]] void ternary_g_nd1(
device const bool* a,
device const T* b,
@@ -32,13 +32,13 @@ template <typename T, typename Op, typename IdxT = size_t>
constant const size_t& b_strides,
constant const size_t& c_strides,
uint index [[thread_position_in_grid]]) {
auto a_idx = elem_to_loc_1<size_t, IdxT>(index, a_strides);
auto b_idx = elem_to_loc_1<size_t, IdxT>(index, b_strides);
auto c_idx = elem_to_loc_1<size_t, IdxT>(index, c_strides);
auto a_idx = elem_to_loc_1(index, a_strides);
auto b_idx = elem_to_loc_1(index, b_strides);
auto c_idx = elem_to_loc_1(index, c_strides);
d[index] = Op()(a[a_idx], b[b_idx], c[c_idx]);
}
template <typename T, typename Op, typename IdxT = size_t>
template <typename T, typename Op>
[[kernel]] void ternary_g_nd2(
device const bool* a,
device const T* b,
@@ -49,14 +49,14 @@ template <typename T, typename Op, typename IdxT = size_t>
constant const size_t c_strides[2],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_2<size_t, IdxT>(index, a_strides);
auto b_idx = elem_to_loc_2<size_t, IdxT>(index, b_strides);
auto c_idx = elem_to_loc_2<size_t, IdxT>(index, c_strides);
IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y;
auto a_idx = elem_to_loc_2(index, a_strides);
auto b_idx = elem_to_loc_2(index, b_strides);
auto c_idx = elem_to_loc_2(index, c_strides);
size_t out_idx = index.x + size_t(grid_dim.x) * index.y;
d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
}
template <typename T, typename Op, typename IdxT = size_t>
template <typename T, typename Op>
[[kernel]] void ternary_g_nd3(
device const bool* a,
device const T* b,
@@ -67,14 +67,15 @@ template <typename T, typename Op, typename IdxT = size_t>
constant const size_t c_strides[3],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_3<size_t, IdxT>(index, a_strides);
auto b_idx = elem_to_loc_3<size_t, IdxT>(index, b_strides);
auto c_idx = elem_to_loc_3<size_t, IdxT>(index, c_strides);
IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z);
auto a_idx = elem_to_loc_3(index, a_strides);
auto b_idx = elem_to_loc_3(index, b_strides);
auto c_idx = elem_to_loc_3(index, c_strides);
size_t out_idx =
index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
}
template <typename T, typename Op, int N = 1, typename IdxT = size_t>
template <typename T, typename Op, int N = 1>
[[kernel]] void ternary_g(
device const bool* a,
device const T* b,
@@ -87,7 +88,7 @@ template <typename T, typename Op, int N = 1, typename IdxT = size_t>
constant const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_3_nd<IdxT>(
auto idx = elem_to_loc_3_nd(
{N * index.x, index.y, index.z},
shape,
a_strides,
@@ -95,10 +96,11 @@ template <typename T, typename Op, int N = 1, typename IdxT = size_t>
c_strides,
ndim);
auto xshape = shape[ndim - 1];
IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
IdxT a_xstride = a_strides[ndim - 1];
IdxT b_xstride = b_strides[ndim - 1];
IdxT c_xstride = c_strides[ndim - 1];
size_t out_idx =
N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z);
auto a_xstride = a_strides[ndim - 1];
auto b_xstride = b_strides[ndim - 1];
auto c_xstride = c_strides[ndim - 1];
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
d[out_idx++] = Op()(a[idx.x], b[idx.y], c[idx.z]);
idx.x += a_xstride;

View File

@@ -4,21 +4,18 @@
#include <metal_math>
// clang-format off
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/ternary_ops.h"
#include "mlx/backend/metal/kernels/ternary.h"
#define instantiate_ternary_all(op, tname, type) \
instantiate_kernel("v_" #op #tname, ternary_v, type, op) \
instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \
instantiate_kernel("gn2_" #op #tname, ternary_g, type, op, 1, uint) \
instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op, uint) \
instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op, uint) \
instantiate_kernel("g3_" #op #tname, ternary_g_nd3, type, op, uint) \
instantiate_kernel("g1large_" #op #tname, ternary_g_nd1, type, op) \
instantiate_kernel("g2large_" #op #tname, ternary_g_nd2, type, op) \
instantiate_kernel("g3large_" #op #tname, ternary_g_nd3, type, op) \
instantiate_kernel("gn4large_" #op #tname, ternary_g, type, op, 4) \
#define instantiate_ternary_all(op, tname, type) \
instantiate_kernel("v_" #op #tname, ternary_v, type, op) \
instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \
instantiate_kernel("gn4_" #op #tname, ternary_g, type, op, 4) \
instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op) \
instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op) \
instantiate_kernel("g3_" #op #tname, ternary_g_nd3, type, op)
#define instantiate_ternary_types(op) \
instantiate_ternary_all(op, bool_, bool) \

View File

@@ -18,12 +18,7 @@ template <typename T, typename U, typename Op>
out[offset] = Op()(in[offset]);
}
template <
typename T,
typename U,
typename Op,
int N = 1,
typename IdxT = size_t>
template <typename T, typename U, typename Op, int N = 1>
[[kernel]] void unary_g(
device const T* in,
device U* out,
@@ -32,11 +27,12 @@ template <
device const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc<size_t, IdxT>(
{N * index.x, index.y, index.z}, in_shape, in_strides, ndim);
auto idx =
elem_to_loc({N * index.x, index.y, index.z}, in_shape, in_strides, ndim);
auto xshape = in_shape[ndim - 1];
IdxT xstride = in_strides[ndim - 1];
IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
auto xstride = in_strides[ndim - 1];
size_t out_idx =
N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z);
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
out[out_idx++] = Op()(in[idx]);
idx += xstride;

View File

@@ -5,13 +5,11 @@
#include "mlx/backend/metal/kernels/unary_ops.h"
#include "mlx/backend/metal/kernels/unary.h"
#define instantiate_unary_all(op, in_tname, out_tname, in_type, out_type) \
instantiate_kernel("v_" #op #in_tname #out_tname, unary_v, in_type, out_type, op) \
instantiate_kernel("v2_" #op #in_tname #out_tname, unary_v2, in_type, out_type, op) \
instantiate_kernel( \
"gn1_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 1, uint) \
instantiate_kernel( \
"gn4large" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 4)
#define instantiate_unary_all(op, in_tname, out_tname, in_type, out_type) \
instantiate_kernel("v_" #op #in_tname #out_tname, unary_v, in_type, out_type, op) \
instantiate_kernel("v2_" #op #in_tname #out_tname, unary_v2, in_type, out_type, op) \
instantiate_kernel("gn4_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 4)
#define instantiate_unary_all_same(op, tname, type) \
instantiate_unary_all(op, tname, tname, type, type)

View File

@@ -3,13 +3,7 @@
#pragma once
#include <metal_math>
// The correct bf16.h is included based on the metal version
// by giving the correct path to -I during compilation
// e.g. mlx/backend/metal/kernels/metal_3_0/ for Metal 3.0
#include "bf16.h"
#include "mlx/backend/metal/kernels/bf16_math.h"
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/complex.h"
#include "mlx/backend/metal/kernels/defines.h"
@@ -89,45 +83,44 @@ struct Limits<complex64_t> {
///////////////////////////////////////////////////////////////////////////////
// Single Array with generic dims
template <typename StrideT, typename IdxT = StrideT>
METAL_FUNC IdxT elem_to_loc(
template <typename stride_t>
METAL_FUNC stride_t elem_to_loc(
uint elem,
constant const int* shape,
constant const StrideT* strides,
constant const stride_t* strides,
int ndim) {
IdxT loc = 0;
stride_t loc = 0;
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
loc += (elem % shape[i]) * IdxT(strides[i]);
loc += (elem % shape[i]) * strides[i];
elem /= shape[i];
}
return loc;
}
template <typename StrideT, typename IdxT = StrideT>
METAL_FUNC IdxT elem_to_loc(
StrideT elem,
template <typename stride_t>
METAL_FUNC stride_t elem_to_loc(
stride_t elem,
constant const int* shape,
constant const StrideT* strides,
constant const stride_t* strides,
int ndim) {
IdxT loc = 0;
stride_t loc = 0;
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
loc += (elem % shape[i]) * IdxT(strides[i]);
loc += (elem % shape[i]) * strides[i];
elem /= shape[i];
}
return loc;
}
// Non templated version to handle arbitrary dims
template <typename StrideT, typename IdxT = StrideT>
METAL_FUNC IdxT elem_to_loc(
template <typename stride_t>
METAL_FUNC stride_t elem_to_loc(
uint3 elem,
constant const int* shape,
constant const StrideT* strides,
constant const stride_t* strides,
int ndim) {
IdxT loc =
elem.x * IdxT(strides[ndim - 1]) + elem.y * IdxT(strides[ndim - 2]);
stride_t loc = elem.x * strides[ndim - 1] + elem.y * strides[ndim - 2];
for (int d = ndim - 3; d >= 0; --d) {
loc += (elem.z % shape[d]) * IdxT(strides[d]);
loc += (elem.z % shape[d]) * strides[d];
elem.z /= shape[d];
}
return loc;
@@ -136,65 +129,61 @@ METAL_FUNC IdxT elem_to_loc(
///////////////////////////////////////////////////////////////////////////////
// Single Array with fixed N dims
template <typename StrideT, typename IdxT = StrideT>
METAL_FUNC IdxT elem_to_loc_1(uint elem, constant const StrideT& stride) {
return elem * IdxT(stride);
template <typename stride_t>
METAL_FUNC stride_t elem_to_loc_1(uint elem, constant const stride_t& stride) {
return elem * stride;
}
template <typename StrideT, typename IdxT = StrideT>
METAL_FUNC IdxT elem_to_loc_2(uint2 elem, constant const StrideT strides[2]) {
return elem.x * IdxT(strides[1]) + elem.y * IdxT(strides[0]);
template <typename stride_t>
METAL_FUNC stride_t
elem_to_loc_2(uint2 elem, constant const stride_t strides[2]) {
return elem.x * strides[1] + elem.y * strides[0];
}
template <typename StrideT, typename IdxT = StrideT>
METAL_FUNC IdxT elem_to_loc_3(uint3 elem, constant const StrideT strides[3]) {
return elem.x * IdxT(strides[2]) + elem.y * IdxT(strides[1]) +
elem.z * IdxT(strides[0]);
template <typename stride_t>
METAL_FUNC stride_t
elem_to_loc_3(uint3 elem, constant const stride_t strides[3]) {
return elem.x * strides[2] + elem.y * strides[1] + elem.z * strides[0];
}
///////////////////////////////////////////////////////////////////////////////
// Multiple Arrays with generic dims
template <typename StrideT, typename IdxT = StrideT>
METAL_FUNC vec<IdxT, 2> elem_to_loc_2_nd(
template <typename stride_t>
METAL_FUNC ulong2 elem_to_loc_2_nd(
uint3 elem,
constant const int* shape,
constant const StrideT* a_strides,
constant const StrideT* b_strides,
constant const stride_t* a_strides,
constant const stride_t* b_strides,
int ndim) {
vec<IdxT, 2> loc = {
IdxT(
elem.x * IdxT(a_strides[ndim - 1]) +
IdxT(elem.y) * IdxT(a_strides[ndim - 2])),
IdxT(
elem.x * IdxT(b_strides[ndim - 1]) +
elem.y * IdxT(b_strides[ndim - 2]))};
ulong2 loc = {
ulong(elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]),
ulong(elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2])};
for (int d = ndim - 3; d >= 0; --d) {
uint l = elem.z % shape[d];
loc.x += l * IdxT(a_strides[d]);
loc.y += l * IdxT(b_strides[d]);
loc.x += l * a_strides[d];
loc.y += l * b_strides[d];
elem.z /= shape[d];
}
return loc;
}
template <typename IdxT = size_t>
METAL_FUNC vec<IdxT, 3> elem_to_loc_3_nd(
METAL_FUNC ulong3 elem_to_loc_3_nd(
uint3 elem,
constant const int* shape,
constant const size_t* a_strides,
constant const size_t* b_strides,
constant const size_t* c_strides,
int ndim) {
vec<IdxT, 3> loc = {
elem.x * IdxT(a_strides[ndim - 1]) + elem.y * IdxT(a_strides[ndim - 2]),
elem.x * IdxT(b_strides[ndim - 1]) + elem.y * IdxT(b_strides[ndim - 2]),
elem.x * IdxT(c_strides[ndim - 1]) + elem.y * IdxT(c_strides[ndim - 2])};
ulong3 loc = {
elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2],
elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2],
elem.x * c_strides[ndim - 1] + elem.y * c_strides[ndim - 2]};
for (int d = ndim - 3; d >= 0; --d) {
uint l = elem.z % shape[d];
loc.x += l * IdxT(a_strides[d]);
loc.y += l * IdxT(b_strides[d]);
loc.z += l * IdxT(c_strides[d]);
loc.x += l * a_strides[d];
loc.y += l * b_strides[d];
loc.z += l * c_strides[d];
elem.z /= shape[d];
}
return loc;
@@ -204,21 +193,16 @@ METAL_FUNC vec<IdxT, 3> elem_to_loc_3_nd(
// Elem to loc in a loop utils
///////////////////////////////////////////////////////////////////////////////
template <int DIM, typename OffsetT = size_t, bool General = true>
struct LoopedElemToLoc {
int dim;
LoopedElemToLoc<DIM - 1, OffsetT, General> inner_looper;
OffsetT offset{0};
template <int dim, typename offset_t = size_t>
struct looped_elem_to_loc {
looped_elem_to_loc<dim - 1, offset_t> inner_looper;
offset_t offset{0};
int index{0};
LoopedElemToLoc(int dim) : dim(dim), inner_looper(dim - 1) {}
void next(const constant int* shape, const constant size_t* strides) {
if (dim == 0) {
return;
}
index++;
offset += OffsetT(strides[dim - 1]);
offset += strides[dim - 1];
if (index >= shape[dim - 1]) {
index = 0;
inner_looper.next(shape, strides);
@@ -227,21 +211,13 @@ struct LoopedElemToLoc {
}
void next(int n, const constant int* shape, const constant size_t* strides) {
if (dim == 0) {
return;
}
index += n;
offset += n * OffsetT(strides[dim - 1]);
offset += n * strides[dim - 1];
if (index >= shape[dim - 1]) {
int extra = index - shape[dim - 1];
if (extra >= shape[dim - 1]) {
inner_looper.next(1 + extra / shape[dim - 1], shape, strides);
extra = extra % shape[dim - 1];
} else {
inner_looper.next(shape, strides);
}
index = 0;
inner_looper.next(shape, strides);
offset = inner_looper.offset;
if (extra > 0) {
next(extra, shape, strides);
@@ -249,61 +225,44 @@ struct LoopedElemToLoc {
}
}
OffsetT location() {
offset_t
location(offset_t, const constant int*, const constant size_t*, int) {
return offset;
}
};
template <typename OffsetT>
struct LoopedElemToLoc<1, OffsetT, true> {
int dim;
OffsetT offset{0};
uint index{0};
LoopedElemToLoc(int dim) : dim(dim) {}
void next(const constant int* shape, const constant size_t* strides) {
index++;
if (dim > 1) {
offset = elem_to_loc<size_t, OffsetT>(index, shape, strides, dim);
} else {
offset += OffsetT(strides[0]);
}
}
void next(int n, const constant int* shape, const constant size_t* strides) {
index += n;
if (dim > 1) {
offset = elem_to_loc<size_t, OffsetT>(index, shape, strides, dim);
} else {
offset = index * OffsetT(strides[0]);
}
}
OffsetT location() {
return offset;
}
};
template <typename OffsetT>
struct LoopedElemToLoc<1, OffsetT, false> {
OffsetT offset{0};
LoopedElemToLoc(int) {}
template <typename offset_t>
struct looped_elem_to_loc<1, offset_t> {
offset_t offset{0};
void next(const constant int*, const constant size_t* strides) {
offset += OffsetT(strides[0]);
offset += strides[0];
}
void next(int n, const constant int*, const constant size_t* strides) {
offset += n * OffsetT(strides[0]);
offset += n * strides[0];
}
OffsetT location() {
offset_t
location(offset_t, const constant int*, const constant size_t*, int) {
return offset;
}
};
template <typename offset_t>
struct looped_elem_to_loc<0, offset_t> {
void next(const constant int*, const constant size_t*) {}
void next(int, const constant int*, const constant size_t*) {}
offset_t location(
offset_t idx,
const constant int* shape,
const constant size_t* strides,
int ndim) {
return elem_to_loc(idx, shape, strides, ndim);
}
};
///////////////////////////////////////////////////////////////////////////////
// Calculation utils
///////////////////////////////////////////////////////////////////////////////
@@ -421,14 +380,3 @@ inline complex64_t simd_shuffle(complex64_t data, uint16_t lane) {
return complex64_t(
simd_shuffle(data.real, lane), simd_shuffle(data.imag, lane));
}
// std::conditional is not included with Metal
template <bool condition, typename T, typename U>
struct ConditionalType {
using type = U;
};
template <typename T, typename U>
struct ConditionalType<true, T, U> {
using type = T;
};

View File

@@ -11,12 +11,12 @@ SRC_DIR=$3
SRC_FILE=$4
CFLAGS=$5
SRC_NAME=$(basename -- "${SRC_FILE}")
JIT_INCLUDES=${SRC_DIR}/mlx/backend/metal/kernels/jit
INPUT_FILE=${SRC_DIR}/mlx/backend/metal/kernels/${SRC_FILE}.h
OUTPUT_FILE=${OUTPUT_DIR}/${SRC_NAME}.cpp
mkdir -p "$OUTPUT_DIR"
CONTENT=$($CC -I"$SRC_DIR" -I"$JIT_INCLUDES" -DMLX_METAL_JIT -E -P "$INPUT_FILE" $CFLAGS 2>/dev/null)
CONTENT=$($CC -I "$SRC_DIR" -DMLX_METAL_JIT -E -P "$INPUT_FILE" $CFLAGS 2>/dev/null)
cat << EOF > "$OUTPUT_FILE"
namespace mlx::core::metal {

View File

@@ -249,7 +249,7 @@ void steel_matmul_regular(
wm,
wn);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder->setComputePipelineState(kernel);
// Use problem size to determine threadblock swizzle
int tn = (N + bn - 1) / bn;
@@ -288,12 +288,12 @@ void steel_matmul_regular(
compute_encoder.set_input_array(b, 1);
compute_encoder.set_output_array(out, 3);
compute_encoder.set_bytes(params, 4);
compute_encoder->setBytes(&params, sizeof(GEMMParams), 4);
compute_encoder.set_vector_bytes(batch_shape, 6);
compute_encoder.set_vector_bytes(batch_strides, 7);
set_vector_bytes(compute_encoder, batch_shape, 6);
set_vector_bytes(compute_encoder, batch_strides, 7);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
// Record copies
d.add_temporaries(std::move(copies), s.index);
@@ -390,7 +390,7 @@ void steel_matmul(
wn,
mn_aligned,
k_aligned);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder->setComputePipelineState(kernel);
int tn = (N + bn - 1) / bn;
int tm = (M + bm - 1) / bm;
@@ -416,29 +416,34 @@ void steel_matmul(
compute_encoder.set_input_array(b, 1);
compute_encoder.set_output_array(C_split, 2);
compute_encoder.set_bytes(params, 3);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
compute_encoder->setBytes(&params, sizeof(GEMMSpiltKParams), 3);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
// Do accum kernel
{
auto c_split_buf =
static_cast<const MTL::Resource*>(C_split.buffer().ptr());
const class MTL::Resource* const resources[1] = {c_split_buf};
compute_encoder->memoryBarrier(resources, 1);
auto kernel_name = "steel_gemm_splitk_accum_" + type_to_name(out) + "_" +
type_to_name(C_split);
auto kernel = get_steel_gemm_splitk_accum_kernel(
d, kernel_name, C_split, out, false);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder->setComputePipelineState(kernel);
// Set the arguments for the kernel
compute_encoder.set_input_array(C_split, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder.set_bytes(split_k_partitions, 2);
compute_encoder.set_bytes(split_k_partition_stride, 3);
compute_encoder.set_bytes(N, 4);
compute_encoder->setBytes(&split_k_partitions, sizeof(int), 2);
compute_encoder->setBytes(&split_k_partition_stride, sizeof(int), 3);
compute_encoder->setBytes(&N, sizeof(int), 4);
// Launch enough thread groups for each output
MTL::Size grid_dims = MTL::Size(N, M, 1);
auto group_dims = get_block_dims(N, M, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims);
MTL::Size group_dims = MTL::Size(std::min(1024, N * M), 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
d.add_temporaries(std::move(copies), s.index);
@@ -620,7 +625,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder->setComputePipelineState(kernel);
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
MTL::Size group_dims = MTL::Size(32, bn, bm);
@@ -630,16 +635,16 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_input_array(vec, 1);
compute_encoder.set_output_array(out, 3);
compute_encoder.set_bytes(in_vector_len, 4);
compute_encoder.set_bytes(out_vector_len, 5);
compute_encoder.set_bytes(mat_ld, 6);
compute_encoder->setBytes(&in_vector_len, sizeof(int), 4);
compute_encoder->setBytes(&out_vector_len, sizeof(int), 5);
compute_encoder->setBytes(&mat_ld, sizeof(int), 6);
compute_encoder.set_bytes(batch_ndim, 9);
compute_encoder.set_vector_bytes(batch_shape, 10);
compute_encoder.set_vector_bytes(batch_strides_vec, 11);
compute_encoder.set_vector_bytes(batch_strides_mat, 12);
compute_encoder->setBytes(&batch_ndim, sizeof(int), 9);
set_vector_bytes(compute_encoder, batch_shape, 10);
set_vector_bytes(compute_encoder, batch_strides_vec, 11);
set_vector_bytes(compute_encoder, batch_strides_mat, 12);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
d.add_temporaries(std::move(copies), s.index);
return;
@@ -817,7 +822,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder->setComputePipelineState(kernel);
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
MTL::Size group_dims = MTL::Size(32, bn, bm);
@@ -828,23 +833,23 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_input_array(c, 2);
compute_encoder.set_output_array(out, 3);
compute_encoder.set_bytes(in_vector_len, 4);
compute_encoder.set_bytes(out_vector_len, 5);
compute_encoder.set_bytes(mat_ld, 6);
compute_encoder->setBytes(&in_vector_len, sizeof(int), 4);
compute_encoder->setBytes(&out_vector_len, sizeof(int), 5);
compute_encoder->setBytes(&mat_ld, sizeof(int), 6);
compute_encoder.set_bytes(alpha_, 7);
compute_encoder.set_bytes(beta_, 8);
compute_encoder->setBytes(&alpha_, sizeof(float), 7);
compute_encoder->setBytes(&beta_, sizeof(float), 8);
compute_encoder.set_bytes(batch_ndim, 9);
compute_encoder.set_vector_bytes(batch_shape, 10);
compute_encoder.set_vector_bytes(batch_strides_vec, 11);
compute_encoder.set_vector_bytes(batch_strides_mat, 12);
compute_encoder.set_vector_bytes(C_batch_stride, 13);
compute_encoder->setBytes(&batch_ndim, sizeof(int), 9);
set_vector_bytes(compute_encoder, batch_shape, 10);
set_vector_bytes(compute_encoder, batch_strides_vec, 11);
set_vector_bytes(compute_encoder, batch_strides_mat, 12);
set_vector_bytes(compute_encoder, C_batch_stride, 13);
int bias_stride = c.strides()[c.ndim() - 1];
compute_encoder.set_bytes(bias_stride, 14);
compute_encoder->setBytes(&bias_stride, sizeof(int), 14);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
d.add_temporaries(std::move(copies), s.index);
return;
@@ -902,7 +907,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
mn_aligned,
k_aligned);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder->setComputePipelineState(kernel);
int tn = (N + bn - 1) / bn;
int tm = (M + bm - 1) / bm;
@@ -928,8 +933,8 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_input_array(b, 1);
compute_encoder.set_output_array(C_split, 2);
compute_encoder.set_bytes(params, 3);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
compute_encoder->setBytes(&params, sizeof(GEMMSpiltKParams), 3);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
// Do accum kernel
{
@@ -938,24 +943,25 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
auto kernel = get_steel_gemm_splitk_accum_kernel(
d, kernel_name, C_split, out, true);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder->setComputePipelineState(kernel);
// Set the arguments for the kernel
compute_encoder.set_input_array(C_split, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder.set_bytes(split_k_partitions, 2);
compute_encoder.set_bytes(split_k_partition_stride, 3);
compute_encoder.set_bytes(N, 4);
compute_encoder->setBytes(&split_k_partitions, sizeof(int), 2);
compute_encoder->setBytes(&split_k_partition_stride, sizeof(int), 3);
compute_encoder->setBytes(&N, sizeof(int), 4);
compute_encoder.set_input_array(c, 5);
compute_encoder.set_bytes(ldc, 6);
compute_encoder.set_bytes(fdc, 7);
compute_encoder.set_bytes(alpha_, 8);
compute_encoder.set_bytes(beta_, 9);
compute_encoder->setBytes(&ldc, sizeof(int), 6);
compute_encoder->setBytes(&fdc, sizeof(int), 7);
compute_encoder->setBytes(&alpha_, sizeof(float), 8);
compute_encoder->setBytes(&beta_, sizeof(float), 9);
// Launch enough thread groups for each output
MTL::Size grid_dims = MTL::Size(N, M, 1);
auto group_dims = get_block_dims(N, M, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims);
MTL::Size group_dims = MTL::Size(std::min(1024, N * M), 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
d.add_temporaries(std::move(copies), s.index);
@@ -1026,7 +1032,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
wm,
wn);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder->setComputePipelineState(kernel);
int tn = (N + bn - 1) / bn;
int tm = (M + bm - 1) / bm;
@@ -1077,13 +1083,13 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_input_array(c, 2);
compute_encoder.set_output_array(out, 3);
compute_encoder.set_bytes(gemm_params, 4);
compute_encoder.set_bytes(params, 5);
compute_encoder->setBytes(&gemm_params, sizeof(GEMMParams), 4);
compute_encoder->setBytes(&params, sizeof(GEMMAddMMParams), 5);
compute_encoder.set_vector_bytes(batch_shape, 6);
compute_encoder.set_vector_bytes(batch_strides, 7);
set_vector_bytes(compute_encoder, batch_shape, 6);
set_vector_bytes(compute_encoder, batch_strides, 7);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
d.add_temporaries(std::move(copies), s.index);
}
@@ -1298,7 +1304,7 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
contiguous_kernel);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder->setComputePipelineState(kernel);
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
MTL::Size group_dims = MTL::Size(32, bn, bm);
@@ -1366,18 +1372,18 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_input_array(vec, 1);
compute_encoder.set_output_array(out, 3);
compute_encoder.set_bytes(in_vector_len, 4);
compute_encoder.set_bytes(out_vector_len, 5);
compute_encoder.set_bytes(mat_ld, 6);
compute_encoder.set_bytes(batch_ndim, 9);
compute_encoder.set_vector_bytes(batch_shape, 10);
compute_encoder.set_vector_bytes(batch_strides_vec, 11);
compute_encoder.set_vector_bytes(batch_strides_mat, 12);
compute_encoder->setBytes(&in_vector_len, sizeof(int), 4);
compute_encoder->setBytes(&out_vector_len, sizeof(int), 5);
compute_encoder->setBytes(&mat_ld, sizeof(int), 6);
compute_encoder->setBytes(&batch_ndim, sizeof(int), 9);
set_vector_bytes(compute_encoder, batch_shape, 10);
set_vector_bytes(compute_encoder, batch_strides_vec, 11);
set_vector_bytes(compute_encoder, batch_strides_mat, 12);
compute_encoder.set_vector_bytes(mask_strides, 23);
compute_encoder.set_vector_bytes(mask_batch_strides, 24);
set_vector_bytes(compute_encoder, mask_strides, 23);
set_vector_bytes(compute_encoder, mask_batch_strides, 24);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
d.add_temporaries(std::move(copies), s.index);
return;
@@ -1417,7 +1423,7 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
wn,
mn_aligned,
k_aligned);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder->setComputePipelineState(kernel);
// Use problem size to determine threadblock swizzle
int tn = (N + bn - 1) / bn;
@@ -1480,14 +1486,14 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_input_array(b, 1);
compute_encoder.set_output_array(out, 3);
compute_encoder.set_bytes(params, 4);
compute_encoder->setBytes(&params, sizeof(GEMMParams), 4);
compute_encoder.set_vector_bytes(batch_shape, 6);
compute_encoder.set_vector_bytes(batch_strides, 7);
set_vector_bytes(compute_encoder, batch_shape, 6);
set_vector_bytes(compute_encoder, batch_strides, 7);
compute_encoder.set_vector_bytes(mask_strides, 13);
set_vector_bytes(compute_encoder, mask_strides, 13);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
d.add_temporaries(std::move(copies), s.index);
}
@@ -1681,7 +1687,7 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder->setComputePipelineState(kernel);
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
MTL::Size group_dims = MTL::Size(32, bn, bm);
@@ -1691,28 +1697,28 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_input_array(vec, 1);
compute_encoder.set_output_array(out, 3);
compute_encoder.set_bytes(in_vector_len, 4);
compute_encoder.set_bytes(out_vector_len, 5);
compute_encoder.set_bytes(mat_ld, 6);
compute_encoder->setBytes(&in_vector_len, sizeof(int), 4);
compute_encoder->setBytes(&out_vector_len, sizeof(int), 5);
compute_encoder->setBytes(&mat_ld, sizeof(int), 6);
compute_encoder.set_bytes(batch_ndim, 9);
compute_encoder.set_vector_bytes(batch_shape, 10);
compute_encoder.set_vector_bytes(batch_strides, 11);
compute_encoder->setBytes(&batch_ndim, sizeof(int), 9);
set_vector_bytes(compute_encoder, batch_shape, 10);
set_vector_bytes(compute_encoder, batch_strides, 11);
int batch_ndim_vec = batch_shape_vec.size();
compute_encoder.set_bytes(batch_ndim_vec, 12);
compute_encoder.set_vector_bytes(batch_shape_vec, 13);
compute_encoder.set_vector_bytes(batch_strides_vec, 14);
compute_encoder->setBytes(&batch_ndim_vec, sizeof(int), 12);
set_vector_bytes(compute_encoder, batch_shape_vec, 13);
set_vector_bytes(compute_encoder, batch_strides_vec, 14);
int batch_ndim_mat = batch_shape_mat.size();
compute_encoder.set_bytes(batch_ndim_mat, 15);
compute_encoder.set_vector_bytes(batch_shape_mat, 16);
compute_encoder.set_vector_bytes(batch_strides_mat, 17);
compute_encoder->setBytes(&batch_ndim_mat, sizeof(int), 15);
set_vector_bytes(compute_encoder, batch_shape_mat, 16);
set_vector_bytes(compute_encoder, batch_strides_mat, 17);
compute_encoder.set_input_array(lhs_indices, 18 + int(!is_b_matrix));
compute_encoder.set_input_array(rhs_indices, 18 + int(is_b_matrix));
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
d.add_temporaries(std::move(copies), s.index);
return;
@@ -1782,7 +1788,7 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
wm,
wn);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder->setComputePipelineState(kernel);
// Use problem size to determine threadblock swizzle
int tn = (N + bn - 1) / bn;
@@ -1821,10 +1827,10 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_input_array(b, 1);
compute_encoder.set_output_array(out, 3);
compute_encoder.set_bytes(params, 4);
compute_encoder->setBytes(&params, sizeof(GEMMParams), 4);
compute_encoder.set_vector_bytes(batch_shape, 6);
compute_encoder.set_vector_bytes(batch_strides, 7);
set_vector_bytes(compute_encoder, batch_shape, 6);
set_vector_bytes(compute_encoder, batch_strides, 7);
compute_encoder.set_input_array(lhs_indices, 10);
compute_encoder.set_input_array(rhs_indices, 11);
@@ -1839,11 +1845,11 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
operand_batch_ndim.push_back(0);
compute_encoder.set_vector_bytes(operand_shape, 13);
compute_encoder.set_vector_bytes(operand_strides, 14);
compute_encoder.set_vector_bytes(operand_batch_ndim, 15);
set_vector_bytes(compute_encoder, operand_shape, 13);
set_vector_bytes(compute_encoder, operand_strides, 14);
set_vector_bytes(compute_encoder, operand_batch_ndim, 15);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
d.add_temporaries(std::move(copies), s.index);
}

View File

@@ -1,11 +1,11 @@
// Copyright © 2023-2024 Apple Inc.
#include <cstdlib>
#include <memory>
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
#include "mlx/scheduler.h"
#include "mlx/utils.h"
namespace mlx::core::metal {
@@ -13,6 +13,20 @@ bool is_available() {
return true;
}
int max_ops_per_buffer() {
auto get_val = []() {
if (const char* buff_str = std::getenv("MLX_MAX_OPS_PER_BUFFER")) {
return atoi(buff_str);
} else {
return 10;
}
};
static int max_ops_per_buffer_ = get_val();
return max_ops_per_buffer_;
}
#define MAX_OPS_PER_BUFFER max_ops_per_buffer()
inline void check_error(MTL::CommandBuffer* cbuf) {
if (cbuf->status() == MTL::CommandBufferStatusError) {
std::ostringstream msg;
@@ -63,8 +77,7 @@ std::function<void()> make_task(array arr, bool signal) {
out.set_status(array::Status::evaluated);
}
if (signal ||
d.get_command_buffer_ops(s.index) >= env::max_ops_per_buffer()) {
if (signal || d.get_command_buffer_ops(s.index) >= MAX_OPS_PER_BUFFER) {
d.end_encoding(s.index);
if (signal) {
command_buffer->encodeSignalEvent(
@@ -94,7 +107,6 @@ std::function<void()> make_synchronize_task(
Stream s,
std::shared_ptr<std::promise<void>> p) {
return [s, p = std::move(p)]() {
auto pool = new_scoped_memory_pool();
auto& d = metal::device(s.device);
auto cb = d.get_command_buffer(s.index);
cb->retain();

View File

@@ -2,7 +2,6 @@
#pragma once
#include <unordered_map>
#include <variant>
#include "mlx/array.h"

View File

@@ -1,4 +1,3 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h"
@@ -100,7 +99,7 @@ MTL::ComputePipelineState* get_reduce_init_kernel(
const std::string& kernel_name,
const std::string&,
const std::string&,
const Dtype&) {
const array&) {
return d.get_kernel(kernel_name);
}
@@ -109,9 +108,8 @@ MTL::ComputePipelineState* get_reduce_kernel(
const std::string& kernel_name,
const std::string&,
const std::string&,
const Dtype&,
const Dtype&,
const std::string&,
const array&,
const array&,
int,
int,
int) {

View File

@@ -78,15 +78,18 @@ void RMSNorm::eval_gpu(
}
uint32_t w_stride = w.strides()[0];
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(
x.data_shared_ptr() == nullptr ? out : x, 0);
compute_encoder.set_input_array(w, 1);
compute_encoder.set_output_array(out, 2);
compute_encoder.set_bytes(eps_, 3);
compute_encoder.set_bytes(axis_size, 4);
compute_encoder.set_bytes(w_stride, 5);
compute_encoder.dispatch_threads(grid_dims, group_dims);
compute_encoder->setBytes(&eps_, sizeof(float), 3);
compute_encoder->setBytes(&axis_size, sizeof(int), 4);
compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 5);
compute_encoder->setThreadgroupMemoryLength(
16 * 8, 0); // minimum of 16 bytes
compute_encoder->setThreadgroupMemoryLength(simd_size * sizeof(float), 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
d.add_temporaries(std::move(copies), s.index);
@@ -180,16 +183,16 @@ void RMSNormVJP::eval_gpu(
}
uint32_t w_stride = w.strides()[0];
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(x_in_gx ? gx : x, 0);
compute_encoder.set_input_array(w, 1);
compute_encoder.set_input_array(g_in_gx ? gx : (g_in_gw ? gw_temp : g), 2);
compute_encoder.set_output_array(gx, 3);
compute_encoder.set_output_array(gw_temp, 4);
compute_encoder.set_bytes(eps_, 5);
compute_encoder.set_bytes(axis_size, 6);
compute_encoder.set_bytes(w_stride, 7);
compute_encoder.dispatch_threads(grid_dims, group_dims);
compute_encoder->setBytes(&eps_, sizeof(float), 5);
compute_encoder->setBytes(&axis_size, sizeof(int), 6);
compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 7);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
ReductionPlan plan(
@@ -270,17 +273,17 @@ void LayerNorm::eval_gpu(
uint32_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
uint32_t b_stride = (b.ndim() == 1) ? b.strides()[0] : 0;
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(
x.data_shared_ptr() == nullptr ? out : x, 0);
compute_encoder.set_input_array(w, 1);
compute_encoder.set_input_array(b, 2);
compute_encoder.set_output_array(out, 3);
compute_encoder.set_bytes(eps_, 4);
compute_encoder.set_bytes(axis_size, 5);
compute_encoder.set_bytes(w_stride, 6);
compute_encoder.set_bytes(b_stride, 7);
compute_encoder.dispatch_threads(grid_dims, group_dims);
compute_encoder->setBytes(&eps_, sizeof(float), 4);
compute_encoder->setBytes(&axis_size, sizeof(int), 5);
compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 6);
compute_encoder->setBytes(&b_stride, sizeof(uint32_t), 7);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
d.add_temporaries(std::move(copies), s.index);
@@ -392,16 +395,16 @@ void LayerNormVJP::eval_gpu(
}
uint32_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(x_in_gx ? gx : x, 0);
compute_encoder.set_input_array(w, 1);
compute_encoder.set_input_array(g_in_gx ? gx : (g_in_gw ? gw_temp : g), 2);
compute_encoder.set_output_array(gx, 3);
compute_encoder.set_output_array(gw_temp, 4);
compute_encoder.set_bytes(eps_, 5);
compute_encoder.set_bytes(axis_size, 6);
compute_encoder.set_bytes(w_stride, 7);
compute_encoder.dispatch_threads(grid_dims, group_dims);
compute_encoder->setBytes(&eps_, sizeof(float), 5);
compute_encoder->setBytes(&axis_size, sizeof(int), 6);
compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 7);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
if (gw.ndim() == 1 && gw.size() == axis_size) {

View File

@@ -5,7 +5,6 @@
#include <sstream>
#include "mlx/backend/common/load.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
@@ -18,10 +17,10 @@
namespace mlx::core {
template <typename T>
void arange_set_scalars(T start, T next, metal::CommandEncoder& enc) {
enc.set_bytes(start, 0);
void arange_set_scalars(T start, T next, CommandEncoder& enc) {
enc->setBytes(&start, sizeof(T), 0);
T step = next - start;
enc.set_bytes(step, 1);
enc->setBytes(&step, sizeof(T), 1);
}
void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -38,7 +37,7 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
MTL::Size group_dims = MTL::Size(
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder->setComputePipelineState(kernel);
switch (out.dtype()) {
case bool_: // unsupported
@@ -81,7 +80,7 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
}
compute_encoder.set_output_array(out, 2);
compute_encoder.dispatch_threads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -130,25 +129,25 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
size_t n_threads = out.size() * thread_group_size;
MTL::Size grid_dims = MTL::Size(n_threads, 1, 1);
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1);
if (ndim == 0) {
// Pass place holders so metal doesn't complain
int shape_ = 0;
size_t stride_ = 0;
compute_encoder.set_bytes(shape_, 2);
compute_encoder.set_bytes(stride_, 3);
compute_encoder.set_bytes(stride_, 4);
compute_encoder->setBytes(&shape_, sizeof(int), 2);
compute_encoder->setBytes(&stride_, sizeof(size_t), 3);
compute_encoder->setBytes(&stride_, sizeof(size_t), 4);
} else {
compute_encoder.set_vector_bytes(shape, 2);
compute_encoder.set_vector_bytes(in_strides, 3);
compute_encoder.set_vector_bytes(out_strides, 4);
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 2);
compute_encoder->setBytes(in_strides.data(), ndim * sizeof(size_t), 3);
compute_encoder->setBytes(out_strides.data(), ndim * sizeof(size_t), 4);
}
compute_encoder.set_bytes(ndim, 5);
compute_encoder.set_bytes(axis_stride, 6);
compute_encoder.set_bytes(axis_size, 7);
compute_encoder.dispatch_threads(grid_dims, group_dims);
compute_encoder->setBytes(&ndim, sizeof(size_t), 5);
compute_encoder->setBytes(&axis_stride, sizeof(size_t), 6);
compute_encoder->setBytes(&axis_size, sizeof(size_t), 7);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
}
@@ -170,17 +169,6 @@ void Concatenate::eval_gpu(const std::vector<array>& inputs, array& out) {
concatenate_gpu(inputs, out, axis_, stream());
}
void Contiguous::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.flags().row_contiguous ||
(allow_col_major_ && in.flags().col_contiguous)) {
move_or_copy(in, out);
} else {
copy_gpu(in, out, CopyType::General);
}
}
void Copy::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
@@ -285,22 +273,24 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
// organize into grid nkeys x elem_per_key
MTL::Size grid_dims = MTL::Size(num_keys, half_size + odd, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
auto group_dims = get_block_dims(num_keys, half_size + odd, 1);
MTL::Size group_dims = MTL::Size(1, thread_group_size, 1);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(keys, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder.set_bytes(odd, 2);
compute_encoder.set_bytes(bytes_per_key, 3);
compute_encoder->setBytes(&odd, sizeof(bool), 2);
compute_encoder->setBytes(&bytes_per_key, sizeof(size_t), 3);
if (!keys.flags().row_contiguous) {
int ndim = keys.ndim();
compute_encoder.set_bytes(ndim, 4);
compute_encoder.set_vector_bytes(keys.shape(), 5);
compute_encoder.set_vector_bytes(keys.strides(), 6);
compute_encoder->setBytes(&ndim, sizeof(int), 4);
compute_encoder->setBytes(
keys.shape().data(), keys.ndim() * sizeof(int), 5);
compute_encoder->setBytes(
keys.strides().data(), keys.ndim() * sizeof(size_t), 6);
}
compute_encoder.dispatch_threads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -355,7 +345,7 @@ void SliceUpdate::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& upd = inputs[1];
if (upd.size() == 0) {
move_or_copy(in, out);
out.copy_shared_buffer(in);
return;
}
@@ -428,12 +418,12 @@ void View::eval_gpu(const std::vector<array>& inputs, array& out) {
if (ibytes == obytes || obytes < ibytes && in.strides().back() == 1 ||
in.flags().row_contiguous) {
auto strides = in.strides();
for (int i = 0; i < static_cast<int>(strides.size()) - 1; ++i) {
for (int i = 0; i < strides.size() - 1; ++i) {
strides[i] *= ibytes;
strides[i] /= obytes;
}
move_or_copy(
in, out, strides, in.flags(), in.data_size() * ibytes / obytes);
out.copy_shared_buffer(
in, strides, in.flags(), in.data_size() * ibytes / obytes);
} else {
auto tmp = array(in.shape(), in.dtype(), nullptr, {});
tmp.set_data(allocator::malloc_or_wait(tmp.nbytes()));

View File

@@ -10,7 +10,6 @@
#include "mlx/backend/metal/utils.h"
#include "mlx/fast_primitives.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
namespace mlx::core {
@@ -102,31 +101,31 @@ void launch_qmm(
auto& d = metal::device(s.device);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(w, 0);
compute_encoder.set_input_array(scales, 1);
compute_encoder.set_input_array(biases, 2);
compute_encoder.set_input_array(x, 3);
compute_encoder.set_output_array(out, 4);
compute_encoder.set_bytes(D, 5);
compute_encoder.set_bytes(O, 6);
compute_encoder->setBytes(&D, sizeof(int), 5);
compute_encoder->setBytes(&O, sizeof(int), 6);
int offset = 7;
if (matrix) {
compute_encoder.set_bytes(B, 7);
compute_encoder->setBytes(&B, sizeof(int), 7);
offset += 1;
}
if (batched || gather) {
compute_encoder.set_bytes(x_batch_ndims, offset);
compute_encoder.set_vector_bytes(x_shape, offset + 1);
compute_encoder.set_vector_bytes(x_strides, offset + 2);
compute_encoder.set_bytes(w_batch_ndims, offset + 3);
compute_encoder.set_vector_bytes(w_shape, offset + 4);
compute_encoder.set_vector_bytes(w_strides, offset + 5);
compute_encoder.set_vector_bytes(s_strides, offset + 6);
compute_encoder.set_vector_bytes(b_strides, offset + 7);
compute_encoder->setBytes(&x_batch_ndims, sizeof(int), offset);
set_vector_bytes(compute_encoder, x_shape, offset + 1);
set_vector_bytes(compute_encoder, x_strides, offset + 2);
compute_encoder->setBytes(&w_batch_ndims, sizeof(int), offset + 3);
set_vector_bytes(compute_encoder, w_shape, offset + 4);
set_vector_bytes(compute_encoder, w_strides, offset + 5);
set_vector_bytes(compute_encoder, s_strides, offset + 6);
set_vector_bytes(compute_encoder, b_strides, offset + 7);
}
if (gather) {
auto& lhs_indices = inputs[4];
@@ -138,15 +137,15 @@ void launch_qmm(
auto& lhs_strides = lhs_indices.strides();
auto& rhs_strides = rhs_indices.strides();
compute_encoder.set_bytes(batch_ndims, offset + 8);
compute_encoder.set_vector_bytes(batch_shape, offset + 9);
compute_encoder->setBytes(&batch_ndims, sizeof(int), offset + 8);
set_vector_bytes(compute_encoder, batch_shape, offset + 9);
compute_encoder.set_input_array(lhs_indices, offset + 10);
compute_encoder.set_input_array(rhs_indices, offset + 11);
compute_encoder.set_vector_bytes(lhs_strides, offset + 12);
compute_encoder.set_vector_bytes(rhs_strides, offset + 13);
set_vector_bytes(compute_encoder, lhs_strides, offset + 12);
set_vector_bytes(compute_encoder, rhs_strides, offset + 13);
}
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
d.add_temporaries(std::move(copies), s.index);
}
@@ -237,27 +236,27 @@ void qvm_split_k(
// Encode and dispatch kernel
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(w, 0);
compute_encoder.set_input_array(scales, 1);
compute_encoder.set_input_array(biases, 2);
compute_encoder.set_input_array(x, 3);
compute_encoder.set_output_array(intermediate, 4);
compute_encoder.set_bytes(split_D, 5);
compute_encoder.set_bytes(O, 6);
compute_encoder->setBytes(&split_D, sizeof(int), 5);
compute_encoder->setBytes(&O, sizeof(int), 6);
compute_encoder.set_bytes(x_batch_ndims, 7);
compute_encoder.set_vector_bytes(x_shape, 8);
compute_encoder.set_vector_bytes(x_strides, 9);
compute_encoder.set_bytes(w_batch_ndims, 10);
compute_encoder.set_vector_bytes(w_shape, 11);
compute_encoder.set_vector_bytes(w_strides, 12);
compute_encoder.set_vector_bytes(s_strides, 13);
compute_encoder.set_vector_bytes(b_strides, 14);
compute_encoder.set_bytes(final_block_size, 15);
compute_encoder->setBytes(&x_batch_ndims, sizeof(int), 7);
set_vector_bytes(compute_encoder, x_shape, 8);
set_vector_bytes(compute_encoder, x_strides, 9);
compute_encoder->setBytes(&w_batch_ndims, sizeof(int), 10);
set_vector_bytes(compute_encoder, w_shape, 11);
set_vector_bytes(compute_encoder, w_strides, 12);
set_vector_bytes(compute_encoder, s_strides, 13);
set_vector_bytes(compute_encoder, b_strides, 14);
compute_encoder->setBytes(&final_block_size, sizeof(int), 15);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
d.add_temporaries(std::move(copies), s.index);
int axis = intermediate.ndim() - 3;
@@ -299,7 +298,7 @@ void qmm_op(
bool quad = false;
if (transpose) {
if (B < 6 && (D == 128 || D == 64) && is_power_of_2(bits)) {
if (B < 6 && (D == 128 || D == 64)) {
name += "qmv_quad";
constexpr int quads_per_simd = 8;
constexpr int results_per_quadgroup = 8;
@@ -392,6 +391,8 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
void fast::AffineQuantize::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
bool compute_scale_bias = inputs.size() == 1;
auto& w_pre = inputs[0];
auto& out = outputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
@@ -414,7 +415,7 @@ void fast::AffineQuantize::eval_gpu(
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_input_array(w, 0);
if (dequantize_) {
if (!compute_scale_bias) {
auto& scales_pre = inputs[1];
auto& biases_pre = inputs[2];
auto scales = ensure_row_contiguous(scales_pre);
@@ -435,21 +436,26 @@ void fast::AffineQuantize::eval_gpu(
std::ostringstream kname;
auto type_string = dequantize_ ? get_type_string(out.dtype())
: get_type_string(w_pre.dtype());
auto kernel_func = dequantize_ ? "affine_dequantize" : "affine_quantize";
auto kernel_func = "affine_quantize_scales_biases";
if (dequantize_) {
kernel_func = "affine_dequantize";
} else if (compute_scale_bias) {
kernel_func = "affine_quantize";
}
kname << kernel_func << "_" << type_string << "_gs_" << group_size_ << "_b_"
<< bits_;
auto template_def = get_template_definition(
kname.str(), kernel_func, type_string, group_size_, bits_);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder->setComputePipelineState(kernel);
// Treat uint32 as uint8 in kernel
constexpr int uint8_per_uint32 = 4;
constexpr int simd_size = 32;
int packs_per_int = bits_ == 3 ? 8 : bits_ == 6 ? 4 : 8 / bits_;
int per_thread = dequantize_ ? packs_per_int : group_size_ / simd_size;
int packs_per_int = 8 / bits_;
int per_thread = compute_scale_bias ? group_size_ / simd_size : packs_per_int;
size_t nthreads =
dequantize_ ? out.size() / packs_per_int : w.size() / per_thread;
dequantize_ ? w.size() * uint8_per_uint32 : w.size() / per_thread;
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) {
@@ -465,7 +471,7 @@ void fast::AffineQuantize::eval_gpu(
}
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(grid_shape, w.strides())
: MTL::Size(nthreads, 1, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
d.add_temporaries(std::move(copies), s.index);
}

View File

@@ -2,6 +2,7 @@
#include <algorithm>
#include <cassert>
#include <sstream>
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
@@ -66,14 +67,17 @@ struct RowReduceArgs {
strides.push_back(0);
}
compute_encoder.set_bytes(row_size, 2);
compute_encoder.set_bytes(non_row_reductions, 3);
compute_encoder.set_vector_bytes(shape, 4);
compute_encoder.set_vector_bytes(strides, 5);
compute_encoder.set_bytes(ndim, 6);
compute_encoder.set_vector_bytes(reduce_shape, 7);
compute_encoder.set_vector_bytes(reduce_strides, 8);
compute_encoder.set_bytes(reduce_ndim, 9);
compute_encoder->setBytes(&row_size, sizeof(size_t), 2);
compute_encoder->setBytes(&non_row_reductions, sizeof(size_t), 3);
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 4);
compute_encoder->setBytes(
strides.data(), strides.size() * sizeof(size_t), 5);
compute_encoder->setBytes(&ndim, sizeof(int), 6);
compute_encoder->setBytes(
reduce_shape.data(), reduce_shape.size() * sizeof(int), 7);
compute_encoder->setBytes(
reduce_strides.data(), reduce_strides.size() * sizeof(size_t), 8);
compute_encoder->setBytes(&reduce_ndim, sizeof(int), 9);
if (reduce_ndim == 0) {
reduce_shape.pop_back();
@@ -162,15 +166,18 @@ struct ColReduceArgs {
strides.push_back(0);
}
compute_encoder.set_bytes(reduction_size, 2);
compute_encoder.set_bytes(reduction_stride, 3);
compute_encoder.set_vector_bytes(shape, 4);
compute_encoder.set_vector_bytes(strides, 5);
compute_encoder.set_bytes(ndim, 6);
compute_encoder.set_vector_bytes(reduce_shape, 7);
compute_encoder.set_vector_bytes(reduce_strides, 8);
compute_encoder.set_bytes(reduce_ndim, 9);
compute_encoder.set_bytes(non_col_reductions, 10);
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3);
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 4);
compute_encoder->setBytes(
strides.data(), strides.size() * sizeof(size_t), 5);
compute_encoder->setBytes(&ndim, sizeof(int), 6);
compute_encoder->setBytes(
reduce_shape.data(), reduce_shape.size() * sizeof(int), 7);
compute_encoder->setBytes(
reduce_strides.data(), reduce_strides.size() * sizeof(size_t), 8);
compute_encoder->setBytes(&reduce_ndim, sizeof(int), 9);
compute_encoder->setBytes(&non_col_reductions, sizeof(size_t), 10);
if (reduce_ndim == 0) {
reduce_shape.pop_back();
@@ -201,16 +208,6 @@ inline bool is_64b_dtype(Dtype dtype) {
return dtype == int64 || dtype == uint64 || dtype == complex64;
}
inline int get_kernel_reduce_ndim(int reduce_ndim) {
if (reduce_ndim <= 1) {
return 1;
} else if (reduce_ndim == 2) {
return 2;
} else {
return 5;
}
}
inline int threadgroup_size_from_row_size(int row_size) {
// 1 simdgroup per row smallish rows
if (row_size <= 512) {
@@ -242,51 +239,16 @@ inline auto output_grid_for_col_reduce(
return get_2d_grid_dims(out_shape, out_strides);
}
std::pair<Dtype, Dtype> remap_reduce_types(
const array& in,
const std::string& op_name) {
if (op_name == "sum" || op_name == "prod") {
if (issubdtype(in.dtype(), integer)) {
switch (in.dtype().size()) {
case 1:
return {int8, int32};
case 2:
return {int16, int32};
case 4:
return {int32, int32};
case 8:
return {int64, int64};
}
}
if (in.dtype() == bool_) {
return {int8, int32};
}
return {in.dtype(), in.dtype()};
} else if (op_name == "and" || op_name == "or") {
if (in.dtype().size() == 1) {
return {bool_, bool_};
} else if (in.dtype().size() == 2) {
return {int16, bool_};
} else if (in.dtype().size() == 4) {
return {int32, bool_};
} else {
return {int64, bool_};
}
}
return {in.dtype(), in.dtype()};
}
void init_reduce(
array& out,
const std::string& op_name,
CommandEncoder& compute_encoder,
metal::Device& d,
const Stream& s) {
auto [_, out_type] = remap_reduce_types(out, op_name);
std::ostringstream kname;
const std::string func_name = "init_reduce";
std::string kname = func_name;
concatenate(kname, "_", op_name, type_to_name(out_type));
auto kernel = get_reduce_init_kernel(d, kname, func_name, op_name, out_type);
kname << func_name << "_" << op_name << type_to_name(out);
auto kernel = get_reduce_init_kernel(d, kname.str(), func_name, op_name, out);
size_t nthreads = out.size();
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
@@ -294,9 +256,9 @@ void init_reduce(
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_output_array(out, 0);
compute_encoder.dispatch_threads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
void all_reduce_dispatch(
@@ -307,13 +269,11 @@ void all_reduce_dispatch(
metal::Device& d,
const Stream& s) {
// Set the kernel
auto [in_type, out_type] = remap_reduce_types(in, op_name);
std::ostringstream kname;
const std::string func_name = "all_reduce";
std::string kname = func_name;
concatenate(kname, "_", op_name, type_to_name(in_type));
auto kernel = get_reduce_kernel(
d, kname, func_name, op_name, in_type, out_type, "int64_t");
compute_encoder.set_compute_pipeline_state(kernel);
kname << func_name << "_" << op_name << type_to_name(in);
auto kernel = get_reduce_kernel(d, kname.str(), func_name, op_name, in, out);
compute_encoder->setComputePipelineState(kernel);
size_t in_size = in.size();
@@ -325,9 +285,9 @@ void all_reduce_dispatch(
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder.set_bytes(in_size, 2);
compute_encoder.set_bytes(in_size, 3);
compute_encoder.dispatch_threads(grid_dims, grid_dims);
compute_encoder->setBytes(&in_size, sizeof(size_t), 2);
compute_encoder->setBytes(&in_size, sizeof(size_t), 3);
compute_encoder.dispatchThreads(grid_dims, grid_dims);
}
// We need multiple threadgroups so we 'll do it in 2 passes.
@@ -346,7 +306,7 @@ void all_reduce_dispatch(
}
// Allocate an intermediate tensor to hold results if needed
array intermediate({n_rows}, out_type, nullptr, {});
array intermediate({n_rows}, out.dtype(), nullptr, {});
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
d.add_temporary(intermediate, s.index);
@@ -359,24 +319,24 @@ void all_reduce_dispatch(
MTL::Size group_dims(threadgroup_size, 1, 1);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(intermediate, 1);
compute_encoder.set_bytes(in_size, 2);
compute_encoder.set_bytes(row_size, 3);
compute_encoder.dispatch_threads(grid_dims, group_dims);
compute_encoder->setBytes(&in_size, sizeof(size_t), 2);
compute_encoder->setBytes(&row_size, sizeof(size_t), 3);
compute_encoder.dispatchThreads(grid_dims, group_dims);
// 2nd pass
std::string kname_2nd_pass = func_name;
concatenate(kname_2nd_pass, "_", op_name, type_to_name(intermediate));
std::ostringstream kname_2nd_pass;
kname_2nd_pass << func_name << "_" << op_name << type_to_name(intermediate);
auto kernel_2nd_pass = get_reduce_kernel(
d, kname_2nd_pass, func_name, op_name, out_type, out_type, "int64_t");
compute_encoder.set_compute_pipeline_state(kernel_2nd_pass);
d, kname_2nd_pass.str(), func_name, op_name, intermediate, out);
compute_encoder->setComputePipelineState(kernel_2nd_pass);
size_t intermediate_size = n_rows;
grid_dims = MTL::Size(threadgroup_2nd_pass, 1, 1);
group_dims = MTL::Size(threadgroup_2nd_pass, 1, 1);
compute_encoder.set_input_array(intermediate, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder.set_bytes(intermediate_size, 2);
compute_encoder.set_bytes(intermediate_size, 3);
compute_encoder.dispatch_threads(grid_dims, group_dims);
compute_encoder->setBytes(&intermediate_size, sizeof(size_t), 2);
compute_encoder->setBytes(&intermediate_size, sizeof(size_t), 3);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
}
@@ -389,31 +349,13 @@ void row_reduce_small(
metal::Device& d,
const Stream& s) {
// Set the kernel
int n = get_kernel_reduce_ndim(args.reduce_ndim);
auto [in_type, out_type] = remap_reduce_types(in, op_name);
std::ostringstream kname;
int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0;
const std::string func_name = "row_reduce_small";
std::string kname = func_name;
bool large = in.size() > UINT32_MAX;
if (large) {
kname += "_large";
}
concatenate(
kname,
"_",
std::to_string(n),
"_reduce_",
op_name,
type_to_name(in_type));
auto kernel = get_reduce_kernel(
d,
kname,
func_name,
op_name,
in_type,
out_type,
large ? "size_t" : "uint",
n);
compute_encoder.set_compute_pipeline_state(kernel);
kname << func_name << "_" << n << "_reduce_" << op_name << type_to_name(in);
auto kernel =
get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n);
compute_encoder->setComputePipelineState(kernel);
// Figure out the grid dims
MTL::Size grid_dims;
@@ -433,7 +375,7 @@ void row_reduce_small(
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1);
args.encode(compute_encoder);
compute_encoder.dispatch_threads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
void row_reduce_simple(
@@ -445,14 +387,11 @@ void row_reduce_simple(
metal::Device& d,
const Stream& s) {
// Set the kernel
auto [in_type, out_type] = remap_reduce_types(in, op_name);
std::ostringstream kname;
const std::string func_name = "row_reduce_simple";
std::string kname = func_name;
concatenate(kname, "_", op_name, type_to_name(in_type));
auto kernel = get_reduce_kernel(
d, kname, func_name, op_name, in_type, out_type, "size_t");
compute_encoder.set_compute_pipeline_state(kernel);
kname << func_name << "_" << op_name << type_to_name(in);
auto kernel = get_reduce_kernel(d, kname.str(), func_name, op_name, in, out);
compute_encoder->setComputePipelineState(kernel);
// Figure out the grid dims
size_t row_size = args.row_size;
@@ -471,9 +410,9 @@ void row_reduce_simple(
// Launch
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder.set_bytes(row_size, 2);
compute_encoder.set_bytes(out_size, 3);
compute_encoder.dispatch_threads(grid_dims, group_dims);
compute_encoder->setBytes(&row_size, sizeof(size_t), 2);
compute_encoder->setBytes(&out_size, sizeof(size_t), 3);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
void row_reduce_looped(
@@ -484,33 +423,14 @@ void row_reduce_looped(
CommandEncoder& compute_encoder,
metal::Device& d,
const Stream& s) {
auto [in_type, out_type] = remap_reduce_types(in, op_name);
// Set the kernel
int n = get_kernel_reduce_ndim(args.reduce_ndim);
std::ostringstream kname;
int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0;
const std::string func_name = "row_reduce_looped";
std::string kname = func_name;
bool large = in.size() > UINT32_MAX;
if (large) {
kname += "_large";
}
concatenate(
kname,
"_",
std::to_string(n),
"_reduce_",
op_name,
type_to_name(in_type));
auto kernel = get_reduce_kernel(
d,
kname,
func_name,
op_name,
in_type,
out_type,
large ? "size_t" : "uint",
n);
compute_encoder.set_compute_pipeline_state(kernel);
kname << func_name << "_" << n << "_reduce_" << op_name << type_to_name(in);
auto kernel =
get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n);
compute_encoder->setComputePipelineState(kernel);
// Figure out the grid
auto out_grid_size = get_2d_grid_dims(out.shape(), out.strides());
@@ -523,7 +443,7 @@ void row_reduce_looped(
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1);
args.encode(compute_encoder);
compute_encoder.dispatch_threads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
void row_reduce_general_dispatch(
@@ -561,8 +481,6 @@ void strided_reduce_small(
CommandEncoder& compute_encoder,
metal::Device& d,
const Stream& s) {
auto [in_type, out_type] = remap_reduce_types(in, op_name);
// Figure out the grid dims
MTL::Size grid_dims, group_dims;
@@ -571,30 +489,13 @@ void strided_reduce_small(
args.reduce_strides.push_back(args.reduction_stride);
args.reduce_ndim++;
int n = get_kernel_reduce_ndim(args.reduce_ndim);
int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0;
std::ostringstream kname;
const std::string func_name = "col_reduce_small";
std::string kname = func_name;
bool large = in.size() > UINT32_MAX;
if (large) {
kname += "_large";
}
concatenate(
kname,
"_",
std::to_string(n),
"_reduce_",
op_name,
type_to_name(in_type));
auto kernel = get_reduce_kernel(
d,
kname,
func_name,
op_name,
in_type,
out_type,
large ? "size_t" : "uint",
n);
compute_encoder.set_compute_pipeline_state(kernel);
kname << func_name << "_" << n << "_reduce_" << op_name << type_to_name(in);
auto kernel =
get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n);
compute_encoder->setComputePipelineState(kernel);
const int n_reads = 4;
size_t reduction_stride_blocks =
@@ -616,7 +517,7 @@ void strided_reduce_small(
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1);
args.encode(compute_encoder);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
void strided_reduce_longcolumn(
@@ -627,7 +528,6 @@ void strided_reduce_longcolumn(
CommandEncoder& compute_encoder,
metal::Device& d,
const Stream& s) {
auto [in_type, out_type] = remap_reduce_types(in, op_name);
size_t total_reduction_size = args.reduction_size * args.non_col_reductions;
size_t outer_blocks = 32;
if (total_reduction_size >= 32768) {
@@ -640,7 +540,7 @@ void strided_reduce_longcolumn(
intermediate_shape.push_back(outer_blocks);
intermediate_shape.insert(
intermediate_shape.end(), out.shape().begin(), out.shape().end());
array intermediate(std::move(intermediate_shape), out_type, nullptr, {});
array intermediate(std::move(intermediate_shape), out.dtype(), nullptr, {});
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
d.add_temporary(intermediate, s.index);
@@ -662,37 +562,20 @@ void strided_reduce_longcolumn(
MTL::Size group_dims(threadgroup_x, threadgroup_y, 1);
// Set the kernel
int n = get_kernel_reduce_ndim(args.reduce_ndim);
std::string func_name = "col_reduce_longcolumn";
std::string kname = func_name;
bool large = in.size() > UINT32_MAX;
if (large) {
kname += "_large";
}
concatenate(
kname,
"_",
std::to_string(n),
"_reduce_",
op_name,
type_to_name(in_type));
auto kernel = get_reduce_kernel(
d,
kname,
func_name,
op_name,
in_type,
out_type,
large ? "size_t" : "uint",
n);
compute_encoder.set_compute_pipeline_state(kernel);
int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0;
std::ostringstream kname;
const std::string func_name = "col_reduce_longcolumn";
kname << func_name << "_" << n << "_reduce_" << op_name << type_to_name(in);
auto kernel =
get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n);
compute_encoder->setComputePipelineState(kernel);
// Launch
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(intermediate, 1);
args.encode(compute_encoder);
compute_encoder.set_bytes(out_size, 11);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
compute_encoder->setBytes(&out_size, sizeof(size_t), 11);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
// Make the 2nd pass arguments and grid_dims
ColReduceArgs second_args(intermediate);
@@ -704,30 +587,24 @@ void strided_reduce_longcolumn(
group_dims = MTL::Size(256, 1, 1);
// Set the 2nd kernel
func_name = "col_reduce_looped";
kname = func_name;
large = intermediate.size() > UINT32_MAX;
if (large) {
kname += "_large";
}
concatenate(kname, "_1_32_32_reduce_", op_name, type_to_name(intermediate));
const std::string second_kernel = "col_reduce_looped_1_32_32_reduce_" +
op_name + type_to_name(intermediate);
kernel = get_reduce_kernel(
d,
kname,
func_name,
second_kernel,
"col_reduce_looped",
op_name,
intermediate.dtype(),
out_type,
large ? "size_t" : "uint",
intermediate,
out,
1,
32,
32);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(intermediate, 0);
compute_encoder.set_output_array(out, 1);
second_args.encode(compute_encoder);
compute_encoder.dispatch_threads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
void strided_reduce_looped(
@@ -738,8 +615,6 @@ void strided_reduce_looped(
CommandEncoder& compute_encoder,
metal::Device& d,
const Stream& s) {
auto [in_type, out_type] = remap_reduce_types(in, op_name);
// Prepare the arguments for the kernel
args.reduce_shape.push_back(args.reduction_size);
args.reduce_strides.push_back(args.reduction_stride);
@@ -757,42 +632,20 @@ void strided_reduce_looped(
MTL::Size group_dims(threadgroup_size, 1, 1);
// Set the kernel
int n = get_kernel_reduce_ndim(args.reduce_ndim);
std::string func_name = "col_reduce_looped";
std::string kname = func_name;
bool large = in.size() > UINT32_MAX;
if (large) {
kname += "_large";
}
concatenate(
kname,
"_",
std::to_string(n),
"_",
std::to_string(BM),
"_",
std::to_string(BN),
"_reduce_",
op_name,
type_to_name(in_type));
auto kernel = get_reduce_kernel(
d,
kname,
func_name,
op_name,
in_type,
out_type,
large ? "size_t" : "uint",
n,
BM,
BN);
compute_encoder.set_compute_pipeline_state(kernel);
int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0;
std::ostringstream kname;
const std::string func_name = "col_reduce_looped";
kname << func_name << "_" << n << "_" << BM << "_" << BN << "_reduce_"
<< op_name << type_to_name(in);
auto kernel =
get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n, BM, BN);
compute_encoder->setComputePipelineState(kernel);
// Launch
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1);
args.encode(compute_encoder);
compute_encoder.dispatch_threads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
void strided_reduce_2pass(
@@ -803,15 +656,13 @@ void strided_reduce_2pass(
CommandEncoder& compute_encoder,
metal::Device& d,
const Stream& s) {
auto [in_type, out_type] = remap_reduce_types(in, op_name);
// Prepare the temporary accumulator
std::vector<int> intermediate_shape;
intermediate_shape.reserve(out.ndim() + 1);
intermediate_shape.push_back(32);
intermediate_shape.insert(
intermediate_shape.end(), out.shape().begin(), out.shape().end());
array intermediate(std::move(intermediate_shape), out_type, nullptr, {});
array intermediate(std::move(intermediate_shape), out.dtype(), nullptr, {});
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
d.add_temporary(intermediate, s.index);
@@ -834,43 +685,21 @@ void strided_reduce_2pass(
MTL::Size group_dims(threadgroup_size, 1, 1);
// Set the kernel
int n = get_kernel_reduce_ndim(args.reduce_ndim);
std::string func_name = "col_reduce_2pass";
std::string kname = func_name;
bool large = in.size() > UINT32_MAX;
if (large) {
kname += "_large";
}
concatenate(
kname,
"_",
std::to_string(n),
"_",
std::to_string(BM),
"_",
std::to_string(BN),
"_reduce_",
op_name,
type_to_name(in_type));
auto kernel = get_reduce_kernel(
d,
kname,
func_name,
op_name,
in_type,
out_type,
large ? "size_t" : "uint",
n,
BM,
BN);
compute_encoder.set_compute_pipeline_state(kernel);
int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0;
std::ostringstream kname;
const std::string func_name = "col_reduce_2pass";
kname << func_name << "_" << n << "_" << BM << "_" << BN << "_reduce_"
<< op_name << type_to_name(in);
auto kernel =
get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n, BM, BN);
compute_encoder->setComputePipelineState(kernel);
// Launch
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(intermediate, 1);
args.encode(compute_encoder);
compute_encoder.set_bytes(out_size, 11);
compute_encoder.dispatch_threads(grid_dims, group_dims);
compute_encoder->setBytes(&out_size, sizeof(size_t), 11);
compute_encoder.dispatchThreads(grid_dims, group_dims);
// Make the 2nd pass arguments and grid_dims
ColReduceArgs second_args(intermediate);
@@ -880,30 +709,24 @@ void strided_reduce_2pass(
grid_dims = MTL::Size(threadgroup_size * ((out.size() + BN - 1) / BN), 1, 1);
// Set the 2nd kernel
func_name = "col_reduce_looped";
kname = func_name;
large = intermediate.size() > UINT32_MAX;
if (large) {
kname += "_large";
}
concatenate(kname, "_1_32_32_reduce_", op_name, type_to_name(intermediate));
const std::string second_kernel = "col_reduce_looped_1_32_32_reduce_" +
op_name + type_to_name(intermediate);
kernel = get_reduce_kernel(
d,
kname,
func_name,
second_kernel,
"col_reduce_looped",
op_name,
intermediate.dtype(),
out_type,
large ? "size_t" : "uint",
intermediate,
out,
1,
32,
32);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(intermediate, 0);
compute_encoder.set_output_array(out, 1);
second_args.encode(compute_encoder);
compute_encoder.dispatch_threads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
void strided_reduce_general_dispatch(
@@ -963,7 +786,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
op_name = "sum";
break;
case Reduce::Prod:
op_name = "prod";
op_name = out.dtype() == bool_ ? "and" : "prod";
break;
case Reduce::Min:
op_name = out.dtype() == bool_ ? "and" : "min";

View File

@@ -63,7 +63,6 @@ void ResidencySet::resize(size_t size) {
size_t current_size = wired_set_->allocatedSize();
if (current_size < size) {
auto pool = new_scoped_memory_pool();
// Add unwired allocations to the set
for (auto it = unwired_set_.begin(); it != unwired_set_.end();) {
auto buf_size = (*it)->allocatedSize();
@@ -78,7 +77,6 @@ void ResidencySet::resize(size_t size) {
wired_set_->commit();
wired_set_->requestResidency();
} else if (current_size > size) {
auto pool = new_scoped_memory_pool();
// Remove wired allocations until under capacity
auto allocations = wired_set_->allAllocations();
auto num_allocations = wired_set_->allocationCount();
@@ -94,7 +92,6 @@ void ResidencySet::resize(size_t size) {
ResidencySet::~ResidencySet() {
if (wired_set_) {
auto pool = new_scoped_memory_pool();
wired_set_->release();
}
}

View File

@@ -75,24 +75,24 @@ void RoPE::eval_gpu(
auto& compute_encoder = d.get_command_encoder(s.index);
float base = std::log2(base_);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(donated ? out : in, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder.set_bytes(offset_, 2);
compute_encoder.set_bytes(scale_, 3);
compute_encoder->setBytes(&offset_, sizeof(int), 2);
compute_encoder->setBytes(&scale_, sizeof(float), 3);
size_t n_batch = in.size() / mat_size;
MTL::Size group_dims;
MTL::Size grid_dims;
if (single) {
compute_encoder.set_bytes(out_strides, 1, 4);
compute_encoder->setBytes(out_strides, sizeof(size_t), 4);
uint32_t dim0 = dims_ / 2;
group_dims = get_block_dims(dim0, n_batch, 1);
grid_dims = MTL::Size(dim0, n_batch, 1);
} else {
compute_encoder.set_bytes(strides, 3, 4);
compute_encoder.set_bytes(out_strides, 3, 5);
compute_encoder.set_bytes(n_batch, 6);
compute_encoder->setBytes(&strides, 3 * sizeof(size_t), 4);
compute_encoder->setBytes(&out_strides, 3 * sizeof(size_t), 5);
compute_encoder->setBytes(&n_batch, sizeof(size_t), 6);
uint32_t dim0 = dims_ / 2;
uint32_t dim1 = in.shape(-2);
uint32_t dim2 = (n_batch + n_per_thread - 1) / n_per_thread;
@@ -104,11 +104,11 @@ void RoPE::eval_gpu(
auto& freqs = inputs[1];
compute_encoder.set_input_array(freqs, 10);
auto freq_stride = freqs.strides()[0];
compute_encoder.set_bytes(freq_stride, 11);
compute_encoder->setBytes(&freq_stride, sizeof(size_t), 11);
} else {
compute_encoder.set_bytes(base, 10);
compute_encoder->setBytes(&base, sizeof(float), 10);
}
compute_encoder.dispatch_threads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
} // namespace mlx::core::fast

Some files were not shown because too many files have changed in this diff Show More