Compare commits

...

64 Commits

Author SHA1 Message Date
Jagrit Digani
4c46e17a5d Update benchmark output 2025-04-15 10:50:06 -07:00
Angelos Katharopoulos
99eefd2ec0 Gather mm new kernel and small refactoring (#2040) 2025-04-14 16:37:36 -07:00
Yury Popov
e9e268336b LogCumSumExp (#2069) 2025-04-13 01:27:29 -07:00
Awni Hannun
7275ac7523 Fix release build (#2072) 2025-04-12 20:41:58 -07:00
Angelos Katharopoulos
c4189a38e4 Add float mask to sdpa vector (#2068) 2025-04-11 17:29:40 -07:00
Awni Hannun
68d1b3256b nit: fix exception handling (#2066) 2025-04-11 14:12:08 -07:00
Awni Hannun
9c6953bda7 Fix stubgen (#2065)
* Fix stubgen

* add multi optim to docs
2025-04-11 12:02:54 -07:00
Awni Hannun
ef7ece9851 fix fft bug (#2062) 2025-04-10 19:41:27 -07:00
Angelos Katharopoulos
ddaa4b7dcb Fix the test and add custom min/max reductions for uncommon MPI types (#2060) 2025-04-10 17:01:17 -07:00
Cheng
dfae2c6989 Fix MSVC build due to use of M_LN2 (#2058) 2025-04-10 07:41:41 -07:00
Anastasiia Filippova
515f104926 Min / max reductions (#2041) 2025-04-09 23:22:20 -07:00
Angelos Katharopoulos
9ecefd56db Do not load the default lib if another is requested (#2055) 2025-04-09 13:31:38 -07:00
Awni Hannun
e5d35aa187 no sdpa in grad (#2054) 2025-04-08 19:13:54 -07:00
Awni Hannun
00794c42bc Fix causal mask sdpa vec (#2053)
* fix sdpa vector causal mask

* test
2025-04-08 09:11:23 -07:00
Cheng
08a1bf3f10 Remove Event::Signal() (#2052) 2025-04-08 06:20:27 -07:00
Awni Hannun
60c4154346 Only request residency once (#2051) 2025-04-07 10:47:51 -07:00
Awni Hannun
f2c85308c1 add a half simd gemm fallback (#2046)
* add a half simd gemm fallback

* nit
2025-04-07 09:31:29 -07:00
Awni Hannun
1a28b69ee2 only add to residency set once (#2049) 2025-04-06 17:38:25 -07:00
Cheng
ba09f01ce8 Remove test of converting negative float to uint (#2048) 2025-04-06 06:21:46 -07:00
Cheng
6cf48872b7 wait_for_one should wait for task to finish (#2047) 2025-04-05 20:05:16 -07:00
Angelos Katharopoulos
7b3b8fa000 Fix ci release (#2045) 2025-04-04 20:25:01 -07:00
Awni Hannun
ec5e2aae61 nit in doc (#2044) 2025-04-04 12:04:17 -07:00
Awni Hannun
86389bf970 patch bump (#2043) 2025-04-03 13:15:18 -07:00
Jagrit Digani
3290bfa690 Add new sdpa function overload (#2035)
* Add new sdpa function overload

* Address comments

* Remove std::varaint from cpp sdpa function
2025-04-03 11:58:28 -07:00
Jagrit Digani
8777fd104f Depthwise Conv2D optimization (#2036)
- Add new specialized kernel for small kernel (kernels size <= 7), small strides (strides <= 2) depthwise 2d convolutions
- Add related tests
2025-04-03 09:42:04 -07:00
Awni Hannun
c41f7565ed fix softmax / logsumexp (#2042) 2025-04-03 08:32:59 -07:00
Awni Hannun
9ba81e3da4 tune quant dispatch (#2031) 2025-04-02 20:05:54 -07:00
Awni Hannun
c23888acd7 Fix build warning (#2033) 2025-04-01 14:42:27 -07:00
Awni Hannun
f98ce25ab9 fix residency set for real (#2032) 2025-04-01 12:59:48 -07:00
Awni Hannun
de5f38fd48 Custom logsumexp (#2028)
* initial custom logsumexp

* more tests

* comments + fix
2025-03-31 07:36:55 -07:00
Angelos Katharopoulos
ec2854b13a Swap -inf for finite_minimum value (#2029) 2025-03-30 21:55:04 -07:00
Stephen Panaro
90823d2938 Add missing funcs to docs (#2021) 2025-03-30 18:29:33 -07:00
Jesper Stemann Andersen
5f5770e3a2 Fix CPU sign for unsigned ints (#2024)
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2025-03-30 17:56:59 -07:00
Awni Hannun
28f39e9038 Log for complex numbers in Metal (#2025)
* Log for complex numbers in Metal

* fix log2
2025-03-30 17:04:38 -07:00
Awni Hannun
b2d2b37888 fix residency set clearing (#2027) 2025-03-30 16:27:26 -07:00
Awni Hannun
fe597e141c add pinv to doc (#2020) 2025-03-30 15:54:18 -07:00
Yi Wang
72ca1539e0 Remove unused variable in /setup.py (#2026)
This is a follow up of https://github.com/ml-explore/mlx/pull/2011
2025-03-30 12:52:33 -07:00
Awni Hannun
13b26775f1 use minimum deployment target (#2016) 2025-03-28 14:31:53 -07:00
Awni Hannun
05d7118561 causal vector sdpa (#2018)
* causal vector sdpa

* get rid of memory threshold
2025-03-28 12:36:13 -07:00
Awni Hannun
98b901ad66 enable complex gemm (#2017) 2025-03-28 10:45:13 -07:00
Awni Hannun
5580b47291 iinfo and scalar overflow detection (#2009) 2025-03-27 19:54:56 -07:00
Awni Hannun
bc62932984 sdpa specialization for head dim 256 (#2007) 2025-03-27 19:31:25 -07:00
Awni Hannun
a6b5d6e759 revise cmake minimum for doctest (#2014) 2025-03-27 19:30:58 -07:00
Yi Wang
a8931306e1 Remove unused variable in CMakeBuild (#2011)
Fix https://github.com/ml-explore/mlx/issues/2010
2025-03-27 16:00:51 -07:00
Yi Wang
fecdb8717e Polish CONTRIBUTING>md (#2005) 2025-03-25 19:06:34 -07:00
Awni Hannun
916fd273ea wire cache (#2006) 2025-03-25 18:54:01 -07:00
Yi Wang
0da8506552 Update docs for extensions (#2004) 2025-03-25 18:35:03 -07:00
Cheng
eda7a7b43e Do not join threads during process exit on Windows (#1738) 2025-03-25 06:33:08 -07:00
Chunyang Wen
022eabb734 Remove unused import (#1987) 2025-03-24 20:19:32 -07:00
Awni Hannun
aba899cef8 patch bump (#2000) 2025-03-24 12:47:05 -07:00
Jagrit Digani
6a40e1c176 Fix looping limit in causal attention (#1999) 2025-03-24 12:28:00 -07:00
Jesper Stemann Andersen
9307b2ab8b Fixed 32-bit platform support for distributed/ring implementation (#1996)
Replaced unsigned long integer literals with size_t literals in ring implementation, e.g., 1UL with size_t(1).
2025-03-24 08:08:40 -07:00
Jesper Stemann Andersen
522d8d3917 Added missing netinet/in.h include that fixes build on FreeBSD (#1997)
Defines IPPROTO_TCP.
2025-03-24 08:07:34 -07:00
Awni Hannun
a84cc0123f promote mask when needed (#1998) 2025-03-23 19:58:28 -07:00
Andrey Velichkevich
f018e248cd fix(backend): Include algorithm library in Allocator (#1992)
Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com>
2025-03-22 21:27:51 -07:00
Awni Hannun
cfd7237a80 fix docs (#1991) 2025-03-21 19:58:53 -07:00
Angelos Katharopoulos
4eef8102c9 Distributed layers (#1270) 2025-03-21 13:52:17 -07:00
Angelos Katharopoulos
69e4dd506b Add a ring all gather (#1985) 2025-03-21 13:36:51 -07:00
Angelos Katharopoulos
25814a9458 Disable mpi on version mismatch (#1989) 2025-03-21 13:36:26 -07:00
Awni Hannun
2a980a76ce Add stats and limit to common allocator and enable tests (#1988)
* add stats to common allocator and enable tests

* linux memory and default

* fix
2025-03-21 12:28:36 -07:00
Angelos Katharopoulos
d343782c8b Cross platform libmpi loading (#1975) 2025-03-21 11:23:10 -07:00
Awni Hannun
4e1994e9d7 move memory APIs into top level mlx.core (#1982) 2025-03-21 07:25:12 -07:00
jiyzhang
65a38c452b update the formula of smooth_l1_loss (#1986) 2025-03-21 06:25:23 -07:00
Awni Hannun
7b7e2352cd fix malloc or wait deadlock (#1976) 2025-03-20 16:48:43 -07:00
187 changed files with 5765 additions and 1717 deletions

View File

@@ -24,8 +24,8 @@ jobs:
type: boolean type: boolean
default: false default: false
macos: macos:
xcode: "15.2.0" xcode: "16.2.0"
resource_class: macos.m1.medium.gen1 resource_class: m2pro.medium
steps: steps:
- checkout - checkout
- run: - run:
@@ -89,15 +89,14 @@ jobs:
pip install numpy pip install numpy
sudo apt-get update sudo apt-get update
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
- run: - run:
name: Install Python package name: Install Python package
command: | command: |
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
CMAKE_COMPILE_WARNING_AS_ERROR=ON" \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \ CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
python3 setup.py build_ext --inplace python3 setup.py build_ext --inplace
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF \ CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
CMAKE_COMPILE_WARNING_AS_ERROR=ON" \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \ CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
python3 setup.py develop python3 setup.py develop
- run: - run:
@@ -110,6 +109,8 @@ jobs:
name: Run Python tests name: Run Python tests
command: | command: |
python3 -m unittest discover python/tests -v python3 -m unittest discover python/tests -v
mpirun --bind-to none -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py
- run: - run:
name: Build CPP only name: Build CPP only
command: | command: |
@@ -124,10 +125,15 @@ jobs:
parameters: parameters:
xcode_version: xcode_version:
type: string type: string
default: "15.2.0" default: "16.2.0"
macosx_deployment_target:
type: string
default: ""
macos: macos:
xcode: << parameters.xcode_version >> xcode: << parameters.xcode_version >>
resource_class: macos.m1.medium.gen1 environment:
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
resource_class: m2pro.medium
steps: steps:
- checkout - checkout
- run: - run:
@@ -149,7 +155,7 @@ jobs:
command: | command: |
source env/bin/activate source env/bin/activate
DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \ DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
CMAKE_ARGS="CMAKE_COMPILE_WARNING_AS_ERROR=ON" \ CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
pip install -e . -v pip install -e . -v
- run: - run:
name: Generate package stubs name: Generate package stubs
@@ -213,13 +219,18 @@ jobs:
default: "3.9" default: "3.9"
xcode_version: xcode_version:
type: string type: string
default: "15.2.0" default: "16.2.0"
build_env: build_env:
type: string type: string
default: "" default: ""
macosx_deployment_target:
type: string
default: ""
macos: macos:
xcode: << parameters.xcode_version >> xcode: << parameters.xcode_version >>
resource_class: macos.m1.medium.gen1 resource_class: m2pro.medium
environment:
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
steps: steps:
- checkout - checkout
- run: - run:
@@ -240,7 +251,7 @@ jobs:
name: Install Python package name: Install Python package
command: | command: |
source env/bin/activate source env/bin/activate
DEV_RELEASE=1 \ env -u MACOSX_DEPLOYMENT_TARGET DEV_RELEASE=1 \
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \ CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
pip install . -v pip install . -v
- run: - run:
@@ -335,7 +346,7 @@ workflows:
- mac_build_and_test: - mac_build_and_test:
matrix: matrix:
parameters: parameters:
xcode_version: ["15.0.0", "15.2.0", "16.0.0"] macosx_deployment_target: ["13.5", "14.0"]
- linux_build_and_test - linux_build_and_test
- build_documentation - build_documentation
@@ -355,8 +366,70 @@ workflows:
matrix: matrix:
parameters: parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
xcode_version: ["15.0.0", "15.2.0"] macosx_deployment_target: ["13.5", "14.0", "15.0"]
build_env: ["PYPI_RELEASE=1"] build_env: ["PYPI_RELEASE=1"]
xcode_version: ["16.2.0", "15.0.0"]
exclude:
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.9"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.10"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "PYPI_RELEASE=1"
- build_documentation: - build_documentation:
filters: filters:
tags: tags:
@@ -379,7 +452,7 @@ workflows:
requires: [ hold ] requires: [ hold ]
matrix: matrix:
parameters: parameters:
xcode_version: ["15.0.0", "15.2.0", "16.0.0"] macosx_deployment_target: ["13.5", "14.0"]
- linux_build_and_test: - linux_build_and_test:
requires: [ hold ] requires: [ hold ]
nightly_build: nightly_build:
@@ -392,7 +465,54 @@ workflows:
matrix: matrix:
parameters: parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
xcode_version: ["15.0.0", "15.2.0"] macosx_deployment_target: ["13.5", "14.0", "15.0"]
xcode_version: ["16.2.0", "15.0.0"]
exclude:
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.9"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.10"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
weekly_build: weekly_build:
when: when:
and: and:
@@ -403,8 +523,70 @@ workflows:
matrix: matrix:
parameters: parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
xcode_version: ["15.0.0", "15.2.0", "16.0.0"] macosx_deployment_target: ["13.5", "14.0", "15.0"]
build_env: ["DEV_RELEASE=1"] build_env: ["DEV_RELEASE=1"]
xcode_version: ["16.2.0", "15.0.0"]
exclude:
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
linux_test_release: linux_test_release:
when: when:
and: and:

View File

@@ -212,24 +212,6 @@ else()
set(MLX_BUILD_ACCELERATE OFF) set(MLX_BUILD_ACCELERATE OFF)
endif() endif()
find_package(MPI)
if(MPI_FOUND)
execute_process(
COMMAND zsh "-c" "mpirun --version"
OUTPUT_VARIABLE MPI_VERSION
ERROR_QUIET)
if(${MPI_VERSION} MATCHES ".*Open MPI.*")
target_include_directories(mlx PRIVATE ${MPI_INCLUDE_PATH})
elseif(MPI_VERSION STREQUAL "")
set(MPI_FOUND FALSE)
message(
WARNING "MPI found but mpirun is not available. Building without MPI.")
else()
set(MPI_FOUND FALSE)
message(WARNING "MPI which is not OpenMPI found. Building without MPI.")
endif()
endif()
message(STATUS "Downloading json") message(STATUS "Downloading json")
FetchContent_Declare( FetchContent_Declare(
json json

View File

@@ -5,26 +5,26 @@ possible.
## Pull Requests ## Pull Requests
1. Fork and submit pull requests to the repo. 1. Fork and submit pull requests to the repo.
2. If you've added code that should be tested, add tests. 2. If you've added code that should be tested, add tests.
3. If a change is likely to impact efficiency, run some of the benchmarks before 3. If a change is likely to impact efficiency, run some of the benchmarks before
and after the change. Examples of benchmarks can be found in `benchmarks/python/`. and after the change. Examples of benchmarks can be found in `benchmarks/python/`.
4. If you've changed APIs, update the documentation. 4. If you've changed APIs, update the documentation.
5. Every PR should have passing tests and at least one review. 5. Every PR should have passing tests and at least one review.
6. For code formatting install `pre-commit` using something like `pip install pre-commit` and run `pre-commit install`. 6. For code formatting install `pre-commit` using something like `pip install pre-commit` and run `pre-commit install`.
This should install hooks for running `black` and `clang-format` to ensure This should install hooks for running `black` and `clang-format` to ensure
consistent style for C++ and python code. consistent style for C++ and python code.
You can also run the formatters manually as follows: You can also run the formatters manually as follows:
``` ```shell
clang-format -i file.cpp clang-format -i file.cpp
``` ```
``` ```shell
black file.py black file.py
``` ```
or run `pre-commit run --all-files` to check all files in the repo. or run `pre-commit run --all-files` to check all files in the repo.
## Issues ## Issues

View File

@@ -157,7 +157,7 @@ def bench_shape(B, M, N, K, np_dtype, transpose="nn"):
def get_gflop_count(B, M, N, K): def get_gflop_count(B, M, N, K):
return float(2.0 * N_iter_bench * N_iter_func * B * M * N * K) / float(1024.0**3) return float(2.0 * N_iter_bench * N_iter_func * B * M * N * K) / float(1000.0**3)
if __name__ == "__main__": if __name__ == "__main__":
@@ -175,6 +175,8 @@ if __name__ == "__main__":
(1, 4096, 4096, 4096), (1, 4096, 4096, 4096),
) )
print(f" B, M, N, K, dtype, t, gflops_pt, gflops_mx, diff%")
for dtype in dtypes: for dtype in dtypes:
for transpose in transposes: for transpose in transposes:
for B, M, N, K in shapes: for B, M, N, K in shapes:
@@ -187,7 +189,7 @@ if __name__ == "__main__":
diff = gflops_mx / gflops_pt - 1.0 diff = gflops_mx / gflops_pt - 1.0
print( print(
f"{B:3d}, {M:4d}, {N:4d}, {K:4d}, {dtype}, {transpose}, {gflops_pt:05.3f}, {gflops_mx:05.3f}, {100. * diff:+5.2f}%" f"{B:3d}, {M:4d}, {N:4d}, {K:5d}, {dtype}, {transpose}, {gflops_pt:8.2f}, {gflops_mx:8.2f}, {100. * diff:+5.2f}%"
) )
if gflops_pt >= 2.0 * gflops_mx: if gflops_pt >= 2.0 * gflops_mx:
print("ATTENTION ^^^^^^^") print("ATTENTION ^^^^^^^")

View File

@@ -0,0 +1,74 @@
# Copyright © 2023-2024 Apple Inc.
import mlx.core as mx
from time_utils import time_fn
N = 1024
D = 1024
M = 1024
E = 32
I = 4
def gather_sort(x, indices):
N, M = indices.shape
indices = indices.flatten()
order = mx.argsort(indices)
inv_order = mx.argsort(order)
return x.flatten(0, -3)[order // M], indices[order], inv_order
def scatter_unsort(x, inv_order, shape=None):
x = x[inv_order]
if shape is not None:
x = mx.unflatten(x, 0, shape)
return x
def gather_mm_simulate(x, w, indices):
x, idx, inv_order = gather_sort(x, indices)
for i in range(2):
y = mx.concatenate([x[i] @ w[j].T for i, j in enumerate(idx.tolist())], axis=0)
x = y[:, None]
x = scatter_unsort(x, inv_order, indices.shape)
return x
def time_gather_mm():
x = mx.random.normal((N, 1, 1, D)) / 1024**0.5
w1 = mx.random.normal((E, M, D)) / 1024**0.5
w2 = mx.random.normal((E, D, M)) / 1024**0.5
indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)
sorted_indices = mx.sort(indices.flatten()).reshape(N, I)
mx.eval(x, w1, w2, indices, sorted_indices)
def gather_mm(x, w1, w2, indices, sort):
idx = indices
inv_order = None
if sort:
x, idx, inv_order = gather_sort(x, indices)
x = mx.gather_mm(x, w1.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort)
x = mx.gather_mm(x, w2.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort)
if sort:
x = scatter_unsort(x, inv_order, indices.shape)
return x
time_fn(gather_mm, x, w1, w2, indices, False)
time_fn(gather_mm, x, w1, w2, sorted_indices, False)
time_fn(gather_mm, x, w1, w2, indices, True)
x = mx.random.normal((N * I, D)) / 1024**0.5
w1 = mx.random.normal((M, D)) / 1024**0.5
w2 = mx.random.normal((D, M)) / 1024**0.5
mx.eval(x, w1, w2)
def equivalent_matmul(x, w1, w2):
x = x @ w1.T
x = x @ w2.T
return x
time_fn(equivalent_matmul, x, w1, w2)
if __name__ == "__main__":
time_gather_mm()

View File

@@ -13,7 +13,7 @@ EXCLUDE_PATTERNS = */private/*
CREATE_SUBDIRS = NO CREATE_SUBDIRS = NO
FULL_PATH_NAMES = YES FULL_PATH_NAMES = YES
RECURSIVE = YES RECURSIVE = YES
GENERATE_HTML = YES GENERATE_HTML = NO
GENERATE_LATEX = NO GENERATE_LATEX = NO
GENERATE_XML = YES GENERATE_XML = YES
XML_PROGRAMLISTING = YES XML_PROGRAMLISTING = YES

View File

@@ -93,9 +93,9 @@ Primitives
^^^^^^^^^^^ ^^^^^^^^^^^
A :class:`Primitive` is part of the computation graph of an :class:`array`. It A :class:`Primitive` is part of the computation graph of an :class:`array`. It
defines how to create outputs arrays given a input arrays. Further, a defines how to create output arrays given input arrays. Further, a
:class:`Primitive` has methods to run on the CPU or GPU and for function :class:`Primitive` has methods to run on the CPU or GPU and for function
transformations such as ``vjp`` and ``jvp``. Lets go back to our example to be transformations such as ``vjp`` and ``jvp``. Let's go back to our example to be
more concrete: more concrete:
.. code-block:: C++ .. code-block:: C++
@@ -128,7 +128,7 @@ more concrete:
/** The vector-Jacobian product. */ /** The vector-Jacobian product. */
std::vector<array> vjp( std::vector<array> vjp(
const std::vector<array>& primals, const std::vector<array>& primals,
const array& cotan, const std::vector<array>& cotangents,
const std::vector<int>& argnums, const std::vector<int>& argnums,
const std::vector<array>& outputs) override; const std::vector<array>& outputs) override;
@@ -247,9 +247,7 @@ point-wise. This is captured in the templated function :meth:`axpby_impl`.
float alpha_, float alpha_,
float beta_, float beta_,
mx::Stream stream) { mx::Stream stream) {
// Allocate the output with `malloc_or_wait` which synchronously allocates out.set_data(mx::allocator::malloc(out.nbytes()));
// memory, potentially waiting if the system is under memory pressure
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
// Get the CPU command encoder and register input and output arrays // Get the CPU command encoder and register input and output arrays
auto& encoder = mx::cpu::get_command_encoder(stream); auto& encoder = mx::cpu::get_command_encoder(stream);
@@ -393,7 +391,7 @@ below.
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
// Allocate output memory // Allocate output memory
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
// Resolve name of kernel // Resolve name of kernel
std::ostringstream kname; std::ostringstream kname;
@@ -471,7 +469,7 @@ one we just defined:
const std::vector<array>& tangents, const std::vector<array>& tangents,
const std::vector<int>& argnums) { const std::vector<int>& argnums) {
// Forward mode diff that pushes along the tangents // Forward mode diff that pushes along the tangents
// The jvp transform on the primitive can built with ops // The jvp transform on the primitive can be built with ops
// that are scheduled on the same stream as the primitive // that are scheduled on the same stream as the primitive
// If argnums = {0}, we only push along x in which case the // If argnums = {0}, we only push along x in which case the
@@ -483,7 +481,7 @@ one we just defined:
auto scale_arr = array(scale, tangents[0].dtype()); auto scale_arr = array(scale, tangents[0].dtype());
return {multiply(scale_arr, tangents[0], stream())}; return {multiply(scale_arr, tangents[0], stream())};
} }
// If, argnums = {0, 1}, we take contributions from both // If argnums = {0, 1}, we take contributions from both
// which gives us jvp = tangent_x * alpha + tangent_y * beta // which gives us jvp = tangent_x * alpha + tangent_y * beta
else { else {
return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())}; return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())};
@@ -737,7 +735,7 @@ Let's look at a simple script and its results:
print(f"c shape: {c.shape}") print(f"c shape: {c.shape}")
print(f"c dtype: {c.dtype}") print(f"c dtype: {c.dtype}")
print(f"c correct: {mx.all(c == 6.0).item()}") print(f"c is correct: {mx.all(c == 6.0).item()}")
Output: Output:
@@ -745,7 +743,7 @@ Output:
c shape: [3, 4] c shape: [3, 4]
c dtype: float32 c dtype: float32
c correctness: True c is correct: True
Results Results
^^^^^^^ ^^^^^^^

View File

@@ -70,6 +70,7 @@ are the CPU and GPU.
python/fft python/fft
python/linalg python/linalg
python/metal python/metal
python/memory_management
python/nn python/nn
python/optimizers python/optimizers
python/distributed python/distributed

View File

@@ -38,6 +38,7 @@ Array
array.log10 array.log10
array.log1p array.log1p
array.log2 array.log2
array.logcumsumexp
array.logsumexp array.logsumexp
array.max array.max
array.mean array.mean

View File

@@ -20,5 +20,6 @@ Linear Algebra
eigh eigh
lu lu
lu_factor lu_factor
pinv
solve solve
solve_triangular solve_triangular

View File

@@ -0,0 +1,16 @@
Memory Management
=================
.. currentmodule:: mlx.core
.. autosummary::
:toctree: _autosummary
get_active_memory
get_peak_memory
reset_peak_memory
get_cache_memory
set_memory_limit
set_cache_limit
set_wired_limit
clear_cache

View File

@@ -8,13 +8,5 @@ Metal
is_available is_available
device_info device_info
get_active_memory
get_peak_memory
reset_peak_memory
get_cache_memory
set_memory_limit
set_cache_limit
set_wired_limit
clear_cache
start_capture start_capture
stop_capture stop_capture

View File

@@ -36,10 +36,12 @@ Operations
bitwise_or bitwise_or
bitwise_xor bitwise_xor
block_masked_mm block_masked_mm
broadcast_arrays
broadcast_to broadcast_to
ceil ceil
clip clip
concatenate concatenate
contiguous
conj conj
conjugate conjugate
convolve convolve
@@ -101,6 +103,7 @@ Operations
log10 log10
log1p log1p
logaddexp logaddexp
logcumsumexp
logical_not logical_not
logical_and logical_and
logical_or logical_or

View File

@@ -18,3 +18,4 @@ Common Optimizers
AdamW AdamW
Adamax Adamax
Lion Lion
MultiOptimizer

View File

@@ -9,6 +9,7 @@ Transforms
:toctree: _autosummary :toctree: _autosummary
eval eval
async_eval
compile compile
custom_function custom_function
disable_compile disable_compile

View File

@@ -72,9 +72,7 @@ void axpby_impl(
float alpha_, float alpha_,
float beta_, float beta_,
mx::Stream stream) { mx::Stream stream) {
// Allocate the output with `malloc_or_wait` which synchronously allocates out.set_data(mx::allocator::malloc(out.nbytes()));
// memory, potentially waiting if the system is under memory pressure
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
// Get the CPU command encoder and register input and output arrays // Get the CPU command encoder and register input and output arrays
auto& encoder = mx::cpu::get_command_encoder(stream); auto& encoder = mx::cpu::get_command_encoder(stream);
@@ -160,12 +158,12 @@ void Axpby::eval_gpu(
// Allocate output memory with strides based on specialization // Allocate output memory with strides based on specialization
if (contiguous_kernel) { if (contiguous_kernel) {
out.set_data( out.set_data(
mx::allocator::malloc_or_wait(x.data_size() * out.itemsize()), mx::allocator::malloc(x.data_size() * out.itemsize()),
x.data_size(), x.data_size(),
x.strides(), x.strides(),
x.flags()); x.flags());
} else { } else {
out.set_data(mx::allocator::malloc_or_wait(out.nbytes())); out.set_data(mx::allocator::malloc(out.nbytes()));
} }
// Resolve name of kernel (corresponds to axpby.metal) // Resolve name of kernel (corresponds to axpby.metal)

View File

@@ -4,12 +4,11 @@
#include <sstream> #include <sstream>
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/scheduler.h"
namespace mlx::core::allocator { namespace mlx::core::allocator {
Buffer malloc(size_t size) { Buffer malloc(size_t size) {
auto buffer = allocator().malloc(size, /* allow_swap */ true); auto buffer = allocator().malloc(size);
if (size && !buffer.ptr()) { if (size && !buffer.ptr()) {
std::ostringstream msg; std::ostringstream msg;
msg << "[malloc] Unable to allocate " << size << " bytes."; msg << "[malloc] Unable to allocate " << size << " bytes.";
@@ -22,45 +21,4 @@ void free(Buffer buffer) {
allocator().free(buffer); allocator().free(buffer);
} }
Buffer CommonAllocator::malloc(size_t size, bool) {
void* ptr = std::malloc(size + sizeof(size_t));
if (ptr != nullptr) {
*static_cast<size_t*>(ptr) = size;
}
return Buffer{ptr};
}
void CommonAllocator::free(Buffer buffer) {
std::free(buffer.ptr());
}
size_t CommonAllocator::size(Buffer buffer) const {
if (buffer.ptr() == nullptr) {
return 0;
}
return *static_cast<size_t*>(buffer.ptr());
}
Buffer malloc_or_wait(size_t size) {
auto buffer = allocator().malloc(size);
while (size && !buffer.ptr() && scheduler::n_active_tasks() > 0) {
scheduler::wait_for_one();
buffer = allocator().malloc(size);
}
// Try swapping if needed
if (size && !buffer.ptr()) {
buffer = allocator().malloc(size, /* allow_swap = */ true);
}
if (size && !buffer.ptr()) {
std::ostringstream msg;
msg << "[malloc_or_wait] Unable to allocate " << size << " bytes.";
throw std::runtime_error(msg.str());
}
return buffer;
}
} // namespace mlx::core::allocator } // namespace mlx::core::allocator

View File

@@ -32,14 +32,10 @@ Buffer malloc(size_t size);
void free(Buffer buffer); void free(Buffer buffer);
// Wait for running tasks to finish and free up memory
// if allocation fails
Buffer malloc_or_wait(size_t size);
class Allocator { class Allocator {
/** Abstract base class for a memory allocator. */ /** Abstract base class for a memory allocator. */
public: public:
virtual Buffer malloc(size_t size, bool allow_swap = false) = 0; virtual Buffer malloc(size_t size) = 0;
virtual void free(Buffer buffer) = 0; virtual void free(Buffer buffer) = 0;
virtual size_t size(Buffer buffer) const = 0; virtual size_t size(Buffer buffer) const = 0;
@@ -53,16 +49,4 @@ class Allocator {
Allocator& allocator(); Allocator& allocator();
class CommonAllocator : public Allocator {
/** A general CPU allocator. */
public:
virtual Buffer malloc(size_t size, bool allow_swap = false) override;
virtual void free(Buffer buffer) override;
virtual size_t size(Buffer buffer) const override;
private:
CommonAllocator() = default;
friend Allocator& allocator();
};
} // namespace mlx::core::allocator } // namespace mlx::core::allocator

View File

@@ -1,6 +1,7 @@
target_sources( target_sources(
mlx mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/broadcasting.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp ${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp

View File

@@ -44,14 +44,14 @@ inline void set_binary_op_output_data(
switch (bopt) { switch (bopt) {
case BinaryOpType::ScalarScalar: case BinaryOpType::ScalarScalar:
out.set_data( out.set_data(
allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags()); allocator::malloc(out.itemsize()), 1, a.strides(), a.flags());
break; break;
case BinaryOpType::ScalarVector: case BinaryOpType::ScalarVector:
if (b_donatable) { if (b_donatable) {
out.copy_shared_buffer(b); out.copy_shared_buffer(b);
} else { } else {
out.set_data( out.set_data(
allocator::malloc_or_wait(b.data_size() * out.itemsize()), allocator::malloc(b.data_size() * out.itemsize()),
b.data_size(), b.data_size(),
b.strides(), b.strides(),
b.flags()); b.flags());
@@ -62,7 +62,7 @@ inline void set_binary_op_output_data(
out.copy_shared_buffer(a); out.copy_shared_buffer(a);
} else { } else {
out.set_data( out.set_data(
allocator::malloc_or_wait(a.data_size() * out.itemsize()), allocator::malloc(a.data_size() * out.itemsize()),
a.data_size(), a.data_size(),
a.strides(), a.strides(),
a.flags()); a.flags());
@@ -75,7 +75,7 @@ inline void set_binary_op_output_data(
out.copy_shared_buffer(b); out.copy_shared_buffer(b);
} else { } else {
out.set_data( out.set_data(
allocator::malloc_or_wait(a.data_size() * out.itemsize()), allocator::malloc(a.data_size() * out.itemsize()),
a.data_size(), a.data_size(),
a.strides(), a.strides(),
a.flags()); a.flags());
@@ -88,7 +88,7 @@ inline void set_binary_op_output_data(
b_donatable && b.flags().row_contiguous && b.size() == out.size()) { b_donatable && b.flags().row_contiguous && b.size() == out.size()) {
out.copy_shared_buffer(b); out.copy_shared_buffer(b);
} else { } else {
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
} }
break; break;
} }

View File

@@ -0,0 +1,24 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/common/utils.h"
namespace mlx::core {
void broadcast(const array& in, array& out) {
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
Strides strides(out.ndim(), 0);
int diff = out.ndim() - in.ndim();
for (int i = in.ndim() - 1; i >= 0; --i) {
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
}
auto flags = in.flags();
if (out.size() > in.size()) {
flags.row_contiguous = flags.col_contiguous = false;
}
out.copy_shared_buffer(in, strides, flags, in.data_size());
}
} // namespace mlx::core

View File

@@ -0,0 +1,11 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/array.h"
namespace mlx::core {
void broadcast(const array& in, array& out);
} // namespace mlx::core

View File

@@ -1,6 +1,7 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
#include <cassert> #include <cassert>
#include "mlx/backend/common/broadcasting.h"
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
@@ -42,23 +43,6 @@ void AsStrided::eval(const std::vector<array>& inputs, array& out) {
return out.copy_shared_buffer(in, strides_, flags, data_size, offset_); return out.copy_shared_buffer(in, strides_, flags, data_size, offset_);
} }
void broadcast(const array& in, array& out) {
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
Strides strides(out.ndim(), 0);
int diff = out.ndim() - in.ndim();
for (int i = in.ndim() - 1; i >= 0; --i) {
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
}
auto flags = in.flags();
if (out.size() > in.size()) {
flags.row_contiguous = flags.col_contiguous = false;
}
out.copy_shared_buffer(in, strides, flags, in.data_size());
}
void Broadcast::eval(const std::vector<array>& inputs, array& out) { void Broadcast::eval(const std::vector<array>& inputs, array& out) {
broadcast(inputs[0], out); broadcast(inputs[0], out);
} }
@@ -103,7 +87,7 @@ void ExpandDims::eval(const std::vector<array>& inputs, array& out) {
void NumberOfElements::eval(const std::vector<array>& inputs, array& out) { void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
double numel = 1; double numel = 1;
for (auto ax : axes_) { for (auto ax : axes_) {

View File

@@ -188,7 +188,7 @@ void compiled_allocate_outputs(
} }
for (; o < outputs.size(); ++o) { for (; o < outputs.size(); ++o) {
outputs[o].set_data( outputs[o].set_data(
allocator::malloc_or_wait(data_size * outputs[o].itemsize()), allocator::malloc(data_size * outputs[o].itemsize()),
data_size, data_size,
strides, strides,
flags); flags);
@@ -211,7 +211,7 @@ void compiled_allocate_outputs(
} }
} }
for (; o < outputs.size(); ++o) { for (; o < outputs.size(); ++o) {
outputs[o].set_data(allocator::malloc_or_wait(outputs[o].nbytes())); outputs[o].set_data(allocator::malloc(outputs[o].nbytes()));
} }
} }
} }

View File

@@ -31,14 +31,14 @@ inline bool set_copy_output_data(const array& in, array& out, CopyType ctype) {
return true; return true;
} else { } else {
out.set_data( out.set_data(
allocator::malloc_or_wait(in.data_size() * out.itemsize()), allocator::malloc(in.data_size() * out.itemsize()),
in.data_size(), in.data_size(),
in.strides(), in.strides(),
in.flags()); in.flags());
return false; return false;
} }
} else { } else {
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
return false; return false;
} }
} }

View File

@@ -28,7 +28,7 @@ void swap_endianness(uint8_t* data_bytes, size_t N) {
namespace mlx::core { namespace mlx::core {
void Load::eval_cpu(const std::vector<array>& inputs, array& out) { void Load::eval_cpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
auto read_task = [out_ptr = out.data<char>(), auto read_task = [out_ptr = out.data<char>(),
size = out.size(), size = out.size(),
itemsize = out.itemsize(), itemsize = out.itemsize(),

View File

@@ -48,12 +48,12 @@ inline void set_ternary_op_output_data(
switch (topt) { switch (topt) {
case TernaryOpType::ScalarScalarScalar: case TernaryOpType::ScalarScalarScalar:
out.set_data( out.set_data(
allocator::malloc_or_wait(out.itemsize()), 1, b.strides(), b.flags()); allocator::malloc(out.itemsize()), 1, b.strides(), b.flags());
break; break;
case TernaryOpType::VectorVectorVector: case TernaryOpType::VectorVectorVector:
if (!(maybe_donate(a) || maybe_donate(b) || maybe_donate(c))) { if (!(maybe_donate(a) || maybe_donate(b) || maybe_donate(c))) {
out.set_data( out.set_data(
allocator::malloc_or_wait(out.itemsize() * b.data_size()), allocator::malloc(out.itemsize() * b.data_size()),
b.data_size(), b.data_size(),
b.strides(), b.strides(),
b.flags()); b.flags());
@@ -64,7 +64,7 @@ inline void set_ternary_op_output_data(
if (!((a.flags().row_contiguous && maybe_donate(a)) || if (!((a.flags().row_contiguous && maybe_donate(a)) ||
(b.flags().row_contiguous && maybe_donate(b)) || (b.flags().row_contiguous && maybe_donate(b)) ||
(c.flags().row_contiguous && maybe_donate(c)))) { (c.flags().row_contiguous && maybe_donate(c)))) {
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
} }
break; break;
} }

View File

@@ -58,6 +58,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp ${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp ${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp ${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp ${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
@@ -73,8 +74,8 @@ target_sources(
if(MLX_BUILD_ACCELERATE) if(MLX_BUILD_ACCELERATE)
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/bnns.cpp) target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/bnns.cpp)
else() else()
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/no_fp16.cpp target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/simd_fp16.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gemms/no_bf16.cpp) ${CMAKE_CURRENT_SOURCE_DIR}/gemms/simd_bf16.cpp)
endif() endif()
if(IOS) if(IOS)

View File

@@ -68,7 +68,7 @@ void arg_reduce_dispatch(
void ArgReduce::eval_cpu(const std::vector<array>& inputs, array& out) { void ArgReduce::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cpu::get_command_encoder(stream()); auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(in); encoder.set_input_array(in);
encoder.set_output_array(out); encoder.set_output_array(out);

View File

@@ -921,7 +921,7 @@ void explicit_gemm_conv_1D_cpu(
if (out.dtype() != float32) { if (out.dtype() != float32) {
gemm_out = array(out.shape(), float32, nullptr, {}); gemm_out = array(out.shape(), float32, nullptr, {});
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes())); gemm_out.set_data(allocator::malloc(gemm_out.nbytes()));
temps.push_back(gemm_out); temps.push_back(gemm_out);
} }
@@ -1048,7 +1048,7 @@ void explicit_gemm_conv_2D_cpu(
if (out.dtype() != float32) { if (out.dtype() != float32) {
gemm_out = array(out.shape(), float32, nullptr, {}); gemm_out = array(out.shape(), float32, nullptr, {});
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes())); gemm_out.set_data(allocator::malloc(gemm_out.nbytes()));
temps.push_back(gemm_out); temps.push_back(gemm_out);
} }
@@ -1214,7 +1214,7 @@ void explicit_gemm_conv_ND_cpu(
if (out.dtype() != float32) { if (out.dtype() != float32) {
gemm_out = array(out.shape(), float32, nullptr, {}); gemm_out = array(out.shape(), float32, nullptr, {});
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes())); gemm_out.set_data(allocator::malloc(gemm_out.nbytes()));
temps.push_back(gemm_out); temps.push_back(gemm_out);
} }
@@ -1327,7 +1327,7 @@ void conv_3D_cpu(
} // namespace } // namespace
void Convolution::eval_cpu(const std::vector<array>& inputs, array& out) { void Convolution::eval_cpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
auto& in = inputs[0]; auto& in = inputs[0];
auto& wt = inputs[1]; auto& wt = inputs[1];

View File

@@ -30,7 +30,7 @@ void AllReduce::eval_cpu(
if (in.is_donatable()) { if (in.is_donatable()) {
out.copy_shared_buffer(in); out.copy_shared_buffer(in);
} else { } else {
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
} }
return in; return in;
} else { } else {
@@ -46,8 +46,15 @@ void AllReduce::eval_cpu(
case Sum: case Sum:
distributed::detail::all_sum(group(), in, outputs[0], stream()); distributed::detail::all_sum(group(), in, outputs[0], stream());
break; break;
case Max:
distributed::detail::all_max(group(), in, outputs[0], stream());
break;
case Min:
distributed::detail::all_min(group(), in, outputs[0], stream());
break;
default: default:
throw std::runtime_error("Only all reduce sum is supported for now"); throw std::runtime_error(
"Only all reduce sum, min and max are supported for now");
} }
} }
@@ -58,7 +65,7 @@ void AllGather::eval_cpu(
assert(outputs.size() == 1); assert(outputs.size() == 1);
auto [in, copied] = ensure_row_contiguous(inputs[0], stream()); auto [in, copied] = ensure_row_contiguous(inputs[0], stream());
outputs[0].set_data(allocator::malloc_or_wait(outputs[0].nbytes())); outputs[0].set_data(allocator::malloc(outputs[0].nbytes()));
distributed::detail::all_gather(group(), in, outputs[0], stream()); distributed::detail::all_gather(group(), in, outputs[0], stream());
if (copied) { if (copied) {
auto& enc = cpu::get_command_encoder(stream()); auto& enc = cpu::get_command_encoder(stream());
@@ -87,7 +94,7 @@ void Recv::eval_cpu(
assert(inputs.size() == 0); assert(inputs.size() == 0);
assert(outputs.size() == 1); assert(outputs.size() == 1);
outputs[0].set_data(allocator::malloc_or_wait(outputs[0].nbytes())); outputs[0].set_data(allocator::malloc(outputs[0].nbytes()));
distributed::detail::recv(group(), outputs[0], src_, stream()); distributed::detail::recv(group(), outputs[0], src_, stream());
} }

View File

@@ -55,9 +55,8 @@ void eigh_impl(
liwork = iwork; liwork = iwork;
} }
auto work_buf = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)}; auto work_buf = array::Data{allocator::malloc(sizeof(T) * lwork)};
auto iwork_buf = auto iwork_buf = array::Data{allocator::malloc(sizeof(int) * liwork)};
array::Data{allocator::malloc_or_wait(sizeof(int) * liwork)};
for (size_t i = 0; i < size / (N * N); ++i) { for (size_t i = 0; i < size / (N * N); ++i) {
syevd<T>( syevd<T>(
&jobz, &jobz,
@@ -98,7 +97,7 @@ void Eigh::eval_cpu(
? outputs[1] ? outputs[1]
: array(a.shape(), a.dtype(), nullptr, {}); : array(a.shape(), a.dtype(), nullptr, {});
values.set_data(allocator::malloc_or_wait(values.nbytes())); values.set_data(allocator::malloc(values.nbytes()));
copy( copy(
a, a,

View File

@@ -22,7 +22,7 @@ void FFT::eval_cpu(const std::vector<array>& inputs, array& out) {
s *= out.itemsize(); s *= out.itemsize();
} }
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
std::vector<size_t> shape; std::vector<size_t> shape;
if (out.dtype() == float32) { if (out.dtype() == float32) {

View File

@@ -1,27 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cpu/gemm.h"
namespace mlx::core {
template <>
void matmul<bfloat16_t>(
const bfloat16_t*,
const bfloat16_t*,
bfloat16_t*,
bool,
bool,
size_t,
size_t,
size_t,
float,
float,
size_t,
const Shape&,
const Strides&,
const Shape&,
const Strides&) {
throw std::runtime_error("[Matmul::eval_cpu] bfloat16 not supported.");
}
} // namespace mlx::core

View File

@@ -1,27 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cpu/gemm.h"
namespace mlx::core {
template <>
void matmul<float16_t>(
const float16_t*,
const float16_t*,
float16_t*,
bool,
bool,
size_t,
size_t,
size_t,
float,
float,
size_t,
const Shape&,
const Strides&,
const Shape&,
const Strides&) {
throw std::runtime_error("[Matmul::eval_cpu] float16 not supported.");
}
} // namespace mlx::core

View File

@@ -0,0 +1,45 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/gemm.h"
#include "mlx/backend/cpu/gemms/simd_gemm.h"
namespace mlx::core {
template <>
void matmul<bfloat16_t>(
const bfloat16_t* a,
const bfloat16_t* b,
bfloat16_t* out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
size_t ldc,
float alpha,
float beta,
size_t batch_size,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides) {
auto ndim = a_shape.size();
size_t M = a_shape[ndim - 2];
size_t N = b_shape[ndim - 1];
size_t K = a_shape[ndim - 1];
for (int i = 0; i < batch_size; ++i) {
simd_gemm<bfloat16_t, float>(
a + elem_to_loc(M * K * i, a_shape, a_strides),
b + elem_to_loc(K * N * i, b_shape, b_strides),
out + M * N * i,
a_transposed,
b_transposed,
M,
N,
K,
alpha,
beta);
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,45 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/gemm.h"
#include "mlx/backend/cpu/gemms/simd_gemm.h"
namespace mlx::core {
template <>
void matmul<float16_t>(
const float16_t* a,
const float16_t* b,
float16_t* out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
size_t ldc,
float alpha,
float beta,
size_t batch_size,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides) {
auto ndim = a_shape.size();
size_t M = a_shape[ndim - 2];
size_t N = b_shape[ndim - 1];
size_t K = a_shape[ndim - 1];
for (int i = 0; i < batch_size; ++i) {
simd_gemm<float16_t, float>(
a + elem_to_loc(M * K * i, a_shape, a_strides),
b + elem_to_loc(K * N * i, b_shape, b_strides),
out + M * N * i,
a_transposed,
b_transposed,
M,
N,
K,
alpha,
beta);
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,139 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/backend/cpu/simd/simd.h"
namespace mlx::core {
inline int ceildiv(int a, int b) {
return (a + b - 1) / b;
}
template <int block_size, typename T, typename AccT>
void load_block(
const T* in,
AccT* out,
int M,
int N,
int i,
int j,
bool transpose) {
if (transpose) {
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
out[jj * block_size + ii] =
in[(i * block_size + ii) * N + j * block_size + jj];
}
}
} else {
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
out[ii * block_size + jj] =
in[(i * block_size + ii) * N + j * block_size + jj];
}
}
}
}
template <typename T, typename AccT>
void simd_gemm(
const T* a,
const T* b,
T* c,
bool a_trans,
bool b_trans,
int M,
int N,
int K,
float alpha,
float beta) {
constexpr int block_size = 16;
constexpr int simd_size = simd::max_size<AccT>;
static_assert(
(block_size % simd_size) == 0,
"Block size must be divisible by SIMD size");
int last_k_block_size = K - block_size * (K / block_size);
int last_k_simd_block = (last_k_block_size / simd_size) * simd_size;
for (int i = 0; i < ceildiv(M, block_size); i++) {
for (int j = 0; j < ceildiv(N, block_size); j++) {
AccT c_block[block_size * block_size] = {0.0};
AccT a_block[block_size * block_size];
AccT b_block[block_size * block_size];
int k = 0;
for (; k < K / block_size; k++) {
// Load a and b blocks
if (a_trans) {
load_block<block_size>(a, a_block, K, M, k, i, true);
} else {
load_block<block_size>(a, a_block, M, K, i, k, false);
}
if (b_trans) {
load_block<block_size>(b, b_block, N, K, j, k, false);
} else {
load_block<block_size>(b, b_block, K, N, k, j, true);
}
// Multiply and accumulate
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
for (int kk = 0; kk < block_size; kk += simd_size) {
auto av =
simd::load<AccT, simd_size>(a_block + ii * block_size + kk);
auto bv =
simd::load<AccT, simd_size>(b_block + jj * block_size + kk);
c_block[ii * block_size + jj] += simd::sum(av * bv);
}
}
}
}
if (last_k_block_size) {
// Load a and b blocks
if (a_trans) {
load_block<block_size>(a, a_block, K, M, k, i, true);
} else {
load_block<block_size>(a, a_block, M, K, i, k, false);
}
if (b_trans) {
load_block<block_size>(b, b_block, N, K, j, k, false);
} else {
load_block<block_size>(b, b_block, K, N, k, j, true);
}
// Multiply and accumulate
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
int kk = 0;
for (; kk < last_k_simd_block; kk += simd_size) {
auto av =
simd::load<AccT, simd_size>(a_block + ii * block_size + kk);
auto bv =
simd::load<AccT, simd_size>(b_block + jj * block_size + kk);
c_block[ii * block_size + jj] += simd::sum(av * bv);
}
for (; kk < last_k_block_size; ++kk) {
c_block[ii * block_size + jj] +=
a_block[ii * block_size + kk] * b_block[jj * block_size + kk];
}
}
}
}
// Store
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
auto c_idx = (i * block_size + ii) * N + j * block_size + jj;
if (beta != 0) {
c[c_idx] = static_cast<T>(
alpha * c_block[ii * block_size + jj] + beta * c[c_idx]);
} else {
c[c_idx] = static_cast<T>(alpha * c_block[ii * block_size + jj]);
}
}
}
}
}
}
} // namespace mlx::core

View File

@@ -197,7 +197,7 @@ void dispatch_gather(
} }
void Gather::eval_cpu(const std::vector<array>& inputs, array& out) { void Gather::eval_cpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
auto& src = inputs[0]; auto& src = inputs[0];
std::vector<array> inds; std::vector<array> inds;
@@ -354,7 +354,7 @@ void dispatch_gather_axis(
} }
void GatherAxis::eval_cpu(const std::vector<array>& inputs, array& out) { void GatherAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
auto& src = inputs[0]; auto& src = inputs[0];
auto& inds = inputs[1]; auto& inds = inputs[1];

View File

@@ -11,7 +11,7 @@ namespace mlx::core {
template <typename T> template <typename T>
void general_inv(T* inv, int N) { void general_inv(T* inv, int N) {
int info; int info;
auto ipiv = array::Data{allocator::malloc_or_wait(sizeof(int) * N)}; auto ipiv = array::Data{allocator::malloc(sizeof(int) * N)};
// Compute LU factorization. // Compute LU factorization.
getrf<T>( getrf<T>(
/* m = */ &N, /* m = */ &N,
@@ -49,7 +49,7 @@ void general_inv(T* inv, int N) {
} }
const int lwork = workspace_size; const int lwork = workspace_size;
auto scratch = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)}; auto scratch = array::Data{allocator::malloc(sizeof(T) * lwork)};
// Compute inverse. // Compute inverse.
getri<T>( getri<T>(

View File

@@ -0,0 +1,140 @@
// Copyright © 2023-2024 Apple Inc.
#include <cassert>
#include <cmath>
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h"
#include "mlx/primitives.h"
#include "mlx/types/limits.h"
namespace mlx::core {
namespace {
using namespace mlx::core::simd;
template <typename T, typename AccT>
void logsumexp(const array& in, array& out, Stream stream) {
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(in);
encoder.set_output_array(out);
const T* in_ptr = in.data<T>();
T* out_ptr = out.data<T>();
int M = in.shape().back();
int L = in.data_size() / M;
encoder.dispatch([in_ptr, out_ptr, M, L]() mutable {
constexpr int N = std::min(max_size<AccT>, max_size<T>);
const T* current_in_ptr;
for (int i = 0; i < L; i++, in_ptr += M, out_ptr += 1) {
// Find the maximum
current_in_ptr = in_ptr;
Simd<AccT, N> vmaximum(-numeric_limits<AccT>::infinity());
size_t s = M;
while (s >= N) {
Simd<AccT, N> vals = load<T, N>(current_in_ptr);
vmaximum = maximum(vals, vmaximum);
current_in_ptr += N;
s -= N;
}
AccT maximum = max(vmaximum);
while (s-- > 0) {
maximum = std::max(maximum, static_cast<AccT>(*current_in_ptr));
current_in_ptr++;
}
// Compute the normalizer and the exponentials
Simd<AccT, N> vnormalizer(0.0);
current_in_ptr = in_ptr;
s = M;
while (s >= N) {
Simd<AccT, N> vexp = load<T, N>(current_in_ptr);
vexp = exp(vexp - maximum);
vnormalizer = vnormalizer + vexp;
current_in_ptr += N;
s -= N;
}
AccT normalizer = sum(vnormalizer);
while (s-- > 0) {
AccT _exp = std::exp(*current_in_ptr - maximum);
normalizer += _exp;
current_in_ptr++;
}
// Normalize
*out_ptr = std::isinf(maximum)
? static_cast<T>(maximum)
: static_cast<T>(std::log(normalizer) + maximum);
}
});
}
} // namespace
void LogSumExp::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
// Make sure that the last dimension is contiguous
auto s = stream();
auto& encoder = cpu::get_command_encoder(s);
auto ensure_contiguous = [&s, &encoder](const array& x) {
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
return x;
} else {
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy(x, x_copy, CopyType::General, s);
encoder.add_temporary(x_copy);
return x_copy;
}
};
auto in = ensure_contiguous(inputs[0]);
if (in.flags().row_contiguous) {
out.set_data(allocator::malloc(out.nbytes()));
} else {
auto n = in.shape(-1);
auto flags = in.flags();
auto strides = in.strides();
for (auto& s : strides) {
s /= n;
}
bool col_contig = strides[0] == 1;
for (int i = 1; col_contig && i < strides.size(); ++i) {
col_contig &=
(out.shape(i) == 1 || strides[i - 1] == out.shape(i) * strides[i]);
}
flags.col_contiguous = col_contig;
out.set_data(
allocator::malloc(in.nbytes() / n),
in.data_size() / n,
std::move(strides),
flags);
}
switch (in.dtype()) {
case float32:
logsumexp<float, float>(in, out, stream());
break;
case float16:
logsumexp<float16_t, float>(in, out, stream());
break;
case bfloat16:
logsumexp<bfloat16_t, float>(in, out, stream());
break;
case float64:
logsumexp<double, double>(in, out, stream());
break;
default:
throw std::runtime_error(
"[logsumexp] only supports floating point types");
break;
}
}
} // namespace mlx::core

View File

@@ -30,8 +30,7 @@ void luf_impl(
auto strides = lu.strides(); auto strides = lu.strides();
strides[ndim - 1] = M; strides[ndim - 1] = M;
strides[ndim - 2] = 1; strides[ndim - 2] = 1;
lu.set_data( lu.set_data(allocator::malloc(lu.nbytes()), lu.nbytes(), strides, flags);
allocator::malloc_or_wait(lu.nbytes()), lu.nbytes(), strides, flags);
copy_inplace( copy_inplace(
a, a,
lu, lu,
@@ -44,8 +43,8 @@ void luf_impl(
stream); stream);
auto a_ptr = lu.data<T>(); auto a_ptr = lu.data<T>();
pivots.set_data(allocator::malloc_or_wait(pivots.nbytes())); pivots.set_data(allocator::malloc(pivots.nbytes()));
row_indices.set_data(allocator::malloc_or_wait(row_indices.nbytes())); row_indices.set_data(allocator::malloc(row_indices.nbytes()));
auto pivots_ptr = pivots.data<uint32_t>(); auto pivots_ptr = pivots.data<uint32_t>();
auto row_indices_ptr = row_indices.data<uint32_t>(); auto row_indices_ptr = row_indices.data<uint32_t>();
size_t num_matrices = a.size() / (M * N); size_t num_matrices = a.size() / (M * N);

View File

@@ -59,7 +59,7 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
throw std::runtime_error( throw std::runtime_error(
"[BlockMaskedMM::eval] Currently only supports float32."); "[BlockMaskedMM::eval] Currently only supports float32.");
} }
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
auto& a_pre = inputs[0]; auto& a_pre = inputs[0];
auto& b_pre = inputs[1]; auto& b_pre = inputs[1];
@@ -318,7 +318,7 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
throw std::runtime_error( throw std::runtime_error(
"[GatherMM::eval] Currently only supports float32."); "[GatherMM::eval] Currently only supports float32.");
} }
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
auto& a_pre = inputs[0]; auto& a_pre = inputs[0];
auto& b_pre = inputs[1]; auto& b_pre = inputs[1];

View File

@@ -115,7 +115,7 @@ void matmul_general(
} }
void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) { void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
if (inputs[0].shape(-1) == 0) { if (inputs[0].shape(-1) == 0) {
auto& encoder = cpu::get_command_encoder(stream()); auto& encoder = cpu::get_command_encoder(stream());
encoder.set_output_array(out); encoder.set_output_array(out);

View File

@@ -21,7 +21,7 @@ namespace mlx::core {
void reshape(const array& in, array& out) { void reshape(const array& in, array& out) {
auto [copy_necessary, out_strides] = prepare_reshape(in, out); auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) { if (copy_necessary) {
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
copy_inplace(in, out, CopyType::General, out.primitive().stream()); copy_inplace(in, out, CopyType::General, out.primitive().stream());
} else { } else {
shared_buffer_reshape(in, out_strides, out); shared_buffer_reshape(in, out_strides, out);
@@ -39,7 +39,7 @@ static std::pair<array, bool> compute_dynamic_offset(
if (donate) { if (donate) {
offset.copy_shared_buffer(indices); offset.copy_shared_buffer(indices);
} else { } else {
offset.set_data(allocator::malloc_or_wait(offset.itemsize())); offset.set_data(allocator::malloc(offset.itemsize()));
} }
auto& encoder = cpu::get_command_encoder(stream); auto& encoder = cpu::get_command_encoder(stream);
@@ -124,7 +124,7 @@ void Transpose::eval_cpu(const std::vector<array>& inputs, array& out) {
void Arange::eval_cpu(const std::vector<array>& inputs, array& out) { void Arange::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 0); assert(inputs.size() == 0);
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
switch (out.dtype()) { switch (out.dtype()) {
case bool_: case bool_:
throw std::runtime_error("Bool type unsupported for arange."); throw std::runtime_error("Bool type unsupported for arange.");
@@ -186,7 +186,7 @@ void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
} }
std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin()); std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin());
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
auto strides = out.strides(); auto strides = out.strides();
auto flags = out.flags(); auto flags = out.flags();
@@ -205,8 +205,10 @@ void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
void Contiguous::eval_cpu(const std::vector<array>& inputs, array& out) { void Contiguous::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
if (in.flags().row_contiguous || constexpr size_t extra_bytes = 16384;
(allow_col_major_ && in.flags().col_contiguous)) { if (in.buffer_size() <= out.nbytes() + extra_bytes &&
(in.flags().row_contiguous ||
(allow_col_major_ && in.flags().col_contiguous))) {
out.copy_shared_buffer(in); out.copy_shared_buffer(in);
} else { } else {
copy(in, out, CopyType::General, stream()); copy(in, out, CopyType::General, stream());
@@ -276,7 +278,7 @@ void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
size_t elems_per_key = out.size() / num_keys; size_t elems_per_key = out.size() / num_keys;
size_t bytes_per_key = out.itemsize() * elems_per_key; size_t bytes_per_key = out.itemsize() * elems_per_key;
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
auto kptr = inputs[0].data<uint32_t>(); auto kptr = inputs[0].data<uint32_t>();
auto cptr = out.data<char>(); auto cptr = out.data<char>();
@@ -335,7 +337,7 @@ void DynamicSlice::eval_cpu(const std::vector<array>& inputs, array& out) {
return; return;
} }
auto& in = inputs[0]; auto& in = inputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
auto [in_offset, donated] = auto [in_offset, donated] =
compute_dynamic_offset(inputs[1], in.strides(), axes_, stream()); compute_dynamic_offset(inputs[1], in.strides(), axes_, stream());
copy_inplace( copy_inplace(
@@ -450,7 +452,7 @@ void View::eval_cpu(const std::vector<array>& inputs, array& out) {
} else { } else {
auto tmp = array( auto tmp = array(
in.shape(), in.dtype() == bool_ ? uint8 : in.dtype(), nullptr, {}); in.shape(), in.dtype() == bool_ ? uint8 : in.dtype(), nullptr, {});
tmp.set_data(allocator::malloc_or_wait(tmp.nbytes())); tmp.set_data(allocator::malloc(tmp.nbytes()));
if (in.dtype() == bool_) { if (in.dtype() == bool_) {
auto in_tmp = array(in.shape(), uint8, nullptr, {}); auto in_tmp = array(in.shape(), uint8, nullptr, {});
in_tmp.copy_shared_buffer(in); in_tmp.copy_shared_buffer(in);

View File

@@ -25,12 +25,11 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
auto strides = in.strides(); auto strides = in.strides();
strides[in.ndim() - 2] = 1; strides[in.ndim() - 2] = 1;
strides[in.ndim() - 1] = M; strides[in.ndim() - 1] = M;
in.set_data( in.set_data(allocator::malloc(in.nbytes()), in.nbytes(), strides, flags);
allocator::malloc_or_wait(in.nbytes()), in.nbytes(), strides, flags);
copy_inplace(a, in, CopyType::GeneralGeneral, stream); copy_inplace(a, in, CopyType::GeneralGeneral, stream);
auto& encoder = cpu::get_command_encoder(stream); auto& encoder = cpu::get_command_encoder(stream);
q.set_data(allocator::malloc_or_wait(q.nbytes())); q.set_data(allocator::malloc(q.nbytes()));
r.set_data(allocator::malloc_or_wait(r.nbytes())); r.set_data(allocator::malloc(r.nbytes()));
auto in_ptr = in.data<T>(); auto in_ptr = in.data<T>();
auto r_ptr = r.data<T>(); auto r_ptr = r.data<T>();
@@ -41,8 +40,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
encoder.set_output_array(r); encoder.set_output_array(r);
encoder.dispatch([in_ptr, q_ptr, r_ptr, M, N, lda, num_matrices]() { encoder.dispatch([in_ptr, q_ptr, r_ptr, M, N, lda, num_matrices]() {
int num_reflectors = std::min(M, N); int num_reflectors = std::min(M, N);
auto tau = auto tau = allocator::malloc(sizeof(T) * num_matrices * num_reflectors);
allocator::malloc_or_wait(sizeof(T) * num_matrices * num_reflectors);
T optimal_work; T optimal_work;
int lwork = -1; int lwork = -1;
@@ -53,7 +51,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
// Update workspace size // Update workspace size
lwork = optimal_work; lwork = optimal_work;
auto work = allocator::malloc_or_wait(sizeof(T) * lwork); auto work = allocator::malloc(sizeof(T) * lwork);
// Loop over matrices // Loop over matrices
for (int i = 0; i < num_matrices; ++i) { for (int i = 0; i < num_matrices; ++i) {
@@ -96,7 +94,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
&lwork, &lwork,
&info); &info);
lwork = optimal_work; lwork = optimal_work;
work = allocator::malloc_or_wait(sizeof(T) * lwork); work = allocator::malloc(sizeof(T) * lwork);
// Loop over matrices // Loop over matrices
for (int i = 0; i < num_matrices; ++i) { for (int i = 0; i < num_matrices; ++i) {

View File

@@ -515,7 +515,7 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
auto scales = ensure_row_contiguous(scales_pre); auto scales = ensure_row_contiguous(scales_pre);
auto biases = ensure_row_contiguous(biases_pre); auto biases = ensure_row_contiguous(biases_pre);
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cpu::get_command_encoder(stream()); auto& encoder = cpu::get_command_encoder(stream());
encoder.add_temporaries(std::move(temps)); encoder.add_temporaries(std::move(temps));
@@ -565,7 +565,7 @@ void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
auto scales = ensure_row_contiguous_last_dims(scales_pre); auto scales = ensure_row_contiguous_last_dims(scales_pre);
auto biases = ensure_row_contiguous_last_dims(biases_pre); auto biases = ensure_row_contiguous_last_dims(biases_pre);
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cpu::get_command_encoder(stream()); auto& encoder = cpu::get_command_encoder(stream());
encoder.add_temporaries(std::move(temps)); encoder.add_temporaries(std::move(temps));
@@ -691,12 +691,12 @@ void fast::AffineQuantize::eval_cpu(
auto [w, copied] = ensure_row_contiguous(inputs[0]); auto [w, copied] = ensure_row_contiguous(inputs[0]);
auto& out = outputs[0]; auto& out = outputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
auto& scales = outputs[1]; auto& scales = outputs[1];
auto& biases = outputs[2]; auto& biases = outputs[2];
scales.set_data(allocator::malloc_or_wait(scales.nbytes())); scales.set_data(allocator::malloc(scales.nbytes()));
biases.set_data(allocator::malloc_or_wait(biases.nbytes())); biases.set_data(allocator::malloc(biases.nbytes()));
auto& encoder = cpu::get_command_encoder(stream()); auto& encoder = cpu::get_command_encoder(stream());
if (copied) { if (copied) {
encoder.add_temporary(w); encoder.add_temporary(w);

View File

@@ -433,7 +433,7 @@ void reduce_dispatch_min_max(
void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) { void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cpu::get_command_encoder(stream()); auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(in); encoder.set_input_array(in);
encoder.set_output_array(out); encoder.set_output_array(out);

View File

@@ -3,6 +3,7 @@
#include <cassert> #include <cassert>
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/binary_ops.h"
#include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h" #include "mlx/backend/cpu/simd/simd.h"
@@ -226,6 +227,16 @@ void scan_dispatch(
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init); scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
break; break;
} }
case Scan::LogAddExp: {
auto op = [](U a, T b) {
return detail::LogAddExp{}(a, static_cast<U>(b));
};
auto init = (issubdtype(in.dtype(), floating))
? static_cast<U>(-std::numeric_limits<float>::infinity())
: std::numeric_limits<U>::min();
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
break;
}
} }
} }
@@ -244,7 +255,7 @@ void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
in = arr_copy; in = arr_copy;
encoder.add_temporary(arr_copy); encoder.add_temporary(arr_copy);
} }
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
encoder.set_input_array(in); encoder.set_input_array(in);
encoder.set_output_array(out); encoder.set_output_array(out);

View File

@@ -17,7 +17,7 @@ struct ScalarT<float16_t, N> {
#endif #endif
template <> template <>
static constexpr int max_size<float16_t> = N; inline constexpr int max_size<float16_t> = N;
#define SIMD_FP16_DEFAULT_UNARY(op) \ #define SIMD_FP16_DEFAULT_UNARY(op) \
template <> \ template <> \

View File

@@ -83,25 +83,25 @@ struct Simd {
// Values chosen based on benchmarks on M3 Max // Values chosen based on benchmarks on M3 Max
// TODO: consider choosing these more optimally // TODO: consider choosing these more optimally
template <> template <>
static constexpr int max_size<int8_t> = 16; inline constexpr int max_size<int8_t> = 16;
template <> template <>
static constexpr int max_size<int16_t> = 16; inline constexpr int max_size<int16_t> = 16;
template <> template <>
static constexpr int max_size<int> = 8; inline constexpr int max_size<int> = 8;
template <> template <>
static constexpr int max_size<int64_t> = 4; inline constexpr int max_size<int64_t> = 4;
template <> template <>
static constexpr int max_size<uint8_t> = 16; inline constexpr int max_size<uint8_t> = 16;
template <> template <>
static constexpr int max_size<uint16_t> = 16; inline constexpr int max_size<uint16_t> = 16;
template <> template <>
static constexpr int max_size<uint32_t> = 8; inline constexpr int max_size<uint32_t> = 8;
template <> template <>
static constexpr int max_size<uint64_t> = 4; inline constexpr int max_size<uint64_t> = 4;
template <> template <>
static constexpr int max_size<float> = 8; inline constexpr int max_size<float> = 8;
template <> template <>
static constexpr int max_size<double> = 4; inline constexpr int max_size<double> = 4;
#define SIMD_DEFAULT_UNARY(name, op) \ #define SIMD_DEFAULT_UNARY(name, op) \
template <typename T, int N> \ template <typename T, int N> \

View File

@@ -87,7 +87,6 @@ DEFAULT_UNARY(cosh, std::cosh)
DEFAULT_UNARY(expm1, std::expm1) DEFAULT_UNARY(expm1, std::expm1)
DEFAULT_UNARY(floor, std::floor) DEFAULT_UNARY(floor, std::floor)
DEFAULT_UNARY(log, std::log) DEFAULT_UNARY(log, std::log)
DEFAULT_UNARY(log2, std::log2)
DEFAULT_UNARY(log10, std::log10) DEFAULT_UNARY(log10, std::log10)
DEFAULT_UNARY(log1p, std::log1p) DEFAULT_UNARY(log1p, std::log1p)
DEFAULT_UNARY(sinh, std::sinh) DEFAULT_UNARY(sinh, std::sinh)
@@ -95,6 +94,17 @@ DEFAULT_UNARY(sqrt, std::sqrt)
DEFAULT_UNARY(tan, std::tan) DEFAULT_UNARY(tan, std::tan)
DEFAULT_UNARY(tanh, std::tanh) DEFAULT_UNARY(tanh, std::tanh)
template <typename T>
Simd<T, 1> log2(Simd<T, 1> in) {
if constexpr (is_complex<T>) {
auto out = std::log(in.value);
auto scale = decltype(out.real())(M_LN2);
return Simd<T, 1>{T{out.real() / scale, out.imag() / scale}};
} else {
return Simd<T, 1>{std::log2(in.value)};
}
}
template <typename T> template <typename T>
Simd<T, 1> operator~(Simd<T, 1> in) { Simd<T, 1> operator~(Simd<T, 1> in) {
return ~in.value; return ~in.value;

View File

@@ -119,17 +119,12 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
// Make sure that the last dimension is contiguous // Make sure that the last dimension is contiguous
auto set_output = [s = stream(), &out](const array& x) { auto set_output = [s = stream(), &out](const array& x) {
bool no_copy = x.strides()[x.ndim() - 1] == 1; if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
if (x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2];
no_copy &= (s == 0 || s == x.shape().back());
}
if (no_copy) {
if (x.is_donatable()) { if (x.is_donatable()) {
out.copy_shared_buffer(x); out.copy_shared_buffer(x);
} else { } else {
out.set_data( out.set_data(
allocator::malloc_or_wait(x.data_size() * x.itemsize()), allocator::malloc(x.data_size() * x.itemsize()),
x.data_size(), x.data_size(),
x.strides(), x.strides(),
x.flags()); x.flags());
@@ -146,18 +141,6 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
auto in = set_output(inputs[0]); auto in = set_output(inputs[0]);
switch (in.dtype()) { switch (in.dtype()) {
case bool_:
case uint8:
case uint16:
case uint32:
case uint64:
case int8:
case int16:
case int32:
case int64:
throw std::runtime_error(
"Softmax is defined only for floating point types");
break;
case float32: case float32:
softmax<float, float>(in, out, stream()); softmax<float, float>(in, out, stream());
break; break;
@@ -178,9 +161,9 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
case float64: case float64:
softmax<double, double>(in, out, stream()); softmax<double, double>(in, out, stream());
break; break;
case complex64: default:
throw std::invalid_argument( throw std::runtime_error(
"[Softmax] Not yet implemented for complex64"); "[softmax] Only defined for floating point types.");
break; break;
} }
} }

View File

@@ -288,7 +288,7 @@ void ArgSort::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0]; auto& in = inputs[0];
// Allocate output // Allocate output
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cpu::get_command_encoder(stream()); auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(in); encoder.set_input_array(in);
@@ -379,7 +379,7 @@ void ArgPartition::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0]; auto& in = inputs[0];
// Allocate output // Allocate output
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cpu::get_command_encoder(stream()); auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(in); encoder.set_input_array(in);

View File

@@ -50,9 +50,9 @@ void svd_impl(
array& s = outputs[1]; array& s = outputs[1];
array& vt = outputs[2]; array& vt = outputs[2];
u.set_data(allocator::malloc_or_wait(u.nbytes())); u.set_data(allocator::malloc(u.nbytes()));
s.set_data(allocator::malloc_or_wait(s.nbytes())); s.set_data(allocator::malloc(s.nbytes()));
vt.set_data(allocator::malloc_or_wait(vt.nbytes())); vt.set_data(allocator::malloc(vt.nbytes()));
encoder.set_output_array(u); encoder.set_output_array(u);
encoder.set_output_array(s); encoder.set_output_array(s);
@@ -64,7 +64,7 @@ void svd_impl(
} else { } else {
array& s = outputs[0]; array& s = outputs[0];
s.set_data(allocator::malloc_or_wait(s.nbytes())); s.set_data(allocator::malloc(s.nbytes()));
encoder.set_output_array(s); encoder.set_output_array(s);
@@ -91,7 +91,7 @@ void svd_impl(
// Will contain the indices of eigenvectors that failed to converge (not // Will contain the indices of eigenvectors that failed to converge (not
// used here but required by lapack). // used here but required by lapack).
auto iwork = array::Data{allocator::malloc_or_wait(sizeof(int) * 12 * K)}; auto iwork = array::Data{allocator::malloc(sizeof(int) * 12 * K)};
static const int lwork_query = -1; static const int lwork_query = -1;
@@ -132,7 +132,7 @@ void svd_impl(
} }
const int lwork = workspace_dimension; const int lwork = workspace_dimension;
auto scratch = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)}; auto scratch = array::Data{allocator::malloc(sizeof(T) * lwork)};
// Loop over matrices. // Loop over matrices.
for (int i = 0; i < num_matrices; i++) { for (int i = 0; i < num_matrices; i++) {

View File

@@ -1,5 +1,8 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
// Required for using M_LN2 in MSVC.
#define _USE_MATH_DEFINES
#include <cassert> #include <cassert>
#include "mlx/backend/cpu/unary.h" #include "mlx/backend/cpu/unary.h"

View File

@@ -18,13 +18,13 @@ void set_unary_output_data(const array& in, array& out) {
} else { } else {
auto size = in.data_size(); auto size = in.data_size();
out.set_data( out.set_data(
allocator::malloc_or_wait(size * out.itemsize()), allocator::malloc(size * out.itemsize()),
size, size,
in.strides(), in.strides(),
in.flags()); in.flags());
} }
} else { } else {
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
} }
} }

View File

@@ -86,13 +86,14 @@ struct Sign {
template <int N, typename T> template <int N, typename T>
Simd<T, N> operator()(Simd<T, N> x) { Simd<T, N> operator()(Simd<T, N> x) {
auto z = Simd<T, N>{0}; auto z = Simd<T, N>{0};
auto o = Simd<T, N>{1};
auto m = Simd<T, N>{-1};
if constexpr (std::is_unsigned_v<T>) { if constexpr (std::is_unsigned_v<T>) {
return x != z; return simd::select(x == z, z, o);
} else if constexpr (std::is_same_v<T, complex64_t>) { } else if constexpr (std::is_same_v<T, complex64_t>) {
return simd::select(x == z, x, Simd<T, N>(x / simd::abs(x))); return simd::select(x == z, x, Simd<T, N>(x / simd::abs(x)));
} else { } else {
return simd::select( return simd::select(x < z, m, simd::select(x > z, o, z));
x < z, Simd<T, N>{-1}, simd::select(x > z, Simd<T, N>{1}, z));
} }
} }
SINGLE() SINGLE()

View File

@@ -47,6 +47,7 @@ if(MLX_METAL_JIT)
make_jit_source(binary) make_jit_source(binary)
make_jit_source(binary_two) make_jit_source(binary_two)
make_jit_source(fft kernels/fft/radix.h kernels/fft/readwrite.h) make_jit_source(fft kernels/fft/radix.h kernels/fft/readwrite.h)
make_jit_source(logsumexp)
make_jit_source(ternary) make_jit_source(ternary)
make_jit_source(softmax) make_jit_source(softmax)
make_jit_source(scan) make_jit_source(scan)
@@ -60,6 +61,7 @@ if(MLX_METAL_JIT)
kernels/steel/gemm/transforms.h) kernels/steel/gemm/transforms.h)
make_jit_source(steel/gemm/kernels/steel_gemm_fused) make_jit_source(steel/gemm/kernels/steel_gemm_fused)
make_jit_source(steel/gemm/kernels/steel_gemm_masked kernels/steel/defines.h) make_jit_source(steel/gemm/kernels/steel_gemm_masked kernels/steel/defines.h)
make_jit_source(steel/gemm/kernels/steel_gemm_gather)
make_jit_source(steel/gemm/kernels/steel_gemm_splitk) make_jit_source(steel/gemm/kernels/steel_gemm_splitk)
make_jit_source( make_jit_source(
steel/conv/conv steel/conv/conv
@@ -95,6 +97,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp ${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cpp
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp ${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp ${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp

View File

@@ -3,6 +3,7 @@
#include "mlx/backend/metal/metal.h" #include "mlx/backend/metal/metal.h"
#include "mlx/backend/metal/metal_impl.h" #include "mlx/backend/metal/metal_impl.h"
#include "mlx/backend/metal/resident.h" #include "mlx/backend/metal/resident.h"
#include "mlx/memory.h"
#include <mach/vm_page_size.h> #include <mach/vm_page_size.h>
#include <unistd.h> #include <unistd.h>
@@ -32,8 +33,11 @@ namespace metal {
namespace { namespace {
BufferCache::BufferCache(MTL::Device* device) BufferCache::BufferCache(ResidencySet& residency_set)
: device_(device), head_(nullptr), tail_(nullptr), pool_size_(0) {} : head_(nullptr),
tail_(nullptr),
pool_size_(0),
residency_set_(residency_set) {}
BufferCache::~BufferCache() { BufferCache::~BufferCache() {
auto pool = metal::new_scoped_memory_pool(); auto pool = metal::new_scoped_memory_pool();
@@ -44,6 +48,9 @@ int BufferCache::clear() {
int n_release = 0; int n_release = 0;
for (auto& [size, holder] : buffer_pool_) { for (auto& [size, holder] : buffer_pool_) {
if (holder->buf) { if (holder->buf) {
if (!holder->buf->heap()) {
residency_set_.erase(holder->buf);
}
holder->buf->release(); holder->buf->release();
n_release++; n_release++;
} }
@@ -101,6 +108,9 @@ int BufferCache::release_cached_buffers(size_t min_bytes_to_free) {
while (tail_ && (total_bytes_freed < min_bytes_to_free)) { while (tail_ && (total_bytes_freed < min_bytes_to_free)) {
if (tail_->buf) { if (tail_->buf) {
total_bytes_freed += tail_->buf->length(); total_bytes_freed += tail_->buf->length();
if (!tail_->buf->heap()) {
residency_set_.erase(tail_->buf);
}
tail_->buf->release(); tail_->buf->release();
tail_->buf = nullptr; tail_->buf = nullptr;
n_release++; n_release++;
@@ -155,7 +165,7 @@ void BufferCache::remove_from_list(BufferCache::BufferHolder* to_remove) {
MetalAllocator::MetalAllocator() MetalAllocator::MetalAllocator()
: device_(device(mlx::core::Device::gpu).mtl_device()), : device_(device(mlx::core::Device::gpu).mtl_device()),
residency_set_(device_), residency_set_(device_),
buffer_cache_(device_) { buffer_cache_(residency_set_) {
auto pool = metal::new_scoped_memory_pool(); auto pool = metal::new_scoped_memory_pool();
auto memsize = std::get<size_t>(device_info().at("memory_size")); auto memsize = std::get<size_t>(device_info().at("memory_size"));
auto max_rec_size = auto max_rec_size =
@@ -192,16 +202,19 @@ size_t MetalAllocator::set_cache_limit(size_t limit) {
return limit; return limit;
}; };
size_t MetalAllocator::set_memory_limit(size_t limit, bool relaxed) { size_t MetalAllocator::set_memory_limit(size_t limit) {
std::unique_lock lk(mutex_); std::unique_lock lk(mutex_);
std::swap(limit, block_limit_); std::swap(limit, block_limit_);
relaxed_ = relaxed;
gc_limit_ = std::min( gc_limit_ = std::min(
block_limit_, block_limit_,
static_cast<size_t>(0.95 * device_->recommendedMaxWorkingSetSize())); static_cast<size_t>(0.95 * device_->recommendedMaxWorkingSetSize()));
return limit; return limit;
}; };
size_t MetalAllocator::get_memory_limit() {
return block_limit_;
}
size_t MetalAllocator::set_wired_limit(size_t limit) { size_t MetalAllocator::set_wired_limit(size_t limit) {
std::unique_lock lk(mutex_); std::unique_lock lk(mutex_);
std::swap(limit, wired_limit_); std::swap(limit, wired_limit_);
@@ -209,7 +222,7 @@ size_t MetalAllocator::set_wired_limit(size_t limit) {
return limit; return limit;
}; };
Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) { Buffer MetalAllocator::malloc(size_t size) {
// Metal doesn't like empty buffers // Metal doesn't like empty buffers
if (size == 0) { if (size == 0) {
return Buffer{nullptr}; return Buffer{nullptr};
@@ -236,11 +249,6 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
if (!buf) { if (!buf) {
size_t mem_required = get_active_memory() + get_cache_memory() + size; size_t mem_required = get_active_memory() + get_cache_memory() + size;
// If there is too much memory pressure, fail (likely causes a wait).
if (!(allow_swap && relaxed_) && mem_required >= block_limit_) {
return Buffer{nullptr};
}
auto pool = metal::new_scoped_memory_pool(); auto pool = metal::new_scoped_memory_pool();
// If we have a lot of memory pressure or are over the maximum cache size, // If we have a lot of memory pressure or are over the maximum cache size,
@@ -264,9 +272,13 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
if (!buf) { if (!buf) {
buf = device_->newBuffer(size, resource_options); buf = device_->newBuffer(size, resource_options);
} }
if (!buf) {
return Buffer{nullptr};
}
lk.lock(); lk.lock();
if (buf) { num_resources_++;
num_resources_++; if (!buf->heap()) {
residency_set_.insert(buf);
} }
} }
@@ -280,10 +292,6 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
get_cache_memory() - max_pool_size_); get_cache_memory() - max_pool_size_);
} }
if (!buf->heap()) {
residency_set_.insert(buf);
}
return Buffer{static_cast<void*>(buf)}; return Buffer{static_cast<void*>(buf)};
} }
@@ -299,14 +307,14 @@ void MetalAllocator::free(Buffer buffer) {
return; return;
} }
std::unique_lock lk(mutex_); std::unique_lock lk(mutex_);
if (!buf->heap()) {
residency_set_.erase(buf);
}
active_memory_ -= buf->length(); active_memory_ -= buf->length();
if (get_cache_memory() < max_pool_size_) { if (get_cache_memory() < max_pool_size_) {
buffer_cache_.recycle_to_cache(buf); buffer_cache_.recycle_to_cache(buf);
} else { } else {
num_resources_--; num_resources_--;
if (!buf->heap()) {
residency_set_.erase(buf);
}
lk.unlock(); lk.unlock();
auto pool = metal::new_scoped_memory_pool(); auto pool = metal::new_scoped_memory_pool();
buf->release(); buf->release();
@@ -325,37 +333,40 @@ MetalAllocator& allocator() {
return *allocator_; return *allocator_;
} }
} // namespace metal
size_t set_cache_limit(size_t limit) { size_t set_cache_limit(size_t limit) {
return allocator().set_cache_limit(limit); return metal::allocator().set_cache_limit(limit);
} }
size_t set_memory_limit(size_t limit, bool relaxed /* = true */) { size_t set_memory_limit(size_t limit) {
return allocator().set_memory_limit(limit, relaxed); return metal::allocator().set_memory_limit(limit);
}
size_t get_memory_limit() {
return metal::allocator().get_memory_limit();
} }
size_t set_wired_limit(size_t limit) { size_t set_wired_limit(size_t limit) {
if (limit > if (limit > std::get<size_t>(metal::device_info().at(
std::get<size_t>(device_info().at("max_recommended_working_set_size"))) { "max_recommended_working_set_size"))) {
throw std::invalid_argument( throw std::invalid_argument(
"[metal::set_wired_limit] Setting a wired limit larger than " "[metal::set_wired_limit] Setting a wired limit larger than "
"the maximum working set size is not allowed."); "the maximum working set size is not allowed.");
} }
return allocator().set_wired_limit(limit); return metal::allocator().set_wired_limit(limit);
} }
size_t get_active_memory() { size_t get_active_memory() {
return allocator().get_active_memory(); return metal::allocator().get_active_memory();
} }
size_t get_peak_memory() { size_t get_peak_memory() {
return allocator().get_peak_memory(); return metal::allocator().get_peak_memory();
} }
void reset_peak_memory() { void reset_peak_memory() {
allocator().reset_peak_memory(); metal::allocator().reset_peak_memory();
} }
size_t get_cache_memory() { size_t get_cache_memory() {
return allocator().get_cache_memory(); return metal::allocator().get_cache_memory();
} }
void clear_cache() { void clear_cache() {
return allocator().clear_cache(); return metal::allocator().clear_cache();
} }
} // namespace metal
} // namespace mlx::core } // namespace mlx::core

View File

@@ -18,7 +18,7 @@ namespace {
class BufferCache { class BufferCache {
public: public:
BufferCache(MTL::Device* device); BufferCache(ResidencySet& residency_set);
~BufferCache(); ~BufferCache();
MTL::Buffer* reuse_from_cache(size_t size); MTL::Buffer* reuse_from_cache(size_t size);
@@ -42,13 +42,11 @@ class BufferCache {
void add_at_head(BufferHolder* to_add); void add_at_head(BufferHolder* to_add);
void remove_from_list(BufferHolder* to_remove); void remove_from_list(BufferHolder* to_remove);
MTL::Device* device_;
MTL::Heap* heap_{nullptr};
std::multimap<size_t, BufferHolder*> buffer_pool_; std::multimap<size_t, BufferHolder*> buffer_pool_;
BufferHolder* head_; BufferHolder* head_;
BufferHolder* tail_; BufferHolder* tail_;
size_t pool_size_; size_t pool_size_;
ResidencySet& residency_set_;
}; };
} // namespace } // namespace
@@ -56,7 +54,7 @@ class BufferCache {
class MetalAllocator : public allocator::Allocator { class MetalAllocator : public allocator::Allocator {
/** Allocator for Metal GPUs. */ /** Allocator for Metal GPUs. */
public: public:
virtual Buffer malloc(size_t size, bool allow_swap = false) override; virtual Buffer malloc(size_t size) override;
virtual void free(Buffer buffer) override; virtual void free(Buffer buffer) override;
virtual size_t size(Buffer buffer) const override; virtual size_t size(Buffer buffer) const override;
size_t get_active_memory() { size_t get_active_memory() {
@@ -73,7 +71,8 @@ class MetalAllocator : public allocator::Allocator {
return buffer_cache_.cache_size(); return buffer_cache_.cache_size();
}; };
size_t set_cache_limit(size_t limit); size_t set_cache_limit(size_t limit);
size_t set_memory_limit(size_t limit, bool relaxed); size_t set_memory_limit(size_t limit);
size_t get_memory_limit();
size_t set_wired_limit(size_t limit); size_t set_wired_limit(size_t limit);
void clear_cache(); void clear_cache();

View File

@@ -37,7 +37,7 @@ void explicit_gemm_conv_ND_gpu(
Shape unfolded_shape{implicit_M, implicit_K}; Shape unfolded_shape{implicit_M, implicit_K};
array in_unfolded(unfolded_shape, in.dtype(), nullptr, {}); array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes())); in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes()));
// Prepare unfolding kernel // Prepare unfolding kernel
std::ostringstream kname; std::ostringstream kname;
@@ -115,7 +115,7 @@ void explicit_gemm_conv_group_ND_gpu(
// Prepare unfolding array // Prepare unfolding array
Shape unfolded_shape{implicit_M, implicit_K * groups}; Shape unfolded_shape{implicit_M, implicit_K * groups};
array in_unfolded(unfolded_shape, in.dtype(), nullptr, {}); array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes())); in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes()));
// Prepare unfolding kernel // Prepare unfolding kernel
std::ostringstream kname; std::ostringstream kname;
@@ -613,7 +613,7 @@ void winograd_conv_2D_gpu(
// Do filter transform // Do filter transform
Shape filt_wg_shape = {8 * 8, conv_params.C, conv_params.O}; Shape filt_wg_shape = {8 * 8, conv_params.C, conv_params.O};
array filt_wg(std::move(filt_wg_shape), wt.dtype(), nullptr, {}); array filt_wg(std::move(filt_wg_shape), wt.dtype(), nullptr, {});
filt_wg.set_data(allocator::malloc_or_wait(filt_wg.nbytes())); filt_wg.set_data(allocator::malloc(filt_wg.nbytes()));
copies_w.push_back(filt_wg); copies_w.push_back(filt_wg);
{ {
int bc = 32; int bc = 32;
@@ -640,7 +640,7 @@ void winograd_conv_2D_gpu(
// Do input transform // Do input transform
Shape inp_wg_shape = {8 * 8, N_tiles, conv_params.C}; Shape inp_wg_shape = {8 * 8, N_tiles, conv_params.C};
array inp_wg(std::move(inp_wg_shape), in.dtype(), nullptr, {}); array inp_wg(std::move(inp_wg_shape), in.dtype(), nullptr, {});
inp_wg.set_data(allocator::malloc_or_wait(inp_wg.nbytes())); inp_wg.set_data(allocator::malloc(inp_wg.nbytes()));
copies_w.push_back(inp_wg); copies_w.push_back(inp_wg);
{ {
int bc = 32; int bc = 32;
@@ -667,7 +667,7 @@ void winograd_conv_2D_gpu(
// Do batched gemm // Do batched gemm
Shape out_wg_shape = {8 * 8, N_tiles, conv_params.O}; Shape out_wg_shape = {8 * 8, N_tiles, conv_params.O};
array out_wg(std::move(out_wg_shape), in.dtype(), nullptr, {}); array out_wg(std::move(out_wg_shape), in.dtype(), nullptr, {});
out_wg.set_data(allocator::malloc_or_wait(out_wg.nbytes())); out_wg.set_data(allocator::malloc(out_wg.nbytes()));
copies_w.push_back(out_wg); copies_w.push_back(out_wg);
{ {
std::vector<array> empty_copies; std::vector<array> empty_copies;
@@ -712,6 +712,65 @@ void winograd_conv_2D_gpu(
} }
} }
void depthwise_conv_2D_gpu(
const Stream& s,
metal::Device& d,
const array& in,
const array& wt,
array out,
const MLXConvParams<2>& conv_params) {
std::ostringstream kname;
kname << "depthwise_conv_2d_" << type_to_name(out);
std::string base_name = kname.str();
const int N = conv_params.N;
const int ker_h = conv_params.wS[0];
const int ker_w = conv_params.wS[1];
const int str_h = conv_params.str[0];
const int str_w = conv_params.str[1];
const int tc = 8;
const int tw = 8;
const int th = 4;
const bool do_flip = conv_params.flip;
metal::MTLFCList func_consts = {
{&ker_h, MTL::DataType::DataTypeInt, 00},
{&ker_w, MTL::DataType::DataTypeInt, 01},
{&str_h, MTL::DataType::DataTypeInt, 10},
{&str_w, MTL::DataType::DataTypeInt, 11},
{&th, MTL::DataType::DataTypeInt, 100},
{&tw, MTL::DataType::DataTypeInt, 101},
{&do_flip, MTL::DataType::DataTypeBool, 200},
};
// clang-format off
kname << "_ker_h_" << ker_h
<< "_ker_w_" << ker_w
<< "_str_h_" << str_h
<< "_str_w_" << str_w
<< "_tgp_h_" << th
<< "_tgp_w_" << tw
<< "_do_flip_" << (do_flip ? 't' : 'n'); // clang-format on
std::string hash_name = kname.str();
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_input_array(wt, 1);
compute_encoder.set_output_array(out, 2);
compute_encoder.set_bytes(conv_params, 3);
MTL::Size group_dims = MTL::Size(tc, tw, th);
MTL::Size grid_dims = MTL::Size(
conv_params.C / tc, conv_params.oS[1] / tw, (conv_params.oS[0] / th) * N);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
void conv_2D_gpu( void conv_2D_gpu(
const Stream& s, const Stream& s,
metal::Device& d, metal::Device& d,
@@ -754,11 +813,20 @@ void conv_2D_gpu(
bool is_kdil_one = conv_params.kdil[0] == 1 && conv_params.kdil[1] == 1; bool is_kdil_one = conv_params.kdil[0] == 1 && conv_params.kdil[1] == 1;
bool is_idil_one = conv_params.idil[0] == 1 && conv_params.idil[1] == 1; bool is_idil_one = conv_params.idil[0] == 1 && conv_params.idil[1] == 1;
if (groups > 1) { if (is_idil_one && groups > 1) {
const int C_per_group = conv_params.C / groups; const int C_per_group = conv_params.C / groups;
const int O_per_group = conv_params.O / groups; const int O_per_group = conv_params.O / groups;
if (is_idil_one && (C_per_group <= 4 || C_per_group % 16 == 0) && if (C_per_group == 1 && O_per_group == 1 && is_kdil_one &&
conv_params.wS[0] <= 7 && conv_params.wS[1] <= 7 &&
conv_params.str[0] <= 2 && conv_params.str[1] <= 2 &&
conv_params.oS[0] % 8 == 0 && conv_params.oS[1] % 8 == 0 &&
conv_params.wt_strides[1] == conv_params.wS[1] &&
conv_params.C % 16 == 0 && conv_params.C == conv_params.O) {
return depthwise_conv_2D_gpu(s, d, in, wt, out, conv_params);
}
if ((C_per_group <= 4 || C_per_group % 16 == 0) &&
(O_per_group <= 16 || O_per_group % 16 == 0)) { (O_per_group <= 16 || O_per_group % 16 == 0)) {
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params); return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
} else { } else {
@@ -855,7 +923,7 @@ void conv_3D_gpu(
} // namespace } // namespace
void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) { void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream(); auto& s = stream();
auto& d = metal::device(s.device); auto& d = metal::device(s.device);

View File

@@ -202,7 +202,7 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
if (out.size() == 0) { if (out.size() == 0) {
return; return;
} }
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
bool large = out.data_size() > UINT32_MAX; bool large = out.data_size() > UINT32_MAX;
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
std::string kernel_name = std::string(large ? "s2" : "s") + "_copy" + std::string kernel_name = std::string(large ? "s2" : "s") + "_copy" +

View File

@@ -19,7 +19,7 @@ void CustomKernel::eval_gpu(
copies.emplace_back(init_value_.value(), out.dtype()); copies.emplace_back(init_value_.value(), out.dtype());
fill_gpu(copies.back(), out, s); fill_gpu(copies.back(), out, s);
} else { } else {
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
} }
} }

View File

@@ -55,7 +55,10 @@ std::pair<MTL::Library*, NS::Error*> load_library_from_path(
} }
#ifdef SWIFTPM_BUNDLE #ifdef SWIFTPM_BUNDLE
MTL::Library* try_load_bundle(MTL::Device* device, NS::URL* url) { MTL::Library* try_load_bundle(
MTL::Device* device,
NS::URL* url,
const std::string& lib_name) {
std::string bundle_path = std::string(url->fileSystemRepresentation()) + "/" + std::string bundle_path = std::string(url->fileSystemRepresentation()) + "/" +
SWIFTPM_BUNDLE + ".bundle"; SWIFTPM_BUNDLE + ".bundle";
auto bundle = NS::Bundle::alloc()->init( auto bundle = NS::Bundle::alloc()->init(
@@ -63,8 +66,8 @@ MTL::Library* try_load_bundle(MTL::Device* device, NS::URL* url) {
if (bundle != nullptr) { if (bundle != nullptr) {
std::string resource_path = std::string resource_path =
std::string(bundle->resourceURL()->fileSystemRepresentation()) + "/" + std::string(bundle->resourceURL()->fileSystemRepresentation()) + "/" +
"default.metallib"; lib_name + ".metallib" auto [lib, error] =
auto [lib, error] = load_library_from_path(device, resource_path.c_str()); load_library_from_path(device, resource_path.c_str());
if (lib) { if (lib) {
return lib; return lib;
} }
@@ -73,51 +76,124 @@ MTL::Library* try_load_bundle(MTL::Device* device, NS::URL* url) {
} }
#endif #endif
// Firstly, search for the metallib in the same path as this binary
std::pair<MTL::Library*, NS::Error*> load_colocated_library(
MTL::Device* device,
const std::string& lib_name) {
std::string lib_path = get_colocated_mtllib_path(lib_name);
if (lib_path.size() != 0) {
return load_library_from_path(device, lib_path.c_str());
}
return {nullptr, nullptr};
}
std::pair<MTL::Library*, NS::Error*> load_swiftpm_library(
MTL::Device* device,
const std::string& lib_name) {
#ifdef SWIFTPM_BUNDLE
MTL::Library* library =
try_load_bundle(device, NS::Bundle::mainBundle()->bundleURL(), lib_name);
if (library != nullptr) {
return {library, nullptr};
}
auto bundles = NS::Bundle::allBundles();
for (int i = 0, c = (int)bundles->count(); i < c; i++) {
auto bundle = reinterpret_cast<NS::Bundle*>(bundles->object(i));
library = try_load_bundle(device, bundle->resourceURL());
if (library != nullptr) {
return {library, nullptr};
}
}
#endif
return {nullptr, nullptr};
}
MTL::Library* load_default_library(MTL::Device* device) {
NS::Error *error1, *error2, *error3;
MTL::Library* lib;
// First try the colocated mlx.metallib
std::tie(lib, error1) = load_colocated_library(device, "mlx");
if (lib) {
return lib;
}
// Then try default.metallib in a SwiftPM bundle if we have one
std::tie(lib, error2) = load_swiftpm_library(device, "default");
if (lib) {
return lib;
}
// Finally try default_mtllib_path
std::tie(lib, error3) = load_library_from_path(device, default_mtllib_path);
if (!lib) {
std::ostringstream msg;
msg << "Failed to load the default metallib. ";
if (error1 != nullptr) {
msg << error1->localizedDescription()->utf8String() << " ";
}
if (error2 != nullptr) {
msg << error2->localizedDescription()->utf8String() << " ";
}
if (error3 != nullptr) {
msg << error3->localizedDescription()->utf8String() << " ";
}
throw std::runtime_error(msg.str());
}
return lib;
}
MTL::Library* load_library( MTL::Library* load_library(
MTL::Device* device, MTL::Device* device,
const std::string& lib_name = "mlx", const std::string& lib_name,
const char* lib_path = default_mtllib_path) { const std::string& lib_path) {
// Firstly, search for the metallib in the same path as this binary // We have been given a path that ends in metallib so try to load it
std::string first_path = get_colocated_mtllib_path(lib_name); if (lib_path.size() > 9 &&
if (first_path.size() != 0) { std::equal(lib_path.end() - 9, lib_path.end(), ".metallib")) {
auto [lib, error] = load_library_from_path(device, first_path.c_str()); auto [lib, error] = load_library_from_path(device, lib_path.c_str());
if (!lib) {
std::ostringstream msg;
msg << "Failed to load the metallib from <" << lib_path << "> with error "
<< error->localizedDescription()->utf8String();
throw std::runtime_error(msg.str());
}
}
// We have been given a path so try to load from lib_path / lib_name.metallib
if (lib_path.size() > 0) {
std::string full_path = lib_path + "/" + lib_name + ".metallib";
auto [lib, error] = load_library_from_path(device, full_path.c_str());
if (!lib) {
std::ostringstream msg;
msg << "Failed to load the metallib from <" << full_path
<< "> with error " << error->localizedDescription()->utf8String();
throw std::runtime_error(msg.str());
}
}
// Try to load the colocated library
{
auto [lib, error] = load_colocated_library(device, lib_name);
if (lib) { if (lib) {
return lib; return lib;
} }
} }
#ifdef SWIFTPM_BUNDLE // Try to load the library from swiftpm
// try to load from a swiftpm resource bundle -- scan the available bundles to
// find one that contains the named bundle
{ {
MTL::Library* library = auto [lib, error] = load_swiftpm_library(device, lib_name);
try_load_bundle(device, NS::Bundle::mainBundle()->bundleURL()); if (lib) {
if (library != nullptr) { return lib;
return library;
}
auto bundles = NS::Bundle::allBundles();
for (int i = 0, c = (int)bundles->count(); i < c; i++) {
auto bundle = reinterpret_cast<NS::Bundle*>(bundles->object(i));
library = try_load_bundle(device, bundle->resourceURL());
if (library != nullptr) {
return library;
}
} }
} }
#endif
// Couldn't find it so let's load it from default_mtllib_path std::ostringstream msg;
{ msg << "Failed to load the metallib " << lib_name << ".metallib. "
auto [lib, error] = load_library_from_path(device, lib_path); << "We attempted to load it from <" << get_colocated_mtllib_path(lib_name)
if (!lib) { << ">";
std::ostringstream msg; #ifdef SWIFTPM_BUNDLE
msg << error->localizedDescription()->utf8String() << "\n" msg << " and from the Swift PM bundle.";
<< "Failed to load device library from <" << lib_path << ">" #endif
<< " or <" << first_path << ">."; throw std::runtime_error(msg.str());
throw std::runtime_error(msg.str());
}
return lib;
}
} }
} // namespace } // namespace
@@ -210,7 +286,7 @@ void CommandEncoder::barrier() {
Device::Device() { Device::Device() {
auto pool = new_scoped_memory_pool(); auto pool = new_scoped_memory_pool();
device_ = load_device(); device_ = load_device();
library_map_ = {{"mlx", load_library(device_)}}; library_map_ = {{"mlx", load_default_library(device_)}};
arch_ = std::string(device_->architecture()->name()->utf8String()); arch_ = std::string(device_->architecture()->name()->utf8String());
auto arch = arch_.back(); auto arch = arch_.back();
switch (arch) { switch (arch) {

View File

@@ -189,15 +189,7 @@ class Device {
void register_library( void register_library(
const std::string& lib_name, const std::string& lib_name,
const std::string& lib_path); const std::string& lib_path = "");
// Note, this should remain in the header so that it is not dynamically
// linked
void register_library(const std::string& lib_name) {
if (auto it = library_map_.find(lib_name); it == library_map_.end()) {
register_library(lib_name, get_colocated_mtllib_path(lib_name));
}
}
MTL::Library* get_library( MTL::Library* get_library(
const std::string& name, const std::string& name,

View File

@@ -24,10 +24,6 @@ void Event::wait() {
} }
} }
void Event::signal() {
static_cast<MTL::SharedEvent*>(event_.get())->setSignaledValue(value());
}
void Event::wait(Stream stream) { void Event::wait(Stream stream) {
if (stream.device == Device::cpu) { if (stream.device == Device::cpu) {
scheduler::enqueue(stream, [*this]() mutable { wait(); }); scheduler::enqueue(stream, [*this]() mutable { wait(); });
@@ -42,7 +38,9 @@ void Event::wait(Stream stream) {
void Event::signal(Stream stream) { void Event::signal(Stream stream) {
if (stream.device == Device::cpu) { if (stream.device == Device::cpu) {
scheduler::enqueue(stream, [*this]() mutable { signal(); }); scheduler::enqueue(stream, [*this]() mutable {
static_cast<MTL::SharedEvent*>(event_.get())->setSignaledValue(value());
});
} else { } else {
auto& d = metal::device(stream.device); auto& d = metal::device(stream.device);
d.end_encoding(stream.index); d.end_encoding(stream.index);

View File

@@ -20,7 +20,7 @@ struct FenceImpl {
auto p = metal::new_scoped_memory_pool(); auto p = metal::new_scoped_memory_pool();
fence = static_cast<void*>(d->newSharedEvent()); fence = static_cast<void*>(d->newSharedEvent());
} else { } else {
auto buf = allocator::malloc_or_wait(sizeof(uint32_t)).ptr(); auto buf = allocator::malloc(sizeof(uint32_t)).ptr();
fence = static_cast<void*>(buf); fence = static_cast<void*>(buf);
cpu_value()[0] = 0; cpu_value()[0] = 0;
} }

View File

@@ -281,7 +281,7 @@ std::tuple<array, array, array> compute_raders_constants(
} }
array b_q_fft({rader_n - 1}, complex64, nullptr, {}); array b_q_fft({rader_n - 1}, complex64, nullptr, {});
b_q_fft.set_data(allocator::malloc_or_wait(b_q_fft.nbytes())); b_q_fft.set_data(allocator::malloc(b_q_fft.nbytes()));
auto b_q_fft_ptr = auto b_q_fft_ptr =
reinterpret_cast<std::complex<float>*>(b_q_fft.data<complex64_t>()); reinterpret_cast<std::complex<float>*>(b_q_fft.data<complex64_t>());
std::ptrdiff_t item_size = b_q_fft.itemsize(); std::ptrdiff_t item_size = b_q_fft.itemsize();
@@ -327,11 +327,11 @@ std::pair<array, array> compute_bluestein_constants(int n, int bluestein_n) {
} }
array w_k({n}, complex64, nullptr, {}); array w_k({n}, complex64, nullptr, {});
w_k.set_data(allocator::malloc_or_wait(w_k.nbytes())); w_k.set_data(allocator::malloc(w_k.nbytes()));
std::copy(w_k_vec.begin(), w_k_vec.end(), w_k.data<complex64_t>()); std::copy(w_k_vec.begin(), w_k_vec.end(), w_k.data<complex64_t>());
array w_q({bluestein_n}, complex64, nullptr, {}); array w_q({bluestein_n}, complex64, nullptr, {});
w_q.set_data(allocator::malloc_or_wait(w_q.nbytes())); w_q.set_data(allocator::malloc(w_q.nbytes()));
auto w_q_ptr = auto w_q_ptr =
reinterpret_cast<std::complex<float>*>(w_q.data<complex64_t>()); reinterpret_cast<std::complex<float>*>(w_q.data<complex64_t>());
@@ -356,20 +356,14 @@ void multi_upload_bluestein_fft(
bool inverse, bool inverse,
bool real, bool real,
FFTPlan& plan, FFTPlan& plan,
std::vector<array> copies, std::vector<array>& copies,
const Stream& s) { const Stream& s) {
// TODO(alexbarron) Implement fused kernels for mutli upload bluestein's // TODO(alexbarron) Implement fused kernels for mutli upload bluestein's
// algorithm // algorithm
int n = inverse ? out.shape(axis) : in.shape(axis); int n = inverse ? out.shape(axis) : in.shape(axis);
auto [w_k, w_q] = compute_bluestein_constants(n, plan.bluestein_n); auto [w_k, w_q] = compute_bluestein_constants(n, plan.bluestein_n);
copies.push_back(w_k);
// Broadcast w_q and w_k to the batch size copies.push_back(w_q);
Strides b_strides(in.ndim(), 0);
b_strides[axis] = 1;
array w_k_broadcast({}, complex64, nullptr, {});
array w_q_broadcast({}, complex64, nullptr, {});
w_k_broadcast.copy_shared_buffer(w_k, b_strides, {}, w_k.data_size());
w_q_broadcast.copy_shared_buffer(w_q, b_strides, {}, w_q.data_size());
auto temp_shape = inverse ? out.shape() : in.shape(); auto temp_shape = inverse ? out.shape() : in.shape();
array temp(temp_shape, complex64, nullptr, {}); array temp(temp_shape, complex64, nullptr, {});
@@ -378,13 +372,13 @@ void multi_upload_bluestein_fft(
if (real && !inverse) { if (real && !inverse) {
// Convert float32->complex64 // Convert float32->complex64
copy_gpu(in, temp, CopyType::General, s); copy_gpu(in, temp, CopyType::General, s);
copies.push_back(temp);
} else if (real && inverse) { } else if (real && inverse) {
int back_offset = n % 2 == 0 ? 2 : 1; int back_offset = n % 2 == 0 ? 2 : 1;
auto slice_shape = in.shape(); auto slice_shape = in.shape();
slice_shape[axis] -= back_offset; slice_shape[axis] -= back_offset;
array slice_temp(slice_shape, complex64, nullptr, {}); array slice_temp(slice_shape, complex64, nullptr, {});
array conj_temp(in.shape(), complex64, nullptr, {}); array conj_temp(in.shape(), complex64, nullptr, {});
copies.push_back(slice_temp);
copies.push_back(conj_temp); copies.push_back(conj_temp);
Shape rstarts(in.ndim(), 0); Shape rstarts(in.ndim(), 0);
@@ -394,19 +388,28 @@ void multi_upload_bluestein_fft(
unary_op_gpu({in}, conj_temp, "Conjugate", s); unary_op_gpu({in}, conj_temp, "Conjugate", s);
slice_gpu(in, slice_temp, rstarts, rstrides, s); slice_gpu(in, slice_temp, rstarts, rstrides, s);
concatenate_gpu({conj_temp, slice_temp}, temp, (int)axis, s); concatenate_gpu({conj_temp, slice_temp}, temp, (int)axis, s);
copies.push_back(temp);
} else if (inverse) { } else if (inverse) {
unary_op_gpu({in}, temp, "Conjugate", s); unary_op_gpu({in}, temp, "Conjugate", s);
copies.push_back(temp);
} else { } else {
temp.copy_shared_buffer(in); temp.copy_shared_buffer(in);
} }
Strides b_strides(in.ndim(), 0);
b_strides[axis] = 1;
array w_k_broadcast(temp.shape(), complex64, nullptr, {});
w_k_broadcast.copy_shared_buffer(w_k, b_strides, {}, w_k.data_size());
binary_op_gpu({temp, w_k_broadcast}, temp1, "Multiply", s); binary_op_gpu({temp, w_k_broadcast}, temp1, "Multiply", s);
std::vector<std::pair<int, int>> pads; std::vector<std::pair<int, int>> pads;
auto padded_shape = out.shape(); auto padded_shape = out.shape();
padded_shape[axis] = plan.bluestein_n; padded_shape[axis] = plan.bluestein_n;
array pad_temp(padded_shape, complex64, nullptr, {}); array pad_temp(padded_shape, complex64, nullptr, {});
pad_gpu(temp1, array(complex64_t{0.0f, 0.0f}), pad_temp, {(int)axis}, {0}, s); auto zero = array(complex64_t{0.0f, 0.0f});
copies.push_back(zero);
pad_gpu(temp1, zero, pad_temp, {(int)axis}, {0}, s);
copies.push_back(pad_temp);
array pad_temp1(padded_shape, complex64, nullptr, {}); array pad_temp1(padded_shape, complex64, nullptr, {});
fft_op( fft_op(
@@ -418,7 +421,10 @@ void multi_upload_bluestein_fft(
FourStepParams(), FourStepParams(),
/*inplace=*/false, /*inplace=*/false,
s); s);
copies.push_back(pad_temp1);
array w_q_broadcast(pad_temp1.shape(), complex64, nullptr, {});
w_q_broadcast.copy_shared_buffer(w_q, b_strides, {}, w_q.data_size());
binary_op_gpu_inplace({pad_temp1, w_q_broadcast}, pad_temp, "Multiply", s); binary_op_gpu_inplace({pad_temp1, w_q_broadcast}, pad_temp, "Multiply", s);
fft_op( fft_op(
@@ -435,9 +441,11 @@ void multi_upload_bluestein_fft(
Shape starts(in.ndim(), 0); Shape starts(in.ndim(), 0);
Shape strides(in.ndim(), 1); Shape strides(in.ndim(), 1);
starts[axis] = plan.bluestein_n - offset - n; starts[axis] = plan.bluestein_n - offset - n;
slice_gpu(pad_temp1, temp, starts, strides, s);
binary_op_gpu_inplace({temp, w_k_broadcast}, temp1, "Multiply", s); array temp2(temp_shape, complex64, nullptr, {});
slice_gpu(pad_temp1, temp2, starts, strides, s);
binary_op_gpu_inplace({temp2, w_k_broadcast}, temp1, "Multiply", s);
if (real && !inverse) { if (real && !inverse) {
Shape rstarts(in.ndim(), 0); Shape rstarts(in.ndim(), 0);
@@ -449,26 +457,21 @@ void multi_upload_bluestein_fft(
array temp_float(out.shape(), out.dtype(), nullptr, {}); array temp_float(out.shape(), out.dtype(), nullptr, {});
copies.push_back(temp_float); copies.push_back(temp_float);
copies.push_back(inv_n); copies.push_back(inv_n);
copies.push_back(temp1);
copy_gpu(temp1, temp_float, CopyType::General, s); copy_gpu(temp1, temp_float, CopyType::General, s);
binary_op_gpu({temp_float, inv_n}, out, "Multiply", s); binary_op_gpu({temp_float, inv_n}, out, "Multiply", s);
} else if (inverse) { } else if (inverse) {
auto inv_n = array({1.0f / n}, {1}, complex64); auto inv_n = array({1.0f / n}, {1}, complex64);
unary_op_gpu({temp1}, temp, "Conjugate", s); array temp3(temp_shape, complex64, nullptr, {});
binary_op_gpu({temp, inv_n}, out, "Multiply", s); unary_op_gpu({temp1}, temp3, "Conjugate", s);
binary_op_gpu({temp3, inv_n}, out, "Multiply", s);
copies.push_back(inv_n); copies.push_back(inv_n);
copies.push_back(temp1);
copies.push_back(temp3);
} else { } else {
out.copy_shared_buffer(temp1); out.copy_shared_buffer(temp1);
} }
copies.push_back(w_k);
copies.push_back(w_q);
copies.push_back(w_k_broadcast);
copies.push_back(w_q_broadcast);
copies.push_back(temp);
copies.push_back(temp1);
copies.push_back(pad_temp);
copies.push_back(pad_temp1);
} }
void four_step_fft( void four_step_fft(
@@ -478,8 +481,9 @@ void four_step_fft(
bool inverse, bool inverse,
bool real, bool real,
FFTPlan& plan, FFTPlan& plan,
std::vector<array> copies, std::vector<array>& copies,
const Stream& s) { const Stream& s,
bool in_place) {
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
if (plan.bluestein_n == -1) { if (plan.bluestein_n == -1) {
@@ -492,7 +496,14 @@ void four_step_fft(
in, temp, axis, inverse, real, four_step_params, /*inplace=*/false, s); in, temp, axis, inverse, real, four_step_params, /*inplace=*/false, s);
four_step_params.first_step = false; four_step_params.first_step = false;
fft_op( fft_op(
temp, out, axis, inverse, real, four_step_params, /*inplace=*/false, s); temp,
out,
axis,
inverse,
real,
four_step_params,
/*inplace=*/in_place,
s);
copies.push_back(temp); copies.push_back(temp);
} else { } else {
multi_upload_bluestein_fft(in, out, axis, inverse, real, plan, copies, s); multi_upload_bluestein_fft(in, out, axis, inverse, real, plan, copies, s);
@@ -551,8 +562,7 @@ void fft_op(
flags.row_contiguous = is_row_contiguous; flags.row_contiguous = is_row_contiguous;
flags.contiguous = data_size == x_copy.size(); flags.contiguous = data_size == x_copy.size();
x_copy.set_data( x_copy.set_data(allocator::malloc(x.nbytes()), data_size, strides, flags);
allocator::malloc_or_wait(x.nbytes()), data_size, strides, flags);
copy_gpu_inplace(x, x_copy, CopyType::GeneralGeneral, s); copy_gpu_inplace(x, x_copy, CopyType::GeneralGeneral, s);
copies.push_back(x_copy); copies.push_back(x_copy);
return x_copy; return x_copy;
@@ -575,7 +585,7 @@ void fft_op(
auto plan = plan_fft(n); auto plan = plan_fft(n);
if (plan.four_step) { if (plan.four_step) {
four_step_fft(in, out, axis, inverse, real, plan, copies, s); four_step_fft(in, out, axis, inverse, real, plan, copies, s, inplace);
d.add_temporaries(std::move(copies), s.index); d.add_temporaries(std::move(copies), s.index);
return; return;
} }
@@ -583,7 +593,7 @@ void fft_op(
// TODO: allow donation here // TODO: allow donation here
if (!inplace) { if (!inplace) {
out.set_data( out.set_data(
allocator::malloc_or_wait(out.nbytes()), allocator::malloc(out.nbytes()),
out_data_size, out_data_size,
out_strides, out_strides,
in_contiguous.flags()); in_contiguous.flags());

View File

@@ -84,7 +84,7 @@ void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
if (in_contiguous.is_donatable()) { if (in_contiguous.is_donatable()) {
out.copy_shared_buffer(in_contiguous); out.copy_shared_buffer(in_contiguous);
} else { } else {
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
} }
int n, m; int n, m;
@@ -161,7 +161,7 @@ void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
// Upload 2: // Upload 2:
// y = h12 @ tmp // y = h12 @ tmp
array temp(in.shape(), in.dtype(), nullptr, {}); array temp(in.shape(), in.dtype(), nullptr, {});
temp.set_data(allocator::malloc_or_wait(temp.nbytes())); temp.set_data(allocator::malloc(temp.nbytes()));
copies.push_back(temp); copies.push_back(temp);
launch_hadamard(in_contiguous, temp, "n" + kernel_name, 1.0); launch_hadamard(in_contiguous, temp, "n" + kernel_name, 1.0);

View File

@@ -43,7 +43,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
throw std::runtime_error(msg.str()); throw std::runtime_error(msg.str());
} }
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
if (out.size() == 0) { if (out.size() == 0) {
return; return;
} }
@@ -393,7 +393,7 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& src = inputs[0]; auto& src = inputs[0];
auto& idx = inputs[1]; auto& idx = inputs[1];
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
if (out.size() == 0) { if (out.size() == 0) {
return; return;
} }

View File

@@ -1,9 +0,0 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view arange_kernels = R"(
template [[host_name("{0}")]] [[kernel]] void arange<{1}>(
constant const {1}& start,
constant const {1}& step,
device {1}* out,
uint index [[thread_position_in_grid]]);
)";

View File

@@ -20,6 +20,7 @@ const char* copy();
const char* fft(); const char* fft();
const char* gather_axis(); const char* gather_axis();
const char* hadamard(); const char* hadamard();
const char* logsumexp();
const char* quantized(); const char* quantized();
const char* ternary(); const char* ternary();
const char* scan(); const char* scan();
@@ -32,6 +33,7 @@ const char* gemm();
const char* steel_gemm_fused(); const char* steel_gemm_fused();
const char* steel_gemm_masked(); const char* steel_gemm_masked();
const char* steel_gemm_splitk(); const char* steel_gemm_splitk();
const char* steel_gemm_gather();
const char* conv(); const char* conv();
const char* steel_conv(); const char* steel_conv();
const char* steel_conv_general(); const char* steel_conv_general();

View File

@@ -1,23 +0,0 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view softmax_kernels = R"(
template [[host_name("block_{0}")]] [[kernel]] void
softmax_single_row<{1}, {2}>(
const device {1}* in,
device {1}* out,
constant int& axis_size,
uint gid [[thread_position_in_grid]],
uint _lid [[thread_position_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
template [[host_name("looped_{0}")]] [[kernel]] void
softmax_looped<{1}, {2}>(
const device {1}* in,
device {1}* out,
constant int& axis_size,
uint gid [[threadgroup_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
)";

View File

@@ -1,8 +1,6 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
#include "mlx/backend/common/compiled.h" #include "mlx/backend/common/compiled.h"
#include "mlx/backend/metal/jit/arange.h"
#include "mlx/backend/metal/jit/includes.h" #include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/jit/softmax.h"
#include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h" #include "mlx/backend/metal/utils.h"
@@ -21,13 +19,11 @@ MTL::ComputePipelineState* get_arange_kernel(
const std::string& kernel_name, const std::string& kernel_name,
const array& out) { const array& out) {
auto lib = d.get_library(kernel_name, [&]() { auto lib = d.get_library(kernel_name, [&]() {
std::ostringstream kernel_source; std::string kernel_source = metal::utils();
kernel_source << metal::utils() << metal::arange() kernel_source += metal::arange();
<< fmt::format( kernel_source += get_template_definition(
arange_kernels, kernel_name, "arange", get_type_string(out.dtype()));
kernel_name, return kernel_source;
get_type_string(out.dtype()));
return kernel_source.str();
}); });
return d.get_kernel(kernel_name, lib); return d.get_kernel(kernel_name, lib);
} }
@@ -259,14 +255,34 @@ MTL::ComputePipelineState* get_softmax_kernel(
const array& out) { const array& out) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name, [&] { auto lib = d.get_library(lib_name, [&] {
std::ostringstream kernel_source; std::string kernel_source = metal::utils();
kernel_source << metal::utils() << metal::softmax() auto in_type = get_type_string(out.dtype());
<< fmt::format( auto acc_type = get_type_string(precise ? float32 : out.dtype());
softmax_kernels, kernel_source += metal::softmax();
lib_name, kernel_source += get_template_definition(
get_type_string(out.dtype()), "block_" + lib_name, "softmax_single_row", in_type, acc_type);
get_type_string(precise ? float32 : out.dtype())); kernel_source += get_template_definition(
return kernel_source.str(); "looped_" + lib_name, "softmax_looped", in_type, acc_type);
return kernel_source;
});
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_logsumexp_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name, [&] {
auto t_str = get_type_string(out.dtype());
std::string kernel_source;
kernel_source = metal::utils();
kernel_source += metal::logsumexp();
kernel_source +=
get_template_definition("block_" + lib_name, "logsumexp", t_str);
kernel_source += get_template_definition(
"looped_" + lib_name, "logsumexp_looped", t_str);
return kernel_source;
}); });
return d.get_kernel(kernel_name, lib); return d.get_kernel(kernel_name, lib);
} }
@@ -568,6 +584,44 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
return d.get_kernel(kernel_name, lib); return d.get_kernel(kernel_name, lib);
} }
MTL::ComputePipelineState* get_steel_gemm_gather_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const array& out,
bool transpose_a,
bool transpose_b,
int bm,
int bn,
int bk,
int wm,
int wn,
bool rhs) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name, [&]() {
std::string kernel_source;
concatenate(
kernel_source,
metal::utils(),
metal::gemm(),
metal::steel_gemm_gather(),
get_template_definition(
lib_name,
rhs ? "gather_mm_rhs" : "gather_mm",
get_type_string(out.dtype()),
bm,
bn,
bk,
wm,
wn,
transpose_a,
transpose_b));
return kernel_source;
});
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
}
MTL::ComputePipelineState* get_gemv_masked_kernel( MTL::ComputePipelineState* get_gemv_masked_kernel(
metal::Device& d, metal::Device& d,
const std::string& kernel_name, const std::string& kernel_name,

View File

@@ -59,6 +59,11 @@ MTL::ComputePipelineState* get_softmax_kernel(
bool precise, bool precise,
const array& out); const array& out);
MTL::ComputePipelineState* get_logsumexp_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out);
MTL::ComputePipelineState* get_scan_kernel( MTL::ComputePipelineState* get_scan_kernel(
metal::Device& d, metal::Device& d,
const std::string& kernel_name, const std::string& kernel_name,
@@ -155,6 +160,21 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
bool mn_aligned, bool mn_aligned,
bool k_aligned); bool k_aligned);
MTL::ComputePipelineState* get_steel_gemm_gather_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const array& out,
bool transpose_a,
bool transpose_b,
int bm,
int bn,
int bk,
int wm,
int wn,
bool rhs);
MTL::ComputePipelineState* get_steel_conv_kernel( MTL::ComputePipelineState* get_steel_conv_kernel(
metal::Device& d, metal::Device& d,
const std::string& kernel_name, const std::string& kernel_name,

View File

@@ -13,6 +13,10 @@ function(build_kernel_base TARGET SRCFILE DEPS)
if(MLX_METAL_DEBUG) if(MLX_METAL_DEBUG)
set(METAL_FLAGS ${METAL_FLAGS} -gline-tables-only -frecord-sources) set(METAL_FLAGS ${METAL_FLAGS} -gline-tables-only -frecord-sources)
endif() endif()
if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "")
set(METAL_FLAGS ${METAL_FLAGS}
"-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
endif()
if(MLX_METAL_VERSION GREATER_EQUAL 310) if(MLX_METAL_VERSION GREATER_EQUAL 310)
set(VERSION_INCLUDES set(VERSION_INCLUDES
${PROJECT_SOURCE_DIR}/mlx/backend/metal/kernels/metal_3_1) ${PROJECT_SOURCE_DIR}/mlx/backend/metal/kernels/metal_3_1)
@@ -65,6 +69,7 @@ set(STEEL_HEADERS
steel/gemm/loader.h steel/gemm/loader.h
steel/gemm/transforms.h steel/gemm/transforms.h
steel/gemm/kernels/steel_gemm_fused.h steel/gemm/kernels/steel_gemm_fused.h
steel/gemm/kernels/steel_gemm_gather.h
steel/gemm/kernels/steel_gemm_masked.h steel/gemm/kernels/steel_gemm_masked.h
steel/gemm/kernels/steel_gemm_splitk.h steel/gemm/kernels/steel_gemm_splitk.h
steel/utils/type_traits.h steel/utils/type_traits.h
@@ -105,12 +110,14 @@ if(NOT MLX_METAL_JIT)
build_kernel(quantized quantized.h ${STEEL_HEADERS}) build_kernel(quantized quantized.h ${STEEL_HEADERS})
build_kernel(scan scan.h) build_kernel(scan scan.h)
build_kernel(softmax softmax.h) build_kernel(softmax softmax.h)
build_kernel(logsumexp logsumexp.h)
build_kernel(sort sort.h) build_kernel(sort sort.h)
build_kernel(ternary ternary.h ternary_ops.h) build_kernel(ternary ternary.h ternary_ops.h)
build_kernel(unary unary.h unary_ops.h) build_kernel(unary unary.h unary_ops.h)
build_kernel(steel/conv/kernels/steel_conv ${STEEL_HEADERS}) build_kernel(steel/conv/kernels/steel_conv ${STEEL_HEADERS})
build_kernel(steel/conv/kernels/steel_conv_general ${STEEL_HEADERS}) build_kernel(steel/conv/kernels/steel_conv_general ${STEEL_HEADERS})
build_kernel(steel/gemm/kernels/steel_gemm_fused ${STEEL_HEADERS}) build_kernel(steel/gemm/kernels/steel_gemm_fused ${STEEL_HEADERS})
build_kernel(steel/gemm/kernels/steel_gemm_gather ${STEEL_HEADERS})
build_kernel(steel/gemm/kernels/steel_gemm_masked ${STEEL_HEADERS}) build_kernel(steel/gemm/kernels/steel_gemm_masked ${STEEL_HEADERS})
build_kernel(steel/gemm/kernels/steel_gemm_splitk ${STEEL_HEADERS}) build_kernel(steel/gemm/kernels/steel_gemm_splitk ${STEEL_HEADERS})
build_kernel(gemv_masked steel/utils.h) build_kernel(gemv_masked steel/utils.h)

View File

@@ -5,11 +5,7 @@
#include "mlx/backend/metal/kernels/arange.h" #include "mlx/backend/metal/kernels/arange.h"
#define instantiate_arange(tname, type) \ #define instantiate_arange(tname, type) \
template [[host_name("arange" #tname)]] [[kernel]] void arange<type>( \ instantiate_kernel("arange" #tname, arange, type)
constant const type& start, \
constant const type& step, \
device type* out, \
uint index [[thread_position_in_grid]]);
instantiate_arange(uint8, uint8_t) instantiate_arange(uint8, uint8_t)
instantiate_arange(uint16, uint16_t) instantiate_arange(uint16, uint16_t)

View File

@@ -275,6 +275,128 @@ instantiate_naive_conv_2d_blocks(float32, float);
instantiate_naive_conv_2d_blocks(float16, half); instantiate_naive_conv_2d_blocks(float16, half);
instantiate_naive_conv_2d_blocks(bfloat16, bfloat16_t); instantiate_naive_conv_2d_blocks(bfloat16, bfloat16_t);
///////////////////////////////////////////////////////////////////////////////
/// Depthwise convolution kernels
///////////////////////////////////////////////////////////////////////////////
constant int ker_h [[function_constant(00)]];
constant int ker_w [[function_constant(01)]];
constant int str_h [[function_constant(10)]];
constant int str_w [[function_constant(11)]];
constant int tgp_h [[function_constant(100)]];
constant int tgp_w [[function_constant(101)]];
constant bool do_flip [[function_constant(200)]];
constant int span_h = tgp_h * str_h + ker_h - 1;
constant int span_w = tgp_w * str_w + ker_w - 1;
constant int span_hw = span_h * span_w;
template <typename T>
[[kernel]] void depthwise_conv_2d(
const device T* in [[buffer(0)]],
const device T* wt [[buffer(1)]],
device T* out [[buffer(2)]],
const constant MLXConvParams<2>& params [[buffer(3)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 gid [[thread_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int tc = 8;
constexpr int tw = 8;
constexpr int th = 4;
constexpr int c_per_thr = 8;
constexpr int TGH = th * 2 + 6;
constexpr int TGW = tw * 2 + 6;
constexpr int TGC = tc;
threadgroup T ins[TGH * TGW * TGC];
const int n_tgblocks_h = params.oS[0] / th;
const int n = tid.z / n_tgblocks_h;
const int tghid = tid.z % n_tgblocks_h;
const int oh = tghid * th + lid.z;
const int ow = gid.y;
const int c = gid.x;
in += n * params.in_strides[0];
// Load in
{
constexpr int n_threads = th * tw * tc;
const int tg_oh = (tghid * th) * str_h - params.pad[0];
const int tg_ow = (tid.y * tw) * str_w - params.pad[1];
const int tg_c = tid.x * tc;
const int thread_idx = simd_gid * 32 + simd_lid;
constexpr int thr_per_hw = tc / c_per_thr;
constexpr int hw_per_group = n_threads / thr_per_hw;
const int thr_c = thread_idx % thr_per_hw;
const int thr_hw = thread_idx / thr_per_hw;
for (int hw = thr_hw; hw < span_hw; hw += hw_per_group) {
const int h = hw / span_w;
const int w = hw % span_w;
const int ih = tg_oh + h;
const int iw = tg_ow + w;
const int in_s_offset = h * span_w * TGC + w * TGC;
if (ih >= 0 && ih < params.iS[0] && iw >= 0 && iw < params.iS[1]) {
const auto in_load =
in + ih * params.in_strides[1] + iw * params.in_strides[2] + tg_c;
MLX_MTL_PRAGMA_UNROLL
for (int cc = 0; cc < c_per_thr; ++cc) {
ins[in_s_offset + c_per_thr * thr_c + cc] =
in_load[c_per_thr * thr_c + cc];
}
} else {
MLX_MTL_PRAGMA_UNROLL
for (int cc = 0; cc < c_per_thr; ++cc) {
ins[in_s_offset + c_per_thr * thr_c + cc] = T(0);
}
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
wt += c * params.wt_strides[0];
const auto ins_ptr =
&ins[lid.z * str_h * span_w * TGC + lid.y * str_w * TGC + lid.x];
float o = 0.;
for (int h = 0; h < ker_h; ++h) {
for (int w = 0; w < ker_w; ++w) {
int wt_h = h;
int wt_w = w;
if (do_flip) {
wt_h = ker_h - h - 1;
wt_w = ker_w - w - 1;
}
auto inv = ins_ptr[h * span_w * TGC + w * TGC];
auto wtv = wt[wt_h * ker_w + wt_w];
o += inv * wtv;
}
}
threadgroup_barrier(mem_flags::mem_none);
out += n * params.out_strides[0] + oh * params.out_strides[1] +
ow * params.out_strides[2];
out[c] = static_cast<T>(o);
}
#define instantiate_depthconv2d(iname, itype) \
instantiate_kernel("depthwise_conv_2d_" #iname, depthwise_conv_2d, itype)
instantiate_depthconv2d(float32, float);
instantiate_depthconv2d(float16, half);
instantiate_depthconv2d(bfloat16, bfloat16_t);
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
/// Winograd kernels /// Winograd kernels
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////

View File

@@ -483,4 +483,4 @@ template <
perform_fft(fft_idx, &p, m, n, buf); perform_fft(fft_idx, &p, m, n, buf);
read_writer.write_strided(stride, overall_n); read_writer.write_strided(stride, overall_n);
} }

View File

@@ -341,7 +341,7 @@ struct GEMVTKernel {
MLX_MTL_PRAGMA_UNROLL MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) { for (int tm = 0; tm < TM; tm++) {
auto vc = float(v_coeff[tm]); auto vc = static_cast<AccT>(v_coeff[tm]);
for (int tn = 0; tn < TN; tn++) { for (int tn = 0; tn < TN; tn++) {
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn]; inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
} }

View File

@@ -493,71 +493,11 @@ template <typename T, int N_READS = RMS_N_READS>
} }
// clang-format off // clang-format off
#define instantiate_layer_norm_single_row(name, itype) \ #define instantiate_layer_norm(name, itype) \
template [[host_name("layer_norm" #name)]] [[kernel]] void \ instantiate_kernel("layer_norm" #name, layer_norm_single_row, itype) \
layer_norm_single_row<itype>( \ instantiate_kernel("vjp_layer_norm" #name, vjp_layer_norm_single_row, itype) \
const device itype* x, \ instantiate_kernel("layer_norm_looped" #name, layer_norm_looped, itype) \
const device itype* w, \ instantiate_kernel("vjp_layer_norm_looped" #name, vjp_layer_norm_looped, itype)
const device itype* b, \
device itype* out, \
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
constant uint& b_stride, \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
template [[host_name("vjp_layer_norm" #name)]] [[kernel]] void \
vjp_layer_norm_single_row<itype>( \
const device itype* x, \
const device itype* w, \
const device itype* g, \
device itype* gx, \
device itype* gw, \
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_layer_norm_looped(name, itype) \
template [[host_name("layer_norm_looped" #name)]] [[kernel]] void \
layer_norm_looped<itype>( \
const device itype* x, \
const device itype* w, \
const device itype* b, \
device itype* out, \
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
constant uint& b_stride, \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
template [[host_name("vjp_layer_norm_looped" #name)]] [[kernel]] void \
vjp_layer_norm_looped<itype>( \
const device itype* x, \
const device itype* w, \
const device itype* g, \
device itype* gx, \
device itype* gb, \
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_layer_norm(name, itype) \
instantiate_layer_norm_single_row(name, itype) \
instantiate_layer_norm_looped(name, itype)
instantiate_layer_norm(float32, float) instantiate_layer_norm(float32, float)
instantiate_layer_norm(float16, half) instantiate_layer_norm(float16, half)

View File

@@ -0,0 +1,143 @@
// Copyright © 2025 Apple Inc.
template <typename T, typename AccT = float, int N_READS = 4>
[[kernel]] void logsumexp(
const device T* in,
device T* out,
constant int& axis_size,
uint gid [[threadgroup_position_in_grid]],
uint _lid [[thread_position_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
int lid = _lid;
constexpr int SIMD_SIZE = 32;
threadgroup AccT local_max[SIMD_SIZE];
threadgroup AccT local_normalizer[SIMD_SIZE];
AccT ld[N_READS];
in += gid * size_t(axis_size) + lid * N_READS;
if (lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
ld[i] = AccT(in[i]);
}
} else {
for (int i = 0; i < N_READS; i++) {
ld[i] =
((lid * N_READS + i) < axis_size) ? AccT(in[i]) : Limits<AccT>::min;
}
}
if (simd_group_id == 0) {
local_max[simd_lane_id] = Limits<AccT>::min;
local_normalizer[simd_lane_id] = 0;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Get the max
AccT maxval = Limits<AccT>::finite_min;
for (int i = 0; i < N_READS; i++) {
maxval = (maxval < ld[i]) ? ld[i] : maxval;
}
maxval = simd_max(maxval);
if (simd_lane_id == 0) {
local_max[simd_group_id] = maxval;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_group_id == 0) {
maxval = simd_max(local_max[simd_lane_id]);
if (simd_lane_id == 0) {
local_max[0] = maxval;
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
maxval = local_max[0];
// Compute exp(x_i - maxval) and store the partial sums in local_normalizer
AccT normalizer = 0;
for (int i = 0; i < N_READS; i++) {
normalizer += fast::exp(ld[i] - maxval);
}
normalizer = simd_sum(normalizer);
if (simd_lane_id == 0) {
local_normalizer[simd_group_id] = normalizer;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_group_id == 0) {
normalizer = simd_sum(local_normalizer[simd_lane_id]);
if (simd_lane_id == 0) {
out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval);
}
}
}
template <typename T, typename AccT = float, int N_READS = 4>
[[kernel]] void logsumexp_looped(
const device T* in,
device T* out,
constant int& axis_size,
uint gid [[threadgroup_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
in += gid * size_t(axis_size);
constexpr int SIMD_SIZE = 32;
threadgroup AccT local_max[SIMD_SIZE];
threadgroup AccT local_normalizer[SIMD_SIZE];
// Get the max and the normalizer in one go
AccT prevmax;
AccT maxval = Limits<AccT>::finite_min;
AccT normalizer = 0;
for (int r = 0; r < static_cast<int>(ceildiv(axis_size, N_READS * lsize));
r++) {
int offset = r * lsize * N_READS + lid * N_READS;
AccT vals[N_READS];
if (offset + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
vals[i] = AccT(in[offset + i]);
}
} else {
for (int i = 0; i < N_READS; i++) {
vals[i] = (offset + i < axis_size) ? AccT(in[offset + i])
: Limits<AccT>::finite_min;
}
}
prevmax = maxval;
for (int i = 0; i < N_READS; i++) {
maxval = (maxval < vals[i]) ? vals[i] : maxval;
}
normalizer *= fast::exp(prevmax - maxval);
for (int i = 0; i < N_READS; i++) {
normalizer += fast::exp(vals[i] - maxval);
}
}
prevmax = maxval;
maxval = simd_max(maxval);
normalizer *= fast::exp(prevmax - maxval);
normalizer = simd_sum(normalizer);
prevmax = maxval;
if (simd_lane_id == 0) {
local_max[simd_group_id] = maxval;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
maxval = simd_max(local_max[simd_lane_id]);
normalizer *= fast::exp(prevmax - maxval);
if (simd_lane_id == 0) {
local_normalizer[simd_group_id] = normalizer;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
normalizer = simd_sum(local_normalizer[simd_lane_id]);
if (simd_group_id == 0) {
normalizer = simd_sum(local_normalizer[simd_lane_id]);
if (simd_lane_id == 0) {
out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval);
}
}
}

View File

@@ -0,0 +1,18 @@
// Copyright © 2023-2024 Apple Inc.
#include <metal_common>
#include <metal_simdgroup>
using namespace metal;
// clang-format off
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/logsumexp.h"
#define instantiate_logsumexp(name, itype) \
instantiate_kernel("block_logsumexp_" #name, logsumexp, itype) \
instantiate_kernel("looped_logsumexp_" #name, logsumexp_looped, itype) \
instantiate_logsumexp(float32, float)
instantiate_logsumexp(float16, half)
instantiate_logsumexp(bfloat16, bfloat16_t) // clang-format on

View File

@@ -586,13 +586,13 @@ METAL_FUNC void qmv_quad_impl(
// Adjust positions // Adjust positions
const int in_vec_size_w = in_vec_size / pack_factor; const int in_vec_size_w = in_vec_size / pack_factor;
const int in_vec_size_g = in_vec_size / group_size; const int in_vec_size_g = in_vec_size / group_size;
const int out_row = tid.x * quads_per_simd * results_per_quadgroup + quad_gid; const int out_row = tid.y * quads_per_simd * results_per_quadgroup + quad_gid;
w += out_row * in_vec_size_w + quad_lid * packs_per_thread; w += out_row * in_vec_size_w + quad_lid * packs_per_thread;
scales += out_row * in_vec_size_g + quad_lid / scale_step_per_thread; scales += out_row * in_vec_size_g + quad_lid / scale_step_per_thread;
biases += out_row * in_vec_size_g + quad_lid / scale_step_per_thread; biases += out_row * in_vec_size_g + quad_lid / scale_step_per_thread;
x += tid.y * in_vec_size + quad_lid * values_per_thread; x += tid.x * in_vec_size + quad_lid * values_per_thread;
y += tid.y * out_vec_size + out_row; y += tid.x * out_vec_size + out_row;
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread); U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);

View File

@@ -380,69 +380,11 @@ template <typename T, int N_READS = RMS_N_READS>
} }
// clang-format off // clang-format off
#define instantiate_rms_single_row(name, itype) \ #define instantiate_rms(name, itype) \
template [[host_name("rms" #name)]] [[kernel]] void \ instantiate_kernel("rms" #name, rms_single_row, itype) \
rms_single_row<itype>( \ instantiate_kernel("vjp_rms" #name, vjp_rms_single_row, itype) \
const device itype* x, \ instantiate_kernel("rms_looped" #name, rms_looped, itype) \
const device itype* w, \ instantiate_kernel("vjp_rms_looped" #name, vjp_rms_looped, itype)
device itype* out, \
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
\
template [[host_name("vjp_rms" #name)]] [[kernel]] void \
vjp_rms_single_row<itype>( \
const device itype* x, \
const device itype* w, \
const device itype* g, \
device itype* gx, \
device itype* gw, \
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_rms_looped(name, itype) \
template [[host_name("rms_looped" #name)]] [[kernel]] void \
rms_looped<itype>( \
const device itype* x, \
const device itype* w, \
device itype* out, \
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
\
template [[host_name("vjp_rms_looped" #name)]] [[kernel]] void \
vjp_rms_looped<itype>( \
const device itype* x, \
const device itype* w, \
const device itype* g, \
device itype* gx, \
device itype* gw, \
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_rms(name, itype) \
instantiate_rms_single_row(name, itype) \
instantiate_rms_looped(name, itype)
instantiate_rms(float32, float) instantiate_rms(float32, float)
instantiate_rms(float16, half) instantiate_rms(float16, half)

View File

@@ -1,11 +1,11 @@
#include <metal_stdlib> #include <metal_stdlib>
#include "mlx/backend/metal/kernels/sdpa_vector.h" // clang-format off
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/sdpa_vector.h"
using namespace metal; using namespace metal;
// clang-format off
// SDPA vector instantiations // SDPA vector instantiations
#define instantiate_sdpa_vector_aggregation(type, value_dim) \ #define instantiate_sdpa_vector_aggregation(type, value_dim) \
instantiate_kernel( \ instantiate_kernel( \
@@ -32,9 +32,11 @@ using namespace metal;
instantiate_sdpa_vector(type, 64, 64) \ instantiate_sdpa_vector(type, 64, 64) \
instantiate_sdpa_vector(type, 96, 96) \ instantiate_sdpa_vector(type, 96, 96) \
instantiate_sdpa_vector(type, 128, 128) \ instantiate_sdpa_vector(type, 128, 128) \
instantiate_sdpa_vector(type, 256, 256) \
instantiate_sdpa_vector_aggregation(type, 64) \ instantiate_sdpa_vector_aggregation(type, 64) \
instantiate_sdpa_vector_aggregation(type, 96) \ instantiate_sdpa_vector_aggregation(type, 96) \
instantiate_sdpa_vector_aggregation(type, 128) instantiate_sdpa_vector_aggregation(type, 128) \
instantiate_sdpa_vector_aggregation(type, 256)
instantiate_sdpa_vector_heads(float) instantiate_sdpa_vector_heads(float)
instantiate_sdpa_vector_heads(bfloat16_t) instantiate_sdpa_vector_heads(bfloat16_t)

View File

@@ -2,6 +2,8 @@
#pragma once #pragma once
#include "mlx/backend/metal/kernels/binary_ops.h"
#define DEFINE_SIMD_SCAN() \ #define DEFINE_SIMD_SCAN() \
template <typename T, metal::enable_if_t<sizeof(T) < 8, bool> = true> \ template <typename T, metal::enable_if_t<sizeof(T) < 8, bool> = true> \
T simd_scan(T val) { \ T simd_scan(T val) { \
@@ -139,6 +141,29 @@ struct CumMin {
} }
}; };
template <typename U>
struct CumLogaddexp {
static constexpr constant U init = Limits<U>::min;
template <typename T>
U operator()(U a, T b) {
return LogAddExp{}(a, static_cast<U>(b));
}
U simd_scan(U x) {
for (int i = 1; i <= 16; i *= 2) {
U other = simd_shuffle_and_fill_up(x, init, i);
x = LogAddExp{}(x, other);
}
return x;
}
U simd_exclusive_scan(U x) {
x = simd_scan(x);
return simd_shuffle_and_fill_up(x, init, 1);
}
};
template <typename T, typename U, int N_READS, bool reverse> template <typename T, typename U, int N_READS, bool reverse>
inline void load_unsafe(U values[N_READS], const device T* input) { inline void load_unsafe(U values[N_READS], const device T* input) {
if (reverse) { if (reverse) {

View File

@@ -101,4 +101,7 @@ instantiate_scan_helper(min_int64_int64, int64_t, int64_t, CumMi
instantiate_scan_helper(min_float16_float16, half, half, CumMin, 4) instantiate_scan_helper(min_float16_float16, half, half, CumMin, 4)
instantiate_scan_helper(min_float32_float32, float, float, CumMin, 4) instantiate_scan_helper(min_float32_float32, float, float, CumMin, 4)
instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMin, 4) instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMin, 4)
instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin, 2) // clang-format on instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin, 2)
instantiate_scan_helper(logaddexp_float16_float16, half, half, CumLogaddexp, 4)
instantiate_scan_helper(logaddexp_float32_float32, float, float, CumLogaddexp, 4)
instantiate_scan_helper(logaddexp_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumLogaddexp, 4) // clang-format on

View File

@@ -6,6 +6,9 @@ using namespace metal;
constant bool has_mask [[function_constant(20)]]; constant bool has_mask [[function_constant(20)]];
constant bool query_transposed [[function_constant(21)]]; constant bool query_transposed [[function_constant(21)]];
constant bool do_causal [[function_constant(22)]];
constant bool bool_mask [[function_constant(23)]];
constant bool float_mask [[function_constant(24)]];
template <typename T, int D, int V = D> template <typename T, int D, int V = D>
[[kernel]] void sdpa_vector( [[kernel]] void sdpa_vector(
@@ -13,17 +16,21 @@ template <typename T, int D, int V = D>
const device T* keys [[buffer(1)]], const device T* keys [[buffer(1)]],
const device T* values [[buffer(2)]], const device T* values [[buffer(2)]],
device T* out [[buffer(3)]], device T* out [[buffer(3)]],
const constant int& gqa_factor, const constant int& gqa_factor [[buffer(4)]],
const constant int& N, const constant int& N [[buffer(5)]],
const constant size_t& k_head_stride, const constant size_t& k_head_stride [[buffer(6)]],
const constant size_t& k_seq_stride, const constant size_t& k_seq_stride [[buffer(7)]],
const constant size_t& v_head_stride, const constant size_t& v_head_stride [[buffer(8)]],
const constant size_t& v_seq_stride, const constant size_t& v_seq_stride [[buffer(9)]],
const constant float& scale, const constant float& scale [[buffer(10)]],
const device bool* mask [[function_constant(has_mask)]], const device bool* bmask [[buffer(11), function_constant(bool_mask)]],
const constant int& mask_kv_seq_stride [[function_constant(has_mask)]], const device T* fmask [[buffer(12), function_constant(float_mask)]],
const constant int& mask_q_seq_stride [[function_constant(has_mask)]], const constant int& mask_kv_seq_stride
const constant int& mask_head_stride [[function_constant(has_mask)]], [[buffer(13), function_constant(has_mask)]],
const constant int& mask_q_seq_stride
[[buffer(14), function_constant(has_mask)]],
const constant int& mask_head_stride
[[buffer(15), function_constant(has_mask)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint3 tpg [[threadgroups_per_grid]], uint3 tpg [[threadgroups_per_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
@@ -57,8 +64,12 @@ template <typename T, int D, int V = D>
simd_lid * qk_per_thread; simd_lid * qk_per_thread;
values += kv_head_idx * v_head_stride + simd_gid * v_seq_stride + values += kv_head_idx * v_head_stride + simd_gid * v_seq_stride +
simd_lid * v_per_thread; simd_lid * v_per_thread;
if (has_mask) { if (bool_mask) {
mask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride + bmask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride +
q_seq_idx * mask_q_seq_stride;
}
if (float_mask) {
fmask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride +
q_seq_idx * mask_q_seq_stride; q_seq_idx * mask_q_seq_stride;
} }
@@ -77,7 +88,13 @@ template <typename T, int D, int V = D>
// For each key // For each key
for (int i = simd_gid; i < N; i += BN) { for (int i = simd_gid; i < N; i += BN) {
if (!has_mask || mask[0]) { bool use_key = true;
if (do_causal) {
use_key = i <= (N - int(tpg.y) + int(q_seq_idx));
} else if (bool_mask) {
use_key = bmask[0];
}
if (use_key) {
// Read the key // Read the key
for (int j = 0; j < qk_per_thread; j++) { for (int j = 0; j < qk_per_thread; j++) {
k[j] = keys[j]; k[j] = keys[j];
@@ -89,6 +106,9 @@ template <typename T, int D, int V = D>
score += q[j] * k[j]; score += q[j] * k[j];
} }
score = simd_sum(score); score = simd_sum(score);
if (float_mask) {
score += max(Limits<U>::finite_min, static_cast<U>(fmask[0]));
}
// Update the accumulators // Update the accumulators
U new_max = max(max_score, score); U new_max = max(max_score, score);
@@ -107,8 +127,11 @@ template <typename T, int D, int V = D>
// Move the pointers to the next kv // Move the pointers to the next kv
keys += inner_k_stride; keys += inner_k_stride;
values += inner_v_stride; values += inner_v_stride;
if (has_mask) { if (bool_mask) {
mask += BN * mask_kv_seq_stride; bmask += BN * mask_kv_seq_stride;
}
if (float_mask) {
fmask += BN * mask_kv_seq_stride;
} }
} }
@@ -149,17 +172,21 @@ template <typename T, int D, int V = D>
device float* out [[buffer(3)]], device float* out [[buffer(3)]],
device float* sums [[buffer(4)]], device float* sums [[buffer(4)]],
device float* maxs [[buffer(5)]], device float* maxs [[buffer(5)]],
const constant int& gqa_factor, const constant int& gqa_factor [[buffer(6)]],
const constant int& N, const constant int& N [[buffer(7)]],
const constant size_t& k_head_stride, const constant size_t& k_head_stride [[buffer(8)]],
const constant size_t& k_seq_stride, const constant size_t& k_seq_stride [[buffer(9)]],
const constant size_t& v_head_stride, const constant size_t& v_head_stride [[buffer(10)]],
const constant size_t& v_seq_stride, const constant size_t& v_seq_stride [[buffer(11)]],
const constant float& scale, const constant float& scale [[buffer(12)]],
const device bool* mask [[function_constant(has_mask)]], const device bool* bmask [[buffer(13), function_constant(bool_mask)]],
const constant int& mask_kv_seq_stride [[function_constant(has_mask)]], const device T* fmask [[buffer(14), function_constant(float_mask)]],
const constant int& mask_q_seq_stride [[function_constant(has_mask)]], const constant int& mask_kv_seq_stride
const constant int& mask_head_stride [[function_constant(has_mask)]], [[buffer(15), function_constant(has_mask)]],
const constant int& mask_q_seq_stride
[[buffer(16), function_constant(has_mask)]],
const constant int& mask_head_stride
[[buffer(17), function_constant(has_mask)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint3 tpg [[threadgroups_per_grid]], uint3 tpg [[threadgroups_per_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
@@ -197,8 +224,13 @@ template <typename T, int D, int V = D>
values += kv_head_idx * v_head_stride + values += kv_head_idx * v_head_stride +
(block_idx * BN + simd_gid) * v_seq_stride + simd_lid * v_per_thread; (block_idx * BN + simd_gid) * v_seq_stride + simd_lid * v_per_thread;
out += o_offset * blocks * V + block_idx * V + simd_lid * v_per_thread; out += o_offset * blocks * V + block_idx * V + simd_lid * v_per_thread;
if (has_mask) { if (bool_mask) {
mask += head_idx * mask_head_stride + bmask += head_idx * mask_head_stride +
(block_idx * BN + simd_gid) * mask_kv_seq_stride +
q_seq_idx * mask_q_seq_stride;
}
if (float_mask) {
fmask += head_idx * mask_head_stride +
(block_idx * BN + simd_gid) * mask_kv_seq_stride + (block_idx * BN + simd_gid) * mask_kv_seq_stride +
q_seq_idx * mask_q_seq_stride; q_seq_idx * mask_q_seq_stride;
} }
@@ -218,7 +250,13 @@ template <typename T, int D, int V = D>
// For each key // For each key
for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) { for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) {
if (!has_mask || mask[0]) { bool use_key = true;
if (do_causal) {
use_key = i <= (N - int(tpg.y) + int(q_seq_idx));
} else if (bool_mask) {
use_key = bmask[0];
}
if (use_key) {
// Read the key // Read the key
for (int i = 0; i < qk_per_thread; i++) { for (int i = 0; i < qk_per_thread; i++) {
k[i] = keys[i]; k[i] = keys[i];
@@ -230,6 +268,9 @@ template <typename T, int D, int V = D>
score += q[i] * k[i]; score += q[i] * k[i];
} }
score = simd_sum(score); score = simd_sum(score);
if (float_mask) {
score += fmask[0];
}
// Update the accumulators // Update the accumulators
U new_max = max(max_score, score); U new_max = max(max_score, score);
@@ -248,8 +289,11 @@ template <typename T, int D, int V = D>
// Move the pointers to the next kv // Move the pointers to the next kv
keys += blocks * inner_k_stride; keys += blocks * inner_k_stride;
values += blocks * inner_v_stride; values += blocks * inner_v_stride;
if (has_mask) { if (bool_mask) {
mask += BN * blocks * mask_kv_seq_stride; bmask += BN * blocks * mask_kv_seq_stride;
}
if (float_mask) {
fmask += BN * blocks * mask_kv_seq_stride;
} }
} }

View File

@@ -9,47 +9,13 @@ using namespace metal;
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/softmax.h" #include "mlx/backend/metal/kernels/softmax.h"
#define instantiate_softmax(name, itype) \ #define instantiate_softmax(name, itype) \
template [[host_name("block_softmax_" #name)]] [[kernel]] void \ instantiate_kernel("block_softmax_" #name, softmax_single_row, itype) \
softmax_single_row<itype>( \ instantiate_kernel("looped_softmax_" #name, softmax_looped, itype)
const device itype* in, \
device itype* out, \
constant int& axis_size, \
uint gid [[thread_position_in_grid]], \
uint _lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
template [[host_name("looped_softmax_" #name)]] [[kernel]] void \
softmax_looped<itype>( \
const device itype* in, \
device itype* out, \
constant int& axis_size, \
uint gid [[threadgroup_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_softmax_precise(name, itype) \ #define instantiate_softmax_precise(name, itype) \
template [[host_name("block_softmax_precise_" #name)]] [[kernel]] void \ instantiate_kernel("block_softmax_precise_" #name, softmax_single_row, itype, float) \
softmax_single_row<itype, float>( \ instantiate_kernel("looped_softmax_precise_" #name, softmax_looped, itype, float)
const device itype* in, \
device itype* out, \
constant int& axis_size, \
uint gid [[thread_position_in_grid]], \
uint _lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
template [[host_name("looped_softmax_precise_" #name)]] [[kernel]] void \
softmax_looped<itype, float>( \
const device itype* in, \
device itype* out, \
constant int& axis_size, \
uint gid [[threadgroup_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
instantiate_softmax(float32, float) instantiate_softmax(float32, float)
instantiate_softmax(float16, half) instantiate_softmax(float16, half)

View File

@@ -229,7 +229,7 @@ template <
// Init to -Inf // Init to -Inf
STEEL_PRAGMA_UNROLL STEEL_PRAGMA_UNROLL
for (short i = 0; i < kRowsPT; ++i) { for (short i = 0; i < kRowsPT; ++i) {
max_score[i] = Limits<AccumType>::min; max_score[i] = Limits<AccumType>::finite_min;
} }
int kb_lim = params->NK; int kb_lim = params->NK;
@@ -237,6 +237,7 @@ template <
if (do_causal) { if (do_causal) {
int q_max = (tid.x + 1) * BQ + params->qL_off; int q_max = (tid.x + 1) * BQ + params->qL_off;
kb_lim = (q_max + BK - 1) / BK; kb_lim = (q_max + BK - 1) / BK;
kb_lim = min(params->NK, kb_lim);
} }
// Loop over KV seq length // Loop over KV seq length
@@ -272,7 +273,7 @@ template <
if (!align_K && kb == (params->NK_aligned)) { if (!align_K && kb == (params->NK_aligned)) {
using stile_t = decltype(Stile); using stile_t = decltype(Stile);
using selem_t = typename stile_t::elem_type; using selem_t = typename stile_t::elem_type;
constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity(); constexpr auto neg_inf = Limits<selem_t>::finite_min;
STEEL_PRAGMA_UNROLL STEEL_PRAGMA_UNROLL
for (short i = 0; i < stile_t::kTileRows; i++) { for (short i = 0; i < stile_t::kTileRows; i++) {
@@ -290,10 +291,10 @@ template <
} }
// Mask out if causal // Mask out if causal
if (do_causal && kb >= (kb_lim - (BQ + BK - 1) / BK - int(!align_K))) { if (do_causal && kb >= (kb_lim - ((BQ + BK - 1) / BK) - int(!align_K))) {
using stile_t = decltype(Stile); using stile_t = decltype(Stile);
using selem_t = typename stile_t::elem_type; using selem_t = typename stile_t::elem_type;
constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity(); constexpr auto neg_inf = Limits<selem_t>::finite_min;
STEEL_PRAGMA_UNROLL STEEL_PRAGMA_UNROLL
for (short i = 0; i < stile_t::kTileRows; i++) { for (short i = 0; i < stile_t::kTileRows; i++) {
@@ -316,7 +317,7 @@ template <
if (has_mask) { if (has_mask) {
using stile_t = decltype(Stile); using stile_t = decltype(Stile);
using selem_t = typename stile_t::elem_type; using selem_t = typename stile_t::elem_type;
constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity(); constexpr auto neg_inf = Limits<selem_t>::finite_min;
constexpr bool is_bool = is_same_v<MaskType, bool>; constexpr bool is_bool = is_same_v<MaskType, bool>;
using melem_t = typename metal::conditional_t<is_bool, bool, selem_t>; using melem_t = typename metal::conditional_t<is_bool, bool, selem_t>;

View File

@@ -15,10 +15,6 @@ constant bool align_M [[function_constant(200)]];
constant bool align_N [[function_constant(201)]]; constant bool align_N [[function_constant(201)]];
constant bool align_K [[function_constant(202)]]; constant bool align_K [[function_constant(202)]];
constant bool do_gather [[function_constant(300)]];
constant bool gather_bias = do_gather && use_out_source;
// clang-format off // clang-format off
template < template <
typename T, typename T,
@@ -39,12 +35,6 @@ template <
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]], const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
const constant int* batch_shape [[buffer(6)]], const constant int* batch_shape [[buffer(6)]],
const constant int64_t* batch_strides [[buffer(7)]], const constant int64_t* batch_strides [[buffer(7)]],
const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]],
const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]],
const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]],
const constant int* operand_shape [[buffer(13), function_constant(do_gather)]],
const constant int64_t* operand_strides [[buffer(14), function_constant(do_gather)]],
const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]],
uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
@@ -81,84 +71,26 @@ template <
} }
// Adjust for batch // Adjust for batch
if (has_batch) {
const constant auto* A_bstrides = batch_strides;
const constant auto* B_bstrides = batch_strides + params->batch_ndim;
// Handle gather ulong2 batch_offsets = elem_to_loc_broadcast(
if (do_gather) { tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
// Read indices
uint32_t indx_A, indx_B, indx_C;
if (has_batch) { A += batch_offsets.x;
const constant auto* indx_A_bstrides = batch_strides; B += batch_offsets.y;
const constant auto* indx_B_bstrides = batch_strides + params->batch_ndim;
ulong2 indx_offsets = elem_to_loc_broadcast(
tid.z,
batch_shape,
indx_A_bstrides,
indx_B_bstrides,
params->batch_ndim);
indx_A = lhs_indices[indx_offsets.x];
indx_B = rhs_indices[indx_offsets.y];
if (use_out_source) {
const constant auto* indx_C_bstrides =
indx_B_bstrides + params->batch_ndim;
auto indx_offset_C = elem_to_loc(
tid.z, batch_shape, indx_C_bstrides, params->batch_ndim);
indx_C = C_indices[indx_offset_C];
}
} else {
indx_A = lhs_indices[params->batch_stride_a * tid.z];
indx_B = rhs_indices[params->batch_stride_b * tid.z];
if (use_out_source) {
indx_C = C_indices[addmm_params->batch_stride_c * tid.z];
}
}
// Translate indices to offsets
int batch_ndim_A = operand_batch_ndim.x;
const constant int* batch_shape_A = operand_shape;
const constant auto* batch_strides_A = operand_strides;
A += elem_to_loc(indx_A, batch_shape_A, batch_strides_A, batch_ndim_A);
int batch_ndim_B = operand_batch_ndim.y;
const constant int* batch_shape_B = batch_shape_A + batch_ndim_A;
const constant auto* batch_strides_B = batch_strides_A + batch_ndim_A;
B += elem_to_loc(indx_B, batch_shape_B, batch_strides_B, batch_ndim_B);
if (use_out_source) { if (use_out_source) {
int batch_ndim_C = operand_batch_ndim.z; const constant auto* C_bstrides = B_bstrides + params->batch_ndim;
const constant int* batch_shape_C = batch_shape_B + batch_ndim_B; C += elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim);
const constant auto* batch_strides_C = batch_strides_B + batch_ndim_B;
C += elem_to_loc(indx_C, batch_shape_C, batch_strides_C, batch_ndim_C);
} }
} else {
A += params->batch_stride_a * tid.z;
B += params->batch_stride_b * tid.z;
} if (use_out_source) {
C += addmm_params->batch_stride_c * tid.z;
// Handle regular batch
else {
if (has_batch) {
const constant auto* A_bstrides = batch_strides;
const constant auto* B_bstrides = batch_strides + params->batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
A += batch_offsets.x;
B += batch_offsets.y;
if (use_out_source) {
const constant auto* C_bstrides = B_bstrides + params->batch_ndim;
C += elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim);
}
} else {
A += params->batch_stride_a * tid.z;
B += params->batch_stride_b * tid.z;
if (use_out_source) {
C += addmm_params->batch_stride_c * tid.z;
}
} }
} }

View File

@@ -0,0 +1,459 @@
// Copyright © 2024 Apple Inc.
using namespace mlx::steel;
constant bool has_batch [[function_constant(10)]];
constant bool align_M [[function_constant(200)]];
constant bool align_N [[function_constant(201)]];
constant bool align_K [[function_constant(202)]];
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
typename AccumType = float>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gather_mm_rhs(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
const device uint32_t* rhs_indices [[buffer(2)]],
device T* C [[buffer(3)]],
const constant GEMMParams* params [[buffer(4)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]]) {
using gemm_kernel = GEMMKernel<
T,
T,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
true,
true,
AccumType>;
using loader_a_t = typename gemm_kernel::loader_a_t;
using loader_b_t = typename gemm_kernel::loader_b_t;
using mma_t = typename gemm_kernel::mma_t;
if (params->tiles_n <= static_cast<int>(tid.x) ||
params->tiles_m <= static_cast<int>(tid.y)) {
return;
}
// Prepare threadgroup memory
threadgroup T As[gemm_kernel::tgp_mem_size_a];
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
// Find the block in A, B, C
const int c_row = tid.y * BM;
const int c_col = tid.x * BN;
const size_t c_row_long = size_t(c_row);
const size_t c_col_long = size_t(c_col);
// Prepare threadgroup bounds
const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row));
const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col));
A += transpose_a ? c_row_long : c_row_long * params->lda;
B += transpose_b ? c_col_long * params->ldb : c_col_long;
C += c_row_long * params->ldd + c_col_long;
// Do as many matmuls as necessary
uint32_t index;
short offset;
uint32_t index_next = rhs_indices[c_row];
short offset_next = 0;
int n = 0;
while (n < tgp_bm) {
n++;
offset = offset_next;
index = index_next;
offset_next = tgp_bm;
for (; n < tgp_bm; n++) {
if (rhs_indices[c_row + n] != index) {
offset_next = n;
index_next = rhs_indices[c_row + n];
break;
}
}
threadgroup_barrier(mem_flags::mem_none);
// Prepare threadgroup mma operation
thread mma_t mma_op(simd_group_id, simd_lane_id);
// Prepare threadgroup loading operations
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
thread loader_b_t loader_b(
B + index * params->batch_stride_b,
params->ldb,
Bs,
simd_group_id,
simd_lane_id);
// Prepare iterations
const int gemm_k_iterations = params->gemm_k_iterations_aligned;
// Do unaligned K iterations first
if (!align_K) {
const int k_last = params->gemm_k_iterations_aligned * BK;
const int k_remain = params->K - k_last;
const size_t k_jump_a =
transpose_a ? params->lda * size_t(k_last) : size_t(k_last);
const size_t k_jump_b =
transpose_b ? size_t(k_last) : params->ldb * size_t(k_last);
// Move loader source ahead to end
loader_a.src += k_jump_a;
loader_b.src += k_jump_b;
// Load tile
const short2 tile_dims_A =
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
const short2 tile_dims_B =
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
loader_a.load_safe(tile_dims_A);
loader_b.load_safe(tile_dims_B);
threadgroup_barrier(mem_flags::mem_threadgroup);
// Do matmul
mma_op.mma(As, Bs);
// Reset source back to start
loader_a.src -= k_jump_a;
loader_b.src -= k_jump_b;
}
// Matrix level aligned never check
if (align_M && align_N) {
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
// Store results to device memory
if (offset_next - offset == BM) {
mma_op.store_result(C, params->ldd);
} else {
mma_op.store_result_slice(
C, params->ldd, short2(0, offset), short2(BN, offset_next));
}
} else {
const short lbk = 0;
// Tile aligned don't check
if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
lbk,
LoopAlignment<true, true, true>{});
if (offset_next - offset == BM) {
mma_op.store_result(C, params->ldd);
} else {
mma_op.store_result_slice(
C, params->ldd, short2(0, offset), short2(BN, offset_next));
}
}
// Tile partially aligned check rows
else if (align_N || tgp_bn == BN) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
lbk,
LoopAlignment<false, true, true>{});
mma_op.store_result_slice(
C, params->ldd, short2(0, offset), short2(BN, offset_next));
}
// Tile partially aligned check cols
else if (align_M || tgp_bm == BM) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
lbk,
LoopAlignment<true, false, true>{});
mma_op.store_result_slice(
C, params->ldd, short2(0, offset), short2(tgp_bn, offset_next));
}
// Nothing aligned so check both rows and cols
else {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
lbk,
LoopAlignment<false, false, true>{});
mma_op.store_result_slice(
C, params->ldd, short2(0, offset), short2(tgp_bn, offset_next));
}
}
}
}
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
typename AccumType = float>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gather_mm(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
const device uint32_t* lhs_indices [[buffer(2)]],
const device uint32_t* rhs_indices [[buffer(3)]],
device T* C [[buffer(4)]],
const constant GEMMParams* params [[buffer(5)]],
const constant int* indices_shape [[buffer(6)]],
const constant int64_t* lhs_strides [[buffer(7)]],
const constant int64_t* rhs_strides [[buffer(8)]],
const constant int& batch_ndim_a [[buffer(9)]],
const constant int* batch_shape_a [[buffer(10)]],
const constant int64_t* batch_strides_a [[buffer(11)]],
const constant int& batch_ndim_b [[buffer(12)]],
const constant int* batch_shape_b [[buffer(13)]],
const constant int64_t* batch_strides_b [[buffer(14)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]]) {
using gemm_kernel = GEMMKernel<
T,
T,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
true,
true,
AccumType>;
using loader_a_t = typename gemm_kernel::loader_a_t;
using loader_b_t = typename gemm_kernel::loader_b_t;
using mma_t = typename gemm_kernel::mma_t;
if (params->tiles_n <= static_cast<int>(tid.x) ||
params->tiles_m <= static_cast<int>(tid.y)) {
return;
}
// Move A and B to the locations pointed by lhs_indices and rhs_indices.
uint32_t indx_A, indx_B;
if (has_batch) {
ulong2 indices_offsets = elem_to_loc_broadcast(
tid.z, indices_shape, lhs_strides, rhs_strides, params->batch_ndim);
indx_A = lhs_indices[indices_offsets.x];
indx_B = rhs_indices[indices_offsets.y];
} else {
indx_A = lhs_indices[params->batch_stride_a * tid.z];
indx_B = rhs_indices[params->batch_stride_b * tid.z];
}
A += elem_to_loc(indx_A, batch_shape_a, batch_strides_a, batch_ndim_a);
B += elem_to_loc(indx_B, batch_shape_b, batch_strides_b, batch_ndim_b);
C += params->batch_stride_d * tid.z;
// Prepare threadgroup memory
threadgroup T As[gemm_kernel::tgp_mem_size_a];
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
// Just make sure everybody's finished with the indexing math above.
threadgroup_barrier(mem_flags::mem_none);
// Find block in A, B, C
const int c_row = tid.y * BM;
const int c_col = tid.x * BN;
const size_t c_row_long = size_t(c_row);
const size_t c_col_long = size_t(c_col);
A += transpose_a ? c_row_long : c_row_long * params->lda;
B += transpose_b ? c_col_long * params->ldb : c_col_long;
C += c_row_long * params->ldd + c_col_long;
// Prepare threadgroup mma operation
thread mma_t mma_op(simd_group_id, simd_lane_id);
// Prepare threadgroup loading operations
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
// Prepare threadgroup bounds
const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row));
const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col));
// Prepare iterations
int gemm_k_iterations = params->gemm_k_iterations_aligned;
// Do unaligned K iterations first
if (!align_K) {
const int k_last = params->gemm_k_iterations_aligned * BK;
const int k_remain = params->K - k_last;
const size_t k_jump_a =
transpose_a ? params->lda * size_t(k_last) : size_t(k_last);
const size_t k_jump_b =
transpose_b ? size_t(k_last) : params->ldb * size_t(k_last);
// Move loader source ahead to end
loader_a.src += k_jump_a;
loader_b.src += k_jump_b;
// Load tile
const short2 tile_dims_A =
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
const short2 tile_dims_B =
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
loader_a.load_safe(tile_dims_A);
loader_b.load_safe(tile_dims_B);
threadgroup_barrier(mem_flags::mem_threadgroup);
// Do matmul
mma_op.mma(As, Bs);
// Reset source back to start
loader_a.src -= k_jump_a;
loader_b.src -= k_jump_b;
}
// Matrix level aligned never check
if (align_M && align_N) {
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
// Store results to device memory
mma_op.store_result(C, params->ldd);
} else {
const short lbk = 0;
// Tile aligned don't check
if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
lbk,
LoopAlignment<true, true, true>{});
mma_op.store_result(C, params->ldd);
}
// Tile partially aligned check rows
else if (align_N || tgp_bn == BN) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
lbk,
LoopAlignment<false, true, true>{});
mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm));
}
// Tile partially aligned check cols
else if (align_M || tgp_bm == BM) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
lbk,
LoopAlignment<true, false, true>{});
mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm));
}
// Nothing aligned so check both rows and cols
else {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
lbk,
LoopAlignment<false, false, true>{});
mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm));
}
}
}

View File

@@ -0,0 +1,59 @@
// Copyright © 2024 Apple Inc.
// clang-format off
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h"
#define instantiate_gather_mm_rhs(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_kernel( \
"steel_gather_mm_rhs_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn \
"_bk" #bk "_wm" #wm "_wn" #wn, \
gather_mm_rhs, \
itype, \
bm, \
bn, \
bk, \
wm, \
wn, \
trans_a, \
trans_b, \
float)
#define instantiate_gather_mm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_kernel( \
"steel_gather_mm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn \
"_bk" #bk "_wm" #wm "_wn" #wn, \
gather_mm, \
itype, \
bm, \
bn, \
bk, \
wm, \
wn, \
trans_a, \
trans_b, \
float)
#define instantiate_gather_mm_rhs_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gather_mm_rhs(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gather_mm_rhs(nt, false, true, iname, itype, oname, otype, bm, bn, bk, wm, wn)
#define instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gather_mm(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gather_mm(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gather_mm(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gather_mm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
#define instantiate_gather_mm_shapes_helper(iname, itype, oname, otype) \
instantiate_gather_mm_rhs_transpose_helper(iname, itype, oname, otype, 16, 64, 16, 1, 2) \
instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 1, 2) \
instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 1, 2) \
instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2)
// clang-format on
instantiate_gather_mm_shapes_helper(float16, half, float16, half);
instantiate_gather_mm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
instantiate_gather_mm_shapes_helper(float32, float, float32, float);

View File

@@ -142,6 +142,42 @@ struct BaseMMAFrag<T, 8, 8> {
} }
} }
template <
typename DstPtrType,
typename StrX,
typename StrY,
typename StartX,
typename StopX,
typename StartY,
typename StopY,
typename OffX,
typename OffY>
METAL_FUNC static constexpr void store_slice(
const thread frag_type& src,
DstPtrType dst,
StrX str_x,
StrY str_y,
StartX start_x,
StopX stop_x,
StartY start_y,
StopY stop_y,
OffX off_x = Int<0>{},
OffY off_y = Int<0>{}) {
using U = pointer_element_t<DstPtrType>;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kElemRows; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kElemCols; j++) {
if ((off_x + i) < stop_x && (off_x + i) >= start_x &&
(off_y + j) < stop_y && (off_y + j) >= start_y) {
dst[(off_x + i) * str_x + (off_y + j) * str_y] =
static_cast<U>(src[i * kElemCols + j]);
}
}
}
}
METAL_FUNC static constexpr void mma( METAL_FUNC static constexpr void mma(
thread frag_type& D, thread frag_type& D,
thread frag_type& A, thread frag_type& A,
@@ -335,6 +371,31 @@ struct MMATile {
} }
} }
} }
template <typename U, int w_x, int w_y>
METAL_FUNC void store_slice(
device U* dst,
const int ld,
const short2 start,
const short2 stop) const {
STEEL_PRAGMA_UNROLL
for (int i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (int j = 0; j < kTileCols; ++j) {
MMAFrag_t::store_slice(
frag_at(i, j),
dst,
ld,
Int<1>{},
start.y,
stop.y,
start.x,
stop.x,
(i * kFragRows) * w_x,
(j * kFragCols) * w_y);
}
}
}
}; };
template <typename T, typename U, int M, int N, int K> template <typename T, typename U, int M, int N, int K>
@@ -474,6 +535,26 @@ struct BlockMMA {
Ctile.template store<U, WM, WN>(D, ldd); Ctile.template store<U, WM, WN>(D, ldd);
} }
METAL_FUNC void
store_result_slice(device U* D, const int ldd, short2 start, short2 stop) {
// Apply epilogue
STEEL_PRAGMA_UNROLL
for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]);
}
D += sm * ldd + sn;
start -= short2(sn, sm);
stop -= short2(sn, sm);
// TODO: Check the start as well
if (stop.y <= 0 || stop.x <= 0) {
return;
}
Ctile.template store_slice<U, WM, WN>(D, ldd, start, stop);
}
METAL_FUNC void METAL_FUNC void
store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) { store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) {
// Apply epilogue // Apply epilogue

View File

@@ -73,6 +73,9 @@ instantiate_unary_all_same(Conjugate, complex64, complex64_t)
instantiate_unary_all_same(Cos, complex64, complex64_t) instantiate_unary_all_same(Cos, complex64, complex64_t)
instantiate_unary_all_same(Cosh, complex64, complex64_t) instantiate_unary_all_same(Cosh, complex64, complex64_t)
instantiate_unary_all_same(Exp, complex64, complex64_t) instantiate_unary_all_same(Exp, complex64, complex64_t)
instantiate_unary_all_same(Log, complex64, complex64_t)
instantiate_unary_all_same(Log2, complex64, complex64_t)
instantiate_unary_all_same(Log10, complex64, complex64_t)
instantiate_unary_all_same(Negative, complex64, complex64_t) instantiate_unary_all_same(Negative, complex64, complex64_t)
instantiate_unary_all_same(Sign, complex64, complex64_t) instantiate_unary_all_same(Sign, complex64, complex64_t)
instantiate_unary_all_same(Sin, complex64, complex64_t) instantiate_unary_all_same(Sin, complex64, complex64_t)

View File

@@ -257,6 +257,13 @@ struct Log {
T operator()(T x) { T operator()(T x) {
return metal::precise::log(x); return metal::precise::log(x);
}; };
template <>
complex64_t operator()(complex64_t x) {
auto r = metal::precise::log(Abs{}(x).real);
auto i = metal::precise::atan2(x.imag, x.real);
return {r, i};
};
}; };
struct Log2 { struct Log2 {
@@ -264,6 +271,12 @@ struct Log2 {
T operator()(T x) { T operator()(T x) {
return metal::precise::log2(x); return metal::precise::log2(x);
}; };
template <>
complex64_t operator()(complex64_t x) {
auto y = Log{}(x);
return {y.real / M_LN2_F, y.imag / M_LN2_F};
};
}; };
struct Log10 { struct Log10 {
@@ -271,6 +284,12 @@ struct Log10 {
T operator()(T x) { T operator()(T x) {
return metal::precise::log10(x); return metal::precise::log10(x);
}; };
template <>
complex64_t operator()(complex64_t x) {
auto y = Log{}(x);
return {y.real / M_LN10_F, y.imag / M_LN10_F};
};
}; };
struct Log1p { struct Log1p {

View File

@@ -0,0 +1,96 @@
// Copyright © 2023-2024 Apple Inc.
#include <algorithm>
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
namespace mlx::core {
constexpr int LOGSUMEXP_LOOPED_LIMIT = 4096;
void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
if (!issubdtype(out.dtype(), floating)) {
throw std::runtime_error(
"[logsumexp] Does not support non-floating point types.");
}
auto& s = stream();
auto& d = metal::device(s.device);
// Make sure that the last dimension is contiguous
auto ensure_contiguous = [&s, &d](const array& x) {
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
return x;
} else {
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
d.add_temporary(x_copy, s.index);
return x_copy;
}
};
auto in = ensure_contiguous(inputs[0]);
if (in.flags().row_contiguous) {
out.set_data(allocator::malloc(out.nbytes()));
} else {
auto n = in.shape(-1);
auto flags = in.flags();
auto strides = in.strides();
for (auto& s : strides) {
s /= n;
}
bool col_contig = strides[0] == 1;
for (int i = 1; col_contig && i < strides.size(); ++i) {
col_contig &=
(out.shape(i) == 1 || strides[i - 1] == out.shape(i) * strides[i]);
}
flags.col_contiguous = col_contig;
out.set_data(
allocator::malloc(in.nbytes() / n),
in.data_size() / n,
std::move(strides),
flags);
}
int axis_size = in.shape().back();
int n_rows = in.data_size() / axis_size;
const int simd_size = 32;
const int n_reads = 4;
const int looped_limit = LOGSUMEXP_LOOPED_LIMIT;
std::string kernel_name = (axis_size > looped_limit) ? "looped_" : "block_";
kernel_name += "logsumexp_";
kernel_name += type_to_name(out);
auto kernel = get_logsumexp_kernel(d, kernel_name, out);
auto& compute_encoder = d.get_command_encoder(s.index);
{
MTL::Size grid_dims, group_dims;
if (axis_size <= looped_limit) {
size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads;
size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size;
size_t threadgroup_size = simd_size * simds_needed;
assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup());
size_t n_threads = n_rows * threadgroup_size;
grid_dims = MTL::Size(n_threads, 1, 1);
group_dims = MTL::Size(threadgroup_size, 1, 1);
} else {
size_t threadgroup_size = kernel->maxTotalThreadsPerThreadgroup();
size_t n_threads = n_rows * threadgroup_size;
grid_dims = MTL::Size(n_threads, 1, 1);
group_dims = MTL::Size(threadgroup_size, 1, 1);
}
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder.set_bytes(axis_size, 2);
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
}
} // namespace mlx::core

View File

@@ -5,6 +5,7 @@
#include <numeric> #include <numeric>
#include <sstream> #include <sstream>
#include "mlx/backend/common/broadcasting.h"
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
@@ -102,6 +103,47 @@ std::tuple<bool, int64_t, array> check_transpose(
} }
}; };
inline array
ensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) {
if (!x.flags().row_contiguous) {
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
d.add_temporary(x_copy, s.index);
return x_copy;
} else {
return x;
}
}
inline std::tuple<bool, int64_t, array>
ensure_batch_contiguous(const array& x, metal::Device& d, const Stream& s) {
if (x.flags().row_contiguous) {
return std::make_tuple(false, x.strides()[x.ndim() - 2], x);
}
bool rc = true;
for (int i = 0; i < x.ndim() - 3; i++) {
rc &= x.strides()[i + 1] * x.shape(i) == x.strides()[i];
}
if (rc) {
auto stx = x.strides()[x.ndim() - 2];
auto sty = x.strides()[x.ndim() - 1];
auto K = x.shape(-2);
auto N = x.shape(-1);
if (sty == 1 && (N != 1 || stx == N)) {
return std::make_tuple(false, stx, x);
}
if (stx == 1 && (N != 1 || sty == K)) {
return std::make_tuple(true, sty, x);
}
}
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
d.add_temporary(x_copy, s.index);
return std::make_tuple(false, x_copy.strides()[x_copy.ndim() - 2], x_copy);
}
} // namespace } // namespace
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
@@ -230,7 +272,6 @@ void steel_matmul_regular(
const bool align_M = (M % bm) == 0; const bool align_M = (M % bm) == 0;
const bool align_N = (N % bn) == 0; const bool align_N = (N % bn) == 0;
const bool align_K = (K % bk) == 0; const bool align_K = (K % bk) == 0;
const bool do_gather = false;
metal::MTLFCList func_consts = { metal::MTLFCList func_consts = {
{&has_batch, MTL::DataType::DataTypeBool, 10}, {&has_batch, MTL::DataType::DataTypeBool, 10},
@@ -239,7 +280,6 @@ void steel_matmul_regular(
{&align_M, MTL::DataType::DataTypeBool, 200}, {&align_M, MTL::DataType::DataTypeBool, 200},
{&align_N, MTL::DataType::DataTypeBool, 201}, {&align_N, MTL::DataType::DataTypeBool, 201},
{&align_K, MTL::DataType::DataTypeBool, 202}, {&align_K, MTL::DataType::DataTypeBool, 202},
{&do_gather, MTL::DataType::DataTypeBool, 300},
}; };
// clang-format off // clang-format off
@@ -248,8 +288,7 @@ void steel_matmul_regular(
<< "_do_axpby_" << (do_axpby ? 't' : 'n') << "_do_axpby_" << (do_axpby ? 't' : 'n')
<< "_align_M_" << (align_M ? 't' : 'n') << "_align_M_" << (align_M ? 't' : 'n')
<< "_align_N_" << (align_N ? 't' : 'n') << "_align_N_" << (align_N ? 't' : 'n')
<< "_align_K_" << (align_K ? 't' : 'n') << "_align_K_" << (align_K ? 't' : 'n'); // clang-format on
<< "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on
std::string hash_name = kname.str(); std::string hash_name = kname.str();
@@ -382,7 +421,7 @@ void steel_matmul(
int split_k_partition_size = gemm_k_iterations * bk; int split_k_partition_size = gemm_k_iterations * bk;
array C_split({split_k_partitions, M, N}, float32, nullptr, {}); array C_split({split_k_partitions, M, N}, float32, nullptr, {});
C_split.set_data(allocator::malloc_or_wait(C_split.nbytes())); C_split.set_data(allocator::malloc(C_split.nbytes()));
copies.push_back(C_split); copies.push_back(C_split);
bool mn_aligned = M % bm == 0 && N % bn == 0; bool mn_aligned = M % bm == 0 && N % bn == 0;
@@ -513,7 +552,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
return; return;
} }
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////
// Init checks and prep // Init checks and prep
@@ -677,7 +716,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
throw std::runtime_error( throw std::runtime_error(
"[matmul] Does not yet support non-floating point types."); "[matmul] Does not yet support non-floating point types.");
} }
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream(); auto& s = stream();
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
@@ -860,7 +899,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
int split_k_partition_size = gemm_k_iterations * bk; int split_k_partition_size = gemm_k_iterations * bk;
array C_split({split_k_partitions, M, N}, float32, nullptr, {}); array C_split({split_k_partitions, M, N}, float32, nullptr, {});
C_split.set_data(allocator::malloc_or_wait(C_split.nbytes())); C_split.set_data(allocator::malloc(C_split.nbytes()));
copies.push_back(C_split); copies.push_back(C_split);
bool mn_aligned = M % bm == 0 && N % bn == 0; bool mn_aligned = M % bm == 0 && N % bn == 0;
@@ -975,7 +1014,6 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
const bool align_M = (M % bm) == 0; const bool align_M = (M % bm) == 0;
const bool align_N = (N % bn) == 0; const bool align_N = (N % bn) == 0;
const bool align_K = (K % bk) == 0; const bool align_K = (K % bk) == 0;
const bool do_gather = false;
metal::MTLFCList func_consts = { metal::MTLFCList func_consts = {
{&has_batch, MTL::DataType::DataTypeBool, 10}, {&has_batch, MTL::DataType::DataTypeBool, 10},
@@ -984,7 +1022,6 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
{&align_M, MTL::DataType::DataTypeBool, 200}, {&align_M, MTL::DataType::DataTypeBool, 200},
{&align_N, MTL::DataType::DataTypeBool, 201}, {&align_N, MTL::DataType::DataTypeBool, 201},
{&align_K, MTL::DataType::DataTypeBool, 202}, {&align_K, MTL::DataType::DataTypeBool, 202},
{&do_gather, MTL::DataType::DataTypeBool, 300},
}; };
// clang-format off // clang-format off
@@ -993,8 +1030,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
<< "_do_axpby_" << (do_axpby ? 't' : 'n') << "_do_axpby_" << (do_axpby ? 't' : 'n')
<< "_align_M_" << (align_M ? 't' : 'n') << "_align_M_" << (align_M ? 't' : 'n')
<< "_align_N_" << (align_N ? 't' : 'n') << "_align_N_" << (align_N ? 't' : 'n')
<< "_align_K_" << (align_K ? 't' : 'n') << "_align_K_" << (align_K ? 't' : 'n'); // clang-format on
<< "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on
std::string hash_name = kname.str(); std::string hash_name = kname.str();
@@ -1096,7 +1132,7 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
return; return;
} }
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////
// Init checks and prep // Init checks and prep
@@ -1464,267 +1500,337 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
d.add_temporaries(std::move(copies), s.index); d.add_temporaries(std::move(copies), s.index);
} }
void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) { void gather_mm_rhs(
using namespace mlx::steel; const array& a_,
// assert(inputs.size() == 2); const array& b_,
if (!issubdtype(out.dtype(), floating)) { const array& indices_,
throw std::runtime_error( array& out,
"[GatherMM] Does not yet support non-floating point types."); metal::Device& d,
} const Stream& s) {
auto& s = stream(); array indices = ensure_row_contiguous(indices_, d, s);
auto& d = metal::device(s.device); auto [transpose_b, ldb, b] = ensure_batch_contiguous(b_, d, s);
auto& a_pre = inputs[0]; // Broadcast a with indices. If we are here that means lhs_indices were not
auto& b_pre = inputs[1]; // provided so the lhs_indices are implied to be the shape of a broadcasted
// Return 0s if either input is empty // with rhs_indices. We need only broadcast a and copy it as if applying the
if (a_pre.size() == 0 || b_pre.size() == 0) { // lhs_indices.
array zero = array(0, a_pre.dtype()); auto broadcast_with_indices = [&d, &s, &indices](const array& x) {
fill_gpu(zero, out, s); if (x.size() / x.shape(-2) / x.shape(-1) == indices.size()) {
d.add_temporary(std::move(zero), s.index); return ensure_row_contiguous(x, d, s);
return; }
}
out.set_data(allocator::malloc_or_wait(out.nbytes())); auto x_shape = indices.shape();
x_shape.push_back(x.shape(-2));
x_shape.push_back(x.shape(-1));
array new_x(std::move(x_shape), x.dtype(), nullptr, {});
broadcast(x, new_x);
return ensure_row_contiguous(new_x, d, s);
};
array a = broadcast_with_indices(a_);
///////////////////////////////////////////////////////////////////////////// // Extract the matmul shapes
// Init checks and prep int K = a.shape(-1);
int M = a.size() / K;
int N = b.shape(-1);
int lda = a.strides()[a.ndim() - 2]; // should be K
int M = a_pre.shape(-2); // Define the dispatch blocks
int N = b_pre.shape(-1); int bm = 16, bn = 64, bk = 16;
int K = a_pre.shape(-1); int wm = 1, wn = 2;
// Keep a vector with copies to be cleared in the completed buffer to release const bool align_M = (M % bm) == 0;
// the arrays const bool align_N = (N % bn) == 0;
std::vector<array> copies; const bool align_K = (K % bk) == 0;
auto [transpose_a, a_cols, a] = check_transpose(copies, s, a_pre, M == 1);
auto [transpose_b, b_cols, b] = check_transpose(copies, s, b_pre, N == 1);
int lda = a_cols; // Define the kernel name
int ldb = b_cols; std::string base_name;
base_name.reserve(64);
concatenate(
base_name,
"steel_gather_mm_rhs_n",
transpose_b ? 't' : 'n',
'_',
type_to_name(a),
'_',
type_to_name(out),
"_bm",
bm,
"_bn",
bn,
"_bk",
bk,
"_wm",
wm,
"_wn",
wn);
///////////////////////////////////////////////////////////////////////////// metal::MTLFCList func_consts = {
// Check and collapse batch dimensions {&align_M, MTL::DataType::DataTypeBool, 200},
{&align_N, MTL::DataType::DataTypeBool, 201},
auto get_batch_dims = [](const auto& v) { {&align_K, MTL::DataType::DataTypeBool, 202},
return decltype(v){v.begin(), v.end() - 2};
}; };
auto& lhs_indices = inputs[2]; // And the kernel hash that includes the function constants
auto& rhs_indices = inputs[3]; std::string hash_name;
hash_name.reserve(128);
concatenate(
hash_name,
base_name,
"_align_M_",
align_M ? 't' : 'n',
"_align_N_",
align_N ? 't' : 'n',
"_align_K_",
align_K ? 't' : 'n');
Shape batch_shape = get_batch_dims(out.shape()); // Get and set the kernel
Strides batch_strides; auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = get_steel_gemm_gather_kernel(
d,
base_name,
hash_name,
func_consts,
out,
false,
transpose_b,
bm,
bn,
bk,
wm,
wn,
true);
compute_encoder.set_compute_pipeline_state(kernel);
batch_strides.insert( // Prepare the matmul params
batch_strides.end(), auto batch_stride_b = b.ndim() > 2 ? b.strides()[b.ndim() - 3] : b.size();
lhs_indices.strides().begin(), steel::GEMMParams params{
lhs_indices.strides().end()); /* const int M = */ M,
auto lhs_indices_str = batch_strides.empty() ? 0 : batch_strides.back(); /* const int N = */ N,
/* const int K = */ K,
/* const int lda = */ lda,
/* const int ldb = */ static_cast<int>(ldb),
/* const int ldd = */ N,
/* const int tiles_n = */ (N + bn - 1) / bn,
/* const int tiles_m = */ (M + bm - 1) / bm,
/* const int64_t batch_stride_a = */ 0,
/* const int64_t batch_stride_b = */ static_cast<int64_t>(batch_stride_b),
/* const int64_t batch_stride_d = */ 0,
/* const int swizzle_log = */ 0,
/* const int gemm_k_iterations_aligned = */ (K / bk),
/* const int batch_ndim = */ 0};
batch_strides.insert( // Prepare the grid
batch_strides.end(), MTL::Size group_dims = MTL::Size(32, wn, wm);
rhs_indices.strides().begin(), MTL::Size grid_dims = MTL::Size(params.tiles_n, params.tiles_m, 1);
rhs_indices.strides().end());
auto rhs_indices_str = batch_strides.empty() ? 0 : batch_strides.back();
int batch_ndim = batch_shape.size(); // Launch kernel
compute_encoder.set_input_array(a, 0);
compute_encoder.set_input_array(b, 1);
compute_encoder.set_input_array(indices, 2);
compute_encoder.set_output_array(out, 3);
compute_encoder.set_bytes(params, 4);
if (batch_ndim == 0) { compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
batch_shape = {1}; }
batch_strides = {0};
}
int batch_ndim_A = a.ndim() - 2; void gather_mv(
int batch_ndim_B = b.ndim() - 2; const array& mat_,
std::vector<int> operand_batch_ndim = {batch_ndim_A, batch_ndim_B}; const array& vec_,
const array& mat_indices_,
const array& vec_indices_,
array& out,
int N,
int K,
bool is_mv,
metal::Device& d,
const Stream& s) {
// Copy if needed
std::vector<array> copies;
auto [transpose_mat, mat_cols, mat] =
check_transpose(copies, s, mat_, N == 1);
auto [transpose_vec, vec_cols, vec] = check_transpose(copies, s, vec_, true);
d.add_temporaries(std::move(copies), s.index);
Shape batch_shape_A = get_batch_dims(a.shape()); // If we are doing vector matrix instead of matrix vector we need to flip the
Strides batch_strides_A = get_batch_dims(a.strides()); // matrix transposition. Basically m @ v = v @ m.T assuming that v is treated
Shape batch_shape_B = get_batch_dims(b.shape()); // as a one dimensional array.
Strides batch_strides_B = get_batch_dims(b.strides()); transpose_mat = (!is_mv) ^ transpose_mat;
if (batch_ndim_A == 0) { // Define some shapes
batch_shape_A = {1}; int in_vector_len = K;
batch_strides_A = {0}; int out_vector_len = N;
} int mat_ld = mat_cols;
if (batch_ndim_B == 0) { int batch_size_out = out.size() / N;
batch_shape_B = {1}; int batch_ndim = out.ndim() - 2;
batch_strides_B = {0}; int batch_ndim_mat = mat.ndim() - 2;
} int batch_ndim_vec = vec.ndim() - 2;
Strides index_strides = vec_indices_.strides();
index_strides.insert(
index_strides.end(),
mat_indices_.strides().begin(),
mat_indices_.strides().end());
auto matrix_stride_out = static_cast<int64_t>(M) * N; // Determine dispatch kernel
auto batch_size_out = out.size() / matrix_stride_out; int tm = 4, tn = 4;
int sm = 1, sn = 32;
///////////////////////////////////////////////////////////////////////////// int bm = 1, bn = 1;
// Gemv specialization int n_out_per_tgp;
std::ostringstream kname;
// Route to gemv if needed
if (std::min(M, N) == 1) {
// Collect problem info
bool is_b_matrix = N != 1;
auto& mat = is_b_matrix ? b : a;
auto& vec = is_b_matrix ? a : b;
bool transpose_mat = is_b_matrix ? !transpose_b : transpose_a;
int in_vector_len = K;
int out_vector_len = is_b_matrix ? N : M;
int mat_cols = transpose_mat ? out_vector_len : in_vector_len;
int mat_rows = transpose_mat ? in_vector_len : out_vector_len;
int mat_ld = is_b_matrix ? b_cols : a_cols;
auto batch_strides_mat = is_b_matrix ? batch_strides_B : batch_strides_A;
auto batch_strides_vec = is_b_matrix ? batch_strides_A : batch_strides_B;
auto batch_shape_mat = is_b_matrix ? batch_shape_B : batch_shape_A;
auto batch_shape_vec = is_b_matrix ? batch_shape_A : batch_shape_B;
if (!is_b_matrix) {
batch_strides = rhs_indices.strides();
batch_strides.insert(
batch_strides.end(),
lhs_indices.strides().begin(),
lhs_indices.strides().end());
}
int batch_ndim = batch_shape.size();
// Determine dispatch kernel
int tm = 4, tn = 4;
int sm = 1, sn = 32;
int bm = 1, bn = 1;
int n_out_per_tgp;
std::ostringstream kname;
if (transpose_mat) {
if (in_vector_len >= 8192 && out_vector_len >= 2048) {
sm = 4;
sn = 8;
} else {
sm = 8;
sn = 4;
}
if (out_vector_len >= 2048) {
bn = 16;
} else if (out_vector_len >= 512) {
bn = 4;
} else {
bn = 2;
}
// Specialized kernel for very small outputs
tn = out_vector_len < tn ? 1 : tn;
n_out_per_tgp = bn * sn * tn;
kname << "gemv_t_gather_" << type_to_name(out);
if (transpose_mat) {
if (in_vector_len >= 8192 && out_vector_len >= 2048) {
sm = 4;
sn = 8;
} else { } else {
bm = out_vector_len >= 4096 ? 8 : 4; sm = 8;
sn = 32; sn = 4;
// Specialized kernel for very small outputs
tm = out_vector_len < tm ? 1 : tm;
n_out_per_tgp = bm * sm * tm;
kname << "gemv_gather_" << type_to_name(out);
} }
kname << "_bm" << bm << "_bn" << bn << "_sm" << sm << "_sn" << sn << "_tm" if (out_vector_len >= 2048) {
<< tm << "_tn" << tn; bn = 16;
} else if (out_vector_len >= 512) {
bn = 4;
} else {
bn = 2;
}
// Encode and dispatch kernel // Specialized kernel for very small outputs
auto& compute_encoder = d.get_command_encoder(s.index); tn = out_vector_len < tn ? 1 : tn;
auto kernel = d.get_kernel(kname.str());
compute_encoder.set_compute_pipeline_state(kernel);
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp; n_out_per_tgp = bn * sn * tn;
MTL::Size group_dims = MTL::Size(32, bn, bm); kname << "gemv_t_gather_" << type_to_name(out);
MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out);
compute_encoder.set_input_array(mat, 0); } else {
compute_encoder.set_input_array(vec, 1); bm = out_vector_len >= 4096 ? 8 : 4;
compute_encoder.set_output_array(out, 3); sn = 32;
compute_encoder.set_bytes(in_vector_len, 4); // Specialized kernel for very small outputs
compute_encoder.set_bytes(out_vector_len, 5); tm = out_vector_len < tm ? 1 : tm;
compute_encoder.set_bytes(mat_ld, 6);
compute_encoder.set_bytes(batch_ndim, 9); n_out_per_tgp = bm * sm * tm;
compute_encoder.set_vector_bytes(batch_shape, 10); kname << "gemv_gather_" << type_to_name(out);
compute_encoder.set_vector_bytes(batch_strides, 11);
int batch_ndim_vec = batch_shape_vec.size();
compute_encoder.set_bytes(batch_ndim_vec, 12);
compute_encoder.set_vector_bytes(batch_shape_vec, 13);
compute_encoder.set_vector_bytes(batch_strides_vec, 14);
int batch_ndim_mat = batch_shape_mat.size();
compute_encoder.set_bytes(batch_ndim_mat, 15);
compute_encoder.set_vector_bytes(batch_shape_mat, 16);
compute_encoder.set_vector_bytes(batch_strides_mat, 17);
compute_encoder.set_input_array(lhs_indices, 18 + int(!is_b_matrix));
compute_encoder.set_input_array(rhs_indices, 18 + int(is_b_matrix));
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
d.add_temporaries(std::move(copies), s.index);
return;
} }
///////////////////////////////////////////////////////////////////////////// kname << "_bm" << bm << "_bn" << bn << "_sm" << sm << "_sn" << sn << "_tm"
// Regular kernel dispatch << tm << "_tn" << tn;
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder.set_compute_pipeline_state(kernel);
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
MTL::Size group_dims = MTL::Size(32, bn, bm);
MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out);
compute_encoder.set_input_array(mat, 0);
compute_encoder.set_input_array(vec, 1);
compute_encoder.set_output_array(out, 3);
compute_encoder.set_bytes(in_vector_len, 4);
compute_encoder.set_bytes(out_vector_len, 5);
compute_encoder.set_bytes(mat_ld, 6);
compute_encoder.set_bytes(batch_ndim, 9);
compute_encoder.set_vector_bytes(out.shape(), 10);
compute_encoder.set_vector_bytes(index_strides, 11);
compute_encoder.set_bytes(batch_ndim_vec, 12);
compute_encoder.set_vector_bytes(vec.shape(), 13);
compute_encoder.set_vector_bytes(vec.strides(), 14);
compute_encoder.set_bytes(batch_ndim_mat, 15);
compute_encoder.set_vector_bytes(mat.shape(), 16);
compute_encoder.set_vector_bytes(mat.strides(), 17);
compute_encoder.set_input_array(vec_indices_, 18);
compute_encoder.set_input_array(mat_indices_, 19);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
void gather_mm(
const array& a_,
const array& b_,
const array& lhs_indices,
const array& rhs_indices,
array& out,
int M,
int N,
int K,
metal::Device& d,
const Stream& s) {
// Copy if needed
std::vector<array> copies;
auto [transpose_a, lda, a] = check_transpose(copies, s, a_, false);
auto [transpose_b, ldb, b] = check_transpose(copies, s, b_, false);
d.add_temporaries(std::move(copies), s.index);
// Determine dispatch kernel // Determine dispatch kernel
int bm = 64, bn = 64, bk = 16; int bm = 64, bn = 64, bk = 16;
int wm = 2, wn = 2; int wm = 2, wn = 2;
size_t batch_size_out = out.size() / M / N;
int batch_ndim = out.ndim() - 2;
int batch_ndim_a = a.ndim() - 2;
int batch_ndim_b = b.ndim() - 2;
char devc = d.get_architecture().back(); char devc = d.get_architecture().back();
GEMM_TPARAM_MACRO(devc) GEMM_TPARAM_MACRO(devc)
// Prepare kernel name
std::ostringstream kname;
kname << "steel_gemm_fused_" << (transpose_a ? 't' : 'n')
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
<< "_wm" << wm << "_wn" << wn;
std::string base_name = kname.str();
const bool has_batch = batch_ndim > 1; const bool has_batch = batch_ndim > 1;
const bool use_out_source = false;
const bool do_axpby = false;
const bool align_M = (M % bm) == 0; const bool align_M = (M % bm) == 0;
const bool align_N = (N % bn) == 0; const bool align_N = (N % bn) == 0;
const bool align_K = (K % bk) == 0; const bool align_K = (K % bk) == 0;
const bool do_gather = true;
// Define the kernel name
std::string base_name;
base_name.reserve(128);
concatenate(
base_name,
"steel_gather_mm_",
transpose_a ? 't' : 'n',
transpose_b ? 't' : 'n',
"_",
type_to_name(a),
"_",
type_to_name(out),
"_bm",
bm,
"_bn",
bn,
"_bk",
bk,
"_wm",
wm,
"_wn",
wn);
metal::MTLFCList func_consts = { metal::MTLFCList func_consts = {
{&has_batch, MTL::DataType::DataTypeBool, 10}, {&has_batch, MTL::DataType::DataTypeBool, 10},
{&use_out_source, MTL::DataType::DataTypeBool, 100},
{&do_axpby, MTL::DataType::DataTypeBool, 110},
{&align_M, MTL::DataType::DataTypeBool, 200}, {&align_M, MTL::DataType::DataTypeBool, 200},
{&align_N, MTL::DataType::DataTypeBool, 201}, {&align_N, MTL::DataType::DataTypeBool, 201},
{&align_K, MTL::DataType::DataTypeBool, 202}, {&align_K, MTL::DataType::DataTypeBool, 202},
{&do_gather, MTL::DataType::DataTypeBool, 300},
}; };
// clang-format off // And the kernel hash that includes the function constants
kname << "_has_batch_" << (has_batch ? 't' : 'n') std::string hash_name;
<< "_use_out_source_" << (use_out_source ? 't' : 'n') hash_name.reserve(128);
<< "_do_axpby_" << (do_axpby ? 't' : 'n') concatenate(
<< "_align_M_" << (align_M ? 't' : 'n') hash_name,
<< "_align_N_" << (align_N ? 't' : 'n') base_name,
<< "_align_K_" << (align_K ? 't' : 'n') "_has_batch_",
<< "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on has_batch ? 't' : 'n',
"_align_M_",
align_M ? 't' : 'n',
"_align_N_",
align_N ? 't' : 'n',
"_align_K_",
align_K ? 't' : 'n');
std::string hash_name = kname.str(); // Get and set the kernel
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = get_steel_gemm_fused_kernel( auto kernel = get_steel_gemm_gather_kernel(
d, d,
base_name, base_name,
hash_name, hash_name,
@@ -1736,72 +1842,97 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
bn, bn,
bk, bk,
wm, wm,
wn); wn,
false);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_compute_pipeline_state(kernel);
// Use problem size to determine threadblock swizzle // Prepare the matmul params
int tn = (N + bn - 1) / bn; steel::GEMMParams params{
int tm = (M + bm - 1) / bm;
// TODO: Explore device-based tuning for swizzle
int swizzle_log = 0; // tm >= 6 ? 3 : (tm <= 3 ? 0 : 2);
// Prepare steel matmul params
GEMMParams params{
/* const int M = */ M, /* const int M = */ M,
/* const int N = */ N, /* const int N = */ N,
/* const int K = */ K, /* const int K = */ K,
/* const int lda = */ lda, /* const int lda = */ static_cast<int>(lda),
/* const int ldb = */ ldb, /* const int ldb = */ static_cast<int>(ldb),
/* const int ldd = */ N, /* const int ldd = */ N,
/* const int tiles_n = */ tn, /* const int tiles_n = */ (N + bn - 1) / bn,
/* const int tiles_m = */ tm, /* const int tiles_m = */ (M + bm - 1) / bm,
/* const int64_t batch_stride_a = */ lhs_indices_str, /* const int64_t batch_stride_a = */
/* const int64_t batch_stride_b = */ rhs_indices_str, (batch_ndim > 0) ? lhs_indices.strides()[0] : 0,
/* const int64_t batch_stride_d = */ matrix_stride_out, /* const int64_t batch_stride_b = */
/* const int swizzle_log = */ swizzle_log, (batch_ndim > 0) ? rhs_indices.strides()[0] : 0,
/* const int64_t batch_stride_d = */ M * N,
/* const int swizzle_log = */ 0,
/* const int gemm_k_iterations_aligned = */ (K / bk), /* const int gemm_k_iterations_aligned = */ (K / bk),
/* const int batch_ndim = */ batch_ndim}; /* const int batch_ndim = */ batch_ndim};
// Prepare launch grid params // Prepare the grid
int tile = 1 << swizzle_log;
tm = (tm + tile - 1) / tile;
tn = tn * tile;
MTL::Size group_dims = MTL::Size(32, wn, wm); MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out); MTL::Size grid_dims =
MTL::Size(params.tiles_n, params.tiles_m, batch_size_out);
// Launch kernel // Launch kernel
compute_encoder.set_input_array(a, 0); compute_encoder.set_input_array(a, 0);
compute_encoder.set_input_array(b, 1); compute_encoder.set_input_array(b, 1);
compute_encoder.set_output_array(out, 3); compute_encoder.set_input_array(lhs_indices, 2);
compute_encoder.set_input_array(rhs_indices, 3);
compute_encoder.set_bytes(params, 4); compute_encoder.set_output_array(out, 4);
compute_encoder.set_bytes(params, 5);
compute_encoder.set_vector_bytes(batch_shape, 6); compute_encoder.set_vector_bytes(lhs_indices.shape(), 6);
compute_encoder.set_vector_bytes(batch_strides, 7); compute_encoder.set_vector_bytes(lhs_indices.strides(), 7);
compute_encoder.set_vector_bytes(rhs_indices.strides(), 8);
compute_encoder.set_input_array(lhs_indices, 10); compute_encoder.set_bytes(batch_ndim_a, 9);
compute_encoder.set_input_array(rhs_indices, 11); compute_encoder.set_vector_bytes(a.shape(), 10);
compute_encoder.set_vector_bytes(a.strides(), 11);
std::vector operand_shape = batch_shape_A; compute_encoder.set_bytes(batch_ndim_b, 12);
operand_shape.insert( compute_encoder.set_vector_bytes(b.shape(), 13);
operand_shape.end(), batch_shape_B.begin(), batch_shape_B.end()); compute_encoder.set_vector_bytes(b.strides(), 14);
std::vector operand_strides = batch_strides_A;
operand_strides.insert(
operand_strides.end(), batch_strides_B.begin(), batch_strides_B.end());
operand_batch_ndim.push_back(0);
compute_encoder.set_vector_bytes(operand_shape, 13);
compute_encoder.set_vector_bytes(operand_strides, 14);
compute_encoder.set_vector_bytes(operand_batch_ndim, 15);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims); compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
d.add_temporaries(std::move(copies), s.index); void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = stream();
auto& d = metal::device(s.device);
auto& a = inputs[0];
auto& b = inputs[1];
auto& lhs_indices = inputs[2];
auto& rhs_indices = inputs[3];
// Return 0s if either input is empty
if (a.size() == 0 || b.size() == 0) {
array zero = array(0, a.dtype());
fill_gpu(zero, out, s);
d.add_temporary(std::move(zero), s.index);
return;
}
out.set_data(allocator::malloc(out.nbytes()));
// Extract shapes strides from inputs and copy in case of non-contiguous
// vectors.
int M = a.shape(-2);
int N = b.shape(-1);
int K = a.shape(-1);
// We are walking a in order and b is also in order so we can batch up the
// matmuls and reuse reading a and b.
if (M == 1 && right_sorted_ == true) {
gather_mm_rhs(a, b, rhs_indices, out, d, s);
return;
}
// Route to gather gemv if any of a or b are vectors
if (M == 1) {
gather_mv(b, a, rhs_indices, lhs_indices, out, N, K, false, d, s);
return;
}
if (N == 1) {
gather_mv(a, b, lhs_indices, rhs_indices, out, M, K, true, d, s);
return;
}
// Route to non specialized gather mm
gather_mm(a, b, lhs_indices, rhs_indices, out, M, N, K, d, s);
} }
} // namespace mlx::core } // namespace mlx::core

Some files were not shown because too many files have changed in this diff Show More