Compare commits

...

24 Commits

Author SHA1 Message Date
Awni Hannun
8bd6bfa4b5 version (#1099) 2024-05-09 17:52:39 -07:00
Awni Hannun
8b1906abd0 Add compiler flags to disable safetensors and gguf (#1098)
* with docs

* nit
2024-05-09 17:39:44 -07:00
Awni Hannun
06375e6605 Split encoders in non-concurrent context with a max ops per encoder (#1085)
* split encoders

* fix race
2024-05-09 16:21:02 -07:00
Awni Hannun
b21242faf1 Allow unary ops to accept array like (#1093) 2024-05-09 09:36:02 -07:00
Rahul Yedida
cc05a281c4 Added ArcTan2 operation (#1079)
* Added ArcTan2 operation

* Cleanup, bug fixes from code review

* Minor cleanup, fixed Linux tests
2024-05-08 08:35:15 -07:00
Jagrit Digani
fe96ceee66 Update block offset adjustment to be in size_t (#1087) 2024-05-08 08:10:23 -07:00
Awni Hannun
9814a2ae12 fix conversion to array (#1070) 2024-05-06 16:02:49 -07:00
Shubham
6992498e7a add keyword positonal (#1081) 2024-05-06 07:18:49 -07:00
Awni Hannun
21623156a3 Reset peak memory (#1074)
* reset peak memory

* fix linux

* nits in docs
2024-05-03 17:12:51 -07:00
Nripesh Niketan
79c859e2e0 feat: implement clip_grad_norm (#1043)
* feat: implement `clip_grad_norm`

* pre-commit

* Add test for clip_grad_norm function in test_optimizers.py

* small fixes

* fix

* lint

* Update tree_reduce

* Update python/mlx/utils.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/mlx/utils.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/mlx/utils.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/mlx/utils.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/mlx/utils.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/mlx/utils.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Refactor clip_grad_norm function to include documentation and improve readability

* format docstring

* Add acknowlegements

* text wrap

* pre-commit

* nits in docs

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2024-05-03 09:07:02 -07:00
Awni Hannun
b00ac960b4 change initial memory limits and add memory size to device info (#1064) 2024-05-03 06:50:15 -07:00
Awni Hannun
02a9fc7bfa Patch bump (#1067)
* version

* use 0.12.2
2024-05-02 16:37:31 -07:00
Jagrit Digani
f390957685 Block sparse mm (#1058) 2024-05-02 14:03:58 -07:00
Angelos Katharopoulos
17f57df797 Improvements in the quantizer and dequantization kernel (#1061) 2024-05-01 18:19:11 -07:00
Awni Hannun
7f7b9662ea Fix leak for multi-output primitives which are never detached (#1059)
* fix multi output leak

* ignore arrays that will be detached

* add some comments

* stray print
2024-05-01 07:31:45 -07:00
Awni Hannun
19bef39f5c Add a mx.metal.device_info (#1060)
* device inof

* add variant

* fix linux

* fix doc
2024-04-30 15:47:27 -07:00
Nripesh Niketan
a30e7ed2da feat: metal formatting and pre-commit bump (#1038)
* feat: metal formatting and pre-commit bump

* add guards

* update

* more guards

* more guards

* smakk fix

* Refactor instantiation of ternary types in ternary.metal

* fix scan.metal
2024-04-30 07:18:09 -07:00
Angelos Katharopoulos
8db7161c94 Bug fix in quantize (#1054) 2024-04-29 20:55:04 -07:00
Awni Hannun
09f1777896 fix slice update indexing (#1053) 2024-04-29 12:17:40 -07:00
Jacket
490c0c4fdc [Fix] expand axes for dimension with integer indices in mlx_slice_update (#1035)
* Not sure if this is correct

* Format

* Edit tests

* Add negative test

* Format

* add one more test

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-04-29 07:57:28 -07:00
Rifur13
c4a471c99d Add groups to Conv1d (#948)
* Add conv1d grouped convs on CPU

* Add GPU support

* Parallelize inside metal kernel

* clenaup

* Update mlx/ops.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* New unfold kernel + remove unused code

* Remove copy and refactor

* Update vjp and reuse steel gemm

* Fixed groups on cpu

* Fix metal validation

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-04-27 06:24:57 -07:00
Awni Hannun
86f495985b Add bitwise ops (#1037)
* bitwise ops

* fix tests
2024-04-26 22:03:42 -07:00
Awni Hannun
67d1894759 fix order device -> scheduler (#1039) 2024-04-26 13:46:41 -07:00
Awni Hannun
5bfe89bdb1 Cpp docs (#1036)
* start of C++ docs

* fix stream doc

* only include ops for now
2024-04-26 12:56:05 -07:00
126 changed files with 7393 additions and 3823 deletions

View File

@@ -1,11 +1,11 @@
repos:
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v18.1.3
rev: v18.1.4
hooks:
- id: clang-format
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 24.3.0
rev: 24.4.2
hooks:
- id: black
- repo: https://github.com/pycqa/isort

View File

@@ -7,7 +7,7 @@ with a short description of your contribution(s) below. For example:
MLX was developed with contributions from the following individuals:
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops.
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`.
- Juarez Bochi: Fixed bug in cross attention.
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream` and safetensor support.

View File

@@ -17,10 +17,12 @@ option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
option(MLX_BUILD_METAL "Build metal backend" ON)
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
if(NOT MLX_VERSION)
set(MLX_VERSION 0.12.0)
set(MLX_VERSION 0.13.0)
endif()
# --------------------- Processor tests -------------------------
@@ -94,7 +96,7 @@ elseif (MLX_BUILD_METAL)
FetchContent_Declare(
metal_cpp
URL ${METAL_CPP_URL}
PATCH_COMMAND patch -N -i ${METAL_CPP_PATCH} || true
PATCH_COMMAND /usr/bin/patch -N -i ${METAL_CPP_PATCH} || true
)
FetchContent_MakeAvailable(metal_cpp)

View File

@@ -0,0 +1,123 @@
import argparse
import math
import os
import subprocess
import time
import mlx.core as mx
import numpy as np
import torch
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
device_name = device_name.decode("utf-8").strip("\n")
N_warmup = 10
N_iter_bench = 100
N_iter_func = 5
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
torch.mps.synchronize()
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_1D(strides=1, padding=0, groups=1):
def mx_conv_1D(a, b):
ys = []
for _ in range(N_iter_func):
y = mx.conv1d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
mx.eval(ys)
return ys
return mx_conv_1D
def make_pt_conv_1D(strides=1, padding=0, groups=1):
@torch.no_grad()
def pt_conv_1D(a, b):
ys = []
for _ in range(N_iter_func):
y = torch.conv1d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
torch.mps.synchronize()
return ys
return pt_conv_1D
def bench_shape(N, iH, C, wH, O, strides, padding, np_dtype, groups):
scale = 1.0 / math.sqrt(wH * C)
a_np = np.random.uniform(0, 0.5, (N, iH, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, wH, int(C / groups))).astype(np_dtype)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np.transpose((0, 2, 1))).to("mps")
b_pt = torch.from_numpy(b_np.transpose((0, 2, 1))).to("mps")
torch.mps.synchronize()
f_mx = make_mx_conv_1D(strides, padding, groups)
f_pt = make_pt_conv_1D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv1d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
out_pt = torch.conv1d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
)
out_pt = torch.permute(out_pt, (0, 2, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, iH, C)}, {(O, wH, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run conv benchmarks")
dtypes = ("float32",)
shapes = (
(4, 32, 32, 5, 32, 1, 2, 1),
(4, 32, 32, 5, 32, 1, 2, 2),
(4, 32, 32, 5, 32, 1, 2, 4),
(4, 32, 32, 5, 32, 1, 2, 8),
(4, 32, 32, 5, 32, 1, 2, 8),
(4, 32, 32, 5, 32, 1, 2, 16),
(4, 32, 32, 5, 32, 1, 2, 32),
(4, 32, 256, 5, 512, 1, 2, 2),
(4, 32, 256, 5, 512, 1, 2, 128),
(4, 32, 256, 5, 512, 1, 2, 256),
)
for dtype in dtypes:
print("(N, iH, C), (O, wH, C), dtype, stride, pads, groups, diff%")
for N, iH, C, wH, O, strides, padding, groups in shapes:
np_dtype = getattr(np, dtype)
time_mlx, time_torch = bench_shape(
N, iH, C, wH, O, strides, padding, np_dtype, groups
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {iH:3d}, {C:3d}), ({O:3d}, {wH:2d}, {C:3d}), {dtype}, {strides:5d}, {padding:4d}, {groups:6d}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")

50
docs/Doxyfile Normal file
View File

@@ -0,0 +1,50 @@
################################################################################
# Primary project setup. #
################################################################################
PROJECT_NAME = "MLX"
OUTPUT_DIRECTORY = build
XML_OUTPUT = xml
HTML_OUTPUT = html
STRIP_FROM_PATH = ../
INPUT = ../mlx
FILE_PATTERNS = *.h
EXCLUDE_PATTERNS = */private/*
CREATE_SUBDIRS = NO
FULL_PATH_NAMES = YES
RECURSIVE = YES
GENERATE_HTML = YES
GENERATE_LATEX = NO
GENERATE_XML = YES
XML_PROGRAMLISTING = YES
################################################################################
# Doxygen preprocessor / parser control. #
################################################################################
ENABLE_PREPROCESSING = YES
MACRO_EXPANSION = YES
EXPAND_ONLY_PREDEF = NO
SKIP_FUNCTION_MACROS = NO
################################################################################
# Compound extraction control. #
################################################################################
EXTRACT_ALL = YES
EXTRACT_PACKAGE = YES
EXTRACT_STATIC = YES
CASE_SENSE_NAMES = NO
################################################################################
# Docstring control / customization. #
################################################################################
JAVADOC_AUTOBRIEF = YES
################################################################################
# Warning suppression. #
################################################################################
QUIET = YES
WARN_IF_UNDOCUMENTED = NO

View File

@@ -2,12 +2,16 @@
### Setup (do once)
Install [sphinx](https://www.sphinx-doc.org/en/master/usage/installation.html)
for example with `conda`:
Install Doxygen:
```
conda install sphinx
pip install sphinx-book-theme
brew install doxygen
```
Install Python packages:
```
pip install -r requirements.txt
```
### Build
@@ -15,7 +19,7 @@ pip install sphinx-book-theme
Build the docs from `mlx/docs/`
```
make html
doxygen && make html
```
View the docs by running a server in `mlx/docs/build/html/`:

3
docs/requirements.txt Normal file
View File

@@ -0,0 +1,3 @@
sphinx
breathe
sphinx-book-theme

View File

@@ -22,6 +22,7 @@ extensions = [
"sphinx.ext.autosummary",
"sphinx.ext.intersphinx",
"sphinx.ext.napoleon",
"breathe",
]
python_use_unqualified_type_names = True
@@ -33,6 +34,9 @@ intersphinx_mapping = {
"numpy": ("https://numpy.org/doc/stable/", None),
}
breathe_projects = {"mlx": "../build/xml"}
breathe_default_project = "mlx"
templates_path = ["_templates"]
html_static_path = ["_static"]
source_suffix = ".rst"

View File

@@ -3,4 +3,5 @@
Operations
==========
.. doxygengroup:: ops
:content-only:

View File

@@ -157,7 +157,10 @@ should point to the path to the built metal library.
- OFF
* - MLX_METAL_DEBUG
- OFF
* - MLX_BUILD_SAFETENSORS
- ON
* - MLX_BUILD_GGUF
- ON
.. note::

View File

@@ -7,8 +7,10 @@ Metal
:toctree: _autosummary
is_available
device_info
get_active_memory
get_peak_memory
reset_peak_memory
get_cache_memory
set_memory_limit
set_cache_limit

View File

@@ -19,6 +19,7 @@ Operations
arcsin
arcsinh
arctan
arctan2
arctanh
argmax
argmin
@@ -28,8 +29,12 @@ Operations
atleast_1d
atleast_2d
atleast_3d
broadcast_to
bitwise_and
bitwise_or
bitwise_xor
block_masked_mm
block_sparse_mm
broadcast_to
ceil
clip
concatenate
@@ -69,6 +74,7 @@ Operations
isnan
isneginf
isposinf
left_shift
less
less_equal
linspace
@@ -105,6 +111,7 @@ Operations
reciprocal
repeat
reshape
right_shift
round
rsqrt
save

View File

@@ -1,5 +1,7 @@
.. _optimizers:
.. currentmodule:: mlx.optimizers
Optimizers
==========
@@ -34,3 +36,8 @@ model's parameters and the **optimizer state**.
optimizers/optimizer
optimizers/common_optimizers
optimizers/schedulers
.. autosummary::
:toctree: _autosummary
clip_grad_norm

View File

@@ -20,3 +20,4 @@ return python trees will be using the default python ``dict``, ``list`` and
tree_unflatten
tree_map
tree_map_with_path
tree_reduce

View File

@@ -33,7 +33,7 @@ array axpby(
class Axpby : public Primitive {
public:
explicit Axpby(Stream stream, float alpha, float beta)
: Primitive(stream), alpha_(alpha), beta_(beta){};
: Primitive(stream), alpha_(alpha), beta_(beta) {};
/**
* A primitive must know how to evaluate itself on the CPU/GPU

View File

@@ -19,7 +19,7 @@ template <typename T>
uint index [[thread_position_in_grid]]) {
auto x_offset = elem_to_loc(index, shape, x_strides, ndim);
auto y_offset = elem_to_loc(index, shape, y_strides, ndim);
out[index] =
out[index] =
static_cast<T>(alpha) * x[x_offset] + static_cast<T>(beta) * y[y_offset];
}
@@ -31,30 +31,30 @@ template <typename T>
constant const float& alpha [[buffer(3)]],
constant const float& beta [[buffer(4)]],
uint index [[thread_position_in_grid]]) {
out[index] =
out[index] =
static_cast<T>(alpha) * x[index] + static_cast<T>(beta) * y[index];
}
#define instantiate_axpby(type_name, type) \
template [[host_name("axpby_general_" #type_name)]] \
[[kernel]] void axpby_general<type>( \
device const type* x [[buffer(0)]], \
device const type* y [[buffer(1)]], \
device type* out [[buffer(2)]], \
constant const float& alpha [[buffer(3)]], \
constant const float& beta [[buffer(4)]], \
constant const int* shape [[buffer(5)]], \
constant const size_t* x_strides [[buffer(6)]], \
constant const size_t* y_strides [[buffer(7)]], \
constant const int& ndim [[buffer(8)]], \
uint index [[thread_position_in_grid]]); \
template [[host_name("axpby_contiguous_" #type_name)]] \
[[kernel]] void axpby_contiguous<type>( \
device const type* x [[buffer(0)]], \
device const type* y [[buffer(1)]], \
device type* out [[buffer(2)]], \
constant const float& alpha [[buffer(3)]], \
constant const float& beta [[buffer(4)]], \
#define instantiate_axpby(type_name, type) \
template [[host_name("axpby_general_" #type_name)]] [[kernel]] void \
axpby_general<type>( \
device const type* x [[buffer(0)]], \
device const type* y [[buffer(1)]], \
device type* out [[buffer(2)]], \
constant const float& alpha [[buffer(3)]], \
constant const float& beta [[buffer(4)]], \
constant const int* shape [[buffer(5)]], \
constant const size_t* x_strides [[buffer(6)]], \
constant const size_t* y_strides [[buffer(7)]], \
constant const int& ndim [[buffer(8)]], \
uint index [[thread_position_in_grid]]); \
template [[host_name("axpby_contiguous_" #type_name)]] [[kernel]] void \
axpby_contiguous<type>( \
device const type* x [[buffer(0)]], \
device const type* y [[buffer(1)]], \
device type* out [[buffer(2)]], \
constant const float& alpha [[buffer(3)]], \
constant const float& beta [[buffer(4)]], \
uint index [[thread_position_in_grid]]);
instantiate_axpby(float32, float);

View File

@@ -14,7 +14,7 @@ class Buffer {
void* ptr_;
public:
Buffer(void* ptr) : ptr_(ptr){};
Buffer(void* ptr) : ptr_(ptr) {};
// Get the raw data pointer from the buffer
void* raw_ptr();

View File

@@ -1,5 +1,4 @@
// Copyright © 2023-2024 Apple Inc.
#include <functional>
#include "mlx/array.h"
@@ -167,6 +166,39 @@ void array::move_shared_buffer(array other) {
move_shared_buffer(other, other.strides(), other.flags(), other.data_size());
}
array::~array() {
if (array_desc_ == nullptr) {
return;
}
// Ignore arrays that will be detached
if (status() != array::Status::unscheduled) {
return;
}
// Break circular reference for non-detached arrays with siblings
if (auto n = siblings().size(); n > 0) {
bool do_detach = true;
// If all siblings have siblings.size() references except
// the one we are currently destroying (which has siblings.size() + 1)
// then there are no more external references
do_detach &= (array_desc_.use_count() == (n + 1));
for (auto& s : siblings()) {
do_detach &= (s.array_desc_.use_count() == n);
if (!do_detach) {
break;
}
}
if (do_detach) {
for (auto& s : siblings()) {
for (auto& ss : s.siblings()) {
ss.array_desc_ = nullptr;
}
s.array_desc_->siblings.clear();
}
}
}
}
void array::ArrayDesc::init() {
strides.resize(shape.size());
size = 1;

View File

@@ -114,6 +114,15 @@ class array {
return array_desc_->strides;
};
/**
* Get the stride of the corresponding dimension.
*
* This function supports negative indexing and provides
* bounds checking. */
size_t strides(int dim) const {
return strides().at(dim < 0 ? dim + ndim() : dim);
};
/** Get the arrays data type. */
Dtype dtype() const {
return array_desc_->dtype;
@@ -200,7 +209,7 @@ class array {
allocator::Buffer buffer;
deleter_t d;
Data(allocator::Buffer buffer, deleter_t d = allocator::free)
: buffer(buffer), d(d){};
: buffer(buffer), d(d) {};
// Not copyable
Data(const Data& d) = delete;
Data& operator=(const Data& d) = delete;
@@ -252,22 +261,16 @@ class array {
return array_desc_->siblings;
};
/** The array's siblings. */
std::vector<array>& siblings() {
return array_desc_->siblings;
};
void set_siblings(std::vector<array> siblings, uint16_t position) {
array_desc_->siblings = std::move(siblings);
array_desc_->position = position;
}
/** The i-th output of the array's primitive. */
const array& output(int i) const {
if (i == array_desc_->position) {
return *this;
} else if (i < array_desc_->position) {
return siblings()[i];
} else {
return siblings()[i + 1];
}
};
/** The outputs of the array's primitive (i.e. this array and
* its siblings) in the order the primitive expects. */
std::vector<array> outputs() const {
@@ -377,6 +380,8 @@ class array {
array_desc_ = other.array_desc_;
}
~array();
private:
// Initialize the arrays data
template <typename It>

View File

@@ -32,6 +32,7 @@ DEFAULT(ArgReduce)
DEFAULT(ArgSort)
DEFAULT(AsStrided)
DEFAULT(BlockMaskedMM)
DEFAULT(BlockSparseMM)
DEFAULT(Broadcast)
DEFAULT(Ceil)
DEFAULT(Concatenate)
@@ -192,6 +193,26 @@ void ArcTan::eval_cpu(const std::vector<array>& inputs, array& out) {
}
}
void ArcTan2::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
if (out.dtype() == float32 && a.flags().row_contiguous &&
b.flags().row_contiguous) {
if (a.is_donatable()) {
out.copy_shared_buffer(a);
} else if (b.is_donatable()) {
out.copy_shared_buffer(b);
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
int size = a.data_size();
vvatan2f(out.data<float>(), a.data<float>(), b.data<float>(), &size);
} else {
eval(inputs, out);
}
}
void ArcTanh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];

View File

@@ -236,4 +236,82 @@ void Subtract::eval(const std::vector<array>& inputs, array& out) {
binary(a, b, out, detail::Subtract());
}
void BitwiseBinary::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
auto dispatch_type = [&a, &b, &out](auto op) {
switch (out.dtype()) {
case bool_:
binary_op<bool>(a, b, out, op);
case uint8:
binary_op<uint8_t>(a, b, out, op);
break;
case uint16:
binary_op<uint16_t>(a, b, out, op);
break;
case uint32:
binary_op<uint32_t>(a, b, out, op);
break;
case uint64:
binary_op<uint64_t>(a, b, out, op);
break;
case int8:
binary_op<int8_t>(a, b, out, op);
break;
case int16:
binary_op<int16_t>(a, b, out, op);
break;
case int32:
binary_op<int32_t>(a, b, out, op);
break;
case int64:
binary_op<int64_t>(a, b, out, op);
break;
default:
throw std::runtime_error(
"[BitwiseBinary::eval_cpu] Type not supported");
break;
}
};
switch (op_) {
case BitwiseBinary::And:
dispatch_type(detail::BitwiseAnd());
break;
case BitwiseBinary::Or:
dispatch_type(detail::BitwiseOr());
break;
case BitwiseBinary::Xor:
dispatch_type(detail::BitwiseXor());
break;
case BitwiseBinary::LeftShift:
dispatch_type(detail::LeftShift());
break;
case BitwiseBinary::RightShift:
dispatch_type(detail::RightShift());
break;
}
}
void ArcTan2::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
const auto& a = inputs[0];
const auto& b = inputs[1];
if (out.dtype() == float32) {
binary_op<float>(a, b, out, detail::ArcTan2());
} else if (out.dtype() == float16) {
binary_op<float16_t>(a, b, out, detail::ArcTan2());
} else if (out.dtype() == bfloat16) {
binary_op<bfloat16_t>(a, b, out, detail::ArcTan2());
} else if (issubdtype(out.dtype(), inexact)) {
std::ostringstream err;
err << "[arctan2] Does not support " << out.dtype();
throw std::invalid_argument(err.str());
} else {
throw std::invalid_argument(
"[arctan2] Cannot compute inverse tangent for arrays"
" with non floating point type.");
}
}
} // namespace mlx::core

View File

@@ -38,11 +38,15 @@ void slow_conv_1D(
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
const int iH = 1 + in_dilation[0] * (in.shape(1) - 1); // Input spatial dim
const int C = in.shape(2); // Input channels
const int oH = out.shape(1); // Output spatial dim
const int O = wt.shape(0); // Out channels
const int C = wt.shape(2); // In channels
const int wH = wt.shape(1); // Weight spatial dim
const int groups = C / wt.shape(2);
const int C_per_group = wt.shape(2);
const int O_per_group = O / groups;
const size_t in_stride_N = in.strides()[0];
const size_t in_stride_H = in.strides()[1];
const size_t in_stride_C = in.strides()[2];
@@ -57,35 +61,36 @@ void slow_conv_1D(
for (int n = 0; n < N; ++n) {
for (int oh = 0; oh < oH; ++oh) {
for (int o = 0; o < O; ++o) {
const T* filter_wt_ptr = start_wt_ptr + o * wt_stride_O;
float r = 0.;
for (int g = 0; g < groups; ++g) {
for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
const T* filter_wt_ptr = start_wt_ptr + o * wt_stride_O;
float r = 0.;
for (int wh = 0; wh < wH; ++wh) {
const T* wt_ptr = filter_wt_ptr + wh * wt_stride_H;
for (int wh = 0; wh < wH; ++wh) {
const T* wt_ptr = filter_wt_ptr + wh * wt_stride_H;
int wh_flip = flip ? (wH - wh - 1) : wh;
int ih = oh * wt_strides[0] - padding[0] + wh_flip * wt_dilation[0];
int wh_flip = flip ? (wH - wh - 1) : wh;
int ih = oh * wt_strides[0] - padding[0] + wh_flip * wt_dilation[0];
auto ih_div = std::div(ih, in_dilation[0]);
auto ih_div = std::div(ih, in_dilation[0]);
if (ih >= 0 && ih < iH && ih_div.rem == 0) {
for (int c = 0; c < C; ++c) {
r += static_cast<float>(
in_ptr[ih_div.quot * in_stride_H + c * in_stride_C]) *
static_cast<float>(wt_ptr[c * wt_stride_C]);
} // c
if (ih >= 0 && ih < iH && ih_div.rem == 0) {
for (int c = g * C_per_group; c < (g + 1) * C_per_group; ++c) {
r += static_cast<float>(
in_ptr[ih_div.quot * in_stride_H + c * in_stride_C]) *
static_cast<float>(wt_ptr[(c % C_per_group) * wt_stride_C]);
} // c
} // ih check
} // wh
} // ih check
} // wh
out_ptr[oh * out_stride_H + o * out_stride_O] = static_cast<T>(r);
} // o
out_ptr[oh * out_stride_H + o * out_stride_O] = static_cast<T>(r);
} // o
} // g
} // oh
in_ptr += in_stride_N;
out_ptr += out_stride_N;
} // n
}
@@ -366,11 +371,15 @@ void explicit_gemm_conv_1D_cpu(
const std::vector<int>& wt_dilation) {
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
const int iH = in.shape(1); // Input spatial dim
const int C = in.shape(2); // Input channels
const int oH = out.shape(1); // Output spatial dim
const int O = wt.shape(0); // Out channels
const int C = wt.shape(2); // In channels
const int wH = wt.shape(1); // Weight spatial dim
const int groups = C / wt.shape(2);
const int C_per_group = wt.shape(2);
const int O_per_group = O / groups;
auto conv_dtype = float32;
// Pad input
@@ -402,6 +411,11 @@ void explicit_gemm_conv_1D_cpu(
in_padded.strides()[1],
in_padded.strides()[2]};
auto flags = in_padded.flags();
if (groups > 1) {
// Transpose the last two dimensions for grouped convolutions
std::swap(strided_shape[2], strided_shape[3]);
std::swap(strided_strides[2], strided_strides[3]);
}
array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {});
in_strided_view.copy_shared_buffer(
@@ -416,7 +430,19 @@ void explicit_gemm_conv_1D_cpu(
auto gemm_wt = wt;
auto gemm_out = out;
if (wt.dtype() != float32 || !wt.flags().row_contiguous) {
if (groups > 1) {
// Transpose the last two dimensions for grouped convolutions
array wt_transpose(
{wt.shape(0), wt.shape(2), wt.shape(1)}, wt.dtype(), nullptr, {});
wt_transpose.copy_shared_buffer(
wt,
{wt.strides(0), wt.strides(2), wt.strides(1)},
wt.flags(),
wt.size(),
0);
gemm_wt = array(wt_transpose.shape(), float32, nullptr, {});
copy(wt_transpose, gemm_wt, CopyType::General);
} else if (wt.dtype() != float32 || !wt.flags().row_contiguous) {
auto ctype =
wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
gemm_wt = array(wt.shape(), float32, nullptr, {});
@@ -428,27 +454,29 @@ void explicit_gemm_conv_1D_cpu(
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes()));
}
// Perform gemm
cblas_sgemm(
CblasRowMajor,
CblasNoTrans, // no trans A
CblasTrans, // transB
strided_reshape[0], // M
O, // N
strided_reshape[1], // K
1.0f, // alpha
in_strided.data<float>(),
strided_reshape[1], // lda
gemm_wt.data<float>(),
strided_reshape[1], // ldb
0.0f, // beta
gemm_out.data<float>(),
O // ldc
);
for (int g = 0; g < groups; ++g) {
// Perform gemm
cblas_sgemm(
CblasRowMajor,
CblasNoTrans, // no trans A
CblasTrans, // transB
strided_reshape[0], // M
O_per_group, // N
C_per_group * wH, // K
1.0f, // alpha
in_strided.data<float>() + g * C_per_group * wH, // A
wH * C, // lda
gemm_wt.data<float>() + g * O_per_group * C_per_group * wH, // B
wH * C_per_group, // ldb
0.0f, // beta
gemm_out.data<float>() + g * O_per_group, // C
O // ldc
);
// Copy results if needed
if (out.dtype() != float32) {
copy(gemm_out, out, CopyType::Vector);
// Copy results if needed
if (out.dtype() != float32) {
copy(gemm_out, out, CopyType::Vector);
}
}
}

View File

@@ -34,6 +34,7 @@ DEFAULT(ArcCosh)
DEFAULT(ArcSin)
DEFAULT(ArcSinh)
DEFAULT(ArcTan)
DEFAULT(ArcTan2)
DEFAULT(ArcTanh)
DEFAULT(ArgPartition)
DEFAULT(ArgReduce)
@@ -42,6 +43,7 @@ DEFAULT(AsType)
DEFAULT(AsStrided)
DEFAULT(Broadcast)
DEFAULT(BlockMaskedMM)
DEFAULT(BlockSparseMM)
DEFAULT_MULTI(DivMod)
DEFAULT(Ceil)
DEFAULT(Concatenate)

View File

@@ -11,7 +11,7 @@ GCC=$2
SRCDIR=$3
CLANG=$4
if [ $CLANG = "TRUE" ]; then
if [ "$CLANG" = "TRUE" ]; then
read -r -d '' INCLUDES <<- EOM
#include <cmath>
#include <complex>

View File

@@ -190,4 +190,91 @@ void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
}
}
void BlockSparseMM::eval(const std::vector<array>& inputs, array& out) {
if (out.dtype() != float32) {
throw std::runtime_error(
"[BlockSparseMM::eval] Currently only supports float32.");
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& a_pre = inputs[0];
auto& b_pre = inputs[1];
auto check_transpose = [](const array& arr) {
auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1];
if (stx == arr.shape(-1) && sty == 1) {
return std::make_tuple(false, stx, arr);
} else if (stx == 1 && sty == arr.shape(-2)) {
return std::make_tuple(true, sty, arr);
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
size_t stx = arr.shape(-1);
return std::make_tuple(false, stx, arr_copy);
}
};
auto [a_transposed, lda, a] = check_transpose(a_pre);
auto [b_transposed, ldb, b] = check_transpose(b_pre);
size_t M = a.shape(-2);
size_t N = b.shape(-1);
size_t K = a.shape(-1);
if (M == 0 || N == 0) {
return;
}
if (K == 0) {
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
return;
}
// Get batch dims
auto batch_size_out = out.size() / (M * N);
size_t matrix_stride_out = M * N;
auto get_batch_dims = [](const auto& v) {
return decltype(v){v.begin(), v.end() - 2};
};
auto& lhs_indices = inputs[2];
auto& rhs_indices = inputs[3];
std::vector<int> batch_shape = get_batch_dims(out.shape());
int batch_ndim = batch_shape.size();
std::vector<int> batch_shape_A = get_batch_dims(a.shape());
std::vector<size_t> batch_strides_A = get_batch_dims(a.strides());
std::vector<int> batch_shape_B = get_batch_dims(b.shape());
std::vector<size_t> batch_strides_B = get_batch_dims(b.strides());
const uint32_t* lhs_indices_ptr = lhs_indices.data<uint32_t>();
const uint32_t* rhs_indices_ptr = rhs_indices.data<uint32_t>();
for (int i = 0; i < batch_size_out; i++) {
// Get index
uint32_t indx_A = lhs_indices_ptr[elem_to_loc(i, lhs_indices)];
uint32_t indx_B = rhs_indices_ptr[elem_to_loc(i, rhs_indices)];
cblas_sgemm(
CblasRowMajor,
a_transposed ? CblasTrans : CblasNoTrans, // transA
b_transposed ? CblasTrans : CblasNoTrans, // transB
M,
N,
K,
1.0f, // alpha
a.data<float>() + elem_to_loc(indx_A, batch_shape_A, batch_strides_A),
lda,
b.data<float>() + elem_to_loc(indx_B, batch_shape_B, batch_strides_B),
ldb,
0.0f, // beta
out.data<float>() + matrix_stride_out * i,
out.shape(-1) // ldc
);
}
}
} // namespace mlx::core

View File

@@ -161,6 +161,13 @@ struct ArcTan {
};
};
struct ArcTan2 {
template <typename T>
T operator()(T y, T x) {
return std::atan2(y, x);
};
};
struct ArcTanh {
template <typename T>
T operator()(T x) {
@@ -606,4 +613,39 @@ struct Select {
}
};
struct BitwiseAnd {
template <typename T>
T operator()(T x, T y) {
return x & y;
};
};
struct BitwiseOr {
template <typename T>
T operator()(T x, T y) {
return x | y;
};
};
struct BitwiseXor {
template <typename T>
T operator()(T x, T y) {
return x ^ y;
};
};
struct LeftShift {
template <typename T>
T operator()(T x, T y) {
return x << y;
};
};
struct RightShift {
template <typename T>
T operator()(T x, T y) {
return x >> y;
};
};
} // namespace mlx::core::detail

View File

@@ -140,10 +140,15 @@ void BufferCache::remove_from_list(BufferCache::BufferHolder* to_remove) {
MetalAllocator::MetalAllocator()
: device_(device(mlx::core::Device::gpu).mtl_device()),
buffer_cache_(device_),
block_limit_(1.5 * device_->recommendedMaxWorkingSetSize()),
gc_limit_(0.95 * device_->recommendedMaxWorkingSetSize()),
max_pool_size_(block_limit_) {}
buffer_cache_(device_) {
auto memsize = std::get<size_t>(device_info()["memory_size"]);
block_limit_ =
std::min(1.5 * device_->recommendedMaxWorkingSetSize(), 0.95 * memsize);
gc_limit_ = std::min(
static_cast<size_t>(0.95 * device_->recommendedMaxWorkingSetSize()),
block_limit_);
max_pool_size_ = block_limit_;
}
size_t MetalAllocator::set_cache_limit(size_t limit) {
std::swap(limit, max_pool_size_);
@@ -165,6 +170,15 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
return Buffer{nullptr};
}
// More helpful message if maximum buffer length is exceeded
if (size > device_->maxBufferLength()) {
std::ostringstream msg;
msg << "Attempting to allocate " << size << " bytes which is greater than"
<< " the maximum allowed buffer size of " << device_->maxBufferLength()
<< " bytes.";
throw std::runtime_error(msg.str());
}
// Align up memory
if (size > vm_page_size) {
size = vm_page_size * ((size + vm_page_size - 1) / vm_page_size);
@@ -244,6 +258,9 @@ size_t get_active_memory() {
size_t get_peak_memory() {
return allocator().get_peak_memory();
}
void reset_peak_memory() {
allocator().reset_peak_memory();
}
size_t get_cache_memory() {
return allocator().get_cache_memory();
}

View File

@@ -62,6 +62,10 @@ class MetalAllocator : public allocator::Allocator {
size_t get_peak_memory() {
return peak_memory_;
};
void reset_peak_memory() {
std::unique_lock lk(mutex_);
peak_memory_ = 0;
};
size_t get_cache_memory() {
return buffer_cache_.cache_size();
};

View File

@@ -336,7 +336,7 @@ void Compiled::eval_gpu(
MTL::Size grid_dims(nthreads, 1, 1);
MTL::Size group_dims(
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} else {
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
@@ -347,7 +347,7 @@ void Compiled::eval_gpu(
}
auto group_dims = get_block_dims(dim0, dim1, rest);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
}

View File

@@ -59,7 +59,7 @@ void explicit_gemm_conv_ND_gpu(
MTL::Size grid_dims = MTL::Size(
conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
// Reshape weight
std::vector<int> wt_reshape{implicit_K, implicit_N};
@@ -89,6 +89,90 @@ void explicit_gemm_conv_ND_gpu(
/*copies = */ copies);
}
template <int N>
void explicit_gemm_conv_group_ND_gpu(
const Stream& s,
metal::Device& d,
const array& in,
const array& wt,
array out,
const MLXConvParams<N>& conv_params) {
const int groups = conv_params.groups;
const int C_per_group = conv_params.C / conv_params.groups;
const int O_per_group = conv_params.O / conv_params.groups;
// Get gemm shapes
const int implicit_M = out.size() / conv_params.O;
const int implicit_K = wt.size() / conv_params.O;
const int implicit_N = O_per_group;
int kernel_size = 1;
for (int i = 0; i < N; ++i) {
kernel_size *= conv_params.wS[i];
}
// Prepare unfolding array
std::vector<int> unfolded_shape{implicit_M, implicit_K * groups};
array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes()));
// Prepare unfolding kernel
std::ostringstream kname;
kname << "naive_unfold_transpose_nd_" << type_to_name(in_unfolded) << "_"
<< N;
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(in_unfolded, 1);
compute_encoder->setBytes(&conv_params, sizeof(conv_params), 2);
// Launch unfolding kernel
int tgp_x = std::min(conv_params.C, 64);
tgp_x = 32 * ((tgp_x + 32 - 1) / 32);
int tgp_y = 256 / tgp_x;
MTL::Size group_dims = MTL::Size(tgp_x, tgp_y, 1);
MTL::Size grid_dims = MTL::Size(
conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]);
compute_encoder.dispatchThreads(grid_dims, group_dims);
// Transpose kernel weights so that we can slice them by contiguous chunks
// of channel groups.
array wt_view(
{wt.shape(0), C_per_group, kernel_size}, wt.dtype(), nullptr, {});
wt_view.copy_shared_buffer(
wt,
{wt.strides(0), 1, static_cast<size_t>(C_per_group)},
wt.flags(),
wt.size());
// Materialize
auto wt_transpose = array(wt_view.shape(), wt_view.dtype(), nullptr, {});
copy_gpu(wt_view, wt_transpose, CopyType::General, s);
// Perform gemm
std::vector<array> copies = {in_unfolded, wt_view, wt_transpose};
return steel_matmul_conv_groups(
s,
d,
/*a = */ in_unfolded,
/*b = */ wt_transpose,
/*c = */ out,
/*M = */ implicit_M,
/*N = */ implicit_N,
/*K = */ implicit_K,
/*a_cols = */ implicit_K * groups,
/*b_cols = */ implicit_K,
/*out_cols = */ implicit_N * groups,
/*a_transposed = */ false,
/*b_transposed = */ true,
/* groups = */ groups,
/*copies = */ copies);
}
void conv_1D_gpu(
const Stream& s,
metal::Device& d,
@@ -99,6 +183,7 @@ void conv_1D_gpu(
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
int groups,
bool flip) {
// Make conv params
MLXConvParams<1> conv_params{
@@ -118,11 +203,15 @@ void conv_1D_gpu(
{wt.strides()[0], wt.strides()[1], wt.strides()[2]},
/* const size_t out_strides[NDIM + 2] = */
{out.strides()[0], out.strides()[1], out.strides()[2]},
/* const int groups = */ 1,
/* const int groups = */ groups,
/* const bool flip = */ flip};
// Direct to explicit gemm conv
return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);
if (groups > 1) {
return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params);
} else {
return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);
}
}
void slow_conv_2D_gpu(
@@ -158,7 +247,7 @@ void slow_conv_2D_gpu(
compute_encoder.set_output_array(out, 2);
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
void implicit_gemm_conv_2D_gpu(
@@ -263,7 +352,7 @@ void implicit_gemm_conv_2D_gpu(
compute_encoder->setBytes(&gemm_params, sizeof(ImplicitGemmConv2DParams), 4);
// Launch kernel
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
void implicit_gemm_conv_2D_general_gpu(
@@ -423,7 +512,7 @@ void implicit_gemm_conv_2D_general_gpu(
base_w.data(), sizeof(Conv2DGeneralBaseInfo) * base_w.size(), 7);
// Launch kernel
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
void winograd_conv_2D_gpu(
@@ -524,7 +613,7 @@ void winograd_conv_2D_gpu(
MTL::Size group_dims = MTL::Size(32, bo, 1);
MTL::Size grid_dims = MTL::Size(O_c / bo, 1, 1);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
// Do input transform
@@ -552,7 +641,7 @@ void winograd_conv_2D_gpu(
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
// Do batched gemm
@@ -600,7 +689,7 @@ void winograd_conv_2D_gpu(
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
}
@@ -721,6 +810,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
kernel_strides_,
kernel_dilation_,
input_dilation_,
groups_,
flip_);
}
// Throw error

View File

@@ -126,7 +126,7 @@ void copy_gpu_inplace(
auto group_dims = get_block_dims(dim0, dim1, rest);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} else {
size_t nthreads = out.data_size();
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
@@ -135,7 +135,7 @@ void copy_gpu_inplace(
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
}

View File

@@ -5,6 +5,8 @@
#include <filesystem>
#include <sstream>
#include <sys/sysctl.h>
#define NS_PRIVATE_IMPLEMENTATION
#define CA_PRIVATE_IMPLEMENTATION
#define MTL_PRIVATE_IMPLEMENTATION
@@ -23,6 +25,7 @@ namespace {
// TODO nicer way to set this or possibly expose as an environment variable
constexpr int MAX_BUFFERS_PER_QUEUE = 12;
constexpr int MAX_DISPATCHES_PER_ENCODER = 2;
constexpr const char* default_mtllib_path = METAL_PATH;
@@ -35,7 +38,6 @@ auto load_device() {
}
return device;
}
std::pair<MTL::Library*, NS::Error*> load_library_from_path(
MTL::Device* device,
const char* path) {
@@ -114,6 +116,33 @@ MTL::Library* load_library(
} // namespace
void CommandEncoder::dispatchThreadgroups(
MTL::Size grid_dims,
MTL::Size group_dims) {
num_dispatches++;
enc->dispatchThreadgroups(grid_dims, group_dims);
maybe_split();
}
void CommandEncoder::dispatchThreads(
MTL::Size grid_dims,
MTL::Size group_dims) {
num_dispatches++;
enc->dispatchThreads(grid_dims, group_dims);
maybe_split();
}
void CommandEncoder::maybe_split() {
if (num_dispatches > MAX_DISPATCHES_PER_ENCODER && !concurrent) {
enc->endEncoding();
enc->release();
num_dispatches = 0;
outputs.clear();
enc = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
enc->retain();
}
}
Device::Device() {
auto pool = new_scoped_memory_pool();
device_ = load_device();
@@ -128,9 +157,6 @@ Device::~Device() {
for (auto& b : buffer_map_) {
b.second.second->release();
}
for (auto& e : encoder_map_) {
(*e.second)->release();
}
for (auto& k : kernel_map_) {
k.second->release();
}
@@ -167,27 +193,26 @@ void Device::increment_command_buffer_ops(int index) {
MTL::CommandBuffer* Device::get_command_buffer(int index) {
auto bit = buffer_map_.find(index);
return (bit == buffer_map_.end()) ? nullptr : bit->second.second;
}
if (bit == buffer_map_.end()) {
auto qit = queue_map_.find(index);
if (qit == queue_map_.end()) {
throw std::runtime_error(
"[metal::Device] Attempting to get command buffer for invalid queue.");
}
MTL::CommandBuffer* Device::new_command_buffer(int index) {
auto qit = queue_map_.find(index);
if (qit == queue_map_.end()) {
throw std::runtime_error(
"[metal::Device] Attempting to get command buffer for invalid queue.");
auto cb = qit->second->commandBufferWithUnretainedReferences();
if (!cb) {
throw std::runtime_error(
"[metal::Device] Unable to create new command buffer");
}
// Increment ref count so the buffer is not garbage collected
cb->retain();
bit = buffer_map_.insert({index, {0, cb}}).first;
}
auto cb = qit->second->commandBufferWithUnretainedReferences();
if (!cb) {
throw std::runtime_error(
"[metal::Device] Unable to create new command buffer");
}
// Increment ref count so the buffer is not garbage collected
cb->retain();
return buffer_map_.insert({index, {0, cb}}).first->second.second;
return bit->second.second;
}
void Device::commit_command_buffer(int index) {
@@ -198,25 +223,15 @@ void Device::commit_command_buffer(int index) {
}
void Device::end_encoding(int index) {
auto eit = encoder_map_.find(index);
if (eit != encoder_map_.end()) {
(*eit->second)->endEncoding();
(*eit->second)->release();
encoder_map_.erase(eit);
}
encoder_map_.erase(index);
}
CommandEncoder& Device::get_command_encoder(int index) {
auto eit = encoder_map_.find(index);
if (eit == encoder_map_.end()) {
auto cb = get_command_buffer(index);
auto compute_encoder =
cb->computeCommandEncoder(MTL::DispatchTypeConcurrent);
// Increment ref count so the buffer is not garbage collected
compute_encoder->retain();
eit = encoder_map_
.emplace(index, std::make_unique<CommandEncoder>(compute_encoder))
.first;
eit =
encoder_map_.emplace(index, std::make_unique<CommandEncoder>(cb)).first;
}
return *(eit->second);
}
@@ -265,8 +280,7 @@ MTL::Library* Device::get_library_(const std::string& source_string) {
// Throw error if unable to compile library
if (!mtl_lib) {
std::ostringstream msg;
msg << "[metal::Device] Unable to load build metal library from source"
<< "\n";
msg << "[metal::Device] Unable to build metal library from source" << "\n";
if (error) {
msg << error->localizedDescription()->utf8String() << "\n";
}
@@ -285,8 +299,7 @@ MTL::Library* Device::get_library_(const MTL::StitchedLibraryDescriptor* desc) {
// Throw error if unable to compile library
if (!mtl_lib) {
std::ostringstream msg;
msg << "[metal::Device] Unable to load build stitched metal library"
<< "\n";
msg << "[metal::Device] Unable to build stitched metal library" << "\n";
if (error) {
msg << error->localizedDescription()->utf8String() << "\n";
}
@@ -558,4 +571,23 @@ void new_stream(Stream stream) {
}
}
std::unordered_map<std::string, std::variant<std::string, size_t>>
device_info() {
auto raw_device = device(default_device()).mtl_device();
auto arch = std::string(raw_device->architecture()->name()->utf8String());
int mib[] = {CTL_HW, HW_MEMSIZE};
size_t memsize = 0;
size_t length = sizeof(memsize);
sysctl(mib, 2, &memsize, &length, NULL, 0);
return {
{"architecture", arch},
{"max_buffer_length", raw_device->maxBufferLength()},
{"max_recommended_working_set_size",
raw_device->recommendedMaxWorkingSetSize()},
{"memory_size", memsize}};
}
} // namespace mlx::core::metal

View File

@@ -37,8 +37,10 @@ using MTLFCList =
std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>;
struct CommandEncoder {
CommandEncoder(MTL::ComputeCommandEncoder* enc)
: enc(enc), concurrent(false){};
CommandEncoder(MTL::CommandBuffer* cbuf) : cbuf(cbuf) {
enc = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
enc->retain();
};
CommandEncoder(const CommandEncoder&) = delete;
CommandEncoder& operator=(const CommandEncoder&) = delete;
@@ -89,13 +91,25 @@ struct CommandEncoder {
}
}
void dispatchThreadgroups(MTL::Size grid_dims, MTL::Size group_dims);
void dispatchThreads(MTL::Size grid_dims, MTL::Size group_dims);
ConcurrentContext start_concurrent() {
return ConcurrentContext(*this);
}
~CommandEncoder() {
enc->endEncoding();
enc->release();
}
private:
void maybe_split();
int num_dispatches{0};
MTL::CommandBuffer* cbuf;
MTL::ComputeCommandEncoder* enc;
bool concurrent;
bool concurrent{false};
std::unordered_set<MTL::Resource*> outputs;
std::unordered_set<MTL::Resource*> concurrent_outputs;
};
@@ -112,7 +126,6 @@ class Device {
};
void new_queue(int index);
MTL::CommandBuffer* new_command_buffer(int index);
MTL::CommandBuffer* get_command_buffer(int index);
int get_command_buffer_ops(int index);
void increment_command_buffer_ops(int index);

View File

@@ -97,7 +97,7 @@ void FFT::eval_gpu(const std::vector<array>& inputs, array& out) {
auto group_dims = MTL::Size(1, m, 1);
auto grid_dims = MTL::Size(batch, m, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });

View File

@@ -107,7 +107,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
}
// Launch grid
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -216,7 +216,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
// Launch grid
MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1);
MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} else {
// Collect all idx shapes and strides into one place
@@ -286,7 +286,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
// Launch grid
MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1);
MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
}

View File

@@ -11,22 +11,22 @@ template <typename T>
out[index] = start + index * step;
}
#define instantiate_arange(tname, type) \
template [[host_name("arange" #tname)]] \
[[kernel]] void arange<type>( \
constant const type& start, \
constant const type& step, \
device type* out, \
uint index [[thread_position_in_grid]]);
#define instantiate_arange(tname, type) \
template [[host_name("arange" #tname)]] [[kernel]] void arange<type>( \
constant const type& start, \
constant const type& step, \
device type* out, \
uint index [[thread_position_in_grid]]);
instantiate_arange(uint8, uint8_t)
// clang-format off
instantiate_arange(uint8, uint8_t)
instantiate_arange(uint16, uint16_t)
instantiate_arange(uint32, uint32_t)
instantiate_arange(uint32, uint32_t)
instantiate_arange(uint64, uint64_t)
instantiate_arange(int8, int8_t)
instantiate_arange(int8, int8_t)
instantiate_arange(int16, int16_t)
instantiate_arange(int32, int32_t)
instantiate_arange(int64, int64_t)
instantiate_arange(float16, half)
instantiate_arange(float32, float)
instantiate_arange(bfloat16, bfloat16_t)
instantiate_arange(bfloat16, bfloat16_t) // clang-format on

View File

@@ -18,7 +18,8 @@ struct ArgMin {
static constexpr constant U init = Limits<U>::max;
IndexValPair<U> reduce(IndexValPair<U> best, IndexValPair<U> current) {
if (best.val > current.val || (best.val == current.val && best.index > current.index)) {
if (best.val > current.val ||
(best.val == current.val && best.index > current.index)) {
return current;
} else {
return best;
@@ -26,11 +27,12 @@ struct ArgMin {
}
template <int N>
IndexValPair<U> reduce_many(IndexValPair<U> best, thread U* vals, uint32_t offset) {
for (int i=0; i<N; i++) {
IndexValPair<U>
reduce_many(IndexValPair<U> best, thread U* vals, uint32_t offset) {
for (int i = 0; i < N; i++) {
if (vals[i] < best.val) {
best.val = vals[i];
best.index = offset+i;
best.index = offset + i;
}
}
return best;
@@ -42,7 +44,8 @@ struct ArgMax {
static constexpr constant U init = Limits<U>::min;
IndexValPair<U> reduce(IndexValPair<U> best, IndexValPair<U> current) {
if (best.val < current.val || (best.val == current.val && best.index > current.index)) {
if (best.val < current.val ||
(best.val == current.val && best.index > current.index)) {
return current;
} else {
return best;
@@ -50,11 +53,12 @@ struct ArgMax {
}
template <int N>
IndexValPair<U> reduce_many(IndexValPair<U> best, thread U* vals, uint32_t offset) {
for (int i=0; i<N; i++) {
IndexValPair<U>
reduce_many(IndexValPair<U> best, thread U* vals, uint32_t offset) {
for (int i = 0; i < N; i++) {
if (vals[i] > best.val) {
best.val = vals[i];
best.index = offset+i;
best.index = offset + i;
}
}
return best;
@@ -64,19 +68,16 @@ struct ArgMax {
template <typename U>
IndexValPair<U> simd_shuffle_down(IndexValPair<U> data, uint16_t delta) {
return IndexValPair<U>{
simd_shuffle_down(data.index, delta),
simd_shuffle_down(data.val, delta)
};
simd_shuffle_down(data.index, delta), simd_shuffle_down(data.val, delta)};
}
template <typename T, typename Op, int N_READS>
[[kernel]] void arg_reduce_general(
const device T *in [[buffer(0)]],
device uint32_t *out [[buffer(1)]],
const device int *shape [[buffer(2)]],
const device size_t *in_strides [[buffer(3)]],
const device size_t *out_strides [[buffer(4)]],
const device T* in [[buffer(0)]],
device uint32_t* out [[buffer(1)]],
const device int* shape [[buffer(2)]],
const device size_t* in_strides [[buffer(3)]],
const device size_t* out_strides [[buffer(4)]],
const device size_t& ndim [[buffer(5)]],
const device size_t& axis_stride [[buffer(6)]],
const device size_t& axis_size [[buffer(7)]],
@@ -86,7 +87,6 @@ template <typename T, typename Op, int N_READS>
uint simd_size [[threads_per_simdgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
// Shapes and strides *do not* contain the reduction axis. The reduction size
// and stride are provided in axis_stride and axis_size.
//
@@ -113,13 +113,13 @@ template <typename T, typename Op, int N_READS>
threadgroup IndexValPair<T> local_data[32];
// Loop over the reduction axis in lsize*N_READS buckets
for (uint r=0; r < ceildiv(axis_size, N_READS*lsize); r++) {
for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize); r++) {
// Read the current value
uint32_t current_index = r*lsize*N_READS + lid*N_READS;
uint32_t current_index = r * lsize * N_READS + lid * N_READS;
uint32_t offset = current_index;
const device T * current_in = in + in_idx + current_index * axis_stride;
const device T* current_in = in + in_idx + current_index * axis_stride;
T vals[N_READS];
for (int i=0; i<N_READS; i++) {
for (int i = 0; i < N_READS; i++) {
vals[i] = (current_index < axis_size) ? *current_in : T(Op::init);
current_index++;
current_in += axis_stride;
@@ -130,7 +130,7 @@ template <typename T, typename Op, int N_READS>
// need to reduce across the thread group.
// First per simd reduction.
for (uint offset=simd_size/2; offset>0; offset/=2) {
for (uint offset = simd_size / 2; offset > 0; offset /= 2) {
IndexValPair<T> neighbor = simd_shuffle_down(best, offset);
best = op.reduce(best, neighbor);
}
@@ -149,7 +149,7 @@ template <typename T, typename Op, int N_READS>
if (simd_lane_id < simd_groups) {
best = local_data[simd_lane_id];
}
for (uint offset=simd_size/2; offset>0; offset/=2) {
for (uint offset = simd_size / 2; offset > 0; offset /= 2) {
IndexValPair<T> neighbor = simd_shuffle_down(best, offset);
best = op.reduce(best, neighbor);
}
@@ -161,24 +161,25 @@ template <typename T, typename Op, int N_READS>
}
#define instantiate_arg_reduce_helper(name, itype, op) \
template [[host_name(name)]] \
[[kernel]] void arg_reduce_general<itype, op<itype>, 4>( \
const device itype *in [[buffer(0)]], \
device uint32_t * out [[buffer(1)]], \
const device int *shape [[buffer(2)]], \
const device size_t *in_strides [[buffer(3)]], \
const device size_t *out_strides [[buffer(4)]], \
const device size_t& ndim [[buffer(5)]], \
const device size_t& axis_stride [[buffer(6)]], \
const device size_t& axis_size [[buffer(7)]], \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_size [[threads_per_simdgroup]], \
template [[host_name(name)]] [[kernel]] void \
arg_reduce_general<itype, op<itype>, 4>( \
const device itype* in [[buffer(0)]], \
device uint32_t* out [[buffer(1)]], \
const device int* shape [[buffer(2)]], \
const device size_t* in_strides [[buffer(3)]], \
const device size_t* out_strides [[buffer(4)]], \
const device size_t& ndim [[buffer(5)]], \
const device size_t& axis_stride [[buffer(6)]], \
const device size_t& axis_size [[buffer(7)]], \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_size [[threads_per_simdgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_arg_reduce(name, itype) \
// clang-format off
#define instantiate_arg_reduce(name, itype) \
instantiate_arg_reduce_helper("argmin_" #name , itype, ArgMin) \
instantiate_arg_reduce_helper("argmax_" #name , itype, ArgMax)
@@ -193,4 +194,4 @@ instantiate_arg_reduce(int32, int32_t)
instantiate_arg_reduce(int64, int64_t)
instantiate_arg_reduce(float16, half)
instantiate_arg_reduce(float32, float)
instantiate_arg_reduce(bfloat16, bfloat16_t)
instantiate_arg_reduce(bfloat16, bfloat16_t) // clang-format on

View File

@@ -229,3 +229,45 @@ struct LogicalOr {
return x || y;
};
};
struct BitwiseAnd {
template <typename T>
T operator()(T x, T y) {
return x & y;
};
};
struct BitwiseOr {
template <typename T>
T operator()(T x, T y) {
return x | y;
};
};
struct BitwiseXor {
template <typename T>
T operator()(T x, T y) {
return x ^ y;
};
};
struct LeftShift {
template <typename T>
T operator()(T x, T y) {
return x << y;
};
};
struct RightShift {
template <typename T>
T operator()(T x, T y) {
return x >> y;
};
};
struct ArcTan2 {
template <typename T>
T operator()(T y, T x) {
return metal::precise::atan2(y, x);
}
};

View File

@@ -77,7 +77,8 @@ template <typename T, typename U, typename Op>
uint3 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_3(index, a_strides);
auto b_idx = elem_to_loc_3(index, b_strides);
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
c[out_idx] = Op()(a[a_idx], b[b_idx]);
}
@@ -92,7 +93,8 @@ template <typename T, typename U, typename Op, int DIM>
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
c[out_idx] = Op()(a[idx.x], b[idx.y]);
}
@@ -112,111 +114,118 @@ template <typename T, typename U, typename Op>
c[out_idx] = Op()(a[idx.x], b[idx.y]);
}
#define instantiate_binary(name, itype, otype, op, bopt) \
template [[host_name(name)]] \
[[kernel]] void binary_op_##bopt<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
uint index [[thread_position_in_grid]]);
#define instantiate_binary(name, itype, otype, op, bopt) \
template \
[[host_name(name)]] [[kernel]] void binary_op_##bopt<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
uint index [[thread_position_in_grid]]);
#define instantiate_binary_g_dim(name, itype, otype, op, dims) \
template [[host_name(name "_" #dims)]] \
[[kernel]] void binary_op_g_nd<itype, otype, op, dims>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
constant const int shape[dims], \
constant const size_t a_strides[dims], \
constant const size_t b_strides[dims], \
uint3 index [[thread_position_in_grid]], \
template [[host_name(name "_" #dims)]] [[kernel]] void \
binary_op_g_nd<itype, otype, op, dims>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
constant const int shape[dims], \
constant const size_t a_strides[dims], \
constant const size_t b_strides[dims], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]);
#define instantiate_binary_g_nd(name, itype, otype, op) \
template [[host_name(name "_1")]] \
[[kernel]] void binary_op_g_nd1<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
constant const size_t& a_stride, \
constant const size_t& b_stride, \
uint index [[thread_position_in_grid]]); \
template [[host_name(name "_2")]] \
[[kernel]] void binary_op_g_nd2<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
constant const size_t a_strides[2], \
constant const size_t b_strides[2], \
uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]); \
template [[host_name(name "_3")]] \
[[kernel]] void binary_op_g_nd3<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
constant const size_t a_strides[3], \
constant const size_t b_strides[3], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
instantiate_binary_g_dim(name, itype, otype, op, 4) \
instantiate_binary_g_dim(name, itype, otype, op, 5)
template [[host_name(name "_1")]] [[kernel]] void \
binary_op_g_nd1<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
constant const size_t& a_stride, \
constant const size_t& b_stride, \
uint index [[thread_position_in_grid]]); \
template [[host_name(name "_2")]] [[kernel]] void \
binary_op_g_nd2<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
constant const size_t a_strides[2], \
constant const size_t b_strides[2], \
uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]); \
template [[host_name(name "_3")]] [[kernel]] void \
binary_op_g_nd3<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
constant const size_t a_strides[3], \
constant const size_t b_strides[3], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
instantiate_binary_g_dim(name, itype, otype, op, 4) \
instantiate_binary_g_dim(name, itype, otype, op, 5)
#define instantiate_binary_g(name, itype, otype, op) \
template [[host_name(name)]] \
[[kernel]] void binary_op_g<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
constant const int* shape, \
constant const size_t* a_strides, \
constant const size_t* b_strides, \
constant const int& ndim, \
uint3 index [[thread_position_in_grid]], \
#define instantiate_binary_g(name, itype, otype, op) \
template [[host_name(name)]] [[kernel]] void binary_op_g<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
constant const int* shape, \
constant const size_t* a_strides, \
constant const size_t* b_strides, \
constant const int& ndim, \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]);
// clang-format off
#define instantiate_binary_all(name, tname, itype, otype, op) \
instantiate_binary("ss" #name #tname, itype, otype, op, ss) \
instantiate_binary("sv" #name #tname, itype, otype, op, sv) \
instantiate_binary("vs" #name #tname, itype, otype, op, vs) \
instantiate_binary("vv" #name #tname, itype, otype, op, vv) \
instantiate_binary_g("g" #name #tname, itype, otype, op) \
instantiate_binary_g_nd("g" #name #tname, itype, otype, op)
instantiate_binary_g("g" #name #tname, itype, otype, op) \
instantiate_binary_g_nd("g" #name #tname, itype, otype, op) // clang-format on
#define instantiate_binary_float(name, op) \
instantiate_binary_all(name, float16, half, half, op) \
instantiate_binary_all(name, float32, float, float, op) \
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op)
#define instantiate_binary_types(name, op) \
instantiate_binary_all(name, bool_, bool, bool, op) \
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op) \
// clang-format off
#define instantiate_binary_integer(name, op) \
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op) \
instantiate_binary_all(name, uint16, uint16_t, uint16_t, op) \
instantiate_binary_all(name, uint32, uint32_t, uint32_t, op) \
instantiate_binary_all(name, uint64, uint64_t, uint64_t, op) \
instantiate_binary_all(name, int8, int8_t, int8_t, op) \
instantiate_binary_all(name, int16, int16_t, int16_t, op) \
instantiate_binary_all(name, int32, int32_t, int32_t, op) \
instantiate_binary_all(name, int64, int64_t, int64_t, op) \
instantiate_binary_all(name, int8, int8_t, int8_t, op) \
instantiate_binary_all(name, int16, int16_t, int16_t, op) \
instantiate_binary_all(name, int32, int32_t, int32_t, op) \
instantiate_binary_all(name, int64, int64_t, int64_t, op) // clang-format on
// clang-format off
#define instantiate_binary_float(name, op) \
instantiate_binary_all(name, float16, half, half, op) \
instantiate_binary_all(name, float32, float, float, op) \
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op) // clang-format on
// clang-format off
#define instantiate_binary_types(name, op) \
instantiate_binary_all(name, bool_, bool, bool, op) \
instantiate_binary_integer(name, op) \
instantiate_binary_all(name, complex64, complex64_t, complex64_t, op) \
instantiate_binary_float(name, op)
instantiate_binary_float(name, op) // clang-format on
#define instantiate_binary_types_bool(name, op) \
instantiate_binary_all(name, bool_, bool, bool, op) \
instantiate_binary_all(name, uint8, uint8_t, bool, op) \
instantiate_binary_all(name, uint16, uint16_t, bool, op) \
instantiate_binary_all(name, uint32, uint32_t, bool, op) \
instantiate_binary_all(name, uint64, uint64_t, bool, op) \
instantiate_binary_all(name, int8, int8_t, bool, op) \
instantiate_binary_all(name, int16, int16_t, bool, op) \
instantiate_binary_all(name, int32, int32_t, bool, op) \
instantiate_binary_all(name, int64, int64_t, bool, op) \
instantiate_binary_all(name, float16, half, bool, op) \
instantiate_binary_all(name, float32, float, bool, op) \
// clang-format off
#define instantiate_binary_types_bool(name, op) \
instantiate_binary_all(name, bool_, bool, bool, op) \
instantiate_binary_all(name, uint8, uint8_t, bool, op) \
instantiate_binary_all(name, uint16, uint16_t, bool, op) \
instantiate_binary_all(name, uint32, uint32_t, bool, op) \
instantiate_binary_all(name, uint64, uint64_t, bool, op) \
instantiate_binary_all(name, int8, int8_t, bool, op) \
instantiate_binary_all(name, int16, int16_t, bool, op) \
instantiate_binary_all(name, int32, int32_t, bool, op) \
instantiate_binary_all(name, int64, int64_t, bool, op) \
instantiate_binary_all(name, float16, half, bool, op) \
instantiate_binary_all(name, float32, float, bool, op) \
instantiate_binary_all(name, bfloat16, bfloat16_t, bool, op) \
instantiate_binary_all(name, complex64, complex64_t, bool, op)
instantiate_binary_all(name, complex64, complex64_t, bool, op) // clang-format on
// clang-format off
instantiate_binary_types(add, Add)
instantiate_binary_types(div, Divide)
instantiate_binary_types_bool(eq, Equal)
@@ -232,6 +241,7 @@ instantiate_binary_types(mul, Multiply)
instantiate_binary_types(sub, Subtract)
instantiate_binary_types(pow, Power)
instantiate_binary_types(rem, Remainder)
instantiate_binary_float(arctan2, ArcTan2)
// NaNEqual only needed for floating point types with boolean output
instantiate_binary_all(naneq, float16, half, bool, NaNEqual)
@@ -241,3 +251,13 @@ instantiate_binary_all(naneq, complex64, complex64_t, bool, NaNEqual)
instantiate_binary_all(lor, bool_, bool, bool, LogicalOr)
instantiate_binary_all(land, bool_, bool, bool, LogicalAnd)
// Bitwise ops only need integer types and bool (except for l/r shift)
instantiate_binary_integer(bitwise_and, BitwiseAnd)
instantiate_binary_all(bitwise_and, bool_, bool, bool, BitwiseAnd)
instantiate_binary_integer(bitwise_or, BitwiseOr)
instantiate_binary_all(bitwise_or, bool_, bool, bool, BitwiseOr)
instantiate_binary_integer(bitwise_xor, BitwiseXor)
instantiate_binary_all(bitwise_xor, bool_, bool, bool, BitwiseXor)
instantiate_binary_integer(left_shift, LeftShift)
instantiate_binary_integer(right_shift, RightShift) // clang-format on

View File

@@ -3,28 +3,42 @@
#include <metal_integer>
#include <metal_math>
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h"
struct FloorDivide {
template <typename T> T operator()(T x, T y) { return x / y; }
template <> float operator()(float x, float y) { return trunc(x / y); }
template <> half operator()(half x, half y) { return trunc(x / y); }
template <> bfloat16_t operator()(bfloat16_t x, bfloat16_t y) { return trunc(x / y); }
template <typename T>
T operator()(T x, T y) {
return x / y;
}
template <>
float operator()(float x, float y) {
return trunc(x / y);
}
template <>
half operator()(half x, half y) {
return trunc(x / y);
}
template <>
bfloat16_t operator()(bfloat16_t x, bfloat16_t y) {
return trunc(x / y);
}
};
struct Remainder {
template <typename T>
metal::enable_if_t<metal::is_integral_v<T> & !metal::is_signed_v<T>, T> operator()(T x, T y) {
metal::enable_if_t<metal::is_integral_v<T> & !metal::is_signed_v<T>, T>
operator()(T x, T y) {
return x % y;
}
template <typename T>
metal::enable_if_t<metal::is_integral_v<T> & metal::is_signed_v<T>, T> operator()(T x, T y) {
metal::enable_if_t<metal::is_integral_v<T> & metal::is_signed_v<T>, T>
operator()(T x, T y) {
auto r = x % y;
if (r != 0 && (r < 0 != y < 0)) {
r += y;
}
return r;
return r;
}
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
@@ -32,10 +46,11 @@ struct Remainder {
if (r != 0 && (r < 0 != y < 0)) {
r += y;
}
return r;
return r;
}
template <> complex64_t operator()(complex64_t x, complex64_t y) {
return x % y;
template <>
complex64_t operator()(complex64_t x, complex64_t y) {
return x % y;
}
};
@@ -50,7 +65,6 @@ template <typename T, typename U, typename Op1, typename Op2>
d[index] = Op2()(a[0], b[0]);
}
template <typename T, typename U, typename Op1, typename Op2>
[[kernel]] void binary_op_ss(
device const T* a,
@@ -139,7 +153,8 @@ template <typename T, typename U, typename Op1, typename Op2>
uint3 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_3(index, a_strides);
auto b_idx = elem_to_loc_3(index, b_strides);
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
c[out_idx] = Op1()(a[a_idx], b[b_idx]);
d[out_idx] = Op2()(a[a_idx], b[b_idx]);
}
@@ -156,7 +171,8 @@ template <typename T, typename U, typename Op1, typename Op2, int DIM>
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
c[out_idx] = Op1()(a[idx.x], b[idx.y]);
d[out_idx] = Op2()(a[idx.x], b[idx.y]);
}
@@ -180,99 +196,102 @@ template <typename T, typename U, typename Op1, typename Op2>
}
#define instantiate_binary(name, itype, otype, op1, op2, bopt) \
template [[host_name(name)]] \
[[kernel]] void binary_op_##bopt<itype, otype, op1, op2>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
device otype* d, \
uint index [[thread_position_in_grid]]);
template [[host_name(name)]] [[kernel]] void \
binary_op_##bopt<itype, otype, op1, op2>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
device otype* d, \
uint index [[thread_position_in_grid]]);
#define instantiate_binary_g_dim(name, itype, otype, op1, op2, dims) \
template [[host_name(name "_" #dims)]] \
[[kernel]] void binary_op_g_nd<itype, otype, op1, op2, dims>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
device otype* d, \
constant const int shape[dims], \
constant const size_t a_strides[dims], \
constant const size_t b_strides[dims], \
uint3 index [[thread_position_in_grid]], \
template [[host_name(name "_" #dims)]] [[kernel]] void \
binary_op_g_nd<itype, otype, op1, op2, dims>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
device otype* d, \
constant const int shape[dims], \
constant const size_t a_strides[dims], \
constant const size_t b_strides[dims], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]);
// clang-format off
#define instantiate_binary_g_nd(name, itype, otype, op1, op2) \
template [[host_name(name "_1")]] \
[[kernel]] void binary_op_g_nd1<itype, otype, op1, op2>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
device otype* d, \
constant const size_t& a_stride, \
constant const size_t& b_stride, \
uint index [[thread_position_in_grid]]); \
template [[host_name(name "_2")]] \
[[kernel]] void binary_op_g_nd2<itype, otype, op1, op2>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
device otype* d, \
constant const size_t a_strides[2], \
constant const size_t b_strides[2], \
uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]); \
template [[host_name(name "_3")]] \
[[kernel]] void binary_op_g_nd3<itype, otype, op1, op2>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
device otype* d, \
constant const size_t a_strides[3], \
constant const size_t b_strides[3], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
instantiate_binary_g_dim(name, itype, otype, op1, op2, 4) \
instantiate_binary_g_dim(name, itype, otype, op1, op2, 5)
template [[host_name(name "_1")]] [[kernel]] void \
binary_op_g_nd1<itype, otype, op1, op2>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
device otype* d, \
constant const size_t& a_stride, \
constant const size_t& b_stride, \
uint index [[thread_position_in_grid]]); \
template [[host_name(name "_2")]] [[kernel]] void \
binary_op_g_nd2<itype, otype, op1, op2>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
device otype* d, \
constant const size_t a_strides[2], \
constant const size_t b_strides[2], \
uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]); \
template [[host_name(name "_3")]] [[kernel]] void \
binary_op_g_nd3<itype, otype, op1, op2>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
device otype* d, \
constant const size_t a_strides[3], \
constant const size_t b_strides[3], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
instantiate_binary_g_dim(name, itype, otype, op1, op2, 4) \
instantiate_binary_g_dim(name, itype, otype, op1, op2, 5) // clang-format on
#define instantiate_binary_g(name, itype, otype, op1, op2) \
template [[host_name(name)]] \
[[kernel]] void binary_op_g<itype, otype, op2, op2>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
device otype* d, \
constant const int* shape, \
constant const size_t* a_strides, \
constant const size_t* b_strides, \
constant const int& ndim, \
uint3 index [[thread_position_in_grid]], \
template [[host_name(name)]] [[kernel]] void \
binary_op_g<itype, otype, op2, op2>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
device otype* d, \
constant const int* shape, \
constant const size_t* a_strides, \
constant const size_t* b_strides, \
constant const int& ndim, \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]);
// clang-format off
#define instantiate_binary_all(name, tname, itype, otype, op1, op2) \
instantiate_binary("ss" #name #tname, itype, otype, op1, op2, ss) \
instantiate_binary("sv" #name #tname, itype, otype, op1, op2, sv) \
instantiate_binary("vs" #name #tname, itype, otype, op1, op2, vs) \
instantiate_binary("vv" #name #tname, itype, otype, op1, op2, vv) \
instantiate_binary_g("g" #name #tname, itype, otype, op1, op2) \
instantiate_binary_g_nd("g" #name #tname, itype, otype, op1, op2)
instantiate_binary_g("g" #name #tname, itype, otype, op1, op2) \
instantiate_binary_g_nd("g" #name #tname, itype, otype, op1, op2) // clang-format on
#define instantiate_binary_float(name, op1, op2) \
instantiate_binary_all(name, float16, half, half, op1, op2) \
// clang-format off
#define instantiate_binary_float(name, op1, op2) \
instantiate_binary_all(name, float16, half, half, op1, op2) \
instantiate_binary_all(name, float32, float, float, op1, op2) \
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op1, op2)
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op1, op2) // clang-format on
#define instantiate_binary_types(name, op1, op2) \
instantiate_binary_all(name, bool_, bool, bool, op1, op2) \
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op1, op2) \
instantiate_binary_all(name, uint16, uint16_t, uint16_t, op1, op2) \
instantiate_binary_all(name, uint32, uint32_t, uint32_t, op1, op2) \
instantiate_binary_all(name, uint64, uint64_t, uint64_t, op1, op2) \
instantiate_binary_all(name, int8, int8_t, int8_t, op1, op2) \
instantiate_binary_all(name, int16, int16_t, int16_t, op1, op2) \
instantiate_binary_all(name, int32, int32_t, int32_t, op1, op2) \
instantiate_binary_all(name, int64, int64_t, int64_t, op1, op2) \
// clang-format off
#define instantiate_binary_types(name, op1, op2) \
instantiate_binary_all(name, bool_, bool, bool, op1, op2) \
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op1, op2) \
instantiate_binary_all(name, uint16, uint16_t, uint16_t, op1, op2) \
instantiate_binary_all(name, uint32, uint32_t, uint32_t, op1, op2) \
instantiate_binary_all(name, uint64, uint64_t, uint64_t, op1, op2) \
instantiate_binary_all(name, int8, int8_t, int8_t, op1, op2) \
instantiate_binary_all(name, int16, int16_t, int16_t, op1, op2) \
instantiate_binary_all(name, int32, int32_t, int32_t, op1, op2) \
instantiate_binary_all(name, int64, int64_t, int64_t, op1, op2) \
instantiate_binary_all(name, complex64, complex64_t, complex64_t, op1, op2) \
instantiate_binary_float(name, op1, op2)
instantiate_binary_types(divmod, FloorDivide, Remainder)
instantiate_binary_types(divmod, FloorDivide, Remainder) // clang-format on

View File

@@ -22,7 +22,7 @@ struct complex64_t {
float imag;
// Constructors
constexpr complex64_t(float real, float imag) : real(real), imag(imag){};
constexpr complex64_t(float real, float imag) : real(real), imag(imag) {};
// Conversions to complex64_t
template <

View File

@@ -1,13 +1,11 @@
// Copyright © 2023-2024 Apple Inc.
#include <metal_stdlib>
#include <metal_simdgroup>
#include <metal_simdgroup_matrix>
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/steel/conv/params.h"
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/steel/conv/params.h"
#define MLX_MTL_CONST static constant constexpr const
@@ -23,17 +21,18 @@ template <typename T, int N>
device T* out [[buffer(1)]],
const constant MLXConvParams<N>* params [[buffer(2)]],
uint3 gid [[thread_position_in_grid]]) {
int filter_size = params->C;
for(short i = 0; i < N; i++) filter_size *= params->wS[i];
for (short i = 0; i < N; i++)
filter_size *= params->wS[i];
int out_pixels = 1;
for(short i = 0; i < N; i++) out_pixels *= params->oS[i];
for (short i = 0; i < N; i++)
out_pixels *= params->oS[i];
// Set out
// Set out
out += gid.z * filter_size + gid.y * (params->C);
// Corrdinates in input
// Coordinates in input
int is[N] = {0};
// gid.z: N oS (Batch and row in unfolded output)
@@ -46,11 +45,11 @@ template <typename T, int N>
bool valid = n < params->N;
// Unroll dimensions
// Unroll dimensions
for (int i = N - 1; i >= 0; --i) {
int os_ = (oS % params->oS[i]);
int ws_ = (wS % params->wS[i]);
ws_ = params->flip ? params->wS[i] - ws_ - 1 : ws_;
int is_ = os_ * params->str[i] - params->pad[i] + ws_ * params->kdil[i];
@@ -64,10 +63,10 @@ template <typename T, int N>
wS /= params->wS[i];
}
if(valid) {
if (valid) {
size_t in_offset = n * params->in_strides[0];
for(int i = 0; i < N; ++i) {
for (int i = 0; i < N; ++i) {
in_offset += is[i] * params->in_strides[i + 1];
}
@@ -75,21 +74,91 @@ template <typename T, int N>
} else {
out[gid.x] = T(0);
}
}
#define instantiate_naive_unfold_nd(name, itype, n) \
template [[host_name("naive_unfold_nd_" #name "_" #n)]] \
[[kernel]] void naive_unfold_Nd( \
const device itype* in [[buffer(0)]], \
device itype* out [[buffer(1)]], \
const constant MLXConvParams<n>* params [[buffer(2)]], \
uint3 gid [[thread_position_in_grid]]);
// This kernel unfolds the input array of size (N, *spatial_dims, C)
// into an array of size (N x *spatial_dims, C x *kernel_dims).
template <typename T, int N>
[[kernel]] void naive_unfold_transpose_Nd(
const device T* in [[buffer(0)]],
device T* out [[buffer(1)]],
const constant MLXConvParams<N>* params [[buffer(2)]],
uint3 gid [[thread_position_in_grid]]) {
int filter_size = params->C;
for (short i = 0; i < N; i++)
filter_size *= params->wS[i];
#define instantiate_naive_unfold_nd_dims(name, itype) \
instantiate_naive_unfold_nd(name, itype, 1) \
instantiate_naive_unfold_nd(name, itype, 2) \
instantiate_naive_unfold_nd(name, itype, 3)
int out_pixels = 1;
for (short i = 0; i < N; i++)
out_pixels *= params->oS[i];
// Set out
out += gid.z * filter_size + gid.x * (filter_size / params->C);
// Coordinates in input
int is[N] = {0};
// gid.z: N oS (Batch and row in unfolded output)
// gid.y: wS (Filter location to unfold input)
// gid.x: C (channel)
int n = (gid.z) / out_pixels;
int oS = (gid.z) % out_pixels;
int wS = gid.y;
bool valid = n < params->N;
// Unroll dimensions
for (int i = N - 1; i >= 0; --i) {
int os_ = (oS % params->oS[i]);
int ws_ = (wS % params->wS[i]);
ws_ = params->flip ? params->wS[i] - ws_ - 1 : ws_;
int is_ = os_ * params->str[i] - params->pad[i] + ws_ * params->kdil[i];
int is_max = 1 + params->idil[i] * (params->iS[i] - 1);
valid &= is_ >= 0 && is_ < is_max && (is_ % params->idil[i] == 0);
is[i] = is_ / params->idil[i];
oS /= params->oS[i];
wS /= params->wS[i];
out += ws_ * params->str[i];
}
if (valid) {
size_t in_offset = n * params->in_strides[0];
for (int i = 0; i < N; ++i) {
in_offset += is[i] * params->in_strides[i + 1];
}
out[0] = in[in_offset + gid.x];
} else {
out[0] = T(0);
}
}
#define instantiate_naive_unfold_nd(name, itype, n) \
template [[host_name("naive_unfold_nd_" #name "_" #n)]] [[kernel]] void \
naive_unfold_Nd( \
const device itype* in [[buffer(0)]], \
device itype* out [[buffer(1)]], \
const constant MLXConvParams<n>* params [[buffer(2)]], \
uint3 gid [[thread_position_in_grid]]); \
template \
[[host_name("naive_unfold_transpose_nd_" #name "_" #n)]] [[kernel]] void \
naive_unfold_transpose_Nd( \
const device itype* in [[buffer(0)]], \
device itype* out [[buffer(1)]], \
const constant MLXConvParams<n>* params [[buffer(2)]], \
uint3 gid [[thread_position_in_grid]]);
#define instantiate_naive_unfold_nd_dims(name, itype) \
instantiate_naive_unfold_nd(name, itype, 1) instantiate_naive_unfold_nd( \
name, itype, 2) instantiate_naive_unfold_nd(name, itype, 3)
instantiate_naive_unfold_nd_dims(float32, float);
instantiate_naive_unfold_nd_dims(float16, half);
@@ -99,12 +168,13 @@ instantiate_naive_unfold_nd_dims(bfloat16, bfloat16_t);
/// Slow and naive conv2d kernels
///////////////////////////////////////////////////////////////////////////////
template <typename T,
const int BM, /* Threadgroup rows (in threads) */
const int BN, /* Threadgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN, /* Thread cols (in elements) */
const int BC = 16>
template <
typename T,
const int BM, /* Threadgroup rows (in threads) */
const int BN, /* Threadgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN, /* Thread cols (in elements) */
const int BC = 16>
[[kernel]] void naive_conv_2d(
const device T* in [[buffer(0)]],
const device T* wt [[buffer(1)]],
@@ -114,7 +184,6 @@ template <typename T,
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
(void)simd_gid;
(void)simd_lid;
@@ -123,80 +192,82 @@ template <typename T,
int out_o = tid.y * BN * TN + lid.y * TN;
int out_hw = tid.x * BM * TM + lid.x * TM;
int out_h[TM];
int out_w[TN];
for(int m = 0; m < TM; ++m) {
for (int m = 0; m < TM; ++m) {
int mm = (out_hw + m);
out_h[m] = mm / params.oS[1];
out_w[m] = mm % params.oS[1];
}
T in_local[TM];
T wt_local[TN];
T out_local[TM * TN] = {T(0)};
for(int h = 0; h < params.wS[0]; ++h) {
for(int w = 0; w < params.wS[1]; ++w) {
for(int c = 0; c < params.C; ++c) {
for (int h = 0; h < params.wS[0]; ++h) {
for (int w = 0; w < params.wS[1]; ++w) {
for (int c = 0; c < params.C; ++c) {
// Local in
for(int m = 0; m < TM; m++) {
for (int m = 0; m < TM; m++) {
int i = out_h[m] * params.str[0] - params.pad[0] + h * params.kdil[0];
int j = out_w[m] * params.str[1] - params.pad[1] + w * params.kdil[1];
bool valid = i >= 0 && i < params.iS[0] && j >= 0 && j < params.iS[1];
in_local[m] = valid ? in[i * params.in_strides[1] + j * params.in_strides[2] + c] : T(0);
in_local[m] = valid
? in[i * params.in_strides[1] + j * params.in_strides[2] + c]
: T(0);
}
// Load weight
for (int n = 0; n < TN; ++n) {
int o = out_o + n;
wt_local[n] = o < params.O ? wt[o * params.wt_strides[0] +
h * params.wt_strides[1] +
w * params.wt_strides[2] + c] : T(0);
wt_local[n] = o < params.O
? wt[o * params.wt_strides[0] + h * params.wt_strides[1] +
w * params.wt_strides[2] + c]
: T(0);
}
// Accumulate
for(int m = 0; m < TM; ++m) {
for(int n = 0; n < TN; ++n) {
for (int m = 0; m < TM; ++m) {
for (int n = 0; n < TN; ++n) {
out_local[m * TN + n] += in_local[m] * wt_local[n];
}
}
}
}
}
for(int m = 0; m < TM; ++m) {
for(int n = 0; n < TN; ++n) {
if(out_h[m] < params.oS[0] && out_w[m] < params.oS[1] && (out_o + n) < params.O)
out[out_h[m] * params.out_strides[1] +
out_w[m] * params.out_strides[2] + out_o + n] = out_local[m * TN + n];
for (int m = 0; m < TM; ++m) {
for (int n = 0; n < TN; ++n) {
if (out_h[m] < params.oS[0] && out_w[m] < params.oS[1] &&
(out_o + n) < params.O)
out[out_h[m] * params.out_strides[1] +
out_w[m] * params.out_strides[2] + out_o + n] =
out_local[m * TN + n];
}
}
}
// Instantiations
#define instantiate_naive_conv_2d(name, itype, bm, bn, tm, tn) \
template [[host_name("naive_conv_2d_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn)]] \
[[kernel]] void naive_conv_2d<itype, bm, bn, tm, tn>( \
const device itype* in [[buffer(0)]], \
const device itype* wt [[buffer(1)]], \
device itype* out [[buffer(2)]], \
const constant MLXConvParams<2>& params [[buffer(3)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
#define instantiate_naive_conv_2d(name, itype, bm, bn, tm, tn) \
template [[host_name("naive_conv_2d_" #name "_bm" #bm "_bn" #bn "_tm" #tm \
"_tn" #tn)]] [[kernel]] void \
naive_conv_2d<itype, bm, bn, tm, tn>( \
const device itype* in [[buffer(0)]], \
const device itype* wt [[buffer(1)]], \
device itype* out [[buffer(2)]], \
const constant MLXConvParams<2>& params [[buffer(3)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_naive_conv_2d_blocks(name, itype) \
instantiate_naive_conv_2d(name, itype, 16, 8, 4, 4) \
instantiate_naive_conv_2d(name, itype, 16, 8, 2, 4)
instantiate_naive_conv_2d(name, itype, 16, 8, 4, 4) \
instantiate_naive_conv_2d(name, itype, 16, 8, 2, 4)
instantiate_naive_conv_2d_blocks(float32, float);
instantiate_naive_conv_2d_blocks(float16, half);
@@ -207,9 +278,7 @@ instantiate_naive_conv_2d_blocks(bfloat16, bfloat16_t);
///////////////////////////////////////////////////////////////////////////////
template <int M, int R, int S>
struct WinogradTransforms {
};
struct WinogradTransforms {};
template <>
struct WinogradTransforms<6, 3, 8> {
@@ -218,36 +287,36 @@ struct WinogradTransforms<6, 3, 8> {
MLX_MTL_CONST int IN_TILE_SIZE = OUT_TILE_SIZE + FILTER_SIZE - 1;
MLX_MTL_CONST int SIMD_MATRIX_SIZE = 8;
MLX_MTL_CONST float in_transform[SIMD_MATRIX_SIZE][SIMD_MATRIX_SIZE] = {
{ 1.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f},
{ 0.00f, 1.00f, -1.00f, 0.50f, -0.50f, 2.00f, -2.00f, -1.00f},
{-5.25f, 1.00f, 1.00f, 0.25f, 0.25f, 4.00f, 4.00f, 0.00f},
{ 0.00f, -4.25f, 4.25f, -2.50f, 2.50f, -2.50f, 2.50f, 5.25f},
{ 5.25f, -4.25f, -4.25f, -1.25f, -1.25f, -5.00f, -5.00f, 0.00f},
{ 0.00f, 1.00f, -1.00f, 2.00f, -2.00f, 0.50f, -0.50f, -5.25f},
{-1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 0.00f},
{ 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 1.00f},
{1.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f},
{0.00f, 1.00f, -1.00f, 0.50f, -0.50f, 2.00f, -2.00f, -1.00f},
{-5.25f, 1.00f, 1.00f, 0.25f, 0.25f, 4.00f, 4.00f, 0.00f},
{0.00f, -4.25f, 4.25f, -2.50f, 2.50f, -2.50f, 2.50f, 5.25f},
{5.25f, -4.25f, -4.25f, -1.25f, -1.25f, -5.00f, -5.00f, 0.00f},
{0.00f, 1.00f, -1.00f, 2.00f, -2.00f, 0.50f, -0.50f, -5.25f},
{-1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 0.00f},
{0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 1.00f},
};
MLX_MTL_CONST float out_transform[SIMD_MATRIX_SIZE][SIMD_MATRIX_SIZE] = {
{ 1.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f},
{ 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f},
{ 1.00f, -1.00f, 1.00f, -1.00f, 1.00f, -1.00f},
{ 1.00f, 2.00f, 4.00f, 8.00f, 16.00f, 32.00f},
{ 1.00f, -2.00f, 4.00f, -8.00f, 16.00f, -32.00f},
{ 1.00f, 0.50f, 0.25f, 0.125f, 0.0625f, 0.03125f},
{ 1.00f, -0.50f, 0.25f, -0.125f, 0.0625f, -0.03125f},
{ 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 1.00f},
{1.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f},
{1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f},
{1.00f, -1.00f, 1.00f, -1.00f, 1.00f, -1.00f},
{1.00f, 2.00f, 4.00f, 8.00f, 16.00f, 32.00f},
{1.00f, -2.00f, 4.00f, -8.00f, 16.00f, -32.00f},
{1.00f, 0.50f, 0.25f, 0.125f, 0.0625f, 0.03125f},
{1.00f, -0.50f, 0.25f, -0.125f, 0.0625f, -0.03125f},
{0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 1.00f},
};
MLX_MTL_CONST float wt_transform[SIMD_MATRIX_SIZE][SIMD_MATRIX_SIZE] = {
{ 1.00, 0.00, 0.00},
{ -2.0/9.00, -2.0/9.00, -2.0/9.00},
{ -2.0/9.00, 2.0/9.00, -2.0/9.00},
{ 1.0/90.0, 1.0/45.0, 2.0/45.0},
{ 1.0/90.0, -1.0/45.0, 2.0/45.0},
{ 32.0/45.0, 16.0/45.0, 8.0/45.0},
{ 32.0/45.0, -16.0/45.0, 8.0/45.0},
{ 0.00, 0.00, 1.00},
{1.00, 0.00, 0.00},
{-2.0 / 9.00, -2.0 / 9.00, -2.0 / 9.00},
{-2.0 / 9.00, 2.0 / 9.00, -2.0 / 9.00},
{1.0 / 90.0, 1.0 / 45.0, 2.0 / 45.0},
{1.0 / 90.0, -1.0 / 45.0, 2.0 / 45.0},
{32.0 / 45.0, 16.0 / 45.0, 8.0 / 45.0},
{32.0 / 45.0, -16.0 / 45.0, 8.0 / 45.0},
{0.00, 0.00, 1.00},
};
};
@@ -255,12 +324,9 @@ constant constexpr const float WinogradTransforms<6, 3, 8>::wt_transform[8][8];
constant constexpr const float WinogradTransforms<6, 3, 8>::in_transform[8][8];
constant constexpr const float WinogradTransforms<6, 3, 8>::out_transform[8][8];
template <typename T,
int BC = 32,
int BO = 4,
int M = 6,
int R = 3>
[[kernel, max_total_threads_per_threadgroup(BO * 32)]] void winograd_conv_2d_weight_transform(
template <typename T, int BC = 32, int BO = 4, int M = 6, int R = 3>
[[kernel, max_total_threads_per_threadgroup(BO * 32)]] void
winograd_conv_2d_weight_transform(
const device T* wt_in [[buffer(0)]],
device T* wt_out [[buffer(1)]],
const constant int& C [[buffer(2)]],
@@ -268,7 +334,6 @@ template <typename T,
uint tid [[threadgroup_position_in_grid]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]]) {
using WGT = WinogradTransforms<M, R, 8>;
// Get lane position in simdgroup
@@ -288,35 +353,37 @@ template <typename T,
// Move to the correct output filter
size_t ko = BO * tid + simd_group_id;
wt_in += ko * R * R * C;
wt_in += ko * R * R * C;
// wt_out is stored transposed (A x A x C x O)
short ohw_0 = sm * 8 + sn;
short ohw_1 = sm * 8 + sn + 1;
device T* wt_out_0 = wt_out + ohw_0 * C * O + ko;
device T* wt_out_1 = wt_out + ohw_1 * C * O + ko;
device T* wt_out_1 = wt_out + ohw_1 * C * O + ko;
// Prepare shared memory
threadgroup T Ws[BO][R][R][BC];
// Loop over C
for(int bc = 0; bc < C; bc += BC) {
for (int bc = 0; bc < C; bc += BC) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Read into shared memory
for(int kh = 0; kh < R; ++kh) {
for(int kw = 0; kw < R; ++kw) {
for(int kc = simd_lane_id; kc < BC; kc += 32) {
for (int kh = 0; kh < R; ++kh) {
for (int kw = 0; kw < R; ++kw) {
for (int kc = simd_lane_id; kc < BC; kc += 32) {
Ws[simd_group_id][kh][kw][kc] = wt_in[kh * R * C + kw * C + kc];
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Do transform and store the result
for(int c = 0; c < BC; ++c) {
// Do transform and store the result
for (int c = 0; c < BC; ++c) {
simdgroup_matrix<T, 8, 8> g;
g.thread_elements()[0] = sm < R && sn < R ? Ws[simd_group_id][sm][sn][c] : T(0);
g.thread_elements()[1] = sm < R && sn + 1 < R ? Ws[simd_group_id][sm][sn + 1][c] : T(0);
g.thread_elements()[0] =
sm < R && sn < R ? Ws[simd_group_id][sm][sn][c] : T(0);
g.thread_elements()[1] =
sm < R && sn + 1 < R ? Ws[simd_group_id][sm][sn + 1][c] : T(0);
simdgroup_matrix<T, 8, 8> g_out = (G * g) * Gt;
wt_out_0[c * O] = g_out.thread_elements()[0];
@@ -327,27 +394,23 @@ template <typename T,
wt_out_0 += BC * O;
wt_out_1 += BC * O;
}
}
#define instantiate_winograd_conv_2d_weight_transform_base(name, itype, bc) \
template [[host_name("winograd_conv_2d_weight_transform_" #name "_bc" #bc)]]\
[[kernel]] void winograd_conv_2d_weight_transform<itype, bc>(\
const device itype* wt_in [[buffer(0)]],\
device itype* wt_out [[buffer(1)]],\
const constant int& C [[buffer(2)]],\
const constant int& O [[buffer(3)]],\
uint tid [[threadgroup_position_in_grid]],\
uint simd_group_id [[simdgroup_index_in_threadgroup]],\
template [[host_name("winograd_conv_2d_weight_transform_" #name \
"_bc" #bc)]] [[kernel]] void \
winograd_conv_2d_weight_transform<itype, bc>( \
const device itype* wt_in [[buffer(0)]], \
device itype* wt_out [[buffer(1)]], \
const constant int& C [[buffer(2)]], \
const constant int& O [[buffer(3)]], \
uint tid [[threadgroup_position_in_grid]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]]);
template <typename T,
int BC,
int WM,
int WN,
int M = 6,
int R = 3>
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void winograd_conv_2d_input_transform(
template <typename T, int BC, int WM, int WN, int M = 6, int R = 3>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
winograd_conv_2d_input_transform(
const device T* inp_in [[buffer(0)]],
device T* inp_out [[buffer(1)]],
const constant MLXConvParams<2>& params [[buffer(2)]],
@@ -356,7 +419,6 @@ template <typename T,
uint3 tgp_per_grid [[threadgroups_per_grid]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]]) {
(void)lid;
using WGT = WinogradTransforms<M, R, 8>;
@@ -387,46 +449,48 @@ template <typename T,
int bw = M * tid.x + kw;
// Move to the correct input tile
inp_in += tid.z * params.in_strides[0]
+ bh * params.in_strides[1]
+ bw * params.in_strides[2];
inp_in += tid.z * params.in_strides[0] + bh * params.in_strides[1] +
bw * params.in_strides[2];
// Pre compute strides
// Pre compute strides
int jump_in[TH][TW];
for(int h = 0; h < TH; h++) {
for(int w = 0; w < TW; w++) {
jump_in[h][w] = h * params.in_strides[1] + w * params.in_strides[2];
for (int h = 0; h < TH; h++) {
for (int w = 0; w < TW; w++) {
jump_in[h][w] = h * params.in_strides[1] + w * params.in_strides[2];
}
}
// inp_out is stored interleaved (A x A x tiles x C)
size_t N_TILES = tgp_per_grid.x * tgp_per_grid.y * tgp_per_grid.z;
size_t tile_id = tid.z * tgp_per_grid.x * tgp_per_grid.y + tid.y * tgp_per_grid.x + tid.x;
size_t tile_id =
tid.z * tgp_per_grid.x * tgp_per_grid.y + tid.y * tgp_per_grid.x + tid.x;
size_t ohw_0 = sm * 8 + sn;
size_t ohw_1 = sm * 8 + sn + 1;
device T* inp_out_0 = inp_out + ohw_0 * N_TILES * params.C + tile_id * params.C;
device T* inp_out_1 = inp_out + ohw_1 * N_TILES * params.C + tile_id * params.C;
device T* inp_out_0 =
inp_out + ohw_0 * N_TILES * params.C + tile_id * params.C;
device T* inp_out_1 =
inp_out + ohw_1 * N_TILES * params.C + tile_id * params.C;
// Prepare shared memory
threadgroup T Is[A][A][BC];
// Loop over C
for(int bc = 0; bc < params.C; bc += BC) {
for (int bc = 0; bc < params.C; bc += BC) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Read into shared memory
for(int h = 0; h < TH; h++) {
for(int w = 0; w < TW; w++) {
for (int h = 0; h < TH; h++) {
for (int w = 0; w < TW; w++) {
const device T* in_ptr = inp_in + jump_in[h][w];
for(int c = simd_lane_id; c < BC; c += 32) {
for (int c = simd_lane_id; c < BC; c += 32) {
Is[kh + h][kw + w][c] = in_ptr[c];
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Do transform and store the result
for(int c = simd_group_id; c < BC; c += N_SIMD_GROUPS) {
// Do transform and store the result
for (int c = simd_group_id; c < BC; c += N_SIMD_GROUPS) {
simdgroup_matrix<T, 8, 8> I;
I.thread_elements()[0] = Is[sm][sn][c];
I.thread_elements()[1] = Is[sm][sn + 1][c];
@@ -440,28 +504,24 @@ template <typename T,
inp_out_0 += BC;
inp_out_1 += BC;
}
}
#define instantiate_winograd_conv_2d_input_transform(name, itype, bc) \
template [[host_name("winograd_conv_2d_input_transform_" #name "_bc" #bc)]]\
[[kernel]] void winograd_conv_2d_input_transform<itype, bc, 2, 2>(\
const device itype* inp_in [[buffer(0)]],\
device itype* inp_out [[buffer(1)]],\
const constant MLXConvParams<2>& params [[buffer(2)]],\
uint3 tid [[threadgroup_position_in_grid]],\
uint3 lid [[thread_position_in_threadgroup]],\
uint3 tgp_per_grid [[threadgroups_per_grid]],\
uint simd_group_id [[simdgroup_index_in_threadgroup]],\
template [[host_name("winograd_conv_2d_input_transform_" #name \
"_bc" #bc)]] [[kernel]] void \
winograd_conv_2d_input_transform<itype, bc, 2, 2>( \
const device itype* inp_in [[buffer(0)]], \
device itype* inp_out [[buffer(1)]], \
const constant MLXConvParams<2>& params [[buffer(2)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 tgp_per_grid [[threadgroups_per_grid]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]]);
template <typename T,
int BO,
int WM,
int WN,
int M = 6,
int R = 3>
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void winograd_conv_2d_output_transform(
template <typename T, int BO, int WM, int WN, int M = 6, int R = 3>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
winograd_conv_2d_output_transform(
const device T* out_in [[buffer(0)]],
device T* out_out [[buffer(1)]],
const constant MLXConvParams<2>& params [[buffer(2)]],
@@ -470,7 +530,6 @@ template <typename T,
uint3 tgp_per_grid [[threadgroups_per_grid]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]]) {
(void)lid;
using WGT = WinogradTransforms<M, R, 8>;
@@ -503,57 +562,59 @@ template <typename T,
int bw = M * tid.x + kw;
// Move to the correct input tile
out_out += tid.z * params.out_strides[0]
+ bh * params.out_strides[1]
+ bw * params.out_strides[2];
out_out += tid.z * params.out_strides[0] + bh * params.out_strides[1] +
bw * params.out_strides[2];
// Pre compute strides
// Pre compute strides
int jump_in[TH][TW];
for(int h = 0; h < TH; h++) {
for(int w = 0; w < TW; w++) {
for (int h = 0; h < TH; h++) {
for (int w = 0; w < TW; w++) {
bool valid = ((bh + h) < params.oS[0]) && ((bw + w) < params.oS[1]);
jump_in[h][w] = valid ? h * params.out_strides[1] + w * params.out_strides[2] : -1;
jump_in[h][w] =
valid ? h * params.out_strides[1] + w * params.out_strides[2] : -1;
}
}
// out_in is stored interleaved (A x A x tiles x O)
size_t N_TILES = tgp_per_grid.x * tgp_per_grid.y * tgp_per_grid.z;
size_t tile_id = tid.z * tgp_per_grid.x * tgp_per_grid.y + tid.y * tgp_per_grid.x + tid.x;
size_t tile_id =
tid.z * tgp_per_grid.x * tgp_per_grid.y + tid.y * tgp_per_grid.x + tid.x;
size_t ohw_0 = sm * 8 + sn;
size_t ohw_1 = sm * 8 + sn + 1;
const device T* out_in_0 = out_in + ohw_0 * N_TILES * params.O + tile_id * params.O;
const device T* out_in_1 = out_in + ohw_1 * N_TILES * params.O + tile_id * params.O;
const device T* out_in_0 =
out_in + ohw_0 * N_TILES * params.O + tile_id * params.O;
const device T* out_in_1 =
out_in + ohw_1 * N_TILES * params.O + tile_id * params.O;
// Prepare shared memory
threadgroup T Os[M][M][BO];
// Loop over O
for(int bo = 0; bo < params.O; bo += BO) {
for (int bo = 0; bo < params.O; bo += BO) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Do transform and store the result
for(int c = simd_group_id; c < BO; c += N_SIMD_GROUPS) {
// Do transform and store the result
for (int c = simd_group_id; c < BO; c += N_SIMD_GROUPS) {
simdgroup_matrix<T, 8, 8> O_mat;
O_mat.thread_elements()[0] = out_in_0[c];
O_mat.thread_elements()[1] = out_in_1[c];
simdgroup_matrix<T, 8, 8> O_out = (Bt * (O_mat * B));
if((sm < M) && (sn < M)) {
if ((sm < M) && (sn < M)) {
Os[sm][sn][c] = O_out.thread_elements()[0];
}
if((sm < M) && ((sn + 1) < M)) {
if ((sm < M) && ((sn + 1) < M)) {
Os[sm][sn + 1][c] = O_out.thread_elements()[1];
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Read out from shared memory
for(int h = 0; h < TH; h++) {
for(int w = 0; w < TW; w++) {
if(jump_in[h][w] >= 0) {
for (int h = 0; h < TH; h++) {
for (int w = 0; w < TW; w++) {
if (jump_in[h][w] >= 0) {
device T* out_ptr = out_out + jump_in[h][w];
for(int c = simd_lane_id; c < BO; c += 32) {
for (int c = simd_lane_id; c < BO; c += 32) {
out_ptr[c] = Os[kh + h][kw + w][c];
}
}
@@ -564,25 +625,27 @@ template <typename T,
out_in_0 += BO;
out_in_1 += BO;
}
}
#define instantiate_winograd_conv_2d_output_transform(name, itype, bo) \
template [[host_name("winograd_conv_2d_output_transform_" #name "_bo" #bo)]]\
[[kernel]] void winograd_conv_2d_output_transform<itype, bo, 2, 2>(\
const device itype* out_in [[buffer(0)]],\
device itype* out_out [[buffer(1)]],\
const constant MLXConvParams<2>& params [[buffer(2)]],\
uint3 tid [[threadgroup_position_in_grid]],\
uint3 lid [[thread_position_in_threadgroup]],\
uint3 tgp_per_grid [[threadgroups_per_grid]],\
uint simd_group_id [[simdgroup_index_in_threadgroup]],\
template [[host_name("winograd_conv_2d_output_transform_" #name \
"_bo" #bo)]] [[kernel]] void \
winograd_conv_2d_output_transform<itype, bo, 2, 2>( \
const device itype* out_in [[buffer(0)]], \
device itype* out_out [[buffer(1)]], \
const constant MLXConvParams<2>& params [[buffer(2)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 tgp_per_grid [[threadgroups_per_grid]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]]);
#define instantiate_winograd_conv_2d(name, itype) \
// clang-format off
#define instantiate_winograd_conv_2d(name, itype) \
instantiate_winograd_conv_2d_weight_transform_base(name, itype, 32) \
instantiate_winograd_conv_2d_input_transform(name, itype, 32) \
instantiate_winograd_conv_2d_output_transform(name, itype, 32)
instantiate_winograd_conv_2d_input_transform(name, itype, 32) \
instantiate_winograd_conv_2d_output_transform(name, itype, 32) // clang-format on
// clang-format off
instantiate_winograd_conv_2d(float32, float);
instantiate_winograd_conv_2d(float16, half);
instantiate_winograd_conv_2d(float16, half); // clang-format on

View File

@@ -49,7 +49,8 @@ template <typename T, typename U>
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc_3(index, src_strides);
int64_t dst_idx = index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
int64_t dst_idx =
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
@@ -62,7 +63,8 @@ template <typename T, typename U, int DIM>
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
int64_t dst_idx = index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
int64_t dst_idx =
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
@@ -76,7 +78,8 @@ template <typename T, typename U>
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim);
int64_t dst_idx = index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
int64_t dst_idx =
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
@@ -143,116 +146,110 @@ template <typename T, typename U>
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
#define instantiate_copy(name, itype, otype, ctype) \
template [[host_name(name)]] \
[[kernel]] void copy_##ctype<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
#define instantiate_copy(name, itype, otype, ctype) \
template [[host_name(name)]] [[kernel]] void copy_##ctype<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
uint index [[thread_position_in_grid]]);
#define instantiate_copy_g_dim(name, itype, otype, dims) \
template [[host_name(name "_" #dims)]] \
[[kernel]] void copy_g_nd<itype, otype, dims>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int* src_shape [[buffer(2)]], \
constant const int64_t* src_strides [[buffer(3)]], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
template [[host_name("g" name "_" #dims)]] \
[[kernel]] void copy_gg_nd<itype, otype, dims>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int* src_shape [[buffer(2)]], \
constant const int64_t* src_strides [[buffer(3)]], \
constant const int64_t* dst_strides [[buffer(4)]], \
#define instantiate_copy_g_dim(name, itype, otype, dims) \
template [[host_name(name "_" #dims)]] [[kernel]] void \
copy_g_nd<itype, otype, dims>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int* src_shape [[buffer(2)]], \
constant const int64_t* src_strides [[buffer(3)]], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
template [[host_name("g" name "_" #dims)]] [[kernel]] void \
copy_gg_nd<itype, otype, dims>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int* src_shape [[buffer(2)]], \
constant const int64_t* src_strides [[buffer(3)]], \
constant const int64_t* dst_strides [[buffer(4)]], \
uint3 index [[thread_position_in_grid]]);
#define instantiate_copy_g_nd(name, itype, otype) \
template [[host_name(name "_1")]] [[kernel]] void copy_g_nd1<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t& src_stride [[buffer(3)]], \
uint index [[thread_position_in_grid]]); \
template [[host_name(name "_2")]] [[kernel]] void copy_g_nd2<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t* src_strides [[buffer(3)]], \
uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]); \
template [[host_name(name "_3")]] [[kernel]] void copy_g_nd3<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t* src_strides [[buffer(3)]], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
template [[host_name("g" name "_1")]] [[kernel]] void \
copy_gg_nd1<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t& src_stride [[buffer(3)]], \
constant const int64_t& dst_stride [[buffer(4)]], \
uint index [[thread_position_in_grid]]); \
template [[host_name("g" name "_2")]] [[kernel]] void \
copy_gg_nd2<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t* src_strides [[buffer(3)]], \
constant const int64_t* dst_strides [[buffer(4)]], \
uint2 index [[thread_position_in_grid]]); \
template [[host_name("g" name "_3")]] [[kernel]] void \
copy_gg_nd3<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t* src_strides [[buffer(3)]], \
constant const int64_t* dst_strides [[buffer(4)]], \
uint3 index [[thread_position_in_grid]]); \
instantiate_copy_g_dim(name, itype, otype, 4) \
instantiate_copy_g_dim(name, itype, otype, 5)
#define instantiate_copy_g_nd(name, itype, otype) \
template [[host_name(name "_1")]] \
[[kernel]] void copy_g_nd1<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t& src_stride [[buffer(3)]], \
uint index [[thread_position_in_grid]]); \
template [[host_name(name "_2")]] \
[[kernel]] void copy_g_nd2<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t* src_strides [[buffer(3)]], \
uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]); \
template [[host_name(name "_3")]] \
[[kernel]] void copy_g_nd3<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t* src_strides [[buffer(3)]], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
template [[host_name("g" name "_1")]] \
[[kernel]] void copy_gg_nd1<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t& src_stride [[buffer(3)]], \
constant const int64_t& dst_stride [[buffer(4)]], \
uint index [[thread_position_in_grid]]); \
template [[host_name("g" name "_2")]] \
[[kernel]] void copy_gg_nd2<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t* src_strides [[buffer(3)]], \
constant const int64_t* dst_strides [[buffer(4)]], \
uint2 index [[thread_position_in_grid]]); \
template [[host_name("g" name "_3")]] \
[[kernel]] void copy_gg_nd3<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t* src_strides [[buffer(3)]], \
constant const int64_t* dst_strides [[buffer(4)]], \
uint3 index [[thread_position_in_grid]]); \
instantiate_copy_g_dim(name, itype, otype, 4) \
instantiate_copy_g_dim(name, itype, otype, 5)
#define instantiate_copy_g(name, itype, otype) \
template [[host_name(name)]] \
[[kernel]] void copy_g<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int* src_shape [[buffer(2)]], \
constant const int64_t* src_strides [[buffer(3)]], \
constant const int& ndim [[buffer(5)]], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
template [[host_name("g" name)]] \
[[kernel]] void copy_gg<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int* src_shape [[buffer(2)]], \
constant const int64_t* src_strides [[buffer(3)]], \
constant const int64_t* dst_strides [[buffer(4)]], \
constant const int& ndim [[buffer(5)]], \
#define instantiate_copy_g(name, itype, otype) \
template [[host_name(name)]] [[kernel]] void copy_g<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int* src_shape [[buffer(2)]], \
constant const int64_t* src_strides [[buffer(3)]], \
constant const int& ndim [[buffer(5)]], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
template [[host_name("g" name)]] [[kernel]] void copy_gg<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int* src_shape [[buffer(2)]], \
constant const int64_t* src_strides [[buffer(3)]], \
constant const int64_t* dst_strides [[buffer(4)]], \
constant const int& ndim [[buffer(5)]], \
uint3 index [[thread_position_in_grid]]);
#define instantiate_copy_all(tname, itype, otype) \
// clang-format off
#define instantiate_copy_all(tname, itype, otype) \
instantiate_copy("scopy" #tname, itype, otype, s) \
instantiate_copy("vcopy" #tname, itype, otype, v) \
instantiate_copy_g("gcopy" #tname, itype, otype) \
instantiate_copy_g_nd("gcopy" #tname, itype, otype)
instantiate_copy_g("gcopy" #tname, itype, otype) \
instantiate_copy_g_nd("gcopy" #tname, itype, otype) // clang-format on
#define instantiate_copy_itype(itname, itype) \
instantiate_copy_all(itname ##bool_, itype, bool) \
instantiate_copy_all(itname ##uint8, itype, uint8_t) \
instantiate_copy_all(itname ##uint16, itype, uint16_t) \
instantiate_copy_all(itname ##uint32, itype, uint32_t) \
instantiate_copy_all(itname ##uint64, itype, uint64_t) \
instantiate_copy_all(itname ##int8, itype, int8_t) \
instantiate_copy_all(itname ##int16, itype, int16_t) \
instantiate_copy_all(itname ##int32, itype, int32_t) \
instantiate_copy_all(itname ##int64, itype, int64_t) \
instantiate_copy_all(itname ##float16, itype, half) \
instantiate_copy_all(itname ##float32, itype, float) \
// clang-format off
#define instantiate_copy_itype(itname, itype) \
instantiate_copy_all(itname ##bool_, itype, bool) \
instantiate_copy_all(itname ##uint8, itype, uint8_t) \
instantiate_copy_all(itname ##uint16, itype, uint16_t) \
instantiate_copy_all(itname ##uint32, itype, uint32_t) \
instantiate_copy_all(itname ##uint64, itype, uint64_t) \
instantiate_copy_all(itname ##int8, itype, int8_t) \
instantiate_copy_all(itname ##int16, itype, int16_t) \
instantiate_copy_all(itname ##int32, itype, int32_t) \
instantiate_copy_all(itname ##int64, itype, int64_t) \
instantiate_copy_all(itname ##float16, itype, half) \
instantiate_copy_all(itname ##float32, itype, float) \
instantiate_copy_all(itname ##bfloat16, itype, bfloat16_t) \
instantiate_copy_all(itname ##complex64, itype, complex64_t)
@@ -268,4 +265,4 @@ instantiate_copy_itype(int64, int64_t)
instantiate_copy_itype(float16, half)
instantiate_copy_itype(float32, float)
instantiate_copy_itype(bfloat16, bfloat16_t)
instantiate_copy_itype(complex64, complex64_t)
instantiate_copy_itype(complex64, complex64_t) // clang-format on

View File

@@ -6,9 +6,8 @@
// - VkFFT (https://github.com/DTolm/VkFFT)
// - Eric Bainville's excellent page (http://www.bealto.com/gpu-fft.html)
#include <metal_math>
#include <metal_common>
#include <metal_math>
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/utils.h"
@@ -23,7 +22,7 @@ float2 complex_mul(float2 a, float2 b) {
}
float2 get_twiddle(int k, int p) {
float theta = -1.0f * k * M_PI_F / (2*p);
float theta = -1.0f * k * M_PI_F / (2 * p);
float2 twiddle;
twiddle.x = metal::fast::cos(theta);
@@ -32,7 +31,12 @@ float2 get_twiddle(int k, int p) {
}
// single threaded radix2 implemetation
void radix2(int i, int p, int m, threadgroup float2* read_buf, threadgroup float2* write_buf) {
void radix2(
int i,
int p,
int m,
threadgroup float2* read_buf,
threadgroup float2* write_buf) {
float2 x_0 = read_buf[i];
float2 x_1 = read_buf[i + m];
@@ -53,11 +57,16 @@ void radix2(int i, int p, int m, threadgroup float2* read_buf, threadgroup float
}
// single threaded radix4 implemetation
void radix4(int i, int p, int m, threadgroup float2* read_buf, threadgroup float2* write_buf) {
void radix4(
int i,
int p,
int m,
threadgroup float2* read_buf,
threadgroup float2* write_buf) {
float2 x_0 = read_buf[i];
float2 x_1 = read_buf[i + m];
float2 x_2 = read_buf[i + 2*m];
float2 x_3 = read_buf[i + 3*m];
float2 x_2 = read_buf[i + 2 * m];
float2 x_3 = read_buf[i + 3 * m];
// The index within this sub-DFT
int k = i & (p - 1);
@@ -90,11 +99,10 @@ void radix4(int i, int p, int m, threadgroup float2* read_buf, threadgroup float
write_buf[j] = y_0;
write_buf[j + p] = y_1;
write_buf[j + 2*p] = y_2;
write_buf[j + 3*p] = y_3;
write_buf[j + 2 * p] = y_2;
write_buf[j + 3 * p] = y_3;
}
// Each FFT is computed entirely in shared GPU memory.
//
// N is decomposed into radix-2 and radix-4 DFTs:
@@ -107,11 +115,10 @@ void radix4(int i, int p, int m, threadgroup float2* read_buf, threadgroup float
// steps at compile time for a ~20% performance boost.
template <size_t n, size_t radix_2_steps, size_t radix_4_steps>
[[kernel]] void fft(
const device float2 *in [[buffer(0)]],
device float2 * out [[buffer(1)]],
const device float2* in [[buffer(0)]],
device float2* out [[buffer(1)]],
uint3 thread_position_in_grid [[thread_position_in_grid]],
uint3 threads_per_grid [[threads_per_grid]]) {
// Index of the DFT in batch
int batch_idx = thread_position_in_grid.x * n;
// The index in the DFT we're working on
@@ -132,16 +139,16 @@ template <size_t n, size_t radix_2_steps, size_t radix_4_steps>
// Copy input into shared memory
shared_in[i] = in[batch_idx + i];
shared_in[i + m] = in[batch_idx + i + m];
shared_in[i + 2*m] = in[batch_idx + i + 2*m];
shared_in[i + 3*m] = in[batch_idx + i + 3*m];
shared_in[i + 2 * m] = in[batch_idx + i + 2 * m];
shared_in[i + 3 * m] = in[batch_idx + i + 3 * m];
threadgroup_barrier(mem_flags::mem_threadgroup);
int p = 1;
for (size_t r = 0; r < radix_2_steps; r++) {
radix2(i, p, m*2, read_buf, write_buf);
radix2(i + m, p, m*2, read_buf, write_buf);
radix2(i, p, m * 2, read_buf, write_buf);
radix2(i + m, p, m * 2, read_buf, write_buf);
p *= 2;
threadgroup_barrier(mem_flags::mem_threadgroup);
@@ -167,29 +174,26 @@ template <size_t n, size_t radix_2_steps, size_t radix_4_steps>
// Copy shared memory to output
out[batch_idx + i] = read_buf[i];
out[batch_idx + i + m] = read_buf[i + m];
out[batch_idx + i + 2*m] = read_buf[i + 2*m];
out[batch_idx + i + 3*m] = read_buf[i + 3*m];
out[batch_idx + i + 2 * m] = read_buf[i + 2 * m];
out[batch_idx + i + 3 * m] = read_buf[i + 3 * m];
}
#define instantiate_fft(name, n, radix_2_steps, radix_4_steps) \
template [[host_name("fft_" #name)]] \
[[kernel]] void fft<n, radix_2_steps, radix_4_steps>( \
const device float2* in [[buffer(0)]], \
device float2* out [[buffer(1)]], \
uint3 thread_position_in_grid [[thread_position_in_grid]], \
uint3 threads_per_grid [[threads_per_grid]]);
#define instantiate_fft(name, n, radix_2_steps, radix_4_steps) \
template [[host_name("fft_" #name)]] [[kernel]] void \
fft<n, radix_2_steps, radix_4_steps>( \
const device float2* in [[buffer(0)]], \
device float2* out [[buffer(1)]], \
uint3 thread_position_in_grid [[thread_position_in_grid]], \
uint3 threads_per_grid [[threads_per_grid]]);
// Explicitly define kernels for each power of 2.
// clang-format off
instantiate_fft(4, /* n= */ 4, /* radix_2_steps= */ 0, /* radix_4_steps= */ 1)
instantiate_fft(8, 8, 1, 1)
instantiate_fft(16, 16, 0, 2)
instantiate_fft(32, 32, 1, 2)
instantiate_fft(64, 64, 0, 3)
instantiate_fft(128, 128, 1, 3)
instantiate_fft(256, 256, 0, 4)
instantiate_fft(8, 8, 1, 1) instantiate_fft(16, 16, 0, 2)
instantiate_fft(32, 32, 1, 2) instantiate_fft(64, 64, 0, 3)
instantiate_fft(128, 128, 1, 3) instantiate_fft(256, 256, 0, 4)
instantiate_fft(512, 512, 1, 4)
instantiate_fft(1024, 1024, 0, 5)
// 2048 is the max that will fit into 32KB of threadgroup memory.
// TODO: implement 4 step FFT for larger n.
instantiate_fft(2048, 2048, 1, 5)
instantiate_fft(2048, 2048, 1, 5) // clang-format on

View File

@@ -14,17 +14,16 @@ using namespace metal;
template <typename T, typename IdxT, int NIDX, int IDX_NDIM>
METAL_FUNC void gather_impl(
const device T *src [[buffer(0)]],
device T *out [[buffer(1)]],
const constant int *src_shape [[buffer(2)]],
const constant size_t *src_strides [[buffer(3)]],
const device T* src [[buffer(0)]],
device T* out [[buffer(1)]],
const constant int* src_shape [[buffer(2)]],
const constant size_t* src_strides [[buffer(3)]],
const constant size_t& src_ndim [[buffer(4)]],
const constant int *slice_sizes [[buffer(5)]],
const constant int *axes [[buffer(6)]],
const constant int* slice_sizes [[buffer(5)]],
const constant int* axes [[buffer(6)]],
const thread Indices<IdxT, NIDX>& indices,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto ind_idx = index.x;
auto ind_offset = index.y;
@@ -43,93 +42,78 @@ METAL_FUNC void gather_impl(
indices.ndim);
}
auto ax = axes[i];
auto idx_val = offset_neg_idx(
indices.buffers[i][idx_loc], src_shape[ax]);
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], src_shape[ax]);
src_idx += idx_val * src_strides[ax];
}
auto src_offset = elem_to_loc(
ind_offset, slice_sizes, src_strides, src_ndim);
auto src_offset = elem_to_loc(ind_offset, slice_sizes, src_strides, src_ndim);
size_t out_idx = index.y + static_cast<size_t>(grid_dim.y) * index.x;
out[out_idx] = src[src_offset + src_idx];
}
#define make_gather_impl(IDX_ARG, IDX_ARR) \
template <typename T, typename IdxT, int NIDX, int IDX_NDIM> \
[[kernel]] void gather( \
const device T *src [[buffer(0)]], \
device T *out [[buffer(1)]], \
const constant int *src_shape [[buffer(2)]], \
const constant size_t *src_strides [[buffer(3)]], \
const constant size_t& src_ndim [[buffer(4)]], \
const constant int *slice_sizes [[buffer(5)]], \
const constant int *axes [[buffer(6)]], \
const constant int *idx_shapes [[buffer(7)]], \
const constant size_t *idx_strides [[buffer(8)]], \
const constant int& idx_ndim [[buffer(9)]], \
IDX_ARG(IdxT) \
uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]) { \
\
Indices<IdxT, NIDX> idxs{ \
{{IDX_ARR()}}, \
idx_shapes, \
idx_strides, \
idx_ndim}; \
\
return gather_impl<T, IdxT, NIDX, IDX_NDIM>( \
src, \
out, \
src_shape, \
src_strides, \
src_ndim, \
slice_sizes, \
axes, \
idxs, \
index, \
grid_dim); \
}
#define make_gather_impl(IDX_ARG, IDX_ARR) \
template <typename T, typename IdxT, int NIDX, int IDX_NDIM> \
[[kernel]] void gather( \
const device T* src [[buffer(0)]], \
device T* out [[buffer(1)]], \
const constant int* src_shape [[buffer(2)]], \
const constant size_t* src_strides [[buffer(3)]], \
const constant size_t& src_ndim [[buffer(4)]], \
const constant int* slice_sizes [[buffer(5)]], \
const constant int* axes [[buffer(6)]], \
const constant int* idx_shapes [[buffer(7)]], \
const constant size_t* idx_strides [[buffer(8)]], \
const constant int& idx_ndim [[buffer(9)]], \
IDX_ARG(IdxT) uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]) { \
Indices<IdxT, NIDX> idxs{ \
{{IDX_ARR()}}, idx_shapes, idx_strides, idx_ndim}; \
\
return gather_impl<T, IdxT, NIDX, IDX_NDIM>( \
src, \
out, \
src_shape, \
src_strides, \
src_ndim, \
slice_sizes, \
axes, \
idxs, \
index, \
grid_dim); \
}
#define make_gather(n) make_gather_impl(IDX_ARG_ ##n, IDX_ARR_ ##n)
#define make_gather(n) make_gather_impl(IDX_ARG_##n, IDX_ARR_##n)
make_gather(0)
make_gather(1)
make_gather(2)
make_gather(3)
make_gather(4)
make_gather(5)
make_gather(6)
make_gather(7)
make_gather(8)
make_gather(9)
make_gather(10)
make_gather(0) make_gather(1) make_gather(2) make_gather(3) make_gather(4)
make_gather(5) make_gather(6) make_gather(7) make_gather(8) make_gather(9)
make_gather(10)
/////////////////////////////////////////////////////////////////////
// Gather instantiations
/////////////////////////////////////////////////////////////////////
#define instantiate_gather6(name, src_t, idx_t, nidx, IDX_ARG, nd, nd_name) \
template [[host_name("gather" name "_" #nidx "" #nd_name)]] \
[[kernel]] void gather<src_t, idx_t, nidx, nd>( \
const device src_t *src [[buffer(0)]], \
device src_t *out [[buffer(1)]], \
const constant int *src_shape [[buffer(2)]], \
const constant size_t *src_strides [[buffer(3)]], \
const constant size_t& src_ndim [[buffer(4)]], \
const constant int *slice_sizes [[buffer(5)]], \
const constant int *axes [[buffer(6)]], \
const constant int *idx_shapes [[buffer(7)]], \
const constant size_t *idx_strides [[buffer(8)]], \
const constant int& idx_ndim [[buffer(9)]], \
IDX_ARG(idx_t) \
uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]);
#define instantiate_gather6(name, src_t, idx_t, nidx, IDX_ARG, nd, nd_name) \
template [[host_name("gather" name "_" #nidx "" #nd_name)]] [[kernel]] void \
gather<src_t, idx_t, nidx, nd>( \
const device src_t* src [[buffer(0)]], \
device src_t* out [[buffer(1)]], \
const constant int* src_shape [[buffer(2)]], \
const constant size_t* src_strides [[buffer(3)]], \
const constant size_t& src_ndim [[buffer(4)]], \
const constant int* slice_sizes [[buffer(5)]], \
const constant int* axes [[buffer(6)]], \
const constant int* idx_shapes [[buffer(7)]], \
const constant size_t* idx_strides [[buffer(8)]], \
const constant int& idx_ndim [[buffer(9)]], \
IDX_ARG(idx_t) uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]);
// clang-format off
#define instantiate_gather5(name, src_t, idx_t, nidx, nd, nd_name) \
instantiate_gather6(name, src_t, idx_t, nidx, IDX_ARG_ ##nidx, nd, nd_name)
instantiate_gather6(name, src_t, idx_t, nidx, IDX_ARG_ ##nidx, nd, nd_name) // clang-format on
// clang-format off
#define instantiate_gather4(name, src_t, idx_t, nidx) \
instantiate_gather5(name, src_t, idx_t, nidx, 0, _0) \
instantiate_gather5(name, src_t, idx_t, nidx, 1, _1) \
@@ -148,29 +132,31 @@ instantiate_gather4("int32", int32_t, bool, 0)
instantiate_gather4("int64", int64_t, bool, 0)
instantiate_gather4("float16", half, bool, 0)
instantiate_gather4("float32", float, bool, 0)
instantiate_gather4("bfloat16", bfloat16_t, bool, 0)
instantiate_gather4("bfloat16", bfloat16_t, bool, 0) // clang-format on
// clang-format off
#define instantiate_gather3(name, src_type, ind_type) \
instantiate_gather4(name, src_type, ind_type, 1) \
instantiate_gather4(name, src_type, ind_type, 2) \
instantiate_gather4(name, src_type, ind_type, 3) \
instantiate_gather4(name, src_type, ind_type, 4) \
instantiate_gather4(name, src_type, ind_type, 5) \
instantiate_gather4(name, src_type, ind_type, 6) \
instantiate_gather4(name, src_type, ind_type, 7) \
instantiate_gather4(name, src_type, ind_type, 8) \
instantiate_gather4(name, src_type, ind_type, 9) \
instantiate_gather4(name, src_type, ind_type, 10)
instantiate_gather4(name, src_type, ind_type, 1) \
instantiate_gather4(name, src_type, ind_type, 2) \
instantiate_gather4(name, src_type, ind_type, 3) \
instantiate_gather4(name, src_type, ind_type, 4) \
instantiate_gather4(name, src_type, ind_type, 5) \
instantiate_gather4(name, src_type, ind_type, 6) \
instantiate_gather4(name, src_type, ind_type, 7) \
instantiate_gather4(name, src_type, ind_type, 8) \
instantiate_gather4(name, src_type, ind_type, 9) \
instantiate_gather4(name, src_type, ind_type, 10) // clang-format on
#define instantiate_gather(name, src_type) \
instantiate_gather3(#name "bool_", src_type, bool) \
instantiate_gather3(#name "uint8", src_type, uint8_t) \
// clang-format off
#define instantiate_gather(name, src_type) \
instantiate_gather3(#name "bool_", src_type, bool) \
instantiate_gather3(#name "uint8", src_type, uint8_t) \
instantiate_gather3(#name "uint16", src_type, uint16_t) \
instantiate_gather3(#name "uint32", src_type, uint32_t) \
instantiate_gather3(#name "uint64", src_type, uint64_t) \
instantiate_gather3(#name "int8", src_type, int8_t) \
instantiate_gather3(#name "int16", src_type, int16_t) \
instantiate_gather3(#name "int32", src_type, int32_t) \
instantiate_gather3(#name "int8", src_type, int8_t) \
instantiate_gather3(#name "int16", src_type, int16_t) \
instantiate_gather3(#name "int32", src_type, int32_t) \
instantiate_gather3(#name "int64", src_type, int64_t)
instantiate_gather(bool_, bool)
@@ -184,4 +170,4 @@ instantiate_gather(int32, int32_t)
instantiate_gather(int64, int64_t)
instantiate_gather(float16, half)
instantiate_gather(float32, float)
instantiate_gather(bfloat16, bfloat16_t)
instantiate_gather(bfloat16, bfloat16_t) // clang-format on

View File

@@ -1,12 +1,14 @@
// Copyright © 2023-2024 Apple Inc.
#include <metal_stdlib>
#include <metal_simdgroup>
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/utils.h"
using namespace metal;
///////////////////////////////////////////////////////////////////////////////
@@ -18,33 +20,34 @@ using namespace metal;
MLX_MTL_CONST int SIMD_SIZE = 32;
template <
typename T,
const int BM, /* Threadgroup rows (in threads) */
const int BN, /* Threadgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN , /* Thread cols (in elements) */
const bool kDoAxpby> /* Do out = alpha * out + beta * bias */
typename T,
const int BM, /* Threadgroup rows (in threads) */
const int BN, /* Threadgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN, /* Thread cols (in elements) */
const bool kDoAxpby> /* Do out = alpha * out + beta * bias */
struct GEMVKernel {
static_assert(BN == SIMD_SIZE, "gemv block must have a width of SIMD_SIZE");
// - The matrix of size (M = out_vec_size, N = in_vec_size) is divided up
// - The matrix of size (M = out_vec_size, N = in_vec_size) is divided up
// into blocks of (BM * TM, BN * TN) divided among threadgroups
// - Every thread works on a block of (TM, TN)
// - We assume each thead group is launched with (BN, BM, 1) threads
//
// 1. A thread loads TN elements each from mat along TM contiguous rows
// and the corresponding scalar from the vector
// 2. The thread then multiplies and adds to accumulate its local result for the block
// 3. At the end, each thread has accumulated results over all blocks across the rows
// These are then summed up across the threadgroup
// 1. A thread loads TN elements each from mat along TM contiguous rows
// and the corresponding scalar from the vector
// 2. The thread then multiplies and adds to accumulate its local result for
// the block
// 3. At the end, each thread has accumulated results over all blocks across
// the rows. These are then summed up across the threadgroup
// 4. Each threadgroup writes its accumulated BN * TN outputs
//
// Edge case handling:
// - The threadgroup with the largest tid will have blocks that exceed the matrix
// * The blocks that start outside the matrix are never read (thread results remain zero)
// * The last thread that partially overlaps with the matrix is shifted inwards
// such that the thread block fits exactly in the matrix
// - The threadgroup with the largest tid has blocks that exceed the matrix
// * The blocks that start outside the matrix are never read (thread results
// remain zero)
// * The last thread that partially overlaps with the matrix is shifted
// inwards such that the thread block fits exactly in the matrix
MLX_MTL_CONST short tgp_mem_size = BN * TN * 2;
@@ -52,7 +55,7 @@ struct GEMVKernel {
const device T* mat [[buffer(0)]],
const device T* in_vec [[buffer(1)]],
const device T* bias [[buffer(2)]],
device T* out_vec [[buffer(3)]],
device T* out_vec [[buffer(3)]],
const constant int& in_vec_size [[buffer(4)]],
const constant int& out_vec_size [[buffer(5)]],
const constant int& marix_ld [[buffer(6)]],
@@ -64,14 +67,13 @@ struct GEMVKernel {
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
// Appease compiler
// Appease compiler
(void)lid;
// Threadgroup in_vec cache
threadgroup T* in_vec_block = tgp_memory + simd_lid * TN * 2;
// Thread local accumulation results
// Thread local accumulation results
thread T result[TM] = {0};
thread T inter[TN];
thread T v_coeff[TN];
@@ -80,7 +82,7 @@ struct GEMVKernel {
int out_row = (tid.x * BM + simd_gid) * TM;
// Exit simdgroup if rows out of bound
if(out_row >= out_vec_size)
if (out_row >= out_vec_size)
return;
// Adjust tail simdgroup to ensure in bound reads
@@ -90,89 +92,81 @@ struct GEMVKernel {
mat += out_row * marix_ld;
// Loop over in_vec in blocks of BN * TN
for(int bn = simd_lid * TN; bn < in_vec_size; bn += BN * TN) {
for (int bn = simd_lid * TN; bn < in_vec_size; bn += BN * TN) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Prefetch in_vector for threadgroup use
if(simd_gid == 0) {
if (simd_gid == 0) {
// Main load loop
if(bn + TN <= in_vec_size) {
#pragma clang loop unroll(full)
for(int tn = 0; tn < TN; tn++) {
if (bn + TN <= in_vec_size) {
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
in_vec_block[tn] = in_vec[bn + tn];
}
} else { // Edgecase
#pragma clang loop unroll(full)
for(int tn = 0; tn < TN; tn++) {
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
in_vec_block[tn] = bn + tn < in_vec_size ? in_vec[bn + tn] : 0;
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load for all rows
#pragma clang loop unroll(full)
for(int tn = 0; tn < TN; tn++) {
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
v_coeff[tn] = in_vec_block[tn];
}
// Per thread work loop
#pragma clang loop unroll(full)
for(int tm = 0; tm < TM; tm++) {
// Load for the row
if(bn + TN <= in_vec_size) {
#pragma clang loop unroll(full)
for(int tn = 0; tn < TN; tn++) {
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
// Load for the row
if (bn + TN <= in_vec_size) {
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
inter[tn] = mat[tm * marix_ld + bn + tn];
}
} else { // Edgecase
#pragma clang loop unroll(full)
for(int tn = 0; tn < TN; tn++) {
int col_idx = (bn + tn) < in_vec_size ? (bn + tn) : (in_vec_size - 1);
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
int col_idx =
(bn + tn) < in_vec_size ? (bn + tn) : (in_vec_size - 1);
inter[tn] = mat[tm * marix_ld + col_idx];
}
}
// Accumulate results
for(int tn = 0; tn < TN; tn++) {
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
result[tm] += inter[tn] * v_coeff[tn];
}
}
}
// Simdgroup accumulations
#pragma clang loop unroll(full)
for(int tm = 0; tm < TM; tm++) {
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
result[tm] = simd_sum(result[tm]);
}
// Write outputs
if(simd_lid == 0) {
#pragma clang loop unroll(full)
for(int tm = 0; tm < TM; tm++) {
if(kDoAxpby) {
out_vec[out_row + tm] =
static_cast<T>(alpha) * result[tm] +
if (simd_lid == 0) {
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
if (kDoAxpby) {
out_vec[out_row + tm] = static_cast<T>(alpha) * result[tm] +
static_cast<T>(beta) * bias[(out_row + tm) * bias_stride];
} else {
out_vec[out_row + tm] = result[tm];
}
}
}
}
};
///////////////////////////////////////////////////////////////////////////////
@@ -180,32 +174,31 @@ struct GEMVKernel {
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
const int BM, /* Threadgroup rows (in threads) */
const int BN, /* Threadgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN, /* Thread cols (in elements) */
const bool kDoAxpby> /* Do out = alpha * out + beta * bias */
typename T,
const int BM, /* Threadgroup rows (in threads) */
const int BN, /* Threadgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN, /* Thread cols (in elements) */
const bool kDoAxpby> /* Do out = alpha * out + beta * bias */
struct GEMVTKernel {
// - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up
// - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up
// into blocks of (BM * TM, BN * TN) divided among threadgroups
// - Every thread works on a block of (TM, TN)
// - We assume each thead group is launched with (BN, BM, 1) threads
//
// 1. A thread loads TN elements each from mat along TM contiguous rows
// and the corresponding scalar from the vector
// 2. The thread then multiplies and adds to accumulate its local result for the block
// 3. At the end, each thread has accumulated results over all blocks across the rows
// These are then summed up across the threadgroup
// 1. A thread loads TN elements each from mat along TM contiguous rows
// and the corresponding scalar from the vector
// 2. The thread then accumulates its local result for the block
// 3. At the end, each thread has accumulated results over all blocks across
// the rows. These are then summed up across the threadgroup
// 4. Each threadgroup writes its accumulated BN * TN outputs
//
// Edge case handling:
// - The threadgroup with the largest tid will have blocks that exceed the matrix
// * The blocks that start outside the matrix are never read (thread results remain zero)
// * The last thread that partially overlaps with the matrix is shifted inwards
// such that the thread block fits exactly in the matrix
// - The threadgroup with the largest tid has blocks that exceed the matrix
// * The blocks that start outside the matrix are never read (thread results
// remain zero)
// * The last thread that partially overlaps with the matrix is shifted
// inwards such that the thread block fits exactly in the matrix
MLX_MTL_CONST short tgp_mem_size = BN * BM * TN;
@@ -213,7 +206,7 @@ struct GEMVTKernel {
const device T* mat [[buffer(0)]],
const device T* in_vec [[buffer(1)]],
const device T* bias [[buffer(2)]],
device T* out_vec [[buffer(3)]],
device T* out_vec [[buffer(3)]],
const constant int& in_vec_size [[buffer(4)]],
const constant int& out_vec_size [[buffer(5)]],
const constant int& marix_ld [[buffer(6)]],
@@ -225,8 +218,7 @@ struct GEMVTKernel {
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
// Appease compiler
// Appease compiler
(void)simd_gid;
(void)simd_lid;
@@ -243,77 +235,69 @@ struct GEMVTKernel {
// Edgecase handling
if (out_col < out_vec_size) {
out_col = out_col + TN < out_vec_size ? out_col : out_vec_size - TN;
// Per thread accumulation main loop
int bm = in_row;
for(; bm < in_vec_size; bm += BM * TM) {
for (; bm < in_vec_size; bm += BM * TM) {
// Adding a threadgroup_barrier improves performance slightly
// This is possibly it may help exploit cache better
threadgroup_barrier(mem_flags::mem_none);
if(bm + TM <= in_vec_size) {
#pragma clang loop unroll(full)
for(int tm = 0; tm < TM; tm++) {
if (bm + TM <= in_vec_size) {
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
v_coeff[tm] = in_vec[bm + tm];
}
#pragma clang loop unroll(full)
for(int tm = 0; tm < TM; tm++) {
for(int tn = 0; tn < TN; tn++) {
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
for (int tn = 0; tn < TN; tn++) {
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
}
for(int tn = 0; tn < TN; tn++) {
for (int tn = 0; tn < TN; tn++) {
result[tn] += v_coeff[tm] * inter[tn];
}
}
} else { // Edgecase handling
for(int tm = 0; bm + tm < in_vec_size; tm++) {
for (int tm = 0; bm + tm < in_vec_size; tm++) {
v_coeff[tm] = in_vec[bm + tm];
for(int tn = 0; tn < TN; tn++) {
for (int tn = 0; tn < TN; tn++) {
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
}
for(int tn = 0; tn < TN; tn++) {
for (int tn = 0; tn < TN; tn++) {
result[tn] += v_coeff[tm] * inter[tn];
}
}
}
}
}
// Threadgroup collection
#pragma clang loop unroll(full)
for(int i = 0; i < TN; i++) {
MLX_MTL_PRAGMA_UNROLL
for (int i = 0; i < TN; i++) {
tgp_results[lid.y * TN + i] = result[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Threadgroup accumulation and writing out results
if(lid.y == 0 && out_col < out_vec_size) {
#pragma clang loop unroll(full)
for(int i = 1; i < BM; i++) {
#pragma clang loop unroll(full)
for(int j = 0; j < TN; j++) {
if (lid.y == 0 && out_col < out_vec_size) {
MLX_MTL_PRAGMA_UNROLL
for (int i = 1; i < BM; i++) {
MLX_MTL_PRAGMA_UNROLL
for (int j = 0; j < TN; j++) {
result[j] += tgp_results[i * TN + j];
}
}
#pragma clang loop unroll(full)
for(int j = 0; j < TN; j++) {
if(kDoAxpby) {
out_vec[out_col + j] =
static_cast<T>(alpha) * result[j] +
MLX_MTL_PRAGMA_UNROLL
for (int j = 0; j < TN; j++) {
if (kDoAxpby) {
out_vec[out_col + j] = static_cast<T>(alpha) * result[j] +
static_cast<T>(beta) * bias[(out_col + j) * bias_stride];
} else {
out_vec[out_col + j] = result[j];
@@ -328,18 +312,18 @@ struct GEMVTKernel {
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
typename T,
const int BM, /* Threadgroup rows (in threads) */
const int BN, /* Threadgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN, /* Thread cols (in elements) */
const bool kDoNCBatch, /* Batch ndim > 1 */
const bool kDoAxpby> /* Do out = alpha * out + beta * bias */
[[kernel, max_total_threads_per_threadgroup(BM * BN)]] void gemv(
[[kernel, max_total_threads_per_threadgroup(BM* BN)]] void gemv(
const device T* mat [[buffer(0)]],
const device T* in_vec [[buffer(1)]],
const device T* bias [[buffer(2)]],
device T* out_vec [[buffer(3)]],
device T* out_vec [[buffer(3)]],
const constant int& in_vec_size [[buffer(4)]],
const constant int& out_vec_size [[buffer(5)]],
const constant int& marix_ld [[buffer(6)]],
@@ -355,16 +339,15 @@ template <
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
using gemv_kernel = GEMVKernel<T, BM, BN, TM, TN, kDoAxpby>;
threadgroup T tgp_memory[gemv_kernel::tgp_mem_size];
// Update batch offsets
if(kDoNCBatch) {
if (kDoNCBatch) {
in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim);
mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim);
if(kDoAxpby) {
if (kDoAxpby) {
bias += elem_to_loc(tid.z, batch_shape, bias_batch_stride, batch_ndim);
}
@@ -372,89 +355,215 @@ template <
in_vec += tid.z * vector_batch_stride[0];
mat += tid.z * matrix_batch_stride[0];
if(kDoAxpby) {
if (kDoAxpby) {
bias += tid.z * bias_batch_stride[0];
}
}
out_vec += tid.z * out_vec_size;
gemv_kernel::run(
mat,
in_vec,
bias,
out_vec,
in_vec_size,
out_vec_size,
marix_ld,
alpha,
beta,
bias_stride,
tgp_memory,
tid,
lid,
simd_gid,
simd_lid
);
gemv_kernel::run(
mat,
in_vec,
bias,
out_vec,
in_vec_size,
out_vec_size,
marix_ld,
alpha,
beta,
bias_stride,
tgp_memory,
tid,
lid,
simd_gid,
simd_lid);
}
#define instantiate_gemv_helper(name, itype, bm, bn, tm, tn, nc, axpby) \
template [[host_name("gemv_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn \
"_nc" #nc "_axpby" #axpby)]] [[kernel]] void \
gemv<itype, bm, bn, tm, tn, nc, axpby>( \
const device itype* mat [[buffer(0)]], \
const device itype* in_vec [[buffer(1)]], \
const device itype* bias [[buffer(2)]], \
device itype* out_vec [[buffer(3)]], \
const constant int& in_vec_size [[buffer(4)]], \
const constant int& out_vec_size [[buffer(5)]], \
const constant int& marix_ld [[buffer(6)]], \
const constant float& alpha [[buffer(7)]], \
const constant float& beta [[buffer(8)]], \
const constant int& batch_ndim [[buffer(9)]], \
const constant int* batch_shape [[buffer(10)]], \
const constant size_t* vector_batch_stride [[buffer(11)]], \
const constant size_t* matrix_batch_stride [[buffer(12)]], \
const constant size_t* bias_batch_stride [[buffer(13)]], \
const constant int& bias_stride [[buffer(14)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_gemv_helper(name, itype, bm, bn, tm, tn, nc, axpby) \
template [[host_name("gemv_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn "_nc" #nc "_axpby" #axpby)]] \
[[kernel]] void gemv<itype, bm, bn, tm, tn, nc, axpby>( \
const device itype* mat [[buffer(0)]], \
const device itype* in_vec [[buffer(1)]], \
const device itype* bias [[buffer(2)]], \
device itype* out_vec [[buffer(3)]], \
const constant int& in_vec_size [[buffer(4)]], \
const constant int& out_vec_size [[buffer(5)]], \
const constant int& marix_ld [[buffer(6)]], \
const constant float& alpha [[buffer(7)]], \
const constant float& beta [[buffer(8)]], \
const constant int& batch_ndim [[buffer(9)]], \
const constant int* batch_shape [[buffer(10)]], \
const constant size_t* vector_batch_stride [[buffer(11)]], \
const constant size_t* matrix_batch_stride [[buffer(12)]], \
const constant size_t* bias_batch_stride [[buffer(13)]], \
const constant int& bias_stride [[buffer(14)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_gemv(name, itype, bm, bn, tm, tn) \
// clang-format off
#define instantiate_gemv(name, itype, bm, bn, tm, tn) \
instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 0, 0) \
instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 0, 1) \
instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 1, 0) \
instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 1, 1)
instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 1, 1) // clang-format on
// clang-format off
#define instantiate_gemv_blocks(name, itype) \
instantiate_gemv(name, itype, 4, 32, 1, 4) \
instantiate_gemv(name, itype, 4, 32, 4, 4) \
instantiate_gemv(name, itype, 8, 32, 4, 4)
instantiate_gemv(name, itype, 8, 32, 4, 4) // clang-format on
instantiate_gemv_blocks(float32, float);
instantiate_gemv_blocks(float16, half);
instantiate_gemv_blocks(bfloat16, bfloat16_t);
template <
typename T,
const int BM, /* Threadgroup rows (in threads) */
const int BN, /* Threadgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN> /* Thread cols (in elements) */
[[kernel, max_total_threads_per_threadgroup(BM* BN)]] void gemv_bs(
const device T* mat [[buffer(0)]],
const device T* in_vec [[buffer(1)]],
const device T* bias [[buffer(2)]],
device T* out_vec [[buffer(3)]],
const constant int& in_vec_size [[buffer(4)]],
const constant int& out_vec_size [[buffer(5)]],
const constant int& marix_ld [[buffer(6)]],
const constant float& alpha [[buffer(7)]],
const constant float& beta [[buffer(8)]],
const constant int& batch_ndim [[buffer(9)]],
const constant int* batch_shape [[buffer(10)]],
const constant size_t* index_batch_strides [[buffer(11)]],
const constant int& vector_batch_ndim [[buffer(12)]],
const constant int* vector_batch_shape [[buffer(13)]],
const constant size_t* vector_batch_stride [[buffer(14)]],
const constant int& matrix_batch_ndim [[buffer(15)]],
const constant int* matrix_batch_shape [[buffer(16)]],
const constant size_t* matrix_batch_stride [[buffer(17)]],
const constant uint32_t* vec_indices [[buffer(18)]],
const constant uint32_t* mat_indices [[buffer(19)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
using gemv_kernel = GEMVKernel<T, BM, BN, TM, TN, false>;
threadgroup T tgp_memory[gemv_kernel::tgp_mem_size];
uint32_t indx_vec;
uint32_t indx_mat;
// Update batch offsets
if (batch_ndim > 1) {
const constant size_t* veci_bstrides = index_batch_strides;
const constant size_t* mati_bstrides = index_batch_strides + batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, veci_bstrides, mati_bstrides, batch_ndim);
indx_vec = vec_indices[batch_offsets.x];
indx_mat = mat_indices[batch_offsets.y];
} else {
indx_vec = vec_indices[index_batch_strides[0] * tid.z];
indx_mat = mat_indices[index_batch_strides[batch_ndim] * tid.z];
}
if (vector_batch_ndim > 1) {
in_vec += elem_to_loc(
indx_vec, vector_batch_shape, vector_batch_stride, vector_batch_ndim);
} else {
in_vec += indx_vec * vector_batch_stride[0];
}
if (matrix_batch_ndim > 1) {
mat += elem_to_loc(
indx_mat, matrix_batch_shape, matrix_batch_stride, matrix_batch_ndim);
} else {
mat += indx_mat * matrix_batch_stride[0];
}
out_vec += tid.z * out_vec_size;
gemv_kernel::run(
mat,
in_vec,
bias,
out_vec,
in_vec_size,
out_vec_size,
marix_ld,
alpha,
beta,
batch_ndim, // Not used
tgp_memory,
tid,
lid,
simd_gid,
simd_lid);
}
#define instantiate_gemv_bs_helper(nm, itype, bm, bn, tm, tn) \
template [[host_name("gemv_bs_" #nm "_bm" #bm "_bn" #bn "_tm" #tm \
"_tn" #tn)]] [[kernel]] void \
gemv_bs<itype, bm, bn, tm, tn>( \
const device itype* mat [[buffer(0)]], \
const device itype* in_vec [[buffer(1)]], \
const device itype* bias [[buffer(2)]], \
device itype* out_vec [[buffer(3)]], \
const constant int& in_vec_size [[buffer(4)]], \
const constant int& out_vec_size [[buffer(5)]], \
const constant int& marix_ld [[buffer(6)]], \
const constant float& alpha [[buffer(7)]], \
const constant float& beta [[buffer(8)]], \
const constant int& batch_ndim [[buffer(9)]], \
const constant int* batch_shape [[buffer(10)]], \
const constant size_t* index_batch_strides [[buffer(11)]], \
const constant int& vector_batch_ndim [[buffer(12)]], \
const constant int* vector_batch_shape [[buffer(13)]], \
const constant size_t* vector_batch_stride [[buffer(14)]], \
const constant int& matrix_batch_ndim [[buffer(15)]], \
const constant int* matrix_batch_shape [[buffer(16)]], \
const constant size_t* matrix_batch_stride [[buffer(17)]], \
const constant uint32_t* vec_indices [[buffer(18)]], \
const constant uint32_t* mat_indices [[buffer(19)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
// clang-format off
#define instantiate_gemv_bs_blocks(name, itype) \
instantiate_gemv_bs_helper(name, itype, 4, 32, 1, 4) \
instantiate_gemv_bs_helper(name, itype, 4, 32, 4, 4) \
instantiate_gemv_bs_helper(name, itype, 8, 32, 4, 4) // clang-format on
instantiate_gemv_bs_blocks(float32, float);
instantiate_gemv_bs_blocks(float16, half);
instantiate_gemv_bs_blocks(bfloat16, bfloat16_t);
///////////////////////////////////////////////////////////////////////////////
/// Vector matrix multiplication
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
typename T,
const int BM, /* Threadgroup rows (in threads) */
const int BN, /* Threadgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN, /* Thread cols (in elements) */
const bool kDoNCBatch, /* Batch ndim > 1 */
const bool kDoAxpby> /* Do out = alpha * out + beta * bias */
[[kernel, max_total_threads_per_threadgroup(BM * BN)]] void gemv_t(
[[kernel, max_total_threads_per_threadgroup(BM* BN)]] void gemv_t(
const device T* mat [[buffer(0)]],
const device T* in_vec [[buffer(1)]],
const device T* bias [[buffer(2)]],
device T* out_vec [[buffer(3)]],
device T* out_vec [[buffer(3)]],
const constant int& in_vec_size [[buffer(4)]],
const constant int& out_vec_size [[buffer(5)]],
const constant int& marix_ld [[buffer(6)]],
@@ -470,16 +579,15 @@ template <
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
using gemv_kernel = GEMVTKernel<T, BM, BN, TM, TN, kDoAxpby>;
threadgroup T tgp_memory[gemv_kernel::tgp_mem_size];
// Update batch offsets
if(kDoNCBatch) {
if (kDoNCBatch) {
in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim);
mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim);
if(kDoAxpby) {
if (kDoAxpby) {
bias += elem_to_loc(tid.z, batch_shape, bias_batch_stride, batch_ndim);
}
@@ -487,70 +595,202 @@ template <
in_vec += tid.z * vector_batch_stride[0];
mat += tid.z * matrix_batch_stride[0];
if(kDoAxpby) {
if (kDoAxpby) {
bias += tid.z * bias_batch_stride[0];
}
}
out_vec += tid.z * out_vec_size;
gemv_kernel::run(
mat,
in_vec,
bias,
out_vec,
in_vec_size,
out_vec_size,
marix_ld,
alpha,
beta,
bias_stride,
tgp_memory,
tid,
lid,
simd_gid,
simd_lid
);
gemv_kernel::run(
mat,
in_vec,
bias,
out_vec,
in_vec_size,
out_vec_size,
marix_ld,
alpha,
beta,
bias_stride,
tgp_memory,
tid,
lid,
simd_gid,
simd_lid);
}
#define instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, nc, axpby) \
template [[host_name("gemv_t_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn "_nc" #nc "_axpby" #axpby)]] \
[[kernel]] void gemv_t<itype, bm, bn, tm, tn, nc, axpby>( \
const device itype* mat [[buffer(0)]], \
const device itype* in_vec [[buffer(1)]], \
const device itype* bias [[buffer(2)]], \
device itype* out_vec [[buffer(3)]], \
const constant int& in_vec_size [[buffer(4)]], \
const constant int& out_vec_size [[buffer(5)]], \
const constant int& marix_ld [[buffer(6)]], \
const constant float& alpha [[buffer(7)]], \
const constant float& beta [[buffer(8)]], \
const constant int& batch_ndim [[buffer(9)]], \
const constant int* batch_shape [[buffer(10)]], \
const constant size_t* vector_batch_stride [[buffer(11)]], \
const constant size_t* matrix_batch_stride [[buffer(12)]], \
const constant size_t* bias_batch_stride [[buffer(13)]], \
const constant int& bias_stride [[buffer(14)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, nc, axpby) \
template [[host_name("gemv_t_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn \
"_nc" #nc "_axpby" #axpby)]] [[kernel]] void \
gemv_t<itype, bm, bn, tm, tn, nc, axpby>( \
const device itype* mat [[buffer(0)]], \
const device itype* in_vec [[buffer(1)]], \
const device itype* bias [[buffer(2)]], \
device itype* out_vec [[buffer(3)]], \
const constant int& in_vec_size [[buffer(4)]], \
const constant int& out_vec_size [[buffer(5)]], \
const constant int& marix_ld [[buffer(6)]], \
const constant float& alpha [[buffer(7)]], \
const constant float& beta [[buffer(8)]], \
const constant int& batch_ndim [[buffer(9)]], \
const constant int* batch_shape [[buffer(10)]], \
const constant size_t* vector_batch_stride [[buffer(11)]], \
const constant size_t* matrix_batch_stride [[buffer(12)]], \
const constant size_t* bias_batch_stride [[buffer(13)]], \
const constant int& bias_stride [[buffer(14)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_gemv_t(name, itype, bm, bn, tm, tn) \
// clang-format off
#define instantiate_gemv_t(name, itype, bm, bn, tm, tn) \
instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 0, 0) \
instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 0, 1) \
instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 1, 0) \
instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 1, 1)
instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 1, 1) // clang-format on
// clang-format off
#define instantiate_gemv_t_blocks(name, itype) \
instantiate_gemv_t(name, itype, 8, 8, 4, 1) \
instantiate_gemv_t(name, itype, 8, 8, 4, 4) \
instantiate_gemv_t(name, itype, 8, 8, 4, 1) \
instantiate_gemv_t(name, itype, 8, 8, 4, 4) \
instantiate_gemv_t(name, itype, 8, 16, 4, 4) \
instantiate_gemv_t(name, itype, 8, 32, 4, 4) \
instantiate_gemv_t(name, itype, 8, 64, 4, 4) \
instantiate_gemv_t(name, itype, 8, 128, 4, 4)
instantiate_gemv_t(name, itype, 8, 128, 4, 4) // clang-format on
// clang-format off
instantiate_gemv_t_blocks(float32, float);
instantiate_gemv_t_blocks(float16, half);
instantiate_gemv_t_blocks(bfloat16, bfloat16_t);
instantiate_gemv_t_blocks(bfloat16, bfloat16_t); // clang-format on
template <
typename T,
const int BM, /* Threadgroup rows (in threads) */
const int BN, /* Threadgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN> /* Thread cols (in elements) */
[[kernel, max_total_threads_per_threadgroup(BM* BN)]] void gemv_t_bs(
const device T* mat [[buffer(0)]],
const device T* in_vec [[buffer(1)]],
const device T* bias [[buffer(2)]],
device T* out_vec [[buffer(3)]],
const constant int& in_vec_size [[buffer(4)]],
const constant int& out_vec_size [[buffer(5)]],
const constant int& marix_ld [[buffer(6)]],
const constant float& alpha [[buffer(7)]],
const constant float& beta [[buffer(8)]],
const constant int& batch_ndim [[buffer(9)]],
const constant int* batch_shape [[buffer(10)]],
const constant size_t* index_batch_strides [[buffer(11)]],
const constant int& vector_batch_ndim [[buffer(12)]],
const constant int* vector_batch_shape [[buffer(13)]],
const constant size_t* vector_batch_stride [[buffer(14)]],
const constant int& matrix_batch_ndim [[buffer(15)]],
const constant int* matrix_batch_shape [[buffer(16)]],
const constant size_t* matrix_batch_stride [[buffer(17)]],
const constant uint32_t* vec_indices [[buffer(18)]],
const constant uint32_t* mat_indices [[buffer(19)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
using gemv_kernel = GEMVTKernel<T, BM, BN, TM, TN, false>;
threadgroup T tgp_memory[gemv_kernel::tgp_mem_size];
uint32_t indx_vec;
uint32_t indx_mat;
// Update batch offsets
if (batch_ndim > 1) {
const constant size_t* veci_bstrides = index_batch_strides;
const constant size_t* mati_bstrides = index_batch_strides + batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, veci_bstrides, mati_bstrides, batch_ndim);
indx_vec = vec_indices[batch_offsets.x];
indx_mat = mat_indices[batch_offsets.y];
} else {
indx_vec = vec_indices[index_batch_strides[0] * tid.z];
indx_mat = mat_indices[index_batch_strides[batch_ndim] * tid.z];
}
if (vector_batch_ndim > 1) {
in_vec += elem_to_loc(
indx_vec, vector_batch_shape, vector_batch_stride, vector_batch_ndim);
} else {
in_vec += indx_vec * vector_batch_stride[0];
}
if (matrix_batch_ndim > 1) {
mat += elem_to_loc(
indx_mat, matrix_batch_shape, matrix_batch_stride, matrix_batch_ndim);
} else {
mat += indx_mat * matrix_batch_stride[0];
}
out_vec += tid.z * out_vec_size;
gemv_kernel::run(
mat,
in_vec,
bias,
out_vec,
in_vec_size,
out_vec_size,
marix_ld,
alpha,
beta,
batch_ndim, // Not used,
tgp_memory,
tid,
lid,
simd_gid,
simd_lid);
}
#define instantiate_gemv_t_bs_helper(nm, itype, bm, bn, tm, tn) \
template [[host_name("gemv_t_bs_" #nm "_bm" #bm "_bn" #bn "_tm" #tm \
"_tn" #tn)]] [[kernel]] void \
gemv_t_bs<itype, bm, bn, tm, tn>( \
const device itype* mat [[buffer(0)]], \
const device itype* in_vec [[buffer(1)]], \
const device itype* bias [[buffer(2)]], \
device itype* out_vec [[buffer(3)]], \
const constant int& in_vec_size [[buffer(4)]], \
const constant int& out_vec_size [[buffer(5)]], \
const constant int& marix_ld [[buffer(6)]], \
const constant float& alpha [[buffer(7)]], \
const constant float& beta [[buffer(8)]], \
const constant int& batch_ndim [[buffer(9)]], \
const constant int* batch_shape [[buffer(10)]], \
const constant size_t* index_batch_strides [[buffer(11)]], \
const constant int& vector_batch_ndim [[buffer(12)]], \
const constant int* vector_batch_shape [[buffer(13)]], \
const constant size_t* vector_batch_stride [[buffer(14)]], \
const constant int& matrix_batch_ndim [[buffer(15)]], \
const constant int* matrix_batch_shape [[buffer(16)]], \
const constant size_t* matrix_batch_stride [[buffer(17)]], \
const constant uint32_t* vec_indices [[buffer(18)]], \
const constant uint32_t* mat_indices [[buffer(19)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
// clang-format off
#define instantiate_gemv_t_bs_blocks(name, itype) \
instantiate_gemv_t_bs_helper(name, itype, 8, 8, 4, 1) \
instantiate_gemv_t_bs_helper(name, itype, 8, 8, 4, 4) \
instantiate_gemv_t_bs_helper(name, itype, 8, 16, 4, 4) \
instantiate_gemv_t_bs_helper(name, itype, 8, 32, 4, 4) \
instantiate_gemv_t_bs_helper(name, itype, 8, 64, 4, 4) \
instantiate_gemv_t_bs_helper(name, itype, 8, 128, 4, 4) // clang-format on
// clang-format off
instantiate_gemv_t_bs_blocks(float32, float);
instantiate_gemv_t_bs_blocks(float16, half);
instantiate_gemv_t_bs_blocks(bfloat16, bfloat16_t); // clang-format on

View File

@@ -99,7 +99,8 @@ template <typename T, int N_READS = RMS_N_READS>
for (int i = 0; i < N_READS; i++) {
if ((lid * N_READS + i) < axis_size) {
thread_x[i] = (thread_x[i] - mean) * normalizer;
out[i] = w[w_stride * i] * static_cast<T>(thread_x[i]) + b[b_stride * i];
out[i] =
w[w_stride * i] * static_cast<T>(thread_x[i]) + b[b_stride * i];
}
}
}
@@ -192,13 +193,15 @@ template <typename T, int N_READS = RMS_N_READS>
if (r + lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
float xi = (x[r + i] - mean) * normalizer;
out[r + i] = w[w_stride * (i + r)] * static_cast<T>(xi) + b[b_stride * (i + r)];
out[r + i] =
w[w_stride * (i + r)] * static_cast<T>(xi) + b[b_stride * (i + r)];
}
} else {
for (int i = 0; i < N_READS; i++) {
if ((r + lid * N_READS + i) < axis_size) {
float xi = (x[r + i] - mean) * normalizer;
out[r + i] = w[w_stride * (i + r)] * static_cast<T>(xi) + b[b_stride * (i + r)];
out[r + i] = w[w_stride * (i + r)] * static_cast<T>(xi) +
b[b_stride * (i + r)];
}
}
}
@@ -323,16 +326,18 @@ template <typename T, int N_READS = RMS_N_READS>
if (lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
thread_x[i] = (thread_x[i] - mean) * normalizer;
gx[i] = static_cast<T>(normalizer * (thread_w[i] * thread_g[i] - meanwg) -
thread_x[i] * meanwgxc * normalizer2);
gx[i] = static_cast<T>(
normalizer * (thread_w[i] * thread_g[i] - meanwg) -
thread_x[i] * meanwgxc * normalizer2);
gw[i] = static_cast<T>(thread_g[i] * thread_x[i]);
}
} else {
for (int i = 0; i < N_READS; i++) {
if ((lid * N_READS + i) < axis_size) {
thread_x[i] = (thread_x[i] - mean) * normalizer;
gx[i] = static_cast<T>(normalizer * (thread_w[i] * thread_g[i] - meanwg) -
thread_x[i] * meanwgxc * normalizer2);
gx[i] = static_cast<T>(
normalizer * (thread_w[i] * thread_g[i] - meanwg) -
thread_x[i] * meanwgxc * normalizer2);
gw[i] = static_cast<T>(thread_g[i] * thread_x[i]);
}
}
@@ -460,8 +465,8 @@ template <typename T, int N_READS = RMS_N_READS>
float xi = (x[i + r] - mean) * normalizer;
float wi = w[(i + r) * w_stride];
float gi = g[i + r];
gx[i + r] = static_cast<T>(normalizer * (wi * gi - meanwg) -
xi * meanwgxc * normalizer2);
gx[i + r] = static_cast<T>(
normalizer * (wi * gi - meanwg) - xi * meanwgxc * normalizer2);
gw[i + r] = static_cast<T>(gi * xi);
}
} else {
@@ -470,8 +475,8 @@ template <typename T, int N_READS = RMS_N_READS>
float xi = (x[i + r] - mean) * normalizer;
float wi = w[(i + r) * w_stride];
float gi = g[i + r];
gx[i + r] = static_cast<T>(normalizer * (wi * gi - meanwg) -
xi * meanwgxc * normalizer2);
gx[i + r] = static_cast<T>(
normalizer * (wi * gi - meanwg) - xi * meanwgxc * normalizer2);
gw[i + r] = static_cast<T>(gi * xi);
}
}
@@ -548,6 +553,4 @@ template <typename T, int N_READS = RMS_N_READS>
instantiate_layer_norm(float32, float)
instantiate_layer_norm(float16, half)
instantiate_layer_norm(bfloat16, bfloat16_t)
// clang-format on
instantiate_layer_norm(bfloat16, bfloat16_t) // clang-format on

View File

@@ -1,7 +1,7 @@
// Copyright © 2023-2024 Apple Inc.
#include <metal_stdlib>
#include <metal_simdgroup>
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/defines.h"
@@ -15,30 +15,31 @@ using namespace metal;
MLX_MTL_CONST int SIMD_SIZE = 32;
template <typename T, typename U, int values_per_thread, int bits>
inline U load_vector(const device T *x, thread U *x_thread) {
static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}");
inline U load_vector(const device T* x, thread U* x_thread) {
static_assert(
bits == 2 || bits == 4 || bits == 8,
"Template undefined for bits not in {2, 4, 8}");
U sum = 0;
if (bits == 2) {
for (int i = 0; i < values_per_thread; i += 4) {
sum += x[i] + x[i+1] + x[i+2] + x[i+3];
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
x_thread[i] = x[i];
x_thread[i+1] = x[i+1] / 4.0f;
x_thread[i+2] = x[i+2] / 16.0f;
x_thread[i+3] = x[i+3] / 64.0f;
x_thread[i + 1] = x[i + 1] / 4.0f;
x_thread[i + 2] = x[i + 2] / 16.0f;
x_thread[i + 3] = x[i + 3] / 64.0f;
}
}
else if (bits == 4) {
for (int i = 0; i < values_per_thread; i += 4) {
sum += x[i] + x[i+1] + x[i+2] + x[i+3];
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
x_thread[i] = x[i];
x_thread[i+1] = x[i+1] / 16.0f;
x_thread[i+2] = x[i+2] / 256.0f;
x_thread[i+3] = x[i+3] / 4096.0f;
x_thread[i + 1] = x[i + 1] / 16.0f;
x_thread[i + 2] = x[i + 2] / 256.0f;
x_thread[i + 3] = x[i + 3] / 4096.0f;
}
}
@@ -53,33 +54,35 @@ inline U load_vector(const device T *x, thread U *x_thread) {
}
template <typename T, typename U, int values_per_thread, int bits>
inline U load_vector_safe(const device T *x, thread U *x_thread, int N) {
static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}");
inline U load_vector_safe(const device T* x, thread U* x_thread, int N) {
static_assert(
bits == 2 || bits == 4 || bits == 8,
"Template undefined for bits not in {2, 4, 8}");
U sum = 0;
if (bits == 2) {
for (int i = 0; i < N; i += 4) {
sum += x[i] + x[i+1] + x[i+2] + x[i+3];
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
x_thread[i] = x[i];
x_thread[i+1] = x[i+1] / 4.0f;
x_thread[i+2] = x[i+2] / 16.0f;
x_thread[i+3] = x[i+3] / 64.0f;
x_thread[i + 1] = x[i + 1] / 4.0f;
x_thread[i + 2] = x[i + 2] / 16.0f;
x_thread[i + 3] = x[i + 3] / 64.0f;
}
for (int i=N; i<values_per_thread; i++) {
for (int i = N; i < values_per_thread; i++) {
x_thread[i] = 0;
}
}
else if (bits == 4) {
for (int i = 0; i < N; i += 4) {
sum += x[i] + x[i+1] + x[i+2] + x[i+3];
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
x_thread[i] = x[i];
x_thread[i+1] = x[i+1] / 16.0f;
x_thread[i+2] = x[i+2] / 256.0f;
x_thread[i+3] = x[i+3] / 4096.0f;
x_thread[i + 1] = x[i + 1] / 16.0f;
x_thread[i + 2] = x[i + 2] / 256.0f;
x_thread[i + 3] = x[i + 3] / 4096.0f;
}
for (int i=N; i<values_per_thread; i++) {
for (int i = N; i < values_per_thread; i++) {
x_thread[i] = 0;
}
}
@@ -89,7 +92,7 @@ inline U load_vector_safe(const device T *x, thread U *x_thread, int N) {
sum += x[i];
x_thread[i] = x[i];
}
for (int i=N; i<values_per_thread; i++) {
for (int i = N; i < values_per_thread; i++) {
x_thread[i] = 0;
}
}
@@ -98,29 +101,36 @@ inline U load_vector_safe(const device T *x, thread U *x_thread, int N) {
}
template <typename U, int values_per_thread, int bits>
inline U qdot(const device uint8_t* w, const thread U *x_thread, U scale, U bias, U sum) {
static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}");
inline U qdot(
const device uint8_t* w,
const thread U* x_thread,
U scale,
U bias,
U sum) {
static_assert(
bits == 2 || bits == 4 || bits == 8,
"Template undefined for bits not in {2, 4, 8}");
U accum = 0;
if (bits == 2) {
for (int i = 0; i < (values_per_thread / 4); i++) {
accum += (
x_thread[4*i] * (w[i] & 0x03)
+ x_thread[4*i+1] * (w[i] & 0x0c)
+ x_thread[4*i+2] * (w[i] & 0x30)
+ x_thread[4*i+3] * (w[i] & 0xc0));
accum +=
(x_thread[4 * i] * (w[i] & 0x03) +
x_thread[4 * i + 1] * (w[i] & 0x0c) +
x_thread[4 * i + 2] * (w[i] & 0x30) +
x_thread[4 * i + 3] * (w[i] & 0xc0));
}
}
else if (bits == 4) {
const device uint16_t* ws = (const device uint16_t*)w;
for (int i = 0; i < (values_per_thread / 4); i++) {
accum += (
x_thread[4*i] * (ws[i] & 0x000f)
+ x_thread[4*i+1] * (ws[i] & 0x00f0)
+ x_thread[4*i+2] * (ws[i] & 0x0f00)
+ x_thread[4*i+3] * (ws[i] & 0xf000));
accum +=
(x_thread[4 * i] * (ws[i] & 0x000f) +
x_thread[4 * i + 1] * (ws[i] & 0x00f0) +
x_thread[4 * i + 2] * (ws[i] & 0x0f00) +
x_thread[4 * i + 3] * (ws[i] & 0xf000));
}
}
@@ -134,29 +144,37 @@ inline U qdot(const device uint8_t* w, const thread U *x_thread, U scale, U bias
}
template <typename U, int values_per_thread, int bits>
inline U qdot_safe(const device uint8_t* w, const thread U *x_thread, U scale, U bias, U sum, int N) {
static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}");
inline U qdot_safe(
const device uint8_t* w,
const thread U* x_thread,
U scale,
U bias,
U sum,
int N) {
static_assert(
bits == 2 || bits == 4 || bits == 8,
"Template undefined for bits not in {2, 4, 8}");
U accum = 0;
if (bits == 2) {
for (int i = 0; i < (N / 4); i++) {
accum += (
x_thread[4*i] * (w[i] & 0x03)
+ x_thread[4*i+1] * (w[i] & 0x0c)
+ x_thread[4*i+2] * (w[i] & 0x30)
+ x_thread[4*i+3] * (w[i] & 0xc0));
accum +=
(x_thread[4 * i] * (w[i] & 0x03) +
x_thread[4 * i + 1] * (w[i] & 0x0c) +
x_thread[4 * i + 2] * (w[i] & 0x30) +
x_thread[4 * i + 3] * (w[i] & 0xc0));
}
}
else if (bits == 4) {
const device uint16_t* ws = (const device uint16_t*)w;
for (int i = 0; i < (N / 4); i++) {
accum += (
x_thread[4*i] * (ws[i] & 0x000f)
+ x_thread[4*i+1] * (ws[i] & 0x00f0)
+ x_thread[4*i+2] * (ws[i] & 0x0f00)
+ x_thread[4*i+3] * (ws[i] & 0xf000));
accum +=
(x_thread[4 * i] * (ws[i] & 0x000f) +
x_thread[4 * i + 1] * (ws[i] & 0x00f0) +
x_thread[4 * i + 2] * (ws[i] & 0x0f00) +
x_thread[4 * i + 3] * (ws[i] & 0xf000));
}
}
@@ -170,27 +188,27 @@ inline U qdot_safe(const device uint8_t* w, const thread U *x_thread, U scale, U
}
template <typename U, int values_per_thread, int bits>
inline void qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) {
static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}");
inline void
qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) {
static_assert(
bits == 2 || bits == 4 || bits == 8,
"Template undefined for bits not in {2, 4, 8}");
if (bits == 2) {
U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f};
for (int i = 0; i < (values_per_thread / 4); i++) {
result[4*i] += x * (s[0] * (w[i] & 0x03) + bias);
result[4*i+1] += x * (s[1] * (w[i] & 0x0c) + bias);
result[4*i+2] += x * (s[2] * (w[i] & 0x30) + bias);
result[4*i+3] += x * (s[3] * (w[i] & 0xc0) + bias);
result[4 * i] += x * (s[0] * (w[i] & 0x03) + bias);
result[4 * i + 1] += x * (s[1] * (w[i] & 0x0c) + bias);
result[4 * i + 2] += x * (s[2] * (w[i] & 0x30) + bias);
result[4 * i + 3] += x * (s[3] * (w[i] & 0xc0) + bias);
}
}
else if (bits == 4) {
const thread uint16_t* ws = (const thread uint16_t*)w;
U s[4] = {scale, scale / 16.0f, scale / 256.0f, scale / 4096.0f};
for (int i = 0; i < (values_per_thread / 4); i++) {
result[4*i] += x * (s[0] * (ws[i] & 0x000f) + bias);
result[4*i+1] += x * (s[1] * (ws[i] & 0x00f0) + bias);
result[4*i+2] += x * (s[2] * (ws[i] & 0x0f00) + bias);
result[4*i+3] += x * (s[3] * (ws[i] & 0xf000) + bias);
U s[2] = {scale, scale / 16.0f};
for (int i = 0; i < (values_per_thread / 2); i++) {
result[2 * i] += x * (s[0] * (w[i] & 0x0f) + bias);
result[2 * i + 1] += x * (s[1] * (w[i] & 0xf0) + bias);
}
}
@@ -202,27 +220,31 @@ inline void qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* resu
}
template <typename U, int N, int bits>
inline void dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) {
static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}");
inline void
dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) {
static_assert(
bits == 2 || bits == 4 || bits == 8,
"Template undefined for bits not in {2, 4, 8}");
if (bits == 2) {
U s[4] = {scale, scale / static_cast<U>(4.0f), scale / static_cast<U>(16.0f), scale / static_cast<U>(64.0f)};
U s[4] = {
scale,
scale / static_cast<U>(4.0f),
scale / static_cast<U>(16.0f),
scale / static_cast<U>(64.0f)};
for (int i = 0; i < (N / 4); i++) {
w_local[4*i] = s[0] * (w[i] & 0x03) + bias;
w_local[4*i+1] = s[1] * (w[i] & 0x0c) + bias;
w_local[4*i+2] = s[2] * (w[i] & 0x30) + bias;
w_local[4*i+3] = s[3] * (w[i] & 0xc0) + bias;
w_local[4 * i] = s[0] * (w[i] & 0x03) + bias;
w_local[4 * i + 1] = s[1] * (w[i] & 0x0c) + bias;
w_local[4 * i + 2] = s[2] * (w[i] & 0x30) + bias;
w_local[4 * i + 3] = s[3] * (w[i] & 0xc0) + bias;
}
}
else if (bits == 4) {
const device uint16_t* ws = (const device uint16_t*)w;
U s[4] = {scale, scale / static_cast<U>(16.0f), scale / static_cast<U>(256.0f), scale / static_cast<U>(4096.0f)};
for (int i = 0; i < (N / 4); i++) {
w_local[4*i] = s[0] * (ws[i] & 0x000f) + bias;
w_local[4*i+1] = s[1] * (ws[i] & 0x00f0) + bias;
w_local[4*i+2] = s[2] * (ws[i] & 0x0f00) + bias;
w_local[4*i+3] = s[3] * (ws[i] & 0xf000) + bias;
U s[2] = {scale, scale / static_cast<U>(16.0f)};
for (int i = 0; i < (N / 2); i++) {
w_local[2 * i] = s[0] * (w[i] & 0x0f) + bias;
w_local[2 * i + 1] = s[1] * (w[i] & 0xf0) + bias;
}
}
@@ -243,13 +265,20 @@ template <
short group_size,
short bits>
struct QuantizedBlockLoader {
static_assert(BCOLS <= group_size, "The group size should be larger than the columns");
static_assert(group_size % BCOLS == 0, "The group size should be divisible by the columns");
static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}");
static_assert(
BCOLS <= group_size,
"The group size should be larger than the columns");
static_assert(
group_size % BCOLS == 0,
"The group size should be divisible by the columns");
static_assert(
bits == 2 || bits == 4 || bits == 8,
"Template undefined for bits not in {2, 4, 8}");
MLX_MTL_CONST short pack_factor = 32 / bits;
MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor;
MLX_MTL_CONST short n_reads = (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size;
MLX_MTL_CONST short n_reads =
(BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size;
MLX_MTL_CONST short group_steps = group_size / BCOLS;
const int src_ld;
@@ -275,7 +304,8 @@ struct QuantizedBlockLoader {
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
ushort simd_lane_id [[thread_index_in_simdgroup]])
: src_ld(src_ld_),
tile_stride(reduction_dim ? BCOLS_PACKED : BROWS * src_ld / pack_factor),
tile_stride(
reduction_dim ? BCOLS_PACKED : BROWS * src_ld / pack_factor),
group_step_cnt(0),
group_stride(BROWS * src_ld / group_size),
thread_idx(simd_group_id * 32 + simd_lane_id),
@@ -293,8 +323,9 @@ struct QuantizedBlockLoader {
T scale = *scales;
T bias = *biases;
for (int i=0; i<n_reads; i++) {
dequantize<T, pack_factor, bits>((device uint8_t*)(src + i), scale, bias, dst + i * pack_factor);
for (int i = 0; i < n_reads; i++) {
dequantize<T, pack_factor, bits>(
(device uint8_t*)(src + i), scale, bias, dst + i * pack_factor);
}
}
@@ -304,14 +335,14 @@ struct QuantizedBlockLoader {
}
if (reduction_dim == 1 && bi >= src_tile_dim.y) {
for (int i=0; i<n_reads*pack_factor; i++) {
for (int i = 0; i < n_reads * pack_factor; i++) {
dst[i] = T(0);
}
return;
}
if (reduction_dim == 0 && bi >= src_tile_dim.x) {
for (int i=0; i<n_reads*pack_factor; i++) {
for (int i = 0; i < n_reads * pack_factor; i++) {
dst[i] = T(0);
}
return;
@@ -319,8 +350,9 @@ struct QuantizedBlockLoader {
T scale = *scales;
T bias = *biases;
for (int i=0; i<n_reads; i++) {
dequantize<T, pack_factor, bits>((device uint8_t*)(src + i), scale, bias, dst + i * pack_factor);
for (int i = 0; i < n_reads; i++) {
dequantize<T, pack_factor, bits>(
(device uint8_t*)(src + i), scale, bias, dst + i * pack_factor);
}
}
@@ -357,7 +389,6 @@ template <typename T, int group_size, int bits, int packs_per_thread>
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int num_simdgroups = 2;
constexpr int results_per_simdgroup = 4;
constexpr int pack_factor = 32 / bits;
@@ -373,7 +404,8 @@ template <typename T, int group_size, int bits, int packs_per_thread>
// Adjust positions
const int in_vec_size_w = in_vec_size / pack_factor;
const int in_vec_size_g = in_vec_size / group_size;
const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + simd_gid * results_per_simdgroup;
const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) +
simd_gid * results_per_simdgroup;
w += out_row * in_vec_size_w + simd_lid * packs_per_thread;
scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
@@ -384,7 +416,8 @@ template <typename T, int group_size, int bits, int packs_per_thread>
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
for (int row = 0; row < results_per_simdgroup; row++) {
const device uint8_t* wl = (const device uint8_t *)(w + row * in_vec_size_w);
const device uint8_t* wl =
(const device uint8_t*)(w + row * in_vec_size_w);
const device T* sl = scales + row * in_vec_size_g;
const device T* bl = biases + row * in_vec_size_g;
@@ -407,7 +440,6 @@ template <typename T, int group_size, int bits, int packs_per_thread>
}
}
template <typename T, const int group_size, const int bits>
[[kernel]] void qmv(
const device uint32_t* w [[buffer(0)]],
@@ -420,7 +452,6 @@ template <typename T, const int group_size, const int bits>
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int num_simdgroups = 2;
constexpr int results_per_simdgroup = 4;
constexpr int packs_per_thread = 1;
@@ -437,7 +468,8 @@ template <typename T, const int group_size, const int bits>
// Adjust positions
const int in_vec_size_w = in_vec_size / pack_factor;
const int in_vec_size_g = in_vec_size / group_size;
const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + simd_gid * results_per_simdgroup;
const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) +
simd_gid * results_per_simdgroup;
const int used_out_row = min(out_vec_size - results_per_simdgroup, out_row);
if (out_row >= out_vec_size) {
@@ -454,17 +486,19 @@ template <typename T, const int group_size, const int bits>
y += tid.z * out_vec_size + out_row;
int k = 0;
for (; k < in_vec_size-block_size; k += block_size) {
for (; k < in_vec_size - block_size; k += block_size) {
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
for (int row = 0; out_row + row < out_vec_size; row++) {
const device uint8_t* wl = (const device uint8_t *)(w + row * in_vec_size_w);
const device uint8_t* wl =
(const device uint8_t*)(w + row * in_vec_size_w);
const device T* sl = scales + row * in_vec_size_g;
const device T* bl = biases + row * in_vec_size_g;
U s = sl[0];
U b = bl[0];
result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
result[row] +=
qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
}
w += block_size / pack_factor;
@@ -472,11 +506,16 @@ template <typename T, const int group_size, const int bits>
biases += block_size / group_size;
x += block_size;
}
const int remaining = clamp(static_cast<int>(in_vec_size - k - simd_lid * values_per_thread), 0, values_per_thread);
U sum = load_vector_safe<T, U, values_per_thread, bits>(x, x_thread, remaining);
const int remaining = clamp(
static_cast<int>(in_vec_size - k - simd_lid * values_per_thread),
0,
values_per_thread);
U sum =
load_vector_safe<T, U, values_per_thread, bits>(x, x_thread, remaining);
for (int row = 0; out_row + row < out_vec_size; row++) {
const device uint8_t* wl = (const device uint8_t *)(w + row * in_vec_size_w);
const device uint8_t* wl =
(const device uint8_t*)(w + row * in_vec_size_w);
const device T* sl = scales + row * in_vec_size_g;
const device T* bl = biases + row * in_vec_size_g;
@@ -502,17 +541,19 @@ template <typename T, const int group_size, const int bits>
y += tid.z * out_vec_size + used_out_row;
int k = 0;
for (; k < in_vec_size-block_size; k += block_size) {
for (; k < in_vec_size - block_size; k += block_size) {
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
for (int row = 0; row < results_per_simdgroup; row++) {
const device uint8_t* wl = (const device uint8_t *)(w + row * in_vec_size_w);
const device uint8_t* wl =
(const device uint8_t*)(w + row * in_vec_size_w);
const device T* sl = scales + row * in_vec_size_g;
const device T* bl = biases + row * in_vec_size_g;
U s = sl[0];
U b = bl[0];
result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
result[row] +=
qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
}
w += block_size / pack_factor;
@@ -520,17 +561,23 @@ template <typename T, const int group_size, const int bits>
biases += block_size / group_size;
x += block_size;
}
const int remaining = clamp(static_cast<int>(in_vec_size - k - simd_lid * values_per_thread), 0, values_per_thread);
U sum = load_vector_safe<T, U, values_per_thread, bits>(x, x_thread, remaining);
const int remaining = clamp(
static_cast<int>(in_vec_size - k - simd_lid * values_per_thread),
0,
values_per_thread);
U sum =
load_vector_safe<T, U, values_per_thread, bits>(x, x_thread, remaining);
for (int row = 0; row < results_per_simdgroup; row++) {
const device uint8_t* wl = (const device uint8_t *)(w + row * in_vec_size_w);
const device uint8_t* wl =
(const device uint8_t*)(w + row * in_vec_size_w);
const device T* sl = scales + row * in_vec_size_g;
const device T* bl = biases + row * in_vec_size_g;
U s = sl[0];
U b = bl[0];
result[row] += qdot_safe<U, values_per_thread, bits>(wl, x_thread, s, b, sum, remaining);
result[row] += qdot_safe<U, values_per_thread, bits>(
wl, x_thread, s, b, sum, remaining);
}
for (int row = 0; row < results_per_simdgroup; row++) {
@@ -542,7 +589,6 @@ template <typename T, const int group_size, const int bits>
}
}
template <typename T, const int group_size, const int bits>
[[kernel]] void qvm(
const device T* x [[buffer(0)]],
@@ -555,7 +601,6 @@ template <typename T, const int group_size, const int bits>
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int num_simdgroups = 8;
constexpr int pack_factor = 32 / bits;
constexpr int blocksize = SIMD_SIZE;
@@ -590,7 +635,8 @@ template <typename T, const int group_size, const int bits>
bias = biases[(i + simd_lid) * out_vec_size_g];
w_local = w[(i + simd_lid) * out_vec_size_w];
qouter<U, pack_factor, bits>((thread uint8_t *)&w_local, x_local, scale, bias, result);
qouter<U, pack_factor, bits>(
(thread uint8_t*)&w_local, x_local, scale, bias, result);
}
if (static_cast<int>(i + simd_lid) < in_vec_size) {
x_local = x[i + simd_lid];
@@ -603,25 +649,32 @@ template <typename T, const int group_size, const int bits>
bias = 0;
w_local = 0;
}
qouter<U, pack_factor, bits>((thread uint8_t *)&w_local, x_local, scale, bias, result);
qouter<U, pack_factor, bits>(
(thread uint8_t*)&w_local, x_local, scale, bias, result);
// Accumulate in the simdgroup
#pragma clang loop unroll(full)
for (int k=0; k<pack_factor; k++) {
// Accumulate in the simdgroup
#pragma clang loop unroll(full)
for (int k = 0; k < pack_factor; k++) {
result[k] = simd_sum(result[k]);
}
// Store the result
if (simd_lid == 0) {
#pragma clang loop unroll(full)
for (int k=0; k<pack_factor; k++) {
#pragma clang loop unroll(full)
for (int k = 0; k < pack_factor; k++) {
y[k] = static_cast<T>(result[k]);
}
}
}
template <typename T, const int BM, const int BK, const int BN, const int group_size, const int bits, const bool aligned_N>
template <
typename T,
const int BM,
const int BK,
const int BN,
const int group_size,
const int bits,
const bool aligned_N>
[[kernel]] void qmm_t(
const device T* x [[buffer(0)]],
const device uint32_t* w [[buffer(1)]],
@@ -635,7 +688,6 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
@@ -647,9 +699,19 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
constexpr int BK_padded = (BK + 16 / sizeof(T));
// Instantiate the appropriate BlockMMA and Loader
using mma_t = mlx::steel::BlockMMA<T, T, BM, BN, BK, WM, WN, false, true, BK_padded, BK_padded>;
using loader_x_t = mlx::steel::BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE>;
using loader_w_t = QuantizedBlockLoader<T, BN, BK, BK_padded, 1, WM * WN * SIMD_SIZE, group_size, bits>;
using mma_t = mlx::steel::
BlockMMA<T, T, BM, BN, BK, WM, WN, false, true, BK_padded, BK_padded>;
using loader_x_t =
mlx::steel::BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE>;
using loader_w_t = QuantizedBlockLoader<
T,
BN,
BK,
BK_padded,
1,
WM * WN * SIMD_SIZE,
group_size,
bits>;
threadgroup T Xs[BM * BK_padded];
threadgroup T Ws[BN * BK_padded];
@@ -675,7 +737,7 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
if (num_els < BM) {
if (!aligned_N && num_outs < BN) {
for (int k=0; k<K; k += BK) {
for (int k = 0; k < K; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_x.load_safe(short2(BK, num_els));
loader_w.load_safe(short2(BK, num_outs));
@@ -685,7 +747,7 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
loader_w.next();
}
} else {
for (int k=0; k<K; k += BK) {
for (int k = 0; k < K; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_x.load_safe(short2(BK, num_els));
loader_w.load_unsafe();
@@ -697,7 +759,7 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
}
} else {
if (!aligned_N && num_outs < BN) {
for (int k=0; k<K; k += BK) {
for (int k = 0; k < K; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_x.load_unsafe();
loader_w.load_safe(short2(BK, num_outs));
@@ -707,7 +769,7 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
loader_w.next();
}
} else {
for (int k=0; k<K; k += BK) {
for (int k = 0; k < K; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_x.load_unsafe();
loader_w.load_unsafe();
@@ -728,8 +790,13 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
}
}
template <typename T, const int BM, const int BK, const int BN, const int group_size, const int bits>
template <
typename T,
const int BM,
const int BK,
const int BN,
const int group_size,
const int bits>
[[kernel]] void qmm_n(
const device T* x [[buffer(0)]],
const device uint32_t* w [[buffer(1)]],
@@ -743,7 +810,6 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
@@ -756,9 +822,19 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
constexpr int BN_padded = (BN + 16 / sizeof(T));
// Instantiate the appropriate BlockMMA and Loader
using mma_t = mlx::steel::BlockMMA<T, T, BM, BN, BK, WM, WN, false, false, BK_padded, BN_padded>;
using loader_x_t = mlx::steel::BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE, 1, 4>;
using loader_w_t = QuantizedBlockLoader<T, BK, BN, BN_padded, 0, WM * WN * SIMD_SIZE, group_size, bits>;
using mma_t = mlx::steel::
BlockMMA<T, T, BM, BN, BK, WM, WN, false, false, BK_padded, BN_padded>;
using loader_x_t = mlx::steel::
BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE, 1, 4>;
using loader_w_t = QuantizedBlockLoader<
T,
BK,
BN,
BN_padded,
0,
WM * WN * SIMD_SIZE,
group_size,
bits>;
threadgroup T Xs[BM * BK_padded];
threadgroup T Ws[BK * BN_padded];
@@ -780,8 +856,8 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
if (num_els < BM) {
if ((K % BK) != 0) {
const int k_blocks = K/BK;
for (int k=0; k<k_blocks; k++) {
const int k_blocks = K / BK;
for (int k = 0; k < k_blocks; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_x.load_safe(short2(BK, num_els));
loader_w.load_unsafe();
@@ -797,7 +873,7 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(Xs, Ws);
} else {
for (int k=0; k<K; k += BK) {
for (int k = 0; k < K; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_x.load_safe(short2(BK, num_els));
loader_w.load_unsafe();
@@ -809,8 +885,8 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
}
} else {
if ((K % BK) != 0) {
const int k_blocks = K/BK;
for (int k=0; k<k_blocks; k++) {
const int k_blocks = K / BK;
for (int k = 0; k < k_blocks; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_x.load_unsafe();
loader_w.load_unsafe();
@@ -826,7 +902,7 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(Xs, Ws);
} else {
for (int k=0; k<K; k += BK) {
for (int k = 0; k < K; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_x.load_unsafe();
loader_w.load_unsafe();
@@ -847,26 +923,28 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
}
}
#define instantiate_qmv_fast(name, itype, group_size, bits, packs_per_thread) \
template [[host_name("qmv_" #name "_gs_" #group_size "_b_" #bits "_fast")]] \
[[kernel]] void qmv_fast<itype, group_size, bits, packs_per_thread>( \
const device uint32_t* w [[buffer(0)]], \
const device itype* scales [[buffer(1)]], \
const device itype* biases [[buffer(2)]], \
const device itype* x [[buffer(3)]], \
device itype* y [[buffer(4)]], \
const constant int& in_vec_size [[buffer(5)]], \
const constant int& out_vec_size [[buffer(6)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
template [[host_name("qmv_" #name "_gs_" #group_size "_b_" #bits \
"_fast")]] [[kernel]] void \
qmv_fast<itype, group_size, bits, packs_per_thread>( \
const device uint32_t* w [[buffer(0)]], \
const device itype* scales [[buffer(1)]], \
const device itype* biases [[buffer(2)]], \
const device itype* x [[buffer(3)]], \
device itype* y [[buffer(4)]], \
const constant int& in_vec_size [[buffer(5)]], \
const constant int& out_vec_size [[buffer(6)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_qmv_fast_types(group_size, bits, packs_per_thread) \
// clang-format off
#define instantiate_qmv_fast_types(group_size, bits, packs_per_thread) \
instantiate_qmv_fast(float32, float, group_size, bits, packs_per_thread) \
instantiate_qmv_fast(float16, half, group_size, bits, packs_per_thread) \
instantiate_qmv_fast(bfloat16, bfloat16_t, group_size, bits, packs_per_thread)
instantiate_qmv_fast(float16, half, group_size, bits, packs_per_thread) \
instantiate_qmv_fast(bfloat16, bfloat16_t, group_size, bits, packs_per_thread) // clang-format on
// clang-format off
instantiate_qmv_fast_types(128, 2, 1)
instantiate_qmv_fast_types(128, 4, 2)
instantiate_qmv_fast_types(128, 8, 2)
@@ -875,27 +953,30 @@ instantiate_qmv_fast_types( 64, 4, 2)
instantiate_qmv_fast_types( 64, 8, 2)
instantiate_qmv_fast_types( 32, 2, 1)
instantiate_qmv_fast_types( 32, 4, 2)
instantiate_qmv_fast_types( 32, 8, 2)
instantiate_qmv_fast_types( 32, 8, 2) // clang-format on
#define instantiate_qmv(name, itype, group_size, bits) \
template [[host_name("qmv_" #name "_gs_" #group_size "_b_" #bits)]] \
[[kernel]] void qmv<itype, group_size, bits>( \
const device uint32_t* w [[buffer(0)]], \
const device itype* scales [[buffer(1)]], \
const device itype* biases [[buffer(2)]], \
const device itype* x [[buffer(3)]], \
device itype* y [[buffer(4)]], \
const constant int& in_vec_size [[buffer(5)]], \
const constant int& out_vec_size [[buffer(6)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_qmv(name, itype, group_size, bits) \
template [[host_name("qmv_" #name "_gs_" #group_size \
"_b_" #bits)]] [[kernel]] void \
qmv<itype, group_size, bits>( \
const device uint32_t* w [[buffer(0)]], \
const device itype* scales [[buffer(1)]], \
const device itype* biases [[buffer(2)]], \
const device itype* x [[buffer(3)]], \
device itype* y [[buffer(4)]], \
const constant int& in_vec_size [[buffer(5)]], \
const constant int& out_vec_size [[buffer(6)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_qmv_types(group_size, bits) \
// clang-format off
#define instantiate_qmv_types(group_size, bits) \
instantiate_qmv(float32, float, group_size, bits) \
instantiate_qmv(float16, half, group_size, bits) \
instantiate_qmv(bfloat16, bfloat16_t, group_size, bits)
instantiate_qmv(float16, half, group_size, bits) \
instantiate_qmv(bfloat16, bfloat16_t, group_size, bits) // clang-format on
// clang-format off
instantiate_qmv_types(128, 2)
instantiate_qmv_types(128, 4)
instantiate_qmv_types(128, 8)
@@ -904,27 +985,30 @@ instantiate_qmv_types( 64, 4)
instantiate_qmv_types( 64, 8)
instantiate_qmv_types( 32, 2)
instantiate_qmv_types( 32, 4)
instantiate_qmv_types( 32, 8)
instantiate_qmv_types( 32, 8) // clang-format on
#define instantiate_qvm(name, itype, group_size, bits) \
template [[host_name("qvm_" #name "_gs_" #group_size "_b_" #bits)]] \
[[kernel]] void qvm<itype, group_size, bits>( \
const device itype* x [[buffer(0)]], \
const device uint32_t* w [[buffer(1)]], \
const device itype* scales [[buffer(2)]], \
const device itype* biases [[buffer(3)]], \
device itype* y [[buffer(4)]], \
const constant int& in_vec_size [[buffer(5)]], \
const constant int& out_vec_size [[buffer(6)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_qvm(name, itype, group_size, bits) \
template [[host_name("qvm_" #name "_gs_" #group_size \
"_b_" #bits)]] [[kernel]] void \
qvm<itype, group_size, bits>( \
const device itype* x [[buffer(0)]], \
const device uint32_t* w [[buffer(1)]], \
const device itype* scales [[buffer(2)]], \
const device itype* biases [[buffer(3)]], \
device itype* y [[buffer(4)]], \
const constant int& in_vec_size [[buffer(5)]], \
const constant int& out_vec_size [[buffer(6)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_qvm_types(group_size, bits) \
// clang-format off
#define instantiate_qvm_types(group_size, bits) \
instantiate_qvm(float32, float, group_size, bits) \
instantiate_qvm(float16, half, group_size, bits) \
instantiate_qvm(bfloat16, bfloat16_t, group_size, bits)
instantiate_qvm(float16, half, group_size, bits) \
instantiate_qvm(bfloat16, bfloat16_t, group_size, bits) // clang-format on
// clang-format off
instantiate_qvm_types(128, 2)
instantiate_qvm_types(128, 4)
instantiate_qvm_types(128, 8)
@@ -933,32 +1017,35 @@ instantiate_qvm_types( 64, 4)
instantiate_qvm_types( 64, 8)
instantiate_qvm_types( 32, 2)
instantiate_qvm_types( 32, 4)
instantiate_qvm_types( 32, 8)
instantiate_qvm_types( 32, 8) // clang-format on
#define instantiate_qmm_t(name, itype, group_size, bits, aligned_N) \
template [[host_name("qmm_t_" #name "_gs_" #group_size "_b_" #bits "_alN_" #aligned_N)]] \
[[kernel]] void qmm_t<itype, 32, 32, 32, group_size, bits, aligned_N>( \
const device itype* x [[buffer(0)]], \
const device uint32_t* w [[buffer(1)]], \
const device itype* scales [[buffer(2)]], \
const device itype* biases [[buffer(3)]], \
device itype* y [[buffer(4)]], \
const constant int& M [[buffer(5)]], \
const constant int& N [[buffer(6)]], \
const constant int& K [[buffer(7)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint lid [[thread_index_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
#define instantiate_qmm_t(name, itype, group_size, bits, aligned_N) \
template [[host_name("qmm_t_" #name "_gs_" #group_size "_b_" #bits \
"_alN_" #aligned_N)]] [[kernel]] void \
qmm_t<itype, 32, 32, 32, group_size, bits, aligned_N>( \
const device itype* x [[buffer(0)]], \
const device uint32_t* w [[buffer(1)]], \
const device itype* scales [[buffer(2)]], \
const device itype* biases [[buffer(3)]], \
device itype* y [[buffer(4)]], \
const constant int& M [[buffer(5)]], \
const constant int& N [[buffer(6)]], \
const constant int& K [[buffer(7)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint lid [[thread_index_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_qmm_t_types(group_size, bits) \
instantiate_qmm_t(float32, float, group_size, bits, false) \
instantiate_qmm_t(float16, half, group_size, bits, false) \
// clang-format off
#define instantiate_qmm_t_types(group_size, bits) \
instantiate_qmm_t(float32, float, group_size, bits, false) \
instantiate_qmm_t(float16, half, group_size, bits, false) \
instantiate_qmm_t(bfloat16, bfloat16_t, group_size, bits, false) \
instantiate_qmm_t(float32, float, group_size, bits, true) \
instantiate_qmm_t(float16, half, group_size, bits, true) \
instantiate_qmm_t(bfloat16, bfloat16_t, group_size, bits, true)
instantiate_qmm_t(float32, float, group_size, bits, true) \
instantiate_qmm_t(float16, half, group_size, bits, true) \
instantiate_qmm_t(bfloat16, bfloat16_t, group_size, bits, true) // clang-format on
// clang-format off
instantiate_qmm_t_types(128, 2)
instantiate_qmm_t_types(128, 4)
instantiate_qmm_t_types(128, 8)
@@ -967,29 +1054,32 @@ instantiate_qmm_t_types( 64, 4)
instantiate_qmm_t_types( 64, 8)
instantiate_qmm_t_types( 32, 2)
instantiate_qmm_t_types( 32, 4)
instantiate_qmm_t_types( 32, 8)
instantiate_qmm_t_types( 32, 8) // clang-format on
#define instantiate_qmm_n(name, itype, group_size, bits) \
template [[host_name("qmm_n_" #name "_gs_" #group_size "_b_" #bits)]] \
[[kernel]] void qmm_n<itype, 32, 32, 32, group_size, bits>( \
const device itype* x [[buffer(0)]], \
const device uint32_t* w [[buffer(1)]], \
const device itype* scales [[buffer(2)]], \
const device itype* biases [[buffer(3)]], \
device itype* y [[buffer(4)]], \
const constant int& M [[buffer(5)]], \
const constant int& N [[buffer(6)]], \
const constant int& K [[buffer(7)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint lid [[thread_index_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
template [[host_name("qmm_n_" #name "_gs_" #group_size \
"_b_" #bits)]] [[kernel]] void \
qmm_n<itype, 32, 32, 32, group_size, bits>( \
const device itype* x [[buffer(0)]], \
const device uint32_t* w [[buffer(1)]], \
const device itype* scales [[buffer(2)]], \
const device itype* biases [[buffer(3)]], \
device itype* y [[buffer(4)]], \
const constant int& M [[buffer(5)]], \
const constant int& N [[buffer(6)]], \
const constant int& K [[buffer(7)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint lid [[thread_index_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_qmm_n_types(group_size, bits) \
// clang-format off
#define instantiate_qmm_n_types(group_size, bits) \
instantiate_qmm_n(float32, float, group_size, bits) \
instantiate_qmm_n(float16, half, group_size, bits) \
instantiate_qmm_n(bfloat16, bfloat16_t, group_size, bits)
instantiate_qmm_n(float16, half, group_size, bits) \
instantiate_qmm_n(bfloat16, bfloat16_t, group_size, bits) // clang-format on
// clang-format off
instantiate_qmm_n_types(128, 2)
instantiate_qmm_n_types(128, 4)
instantiate_qmm_n_types(128, 8)
@@ -998,4 +1088,4 @@ instantiate_qmm_n_types( 64, 4)
instantiate_qmm_n_types( 64, 8)
instantiate_qmm_n_types( 32, 2)
instantiate_qmm_n_types( 32, 4)
instantiate_qmm_n_types( 32, 8)
instantiate_qmm_n_types( 32, 8) // clang-format on

View File

@@ -3,9 +3,8 @@
#include "mlx/backend/metal/kernels/utils.h"
static constexpr constant uint32_t rotations[2][4] = {
{13, 15, 26, 6},
{17, 29, 16, 24}
};
{13, 15, 26, 6},
{17, 29, 16, 24}};
union rbits {
uint2 val;
@@ -13,7 +12,6 @@ union rbits {
};
rbits threefry2x32_hash(const thread uint2& key, uint2 count) {
uint4 ks = {key.x, key.y, key.x ^ key.y ^ 0x1BD11BDA};
rbits v;
@@ -51,7 +49,7 @@ rbits threefry2x32_hash(const thread uint2& key, uint2 count) {
out[4 * count.x + i] = bits.bytes[0][i];
}
if (!drop_last) {
if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) {
if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) {
int edge_bytes = (bytes_per_key % 4);
for (int i = 0; i < edge_bytes; ++i) {
out[4 * count.y + i] = bits.bytes[1][i];
@@ -87,7 +85,7 @@ rbits threefry2x32_hash(const thread uint2& key, uint2 count) {
out[4 * count.x + i] = bits.bytes[0][i];
}
if (!drop_last) {
if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) {
if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) {
int edge_bytes = (bytes_per_key % 4);
for (int i = 0; i < edge_bytes; ++i) {
out[4 * count.y + i] = bits.bytes[1][i];

View File

@@ -1,8 +1,8 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/kernels/reduction/utils.h"
#include "mlx/backend/metal/kernels/reduction/ops.h"
#include "mlx/backend/metal/kernels/reduction/reduce_inst.h"
#include "mlx/backend/metal/kernels/reduction/utils.h"
using namespace metal;
@@ -60,14 +60,13 @@ METAL_FUNC U per_thread_all_reduce(
// All reduce kernel
///////////////////////////////////////////////////////////////////////////////
// NB: This kernel assumes threads_per_threadgroup is at most
// 1024. This way with a simd_size of 32, we are guaranteed to
// complete the reduction in two steps of simd-level reductions.
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
[[kernel]] void all_reduce(
const device T *in [[buffer(0)]],
device mlx_atomic<U> *out [[buffer(1)]],
const device T* in [[buffer(0)]],
device mlx_atomic<U>* out [[buffer(1)]],
const device size_t& in_size [[buffer(2)]],
uint gid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
@@ -75,11 +74,11 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
Op op;
threadgroup U local_vals[simd_size];
U total_val = per_thread_all_reduce<T, U, Op, N_READS>(in, in_size, gid, grid_size);
U total_val =
per_thread_all_reduce<T, U, Op, N_READS>(in, in_size, gid, grid_size);
// Reduction within simd group
total_val = op.simd_reduce(total_val);
@@ -98,10 +97,10 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
}
}
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
[[kernel]] void all_reduce_no_atomics(
const device T *in [[buffer(0)]],
device U *out [[buffer(1)]],
const device T* in [[buffer(0)]],
device U* out [[buffer(1)]],
const device size_t& in_size [[buffer(2)]],
uint gid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
@@ -110,14 +109,16 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint thread_group_id [[threadgroup_position_in_grid]]) {
Op op;
threadgroup U local_vals[simd_size];
U total_val = per_thread_all_reduce<T, U, Op, N_READS>(in, in_size, gid, grid_size);
U total_val =
per_thread_all_reduce<T, U, Op, N_READS>(in, in_size, gid, grid_size);
// Reduction within simd group (simd_add isn't supported for uint64/int64 types)
for (uint16_t lane_offset = simd_size/2; lane_offset > 0; lane_offset /= 2) {
// Reduction within simd group (simd_add isn't supported for uint64/int64
// types)
for (uint16_t lane_offset = simd_size / 2; lane_offset > 0;
lane_offset /= 2) {
total_val = op(total_val, simd_shuffle_down(total_val, lane_offset));
}
// Write simd group reduction results to local memory
@@ -128,7 +129,8 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
// Reduction of simdgroup reduction results within threadgroup.
total_val = lid < simd_per_group ? local_vals[lid] : op.init;
for (uint16_t lane_offset = simd_size/2; lane_offset > 0; lane_offset /= 2) {
for (uint16_t lane_offset = simd_size / 2; lane_offset > 0;
lane_offset /= 2) {
total_val = op(total_val, simd_shuffle_down(total_val, lane_offset));
}
@@ -138,31 +140,31 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
}
}
#define instantiate_all_reduce(name, itype, otype, op) \
template [[host_name("all_reduce_" #name)]] \
[[kernel]] void all_reduce<itype, otype, op>( \
const device itype *in [[buffer(0)]], \
device mlx_atomic<otype> *out [[buffer(1)]], \
const device size_t& in_size [[buffer(2)]], \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint grid_size [[threads_per_grid]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
#define instantiate_all_reduce(name, itype, otype, op) \
template [[host_name("all_reduce_" #name)]] [[kernel]] void \
all_reduce<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device mlx_atomic<otype>* out [[buffer(1)]], \
const device size_t& in_size [[buffer(2)]], \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint grid_size [[threads_per_grid]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_all_reduce_no_atomics(name, itype, otype, op) \
template [[host_name("all_reduce_no_atomics_" #name)]] \
[[kernel]] void all_reduce_no_atomics<itype, otype, op>( \
const device itype *in [[buffer(0)]], \
device otype *out [[buffer(1)]], \
const device size_t& in_size [[buffer(2)]], \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint grid_size [[threads_per_grid]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
#define instantiate_all_reduce_no_atomics(name, itype, otype, op) \
template [[host_name("all_reduce_no_atomics_" #name)]] [[kernel]] void \
all_reduce_no_atomics<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const device size_t& in_size [[buffer(2)]], \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint grid_size [[threads_per_grid]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint thread_group_id [[threadgroup_position_in_grid]]);
///////////////////////////////////////////////////////////////////////////////
@@ -170,11 +172,12 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
///////////////////////////////////////////////////////////////////////////////
#define instantiate_same_all_reduce_helper(name, tname, type, op) \
instantiate_all_reduce(name ##tname, type, type, op<type>)
instantiate_all_reduce(name##tname, type, type, op<type>)
#define instantiate_same_all_reduce_na_helper(name, tname, type, op) \
instantiate_all_reduce_no_atomics(name ##tname, type, type, op<type>)
instantiate_all_reduce_no_atomics(name##tname, type, type, op<type>)
// clang-format off
instantiate_reduce_ops(instantiate_same_all_reduce_helper, instantiate_reduce_helper_types)
instantiate_reduce_ops(instantiate_same_all_reduce_na_helper, instantiate_reduce_helper_64b)
@@ -182,4 +185,4 @@ instantiate_reduce_from_types(instantiate_all_reduce, and, bool, And)
instantiate_reduce_from_types(instantiate_all_reduce, or, bool, Or)
// special case bool with larger output type
instantiate_all_reduce(sumbool_, bool, uint32_t, Sum<uint32_t>)
instantiate_all_reduce(sumbool_, bool, uint32_t, Sum<uint32_t>) // clang-format on

View File

@@ -1,8 +1,8 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/kernels/reduction/utils.h"
#include "mlx/backend/metal/kernels/reduction/ops.h"
#include "mlx/backend/metal/kernels/reduction/reduce_inst.h"
#include "mlx/backend/metal/kernels/reduction/utils.h"
using namespace metal;
@@ -12,8 +12,8 @@ using namespace metal;
template <typename T, typename U, typename Op>
[[kernel]] void col_reduce_small(
const device T *in [[buffer(0)]],
device U *out [[buffer(1)]],
const device T* in [[buffer(0)]],
device U* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant size_t& out_size [[buffer(4)]],
@@ -25,7 +25,6 @@ template <typename T, typename U, typename Op>
const constant size_t* non_col_strides [[buffer(10)]],
const constant int& non_col_ndim [[buffer(11)]],
uint tid [[thread_position_in_grid]]) {
// Appease the compiler
(void)out_size;
@@ -35,15 +34,16 @@ template <typename T, typename U, typename Op>
auto out_idx = tid;
in += elem_to_loc(
out_idx,
shape + non_col_ndim,
strides + non_col_ndim,
ndim - non_col_ndim);
out_idx,
shape + non_col_ndim,
strides + non_col_ndim,
ndim - non_col_ndim);
for(uint i = 0; i < non_col_reductions; i++) {
size_t in_idx = elem_to_loc(i, non_col_shapes, non_col_strides, non_col_ndim);
for (uint i = 0; i < non_col_reductions; i++) {
size_t in_idx =
elem_to_loc(i, non_col_shapes, non_col_strides, non_col_ndim);
for(uint j = 0; j < reduction_size; j++, in_idx += reduction_stride) {
for (uint j = 0; j < reduction_size; j++, in_idx += reduction_stride) {
U val = static_cast<U>(in[in_idx]);
total_val = op(total_val, val);
}
@@ -52,21 +52,21 @@ template <typename T, typename U, typename Op>
out[out_idx] = total_val;
}
#define instantiate_col_reduce_small(name, itype, otype, op) \
template [[host_name("col_reduce_small_" #name)]] \
[[kernel]] void col_reduce_small<itype, otype, op>( \
const device itype *in [[buffer(0)]], \
device otype *out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& reduction_stride [[buffer(3)]], \
const constant size_t& out_size [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
const constant size_t& non_col_reductions [[buffer(8)]], \
const constant int* non_col_shapes [[buffer(9)]], \
const constant size_t* non_col_strides [[buffer(10)]], \
const constant int& non_col_ndim [[buffer(11)]], \
#define instantiate_col_reduce_small(name, itype, otype, op) \
template [[host_name("col_reduce_small_" #name)]] [[kernel]] void \
col_reduce_small<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& reduction_stride [[buffer(3)]], \
const constant size_t& out_size [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
const constant size_t& non_col_reductions [[buffer(8)]], \
const constant int* non_col_shapes [[buffer(9)]], \
const constant size_t* non_col_strides [[buffer(10)]], \
const constant int& non_col_ndim [[buffer(11)]], \
uint tid [[thread_position_in_grid]]);
///////////////////////////////////////////////////////////////////////////////
@@ -112,39 +112,35 @@ METAL_FUNC U _contiguous_strided_reduce(
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
[[kernel]] void col_reduce_general(
const device T *in [[buffer(0)]],
device mlx_atomic<U> *out [[buffer(1)]],
const device T* in [[buffer(0)]],
device mlx_atomic<U>* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant size_t& out_size [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
threadgroup U *local_data [[threadgroup(0)]],
threadgroup U* local_data [[threadgroup(0)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 lsize [[threads_per_threadgroup]]) {
auto out_idx = tid.x * lsize.x + lid.x;
auto in_idx = elem_to_loc(
out_idx + tid.z * out_size,
shape,
strides,
ndim
);
auto in_idx = elem_to_loc(out_idx + tid.z * out_size, shape, strides, ndim);
Op op;
if(out_idx < out_size) {
if (out_idx < out_size) {
U val = _contiguous_strided_reduce<T, U, Op, N_READS>(
in,
local_data,
in_idx,
reduction_size,
reduction_stride,
tid.xy,
lid.xy,
lsize.xy);
in,
local_data,
in_idx,
reduction_size,
reduction_stride,
tid.xy,
lid.xy,
lsize.xy);
// Write out reduction results generated by threadgroups working on specific output element, contiguously.
// Write out reduction results generated by threadgroups working on specific
// output element, contiguously.
if (lid.y == 0) {
op.atomic_update(out, val, out_idx);
}
@@ -153,40 +149,36 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
[[kernel]] void col_reduce_general_no_atomics(
const device T *in [[buffer(0)]],
device U *out [[buffer(1)]],
const device T* in [[buffer(0)]],
device U* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant size_t& out_size [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
threadgroup U *local_data [[threadgroup(0)]],
threadgroup U* local_data [[threadgroup(0)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 gid [[thread_position_in_grid]],
uint3 lsize [[threads_per_threadgroup]],
uint3 gsize [[threads_per_grid]]) {
auto out_idx = tid.x * lsize.x + lid.x;
auto in_idx = elem_to_loc(
out_idx + tid.z * out_size,
shape,
strides,
ndim
);
auto in_idx = elem_to_loc(out_idx + tid.z * out_size, shape, strides, ndim);
if(out_idx < out_size) {
if (out_idx < out_size) {
U val = _contiguous_strided_reduce<T, U, Op, N_READS>(
in,
local_data,
in_idx,
reduction_size,
reduction_stride,
tid.xy,
lid.xy,
lsize.xy);
in,
local_data,
in_idx,
reduction_size,
reduction_stride,
tid.xy,
lid.xy,
lsize.xy);
// Write out reduction results generated by threadgroups working on specific output element, contiguously.
// Write out reduction results generated by threadgroups working on specific
// output element, contiguously.
if (lid.y == 0) {
uint tgsize_y = ceildiv(gsize.y, lsize.y);
uint tgsize_z = ceildiv(gsize.z, lsize.z);
@@ -195,52 +187,56 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
}
}
#define instantiate_col_reduce_general(name, itype, otype, op) \
template [[host_name("col_reduce_general_" #name)]] \
[[kernel]] void col_reduce_general<itype, otype, op>( \
const device itype *in [[buffer(0)]], \
device mlx_atomic<otype> *out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& reduction_stride [[buffer(3)]], \
const constant size_t& out_size [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
threadgroup otype *local_data [[threadgroup(0)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
#define instantiate_col_reduce_general(name, itype, otype, op) \
template [[host_name("col_reduce_general_" #name)]] [[kernel]] void \
col_reduce_general<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device mlx_atomic<otype>* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& reduction_stride [[buffer(3)]], \
const constant size_t& out_size [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
threadgroup otype* local_data [[threadgroup(0)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 lsize [[threads_per_threadgroup]]);
#define instantiate_col_reduce_general_no_atomics(name, itype, otype, op) \
template [[host_name("col_reduce_general_no_atomics_" #name)]] \
[[kernel]] void col_reduce_general_no_atomics<itype, otype, op>( \
const device itype *in [[buffer(0)]], \
device otype *out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& reduction_stride [[buffer(3)]], \
const constant size_t& out_size [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
threadgroup otype *local_data [[threadgroup(0)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 gid [[thread_position_in_grid]], \
uint3 lsize [[threads_per_threadgroup]], \
uint3 gsize [[threads_per_grid]]);
#define instantiate_col_reduce_general_no_atomics(name, itype, otype, op) \
template \
[[host_name("col_reduce_general_no_atomics_" #name)]] [[kernel]] void \
col_reduce_general_no_atomics<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& reduction_stride [[buffer(3)]], \
const constant size_t& out_size [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
threadgroup otype* local_data [[threadgroup(0)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 gid [[thread_position_in_grid]], \
uint3 lsize [[threads_per_threadgroup]], \
uint3 gsize [[threads_per_grid]]);
///////////////////////////////////////////////////////////////////////////////
// Instantiations
///////////////////////////////////////////////////////////////////////////////
#define instantiate_same_col_reduce_helper(name, tname, type, op) \
// clang-format off
#define instantiate_same_col_reduce_helper(name, tname, type, op) \
instantiate_col_reduce_small(name ##tname, type, type, op<type>) \
instantiate_col_reduce_general(name ##tname, type, type, op<type>)
instantiate_col_reduce_general(name ##tname, type, type, op<type>) // clang-format on
// clang-format off
#define instantiate_same_col_reduce_na_helper(name, tname, type, op) \
instantiate_col_reduce_small(name ##tname, type, type, op<type>) \
instantiate_col_reduce_general_no_atomics(name ##tname, type, type, op<type>)
instantiate_col_reduce_small(name ##tname, type, type, op<type>) \
instantiate_col_reduce_general_no_atomics(name ##tname, type, type, op<type>) // clang-format on
// clang-format off
instantiate_reduce_ops(instantiate_same_col_reduce_helper, instantiate_reduce_helper_types)
instantiate_reduce_ops(instantiate_same_col_reduce_na_helper, instantiate_reduce_helper_64b)
@@ -250,4 +246,4 @@ instantiate_reduce_from_types(instantiate_col_reduce_general, or, bool, Or)
instantiate_col_reduce_small(sumbool_, bool, uint32_t, Sum<uint32_t>)
instantiate_reduce_from_types(instantiate_col_reduce_small, and, bool, And)
instantiate_reduce_from_types(instantiate_col_reduce_small, or, bool, Or)
instantiate_reduce_from_types(instantiate_col_reduce_small, or, bool, Or) // clang-format on

View File

@@ -1,8 +1,8 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/kernels/reduction/utils.h"
#include "mlx/backend/metal/kernels/reduction/ops.h"
#include "mlx/backend/metal/kernels/reduction/reduce_inst.h"
#include "mlx/backend/metal/kernels/reduction/utils.h"
using namespace metal;
@@ -12,22 +12,21 @@ using namespace metal;
template <typename T, typename Op>
[[kernel]] void init_reduce(
device T *out [[buffer(0)]],
device T* out [[buffer(0)]],
uint tid [[thread_position_in_grid]]) {
out[tid] = Op::init;
}
#define instantiate_init_reduce(name, otype, op) \
template [[host_name("i" #name)]] \
[[kernel]] void init_reduce<otype, op>( \
device otype *out [[buffer(1)]], \
uint tid [[thread_position_in_grid]]);
#define instantiate_init_reduce(name, otype, op) \
template [[host_name("i" #name)]] [[kernel]] void init_reduce<otype, op>( \
device otype * out [[buffer(1)]], uint tid [[thread_position_in_grid]]);
#define instantiate_init_reduce_helper(name, tname, type, op) \
instantiate_init_reduce(name ##tname, type, op<type>)
instantiate_init_reduce(name##tname, type, op<type>)
// clang-format off
instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_types)
instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_64b)
instantiate_init_reduce(andbool_, bool, And)
instantiate_init_reduce(orbool_, bool, Or)
instantiate_init_reduce(orbool_, bool, Or) // clang-format on

View File

@@ -1,8 +1,8 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/kernels/reduction/utils.h"
#include "mlx/backend/metal/kernels/reduction/ops.h"
#include "mlx/backend/metal/kernels/reduction/reduce_inst.h"
#include "mlx/backend/metal/kernels/reduction/utils.h"
using namespace metal;
@@ -13,8 +13,8 @@ using namespace metal;
// Each thread reduces for one output
template <typename T, typename U, typename Op>
[[kernel]] void row_reduce_general_small(
const device T *in [[buffer(0)]],
device U *out [[buffer(1)]],
const device T* in [[buffer(0)]],
device U* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& out_size [[buffer(3)]],
const constant size_t& non_row_reductions [[buffer(4)]],
@@ -22,22 +22,21 @@ template <typename T, typename U, typename Op>
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
uint lid [[thread_position_in_grid]]) {
Op op;
uint out_idx = lid;
if(out_idx >= out_size) {
if (out_idx >= out_size) {
return;
}
U total_val = Op::init;
for(short r = 0; r < short(non_row_reductions); r++) {
for (short r = 0; r < short(non_row_reductions); r++) {
uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim);
const device T * in_row = in + in_idx;
for(short i = 0; i < short(reduction_size); i++) {
const device T* in_row = in + in_idx;
for (short i = 0; i < short(reduction_size); i++) {
total_val = op(static_cast<U>(in_row[i]), total_val);
}
}
@@ -48,8 +47,8 @@ template <typename T, typename U, typename Op>
// Each simdgroup reduces for one output
template <typename T, typename U, typename Op>
[[kernel]] void row_reduce_general_med(
const device T *in [[buffer(0)]],
device U *out [[buffer(1)]],
const device T* in [[buffer(0)]],
device U* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& out_size [[buffer(3)]],
const constant size_t& non_row_reductions [[buffer(4)]],
@@ -60,45 +59,42 @@ template <typename T, typename U, typename Op>
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_per_group [[dispatch_simdgroups_per_threadgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
Op op;
uint out_idx = simd_per_group * tid + simd_group_id;
if(out_idx >= out_size) {
if (out_idx >= out_size) {
return;
}
U total_val = Op::init;
if(short(non_row_reductions) == 1) {
if (short(non_row_reductions) == 1) {
uint in_idx = elem_to_loc(out_idx, shape, strides, ndim);
const device T * in_row = in + in_idx;
const device T* in_row = in + in_idx;
for(short i = simd_lane_id; i < short(reduction_size); i += 32) {
for (short i = simd_lane_id; i < short(reduction_size); i += 32) {
total_val = op(static_cast<U>(in_row[i]), total_val);
}
}
else if (short(non_row_reductions) >= 32) {
for(short r = simd_lane_id; r < short(non_row_reductions); r+=32) {
for (short r = simd_lane_id; r < short(non_row_reductions); r += 32) {
uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim);
const device T * in_row = in + in_idx;
const device T* in_row = in + in_idx;
for(short i = 0; i < short(reduction_size); i++) {
for (short i = 0; i < short(reduction_size); i++) {
total_val = op(static_cast<U>(in_row[i]), total_val);
}
}
}
else {
const short n_reductions = short(reduction_size) * short(non_row_reductions);
const short reductions_per_thread = (n_reductions + simd_size - 1) / simd_size;
const short n_reductions =
short(reduction_size) * short(non_row_reductions);
const short reductions_per_thread =
(n_reductions + simd_size - 1) / simd_size;
const short r_st = simd_lane_id / reductions_per_thread;
const short r_ed = short(non_row_reductions);
@@ -108,54 +104,50 @@ template <typename T, typename U, typename Op>
const short i_ed = short(reduction_size);
const short i_jump = reductions_per_thread;
if(r_st < r_jump) {
for(short r = r_st; r < r_ed; r += r_jump) {
if (r_st < r_jump) {
for (short r = r_st; r < r_ed; r += r_jump) {
uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim);
const device T * in_row = in + in_idx;
const device T* in_row = in + in_idx;
for(short i = i_st; i < i_ed; i += i_jump) {
for (short i = i_st; i < i_ed; i += i_jump) {
total_val = op(static_cast<U>(in_row[i]), total_val);
}
}
}
}
total_val = op.simd_reduce(total_val);
if(simd_lane_id == 0) {
if (simd_lane_id == 0) {
out[out_idx] = total_val;
}
}
#define instantiate_row_reduce_small(name, itype, otype, op) \
template[[host_name("row_reduce_general_small_" #name)]] \
[[kernel]] void row_reduce_general_small<itype, otype, op>( \
const device itype *in [[buffer(0)]], \
device otype *out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& out_size [[buffer(3)]], \
const constant size_t& non_row_reductions [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
uint lid [[thread_position_in_grid]]); \
template[[host_name("row_reduce_general_med_" #name)]] \
[[kernel]] void row_reduce_general_med<itype, otype, op>( \
const device itype *in [[buffer(0)]], \
device otype *out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& out_size [[buffer(3)]], \
const constant size_t& non_row_reductions [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
uint tid [[threadgroup_position_in_grid]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_per_group [[dispatch_simdgroups_per_threadgroup]], \
#define instantiate_row_reduce_small(name, itype, otype, op) \
template [[host_name("row_reduce_general_small_" #name)]] [[kernel]] void \
row_reduce_general_small<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& out_size [[buffer(3)]], \
const constant size_t& non_row_reductions [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
uint lid [[thread_position_in_grid]]); \
template [[host_name("row_reduce_general_med_" #name)]] [[kernel]] void \
row_reduce_general_med<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& out_size [[buffer(3)]], \
const constant size_t& non_row_reductions [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
uint tid [[threadgroup_position_in_grid]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_per_group [[dispatch_simdgroups_per_threadgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
///////////////////////////////////////////////////////////////////////////////
@@ -217,10 +209,10 @@ METAL_FUNC U per_thread_row_reduce(
return total_val;
}
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
[[kernel]] void row_reduce_general(
const device T *in [[buffer(0)]],
device mlx_atomic<U> *out [[buffer(1)]],
const device T* in [[buffer(0)]],
device mlx_atomic<U>* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& out_size [[buffer(3)]],
const constant size_t& non_row_reductions [[buffer(4)]],
@@ -233,25 +225,33 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
(void)non_row_reductions;
Op op;
threadgroup U local_vals[simd_size];
U total_val = per_thread_row_reduce<T, U, Op, N_READS>(in, reduction_size, out_size, shape, strides, ndim, lsize.x, lid.x, tid.xy);
U total_val = per_thread_row_reduce<T, U, Op, N_READS>(
in,
reduction_size,
out_size,
shape,
strides,
ndim,
lsize.x,
lid.x,
tid.xy);
total_val = op.simd_reduce(total_val);
// Prepare next level
if (simd_lane_id == 0) {
local_vals[simd_group_id] = total_val;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Reduction within thread group
// Only needed if multiple simd groups
if(reduction_size > simd_size) {
if (reduction_size > simd_size) {
total_val = lid.x < simd_per_group ? local_vals[lid.x] : op.init;
total_val = op.simd_reduce(total_val);
}
@@ -261,10 +261,10 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
}
}
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
[[kernel]] void row_reduce_general_no_atomics(
const device T *in [[buffer(0)]],
device U *out [[buffer(1)]],
const device T* in [[buffer(0)]],
device U* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& out_size [[buffer(3)]],
const constant size_t& non_row_reductions [[buffer(4)]],
@@ -278,16 +278,24 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
(void)non_row_reductions;
Op op;
threadgroup U local_vals[simd_size];
U total_val = per_thread_row_reduce<T, U, Op, N_READS>(in, reduction_size, out_size, shape, strides, ndim, lsize.x, lid.x, tid.xy);
U total_val = per_thread_row_reduce<T, U, Op, N_READS>(
in,
reduction_size,
out_size,
shape,
strides,
ndim,
lsize.x,
lid.x,
tid.xy);
// Reduction within simd group - simd_add isn't supported for int64 types
for (uint16_t i = simd_size/2; i > 0; i /= 2) {
for (uint16_t i = simd_size / 2; i > 0; i /= 2) {
total_val = op(total_val, simd_shuffle_down(total_val, i));
}
@@ -299,9 +307,9 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
// Reduction within thread group
// Only needed if thread group has multiple simd groups
if(ceildiv(reduction_size, N_READS) > simd_size) {
if (ceildiv(reduction_size, N_READS) > simd_size) {
total_val = lid.x < simd_per_group ? local_vals[lid.x] : op.init;
for (uint16_t i = simd_size/2; i > 0; i /= 2) {
for (uint16_t i = simd_size / 2; i > 0; i /= 2) {
total_val = op(total_val, simd_shuffle_down(total_val, i));
}
}
@@ -311,61 +319,60 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
}
}
#define instantiate_row_reduce_general(name, itype, otype, op) \
instantiate_row_reduce_small(name, itype, otype, op) \
template [[host_name("row_reduce_general_" #name)]] \
[[kernel]] void row_reduce_general<itype, otype, op>( \
const device itype *in [[buffer(0)]], \
device mlx_atomic<otype> *out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& out_size [[buffer(3)]], \
const constant size_t& non_row_reductions [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 lsize [[threads_per_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_row_reduce_general_no_atomics(name, itype, otype, op) \
instantiate_row_reduce_small(name, itype, otype, op) \
template [[host_name("row_reduce_general_no_atomics_" #name)]] \
[[kernel]] void row_reduce_general_no_atomics<itype, otype, op>( \
const device itype *in [[buffer(0)]], \
device otype *out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& out_size [[buffer(3)]], \
const constant size_t& non_row_reductions [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 lsize [[threads_per_threadgroup]], \
uint3 gsize [[threads_per_grid]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_row_reduce_general(name, itype, otype, op) \
instantiate_row_reduce_small(name, itype, otype, op) template \
[[host_name("row_reduce_general_" #name)]] [[kernel]] void \
row_reduce_general<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device mlx_atomic<otype>* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& out_size [[buffer(3)]], \
const constant size_t& non_row_reductions [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 lsize [[threads_per_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_row_reduce_general_no_atomics(name, itype, otype, op) \
instantiate_row_reduce_small(name, itype, otype, op) template \
[[host_name("row_reduce_general_no_atomics_" #name)]] [[kernel]] void \
row_reduce_general_no_atomics<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& out_size [[buffer(3)]], \
const constant size_t& non_row_reductions [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 lsize [[threads_per_threadgroup]], \
uint3 gsize [[threads_per_grid]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
///////////////////////////////////////////////////////////////////////////////
// Instantiations
///////////////////////////////////////////////////////////////////////////////
#define instantiate_same_row_reduce_helper(name, tname, type, op) \
instantiate_row_reduce_general(name ##tname, type, type, op<type>)
instantiate_row_reduce_general(name##tname, type, type, op<type>)
#define instantiate_same_row_reduce_na_helper(name, tname, type, op) \
instantiate_row_reduce_general_no_atomics(name ##tname, type, type, op<type>)
instantiate_row_reduce_general_no_atomics(name##tname, type, type, op<type>)
// clang-format off
instantiate_reduce_ops(instantiate_same_row_reduce_helper, instantiate_reduce_helper_types)
instantiate_reduce_ops(instantiate_same_row_reduce_na_helper, instantiate_reduce_helper_64b)
instantiate_reduce_from_types(instantiate_row_reduce_general, and, bool, And)
instantiate_reduce_from_types(instantiate_row_reduce_general, or, bool, Or)
instantiate_row_reduce_general(sumbool_, bool, uint32_t, Sum<uint32_t>)
instantiate_row_reduce_general(sumbool_, bool, uint32_t, Sum<uint32_t>) // clang-format on

View File

@@ -237,13 +237,17 @@ template <typename T, int N_READS = RMS_N_READS>
gw += gid * axis_size + lid * N_READS;
if (lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
gx[i] = static_cast<T>(thread_g[i] * thread_w[i] * normalizer - thread_x[i] * meangwx * normalizer3);
gx[i] = static_cast<T>(
thread_g[i] * thread_w[i] * normalizer -
thread_x[i] * meangwx * normalizer3);
gw[i] = static_cast<T>(thread_g[i] * thread_x[i] * normalizer);
}
} else {
for (int i = 0; i < N_READS; i++) {
if ((lid * N_READS + i) < axis_size) {
gx[i] = static_cast<T>(thread_g[i] * thread_w[i] * normalizer - thread_x[i] * meangwx * normalizer3);
gx[i] = static_cast<T>(
thread_g[i] * thread_w[i] * normalizer -
thread_x[i] * meangwx * normalizer3);
gw[i] = static_cast<T>(thread_g[i] * thread_x[i] * normalizer);
}
}
@@ -342,7 +346,8 @@ template <typename T, int N_READS = RMS_N_READS>
float wi = w[w_stride * (i + r)];
float gi = g[i + r];
gx[i + r] = static_cast<T>(gi * wi * normalizer - xi * meangwx * normalizer3);
gx[i + r] =
static_cast<T>(gi * wi * normalizer - xi * meangwx * normalizer3);
gw[i + r] = static_cast<T>(gi * xi * normalizer);
}
} else {
@@ -352,7 +357,8 @@ template <typename T, int N_READS = RMS_N_READS>
float wi = w[w_stride * (i + r)];
float gi = g[i + r];
gx[i + r] = static_cast<T>(gi * wi * normalizer - xi * meangwx * normalizer3);
gx[i + r] =
static_cast<T>(gi * wi * normalizer - xi * meangwx * normalizer3);
gw[i + r] = static_cast<T>(gi * xi * normalizer);
}
}
@@ -431,5 +437,4 @@ template <typename T, int N_READS = RMS_N_READS>
instantiate_rms(float32, float)
instantiate_rms(float16, half)
instantiate_rms(bfloat16, bfloat16_t)
// clang-format on
instantiate_rms(bfloat16, bfloat16_t) // clang-format on

View File

@@ -7,8 +7,8 @@
template <typename T, bool traditional, bool forward>
[[kernel]] void rope(
const device T *in [[buffer(0)]],
device T * out [[buffer(1)]],
const device T* in [[buffer(0)]],
device T* out [[buffer(1)]],
constant const size_t strides[3],
constant const size_t out_strides[3],
constant const int& offset,
@@ -20,12 +20,15 @@ template <typename T, bool traditional, bool forward>
uint in_index_1, in_index_2;
uint out_index_1, out_index_2;
if (traditional) {
out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] + pos.z * out_strides[0];
out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] +
pos.z * out_strides[0];
out_index_2 = out_index_1 + 1;
in_index_1 = 2 * pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0];
in_index_1 =
2 * pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0];
in_index_2 = in_index_1 + strides[2];
} else {
out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] + pos.z * out_strides[0];
out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] +
pos.z * out_strides[0];
out_index_2 = out_index_1 + grid.x * out_strides[2];
in_index_1 = pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0];
in_index_2 = in_index_1 + grid.x * strides[2];
@@ -57,18 +60,19 @@ template <typename T, bool traditional, bool forward>
}
#define instantiate_rope(name, type, traditional, forward) \
template [[host_name("rope_" #name)]] \
[[kernel]] void rope<type, traditional, forward>( \
const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \
constant const size_t strides[3], \
constant const size_t out_strides[3], \
constant const int& offset, \
constant const float& base, \
constant const float& scale, \
uint3 pos [[thread_position_in_grid]], \
uint3 grid [[threads_per_grid]]);
template [[host_name("rope_" #name)]] [[kernel]] void \
rope<type, traditional, forward>( \
const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \
constant const size_t strides[3], \
constant const size_t out_strides[3], \
constant const int& offset, \
constant const float& base, \
constant const float& scale, \
uint3 pos [[thread_position_in_grid]], \
uint3 grid [[threads_per_grid]]);
// clang-format off
instantiate_rope(traditional_float16, half, true, true)
instantiate_rope(traditional_bfloat16, bfloat16_t, true, true)
instantiate_rope(traditional_float32, float, true, true)
@@ -80,4 +84,4 @@ instantiate_rope(vjp_traditional_bfloat16, bfloat16_t, true, false)
instantiate_rope(vjp_traditional_float32, float, true, false)
instantiate_rope(vjp_float16, half, false, false)
instantiate_rope(vjp_bfloat16, bfloat16_t, false, false)
instantiate_rope(vjp_float32, float, false, false)
instantiate_rope(vjp_float32, float, false, false) // clang-format on

View File

@@ -1,451 +1,551 @@
#include <metal_stdlib>
#include <metal_simdgroup>
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/scaled_dot_product_attention_params.h"
using namespace metal;
template<typename T, typename T2, typename T4, uint16_t TILE_SIZE_CONST, uint16_t NSIMDGROUPS>
[[kernel]] void fast_inference_sdpa_compute_partials_template(const device T *Q [[buffer(0)]],
const device T *K [[buffer(1)]],
const device T *V [[buffer(2)]],
const device uint64_t& L [[buffer(3)]],
const device MLXScaledDotProductAttentionParams& params [[buffer(4)]],
device float* O_partials [[buffer(5)]],
device float* p_lse [[buffer(6)]],
device float* p_maxes [[buffer(7)]],
threadgroup T* threadgroup_block [[threadgroup(0)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]]) {
constexpr const size_t DK = 128;
constexpr const ulong SIMDGROUP_MATRIX_LOAD_FACTOR = 8;
constexpr const size_t THREADS_PER_SIMDGROUP = 32;
constexpr const uint iter_offset = NSIMDGROUPS * 4;
const bool is_gqa = params.N_KV_HEADS != params.N_Q_HEADS;
uint kv_head_offset_factor = tid.x;
if(is_gqa) {
int q_kv_head_ratio = params.N_Q_HEADS / params.N_KV_HEADS;
kv_head_offset_factor = tid.x / q_kv_head_ratio;
template <
typename T,
typename T2,
typename T4,
uint16_t TILE_SIZE_CONST,
uint16_t NSIMDGROUPS>
[[kernel]] void fast_inference_sdpa_compute_partials_template(
const device T* Q [[buffer(0)]],
const device T* K [[buffer(1)]],
const device T* V [[buffer(2)]],
const device uint64_t& L [[buffer(3)]],
const device MLXScaledDotProductAttentionParams& params [[buffer(4)]],
device float* O_partials [[buffer(5)]],
device float* p_lse [[buffer(6)]],
device float* p_maxes [[buffer(7)]],
threadgroup T* threadgroup_block [[threadgroup(0)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]]) {
constexpr const size_t DK = 128;
constexpr const ulong SIMDGROUP_MATRIX_LOAD_FACTOR = 8;
constexpr const size_t THREADS_PER_SIMDGROUP = 32;
constexpr const uint iter_offset = NSIMDGROUPS * 4;
const bool is_gqa = params.N_KV_HEADS != params.N_Q_HEADS;
uint kv_head_offset_factor = tid.x;
if (is_gqa) {
int q_kv_head_ratio = params.N_Q_HEADS / params.N_KV_HEADS;
kv_head_offset_factor = tid.x / q_kv_head_ratio;
}
constexpr const uint16_t P_VEC4 = TILE_SIZE_CONST / NSIMDGROUPS / 4;
constexpr const size_t MATRIX_LOADS_PER_SIMDGROUP =
TILE_SIZE_CONST / (SIMDGROUP_MATRIX_LOAD_FACTOR * NSIMDGROUPS);
constexpr const size_t MATRIX_COLS = DK / SIMDGROUP_MATRIX_LOAD_FACTOR;
constexpr const uint totalSmemV = SIMDGROUP_MATRIX_LOAD_FACTOR *
SIMDGROUP_MATRIX_LOAD_FACTOR * (MATRIX_LOADS_PER_SIMDGROUP + 1) *
NSIMDGROUPS;
threadgroup T4* smemFlush = (threadgroup T4*)threadgroup_block;
#pragma clang loop unroll(full)
for (uint i = 0; i < 8; i++) {
smemFlush
[simd_lane_id + simd_group_id * THREADS_PER_SIMDGROUP +
i * NSIMDGROUPS * THREADS_PER_SIMDGROUP] = T4(0.f);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// TODO: multiple query sequence length for speculative decoding
const uint tgroup_query_head_offset =
tid.x * DK + tid.z * (params.N_Q_HEADS * DK);
const uint tgroup_k_head_offset = kv_head_offset_factor * DK * L;
const uint tgroup_k_tile_offset = tid.y * TILE_SIZE_CONST * DK;
const uint tgroup_k_batch_offset = tid.z * L * params.N_KV_HEADS * DK;
const device T* baseK =
K + tgroup_k_batch_offset + tgroup_k_tile_offset + tgroup_k_head_offset;
const device T* baseQ = Q + tgroup_query_head_offset;
device T4* simdgroupQueryData = (device T4*)baseQ;
constexpr const size_t ACCUM_PER_GROUP = TILE_SIZE_CONST / NSIMDGROUPS;
float threadAccum[ACCUM_PER_GROUP];
#pragma clang loop unroll(full)
for (size_t threadAccumIndex = 0; threadAccumIndex < ACCUM_PER_GROUP;
threadAccumIndex++) {
threadAccum[threadAccumIndex] = -INFINITY;
}
uint KROW_ACCUM_INDEX = 0;
const int32_t SEQUENCE_LENGTH_LESS_TILE_SIZE = L - TILE_SIZE_CONST;
const bool LAST_TILE = (tid.y + 1) * TILE_SIZE_CONST >= L;
const bool LAST_TILE_ALIGNED =
(SEQUENCE_LENGTH_LESS_TILE_SIZE == int32_t(tid.y * TILE_SIZE_CONST));
T4 thread_data_x4;
T4 thread_data_y4;
if (!LAST_TILE || LAST_TILE_ALIGNED) {
thread_data_x4 = *(simdgroupQueryData + simd_lane_id);
#pragma clang loop unroll(full)
for (size_t KROW = simd_group_id; KROW < TILE_SIZE_CONST;
KROW += NSIMDGROUPS) {
const uint KROW_OFFSET = KROW * DK;
const device T* baseKRow = baseK + KROW_OFFSET;
device T4* keysData = (device T4*)baseKRow;
thread_data_y4 = *(keysData + simd_lane_id);
T kq_scalar = dot(thread_data_x4, thread_data_y4);
threadAccum[KROW_ACCUM_INDEX] = float(kq_scalar);
KROW_ACCUM_INDEX++;
}
constexpr const uint16_t P_VEC4 = TILE_SIZE_CONST / NSIMDGROUPS / 4;
constexpr const size_t MATRIX_LOADS_PER_SIMDGROUP = TILE_SIZE_CONST / (SIMDGROUP_MATRIX_LOAD_FACTOR * NSIMDGROUPS);
constexpr const size_t MATRIX_COLS = DK / SIMDGROUP_MATRIX_LOAD_FACTOR;
constexpr const uint totalSmemV = SIMDGROUP_MATRIX_LOAD_FACTOR * SIMDGROUP_MATRIX_LOAD_FACTOR * (MATRIX_LOADS_PER_SIMDGROUP + 1) * NSIMDGROUPS;
} else {
thread_data_x4 = *(simdgroupQueryData + simd_lane_id);
const uint START_ROW = tid.y * TILE_SIZE_CONST;
const device T* baseKThisHead =
K + tgroup_k_batch_offset + tgroup_k_head_offset;
threadgroup T4* smemFlush = (threadgroup T4*)threadgroup_block;
#pragma clang loop unroll(full)
for(uint i = 0; i < 8; i++) {
smemFlush[simd_lane_id + simd_group_id * THREADS_PER_SIMDGROUP + i * NSIMDGROUPS * THREADS_PER_SIMDGROUP] = T4(0.f);
for (size_t KROW = START_ROW + simd_group_id; KROW < L;
KROW += NSIMDGROUPS) {
const uint KROW_OFFSET = KROW * DK;
const device T* baseKRow = baseKThisHead + KROW_OFFSET;
device T4* keysData = (device T4*)baseKRow;
thread_data_y4 = *(keysData + simd_lane_id);
T kq_scalar = dot(thread_data_x4, thread_data_y4);
threadAccum[KROW_ACCUM_INDEX] = float(kq_scalar);
KROW_ACCUM_INDEX++;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// TODO: multiple query sequence length for speculative decoding
const uint tgroup_query_head_offset = tid.x * DK + tid.z * (params.N_Q_HEADS * DK);
}
threadgroup float* smemP = (threadgroup float*)threadgroup_block;
const uint tgroup_k_head_offset = kv_head_offset_factor * DK * L;
const uint tgroup_k_tile_offset = tid.y * TILE_SIZE_CONST * DK;
const uint tgroup_k_batch_offset = tid.z * L * params.N_KV_HEADS * DK;
const device T* baseK = K + tgroup_k_batch_offset + tgroup_k_tile_offset + tgroup_k_head_offset;
const device T* baseQ = Q + tgroup_query_head_offset;
device T4* simdgroupQueryData = (device T4*)baseQ;
constexpr const size_t ACCUM_PER_GROUP = TILE_SIZE_CONST / NSIMDGROUPS;
float threadAccum[ACCUM_PER_GROUP];
#pragma clang loop unroll(full)
for(size_t threadAccumIndex = 0; threadAccumIndex < ACCUM_PER_GROUP; threadAccumIndex++) {
threadAccum[threadAccumIndex] = -INFINITY;
#pragma clang loop unroll(full)
for (size_t i = 0; i < P_VEC4; i++) {
thread_data_x4 =
T4(threadAccum[4 * i],
threadAccum[4 * i + 1],
threadAccum[4 * i + 2],
threadAccum[4 * i + 3]);
simdgroup_barrier(mem_flags::mem_none);
thread_data_y4 = simd_sum(thread_data_x4);
if (simd_lane_id == 0) {
const uint base_smem_p_offset = i * iter_offset + simd_group_id;
smemP[base_smem_p_offset + NSIMDGROUPS * 0] = float(thread_data_y4.x);
smemP[base_smem_p_offset + NSIMDGROUPS * 1] = float(thread_data_y4.y);
smemP[base_smem_p_offset + NSIMDGROUPS * 2] = float(thread_data_y4.z);
smemP[base_smem_p_offset + NSIMDGROUPS * 3] = float(thread_data_y4.w);
}
}
uint KROW_ACCUM_INDEX = 0;
threadgroup_barrier(mem_flags::mem_threadgroup);
const int32_t SEQUENCE_LENGTH_LESS_TILE_SIZE = L - TILE_SIZE_CONST;
const bool LAST_TILE = (tid.y + 1) * TILE_SIZE_CONST >= L;
const bool LAST_TILE_ALIGNED = (SEQUENCE_LENGTH_LESS_TILE_SIZE == int32_t(tid.y * TILE_SIZE_CONST));
float groupMax;
float lse = 0.f;
T4 thread_data_x4;
T4 thread_data_y4;
if(!LAST_TILE || LAST_TILE_ALIGNED) {
thread_data_x4 = *(simdgroupQueryData + simd_lane_id);
#pragma clang loop unroll(full)
for(size_t KROW = simd_group_id; KROW < TILE_SIZE_CONST; KROW += NSIMDGROUPS) {
const uint KROW_OFFSET = KROW * DK;
const device T* baseKRow = baseK + KROW_OFFSET;
device T4* keysData = (device T4*)baseKRow;
thread_data_y4 = *(keysData + simd_lane_id);
T kq_scalar = dot(thread_data_x4, thread_data_y4);
threadAccum[KROW_ACCUM_INDEX] = float(kq_scalar);
KROW_ACCUM_INDEX++;
constexpr const size_t THREADS_PER_THREADGROUP_TIMES_4 = 4 * 32;
constexpr const size_t ACCUM_ARRAY_LENGTH =
TILE_SIZE_CONST / THREADS_PER_THREADGROUP_TIMES_4 + 1;
float4 pvals[ACCUM_ARRAY_LENGTH];
#pragma clang loop unroll(full)
for (uint accum_array_iter = 0; accum_array_iter < ACCUM_ARRAY_LENGTH;
accum_array_iter++) {
pvals[accum_array_iter] = float4(-INFINITY);
}
if (TILE_SIZE_CONST == 64) {
threadgroup float2* smemPtrFlt2 = (threadgroup float2*)threadgroup_block;
float2 vals = smemPtrFlt2[simd_lane_id];
vals *= params.INV_ALPHA;
float maxval = max(vals.x, vals.y);
simdgroup_barrier(mem_flags::mem_none);
groupMax = simd_max(maxval);
float2 expf_shifted = exp(vals - groupMax);
float sumExpLocal = expf_shifted.x + expf_shifted.y;
simdgroup_barrier(mem_flags::mem_none);
float tgroupExpSum = simd_sum(sumExpLocal);
lse = log(tgroupExpSum);
float2 local_p_hat = expf_shifted / tgroupExpSum;
pvals[0].x = local_p_hat.x;
pvals[0].y = local_p_hat.y;
smemPtrFlt2[simd_lane_id] = float2(0.f);
}
constexpr const bool TILE_SIZE_LARGER_THAN_64 = TILE_SIZE_CONST > 64;
constexpr const int TILE_SIZE_ITERS_128 = TILE_SIZE_CONST / 128;
if (TILE_SIZE_LARGER_THAN_64) {
float maxval = -INFINITY;
threadgroup float4* smemPtrFlt4 = (threadgroup float4*)threadgroup_block;
#pragma clang loop unroll(full)
for (int i = 0; i < TILE_SIZE_ITERS_128; i++) {
float4 vals = smemPtrFlt4[simd_lane_id + i * THREADS_PER_SIMDGROUP];
vals *= params.INV_ALPHA;
pvals[i] = vals;
maxval = fmax3(vals.x, vals.y, maxval);
maxval = fmax3(vals.z, vals.w, maxval);
}
simdgroup_barrier(mem_flags::mem_none);
groupMax = simd_max(maxval);
float sumExpLocal = 0.f;
#pragma clang loop unroll(full)
for (int i = 0; i < TILE_SIZE_ITERS_128; i++) {
pvals[i] = exp(pvals[i] - groupMax);
sumExpLocal += pvals[i].x + pvals[i].y + pvals[i].z + pvals[i].w;
}
simdgroup_barrier(mem_flags::mem_none);
float tgroupExpSum = simd_sum(sumExpLocal);
lse = log(tgroupExpSum);
#pragma clang loop unroll(full)
for (int i = 0; i < TILE_SIZE_ITERS_128; i++) {
pvals[i] = pvals[i] / tgroupExpSum;
smemPtrFlt4[simd_lane_id + i * THREADS_PER_SIMDGROUP] = float4(0.f);
}
}
threadgroup T* smemV = (threadgroup T*)threadgroup_block;
const size_t v_batch_offset = tid.z * params.N_KV_HEADS * L * DK;
const size_t v_head_offset = kv_head_offset_factor * L * DK;
const size_t v_tile_offset = tid.y * TILE_SIZE_CONST * DK;
const size_t v_offset = v_batch_offset + v_head_offset + v_tile_offset;
device T* baseV = (device T*)V + v_offset;
threadgroup float* smemOpartial = (threadgroup float*)(smemV + totalSmemV);
if (!LAST_TILE || LAST_TILE_ALIGNED) {
#pragma clang loop unroll(full)
for (size_t col = 0; col < MATRIX_COLS; col++) {
uint matrix_load_loop_iter = 0;
constexpr const size_t TILE_SIZE_CONST_DIV_8 = TILE_SIZE_CONST / 8;
for (size_t tile_start = simd_group_id;
tile_start < TILE_SIZE_CONST_DIV_8;
tile_start += NSIMDGROUPS) {
simdgroup_matrix<T, 8, 8> tmp;
ulong simdgroup_matrix_offset =
matrix_load_loop_iter * NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR +
simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR;
ulong2 matrixOrigin =
ulong2(col * SIMDGROUP_MATRIX_LOAD_FACTOR, simdgroup_matrix_offset);
simdgroup_load(tmp, baseV, DK, matrixOrigin, true);
const ulong2 matrixOriginSmem = ulong2(simdgroup_matrix_offset, 0);
const ulong elemsPerRowSmem = TILE_SIZE_CONST;
simdgroup_store(tmp, smemV, elemsPerRowSmem, matrixOriginSmem, false);
matrix_load_loop_iter++;
};
threadgroup_barrier(mem_flags::mem_threadgroup);
if (TILE_SIZE_CONST == 64) {
T2 local_p_hat = T2(pvals[0].x, pvals[0].y);
uint loop_iter = 0;
threadgroup float* oPartialSmem =
smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
#pragma clang loop unroll(full)
for (size_t row = simd_group_id; row < SIMDGROUP_MATRIX_LOAD_FACTOR;
row += NSIMDGROUPS) {
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row);
threadgroup T2* smemV2 = (threadgroup T2*)smemV_row;
T2 v_local = *(smemV2 + simd_lane_id);
T val = dot(local_p_hat, v_local);
simdgroup_barrier(mem_flags::mem_none);
T row_sum = simd_sum(val);
oPartialSmem[simd_group_id + loop_iter * NSIMDGROUPS] =
float(row_sum);
loop_iter++;
}
} else {
thread_data_x4 = *(simdgroupQueryData + simd_lane_id);
const uint START_ROW = tid.y * TILE_SIZE_CONST;
const device T* baseKThisHead = K + tgroup_k_batch_offset + tgroup_k_head_offset;
}
for(size_t KROW = START_ROW + simd_group_id; KROW < L; KROW += NSIMDGROUPS) {
const uint KROW_OFFSET = KROW * DK;
const device T* baseKRow = baseKThisHead + KROW_OFFSET;
device T4* keysData = (device T4*)baseKRow;
thread_data_y4 = *(keysData + simd_lane_id);
T kq_scalar = dot(thread_data_x4, thread_data_y4);
threadAccum[KROW_ACCUM_INDEX] = float(kq_scalar);
KROW_ACCUM_INDEX++;
if (TILE_SIZE_CONST > 64) {
constexpr const size_t TILE_SIZE_CONST_DIV_128 =
(TILE_SIZE_CONST + 1) / 128;
threadgroup float* oPartialSmem =
smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
uint loop_iter = 0;
for (size_t row = simd_group_id; row < SIMDGROUP_MATRIX_LOAD_FACTOR;
row += NSIMDGROUPS) {
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row);
T row_sum = 0.f;
for (size_t i = 0; i < TILE_SIZE_CONST_DIV_128; i++) {
threadgroup T4* smemV2 = (threadgroup T4*)smemV_row;
T4 v_local = *(smemV2 + simd_lane_id + i * THREADS_PER_SIMDGROUP);
T4 p_local = T4(pvals[i]);
T val = dot(p_local, v_local);
row_sum += val;
}
simdgroup_barrier(mem_flags::mem_none);
row_sum = simd_sum(row_sum);
oPartialSmem[simd_group_id + loop_iter * NSIMDGROUPS] =
float(row_sum);
loop_iter++;
}
}
}
threadgroup float* smemP = (threadgroup float*)threadgroup_block;
} else {
const int32_t START_ROW = tid.y * TILE_SIZE_CONST;
const int32_t MAX_START_ROW = L - SIMDGROUP_MATRIX_LOAD_FACTOR + 1;
const device T* baseVThisHead = V + v_batch_offset + v_head_offset;
constexpr const int ROWS_PER_ITER = 8;
#pragma clang loop unroll(full)
for (size_t col = 0; col < MATRIX_COLS; col++) {
uint smem_col_index = simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR;
int32_t tile_start;
for (tile_start =
START_ROW + simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR;
tile_start < MAX_START_ROW;
tile_start += NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR) {
simdgroup_matrix<T, 8, 8> tmp;
ulong2 matrixOrigin =
ulong2(col * SIMDGROUP_MATRIX_LOAD_FACTOR, tile_start);
simdgroup_load(
tmp, baseVThisHead, DK, matrixOrigin, /* transpose */ true);
const ulong2 matrixOriginSmem = ulong2(smem_col_index, 0);
constexpr const ulong elemsPerRowSmem = TILE_SIZE_CONST;
simdgroup_store(
tmp,
smemV,
elemsPerRowSmem,
matrixOriginSmem,
/* transpose */ false);
smem_col_index += NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR;
};
#pragma clang loop unroll(full)
for(size_t i = 0; i < P_VEC4; i++) {
thread_data_x4 = T4(threadAccum[4 * i], threadAccum[4 * i + 1], threadAccum[4 * i + 2], threadAccum[4 * i + 3]);
simdgroup_barrier(mem_flags::mem_none);
thread_data_y4 = simd_sum(thread_data_x4);
if(simd_lane_id == 0) {
const uint base_smem_p_offset = i * iter_offset + simd_group_id;
smemP[base_smem_p_offset + NSIMDGROUPS * 0] = float(thread_data_y4.x);
smemP[base_smem_p_offset + NSIMDGROUPS * 1] = float(thread_data_y4.y);
smemP[base_smem_p_offset + NSIMDGROUPS * 2] = float(thread_data_y4.z);
smemP[base_smem_p_offset + NSIMDGROUPS * 3] = float(thread_data_y4.w);
tile_start =
((L / SIMDGROUP_MATRIX_LOAD_FACTOR) * SIMDGROUP_MATRIX_LOAD_FACTOR);
const int32_t INT_L = int32_t(L);
for (int row_index = tile_start + simd_group_id; row_index < INT_L;
row_index += NSIMDGROUPS) {
if (simd_lane_id < SIMDGROUP_MATRIX_LOAD_FACTOR) {
const uint elems_per_row_gmem = DK;
const uint col_index_v_gmem =
col * SIMDGROUP_MATRIX_LOAD_FACTOR + simd_lane_id;
const uint row_index_v_gmem = row_index;
const uint elems_per_row_smem = TILE_SIZE_CONST;
const uint col_index_v_smem = row_index % TILE_SIZE_CONST;
const uint row_index_v_smem = simd_lane_id;
const uint scalar_offset_gmem =
row_index_v_gmem * elems_per_row_gmem + col_index_v_gmem;
const uint scalar_offset_smem =
row_index_v_smem * elems_per_row_smem + col_index_v_smem;
T vdata = T(*(baseVThisHead + scalar_offset_gmem));
smemV[scalar_offset_smem] = vdata;
smem_col_index += NSIMDGROUPS;
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
threadgroup_barrier(mem_flags::mem_threadgroup);
float groupMax;
float lse = 0.f;
constexpr const size_t THREADS_PER_THREADGROUP_TIMES_4 = 4 * 32;
constexpr const size_t ACCUM_ARRAY_LENGTH = TILE_SIZE_CONST / THREADS_PER_THREADGROUP_TIMES_4 + 1;
float4 pvals[ACCUM_ARRAY_LENGTH];
#pragma clang loop unroll(full)
for(uint accum_array_iter = 0; accum_array_iter < ACCUM_ARRAY_LENGTH; accum_array_iter++) {
pvals[accum_array_iter] = float4(-INFINITY);
}
if (TILE_SIZE_CONST == 64) {
threadgroup float2* smemPtrFlt2 = (threadgroup float2*)threadgroup_block;
float2 vals = smemPtrFlt2[simd_lane_id];
vals *= params.INV_ALPHA;
float maxval = max(vals.x, vals.y);
simdgroup_barrier(mem_flags::mem_none);
groupMax = simd_max(maxval);
float2 expf_shifted = exp(vals - groupMax);
float sumExpLocal = expf_shifted.x + expf_shifted.y;
simdgroup_barrier(mem_flags::mem_none);
float tgroupExpSum = simd_sum(sumExpLocal);
lse = log(tgroupExpSum);
float2 local_p_hat = expf_shifted / tgroupExpSum;
pvals[0].x = local_p_hat.x;
pvals[0].y = local_p_hat.y;
smemPtrFlt2[simd_lane_id] = float2(0.f);
}
constexpr const bool TILE_SIZE_LARGER_THAN_64 = TILE_SIZE_CONST > 64;
constexpr const int TILE_SIZE_ITERS_128 = TILE_SIZE_CONST / 128;
if (TILE_SIZE_LARGER_THAN_64) {
float maxval = -INFINITY;
threadgroup float4* smemPtrFlt4 = (threadgroup float4*)threadgroup_block;
#pragma clang loop unroll(full)
for(int i = 0; i < TILE_SIZE_ITERS_128; i++) {
float4 vals = smemPtrFlt4[simd_lane_id + i * THREADS_PER_SIMDGROUP];
vals *= params.INV_ALPHA;
pvals[i] = vals;
maxval = fmax3(vals.x, vals.y, maxval);
maxval = fmax3(vals.z, vals.w, maxval);
if (TILE_SIZE_CONST == 64) {
T2 local_p_hat = T2(pvals[0].x, pvals[0].y);
threadgroup float* oPartialSmem =
smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
for (size_t smem_row_index = simd_group_id;
smem_row_index < ROWS_PER_ITER;
smem_row_index += NSIMDGROUPS) {
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * smem_row_index);
threadgroup T2* smemV2 = (threadgroup T2*)smemV_row;
T2 v_local = *(smemV2 + simd_lane_id);
T val = dot(local_p_hat, v_local);
simdgroup_barrier(mem_flags::mem_none);
T row_sum = simd_sum(val);
oPartialSmem[smem_row_index] = float(row_sum);
}
simdgroup_barrier(mem_flags::mem_none);
groupMax = simd_max(maxval);
}
float sumExpLocal = 0.f;
#pragma clang loop unroll(full)
for(int i = 0; i < TILE_SIZE_ITERS_128; i++) {
pvals[i] = exp(pvals[i] - groupMax);
sumExpLocal += pvals[i].x + pvals[i].y + pvals[i].z + pvals[i].w;
}
simdgroup_barrier(mem_flags::mem_none);
float tgroupExpSum = simd_sum(sumExpLocal);
lse = log(tgroupExpSum);
#pragma clang loop unroll(full)
for(int i = 0; i < TILE_SIZE_ITERS_128; i++) {
pvals[i] = pvals[i] / tgroupExpSum;
smemPtrFlt4[simd_lane_id + i * THREADS_PER_SIMDGROUP] = float4(0.f);
if (TILE_SIZE_CONST > 64) {
threadgroup float* oPartialSmem =
smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
uint loop_count = 0;
for (size_t row_index = simd_group_id; row_index < ROWS_PER_ITER;
row_index += NSIMDGROUPS) {
T row_sum = 0.f;
for (size_t tile_iters = 0; tile_iters < TILE_SIZE_ITERS_128;
tile_iters++) {
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row_index);
threadgroup T4* smemV2 = (threadgroup T4*)smemV_row;
T4 v_local =
*(smemV2 + simd_lane_id + tile_iters * THREADS_PER_SIMDGROUP);
T4 p_local = T4(pvals[tile_iters]);
row_sum += dot(p_local, v_local);
}
simdgroup_barrier(mem_flags::mem_none);
row_sum = simd_sum(row_sum);
oPartialSmem[simd_group_id + NSIMDGROUPS * loop_count] =
float(row_sum);
loop_count++;
}
}
}
}
threadgroup T* smemV = (threadgroup T*)threadgroup_block;
threadgroup_barrier(mem_flags::mem_threadgroup);
const size_t v_batch_offset = tid.z * params.N_KV_HEADS * L * DK;
const size_t v_head_offset = kv_head_offset_factor * L * DK;
if (simd_group_id == 0) {
threadgroup float4* oPartialVec4 = (threadgroup float4*)smemOpartial;
float4 vals = *(oPartialVec4 + simd_lane_id);
device float* oPartialGmem =
O_partials + tid.x * DK * params.KV_TILES + tid.y * DK;
device float4* oPartialGmemVec4 = (device float4*)oPartialGmem;
oPartialGmemVec4[simd_lane_id] = vals;
}
const size_t v_tile_offset = tid.y * TILE_SIZE_CONST * DK;
const size_t v_offset = v_batch_offset + v_head_offset + v_tile_offset;
device T* baseV = (device T*)V + v_offset;
threadgroup float* smemOpartial = (threadgroup float*)(smemV + totalSmemV);
if (!LAST_TILE || LAST_TILE_ALIGNED) {
#pragma clang loop unroll(full)
for(size_t col = 0; col < MATRIX_COLS; col++) {
uint matrix_load_loop_iter = 0;
constexpr const size_t TILE_SIZE_CONST_DIV_8 = TILE_SIZE_CONST / 8;
for(size_t tile_start = simd_group_id; tile_start < TILE_SIZE_CONST_DIV_8; tile_start += NSIMDGROUPS) {
simdgroup_matrix<T, 8, 8> tmp;
ulong simdgroup_matrix_offset = matrix_load_loop_iter * NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR + simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR;
ulong2 matrixOrigin = ulong2(col * SIMDGROUP_MATRIX_LOAD_FACTOR, simdgroup_matrix_offset);
simdgroup_load(tmp, baseV, DK, matrixOrigin, true);
const ulong2 matrixOriginSmem = ulong2(simdgroup_matrix_offset, 0);
const ulong elemsPerRowSmem = TILE_SIZE_CONST;
simdgroup_store(tmp, smemV, elemsPerRowSmem, matrixOriginSmem, false);
matrix_load_loop_iter++;
};
threadgroup_barrier(mem_flags::mem_threadgroup);
if (TILE_SIZE_CONST == 64) {
T2 local_p_hat = T2(pvals[0].x, pvals[0].y);
uint loop_iter = 0;
threadgroup float* oPartialSmem = smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
#pragma clang loop unroll(full)
for(size_t row = simd_group_id; row < SIMDGROUP_MATRIX_LOAD_FACTOR; row += NSIMDGROUPS) {
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row);
threadgroup T2* smemV2 = (threadgroup T2*)smemV_row;
T2 v_local = *(smemV2 + simd_lane_id);
T val = dot(local_p_hat, v_local);
simdgroup_barrier(mem_flags::mem_none);
T row_sum = simd_sum(val);
oPartialSmem[simd_group_id + loop_iter * NSIMDGROUPS] = float(row_sum);
loop_iter++;
}
}
if (TILE_SIZE_CONST > 64) {
constexpr const size_t TILE_SIZE_CONST_DIV_128 = (TILE_SIZE_CONST + 1) / 128;
threadgroup float* oPartialSmem = smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
uint loop_iter = 0;
for(size_t row = simd_group_id; row < SIMDGROUP_MATRIX_LOAD_FACTOR; row += NSIMDGROUPS) {
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row);
T row_sum = 0.f;
for(size_t i = 0; i < TILE_SIZE_CONST_DIV_128; i++) {
threadgroup T4* smemV2 = (threadgroup T4*)smemV_row;
T4 v_local = *(smemV2 + simd_lane_id + i * THREADS_PER_SIMDGROUP);
T4 p_local = T4(pvals[i]);
T val = dot(p_local, v_local);
row_sum += val;
}
simdgroup_barrier(mem_flags::mem_none);
row_sum = simd_sum(row_sum);
oPartialSmem[simd_group_id + loop_iter * NSIMDGROUPS] = float(row_sum);
loop_iter++;
}
}
}
} else {
const int32_t START_ROW = tid.y * TILE_SIZE_CONST;
const int32_t MAX_START_ROW = L - SIMDGROUP_MATRIX_LOAD_FACTOR + 1;
const device T* baseVThisHead = V + v_batch_offset + v_head_offset;
constexpr const int ROWS_PER_ITER = 8;
#pragma clang loop unroll(full)
for(size_t col = 0; col < MATRIX_COLS; col++) {
uint smem_col_index = simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR;
int32_t tile_start;
for(tile_start = START_ROW + simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR; tile_start < MAX_START_ROW; tile_start += NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR) {
simdgroup_matrix<T, 8, 8> tmp;
ulong2 matrixOrigin = ulong2(col * SIMDGROUP_MATRIX_LOAD_FACTOR, tile_start);
simdgroup_load(tmp, baseVThisHead, DK, matrixOrigin, /* transpose */ true);
const ulong2 matrixOriginSmem = ulong2(smem_col_index, 0);
constexpr const ulong elemsPerRowSmem = TILE_SIZE_CONST;
simdgroup_store(tmp, smemV, elemsPerRowSmem, matrixOriginSmem, /* transpose */ false);
smem_col_index += NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR;
};
tile_start = ((L / SIMDGROUP_MATRIX_LOAD_FACTOR) * SIMDGROUP_MATRIX_LOAD_FACTOR);
const int32_t INT_L = int32_t(L);
for(int row_index = tile_start + simd_group_id ; row_index < INT_L; row_index += NSIMDGROUPS) {
if(simd_lane_id < SIMDGROUP_MATRIX_LOAD_FACTOR) {
const uint elems_per_row_gmem = DK;
const uint col_index_v_gmem = col * SIMDGROUP_MATRIX_LOAD_FACTOR + simd_lane_id;
const uint row_index_v_gmem = row_index;
const uint elems_per_row_smem = TILE_SIZE_CONST;
const uint col_index_v_smem = row_index % TILE_SIZE_CONST;
const uint row_index_v_smem = simd_lane_id;
const uint scalar_offset_gmem = row_index_v_gmem * elems_per_row_gmem + col_index_v_gmem;
const uint scalar_offset_smem = row_index_v_smem * elems_per_row_smem + col_index_v_smem;
T vdata = T(*(baseVThisHead + scalar_offset_gmem));
smemV[scalar_offset_smem] = vdata;
smem_col_index += NSIMDGROUPS;
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (TILE_SIZE_CONST == 64) {
T2 local_p_hat = T2(pvals[0].x, pvals[0].y);
threadgroup float* oPartialSmem = smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
for(size_t smem_row_index = simd_group_id;
smem_row_index < ROWS_PER_ITER; smem_row_index += NSIMDGROUPS) {
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * smem_row_index);
threadgroup T2* smemV2 = (threadgroup T2*)smemV_row;
T2 v_local = *(smemV2 + simd_lane_id);
T val = dot(local_p_hat, v_local);
simdgroup_barrier(mem_flags::mem_none);
T row_sum = simd_sum(val);
oPartialSmem[smem_row_index] = float(row_sum);
}
}
if (TILE_SIZE_CONST > 64) {
threadgroup float* oPartialSmem = smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
uint loop_count = 0;
for(size_t row_index = simd_group_id;
row_index < ROWS_PER_ITER; row_index += NSIMDGROUPS) {
T row_sum = 0.f;
for(size_t tile_iters = 0; tile_iters < TILE_SIZE_ITERS_128; tile_iters++) {
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row_index);
threadgroup T4* smemV2 = (threadgroup T4*)smemV_row;
T4 v_local = *(smemV2 + simd_lane_id + tile_iters * THREADS_PER_SIMDGROUP);
T4 p_local = T4(pvals[tile_iters]);
row_sum += dot(p_local, v_local);
}
simdgroup_barrier(mem_flags::mem_none);
row_sum = simd_sum(row_sum);
oPartialSmem[simd_group_id + NSIMDGROUPS * loop_count] = float(row_sum);
loop_count++;
}
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if(simd_group_id == 0) {
threadgroup float4* oPartialVec4 = (threadgroup float4*)smemOpartial;
float4 vals = *(oPartialVec4 + simd_lane_id);
device float* oPartialGmem = O_partials + tid.x * DK * params.KV_TILES + tid.y * DK;
device float4* oPartialGmemVec4 = (device float4*)oPartialGmem;
oPartialGmemVec4[simd_lane_id] = vals;
}
if(simd_group_id == 0 && simd_lane_id == 0) {
const uint tileIndex = tid.y;
const uint gmem_partial_scalar_offset = tid.z * params.N_Q_HEADS * params.KV_TILES + tid.x * params.KV_TILES + tileIndex;
p_lse[gmem_partial_scalar_offset] = lse;
p_maxes[gmem_partial_scalar_offset] = groupMax;
}
if (simd_group_id == 0 && simd_lane_id == 0) {
const uint tileIndex = tid.y;
const uint gmem_partial_scalar_offset =
tid.z * params.N_Q_HEADS * params.KV_TILES + tid.x * params.KV_TILES +
tileIndex;
p_lse[gmem_partial_scalar_offset] = lse;
p_maxes[gmem_partial_scalar_offset] = groupMax;
}
}
#define instantiate_fast_inference_sdpa_to_partials_kernel(itype, itype2, itype4, tile_size, nsimdgroups) \
template [[host_name("fast_inference_sdpa_compute_partials_" #itype "_" #tile_size "_" #nsimdgroups )]] \
[[kernel]] void fast_inference_sdpa_compute_partials_template<itype, itype2, itype4, tile_size, nsimdgroups>( \
const device itype *Q [[buffer(0)]], \
const device itype *K [[buffer(1)]], \
const device itype *V [[buffer(2)]], \
const device uint64_t& L [[buffer(3)]], \
const device MLXScaledDotProductAttentionParams& params [[buffer(4)]], \
device float* O_partials [[buffer(5)]], \
device float* p_lse [[buffer(6)]], \
device float* p_maxes [[buffer(7)]], \
threadgroup itype *threadgroup_block [[threadgroup(0)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]]);
#define instantiate_fast_inference_sdpa_to_partials_kernel( \
itype, itype2, itype4, tile_size, nsimdgroups) \
template [[host_name("fast_inference_sdpa_compute_partials_" #itype \
"_" #tile_size "_" #nsimdgroups)]] [[kernel]] void \
fast_inference_sdpa_compute_partials_template< \
itype, \
itype2, \
itype4, \
tile_size, \
nsimdgroups>( \
const device itype* Q [[buffer(0)]], \
const device itype* K [[buffer(1)]], \
const device itype* V [[buffer(2)]], \
const device uint64_t& L [[buffer(3)]], \
const device MLXScaledDotProductAttentionParams& params [[buffer(4)]], \
device float* O_partials [[buffer(5)]], \
device float* p_lse [[buffer(6)]], \
device float* p_maxes [[buffer(7)]], \
threadgroup itype* threadgroup_block [[threadgroup(0)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]]);
// clang-format off
#define instantiate_fast_inference_sdpa_to_partials_shapes_helper( \
itype, itype2, itype4, tile_size) \
instantiate_fast_inference_sdpa_to_partials_kernel( \
itype, itype2, itype4, tile_size, 4) \
instantiate_fast_inference_sdpa_to_partials_kernel( \
itype, itype2, itype4, tile_size, 8) // clang-format on
#define instantiate_fast_inference_sdpa_to_partials_shapes_helper(itype, itype2, itype4, tile_size) \
instantiate_fast_inference_sdpa_to_partials_kernel(itype, itype2, itype4, tile_size, 4) \
instantiate_fast_inference_sdpa_to_partials_kernel(itype, itype2, itype4, tile_size, 8) \
instantiate_fast_inference_sdpa_to_partials_shapes_helper(float, float2, float4, 64);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(float, float2, float4, 128);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(float, float2, float4, 256);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(float, float2, float4, 512);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(half, half2, half4, 64);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(half, half2, half4, 128);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(half, half2, half4, 256);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(half, half2, half4, 512);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
float,
float2,
float4,
64);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
float,
float2,
float4,
128);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
float,
float2,
float4,
256);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
float,
float2,
float4,
512);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
half,
half2,
half4,
64);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
half,
half2,
half4,
128);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
half,
half2,
half4,
256);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
half,
half2,
half4,
512);
template <typename T>
void fast_inference_sdpa_reduce_tiles_template(
const device float *O_partials [[buffer(0)]],
const device float *p_lse[[buffer(1)]],
const device float *p_maxes [[buffer(2)]],
const device float* O_partials [[buffer(0)]],
const device float* p_lse [[buffer(1)]],
const device float* p_maxes [[buffer(2)]],
const device MLXScaledDotProductAttentionParams& params [[buffer(3)]],
device T* O [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
constexpr const int DK = 128;
const ulong offset_rows =
tid.z * params.KV_TILES * params.N_Q_HEADS + tid.x * params.KV_TILES;
const device float* p_lse_row = p_lse + offset_rows;
const device float* p_rowmax_row = p_maxes + offset_rows;
// reserve some number of registers. this constitutes an assumption on max
// value of KV TILES.
constexpr const uint8_t reserve = 128;
float p_lse_regs[reserve];
float p_rowmax_regs[reserve];
float weights[reserve];
constexpr const int DK = 128;
const ulong offset_rows = tid.z * params.KV_TILES * params.N_Q_HEADS + tid.x * params.KV_TILES;
const device float* p_lse_row = p_lse + offset_rows;
const device float* p_rowmax_row = p_maxes + offset_rows;
// reserve some number of registers. this constitutes an assumption on max value of KV TILES.
constexpr const uint8_t reserve = 128;
float p_lse_regs[reserve];
float p_rowmax_regs[reserve];
float weights[reserve];
float true_max = -INFINITY;
for (size_t i = 0; i < params.KV_TILES; i++) {
p_lse_regs[i] = float(*(p_lse_row + i));
p_rowmax_regs[i] = float(*(p_rowmax_row + i));
true_max = fmax(p_rowmax_regs[i], true_max);
weights[i] = exp(p_lse_regs[i]);
}
float true_max = -INFINITY;
for(size_t i = 0; i < params.KV_TILES; i++) {
p_lse_regs[i] = float(*(p_lse_row + i));
p_rowmax_regs[i] = float(*(p_rowmax_row + i));
true_max = fmax(p_rowmax_regs[i], true_max);
weights[i] = exp(p_lse_regs[i]);
}
float denom = 0.f;
for (size_t i = 0; i < params.KV_TILES; i++) {
weights[i] *= exp(p_rowmax_regs[i] - true_max);
denom += weights[i];
}
float denom = 0.f;
for(size_t i = 0; i < params.KV_TILES; i++) {
weights[i] *= exp(p_rowmax_regs[i]-true_max);
denom += weights[i];
}
const device float* O_partials_with_offset = O_partials +
tid.z * params.N_Q_HEADS * DK * params.KV_TILES +
tid.x * DK * params.KV_TILES;
const device float* O_partials_with_offset = O_partials + tid.z * params.N_Q_HEADS * DK * params.KV_TILES + tid.x * DK * params.KV_TILES;
float o_value = 0.f;
for(size_t i = 0; i < params.KV_TILES; i++) {
float val = *(O_partials_with_offset + i * DK + lid.x);
o_value += val * weights[i] / denom;
}
device T* O_gmem = O + tid.z * params.N_Q_HEADS * DK + tid.x * DK;
O_gmem[lid.x] = T(o_value);
return;
float o_value = 0.f;
for (size_t i = 0; i < params.KV_TILES; i++) {
float val = *(O_partials_with_offset + i * DK + lid.x);
o_value += val * weights[i] / denom;
}
device T* O_gmem = O + tid.z * params.N_Q_HEADS * DK + tid.x * DK;
O_gmem[lid.x] = T(o_value);
return;
}
kernel void fast_inference_sdpa_reduce_tiles_float(
const device float *O_partials [[buffer(0)]],
const device float *p_lse[[buffer(1)]],
const device float *p_maxes [[buffer(2)]],
const device float* O_partials [[buffer(0)]],
const device float* p_lse [[buffer(1)]],
const device float* p_maxes [[buffer(2)]],
const device MLXScaledDotProductAttentionParams& params [[buffer(3)]],
device float* O [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]])
{
fast_inference_sdpa_reduce_tiles_template<float>(O_partials, p_lse, p_maxes, params,
O, tid, lid);
uint3 lid [[thread_position_in_threadgroup]]) {
fast_inference_sdpa_reduce_tiles_template<float>(
O_partials, p_lse, p_maxes, params, O, tid, lid);
}
kernel void fast_inference_sdpa_reduce_tiles_half(
const device float *O_partials [[buffer(0)]],
const device float *p_lse[[buffer(1)]],
const device float *p_maxes [[buffer(2)]],
const device float* O_partials [[buffer(0)]],
const device float* p_lse [[buffer(1)]],
const device float* p_maxes [[buffer(2)]],
const device MLXScaledDotProductAttentionParams& params [[buffer(3)]],
device half* O [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]])
{
fast_inference_sdpa_reduce_tiles_template<half>(O_partials, p_lse, p_maxes, params,
O, tid, lid);
uint3 lid [[thread_position_in_threadgroup]]) {
fast_inference_sdpa_reduce_tiles_template<half>(
O_partials, p_lse, p_maxes, params, O, tid, lid);
}

View File

@@ -54,7 +54,7 @@ struct CumProd<bool> {
}
bool simd_scan(bool x) {
for (int i=1; i<=16; i*=2) {
for (int i = 1; i <= 16; i *= 2) {
bool other = simd_shuffle_up(x, i);
x &= other;
}
@@ -77,7 +77,7 @@ struct CumMax {
}
U simd_scan(U x) {
for (int i=1; i<=16; i*=2) {
for (int i = 1; i <= 16; i *= 2) {
U other = simd_shuffle_up(x, i);
x = (x >= other) ? x : other;
}
@@ -100,7 +100,7 @@ struct CumMin {
}
U simd_scan(U x) {
for (int i=1; i<=16; i*=2) {
for (int i = 1; i <= 16; i *= 2) {
U other = simd_shuffle_up(x, i);
x = (x <= other) ? x : other;
}
@@ -114,54 +114,60 @@ struct CumMin {
};
template <typename T, typename U, int N_READS, bool reverse>
inline void load_unsafe(U values[N_READS], const device T * input) {
inline void load_unsafe(U values[N_READS], const device T* input) {
if (reverse) {
for (int i=0; i<N_READS; i++) {
values[N_READS-i-1] = input[i];
for (int i = 0; i < N_READS; i++) {
values[N_READS - i - 1] = input[i];
}
} else {
for (int i=0; i<N_READS; i++) {
for (int i = 0; i < N_READS; i++) {
values[i] = input[i];
}
}
}
template <typename T, typename U, int N_READS, bool reverse>
inline void load_safe(U values[N_READS], const device T * input, int start, int total, U init) {
inline void load_safe(
U values[N_READS],
const device T* input,
int start,
int total,
U init) {
if (reverse) {
for (int i=0; i<N_READS; i++) {
values[N_READS-i-1] = (start + N_READS - i - 1 < total) ? input[i] : init;
for (int i = 0; i < N_READS; i++) {
values[N_READS - i - 1] =
(start + N_READS - i - 1 < total) ? input[i] : init;
}
} else {
for (int i=0; i<N_READS; i++) {
for (int i = 0; i < N_READS; i++) {
values[i] = (start + i < total) ? input[i] : init;
}
}
}
template <typename U, int N_READS, bool reverse>
inline void write_unsafe(U values[N_READS], device U * out) {
inline void write_unsafe(U values[N_READS], device U* out) {
if (reverse) {
for (int i=0; i<N_READS; i++) {
out[i] = values[N_READS-i-1];
for (int i = 0; i < N_READS; i++) {
out[i] = values[N_READS - i - 1];
}
} else {
for (int i=0; i<N_READS; i++) {
for (int i = 0; i < N_READS; i++) {
out[i] = values[i];
}
}
}
template <typename U, int N_READS, bool reverse>
inline void write_safe(U values[N_READS], device U * out, int start, int total) {
inline void write_safe(U values[N_READS], device U* out, int start, int total) {
if (reverse) {
for (int i=0; i<N_READS; i++) {
for (int i = 0; i < N_READS; i++) {
if (start + N_READS - i - 1 < total) {
out[i] = values[N_READS-i-1];
out[i] = values[N_READS - i - 1];
}
}
} else {
for (int i=0; i<N_READS; i++) {
for (int i = 0; i < N_READS; i++) {
if (start + i < total) {
out[i] = values[i];
}
@@ -169,12 +175,17 @@ inline void write_safe(U values[N_READS], device U * out, int start, int total)
}
}
template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool reverse>
template <
typename T,
typename U,
typename Op,
int N_READS,
bool inclusive,
bool reverse>
[[kernel]] void contiguous_scan(
const device T* in [[buffer(0)]],
device U* out [[buffer(1)]],
const constant size_t & axis_size [[buffer(2)]],
const constant size_t& axis_size [[buffer(2)]],
uint gid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]],
@@ -195,42 +206,51 @@ template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool
U values[N_READS];
threadgroup U simdgroup_sums[32];
// Loop over the reduced axis in blocks of size ceildiv(axis_size, N_READS*lsize)
// Loop over the reduced axis in blocks of size ceildiv(axis_size,
// N_READS*lsize)
// Read block
// Compute inclusive scan of the block
// Compute inclusive scan per thread
// Compute exclusive scan of thread sums in simdgroup
// Write simdgroup sums in SM
// Compute exclusive scan of simdgroup sums
// Compute the output by scanning prefix, prev_simdgroup, prev_thread, value
// Compute the output by scanning prefix, prev_simdgroup, prev_thread,
// value
// Write block
for (uint r = 0; r < ceildiv(axis_size, N_READS*lsize); r++) {
for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize); r++) {
// Compute the block offset
uint offset = r*lsize*N_READS + lid*N_READS;
uint offset = r * lsize * N_READS + lid * N_READS;
// Read the values
if (reverse) {
if ((offset + N_READS) < axis_size) {
load_unsafe<T, U, N_READS, reverse>(values, in + axis_size - offset - N_READS);
load_unsafe<T, U, N_READS, reverse>(
values, in + axis_size - offset - N_READS);
} else {
load_safe<T, U, N_READS, reverse>(values, in + axis_size - offset - N_READS, offset, axis_size, Op::init);
load_safe<T, U, N_READS, reverse>(
values,
in + axis_size - offset - N_READS,
offset,
axis_size,
Op::init);
}
} else {
if ((offset + N_READS) < axis_size) {
load_unsafe<T, U, N_READS, reverse>(values, in + offset);
} else {
load_safe<T, U, N_READS, reverse>(values, in + offset, offset, axis_size, Op::init);
load_safe<T, U, N_READS, reverse>(
values, in + offset, offset, axis_size, Op::init);
}
}
// Compute an inclusive scan per thread
for (int i=1; i<N_READS; i++) {
values[i] = op(values[i], values[i-1]);
for (int i = 1; i < N_READS; i++) {
values[i] = op(values[i], values[i - 1]);
}
// Compute exclusive scan of thread sums
U prev_thread = op.simd_exclusive_scan(values[N_READS-1]);
U prev_thread = op.simd_exclusive_scan(values[N_READS - 1]);
// Write simdgroup_sums to SM
if (simd_lane_id == simd_size - 1) {
@@ -246,7 +266,7 @@ template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool
threadgroup_barrier(mem_flags::mem_threadgroup);
// Compute the output
for (int i=0; i<N_READS; i++) {
for (int i = 0; i < N_READS; i++) {
values[i] = op(values[i], prefix);
values[i] = op(values[i], simdgroup_sums[simd_group_id]);
values[i] = op(values[i], prev_thread);
@@ -256,18 +276,25 @@ template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool
if (reverse) {
if (inclusive) {
if ((offset + N_READS) < axis_size) {
write_unsafe<U, N_READS, reverse>(values, out + axis_size - offset - N_READS);
write_unsafe<U, N_READS, reverse>(
values, out + axis_size - offset - N_READS);
} else {
write_safe<U, N_READS, reverse>(values, out + axis_size - offset - N_READS, offset, axis_size);
write_safe<U, N_READS, reverse>(
values, out + axis_size - offset - N_READS, offset, axis_size);
}
} else {
if (lid == 0 && offset == 0) {
out[axis_size-1] = Op::init;
out[axis_size - 1] = Op::init;
}
if ((offset + N_READS + 1) < axis_size) {
write_unsafe<U, N_READS, reverse>(values, out + axis_size - offset - 1 - N_READS);
write_unsafe<U, N_READS, reverse>(
values, out + axis_size - offset - 1 - N_READS);
} else {
write_safe<U, N_READS, reverse>(values, out + axis_size - offset - 1 - N_READS, offset + 1, axis_size);
write_safe<U, N_READS, reverse>(
values,
out + axis_size - offset - 1 - N_READS,
offset + 1,
axis_size);
}
}
} else {
@@ -275,7 +302,8 @@ template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool
if ((offset + N_READS) < axis_size) {
write_unsafe<U, N_READS, reverse>(values, out + offset);
} else {
write_safe<U, N_READS, reverse>(values, out + offset, offset, axis_size);
write_safe<U, N_READS, reverse>(
values, out + offset, offset, axis_size);
}
} else {
if (lid == 0 && offset == 0) {
@@ -284,26 +312,33 @@ template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool
if ((offset + N_READS + 1) < axis_size) {
write_unsafe<U, N_READS, reverse>(values, out + offset + 1);
} else {
write_safe<U, N_READS, reverse>(values, out + offset + 1, offset + 1, axis_size);
write_safe<U, N_READS, reverse>(
values, out + offset + 1, offset + 1, axis_size);
}
}
}
// Share the prefix
if (simd_group_id == simd_groups - 1 && simd_lane_id == simd_size - 1) {
simdgroup_sums[0] = values[N_READS-1];
simdgroup_sums[0] = values[N_READS - 1];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
prefix = simdgroup_sums[0];
}
}
template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool reverse>
template <
typename T,
typename U,
typename Op,
int N_READS,
bool inclusive,
bool reverse>
[[kernel]] void strided_scan(
const device T* in [[buffer(0)]],
device U* out [[buffer(1)]],
const constant size_t & axis_size [[buffer(2)]],
const constant size_t & stride [[buffer(3)]],
const constant size_t& axis_size [[buffer(2)]],
const constant size_t& stride [[buffer(3)]],
uint2 gid [[threadgroup_position_in_grid]],
uint2 lid [[thread_position_in_threadgroup]],
uint2 lsize [[threads_per_threadgroup]],
@@ -311,10 +346,10 @@ template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool
Op op;
// Allocate memory
threadgroup U read_buffer[N_READS*32*32 + N_READS*32];
threadgroup U read_buffer[N_READS * 32 * 32 + N_READS * 32];
U values[N_READS];
U prefix[N_READS];
for (int i=0; i<N_READS; i++) {
for (int i = 0; i < N_READS; i++) {
prefix[i] = Op::init;
}
@@ -322,7 +357,7 @@ template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool
int offset = gid.y * axis_size * stride;
int global_index_x = gid.x * lsize.y * N_READS;
for (uint j=0; j<axis_size; j+=simd_size) {
for (uint j = 0; j < axis_size; j += simd_size) {
// Calculate the indices for the current thread
uint index_y = j + lid.y;
uint check_index_y = index_y;
@@ -333,37 +368,43 @@ template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool
// Read in SM
if (check_index_y < axis_size && (index_x + N_READS) < stride) {
for (int i=0; i<N_READS; i++) {
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] = in[offset + index_y * stride + index_x + i];
for (int i = 0; i < N_READS; i++) {
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] =
in[offset + index_y * stride + index_x + i];
}
} else {
for (int i=0; i<N_READS; i++) {
for (int i = 0; i < N_READS; i++) {
if (check_index_y < axis_size && (index_x + i) < stride) {
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] = in[offset + index_y * stride + index_x + i];
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] =
in[offset + index_y * stride + index_x + i];
} else {
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] = Op::init;
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] =
Op::init;
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Read strided into registers
for (int i=0; i<N_READS; i++) {
values[i] = read_buffer[lid.x * simd_size * N_READS + lid.y * N_READS + i];
for (int i = 0; i < N_READS; i++) {
values[i] =
read_buffer[lid.x * simd_size * N_READS + lid.y * N_READS + i];
}
// Do we need the following barrier? Shouldn't all simd threads execute simultaneously?
// Do we need the following barrier? Shouldn't all simd threads execute
// simultaneously?
simdgroup_barrier(mem_flags::mem_threadgroup);
// Perform the scan
for (int i=0; i<N_READS; i++) {
for (int i = 0; i < N_READS; i++) {
values[i] = op.simd_scan(values[i]);
values[i] = op(values[i], prefix[i]);
prefix[i] = simd_shuffle(values[i], simd_size-1);
prefix[i] = simd_shuffle(values[i], simd_size - 1);
}
// Write to SM
for (int i=0; i<N_READS; i++) {
read_buffer[lid.x * simd_size * N_READS + lid.y * N_READS + i] = values[i];
for (int i = 0; i < N_READS; i++) {
read_buffer[lid.x * simd_size * N_READS + lid.y * N_READS + i] =
values[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
@@ -371,11 +412,11 @@ template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool
if (!inclusive) {
if (check_index_y == 0) {
if ((index_x + N_READS) < stride) {
for (int i=0; i<N_READS; i++) {
for (int i = 0; i < N_READS; i++) {
out[offset + index_y * stride + index_x + i] = Op::init;
}
} else {
for (int i=0; i<N_READS; i++) {
for (int i = 0; i < N_READS; i++) {
if ((index_x + i) < stride) {
out[offset + index_y * stride + index_x + i] = Op::init;
}
@@ -391,55 +432,60 @@ template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool
}
}
if (check_index_y < axis_size && (index_x + N_READS) < stride) {
for (int i=0; i<N_READS; i++) {
out[offset + index_y * stride + index_x + i] = read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i];
for (int i = 0; i < N_READS; i++) {
out[offset + index_y * stride + index_x + i] =
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i];
}
} else {
for (int i=0; i<N_READS; i++) {
for (int i = 0; i < N_READS; i++) {
if (check_index_y < axis_size && (index_x + i) < stride) {
out[offset + index_y * stride + index_x + i] = read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i];
out[offset + index_y * stride + index_x + i] =
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i];
}
}
}
}
}
#define instantiate_contiguous_scan(name, itype, otype, op, inclusive, reverse, nreads) \
template [[host_name("contiguous_scan_" #name)]] \
[[kernel]] void contiguous_scan<itype, otype, op<otype>, nreads, inclusive, reverse>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t & axis_size [[buffer(2)]], \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_size [[threads_per_simdgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_contiguous_scan( \
name, itype, otype, op, inclusive, reverse, nreads) \
template [[host_name("contiguous_scan_" #name)]] [[kernel]] void \
contiguous_scan<itype, otype, op<otype>, nreads, inclusive, reverse>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t& axis_size [[buffer(2)]], \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_size [[threads_per_simdgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_strided_scan(name, itype, otype, op, inclusive, reverse, nreads) \
template [[host_name("strided_scan_" #name)]] \
[[kernel]] void strided_scan<itype, otype, op<otype>, nreads, inclusive, reverse>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t & axis_size [[buffer(2)]], \
const constant size_t & stride [[buffer(3)]], \
uint2 gid [[thread_position_in_grid]], \
uint2 lid [[thread_position_in_threadgroup]], \
uint2 lsize [[threads_per_threadgroup]], \
uint simd_size [[threads_per_simdgroup]]);
#define instantiate_strided_scan( \
name, itype, otype, op, inclusive, reverse, nreads) \
template [[host_name("strided_scan_" #name)]] [[kernel]] void \
strided_scan<itype, otype, op<otype>, nreads, inclusive, reverse>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t& axis_size [[buffer(2)]], \
const constant size_t& stride [[buffer(3)]], \
uint2 gid [[thread_position_in_grid]], \
uint2 lid [[thread_position_in_threadgroup]], \
uint2 lsize [[threads_per_threadgroup]], \
uint simd_size [[threads_per_simdgroup]]);
#define instantiate_scan_helper(name, itype, otype, op, nreads) \
instantiate_contiguous_scan(inclusive_##name, itype, otype, op, true, false, nreads) \
instantiate_contiguous_scan(exclusive_##name, itype, otype, op, false, false, nreads) \
instantiate_contiguous_scan(reverse_inclusive_##name, itype, otype, op, true, true, nreads) \
// clang-format off
#define instantiate_scan_helper(name, itype, otype, op, nreads) \
instantiate_contiguous_scan(inclusive_##name, itype, otype, op, true, false, nreads) \
instantiate_contiguous_scan(exclusive_##name, itype, otype, op, false, false, nreads) \
instantiate_contiguous_scan(reverse_inclusive_##name, itype, otype, op, true, true, nreads) \
instantiate_contiguous_scan(reverse_exclusive_##name, itype, otype, op, false, true, nreads) \
instantiate_strided_scan(inclusive_##name, itype, otype, op, true, false, nreads) \
instantiate_strided_scan(exclusive_##name, itype, otype, op, false, false, nreads) \
instantiate_strided_scan(reverse_inclusive_##name, itype, otype, op, true, true, nreads) \
instantiate_strided_scan(reverse_exclusive_##name, itype, otype, op, false, true, nreads)
instantiate_strided_scan(inclusive_##name, itype, otype, op, true, false, nreads) \
instantiate_strided_scan(exclusive_##name, itype, otype, op, false, false, nreads) \
instantiate_strided_scan(reverse_inclusive_##name, itype, otype, op, true, true, nreads) \
instantiate_strided_scan(reverse_exclusive_##name, itype, otype, op, false, true, nreads) // clang-format on
// clang-format off
instantiate_scan_helper(sum_bool__int32, bool, int32_t, CumSum, 4)
instantiate_scan_helper(sum_uint8_uint8, uint8_t, uint8_t, CumSum, 4)
instantiate_scan_helper(sum_uint16_uint16, uint16_t, uint16_t, CumSum, 4)
@@ -491,4 +537,4 @@ instantiate_scan_helper(min_int32_int32, int32_t, int32_t, CumMi
instantiate_scan_helper(min_float16_float16, half, half, CumMin, 4)
instantiate_scan_helper(min_float32_float32, float, float, CumMin, 4)
instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMin, 4)
//instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin)
//instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin) // clang-format on

View File

@@ -13,67 +13,55 @@ using namespace metal;
// Scatter kernel
/////////////////////////////////////////////////////////////////////
template <typename T, typename IdxT, typename Op, int NIDX> \
template <typename T, typename IdxT, typename Op, int NIDX>
METAL_FUNC void scatter_1d_index_impl(
const device T *updates [[buffer(1)]],
device mlx_atomic<T> *out [[buffer(2)]],
const constant int* out_shape [[buffer(3)]],
const constant size_t* out_strides [[buffer(4)]],
const constant size_t& upd_size [[buffer(5)]],
const thread array<const device IdxT*, NIDX>& idx_buffers,
uint2 gid [[thread_position_in_grid]]) {
const device T* updates [[buffer(1)]],
device mlx_atomic<T>* out [[buffer(2)]],
const constant int* out_shape [[buffer(3)]],
const constant size_t* out_strides [[buffer(4)]],
const constant size_t& upd_size [[buffer(5)]],
const thread array<const device IdxT*, NIDX>& idx_buffers,
uint2 gid [[thread_position_in_grid]]) {
Op op;
uint out_idx = 0;
for (int i = 0; i < NIDX; i++) {
auto idx_val = offset_neg_idx(
idx_buffers[i][gid.y], out_shape[i]);
auto idx_val = offset_neg_idx(idx_buffers[i][gid.y], out_shape[i]);
out_idx += idx_val * out_strides[i];
}
op.atomic_update(out, updates[gid.y * upd_size + gid.x], out_idx + gid.x);
}
#define make_scatter_1d_index(IDX_ARG, IDX_ARR) \
template <typename T, typename IdxT, typename Op, int NIDX> \
[[kernel]] void scatter_1d_index( \
const device T *updates [[buffer(1)]], \
device mlx_atomic<T> *out [[buffer(2)]], \
const constant int* out_shape [[buffer(3)]], \
const constant size_t* out_strides [[buffer(4)]], \
const constant size_t& upd_size [[buffer(5)]], \
IDX_ARG(IdxT) \
uint2 gid [[thread_position_in_grid]]) { \
\
const array<const device IdxT*, NIDX> idx_buffers = {IDX_ARR()}; \
\
return scatter_1d_index_impl<T, IdxT, Op, NIDX>( \
updates, \
out, \
out_shape, \
out_strides, \
upd_size, \
idx_buffers, \
gid); \
\
}
#define make_scatter_1d_index(IDX_ARG, IDX_ARR) \
template <typename T, typename IdxT, typename Op, int NIDX> \
[[kernel]] void scatter_1d_index( \
const device T* updates [[buffer(1)]], \
device mlx_atomic<T>* out [[buffer(2)]], \
const constant int* out_shape [[buffer(3)]], \
const constant size_t* out_strides [[buffer(4)]], \
const constant size_t& upd_size [[buffer(5)]], \
IDX_ARG(IdxT) uint2 gid [[thread_position_in_grid]]) { \
const array<const device IdxT*, NIDX> idx_buffers = {IDX_ARR()}; \
\
return scatter_1d_index_impl<T, IdxT, Op, NIDX>( \
updates, out, out_shape, out_strides, upd_size, idx_buffers, gid); \
}
template <typename T, typename IdxT, typename Op, int NIDX>
METAL_FUNC void scatter_impl(
const device T *updates [[buffer(1)]],
device mlx_atomic<T> *out [[buffer(2)]],
const constant int *upd_shape [[buffer(3)]],
const constant size_t *upd_strides [[buffer(4)]],
const device T* updates [[buffer(1)]],
device mlx_atomic<T>* out [[buffer(2)]],
const constant int* upd_shape [[buffer(3)]],
const constant size_t* upd_strides [[buffer(4)]],
const constant size_t& upd_ndim [[buffer(5)]],
const constant size_t& upd_size [[buffer(6)]],
const constant int *out_shape [[buffer(7)]],
const constant size_t *out_strides [[buffer(8)]],
const constant int* out_shape [[buffer(7)]],
const constant size_t* out_strides [[buffer(8)]],
const constant size_t& out_ndim [[buffer(9)]],
const constant int* axes [[buffer(10)]],
const thread Indices<IdxT, NIDX>& indices,
uint2 gid [[thread_position_in_grid]]) {
Op op;
auto ind_idx = gid.y;
auto ind_offset = gid.x;
@@ -86,8 +74,7 @@ METAL_FUNC void scatter_impl(
&indices.strides[indices.ndim * i],
indices.ndim);
auto ax = axes[i];
auto idx_val = offset_neg_idx(
indices.buffers[i][idx_loc], out_shape[ax]);
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], out_shape[ax]);
out_idx += idx_val * out_strides[ax];
}
@@ -97,142 +84,134 @@ METAL_FUNC void scatter_impl(
out_idx += out_offset;
}
auto upd_idx = elem_to_loc(gid.y * upd_size + gid.x, upd_shape, upd_strides, upd_ndim);
auto upd_idx =
elem_to_loc(gid.y * upd_size + gid.x, upd_shape, upd_strides, upd_ndim);
op.atomic_update(out, updates[upd_idx], out_idx);
}
#define make_scatter_impl(IDX_ARG, IDX_ARR) \
template <typename T, typename IdxT, typename Op, int NIDX> \
[[kernel]] void scatter( \
const device T *updates [[buffer(1)]], \
device mlx_atomic<T> *out [[buffer(2)]], \
const constant int *upd_shape [[buffer(3)]], \
const constant size_t *upd_strides [[buffer(4)]], \
const constant size_t& upd_ndim [[buffer(5)]], \
const constant size_t& upd_size [[buffer(6)]], \
const constant int *out_shape [[buffer(7)]], \
const constant size_t *out_strides [[buffer(8)]], \
const constant size_t& out_ndim [[buffer(9)]], \
const constant int* axes [[buffer(10)]], \
const constant int *idx_shapes [[buffer(11)]], \
const constant size_t *idx_strides [[buffer(12)]], \
const constant int& idx_ndim [[buffer(13)]], \
IDX_ARG(IdxT) \
uint2 gid [[thread_position_in_grid]]) { \
\
Indices<IdxT, NIDX> idxs{ \
{{IDX_ARR()}}, \
idx_shapes, \
idx_strides, \
idx_ndim}; \
\
return scatter_impl<T, IdxT, Op, NIDX>( \
updates, \
out, \
upd_shape, \
upd_strides, \
upd_ndim, \
upd_size, \
out_shape, \
out_strides, \
out_ndim, \
axes, \
idxs, \
gid); \
}
#define make_scatter_impl(IDX_ARG, IDX_ARR) \
template <typename T, typename IdxT, typename Op, int NIDX> \
[[kernel]] void scatter( \
const device T* updates [[buffer(1)]], \
device mlx_atomic<T>* out [[buffer(2)]], \
const constant int* upd_shape [[buffer(3)]], \
const constant size_t* upd_strides [[buffer(4)]], \
const constant size_t& upd_ndim [[buffer(5)]], \
const constant size_t& upd_size [[buffer(6)]], \
const constant int* out_shape [[buffer(7)]], \
const constant size_t* out_strides [[buffer(8)]], \
const constant size_t& out_ndim [[buffer(9)]], \
const constant int* axes [[buffer(10)]], \
const constant int* idx_shapes [[buffer(11)]], \
const constant size_t* idx_strides [[buffer(12)]], \
const constant int& idx_ndim [[buffer(13)]], \
IDX_ARG(IdxT) uint2 gid [[thread_position_in_grid]]) { \
Indices<IdxT, NIDX> idxs{ \
{{IDX_ARR()}}, idx_shapes, idx_strides, idx_ndim}; \
\
return scatter_impl<T, IdxT, Op, NIDX>( \
updates, \
out, \
upd_shape, \
upd_strides, \
upd_ndim, \
upd_size, \
out_shape, \
out_strides, \
out_ndim, \
axes, \
idxs, \
gid); \
}
#define make_scatter(n) \
make_scatter_impl(IDX_ARG_ ##n, IDX_ARR_ ##n) \
make_scatter_1d_index(IDX_ARG_ ##n, IDX_ARR_ ##n)
#define make_scatter(n) \
make_scatter_impl(IDX_ARG_##n, IDX_ARR_##n) \
make_scatter_1d_index(IDX_ARG_##n, IDX_ARR_##n)
make_scatter(0)
make_scatter(1)
make_scatter(2)
make_scatter(3)
make_scatter(4)
make_scatter(5)
make_scatter(6)
make_scatter(7)
make_scatter(8)
make_scatter(9)
make_scatter(10)
make_scatter(0) make_scatter(1) make_scatter(2) make_scatter(3) make_scatter(4)
make_scatter(5) make_scatter(6) make_scatter(7) make_scatter(8)
make_scatter(9) make_scatter(10)
/////////////////////////////////////////////////////////////////////
// Scatter instantiations
/////////////////////////////////////////////////////////////////////
#define instantiate_scatter5(name, src_t, idx_t, op_t, nidx, IDX_ARG) \
template [[host_name("scatter" name "_" #nidx)]] \
[[kernel]] void scatter<src_t, idx_t, op_t, nidx>( \
const device src_t *updates [[buffer(1)]], \
device mlx_atomic<src_t> *out [[buffer(2)]], \
const constant int *upd_shape [[buffer(3)]], \
const constant size_t *upd_strides [[buffer(4)]], \
const constant size_t& upd_ndim [[buffer(5)]], \
const constant size_t& upd_size [[buffer(6)]], \
const constant int *out_shape [[buffer(7)]], \
const constant size_t *out_strides [[buffer(8)]], \
const constant size_t& out_ndim [[buffer(9)]], \
const constant int* axes [[buffer(10)]], \
const constant int *idx_shapes [[buffer(11)]], \
const constant size_t *idx_strides [[buffer(12)]], \
const constant int& idx_ndim [[buffer(13)]], \
IDX_ARG(idx_t) \
uint2 gid [[thread_position_in_grid]]);
template [[host_name("scatter" name "_" #nidx)]] [[kernel]] void \
scatter<src_t, idx_t, op_t, nidx>( \
const device src_t* updates [[buffer(1)]], \
device mlx_atomic<src_t>* out [[buffer(2)]], \
const constant int* upd_shape [[buffer(3)]], \
const constant size_t* upd_strides [[buffer(4)]], \
const constant size_t& upd_ndim [[buffer(5)]], \
const constant size_t& upd_size [[buffer(6)]], \
const constant int* out_shape [[buffer(7)]], \
const constant size_t* out_strides [[buffer(8)]], \
const constant size_t& out_ndim [[buffer(9)]], \
const constant int* axes [[buffer(10)]], \
const constant int* idx_shapes [[buffer(11)]], \
const constant size_t* idx_strides [[buffer(12)]], \
const constant int& idx_ndim [[buffer(13)]], \
IDX_ARG(idx_t) uint2 gid [[thread_position_in_grid]]);
#define instantiate_scatter6(name, src_t, idx_t, op_t, nidx, IDX_ARG) \
template [[host_name("scatter_1d_index" name "_" #nidx)]] \
[[kernel]] void scatter_1d_index<src_t, idx_t, op_t, nidx>( \
const device src_t *updates [[buffer(1)]], \
device mlx_atomic<src_t> *out [[buffer(2)]], \
const constant int* out_shape [[buffer(3)]], \
const constant size_t* out_strides [[buffer(4)]], \
const constant size_t& upd_size [[buffer(5)]], \
IDX_ARG(idx_t) \
uint2 gid [[thread_position_in_grid]]);
#define instantiate_scatter6(name, src_t, idx_t, op_t, nidx, IDX_ARG) \
template [[host_name("scatter_1d_index" name "_" #nidx)]] [[kernel]] void \
scatter_1d_index<src_t, idx_t, op_t, nidx>( \
const device src_t* updates [[buffer(1)]], \
device mlx_atomic<src_t>* out [[buffer(2)]], \
const constant int* out_shape [[buffer(3)]], \
const constant size_t* out_strides [[buffer(4)]], \
const constant size_t& upd_size [[buffer(5)]], \
IDX_ARG(idx_t) uint2 gid [[thread_position_in_grid]]);
#define instantiate_scatter4(name, src_t, idx_t, op_t, nidx) \
// clang-format off
#define instantiate_scatter4(name, src_t, idx_t, op_t, nidx) \
instantiate_scatter5(name, src_t, idx_t, op_t, nidx, IDX_ARG_ ##nidx) \
instantiate_scatter6(name, src_t, idx_t, op_t, nidx, IDX_ARG_ ##nidx)
instantiate_scatter6(name, src_t, idx_t, op_t, nidx, IDX_ARG_ ##nidx) // clang-format on
// clang-format off
// Special case NINDEX=0
#define instantiate_scatter_nd0(name, type) \
instantiate_scatter4(#name "none", type, bool, None, 0) \
instantiate_scatter4(#name "_sum", type, bool, Sum<type>, 0) \
#define instantiate_scatter_nd0(name, type) \
instantiate_scatter4(#name "none", type, bool, None, 0) \
instantiate_scatter4(#name "_sum", type, bool, Sum<type>, 0) \
instantiate_scatter4(#name "_prod", type, bool, Prod<type>, 0) \
instantiate_scatter4(#name "_max", type, bool, Max<type>, 0) \
instantiate_scatter4(#name "_min", type, bool, Min<type>, 0)
instantiate_scatter4(#name "_max", type, bool, Max<type>, 0) \
instantiate_scatter4(#name "_min", type, bool, Min<type>, 0) // clang-format on
// clang-format off
#define instantiate_scatter3(name, type, ind_type, op_type) \
instantiate_scatter4(name, type, ind_type, op_type, 1) \
instantiate_scatter4(name, type, ind_type, op_type, 2) \
instantiate_scatter4(name, type, ind_type, op_type, 3) \
instantiate_scatter4(name, type, ind_type, op_type, 4) \
instantiate_scatter4(name, type, ind_type, op_type, 5) \
instantiate_scatter4(name, type, ind_type, op_type, 6) \
instantiate_scatter4(name, type, ind_type, op_type, 7) \
instantiate_scatter4(name, type, ind_type, op_type, 8) \
instantiate_scatter4(name, type, ind_type, op_type, 9) \
instantiate_scatter4(name, type, ind_type, op_type, 10)
instantiate_scatter4(name, type, ind_type, op_type, 1) \
instantiate_scatter4(name, type, ind_type, op_type, 2) \
instantiate_scatter4(name, type, ind_type, op_type, 3) \
instantiate_scatter4(name, type, ind_type, op_type, 4) \
instantiate_scatter4(name, type, ind_type, op_type, 5) \
instantiate_scatter4(name, type, ind_type, op_type, 6) \
instantiate_scatter4(name, type, ind_type, op_type, 7) \
instantiate_scatter4(name, type, ind_type, op_type, 8) \
instantiate_scatter4(name, type, ind_type, op_type, 9) \
instantiate_scatter4(name, type, ind_type, op_type, 10) // clang-format on
#define instantiate_scatter2(name, type, ind_type) \
instantiate_scatter3(name "_none", type, ind_type, None) \
instantiate_scatter3(name "_sum", type, ind_type, Sum<type>) \
// clang-format off
#define instantiate_scatter2(name, type, ind_type) \
instantiate_scatter3(name "_none", type, ind_type, None) \
instantiate_scatter3(name "_sum", type, ind_type, Sum<type>) \
instantiate_scatter3(name "_prod", type, ind_type, Prod<type>) \
instantiate_scatter3(name "_max", type, ind_type, Max<type>) \
instantiate_scatter3(name "_min", type, ind_type, Min<type>)
instantiate_scatter3(name "_max", type, ind_type, Max<type>) \
instantiate_scatter3(name "_min", type, ind_type, Min<type>) // clang-format on
#define instantiate_scatter(name, type) \
instantiate_scatter2(#name "bool_", type, bool) \
instantiate_scatter2(#name "uint8", type, uint8_t) \
// clang-format off
#define instantiate_scatter(name, type) \
instantiate_scatter2(#name "bool_", type, bool) \
instantiate_scatter2(#name "uint8", type, uint8_t) \
instantiate_scatter2(#name "uint16", type, uint16_t) \
instantiate_scatter2(#name "uint32", type, uint32_t) \
instantiate_scatter2(#name "uint64", type, uint64_t) \
instantiate_scatter2(#name "int8", type, int8_t) \
instantiate_scatter2(#name "int16", type, int16_t) \
instantiate_scatter2(#name "int32", type, int32_t) \
instantiate_scatter2(#name "int64", type, int64_t)
instantiate_scatter2(#name "int8", type, int8_t) \
instantiate_scatter2(#name "int16", type, int16_t) \
instantiate_scatter2(#name "int32", type, int32_t) \
instantiate_scatter2(#name "int64", type, int64_t) // clang-format on
// clang-format off
// TODO uint64 and int64 unsupported
instantiate_scatter_nd0(bool_, bool)
instantiate_scatter_nd0(uint8, uint8_t)
@@ -254,4 +233,4 @@ instantiate_scatter(int16, int16_t)
instantiate_scatter(int32, int32_t)
instantiate_scatter(float16, half)
instantiate_scatter(float32, float)
instantiate_scatter(bfloat16, bfloat16_t)
instantiate_scatter(bfloat16, bfloat16_t) // clang-format on

View File

@@ -198,17 +198,16 @@ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
}
}
// clang-format off
#define instantiate_softmax(name, itype) \
template [[host_name("softmax_" #name)]] [[kernel]] void \
softmax_single_row<itype>( \
const device itype* in, \
device itype* out, \
constant int& axis_size, \
uint gid [[thread_position_in_grid]], \
uint _lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
#define instantiate_softmax(name, itype) \
template [[host_name("softmax_" #name)]] [[kernel]] void \
softmax_single_row<itype>( \
const device itype* in, \
device itype* out, \
constant int& axis_size, \
uint gid [[thread_position_in_grid]], \
uint _lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
template [[host_name("softmax_looped_" #name)]] [[kernel]] void \
softmax_looped<itype>( \
const device itype* in, \
@@ -220,16 +219,16 @@ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_softmax_precise(name, itype) \
template [[host_name("softmax_precise_" #name)]] [[kernel]] void \
softmax_single_row<itype, float>( \
const device itype* in, \
device itype* out, \
constant int& axis_size, \
uint gid [[thread_position_in_grid]], \
uint _lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
#define instantiate_softmax_precise(name, itype) \
template [[host_name("softmax_precise_" #name)]] [[kernel]] void \
softmax_single_row<itype, float>( \
const device itype* in, \
device itype* out, \
constant int& axis_size, \
uint gid [[thread_position_in_grid]], \
uint _lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
template [[host_name("softmax_looped_precise_" #name)]] [[kernel]] void \
softmax_looped<itype, float>( \
const device itype* in, \
@@ -241,9 +240,9 @@ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
// clang-format off
instantiate_softmax(float32, float)
instantiate_softmax(float16, half)
instantiate_softmax(bfloat16, bfloat16_t)
instantiate_softmax_precise(float16, half)
instantiate_softmax_precise(bfloat16, bfloat16_t)
// clang-format on
instantiate_softmax_precise(bfloat16, bfloat16_t) // clang-format on

View File

@@ -11,7 +11,8 @@
using namespace metal;
// Based on GPU merge sort algorithm at https://github.com/NVIDIA/cccl/tree/main/cub/cub
// Based on GPU merge sort algorithm at
// https://github.com/NVIDIA/cccl/tree/main/cub/cub
///////////////////////////////////////////////////////////////////////////////
// Thread-level sort
@@ -43,20 +44,18 @@ struct ThreadSort {
static METAL_FUNC void sort(
thread val_t (&vals)[N_PER_THREAD],
thread idx_t (&idxs)[N_PER_THREAD]) {
CompareOp op;
MLX_MTL_LOOP_UNROLL
for(short i = 0; i < N_PER_THREAD; ++i) {
MLX_MTL_LOOP_UNROLL
for(short j = i & 1; j < N_PER_THREAD - 1; j += 2) {
if(op(vals[j + 1], vals[j])) {
for (short i = 0; i < N_PER_THREAD; ++i) {
MLX_MTL_LOOP_UNROLL
for (short j = i & 1; j < N_PER_THREAD - 1; j += 2) {
if (op(vals[j + 1], vals[j])) {
thread_swap(vals[j + 1], vals[j]);
thread_swap(idxs[j + 1], idxs[j]);
}
}
}
}
};
@@ -72,25 +71,25 @@ template <
short N_PER_THREAD,
typename CompareOp>
struct BlockMergeSort {
using thread_sort_t = ThreadSort<val_t, idx_t, ARG_SORT, N_PER_THREAD, CompareOp>;
using thread_sort_t =
ThreadSort<val_t, idx_t, ARG_SORT, N_PER_THREAD, CompareOp>;
static METAL_FUNC int merge_partition(
const threadgroup val_t* As,
const threadgroup val_t* Bs,
short A_sz,
short B_sz,
short sort_md) {
CompareOp op;
short A_st = max(0, sort_md - B_sz);
short A_ed = min(sort_md, A_sz);
while(A_st < A_ed) {
while (A_st < A_ed) {
short md = A_st + (A_ed - A_st) / 2;
auto a = As[md];
auto b = Bs[sort_md - 1 - md];
if(op(b, a)) {
if (op(b, a)) {
A_ed = md;
} else {
A_st = md + 1;
@@ -98,7 +97,6 @@ struct BlockMergeSort {
}
return A_ed;
}
static METAL_FUNC void merge_step(
@@ -110,12 +108,11 @@ struct BlockMergeSort {
short B_sz,
thread val_t (&vals)[N_PER_THREAD],
thread idx_t (&idxs)[N_PER_THREAD]) {
CompareOp op;
short a_idx = 0;
short b_idx = 0;
for(int i = 0; i < N_PER_THREAD; ++i) {
for (int i = 0; i < N_PER_THREAD; ++i) {
auto a = As[a_idx];
auto b = Bs[b_idx];
bool pred = (b_idx < B_sz) && (a_idx >= A_sz || op(b, a));
@@ -126,7 +123,6 @@ struct BlockMergeSort {
b_idx += short(pred);
a_idx += short(!pred);
}
}
static METAL_FUNC void sort(
@@ -134,32 +130,32 @@ struct BlockMergeSort {
threadgroup idx_t* tgp_idxs [[threadgroup(1)]],
int size_sorted_axis,
uint3 lid [[thread_position_in_threadgroup]]) {
// Get thread location
int idx = lid.x * N_PER_THREAD;
// Load from shared memory
thread val_t thread_vals[N_PER_THREAD];
thread idx_t thread_idxs[N_PER_THREAD];
for(int i = 0; i < N_PER_THREAD; ++i) {
for (int i = 0; i < N_PER_THREAD; ++i) {
thread_vals[i] = tgp_vals[idx + i];
if(ARG_SORT) {
if (ARG_SORT) {
thread_idxs[i] = tgp_idxs[idx + i];
}
}
// Per thread sort
if(idx < size_sorted_axis) {
// Per thread sort
if (idx < size_sorted_axis) {
thread_sort_t::sort(thread_vals, thread_idxs);
}
// Do merges using threadgroup memory
for (int merge_threads = 2; merge_threads <= BLOCK_THREADS; merge_threads *= 2) {
for (int merge_threads = 2; merge_threads <= BLOCK_THREADS;
merge_threads *= 2) {
// Update threadgroup memory
threadgroup_barrier(mem_flags::mem_threadgroup);
for(int i = 0; i < N_PER_THREAD; ++i) {
for (int i = 0; i < N_PER_THREAD; ++i) {
tgp_vals[idx + i] = thread_vals[i];
if(ARG_SORT) {
if (ARG_SORT) {
tgp_idxs[idx + i] = thread_idxs[i];
}
}
@@ -167,7 +163,7 @@ struct BlockMergeSort {
// Find location in merge step
int merge_group = lid.x / merge_threads;
int merge_lane = lid.x % merge_threads;
int merge_lane = lid.x % merge_threads;
int sort_sz = N_PER_THREAD * merge_threads;
int sort_st = N_PER_THREAD * merge_threads * merge_group;
@@ -185,16 +181,11 @@ struct BlockMergeSort {
int B_sz = B_ed - B_st;
// Find a partition of merge elements
// Ci = merge(As[partition:], Bs[sort_md - partition:])
// Ci = merge(As[partition:], Bs[sort_md - partition:])
// of size N_PER_THREAD for each merge lane i
// C = [Ci] is sorted
// C = [Ci] is sorted
int sort_md = N_PER_THREAD * merge_lane;
int partition = merge_partition(
As,
Bs,
A_sz,
B_sz,
sort_md);
int partition = merge_partition(As, Bs, A_sz, B_sz, sort_md);
As += partition;
Bs += sort_md - partition;
@@ -202,27 +193,20 @@ struct BlockMergeSort {
A_sz -= partition;
B_sz -= sort_md - partition;
const threadgroup idx_t* As_idx = ARG_SORT ? tgp_idxs + A_st + partition : nullptr;
const threadgroup idx_t* Bs_idx = ARG_SORT ? tgp_idxs + B_st + sort_md - partition : nullptr;
const threadgroup idx_t* As_idx =
ARG_SORT ? tgp_idxs + A_st + partition : nullptr;
const threadgroup idx_t* Bs_idx =
ARG_SORT ? tgp_idxs + B_st + sort_md - partition : nullptr;
// Merge starting at the partition and store results in thread registers
merge_step(
As,
Bs,
As_idx,
Bs_idx,
A_sz,
B_sz,
thread_vals,
thread_idxs);
merge_step(As, Bs, As_idx, Bs_idx, A_sz, B_sz, thread_vals, thread_idxs);
}
// Write out to shared memory
threadgroup_barrier(mem_flags::mem_threadgroup);
for(int i = 0; i < N_PER_THREAD; ++i) {
for (int i = 0; i < N_PER_THREAD; ++i) {
tgp_vals[idx + i] = thread_vals[i];
if(ARG_SORT) {
if (ARG_SORT) {
tgp_idxs[idx + i] = thread_idxs[i];
}
}
@@ -235,7 +219,7 @@ struct BlockMergeSort {
template <
typename T,
typename U,
typename U,
bool ARG_SORT,
short BLOCK_THREADS,
short N_PER_THREAD,
@@ -244,13 +228,13 @@ struct KernelMergeSort {
using val_t = T;
using idx_t = uint;
using block_merge_sort_t = BlockMergeSort<
val_t,
val_t,
idx_t,
ARG_SORT,
BLOCK_THREADS,
BLOCK_THREADS,
N_PER_THREAD,
CompareOp>;
MLX_MTL_CONST short N_PER_BLOCK = BLOCK_THREADS * N_PER_THREAD;
static METAL_FUNC void block_sort(
@@ -263,15 +247,15 @@ struct KernelMergeSort {
threadgroup idx_t* tgp_idxs,
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
// tid.y tells us the segment index
inp += tid.y * stride_segment_axis;
out += tid.y * stride_segment_axis;
// Copy into threadgroup memory
for(short i = lid.x; i < N_PER_BLOCK; i+= BLOCK_THREADS) {
tgp_vals[i] = i < size_sorted_axis ? inp[i * stride_sorted_axis] : val_t(CompareOp::init);
if(ARG_SORT) {
for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) {
tgp_vals[i] = i < size_sorted_axis ? inp[i * stride_sorted_axis]
: val_t(CompareOp::init);
if (ARG_SORT) {
tgp_idxs[i] = i;
}
}
@@ -284,8 +268,8 @@ struct KernelMergeSort {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Write output
for(int i = lid.x; i < size_sorted_axis; i+= BLOCK_THREADS) {
if(ARG_SORT) {
for (int i = lid.x; i < size_sorted_axis; i += BLOCK_THREADS) {
if (ARG_SORT) {
out[i * stride_sorted_axis] = tgp_idxs[i];
} else {
out[i * stride_sorted_axis] = tgp_vals[i];
@@ -296,7 +280,7 @@ struct KernelMergeSort {
template <
typename T,
typename U,
typename U,
bool ARG_SORT,
short BLOCK_THREADS,
short N_PER_THREAD>
@@ -308,12 +292,12 @@ template <
const constant int& stride_segment_axis [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
using sort_kernel = KernelMergeSort<T, U, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>;
using sort_kernel =
KernelMergeSort<T, U, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>;
using val_t = typename sort_kernel::val_t;
using idx_t = typename sort_kernel::idx_t;
if(ARG_SORT) {
if (ARG_SORT) {
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
sort_kernel::block_sort(
@@ -339,14 +323,13 @@ template <
tid,
lid);
}
}
constant constexpr const int zero_helper = 0;
template <
typename T,
typename U,
typename U,
bool ARG_SORT,
short BLOCK_THREADS,
short N_PER_THREAD>
@@ -360,8 +343,8 @@ template <
const device size_t* nc_strides [[buffer(6)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
using sort_kernel = KernelMergeSort<T, U, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>;
using sort_kernel =
KernelMergeSort<T, U, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>;
using val_t = typename sort_kernel::val_t;
using idx_t = typename sort_kernel::idx_t;
@@ -369,7 +352,7 @@ template <
inp += block_idx;
out += block_idx;
if(ARG_SORT) {
if (ARG_SORT) {
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
sort_kernel::block_sort(
@@ -395,50 +378,55 @@ template <
tid,
lid);
}
}
///////////////////////////////////////////////////////////////////////////////
// Instantiations
///////////////////////////////////////////////////////////////////////////////
#define instantiate_block_sort(name, itname, itype, otname, otype, arg_sort, bn, tn) \
template [[host_name(#name "_" #itname "_" #otname "_bn" #bn "_tn" #tn)]] \
[[kernel]] void block_sort<itype, otype, arg_sort, bn, tn>( \
const device itype* inp [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant int& size_sorted_axis [[buffer(2)]], \
const constant int& stride_sorted_axis [[buffer(3)]], \
const constant int& stride_segment_axis [[buffer(4)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]); \
template [[host_name(#name "_" #itname "_" #otname "_bn" #bn "_tn" #tn "_nc")]] \
[[kernel]] void block_sort_nc<itype, otype, arg_sort, bn, tn>( \
const device itype* inp [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant int& size_sorted_axis [[buffer(2)]], \
const constant int& stride_sorted_axis [[buffer(3)]], \
const constant int& nc_dim [[buffer(4)]], \
const device int* nc_shape [[buffer(5)]], \
const device size_t* nc_strides [[buffer(6)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);
#define instantiate_block_sort( \
name, itname, itype, otname, otype, arg_sort, bn, tn) \
template [[host_name(#name "_" #itname "_" #otname "_bn" #bn \
"_tn" #tn)]] [[kernel]] void \
block_sort<itype, otype, arg_sort, bn, tn>( \
const device itype* inp [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant int& size_sorted_axis [[buffer(2)]], \
const constant int& stride_sorted_axis [[buffer(3)]], \
const constant int& stride_segment_axis [[buffer(4)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]); \
template [[host_name(#name "_" #itname "_" #otname "_bn" #bn "_tn" #tn \
"_nc")]] [[kernel]] void \
block_sort_nc<itype, otype, arg_sort, bn, tn>( \
const device itype* inp [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant int& size_sorted_axis [[buffer(2)]], \
const constant int& stride_sorted_axis [[buffer(3)]], \
const constant int& nc_dim [[buffer(4)]], \
const device int* nc_shape [[buffer(5)]], \
const device size_t* nc_strides [[buffer(6)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);
#define instantiate_arg_block_sort_base(itname, itype, bn, tn) \
instantiate_block_sort(arg_block_merge_sort, itname, itype, uint32, uint32_t, true, bn, tn)
instantiate_block_sort( \
arg_block_merge_sort, itname, itype, uint32, uint32_t, true, bn, tn)
#define instantiate_block_sort_base(itname, itype, bn, tn) \
instantiate_block_sort(block_merge_sort, itname, itype, itname, itype, false, bn, tn)
instantiate_block_sort( \
block_merge_sort, itname, itype, itname, itype, false, bn, tn)
// clang-format off
#define instantiate_block_sort_tn(itname, itype, bn) \
instantiate_block_sort_base(itname, itype, bn, 8) \
instantiate_arg_block_sort_base(itname, itype, bn, 8)
instantiate_block_sort_base(itname, itype, bn, 8) \
instantiate_arg_block_sort_base(itname, itype, bn, 8) // clang-format on
// clang-format off
#define instantiate_block_sort_bn(itname, itype) \
instantiate_block_sort_tn(itname, itype, 128) \
instantiate_block_sort_tn(itname, itype, 256) \
instantiate_block_sort_tn(itname, itype, 512)
instantiate_block_sort_tn(itname, itype, 128) \
instantiate_block_sort_tn(itname, itype, 256) \
instantiate_block_sort_tn(itname, itype, 512)
instantiate_block_sort_bn(uint8, uint8_t)
instantiate_block_sort_bn(uint16, uint16_t)
@@ -448,35 +436,35 @@ instantiate_block_sort_bn(int16, int16_t)
instantiate_block_sort_bn(int32, int32_t)
instantiate_block_sort_bn(float16, half)
instantiate_block_sort_bn(float32, float)
instantiate_block_sort_bn(bfloat16, bfloat16_t)
instantiate_block_sort_bn(bfloat16, bfloat16_t) // clang-format on
// clang-format off
#define instantiate_block_sort_long(itname, itype) \
instantiate_block_sort_tn(itname, itype, 128) \
instantiate_block_sort_tn(itname, itype, 128) \
instantiate_block_sort_tn(itname, itype, 256)
instantiate_block_sort_long(uint64, uint64_t)
instantiate_block_sort_long(int64, int64_t)
instantiate_block_sort_long(int64, int64_t) // clang-format on
///////////////////////////////////////////////////////////////////////////////
// Multi block merge sort
///////////////////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////
// Multi block merge sort
///////////////////////////////////////////////////////////////////////////////
template <
typename val_t,
typename idx_t,
bool ARG_SORT,
short BLOCK_THREADS,
short N_PER_THREAD,
typename CompareOp = LessThan<val_t>>
struct KernelMultiBlockMergeSort {
template <
typename val_t,
typename idx_t,
bool ARG_SORT,
short BLOCK_THREADS,
short N_PER_THREAD,
typename CompareOp = LessThan<val_t>>
struct KernelMultiBlockMergeSort {
using block_merge_sort_t = BlockMergeSort<
val_t,
val_t,
idx_t,
ARG_SORT,
BLOCK_THREADS,
BLOCK_THREADS,
N_PER_THREAD,
CompareOp>;
MLX_MTL_CONST short N_PER_BLOCK = BLOCK_THREADS * N_PER_THREAD;
static METAL_FUNC void block_sort(
@@ -489,14 +477,14 @@ struct KernelMultiBlockMergeSort {
threadgroup idx_t* tgp_idxs,
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
// tid.y tells us the segment index
int base_idx = tid.x * N_PER_BLOCK;
// Copy into threadgroup memory
for(short i = lid.x; i < N_PER_BLOCK; i+= BLOCK_THREADS) {
for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) {
int idx = base_idx + i;
tgp_vals[i] = idx < size_sorted_axis ? inp[idx * stride_sorted_axis] : val_t(CompareOp::init);
tgp_vals[i] = idx < size_sorted_axis ? inp[idx * stride_sorted_axis]
: val_t(CompareOp::init);
tgp_idxs[i] = idx;
}
@@ -508,9 +496,9 @@ struct KernelMultiBlockMergeSort {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Write output
for(int i = lid.x; i < N_PER_BLOCK; i+= BLOCK_THREADS) {
for (int i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) {
int idx = base_idx + i;
if(idx < size_sorted_axis) {
if (idx < size_sorted_axis) {
out_vals[idx] = tgp_vals[i];
out_idxs[idx] = tgp_idxs[i];
}
@@ -523,18 +511,17 @@ struct KernelMultiBlockMergeSort {
int A_sz,
int B_sz,
int sort_md) {
CompareOp op;
int A_st = max(0, sort_md - B_sz);
int A_ed = min(sort_md, A_sz);
while(A_st < A_ed) {
while (A_st < A_ed) {
int md = A_st + (A_ed - A_st) / 2;
auto a = As[md];
auto b = Bs[sort_md - 1 - md];
if(op(b, a)) {
if (op(b, a)) {
A_ed = md;
} else {
A_st = md + 1;
@@ -542,7 +529,6 @@ struct KernelMultiBlockMergeSort {
}
return A_ed;
}
};
@@ -563,8 +549,12 @@ template <
const device size_t* nc_strides [[buffer(7)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
using sort_kernel = KernelMultiBlockMergeSort<val_t, idx_t, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>;
using sort_kernel = KernelMultiBlockMergeSort<
val_t,
idx_t,
ARG_SORT,
BLOCK_THREADS,
N_PER_THREAD>;
auto block_idx = elem_to_loc(tid.y, nc_shape, nc_strides, nc_dim);
inp += block_idx;
@@ -575,12 +565,12 @@ template <
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
sort_kernel::block_sort(
inp,
out_vals,
out_idxs,
size_sorted_axis,
stride_sorted_axis,
tgp_vals,
inp,
out_vals,
out_idxs,
size_sorted_axis,
stride_sorted_axis,
tgp_vals,
tgp_idxs,
tid,
lid);
@@ -592,7 +582,8 @@ template <
bool ARG_SORT,
short BLOCK_THREADS,
short N_PER_THREAD>
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_partition(
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void
mb_block_partition(
device idx_t* block_partitions [[buffer(0)]],
const device val_t* dev_vals [[buffer(1)]],
const device idx_t* dev_idxs [[buffer(2)]],
@@ -601,21 +592,20 @@ template <
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 tgp_dims [[threads_per_threadgroup]]) {
using sort_kernel = KernelMultiBlockMergeSort<
val_t,
idx_t,
ARG_SORT,
BLOCK_THREADS,
val_t,
idx_t,
ARG_SORT,
BLOCK_THREADS,
N_PER_THREAD>;
block_partitions += tid.y * tgp_dims.x;
dev_vals += tid.y * size_sorted_axis;
dev_idxs += tid.y * size_sorted_axis;
// Find location in merge step
int merge_group = lid.x / merge_tiles;
int merge_lane = lid.x % merge_tiles;
int merge_lane = lid.x % merge_tiles;
int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles;
int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group;
@@ -627,14 +617,9 @@ template <
int partition_at = min(B_ed - A_st, sort_kernel::N_PER_BLOCK * merge_lane);
int partition = sort_kernel::merge_partition(
dev_vals + A_st,
dev_vals + B_st,
A_ed - A_st,
B_ed - B_st,
partition_at);
dev_vals + A_st, dev_vals + B_st, A_ed - A_st, B_ed - B_st, partition_at);
block_partitions[lid.x] = A_st + partition;
}
template <
@@ -644,7 +629,8 @@ template <
short BLOCK_THREADS,
short N_PER_THREAD,
typename CompareOp = LessThan<val_t>>
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_merge(
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void
mb_block_merge(
const device idx_t* block_partitions [[buffer(0)]],
const device val_t* dev_vals_in [[buffer(1)]],
const device idx_t* dev_idxs_in [[buffer(2)]],
@@ -655,20 +641,19 @@ template <
const constant int& num_tiles [[buffer(7)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
using sort_kernel = KernelMultiBlockMergeSort<
val_t,
idx_t,
ARG_SORT,
BLOCK_THREADS,
N_PER_THREAD,
val_t,
idx_t,
ARG_SORT,
BLOCK_THREADS,
N_PER_THREAD,
CompareOp>;
using block_sort_t = typename sort_kernel::block_merge_sort_t;
block_partitions += tid.y * (num_tiles + 1);
dev_vals_in += tid.y * size_sorted_axis;
dev_idxs_in += tid.y * size_sorted_axis;
dev_vals_in += tid.y * size_sorted_axis;
dev_idxs_in += tid.y * size_sorted_axis;
dev_vals_out += tid.y * size_sorted_axis;
dev_idxs_out += tid.y * size_sorted_axis;
@@ -680,25 +665,29 @@ template <
int A_st = block_partitions[block_idx + 0];
int A_ed = block_partitions[block_idx + 1];
int B_st = min(size_sorted_axis, 2 * sort_st + sort_sz/2 + sort_md - A_st);
int B_ed = min(size_sorted_axis, 2 * sort_st + sort_sz/2 + sort_md + sort_kernel::N_PER_BLOCK - A_ed);
int B_st = min(size_sorted_axis, 2 * sort_st + sort_sz / 2 + sort_md - A_st);
int B_ed = min(
size_sorted_axis,
2 * sort_st + sort_sz / 2 + sort_md + sort_kernel::N_PER_BLOCK - A_ed);
if((block_idx % merge_tiles) == merge_tiles - 1) {
A_ed = min(size_sorted_axis, sort_st + sort_sz/2);
if ((block_idx % merge_tiles) == merge_tiles - 1) {
A_ed = min(size_sorted_axis, sort_st + sort_sz / 2);
B_ed = min(size_sorted_axis, sort_st + sort_sz);
}
int A_sz = A_ed - A_st;
int B_sz = B_ed - B_st;
// Load from global memory
thread val_t thread_vals[N_PER_THREAD];
thread idx_t thread_idxs[N_PER_THREAD];
for(int i = 0; i < N_PER_THREAD; i++) {
for (int i = 0; i < N_PER_THREAD; i++) {
int idx = BLOCK_THREADS * i + lid.x;
if(idx < (A_sz + B_sz)) {
thread_vals[i] = (idx < A_sz) ? dev_vals_in[A_st + idx] : dev_vals_in[B_st + idx - A_sz];
thread_idxs[i] = (idx < A_sz) ? dev_idxs_in[A_st + idx] : dev_idxs_in[B_st + idx - A_sz];
if (idx < (A_sz + B_sz)) {
thread_vals[i] = (idx < A_sz) ? dev_vals_in[A_st + idx]
: dev_vals_in[B_st + idx - A_sz];
thread_idxs[i] = (idx < A_sz) ? dev_idxs_in[A_st + idx]
: dev_idxs_in[B_st + idx - A_sz];
} else {
thread_vals[i] = CompareOp::init;
thread_idxs[i] = 0;
@@ -709,7 +698,7 @@ template <
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
threadgroup_barrier(mem_flags::mem_threadgroup);
for(int i = 0; i < N_PER_THREAD; i++) {
for (int i = 0; i < N_PER_THREAD; i++) {
int idx = BLOCK_THREADS * i + lid.x;
tgp_vals[idx] = thread_vals[i];
tgp_idxs[idx] = thread_idxs[i];
@@ -720,11 +709,7 @@ template <
int sort_md_local = min(A_sz + B_sz, N_PER_THREAD * int(lid.x));
int A_st_local = block_sort_t::merge_partition(
tgp_vals,
tgp_vals + A_sz,
A_sz,
B_sz,
sort_md_local);
tgp_vals, tgp_vals + A_sz, A_sz, B_sz, sort_md_local);
int A_ed_local = A_sz;
int B_st_local = sort_md_local - A_st_local;
@@ -733,7 +718,7 @@ template <
int A_sz_local = A_ed_local - A_st_local;
int B_sz_local = B_ed_local - B_st_local;
// Do merge
// Do merge
block_sort_t::merge_step(
tgp_vals + A_st_local,
tgp_vals + A_ed_local + B_st_local,
@@ -745,61 +730,65 @@ template <
thread_idxs);
threadgroup_barrier(mem_flags::mem_threadgroup);
for(int i = 0; i < N_PER_THREAD; ++i) {
for (int i = 0; i < N_PER_THREAD; ++i) {
int idx = lid.x * N_PER_THREAD;
tgp_vals[idx + i] = thread_vals[i];
tgp_idxs[idx + i] = thread_idxs[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Write output
int base_idx = tid.x * sort_kernel::N_PER_BLOCK;
for(int i = lid.x; i < sort_kernel::N_PER_BLOCK; i+= BLOCK_THREADS) {
for (int i = lid.x; i < sort_kernel::N_PER_BLOCK; i += BLOCK_THREADS) {
int idx = base_idx + i;
if(idx < size_sorted_axis) {
if (idx < size_sorted_axis) {
dev_vals_out[idx] = tgp_vals[i];
dev_idxs_out[idx] = tgp_idxs[i];
}
}
}
#define instantiate_multi_block_sort(vtname, vtype, itname, itype, arg_sort, bn, tn) \
template [[host_name("mb_block_sort_" #vtname "_" #itname "_bn" #bn "_tn" #tn)]] \
[[kernel]] void mb_block_sort<vtype, itype, arg_sort, bn, tn>( \
const device vtype* inp [[buffer(0)]], \
device vtype* out_vals [[buffer(1)]], \
device itype* out_idxs [[buffer(2)]], \
const constant int& size_sorted_axis [[buffer(3)]], \
const constant int& stride_sorted_axis [[buffer(4)]], \
const constant int& nc_dim [[buffer(5)]], \
const device int* nc_shape [[buffer(6)]], \
const device size_t* nc_strides [[buffer(7)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]); \
template [[host_name("mb_block_partition_" #vtname "_" #itname "_bn" #bn "_tn" #tn)]] \
[[kernel]] void mb_block_partition<vtype, itype, arg_sort, bn, tn>( \
device itype* block_partitions [[buffer(0)]], \
const device vtype* dev_vals [[buffer(1)]], \
const device itype* dev_idxs [[buffer(2)]], \
const constant int& size_sorted_axis [[buffer(3)]], \
const constant int& merge_tiles [[buffer(4)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 tgp_dims [[threads_per_threadgroup]]); \
template [[host_name("mb_block_merge_" #vtname "_" #itname "_bn" #bn "_tn" #tn)]] \
[[kernel]] void mb_block_merge<vtype, itype, arg_sort, bn, tn>( \
const device itype* block_partitions [[buffer(0)]], \
const device vtype* dev_vals_in [[buffer(1)]], \
const device itype* dev_idxs_in [[buffer(2)]], \
device vtype* dev_vals_out [[buffer(3)]], \
device itype* dev_idxs_out [[buffer(4)]], \
const constant int& size_sorted_axis [[buffer(5)]], \
const constant int& merge_tiles [[buffer(6)]], \
const constant int& num_tiles [[buffer(7)]], \
uint3 tid [[threadgroup_position_in_grid]], \
#define instantiate_multi_block_sort( \
vtname, vtype, itname, itype, arg_sort, bn, tn) \
template [[host_name("mb_block_sort_" #vtname "_" #itname "_bn" #bn \
"_tn" #tn)]] [[kernel]] void \
mb_block_sort<vtype, itype, arg_sort, bn, tn>( \
const device vtype* inp [[buffer(0)]], \
device vtype* out_vals [[buffer(1)]], \
device itype* out_idxs [[buffer(2)]], \
const constant int& size_sorted_axis [[buffer(3)]], \
const constant int& stride_sorted_axis [[buffer(4)]], \
const constant int& nc_dim [[buffer(5)]], \
const device int* nc_shape [[buffer(6)]], \
const device size_t* nc_strides [[buffer(7)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]); \
template [[host_name("mb_block_partition_" #vtname "_" #itname "_bn" #bn \
"_tn" #tn)]] [[kernel]] void \
mb_block_partition<vtype, itype, arg_sort, bn, tn>( \
device itype * block_partitions [[buffer(0)]], \
const device vtype* dev_vals [[buffer(1)]], \
const device itype* dev_idxs [[buffer(2)]], \
const constant int& size_sorted_axis [[buffer(3)]], \
const constant int& merge_tiles [[buffer(4)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 tgp_dims [[threads_per_threadgroup]]); \
template [[host_name("mb_block_merge_" #vtname "_" #itname "_bn" #bn \
"_tn" #tn)]] [[kernel]] void \
mb_block_merge<vtype, itype, arg_sort, bn, tn>( \
const device itype* block_partitions [[buffer(0)]], \
const device vtype* dev_vals_in [[buffer(1)]], \
const device itype* dev_idxs_in [[buffer(2)]], \
device vtype* dev_vals_out [[buffer(3)]], \
device itype* dev_idxs_out [[buffer(4)]], \
const constant int& size_sorted_axis [[buffer(5)]], \
const constant int& merge_tiles [[buffer(6)]], \
const constant int& num_tiles [[buffer(7)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);
// clang-format off
#define instantiate_multi_block_sort_base(vtname, vtype) \
instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 512, 8)
@@ -811,10 +800,11 @@ instantiate_multi_block_sort_base(int16, int16_t)
instantiate_multi_block_sort_base(int32, int32_t)
instantiate_multi_block_sort_base(float16, half)
instantiate_multi_block_sort_base(float32, float)
instantiate_multi_block_sort_base(bfloat16, bfloat16_t)
instantiate_multi_block_sort_base(bfloat16, bfloat16_t) // clang-format on
// clang-format off
#define instantiate_multi_block_sort_long(vtname, vtype) \
instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 256, 8)
instantiate_multi_block_sort_long(uint64, uint64_t)
instantiate_multi_block_sort_long(int64, int64_t)
instantiate_multi_block_sort_long(int64, int64_t) // clang-format on

View File

@@ -4,21 +4,23 @@
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/steel/conv/conv.h"
#include "mlx/backend/metal/kernels/steel/conv/params.h"
#include "mlx/backend/metal/kernels/bf16.h"
using namespace metal;
template <typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
int N_CHANNELS = 0,
bool SMALL_FILTER = false>
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void implicit_gemm_conv_2d(
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
int N_CHANNELS = 0,
bool SMALL_FILTER = false>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
implicit_gemm_conv_2d(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
device T* C [[buffer(2)]],
@@ -28,12 +30,10 @@ template <typename T,
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
using namespace mlx::steel;
(void)lid;
constexpr bool transpose_a = false;
constexpr bool transpose_b = true;
constexpr short tgp_padding_a = 16 / sizeof(T);
@@ -47,46 +47,64 @@ template <typename T,
constexpr short tgp_mem_size_b = shape_b_cols * shape_b_rows;
constexpr short tgp_size = WM * WN * 32;
// Input loader
// Input loader
using loader_a_t = typename metal::conditional_t<
// Check for small channel specialization
N_CHANNELS != 0 && N_CHANNELS <= 4,
// Go to small channel specialization
Conv2DInputBlockLoaderSmallChannels<
T, BM, BN, BK, tgp_size, N_CHANNELS, tgp_padding_a>,
// Go to small channel specialization
Conv2DInputBlockLoaderSmallChannels<
T,
BM,
BN,
BK,
tgp_size,
N_CHANNELS,
tgp_padding_a>,
// Else go to general loader
typename metal::conditional_t<
// Check if filter size is small enough
SMALL_FILTER,
// Else go to general loader
typename metal::conditional_t<
// Check if filter size is small enough
SMALL_FILTER,
// Go to small filter specialization
Conv2DInputBlockLoaderSmallFilter<
T, BM, BN, BK, tgp_size, tgp_padding_a>,
// Else go to large filter generalization
Conv2DInputBlockLoaderLargeFilter<
T, BM, BN, BK, tgp_size, tgp_padding_a>
>
>;
// Go to small filter specialization
Conv2DInputBlockLoaderSmallFilter<
T,
BM,
BN,
BK,
tgp_size,
tgp_padding_a>,
// Else go to large filter generalization
Conv2DInputBlockLoaderLargeFilter<
T,
BM,
BN,
BK,
tgp_size,
tgp_padding_a>>>;
// Weight loader
using loader_b_t = typename metal::conditional_t<
// Check for small channel specialization
N_CHANNELS != 0 && N_CHANNELS <= 4,
// Go to small channel specialization
Conv2DWeightBlockLoaderSmallChannels<
T, BM, BN, BK, tgp_size, N_CHANNELS, tgp_padding_b>,
// Go to small channel specialization
Conv2DWeightBlockLoaderSmallChannels<
T,
BM,
BN,
BK,
tgp_size,
N_CHANNELS,
tgp_padding_b>,
// Else go to general loader
Conv2DWeightBlockLoader<T, BM, BN, BK, tgp_size, tgp_padding_b>>;
// Else go to general loader
Conv2DWeightBlockLoader<T, BM, BN, BK, tgp_size, tgp_padding_b>
>;
using mma_t = BlockMMA<
T,
T,
@@ -99,12 +117,12 @@ template <typename T,
transpose_b,
shape_a_cols,
shape_b_cols>;
threadgroup T As[tgp_mem_size_a];
threadgroup T Bs[tgp_mem_size_b];
const int tid_y = ((tid.y) << gemm_params->swizzle_log) +
((tid.x) & ((1 << gemm_params->swizzle_log) - 1));
((tid.x) & ((1 << gemm_params->swizzle_log) - 1));
const int tid_x = (tid.x) >> gemm_params->swizzle_log;
if (gemm_params->tiles_n <= tid_x || gemm_params->tiles_m <= tid_y) {
@@ -123,8 +141,10 @@ template <typename T,
const int2 offsets_b(0, c_col);
// Prepare threadgroup loading operations
loader_a_t loader_a(A, As, offsets_a, params, gemm_params, simd_gid, simd_lid);
loader_b_t loader_b(B, Bs, offsets_b, params, gemm_params, simd_gid, simd_lid);
loader_a_t loader_a(
A, As, offsets_a, params, gemm_params, simd_gid, simd_lid);
loader_b_t loader_b(
B, Bs, offsets_b, params, gemm_params, simd_gid, simd_lid);
// Prepare threadgroup mma operation
mma_t mma_op(simd_gid, simd_lid);
@@ -152,38 +172,53 @@ template <typename T,
short tgp_bm = min(BM, gemm_params->M - c_row);
short tgp_bn = min(BN, gemm_params->N - c_col);
mma_op.store_result_safe(C, N, short2(tgp_bn, tgp_bm));
}
#define instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, channel_name, n_channels, filter_name, small_filter) \
template [[host_name("implicit_gemm_conv_2d_" #name "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_channel_" #channel_name "_filter_" #filter_name)]] \
[[kernel]] void implicit_gemm_conv_2d<itype, bm, bn, bk, wm, wn, n_channels, small_filter>( \
const device itype* A [[buffer(0)]], \
const device itype* B [[buffer(1)]], \
device itype* C [[buffer(2)]], \
const constant MLXConvParams<2>* params [[buffer(3)]], \
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
#define instantiate_implicit_conv_2d( \
name, \
itype, \
bm, \
bn, \
bk, \
wm, \
wn, \
channel_name, \
n_channels, \
filter_name, \
small_filter) \
template [[host_name("implicit_gemm_conv_2d_" #name "_bm" #bm "_bn" #bn \
"_bk" #bk "_wm" #wm "_wn" #wn "_channel_" #channel_name \
"_filter_" #filter_name)]] [[kernel]] void \
implicit_gemm_conv_2d<itype, bm, bn, bk, wm, wn, n_channels, small_filter>( \
const device itype* A [[buffer(0)]], \
const device itype* B [[buffer(1)]], \
device itype* C [[buffer(2)]], \
const constant MLXConvParams<2>* params [[buffer(3)]], \
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_implicit_2d_filter(name, itype, bm, bn, bk, wm, wn) \
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, l, 0, s, true) \
// clang-format off
#define instantiate_implicit_2d_filter(name, itype, bm, bn, bk, wm, wn) \
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, l, 0, s, true) \
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, l, 0, l, false) \
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 1, 1, l, false) \
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 2, 2, l, false) \
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 3, 3, l, false) \
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 4, 4, l, false)
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 4, 4, l, false) // clang-format on
#define instantiate_implicit_2d_blocks(name, itype) \
// clang-format off
#define instantiate_implicit_2d_blocks(name, itype) \
instantiate_implicit_2d_filter(name, itype, 32, 8, 16, 4, 1) \
instantiate_implicit_2d_filter(name, itype, 64, 8, 16, 4, 1) \
instantiate_implicit_2d_filter(name, itype, 32, 32, 16, 2, 2) \
instantiate_implicit_2d_filter(name, itype, 32, 64, 16, 2, 2) \
instantiate_implicit_2d_filter(name, itype, 64, 32, 16, 2, 2) \
instantiate_implicit_2d_filter(name, itype, 64, 64, 16, 2, 2)
instantiate_implicit_2d_filter(name, itype, 64, 64, 16, 2, 2) // clang-format on
// clang-format off
instantiate_implicit_2d_blocks(float32, float);
instantiate_implicit_2d_blocks(float16, half);
instantiate_implicit_2d_blocks(bfloat16, bfloat16_t);
instantiate_implicit_2d_blocks(bfloat16, bfloat16_t); // clang-format on

View File

@@ -4,23 +4,25 @@
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
#include "mlx/backend/metal/kernels/steel/conv/conv.h"
#include "mlx/backend/metal/kernels/steel/conv/params.h"
#include "mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h"
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/steel/conv/conv.h"
#include "mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h"
#include "mlx/backend/metal/kernels/steel/conv/params.h"
using namespace metal;
using namespace mlx::steel;
template <typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
typename AccumType = float,
typename Epilogue = TransformNone<T, AccumType>>
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void implicit_gemm_conv_2d_general(
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
typename AccumType = float,
typename Epilogue = TransformNone<T, AccumType>>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
implicit_gemm_conv_2d_general(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
device T* C [[buffer(2)]],
@@ -33,9 +35,8 @@ template <typename T,
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
(void)lid;
constexpr bool transpose_a = false;
constexpr bool transpose_b = true;
constexpr short tgp_padding_a = 16 / sizeof(T);
@@ -49,15 +50,15 @@ template <typename T,
constexpr short tgp_mem_size_b = shape_b_cols * shape_b_rows;
constexpr short tgp_size = WM * WN * 32;
// Input loader
using loader_a_t = Conv2DInputBlockLoaderGeneral<
T, BM, BN, BK, tgp_size, tgp_padding_a>;
// Input loader
using loader_a_t =
Conv2DInputBlockLoaderGeneral<T, BM, BN, BK, tgp_size, tgp_padding_a>;
// Weight loader
using loader_b_t = Conv2DWeightBlockLoaderGeneral<
T, BM, BN, BK, tgp_size, tgp_padding_b>;
using loader_b_t =
Conv2DWeightBlockLoaderGeneral<T, BM, BN, BK, tgp_size, tgp_padding_b>;
using mma_t = BlockMMA<
T,
T,
@@ -70,12 +71,12 @@ template <typename T,
transpose_b,
shape_a_cols,
shape_b_cols>;
threadgroup T As[tgp_mem_size_a];
threadgroup T Bs[tgp_mem_size_b];
const int tid_y = ((tid.y) << gemm_params->swizzle_log) +
((tid.x) & ((1 << gemm_params->swizzle_log) - 1));
((tid.x) & ((1 << gemm_params->swizzle_log) - 1));
const int tid_x = (tid.x) >> gemm_params->swizzle_log;
if (gemm_params->tiles_n <= tid_x || gemm_params->tiles_m <= tid_y) {
@@ -103,13 +104,32 @@ template <typename T,
const int2 offsets_b(0, c_col);
// Prepare threadgroup loading operations
loader_a_t loader_a(A, As, offsets_a, params, jump_params, base_wh, base_ww, simd_gid, simd_lid);
loader_b_t loader_b(B, Bs, offsets_b, params, jump_params, base_wh, base_ww, simd_gid, simd_lid);
loader_a_t loader_a(
A,
As,
offsets_a,
params,
jump_params,
base_wh,
base_ww,
simd_gid,
simd_lid);
loader_b_t loader_b(
B,
Bs,
offsets_b,
params,
jump_params,
base_wh,
base_ww,
simd_gid,
simd_lid);
// Prepare threadgroup mma operation
mma_t mma_op(simd_gid, simd_lid);
int gemm_k_iterations = base_wh_size * base_ww_size * gemm_params->gemm_k_iterations;
int gemm_k_iterations =
base_wh_size * base_ww_size * gemm_params->gemm_k_iterations;
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
@@ -143,22 +163,24 @@ template <typename T,
STEEL_PRAGMA_UNROLL
for (int i = 0; i < mma_t::TM; i++) {
int cm = offset_m + i * mma_t::TM_stride;
int n = cm / jump_params->adj_out_hw;
int hw = cm % jump_params->adj_out_hw;
int oh = (hw / jump_params->adj_out_w) * jump_params->f_out_jump_h + base_oh;
int ow = (hw % jump_params->adj_out_w) * jump_params->f_out_jump_w + base_ow;
int oh =
(hw / jump_params->adj_out_w) * jump_params->f_out_jump_h + base_oh;
int ow =
(hw % jump_params->adj_out_w) * jump_params->f_out_jump_w + base_ow;
if(n < params->N && oh < params->oS[0] && ow < params->oS[1]) {
int offset_cm = n * params->out_strides[0] + oh * params->out_strides[1] + ow * params->out_strides[2];
if (n < params->N && oh < params->oS[0] && ow < params->oS[1]) {
int offset_cm = n * params->out_strides[0] +
oh * params->out_strides[1] + ow * params->out_strides[2];
STEEL_PRAGMA_UNROLL
for (int j = 0; j < mma_t::TN; j++) {
// Get accumulated result and associated offset in C
thread const auto& accum = mma_op.results[i * mma_t::TN + j].thread_elements();
thread const auto& accum =
mma_op.results[i * mma_t::TN + j].thread_elements();
int offset = offset_cm + (j * mma_t::TN_stride);
// Apply epilogue and output C
@@ -170,40 +192,42 @@ template <typename T,
C[offset + 1] = Epilogue::apply(accum[1]);
}
}
}
}
}
}
#define instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn) \
template [[host_name("implicit_gemm_conv_2d_general_" #name "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn)]] \
[[kernel]] void implicit_gemm_conv_2d_general<itype, bm, bn, bk, wm, wn>( \
const device itype* A [[buffer(0)]], \
const device itype* B [[buffer(1)]], \
device itype* C [[buffer(2)]], \
const constant MLXConvParams<2>* params [[buffer(3)]], \
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]], \
const constant Conv2DGeneralJumpParams* jump_params [[buffer(5)]], \
const constant Conv2DGeneralBaseInfo* base_h [[buffer(6)]], \
const constant Conv2DGeneralBaseInfo* base_w [[buffer(7)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn) \
template \
[[host_name("implicit_gemm_conv_2d_general_" #name "_bm" #bm "_bn" #bn \
"_bk" #bk "_wm" #wm "_wn" #wn)]] [[kernel]] void \
implicit_gemm_conv_2d_general<itype, bm, bn, bk, wm, wn>( \
const device itype* A [[buffer(0)]], \
const device itype* B [[buffer(1)]], \
device itype* C [[buffer(2)]], \
const constant MLXConvParams<2>* params [[buffer(3)]], \
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]], \
const constant Conv2DGeneralJumpParams* jump_params [[buffer(5)]], \
const constant Conv2DGeneralBaseInfo* base_h [[buffer(6)]], \
const constant Conv2DGeneralBaseInfo* base_w [[buffer(7)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_implicit_2d_filter(name, itype, bm, bn, bk, wm, wn) \
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn)
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn)
#define instantiate_implicit_2d_blocks(name, itype) \
// clang-format off
#define instantiate_implicit_2d_blocks(name, itype) \
instantiate_implicit_2d_filter(name, itype, 32, 8, 16, 4, 1) \
instantiate_implicit_2d_filter(name, itype, 64, 8, 16, 4, 1) \
instantiate_implicit_2d_filter(name, itype, 32, 32, 16, 2, 2) \
instantiate_implicit_2d_filter(name, itype, 32, 64, 16, 2, 2) \
instantiate_implicit_2d_filter(name, itype, 64, 32, 16, 2, 2) \
instantiate_implicit_2d_filter(name, itype, 64, 64, 16, 2, 2)
instantiate_implicit_2d_filter(name, itype, 64, 64, 16, 2, 2) // clang-format on
// clang-format off
instantiate_implicit_2d_blocks(float32, float);
instantiate_implicit_2d_blocks(float16, half);
instantiate_implicit_2d_blocks(bfloat16, bfloat16_t);
instantiate_implicit_2d_blocks(bfloat16, bfloat16_t); // clang-format on

View File

@@ -164,10 +164,12 @@ struct GEMMKernel {
// Find block in A, B, C
const int c_row = tid_y * BM;
const int c_col = tid_x * BN;
const size_t c_row_long = size_t(c_row);
const size_t c_col_long = size_t(c_col);
A += transpose_a ? c_row : c_row * params->lda;
B += transpose_b ? c_col * params->ldb : c_col;
D += c_row * params->ldd + c_col;
A += transpose_a ? c_row_long : c_row_long * params->lda;
B += transpose_b ? c_col_long * params->ldb : c_col_long;
D += c_row_long * params->ldd + c_col_long;
// Prepare threadgroup loading operations
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);

View File

@@ -1,8 +1,8 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
#include "mlx/backend/metal/kernels/utils.h"
using namespace metal;
using namespace mlx::steel;
@@ -11,97 +11,126 @@ using namespace mlx::steel;
// GEMM kernels
///////////////////////////////////////////////////////////////////////////////
template <typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
bool MN_aligned,
bool K_aligned>
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void gemm(
const device T *A [[buffer(0)]],
const device T *B [[buffer(1)]],
device T *D [[buffer(3)]],
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
bool MN_aligned,
bool K_aligned>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
device T* D [[buffer(3)]],
const constant GEMMParams* params [[buffer(4)]],
const constant int* batch_shape [[buffer(6)]],
const constant size_t* batch_strides [[buffer(7)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
using gemm_kernel = GEMMKernel<T, T, BM, BN, BK, WM, WN, transpose_a, transpose_b, MN_aligned, K_aligned>;
threadgroup T As[gemm_kernel::tgp_mem_size_a];
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
uint3 lid [[thread_position_in_threadgroup]]) {
using gemm_kernel = GEMMKernel<
T,
T,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
MN_aligned,
K_aligned>;
// Adjust for batch
if(params->batch_ndim > 1) {
const constant size_t* A_bstrides = batch_strides;
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
threadgroup T As[gemm_kernel::tgp_mem_size_a];
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
// Adjust for batch
if (params->batch_ndim > 1) {
const constant size_t* A_bstrides = batch_strides;
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
A += batch_offsets.x;
B += batch_offsets.y;
} else {
A += params->batch_stride_a * tid.z;
B += params->batch_stride_b * tid.z;
}
D += params->batch_stride_d * tid.z;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
gemm_kernel::run(
A, B, D,
params,
As, Bs,
simd_lane_id, simd_group_id, tid, lid
);
A += batch_offsets.x;
B += batch_offsets.y;
} else {
A += params->batch_stride_a * tid.z;
B += params->batch_stride_b * tid.z;
}
D += params->batch_stride_d * tid.z;
gemm_kernel::run(
A, B, D, params, As, Bs, simd_lane_id, simd_group_id, tid, lid);
}
///////////////////////////////////////////////////////////////////////////////
// GEMM kernel initializations
///////////////////////////////////////////////////////////////////////////////
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
template [[host_name("steel_gemm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname)]] \
[[kernel]] void gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned>( \
const device itype *A [[buffer(0)]], \
const device itype *B [[buffer(1)]], \
device itype *D [[buffer(3)]], \
const constant GEMMParams* params [[buffer(4)]], \
const constant int* batch_shape [[buffer(6)]], \
const constant size_t* batch_strides [[buffer(7)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
#define instantiate_gemm( \
tname, \
trans_a, \
trans_b, \
iname, \
itype, \
oname, \
otype, \
bm, \
bn, \
bk, \
wm, \
wn, \
aname, \
mn_aligned, \
kname, \
k_aligned) \
template [[host_name("steel_gemm_" #tname "_" #iname "_" #oname "_bm" #bm \
"_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname \
"_K_" #kname)]] [[kernel]] void \
gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned>( \
const device itype* A [[buffer(0)]], \
const device itype* B [[buffer(1)]], \
device itype* D [[buffer(3)]], \
const constant GEMMParams* params [[buffer(4)]], \
const constant int* batch_shape [[buffer(6)]], \
const constant size_t* batch_strides [[buffer(7)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
// clang-format off
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false) // clang-format on
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
// clang-format off
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) // clang-format on
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
// clang-format off
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2)
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2) // clang-format on
// clang-format off
instantiate_gemm_shapes_helper(float16, half, float16, half);
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
instantiate_gemm_shapes_helper(float32, float, float32, float);
instantiate_gemm_shapes_helper(float32, float, float32, float); // clang-format on

View File

@@ -10,23 +10,24 @@ using namespace mlx::steel;
// GEMM kernels
///////////////////////////////////////////////////////////////////////////////
template <typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
bool MN_aligned,
bool K_aligned,
typename AccumType = float,
typename Epilogue = TransformAdd<T, AccumType>>
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void addmm(
const device T *A [[buffer(0)]],
const device T *B [[buffer(1)]],
const device T *C [[buffer(2)]],
device T *D [[buffer(3)]],
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
bool MN_aligned,
bool K_aligned,
typename AccumType = float,
typename Epilogue = TransformAdd<T, AccumType>>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void addmm(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
const device T* C [[buffer(2)]],
device T* D [[buffer(3)]],
const constant GEMMParams* params [[buffer(4)]],
const constant GEMMAddMMParams* addmm_params [[buffer(5)]],
const constant int* batch_shape [[buffer(6)]],
@@ -34,243 +35,306 @@ template <typename T,
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
// Pacifying compiler
(void)lid;
using gemm_kernel =
GEMMKernel<T, T, BM, BN, BK, WM, WN,
transpose_a, transpose_b,
MN_aligned, K_aligned,
AccumType, Epilogue>;
using loader_a_t = typename gemm_kernel::loader_a_t;
using loader_b_t = typename gemm_kernel::loader_b_t;
using mma_t = typename gemm_kernel::mma_t;
threadgroup T As[gemm_kernel::tgp_mem_size_a];
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
uint3 lid [[thread_position_in_threadgroup]]) {
// Pacifying compiler
(void)lid;
// Adjust for batch
if(params->batch_ndim > 1) {
const constant size_t* A_bstrides = batch_strides;
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
const constant size_t* C_bstrides = B_bstrides + params->batch_ndim;
using gemm_kernel = GEMMKernel<
T,
T,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
MN_aligned,
K_aligned,
AccumType,
Epilogue>;
ulong3 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, A_bstrides, B_bstrides, C_bstrides, params->batch_ndim);
using loader_a_t = typename gemm_kernel::loader_a_t;
using loader_b_t = typename gemm_kernel::loader_b_t;
using mma_t = typename gemm_kernel::mma_t;
A += batch_offsets.x;
B += batch_offsets.y;
C += batch_offsets.z;
} else {
A += params->batch_stride_a * tid.z;
B += params->batch_stride_b * tid.z;
C += addmm_params->batch_stride_c * tid.z;
}
threadgroup T As[gemm_kernel::tgp_mem_size_a];
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
D += params->batch_stride_d * tid.z;
// Adjust for batch
if (params->batch_ndim > 1) {
const constant size_t* A_bstrides = batch_strides;
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
const constant size_t* C_bstrides = B_bstrides + params->batch_ndim;
const int tid_y = ((tid.y) << params->swizzle_log) +
((tid.x) & ((1 << params->swizzle_log) - 1));
const int tid_x = (tid.x) >> params->swizzle_log;
ulong3 batch_offsets = elem_to_loc_broadcast(
tid.z,
batch_shape,
A_bstrides,
B_bstrides,
C_bstrides,
params->batch_ndim);
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
return;
A += batch_offsets.x;
B += batch_offsets.y;
C += batch_offsets.z;
} else {
A += params->batch_stride_a * tid.z;
B += params->batch_stride_b * tid.z;
C += addmm_params->batch_stride_c * tid.z;
}
D += params->batch_stride_d * tid.z;
const int tid_y = ((tid.y) << params->swizzle_log) +
((tid.x) & ((1 << params->swizzle_log) - 1));
const int tid_x = (tid.x) >> params->swizzle_log;
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
return;
}
threadgroup_barrier(mem_flags::mem_none);
// Find block in A, B, C
const int c_row = tid_y * BM;
const int c_col = tid_x * BN;
const size_t c_row_long = size_t(c_row);
const size_t c_col_long = size_t(c_col);
A += transpose_a ? c_row_long : c_row_long * params->lda;
B += transpose_b ? c_col_long * params->ldb : c_col_long;
D += c_row_long * params->ldd + c_col_long;
C += c_row_long * addmm_params->ldc + c_col_long * addmm_params->fdc;
// Prepare threadgroup loading operations
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
// Prepare threadgroup mma operation
thread mma_t mma_op(simd_group_id, simd_lane_id);
int gemm_k_iterations = params->gemm_k_iterations_aligned;
const Epilogue epilogue_op(addmm_params->alpha, addmm_params->beta);
///////////////////////////////////////////////////////////////////////////////
// MNK aligned loop
if (MN_aligned) {
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
threadgroup_barrier(mem_flags::mem_none);
// Find block in A, B, C
const int c_row = tid_y * BM;
const int c_col = tid_x * BN;
// Loop tail
if (!K_aligned) {
int lbk = params->K - params->gemm_k_iterations_aligned * BK;
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
A += transpose_a ? c_row : c_row * params->lda;
B += transpose_b ? c_col * params->ldb : c_col;
D += c_row * params->ldd + c_col;
loader_a.load_safe(tile_dims_A);
loader_b.load_safe(tile_dims_B);
C += c_row * addmm_params->ldc + c_col * addmm_params->fdc;
threadgroup_barrier(mem_flags::mem_threadgroup);
// Prepare threadgroup loading operations
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
mma_op.mma(As, Bs);
}
// Prepare threadgroup mma operation
thread mma_t mma_op(simd_group_id, simd_lane_id);
// Store results to device memory
mma_op.store_result(
D, params->ldd, C, addmm_params->ldc, addmm_params->fdc, epilogue_op);
return;
int gemm_k_iterations = params->gemm_k_iterations_aligned;
}
///////////////////////////////////////////////////////////////////////////////
// MN unaligned loop
else { // Loop over K - unaligned case
short tgp_bm = min(BM, params->M - c_row);
short tgp_bn = min(BN, params->N - c_col);
int leftover_bk = params->K - params->gemm_k_iterations_aligned * BK;
const Epilogue epilogue_op(addmm_params->alpha, addmm_params->beta);
if (tgp_bm == BM && tgp_bn == BN) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<true, true, K_aligned>{});
///////////////////////////////////////////////////////////////////////////////
// MNK aligned loop
if (MN_aligned) {
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
threadgroup_barrier(mem_flags::mem_none);
// Loop tail
if (!K_aligned) {
int lbk = params->K - params->gemm_k_iterations_aligned * BK;
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
loader_a.load_safe(tile_dims_A);
loader_b.load_safe(tile_dims_B);
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
// Store results to device memory
mma_op.store_result(D, params->ldd, C, addmm_params->ldc, addmm_params->fdc, epilogue_op);
mma_op.store_result(
D, params->ldd, C, addmm_params->ldc, addmm_params->fdc, epilogue_op);
return;
} else if (tgp_bn == BN) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<false, true, K_aligned>{});
return mma_op.store_result_safe(
D,
params->ldd,
C,
addmm_params->ldc,
addmm_params->fdc,
short2(tgp_bn, tgp_bm),
epilogue_op);
} else if (tgp_bm == BM) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<true, false, K_aligned>{});
return mma_op.store_result_safe(
D,
params->ldd,
C,
addmm_params->ldc,
addmm_params->fdc,
short2(tgp_bn, tgp_bm),
epilogue_op);
} else {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<false, false, K_aligned>{});
return mma_op.store_result_safe(
D,
params->ldd,
C,
addmm_params->ldc,
addmm_params->fdc,
short2(tgp_bn, tgp_bm),
epilogue_op);
}
///////////////////////////////////////////////////////////////////////////////
// MN unaligned loop
else { // Loop over K - unaligned case
short tgp_bm = min(BM, params->M - c_row);
short tgp_bn = min(BN, params->N - c_col);
short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK;
if (tgp_bm == BM && tgp_bn == BN) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<true, true, K_aligned>{});
mma_op.store_result(D, params->ldd, C, addmm_params->ldc, addmm_params->fdc, epilogue_op);
return;
} else if (tgp_bn == BN) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<false, true, K_aligned>{});
return mma_op.store_result_safe(
D, params->ldd,
C, addmm_params->ldc, addmm_params->fdc,
short2(tgp_bn, tgp_bm),
epilogue_op);
} else if (tgp_bm == BM) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<true, false, K_aligned>{});
return mma_op.store_result_safe(
D, params->ldd,
C, addmm_params->ldc, addmm_params->fdc,
short2(tgp_bn, tgp_bm),
epilogue_op);
} else {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<false, false, K_aligned>{});
return mma_op.store_result_safe(
D, params->ldd,
C, addmm_params->ldc, addmm_params->fdc,
short2(tgp_bn, tgp_bm),
epilogue_op);
}
}
}
}
///////////////////////////////////////////////////////////////////////////////
// GEMM kernel initializations
///////////////////////////////////////////////////////////////////////////////
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, ep_name, epilogue) \
template [[host_name("steel_addmm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname "_" #ep_name)]] \
[[kernel]] void addmm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned, float, epilogue<itype, float>>( \
const device itype *A [[buffer(0)]], \
const device itype *B [[buffer(1)]], \
const device itype *C [[buffer(2)]], \
device itype *D [[buffer(3)]], \
const constant GEMMParams* gemm_params [[buffer(4)]], \
const constant GEMMAddMMParams* params [[buffer(5)]], \
const constant int* batch_shape [[buffer(6)]], \
const constant size_t* batch_strides [[buffer(7)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
#define instantiate_gemm( \
tname, \
trans_a, \
trans_b, \
iname, \
itype, \
oname, \
otype, \
bm, \
bn, \
bk, \
wm, \
wn, \
aname, \
mn_aligned, \
kname, \
k_aligned, \
ep_name, \
epilogue) \
template [[host_name("steel_addmm_" #tname "_" #iname "_" #oname "_bm" #bm \
"_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname \
"_K_" #kname "_" #ep_name)]] [[kernel]] void \
addmm< \
itype, \
bm, \
bn, \
bk, \
wm, \
wn, \
trans_a, \
trans_b, \
mn_aligned, \
k_aligned, \
float, \
epilogue<itype, float>>( \
const device itype* A [[buffer(0)]], \
const device itype* B [[buffer(1)]], \
const device itype* C [[buffer(2)]], \
device itype* D [[buffer(3)]], \
const constant GEMMParams* gemm_params [[buffer(4)]], \
const constant GEMMAddMMParams* params [[buffer(5)]], \
const constant int* batch_shape [[buffer(6)]], \
const constant size_t* batch_strides [[buffer(7)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);
#define instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, add, TransformAdd) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, axpby, TransformAxpby)
// clang-format off
#define instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, add, TransformAdd) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, axpby, TransformAxpby) // clang-format on
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
// clang-format off
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false) // clang-format on
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
// clang-format off
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) // clang-format on
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
// clang-format off
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2)
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2) // clang-format on
// clang-format off
instantiate_gemm_shapes_helper(float16, half, float16, half);
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
instantiate_gemm_shapes_helper(float32, float, float32, float);
instantiate_gemm_shapes_helper(float32, float, float32, float); // clang-format on

View File

@@ -0,0 +1,168 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
#include "mlx/backend/metal/kernels/utils.h"
using namespace metal;
using namespace mlx::steel;
///////////////////////////////////////////////////////////////////////////////
// GEMM kernels
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
bool MN_aligned,
bool K_aligned>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void bs_gemm(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
device T* D [[buffer(3)]],
const constant GEMMParams* params [[buffer(4)]],
const constant int* batch_shape [[buffer(6)]],
const constant size_t* batch_strides [[buffer(7)]],
const constant uint32_t* lhs_indices [[buffer(10)]],
const constant uint32_t* rhs_indices [[buffer(11)]],
const constant int* batch_shape_A [[buffer(12)]],
const constant size_t* batch_strides_A [[buffer(13)]],
const constant int* batch_shape_B [[buffer(14)]],
const constant size_t* batch_strides_B [[buffer(15)]],
const constant int2& operand_batch_ndim [[buffer(16)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
using gemm_kernel = GEMMKernel<
T,
T,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
MN_aligned,
K_aligned>;
threadgroup T As[gemm_kernel::tgp_mem_size_a];
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
uint32_t indx_A;
uint32_t indx_B;
// Adjust for batch
if (params->batch_ndim > 1) {
const constant size_t* A_bstrides = batch_strides;
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
indx_A = lhs_indices[batch_offsets.x];
indx_B = rhs_indices[batch_offsets.y];
} else {
indx_A = lhs_indices[params->batch_stride_a * tid.z];
indx_B = rhs_indices[params->batch_stride_b * tid.z];
}
int batch_ndim_A = operand_batch_ndim.x;
int batch_ndim_B = operand_batch_ndim.y;
if (batch_ndim_A > 1) {
A += elem_to_loc(indx_A, batch_shape_A, batch_strides_A, batch_ndim_A);
} else {
A += indx_A * batch_strides_A[0];
}
if (batch_ndim_B > 1) {
B += elem_to_loc(indx_B, batch_shape_B, batch_strides_B, batch_ndim_B);
} else {
B += indx_B * batch_strides_B[0];
}
D += params->batch_stride_d * tid.z;
gemm_kernel::run(
A, B, D, params, As, Bs, simd_lane_id, simd_group_id, tid, lid);
}
///////////////////////////////////////////////////////////////////////////////
// GEMM kernel initializations
///////////////////////////////////////////////////////////////////////////////
#define instantiate_gemm( \
tname, \
trans_a, \
trans_b, \
iname, \
itype, \
oname, \
otype, \
bm, \
bn, \
bk, \
wm, \
wn, \
aname, \
mn_aligned, \
kname, \
k_aligned) \
template [[host_name("steel_block_sparse_gemm_" #tname "_" #iname "_" #oname \
"_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn \
"_MN_" #aname "_K_" #kname)]] [[kernel]] void \
bs_gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned>( \
const device itype* A [[buffer(0)]], \
const device itype* B [[buffer(1)]], \
device itype* D [[buffer(3)]], \
const constant GEMMParams* params [[buffer(4)]], \
const constant int* batch_shape [[buffer(6)]], \
const constant size_t* batch_strides [[buffer(7)]], \
const constant uint32_t* lhs_indices [[buffer(10)]], \
const constant uint32_t* rhs_indices [[buffer(11)]], \
const constant int* batch_shape_A [[buffer(12)]], \
const constant size_t* batch_strides_A [[buffer(13)]], \
const constant int* batch_shape_B [[buffer(14)]], \
const constant size_t* batch_strides_B [[buffer(15)]], \
const constant int2& operand_batch_ndim [[buffer(16)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);
// clang-format off
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false) // clang-format on
// clang-format off
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) // clang-format on
// clang-format off
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2) // clang-format on
// clang-format off
instantiate_gemm_shapes_helper(float16, half, float16, half);
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
instantiate_gemm_shapes_helper(float32, float, float32, float); // clang-format on

View File

@@ -1,8 +1,8 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
#include "mlx/backend/metal/kernels/utils.h"
using namespace metal;
using namespace mlx::steel;
@@ -11,319 +11,380 @@ using namespace mlx::steel;
// GEMM kernels
///////////////////////////////////////////////////////////////////////////////
template <typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
bool MN_aligned,
bool K_aligned,
bool has_operand_mask=false>
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void block_masked_gemm(
const device T *A [[buffer(0)]],
const device T *B [[buffer(1)]],
device T *D [[buffer(3)]],
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
bool MN_aligned,
bool K_aligned,
bool has_operand_mask = false>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
block_masked_gemm(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
device T* D [[buffer(3)]],
const constant GEMMParams* params [[buffer(4)]],
const constant int* batch_shape [[buffer(6)]],
const constant size_t* batch_strides [[buffer(7)]],
const device bool *out_mask [[buffer(10)]],
const device bool *lhs_mask [[buffer(11)]],
const device bool *rhs_mask [[buffer(12)]],
const device bool* out_mask [[buffer(10)]],
const device bool* lhs_mask [[buffer(11)]],
const device bool* rhs_mask [[buffer(12)]],
const constant int* mask_strides [[buffer(13)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
uint3 lid [[thread_position_in_threadgroup]]) {
// Appease the compiler
(void)lid;
// Appease the compiler
(void)lid;
using gemm_kernel = GEMMKernel<T, T, BM, BN, BK, WM, WN, transpose_a, transpose_b, MN_aligned, K_aligned>;
const int tid_y = ((tid.y) << params->swizzle_log) +
((tid.x) & ((1 << params->swizzle_log) - 1));
const int tid_x = (tid.x) >> params->swizzle_log;
using gemm_kernel = GEMMKernel<
T,
T,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
MN_aligned,
K_aligned>;
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
return;
}
const int tid_y = ((tid.y) << params->swizzle_log) +
((tid.x) & ((1 << params->swizzle_log) - 1));
const int tid_x = (tid.x) >> params->swizzle_log;
if(params->batch_ndim > 1) {
const constant size_t* mask_batch_strides = batch_strides + 2 * params->batch_ndim;
out_mask += elem_to_loc(tid.z, batch_shape, mask_batch_strides, params->batch_ndim);
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
return;
}
if(has_operand_mask) {
const constant size_t* mask_strides_lhs = mask_batch_strides + params->batch_ndim;
const constant size_t* mask_strides_rhs = mask_strides_lhs + params->batch_ndim;
if (params->batch_ndim > 1) {
const constant size_t* mask_batch_strides =
batch_strides + 2 * params->batch_ndim;
out_mask +=
elem_to_loc(tid.z, batch_shape, mask_batch_strides, params->batch_ndim);
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, mask_strides_lhs, mask_strides_rhs, params->batch_ndim);
lhs_mask += batch_offsets.x;
rhs_mask += batch_offsets.y;
}
} else {
out_mask += tid.z * batch_strides[2 * params->batch_ndim];
if(has_operand_mask) {
lhs_mask += tid.z * batch_strides[3 * params->batch_ndim];
rhs_mask += tid.z * batch_strides[4 * params->batch_ndim];
}
}
// Adjust for batch
if(params->batch_ndim > 1) {
const constant size_t* A_bstrides = batch_strides;
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
if (has_operand_mask) {
const constant size_t* mask_strides_lhs =
mask_batch_strides + params->batch_ndim;
const constant size_t* mask_strides_rhs =
mask_strides_lhs + params->batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
tid.z,
batch_shape,
mask_strides_lhs,
mask_strides_rhs,
params->batch_ndim);
A += batch_offsets.x;
B += batch_offsets.y;
} else {
A += params->batch_stride_a * tid.z;
B += params->batch_stride_b * tid.z;
lhs_mask += batch_offsets.x;
rhs_mask += batch_offsets.y;
}
D += params->batch_stride_d * tid.z;
} else {
out_mask += tid.z * batch_strides[2 * params->batch_ndim];
if (has_operand_mask) {
lhs_mask += tid.z * batch_strides[3 * params->batch_ndim];
rhs_mask += tid.z * batch_strides[4 * params->batch_ndim];
}
}
// Find block in A, B, C
const int c_row = tid_y * BM;
const int c_col = tid_x * BN;
// Adjust for batch
if (params->batch_ndim > 1) {
const constant size_t* A_bstrides = batch_strides;
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
A += transpose_a ? c_row : c_row * params->lda;
B += transpose_b ? c_col * params->ldb : c_col;
D += c_row * params->ldd + c_col;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
A += batch_offsets.x;
B += batch_offsets.y;
bool mask_out = out_mask[tid_y * mask_strides[1] + tid_x * mask_strides[0]];
} else {
A += params->batch_stride_a * tid.z;
B += params->batch_stride_b * tid.z;
}
// Write zeros and return
if(!mask_out) {
constexpr short tgp_size = WM * WN * 32;
constexpr short vec_size = 4;
D += params->batch_stride_d * tid.z;
// Tile threads in threadgroup
constexpr short TN = BN / vec_size;
constexpr short TM = tgp_size / TN;
// Find block in A, B, C
const int c_row = tid_y * BM;
const int c_col = tid_x * BN;
const size_t c_row_long = size_t(c_row);
const size_t c_col_long = size_t(c_col);
const short thread_idx = simd_group_id * 32 + simd_lane_id;
const short bi = thread_idx / TN;
const short bj = vec_size * (thread_idx % TN);
A += transpose_a ? c_row_long : c_row_long * params->lda;
B += transpose_b ? c_col_long * params->ldb : c_col_long;
D += c_row_long * params->ldd + c_col_long;
D += bi * params->ldd + bj;
bool mask_out = out_mask[tid_y * mask_strides[1] + tid_x * mask_strides[0]];
short tgp_bm = min(BM, params->M - c_row);
short tgp_bn = min(BN, params->N - c_col);
// Write zeros and return
if (!mask_out) {
constexpr short tgp_size = WM * WN * 32;
constexpr short vec_size = 4;
if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
for (short ti = 0; ti < BM; ti += TM) {
STEEL_PRAGMA_UNROLL
for(short j = 0; j < vec_size; j++) {
D[ti * params->ldd + j] = T(0.);
}
}
} else {
short jmax = tgp_bn - bj;
jmax = jmax < vec_size ? jmax : vec_size;
for (short ti = 0; (bi + ti) < tgp_bm; ti += TM) {
for(short j = 0; j < jmax; j++) {
D[ti * params->ldd + j] = T(0.);
}
// Tile threads in threadgroup
constexpr short TN = BN / vec_size;
constexpr short TM = tgp_size / TN;
const short thread_idx = simd_group_id * 32 + simd_lane_id;
const short bi = thread_idx / TN;
const short bj = vec_size * (thread_idx % TN);
D += bi * params->ldd + bj;
short tgp_bm = min(BM, params->M - c_row);
short tgp_bn = min(BN, params->N - c_col);
if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
for (short ti = 0; ti < BM; ti += TM) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
D[ti * params->ldd + j] = T(0.);
}
}
return;
} else {
short jmax = tgp_bn - bj;
jmax = jmax < vec_size ? jmax : vec_size;
for (short ti = 0; (bi + ti) < tgp_bm; ti += TM) {
for (short j = 0; j < jmax; j++) {
D[ti * params->ldd + j] = T(0.);
}
}
}
return;
}
threadgroup_barrier(mem_flags::mem_none);
// Prepare threadgroup mma operation
thread typename gemm_kernel::mma_t mma_op(simd_group_id, simd_lane_id);
int gemm_k_iterations = params->gemm_k_iterations_aligned;
threadgroup T As[gemm_kernel::tgp_mem_size_a];
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
// Prepare threadgroup loading operations
thread typename gemm_kernel::loader_a_t loader_a(
A, params->lda, As, simd_group_id, simd_lane_id);
thread typename gemm_kernel::loader_b_t loader_b(
B, params->ldb, Bs, simd_group_id, simd_lane_id);
///////////////////////////////////////////////////////////////////////////////
// MNK aligned loop
if (MN_aligned) {
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
if (!has_operand_mask ||
(lhs_mask
[tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] &&
rhs_mask
[((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) {
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
}
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
threadgroup_barrier(mem_flags::mem_none);
// Prepare threadgroup mma operation
thread typename gemm_kernel::mma_t mma_op(simd_group_id, simd_lane_id);
// Loop tail
if (!K_aligned) {
if (!has_operand_mask ||
(lhs_mask
[tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] &&
rhs_mask
[(params->K / BM) * mask_strides[5] +
tid_x * mask_strides[4]])) {
int lbk = params->K - params->gemm_k_iterations_aligned * BK;
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
int gemm_k_iterations = params->gemm_k_iterations_aligned;
loader_a.load_safe(tile_dims_A);
loader_b.load_safe(tile_dims_B);
threadgroup T As[gemm_kernel::tgp_mem_size_a];
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
// Prepare threadgroup loading operations
thread typename gemm_kernel::loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
thread typename gemm_kernel::loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
///////////////////////////////////////////////////////////////////////////////
// MNK aligned loop
if (MN_aligned) {
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
if(!has_operand_mask ||
(lhs_mask[tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] &&
rhs_mask[((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) {
mma_op.mma(As, Bs);
}
}
// Load elements into threadgroup
// Store results to device memory
mma_op.store_result(D, params->ldd);
return;
}
///////////////////////////////////////////////////////////////////////////////
// MN unaligned loop
else { // Loop over K - unaligned case
short tgp_bm = min(BM, params->M - c_row);
short tgp_bn = min(BN, params->N - c_col);
short lbk = params->K - params->gemm_k_iterations_aligned * BK;
bool M_aligned = (tgp_bm == BM);
bool N_aligned = (tgp_bn == BN);
short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
if (!has_operand_mask ||
(lhs_mask
[tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] &&
rhs_mask
[((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) {
// Load elements into threadgroup
if (M_aligned) {
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
}
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
threadgroup_barrier(mem_flags::mem_none);
// Loop tail
if (!K_aligned) {
if(!has_operand_mask ||
(lhs_mask[tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] &&
rhs_mask[(params->K / BM) * mask_strides[5] + tid_x * mask_strides[4]])) {
int lbk = params->K - params->gemm_k_iterations_aligned * BK;
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
} else {
loader_a.load_safe(tile_dims_A);
}
if (N_aligned) {
loader_b.load_unsafe();
} else {
loader_b.load_safe(tile_dims_B);
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
}
// Store results to device memory
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
if (!K_aligned) {
threadgroup_barrier(mem_flags::mem_threadgroup);
if (!has_operand_mask ||
(lhs_mask
[tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] &&
rhs_mask
[(params->K / BM) * mask_strides[5] +
tid_x * mask_strides[4]])) {
short2 tile_dims_A_last =
transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm);
short2 tile_dims_B_last =
transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk);
loader_a.load_safe(tile_dims_A_last);
loader_b.load_safe(tile_dims_B_last);
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
}
if (M_aligned && N_aligned) {
mma_op.store_result(D, params->ldd);
return;
}
///////////////////////////////////////////////////////////////////////////////
// MN unaligned loop
else { // Loop over K - unaligned case
short tgp_bm = min(BM, params->M - c_row);
short tgp_bn = min(BN, params->N - c_col);
short lbk = params->K - params->gemm_k_iterations_aligned * BK;
bool M_aligned = (tgp_bm == BM);
bool N_aligned = (tgp_bn == BN);
short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
if(!has_operand_mask ||
(lhs_mask[tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] &&
rhs_mask[((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) {
// Load elements into threadgroup
if (M_aligned) {
loader_a.load_unsafe();
} else {
loader_a.load_safe(tile_dims_A);
}
if (N_aligned) {
loader_b.load_unsafe();
} else {
loader_b.load_safe(tile_dims_B);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
}
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
if (!K_aligned) {
threadgroup_barrier(mem_flags::mem_threadgroup);
if(!has_operand_mask ||
(lhs_mask[tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] &&
rhs_mask[(params->K / BM) * mask_strides[5] + tid_x * mask_strides[4]])) {
short2 tile_dims_A_last =
transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm);
short2 tile_dims_B_last =
transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk);
loader_a.load_safe(tile_dims_A_last);
loader_b.load_safe(tile_dims_B_last);
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
}
if(M_aligned && N_aligned) {
mma_op.store_result(D, params->ldd);
} else {
mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
}
} else {
mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
}
}
}
///////////////////////////////////////////////////////////////////////////////
// GEMM kernel initializations
///////////////////////////////////////////////////////////////////////////////
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, omname, op_mask) \
template [[host_name("steel_block_masked_gemm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname "_op_mask_" #omname)]] \
[[kernel]] void block_masked_gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned, op_mask>( \
const device itype *A [[buffer(0)]], \
const device itype *B [[buffer(1)]], \
device itype *D [[buffer(3)]], \
const constant GEMMParams* params [[buffer(4)]], \
const constant int* batch_shape [[buffer(6)]], \
const constant size_t* batch_strides [[buffer(7)]], \
const device bool *out_mask [[buffer(10)]], \
const device bool *lhs_mask [[buffer(11)]], \
const device bool *rhs_mask [[buffer(12)]], \
const constant int* mask_strides [[buffer(13)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
#define instantiate_gemm( \
tname, \
trans_a, \
trans_b, \
iname, \
itype, \
oname, \
otype, \
bm, \
bn, \
bk, \
wm, \
wn, \
aname, \
mn_aligned, \
kname, \
k_aligned, \
omname, \
op_mask) \
template [[host_name("steel_block_masked_gemm_" #tname "_" #iname "_" #oname \
"_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn \
"_MN_" #aname "_K_" #kname \
"_op_mask_" #omname)]] [[kernel]] void \
block_masked_gemm< \
itype, \
bm, \
bn, \
bk, \
wm, \
wn, \
trans_a, \
trans_b, \
mn_aligned, \
k_aligned, \
op_mask>( \
const device itype* A [[buffer(0)]], \
const device itype* B [[buffer(1)]], \
device itype* D [[buffer(3)]], \
const constant GEMMParams* params [[buffer(4)]], \
const constant int* batch_shape [[buffer(6)]], \
const constant size_t* batch_strides [[buffer(7)]], \
const device bool* out_mask [[buffer(10)]], \
const device bool* lhs_mask [[buffer(11)]], \
const device bool* rhs_mask [[buffer(12)]], \
const constant int* mask_strides [[buffer(13)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);
// clang-format off
#define instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, N, false) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, T, true)
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, N, false) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, T, true) // clang-format on
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
// clang-format off
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)
instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false) // clang-format on
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
// clang-format off
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) // clang-format on
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
// clang-format off
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2)
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) // clang-format on
// clang-format off
instantiate_gemm_shapes_helper(float16, half, float16, half);
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
instantiate_gemm_shapes_helper(float32, float, float32, float);
instantiate_gemm_shapes_helper(float32, float, float32, float); // clang-format on

View File

@@ -10,81 +10,99 @@ using namespace mlx::steel;
// GEMM kernels
///////////////////////////////////////////////////////////////////////////////
template <typename T,
typename U,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
bool MN_aligned,
bool K_aligned>
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void gemm_splitk(
const device T *A [[buffer(0)]],
const device T *B [[buffer(1)]],
device U *C [[buffer(2)]],
template <
typename T,
typename U,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
bool MN_aligned,
bool K_aligned>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm_splitk(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
device U* C [[buffer(2)]],
const constant GEMMSpiltKParams* params [[buffer(3)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
uint3 lid [[thread_position_in_threadgroup]]) {
(void)lid;
(void)lid;
using gemm_kernel = GEMMKernel<T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, MN_aligned, K_aligned>;
using loader_a_t = typename gemm_kernel::loader_a_t;
using loader_b_t = typename gemm_kernel::loader_b_t;
using mma_t = typename gemm_kernel::mma_t;
threadgroup T As[gemm_kernel::tgp_mem_size_a];
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
using gemm_kernel = GEMMKernel<
T,
U,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
MN_aligned,
K_aligned>;
using loader_a_t = typename gemm_kernel::loader_a_t;
using loader_b_t = typename gemm_kernel::loader_b_t;
using mma_t = typename gemm_kernel::mma_t;
const int tid_x = tid.x;
const int tid_y = tid.y;
const int tid_z = tid.z;
threadgroup T As[gemm_kernel::tgp_mem_size_a];
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
return;
}
const int tid_x = tid.x;
const int tid_y = tid.y;
const int tid_z = tid.z;
// Find block in A, B, C
const int c_row = tid_y * BM;
const int c_col = tid_x * BN;
const int k_start = params->split_k_partition_size * tid_z;
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
return;
}
A += transpose_a ? (c_row + k_start * params->lda) : (k_start + c_row * params->lda);
B += transpose_b ? (k_start + c_col * params->ldb) : (c_col + k_start * params->ldb);
C += (params->split_k_partition_stride * tid_z) + (c_row * params->ldc + c_col);
// Find block in A, B, C
const int c_row = tid_y * BM;
const int c_col = tid_x * BN;
const int k_start = params->split_k_partition_size * tid_z;
// Prepare threadgroup loading operations
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
const size_t c_row_long = size_t(c_row);
const size_t c_col_long = size_t(c_col);
const size_t k_start_long = size_t(k_start);
// Prepare threadgroup mma operation
thread mma_t mma_op(simd_group_id, simd_lane_id);
A += transpose_a ? (c_row_long + k_start_long * params->lda)
: (k_start_long + c_row_long * params->lda);
B += transpose_b ? (k_start_long + c_col_long * params->ldb)
: (c_col_long + k_start_long * params->ldb);
C += (size_t(params->split_k_partition_stride) * tid_z) +
(c_row_long * params->ldc + c_col_long);
int gemm_k_iterations = params->gemm_k_iterations_aligned;
// Prepare threadgroup loading operations
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
short tgp_bm = min(BM, params->M - c_row);
short tgp_bn = min(BN, params->N - c_col);
short leftover_bk = params->K % BK;
// Prepare threadgroup mma operation
thread mma_t mma_op(simd_group_id, simd_lane_id);
if(MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<true, true, true>{});
} else if (tgp_bn == BN) {
gemm_kernel::gemm_loop(
int gemm_k_iterations = params->gemm_k_iterations_aligned;
short tgp_bm = min(BM, params->M - c_row);
short tgp_bn = min(BN, params->N - c_col);
short leftover_bk = params->K % BK;
if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<true, true, true>{});
} else if (tgp_bn == BN) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
@@ -95,37 +113,38 @@ template <typename T,
tgp_bn,
leftover_bk,
LoopAlignment<false, true, true>{});
} else if (tgp_bm == BM) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<true, false, true>{});
} else {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<false, false, true>{});
}
} else if (tgp_bm == BM) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<true, false, true>{});
} else {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<false, false, true>{});
}
threadgroup_barrier(mem_flags::mem_threadgroup);
threadgroup_barrier(mem_flags::mem_threadgroup);
if ((tid_z + 1) == (params->split_k_partitions)) {
int gemm_k_iter_remaining = (params->K - (k_start + params->split_k_partition_size)) / BK;
if(!K_aligned || gemm_k_iter_remaining > 0)
if ((tid_z + 1) == (params->split_k_partitions)) {
int gemm_k_iter_remaining =
(params->K - (k_start + params->split_k_partition_size)) / BK;
if (!K_aligned || gemm_k_iter_remaining > 0)
gemm_kernel::gemm_loop(
As,
Bs,
@@ -137,144 +156,178 @@ template <typename T,
tgp_bn,
leftover_bk,
LoopAlignment<false, false, K_aligned>{});
}
}
if(MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
mma_op.store_result(C, params->ldc);
} else {
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
}
if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
mma_op.store_result(C, params->ldc);
} else {
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
}
}
///////////////////////////////////////////////////////////////////////////////
// GEMM kernel initializations
///////////////////////////////////////////////////////////////////////////////
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
template [[host_name("steel_gemm_splitk_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname)]] \
[[kernel]] void gemm_splitk<itype, otype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned>( \
const device itype *A [[buffer(0)]], \
const device itype *B [[buffer(1)]], \
device otype *C [[buffer(2)]], \
const constant GEMMSpiltKParams* params [[buffer(3)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
#define instantiate_gemm( \
tname, \
trans_a, \
trans_b, \
iname, \
itype, \
oname, \
otype, \
bm, \
bn, \
bk, \
wm, \
wn, \
aname, \
mn_aligned, \
kname, \
k_aligned) \
template [[host_name("steel_gemm_splitk_" #tname "_" #iname "_" #oname \
"_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn \
"_MN_" #aname "_K_" #kname)]] [[kernel]] void \
gemm_splitk< \
itype, \
otype, \
bm, \
bn, \
bk, \
wm, \
wn, \
trans_a, \
trans_b, \
mn_aligned, \
k_aligned>( \
const device itype* A [[buffer(0)]], \
const device itype* B [[buffer(1)]], \
device otype* C [[buffer(2)]], \
const constant GEMMSpiltKParams* params [[buffer(3)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
// clang-format off
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false) // clang-format on
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
// clang-format off
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) // clang-format on
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
// clang-format off
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 16, 16, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 16, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 16, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2)
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) // clang-format on
// clang-format off
instantiate_gemm_shapes_helper(float16, half, float32, float);
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, float32, float);
instantiate_gemm_shapes_helper(float32, float, float32, float);
instantiate_gemm_shapes_helper(float32, float, float32, float); // clang-format on
///////////////////////////////////////////////////////////////////////////////
// Split k accumulation kernel
// Split k accumulation kernel
///////////////////////////////////////////////////////////////////////////////
template <typename AccT,
typename OutT,
typename Epilogue = TransformNone<OutT, AccT>>
template <
typename AccT,
typename OutT,
typename Epilogue = TransformNone<OutT, AccT>>
[[kernel]] void gemm_splitk_accum(
const device AccT *C_split [[buffer(0)]],
device OutT *D [[buffer(1)]],
const device AccT* C_split [[buffer(0)]],
device OutT* D [[buffer(1)]],
const constant int& k_partitions [[buffer(2)]],
const constant int& partition_stride [[buffer(3)]],
const constant int& ldd [[buffer(4)]],
uint2 gid [[thread_position_in_grid]]) {
// Ajust D and C
D += gid.x + gid.y * ldd;
C_split += gid.x + gid.y * ldd;
D += gid.x + gid.y * size_t(ldd);
C_split += gid.x + gid.y * size_t(ldd);
int offset = 0;
size_t offset = 0;
AccT out = 0;
for(int i = 0; i < k_partitions; i++) {
for (int i = 0; i < k_partitions; i++) {
out += C_split[offset];
offset += partition_stride;
}
// Write output
// Write output
D[0] = Epilogue::apply(out);
}
template <typename AccT,
typename OutT,
typename Epilogue = TransformAxpby<OutT, AccT>>
template <
typename AccT,
typename OutT,
typename Epilogue = TransformAxpby<OutT, AccT>>
[[kernel]] void gemm_splitk_accum_axpby(
const device AccT *C_split [[buffer(0)]],
device OutT *D [[buffer(1)]],
const device AccT* C_split [[buffer(0)]],
device OutT* D [[buffer(1)]],
const constant int& k_partitions [[buffer(2)]],
const constant int& partition_stride [[buffer(3)]],
const constant int& ldd [[buffer(4)]],
const device OutT *C [[buffer(5)]],
const device OutT* C [[buffer(5)]],
const constant int& ldc [[buffer(6)]],
const constant int& fdc [[buffer(7)]],
const constant float& alpha [[buffer(8)]],
const constant float& beta [[buffer(9)]],
uint2 gid [[thread_position_in_grid]]) {
// Ajust D and C
C += gid.x * fdc + gid.y * ldc;
D += gid.x + gid.y * ldd;
C_split += gid.x + gid.y * ldd;
C += gid.x * size_t(fdc) + gid.y * size_t(ldc);
D += gid.x + gid.y * size_t(ldd);
C_split += gid.x + gid.y * size_t(ldd);
int offset = 0;
size_t offset = 0;
AccT out = 0;
for(int i = 0; i < k_partitions; i++) {
for (int i = 0; i < k_partitions; i++) {
out += C_split[offset];
offset += partition_stride;
}
// Write output
// Write output
Epilogue op(alpha, beta);
D[0] = op.apply(out, *C);
}
#define instantiate_accum(oname, otype, aname, atype) \
template [[host_name("steel_gemm_splitk_accum_" #oname "_" #aname)]] \
[[kernel]] void gemm_splitk_accum<atype, otype>( \
const device atype *C_split [[buffer(0)]], \
device otype *D [[buffer(1)]], \
const constant int& k_partitions [[buffer(2)]], \
const constant int& partition_stride [[buffer(3)]], \
const constant int& ldd [[buffer(4)]], \
uint2 gid [[thread_position_in_grid]]); \
template [[host_name("steel_gemm_splitk_accum_" #oname "_" #aname "_axpby")]] \
[[kernel]] void gemm_splitk_accum_axpby<atype, otype>( \
const device atype *C_split [[buffer(0)]], \
device otype *D [[buffer(1)]], \
const constant int& k_partitions [[buffer(2)]], \
const constant int& partition_stride [[buffer(3)]], \
const constant int& ldd [[buffer(4)]], \
const device otype *C [[buffer(5)]], \
const constant int& ldc [[buffer(6)]], \
const constant int& fdc [[buffer(7)]], \
const constant float& alpha [[buffer(8)]], \
const constant float& beta [[buffer(9)]], \
#define instantiate_accum(oname, otype, aname, atype) \
template [[host_name("steel_gemm_splitk_accum_" #oname \
"_" #aname)]] [[kernel]] void \
gemm_splitk_accum<atype, otype>( \
const device atype* C_split [[buffer(0)]], \
device otype* D [[buffer(1)]], \
const constant int& k_partitions [[buffer(2)]], \
const constant int& partition_stride [[buffer(3)]], \
const constant int& ldd [[buffer(4)]], \
uint2 gid [[thread_position_in_grid]]); \
template [[host_name("steel_gemm_splitk_accum_" #oname "_" #aname \
"_axpby")]] [[kernel]] void \
gemm_splitk_accum_axpby<atype, otype>( \
const device atype* C_split [[buffer(0)]], \
device otype* D [[buffer(1)]], \
const constant int& k_partitions [[buffer(2)]], \
const constant int& partition_stride [[buffer(3)]], \
const constant int& ldd [[buffer(4)]], \
const device otype* C [[buffer(5)]], \
const constant int& ldc [[buffer(6)]], \
const constant int& fdc [[buffer(7)]], \
const constant float& alpha [[buffer(8)]], \
const constant float& beta [[buffer(9)]], \
uint2 gid [[thread_position_in_grid]]);
// clang-format off
instantiate_accum(bfloat16, bfloat16_t, float32, float);
instantiate_accum(float16, half, float32, float);
instantiate_accum(float32, float, float32, float);
instantiate_accum(float32, float, float32, float); // clang-format on

View File

@@ -21,9 +21,9 @@ struct GEMMParams {
const int tiles_n;
const int tiles_m;
const int batch_stride_a;
const int batch_stride_b;
const int batch_stride_d;
const size_t batch_stride_a;
const size_t batch_stride_b;
const size_t batch_stride_d;
const int swizzle_log;
const int gemm_k_iterations_aligned;
@@ -54,7 +54,7 @@ struct GEMMAddMMParams {
const int ldc;
const int fdc;
const int batch_stride_c;
const size_t batch_stride_c;
const float alpha;
const float beta;

View File

@@ -3,9 +3,9 @@
#include <metal_integer>
#include <metal_math>
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/ternary.h"
#include "mlx/backend/metal/kernels/utils.h"
template <typename T, typename Op>
[[kernel]] void ternary_op_v(
@@ -65,7 +65,8 @@ template <typename T, typename Op>
auto a_idx = elem_to_loc_3(index, a_strides);
auto b_idx = elem_to_loc_3(index, b_strides);
auto c_idx = elem_to_loc_3(index, c_strides);
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
}
@@ -81,8 +82,10 @@ template <typename T, typename Op, int DIM>
constant const size_t c_strides[DIM],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_3_nd<DIM>(index, shape, a_strides, b_strides, c_strides);
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
auto idx =
elem_to_loc_3_nd<DIM>(index, shape, a_strides, b_strides, c_strides);
size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]);
}
@@ -99,103 +102,104 @@ template <typename T, typename Op>
constant const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_3_nd(index, shape, a_strides, b_strides, c_strides, ndim);
auto idx =
elem_to_loc_3_nd(index, shape, a_strides, b_strides, c_strides, ndim);
size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]);
}
#define instantiate_ternary_v(name, type, op) \
template [[host_name(name)]] \
[[kernel]] void ternary_op_v<type, op>( \
device const bool* a, \
device const type* b, \
device const type* c, \
device type* d, \
uint index [[thread_position_in_grid]]); \
#define instantiate_ternary_v(name, type, op) \
template [[host_name(name)]] [[kernel]] void ternary_op_v<type, op>( \
device const bool* a, \
device const type* b, \
device const type* c, \
device type* d, \
uint index [[thread_position_in_grid]]);
#define instantiate_ternary_g(name, type, op) \
template [[host_name(name)]] \
[[kernel]] void ternary_op_g<type, op>( \
device const bool* a, \
device const type* b, \
device const type* c, \
device type* d, \
constant const int* shape, \
constant const size_t* a_strides, \
constant const size_t* b_strides, \
constant const size_t* c_strides, \
constant const int& ndim, \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
#define instantiate_ternary_g(name, type, op) \
template [[host_name(name)]] [[kernel]] void ternary_op_g<type, op>( \
device const bool* a, \
device const type* b, \
device const type* c, \
device type* d, \
constant const int* shape, \
constant const size_t* a_strides, \
constant const size_t* b_strides, \
constant const size_t* c_strides, \
constant const int& ndim, \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]);
#define instantiate_ternary_g_dim(name, type, op, dims) \
template [[host_name(name "_" #dims)]] \
[[kernel]] void ternary_op_g_nd<type, op, dims>( \
device const bool* a, \
device const type* b, \
device const type* c, \
device type* d, \
constant const int shape[dims], \
constant const size_t a_strides[dims], \
constant const size_t b_strides[dims], \
constant const size_t c_strides[dims], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
#define instantiate_ternary_g_dim(name, type, op, dims) \
template [[host_name(name "_" #dims)]] [[kernel]] void \
ternary_op_g_nd<type, op, dims>( \
device const bool* a, \
device const type* b, \
device const type* c, \
device type* d, \
constant const int shape[dims], \
constant const size_t a_strides[dims], \
constant const size_t b_strides[dims], \
constant const size_t c_strides[dims], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]);
#define instantiate_ternary_g_nd(name, type, op) \
template [[host_name(name "_1")]] \
[[kernel]] void ternary_op_g_nd1<type, op>( \
device const bool* a, \
device const type* b, \
device const type* c, \
device type* d, \
constant const size_t& a_strides, \
constant const size_t& b_strides, \
constant const size_t& c_strides, \
uint index [[thread_position_in_grid]]); \
template [[host_name(name "_2")]] \
[[kernel]] void ternary_op_g_nd2<type, op>( \
device const bool* a, \
device const type* b, \
device const type* c, \
device type* d, \
constant const size_t a_strides[2], \
constant const size_t b_strides[2], \
constant const size_t c_strides[2], \
uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]); \
template [[host_name(name "_3")]] \
[[kernel]] void ternary_op_g_nd3<type, op>( \
device const bool* a, \
device const type* b, \
device const type* c, \
device type* d, \
constant const size_t a_strides[3], \
constant const size_t b_strides[3], \
constant const size_t c_strides[3], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
instantiate_ternary_g_dim(name, type, op, 4) \
instantiate_ternary_g_dim(name, type, op, 5) \
#define instantiate_ternary_g_nd(name, type, op) \
template [[host_name(name "_1")]] [[kernel]] void \
ternary_op_g_nd1<type, op>( \
device const bool* a, \
device const type* b, \
device const type* c, \
device type* d, \
constant const size_t& a_strides, \
constant const size_t& b_strides, \
constant const size_t& c_strides, \
uint index [[thread_position_in_grid]]); \
template [[host_name(name "_2")]] [[kernel]] void \
ternary_op_g_nd2<type, op>( \
device const bool* a, \
device const type* b, \
device const type* c, \
device type* d, \
constant const size_t a_strides[2], \
constant const size_t b_strides[2], \
constant const size_t c_strides[2], \
uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]); \
template [[host_name(name "_3")]] [[kernel]] void \
ternary_op_g_nd3<type, op>( \
device const bool* a, \
device const type* b, \
device const type* c, \
device type* d, \
constant const size_t a_strides[3], \
constant const size_t b_strides[3], \
constant const size_t c_strides[3], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
instantiate_ternary_g_dim(name, type, op, 4) \
instantiate_ternary_g_dim(name, type, op, 5)
// clang-format off
#define instantiate_ternary_all(name, tname, type, op) \
instantiate_ternary_v("v" #name #tname, type, op) \
instantiate_ternary_g("g" #name #tname, type, op) \
instantiate_ternary_g_nd("g" #name #tname, type, op) \
instantiate_ternary_v("v" #name #tname, type, op) \
instantiate_ternary_g("g" #name #tname, type, op) \
instantiate_ternary_g_nd("g" #name #tname, type, op) // clang-format on
#define instantiate_ternary_types(name, op) \
instantiate_ternary_all(name, bool_, bool, op) \
instantiate_ternary_all(name, uint8, uint8_t, op) \
instantiate_ternary_all(name, uint16, uint16_t, op) \
instantiate_ternary_all(name, uint32, uint32_t, op) \
instantiate_ternary_all(name, uint64, uint64_t, op) \
instantiate_ternary_all(name, int8, int8_t, op) \
instantiate_ternary_all(name, int16, int16_t, op) \
instantiate_ternary_all(name, int32, int32_t, op) \
instantiate_ternary_all(name, int64, int64_t, op) \
instantiate_ternary_all(name, float16, half, op) \
instantiate_ternary_all(name, float32, float, op) \
// clang-format off
#define instantiate_ternary_types(name, op) \
instantiate_ternary_all(name, bool_, bool, op) \
instantiate_ternary_all(name, uint8, uint8_t, op) \
instantiate_ternary_all(name, uint16, uint16_t, op) \
instantiate_ternary_all(name, uint32, uint32_t, op) \
instantiate_ternary_all(name, uint64, uint64_t, op) \
instantiate_ternary_all(name, int8, int8_t, op) \
instantiate_ternary_all(name, int16, int16_t, op) \
instantiate_ternary_all(name, int32, int32_t, op) \
instantiate_ternary_all(name, int64, int64_t, op) \
instantiate_ternary_all(name, float16, half, op) \
instantiate_ternary_all(name, float32, float, op) \
instantiate_ternary_all(name, bfloat16, bfloat16_t, op) \
instantiate_ternary_all(name, complex64, complex64_t, op) \
instantiate_ternary_all(name, complex64, complex64_t, op) // clang-format on
instantiate_ternary_types(select, Select)
instantiate_ternary_types(select, Select)

View File

@@ -22,44 +22,46 @@ template <typename T, typename Op>
out[index] = Op()(in[idx]);
}
#define instantiate_unary_v(name, type, op) \
template [[host_name(name)]] \
[[kernel]] void unary_op_v<type, op>( \
device const type* in, \
device type* out, \
#define instantiate_unary_v(name, type, op) \
template [[host_name(name)]] [[kernel]] void unary_op_v<type, op>( \
device const type* in, \
device type* out, \
uint index [[thread_position_in_grid]]);
#define instantiate_unary_g(name, type, op) \
template [[host_name(name)]] \
[[kernel]] void unary_op_g<type, op>( \
device const type* in, \
device type* out, \
device const int* in_shape, \
device const size_t* in_strides, \
device const int& ndim, \
#define instantiate_unary_g(name, type, op) \
template [[host_name(name)]] [[kernel]] void unary_op_g<type, op>( \
device const type* in, \
device type* out, \
device const int* in_shape, \
device const size_t* in_strides, \
device const int& ndim, \
uint index [[thread_position_in_grid]]);
// clang-format off
#define instantiate_unary_all(name, tname, type, op) \
instantiate_unary_v("v" #name #tname, type, op) \
instantiate_unary_g("g" #name #tname, type, op)
instantiate_unary_v("v" #name #tname, type, op) \
instantiate_unary_g("g" #name #tname, type, op) // clang-format on
#define instantiate_unary_float(name, op) \
instantiate_unary_all(name, float16, half, op) \
instantiate_unary_all(name, float32, float, op) \
instantiate_unary_all(name, bfloat16, bfloat16_t, op) \
// clang-format off
#define instantiate_unary_float(name, op) \
instantiate_unary_all(name, float16, half, op) \
instantiate_unary_all(name, float32, float, op) \
instantiate_unary_all(name, bfloat16, bfloat16_t, op) // clang-format on
#define instantiate_unary_types(name, op) \
instantiate_unary_all(name, bool_, bool, op) \
instantiate_unary_all(name, uint8, uint8_t, op) \
// clang-format off
#define instantiate_unary_types(name, op) \
instantiate_unary_all(name, bool_, bool, op) \
instantiate_unary_all(name, uint8, uint8_t, op) \
instantiate_unary_all(name, uint16, uint16_t, op) \
instantiate_unary_all(name, uint32, uint32_t, op) \
instantiate_unary_all(name, uint64, uint64_t, op) \
instantiate_unary_all(name, int8, int8_t, op) \
instantiate_unary_all(name, int16, int16_t, op) \
instantiate_unary_all(name, int32, int32_t, op) \
instantiate_unary_all(name, int64, int64_t, op) \
instantiate_unary_float(name, op)
instantiate_unary_all(name, int8, int8_t, op) \
instantiate_unary_all(name, int16, int16_t, op) \
instantiate_unary_all(name, int32, int32_t, op) \
instantiate_unary_all(name, int64, int64_t, op) \
instantiate_unary_float(name, op) // clang-format on
// clang-format off
instantiate_unary_types(abs, Abs)
instantiate_unary_float(arccos, ArcCos)
instantiate_unary_float(arccosh, ArcCosh)
@@ -102,4 +104,4 @@ instantiate_unary_all(tan, complex64, complex64_t, Tan)
instantiate_unary_all(tanh, complex64, complex64_t, Tanh)
instantiate_unary_all(round, complex64, complex64_t, Round)
instantiate_unary_all(lnot, bool_, bool, LogicalNot)
instantiate_unary_all(lnot, bool_, bool, LogicalNot) // clang-format on

View File

@@ -260,6 +260,110 @@ inline auto collapse_batches(const array& a, const array& b, const array& c) {
// Steel matmul fallback
///////////////////////////////////////////////////////////////////////////////
void steel_matmul_conv_groups(
const Stream& s,
metal::Device& d,
const array& a,
const array& b,
array& out,
int M,
int N,
int K,
int lda,
int ldb,
int ldd,
bool transpose_a,
bool transpose_b,
int groups,
std::vector<array>& copies) {
using namespace mlx::steel;
/////////////////////////////////////////////////////////////////////////////
// Regular kernel dispatch
// Determine dispatch kernel
int bm = 32, bn = 32, bk = 16;
int wm = 2, wn = 2;
if ((size_t)M * N >= 1ul << 20) {
if (!transpose_a && transpose_b) {
bm = 64;
bn = (out.dtype() == float32) ? 64 : 32;
bk = (out.dtype() == float32) ? 16 : 32;
} else {
bm = 64;
bn = 64;
}
}
// Prepare kernel name
std::ostringstream kname;
kname << "steel_gemm_" << (transpose_a ? 't' : 'n')
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
<< "_wm" << wm << "_wn" << wn << "_MN_"
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_"
<< ((K % bk == 0) ? "t" : "n") << "aligned";
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
// Use problem size to determine threadblock swizzle
int tn = (N + bn - 1) / bn;
int tm = (M + bm - 1) / bm;
// TODO: Explore device-based tuning for swizzle
int swizzle_log = 0; // tm >= 6 ? 3 : (tm <= 3 ? 0 : 2);
// Prepare steel matmul params
GEMMParams params{
/* const int M = */ M,
/* const int N = */ N,
/* const int K = */ K,
/* const int lda = */ lda,
/* const int ldb = */ ldb,
/* const int ldd = */ ldd,
/* const int tiles_n = */ tn,
/* const int tiles_m = */ tm,
/* const size_t batch_stride_a = */ size_t(K),
/* const size_t batch_stride_b = */ size_t(N) * K,
/* const size_t batch_stride_d = */ size_t(N),
/* const int swizzle_log = */ swizzle_log,
/* const int gemm_k_iterations_aligned = */ (K / bk),
/* const int batch_ndim = */ 1};
// Prepare launch grid params
int tile = 1 << swizzle_log;
tm = (tm + tile - 1) / tile;
tn = tn * tile;
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(tn, tm, groups);
std::vector<int> batch_shape = {1};
std::vector<size_t> batch_strides = {0};
// Launch kernel
compute_encoder.set_input_array(a, 0);
compute_encoder.set_input_array(b, 1);
compute_encoder.set_output_array(out, 3);
compute_encoder->setBytes(&params, sizeof(GEMMParams), 4);
compute_encoder->setBytes(
batch_shape.data(), sizeof(int) * batch_shape.size(), 6);
compute_encoder->setBytes(
batch_strides.data(), sizeof(size_t) * batch_strides.size(), 7);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
// Clear copies
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
return;
}
void steel_matmul(
const Stream& s,
metal::Device& d,
@@ -301,7 +405,7 @@ void steel_matmul(
}
}
int matrix_stride_out = M * N;
size_t matrix_stride_out = size_t(M) * N;
/////////////////////////////////////////////////////////////////////////////
// Split K specialization
@@ -364,7 +468,7 @@ void steel_matmul(
compute_encoder.set_output_array(C_split, 2);
compute_encoder->setBytes(&params, sizeof(GEMMSpiltKParams), 3);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
// Do accum kernel
{
@@ -389,7 +493,7 @@ void steel_matmul(
MTL::Size grid_dims = MTL::Size(N, M, 1);
MTL::Size group_dims = MTL::Size(std::min(1024, N * M), 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
d.get_command_buffer(s.index)->addCompletedHandler(
@@ -446,9 +550,9 @@ void steel_matmul(
/* const int ldd = */ N,
/* const int tiles_n = */ tn,
/* const int tiles_m = */ tm,
/* const int batch_stride_a = */ int(A_batch_stride.back()),
/* const int batch_stride_b = */ int(B_batch_stride.back()),
/* const int batch_stride_d = */ matrix_stride_out,
/* const size_t batch_stride_a = */ A_batch_stride.back(),
/* const size_t batch_stride_b = */ B_batch_stride.back(),
/* const size_t batch_stride_d = */ matrix_stride_out,
/* const int swizzle_log = */ swizzle_log,
/* const int gemm_k_iterations_aligned = */ (K / bk),
/* const int batch_ndim = */ int(batch_shape.size())};
@@ -477,7 +581,7 @@ void steel_matmul(
compute_encoder->setBytes(
batch_strides.data(), sizeof(size_t) * batch_strides.size(), 7);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
// Clear copies
d.get_command_buffer(s.index)->addCompletedHandler(
@@ -541,7 +645,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
auto [batch_shape, A_batch_stride, B_batch_stride] = collapse_batches(a, b);
auto batch_size_out = out.size() / (M * N);
auto batch_size_out = out.size() / (size_t(M) * size_t(N));
// Collapse batches into M if needed
if (batch_size_out > 1 && !a_transposed && batch_shape.size() == 1 &&
@@ -644,7 +748,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder->setBytes(
batch_strides_mat.data(), batch_ndim * sizeof(size_t), 12);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
@@ -749,7 +853,8 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
auto [batch_shape, A_batch_stride, B_batch_stride, C_batch_stride] =
collapse_batches(a, b, c);
auto batch_size_out = out.size() / (M * N);
size_t matrix_stride_out = size_t(M) * size_t(N);
auto batch_size_out = out.size() / (matrix_stride_out);
// Collapse batches into M if needed
if (batch_size_out > 1 && !transpose_a && batch_shape.size() == 1 &&
@@ -765,8 +870,6 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
batch_shape = {1};
}
int matrix_stride_out = M * N;
/////////////////////////////////////////////////////////////////////////////
// Gemv specialization
@@ -865,7 +968,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
int bias_stride = c.strides()[c.ndim() - 1];
compute_encoder->setBytes(&bias_stride, sizeof(int), 14);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
@@ -935,7 +1038,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_output_array(C_split, 2);
compute_encoder->setBytes(&params, sizeof(GEMMSpiltKParams), 3);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
// Do accum kernel
{
@@ -960,7 +1063,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
MTL::Size grid_dims = MTL::Size(N, M, 1);
MTL::Size group_dims = MTL::Size(std::min(1024, N * M), 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
d.get_command_buffer(s.index)->addCompletedHandler(
@@ -1016,9 +1119,9 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
/* const int ldd = */ N,
/* const int tiles_n = */ tn,
/* const int tiles_m = */ tm,
/* const int batch_stride_a = */ int(A_batch_stride.back()),
/* const int batch_stride_b = */ int(B_batch_stride.back()),
/* const int batch_stride_d = */ matrix_stride_out,
/* const size_t batch_stride_a = */ A_batch_stride.back(),
/* const size_t batch_stride_b = */ B_batch_stride.back(),
/* const size_t batch_stride_d = */ matrix_stride_out,
/* const int swizzle_log = */ swizzle_log,
/* const int gemm_k_iterations_aligned = */ (K / bk),
/* const int batch_ndim = */ int(batch_shape.size())};
@@ -1026,7 +1129,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
GEMMAddMMParams params{
/* const int ldc = */ ldc,
/* const int fdc = */ fdc,
/* const int batch_stride_c = */ int(C_batch_stride.back()),
/* const size_t batch_stride_c = */ C_batch_stride.back(),
/* const float alpha = */ alpha_,
/* const float beta = */ beta_};
@@ -1057,7 +1160,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder->setBytes(
batch_strides.data(), sizeof(size_t) * batch_strides.size(), 7);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
@@ -1126,8 +1229,8 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& out_mask = inputs[2];
std::vector<int> batch_shape{1};
int A_batch_str = 0;
int B_batch_str = 0;
size_t A_batch_str = 0;
size_t B_batch_str = 0;
std::vector<size_t> batch_strides;
@@ -1145,8 +1248,8 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
auto [bshape_c, bstrides_c] = collapse_contiguous_dims(bshape, bstrides);
batch_shape = bshape_c;
A_batch_str = int(bstrides_c[0].back());
B_batch_str = int(bstrides_c[1].back());
A_batch_str = bstrides_c[0].back();
B_batch_str = bstrides_c[1].back();
for (auto& bstr : bstrides_c) {
batch_strides.insert(batch_strides.end(), bstr.begin(), bstr.end());
@@ -1155,8 +1258,8 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
batch_strides = std::vector<size_t>(inputs.size(), 0);
}
auto batch_size_out = out.size() / (M * N);
int matrix_stride_out = M * N;
size_t matrix_stride_out = size_t(M) * N;
size_t batch_size_out = out.size() / (matrix_stride_out);
/////////////////////////////////////////////////////////////////////////////
// Regular kernel dispatch
@@ -1197,9 +1300,9 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
/* const int ldd = */ N,
/* const int tiles_n = */ tn,
/* const int tiles_m = */ tm,
/* const int batch_stride_a = */ A_batch_str,
/* const int batch_stride_b = */ B_batch_str,
/* const int batch_stride_d = */ matrix_stride_out,
/* const size_t batch_stride_a = */ A_batch_str,
/* const size_t batch_stride_b = */ B_batch_str,
/* const size_t batch_stride_d = */ matrix_stride_out,
/* const int swizzle_log = */ swizzle_log,
/* const int gemm_k_iterations_aligned = */ (K / bk),
/* const int batch_ndim = */ int(batch_shape.size())};
@@ -1243,7 +1346,317 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_input_array(out_mask, 10);
set_vector_bytes(compute_encoder, mask_strides, 13);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
// Clear copies
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
return;
}
void BlockSparseMM::eval_gpu(const std::vector<array>& inputs, array& out) {
using namespace mlx::steel;
// assert(inputs.size() == 2);
if (!issubdtype(out.dtype(), floating)) {
throw std::runtime_error(
"[matmul] Does not yet support non-floating point types.");
}
auto& s = stream();
auto& d = metal::device(s.device);
auto& a_pre = inputs[0];
auto& b_pre = inputs[1];
// Return 0s if either input is empty
if (a_pre.size() == 0 || b_pre.size() == 0) {
array zero = array(0, a_pre.dtype());
copy_gpu(zero, out, CopyType::Scalar, s);
auto command_buffer = d.get_command_buffer(s.index);
command_buffer->addCompletedHandler([zero](MTL::CommandBuffer*) {});
return;
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
/////////////////////////////////////////////////////////////////////////////
// Init checks and prep
// Keep a vector with copies to be cleared in the completed buffer to release
// the arrays
std::vector<array> copies;
auto check_transpose = [&copies, &s](const array& arr) {
auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1];
if (sty == 1) {
return std::make_tuple(false, stx, arr);
} else if (stx == 1) {
return std::make_tuple(true, sty, arr);
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s);
copies.push_back(arr_copy);
size_t stx = arr.shape(-1);
return std::make_tuple(false, stx, arr_copy);
}
};
auto [transpose_a, a_cols, a] = check_transpose(a_pre);
auto [transpose_b, b_cols, b] = check_transpose(b_pre);
int lda = a_cols;
int ldb = b_cols;
int M = a.shape(-2);
int N = b.shape(-1);
int K = a.shape(-1);
/////////////////////////////////////////////////////////////////////////////
// Check and collapse batch dimensions
auto get_batch_dims = [](const auto& v) {
return decltype(v){v.begin(), v.end() - 2};
};
auto& lhs_indices = inputs[2];
auto& rhs_indices = inputs[3];
std::vector<int> batch_shape = get_batch_dims(out.shape());
std::vector<size_t> batch_strides;
batch_strides.insert(
batch_strides.end(),
lhs_indices.strides().begin(),
lhs_indices.strides().end());
size_t lhs_indices_str = batch_strides.empty() ? 0 : batch_strides.back();
batch_strides.insert(
batch_strides.end(),
rhs_indices.strides().begin(),
rhs_indices.strides().end());
size_t rhs_indices_str = batch_strides.empty() ? 0 : batch_strides.back();
int batch_ndim = batch_shape.size();
if (batch_ndim == 0) {
batch_shape = {1};
batch_strides = {0};
}
int batch_ndim_A = a.ndim() - 2;
int batch_ndim_B = b.ndim() - 2;
std::vector<int> operand_batch_ndim = {batch_ndim_A, batch_ndim_B};
std::vector<int> batch_shape_A = get_batch_dims(a.shape());
std::vector<size_t> batch_strides_A = get_batch_dims(a.strides());
std::vector<int> batch_shape_B = get_batch_dims(b.shape());
std::vector<size_t> batch_strides_B = get_batch_dims(b.strides());
if (batch_ndim_A == 0) {
batch_shape_A = {1};
batch_strides_A = {0};
}
if (batch_ndim_B == 0) {
batch_shape_B = {1};
batch_strides_B = {0};
}
size_t matrix_stride_out = size_t(M) * N;
auto batch_size_out = out.size() / matrix_stride_out;
/////////////////////////////////////////////////////////////////////////////
// Gemv specialization
// Route to gemv if needed
if (std::min(M, N) == 1) {
// Collect problem info
bool is_b_matrix = N != 1;
auto& mat = is_b_matrix ? b : a;
auto& vec = is_b_matrix ? a : b;
bool transpose_mat = is_b_matrix ? !transpose_b : transpose_a;
int in_vector_len = K;
int out_vector_len = is_b_matrix ? N : M;
int mat_cols = transpose_mat ? out_vector_len : in_vector_len;
int mat_rows = transpose_mat ? in_vector_len : out_vector_len;
int mat_ld = is_b_matrix ? b_cols : a_cols;
auto batch_strides_mat = is_b_matrix ? batch_strides_B : batch_strides_A;
auto batch_strides_vec = is_b_matrix ? batch_strides_A : batch_strides_B;
auto batch_shape_mat = is_b_matrix ? batch_shape_B : batch_shape_A;
auto batch_shape_vec = is_b_matrix ? batch_shape_A : batch_shape_B;
if (!is_b_matrix) {
batch_strides = rhs_indices.strides();
batch_strides.insert(
batch_strides.end(),
lhs_indices.strides().begin(),
lhs_indices.strides().end());
}
int batch_ndim = batch_shape.size();
// Determine dispatch kernel
int tm = 4, tn = 4;
int bm, bn, n_out_per_tgp;
std::ostringstream kname;
if (transpose_mat) {
bm = 8;
bn = 8;
if (out_vector_len >= 24576) {
bn = 128;
} else if (out_vector_len >= 16384) {
bn = 64;
} else if (out_vector_len >= 8192) {
bn = 16;
}
// Specialized kernel for very small outputs
tn = out_vector_len < tn ? 1 : tn;
n_out_per_tgp = bn * tn;
kname << "gemv_t_bs_" << type_to_name(out);
} else {
bm = out_vector_len >= 4096 ? 8 : 4;
bn = 32;
// Specialized kernel for very small outputs
tm = out_vector_len < tm ? 1 : tm;
n_out_per_tgp = bm * tm;
kname << "gemv_bs_" << type_to_name(out);
}
kname << "_bm" << bm << "_bn" << bn << "_tm" << tm << "_tn" << tn;
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
MTL::Size group_dims = MTL::Size(bn, bm, 1);
MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out);
compute_encoder.set_input_array(mat, 0);
compute_encoder.set_input_array(vec, 1);
compute_encoder.set_output_array(out, 3);
compute_encoder->setBytes(&in_vector_len, sizeof(int), 4);
compute_encoder->setBytes(&out_vector_len, sizeof(int), 5);
compute_encoder->setBytes(&mat_ld, sizeof(int), 6);
compute_encoder->setBytes(&batch_ndim, sizeof(int), 9);
set_vector_bytes(compute_encoder, batch_shape, 10);
set_vector_bytes(compute_encoder, batch_strides, 11);
int batch_ndim_vec = batch_shape_vec.size();
compute_encoder->setBytes(&batch_ndim_vec, sizeof(int), 12);
set_vector_bytes(compute_encoder, batch_shape_vec, 13);
set_vector_bytes(compute_encoder, batch_strides_vec, 14);
int batch_ndim_mat = batch_shape_mat.size();
compute_encoder->setBytes(&batch_ndim_mat, sizeof(int), 15);
set_vector_bytes(compute_encoder, batch_shape_mat, 16);
set_vector_bytes(compute_encoder, batch_strides_mat, 17);
compute_encoder.set_input_array(lhs_indices, 18 + int(!is_b_matrix));
compute_encoder.set_input_array(rhs_indices, 18 + int(is_b_matrix));
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
return;
}
/////////////////////////////////////////////////////////////////////////////
// Regular kernel dispatch
// Determine dispatch kernel
int bm = 32, bn = 32, bk = 16;
int wm = 2, wn = 2;
if ((size_t)batch_size_out * M * N >= 1ul << 20) {
if (!transpose_a && transpose_b) {
bm = 64;
bn = (out.dtype() == float32) ? 64 : 32;
bk = (out.dtype() == float32) ? 16 : 32;
} else {
bm = 64;
bn = 64;
}
}
// Prepare kernel name
std::ostringstream kname;
kname << "steel_block_sparse_gemm_" << (transpose_a ? 't' : 'n')
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
<< "_wm" << wm << "_wn" << wn << "_MN_"
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_"
<< ((K % bk == 0) ? "t" : "n") << "aligned";
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
// Use problem size to determine threadblock swizzle
int tn = (N + bn - 1) / bn;
int tm = (M + bm - 1) / bm;
// TODO: Explore device-based tuning for swizzle
int swizzle_log = 0; // tm >= 6 ? 3 : (tm <= 3 ? 0 : 2);
// Prepare steel matmul params
GEMMParams params{
/* const int M = */ M,
/* const int N = */ N,
/* const int K = */ K,
/* const int lda = */ lda,
/* const int ldb = */ ldb,
/* const int ldd = */ N,
/* const int tiles_n = */ tn,
/* const int tiles_m = */ tm,
/* const size_t batch_stride_a = */ lhs_indices_str,
/* const size_t batch_stride_b = */ rhs_indices_str,
/* const size_t batch_stride_d = */ matrix_stride_out,
/* const int swizzle_log = */ swizzle_log,
/* const int gemm_k_iterations_aligned = */ (K / bk),
/* const int batch_ndim = */ batch_ndim};
// Prepare launch grid params
int tile = 1 << swizzle_log;
tm = (tm + tile - 1) / tile;
tn = tn * tile;
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out);
// Launch kernel
compute_encoder.set_input_array(a, 0);
compute_encoder.set_input_array(b, 1);
compute_encoder.set_output_array(out, 3);
compute_encoder->setBytes(&params, sizeof(GEMMParams), 4);
set_vector_bytes(compute_encoder, batch_shape, 6);
set_vector_bytes(compute_encoder, batch_strides, 7);
compute_encoder.set_input_array(lhs_indices, 10);
compute_encoder.set_input_array(rhs_indices, 11);
set_vector_bytes(compute_encoder, batch_shape_A, 12);
set_vector_bytes(compute_encoder, batch_strides_A, 13);
set_vector_bytes(compute_encoder, batch_shape_B, 14);
set_vector_bytes(compute_encoder, batch_strides_B, 15);
set_vector_bytes(compute_encoder, operand_batch_ndim, 16);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
// Clear copies
d.get_command_buffer(s.index)->addCompletedHandler(

View File

@@ -12,6 +12,23 @@
namespace mlx::core {
void steel_matmul_conv_groups(
const Stream& s,
metal::Device& d,
const array& a,
const array& b,
array& out,
int M,
int N,
int K,
int lda,
int ldb,
int ldd,
bool transpose_a,
bool transpose_b,
int groups,
std::vector<array>& copies);
void steel_matmul(
const Stream& s,
metal::Device& d,

View File

@@ -27,24 +27,6 @@ int max_ops_per_buffer() {
#define MAX_OPS_PER_BUFFER max_ops_per_buffer()
MTL::CommandBuffer* increment_command_buffer(Stream s) {
auto& d = metal::device(s.device);
auto command_buffer = d.get_command_buffer(s.index);
if (command_buffer == nullptr ||
d.get_command_buffer_ops(s.index) >= MAX_OPS_PER_BUFFER) {
if (command_buffer != nullptr) {
d.end_encoding(s.index);
scheduler::notify_new_task(s);
command_buffer->addCompletedHandler(
[s](MTL::CommandBuffer*) { scheduler::notify_task_completion(s); });
d.commit_command_buffer(s.index);
}
command_buffer = d.new_command_buffer(s.index);
}
d.increment_command_buffer_ops(s.index);
return command_buffer;
}
inline void check_error(MTL::CommandBuffer* cbuf) {
if (cbuf->status() == MTL::CommandBufferStatusError) {
std::ostringstream msg;
@@ -58,7 +40,10 @@ std::function<void()> make_task(array arr, bool signal) {
auto task = [arr = std::move(arr), signal]() mutable {
auto pool = new_scoped_memory_pool();
auto s = arr.primitive().stream();
auto command_buffer = increment_command_buffer(s);
auto& d = metal::device(s.device);
auto command_buffer = d.get_command_buffer(s.index);
d.increment_command_buffer_ops(s.index);
for (auto& input : arr.inputs()) {
if (input.event().valid() &&
input.event().stream() != arr.primitive().stream()) {
@@ -91,11 +76,13 @@ std::function<void()> make_task(array arr, bool signal) {
arr.detach();
}
if (signal) {
metal::device(s.device).end_encoding(s.index);
command_buffer->encodeSignalEvent(
static_cast<MTL::Event*>(arr.event().raw_event().get()),
arr.event().value());
if (signal || d.get_command_buffer_ops(s.index) >= MAX_OPS_PER_BUFFER) {
d.end_encoding(s.index);
if (signal) {
command_buffer->encodeSignalEvent(
static_cast<MTL::Event*>(arr.event().raw_event().get()),
arr.event().value());
}
scheduler::notify_new_task(s);
command_buffer->addCompletedHandler(
[s, buffers = std::move(buffers), event = arr.event()](
@@ -103,7 +90,8 @@ std::function<void()> make_task(array arr, bool signal) {
scheduler::notify_task_completion(s);
check_error(cbuf);
});
metal::device(s.device).commit_command_buffer(s.index);
d.commit_command_buffer(s.index);
d.get_command_buffer(s.index);
} else {
command_buffer->addCompletedHandler(
[s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
@@ -120,14 +108,12 @@ std::function<void()> make_synchronize_task(
return [s, p = std::move(p)]() {
auto& d = metal::device(s.device);
auto cb = d.get_command_buffer(s.index);
if (cb == nullptr) {
cb = d.new_command_buffer(s.index);
} else {
d.end_encoding(s.index);
}
cb->retain();
d.end_encoding(s.index);
d.commit_command_buffer(s.index);
cb->waitUntilCompleted();
check_error(cb);
cb->release();
p->set_value();
};
}

View File

@@ -2,6 +2,8 @@
#pragma once
#include <variant>
#include "mlx/array.h"
namespace mlx::core::metal {
@@ -18,11 +20,15 @@ size_t get_active_memory();
/* Get the peak amount of used memory in bytes.
*
* The maximum memory used is recorded from the beginning of the program
* execution.
* The maximum memory used recorded from the beginning of the program
* execution or since the last call to reset_peak_memory.
* */
size_t get_peak_memory();
/* Reset the peak memory to zero.
* */
void reset_peak_memory();
/* Get the cache size in bytes.
*
* The cache includes memory not currently used that has not been returned
@@ -61,4 +67,8 @@ void clear_cache();
void start_capture(std::string path = "");
void stop_capture();
/** Get information about the GPU and system settings. */
std::unordered_map<std::string, std::variant<std::string, size_t>>
device_info();
} // namespace mlx::core::metal

View File

@@ -89,7 +89,7 @@ void RMSNorm::eval_gpu(
compute_encoder->setThreadgroupMemoryLength(
16 * 8, 0); // minimum of 16 bytes
compute_encoder->setThreadgroupMemoryLength(simd_size * sizeof(float), 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
@@ -190,7 +190,7 @@ void RMSNormVJP::eval_gpu(
compute_encoder->setBytes(&eps_, sizeof(float), 5);
compute_encoder->setBytes(&axis_size, sizeof(int), 6);
compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 7);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
ReductionPlan plan(
@@ -282,7 +282,7 @@ void LayerNorm::eval_gpu(
compute_encoder->setBytes(&axis_size, sizeof(int), 5);
compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 6);
compute_encoder->setBytes(&b_stride, sizeof(uint32_t), 7);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
@@ -401,7 +401,7 @@ void LayerNormVJP::eval_gpu(
compute_encoder->setBytes(&eps_, sizeof(float), 5);
compute_encoder->setBytes(&axis_size, sizeof(int), 6);
compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 7);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
if (gw.ndim() == 1 && gw.size() == axis_size) {

View File

@@ -107,7 +107,7 @@ void binary_op(
}
auto group_dims = get_block_dims(dim0, dim1, rest);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} else {
// Launch a 1D grid of threads
size_t nthreads = out.data_size();
@@ -117,7 +117,7 @@ void binary_op(
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
}
@@ -201,7 +201,7 @@ void binary_op(
}
auto group_dims = get_block_dims(dim0, dim1, rest);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} else {
// Launch a 1D grid of threads
size_t nthreads =
@@ -212,7 +212,7 @@ void binary_op(
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
}
@@ -288,7 +288,7 @@ void ternary_op(
}
MTL::Size group_dims = get_block_dims(dim0, dim1, rest);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} else {
// Launch a 1D grid of threads
size_t nthreads = out.data_size();
@@ -298,7 +298,7 @@ void ternary_op(
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
}
@@ -351,7 +351,7 @@ void unary_op(
int ndim = in.ndim();
compute_encoder->setBytes(&ndim, sizeof(int), 4);
}
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
} // namespace
@@ -428,7 +428,7 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
}
compute_encoder.set_output_array(out, 2);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
void ArcCos::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -451,6 +451,10 @@ void ArcTan::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "arctan");
}
void ArcTan2::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "arctan2");
}
void ArcTanh::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "arctanh");
}
@@ -519,7 +523,7 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder->setBytes(&ndim, sizeof(size_t), 5);
compute_encoder->setBytes(&axis_stride, sizeof(size_t), 6);
compute_encoder->setBytes(&axis_size, sizeof(size_t), 7);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
}
@@ -533,6 +537,26 @@ void AsStrided::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
switch (op_) {
case BitwiseBinary::And:
binary_op(inputs, out, "bitwise_and");
break;
case BitwiseBinary::Or:
binary_op(inputs, out, "bitwise_or");
break;
case BitwiseBinary::Xor:
binary_op(inputs, out, "bitwise_xor");
break;
case BitwiseBinary::LeftShift:
binary_op(inputs, out, "left_shift");
break;
case BitwiseBinary::RightShift:
binary_op(inputs, out, "right_shift");
break;
}
}
void Broadcast::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
@@ -810,7 +834,7 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
keys.strides().data(), keys.ndim() * sizeof(size_t), 6);
}
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) {

View File

@@ -65,7 +65,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder->setBytes(&D, sizeof(int), 5);
compute_encoder->setBytes(&O, sizeof(int), 6);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
// Route to the qmv kernel
@@ -92,7 +92,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder->setBytes(&D, sizeof(int), 5);
compute_encoder->setBytes(&O, sizeof(int), 6);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
// Route to the qmm_t kernel
@@ -123,7 +123,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder->setBytes(&O, sizeof(int), 6);
compute_encoder->setBytes(&D, sizeof(int), 7);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
} else {
// Route to the qvm kernel
@@ -150,7 +150,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder->setBytes(&D, sizeof(int), 5);
compute_encoder->setBytes(&O, sizeof(int), 6);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
// Route to the qmm_n kernel
@@ -188,7 +188,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder->setBytes(&O, sizeof(int), 6);
compute_encoder->setBytes(&D, sizeof(int), 7);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
}

View File

@@ -74,7 +74,7 @@ void all_reduce_dispatch(
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder->setBytes(&in_size, sizeof(size_t), 2);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} else {
// Allocate intermediate array to store partial reduction results
@@ -88,7 +88,7 @@ void all_reduce_dispatch(
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(intermediate, 1);
compute_encoder->setBytes(&in_size, sizeof(size_t), 2);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
// Second pass to reduce intermediate reduction results written to DRAM
compute_encoder.set_input_array(intermediate, 0);
@@ -108,7 +108,7 @@ void all_reduce_dispatch(
nthreads = thread_group_size;
group_dims = MTL::Size(thread_group_size, 1, 1);
grid_dims = MTL::Size(nthreads, 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
d.get_command_buffer(s.index)->addCompletedHandler(
[intermediates](MTL::CommandBuffer*) mutable {
@@ -217,7 +217,7 @@ void row_reduce_general_dispatch(
compute_encoder->setBytes(
strides.data(), strides.size() * sizeof(size_t), 6);
compute_encoder->setBytes(&ndim, sizeof(int), 7);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} else {
// Allocate intermediate array to store partial reduction results
@@ -239,7 +239,7 @@ void row_reduce_general_dispatch(
compute_encoder->setBytes(
strides.data(), strides.size() * sizeof(size_t), 6);
compute_encoder->setBytes(&ndim, sizeof(int), 7);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
// Set up second dispatch
reduction_size = non_row_reductions;
@@ -286,7 +286,7 @@ void row_reduce_general_dispatch(
grid_dims = MTL::Size(n_threads, out.size(), 1);
group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
d.get_command_buffer(s.index)->addCompletedHandler(
[intermediates](MTL::CommandBuffer*) mutable {
@@ -366,7 +366,7 @@ void strided_reduce_general_dispatch(
compute_encoder->setBytes(&non_col_ndim, sizeof(int), 11);
// Dispatch threads
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
return;
}
@@ -435,7 +435,7 @@ void strided_reduce_general_dispatch(
compute_encoder->setThreadgroupMemoryLength(
safe_divup(threadgroup_dim_x * threadgroup_dim_y * out.itemsize(), 16),
0);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
} else {
// Allocate intermediate array to store reduction results from all thread
@@ -470,7 +470,7 @@ void strided_reduce_general_dispatch(
compute_encoder->setThreadgroupMemoryLength(
safe_divup(threadgroup_dim_x * threadgroup_dim_y * out.itemsize(), 16),
0);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
// Perform second pass of reductions
// Reduce results of threadgroups along y, z from first pass, that
@@ -523,7 +523,7 @@ void strided_reduce_general_dispatch(
grid_dims = MTL::Size(n_threads, out.size(), 1);
group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
d.get_command_buffer(s.index)->addCompletedHandler(
[intermediates](MTL::CommandBuffer*) mutable {
@@ -585,7 +585,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_output_array(out, 0);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
// Reduce

View File

@@ -83,7 +83,7 @@ void RoPE::eval_gpu(
int dim2 = in.size() / mat_size;
auto group_dims = get_block_dims(dim0, dim1, dim2);
auto grid_dims = MTL::Size(dim0, dim1, dim2);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
} // namespace mlx::core::fast

View File

@@ -99,7 +99,7 @@ void sdpa_metal(
constexpr const uint tgroupMemorySize = 32768;
compute_encoder->setThreadgroupMemoryLength(tgroupMemorySize, 0);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
{
auto kernel_accum = d.get_kernel(kname_reduce.str());
@@ -114,7 +114,7 @@ void sdpa_metal(
MTL::Size grid_dims_reduce = MTL::Size(heads, 1, batch);
MTL::Size group_dims_reduce = MTL::Size(128, 1, 1);
compute_encoder->dispatchThreadgroups(grid_dims_reduce, group_dims_reduce);
compute_encoder.dispatchThreadgroups(grid_dims_reduce, group_dims_reduce);
d.get_command_buffer(s.index)->addCompletedHandler(
[temporaries](MTL::CommandBuffer*) mutable { temporaries.clear(); });

View File

@@ -77,7 +77,7 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
static_cast<int>(kernel->maxTotalThreadsPerThreadgroup()));
MTL::Size grid_dims = MTL::Size(thread_groups * thread_group_size, 1, 1);
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} else {
kname << "strided_scan_";
if (reverse_) {
@@ -119,7 +119,7 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
int grid_x = (stride + elements_per_tile_x - 1) / elements_per_tile_x;
MTL::Size grid_dims = MTL::Size(grid_x * tile_x, grid_y * tile_y, 1);
MTL::Size group_dims = MTL::Size(tile_x, tile_y, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
if (copies.size() > 0) {

View File

@@ -85,7 +85,7 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
in.data_shared_ptr() == nullptr ? out : in, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder->setBytes(&axis_size, sizeof(int), 2);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });

View File

@@ -78,7 +78,7 @@ void single_block_sort(
MTL::Size group_dims = MTL::Size(bn, 1, 1);
MTL::Size grid_dims = MTL::Size(1, n_rows, 1);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
template <bool ARGSORT>
@@ -155,7 +155,7 @@ void multi_block_sort(
MTL::Size group_dims = MTL::Size(bn, 1, 1);
MTL::Size grid_dims = MTL::Size(n_blocks, n_rows, 1);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
// Do merges
@@ -190,7 +190,7 @@ void multi_block_sort(
MTL::Size group_dims = MTL::Size(n_blocks + 1, 1, 1);
MTL::Size grid_dims = MTL::Size(1, n_rows, 1);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
// Do merge
@@ -214,7 +214,7 @@ void multi_block_sort(
MTL::Size group_dims = MTL::Size(bn, 1, 1);
MTL::Size grid_dims = MTL::Size(n_blocks, n_rows, 1);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
}

View File

@@ -36,6 +36,7 @@ size_t get_active_memory() {
size_t get_peak_memory() {
return 0;
}
void reset_peak_memory() {}
size_t get_cache_memory() {
return 0;
}
@@ -49,4 +50,10 @@ void start_capture(std::string path) {}
void stop_capture() {}
void clear_cache() {}
std::unordered_map<std::string, std::variant<std::string, size_t>>
device_info() {
throw std::runtime_error(
"[metal::device_info] Cannot get device info without metal backend");
};
} // namespace mlx::core::metal

View File

@@ -25,12 +25,16 @@ NO_GPU(ArcCosh)
NO_GPU(ArcSin)
NO_GPU(ArcSinh)
NO_GPU(ArcTan)
NO_GPU(ArcTan2)
NO_GPU(ArcTanh)
NO_GPU(ArgPartition)
NO_GPU(ArgReduce)
NO_GPU(ArgSort)
NO_GPU(AsType)
NO_GPU(AsStrided)
NO_GPU(BitwiseBinary)
NO_GPU(BlockMaskedMM)
NO_GPU(BlockSparseMM)
NO_GPU(Broadcast)
NO_GPU(Ceil)
NO_GPU_MULTI(Compiled)
@@ -99,7 +103,6 @@ NO_GPU(Subtract)
NO_GPU_MULTI(SVD)
NO_GPU(Tan)
NO_GPU(Tanh)
NO_GPU(BlockMaskedMM)
NO_GPU(Transpose)
NO_GPU(Inverse)

View File

@@ -45,7 +45,8 @@ bool is_binary(const Primitive& p) {
typeid(p) == typeid(LogAddExp) || typeid(p) == typeid(Maximum) ||
typeid(p) == typeid(Minimum) || typeid(p) == typeid(Multiply) ||
typeid(p) == typeid(NotEqual) || typeid(p) == typeid(Power) ||
typeid(p) == typeid(Subtract));
typeid(p) == typeid(Subtract) || typeid(p) == typeid(BitwiseBinary) ||
typeid(p) == typeid(ArcTan2));
}
bool is_ternary(const Primitive& p) {

View File

@@ -13,7 +13,7 @@ struct Device {
static constexpr DeviceType cpu = DeviceType::cpu;
static constexpr DeviceType gpu = DeviceType::gpu;
Device(DeviceType type, int index = 0) : type(type), index(index){};
Device(DeviceType type, int index = 0) : type(type), index(index) {};
DeviceType type;
int index;

View File

@@ -51,7 +51,7 @@ struct Dtype {
Val val;
const uint8_t size;
constexpr explicit Dtype(Val val, uint8_t size) : val(val), size(size){};
constexpr explicit Dtype(Val val, uint8_t size) : val(val), size(size) {};
constexpr operator Val() const {
return val;
};

View File

@@ -10,7 +10,7 @@ namespace mlx::core {
class Event {
public:
Event(){};
Event() {};
Event(const Stream& steam);

View File

@@ -12,7 +12,7 @@ class Custom : public Primitive {
explicit Custom(
Stream stream,
std::function<std::vector<array>(std::vector<array>)> fallback)
: Primitive(stream), fallback_(fallback){};
: Primitive(stream), fallback_(fallback) {};
virtual std::pair<std::vector<array>, std::vector<int>> vmap(
const std::vector<array>& inputs,
@@ -39,7 +39,7 @@ class RMSNorm : public Custom {
Stream stream,
std::function<std::vector<array>(std::vector<array>)> fallback,
float eps)
: Custom(stream, fallback), eps_(eps){};
: Custom(stream, fallback), eps_(eps) {};
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override {
@@ -68,7 +68,7 @@ class RMSNormVJP : public Custom {
Stream stream,
std::function<std::vector<array>(std::vector<array>)> fallback,
float eps)
: Custom(stream, fallback), eps_(eps){};
: Custom(stream, fallback), eps_(eps) {};
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override {
@@ -91,7 +91,7 @@ class LayerNorm : public Custom {
Stream stream,
std::function<std::vector<array>(std::vector<array>)> fallback,
float eps)
: Custom(stream, fallback), eps_(eps){};
: Custom(stream, fallback), eps_(eps) {};
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override {
@@ -120,7 +120,7 @@ class LayerNormVJP : public Custom {
Stream stream,
std::function<std::vector<array>(std::vector<array>)> fallback,
float eps)
: Custom(stream, fallback), eps_(eps){};
: Custom(stream, fallback), eps_(eps) {};
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override {
@@ -154,7 +154,7 @@ class RoPE : public Custom {
base_(base),
scale_(scale),
offset_(offset),
forward_(forward){};
forward_(forward) {};
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override {
@@ -189,7 +189,7 @@ class ScaledDotProductAttention : public Custom {
std::function<std::vector<array>(std::vector<array>)> fallback,
const float scale,
const bool needs_mask)
: Custom(stream, fallback), scale_(scale), needs_mask_(needs_mask){};
: Custom(stream, fallback), scale_(scale), needs_mask_(needs_mask) {};
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override {
@@ -204,7 +204,7 @@ class ScaledDotProductAttention : public Custom {
void eval_gpu(const std::vector<array>& inputs, array& out);
bool is_equivalent(const Primitive& other) const override;
DEFINE_PRINT(ScaledDotProductAttention)
DEFINE_PRINT(ScaledDotProductAttention);
private:
std::function<std::vector<array>(std::vector<array>)> fallback_;

View File

@@ -6,8 +6,8 @@
#include "mlx/array.h"
#include "mlx/io/load.h"
#include "mlx/ops.h"
#include "mlx/stream.h"
#include "mlx/utils.h"
namespace mlx::core {
using GGUFMetaData =

View File

@@ -1,33 +1,58 @@
target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
${CMAKE_CURRENT_SOURCE_DIR}/safetensor.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gguf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gguf_quants.cpp
)
MESSAGE(STATUS "Downloading json")
FetchContent_Declare(json URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz)
FetchContent_MakeAvailable(json)
target_include_directories(
mlx PRIVATE
$<BUILD_INTERFACE:${json_SOURCE_DIR}/single_include/nlohmann>
)
if (MLX_BUILD_SAFETENSORS)
MESSAGE(STATUS "Downloading json")
FetchContent_Declare(json URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz)
FetchContent_MakeAvailable(json)
target_include_directories(
mlx PRIVATE
$<BUILD_INTERFACE:${json_SOURCE_DIR}/single_include/nlohmann>
)
target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/safetensors.cpp
)
else()
target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/no_safetensors.cpp
)
endif()
MESSAGE(STATUS "Downloading gguflib")
FetchContent_Declare(gguflib
GIT_REPOSITORY https://github.com/antirez/gguf-tools/
GIT_TAG af7d88d808a7608a33723fba067036202910acb3
)
FetchContent_MakeAvailable(gguflib)
target_include_directories(
mlx PRIVATE
$<BUILD_INTERFACE:${gguflib_SOURCE_DIR}>
)
if (MLX_BUILD_GGUF)
MESSAGE(STATUS "Downloading gguflib")
FetchContent_Declare(gguflib
GIT_REPOSITORY https://github.com/antirez/gguf-tools/
GIT_TAG af7d88d808a7608a33723fba067036202910acb3
)
FetchContent_MakeAvailable(gguflib)
target_include_directories(
mlx PRIVATE
$<BUILD_INTERFACE:${gguflib_SOURCE_DIR}>
)
add_library(
gguflib STATIC
${gguflib_SOURCE_DIR}/fp16.c
${gguflib_SOURCE_DIR}/gguflib.c)
target_link_libraries(mlx $<BUILD_INTERFACE:gguflib>)
target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/gguf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gguf_quants.cpp
)
else()
target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/no_gguf.cpp
)
endif()
add_library(
gguflib STATIC
${gguflib_SOURCE_DIR}/fp16.c
${gguflib_SOURCE_DIR}/gguflib.c)
target_link_libraries(mlx $<BUILD_INTERFACE:gguflib>)

View File

@@ -4,7 +4,8 @@
#include <cstring>
#include <numeric>
#include <mlx/io/gguf.h>
#include "mlx/io/gguf.h"
#include "mlx/ops.h"
namespace mlx::core {

View File

@@ -4,7 +4,7 @@
#include <cstring>
#include <numeric>
#include <mlx/io/gguf.h>
#include "mlx/io/gguf.h"
namespace mlx::core {

20
mlx/io/no_gguf.cpp Normal file
View File

@@ -0,0 +1,20 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/io.h"
namespace mlx::core {
GGUFLoad load_gguf(const std::string&, StreamOrDevice s) {
throw std::runtime_error(
"[load_gguf] Compile with MLX_BUILD_GGUF=ON to enable GGUF support.");
}
void save_gguf(
std::string,
std::unordered_map<std::string, array>,
std::unordered_map<std::string, GGUFMetaData>) {
throw std::runtime_error(
"[save_gguf] Compile with MLX_BUILD_GGUF=ON to enable GGUF support.");
}
} // namespace mlx::core

37
mlx/io/no_safetensors.cpp Normal file
View File

@@ -0,0 +1,37 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/io.h"
namespace mlx::core {
SafetensorsLoad load_safetensors(std::shared_ptr<io::Reader>, StreamOrDevice) {
throw std::runtime_error(
"[load_safetensors] Compile with MLX_BUILD_SAFETENSORS=ON "
"to enable safetensors support.");
}
SafetensorsLoad load_safetensors(const std::string&, StreamOrDevice) {
throw std::runtime_error(
"[load_safetensors] Compile with MLX_BUILD_SAFETENSORS=ON "
"to enable safetensors support.");
}
void save_safetensors(
std::shared_ptr<io::Writer>,
std::unordered_map<std::string, array>,
std::unordered_map<std::string, std::string>) {
throw std::runtime_error(
"[save_safetensors] Compile with MLX_BUILD_SAFETENSORS=ON "
"to enable safetensors support.");
}
void save_safetensors(
std::string file,
std::unordered_map<std::string, array>,
std::unordered_map<std::string, std::string>) {
throw std::runtime_error(
"[save_safetensors] Compile with MLX_BUILD_SAFETENSORS=ON "
"to enable safetensors support.");
}
} // namespace mlx::core

View File

@@ -5,6 +5,7 @@
#include "mlx/io.h"
#include "mlx/io/load.h"
#include "mlx/ops.h"
#include "mlx/primitives.h"
using json = nlohmann::json;
@@ -149,7 +150,6 @@ SafetensorsLoad load_safetensors(const std::string& file, StreamOrDevice s) {
return load_safetensors(std::make_shared<io::FileReader>(file), s);
}
/** Save array to out stream in .npy format */
void save_safetensors(
std::shared_ptr<io::Writer> out_stream,
std::unordered_map<std::string, array> a,

View File

@@ -320,7 +320,7 @@ array reshape(
"[reshape] Cannot infer the shape of an empty array");
}
// Check the the reshaping is valid
// Check that the reshaping is valid
if (a.size() != size) {
std::ostringstream msg;
msg << "[reshape] Cannot reshape array of size " << a.size()
@@ -426,7 +426,7 @@ array expand_dims(const array& a, int axis, StreamOrDevice s /* = {} */) {
int ax = axis < 0 ? axis + out_dim : axis;
if (ax < 0 || ax >= out_dim) {
std::ostringstream msg;
msg << "[expand_dims] Invalid axes " << axis << " for output array with "
msg << "[expand_dims] Invalid axis " << axis << " for output array with "
<< a.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
@@ -452,7 +452,7 @@ array expand_dims(
ax = ax < 0 ? ax + out_ndim : ax;
if (ax < 0 || ax >= out_ndim) {
std::ostringstream msg;
msg << "[expand_dims] Invalid axes " << ax << " for output array with "
msg << "[expand_dims] Invalid axis " << ax << " for output array with "
<< a.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
@@ -591,7 +591,6 @@ array slice_update(
if (!has_neg_strides && upd_shape == src.shape()) {
return astype(update_broadcasted, src.dtype(), s);
}
return array(
src.shape(),
src.dtype(),
@@ -2153,6 +2152,14 @@ array arctan(const array& a, StreamOrDevice s /* = {} */) {
a.shape(), dtype, std::make_shared<ArcTan>(to_stream(s)), {input});
}
array arctan2(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto dtype = at_least_float(promote_types(a.dtype(), b.dtype()));
auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
auto& shape = inputs[0].shape();
return array(
shape, dtype, std::make_shared<ArcTan2>(to_stream(s)), std::move(inputs));
}
array sinh(const array& a, StreamOrDevice s /* = {} */) {
auto dtype = at_least_float(a.dtype());
auto input = astype(a, dtype, s);
@@ -2947,7 +2954,8 @@ inline std::vector<int> conv_out_shape(
return out_shape;
}
inline void run_conv_checks(const array& in, const array& wt, int n_dim) {
inline void
run_conv_checks(const array& in, const array& wt, int n_dim, int groups) {
if (!issubdtype(in.dtype(), floating)) {
std::ostringstream msg;
msg << "[conv] Invalid input array with type " << in.dtype() << "."
@@ -2972,11 +2980,35 @@ inline void run_conv_checks(const array& in, const array& wt, int n_dim) {
throw std::invalid_argument(msg.str());
}
if (in.shape(n_dim + 1) != wt.shape(n_dim + 1)) {
if (in.shape(n_dim + 1) % groups != 0) {
std::ostringstream msg;
msg << "[conv] Expect the input channels in the input"
<< " and weight array to match but got shapes -"
<< " input: " << in.shape() << " and weight: " << wt.shape();
msg << "[conv] The input channels must be divisible by the number"
<< " of groups. Got input with shape " << in.shape() << " and "
<< groups << " groups.";
throw std::invalid_argument(msg.str());
}
if (groups > 1 && wt.shape(0) % groups != 0) {
std::ostringstream msg;
msg << "[conv] If groups > 1, the output channels must be divisible by the number"
<< " of groups. Got " << wt.shape(0) << " output channels and "
<< groups << " groups.";
throw std::invalid_argument(msg.str());
}
if (in.shape(n_dim + 1) != (groups * wt.shape(n_dim + 1))) {
std::ostringstream msg;
if (groups == 1) {
msg << "[conv] Expect the input channels in the input"
<< " and weight array to match but got shapes -"
<< " input: " << in.shape() << " and weight: " << wt.shape();
} else {
msg << "Given groups=" << groups << " and weights of shape " << wt.shape()
<< ", expected to have " << (groups * wt.shape(n_dim + 1))
<< " input channels but got " << in.shape(n_dim + 1)
<< " input channels instead.";
}
throw std::invalid_argument(msg.str());
}
}
@@ -3039,8 +3071,9 @@ array conv_general(
bool flip /* = false */,
StreamOrDevice s /* = {} */) {
// Run checks
if (groups != 1) {
throw std::invalid_argument("[conv] Cannot handle groups != 1 yet");
if (groups != 1 && in.ndim() != 3) {
throw std::invalid_argument(
"[conv] Can only handle groups != 1 in 1D convolutions.");
}
int spatial_dims = in.ndim() - 2;
@@ -3052,7 +3085,7 @@ array conv_general(
}
// Run checks
run_conv_checks(in, wt, spatial_dims);
run_conv_checks(in, wt, spatial_dims, groups);
// Type promotion
auto out_type = promote_types(in.dtype(), wt.dtype());
@@ -3250,7 +3283,9 @@ std::tuple<array, array, array> quantize(
}
// Compute some constants used for the quantization
int n_bins = (1 << bits) - 1; // 2**bits - 1
array n_bins((1 << bits) - 1, w.dtype()); // 2**bits - 1
array eps(1e-7, w.dtype());
array zero(0, w.dtype());
int el_per_int = 32 / bits;
array shifts = power(array(2, uint32), arange(0, 32, bits, uint32, s), s);
shifts = reshape(shifts, {1, 1, -1}, s);
@@ -3274,24 +3309,28 @@ std::tuple<array, array, array> quantize(
reshape(w, {w.shape(0), w.shape(1) / group_size, group_size}, s);
array w_max = max(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
array w_min = min(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
array delta = maximum(
divide(subtract(w_max, w_min, s), array(n_bins, w.dtype()), s),
array(1e-7, w.dtype()),
s);
array scales = squeeze(delta, -1, s);
array biases = squeeze(w_min, -1, s);
// making sure that 0 is represented exactly in the resulting quantization
biases = multiply(round(divide(biases, scales, s), s), scales, s);
array mask = greater(abs(w_min, s), abs(w_max, s), s);
array scales = maximum(divide(subtract(w_max, w_min, s), n_bins, s), eps, s);
scales = where(mask, scales, negative(scales), s);
array edge = where(mask, w_min, w_max, s);
array q0 = round(divide(edge, scales, s), s);
scales = where(not_equal(q0, zero, s), divide(edge, q0, s), scales);
array biases = where(equal(q0, zero, s), zero, edge);
// Quantize and pack w
packed_w =
astype(round(divide(subtract(packed_w, w_min, s), delta, s), s), uint32);
packed_w = astype(
clip(
round(divide(subtract(packed_w, biases, s), scales, s), s),
zero,
n_bins),
uint32);
packed_w = reshape(packed_w, {w.shape(0), -1, el_per_int}, s);
packed_w = sum(
multiply(packed_w, shifts, s), /* axis= */ 2, /* keepdims= */ false, s);
return std::make_tuple(packed_w, scales, biases);
return std::make_tuple(
packed_w, squeeze(scales, -1, s), squeeze(biases, -1, s));
}
array dequantize(
@@ -3754,6 +3793,124 @@ array block_masked_mm(
return out;
}
/** Compute matrix product with matrix-level gather */
array block_sparse_mm(
array a,
array b,
std::optional<array> lhs_indices_ /* = std::nullopt */,
std::optional<array> rhs_indices_ /* = std::nullopt */,
StreamOrDevice s /* = {} */) {
// If no indices, fall back to full matmul
if (!lhs_indices_ && !rhs_indices_) {
return matmul(a, b, s);
}
// Do shape checks for operands
int in_a_ndim = a.ndim();
int in_b_ndim = b.ndim();
if (a.ndim() == 0 || b.ndim() == 0) {
throw std::invalid_argument(
"[block_sparse_mm] Got 0 dimension input. Inputs must "
"have at least one dimension.");
}
if (a.ndim() == 1) {
// Insert a singleton dim in the beginning
a = reshape(a, {1, -1}, s);
}
if (b.ndim() == 1) {
// Insert a singleton dim at the end
b = reshape(b, {-1, 1}, s);
}
if (a.shape(-1) != b.shape(-2)) {
std::ostringstream msg;
msg << "[block_sparse_mm] Last dimension of first input with shape "
<< a.shape() << " must match second to last dimension of"
<< " second input with shape " << b.shape() << ".";
throw std::invalid_argument(msg.str());
}
// Type promotion
auto out_type = result_type(a, b);
if (!issubdtype(out_type, floating)) {
std::ostringstream msg;
msg << "[block_sparse_mm] Only real floating point types are supported but "
<< a.dtype() << " and " << b.dtype()
<< " were provided which results in " << out_type
<< ", which is not a real floating point type.";
throw std::invalid_argument(msg.str());
}
a = astype(a, out_type, s);
b = astype(b, out_type, s);
// Handle broadcasting
std::vector<int> bsx_a(a.shape().begin(), a.shape().end() - 2);
std::vector<int> bsx_b(b.shape().begin(), b.shape().end() - 2);
auto indices_or_default = [&](const std::optional<array>& indices,
const std::vector<int>& bsx_shape) {
if (indices.has_value()) {
return indices.value();
} else {
int n_batch = 1;
for (auto& i : bsx_shape)
n_batch *= i;
return reshape(arange(n_batch, uint32, s), bsx_shape, s);
}
};
// Pull and broadcast indices
array lhs_indices = indices_or_default(lhs_indices_, bsx_a);
array rhs_indices = indices_or_default(rhs_indices_, bsx_b);
if (!issubdtype(lhs_indices.dtype(), integer)) {
throw std::invalid_argument(
"[block_sparse_mm] Got lhs_indices with invalid dtype. Indices must be integral.");
}
if (!issubdtype(rhs_indices.dtype(), integer)) {
throw std::invalid_argument(
"[block_sparse_mm] Got rhs_indices with invalid dtype. Indices must be integral.");
}
lhs_indices = astype(lhs_indices, uint32, s);
rhs_indices = astype(rhs_indices, uint32, s);
int M = a.shape(-2);
int N = b.shape(-1);
int K = a.shape(-1);
auto out_bsx_shape =
broadcast_shapes(lhs_indices.shape(), rhs_indices.shape());
lhs_indices = broadcast_to(lhs_indices, out_bsx_shape, s);
rhs_indices = broadcast_to(rhs_indices, out_bsx_shape, s);
auto out_shape = out_bsx_shape;
out_shape.push_back(M);
out_shape.push_back(N);
// Caculate array
auto out = array(
out_shape,
out_type,
std::make_shared<BlockSparseMM>(to_stream(s)),
{a, b, lhs_indices, rhs_indices});
// Remove the possibly inserted singleton dimensions
if (in_a_ndim == 1 || in_b_ndim == 1) {
out_shape.erase(
out_shape.end() - ((in_a_ndim == 1) ? 2 : 1),
out_shape.end() - ((in_b_ndim == 1) ? 0 : 1));
out = reshape(out, out_shape, s);
}
return out;
}
array diagonal(
const array& a,
int offset /* = 0 */,
@@ -3944,4 +4101,77 @@ array number_of_elements(
{a}));
}
array bitwise_impl(
const array& a,
const array& b,
BitwiseBinary::Op op,
const std::string& op_name,
const StreamOrDevice& s) {
auto out_type = promote_types(a.dtype(), b.dtype());
if (!(issubdtype(out_type, integer) || out_type == bool_)) {
std::ostringstream msg;
msg << "[" << op_name
<< "] Only allowed on integer or boolean types "
"but got types "
<< a.dtype() << " and " << b.dtype() << ".";
throw std::runtime_error(msg.str());
}
auto inputs =
broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);
return array(
a.shape(),
out_type,
std::make_shared<BitwiseBinary>(to_stream(s), op),
std::move(inputs));
}
array bitwise_and(const array& a, const array& b, StreamOrDevice s /* = {} */) {
return bitwise_impl(a, b, BitwiseBinary::Op::And, "bitwise_and", s);
}
array operator&(const array& a, const array& b) {
return bitwise_and(a, b);
}
array bitwise_or(const array& a, const array& b, StreamOrDevice s /* = {} */) {
return bitwise_impl(a, b, BitwiseBinary::Op::Or, "bitwise_or", s);
}
array operator|(const array& a, const array& b) {
return bitwise_or(a, b);
}
array bitwise_xor(const array& a, const array& b, StreamOrDevice s /* = {} */) {
return bitwise_impl(a, b, BitwiseBinary::Op::Xor, "bitwise_xor", s);
}
array operator^(const array& a, const array& b) {
return bitwise_xor(a, b);
}
array left_shift(const array& a, const array& b, StreamOrDevice s /* = {} */) {
// Bit shift on bool always up-casts to uint8
auto t = promote_types(result_type(a, b), uint8);
return bitwise_impl(
astype(a, t, s),
astype(b, t, s),
BitwiseBinary::Op::LeftShift,
"left_shift",
s);
}
array operator<<(const array& a, const array& b) {
return left_shift(a, b);
}
array right_shift(const array& a, const array& b, StreamOrDevice s /* = {} */) {
// Bit shift on bool always up-casts to uint8
auto t = promote_types(result_type(a, b), uint8);
return bitwise_impl(
astype(a, t, s),
astype(b, t, s),
BitwiseBinary::Op::RightShift,
"right_shift",
s);
}
array operator>>(const array& a, const array& b) {
return right_shift(a, b);
}
} // namespace mlx::core

View File

@@ -11,7 +11,10 @@
namespace mlx::core {
/** Creation operations */
/**
* \defgroup ops Core array operations
* @{
*/
/**
* A 1D array of numbers starting at `start` (optional),
@@ -115,8 +118,6 @@ inline array tri(int n, Dtype type, StreamOrDevice s = {}) {
array tril(array x, int k = 0, StreamOrDevice s = {});
array triu(array x, int k = 0, StreamOrDevice s = {});
/** array manipulation */
/** Reshape an array to the given shape. */
array reshape(const array& a, std::vector<int> shape, StreamOrDevice s = {});
@@ -289,8 +290,6 @@ std::vector<array> broadcast_arrays(
const std::vector<array>& inputs,
StreamOrDevice s = {});
/** Comparison operations */
/** Returns the bool array with (a == b) element-wise. */
array equal(const array& a, const array& b, StreamOrDevice s = {});
inline array operator==(const array& a, const array& b) {
@@ -401,8 +400,6 @@ array where(
const array& y,
StreamOrDevice s = {});
/** Reduction operations */
/** True if all elements in the array are true (or non-zero). **/
array all(const array& a, bool keepdims, StreamOrDevice s = {});
inline array all(const array& a, StreamOrDevice s = {}) {
@@ -710,8 +707,6 @@ array logsumexp(
bool keepdims = false,
StreamOrDevice s = {});
/** Simple arithmetic operations */
/** Absolute value of elements in an array. */
array abs(const array& a, StreamOrDevice s = {});
@@ -833,6 +828,9 @@ array arccos(const array& a, StreamOrDevice s = {});
/** Arc Tangent of the elements of an array */
array arctan(const array& a, StreamOrDevice s = {});
/** Inverse tangent of the ratio of two arrays */
array arctan2(const array& a, const array& b, StreamOrDevice s = {});
/** Hyperbolic Sine of the elements of an array */
array sinh(const array& a, StreamOrDevice s = {});
@@ -1032,17 +1030,6 @@ softmax(const array& a, int axis, bool precise = false, StreamOrDevice s = {}) {
/** Raise elements of a to the power of b element-wise */
array power(const array& a, const array& b, StreamOrDevice s = {});
inline array operator^(const array& a, const array& b) {
return power(a, b);
}
template <typename T>
array operator^(T a, const array& b) {
return power(array(a), b);
}
template <typename T>
array operator^(const array& a, T b) {
return power(a, array(b));
}
/** Cumulative sum of an array. */
array cumsum(
@@ -1076,8 +1063,6 @@ array cummin(
bool inclusive = true,
StreamOrDevice s = {});
/** Convolution operations */
/** General convolution with a filter */
array conv_general(
array input,
@@ -1201,6 +1186,14 @@ array block_masked_mm(
std::optional<array> mask_rhs = std::nullopt,
StreamOrDevice s = {});
/** Compute matrix product with matrix-level gather */
array block_sparse_mm(
array a,
array b,
std::optional<array> lhs_indices = std::nullopt,
std::optional<array> rhs_indices = std::nullopt,
StreamOrDevice s = {});
/** Extract a diagonal or construct a diagonal array */
array diagonal(
const array& a,
@@ -1246,4 +1239,26 @@ array number_of_elements(
Dtype dtype = int32,
StreamOrDevice s = {});
/** Bitwise and. */
array bitwise_and(const array& a, const array& b, StreamOrDevice s = {});
array operator&(const array& a, const array& b);
/** Bitwise inclusive or. */
array bitwise_or(const array& a, const array& b, StreamOrDevice s = {});
array operator|(const array& a, const array& b);
/** Bitwise exclusive or. */
array bitwise_xor(const array& a, const array& b, StreamOrDevice s = {});
array operator^(const array& a, const array& b);
/** Shift bits to the left. */
array left_shift(const array& a, const array& b, StreamOrDevice s = {});
array operator<<(const array& a, const array& b);
/** Shift bits to the right. */
array right_shift(const array& a, const array& b, StreamOrDevice s = {});
array operator>>(const array& a, const array& b);
/** @} */
} // namespace mlx::core

Some files were not shown because too many files have changed in this diff Show More