Merge branch 'ml-explore:main' into adding-Muon-optimizer

This commit is contained in:
Gökdeniz Gülmez
2025-04-21 20:27:33 +02:00
committed by GitHub
166 changed files with 7312 additions and 2043 deletions

View File

@@ -24,8 +24,8 @@ jobs:
type: boolean type: boolean
default: false default: false
macos: macos:
xcode: "15.2.0" xcode: "16.2.0"
resource_class: macos.m1.medium.gen1 resource_class: m2pro.medium
steps: steps:
- checkout - checkout
- run: - run:
@@ -89,15 +89,14 @@ jobs:
pip install numpy pip install numpy
sudo apt-get update sudo apt-get update
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
- run: - run:
name: Install Python package name: Install Python package
command: | command: |
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
CMAKE_COMPILE_WARNING_AS_ERROR=ON" \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \ CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
python3 setup.py build_ext --inplace python3 setup.py build_ext --inplace
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF \ CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
CMAKE_COMPILE_WARNING_AS_ERROR=ON" \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \ CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
python3 setup.py develop python3 setup.py develop
- run: - run:
@@ -110,6 +109,8 @@ jobs:
name: Run Python tests name: Run Python tests
command: | command: |
python3 -m unittest discover python/tests -v python3 -m unittest discover python/tests -v
mpirun --bind-to none -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py
- run: - run:
name: Build CPP only name: Build CPP only
command: | command: |
@@ -124,10 +125,15 @@ jobs:
parameters: parameters:
xcode_version: xcode_version:
type: string type: string
default: "15.2.0" default: "16.2.0"
macosx_deployment_target:
type: string
default: ""
macos: macos:
xcode: << parameters.xcode_version >> xcode: << parameters.xcode_version >>
resource_class: macos.m1.medium.gen1 environment:
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
resource_class: m2pro.medium
steps: steps:
- checkout - checkout
- run: - run:
@@ -149,7 +155,7 @@ jobs:
command: | command: |
source env/bin/activate source env/bin/activate
DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \ DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
CMAKE_ARGS="CMAKE_COMPILE_WARNING_AS_ERROR=ON" \ CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
pip install -e . -v pip install -e . -v
- run: - run:
name: Generate package stubs name: Generate package stubs
@@ -213,13 +219,18 @@ jobs:
default: "3.9" default: "3.9"
xcode_version: xcode_version:
type: string type: string
default: "15.2.0" default: "16.2.0"
build_env: build_env:
type: string type: string
default: "" default: ""
macosx_deployment_target:
type: string
default: ""
macos: macos:
xcode: << parameters.xcode_version >> xcode: << parameters.xcode_version >>
resource_class: macos.m1.medium.gen1 resource_class: m2pro.medium
environment:
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
steps: steps:
- checkout - checkout
- run: - run:
@@ -240,7 +251,7 @@ jobs:
name: Install Python package name: Install Python package
command: | command: |
source env/bin/activate source env/bin/activate
DEV_RELEASE=1 \ env -u MACOSX_DEPLOYMENT_TARGET DEV_RELEASE=1 \
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \ CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
pip install . -v pip install . -v
- run: - run:
@@ -335,7 +346,7 @@ workflows:
- mac_build_and_test: - mac_build_and_test:
matrix: matrix:
parameters: parameters:
xcode_version: ["15.0.0", "15.2.0", "16.0.0"] macosx_deployment_target: ["13.5", "14.0"]
- linux_build_and_test - linux_build_and_test
- build_documentation - build_documentation
@@ -355,8 +366,70 @@ workflows:
matrix: matrix:
parameters: parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
xcode_version: ["15.0.0", "15.2.0"] macosx_deployment_target: ["13.5", "14.0", "15.0"]
build_env: ["PYPI_RELEASE=1"] build_env: ["PYPI_RELEASE=1"]
xcode_version: ["16.2.0", "15.0.0"]
exclude:
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.9"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.10"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "PYPI_RELEASE=1"
- build_documentation: - build_documentation:
filters: filters:
tags: tags:
@@ -379,7 +452,7 @@ workflows:
requires: [ hold ] requires: [ hold ]
matrix: matrix:
parameters: parameters:
xcode_version: ["15.0.0", "15.2.0", "16.0.0"] macosx_deployment_target: ["13.5", "14.0"]
- linux_build_and_test: - linux_build_and_test:
requires: [ hold ] requires: [ hold ]
nightly_build: nightly_build:
@@ -392,7 +465,54 @@ workflows:
matrix: matrix:
parameters: parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
xcode_version: ["15.0.0", "15.2.0"] macosx_deployment_target: ["13.5", "14.0", "15.0"]
xcode_version: ["16.2.0", "15.0.0"]
exclude:
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.9"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.10"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
weekly_build: weekly_build:
when: when:
and: and:
@@ -403,8 +523,70 @@ workflows:
matrix: matrix:
parameters: parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
xcode_version: ["15.0.0", "15.2.0", "16.0.0"] macosx_deployment_target: ["13.5", "14.0", "15.0"]
build_env: ["DEV_RELEASE=1"] build_env: ["DEV_RELEASE=1"]
xcode_version: ["16.2.0", "15.0.0"]
exclude:
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
linux_test_release: linux_test_release:
when: when:
and: and:

View File

@@ -212,24 +212,6 @@ else()
set(MLX_BUILD_ACCELERATE OFF) set(MLX_BUILD_ACCELERATE OFF)
endif() endif()
find_package(MPI)
if(MPI_FOUND)
execute_process(
COMMAND zsh "-c" "mpirun --version"
OUTPUT_VARIABLE MPI_VERSION
ERROR_QUIET)
if(${MPI_VERSION} MATCHES ".*Open MPI.*")
target_include_directories(mlx PRIVATE ${MPI_INCLUDE_PATH})
elseif(MPI_VERSION STREQUAL "")
set(MPI_FOUND FALSE)
message(
WARNING "MPI found but mpirun is not available. Building without MPI.")
else()
set(MPI_FOUND FALSE)
message(WARNING "MPI which is not OpenMPI found. Building without MPI.")
endif()
endif()
message(STATUS "Downloading json") message(STATUS "Downloading json")
FetchContent_Declare( FetchContent_Declare(
json json

View File

@@ -5,26 +5,26 @@ possible.
## Pull Requests ## Pull Requests
1. Fork and submit pull requests to the repo. 1. Fork and submit pull requests to the repo.
2. If you've added code that should be tested, add tests. 2. If you've added code that should be tested, add tests.
3. If a change is likely to impact efficiency, run some of the benchmarks before 3. If a change is likely to impact efficiency, run some of the benchmarks before
and after the change. Examples of benchmarks can be found in `benchmarks/python/`. and after the change. Examples of benchmarks can be found in `benchmarks/python/`.
4. If you've changed APIs, update the documentation. 4. If you've changed APIs, update the documentation.
5. Every PR should have passing tests and at least one review. 5. Every PR should have passing tests and at least one review.
6. For code formatting install `pre-commit` using something like `pip install pre-commit` and run `pre-commit install`. 6. For code formatting install `pre-commit` using something like `pip install pre-commit` and run `pre-commit install`.
This should install hooks for running `black` and `clang-format` to ensure This should install hooks for running `black` and `clang-format` to ensure
consistent style for C++ and python code. consistent style for C++ and python code.
You can also run the formatters manually as follows: You can also run the formatters manually as follows:
``` ```shell
clang-format -i file.cpp clang-format -i file.cpp
``` ```
``` ```shell
black file.py black file.py
``` ```
or run `pre-commit run --all-files` to check all files in the repo. or run `pre-commit run --all-files` to check all files in the repo.
## Issues ## Issues

View File

@@ -0,0 +1,74 @@
# Copyright © 2025 Apple Inc.
import mlx.core as mx
from time_utils import time_fn
N = 1024
D = 1024
M = 1024
E = 32
I = 4
def gather_sort(x, indices):
N, M = indices.shape
indices = indices.flatten()
order = mx.argsort(indices)
inv_order = mx.argsort(order)
return x.flatten(0, -3)[order // M], indices[order], inv_order
def scatter_unsort(x, inv_order, shape=None):
x = x[inv_order]
if shape is not None:
x = mx.unflatten(x, 0, shape)
return x
def gather_mm_simulate(x, w, indices):
x, idx, inv_order = gather_sort(x, indices)
for i in range(2):
y = mx.concatenate([x[i] @ w[j].T for i, j in enumerate(idx.tolist())], axis=0)
x = y[:, None]
x = scatter_unsort(x, inv_order, indices.shape)
return x
def time_gather_mm():
x = mx.random.normal((N, 1, 1, D)) / 1024**0.5
w1 = mx.random.normal((E, M, D)) / 1024**0.5
w2 = mx.random.normal((E, D, M)) / 1024**0.5
indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)
sorted_indices = mx.sort(indices.flatten()).reshape(N, I)
mx.eval(x, w1, w2, indices, sorted_indices)
def gather_mm(x, w1, w2, indices, sort):
idx = indices
inv_order = None
if sort:
x, idx, inv_order = gather_sort(x, indices)
x = mx.gather_mm(x, w1.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort)
x = mx.gather_mm(x, w2.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort)
if sort:
x = scatter_unsort(x, inv_order, indices.shape)
return x
time_fn(gather_mm, x, w1, w2, indices, False)
time_fn(gather_mm, x, w1, w2, sorted_indices, False)
time_fn(gather_mm, x, w1, w2, indices, True)
x = mx.random.normal((N * I, D)) / 1024**0.5
w1 = mx.random.normal((M, D)) / 1024**0.5
w2 = mx.random.normal((D, M)) / 1024**0.5
mx.eval(x, w1, w2)
def equivalent_matmul(x, w1, w2):
x = x @ w1.T
x = x @ w2.T
return x
time_fn(equivalent_matmul, x, w1, w2)
if __name__ == "__main__":
time_gather_mm()

View File

@@ -0,0 +1,84 @@
# Copyright © 2025 Apple Inc.
import mlx.core as mx
from time_utils import time_fn
N = 1024
D = 1024
M = 1024
E = 32
I = 4
def gather_sort(x, indices):
N, M = indices.shape
indices = indices.flatten()
order = mx.argsort(indices)
inv_order = mx.argsort(order)
return x.flatten(0, -3)[order // M], indices[order], inv_order
def scatter_unsort(x, inv_order, shape=None):
x = x[inv_order]
if shape is not None:
x = mx.unflatten(x, 0, shape)
return x
def gather_mm_simulate(x, w, indices):
x, idx, inv_order = gather_sort(x, indices)
for i in range(2):
y = mx.concatenate(
[
mx.quantized_matmul(x[i], w[0][j], w[1][j], w[2][j], transpose=True)
for i, j in enumerate(idx.tolist())
],
axis=0,
)
x = y[:, None]
x = scatter_unsort(x, inv_order, indices.shape)
return x
def time_gather_qmm():
x = mx.random.normal((N, 1, 1, D)) / 1024**0.5
w1 = mx.random.normal((E, M, D)) / 1024**0.5
w2 = mx.random.normal((E, D, M)) / 1024**0.5
w1 = mx.quantize(w1)
w2 = mx.quantize(w2)
indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)
sorted_indices = mx.sort(indices.flatten()).reshape(N, I)
mx.eval(x, w1, w2, indices, sorted_indices)
def gather_mm(x, w1, w2, indices, sort):
idx = indices
inv_order = None
if sort:
x, idx, inv_order = gather_sort(x, indices)
x = mx.gather_qmm(x, *w1, transpose=True, rhs_indices=idx, sorted_indices=sort)
x = mx.gather_qmm(x, *w2, transpose=True, rhs_indices=idx, sorted_indices=sort)
if sort:
x = scatter_unsort(x, inv_order, indices.shape)
return x
time_fn(gather_mm, x, w1, w2, indices, False)
time_fn(gather_mm, x, w1, w2, sorted_indices, False)
time_fn(gather_mm, x, w1, w2, indices, True)
x = mx.random.normal((N * I, D)) / 1024**0.5
w1 = mx.random.normal((M, D)) / 1024**0.5
w2 = mx.random.normal((D, M)) / 1024**0.5
w1 = mx.quantize(w1)
w2 = mx.quantize(w2)
mx.eval(x, w1, w2)
def equivalent_matmul(x, w1, w2):
x = mx.quantized_matmul(x, *w1, transpose=True)
x = mx.quantized_matmul(x, *w2, transpose=True)
return x
time_fn(equivalent_matmul, x, w1, w2)
if __name__ == "__main__":
time_gather_qmm()

View File

@@ -13,7 +13,7 @@ EXCLUDE_PATTERNS = */private/*
CREATE_SUBDIRS = NO CREATE_SUBDIRS = NO
FULL_PATH_NAMES = YES FULL_PATH_NAMES = YES
RECURSIVE = YES RECURSIVE = YES
GENERATE_HTML = YES GENERATE_HTML = NO
GENERATE_LATEX = NO GENERATE_LATEX = NO
GENERATE_XML = YES GENERATE_XML = YES
XML_PROGRAMLISTING = YES XML_PROGRAMLISTING = YES

View File

@@ -93,9 +93,9 @@ Primitives
^^^^^^^^^^^ ^^^^^^^^^^^
A :class:`Primitive` is part of the computation graph of an :class:`array`. It A :class:`Primitive` is part of the computation graph of an :class:`array`. It
defines how to create outputs arrays given a input arrays. Further, a defines how to create output arrays given input arrays. Further, a
:class:`Primitive` has methods to run on the CPU or GPU and for function :class:`Primitive` has methods to run on the CPU or GPU and for function
transformations such as ``vjp`` and ``jvp``. Lets go back to our example to be transformations such as ``vjp`` and ``jvp``. Let's go back to our example to be
more concrete: more concrete:
.. code-block:: C++ .. code-block:: C++
@@ -128,7 +128,7 @@ more concrete:
/** The vector-Jacobian product. */ /** The vector-Jacobian product. */
std::vector<array> vjp( std::vector<array> vjp(
const std::vector<array>& primals, const std::vector<array>& primals,
const array& cotan, const std::vector<array>& cotangents,
const std::vector<int>& argnums, const std::vector<int>& argnums,
const std::vector<array>& outputs) override; const std::vector<array>& outputs) override;
@@ -469,7 +469,7 @@ one we just defined:
const std::vector<array>& tangents, const std::vector<array>& tangents,
const std::vector<int>& argnums) { const std::vector<int>& argnums) {
// Forward mode diff that pushes along the tangents // Forward mode diff that pushes along the tangents
// The jvp transform on the primitive can built with ops // The jvp transform on the primitive can be built with ops
// that are scheduled on the same stream as the primitive // that are scheduled on the same stream as the primitive
// If argnums = {0}, we only push along x in which case the // If argnums = {0}, we only push along x in which case the
@@ -481,7 +481,7 @@ one we just defined:
auto scale_arr = array(scale, tangents[0].dtype()); auto scale_arr = array(scale, tangents[0].dtype());
return {multiply(scale_arr, tangents[0], stream())}; return {multiply(scale_arr, tangents[0], stream())};
} }
// If, argnums = {0, 1}, we take contributions from both // If argnums = {0, 1}, we take contributions from both
// which gives us jvp = tangent_x * alpha + tangent_y * beta // which gives us jvp = tangent_x * alpha + tangent_y * beta
else { else {
return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())}; return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())};
@@ -735,7 +735,7 @@ Let's look at a simple script and its results:
print(f"c shape: {c.shape}") print(f"c shape: {c.shape}")
print(f"c dtype: {c.dtype}") print(f"c dtype: {c.dtype}")
print(f"c correct: {mx.all(c == 6.0).item()}") print(f"c is correct: {mx.all(c == 6.0).item()}")
Output: Output:
@@ -743,7 +743,7 @@ Output:
c shape: [3, 4] c shape: [3, 4]
c dtype: float32 c dtype: float32
c correctness: True c is correct: True
Results Results
^^^^^^^ ^^^^^^^

View File

@@ -70,6 +70,7 @@ are the CPU and GPU.
python/fft python/fft
python/linalg python/linalg
python/metal python/metal
python/memory_management
python/nn python/nn
python/optimizers python/optimizers
python/distributed python/distributed

View File

@@ -38,6 +38,7 @@ Array
array.log10 array.log10
array.log1p array.log1p
array.log2 array.log2
array.logcumsumexp
array.logsumexp array.logsumexp
array.max array.max
array.mean array.mean

View File

@@ -20,5 +20,6 @@ Linear Algebra
eigh eigh
lu lu
lu_factor lu_factor
pinv
solve solve
solve_triangular solve_triangular

View File

@@ -0,0 +1,16 @@
Memory Management
=================
.. currentmodule:: mlx.core
.. autosummary::
:toctree: _autosummary
get_active_memory
get_peak_memory
reset_peak_memory
get_cache_memory
set_memory_limit
set_cache_limit
set_wired_limit
clear_cache

View File

@@ -8,13 +8,5 @@ Metal
is_available is_available
device_info device_info
get_active_memory
get_peak_memory
reset_peak_memory
get_cache_memory
set_memory_limit
set_cache_limit
set_wired_limit
clear_cache
start_capture start_capture
stop_capture stop_capture

View File

@@ -36,10 +36,12 @@ Operations
bitwise_or bitwise_or
bitwise_xor bitwise_xor
block_masked_mm block_masked_mm
broadcast_arrays
broadcast_to broadcast_to
ceil ceil
clip clip
concatenate concatenate
contiguous
conj conj
conjugate conjugate
convolve convolve
@@ -101,6 +103,7 @@ Operations
log10 log10
log1p log1p
logaddexp logaddexp
logcumsumexp
logical_not logical_not
logical_and logical_and
logical_or logical_or

View File

@@ -18,3 +18,4 @@ Common Optimizers
AdamW AdamW
Adamax Adamax
Lion Lion
MultiOptimizer

View File

@@ -9,6 +9,7 @@ Transforms
:toctree: _autosummary :toctree: _autosummary
eval eval
async_eval
compile compile
custom_function custom_function
disable_compile disable_compile

View File

@@ -5,6 +5,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp ${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp ${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
${CMAKE_CURRENT_SOURCE_DIR}/dtype_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/export.cpp ${CMAKE_CURRENT_SOURCE_DIR}/export.cpp
${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp ${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp

View File

@@ -4,7 +4,6 @@
#include <sstream> #include <sstream>
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/scheduler.h"
namespace mlx::core::allocator { namespace mlx::core::allocator {
@@ -22,23 +21,4 @@ void free(Buffer buffer) {
allocator().free(buffer); allocator().free(buffer);
} }
Buffer CommonAllocator::malloc(size_t size) {
void* ptr = std::malloc(size + sizeof(size_t));
if (ptr != nullptr) {
*static_cast<size_t*>(ptr) = size;
}
return Buffer{ptr};
}
void CommonAllocator::free(Buffer buffer) {
std::free(buffer.ptr());
}
size_t CommonAllocator::size(Buffer buffer) const {
if (buffer.ptr() == nullptr) {
return 0;
}
return *static_cast<size_t*>(buffer.ptr());
}
} // namespace mlx::core::allocator } // namespace mlx::core::allocator

View File

@@ -49,16 +49,4 @@ class Allocator {
Allocator& allocator(); Allocator& allocator();
class CommonAllocator : public Allocator {
/** A general CPU allocator. */
public:
virtual Buffer malloc(size_t size) override;
virtual void free(Buffer buffer) override;
virtual size_t size(Buffer buffer) const override;
private:
CommonAllocator() = default;
friend Allocator& allocator();
};
} // namespace mlx::core::allocator } // namespace mlx::core::allocator

View File

@@ -339,11 +339,11 @@ class array {
return allocator::allocator().size(buffer()); return allocator::allocator().size(buffer());
} }
// Return a copy of the shared pointer // Return the shared pointer to the array::Data struct
// to the array::Data struct const std::shared_ptr<Data>& data_shared_ptr() const {
std::shared_ptr<Data> data_shared_ptr() const {
return array_desc_->data; return array_desc_->data;
} }
// Return a raw pointer to the arrays data // Return a raw pointer to the arrays data
template <typename T> template <typename T>
T* data() { T* data() {

View File

@@ -1,6 +1,7 @@
target_sources( target_sources(
mlx mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/broadcasting.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp ${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp

View File

@@ -0,0 +1,24 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/common/utils.h"
namespace mlx::core {
void broadcast(const array& in, array& out) {
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
Strides strides(out.ndim(), 0);
int diff = out.ndim() - in.ndim();
for (int i = in.ndim() - 1; i >= 0; --i) {
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
}
auto flags = in.flags();
if (out.size() > in.size()) {
flags.row_contiguous = flags.col_contiguous = false;
}
out.copy_shared_buffer(in, strides, flags, in.data_size());
}
} // namespace mlx::core

View File

@@ -0,0 +1,11 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/array.h"
namespace mlx::core {
void broadcast(const array& in, array& out);
} // namespace mlx::core

View File

@@ -1,6 +1,7 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
#include <cassert> #include <cassert>
#include "mlx/backend/common/broadcasting.h"
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
@@ -42,23 +43,6 @@ void AsStrided::eval(const std::vector<array>& inputs, array& out) {
return out.copy_shared_buffer(in, strides_, flags, data_size, offset_); return out.copy_shared_buffer(in, strides_, flags, data_size, offset_);
} }
void broadcast(const array& in, array& out) {
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
Strides strides(out.ndim(), 0);
int diff = out.ndim() - in.ndim();
for (int i = in.ndim() - 1; i >= 0; --i) {
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
}
auto flags = in.flags();
if (out.size() > in.size()) {
flags.row_contiguous = flags.col_contiguous = false;
}
out.copy_shared_buffer(in, strides, flags, in.data_size());
}
void Broadcast::eval(const std::vector<array>& inputs, array& out) { void Broadcast::eval(const std::vector<array>& inputs, array& out) {
broadcast(inputs[0], out); broadcast(inputs[0], out);
} }

View File

@@ -58,6 +58,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp ${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp ${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp ${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp ${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
@@ -73,8 +74,8 @@ target_sources(
if(MLX_BUILD_ACCELERATE) if(MLX_BUILD_ACCELERATE)
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/bnns.cpp) target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/bnns.cpp)
else() else()
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/no_fp16.cpp target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/simd_fp16.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gemms/no_bf16.cpp) ${CMAKE_CURRENT_SOURCE_DIR}/gemms/simd_bf16.cpp)
endif() endif()
if(IOS) if(IOS)

View File

@@ -46,8 +46,15 @@ void AllReduce::eval_cpu(
case Sum: case Sum:
distributed::detail::all_sum(group(), in, outputs[0], stream()); distributed::detail::all_sum(group(), in, outputs[0], stream());
break; break;
case Max:
distributed::detail::all_max(group(), in, outputs[0], stream());
break;
case Min:
distributed::detail::all_min(group(), in, outputs[0], stream());
break;
default: default:
throw std::runtime_error("Only all reduce sum is supported for now"); throw std::runtime_error(
"Only all reduce sum, min and max are supported for now");
} }
} }

View File

@@ -1,27 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cpu/gemm.h"
namespace mlx::core {
template <>
void matmul<bfloat16_t>(
const bfloat16_t*,
const bfloat16_t*,
bfloat16_t*,
bool,
bool,
size_t,
size_t,
size_t,
float,
float,
size_t,
const Shape&,
const Strides&,
const Shape&,
const Strides&) {
throw std::runtime_error("[Matmul::eval_cpu] bfloat16 not supported.");
}
} // namespace mlx::core

View File

@@ -1,27 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cpu/gemm.h"
namespace mlx::core {
template <>
void matmul<float16_t>(
const float16_t*,
const float16_t*,
float16_t*,
bool,
bool,
size_t,
size_t,
size_t,
float,
float,
size_t,
const Shape&,
const Strides&,
const Shape&,
const Strides&) {
throw std::runtime_error("[Matmul::eval_cpu] float16 not supported.");
}
} // namespace mlx::core

View File

@@ -0,0 +1,45 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/gemm.h"
#include "mlx/backend/cpu/gemms/simd_gemm.h"
namespace mlx::core {
template <>
void matmul<bfloat16_t>(
const bfloat16_t* a,
const bfloat16_t* b,
bfloat16_t* out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
size_t ldc,
float alpha,
float beta,
size_t batch_size,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides) {
auto ndim = a_shape.size();
size_t M = a_shape[ndim - 2];
size_t N = b_shape[ndim - 1];
size_t K = a_shape[ndim - 1];
for (int i = 0; i < batch_size; ++i) {
simd_gemm<bfloat16_t, float>(
a + elem_to_loc(M * K * i, a_shape, a_strides),
b + elem_to_loc(K * N * i, b_shape, b_strides),
out + M * N * i,
a_transposed,
b_transposed,
M,
N,
K,
alpha,
beta);
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,45 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/gemm.h"
#include "mlx/backend/cpu/gemms/simd_gemm.h"
namespace mlx::core {
template <>
void matmul<float16_t>(
const float16_t* a,
const float16_t* b,
float16_t* out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
size_t ldc,
float alpha,
float beta,
size_t batch_size,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides) {
auto ndim = a_shape.size();
size_t M = a_shape[ndim - 2];
size_t N = b_shape[ndim - 1];
size_t K = a_shape[ndim - 1];
for (int i = 0; i < batch_size; ++i) {
simd_gemm<float16_t, float>(
a + elem_to_loc(M * K * i, a_shape, a_strides),
b + elem_to_loc(K * N * i, b_shape, b_strides),
out + M * N * i,
a_transposed,
b_transposed,
M,
N,
K,
alpha,
beta);
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,139 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/backend/cpu/simd/simd.h"
namespace mlx::core {
inline int ceildiv(int a, int b) {
return (a + b - 1) / b;
}
template <int block_size, typename T, typename AccT>
void load_block(
const T* in,
AccT* out,
int M,
int N,
int i,
int j,
bool transpose) {
if (transpose) {
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
out[jj * block_size + ii] =
in[(i * block_size + ii) * N + j * block_size + jj];
}
}
} else {
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
out[ii * block_size + jj] =
in[(i * block_size + ii) * N + j * block_size + jj];
}
}
}
}
template <typename T, typename AccT>
void simd_gemm(
const T* a,
const T* b,
T* c,
bool a_trans,
bool b_trans,
int M,
int N,
int K,
float alpha,
float beta) {
constexpr int block_size = 16;
constexpr int simd_size = simd::max_size<AccT>;
static_assert(
(block_size % simd_size) == 0,
"Block size must be divisible by SIMD size");
int last_k_block_size = K - block_size * (K / block_size);
int last_k_simd_block = (last_k_block_size / simd_size) * simd_size;
for (int i = 0; i < ceildiv(M, block_size); i++) {
for (int j = 0; j < ceildiv(N, block_size); j++) {
AccT c_block[block_size * block_size] = {0.0};
AccT a_block[block_size * block_size];
AccT b_block[block_size * block_size];
int k = 0;
for (; k < K / block_size; k++) {
// Load a and b blocks
if (a_trans) {
load_block<block_size>(a, a_block, K, M, k, i, true);
} else {
load_block<block_size>(a, a_block, M, K, i, k, false);
}
if (b_trans) {
load_block<block_size>(b, b_block, N, K, j, k, false);
} else {
load_block<block_size>(b, b_block, K, N, k, j, true);
}
// Multiply and accumulate
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
for (int kk = 0; kk < block_size; kk += simd_size) {
auto av =
simd::load<AccT, simd_size>(a_block + ii * block_size + kk);
auto bv =
simd::load<AccT, simd_size>(b_block + jj * block_size + kk);
c_block[ii * block_size + jj] += simd::sum(av * bv);
}
}
}
}
if (last_k_block_size) {
// Load a and b blocks
if (a_trans) {
load_block<block_size>(a, a_block, K, M, k, i, true);
} else {
load_block<block_size>(a, a_block, M, K, i, k, false);
}
if (b_trans) {
load_block<block_size>(b, b_block, N, K, j, k, false);
} else {
load_block<block_size>(b, b_block, K, N, k, j, true);
}
// Multiply and accumulate
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
int kk = 0;
for (; kk < last_k_simd_block; kk += simd_size) {
auto av =
simd::load<AccT, simd_size>(a_block + ii * block_size + kk);
auto bv =
simd::load<AccT, simd_size>(b_block + jj * block_size + kk);
c_block[ii * block_size + jj] += simd::sum(av * bv);
}
for (; kk < last_k_block_size; ++kk) {
c_block[ii * block_size + jj] +=
a_block[ii * block_size + kk] * b_block[jj * block_size + kk];
}
}
}
}
// Store
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
auto c_idx = (i * block_size + ii) * N + j * block_size + jj;
if (beta != 0) {
c[c_idx] = static_cast<T>(
alpha * c_block[ii * block_size + jj] + beta * c[c_idx]);
} else {
c[c_idx] = static_cast<T>(alpha * c_block[ii * block_size + jj]);
}
}
}
}
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,140 @@
// Copyright © 2023-2024 Apple Inc.
#include <cassert>
#include <cmath>
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h"
#include "mlx/primitives.h"
#include "mlx/types/limits.h"
namespace mlx::core {
namespace {
using namespace mlx::core::simd;
template <typename T, typename AccT>
void logsumexp(const array& in, array& out, Stream stream) {
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(in);
encoder.set_output_array(out);
const T* in_ptr = in.data<T>();
T* out_ptr = out.data<T>();
int M = in.shape().back();
int L = in.data_size() / M;
encoder.dispatch([in_ptr, out_ptr, M, L]() mutable {
constexpr int N = std::min(max_size<AccT>, max_size<T>);
const T* current_in_ptr;
for (int i = 0; i < L; i++, in_ptr += M, out_ptr += 1) {
// Find the maximum
current_in_ptr = in_ptr;
Simd<AccT, N> vmaximum(-numeric_limits<AccT>::infinity());
size_t s = M;
while (s >= N) {
Simd<AccT, N> vals = load<T, N>(current_in_ptr);
vmaximum = maximum(vals, vmaximum);
current_in_ptr += N;
s -= N;
}
AccT maximum = max(vmaximum);
while (s-- > 0) {
maximum = std::max(maximum, static_cast<AccT>(*current_in_ptr));
current_in_ptr++;
}
// Compute the normalizer and the exponentials
Simd<AccT, N> vnormalizer(0.0);
current_in_ptr = in_ptr;
s = M;
while (s >= N) {
Simd<AccT, N> vexp = load<T, N>(current_in_ptr);
vexp = exp(vexp - maximum);
vnormalizer = vnormalizer + vexp;
current_in_ptr += N;
s -= N;
}
AccT normalizer = sum(vnormalizer);
while (s-- > 0) {
AccT _exp = std::exp(*current_in_ptr - maximum);
normalizer += _exp;
current_in_ptr++;
}
// Normalize
*out_ptr = std::isinf(maximum)
? static_cast<T>(maximum)
: static_cast<T>(std::log(normalizer) + maximum);
}
});
}
} // namespace
void LogSumExp::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
// Make sure that the last dimension is contiguous
auto s = stream();
auto& encoder = cpu::get_command_encoder(s);
auto ensure_contiguous = [&s, &encoder](const array& x) {
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
return x;
} else {
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy(x, x_copy, CopyType::General, s);
encoder.add_temporary(x_copy);
return x_copy;
}
};
auto in = ensure_contiguous(inputs[0]);
if (in.flags().row_contiguous) {
out.set_data(allocator::malloc(out.nbytes()));
} else {
auto n = in.shape(-1);
auto flags = in.flags();
auto strides = in.strides();
for (auto& s : strides) {
s /= n;
}
bool col_contig = strides[0] == 1;
for (int i = 1; col_contig && i < strides.size(); ++i) {
col_contig &=
(out.shape(i) == 1 || strides[i - 1] == out.shape(i) * strides[i]);
}
flags.col_contiguous = col_contig;
out.set_data(
allocator::malloc(in.nbytes() / n),
in.data_size() / n,
std::move(strides),
flags);
}
switch (in.dtype()) {
case float32:
logsumexp<float, float>(in, out, stream());
break;
case float16:
logsumexp<float16_t, float>(in, out, stream());
break;
case bfloat16:
logsumexp<bfloat16_t, float>(in, out, stream());
break;
case float64:
logsumexp<double, double>(in, out, stream());
break;
default:
throw std::runtime_error(
"[logsumexp] only supports floating point types");
break;
}
}
} // namespace mlx::core

View File

@@ -205,8 +205,10 @@ void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
void Contiguous::eval_cpu(const std::vector<array>& inputs, array& out) { void Contiguous::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
if (in.flags().row_contiguous || constexpr size_t extra_bytes = 16384;
(allow_col_major_ && in.flags().col_contiguous)) { if (in.buffer_size() <= out.nbytes() + extra_bytes &&
(in.flags().row_contiguous ||
(allow_col_major_ && in.flags().col_contiguous))) {
out.copy_shared_buffer(in); out.copy_shared_buffer(in);
} else { } else {
copy(in, out, CopyType::General, stream()); copy(in, out, CopyType::General, stream());

View File

@@ -3,6 +3,7 @@
#include <cassert> #include <cassert>
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/binary_ops.h"
#include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h" #include "mlx/backend/cpu/simd/simd.h"
@@ -226,6 +227,16 @@ void scan_dispatch(
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init); scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
break; break;
} }
case Scan::LogAddExp: {
auto op = [](U a, T b) {
return detail::LogAddExp{}(a, static_cast<U>(b));
};
auto init = (issubdtype(in.dtype(), floating))
? static_cast<U>(-std::numeric_limits<float>::infinity())
: std::numeric_limits<U>::min();
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
break;
}
} }
} }

View File

@@ -17,7 +17,7 @@ struct ScalarT<float16_t, N> {
#endif #endif
template <> template <>
static constexpr int max_size<float16_t> = N; inline constexpr int max_size<float16_t> = N;
#define SIMD_FP16_DEFAULT_UNARY(op) \ #define SIMD_FP16_DEFAULT_UNARY(op) \
template <> \ template <> \

View File

@@ -83,25 +83,25 @@ struct Simd {
// Values chosen based on benchmarks on M3 Max // Values chosen based on benchmarks on M3 Max
// TODO: consider choosing these more optimally // TODO: consider choosing these more optimally
template <> template <>
static constexpr int max_size<int8_t> = 16; inline constexpr int max_size<int8_t> = 16;
template <> template <>
static constexpr int max_size<int16_t> = 16; inline constexpr int max_size<int16_t> = 16;
template <> template <>
static constexpr int max_size<int> = 8; inline constexpr int max_size<int> = 8;
template <> template <>
static constexpr int max_size<int64_t> = 4; inline constexpr int max_size<int64_t> = 4;
template <> template <>
static constexpr int max_size<uint8_t> = 16; inline constexpr int max_size<uint8_t> = 16;
template <> template <>
static constexpr int max_size<uint16_t> = 16; inline constexpr int max_size<uint16_t> = 16;
template <> template <>
static constexpr int max_size<uint32_t> = 8; inline constexpr int max_size<uint32_t> = 8;
template <> template <>
static constexpr int max_size<uint64_t> = 4; inline constexpr int max_size<uint64_t> = 4;
template <> template <>
static constexpr int max_size<float> = 8; inline constexpr int max_size<float> = 8;
template <> template <>
static constexpr int max_size<double> = 4; inline constexpr int max_size<double> = 4;
#define SIMD_DEFAULT_UNARY(name, op) \ #define SIMD_DEFAULT_UNARY(name, op) \
template <typename T, int N> \ template <typename T, int N> \

View File

@@ -87,7 +87,6 @@ DEFAULT_UNARY(cosh, std::cosh)
DEFAULT_UNARY(expm1, std::expm1) DEFAULT_UNARY(expm1, std::expm1)
DEFAULT_UNARY(floor, std::floor) DEFAULT_UNARY(floor, std::floor)
DEFAULT_UNARY(log, std::log) DEFAULT_UNARY(log, std::log)
DEFAULT_UNARY(log2, std::log2)
DEFAULT_UNARY(log10, std::log10) DEFAULT_UNARY(log10, std::log10)
DEFAULT_UNARY(log1p, std::log1p) DEFAULT_UNARY(log1p, std::log1p)
DEFAULT_UNARY(sinh, std::sinh) DEFAULT_UNARY(sinh, std::sinh)
@@ -95,6 +94,17 @@ DEFAULT_UNARY(sqrt, std::sqrt)
DEFAULT_UNARY(tan, std::tan) DEFAULT_UNARY(tan, std::tan)
DEFAULT_UNARY(tanh, std::tanh) DEFAULT_UNARY(tanh, std::tanh)
template <typename T>
Simd<T, 1> log2(Simd<T, 1> in) {
if constexpr (is_complex<T>) {
auto out = std::log(in.value);
auto scale = decltype(out.real())(M_LN2);
return Simd<T, 1>{T{out.real() / scale, out.imag() / scale}};
} else {
return Simd<T, 1>{std::log2(in.value)};
}
}
template <typename T> template <typename T>
Simd<T, 1> operator~(Simd<T, 1> in) { Simd<T, 1> operator~(Simd<T, 1> in) {
return ~in.value; return ~in.value;

View File

@@ -119,12 +119,7 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
// Make sure that the last dimension is contiguous // Make sure that the last dimension is contiguous
auto set_output = [s = stream(), &out](const array& x) { auto set_output = [s = stream(), &out](const array& x) {
bool no_copy = x.strides()[x.ndim() - 1] == 1; if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
if (x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2];
no_copy &= (s == 0 || s == x.shape().back());
}
if (no_copy) {
if (x.is_donatable()) { if (x.is_donatable()) {
out.copy_shared_buffer(x); out.copy_shared_buffer(x);
} else { } else {
@@ -146,18 +141,6 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
auto in = set_output(inputs[0]); auto in = set_output(inputs[0]);
switch (in.dtype()) { switch (in.dtype()) {
case bool_:
case uint8:
case uint16:
case uint32:
case uint64:
case int8:
case int16:
case int32:
case int64:
throw std::runtime_error(
"Softmax is defined only for floating point types");
break;
case float32: case float32:
softmax<float, float>(in, out, stream()); softmax<float, float>(in, out, stream());
break; break;
@@ -178,9 +161,9 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
case float64: case float64:
softmax<double, double>(in, out, stream()); softmax<double, double>(in, out, stream());
break; break;
case complex64: default:
throw std::invalid_argument( throw std::runtime_error(
"[Softmax] Not yet implemented for complex64"); "[softmax] Only defined for floating point types.");
break; break;
} }
} }

View File

@@ -1,5 +1,8 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
// Required for using M_LN2 in MSVC.
#define _USE_MATH_DEFINES
#include <cassert> #include <cassert>
#include "mlx/backend/cpu/unary.h" #include "mlx/backend/cpu/unary.h"

View File

@@ -86,13 +86,14 @@ struct Sign {
template <int N, typename T> template <int N, typename T>
Simd<T, N> operator()(Simd<T, N> x) { Simd<T, N> operator()(Simd<T, N> x) {
auto z = Simd<T, N>{0}; auto z = Simd<T, N>{0};
auto o = Simd<T, N>{1};
auto m = Simd<T, N>{-1};
if constexpr (std::is_unsigned_v<T>) { if constexpr (std::is_unsigned_v<T>) {
return x != z; return simd::select(x == z, z, o);
} else if constexpr (std::is_same_v<T, complex64_t>) { } else if constexpr (std::is_same_v<T, complex64_t>) {
return simd::select(x == z, x, Simd<T, N>(x / simd::abs(x))); return simd::select(x == z, x, Simd<T, N>(x / simd::abs(x)));
} else { } else {
return simd::select( return simd::select(x < z, m, simd::select(x > z, o, z));
x < z, Simd<T, N>{-1}, simd::select(x > z, Simd<T, N>{1}, z));
} }
} }
SINGLE() SINGLE()

View File

@@ -47,6 +47,7 @@ if(MLX_METAL_JIT)
make_jit_source(binary) make_jit_source(binary)
make_jit_source(binary_two) make_jit_source(binary_two)
make_jit_source(fft kernels/fft/radix.h kernels/fft/readwrite.h) make_jit_source(fft kernels/fft/radix.h kernels/fft/readwrite.h)
make_jit_source(logsumexp)
make_jit_source(ternary) make_jit_source(ternary)
make_jit_source(softmax) make_jit_source(softmax)
make_jit_source(scan) make_jit_source(scan)
@@ -60,6 +61,7 @@ if(MLX_METAL_JIT)
kernels/steel/gemm/transforms.h) kernels/steel/gemm/transforms.h)
make_jit_source(steel/gemm/kernels/steel_gemm_fused) make_jit_source(steel/gemm/kernels/steel_gemm_fused)
make_jit_source(steel/gemm/kernels/steel_gemm_masked kernels/steel/defines.h) make_jit_source(steel/gemm/kernels/steel_gemm_masked kernels/steel/defines.h)
make_jit_source(steel/gemm/kernels/steel_gemm_gather)
make_jit_source(steel/gemm/kernels/steel_gemm_splitk) make_jit_source(steel/gemm/kernels/steel_gemm_splitk)
make_jit_source( make_jit_source(
steel/conv/conv steel/conv/conv
@@ -95,6 +97,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp ${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cpp
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp ${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp ${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp

View File

@@ -3,6 +3,7 @@
#include "mlx/backend/metal/metal.h" #include "mlx/backend/metal/metal.h"
#include "mlx/backend/metal/metal_impl.h" #include "mlx/backend/metal/metal_impl.h"
#include "mlx/backend/metal/resident.h" #include "mlx/backend/metal/resident.h"
#include "mlx/memory.h"
#include <mach/vm_page_size.h> #include <mach/vm_page_size.h>
#include <unistd.h> #include <unistd.h>
@@ -32,8 +33,11 @@ namespace metal {
namespace { namespace {
BufferCache::BufferCache(MTL::Device* device) BufferCache::BufferCache(ResidencySet& residency_set)
: device_(device), head_(nullptr), tail_(nullptr), pool_size_(0) {} : head_(nullptr),
tail_(nullptr),
pool_size_(0),
residency_set_(residency_set) {}
BufferCache::~BufferCache() { BufferCache::~BufferCache() {
auto pool = metal::new_scoped_memory_pool(); auto pool = metal::new_scoped_memory_pool();
@@ -44,6 +48,9 @@ int BufferCache::clear() {
int n_release = 0; int n_release = 0;
for (auto& [size, holder] : buffer_pool_) { for (auto& [size, holder] : buffer_pool_) {
if (holder->buf) { if (holder->buf) {
if (!holder->buf->heap()) {
residency_set_.erase(holder->buf);
}
holder->buf->release(); holder->buf->release();
n_release++; n_release++;
} }
@@ -101,6 +108,9 @@ int BufferCache::release_cached_buffers(size_t min_bytes_to_free) {
while (tail_ && (total_bytes_freed < min_bytes_to_free)) { while (tail_ && (total_bytes_freed < min_bytes_to_free)) {
if (tail_->buf) { if (tail_->buf) {
total_bytes_freed += tail_->buf->length(); total_bytes_freed += tail_->buf->length();
if (!tail_->buf->heap()) {
residency_set_.erase(tail_->buf);
}
tail_->buf->release(); tail_->buf->release();
tail_->buf = nullptr; tail_->buf = nullptr;
n_release++; n_release++;
@@ -155,7 +165,7 @@ void BufferCache::remove_from_list(BufferCache::BufferHolder* to_remove) {
MetalAllocator::MetalAllocator() MetalAllocator::MetalAllocator()
: device_(device(mlx::core::Device::gpu).mtl_device()), : device_(device(mlx::core::Device::gpu).mtl_device()),
residency_set_(device_), residency_set_(device_),
buffer_cache_(device_) { buffer_cache_(residency_set_) {
auto pool = metal::new_scoped_memory_pool(); auto pool = metal::new_scoped_memory_pool();
auto memsize = std::get<size_t>(device_info().at("memory_size")); auto memsize = std::get<size_t>(device_info().at("memory_size"));
auto max_rec_size = auto max_rec_size =
@@ -262,9 +272,13 @@ Buffer MetalAllocator::malloc(size_t size) {
if (!buf) { if (!buf) {
buf = device_->newBuffer(size, resource_options); buf = device_->newBuffer(size, resource_options);
} }
if (!buf) {
return Buffer{nullptr};
}
lk.lock(); lk.lock();
if (buf) { num_resources_++;
num_resources_++; if (!buf->heap()) {
residency_set_.insert(buf);
} }
} }
@@ -278,10 +292,6 @@ Buffer MetalAllocator::malloc(size_t size) {
get_cache_memory() - max_pool_size_); get_cache_memory() - max_pool_size_);
} }
if (!buf->heap()) {
residency_set_.insert(buf);
}
return Buffer{static_cast<void*>(buf)}; return Buffer{static_cast<void*>(buf)};
} }
@@ -297,14 +307,14 @@ void MetalAllocator::free(Buffer buffer) {
return; return;
} }
std::unique_lock lk(mutex_); std::unique_lock lk(mutex_);
if (!buf->heap()) {
residency_set_.erase(buf);
}
active_memory_ -= buf->length(); active_memory_ -= buf->length();
if (get_cache_memory() < max_pool_size_) { if (get_cache_memory() < max_pool_size_) {
buffer_cache_.recycle_to_cache(buf); buffer_cache_.recycle_to_cache(buf);
} else { } else {
num_resources_--; num_resources_--;
if (!buf->heap()) {
residency_set_.erase(buf);
}
lk.unlock(); lk.unlock();
auto pool = metal::new_scoped_memory_pool(); auto pool = metal::new_scoped_memory_pool();
buf->release(); buf->release();
@@ -323,40 +333,40 @@ MetalAllocator& allocator() {
return *allocator_; return *allocator_;
} }
} // namespace metal
size_t set_cache_limit(size_t limit) { size_t set_cache_limit(size_t limit) {
return allocator().set_cache_limit(limit); return metal::allocator().set_cache_limit(limit);
} }
size_t set_memory_limit(size_t limit) { size_t set_memory_limit(size_t limit) {
return allocator().set_memory_limit(limit); return metal::allocator().set_memory_limit(limit);
} }
size_t get_memory_limit() { size_t get_memory_limit() {
return allocator().get_memory_limit(); return metal::allocator().get_memory_limit();
} }
size_t set_wired_limit(size_t limit) { size_t set_wired_limit(size_t limit) {
if (limit > if (limit > std::get<size_t>(metal::device_info().at(
std::get<size_t>(device_info().at("max_recommended_working_set_size"))) { "max_recommended_working_set_size"))) {
throw std::invalid_argument( throw std::invalid_argument(
"[metal::set_wired_limit] Setting a wired limit larger than " "[metal::set_wired_limit] Setting a wired limit larger than "
"the maximum working set size is not allowed."); "the maximum working set size is not allowed.");
} }
return allocator().set_wired_limit(limit); return metal::allocator().set_wired_limit(limit);
} }
size_t get_active_memory() { size_t get_active_memory() {
return allocator().get_active_memory(); return metal::allocator().get_active_memory();
} }
size_t get_peak_memory() { size_t get_peak_memory() {
return allocator().get_peak_memory(); return metal::allocator().get_peak_memory();
} }
void reset_peak_memory() { void reset_peak_memory() {
allocator().reset_peak_memory(); metal::allocator().reset_peak_memory();
} }
size_t get_cache_memory() { size_t get_cache_memory() {
return allocator().get_cache_memory(); return metal::allocator().get_cache_memory();
} }
void clear_cache() { void clear_cache() {
return allocator().clear_cache(); return metal::allocator().clear_cache();
} }
} // namespace metal
} // namespace mlx::core } // namespace mlx::core

View File

@@ -18,7 +18,7 @@ namespace {
class BufferCache { class BufferCache {
public: public:
BufferCache(MTL::Device* device); BufferCache(ResidencySet& residency_set);
~BufferCache(); ~BufferCache();
MTL::Buffer* reuse_from_cache(size_t size); MTL::Buffer* reuse_from_cache(size_t size);
@@ -42,13 +42,11 @@ class BufferCache {
void add_at_head(BufferHolder* to_add); void add_at_head(BufferHolder* to_add);
void remove_from_list(BufferHolder* to_remove); void remove_from_list(BufferHolder* to_remove);
MTL::Device* device_;
MTL::Heap* heap_{nullptr};
std::multimap<size_t, BufferHolder*> buffer_pool_; std::multimap<size_t, BufferHolder*> buffer_pool_;
BufferHolder* head_; BufferHolder* head_;
BufferHolder* tail_; BufferHolder* tail_;
size_t pool_size_; size_t pool_size_;
ResidencySet& residency_set_;
}; };
} // namespace } // namespace

View File

@@ -712,6 +712,65 @@ void winograd_conv_2D_gpu(
} }
} }
void depthwise_conv_2D_gpu(
const Stream& s,
metal::Device& d,
const array& in,
const array& wt,
array out,
const MLXConvParams<2>& conv_params) {
std::ostringstream kname;
kname << "depthwise_conv_2d_" << type_to_name(out);
std::string base_name = kname.str();
const int N = conv_params.N;
const int ker_h = conv_params.wS[0];
const int ker_w = conv_params.wS[1];
const int str_h = conv_params.str[0];
const int str_w = conv_params.str[1];
const int tc = 8;
const int tw = 8;
const int th = 4;
const bool do_flip = conv_params.flip;
metal::MTLFCList func_consts = {
{&ker_h, MTL::DataType::DataTypeInt, 00},
{&ker_w, MTL::DataType::DataTypeInt, 01},
{&str_h, MTL::DataType::DataTypeInt, 10},
{&str_w, MTL::DataType::DataTypeInt, 11},
{&th, MTL::DataType::DataTypeInt, 100},
{&tw, MTL::DataType::DataTypeInt, 101},
{&do_flip, MTL::DataType::DataTypeBool, 200},
};
// clang-format off
kname << "_ker_h_" << ker_h
<< "_ker_w_" << ker_w
<< "_str_h_" << str_h
<< "_str_w_" << str_w
<< "_tgp_h_" << th
<< "_tgp_w_" << tw
<< "_do_flip_" << (do_flip ? 't' : 'n'); // clang-format on
std::string hash_name = kname.str();
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_input_array(wt, 1);
compute_encoder.set_output_array(out, 2);
compute_encoder.set_bytes(conv_params, 3);
MTL::Size group_dims = MTL::Size(tc, tw, th);
MTL::Size grid_dims = MTL::Size(
conv_params.C / tc, conv_params.oS[1] / tw, (conv_params.oS[0] / th) * N);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
void conv_2D_gpu( void conv_2D_gpu(
const Stream& s, const Stream& s,
metal::Device& d, metal::Device& d,
@@ -754,11 +813,20 @@ void conv_2D_gpu(
bool is_kdil_one = conv_params.kdil[0] == 1 && conv_params.kdil[1] == 1; bool is_kdil_one = conv_params.kdil[0] == 1 && conv_params.kdil[1] == 1;
bool is_idil_one = conv_params.idil[0] == 1 && conv_params.idil[1] == 1; bool is_idil_one = conv_params.idil[0] == 1 && conv_params.idil[1] == 1;
if (groups > 1) { if (is_idil_one && groups > 1) {
const int C_per_group = conv_params.C / groups; const int C_per_group = conv_params.C / groups;
const int O_per_group = conv_params.O / groups; const int O_per_group = conv_params.O / groups;
if (is_idil_one && (C_per_group <= 4 || C_per_group % 16 == 0) && if (C_per_group == 1 && O_per_group == 1 && is_kdil_one &&
conv_params.wS[0] <= 7 && conv_params.wS[1] <= 7 &&
conv_params.str[0] <= 2 && conv_params.str[1] <= 2 &&
conv_params.oS[0] % 8 == 0 && conv_params.oS[1] % 8 == 0 &&
conv_params.wt_strides[1] == conv_params.wS[1] &&
conv_params.C % 16 == 0 && conv_params.C == conv_params.O) {
return depthwise_conv_2D_gpu(s, d, in, wt, out, conv_params);
}
if ((C_per_group <= 4 || C_per_group % 16 == 0) &&
(O_per_group <= 16 || O_per_group % 16 == 0)) { (O_per_group <= 16 || O_per_group % 16 == 0)) {
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params); return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
} else { } else {

View File

@@ -55,7 +55,10 @@ std::pair<MTL::Library*, NS::Error*> load_library_from_path(
} }
#ifdef SWIFTPM_BUNDLE #ifdef SWIFTPM_BUNDLE
MTL::Library* try_load_bundle(MTL::Device* device, NS::URL* url) { MTL::Library* try_load_bundle(
MTL::Device* device,
NS::URL* url,
const std::string& lib_name) {
std::string bundle_path = std::string(url->fileSystemRepresentation()) + "/" + std::string bundle_path = std::string(url->fileSystemRepresentation()) + "/" +
SWIFTPM_BUNDLE + ".bundle"; SWIFTPM_BUNDLE + ".bundle";
auto bundle = NS::Bundle::alloc()->init( auto bundle = NS::Bundle::alloc()->init(
@@ -63,8 +66,8 @@ MTL::Library* try_load_bundle(MTL::Device* device, NS::URL* url) {
if (bundle != nullptr) { if (bundle != nullptr) {
std::string resource_path = std::string resource_path =
std::string(bundle->resourceURL()->fileSystemRepresentation()) + "/" + std::string(bundle->resourceURL()->fileSystemRepresentation()) + "/" +
"default.metallib"; lib_name + ".metallib" auto [lib, error] =
auto [lib, error] = load_library_from_path(device, resource_path.c_str()); load_library_from_path(device, resource_path.c_str());
if (lib) { if (lib) {
return lib; return lib;
} }
@@ -73,51 +76,124 @@ MTL::Library* try_load_bundle(MTL::Device* device, NS::URL* url) {
} }
#endif #endif
// Firstly, search for the metallib in the same path as this binary
std::pair<MTL::Library*, NS::Error*> load_colocated_library(
MTL::Device* device,
const std::string& lib_name) {
std::string lib_path = get_colocated_mtllib_path(lib_name);
if (lib_path.size() != 0) {
return load_library_from_path(device, lib_path.c_str());
}
return {nullptr, nullptr};
}
std::pair<MTL::Library*, NS::Error*> load_swiftpm_library(
MTL::Device* device,
const std::string& lib_name) {
#ifdef SWIFTPM_BUNDLE
MTL::Library* library =
try_load_bundle(device, NS::Bundle::mainBundle()->bundleURL(), lib_name);
if (library != nullptr) {
return {library, nullptr};
}
auto bundles = NS::Bundle::allBundles();
for (int i = 0, c = (int)bundles->count(); i < c; i++) {
auto bundle = reinterpret_cast<NS::Bundle*>(bundles->object(i));
library = try_load_bundle(device, bundle->resourceURL());
if (library != nullptr) {
return {library, nullptr};
}
}
#endif
return {nullptr, nullptr};
}
MTL::Library* load_default_library(MTL::Device* device) {
NS::Error *error1, *error2, *error3;
MTL::Library* lib;
// First try the colocated mlx.metallib
std::tie(lib, error1) = load_colocated_library(device, "mlx");
if (lib) {
return lib;
}
// Then try default.metallib in a SwiftPM bundle if we have one
std::tie(lib, error2) = load_swiftpm_library(device, "default");
if (lib) {
return lib;
}
// Finally try default_mtllib_path
std::tie(lib, error3) = load_library_from_path(device, default_mtllib_path);
if (!lib) {
std::ostringstream msg;
msg << "Failed to load the default metallib. ";
if (error1 != nullptr) {
msg << error1->localizedDescription()->utf8String() << " ";
}
if (error2 != nullptr) {
msg << error2->localizedDescription()->utf8String() << " ";
}
if (error3 != nullptr) {
msg << error3->localizedDescription()->utf8String() << " ";
}
throw std::runtime_error(msg.str());
}
return lib;
}
MTL::Library* load_library( MTL::Library* load_library(
MTL::Device* device, MTL::Device* device,
const std::string& lib_name = "mlx", const std::string& lib_name,
const char* lib_path = default_mtllib_path) { const std::string& lib_path) {
// Firstly, search for the metallib in the same path as this binary // We have been given a path that ends in metallib so try to load it
std::string first_path = get_colocated_mtllib_path(lib_name); if (lib_path.size() > 9 &&
if (first_path.size() != 0) { std::equal(lib_path.end() - 9, lib_path.end(), ".metallib")) {
auto [lib, error] = load_library_from_path(device, first_path.c_str()); auto [lib, error] = load_library_from_path(device, lib_path.c_str());
if (!lib) {
std::ostringstream msg;
msg << "Failed to load the metallib from <" << lib_path << "> with error "
<< error->localizedDescription()->utf8String();
throw std::runtime_error(msg.str());
}
}
// We have been given a path so try to load from lib_path / lib_name.metallib
if (lib_path.size() > 0) {
std::string full_path = lib_path + "/" + lib_name + ".metallib";
auto [lib, error] = load_library_from_path(device, full_path.c_str());
if (!lib) {
std::ostringstream msg;
msg << "Failed to load the metallib from <" << full_path
<< "> with error " << error->localizedDescription()->utf8String();
throw std::runtime_error(msg.str());
}
}
// Try to load the colocated library
{
auto [lib, error] = load_colocated_library(device, lib_name);
if (lib) { if (lib) {
return lib; return lib;
} }
} }
#ifdef SWIFTPM_BUNDLE // Try to load the library from swiftpm
// try to load from a swiftpm resource bundle -- scan the available bundles to
// find one that contains the named bundle
{ {
MTL::Library* library = auto [lib, error] = load_swiftpm_library(device, lib_name);
try_load_bundle(device, NS::Bundle::mainBundle()->bundleURL()); if (lib) {
if (library != nullptr) { return lib;
return library;
}
auto bundles = NS::Bundle::allBundles();
for (int i = 0, c = (int)bundles->count(); i < c; i++) {
auto bundle = reinterpret_cast<NS::Bundle*>(bundles->object(i));
library = try_load_bundle(device, bundle->resourceURL());
if (library != nullptr) {
return library;
}
} }
} }
#endif
// Couldn't find it so let's load it from default_mtllib_path std::ostringstream msg;
{ msg << "Failed to load the metallib " << lib_name << ".metallib. "
auto [lib, error] = load_library_from_path(device, lib_path); << "We attempted to load it from <" << get_colocated_mtllib_path(lib_name)
if (!lib) { << ">";
std::ostringstream msg; #ifdef SWIFTPM_BUNDLE
msg << error->localizedDescription()->utf8String() << "\n" msg << " and from the Swift PM bundle.";
<< "Failed to load device library from <" << lib_path << ">" #endif
<< " or <" << first_path << ">."; throw std::runtime_error(msg.str());
throw std::runtime_error(msg.str());
}
return lib;
}
} }
} // namespace } // namespace
@@ -210,7 +286,7 @@ void CommandEncoder::barrier() {
Device::Device() { Device::Device() {
auto pool = new_scoped_memory_pool(); auto pool = new_scoped_memory_pool();
device_ = load_device(); device_ = load_device();
library_map_ = {{"mlx", load_library(device_)}}; library_map_ = {{"mlx", load_default_library(device_)}};
arch_ = std::string(device_->architecture()->name()->utf8String()); arch_ = std::string(device_->architecture()->name()->utf8String());
auto arch = arch_.back(); auto arch = arch_.back();
switch (arch) { switch (arch) {

View File

@@ -189,15 +189,7 @@ class Device {
void register_library( void register_library(
const std::string& lib_name, const std::string& lib_name,
const std::string& lib_path); const std::string& lib_path = "");
// Note, this should remain in the header so that it is not dynamically
// linked
void register_library(const std::string& lib_name) {
if (auto it = library_map_.find(lib_name); it == library_map_.end()) {
register_library(lib_name, get_colocated_mtllib_path(lib_name));
}
}
MTL::Library* get_library( MTL::Library* get_library(
const std::string& name, const std::string& name,

View File

@@ -24,10 +24,6 @@ void Event::wait() {
} }
} }
void Event::signal() {
static_cast<MTL::SharedEvent*>(event_.get())->setSignaledValue(value());
}
void Event::wait(Stream stream) { void Event::wait(Stream stream) {
if (stream.device == Device::cpu) { if (stream.device == Device::cpu) {
scheduler::enqueue(stream, [*this]() mutable { wait(); }); scheduler::enqueue(stream, [*this]() mutable { wait(); });
@@ -42,7 +38,9 @@ void Event::wait(Stream stream) {
void Event::signal(Stream stream) { void Event::signal(Stream stream) {
if (stream.device == Device::cpu) { if (stream.device == Device::cpu) {
scheduler::enqueue(stream, [*this]() mutable { signal(); }); scheduler::enqueue(stream, [*this]() mutable {
static_cast<MTL::SharedEvent*>(event_.get())->setSignaledValue(value());
});
} else { } else {
auto& d = metal::device(stream.device); auto& d = metal::device(stream.device);
d.end_encoding(stream.index); d.end_encoding(stream.index);

View File

@@ -356,20 +356,14 @@ void multi_upload_bluestein_fft(
bool inverse, bool inverse,
bool real, bool real,
FFTPlan& plan, FFTPlan& plan,
std::vector<array> copies, std::vector<array>& copies,
const Stream& s) { const Stream& s) {
// TODO(alexbarron) Implement fused kernels for mutli upload bluestein's // TODO(alexbarron) Implement fused kernels for mutli upload bluestein's
// algorithm // algorithm
int n = inverse ? out.shape(axis) : in.shape(axis); int n = inverse ? out.shape(axis) : in.shape(axis);
auto [w_k, w_q] = compute_bluestein_constants(n, plan.bluestein_n); auto [w_k, w_q] = compute_bluestein_constants(n, plan.bluestein_n);
copies.push_back(w_k);
// Broadcast w_q and w_k to the batch size copies.push_back(w_q);
Strides b_strides(in.ndim(), 0);
b_strides[axis] = 1;
array w_k_broadcast({}, complex64, nullptr, {});
array w_q_broadcast({}, complex64, nullptr, {});
w_k_broadcast.copy_shared_buffer(w_k, b_strides, {}, w_k.data_size());
w_q_broadcast.copy_shared_buffer(w_q, b_strides, {}, w_q.data_size());
auto temp_shape = inverse ? out.shape() : in.shape(); auto temp_shape = inverse ? out.shape() : in.shape();
array temp(temp_shape, complex64, nullptr, {}); array temp(temp_shape, complex64, nullptr, {});
@@ -378,13 +372,13 @@ void multi_upload_bluestein_fft(
if (real && !inverse) { if (real && !inverse) {
// Convert float32->complex64 // Convert float32->complex64
copy_gpu(in, temp, CopyType::General, s); copy_gpu(in, temp, CopyType::General, s);
copies.push_back(temp);
} else if (real && inverse) { } else if (real && inverse) {
int back_offset = n % 2 == 0 ? 2 : 1; int back_offset = n % 2 == 0 ? 2 : 1;
auto slice_shape = in.shape(); auto slice_shape = in.shape();
slice_shape[axis] -= back_offset; slice_shape[axis] -= back_offset;
array slice_temp(slice_shape, complex64, nullptr, {}); array slice_temp(slice_shape, complex64, nullptr, {});
array conj_temp(in.shape(), complex64, nullptr, {}); array conj_temp(in.shape(), complex64, nullptr, {});
copies.push_back(slice_temp);
copies.push_back(conj_temp); copies.push_back(conj_temp);
Shape rstarts(in.ndim(), 0); Shape rstarts(in.ndim(), 0);
@@ -394,19 +388,28 @@ void multi_upload_bluestein_fft(
unary_op_gpu({in}, conj_temp, "Conjugate", s); unary_op_gpu({in}, conj_temp, "Conjugate", s);
slice_gpu(in, slice_temp, rstarts, rstrides, s); slice_gpu(in, slice_temp, rstarts, rstrides, s);
concatenate_gpu({conj_temp, slice_temp}, temp, (int)axis, s); concatenate_gpu({conj_temp, slice_temp}, temp, (int)axis, s);
copies.push_back(temp);
} else if (inverse) { } else if (inverse) {
unary_op_gpu({in}, temp, "Conjugate", s); unary_op_gpu({in}, temp, "Conjugate", s);
copies.push_back(temp);
} else { } else {
temp.copy_shared_buffer(in); temp.copy_shared_buffer(in);
} }
Strides b_strides(in.ndim(), 0);
b_strides[axis] = 1;
array w_k_broadcast(temp.shape(), complex64, nullptr, {});
w_k_broadcast.copy_shared_buffer(w_k, b_strides, {}, w_k.data_size());
binary_op_gpu({temp, w_k_broadcast}, temp1, "Multiply", s); binary_op_gpu({temp, w_k_broadcast}, temp1, "Multiply", s);
std::vector<std::pair<int, int>> pads; std::vector<std::pair<int, int>> pads;
auto padded_shape = out.shape(); auto padded_shape = out.shape();
padded_shape[axis] = plan.bluestein_n; padded_shape[axis] = plan.bluestein_n;
array pad_temp(padded_shape, complex64, nullptr, {}); array pad_temp(padded_shape, complex64, nullptr, {});
pad_gpu(temp1, array(complex64_t{0.0f, 0.0f}), pad_temp, {(int)axis}, {0}, s); auto zero = array(complex64_t{0.0f, 0.0f});
copies.push_back(zero);
pad_gpu(temp1, zero, pad_temp, {(int)axis}, {0}, s);
copies.push_back(pad_temp);
array pad_temp1(padded_shape, complex64, nullptr, {}); array pad_temp1(padded_shape, complex64, nullptr, {});
fft_op( fft_op(
@@ -418,7 +421,10 @@ void multi_upload_bluestein_fft(
FourStepParams(), FourStepParams(),
/*inplace=*/false, /*inplace=*/false,
s); s);
copies.push_back(pad_temp1);
array w_q_broadcast(pad_temp1.shape(), complex64, nullptr, {});
w_q_broadcast.copy_shared_buffer(w_q, b_strides, {}, w_q.data_size());
binary_op_gpu_inplace({pad_temp1, w_q_broadcast}, pad_temp, "Multiply", s); binary_op_gpu_inplace({pad_temp1, w_q_broadcast}, pad_temp, "Multiply", s);
fft_op( fft_op(
@@ -435,9 +441,11 @@ void multi_upload_bluestein_fft(
Shape starts(in.ndim(), 0); Shape starts(in.ndim(), 0);
Shape strides(in.ndim(), 1); Shape strides(in.ndim(), 1);
starts[axis] = plan.bluestein_n - offset - n; starts[axis] = plan.bluestein_n - offset - n;
slice_gpu(pad_temp1, temp, starts, strides, s);
binary_op_gpu_inplace({temp, w_k_broadcast}, temp1, "Multiply", s); array temp2(temp_shape, complex64, nullptr, {});
slice_gpu(pad_temp1, temp2, starts, strides, s);
binary_op_gpu_inplace({temp2, w_k_broadcast}, temp1, "Multiply", s);
if (real && !inverse) { if (real && !inverse) {
Shape rstarts(in.ndim(), 0); Shape rstarts(in.ndim(), 0);
@@ -449,26 +457,21 @@ void multi_upload_bluestein_fft(
array temp_float(out.shape(), out.dtype(), nullptr, {}); array temp_float(out.shape(), out.dtype(), nullptr, {});
copies.push_back(temp_float); copies.push_back(temp_float);
copies.push_back(inv_n); copies.push_back(inv_n);
copies.push_back(temp1);
copy_gpu(temp1, temp_float, CopyType::General, s); copy_gpu(temp1, temp_float, CopyType::General, s);
binary_op_gpu({temp_float, inv_n}, out, "Multiply", s); binary_op_gpu({temp_float, inv_n}, out, "Multiply", s);
} else if (inverse) { } else if (inverse) {
auto inv_n = array({1.0f / n}, {1}, complex64); auto inv_n = array({1.0f / n}, {1}, complex64);
unary_op_gpu({temp1}, temp, "Conjugate", s); array temp3(temp_shape, complex64, nullptr, {});
binary_op_gpu({temp, inv_n}, out, "Multiply", s); unary_op_gpu({temp1}, temp3, "Conjugate", s);
binary_op_gpu({temp3, inv_n}, out, "Multiply", s);
copies.push_back(inv_n); copies.push_back(inv_n);
copies.push_back(temp1);
copies.push_back(temp3);
} else { } else {
out.copy_shared_buffer(temp1); out.copy_shared_buffer(temp1);
} }
copies.push_back(w_k);
copies.push_back(w_q);
copies.push_back(w_k_broadcast);
copies.push_back(w_q_broadcast);
copies.push_back(temp);
copies.push_back(temp1);
copies.push_back(pad_temp);
copies.push_back(pad_temp1);
} }
void four_step_fft( void four_step_fft(
@@ -478,8 +481,9 @@ void four_step_fft(
bool inverse, bool inverse,
bool real, bool real,
FFTPlan& plan, FFTPlan& plan,
std::vector<array> copies, std::vector<array>& copies,
const Stream& s) { const Stream& s,
bool in_place) {
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
if (plan.bluestein_n == -1) { if (plan.bluestein_n == -1) {
@@ -492,7 +496,14 @@ void four_step_fft(
in, temp, axis, inverse, real, four_step_params, /*inplace=*/false, s); in, temp, axis, inverse, real, four_step_params, /*inplace=*/false, s);
four_step_params.first_step = false; four_step_params.first_step = false;
fft_op( fft_op(
temp, out, axis, inverse, real, four_step_params, /*inplace=*/false, s); temp,
out,
axis,
inverse,
real,
four_step_params,
/*inplace=*/in_place,
s);
copies.push_back(temp); copies.push_back(temp);
} else { } else {
multi_upload_bluestein_fft(in, out, axis, inverse, real, plan, copies, s); multi_upload_bluestein_fft(in, out, axis, inverse, real, plan, copies, s);
@@ -574,7 +585,7 @@ void fft_op(
auto plan = plan_fft(n); auto plan = plan_fft(n);
if (plan.four_step) { if (plan.four_step) {
four_step_fft(in, out, axis, inverse, real, plan, copies, s); four_step_fft(in, out, axis, inverse, real, plan, copies, s, inplace);
d.add_temporaries(std::move(copies), s.index); d.add_temporaries(std::move(copies), s.index);
return; return;
} }

View File

@@ -1,9 +0,0 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view arange_kernels = R"(
template [[host_name("{0}")]] [[kernel]] void arange<{1}>(
constant const {1}& start,
constant const {1}& step,
device {1}* out,
uint index [[thread_position_in_grid]]);
)";

View File

@@ -20,6 +20,7 @@ const char* copy();
const char* fft(); const char* fft();
const char* gather_axis(); const char* gather_axis();
const char* hadamard(); const char* hadamard();
const char* logsumexp();
const char* quantized(); const char* quantized();
const char* ternary(); const char* ternary();
const char* scan(); const char* scan();
@@ -32,6 +33,7 @@ const char* gemm();
const char* steel_gemm_fused(); const char* steel_gemm_fused();
const char* steel_gemm_masked(); const char* steel_gemm_masked();
const char* steel_gemm_splitk(); const char* steel_gemm_splitk();
const char* steel_gemm_gather();
const char* conv(); const char* conv();
const char* steel_conv(); const char* steel_conv();
const char* steel_conv_general(); const char* steel_conv_general();

View File

@@ -1,23 +0,0 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view softmax_kernels = R"(
template [[host_name("block_{0}")]] [[kernel]] void
softmax_single_row<{1}, {2}>(
const device {1}* in,
device {1}* out,
constant int& axis_size,
uint gid [[thread_position_in_grid]],
uint _lid [[thread_position_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
template [[host_name("looped_{0}")]] [[kernel]] void
softmax_looped<{1}, {2}>(
const device {1}* in,
device {1}* out,
constant int& axis_size,
uint gid [[threadgroup_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
)";

View File

@@ -1,8 +1,6 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
#include "mlx/backend/common/compiled.h" #include "mlx/backend/common/compiled.h"
#include "mlx/backend/metal/jit/arange.h"
#include "mlx/backend/metal/jit/includes.h" #include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/jit/softmax.h"
#include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h" #include "mlx/backend/metal/utils.h"
@@ -21,13 +19,11 @@ MTL::ComputePipelineState* get_arange_kernel(
const std::string& kernel_name, const std::string& kernel_name,
const array& out) { const array& out) {
auto lib = d.get_library(kernel_name, [&]() { auto lib = d.get_library(kernel_name, [&]() {
std::ostringstream kernel_source; std::string kernel_source = metal::utils();
kernel_source << metal::utils() << metal::arange() kernel_source += metal::arange();
<< fmt::format( kernel_source += get_template_definition(
arange_kernels, kernel_name, "arange", get_type_string(out.dtype()));
kernel_name, return kernel_source;
get_type_string(out.dtype()));
return kernel_source.str();
}); });
return d.get_kernel(kernel_name, lib); return d.get_kernel(kernel_name, lib);
} }
@@ -259,14 +255,34 @@ MTL::ComputePipelineState* get_softmax_kernel(
const array& out) { const array& out) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name, [&] { auto lib = d.get_library(lib_name, [&] {
std::ostringstream kernel_source; std::string kernel_source = metal::utils();
kernel_source << metal::utils() << metal::softmax() auto in_type = get_type_string(out.dtype());
<< fmt::format( auto acc_type = get_type_string(precise ? float32 : out.dtype());
softmax_kernels, kernel_source += metal::softmax();
lib_name, kernel_source += get_template_definition(
get_type_string(out.dtype()), "block_" + lib_name, "softmax_single_row", in_type, acc_type);
get_type_string(precise ? float32 : out.dtype())); kernel_source += get_template_definition(
return kernel_source.str(); "looped_" + lib_name, "softmax_looped", in_type, acc_type);
return kernel_source;
});
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_logsumexp_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name, [&] {
auto t_str = get_type_string(out.dtype());
std::string kernel_source;
kernel_source = metal::utils();
kernel_source += metal::logsumexp();
kernel_source +=
get_template_definition("block_" + lib_name, "logsumexp", t_str);
kernel_source += get_template_definition(
"looped_" + lib_name, "logsumexp_looped", t_str);
return kernel_source;
}); });
return d.get_kernel(kernel_name, lib); return d.get_kernel(kernel_name, lib);
} }
@@ -568,6 +584,44 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
return d.get_kernel(kernel_name, lib); return d.get_kernel(kernel_name, lib);
} }
MTL::ComputePipelineState* get_steel_gemm_gather_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const array& out,
bool transpose_a,
bool transpose_b,
int bm,
int bn,
int bk,
int wm,
int wn,
bool rhs) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name, [&]() {
std::string kernel_source;
concatenate(
kernel_source,
metal::utils(),
metal::gemm(),
metal::steel_gemm_gather(),
get_template_definition(
lib_name,
rhs ? "gather_mm_rhs" : "gather_mm",
get_type_string(out.dtype()),
bm,
bn,
bk,
wm,
wn,
transpose_a,
transpose_b));
return kernel_source;
});
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
}
MTL::ComputePipelineState* get_gemv_masked_kernel( MTL::ComputePipelineState* get_gemv_masked_kernel(
metal::Device& d, metal::Device& d,
const std::string& kernel_name, const std::string& kernel_name,
@@ -698,4 +752,43 @@ MTL::ComputePipelineState* get_quantized_kernel(
return d.get_kernel(kernel_name, lib); return d.get_kernel(kernel_name, lib);
} }
MTL::ComputePipelineState* get_gather_qmm_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const array& x,
int group_size,
int bits,
int bm,
int bn,
int bk,
int wm,
int wn,
bool transpose) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name, [&]() {
std::string kernel_source;
concatenate(
kernel_source,
metal::utils(),
metal::gemm(),
metal::quantized(),
get_template_definition(
lib_name,
"gather_qmm_rhs",
get_type_string(x.dtype()),
group_size,
bits,
bm,
bn,
bk,
wm,
wn,
transpose));
return kernel_source;
});
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
}
} // namespace mlx::core } // namespace mlx::core

View File

@@ -59,6 +59,11 @@ MTL::ComputePipelineState* get_softmax_kernel(
bool precise, bool precise,
const array& out); const array& out);
MTL::ComputePipelineState* get_logsumexp_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out);
MTL::ComputePipelineState* get_scan_kernel( MTL::ComputePipelineState* get_scan_kernel(
metal::Device& d, metal::Device& d,
const std::string& kernel_name, const std::string& kernel_name,
@@ -155,6 +160,21 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
bool mn_aligned, bool mn_aligned,
bool k_aligned); bool k_aligned);
MTL::ComputePipelineState* get_steel_gemm_gather_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const array& out,
bool transpose_a,
bool transpose_b,
int bm,
int bn,
int bk,
int wm,
int wn,
bool rhs);
MTL::ComputePipelineState* get_steel_conv_kernel( MTL::ComputePipelineState* get_steel_conv_kernel(
metal::Device& d, metal::Device& d,
const std::string& kernel_name, const std::string& kernel_name,
@@ -204,6 +224,21 @@ MTL::ComputePipelineState* get_quantized_kernel(
const std::string& kernel_name, const std::string& kernel_name,
const std::string& template_def); const std::string& template_def);
MTL::ComputePipelineState* get_gather_qmm_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const array& x,
int group_size,
int bits,
int bm,
int bn,
int bk,
int wm,
int wn,
bool transpose);
// Create a GPU kernel template definition for JIT compilation // Create a GPU kernel template definition for JIT compilation
template <typename... Args> template <typename... Args>
std::string std::string

View File

@@ -13,6 +13,10 @@ function(build_kernel_base TARGET SRCFILE DEPS)
if(MLX_METAL_DEBUG) if(MLX_METAL_DEBUG)
set(METAL_FLAGS ${METAL_FLAGS} -gline-tables-only -frecord-sources) set(METAL_FLAGS ${METAL_FLAGS} -gline-tables-only -frecord-sources)
endif() endif()
if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "")
set(METAL_FLAGS ${METAL_FLAGS}
"-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
endif()
if(MLX_METAL_VERSION GREATER_EQUAL 310) if(MLX_METAL_VERSION GREATER_EQUAL 310)
set(VERSION_INCLUDES set(VERSION_INCLUDES
${PROJECT_SOURCE_DIR}/mlx/backend/metal/kernels/metal_3_1) ${PROJECT_SOURCE_DIR}/mlx/backend/metal/kernels/metal_3_1)
@@ -65,6 +69,7 @@ set(STEEL_HEADERS
steel/gemm/loader.h steel/gemm/loader.h
steel/gemm/transforms.h steel/gemm/transforms.h
steel/gemm/kernels/steel_gemm_fused.h steel/gemm/kernels/steel_gemm_fused.h
steel/gemm/kernels/steel_gemm_gather.h
steel/gemm/kernels/steel_gemm_masked.h steel/gemm/kernels/steel_gemm_masked.h
steel/gemm/kernels/steel_gemm_splitk.h steel/gemm/kernels/steel_gemm_splitk.h
steel/utils/type_traits.h steel/utils/type_traits.h
@@ -105,12 +110,14 @@ if(NOT MLX_METAL_JIT)
build_kernel(quantized quantized.h ${STEEL_HEADERS}) build_kernel(quantized quantized.h ${STEEL_HEADERS})
build_kernel(scan scan.h) build_kernel(scan scan.h)
build_kernel(softmax softmax.h) build_kernel(softmax softmax.h)
build_kernel(logsumexp logsumexp.h)
build_kernel(sort sort.h) build_kernel(sort sort.h)
build_kernel(ternary ternary.h ternary_ops.h) build_kernel(ternary ternary.h ternary_ops.h)
build_kernel(unary unary.h unary_ops.h) build_kernel(unary unary.h unary_ops.h)
build_kernel(steel/conv/kernels/steel_conv ${STEEL_HEADERS}) build_kernel(steel/conv/kernels/steel_conv ${STEEL_HEADERS})
build_kernel(steel/conv/kernels/steel_conv_general ${STEEL_HEADERS}) build_kernel(steel/conv/kernels/steel_conv_general ${STEEL_HEADERS})
build_kernel(steel/gemm/kernels/steel_gemm_fused ${STEEL_HEADERS}) build_kernel(steel/gemm/kernels/steel_gemm_fused ${STEEL_HEADERS})
build_kernel(steel/gemm/kernels/steel_gemm_gather ${STEEL_HEADERS})
build_kernel(steel/gemm/kernels/steel_gemm_masked ${STEEL_HEADERS}) build_kernel(steel/gemm/kernels/steel_gemm_masked ${STEEL_HEADERS})
build_kernel(steel/gemm/kernels/steel_gemm_splitk ${STEEL_HEADERS}) build_kernel(steel/gemm/kernels/steel_gemm_splitk ${STEEL_HEADERS})
build_kernel(gemv_masked steel/utils.h) build_kernel(gemv_masked steel/utils.h)

View File

@@ -5,11 +5,7 @@
#include "mlx/backend/metal/kernels/arange.h" #include "mlx/backend/metal/kernels/arange.h"
#define instantiate_arange(tname, type) \ #define instantiate_arange(tname, type) \
template [[host_name("arange" #tname)]] [[kernel]] void arange<type>( \ instantiate_kernel("arange" #tname, arange, type)
constant const type& start, \
constant const type& step, \
device type* out, \
uint index [[thread_position_in_grid]]);
instantiate_arange(uint8, uint8_t) instantiate_arange(uint8, uint8_t)
instantiate_arange(uint16, uint16_t) instantiate_arange(uint16, uint16_t)

View File

@@ -275,6 +275,128 @@ instantiate_naive_conv_2d_blocks(float32, float);
instantiate_naive_conv_2d_blocks(float16, half); instantiate_naive_conv_2d_blocks(float16, half);
instantiate_naive_conv_2d_blocks(bfloat16, bfloat16_t); instantiate_naive_conv_2d_blocks(bfloat16, bfloat16_t);
///////////////////////////////////////////////////////////////////////////////
/// Depthwise convolution kernels
///////////////////////////////////////////////////////////////////////////////
constant int ker_h [[function_constant(00)]];
constant int ker_w [[function_constant(01)]];
constant int str_h [[function_constant(10)]];
constant int str_w [[function_constant(11)]];
constant int tgp_h [[function_constant(100)]];
constant int tgp_w [[function_constant(101)]];
constant bool do_flip [[function_constant(200)]];
constant int span_h = tgp_h * str_h + ker_h - 1;
constant int span_w = tgp_w * str_w + ker_w - 1;
constant int span_hw = span_h * span_w;
template <typename T>
[[kernel]] void depthwise_conv_2d(
const device T* in [[buffer(0)]],
const device T* wt [[buffer(1)]],
device T* out [[buffer(2)]],
const constant MLXConvParams<2>& params [[buffer(3)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 gid [[thread_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int tc = 8;
constexpr int tw = 8;
constexpr int th = 4;
constexpr int c_per_thr = 8;
constexpr int TGH = th * 2 + 6;
constexpr int TGW = tw * 2 + 6;
constexpr int TGC = tc;
threadgroup T ins[TGH * TGW * TGC];
const int n_tgblocks_h = params.oS[0] / th;
const int n = tid.z / n_tgblocks_h;
const int tghid = tid.z % n_tgblocks_h;
const int oh = tghid * th + lid.z;
const int ow = gid.y;
const int c = gid.x;
in += n * params.in_strides[0];
// Load in
{
constexpr int n_threads = th * tw * tc;
const int tg_oh = (tghid * th) * str_h - params.pad[0];
const int tg_ow = (tid.y * tw) * str_w - params.pad[1];
const int tg_c = tid.x * tc;
const int thread_idx = simd_gid * 32 + simd_lid;
constexpr int thr_per_hw = tc / c_per_thr;
constexpr int hw_per_group = n_threads / thr_per_hw;
const int thr_c = thread_idx % thr_per_hw;
const int thr_hw = thread_idx / thr_per_hw;
for (int hw = thr_hw; hw < span_hw; hw += hw_per_group) {
const int h = hw / span_w;
const int w = hw % span_w;
const int ih = tg_oh + h;
const int iw = tg_ow + w;
const int in_s_offset = h * span_w * TGC + w * TGC;
if (ih >= 0 && ih < params.iS[0] && iw >= 0 && iw < params.iS[1]) {
const auto in_load =
in + ih * params.in_strides[1] + iw * params.in_strides[2] + tg_c;
MLX_MTL_PRAGMA_UNROLL
for (int cc = 0; cc < c_per_thr; ++cc) {
ins[in_s_offset + c_per_thr * thr_c + cc] =
in_load[c_per_thr * thr_c + cc];
}
} else {
MLX_MTL_PRAGMA_UNROLL
for (int cc = 0; cc < c_per_thr; ++cc) {
ins[in_s_offset + c_per_thr * thr_c + cc] = T(0);
}
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
wt += c * params.wt_strides[0];
const auto ins_ptr =
&ins[lid.z * str_h * span_w * TGC + lid.y * str_w * TGC + lid.x];
float o = 0.;
for (int h = 0; h < ker_h; ++h) {
for (int w = 0; w < ker_w; ++w) {
int wt_h = h;
int wt_w = w;
if (do_flip) {
wt_h = ker_h - h - 1;
wt_w = ker_w - w - 1;
}
auto inv = ins_ptr[h * span_w * TGC + w * TGC];
auto wtv = wt[wt_h * ker_w + wt_w];
o += inv * wtv;
}
}
threadgroup_barrier(mem_flags::mem_none);
out += n * params.out_strides[0] + oh * params.out_strides[1] +
ow * params.out_strides[2];
out[c] = static_cast<T>(o);
}
#define instantiate_depthconv2d(iname, itype) \
instantiate_kernel("depthwise_conv_2d_" #iname, depthwise_conv_2d, itype)
instantiate_depthconv2d(float32, float);
instantiate_depthconv2d(float16, half);
instantiate_depthconv2d(bfloat16, bfloat16_t);
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
/// Winograd kernels /// Winograd kernels
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////

View File

@@ -483,4 +483,4 @@ template <
perform_fft(fft_idx, &p, m, n, buf); perform_fft(fft_idx, &p, m, n, buf);
read_writer.write_strided(stride, overall_n); read_writer.write_strided(stride, overall_n);
} }

View File

@@ -341,7 +341,7 @@ struct GEMVTKernel {
MLX_MTL_PRAGMA_UNROLL MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) { for (int tm = 0; tm < TM; tm++) {
auto vc = float(v_coeff[tm]); auto vc = static_cast<AccT>(v_coeff[tm]);
for (int tn = 0; tn < TN; tn++) { for (int tn = 0; tn < TN; tn++) {
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn]; inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
} }

View File

@@ -493,71 +493,11 @@ template <typename T, int N_READS = RMS_N_READS>
} }
// clang-format off // clang-format off
#define instantiate_layer_norm_single_row(name, itype) \ #define instantiate_layer_norm(name, itype) \
template [[host_name("layer_norm" #name)]] [[kernel]] void \ instantiate_kernel("layer_norm" #name, layer_norm_single_row, itype) \
layer_norm_single_row<itype>( \ instantiate_kernel("vjp_layer_norm" #name, vjp_layer_norm_single_row, itype) \
const device itype* x, \ instantiate_kernel("layer_norm_looped" #name, layer_norm_looped, itype) \
const device itype* w, \ instantiate_kernel("vjp_layer_norm_looped" #name, vjp_layer_norm_looped, itype)
const device itype* b, \
device itype* out, \
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
constant uint& b_stride, \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
template [[host_name("vjp_layer_norm" #name)]] [[kernel]] void \
vjp_layer_norm_single_row<itype>( \
const device itype* x, \
const device itype* w, \
const device itype* g, \
device itype* gx, \
device itype* gw, \
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_layer_norm_looped(name, itype) \
template [[host_name("layer_norm_looped" #name)]] [[kernel]] void \
layer_norm_looped<itype>( \
const device itype* x, \
const device itype* w, \
const device itype* b, \
device itype* out, \
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
constant uint& b_stride, \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
template [[host_name("vjp_layer_norm_looped" #name)]] [[kernel]] void \
vjp_layer_norm_looped<itype>( \
const device itype* x, \
const device itype* w, \
const device itype* g, \
device itype* gx, \
device itype* gb, \
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_layer_norm(name, itype) \
instantiate_layer_norm_single_row(name, itype) \
instantiate_layer_norm_looped(name, itype)
instantiate_layer_norm(float32, float) instantiate_layer_norm(float32, float)
instantiate_layer_norm(float16, half) instantiate_layer_norm(float16, half)

View File

@@ -0,0 +1,143 @@
// Copyright © 2025 Apple Inc.
template <typename T, typename AccT = float, int N_READS = 4>
[[kernel]] void logsumexp(
const device T* in,
device T* out,
constant int& axis_size,
uint gid [[threadgroup_position_in_grid]],
uint _lid [[thread_position_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
int lid = _lid;
constexpr int SIMD_SIZE = 32;
threadgroup AccT local_max[SIMD_SIZE];
threadgroup AccT local_normalizer[SIMD_SIZE];
AccT ld[N_READS];
in += gid * size_t(axis_size) + lid * N_READS;
if (lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
ld[i] = AccT(in[i]);
}
} else {
for (int i = 0; i < N_READS; i++) {
ld[i] =
((lid * N_READS + i) < axis_size) ? AccT(in[i]) : Limits<AccT>::min;
}
}
if (simd_group_id == 0) {
local_max[simd_lane_id] = Limits<AccT>::min;
local_normalizer[simd_lane_id] = 0;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Get the max
AccT maxval = Limits<AccT>::finite_min;
for (int i = 0; i < N_READS; i++) {
maxval = (maxval < ld[i]) ? ld[i] : maxval;
}
maxval = simd_max(maxval);
if (simd_lane_id == 0) {
local_max[simd_group_id] = maxval;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_group_id == 0) {
maxval = simd_max(local_max[simd_lane_id]);
if (simd_lane_id == 0) {
local_max[0] = maxval;
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
maxval = local_max[0];
// Compute exp(x_i - maxval) and store the partial sums in local_normalizer
AccT normalizer = 0;
for (int i = 0; i < N_READS; i++) {
normalizer += fast::exp(ld[i] - maxval);
}
normalizer = simd_sum(normalizer);
if (simd_lane_id == 0) {
local_normalizer[simd_group_id] = normalizer;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_group_id == 0) {
normalizer = simd_sum(local_normalizer[simd_lane_id]);
if (simd_lane_id == 0) {
out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval);
}
}
}
template <typename T, typename AccT = float, int N_READS = 4>
[[kernel]] void logsumexp_looped(
const device T* in,
device T* out,
constant int& axis_size,
uint gid [[threadgroup_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
in += gid * size_t(axis_size);
constexpr int SIMD_SIZE = 32;
threadgroup AccT local_max[SIMD_SIZE];
threadgroup AccT local_normalizer[SIMD_SIZE];
// Get the max and the normalizer in one go
AccT prevmax;
AccT maxval = Limits<AccT>::finite_min;
AccT normalizer = 0;
for (int r = 0; r < static_cast<int>(ceildiv(axis_size, N_READS * lsize));
r++) {
int offset = r * lsize * N_READS + lid * N_READS;
AccT vals[N_READS];
if (offset + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
vals[i] = AccT(in[offset + i]);
}
} else {
for (int i = 0; i < N_READS; i++) {
vals[i] = (offset + i < axis_size) ? AccT(in[offset + i])
: Limits<AccT>::finite_min;
}
}
prevmax = maxval;
for (int i = 0; i < N_READS; i++) {
maxval = (maxval < vals[i]) ? vals[i] : maxval;
}
normalizer *= fast::exp(prevmax - maxval);
for (int i = 0; i < N_READS; i++) {
normalizer += fast::exp(vals[i] - maxval);
}
}
prevmax = maxval;
maxval = simd_max(maxval);
normalizer *= fast::exp(prevmax - maxval);
normalizer = simd_sum(normalizer);
prevmax = maxval;
if (simd_lane_id == 0) {
local_max[simd_group_id] = maxval;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
maxval = simd_max(local_max[simd_lane_id]);
normalizer *= fast::exp(prevmax - maxval);
if (simd_lane_id == 0) {
local_normalizer[simd_group_id] = normalizer;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
normalizer = simd_sum(local_normalizer[simd_lane_id]);
if (simd_group_id == 0) {
normalizer = simd_sum(local_normalizer[simd_lane_id]);
if (simd_lane_id == 0) {
out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval);
}
}
}

View File

@@ -0,0 +1,18 @@
// Copyright © 2023-2024 Apple Inc.
#include <metal_common>
#include <metal_simdgroup>
using namespace metal;
// clang-format off
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/logsumexp.h"
#define instantiate_logsumexp(name, itype) \
instantiate_kernel("block_logsumexp_" #name, logsumexp, itype) \
instantiate_kernel("looped_logsumexp_" #name, logsumexp_looped, itype) \
instantiate_logsumexp(float32, float)
instantiate_logsumexp(float16, half)
instantiate_logsumexp(bfloat16, bfloat16_t) // clang-format on

View File

@@ -3,6 +3,10 @@
#include <metal_simdgroup> #include <metal_simdgroup>
#include <metal_stdlib> #include <metal_stdlib>
constant bool align_M [[function_constant(200)]];
constant bool align_N [[function_constant(201)]];
constant bool align_K [[function_constant(202)]];
using namespace metal; using namespace metal;
#define MLX_MTL_CONST static constant constexpr const #define MLX_MTL_CONST static constant constexpr const
@@ -586,13 +590,13 @@ METAL_FUNC void qmv_quad_impl(
// Adjust positions // Adjust positions
const int in_vec_size_w = in_vec_size / pack_factor; const int in_vec_size_w = in_vec_size / pack_factor;
const int in_vec_size_g = in_vec_size / group_size; const int in_vec_size_g = in_vec_size / group_size;
const int out_row = tid.x * quads_per_simd * results_per_quadgroup + quad_gid; const int out_row = tid.y * quads_per_simd * results_per_quadgroup + quad_gid;
w += out_row * in_vec_size_w + quad_lid * packs_per_thread; w += out_row * in_vec_size_w + quad_lid * packs_per_thread;
scales += out_row * in_vec_size_g + quad_lid / scale_step_per_thread; scales += out_row * in_vec_size_g + quad_lid / scale_step_per_thread;
biases += out_row * in_vec_size_g + quad_lid / scale_step_per_thread; biases += out_row * in_vec_size_g + quad_lid / scale_step_per_thread;
x += tid.y * in_vec_size + quad_lid * values_per_thread; x += tid.x * in_vec_size + quad_lid * values_per_thread;
y += tid.y * out_vec_size + out_row; y += tid.x * out_vec_size + out_row;
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread); U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
@@ -1686,26 +1690,26 @@ template <
} }
template <typename T, int group_size, int bits> template <typename T, int group_size, int bits>
[[kernel]] void bs_qmv_fast( [[kernel]] void gather_qmv_fast(
const device uint32_t* w [[buffer(0)]], const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]], const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]], const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]], const device T* x [[buffer(3)]],
device T* y [[buffer(4)]], const device uint32_t* lhs_indices [[buffer(4)]],
const constant int& in_vec_size [[buffer(5)]], const device uint32_t* rhs_indices [[buffer(5)]],
const constant int& out_vec_size [[buffer(6)]], device T* y [[buffer(6)]],
const constant int& x_batch_ndims [[buffer(7)]], const constant int& in_vec_size [[buffer(7)]],
const constant int* x_shape [[buffer(8)]], const constant int& out_vec_size [[buffer(8)]],
const constant int64_t* x_strides [[buffer(9)]], const constant int& x_batch_ndims [[buffer(9)]],
const constant int& w_batch_ndims [[buffer(10)]], const constant int* x_shape [[buffer(10)]],
const constant int* w_shape [[buffer(11)]], const constant int64_t* x_strides [[buffer(11)]],
const constant int64_t* w_strides [[buffer(12)]], const constant int& w_batch_ndims [[buffer(12)]],
const constant int64_t* s_strides [[buffer(13)]], const constant int* w_shape [[buffer(13)]],
const constant int64_t* b_strides [[buffer(14)]], const constant int64_t* w_strides [[buffer(14)]],
const constant int& batch_ndims [[buffer(15)]], const constant int64_t* s_strides [[buffer(15)]],
const constant int* batch_shape [[buffer(16)]], const constant int64_t* b_strides [[buffer(16)]],
const device uint32_t* lhs_indices [[buffer(17)]], const constant int& batch_ndims [[buffer(17)]],
const device uint32_t* rhs_indices [[buffer(18)]], const constant int* batch_shape [[buffer(18)]],
const constant int64_t* lhs_strides [[buffer(19)]], const constant int64_t* lhs_strides [[buffer(19)]],
const constant int64_t* rhs_strides [[buffer(20)]], const constant int64_t* rhs_strides [[buffer(20)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
@@ -1748,26 +1752,26 @@ template <typename T, int group_size, int bits>
} }
template <typename T, int group_size, int bits> template <typename T, int group_size, int bits>
[[kernel]] void bs_qmv( [[kernel]] void gather_qmv(
const device uint32_t* w [[buffer(0)]], const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]], const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]], const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]], const device T* x [[buffer(3)]],
device T* y [[buffer(4)]], const device uint32_t* lhs_indices [[buffer(4)]],
const constant int& in_vec_size [[buffer(5)]], const device uint32_t* rhs_indices [[buffer(5)]],
const constant int& out_vec_size [[buffer(6)]], device T* y [[buffer(6)]],
const constant int& x_batch_ndims [[buffer(7)]], const constant int& in_vec_size [[buffer(7)]],
const constant int* x_shape [[buffer(8)]], const constant int& out_vec_size [[buffer(8)]],
const constant int64_t* x_strides [[buffer(9)]], const constant int& x_batch_ndims [[buffer(9)]],
const constant int& w_batch_ndims [[buffer(10)]], const constant int* x_shape [[buffer(10)]],
const constant int* w_shape [[buffer(11)]], const constant int64_t* x_strides [[buffer(11)]],
const constant int64_t* w_strides [[buffer(12)]], const constant int& w_batch_ndims [[buffer(12)]],
const constant int64_t* s_strides [[buffer(13)]], const constant int* w_shape [[buffer(13)]],
const constant int64_t* b_strides [[buffer(14)]], const constant int64_t* w_strides [[buffer(14)]],
const constant int& batch_ndims [[buffer(15)]], const constant int64_t* s_strides [[buffer(15)]],
const constant int* batch_shape [[buffer(16)]], const constant int64_t* b_strides [[buffer(16)]],
const device uint32_t* lhs_indices [[buffer(17)]], const constant int& batch_ndims [[buffer(17)]],
const device uint32_t* rhs_indices [[buffer(18)]], const constant int* batch_shape [[buffer(18)]],
const constant int64_t* lhs_strides [[buffer(19)]], const constant int64_t* lhs_strides [[buffer(19)]],
const constant int64_t* rhs_strides [[buffer(20)]], const constant int64_t* rhs_strides [[buffer(20)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
@@ -1810,26 +1814,26 @@ template <typename T, int group_size, int bits>
} }
template <typename T, int group_size, int bits> template <typename T, int group_size, int bits>
[[kernel]] void bs_qvm( [[kernel]] void gather_qvm(
const device uint32_t* w [[buffer(0)]], const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]], const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]], const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]], const device T* x [[buffer(3)]],
device T* y [[buffer(4)]], const device uint32_t* lhs_indices [[buffer(4)]],
const constant int& in_vec_size [[buffer(5)]], const device uint32_t* rhs_indices [[buffer(5)]],
const constant int& out_vec_size [[buffer(6)]], device T* y [[buffer(6)]],
const constant int& x_batch_ndims [[buffer(7)]], const constant int& in_vec_size [[buffer(7)]],
const constant int* x_shape [[buffer(8)]], const constant int& out_vec_size [[buffer(8)]],
const constant int64_t* x_strides [[buffer(9)]], const constant int& x_batch_ndims [[buffer(9)]],
const constant int& w_batch_ndims [[buffer(10)]], const constant int* x_shape [[buffer(10)]],
const constant int* w_shape [[buffer(11)]], const constant int64_t* x_strides [[buffer(11)]],
const constant int64_t* w_strides [[buffer(12)]], const constant int& w_batch_ndims [[buffer(12)]],
const constant int64_t* s_strides [[buffer(13)]], const constant int* w_shape [[buffer(13)]],
const constant int64_t* b_strides [[buffer(14)]], const constant int64_t* w_strides [[buffer(14)]],
const constant int& batch_ndims [[buffer(15)]], const constant int64_t* s_strides [[buffer(15)]],
const constant int* batch_shape [[buffer(16)]], const constant int64_t* b_strides [[buffer(16)]],
const device uint32_t* lhs_indices [[buffer(17)]], const constant int& batch_ndims [[buffer(17)]],
const device uint32_t* rhs_indices [[buffer(18)]], const constant int* batch_shape [[buffer(18)]],
const constant int64_t* lhs_strides [[buffer(19)]], const constant int64_t* lhs_strides [[buffer(19)]],
const constant int64_t* rhs_strides [[buffer(20)]], const constant int64_t* rhs_strides [[buffer(20)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
@@ -1879,27 +1883,27 @@ template <
const int BM = 32, const int BM = 32,
const int BK = 32, const int BK = 32,
const int BN = 32> const int BN = 32>
[[kernel]] void bs_qmm_t( [[kernel]] void gather_qmm_t(
const device uint32_t* w [[buffer(0)]], const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]], const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]], const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]], const device T* x [[buffer(3)]],
device T* y [[buffer(4)]], const device uint32_t* lhs_indices [[buffer(4)]],
const constant int& K [[buffer(5)]], const device uint32_t* rhs_indices [[buffer(5)]],
const constant int& N [[buffer(6)]], device T* y [[buffer(6)]],
const constant int& M [[buffer(7)]], const constant int& K [[buffer(7)]],
const constant int& x_batch_ndims [[buffer(8)]], const constant int& N [[buffer(8)]],
const constant int* x_shape [[buffer(9)]], const constant int& M [[buffer(9)]],
const constant int64_t* x_strides [[buffer(10)]], const constant int& x_batch_ndims [[buffer(10)]],
const constant int& w_batch_ndims [[buffer(11)]], const constant int* x_shape [[buffer(11)]],
const constant int* w_shape [[buffer(12)]], const constant int64_t* x_strides [[buffer(12)]],
const constant int64_t* w_strides [[buffer(13)]], const constant int& w_batch_ndims [[buffer(13)]],
const constant int64_t* s_strides [[buffer(14)]], const constant int* w_shape [[buffer(14)]],
const constant int64_t* b_strides [[buffer(15)]], const constant int64_t* w_strides [[buffer(15)]],
const constant int& batch_ndims [[buffer(16)]], const constant int64_t* s_strides [[buffer(16)]],
const constant int* batch_shape [[buffer(17)]], const constant int64_t* b_strides [[buffer(17)]],
const device uint32_t* lhs_indices [[buffer(18)]], const constant int& batch_ndims [[buffer(18)]],
const device uint32_t* rhs_indices [[buffer(19)]], const constant int* batch_shape [[buffer(19)]],
const constant int64_t* lhs_strides [[buffer(20)]], const constant int64_t* lhs_strides [[buffer(20)]],
const constant int64_t* rhs_strides [[buffer(21)]], const constant int64_t* rhs_strides [[buffer(21)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
@@ -1946,27 +1950,27 @@ template <
const int BM = 32, const int BM = 32,
const int BK = 32, const int BK = 32,
const int BN = 32> const int BN = 32>
[[kernel]] void bs_qmm_n( [[kernel]] void gather_qmm_n(
const device uint32_t* w [[buffer(0)]], const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]], const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]], const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]], const device T* x [[buffer(3)]],
device T* y [[buffer(4)]], const device uint32_t* lhs_indices [[buffer(4)]],
const constant int& K [[buffer(5)]], const device uint32_t* rhs_indices [[buffer(5)]],
const constant int& N [[buffer(6)]], device T* y [[buffer(6)]],
const constant int& M [[buffer(7)]], const constant int& K [[buffer(7)]],
const constant int& x_batch_ndims [[buffer(8)]], const constant int& N [[buffer(8)]],
const constant int* x_shape [[buffer(9)]], const constant int& M [[buffer(9)]],
const constant int64_t* x_strides [[buffer(10)]], const constant int& x_batch_ndims [[buffer(10)]],
const constant int& w_batch_ndims [[buffer(11)]], const constant int* x_shape [[buffer(11)]],
const constant int* w_shape [[buffer(12)]], const constant int64_t* x_strides [[buffer(12)]],
const constant int64_t* w_strides [[buffer(13)]], const constant int& w_batch_ndims [[buffer(13)]],
const constant int64_t* s_strides [[buffer(14)]], const constant int* w_shape [[buffer(14)]],
const constant int64_t* b_strides [[buffer(15)]], const constant int64_t* w_strides [[buffer(15)]],
const constant int& batch_ndims [[buffer(16)]], const constant int64_t* s_strides [[buffer(16)]],
const constant int* batch_shape [[buffer(17)]], const constant int64_t* b_strides [[buffer(17)]],
const device uint32_t* lhs_indices [[buffer(18)]], const constant int& batch_ndims [[buffer(18)]],
const device uint32_t* rhs_indices [[buffer(19)]], const constant int* batch_shape [[buffer(19)]],
const constant int64_t* lhs_strides [[buffer(20)]], const constant int64_t* lhs_strides [[buffer(20)]],
const constant int64_t* rhs_strides [[buffer(21)]], const constant int64_t* rhs_strides [[buffer(21)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
@@ -2007,6 +2011,289 @@ template <
w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
} }
template <typename T, typename mma_t, typename loader_a_t, typename loader_b_t>
METAL_FUNC void gemm_loop_aligned(
threadgroup T* As,
threadgroup T* Bs,
thread mma_t& mma_op,
thread loader_a_t& loader_a,
thread loader_b_t& loader_b,
const int k_iterations) {
for (int k = 0; k < k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup memory
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
}
template <
bool rows_aligned,
bool cols_aligned,
bool transpose,
typename T,
typename mma_t,
typename loader_a_t,
typename loader_b_t>
METAL_FUNC void gemm_loop_unaligned(
threadgroup T* As,
threadgroup T* Bs,
thread mma_t& mma_op,
thread loader_a_t& loader_a,
thread loader_b_t& loader_b,
const int k_iterations,
const short tgp_bm,
const short tgp_bn,
const short tgp_bk) {
for (int k = 0; k < k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup memory
if (rows_aligned) {
loader_a.load_unsafe();
} else {
loader_a.load_safe(short2(tgp_bk, tgp_bm));
}
if (cols_aligned) {
loader_b.load_unsafe();
} else {
loader_b.load_safe(
transpose ? short2(tgp_bk, tgp_bn) : short2(tgp_bn, tgp_bk));
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
}
template <typename T, typename mma_t, typename loader_a_t, typename loader_b_t>
METAL_FUNC void gemm_loop_finalize(
threadgroup T* As,
threadgroup T* Bs,
thread mma_t& mma_op,
thread loader_a_t& loader_a,
thread loader_b_t& loader_b,
const short2 tile_a,
const short2 tile_b) {
loader_a.load_safe(tile_a);
loader_b.load_safe(tile_b);
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
template <
typename T,
int group_size,
int bits,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose>
[[kernel]] void gather_qmm_rhs(
const device T* x [[buffer(0)]],
const device uint32_t* w [[buffer(1)]],
const device T* scales [[buffer(2)]],
const device T* biases [[buffer(3)]],
const device uint32_t* indices [[buffer(4)]],
device T* y [[buffer(5)]],
const constant int& M [[buffer(6)]],
const constant int& N [[buffer(7)]],
const constant int& K [[buffer(8)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]]) {
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
constexpr int BK_padded = (BK + 16 / sizeof(T));
constexpr int BN_padded = (BN + 16 / sizeof(T));
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
using mma_t = mlx::steel::BlockMMA<
T,
T,
BM,
BN,
BK,
WM,
WN,
false,
transpose,
BK_padded,
transpose ? BK_padded : BN_padded>;
using loader_x_t =
mlx::steel::BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE>;
using loader_w_t = QuantizedBlockLoader<
T,
transpose ? BN : BK,
transpose ? BK : BN,
transpose ? BK_padded : BN_padded,
transpose,
WM * WN * SIMD_SIZE,
group_size,
bits>;
threadgroup T Xs[BM * BK_padded];
threadgroup T Ws[transpose ? BN * BK_padded : BK * BN_padded];
// Compute the block
const int K_w = K * bytes_per_pack / pack_factor;
const int K_g = K / group_size;
const int N_w = N * bytes_per_pack / pack_factor;
const int N_g = N / group_size;
const int K_it = K / BK;
const size_t stride_w = transpose ? N * K_w : K * N_w;
const size_t stride_s = transpose ? N * K_g : K * N_g;
const int y_row = tid.y * BM;
const int y_col = tid.x * BN;
const size_t y_row_long = size_t(y_row);
const size_t y_col_long = size_t(y_col);
// Prepare threadgroup bounds
const short tgp_bm = align_M ? BM : short(min(BM, M - y_row));
const short tgp_bn = align_N ? BN : short(min(BN, N - y_col));
// Calculate the final tiles in the case that K is not aligned
const int k_remain = K - K_it * BK;
const short2 tile_x = short2(k_remain, tgp_bm);
const short2 tile_w =
transpose ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
// Move x and output to the correct block
auto wl = (const device uint8_t*)w;
x += y_row_long * K;
y += y_row_long * N + y_col_long;
wl += transpose ? y_col_long * K_w : y_col * bytes_per_pack / pack_factor;
scales += transpose ? y_col_long * K_g : y_col / group_size;
biases += transpose ? y_col_long * K_g : y_col / group_size;
// Do as many matmuls as necessary
uint32_t index;
short offset;
uint32_t index_next = indices[y_row];
short offset_next = 0;
int n = 0;
while (n < tgp_bm) {
n++;
offset = offset_next;
index = index_next;
offset_next = tgp_bm;
for (; n < tgp_bm; n++) {
if (indices[y_row + n] != index) {
offset_next = n;
index_next = indices[y_row + n];
break;
}
}
threadgroup_barrier(mem_flags::mem_none);
// Prepare threadgroup mma operation
thread mma_t mma_op(simd_group_id, simd_lane_id);
// Prepare threadgroup loading operations
thread loader_x_t loader_x(x, K, Xs, simd_group_id, simd_lane_id);
thread loader_w_t loader_w(
wl + index * stride_w,
scales + index * stride_s,
biases + index * stride_s,
transpose ? K : N,
Ws,
simd_group_id,
simd_lane_id);
// Matrices are all aligned check nothing
if (align_M && align_N) {
gemm_loop_aligned(Xs, Ws, mma_op, loader_x, loader_w, K_it);
if (!align_K) {
threadgroup_barrier(mem_flags::mem_threadgroup);
gemm_loop_finalize(Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
}
// Store results to device memory
if (offset_next - offset == BM) {
mma_op.store_result(y, N);
} else {
mma_op.store_result_slice(
y, N, short2(0, offset), short2(BN, offset_next));
}
} else {
// Tile aligned so check outside of the hot loop
if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {
gemm_loop_aligned(Xs, Ws, mma_op, loader_x, loader_w, K_it);
if (!align_K) {
threadgroup_barrier(mem_flags::mem_threadgroup);
gemm_loop_finalize(
Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
}
// Store results to device memory
if (offset_next - offset == BM) {
mma_op.store_result(y, N);
} else {
mma_op.store_result_slice(
y, N, short2(0, offset), short2(BN, offset_next));
}
}
// Tile partially aligned check rows
else if (align_N || tgp_bn == BN) {
gemm_loop_unaligned<false, true, transpose>(
Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK);
if (!align_K) {
threadgroup_barrier(mem_flags::mem_threadgroup);
gemm_loop_finalize(
Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
}
mma_op.store_result_slice(
y, N, short2(0, offset), short2(BN, offset_next));
}
// Tile partially aligned check cols
else if (align_M || tgp_bm == BM) {
gemm_loop_unaligned<true, false, transpose>(
Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK);
if (!align_K) {
threadgroup_barrier(mem_flags::mem_threadgroup);
gemm_loop_finalize(
Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
}
mma_op.store_result_slice(
y, N, short2(0, offset), short2(tgp_bn, offset_next));
}
// Nothing aligned so check both rows and cols
else {
gemm_loop_unaligned<false, false, transpose>(
Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK);
if (!align_K) {
threadgroup_barrier(mem_flags::mem_threadgroup);
gemm_loop_finalize(
Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
}
mma_op.store_result_slice(
y, N, short2(0, offset), short2(tgp_bn, offset_next));
}
}
}
}
template <typename T, const int group_size, const int bits> template <typename T, const int group_size, const int bits>
[[kernel]] void affine_quantize( [[kernel]] void affine_quantize(
const device T* w [[buffer(0)]], const device T* w [[buffer(0)]],

View File

@@ -60,6 +60,20 @@
bits, \ bits, \
split_k) split_k)
#define instantiate_gather_qmm_rhs(func, name, type, group_size, bits, bm, bn, bk, wm, wn, transpose) \
instantiate_kernel( \
#name "_" #type "_gs_" #group_size "_b_" #bits "_bm_" #bm "_bn_" #bn "_bk_" #bk "_wm_" #wm "_wn_" #wn, \
func, \
type, \
group_size, \
bits, \
bm, \
bn, \
bk, \
wm, \
wn, \
transpose)
#define instantiate_quantized_batched_wrap(name, type, group_size, bits) \ #define instantiate_quantized_batched_wrap(name, type, group_size, bits) \
instantiate_quantized_batched(name, type, group_size, bits, 1) \ instantiate_quantized_batched(name, type, group_size, bits, 1) \
instantiate_quantized_batched(name, type, group_size, bits, 0) instantiate_quantized_batched(name, type, group_size, bits, 0)
@@ -73,14 +87,14 @@
#define instantiate_quantized_all_single(type, group_size, bits) \ #define instantiate_quantized_all_single(type, group_size, bits) \
instantiate_quantized(affine_quantize, type, group_size, bits) \ instantiate_quantized(affine_quantize, type, group_size, bits) \
instantiate_quantized(affine_dequantize, type, group_size, bits) \ instantiate_quantized(affine_dequantize, type, group_size, bits) \
instantiate_quantized(bs_qmv_fast, type, group_size, bits) \ instantiate_quantized(gather_qmv_fast, type, group_size, bits) \
instantiate_quantized(bs_qmv, type, group_size, bits) \ instantiate_quantized(gather_qmv, type, group_size, bits) \
instantiate_quantized(bs_qvm, type, group_size, bits) \ instantiate_quantized(gather_qvm, type, group_size, bits) \
instantiate_quantized(bs_qmm_n, type, group_size, bits) instantiate_quantized(gather_qmm_n, type, group_size, bits)
#define instantiate_quantized_all_aligned(type, group_size, bits) \ #define instantiate_quantized_all_aligned(type, group_size, bits) \
instantiate_quantized_aligned(bs_qmm_t, type, group_size, bits, true) \ instantiate_quantized_aligned(gather_qmm_t, type, group_size, bits, true) \
instantiate_quantized_aligned(bs_qmm_t, type, group_size, bits, false) \ instantiate_quantized_aligned(gather_qmm_t, type, group_size, bits, false) \
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, true, 1) \ instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, true, 1) \
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, true, 0) \ instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, true, 0) \
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, false, 1) \ instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, false, 1) \
@@ -96,12 +110,17 @@
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 8) \ instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 8) \
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 32) instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 32)
#define instantiate_quantized_all_rhs(type, group_size, bits) \
instantiate_gather_qmm_rhs(gather_qmm_rhs, gather_qmm_rhs_nt, type, group_size, bits, 16, 32, 32, 1, 2, true) \
instantiate_gather_qmm_rhs(gather_qmm_rhs, gather_qmm_rhs_nn, type, group_size, bits, 16, 32, 32, 1, 2, false)
#define instantiate_quantized_funcs(type, group_size, bits) \ #define instantiate_quantized_funcs(type, group_size, bits) \
instantiate_quantized_all_single(type, group_size, bits) \ instantiate_quantized_all_single(type, group_size, bits) \
instantiate_quantized_all_batched(type, group_size, bits) \ instantiate_quantized_all_batched(type, group_size, bits) \
instantiate_quantized_all_aligned(type, group_size, bits) \ instantiate_quantized_all_aligned(type, group_size, bits) \
instantiate_quantized_all_quad(type, group_size, bits) \ instantiate_quantized_all_quad(type, group_size, bits) \
instantiate_quantized_all_splitk(type, group_size, bits) instantiate_quantized_all_splitk(type, group_size, bits) \
instantiate_quantized_all_rhs(type, group_size, bits)
#define instantiate_quantized_types(group_size, bits) \ #define instantiate_quantized_types(group_size, bits) \
instantiate_quantized_funcs(float, group_size, bits) \ instantiate_quantized_funcs(float, group_size, bits) \

View File

@@ -380,69 +380,11 @@ template <typename T, int N_READS = RMS_N_READS>
} }
// clang-format off // clang-format off
#define instantiate_rms_single_row(name, itype) \ #define instantiate_rms(name, itype) \
template [[host_name("rms" #name)]] [[kernel]] void \ instantiate_kernel("rms" #name, rms_single_row, itype) \
rms_single_row<itype>( \ instantiate_kernel("vjp_rms" #name, vjp_rms_single_row, itype) \
const device itype* x, \ instantiate_kernel("rms_looped" #name, rms_looped, itype) \
const device itype* w, \ instantiate_kernel("vjp_rms_looped" #name, vjp_rms_looped, itype)
device itype* out, \
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
\
template [[host_name("vjp_rms" #name)]] [[kernel]] void \
vjp_rms_single_row<itype>( \
const device itype* x, \
const device itype* w, \
const device itype* g, \
device itype* gx, \
device itype* gw, \
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_rms_looped(name, itype) \
template [[host_name("rms_looped" #name)]] [[kernel]] void \
rms_looped<itype>( \
const device itype* x, \
const device itype* w, \
device itype* out, \
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
\
template [[host_name("vjp_rms_looped" #name)]] [[kernel]] void \
vjp_rms_looped<itype>( \
const device itype* x, \
const device itype* w, \
const device itype* g, \
device itype* gx, \
device itype* gw, \
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_rms(name, itype) \
instantiate_rms_single_row(name, itype) \
instantiate_rms_looped(name, itype)
instantiate_rms(float32, float) instantiate_rms(float32, float)
instantiate_rms(float16, half) instantiate_rms(float16, half)

View File

@@ -1,11 +1,11 @@
#include <metal_stdlib> #include <metal_stdlib>
#include "mlx/backend/metal/kernels/sdpa_vector.h" // clang-format off
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/sdpa_vector.h"
using namespace metal; using namespace metal;
// clang-format off
// SDPA vector instantiations // SDPA vector instantiations
#define instantiate_sdpa_vector_aggregation(type, value_dim) \ #define instantiate_sdpa_vector_aggregation(type, value_dim) \
instantiate_kernel( \ instantiate_kernel( \
@@ -32,9 +32,11 @@ using namespace metal;
instantiate_sdpa_vector(type, 64, 64) \ instantiate_sdpa_vector(type, 64, 64) \
instantiate_sdpa_vector(type, 96, 96) \ instantiate_sdpa_vector(type, 96, 96) \
instantiate_sdpa_vector(type, 128, 128) \ instantiate_sdpa_vector(type, 128, 128) \
instantiate_sdpa_vector(type, 256, 256) \
instantiate_sdpa_vector_aggregation(type, 64) \ instantiate_sdpa_vector_aggregation(type, 64) \
instantiate_sdpa_vector_aggregation(type, 96) \ instantiate_sdpa_vector_aggregation(type, 96) \
instantiate_sdpa_vector_aggregation(type, 128) instantiate_sdpa_vector_aggregation(type, 128) \
instantiate_sdpa_vector_aggregation(type, 256)
instantiate_sdpa_vector_heads(float) instantiate_sdpa_vector_heads(float)
instantiate_sdpa_vector_heads(bfloat16_t) instantiate_sdpa_vector_heads(bfloat16_t)

View File

@@ -2,6 +2,8 @@
#pragma once #pragma once
#include "mlx/backend/metal/kernels/binary_ops.h"
#define DEFINE_SIMD_SCAN() \ #define DEFINE_SIMD_SCAN() \
template <typename T, metal::enable_if_t<sizeof(T) < 8, bool> = true> \ template <typename T, metal::enable_if_t<sizeof(T) < 8, bool> = true> \
T simd_scan(T val) { \ T simd_scan(T val) { \
@@ -139,6 +141,29 @@ struct CumMin {
} }
}; };
template <typename U>
struct CumLogaddexp {
static constexpr constant U init = Limits<U>::min;
template <typename T>
U operator()(U a, T b) {
return LogAddExp{}(a, static_cast<U>(b));
}
U simd_scan(U x) {
for (int i = 1; i <= 16; i *= 2) {
U other = simd_shuffle_and_fill_up(x, init, i);
x = LogAddExp{}(x, other);
}
return x;
}
U simd_exclusive_scan(U x) {
x = simd_scan(x);
return simd_shuffle_and_fill_up(x, init, 1);
}
};
template <typename T, typename U, int N_READS, bool reverse> template <typename T, typename U, int N_READS, bool reverse>
inline void load_unsafe(U values[N_READS], const device T* input) { inline void load_unsafe(U values[N_READS], const device T* input) {
if (reverse) { if (reverse) {

View File

@@ -101,4 +101,7 @@ instantiate_scan_helper(min_int64_int64, int64_t, int64_t, CumMi
instantiate_scan_helper(min_float16_float16, half, half, CumMin, 4) instantiate_scan_helper(min_float16_float16, half, half, CumMin, 4)
instantiate_scan_helper(min_float32_float32, float, float, CumMin, 4) instantiate_scan_helper(min_float32_float32, float, float, CumMin, 4)
instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMin, 4) instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMin, 4)
instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin, 2) // clang-format on instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin, 2)
instantiate_scan_helper(logaddexp_float16_float16, half, half, CumLogaddexp, 4)
instantiate_scan_helper(logaddexp_float32_float32, float, float, CumLogaddexp, 4)
instantiate_scan_helper(logaddexp_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumLogaddexp, 4) // clang-format on

View File

@@ -6,6 +6,9 @@ using namespace metal;
constant bool has_mask [[function_constant(20)]]; constant bool has_mask [[function_constant(20)]];
constant bool query_transposed [[function_constant(21)]]; constant bool query_transposed [[function_constant(21)]];
constant bool do_causal [[function_constant(22)]];
constant bool bool_mask [[function_constant(23)]];
constant bool float_mask [[function_constant(24)]];
template <typename T, int D, int V = D> template <typename T, int D, int V = D>
[[kernel]] void sdpa_vector( [[kernel]] void sdpa_vector(
@@ -13,17 +16,21 @@ template <typename T, int D, int V = D>
const device T* keys [[buffer(1)]], const device T* keys [[buffer(1)]],
const device T* values [[buffer(2)]], const device T* values [[buffer(2)]],
device T* out [[buffer(3)]], device T* out [[buffer(3)]],
const constant int& gqa_factor, const constant int& gqa_factor [[buffer(4)]],
const constant int& N, const constant int& N [[buffer(5)]],
const constant size_t& k_head_stride, const constant size_t& k_head_stride [[buffer(6)]],
const constant size_t& k_seq_stride, const constant size_t& k_seq_stride [[buffer(7)]],
const constant size_t& v_head_stride, const constant size_t& v_head_stride [[buffer(8)]],
const constant size_t& v_seq_stride, const constant size_t& v_seq_stride [[buffer(9)]],
const constant float& scale, const constant float& scale [[buffer(10)]],
const device bool* mask [[function_constant(has_mask)]], const device bool* bmask [[buffer(11), function_constant(bool_mask)]],
const constant int& mask_kv_seq_stride [[function_constant(has_mask)]], const device T* fmask [[buffer(12), function_constant(float_mask)]],
const constant int& mask_q_seq_stride [[function_constant(has_mask)]], const constant int& mask_kv_seq_stride
const constant int& mask_head_stride [[function_constant(has_mask)]], [[buffer(13), function_constant(has_mask)]],
const constant int& mask_q_seq_stride
[[buffer(14), function_constant(has_mask)]],
const constant int& mask_head_stride
[[buffer(15), function_constant(has_mask)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint3 tpg [[threadgroups_per_grid]], uint3 tpg [[threadgroups_per_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
@@ -57,8 +64,12 @@ template <typename T, int D, int V = D>
simd_lid * qk_per_thread; simd_lid * qk_per_thread;
values += kv_head_idx * v_head_stride + simd_gid * v_seq_stride + values += kv_head_idx * v_head_stride + simd_gid * v_seq_stride +
simd_lid * v_per_thread; simd_lid * v_per_thread;
if (has_mask) { if (bool_mask) {
mask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride + bmask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride +
q_seq_idx * mask_q_seq_stride;
}
if (float_mask) {
fmask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride +
q_seq_idx * mask_q_seq_stride; q_seq_idx * mask_q_seq_stride;
} }
@@ -77,7 +88,13 @@ template <typename T, int D, int V = D>
// For each key // For each key
for (int i = simd_gid; i < N; i += BN) { for (int i = simd_gid; i < N; i += BN) {
if (!has_mask || mask[0]) { bool use_key = true;
if (do_causal) {
use_key = i <= (N - int(tpg.y) + int(q_seq_idx));
} else if (bool_mask) {
use_key = bmask[0];
}
if (use_key) {
// Read the key // Read the key
for (int j = 0; j < qk_per_thread; j++) { for (int j = 0; j < qk_per_thread; j++) {
k[j] = keys[j]; k[j] = keys[j];
@@ -89,6 +106,9 @@ template <typename T, int D, int V = D>
score += q[j] * k[j]; score += q[j] * k[j];
} }
score = simd_sum(score); score = simd_sum(score);
if (float_mask) {
score += max(Limits<U>::finite_min, static_cast<U>(fmask[0]));
}
// Update the accumulators // Update the accumulators
U new_max = max(max_score, score); U new_max = max(max_score, score);
@@ -107,8 +127,11 @@ template <typename T, int D, int V = D>
// Move the pointers to the next kv // Move the pointers to the next kv
keys += inner_k_stride; keys += inner_k_stride;
values += inner_v_stride; values += inner_v_stride;
if (has_mask) { if (bool_mask) {
mask += BN * mask_kv_seq_stride; bmask += BN * mask_kv_seq_stride;
}
if (float_mask) {
fmask += BN * mask_kv_seq_stride;
} }
} }
@@ -149,17 +172,21 @@ template <typename T, int D, int V = D>
device float* out [[buffer(3)]], device float* out [[buffer(3)]],
device float* sums [[buffer(4)]], device float* sums [[buffer(4)]],
device float* maxs [[buffer(5)]], device float* maxs [[buffer(5)]],
const constant int& gqa_factor, const constant int& gqa_factor [[buffer(6)]],
const constant int& N, const constant int& N [[buffer(7)]],
const constant size_t& k_head_stride, const constant size_t& k_head_stride [[buffer(8)]],
const constant size_t& k_seq_stride, const constant size_t& k_seq_stride [[buffer(9)]],
const constant size_t& v_head_stride, const constant size_t& v_head_stride [[buffer(10)]],
const constant size_t& v_seq_stride, const constant size_t& v_seq_stride [[buffer(11)]],
const constant float& scale, const constant float& scale [[buffer(12)]],
const device bool* mask [[function_constant(has_mask)]], const device bool* bmask [[buffer(13), function_constant(bool_mask)]],
const constant int& mask_kv_seq_stride [[function_constant(has_mask)]], const device T* fmask [[buffer(14), function_constant(float_mask)]],
const constant int& mask_q_seq_stride [[function_constant(has_mask)]], const constant int& mask_kv_seq_stride
const constant int& mask_head_stride [[function_constant(has_mask)]], [[buffer(15), function_constant(has_mask)]],
const constant int& mask_q_seq_stride
[[buffer(16), function_constant(has_mask)]],
const constant int& mask_head_stride
[[buffer(17), function_constant(has_mask)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint3 tpg [[threadgroups_per_grid]], uint3 tpg [[threadgroups_per_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
@@ -197,8 +224,13 @@ template <typename T, int D, int V = D>
values += kv_head_idx * v_head_stride + values += kv_head_idx * v_head_stride +
(block_idx * BN + simd_gid) * v_seq_stride + simd_lid * v_per_thread; (block_idx * BN + simd_gid) * v_seq_stride + simd_lid * v_per_thread;
out += o_offset * blocks * V + block_idx * V + simd_lid * v_per_thread; out += o_offset * blocks * V + block_idx * V + simd_lid * v_per_thread;
if (has_mask) { if (bool_mask) {
mask += head_idx * mask_head_stride + bmask += head_idx * mask_head_stride +
(block_idx * BN + simd_gid) * mask_kv_seq_stride +
q_seq_idx * mask_q_seq_stride;
}
if (float_mask) {
fmask += head_idx * mask_head_stride +
(block_idx * BN + simd_gid) * mask_kv_seq_stride + (block_idx * BN + simd_gid) * mask_kv_seq_stride +
q_seq_idx * mask_q_seq_stride; q_seq_idx * mask_q_seq_stride;
} }
@@ -218,7 +250,13 @@ template <typename T, int D, int V = D>
// For each key // For each key
for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) { for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) {
if (!has_mask || mask[0]) { bool use_key = true;
if (do_causal) {
use_key = i <= (N - int(tpg.y) + int(q_seq_idx));
} else if (bool_mask) {
use_key = bmask[0];
}
if (use_key) {
// Read the key // Read the key
for (int i = 0; i < qk_per_thread; i++) { for (int i = 0; i < qk_per_thread; i++) {
k[i] = keys[i]; k[i] = keys[i];
@@ -230,6 +268,9 @@ template <typename T, int D, int V = D>
score += q[i] * k[i]; score += q[i] * k[i];
} }
score = simd_sum(score); score = simd_sum(score);
if (float_mask) {
score += fmask[0];
}
// Update the accumulators // Update the accumulators
U new_max = max(max_score, score); U new_max = max(max_score, score);
@@ -248,8 +289,11 @@ template <typename T, int D, int V = D>
// Move the pointers to the next kv // Move the pointers to the next kv
keys += blocks * inner_k_stride; keys += blocks * inner_k_stride;
values += blocks * inner_v_stride; values += blocks * inner_v_stride;
if (has_mask) { if (bool_mask) {
mask += BN * blocks * mask_kv_seq_stride; bmask += BN * blocks * mask_kv_seq_stride;
}
if (float_mask) {
fmask += BN * blocks * mask_kv_seq_stride;
} }
} }

View File

@@ -9,47 +9,13 @@ using namespace metal;
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/softmax.h" #include "mlx/backend/metal/kernels/softmax.h"
#define instantiate_softmax(name, itype) \ #define instantiate_softmax(name, itype) \
template [[host_name("block_softmax_" #name)]] [[kernel]] void \ instantiate_kernel("block_softmax_" #name, softmax_single_row, itype) \
softmax_single_row<itype>( \ instantiate_kernel("looped_softmax_" #name, softmax_looped, itype)
const device itype* in, \
device itype* out, \
constant int& axis_size, \
uint gid [[thread_position_in_grid]], \
uint _lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
template [[host_name("looped_softmax_" #name)]] [[kernel]] void \
softmax_looped<itype>( \
const device itype* in, \
device itype* out, \
constant int& axis_size, \
uint gid [[threadgroup_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_softmax_precise(name, itype) \ #define instantiate_softmax_precise(name, itype) \
template [[host_name("block_softmax_precise_" #name)]] [[kernel]] void \ instantiate_kernel("block_softmax_precise_" #name, softmax_single_row, itype, float) \
softmax_single_row<itype, float>( \ instantiate_kernel("looped_softmax_precise_" #name, softmax_looped, itype, float)
const device itype* in, \
device itype* out, \
constant int& axis_size, \
uint gid [[thread_position_in_grid]], \
uint _lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
template [[host_name("looped_softmax_precise_" #name)]] [[kernel]] void \
softmax_looped<itype, float>( \
const device itype* in, \
device itype* out, \
constant int& axis_size, \
uint gid [[threadgroup_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
instantiate_softmax(float32, float) instantiate_softmax(float32, float)
instantiate_softmax(float16, half) instantiate_softmax(float16, half)

View File

@@ -229,7 +229,7 @@ template <
// Init to -Inf // Init to -Inf
STEEL_PRAGMA_UNROLL STEEL_PRAGMA_UNROLL
for (short i = 0; i < kRowsPT; ++i) { for (short i = 0; i < kRowsPT; ++i) {
max_score[i] = Limits<AccumType>::min; max_score[i] = Limits<AccumType>::finite_min;
} }
int kb_lim = params->NK; int kb_lim = params->NK;
@@ -237,6 +237,7 @@ template <
if (do_causal) { if (do_causal) {
int q_max = (tid.x + 1) * BQ + params->qL_off; int q_max = (tid.x + 1) * BQ + params->qL_off;
kb_lim = (q_max + BK - 1) / BK; kb_lim = (q_max + BK - 1) / BK;
kb_lim = min(params->NK, kb_lim);
} }
// Loop over KV seq length // Loop over KV seq length
@@ -272,7 +273,7 @@ template <
if (!align_K && kb == (params->NK_aligned)) { if (!align_K && kb == (params->NK_aligned)) {
using stile_t = decltype(Stile); using stile_t = decltype(Stile);
using selem_t = typename stile_t::elem_type; using selem_t = typename stile_t::elem_type;
constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity(); constexpr auto neg_inf = Limits<selem_t>::finite_min;
STEEL_PRAGMA_UNROLL STEEL_PRAGMA_UNROLL
for (short i = 0; i < stile_t::kTileRows; i++) { for (short i = 0; i < stile_t::kTileRows; i++) {
@@ -290,10 +291,10 @@ template <
} }
// Mask out if causal // Mask out if causal
if (do_causal && kb >= (kb_lim - (BQ + BK - 1) / BK - int(!align_K))) { if (do_causal && kb >= (kb_lim - ((BQ + BK - 1) / BK) - int(!align_K))) {
using stile_t = decltype(Stile); using stile_t = decltype(Stile);
using selem_t = typename stile_t::elem_type; using selem_t = typename stile_t::elem_type;
constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity(); constexpr auto neg_inf = Limits<selem_t>::finite_min;
STEEL_PRAGMA_UNROLL STEEL_PRAGMA_UNROLL
for (short i = 0; i < stile_t::kTileRows; i++) { for (short i = 0; i < stile_t::kTileRows; i++) {
@@ -316,7 +317,7 @@ template <
if (has_mask) { if (has_mask) {
using stile_t = decltype(Stile); using stile_t = decltype(Stile);
using selem_t = typename stile_t::elem_type; using selem_t = typename stile_t::elem_type;
constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity(); constexpr auto neg_inf = Limits<selem_t>::finite_min;
constexpr bool is_bool = is_same_v<MaskType, bool>; constexpr bool is_bool = is_same_v<MaskType, bool>;
using melem_t = typename metal::conditional_t<is_bool, bool, selem_t>; using melem_t = typename metal::conditional_t<is_bool, bool, selem_t>;

View File

@@ -15,10 +15,6 @@ constant bool align_M [[function_constant(200)]];
constant bool align_N [[function_constant(201)]]; constant bool align_N [[function_constant(201)]];
constant bool align_K [[function_constant(202)]]; constant bool align_K [[function_constant(202)]];
constant bool do_gather [[function_constant(300)]];
constant bool gather_bias = do_gather && use_out_source;
// clang-format off // clang-format off
template < template <
typename T, typename T,
@@ -39,12 +35,6 @@ template <
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]], const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
const constant int* batch_shape [[buffer(6)]], const constant int* batch_shape [[buffer(6)]],
const constant int64_t* batch_strides [[buffer(7)]], const constant int64_t* batch_strides [[buffer(7)]],
const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]],
const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]],
const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]],
const constant int* operand_shape [[buffer(13), function_constant(do_gather)]],
const constant int64_t* operand_strides [[buffer(14), function_constant(do_gather)]],
const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]],
uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
@@ -81,84 +71,26 @@ template <
} }
// Adjust for batch // Adjust for batch
if (has_batch) {
const constant auto* A_bstrides = batch_strides;
const constant auto* B_bstrides = batch_strides + params->batch_ndim;
// Handle gather ulong2 batch_offsets = elem_to_loc_broadcast(
if (do_gather) { tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
// Read indices
uint32_t indx_A, indx_B, indx_C;
if (has_batch) { A += batch_offsets.x;
const constant auto* indx_A_bstrides = batch_strides; B += batch_offsets.y;
const constant auto* indx_B_bstrides = batch_strides + params->batch_ndim;
ulong2 indx_offsets = elem_to_loc_broadcast(
tid.z,
batch_shape,
indx_A_bstrides,
indx_B_bstrides,
params->batch_ndim);
indx_A = lhs_indices[indx_offsets.x];
indx_B = rhs_indices[indx_offsets.y];
if (use_out_source) {
const constant auto* indx_C_bstrides =
indx_B_bstrides + params->batch_ndim;
auto indx_offset_C = elem_to_loc(
tid.z, batch_shape, indx_C_bstrides, params->batch_ndim);
indx_C = C_indices[indx_offset_C];
}
} else {
indx_A = lhs_indices[params->batch_stride_a * tid.z];
indx_B = rhs_indices[params->batch_stride_b * tid.z];
if (use_out_source) {
indx_C = C_indices[addmm_params->batch_stride_c * tid.z];
}
}
// Translate indices to offsets
int batch_ndim_A = operand_batch_ndim.x;
const constant int* batch_shape_A = operand_shape;
const constant auto* batch_strides_A = operand_strides;
A += elem_to_loc(indx_A, batch_shape_A, batch_strides_A, batch_ndim_A);
int batch_ndim_B = operand_batch_ndim.y;
const constant int* batch_shape_B = batch_shape_A + batch_ndim_A;
const constant auto* batch_strides_B = batch_strides_A + batch_ndim_A;
B += elem_to_loc(indx_B, batch_shape_B, batch_strides_B, batch_ndim_B);
if (use_out_source) { if (use_out_source) {
int batch_ndim_C = operand_batch_ndim.z; const constant auto* C_bstrides = B_bstrides + params->batch_ndim;
const constant int* batch_shape_C = batch_shape_B + batch_ndim_B; C += elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim);
const constant auto* batch_strides_C = batch_strides_B + batch_ndim_B;
C += elem_to_loc(indx_C, batch_shape_C, batch_strides_C, batch_ndim_C);
} }
} else {
A += params->batch_stride_a * tid.z;
B += params->batch_stride_b * tid.z;
} if (use_out_source) {
C += addmm_params->batch_stride_c * tid.z;
// Handle regular batch
else {
if (has_batch) {
const constant auto* A_bstrides = batch_strides;
const constant auto* B_bstrides = batch_strides + params->batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
A += batch_offsets.x;
B += batch_offsets.y;
if (use_out_source) {
const constant auto* C_bstrides = B_bstrides + params->batch_ndim;
C += elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim);
}
} else {
A += params->batch_stride_a * tid.z;
B += params->batch_stride_b * tid.z;
if (use_out_source) {
C += addmm_params->batch_stride_c * tid.z;
}
} }
} }

View File

@@ -0,0 +1,459 @@
// Copyright © 2024 Apple Inc.
using namespace mlx::steel;
constant bool has_batch [[function_constant(10)]];
constant bool align_M [[function_constant(200)]];
constant bool align_N [[function_constant(201)]];
constant bool align_K [[function_constant(202)]];
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
typename AccumType = float>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gather_mm_rhs(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
const device uint32_t* rhs_indices [[buffer(2)]],
device T* C [[buffer(3)]],
const constant GEMMParams* params [[buffer(4)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]]) {
using gemm_kernel = GEMMKernel<
T,
T,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
true,
true,
AccumType>;
using loader_a_t = typename gemm_kernel::loader_a_t;
using loader_b_t = typename gemm_kernel::loader_b_t;
using mma_t = typename gemm_kernel::mma_t;
if (params->tiles_n <= static_cast<int>(tid.x) ||
params->tiles_m <= static_cast<int>(tid.y)) {
return;
}
// Prepare threadgroup memory
threadgroup T As[gemm_kernel::tgp_mem_size_a];
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
// Find the block in A, B, C
const int c_row = tid.y * BM;
const int c_col = tid.x * BN;
const size_t c_row_long = size_t(c_row);
const size_t c_col_long = size_t(c_col);
// Prepare threadgroup bounds
const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row));
const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col));
A += transpose_a ? c_row_long : c_row_long * params->lda;
B += transpose_b ? c_col_long * params->ldb : c_col_long;
C += c_row_long * params->ldd + c_col_long;
// Do as many matmuls as necessary
uint32_t index;
short offset;
uint32_t index_next = rhs_indices[c_row];
short offset_next = 0;
int n = 0;
while (n < tgp_bm) {
n++;
offset = offset_next;
index = index_next;
offset_next = tgp_bm;
for (; n < tgp_bm; n++) {
if (rhs_indices[c_row + n] != index) {
offset_next = n;
index_next = rhs_indices[c_row + n];
break;
}
}
threadgroup_barrier(mem_flags::mem_none);
// Prepare threadgroup mma operation
thread mma_t mma_op(simd_group_id, simd_lane_id);
// Prepare threadgroup loading operations
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
thread loader_b_t loader_b(
B + index * params->batch_stride_b,
params->ldb,
Bs,
simd_group_id,
simd_lane_id);
// Prepare iterations
const int gemm_k_iterations = params->gemm_k_iterations_aligned;
// Do unaligned K iterations first
if (!align_K) {
const int k_last = params->gemm_k_iterations_aligned * BK;
const int k_remain = params->K - k_last;
const size_t k_jump_a =
transpose_a ? params->lda * size_t(k_last) : size_t(k_last);
const size_t k_jump_b =
transpose_b ? size_t(k_last) : params->ldb * size_t(k_last);
// Move loader source ahead to end
loader_a.src += k_jump_a;
loader_b.src += k_jump_b;
// Load tile
const short2 tile_dims_A =
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
const short2 tile_dims_B =
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
loader_a.load_safe(tile_dims_A);
loader_b.load_safe(tile_dims_B);
threadgroup_barrier(mem_flags::mem_threadgroup);
// Do matmul
mma_op.mma(As, Bs);
// Reset source back to start
loader_a.src -= k_jump_a;
loader_b.src -= k_jump_b;
}
// Matrix level aligned never check
if (align_M && align_N) {
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
// Store results to device memory
if (offset_next - offset == BM) {
mma_op.store_result(C, params->ldd);
} else {
mma_op.store_result_slice(
C, params->ldd, short2(0, offset), short2(BN, offset_next));
}
} else {
const short lbk = 0;
// Tile aligned don't check
if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
lbk,
LoopAlignment<true, true, true>{});
if (offset_next - offset == BM) {
mma_op.store_result(C, params->ldd);
} else {
mma_op.store_result_slice(
C, params->ldd, short2(0, offset), short2(BN, offset_next));
}
}
// Tile partially aligned check rows
else if (align_N || tgp_bn == BN) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
lbk,
LoopAlignment<false, true, true>{});
mma_op.store_result_slice(
C, params->ldd, short2(0, offset), short2(BN, offset_next));
}
// Tile partially aligned check cols
else if (align_M || tgp_bm == BM) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
lbk,
LoopAlignment<true, false, true>{});
mma_op.store_result_slice(
C, params->ldd, short2(0, offset), short2(tgp_bn, offset_next));
}
// Nothing aligned so check both rows and cols
else {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
lbk,
LoopAlignment<false, false, true>{});
mma_op.store_result_slice(
C, params->ldd, short2(0, offset), short2(tgp_bn, offset_next));
}
}
}
}
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
typename AccumType = float>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gather_mm(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
const device uint32_t* lhs_indices [[buffer(2)]],
const device uint32_t* rhs_indices [[buffer(3)]],
device T* C [[buffer(4)]],
const constant GEMMParams* params [[buffer(5)]],
const constant int* indices_shape [[buffer(6)]],
const constant int64_t* lhs_strides [[buffer(7)]],
const constant int64_t* rhs_strides [[buffer(8)]],
const constant int& batch_ndim_a [[buffer(9)]],
const constant int* batch_shape_a [[buffer(10)]],
const constant int64_t* batch_strides_a [[buffer(11)]],
const constant int& batch_ndim_b [[buffer(12)]],
const constant int* batch_shape_b [[buffer(13)]],
const constant int64_t* batch_strides_b [[buffer(14)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]]) {
using gemm_kernel = GEMMKernel<
T,
T,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
true,
true,
AccumType>;
using loader_a_t = typename gemm_kernel::loader_a_t;
using loader_b_t = typename gemm_kernel::loader_b_t;
using mma_t = typename gemm_kernel::mma_t;
if (params->tiles_n <= static_cast<int>(tid.x) ||
params->tiles_m <= static_cast<int>(tid.y)) {
return;
}
// Move A and B to the locations pointed by lhs_indices and rhs_indices.
uint32_t indx_A, indx_B;
if (has_batch) {
ulong2 indices_offsets = elem_to_loc_broadcast(
tid.z, indices_shape, lhs_strides, rhs_strides, params->batch_ndim);
indx_A = lhs_indices[indices_offsets.x];
indx_B = rhs_indices[indices_offsets.y];
} else {
indx_A = lhs_indices[params->batch_stride_a * tid.z];
indx_B = rhs_indices[params->batch_stride_b * tid.z];
}
A += elem_to_loc(indx_A, batch_shape_a, batch_strides_a, batch_ndim_a);
B += elem_to_loc(indx_B, batch_shape_b, batch_strides_b, batch_ndim_b);
C += params->batch_stride_d * tid.z;
// Prepare threadgroup memory
threadgroup T As[gemm_kernel::tgp_mem_size_a];
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
// Just make sure everybody's finished with the indexing math above.
threadgroup_barrier(mem_flags::mem_none);
// Find block in A, B, C
const int c_row = tid.y * BM;
const int c_col = tid.x * BN;
const size_t c_row_long = size_t(c_row);
const size_t c_col_long = size_t(c_col);
A += transpose_a ? c_row_long : c_row_long * params->lda;
B += transpose_b ? c_col_long * params->ldb : c_col_long;
C += c_row_long * params->ldd + c_col_long;
// Prepare threadgroup mma operation
thread mma_t mma_op(simd_group_id, simd_lane_id);
// Prepare threadgroup loading operations
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
// Prepare threadgroup bounds
const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row));
const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col));
// Prepare iterations
int gemm_k_iterations = params->gemm_k_iterations_aligned;
// Do unaligned K iterations first
if (!align_K) {
const int k_last = params->gemm_k_iterations_aligned * BK;
const int k_remain = params->K - k_last;
const size_t k_jump_a =
transpose_a ? params->lda * size_t(k_last) : size_t(k_last);
const size_t k_jump_b =
transpose_b ? size_t(k_last) : params->ldb * size_t(k_last);
// Move loader source ahead to end
loader_a.src += k_jump_a;
loader_b.src += k_jump_b;
// Load tile
const short2 tile_dims_A =
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
const short2 tile_dims_B =
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
loader_a.load_safe(tile_dims_A);
loader_b.load_safe(tile_dims_B);
threadgroup_barrier(mem_flags::mem_threadgroup);
// Do matmul
mma_op.mma(As, Bs);
// Reset source back to start
loader_a.src -= k_jump_a;
loader_b.src -= k_jump_b;
}
// Matrix level aligned never check
if (align_M && align_N) {
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
// Store results to device memory
mma_op.store_result(C, params->ldd);
} else {
const short lbk = 0;
// Tile aligned don't check
if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
lbk,
LoopAlignment<true, true, true>{});
mma_op.store_result(C, params->ldd);
}
// Tile partially aligned check rows
else if (align_N || tgp_bn == BN) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
lbk,
LoopAlignment<false, true, true>{});
mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm));
}
// Tile partially aligned check cols
else if (align_M || tgp_bm == BM) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
lbk,
LoopAlignment<true, false, true>{});
mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm));
}
// Nothing aligned so check both rows and cols
else {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
lbk,
LoopAlignment<false, false, true>{});
mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm));
}
}
}

View File

@@ -0,0 +1,59 @@
// Copyright © 2024 Apple Inc.
// clang-format off
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h"
#define instantiate_gather_mm_rhs(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_kernel( \
"steel_gather_mm_rhs_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn \
"_bk" #bk "_wm" #wm "_wn" #wn, \
gather_mm_rhs, \
itype, \
bm, \
bn, \
bk, \
wm, \
wn, \
trans_a, \
trans_b, \
float)
#define instantiate_gather_mm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_kernel( \
"steel_gather_mm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn \
"_bk" #bk "_wm" #wm "_wn" #wn, \
gather_mm, \
itype, \
bm, \
bn, \
bk, \
wm, \
wn, \
trans_a, \
trans_b, \
float)
#define instantiate_gather_mm_rhs_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gather_mm_rhs(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gather_mm_rhs(nt, false, true, iname, itype, oname, otype, bm, bn, bk, wm, wn)
#define instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gather_mm(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gather_mm(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gather_mm(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gather_mm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
#define instantiate_gather_mm_shapes_helper(iname, itype, oname, otype) \
instantiate_gather_mm_rhs_transpose_helper(iname, itype, oname, otype, 16, 64, 16, 1, 2) \
instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 1, 2) \
instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 1, 2) \
instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2)
// clang-format on
instantiate_gather_mm_shapes_helper(float16, half, float16, half);
instantiate_gather_mm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
instantiate_gather_mm_shapes_helper(float32, float, float32, float);

View File

@@ -142,6 +142,42 @@ struct BaseMMAFrag<T, 8, 8> {
} }
} }
template <
typename DstPtrType,
typename StrX,
typename StrY,
typename StartX,
typename StopX,
typename StartY,
typename StopY,
typename OffX,
typename OffY>
METAL_FUNC static constexpr void store_slice(
const thread frag_type& src,
DstPtrType dst,
StrX str_x,
StrY str_y,
StartX start_x,
StopX stop_x,
StartY start_y,
StopY stop_y,
OffX off_x = Int<0>{},
OffY off_y = Int<0>{}) {
using U = pointer_element_t<DstPtrType>;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kElemRows; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kElemCols; j++) {
if ((off_x + i) < stop_x && (off_x + i) >= start_x &&
(off_y + j) < stop_y && (off_y + j) >= start_y) {
dst[(off_x + i) * str_x + (off_y + j) * str_y] =
static_cast<U>(src[i * kElemCols + j]);
}
}
}
}
METAL_FUNC static constexpr void mma( METAL_FUNC static constexpr void mma(
thread frag_type& D, thread frag_type& D,
thread frag_type& A, thread frag_type& A,
@@ -335,6 +371,31 @@ struct MMATile {
} }
} }
} }
template <typename U, int w_x, int w_y>
METAL_FUNC void store_slice(
device U* dst,
const int ld,
const short2 start,
const short2 stop) const {
STEEL_PRAGMA_UNROLL
for (int i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (int j = 0; j < kTileCols; ++j) {
MMAFrag_t::store_slice(
frag_at(i, j),
dst,
ld,
Int<1>{},
start.y,
stop.y,
start.x,
stop.x,
(i * kFragRows) * w_x,
(j * kFragCols) * w_y);
}
}
}
}; };
template <typename T, typename U, int M, int N, int K> template <typename T, typename U, int M, int N, int K>
@@ -474,6 +535,26 @@ struct BlockMMA {
Ctile.template store<U, WM, WN>(D, ldd); Ctile.template store<U, WM, WN>(D, ldd);
} }
METAL_FUNC void
store_result_slice(device U* D, const int ldd, short2 start, short2 stop) {
// Apply epilogue
STEEL_PRAGMA_UNROLL
for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]);
}
D += sm * ldd + sn;
start -= short2(sn, sm);
stop -= short2(sn, sm);
// TODO: Check the start as well
if (stop.y <= 0 || stop.x <= 0) {
return;
}
Ctile.template store_slice<U, WM, WN>(D, ldd, start, stop);
}
METAL_FUNC void METAL_FUNC void
store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) { store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) {
// Apply epilogue // Apply epilogue

View File

@@ -73,6 +73,9 @@ instantiate_unary_all_same(Conjugate, complex64, complex64_t)
instantiate_unary_all_same(Cos, complex64, complex64_t) instantiate_unary_all_same(Cos, complex64, complex64_t)
instantiate_unary_all_same(Cosh, complex64, complex64_t) instantiate_unary_all_same(Cosh, complex64, complex64_t)
instantiate_unary_all_same(Exp, complex64, complex64_t) instantiate_unary_all_same(Exp, complex64, complex64_t)
instantiate_unary_all_same(Log, complex64, complex64_t)
instantiate_unary_all_same(Log2, complex64, complex64_t)
instantiate_unary_all_same(Log10, complex64, complex64_t)
instantiate_unary_all_same(Negative, complex64, complex64_t) instantiate_unary_all_same(Negative, complex64, complex64_t)
instantiate_unary_all_same(Sign, complex64, complex64_t) instantiate_unary_all_same(Sign, complex64, complex64_t)
instantiate_unary_all_same(Sin, complex64, complex64_t) instantiate_unary_all_same(Sin, complex64, complex64_t)

View File

@@ -257,6 +257,13 @@ struct Log {
T operator()(T x) { T operator()(T x) {
return metal::precise::log(x); return metal::precise::log(x);
}; };
template <>
complex64_t operator()(complex64_t x) {
auto r = metal::precise::log(Abs{}(x).real);
auto i = metal::precise::atan2(x.imag, x.real);
return {r, i};
};
}; };
struct Log2 { struct Log2 {
@@ -264,6 +271,12 @@ struct Log2 {
T operator()(T x) { T operator()(T x) {
return metal::precise::log2(x); return metal::precise::log2(x);
}; };
template <>
complex64_t operator()(complex64_t x) {
auto y = Log{}(x);
return {y.real / M_LN2_F, y.imag / M_LN2_F};
};
}; };
struct Log10 { struct Log10 {
@@ -271,6 +284,12 @@ struct Log10 {
T operator()(T x) { T operator()(T x) {
return metal::precise::log10(x); return metal::precise::log10(x);
}; };
template <>
complex64_t operator()(complex64_t x) {
auto y = Log{}(x);
return {y.real / M_LN10_F, y.imag / M_LN10_F};
};
}; };
struct Log1p { struct Log1p {

View File

@@ -0,0 +1,96 @@
// Copyright © 2023-2024 Apple Inc.
#include <algorithm>
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
namespace mlx::core {
constexpr int LOGSUMEXP_LOOPED_LIMIT = 4096;
void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
if (!issubdtype(out.dtype(), floating)) {
throw std::runtime_error(
"[logsumexp] Does not support non-floating point types.");
}
auto& s = stream();
auto& d = metal::device(s.device);
// Make sure that the last dimension is contiguous
auto ensure_contiguous = [&s, &d](const array& x) {
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
return x;
} else {
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
d.add_temporary(x_copy, s.index);
return x_copy;
}
};
auto in = ensure_contiguous(inputs[0]);
if (in.flags().row_contiguous) {
out.set_data(allocator::malloc(out.nbytes()));
} else {
auto n = in.shape(-1);
auto flags = in.flags();
auto strides = in.strides();
for (auto& s : strides) {
s /= n;
}
bool col_contig = strides[0] == 1;
for (int i = 1; col_contig && i < strides.size(); ++i) {
col_contig &=
(out.shape(i) == 1 || strides[i - 1] == out.shape(i) * strides[i]);
}
flags.col_contiguous = col_contig;
out.set_data(
allocator::malloc(in.nbytes() / n),
in.data_size() / n,
std::move(strides),
flags);
}
int axis_size = in.shape().back();
int n_rows = in.data_size() / axis_size;
const int simd_size = 32;
const int n_reads = 4;
const int looped_limit = LOGSUMEXP_LOOPED_LIMIT;
std::string kernel_name = (axis_size > looped_limit) ? "looped_" : "block_";
kernel_name += "logsumexp_";
kernel_name += type_to_name(out);
auto kernel = get_logsumexp_kernel(d, kernel_name, out);
auto& compute_encoder = d.get_command_encoder(s.index);
{
MTL::Size grid_dims, group_dims;
if (axis_size <= looped_limit) {
size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads;
size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size;
size_t threadgroup_size = simd_size * simds_needed;
assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup());
size_t n_threads = n_rows * threadgroup_size;
grid_dims = MTL::Size(n_threads, 1, 1);
group_dims = MTL::Size(threadgroup_size, 1, 1);
} else {
size_t threadgroup_size = kernel->maxTotalThreadsPerThreadgroup();
size_t n_threads = n_rows * threadgroup_size;
grid_dims = MTL::Size(n_threads, 1, 1);
group_dims = MTL::Size(threadgroup_size, 1, 1);
}
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder.set_bytes(axis_size, 2);
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
}
} // namespace mlx::core

View File

@@ -5,6 +5,7 @@
#include <numeric> #include <numeric>
#include <sstream> #include <sstream>
#include "mlx/backend/common/broadcasting.h"
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
@@ -102,6 +103,47 @@ std::tuple<bool, int64_t, array> check_transpose(
} }
}; };
inline array
ensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) {
if (!x.flags().row_contiguous) {
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
d.add_temporary(x_copy, s.index);
return x_copy;
} else {
return x;
}
}
inline std::tuple<bool, int64_t, array>
ensure_batch_contiguous(const array& x, metal::Device& d, const Stream& s) {
if (x.flags().row_contiguous) {
return std::make_tuple(false, x.strides()[x.ndim() - 2], x);
}
bool rc = true;
for (int i = 0; i < x.ndim() - 3; i++) {
rc &= x.strides()[i + 1] * x.shape(i) == x.strides()[i];
}
if (rc) {
auto stx = x.strides()[x.ndim() - 2];
auto sty = x.strides()[x.ndim() - 1];
auto K = x.shape(-2);
auto N = x.shape(-1);
if (sty == 1 && (N != 1 || stx == N)) {
return std::make_tuple(false, stx, x);
}
if (stx == 1 && (N != 1 || sty == K)) {
return std::make_tuple(true, sty, x);
}
}
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
d.add_temporary(x_copy, s.index);
return std::make_tuple(false, x_copy.strides()[x_copy.ndim() - 2], x_copy);
}
} // namespace } // namespace
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
@@ -230,7 +272,6 @@ void steel_matmul_regular(
const bool align_M = (M % bm) == 0; const bool align_M = (M % bm) == 0;
const bool align_N = (N % bn) == 0; const bool align_N = (N % bn) == 0;
const bool align_K = (K % bk) == 0; const bool align_K = (K % bk) == 0;
const bool do_gather = false;
metal::MTLFCList func_consts = { metal::MTLFCList func_consts = {
{&has_batch, MTL::DataType::DataTypeBool, 10}, {&has_batch, MTL::DataType::DataTypeBool, 10},
@@ -239,7 +280,6 @@ void steel_matmul_regular(
{&align_M, MTL::DataType::DataTypeBool, 200}, {&align_M, MTL::DataType::DataTypeBool, 200},
{&align_N, MTL::DataType::DataTypeBool, 201}, {&align_N, MTL::DataType::DataTypeBool, 201},
{&align_K, MTL::DataType::DataTypeBool, 202}, {&align_K, MTL::DataType::DataTypeBool, 202},
{&do_gather, MTL::DataType::DataTypeBool, 300},
}; };
// clang-format off // clang-format off
@@ -248,8 +288,7 @@ void steel_matmul_regular(
<< "_do_axpby_" << (do_axpby ? 't' : 'n') << "_do_axpby_" << (do_axpby ? 't' : 'n')
<< "_align_M_" << (align_M ? 't' : 'n') << "_align_M_" << (align_M ? 't' : 'n')
<< "_align_N_" << (align_N ? 't' : 'n') << "_align_N_" << (align_N ? 't' : 'n')
<< "_align_K_" << (align_K ? 't' : 'n') << "_align_K_" << (align_K ? 't' : 'n'); // clang-format on
<< "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on
std::string hash_name = kname.str(); std::string hash_name = kname.str();
@@ -975,7 +1014,6 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
const bool align_M = (M % bm) == 0; const bool align_M = (M % bm) == 0;
const bool align_N = (N % bn) == 0; const bool align_N = (N % bn) == 0;
const bool align_K = (K % bk) == 0; const bool align_K = (K % bk) == 0;
const bool do_gather = false;
metal::MTLFCList func_consts = { metal::MTLFCList func_consts = {
{&has_batch, MTL::DataType::DataTypeBool, 10}, {&has_batch, MTL::DataType::DataTypeBool, 10},
@@ -984,7 +1022,6 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
{&align_M, MTL::DataType::DataTypeBool, 200}, {&align_M, MTL::DataType::DataTypeBool, 200},
{&align_N, MTL::DataType::DataTypeBool, 201}, {&align_N, MTL::DataType::DataTypeBool, 201},
{&align_K, MTL::DataType::DataTypeBool, 202}, {&align_K, MTL::DataType::DataTypeBool, 202},
{&do_gather, MTL::DataType::DataTypeBool, 300},
}; };
// clang-format off // clang-format off
@@ -993,8 +1030,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
<< "_do_axpby_" << (do_axpby ? 't' : 'n') << "_do_axpby_" << (do_axpby ? 't' : 'n')
<< "_align_M_" << (align_M ? 't' : 'n') << "_align_M_" << (align_M ? 't' : 'n')
<< "_align_N_" << (align_N ? 't' : 'n') << "_align_N_" << (align_N ? 't' : 'n')
<< "_align_K_" << (align_K ? 't' : 'n') << "_align_K_" << (align_K ? 't' : 'n'); // clang-format on
<< "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on
std::string hash_name = kname.str(); std::string hash_name = kname.str();
@@ -1464,267 +1500,337 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
d.add_temporaries(std::move(copies), s.index); d.add_temporaries(std::move(copies), s.index);
} }
void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) { void gather_mm_rhs(
using namespace mlx::steel; const array& a_,
// assert(inputs.size() == 2); const array& b_,
if (!issubdtype(out.dtype(), floating)) { const array& indices_,
throw std::runtime_error( array& out,
"[GatherMM] Does not yet support non-floating point types."); metal::Device& d,
} const Stream& s) {
auto& s = stream(); array indices = ensure_row_contiguous(indices_, d, s);
auto& d = metal::device(s.device); auto [transpose_b, ldb, b] = ensure_batch_contiguous(b_, d, s);
auto& a_pre = inputs[0]; // Broadcast a with indices. If we are here that means lhs_indices were not
auto& b_pre = inputs[1]; // provided so the lhs_indices are implied to be the shape of a broadcasted
// Return 0s if either input is empty // with rhs_indices. We need only broadcast a and copy it as if applying the
if (a_pre.size() == 0 || b_pre.size() == 0) { // lhs_indices.
array zero = array(0, a_pre.dtype()); auto broadcast_with_indices = [&d, &s, &indices](const array& x) {
fill_gpu(zero, out, s); if (x.size() / x.shape(-2) / x.shape(-1) == indices.size()) {
d.add_temporary(std::move(zero), s.index); return ensure_row_contiguous(x, d, s);
return; }
}
out.set_data(allocator::malloc(out.nbytes())); auto x_shape = indices.shape();
x_shape.push_back(x.shape(-2));
x_shape.push_back(x.shape(-1));
array new_x(std::move(x_shape), x.dtype(), nullptr, {});
broadcast(x, new_x);
return ensure_row_contiguous(new_x, d, s);
};
array a = broadcast_with_indices(a_);
///////////////////////////////////////////////////////////////////////////// // Extract the matmul shapes
// Init checks and prep int K = a.shape(-1);
int M = a.size() / K;
int N = b.shape(-1);
int lda = a.strides()[a.ndim() - 2]; // should be K
int M = a_pre.shape(-2); // Define the dispatch blocks
int N = b_pre.shape(-1); int bm = 16, bn = 64, bk = 16;
int K = a_pre.shape(-1); int wm = 1, wn = 2;
// Keep a vector with copies to be cleared in the completed buffer to release const bool align_M = (M % bm) == 0;
// the arrays const bool align_N = (N % bn) == 0;
std::vector<array> copies; const bool align_K = (K % bk) == 0;
auto [transpose_a, a_cols, a] = check_transpose(copies, s, a_pre, M == 1);
auto [transpose_b, b_cols, b] = check_transpose(copies, s, b_pre, N == 1);
int lda = a_cols; // Define the kernel name
int ldb = b_cols; std::string base_name;
base_name.reserve(64);
concatenate(
base_name,
"steel_gather_mm_rhs_n",
transpose_b ? 't' : 'n',
'_',
type_to_name(a),
'_',
type_to_name(out),
"_bm",
bm,
"_bn",
bn,
"_bk",
bk,
"_wm",
wm,
"_wn",
wn);
///////////////////////////////////////////////////////////////////////////// metal::MTLFCList func_consts = {
// Check and collapse batch dimensions {&align_M, MTL::DataType::DataTypeBool, 200},
{&align_N, MTL::DataType::DataTypeBool, 201},
auto get_batch_dims = [](const auto& v) { {&align_K, MTL::DataType::DataTypeBool, 202},
return decltype(v){v.begin(), v.end() - 2};
}; };
auto& lhs_indices = inputs[2]; // And the kernel hash that includes the function constants
auto& rhs_indices = inputs[3]; std::string hash_name;
hash_name.reserve(128);
concatenate(
hash_name,
base_name,
"_align_M_",
align_M ? 't' : 'n',
"_align_N_",
align_N ? 't' : 'n',
"_align_K_",
align_K ? 't' : 'n');
Shape batch_shape = get_batch_dims(out.shape()); // Get and set the kernel
Strides batch_strides; auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = get_steel_gemm_gather_kernel(
d,
base_name,
hash_name,
func_consts,
out,
false,
transpose_b,
bm,
bn,
bk,
wm,
wn,
true);
compute_encoder.set_compute_pipeline_state(kernel);
batch_strides.insert( // Prepare the matmul params
batch_strides.end(), auto batch_stride_b = b.ndim() > 2 ? b.strides()[b.ndim() - 3] : b.size();
lhs_indices.strides().begin(), steel::GEMMParams params{
lhs_indices.strides().end()); /* const int M = */ M,
auto lhs_indices_str = batch_strides.empty() ? 0 : batch_strides.back(); /* const int N = */ N,
/* const int K = */ K,
/* const int lda = */ lda,
/* const int ldb = */ static_cast<int>(ldb),
/* const int ldd = */ N,
/* const int tiles_n = */ (N + bn - 1) / bn,
/* const int tiles_m = */ (M + bm - 1) / bm,
/* const int64_t batch_stride_a = */ 0,
/* const int64_t batch_stride_b = */ static_cast<int64_t>(batch_stride_b),
/* const int64_t batch_stride_d = */ 0,
/* const int swizzle_log = */ 0,
/* const int gemm_k_iterations_aligned = */ (K / bk),
/* const int batch_ndim = */ 0};
batch_strides.insert( // Prepare the grid
batch_strides.end(), MTL::Size group_dims = MTL::Size(32, wn, wm);
rhs_indices.strides().begin(), MTL::Size grid_dims = MTL::Size(params.tiles_n, params.tiles_m, 1);
rhs_indices.strides().end());
auto rhs_indices_str = batch_strides.empty() ? 0 : batch_strides.back();
int batch_ndim = batch_shape.size(); // Launch kernel
compute_encoder.set_input_array(a, 0);
compute_encoder.set_input_array(b, 1);
compute_encoder.set_input_array(indices, 2);
compute_encoder.set_output_array(out, 3);
compute_encoder.set_bytes(params, 4);
if (batch_ndim == 0) { compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
batch_shape = {1}; }
batch_strides = {0};
}
int batch_ndim_A = a.ndim() - 2; void gather_mv(
int batch_ndim_B = b.ndim() - 2; const array& mat_,
std::vector<int> operand_batch_ndim = {batch_ndim_A, batch_ndim_B}; const array& vec_,
const array& mat_indices_,
const array& vec_indices_,
array& out,
int N,
int K,
bool is_mv,
metal::Device& d,
const Stream& s) {
// Copy if needed
std::vector<array> copies;
auto [transpose_mat, mat_cols, mat] =
check_transpose(copies, s, mat_, N == 1);
auto [transpose_vec, vec_cols, vec] = check_transpose(copies, s, vec_, true);
d.add_temporaries(std::move(copies), s.index);
Shape batch_shape_A = get_batch_dims(a.shape()); // If we are doing vector matrix instead of matrix vector we need to flip the
Strides batch_strides_A = get_batch_dims(a.strides()); // matrix transposition. Basically m @ v = v @ m.T assuming that v is treated
Shape batch_shape_B = get_batch_dims(b.shape()); // as a one dimensional array.
Strides batch_strides_B = get_batch_dims(b.strides()); transpose_mat = (!is_mv) ^ transpose_mat;
if (batch_ndim_A == 0) { // Define some shapes
batch_shape_A = {1}; int in_vector_len = K;
batch_strides_A = {0}; int out_vector_len = N;
} int mat_ld = mat_cols;
if (batch_ndim_B == 0) { int batch_size_out = out.size() / N;
batch_shape_B = {1}; int batch_ndim = out.ndim() - 2;
batch_strides_B = {0}; int batch_ndim_mat = mat.ndim() - 2;
} int batch_ndim_vec = vec.ndim() - 2;
Strides index_strides = vec_indices_.strides();
index_strides.insert(
index_strides.end(),
mat_indices_.strides().begin(),
mat_indices_.strides().end());
auto matrix_stride_out = static_cast<int64_t>(M) * N; // Determine dispatch kernel
auto batch_size_out = out.size() / matrix_stride_out; int tm = 4, tn = 4;
int sm = 1, sn = 32;
///////////////////////////////////////////////////////////////////////////// int bm = 1, bn = 1;
// Gemv specialization int n_out_per_tgp;
std::ostringstream kname;
// Route to gemv if needed
if (std::min(M, N) == 1) {
// Collect problem info
bool is_b_matrix = N != 1;
auto& mat = is_b_matrix ? b : a;
auto& vec = is_b_matrix ? a : b;
bool transpose_mat = is_b_matrix ? !transpose_b : transpose_a;
int in_vector_len = K;
int out_vector_len = is_b_matrix ? N : M;
int mat_cols = transpose_mat ? out_vector_len : in_vector_len;
int mat_rows = transpose_mat ? in_vector_len : out_vector_len;
int mat_ld = is_b_matrix ? b_cols : a_cols;
auto batch_strides_mat = is_b_matrix ? batch_strides_B : batch_strides_A;
auto batch_strides_vec = is_b_matrix ? batch_strides_A : batch_strides_B;
auto batch_shape_mat = is_b_matrix ? batch_shape_B : batch_shape_A;
auto batch_shape_vec = is_b_matrix ? batch_shape_A : batch_shape_B;
if (!is_b_matrix) {
batch_strides = rhs_indices.strides();
batch_strides.insert(
batch_strides.end(),
lhs_indices.strides().begin(),
lhs_indices.strides().end());
}
int batch_ndim = batch_shape.size();
// Determine dispatch kernel
int tm = 4, tn = 4;
int sm = 1, sn = 32;
int bm = 1, bn = 1;
int n_out_per_tgp;
std::ostringstream kname;
if (transpose_mat) {
if (in_vector_len >= 8192 && out_vector_len >= 2048) {
sm = 4;
sn = 8;
} else {
sm = 8;
sn = 4;
}
if (out_vector_len >= 2048) {
bn = 16;
} else if (out_vector_len >= 512) {
bn = 4;
} else {
bn = 2;
}
// Specialized kernel for very small outputs
tn = out_vector_len < tn ? 1 : tn;
n_out_per_tgp = bn * sn * tn;
kname << "gemv_t_gather_" << type_to_name(out);
if (transpose_mat) {
if (in_vector_len >= 8192 && out_vector_len >= 2048) {
sm = 4;
sn = 8;
} else { } else {
bm = out_vector_len >= 4096 ? 8 : 4; sm = 8;
sn = 32; sn = 4;
// Specialized kernel for very small outputs
tm = out_vector_len < tm ? 1 : tm;
n_out_per_tgp = bm * sm * tm;
kname << "gemv_gather_" << type_to_name(out);
} }
kname << "_bm" << bm << "_bn" << bn << "_sm" << sm << "_sn" << sn << "_tm" if (out_vector_len >= 2048) {
<< tm << "_tn" << tn; bn = 16;
} else if (out_vector_len >= 512) {
bn = 4;
} else {
bn = 2;
}
// Encode and dispatch kernel // Specialized kernel for very small outputs
auto& compute_encoder = d.get_command_encoder(s.index); tn = out_vector_len < tn ? 1 : tn;
auto kernel = d.get_kernel(kname.str());
compute_encoder.set_compute_pipeline_state(kernel);
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp; n_out_per_tgp = bn * sn * tn;
MTL::Size group_dims = MTL::Size(32, bn, bm); kname << "gemv_t_gather_" << type_to_name(out);
MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out);
compute_encoder.set_input_array(mat, 0); } else {
compute_encoder.set_input_array(vec, 1); bm = out_vector_len >= 4096 ? 8 : 4;
compute_encoder.set_output_array(out, 3); sn = 32;
compute_encoder.set_bytes(in_vector_len, 4); // Specialized kernel for very small outputs
compute_encoder.set_bytes(out_vector_len, 5); tm = out_vector_len < tm ? 1 : tm;
compute_encoder.set_bytes(mat_ld, 6);
compute_encoder.set_bytes(batch_ndim, 9); n_out_per_tgp = bm * sm * tm;
compute_encoder.set_vector_bytes(batch_shape, 10); kname << "gemv_gather_" << type_to_name(out);
compute_encoder.set_vector_bytes(batch_strides, 11);
int batch_ndim_vec = batch_shape_vec.size();
compute_encoder.set_bytes(batch_ndim_vec, 12);
compute_encoder.set_vector_bytes(batch_shape_vec, 13);
compute_encoder.set_vector_bytes(batch_strides_vec, 14);
int batch_ndim_mat = batch_shape_mat.size();
compute_encoder.set_bytes(batch_ndim_mat, 15);
compute_encoder.set_vector_bytes(batch_shape_mat, 16);
compute_encoder.set_vector_bytes(batch_strides_mat, 17);
compute_encoder.set_input_array(lhs_indices, 18 + int(!is_b_matrix));
compute_encoder.set_input_array(rhs_indices, 18 + int(is_b_matrix));
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
d.add_temporaries(std::move(copies), s.index);
return;
} }
///////////////////////////////////////////////////////////////////////////// kname << "_bm" << bm << "_bn" << bn << "_sm" << sm << "_sn" << sn << "_tm"
// Regular kernel dispatch << tm << "_tn" << tn;
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder.set_compute_pipeline_state(kernel);
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
MTL::Size group_dims = MTL::Size(32, bn, bm);
MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out);
compute_encoder.set_input_array(mat, 0);
compute_encoder.set_input_array(vec, 1);
compute_encoder.set_output_array(out, 3);
compute_encoder.set_bytes(in_vector_len, 4);
compute_encoder.set_bytes(out_vector_len, 5);
compute_encoder.set_bytes(mat_ld, 6);
compute_encoder.set_bytes(batch_ndim, 9);
compute_encoder.set_vector_bytes(out.shape(), 10);
compute_encoder.set_vector_bytes(index_strides, 11);
compute_encoder.set_bytes(batch_ndim_vec, 12);
compute_encoder.set_vector_bytes(vec.shape(), 13);
compute_encoder.set_vector_bytes(vec.strides(), 14);
compute_encoder.set_bytes(batch_ndim_mat, 15);
compute_encoder.set_vector_bytes(mat.shape(), 16);
compute_encoder.set_vector_bytes(mat.strides(), 17);
compute_encoder.set_input_array(vec_indices_, 18);
compute_encoder.set_input_array(mat_indices_, 19);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
void gather_mm(
const array& a_,
const array& b_,
const array& lhs_indices,
const array& rhs_indices,
array& out,
int M,
int N,
int K,
metal::Device& d,
const Stream& s) {
// Copy if needed
std::vector<array> copies;
auto [transpose_a, lda, a] = check_transpose(copies, s, a_, false);
auto [transpose_b, ldb, b] = check_transpose(copies, s, b_, false);
d.add_temporaries(std::move(copies), s.index);
// Determine dispatch kernel // Determine dispatch kernel
int bm = 64, bn = 64, bk = 16; int bm = 64, bn = 64, bk = 16;
int wm = 2, wn = 2; int wm = 2, wn = 2;
size_t batch_size_out = out.size() / M / N;
int batch_ndim = out.ndim() - 2;
int batch_ndim_a = a.ndim() - 2;
int batch_ndim_b = b.ndim() - 2;
char devc = d.get_architecture().back(); char devc = d.get_architecture().back();
GEMM_TPARAM_MACRO(devc) GEMM_TPARAM_MACRO(devc)
// Prepare kernel name
std::ostringstream kname;
kname << "steel_gemm_fused_" << (transpose_a ? 't' : 'n')
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
<< "_wm" << wm << "_wn" << wn;
std::string base_name = kname.str();
const bool has_batch = batch_ndim > 1; const bool has_batch = batch_ndim > 1;
const bool use_out_source = false;
const bool do_axpby = false;
const bool align_M = (M % bm) == 0; const bool align_M = (M % bm) == 0;
const bool align_N = (N % bn) == 0; const bool align_N = (N % bn) == 0;
const bool align_K = (K % bk) == 0; const bool align_K = (K % bk) == 0;
const bool do_gather = true;
// Define the kernel name
std::string base_name;
base_name.reserve(128);
concatenate(
base_name,
"steel_gather_mm_",
transpose_a ? 't' : 'n',
transpose_b ? 't' : 'n',
"_",
type_to_name(a),
"_",
type_to_name(out),
"_bm",
bm,
"_bn",
bn,
"_bk",
bk,
"_wm",
wm,
"_wn",
wn);
metal::MTLFCList func_consts = { metal::MTLFCList func_consts = {
{&has_batch, MTL::DataType::DataTypeBool, 10}, {&has_batch, MTL::DataType::DataTypeBool, 10},
{&use_out_source, MTL::DataType::DataTypeBool, 100},
{&do_axpby, MTL::DataType::DataTypeBool, 110},
{&align_M, MTL::DataType::DataTypeBool, 200}, {&align_M, MTL::DataType::DataTypeBool, 200},
{&align_N, MTL::DataType::DataTypeBool, 201}, {&align_N, MTL::DataType::DataTypeBool, 201},
{&align_K, MTL::DataType::DataTypeBool, 202}, {&align_K, MTL::DataType::DataTypeBool, 202},
{&do_gather, MTL::DataType::DataTypeBool, 300},
}; };
// clang-format off // And the kernel hash that includes the function constants
kname << "_has_batch_" << (has_batch ? 't' : 'n') std::string hash_name;
<< "_use_out_source_" << (use_out_source ? 't' : 'n') hash_name.reserve(128);
<< "_do_axpby_" << (do_axpby ? 't' : 'n') concatenate(
<< "_align_M_" << (align_M ? 't' : 'n') hash_name,
<< "_align_N_" << (align_N ? 't' : 'n') base_name,
<< "_align_K_" << (align_K ? 't' : 'n') "_has_batch_",
<< "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on has_batch ? 't' : 'n',
"_align_M_",
align_M ? 't' : 'n',
"_align_N_",
align_N ? 't' : 'n',
"_align_K_",
align_K ? 't' : 'n');
std::string hash_name = kname.str(); // Get and set the kernel
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = get_steel_gemm_fused_kernel( auto kernel = get_steel_gemm_gather_kernel(
d, d,
base_name, base_name,
hash_name, hash_name,
@@ -1736,72 +1842,96 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
bn, bn,
bk, bk,
wm, wm,
wn); wn,
false);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_compute_pipeline_state(kernel);
// Use problem size to determine threadblock swizzle // Prepare the matmul params
int tn = (N + bn - 1) / bn; steel::GEMMParams params{
int tm = (M + bm - 1) / bm;
// TODO: Explore device-based tuning for swizzle
int swizzle_log = 0; // tm >= 6 ? 3 : (tm <= 3 ? 0 : 2);
// Prepare steel matmul params
GEMMParams params{
/* const int M = */ M, /* const int M = */ M,
/* const int N = */ N, /* const int N = */ N,
/* const int K = */ K, /* const int K = */ K,
/* const int lda = */ lda, /* const int lda = */ static_cast<int>(lda),
/* const int ldb = */ ldb, /* const int ldb = */ static_cast<int>(ldb),
/* const int ldd = */ N, /* const int ldd = */ N,
/* const int tiles_n = */ tn, /* const int tiles_n = */ (N + bn - 1) / bn,
/* const int tiles_m = */ tm, /* const int tiles_m = */ (M + bm - 1) / bm,
/* const int64_t batch_stride_a = */ lhs_indices_str, /* const int64_t batch_stride_a = */
/* const int64_t batch_stride_b = */ rhs_indices_str, (batch_ndim > 0) ? lhs_indices.strides()[0] : 0,
/* const int64_t batch_stride_d = */ matrix_stride_out, /* const int64_t batch_stride_b = */
/* const int swizzle_log = */ swizzle_log, (batch_ndim > 0) ? rhs_indices.strides()[0] : 0,
/* const int64_t batch_stride_d = */ M * N,
/* const int swizzle_log = */ 0,
/* const int gemm_k_iterations_aligned = */ (K / bk), /* const int gemm_k_iterations_aligned = */ (K / bk),
/* const int batch_ndim = */ batch_ndim}; /* const int batch_ndim = */ batch_ndim};
// Prepare launch grid params // Prepare the grid
int tile = 1 << swizzle_log;
tm = (tm + tile - 1) / tile;
tn = tn * tile;
MTL::Size group_dims = MTL::Size(32, wn, wm); MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out); MTL::Size grid_dims =
MTL::Size(params.tiles_n, params.tiles_m, batch_size_out);
// Launch kernel // Launch kernel
compute_encoder.set_input_array(a, 0); compute_encoder.set_input_array(a, 0);
compute_encoder.set_input_array(b, 1); compute_encoder.set_input_array(b, 1);
compute_encoder.set_output_array(out, 3); compute_encoder.set_input_array(lhs_indices, 2);
compute_encoder.set_input_array(rhs_indices, 3);
compute_encoder.set_bytes(params, 4); compute_encoder.set_output_array(out, 4);
compute_encoder.set_bytes(params, 5);
compute_encoder.set_vector_bytes(batch_shape, 6); compute_encoder.set_vector_bytes(lhs_indices.shape(), 6);
compute_encoder.set_vector_bytes(batch_strides, 7); compute_encoder.set_vector_bytes(lhs_indices.strides(), 7);
compute_encoder.set_vector_bytes(rhs_indices.strides(), 8);
compute_encoder.set_input_array(lhs_indices, 10); compute_encoder.set_bytes(batch_ndim_a, 9);
compute_encoder.set_input_array(rhs_indices, 11); compute_encoder.set_vector_bytes(a.shape(), 10);
compute_encoder.set_vector_bytes(a.strides(), 11);
std::vector operand_shape = batch_shape_A; compute_encoder.set_bytes(batch_ndim_b, 12);
operand_shape.insert( compute_encoder.set_vector_bytes(b.shape(), 13);
operand_shape.end(), batch_shape_B.begin(), batch_shape_B.end()); compute_encoder.set_vector_bytes(b.strides(), 14);
std::vector operand_strides = batch_strides_A;
operand_strides.insert(
operand_strides.end(), batch_strides_B.begin(), batch_strides_B.end());
operand_batch_ndim.push_back(0);
compute_encoder.set_vector_bytes(operand_shape, 13);
compute_encoder.set_vector_bytes(operand_strides, 14);
compute_encoder.set_vector_bytes(operand_batch_ndim, 15);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims); compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
d.add_temporaries(std::move(copies), s.index); void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = stream();
auto& d = metal::device(s.device);
auto& a = inputs[0];
auto& b = inputs[1];
auto& lhs_indices = inputs[2];
auto& rhs_indices = inputs[3];
// Return 0s if either input is empty
if (a.size() == 0 || b.size() == 0) {
array zero = array(0, a.dtype());
fill_gpu(zero, out, s);
d.add_temporary(std::move(zero), s.index);
return;
}
out.set_data(allocator::malloc(out.nbytes()));
// Extract shapes from inputs.
int M = a.shape(-2);
int N = b.shape(-1);
int K = a.shape(-1);
// We are walking a in order and b is also in order so we can batch up the
// matmuls and reuse reading a and b.
if (M == 1 && right_sorted_ == true) {
gather_mm_rhs(a, b, rhs_indices, out, d, s);
return;
}
// Route to gather gemv if any of a or b are vectors
if (M == 1) {
gather_mv(b, a, rhs_indices, lhs_indices, out, N, K, false, d, s);
return;
}
if (N == 1) {
gather_mv(a, b, lhs_indices, rhs_indices, out, M, K, true, d, s);
return;
}
// Route to non specialized gather mm
gather_mm(a, b, lhs_indices, rhs_indices, out, M, N, K, d, s);
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -12,74 +12,6 @@ namespace mlx::core::metal {
/* Check if the Metal backend is available. */ /* Check if the Metal backend is available. */
bool is_available(); bool is_available();
/* Get the actively used memory in bytes.
*
* Note, this will not always match memory use reported by the system because
* it does not include cached memory buffers.
* */
size_t get_active_memory();
/* Get the peak amount of used memory in bytes.
*
* The maximum memory used recorded from the beginning of the program
* execution or since the last call to reset_peak_memory.
* */
size_t get_peak_memory();
/* Reset the peak memory to zero.
* */
void reset_peak_memory();
/* Get the cache size in bytes.
*
* The cache includes memory not currently used that has not been returned
* to the system allocator.
* */
size_t get_cache_memory();
/* Set the memory limit.
* The memory limit is a guideline for the maximum amount of memory to use
* during graph evaluation. If the memory limit is exceeded and there is no
* more RAM (including swap when available) allocations will result in an
* exception.
*
* When metal is available the memory limit defaults to 1.5 times the maximum
* recommended working set size reported by the device.
*
* Returns the previous memory limit.
* */
size_t set_memory_limit(size_t limit);
/* Get the current memory limit. */
size_t get_memory_limit();
/* Set the free cache limit.
* If using more than the given limit, free memory will be reclaimed
* from the cache on the next allocation. To disable the cache,
* set the limit to 0.
*
* The cache limit defaults to the memory limit.
*
* Returns the previous cache limit.
* */
size_t set_cache_limit(size_t limit);
/* Clear the memory cache. */
void clear_cache();
/* Set the wired size limit.
*
* Note, this function is only useful for macOS 15.0 or higher.
*
* The wired limit is the total size in bytes of memory that will be kept
* resident. The default value is ``0``.
*
* Setting a wired limit larger than system wired limit is an error.
*
* Returns the previous wired limit.
* */
size_t set_wired_limit(size_t limit);
/** Capture a GPU trace, saving it to an absolute file `path` */ /** Capture a GPU trace, saving it to an absolute file `path` */
void start_capture(std::string path = ""); void start_capture(std::string path = "");
void stop_capture(); void stop_capture();

View File

@@ -72,6 +72,13 @@ MTL::ComputePipelineState* get_softmax_kernel(
return d.get_kernel(kernel_name); return d.get_kernel(kernel_name);
} }
MTL::ComputePipelineState* get_logsumexp_kernel(
metal::Device& d,
const std::string& kernel_name,
const array&) {
return d.get_kernel(kernel_name);
}
MTL::ComputePipelineState* get_scan_kernel( MTL::ComputePipelineState* get_scan_kernel(
metal::Device& d, metal::Device& d,
const std::string& kernel_name, const std::string& kernel_name,
@@ -186,6 +193,23 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
return d.get_kernel(kernel_name); return d.get_kernel(kernel_name);
} }
MTL::ComputePipelineState* get_steel_gemm_gather_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const array&,
bool,
bool,
int,
int,
int,
int,
int,
bool) {
return d.get_kernel(kernel_name, "mlx", hash_name, func_consts);
}
MTL::ComputePipelineState* get_gemv_masked_kernel( MTL::ComputePipelineState* get_gemv_masked_kernel(
metal::Device& d, metal::Device& d,
const std::string& kernel_name, const std::string& kernel_name,
@@ -245,4 +269,21 @@ MTL::ComputePipelineState* get_quantized_kernel(
return d.get_kernel(kernel_name); return d.get_kernel(kernel_name);
} }
MTL::ComputePipelineState* get_gather_qmm_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const array&,
int,
int,
int,
int,
int,
int,
int,
bool) {
return d.get_kernel(kernel_name, "mlx", hash_name, func_consts);
}
} // namespace mlx::core } // namespace mlx::core

View File

@@ -251,8 +251,10 @@ void Concatenate::eval_gpu(const std::vector<array>& inputs, array& out) {
void Contiguous::eval_gpu(const std::vector<array>& inputs, array& out) { void Contiguous::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
if (in.flags().row_contiguous || constexpr size_t extra_bytes = 16384;
(allow_col_major_ && in.flags().col_contiguous)) { if (in.buffer_size() <= out.nbytes() + extra_bytes &&
(in.flags().row_contiguous ||
(allow_col_major_ && in.flags().col_contiguous))) {
out.copy_shared_buffer(in); out.copy_shared_buffer(in);
} else { } else {
copy_gpu(in, out, CopyType::General); copy_gpu(in, out, CopyType::General);

File diff suppressed because it is too large Load Diff

View File

@@ -22,6 +22,7 @@ ResidencySet::ResidencySet(MTL::Device* d) {
} }
throw std::runtime_error(msg.str()); throw std::runtime_error(msg.str());
} }
wired_set_->requestResidency();
} }
} }
@@ -32,7 +33,6 @@ void ResidencySet::insert(MTL::Allocation* buf) {
if (wired_set_->allocatedSize() + buf->allocatedSize() <= capacity_) { if (wired_set_->allocatedSize() + buf->allocatedSize() <= capacity_) {
wired_set_->addAllocation(buf); wired_set_->addAllocation(buf);
wired_set_->commit(); wired_set_->commit();
wired_set_->requestResidency();
} else { } else {
unwired_set_.insert(buf); unwired_set_.insert(buf);
} }
@@ -76,7 +76,6 @@ void ResidencySet::resize(size_t size) {
} }
} }
wired_set_->commit(); wired_set_->commit();
wired_set_->requestResidency();
} else if (current_size > size) { } else if (current_size > size) {
auto pool = new_scoped_memory_pool(); auto pool = new_scoped_memory_pool();
// Remove wired allocations until under capacity // Remove wired allocations until under capacity

View File

@@ -138,6 +138,7 @@ void sdpa_vector(
const array& v, const array& v,
array& out, array& out,
float scale, float scale,
bool do_causal,
const std::optional<array>& mask) { const std::optional<array>& mask) {
// Set the kernel name // Set the kernel name
std::string kname; std::string kname;
@@ -162,14 +163,20 @@ void sdpa_vector(
MTL::Size grid_dims(B, q.shape(2), 1); MTL::Size grid_dims(B, q.shape(2), 1);
bool has_mask = mask.has_value(); bool has_mask = mask.has_value();
bool bool_mask = has_mask && (*mask).dtype() == bool_;
bool float_mask = has_mask && !bool_mask;
bool query_transposed = !q.flags().row_contiguous; bool query_transposed = !q.flags().row_contiguous;
metal::MTLFCList func_consts = { metal::MTLFCList func_consts = {
{&has_mask, MTL::DataType::DataTypeBool, 20}, {&has_mask, MTL::DataType::DataTypeBool, 20},
{&query_transposed, MTL::DataType::DataTypeBool, 21}, {&query_transposed, MTL::DataType::DataTypeBool, 21},
{&do_causal, MTL::DataType::DataTypeBool, 22},
{&bool_mask, MTL::DataType::DataTypeBool, 23},
{&float_mask, MTL::DataType::DataTypeBool, 24},
}; };
std::string hash_name = kname; std::string hash_name = kname;
hash_name += has_mask ? "_mask" : "_nomask"; hash_name += has_mask ? (bool_mask ? "_boolmask" : "_floatmask") : "_nomask";
hash_name += query_transposed ? "_qt" : "_qnt"; hash_name += query_transposed ? "_qt" : "_qnt";
hash_name += do_causal ? "_c" : "_nc";
// Get the kernel // Get the kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
@@ -191,15 +198,15 @@ void sdpa_vector(
compute_encoder.set_bytes(scale, 10); compute_encoder.set_bytes(scale, 10);
if (has_mask) { if (has_mask) {
auto& m = *mask; auto& m = *mask;
compute_encoder.set_input_array(m, 11); compute_encoder.set_input_array(m, 11 + float_mask);
auto nd = m.ndim(); auto nd = m.ndim();
int32_t kv_seq_stride = int32_t kv_seq_stride =
nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0; nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0;
int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0; int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0;
int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0; int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0;
compute_encoder.set_bytes(kv_seq_stride, 12); compute_encoder.set_bytes(kv_seq_stride, 13);
compute_encoder.set_bytes(q_seq_stride, 13); compute_encoder.set_bytes(q_seq_stride, 14);
compute_encoder.set_bytes(head_stride, 14); compute_encoder.set_bytes(head_stride, 15);
} }
// Launch // Launch
@@ -214,6 +221,7 @@ void sdpa_vector_2pass(
const array& v, const array& v,
array& out, array& out,
float scale, float scale,
bool do_causal,
const std::optional<array>& mask) { const std::optional<array>& mask) {
// Set the kernel name // Set the kernel name
std::string kname; std::string kname;
@@ -256,14 +264,20 @@ void sdpa_vector_2pass(
d.add_temporary(maxs, s.index); d.add_temporary(maxs, s.index);
bool has_mask = mask.has_value(); bool has_mask = mask.has_value();
bool bool_mask = has_mask && (*mask).dtype() == bool_;
bool float_mask = has_mask && !bool_mask;
bool query_transposed = !q.flags().row_contiguous; bool query_transposed = !q.flags().row_contiguous;
metal::MTLFCList func_consts = { metal::MTLFCList func_consts = {
{&has_mask, MTL::DataType::DataTypeBool, 20}, {&has_mask, MTL::DataType::DataTypeBool, 20},
{&query_transposed, MTL::DataType::DataTypeBool, 21}, {&query_transposed, MTL::DataType::DataTypeBool, 21},
{&do_causal, MTL::DataType::DataTypeBool, 22},
{&bool_mask, MTL::DataType::DataTypeBool, 23},
{&float_mask, MTL::DataType::DataTypeBool, 24},
}; };
std::string hash_name = kname; std::string hash_name = kname;
hash_name += has_mask ? "_mask" : "_nomask"; hash_name += has_mask ? (bool_mask ? "_boolmask" : "_floatmask") : "_nomask";
hash_name += query_transposed ? "_qt" : "_qnt"; hash_name += query_transposed ? "_qt" : "_qnt";
hash_name += do_causal ? "_c" : "_nc";
// Get the kernel // Get the kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
@@ -287,15 +301,15 @@ void sdpa_vector_2pass(
compute_encoder.set_bytes(scale, 12); compute_encoder.set_bytes(scale, 12);
if (has_mask) { if (has_mask) {
auto& m = *mask; auto& m = *mask;
compute_encoder.set_input_array(m, 13); compute_encoder.set_input_array(m, 13 + float_mask);
auto nd = m.ndim(); auto nd = m.ndim();
int32_t kv_seq_stride = int32_t kv_seq_stride =
nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0; nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0;
int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0; int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0;
int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0; int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0;
compute_encoder.set_bytes(kv_seq_stride, 14); compute_encoder.set_bytes(kv_seq_stride, 15);
compute_encoder.set_bytes(q_seq_stride, 15); compute_encoder.set_bytes(q_seq_stride, 16);
compute_encoder.set_bytes(head_stride, 16); compute_encoder.set_bytes(head_stride, 17);
} }
// Launch // Launch
@@ -401,12 +415,13 @@ void ScaledDotProductAttention::eval_gpu(
// We route to the 2 pass fused attention if // We route to the 2 pass fused attention if
// - The device is large and the sequence length long // - The device is large and the sequence length long
// - The sequence length is even longer and we have gqa // - The sequence length is even longer and we have gqa
bool do_causal = do_causal_ && q.shape(2) > 1;
char devc = d.get_architecture().back(); char devc = d.get_architecture().back();
if ((devc == 'd' && k.shape(2) >= 1024) || if ((devc == 'd' && k.shape(2) >= 1024) ||
(k.shape(1) < q.shape(1) && k.shape(2) >= 4096)) { (k.shape(1) < q.shape(1) && k.shape(2) >= 4096)) {
sdpa_vector_2pass(s, d, q, k, v, o, scale_, mask); sdpa_vector_2pass(s, d, q, k, v, o, scale_, do_causal, mask);
} else { } else {
sdpa_vector(s, d, q, k, v, o, scale_, mask); sdpa_vector(s, d, q, k, v, o, scale_, do_causal, mask);
} }
} }

View File

@@ -60,6 +60,9 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
case Scan::Min: case Scan::Min:
reduce_type = "min"; reduce_type = "min";
break; break;
case Scan::LogAddExp:
reduce_type = "logaddexp";
break;
} }
kname << reduce_type << "_" << type_to_name(in) << "_" << type_to_name(out); kname << reduce_type << "_" << type_to_name(in) << "_" << type_to_name(out);
auto kernel = get_scan_kernel( auto kernel = get_scan_kernel(

View File

@@ -23,12 +23,7 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
// Make sure that the last dimension is contiguous // Make sure that the last dimension is contiguous
auto set_output = [&s, &out](const array& x) { auto set_output = [&s, &out](const array& x) {
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1; if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
if (no_copy && x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2];
no_copy &= (s == 0 || s == x.shape().back());
}
if (no_copy) {
if (x.is_donatable()) { if (x.is_donatable()) {
out.copy_shared_buffer(x); out.copy_shared_buffer(x);
} else { } else {

View File

@@ -2,6 +2,8 @@
#pragma once #pragma once
#include <type_traits>
#include "mlx/array.h" #include "mlx/array.h"
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
@@ -58,14 +60,27 @@ inline void debug_set_primitive_buffer_label(
std::string get_primitive_string(Primitive* primitive); std::string get_primitive_string(Primitive* primitive);
template <typename T>
constexpr bool is_numeric_except_char = std::is_arithmetic_v<T> &&
!std::is_same_v<T, char> && !std::is_same_v<T, signed char> &&
!std::is_same_v<T, unsigned char> && !std::is_same_v<T, wchar_t>;
template <typename T> template <typename T>
void concatenate(std::string& acc, T first) { void concatenate(std::string& acc, T first) {
acc += first; if constexpr (is_numeric_except_char<T>) {
acc += std::to_string(first);
} else {
acc += first;
}
} }
template <typename T, typename... Args> template <typename T, typename... Args>
void concatenate(std::string& acc, T first, Args... args) { void concatenate(std::string& acc, T first, Args... args) {
acc += first; if constexpr (is_numeric_except_char<T>) {
acc += std::to_string(first);
} else {
acc += first;
}
concatenate(acc, args...); concatenate(acc, args...);
} }

View File

@@ -82,6 +82,7 @@ NO_CPU(LogicalNot)
NO_CPU(LogicalAnd) NO_CPU(LogicalAnd)
NO_CPU(LogicalOr) NO_CPU(LogicalOr)
NO_CPU(LogAddExp) NO_CPU(LogAddExp)
NO_CPU(LogSumExp)
NO_CPU_MULTI(LUF) NO_CPU_MULTI(LUF)
NO_CPU(Matmul) NO_CPU(Matmul)
NO_CPU(Maximum) NO_CPU(Maximum)

View File

@@ -1,14 +1,72 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023 Apple Inc.
#include <algorithm>
#include <mutex>
#include "mlx/allocator.h" #include "mlx/allocator.h"
namespace mlx::core::allocator { #ifdef __APPLE__
#include "mlx/backend/no_metal/apple_memory.h"
#elif defined(__linux__)
#include "mlx/backend/no_metal/linux_memory.h"
#else
size_t get_memory_size() {
return 0;
}
#endif
Allocator& allocator() { namespace mlx::core {
namespace allocator {
class CommonAllocator : public Allocator {
/** A general CPU allocator. */
public:
virtual Buffer malloc(size_t size) override;
virtual void free(Buffer buffer) override;
virtual size_t size(Buffer buffer) const override;
size_t get_active_memory() const {
return active_memory_;
};
size_t get_peak_memory() const {
return peak_memory_;
};
void reset_peak_memory() {
std::unique_lock lk(mutex_);
peak_memory_ = 0;
};
size_t get_memory_limit() {
return memory_limit_;
}
size_t set_memory_limit(size_t limit) {
std::unique_lock lk(mutex_);
std::swap(memory_limit_, limit);
return limit;
}
private:
size_t memory_limit_;
size_t active_memory_{0};
size_t peak_memory_{0};
std::mutex mutex_;
CommonAllocator() : memory_limit_(0.8 * get_memory_size()) {
if (memory_limit_ == 0) {
memory_limit_ = 1UL << 33;
}
};
friend CommonAllocator& common_allocator();
};
CommonAllocator& common_allocator() {
static CommonAllocator allocator_; static CommonAllocator allocator_;
return allocator_; return allocator_;
} }
Allocator& allocator() {
return common_allocator();
}
void* Buffer::raw_ptr() { void* Buffer::raw_ptr() {
if (!ptr_) { if (!ptr_) {
return nullptr; return nullptr;
@@ -16,4 +74,59 @@ void* Buffer::raw_ptr() {
return static_cast<size_t*>(ptr_) + 1; return static_cast<size_t*>(ptr_) + 1;
} }
} // namespace mlx::core::allocator Buffer CommonAllocator::malloc(size_t size) {
void* ptr = std::malloc(size + sizeof(size_t));
if (ptr != nullptr) {
*static_cast<size_t*>(ptr) = size;
}
std::unique_lock lk(mutex_);
active_memory_ += size;
peak_memory_ = std::max(active_memory_, peak_memory_);
return Buffer{ptr};
}
void CommonAllocator::free(Buffer buffer) {
auto sz = size(buffer);
std::free(buffer.ptr());
std::unique_lock lk(mutex_);
active_memory_ -= sz;
}
size_t CommonAllocator::size(Buffer buffer) const {
if (buffer.ptr() == nullptr) {
return 0;
}
return *static_cast<size_t*>(buffer.ptr());
}
} // namespace allocator
size_t get_active_memory() {
return allocator::common_allocator().get_active_memory();
}
size_t get_peak_memory() {
return allocator::common_allocator().get_peak_memory();
}
void reset_peak_memory() {
return allocator::common_allocator().reset_peak_memory();
}
size_t set_memory_limit(size_t limit) {
return allocator::common_allocator().set_memory_limit(limit);
}
size_t get_memory_limit() {
return allocator::common_allocator().get_memory_limit();
}
// No-ops for common allocator
size_t get_cache_memory() {
return 0;
}
size_t set_cache_limit(size_t) {
return 0;
}
size_t set_wired_limit(size_t) {
return 0;
}
void clear_cache() {}
} // namespace mlx::core

View File

@@ -0,0 +1,16 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <sys/sysctl.h>
namespace {
size_t get_memory_size() {
size_t memsize = 0;
size_t length = sizeof(memsize);
sysctlbyname("hw.memsize", &memsize, &length, NULL, 0);
return memsize;
}
} // namespace

View File

@@ -28,21 +28,19 @@ void Event::wait() {
ec->cv.wait(lk, [value = value(), ec] { return ec->value >= value; }); ec->cv.wait(lk, [value = value(), ec] { return ec->value >= value; });
} }
void Event::signal() {
auto ec = static_cast<EventCounter*>(event_.get());
{
std::lock_guard<std::mutex> lk(ec->mtx);
ec->value = value();
}
ec->cv.notify_all();
}
void Event::wait(Stream stream) { void Event::wait(Stream stream) {
scheduler::enqueue(stream, [*this]() mutable { wait(); }); scheduler::enqueue(stream, [*this]() mutable { wait(); });
} }
void Event::signal(Stream stream) { void Event::signal(Stream stream) {
scheduler::enqueue(stream, [*this]() mutable { signal(); }); scheduler::enqueue(stream, [*this]() mutable {
auto ec = static_cast<EventCounter*>(event_.get());
{
std::lock_guard<std::mutex> lk(ec->mtx);
ec->value = value();
}
ec->cv.notify_all();
});
} }
bool Event::is_signaled() const { bool Event::is_signaled() const {

View File

@@ -0,0 +1,22 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <sys/sysinfo.h>
namespace {
size_t get_memory_size() {
struct sysinfo info;
if (sysinfo(&info) != 0) {
return 0;
}
size_t total_ram = info.totalram;
total_ram *= info.mem_unit;
return total_ram;
}
} // namespace

View File

@@ -31,33 +31,8 @@ void synchronize(Stream) {
"[metal::synchronize] Cannot synchronize GPU without metal backend"); "[metal::synchronize] Cannot synchronize GPU without metal backend");
} }
// No-ops when Metal is not available.
size_t get_active_memory() {
return 0;
}
size_t get_peak_memory() {
return 0;
}
void reset_peak_memory() {}
size_t get_cache_memory() {
return 0;
}
size_t set_memory_limit(size_t) {
return 0;
}
size_t get_memory_limit() {
return 0;
}
size_t set_cache_limit(size_t) {
return 0;
}
size_t set_wired_limit(size_t) {
return 0;
}
void start_capture(std::string) {} void start_capture(std::string) {}
void stop_capture() {} void stop_capture() {}
void clear_cache() {}
const std::unordered_map<std::string, std::variant<std::string, size_t>>& const std::unordered_map<std::string, std::variant<std::string, size_t>>&
device_info() { device_info() {

View File

@@ -82,6 +82,7 @@ NO_GPU(LogicalNot)
NO_GPU(LogicalAnd) NO_GPU(LogicalAnd)
NO_GPU(LogicalOr) NO_GPU(LogicalOr)
NO_GPU(LogAddExp) NO_GPU(LogAddExp)
NO_GPU(LogSumExp)
NO_GPU_MULTI(LUF) NO_GPU_MULTI(LUF)
NO_GPU(Matmul) NO_GPU(Matmul)
NO_GPU(Maximum) NO_GPU(Maximum)

View File

@@ -15,6 +15,14 @@ void all_sum(Group group, const array& input, array& output, Stream stream) {
group.raw_group()->all_sum(input, output, stream); group.raw_group()->all_sum(input, output, stream);
} }
void all_max(Group group, const array& input, array& output, Stream stream) {
group.raw_group()->all_max(input, output, stream);
}
void all_min(Group group, const array& input, array& output, Stream stream) {
group.raw_group()->all_min(input, output, stream);
}
void all_gather(Group group, const array& input, array& output, Stream stream) { void all_gather(Group group, const array& input, array& output, Stream stream) {
group.raw_group()->all_gather(input, output, stream); group.raw_group()->all_gather(input, output, stream);
} }
@@ -57,6 +65,16 @@ class EmptyGroup : public GroupImpl {
throw std::runtime_error( throw std::runtime_error(
"Communication not implemented in an empty distributed group."); "Communication not implemented in an empty distributed group.");
} }
void all_max(const array&, array&, Stream) override {
throw std::runtime_error(
"Communication not implemented in an empty distributed group.");
}
void all_min(const array&, array&, Stream) override {
throw std::runtime_error(
"Communication not implemented in an empty distributed group.");
}
}; };
} // namespace detail } // namespace detail

View File

@@ -21,6 +21,8 @@ class GroupImpl {
virtual void all_gather(const array& input, array& output, Stream stream) = 0; virtual void all_gather(const array& input, array& output, Stream stream) = 0;
virtual void send(const array& input, int dst, Stream stream) = 0; virtual void send(const array& input, int dst, Stream stream) = 0;
virtual void recv(array& out, int src, Stream stream) = 0; virtual void recv(array& out, int src, Stream stream) = 0;
virtual void all_max(const array& input, array& output, Stream stream) = 0;
virtual void all_min(const array& input, array& output, Stream stream) = 0;
}; };
/* Perform an all reduce sum operation */ /* Perform an all reduce sum operation */
@@ -35,4 +37,10 @@ void send(Group group, const array& input, int dst, Stream stream);
/** Recv an array from the src rank */ /** Recv an array from the src rank */
void recv(Group group, array& out, int src, Stream stream); void recv(Group group, array& out, int src, Stream stream);
/** Max reduction */
void all_max(Group group, const array& input, array& output, Stream stream);
/** Min reduction */
void all_min(Group group, const array& input, array& output, Stream stream);
} // namespace mlx::core::distributed::detail } // namespace mlx::core::distributed::detail

View File

@@ -1,4 +1,4 @@
if(MPI_FOUND AND MLX_BUILD_CPU) if(MLX_BUILD_CPU)
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/mpi.cpp) target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/mpi.cpp)
else() else()
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_mpi.cpp) target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_mpi.cpp)

View File

@@ -1,12 +1,13 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
#include <dlfcn.h> #include <dlfcn.h>
#include <mpi.h> #include <iostream>
#include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/encoder.h"
#include "mlx/distributed/distributed.h" #include "mlx/distributed/distributed.h"
#include "mlx/distributed/distributed_impl.h" #include "mlx/distributed/distributed_impl.h"
#include "mlx/distributed/mpi/mpi.h" #include "mlx/distributed/mpi/mpi.h"
#include "mlx/distributed/mpi/mpi_declarations.h"
#define LOAD_SYMBOL(symbol, variable) \ #define LOAD_SYMBOL(symbol, variable) \
{ \ { \
@@ -18,6 +19,12 @@
} \ } \
} }
#ifdef __APPLE__
static constexpr const char* libmpi_name = "libmpi.dylib";
#else
static constexpr const char* libmpi_name = "libmpi.so";
#endif
namespace mlx::core::distributed::mpi { namespace mlx::core::distributed::mpi {
using GroupImpl = mlx::core::distributed::detail::GroupImpl; using GroupImpl = mlx::core::distributed::detail::GroupImpl;
@@ -43,15 +50,69 @@ void simple_sum(
template void simple_sum<float16_t>(void*, void*, int*, MPI_Datatype*); template void simple_sum<float16_t>(void*, void*, int*, MPI_Datatype*);
template void simple_sum<bfloat16_t>(void*, void*, int*, MPI_Datatype*); template void simple_sum<bfloat16_t>(void*, void*, int*, MPI_Datatype*);
template <typename T>
void simple_max(
void* input,
void* accumulator,
int* len,
MPI_Datatype* datatype) {
T* in = (T*)input;
T* acc = (T*)accumulator;
int N = *len;
while (N-- > 0) {
*acc = std::max(*acc, *in);
acc++;
in++;
}
}
template void simple_max<float16_t>(void*, void*, int*, MPI_Datatype*);
template void simple_max<bfloat16_t>(void*, void*, int*, MPI_Datatype*);
template void simple_max<complex64_t>(void*, void*, int*, MPI_Datatype*);
template <typename T>
void simple_min(
void* input,
void* accumulator,
int* len,
MPI_Datatype* datatype) {
T* in = (T*)input;
T* acc = (T*)accumulator;
int N = *len;
while (N-- > 0) {
*acc = std::min(*acc, *in);
acc++;
in++;
}
}
template void simple_min<float16_t>(void*, void*, int*, MPI_Datatype*);
template void simple_min<bfloat16_t>(void*, void*, int*, MPI_Datatype*);
template void simple_min<complex64_t>(void*, void*, int*, MPI_Datatype*);
struct MPIWrapper { struct MPIWrapper {
MPIWrapper() { MPIWrapper() {
initialized_ = false; initialized_ = false;
libmpi_handle_ = dlopen("libmpi.dylib", RTLD_NOW | RTLD_GLOBAL); libmpi_handle_ = dlopen(libmpi_name, RTLD_NOW | RTLD_GLOBAL);
if (libmpi_handle_ == nullptr) { if (libmpi_handle_ == nullptr) {
return; return;
} }
// Check library version and warn if it isn't Open MPI
int (*get_version)(char*, int*);
LOAD_SYMBOL(MPI_Get_library_version, get_version);
char version_ptr[MPI_MAX_LIBRARY_VERSION_STRING];
int version_length = 0;
get_version(version_ptr, &version_length);
std::string_view version(version_ptr, version_length);
if (version.find("Open MPI") == std::string::npos) {
std::cerr << "[mpi] MPI found but it does not appear to be Open MPI."
<< "MLX requires Open MPI but this is " << version << std::endl;
libmpi_handle_ = nullptr;
return;
}
// API // API
LOAD_SYMBOL(MPI_Init, init); LOAD_SYMBOL(MPI_Init, init);
LOAD_SYMBOL(MPI_Finalize, finalize); LOAD_SYMBOL(MPI_Finalize, finalize);
@@ -72,6 +133,8 @@ struct MPIWrapper {
// Ops // Ops
LOAD_SYMBOL(ompi_mpi_op_sum, op_sum_); LOAD_SYMBOL(ompi_mpi_op_sum, op_sum_);
LOAD_SYMBOL(ompi_mpi_op_max, op_max_);
LOAD_SYMBOL(ompi_mpi_op_min, op_min_);
// Datatypes // Datatypes
LOAD_SYMBOL(ompi_mpi_c_bool, mpi_bool_); LOAD_SYMBOL(ompi_mpi_c_bool, mpi_bool_);
@@ -106,9 +169,15 @@ struct MPIWrapper {
mpi_type_contiguous(2, mpi_uint8_, &mpi_bfloat16_); mpi_type_contiguous(2, mpi_uint8_, &mpi_bfloat16_);
mpi_type_commit(&mpi_bfloat16_); mpi_type_commit(&mpi_bfloat16_);
// Custom sum ops // Custom reduction ops
mpi_op_create(&simple_sum<float16_t>, 1, &op_sum_f16_); mpi_op_create(&simple_sum<float16_t>, 1, &op_sum_f16_);
mpi_op_create(&simple_sum<bfloat16_t>, 1, &op_sum_bf16_); mpi_op_create(&simple_sum<bfloat16_t>, 1, &op_sum_bf16_);
mpi_op_create(&simple_max<float16_t>, 1, &op_max_f16_);
mpi_op_create(&simple_max<bfloat16_t>, 1, &op_max_bf16_);
mpi_op_create(&simple_max<complex64_t>, 1, &op_max_c64_);
mpi_op_create(&simple_min<float16_t>, 1, &op_min_f16_);
mpi_op_create(&simple_min<bfloat16_t>, 1, &op_min_bf16_);
mpi_op_create(&simple_min<complex64_t>, 1, &op_min_c64_);
initialized_ = true; initialized_ = true;
} }
@@ -170,6 +239,32 @@ struct MPIWrapper {
} }
} }
MPI_Op op_max(const array& arr) {
switch (arr.dtype()) {
case float16:
return op_max_f16_;
case bfloat16:
return op_max_bf16_;
case complex64:
return op_max_c64_;
default:
return op_max_;
}
}
MPI_Op op_min(const array& arr) {
switch (arr.dtype()) {
case float16:
return op_min_f16_;
case bfloat16:
return op_min_bf16_;
case complex64:
return op_min_c64_;
default:
return op_min_;
}
}
void* libmpi_handle_; void* libmpi_handle_;
// API // API
@@ -198,6 +293,14 @@ struct MPIWrapper {
MPI_Op op_sum_; MPI_Op op_sum_;
MPI_Op op_sum_f16_; MPI_Op op_sum_f16_;
MPI_Op op_sum_bf16_; MPI_Op op_sum_bf16_;
MPI_Op op_max_;
MPI_Op op_max_f16_;
MPI_Op op_max_bf16_;
MPI_Op op_max_c64_;
MPI_Op op_min_;
MPI_Op op_min_f16_;
MPI_Op op_min_bf16_;
MPI_Op op_min_c64_;
// Datatypes // Datatypes
MPI_Datatype mpi_bool_; MPI_Datatype mpi_bool_;
@@ -285,6 +388,36 @@ class MPIGroup : public GroupImpl {
comm_); comm_);
} }
void all_max(const array& input, array& output, Stream stream) override {
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(input);
encoder.set_output_array(output);
encoder.dispatch(
mpi().all_reduce,
(input.data<void>() == output.data<void>()) ? MPI_IN_PLACE
: input.data<void>(),
output.data<void>(),
input.size(),
mpi().datatype(input),
mpi().op_max(input),
comm_);
}
void all_min(const array& input, array& output, Stream stream) override {
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(input);
encoder.set_output_array(output);
encoder.dispatch(
mpi().all_reduce,
(input.data<void>() == output.data<void>()) ? MPI_IN_PLACE
: input.data<void>(),
output.data<void>(),
input.size(),
mpi().datatype(input),
mpi().op_min(input),
comm_);
}
void all_gather(const array& input, array& output, Stream stream) override { void all_gather(const array& input, array& output, Stream stream) override {
auto& encoder = cpu::get_command_encoder(stream); auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(input); encoder.set_input_array(input);

View File

@@ -0,0 +1,28 @@
// Copyright © 2024 Apple Inc.
// Constants
#define MPI_SUCCESS 0
#define MPI_ANY_SOURCE -1
#define MPI_ANY_TAG -1
#define MPI_IN_PLACE ((void*)1)
#define MPI_MAX_LIBRARY_VERSION_STRING 256
// Define all the types that we use so that we don't include <mpi.h> which
// causes linker errors on some platforms.
//
// NOTE: We define everything for openmpi.
typedef void* MPI_Comm;
typedef void* MPI_Datatype;
typedef void* MPI_Op;
typedef void(MPI_User_function)(void*, void*, int*, MPI_Datatype*);
typedef struct ompi_status_public_t {
int MPI_SOURCE;
int MPI_TAG;
int MPI_ERROR;
int _cancelled;
size_t _ucount;
} MPI_Status;

View File

@@ -36,6 +36,40 @@ array all_sum(
{x}); {x});
} }
array all_max(
const array& x,
std::optional<Group> group_ /* = std::nullopt */,
StreamOrDevice s /* = {} */) {
auto group = to_group(group_);
if (group.size() == 1) {
return x;
}
return array(
x.shape(),
x.dtype(),
std::make_shared<AllReduce>(
to_stream(s, Device::cpu), group, AllReduce::Max),
{x});
}
array all_min(
const array& x,
std::optional<Group> group_ /* = std::nullopt */,
StreamOrDevice s /* = {} */) {
auto group = to_group(group_);
if (group.size() == 1) {
return x;
}
return array(
x.shape(),
x.dtype(),
std::make_shared<AllReduce>(
to_stream(s, Device::cpu), group, AllReduce::Min),
{x});
}
array all_gather( array all_gather(
const array& x, const array& x,
std::optional<Group> group_ /* = std::nullopt */, std::optional<Group> group_ /* = std::nullopt */,

View File

@@ -38,4 +38,14 @@ array recv_like(
std::optional<Group> group = std::nullopt, std::optional<Group> group = std::nullopt,
StreamOrDevice s = {}); StreamOrDevice s = {});
array all_max(
const array& x,
std::optional<Group> group = std::nullopt,
StreamOrDevice s = {});
array all_min(
const array& x,
std::optional<Group> group = std::nullopt,
StreamOrDevice s = {});
} // namespace mlx::core::distributed } // namespace mlx::core::distributed

Some files were not shown because too many files have changed in this diff Show More