Merge branch 'ml-explore:main' into adding-Muon-optimizer

This commit is contained in:
Gökdeniz Gülmez
2025-04-21 20:27:33 +02:00
committed by GitHub
166 changed files with 7312 additions and 2043 deletions

View File

@@ -24,8 +24,8 @@ jobs:
type: boolean
default: false
macos:
xcode: "15.2.0"
resource_class: macos.m1.medium.gen1
xcode: "16.2.0"
resource_class: m2pro.medium
steps:
- checkout
- run:
@@ -89,15 +89,14 @@ jobs:
pip install numpy
sudo apt-get update
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
- run:
name: Install Python package
command: |
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF
CMAKE_COMPILE_WARNING_AS_ERROR=ON" \
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
python3 setup.py build_ext --inplace
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF \
CMAKE_COMPILE_WARNING_AS_ERROR=ON" \
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
python3 setup.py develop
- run:
@@ -110,6 +109,8 @@ jobs:
name: Run Python tests
command: |
python3 -m unittest discover python/tests -v
mpirun --bind-to none -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py
- run:
name: Build CPP only
command: |
@@ -124,10 +125,15 @@ jobs:
parameters:
xcode_version:
type: string
default: "15.2.0"
default: "16.2.0"
macosx_deployment_target:
type: string
default: ""
macos:
xcode: << parameters.xcode_version >>
resource_class: macos.m1.medium.gen1
environment:
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
resource_class: m2pro.medium
steps:
- checkout
- run:
@@ -149,7 +155,7 @@ jobs:
command: |
source env/bin/activate
DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
CMAKE_ARGS="CMAKE_COMPILE_WARNING_AS_ERROR=ON" \
CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
pip install -e . -v
- run:
name: Generate package stubs
@@ -213,13 +219,18 @@ jobs:
default: "3.9"
xcode_version:
type: string
default: "15.2.0"
default: "16.2.0"
build_env:
type: string
default: ""
macosx_deployment_target:
type: string
default: ""
macos:
xcode: << parameters.xcode_version >>
resource_class: macos.m1.medium.gen1
resource_class: m2pro.medium
environment:
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
steps:
- checkout
- run:
@@ -240,7 +251,7 @@ jobs:
name: Install Python package
command: |
source env/bin/activate
DEV_RELEASE=1 \
env -u MACOSX_DEPLOYMENT_TARGET DEV_RELEASE=1 \
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
pip install . -v
- run:
@@ -335,7 +346,7 @@ workflows:
- mac_build_and_test:
matrix:
parameters:
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
macosx_deployment_target: ["13.5", "14.0"]
- linux_build_and_test
- build_documentation
@@ -355,8 +366,70 @@ workflows:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
xcode_version: ["15.0.0", "15.2.0"]
macosx_deployment_target: ["13.5", "14.0", "15.0"]
build_env: ["PYPI_RELEASE=1"]
xcode_version: ["16.2.0", "15.0.0"]
exclude:
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.9"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.10"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "PYPI_RELEASE=1"
- build_documentation:
filters:
tags:
@@ -379,7 +452,7 @@ workflows:
requires: [ hold ]
matrix:
parameters:
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
macosx_deployment_target: ["13.5", "14.0"]
- linux_build_and_test:
requires: [ hold ]
nightly_build:
@@ -392,7 +465,54 @@ workflows:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
xcode_version: ["15.0.0", "15.2.0"]
macosx_deployment_target: ["13.5", "14.0", "15.0"]
xcode_version: ["16.2.0", "15.0.0"]
exclude:
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.9"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.10"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
weekly_build:
when:
and:
@@ -403,8 +523,70 @@ workflows:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
macosx_deployment_target: ["13.5", "14.0", "15.0"]
build_env: ["DEV_RELEASE=1"]
xcode_version: ["16.2.0", "15.0.0"]
exclude:
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
linux_test_release:
when:
and:

View File

@@ -212,24 +212,6 @@ else()
set(MLX_BUILD_ACCELERATE OFF)
endif()
find_package(MPI)
if(MPI_FOUND)
execute_process(
COMMAND zsh "-c" "mpirun --version"
OUTPUT_VARIABLE MPI_VERSION
ERROR_QUIET)
if(${MPI_VERSION} MATCHES ".*Open MPI.*")
target_include_directories(mlx PRIVATE ${MPI_INCLUDE_PATH})
elseif(MPI_VERSION STREQUAL "")
set(MPI_FOUND FALSE)
message(
WARNING "MPI found but mpirun is not available. Building without MPI.")
else()
set(MPI_FOUND FALSE)
message(WARNING "MPI which is not OpenMPI found. Building without MPI.")
endif()
endif()
message(STATUS "Downloading json")
FetchContent_Declare(
json

View File

@@ -5,26 +5,26 @@ possible.
## Pull Requests
1. Fork and submit pull requests to the repo.
1. Fork and submit pull requests to the repo.
2. If you've added code that should be tested, add tests.
3. If a change is likely to impact efficiency, run some of the benchmarks before
and after the change. Examples of benchmarks can be found in `benchmarks/python/`.
4. If you've changed APIs, update the documentation.
5. Every PR should have passing tests and at least one review.
5. Every PR should have passing tests and at least one review.
6. For code formatting install `pre-commit` using something like `pip install pre-commit` and run `pre-commit install`.
This should install hooks for running `black` and `clang-format` to ensure
consistent style for C++ and python code.
You can also run the formatters manually as follows:
```
clang-format -i file.cpp
```
```
black file.py
```
```shell
clang-format -i file.cpp
```
```shell
black file.py
```
or run `pre-commit run --all-files` to check all files in the repo.
## Issues

View File

@@ -0,0 +1,74 @@
# Copyright © 2025 Apple Inc.
import mlx.core as mx
from time_utils import time_fn
N = 1024
D = 1024
M = 1024
E = 32
I = 4
def gather_sort(x, indices):
N, M = indices.shape
indices = indices.flatten()
order = mx.argsort(indices)
inv_order = mx.argsort(order)
return x.flatten(0, -3)[order // M], indices[order], inv_order
def scatter_unsort(x, inv_order, shape=None):
x = x[inv_order]
if shape is not None:
x = mx.unflatten(x, 0, shape)
return x
def gather_mm_simulate(x, w, indices):
x, idx, inv_order = gather_sort(x, indices)
for i in range(2):
y = mx.concatenate([x[i] @ w[j].T for i, j in enumerate(idx.tolist())], axis=0)
x = y[:, None]
x = scatter_unsort(x, inv_order, indices.shape)
return x
def time_gather_mm():
x = mx.random.normal((N, 1, 1, D)) / 1024**0.5
w1 = mx.random.normal((E, M, D)) / 1024**0.5
w2 = mx.random.normal((E, D, M)) / 1024**0.5
indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)
sorted_indices = mx.sort(indices.flatten()).reshape(N, I)
mx.eval(x, w1, w2, indices, sorted_indices)
def gather_mm(x, w1, w2, indices, sort):
idx = indices
inv_order = None
if sort:
x, idx, inv_order = gather_sort(x, indices)
x = mx.gather_mm(x, w1.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort)
x = mx.gather_mm(x, w2.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort)
if sort:
x = scatter_unsort(x, inv_order, indices.shape)
return x
time_fn(gather_mm, x, w1, w2, indices, False)
time_fn(gather_mm, x, w1, w2, sorted_indices, False)
time_fn(gather_mm, x, w1, w2, indices, True)
x = mx.random.normal((N * I, D)) / 1024**0.5
w1 = mx.random.normal((M, D)) / 1024**0.5
w2 = mx.random.normal((D, M)) / 1024**0.5
mx.eval(x, w1, w2)
def equivalent_matmul(x, w1, w2):
x = x @ w1.T
x = x @ w2.T
return x
time_fn(equivalent_matmul, x, w1, w2)
if __name__ == "__main__":
time_gather_mm()

View File

@@ -0,0 +1,84 @@
# Copyright © 2025 Apple Inc.
import mlx.core as mx
from time_utils import time_fn
N = 1024
D = 1024
M = 1024
E = 32
I = 4
def gather_sort(x, indices):
N, M = indices.shape
indices = indices.flatten()
order = mx.argsort(indices)
inv_order = mx.argsort(order)
return x.flatten(0, -3)[order // M], indices[order], inv_order
def scatter_unsort(x, inv_order, shape=None):
x = x[inv_order]
if shape is not None:
x = mx.unflatten(x, 0, shape)
return x
def gather_mm_simulate(x, w, indices):
x, idx, inv_order = gather_sort(x, indices)
for i in range(2):
y = mx.concatenate(
[
mx.quantized_matmul(x[i], w[0][j], w[1][j], w[2][j], transpose=True)
for i, j in enumerate(idx.tolist())
],
axis=0,
)
x = y[:, None]
x = scatter_unsort(x, inv_order, indices.shape)
return x
def time_gather_qmm():
x = mx.random.normal((N, 1, 1, D)) / 1024**0.5
w1 = mx.random.normal((E, M, D)) / 1024**0.5
w2 = mx.random.normal((E, D, M)) / 1024**0.5
w1 = mx.quantize(w1)
w2 = mx.quantize(w2)
indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)
sorted_indices = mx.sort(indices.flatten()).reshape(N, I)
mx.eval(x, w1, w2, indices, sorted_indices)
def gather_mm(x, w1, w2, indices, sort):
idx = indices
inv_order = None
if sort:
x, idx, inv_order = gather_sort(x, indices)
x = mx.gather_qmm(x, *w1, transpose=True, rhs_indices=idx, sorted_indices=sort)
x = mx.gather_qmm(x, *w2, transpose=True, rhs_indices=idx, sorted_indices=sort)
if sort:
x = scatter_unsort(x, inv_order, indices.shape)
return x
time_fn(gather_mm, x, w1, w2, indices, False)
time_fn(gather_mm, x, w1, w2, sorted_indices, False)
time_fn(gather_mm, x, w1, w2, indices, True)
x = mx.random.normal((N * I, D)) / 1024**0.5
w1 = mx.random.normal((M, D)) / 1024**0.5
w2 = mx.random.normal((D, M)) / 1024**0.5
w1 = mx.quantize(w1)
w2 = mx.quantize(w2)
mx.eval(x, w1, w2)
def equivalent_matmul(x, w1, w2):
x = mx.quantized_matmul(x, *w1, transpose=True)
x = mx.quantized_matmul(x, *w2, transpose=True)
return x
time_fn(equivalent_matmul, x, w1, w2)
if __name__ == "__main__":
time_gather_qmm()

View File

@@ -13,7 +13,7 @@ EXCLUDE_PATTERNS = */private/*
CREATE_SUBDIRS = NO
FULL_PATH_NAMES = YES
RECURSIVE = YES
GENERATE_HTML = YES
GENERATE_HTML = NO
GENERATE_LATEX = NO
GENERATE_XML = YES
XML_PROGRAMLISTING = YES

View File

@@ -93,9 +93,9 @@ Primitives
^^^^^^^^^^^
A :class:`Primitive` is part of the computation graph of an :class:`array`. It
defines how to create outputs arrays given a input arrays. Further, a
defines how to create output arrays given input arrays. Further, a
:class:`Primitive` has methods to run on the CPU or GPU and for function
transformations such as ``vjp`` and ``jvp``. Lets go back to our example to be
transformations such as ``vjp`` and ``jvp``. Let's go back to our example to be
more concrete:
.. code-block:: C++
@@ -128,7 +128,7 @@ more concrete:
/** The vector-Jacobian product. */
std::vector<array> vjp(
const std::vector<array>& primals,
const array& cotan,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
@@ -469,7 +469,7 @@ one we just defined:
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
// Forward mode diff that pushes along the tangents
// The jvp transform on the primitive can built with ops
// The jvp transform on the primitive can be built with ops
// that are scheduled on the same stream as the primitive
// If argnums = {0}, we only push along x in which case the
@@ -481,7 +481,7 @@ one we just defined:
auto scale_arr = array(scale, tangents[0].dtype());
return {multiply(scale_arr, tangents[0], stream())};
}
// If, argnums = {0, 1}, we take contributions from both
// If argnums = {0, 1}, we take contributions from both
// which gives us jvp = tangent_x * alpha + tangent_y * beta
else {
return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())};
@@ -735,7 +735,7 @@ Let's look at a simple script and its results:
print(f"c shape: {c.shape}")
print(f"c dtype: {c.dtype}")
print(f"c correct: {mx.all(c == 6.0).item()}")
print(f"c is correct: {mx.all(c == 6.0).item()}")
Output:
@@ -743,7 +743,7 @@ Output:
c shape: [3, 4]
c dtype: float32
c correctness: True
c is correct: True
Results
^^^^^^^

View File

@@ -70,6 +70,7 @@ are the CPU and GPU.
python/fft
python/linalg
python/metal
python/memory_management
python/nn
python/optimizers
python/distributed

View File

@@ -38,6 +38,7 @@ Array
array.log10
array.log1p
array.log2
array.logcumsumexp
array.logsumexp
array.max
array.mean

View File

@@ -20,5 +20,6 @@ Linear Algebra
eigh
lu
lu_factor
pinv
solve
solve_triangular

View File

@@ -0,0 +1,16 @@
Memory Management
=================
.. currentmodule:: mlx.core
.. autosummary::
:toctree: _autosummary
get_active_memory
get_peak_memory
reset_peak_memory
get_cache_memory
set_memory_limit
set_cache_limit
set_wired_limit
clear_cache

View File

@@ -8,13 +8,5 @@ Metal
is_available
device_info
get_active_memory
get_peak_memory
reset_peak_memory
get_cache_memory
set_memory_limit
set_cache_limit
set_wired_limit
clear_cache
start_capture
stop_capture

View File

@@ -36,10 +36,12 @@ Operations
bitwise_or
bitwise_xor
block_masked_mm
broadcast_arrays
broadcast_to
ceil
clip
concatenate
contiguous
conj
conjugate
convolve
@@ -101,6 +103,7 @@ Operations
log10
log1p
logaddexp
logcumsumexp
logical_not
logical_and
logical_or

View File

@@ -18,3 +18,4 @@ Common Optimizers
AdamW
Adamax
Lion
MultiOptimizer

View File

@@ -9,6 +9,7 @@ Transforms
:toctree: _autosummary
eval
async_eval
compile
custom_function
disable_compile

View File

@@ -5,6 +5,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
${CMAKE_CURRENT_SOURCE_DIR}/dtype_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/export.cpp
${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp

View File

@@ -4,7 +4,6 @@
#include <sstream>
#include "mlx/allocator.h"
#include "mlx/scheduler.h"
namespace mlx::core::allocator {
@@ -22,23 +21,4 @@ void free(Buffer buffer) {
allocator().free(buffer);
}
Buffer CommonAllocator::malloc(size_t size) {
void* ptr = std::malloc(size + sizeof(size_t));
if (ptr != nullptr) {
*static_cast<size_t*>(ptr) = size;
}
return Buffer{ptr};
}
void CommonAllocator::free(Buffer buffer) {
std::free(buffer.ptr());
}
size_t CommonAllocator::size(Buffer buffer) const {
if (buffer.ptr() == nullptr) {
return 0;
}
return *static_cast<size_t*>(buffer.ptr());
}
} // namespace mlx::core::allocator

View File

@@ -49,16 +49,4 @@ class Allocator {
Allocator& allocator();
class CommonAllocator : public Allocator {
/** A general CPU allocator. */
public:
virtual Buffer malloc(size_t size) override;
virtual void free(Buffer buffer) override;
virtual size_t size(Buffer buffer) const override;
private:
CommonAllocator() = default;
friend Allocator& allocator();
};
} // namespace mlx::core::allocator

View File

@@ -339,11 +339,11 @@ class array {
return allocator::allocator().size(buffer());
}
// Return a copy of the shared pointer
// to the array::Data struct
std::shared_ptr<Data> data_shared_ptr() const {
// Return the shared pointer to the array::Data struct
const std::shared_ptr<Data>& data_shared_ptr() const {
return array_desc_->data;
}
// Return a raw pointer to the arrays data
template <typename T>
T* data() {

View File

@@ -1,6 +1,7 @@
target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/broadcasting.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp

View File

@@ -0,0 +1,24 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/common/utils.h"
namespace mlx::core {
void broadcast(const array& in, array& out) {
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
Strides strides(out.ndim(), 0);
int diff = out.ndim() - in.ndim();
for (int i = in.ndim() - 1; i >= 0; --i) {
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
}
auto flags = in.flags();
if (out.size() > in.size()) {
flags.row_contiguous = flags.col_contiguous = false;
}
out.copy_shared_buffer(in, strides, flags, in.data_size());
}
} // namespace mlx::core

View File

@@ -0,0 +1,11 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/array.h"
namespace mlx::core {
void broadcast(const array& in, array& out);
} // namespace mlx::core

View File

@@ -1,6 +1,7 @@
// Copyright © 2024 Apple Inc.
#include <cassert>
#include "mlx/backend/common/broadcasting.h"
#include "mlx/backend/common/utils.h"
#include "mlx/primitives.h"
@@ -42,23 +43,6 @@ void AsStrided::eval(const std::vector<array>& inputs, array& out) {
return out.copy_shared_buffer(in, strides_, flags, data_size, offset_);
}
void broadcast(const array& in, array& out) {
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
Strides strides(out.ndim(), 0);
int diff = out.ndim() - in.ndim();
for (int i = in.ndim() - 1; i >= 0; --i) {
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
}
auto flags = in.flags();
if (out.size() > in.size()) {
flags.row_contiguous = flags.col_contiguous = false;
}
out.copy_shared_buffer(in, strides, flags, in.data_size());
}
void Broadcast::eval(const std::vector<array>& inputs, array& out) {
broadcast(inputs[0], out);
}

View File

@@ -58,6 +58,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
@@ -73,8 +74,8 @@ target_sources(
if(MLX_BUILD_ACCELERATE)
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/bnns.cpp)
else()
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/no_fp16.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gemms/no_bf16.cpp)
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/simd_fp16.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gemms/simd_bf16.cpp)
endif()
if(IOS)

View File

@@ -46,8 +46,15 @@ void AllReduce::eval_cpu(
case Sum:
distributed::detail::all_sum(group(), in, outputs[0], stream());
break;
case Max:
distributed::detail::all_max(group(), in, outputs[0], stream());
break;
case Min:
distributed::detail::all_min(group(), in, outputs[0], stream());
break;
default:
throw std::runtime_error("Only all reduce sum is supported for now");
throw std::runtime_error(
"Only all reduce sum, min and max are supported for now");
}
}

View File

@@ -1,27 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cpu/gemm.h"
namespace mlx::core {
template <>
void matmul<bfloat16_t>(
const bfloat16_t*,
const bfloat16_t*,
bfloat16_t*,
bool,
bool,
size_t,
size_t,
size_t,
float,
float,
size_t,
const Shape&,
const Strides&,
const Shape&,
const Strides&) {
throw std::runtime_error("[Matmul::eval_cpu] bfloat16 not supported.");
}
} // namespace mlx::core

View File

@@ -1,27 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cpu/gemm.h"
namespace mlx::core {
template <>
void matmul<float16_t>(
const float16_t*,
const float16_t*,
float16_t*,
bool,
bool,
size_t,
size_t,
size_t,
float,
float,
size_t,
const Shape&,
const Strides&,
const Shape&,
const Strides&) {
throw std::runtime_error("[Matmul::eval_cpu] float16 not supported.");
}
} // namespace mlx::core

View File

@@ -0,0 +1,45 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/gemm.h"
#include "mlx/backend/cpu/gemms/simd_gemm.h"
namespace mlx::core {
template <>
void matmul<bfloat16_t>(
const bfloat16_t* a,
const bfloat16_t* b,
bfloat16_t* out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
size_t ldc,
float alpha,
float beta,
size_t batch_size,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides) {
auto ndim = a_shape.size();
size_t M = a_shape[ndim - 2];
size_t N = b_shape[ndim - 1];
size_t K = a_shape[ndim - 1];
for (int i = 0; i < batch_size; ++i) {
simd_gemm<bfloat16_t, float>(
a + elem_to_loc(M * K * i, a_shape, a_strides),
b + elem_to_loc(K * N * i, b_shape, b_strides),
out + M * N * i,
a_transposed,
b_transposed,
M,
N,
K,
alpha,
beta);
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,45 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/gemm.h"
#include "mlx/backend/cpu/gemms/simd_gemm.h"
namespace mlx::core {
template <>
void matmul<float16_t>(
const float16_t* a,
const float16_t* b,
float16_t* out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
size_t ldc,
float alpha,
float beta,
size_t batch_size,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides) {
auto ndim = a_shape.size();
size_t M = a_shape[ndim - 2];
size_t N = b_shape[ndim - 1];
size_t K = a_shape[ndim - 1];
for (int i = 0; i < batch_size; ++i) {
simd_gemm<float16_t, float>(
a + elem_to_loc(M * K * i, a_shape, a_strides),
b + elem_to_loc(K * N * i, b_shape, b_strides),
out + M * N * i,
a_transposed,
b_transposed,
M,
N,
K,
alpha,
beta);
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,139 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/backend/cpu/simd/simd.h"
namespace mlx::core {
inline int ceildiv(int a, int b) {
return (a + b - 1) / b;
}
template <int block_size, typename T, typename AccT>
void load_block(
const T* in,
AccT* out,
int M,
int N,
int i,
int j,
bool transpose) {
if (transpose) {
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
out[jj * block_size + ii] =
in[(i * block_size + ii) * N + j * block_size + jj];
}
}
} else {
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
out[ii * block_size + jj] =
in[(i * block_size + ii) * N + j * block_size + jj];
}
}
}
}
template <typename T, typename AccT>
void simd_gemm(
const T* a,
const T* b,
T* c,
bool a_trans,
bool b_trans,
int M,
int N,
int K,
float alpha,
float beta) {
constexpr int block_size = 16;
constexpr int simd_size = simd::max_size<AccT>;
static_assert(
(block_size % simd_size) == 0,
"Block size must be divisible by SIMD size");
int last_k_block_size = K - block_size * (K / block_size);
int last_k_simd_block = (last_k_block_size / simd_size) * simd_size;
for (int i = 0; i < ceildiv(M, block_size); i++) {
for (int j = 0; j < ceildiv(N, block_size); j++) {
AccT c_block[block_size * block_size] = {0.0};
AccT a_block[block_size * block_size];
AccT b_block[block_size * block_size];
int k = 0;
for (; k < K / block_size; k++) {
// Load a and b blocks
if (a_trans) {
load_block<block_size>(a, a_block, K, M, k, i, true);
} else {
load_block<block_size>(a, a_block, M, K, i, k, false);
}
if (b_trans) {
load_block<block_size>(b, b_block, N, K, j, k, false);
} else {
load_block<block_size>(b, b_block, K, N, k, j, true);
}
// Multiply and accumulate
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
for (int kk = 0; kk < block_size; kk += simd_size) {
auto av =
simd::load<AccT, simd_size>(a_block + ii * block_size + kk);
auto bv =
simd::load<AccT, simd_size>(b_block + jj * block_size + kk);
c_block[ii * block_size + jj] += simd::sum(av * bv);
}
}
}
}
if (last_k_block_size) {
// Load a and b blocks
if (a_trans) {
load_block<block_size>(a, a_block, K, M, k, i, true);
} else {
load_block<block_size>(a, a_block, M, K, i, k, false);
}
if (b_trans) {
load_block<block_size>(b, b_block, N, K, j, k, false);
} else {
load_block<block_size>(b, b_block, K, N, k, j, true);
}
// Multiply and accumulate
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
int kk = 0;
for (; kk < last_k_simd_block; kk += simd_size) {
auto av =
simd::load<AccT, simd_size>(a_block + ii * block_size + kk);
auto bv =
simd::load<AccT, simd_size>(b_block + jj * block_size + kk);
c_block[ii * block_size + jj] += simd::sum(av * bv);
}
for (; kk < last_k_block_size; ++kk) {
c_block[ii * block_size + jj] +=
a_block[ii * block_size + kk] * b_block[jj * block_size + kk];
}
}
}
}
// Store
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
auto c_idx = (i * block_size + ii) * N + j * block_size + jj;
if (beta != 0) {
c[c_idx] = static_cast<T>(
alpha * c_block[ii * block_size + jj] + beta * c[c_idx]);
} else {
c[c_idx] = static_cast<T>(alpha * c_block[ii * block_size + jj]);
}
}
}
}
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,140 @@
// Copyright © 2023-2024 Apple Inc.
#include <cassert>
#include <cmath>
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h"
#include "mlx/primitives.h"
#include "mlx/types/limits.h"
namespace mlx::core {
namespace {
using namespace mlx::core::simd;
template <typename T, typename AccT>
void logsumexp(const array& in, array& out, Stream stream) {
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(in);
encoder.set_output_array(out);
const T* in_ptr = in.data<T>();
T* out_ptr = out.data<T>();
int M = in.shape().back();
int L = in.data_size() / M;
encoder.dispatch([in_ptr, out_ptr, M, L]() mutable {
constexpr int N = std::min(max_size<AccT>, max_size<T>);
const T* current_in_ptr;
for (int i = 0; i < L; i++, in_ptr += M, out_ptr += 1) {
// Find the maximum
current_in_ptr = in_ptr;
Simd<AccT, N> vmaximum(-numeric_limits<AccT>::infinity());
size_t s = M;
while (s >= N) {
Simd<AccT, N> vals = load<T, N>(current_in_ptr);
vmaximum = maximum(vals, vmaximum);
current_in_ptr += N;
s -= N;
}
AccT maximum = max(vmaximum);
while (s-- > 0) {
maximum = std::max(maximum, static_cast<AccT>(*current_in_ptr));
current_in_ptr++;
}
// Compute the normalizer and the exponentials
Simd<AccT, N> vnormalizer(0.0);
current_in_ptr = in_ptr;
s = M;
while (s >= N) {
Simd<AccT, N> vexp = load<T, N>(current_in_ptr);
vexp = exp(vexp - maximum);
vnormalizer = vnormalizer + vexp;
current_in_ptr += N;
s -= N;
}
AccT normalizer = sum(vnormalizer);
while (s-- > 0) {
AccT _exp = std::exp(*current_in_ptr - maximum);
normalizer += _exp;
current_in_ptr++;
}
// Normalize
*out_ptr = std::isinf(maximum)
? static_cast<T>(maximum)
: static_cast<T>(std::log(normalizer) + maximum);
}
});
}
} // namespace
void LogSumExp::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
// Make sure that the last dimension is contiguous
auto s = stream();
auto& encoder = cpu::get_command_encoder(s);
auto ensure_contiguous = [&s, &encoder](const array& x) {
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
return x;
} else {
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy(x, x_copy, CopyType::General, s);
encoder.add_temporary(x_copy);
return x_copy;
}
};
auto in = ensure_contiguous(inputs[0]);
if (in.flags().row_contiguous) {
out.set_data(allocator::malloc(out.nbytes()));
} else {
auto n = in.shape(-1);
auto flags = in.flags();
auto strides = in.strides();
for (auto& s : strides) {
s /= n;
}
bool col_contig = strides[0] == 1;
for (int i = 1; col_contig && i < strides.size(); ++i) {
col_contig &=
(out.shape(i) == 1 || strides[i - 1] == out.shape(i) * strides[i]);
}
flags.col_contiguous = col_contig;
out.set_data(
allocator::malloc(in.nbytes() / n),
in.data_size() / n,
std::move(strides),
flags);
}
switch (in.dtype()) {
case float32:
logsumexp<float, float>(in, out, stream());
break;
case float16:
logsumexp<float16_t, float>(in, out, stream());
break;
case bfloat16:
logsumexp<bfloat16_t, float>(in, out, stream());
break;
case float64:
logsumexp<double, double>(in, out, stream());
break;
default:
throw std::runtime_error(
"[logsumexp] only supports floating point types");
break;
}
}
} // namespace mlx::core

View File

@@ -205,8 +205,10 @@ void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
void Contiguous::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.flags().row_contiguous ||
(allow_col_major_ && in.flags().col_contiguous)) {
constexpr size_t extra_bytes = 16384;
if (in.buffer_size() <= out.nbytes() + extra_bytes &&
(in.flags().row_contiguous ||
(allow_col_major_ && in.flags().col_contiguous))) {
out.copy_shared_buffer(in);
} else {
copy(in, out, CopyType::General, stream());

View File

@@ -3,6 +3,7 @@
#include <cassert>
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/binary_ops.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h"
@@ -226,6 +227,16 @@ void scan_dispatch(
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
break;
}
case Scan::LogAddExp: {
auto op = [](U a, T b) {
return detail::LogAddExp{}(a, static_cast<U>(b));
};
auto init = (issubdtype(in.dtype(), floating))
? static_cast<U>(-std::numeric_limits<float>::infinity())
: std::numeric_limits<U>::min();
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
break;
}
}
}

View File

@@ -17,7 +17,7 @@ struct ScalarT<float16_t, N> {
#endif
template <>
static constexpr int max_size<float16_t> = N;
inline constexpr int max_size<float16_t> = N;
#define SIMD_FP16_DEFAULT_UNARY(op) \
template <> \

View File

@@ -83,25 +83,25 @@ struct Simd {
// Values chosen based on benchmarks on M3 Max
// TODO: consider choosing these more optimally
template <>
static constexpr int max_size<int8_t> = 16;
inline constexpr int max_size<int8_t> = 16;
template <>
static constexpr int max_size<int16_t> = 16;
inline constexpr int max_size<int16_t> = 16;
template <>
static constexpr int max_size<int> = 8;
inline constexpr int max_size<int> = 8;
template <>
static constexpr int max_size<int64_t> = 4;
inline constexpr int max_size<int64_t> = 4;
template <>
static constexpr int max_size<uint8_t> = 16;
inline constexpr int max_size<uint8_t> = 16;
template <>
static constexpr int max_size<uint16_t> = 16;
inline constexpr int max_size<uint16_t> = 16;
template <>
static constexpr int max_size<uint32_t> = 8;
inline constexpr int max_size<uint32_t> = 8;
template <>
static constexpr int max_size<uint64_t> = 4;
inline constexpr int max_size<uint64_t> = 4;
template <>
static constexpr int max_size<float> = 8;
inline constexpr int max_size<float> = 8;
template <>
static constexpr int max_size<double> = 4;
inline constexpr int max_size<double> = 4;
#define SIMD_DEFAULT_UNARY(name, op) \
template <typename T, int N> \

View File

@@ -87,7 +87,6 @@ DEFAULT_UNARY(cosh, std::cosh)
DEFAULT_UNARY(expm1, std::expm1)
DEFAULT_UNARY(floor, std::floor)
DEFAULT_UNARY(log, std::log)
DEFAULT_UNARY(log2, std::log2)
DEFAULT_UNARY(log10, std::log10)
DEFAULT_UNARY(log1p, std::log1p)
DEFAULT_UNARY(sinh, std::sinh)
@@ -95,6 +94,17 @@ DEFAULT_UNARY(sqrt, std::sqrt)
DEFAULT_UNARY(tan, std::tan)
DEFAULT_UNARY(tanh, std::tanh)
template <typename T>
Simd<T, 1> log2(Simd<T, 1> in) {
if constexpr (is_complex<T>) {
auto out = std::log(in.value);
auto scale = decltype(out.real())(M_LN2);
return Simd<T, 1>{T{out.real() / scale, out.imag() / scale}};
} else {
return Simd<T, 1>{std::log2(in.value)};
}
}
template <typename T>
Simd<T, 1> operator~(Simd<T, 1> in) {
return ~in.value;

View File

@@ -119,12 +119,7 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
// Make sure that the last dimension is contiguous
auto set_output = [s = stream(), &out](const array& x) {
bool no_copy = x.strides()[x.ndim() - 1] == 1;
if (x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2];
no_copy &= (s == 0 || s == x.shape().back());
}
if (no_copy) {
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
if (x.is_donatable()) {
out.copy_shared_buffer(x);
} else {
@@ -146,18 +141,6 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
auto in = set_output(inputs[0]);
switch (in.dtype()) {
case bool_:
case uint8:
case uint16:
case uint32:
case uint64:
case int8:
case int16:
case int32:
case int64:
throw std::runtime_error(
"Softmax is defined only for floating point types");
break;
case float32:
softmax<float, float>(in, out, stream());
break;
@@ -178,9 +161,9 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
case float64:
softmax<double, double>(in, out, stream());
break;
case complex64:
throw std::invalid_argument(
"[Softmax] Not yet implemented for complex64");
default:
throw std::runtime_error(
"[softmax] Only defined for floating point types.");
break;
}
}

View File

@@ -1,5 +1,8 @@
// Copyright © 2024 Apple Inc.
// Required for using M_LN2 in MSVC.
#define _USE_MATH_DEFINES
#include <cassert>
#include "mlx/backend/cpu/unary.h"

View File

@@ -86,13 +86,14 @@ struct Sign {
template <int N, typename T>
Simd<T, N> operator()(Simd<T, N> x) {
auto z = Simd<T, N>{0};
auto o = Simd<T, N>{1};
auto m = Simd<T, N>{-1};
if constexpr (std::is_unsigned_v<T>) {
return x != z;
return simd::select(x == z, z, o);
} else if constexpr (std::is_same_v<T, complex64_t>) {
return simd::select(x == z, x, Simd<T, N>(x / simd::abs(x)));
} else {
return simd::select(
x < z, Simd<T, N>{-1}, simd::select(x > z, Simd<T, N>{1}, z));
return simd::select(x < z, m, simd::select(x > z, o, z));
}
}
SINGLE()

View File

@@ -47,6 +47,7 @@ if(MLX_METAL_JIT)
make_jit_source(binary)
make_jit_source(binary_two)
make_jit_source(fft kernels/fft/radix.h kernels/fft/readwrite.h)
make_jit_source(logsumexp)
make_jit_source(ternary)
make_jit_source(softmax)
make_jit_source(scan)
@@ -60,6 +61,7 @@ if(MLX_METAL_JIT)
kernels/steel/gemm/transforms.h)
make_jit_source(steel/gemm/kernels/steel_gemm_fused)
make_jit_source(steel/gemm/kernels/steel_gemm_masked kernels/steel/defines.h)
make_jit_source(steel/gemm/kernels/steel_gemm_gather)
make_jit_source(steel/gemm/kernels/steel_gemm_splitk)
make_jit_source(
steel/conv/conv
@@ -95,6 +97,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cpp
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp

View File

@@ -3,6 +3,7 @@
#include "mlx/backend/metal/metal.h"
#include "mlx/backend/metal/metal_impl.h"
#include "mlx/backend/metal/resident.h"
#include "mlx/memory.h"
#include <mach/vm_page_size.h>
#include <unistd.h>
@@ -32,8 +33,11 @@ namespace metal {
namespace {
BufferCache::BufferCache(MTL::Device* device)
: device_(device), head_(nullptr), tail_(nullptr), pool_size_(0) {}
BufferCache::BufferCache(ResidencySet& residency_set)
: head_(nullptr),
tail_(nullptr),
pool_size_(0),
residency_set_(residency_set) {}
BufferCache::~BufferCache() {
auto pool = metal::new_scoped_memory_pool();
@@ -44,6 +48,9 @@ int BufferCache::clear() {
int n_release = 0;
for (auto& [size, holder] : buffer_pool_) {
if (holder->buf) {
if (!holder->buf->heap()) {
residency_set_.erase(holder->buf);
}
holder->buf->release();
n_release++;
}
@@ -101,6 +108,9 @@ int BufferCache::release_cached_buffers(size_t min_bytes_to_free) {
while (tail_ && (total_bytes_freed < min_bytes_to_free)) {
if (tail_->buf) {
total_bytes_freed += tail_->buf->length();
if (!tail_->buf->heap()) {
residency_set_.erase(tail_->buf);
}
tail_->buf->release();
tail_->buf = nullptr;
n_release++;
@@ -155,7 +165,7 @@ void BufferCache::remove_from_list(BufferCache::BufferHolder* to_remove) {
MetalAllocator::MetalAllocator()
: device_(device(mlx::core::Device::gpu).mtl_device()),
residency_set_(device_),
buffer_cache_(device_) {
buffer_cache_(residency_set_) {
auto pool = metal::new_scoped_memory_pool();
auto memsize = std::get<size_t>(device_info().at("memory_size"));
auto max_rec_size =
@@ -262,9 +272,13 @@ Buffer MetalAllocator::malloc(size_t size) {
if (!buf) {
buf = device_->newBuffer(size, resource_options);
}
if (!buf) {
return Buffer{nullptr};
}
lk.lock();
if (buf) {
num_resources_++;
num_resources_++;
if (!buf->heap()) {
residency_set_.insert(buf);
}
}
@@ -278,10 +292,6 @@ Buffer MetalAllocator::malloc(size_t size) {
get_cache_memory() - max_pool_size_);
}
if (!buf->heap()) {
residency_set_.insert(buf);
}
return Buffer{static_cast<void*>(buf)};
}
@@ -297,14 +307,14 @@ void MetalAllocator::free(Buffer buffer) {
return;
}
std::unique_lock lk(mutex_);
if (!buf->heap()) {
residency_set_.erase(buf);
}
active_memory_ -= buf->length();
if (get_cache_memory() < max_pool_size_) {
buffer_cache_.recycle_to_cache(buf);
} else {
num_resources_--;
if (!buf->heap()) {
residency_set_.erase(buf);
}
lk.unlock();
auto pool = metal::new_scoped_memory_pool();
buf->release();
@@ -323,40 +333,40 @@ MetalAllocator& allocator() {
return *allocator_;
}
} // namespace metal
size_t set_cache_limit(size_t limit) {
return allocator().set_cache_limit(limit);
return metal::allocator().set_cache_limit(limit);
}
size_t set_memory_limit(size_t limit) {
return allocator().set_memory_limit(limit);
return metal::allocator().set_memory_limit(limit);
}
size_t get_memory_limit() {
return allocator().get_memory_limit();
return metal::allocator().get_memory_limit();
}
size_t set_wired_limit(size_t limit) {
if (limit >
std::get<size_t>(device_info().at("max_recommended_working_set_size"))) {
if (limit > std::get<size_t>(metal::device_info().at(
"max_recommended_working_set_size"))) {
throw std::invalid_argument(
"[metal::set_wired_limit] Setting a wired limit larger than "
"the maximum working set size is not allowed.");
}
return allocator().set_wired_limit(limit);
return metal::allocator().set_wired_limit(limit);
}
size_t get_active_memory() {
return allocator().get_active_memory();
return metal::allocator().get_active_memory();
}
size_t get_peak_memory() {
return allocator().get_peak_memory();
return metal::allocator().get_peak_memory();
}
void reset_peak_memory() {
allocator().reset_peak_memory();
metal::allocator().reset_peak_memory();
}
size_t get_cache_memory() {
return allocator().get_cache_memory();
return metal::allocator().get_cache_memory();
}
void clear_cache() {
return allocator().clear_cache();
return metal::allocator().clear_cache();
}
} // namespace metal
} // namespace mlx::core

View File

@@ -18,7 +18,7 @@ namespace {
class BufferCache {
public:
BufferCache(MTL::Device* device);
BufferCache(ResidencySet& residency_set);
~BufferCache();
MTL::Buffer* reuse_from_cache(size_t size);
@@ -42,13 +42,11 @@ class BufferCache {
void add_at_head(BufferHolder* to_add);
void remove_from_list(BufferHolder* to_remove);
MTL::Device* device_;
MTL::Heap* heap_{nullptr};
std::multimap<size_t, BufferHolder*> buffer_pool_;
BufferHolder* head_;
BufferHolder* tail_;
size_t pool_size_;
ResidencySet& residency_set_;
};
} // namespace

View File

@@ -712,6 +712,65 @@ void winograd_conv_2D_gpu(
}
}
void depthwise_conv_2D_gpu(
const Stream& s,
metal::Device& d,
const array& in,
const array& wt,
array out,
const MLXConvParams<2>& conv_params) {
std::ostringstream kname;
kname << "depthwise_conv_2d_" << type_to_name(out);
std::string base_name = kname.str();
const int N = conv_params.N;
const int ker_h = conv_params.wS[0];
const int ker_w = conv_params.wS[1];
const int str_h = conv_params.str[0];
const int str_w = conv_params.str[1];
const int tc = 8;
const int tw = 8;
const int th = 4;
const bool do_flip = conv_params.flip;
metal::MTLFCList func_consts = {
{&ker_h, MTL::DataType::DataTypeInt, 00},
{&ker_w, MTL::DataType::DataTypeInt, 01},
{&str_h, MTL::DataType::DataTypeInt, 10},
{&str_w, MTL::DataType::DataTypeInt, 11},
{&th, MTL::DataType::DataTypeInt, 100},
{&tw, MTL::DataType::DataTypeInt, 101},
{&do_flip, MTL::DataType::DataTypeBool, 200},
};
// clang-format off
kname << "_ker_h_" << ker_h
<< "_ker_w_" << ker_w
<< "_str_h_" << str_h
<< "_str_w_" << str_w
<< "_tgp_h_" << th
<< "_tgp_w_" << tw
<< "_do_flip_" << (do_flip ? 't' : 'n'); // clang-format on
std::string hash_name = kname.str();
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_input_array(wt, 1);
compute_encoder.set_output_array(out, 2);
compute_encoder.set_bytes(conv_params, 3);
MTL::Size group_dims = MTL::Size(tc, tw, th);
MTL::Size grid_dims = MTL::Size(
conv_params.C / tc, conv_params.oS[1] / tw, (conv_params.oS[0] / th) * N);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
void conv_2D_gpu(
const Stream& s,
metal::Device& d,
@@ -754,11 +813,20 @@ void conv_2D_gpu(
bool is_kdil_one = conv_params.kdil[0] == 1 && conv_params.kdil[1] == 1;
bool is_idil_one = conv_params.idil[0] == 1 && conv_params.idil[1] == 1;
if (groups > 1) {
if (is_idil_one && groups > 1) {
const int C_per_group = conv_params.C / groups;
const int O_per_group = conv_params.O / groups;
if (is_idil_one && (C_per_group <= 4 || C_per_group % 16 == 0) &&
if (C_per_group == 1 && O_per_group == 1 && is_kdil_one &&
conv_params.wS[0] <= 7 && conv_params.wS[1] <= 7 &&
conv_params.str[0] <= 2 && conv_params.str[1] <= 2 &&
conv_params.oS[0] % 8 == 0 && conv_params.oS[1] % 8 == 0 &&
conv_params.wt_strides[1] == conv_params.wS[1] &&
conv_params.C % 16 == 0 && conv_params.C == conv_params.O) {
return depthwise_conv_2D_gpu(s, d, in, wt, out, conv_params);
}
if ((C_per_group <= 4 || C_per_group % 16 == 0) &&
(O_per_group <= 16 || O_per_group % 16 == 0)) {
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
} else {

View File

@@ -55,7 +55,10 @@ std::pair<MTL::Library*, NS::Error*> load_library_from_path(
}
#ifdef SWIFTPM_BUNDLE
MTL::Library* try_load_bundle(MTL::Device* device, NS::URL* url) {
MTL::Library* try_load_bundle(
MTL::Device* device,
NS::URL* url,
const std::string& lib_name) {
std::string bundle_path = std::string(url->fileSystemRepresentation()) + "/" +
SWIFTPM_BUNDLE + ".bundle";
auto bundle = NS::Bundle::alloc()->init(
@@ -63,8 +66,8 @@ MTL::Library* try_load_bundle(MTL::Device* device, NS::URL* url) {
if (bundle != nullptr) {
std::string resource_path =
std::string(bundle->resourceURL()->fileSystemRepresentation()) + "/" +
"default.metallib";
auto [lib, error] = load_library_from_path(device, resource_path.c_str());
lib_name + ".metallib" auto [lib, error] =
load_library_from_path(device, resource_path.c_str());
if (lib) {
return lib;
}
@@ -73,51 +76,124 @@ MTL::Library* try_load_bundle(MTL::Device* device, NS::URL* url) {
}
#endif
// Firstly, search for the metallib in the same path as this binary
std::pair<MTL::Library*, NS::Error*> load_colocated_library(
MTL::Device* device,
const std::string& lib_name) {
std::string lib_path = get_colocated_mtllib_path(lib_name);
if (lib_path.size() != 0) {
return load_library_from_path(device, lib_path.c_str());
}
return {nullptr, nullptr};
}
std::pair<MTL::Library*, NS::Error*> load_swiftpm_library(
MTL::Device* device,
const std::string& lib_name) {
#ifdef SWIFTPM_BUNDLE
MTL::Library* library =
try_load_bundle(device, NS::Bundle::mainBundle()->bundleURL(), lib_name);
if (library != nullptr) {
return {library, nullptr};
}
auto bundles = NS::Bundle::allBundles();
for (int i = 0, c = (int)bundles->count(); i < c; i++) {
auto bundle = reinterpret_cast<NS::Bundle*>(bundles->object(i));
library = try_load_bundle(device, bundle->resourceURL());
if (library != nullptr) {
return {library, nullptr};
}
}
#endif
return {nullptr, nullptr};
}
MTL::Library* load_default_library(MTL::Device* device) {
NS::Error *error1, *error2, *error3;
MTL::Library* lib;
// First try the colocated mlx.metallib
std::tie(lib, error1) = load_colocated_library(device, "mlx");
if (lib) {
return lib;
}
// Then try default.metallib in a SwiftPM bundle if we have one
std::tie(lib, error2) = load_swiftpm_library(device, "default");
if (lib) {
return lib;
}
// Finally try default_mtllib_path
std::tie(lib, error3) = load_library_from_path(device, default_mtllib_path);
if (!lib) {
std::ostringstream msg;
msg << "Failed to load the default metallib. ";
if (error1 != nullptr) {
msg << error1->localizedDescription()->utf8String() << " ";
}
if (error2 != nullptr) {
msg << error2->localizedDescription()->utf8String() << " ";
}
if (error3 != nullptr) {
msg << error3->localizedDescription()->utf8String() << " ";
}
throw std::runtime_error(msg.str());
}
return lib;
}
MTL::Library* load_library(
MTL::Device* device,
const std::string& lib_name = "mlx",
const char* lib_path = default_mtllib_path) {
// Firstly, search for the metallib in the same path as this binary
std::string first_path = get_colocated_mtllib_path(lib_name);
if (first_path.size() != 0) {
auto [lib, error] = load_library_from_path(device, first_path.c_str());
const std::string& lib_name,
const std::string& lib_path) {
// We have been given a path that ends in metallib so try to load it
if (lib_path.size() > 9 &&
std::equal(lib_path.end() - 9, lib_path.end(), ".metallib")) {
auto [lib, error] = load_library_from_path(device, lib_path.c_str());
if (!lib) {
std::ostringstream msg;
msg << "Failed to load the metallib from <" << lib_path << "> with error "
<< error->localizedDescription()->utf8String();
throw std::runtime_error(msg.str());
}
}
// We have been given a path so try to load from lib_path / lib_name.metallib
if (lib_path.size() > 0) {
std::string full_path = lib_path + "/" + lib_name + ".metallib";
auto [lib, error] = load_library_from_path(device, full_path.c_str());
if (!lib) {
std::ostringstream msg;
msg << "Failed to load the metallib from <" << full_path
<< "> with error " << error->localizedDescription()->utf8String();
throw std::runtime_error(msg.str());
}
}
// Try to load the colocated library
{
auto [lib, error] = load_colocated_library(device, lib_name);
if (lib) {
return lib;
}
}
#ifdef SWIFTPM_BUNDLE
// try to load from a swiftpm resource bundle -- scan the available bundles to
// find one that contains the named bundle
// Try to load the library from swiftpm
{
MTL::Library* library =
try_load_bundle(device, NS::Bundle::mainBundle()->bundleURL());
if (library != nullptr) {
return library;
}
auto bundles = NS::Bundle::allBundles();
for (int i = 0, c = (int)bundles->count(); i < c; i++) {
auto bundle = reinterpret_cast<NS::Bundle*>(bundles->object(i));
library = try_load_bundle(device, bundle->resourceURL());
if (library != nullptr) {
return library;
}
auto [lib, error] = load_swiftpm_library(device, lib_name);
if (lib) {
return lib;
}
}
#endif
// Couldn't find it so let's load it from default_mtllib_path
{
auto [lib, error] = load_library_from_path(device, lib_path);
if (!lib) {
std::ostringstream msg;
msg << error->localizedDescription()->utf8String() << "\n"
<< "Failed to load device library from <" << lib_path << ">"
<< " or <" << first_path << ">.";
throw std::runtime_error(msg.str());
}
return lib;
}
std::ostringstream msg;
msg << "Failed to load the metallib " << lib_name << ".metallib. "
<< "We attempted to load it from <" << get_colocated_mtllib_path(lib_name)
<< ">";
#ifdef SWIFTPM_BUNDLE
msg << " and from the Swift PM bundle.";
#endif
throw std::runtime_error(msg.str());
}
} // namespace
@@ -210,7 +286,7 @@ void CommandEncoder::barrier() {
Device::Device() {
auto pool = new_scoped_memory_pool();
device_ = load_device();
library_map_ = {{"mlx", load_library(device_)}};
library_map_ = {{"mlx", load_default_library(device_)}};
arch_ = std::string(device_->architecture()->name()->utf8String());
auto arch = arch_.back();
switch (arch) {

View File

@@ -189,15 +189,7 @@ class Device {
void register_library(
const std::string& lib_name,
const std::string& lib_path);
// Note, this should remain in the header so that it is not dynamically
// linked
void register_library(const std::string& lib_name) {
if (auto it = library_map_.find(lib_name); it == library_map_.end()) {
register_library(lib_name, get_colocated_mtllib_path(lib_name));
}
}
const std::string& lib_path = "");
MTL::Library* get_library(
const std::string& name,

View File

@@ -24,10 +24,6 @@ void Event::wait() {
}
}
void Event::signal() {
static_cast<MTL::SharedEvent*>(event_.get())->setSignaledValue(value());
}
void Event::wait(Stream stream) {
if (stream.device == Device::cpu) {
scheduler::enqueue(stream, [*this]() mutable { wait(); });
@@ -42,7 +38,9 @@ void Event::wait(Stream stream) {
void Event::signal(Stream stream) {
if (stream.device == Device::cpu) {
scheduler::enqueue(stream, [*this]() mutable { signal(); });
scheduler::enqueue(stream, [*this]() mutable {
static_cast<MTL::SharedEvent*>(event_.get())->setSignaledValue(value());
});
} else {
auto& d = metal::device(stream.device);
d.end_encoding(stream.index);

View File

@@ -356,20 +356,14 @@ void multi_upload_bluestein_fft(
bool inverse,
bool real,
FFTPlan& plan,
std::vector<array> copies,
std::vector<array>& copies,
const Stream& s) {
// TODO(alexbarron) Implement fused kernels for mutli upload bluestein's
// algorithm
int n = inverse ? out.shape(axis) : in.shape(axis);
auto [w_k, w_q] = compute_bluestein_constants(n, plan.bluestein_n);
// Broadcast w_q and w_k to the batch size
Strides b_strides(in.ndim(), 0);
b_strides[axis] = 1;
array w_k_broadcast({}, complex64, nullptr, {});
array w_q_broadcast({}, complex64, nullptr, {});
w_k_broadcast.copy_shared_buffer(w_k, b_strides, {}, w_k.data_size());
w_q_broadcast.copy_shared_buffer(w_q, b_strides, {}, w_q.data_size());
copies.push_back(w_k);
copies.push_back(w_q);
auto temp_shape = inverse ? out.shape() : in.shape();
array temp(temp_shape, complex64, nullptr, {});
@@ -378,13 +372,13 @@ void multi_upload_bluestein_fft(
if (real && !inverse) {
// Convert float32->complex64
copy_gpu(in, temp, CopyType::General, s);
copies.push_back(temp);
} else if (real && inverse) {
int back_offset = n % 2 == 0 ? 2 : 1;
auto slice_shape = in.shape();
slice_shape[axis] -= back_offset;
array slice_temp(slice_shape, complex64, nullptr, {});
array conj_temp(in.shape(), complex64, nullptr, {});
copies.push_back(slice_temp);
copies.push_back(conj_temp);
Shape rstarts(in.ndim(), 0);
@@ -394,19 +388,28 @@ void multi_upload_bluestein_fft(
unary_op_gpu({in}, conj_temp, "Conjugate", s);
slice_gpu(in, slice_temp, rstarts, rstrides, s);
concatenate_gpu({conj_temp, slice_temp}, temp, (int)axis, s);
copies.push_back(temp);
} else if (inverse) {
unary_op_gpu({in}, temp, "Conjugate", s);
copies.push_back(temp);
} else {
temp.copy_shared_buffer(in);
}
Strides b_strides(in.ndim(), 0);
b_strides[axis] = 1;
array w_k_broadcast(temp.shape(), complex64, nullptr, {});
w_k_broadcast.copy_shared_buffer(w_k, b_strides, {}, w_k.data_size());
binary_op_gpu({temp, w_k_broadcast}, temp1, "Multiply", s);
std::vector<std::pair<int, int>> pads;
auto padded_shape = out.shape();
padded_shape[axis] = plan.bluestein_n;
array pad_temp(padded_shape, complex64, nullptr, {});
pad_gpu(temp1, array(complex64_t{0.0f, 0.0f}), pad_temp, {(int)axis}, {0}, s);
auto zero = array(complex64_t{0.0f, 0.0f});
copies.push_back(zero);
pad_gpu(temp1, zero, pad_temp, {(int)axis}, {0}, s);
copies.push_back(pad_temp);
array pad_temp1(padded_shape, complex64, nullptr, {});
fft_op(
@@ -418,7 +421,10 @@ void multi_upload_bluestein_fft(
FourStepParams(),
/*inplace=*/false,
s);
copies.push_back(pad_temp1);
array w_q_broadcast(pad_temp1.shape(), complex64, nullptr, {});
w_q_broadcast.copy_shared_buffer(w_q, b_strides, {}, w_q.data_size());
binary_op_gpu_inplace({pad_temp1, w_q_broadcast}, pad_temp, "Multiply", s);
fft_op(
@@ -435,9 +441,11 @@ void multi_upload_bluestein_fft(
Shape starts(in.ndim(), 0);
Shape strides(in.ndim(), 1);
starts[axis] = plan.bluestein_n - offset - n;
slice_gpu(pad_temp1, temp, starts, strides, s);
binary_op_gpu_inplace({temp, w_k_broadcast}, temp1, "Multiply", s);
array temp2(temp_shape, complex64, nullptr, {});
slice_gpu(pad_temp1, temp2, starts, strides, s);
binary_op_gpu_inplace({temp2, w_k_broadcast}, temp1, "Multiply", s);
if (real && !inverse) {
Shape rstarts(in.ndim(), 0);
@@ -449,26 +457,21 @@ void multi_upload_bluestein_fft(
array temp_float(out.shape(), out.dtype(), nullptr, {});
copies.push_back(temp_float);
copies.push_back(inv_n);
copies.push_back(temp1);
copy_gpu(temp1, temp_float, CopyType::General, s);
binary_op_gpu({temp_float, inv_n}, out, "Multiply", s);
} else if (inverse) {
auto inv_n = array({1.0f / n}, {1}, complex64);
unary_op_gpu({temp1}, temp, "Conjugate", s);
binary_op_gpu({temp, inv_n}, out, "Multiply", s);
array temp3(temp_shape, complex64, nullptr, {});
unary_op_gpu({temp1}, temp3, "Conjugate", s);
binary_op_gpu({temp3, inv_n}, out, "Multiply", s);
copies.push_back(inv_n);
copies.push_back(temp1);
copies.push_back(temp3);
} else {
out.copy_shared_buffer(temp1);
}
copies.push_back(w_k);
copies.push_back(w_q);
copies.push_back(w_k_broadcast);
copies.push_back(w_q_broadcast);
copies.push_back(temp);
copies.push_back(temp1);
copies.push_back(pad_temp);
copies.push_back(pad_temp1);
}
void four_step_fft(
@@ -478,8 +481,9 @@ void four_step_fft(
bool inverse,
bool real,
FFTPlan& plan,
std::vector<array> copies,
const Stream& s) {
std::vector<array>& copies,
const Stream& s,
bool in_place) {
auto& d = metal::device(s.device);
if (plan.bluestein_n == -1) {
@@ -492,7 +496,14 @@ void four_step_fft(
in, temp, axis, inverse, real, four_step_params, /*inplace=*/false, s);
four_step_params.first_step = false;
fft_op(
temp, out, axis, inverse, real, four_step_params, /*inplace=*/false, s);
temp,
out,
axis,
inverse,
real,
four_step_params,
/*inplace=*/in_place,
s);
copies.push_back(temp);
} else {
multi_upload_bluestein_fft(in, out, axis, inverse, real, plan, copies, s);
@@ -574,7 +585,7 @@ void fft_op(
auto plan = plan_fft(n);
if (plan.four_step) {
four_step_fft(in, out, axis, inverse, real, plan, copies, s);
four_step_fft(in, out, axis, inverse, real, plan, copies, s, inplace);
d.add_temporaries(std::move(copies), s.index);
return;
}

View File

@@ -1,9 +0,0 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view arange_kernels = R"(
template [[host_name("{0}")]] [[kernel]] void arange<{1}>(
constant const {1}& start,
constant const {1}& step,
device {1}* out,
uint index [[thread_position_in_grid]]);
)";

View File

@@ -20,6 +20,7 @@ const char* copy();
const char* fft();
const char* gather_axis();
const char* hadamard();
const char* logsumexp();
const char* quantized();
const char* ternary();
const char* scan();
@@ -32,6 +33,7 @@ const char* gemm();
const char* steel_gemm_fused();
const char* steel_gemm_masked();
const char* steel_gemm_splitk();
const char* steel_gemm_gather();
const char* conv();
const char* steel_conv();
const char* steel_conv_general();

View File

@@ -1,23 +0,0 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view softmax_kernels = R"(
template [[host_name("block_{0}")]] [[kernel]] void
softmax_single_row<{1}, {2}>(
const device {1}* in,
device {1}* out,
constant int& axis_size,
uint gid [[thread_position_in_grid]],
uint _lid [[thread_position_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
template [[host_name("looped_{0}")]] [[kernel]] void
softmax_looped<{1}, {2}>(
const device {1}* in,
device {1}* out,
constant int& axis_size,
uint gid [[threadgroup_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
)";

View File

@@ -1,8 +1,6 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/metal/jit/arange.h"
#include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/jit/softmax.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h"
@@ -21,13 +19,11 @@ MTL::ComputePipelineState* get_arange_kernel(
const std::string& kernel_name,
const array& out) {
auto lib = d.get_library(kernel_name, [&]() {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::arange()
<< fmt::format(
arange_kernels,
kernel_name,
get_type_string(out.dtype()));
return kernel_source.str();
std::string kernel_source = metal::utils();
kernel_source += metal::arange();
kernel_source += get_template_definition(
kernel_name, "arange", get_type_string(out.dtype()));
return kernel_source;
});
return d.get_kernel(kernel_name, lib);
}
@@ -259,14 +255,34 @@ MTL::ComputePipelineState* get_softmax_kernel(
const array& out) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name, [&] {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::softmax()
<< fmt::format(
softmax_kernels,
lib_name,
get_type_string(out.dtype()),
get_type_string(precise ? float32 : out.dtype()));
return kernel_source.str();
std::string kernel_source = metal::utils();
auto in_type = get_type_string(out.dtype());
auto acc_type = get_type_string(precise ? float32 : out.dtype());
kernel_source += metal::softmax();
kernel_source += get_template_definition(
"block_" + lib_name, "softmax_single_row", in_type, acc_type);
kernel_source += get_template_definition(
"looped_" + lib_name, "softmax_looped", in_type, acc_type);
return kernel_source;
});
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_logsumexp_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name, [&] {
auto t_str = get_type_string(out.dtype());
std::string kernel_source;
kernel_source = metal::utils();
kernel_source += metal::logsumexp();
kernel_source +=
get_template_definition("block_" + lib_name, "logsumexp", t_str);
kernel_source += get_template_definition(
"looped_" + lib_name, "logsumexp_looped", t_str);
return kernel_source;
});
return d.get_kernel(kernel_name, lib);
}
@@ -568,6 +584,44 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_steel_gemm_gather_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const array& out,
bool transpose_a,
bool transpose_b,
int bm,
int bn,
int bk,
int wm,
int wn,
bool rhs) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name, [&]() {
std::string kernel_source;
concatenate(
kernel_source,
metal::utils(),
metal::gemm(),
metal::steel_gemm_gather(),
get_template_definition(
lib_name,
rhs ? "gather_mm_rhs" : "gather_mm",
get_type_string(out.dtype()),
bm,
bn,
bk,
wm,
wn,
transpose_a,
transpose_b));
return kernel_source;
});
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
}
MTL::ComputePipelineState* get_gemv_masked_kernel(
metal::Device& d,
const std::string& kernel_name,
@@ -698,4 +752,43 @@ MTL::ComputePipelineState* get_quantized_kernel(
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_gather_qmm_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const array& x,
int group_size,
int bits,
int bm,
int bn,
int bk,
int wm,
int wn,
bool transpose) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name, [&]() {
std::string kernel_source;
concatenate(
kernel_source,
metal::utils(),
metal::gemm(),
metal::quantized(),
get_template_definition(
lib_name,
"gather_qmm_rhs",
get_type_string(x.dtype()),
group_size,
bits,
bm,
bn,
bk,
wm,
wn,
transpose));
return kernel_source;
});
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
}
} // namespace mlx::core

View File

@@ -59,6 +59,11 @@ MTL::ComputePipelineState* get_softmax_kernel(
bool precise,
const array& out);
MTL::ComputePipelineState* get_logsumexp_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out);
MTL::ComputePipelineState* get_scan_kernel(
metal::Device& d,
const std::string& kernel_name,
@@ -155,6 +160,21 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
bool mn_aligned,
bool k_aligned);
MTL::ComputePipelineState* get_steel_gemm_gather_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const array& out,
bool transpose_a,
bool transpose_b,
int bm,
int bn,
int bk,
int wm,
int wn,
bool rhs);
MTL::ComputePipelineState* get_steel_conv_kernel(
metal::Device& d,
const std::string& kernel_name,
@@ -204,6 +224,21 @@ MTL::ComputePipelineState* get_quantized_kernel(
const std::string& kernel_name,
const std::string& template_def);
MTL::ComputePipelineState* get_gather_qmm_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const array& x,
int group_size,
int bits,
int bm,
int bn,
int bk,
int wm,
int wn,
bool transpose);
// Create a GPU kernel template definition for JIT compilation
template <typename... Args>
std::string

View File

@@ -13,6 +13,10 @@ function(build_kernel_base TARGET SRCFILE DEPS)
if(MLX_METAL_DEBUG)
set(METAL_FLAGS ${METAL_FLAGS} -gline-tables-only -frecord-sources)
endif()
if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "")
set(METAL_FLAGS ${METAL_FLAGS}
"-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
endif()
if(MLX_METAL_VERSION GREATER_EQUAL 310)
set(VERSION_INCLUDES
${PROJECT_SOURCE_DIR}/mlx/backend/metal/kernels/metal_3_1)
@@ -65,6 +69,7 @@ set(STEEL_HEADERS
steel/gemm/loader.h
steel/gemm/transforms.h
steel/gemm/kernels/steel_gemm_fused.h
steel/gemm/kernels/steel_gemm_gather.h
steel/gemm/kernels/steel_gemm_masked.h
steel/gemm/kernels/steel_gemm_splitk.h
steel/utils/type_traits.h
@@ -105,12 +110,14 @@ if(NOT MLX_METAL_JIT)
build_kernel(quantized quantized.h ${STEEL_HEADERS})
build_kernel(scan scan.h)
build_kernel(softmax softmax.h)
build_kernel(logsumexp logsumexp.h)
build_kernel(sort sort.h)
build_kernel(ternary ternary.h ternary_ops.h)
build_kernel(unary unary.h unary_ops.h)
build_kernel(steel/conv/kernels/steel_conv ${STEEL_HEADERS})
build_kernel(steel/conv/kernels/steel_conv_general ${STEEL_HEADERS})
build_kernel(steel/gemm/kernels/steel_gemm_fused ${STEEL_HEADERS})
build_kernel(steel/gemm/kernels/steel_gemm_gather ${STEEL_HEADERS})
build_kernel(steel/gemm/kernels/steel_gemm_masked ${STEEL_HEADERS})
build_kernel(steel/gemm/kernels/steel_gemm_splitk ${STEEL_HEADERS})
build_kernel(gemv_masked steel/utils.h)

View File

@@ -5,11 +5,7 @@
#include "mlx/backend/metal/kernels/arange.h"
#define instantiate_arange(tname, type) \
template [[host_name("arange" #tname)]] [[kernel]] void arange<type>( \
constant const type& start, \
constant const type& step, \
device type* out, \
uint index [[thread_position_in_grid]]);
instantiate_kernel("arange" #tname, arange, type)
instantiate_arange(uint8, uint8_t)
instantiate_arange(uint16, uint16_t)

View File

@@ -275,6 +275,128 @@ instantiate_naive_conv_2d_blocks(float32, float);
instantiate_naive_conv_2d_blocks(float16, half);
instantiate_naive_conv_2d_blocks(bfloat16, bfloat16_t);
///////////////////////////////////////////////////////////////////////////////
/// Depthwise convolution kernels
///////////////////////////////////////////////////////////////////////////////
constant int ker_h [[function_constant(00)]];
constant int ker_w [[function_constant(01)]];
constant int str_h [[function_constant(10)]];
constant int str_w [[function_constant(11)]];
constant int tgp_h [[function_constant(100)]];
constant int tgp_w [[function_constant(101)]];
constant bool do_flip [[function_constant(200)]];
constant int span_h = tgp_h * str_h + ker_h - 1;
constant int span_w = tgp_w * str_w + ker_w - 1;
constant int span_hw = span_h * span_w;
template <typename T>
[[kernel]] void depthwise_conv_2d(
const device T* in [[buffer(0)]],
const device T* wt [[buffer(1)]],
device T* out [[buffer(2)]],
const constant MLXConvParams<2>& params [[buffer(3)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 gid [[thread_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int tc = 8;
constexpr int tw = 8;
constexpr int th = 4;
constexpr int c_per_thr = 8;
constexpr int TGH = th * 2 + 6;
constexpr int TGW = tw * 2 + 6;
constexpr int TGC = tc;
threadgroup T ins[TGH * TGW * TGC];
const int n_tgblocks_h = params.oS[0] / th;
const int n = tid.z / n_tgblocks_h;
const int tghid = tid.z % n_tgblocks_h;
const int oh = tghid * th + lid.z;
const int ow = gid.y;
const int c = gid.x;
in += n * params.in_strides[0];
// Load in
{
constexpr int n_threads = th * tw * tc;
const int tg_oh = (tghid * th) * str_h - params.pad[0];
const int tg_ow = (tid.y * tw) * str_w - params.pad[1];
const int tg_c = tid.x * tc;
const int thread_idx = simd_gid * 32 + simd_lid;
constexpr int thr_per_hw = tc / c_per_thr;
constexpr int hw_per_group = n_threads / thr_per_hw;
const int thr_c = thread_idx % thr_per_hw;
const int thr_hw = thread_idx / thr_per_hw;
for (int hw = thr_hw; hw < span_hw; hw += hw_per_group) {
const int h = hw / span_w;
const int w = hw % span_w;
const int ih = tg_oh + h;
const int iw = tg_ow + w;
const int in_s_offset = h * span_w * TGC + w * TGC;
if (ih >= 0 && ih < params.iS[0] && iw >= 0 && iw < params.iS[1]) {
const auto in_load =
in + ih * params.in_strides[1] + iw * params.in_strides[2] + tg_c;
MLX_MTL_PRAGMA_UNROLL
for (int cc = 0; cc < c_per_thr; ++cc) {
ins[in_s_offset + c_per_thr * thr_c + cc] =
in_load[c_per_thr * thr_c + cc];
}
} else {
MLX_MTL_PRAGMA_UNROLL
for (int cc = 0; cc < c_per_thr; ++cc) {
ins[in_s_offset + c_per_thr * thr_c + cc] = T(0);
}
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
wt += c * params.wt_strides[0];
const auto ins_ptr =
&ins[lid.z * str_h * span_w * TGC + lid.y * str_w * TGC + lid.x];
float o = 0.;
for (int h = 0; h < ker_h; ++h) {
for (int w = 0; w < ker_w; ++w) {
int wt_h = h;
int wt_w = w;
if (do_flip) {
wt_h = ker_h - h - 1;
wt_w = ker_w - w - 1;
}
auto inv = ins_ptr[h * span_w * TGC + w * TGC];
auto wtv = wt[wt_h * ker_w + wt_w];
o += inv * wtv;
}
}
threadgroup_barrier(mem_flags::mem_none);
out += n * params.out_strides[0] + oh * params.out_strides[1] +
ow * params.out_strides[2];
out[c] = static_cast<T>(o);
}
#define instantiate_depthconv2d(iname, itype) \
instantiate_kernel("depthwise_conv_2d_" #iname, depthwise_conv_2d, itype)
instantiate_depthconv2d(float32, float);
instantiate_depthconv2d(float16, half);
instantiate_depthconv2d(bfloat16, bfloat16_t);
///////////////////////////////////////////////////////////////////////////////
/// Winograd kernels
///////////////////////////////////////////////////////////////////////////////

View File

@@ -483,4 +483,4 @@ template <
perform_fft(fft_idx, &p, m, n, buf);
read_writer.write_strided(stride, overall_n);
}
}

View File

@@ -341,7 +341,7 @@ struct GEMVTKernel {
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
auto vc = float(v_coeff[tm]);
auto vc = static_cast<AccT>(v_coeff[tm]);
for (int tn = 0; tn < TN; tn++) {
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
}

View File

@@ -493,71 +493,11 @@ template <typename T, int N_READS = RMS_N_READS>
}
// clang-format off
#define instantiate_layer_norm_single_row(name, itype) \
template [[host_name("layer_norm" #name)]] [[kernel]] void \
layer_norm_single_row<itype>( \
const device itype* x, \
const device itype* w, \
const device itype* b, \
device itype* out, \
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
constant uint& b_stride, \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
template [[host_name("vjp_layer_norm" #name)]] [[kernel]] void \
vjp_layer_norm_single_row<itype>( \
const device itype* x, \
const device itype* w, \
const device itype* g, \
device itype* gx, \
device itype* gw, \
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_layer_norm_looped(name, itype) \
template [[host_name("layer_norm_looped" #name)]] [[kernel]] void \
layer_norm_looped<itype>( \
const device itype* x, \
const device itype* w, \
const device itype* b, \
device itype* out, \
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
constant uint& b_stride, \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
template [[host_name("vjp_layer_norm_looped" #name)]] [[kernel]] void \
vjp_layer_norm_looped<itype>( \
const device itype* x, \
const device itype* w, \
const device itype* g, \
device itype* gx, \
device itype* gb, \
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_layer_norm(name, itype) \
instantiate_layer_norm_single_row(name, itype) \
instantiate_layer_norm_looped(name, itype)
#define instantiate_layer_norm(name, itype) \
instantiate_kernel("layer_norm" #name, layer_norm_single_row, itype) \
instantiate_kernel("vjp_layer_norm" #name, vjp_layer_norm_single_row, itype) \
instantiate_kernel("layer_norm_looped" #name, layer_norm_looped, itype) \
instantiate_kernel("vjp_layer_norm_looped" #name, vjp_layer_norm_looped, itype)
instantiate_layer_norm(float32, float)
instantiate_layer_norm(float16, half)

View File

@@ -0,0 +1,143 @@
// Copyright © 2025 Apple Inc.
template <typename T, typename AccT = float, int N_READS = 4>
[[kernel]] void logsumexp(
const device T* in,
device T* out,
constant int& axis_size,
uint gid [[threadgroup_position_in_grid]],
uint _lid [[thread_position_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
int lid = _lid;
constexpr int SIMD_SIZE = 32;
threadgroup AccT local_max[SIMD_SIZE];
threadgroup AccT local_normalizer[SIMD_SIZE];
AccT ld[N_READS];
in += gid * size_t(axis_size) + lid * N_READS;
if (lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
ld[i] = AccT(in[i]);
}
} else {
for (int i = 0; i < N_READS; i++) {
ld[i] =
((lid * N_READS + i) < axis_size) ? AccT(in[i]) : Limits<AccT>::min;
}
}
if (simd_group_id == 0) {
local_max[simd_lane_id] = Limits<AccT>::min;
local_normalizer[simd_lane_id] = 0;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Get the max
AccT maxval = Limits<AccT>::finite_min;
for (int i = 0; i < N_READS; i++) {
maxval = (maxval < ld[i]) ? ld[i] : maxval;
}
maxval = simd_max(maxval);
if (simd_lane_id == 0) {
local_max[simd_group_id] = maxval;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_group_id == 0) {
maxval = simd_max(local_max[simd_lane_id]);
if (simd_lane_id == 0) {
local_max[0] = maxval;
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
maxval = local_max[0];
// Compute exp(x_i - maxval) and store the partial sums in local_normalizer
AccT normalizer = 0;
for (int i = 0; i < N_READS; i++) {
normalizer += fast::exp(ld[i] - maxval);
}
normalizer = simd_sum(normalizer);
if (simd_lane_id == 0) {
local_normalizer[simd_group_id] = normalizer;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_group_id == 0) {
normalizer = simd_sum(local_normalizer[simd_lane_id]);
if (simd_lane_id == 0) {
out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval);
}
}
}
template <typename T, typename AccT = float, int N_READS = 4>
[[kernel]] void logsumexp_looped(
const device T* in,
device T* out,
constant int& axis_size,
uint gid [[threadgroup_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
in += gid * size_t(axis_size);
constexpr int SIMD_SIZE = 32;
threadgroup AccT local_max[SIMD_SIZE];
threadgroup AccT local_normalizer[SIMD_SIZE];
// Get the max and the normalizer in one go
AccT prevmax;
AccT maxval = Limits<AccT>::finite_min;
AccT normalizer = 0;
for (int r = 0; r < static_cast<int>(ceildiv(axis_size, N_READS * lsize));
r++) {
int offset = r * lsize * N_READS + lid * N_READS;
AccT vals[N_READS];
if (offset + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
vals[i] = AccT(in[offset + i]);
}
} else {
for (int i = 0; i < N_READS; i++) {
vals[i] = (offset + i < axis_size) ? AccT(in[offset + i])
: Limits<AccT>::finite_min;
}
}
prevmax = maxval;
for (int i = 0; i < N_READS; i++) {
maxval = (maxval < vals[i]) ? vals[i] : maxval;
}
normalizer *= fast::exp(prevmax - maxval);
for (int i = 0; i < N_READS; i++) {
normalizer += fast::exp(vals[i] - maxval);
}
}
prevmax = maxval;
maxval = simd_max(maxval);
normalizer *= fast::exp(prevmax - maxval);
normalizer = simd_sum(normalizer);
prevmax = maxval;
if (simd_lane_id == 0) {
local_max[simd_group_id] = maxval;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
maxval = simd_max(local_max[simd_lane_id]);
normalizer *= fast::exp(prevmax - maxval);
if (simd_lane_id == 0) {
local_normalizer[simd_group_id] = normalizer;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
normalizer = simd_sum(local_normalizer[simd_lane_id]);
if (simd_group_id == 0) {
normalizer = simd_sum(local_normalizer[simd_lane_id]);
if (simd_lane_id == 0) {
out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval);
}
}
}

View File

@@ -0,0 +1,18 @@
// Copyright © 2023-2024 Apple Inc.
#include <metal_common>
#include <metal_simdgroup>
using namespace metal;
// clang-format off
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/logsumexp.h"
#define instantiate_logsumexp(name, itype) \
instantiate_kernel("block_logsumexp_" #name, logsumexp, itype) \
instantiate_kernel("looped_logsumexp_" #name, logsumexp_looped, itype) \
instantiate_logsumexp(float32, float)
instantiate_logsumexp(float16, half)
instantiate_logsumexp(bfloat16, bfloat16_t) // clang-format on

View File

@@ -3,6 +3,10 @@
#include <metal_simdgroup>
#include <metal_stdlib>
constant bool align_M [[function_constant(200)]];
constant bool align_N [[function_constant(201)]];
constant bool align_K [[function_constant(202)]];
using namespace metal;
#define MLX_MTL_CONST static constant constexpr const
@@ -586,13 +590,13 @@ METAL_FUNC void qmv_quad_impl(
// Adjust positions
const int in_vec_size_w = in_vec_size / pack_factor;
const int in_vec_size_g = in_vec_size / group_size;
const int out_row = tid.x * quads_per_simd * results_per_quadgroup + quad_gid;
const int out_row = tid.y * quads_per_simd * results_per_quadgroup + quad_gid;
w += out_row * in_vec_size_w + quad_lid * packs_per_thread;
scales += out_row * in_vec_size_g + quad_lid / scale_step_per_thread;
biases += out_row * in_vec_size_g + quad_lid / scale_step_per_thread;
x += tid.y * in_vec_size + quad_lid * values_per_thread;
y += tid.y * out_vec_size + out_row;
x += tid.x * in_vec_size + quad_lid * values_per_thread;
y += tid.x * out_vec_size + out_row;
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
@@ -1686,26 +1690,26 @@ template <
}
template <typename T, int group_size, int bits>
[[kernel]] void bs_qmv_fast(
[[kernel]] void gather_qmv_fast(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]],
device T* y [[buffer(4)]],
const constant int& in_vec_size [[buffer(5)]],
const constant int& out_vec_size [[buffer(6)]],
const constant int& x_batch_ndims [[buffer(7)]],
const constant int* x_shape [[buffer(8)]],
const constant int64_t* x_strides [[buffer(9)]],
const constant int& w_batch_ndims [[buffer(10)]],
const constant int* w_shape [[buffer(11)]],
const constant int64_t* w_strides [[buffer(12)]],
const constant int64_t* s_strides [[buffer(13)]],
const constant int64_t* b_strides [[buffer(14)]],
const constant int& batch_ndims [[buffer(15)]],
const constant int* batch_shape [[buffer(16)]],
const device uint32_t* lhs_indices [[buffer(17)]],
const device uint32_t* rhs_indices [[buffer(18)]],
const device uint32_t* lhs_indices [[buffer(4)]],
const device uint32_t* rhs_indices [[buffer(5)]],
device T* y [[buffer(6)]],
const constant int& in_vec_size [[buffer(7)]],
const constant int& out_vec_size [[buffer(8)]],
const constant int& x_batch_ndims [[buffer(9)]],
const constant int* x_shape [[buffer(10)]],
const constant int64_t* x_strides [[buffer(11)]],
const constant int& w_batch_ndims [[buffer(12)]],
const constant int* w_shape [[buffer(13)]],
const constant int64_t* w_strides [[buffer(14)]],
const constant int64_t* s_strides [[buffer(15)]],
const constant int64_t* b_strides [[buffer(16)]],
const constant int& batch_ndims [[buffer(17)]],
const constant int* batch_shape [[buffer(18)]],
const constant int64_t* lhs_strides [[buffer(19)]],
const constant int64_t* rhs_strides [[buffer(20)]],
uint3 tid [[threadgroup_position_in_grid]],
@@ -1748,26 +1752,26 @@ template <typename T, int group_size, int bits>
}
template <typename T, int group_size, int bits>
[[kernel]] void bs_qmv(
[[kernel]] void gather_qmv(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]],
device T* y [[buffer(4)]],
const constant int& in_vec_size [[buffer(5)]],
const constant int& out_vec_size [[buffer(6)]],
const constant int& x_batch_ndims [[buffer(7)]],
const constant int* x_shape [[buffer(8)]],
const constant int64_t* x_strides [[buffer(9)]],
const constant int& w_batch_ndims [[buffer(10)]],
const constant int* w_shape [[buffer(11)]],
const constant int64_t* w_strides [[buffer(12)]],
const constant int64_t* s_strides [[buffer(13)]],
const constant int64_t* b_strides [[buffer(14)]],
const constant int& batch_ndims [[buffer(15)]],
const constant int* batch_shape [[buffer(16)]],
const device uint32_t* lhs_indices [[buffer(17)]],
const device uint32_t* rhs_indices [[buffer(18)]],
const device uint32_t* lhs_indices [[buffer(4)]],
const device uint32_t* rhs_indices [[buffer(5)]],
device T* y [[buffer(6)]],
const constant int& in_vec_size [[buffer(7)]],
const constant int& out_vec_size [[buffer(8)]],
const constant int& x_batch_ndims [[buffer(9)]],
const constant int* x_shape [[buffer(10)]],
const constant int64_t* x_strides [[buffer(11)]],
const constant int& w_batch_ndims [[buffer(12)]],
const constant int* w_shape [[buffer(13)]],
const constant int64_t* w_strides [[buffer(14)]],
const constant int64_t* s_strides [[buffer(15)]],
const constant int64_t* b_strides [[buffer(16)]],
const constant int& batch_ndims [[buffer(17)]],
const constant int* batch_shape [[buffer(18)]],
const constant int64_t* lhs_strides [[buffer(19)]],
const constant int64_t* rhs_strides [[buffer(20)]],
uint3 tid [[threadgroup_position_in_grid]],
@@ -1810,26 +1814,26 @@ template <typename T, int group_size, int bits>
}
template <typename T, int group_size, int bits>
[[kernel]] void bs_qvm(
[[kernel]] void gather_qvm(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]],
device T* y [[buffer(4)]],
const constant int& in_vec_size [[buffer(5)]],
const constant int& out_vec_size [[buffer(6)]],
const constant int& x_batch_ndims [[buffer(7)]],
const constant int* x_shape [[buffer(8)]],
const constant int64_t* x_strides [[buffer(9)]],
const constant int& w_batch_ndims [[buffer(10)]],
const constant int* w_shape [[buffer(11)]],
const constant int64_t* w_strides [[buffer(12)]],
const constant int64_t* s_strides [[buffer(13)]],
const constant int64_t* b_strides [[buffer(14)]],
const constant int& batch_ndims [[buffer(15)]],
const constant int* batch_shape [[buffer(16)]],
const device uint32_t* lhs_indices [[buffer(17)]],
const device uint32_t* rhs_indices [[buffer(18)]],
const device uint32_t* lhs_indices [[buffer(4)]],
const device uint32_t* rhs_indices [[buffer(5)]],
device T* y [[buffer(6)]],
const constant int& in_vec_size [[buffer(7)]],
const constant int& out_vec_size [[buffer(8)]],
const constant int& x_batch_ndims [[buffer(9)]],
const constant int* x_shape [[buffer(10)]],
const constant int64_t* x_strides [[buffer(11)]],
const constant int& w_batch_ndims [[buffer(12)]],
const constant int* w_shape [[buffer(13)]],
const constant int64_t* w_strides [[buffer(14)]],
const constant int64_t* s_strides [[buffer(15)]],
const constant int64_t* b_strides [[buffer(16)]],
const constant int& batch_ndims [[buffer(17)]],
const constant int* batch_shape [[buffer(18)]],
const constant int64_t* lhs_strides [[buffer(19)]],
const constant int64_t* rhs_strides [[buffer(20)]],
uint3 tid [[threadgroup_position_in_grid]],
@@ -1879,27 +1883,27 @@ template <
const int BM = 32,
const int BK = 32,
const int BN = 32>
[[kernel]] void bs_qmm_t(
[[kernel]] void gather_qmm_t(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]],
device T* y [[buffer(4)]],
const constant int& K [[buffer(5)]],
const constant int& N [[buffer(6)]],
const constant int& M [[buffer(7)]],
const constant int& x_batch_ndims [[buffer(8)]],
const constant int* x_shape [[buffer(9)]],
const constant int64_t* x_strides [[buffer(10)]],
const constant int& w_batch_ndims [[buffer(11)]],
const constant int* w_shape [[buffer(12)]],
const constant int64_t* w_strides [[buffer(13)]],
const constant int64_t* s_strides [[buffer(14)]],
const constant int64_t* b_strides [[buffer(15)]],
const constant int& batch_ndims [[buffer(16)]],
const constant int* batch_shape [[buffer(17)]],
const device uint32_t* lhs_indices [[buffer(18)]],
const device uint32_t* rhs_indices [[buffer(19)]],
const device uint32_t* lhs_indices [[buffer(4)]],
const device uint32_t* rhs_indices [[buffer(5)]],
device T* y [[buffer(6)]],
const constant int& K [[buffer(7)]],
const constant int& N [[buffer(8)]],
const constant int& M [[buffer(9)]],
const constant int& x_batch_ndims [[buffer(10)]],
const constant int* x_shape [[buffer(11)]],
const constant int64_t* x_strides [[buffer(12)]],
const constant int& w_batch_ndims [[buffer(13)]],
const constant int* w_shape [[buffer(14)]],
const constant int64_t* w_strides [[buffer(15)]],
const constant int64_t* s_strides [[buffer(16)]],
const constant int64_t* b_strides [[buffer(17)]],
const constant int& batch_ndims [[buffer(18)]],
const constant int* batch_shape [[buffer(19)]],
const constant int64_t* lhs_strides [[buffer(20)]],
const constant int64_t* rhs_strides [[buffer(21)]],
uint3 tid [[threadgroup_position_in_grid]],
@@ -1946,27 +1950,27 @@ template <
const int BM = 32,
const int BK = 32,
const int BN = 32>
[[kernel]] void bs_qmm_n(
[[kernel]] void gather_qmm_n(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]],
device T* y [[buffer(4)]],
const constant int& K [[buffer(5)]],
const constant int& N [[buffer(6)]],
const constant int& M [[buffer(7)]],
const constant int& x_batch_ndims [[buffer(8)]],
const constant int* x_shape [[buffer(9)]],
const constant int64_t* x_strides [[buffer(10)]],
const constant int& w_batch_ndims [[buffer(11)]],
const constant int* w_shape [[buffer(12)]],
const constant int64_t* w_strides [[buffer(13)]],
const constant int64_t* s_strides [[buffer(14)]],
const constant int64_t* b_strides [[buffer(15)]],
const constant int& batch_ndims [[buffer(16)]],
const constant int* batch_shape [[buffer(17)]],
const device uint32_t* lhs_indices [[buffer(18)]],
const device uint32_t* rhs_indices [[buffer(19)]],
const device uint32_t* lhs_indices [[buffer(4)]],
const device uint32_t* rhs_indices [[buffer(5)]],
device T* y [[buffer(6)]],
const constant int& K [[buffer(7)]],
const constant int& N [[buffer(8)]],
const constant int& M [[buffer(9)]],
const constant int& x_batch_ndims [[buffer(10)]],
const constant int* x_shape [[buffer(11)]],
const constant int64_t* x_strides [[buffer(12)]],
const constant int& w_batch_ndims [[buffer(13)]],
const constant int* w_shape [[buffer(14)]],
const constant int64_t* w_strides [[buffer(15)]],
const constant int64_t* s_strides [[buffer(16)]],
const constant int64_t* b_strides [[buffer(17)]],
const constant int& batch_ndims [[buffer(18)]],
const constant int* batch_shape [[buffer(19)]],
const constant int64_t* lhs_strides [[buffer(20)]],
const constant int64_t* rhs_strides [[buffer(21)]],
uint3 tid [[threadgroup_position_in_grid]],
@@ -2007,6 +2011,289 @@ template <
w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
}
template <typename T, typename mma_t, typename loader_a_t, typename loader_b_t>
METAL_FUNC void gemm_loop_aligned(
threadgroup T* As,
threadgroup T* Bs,
thread mma_t& mma_op,
thread loader_a_t& loader_a,
thread loader_b_t& loader_b,
const int k_iterations) {
for (int k = 0; k < k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup memory
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
}
template <
bool rows_aligned,
bool cols_aligned,
bool transpose,
typename T,
typename mma_t,
typename loader_a_t,
typename loader_b_t>
METAL_FUNC void gemm_loop_unaligned(
threadgroup T* As,
threadgroup T* Bs,
thread mma_t& mma_op,
thread loader_a_t& loader_a,
thread loader_b_t& loader_b,
const int k_iterations,
const short tgp_bm,
const short tgp_bn,
const short tgp_bk) {
for (int k = 0; k < k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup memory
if (rows_aligned) {
loader_a.load_unsafe();
} else {
loader_a.load_safe(short2(tgp_bk, tgp_bm));
}
if (cols_aligned) {
loader_b.load_unsafe();
} else {
loader_b.load_safe(
transpose ? short2(tgp_bk, tgp_bn) : short2(tgp_bn, tgp_bk));
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
}
template <typename T, typename mma_t, typename loader_a_t, typename loader_b_t>
METAL_FUNC void gemm_loop_finalize(
threadgroup T* As,
threadgroup T* Bs,
thread mma_t& mma_op,
thread loader_a_t& loader_a,
thread loader_b_t& loader_b,
const short2 tile_a,
const short2 tile_b) {
loader_a.load_safe(tile_a);
loader_b.load_safe(tile_b);
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
template <
typename T,
int group_size,
int bits,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose>
[[kernel]] void gather_qmm_rhs(
const device T* x [[buffer(0)]],
const device uint32_t* w [[buffer(1)]],
const device T* scales [[buffer(2)]],
const device T* biases [[buffer(3)]],
const device uint32_t* indices [[buffer(4)]],
device T* y [[buffer(5)]],
const constant int& M [[buffer(6)]],
const constant int& N [[buffer(7)]],
const constant int& K [[buffer(8)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]]) {
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
constexpr int BK_padded = (BK + 16 / sizeof(T));
constexpr int BN_padded = (BN + 16 / sizeof(T));
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
using mma_t = mlx::steel::BlockMMA<
T,
T,
BM,
BN,
BK,
WM,
WN,
false,
transpose,
BK_padded,
transpose ? BK_padded : BN_padded>;
using loader_x_t =
mlx::steel::BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE>;
using loader_w_t = QuantizedBlockLoader<
T,
transpose ? BN : BK,
transpose ? BK : BN,
transpose ? BK_padded : BN_padded,
transpose,
WM * WN * SIMD_SIZE,
group_size,
bits>;
threadgroup T Xs[BM * BK_padded];
threadgroup T Ws[transpose ? BN * BK_padded : BK * BN_padded];
// Compute the block
const int K_w = K * bytes_per_pack / pack_factor;
const int K_g = K / group_size;
const int N_w = N * bytes_per_pack / pack_factor;
const int N_g = N / group_size;
const int K_it = K / BK;
const size_t stride_w = transpose ? N * K_w : K * N_w;
const size_t stride_s = transpose ? N * K_g : K * N_g;
const int y_row = tid.y * BM;
const int y_col = tid.x * BN;
const size_t y_row_long = size_t(y_row);
const size_t y_col_long = size_t(y_col);
// Prepare threadgroup bounds
const short tgp_bm = align_M ? BM : short(min(BM, M - y_row));
const short tgp_bn = align_N ? BN : short(min(BN, N - y_col));
// Calculate the final tiles in the case that K is not aligned
const int k_remain = K - K_it * BK;
const short2 tile_x = short2(k_remain, tgp_bm);
const short2 tile_w =
transpose ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
// Move x and output to the correct block
auto wl = (const device uint8_t*)w;
x += y_row_long * K;
y += y_row_long * N + y_col_long;
wl += transpose ? y_col_long * K_w : y_col * bytes_per_pack / pack_factor;
scales += transpose ? y_col_long * K_g : y_col / group_size;
biases += transpose ? y_col_long * K_g : y_col / group_size;
// Do as many matmuls as necessary
uint32_t index;
short offset;
uint32_t index_next = indices[y_row];
short offset_next = 0;
int n = 0;
while (n < tgp_bm) {
n++;
offset = offset_next;
index = index_next;
offset_next = tgp_bm;
for (; n < tgp_bm; n++) {
if (indices[y_row + n] != index) {
offset_next = n;
index_next = indices[y_row + n];
break;
}
}
threadgroup_barrier(mem_flags::mem_none);
// Prepare threadgroup mma operation
thread mma_t mma_op(simd_group_id, simd_lane_id);
// Prepare threadgroup loading operations
thread loader_x_t loader_x(x, K, Xs, simd_group_id, simd_lane_id);
thread loader_w_t loader_w(
wl + index * stride_w,
scales + index * stride_s,
biases + index * stride_s,
transpose ? K : N,
Ws,
simd_group_id,
simd_lane_id);
// Matrices are all aligned check nothing
if (align_M && align_N) {
gemm_loop_aligned(Xs, Ws, mma_op, loader_x, loader_w, K_it);
if (!align_K) {
threadgroup_barrier(mem_flags::mem_threadgroup);
gemm_loop_finalize(Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
}
// Store results to device memory
if (offset_next - offset == BM) {
mma_op.store_result(y, N);
} else {
mma_op.store_result_slice(
y, N, short2(0, offset), short2(BN, offset_next));
}
} else {
// Tile aligned so check outside of the hot loop
if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {
gemm_loop_aligned(Xs, Ws, mma_op, loader_x, loader_w, K_it);
if (!align_K) {
threadgroup_barrier(mem_flags::mem_threadgroup);
gemm_loop_finalize(
Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
}
// Store results to device memory
if (offset_next - offset == BM) {
mma_op.store_result(y, N);
} else {
mma_op.store_result_slice(
y, N, short2(0, offset), short2(BN, offset_next));
}
}
// Tile partially aligned check rows
else if (align_N || tgp_bn == BN) {
gemm_loop_unaligned<false, true, transpose>(
Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK);
if (!align_K) {
threadgroup_barrier(mem_flags::mem_threadgroup);
gemm_loop_finalize(
Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
}
mma_op.store_result_slice(
y, N, short2(0, offset), short2(BN, offset_next));
}
// Tile partially aligned check cols
else if (align_M || tgp_bm == BM) {
gemm_loop_unaligned<true, false, transpose>(
Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK);
if (!align_K) {
threadgroup_barrier(mem_flags::mem_threadgroup);
gemm_loop_finalize(
Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
}
mma_op.store_result_slice(
y, N, short2(0, offset), short2(tgp_bn, offset_next));
}
// Nothing aligned so check both rows and cols
else {
gemm_loop_unaligned<false, false, transpose>(
Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK);
if (!align_K) {
threadgroup_barrier(mem_flags::mem_threadgroup);
gemm_loop_finalize(
Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
}
mma_op.store_result_slice(
y, N, short2(0, offset), short2(tgp_bn, offset_next));
}
}
}
}
template <typename T, const int group_size, const int bits>
[[kernel]] void affine_quantize(
const device T* w [[buffer(0)]],

View File

@@ -60,6 +60,20 @@
bits, \
split_k)
#define instantiate_gather_qmm_rhs(func, name, type, group_size, bits, bm, bn, bk, wm, wn, transpose) \
instantiate_kernel( \
#name "_" #type "_gs_" #group_size "_b_" #bits "_bm_" #bm "_bn_" #bn "_bk_" #bk "_wm_" #wm "_wn_" #wn, \
func, \
type, \
group_size, \
bits, \
bm, \
bn, \
bk, \
wm, \
wn, \
transpose)
#define instantiate_quantized_batched_wrap(name, type, group_size, bits) \
instantiate_quantized_batched(name, type, group_size, bits, 1) \
instantiate_quantized_batched(name, type, group_size, bits, 0)
@@ -73,14 +87,14 @@
#define instantiate_quantized_all_single(type, group_size, bits) \
instantiate_quantized(affine_quantize, type, group_size, bits) \
instantiate_quantized(affine_dequantize, type, group_size, bits) \
instantiate_quantized(bs_qmv_fast, type, group_size, bits) \
instantiate_quantized(bs_qmv, type, group_size, bits) \
instantiate_quantized(bs_qvm, type, group_size, bits) \
instantiate_quantized(bs_qmm_n, type, group_size, bits)
instantiate_quantized(gather_qmv_fast, type, group_size, bits) \
instantiate_quantized(gather_qmv, type, group_size, bits) \
instantiate_quantized(gather_qvm, type, group_size, bits) \
instantiate_quantized(gather_qmm_n, type, group_size, bits)
#define instantiate_quantized_all_aligned(type, group_size, bits) \
instantiate_quantized_aligned(bs_qmm_t, type, group_size, bits, true) \
instantiate_quantized_aligned(bs_qmm_t, type, group_size, bits, false) \
instantiate_quantized_aligned(gather_qmm_t, type, group_size, bits, true) \
instantiate_quantized_aligned(gather_qmm_t, type, group_size, bits, false) \
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, true, 1) \
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, true, 0) \
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, false, 1) \
@@ -96,12 +110,17 @@
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 8) \
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 32)
#define instantiate_quantized_all_rhs(type, group_size, bits) \
instantiate_gather_qmm_rhs(gather_qmm_rhs, gather_qmm_rhs_nt, type, group_size, bits, 16, 32, 32, 1, 2, true) \
instantiate_gather_qmm_rhs(gather_qmm_rhs, gather_qmm_rhs_nn, type, group_size, bits, 16, 32, 32, 1, 2, false)
#define instantiate_quantized_funcs(type, group_size, bits) \
instantiate_quantized_all_single(type, group_size, bits) \
instantiate_quantized_all_batched(type, group_size, bits) \
instantiate_quantized_all_aligned(type, group_size, bits) \
instantiate_quantized_all_quad(type, group_size, bits) \
instantiate_quantized_all_splitk(type, group_size, bits)
instantiate_quantized_all_splitk(type, group_size, bits) \
instantiate_quantized_all_rhs(type, group_size, bits)
#define instantiate_quantized_types(group_size, bits) \
instantiate_quantized_funcs(float, group_size, bits) \

View File

@@ -380,69 +380,11 @@ template <typename T, int N_READS = RMS_N_READS>
}
// clang-format off
#define instantiate_rms_single_row(name, itype) \
template [[host_name("rms" #name)]] [[kernel]] void \
rms_single_row<itype>( \
const device itype* x, \
const device itype* w, \
device itype* out, \
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
\
template [[host_name("vjp_rms" #name)]] [[kernel]] void \
vjp_rms_single_row<itype>( \
const device itype* x, \
const device itype* w, \
const device itype* g, \
device itype* gx, \
device itype* gw, \
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_rms_looped(name, itype) \
template [[host_name("rms_looped" #name)]] [[kernel]] void \
rms_looped<itype>( \
const device itype* x, \
const device itype* w, \
device itype* out, \
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
\
template [[host_name("vjp_rms_looped" #name)]] [[kernel]] void \
vjp_rms_looped<itype>( \
const device itype* x, \
const device itype* w, \
const device itype* g, \
device itype* gx, \
device itype* gw, \
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_rms(name, itype) \
instantiate_rms_single_row(name, itype) \
instantiate_rms_looped(name, itype)
#define instantiate_rms(name, itype) \
instantiate_kernel("rms" #name, rms_single_row, itype) \
instantiate_kernel("vjp_rms" #name, vjp_rms_single_row, itype) \
instantiate_kernel("rms_looped" #name, rms_looped, itype) \
instantiate_kernel("vjp_rms_looped" #name, vjp_rms_looped, itype)
instantiate_rms(float32, float)
instantiate_rms(float16, half)

View File

@@ -1,11 +1,11 @@
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/sdpa_vector.h"
// clang-format off
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/sdpa_vector.h"
using namespace metal;
// clang-format off
// SDPA vector instantiations
#define instantiate_sdpa_vector_aggregation(type, value_dim) \
instantiate_kernel( \
@@ -32,9 +32,11 @@ using namespace metal;
instantiate_sdpa_vector(type, 64, 64) \
instantiate_sdpa_vector(type, 96, 96) \
instantiate_sdpa_vector(type, 128, 128) \
instantiate_sdpa_vector(type, 256, 256) \
instantiate_sdpa_vector_aggregation(type, 64) \
instantiate_sdpa_vector_aggregation(type, 96) \
instantiate_sdpa_vector_aggregation(type, 128)
instantiate_sdpa_vector_aggregation(type, 128) \
instantiate_sdpa_vector_aggregation(type, 256)
instantiate_sdpa_vector_heads(float)
instantiate_sdpa_vector_heads(bfloat16_t)

View File

@@ -2,6 +2,8 @@
#pragma once
#include "mlx/backend/metal/kernels/binary_ops.h"
#define DEFINE_SIMD_SCAN() \
template <typename T, metal::enable_if_t<sizeof(T) < 8, bool> = true> \
T simd_scan(T val) { \
@@ -139,6 +141,29 @@ struct CumMin {
}
};
template <typename U>
struct CumLogaddexp {
static constexpr constant U init = Limits<U>::min;
template <typename T>
U operator()(U a, T b) {
return LogAddExp{}(a, static_cast<U>(b));
}
U simd_scan(U x) {
for (int i = 1; i <= 16; i *= 2) {
U other = simd_shuffle_and_fill_up(x, init, i);
x = LogAddExp{}(x, other);
}
return x;
}
U simd_exclusive_scan(U x) {
x = simd_scan(x);
return simd_shuffle_and_fill_up(x, init, 1);
}
};
template <typename T, typename U, int N_READS, bool reverse>
inline void load_unsafe(U values[N_READS], const device T* input) {
if (reverse) {

View File

@@ -101,4 +101,7 @@ instantiate_scan_helper(min_int64_int64, int64_t, int64_t, CumMi
instantiate_scan_helper(min_float16_float16, half, half, CumMin, 4)
instantiate_scan_helper(min_float32_float32, float, float, CumMin, 4)
instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMin, 4)
instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin, 2) // clang-format on
instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin, 2)
instantiate_scan_helper(logaddexp_float16_float16, half, half, CumLogaddexp, 4)
instantiate_scan_helper(logaddexp_float32_float32, float, float, CumLogaddexp, 4)
instantiate_scan_helper(logaddexp_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumLogaddexp, 4) // clang-format on

View File

@@ -6,6 +6,9 @@ using namespace metal;
constant bool has_mask [[function_constant(20)]];
constant bool query_transposed [[function_constant(21)]];
constant bool do_causal [[function_constant(22)]];
constant bool bool_mask [[function_constant(23)]];
constant bool float_mask [[function_constant(24)]];
template <typename T, int D, int V = D>
[[kernel]] void sdpa_vector(
@@ -13,17 +16,21 @@ template <typename T, int D, int V = D>
const device T* keys [[buffer(1)]],
const device T* values [[buffer(2)]],
device T* out [[buffer(3)]],
const constant int& gqa_factor,
const constant int& N,
const constant size_t& k_head_stride,
const constant size_t& k_seq_stride,
const constant size_t& v_head_stride,
const constant size_t& v_seq_stride,
const constant float& scale,
const device bool* mask [[function_constant(has_mask)]],
const constant int& mask_kv_seq_stride [[function_constant(has_mask)]],
const constant int& mask_q_seq_stride [[function_constant(has_mask)]],
const constant int& mask_head_stride [[function_constant(has_mask)]],
const constant int& gqa_factor [[buffer(4)]],
const constant int& N [[buffer(5)]],
const constant size_t& k_head_stride [[buffer(6)]],
const constant size_t& k_seq_stride [[buffer(7)]],
const constant size_t& v_head_stride [[buffer(8)]],
const constant size_t& v_seq_stride [[buffer(9)]],
const constant float& scale [[buffer(10)]],
const device bool* bmask [[buffer(11), function_constant(bool_mask)]],
const device T* fmask [[buffer(12), function_constant(float_mask)]],
const constant int& mask_kv_seq_stride
[[buffer(13), function_constant(has_mask)]],
const constant int& mask_q_seq_stride
[[buffer(14), function_constant(has_mask)]],
const constant int& mask_head_stride
[[buffer(15), function_constant(has_mask)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 tpg [[threadgroups_per_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
@@ -57,8 +64,12 @@ template <typename T, int D, int V = D>
simd_lid * qk_per_thread;
values += kv_head_idx * v_head_stride + simd_gid * v_seq_stride +
simd_lid * v_per_thread;
if (has_mask) {
mask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride +
if (bool_mask) {
bmask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride +
q_seq_idx * mask_q_seq_stride;
}
if (float_mask) {
fmask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride +
q_seq_idx * mask_q_seq_stride;
}
@@ -77,7 +88,13 @@ template <typename T, int D, int V = D>
// For each key
for (int i = simd_gid; i < N; i += BN) {
if (!has_mask || mask[0]) {
bool use_key = true;
if (do_causal) {
use_key = i <= (N - int(tpg.y) + int(q_seq_idx));
} else if (bool_mask) {
use_key = bmask[0];
}
if (use_key) {
// Read the key
for (int j = 0; j < qk_per_thread; j++) {
k[j] = keys[j];
@@ -89,6 +106,9 @@ template <typename T, int D, int V = D>
score += q[j] * k[j];
}
score = simd_sum(score);
if (float_mask) {
score += max(Limits<U>::finite_min, static_cast<U>(fmask[0]));
}
// Update the accumulators
U new_max = max(max_score, score);
@@ -107,8 +127,11 @@ template <typename T, int D, int V = D>
// Move the pointers to the next kv
keys += inner_k_stride;
values += inner_v_stride;
if (has_mask) {
mask += BN * mask_kv_seq_stride;
if (bool_mask) {
bmask += BN * mask_kv_seq_stride;
}
if (float_mask) {
fmask += BN * mask_kv_seq_stride;
}
}
@@ -149,17 +172,21 @@ template <typename T, int D, int V = D>
device float* out [[buffer(3)]],
device float* sums [[buffer(4)]],
device float* maxs [[buffer(5)]],
const constant int& gqa_factor,
const constant int& N,
const constant size_t& k_head_stride,
const constant size_t& k_seq_stride,
const constant size_t& v_head_stride,
const constant size_t& v_seq_stride,
const constant float& scale,
const device bool* mask [[function_constant(has_mask)]],
const constant int& mask_kv_seq_stride [[function_constant(has_mask)]],
const constant int& mask_q_seq_stride [[function_constant(has_mask)]],
const constant int& mask_head_stride [[function_constant(has_mask)]],
const constant int& gqa_factor [[buffer(6)]],
const constant int& N [[buffer(7)]],
const constant size_t& k_head_stride [[buffer(8)]],
const constant size_t& k_seq_stride [[buffer(9)]],
const constant size_t& v_head_stride [[buffer(10)]],
const constant size_t& v_seq_stride [[buffer(11)]],
const constant float& scale [[buffer(12)]],
const device bool* bmask [[buffer(13), function_constant(bool_mask)]],
const device T* fmask [[buffer(14), function_constant(float_mask)]],
const constant int& mask_kv_seq_stride
[[buffer(15), function_constant(has_mask)]],
const constant int& mask_q_seq_stride
[[buffer(16), function_constant(has_mask)]],
const constant int& mask_head_stride
[[buffer(17), function_constant(has_mask)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 tpg [[threadgroups_per_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
@@ -197,8 +224,13 @@ template <typename T, int D, int V = D>
values += kv_head_idx * v_head_stride +
(block_idx * BN + simd_gid) * v_seq_stride + simd_lid * v_per_thread;
out += o_offset * blocks * V + block_idx * V + simd_lid * v_per_thread;
if (has_mask) {
mask += head_idx * mask_head_stride +
if (bool_mask) {
bmask += head_idx * mask_head_stride +
(block_idx * BN + simd_gid) * mask_kv_seq_stride +
q_seq_idx * mask_q_seq_stride;
}
if (float_mask) {
fmask += head_idx * mask_head_stride +
(block_idx * BN + simd_gid) * mask_kv_seq_stride +
q_seq_idx * mask_q_seq_stride;
}
@@ -218,7 +250,13 @@ template <typename T, int D, int V = D>
// For each key
for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) {
if (!has_mask || mask[0]) {
bool use_key = true;
if (do_causal) {
use_key = i <= (N - int(tpg.y) + int(q_seq_idx));
} else if (bool_mask) {
use_key = bmask[0];
}
if (use_key) {
// Read the key
for (int i = 0; i < qk_per_thread; i++) {
k[i] = keys[i];
@@ -230,6 +268,9 @@ template <typename T, int D, int V = D>
score += q[i] * k[i];
}
score = simd_sum(score);
if (float_mask) {
score += fmask[0];
}
// Update the accumulators
U new_max = max(max_score, score);
@@ -248,8 +289,11 @@ template <typename T, int D, int V = D>
// Move the pointers to the next kv
keys += blocks * inner_k_stride;
values += blocks * inner_v_stride;
if (has_mask) {
mask += BN * blocks * mask_kv_seq_stride;
if (bool_mask) {
bmask += BN * blocks * mask_kv_seq_stride;
}
if (float_mask) {
fmask += BN * blocks * mask_kv_seq_stride;
}
}

View File

@@ -9,47 +9,13 @@ using namespace metal;
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/softmax.h"
#define instantiate_softmax(name, itype) \
template [[host_name("block_softmax_" #name)]] [[kernel]] void \
softmax_single_row<itype>( \
const device itype* in, \
device itype* out, \
constant int& axis_size, \
uint gid [[thread_position_in_grid]], \
uint _lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
template [[host_name("looped_softmax_" #name)]] [[kernel]] void \
softmax_looped<itype>( \
const device itype* in, \
device itype* out, \
constant int& axis_size, \
uint gid [[threadgroup_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_softmax(name, itype) \
instantiate_kernel("block_softmax_" #name, softmax_single_row, itype) \
instantiate_kernel("looped_softmax_" #name, softmax_looped, itype)
#define instantiate_softmax_precise(name, itype) \
template [[host_name("block_softmax_precise_" #name)]] [[kernel]] void \
softmax_single_row<itype, float>( \
const device itype* in, \
device itype* out, \
constant int& axis_size, \
uint gid [[thread_position_in_grid]], \
uint _lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
template [[host_name("looped_softmax_precise_" #name)]] [[kernel]] void \
softmax_looped<itype, float>( \
const device itype* in, \
device itype* out, \
constant int& axis_size, \
uint gid [[threadgroup_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_softmax_precise(name, itype) \
instantiate_kernel("block_softmax_precise_" #name, softmax_single_row, itype, float) \
instantiate_kernel("looped_softmax_precise_" #name, softmax_looped, itype, float)
instantiate_softmax(float32, float)
instantiate_softmax(float16, half)

View File

@@ -229,7 +229,7 @@ template <
// Init to -Inf
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kRowsPT; ++i) {
max_score[i] = Limits<AccumType>::min;
max_score[i] = Limits<AccumType>::finite_min;
}
int kb_lim = params->NK;
@@ -237,6 +237,7 @@ template <
if (do_causal) {
int q_max = (tid.x + 1) * BQ + params->qL_off;
kb_lim = (q_max + BK - 1) / BK;
kb_lim = min(params->NK, kb_lim);
}
// Loop over KV seq length
@@ -272,7 +273,7 @@ template <
if (!align_K && kb == (params->NK_aligned)) {
using stile_t = decltype(Stile);
using selem_t = typename stile_t::elem_type;
constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity();
constexpr auto neg_inf = Limits<selem_t>::finite_min;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < stile_t::kTileRows; i++) {
@@ -290,10 +291,10 @@ template <
}
// Mask out if causal
if (do_causal && kb >= (kb_lim - (BQ + BK - 1) / BK - int(!align_K))) {
if (do_causal && kb >= (kb_lim - ((BQ + BK - 1) / BK) - int(!align_K))) {
using stile_t = decltype(Stile);
using selem_t = typename stile_t::elem_type;
constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity();
constexpr auto neg_inf = Limits<selem_t>::finite_min;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < stile_t::kTileRows; i++) {
@@ -316,7 +317,7 @@ template <
if (has_mask) {
using stile_t = decltype(Stile);
using selem_t = typename stile_t::elem_type;
constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity();
constexpr auto neg_inf = Limits<selem_t>::finite_min;
constexpr bool is_bool = is_same_v<MaskType, bool>;
using melem_t = typename metal::conditional_t<is_bool, bool, selem_t>;

View File

@@ -15,10 +15,6 @@ constant bool align_M [[function_constant(200)]];
constant bool align_N [[function_constant(201)]];
constant bool align_K [[function_constant(202)]];
constant bool do_gather [[function_constant(300)]];
constant bool gather_bias = do_gather && use_out_source;
// clang-format off
template <
typename T,
@@ -39,12 +35,6 @@ template <
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
const constant int* batch_shape [[buffer(6)]],
const constant int64_t* batch_strides [[buffer(7)]],
const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]],
const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]],
const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]],
const constant int* operand_shape [[buffer(13), function_constant(do_gather)]],
const constant int64_t* operand_strides [[buffer(14), function_constant(do_gather)]],
const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
@@ -81,84 +71,26 @@ template <
}
// Adjust for batch
if (has_batch) {
const constant auto* A_bstrides = batch_strides;
const constant auto* B_bstrides = batch_strides + params->batch_ndim;
// Handle gather
if (do_gather) {
// Read indices
uint32_t indx_A, indx_B, indx_C;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
if (has_batch) {
const constant auto* indx_A_bstrides = batch_strides;
const constant auto* indx_B_bstrides = batch_strides + params->batch_ndim;
ulong2 indx_offsets = elem_to_loc_broadcast(
tid.z,
batch_shape,
indx_A_bstrides,
indx_B_bstrides,
params->batch_ndim);
indx_A = lhs_indices[indx_offsets.x];
indx_B = rhs_indices[indx_offsets.y];
if (use_out_source) {
const constant auto* indx_C_bstrides =
indx_B_bstrides + params->batch_ndim;
auto indx_offset_C = elem_to_loc(
tid.z, batch_shape, indx_C_bstrides, params->batch_ndim);
indx_C = C_indices[indx_offset_C];
}
} else {
indx_A = lhs_indices[params->batch_stride_a * tid.z];
indx_B = rhs_indices[params->batch_stride_b * tid.z];
if (use_out_source) {
indx_C = C_indices[addmm_params->batch_stride_c * tid.z];
}
}
// Translate indices to offsets
int batch_ndim_A = operand_batch_ndim.x;
const constant int* batch_shape_A = operand_shape;
const constant auto* batch_strides_A = operand_strides;
A += elem_to_loc(indx_A, batch_shape_A, batch_strides_A, batch_ndim_A);
int batch_ndim_B = operand_batch_ndim.y;
const constant int* batch_shape_B = batch_shape_A + batch_ndim_A;
const constant auto* batch_strides_B = batch_strides_A + batch_ndim_A;
B += elem_to_loc(indx_B, batch_shape_B, batch_strides_B, batch_ndim_B);
A += batch_offsets.x;
B += batch_offsets.y;
if (use_out_source) {
int batch_ndim_C = operand_batch_ndim.z;
const constant int* batch_shape_C = batch_shape_B + batch_ndim_B;
const constant auto* batch_strides_C = batch_strides_B + batch_ndim_B;
C += elem_to_loc(indx_C, batch_shape_C, batch_strides_C, batch_ndim_C);
const constant auto* C_bstrides = B_bstrides + params->batch_ndim;
C += elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim);
}
} else {
A += params->batch_stride_a * tid.z;
B += params->batch_stride_b * tid.z;
}
// Handle regular batch
else {
if (has_batch) {
const constant auto* A_bstrides = batch_strides;
const constant auto* B_bstrides = batch_strides + params->batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
A += batch_offsets.x;
B += batch_offsets.y;
if (use_out_source) {
const constant auto* C_bstrides = B_bstrides + params->batch_ndim;
C += elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim);
}
} else {
A += params->batch_stride_a * tid.z;
B += params->batch_stride_b * tid.z;
if (use_out_source) {
C += addmm_params->batch_stride_c * tid.z;
}
if (use_out_source) {
C += addmm_params->batch_stride_c * tid.z;
}
}

View File

@@ -0,0 +1,459 @@
// Copyright © 2024 Apple Inc.
using namespace mlx::steel;
constant bool has_batch [[function_constant(10)]];
constant bool align_M [[function_constant(200)]];
constant bool align_N [[function_constant(201)]];
constant bool align_K [[function_constant(202)]];
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
typename AccumType = float>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gather_mm_rhs(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
const device uint32_t* rhs_indices [[buffer(2)]],
device T* C [[buffer(3)]],
const constant GEMMParams* params [[buffer(4)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]]) {
using gemm_kernel = GEMMKernel<
T,
T,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
true,
true,
AccumType>;
using loader_a_t = typename gemm_kernel::loader_a_t;
using loader_b_t = typename gemm_kernel::loader_b_t;
using mma_t = typename gemm_kernel::mma_t;
if (params->tiles_n <= static_cast<int>(tid.x) ||
params->tiles_m <= static_cast<int>(tid.y)) {
return;
}
// Prepare threadgroup memory
threadgroup T As[gemm_kernel::tgp_mem_size_a];
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
// Find the block in A, B, C
const int c_row = tid.y * BM;
const int c_col = tid.x * BN;
const size_t c_row_long = size_t(c_row);
const size_t c_col_long = size_t(c_col);
// Prepare threadgroup bounds
const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row));
const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col));
A += transpose_a ? c_row_long : c_row_long * params->lda;
B += transpose_b ? c_col_long * params->ldb : c_col_long;
C += c_row_long * params->ldd + c_col_long;
// Do as many matmuls as necessary
uint32_t index;
short offset;
uint32_t index_next = rhs_indices[c_row];
short offset_next = 0;
int n = 0;
while (n < tgp_bm) {
n++;
offset = offset_next;
index = index_next;
offset_next = tgp_bm;
for (; n < tgp_bm; n++) {
if (rhs_indices[c_row + n] != index) {
offset_next = n;
index_next = rhs_indices[c_row + n];
break;
}
}
threadgroup_barrier(mem_flags::mem_none);
// Prepare threadgroup mma operation
thread mma_t mma_op(simd_group_id, simd_lane_id);
// Prepare threadgroup loading operations
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
thread loader_b_t loader_b(
B + index * params->batch_stride_b,
params->ldb,
Bs,
simd_group_id,
simd_lane_id);
// Prepare iterations
const int gemm_k_iterations = params->gemm_k_iterations_aligned;
// Do unaligned K iterations first
if (!align_K) {
const int k_last = params->gemm_k_iterations_aligned * BK;
const int k_remain = params->K - k_last;
const size_t k_jump_a =
transpose_a ? params->lda * size_t(k_last) : size_t(k_last);
const size_t k_jump_b =
transpose_b ? size_t(k_last) : params->ldb * size_t(k_last);
// Move loader source ahead to end
loader_a.src += k_jump_a;
loader_b.src += k_jump_b;
// Load tile
const short2 tile_dims_A =
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
const short2 tile_dims_B =
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
loader_a.load_safe(tile_dims_A);
loader_b.load_safe(tile_dims_B);
threadgroup_barrier(mem_flags::mem_threadgroup);
// Do matmul
mma_op.mma(As, Bs);
// Reset source back to start
loader_a.src -= k_jump_a;
loader_b.src -= k_jump_b;
}
// Matrix level aligned never check
if (align_M && align_N) {
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
// Store results to device memory
if (offset_next - offset == BM) {
mma_op.store_result(C, params->ldd);
} else {
mma_op.store_result_slice(
C, params->ldd, short2(0, offset), short2(BN, offset_next));
}
} else {
const short lbk = 0;
// Tile aligned don't check
if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
lbk,
LoopAlignment<true, true, true>{});
if (offset_next - offset == BM) {
mma_op.store_result(C, params->ldd);
} else {
mma_op.store_result_slice(
C, params->ldd, short2(0, offset), short2(BN, offset_next));
}
}
// Tile partially aligned check rows
else if (align_N || tgp_bn == BN) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
lbk,
LoopAlignment<false, true, true>{});
mma_op.store_result_slice(
C, params->ldd, short2(0, offset), short2(BN, offset_next));
}
// Tile partially aligned check cols
else if (align_M || tgp_bm == BM) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
lbk,
LoopAlignment<true, false, true>{});
mma_op.store_result_slice(
C, params->ldd, short2(0, offset), short2(tgp_bn, offset_next));
}
// Nothing aligned so check both rows and cols
else {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
lbk,
LoopAlignment<false, false, true>{});
mma_op.store_result_slice(
C, params->ldd, short2(0, offset), short2(tgp_bn, offset_next));
}
}
}
}
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
typename AccumType = float>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gather_mm(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
const device uint32_t* lhs_indices [[buffer(2)]],
const device uint32_t* rhs_indices [[buffer(3)]],
device T* C [[buffer(4)]],
const constant GEMMParams* params [[buffer(5)]],
const constant int* indices_shape [[buffer(6)]],
const constant int64_t* lhs_strides [[buffer(7)]],
const constant int64_t* rhs_strides [[buffer(8)]],
const constant int& batch_ndim_a [[buffer(9)]],
const constant int* batch_shape_a [[buffer(10)]],
const constant int64_t* batch_strides_a [[buffer(11)]],
const constant int& batch_ndim_b [[buffer(12)]],
const constant int* batch_shape_b [[buffer(13)]],
const constant int64_t* batch_strides_b [[buffer(14)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]]) {
using gemm_kernel = GEMMKernel<
T,
T,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
true,
true,
AccumType>;
using loader_a_t = typename gemm_kernel::loader_a_t;
using loader_b_t = typename gemm_kernel::loader_b_t;
using mma_t = typename gemm_kernel::mma_t;
if (params->tiles_n <= static_cast<int>(tid.x) ||
params->tiles_m <= static_cast<int>(tid.y)) {
return;
}
// Move A and B to the locations pointed by lhs_indices and rhs_indices.
uint32_t indx_A, indx_B;
if (has_batch) {
ulong2 indices_offsets = elem_to_loc_broadcast(
tid.z, indices_shape, lhs_strides, rhs_strides, params->batch_ndim);
indx_A = lhs_indices[indices_offsets.x];
indx_B = rhs_indices[indices_offsets.y];
} else {
indx_A = lhs_indices[params->batch_stride_a * tid.z];
indx_B = rhs_indices[params->batch_stride_b * tid.z];
}
A += elem_to_loc(indx_A, batch_shape_a, batch_strides_a, batch_ndim_a);
B += elem_to_loc(indx_B, batch_shape_b, batch_strides_b, batch_ndim_b);
C += params->batch_stride_d * tid.z;
// Prepare threadgroup memory
threadgroup T As[gemm_kernel::tgp_mem_size_a];
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
// Just make sure everybody's finished with the indexing math above.
threadgroup_barrier(mem_flags::mem_none);
// Find block in A, B, C
const int c_row = tid.y * BM;
const int c_col = tid.x * BN;
const size_t c_row_long = size_t(c_row);
const size_t c_col_long = size_t(c_col);
A += transpose_a ? c_row_long : c_row_long * params->lda;
B += transpose_b ? c_col_long * params->ldb : c_col_long;
C += c_row_long * params->ldd + c_col_long;
// Prepare threadgroup mma operation
thread mma_t mma_op(simd_group_id, simd_lane_id);
// Prepare threadgroup loading operations
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
// Prepare threadgroup bounds
const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row));
const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col));
// Prepare iterations
int gemm_k_iterations = params->gemm_k_iterations_aligned;
// Do unaligned K iterations first
if (!align_K) {
const int k_last = params->gemm_k_iterations_aligned * BK;
const int k_remain = params->K - k_last;
const size_t k_jump_a =
transpose_a ? params->lda * size_t(k_last) : size_t(k_last);
const size_t k_jump_b =
transpose_b ? size_t(k_last) : params->ldb * size_t(k_last);
// Move loader source ahead to end
loader_a.src += k_jump_a;
loader_b.src += k_jump_b;
// Load tile
const short2 tile_dims_A =
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
const short2 tile_dims_B =
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
loader_a.load_safe(tile_dims_A);
loader_b.load_safe(tile_dims_B);
threadgroup_barrier(mem_flags::mem_threadgroup);
// Do matmul
mma_op.mma(As, Bs);
// Reset source back to start
loader_a.src -= k_jump_a;
loader_b.src -= k_jump_b;
}
// Matrix level aligned never check
if (align_M && align_N) {
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
// Store results to device memory
mma_op.store_result(C, params->ldd);
} else {
const short lbk = 0;
// Tile aligned don't check
if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
lbk,
LoopAlignment<true, true, true>{});
mma_op.store_result(C, params->ldd);
}
// Tile partially aligned check rows
else if (align_N || tgp_bn == BN) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
lbk,
LoopAlignment<false, true, true>{});
mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm));
}
// Tile partially aligned check cols
else if (align_M || tgp_bm == BM) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
lbk,
LoopAlignment<true, false, true>{});
mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm));
}
// Nothing aligned so check both rows and cols
else {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
lbk,
LoopAlignment<false, false, true>{});
mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm));
}
}
}

View File

@@ -0,0 +1,59 @@
// Copyright © 2024 Apple Inc.
// clang-format off
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h"
#define instantiate_gather_mm_rhs(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_kernel( \
"steel_gather_mm_rhs_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn \
"_bk" #bk "_wm" #wm "_wn" #wn, \
gather_mm_rhs, \
itype, \
bm, \
bn, \
bk, \
wm, \
wn, \
trans_a, \
trans_b, \
float)
#define instantiate_gather_mm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_kernel( \
"steel_gather_mm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn \
"_bk" #bk "_wm" #wm "_wn" #wn, \
gather_mm, \
itype, \
bm, \
bn, \
bk, \
wm, \
wn, \
trans_a, \
trans_b, \
float)
#define instantiate_gather_mm_rhs_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gather_mm_rhs(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gather_mm_rhs(nt, false, true, iname, itype, oname, otype, bm, bn, bk, wm, wn)
#define instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gather_mm(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gather_mm(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gather_mm(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gather_mm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
#define instantiate_gather_mm_shapes_helper(iname, itype, oname, otype) \
instantiate_gather_mm_rhs_transpose_helper(iname, itype, oname, otype, 16, 64, 16, 1, 2) \
instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 1, 2) \
instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 1, 2) \
instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2)
// clang-format on
instantiate_gather_mm_shapes_helper(float16, half, float16, half);
instantiate_gather_mm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
instantiate_gather_mm_shapes_helper(float32, float, float32, float);

View File

@@ -142,6 +142,42 @@ struct BaseMMAFrag<T, 8, 8> {
}
}
template <
typename DstPtrType,
typename StrX,
typename StrY,
typename StartX,
typename StopX,
typename StartY,
typename StopY,
typename OffX,
typename OffY>
METAL_FUNC static constexpr void store_slice(
const thread frag_type& src,
DstPtrType dst,
StrX str_x,
StrY str_y,
StartX start_x,
StopX stop_x,
StartY start_y,
StopY stop_y,
OffX off_x = Int<0>{},
OffY off_y = Int<0>{}) {
using U = pointer_element_t<DstPtrType>;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kElemRows; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kElemCols; j++) {
if ((off_x + i) < stop_x && (off_x + i) >= start_x &&
(off_y + j) < stop_y && (off_y + j) >= start_y) {
dst[(off_x + i) * str_x + (off_y + j) * str_y] =
static_cast<U>(src[i * kElemCols + j]);
}
}
}
}
METAL_FUNC static constexpr void mma(
thread frag_type& D,
thread frag_type& A,
@@ -335,6 +371,31 @@ struct MMATile {
}
}
}
template <typename U, int w_x, int w_y>
METAL_FUNC void store_slice(
device U* dst,
const int ld,
const short2 start,
const short2 stop) const {
STEEL_PRAGMA_UNROLL
for (int i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (int j = 0; j < kTileCols; ++j) {
MMAFrag_t::store_slice(
frag_at(i, j),
dst,
ld,
Int<1>{},
start.y,
stop.y,
start.x,
stop.x,
(i * kFragRows) * w_x,
(j * kFragCols) * w_y);
}
}
}
};
template <typename T, typename U, int M, int N, int K>
@@ -474,6 +535,26 @@ struct BlockMMA {
Ctile.template store<U, WM, WN>(D, ldd);
}
METAL_FUNC void
store_result_slice(device U* D, const int ldd, short2 start, short2 stop) {
// Apply epilogue
STEEL_PRAGMA_UNROLL
for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]);
}
D += sm * ldd + sn;
start -= short2(sn, sm);
stop -= short2(sn, sm);
// TODO: Check the start as well
if (stop.y <= 0 || stop.x <= 0) {
return;
}
Ctile.template store_slice<U, WM, WN>(D, ldd, start, stop);
}
METAL_FUNC void
store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) {
// Apply epilogue

View File

@@ -73,6 +73,9 @@ instantiate_unary_all_same(Conjugate, complex64, complex64_t)
instantiate_unary_all_same(Cos, complex64, complex64_t)
instantiate_unary_all_same(Cosh, complex64, complex64_t)
instantiate_unary_all_same(Exp, complex64, complex64_t)
instantiate_unary_all_same(Log, complex64, complex64_t)
instantiate_unary_all_same(Log2, complex64, complex64_t)
instantiate_unary_all_same(Log10, complex64, complex64_t)
instantiate_unary_all_same(Negative, complex64, complex64_t)
instantiate_unary_all_same(Sign, complex64, complex64_t)
instantiate_unary_all_same(Sin, complex64, complex64_t)

View File

@@ -257,6 +257,13 @@ struct Log {
T operator()(T x) {
return metal::precise::log(x);
};
template <>
complex64_t operator()(complex64_t x) {
auto r = metal::precise::log(Abs{}(x).real);
auto i = metal::precise::atan2(x.imag, x.real);
return {r, i};
};
};
struct Log2 {
@@ -264,6 +271,12 @@ struct Log2 {
T operator()(T x) {
return metal::precise::log2(x);
};
template <>
complex64_t operator()(complex64_t x) {
auto y = Log{}(x);
return {y.real / M_LN2_F, y.imag / M_LN2_F};
};
};
struct Log10 {
@@ -271,6 +284,12 @@ struct Log10 {
T operator()(T x) {
return metal::precise::log10(x);
};
template <>
complex64_t operator()(complex64_t x) {
auto y = Log{}(x);
return {y.real / M_LN10_F, y.imag / M_LN10_F};
};
};
struct Log1p {

View File

@@ -0,0 +1,96 @@
// Copyright © 2023-2024 Apple Inc.
#include <algorithm>
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
namespace mlx::core {
constexpr int LOGSUMEXP_LOOPED_LIMIT = 4096;
void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
if (!issubdtype(out.dtype(), floating)) {
throw std::runtime_error(
"[logsumexp] Does not support non-floating point types.");
}
auto& s = stream();
auto& d = metal::device(s.device);
// Make sure that the last dimension is contiguous
auto ensure_contiguous = [&s, &d](const array& x) {
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
return x;
} else {
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
d.add_temporary(x_copy, s.index);
return x_copy;
}
};
auto in = ensure_contiguous(inputs[0]);
if (in.flags().row_contiguous) {
out.set_data(allocator::malloc(out.nbytes()));
} else {
auto n = in.shape(-1);
auto flags = in.flags();
auto strides = in.strides();
for (auto& s : strides) {
s /= n;
}
bool col_contig = strides[0] == 1;
for (int i = 1; col_contig && i < strides.size(); ++i) {
col_contig &=
(out.shape(i) == 1 || strides[i - 1] == out.shape(i) * strides[i]);
}
flags.col_contiguous = col_contig;
out.set_data(
allocator::malloc(in.nbytes() / n),
in.data_size() / n,
std::move(strides),
flags);
}
int axis_size = in.shape().back();
int n_rows = in.data_size() / axis_size;
const int simd_size = 32;
const int n_reads = 4;
const int looped_limit = LOGSUMEXP_LOOPED_LIMIT;
std::string kernel_name = (axis_size > looped_limit) ? "looped_" : "block_";
kernel_name += "logsumexp_";
kernel_name += type_to_name(out);
auto kernel = get_logsumexp_kernel(d, kernel_name, out);
auto& compute_encoder = d.get_command_encoder(s.index);
{
MTL::Size grid_dims, group_dims;
if (axis_size <= looped_limit) {
size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads;
size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size;
size_t threadgroup_size = simd_size * simds_needed;
assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup());
size_t n_threads = n_rows * threadgroup_size;
grid_dims = MTL::Size(n_threads, 1, 1);
group_dims = MTL::Size(threadgroup_size, 1, 1);
} else {
size_t threadgroup_size = kernel->maxTotalThreadsPerThreadgroup();
size_t n_threads = n_rows * threadgroup_size;
grid_dims = MTL::Size(n_threads, 1, 1);
group_dims = MTL::Size(threadgroup_size, 1, 1);
}
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder.set_bytes(axis_size, 2);
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
}
} // namespace mlx::core

View File

@@ -5,6 +5,7 @@
#include <numeric>
#include <sstream>
#include "mlx/backend/common/broadcasting.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
@@ -102,6 +103,47 @@ std::tuple<bool, int64_t, array> check_transpose(
}
};
inline array
ensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) {
if (!x.flags().row_contiguous) {
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
d.add_temporary(x_copy, s.index);
return x_copy;
} else {
return x;
}
}
inline std::tuple<bool, int64_t, array>
ensure_batch_contiguous(const array& x, metal::Device& d, const Stream& s) {
if (x.flags().row_contiguous) {
return std::make_tuple(false, x.strides()[x.ndim() - 2], x);
}
bool rc = true;
for (int i = 0; i < x.ndim() - 3; i++) {
rc &= x.strides()[i + 1] * x.shape(i) == x.strides()[i];
}
if (rc) {
auto stx = x.strides()[x.ndim() - 2];
auto sty = x.strides()[x.ndim() - 1];
auto K = x.shape(-2);
auto N = x.shape(-1);
if (sty == 1 && (N != 1 || stx == N)) {
return std::make_tuple(false, stx, x);
}
if (stx == 1 && (N != 1 || sty == K)) {
return std::make_tuple(true, sty, x);
}
}
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
d.add_temporary(x_copy, s.index);
return std::make_tuple(false, x_copy.strides()[x_copy.ndim() - 2], x_copy);
}
} // namespace
///////////////////////////////////////////////////////////////////////////////
@@ -230,7 +272,6 @@ void steel_matmul_regular(
const bool align_M = (M % bm) == 0;
const bool align_N = (N % bn) == 0;
const bool align_K = (K % bk) == 0;
const bool do_gather = false;
metal::MTLFCList func_consts = {
{&has_batch, MTL::DataType::DataTypeBool, 10},
@@ -239,7 +280,6 @@ void steel_matmul_regular(
{&align_M, MTL::DataType::DataTypeBool, 200},
{&align_N, MTL::DataType::DataTypeBool, 201},
{&align_K, MTL::DataType::DataTypeBool, 202},
{&do_gather, MTL::DataType::DataTypeBool, 300},
};
// clang-format off
@@ -248,8 +288,7 @@ void steel_matmul_regular(
<< "_do_axpby_" << (do_axpby ? 't' : 'n')
<< "_align_M_" << (align_M ? 't' : 'n')
<< "_align_N_" << (align_N ? 't' : 'n')
<< "_align_K_" << (align_K ? 't' : 'n')
<< "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on
<< "_align_K_" << (align_K ? 't' : 'n'); // clang-format on
std::string hash_name = kname.str();
@@ -975,7 +1014,6 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
const bool align_M = (M % bm) == 0;
const bool align_N = (N % bn) == 0;
const bool align_K = (K % bk) == 0;
const bool do_gather = false;
metal::MTLFCList func_consts = {
{&has_batch, MTL::DataType::DataTypeBool, 10},
@@ -984,7 +1022,6 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
{&align_M, MTL::DataType::DataTypeBool, 200},
{&align_N, MTL::DataType::DataTypeBool, 201},
{&align_K, MTL::DataType::DataTypeBool, 202},
{&do_gather, MTL::DataType::DataTypeBool, 300},
};
// clang-format off
@@ -993,8 +1030,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
<< "_do_axpby_" << (do_axpby ? 't' : 'n')
<< "_align_M_" << (align_M ? 't' : 'n')
<< "_align_N_" << (align_N ? 't' : 'n')
<< "_align_K_" << (align_K ? 't' : 'n')
<< "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on
<< "_align_K_" << (align_K ? 't' : 'n'); // clang-format on
std::string hash_name = kname.str();
@@ -1464,267 +1500,337 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
d.add_temporaries(std::move(copies), s.index);
}
void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
using namespace mlx::steel;
// assert(inputs.size() == 2);
if (!issubdtype(out.dtype(), floating)) {
throw std::runtime_error(
"[GatherMM] Does not yet support non-floating point types.");
}
auto& s = stream();
auto& d = metal::device(s.device);
void gather_mm_rhs(
const array& a_,
const array& b_,
const array& indices_,
array& out,
metal::Device& d,
const Stream& s) {
array indices = ensure_row_contiguous(indices_, d, s);
auto [transpose_b, ldb, b] = ensure_batch_contiguous(b_, d, s);
auto& a_pre = inputs[0];
auto& b_pre = inputs[1];
// Return 0s if either input is empty
if (a_pre.size() == 0 || b_pre.size() == 0) {
array zero = array(0, a_pre.dtype());
fill_gpu(zero, out, s);
d.add_temporary(std::move(zero), s.index);
return;
}
// Broadcast a with indices. If we are here that means lhs_indices were not
// provided so the lhs_indices are implied to be the shape of a broadcasted
// with rhs_indices. We need only broadcast a and copy it as if applying the
// lhs_indices.
auto broadcast_with_indices = [&d, &s, &indices](const array& x) {
if (x.size() / x.shape(-2) / x.shape(-1) == indices.size()) {
return ensure_row_contiguous(x, d, s);
}
out.set_data(allocator::malloc(out.nbytes()));
auto x_shape = indices.shape();
x_shape.push_back(x.shape(-2));
x_shape.push_back(x.shape(-1));
array new_x(std::move(x_shape), x.dtype(), nullptr, {});
broadcast(x, new_x);
return ensure_row_contiguous(new_x, d, s);
};
array a = broadcast_with_indices(a_);
/////////////////////////////////////////////////////////////////////////////
// Init checks and prep
// Extract the matmul shapes
int K = a.shape(-1);
int M = a.size() / K;
int N = b.shape(-1);
int lda = a.strides()[a.ndim() - 2]; // should be K
int M = a_pre.shape(-2);
int N = b_pre.shape(-1);
int K = a_pre.shape(-1);
// Define the dispatch blocks
int bm = 16, bn = 64, bk = 16;
int wm = 1, wn = 2;
// Keep a vector with copies to be cleared in the completed buffer to release
// the arrays
std::vector<array> copies;
auto [transpose_a, a_cols, a] = check_transpose(copies, s, a_pre, M == 1);
auto [transpose_b, b_cols, b] = check_transpose(copies, s, b_pre, N == 1);
const bool align_M = (M % bm) == 0;
const bool align_N = (N % bn) == 0;
const bool align_K = (K % bk) == 0;
int lda = a_cols;
int ldb = b_cols;
// Define the kernel name
std::string base_name;
base_name.reserve(64);
concatenate(
base_name,
"steel_gather_mm_rhs_n",
transpose_b ? 't' : 'n',
'_',
type_to_name(a),
'_',
type_to_name(out),
"_bm",
bm,
"_bn",
bn,
"_bk",
bk,
"_wm",
wm,
"_wn",
wn);
/////////////////////////////////////////////////////////////////////////////
// Check and collapse batch dimensions
auto get_batch_dims = [](const auto& v) {
return decltype(v){v.begin(), v.end() - 2};
metal::MTLFCList func_consts = {
{&align_M, MTL::DataType::DataTypeBool, 200},
{&align_N, MTL::DataType::DataTypeBool, 201},
{&align_K, MTL::DataType::DataTypeBool, 202},
};
auto& lhs_indices = inputs[2];
auto& rhs_indices = inputs[3];
// And the kernel hash that includes the function constants
std::string hash_name;
hash_name.reserve(128);
concatenate(
hash_name,
base_name,
"_align_M_",
align_M ? 't' : 'n',
"_align_N_",
align_N ? 't' : 'n',
"_align_K_",
align_K ? 't' : 'n');
Shape batch_shape = get_batch_dims(out.shape());
Strides batch_strides;
// Get and set the kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = get_steel_gemm_gather_kernel(
d,
base_name,
hash_name,
func_consts,
out,
false,
transpose_b,
bm,
bn,
bk,
wm,
wn,
true);
compute_encoder.set_compute_pipeline_state(kernel);
batch_strides.insert(
batch_strides.end(),
lhs_indices.strides().begin(),
lhs_indices.strides().end());
auto lhs_indices_str = batch_strides.empty() ? 0 : batch_strides.back();
// Prepare the matmul params
auto batch_stride_b = b.ndim() > 2 ? b.strides()[b.ndim() - 3] : b.size();
steel::GEMMParams params{
/* const int M = */ M,
/* const int N = */ N,
/* const int K = */ K,
/* const int lda = */ lda,
/* const int ldb = */ static_cast<int>(ldb),
/* const int ldd = */ N,
/* const int tiles_n = */ (N + bn - 1) / bn,
/* const int tiles_m = */ (M + bm - 1) / bm,
/* const int64_t batch_stride_a = */ 0,
/* const int64_t batch_stride_b = */ static_cast<int64_t>(batch_stride_b),
/* const int64_t batch_stride_d = */ 0,
/* const int swizzle_log = */ 0,
/* const int gemm_k_iterations_aligned = */ (K / bk),
/* const int batch_ndim = */ 0};
batch_strides.insert(
batch_strides.end(),
rhs_indices.strides().begin(),
rhs_indices.strides().end());
auto rhs_indices_str = batch_strides.empty() ? 0 : batch_strides.back();
// Prepare the grid
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(params.tiles_n, params.tiles_m, 1);
int batch_ndim = batch_shape.size();
// Launch kernel
compute_encoder.set_input_array(a, 0);
compute_encoder.set_input_array(b, 1);
compute_encoder.set_input_array(indices, 2);
compute_encoder.set_output_array(out, 3);
compute_encoder.set_bytes(params, 4);
if (batch_ndim == 0) {
batch_shape = {1};
batch_strides = {0};
}
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
int batch_ndim_A = a.ndim() - 2;
int batch_ndim_B = b.ndim() - 2;
std::vector<int> operand_batch_ndim = {batch_ndim_A, batch_ndim_B};
void gather_mv(
const array& mat_,
const array& vec_,
const array& mat_indices_,
const array& vec_indices_,
array& out,
int N,
int K,
bool is_mv,
metal::Device& d,
const Stream& s) {
// Copy if needed
std::vector<array> copies;
auto [transpose_mat, mat_cols, mat] =
check_transpose(copies, s, mat_, N == 1);
auto [transpose_vec, vec_cols, vec] = check_transpose(copies, s, vec_, true);
d.add_temporaries(std::move(copies), s.index);
Shape batch_shape_A = get_batch_dims(a.shape());
Strides batch_strides_A = get_batch_dims(a.strides());
Shape batch_shape_B = get_batch_dims(b.shape());
Strides batch_strides_B = get_batch_dims(b.strides());
// If we are doing vector matrix instead of matrix vector we need to flip the
// matrix transposition. Basically m @ v = v @ m.T assuming that v is treated
// as a one dimensional array.
transpose_mat = (!is_mv) ^ transpose_mat;
if (batch_ndim_A == 0) {
batch_shape_A = {1};
batch_strides_A = {0};
}
// Define some shapes
int in_vector_len = K;
int out_vector_len = N;
int mat_ld = mat_cols;
if (batch_ndim_B == 0) {
batch_shape_B = {1};
batch_strides_B = {0};
}
int batch_size_out = out.size() / N;
int batch_ndim = out.ndim() - 2;
int batch_ndim_mat = mat.ndim() - 2;
int batch_ndim_vec = vec.ndim() - 2;
Strides index_strides = vec_indices_.strides();
index_strides.insert(
index_strides.end(),
mat_indices_.strides().begin(),
mat_indices_.strides().end());
auto matrix_stride_out = static_cast<int64_t>(M) * N;
auto batch_size_out = out.size() / matrix_stride_out;
/////////////////////////////////////////////////////////////////////////////
// Gemv specialization
// Route to gemv if needed
if (std::min(M, N) == 1) {
// Collect problem info
bool is_b_matrix = N != 1;
auto& mat = is_b_matrix ? b : a;
auto& vec = is_b_matrix ? a : b;
bool transpose_mat = is_b_matrix ? !transpose_b : transpose_a;
int in_vector_len = K;
int out_vector_len = is_b_matrix ? N : M;
int mat_cols = transpose_mat ? out_vector_len : in_vector_len;
int mat_rows = transpose_mat ? in_vector_len : out_vector_len;
int mat_ld = is_b_matrix ? b_cols : a_cols;
auto batch_strides_mat = is_b_matrix ? batch_strides_B : batch_strides_A;
auto batch_strides_vec = is_b_matrix ? batch_strides_A : batch_strides_B;
auto batch_shape_mat = is_b_matrix ? batch_shape_B : batch_shape_A;
auto batch_shape_vec = is_b_matrix ? batch_shape_A : batch_shape_B;
if (!is_b_matrix) {
batch_strides = rhs_indices.strides();
batch_strides.insert(
batch_strides.end(),
lhs_indices.strides().begin(),
lhs_indices.strides().end());
}
int batch_ndim = batch_shape.size();
// Determine dispatch kernel
int tm = 4, tn = 4;
int sm = 1, sn = 32;
int bm = 1, bn = 1;
int n_out_per_tgp;
std::ostringstream kname;
if (transpose_mat) {
if (in_vector_len >= 8192 && out_vector_len >= 2048) {
sm = 4;
sn = 8;
} else {
sm = 8;
sn = 4;
}
if (out_vector_len >= 2048) {
bn = 16;
} else if (out_vector_len >= 512) {
bn = 4;
} else {
bn = 2;
}
// Specialized kernel for very small outputs
tn = out_vector_len < tn ? 1 : tn;
n_out_per_tgp = bn * sn * tn;
kname << "gemv_t_gather_" << type_to_name(out);
// Determine dispatch kernel
int tm = 4, tn = 4;
int sm = 1, sn = 32;
int bm = 1, bn = 1;
int n_out_per_tgp;
std::ostringstream kname;
if (transpose_mat) {
if (in_vector_len >= 8192 && out_vector_len >= 2048) {
sm = 4;
sn = 8;
} else {
bm = out_vector_len >= 4096 ? 8 : 4;
sn = 32;
// Specialized kernel for very small outputs
tm = out_vector_len < tm ? 1 : tm;
n_out_per_tgp = bm * sm * tm;
kname << "gemv_gather_" << type_to_name(out);
sm = 8;
sn = 4;
}
kname << "_bm" << bm << "_bn" << bn << "_sm" << sm << "_sn" << sn << "_tm"
<< tm << "_tn" << tn;
if (out_vector_len >= 2048) {
bn = 16;
} else if (out_vector_len >= 512) {
bn = 4;
} else {
bn = 2;
}
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder.set_compute_pipeline_state(kernel);
// Specialized kernel for very small outputs
tn = out_vector_len < tn ? 1 : tn;
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
MTL::Size group_dims = MTL::Size(32, bn, bm);
MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out);
n_out_per_tgp = bn * sn * tn;
kname << "gemv_t_gather_" << type_to_name(out);
compute_encoder.set_input_array(mat, 0);
compute_encoder.set_input_array(vec, 1);
compute_encoder.set_output_array(out, 3);
} else {
bm = out_vector_len >= 4096 ? 8 : 4;
sn = 32;
compute_encoder.set_bytes(in_vector_len, 4);
compute_encoder.set_bytes(out_vector_len, 5);
compute_encoder.set_bytes(mat_ld, 6);
// Specialized kernel for very small outputs
tm = out_vector_len < tm ? 1 : tm;
compute_encoder.set_bytes(batch_ndim, 9);
compute_encoder.set_vector_bytes(batch_shape, 10);
compute_encoder.set_vector_bytes(batch_strides, 11);
int batch_ndim_vec = batch_shape_vec.size();
compute_encoder.set_bytes(batch_ndim_vec, 12);
compute_encoder.set_vector_bytes(batch_shape_vec, 13);
compute_encoder.set_vector_bytes(batch_strides_vec, 14);
int batch_ndim_mat = batch_shape_mat.size();
compute_encoder.set_bytes(batch_ndim_mat, 15);
compute_encoder.set_vector_bytes(batch_shape_mat, 16);
compute_encoder.set_vector_bytes(batch_strides_mat, 17);
compute_encoder.set_input_array(lhs_indices, 18 + int(!is_b_matrix));
compute_encoder.set_input_array(rhs_indices, 18 + int(is_b_matrix));
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
d.add_temporaries(std::move(copies), s.index);
return;
n_out_per_tgp = bm * sm * tm;
kname << "gemv_gather_" << type_to_name(out);
}
/////////////////////////////////////////////////////////////////////////////
// Regular kernel dispatch
kname << "_bm" << bm << "_bn" << bn << "_sm" << sm << "_sn" << sn << "_tm"
<< tm << "_tn" << tn;
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder.set_compute_pipeline_state(kernel);
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
MTL::Size group_dims = MTL::Size(32, bn, bm);
MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out);
compute_encoder.set_input_array(mat, 0);
compute_encoder.set_input_array(vec, 1);
compute_encoder.set_output_array(out, 3);
compute_encoder.set_bytes(in_vector_len, 4);
compute_encoder.set_bytes(out_vector_len, 5);
compute_encoder.set_bytes(mat_ld, 6);
compute_encoder.set_bytes(batch_ndim, 9);
compute_encoder.set_vector_bytes(out.shape(), 10);
compute_encoder.set_vector_bytes(index_strides, 11);
compute_encoder.set_bytes(batch_ndim_vec, 12);
compute_encoder.set_vector_bytes(vec.shape(), 13);
compute_encoder.set_vector_bytes(vec.strides(), 14);
compute_encoder.set_bytes(batch_ndim_mat, 15);
compute_encoder.set_vector_bytes(mat.shape(), 16);
compute_encoder.set_vector_bytes(mat.strides(), 17);
compute_encoder.set_input_array(vec_indices_, 18);
compute_encoder.set_input_array(mat_indices_, 19);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
void gather_mm(
const array& a_,
const array& b_,
const array& lhs_indices,
const array& rhs_indices,
array& out,
int M,
int N,
int K,
metal::Device& d,
const Stream& s) {
// Copy if needed
std::vector<array> copies;
auto [transpose_a, lda, a] = check_transpose(copies, s, a_, false);
auto [transpose_b, ldb, b] = check_transpose(copies, s, b_, false);
d.add_temporaries(std::move(copies), s.index);
// Determine dispatch kernel
int bm = 64, bn = 64, bk = 16;
int wm = 2, wn = 2;
size_t batch_size_out = out.size() / M / N;
int batch_ndim = out.ndim() - 2;
int batch_ndim_a = a.ndim() - 2;
int batch_ndim_b = b.ndim() - 2;
char devc = d.get_architecture().back();
GEMM_TPARAM_MACRO(devc)
// Prepare kernel name
std::ostringstream kname;
kname << "steel_gemm_fused_" << (transpose_a ? 't' : 'n')
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
<< "_wm" << wm << "_wn" << wn;
std::string base_name = kname.str();
const bool has_batch = batch_ndim > 1;
const bool use_out_source = false;
const bool do_axpby = false;
const bool align_M = (M % bm) == 0;
const bool align_N = (N % bn) == 0;
const bool align_K = (K % bk) == 0;
const bool do_gather = true;
// Define the kernel name
std::string base_name;
base_name.reserve(128);
concatenate(
base_name,
"steel_gather_mm_",
transpose_a ? 't' : 'n',
transpose_b ? 't' : 'n',
"_",
type_to_name(a),
"_",
type_to_name(out),
"_bm",
bm,
"_bn",
bn,
"_bk",
bk,
"_wm",
wm,
"_wn",
wn);
metal::MTLFCList func_consts = {
{&has_batch, MTL::DataType::DataTypeBool, 10},
{&use_out_source, MTL::DataType::DataTypeBool, 100},
{&do_axpby, MTL::DataType::DataTypeBool, 110},
{&align_M, MTL::DataType::DataTypeBool, 200},
{&align_N, MTL::DataType::DataTypeBool, 201},
{&align_K, MTL::DataType::DataTypeBool, 202},
{&do_gather, MTL::DataType::DataTypeBool, 300},
};
// clang-format off
kname << "_has_batch_" << (has_batch ? 't' : 'n')
<< "_use_out_source_" << (use_out_source ? 't' : 'n')
<< "_do_axpby_" << (do_axpby ? 't' : 'n')
<< "_align_M_" << (align_M ? 't' : 'n')
<< "_align_N_" << (align_N ? 't' : 'n')
<< "_align_K_" << (align_K ? 't' : 'n')
<< "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on
// And the kernel hash that includes the function constants
std::string hash_name;
hash_name.reserve(128);
concatenate(
hash_name,
base_name,
"_has_batch_",
has_batch ? 't' : 'n',
"_align_M_",
align_M ? 't' : 'n',
"_align_N_",
align_N ? 't' : 'n',
"_align_K_",
align_K ? 't' : 'n');
std::string hash_name = kname.str();
// Encode and dispatch kernel
// Get and set the kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = get_steel_gemm_fused_kernel(
auto kernel = get_steel_gemm_gather_kernel(
d,
base_name,
hash_name,
@@ -1736,72 +1842,96 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
bn,
bk,
wm,
wn);
wn,
false);
compute_encoder.set_compute_pipeline_state(kernel);
// Use problem size to determine threadblock swizzle
int tn = (N + bn - 1) / bn;
int tm = (M + bm - 1) / bm;
// TODO: Explore device-based tuning for swizzle
int swizzle_log = 0; // tm >= 6 ? 3 : (tm <= 3 ? 0 : 2);
// Prepare steel matmul params
GEMMParams params{
// Prepare the matmul params
steel::GEMMParams params{
/* const int M = */ M,
/* const int N = */ N,
/* const int K = */ K,
/* const int lda = */ lda,
/* const int ldb = */ ldb,
/* const int lda = */ static_cast<int>(lda),
/* const int ldb = */ static_cast<int>(ldb),
/* const int ldd = */ N,
/* const int tiles_n = */ tn,
/* const int tiles_m = */ tm,
/* const int64_t batch_stride_a = */ lhs_indices_str,
/* const int64_t batch_stride_b = */ rhs_indices_str,
/* const int64_t batch_stride_d = */ matrix_stride_out,
/* const int swizzle_log = */ swizzle_log,
/* const int tiles_n = */ (N + bn - 1) / bn,
/* const int tiles_m = */ (M + bm - 1) / bm,
/* const int64_t batch_stride_a = */
(batch_ndim > 0) ? lhs_indices.strides()[0] : 0,
/* const int64_t batch_stride_b = */
(batch_ndim > 0) ? rhs_indices.strides()[0] : 0,
/* const int64_t batch_stride_d = */ M * N,
/* const int swizzle_log = */ 0,
/* const int gemm_k_iterations_aligned = */ (K / bk),
/* const int batch_ndim = */ batch_ndim};
// Prepare launch grid params
int tile = 1 << swizzle_log;
tm = (tm + tile - 1) / tile;
tn = tn * tile;
// Prepare the grid
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out);
MTL::Size grid_dims =
MTL::Size(params.tiles_n, params.tiles_m, batch_size_out);
// Launch kernel
compute_encoder.set_input_array(a, 0);
compute_encoder.set_input_array(b, 1);
compute_encoder.set_output_array(out, 3);
compute_encoder.set_bytes(params, 4);
compute_encoder.set_vector_bytes(batch_shape, 6);
compute_encoder.set_vector_bytes(batch_strides, 7);
compute_encoder.set_input_array(lhs_indices, 10);
compute_encoder.set_input_array(rhs_indices, 11);
std::vector operand_shape = batch_shape_A;
operand_shape.insert(
operand_shape.end(), batch_shape_B.begin(), batch_shape_B.end());
std::vector operand_strides = batch_strides_A;
operand_strides.insert(
operand_strides.end(), batch_strides_B.begin(), batch_strides_B.end());
operand_batch_ndim.push_back(0);
compute_encoder.set_vector_bytes(operand_shape, 13);
compute_encoder.set_vector_bytes(operand_strides, 14);
compute_encoder.set_vector_bytes(operand_batch_ndim, 15);
compute_encoder.set_input_array(lhs_indices, 2);
compute_encoder.set_input_array(rhs_indices, 3);
compute_encoder.set_output_array(out, 4);
compute_encoder.set_bytes(params, 5);
compute_encoder.set_vector_bytes(lhs_indices.shape(), 6);
compute_encoder.set_vector_bytes(lhs_indices.strides(), 7);
compute_encoder.set_vector_bytes(rhs_indices.strides(), 8);
compute_encoder.set_bytes(batch_ndim_a, 9);
compute_encoder.set_vector_bytes(a.shape(), 10);
compute_encoder.set_vector_bytes(a.strides(), 11);
compute_encoder.set_bytes(batch_ndim_b, 12);
compute_encoder.set_vector_bytes(b.shape(), 13);
compute_encoder.set_vector_bytes(b.strides(), 14);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
d.add_temporaries(std::move(copies), s.index);
void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = stream();
auto& d = metal::device(s.device);
auto& a = inputs[0];
auto& b = inputs[1];
auto& lhs_indices = inputs[2];
auto& rhs_indices = inputs[3];
// Return 0s if either input is empty
if (a.size() == 0 || b.size() == 0) {
array zero = array(0, a.dtype());
fill_gpu(zero, out, s);
d.add_temporary(std::move(zero), s.index);
return;
}
out.set_data(allocator::malloc(out.nbytes()));
// Extract shapes from inputs.
int M = a.shape(-2);
int N = b.shape(-1);
int K = a.shape(-1);
// We are walking a in order and b is also in order so we can batch up the
// matmuls and reuse reading a and b.
if (M == 1 && right_sorted_ == true) {
gather_mm_rhs(a, b, rhs_indices, out, d, s);
return;
}
// Route to gather gemv if any of a or b are vectors
if (M == 1) {
gather_mv(b, a, rhs_indices, lhs_indices, out, N, K, false, d, s);
return;
}
if (N == 1) {
gather_mv(a, b, lhs_indices, rhs_indices, out, M, K, true, d, s);
return;
}
// Route to non specialized gather mm
gather_mm(a, b, lhs_indices, rhs_indices, out, M, N, K, d, s);
}
} // namespace mlx::core

View File

@@ -12,74 +12,6 @@ namespace mlx::core::metal {
/* Check if the Metal backend is available. */
bool is_available();
/* Get the actively used memory in bytes.
*
* Note, this will not always match memory use reported by the system because
* it does not include cached memory buffers.
* */
size_t get_active_memory();
/* Get the peak amount of used memory in bytes.
*
* The maximum memory used recorded from the beginning of the program
* execution or since the last call to reset_peak_memory.
* */
size_t get_peak_memory();
/* Reset the peak memory to zero.
* */
void reset_peak_memory();
/* Get the cache size in bytes.
*
* The cache includes memory not currently used that has not been returned
* to the system allocator.
* */
size_t get_cache_memory();
/* Set the memory limit.
* The memory limit is a guideline for the maximum amount of memory to use
* during graph evaluation. If the memory limit is exceeded and there is no
* more RAM (including swap when available) allocations will result in an
* exception.
*
* When metal is available the memory limit defaults to 1.5 times the maximum
* recommended working set size reported by the device.
*
* Returns the previous memory limit.
* */
size_t set_memory_limit(size_t limit);
/* Get the current memory limit. */
size_t get_memory_limit();
/* Set the free cache limit.
* If using more than the given limit, free memory will be reclaimed
* from the cache on the next allocation. To disable the cache,
* set the limit to 0.
*
* The cache limit defaults to the memory limit.
*
* Returns the previous cache limit.
* */
size_t set_cache_limit(size_t limit);
/* Clear the memory cache. */
void clear_cache();
/* Set the wired size limit.
*
* Note, this function is only useful for macOS 15.0 or higher.
*
* The wired limit is the total size in bytes of memory that will be kept
* resident. The default value is ``0``.
*
* Setting a wired limit larger than system wired limit is an error.
*
* Returns the previous wired limit.
* */
size_t set_wired_limit(size_t limit);
/** Capture a GPU trace, saving it to an absolute file `path` */
void start_capture(std::string path = "");
void stop_capture();

View File

@@ -72,6 +72,13 @@ MTL::ComputePipelineState* get_softmax_kernel(
return d.get_kernel(kernel_name);
}
MTL::ComputePipelineState* get_logsumexp_kernel(
metal::Device& d,
const std::string& kernel_name,
const array&) {
return d.get_kernel(kernel_name);
}
MTL::ComputePipelineState* get_scan_kernel(
metal::Device& d,
const std::string& kernel_name,
@@ -186,6 +193,23 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
return d.get_kernel(kernel_name);
}
MTL::ComputePipelineState* get_steel_gemm_gather_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const array&,
bool,
bool,
int,
int,
int,
int,
int,
bool) {
return d.get_kernel(kernel_name, "mlx", hash_name, func_consts);
}
MTL::ComputePipelineState* get_gemv_masked_kernel(
metal::Device& d,
const std::string& kernel_name,
@@ -245,4 +269,21 @@ MTL::ComputePipelineState* get_quantized_kernel(
return d.get_kernel(kernel_name);
}
MTL::ComputePipelineState* get_gather_qmm_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const array&,
int,
int,
int,
int,
int,
int,
int,
bool) {
return d.get_kernel(kernel_name, "mlx", hash_name, func_consts);
}
} // namespace mlx::core

View File

@@ -251,8 +251,10 @@ void Concatenate::eval_gpu(const std::vector<array>& inputs, array& out) {
void Contiguous::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.flags().row_contiguous ||
(allow_col_major_ && in.flags().col_contiguous)) {
constexpr size_t extra_bytes = 16384;
if (in.buffer_size() <= out.nbytes() + extra_bytes &&
(in.flags().row_contiguous ||
(allow_col_major_ && in.flags().col_contiguous))) {
out.copy_shared_buffer(in);
} else {
copy_gpu(in, out, CopyType::General);

File diff suppressed because it is too large Load Diff

View File

@@ -22,6 +22,7 @@ ResidencySet::ResidencySet(MTL::Device* d) {
}
throw std::runtime_error(msg.str());
}
wired_set_->requestResidency();
}
}
@@ -32,7 +33,6 @@ void ResidencySet::insert(MTL::Allocation* buf) {
if (wired_set_->allocatedSize() + buf->allocatedSize() <= capacity_) {
wired_set_->addAllocation(buf);
wired_set_->commit();
wired_set_->requestResidency();
} else {
unwired_set_.insert(buf);
}
@@ -76,7 +76,6 @@ void ResidencySet::resize(size_t size) {
}
}
wired_set_->commit();
wired_set_->requestResidency();
} else if (current_size > size) {
auto pool = new_scoped_memory_pool();
// Remove wired allocations until under capacity

View File

@@ -138,6 +138,7 @@ void sdpa_vector(
const array& v,
array& out,
float scale,
bool do_causal,
const std::optional<array>& mask) {
// Set the kernel name
std::string kname;
@@ -162,14 +163,20 @@ void sdpa_vector(
MTL::Size grid_dims(B, q.shape(2), 1);
bool has_mask = mask.has_value();
bool bool_mask = has_mask && (*mask).dtype() == bool_;
bool float_mask = has_mask && !bool_mask;
bool query_transposed = !q.flags().row_contiguous;
metal::MTLFCList func_consts = {
{&has_mask, MTL::DataType::DataTypeBool, 20},
{&query_transposed, MTL::DataType::DataTypeBool, 21},
{&do_causal, MTL::DataType::DataTypeBool, 22},
{&bool_mask, MTL::DataType::DataTypeBool, 23},
{&float_mask, MTL::DataType::DataTypeBool, 24},
};
std::string hash_name = kname;
hash_name += has_mask ? "_mask" : "_nomask";
hash_name += has_mask ? (bool_mask ? "_boolmask" : "_floatmask") : "_nomask";
hash_name += query_transposed ? "_qt" : "_qnt";
hash_name += do_causal ? "_c" : "_nc";
// Get the kernel
auto& compute_encoder = d.get_command_encoder(s.index);
@@ -191,15 +198,15 @@ void sdpa_vector(
compute_encoder.set_bytes(scale, 10);
if (has_mask) {
auto& m = *mask;
compute_encoder.set_input_array(m, 11);
compute_encoder.set_input_array(m, 11 + float_mask);
auto nd = m.ndim();
int32_t kv_seq_stride =
nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0;
int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0;
int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0;
compute_encoder.set_bytes(kv_seq_stride, 12);
compute_encoder.set_bytes(q_seq_stride, 13);
compute_encoder.set_bytes(head_stride, 14);
compute_encoder.set_bytes(kv_seq_stride, 13);
compute_encoder.set_bytes(q_seq_stride, 14);
compute_encoder.set_bytes(head_stride, 15);
}
// Launch
@@ -214,6 +221,7 @@ void sdpa_vector_2pass(
const array& v,
array& out,
float scale,
bool do_causal,
const std::optional<array>& mask) {
// Set the kernel name
std::string kname;
@@ -256,14 +264,20 @@ void sdpa_vector_2pass(
d.add_temporary(maxs, s.index);
bool has_mask = mask.has_value();
bool bool_mask = has_mask && (*mask).dtype() == bool_;
bool float_mask = has_mask && !bool_mask;
bool query_transposed = !q.flags().row_contiguous;
metal::MTLFCList func_consts = {
{&has_mask, MTL::DataType::DataTypeBool, 20},
{&query_transposed, MTL::DataType::DataTypeBool, 21},
{&do_causal, MTL::DataType::DataTypeBool, 22},
{&bool_mask, MTL::DataType::DataTypeBool, 23},
{&float_mask, MTL::DataType::DataTypeBool, 24},
};
std::string hash_name = kname;
hash_name += has_mask ? "_mask" : "_nomask";
hash_name += has_mask ? (bool_mask ? "_boolmask" : "_floatmask") : "_nomask";
hash_name += query_transposed ? "_qt" : "_qnt";
hash_name += do_causal ? "_c" : "_nc";
// Get the kernel
auto& compute_encoder = d.get_command_encoder(s.index);
@@ -287,15 +301,15 @@ void sdpa_vector_2pass(
compute_encoder.set_bytes(scale, 12);
if (has_mask) {
auto& m = *mask;
compute_encoder.set_input_array(m, 13);
compute_encoder.set_input_array(m, 13 + float_mask);
auto nd = m.ndim();
int32_t kv_seq_stride =
nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0;
int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0;
int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0;
compute_encoder.set_bytes(kv_seq_stride, 14);
compute_encoder.set_bytes(q_seq_stride, 15);
compute_encoder.set_bytes(head_stride, 16);
compute_encoder.set_bytes(kv_seq_stride, 15);
compute_encoder.set_bytes(q_seq_stride, 16);
compute_encoder.set_bytes(head_stride, 17);
}
// Launch
@@ -401,12 +415,13 @@ void ScaledDotProductAttention::eval_gpu(
// We route to the 2 pass fused attention if
// - The device is large and the sequence length long
// - The sequence length is even longer and we have gqa
bool do_causal = do_causal_ && q.shape(2) > 1;
char devc = d.get_architecture().back();
if ((devc == 'd' && k.shape(2) >= 1024) ||
(k.shape(1) < q.shape(1) && k.shape(2) >= 4096)) {
sdpa_vector_2pass(s, d, q, k, v, o, scale_, mask);
sdpa_vector_2pass(s, d, q, k, v, o, scale_, do_causal, mask);
} else {
sdpa_vector(s, d, q, k, v, o, scale_, mask);
sdpa_vector(s, d, q, k, v, o, scale_, do_causal, mask);
}
}

View File

@@ -60,6 +60,9 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
case Scan::Min:
reduce_type = "min";
break;
case Scan::LogAddExp:
reduce_type = "logaddexp";
break;
}
kname << reduce_type << "_" << type_to_name(in) << "_" << type_to_name(out);
auto kernel = get_scan_kernel(

View File

@@ -23,12 +23,7 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
// Make sure that the last dimension is contiguous
auto set_output = [&s, &out](const array& x) {
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
if (no_copy && x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2];
no_copy &= (s == 0 || s == x.shape().back());
}
if (no_copy) {
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
if (x.is_donatable()) {
out.copy_shared_buffer(x);
} else {

View File

@@ -2,6 +2,8 @@
#pragma once
#include <type_traits>
#include "mlx/array.h"
#include "mlx/backend/metal/device.h"
#include "mlx/primitives.h"
@@ -58,14 +60,27 @@ inline void debug_set_primitive_buffer_label(
std::string get_primitive_string(Primitive* primitive);
template <typename T>
constexpr bool is_numeric_except_char = std::is_arithmetic_v<T> &&
!std::is_same_v<T, char> && !std::is_same_v<T, signed char> &&
!std::is_same_v<T, unsigned char> && !std::is_same_v<T, wchar_t>;
template <typename T>
void concatenate(std::string& acc, T first) {
acc += first;
if constexpr (is_numeric_except_char<T>) {
acc += std::to_string(first);
} else {
acc += first;
}
}
template <typename T, typename... Args>
void concatenate(std::string& acc, T first, Args... args) {
acc += first;
if constexpr (is_numeric_except_char<T>) {
acc += std::to_string(first);
} else {
acc += first;
}
concatenate(acc, args...);
}

View File

@@ -82,6 +82,7 @@ NO_CPU(LogicalNot)
NO_CPU(LogicalAnd)
NO_CPU(LogicalOr)
NO_CPU(LogAddExp)
NO_CPU(LogSumExp)
NO_CPU_MULTI(LUF)
NO_CPU(Matmul)
NO_CPU(Maximum)

View File

@@ -1,14 +1,72 @@
// Copyright © 2023 Apple Inc.
#include <algorithm>
#include <mutex>
#include "mlx/allocator.h"
namespace mlx::core::allocator {
#ifdef __APPLE__
#include "mlx/backend/no_metal/apple_memory.h"
#elif defined(__linux__)
#include "mlx/backend/no_metal/linux_memory.h"
#else
size_t get_memory_size() {
return 0;
}
#endif
Allocator& allocator() {
namespace mlx::core {
namespace allocator {
class CommonAllocator : public Allocator {
/** A general CPU allocator. */
public:
virtual Buffer malloc(size_t size) override;
virtual void free(Buffer buffer) override;
virtual size_t size(Buffer buffer) const override;
size_t get_active_memory() const {
return active_memory_;
};
size_t get_peak_memory() const {
return peak_memory_;
};
void reset_peak_memory() {
std::unique_lock lk(mutex_);
peak_memory_ = 0;
};
size_t get_memory_limit() {
return memory_limit_;
}
size_t set_memory_limit(size_t limit) {
std::unique_lock lk(mutex_);
std::swap(memory_limit_, limit);
return limit;
}
private:
size_t memory_limit_;
size_t active_memory_{0};
size_t peak_memory_{0};
std::mutex mutex_;
CommonAllocator() : memory_limit_(0.8 * get_memory_size()) {
if (memory_limit_ == 0) {
memory_limit_ = 1UL << 33;
}
};
friend CommonAllocator& common_allocator();
};
CommonAllocator& common_allocator() {
static CommonAllocator allocator_;
return allocator_;
}
Allocator& allocator() {
return common_allocator();
}
void* Buffer::raw_ptr() {
if (!ptr_) {
return nullptr;
@@ -16,4 +74,59 @@ void* Buffer::raw_ptr() {
return static_cast<size_t*>(ptr_) + 1;
}
} // namespace mlx::core::allocator
Buffer CommonAllocator::malloc(size_t size) {
void* ptr = std::malloc(size + sizeof(size_t));
if (ptr != nullptr) {
*static_cast<size_t*>(ptr) = size;
}
std::unique_lock lk(mutex_);
active_memory_ += size;
peak_memory_ = std::max(active_memory_, peak_memory_);
return Buffer{ptr};
}
void CommonAllocator::free(Buffer buffer) {
auto sz = size(buffer);
std::free(buffer.ptr());
std::unique_lock lk(mutex_);
active_memory_ -= sz;
}
size_t CommonAllocator::size(Buffer buffer) const {
if (buffer.ptr() == nullptr) {
return 0;
}
return *static_cast<size_t*>(buffer.ptr());
}
} // namespace allocator
size_t get_active_memory() {
return allocator::common_allocator().get_active_memory();
}
size_t get_peak_memory() {
return allocator::common_allocator().get_peak_memory();
}
void reset_peak_memory() {
return allocator::common_allocator().reset_peak_memory();
}
size_t set_memory_limit(size_t limit) {
return allocator::common_allocator().set_memory_limit(limit);
}
size_t get_memory_limit() {
return allocator::common_allocator().get_memory_limit();
}
// No-ops for common allocator
size_t get_cache_memory() {
return 0;
}
size_t set_cache_limit(size_t) {
return 0;
}
size_t set_wired_limit(size_t) {
return 0;
}
void clear_cache() {}
} // namespace mlx::core

View File

@@ -0,0 +1,16 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <sys/sysctl.h>
namespace {
size_t get_memory_size() {
size_t memsize = 0;
size_t length = sizeof(memsize);
sysctlbyname("hw.memsize", &memsize, &length, NULL, 0);
return memsize;
}
} // namespace

View File

@@ -28,21 +28,19 @@ void Event::wait() {
ec->cv.wait(lk, [value = value(), ec] { return ec->value >= value; });
}
void Event::signal() {
auto ec = static_cast<EventCounter*>(event_.get());
{
std::lock_guard<std::mutex> lk(ec->mtx);
ec->value = value();
}
ec->cv.notify_all();
}
void Event::wait(Stream stream) {
scheduler::enqueue(stream, [*this]() mutable { wait(); });
}
void Event::signal(Stream stream) {
scheduler::enqueue(stream, [*this]() mutable { signal(); });
scheduler::enqueue(stream, [*this]() mutable {
auto ec = static_cast<EventCounter*>(event_.get());
{
std::lock_guard<std::mutex> lk(ec->mtx);
ec->value = value();
}
ec->cv.notify_all();
});
}
bool Event::is_signaled() const {

View File

@@ -0,0 +1,22 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <sys/sysinfo.h>
namespace {
size_t get_memory_size() {
struct sysinfo info;
if (sysinfo(&info) != 0) {
return 0;
}
size_t total_ram = info.totalram;
total_ram *= info.mem_unit;
return total_ram;
}
} // namespace

View File

@@ -31,33 +31,8 @@ void synchronize(Stream) {
"[metal::synchronize] Cannot synchronize GPU without metal backend");
}
// No-ops when Metal is not available.
size_t get_active_memory() {
return 0;
}
size_t get_peak_memory() {
return 0;
}
void reset_peak_memory() {}
size_t get_cache_memory() {
return 0;
}
size_t set_memory_limit(size_t) {
return 0;
}
size_t get_memory_limit() {
return 0;
}
size_t set_cache_limit(size_t) {
return 0;
}
size_t set_wired_limit(size_t) {
return 0;
}
void start_capture(std::string) {}
void stop_capture() {}
void clear_cache() {}
const std::unordered_map<std::string, std::variant<std::string, size_t>>&
device_info() {

View File

@@ -82,6 +82,7 @@ NO_GPU(LogicalNot)
NO_GPU(LogicalAnd)
NO_GPU(LogicalOr)
NO_GPU(LogAddExp)
NO_GPU(LogSumExp)
NO_GPU_MULTI(LUF)
NO_GPU(Matmul)
NO_GPU(Maximum)

View File

@@ -15,6 +15,14 @@ void all_sum(Group group, const array& input, array& output, Stream stream) {
group.raw_group()->all_sum(input, output, stream);
}
void all_max(Group group, const array& input, array& output, Stream stream) {
group.raw_group()->all_max(input, output, stream);
}
void all_min(Group group, const array& input, array& output, Stream stream) {
group.raw_group()->all_min(input, output, stream);
}
void all_gather(Group group, const array& input, array& output, Stream stream) {
group.raw_group()->all_gather(input, output, stream);
}
@@ -57,6 +65,16 @@ class EmptyGroup : public GroupImpl {
throw std::runtime_error(
"Communication not implemented in an empty distributed group.");
}
void all_max(const array&, array&, Stream) override {
throw std::runtime_error(
"Communication not implemented in an empty distributed group.");
}
void all_min(const array&, array&, Stream) override {
throw std::runtime_error(
"Communication not implemented in an empty distributed group.");
}
};
} // namespace detail

View File

@@ -21,6 +21,8 @@ class GroupImpl {
virtual void all_gather(const array& input, array& output, Stream stream) = 0;
virtual void send(const array& input, int dst, Stream stream) = 0;
virtual void recv(array& out, int src, Stream stream) = 0;
virtual void all_max(const array& input, array& output, Stream stream) = 0;
virtual void all_min(const array& input, array& output, Stream stream) = 0;
};
/* Perform an all reduce sum operation */
@@ -35,4 +37,10 @@ void send(Group group, const array& input, int dst, Stream stream);
/** Recv an array from the src rank */
void recv(Group group, array& out, int src, Stream stream);
/** Max reduction */
void all_max(Group group, const array& input, array& output, Stream stream);
/** Min reduction */
void all_min(Group group, const array& input, array& output, Stream stream);
} // namespace mlx::core::distributed::detail

View File

@@ -1,4 +1,4 @@
if(MPI_FOUND AND MLX_BUILD_CPU)
if(MLX_BUILD_CPU)
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/mpi.cpp)
else()
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_mpi.cpp)

View File

@@ -1,12 +1,13 @@
// Copyright © 2024 Apple Inc.
#include <dlfcn.h>
#include <mpi.h>
#include <iostream>
#include "mlx/backend/cpu/encoder.h"
#include "mlx/distributed/distributed.h"
#include "mlx/distributed/distributed_impl.h"
#include "mlx/distributed/mpi/mpi.h"
#include "mlx/distributed/mpi/mpi_declarations.h"
#define LOAD_SYMBOL(symbol, variable) \
{ \
@@ -18,6 +19,12 @@
} \
}
#ifdef __APPLE__
static constexpr const char* libmpi_name = "libmpi.dylib";
#else
static constexpr const char* libmpi_name = "libmpi.so";
#endif
namespace mlx::core::distributed::mpi {
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
@@ -43,15 +50,69 @@ void simple_sum(
template void simple_sum<float16_t>(void*, void*, int*, MPI_Datatype*);
template void simple_sum<bfloat16_t>(void*, void*, int*, MPI_Datatype*);
template <typename T>
void simple_max(
void* input,
void* accumulator,
int* len,
MPI_Datatype* datatype) {
T* in = (T*)input;
T* acc = (T*)accumulator;
int N = *len;
while (N-- > 0) {
*acc = std::max(*acc, *in);
acc++;
in++;
}
}
template void simple_max<float16_t>(void*, void*, int*, MPI_Datatype*);
template void simple_max<bfloat16_t>(void*, void*, int*, MPI_Datatype*);
template void simple_max<complex64_t>(void*, void*, int*, MPI_Datatype*);
template <typename T>
void simple_min(
void* input,
void* accumulator,
int* len,
MPI_Datatype* datatype) {
T* in = (T*)input;
T* acc = (T*)accumulator;
int N = *len;
while (N-- > 0) {
*acc = std::min(*acc, *in);
acc++;
in++;
}
}
template void simple_min<float16_t>(void*, void*, int*, MPI_Datatype*);
template void simple_min<bfloat16_t>(void*, void*, int*, MPI_Datatype*);
template void simple_min<complex64_t>(void*, void*, int*, MPI_Datatype*);
struct MPIWrapper {
MPIWrapper() {
initialized_ = false;
libmpi_handle_ = dlopen("libmpi.dylib", RTLD_NOW | RTLD_GLOBAL);
libmpi_handle_ = dlopen(libmpi_name, RTLD_NOW | RTLD_GLOBAL);
if (libmpi_handle_ == nullptr) {
return;
}
// Check library version and warn if it isn't Open MPI
int (*get_version)(char*, int*);
LOAD_SYMBOL(MPI_Get_library_version, get_version);
char version_ptr[MPI_MAX_LIBRARY_VERSION_STRING];
int version_length = 0;
get_version(version_ptr, &version_length);
std::string_view version(version_ptr, version_length);
if (version.find("Open MPI") == std::string::npos) {
std::cerr << "[mpi] MPI found but it does not appear to be Open MPI."
<< "MLX requires Open MPI but this is " << version << std::endl;
libmpi_handle_ = nullptr;
return;
}
// API
LOAD_SYMBOL(MPI_Init, init);
LOAD_SYMBOL(MPI_Finalize, finalize);
@@ -72,6 +133,8 @@ struct MPIWrapper {
// Ops
LOAD_SYMBOL(ompi_mpi_op_sum, op_sum_);
LOAD_SYMBOL(ompi_mpi_op_max, op_max_);
LOAD_SYMBOL(ompi_mpi_op_min, op_min_);
// Datatypes
LOAD_SYMBOL(ompi_mpi_c_bool, mpi_bool_);
@@ -106,9 +169,15 @@ struct MPIWrapper {
mpi_type_contiguous(2, mpi_uint8_, &mpi_bfloat16_);
mpi_type_commit(&mpi_bfloat16_);
// Custom sum ops
// Custom reduction ops
mpi_op_create(&simple_sum<float16_t>, 1, &op_sum_f16_);
mpi_op_create(&simple_sum<bfloat16_t>, 1, &op_sum_bf16_);
mpi_op_create(&simple_max<float16_t>, 1, &op_max_f16_);
mpi_op_create(&simple_max<bfloat16_t>, 1, &op_max_bf16_);
mpi_op_create(&simple_max<complex64_t>, 1, &op_max_c64_);
mpi_op_create(&simple_min<float16_t>, 1, &op_min_f16_);
mpi_op_create(&simple_min<bfloat16_t>, 1, &op_min_bf16_);
mpi_op_create(&simple_min<complex64_t>, 1, &op_min_c64_);
initialized_ = true;
}
@@ -170,6 +239,32 @@ struct MPIWrapper {
}
}
MPI_Op op_max(const array& arr) {
switch (arr.dtype()) {
case float16:
return op_max_f16_;
case bfloat16:
return op_max_bf16_;
case complex64:
return op_max_c64_;
default:
return op_max_;
}
}
MPI_Op op_min(const array& arr) {
switch (arr.dtype()) {
case float16:
return op_min_f16_;
case bfloat16:
return op_min_bf16_;
case complex64:
return op_min_c64_;
default:
return op_min_;
}
}
void* libmpi_handle_;
// API
@@ -198,6 +293,14 @@ struct MPIWrapper {
MPI_Op op_sum_;
MPI_Op op_sum_f16_;
MPI_Op op_sum_bf16_;
MPI_Op op_max_;
MPI_Op op_max_f16_;
MPI_Op op_max_bf16_;
MPI_Op op_max_c64_;
MPI_Op op_min_;
MPI_Op op_min_f16_;
MPI_Op op_min_bf16_;
MPI_Op op_min_c64_;
// Datatypes
MPI_Datatype mpi_bool_;
@@ -285,6 +388,36 @@ class MPIGroup : public GroupImpl {
comm_);
}
void all_max(const array& input, array& output, Stream stream) override {
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(input);
encoder.set_output_array(output);
encoder.dispatch(
mpi().all_reduce,
(input.data<void>() == output.data<void>()) ? MPI_IN_PLACE
: input.data<void>(),
output.data<void>(),
input.size(),
mpi().datatype(input),
mpi().op_max(input),
comm_);
}
void all_min(const array& input, array& output, Stream stream) override {
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(input);
encoder.set_output_array(output);
encoder.dispatch(
mpi().all_reduce,
(input.data<void>() == output.data<void>()) ? MPI_IN_PLACE
: input.data<void>(),
output.data<void>(),
input.size(),
mpi().datatype(input),
mpi().op_min(input),
comm_);
}
void all_gather(const array& input, array& output, Stream stream) override {
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(input);

View File

@@ -0,0 +1,28 @@
// Copyright © 2024 Apple Inc.
// Constants
#define MPI_SUCCESS 0
#define MPI_ANY_SOURCE -1
#define MPI_ANY_TAG -1
#define MPI_IN_PLACE ((void*)1)
#define MPI_MAX_LIBRARY_VERSION_STRING 256
// Define all the types that we use so that we don't include <mpi.h> which
// causes linker errors on some platforms.
//
// NOTE: We define everything for openmpi.
typedef void* MPI_Comm;
typedef void* MPI_Datatype;
typedef void* MPI_Op;
typedef void(MPI_User_function)(void*, void*, int*, MPI_Datatype*);
typedef struct ompi_status_public_t {
int MPI_SOURCE;
int MPI_TAG;
int MPI_ERROR;
int _cancelled;
size_t _ucount;
} MPI_Status;

View File

@@ -36,6 +36,40 @@ array all_sum(
{x});
}
array all_max(
const array& x,
std::optional<Group> group_ /* = std::nullopt */,
StreamOrDevice s /* = {} */) {
auto group = to_group(group_);
if (group.size() == 1) {
return x;
}
return array(
x.shape(),
x.dtype(),
std::make_shared<AllReduce>(
to_stream(s, Device::cpu), group, AllReduce::Max),
{x});
}
array all_min(
const array& x,
std::optional<Group> group_ /* = std::nullopt */,
StreamOrDevice s /* = {} */) {
auto group = to_group(group_);
if (group.size() == 1) {
return x;
}
return array(
x.shape(),
x.dtype(),
std::make_shared<AllReduce>(
to_stream(s, Device::cpu), group, AllReduce::Min),
{x});
}
array all_gather(
const array& x,
std::optional<Group> group_ /* = std::nullopt */,

View File

@@ -38,4 +38,14 @@ array recv_like(
std::optional<Group> group = std::nullopt,
StreamOrDevice s = {});
array all_max(
const array& x,
std::optional<Group> group = std::nullopt,
StreamOrDevice s = {});
array all_min(
const array& x,
std::optional<Group> group = std::nullopt,
StreamOrDevice s = {});
} // namespace mlx::core::distributed

Some files were not shown because too many files have changed in this diff Show More