Merge branch 'ml-explore:main' into adding-Muon-optimizer

This commit is contained in:
Gökdeniz Gülmez
2025-04-21 20:27:33 +02:00
committed by GitHub
166 changed files with 7312 additions and 2043 deletions

View File

@@ -24,8 +24,8 @@ jobs:
type: boolean type: boolean
default: false default: false
macos: macos:
xcode: "15.2.0" xcode: "16.2.0"
resource_class: macos.m1.medium.gen1 resource_class: m2pro.medium
steps: steps:
- checkout - checkout
- run: - run:
@@ -89,15 +89,14 @@ jobs:
pip install numpy pip install numpy
sudo apt-get update sudo apt-get update
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
- run: - run:
name: Install Python package name: Install Python package
command: | command: |
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
CMAKE_COMPILE_WARNING_AS_ERROR=ON" \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \ CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
python3 setup.py build_ext --inplace python3 setup.py build_ext --inplace
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF \ CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
CMAKE_COMPILE_WARNING_AS_ERROR=ON" \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \ CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
python3 setup.py develop python3 setup.py develop
- run: - run:
@@ -110,6 +109,8 @@ jobs:
name: Run Python tests name: Run Python tests
command: | command: |
python3 -m unittest discover python/tests -v python3 -m unittest discover python/tests -v
mpirun --bind-to none -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py
- run: - run:
name: Build CPP only name: Build CPP only
command: | command: |
@@ -124,10 +125,15 @@ jobs:
parameters: parameters:
xcode_version: xcode_version:
type: string type: string
default: "15.2.0" default: "16.2.0"
macosx_deployment_target:
type: string
default: ""
macos: macos:
xcode: << parameters.xcode_version >> xcode: << parameters.xcode_version >>
resource_class: macos.m1.medium.gen1 environment:
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
resource_class: m2pro.medium
steps: steps:
- checkout - checkout
- run: - run:
@@ -149,7 +155,7 @@ jobs:
command: | command: |
source env/bin/activate source env/bin/activate
DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \ DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
CMAKE_ARGS="CMAKE_COMPILE_WARNING_AS_ERROR=ON" \ CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
pip install -e . -v pip install -e . -v
- run: - run:
name: Generate package stubs name: Generate package stubs
@@ -213,13 +219,18 @@ jobs:
default: "3.9" default: "3.9"
xcode_version: xcode_version:
type: string type: string
default: "15.2.0" default: "16.2.0"
build_env: build_env:
type: string type: string
default: "" default: ""
macosx_deployment_target:
type: string
default: ""
macos: macos:
xcode: << parameters.xcode_version >> xcode: << parameters.xcode_version >>
resource_class: macos.m1.medium.gen1 resource_class: m2pro.medium
environment:
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
steps: steps:
- checkout - checkout
- run: - run:
@@ -240,7 +251,7 @@ jobs:
name: Install Python package name: Install Python package
command: | command: |
source env/bin/activate source env/bin/activate
DEV_RELEASE=1 \ env -u MACOSX_DEPLOYMENT_TARGET DEV_RELEASE=1 \
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \ CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
pip install . -v pip install . -v
- run: - run:
@@ -335,7 +346,7 @@ workflows:
- mac_build_and_test: - mac_build_and_test:
matrix: matrix:
parameters: parameters:
xcode_version: ["15.0.0", "15.2.0", "16.0.0"] macosx_deployment_target: ["13.5", "14.0"]
- linux_build_and_test - linux_build_and_test
- build_documentation - build_documentation
@@ -355,8 +366,70 @@ workflows:
matrix: matrix:
parameters: parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
xcode_version: ["15.0.0", "15.2.0"] macosx_deployment_target: ["13.5", "14.0", "15.0"]
build_env: ["PYPI_RELEASE=1"] build_env: ["PYPI_RELEASE=1"]
xcode_version: ["16.2.0", "15.0.0"]
exclude:
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.9"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.10"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "PYPI_RELEASE=1"
- build_documentation: - build_documentation:
filters: filters:
tags: tags:
@@ -379,7 +452,7 @@ workflows:
requires: [ hold ] requires: [ hold ]
matrix: matrix:
parameters: parameters:
xcode_version: ["15.0.0", "15.2.0", "16.0.0"] macosx_deployment_target: ["13.5", "14.0"]
- linux_build_and_test: - linux_build_and_test:
requires: [ hold ] requires: [ hold ]
nightly_build: nightly_build:
@@ -392,7 +465,54 @@ workflows:
matrix: matrix:
parameters: parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
xcode_version: ["15.0.0", "15.2.0"] macosx_deployment_target: ["13.5", "14.0", "15.0"]
xcode_version: ["16.2.0", "15.0.0"]
exclude:
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.9"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.10"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
weekly_build: weekly_build:
when: when:
and: and:
@@ -403,8 +523,70 @@ workflows:
matrix: matrix:
parameters: parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
xcode_version: ["15.0.0", "15.2.0", "16.0.0"] macosx_deployment_target: ["13.5", "14.0", "15.0"]
build_env: ["DEV_RELEASE=1"] build_env: ["DEV_RELEASE=1"]
xcode_version: ["16.2.0", "15.0.0"]
exclude:
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
linux_test_release: linux_test_release:
when: when:
and: and:

View File

@@ -212,24 +212,6 @@ else()
set(MLX_BUILD_ACCELERATE OFF) set(MLX_BUILD_ACCELERATE OFF)
endif() endif()
find_package(MPI)
if(MPI_FOUND)
execute_process(
COMMAND zsh "-c" "mpirun --version"
OUTPUT_VARIABLE MPI_VERSION
ERROR_QUIET)
if(${MPI_VERSION} MATCHES ".*Open MPI.*")
target_include_directories(mlx PRIVATE ${MPI_INCLUDE_PATH})
elseif(MPI_VERSION STREQUAL "")
set(MPI_FOUND FALSE)
message(
WARNING "MPI found but mpirun is not available. Building without MPI.")
else()
set(MPI_FOUND FALSE)
message(WARNING "MPI which is not OpenMPI found. Building without MPI.")
endif()
endif()
message(STATUS "Downloading json") message(STATUS "Downloading json")
FetchContent_Declare( FetchContent_Declare(
json json

View File

@@ -17,11 +17,11 @@ possible.
You can also run the formatters manually as follows: You can also run the formatters manually as follows:
``` ```shell
clang-format -i file.cpp clang-format -i file.cpp
``` ```
``` ```shell
black file.py black file.py
``` ```

View File

@@ -0,0 +1,74 @@
# Copyright © 2025 Apple Inc.
import mlx.core as mx
from time_utils import time_fn
N = 1024
D = 1024
M = 1024
E = 32
I = 4
def gather_sort(x, indices):
N, M = indices.shape
indices = indices.flatten()
order = mx.argsort(indices)
inv_order = mx.argsort(order)
return x.flatten(0, -3)[order // M], indices[order], inv_order
def scatter_unsort(x, inv_order, shape=None):
x = x[inv_order]
if shape is not None:
x = mx.unflatten(x, 0, shape)
return x
def gather_mm_simulate(x, w, indices):
x, idx, inv_order = gather_sort(x, indices)
for i in range(2):
y = mx.concatenate([x[i] @ w[j].T for i, j in enumerate(idx.tolist())], axis=0)
x = y[:, None]
x = scatter_unsort(x, inv_order, indices.shape)
return x
def time_gather_mm():
x = mx.random.normal((N, 1, 1, D)) / 1024**0.5
w1 = mx.random.normal((E, M, D)) / 1024**0.5
w2 = mx.random.normal((E, D, M)) / 1024**0.5
indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)
sorted_indices = mx.sort(indices.flatten()).reshape(N, I)
mx.eval(x, w1, w2, indices, sorted_indices)
def gather_mm(x, w1, w2, indices, sort):
idx = indices
inv_order = None
if sort:
x, idx, inv_order = gather_sort(x, indices)
x = mx.gather_mm(x, w1.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort)
x = mx.gather_mm(x, w2.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort)
if sort:
x = scatter_unsort(x, inv_order, indices.shape)
return x
time_fn(gather_mm, x, w1, w2, indices, False)
time_fn(gather_mm, x, w1, w2, sorted_indices, False)
time_fn(gather_mm, x, w1, w2, indices, True)
x = mx.random.normal((N * I, D)) / 1024**0.5
w1 = mx.random.normal((M, D)) / 1024**0.5
w2 = mx.random.normal((D, M)) / 1024**0.5
mx.eval(x, w1, w2)
def equivalent_matmul(x, w1, w2):
x = x @ w1.T
x = x @ w2.T
return x
time_fn(equivalent_matmul, x, w1, w2)
if __name__ == "__main__":
time_gather_mm()

View File

@@ -0,0 +1,84 @@
# Copyright © 2025 Apple Inc.
import mlx.core as mx
from time_utils import time_fn
N = 1024
D = 1024
M = 1024
E = 32
I = 4
def gather_sort(x, indices):
N, M = indices.shape
indices = indices.flatten()
order = mx.argsort(indices)
inv_order = mx.argsort(order)
return x.flatten(0, -3)[order // M], indices[order], inv_order
def scatter_unsort(x, inv_order, shape=None):
x = x[inv_order]
if shape is not None:
x = mx.unflatten(x, 0, shape)
return x
def gather_mm_simulate(x, w, indices):
x, idx, inv_order = gather_sort(x, indices)
for i in range(2):
y = mx.concatenate(
[
mx.quantized_matmul(x[i], w[0][j], w[1][j], w[2][j], transpose=True)
for i, j in enumerate(idx.tolist())
],
axis=0,
)
x = y[:, None]
x = scatter_unsort(x, inv_order, indices.shape)
return x
def time_gather_qmm():
x = mx.random.normal((N, 1, 1, D)) / 1024**0.5
w1 = mx.random.normal((E, M, D)) / 1024**0.5
w2 = mx.random.normal((E, D, M)) / 1024**0.5
w1 = mx.quantize(w1)
w2 = mx.quantize(w2)
indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)
sorted_indices = mx.sort(indices.flatten()).reshape(N, I)
mx.eval(x, w1, w2, indices, sorted_indices)
def gather_mm(x, w1, w2, indices, sort):
idx = indices
inv_order = None
if sort:
x, idx, inv_order = gather_sort(x, indices)
x = mx.gather_qmm(x, *w1, transpose=True, rhs_indices=idx, sorted_indices=sort)
x = mx.gather_qmm(x, *w2, transpose=True, rhs_indices=idx, sorted_indices=sort)
if sort:
x = scatter_unsort(x, inv_order, indices.shape)
return x
time_fn(gather_mm, x, w1, w2, indices, False)
time_fn(gather_mm, x, w1, w2, sorted_indices, False)
time_fn(gather_mm, x, w1, w2, indices, True)
x = mx.random.normal((N * I, D)) / 1024**0.5
w1 = mx.random.normal((M, D)) / 1024**0.5
w2 = mx.random.normal((D, M)) / 1024**0.5
w1 = mx.quantize(w1)
w2 = mx.quantize(w2)
mx.eval(x, w1, w2)
def equivalent_matmul(x, w1, w2):
x = mx.quantized_matmul(x, *w1, transpose=True)
x = mx.quantized_matmul(x, *w2, transpose=True)
return x
time_fn(equivalent_matmul, x, w1, w2)
if __name__ == "__main__":
time_gather_qmm()

View File

@@ -13,7 +13,7 @@ EXCLUDE_PATTERNS = */private/*
CREATE_SUBDIRS = NO CREATE_SUBDIRS = NO
FULL_PATH_NAMES = YES FULL_PATH_NAMES = YES
RECURSIVE = YES RECURSIVE = YES
GENERATE_HTML = YES GENERATE_HTML = NO
GENERATE_LATEX = NO GENERATE_LATEX = NO
GENERATE_XML = YES GENERATE_XML = YES
XML_PROGRAMLISTING = YES XML_PROGRAMLISTING = YES

View File

@@ -93,9 +93,9 @@ Primitives
^^^^^^^^^^^ ^^^^^^^^^^^
A :class:`Primitive` is part of the computation graph of an :class:`array`. It A :class:`Primitive` is part of the computation graph of an :class:`array`. It
defines how to create outputs arrays given a input arrays. Further, a defines how to create output arrays given input arrays. Further, a
:class:`Primitive` has methods to run on the CPU or GPU and for function :class:`Primitive` has methods to run on the CPU or GPU and for function
transformations such as ``vjp`` and ``jvp``. Lets go back to our example to be transformations such as ``vjp`` and ``jvp``. Let's go back to our example to be
more concrete: more concrete:
.. code-block:: C++ .. code-block:: C++
@@ -128,7 +128,7 @@ more concrete:
/** The vector-Jacobian product. */ /** The vector-Jacobian product. */
std::vector<array> vjp( std::vector<array> vjp(
const std::vector<array>& primals, const std::vector<array>& primals,
const array& cotan, const std::vector<array>& cotangents,
const std::vector<int>& argnums, const std::vector<int>& argnums,
const std::vector<array>& outputs) override; const std::vector<array>& outputs) override;
@@ -469,7 +469,7 @@ one we just defined:
const std::vector<array>& tangents, const std::vector<array>& tangents,
const std::vector<int>& argnums) { const std::vector<int>& argnums) {
// Forward mode diff that pushes along the tangents // Forward mode diff that pushes along the tangents
// The jvp transform on the primitive can built with ops // The jvp transform on the primitive can be built with ops
// that are scheduled on the same stream as the primitive // that are scheduled on the same stream as the primitive
// If argnums = {0}, we only push along x in which case the // If argnums = {0}, we only push along x in which case the
@@ -481,7 +481,7 @@ one we just defined:
auto scale_arr = array(scale, tangents[0].dtype()); auto scale_arr = array(scale, tangents[0].dtype());
return {multiply(scale_arr, tangents[0], stream())}; return {multiply(scale_arr, tangents[0], stream())};
} }
// If, argnums = {0, 1}, we take contributions from both // If argnums = {0, 1}, we take contributions from both
// which gives us jvp = tangent_x * alpha + tangent_y * beta // which gives us jvp = tangent_x * alpha + tangent_y * beta
else { else {
return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())}; return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())};
@@ -735,7 +735,7 @@ Let's look at a simple script and its results:
print(f"c shape: {c.shape}") print(f"c shape: {c.shape}")
print(f"c dtype: {c.dtype}") print(f"c dtype: {c.dtype}")
print(f"c correct: {mx.all(c == 6.0).item()}") print(f"c is correct: {mx.all(c == 6.0).item()}")
Output: Output:
@@ -743,7 +743,7 @@ Output:
c shape: [3, 4] c shape: [3, 4]
c dtype: float32 c dtype: float32
c correctness: True c is correct: True
Results Results
^^^^^^^ ^^^^^^^

View File

@@ -70,6 +70,7 @@ are the CPU and GPU.
python/fft python/fft
python/linalg python/linalg
python/metal python/metal
python/memory_management
python/nn python/nn
python/optimizers python/optimizers
python/distributed python/distributed

View File

@@ -38,6 +38,7 @@ Array
array.log10 array.log10
array.log1p array.log1p
array.log2 array.log2
array.logcumsumexp
array.logsumexp array.logsumexp
array.max array.max
array.mean array.mean

View File

@@ -20,5 +20,6 @@ Linear Algebra
eigh eigh
lu lu
lu_factor lu_factor
pinv
solve solve
solve_triangular solve_triangular

View File

@@ -0,0 +1,16 @@
Memory Management
=================
.. currentmodule:: mlx.core
.. autosummary::
:toctree: _autosummary
get_active_memory
get_peak_memory
reset_peak_memory
get_cache_memory
set_memory_limit
set_cache_limit
set_wired_limit
clear_cache

View File

@@ -8,13 +8,5 @@ Metal
is_available is_available
device_info device_info
get_active_memory
get_peak_memory
reset_peak_memory
get_cache_memory
set_memory_limit
set_cache_limit
set_wired_limit
clear_cache
start_capture start_capture
stop_capture stop_capture

View File

@@ -36,10 +36,12 @@ Operations
bitwise_or bitwise_or
bitwise_xor bitwise_xor
block_masked_mm block_masked_mm
broadcast_arrays
broadcast_to broadcast_to
ceil ceil
clip clip
concatenate concatenate
contiguous
conj conj
conjugate conjugate
convolve convolve
@@ -101,6 +103,7 @@ Operations
log10 log10
log1p log1p
logaddexp logaddexp
logcumsumexp
logical_not logical_not
logical_and logical_and
logical_or logical_or

View File

@@ -18,3 +18,4 @@ Common Optimizers
AdamW AdamW
Adamax Adamax
Lion Lion
MultiOptimizer

View File

@@ -9,6 +9,7 @@ Transforms
:toctree: _autosummary :toctree: _autosummary
eval eval
async_eval
compile compile
custom_function custom_function
disable_compile disable_compile

View File

@@ -5,6 +5,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp ${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp ${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
${CMAKE_CURRENT_SOURCE_DIR}/dtype_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/export.cpp ${CMAKE_CURRENT_SOURCE_DIR}/export.cpp
${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp ${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp

View File

@@ -4,7 +4,6 @@
#include <sstream> #include <sstream>
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/scheduler.h"
namespace mlx::core::allocator { namespace mlx::core::allocator {
@@ -22,23 +21,4 @@ void free(Buffer buffer) {
allocator().free(buffer); allocator().free(buffer);
} }
Buffer CommonAllocator::malloc(size_t size) {
void* ptr = std::malloc(size + sizeof(size_t));
if (ptr != nullptr) {
*static_cast<size_t*>(ptr) = size;
}
return Buffer{ptr};
}
void CommonAllocator::free(Buffer buffer) {
std::free(buffer.ptr());
}
size_t CommonAllocator::size(Buffer buffer) const {
if (buffer.ptr() == nullptr) {
return 0;
}
return *static_cast<size_t*>(buffer.ptr());
}
} // namespace mlx::core::allocator } // namespace mlx::core::allocator

View File

@@ -49,16 +49,4 @@ class Allocator {
Allocator& allocator(); Allocator& allocator();
class CommonAllocator : public Allocator {
/** A general CPU allocator. */
public:
virtual Buffer malloc(size_t size) override;
virtual void free(Buffer buffer) override;
virtual size_t size(Buffer buffer) const override;
private:
CommonAllocator() = default;
friend Allocator& allocator();
};
} // namespace mlx::core::allocator } // namespace mlx::core::allocator

View File

@@ -339,11 +339,11 @@ class array {
return allocator::allocator().size(buffer()); return allocator::allocator().size(buffer());
} }
// Return a copy of the shared pointer // Return the shared pointer to the array::Data struct
// to the array::Data struct const std::shared_ptr<Data>& data_shared_ptr() const {
std::shared_ptr<Data> data_shared_ptr() const {
return array_desc_->data; return array_desc_->data;
} }
// Return a raw pointer to the arrays data // Return a raw pointer to the arrays data
template <typename T> template <typename T>
T* data() { T* data() {

View File

@@ -1,6 +1,7 @@
target_sources( target_sources(
mlx mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/broadcasting.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp ${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp

View File

@@ -0,0 +1,24 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/common/utils.h"
namespace mlx::core {
void broadcast(const array& in, array& out) {
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
Strides strides(out.ndim(), 0);
int diff = out.ndim() - in.ndim();
for (int i = in.ndim() - 1; i >= 0; --i) {
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
}
auto flags = in.flags();
if (out.size() > in.size()) {
flags.row_contiguous = flags.col_contiguous = false;
}
out.copy_shared_buffer(in, strides, flags, in.data_size());
}
} // namespace mlx::core

View File

@@ -0,0 +1,11 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/array.h"
namespace mlx::core {
void broadcast(const array& in, array& out);
} // namespace mlx::core

View File

@@ -1,6 +1,7 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
#include <cassert> #include <cassert>
#include "mlx/backend/common/broadcasting.h"
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
@@ -42,23 +43,6 @@ void AsStrided::eval(const std::vector<array>& inputs, array& out) {
return out.copy_shared_buffer(in, strides_, flags, data_size, offset_); return out.copy_shared_buffer(in, strides_, flags, data_size, offset_);
} }
void broadcast(const array& in, array& out) {
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
Strides strides(out.ndim(), 0);
int diff = out.ndim() - in.ndim();
for (int i = in.ndim() - 1; i >= 0; --i) {
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
}
auto flags = in.flags();
if (out.size() > in.size()) {
flags.row_contiguous = flags.col_contiguous = false;
}
out.copy_shared_buffer(in, strides, flags, in.data_size());
}
void Broadcast::eval(const std::vector<array>& inputs, array& out) { void Broadcast::eval(const std::vector<array>& inputs, array& out) {
broadcast(inputs[0], out); broadcast(inputs[0], out);
} }

View File

@@ -58,6 +58,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp ${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp ${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp ${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp ${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
@@ -73,8 +74,8 @@ target_sources(
if(MLX_BUILD_ACCELERATE) if(MLX_BUILD_ACCELERATE)
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/bnns.cpp) target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/bnns.cpp)
else() else()
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/no_fp16.cpp target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/simd_fp16.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gemms/no_bf16.cpp) ${CMAKE_CURRENT_SOURCE_DIR}/gemms/simd_bf16.cpp)
endif() endif()
if(IOS) if(IOS)

View File

@@ -46,8 +46,15 @@ void AllReduce::eval_cpu(
case Sum: case Sum:
distributed::detail::all_sum(group(), in, outputs[0], stream()); distributed::detail::all_sum(group(), in, outputs[0], stream());
break; break;
case Max:
distributed::detail::all_max(group(), in, outputs[0], stream());
break;
case Min:
distributed::detail::all_min(group(), in, outputs[0], stream());
break;
default: default:
throw std::runtime_error("Only all reduce sum is supported for now"); throw std::runtime_error(
"Only all reduce sum, min and max are supported for now");
} }
} }

View File

@@ -1,27 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cpu/gemm.h"
namespace mlx::core {
template <>
void matmul<bfloat16_t>(
const bfloat16_t*,
const bfloat16_t*,
bfloat16_t*,
bool,
bool,
size_t,
size_t,
size_t,
float,
float,
size_t,
const Shape&,
const Strides&,
const Shape&,
const Strides&) {
throw std::runtime_error("[Matmul::eval_cpu] bfloat16 not supported.");
}
} // namespace mlx::core

View File

@@ -1,27 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cpu/gemm.h"
namespace mlx::core {
template <>
void matmul<float16_t>(
const float16_t*,
const float16_t*,
float16_t*,
bool,
bool,
size_t,
size_t,
size_t,
float,
float,
size_t,
const Shape&,
const Strides&,
const Shape&,
const Strides&) {
throw std::runtime_error("[Matmul::eval_cpu] float16 not supported.");
}
} // namespace mlx::core

View File

@@ -0,0 +1,45 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/gemm.h"
#include "mlx/backend/cpu/gemms/simd_gemm.h"
namespace mlx::core {
template <>
void matmul<bfloat16_t>(
const bfloat16_t* a,
const bfloat16_t* b,
bfloat16_t* out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
size_t ldc,
float alpha,
float beta,
size_t batch_size,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides) {
auto ndim = a_shape.size();
size_t M = a_shape[ndim - 2];
size_t N = b_shape[ndim - 1];
size_t K = a_shape[ndim - 1];
for (int i = 0; i < batch_size; ++i) {
simd_gemm<bfloat16_t, float>(
a + elem_to_loc(M * K * i, a_shape, a_strides),
b + elem_to_loc(K * N * i, b_shape, b_strides),
out + M * N * i,
a_transposed,
b_transposed,
M,
N,
K,
alpha,
beta);
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,45 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/gemm.h"
#include "mlx/backend/cpu/gemms/simd_gemm.h"
namespace mlx::core {
template <>
void matmul<float16_t>(
const float16_t* a,
const float16_t* b,
float16_t* out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
size_t ldc,
float alpha,
float beta,
size_t batch_size,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides) {
auto ndim = a_shape.size();
size_t M = a_shape[ndim - 2];
size_t N = b_shape[ndim - 1];
size_t K = a_shape[ndim - 1];
for (int i = 0; i < batch_size; ++i) {
simd_gemm<float16_t, float>(
a + elem_to_loc(M * K * i, a_shape, a_strides),
b + elem_to_loc(K * N * i, b_shape, b_strides),
out + M * N * i,
a_transposed,
b_transposed,
M,
N,
K,
alpha,
beta);
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,139 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/backend/cpu/simd/simd.h"
namespace mlx::core {
inline int ceildiv(int a, int b) {
return (a + b - 1) / b;
}
template <int block_size, typename T, typename AccT>
void load_block(
const T* in,
AccT* out,
int M,
int N,
int i,
int j,
bool transpose) {
if (transpose) {
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
out[jj * block_size + ii] =
in[(i * block_size + ii) * N + j * block_size + jj];
}
}
} else {
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
out[ii * block_size + jj] =
in[(i * block_size + ii) * N + j * block_size + jj];
}
}
}
}
template <typename T, typename AccT>
void simd_gemm(
const T* a,
const T* b,
T* c,
bool a_trans,
bool b_trans,
int M,
int N,
int K,
float alpha,
float beta) {
constexpr int block_size = 16;
constexpr int simd_size = simd::max_size<AccT>;
static_assert(
(block_size % simd_size) == 0,
"Block size must be divisible by SIMD size");
int last_k_block_size = K - block_size * (K / block_size);
int last_k_simd_block = (last_k_block_size / simd_size) * simd_size;
for (int i = 0; i < ceildiv(M, block_size); i++) {
for (int j = 0; j < ceildiv(N, block_size); j++) {
AccT c_block[block_size * block_size] = {0.0};
AccT a_block[block_size * block_size];
AccT b_block[block_size * block_size];
int k = 0;
for (; k < K / block_size; k++) {
// Load a and b blocks
if (a_trans) {
load_block<block_size>(a, a_block, K, M, k, i, true);
} else {
load_block<block_size>(a, a_block, M, K, i, k, false);
}
if (b_trans) {
load_block<block_size>(b, b_block, N, K, j, k, false);
} else {
load_block<block_size>(b, b_block, K, N, k, j, true);
}
// Multiply and accumulate
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
for (int kk = 0; kk < block_size; kk += simd_size) {
auto av =
simd::load<AccT, simd_size>(a_block + ii * block_size + kk);
auto bv =
simd::load<AccT, simd_size>(b_block + jj * block_size + kk);
c_block[ii * block_size + jj] += simd::sum(av * bv);
}
}
}
}
if (last_k_block_size) {
// Load a and b blocks
if (a_trans) {
load_block<block_size>(a, a_block, K, M, k, i, true);
} else {
load_block<block_size>(a, a_block, M, K, i, k, false);
}
if (b_trans) {
load_block<block_size>(b, b_block, N, K, j, k, false);
} else {
load_block<block_size>(b, b_block, K, N, k, j, true);
}
// Multiply and accumulate
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
int kk = 0;
for (; kk < last_k_simd_block; kk += simd_size) {
auto av =
simd::load<AccT, simd_size>(a_block + ii * block_size + kk);
auto bv =
simd::load<AccT, simd_size>(b_block + jj * block_size + kk);
c_block[ii * block_size + jj] += simd::sum(av * bv);
}
for (; kk < last_k_block_size; ++kk) {
c_block[ii * block_size + jj] +=
a_block[ii * block_size + kk] * b_block[jj * block_size + kk];
}
}
}
}
// Store
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
auto c_idx = (i * block_size + ii) * N + j * block_size + jj;
if (beta != 0) {
c[c_idx] = static_cast<T>(
alpha * c_block[ii * block_size + jj] + beta * c[c_idx]);
} else {
c[c_idx] = static_cast<T>(alpha * c_block[ii * block_size + jj]);
}
}
}
}
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,140 @@
// Copyright © 2023-2024 Apple Inc.
#include <cassert>
#include <cmath>
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h"
#include "mlx/primitives.h"
#include "mlx/types/limits.h"
namespace mlx::core {
namespace {
using namespace mlx::core::simd;
template <typename T, typename AccT>
void logsumexp(const array& in, array& out, Stream stream) {
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(in);
encoder.set_output_array(out);
const T* in_ptr = in.data<T>();
T* out_ptr = out.data<T>();
int M = in.shape().back();
int L = in.data_size() / M;
encoder.dispatch([in_ptr, out_ptr, M, L]() mutable {
constexpr int N = std::min(max_size<AccT>, max_size<T>);
const T* current_in_ptr;
for (int i = 0; i < L; i++, in_ptr += M, out_ptr += 1) {
// Find the maximum
current_in_ptr = in_ptr;
Simd<AccT, N> vmaximum(-numeric_limits<AccT>::infinity());
size_t s = M;
while (s >= N) {
Simd<AccT, N> vals = load<T, N>(current_in_ptr);
vmaximum = maximum(vals, vmaximum);
current_in_ptr += N;
s -= N;
}
AccT maximum = max(vmaximum);
while (s-- > 0) {
maximum = std::max(maximum, static_cast<AccT>(*current_in_ptr));
current_in_ptr++;
}
// Compute the normalizer and the exponentials
Simd<AccT, N> vnormalizer(0.0);
current_in_ptr = in_ptr;
s = M;
while (s >= N) {
Simd<AccT, N> vexp = load<T, N>(current_in_ptr);
vexp = exp(vexp - maximum);
vnormalizer = vnormalizer + vexp;
current_in_ptr += N;
s -= N;
}
AccT normalizer = sum(vnormalizer);
while (s-- > 0) {
AccT _exp = std::exp(*current_in_ptr - maximum);
normalizer += _exp;
current_in_ptr++;
}
// Normalize
*out_ptr = std::isinf(maximum)
? static_cast<T>(maximum)
: static_cast<T>(std::log(normalizer) + maximum);
}
});
}
} // namespace
void LogSumExp::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
// Make sure that the last dimension is contiguous
auto s = stream();
auto& encoder = cpu::get_command_encoder(s);
auto ensure_contiguous = [&s, &encoder](const array& x) {
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
return x;
} else {
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy(x, x_copy, CopyType::General, s);
encoder.add_temporary(x_copy);
return x_copy;
}
};
auto in = ensure_contiguous(inputs[0]);
if (in.flags().row_contiguous) {
out.set_data(allocator::malloc(out.nbytes()));
} else {
auto n = in.shape(-1);
auto flags = in.flags();
auto strides = in.strides();
for (auto& s : strides) {
s /= n;
}
bool col_contig = strides[0] == 1;
for (int i = 1; col_contig && i < strides.size(); ++i) {
col_contig &=
(out.shape(i) == 1 || strides[i - 1] == out.shape(i) * strides[i]);
}
flags.col_contiguous = col_contig;
out.set_data(
allocator::malloc(in.nbytes() / n),
in.data_size() / n,
std::move(strides),
flags);
}
switch (in.dtype()) {
case float32:
logsumexp<float, float>(in, out, stream());
break;
case float16:
logsumexp<float16_t, float>(in, out, stream());
break;
case bfloat16:
logsumexp<bfloat16_t, float>(in, out, stream());
break;
case float64:
logsumexp<double, double>(in, out, stream());
break;
default:
throw std::runtime_error(
"[logsumexp] only supports floating point types");
break;
}
}
} // namespace mlx::core

View File

@@ -205,8 +205,10 @@ void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
void Contiguous::eval_cpu(const std::vector<array>& inputs, array& out) { void Contiguous::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
if (in.flags().row_contiguous || constexpr size_t extra_bytes = 16384;
(allow_col_major_ && in.flags().col_contiguous)) { if (in.buffer_size() <= out.nbytes() + extra_bytes &&
(in.flags().row_contiguous ||
(allow_col_major_ && in.flags().col_contiguous))) {
out.copy_shared_buffer(in); out.copy_shared_buffer(in);
} else { } else {
copy(in, out, CopyType::General, stream()); copy(in, out, CopyType::General, stream());

View File

@@ -3,6 +3,7 @@
#include <cassert> #include <cassert>
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/binary_ops.h"
#include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h" #include "mlx/backend/cpu/simd/simd.h"
@@ -226,6 +227,16 @@ void scan_dispatch(
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init); scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
break; break;
} }
case Scan::LogAddExp: {
auto op = [](U a, T b) {
return detail::LogAddExp{}(a, static_cast<U>(b));
};
auto init = (issubdtype(in.dtype(), floating))
? static_cast<U>(-std::numeric_limits<float>::infinity())
: std::numeric_limits<U>::min();
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
break;
}
} }
} }

View File

@@ -17,7 +17,7 @@ struct ScalarT<float16_t, N> {
#endif #endif
template <> template <>
static constexpr int max_size<float16_t> = N; inline constexpr int max_size<float16_t> = N;
#define SIMD_FP16_DEFAULT_UNARY(op) \ #define SIMD_FP16_DEFAULT_UNARY(op) \
template <> \ template <> \

View File

@@ -83,25 +83,25 @@ struct Simd {
// Values chosen based on benchmarks on M3 Max // Values chosen based on benchmarks on M3 Max
// TODO: consider choosing these more optimally // TODO: consider choosing these more optimally
template <> template <>
static constexpr int max_size<int8_t> = 16; inline constexpr int max_size<int8_t> = 16;
template <> template <>
static constexpr int max_size<int16_t> = 16; inline constexpr int max_size<int16_t> = 16;
template <> template <>
static constexpr int max_size<int> = 8; inline constexpr int max_size<int> = 8;
template <> template <>
static constexpr int max_size<int64_t> = 4; inline constexpr int max_size<int64_t> = 4;
template <> template <>
static constexpr int max_size<uint8_t> = 16; inline constexpr int max_size<uint8_t> = 16;
template <> template <>
static constexpr int max_size<uint16_t> = 16; inline constexpr int max_size<uint16_t> = 16;
template <> template <>
static constexpr int max_size<uint32_t> = 8; inline constexpr int max_size<uint32_t> = 8;
template <> template <>
static constexpr int max_size<uint64_t> = 4; inline constexpr int max_size<uint64_t> = 4;
template <> template <>
static constexpr int max_size<float> = 8; inline constexpr int max_size<float> = 8;
template <> template <>
static constexpr int max_size<double> = 4; inline constexpr int max_size<double> = 4;
#define SIMD_DEFAULT_UNARY(name, op) \ #define SIMD_DEFAULT_UNARY(name, op) \
template <typename T, int N> \ template <typename T, int N> \

View File

@@ -87,7 +87,6 @@ DEFAULT_UNARY(cosh, std::cosh)
DEFAULT_UNARY(expm1, std::expm1) DEFAULT_UNARY(expm1, std::expm1)
DEFAULT_UNARY(floor, std::floor) DEFAULT_UNARY(floor, std::floor)
DEFAULT_UNARY(log, std::log) DEFAULT_UNARY(log, std::log)
DEFAULT_UNARY(log2, std::log2)
DEFAULT_UNARY(log10, std::log10) DEFAULT_UNARY(log10, std::log10)
DEFAULT_UNARY(log1p, std::log1p) DEFAULT_UNARY(log1p, std::log1p)
DEFAULT_UNARY(sinh, std::sinh) DEFAULT_UNARY(sinh, std::sinh)
@@ -95,6 +94,17 @@ DEFAULT_UNARY(sqrt, std::sqrt)
DEFAULT_UNARY(tan, std::tan) DEFAULT_UNARY(tan, std::tan)
DEFAULT_UNARY(tanh, std::tanh) DEFAULT_UNARY(tanh, std::tanh)
template <typename T>
Simd<T, 1> log2(Simd<T, 1> in) {
if constexpr (is_complex<T>) {
auto out = std::log(in.value);
auto scale = decltype(out.real())(M_LN2);
return Simd<T, 1>{T{out.real() / scale, out.imag() / scale}};
} else {
return Simd<T, 1>{std::log2(in.value)};
}
}
template <typename T> template <typename T>
Simd<T, 1> operator~(Simd<T, 1> in) { Simd<T, 1> operator~(Simd<T, 1> in) {
return ~in.value; return ~in.value;

View File

@@ -119,12 +119,7 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
// Make sure that the last dimension is contiguous // Make sure that the last dimension is contiguous
auto set_output = [s = stream(), &out](const array& x) { auto set_output = [s = stream(), &out](const array& x) {
bool no_copy = x.strides()[x.ndim() - 1] == 1; if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
if (x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2];
no_copy &= (s == 0 || s == x.shape().back());
}
if (no_copy) {
if (x.is_donatable()) { if (x.is_donatable()) {
out.copy_shared_buffer(x); out.copy_shared_buffer(x);
} else { } else {
@@ -146,18 +141,6 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
auto in = set_output(inputs[0]); auto in = set_output(inputs[0]);
switch (in.dtype()) { switch (in.dtype()) {
case bool_:
case uint8:
case uint16:
case uint32:
case uint64:
case int8:
case int16:
case int32:
case int64:
throw std::runtime_error(
"Softmax is defined only for floating point types");
break;
case float32: case float32:
softmax<float, float>(in, out, stream()); softmax<float, float>(in, out, stream());
break; break;
@@ -178,9 +161,9 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
case float64: case float64:
softmax<double, double>(in, out, stream()); softmax<double, double>(in, out, stream());
break; break;
case complex64: default:
throw std::invalid_argument( throw std::runtime_error(
"[Softmax] Not yet implemented for complex64"); "[softmax] Only defined for floating point types.");
break; break;
} }
} }

View File

@@ -1,5 +1,8 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
// Required for using M_LN2 in MSVC.
#define _USE_MATH_DEFINES
#include <cassert> #include <cassert>
#include "mlx/backend/cpu/unary.h" #include "mlx/backend/cpu/unary.h"

View File

@@ -86,13 +86,14 @@ struct Sign {
template <int N, typename T> template <int N, typename T>
Simd<T, N> operator()(Simd<T, N> x) { Simd<T, N> operator()(Simd<T, N> x) {
auto z = Simd<T, N>{0}; auto z = Simd<T, N>{0};
auto o = Simd<T, N>{1};
auto m = Simd<T, N>{-1};
if constexpr (std::is_unsigned_v<T>) { if constexpr (std::is_unsigned_v<T>) {
return x != z; return simd::select(x == z, z, o);
} else if constexpr (std::is_same_v<T, complex64_t>) { } else if constexpr (std::is_same_v<T, complex64_t>) {
return simd::select(x == z, x, Simd<T, N>(x / simd::abs(x))); return simd::select(x == z, x, Simd<T, N>(x / simd::abs(x)));
} else { } else {
return simd::select( return simd::select(x < z, m, simd::select(x > z, o, z));
x < z, Simd<T, N>{-1}, simd::select(x > z, Simd<T, N>{1}, z));
} }
} }
SINGLE() SINGLE()

View File

@@ -47,6 +47,7 @@ if(MLX_METAL_JIT)
make_jit_source(binary) make_jit_source(binary)
make_jit_source(binary_two) make_jit_source(binary_two)
make_jit_source(fft kernels/fft/radix.h kernels/fft/readwrite.h) make_jit_source(fft kernels/fft/radix.h kernels/fft/readwrite.h)
make_jit_source(logsumexp)
make_jit_source(ternary) make_jit_source(ternary)
make_jit_source(softmax) make_jit_source(softmax)
make_jit_source(scan) make_jit_source(scan)
@@ -60,6 +61,7 @@ if(MLX_METAL_JIT)
kernels/steel/gemm/transforms.h) kernels/steel/gemm/transforms.h)
make_jit_source(steel/gemm/kernels/steel_gemm_fused) make_jit_source(steel/gemm/kernels/steel_gemm_fused)
make_jit_source(steel/gemm/kernels/steel_gemm_masked kernels/steel/defines.h) make_jit_source(steel/gemm/kernels/steel_gemm_masked kernels/steel/defines.h)
make_jit_source(steel/gemm/kernels/steel_gemm_gather)
make_jit_source(steel/gemm/kernels/steel_gemm_splitk) make_jit_source(steel/gemm/kernels/steel_gemm_splitk)
make_jit_source( make_jit_source(
steel/conv/conv steel/conv/conv
@@ -95,6 +97,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp ${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cpp
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp ${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp ${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp

View File

@@ -3,6 +3,7 @@
#include "mlx/backend/metal/metal.h" #include "mlx/backend/metal/metal.h"
#include "mlx/backend/metal/metal_impl.h" #include "mlx/backend/metal/metal_impl.h"
#include "mlx/backend/metal/resident.h" #include "mlx/backend/metal/resident.h"
#include "mlx/memory.h"
#include <mach/vm_page_size.h> #include <mach/vm_page_size.h>
#include <unistd.h> #include <unistd.h>
@@ -32,8 +33,11 @@ namespace metal {
namespace { namespace {
BufferCache::BufferCache(MTL::Device* device) BufferCache::BufferCache(ResidencySet& residency_set)
: device_(device), head_(nullptr), tail_(nullptr), pool_size_(0) {} : head_(nullptr),
tail_(nullptr),
pool_size_(0),
residency_set_(residency_set) {}
BufferCache::~BufferCache() { BufferCache::~BufferCache() {
auto pool = metal::new_scoped_memory_pool(); auto pool = metal::new_scoped_memory_pool();
@@ -44,6 +48,9 @@ int BufferCache::clear() {
int n_release = 0; int n_release = 0;
for (auto& [size, holder] : buffer_pool_) { for (auto& [size, holder] : buffer_pool_) {
if (holder->buf) { if (holder->buf) {
if (!holder->buf->heap()) {
residency_set_.erase(holder->buf);
}
holder->buf->release(); holder->buf->release();
n_release++; n_release++;
} }
@@ -101,6 +108,9 @@ int BufferCache::release_cached_buffers(size_t min_bytes_to_free) {
while (tail_ && (total_bytes_freed < min_bytes_to_free)) { while (tail_ && (total_bytes_freed < min_bytes_to_free)) {
if (tail_->buf) { if (tail_->buf) {
total_bytes_freed += tail_->buf->length(); total_bytes_freed += tail_->buf->length();
if (!tail_->buf->heap()) {
residency_set_.erase(tail_->buf);
}
tail_->buf->release(); tail_->buf->release();
tail_->buf = nullptr; tail_->buf = nullptr;
n_release++; n_release++;
@@ -155,7 +165,7 @@ void BufferCache::remove_from_list(BufferCache::BufferHolder* to_remove) {
MetalAllocator::MetalAllocator() MetalAllocator::MetalAllocator()
: device_(device(mlx::core::Device::gpu).mtl_device()), : device_(device(mlx::core::Device::gpu).mtl_device()),
residency_set_(device_), residency_set_(device_),
buffer_cache_(device_) { buffer_cache_(residency_set_) {
auto pool = metal::new_scoped_memory_pool(); auto pool = metal::new_scoped_memory_pool();
auto memsize = std::get<size_t>(device_info().at("memory_size")); auto memsize = std::get<size_t>(device_info().at("memory_size"));
auto max_rec_size = auto max_rec_size =
@@ -262,9 +272,13 @@ Buffer MetalAllocator::malloc(size_t size) {
if (!buf) { if (!buf) {
buf = device_->newBuffer(size, resource_options); buf = device_->newBuffer(size, resource_options);
} }
if (!buf) {
return Buffer{nullptr};
}
lk.lock(); lk.lock();
if (buf) {
num_resources_++; num_resources_++;
if (!buf->heap()) {
residency_set_.insert(buf);
} }
} }
@@ -278,10 +292,6 @@ Buffer MetalAllocator::malloc(size_t size) {
get_cache_memory() - max_pool_size_); get_cache_memory() - max_pool_size_);
} }
if (!buf->heap()) {
residency_set_.insert(buf);
}
return Buffer{static_cast<void*>(buf)}; return Buffer{static_cast<void*>(buf)};
} }
@@ -297,14 +307,14 @@ void MetalAllocator::free(Buffer buffer) {
return; return;
} }
std::unique_lock lk(mutex_); std::unique_lock lk(mutex_);
if (!buf->heap()) {
residency_set_.erase(buf);
}
active_memory_ -= buf->length(); active_memory_ -= buf->length();
if (get_cache_memory() < max_pool_size_) { if (get_cache_memory() < max_pool_size_) {
buffer_cache_.recycle_to_cache(buf); buffer_cache_.recycle_to_cache(buf);
} else { } else {
num_resources_--; num_resources_--;
if (!buf->heap()) {
residency_set_.erase(buf);
}
lk.unlock(); lk.unlock();
auto pool = metal::new_scoped_memory_pool(); auto pool = metal::new_scoped_memory_pool();
buf->release(); buf->release();
@@ -323,40 +333,40 @@ MetalAllocator& allocator() {
return *allocator_; return *allocator_;
} }
} // namespace metal
size_t set_cache_limit(size_t limit) { size_t set_cache_limit(size_t limit) {
return allocator().set_cache_limit(limit); return metal::allocator().set_cache_limit(limit);
} }
size_t set_memory_limit(size_t limit) { size_t set_memory_limit(size_t limit) {
return allocator().set_memory_limit(limit); return metal::allocator().set_memory_limit(limit);
} }
size_t get_memory_limit() { size_t get_memory_limit() {
return allocator().get_memory_limit(); return metal::allocator().get_memory_limit();
} }
size_t set_wired_limit(size_t limit) { size_t set_wired_limit(size_t limit) {
if (limit > if (limit > std::get<size_t>(metal::device_info().at(
std::get<size_t>(device_info().at("max_recommended_working_set_size"))) { "max_recommended_working_set_size"))) {
throw std::invalid_argument( throw std::invalid_argument(
"[metal::set_wired_limit] Setting a wired limit larger than " "[metal::set_wired_limit] Setting a wired limit larger than "
"the maximum working set size is not allowed."); "the maximum working set size is not allowed.");
} }
return allocator().set_wired_limit(limit); return metal::allocator().set_wired_limit(limit);
} }
size_t get_active_memory() { size_t get_active_memory() {
return allocator().get_active_memory(); return metal::allocator().get_active_memory();
} }
size_t get_peak_memory() { size_t get_peak_memory() {
return allocator().get_peak_memory(); return metal::allocator().get_peak_memory();
} }
void reset_peak_memory() { void reset_peak_memory() {
allocator().reset_peak_memory(); metal::allocator().reset_peak_memory();
} }
size_t get_cache_memory() { size_t get_cache_memory() {
return allocator().get_cache_memory(); return metal::allocator().get_cache_memory();
} }
void clear_cache() { void clear_cache() {
return allocator().clear_cache(); return metal::allocator().clear_cache();
} }
} // namespace metal
} // namespace mlx::core } // namespace mlx::core

View File

@@ -18,7 +18,7 @@ namespace {
class BufferCache { class BufferCache {
public: public:
BufferCache(MTL::Device* device); BufferCache(ResidencySet& residency_set);
~BufferCache(); ~BufferCache();
MTL::Buffer* reuse_from_cache(size_t size); MTL::Buffer* reuse_from_cache(size_t size);
@@ -42,13 +42,11 @@ class BufferCache {
void add_at_head(BufferHolder* to_add); void add_at_head(BufferHolder* to_add);
void remove_from_list(BufferHolder* to_remove); void remove_from_list(BufferHolder* to_remove);
MTL::Device* device_;
MTL::Heap* heap_{nullptr};
std::multimap<size_t, BufferHolder*> buffer_pool_; std::multimap<size_t, BufferHolder*> buffer_pool_;
BufferHolder* head_; BufferHolder* head_;
BufferHolder* tail_; BufferHolder* tail_;
size_t pool_size_; size_t pool_size_;
ResidencySet& residency_set_;
}; };
} // namespace } // namespace

View File

@@ -712,6 +712,65 @@ void winograd_conv_2D_gpu(
} }
} }
void depthwise_conv_2D_gpu(
const Stream& s,
metal::Device& d,
const array& in,
const array& wt,
array out,
const MLXConvParams<2>& conv_params) {
std::ostringstream kname;
kname << "depthwise_conv_2d_" << type_to_name(out);
std::string base_name = kname.str();
const int N = conv_params.N;
const int ker_h = conv_params.wS[0];
const int ker_w = conv_params.wS[1];
const int str_h = conv_params.str[0];
const int str_w = conv_params.str[1];
const int tc = 8;
const int tw = 8;
const int th = 4;
const bool do_flip = conv_params.flip;
metal::MTLFCList func_consts = {
{&ker_h, MTL::DataType::DataTypeInt, 00},
{&ker_w, MTL::DataType::DataTypeInt, 01},
{&str_h, MTL::DataType::DataTypeInt, 10},
{&str_w, MTL::DataType::DataTypeInt, 11},
{&th, MTL::DataType::DataTypeInt, 100},
{&tw, MTL::DataType::DataTypeInt, 101},
{&do_flip, MTL::DataType::DataTypeBool, 200},
};
// clang-format off
kname << "_ker_h_" << ker_h
<< "_ker_w_" << ker_w
<< "_str_h_" << str_h
<< "_str_w_" << str_w
<< "_tgp_h_" << th
<< "_tgp_w_" << tw
<< "_do_flip_" << (do_flip ? 't' : 'n'); // clang-format on
std::string hash_name = kname.str();
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_input_array(wt, 1);
compute_encoder.set_output_array(out, 2);
compute_encoder.set_bytes(conv_params, 3);
MTL::Size group_dims = MTL::Size(tc, tw, th);
MTL::Size grid_dims = MTL::Size(
conv_params.C / tc, conv_params.oS[1] / tw, (conv_params.oS[0] / th) * N);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
void conv_2D_gpu( void conv_2D_gpu(
const Stream& s, const Stream& s,
metal::Device& d, metal::Device& d,
@@ -754,11 +813,20 @@ void conv_2D_gpu(
bool is_kdil_one = conv_params.kdil[0] == 1 && conv_params.kdil[1] == 1; bool is_kdil_one = conv_params.kdil[0] == 1 && conv_params.kdil[1] == 1;
bool is_idil_one = conv_params.idil[0] == 1 && conv_params.idil[1] == 1; bool is_idil_one = conv_params.idil[0] == 1 && conv_params.idil[1] == 1;
if (groups > 1) { if (is_idil_one && groups > 1) {
const int C_per_group = conv_params.C / groups; const int C_per_group = conv_params.C / groups;
const int O_per_group = conv_params.O / groups; const int O_per_group = conv_params.O / groups;
if (is_idil_one && (C_per_group <= 4 || C_per_group % 16 == 0) && if (C_per_group == 1 && O_per_group == 1 && is_kdil_one &&
conv_params.wS[0] <= 7 && conv_params.wS[1] <= 7 &&
conv_params.str[0] <= 2 && conv_params.str[1] <= 2 &&
conv_params.oS[0] % 8 == 0 && conv_params.oS[1] % 8 == 0 &&
conv_params.wt_strides[1] == conv_params.wS[1] &&
conv_params.C % 16 == 0 && conv_params.C == conv_params.O) {
return depthwise_conv_2D_gpu(s, d, in, wt, out, conv_params);
}
if ((C_per_group <= 4 || C_per_group % 16 == 0) &&
(O_per_group <= 16 || O_per_group % 16 == 0)) { (O_per_group <= 16 || O_per_group % 16 == 0)) {
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params); return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
} else { } else {

View File

@@ -55,7 +55,10 @@ std::pair<MTL::Library*, NS::Error*> load_library_from_path(
} }
#ifdef SWIFTPM_BUNDLE #ifdef SWIFTPM_BUNDLE
MTL::Library* try_load_bundle(MTL::Device* device, NS::URL* url) { MTL::Library* try_load_bundle(
MTL::Device* device,
NS::URL* url,
const std::string& lib_name) {
std::string bundle_path = std::string(url->fileSystemRepresentation()) + "/" + std::string bundle_path = std::string(url->fileSystemRepresentation()) + "/" +
SWIFTPM_BUNDLE + ".bundle"; SWIFTPM_BUNDLE + ".bundle";
auto bundle = NS::Bundle::alloc()->init( auto bundle = NS::Bundle::alloc()->init(
@@ -63,8 +66,8 @@ MTL::Library* try_load_bundle(MTL::Device* device, NS::URL* url) {
if (bundle != nullptr) { if (bundle != nullptr) {
std::string resource_path = std::string resource_path =
std::string(bundle->resourceURL()->fileSystemRepresentation()) + "/" + std::string(bundle->resourceURL()->fileSystemRepresentation()) + "/" +
"default.metallib"; lib_name + ".metallib" auto [lib, error] =
auto [lib, error] = load_library_from_path(device, resource_path.c_str()); load_library_from_path(device, resource_path.c_str());
if (lib) { if (lib) {
return lib; return lib;
} }
@@ -73,51 +76,124 @@ MTL::Library* try_load_bundle(MTL::Device* device, NS::URL* url) {
} }
#endif #endif
MTL::Library* load_library( // Firstly, search for the metallib in the same path as this binary
std::pair<MTL::Library*, NS::Error*> load_colocated_library(
MTL::Device* device, MTL::Device* device,
const std::string& lib_name = "mlx", const std::string& lib_name) {
const char* lib_path = default_mtllib_path) { std::string lib_path = get_colocated_mtllib_path(lib_name);
// Firstly, search for the metallib in the same path as this binary if (lib_path.size() != 0) {
std::string first_path = get_colocated_mtllib_path(lib_name); return load_library_from_path(device, lib_path.c_str());
if (first_path.size() != 0) {
auto [lib, error] = load_library_from_path(device, first_path.c_str());
if (lib) {
return lib;
}
} }
return {nullptr, nullptr};
}
std::pair<MTL::Library*, NS::Error*> load_swiftpm_library(
MTL::Device* device,
const std::string& lib_name) {
#ifdef SWIFTPM_BUNDLE #ifdef SWIFTPM_BUNDLE
// try to load from a swiftpm resource bundle -- scan the available bundles to
// find one that contains the named bundle
{
MTL::Library* library = MTL::Library* library =
try_load_bundle(device, NS::Bundle::mainBundle()->bundleURL()); try_load_bundle(device, NS::Bundle::mainBundle()->bundleURL(), lib_name);
if (library != nullptr) { if (library != nullptr) {
return library; return {library, nullptr};
} }
auto bundles = NS::Bundle::allBundles(); auto bundles = NS::Bundle::allBundles();
for (int i = 0, c = (int)bundles->count(); i < c; i++) { for (int i = 0, c = (int)bundles->count(); i < c; i++) {
auto bundle = reinterpret_cast<NS::Bundle*>(bundles->object(i)); auto bundle = reinterpret_cast<NS::Bundle*>(bundles->object(i));
library = try_load_bundle(device, bundle->resourceURL()); library = try_load_bundle(device, bundle->resourceURL());
if (library != nullptr) { if (library != nullptr) {
return library; return {library, nullptr};
}
} }
} }
#endif #endif
return {nullptr, nullptr};
}
// Couldn't find it so let's load it from default_mtllib_path MTL::Library* load_default_library(MTL::Device* device) {
{ NS::Error *error1, *error2, *error3;
auto [lib, error] = load_library_from_path(device, lib_path); MTL::Library* lib;
// First try the colocated mlx.metallib
std::tie(lib, error1) = load_colocated_library(device, "mlx");
if (lib) {
return lib;
}
// Then try default.metallib in a SwiftPM bundle if we have one
std::tie(lib, error2) = load_swiftpm_library(device, "default");
if (lib) {
return lib;
}
// Finally try default_mtllib_path
std::tie(lib, error3) = load_library_from_path(device, default_mtllib_path);
if (!lib) { if (!lib) {
std::ostringstream msg; std::ostringstream msg;
msg << error->localizedDescription()->utf8String() << "\n" msg << "Failed to load the default metallib. ";
<< "Failed to load device library from <" << lib_path << ">" if (error1 != nullptr) {
<< " or <" << first_path << ">."; msg << error1->localizedDescription()->utf8String() << " ";
}
if (error2 != nullptr) {
msg << error2->localizedDescription()->utf8String() << " ";
}
if (error3 != nullptr) {
msg << error3->localizedDescription()->utf8String() << " ";
}
throw std::runtime_error(msg.str()); throw std::runtime_error(msg.str());
} }
return lib; return lib;
}
MTL::Library* load_library(
MTL::Device* device,
const std::string& lib_name,
const std::string& lib_path) {
// We have been given a path that ends in metallib so try to load it
if (lib_path.size() > 9 &&
std::equal(lib_path.end() - 9, lib_path.end(), ".metallib")) {
auto [lib, error] = load_library_from_path(device, lib_path.c_str());
if (!lib) {
std::ostringstream msg;
msg << "Failed to load the metallib from <" << lib_path << "> with error "
<< error->localizedDescription()->utf8String();
throw std::runtime_error(msg.str());
} }
}
// We have been given a path so try to load from lib_path / lib_name.metallib
if (lib_path.size() > 0) {
std::string full_path = lib_path + "/" + lib_name + ".metallib";
auto [lib, error] = load_library_from_path(device, full_path.c_str());
if (!lib) {
std::ostringstream msg;
msg << "Failed to load the metallib from <" << full_path
<< "> with error " << error->localizedDescription()->utf8String();
throw std::runtime_error(msg.str());
}
}
// Try to load the colocated library
{
auto [lib, error] = load_colocated_library(device, lib_name);
if (lib) {
return lib;
}
}
// Try to load the library from swiftpm
{
auto [lib, error] = load_swiftpm_library(device, lib_name);
if (lib) {
return lib;
}
}
std::ostringstream msg;
msg << "Failed to load the metallib " << lib_name << ".metallib. "
<< "We attempted to load it from <" << get_colocated_mtllib_path(lib_name)
<< ">";
#ifdef SWIFTPM_BUNDLE
msg << " and from the Swift PM bundle.";
#endif
throw std::runtime_error(msg.str());
} }
} // namespace } // namespace
@@ -210,7 +286,7 @@ void CommandEncoder::barrier() {
Device::Device() { Device::Device() {
auto pool = new_scoped_memory_pool(); auto pool = new_scoped_memory_pool();
device_ = load_device(); device_ = load_device();
library_map_ = {{"mlx", load_library(device_)}}; library_map_ = {{"mlx", load_default_library(device_)}};
arch_ = std::string(device_->architecture()->name()->utf8String()); arch_ = std::string(device_->architecture()->name()->utf8String());
auto arch = arch_.back(); auto arch = arch_.back();
switch (arch) { switch (arch) {

View File

@@ -189,15 +189,7 @@ class Device {
void register_library( void register_library(
const std::string& lib_name, const std::string& lib_name,
const std::string& lib_path); const std::string& lib_path = "");
// Note, this should remain in the header so that it is not dynamically
// linked
void register_library(const std::string& lib_name) {
if (auto it = library_map_.find(lib_name); it == library_map_.end()) {
register_library(lib_name, get_colocated_mtllib_path(lib_name));
}
}
MTL::Library* get_library( MTL::Library* get_library(
const std::string& name, const std::string& name,

View File

@@ -24,10 +24,6 @@ void Event::wait() {
} }
} }
void Event::signal() {
static_cast<MTL::SharedEvent*>(event_.get())->setSignaledValue(value());
}
void Event::wait(Stream stream) { void Event::wait(Stream stream) {
if (stream.device == Device::cpu) { if (stream.device == Device::cpu) {
scheduler::enqueue(stream, [*this]() mutable { wait(); }); scheduler::enqueue(stream, [*this]() mutable { wait(); });
@@ -42,7 +38,9 @@ void Event::wait(Stream stream) {
void Event::signal(Stream stream) { void Event::signal(Stream stream) {
if (stream.device == Device::cpu) { if (stream.device == Device::cpu) {
scheduler::enqueue(stream, [*this]() mutable { signal(); }); scheduler::enqueue(stream, [*this]() mutable {
static_cast<MTL::SharedEvent*>(event_.get())->setSignaledValue(value());
});
} else { } else {
auto& d = metal::device(stream.device); auto& d = metal::device(stream.device);
d.end_encoding(stream.index); d.end_encoding(stream.index);

View File

@@ -356,20 +356,14 @@ void multi_upload_bluestein_fft(
bool inverse, bool inverse,
bool real, bool real,
FFTPlan& plan, FFTPlan& plan,
std::vector<array> copies, std::vector<array>& copies,
const Stream& s) { const Stream& s) {
// TODO(alexbarron) Implement fused kernels for mutli upload bluestein's // TODO(alexbarron) Implement fused kernels for mutli upload bluestein's
// algorithm // algorithm
int n = inverse ? out.shape(axis) : in.shape(axis); int n = inverse ? out.shape(axis) : in.shape(axis);
auto [w_k, w_q] = compute_bluestein_constants(n, plan.bluestein_n); auto [w_k, w_q] = compute_bluestein_constants(n, plan.bluestein_n);
copies.push_back(w_k);
// Broadcast w_q and w_k to the batch size copies.push_back(w_q);
Strides b_strides(in.ndim(), 0);
b_strides[axis] = 1;
array w_k_broadcast({}, complex64, nullptr, {});
array w_q_broadcast({}, complex64, nullptr, {});
w_k_broadcast.copy_shared_buffer(w_k, b_strides, {}, w_k.data_size());
w_q_broadcast.copy_shared_buffer(w_q, b_strides, {}, w_q.data_size());
auto temp_shape = inverse ? out.shape() : in.shape(); auto temp_shape = inverse ? out.shape() : in.shape();
array temp(temp_shape, complex64, nullptr, {}); array temp(temp_shape, complex64, nullptr, {});
@@ -378,13 +372,13 @@ void multi_upload_bluestein_fft(
if (real && !inverse) { if (real && !inverse) {
// Convert float32->complex64 // Convert float32->complex64
copy_gpu(in, temp, CopyType::General, s); copy_gpu(in, temp, CopyType::General, s);
copies.push_back(temp);
} else if (real && inverse) { } else if (real && inverse) {
int back_offset = n % 2 == 0 ? 2 : 1; int back_offset = n % 2 == 0 ? 2 : 1;
auto slice_shape = in.shape(); auto slice_shape = in.shape();
slice_shape[axis] -= back_offset; slice_shape[axis] -= back_offset;
array slice_temp(slice_shape, complex64, nullptr, {}); array slice_temp(slice_shape, complex64, nullptr, {});
array conj_temp(in.shape(), complex64, nullptr, {}); array conj_temp(in.shape(), complex64, nullptr, {});
copies.push_back(slice_temp);
copies.push_back(conj_temp); copies.push_back(conj_temp);
Shape rstarts(in.ndim(), 0); Shape rstarts(in.ndim(), 0);
@@ -394,19 +388,28 @@ void multi_upload_bluestein_fft(
unary_op_gpu({in}, conj_temp, "Conjugate", s); unary_op_gpu({in}, conj_temp, "Conjugate", s);
slice_gpu(in, slice_temp, rstarts, rstrides, s); slice_gpu(in, slice_temp, rstarts, rstrides, s);
concatenate_gpu({conj_temp, slice_temp}, temp, (int)axis, s); concatenate_gpu({conj_temp, slice_temp}, temp, (int)axis, s);
copies.push_back(temp);
} else if (inverse) { } else if (inverse) {
unary_op_gpu({in}, temp, "Conjugate", s); unary_op_gpu({in}, temp, "Conjugate", s);
copies.push_back(temp);
} else { } else {
temp.copy_shared_buffer(in); temp.copy_shared_buffer(in);
} }
Strides b_strides(in.ndim(), 0);
b_strides[axis] = 1;
array w_k_broadcast(temp.shape(), complex64, nullptr, {});
w_k_broadcast.copy_shared_buffer(w_k, b_strides, {}, w_k.data_size());
binary_op_gpu({temp, w_k_broadcast}, temp1, "Multiply", s); binary_op_gpu({temp, w_k_broadcast}, temp1, "Multiply", s);
std::vector<std::pair<int, int>> pads; std::vector<std::pair<int, int>> pads;
auto padded_shape = out.shape(); auto padded_shape = out.shape();
padded_shape[axis] = plan.bluestein_n; padded_shape[axis] = plan.bluestein_n;
array pad_temp(padded_shape, complex64, nullptr, {}); array pad_temp(padded_shape, complex64, nullptr, {});
pad_gpu(temp1, array(complex64_t{0.0f, 0.0f}), pad_temp, {(int)axis}, {0}, s); auto zero = array(complex64_t{0.0f, 0.0f});
copies.push_back(zero);
pad_gpu(temp1, zero, pad_temp, {(int)axis}, {0}, s);
copies.push_back(pad_temp);
array pad_temp1(padded_shape, complex64, nullptr, {}); array pad_temp1(padded_shape, complex64, nullptr, {});
fft_op( fft_op(
@@ -418,7 +421,10 @@ void multi_upload_bluestein_fft(
FourStepParams(), FourStepParams(),
/*inplace=*/false, /*inplace=*/false,
s); s);
copies.push_back(pad_temp1);
array w_q_broadcast(pad_temp1.shape(), complex64, nullptr, {});
w_q_broadcast.copy_shared_buffer(w_q, b_strides, {}, w_q.data_size());
binary_op_gpu_inplace({pad_temp1, w_q_broadcast}, pad_temp, "Multiply", s); binary_op_gpu_inplace({pad_temp1, w_q_broadcast}, pad_temp, "Multiply", s);
fft_op( fft_op(
@@ -435,9 +441,11 @@ void multi_upload_bluestein_fft(
Shape starts(in.ndim(), 0); Shape starts(in.ndim(), 0);
Shape strides(in.ndim(), 1); Shape strides(in.ndim(), 1);
starts[axis] = plan.bluestein_n - offset - n; starts[axis] = plan.bluestein_n - offset - n;
slice_gpu(pad_temp1, temp, starts, strides, s);
binary_op_gpu_inplace({temp, w_k_broadcast}, temp1, "Multiply", s); array temp2(temp_shape, complex64, nullptr, {});
slice_gpu(pad_temp1, temp2, starts, strides, s);
binary_op_gpu_inplace({temp2, w_k_broadcast}, temp1, "Multiply", s);
if (real && !inverse) { if (real && !inverse) {
Shape rstarts(in.ndim(), 0); Shape rstarts(in.ndim(), 0);
@@ -449,26 +457,21 @@ void multi_upload_bluestein_fft(
array temp_float(out.shape(), out.dtype(), nullptr, {}); array temp_float(out.shape(), out.dtype(), nullptr, {});
copies.push_back(temp_float); copies.push_back(temp_float);
copies.push_back(inv_n); copies.push_back(inv_n);
copies.push_back(temp1);
copy_gpu(temp1, temp_float, CopyType::General, s); copy_gpu(temp1, temp_float, CopyType::General, s);
binary_op_gpu({temp_float, inv_n}, out, "Multiply", s); binary_op_gpu({temp_float, inv_n}, out, "Multiply", s);
} else if (inverse) { } else if (inverse) {
auto inv_n = array({1.0f / n}, {1}, complex64); auto inv_n = array({1.0f / n}, {1}, complex64);
unary_op_gpu({temp1}, temp, "Conjugate", s); array temp3(temp_shape, complex64, nullptr, {});
binary_op_gpu({temp, inv_n}, out, "Multiply", s); unary_op_gpu({temp1}, temp3, "Conjugate", s);
binary_op_gpu({temp3, inv_n}, out, "Multiply", s);
copies.push_back(inv_n); copies.push_back(inv_n);
copies.push_back(temp1);
copies.push_back(temp3);
} else { } else {
out.copy_shared_buffer(temp1); out.copy_shared_buffer(temp1);
} }
copies.push_back(w_k);
copies.push_back(w_q);
copies.push_back(w_k_broadcast);
copies.push_back(w_q_broadcast);
copies.push_back(temp);
copies.push_back(temp1);
copies.push_back(pad_temp);
copies.push_back(pad_temp1);
} }
void four_step_fft( void four_step_fft(
@@ -478,8 +481,9 @@ void four_step_fft(
bool inverse, bool inverse,
bool real, bool real,
FFTPlan& plan, FFTPlan& plan,
std::vector<array> copies, std::vector<array>& copies,
const Stream& s) { const Stream& s,
bool in_place) {
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
if (plan.bluestein_n == -1) { if (plan.bluestein_n == -1) {
@@ -492,7 +496,14 @@ void four_step_fft(
in, temp, axis, inverse, real, four_step_params, /*inplace=*/false, s); in, temp, axis, inverse, real, four_step_params, /*inplace=*/false, s);
four_step_params.first_step = false; four_step_params.first_step = false;
fft_op( fft_op(
temp, out, axis, inverse, real, four_step_params, /*inplace=*/false, s); temp,
out,
axis,
inverse,
real,
four_step_params,
/*inplace=*/in_place,
s);
copies.push_back(temp); copies.push_back(temp);
} else { } else {
multi_upload_bluestein_fft(in, out, axis, inverse, real, plan, copies, s); multi_upload_bluestein_fft(in, out, axis, inverse, real, plan, copies, s);
@@ -574,7 +585,7 @@ void fft_op(
auto plan = plan_fft(n); auto plan = plan_fft(n);
if (plan.four_step) { if (plan.four_step) {
four_step_fft(in, out, axis, inverse, real, plan, copies, s); four_step_fft(in, out, axis, inverse, real, plan, copies, s, inplace);
d.add_temporaries(std::move(copies), s.index); d.add_temporaries(std::move(copies), s.index);
return; return;
} }

View File

@@ -1,9 +0,0 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view arange_kernels = R"(
template [[host_name("{0}")]] [[kernel]] void arange<{1}>(
constant const {1}& start,
constant const {1}& step,
device {1}* out,
uint index [[thread_position_in_grid]]);
)";

View File

@@ -20,6 +20,7 @@ const char* copy();
const char* fft(); const char* fft();
const char* gather_axis(); const char* gather_axis();
const char* hadamard(); const char* hadamard();
const char* logsumexp();
const char* quantized(); const char* quantized();
const char* ternary(); const char* ternary();
const char* scan(); const char* scan();
@@ -32,6 +33,7 @@ const char* gemm();
const char* steel_gemm_fused(); const char* steel_gemm_fused();
const char* steel_gemm_masked(); const char* steel_gemm_masked();
const char* steel_gemm_splitk(); const char* steel_gemm_splitk();
const char* steel_gemm_gather();
const char* conv(); const char* conv();
const char* steel_conv(); const char* steel_conv();
const char* steel_conv_general(); const char* steel_conv_general();

View File

@@ -1,23 +0,0 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view softmax_kernels = R"(
template [[host_name("block_{0}")]] [[kernel]] void
softmax_single_row<{1}, {2}>(
const device {1}* in,
device {1}* out,
constant int& axis_size,
uint gid [[thread_position_in_grid]],
uint _lid [[thread_position_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
template [[host_name("looped_{0}")]] [[kernel]] void
softmax_looped<{1}, {2}>(
const device {1}* in,
device {1}* out,
constant int& axis_size,
uint gid [[threadgroup_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
)";

View File

@@ -1,8 +1,6 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
#include "mlx/backend/common/compiled.h" #include "mlx/backend/common/compiled.h"
#include "mlx/backend/metal/jit/arange.h"
#include "mlx/backend/metal/jit/includes.h" #include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/jit/softmax.h"
#include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h" #include "mlx/backend/metal/utils.h"
@@ -21,13 +19,11 @@ MTL::ComputePipelineState* get_arange_kernel(
const std::string& kernel_name, const std::string& kernel_name,
const array& out) { const array& out) {
auto lib = d.get_library(kernel_name, [&]() { auto lib = d.get_library(kernel_name, [&]() {
std::ostringstream kernel_source; std::string kernel_source = metal::utils();
kernel_source << metal::utils() << metal::arange() kernel_source += metal::arange();
<< fmt::format( kernel_source += get_template_definition(
arange_kernels, kernel_name, "arange", get_type_string(out.dtype()));
kernel_name, return kernel_source;
get_type_string(out.dtype()));
return kernel_source.str();
}); });
return d.get_kernel(kernel_name, lib); return d.get_kernel(kernel_name, lib);
} }
@@ -259,14 +255,34 @@ MTL::ComputePipelineState* get_softmax_kernel(
const array& out) { const array& out) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name, [&] { auto lib = d.get_library(lib_name, [&] {
std::ostringstream kernel_source; std::string kernel_source = metal::utils();
kernel_source << metal::utils() << metal::softmax() auto in_type = get_type_string(out.dtype());
<< fmt::format( auto acc_type = get_type_string(precise ? float32 : out.dtype());
softmax_kernels, kernel_source += metal::softmax();
lib_name, kernel_source += get_template_definition(
get_type_string(out.dtype()), "block_" + lib_name, "softmax_single_row", in_type, acc_type);
get_type_string(precise ? float32 : out.dtype())); kernel_source += get_template_definition(
return kernel_source.str(); "looped_" + lib_name, "softmax_looped", in_type, acc_type);
return kernel_source;
});
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_logsumexp_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name, [&] {
auto t_str = get_type_string(out.dtype());
std::string kernel_source;
kernel_source = metal::utils();
kernel_source += metal::logsumexp();
kernel_source +=
get_template_definition("block_" + lib_name, "logsumexp", t_str);
kernel_source += get_template_definition(
"looped_" + lib_name, "logsumexp_looped", t_str);
return kernel_source;
}); });
return d.get_kernel(kernel_name, lib); return d.get_kernel(kernel_name, lib);
} }
@@ -568,6 +584,44 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
return d.get_kernel(kernel_name, lib); return d.get_kernel(kernel_name, lib);
} }
MTL::ComputePipelineState* get_steel_gemm_gather_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const array& out,
bool transpose_a,
bool transpose_b,
int bm,
int bn,
int bk,
int wm,
int wn,
bool rhs) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name, [&]() {
std::string kernel_source;
concatenate(
kernel_source,
metal::utils(),
metal::gemm(),
metal::steel_gemm_gather(),
get_template_definition(
lib_name,
rhs ? "gather_mm_rhs" : "gather_mm",
get_type_string(out.dtype()),
bm,
bn,
bk,
wm,
wn,
transpose_a,
transpose_b));
return kernel_source;
});
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
}
MTL::ComputePipelineState* get_gemv_masked_kernel( MTL::ComputePipelineState* get_gemv_masked_kernel(
metal::Device& d, metal::Device& d,
const std::string& kernel_name, const std::string& kernel_name,
@@ -698,4 +752,43 @@ MTL::ComputePipelineState* get_quantized_kernel(
return d.get_kernel(kernel_name, lib); return d.get_kernel(kernel_name, lib);
} }
MTL::ComputePipelineState* get_gather_qmm_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const array& x,
int group_size,
int bits,
int bm,
int bn,
int bk,
int wm,
int wn,
bool transpose) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name, [&]() {
std::string kernel_source;
concatenate(
kernel_source,
metal::utils(),
metal::gemm(),
metal::quantized(),
get_template_definition(
lib_name,
"gather_qmm_rhs",
get_type_string(x.dtype()),
group_size,
bits,
bm,
bn,
bk,
wm,
wn,
transpose));
return kernel_source;
});
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
}
} // namespace mlx::core } // namespace mlx::core

View File

@@ -59,6 +59,11 @@ MTL::ComputePipelineState* get_softmax_kernel(
bool precise, bool precise,
const array& out); const array& out);
MTL::ComputePipelineState* get_logsumexp_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out);
MTL::ComputePipelineState* get_scan_kernel( MTL::ComputePipelineState* get_scan_kernel(
metal::Device& d, metal::Device& d,
const std::string& kernel_name, const std::string& kernel_name,
@@ -155,6 +160,21 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
bool mn_aligned, bool mn_aligned,
bool k_aligned); bool k_aligned);
MTL::ComputePipelineState* get_steel_gemm_gather_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const array& out,
bool transpose_a,
bool transpose_b,
int bm,
int bn,
int bk,
int wm,
int wn,
bool rhs);
MTL::ComputePipelineState* get_steel_conv_kernel( MTL::ComputePipelineState* get_steel_conv_kernel(
metal::Device& d, metal::Device& d,
const std::string& kernel_name, const std::string& kernel_name,
@@ -204,6 +224,21 @@ MTL::ComputePipelineState* get_quantized_kernel(
const std::string& kernel_name, const std::string& kernel_name,
const std::string& template_def); const std::string& template_def);
MTL::ComputePipelineState* get_gather_qmm_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const array& x,
int group_size,
int bits,
int bm,
int bn,
int bk,
int wm,
int wn,
bool transpose);
// Create a GPU kernel template definition for JIT compilation // Create a GPU kernel template definition for JIT compilation
template <typename... Args> template <typename... Args>
std::string std::string

View File

@@ -13,6 +13,10 @@ function(build_kernel_base TARGET SRCFILE DEPS)
if(MLX_METAL_DEBUG) if(MLX_METAL_DEBUG)
set(METAL_FLAGS ${METAL_FLAGS} -gline-tables-only -frecord-sources) set(METAL_FLAGS ${METAL_FLAGS} -gline-tables-only -frecord-sources)
endif() endif()
if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "")
set(METAL_FLAGS ${METAL_FLAGS}
"-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
endif()
if(MLX_METAL_VERSION GREATER_EQUAL 310) if(MLX_METAL_VERSION GREATER_EQUAL 310)
set(VERSION_INCLUDES set(VERSION_INCLUDES
${PROJECT_SOURCE_DIR}/mlx/backend/metal/kernels/metal_3_1) ${PROJECT_SOURCE_DIR}/mlx/backend/metal/kernels/metal_3_1)
@@ -65,6 +69,7 @@ set(STEEL_HEADERS
steel/gemm/loader.h steel/gemm/loader.h
steel/gemm/transforms.h steel/gemm/transforms.h
steel/gemm/kernels/steel_gemm_fused.h steel/gemm/kernels/steel_gemm_fused.h
steel/gemm/kernels/steel_gemm_gather.h
steel/gemm/kernels/steel_gemm_masked.h steel/gemm/kernels/steel_gemm_masked.h
steel/gemm/kernels/steel_gemm_splitk.h steel/gemm/kernels/steel_gemm_splitk.h
steel/utils/type_traits.h steel/utils/type_traits.h
@@ -105,12 +110,14 @@ if(NOT MLX_METAL_JIT)
build_kernel(quantized quantized.h ${STEEL_HEADERS}) build_kernel(quantized quantized.h ${STEEL_HEADERS})
build_kernel(scan scan.h) build_kernel(scan scan.h)
build_kernel(softmax softmax.h) build_kernel(softmax softmax.h)
build_kernel(logsumexp logsumexp.h)
build_kernel(sort sort.h) build_kernel(sort sort.h)
build_kernel(ternary ternary.h ternary_ops.h) build_kernel(ternary ternary.h ternary_ops.h)
build_kernel(unary unary.h unary_ops.h) build_kernel(unary unary.h unary_ops.h)
build_kernel(steel/conv/kernels/steel_conv ${STEEL_HEADERS}) build_kernel(steel/conv/kernels/steel_conv ${STEEL_HEADERS})
build_kernel(steel/conv/kernels/steel_conv_general ${STEEL_HEADERS}) build_kernel(steel/conv/kernels/steel_conv_general ${STEEL_HEADERS})
build_kernel(steel/gemm/kernels/steel_gemm_fused ${STEEL_HEADERS}) build_kernel(steel/gemm/kernels/steel_gemm_fused ${STEEL_HEADERS})
build_kernel(steel/gemm/kernels/steel_gemm_gather ${STEEL_HEADERS})
build_kernel(steel/gemm/kernels/steel_gemm_masked ${STEEL_HEADERS}) build_kernel(steel/gemm/kernels/steel_gemm_masked ${STEEL_HEADERS})
build_kernel(steel/gemm/kernels/steel_gemm_splitk ${STEEL_HEADERS}) build_kernel(steel/gemm/kernels/steel_gemm_splitk ${STEEL_HEADERS})
build_kernel(gemv_masked steel/utils.h) build_kernel(gemv_masked steel/utils.h)

View File

@@ -5,11 +5,7 @@
#include "mlx/backend/metal/kernels/arange.h" #include "mlx/backend/metal/kernels/arange.h"
#define instantiate_arange(tname, type) \ #define instantiate_arange(tname, type) \
template [[host_name("arange" #tname)]] [[kernel]] void arange<type>( \ instantiate_kernel("arange" #tname, arange, type)
constant const type& start, \
constant const type& step, \
device type* out, \
uint index [[thread_position_in_grid]]);
instantiate_arange(uint8, uint8_t) instantiate_arange(uint8, uint8_t)
instantiate_arange(uint16, uint16_t) instantiate_arange(uint16, uint16_t)

View File

@@ -275,6 +275,128 @@ instantiate_naive_conv_2d_blocks(float32, float);
instantiate_naive_conv_2d_blocks(float16, half); instantiate_naive_conv_2d_blocks(float16, half);
instantiate_naive_conv_2d_blocks(bfloat16, bfloat16_t); instantiate_naive_conv_2d_blocks(bfloat16, bfloat16_t);
///////////////////////////////////////////////////////////////////////////////
/// Depthwise convolution kernels
///////////////////////////////////////////////////////////////////////////////
constant int ker_h [[function_constant(00)]];
constant int ker_w [[function_constant(01)]];
constant int str_h [[function_constant(10)]];
constant int str_w [[function_constant(11)]];
constant int tgp_h [[function_constant(100)]];
constant int tgp_w [[function_constant(101)]];
constant bool do_flip [[function_constant(200)]];
constant int span_h = tgp_h * str_h + ker_h - 1;
constant int span_w = tgp_w * str_w + ker_w - 1;
constant int span_hw = span_h * span_w;
template <typename T>
[[kernel]] void depthwise_conv_2d(
const device T* in [[buffer(0)]],
const device T* wt [[buffer(1)]],
device T* out [[buffer(2)]],
const constant MLXConvParams<2>& params [[buffer(3)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 gid [[thread_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int tc = 8;
constexpr int tw = 8;
constexpr int th = 4;
constexpr int c_per_thr = 8;
constexpr int TGH = th * 2 + 6;
constexpr int TGW = tw * 2 + 6;
constexpr int TGC = tc;
threadgroup T ins[TGH * TGW * TGC];
const int n_tgblocks_h = params.oS[0] / th;
const int n = tid.z / n_tgblocks_h;
const int tghid = tid.z % n_tgblocks_h;
const int oh = tghid * th + lid.z;
const int ow = gid.y;
const int c = gid.x;
in += n * params.in_strides[0];
// Load in
{
constexpr int n_threads = th * tw * tc;
const int tg_oh = (tghid * th) * str_h - params.pad[0];
const int tg_ow = (tid.y * tw) * str_w - params.pad[1];
const int tg_c = tid.x * tc;
const int thread_idx = simd_gid * 32 + simd_lid;
constexpr int thr_per_hw = tc / c_per_thr;
constexpr int hw_per_group = n_threads / thr_per_hw;
const int thr_c = thread_idx % thr_per_hw;
const int thr_hw = thread_idx / thr_per_hw;
for (int hw = thr_hw; hw < span_hw; hw += hw_per_group) {
const int h = hw / span_w;
const int w = hw % span_w;
const int ih = tg_oh + h;
const int iw = tg_ow + w;
const int in_s_offset = h * span_w * TGC + w * TGC;
if (ih >= 0 && ih < params.iS[0] && iw >= 0 && iw < params.iS[1]) {
const auto in_load =
in + ih * params.in_strides[1] + iw * params.in_strides[2] + tg_c;
MLX_MTL_PRAGMA_UNROLL
for (int cc = 0; cc < c_per_thr; ++cc) {
ins[in_s_offset + c_per_thr * thr_c + cc] =
in_load[c_per_thr * thr_c + cc];
}
} else {
MLX_MTL_PRAGMA_UNROLL
for (int cc = 0; cc < c_per_thr; ++cc) {
ins[in_s_offset + c_per_thr * thr_c + cc] = T(0);
}
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
wt += c * params.wt_strides[0];
const auto ins_ptr =
&ins[lid.z * str_h * span_w * TGC + lid.y * str_w * TGC + lid.x];
float o = 0.;
for (int h = 0; h < ker_h; ++h) {
for (int w = 0; w < ker_w; ++w) {
int wt_h = h;
int wt_w = w;
if (do_flip) {
wt_h = ker_h - h - 1;
wt_w = ker_w - w - 1;
}
auto inv = ins_ptr[h * span_w * TGC + w * TGC];
auto wtv = wt[wt_h * ker_w + wt_w];
o += inv * wtv;
}
}
threadgroup_barrier(mem_flags::mem_none);
out += n * params.out_strides[0] + oh * params.out_strides[1] +
ow * params.out_strides[2];
out[c] = static_cast<T>(o);
}
#define instantiate_depthconv2d(iname, itype) \
instantiate_kernel("depthwise_conv_2d_" #iname, depthwise_conv_2d, itype)
instantiate_depthconv2d(float32, float);
instantiate_depthconv2d(float16, half);
instantiate_depthconv2d(bfloat16, bfloat16_t);
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
/// Winograd kernels /// Winograd kernels
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////

View File

@@ -341,7 +341,7 @@ struct GEMVTKernel {
MLX_MTL_PRAGMA_UNROLL MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) { for (int tm = 0; tm < TM; tm++) {
auto vc = float(v_coeff[tm]); auto vc = static_cast<AccT>(v_coeff[tm]);
for (int tn = 0; tn < TN; tn++) { for (int tn = 0; tn < TN; tn++) {
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn]; inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
} }

View File

@@ -493,71 +493,11 @@ template <typename T, int N_READS = RMS_N_READS>
} }
// clang-format off // clang-format off
#define instantiate_layer_norm_single_row(name, itype) \
template [[host_name("layer_norm" #name)]] [[kernel]] void \
layer_norm_single_row<itype>( \
const device itype* x, \
const device itype* w, \
const device itype* b, \
device itype* out, \
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
constant uint& b_stride, \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
template [[host_name("vjp_layer_norm" #name)]] [[kernel]] void \
vjp_layer_norm_single_row<itype>( \
const device itype* x, \
const device itype* w, \
const device itype* g, \
device itype* gx, \
device itype* gw, \
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_layer_norm_looped(name, itype) \
template [[host_name("layer_norm_looped" #name)]] [[kernel]] void \
layer_norm_looped<itype>( \
const device itype* x, \
const device itype* w, \
const device itype* b, \
device itype* out, \
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
constant uint& b_stride, \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
template [[host_name("vjp_layer_norm_looped" #name)]] [[kernel]] void \
vjp_layer_norm_looped<itype>( \
const device itype* x, \
const device itype* w, \
const device itype* g, \
device itype* gx, \
device itype* gb, \
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_layer_norm(name, itype) \ #define instantiate_layer_norm(name, itype) \
instantiate_layer_norm_single_row(name, itype) \ instantiate_kernel("layer_norm" #name, layer_norm_single_row, itype) \
instantiate_layer_norm_looped(name, itype) instantiate_kernel("vjp_layer_norm" #name, vjp_layer_norm_single_row, itype) \
instantiate_kernel("layer_norm_looped" #name, layer_norm_looped, itype) \
instantiate_kernel("vjp_layer_norm_looped" #name, vjp_layer_norm_looped, itype)
instantiate_layer_norm(float32, float) instantiate_layer_norm(float32, float)
instantiate_layer_norm(float16, half) instantiate_layer_norm(float16, half)

View File

@@ -0,0 +1,143 @@
// Copyright © 2025 Apple Inc.
template <typename T, typename AccT = float, int N_READS = 4>
[[kernel]] void logsumexp(
const device T* in,
device T* out,
constant int& axis_size,
uint gid [[threadgroup_position_in_grid]],
uint _lid [[thread_position_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
int lid = _lid;
constexpr int SIMD_SIZE = 32;
threadgroup AccT local_max[SIMD_SIZE];
threadgroup AccT local_normalizer[SIMD_SIZE];
AccT ld[N_READS];
in += gid * size_t(axis_size) + lid * N_READS;
if (lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
ld[i] = AccT(in[i]);
}
} else {
for (int i = 0; i < N_READS; i++) {
ld[i] =
((lid * N_READS + i) < axis_size) ? AccT(in[i]) : Limits<AccT>::min;
}
}
if (simd_group_id == 0) {
local_max[simd_lane_id] = Limits<AccT>::min;
local_normalizer[simd_lane_id] = 0;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Get the max
AccT maxval = Limits<AccT>::finite_min;
for (int i = 0; i < N_READS; i++) {
maxval = (maxval < ld[i]) ? ld[i] : maxval;
}
maxval = simd_max(maxval);
if (simd_lane_id == 0) {
local_max[simd_group_id] = maxval;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_group_id == 0) {
maxval = simd_max(local_max[simd_lane_id]);
if (simd_lane_id == 0) {
local_max[0] = maxval;
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
maxval = local_max[0];
// Compute exp(x_i - maxval) and store the partial sums in local_normalizer
AccT normalizer = 0;
for (int i = 0; i < N_READS; i++) {
normalizer += fast::exp(ld[i] - maxval);
}
normalizer = simd_sum(normalizer);
if (simd_lane_id == 0) {
local_normalizer[simd_group_id] = normalizer;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_group_id == 0) {
normalizer = simd_sum(local_normalizer[simd_lane_id]);
if (simd_lane_id == 0) {
out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval);
}
}
}
template <typename T, typename AccT = float, int N_READS = 4>
[[kernel]] void logsumexp_looped(
const device T* in,
device T* out,
constant int& axis_size,
uint gid [[threadgroup_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
in += gid * size_t(axis_size);
constexpr int SIMD_SIZE = 32;
threadgroup AccT local_max[SIMD_SIZE];
threadgroup AccT local_normalizer[SIMD_SIZE];
// Get the max and the normalizer in one go
AccT prevmax;
AccT maxval = Limits<AccT>::finite_min;
AccT normalizer = 0;
for (int r = 0; r < static_cast<int>(ceildiv(axis_size, N_READS * lsize));
r++) {
int offset = r * lsize * N_READS + lid * N_READS;
AccT vals[N_READS];
if (offset + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
vals[i] = AccT(in[offset + i]);
}
} else {
for (int i = 0; i < N_READS; i++) {
vals[i] = (offset + i < axis_size) ? AccT(in[offset + i])
: Limits<AccT>::finite_min;
}
}
prevmax = maxval;
for (int i = 0; i < N_READS; i++) {
maxval = (maxval < vals[i]) ? vals[i] : maxval;
}
normalizer *= fast::exp(prevmax - maxval);
for (int i = 0; i < N_READS; i++) {
normalizer += fast::exp(vals[i] - maxval);
}
}
prevmax = maxval;
maxval = simd_max(maxval);
normalizer *= fast::exp(prevmax - maxval);
normalizer = simd_sum(normalizer);
prevmax = maxval;
if (simd_lane_id == 0) {
local_max[simd_group_id] = maxval;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
maxval = simd_max(local_max[simd_lane_id]);
normalizer *= fast::exp(prevmax - maxval);
if (simd_lane_id == 0) {
local_normalizer[simd_group_id] = normalizer;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
normalizer = simd_sum(local_normalizer[simd_lane_id]);
if (simd_group_id == 0) {
normalizer = simd_sum(local_normalizer[simd_lane_id]);
if (simd_lane_id == 0) {
out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval);
}
}
}

View File

@@ -0,0 +1,18 @@
// Copyright © 2023-2024 Apple Inc.
#include <metal_common>
#include <metal_simdgroup>
using namespace metal;
// clang-format off
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/logsumexp.h"
#define instantiate_logsumexp(name, itype) \
instantiate_kernel("block_logsumexp_" #name, logsumexp, itype) \
instantiate_kernel("looped_logsumexp_" #name, logsumexp_looped, itype) \
instantiate_logsumexp(float32, float)
instantiate_logsumexp(float16, half)
instantiate_logsumexp(bfloat16, bfloat16_t) // clang-format on

View File

@@ -3,6 +3,10 @@
#include <metal_simdgroup> #include <metal_simdgroup>
#include <metal_stdlib> #include <metal_stdlib>
constant bool align_M [[function_constant(200)]];
constant bool align_N [[function_constant(201)]];
constant bool align_K [[function_constant(202)]];
using namespace metal; using namespace metal;
#define MLX_MTL_CONST static constant constexpr const #define MLX_MTL_CONST static constant constexpr const
@@ -586,13 +590,13 @@ METAL_FUNC void qmv_quad_impl(
// Adjust positions // Adjust positions
const int in_vec_size_w = in_vec_size / pack_factor; const int in_vec_size_w = in_vec_size / pack_factor;
const int in_vec_size_g = in_vec_size / group_size; const int in_vec_size_g = in_vec_size / group_size;
const int out_row = tid.x * quads_per_simd * results_per_quadgroup + quad_gid; const int out_row = tid.y * quads_per_simd * results_per_quadgroup + quad_gid;
w += out_row * in_vec_size_w + quad_lid * packs_per_thread; w += out_row * in_vec_size_w + quad_lid * packs_per_thread;
scales += out_row * in_vec_size_g + quad_lid / scale_step_per_thread; scales += out_row * in_vec_size_g + quad_lid / scale_step_per_thread;
biases += out_row * in_vec_size_g + quad_lid / scale_step_per_thread; biases += out_row * in_vec_size_g + quad_lid / scale_step_per_thread;
x += tid.y * in_vec_size + quad_lid * values_per_thread; x += tid.x * in_vec_size + quad_lid * values_per_thread;
y += tid.y * out_vec_size + out_row; y += tid.x * out_vec_size + out_row;
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread); U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
@@ -1686,26 +1690,26 @@ template <
} }
template <typename T, int group_size, int bits> template <typename T, int group_size, int bits>
[[kernel]] void bs_qmv_fast( [[kernel]] void gather_qmv_fast(
const device uint32_t* w [[buffer(0)]], const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]], const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]], const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]], const device T* x [[buffer(3)]],
device T* y [[buffer(4)]], const device uint32_t* lhs_indices [[buffer(4)]],
const constant int& in_vec_size [[buffer(5)]], const device uint32_t* rhs_indices [[buffer(5)]],
const constant int& out_vec_size [[buffer(6)]], device T* y [[buffer(6)]],
const constant int& x_batch_ndims [[buffer(7)]], const constant int& in_vec_size [[buffer(7)]],
const constant int* x_shape [[buffer(8)]], const constant int& out_vec_size [[buffer(8)]],
const constant int64_t* x_strides [[buffer(9)]], const constant int& x_batch_ndims [[buffer(9)]],
const constant int& w_batch_ndims [[buffer(10)]], const constant int* x_shape [[buffer(10)]],
const constant int* w_shape [[buffer(11)]], const constant int64_t* x_strides [[buffer(11)]],
const constant int64_t* w_strides [[buffer(12)]], const constant int& w_batch_ndims [[buffer(12)]],
const constant int64_t* s_strides [[buffer(13)]], const constant int* w_shape [[buffer(13)]],
const constant int64_t* b_strides [[buffer(14)]], const constant int64_t* w_strides [[buffer(14)]],
const constant int& batch_ndims [[buffer(15)]], const constant int64_t* s_strides [[buffer(15)]],
const constant int* batch_shape [[buffer(16)]], const constant int64_t* b_strides [[buffer(16)]],
const device uint32_t* lhs_indices [[buffer(17)]], const constant int& batch_ndims [[buffer(17)]],
const device uint32_t* rhs_indices [[buffer(18)]], const constant int* batch_shape [[buffer(18)]],
const constant int64_t* lhs_strides [[buffer(19)]], const constant int64_t* lhs_strides [[buffer(19)]],
const constant int64_t* rhs_strides [[buffer(20)]], const constant int64_t* rhs_strides [[buffer(20)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
@@ -1748,26 +1752,26 @@ template <typename T, int group_size, int bits>
} }
template <typename T, int group_size, int bits> template <typename T, int group_size, int bits>
[[kernel]] void bs_qmv( [[kernel]] void gather_qmv(
const device uint32_t* w [[buffer(0)]], const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]], const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]], const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]], const device T* x [[buffer(3)]],
device T* y [[buffer(4)]], const device uint32_t* lhs_indices [[buffer(4)]],
const constant int& in_vec_size [[buffer(5)]], const device uint32_t* rhs_indices [[buffer(5)]],
const constant int& out_vec_size [[buffer(6)]], device T* y [[buffer(6)]],
const constant int& x_batch_ndims [[buffer(7)]], const constant int& in_vec_size [[buffer(7)]],
const constant int* x_shape [[buffer(8)]], const constant int& out_vec_size [[buffer(8)]],
const constant int64_t* x_strides [[buffer(9)]], const constant int& x_batch_ndims [[buffer(9)]],
const constant int& w_batch_ndims [[buffer(10)]], const constant int* x_shape [[buffer(10)]],
const constant int* w_shape [[buffer(11)]], const constant int64_t* x_strides [[buffer(11)]],
const constant int64_t* w_strides [[buffer(12)]], const constant int& w_batch_ndims [[buffer(12)]],
const constant int64_t* s_strides [[buffer(13)]], const constant int* w_shape [[buffer(13)]],
const constant int64_t* b_strides [[buffer(14)]], const constant int64_t* w_strides [[buffer(14)]],
const constant int& batch_ndims [[buffer(15)]], const constant int64_t* s_strides [[buffer(15)]],
const constant int* batch_shape [[buffer(16)]], const constant int64_t* b_strides [[buffer(16)]],
const device uint32_t* lhs_indices [[buffer(17)]], const constant int& batch_ndims [[buffer(17)]],
const device uint32_t* rhs_indices [[buffer(18)]], const constant int* batch_shape [[buffer(18)]],
const constant int64_t* lhs_strides [[buffer(19)]], const constant int64_t* lhs_strides [[buffer(19)]],
const constant int64_t* rhs_strides [[buffer(20)]], const constant int64_t* rhs_strides [[buffer(20)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
@@ -1810,26 +1814,26 @@ template <typename T, int group_size, int bits>
} }
template <typename T, int group_size, int bits> template <typename T, int group_size, int bits>
[[kernel]] void bs_qvm( [[kernel]] void gather_qvm(
const device uint32_t* w [[buffer(0)]], const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]], const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]], const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]], const device T* x [[buffer(3)]],
device T* y [[buffer(4)]], const device uint32_t* lhs_indices [[buffer(4)]],
const constant int& in_vec_size [[buffer(5)]], const device uint32_t* rhs_indices [[buffer(5)]],
const constant int& out_vec_size [[buffer(6)]], device T* y [[buffer(6)]],
const constant int& x_batch_ndims [[buffer(7)]], const constant int& in_vec_size [[buffer(7)]],
const constant int* x_shape [[buffer(8)]], const constant int& out_vec_size [[buffer(8)]],
const constant int64_t* x_strides [[buffer(9)]], const constant int& x_batch_ndims [[buffer(9)]],
const constant int& w_batch_ndims [[buffer(10)]], const constant int* x_shape [[buffer(10)]],
const constant int* w_shape [[buffer(11)]], const constant int64_t* x_strides [[buffer(11)]],
const constant int64_t* w_strides [[buffer(12)]], const constant int& w_batch_ndims [[buffer(12)]],
const constant int64_t* s_strides [[buffer(13)]], const constant int* w_shape [[buffer(13)]],
const constant int64_t* b_strides [[buffer(14)]], const constant int64_t* w_strides [[buffer(14)]],
const constant int& batch_ndims [[buffer(15)]], const constant int64_t* s_strides [[buffer(15)]],
const constant int* batch_shape [[buffer(16)]], const constant int64_t* b_strides [[buffer(16)]],
const device uint32_t* lhs_indices [[buffer(17)]], const constant int& batch_ndims [[buffer(17)]],
const device uint32_t* rhs_indices [[buffer(18)]], const constant int* batch_shape [[buffer(18)]],
const constant int64_t* lhs_strides [[buffer(19)]], const constant int64_t* lhs_strides [[buffer(19)]],
const constant int64_t* rhs_strides [[buffer(20)]], const constant int64_t* rhs_strides [[buffer(20)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
@@ -1879,27 +1883,27 @@ template <
const int BM = 32, const int BM = 32,
const int BK = 32, const int BK = 32,
const int BN = 32> const int BN = 32>
[[kernel]] void bs_qmm_t( [[kernel]] void gather_qmm_t(
const device uint32_t* w [[buffer(0)]], const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]], const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]], const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]], const device T* x [[buffer(3)]],
device T* y [[buffer(4)]], const device uint32_t* lhs_indices [[buffer(4)]],
const constant int& K [[buffer(5)]], const device uint32_t* rhs_indices [[buffer(5)]],
const constant int& N [[buffer(6)]], device T* y [[buffer(6)]],
const constant int& M [[buffer(7)]], const constant int& K [[buffer(7)]],
const constant int& x_batch_ndims [[buffer(8)]], const constant int& N [[buffer(8)]],
const constant int* x_shape [[buffer(9)]], const constant int& M [[buffer(9)]],
const constant int64_t* x_strides [[buffer(10)]], const constant int& x_batch_ndims [[buffer(10)]],
const constant int& w_batch_ndims [[buffer(11)]], const constant int* x_shape [[buffer(11)]],
const constant int* w_shape [[buffer(12)]], const constant int64_t* x_strides [[buffer(12)]],
const constant int64_t* w_strides [[buffer(13)]], const constant int& w_batch_ndims [[buffer(13)]],
const constant int64_t* s_strides [[buffer(14)]], const constant int* w_shape [[buffer(14)]],
const constant int64_t* b_strides [[buffer(15)]], const constant int64_t* w_strides [[buffer(15)]],
const constant int& batch_ndims [[buffer(16)]], const constant int64_t* s_strides [[buffer(16)]],
const constant int* batch_shape [[buffer(17)]], const constant int64_t* b_strides [[buffer(17)]],
const device uint32_t* lhs_indices [[buffer(18)]], const constant int& batch_ndims [[buffer(18)]],
const device uint32_t* rhs_indices [[buffer(19)]], const constant int* batch_shape [[buffer(19)]],
const constant int64_t* lhs_strides [[buffer(20)]], const constant int64_t* lhs_strides [[buffer(20)]],
const constant int64_t* rhs_strides [[buffer(21)]], const constant int64_t* rhs_strides [[buffer(21)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
@@ -1946,27 +1950,27 @@ template <
const int BM = 32, const int BM = 32,
const int BK = 32, const int BK = 32,
const int BN = 32> const int BN = 32>
[[kernel]] void bs_qmm_n( [[kernel]] void gather_qmm_n(
const device uint32_t* w [[buffer(0)]], const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]], const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]], const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]], const device T* x [[buffer(3)]],
device T* y [[buffer(4)]], const device uint32_t* lhs_indices [[buffer(4)]],
const constant int& K [[buffer(5)]], const device uint32_t* rhs_indices [[buffer(5)]],
const constant int& N [[buffer(6)]], device T* y [[buffer(6)]],
const constant int& M [[buffer(7)]], const constant int& K [[buffer(7)]],
const constant int& x_batch_ndims [[buffer(8)]], const constant int& N [[buffer(8)]],
const constant int* x_shape [[buffer(9)]], const constant int& M [[buffer(9)]],
const constant int64_t* x_strides [[buffer(10)]], const constant int& x_batch_ndims [[buffer(10)]],
const constant int& w_batch_ndims [[buffer(11)]], const constant int* x_shape [[buffer(11)]],
const constant int* w_shape [[buffer(12)]], const constant int64_t* x_strides [[buffer(12)]],
const constant int64_t* w_strides [[buffer(13)]], const constant int& w_batch_ndims [[buffer(13)]],
const constant int64_t* s_strides [[buffer(14)]], const constant int* w_shape [[buffer(14)]],
const constant int64_t* b_strides [[buffer(15)]], const constant int64_t* w_strides [[buffer(15)]],
const constant int& batch_ndims [[buffer(16)]], const constant int64_t* s_strides [[buffer(16)]],
const constant int* batch_shape [[buffer(17)]], const constant int64_t* b_strides [[buffer(17)]],
const device uint32_t* lhs_indices [[buffer(18)]], const constant int& batch_ndims [[buffer(18)]],
const device uint32_t* rhs_indices [[buffer(19)]], const constant int* batch_shape [[buffer(19)]],
const constant int64_t* lhs_strides [[buffer(20)]], const constant int64_t* lhs_strides [[buffer(20)]],
const constant int64_t* rhs_strides [[buffer(21)]], const constant int64_t* rhs_strides [[buffer(21)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
@@ -2007,6 +2011,289 @@ template <
w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
} }
template <typename T, typename mma_t, typename loader_a_t, typename loader_b_t>
METAL_FUNC void gemm_loop_aligned(
threadgroup T* As,
threadgroup T* Bs,
thread mma_t& mma_op,
thread loader_a_t& loader_a,
thread loader_b_t& loader_b,
const int k_iterations) {
for (int k = 0; k < k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup memory
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
}
template <
bool rows_aligned,
bool cols_aligned,
bool transpose,
typename T,
typename mma_t,
typename loader_a_t,
typename loader_b_t>
METAL_FUNC void gemm_loop_unaligned(
threadgroup T* As,
threadgroup T* Bs,
thread mma_t& mma_op,
thread loader_a_t& loader_a,
thread loader_b_t& loader_b,
const int k_iterations,
const short tgp_bm,
const short tgp_bn,
const short tgp_bk) {
for (int k = 0; k < k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup memory
if (rows_aligned) {
loader_a.load_unsafe();
} else {
loader_a.load_safe(short2(tgp_bk, tgp_bm));
}
if (cols_aligned) {
loader_b.load_unsafe();
} else {
loader_b.load_safe(
transpose ? short2(tgp_bk, tgp_bn) : short2(tgp_bn, tgp_bk));
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
}
template <typename T, typename mma_t, typename loader_a_t, typename loader_b_t>
METAL_FUNC void gemm_loop_finalize(
threadgroup T* As,
threadgroup T* Bs,
thread mma_t& mma_op,
thread loader_a_t& loader_a,
thread loader_b_t& loader_b,
const short2 tile_a,
const short2 tile_b) {
loader_a.load_safe(tile_a);
loader_b.load_safe(tile_b);
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
template <
typename T,
int group_size,
int bits,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose>
[[kernel]] void gather_qmm_rhs(
const device T* x [[buffer(0)]],
const device uint32_t* w [[buffer(1)]],
const device T* scales [[buffer(2)]],
const device T* biases [[buffer(3)]],
const device uint32_t* indices [[buffer(4)]],
device T* y [[buffer(5)]],
const constant int& M [[buffer(6)]],
const constant int& N [[buffer(7)]],
const constant int& K [[buffer(8)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]]) {
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
constexpr int BK_padded = (BK + 16 / sizeof(T));
constexpr int BN_padded = (BN + 16 / sizeof(T));
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
using mma_t = mlx::steel::BlockMMA<
T,
T,
BM,
BN,
BK,
WM,
WN,
false,
transpose,
BK_padded,
transpose ? BK_padded : BN_padded>;
using loader_x_t =
mlx::steel::BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE>;
using loader_w_t = QuantizedBlockLoader<
T,
transpose ? BN : BK,
transpose ? BK : BN,
transpose ? BK_padded : BN_padded,
transpose,
WM * WN * SIMD_SIZE,
group_size,
bits>;
threadgroup T Xs[BM * BK_padded];
threadgroup T Ws[transpose ? BN * BK_padded : BK * BN_padded];
// Compute the block
const int K_w = K * bytes_per_pack / pack_factor;
const int K_g = K / group_size;
const int N_w = N * bytes_per_pack / pack_factor;
const int N_g = N / group_size;
const int K_it = K / BK;
const size_t stride_w = transpose ? N * K_w : K * N_w;
const size_t stride_s = transpose ? N * K_g : K * N_g;
const int y_row = tid.y * BM;
const int y_col = tid.x * BN;
const size_t y_row_long = size_t(y_row);
const size_t y_col_long = size_t(y_col);
// Prepare threadgroup bounds
const short tgp_bm = align_M ? BM : short(min(BM, M - y_row));
const short tgp_bn = align_N ? BN : short(min(BN, N - y_col));
// Calculate the final tiles in the case that K is not aligned
const int k_remain = K - K_it * BK;
const short2 tile_x = short2(k_remain, tgp_bm);
const short2 tile_w =
transpose ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
// Move x and output to the correct block
auto wl = (const device uint8_t*)w;
x += y_row_long * K;
y += y_row_long * N + y_col_long;
wl += transpose ? y_col_long * K_w : y_col * bytes_per_pack / pack_factor;
scales += transpose ? y_col_long * K_g : y_col / group_size;
biases += transpose ? y_col_long * K_g : y_col / group_size;
// Do as many matmuls as necessary
uint32_t index;
short offset;
uint32_t index_next = indices[y_row];
short offset_next = 0;
int n = 0;
while (n < tgp_bm) {
n++;
offset = offset_next;
index = index_next;
offset_next = tgp_bm;
for (; n < tgp_bm; n++) {
if (indices[y_row + n] != index) {
offset_next = n;
index_next = indices[y_row + n];
break;
}
}
threadgroup_barrier(mem_flags::mem_none);
// Prepare threadgroup mma operation
thread mma_t mma_op(simd_group_id, simd_lane_id);
// Prepare threadgroup loading operations
thread loader_x_t loader_x(x, K, Xs, simd_group_id, simd_lane_id);
thread loader_w_t loader_w(
wl + index * stride_w,
scales + index * stride_s,
biases + index * stride_s,
transpose ? K : N,
Ws,
simd_group_id,
simd_lane_id);
// Matrices are all aligned check nothing
if (align_M && align_N) {
gemm_loop_aligned(Xs, Ws, mma_op, loader_x, loader_w, K_it);
if (!align_K) {
threadgroup_barrier(mem_flags::mem_threadgroup);
gemm_loop_finalize(Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
}
// Store results to device memory
if (offset_next - offset == BM) {
mma_op.store_result(y, N);
} else {
mma_op.store_result_slice(
y, N, short2(0, offset), short2(BN, offset_next));
}
} else {
// Tile aligned so check outside of the hot loop
if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {
gemm_loop_aligned(Xs, Ws, mma_op, loader_x, loader_w, K_it);
if (!align_K) {
threadgroup_barrier(mem_flags::mem_threadgroup);
gemm_loop_finalize(
Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
}
// Store results to device memory
if (offset_next - offset == BM) {
mma_op.store_result(y, N);
} else {
mma_op.store_result_slice(
y, N, short2(0, offset), short2(BN, offset_next));
}
}
// Tile partially aligned check rows
else if (align_N || tgp_bn == BN) {
gemm_loop_unaligned<false, true, transpose>(
Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK);
if (!align_K) {
threadgroup_barrier(mem_flags::mem_threadgroup);
gemm_loop_finalize(
Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
}
mma_op.store_result_slice(
y, N, short2(0, offset), short2(BN, offset_next));
}
// Tile partially aligned check cols
else if (align_M || tgp_bm == BM) {
gemm_loop_unaligned<true, false, transpose>(
Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK);
if (!align_K) {
threadgroup_barrier(mem_flags::mem_threadgroup);
gemm_loop_finalize(
Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
}
mma_op.store_result_slice(
y, N, short2(0, offset), short2(tgp_bn, offset_next));
}
// Nothing aligned so check both rows and cols
else {
gemm_loop_unaligned<false, false, transpose>(
Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK);
if (!align_K) {
threadgroup_barrier(mem_flags::mem_threadgroup);
gemm_loop_finalize(
Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
}
mma_op.store_result_slice(
y, N, short2(0, offset), short2(tgp_bn, offset_next));
}
}
}
}
template <typename T, const int group_size, const int bits> template <typename T, const int group_size, const int bits>
[[kernel]] void affine_quantize( [[kernel]] void affine_quantize(
const device T* w [[buffer(0)]], const device T* w [[buffer(0)]],

View File

@@ -60,6 +60,20 @@
bits, \ bits, \
split_k) split_k)
#define instantiate_gather_qmm_rhs(func, name, type, group_size, bits, bm, bn, bk, wm, wn, transpose) \
instantiate_kernel( \
#name "_" #type "_gs_" #group_size "_b_" #bits "_bm_" #bm "_bn_" #bn "_bk_" #bk "_wm_" #wm "_wn_" #wn, \
func, \
type, \
group_size, \
bits, \
bm, \
bn, \
bk, \
wm, \
wn, \
transpose)
#define instantiate_quantized_batched_wrap(name, type, group_size, bits) \ #define instantiate_quantized_batched_wrap(name, type, group_size, bits) \
instantiate_quantized_batched(name, type, group_size, bits, 1) \ instantiate_quantized_batched(name, type, group_size, bits, 1) \
instantiate_quantized_batched(name, type, group_size, bits, 0) instantiate_quantized_batched(name, type, group_size, bits, 0)
@@ -73,14 +87,14 @@
#define instantiate_quantized_all_single(type, group_size, bits) \ #define instantiate_quantized_all_single(type, group_size, bits) \
instantiate_quantized(affine_quantize, type, group_size, bits) \ instantiate_quantized(affine_quantize, type, group_size, bits) \
instantiate_quantized(affine_dequantize, type, group_size, bits) \ instantiate_quantized(affine_dequantize, type, group_size, bits) \
instantiate_quantized(bs_qmv_fast, type, group_size, bits) \ instantiate_quantized(gather_qmv_fast, type, group_size, bits) \
instantiate_quantized(bs_qmv, type, group_size, bits) \ instantiate_quantized(gather_qmv, type, group_size, bits) \
instantiate_quantized(bs_qvm, type, group_size, bits) \ instantiate_quantized(gather_qvm, type, group_size, bits) \
instantiate_quantized(bs_qmm_n, type, group_size, bits) instantiate_quantized(gather_qmm_n, type, group_size, bits)
#define instantiate_quantized_all_aligned(type, group_size, bits) \ #define instantiate_quantized_all_aligned(type, group_size, bits) \
instantiate_quantized_aligned(bs_qmm_t, type, group_size, bits, true) \ instantiate_quantized_aligned(gather_qmm_t, type, group_size, bits, true) \
instantiate_quantized_aligned(bs_qmm_t, type, group_size, bits, false) \ instantiate_quantized_aligned(gather_qmm_t, type, group_size, bits, false) \
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, true, 1) \ instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, true, 1) \
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, true, 0) \ instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, true, 0) \
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, false, 1) \ instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, false, 1) \
@@ -96,12 +110,17 @@
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 8) \ instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 8) \
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 32) instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 32)
#define instantiate_quantized_all_rhs(type, group_size, bits) \
instantiate_gather_qmm_rhs(gather_qmm_rhs, gather_qmm_rhs_nt, type, group_size, bits, 16, 32, 32, 1, 2, true) \
instantiate_gather_qmm_rhs(gather_qmm_rhs, gather_qmm_rhs_nn, type, group_size, bits, 16, 32, 32, 1, 2, false)
#define instantiate_quantized_funcs(type, group_size, bits) \ #define instantiate_quantized_funcs(type, group_size, bits) \
instantiate_quantized_all_single(type, group_size, bits) \ instantiate_quantized_all_single(type, group_size, bits) \
instantiate_quantized_all_batched(type, group_size, bits) \ instantiate_quantized_all_batched(type, group_size, bits) \
instantiate_quantized_all_aligned(type, group_size, bits) \ instantiate_quantized_all_aligned(type, group_size, bits) \
instantiate_quantized_all_quad(type, group_size, bits) \ instantiate_quantized_all_quad(type, group_size, bits) \
instantiate_quantized_all_splitk(type, group_size, bits) instantiate_quantized_all_splitk(type, group_size, bits) \
instantiate_quantized_all_rhs(type, group_size, bits)
#define instantiate_quantized_types(group_size, bits) \ #define instantiate_quantized_types(group_size, bits) \
instantiate_quantized_funcs(float, group_size, bits) \ instantiate_quantized_funcs(float, group_size, bits) \

View File

@@ -380,69 +380,11 @@ template <typename T, int N_READS = RMS_N_READS>
} }
// clang-format off // clang-format off
#define instantiate_rms_single_row(name, itype) \
template [[host_name("rms" #name)]] [[kernel]] void \
rms_single_row<itype>( \
const device itype* x, \
const device itype* w, \
device itype* out, \
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
\
template [[host_name("vjp_rms" #name)]] [[kernel]] void \
vjp_rms_single_row<itype>( \
const device itype* x, \
const device itype* w, \
const device itype* g, \
device itype* gx, \
device itype* gw, \
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_rms_looped(name, itype) \
template [[host_name("rms_looped" #name)]] [[kernel]] void \
rms_looped<itype>( \
const device itype* x, \
const device itype* w, \
device itype* out, \
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
\
template [[host_name("vjp_rms_looped" #name)]] [[kernel]] void \
vjp_rms_looped<itype>( \
const device itype* x, \
const device itype* w, \
const device itype* g, \
device itype* gx, \
device itype* gw, \
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_rms(name, itype) \ #define instantiate_rms(name, itype) \
instantiate_rms_single_row(name, itype) \ instantiate_kernel("rms" #name, rms_single_row, itype) \
instantiate_rms_looped(name, itype) instantiate_kernel("vjp_rms" #name, vjp_rms_single_row, itype) \
instantiate_kernel("rms_looped" #name, rms_looped, itype) \
instantiate_kernel("vjp_rms_looped" #name, vjp_rms_looped, itype)
instantiate_rms(float32, float) instantiate_rms(float32, float)
instantiate_rms(float16, half) instantiate_rms(float16, half)

View File

@@ -1,11 +1,11 @@
#include <metal_stdlib> #include <metal_stdlib>
#include "mlx/backend/metal/kernels/sdpa_vector.h" // clang-format off
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/sdpa_vector.h"
using namespace metal; using namespace metal;
// clang-format off
// SDPA vector instantiations // SDPA vector instantiations
#define instantiate_sdpa_vector_aggregation(type, value_dim) \ #define instantiate_sdpa_vector_aggregation(type, value_dim) \
instantiate_kernel( \ instantiate_kernel( \
@@ -32,9 +32,11 @@ using namespace metal;
instantiate_sdpa_vector(type, 64, 64) \ instantiate_sdpa_vector(type, 64, 64) \
instantiate_sdpa_vector(type, 96, 96) \ instantiate_sdpa_vector(type, 96, 96) \
instantiate_sdpa_vector(type, 128, 128) \ instantiate_sdpa_vector(type, 128, 128) \
instantiate_sdpa_vector(type, 256, 256) \
instantiate_sdpa_vector_aggregation(type, 64) \ instantiate_sdpa_vector_aggregation(type, 64) \
instantiate_sdpa_vector_aggregation(type, 96) \ instantiate_sdpa_vector_aggregation(type, 96) \
instantiate_sdpa_vector_aggregation(type, 128) instantiate_sdpa_vector_aggregation(type, 128) \
instantiate_sdpa_vector_aggregation(type, 256)
instantiate_sdpa_vector_heads(float) instantiate_sdpa_vector_heads(float)
instantiate_sdpa_vector_heads(bfloat16_t) instantiate_sdpa_vector_heads(bfloat16_t)

View File

@@ -2,6 +2,8 @@
#pragma once #pragma once
#include "mlx/backend/metal/kernels/binary_ops.h"
#define DEFINE_SIMD_SCAN() \ #define DEFINE_SIMD_SCAN() \
template <typename T, metal::enable_if_t<sizeof(T) < 8, bool> = true> \ template <typename T, metal::enable_if_t<sizeof(T) < 8, bool> = true> \
T simd_scan(T val) { \ T simd_scan(T val) { \
@@ -139,6 +141,29 @@ struct CumMin {
} }
}; };
template <typename U>
struct CumLogaddexp {
static constexpr constant U init = Limits<U>::min;
template <typename T>
U operator()(U a, T b) {
return LogAddExp{}(a, static_cast<U>(b));
}
U simd_scan(U x) {
for (int i = 1; i <= 16; i *= 2) {
U other = simd_shuffle_and_fill_up(x, init, i);
x = LogAddExp{}(x, other);
}
return x;
}
U simd_exclusive_scan(U x) {
x = simd_scan(x);
return simd_shuffle_and_fill_up(x, init, 1);
}
};
template <typename T, typename U, int N_READS, bool reverse> template <typename T, typename U, int N_READS, bool reverse>
inline void load_unsafe(U values[N_READS], const device T* input) { inline void load_unsafe(U values[N_READS], const device T* input) {
if (reverse) { if (reverse) {

View File

@@ -101,4 +101,7 @@ instantiate_scan_helper(min_int64_int64, int64_t, int64_t, CumMi
instantiate_scan_helper(min_float16_float16, half, half, CumMin, 4) instantiate_scan_helper(min_float16_float16, half, half, CumMin, 4)
instantiate_scan_helper(min_float32_float32, float, float, CumMin, 4) instantiate_scan_helper(min_float32_float32, float, float, CumMin, 4)
instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMin, 4) instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMin, 4)
instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin, 2) // clang-format on instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin, 2)
instantiate_scan_helper(logaddexp_float16_float16, half, half, CumLogaddexp, 4)
instantiate_scan_helper(logaddexp_float32_float32, float, float, CumLogaddexp, 4)
instantiate_scan_helper(logaddexp_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumLogaddexp, 4) // clang-format on

View File

@@ -6,6 +6,9 @@ using namespace metal;
constant bool has_mask [[function_constant(20)]]; constant bool has_mask [[function_constant(20)]];
constant bool query_transposed [[function_constant(21)]]; constant bool query_transposed [[function_constant(21)]];
constant bool do_causal [[function_constant(22)]];
constant bool bool_mask [[function_constant(23)]];
constant bool float_mask [[function_constant(24)]];
template <typename T, int D, int V = D> template <typename T, int D, int V = D>
[[kernel]] void sdpa_vector( [[kernel]] void sdpa_vector(
@@ -13,17 +16,21 @@ template <typename T, int D, int V = D>
const device T* keys [[buffer(1)]], const device T* keys [[buffer(1)]],
const device T* values [[buffer(2)]], const device T* values [[buffer(2)]],
device T* out [[buffer(3)]], device T* out [[buffer(3)]],
const constant int& gqa_factor, const constant int& gqa_factor [[buffer(4)]],
const constant int& N, const constant int& N [[buffer(5)]],
const constant size_t& k_head_stride, const constant size_t& k_head_stride [[buffer(6)]],
const constant size_t& k_seq_stride, const constant size_t& k_seq_stride [[buffer(7)]],
const constant size_t& v_head_stride, const constant size_t& v_head_stride [[buffer(8)]],
const constant size_t& v_seq_stride, const constant size_t& v_seq_stride [[buffer(9)]],
const constant float& scale, const constant float& scale [[buffer(10)]],
const device bool* mask [[function_constant(has_mask)]], const device bool* bmask [[buffer(11), function_constant(bool_mask)]],
const constant int& mask_kv_seq_stride [[function_constant(has_mask)]], const device T* fmask [[buffer(12), function_constant(float_mask)]],
const constant int& mask_q_seq_stride [[function_constant(has_mask)]], const constant int& mask_kv_seq_stride
const constant int& mask_head_stride [[function_constant(has_mask)]], [[buffer(13), function_constant(has_mask)]],
const constant int& mask_q_seq_stride
[[buffer(14), function_constant(has_mask)]],
const constant int& mask_head_stride
[[buffer(15), function_constant(has_mask)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint3 tpg [[threadgroups_per_grid]], uint3 tpg [[threadgroups_per_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
@@ -57,8 +64,12 @@ template <typename T, int D, int V = D>
simd_lid * qk_per_thread; simd_lid * qk_per_thread;
values += kv_head_idx * v_head_stride + simd_gid * v_seq_stride + values += kv_head_idx * v_head_stride + simd_gid * v_seq_stride +
simd_lid * v_per_thread; simd_lid * v_per_thread;
if (has_mask) { if (bool_mask) {
mask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride + bmask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride +
q_seq_idx * mask_q_seq_stride;
}
if (float_mask) {
fmask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride +
q_seq_idx * mask_q_seq_stride; q_seq_idx * mask_q_seq_stride;
} }
@@ -77,7 +88,13 @@ template <typename T, int D, int V = D>
// For each key // For each key
for (int i = simd_gid; i < N; i += BN) { for (int i = simd_gid; i < N; i += BN) {
if (!has_mask || mask[0]) { bool use_key = true;
if (do_causal) {
use_key = i <= (N - int(tpg.y) + int(q_seq_idx));
} else if (bool_mask) {
use_key = bmask[0];
}
if (use_key) {
// Read the key // Read the key
for (int j = 0; j < qk_per_thread; j++) { for (int j = 0; j < qk_per_thread; j++) {
k[j] = keys[j]; k[j] = keys[j];
@@ -89,6 +106,9 @@ template <typename T, int D, int V = D>
score += q[j] * k[j]; score += q[j] * k[j];
} }
score = simd_sum(score); score = simd_sum(score);
if (float_mask) {
score += max(Limits<U>::finite_min, static_cast<U>(fmask[0]));
}
// Update the accumulators // Update the accumulators
U new_max = max(max_score, score); U new_max = max(max_score, score);
@@ -107,8 +127,11 @@ template <typename T, int D, int V = D>
// Move the pointers to the next kv // Move the pointers to the next kv
keys += inner_k_stride; keys += inner_k_stride;
values += inner_v_stride; values += inner_v_stride;
if (has_mask) { if (bool_mask) {
mask += BN * mask_kv_seq_stride; bmask += BN * mask_kv_seq_stride;
}
if (float_mask) {
fmask += BN * mask_kv_seq_stride;
} }
} }
@@ -149,17 +172,21 @@ template <typename T, int D, int V = D>
device float* out [[buffer(3)]], device float* out [[buffer(3)]],
device float* sums [[buffer(4)]], device float* sums [[buffer(4)]],
device float* maxs [[buffer(5)]], device float* maxs [[buffer(5)]],
const constant int& gqa_factor, const constant int& gqa_factor [[buffer(6)]],
const constant int& N, const constant int& N [[buffer(7)]],
const constant size_t& k_head_stride, const constant size_t& k_head_stride [[buffer(8)]],
const constant size_t& k_seq_stride, const constant size_t& k_seq_stride [[buffer(9)]],
const constant size_t& v_head_stride, const constant size_t& v_head_stride [[buffer(10)]],
const constant size_t& v_seq_stride, const constant size_t& v_seq_stride [[buffer(11)]],
const constant float& scale, const constant float& scale [[buffer(12)]],
const device bool* mask [[function_constant(has_mask)]], const device bool* bmask [[buffer(13), function_constant(bool_mask)]],
const constant int& mask_kv_seq_stride [[function_constant(has_mask)]], const device T* fmask [[buffer(14), function_constant(float_mask)]],
const constant int& mask_q_seq_stride [[function_constant(has_mask)]], const constant int& mask_kv_seq_stride
const constant int& mask_head_stride [[function_constant(has_mask)]], [[buffer(15), function_constant(has_mask)]],
const constant int& mask_q_seq_stride
[[buffer(16), function_constant(has_mask)]],
const constant int& mask_head_stride
[[buffer(17), function_constant(has_mask)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint3 tpg [[threadgroups_per_grid]], uint3 tpg [[threadgroups_per_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
@@ -197,8 +224,13 @@ template <typename T, int D, int V = D>
values += kv_head_idx * v_head_stride + values += kv_head_idx * v_head_stride +
(block_idx * BN + simd_gid) * v_seq_stride + simd_lid * v_per_thread; (block_idx * BN + simd_gid) * v_seq_stride + simd_lid * v_per_thread;
out += o_offset * blocks * V + block_idx * V + simd_lid * v_per_thread; out += o_offset * blocks * V + block_idx * V + simd_lid * v_per_thread;
if (has_mask) { if (bool_mask) {
mask += head_idx * mask_head_stride + bmask += head_idx * mask_head_stride +
(block_idx * BN + simd_gid) * mask_kv_seq_stride +
q_seq_idx * mask_q_seq_stride;
}
if (float_mask) {
fmask += head_idx * mask_head_stride +
(block_idx * BN + simd_gid) * mask_kv_seq_stride + (block_idx * BN + simd_gid) * mask_kv_seq_stride +
q_seq_idx * mask_q_seq_stride; q_seq_idx * mask_q_seq_stride;
} }
@@ -218,7 +250,13 @@ template <typename T, int D, int V = D>
// For each key // For each key
for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) { for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) {
if (!has_mask || mask[0]) { bool use_key = true;
if (do_causal) {
use_key = i <= (N - int(tpg.y) + int(q_seq_idx));
} else if (bool_mask) {
use_key = bmask[0];
}
if (use_key) {
// Read the key // Read the key
for (int i = 0; i < qk_per_thread; i++) { for (int i = 0; i < qk_per_thread; i++) {
k[i] = keys[i]; k[i] = keys[i];
@@ -230,6 +268,9 @@ template <typename T, int D, int V = D>
score += q[i] * k[i]; score += q[i] * k[i];
} }
score = simd_sum(score); score = simd_sum(score);
if (float_mask) {
score += fmask[0];
}
// Update the accumulators // Update the accumulators
U new_max = max(max_score, score); U new_max = max(max_score, score);
@@ -248,8 +289,11 @@ template <typename T, int D, int V = D>
// Move the pointers to the next kv // Move the pointers to the next kv
keys += blocks * inner_k_stride; keys += blocks * inner_k_stride;
values += blocks * inner_v_stride; values += blocks * inner_v_stride;
if (has_mask) { if (bool_mask) {
mask += BN * blocks * mask_kv_seq_stride; bmask += BN * blocks * mask_kv_seq_stride;
}
if (float_mask) {
fmask += BN * blocks * mask_kv_seq_stride;
} }
} }

View File

@@ -10,46 +10,12 @@ using namespace metal;
#include "mlx/backend/metal/kernels/softmax.h" #include "mlx/backend/metal/kernels/softmax.h"
#define instantiate_softmax(name, itype) \ #define instantiate_softmax(name, itype) \
template [[host_name("block_softmax_" #name)]] [[kernel]] void \ instantiate_kernel("block_softmax_" #name, softmax_single_row, itype) \
softmax_single_row<itype>( \ instantiate_kernel("looped_softmax_" #name, softmax_looped, itype)
const device itype* in, \
device itype* out, \
constant int& axis_size, \
uint gid [[thread_position_in_grid]], \
uint _lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
template [[host_name("looped_softmax_" #name)]] [[kernel]] void \
softmax_looped<itype>( \
const device itype* in, \
device itype* out, \
constant int& axis_size, \
uint gid [[threadgroup_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_softmax_precise(name, itype) \ #define instantiate_softmax_precise(name, itype) \
template [[host_name("block_softmax_precise_" #name)]] [[kernel]] void \ instantiate_kernel("block_softmax_precise_" #name, softmax_single_row, itype, float) \
softmax_single_row<itype, float>( \ instantiate_kernel("looped_softmax_precise_" #name, softmax_looped, itype, float)
const device itype* in, \
device itype* out, \
constant int& axis_size, \
uint gid [[thread_position_in_grid]], \
uint _lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
template [[host_name("looped_softmax_precise_" #name)]] [[kernel]] void \
softmax_looped<itype, float>( \
const device itype* in, \
device itype* out, \
constant int& axis_size, \
uint gid [[threadgroup_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
instantiate_softmax(float32, float) instantiate_softmax(float32, float)
instantiate_softmax(float16, half) instantiate_softmax(float16, half)

View File

@@ -229,7 +229,7 @@ template <
// Init to -Inf // Init to -Inf
STEEL_PRAGMA_UNROLL STEEL_PRAGMA_UNROLL
for (short i = 0; i < kRowsPT; ++i) { for (short i = 0; i < kRowsPT; ++i) {
max_score[i] = Limits<AccumType>::min; max_score[i] = Limits<AccumType>::finite_min;
} }
int kb_lim = params->NK; int kb_lim = params->NK;
@@ -237,6 +237,7 @@ template <
if (do_causal) { if (do_causal) {
int q_max = (tid.x + 1) * BQ + params->qL_off; int q_max = (tid.x + 1) * BQ + params->qL_off;
kb_lim = (q_max + BK - 1) / BK; kb_lim = (q_max + BK - 1) / BK;
kb_lim = min(params->NK, kb_lim);
} }
// Loop over KV seq length // Loop over KV seq length
@@ -272,7 +273,7 @@ template <
if (!align_K && kb == (params->NK_aligned)) { if (!align_K && kb == (params->NK_aligned)) {
using stile_t = decltype(Stile); using stile_t = decltype(Stile);
using selem_t = typename stile_t::elem_type; using selem_t = typename stile_t::elem_type;
constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity(); constexpr auto neg_inf = Limits<selem_t>::finite_min;
STEEL_PRAGMA_UNROLL STEEL_PRAGMA_UNROLL
for (short i = 0; i < stile_t::kTileRows; i++) { for (short i = 0; i < stile_t::kTileRows; i++) {
@@ -290,10 +291,10 @@ template <
} }
// Mask out if causal // Mask out if causal
if (do_causal && kb >= (kb_lim - (BQ + BK - 1) / BK - int(!align_K))) { if (do_causal && kb >= (kb_lim - ((BQ + BK - 1) / BK) - int(!align_K))) {
using stile_t = decltype(Stile); using stile_t = decltype(Stile);
using selem_t = typename stile_t::elem_type; using selem_t = typename stile_t::elem_type;
constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity(); constexpr auto neg_inf = Limits<selem_t>::finite_min;
STEEL_PRAGMA_UNROLL STEEL_PRAGMA_UNROLL
for (short i = 0; i < stile_t::kTileRows; i++) { for (short i = 0; i < stile_t::kTileRows; i++) {
@@ -316,7 +317,7 @@ template <
if (has_mask) { if (has_mask) {
using stile_t = decltype(Stile); using stile_t = decltype(Stile);
using selem_t = typename stile_t::elem_type; using selem_t = typename stile_t::elem_type;
constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity(); constexpr auto neg_inf = Limits<selem_t>::finite_min;
constexpr bool is_bool = is_same_v<MaskType, bool>; constexpr bool is_bool = is_same_v<MaskType, bool>;
using melem_t = typename metal::conditional_t<is_bool, bool, selem_t>; using melem_t = typename metal::conditional_t<is_bool, bool, selem_t>;

View File

@@ -15,10 +15,6 @@ constant bool align_M [[function_constant(200)]];
constant bool align_N [[function_constant(201)]]; constant bool align_N [[function_constant(201)]];
constant bool align_K [[function_constant(202)]]; constant bool align_K [[function_constant(202)]];
constant bool do_gather [[function_constant(300)]];
constant bool gather_bias = do_gather && use_out_source;
// clang-format off // clang-format off
template < template <
typename T, typename T,
@@ -39,12 +35,6 @@ template <
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]], const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
const constant int* batch_shape [[buffer(6)]], const constant int* batch_shape [[buffer(6)]],
const constant int64_t* batch_strides [[buffer(7)]], const constant int64_t* batch_strides [[buffer(7)]],
const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]],
const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]],
const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]],
const constant int* operand_shape [[buffer(13), function_constant(do_gather)]],
const constant int64_t* operand_strides [[buffer(14), function_constant(do_gather)]],
const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]],
uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
@@ -81,63 +71,6 @@ template <
} }
// Adjust for batch // Adjust for batch
// Handle gather
if (do_gather) {
// Read indices
uint32_t indx_A, indx_B, indx_C;
if (has_batch) {
const constant auto* indx_A_bstrides = batch_strides;
const constant auto* indx_B_bstrides = batch_strides + params->batch_ndim;
ulong2 indx_offsets = elem_to_loc_broadcast(
tid.z,
batch_shape,
indx_A_bstrides,
indx_B_bstrides,
params->batch_ndim);
indx_A = lhs_indices[indx_offsets.x];
indx_B = rhs_indices[indx_offsets.y];
if (use_out_source) {
const constant auto* indx_C_bstrides =
indx_B_bstrides + params->batch_ndim;
auto indx_offset_C = elem_to_loc(
tid.z, batch_shape, indx_C_bstrides, params->batch_ndim);
indx_C = C_indices[indx_offset_C];
}
} else {
indx_A = lhs_indices[params->batch_stride_a * tid.z];
indx_B = rhs_indices[params->batch_stride_b * tid.z];
if (use_out_source) {
indx_C = C_indices[addmm_params->batch_stride_c * tid.z];
}
}
// Translate indices to offsets
int batch_ndim_A = operand_batch_ndim.x;
const constant int* batch_shape_A = operand_shape;
const constant auto* batch_strides_A = operand_strides;
A += elem_to_loc(indx_A, batch_shape_A, batch_strides_A, batch_ndim_A);
int batch_ndim_B = operand_batch_ndim.y;
const constant int* batch_shape_B = batch_shape_A + batch_ndim_A;
const constant auto* batch_strides_B = batch_strides_A + batch_ndim_A;
B += elem_to_loc(indx_B, batch_shape_B, batch_strides_B, batch_ndim_B);
if (use_out_source) {
int batch_ndim_C = operand_batch_ndim.z;
const constant int* batch_shape_C = batch_shape_B + batch_ndim_B;
const constant auto* batch_strides_C = batch_strides_B + batch_ndim_B;
C += elem_to_loc(indx_C, batch_shape_C, batch_strides_C, batch_ndim_C);
}
}
// Handle regular batch
else {
if (has_batch) { if (has_batch) {
const constant auto* A_bstrides = batch_strides; const constant auto* A_bstrides = batch_strides;
const constant auto* B_bstrides = batch_strides + params->batch_ndim; const constant auto* B_bstrides = batch_strides + params->batch_ndim;
@@ -160,7 +93,6 @@ template <
C += addmm_params->batch_stride_c * tid.z; C += addmm_params->batch_stride_c * tid.z;
} }
} }
}
D += params->batch_stride_d * tid.z; D += params->batch_stride_d * tid.z;

View File

@@ -0,0 +1,459 @@
// Copyright © 2024 Apple Inc.
using namespace mlx::steel;
constant bool has_batch [[function_constant(10)]];
constant bool align_M [[function_constant(200)]];
constant bool align_N [[function_constant(201)]];
constant bool align_K [[function_constant(202)]];
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
typename AccumType = float>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gather_mm_rhs(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
const device uint32_t* rhs_indices [[buffer(2)]],
device T* C [[buffer(3)]],
const constant GEMMParams* params [[buffer(4)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]]) {
using gemm_kernel = GEMMKernel<
T,
T,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
true,
true,
AccumType>;
using loader_a_t = typename gemm_kernel::loader_a_t;
using loader_b_t = typename gemm_kernel::loader_b_t;
using mma_t = typename gemm_kernel::mma_t;
if (params->tiles_n <= static_cast<int>(tid.x) ||
params->tiles_m <= static_cast<int>(tid.y)) {
return;
}
// Prepare threadgroup memory
threadgroup T As[gemm_kernel::tgp_mem_size_a];
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
// Find the block in A, B, C
const int c_row = tid.y * BM;
const int c_col = tid.x * BN;
const size_t c_row_long = size_t(c_row);
const size_t c_col_long = size_t(c_col);
// Prepare threadgroup bounds
const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row));
const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col));
A += transpose_a ? c_row_long : c_row_long * params->lda;
B += transpose_b ? c_col_long * params->ldb : c_col_long;
C += c_row_long * params->ldd + c_col_long;
// Do as many matmuls as necessary
uint32_t index;
short offset;
uint32_t index_next = rhs_indices[c_row];
short offset_next = 0;
int n = 0;
while (n < tgp_bm) {
n++;
offset = offset_next;
index = index_next;
offset_next = tgp_bm;
for (; n < tgp_bm; n++) {
if (rhs_indices[c_row + n] != index) {
offset_next = n;
index_next = rhs_indices[c_row + n];
break;
}
}
threadgroup_barrier(mem_flags::mem_none);
// Prepare threadgroup mma operation
thread mma_t mma_op(simd_group_id, simd_lane_id);
// Prepare threadgroup loading operations
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
thread loader_b_t loader_b(
B + index * params->batch_stride_b,
params->ldb,
Bs,
simd_group_id,
simd_lane_id);
// Prepare iterations
const int gemm_k_iterations = params->gemm_k_iterations_aligned;
// Do unaligned K iterations first
if (!align_K) {
const int k_last = params->gemm_k_iterations_aligned * BK;
const int k_remain = params->K - k_last;
const size_t k_jump_a =
transpose_a ? params->lda * size_t(k_last) : size_t(k_last);
const size_t k_jump_b =
transpose_b ? size_t(k_last) : params->ldb * size_t(k_last);
// Move loader source ahead to end
loader_a.src += k_jump_a;
loader_b.src += k_jump_b;
// Load tile
const short2 tile_dims_A =
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
const short2 tile_dims_B =
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
loader_a.load_safe(tile_dims_A);
loader_b.load_safe(tile_dims_B);
threadgroup_barrier(mem_flags::mem_threadgroup);
// Do matmul
mma_op.mma(As, Bs);
// Reset source back to start
loader_a.src -= k_jump_a;
loader_b.src -= k_jump_b;
}
// Matrix level aligned never check
if (align_M && align_N) {
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
// Store results to device memory
if (offset_next - offset == BM) {
mma_op.store_result(C, params->ldd);
} else {
mma_op.store_result_slice(
C, params->ldd, short2(0, offset), short2(BN, offset_next));
}
} else {
const short lbk = 0;
// Tile aligned don't check
if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
lbk,
LoopAlignment<true, true, true>{});
if (offset_next - offset == BM) {
mma_op.store_result(C, params->ldd);
} else {
mma_op.store_result_slice(
C, params->ldd, short2(0, offset), short2(BN, offset_next));
}
}
// Tile partially aligned check rows
else if (align_N || tgp_bn == BN) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
lbk,
LoopAlignment<false, true, true>{});
mma_op.store_result_slice(
C, params->ldd, short2(0, offset), short2(BN, offset_next));
}
// Tile partially aligned check cols
else if (align_M || tgp_bm == BM) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
lbk,
LoopAlignment<true, false, true>{});
mma_op.store_result_slice(
C, params->ldd, short2(0, offset), short2(tgp_bn, offset_next));
}
// Nothing aligned so check both rows and cols
else {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
lbk,
LoopAlignment<false, false, true>{});
mma_op.store_result_slice(
C, params->ldd, short2(0, offset), short2(tgp_bn, offset_next));
}
}
}
}
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
typename AccumType = float>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gather_mm(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
const device uint32_t* lhs_indices [[buffer(2)]],
const device uint32_t* rhs_indices [[buffer(3)]],
device T* C [[buffer(4)]],
const constant GEMMParams* params [[buffer(5)]],
const constant int* indices_shape [[buffer(6)]],
const constant int64_t* lhs_strides [[buffer(7)]],
const constant int64_t* rhs_strides [[buffer(8)]],
const constant int& batch_ndim_a [[buffer(9)]],
const constant int* batch_shape_a [[buffer(10)]],
const constant int64_t* batch_strides_a [[buffer(11)]],
const constant int& batch_ndim_b [[buffer(12)]],
const constant int* batch_shape_b [[buffer(13)]],
const constant int64_t* batch_strides_b [[buffer(14)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]]) {
using gemm_kernel = GEMMKernel<
T,
T,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
true,
true,
AccumType>;
using loader_a_t = typename gemm_kernel::loader_a_t;
using loader_b_t = typename gemm_kernel::loader_b_t;
using mma_t = typename gemm_kernel::mma_t;
if (params->tiles_n <= static_cast<int>(tid.x) ||
params->tiles_m <= static_cast<int>(tid.y)) {
return;
}
// Move A and B to the locations pointed by lhs_indices and rhs_indices.
uint32_t indx_A, indx_B;
if (has_batch) {
ulong2 indices_offsets = elem_to_loc_broadcast(
tid.z, indices_shape, lhs_strides, rhs_strides, params->batch_ndim);
indx_A = lhs_indices[indices_offsets.x];
indx_B = rhs_indices[indices_offsets.y];
} else {
indx_A = lhs_indices[params->batch_stride_a * tid.z];
indx_B = rhs_indices[params->batch_stride_b * tid.z];
}
A += elem_to_loc(indx_A, batch_shape_a, batch_strides_a, batch_ndim_a);
B += elem_to_loc(indx_B, batch_shape_b, batch_strides_b, batch_ndim_b);
C += params->batch_stride_d * tid.z;
// Prepare threadgroup memory
threadgroup T As[gemm_kernel::tgp_mem_size_a];
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
// Just make sure everybody's finished with the indexing math above.
threadgroup_barrier(mem_flags::mem_none);
// Find block in A, B, C
const int c_row = tid.y * BM;
const int c_col = tid.x * BN;
const size_t c_row_long = size_t(c_row);
const size_t c_col_long = size_t(c_col);
A += transpose_a ? c_row_long : c_row_long * params->lda;
B += transpose_b ? c_col_long * params->ldb : c_col_long;
C += c_row_long * params->ldd + c_col_long;
// Prepare threadgroup mma operation
thread mma_t mma_op(simd_group_id, simd_lane_id);
// Prepare threadgroup loading operations
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
// Prepare threadgroup bounds
const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row));
const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col));
// Prepare iterations
int gemm_k_iterations = params->gemm_k_iterations_aligned;
// Do unaligned K iterations first
if (!align_K) {
const int k_last = params->gemm_k_iterations_aligned * BK;
const int k_remain = params->K - k_last;
const size_t k_jump_a =
transpose_a ? params->lda * size_t(k_last) : size_t(k_last);
const size_t k_jump_b =
transpose_b ? size_t(k_last) : params->ldb * size_t(k_last);
// Move loader source ahead to end
loader_a.src += k_jump_a;
loader_b.src += k_jump_b;
// Load tile
const short2 tile_dims_A =
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
const short2 tile_dims_B =
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
loader_a.load_safe(tile_dims_A);
loader_b.load_safe(tile_dims_B);
threadgroup_barrier(mem_flags::mem_threadgroup);
// Do matmul
mma_op.mma(As, Bs);
// Reset source back to start
loader_a.src -= k_jump_a;
loader_b.src -= k_jump_b;
}
// Matrix level aligned never check
if (align_M && align_N) {
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
// Store results to device memory
mma_op.store_result(C, params->ldd);
} else {
const short lbk = 0;
// Tile aligned don't check
if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
lbk,
LoopAlignment<true, true, true>{});
mma_op.store_result(C, params->ldd);
}
// Tile partially aligned check rows
else if (align_N || tgp_bn == BN) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
lbk,
LoopAlignment<false, true, true>{});
mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm));
}
// Tile partially aligned check cols
else if (align_M || tgp_bm == BM) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
lbk,
LoopAlignment<true, false, true>{});
mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm));
}
// Nothing aligned so check both rows and cols
else {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
lbk,
LoopAlignment<false, false, true>{});
mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm));
}
}
}

View File

@@ -0,0 +1,59 @@
// Copyright © 2024 Apple Inc.
// clang-format off
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h"
#define instantiate_gather_mm_rhs(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_kernel( \
"steel_gather_mm_rhs_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn \
"_bk" #bk "_wm" #wm "_wn" #wn, \
gather_mm_rhs, \
itype, \
bm, \
bn, \
bk, \
wm, \
wn, \
trans_a, \
trans_b, \
float)
#define instantiate_gather_mm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_kernel( \
"steel_gather_mm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn \
"_bk" #bk "_wm" #wm "_wn" #wn, \
gather_mm, \
itype, \
bm, \
bn, \
bk, \
wm, \
wn, \
trans_a, \
trans_b, \
float)
#define instantiate_gather_mm_rhs_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gather_mm_rhs(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gather_mm_rhs(nt, false, true, iname, itype, oname, otype, bm, bn, bk, wm, wn)
#define instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gather_mm(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gather_mm(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gather_mm(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gather_mm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
#define instantiate_gather_mm_shapes_helper(iname, itype, oname, otype) \
instantiate_gather_mm_rhs_transpose_helper(iname, itype, oname, otype, 16, 64, 16, 1, 2) \
instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 1, 2) \
instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 1, 2) \
instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2)
// clang-format on
instantiate_gather_mm_shapes_helper(float16, half, float16, half);
instantiate_gather_mm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
instantiate_gather_mm_shapes_helper(float32, float, float32, float);

View File

@@ -142,6 +142,42 @@ struct BaseMMAFrag<T, 8, 8> {
} }
} }
template <
typename DstPtrType,
typename StrX,
typename StrY,
typename StartX,
typename StopX,
typename StartY,
typename StopY,
typename OffX,
typename OffY>
METAL_FUNC static constexpr void store_slice(
const thread frag_type& src,
DstPtrType dst,
StrX str_x,
StrY str_y,
StartX start_x,
StopX stop_x,
StartY start_y,
StopY stop_y,
OffX off_x = Int<0>{},
OffY off_y = Int<0>{}) {
using U = pointer_element_t<DstPtrType>;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kElemRows; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kElemCols; j++) {
if ((off_x + i) < stop_x && (off_x + i) >= start_x &&
(off_y + j) < stop_y && (off_y + j) >= start_y) {
dst[(off_x + i) * str_x + (off_y + j) * str_y] =
static_cast<U>(src[i * kElemCols + j]);
}
}
}
}
METAL_FUNC static constexpr void mma( METAL_FUNC static constexpr void mma(
thread frag_type& D, thread frag_type& D,
thread frag_type& A, thread frag_type& A,
@@ -335,6 +371,31 @@ struct MMATile {
} }
} }
} }
template <typename U, int w_x, int w_y>
METAL_FUNC void store_slice(
device U* dst,
const int ld,
const short2 start,
const short2 stop) const {
STEEL_PRAGMA_UNROLL
for (int i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (int j = 0; j < kTileCols; ++j) {
MMAFrag_t::store_slice(
frag_at(i, j),
dst,
ld,
Int<1>{},
start.y,
stop.y,
start.x,
stop.x,
(i * kFragRows) * w_x,
(j * kFragCols) * w_y);
}
}
}
}; };
template <typename T, typename U, int M, int N, int K> template <typename T, typename U, int M, int N, int K>
@@ -474,6 +535,26 @@ struct BlockMMA {
Ctile.template store<U, WM, WN>(D, ldd); Ctile.template store<U, WM, WN>(D, ldd);
} }
METAL_FUNC void
store_result_slice(device U* D, const int ldd, short2 start, short2 stop) {
// Apply epilogue
STEEL_PRAGMA_UNROLL
for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]);
}
D += sm * ldd + sn;
start -= short2(sn, sm);
stop -= short2(sn, sm);
// TODO: Check the start as well
if (stop.y <= 0 || stop.x <= 0) {
return;
}
Ctile.template store_slice<U, WM, WN>(D, ldd, start, stop);
}
METAL_FUNC void METAL_FUNC void
store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) { store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) {
// Apply epilogue // Apply epilogue

View File

@@ -73,6 +73,9 @@ instantiate_unary_all_same(Conjugate, complex64, complex64_t)
instantiate_unary_all_same(Cos, complex64, complex64_t) instantiate_unary_all_same(Cos, complex64, complex64_t)
instantiate_unary_all_same(Cosh, complex64, complex64_t) instantiate_unary_all_same(Cosh, complex64, complex64_t)
instantiate_unary_all_same(Exp, complex64, complex64_t) instantiate_unary_all_same(Exp, complex64, complex64_t)
instantiate_unary_all_same(Log, complex64, complex64_t)
instantiate_unary_all_same(Log2, complex64, complex64_t)
instantiate_unary_all_same(Log10, complex64, complex64_t)
instantiate_unary_all_same(Negative, complex64, complex64_t) instantiate_unary_all_same(Negative, complex64, complex64_t)
instantiate_unary_all_same(Sign, complex64, complex64_t) instantiate_unary_all_same(Sign, complex64, complex64_t)
instantiate_unary_all_same(Sin, complex64, complex64_t) instantiate_unary_all_same(Sin, complex64, complex64_t)

View File

@@ -257,6 +257,13 @@ struct Log {
T operator()(T x) { T operator()(T x) {
return metal::precise::log(x); return metal::precise::log(x);
}; };
template <>
complex64_t operator()(complex64_t x) {
auto r = metal::precise::log(Abs{}(x).real);
auto i = metal::precise::atan2(x.imag, x.real);
return {r, i};
};
}; };
struct Log2 { struct Log2 {
@@ -264,6 +271,12 @@ struct Log2 {
T operator()(T x) { T operator()(T x) {
return metal::precise::log2(x); return metal::precise::log2(x);
}; };
template <>
complex64_t operator()(complex64_t x) {
auto y = Log{}(x);
return {y.real / M_LN2_F, y.imag / M_LN2_F};
};
}; };
struct Log10 { struct Log10 {
@@ -271,6 +284,12 @@ struct Log10 {
T operator()(T x) { T operator()(T x) {
return metal::precise::log10(x); return metal::precise::log10(x);
}; };
template <>
complex64_t operator()(complex64_t x) {
auto y = Log{}(x);
return {y.real / M_LN10_F, y.imag / M_LN10_F};
};
}; };
struct Log1p { struct Log1p {

View File

@@ -0,0 +1,96 @@
// Copyright © 2023-2024 Apple Inc.
#include <algorithm>
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
namespace mlx::core {
constexpr int LOGSUMEXP_LOOPED_LIMIT = 4096;
void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
if (!issubdtype(out.dtype(), floating)) {
throw std::runtime_error(
"[logsumexp] Does not support non-floating point types.");
}
auto& s = stream();
auto& d = metal::device(s.device);
// Make sure that the last dimension is contiguous
auto ensure_contiguous = [&s, &d](const array& x) {
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
return x;
} else {
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
d.add_temporary(x_copy, s.index);
return x_copy;
}
};
auto in = ensure_contiguous(inputs[0]);
if (in.flags().row_contiguous) {
out.set_data(allocator::malloc(out.nbytes()));
} else {
auto n = in.shape(-1);
auto flags = in.flags();
auto strides = in.strides();
for (auto& s : strides) {
s /= n;
}
bool col_contig = strides[0] == 1;
for (int i = 1; col_contig && i < strides.size(); ++i) {
col_contig &=
(out.shape(i) == 1 || strides[i - 1] == out.shape(i) * strides[i]);
}
flags.col_contiguous = col_contig;
out.set_data(
allocator::malloc(in.nbytes() / n),
in.data_size() / n,
std::move(strides),
flags);
}
int axis_size = in.shape().back();
int n_rows = in.data_size() / axis_size;
const int simd_size = 32;
const int n_reads = 4;
const int looped_limit = LOGSUMEXP_LOOPED_LIMIT;
std::string kernel_name = (axis_size > looped_limit) ? "looped_" : "block_";
kernel_name += "logsumexp_";
kernel_name += type_to_name(out);
auto kernel = get_logsumexp_kernel(d, kernel_name, out);
auto& compute_encoder = d.get_command_encoder(s.index);
{
MTL::Size grid_dims, group_dims;
if (axis_size <= looped_limit) {
size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads;
size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size;
size_t threadgroup_size = simd_size * simds_needed;
assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup());
size_t n_threads = n_rows * threadgroup_size;
grid_dims = MTL::Size(n_threads, 1, 1);
group_dims = MTL::Size(threadgroup_size, 1, 1);
} else {
size_t threadgroup_size = kernel->maxTotalThreadsPerThreadgroup();
size_t n_threads = n_rows * threadgroup_size;
grid_dims = MTL::Size(n_threads, 1, 1);
group_dims = MTL::Size(threadgroup_size, 1, 1);
}
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder.set_bytes(axis_size, 2);
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
}
} // namespace mlx::core

View File

@@ -5,6 +5,7 @@
#include <numeric> #include <numeric>
#include <sstream> #include <sstream>
#include "mlx/backend/common/broadcasting.h"
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
@@ -102,6 +103,47 @@ std::tuple<bool, int64_t, array> check_transpose(
} }
}; };
inline array
ensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) {
if (!x.flags().row_contiguous) {
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
d.add_temporary(x_copy, s.index);
return x_copy;
} else {
return x;
}
}
inline std::tuple<bool, int64_t, array>
ensure_batch_contiguous(const array& x, metal::Device& d, const Stream& s) {
if (x.flags().row_contiguous) {
return std::make_tuple(false, x.strides()[x.ndim() - 2], x);
}
bool rc = true;
for (int i = 0; i < x.ndim() - 3; i++) {
rc &= x.strides()[i + 1] * x.shape(i) == x.strides()[i];
}
if (rc) {
auto stx = x.strides()[x.ndim() - 2];
auto sty = x.strides()[x.ndim() - 1];
auto K = x.shape(-2);
auto N = x.shape(-1);
if (sty == 1 && (N != 1 || stx == N)) {
return std::make_tuple(false, stx, x);
}
if (stx == 1 && (N != 1 || sty == K)) {
return std::make_tuple(true, sty, x);
}
}
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
d.add_temporary(x_copy, s.index);
return std::make_tuple(false, x_copy.strides()[x_copy.ndim() - 2], x_copy);
}
} // namespace } // namespace
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
@@ -230,7 +272,6 @@ void steel_matmul_regular(
const bool align_M = (M % bm) == 0; const bool align_M = (M % bm) == 0;
const bool align_N = (N % bn) == 0; const bool align_N = (N % bn) == 0;
const bool align_K = (K % bk) == 0; const bool align_K = (K % bk) == 0;
const bool do_gather = false;
metal::MTLFCList func_consts = { metal::MTLFCList func_consts = {
{&has_batch, MTL::DataType::DataTypeBool, 10}, {&has_batch, MTL::DataType::DataTypeBool, 10},
@@ -239,7 +280,6 @@ void steel_matmul_regular(
{&align_M, MTL::DataType::DataTypeBool, 200}, {&align_M, MTL::DataType::DataTypeBool, 200},
{&align_N, MTL::DataType::DataTypeBool, 201}, {&align_N, MTL::DataType::DataTypeBool, 201},
{&align_K, MTL::DataType::DataTypeBool, 202}, {&align_K, MTL::DataType::DataTypeBool, 202},
{&do_gather, MTL::DataType::DataTypeBool, 300},
}; };
// clang-format off // clang-format off
@@ -248,8 +288,7 @@ void steel_matmul_regular(
<< "_do_axpby_" << (do_axpby ? 't' : 'n') << "_do_axpby_" << (do_axpby ? 't' : 'n')
<< "_align_M_" << (align_M ? 't' : 'n') << "_align_M_" << (align_M ? 't' : 'n')
<< "_align_N_" << (align_N ? 't' : 'n') << "_align_N_" << (align_N ? 't' : 'n')
<< "_align_K_" << (align_K ? 't' : 'n') << "_align_K_" << (align_K ? 't' : 'n'); // clang-format on
<< "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on
std::string hash_name = kname.str(); std::string hash_name = kname.str();
@@ -975,7 +1014,6 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
const bool align_M = (M % bm) == 0; const bool align_M = (M % bm) == 0;
const bool align_N = (N % bn) == 0; const bool align_N = (N % bn) == 0;
const bool align_K = (K % bk) == 0; const bool align_K = (K % bk) == 0;
const bool do_gather = false;
metal::MTLFCList func_consts = { metal::MTLFCList func_consts = {
{&has_batch, MTL::DataType::DataTypeBool, 10}, {&has_batch, MTL::DataType::DataTypeBool, 10},
@@ -984,7 +1022,6 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
{&align_M, MTL::DataType::DataTypeBool, 200}, {&align_M, MTL::DataType::DataTypeBool, 200},
{&align_N, MTL::DataType::DataTypeBool, 201}, {&align_N, MTL::DataType::DataTypeBool, 201},
{&align_K, MTL::DataType::DataTypeBool, 202}, {&align_K, MTL::DataType::DataTypeBool, 202},
{&do_gather, MTL::DataType::DataTypeBool, 300},
}; };
// clang-format off // clang-format off
@@ -993,8 +1030,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
<< "_do_axpby_" << (do_axpby ? 't' : 'n') << "_do_axpby_" << (do_axpby ? 't' : 'n')
<< "_align_M_" << (align_M ? 't' : 'n') << "_align_M_" << (align_M ? 't' : 'n')
<< "_align_N_" << (align_N ? 't' : 'n') << "_align_N_" << (align_N ? 't' : 'n')
<< "_align_K_" << (align_K ? 't' : 'n') << "_align_K_" << (align_K ? 't' : 'n'); // clang-format on
<< "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on
std::string hash_name = kname.str(); std::string hash_name = kname.str();
@@ -1464,131 +1500,176 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
d.add_temporaries(std::move(copies), s.index); d.add_temporaries(std::move(copies), s.index);
} }
void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) { void gather_mm_rhs(
using namespace mlx::steel; const array& a_,
// assert(inputs.size() == 2); const array& b_,
if (!issubdtype(out.dtype(), floating)) { const array& indices_,
throw std::runtime_error( array& out,
"[GatherMM] Does not yet support non-floating point types."); metal::Device& d,
} const Stream& s) {
auto& s = stream(); array indices = ensure_row_contiguous(indices_, d, s);
auto& d = metal::device(s.device); auto [transpose_b, ldb, b] = ensure_batch_contiguous(b_, d, s);
auto& a_pre = inputs[0]; // Broadcast a with indices. If we are here that means lhs_indices were not
auto& b_pre = inputs[1]; // provided so the lhs_indices are implied to be the shape of a broadcasted
// Return 0s if either input is empty // with rhs_indices. We need only broadcast a and copy it as if applying the
if (a_pre.size() == 0 || b_pre.size() == 0) { // lhs_indices.
array zero = array(0, a_pre.dtype()); auto broadcast_with_indices = [&d, &s, &indices](const array& x) {
fill_gpu(zero, out, s); if (x.size() / x.shape(-2) / x.shape(-1) == indices.size()) {
d.add_temporary(std::move(zero), s.index); return ensure_row_contiguous(x, d, s);
return;
} }
out.set_data(allocator::malloc(out.nbytes())); auto x_shape = indices.shape();
x_shape.push_back(x.shape(-2));
x_shape.push_back(x.shape(-1));
array new_x(std::move(x_shape), x.dtype(), nullptr, {});
broadcast(x, new_x);
return ensure_row_contiguous(new_x, d, s);
};
array a = broadcast_with_indices(a_);
///////////////////////////////////////////////////////////////////////////// // Extract the matmul shapes
// Init checks and prep int K = a.shape(-1);
int M = a.size() / K;
int N = b.shape(-1);
int lda = a.strides()[a.ndim() - 2]; // should be K
int M = a_pre.shape(-2); // Define the dispatch blocks
int N = b_pre.shape(-1); int bm = 16, bn = 64, bk = 16;
int K = a_pre.shape(-1); int wm = 1, wn = 2;
// Keep a vector with copies to be cleared in the completed buffer to release const bool align_M = (M % bm) == 0;
// the arrays const bool align_N = (N % bn) == 0;
std::vector<array> copies; const bool align_K = (K % bk) == 0;
auto [transpose_a, a_cols, a] = check_transpose(copies, s, a_pre, M == 1);
auto [transpose_b, b_cols, b] = check_transpose(copies, s, b_pre, N == 1);
int lda = a_cols; // Define the kernel name
int ldb = b_cols; std::string base_name;
base_name.reserve(64);
concatenate(
base_name,
"steel_gather_mm_rhs_n",
transpose_b ? 't' : 'n',
'_',
type_to_name(a),
'_',
type_to_name(out),
"_bm",
bm,
"_bn",
bn,
"_bk",
bk,
"_wm",
wm,
"_wn",
wn);
///////////////////////////////////////////////////////////////////////////// metal::MTLFCList func_consts = {
// Check and collapse batch dimensions {&align_M, MTL::DataType::DataTypeBool, 200},
{&align_N, MTL::DataType::DataTypeBool, 201},
auto get_batch_dims = [](const auto& v) { {&align_K, MTL::DataType::DataTypeBool, 202},
return decltype(v){v.begin(), v.end() - 2};
}; };
auto& lhs_indices = inputs[2]; // And the kernel hash that includes the function constants
auto& rhs_indices = inputs[3]; std::string hash_name;
hash_name.reserve(128);
concatenate(
hash_name,
base_name,
"_align_M_",
align_M ? 't' : 'n',
"_align_N_",
align_N ? 't' : 'n',
"_align_K_",
align_K ? 't' : 'n');
Shape batch_shape = get_batch_dims(out.shape()); // Get and set the kernel
Strides batch_strides; auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = get_steel_gemm_gather_kernel(
d,
base_name,
hash_name,
func_consts,
out,
false,
transpose_b,
bm,
bn,
bk,
wm,
wn,
true);
compute_encoder.set_compute_pipeline_state(kernel);
batch_strides.insert( // Prepare the matmul params
batch_strides.end(), auto batch_stride_b = b.ndim() > 2 ? b.strides()[b.ndim() - 3] : b.size();
lhs_indices.strides().begin(), steel::GEMMParams params{
lhs_indices.strides().end()); /* const int M = */ M,
auto lhs_indices_str = batch_strides.empty() ? 0 : batch_strides.back(); /* const int N = */ N,
/* const int K = */ K,
/* const int lda = */ lda,
/* const int ldb = */ static_cast<int>(ldb),
/* const int ldd = */ N,
/* const int tiles_n = */ (N + bn - 1) / bn,
/* const int tiles_m = */ (M + bm - 1) / bm,
/* const int64_t batch_stride_a = */ 0,
/* const int64_t batch_stride_b = */ static_cast<int64_t>(batch_stride_b),
/* const int64_t batch_stride_d = */ 0,
/* const int swizzle_log = */ 0,
/* const int gemm_k_iterations_aligned = */ (K / bk),
/* const int batch_ndim = */ 0};
batch_strides.insert( // Prepare the grid
batch_strides.end(), MTL::Size group_dims = MTL::Size(32, wn, wm);
rhs_indices.strides().begin(), MTL::Size grid_dims = MTL::Size(params.tiles_n, params.tiles_m, 1);
rhs_indices.strides().end());
auto rhs_indices_str = batch_strides.empty() ? 0 : batch_strides.back();
int batch_ndim = batch_shape.size(); // Launch kernel
compute_encoder.set_input_array(a, 0);
compute_encoder.set_input_array(b, 1);
compute_encoder.set_input_array(indices, 2);
compute_encoder.set_output_array(out, 3);
compute_encoder.set_bytes(params, 4);
if (batch_ndim == 0) { compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
batch_shape = {1}; }
batch_strides = {0};
}
int batch_ndim_A = a.ndim() - 2; void gather_mv(
int batch_ndim_B = b.ndim() - 2; const array& mat_,
std::vector<int> operand_batch_ndim = {batch_ndim_A, batch_ndim_B}; const array& vec_,
const array& mat_indices_,
const array& vec_indices_,
array& out,
int N,
int K,
bool is_mv,
metal::Device& d,
const Stream& s) {
// Copy if needed
std::vector<array> copies;
auto [transpose_mat, mat_cols, mat] =
check_transpose(copies, s, mat_, N == 1);
auto [transpose_vec, vec_cols, vec] = check_transpose(copies, s, vec_, true);
d.add_temporaries(std::move(copies), s.index);
Shape batch_shape_A = get_batch_dims(a.shape()); // If we are doing vector matrix instead of matrix vector we need to flip the
Strides batch_strides_A = get_batch_dims(a.strides()); // matrix transposition. Basically m @ v = v @ m.T assuming that v is treated
Shape batch_shape_B = get_batch_dims(b.shape()); // as a one dimensional array.
Strides batch_strides_B = get_batch_dims(b.strides()); transpose_mat = (!is_mv) ^ transpose_mat;
if (batch_ndim_A == 0) { // Define some shapes
batch_shape_A = {1};
batch_strides_A = {0};
}
if (batch_ndim_B == 0) {
batch_shape_B = {1};
batch_strides_B = {0};
}
auto matrix_stride_out = static_cast<int64_t>(M) * N;
auto batch_size_out = out.size() / matrix_stride_out;
/////////////////////////////////////////////////////////////////////////////
// Gemv specialization
// Route to gemv if needed
if (std::min(M, N) == 1) {
// Collect problem info
bool is_b_matrix = N != 1;
auto& mat = is_b_matrix ? b : a;
auto& vec = is_b_matrix ? a : b;
bool transpose_mat = is_b_matrix ? !transpose_b : transpose_a;
int in_vector_len = K; int in_vector_len = K;
int out_vector_len = is_b_matrix ? N : M; int out_vector_len = N;
int mat_ld = mat_cols;
int mat_cols = transpose_mat ? out_vector_len : in_vector_len; int batch_size_out = out.size() / N;
int mat_rows = transpose_mat ? in_vector_len : out_vector_len; int batch_ndim = out.ndim() - 2;
int mat_ld = is_b_matrix ? b_cols : a_cols; int batch_ndim_mat = mat.ndim() - 2;
int batch_ndim_vec = vec.ndim() - 2;
auto batch_strides_mat = is_b_matrix ? batch_strides_B : batch_strides_A; Strides index_strides = vec_indices_.strides();
auto batch_strides_vec = is_b_matrix ? batch_strides_A : batch_strides_B; index_strides.insert(
index_strides.end(),
auto batch_shape_mat = is_b_matrix ? batch_shape_B : batch_shape_A; mat_indices_.strides().begin(),
auto batch_shape_vec = is_b_matrix ? batch_shape_A : batch_shape_B; mat_indices_.strides().end());
if (!is_b_matrix) {
batch_strides = rhs_indices.strides();
batch_strides.insert(
batch_strides.end(),
lhs_indices.strides().begin(),
lhs_indices.strides().end());
}
int batch_ndim = batch_shape.size();
// Determine dispatch kernel // Determine dispatch kernel
int tm = 4, tn = 4; int tm = 4, tn = 4;
@@ -1652,79 +1733,104 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_bytes(mat_ld, 6); compute_encoder.set_bytes(mat_ld, 6);
compute_encoder.set_bytes(batch_ndim, 9); compute_encoder.set_bytes(batch_ndim, 9);
compute_encoder.set_vector_bytes(batch_shape, 10); compute_encoder.set_vector_bytes(out.shape(), 10);
compute_encoder.set_vector_bytes(batch_strides, 11); compute_encoder.set_vector_bytes(index_strides, 11);
int batch_ndim_vec = batch_shape_vec.size();
compute_encoder.set_bytes(batch_ndim_vec, 12); compute_encoder.set_bytes(batch_ndim_vec, 12);
compute_encoder.set_vector_bytes(batch_shape_vec, 13); compute_encoder.set_vector_bytes(vec.shape(), 13);
compute_encoder.set_vector_bytes(batch_strides_vec, 14); compute_encoder.set_vector_bytes(vec.strides(), 14);
int batch_ndim_mat = batch_shape_mat.size();
compute_encoder.set_bytes(batch_ndim_mat, 15); compute_encoder.set_bytes(batch_ndim_mat, 15);
compute_encoder.set_vector_bytes(batch_shape_mat, 16); compute_encoder.set_vector_bytes(mat.shape(), 16);
compute_encoder.set_vector_bytes(batch_strides_mat, 17); compute_encoder.set_vector_bytes(mat.strides(), 17);
compute_encoder.set_input_array(lhs_indices, 18 + int(!is_b_matrix)); compute_encoder.set_input_array(vec_indices_, 18);
compute_encoder.set_input_array(rhs_indices, 18 + int(is_b_matrix)); compute_encoder.set_input_array(mat_indices_, 19);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims); compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
void gather_mm(
const array& a_,
const array& b_,
const array& lhs_indices,
const array& rhs_indices,
array& out,
int M,
int N,
int K,
metal::Device& d,
const Stream& s) {
// Copy if needed
std::vector<array> copies;
auto [transpose_a, lda, a] = check_transpose(copies, s, a_, false);
auto [transpose_b, ldb, b] = check_transpose(copies, s, b_, false);
d.add_temporaries(std::move(copies), s.index); d.add_temporaries(std::move(copies), s.index);
return;
}
/////////////////////////////////////////////////////////////////////////////
// Regular kernel dispatch
// Determine dispatch kernel // Determine dispatch kernel
int bm = 64, bn = 64, bk = 16; int bm = 64, bn = 64, bk = 16;
int wm = 2, wn = 2; int wm = 2, wn = 2;
size_t batch_size_out = out.size() / M / N;
int batch_ndim = out.ndim() - 2;
int batch_ndim_a = a.ndim() - 2;
int batch_ndim_b = b.ndim() - 2;
char devc = d.get_architecture().back(); char devc = d.get_architecture().back();
GEMM_TPARAM_MACRO(devc) GEMM_TPARAM_MACRO(devc)
// Prepare kernel name
std::ostringstream kname;
kname << "steel_gemm_fused_" << (transpose_a ? 't' : 'n')
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
<< "_wm" << wm << "_wn" << wn;
std::string base_name = kname.str();
const bool has_batch = batch_ndim > 1; const bool has_batch = batch_ndim > 1;
const bool use_out_source = false;
const bool do_axpby = false;
const bool align_M = (M % bm) == 0; const bool align_M = (M % bm) == 0;
const bool align_N = (N % bn) == 0; const bool align_N = (N % bn) == 0;
const bool align_K = (K % bk) == 0; const bool align_K = (K % bk) == 0;
const bool do_gather = true;
// Define the kernel name
std::string base_name;
base_name.reserve(128);
concatenate(
base_name,
"steel_gather_mm_",
transpose_a ? 't' : 'n',
transpose_b ? 't' : 'n',
"_",
type_to_name(a),
"_",
type_to_name(out),
"_bm",
bm,
"_bn",
bn,
"_bk",
bk,
"_wm",
wm,
"_wn",
wn);
metal::MTLFCList func_consts = { metal::MTLFCList func_consts = {
{&has_batch, MTL::DataType::DataTypeBool, 10}, {&has_batch, MTL::DataType::DataTypeBool, 10},
{&use_out_source, MTL::DataType::DataTypeBool, 100},
{&do_axpby, MTL::DataType::DataTypeBool, 110},
{&align_M, MTL::DataType::DataTypeBool, 200}, {&align_M, MTL::DataType::DataTypeBool, 200},
{&align_N, MTL::DataType::DataTypeBool, 201}, {&align_N, MTL::DataType::DataTypeBool, 201},
{&align_K, MTL::DataType::DataTypeBool, 202}, {&align_K, MTL::DataType::DataTypeBool, 202},
{&do_gather, MTL::DataType::DataTypeBool, 300},
}; };
// clang-format off // And the kernel hash that includes the function constants
kname << "_has_batch_" << (has_batch ? 't' : 'n') std::string hash_name;
<< "_use_out_source_" << (use_out_source ? 't' : 'n') hash_name.reserve(128);
<< "_do_axpby_" << (do_axpby ? 't' : 'n') concatenate(
<< "_align_M_" << (align_M ? 't' : 'n') hash_name,
<< "_align_N_" << (align_N ? 't' : 'n') base_name,
<< "_align_K_" << (align_K ? 't' : 'n') "_has_batch_",
<< "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on has_batch ? 't' : 'n',
"_align_M_",
align_M ? 't' : 'n',
"_align_N_",
align_N ? 't' : 'n',
"_align_K_",
align_K ? 't' : 'n');
std::string hash_name = kname.str(); // Get and set the kernel
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = get_steel_gemm_fused_kernel( auto kernel = get_steel_gemm_gather_kernel(
d, d,
base_name, base_name,
hash_name, hash_name,
@@ -1736,72 +1842,96 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
bn, bn,
bk, bk,
wm, wm,
wn); wn,
false);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_compute_pipeline_state(kernel);
// Use problem size to determine threadblock swizzle // Prepare the matmul params
int tn = (N + bn - 1) / bn; steel::GEMMParams params{
int tm = (M + bm - 1) / bm;
// TODO: Explore device-based tuning for swizzle
int swizzle_log = 0; // tm >= 6 ? 3 : (tm <= 3 ? 0 : 2);
// Prepare steel matmul params
GEMMParams params{
/* const int M = */ M, /* const int M = */ M,
/* const int N = */ N, /* const int N = */ N,
/* const int K = */ K, /* const int K = */ K,
/* const int lda = */ lda, /* const int lda = */ static_cast<int>(lda),
/* const int ldb = */ ldb, /* const int ldb = */ static_cast<int>(ldb),
/* const int ldd = */ N, /* const int ldd = */ N,
/* const int tiles_n = */ tn, /* const int tiles_n = */ (N + bn - 1) / bn,
/* const int tiles_m = */ tm, /* const int tiles_m = */ (M + bm - 1) / bm,
/* const int64_t batch_stride_a = */ lhs_indices_str, /* const int64_t batch_stride_a = */
/* const int64_t batch_stride_b = */ rhs_indices_str, (batch_ndim > 0) ? lhs_indices.strides()[0] : 0,
/* const int64_t batch_stride_d = */ matrix_stride_out, /* const int64_t batch_stride_b = */
/* const int swizzle_log = */ swizzle_log, (batch_ndim > 0) ? rhs_indices.strides()[0] : 0,
/* const int64_t batch_stride_d = */ M * N,
/* const int swizzle_log = */ 0,
/* const int gemm_k_iterations_aligned = */ (K / bk), /* const int gemm_k_iterations_aligned = */ (K / bk),
/* const int batch_ndim = */ batch_ndim}; /* const int batch_ndim = */ batch_ndim};
// Prepare launch grid params // Prepare the grid
int tile = 1 << swizzle_log;
tm = (tm + tile - 1) / tile;
tn = tn * tile;
MTL::Size group_dims = MTL::Size(32, wn, wm); MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out); MTL::Size grid_dims =
MTL::Size(params.tiles_n, params.tiles_m, batch_size_out);
// Launch kernel // Launch kernel
compute_encoder.set_input_array(a, 0); compute_encoder.set_input_array(a, 0);
compute_encoder.set_input_array(b, 1); compute_encoder.set_input_array(b, 1);
compute_encoder.set_output_array(out, 3); compute_encoder.set_input_array(lhs_indices, 2);
compute_encoder.set_input_array(rhs_indices, 3);
compute_encoder.set_bytes(params, 4); compute_encoder.set_output_array(out, 4);
compute_encoder.set_bytes(params, 5);
compute_encoder.set_vector_bytes(batch_shape, 6); compute_encoder.set_vector_bytes(lhs_indices.shape(), 6);
compute_encoder.set_vector_bytes(batch_strides, 7); compute_encoder.set_vector_bytes(lhs_indices.strides(), 7);
compute_encoder.set_vector_bytes(rhs_indices.strides(), 8);
compute_encoder.set_input_array(lhs_indices, 10); compute_encoder.set_bytes(batch_ndim_a, 9);
compute_encoder.set_input_array(rhs_indices, 11); compute_encoder.set_vector_bytes(a.shape(), 10);
compute_encoder.set_vector_bytes(a.strides(), 11);
std::vector operand_shape = batch_shape_A; compute_encoder.set_bytes(batch_ndim_b, 12);
operand_shape.insert( compute_encoder.set_vector_bytes(b.shape(), 13);
operand_shape.end(), batch_shape_B.begin(), batch_shape_B.end()); compute_encoder.set_vector_bytes(b.strides(), 14);
std::vector operand_strides = batch_strides_A;
operand_strides.insert(
operand_strides.end(), batch_strides_B.begin(), batch_strides_B.end());
operand_batch_ndim.push_back(0);
compute_encoder.set_vector_bytes(operand_shape, 13);
compute_encoder.set_vector_bytes(operand_strides, 14);
compute_encoder.set_vector_bytes(operand_batch_ndim, 15);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims); compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
d.add_temporaries(std::move(copies), s.index); void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = stream();
auto& d = metal::device(s.device);
auto& a = inputs[0];
auto& b = inputs[1];
auto& lhs_indices = inputs[2];
auto& rhs_indices = inputs[3];
// Return 0s if either input is empty
if (a.size() == 0 || b.size() == 0) {
array zero = array(0, a.dtype());
fill_gpu(zero, out, s);
d.add_temporary(std::move(zero), s.index);
return;
}
out.set_data(allocator::malloc(out.nbytes()));
// Extract shapes from inputs.
int M = a.shape(-2);
int N = b.shape(-1);
int K = a.shape(-1);
// We are walking a in order and b is also in order so we can batch up the
// matmuls and reuse reading a and b.
if (M == 1 && right_sorted_ == true) {
gather_mm_rhs(a, b, rhs_indices, out, d, s);
return;
}
// Route to gather gemv if any of a or b are vectors
if (M == 1) {
gather_mv(b, a, rhs_indices, lhs_indices, out, N, K, false, d, s);
return;
}
if (N == 1) {
gather_mv(a, b, lhs_indices, rhs_indices, out, M, K, true, d, s);
return;
}
// Route to non specialized gather mm
gather_mm(a, b, lhs_indices, rhs_indices, out, M, N, K, d, s);
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -12,74 +12,6 @@ namespace mlx::core::metal {
/* Check if the Metal backend is available. */ /* Check if the Metal backend is available. */
bool is_available(); bool is_available();
/* Get the actively used memory in bytes.
*
* Note, this will not always match memory use reported by the system because
* it does not include cached memory buffers.
* */
size_t get_active_memory();
/* Get the peak amount of used memory in bytes.
*
* The maximum memory used recorded from the beginning of the program
* execution or since the last call to reset_peak_memory.
* */
size_t get_peak_memory();
/* Reset the peak memory to zero.
* */
void reset_peak_memory();
/* Get the cache size in bytes.
*
* The cache includes memory not currently used that has not been returned
* to the system allocator.
* */
size_t get_cache_memory();
/* Set the memory limit.
* The memory limit is a guideline for the maximum amount of memory to use
* during graph evaluation. If the memory limit is exceeded and there is no
* more RAM (including swap when available) allocations will result in an
* exception.
*
* When metal is available the memory limit defaults to 1.5 times the maximum
* recommended working set size reported by the device.
*
* Returns the previous memory limit.
* */
size_t set_memory_limit(size_t limit);
/* Get the current memory limit. */
size_t get_memory_limit();
/* Set the free cache limit.
* If using more than the given limit, free memory will be reclaimed
* from the cache on the next allocation. To disable the cache,
* set the limit to 0.
*
* The cache limit defaults to the memory limit.
*
* Returns the previous cache limit.
* */
size_t set_cache_limit(size_t limit);
/* Clear the memory cache. */
void clear_cache();
/* Set the wired size limit.
*
* Note, this function is only useful for macOS 15.0 or higher.
*
* The wired limit is the total size in bytes of memory that will be kept
* resident. The default value is ``0``.
*
* Setting a wired limit larger than system wired limit is an error.
*
* Returns the previous wired limit.
* */
size_t set_wired_limit(size_t limit);
/** Capture a GPU trace, saving it to an absolute file `path` */ /** Capture a GPU trace, saving it to an absolute file `path` */
void start_capture(std::string path = ""); void start_capture(std::string path = "");
void stop_capture(); void stop_capture();

View File

@@ -72,6 +72,13 @@ MTL::ComputePipelineState* get_softmax_kernel(
return d.get_kernel(kernel_name); return d.get_kernel(kernel_name);
} }
MTL::ComputePipelineState* get_logsumexp_kernel(
metal::Device& d,
const std::string& kernel_name,
const array&) {
return d.get_kernel(kernel_name);
}
MTL::ComputePipelineState* get_scan_kernel( MTL::ComputePipelineState* get_scan_kernel(
metal::Device& d, metal::Device& d,
const std::string& kernel_name, const std::string& kernel_name,
@@ -186,6 +193,23 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
return d.get_kernel(kernel_name); return d.get_kernel(kernel_name);
} }
MTL::ComputePipelineState* get_steel_gemm_gather_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const array&,
bool,
bool,
int,
int,
int,
int,
int,
bool) {
return d.get_kernel(kernel_name, "mlx", hash_name, func_consts);
}
MTL::ComputePipelineState* get_gemv_masked_kernel( MTL::ComputePipelineState* get_gemv_masked_kernel(
metal::Device& d, metal::Device& d,
const std::string& kernel_name, const std::string& kernel_name,
@@ -245,4 +269,21 @@ MTL::ComputePipelineState* get_quantized_kernel(
return d.get_kernel(kernel_name); return d.get_kernel(kernel_name);
} }
MTL::ComputePipelineState* get_gather_qmm_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const array&,
int,
int,
int,
int,
int,
int,
int,
bool) {
return d.get_kernel(kernel_name, "mlx", hash_name, func_consts);
}
} // namespace mlx::core } // namespace mlx::core

View File

@@ -251,8 +251,10 @@ void Concatenate::eval_gpu(const std::vector<array>& inputs, array& out) {
void Contiguous::eval_gpu(const std::vector<array>& inputs, array& out) { void Contiguous::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
if (in.flags().row_contiguous || constexpr size_t extra_bytes = 16384;
(allow_col_major_ && in.flags().col_contiguous)) { if (in.buffer_size() <= out.nbytes() + extra_bytes &&
(in.flags().row_contiguous ||
(allow_col_major_ && in.flags().col_contiguous))) {
out.copy_shared_buffer(in); out.copy_shared_buffer(in);
} else { } else {
copy_gpu(in, out, CopyType::General); copy_gpu(in, out, CopyType::General);

File diff suppressed because it is too large Load Diff

View File

@@ -22,6 +22,7 @@ ResidencySet::ResidencySet(MTL::Device* d) {
} }
throw std::runtime_error(msg.str()); throw std::runtime_error(msg.str());
} }
wired_set_->requestResidency();
} }
} }
@@ -32,7 +33,6 @@ void ResidencySet::insert(MTL::Allocation* buf) {
if (wired_set_->allocatedSize() + buf->allocatedSize() <= capacity_) { if (wired_set_->allocatedSize() + buf->allocatedSize() <= capacity_) {
wired_set_->addAllocation(buf); wired_set_->addAllocation(buf);
wired_set_->commit(); wired_set_->commit();
wired_set_->requestResidency();
} else { } else {
unwired_set_.insert(buf); unwired_set_.insert(buf);
} }
@@ -76,7 +76,6 @@ void ResidencySet::resize(size_t size) {
} }
} }
wired_set_->commit(); wired_set_->commit();
wired_set_->requestResidency();
} else if (current_size > size) { } else if (current_size > size) {
auto pool = new_scoped_memory_pool(); auto pool = new_scoped_memory_pool();
// Remove wired allocations until under capacity // Remove wired allocations until under capacity

View File

@@ -138,6 +138,7 @@ void sdpa_vector(
const array& v, const array& v,
array& out, array& out,
float scale, float scale,
bool do_causal,
const std::optional<array>& mask) { const std::optional<array>& mask) {
// Set the kernel name // Set the kernel name
std::string kname; std::string kname;
@@ -162,14 +163,20 @@ void sdpa_vector(
MTL::Size grid_dims(B, q.shape(2), 1); MTL::Size grid_dims(B, q.shape(2), 1);
bool has_mask = mask.has_value(); bool has_mask = mask.has_value();
bool bool_mask = has_mask && (*mask).dtype() == bool_;
bool float_mask = has_mask && !bool_mask;
bool query_transposed = !q.flags().row_contiguous; bool query_transposed = !q.flags().row_contiguous;
metal::MTLFCList func_consts = { metal::MTLFCList func_consts = {
{&has_mask, MTL::DataType::DataTypeBool, 20}, {&has_mask, MTL::DataType::DataTypeBool, 20},
{&query_transposed, MTL::DataType::DataTypeBool, 21}, {&query_transposed, MTL::DataType::DataTypeBool, 21},
{&do_causal, MTL::DataType::DataTypeBool, 22},
{&bool_mask, MTL::DataType::DataTypeBool, 23},
{&float_mask, MTL::DataType::DataTypeBool, 24},
}; };
std::string hash_name = kname; std::string hash_name = kname;
hash_name += has_mask ? "_mask" : "_nomask"; hash_name += has_mask ? (bool_mask ? "_boolmask" : "_floatmask") : "_nomask";
hash_name += query_transposed ? "_qt" : "_qnt"; hash_name += query_transposed ? "_qt" : "_qnt";
hash_name += do_causal ? "_c" : "_nc";
// Get the kernel // Get the kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
@@ -191,15 +198,15 @@ void sdpa_vector(
compute_encoder.set_bytes(scale, 10); compute_encoder.set_bytes(scale, 10);
if (has_mask) { if (has_mask) {
auto& m = *mask; auto& m = *mask;
compute_encoder.set_input_array(m, 11); compute_encoder.set_input_array(m, 11 + float_mask);
auto nd = m.ndim(); auto nd = m.ndim();
int32_t kv_seq_stride = int32_t kv_seq_stride =
nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0; nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0;
int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0; int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0;
int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0; int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0;
compute_encoder.set_bytes(kv_seq_stride, 12); compute_encoder.set_bytes(kv_seq_stride, 13);
compute_encoder.set_bytes(q_seq_stride, 13); compute_encoder.set_bytes(q_seq_stride, 14);
compute_encoder.set_bytes(head_stride, 14); compute_encoder.set_bytes(head_stride, 15);
} }
// Launch // Launch
@@ -214,6 +221,7 @@ void sdpa_vector_2pass(
const array& v, const array& v,
array& out, array& out,
float scale, float scale,
bool do_causal,
const std::optional<array>& mask) { const std::optional<array>& mask) {
// Set the kernel name // Set the kernel name
std::string kname; std::string kname;
@@ -256,14 +264,20 @@ void sdpa_vector_2pass(
d.add_temporary(maxs, s.index); d.add_temporary(maxs, s.index);
bool has_mask = mask.has_value(); bool has_mask = mask.has_value();
bool bool_mask = has_mask && (*mask).dtype() == bool_;
bool float_mask = has_mask && !bool_mask;
bool query_transposed = !q.flags().row_contiguous; bool query_transposed = !q.flags().row_contiguous;
metal::MTLFCList func_consts = { metal::MTLFCList func_consts = {
{&has_mask, MTL::DataType::DataTypeBool, 20}, {&has_mask, MTL::DataType::DataTypeBool, 20},
{&query_transposed, MTL::DataType::DataTypeBool, 21}, {&query_transposed, MTL::DataType::DataTypeBool, 21},
{&do_causal, MTL::DataType::DataTypeBool, 22},
{&bool_mask, MTL::DataType::DataTypeBool, 23},
{&float_mask, MTL::DataType::DataTypeBool, 24},
}; };
std::string hash_name = kname; std::string hash_name = kname;
hash_name += has_mask ? "_mask" : "_nomask"; hash_name += has_mask ? (bool_mask ? "_boolmask" : "_floatmask") : "_nomask";
hash_name += query_transposed ? "_qt" : "_qnt"; hash_name += query_transposed ? "_qt" : "_qnt";
hash_name += do_causal ? "_c" : "_nc";
// Get the kernel // Get the kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
@@ -287,15 +301,15 @@ void sdpa_vector_2pass(
compute_encoder.set_bytes(scale, 12); compute_encoder.set_bytes(scale, 12);
if (has_mask) { if (has_mask) {
auto& m = *mask; auto& m = *mask;
compute_encoder.set_input_array(m, 13); compute_encoder.set_input_array(m, 13 + float_mask);
auto nd = m.ndim(); auto nd = m.ndim();
int32_t kv_seq_stride = int32_t kv_seq_stride =
nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0; nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0;
int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0; int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0;
int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0; int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0;
compute_encoder.set_bytes(kv_seq_stride, 14); compute_encoder.set_bytes(kv_seq_stride, 15);
compute_encoder.set_bytes(q_seq_stride, 15); compute_encoder.set_bytes(q_seq_stride, 16);
compute_encoder.set_bytes(head_stride, 16); compute_encoder.set_bytes(head_stride, 17);
} }
// Launch // Launch
@@ -401,12 +415,13 @@ void ScaledDotProductAttention::eval_gpu(
// We route to the 2 pass fused attention if // We route to the 2 pass fused attention if
// - The device is large and the sequence length long // - The device is large and the sequence length long
// - The sequence length is even longer and we have gqa // - The sequence length is even longer and we have gqa
bool do_causal = do_causal_ && q.shape(2) > 1;
char devc = d.get_architecture().back(); char devc = d.get_architecture().back();
if ((devc == 'd' && k.shape(2) >= 1024) || if ((devc == 'd' && k.shape(2) >= 1024) ||
(k.shape(1) < q.shape(1) && k.shape(2) >= 4096)) { (k.shape(1) < q.shape(1) && k.shape(2) >= 4096)) {
sdpa_vector_2pass(s, d, q, k, v, o, scale_, mask); sdpa_vector_2pass(s, d, q, k, v, o, scale_, do_causal, mask);
} else { } else {
sdpa_vector(s, d, q, k, v, o, scale_, mask); sdpa_vector(s, d, q, k, v, o, scale_, do_causal, mask);
} }
} }

View File

@@ -60,6 +60,9 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
case Scan::Min: case Scan::Min:
reduce_type = "min"; reduce_type = "min";
break; break;
case Scan::LogAddExp:
reduce_type = "logaddexp";
break;
} }
kname << reduce_type << "_" << type_to_name(in) << "_" << type_to_name(out); kname << reduce_type << "_" << type_to_name(in) << "_" << type_to_name(out);
auto kernel = get_scan_kernel( auto kernel = get_scan_kernel(

View File

@@ -23,12 +23,7 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
// Make sure that the last dimension is contiguous // Make sure that the last dimension is contiguous
auto set_output = [&s, &out](const array& x) { auto set_output = [&s, &out](const array& x) {
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1; if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
if (no_copy && x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2];
no_copy &= (s == 0 || s == x.shape().back());
}
if (no_copy) {
if (x.is_donatable()) { if (x.is_donatable()) {
out.copy_shared_buffer(x); out.copy_shared_buffer(x);
} else { } else {

View File

@@ -2,6 +2,8 @@
#pragma once #pragma once
#include <type_traits>
#include "mlx/array.h" #include "mlx/array.h"
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
@@ -58,14 +60,27 @@ inline void debug_set_primitive_buffer_label(
std::string get_primitive_string(Primitive* primitive); std::string get_primitive_string(Primitive* primitive);
template <typename T>
constexpr bool is_numeric_except_char = std::is_arithmetic_v<T> &&
!std::is_same_v<T, char> && !std::is_same_v<T, signed char> &&
!std::is_same_v<T, unsigned char> && !std::is_same_v<T, wchar_t>;
template <typename T> template <typename T>
void concatenate(std::string& acc, T first) { void concatenate(std::string& acc, T first) {
if constexpr (is_numeric_except_char<T>) {
acc += std::to_string(first);
} else {
acc += first; acc += first;
}
} }
template <typename T, typename... Args> template <typename T, typename... Args>
void concatenate(std::string& acc, T first, Args... args) { void concatenate(std::string& acc, T first, Args... args) {
if constexpr (is_numeric_except_char<T>) {
acc += std::to_string(first);
} else {
acc += first; acc += first;
}
concatenate(acc, args...); concatenate(acc, args...);
} }

View File

@@ -82,6 +82,7 @@ NO_CPU(LogicalNot)
NO_CPU(LogicalAnd) NO_CPU(LogicalAnd)
NO_CPU(LogicalOr) NO_CPU(LogicalOr)
NO_CPU(LogAddExp) NO_CPU(LogAddExp)
NO_CPU(LogSumExp)
NO_CPU_MULTI(LUF) NO_CPU_MULTI(LUF)
NO_CPU(Matmul) NO_CPU(Matmul)
NO_CPU(Maximum) NO_CPU(Maximum)

View File

@@ -1,14 +1,72 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023 Apple Inc.
#include <algorithm>
#include <mutex>
#include "mlx/allocator.h" #include "mlx/allocator.h"
namespace mlx::core::allocator { #ifdef __APPLE__
#include "mlx/backend/no_metal/apple_memory.h"
#elif defined(__linux__)
#include "mlx/backend/no_metal/linux_memory.h"
#else
size_t get_memory_size() {
return 0;
}
#endif
Allocator& allocator() { namespace mlx::core {
namespace allocator {
class CommonAllocator : public Allocator {
/** A general CPU allocator. */
public:
virtual Buffer malloc(size_t size) override;
virtual void free(Buffer buffer) override;
virtual size_t size(Buffer buffer) const override;
size_t get_active_memory() const {
return active_memory_;
};
size_t get_peak_memory() const {
return peak_memory_;
};
void reset_peak_memory() {
std::unique_lock lk(mutex_);
peak_memory_ = 0;
};
size_t get_memory_limit() {
return memory_limit_;
}
size_t set_memory_limit(size_t limit) {
std::unique_lock lk(mutex_);
std::swap(memory_limit_, limit);
return limit;
}
private:
size_t memory_limit_;
size_t active_memory_{0};
size_t peak_memory_{0};
std::mutex mutex_;
CommonAllocator() : memory_limit_(0.8 * get_memory_size()) {
if (memory_limit_ == 0) {
memory_limit_ = 1UL << 33;
}
};
friend CommonAllocator& common_allocator();
};
CommonAllocator& common_allocator() {
static CommonAllocator allocator_; static CommonAllocator allocator_;
return allocator_; return allocator_;
} }
Allocator& allocator() {
return common_allocator();
}
void* Buffer::raw_ptr() { void* Buffer::raw_ptr() {
if (!ptr_) { if (!ptr_) {
return nullptr; return nullptr;
@@ -16,4 +74,59 @@ void* Buffer::raw_ptr() {
return static_cast<size_t*>(ptr_) + 1; return static_cast<size_t*>(ptr_) + 1;
} }
} // namespace mlx::core::allocator Buffer CommonAllocator::malloc(size_t size) {
void* ptr = std::malloc(size + sizeof(size_t));
if (ptr != nullptr) {
*static_cast<size_t*>(ptr) = size;
}
std::unique_lock lk(mutex_);
active_memory_ += size;
peak_memory_ = std::max(active_memory_, peak_memory_);
return Buffer{ptr};
}
void CommonAllocator::free(Buffer buffer) {
auto sz = size(buffer);
std::free(buffer.ptr());
std::unique_lock lk(mutex_);
active_memory_ -= sz;
}
size_t CommonAllocator::size(Buffer buffer) const {
if (buffer.ptr() == nullptr) {
return 0;
}
return *static_cast<size_t*>(buffer.ptr());
}
} // namespace allocator
size_t get_active_memory() {
return allocator::common_allocator().get_active_memory();
}
size_t get_peak_memory() {
return allocator::common_allocator().get_peak_memory();
}
void reset_peak_memory() {
return allocator::common_allocator().reset_peak_memory();
}
size_t set_memory_limit(size_t limit) {
return allocator::common_allocator().set_memory_limit(limit);
}
size_t get_memory_limit() {
return allocator::common_allocator().get_memory_limit();
}
// No-ops for common allocator
size_t get_cache_memory() {
return 0;
}
size_t set_cache_limit(size_t) {
return 0;
}
size_t set_wired_limit(size_t) {
return 0;
}
void clear_cache() {}
} // namespace mlx::core

View File

@@ -0,0 +1,16 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <sys/sysctl.h>
namespace {
size_t get_memory_size() {
size_t memsize = 0;
size_t length = sizeof(memsize);
sysctlbyname("hw.memsize", &memsize, &length, NULL, 0);
return memsize;
}
} // namespace

View File

@@ -28,21 +28,19 @@ void Event::wait() {
ec->cv.wait(lk, [value = value(), ec] { return ec->value >= value; }); ec->cv.wait(lk, [value = value(), ec] { return ec->value >= value; });
} }
void Event::signal() { void Event::wait(Stream stream) {
scheduler::enqueue(stream, [*this]() mutable { wait(); });
}
void Event::signal(Stream stream) {
scheduler::enqueue(stream, [*this]() mutable {
auto ec = static_cast<EventCounter*>(event_.get()); auto ec = static_cast<EventCounter*>(event_.get());
{ {
std::lock_guard<std::mutex> lk(ec->mtx); std::lock_guard<std::mutex> lk(ec->mtx);
ec->value = value(); ec->value = value();
} }
ec->cv.notify_all(); ec->cv.notify_all();
} });
void Event::wait(Stream stream) {
scheduler::enqueue(stream, [*this]() mutable { wait(); });
}
void Event::signal(Stream stream) {
scheduler::enqueue(stream, [*this]() mutable { signal(); });
} }
bool Event::is_signaled() const { bool Event::is_signaled() const {

View File

@@ -0,0 +1,22 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <sys/sysinfo.h>
namespace {
size_t get_memory_size() {
struct sysinfo info;
if (sysinfo(&info) != 0) {
return 0;
}
size_t total_ram = info.totalram;
total_ram *= info.mem_unit;
return total_ram;
}
} // namespace

View File

@@ -31,33 +31,8 @@ void synchronize(Stream) {
"[metal::synchronize] Cannot synchronize GPU without metal backend"); "[metal::synchronize] Cannot synchronize GPU without metal backend");
} }
// No-ops when Metal is not available.
size_t get_active_memory() {
return 0;
}
size_t get_peak_memory() {
return 0;
}
void reset_peak_memory() {}
size_t get_cache_memory() {
return 0;
}
size_t set_memory_limit(size_t) {
return 0;
}
size_t get_memory_limit() {
return 0;
}
size_t set_cache_limit(size_t) {
return 0;
}
size_t set_wired_limit(size_t) {
return 0;
}
void start_capture(std::string) {} void start_capture(std::string) {}
void stop_capture() {} void stop_capture() {}
void clear_cache() {}
const std::unordered_map<std::string, std::variant<std::string, size_t>>& const std::unordered_map<std::string, std::variant<std::string, size_t>>&
device_info() { device_info() {

View File

@@ -82,6 +82,7 @@ NO_GPU(LogicalNot)
NO_GPU(LogicalAnd) NO_GPU(LogicalAnd)
NO_GPU(LogicalOr) NO_GPU(LogicalOr)
NO_GPU(LogAddExp) NO_GPU(LogAddExp)
NO_GPU(LogSumExp)
NO_GPU_MULTI(LUF) NO_GPU_MULTI(LUF)
NO_GPU(Matmul) NO_GPU(Matmul)
NO_GPU(Maximum) NO_GPU(Maximum)

View File

@@ -15,6 +15,14 @@ void all_sum(Group group, const array& input, array& output, Stream stream) {
group.raw_group()->all_sum(input, output, stream); group.raw_group()->all_sum(input, output, stream);
} }
void all_max(Group group, const array& input, array& output, Stream stream) {
group.raw_group()->all_max(input, output, stream);
}
void all_min(Group group, const array& input, array& output, Stream stream) {
group.raw_group()->all_min(input, output, stream);
}
void all_gather(Group group, const array& input, array& output, Stream stream) { void all_gather(Group group, const array& input, array& output, Stream stream) {
group.raw_group()->all_gather(input, output, stream); group.raw_group()->all_gather(input, output, stream);
} }
@@ -57,6 +65,16 @@ class EmptyGroup : public GroupImpl {
throw std::runtime_error( throw std::runtime_error(
"Communication not implemented in an empty distributed group."); "Communication not implemented in an empty distributed group.");
} }
void all_max(const array&, array&, Stream) override {
throw std::runtime_error(
"Communication not implemented in an empty distributed group.");
}
void all_min(const array&, array&, Stream) override {
throw std::runtime_error(
"Communication not implemented in an empty distributed group.");
}
}; };
} // namespace detail } // namespace detail

View File

@@ -21,6 +21,8 @@ class GroupImpl {
virtual void all_gather(const array& input, array& output, Stream stream) = 0; virtual void all_gather(const array& input, array& output, Stream stream) = 0;
virtual void send(const array& input, int dst, Stream stream) = 0; virtual void send(const array& input, int dst, Stream stream) = 0;
virtual void recv(array& out, int src, Stream stream) = 0; virtual void recv(array& out, int src, Stream stream) = 0;
virtual void all_max(const array& input, array& output, Stream stream) = 0;
virtual void all_min(const array& input, array& output, Stream stream) = 0;
}; };
/* Perform an all reduce sum operation */ /* Perform an all reduce sum operation */
@@ -35,4 +37,10 @@ void send(Group group, const array& input, int dst, Stream stream);
/** Recv an array from the src rank */ /** Recv an array from the src rank */
void recv(Group group, array& out, int src, Stream stream); void recv(Group group, array& out, int src, Stream stream);
/** Max reduction */
void all_max(Group group, const array& input, array& output, Stream stream);
/** Min reduction */
void all_min(Group group, const array& input, array& output, Stream stream);
} // namespace mlx::core::distributed::detail } // namespace mlx::core::distributed::detail

View File

@@ -1,4 +1,4 @@
if(MPI_FOUND AND MLX_BUILD_CPU) if(MLX_BUILD_CPU)
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/mpi.cpp) target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/mpi.cpp)
else() else()
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_mpi.cpp) target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_mpi.cpp)

View File

@@ -1,12 +1,13 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
#include <dlfcn.h> #include <dlfcn.h>
#include <mpi.h> #include <iostream>
#include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/encoder.h"
#include "mlx/distributed/distributed.h" #include "mlx/distributed/distributed.h"
#include "mlx/distributed/distributed_impl.h" #include "mlx/distributed/distributed_impl.h"
#include "mlx/distributed/mpi/mpi.h" #include "mlx/distributed/mpi/mpi.h"
#include "mlx/distributed/mpi/mpi_declarations.h"
#define LOAD_SYMBOL(symbol, variable) \ #define LOAD_SYMBOL(symbol, variable) \
{ \ { \
@@ -18,6 +19,12 @@
} \ } \
} }
#ifdef __APPLE__
static constexpr const char* libmpi_name = "libmpi.dylib";
#else
static constexpr const char* libmpi_name = "libmpi.so";
#endif
namespace mlx::core::distributed::mpi { namespace mlx::core::distributed::mpi {
using GroupImpl = mlx::core::distributed::detail::GroupImpl; using GroupImpl = mlx::core::distributed::detail::GroupImpl;
@@ -43,15 +50,69 @@ void simple_sum(
template void simple_sum<float16_t>(void*, void*, int*, MPI_Datatype*); template void simple_sum<float16_t>(void*, void*, int*, MPI_Datatype*);
template void simple_sum<bfloat16_t>(void*, void*, int*, MPI_Datatype*); template void simple_sum<bfloat16_t>(void*, void*, int*, MPI_Datatype*);
template <typename T>
void simple_max(
void* input,
void* accumulator,
int* len,
MPI_Datatype* datatype) {
T* in = (T*)input;
T* acc = (T*)accumulator;
int N = *len;
while (N-- > 0) {
*acc = std::max(*acc, *in);
acc++;
in++;
}
}
template void simple_max<float16_t>(void*, void*, int*, MPI_Datatype*);
template void simple_max<bfloat16_t>(void*, void*, int*, MPI_Datatype*);
template void simple_max<complex64_t>(void*, void*, int*, MPI_Datatype*);
template <typename T>
void simple_min(
void* input,
void* accumulator,
int* len,
MPI_Datatype* datatype) {
T* in = (T*)input;
T* acc = (T*)accumulator;
int N = *len;
while (N-- > 0) {
*acc = std::min(*acc, *in);
acc++;
in++;
}
}
template void simple_min<float16_t>(void*, void*, int*, MPI_Datatype*);
template void simple_min<bfloat16_t>(void*, void*, int*, MPI_Datatype*);
template void simple_min<complex64_t>(void*, void*, int*, MPI_Datatype*);
struct MPIWrapper { struct MPIWrapper {
MPIWrapper() { MPIWrapper() {
initialized_ = false; initialized_ = false;
libmpi_handle_ = dlopen("libmpi.dylib", RTLD_NOW | RTLD_GLOBAL); libmpi_handle_ = dlopen(libmpi_name, RTLD_NOW | RTLD_GLOBAL);
if (libmpi_handle_ == nullptr) { if (libmpi_handle_ == nullptr) {
return; return;
} }
// Check library version and warn if it isn't Open MPI
int (*get_version)(char*, int*);
LOAD_SYMBOL(MPI_Get_library_version, get_version);
char version_ptr[MPI_MAX_LIBRARY_VERSION_STRING];
int version_length = 0;
get_version(version_ptr, &version_length);
std::string_view version(version_ptr, version_length);
if (version.find("Open MPI") == std::string::npos) {
std::cerr << "[mpi] MPI found but it does not appear to be Open MPI."
<< "MLX requires Open MPI but this is " << version << std::endl;
libmpi_handle_ = nullptr;
return;
}
// API // API
LOAD_SYMBOL(MPI_Init, init); LOAD_SYMBOL(MPI_Init, init);
LOAD_SYMBOL(MPI_Finalize, finalize); LOAD_SYMBOL(MPI_Finalize, finalize);
@@ -72,6 +133,8 @@ struct MPIWrapper {
// Ops // Ops
LOAD_SYMBOL(ompi_mpi_op_sum, op_sum_); LOAD_SYMBOL(ompi_mpi_op_sum, op_sum_);
LOAD_SYMBOL(ompi_mpi_op_max, op_max_);
LOAD_SYMBOL(ompi_mpi_op_min, op_min_);
// Datatypes // Datatypes
LOAD_SYMBOL(ompi_mpi_c_bool, mpi_bool_); LOAD_SYMBOL(ompi_mpi_c_bool, mpi_bool_);
@@ -106,9 +169,15 @@ struct MPIWrapper {
mpi_type_contiguous(2, mpi_uint8_, &mpi_bfloat16_); mpi_type_contiguous(2, mpi_uint8_, &mpi_bfloat16_);
mpi_type_commit(&mpi_bfloat16_); mpi_type_commit(&mpi_bfloat16_);
// Custom sum ops // Custom reduction ops
mpi_op_create(&simple_sum<float16_t>, 1, &op_sum_f16_); mpi_op_create(&simple_sum<float16_t>, 1, &op_sum_f16_);
mpi_op_create(&simple_sum<bfloat16_t>, 1, &op_sum_bf16_); mpi_op_create(&simple_sum<bfloat16_t>, 1, &op_sum_bf16_);
mpi_op_create(&simple_max<float16_t>, 1, &op_max_f16_);
mpi_op_create(&simple_max<bfloat16_t>, 1, &op_max_bf16_);
mpi_op_create(&simple_max<complex64_t>, 1, &op_max_c64_);
mpi_op_create(&simple_min<float16_t>, 1, &op_min_f16_);
mpi_op_create(&simple_min<bfloat16_t>, 1, &op_min_bf16_);
mpi_op_create(&simple_min<complex64_t>, 1, &op_min_c64_);
initialized_ = true; initialized_ = true;
} }
@@ -170,6 +239,32 @@ struct MPIWrapper {
} }
} }
MPI_Op op_max(const array& arr) {
switch (arr.dtype()) {
case float16:
return op_max_f16_;
case bfloat16:
return op_max_bf16_;
case complex64:
return op_max_c64_;
default:
return op_max_;
}
}
MPI_Op op_min(const array& arr) {
switch (arr.dtype()) {
case float16:
return op_min_f16_;
case bfloat16:
return op_min_bf16_;
case complex64:
return op_min_c64_;
default:
return op_min_;
}
}
void* libmpi_handle_; void* libmpi_handle_;
// API // API
@@ -198,6 +293,14 @@ struct MPIWrapper {
MPI_Op op_sum_; MPI_Op op_sum_;
MPI_Op op_sum_f16_; MPI_Op op_sum_f16_;
MPI_Op op_sum_bf16_; MPI_Op op_sum_bf16_;
MPI_Op op_max_;
MPI_Op op_max_f16_;
MPI_Op op_max_bf16_;
MPI_Op op_max_c64_;
MPI_Op op_min_;
MPI_Op op_min_f16_;
MPI_Op op_min_bf16_;
MPI_Op op_min_c64_;
// Datatypes // Datatypes
MPI_Datatype mpi_bool_; MPI_Datatype mpi_bool_;
@@ -285,6 +388,36 @@ class MPIGroup : public GroupImpl {
comm_); comm_);
} }
void all_max(const array& input, array& output, Stream stream) override {
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(input);
encoder.set_output_array(output);
encoder.dispatch(
mpi().all_reduce,
(input.data<void>() == output.data<void>()) ? MPI_IN_PLACE
: input.data<void>(),
output.data<void>(),
input.size(),
mpi().datatype(input),
mpi().op_max(input),
comm_);
}
void all_min(const array& input, array& output, Stream stream) override {
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(input);
encoder.set_output_array(output);
encoder.dispatch(
mpi().all_reduce,
(input.data<void>() == output.data<void>()) ? MPI_IN_PLACE
: input.data<void>(),
output.data<void>(),
input.size(),
mpi().datatype(input),
mpi().op_min(input),
comm_);
}
void all_gather(const array& input, array& output, Stream stream) override { void all_gather(const array& input, array& output, Stream stream) override {
auto& encoder = cpu::get_command_encoder(stream); auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(input); encoder.set_input_array(input);

View File

@@ -0,0 +1,28 @@
// Copyright © 2024 Apple Inc.
// Constants
#define MPI_SUCCESS 0
#define MPI_ANY_SOURCE -1
#define MPI_ANY_TAG -1
#define MPI_IN_PLACE ((void*)1)
#define MPI_MAX_LIBRARY_VERSION_STRING 256
// Define all the types that we use so that we don't include <mpi.h> which
// causes linker errors on some platforms.
//
// NOTE: We define everything for openmpi.
typedef void* MPI_Comm;
typedef void* MPI_Datatype;
typedef void* MPI_Op;
typedef void(MPI_User_function)(void*, void*, int*, MPI_Datatype*);
typedef struct ompi_status_public_t {
int MPI_SOURCE;
int MPI_TAG;
int MPI_ERROR;
int _cancelled;
size_t _ucount;
} MPI_Status;

View File

@@ -36,6 +36,40 @@ array all_sum(
{x}); {x});
} }
array all_max(
const array& x,
std::optional<Group> group_ /* = std::nullopt */,
StreamOrDevice s /* = {} */) {
auto group = to_group(group_);
if (group.size() == 1) {
return x;
}
return array(
x.shape(),
x.dtype(),
std::make_shared<AllReduce>(
to_stream(s, Device::cpu), group, AllReduce::Max),
{x});
}
array all_min(
const array& x,
std::optional<Group> group_ /* = std::nullopt */,
StreamOrDevice s /* = {} */) {
auto group = to_group(group_);
if (group.size() == 1) {
return x;
}
return array(
x.shape(),
x.dtype(),
std::make_shared<AllReduce>(
to_stream(s, Device::cpu), group, AllReduce::Min),
{x});
}
array all_gather( array all_gather(
const array& x, const array& x,
std::optional<Group> group_ /* = std::nullopt */, std::optional<Group> group_ /* = std::nullopt */,

View File

@@ -38,4 +38,14 @@ array recv_like(
std::optional<Group> group = std::nullopt, std::optional<Group> group = std::nullopt,
StreamOrDevice s = {}); StreamOrDevice s = {});
array all_max(
const array& x,
std::optional<Group> group = std::nullopt,
StreamOrDevice s = {});
array all_min(
const array& x,
std::optional<Group> group = std::nullopt,
StreamOrDevice s = {});
} // namespace mlx::core::distributed } // namespace mlx::core::distributed

View File

@@ -15,8 +15,14 @@ std::pair<std::vector<array>, std::vector<int>> AllReduce::vmap(
switch (reduce_type_) { switch (reduce_type_) {
case Sum: case Sum:
return {{all_sum(inputs[0], group(), stream())}, axes}; return {{all_sum(inputs[0], group(), stream())}, axes};
case Max:
return {{all_max(inputs[0], group(), stream())}, axes};
case Min:
return {{all_min(inputs[0], group(), stream())}, axes};
default: default:
throw std::runtime_error("Only all reduce sum is supported for now");
throw std::runtime_error(
"Only all reduce sum, max and min are supported for now");
} }
} }
@@ -27,8 +33,13 @@ std::vector<array> AllReduce::jvp(
switch (reduce_type_) { switch (reduce_type_) {
case Sum: case Sum:
return {all_sum(tangents[0], group(), stream())}; return {all_sum(tangents[0], group(), stream())};
case Max:
return {all_max(tangents[0], group(), stream())};
case Min:
return {all_min(tangents[0], group(), stream())};
default: default:
throw std::runtime_error("Only all reduce sum is supported for now"); throw std::runtime_error(
"Only all reduce sum, max and min are supported for now");
} }
} }

Some files were not shown because too many files have changed in this diff Show More