Compare commits

..

65 Commits

Author SHA1 Message Date
Awni Hannun
998404ada4 Get trellis to run 2025-04-26 07:02:20 -07:00
Alex Barron
e3d275bc49 rebase on main 2025-04-14 16:37:23 -07:00
Alex Barron
d7acf59fd0 add trellis quant mode 2025-04-14 16:28:23 -07:00
Yury Popov
e9e268336b LogCumSumExp (#2069) 2025-04-13 01:27:29 -07:00
Awni Hannun
7275ac7523 Fix release build (#2072) 2025-04-12 20:41:58 -07:00
Angelos Katharopoulos
c4189a38e4 Add float mask to sdpa vector (#2068) 2025-04-11 17:29:40 -07:00
Awni Hannun
68d1b3256b nit: fix exception handling (#2066) 2025-04-11 14:12:08 -07:00
Awni Hannun
9c6953bda7 Fix stubgen (#2065)
* Fix stubgen

* add multi optim to docs
2025-04-11 12:02:54 -07:00
Awni Hannun
ef7ece9851 fix fft bug (#2062) 2025-04-10 19:41:27 -07:00
Angelos Katharopoulos
ddaa4b7dcb Fix the test and add custom min/max reductions for uncommon MPI types (#2060) 2025-04-10 17:01:17 -07:00
Cheng
dfae2c6989 Fix MSVC build due to use of M_LN2 (#2058) 2025-04-10 07:41:41 -07:00
Anastasiia Filippova
515f104926 Min / max reductions (#2041) 2025-04-09 23:22:20 -07:00
Angelos Katharopoulos
9ecefd56db Do not load the default lib if another is requested (#2055) 2025-04-09 13:31:38 -07:00
Awni Hannun
e5d35aa187 no sdpa in grad (#2054) 2025-04-08 19:13:54 -07:00
Awni Hannun
00794c42bc Fix causal mask sdpa vec (#2053)
* fix sdpa vector causal mask

* test
2025-04-08 09:11:23 -07:00
Cheng
08a1bf3f10 Remove Event::Signal() (#2052) 2025-04-08 06:20:27 -07:00
Awni Hannun
60c4154346 Only request residency once (#2051) 2025-04-07 10:47:51 -07:00
Awni Hannun
f2c85308c1 add a half simd gemm fallback (#2046)
* add a half simd gemm fallback

* nit
2025-04-07 09:31:29 -07:00
Awni Hannun
1a28b69ee2 only add to residency set once (#2049) 2025-04-06 17:38:25 -07:00
Cheng
ba09f01ce8 Remove test of converting negative float to uint (#2048) 2025-04-06 06:21:46 -07:00
Cheng
6cf48872b7 wait_for_one should wait for task to finish (#2047) 2025-04-05 20:05:16 -07:00
Angelos Katharopoulos
7b3b8fa000 Fix ci release (#2045) 2025-04-04 20:25:01 -07:00
Awni Hannun
ec5e2aae61 nit in doc (#2044) 2025-04-04 12:04:17 -07:00
Awni Hannun
86389bf970 patch bump (#2043) 2025-04-03 13:15:18 -07:00
Jagrit Digani
3290bfa690 Add new sdpa function overload (#2035)
* Add new sdpa function overload

* Address comments

* Remove std::varaint from cpp sdpa function
2025-04-03 11:58:28 -07:00
Jagrit Digani
8777fd104f Depthwise Conv2D optimization (#2036)
- Add new specialized kernel for small kernel (kernels size <= 7), small strides (strides <= 2) depthwise 2d convolutions
- Add related tests
2025-04-03 09:42:04 -07:00
Awni Hannun
c41f7565ed fix softmax / logsumexp (#2042) 2025-04-03 08:32:59 -07:00
Awni Hannun
9ba81e3da4 tune quant dispatch (#2031) 2025-04-02 20:05:54 -07:00
Awni Hannun
c23888acd7 Fix build warning (#2033) 2025-04-01 14:42:27 -07:00
Awni Hannun
f98ce25ab9 fix residency set for real (#2032) 2025-04-01 12:59:48 -07:00
Awni Hannun
de5f38fd48 Custom logsumexp (#2028)
* initial custom logsumexp

* more tests

* comments + fix
2025-03-31 07:36:55 -07:00
Angelos Katharopoulos
ec2854b13a Swap -inf for finite_minimum value (#2029) 2025-03-30 21:55:04 -07:00
Stephen Panaro
90823d2938 Add missing funcs to docs (#2021) 2025-03-30 18:29:33 -07:00
Jesper Stemann Andersen
5f5770e3a2 Fix CPU sign for unsigned ints (#2024)
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2025-03-30 17:56:59 -07:00
Awni Hannun
28f39e9038 Log for complex numbers in Metal (#2025)
* Log for complex numbers in Metal

* fix log2
2025-03-30 17:04:38 -07:00
Awni Hannun
b2d2b37888 fix residency set clearing (#2027) 2025-03-30 16:27:26 -07:00
Awni Hannun
fe597e141c add pinv to doc (#2020) 2025-03-30 15:54:18 -07:00
Yi Wang
72ca1539e0 Remove unused variable in /setup.py (#2026)
This is a follow up of https://github.com/ml-explore/mlx/pull/2011
2025-03-30 12:52:33 -07:00
Awni Hannun
13b26775f1 use minimum deployment target (#2016) 2025-03-28 14:31:53 -07:00
Awni Hannun
05d7118561 causal vector sdpa (#2018)
* causal vector sdpa

* get rid of memory threshold
2025-03-28 12:36:13 -07:00
Awni Hannun
98b901ad66 enable complex gemm (#2017) 2025-03-28 10:45:13 -07:00
Awni Hannun
5580b47291 iinfo and scalar overflow detection (#2009) 2025-03-27 19:54:56 -07:00
Awni Hannun
bc62932984 sdpa specialization for head dim 256 (#2007) 2025-03-27 19:31:25 -07:00
Awni Hannun
a6b5d6e759 revise cmake minimum for doctest (#2014) 2025-03-27 19:30:58 -07:00
Yi Wang
a8931306e1 Remove unused variable in CMakeBuild (#2011)
Fix https://github.com/ml-explore/mlx/issues/2010
2025-03-27 16:00:51 -07:00
Yi Wang
fecdb8717e Polish CONTRIBUTING>md (#2005) 2025-03-25 19:06:34 -07:00
Awni Hannun
916fd273ea wire cache (#2006) 2025-03-25 18:54:01 -07:00
Yi Wang
0da8506552 Update docs for extensions (#2004) 2025-03-25 18:35:03 -07:00
Cheng
eda7a7b43e Do not join threads during process exit on Windows (#1738) 2025-03-25 06:33:08 -07:00
Chunyang Wen
022eabb734 Remove unused import (#1987) 2025-03-24 20:19:32 -07:00
Awni Hannun
aba899cef8 patch bump (#2000) 2025-03-24 12:47:05 -07:00
Jagrit Digani
6a40e1c176 Fix looping limit in causal attention (#1999) 2025-03-24 12:28:00 -07:00
Jesper Stemann Andersen
9307b2ab8b Fixed 32-bit platform support for distributed/ring implementation (#1996)
Replaced unsigned long integer literals with size_t literals in ring implementation, e.g., 1UL with size_t(1).
2025-03-24 08:08:40 -07:00
Jesper Stemann Andersen
522d8d3917 Added missing netinet/in.h include that fixes build on FreeBSD (#1997)
Defines IPPROTO_TCP.
2025-03-24 08:07:34 -07:00
Awni Hannun
a84cc0123f promote mask when needed (#1998) 2025-03-23 19:58:28 -07:00
Andrey Velichkevich
f018e248cd fix(backend): Include algorithm library in Allocator (#1992)
Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com>
2025-03-22 21:27:51 -07:00
Awni Hannun
cfd7237a80 fix docs (#1991) 2025-03-21 19:58:53 -07:00
Angelos Katharopoulos
4eef8102c9 Distributed layers (#1270) 2025-03-21 13:52:17 -07:00
Angelos Katharopoulos
69e4dd506b Add a ring all gather (#1985) 2025-03-21 13:36:51 -07:00
Angelos Katharopoulos
25814a9458 Disable mpi on version mismatch (#1989) 2025-03-21 13:36:26 -07:00
Awni Hannun
2a980a76ce Add stats and limit to common allocator and enable tests (#1988)
* add stats to common allocator and enable tests

* linux memory and default

* fix
2025-03-21 12:28:36 -07:00
Angelos Katharopoulos
d343782c8b Cross platform libmpi loading (#1975) 2025-03-21 11:23:10 -07:00
Awni Hannun
4e1994e9d7 move memory APIs into top level mlx.core (#1982) 2025-03-21 07:25:12 -07:00
jiyzhang
65a38c452b update the formula of smooth_l1_loss (#1986) 2025-03-21 06:25:23 -07:00
Awni Hannun
7b7e2352cd fix malloc or wait deadlock (#1976) 2025-03-20 16:48:43 -07:00
184 changed files with 5381 additions and 1444 deletions

View File

@@ -24,8 +24,8 @@ jobs:
type: boolean
default: false
macos:
xcode: "15.2.0"
resource_class: macos.m1.medium.gen1
xcode: "16.2.0"
resource_class: m2pro.medium
steps:
- checkout
- run:
@@ -89,15 +89,14 @@ jobs:
pip install numpy
sudo apt-get update
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
- run:
name: Install Python package
command: |
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF
CMAKE_COMPILE_WARNING_AS_ERROR=ON" \
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
python3 setup.py build_ext --inplace
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF \
CMAKE_COMPILE_WARNING_AS_ERROR=ON" \
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
python3 setup.py develop
- run:
@@ -110,6 +109,8 @@ jobs:
name: Run Python tests
command: |
python3 -m unittest discover python/tests -v
mpirun --bind-to none -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py
- run:
name: Build CPP only
command: |
@@ -124,10 +125,15 @@ jobs:
parameters:
xcode_version:
type: string
default: "15.2.0"
default: "16.2.0"
macosx_deployment_target:
type: string
default: ""
macos:
xcode: << parameters.xcode_version >>
resource_class: macos.m1.medium.gen1
environment:
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
resource_class: m2pro.medium
steps:
- checkout
- run:
@@ -149,7 +155,7 @@ jobs:
command: |
source env/bin/activate
DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
CMAKE_ARGS="CMAKE_COMPILE_WARNING_AS_ERROR=ON" \
CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
pip install -e . -v
- run:
name: Generate package stubs
@@ -213,13 +219,18 @@ jobs:
default: "3.9"
xcode_version:
type: string
default: "15.2.0"
default: "16.2.0"
build_env:
type: string
default: ""
macosx_deployment_target:
type: string
default: ""
macos:
xcode: << parameters.xcode_version >>
resource_class: macos.m1.medium.gen1
resource_class: m2pro.medium
environment:
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
steps:
- checkout
- run:
@@ -240,7 +251,7 @@ jobs:
name: Install Python package
command: |
source env/bin/activate
DEV_RELEASE=1 \
env -u MACOSX_DEPLOYMENT_TARGET DEV_RELEASE=1 \
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
pip install . -v
- run:
@@ -335,7 +346,7 @@ workflows:
- mac_build_and_test:
matrix:
parameters:
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
macosx_deployment_target: ["13.5", "14.0"]
- linux_build_and_test
- build_documentation
@@ -355,8 +366,70 @@ workflows:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
xcode_version: ["15.0.0", "15.2.0"]
macosx_deployment_target: ["13.5", "14.0", "15.0"]
build_env: ["PYPI_RELEASE=1"]
xcode_version: ["16.2.0", "15.0.0"]
exclude:
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.9"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.10"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "PYPI_RELEASE=1"
- build_documentation:
filters:
tags:
@@ -379,7 +452,7 @@ workflows:
requires: [ hold ]
matrix:
parameters:
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
macosx_deployment_target: ["13.5", "14.0"]
- linux_build_and_test:
requires: [ hold ]
nightly_build:
@@ -392,7 +465,54 @@ workflows:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
xcode_version: ["15.0.0", "15.2.0"]
macosx_deployment_target: ["13.5", "14.0", "15.0"]
xcode_version: ["16.2.0", "15.0.0"]
exclude:
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.9"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.10"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
weekly_build:
when:
and:
@@ -403,8 +523,70 @@ workflows:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
macosx_deployment_target: ["13.5", "14.0", "15.0"]
build_env: ["DEV_RELEASE=1"]
xcode_version: ["16.2.0", "15.0.0"]
exclude:
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
linux_test_release:
when:
and:

View File

@@ -212,24 +212,6 @@ else()
set(MLX_BUILD_ACCELERATE OFF)
endif()
find_package(MPI)
if(MPI_FOUND)
execute_process(
COMMAND zsh "-c" "mpirun --version"
OUTPUT_VARIABLE MPI_VERSION
ERROR_QUIET)
if(${MPI_VERSION} MATCHES ".*Open MPI.*")
target_include_directories(mlx PRIVATE ${MPI_INCLUDE_PATH})
elseif(MPI_VERSION STREQUAL "")
set(MPI_FOUND FALSE)
message(
WARNING "MPI found but mpirun is not available. Building without MPI.")
else()
set(MPI_FOUND FALSE)
message(WARNING "MPI which is not OpenMPI found. Building without MPI.")
endif()
endif()
message(STATUS "Downloading json")
FetchContent_Declare(
json

View File

@@ -5,26 +5,26 @@ possible.
## Pull Requests
1. Fork and submit pull requests to the repo.
1. Fork and submit pull requests to the repo.
2. If you've added code that should be tested, add tests.
3. If a change is likely to impact efficiency, run some of the benchmarks before
and after the change. Examples of benchmarks can be found in `benchmarks/python/`.
4. If you've changed APIs, update the documentation.
5. Every PR should have passing tests and at least one review.
5. Every PR should have passing tests and at least one review.
6. For code formatting install `pre-commit` using something like `pip install pre-commit` and run `pre-commit install`.
This should install hooks for running `black` and `clang-format` to ensure
consistent style for C++ and python code.
You can also run the formatters manually as follows:
```
clang-format -i file.cpp
```
```
black file.py
```
```shell
clang-format -i file.cpp
```
```shell
black file.py
```
or run `pre-commit run --all-files` to check all files in the repo.
## Issues

View File

@@ -13,7 +13,7 @@ EXCLUDE_PATTERNS = */private/*
CREATE_SUBDIRS = NO
FULL_PATH_NAMES = YES
RECURSIVE = YES
GENERATE_HTML = YES
GENERATE_HTML = NO
GENERATE_LATEX = NO
GENERATE_XML = YES
XML_PROGRAMLISTING = YES

View File

@@ -93,9 +93,9 @@ Primitives
^^^^^^^^^^^
A :class:`Primitive` is part of the computation graph of an :class:`array`. It
defines how to create outputs arrays given a input arrays. Further, a
defines how to create output arrays given input arrays. Further, a
:class:`Primitive` has methods to run on the CPU or GPU and for function
transformations such as ``vjp`` and ``jvp``. Lets go back to our example to be
transformations such as ``vjp`` and ``jvp``. Let's go back to our example to be
more concrete:
.. code-block:: C++
@@ -128,7 +128,7 @@ more concrete:
/** The vector-Jacobian product. */
std::vector<array> vjp(
const std::vector<array>& primals,
const array& cotan,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
@@ -247,9 +247,7 @@ point-wise. This is captured in the templated function :meth:`axpby_impl`.
float alpha_,
float beta_,
mx::Stream stream) {
// Allocate the output with `malloc_or_wait` which synchronously allocates
// memory, potentially waiting if the system is under memory pressure
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
out.set_data(mx::allocator::malloc(out.nbytes()));
// Get the CPU command encoder and register input and output arrays
auto& encoder = mx::cpu::get_command_encoder(stream);
@@ -393,7 +391,7 @@ below.
auto& d = metal::device(s.device);
// Allocate output memory
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
// Resolve name of kernel
std::ostringstream kname;
@@ -471,7 +469,7 @@ one we just defined:
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
// Forward mode diff that pushes along the tangents
// The jvp transform on the primitive can built with ops
// The jvp transform on the primitive can be built with ops
// that are scheduled on the same stream as the primitive
// If argnums = {0}, we only push along x in which case the
@@ -483,7 +481,7 @@ one we just defined:
auto scale_arr = array(scale, tangents[0].dtype());
return {multiply(scale_arr, tangents[0], stream())};
}
// If, argnums = {0, 1}, we take contributions from both
// If argnums = {0, 1}, we take contributions from both
// which gives us jvp = tangent_x * alpha + tangent_y * beta
else {
return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())};
@@ -737,7 +735,7 @@ Let's look at a simple script and its results:
print(f"c shape: {c.shape}")
print(f"c dtype: {c.dtype}")
print(f"c correct: {mx.all(c == 6.0).item()}")
print(f"c is correct: {mx.all(c == 6.0).item()}")
Output:
@@ -745,7 +743,7 @@ Output:
c shape: [3, 4]
c dtype: float32
c correctness: True
c is correct: True
Results
^^^^^^^

View File

@@ -70,6 +70,7 @@ are the CPU and GPU.
python/fft
python/linalg
python/metal
python/memory_management
python/nn
python/optimizers
python/distributed

View File

@@ -38,6 +38,7 @@ Array
array.log10
array.log1p
array.log2
array.logcumsumexp
array.logsumexp
array.max
array.mean

View File

@@ -20,5 +20,6 @@ Linear Algebra
eigh
lu
lu_factor
pinv
solve
solve_triangular

View File

@@ -0,0 +1,16 @@
Memory Management
=================
.. currentmodule:: mlx.core
.. autosummary::
:toctree: _autosummary
get_active_memory
get_peak_memory
reset_peak_memory
get_cache_memory
set_memory_limit
set_cache_limit
set_wired_limit
clear_cache

View File

@@ -8,13 +8,5 @@ Metal
is_available
device_info
get_active_memory
get_peak_memory
reset_peak_memory
get_cache_memory
set_memory_limit
set_cache_limit
set_wired_limit
clear_cache
start_capture
stop_capture

View File

@@ -36,10 +36,12 @@ Operations
bitwise_or
bitwise_xor
block_masked_mm
broadcast_arrays
broadcast_to
ceil
clip
concatenate
contiguous
conj
conjugate
convolve
@@ -101,6 +103,7 @@ Operations
log10
log1p
logaddexp
logcumsumexp
logical_not
logical_and
logical_or

View File

@@ -18,3 +18,4 @@ Common Optimizers
AdamW
Adamax
Lion
MultiOptimizer

View File

@@ -9,6 +9,7 @@ Transforms
:toctree: _autosummary
eval
async_eval
compile
custom_function
disable_compile

View File

@@ -72,9 +72,7 @@ void axpby_impl(
float alpha_,
float beta_,
mx::Stream stream) {
// Allocate the output with `malloc_or_wait` which synchronously allocates
// memory, potentially waiting if the system is under memory pressure
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
out.set_data(mx::allocator::malloc(out.nbytes()));
// Get the CPU command encoder and register input and output arrays
auto& encoder = mx::cpu::get_command_encoder(stream);
@@ -160,12 +158,12 @@ void Axpby::eval_gpu(
// Allocate output memory with strides based on specialization
if (contiguous_kernel) {
out.set_data(
mx::allocator::malloc_or_wait(x.data_size() * out.itemsize()),
mx::allocator::malloc(x.data_size() * out.itemsize()),
x.data_size(),
x.strides(),
x.flags());
} else {
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
out.set_data(mx::allocator::malloc(out.nbytes()));
}
// Resolve name of kernel (corresponds to axpby.metal)

View File

@@ -4,12 +4,11 @@
#include <sstream>
#include "mlx/allocator.h"
#include "mlx/scheduler.h"
namespace mlx::core::allocator {
Buffer malloc(size_t size) {
auto buffer = allocator().malloc(size, /* allow_swap */ true);
auto buffer = allocator().malloc(size);
if (size && !buffer.ptr()) {
std::ostringstream msg;
msg << "[malloc] Unable to allocate " << size << " bytes.";
@@ -22,45 +21,4 @@ void free(Buffer buffer) {
allocator().free(buffer);
}
Buffer CommonAllocator::malloc(size_t size, bool) {
void* ptr = std::malloc(size + sizeof(size_t));
if (ptr != nullptr) {
*static_cast<size_t*>(ptr) = size;
}
return Buffer{ptr};
}
void CommonAllocator::free(Buffer buffer) {
std::free(buffer.ptr());
}
size_t CommonAllocator::size(Buffer buffer) const {
if (buffer.ptr() == nullptr) {
return 0;
}
return *static_cast<size_t*>(buffer.ptr());
}
Buffer malloc_or_wait(size_t size) {
auto buffer = allocator().malloc(size);
while (size && !buffer.ptr() && scheduler::n_active_tasks() > 0) {
scheduler::wait_for_one();
buffer = allocator().malloc(size);
}
// Try swapping if needed
if (size && !buffer.ptr()) {
buffer = allocator().malloc(size, /* allow_swap = */ true);
}
if (size && !buffer.ptr()) {
std::ostringstream msg;
msg << "[malloc_or_wait] Unable to allocate " << size << " bytes.";
throw std::runtime_error(msg.str());
}
return buffer;
}
} // namespace mlx::core::allocator

View File

@@ -32,14 +32,10 @@ Buffer malloc(size_t size);
void free(Buffer buffer);
// Wait for running tasks to finish and free up memory
// if allocation fails
Buffer malloc_or_wait(size_t size);
class Allocator {
/** Abstract base class for a memory allocator. */
public:
virtual Buffer malloc(size_t size, bool allow_swap = false) = 0;
virtual Buffer malloc(size_t size) = 0;
virtual void free(Buffer buffer) = 0;
virtual size_t size(Buffer buffer) const = 0;
@@ -53,16 +49,4 @@ class Allocator {
Allocator& allocator();
class CommonAllocator : public Allocator {
/** A general CPU allocator. */
public:
virtual Buffer malloc(size_t size, bool allow_swap = false) override;
virtual void free(Buffer buffer) override;
virtual size_t size(Buffer buffer) const override;
private:
CommonAllocator() = default;
friend Allocator& allocator();
};
} // namespace mlx::core::allocator

View File

@@ -44,14 +44,14 @@ inline void set_binary_op_output_data(
switch (bopt) {
case BinaryOpType::ScalarScalar:
out.set_data(
allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags());
allocator::malloc(out.itemsize()), 1, a.strides(), a.flags());
break;
case BinaryOpType::ScalarVector:
if (b_donatable) {
out.copy_shared_buffer(b);
} else {
out.set_data(
allocator::malloc_or_wait(b.data_size() * out.itemsize()),
allocator::malloc(b.data_size() * out.itemsize()),
b.data_size(),
b.strides(),
b.flags());
@@ -62,7 +62,7 @@ inline void set_binary_op_output_data(
out.copy_shared_buffer(a);
} else {
out.set_data(
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
allocator::malloc(a.data_size() * out.itemsize()),
a.data_size(),
a.strides(),
a.flags());
@@ -75,7 +75,7 @@ inline void set_binary_op_output_data(
out.copy_shared_buffer(b);
} else {
out.set_data(
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
allocator::malloc(a.data_size() * out.itemsize()),
a.data_size(),
a.strides(),
a.flags());
@@ -88,7 +88,7 @@ inline void set_binary_op_output_data(
b_donatable && b.flags().row_contiguous && b.size() == out.size()) {
out.copy_shared_buffer(b);
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
}
break;
}

View File

@@ -103,7 +103,7 @@ void ExpandDims::eval(const std::vector<array>& inputs, array& out) {
void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
double numel = 1;
for (auto ax : axes_) {

View File

@@ -188,7 +188,7 @@ void compiled_allocate_outputs(
}
for (; o < outputs.size(); ++o) {
outputs[o].set_data(
allocator::malloc_or_wait(data_size * outputs[o].itemsize()),
allocator::malloc(data_size * outputs[o].itemsize()),
data_size,
strides,
flags);
@@ -211,7 +211,7 @@ void compiled_allocate_outputs(
}
}
for (; o < outputs.size(); ++o) {
outputs[o].set_data(allocator::malloc_or_wait(outputs[o].nbytes()));
outputs[o].set_data(allocator::malloc(outputs[o].nbytes()));
}
}
}

View File

@@ -31,14 +31,14 @@ inline bool set_copy_output_data(const array& in, array& out, CopyType ctype) {
return true;
} else {
out.set_data(
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
allocator::malloc(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
return false;
}
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
return false;
}
}

View File

@@ -28,7 +28,7 @@ void swap_endianness(uint8_t* data_bytes, size_t N) {
namespace mlx::core {
void Load::eval_cpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
auto read_task = [out_ptr = out.data<char>(),
size = out.size(),
itemsize = out.itemsize(),

View File

@@ -48,12 +48,12 @@ inline void set_ternary_op_output_data(
switch (topt) {
case TernaryOpType::ScalarScalarScalar:
out.set_data(
allocator::malloc_or_wait(out.itemsize()), 1, b.strides(), b.flags());
allocator::malloc(out.itemsize()), 1, b.strides(), b.flags());
break;
case TernaryOpType::VectorVectorVector:
if (!(maybe_donate(a) || maybe_donate(b) || maybe_donate(c))) {
out.set_data(
allocator::malloc_or_wait(out.itemsize() * b.data_size()),
allocator::malloc(out.itemsize() * b.data_size()),
b.data_size(),
b.strides(),
b.flags());
@@ -64,7 +64,7 @@ inline void set_ternary_op_output_data(
if (!((a.flags().row_contiguous && maybe_donate(a)) ||
(b.flags().row_contiguous && maybe_donate(b)) ||
(c.flags().row_contiguous && maybe_donate(c)))) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
}
break;
}

View File

@@ -58,6 +58,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
@@ -73,8 +74,8 @@ target_sources(
if(MLX_BUILD_ACCELERATE)
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/bnns.cpp)
else()
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/no_fp16.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gemms/no_bf16.cpp)
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/simd_fp16.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gemms/simd_bf16.cpp)
endif()
if(IOS)

View File

@@ -68,7 +68,7 @@ void arg_reduce_dispatch(
void ArgReduce::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(in);
encoder.set_output_array(out);

View File

@@ -921,7 +921,7 @@ void explicit_gemm_conv_1D_cpu(
if (out.dtype() != float32) {
gemm_out = array(out.shape(), float32, nullptr, {});
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes()));
gemm_out.set_data(allocator::malloc(gemm_out.nbytes()));
temps.push_back(gemm_out);
}
@@ -1048,7 +1048,7 @@ void explicit_gemm_conv_2D_cpu(
if (out.dtype() != float32) {
gemm_out = array(out.shape(), float32, nullptr, {});
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes()));
gemm_out.set_data(allocator::malloc(gemm_out.nbytes()));
temps.push_back(gemm_out);
}
@@ -1214,7 +1214,7 @@ void explicit_gemm_conv_ND_cpu(
if (out.dtype() != float32) {
gemm_out = array(out.shape(), float32, nullptr, {});
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes()));
gemm_out.set_data(allocator::malloc(gemm_out.nbytes()));
temps.push_back(gemm_out);
}
@@ -1327,7 +1327,7 @@ void conv_3D_cpu(
} // namespace
void Convolution::eval_cpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
auto& in = inputs[0];
auto& wt = inputs[1];

View File

@@ -30,7 +30,7 @@ void AllReduce::eval_cpu(
if (in.is_donatable()) {
out.copy_shared_buffer(in);
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
}
return in;
} else {
@@ -46,8 +46,15 @@ void AllReduce::eval_cpu(
case Sum:
distributed::detail::all_sum(group(), in, outputs[0], stream());
break;
case Max:
distributed::detail::all_max(group(), in, outputs[0], stream());
break;
case Min:
distributed::detail::all_min(group(), in, outputs[0], stream());
break;
default:
throw std::runtime_error("Only all reduce sum is supported for now");
throw std::runtime_error(
"Only all reduce sum, min and max are supported for now");
}
}
@@ -58,7 +65,7 @@ void AllGather::eval_cpu(
assert(outputs.size() == 1);
auto [in, copied] = ensure_row_contiguous(inputs[0], stream());
outputs[0].set_data(allocator::malloc_or_wait(outputs[0].nbytes()));
outputs[0].set_data(allocator::malloc(outputs[0].nbytes()));
distributed::detail::all_gather(group(), in, outputs[0], stream());
if (copied) {
auto& enc = cpu::get_command_encoder(stream());
@@ -87,7 +94,7 @@ void Recv::eval_cpu(
assert(inputs.size() == 0);
assert(outputs.size() == 1);
outputs[0].set_data(allocator::malloc_or_wait(outputs[0].nbytes()));
outputs[0].set_data(allocator::malloc(outputs[0].nbytes()));
distributed::detail::recv(group(), outputs[0], src_, stream());
}

View File

@@ -55,9 +55,8 @@ void eigh_impl(
liwork = iwork;
}
auto work_buf = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)};
auto iwork_buf =
array::Data{allocator::malloc_or_wait(sizeof(int) * liwork)};
auto work_buf = array::Data{allocator::malloc(sizeof(T) * lwork)};
auto iwork_buf = array::Data{allocator::malloc(sizeof(int) * liwork)};
for (size_t i = 0; i < size / (N * N); ++i) {
syevd<T>(
&jobz,
@@ -98,7 +97,7 @@ void Eigh::eval_cpu(
? outputs[1]
: array(a.shape(), a.dtype(), nullptr, {});
values.set_data(allocator::malloc_or_wait(values.nbytes()));
values.set_data(allocator::malloc(values.nbytes()));
copy(
a,

View File

@@ -22,7 +22,7 @@ void FFT::eval_cpu(const std::vector<array>& inputs, array& out) {
s *= out.itemsize();
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
std::vector<size_t> shape;
if (out.dtype() == float32) {

View File

@@ -1,27 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cpu/gemm.h"
namespace mlx::core {
template <>
void matmul<bfloat16_t>(
const bfloat16_t*,
const bfloat16_t*,
bfloat16_t*,
bool,
bool,
size_t,
size_t,
size_t,
float,
float,
size_t,
const Shape&,
const Strides&,
const Shape&,
const Strides&) {
throw std::runtime_error("[Matmul::eval_cpu] bfloat16 not supported.");
}
} // namespace mlx::core

View File

@@ -1,27 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cpu/gemm.h"
namespace mlx::core {
template <>
void matmul<float16_t>(
const float16_t*,
const float16_t*,
float16_t*,
bool,
bool,
size_t,
size_t,
size_t,
float,
float,
size_t,
const Shape&,
const Strides&,
const Shape&,
const Strides&) {
throw std::runtime_error("[Matmul::eval_cpu] float16 not supported.");
}
} // namespace mlx::core

View File

@@ -0,0 +1,45 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/gemm.h"
#include "mlx/backend/cpu/gemms/simd_gemm.h"
namespace mlx::core {
template <>
void matmul<bfloat16_t>(
const bfloat16_t* a,
const bfloat16_t* b,
bfloat16_t* out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
size_t ldc,
float alpha,
float beta,
size_t batch_size,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides) {
auto ndim = a_shape.size();
size_t M = a_shape[ndim - 2];
size_t N = b_shape[ndim - 1];
size_t K = a_shape[ndim - 1];
for (int i = 0; i < batch_size; ++i) {
simd_gemm<bfloat16_t, float>(
a + elem_to_loc(M * K * i, a_shape, a_strides),
b + elem_to_loc(K * N * i, b_shape, b_strides),
out + M * N * i,
a_transposed,
b_transposed,
M,
N,
K,
alpha,
beta);
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,45 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/gemm.h"
#include "mlx/backend/cpu/gemms/simd_gemm.h"
namespace mlx::core {
template <>
void matmul<float16_t>(
const float16_t* a,
const float16_t* b,
float16_t* out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
size_t ldc,
float alpha,
float beta,
size_t batch_size,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides) {
auto ndim = a_shape.size();
size_t M = a_shape[ndim - 2];
size_t N = b_shape[ndim - 1];
size_t K = a_shape[ndim - 1];
for (int i = 0; i < batch_size; ++i) {
simd_gemm<float16_t, float>(
a + elem_to_loc(M * K * i, a_shape, a_strides),
b + elem_to_loc(K * N * i, b_shape, b_strides),
out + M * N * i,
a_transposed,
b_transposed,
M,
N,
K,
alpha,
beta);
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,139 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/backend/cpu/simd/simd.h"
namespace mlx::core {
inline int ceildiv(int a, int b) {
return (a + b - 1) / b;
}
template <int block_size, typename T, typename AccT>
void load_block(
const T* in,
AccT* out,
int M,
int N,
int i,
int j,
bool transpose) {
if (transpose) {
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
out[jj * block_size + ii] =
in[(i * block_size + ii) * N + j * block_size + jj];
}
}
} else {
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
out[ii * block_size + jj] =
in[(i * block_size + ii) * N + j * block_size + jj];
}
}
}
}
template <typename T, typename AccT>
void simd_gemm(
const T* a,
const T* b,
T* c,
bool a_trans,
bool b_trans,
int M,
int N,
int K,
float alpha,
float beta) {
constexpr int block_size = 16;
constexpr int simd_size = simd::max_size<AccT>;
static_assert(
(block_size % simd_size) == 0,
"Block size must be divisible by SIMD size");
int last_k_block_size = K - block_size * (K / block_size);
int last_k_simd_block = (last_k_block_size / simd_size) * simd_size;
for (int i = 0; i < ceildiv(M, block_size); i++) {
for (int j = 0; j < ceildiv(N, block_size); j++) {
AccT c_block[block_size * block_size] = {0.0};
AccT a_block[block_size * block_size];
AccT b_block[block_size * block_size];
int k = 0;
for (; k < K / block_size; k++) {
// Load a and b blocks
if (a_trans) {
load_block<block_size>(a, a_block, K, M, k, i, true);
} else {
load_block<block_size>(a, a_block, M, K, i, k, false);
}
if (b_trans) {
load_block<block_size>(b, b_block, N, K, j, k, false);
} else {
load_block<block_size>(b, b_block, K, N, k, j, true);
}
// Multiply and accumulate
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
for (int kk = 0; kk < block_size; kk += simd_size) {
auto av =
simd::load<AccT, simd_size>(a_block + ii * block_size + kk);
auto bv =
simd::load<AccT, simd_size>(b_block + jj * block_size + kk);
c_block[ii * block_size + jj] += simd::sum(av * bv);
}
}
}
}
if (last_k_block_size) {
// Load a and b blocks
if (a_trans) {
load_block<block_size>(a, a_block, K, M, k, i, true);
} else {
load_block<block_size>(a, a_block, M, K, i, k, false);
}
if (b_trans) {
load_block<block_size>(b, b_block, N, K, j, k, false);
} else {
load_block<block_size>(b, b_block, K, N, k, j, true);
}
// Multiply and accumulate
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
int kk = 0;
for (; kk < last_k_simd_block; kk += simd_size) {
auto av =
simd::load<AccT, simd_size>(a_block + ii * block_size + kk);
auto bv =
simd::load<AccT, simd_size>(b_block + jj * block_size + kk);
c_block[ii * block_size + jj] += simd::sum(av * bv);
}
for (; kk < last_k_block_size; ++kk) {
c_block[ii * block_size + jj] +=
a_block[ii * block_size + kk] * b_block[jj * block_size + kk];
}
}
}
}
// Store
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
auto c_idx = (i * block_size + ii) * N + j * block_size + jj;
if (beta != 0) {
c[c_idx] = static_cast<T>(
alpha * c_block[ii * block_size + jj] + beta * c[c_idx]);
} else {
c[c_idx] = static_cast<T>(alpha * c_block[ii * block_size + jj]);
}
}
}
}
}
}
} // namespace mlx::core

View File

@@ -197,7 +197,7 @@ void dispatch_gather(
}
void Gather::eval_cpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
auto& src = inputs[0];
std::vector<array> inds;
@@ -354,7 +354,7 @@ void dispatch_gather_axis(
}
void GatherAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
auto& src = inputs[0];
auto& inds = inputs[1];

View File

@@ -11,7 +11,7 @@ namespace mlx::core {
template <typename T>
void general_inv(T* inv, int N) {
int info;
auto ipiv = array::Data{allocator::malloc_or_wait(sizeof(int) * N)};
auto ipiv = array::Data{allocator::malloc(sizeof(int) * N)};
// Compute LU factorization.
getrf<T>(
/* m = */ &N,
@@ -49,7 +49,7 @@ void general_inv(T* inv, int N) {
}
const int lwork = workspace_size;
auto scratch = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)};
auto scratch = array::Data{allocator::malloc(sizeof(T) * lwork)};
// Compute inverse.
getri<T>(

View File

@@ -0,0 +1,140 @@
// Copyright © 2023-2024 Apple Inc.
#include <cassert>
#include <cmath>
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h"
#include "mlx/primitives.h"
#include "mlx/types/limits.h"
namespace mlx::core {
namespace {
using namespace mlx::core::simd;
template <typename T, typename AccT>
void logsumexp(const array& in, array& out, Stream stream) {
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(in);
encoder.set_output_array(out);
const T* in_ptr = in.data<T>();
T* out_ptr = out.data<T>();
int M = in.shape().back();
int L = in.data_size() / M;
encoder.dispatch([in_ptr, out_ptr, M, L]() mutable {
constexpr int N = std::min(max_size<AccT>, max_size<T>);
const T* current_in_ptr;
for (int i = 0; i < L; i++, in_ptr += M, out_ptr += 1) {
// Find the maximum
current_in_ptr = in_ptr;
Simd<AccT, N> vmaximum(-numeric_limits<AccT>::infinity());
size_t s = M;
while (s >= N) {
Simd<AccT, N> vals = load<T, N>(current_in_ptr);
vmaximum = maximum(vals, vmaximum);
current_in_ptr += N;
s -= N;
}
AccT maximum = max(vmaximum);
while (s-- > 0) {
maximum = std::max(maximum, static_cast<AccT>(*current_in_ptr));
current_in_ptr++;
}
// Compute the normalizer and the exponentials
Simd<AccT, N> vnormalizer(0.0);
current_in_ptr = in_ptr;
s = M;
while (s >= N) {
Simd<AccT, N> vexp = load<T, N>(current_in_ptr);
vexp = exp(vexp - maximum);
vnormalizer = vnormalizer + vexp;
current_in_ptr += N;
s -= N;
}
AccT normalizer = sum(vnormalizer);
while (s-- > 0) {
AccT _exp = std::exp(*current_in_ptr - maximum);
normalizer += _exp;
current_in_ptr++;
}
// Normalize
*out_ptr = std::isinf(maximum)
? static_cast<T>(maximum)
: static_cast<T>(std::log(normalizer) + maximum);
}
});
}
} // namespace
void LogSumExp::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
// Make sure that the last dimension is contiguous
auto s = stream();
auto& encoder = cpu::get_command_encoder(s);
auto ensure_contiguous = [&s, &encoder](const array& x) {
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
return x;
} else {
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy(x, x_copy, CopyType::General, s);
encoder.add_temporary(x_copy);
return x_copy;
}
};
auto in = ensure_contiguous(inputs[0]);
if (in.flags().row_contiguous) {
out.set_data(allocator::malloc(out.nbytes()));
} else {
auto n = in.shape(-1);
auto flags = in.flags();
auto strides = in.strides();
for (auto& s : strides) {
s /= n;
}
bool col_contig = strides[0] == 1;
for (int i = 1; col_contig && i < strides.size(); ++i) {
col_contig &=
(out.shape(i) == 1 || strides[i - 1] == out.shape(i) * strides[i]);
}
flags.col_contiguous = col_contig;
out.set_data(
allocator::malloc(in.nbytes() / n),
in.data_size() / n,
std::move(strides),
flags);
}
switch (in.dtype()) {
case float32:
logsumexp<float, float>(in, out, stream());
break;
case float16:
logsumexp<float16_t, float>(in, out, stream());
break;
case bfloat16:
logsumexp<bfloat16_t, float>(in, out, stream());
break;
case float64:
logsumexp<double, double>(in, out, stream());
break;
default:
throw std::runtime_error(
"[logsumexp] only supports floating point types");
break;
}
}
} // namespace mlx::core

View File

@@ -30,8 +30,7 @@ void luf_impl(
auto strides = lu.strides();
strides[ndim - 1] = M;
strides[ndim - 2] = 1;
lu.set_data(
allocator::malloc_or_wait(lu.nbytes()), lu.nbytes(), strides, flags);
lu.set_data(allocator::malloc(lu.nbytes()), lu.nbytes(), strides, flags);
copy_inplace(
a,
lu,
@@ -44,8 +43,8 @@ void luf_impl(
stream);
auto a_ptr = lu.data<T>();
pivots.set_data(allocator::malloc_or_wait(pivots.nbytes()));
row_indices.set_data(allocator::malloc_or_wait(row_indices.nbytes()));
pivots.set_data(allocator::malloc(pivots.nbytes()));
row_indices.set_data(allocator::malloc(row_indices.nbytes()));
auto pivots_ptr = pivots.data<uint32_t>();
auto row_indices_ptr = row_indices.data<uint32_t>();
size_t num_matrices = a.size() / (M * N);

View File

@@ -59,7 +59,7 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
throw std::runtime_error(
"[BlockMaskedMM::eval] Currently only supports float32.");
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
auto& a_pre = inputs[0];
auto& b_pre = inputs[1];
@@ -318,7 +318,7 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
throw std::runtime_error(
"[GatherMM::eval] Currently only supports float32.");
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
auto& a_pre = inputs[0];
auto& b_pre = inputs[1];

View File

@@ -115,7 +115,7 @@ void matmul_general(
}
void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
if (inputs[0].shape(-1) == 0) {
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_output_array(out);

View File

@@ -21,7 +21,7 @@ namespace mlx::core {
void reshape(const array& in, array& out) {
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
copy_inplace(in, out, CopyType::General, out.primitive().stream());
} else {
shared_buffer_reshape(in, out_strides, out);
@@ -39,7 +39,7 @@ static std::pair<array, bool> compute_dynamic_offset(
if (donate) {
offset.copy_shared_buffer(indices);
} else {
offset.set_data(allocator::malloc_or_wait(offset.itemsize()));
offset.set_data(allocator::malloc(offset.itemsize()));
}
auto& encoder = cpu::get_command_encoder(stream);
@@ -124,7 +124,7 @@ void Transpose::eval_cpu(const std::vector<array>& inputs, array& out) {
void Arange::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 0);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
switch (out.dtype()) {
case bool_:
throw std::runtime_error("Bool type unsupported for arange.");
@@ -186,7 +186,7 @@ void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
}
std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin());
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
auto strides = out.strides();
auto flags = out.flags();
@@ -205,8 +205,10 @@ void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
void Contiguous::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.flags().row_contiguous ||
(allow_col_major_ && in.flags().col_contiguous)) {
constexpr size_t extra_bytes = 16384;
if (in.buffer_size() <= out.nbytes() + extra_bytes &&
(in.flags().row_contiguous ||
(allow_col_major_ && in.flags().col_contiguous))) {
out.copy_shared_buffer(in);
} else {
copy(in, out, CopyType::General, stream());
@@ -276,7 +278,7 @@ void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
size_t elems_per_key = out.size() / num_keys;
size_t bytes_per_key = out.itemsize() * elems_per_key;
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
auto kptr = inputs[0].data<uint32_t>();
auto cptr = out.data<char>();
@@ -335,7 +337,7 @@ void DynamicSlice::eval_cpu(const std::vector<array>& inputs, array& out) {
return;
}
auto& in = inputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
auto [in_offset, donated] =
compute_dynamic_offset(inputs[1], in.strides(), axes_, stream());
copy_inplace(
@@ -450,7 +452,7 @@ void View::eval_cpu(const std::vector<array>& inputs, array& out) {
} else {
auto tmp = array(
in.shape(), in.dtype() == bool_ ? uint8 : in.dtype(), nullptr, {});
tmp.set_data(allocator::malloc_or_wait(tmp.nbytes()));
tmp.set_data(allocator::malloc(tmp.nbytes()));
if (in.dtype() == bool_) {
auto in_tmp = array(in.shape(), uint8, nullptr, {});
in_tmp.copy_shared_buffer(in);

View File

@@ -25,12 +25,11 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
auto strides = in.strides();
strides[in.ndim() - 2] = 1;
strides[in.ndim() - 1] = M;
in.set_data(
allocator::malloc_or_wait(in.nbytes()), in.nbytes(), strides, flags);
in.set_data(allocator::malloc(in.nbytes()), in.nbytes(), strides, flags);
copy_inplace(a, in, CopyType::GeneralGeneral, stream);
auto& encoder = cpu::get_command_encoder(stream);
q.set_data(allocator::malloc_or_wait(q.nbytes()));
r.set_data(allocator::malloc_or_wait(r.nbytes()));
q.set_data(allocator::malloc(q.nbytes()));
r.set_data(allocator::malloc(r.nbytes()));
auto in_ptr = in.data<T>();
auto r_ptr = r.data<T>();
@@ -41,8 +40,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
encoder.set_output_array(r);
encoder.dispatch([in_ptr, q_ptr, r_ptr, M, N, lda, num_matrices]() {
int num_reflectors = std::min(M, N);
auto tau =
allocator::malloc_or_wait(sizeof(T) * num_matrices * num_reflectors);
auto tau = allocator::malloc(sizeof(T) * num_matrices * num_reflectors);
T optimal_work;
int lwork = -1;
@@ -53,7 +51,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
// Update workspace size
lwork = optimal_work;
auto work = allocator::malloc_or_wait(sizeof(T) * lwork);
auto work = allocator::malloc(sizeof(T) * lwork);
// Loop over matrices
for (int i = 0; i < num_matrices; ++i) {
@@ -96,7 +94,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
&lwork,
&info);
lwork = optimal_work;
work = allocator::malloc_or_wait(sizeof(T) * lwork);
work = allocator::malloc(sizeof(T) * lwork);
// Loop over matrices
for (int i = 0; i < num_matrices; ++i) {

View File

@@ -515,7 +515,7 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
auto scales = ensure_row_contiguous(scales_pre);
auto biases = ensure_row_contiguous(biases_pre);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cpu::get_command_encoder(stream());
encoder.add_temporaries(std::move(temps));
@@ -565,7 +565,7 @@ void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
auto scales = ensure_row_contiguous_last_dims(scales_pre);
auto biases = ensure_row_contiguous_last_dims(biases_pre);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cpu::get_command_encoder(stream());
encoder.add_temporaries(std::move(temps));
@@ -691,12 +691,12 @@ void fast::AffineQuantize::eval_cpu(
auto [w, copied] = ensure_row_contiguous(inputs[0]);
auto& out = outputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
auto& scales = outputs[1];
auto& biases = outputs[2];
scales.set_data(allocator::malloc_or_wait(scales.nbytes()));
biases.set_data(allocator::malloc_or_wait(biases.nbytes()));
scales.set_data(allocator::malloc(scales.nbytes()));
biases.set_data(allocator::malloc(biases.nbytes()));
auto& encoder = cpu::get_command_encoder(stream());
if (copied) {
encoder.add_temporary(w);

View File

@@ -433,7 +433,7 @@ void reduce_dispatch_min_max(
void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(in);
encoder.set_output_array(out);

View File

@@ -3,6 +3,7 @@
#include <cassert>
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/binary_ops.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h"
@@ -226,6 +227,16 @@ void scan_dispatch(
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
break;
}
case Scan::LogAddExp: {
auto op = [](U a, T b) {
return detail::LogAddExp{}(a, static_cast<U>(b));
};
auto init = (issubdtype(in.dtype(), floating))
? static_cast<U>(-std::numeric_limits<float>::infinity())
: std::numeric_limits<U>::min();
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
break;
}
}
}
@@ -244,7 +255,7 @@ void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
in = arr_copy;
encoder.add_temporary(arr_copy);
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
encoder.set_input_array(in);
encoder.set_output_array(out);

View File

@@ -17,7 +17,7 @@ struct ScalarT<float16_t, N> {
#endif
template <>
static constexpr int max_size<float16_t> = N;
inline constexpr int max_size<float16_t> = N;
#define SIMD_FP16_DEFAULT_UNARY(op) \
template <> \

View File

@@ -83,25 +83,25 @@ struct Simd {
// Values chosen based on benchmarks on M3 Max
// TODO: consider choosing these more optimally
template <>
static constexpr int max_size<int8_t> = 16;
inline constexpr int max_size<int8_t> = 16;
template <>
static constexpr int max_size<int16_t> = 16;
inline constexpr int max_size<int16_t> = 16;
template <>
static constexpr int max_size<int> = 8;
inline constexpr int max_size<int> = 8;
template <>
static constexpr int max_size<int64_t> = 4;
inline constexpr int max_size<int64_t> = 4;
template <>
static constexpr int max_size<uint8_t> = 16;
inline constexpr int max_size<uint8_t> = 16;
template <>
static constexpr int max_size<uint16_t> = 16;
inline constexpr int max_size<uint16_t> = 16;
template <>
static constexpr int max_size<uint32_t> = 8;
inline constexpr int max_size<uint32_t> = 8;
template <>
static constexpr int max_size<uint64_t> = 4;
inline constexpr int max_size<uint64_t> = 4;
template <>
static constexpr int max_size<float> = 8;
inline constexpr int max_size<float> = 8;
template <>
static constexpr int max_size<double> = 4;
inline constexpr int max_size<double> = 4;
#define SIMD_DEFAULT_UNARY(name, op) \
template <typename T, int N> \

View File

@@ -87,7 +87,6 @@ DEFAULT_UNARY(cosh, std::cosh)
DEFAULT_UNARY(expm1, std::expm1)
DEFAULT_UNARY(floor, std::floor)
DEFAULT_UNARY(log, std::log)
DEFAULT_UNARY(log2, std::log2)
DEFAULT_UNARY(log10, std::log10)
DEFAULT_UNARY(log1p, std::log1p)
DEFAULT_UNARY(sinh, std::sinh)
@@ -95,6 +94,17 @@ DEFAULT_UNARY(sqrt, std::sqrt)
DEFAULT_UNARY(tan, std::tan)
DEFAULT_UNARY(tanh, std::tanh)
template <typename T>
Simd<T, 1> log2(Simd<T, 1> in) {
if constexpr (is_complex<T>) {
auto out = std::log(in.value);
auto scale = decltype(out.real())(M_LN2);
return Simd<T, 1>{T{out.real() / scale, out.imag() / scale}};
} else {
return Simd<T, 1>{std::log2(in.value)};
}
}
template <typename T>
Simd<T, 1> operator~(Simd<T, 1> in) {
return ~in.value;

View File

@@ -119,17 +119,12 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
// Make sure that the last dimension is contiguous
auto set_output = [s = stream(), &out](const array& x) {
bool no_copy = x.strides()[x.ndim() - 1] == 1;
if (x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2];
no_copy &= (s == 0 || s == x.shape().back());
}
if (no_copy) {
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
if (x.is_donatable()) {
out.copy_shared_buffer(x);
} else {
out.set_data(
allocator::malloc_or_wait(x.data_size() * x.itemsize()),
allocator::malloc(x.data_size() * x.itemsize()),
x.data_size(),
x.strides(),
x.flags());
@@ -146,18 +141,6 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
auto in = set_output(inputs[0]);
switch (in.dtype()) {
case bool_:
case uint8:
case uint16:
case uint32:
case uint64:
case int8:
case int16:
case int32:
case int64:
throw std::runtime_error(
"Softmax is defined only for floating point types");
break;
case float32:
softmax<float, float>(in, out, stream());
break;
@@ -178,9 +161,9 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
case float64:
softmax<double, double>(in, out, stream());
break;
case complex64:
throw std::invalid_argument(
"[Softmax] Not yet implemented for complex64");
default:
throw std::runtime_error(
"[softmax] Only defined for floating point types.");
break;
}
}

View File

@@ -288,7 +288,7 @@ void ArgSort::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0];
// Allocate output
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(in);
@@ -379,7 +379,7 @@ void ArgPartition::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0];
// Allocate output
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(in);

View File

@@ -50,9 +50,9 @@ void svd_impl(
array& s = outputs[1];
array& vt = outputs[2];
u.set_data(allocator::malloc_or_wait(u.nbytes()));
s.set_data(allocator::malloc_or_wait(s.nbytes()));
vt.set_data(allocator::malloc_or_wait(vt.nbytes()));
u.set_data(allocator::malloc(u.nbytes()));
s.set_data(allocator::malloc(s.nbytes()));
vt.set_data(allocator::malloc(vt.nbytes()));
encoder.set_output_array(u);
encoder.set_output_array(s);
@@ -64,7 +64,7 @@ void svd_impl(
} else {
array& s = outputs[0];
s.set_data(allocator::malloc_or_wait(s.nbytes()));
s.set_data(allocator::malloc(s.nbytes()));
encoder.set_output_array(s);
@@ -91,7 +91,7 @@ void svd_impl(
// Will contain the indices of eigenvectors that failed to converge (not
// used here but required by lapack).
auto iwork = array::Data{allocator::malloc_or_wait(sizeof(int) * 12 * K)};
auto iwork = array::Data{allocator::malloc(sizeof(int) * 12 * K)};
static const int lwork_query = -1;
@@ -132,7 +132,7 @@ void svd_impl(
}
const int lwork = workspace_dimension;
auto scratch = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)};
auto scratch = array::Data{allocator::malloc(sizeof(T) * lwork)};
// Loop over matrices.
for (int i = 0; i < num_matrices; i++) {

View File

@@ -1,5 +1,8 @@
// Copyright © 2024 Apple Inc.
// Required for using M_LN2 in MSVC.
#define _USE_MATH_DEFINES
#include <cassert>
#include "mlx/backend/cpu/unary.h"

View File

@@ -18,13 +18,13 @@ void set_unary_output_data(const array& in, array& out) {
} else {
auto size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
allocator::malloc(size * out.itemsize()),
size,
in.strides(),
in.flags());
}
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
}
}

View File

@@ -86,13 +86,14 @@ struct Sign {
template <int N, typename T>
Simd<T, N> operator()(Simd<T, N> x) {
auto z = Simd<T, N>{0};
auto o = Simd<T, N>{1};
auto m = Simd<T, N>{-1};
if constexpr (std::is_unsigned_v<T>) {
return x != z;
return simd::select(x == z, z, o);
} else if constexpr (std::is_same_v<T, complex64_t>) {
return simd::select(x == z, x, Simd<T, N>(x / simd::abs(x)));
} else {
return simd::select(
x < z, Simd<T, N>{-1}, simd::select(x > z, Simd<T, N>{1}, z));
return simd::select(x < z, m, simd::select(x > z, o, z));
}
}
SINGLE()

View File

@@ -47,6 +47,7 @@ if(MLX_METAL_JIT)
make_jit_source(binary)
make_jit_source(binary_two)
make_jit_source(fft kernels/fft/radix.h kernels/fft/readwrite.h)
make_jit_source(logsumexp)
make_jit_source(ternary)
make_jit_source(softmax)
make_jit_source(scan)
@@ -95,6 +96,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cpp
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp

View File

@@ -3,6 +3,7 @@
#include "mlx/backend/metal/metal.h"
#include "mlx/backend/metal/metal_impl.h"
#include "mlx/backend/metal/resident.h"
#include "mlx/memory.h"
#include <mach/vm_page_size.h>
#include <unistd.h>
@@ -32,8 +33,11 @@ namespace metal {
namespace {
BufferCache::BufferCache(MTL::Device* device)
: device_(device), head_(nullptr), tail_(nullptr), pool_size_(0) {}
BufferCache::BufferCache(ResidencySet& residency_set)
: head_(nullptr),
tail_(nullptr),
pool_size_(0),
residency_set_(residency_set) {}
BufferCache::~BufferCache() {
auto pool = metal::new_scoped_memory_pool();
@@ -44,6 +48,9 @@ int BufferCache::clear() {
int n_release = 0;
for (auto& [size, holder] : buffer_pool_) {
if (holder->buf) {
if (!holder->buf->heap()) {
residency_set_.erase(holder->buf);
}
holder->buf->release();
n_release++;
}
@@ -101,6 +108,9 @@ int BufferCache::release_cached_buffers(size_t min_bytes_to_free) {
while (tail_ && (total_bytes_freed < min_bytes_to_free)) {
if (tail_->buf) {
total_bytes_freed += tail_->buf->length();
if (!tail_->buf->heap()) {
residency_set_.erase(tail_->buf);
}
tail_->buf->release();
tail_->buf = nullptr;
n_release++;
@@ -155,7 +165,7 @@ void BufferCache::remove_from_list(BufferCache::BufferHolder* to_remove) {
MetalAllocator::MetalAllocator()
: device_(device(mlx::core::Device::gpu).mtl_device()),
residency_set_(device_),
buffer_cache_(device_) {
buffer_cache_(residency_set_) {
auto pool = metal::new_scoped_memory_pool();
auto memsize = std::get<size_t>(device_info().at("memory_size"));
auto max_rec_size =
@@ -192,16 +202,19 @@ size_t MetalAllocator::set_cache_limit(size_t limit) {
return limit;
};
size_t MetalAllocator::set_memory_limit(size_t limit, bool relaxed) {
size_t MetalAllocator::set_memory_limit(size_t limit) {
std::unique_lock lk(mutex_);
std::swap(limit, block_limit_);
relaxed_ = relaxed;
gc_limit_ = std::min(
block_limit_,
static_cast<size_t>(0.95 * device_->recommendedMaxWorkingSetSize()));
return limit;
};
size_t MetalAllocator::get_memory_limit() {
return block_limit_;
}
size_t MetalAllocator::set_wired_limit(size_t limit) {
std::unique_lock lk(mutex_);
std::swap(limit, wired_limit_);
@@ -209,7 +222,7 @@ size_t MetalAllocator::set_wired_limit(size_t limit) {
return limit;
};
Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
Buffer MetalAllocator::malloc(size_t size) {
// Metal doesn't like empty buffers
if (size == 0) {
return Buffer{nullptr};
@@ -236,11 +249,6 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
if (!buf) {
size_t mem_required = get_active_memory() + get_cache_memory() + size;
// If there is too much memory pressure, fail (likely causes a wait).
if (!(allow_swap && relaxed_) && mem_required >= block_limit_) {
return Buffer{nullptr};
}
auto pool = metal::new_scoped_memory_pool();
// If we have a lot of memory pressure or are over the maximum cache size,
@@ -264,9 +272,13 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
if (!buf) {
buf = device_->newBuffer(size, resource_options);
}
if (!buf) {
return Buffer{nullptr};
}
lk.lock();
if (buf) {
num_resources_++;
num_resources_++;
if (!buf->heap()) {
residency_set_.insert(buf);
}
}
@@ -280,10 +292,6 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
get_cache_memory() - max_pool_size_);
}
if (!buf->heap()) {
residency_set_.insert(buf);
}
return Buffer{static_cast<void*>(buf)};
}
@@ -299,14 +307,14 @@ void MetalAllocator::free(Buffer buffer) {
return;
}
std::unique_lock lk(mutex_);
if (!buf->heap()) {
residency_set_.erase(buf);
}
active_memory_ -= buf->length();
if (get_cache_memory() < max_pool_size_) {
buffer_cache_.recycle_to_cache(buf);
} else {
num_resources_--;
if (!buf->heap()) {
residency_set_.erase(buf);
}
lk.unlock();
auto pool = metal::new_scoped_memory_pool();
buf->release();
@@ -325,37 +333,40 @@ MetalAllocator& allocator() {
return *allocator_;
}
} // namespace metal
size_t set_cache_limit(size_t limit) {
return allocator().set_cache_limit(limit);
return metal::allocator().set_cache_limit(limit);
}
size_t set_memory_limit(size_t limit, bool relaxed /* = true */) {
return allocator().set_memory_limit(limit, relaxed);
size_t set_memory_limit(size_t limit) {
return metal::allocator().set_memory_limit(limit);
}
size_t get_memory_limit() {
return metal::allocator().get_memory_limit();
}
size_t set_wired_limit(size_t limit) {
if (limit >
std::get<size_t>(device_info().at("max_recommended_working_set_size"))) {
if (limit > std::get<size_t>(metal::device_info().at(
"max_recommended_working_set_size"))) {
throw std::invalid_argument(
"[metal::set_wired_limit] Setting a wired limit larger than "
"the maximum working set size is not allowed.");
}
return allocator().set_wired_limit(limit);
return metal::allocator().set_wired_limit(limit);
}
size_t get_active_memory() {
return allocator().get_active_memory();
return metal::allocator().get_active_memory();
}
size_t get_peak_memory() {
return allocator().get_peak_memory();
return metal::allocator().get_peak_memory();
}
void reset_peak_memory() {
allocator().reset_peak_memory();
metal::allocator().reset_peak_memory();
}
size_t get_cache_memory() {
return allocator().get_cache_memory();
return metal::allocator().get_cache_memory();
}
void clear_cache() {
return allocator().clear_cache();
return metal::allocator().clear_cache();
}
} // namespace metal
} // namespace mlx::core

View File

@@ -18,7 +18,7 @@ namespace {
class BufferCache {
public:
BufferCache(MTL::Device* device);
BufferCache(ResidencySet& residency_set);
~BufferCache();
MTL::Buffer* reuse_from_cache(size_t size);
@@ -42,13 +42,11 @@ class BufferCache {
void add_at_head(BufferHolder* to_add);
void remove_from_list(BufferHolder* to_remove);
MTL::Device* device_;
MTL::Heap* heap_{nullptr};
std::multimap<size_t, BufferHolder*> buffer_pool_;
BufferHolder* head_;
BufferHolder* tail_;
size_t pool_size_;
ResidencySet& residency_set_;
};
} // namespace
@@ -56,7 +54,7 @@ class BufferCache {
class MetalAllocator : public allocator::Allocator {
/** Allocator for Metal GPUs. */
public:
virtual Buffer malloc(size_t size, bool allow_swap = false) override;
virtual Buffer malloc(size_t size) override;
virtual void free(Buffer buffer) override;
virtual size_t size(Buffer buffer) const override;
size_t get_active_memory() {
@@ -73,7 +71,8 @@ class MetalAllocator : public allocator::Allocator {
return buffer_cache_.cache_size();
};
size_t set_cache_limit(size_t limit);
size_t set_memory_limit(size_t limit, bool relaxed);
size_t set_memory_limit(size_t limit);
size_t get_memory_limit();
size_t set_wired_limit(size_t limit);
void clear_cache();

View File

@@ -37,7 +37,7 @@ void explicit_gemm_conv_ND_gpu(
Shape unfolded_shape{implicit_M, implicit_K};
array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes()));
in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes()));
// Prepare unfolding kernel
std::ostringstream kname;
@@ -115,7 +115,7 @@ void explicit_gemm_conv_group_ND_gpu(
// Prepare unfolding array
Shape unfolded_shape{implicit_M, implicit_K * groups};
array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes()));
in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes()));
// Prepare unfolding kernel
std::ostringstream kname;
@@ -613,7 +613,7 @@ void winograd_conv_2D_gpu(
// Do filter transform
Shape filt_wg_shape = {8 * 8, conv_params.C, conv_params.O};
array filt_wg(std::move(filt_wg_shape), wt.dtype(), nullptr, {});
filt_wg.set_data(allocator::malloc_or_wait(filt_wg.nbytes()));
filt_wg.set_data(allocator::malloc(filt_wg.nbytes()));
copies_w.push_back(filt_wg);
{
int bc = 32;
@@ -640,7 +640,7 @@ void winograd_conv_2D_gpu(
// Do input transform
Shape inp_wg_shape = {8 * 8, N_tiles, conv_params.C};
array inp_wg(std::move(inp_wg_shape), in.dtype(), nullptr, {});
inp_wg.set_data(allocator::malloc_or_wait(inp_wg.nbytes()));
inp_wg.set_data(allocator::malloc(inp_wg.nbytes()));
copies_w.push_back(inp_wg);
{
int bc = 32;
@@ -667,7 +667,7 @@ void winograd_conv_2D_gpu(
// Do batched gemm
Shape out_wg_shape = {8 * 8, N_tiles, conv_params.O};
array out_wg(std::move(out_wg_shape), in.dtype(), nullptr, {});
out_wg.set_data(allocator::malloc_or_wait(out_wg.nbytes()));
out_wg.set_data(allocator::malloc(out_wg.nbytes()));
copies_w.push_back(out_wg);
{
std::vector<array> empty_copies;
@@ -712,6 +712,65 @@ void winograd_conv_2D_gpu(
}
}
void depthwise_conv_2D_gpu(
const Stream& s,
metal::Device& d,
const array& in,
const array& wt,
array out,
const MLXConvParams<2>& conv_params) {
std::ostringstream kname;
kname << "depthwise_conv_2d_" << type_to_name(out);
std::string base_name = kname.str();
const int N = conv_params.N;
const int ker_h = conv_params.wS[0];
const int ker_w = conv_params.wS[1];
const int str_h = conv_params.str[0];
const int str_w = conv_params.str[1];
const int tc = 8;
const int tw = 8;
const int th = 4;
const bool do_flip = conv_params.flip;
metal::MTLFCList func_consts = {
{&ker_h, MTL::DataType::DataTypeInt, 00},
{&ker_w, MTL::DataType::DataTypeInt, 01},
{&str_h, MTL::DataType::DataTypeInt, 10},
{&str_w, MTL::DataType::DataTypeInt, 11},
{&th, MTL::DataType::DataTypeInt, 100},
{&tw, MTL::DataType::DataTypeInt, 101},
{&do_flip, MTL::DataType::DataTypeBool, 200},
};
// clang-format off
kname << "_ker_h_" << ker_h
<< "_ker_w_" << ker_w
<< "_str_h_" << str_h
<< "_str_w_" << str_w
<< "_tgp_h_" << th
<< "_tgp_w_" << tw
<< "_do_flip_" << (do_flip ? 't' : 'n'); // clang-format on
std::string hash_name = kname.str();
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_input_array(wt, 1);
compute_encoder.set_output_array(out, 2);
compute_encoder.set_bytes(conv_params, 3);
MTL::Size group_dims = MTL::Size(tc, tw, th);
MTL::Size grid_dims = MTL::Size(
conv_params.C / tc, conv_params.oS[1] / tw, (conv_params.oS[0] / th) * N);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
void conv_2D_gpu(
const Stream& s,
metal::Device& d,
@@ -754,11 +813,20 @@ void conv_2D_gpu(
bool is_kdil_one = conv_params.kdil[0] == 1 && conv_params.kdil[1] == 1;
bool is_idil_one = conv_params.idil[0] == 1 && conv_params.idil[1] == 1;
if (groups > 1) {
if (is_idil_one && groups > 1) {
const int C_per_group = conv_params.C / groups;
const int O_per_group = conv_params.O / groups;
if (is_idil_one && (C_per_group <= 4 || C_per_group % 16 == 0) &&
if (C_per_group == 1 && O_per_group == 1 && is_kdil_one &&
conv_params.wS[0] <= 7 && conv_params.wS[1] <= 7 &&
conv_params.str[0] <= 2 && conv_params.str[1] <= 2 &&
conv_params.oS[0] % 8 == 0 && conv_params.oS[1] % 8 == 0 &&
conv_params.wt_strides[1] == conv_params.wS[1] &&
conv_params.C % 16 == 0 && conv_params.C == conv_params.O) {
return depthwise_conv_2D_gpu(s, d, in, wt, out, conv_params);
}
if ((C_per_group <= 4 || C_per_group % 16 == 0) &&
(O_per_group <= 16 || O_per_group % 16 == 0)) {
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
} else {
@@ -855,7 +923,7 @@ void conv_3D_gpu(
} // namespace
void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
auto& d = metal::device(s.device);

View File

@@ -202,7 +202,7 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
if (out.size() == 0) {
return;
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
bool large = out.data_size() > UINT32_MAX;
auto& d = metal::device(s.device);
std::string kernel_name = std::string(large ? "s2" : "s") + "_copy" +

View File

@@ -19,7 +19,7 @@ void CustomKernel::eval_gpu(
copies.emplace_back(init_value_.value(), out.dtype());
fill_gpu(copies.back(), out, s);
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
}
}

View File

@@ -55,7 +55,10 @@ std::pair<MTL::Library*, NS::Error*> load_library_from_path(
}
#ifdef SWIFTPM_BUNDLE
MTL::Library* try_load_bundle(MTL::Device* device, NS::URL* url) {
MTL::Library* try_load_bundle(
MTL::Device* device,
NS::URL* url,
const std::string& lib_name) {
std::string bundle_path = std::string(url->fileSystemRepresentation()) + "/" +
SWIFTPM_BUNDLE + ".bundle";
auto bundle = NS::Bundle::alloc()->init(
@@ -63,8 +66,8 @@ MTL::Library* try_load_bundle(MTL::Device* device, NS::URL* url) {
if (bundle != nullptr) {
std::string resource_path =
std::string(bundle->resourceURL()->fileSystemRepresentation()) + "/" +
"default.metallib";
auto [lib, error] = load_library_from_path(device, resource_path.c_str());
lib_name + ".metallib" auto [lib, error] =
load_library_from_path(device, resource_path.c_str());
if (lib) {
return lib;
}
@@ -73,51 +76,124 @@ MTL::Library* try_load_bundle(MTL::Device* device, NS::URL* url) {
}
#endif
// Firstly, search for the metallib in the same path as this binary
std::pair<MTL::Library*, NS::Error*> load_colocated_library(
MTL::Device* device,
const std::string& lib_name) {
std::string lib_path = get_colocated_mtllib_path(lib_name);
if (lib_path.size() != 0) {
return load_library_from_path(device, lib_path.c_str());
}
return {nullptr, nullptr};
}
std::pair<MTL::Library*, NS::Error*> load_swiftpm_library(
MTL::Device* device,
const std::string& lib_name) {
#ifdef SWIFTPM_BUNDLE
MTL::Library* library =
try_load_bundle(device, NS::Bundle::mainBundle()->bundleURL(), lib_name);
if (library != nullptr) {
return {library, nullptr};
}
auto bundles = NS::Bundle::allBundles();
for (int i = 0, c = (int)bundles->count(); i < c; i++) {
auto bundle = reinterpret_cast<NS::Bundle*>(bundles->object(i));
library = try_load_bundle(device, bundle->resourceURL());
if (library != nullptr) {
return {library, nullptr};
}
}
#endif
return {nullptr, nullptr};
}
MTL::Library* load_default_library(MTL::Device* device) {
NS::Error *error1, *error2, *error3;
MTL::Library* lib;
// First try the colocated mlx.metallib
std::tie(lib, error1) = load_colocated_library(device, "mlx");
if (lib) {
return lib;
}
// Then try default.metallib in a SwiftPM bundle if we have one
std::tie(lib, error2) = load_swiftpm_library(device, "default");
if (lib) {
return lib;
}
// Finally try default_mtllib_path
std::tie(lib, error3) = load_library_from_path(device, default_mtllib_path);
if (!lib) {
std::ostringstream msg;
msg << "Failed to load the default metallib. ";
if (error1 != nullptr) {
msg << error1->localizedDescription()->utf8String() << " ";
}
if (error2 != nullptr) {
msg << error2->localizedDescription()->utf8String() << " ";
}
if (error3 != nullptr) {
msg << error3->localizedDescription()->utf8String() << " ";
}
throw std::runtime_error(msg.str());
}
return lib;
}
MTL::Library* load_library(
MTL::Device* device,
const std::string& lib_name = "mlx",
const char* lib_path = default_mtllib_path) {
// Firstly, search for the metallib in the same path as this binary
std::string first_path = get_colocated_mtllib_path(lib_name);
if (first_path.size() != 0) {
auto [lib, error] = load_library_from_path(device, first_path.c_str());
const std::string& lib_name,
const std::string& lib_path) {
// We have been given a path that ends in metallib so try to load it
if (lib_path.size() > 9 &&
std::equal(lib_path.end() - 9, lib_path.end(), ".metallib")) {
auto [lib, error] = load_library_from_path(device, lib_path.c_str());
if (!lib) {
std::ostringstream msg;
msg << "Failed to load the metallib from <" << lib_path << "> with error "
<< error->localizedDescription()->utf8String();
throw std::runtime_error(msg.str());
}
}
// We have been given a path so try to load from lib_path / lib_name.metallib
if (lib_path.size() > 0) {
std::string full_path = lib_path + "/" + lib_name + ".metallib";
auto [lib, error] = load_library_from_path(device, full_path.c_str());
if (!lib) {
std::ostringstream msg;
msg << "Failed to load the metallib from <" << full_path
<< "> with error " << error->localizedDescription()->utf8String();
throw std::runtime_error(msg.str());
}
}
// Try to load the colocated library
{
auto [lib, error] = load_colocated_library(device, lib_name);
if (lib) {
return lib;
}
}
#ifdef SWIFTPM_BUNDLE
// try to load from a swiftpm resource bundle -- scan the available bundles to
// find one that contains the named bundle
// Try to load the library from swiftpm
{
MTL::Library* library =
try_load_bundle(device, NS::Bundle::mainBundle()->bundleURL());
if (library != nullptr) {
return library;
}
auto bundles = NS::Bundle::allBundles();
for (int i = 0, c = (int)bundles->count(); i < c; i++) {
auto bundle = reinterpret_cast<NS::Bundle*>(bundles->object(i));
library = try_load_bundle(device, bundle->resourceURL());
if (library != nullptr) {
return library;
}
auto [lib, error] = load_swiftpm_library(device, lib_name);
if (lib) {
return lib;
}
}
#endif
// Couldn't find it so let's load it from default_mtllib_path
{
auto [lib, error] = load_library_from_path(device, lib_path);
if (!lib) {
std::ostringstream msg;
msg << error->localizedDescription()->utf8String() << "\n"
<< "Failed to load device library from <" << lib_path << ">"
<< " or <" << first_path << ">.";
throw std::runtime_error(msg.str());
}
return lib;
}
std::ostringstream msg;
msg << "Failed to load the metallib " << lib_name << ".metallib. "
<< "We attempted to load it from <" << get_colocated_mtllib_path(lib_name)
<< ">";
#ifdef SWIFTPM_BUNDLE
msg << " and from the Swift PM bundle.";
#endif
throw std::runtime_error(msg.str());
}
} // namespace
@@ -210,7 +286,7 @@ void CommandEncoder::barrier() {
Device::Device() {
auto pool = new_scoped_memory_pool();
device_ = load_device();
library_map_ = {{"mlx", load_library(device_)}};
library_map_ = {{"mlx", load_default_library(device_)}};
arch_ = std::string(device_->architecture()->name()->utf8String());
auto arch = arch_.back();
switch (arch) {

View File

@@ -189,15 +189,7 @@ class Device {
void register_library(
const std::string& lib_name,
const std::string& lib_path);
// Note, this should remain in the header so that it is not dynamically
// linked
void register_library(const std::string& lib_name) {
if (auto it = library_map_.find(lib_name); it == library_map_.end()) {
register_library(lib_name, get_colocated_mtllib_path(lib_name));
}
}
const std::string& lib_path = "");
MTL::Library* get_library(
const std::string& name,

View File

@@ -24,10 +24,6 @@ void Event::wait() {
}
}
void Event::signal() {
static_cast<MTL::SharedEvent*>(event_.get())->setSignaledValue(value());
}
void Event::wait(Stream stream) {
if (stream.device == Device::cpu) {
scheduler::enqueue(stream, [*this]() mutable { wait(); });
@@ -42,7 +38,9 @@ void Event::wait(Stream stream) {
void Event::signal(Stream stream) {
if (stream.device == Device::cpu) {
scheduler::enqueue(stream, [*this]() mutable { signal(); });
scheduler::enqueue(stream, [*this]() mutable {
static_cast<MTL::SharedEvent*>(event_.get())->setSignaledValue(value());
});
} else {
auto& d = metal::device(stream.device);
d.end_encoding(stream.index);

View File

@@ -20,7 +20,7 @@ struct FenceImpl {
auto p = metal::new_scoped_memory_pool();
fence = static_cast<void*>(d->newSharedEvent());
} else {
auto buf = allocator::malloc_or_wait(sizeof(uint32_t)).ptr();
auto buf = allocator::malloc(sizeof(uint32_t)).ptr();
fence = static_cast<void*>(buf);
cpu_value()[0] = 0;
}

View File

@@ -281,7 +281,7 @@ std::tuple<array, array, array> compute_raders_constants(
}
array b_q_fft({rader_n - 1}, complex64, nullptr, {});
b_q_fft.set_data(allocator::malloc_or_wait(b_q_fft.nbytes()));
b_q_fft.set_data(allocator::malloc(b_q_fft.nbytes()));
auto b_q_fft_ptr =
reinterpret_cast<std::complex<float>*>(b_q_fft.data<complex64_t>());
std::ptrdiff_t item_size = b_q_fft.itemsize();
@@ -327,11 +327,11 @@ std::pair<array, array> compute_bluestein_constants(int n, int bluestein_n) {
}
array w_k({n}, complex64, nullptr, {});
w_k.set_data(allocator::malloc_or_wait(w_k.nbytes()));
w_k.set_data(allocator::malloc(w_k.nbytes()));
std::copy(w_k_vec.begin(), w_k_vec.end(), w_k.data<complex64_t>());
array w_q({bluestein_n}, complex64, nullptr, {});
w_q.set_data(allocator::malloc_or_wait(w_q.nbytes()));
w_q.set_data(allocator::malloc(w_q.nbytes()));
auto w_q_ptr =
reinterpret_cast<std::complex<float>*>(w_q.data<complex64_t>());
@@ -356,20 +356,14 @@ void multi_upload_bluestein_fft(
bool inverse,
bool real,
FFTPlan& plan,
std::vector<array> copies,
std::vector<array>& copies,
const Stream& s) {
// TODO(alexbarron) Implement fused kernels for mutli upload bluestein's
// algorithm
int n = inverse ? out.shape(axis) : in.shape(axis);
auto [w_k, w_q] = compute_bluestein_constants(n, plan.bluestein_n);
// Broadcast w_q and w_k to the batch size
Strides b_strides(in.ndim(), 0);
b_strides[axis] = 1;
array w_k_broadcast({}, complex64, nullptr, {});
array w_q_broadcast({}, complex64, nullptr, {});
w_k_broadcast.copy_shared_buffer(w_k, b_strides, {}, w_k.data_size());
w_q_broadcast.copy_shared_buffer(w_q, b_strides, {}, w_q.data_size());
copies.push_back(w_k);
copies.push_back(w_q);
auto temp_shape = inverse ? out.shape() : in.shape();
array temp(temp_shape, complex64, nullptr, {});
@@ -378,13 +372,13 @@ void multi_upload_bluestein_fft(
if (real && !inverse) {
// Convert float32->complex64
copy_gpu(in, temp, CopyType::General, s);
copies.push_back(temp);
} else if (real && inverse) {
int back_offset = n % 2 == 0 ? 2 : 1;
auto slice_shape = in.shape();
slice_shape[axis] -= back_offset;
array slice_temp(slice_shape, complex64, nullptr, {});
array conj_temp(in.shape(), complex64, nullptr, {});
copies.push_back(slice_temp);
copies.push_back(conj_temp);
Shape rstarts(in.ndim(), 0);
@@ -394,19 +388,28 @@ void multi_upload_bluestein_fft(
unary_op_gpu({in}, conj_temp, "Conjugate", s);
slice_gpu(in, slice_temp, rstarts, rstrides, s);
concatenate_gpu({conj_temp, slice_temp}, temp, (int)axis, s);
copies.push_back(temp);
} else if (inverse) {
unary_op_gpu({in}, temp, "Conjugate", s);
copies.push_back(temp);
} else {
temp.copy_shared_buffer(in);
}
Strides b_strides(in.ndim(), 0);
b_strides[axis] = 1;
array w_k_broadcast(temp.shape(), complex64, nullptr, {});
w_k_broadcast.copy_shared_buffer(w_k, b_strides, {}, w_k.data_size());
binary_op_gpu({temp, w_k_broadcast}, temp1, "Multiply", s);
std::vector<std::pair<int, int>> pads;
auto padded_shape = out.shape();
padded_shape[axis] = plan.bluestein_n;
array pad_temp(padded_shape, complex64, nullptr, {});
pad_gpu(temp1, array(complex64_t{0.0f, 0.0f}), pad_temp, {(int)axis}, {0}, s);
auto zero = array(complex64_t{0.0f, 0.0f});
copies.push_back(zero);
pad_gpu(temp1, zero, pad_temp, {(int)axis}, {0}, s);
copies.push_back(pad_temp);
array pad_temp1(padded_shape, complex64, nullptr, {});
fft_op(
@@ -418,7 +421,10 @@ void multi_upload_bluestein_fft(
FourStepParams(),
/*inplace=*/false,
s);
copies.push_back(pad_temp1);
array w_q_broadcast(pad_temp1.shape(), complex64, nullptr, {});
w_q_broadcast.copy_shared_buffer(w_q, b_strides, {}, w_q.data_size());
binary_op_gpu_inplace({pad_temp1, w_q_broadcast}, pad_temp, "Multiply", s);
fft_op(
@@ -435,9 +441,11 @@ void multi_upload_bluestein_fft(
Shape starts(in.ndim(), 0);
Shape strides(in.ndim(), 1);
starts[axis] = plan.bluestein_n - offset - n;
slice_gpu(pad_temp1, temp, starts, strides, s);
binary_op_gpu_inplace({temp, w_k_broadcast}, temp1, "Multiply", s);
array temp2(temp_shape, complex64, nullptr, {});
slice_gpu(pad_temp1, temp2, starts, strides, s);
binary_op_gpu_inplace({temp2, w_k_broadcast}, temp1, "Multiply", s);
if (real && !inverse) {
Shape rstarts(in.ndim(), 0);
@@ -449,26 +457,21 @@ void multi_upload_bluestein_fft(
array temp_float(out.shape(), out.dtype(), nullptr, {});
copies.push_back(temp_float);
copies.push_back(inv_n);
copies.push_back(temp1);
copy_gpu(temp1, temp_float, CopyType::General, s);
binary_op_gpu({temp_float, inv_n}, out, "Multiply", s);
} else if (inverse) {
auto inv_n = array({1.0f / n}, {1}, complex64);
unary_op_gpu({temp1}, temp, "Conjugate", s);
binary_op_gpu({temp, inv_n}, out, "Multiply", s);
array temp3(temp_shape, complex64, nullptr, {});
unary_op_gpu({temp1}, temp3, "Conjugate", s);
binary_op_gpu({temp3, inv_n}, out, "Multiply", s);
copies.push_back(inv_n);
copies.push_back(temp1);
copies.push_back(temp3);
} else {
out.copy_shared_buffer(temp1);
}
copies.push_back(w_k);
copies.push_back(w_q);
copies.push_back(w_k_broadcast);
copies.push_back(w_q_broadcast);
copies.push_back(temp);
copies.push_back(temp1);
copies.push_back(pad_temp);
copies.push_back(pad_temp1);
}
void four_step_fft(
@@ -478,8 +481,9 @@ void four_step_fft(
bool inverse,
bool real,
FFTPlan& plan,
std::vector<array> copies,
const Stream& s) {
std::vector<array>& copies,
const Stream& s,
bool in_place) {
auto& d = metal::device(s.device);
if (plan.bluestein_n == -1) {
@@ -492,7 +496,14 @@ void four_step_fft(
in, temp, axis, inverse, real, four_step_params, /*inplace=*/false, s);
four_step_params.first_step = false;
fft_op(
temp, out, axis, inverse, real, four_step_params, /*inplace=*/false, s);
temp,
out,
axis,
inverse,
real,
four_step_params,
/*inplace=*/in_place,
s);
copies.push_back(temp);
} else {
multi_upload_bluestein_fft(in, out, axis, inverse, real, plan, copies, s);
@@ -551,8 +562,7 @@ void fft_op(
flags.row_contiguous = is_row_contiguous;
flags.contiguous = data_size == x_copy.size();
x_copy.set_data(
allocator::malloc_or_wait(x.nbytes()), data_size, strides, flags);
x_copy.set_data(allocator::malloc(x.nbytes()), data_size, strides, flags);
copy_gpu_inplace(x, x_copy, CopyType::GeneralGeneral, s);
copies.push_back(x_copy);
return x_copy;
@@ -575,7 +585,7 @@ void fft_op(
auto plan = plan_fft(n);
if (plan.four_step) {
four_step_fft(in, out, axis, inverse, real, plan, copies, s);
four_step_fft(in, out, axis, inverse, real, plan, copies, s, inplace);
d.add_temporaries(std::move(copies), s.index);
return;
}
@@ -583,7 +593,7 @@ void fft_op(
// TODO: allow donation here
if (!inplace) {
out.set_data(
allocator::malloc_or_wait(out.nbytes()),
allocator::malloc(out.nbytes()),
out_data_size,
out_strides,
in_contiguous.flags());

View File

@@ -84,7 +84,7 @@ void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
if (in_contiguous.is_donatable()) {
out.copy_shared_buffer(in_contiguous);
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
}
int n, m;
@@ -161,7 +161,7 @@ void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
// Upload 2:
// y = h12 @ tmp
array temp(in.shape(), in.dtype(), nullptr, {});
temp.set_data(allocator::malloc_or_wait(temp.nbytes()));
temp.set_data(allocator::malloc(temp.nbytes()));
copies.push_back(temp);
launch_hadamard(in_contiguous, temp, "n" + kernel_name, 1.0);

View File

@@ -43,7 +43,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
throw std::runtime_error(msg.str());
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
if (out.size() == 0) {
return;
}
@@ -393,7 +393,7 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& src = inputs[0];
auto& idx = inputs[1];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
if (out.size() == 0) {
return;
}

View File

@@ -1,9 +0,0 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view arange_kernels = R"(
template [[host_name("{0}")]] [[kernel]] void arange<{1}>(
constant const {1}& start,
constant const {1}& step,
device {1}* out,
uint index [[thread_position_in_grid]]);
)";

View File

@@ -20,6 +20,7 @@ const char* copy();
const char* fft();
const char* gather_axis();
const char* hadamard();
const char* logsumexp();
const char* quantized();
const char* ternary();
const char* scan();

View File

@@ -1,23 +0,0 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view softmax_kernels = R"(
template [[host_name("block_{0}")]] [[kernel]] void
softmax_single_row<{1}, {2}>(
const device {1}* in,
device {1}* out,
constant int& axis_size,
uint gid [[thread_position_in_grid]],
uint _lid [[thread_position_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
template [[host_name("looped_{0}")]] [[kernel]] void
softmax_looped<{1}, {2}>(
const device {1}* in,
device {1}* out,
constant int& axis_size,
uint gid [[threadgroup_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
)";

View File

@@ -1,8 +1,6 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/metal/jit/arange.h"
#include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/jit/softmax.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h"
@@ -21,13 +19,11 @@ MTL::ComputePipelineState* get_arange_kernel(
const std::string& kernel_name,
const array& out) {
auto lib = d.get_library(kernel_name, [&]() {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::arange()
<< fmt::format(
arange_kernels,
kernel_name,
get_type_string(out.dtype()));
return kernel_source.str();
std::string kernel_source = metal::utils();
kernel_source += metal::arange();
kernel_source += get_template_definition(
kernel_name, "arange", get_type_string(out.dtype()));
return kernel_source;
});
return d.get_kernel(kernel_name, lib);
}
@@ -259,14 +255,34 @@ MTL::ComputePipelineState* get_softmax_kernel(
const array& out) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name, [&] {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::softmax()
<< fmt::format(
softmax_kernels,
lib_name,
get_type_string(out.dtype()),
get_type_string(precise ? float32 : out.dtype()));
return kernel_source.str();
std::string kernel_source = metal::utils();
auto in_type = get_type_string(out.dtype());
auto acc_type = get_type_string(precise ? float32 : out.dtype());
kernel_source += metal::softmax();
kernel_source += get_template_definition(
"block_" + lib_name, "softmax_single_row", in_type, acc_type);
kernel_source += get_template_definition(
"looped_" + lib_name, "softmax_looped", in_type, acc_type);
return kernel_source;
});
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_logsumexp_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name, [&] {
auto t_str = get_type_string(out.dtype());
std::string kernel_source;
kernel_source = metal::utils();
kernel_source += metal::logsumexp();
kernel_source +=
get_template_definition("block_" + lib_name, "logsumexp", t_str);
kernel_source += get_template_definition(
"looped_" + lib_name, "logsumexp_looped", t_str);
return kernel_source;
});
return d.get_kernel(kernel_name, lib);
}

View File

@@ -59,6 +59,11 @@ MTL::ComputePipelineState* get_softmax_kernel(
bool precise,
const array& out);
MTL::ComputePipelineState* get_logsumexp_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out);
MTL::ComputePipelineState* get_scan_kernel(
metal::Device& d,
const std::string& kernel_name,

View File

@@ -13,6 +13,10 @@ function(build_kernel_base TARGET SRCFILE DEPS)
if(MLX_METAL_DEBUG)
set(METAL_FLAGS ${METAL_FLAGS} -gline-tables-only -frecord-sources)
endif()
if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "")
set(METAL_FLAGS ${METAL_FLAGS}
"-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
endif()
if(MLX_METAL_VERSION GREATER_EQUAL 310)
set(VERSION_INCLUDES
${PROJECT_SOURCE_DIR}/mlx/backend/metal/kernels/metal_3_1)
@@ -105,6 +109,7 @@ if(NOT MLX_METAL_JIT)
build_kernel(quantized quantized.h ${STEEL_HEADERS})
build_kernel(scan scan.h)
build_kernel(softmax softmax.h)
build_kernel(logsumexp logsumexp.h)
build_kernel(sort sort.h)
build_kernel(ternary ternary.h ternary_ops.h)
build_kernel(unary unary.h unary_ops.h)

View File

@@ -5,11 +5,7 @@
#include "mlx/backend/metal/kernels/arange.h"
#define instantiate_arange(tname, type) \
template [[host_name("arange" #tname)]] [[kernel]] void arange<type>( \
constant const type& start, \
constant const type& step, \
device type* out, \
uint index [[thread_position_in_grid]]);
instantiate_kernel("arange" #tname, arange, type)
instantiate_arange(uint8, uint8_t)
instantiate_arange(uint16, uint16_t)

View File

@@ -275,6 +275,128 @@ instantiate_naive_conv_2d_blocks(float32, float);
instantiate_naive_conv_2d_blocks(float16, half);
instantiate_naive_conv_2d_blocks(bfloat16, bfloat16_t);
///////////////////////////////////////////////////////////////////////////////
/// Depthwise convolution kernels
///////////////////////////////////////////////////////////////////////////////
constant int ker_h [[function_constant(00)]];
constant int ker_w [[function_constant(01)]];
constant int str_h [[function_constant(10)]];
constant int str_w [[function_constant(11)]];
constant int tgp_h [[function_constant(100)]];
constant int tgp_w [[function_constant(101)]];
constant bool do_flip [[function_constant(200)]];
constant int span_h = tgp_h * str_h + ker_h - 1;
constant int span_w = tgp_w * str_w + ker_w - 1;
constant int span_hw = span_h * span_w;
template <typename T>
[[kernel]] void depthwise_conv_2d(
const device T* in [[buffer(0)]],
const device T* wt [[buffer(1)]],
device T* out [[buffer(2)]],
const constant MLXConvParams<2>& params [[buffer(3)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 gid [[thread_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int tc = 8;
constexpr int tw = 8;
constexpr int th = 4;
constexpr int c_per_thr = 8;
constexpr int TGH = th * 2 + 6;
constexpr int TGW = tw * 2 + 6;
constexpr int TGC = tc;
threadgroup T ins[TGH * TGW * TGC];
const int n_tgblocks_h = params.oS[0] / th;
const int n = tid.z / n_tgblocks_h;
const int tghid = tid.z % n_tgblocks_h;
const int oh = tghid * th + lid.z;
const int ow = gid.y;
const int c = gid.x;
in += n * params.in_strides[0];
// Load in
{
constexpr int n_threads = th * tw * tc;
const int tg_oh = (tghid * th) * str_h - params.pad[0];
const int tg_ow = (tid.y * tw) * str_w - params.pad[1];
const int tg_c = tid.x * tc;
const int thread_idx = simd_gid * 32 + simd_lid;
constexpr int thr_per_hw = tc / c_per_thr;
constexpr int hw_per_group = n_threads / thr_per_hw;
const int thr_c = thread_idx % thr_per_hw;
const int thr_hw = thread_idx / thr_per_hw;
for (int hw = thr_hw; hw < span_hw; hw += hw_per_group) {
const int h = hw / span_w;
const int w = hw % span_w;
const int ih = tg_oh + h;
const int iw = tg_ow + w;
const int in_s_offset = h * span_w * TGC + w * TGC;
if (ih >= 0 && ih < params.iS[0] && iw >= 0 && iw < params.iS[1]) {
const auto in_load =
in + ih * params.in_strides[1] + iw * params.in_strides[2] + tg_c;
MLX_MTL_PRAGMA_UNROLL
for (int cc = 0; cc < c_per_thr; ++cc) {
ins[in_s_offset + c_per_thr * thr_c + cc] =
in_load[c_per_thr * thr_c + cc];
}
} else {
MLX_MTL_PRAGMA_UNROLL
for (int cc = 0; cc < c_per_thr; ++cc) {
ins[in_s_offset + c_per_thr * thr_c + cc] = T(0);
}
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
wt += c * params.wt_strides[0];
const auto ins_ptr =
&ins[lid.z * str_h * span_w * TGC + lid.y * str_w * TGC + lid.x];
float o = 0.;
for (int h = 0; h < ker_h; ++h) {
for (int w = 0; w < ker_w; ++w) {
int wt_h = h;
int wt_w = w;
if (do_flip) {
wt_h = ker_h - h - 1;
wt_w = ker_w - w - 1;
}
auto inv = ins_ptr[h * span_w * TGC + w * TGC];
auto wtv = wt[wt_h * ker_w + wt_w];
o += inv * wtv;
}
}
threadgroup_barrier(mem_flags::mem_none);
out += n * params.out_strides[0] + oh * params.out_strides[1] +
ow * params.out_strides[2];
out[c] = static_cast<T>(o);
}
#define instantiate_depthconv2d(iname, itype) \
instantiate_kernel("depthwise_conv_2d_" #iname, depthwise_conv_2d, itype)
instantiate_depthconv2d(float32, float);
instantiate_depthconv2d(float16, half);
instantiate_depthconv2d(bfloat16, bfloat16_t);
///////////////////////////////////////////////////////////////////////////////
/// Winograd kernels
///////////////////////////////////////////////////////////////////////////////

View File

@@ -483,4 +483,4 @@ template <
perform_fft(fft_idx, &p, m, n, buf);
read_writer.write_strided(stride, overall_n);
}
}

View File

@@ -341,7 +341,7 @@ struct GEMVTKernel {
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
auto vc = float(v_coeff[tm]);
auto vc = static_cast<AccT>(v_coeff[tm]);
for (int tn = 0; tn < TN; tn++) {
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
}

View File

@@ -493,71 +493,11 @@ template <typename T, int N_READS = RMS_N_READS>
}
// clang-format off
#define instantiate_layer_norm_single_row(name, itype) \
template [[host_name("layer_norm" #name)]] [[kernel]] void \
layer_norm_single_row<itype>( \
const device itype* x, \
const device itype* w, \
const device itype* b, \
device itype* out, \
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
constant uint& b_stride, \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
template [[host_name("vjp_layer_norm" #name)]] [[kernel]] void \
vjp_layer_norm_single_row<itype>( \
const device itype* x, \
const device itype* w, \
const device itype* g, \
device itype* gx, \
device itype* gw, \
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_layer_norm_looped(name, itype) \
template [[host_name("layer_norm_looped" #name)]] [[kernel]] void \
layer_norm_looped<itype>( \
const device itype* x, \
const device itype* w, \
const device itype* b, \
device itype* out, \
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
constant uint& b_stride, \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
template [[host_name("vjp_layer_norm_looped" #name)]] [[kernel]] void \
vjp_layer_norm_looped<itype>( \
const device itype* x, \
const device itype* w, \
const device itype* g, \
device itype* gx, \
device itype* gb, \
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_layer_norm(name, itype) \
instantiate_layer_norm_single_row(name, itype) \
instantiate_layer_norm_looped(name, itype)
#define instantiate_layer_norm(name, itype) \
instantiate_kernel("layer_norm" #name, layer_norm_single_row, itype) \
instantiate_kernel("vjp_layer_norm" #name, vjp_layer_norm_single_row, itype) \
instantiate_kernel("layer_norm_looped" #name, layer_norm_looped, itype) \
instantiate_kernel("vjp_layer_norm_looped" #name, vjp_layer_norm_looped, itype)
instantiate_layer_norm(float32, float)
instantiate_layer_norm(float16, half)

View File

@@ -0,0 +1,143 @@
// Copyright © 2025 Apple Inc.
template <typename T, typename AccT = float, int N_READS = 4>
[[kernel]] void logsumexp(
const device T* in,
device T* out,
constant int& axis_size,
uint gid [[threadgroup_position_in_grid]],
uint _lid [[thread_position_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
int lid = _lid;
constexpr int SIMD_SIZE = 32;
threadgroup AccT local_max[SIMD_SIZE];
threadgroup AccT local_normalizer[SIMD_SIZE];
AccT ld[N_READS];
in += gid * size_t(axis_size) + lid * N_READS;
if (lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
ld[i] = AccT(in[i]);
}
} else {
for (int i = 0; i < N_READS; i++) {
ld[i] =
((lid * N_READS + i) < axis_size) ? AccT(in[i]) : Limits<AccT>::min;
}
}
if (simd_group_id == 0) {
local_max[simd_lane_id] = Limits<AccT>::min;
local_normalizer[simd_lane_id] = 0;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Get the max
AccT maxval = Limits<AccT>::finite_min;
for (int i = 0; i < N_READS; i++) {
maxval = (maxval < ld[i]) ? ld[i] : maxval;
}
maxval = simd_max(maxval);
if (simd_lane_id == 0) {
local_max[simd_group_id] = maxval;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_group_id == 0) {
maxval = simd_max(local_max[simd_lane_id]);
if (simd_lane_id == 0) {
local_max[0] = maxval;
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
maxval = local_max[0];
// Compute exp(x_i - maxval) and store the partial sums in local_normalizer
AccT normalizer = 0;
for (int i = 0; i < N_READS; i++) {
normalizer += fast::exp(ld[i] - maxval);
}
normalizer = simd_sum(normalizer);
if (simd_lane_id == 0) {
local_normalizer[simd_group_id] = normalizer;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_group_id == 0) {
normalizer = simd_sum(local_normalizer[simd_lane_id]);
if (simd_lane_id == 0) {
out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval);
}
}
}
template <typename T, typename AccT = float, int N_READS = 4>
[[kernel]] void logsumexp_looped(
const device T* in,
device T* out,
constant int& axis_size,
uint gid [[threadgroup_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
in += gid * size_t(axis_size);
constexpr int SIMD_SIZE = 32;
threadgroup AccT local_max[SIMD_SIZE];
threadgroup AccT local_normalizer[SIMD_SIZE];
// Get the max and the normalizer in one go
AccT prevmax;
AccT maxval = Limits<AccT>::finite_min;
AccT normalizer = 0;
for (int r = 0; r < static_cast<int>(ceildiv(axis_size, N_READS * lsize));
r++) {
int offset = r * lsize * N_READS + lid * N_READS;
AccT vals[N_READS];
if (offset + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
vals[i] = AccT(in[offset + i]);
}
} else {
for (int i = 0; i < N_READS; i++) {
vals[i] = (offset + i < axis_size) ? AccT(in[offset + i])
: Limits<AccT>::finite_min;
}
}
prevmax = maxval;
for (int i = 0; i < N_READS; i++) {
maxval = (maxval < vals[i]) ? vals[i] : maxval;
}
normalizer *= fast::exp(prevmax - maxval);
for (int i = 0; i < N_READS; i++) {
normalizer += fast::exp(vals[i] - maxval);
}
}
prevmax = maxval;
maxval = simd_max(maxval);
normalizer *= fast::exp(prevmax - maxval);
normalizer = simd_sum(normalizer);
prevmax = maxval;
if (simd_lane_id == 0) {
local_max[simd_group_id] = maxval;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
maxval = simd_max(local_max[simd_lane_id]);
normalizer *= fast::exp(prevmax - maxval);
if (simd_lane_id == 0) {
local_normalizer[simd_group_id] = normalizer;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
normalizer = simd_sum(local_normalizer[simd_lane_id]);
if (simd_group_id == 0) {
normalizer = simd_sum(local_normalizer[simd_lane_id]);
if (simd_lane_id == 0) {
out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval);
}
}
}

View File

@@ -0,0 +1,18 @@
// Copyright © 2023-2024 Apple Inc.
#include <metal_common>
#include <metal_simdgroup>
using namespace metal;
// clang-format off
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/logsumexp.h"
#define instantiate_logsumexp(name, itype) \
instantiate_kernel("block_logsumexp_" #name, logsumexp, itype) \
instantiate_kernel("looped_logsumexp_" #name, logsumexp_looped, itype) \
instantiate_logsumexp(float32, float)
instantiate_logsumexp(float16, half)
instantiate_logsumexp(bfloat16, bfloat16_t) // clang-format on

View File

@@ -586,13 +586,13 @@ METAL_FUNC void qmv_quad_impl(
// Adjust positions
const int in_vec_size_w = in_vec_size / pack_factor;
const int in_vec_size_g = in_vec_size / group_size;
const int out_row = tid.x * quads_per_simd * results_per_quadgroup + quad_gid;
const int out_row = tid.y * quads_per_simd * results_per_quadgroup + quad_gid;
w += out_row * in_vec_size_w + quad_lid * packs_per_thread;
scales += out_row * in_vec_size_g + quad_lid / scale_step_per_thread;
biases += out_row * in_vec_size_g + quad_lid / scale_step_per_thread;
x += tid.y * in_vec_size + quad_lid * values_per_thread;
y += tid.y * out_vec_size + out_row;
x += tid.x * in_vec_size + quad_lid * values_per_thread;
y += tid.x * out_vec_size + out_row;
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
@@ -684,6 +684,115 @@ METAL_FUNC void qmv_fast_impl(
}
}
template <uint32_t a = 89226354, uint32_t b = 64248484, uint32_t m = 996162400>
float inst3(uint16_t xi) {
uint32_t x = xi;
x = a * x + b;
x = (x & 0b10001111111111111000111111111111) ^ m;
auto xf = reinterpret_cast<thread float16_t*>(&x);
return xf[0] + xf[1];
}
template <typename T, int bits>
METAL_FUNC void qmv_trellis_impl(
const device uint32_t* w,
const device T* scales,
const device T* biases,
const device T* x,
device T* y,
const constant int& in_vec_size,
const constant int& out_vec_size,
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
constexpr int packs_per_thread = 2;
constexpr int num_simdgroups = 2;
constexpr int results_per_simdgroup = 4;
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3;
constexpr int values_per_thread = pack_factor * packs_per_thread;
constexpr int block_size = values_per_thread * SIMD_SIZE;
constexpr int reads_per = 16 / bits;
constexpr int local_w_size =
results_per_simdgroup * values_per_thread / reads_per;
const device uint8_t* ws = (const device uint8_t*)w;
typedef float U;
thread U x_thread[values_per_thread];
thread uint16_t w_thread[local_w_size];
thread U result[results_per_simdgroup] = {0};
// Adjust positions
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) +
simd_gid * results_per_simdgroup;
ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack;
x += tid.x * in_vec_size + simd_lid * values_per_thread;
y += tid.x * out_vec_size + out_row;
T scale = scales[0];
for (int k = 0; k < in_vec_size; k += block_size) {
#pragma clang loop unroll(full)
for (int i = 0; i < values_per_thread; i++) {
x_thread[i] = x[i];
}
#pragma clang loop unroll(full)
for (int row = 0; row < results_per_simdgroup; row++) {
#pragma clang loop unroll(full)
for (int i = 0; i < values_per_thread / reads_per; i++) {
auto wl = (const device uint16_t*)(ws + row * in_vec_size_w);
w_thread[row * values_per_thread / reads_per + i] = wl[i];
}
}
#pragma clang loop unroll(full)
for (int row = 0; row < results_per_simdgroup; row++) {
#pragma clang loop unroll(full)
for (int i = 0; i < values_per_thread / reads_per; i++) {
int index = row * values_per_thread / reads_per + i;
uint16_t w0 = w_thread[index];
uint16_t w1 = w_thread[(index + 1) % local_w_size];
uint16_t wx = w0 ^ w1;
uint16_t wx1 = wx ^ 1;
uint16_t wf = w0 ^ (1 << bits);
if (bits == 2) {
result[row] += x_thread[8 * i] * inst3(w0);
result[row] += x_thread[8 * i + 1] * inst3(wf ^ (wx1 & 0x3));
result[row] += x_thread[8 * i + 2] * inst3(w0 ^ (wx & 0xf));
result[row] += x_thread[8 * i + 3] * inst3(w0 ^ (wx1 & 0x3f));
result[row] += x_thread[8 * i + 4] * inst3(w0 ^ (wx & 0xff));
result[row] += x_thread[8 * i + 5] * inst3(w0 ^ (wx1 & 0x3ff));
result[row] += x_thread[8 * i + 6] * inst3(w0 ^ (wx & 0xfff));
result[row] += x_thread[8 * i + 7] * inst3(w0 ^ (wx1 & 0x3fff));
} else if (bits == 4) {
result[row] += x_thread[4 * i] * inst3(w0);
result[row] += x_thread[4 * i + 1] * inst3(wf ^ (wx1 & 0xf));
result[row] += x_thread[4 * i + 2] * inst3(w0 ^ (wx & 0xff));
result[row] += x_thread[4 * i + 3] * inst3(w0 ^ (wx1 & 0xfff));
}
}
}
ws += block_size * bytes_per_pack / pack_factor;
x += block_size;
}
for (int row = 0; row < results_per_simdgroup; row++) {
result[row] = simd_sum(result[row]);
if (simd_lid == 0) {
y[row] = static_cast<T>(scale * result[row]);
}
}
}
template <typename T, int group_size, int bits>
METAL_FUNC void qmv_impl(
const device uint32_t* w,
@@ -1302,7 +1411,13 @@ METAL_FUNC void adjust_matrix_offsets(
y += tid.z * output_stride;
}
template <typename T, int group_size, int bits, int D, bool batched>
template <
typename T,
int group_size,
int bits,
int D,
bool batched,
bool trellis = false>
[[kernel]] void qmv_quad(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
@@ -1354,7 +1469,12 @@ template <typename T, int group_size, int bits, int D, bool batched>
quad_lid);
}
template <typename T, int group_size, int bits, bool batched>
template <
typename T,
int group_size,
int bits,
bool batched,
bool trellis = false>
[[kernel]] void qmv_fast(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
@@ -1393,20 +1513,39 @@ template <typename T, int group_size, int bits, bool batched>
b_strides,
tid);
}
qmv_fast_impl<T, group_size, bits>(
w,
scales,
biases,
x,
y,
in_vec_size,
out_vec_size,
tid,
simd_gid,
simd_lid);
if (trellis) {
qmv_trellis_impl<T, bits>(
w,
scales,
biases,
x,
y,
in_vec_size,
out_vec_size,
tid,
simd_gid,
simd_lid);
} else {
qmv_fast_impl<T, group_size, bits>(
w,
scales,
biases,
x,
y,
in_vec_size,
out_vec_size,
tid,
simd_gid,
simd_lid);
}
}
template <typename T, const int group_size, const int bits, bool batched>
template <
typename T,
const int group_size,
const int bits,
bool batched,
bool trellis = false>
[[kernel]] void qmv(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
@@ -1458,7 +1597,12 @@ template <typename T, const int group_size, const int bits, bool batched>
simd_lid);
}
template <typename T, const int group_size, const int bits, bool batched>
template <
typename T,
const int group_size,
const int bits,
bool batched,
bool trellis = false>
[[kernel]] void qvm(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
@@ -1572,6 +1716,7 @@ template <
const int bits,
const bool aligned_N,
const bool batched,
bool trellis = false,
const int BM = 32,
const int BK = 32,
const int BN = 32>
@@ -1630,6 +1775,7 @@ template <
const int group_size,
const int bits,
const bool batched,
bool trellis = false,
const int BM = 32,
const int BK = 32,
const int BN = 32>
@@ -1685,7 +1831,7 @@ template <
w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
}
template <typename T, int group_size, int bits>
template <typename T, int group_size, int bits, bool trellis = false>
[[kernel]] void bs_qmv_fast(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
@@ -1734,20 +1880,34 @@ template <typename T, int group_size, int bits>
s_strides,
b_strides,
tid);
qmv_fast_impl<T, group_size, bits>(
w,
scales,
biases,
x,
y,
in_vec_size,
out_vec_size,
tid,
simd_gid,
simd_lid);
if (trellis) {
qmv_trellis_impl<T, bits>(
w,
scales,
biases,
x,
y,
in_vec_size,
out_vec_size,
tid,
simd_gid,
simd_lid);
} else {
qmv_fast_impl<T, group_size, bits>(
w,
scales,
biases,
x,
y,
in_vec_size,
out_vec_size,
tid,
simd_gid,
simd_lid);
}
}
template <typename T, int group_size, int bits>
template <typename T, int group_size, int bits, bool trellis = false>
[[kernel]] void bs_qmv(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
@@ -1809,7 +1969,7 @@ template <typename T, int group_size, int bits>
simd_lid);
}
template <typename T, int group_size, int bits>
template <typename T, int group_size, int bits, bool trellis = false>
[[kernel]] void bs_qvm(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
@@ -1876,6 +2036,7 @@ template <
const int group_size,
const int bits,
const bool aligned_N,
bool trellis = false,
const int BM = 32,
const int BK = 32,
const int BN = 32>
@@ -1943,6 +2104,7 @@ template <
typename T,
const int group_size,
const int bits,
bool trellis = false,
const int BM = 32,
const int BK = 32,
const int BN = 32>
@@ -2157,3 +2319,211 @@ template <typename T, const int group_size, const int bits>
}
}
}
template <
typename T,
const bool use_overlap,
const int bits = 2,
const int timesteps = 128>
[[kernel]] void trellis_viterbi(
const device T* w [[buffer(0)]],
device float16_t* score [[buffer(1)]],
device uint8_t* pointers [[buffer(2)]],
const device uint16_t* overlap [[buffer(3)]],
uint3 tid [[thread_position_in_grid]]) {
constexpr uint16_t L = 16;
constexpr uint L2 = 1 << L;
uint16_t idx = tid.y * 16;
threadgroup float16_t swap_V[16384];
thread float16_t min_V[16] = {0};
for (uint16_t t = 0; t < timesteps; t++) {
uint16_t tt = t % 8 == 0 ? L / bits : t % 8;
uint16_t shift = ((tt - 1) % (L / bits)) * bits;
uint16_t flip = (t == 0 || (t > 1 && t % 8 == 1)) ? (1 << bits) + 1 : t % 2;
uint16_t s000 = 1 << (shift - 6);
uint16_t s0 = 1 << (shift - 2);
uint16_t s1 = 1 << (shift);
uint16_t s2 = 1 << (shift + 2);
uint16_t s4 = 1 << (shift + 4);
if (t > 1) {
uint16_t i = 0;
uint16_t loff = 1 << (metal::clamp((shift + 14) % 16, 2, 12));
uint16_t hoff = shift > 4 ? 4 : shift == 4 ? 16 : 1;
uint16_t ind = idx;
if (shift == 0) {
ind >>= 2;
} else if (shift == 14) {
ind = (ind & 0xfff) + (ind >> 12);
} else if (shift == 2) {
} else if (shift == 4) {
ind = ((ind >> 4) & 0x3) + (ind & ~0x3f);
} else if (shift == 6) {
ind = ((ind / s0) % 4) * s1 + ((ind / s1) % 4) + (ind / s2) * s2;
} else {
ind = ((ind / 16) % s000) * 16 + ((ind / s0) % 4) * s1 +
((ind / s1) % 4) + (ind / s2) * s2;
}
for (uint16_t high = 0; high < 4; high++) {
uint16_t sub_ind = ind;
for (uint16_t low = 0; low < 4; low++) {
swap_V[sub_ind] = min_V[i];
i++;
sub_ind += loff;
}
ind += hoff;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
for (uint16_t i = 0; i < 16; i++) {
min_V[i] = swap_V[idx + i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
uint16_t rolled_t = use_overlap ? t : (t + 64) % 128;
T w_t = w[tid.x * timesteps + rolled_t];
for (uint16_t i = 0; i < 4; i++) {
thread float16_t min_val[4] = {INFINITY, INFINITY, INFINITY, INFINITY};
thread uint16_t min_idx[4] = {0};
uint16_t ii = idx * 4 + i * 16;
uint16_t big_idx = ii;
if (shift > 0 && shift < 14) {
big_idx = ((ii / s2) % 4) + (ii / s4 * s4);
if (shift > 2) {
big_idx += ((ii / 16) % s0) * 4;
}
} else if (shift == 14 && t > 0) {
big_idx >>= 2;
}
uint16_t loff = t == 0 ? 4 : s1;
uint16_t hoff = (t == 0 || shift == 14) ? 1 : s2;
for (uint16_t high = 0; high < 4; high++) {
uint16_t sub_ind = big_idx;
for (uint16_t low = 0; low < 4; low++) {
float mse = inst3(sub_ind ^ flip) - w_t;
mse *= mse;
float16_t new_val = min_V[i * 4 + high] + mse;
if (new_val < min_val[low]) {
min_val[low] = new_val;
min_idx[low] = high;
}
sub_ind += loff;
}
big_idx += hoff;
}
for (uint16_t j = 0; j < 4; j++) {
min_V[i * 4 + j] = min_val[j];
pointers[tid.x * L2 / 4 * timesteps + t * L2 / 4 + idx + i * 4 + j] =
min_idx[j];
}
}
if (t == 0 && use_overlap) {
uint16_t over = overlap[tid.x * 128 + 64];
over = over & ((1 << 14) - 1);
for (uint16_t i = 0; i < 16; i++) {
uint16_t rs = (over >> 2) ^ 1;
uint16_t ls = (idx + i) & ((1 << 12) - 1);
min_V[i] = rs == ls ? min_V[i] : INFINITY;
}
}
}
if (use_overlap) {
uint16_t over = overlap[tid.x * 128 + 64];
over = over & ((1 << 14) - 1);
uint16_t node =
(over % 4) * 4096 + ((over / 4) % 1024) * 4 + (over / 4096) % 4;
for (uint16_t i = 0; i < 16; i++) {
min_V[i] = (idx + i) == node ? min_V[i] : INFINITY;
}
}
for (uint16_t i = 0; i < 16; i++) {
score[tid.x * L2 / 4 + idx + i] = min_V[i];
}
}
uint16_t remove_bits(uint16_t i, uint16_t shift) {
uint16_t lower = i & ((1 << shift) - 1);
uint16_t upper = i & ~((1 << (shift + 2)) - 1);
return lower + (upper >> 2);
}
uint16_t swap_bits(uint16_t i, uint16_t shift) {
uint16_t diff = ((i >> shift) ^ i) & 0x3;
i = i ^ diff;
i ^= diff << shift;
return i;
}
template <const bool use_overlap, const int bits = 2, const int timesteps = 128>
[[kernel]] void trellis_backtrack(
const device uint32_t* start [[buffer(0)]],
const device uint8_t* pointers [[buffer(1)]],
device uint16_t* out [[buffer(2)]],
const device uint16_t* overlap [[buffer(3)]],
uint3 tid [[thread_position_in_grid]]) {
constexpr uint16_t L = 16;
uint16_t node = start[tid.x];
uint16_t dir =
pointers[tid.x * timesteps * 16384 + (timesteps - 1) * 16384 + node];
node = (node % 4) * 4096 + ((node / 4) % 1024) * 4 + (node / 4096) % 4;
node ^= 1;
node += dir * 16384;
out[tid.x * timesteps + timesteps - 1] = node;
for (int t = timesteps - 2; t >= 0; t--) {
uint16_t shift = (t % (L / bits)) * bits;
uint16_t mask = ((1 << L) - 1) ^ (((1 << bits) - 1) << shift);
uint16_t flip = t % (L / bits) == 0 ? 1 << bits : 1;
uint16_t i = (node & mask) ^ flip;
if (shift > 0) {
i = remove_bits(i, shift);
}
if (t == 0) {
i >>= 2;
}
if (t % 2 == 1 || t == 0) {
i ^= 1;
}
shift = shift == 0 ? L : shift;
if (t > 0) {
i = swap_bits(i, shift - 2);
}
shift = shift == L ? 0 : shift;
uint16_t last_p = pointers[tid.x * timesteps * 16384 + t * 16384 + i];
if ((t % 8 == 1 && t > 1) || t == 0) {
last_p ^= 1;
}
node = ((node & mask) ^ flip) | (last_p << shift);
if (t == 0 && use_overlap) {
uint16_t over = overlap[tid.x * 128 + 64];
over = over & ((1 << 14) - 1);
node = (node & 0xfffc) + (over & 0x3);
}
out[tid.x * timesteps + t] = node;
}
}

View File

@@ -120,4 +120,34 @@
instantiate_quantized_groups(6) \
instantiate_quantized_groups(8)
instantiate_kernel(
"trellis_viterbi_float16_t_overlap_0",
trellis_viterbi,
float16_t,
false)
instantiate_kernel(
"trellis_viterbi_float16_t_overlap_1",
trellis_viterbi,
float16_t,
true)
instantiate_kernel(
"trellis_backtrack_overlap_0",
trellis_backtrack,
false)
instantiate_kernel(
"trellis_backtrack_overlap_1",
trellis_backtrack,
true)
instantiate_kernel(
"qmv_fast_float16_t_gs_64_b_2_batch_0_mode_1",
qmv_fast,
float16_t,
64,
2,
false,
true)
instantiate_quantized_all() // clang-format on

View File

@@ -380,69 +380,11 @@ template <typename T, int N_READS = RMS_N_READS>
}
// clang-format off
#define instantiate_rms_single_row(name, itype) \
template [[host_name("rms" #name)]] [[kernel]] void \
rms_single_row<itype>( \
const device itype* x, \
const device itype* w, \
device itype* out, \
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
\
template [[host_name("vjp_rms" #name)]] [[kernel]] void \
vjp_rms_single_row<itype>( \
const device itype* x, \
const device itype* w, \
const device itype* g, \
device itype* gx, \
device itype* gw, \
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_rms_looped(name, itype) \
template [[host_name("rms_looped" #name)]] [[kernel]] void \
rms_looped<itype>( \
const device itype* x, \
const device itype* w, \
device itype* out, \
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
\
template [[host_name("vjp_rms_looped" #name)]] [[kernel]] void \
vjp_rms_looped<itype>( \
const device itype* x, \
const device itype* w, \
const device itype* g, \
device itype* gx, \
device itype* gw, \
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_rms(name, itype) \
instantiate_rms_single_row(name, itype) \
instantiate_rms_looped(name, itype)
#define instantiate_rms(name, itype) \
instantiate_kernel("rms" #name, rms_single_row, itype) \
instantiate_kernel("vjp_rms" #name, vjp_rms_single_row, itype) \
instantiate_kernel("rms_looped" #name, rms_looped, itype) \
instantiate_kernel("vjp_rms_looped" #name, vjp_rms_looped, itype)
instantiate_rms(float32, float)
instantiate_rms(float16, half)

View File

@@ -1,11 +1,11 @@
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/sdpa_vector.h"
// clang-format off
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/sdpa_vector.h"
using namespace metal;
// clang-format off
// SDPA vector instantiations
#define instantiate_sdpa_vector_aggregation(type, value_dim) \
instantiate_kernel( \
@@ -32,9 +32,11 @@ using namespace metal;
instantiate_sdpa_vector(type, 64, 64) \
instantiate_sdpa_vector(type, 96, 96) \
instantiate_sdpa_vector(type, 128, 128) \
instantiate_sdpa_vector(type, 256, 256) \
instantiate_sdpa_vector_aggregation(type, 64) \
instantiate_sdpa_vector_aggregation(type, 96) \
instantiate_sdpa_vector_aggregation(type, 128)
instantiate_sdpa_vector_aggregation(type, 128) \
instantiate_sdpa_vector_aggregation(type, 256)
instantiate_sdpa_vector_heads(float)
instantiate_sdpa_vector_heads(bfloat16_t)

View File

@@ -2,6 +2,8 @@
#pragma once
#include "mlx/backend/metal/kernels/binary_ops.h"
#define DEFINE_SIMD_SCAN() \
template <typename T, metal::enable_if_t<sizeof(T) < 8, bool> = true> \
T simd_scan(T val) { \
@@ -139,6 +141,29 @@ struct CumMin {
}
};
template <typename U>
struct CumLogaddexp {
static constexpr constant U init = Limits<U>::min;
template <typename T>
U operator()(U a, T b) {
return LogAddExp{}(a, static_cast<U>(b));
}
U simd_scan(U x) {
for (int i = 1; i <= 16; i *= 2) {
U other = simd_shuffle_and_fill_up(x, init, i);
x = LogAddExp{}(x, other);
}
return x;
}
U simd_exclusive_scan(U x) {
x = simd_scan(x);
return simd_shuffle_and_fill_up(x, init, 1);
}
};
template <typename T, typename U, int N_READS, bool reverse>
inline void load_unsafe(U values[N_READS], const device T* input) {
if (reverse) {

View File

@@ -101,4 +101,7 @@ instantiate_scan_helper(min_int64_int64, int64_t, int64_t, CumMi
instantiate_scan_helper(min_float16_float16, half, half, CumMin, 4)
instantiate_scan_helper(min_float32_float32, float, float, CumMin, 4)
instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMin, 4)
instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin, 2) // clang-format on
instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin, 2)
instantiate_scan_helper(logaddexp_float16_float16, half, half, CumLogaddexp, 4)
instantiate_scan_helper(logaddexp_float32_float32, float, float, CumLogaddexp, 4)
instantiate_scan_helper(logaddexp_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumLogaddexp, 4) // clang-format on

View File

@@ -6,6 +6,9 @@ using namespace metal;
constant bool has_mask [[function_constant(20)]];
constant bool query_transposed [[function_constant(21)]];
constant bool do_causal [[function_constant(22)]];
constant bool bool_mask [[function_constant(23)]];
constant bool float_mask [[function_constant(24)]];
template <typename T, int D, int V = D>
[[kernel]] void sdpa_vector(
@@ -13,17 +16,21 @@ template <typename T, int D, int V = D>
const device T* keys [[buffer(1)]],
const device T* values [[buffer(2)]],
device T* out [[buffer(3)]],
const constant int& gqa_factor,
const constant int& N,
const constant size_t& k_head_stride,
const constant size_t& k_seq_stride,
const constant size_t& v_head_stride,
const constant size_t& v_seq_stride,
const constant float& scale,
const device bool* mask [[function_constant(has_mask)]],
const constant int& mask_kv_seq_stride [[function_constant(has_mask)]],
const constant int& mask_q_seq_stride [[function_constant(has_mask)]],
const constant int& mask_head_stride [[function_constant(has_mask)]],
const constant int& gqa_factor [[buffer(4)]],
const constant int& N [[buffer(5)]],
const constant size_t& k_head_stride [[buffer(6)]],
const constant size_t& k_seq_stride [[buffer(7)]],
const constant size_t& v_head_stride [[buffer(8)]],
const constant size_t& v_seq_stride [[buffer(9)]],
const constant float& scale [[buffer(10)]],
const device bool* bmask [[buffer(11), function_constant(bool_mask)]],
const device T* fmask [[buffer(12), function_constant(float_mask)]],
const constant int& mask_kv_seq_stride
[[buffer(13), function_constant(has_mask)]],
const constant int& mask_q_seq_stride
[[buffer(14), function_constant(has_mask)]],
const constant int& mask_head_stride
[[buffer(15), function_constant(has_mask)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 tpg [[threadgroups_per_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
@@ -57,8 +64,12 @@ template <typename T, int D, int V = D>
simd_lid * qk_per_thread;
values += kv_head_idx * v_head_stride + simd_gid * v_seq_stride +
simd_lid * v_per_thread;
if (has_mask) {
mask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride +
if (bool_mask) {
bmask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride +
q_seq_idx * mask_q_seq_stride;
}
if (float_mask) {
fmask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride +
q_seq_idx * mask_q_seq_stride;
}
@@ -77,7 +88,13 @@ template <typename T, int D, int V = D>
// For each key
for (int i = simd_gid; i < N; i += BN) {
if (!has_mask || mask[0]) {
bool use_key = true;
if (do_causal) {
use_key = i <= (N - int(tpg.y) + int(q_seq_idx));
} else if (bool_mask) {
use_key = bmask[0];
}
if (use_key) {
// Read the key
for (int j = 0; j < qk_per_thread; j++) {
k[j] = keys[j];
@@ -89,6 +106,9 @@ template <typename T, int D, int V = D>
score += q[j] * k[j];
}
score = simd_sum(score);
if (float_mask) {
score += max(Limits<U>::finite_min, static_cast<U>(fmask[0]));
}
// Update the accumulators
U new_max = max(max_score, score);
@@ -107,8 +127,11 @@ template <typename T, int D, int V = D>
// Move the pointers to the next kv
keys += inner_k_stride;
values += inner_v_stride;
if (has_mask) {
mask += BN * mask_kv_seq_stride;
if (bool_mask) {
bmask += BN * mask_kv_seq_stride;
}
if (float_mask) {
fmask += BN * mask_kv_seq_stride;
}
}
@@ -149,17 +172,21 @@ template <typename T, int D, int V = D>
device float* out [[buffer(3)]],
device float* sums [[buffer(4)]],
device float* maxs [[buffer(5)]],
const constant int& gqa_factor,
const constant int& N,
const constant size_t& k_head_stride,
const constant size_t& k_seq_stride,
const constant size_t& v_head_stride,
const constant size_t& v_seq_stride,
const constant float& scale,
const device bool* mask [[function_constant(has_mask)]],
const constant int& mask_kv_seq_stride [[function_constant(has_mask)]],
const constant int& mask_q_seq_stride [[function_constant(has_mask)]],
const constant int& mask_head_stride [[function_constant(has_mask)]],
const constant int& gqa_factor [[buffer(6)]],
const constant int& N [[buffer(7)]],
const constant size_t& k_head_stride [[buffer(8)]],
const constant size_t& k_seq_stride [[buffer(9)]],
const constant size_t& v_head_stride [[buffer(10)]],
const constant size_t& v_seq_stride [[buffer(11)]],
const constant float& scale [[buffer(12)]],
const device bool* bmask [[buffer(13), function_constant(bool_mask)]],
const device T* fmask [[buffer(14), function_constant(float_mask)]],
const constant int& mask_kv_seq_stride
[[buffer(15), function_constant(has_mask)]],
const constant int& mask_q_seq_stride
[[buffer(16), function_constant(has_mask)]],
const constant int& mask_head_stride
[[buffer(17), function_constant(has_mask)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 tpg [[threadgroups_per_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
@@ -197,8 +224,13 @@ template <typename T, int D, int V = D>
values += kv_head_idx * v_head_stride +
(block_idx * BN + simd_gid) * v_seq_stride + simd_lid * v_per_thread;
out += o_offset * blocks * V + block_idx * V + simd_lid * v_per_thread;
if (has_mask) {
mask += head_idx * mask_head_stride +
if (bool_mask) {
bmask += head_idx * mask_head_stride +
(block_idx * BN + simd_gid) * mask_kv_seq_stride +
q_seq_idx * mask_q_seq_stride;
}
if (float_mask) {
fmask += head_idx * mask_head_stride +
(block_idx * BN + simd_gid) * mask_kv_seq_stride +
q_seq_idx * mask_q_seq_stride;
}
@@ -218,7 +250,13 @@ template <typename T, int D, int V = D>
// For each key
for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) {
if (!has_mask || mask[0]) {
bool use_key = true;
if (do_causal) {
use_key = i <= (N - int(tpg.y) + int(q_seq_idx));
} else if (bool_mask) {
use_key = bmask[0];
}
if (use_key) {
// Read the key
for (int i = 0; i < qk_per_thread; i++) {
k[i] = keys[i];
@@ -230,6 +268,9 @@ template <typename T, int D, int V = D>
score += q[i] * k[i];
}
score = simd_sum(score);
if (float_mask) {
score += fmask[0];
}
// Update the accumulators
U new_max = max(max_score, score);
@@ -248,8 +289,11 @@ template <typename T, int D, int V = D>
// Move the pointers to the next kv
keys += blocks * inner_k_stride;
values += blocks * inner_v_stride;
if (has_mask) {
mask += BN * blocks * mask_kv_seq_stride;
if (bool_mask) {
bmask += BN * blocks * mask_kv_seq_stride;
}
if (float_mask) {
fmask += BN * blocks * mask_kv_seq_stride;
}
}

View File

@@ -9,47 +9,13 @@ using namespace metal;
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/softmax.h"
#define instantiate_softmax(name, itype) \
template [[host_name("block_softmax_" #name)]] [[kernel]] void \
softmax_single_row<itype>( \
const device itype* in, \
device itype* out, \
constant int& axis_size, \
uint gid [[thread_position_in_grid]], \
uint _lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
template [[host_name("looped_softmax_" #name)]] [[kernel]] void \
softmax_looped<itype>( \
const device itype* in, \
device itype* out, \
constant int& axis_size, \
uint gid [[threadgroup_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_softmax(name, itype) \
instantiate_kernel("block_softmax_" #name, softmax_single_row, itype) \
instantiate_kernel("looped_softmax_" #name, softmax_looped, itype)
#define instantiate_softmax_precise(name, itype) \
template [[host_name("block_softmax_precise_" #name)]] [[kernel]] void \
softmax_single_row<itype, float>( \
const device itype* in, \
device itype* out, \
constant int& axis_size, \
uint gid [[thread_position_in_grid]], \
uint _lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
template [[host_name("looped_softmax_precise_" #name)]] [[kernel]] void \
softmax_looped<itype, float>( \
const device itype* in, \
device itype* out, \
constant int& axis_size, \
uint gid [[threadgroup_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_softmax_precise(name, itype) \
instantiate_kernel("block_softmax_precise_" #name, softmax_single_row, itype, float) \
instantiate_kernel("looped_softmax_precise_" #name, softmax_looped, itype, float)
instantiate_softmax(float32, float)
instantiate_softmax(float16, half)

View File

@@ -229,7 +229,7 @@ template <
// Init to -Inf
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kRowsPT; ++i) {
max_score[i] = Limits<AccumType>::min;
max_score[i] = Limits<AccumType>::finite_min;
}
int kb_lim = params->NK;
@@ -237,6 +237,7 @@ template <
if (do_causal) {
int q_max = (tid.x + 1) * BQ + params->qL_off;
kb_lim = (q_max + BK - 1) / BK;
kb_lim = min(params->NK, kb_lim);
}
// Loop over KV seq length
@@ -272,7 +273,7 @@ template <
if (!align_K && kb == (params->NK_aligned)) {
using stile_t = decltype(Stile);
using selem_t = typename stile_t::elem_type;
constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity();
constexpr auto neg_inf = Limits<selem_t>::finite_min;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < stile_t::kTileRows; i++) {
@@ -290,10 +291,10 @@ template <
}
// Mask out if causal
if (do_causal && kb >= (kb_lim - (BQ + BK - 1) / BK - int(!align_K))) {
if (do_causal && kb >= (kb_lim - ((BQ + BK - 1) / BK) - int(!align_K))) {
using stile_t = decltype(Stile);
using selem_t = typename stile_t::elem_type;
constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity();
constexpr auto neg_inf = Limits<selem_t>::finite_min;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < stile_t::kTileRows; i++) {
@@ -316,7 +317,7 @@ template <
if (has_mask) {
using stile_t = decltype(Stile);
using selem_t = typename stile_t::elem_type;
constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity();
constexpr auto neg_inf = Limits<selem_t>::finite_min;
constexpr bool is_bool = is_same_v<MaskType, bool>;
using melem_t = typename metal::conditional_t<is_bool, bool, selem_t>;

View File

@@ -73,6 +73,9 @@ instantiate_unary_all_same(Conjugate, complex64, complex64_t)
instantiate_unary_all_same(Cos, complex64, complex64_t)
instantiate_unary_all_same(Cosh, complex64, complex64_t)
instantiate_unary_all_same(Exp, complex64, complex64_t)
instantiate_unary_all_same(Log, complex64, complex64_t)
instantiate_unary_all_same(Log2, complex64, complex64_t)
instantiate_unary_all_same(Log10, complex64, complex64_t)
instantiate_unary_all_same(Negative, complex64, complex64_t)
instantiate_unary_all_same(Sign, complex64, complex64_t)
instantiate_unary_all_same(Sin, complex64, complex64_t)

View File

@@ -257,6 +257,13 @@ struct Log {
T operator()(T x) {
return metal::precise::log(x);
};
template <>
complex64_t operator()(complex64_t x) {
auto r = metal::precise::log(Abs{}(x).real);
auto i = metal::precise::atan2(x.imag, x.real);
return {r, i};
};
};
struct Log2 {
@@ -264,6 +271,12 @@ struct Log2 {
T operator()(T x) {
return metal::precise::log2(x);
};
template <>
complex64_t operator()(complex64_t x) {
auto y = Log{}(x);
return {y.real / M_LN2_F, y.imag / M_LN2_F};
};
};
struct Log10 {
@@ -271,6 +284,12 @@ struct Log10 {
T operator()(T x) {
return metal::precise::log10(x);
};
template <>
complex64_t operator()(complex64_t x) {
auto y = Log{}(x);
return {y.real / M_LN10_F, y.imag / M_LN10_F};
};
};
struct Log1p {

View File

@@ -0,0 +1,96 @@
// Copyright © 2023-2024 Apple Inc.
#include <algorithm>
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
namespace mlx::core {
constexpr int LOGSUMEXP_LOOPED_LIMIT = 4096;
void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
if (!issubdtype(out.dtype(), floating)) {
throw std::runtime_error(
"[logsumexp] Does not support non-floating point types.");
}
auto& s = stream();
auto& d = metal::device(s.device);
// Make sure that the last dimension is contiguous
auto ensure_contiguous = [&s, &d](const array& x) {
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
return x;
} else {
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
d.add_temporary(x_copy, s.index);
return x_copy;
}
};
auto in = ensure_contiguous(inputs[0]);
if (in.flags().row_contiguous) {
out.set_data(allocator::malloc(out.nbytes()));
} else {
auto n = in.shape(-1);
auto flags = in.flags();
auto strides = in.strides();
for (auto& s : strides) {
s /= n;
}
bool col_contig = strides[0] == 1;
for (int i = 1; col_contig && i < strides.size(); ++i) {
col_contig &=
(out.shape(i) == 1 || strides[i - 1] == out.shape(i) * strides[i]);
}
flags.col_contiguous = col_contig;
out.set_data(
allocator::malloc(in.nbytes() / n),
in.data_size() / n,
std::move(strides),
flags);
}
int axis_size = in.shape().back();
int n_rows = in.data_size() / axis_size;
const int simd_size = 32;
const int n_reads = 4;
const int looped_limit = LOGSUMEXP_LOOPED_LIMIT;
std::string kernel_name = (axis_size > looped_limit) ? "looped_" : "block_";
kernel_name += "logsumexp_";
kernel_name += type_to_name(out);
auto kernel = get_logsumexp_kernel(d, kernel_name, out);
auto& compute_encoder = d.get_command_encoder(s.index);
{
MTL::Size grid_dims, group_dims;
if (axis_size <= looped_limit) {
size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads;
size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size;
size_t threadgroup_size = simd_size * simds_needed;
assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup());
size_t n_threads = n_rows * threadgroup_size;
grid_dims = MTL::Size(n_threads, 1, 1);
group_dims = MTL::Size(threadgroup_size, 1, 1);
} else {
size_t threadgroup_size = kernel->maxTotalThreadsPerThreadgroup();
size_t n_threads = n_rows * threadgroup_size;
grid_dims = MTL::Size(n_threads, 1, 1);
group_dims = MTL::Size(threadgroup_size, 1, 1);
}
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder.set_bytes(axis_size, 2);
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
}
} // namespace mlx::core

View File

@@ -382,7 +382,7 @@ void steel_matmul(
int split_k_partition_size = gemm_k_iterations * bk;
array C_split({split_k_partitions, M, N}, float32, nullptr, {});
C_split.set_data(allocator::malloc_or_wait(C_split.nbytes()));
C_split.set_data(allocator::malloc(C_split.nbytes()));
copies.push_back(C_split);
bool mn_aligned = M % bm == 0 && N % bn == 0;
@@ -513,7 +513,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
return;
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
/////////////////////////////////////////////////////////////////////////////
// Init checks and prep
@@ -677,7 +677,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
throw std::runtime_error(
"[matmul] Does not yet support non-floating point types.");
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
auto& d = metal::device(s.device);
@@ -860,7 +860,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
int split_k_partition_size = gemm_k_iterations * bk;
array C_split({split_k_partitions, M, N}, float32, nullptr, {});
C_split.set_data(allocator::malloc_or_wait(C_split.nbytes()));
C_split.set_data(allocator::malloc(C_split.nbytes()));
copies.push_back(C_split);
bool mn_aligned = M % bm == 0 && N % bn == 0;
@@ -1096,7 +1096,7 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
return;
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
/////////////////////////////////////////////////////////////////////////////
// Init checks and prep
@@ -1484,7 +1484,7 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
return;
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
/////////////////////////////////////////////////////////////////////////////
// Init checks and prep

View File

@@ -12,71 +12,6 @@ namespace mlx::core::metal {
/* Check if the Metal backend is available. */
bool is_available();
/* Get the actively used memory in bytes.
*
* Note, this will not always match memory use reported by the system because
* it does not include cached memory buffers.
* */
size_t get_active_memory();
/* Get the peak amount of used memory in bytes.
*
* The maximum memory used recorded from the beginning of the program
* execution or since the last call to reset_peak_memory.
* */
size_t get_peak_memory();
/* Reset the peak memory to zero.
* */
void reset_peak_memory();
/* Get the cache size in bytes.
*
* The cache includes memory not currently used that has not been returned
* to the system allocator.
* */
size_t get_cache_memory();
/* Set the memory limit.
* Calls to malloc will wait on scheduled tasks if the limit is exceeded. If
* there are no more scheduled tasks an error will be raised if relaxed
* is false or memory will be allocated (including the potential for
* swap) if relaxed is true.
*
* The memory limit defaults to 1.5 times the maximum recommended working set
* size reported by the device.
*
* Returns the previous memory limit.
* */
size_t set_memory_limit(size_t limit, bool relaxed = true);
/* Set the free cache limit.
* If using more than the given limit, free memory will be reclaimed
* from the cache on the next allocation. To disable the cache,
* set the limit to 0.
*
* The cache limit defaults to the memory limit.
*
* Returns the previous cache limit.
* */
size_t set_cache_limit(size_t limit);
/* Clear the memory cache. */
void clear_cache();
/* Set the wired size limit.
*
* Note, this function is only useful for macOS 15.0 or higher.
*
* The wired limit is the total size in bytes of memory that will be kept
* resident. The default value is ``0``.
*
* Setting a wired limit larger than system wired limit is an error.
*
* Returns the previous wired limit.
* */
size_t set_wired_limit(size_t limit);
/** Capture a GPU trace, saving it to an absolute file `path` */
void start_capture(std::string path = "");
void stop_capture();

View File

@@ -72,6 +72,13 @@ MTL::ComputePipelineState* get_softmax_kernel(
return d.get_kernel(kernel_name);
}
MTL::ComputePipelineState* get_logsumexp_kernel(
metal::Device& d,
const std::string& kernel_name,
const array&) {
return d.get_kernel(kernel_name);
}
MTL::ComputePipelineState* get_scan_kernel(
metal::Device& d,
const std::string& kernel_name,

View File

@@ -29,7 +29,7 @@ void RMSNorm::eval_gpu(
out.copy_shared_buffer(x);
} else {
out.set_data(
allocator::malloc_or_wait(x.data_size() * x.itemsize()),
allocator::malloc(x.data_size() * x.itemsize()),
x.data_size(),
x.strides(),
x.flags());
@@ -129,7 +129,7 @@ void RMSNormVJP::eval_gpu(
gx.copy_shared_buffer(g);
g_in_gx = true;
} else {
gx.set_data(allocator::malloc_or_wait(gx.nbytes()));
gx.set_data(allocator::malloc(gx.nbytes()));
}
if (g_copied && !g_in_gx) {
d.add_temporary(g, s.index);
@@ -146,11 +146,11 @@ void RMSNormVJP::eval_gpu(
if (!g_in_gx && donate_g) {
gw_temp.copy_shared_buffer(g);
} else {
gw_temp.set_data(allocator::malloc_or_wait(gw_temp.nbytes()));
gw_temp.set_data(allocator::malloc(gw_temp.nbytes()));
d.add_temporary(gw_temp, s.index);
}
}
gw.set_data(allocator::malloc_or_wait(gw.nbytes()));
gw.set_data(allocator::malloc(gw.nbytes()));
const int simd_size = 32;
const int n_reads = RMS_N_READS;
@@ -226,7 +226,7 @@ void LayerNorm::eval_gpu(
out.copy_shared_buffer(x);
} else {
out.set_data(
allocator::malloc_or_wait(x.data_size() * x.itemsize()),
allocator::malloc(x.data_size() * x.itemsize()),
x.data_size(),
x.strides(),
x.flags());
@@ -331,7 +331,7 @@ void LayerNormVJP::eval_gpu(
gx.copy_shared_buffer(g);
g_in_gx = true;
} else {
gx.set_data(allocator::malloc_or_wait(gx.nbytes()));
gx.set_data(allocator::malloc(gx.nbytes()));
}
if (g_copied && !g_in_gx) {
d.add_temporary(g, s.index);
@@ -348,12 +348,12 @@ void LayerNormVJP::eval_gpu(
if (!g_in_gx && donate_g) {
gw_temp.copy_shared_buffer(g);
} else {
gw_temp.set_data(allocator::malloc_or_wait(gw_temp.nbytes()));
gw_temp.set_data(allocator::malloc(gw_temp.nbytes()));
d.add_temporary(gw_temp, s.index);
}
}
gw.set_data(allocator::malloc_or_wait(gw.nbytes()));
gb.set_data(allocator::malloc_or_wait(gb.nbytes()));
gw.set_data(allocator::malloc(gw.nbytes()));
gb.set_data(allocator::malloc(gb.nbytes()));
// Finish with the gradient for b in case we had a b
auto& compute_encoder = d.get_command_encoder(s.index);

View File

@@ -16,6 +16,8 @@
#include "mlx/scheduler.h"
#include "mlx/utils.h"
#include <iostream>
namespace mlx::core {
template <typename T>
@@ -28,7 +30,7 @@ void arange_set_scalars(T start, T next, metal::CommandEncoder& enc) {
void reshape(const array& in, array& out, Stream s) {
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
copy_gpu_inplace(
in,
out,
@@ -58,7 +60,7 @@ static array compute_dynamic_offset(
if (donate) {
offset.copy_shared_buffer(indices);
} else {
offset.set_data(allocator::malloc_or_wait(offset.itemsize()));
offset.set_data(allocator::malloc(offset.itemsize()));
}
d.add_temporary(offset, s.index);
@@ -100,7 +102,7 @@ static array compute_dynamic_offset(
void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 0);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
if (out.size() == 0) {
return;
}
@@ -158,33 +160,25 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& s = stream();
void arg_reduce_dispatch(
const array& in,
array& out,
int axis,
std::string op_name,
const Stream& s) {
auto& d = metal::device(s.device);
std::string op_name;
switch (reduce_type_) {
case ArgReduce::ArgMin:
op_name = "argmin_";
break;
case ArgReduce::ArgMax:
op_name = "argmax_";
break;
}
// Prepare the shapes, strides and axis arguments.
auto in_strides = in.strides();
auto shape = in.shape();
auto out_strides = out.strides();
auto axis_stride = in_strides[axis_];
size_t axis_size = shape[axis_];
auto axis_stride = in_strides[axis];
size_t axis_size = shape[axis];
if (out_strides.size() == in_strides.size()) {
out_strides.erase(out_strides.begin() + axis_);
out_strides.erase(out_strides.begin() + axis);
}
in_strides.erase(in_strides.begin() + axis_);
shape.erase(shape.begin() + axis_);
in_strides.erase(in_strides.begin() + axis);
shape.erase(shape.begin() + axis);
size_t ndim = shape.size();
// ArgReduce
@@ -192,7 +186,7 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
int n_reads = 4;
auto& compute_encoder = d.get_command_encoder(s.index);
{
auto kernel = d.get_kernel(op_name + type_to_name(in));
auto kernel = d.get_kernel(op_name + "_" + type_to_name(in));
NS::UInteger thread_group_size = std::min(
(axis_size + n_reads - 1) / n_reads,
kernel->maxTotalThreadsPerThreadgroup());
@@ -226,6 +220,23 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
}
}
void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
std::string op_name;
switch (reduce_type_) {
case ArgReduce::ArgMin:
op_name = "argmin";
break;
case ArgReduce::ArgMax:
op_name = "argmax";
break;
}
auto& in = inputs[0];
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
arg_reduce_dispatch(in, out, axis_, op_name, s);
}
void AsType::eval_gpu(const std::vector<array>& inputs, array& out) {
CopyType ctype =
inputs[0].flags().contiguous ? CopyType::Vector : CopyType::General;
@@ -251,8 +262,10 @@ void Concatenate::eval_gpu(const std::vector<array>& inputs, array& out) {
void Contiguous::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.flags().row_contiguous ||
(allow_col_major_ && in.flags().col_contiguous)) {
constexpr size_t extra_bytes = 16384;
if (in.buffer_size() <= out.nbytes() + extra_bytes &&
(in.flags().row_contiguous ||
(allow_col_major_ && in.flags().col_contiguous))) {
out.copy_shared_buffer(in);
} else {
copy_gpu(in, out, CopyType::General);
@@ -333,7 +346,7 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
size_t elems_per_key = out.size() / num_keys;
size_t bytes_per_key = out.itemsize() * elems_per_key;
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
if (out.size() == 0) {
return;
}
@@ -397,7 +410,7 @@ void DynamicSlice::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0];
auto& start = inputs[1];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
auto s = stream();
auto in_offset = compute_dynamic_offset(start, in.strides(), axes_, s);
copy_gpu_inplace(
@@ -554,7 +567,7 @@ void View::eval_gpu(const std::vector<array>& inputs, array& out) {
in, strides, in.flags(), in.data_size() * ibytes / obytes);
} else {
auto tmp = array(in.shape(), in.dtype(), nullptr, {});
tmp.set_data(allocator::malloc_or_wait(tmp.nbytes()));
tmp.set_data(allocator::malloc(tmp.nbytes()));
copy_gpu_inplace(in, tmp, CopyType::General, stream());
auto flags = out.flags();

View File

@@ -7,11 +7,14 @@
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/reduce.h"
#include "mlx/backend/metal/slicing.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/fast_primitives.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
#include <iostream>
namespace mlx::core {
void launch_qmm(
@@ -31,6 +34,7 @@ void launch_qmm(
bool gather,
bool aligned,
bool quad,
const std::string& mode,
const Stream& s) {
auto& x_pre = inputs[0];
auto& w_pre = inputs[1];
@@ -54,8 +58,12 @@ void launch_qmm(
};
auto x = ensure_row_contiguous_last_dims(x_pre);
auto w = ensure_row_contiguous_last_dims(w_pre);
auto scales = ensure_row_contiguous_last_dims(scales_pre);
auto biases = ensure_row_contiguous_last_dims(biases_pre);
auto scales = scales_pre;
auto biases = biases_pre;
if (mode == "affine") {
scales = ensure_row_contiguous_last_dims(scales_pre);
biases = ensure_row_contiguous_last_dims(biases_pre);
}
int x_batch_ndims = x.ndim() - 2;
auto& x_shape = x.shape();
@@ -68,6 +76,8 @@ void launch_qmm(
std::string aligned_n = (O % 32) == 0 ? "true" : "false";
bool is_trellis = (mode == "trellis");
std::ostringstream kname;
auto type_string = get_type_string(x.dtype());
kname << name << "_" << type_string << "_gs_" << group_size << "_b_" << bits;
@@ -80,24 +90,47 @@ void launch_qmm(
if (!gather) {
kname << "_batch_" << batched;
}
if (mode == "trellis") {
kname << "_mode_" << is_trellis;
}
// Encode and dispatch kernel
std::string template_def;
if (quad) {
template_def = get_template_definition(
kname.str(), name, type_string, group_size, bits, D, batched);
kname.str(),
name,
type_string,
group_size,
bits,
D,
batched,
is_trellis);
} else if (aligned && !gather) {
template_def = get_template_definition(
kname.str(), name, type_string, group_size, bits, aligned_n, batched);
kname.str(),
name,
type_string,
group_size,
bits,
aligned_n,
batched,
is_trellis);
} else if (!gather && !aligned) {
template_def = get_template_definition(
kname.str(), name, type_string, group_size, bits, batched);
kname.str(), name, type_string, group_size, bits, batched, is_trellis);
} else if (aligned && gather) {
template_def = get_template_definition(
kname.str(), name, type_string, group_size, bits, aligned_n);
kname.str(),
name,
type_string,
group_size,
bits,
aligned_n,
is_trellis);
} else {
template_def = get_template_definition(
kname.str(), name, type_string, group_size, bits);
kname.str(), name, type_string, group_size, bits, is_trellis);
}
auto& d = metal::device(s.device);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
@@ -224,7 +257,7 @@ void qvm_split_k(
auto temp_shape = out.shape();
temp_shape.insert(temp_shape.end() - 2, split_k);
array intermediate(temp_shape, x.dtype(), nullptr, {});
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
d.add_temporary(intermediate, s.index);
std::ostringstream kname;
@@ -276,8 +309,9 @@ void qmm_op(
int group_size,
int bits,
bool gather,
const std::string& mode,
const Stream& s) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
MTL::Size group_dims;
MTL::Size grid_dims;
@@ -298,23 +332,69 @@ void qmm_op(
bool aligned = false;
bool quad = false;
auto get_qmv_batch_limit = [s](int D, int O) {
auto arch = metal::device(s.device).get_architecture();
auto arch_size = arch.back();
auto arch_gen = arch.substr(arch.size() - 3, 2);
if (arch_gen == "13" || arch_gen == "14") {
switch (arch_size) {
case 'd':
if (D <= 2048 && O <= 2048) {
return 32;
} else if (D <= 4096 && O <= 4096) {
return 18;
} else {
return 12;
}
default:
if (D <= 2048 && O <= 2048) {
return 14;
} else if (D <= 4096 && O <= 4096) {
return 10;
} else {
return 6;
}
}
} else {
switch (arch_size) {
case 'd':
if (D <= 2048 && O <= 2048) {
return 32;
} else if (D <= 4096 && O <= 4096) {
return 18;
} else {
return 12;
}
default:
if (D <= 2048 && O <= 2048) {
return 18;
} else if (D <= 4096 && O <= 4096) {
return 12;
} else {
return 10;
}
}
}
};
if (transpose) {
if (B < 6 && (D == 128 || D == 64) && is_power_of_2(bits)) {
auto qmv_batch_limit = get_qmv_batch_limit(D, O);
if (B < qmv_batch_limit && (D == 128 || D == 64) && is_power_of_2(bits)) {
name += "qmv_quad";
constexpr int quads_per_simd = 8;
constexpr int results_per_quadgroup = 8;
int bo = quads_per_simd * results_per_quadgroup;
int simdgroup_size = 32;
group_dims = MTL::Size(simdgroup_size, 1, 1);
grid_dims = MTL::Size((O + bo - 1) / bo, B, N);
grid_dims = MTL::Size(B, (O + bo - 1) / bo, N);
quad = true;
} else if (B < 6 && O % 8 == 0 && D % 512 == 0 && D >= 512) {
} else if (B < 10000 && O % 8 == 0 && D % 512 == 0 && D >= 512) {
name += "qmv_fast";
int bo = 8;
int bd = 32;
group_dims = MTL::Size(bd, 2, 1);
grid_dims = MTL::Size(B, O / bo, N);
} else if (B < 6) {
} else if (B < qmv_batch_limit) {
name += "qmv";
int bo = 8;
int bd = 32;
@@ -374,19 +454,34 @@ void qmm_op(
gather,
aligned,
quad,
mode,
s);
}
void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 4);
qmm_op(
inputs, out, transpose_, group_size_, bits_, /*gather=*/false, stream());
inputs,
out,
transpose_,
group_size_,
bits_,
/*gather=*/false,
mode_,
stream());
}
void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 6);
qmm_op(
inputs, out, transpose_, group_size_, bits_, /*gather=*/true, stream());
inputs,
out,
transpose_,
group_size_,
bits_,
/*gather=*/true,
mode_,
stream());
}
void fast::AffineQuantize::eval_gpu(
@@ -394,7 +489,7 @@ void fast::AffineQuantize::eval_gpu(
std::vector<array>& outputs) {
auto& w_pre = inputs[0];
auto& out = outputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
auto& d = metal::device(s.device);
@@ -425,8 +520,8 @@ void fast::AffineQuantize::eval_gpu(
} else {
auto& scales = outputs[1];
auto& biases = outputs[2];
scales.set_data(allocator::malloc_or_wait(scales.nbytes()));
biases.set_data(allocator::malloc_or_wait(biases.nbytes()));
scales.set_data(allocator::malloc(scales.nbytes()));
biases.set_data(allocator::malloc(biases.nbytes()));
compute_encoder.set_output_array(out, 1);
compute_encoder.set_output_array(scales, 2);
compute_encoder.set_output_array(biases, 3);
@@ -470,4 +565,123 @@ void fast::AffineQuantize::eval_gpu(
d.add_temporaries(std::move(copies), s.index);
}
void viterbi(
array& w,
array& scores,
array& pointers,
array& start,
array& overlap,
bool use_overlap,
const Stream& s) {
int B = scores.shape(0);
auto& d = metal::device(s.device);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_input_array(w, 0);
compute_encoder.set_output_array(scores, 1);
compute_encoder.set_output_array(pointers, 2);
if (use_overlap) {
compute_encoder.set_input_array(overlap, 3);
}
std::ostringstream kname;
auto type_string = get_type_string(w.dtype());
kname << "trellis_viterbi_" << type_string << "_overlap_" << use_overlap;
auto template_def = get_template_definition(
kname.str(), "trellis_viterbi", type_string, use_overlap);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder.set_compute_pipeline_state(kernel);
auto group_dims = MTL::Size(1, 1024, 1);
auto grid_dims = MTL::Size(B, 1024, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims);
arg_reduce_dispatch(scores, start, 1, "argmin", s);
}
void viterbi_backtrack(
array& start,
array& pointers,
array& out,
array& overlap,
bool use_overlap,
const Stream& s) {
int B = start.shape(0);
auto& d = metal::device(s.device);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_input_array(start, 0);
compute_encoder.set_input_array(pointers, 1);
compute_encoder.set_output_array(out, 2);
if (use_overlap) {
compute_encoder.set_input_array(overlap, 3);
}
std::ostringstream kname;
kname << "trellis_backtrack" << "_overlap_" << use_overlap;
auto template_def =
get_template_definition(kname.str(), "trellis_backtrack", use_overlap);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder.set_compute_pipeline_state(kernel);
auto group_dims = MTL::Size(256, 1, 1);
auto grid_dims = MTL::Size(B, 1, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
void fast::TrellisQuantize::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
auto& w_pre = inputs[0];
auto& out = outputs[0];
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
auto& d = metal::device(s.device);
std::vector<array> copies;
auto ensure_row_contiguous = [&copies, &s](const array& arr) {
if (arr.flags().row_contiguous) {
return arr;
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s);
copies.push_back(arr_copy);
return arr_copy;
}
};
auto w = ensure_row_contiguous(w_pre);
int B = w.shape(0);
int T = w.shape(1);
constexpr int num_states = 1 << 14;
array scores({B, num_states}, float16, nullptr, {});
scores.set_data(allocator::malloc(scores.nbytes()));
copies.push_back(scores);
array pointers({B, T, num_states}, uint8, nullptr, {});
pointers.set_data(allocator::malloc(pointers.nbytes()));
copies.push_back(pointers);
array start({B}, uint32, nullptr, {});
start.set_data(allocator::malloc(start.nbytes()));
copies.push_back(start);
array rolled({B, T}, uint16, nullptr, {});
rolled.set_data(allocator::malloc(rolled.nbytes()));
copies.push_back(rolled);
viterbi(w, scores, pointers, start, out, false, s);
viterbi_backtrack(start, pointers, rolled, out, false, s);
viterbi(w, scores, pointers, start, rolled, true, s);
viterbi_backtrack(start, pointers, out, rolled, true, s);
d.add_temporaries(std::move(copies), s.index);
}
} // namespace mlx::core

View File

@@ -347,7 +347,7 @@ void all_reduce_dispatch(
// Allocate an intermediate tensor to hold results if needed
array intermediate({n_rows}, out_type, nullptr, {});
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
d.add_temporary(intermediate, s.index);
// 1st pass
@@ -641,7 +641,7 @@ void strided_reduce_longcolumn(
intermediate_shape.insert(
intermediate_shape.end(), out.shape().begin(), out.shape().end());
array intermediate(std::move(intermediate_shape), out_type, nullptr, {});
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
d.add_temporary(intermediate, s.index);
// Prepare the arguments for the kernel
@@ -812,7 +812,7 @@ void strided_reduce_2pass(
intermediate_shape.insert(
intermediate_shape.end(), out.shape().begin(), out.shape().end());
array intermediate(std::move(intermediate_shape), out_type, nullptr, {});
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
d.add_temporary(intermediate, s.index);
// Prepare the arguments for the kernel
@@ -950,7 +950,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
// Minimum of 4 bytes since we use size 4 structs for all reduce
// and metal will complain o/w
size_t min_bytes = std::max(out.nbytes(), 4ul);
out.set_data(allocator::malloc_or_wait(min_bytes));
out.set_data(allocator::malloc(min_bytes));
std::string op_name;
switch (reduce_type_) {
case Reduce::And:

View File

@@ -38,4 +38,11 @@ void strided_reduce_general_dispatch(
metal::Device& d,
const Stream& s);
void arg_reduce_dispatch(
const array& in,
array& out,
int axis,
std::string op_name,
const Stream& s);
} // namespace mlx::core

View File

@@ -22,6 +22,7 @@ ResidencySet::ResidencySet(MTL::Device* d) {
}
throw std::runtime_error(msg.str());
}
wired_set_->requestResidency();
}
}
@@ -32,7 +33,6 @@ void ResidencySet::insert(MTL::Allocation* buf) {
if (wired_set_->allocatedSize() + buf->allocatedSize() <= capacity_) {
wired_set_->addAllocation(buf);
wired_set_->commit();
wired_set_->requestResidency();
} else {
unwired_set_.insert(buf);
}
@@ -76,7 +76,6 @@ void ResidencySet::resize(size_t size) {
}
}
wired_set_->commit();
wired_set_->requestResidency();
} else if (current_size > size) {
auto pool = new_scoped_memory_pool();
// Remove wired allocations until under capacity

Some files were not shown because too many files have changed in this diff Show More