mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-10 13:07:29 +08:00
Compare commits
93 Commits
Author | SHA1 | Date | |
---|---|---|---|
![]() |
83762691ba | ||
![]() |
2a41caa00e | ||
![]() |
6593281d25 | ||
![]() |
da98e8bce8 | ||
![]() |
be57a16a80 | ||
![]() |
1704809f29 | ||
![]() |
0cae0bdac8 | ||
![]() |
5a1a5d5ed1 | ||
![]() |
1683975acf | ||
![]() |
af705590ac | ||
![]() |
825124af8f | ||
![]() |
9c5e7da507 | ||
![]() |
481349495b | ||
![]() |
9daa6b003f | ||
![]() |
a3a632d567 | ||
![]() |
e496c5a4b4 | ||
![]() |
ea890d8710 | ||
![]() |
aa5d84f102 | ||
![]() |
f1606486d2 | ||
![]() |
87720a8908 | ||
![]() |
bb6565ef14 | ||
![]() |
7bb063bcb3 | ||
![]() |
b36dd472bb | ||
![]() |
167b759a38 | ||
![]() |
99b9868859 | ||
![]() |
6b2d5448f2 | ||
![]() |
eaf709b83e | ||
![]() |
f0e70afff0 | ||
![]() |
86984cad68 | ||
![]() |
fbc89e3ced | ||
![]() |
38c1e720c2 | ||
![]() |
600e87e03c | ||
![]() |
3836445241 | ||
![]() |
1d2c9d6a07 | ||
![]() |
e8ac6bd2f5 | ||
![]() |
fdadc4f22c | ||
![]() |
79b527f45f | ||
![]() |
dc4eada7f0 | ||
![]() |
70ebc3b598 | ||
![]() |
b13f2aed16 | ||
![]() |
5f04c0f818 | ||
![]() |
55935ccae7 | ||
![]() |
b529515eb1 | ||
![]() |
3cde719eb7 | ||
![]() |
5de6d94a90 | ||
![]() |
99eefd2ec0 | ||
![]() |
e9e268336b | ||
![]() |
7275ac7523 | ||
![]() |
c4189a38e4 | ||
![]() |
68d1b3256b | ||
![]() |
9c6953bda7 | ||
![]() |
ef7ece9851 | ||
![]() |
ddaa4b7dcb | ||
![]() |
dfae2c6989 | ||
![]() |
515f104926 | ||
![]() |
9ecefd56db | ||
![]() |
e5d35aa187 | ||
![]() |
00794c42bc | ||
![]() |
08a1bf3f10 | ||
![]() |
60c4154346 | ||
![]() |
f2c85308c1 | ||
![]() |
1a28b69ee2 | ||
![]() |
ba09f01ce8 | ||
![]() |
6cf48872b7 | ||
![]() |
7b3b8fa000 | ||
![]() |
ec5e2aae61 | ||
![]() |
86389bf970 | ||
![]() |
3290bfa690 | ||
![]() |
8777fd104f | ||
![]() |
c41f7565ed | ||
![]() |
9ba81e3da4 | ||
![]() |
c23888acd7 | ||
![]() |
f98ce25ab9 | ||
![]() |
de5f38fd48 | ||
![]() |
ec2854b13a | ||
![]() |
90823d2938 | ||
![]() |
5f5770e3a2 | ||
![]() |
28f39e9038 | ||
![]() |
b2d2b37888 | ||
![]() |
fe597e141c | ||
![]() |
72ca1539e0 | ||
![]() |
13b26775f1 | ||
![]() |
05d7118561 | ||
![]() |
98b901ad66 | ||
![]() |
5580b47291 | ||
![]() |
bc62932984 | ||
![]() |
a6b5d6e759 | ||
![]() |
a8931306e1 | ||
![]() |
fecdb8717e | ||
![]() |
916fd273ea | ||
![]() |
0da8506552 | ||
![]() |
eda7a7b43e | ||
![]() |
022eabb734 |
@@ -24,8 +24,8 @@ jobs:
|
||||
type: boolean
|
||||
default: false
|
||||
macos:
|
||||
xcode: "15.2.0"
|
||||
resource_class: macos.m1.medium.gen1
|
||||
xcode: "16.2.0"
|
||||
resource_class: m2pro.medium
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
@@ -93,12 +93,10 @@ jobs:
|
||||
- run:
|
||||
name: Install Python package
|
||||
command: |
|
||||
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF
|
||||
CMAKE_COMPILE_WARNING_AS_ERROR=ON" \
|
||||
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||
python3 setup.py build_ext --inplace
|
||||
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF \
|
||||
CMAKE_COMPILE_WARNING_AS_ERROR=ON" \
|
||||
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||
python3 setup.py develop
|
||||
- run:
|
||||
@@ -127,10 +125,15 @@ jobs:
|
||||
parameters:
|
||||
xcode_version:
|
||||
type: string
|
||||
default: "15.2.0"
|
||||
default: "16.2.0"
|
||||
macosx_deployment_target:
|
||||
type: string
|
||||
default: ""
|
||||
macos:
|
||||
xcode: << parameters.xcode_version >>
|
||||
resource_class: macos.m1.medium.gen1
|
||||
environment:
|
||||
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
|
||||
resource_class: m2pro.medium
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
@@ -152,7 +155,7 @@ jobs:
|
||||
command: |
|
||||
source env/bin/activate
|
||||
DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
||||
CMAKE_ARGS="CMAKE_COMPILE_WARNING_AS_ERROR=ON" \
|
||||
CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
|
||||
pip install -e . -v
|
||||
- run:
|
||||
name: Generate package stubs
|
||||
@@ -216,13 +219,18 @@ jobs:
|
||||
default: "3.9"
|
||||
xcode_version:
|
||||
type: string
|
||||
default: "15.2.0"
|
||||
default: "16.2.0"
|
||||
build_env:
|
||||
type: string
|
||||
default: ""
|
||||
macosx_deployment_target:
|
||||
type: string
|
||||
default: ""
|
||||
macos:
|
||||
xcode: << parameters.xcode_version >>
|
||||
resource_class: macos.m1.medium.gen1
|
||||
resource_class: m2pro.medium
|
||||
environment:
|
||||
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
@@ -243,7 +251,7 @@ jobs:
|
||||
name: Install Python package
|
||||
command: |
|
||||
source env/bin/activate
|
||||
DEV_RELEASE=1 \
|
||||
env -u MACOSX_DEPLOYMENT_TARGET DEV_RELEASE=1 \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
||||
pip install . -v
|
||||
- run:
|
||||
@@ -338,7 +346,7 @@ workflows:
|
||||
- mac_build_and_test:
|
||||
matrix:
|
||||
parameters:
|
||||
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
|
||||
macosx_deployment_target: ["13.5", "14.0"]
|
||||
- linux_build_and_test
|
||||
- build_documentation
|
||||
|
||||
@@ -358,8 +366,70 @@ workflows:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
xcode_version: ["15.0.0", "15.2.0"]
|
||||
macosx_deployment_target: ["13.5", "14.0", "15.0"]
|
||||
build_env: ["PYPI_RELEASE=1"]
|
||||
xcode_version: ["16.2.0", "15.0.0"]
|
||||
exclude:
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.9"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.10"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.11"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.12"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.13"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.9"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.10"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.11"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.12"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.9"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.10"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.11"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.12"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- build_documentation:
|
||||
filters:
|
||||
tags:
|
||||
@@ -382,7 +452,7 @@ workflows:
|
||||
requires: [ hold ]
|
||||
matrix:
|
||||
parameters:
|
||||
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
|
||||
macosx_deployment_target: ["13.5", "14.0"]
|
||||
- linux_build_and_test:
|
||||
requires: [ hold ]
|
||||
nightly_build:
|
||||
@@ -395,7 +465,54 @@ workflows:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
xcode_version: ["15.0.0", "15.2.0"]
|
||||
macosx_deployment_target: ["13.5", "14.0", "15.0"]
|
||||
xcode_version: ["16.2.0", "15.0.0"]
|
||||
exclude:
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.9"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.10"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.11"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.12"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.13"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.9"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.10"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.11"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.12"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.9"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.10"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.11"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.12"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
weekly_build:
|
||||
when:
|
||||
and:
|
||||
@@ -406,8 +523,70 @@ workflows:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
|
||||
macosx_deployment_target: ["13.5", "14.0", "15.0"]
|
||||
build_env: ["DEV_RELEASE=1"]
|
||||
xcode_version: ["16.2.0", "15.0.0"]
|
||||
exclude:
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.9"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.10"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.11"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.12"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.13"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.9"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.10"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.11"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.12"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.9"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.10"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.11"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.12"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
linux_test_release:
|
||||
when:
|
||||
and:
|
||||
|
1
.gitignore
vendored
1
.gitignore
vendored
@@ -36,6 +36,7 @@ share/python-wheels/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
uv.lock
|
||||
|
||||
# vim
|
||||
*.swp
|
||||
|
@@ -34,6 +34,7 @@ option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
|
||||
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
|
||||
option(MLX_BUILD_METAL "Build metal backend" ON)
|
||||
option(MLX_BUILD_CPU "Build cpu backend" ON)
|
||||
option(MLX_BUILD_CUDA "Build cuda backend" OFF)
|
||||
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
|
||||
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
|
||||
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
|
||||
@@ -83,6 +84,10 @@ if(MLX_BUILD_METAL)
|
||||
set(QUARTZ_LIB "-framework QuartzCore")
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_CUDA)
|
||||
enable_language(CUDA)
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_METAL AND NOT METAL_LIB)
|
||||
message(STATUS "Metal not found. Unable to build GPU")
|
||||
set(MLX_BUILD_METAL OFF)
|
||||
|
@@ -5,26 +5,26 @@ possible.
|
||||
|
||||
## Pull Requests
|
||||
|
||||
1. Fork and submit pull requests to the repo.
|
||||
1. Fork and submit pull requests to the repo.
|
||||
2. If you've added code that should be tested, add tests.
|
||||
3. If a change is likely to impact efficiency, run some of the benchmarks before
|
||||
and after the change. Examples of benchmarks can be found in `benchmarks/python/`.
|
||||
4. If you've changed APIs, update the documentation.
|
||||
5. Every PR should have passing tests and at least one review.
|
||||
5. Every PR should have passing tests and at least one review.
|
||||
6. For code formatting install `pre-commit` using something like `pip install pre-commit` and run `pre-commit install`.
|
||||
This should install hooks for running `black` and `clang-format` to ensure
|
||||
consistent style for C++ and python code.
|
||||
|
||||
|
||||
You can also run the formatters manually as follows:
|
||||
|
||||
```
|
||||
clang-format -i file.cpp
|
||||
```
|
||||
|
||||
```
|
||||
black file.py
|
||||
```
|
||||
|
||||
|
||||
```shell
|
||||
clang-format -i file.cpp
|
||||
```
|
||||
|
||||
```shell
|
||||
black file.py
|
||||
```
|
||||
|
||||
or run `pre-commit run --all-files` to check all files in the repo.
|
||||
|
||||
## Issues
|
||||
|
@@ -1,4 +1,6 @@
|
||||
include CMakeLists.txt
|
||||
include mlx.pc.in
|
||||
recursive-include mlx/ *
|
||||
include cmake/*
|
||||
include python/src/*
|
||||
include python/mlx/py.typed # support type hinting as in PEP-561
|
||||
|
74
benchmarks/python/gather_mm_bench.py
Normal file
74
benchmarks/python/gather_mm_bench.py
Normal file
@@ -0,0 +1,74 @@
|
||||
# Copyright © 2025 Apple Inc.
|
||||
|
||||
import mlx.core as mx
|
||||
from time_utils import time_fn
|
||||
|
||||
N = 1024
|
||||
D = 1024
|
||||
M = 1024
|
||||
E = 32
|
||||
I = 4
|
||||
|
||||
|
||||
def gather_sort(x, indices):
|
||||
N, M = indices.shape
|
||||
indices = indices.flatten()
|
||||
order = mx.argsort(indices)
|
||||
inv_order = mx.argsort(order)
|
||||
return x.flatten(0, -3)[order // M], indices[order], inv_order
|
||||
|
||||
|
||||
def scatter_unsort(x, inv_order, shape=None):
|
||||
x = x[inv_order]
|
||||
if shape is not None:
|
||||
x = mx.unflatten(x, 0, shape)
|
||||
return x
|
||||
|
||||
|
||||
def gather_mm_simulate(x, w, indices):
|
||||
x, idx, inv_order = gather_sort(x, indices)
|
||||
for i in range(2):
|
||||
y = mx.concatenate([x[i] @ w[j].T for i, j in enumerate(idx.tolist())], axis=0)
|
||||
x = y[:, None]
|
||||
x = scatter_unsort(x, inv_order, indices.shape)
|
||||
return x
|
||||
|
||||
|
||||
def time_gather_mm():
|
||||
x = mx.random.normal((N, 1, 1, D)) / 1024**0.5
|
||||
w1 = mx.random.normal((E, M, D)) / 1024**0.5
|
||||
w2 = mx.random.normal((E, D, M)) / 1024**0.5
|
||||
indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)
|
||||
sorted_indices = mx.sort(indices.flatten()).reshape(N, I)
|
||||
mx.eval(x, w1, w2, indices, sorted_indices)
|
||||
|
||||
def gather_mm(x, w1, w2, indices, sort):
|
||||
idx = indices
|
||||
inv_order = None
|
||||
if sort:
|
||||
x, idx, inv_order = gather_sort(x, indices)
|
||||
x = mx.gather_mm(x, w1.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort)
|
||||
x = mx.gather_mm(x, w2.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort)
|
||||
if sort:
|
||||
x = scatter_unsort(x, inv_order, indices.shape)
|
||||
return x
|
||||
|
||||
time_fn(gather_mm, x, w1, w2, indices, False)
|
||||
time_fn(gather_mm, x, w1, w2, sorted_indices, False)
|
||||
time_fn(gather_mm, x, w1, w2, indices, True)
|
||||
|
||||
x = mx.random.normal((N * I, D)) / 1024**0.5
|
||||
w1 = mx.random.normal((M, D)) / 1024**0.5
|
||||
w2 = mx.random.normal((D, M)) / 1024**0.5
|
||||
mx.eval(x, w1, w2)
|
||||
|
||||
def equivalent_matmul(x, w1, w2):
|
||||
x = x @ w1.T
|
||||
x = x @ w2.T
|
||||
return x
|
||||
|
||||
time_fn(equivalent_matmul, x, w1, w2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
time_gather_mm()
|
84
benchmarks/python/gather_qmm_bench.py
Normal file
84
benchmarks/python/gather_qmm_bench.py
Normal file
@@ -0,0 +1,84 @@
|
||||
# Copyright © 2025 Apple Inc.
|
||||
|
||||
import mlx.core as mx
|
||||
from time_utils import time_fn
|
||||
|
||||
N = 1024
|
||||
D = 1024
|
||||
M = 1024
|
||||
E = 32
|
||||
I = 4
|
||||
|
||||
|
||||
def gather_sort(x, indices):
|
||||
N, M = indices.shape
|
||||
indices = indices.flatten()
|
||||
order = mx.argsort(indices)
|
||||
inv_order = mx.argsort(order)
|
||||
return x.flatten(0, -3)[order // M], indices[order], inv_order
|
||||
|
||||
|
||||
def scatter_unsort(x, inv_order, shape=None):
|
||||
x = x[inv_order]
|
||||
if shape is not None:
|
||||
x = mx.unflatten(x, 0, shape)
|
||||
return x
|
||||
|
||||
|
||||
def gather_mm_simulate(x, w, indices):
|
||||
x, idx, inv_order = gather_sort(x, indices)
|
||||
for i in range(2):
|
||||
y = mx.concatenate(
|
||||
[
|
||||
mx.quantized_matmul(x[i], w[0][j], w[1][j], w[2][j], transpose=True)
|
||||
for i, j in enumerate(idx.tolist())
|
||||
],
|
||||
axis=0,
|
||||
)
|
||||
x = y[:, None]
|
||||
x = scatter_unsort(x, inv_order, indices.shape)
|
||||
return x
|
||||
|
||||
|
||||
def time_gather_qmm():
|
||||
x = mx.random.normal((N, 1, 1, D)) / 1024**0.5
|
||||
w1 = mx.random.normal((E, M, D)) / 1024**0.5
|
||||
w2 = mx.random.normal((E, D, M)) / 1024**0.5
|
||||
w1 = mx.quantize(w1)
|
||||
w2 = mx.quantize(w2)
|
||||
indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)
|
||||
sorted_indices = mx.sort(indices.flatten()).reshape(N, I)
|
||||
mx.eval(x, w1, w2, indices, sorted_indices)
|
||||
|
||||
def gather_mm(x, w1, w2, indices, sort):
|
||||
idx = indices
|
||||
inv_order = None
|
||||
if sort:
|
||||
x, idx, inv_order = gather_sort(x, indices)
|
||||
x = mx.gather_qmm(x, *w1, transpose=True, rhs_indices=idx, sorted_indices=sort)
|
||||
x = mx.gather_qmm(x, *w2, transpose=True, rhs_indices=idx, sorted_indices=sort)
|
||||
if sort:
|
||||
x = scatter_unsort(x, inv_order, indices.shape)
|
||||
return x
|
||||
|
||||
time_fn(gather_mm, x, w1, w2, indices, False)
|
||||
time_fn(gather_mm, x, w1, w2, sorted_indices, False)
|
||||
time_fn(gather_mm, x, w1, w2, indices, True)
|
||||
|
||||
x = mx.random.normal((N * I, D)) / 1024**0.5
|
||||
w1 = mx.random.normal((M, D)) / 1024**0.5
|
||||
w2 = mx.random.normal((D, M)) / 1024**0.5
|
||||
w1 = mx.quantize(w1)
|
||||
w2 = mx.quantize(w2)
|
||||
mx.eval(x, w1, w2)
|
||||
|
||||
def equivalent_matmul(x, w1, w2):
|
||||
x = mx.quantized_matmul(x, *w1, transpose=True)
|
||||
x = mx.quantized_matmul(x, *w2, transpose=True)
|
||||
return x
|
||||
|
||||
time_fn(equivalent_matmul, x, w1, w2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
time_gather_qmm()
|
@@ -93,9 +93,9 @@ Primitives
|
||||
^^^^^^^^^^^
|
||||
|
||||
A :class:`Primitive` is part of the computation graph of an :class:`array`. It
|
||||
defines how to create outputs arrays given a input arrays. Further, a
|
||||
defines how to create output arrays given input arrays. Further, a
|
||||
:class:`Primitive` has methods to run on the CPU or GPU and for function
|
||||
transformations such as ``vjp`` and ``jvp``. Lets go back to our example to be
|
||||
transformations such as ``vjp`` and ``jvp``. Let's go back to our example to be
|
||||
more concrete:
|
||||
|
||||
.. code-block:: C++
|
||||
@@ -128,7 +128,7 @@ more concrete:
|
||||
/** The vector-Jacobian product. */
|
||||
std::vector<array> vjp(
|
||||
const std::vector<array>& primals,
|
||||
const array& cotan,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) override;
|
||||
|
||||
@@ -469,7 +469,7 @@ one we just defined:
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
// Forward mode diff that pushes along the tangents
|
||||
// The jvp transform on the primitive can built with ops
|
||||
// The jvp transform on the primitive can be built with ops
|
||||
// that are scheduled on the same stream as the primitive
|
||||
|
||||
// If argnums = {0}, we only push along x in which case the
|
||||
@@ -481,7 +481,7 @@ one we just defined:
|
||||
auto scale_arr = array(scale, tangents[0].dtype());
|
||||
return {multiply(scale_arr, tangents[0], stream())};
|
||||
}
|
||||
// If, argnums = {0, 1}, we take contributions from both
|
||||
// If argnums = {0, 1}, we take contributions from both
|
||||
// which gives us jvp = tangent_x * alpha + tangent_y * beta
|
||||
else {
|
||||
return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())};
|
||||
@@ -735,7 +735,7 @@ Let's look at a simple script and its results:
|
||||
|
||||
print(f"c shape: {c.shape}")
|
||||
print(f"c dtype: {c.dtype}")
|
||||
print(f"c correct: {mx.all(c == 6.0).item()}")
|
||||
print(f"c is correct: {mx.all(c == 6.0).item()}")
|
||||
|
||||
Output:
|
||||
|
||||
@@ -743,7 +743,7 @@ Output:
|
||||
|
||||
c shape: [3, 4]
|
||||
c dtype: float32
|
||||
c correctness: True
|
||||
c is correct: True
|
||||
|
||||
Results
|
||||
^^^^^^^
|
||||
|
@@ -38,6 +38,7 @@ Array
|
||||
array.log10
|
||||
array.log1p
|
||||
array.log2
|
||||
array.logcumsumexp
|
||||
array.logsumexp
|
||||
array.max
|
||||
array.mean
|
||||
|
@@ -20,3 +20,5 @@ FFT
|
||||
irfft2
|
||||
rfftn
|
||||
irfftn
|
||||
fftshift
|
||||
ifftshift
|
||||
|
@@ -20,5 +20,6 @@ Linear Algebra
|
||||
eigh
|
||||
lu
|
||||
lu_factor
|
||||
pinv
|
||||
solve
|
||||
solve_triangular
|
||||
|
@@ -36,10 +36,12 @@ Operations
|
||||
bitwise_or
|
||||
bitwise_xor
|
||||
block_masked_mm
|
||||
broadcast_arrays
|
||||
broadcast_to
|
||||
ceil
|
||||
clip
|
||||
concatenate
|
||||
contiguous
|
||||
conj
|
||||
conjugate
|
||||
convolve
|
||||
@@ -101,6 +103,7 @@ Operations
|
||||
log10
|
||||
log1p
|
||||
logaddexp
|
||||
logcumsumexp
|
||||
logical_not
|
||||
logical_and
|
||||
logical_or
|
||||
|
@@ -18,3 +18,4 @@ Common Optimizers
|
||||
AdamW
|
||||
Adamax
|
||||
Lion
|
||||
MultiOptimizer
|
||||
|
@@ -9,6 +9,7 @@ Transforms
|
||||
:toctree: _autosummary
|
||||
|
||||
eval
|
||||
async_eval
|
||||
compile
|
||||
custom_function
|
||||
disable_compile
|
||||
|
@@ -5,6 +5,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/dtype_utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/export.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
|
||||
@@ -48,5 +49,16 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
|
||||
if(MLX_BUILD_METAL)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal)
|
||||
else()
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_metal)
|
||||
target_sources(mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/no_metal.cpp)
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_CUDA)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda)
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_METAL OR MLX_BUILD_CUDA)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/gpu)
|
||||
else()
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_gpu)
|
||||
endif()
|
||||
|
@@ -339,11 +339,11 @@ class array {
|
||||
return allocator::allocator().size(buffer());
|
||||
}
|
||||
|
||||
// Return a copy of the shared pointer
|
||||
// to the array::Data struct
|
||||
std::shared_ptr<Data> data_shared_ptr() const {
|
||||
// Return the shared pointer to the array::Data struct
|
||||
const std::shared_ptr<Data>& data_shared_ptr() const {
|
||||
return array_desc_->data;
|
||||
}
|
||||
|
||||
// Return a raw pointer to the arrays data
|
||||
template <typename T>
|
||||
T* data() {
|
||||
@@ -356,7 +356,7 @@ class array {
|
||||
}
|
||||
|
||||
enum Status {
|
||||
// The ouptut of a computation which has not been scheduled.
|
||||
// The output of a computation which has not been scheduled.
|
||||
// For example, the status of `x` in `auto x = a + b`.
|
||||
unscheduled,
|
||||
|
||||
|
@@ -1,8 +1,10 @@
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/broadcasting.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/transpose.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp)
|
||||
|
24
mlx/backend/common/broadcasting.cpp
Normal file
24
mlx/backend/common/broadcasting.cpp
Normal file
@@ -0,0 +1,24 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void broadcast(const array& in, array& out) {
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
return;
|
||||
}
|
||||
Strides strides(out.ndim(), 0);
|
||||
int diff = out.ndim() - in.ndim();
|
||||
for (int i = in.ndim() - 1; i >= 0; --i) {
|
||||
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
|
||||
}
|
||||
auto flags = in.flags();
|
||||
if (out.size() > in.size()) {
|
||||
flags.row_contiguous = flags.col_contiguous = false;
|
||||
}
|
||||
out.copy_shared_buffer(in, strides, flags, in.data_size());
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
11
mlx/backend/common/broadcasting.h
Normal file
11
mlx/backend/common/broadcasting.h
Normal file
@@ -0,0 +1,11 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void broadcast(const array& in, array& out);
|
||||
|
||||
} // namespace mlx::core
|
@@ -1,6 +1,8 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/backend/common/broadcasting.h"
|
||||
#include "mlx/backend/common/transpose.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
@@ -18,47 +20,23 @@ void AsStrided::eval(const std::vector<array>& inputs, array& out) {
|
||||
"AsStrided must be used with row contiguous arrays only.");
|
||||
}
|
||||
|
||||
// Compute the flags given the shape and strides
|
||||
bool row_contiguous = true, col_contiguous = true;
|
||||
size_t r = 1, c = 1;
|
||||
for (int i = strides_.size() - 1, j = 0; i >= 0; i--, j++) {
|
||||
row_contiguous &= (r == strides_[i]) || (shape_[i] == 1);
|
||||
col_contiguous &= (c == strides_[j]) || (shape_[j] == 1);
|
||||
r *= shape_[i];
|
||||
c *= shape_[j];
|
||||
}
|
||||
// Calculate the contiguity based on the given shape and strides
|
||||
auto [ds, rc, cc] = check_contiguity(shape_, strides_);
|
||||
auto flags = in.flags();
|
||||
|
||||
// TODO: Compute the contiguous flag in a better way cause now we are
|
||||
// unnecessarily strict.
|
||||
flags.contiguous = row_contiguous || col_contiguous;
|
||||
flags.row_contiguous = row_contiguous;
|
||||
flags.col_contiguous = col_contiguous;
|
||||
flags.contiguous = rc || cc;
|
||||
flags.row_contiguous = rc;
|
||||
flags.col_contiguous = cc;
|
||||
|
||||
// There is no easy way to compute the actual data size so we use out.size().
|
||||
// The contiguous flag will almost certainly not be set so no code should
|
||||
// rely on data_size anyway.
|
||||
size_t data_size = out.size();
|
||||
// There is no easy way to compute the actual data size so we use out.size()
|
||||
// when the array is not contiguous.
|
||||
size_t data_size = flags.contiguous ? ds : out.size();
|
||||
|
||||
return out.copy_shared_buffer(in, strides_, flags, data_size, offset_);
|
||||
}
|
||||
|
||||
void broadcast(const array& in, array& out) {
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
return;
|
||||
}
|
||||
Strides strides(out.ndim(), 0);
|
||||
int diff = out.ndim() - in.ndim();
|
||||
for (int i = in.ndim() - 1; i >= 0; --i) {
|
||||
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
|
||||
}
|
||||
auto flags = in.flags();
|
||||
if (out.size() > in.size()) {
|
||||
flags.row_contiguous = flags.col_contiguous = false;
|
||||
}
|
||||
out.copy_shared_buffer(in, strides, flags, in.data_size());
|
||||
}
|
||||
|
||||
void Broadcast::eval(const std::vector<array>& inputs, array& out) {
|
||||
broadcast(inputs[0], out);
|
||||
}
|
||||
@@ -286,36 +264,7 @@ void StopGradient::eval(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
void Transpose::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
Strides out_strides(out.ndim());
|
||||
auto& in = inputs[0];
|
||||
for (int ax = 0; ax < axes_.size(); ++ax) {
|
||||
out_strides[ax] = in.strides()[axes_[ax]];
|
||||
}
|
||||
|
||||
// Conditions for {row/col}_contiguous
|
||||
// - array must be contiguous (no gaps)
|
||||
// - underlying buffer size should have the same size as the array
|
||||
// - cumulative product of shapes is equal to the strides (we can ignore axes
|
||||
// with size == 1)
|
||||
// - in the forward direction (column contiguous)
|
||||
// - in the reverse direction (row contiguous)
|
||||
// - vectors are both row and col contiguous (hence if both row/col are
|
||||
// true, they stay true)
|
||||
auto flags = in.flags();
|
||||
if (flags.contiguous && in.data_size() == in.size()) {
|
||||
int64_t f_stride = 1;
|
||||
int64_t b_stride = 1;
|
||||
flags.col_contiguous = true;
|
||||
flags.row_contiguous = true;
|
||||
for (int i = 0, ri = out.ndim() - 1; i < out.ndim(); ++i, --ri) {
|
||||
flags.col_contiguous &= (out_strides[i] == f_stride || out.shape(i) == 1);
|
||||
f_stride *= out.shape(i);
|
||||
flags.row_contiguous &=
|
||||
(out_strides[ri] == b_stride || out.shape(ri) == 1);
|
||||
b_stride *= out.shape(ri);
|
||||
}
|
||||
}
|
||||
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
|
||||
transpose(inputs[0], out, axes_);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -99,7 +99,11 @@ inline std::pair<int, int> decompose_hadamard(int n) {
|
||||
"[hadamard] Only supports n = m*2^k where m in (1, 12, 20, 28).");
|
||||
}
|
||||
}
|
||||
if (n > (1 << 26)) {
|
||||
throw std::invalid_argument(
|
||||
"[hadamard] Only supports n = m*2^k where k <= 26");
|
||||
}
|
||||
return {n, m};
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
} // namespace mlx::core
|
||||
|
57
mlx/backend/common/transpose.cpp
Normal file
57
mlx/backend/common/transpose.cpp
Normal file
@@ -0,0 +1,57 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void transpose(const array& in, array& out, const std::vector<int>& axes) {
|
||||
Strides out_strides(out.ndim());
|
||||
for (int ax = 0; ax < axes.size(); ++ax) {
|
||||
out_strides[ax] = in.strides()[axes[ax]];
|
||||
}
|
||||
|
||||
// Conditions for {row/col}_contiguous
|
||||
// - array must be contiguous (no gaps)
|
||||
// - underlying buffer size should have the same size as the array
|
||||
// - cumulative product of shapes is equal to the strides (we can ignore axes
|
||||
// with size == 1)
|
||||
// - in the forward direction (column contiguous)
|
||||
// - in the reverse direction (row contiguous)
|
||||
// - vectors are both row and col contiguous (hence if both row/col are
|
||||
// true, they stay true)
|
||||
auto flags = in.flags();
|
||||
if (flags.contiguous && in.data_size() == in.size()) {
|
||||
auto [_, rc, cc] = check_contiguity(out.shape(), out_strides);
|
||||
flags.row_contiguous = rc;
|
||||
flags.col_contiguous = cc;
|
||||
}
|
||||
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
|
||||
}
|
||||
|
||||
void as_transposed(array& out, const std::vector<int>& axes) {
|
||||
assert(out.data_size() == out.size() && out.flags().contiguous);
|
||||
|
||||
// Calculate the contiguous strides.
|
||||
Strides strides(out.ndim(), 1);
|
||||
for (int i = out.ndim() - 2; i >= 0; i--) {
|
||||
strides[i] = strides[i + 1] * out.shape(i);
|
||||
}
|
||||
|
||||
// Calculate the new strides for transposing.
|
||||
Strides new_strides;
|
||||
new_strides.reserve(out.ndim());
|
||||
for (auto ax : axes) {
|
||||
new_strides.push_back(strides[ax]);
|
||||
}
|
||||
|
||||
auto [ds, rc, cc] = check_contiguity(out.shape(), new_strides);
|
||||
auto flags = out.flags();
|
||||
flags.row_contiguous = rc;
|
||||
flags.col_contiguous = cc;
|
||||
|
||||
out.copy_shared_buffer(out, new_strides, flags, ds);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
12
mlx/backend/common/transpose.h
Normal file
12
mlx/backend/common/transpose.h
Normal file
@@ -0,0 +1,12 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void transpose(const array& in, array& out, const std::vector<int>& axes);
|
||||
void as_transposed(array& out, const std::vector<int>& axes);
|
||||
|
||||
} // namespace mlx::core
|
@@ -132,6 +132,11 @@ struct ContiguousIterator {
|
||||
};
|
||||
|
||||
inline auto check_contiguity(const Shape& shape, const Strides& strides) {
|
||||
// Conditions for {row/col}_contiguous
|
||||
// - cumulative product of shapes is equal to the strides (we can ignore axes
|
||||
// with size == 1)
|
||||
// - in the forward direction (column contiguous)
|
||||
// - in the reverse direction (row contiguous)
|
||||
size_t no_broadcast_data_size = 1;
|
||||
int64_t f_stride = 1;
|
||||
int64_t b_stride = 1;
|
||||
|
@@ -40,7 +40,8 @@ add_dependencies(mlx cpu_compiled_preamble)
|
||||
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/available.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||
@@ -58,6 +59,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||
@@ -73,8 +75,8 @@ target_sources(
|
||||
if(MLX_BUILD_ACCELERATE)
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/bnns.cpp)
|
||||
else()
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/no_fp16.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/gemms/no_bf16.cpp)
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/simd_fp16.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/gemms/simd_bf16.cpp)
|
||||
endif()
|
||||
|
||||
if(IOS)
|
||||
|
11
mlx/backend/cpu/available.cpp
Normal file
11
mlx/backend/cpu/available.cpp
Normal file
@@ -0,0 +1,11 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cpu/available.h"
|
||||
|
||||
namespace mlx::core::cpu {
|
||||
|
||||
bool is_available() {
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cpu
|
9
mlx/backend/cpu/available.h
Normal file
9
mlx/backend/cpu/available.h
Normal file
@@ -0,0 +1,9 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace mlx::core::cpu {
|
||||
|
||||
bool is_available();
|
||||
|
||||
} // namespace mlx::core::cpu
|
@@ -172,9 +172,12 @@ void binary_float(
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case complex64:
|
||||
binary_op<complex64_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[binary_float] Only supports non-complex floating point types.");
|
||||
"[binary_float] Only supports floating point types.");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
@@ -40,7 +40,10 @@ struct CompilerCache {
|
||||
std::shared_mutex mtx;
|
||||
};
|
||||
|
||||
static CompilerCache cache{};
|
||||
static CompilerCache& cache() {
|
||||
static CompilerCache cache_;
|
||||
return cache_;
|
||||
};
|
||||
|
||||
// GPU compile is always available if the GPU is available and since we are in
|
||||
// this file CPU compile is also available.
|
||||
@@ -56,14 +59,16 @@ void* compile(
|
||||
const std::string& kernel_name,
|
||||
const std::function<std::string(void)>& source_builder) {
|
||||
{
|
||||
std::shared_lock lock(cache.mtx);
|
||||
if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) {
|
||||
std::shared_lock lock(cache().mtx);
|
||||
if (auto it = cache().kernels.find(kernel_name);
|
||||
it != cache().kernels.end()) {
|
||||
return it->second;
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_lock lock(cache.mtx);
|
||||
if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) {
|
||||
std::unique_lock lock(cache().mtx);
|
||||
if (auto it = cache().kernels.find(kernel_name);
|
||||
it != cache().kernels.end()) {
|
||||
return it->second;
|
||||
}
|
||||
std::string source_code = source_builder();
|
||||
@@ -120,10 +125,10 @@ void* compile(
|
||||
}
|
||||
|
||||
// load library
|
||||
cache.libs.emplace_back(shared_lib_path);
|
||||
cache().libs.emplace_back(shared_lib_path);
|
||||
|
||||
// Load function
|
||||
void* fun = dlsym(cache.libs.back().lib, kernel_name.c_str());
|
||||
void* fun = dlsym(cache().libs.back().lib, kernel_name.c_str());
|
||||
if (!fun) {
|
||||
std::ostringstream msg;
|
||||
msg << "[Compile::eval_cpu] Failed to load compiled function "
|
||||
@@ -131,7 +136,7 @@ void* compile(
|
||||
<< dlerror();
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
cache.kernels.insert({kernel_name, fun});
|
||||
cache().kernels.insert({kernel_name, fun});
|
||||
return fun;
|
||||
}
|
||||
|
||||
|
@@ -46,8 +46,15 @@ void AllReduce::eval_cpu(
|
||||
case Sum:
|
||||
distributed::detail::all_sum(group(), in, outputs[0], stream());
|
||||
break;
|
||||
case Max:
|
||||
distributed::detail::all_max(group(), in, outputs[0], stream());
|
||||
break;
|
||||
case Min:
|
||||
distributed::detail::all_min(group(), in, outputs[0], stream());
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("Only all reduce sum is supported for now");
|
||||
throw std::runtime_error(
|
||||
"Only all reduce sum, min and max are supported for now");
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -1,27 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cpu/gemm.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <>
|
||||
void matmul<bfloat16_t>(
|
||||
const bfloat16_t*,
|
||||
const bfloat16_t*,
|
||||
bfloat16_t*,
|
||||
bool,
|
||||
bool,
|
||||
size_t,
|
||||
size_t,
|
||||
size_t,
|
||||
float,
|
||||
float,
|
||||
size_t,
|
||||
const Shape&,
|
||||
const Strides&,
|
||||
const Shape&,
|
||||
const Strides&) {
|
||||
throw std::runtime_error("[Matmul::eval_cpu] bfloat16 not supported.");
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -1,27 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cpu/gemm.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <>
|
||||
void matmul<float16_t>(
|
||||
const float16_t*,
|
||||
const float16_t*,
|
||||
float16_t*,
|
||||
bool,
|
||||
bool,
|
||||
size_t,
|
||||
size_t,
|
||||
size_t,
|
||||
float,
|
||||
float,
|
||||
size_t,
|
||||
const Shape&,
|
||||
const Strides&,
|
||||
const Shape&,
|
||||
const Strides&) {
|
||||
throw std::runtime_error("[Matmul::eval_cpu] float16 not supported.");
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
45
mlx/backend/cpu/gemms/simd_bf16.cpp
Normal file
45
mlx/backend/cpu/gemms/simd_bf16.cpp
Normal file
@@ -0,0 +1,45 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cpu/gemm.h"
|
||||
#include "mlx/backend/cpu/gemms/simd_gemm.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <>
|
||||
void matmul<bfloat16_t>(
|
||||
const bfloat16_t* a,
|
||||
const bfloat16_t* b,
|
||||
bfloat16_t* out,
|
||||
bool a_transposed,
|
||||
bool b_transposed,
|
||||
size_t lda,
|
||||
size_t ldb,
|
||||
size_t ldc,
|
||||
float alpha,
|
||||
float beta,
|
||||
size_t batch_size,
|
||||
const Shape& a_shape,
|
||||
const Strides& a_strides,
|
||||
const Shape& b_shape,
|
||||
const Strides& b_strides) {
|
||||
auto ndim = a_shape.size();
|
||||
size_t M = a_shape[ndim - 2];
|
||||
size_t N = b_shape[ndim - 1];
|
||||
size_t K = a_shape[ndim - 1];
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
simd_gemm<bfloat16_t, float>(
|
||||
a + elem_to_loc(M * K * i, a_shape, a_strides),
|
||||
b + elem_to_loc(K * N * i, b_shape, b_strides),
|
||||
out + M * N * i,
|
||||
a_transposed,
|
||||
b_transposed,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
alpha,
|
||||
beta);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
45
mlx/backend/cpu/gemms/simd_fp16.cpp
Normal file
45
mlx/backend/cpu/gemms/simd_fp16.cpp
Normal file
@@ -0,0 +1,45 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cpu/gemm.h"
|
||||
#include "mlx/backend/cpu/gemms/simd_gemm.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <>
|
||||
void matmul<float16_t>(
|
||||
const float16_t* a,
|
||||
const float16_t* b,
|
||||
float16_t* out,
|
||||
bool a_transposed,
|
||||
bool b_transposed,
|
||||
size_t lda,
|
||||
size_t ldb,
|
||||
size_t ldc,
|
||||
float alpha,
|
||||
float beta,
|
||||
size_t batch_size,
|
||||
const Shape& a_shape,
|
||||
const Strides& a_strides,
|
||||
const Shape& b_shape,
|
||||
const Strides& b_strides) {
|
||||
auto ndim = a_shape.size();
|
||||
size_t M = a_shape[ndim - 2];
|
||||
size_t N = b_shape[ndim - 1];
|
||||
size_t K = a_shape[ndim - 1];
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
simd_gemm<float16_t, float>(
|
||||
a + elem_to_loc(M * K * i, a_shape, a_strides),
|
||||
b + elem_to_loc(K * N * i, b_shape, b_strides),
|
||||
out + M * N * i,
|
||||
a_transposed,
|
||||
b_transposed,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
alpha,
|
||||
beta);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
139
mlx/backend/cpu/gemms/simd_gemm.h
Normal file
139
mlx/backend/cpu/gemms/simd_gemm.h
Normal file
@@ -0,0 +1,139 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cpu/simd/simd.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
inline int ceildiv(int a, int b) {
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
|
||||
template <int block_size, typename T, typename AccT>
|
||||
void load_block(
|
||||
const T* in,
|
||||
AccT* out,
|
||||
int M,
|
||||
int N,
|
||||
int i,
|
||||
int j,
|
||||
bool transpose) {
|
||||
if (transpose) {
|
||||
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
|
||||
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
|
||||
out[jj * block_size + ii] =
|
||||
in[(i * block_size + ii) * N + j * block_size + jj];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
|
||||
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
|
||||
out[ii * block_size + jj] =
|
||||
in[(i * block_size + ii) * N + j * block_size + jj];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename AccT>
|
||||
void simd_gemm(
|
||||
const T* a,
|
||||
const T* b,
|
||||
T* c,
|
||||
bool a_trans,
|
||||
bool b_trans,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
float alpha,
|
||||
float beta) {
|
||||
constexpr int block_size = 16;
|
||||
constexpr int simd_size = simd::max_size<AccT>;
|
||||
static_assert(
|
||||
(block_size % simd_size) == 0,
|
||||
"Block size must be divisible by SIMD size");
|
||||
|
||||
int last_k_block_size = K - block_size * (K / block_size);
|
||||
int last_k_simd_block = (last_k_block_size / simd_size) * simd_size;
|
||||
for (int i = 0; i < ceildiv(M, block_size); i++) {
|
||||
for (int j = 0; j < ceildiv(N, block_size); j++) {
|
||||
AccT c_block[block_size * block_size] = {0.0};
|
||||
AccT a_block[block_size * block_size];
|
||||
AccT b_block[block_size * block_size];
|
||||
|
||||
int k = 0;
|
||||
for (; k < K / block_size; k++) {
|
||||
// Load a and b blocks
|
||||
if (a_trans) {
|
||||
load_block<block_size>(a, a_block, K, M, k, i, true);
|
||||
} else {
|
||||
load_block<block_size>(a, a_block, M, K, i, k, false);
|
||||
}
|
||||
if (b_trans) {
|
||||
load_block<block_size>(b, b_block, N, K, j, k, false);
|
||||
} else {
|
||||
load_block<block_size>(b, b_block, K, N, k, j, true);
|
||||
}
|
||||
|
||||
// Multiply and accumulate
|
||||
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
|
||||
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
|
||||
for (int kk = 0; kk < block_size; kk += simd_size) {
|
||||
auto av =
|
||||
simd::load<AccT, simd_size>(a_block + ii * block_size + kk);
|
||||
auto bv =
|
||||
simd::load<AccT, simd_size>(b_block + jj * block_size + kk);
|
||||
c_block[ii * block_size + jj] += simd::sum(av * bv);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (last_k_block_size) {
|
||||
// Load a and b blocks
|
||||
if (a_trans) {
|
||||
load_block<block_size>(a, a_block, K, M, k, i, true);
|
||||
} else {
|
||||
load_block<block_size>(a, a_block, M, K, i, k, false);
|
||||
}
|
||||
if (b_trans) {
|
||||
load_block<block_size>(b, b_block, N, K, j, k, false);
|
||||
} else {
|
||||
load_block<block_size>(b, b_block, K, N, k, j, true);
|
||||
}
|
||||
|
||||
// Multiply and accumulate
|
||||
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
|
||||
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
|
||||
int kk = 0;
|
||||
for (; kk < last_k_simd_block; kk += simd_size) {
|
||||
auto av =
|
||||
simd::load<AccT, simd_size>(a_block + ii * block_size + kk);
|
||||
auto bv =
|
||||
simd::load<AccT, simd_size>(b_block + jj * block_size + kk);
|
||||
c_block[ii * block_size + jj] += simd::sum(av * bv);
|
||||
}
|
||||
for (; kk < last_k_block_size; ++kk) {
|
||||
c_block[ii * block_size + jj] +=
|
||||
a_block[ii * block_size + kk] * b_block[jj * block_size + kk];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Store
|
||||
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
|
||||
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
|
||||
auto c_idx = (i * block_size + ii) * N + j * block_size + jj;
|
||||
if (beta != 0) {
|
||||
c[c_idx] = static_cast<T>(
|
||||
alpha * c_block[ii * block_size + jj] + beta * c[c_idx]);
|
||||
} else {
|
||||
c[c_idx] = static_cast<T>(alpha * c_block[ii * block_size + jj]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
140
mlx/backend/cpu/logsumexp.cpp
Normal file
140
mlx/backend/cpu/logsumexp.cpp
Normal file
@@ -0,0 +1,140 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/simd/simd.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/types/limits.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
using namespace mlx::core::simd;
|
||||
|
||||
template <typename T, typename AccT>
|
||||
void logsumexp(const array& in, array& out, Stream stream) {
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
const T* in_ptr = in.data<T>();
|
||||
T* out_ptr = out.data<T>();
|
||||
|
||||
int M = in.shape().back();
|
||||
int L = in.data_size() / M;
|
||||
|
||||
encoder.dispatch([in_ptr, out_ptr, M, L]() mutable {
|
||||
constexpr int N = std::min(max_size<AccT>, max_size<T>);
|
||||
|
||||
const T* current_in_ptr;
|
||||
|
||||
for (int i = 0; i < L; i++, in_ptr += M, out_ptr += 1) {
|
||||
// Find the maximum
|
||||
current_in_ptr = in_ptr;
|
||||
Simd<AccT, N> vmaximum(-numeric_limits<AccT>::infinity());
|
||||
size_t s = M;
|
||||
while (s >= N) {
|
||||
Simd<AccT, N> vals = load<T, N>(current_in_ptr);
|
||||
vmaximum = maximum(vals, vmaximum);
|
||||
current_in_ptr += N;
|
||||
s -= N;
|
||||
}
|
||||
|
||||
AccT maximum = max(vmaximum);
|
||||
while (s-- > 0) {
|
||||
maximum = std::max(maximum, static_cast<AccT>(*current_in_ptr));
|
||||
current_in_ptr++;
|
||||
}
|
||||
|
||||
// Compute the normalizer and the exponentials
|
||||
Simd<AccT, N> vnormalizer(0.0);
|
||||
current_in_ptr = in_ptr;
|
||||
s = M;
|
||||
while (s >= N) {
|
||||
Simd<AccT, N> vexp = load<T, N>(current_in_ptr);
|
||||
vexp = exp(vexp - maximum);
|
||||
vnormalizer = vnormalizer + vexp;
|
||||
current_in_ptr += N;
|
||||
s -= N;
|
||||
}
|
||||
AccT normalizer = sum(vnormalizer);
|
||||
while (s-- > 0) {
|
||||
AccT _exp = std::exp(*current_in_ptr - maximum);
|
||||
normalizer += _exp;
|
||||
current_in_ptr++;
|
||||
}
|
||||
// Normalize
|
||||
*out_ptr = std::isinf(maximum)
|
||||
? static_cast<T>(maximum)
|
||||
: static_cast<T>(std::log(normalizer) + maximum);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void LogSumExp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
|
||||
// Make sure that the last dimension is contiguous
|
||||
auto s = stream();
|
||||
auto& encoder = cpu::get_command_encoder(s);
|
||||
auto ensure_contiguous = [&s, &encoder](const array& x) {
|
||||
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
|
||||
return x;
|
||||
} else {
|
||||
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
||||
copy(x, x_copy, CopyType::General, s);
|
||||
encoder.add_temporary(x_copy);
|
||||
return x_copy;
|
||||
}
|
||||
};
|
||||
|
||||
auto in = ensure_contiguous(inputs[0]);
|
||||
if (in.flags().row_contiguous) {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
} else {
|
||||
auto n = in.shape(-1);
|
||||
auto flags = in.flags();
|
||||
auto strides = in.strides();
|
||||
for (auto& s : strides) {
|
||||
s /= n;
|
||||
}
|
||||
bool col_contig = strides[0] == 1;
|
||||
for (int i = 1; col_contig && i < strides.size(); ++i) {
|
||||
col_contig &=
|
||||
(out.shape(i) == 1 || strides[i - 1] == out.shape(i) * strides[i]);
|
||||
}
|
||||
flags.col_contiguous = col_contig;
|
||||
out.set_data(
|
||||
allocator::malloc(in.nbytes() / n),
|
||||
in.data_size() / n,
|
||||
std::move(strides),
|
||||
flags);
|
||||
}
|
||||
|
||||
switch (in.dtype()) {
|
||||
case float32:
|
||||
logsumexp<float, float>(in, out, stream());
|
||||
break;
|
||||
case float16:
|
||||
logsumexp<float16_t, float>(in, out, stream());
|
||||
break;
|
||||
case bfloat16:
|
||||
logsumexp<bfloat16_t, float>(in, out, stream());
|
||||
break;
|
||||
case float64:
|
||||
logsumexp<double, double>(in, out, stream());
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[logsumexp] only supports floating point types");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -3,6 +3,7 @@
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cpu/binary_ops.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/simd/simd.h"
|
||||
@@ -226,6 +227,16 @@ void scan_dispatch(
|
||||
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
|
||||
break;
|
||||
}
|
||||
case Scan::LogAddExp: {
|
||||
auto op = [](U a, T b) {
|
||||
return detail::LogAddExp{}(a, static_cast<U>(b));
|
||||
};
|
||||
auto init = (issubdtype(in.dtype(), floating))
|
||||
? static_cast<U>(-std::numeric_limits<float>::infinity())
|
||||
: std::numeric_limits<U>::min();
|
||||
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -319,7 +330,8 @@ void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
case complex64:
|
||||
throw std::runtime_error("Scan ops do not support complex types yet");
|
||||
scan_dispatch<complex64_t, complex64_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
}
|
||||
});
|
||||
|
@@ -17,7 +17,7 @@ struct ScalarT<float16_t, N> {
|
||||
#endif
|
||||
|
||||
template <>
|
||||
static constexpr int max_size<float16_t> = N;
|
||||
inline constexpr int max_size<float16_t> = N;
|
||||
|
||||
#define SIMD_FP16_DEFAULT_UNARY(op) \
|
||||
template <> \
|
||||
|
@@ -83,25 +83,25 @@ struct Simd {
|
||||
// Values chosen based on benchmarks on M3 Max
|
||||
// TODO: consider choosing these more optimally
|
||||
template <>
|
||||
static constexpr int max_size<int8_t> = 16;
|
||||
inline constexpr int max_size<int8_t> = 16;
|
||||
template <>
|
||||
static constexpr int max_size<int16_t> = 16;
|
||||
inline constexpr int max_size<int16_t> = 16;
|
||||
template <>
|
||||
static constexpr int max_size<int> = 8;
|
||||
inline constexpr int max_size<int> = 8;
|
||||
template <>
|
||||
static constexpr int max_size<int64_t> = 4;
|
||||
inline constexpr int max_size<int64_t> = 4;
|
||||
template <>
|
||||
static constexpr int max_size<uint8_t> = 16;
|
||||
inline constexpr int max_size<uint8_t> = 16;
|
||||
template <>
|
||||
static constexpr int max_size<uint16_t> = 16;
|
||||
inline constexpr int max_size<uint16_t> = 16;
|
||||
template <>
|
||||
static constexpr int max_size<uint32_t> = 8;
|
||||
inline constexpr int max_size<uint32_t> = 8;
|
||||
template <>
|
||||
static constexpr int max_size<uint64_t> = 4;
|
||||
inline constexpr int max_size<uint64_t> = 4;
|
||||
template <>
|
||||
static constexpr int max_size<float> = 8;
|
||||
inline constexpr int max_size<float> = 8;
|
||||
template <>
|
||||
static constexpr int max_size<double> = 4;
|
||||
inline constexpr int max_size<double> = 4;
|
||||
|
||||
#define SIMD_DEFAULT_UNARY(name, op) \
|
||||
template <typename T, int N> \
|
||||
|
@@ -87,14 +87,45 @@ DEFAULT_UNARY(cosh, std::cosh)
|
||||
DEFAULT_UNARY(expm1, std::expm1)
|
||||
DEFAULT_UNARY(floor, std::floor)
|
||||
DEFAULT_UNARY(log, std::log)
|
||||
DEFAULT_UNARY(log2, std::log2)
|
||||
DEFAULT_UNARY(log10, std::log10)
|
||||
DEFAULT_UNARY(log1p, std::log1p)
|
||||
DEFAULT_UNARY(sinh, std::sinh)
|
||||
DEFAULT_UNARY(sqrt, std::sqrt)
|
||||
DEFAULT_UNARY(tan, std::tan)
|
||||
DEFAULT_UNARY(tanh, std::tanh)
|
||||
|
||||
template <typename T>
|
||||
Simd<T, 1> log1p(Simd<T, 1> in) {
|
||||
if constexpr (is_complex<T>) {
|
||||
auto x = in.value.real();
|
||||
auto y = in.value.imag();
|
||||
auto zabs = std::abs(in.value);
|
||||
auto theta = std::atan2(y, x + 1);
|
||||
if (zabs < 0.5) {
|
||||
auto r = x * (2 + x) + y * y;
|
||||
if (r == 0) { // handle underflow
|
||||
return Simd<T, 1>{T{x, theta}};
|
||||
}
|
||||
return Simd<T, 1>{T{((typeof(x))(0.5)) * std::log1p(r), theta}};
|
||||
} else {
|
||||
auto z0 = std::hypot(x + 1, y);
|
||||
return Simd<T, 1>{T{std::log(z0), theta}};
|
||||
}
|
||||
} else {
|
||||
return Simd<T, 1>{std::log1p(in.value)};
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Simd<T, 1> log2(Simd<T, 1> in) {
|
||||
if constexpr (is_complex<T>) {
|
||||
auto out = std::log(in.value);
|
||||
auto scale = decltype(out.real())(M_LN2);
|
||||
return Simd<T, 1>{T{out.real() / scale, out.imag() / scale}};
|
||||
} else {
|
||||
return Simd<T, 1>{std::log2(in.value)};
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Simd<T, 1> operator~(Simd<T, 1> in) {
|
||||
return ~in.value;
|
||||
|
@@ -119,12 +119,7 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
// Make sure that the last dimension is contiguous
|
||||
auto set_output = [s = stream(), &out](const array& x) {
|
||||
bool no_copy = x.strides()[x.ndim() - 1] == 1;
|
||||
if (x.ndim() > 1) {
|
||||
auto s = x.strides()[x.ndim() - 2];
|
||||
no_copy &= (s == 0 || s == x.shape().back());
|
||||
}
|
||||
if (no_copy) {
|
||||
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
|
||||
if (x.is_donatable()) {
|
||||
out.copy_shared_buffer(x);
|
||||
} else {
|
||||
@@ -146,18 +141,6 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto in = set_output(inputs[0]);
|
||||
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
case uint8:
|
||||
case uint16:
|
||||
case uint32:
|
||||
case uint64:
|
||||
case int8:
|
||||
case int16:
|
||||
case int32:
|
||||
case int64:
|
||||
throw std::runtime_error(
|
||||
"Softmax is defined only for floating point types");
|
||||
break;
|
||||
case float32:
|
||||
softmax<float, float>(in, out, stream());
|
||||
break;
|
||||
@@ -178,9 +161,9 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
case float64:
|
||||
softmax<double, double>(in, out, stream());
|
||||
break;
|
||||
case complex64:
|
||||
throw std::invalid_argument(
|
||||
"[Softmax] Not yet implemented for complex64");
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[softmax] Only defined for floating point types.");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
@@ -1,5 +1,8 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
// Required for using M_LN2 in MSVC.
|
||||
#define _USE_MATH_DEFINES
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/backend/cpu/unary.h"
|
||||
|
@@ -86,13 +86,14 @@ struct Sign {
|
||||
template <int N, typename T>
|
||||
Simd<T, N> operator()(Simd<T, N> x) {
|
||||
auto z = Simd<T, N>{0};
|
||||
auto o = Simd<T, N>{1};
|
||||
auto m = Simd<T, N>{-1};
|
||||
if constexpr (std::is_unsigned_v<T>) {
|
||||
return x != z;
|
||||
return simd::select(x == z, z, o);
|
||||
} else if constexpr (std::is_same_v<T, complex64_t>) {
|
||||
return simd::select(x == z, x, Simd<T, N>(x / simd::abs(x)));
|
||||
} else {
|
||||
return simd::select(
|
||||
x < z, Simd<T, N>{-1}, simd::select(x > z, Simd<T, N>{1}, z));
|
||||
return simd::select(x < z, m, simd::select(x > z, o, z));
|
||||
}
|
||||
}
|
||||
SINGLE()
|
||||
|
57
mlx/backend/cuda/CMakeLists.txt
Normal file
57
mlx/backend/cuda/CMakeLists.txt
Normal file
@@ -0,0 +1,57 @@
|
||||
# Filename rules in cuda backend:
|
||||
#
|
||||
# * Use .cu/.cuh if code contains device code, and .cpp/.h if not.
|
||||
# * Device-only kernel code should be put in kernels/ subdir.
|
||||
# * Files in kernels/ subdir should not include files outside.
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fence.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
|
||||
|
||||
target_compile_definitions(mlx PUBLIC MLX_USE_CUDA)
|
||||
|
||||
# Enable defining device lambda functions.
|
||||
target_compile_options(mlx
|
||||
PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--extended-lambda>")
|
||||
|
||||
# Compute capability 7 is required for synchronization between CPU/GPU with
|
||||
# managed memory. TODO: Add more architectures for potential performance gain.
|
||||
set(MLX_CUDA_ARCHITECTURES
|
||||
"75;80"
|
||||
CACHE STRING "CUDA architectures")
|
||||
message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}")
|
||||
set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES
|
||||
"${MLX_CUDA_ARCHITECTURES}")
|
||||
|
||||
# Use fixed version of CCCL.
|
||||
FetchContent_Declare(
|
||||
cccl
|
||||
URL "https://github.com/NVIDIA/cccl/releases/download/v2.8.1/cccl-v2.8.1.zip")
|
||||
FetchContent_MakeAvailable(cccl)
|
||||
target_include_directories(mlx PRIVATE BEFORE "${cccl_SOURCE_DIR}/include")
|
||||
|
||||
# Use fixed version of NVTX.
|
||||
FetchContent_Declare(
|
||||
nvtx3
|
||||
GIT_REPOSITORY https://github.com/NVIDIA/NVTX.git
|
||||
GIT_TAG v3.1.1
|
||||
GIT_SHALLOW TRUE
|
||||
SOURCE_SUBDIR c EXCLUDE_FROM_ALL)
|
||||
FetchContent_MakeAvailable(nvtx3)
|
||||
target_link_libraries(mlx PUBLIC $<BUILD_INTERFACE:nvtx3-cpp>)
|
||||
|
||||
# Make cuda runtime APIs available in non-cuda files.
|
||||
find_package(CUDAToolkit REQUIRED)
|
||||
target_include_directories(mlx PRIVATE ${CUDAToolkit_INCLUDE_DIRS})
|
||||
|
||||
# Suppress nvcc warnings on MLX headers.
|
||||
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
|
||||
--diag_suppress=997>)
|
154
mlx/backend/cuda/allocator.cpp
Normal file
154
mlx/backend/cuda/allocator.cpp
Normal file
@@ -0,0 +1,154 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/allocator.h"
|
||||
#include "mlx/backend/cuda/utils.h"
|
||||
#include "mlx/backend/cuda/worker.h"
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <fmt/format.h>
|
||||
|
||||
#include <cassert>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
CudaAllocator::CudaAllocator() {
|
||||
// TODO: Set memory limit for multi-device.
|
||||
size_t free, total;
|
||||
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
|
||||
memory_limit_ = total * 0.8;
|
||||
}
|
||||
|
||||
Buffer CudaAllocator::malloc(size_t size) {
|
||||
// TODO: Check memory limit.
|
||||
auto* buf = new CudaBuffer{nullptr, size};
|
||||
cudaError_t err = cudaMallocManaged(&buf->data, size);
|
||||
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
|
||||
throw std::runtime_error(
|
||||
fmt::format("cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
|
||||
}
|
||||
std::lock_guard lock(mutex_);
|
||||
active_memory_ += size;
|
||||
peak_memory_ = std::max(active_memory_, peak_memory_);
|
||||
return Buffer{buf};
|
||||
}
|
||||
|
||||
void CudaAllocator::free(Buffer buffer) {
|
||||
auto* buf = static_cast<CudaBuffer*>(buffer.ptr());
|
||||
if (!buf) {
|
||||
return;
|
||||
}
|
||||
|
||||
// If free() is called from a unregistered thread, reschedule the call to
|
||||
// worker.
|
||||
{
|
||||
std::lock_guard lock(worker_mutex_);
|
||||
if (allowed_threads_.count(std::this_thread::get_id()) == 0) {
|
||||
if (!worker_) {
|
||||
worker_.reset(new Worker);
|
||||
}
|
||||
worker_->add_task([buffer]() { allocator().free(buffer); });
|
||||
worker_->end_batch();
|
||||
worker_->commit();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
size_t size = buf->size;
|
||||
cudaFree(buf->data);
|
||||
delete buf;
|
||||
std::lock_guard lock(mutex_);
|
||||
active_memory_ -= size;
|
||||
}
|
||||
|
||||
size_t CudaAllocator::size(Buffer buffer) const {
|
||||
auto* buf = static_cast<CudaBuffer*>(buffer.ptr());
|
||||
if (!buf) {
|
||||
return 0;
|
||||
}
|
||||
return buf->size;
|
||||
}
|
||||
|
||||
void CudaAllocator::register_this_thread() {
|
||||
std::lock_guard lock(worker_mutex_);
|
||||
allowed_threads_.insert(std::this_thread::get_id());
|
||||
}
|
||||
|
||||
size_t CudaAllocator::get_active_memory() const {
|
||||
return active_memory_;
|
||||
}
|
||||
|
||||
size_t CudaAllocator::get_peak_memory() const {
|
||||
return peak_memory_;
|
||||
}
|
||||
|
||||
void CudaAllocator::reset_peak_memory() {
|
||||
std::lock_guard lock(mutex_);
|
||||
peak_memory_ = 0;
|
||||
}
|
||||
|
||||
size_t CudaAllocator::get_memory_limit() {
|
||||
return memory_limit_;
|
||||
}
|
||||
|
||||
size_t CudaAllocator::set_memory_limit(size_t limit) {
|
||||
std::lock_guard lock(mutex_);
|
||||
std::swap(limit, memory_limit_);
|
||||
return limit;
|
||||
}
|
||||
|
||||
CudaAllocator& allocator() {
|
||||
// By creating the |allocator_| on heap, the destructor of CudaAllocator
|
||||
// will not be called on exit and buffers in the cache will be leaked. This
|
||||
// can save some time at program exit.
|
||||
static CudaAllocator* allocator_ = new CudaAllocator;
|
||||
return *allocator_;
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
namespace allocator {
|
||||
|
||||
Allocator& allocator() {
|
||||
return cu::allocator();
|
||||
}
|
||||
|
||||
void* Buffer::raw_ptr() {
|
||||
if (!ptr_) {
|
||||
return nullptr;
|
||||
}
|
||||
return static_cast<cu::CudaBuffer*>(ptr_)->data;
|
||||
}
|
||||
|
||||
} // namespace allocator
|
||||
|
||||
size_t get_active_memory() {
|
||||
return cu::allocator().get_active_memory();
|
||||
}
|
||||
size_t get_peak_memory() {
|
||||
return cu::allocator().get_peak_memory();
|
||||
}
|
||||
void reset_peak_memory() {
|
||||
return cu::allocator().reset_peak_memory();
|
||||
}
|
||||
size_t set_memory_limit(size_t limit) {
|
||||
return cu::allocator().set_memory_limit(limit);
|
||||
}
|
||||
size_t get_memory_limit() {
|
||||
return cu::allocator().get_memory_limit();
|
||||
}
|
||||
|
||||
// TODO: Implement buffer cache.
|
||||
size_t get_cache_memory() {
|
||||
return 0;
|
||||
}
|
||||
size_t set_cache_limit(size_t) {
|
||||
return 0;
|
||||
}
|
||||
size_t set_wired_limit(size_t) {
|
||||
return 0;
|
||||
}
|
||||
void clear_cache() {}
|
||||
|
||||
} // namespace mlx::core
|
58
mlx/backend/cuda/allocator.h
Normal file
58
mlx/backend/cuda/allocator.h
Normal file
@@ -0,0 +1,58 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
|
||||
#include <mutex>
|
||||
#include <set>
|
||||
#include <thread>
|
||||
#include <utility>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
class Worker;
|
||||
|
||||
using allocator::Buffer;
|
||||
|
||||
// Stores cuda-managed unified memory.
|
||||
struct CudaBuffer {
|
||||
void* data;
|
||||
size_t size;
|
||||
};
|
||||
|
||||
class CudaAllocator : public allocator::Allocator {
|
||||
public:
|
||||
Buffer malloc(size_t size) override;
|
||||
void free(Buffer buffer) override;
|
||||
size_t size(Buffer buffer) const override;
|
||||
|
||||
// Register current thread as safe to free buffers.
|
||||
// In cuda freeing a buffer implicitly synchronizes stream, and for threads
|
||||
// that may be waited by gpu stream (for example cpu stream threads), freeing
|
||||
// buffers there would result in dead lock.
|
||||
void register_this_thread();
|
||||
|
||||
size_t get_active_memory() const;
|
||||
size_t get_peak_memory() const;
|
||||
void reset_peak_memory();
|
||||
size_t get_memory_limit();
|
||||
size_t set_memory_limit(size_t limit);
|
||||
|
||||
private:
|
||||
CudaAllocator();
|
||||
friend CudaAllocator& allocator();
|
||||
|
||||
std::mutex worker_mutex_;
|
||||
std::unique_ptr<Worker> worker_;
|
||||
std::set<std::thread::id> allowed_threads_;
|
||||
|
||||
std::mutex mutex_;
|
||||
size_t memory_limit_;
|
||||
size_t active_memory_{0};
|
||||
size_t peak_memory_{0};
|
||||
};
|
||||
|
||||
CudaAllocator& allocator();
|
||||
|
||||
} // namespace mlx::core::cu
|
26
mlx/backend/cuda/copy.cpp
Normal file
26
mlx/backend/cuda/copy.cpp
Normal file
@@ -0,0 +1,26 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void copy_gpu_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
const Shape& data_shape,
|
||||
const Strides& strides_in_pre,
|
||||
const Strides& strides_out_pre,
|
||||
int64_t inp_offset,
|
||||
int64_t out_offset,
|
||||
CopyType ctype,
|
||||
const Stream& s,
|
||||
const std::optional<array>& dynamic_i_offset /* = std::nullopt */,
|
||||
const std::optional<array>& dynamic_o_offset /* = std::nullopt */) {
|
||||
throw std::runtime_error("copy_gpu_inplace not implemented in CUDA backend.");
|
||||
}
|
||||
|
||||
void fill_gpu(const array& val, array& out, const Stream& s) {
|
||||
throw std::runtime_error("fill_gpu not implemented in CUDA backend.");
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
117
mlx/backend/cuda/device.cpp
Normal file
117
mlx/backend/cuda/device.cpp
Normal file
@@ -0,0 +1,117 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/worker.h"
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
|
||||
#include <fmt/format.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
DeviceStream::DeviceStream(Device& device) : device_(device), stream_(device) {}
|
||||
|
||||
void DeviceStream::synchronize() {
|
||||
cudaStreamSynchronize(stream_);
|
||||
}
|
||||
|
||||
cudaStream_t DeviceStream::schedule_cuda_stream() {
|
||||
// TODO: Return a stream that maximizes parallelism.
|
||||
return stream_;
|
||||
}
|
||||
|
||||
cudaStream_t DeviceStream::last_cuda_stream() {
|
||||
return stream_;
|
||||
}
|
||||
|
||||
CommandEncoder& DeviceStream::get_encoder() {
|
||||
if (!encoder_) {
|
||||
encoder_ = std::make_unique<CommandEncoder>(*this);
|
||||
}
|
||||
return *encoder_;
|
||||
}
|
||||
|
||||
Device::Device(int device) : device_(device) {
|
||||
// Validate the requirements of device.
|
||||
int attr = 0;
|
||||
cudaDeviceGetAttribute(&attr, cudaDevAttrConcurrentManagedAccess, device_);
|
||||
if (attr != 1) {
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Device {} does not support synchronization in managed memory.",
|
||||
device_));
|
||||
}
|
||||
}
|
||||
|
||||
void Device::make_current() {
|
||||
// We need to set/get current CUDA device very frequently, cache it to reduce
|
||||
// actual calls of CUDA APIs. This function assumes single-thread in host.
|
||||
static int current = 0;
|
||||
if (current != device_) {
|
||||
CHECK_CUDA_ERROR(cudaSetDevice(device_));
|
||||
current = device_;
|
||||
}
|
||||
}
|
||||
|
||||
DeviceStream& Device::get_stream(Stream s) {
|
||||
auto it = streams_.find(s.index);
|
||||
if (it == streams_.end()) {
|
||||
it = streams_.try_emplace(s.index, *this).first;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
CommandEncoder::CommandEncoder(DeviceStream& s)
|
||||
: device_(s.device()), stream_(s) {}
|
||||
|
||||
void CommandEncoder::add_completed_handler(std::function<void()> task) {
|
||||
worker_.add_task(std::move(task));
|
||||
}
|
||||
|
||||
void CommandEncoder::end_encoding() {
|
||||
if (!temporaries_.empty()) {
|
||||
add_completed_handler([temporaries = std::move(temporaries_)]() {});
|
||||
}
|
||||
|
||||
// There is no kernel running, run completion handlers immediately.
|
||||
if (!has_gpu_work_) {
|
||||
worker_.consume_in_this_thread();
|
||||
return;
|
||||
}
|
||||
has_gpu_work_ = false;
|
||||
|
||||
// Put completion handlers in a batch.
|
||||
worker_.end_batch();
|
||||
|
||||
// Signaling kernel completion is expensive, delay until enough batches.
|
||||
// TODO: This number is arbitrarily picked, profile for a better stragety.
|
||||
if (worker_.uncommited_batches() > 8) {
|
||||
commit();
|
||||
}
|
||||
}
|
||||
|
||||
void CommandEncoder::commit() {
|
||||
worker_.commit(stream_.last_cuda_stream());
|
||||
}
|
||||
|
||||
Device& device(mlx::core::Device device) {
|
||||
static std::unordered_map<int, Device> devices;
|
||||
auto it = devices.find(device.index);
|
||||
if (it == devices.end()) {
|
||||
it = devices.try_emplace(device.index, device.index).first;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
DeviceStream& get_stream(Stream s) {
|
||||
return device(s.device).get_stream(s);
|
||||
}
|
||||
|
||||
CommandEncoder& get_command_encoder(Stream s) {
|
||||
return get_stream(s).get_encoder();
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
} // namespace mlx::core
|
131
mlx/backend/cuda/device.h
Normal file
131
mlx/backend/cuda/device.h
Normal file
@@ -0,0 +1,131 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/cuda/worker.h"
|
||||
#include "mlx/stream.h"
|
||||
|
||||
#include <thrust/execution_policy.h>
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
class Device;
|
||||
class CommandEncoder;
|
||||
|
||||
class DeviceStream {
|
||||
public:
|
||||
explicit DeviceStream(Device& device);
|
||||
|
||||
DeviceStream(const DeviceStream&) = delete;
|
||||
DeviceStream& operator=(const DeviceStream&) = delete;
|
||||
|
||||
// Wait until kernels in the stream complete.
|
||||
void synchronize();
|
||||
|
||||
// Return a cuda stream for launching kernels.
|
||||
cudaStream_t schedule_cuda_stream();
|
||||
|
||||
// Return the last cuda stream used.
|
||||
cudaStream_t last_cuda_stream();
|
||||
|
||||
CommandEncoder& get_encoder();
|
||||
|
||||
Device& device() {
|
||||
return device_;
|
||||
}
|
||||
|
||||
private:
|
||||
Device& device_;
|
||||
CudaStream stream_;
|
||||
std::unique_ptr<CommandEncoder> encoder_;
|
||||
};
|
||||
|
||||
class Device {
|
||||
public:
|
||||
explicit Device(int device);
|
||||
|
||||
Device(const Device&) = delete;
|
||||
Device& operator=(const Device&) = delete;
|
||||
|
||||
// Make this device the current cuda device, required by some cuda calls.
|
||||
void make_current();
|
||||
|
||||
DeviceStream& get_stream(Stream s);
|
||||
|
||||
int cuda_device() const {
|
||||
return device_;
|
||||
}
|
||||
|
||||
private:
|
||||
int device_;
|
||||
std::unordered_map<int, DeviceStream> streams_;
|
||||
};
|
||||
|
||||
class CommandEncoder {
|
||||
public:
|
||||
explicit CommandEncoder(DeviceStream& stream);
|
||||
|
||||
CommandEncoder(const CommandEncoder&) = delete;
|
||||
CommandEncoder& operator=(const CommandEncoder&) = delete;
|
||||
|
||||
void set_input_array(const array& arr) {}
|
||||
void set_output_array(const array& arr) {}
|
||||
|
||||
void add_temporary(const array& arr) {
|
||||
temporaries_.push_back(arr.data_shared_ptr());
|
||||
}
|
||||
|
||||
void add_completed_handler(std::function<void()> task);
|
||||
void end_encoding();
|
||||
void commit();
|
||||
|
||||
// Schedule a cuda stream for |fun| to launch kernels, and check error
|
||||
// afterwards.
|
||||
template <typename F>
|
||||
void launch_kernel(F&& fun) {
|
||||
launch_kernel(stream_.schedule_cuda_stream(), std::forward<F>(fun));
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
void launch_kernel(cudaStream_t stream, F&& fun) {
|
||||
device_.make_current();
|
||||
fun(stream);
|
||||
check_cuda_error("kernel launch", cudaGetLastError());
|
||||
has_gpu_work_ = true;
|
||||
}
|
||||
|
||||
Device& device() {
|
||||
return device_;
|
||||
}
|
||||
|
||||
DeviceStream& stream() {
|
||||
return stream_;
|
||||
}
|
||||
|
||||
bool has_gpu_work() const {
|
||||
return has_gpu_work_;
|
||||
}
|
||||
|
||||
private:
|
||||
Device& device_;
|
||||
DeviceStream& stream_;
|
||||
Worker worker_;
|
||||
bool has_gpu_work_{false};
|
||||
std::vector<std::shared_ptr<array::Data>> temporaries_;
|
||||
};
|
||||
|
||||
Device& device(mlx::core::Device device);
|
||||
DeviceStream& get_stream(Stream s);
|
||||
CommandEncoder& get_command_encoder(Stream s);
|
||||
|
||||
// Return an execution policy that does not sync for result.
|
||||
// Note that not all thrust APIs support async policy, confirm before using.
|
||||
inline auto thrust_policy(cudaStream_t stream) {
|
||||
// TODO: Connect thrust's custom allocator with mlx's allocator.
|
||||
return thrust::cuda::par_nosync.on(stream);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
35
mlx/backend/cuda/dtype_utils.cuh
Normal file
35
mlx/backend/cuda/dtype_utils.cuh
Normal file
@@ -0,0 +1,35 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuComplex.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
// Maps CPU types to CUDA types.
|
||||
template <typename T>
|
||||
struct CTypeToCudaType {
|
||||
using type = T;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct CTypeToCudaType<float16_t> {
|
||||
using type = __half;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct CTypeToCudaType<bfloat16_t> {
|
||||
using type = __nv_bfloat16;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct CTypeToCudaType<complex64_t> {
|
||||
using type = cuComplex;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
using cuda_type_t = typename CTypeToCudaType<T>::type;
|
||||
|
||||
} // namespace mlx::core
|
68
mlx/backend/cuda/eval.cpp
Normal file
68
mlx/backend/cuda/eval.cpp
Normal file
@@ -0,0 +1,68 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/gpu/eval.h"
|
||||
#include "mlx/backend/cuda/allocator.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/gpu/available.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
namespace mlx::core::gpu {
|
||||
|
||||
bool is_available() {
|
||||
return true;
|
||||
}
|
||||
|
||||
void new_stream(Stream s) {
|
||||
// Force initalization of cuda, so cuda runtime get destroyed at last.
|
||||
cudaFree(nullptr);
|
||||
// Ensure the static stream objects get created.
|
||||
cu::get_command_encoder(s);
|
||||
// The main thread is safe to free buffers.
|
||||
cu::allocator().register_this_thread();
|
||||
}
|
||||
|
||||
void eval(array& arr) {
|
||||
nvtx3::scoped_range r("gpu::eval");
|
||||
auto outputs = arr.outputs();
|
||||
{
|
||||
// If the array is a tracer hold a reference
|
||||
// to its inputs so they don't get donated
|
||||
std::vector<array> inputs;
|
||||
if (arr.is_tracer()) {
|
||||
inputs = arr.inputs();
|
||||
}
|
||||
arr.primitive().eval_gpu(arr.inputs(), outputs);
|
||||
}
|
||||
|
||||
auto& encoder = cu::get_command_encoder(arr.primitive().stream());
|
||||
if (encoder.has_gpu_work()) {
|
||||
// Keep used buffers alive until kernel finishes running.
|
||||
std::unordered_set<std::shared_ptr<array::Data>> buffers;
|
||||
for (auto& in : arr.inputs()) {
|
||||
buffers.insert(in.data_shared_ptr());
|
||||
}
|
||||
for (auto& s : arr.siblings()) {
|
||||
buffers.insert(s.data_shared_ptr());
|
||||
}
|
||||
// Remove the output if it was donated to by an input.
|
||||
if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) {
|
||||
buffers.erase(it);
|
||||
}
|
||||
encoder.add_completed_handler([buffers = std::move(buffers)]() {});
|
||||
}
|
||||
encoder.end_encoding();
|
||||
}
|
||||
|
||||
void finalize(Stream s) {
|
||||
nvtx3::scoped_range r("gpu::finalize");
|
||||
cu::get_command_encoder(s).commit();
|
||||
}
|
||||
|
||||
void synchronize(Stream s) {
|
||||
nvtx3::scoped_range r("gpu::synchronize");
|
||||
cu::get_stream(s).synchronize();
|
||||
}
|
||||
|
||||
} // namespace mlx::core::gpu
|
265
mlx/backend/cuda/event.cu
Normal file
265
mlx/backend/cuda/event.cu
Normal file
@@ -0,0 +1,265 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/event.h"
|
||||
#include "mlx/backend/cuda/utils.h"
|
||||
#include "mlx/event.h"
|
||||
#include "mlx/scheduler.h"
|
||||
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// CudaEvent implementations
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Cuda event managed with RAII.
|
||||
class CudaEventHandle {
|
||||
public:
|
||||
CudaEventHandle() {
|
||||
CHECK_CUDA_ERROR(cudaEventCreateWithFlags(
|
||||
&event_, cudaEventDisableTiming | cudaEventBlockingSync));
|
||||
}
|
||||
|
||||
~CudaEventHandle() {
|
||||
CHECK_CUDA_ERROR(cudaEventDestroy(event_));
|
||||
}
|
||||
|
||||
CudaEventHandle(const CudaEventHandle&) = delete;
|
||||
CudaEventHandle& operator=(const CudaEventHandle&) = delete;
|
||||
|
||||
operator cudaEvent_t() const {
|
||||
return event_;
|
||||
}
|
||||
|
||||
private:
|
||||
cudaEvent_t event_;
|
||||
};
|
||||
|
||||
CudaEvent::CudaEvent() : event_(std::make_shared<CudaEventHandle>()) {}
|
||||
|
||||
void CudaEvent::wait() {
|
||||
nvtx3::scoped_range r("cu::CudaEvent::wait");
|
||||
if (!recorded_) {
|
||||
throw std::runtime_error("Should not wait on a CudaEvent before record.");
|
||||
}
|
||||
cudaEventSynchronize(*event_);
|
||||
}
|
||||
|
||||
void CudaEvent::wait(cudaStream_t stream) {
|
||||
if (!recorded_) {
|
||||
throw std::runtime_error("Should not wait on a CudaEvent before record.");
|
||||
}
|
||||
cudaStreamWaitEvent(stream, *event_);
|
||||
}
|
||||
|
||||
void CudaEvent::wait(Stream s) {
|
||||
if (s.device == mlx::core::Device::cpu) {
|
||||
scheduler::enqueue(s, [*this]() mutable { wait(); });
|
||||
} else {
|
||||
wait(cu::get_stream(s).last_cuda_stream());
|
||||
}
|
||||
}
|
||||
|
||||
void CudaEvent::record(cudaStream_t stream) {
|
||||
cudaEventRecord(*event_, stream);
|
||||
recorded_ = true;
|
||||
}
|
||||
|
||||
void CudaEvent::record(Stream s) {
|
||||
if (s.device == mlx::core::Device::cpu) {
|
||||
throw std::runtime_error("CudaEvent can not wait on cpu stream.");
|
||||
} else {
|
||||
record(cu::get_stream(s).last_cuda_stream());
|
||||
}
|
||||
}
|
||||
|
||||
bool CudaEvent::completed() const {
|
||||
return cudaEventQuery(*event_) == cudaSuccess;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// SharedEvent implementations
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace {
|
||||
|
||||
__host__ __device__ void event_wait(SharedEvent::Atomic* ac, uint64_t value) {
|
||||
uint64_t current;
|
||||
while ((current = ac->load()) < value) {
|
||||
ac->wait(current);
|
||||
}
|
||||
}
|
||||
|
||||
__host__ __device__ void event_signal(SharedEvent::Atomic* ac, uint64_t value) {
|
||||
ac->store(value);
|
||||
ac->notify_all();
|
||||
}
|
||||
|
||||
__global__ void event_wait_kernel(SharedEvent::Atomic* ac, uint64_t value) {
|
||||
event_wait(ac, value);
|
||||
}
|
||||
|
||||
__global__ void event_signal_kernel(SharedEvent::Atomic* ac, uint64_t value) {
|
||||
event_signal(ac, value);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
SharedEvent::SharedEvent() {
|
||||
// Allocate cuda::atomic on managed memory.
|
||||
allocator::Buffer buffer = allocator::malloc(sizeof(Atomic));
|
||||
Atomic* ac = static_cast<Atomic*>(buffer.raw_ptr());
|
||||
new (ac) Atomic(0);
|
||||
ac_ = std::shared_ptr<Atomic>(ac, [buffer](Atomic* ptr) {
|
||||
ptr->~Atomic();
|
||||
allocator::free(buffer);
|
||||
});
|
||||
}
|
||||
|
||||
void SharedEvent::wait(uint64_t value) {
|
||||
nvtx3::scoped_range r("cu::SharedEvent::wait");
|
||||
event_wait(ac_.get(), value);
|
||||
}
|
||||
|
||||
void SharedEvent::wait(cudaStream_t stream, uint64_t value) {
|
||||
event_wait_kernel<<<1, 1, 0, stream>>>(ac_.get(), value);
|
||||
}
|
||||
|
||||
void SharedEvent::wait(Stream s, uint64_t value) {
|
||||
nvtx3::scoped_range r("cu::SharedEvent::wait(s)");
|
||||
if (s.device == mlx::core::Device::cpu) {
|
||||
scheduler::enqueue(s, [*this, value]() mutable { wait(value); });
|
||||
} else {
|
||||
auto& encoder = get_command_encoder(s);
|
||||
encoder.launch_kernel(
|
||||
encoder.stream().last_cuda_stream(),
|
||||
[this, value](cudaStream_t stream) { wait(stream, value); });
|
||||
encoder.add_completed_handler([ac = ac_]() {});
|
||||
encoder.end_encoding();
|
||||
}
|
||||
}
|
||||
|
||||
void SharedEvent::signal(uint64_t value) {
|
||||
nvtx3::scoped_range r("cu::SharedEvent::signal");
|
||||
event_signal(ac_.get(), value);
|
||||
}
|
||||
|
||||
void SharedEvent::signal(cudaStream_t stream, uint64_t value) {
|
||||
event_signal_kernel<<<1, 1, 0, stream>>>(ac_.get(), value);
|
||||
}
|
||||
|
||||
void SharedEvent::signal(Stream s, uint64_t value) {
|
||||
nvtx3::scoped_range r("cu::SharedEvent::signal(s)");
|
||||
if (s.device == mlx::core::Device::cpu) {
|
||||
scheduler::enqueue(s, [*this, value]() mutable { signal(value); });
|
||||
} else {
|
||||
auto& encoder = get_command_encoder(s);
|
||||
encoder.launch_kernel(
|
||||
encoder.stream().last_cuda_stream(),
|
||||
[this, value](cudaStream_t stream) { signal(stream, value); });
|
||||
encoder.add_completed_handler([ac = ac_]() {});
|
||||
encoder.end_encoding();
|
||||
}
|
||||
}
|
||||
|
||||
bool SharedEvent::is_signaled(uint64_t value) const {
|
||||
nvtx3::scoped_range r("cu::SharedEvent::is_signaled");
|
||||
return ac_->load() >= value;
|
||||
}
|
||||
|
||||
uint64_t SharedEvent::value() const {
|
||||
nvtx3::scoped_range r("cu::SharedEvent::value");
|
||||
return ac_->load();
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Event implementations
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace {
|
||||
|
||||
struct EventImpl {
|
||||
// CudaEvent is preferred when possible because it is fast, however we have
|
||||
// to fallback to SharedEvent in following cases:
|
||||
// 1. the event is used to wait/signal a cpu stream;
|
||||
// 2. signal value other than 1 has been specified.
|
||||
std::unique_ptr<cu::CudaEvent> cuda;
|
||||
std::unique_ptr<cu::SharedEvent> shared;
|
||||
|
||||
bool is_created() const {
|
||||
return cuda || shared;
|
||||
}
|
||||
|
||||
void ensure_created(Stream s, uint64_t signal_value) {
|
||||
if (is_created()) {
|
||||
return;
|
||||
}
|
||||
if (s.device == mlx::core::Device::cpu || signal_value > 1) {
|
||||
nvtx3::mark("Using slow SharedEvent");
|
||||
shared = std::make_unique<cu::SharedEvent>();
|
||||
} else {
|
||||
cuda = std::make_unique<cu::CudaEvent>();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
Event::Event(Stream s) : stream_(s) {
|
||||
event_ = std::shared_ptr<void>(
|
||||
new EventImpl(), [](void* ptr) { delete static_cast<EventImpl*>(ptr); });
|
||||
}
|
||||
|
||||
void Event::wait() {
|
||||
auto* event = static_cast<EventImpl*>(event_.get());
|
||||
assert(event->is_created());
|
||||
if (event->cuda) {
|
||||
assert(value() == 1);
|
||||
event->cuda->wait();
|
||||
} else {
|
||||
event->shared->wait(value());
|
||||
}
|
||||
}
|
||||
|
||||
void Event::wait(Stream s) {
|
||||
auto* event = static_cast<EventImpl*>(event_.get());
|
||||
assert(event->is_created());
|
||||
if (event->cuda) {
|
||||
assert(value() == 1);
|
||||
event->cuda->wait(s);
|
||||
} else {
|
||||
event->shared->wait(s, value());
|
||||
}
|
||||
}
|
||||
|
||||
void Event::signal(Stream s) {
|
||||
auto* event = static_cast<EventImpl*>(event_.get());
|
||||
event->ensure_created(s, value());
|
||||
if (event->cuda) {
|
||||
assert(value() == 1);
|
||||
event->cuda->record(s);
|
||||
} else {
|
||||
event->shared->signal(s, value());
|
||||
}
|
||||
}
|
||||
|
||||
bool Event::is_signaled() const {
|
||||
auto* event = static_cast<EventImpl*>(event_.get());
|
||||
if (!event->is_created()) {
|
||||
return false;
|
||||
}
|
||||
if (event->cuda) {
|
||||
assert(value() == 1);
|
||||
return event->cuda->recorded() && event->cuda->completed();
|
||||
} else {
|
||||
return event->shared->is_signaled(value());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
66
mlx/backend/cuda/event.h
Normal file
66
mlx/backend/cuda/event.h
Normal file
@@ -0,0 +1,66 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/stream.h"
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda/atomic>
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
class CudaEventHandle;
|
||||
|
||||
// Wrapper of native cuda event. It can synchronize between GPU streams, or wait
|
||||
// on GPU stream in CPU stream, but can not wait on CPU stream.
|
||||
class CudaEvent {
|
||||
public:
|
||||
CudaEvent();
|
||||
|
||||
void wait();
|
||||
void wait(cudaStream_t stream);
|
||||
void wait(Stream s);
|
||||
void record(cudaStream_t stream);
|
||||
void record(Stream s);
|
||||
|
||||
// Return whether the recorded kernels have completed. Note that this method
|
||||
// returns true if record() has not been called.
|
||||
bool completed() const;
|
||||
|
||||
bool recorded() const {
|
||||
return recorded_;
|
||||
}
|
||||
|
||||
private:
|
||||
bool recorded_{false};
|
||||
std::shared_ptr<CudaEventHandle> event_;
|
||||
};
|
||||
|
||||
// Event that can synchronize between CPU and GPU. It is much slower than
|
||||
// CudaEvent so the latter should always be preferred when possible.
|
||||
class SharedEvent {
|
||||
public:
|
||||
using Atomic = cuda::atomic<uint64_t>;
|
||||
|
||||
SharedEvent();
|
||||
|
||||
void wait(uint64_t value);
|
||||
void wait(cudaStream_t stream, uint64_t value);
|
||||
void wait(Stream s, uint64_t value);
|
||||
void signal(uint64_t value);
|
||||
void signal(cudaStream_t stream, uint64_t value);
|
||||
void signal(Stream s, uint64_t value);
|
||||
bool is_signaled(uint64_t value) const;
|
||||
uint64_t value() const;
|
||||
|
||||
const std::shared_ptr<Atomic>& atomic() const {
|
||||
return ac_;
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<Atomic> ac_;
|
||||
};
|
||||
|
||||
} // namespace mlx::core::cu
|
70
mlx/backend/cuda/fence.cu
Normal file
70
mlx/backend/cuda/fence.cu
Normal file
@@ -0,0 +1,70 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/event.h"
|
||||
#include "mlx/fence.h"
|
||||
#include "mlx/scheduler.h"
|
||||
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
__host__ __device__ void busy_wait(cuda::atomic<uint64_t>* ac, uint64_t value) {
|
||||
while (true) {
|
||||
// In theory the atomic_thread_fence is not needed, but for CUDA 11 without
|
||||
// it the load() may never return new value.
|
||||
cuda::atomic_thread_fence(cuda::memory_order_seq_cst);
|
||||
uint64_t current = ac->load();
|
||||
if (current >= value) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void busy_wait_kernel(cuda::atomic<uint64_t>* ac, uint64_t value) {
|
||||
busy_wait(ac, value);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
struct FenceImpl {
|
||||
uint32_t count;
|
||||
cu::SharedEvent event;
|
||||
};
|
||||
|
||||
Fence::Fence(Stream s) {
|
||||
fence_ = std::shared_ptr<void>(
|
||||
new FenceImpl{0}, [](void* ptr) { delete static_cast<FenceImpl*>(ptr); });
|
||||
}
|
||||
|
||||
void Fence::wait(Stream s, const array&) {
|
||||
auto* fence = static_cast<FenceImpl*>(fence_.get());
|
||||
// We can't use SharedEvent::wait because it could hang in CUDA 11, see also:
|
||||
// https://github.com/ml-explore/mlx/issues/2137
|
||||
const auto& ac = fence->event.atomic();
|
||||
if (s.device == mlx::core::Device::cpu) {
|
||||
scheduler::enqueue(s, [ac, count = fence->count]() {
|
||||
nvtx3::scoped_range r("Fence::wait()");
|
||||
busy_wait(ac.get(), count);
|
||||
});
|
||||
} else {
|
||||
nvtx3::scoped_range r("Fence::wait(s)");
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.launch_kernel(
|
||||
encoder.stream().last_cuda_stream(), [&](cudaStream_t stream) {
|
||||
busy_wait_kernel<<<1, 1, 0>>>(ac.get(), fence->count);
|
||||
});
|
||||
encoder.add_completed_handler([ac]() {});
|
||||
encoder.end_encoding();
|
||||
}
|
||||
}
|
||||
|
||||
void Fence::update(Stream s, const array&) {
|
||||
auto* fence = static_cast<FenceImpl*>(fence_.get());
|
||||
fence->count++;
|
||||
fence->event.signal(s, fence->count);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
15
mlx/backend/cuda/kernels/arange.cuh
Normal file
15
mlx/backend/cuda/kernels/arange.cuh
Normal file
@@ -0,0 +1,15 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
template <typename T>
|
||||
struct Arange {
|
||||
const T start;
|
||||
const T step;
|
||||
|
||||
__device__ T operator()(uint32_t i) const {
|
||||
return start + i * step;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mlx::core::cu
|
107
mlx/backend/cuda/kernels/fp16_math.cuh
Normal file
107
mlx/backend/cuda/kernels/fp16_math.cuh
Normal file
@@ -0,0 +1,107 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda/std/limits>
|
||||
#include <cuda/std/type_traits>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Missing C++ operator overrides for CUDA 7.
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800
|
||||
|
||||
#define MLX_DEFINE_BF16_OP(OP) \
|
||||
__forceinline__ __device__ __nv_bfloat16 operator OP( \
|
||||
__nv_bfloat16 x, __nv_bfloat16 y) { \
|
||||
return __float2bfloat16(__bfloat162float(x) OP __bfloat162float(y)); \
|
||||
}
|
||||
|
||||
#define MLX_DEFINE_BF16_CMP(OP) \
|
||||
__forceinline__ __device__ bool operator OP( \
|
||||
__nv_bfloat16 x, __nv_bfloat16 y) { \
|
||||
return __float2bfloat16(__bfloat162float(x) OP __bfloat162float(y)); \
|
||||
}
|
||||
|
||||
MLX_DEFINE_BF16_OP(+)
|
||||
MLX_DEFINE_BF16_OP(-)
|
||||
MLX_DEFINE_BF16_OP(*)
|
||||
MLX_DEFINE_BF16_OP(/)
|
||||
MLX_DEFINE_BF16_CMP(>)
|
||||
MLX_DEFINE_BF16_CMP(<)
|
||||
MLX_DEFINE_BF16_CMP(>=)
|
||||
MLX_DEFINE_BF16_CMP(<=)
|
||||
|
||||
#undef MLX_DEFINE_BF16_OP
|
||||
#undef MLX_DEFINE_BF16_CMP
|
||||
|
||||
#endif // CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Additional C++ operator overrides between half types and native types.
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename U>
|
||||
constexpr bool is_integral_except =
|
||||
cuda::std::is_integral_v<T> && !cuda::std::is_same_v<T, U>;
|
||||
|
||||
template <typename T, typename U>
|
||||
constexpr bool is_arithmetic_except =
|
||||
cuda::std::is_arithmetic_v<T> && !cuda::std::is_same_v<T, U>;
|
||||
|
||||
#define MLX_DEFINE_HALF_OP(HALF, HALF2FLOAT, FLOAT2HALF, OP) \
|
||||
template < \
|
||||
typename T, \
|
||||
typename = cuda::std::enable_if_t<is_integral_except<T, HALF>>> \
|
||||
__forceinline__ __device__ HALF operator OP(HALF x, T y) { \
|
||||
return FLOAT2HALF(HALF2FLOAT(x) OP static_cast<float>(y)); \
|
||||
} \
|
||||
template < \
|
||||
typename T, \
|
||||
typename = cuda::std::enable_if_t<is_integral_except<T, HALF>>> \
|
||||
__forceinline__ __device__ HALF operator OP(T x, HALF y) { \
|
||||
return FLOAT2HALF(static_cast<float>(x) OP HALF2FLOAT(y)); \
|
||||
}
|
||||
|
||||
#define MLX_DEFINE_HALF_CMP(HALF, HALF2FLOAT, OP) \
|
||||
template < \
|
||||
typename T, \
|
||||
typename = cuda::std::enable_if_t<is_arithmetic_except<T, HALF>>> \
|
||||
__forceinline__ __device__ bool operator OP(HALF x, T y) { \
|
||||
return HALF2FLOAT(x) OP static_cast<float>(y); \
|
||||
} \
|
||||
template < \
|
||||
typename T, \
|
||||
typename = cuda::std::enable_if_t<is_arithmetic_except<T, HALF>>> \
|
||||
__forceinline__ __device__ bool operator OP(T x, HALF y) { \
|
||||
return static_cast<float>(y) OP HALF2FLOAT(x); \
|
||||
}
|
||||
|
||||
MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, +)
|
||||
MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, -)
|
||||
MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, *)
|
||||
MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, /)
|
||||
MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, +)
|
||||
MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, -)
|
||||
MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, *)
|
||||
MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, /)
|
||||
MLX_DEFINE_HALF_CMP(__half, __half2float, <)
|
||||
MLX_DEFINE_HALF_CMP(__half, __half2float, >)
|
||||
MLX_DEFINE_HALF_CMP(__half, __half2float, <=)
|
||||
MLX_DEFINE_HALF_CMP(__half, __half2float, >=)
|
||||
MLX_DEFINE_HALF_CMP(__half, __half2float, ==)
|
||||
MLX_DEFINE_HALF_CMP(__half, __half2float, !=)
|
||||
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, <)
|
||||
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, >)
|
||||
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, <=)
|
||||
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, >=)
|
||||
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, ==)
|
||||
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, !=)
|
||||
|
||||
#undef MLX_DEFINE_HALF_OP
|
||||
#undef MLX_DEFINE_HALF_CMP
|
||||
|
||||
} // namespace mlx::core::cu
|
163
mlx/backend/cuda/primitives.cu
Normal file
163
mlx/backend/cuda/primitives.cu
Normal file
@@ -0,0 +1,163 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/dtype_utils.cuh"
|
||||
#include "mlx/backend/cuda/kernels/arange.cuh"
|
||||
#include "mlx/backend/cuda/kernels/fp16_math.cuh"
|
||||
#include "mlx/distributed/primitives.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
#include <thrust/device_ptr.h>
|
||||
#include <thrust/transform.h>
|
||||
|
||||
#include <cassert>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("Arange::eval_gpu");
|
||||
assert(inputs.size() == 0);
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
auto& s = stream();
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&, this](cudaStream_t stream) {
|
||||
MLX_SWITCH_INT_FLOAT_TYPES_CHECKED(out.dtype(), "Arange", CTYPE, {
|
||||
using OutType = cuda_type_t<CTYPE>;
|
||||
CTYPE step =
|
||||
static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_);
|
||||
thrust::transform(
|
||||
cu::thrust_policy(stream),
|
||||
thrust::counting_iterator<uint32_t>(0),
|
||||
thrust::counting_iterator<uint32_t>(out.data_size()),
|
||||
thrust::device_pointer_cast(out.data<OutType>()),
|
||||
cu::Arange<OutType>{
|
||||
static_cast<OutType>(start_), static_cast<OutType>(step)});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
#define NO_GPU_MULTI(func) \
|
||||
void func::eval_gpu( \
|
||||
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
||||
throw std::runtime_error(#func " has no CUDA implementation."); \
|
||||
}
|
||||
|
||||
#define NO_GPU(func) \
|
||||
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
|
||||
throw std::runtime_error(#func " has no CUDA implementation."); \
|
||||
}
|
||||
|
||||
NO_GPU(Abs)
|
||||
NO_GPU(Add)
|
||||
NO_GPU(AddMM)
|
||||
NO_GPU(ArcCos)
|
||||
NO_GPU(ArcCosh)
|
||||
NO_GPU(ArcSin)
|
||||
NO_GPU(ArcSinh)
|
||||
NO_GPU(ArcTan)
|
||||
NO_GPU(ArcTan2)
|
||||
NO_GPU(ArcTanh)
|
||||
NO_GPU(ArgPartition)
|
||||
NO_GPU(ArgReduce)
|
||||
NO_GPU(ArgSort)
|
||||
NO_GPU(BitwiseBinary)
|
||||
NO_GPU(BitwiseInvert)
|
||||
NO_GPU(BlockMaskedMM)
|
||||
NO_GPU(Ceil)
|
||||
NO_GPU_MULTI(Compiled)
|
||||
NO_GPU(Conjugate)
|
||||
NO_GPU(Convolution)
|
||||
NO_GPU(Cos)
|
||||
NO_GPU(Cosh)
|
||||
NO_GPU(Divide)
|
||||
NO_GPU_MULTI(DivMod)
|
||||
NO_GPU(DynamicSlice)
|
||||
NO_GPU(DynamicSliceUpdate)
|
||||
NO_GPU(Remainder)
|
||||
NO_GPU(Equal)
|
||||
NO_GPU(Erf)
|
||||
NO_GPU(ErfInv)
|
||||
NO_GPU(Exp)
|
||||
NO_GPU(Expm1)
|
||||
NO_GPU(FFT)
|
||||
NO_GPU(Floor)
|
||||
NO_GPU(Gather)
|
||||
NO_GPU(GatherAxis)
|
||||
NO_GPU(GatherMM)
|
||||
NO_GPU(GatherQMM)
|
||||
NO_GPU(Greater)
|
||||
NO_GPU(GreaterEqual)
|
||||
NO_GPU(Hadamard)
|
||||
NO_GPU(Imag)
|
||||
NO_GPU(Less)
|
||||
NO_GPU(LessEqual)
|
||||
NO_GPU(Load)
|
||||
NO_GPU(Log)
|
||||
NO_GPU(Log1p)
|
||||
NO_GPU(LogicalNot)
|
||||
NO_GPU(LogicalAnd)
|
||||
NO_GPU(LogicalOr)
|
||||
NO_GPU(LogAddExp)
|
||||
NO_GPU(LogSumExp)
|
||||
NO_GPU_MULTI(LUF)
|
||||
NO_GPU(Matmul)
|
||||
NO_GPU(Maximum)
|
||||
NO_GPU(Minimum)
|
||||
NO_GPU(Multiply)
|
||||
NO_GPU(Negative)
|
||||
NO_GPU(NotEqual)
|
||||
NO_GPU(Partition)
|
||||
NO_GPU(Power)
|
||||
NO_GPU_MULTI(QRF)
|
||||
NO_GPU(QuantizedMatmul)
|
||||
NO_GPU(RandomBits)
|
||||
NO_GPU(Real)
|
||||
NO_GPU(Reduce)
|
||||
NO_GPU(Round)
|
||||
NO_GPU(Scan)
|
||||
NO_GPU(Scatter)
|
||||
NO_GPU(ScatterAxis)
|
||||
NO_GPU(Select)
|
||||
NO_GPU(Sigmoid)
|
||||
NO_GPU(Sign)
|
||||
NO_GPU(Sin)
|
||||
NO_GPU(Sinh)
|
||||
NO_GPU(SliceUpdate)
|
||||
NO_GPU(Softmax)
|
||||
NO_GPU(Sort)
|
||||
NO_GPU(Square)
|
||||
NO_GPU(Sqrt)
|
||||
NO_GPU(Subtract)
|
||||
NO_GPU_MULTI(SVD)
|
||||
NO_GPU(Tan)
|
||||
NO_GPU(Tanh)
|
||||
NO_GPU(Inverse)
|
||||
NO_GPU(Cholesky)
|
||||
NO_GPU_MULTI(Eigh)
|
||||
|
||||
namespace fast {
|
||||
NO_GPU_MULTI(LayerNorm)
|
||||
NO_GPU_MULTI(LayerNormVJP)
|
||||
NO_GPU_MULTI(RMSNorm)
|
||||
NO_GPU_MULTI(RMSNormVJP)
|
||||
NO_GPU_MULTI(RoPE)
|
||||
NO_GPU(ScaledDotProductAttention)
|
||||
NO_GPU_MULTI(AffineQuantize)
|
||||
NO_GPU_MULTI(CustomKernel)
|
||||
} // namespace fast
|
||||
|
||||
namespace distributed {
|
||||
NO_GPU_MULTI(AllReduce)
|
||||
NO_GPU_MULTI(AllGather)
|
||||
NO_GPU_MULTI(Send)
|
||||
NO_GPU_MULTI(Recv)
|
||||
} // namespace distributed
|
||||
|
||||
} // namespace mlx::core
|
15
mlx/backend/cuda/slicing.cpp
Normal file
15
mlx/backend/cuda/slicing.cpp
Normal file
@@ -0,0 +1,15 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/gpu/slicing.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void concatenate_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
int axis,
|
||||
const Stream& s) {
|
||||
throw std::runtime_error("concatenate_gpu not implemented in CUDA backend.");
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
26
mlx/backend/cuda/utils.cpp
Normal file
26
mlx/backend/cuda/utils.cpp
Normal file
@@ -0,0 +1,26 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/utils.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
|
||||
#include <fmt/format.h>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
CudaStream::CudaStream(cu::Device& device) {
|
||||
device.make_current();
|
||||
CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking));
|
||||
}
|
||||
|
||||
CudaStream::~CudaStream() {
|
||||
CHECK_CUDA_ERROR(cudaStreamDestroy(stream_));
|
||||
}
|
||||
|
||||
void check_cuda_error(const char* name, cudaError_t err) {
|
||||
if (err != cudaSuccess) {
|
||||
throw std::runtime_error(
|
||||
fmt::format("{} failed: {}", name, cudaGetErrorString(err)));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
36
mlx/backend/cuda/utils.h
Normal file
36
mlx/backend/cuda/utils.h
Normal file
@@ -0,0 +1,36 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
class Device;
|
||||
}
|
||||
|
||||
// Cuda stream managed with RAII.
|
||||
class CudaStream {
|
||||
public:
|
||||
explicit CudaStream(cu::Device& device);
|
||||
~CudaStream();
|
||||
|
||||
CudaStream(const CudaStream&) = delete;
|
||||
CudaStream& operator=(const CudaStream&) = delete;
|
||||
|
||||
operator cudaStream_t() const {
|
||||
return stream_;
|
||||
}
|
||||
|
||||
private:
|
||||
cudaStream_t stream_;
|
||||
};
|
||||
|
||||
// Throw exception if the cuda API does not succeed.
|
||||
void check_cuda_error(const char* name, cudaError_t err);
|
||||
|
||||
// The macro version that prints the command that failed.
|
||||
#define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd))
|
||||
|
||||
} // namespace mlx::core
|
90
mlx/backend/cuda/worker.cpp
Normal file
90
mlx/backend/cuda/worker.cpp
Normal file
@@ -0,0 +1,90 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/worker.h"
|
||||
#include "mlx/backend/cuda/allocator.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
Worker::Worker()
|
||||
: signal_stream_(device(mlx::core::Device::gpu)),
|
||||
worker_(&Worker::thread_fn, this) {}
|
||||
|
||||
Worker::~Worker() {
|
||||
{
|
||||
std::lock_guard lock(worker_mutex_);
|
||||
stop_ = true;
|
||||
}
|
||||
worker_event_.signal(batch_ + 1);
|
||||
worker_.join();
|
||||
}
|
||||
|
||||
void Worker::add_task(std::function<void()> task) {
|
||||
pending_tasks_.push_back(std::move(task));
|
||||
}
|
||||
|
||||
void Worker::consume_in_this_thread() {
|
||||
for (auto& task : pending_tasks_) {
|
||||
task();
|
||||
}
|
||||
pending_tasks_.clear();
|
||||
}
|
||||
|
||||
void Worker::end_batch() {
|
||||
batch_++;
|
||||
{
|
||||
std::lock_guard lock(worker_mutex_);
|
||||
worker_tasks_[batch_] = std::move(pending_tasks_);
|
||||
}
|
||||
uncommited_batches_++;
|
||||
}
|
||||
|
||||
void Worker::commit() {
|
||||
if (uncommited_batches_ == 0) {
|
||||
return;
|
||||
}
|
||||
uncommited_batches_ = 0;
|
||||
worker_event_.signal(batch_);
|
||||
}
|
||||
|
||||
void Worker::commit(cudaStream_t stream) {
|
||||
if (uncommited_batches_ == 0) {
|
||||
return;
|
||||
}
|
||||
uncommited_batches_ = 0;
|
||||
// Signal the |worker_event_| in |signal_stream_| after the kernels in
|
||||
// |stream_| finish running.
|
||||
signal_event_.record(stream);
|
||||
signal_event_.wait(signal_stream_);
|
||||
worker_event_.signal(signal_stream_, batch_);
|
||||
}
|
||||
|
||||
void Worker::thread_fn() {
|
||||
// The worker thread is safe to free buffers.
|
||||
allocator().register_this_thread();
|
||||
|
||||
while (!stop_) {
|
||||
uint64_t batch = worker_event_.value();
|
||||
Tasks tasks;
|
||||
{
|
||||
std::lock_guard lock(worker_mutex_);
|
||||
// Move tasks in signaled batches.
|
||||
auto end = worker_tasks_.upper_bound(batch);
|
||||
for (auto it = worker_tasks_.begin(); it != end; ++it) {
|
||||
if (tasks.empty()) {
|
||||
tasks = std::move(it->second);
|
||||
} else {
|
||||
std::move(
|
||||
it->second.begin(), it->second.end(), std::back_inserter(tasks));
|
||||
}
|
||||
}
|
||||
worker_tasks_.erase(worker_tasks_.begin(), end);
|
||||
}
|
||||
for (auto& task : tasks) {
|
||||
task();
|
||||
}
|
||||
worker_event_.wait(batch + 1);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
68
mlx/backend/cuda/worker.h
Normal file
68
mlx/backend/cuda/worker.h
Normal file
@@ -0,0 +1,68 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cuda/event.h"
|
||||
#include "mlx/backend/cuda/utils.h"
|
||||
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include <mutex>
|
||||
#include <thread>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
// Run tasks in worker thread, synchronized with cuda stream.
|
||||
class Worker {
|
||||
public:
|
||||
Worker();
|
||||
~Worker();
|
||||
|
||||
Worker(const Worker&) = delete;
|
||||
Worker& operator=(const Worker&) = delete;
|
||||
|
||||
// Add a pending |task| that will run when consumed or commited.
|
||||
void add_task(std::function<void()> task);
|
||||
|
||||
// Run pending tasks immediately in current thread.
|
||||
void consume_in_this_thread();
|
||||
|
||||
// Put pending tasks in a batch.
|
||||
void end_batch();
|
||||
|
||||
// Inform worker thread to run current batches now.
|
||||
void commit();
|
||||
|
||||
// Inform worker thread to run current batches after kernels in |stream|
|
||||
// finish running.
|
||||
void commit(cudaStream_t stream);
|
||||
|
||||
// Return how many batches have been added but not committed yet.
|
||||
size_t uncommited_batches() const {
|
||||
return uncommited_batches_;
|
||||
}
|
||||
|
||||
private:
|
||||
void thread_fn();
|
||||
|
||||
uint64_t batch_{0};
|
||||
size_t uncommited_batches_{0};
|
||||
|
||||
// Cuda stream and event for signaling kernel completion.
|
||||
CudaStream signal_stream_;
|
||||
CudaEvent signal_event_;
|
||||
|
||||
// Worker thread.
|
||||
SharedEvent worker_event_;
|
||||
std::thread worker_;
|
||||
std::mutex worker_mutex_;
|
||||
bool stop_{false};
|
||||
|
||||
// Tasks are put in |pending_tasks_| first, and then moved to
|
||||
// |worker_tasks_| when end_batch() is called.
|
||||
using Tasks = std::vector<std::function<void()>>;
|
||||
Tasks pending_tasks_;
|
||||
std::map<uint64_t, Tasks> worker_tasks_;
|
||||
};
|
||||
|
||||
} // namespace mlx::core::cu
|
5
mlx/backend/gpu/CMakeLists.txt
Normal file
5
mlx/backend/gpu/CMakeLists.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp)
|
9
mlx/backend/gpu/available.h
Normal file
9
mlx/backend/gpu/available.h
Normal file
@@ -0,0 +1,9 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace mlx::core::gpu {
|
||||
|
||||
bool is_available();
|
||||
|
||||
} // namespace mlx::core::gpu
|
49
mlx/backend/gpu/copy.cpp
Normal file
49
mlx/backend/gpu/copy.cpp
Normal file
@@ -0,0 +1,49 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <cassert>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
|
||||
bool donated = set_copy_output_data(in, out, ctype);
|
||||
if (donated && in.dtype() == out.dtype()) {
|
||||
// If the output has the same type as the input then there is nothing to
|
||||
// copy, just use the buffer.
|
||||
return;
|
||||
}
|
||||
if (ctype == CopyType::GeneralGeneral) {
|
||||
ctype = CopyType::General;
|
||||
}
|
||||
copy_gpu_inplace(in, out, ctype, s);
|
||||
}
|
||||
|
||||
void copy_gpu(const array& in, array& out, CopyType ctype) {
|
||||
copy_gpu(in, out, ctype, out.primitive().stream());
|
||||
}
|
||||
|
||||
void copy_gpu_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
CopyType ctype,
|
||||
const Stream& s) {
|
||||
assert(in.shape() == out.shape());
|
||||
return copy_gpu_inplace(
|
||||
in, out, in.shape(), in.strides(), out.strides(), 0, 0, ctype, s);
|
||||
}
|
||||
|
||||
void copy_gpu_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
const Strides& i_strides,
|
||||
int64_t i_offset,
|
||||
CopyType ctype,
|
||||
const Stream& s) {
|
||||
assert(in.shape() == out.shape());
|
||||
return copy_gpu_inplace(
|
||||
in, out, in.shape(), i_strides, out.strides(), i_offset, 0, ctype, s);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -5,6 +5,8 @@
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/stream.h"
|
||||
|
||||
#include <optional>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
// Generic copy inplace
|
@@ -8,14 +8,11 @@
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/stream.h"
|
||||
|
||||
namespace mlx::core::metal {
|
||||
namespace mlx::core::gpu {
|
||||
|
||||
void new_stream(Stream stream);
|
||||
|
||||
std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool();
|
||||
|
||||
void eval(array& arr);
|
||||
void finalize(Stream s);
|
||||
void synchronize(Stream s);
|
||||
|
||||
} // namespace mlx::core::metal
|
||||
} // namespace mlx::core::gpu
|
222
mlx/backend/gpu/primitives.cpp
Normal file
222
mlx/backend/gpu/primitives.cpp
Normal file
@@ -0,0 +1,222 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/gpu/slicing.h"
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#define MLX_PROFILER_RANGE(message)
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
void reshape(const array& in, array& out, Stream s) {
|
||||
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
||||
if (copy_necessary) {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
copy_gpu_inplace(
|
||||
in,
|
||||
out,
|
||||
in.shape(),
|
||||
in.strides(),
|
||||
make_contiguous_strides(in.shape()),
|
||||
0,
|
||||
0,
|
||||
CopyType::General,
|
||||
s);
|
||||
} else {
|
||||
shared_buffer_reshape(in, out_strides, out);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void AsStrided::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("AsStrided::eval_gpu");
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void AsType::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("AsType::eval_gpu");
|
||||
CopyType ctype =
|
||||
inputs[0].flags().contiguous ? CopyType::Vector : CopyType::General;
|
||||
copy_gpu(inputs[0], out, ctype);
|
||||
}
|
||||
|
||||
void Broadcast::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("Broadcast::eval_gpu");
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void BroadcastAxes::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("BroadcastAxes::eval_gpu");
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void Concatenate::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("Concatenate::eval_gpu");
|
||||
concatenate_gpu(inputs, out, axis_, stream());
|
||||
}
|
||||
|
||||
void Contiguous::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("Contiguous::eval_gpu");
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
constexpr size_t extra_bytes = 16384;
|
||||
if (in.buffer_size() <= out.nbytes() + extra_bytes &&
|
||||
(in.flags().row_contiguous ||
|
||||
(allow_col_major_ && in.flags().col_contiguous))) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
copy_gpu_inplace(
|
||||
in,
|
||||
out,
|
||||
in.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||
stream());
|
||||
}
|
||||
}
|
||||
|
||||
void Copy::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("Copy::eval_gpu");
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void CustomTransforms::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
MLX_PROFILER_RANGE("CustomTransforms::eval_gpu");
|
||||
eval(inputs, outputs);
|
||||
}
|
||||
|
||||
void Depends::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
MLX_PROFILER_RANGE("Depends::eval_gpu");
|
||||
eval(inputs, outputs);
|
||||
}
|
||||
|
||||
void ExpandDims::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("ExpandDims::eval_gpu");
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void Full::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("Full::eval_gpu");
|
||||
auto in = inputs[0];
|
||||
CopyType ctype;
|
||||
if (in.data_size() == 1) {
|
||||
ctype = CopyType::Scalar;
|
||||
} else if (in.flags().contiguous) {
|
||||
ctype = CopyType::Vector;
|
||||
} else {
|
||||
ctype = CopyType::General;
|
||||
}
|
||||
copy_gpu(in, out, ctype);
|
||||
}
|
||||
|
||||
void Flatten::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("Flatten::eval_gpu");
|
||||
reshape(inputs[0], out, stream());
|
||||
}
|
||||
|
||||
void NumberOfElements::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("NumberOfElements::eval_gpu");
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void Pad::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Inputs must be base input array and scalar val array
|
||||
assert(inputs.size() == 2);
|
||||
auto& in = inputs[0];
|
||||
auto& val = inputs[1];
|
||||
|
||||
// Padding value must be a scalar
|
||||
assert(val.size() == 1);
|
||||
|
||||
// Padding value, input and output must be of the same type
|
||||
assert(val.dtype() == in.dtype() && in.dtype() == out.dtype());
|
||||
|
||||
pad_gpu(in, val, out, axes_, low_pad_size_, stream());
|
||||
}
|
||||
|
||||
void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("Reshape::eval_gpu");
|
||||
reshape(inputs[0], out, stream());
|
||||
}
|
||||
|
||||
void Split::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
MLX_PROFILER_RANGE("Split::eval_gpu");
|
||||
eval(inputs, outputs);
|
||||
}
|
||||
|
||||
void Slice::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("Slice::eval_gpu");
|
||||
assert(inputs.size() == 1);
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
return;
|
||||
}
|
||||
|
||||
auto& in = inputs[0];
|
||||
slice_gpu(in, out, start_indices_, strides_, stream());
|
||||
}
|
||||
|
||||
void Squeeze::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("Squeeze::eval_gpu");
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void StopGradient::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("StopGradient::eval_gpu");
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void Transpose::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("Transpose::eval_gpu");
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void Unflatten::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("Unflatten::eval_gpu");
|
||||
reshape(inputs[0], out, stream());
|
||||
}
|
||||
|
||||
void View::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("View::eval_gpu");
|
||||
auto& in = inputs[0];
|
||||
auto ibytes = size_of(in.dtype());
|
||||
auto obytes = size_of(out.dtype());
|
||||
// Conditions for buffer copying (disjunction):
|
||||
// - type size is the same
|
||||
// - type size is smaller and the last axis is contiguous
|
||||
// - the entire array is row contiguous
|
||||
if (ibytes == obytes || (obytes < ibytes && in.strides().back() == 1) ||
|
||||
in.flags().row_contiguous) {
|
||||
auto strides = in.strides();
|
||||
for (int i = 0; i < static_cast<int>(strides.size()) - 1; ++i) {
|
||||
strides[i] *= ibytes;
|
||||
strides[i] /= obytes;
|
||||
}
|
||||
out.copy_shared_buffer(
|
||||
in, strides, in.flags(), in.data_size() * ibytes / obytes);
|
||||
} else {
|
||||
auto tmp = array(in.shape(), in.dtype(), nullptr, {});
|
||||
tmp.set_data(allocator::malloc(tmp.nbytes()));
|
||||
copy_gpu_inplace(in, tmp, CopyType::General, stream());
|
||||
|
||||
auto flags = out.flags();
|
||||
flags.contiguous = true;
|
||||
flags.row_contiguous = true;
|
||||
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
|
||||
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
|
||||
out.copy_shared_buffer(tmp, out.strides(), flags, out.size());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
44
mlx/backend/gpu/slicing.cpp
Normal file
44
mlx/backend/gpu/slicing.cpp
Normal file
@@ -0,0 +1,44 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/slicing.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/gpu/slicing.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void slice_gpu(
|
||||
const array& in,
|
||||
array& out,
|
||||
const Shape& start_indices,
|
||||
const Shape& strides,
|
||||
const Stream& s) {
|
||||
slice(in, out, start_indices, strides);
|
||||
}
|
||||
|
||||
void pad_gpu(
|
||||
const array& in,
|
||||
const array& val,
|
||||
array& out,
|
||||
const std::vector<int>& axes,
|
||||
const Shape& low_pad_size,
|
||||
const Stream& s) {
|
||||
// Fill output with val
|
||||
fill_gpu(val, out, s);
|
||||
|
||||
// Find offset for start of input values
|
||||
size_t data_offset = 0;
|
||||
for (int i = 0; i < axes.size(); i++) {
|
||||
auto ax = axes[i] < 0 ? out.ndim() + axes[i] : axes[i];
|
||||
data_offset += out.strides()[ax] * low_pad_size[i];
|
||||
}
|
||||
|
||||
// Extract slice from output where input will be pasted
|
||||
array out_slice(in.shape(), out.dtype(), nullptr, {});
|
||||
out_slice.copy_shared_buffer(
|
||||
out, out.strides(), out.flags(), out_slice.size(), data_offset);
|
||||
|
||||
// Copy input values into the slice
|
||||
copy_gpu_inplace(in, out_slice, CopyType::GeneralGeneral, s);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -47,6 +47,7 @@ if(MLX_METAL_JIT)
|
||||
make_jit_source(binary)
|
||||
make_jit_source(binary_two)
|
||||
make_jit_source(fft kernels/fft/radix.h kernels/fft/readwrite.h)
|
||||
make_jit_source(logsumexp)
|
||||
make_jit_source(ternary)
|
||||
make_jit_source(softmax)
|
||||
make_jit_source(scan)
|
||||
@@ -60,6 +61,7 @@ if(MLX_METAL_JIT)
|
||||
kernels/steel/gemm/transforms.h)
|
||||
make_jit_source(steel/gemm/kernels/steel_gemm_fused)
|
||||
make_jit_source(steel/gemm/kernels/steel_gemm_masked kernels/steel/defines.h)
|
||||
make_jit_source(steel/gemm/kernels/steel_gemm_gather)
|
||||
make_jit_source(steel/gemm/kernels/steel_gemm_splitk)
|
||||
make_jit_source(
|
||||
steel/conv/conv
|
||||
@@ -91,10 +93,12 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
|
||||
|
@@ -1,7 +1,6 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#include "mlx/backend/metal/allocator.h"
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
#include "mlx/backend/metal/metal_impl.h"
|
||||
#include "mlx/backend/metal/resident.h"
|
||||
#include "mlx/memory.h"
|
||||
|
||||
@@ -33,8 +32,11 @@ namespace metal {
|
||||
|
||||
namespace {
|
||||
|
||||
BufferCache::BufferCache(MTL::Device* device)
|
||||
: device_(device), head_(nullptr), tail_(nullptr), pool_size_(0) {}
|
||||
BufferCache::BufferCache(ResidencySet& residency_set)
|
||||
: head_(nullptr),
|
||||
tail_(nullptr),
|
||||
pool_size_(0),
|
||||
residency_set_(residency_set) {}
|
||||
|
||||
BufferCache::~BufferCache() {
|
||||
auto pool = metal::new_scoped_memory_pool();
|
||||
@@ -45,6 +47,9 @@ int BufferCache::clear() {
|
||||
int n_release = 0;
|
||||
for (auto& [size, holder] : buffer_pool_) {
|
||||
if (holder->buf) {
|
||||
if (!holder->buf->heap()) {
|
||||
residency_set_.erase(holder->buf);
|
||||
}
|
||||
holder->buf->release();
|
||||
n_release++;
|
||||
}
|
||||
@@ -102,6 +107,9 @@ int BufferCache::release_cached_buffers(size_t min_bytes_to_free) {
|
||||
while (tail_ && (total_bytes_freed < min_bytes_to_free)) {
|
||||
if (tail_->buf) {
|
||||
total_bytes_freed += tail_->buf->length();
|
||||
if (!tail_->buf->heap()) {
|
||||
residency_set_.erase(tail_->buf);
|
||||
}
|
||||
tail_->buf->release();
|
||||
tail_->buf = nullptr;
|
||||
n_release++;
|
||||
@@ -156,7 +164,7 @@ void BufferCache::remove_from_list(BufferCache::BufferHolder* to_remove) {
|
||||
MetalAllocator::MetalAllocator()
|
||||
: device_(device(mlx::core::Device::gpu).mtl_device()),
|
||||
residency_set_(device_),
|
||||
buffer_cache_(device_) {
|
||||
buffer_cache_(residency_set_) {
|
||||
auto pool = metal::new_scoped_memory_pool();
|
||||
auto memsize = std::get<size_t>(device_info().at("memory_size"));
|
||||
auto max_rec_size =
|
||||
@@ -263,9 +271,13 @@ Buffer MetalAllocator::malloc(size_t size) {
|
||||
if (!buf) {
|
||||
buf = device_->newBuffer(size, resource_options);
|
||||
}
|
||||
if (!buf) {
|
||||
return Buffer{nullptr};
|
||||
}
|
||||
lk.lock();
|
||||
if (buf) {
|
||||
num_resources_++;
|
||||
num_resources_++;
|
||||
if (!buf->heap()) {
|
||||
residency_set_.insert(buf);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -279,10 +291,6 @@ Buffer MetalAllocator::malloc(size_t size) {
|
||||
get_cache_memory() - max_pool_size_);
|
||||
}
|
||||
|
||||
if (!buf->heap()) {
|
||||
residency_set_.insert(buf);
|
||||
}
|
||||
|
||||
return Buffer{static_cast<void*>(buf)};
|
||||
}
|
||||
|
||||
@@ -298,14 +306,14 @@ void MetalAllocator::free(Buffer buffer) {
|
||||
return;
|
||||
}
|
||||
std::unique_lock lk(mutex_);
|
||||
if (!buf->heap()) {
|
||||
residency_set_.erase(buf);
|
||||
}
|
||||
active_memory_ -= buf->length();
|
||||
if (get_cache_memory() < max_pool_size_) {
|
||||
buffer_cache_.recycle_to_cache(buf);
|
||||
} else {
|
||||
num_resources_--;
|
||||
if (!buf->heap()) {
|
||||
residency_set_.erase(buf);
|
||||
}
|
||||
lk.unlock();
|
||||
auto pool = metal::new_scoped_memory_pool();
|
||||
buf->release();
|
||||
|
@@ -18,7 +18,7 @@ namespace {
|
||||
|
||||
class BufferCache {
|
||||
public:
|
||||
BufferCache(MTL::Device* device);
|
||||
BufferCache(ResidencySet& residency_set);
|
||||
~BufferCache();
|
||||
|
||||
MTL::Buffer* reuse_from_cache(size_t size);
|
||||
@@ -42,13 +42,11 @@ class BufferCache {
|
||||
void add_at_head(BufferHolder* to_add);
|
||||
void remove_from_list(BufferHolder* to_remove);
|
||||
|
||||
MTL::Device* device_;
|
||||
MTL::Heap* heap_{nullptr};
|
||||
|
||||
std::multimap<size_t, BufferHolder*> buffer_pool_;
|
||||
BufferHolder* head_;
|
||||
BufferHolder* tail_;
|
||||
size_t pool_size_;
|
||||
ResidencySet& residency_set_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
@@ -90,7 +90,7 @@ void binary_op_gpu_inplace(
|
||||
work_per_thread = large ? 4 : 2;
|
||||
} else {
|
||||
large = out.data_size() > UINT32_MAX;
|
||||
work_per_thread = 1;
|
||||
work_per_thread = get_work_per_thread(a.dtype());
|
||||
}
|
||||
std::string kernel_name =
|
||||
get_kernel_name(bopt, op, a, large, shape.size(), work_per_thread);
|
||||
@@ -137,13 +137,20 @@ void binary_op_gpu_inplace(
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
} else {
|
||||
// Launch a 1D or 2D grid of threads
|
||||
size_t nthreads = out.data_size();
|
||||
size_t nthreads = ceildiv(out.data_size(), work_per_thread);
|
||||
if (thread_group_size > nthreads) {
|
||||
thread_group_size = nthreads;
|
||||
}
|
||||
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
MTL::Size grid_dims;
|
||||
if (large) {
|
||||
compute_encoder.set_bytes<int64_t>(out.data_size(), arg_idx++);
|
||||
grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread);
|
||||
} else {
|
||||
compute_encoder.set_bytes<int>(out.data_size(), arg_idx++);
|
||||
grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
}
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
@@ -64,6 +64,7 @@ inline void build_kernel(
|
||||
cnt++);
|
||||
}
|
||||
|
||||
std::string idx_type = use_big_index ? "int64_t" : "uint";
|
||||
if (add_indices) {
|
||||
os += fmt::format(
|
||||
" constant const int64_t* in_strides [[buffer({0})]],\n", cnt++);
|
||||
@@ -83,6 +84,9 @@ inline void build_kernel(
|
||||
" constant const int64_t* output_strides [[buffer({0})]],\n", cnt++);
|
||||
os += fmt::format(
|
||||
" constant const int* output_shape [[buffer({0})]],\n", cnt++);
|
||||
} else {
|
||||
os += fmt::format(
|
||||
" constant const {0}& size [[buffer({1})]],\n", idx_type, cnt++);
|
||||
}
|
||||
if (dynamic_dims) {
|
||||
os += fmt::format(" constant const int& ndim [[buffer({0})]],\n", cnt++);
|
||||
@@ -92,13 +96,14 @@ inline void build_kernel(
|
||||
os += " uint3 pos [[thread_position_in_grid]],\n";
|
||||
os += " uint3 grid [[threads_per_grid]]) {\n";
|
||||
|
||||
std::string idx_type = use_big_index ? "int64_t" : "uint";
|
||||
os += fmt::format(" constexpr int N_ = {0};\n", work_per_thread);
|
||||
if (contiguous && use_big_index) {
|
||||
// This is only used for contiguous kernels which don't have
|
||||
// a third grid dimension
|
||||
os += " int64_t index = pos.x + grid.x * int64_t(pos.y);\n";
|
||||
os += " int64_t index = N_ * (pos.x + grid.x * int64_t(pos.y));\n";
|
||||
} else if (contiguous) {
|
||||
os += " uint index = N_ * pos.x;\n";
|
||||
} else if (work_per_thread > 1) {
|
||||
os += fmt::format(" constexpr int N_ = {0};\n", work_per_thread);
|
||||
os += fmt::format(
|
||||
" int xshape = output_shape[{0}];\n",
|
||||
dynamic_dims ? "ndim - 1" : std::to_string(ndim - 1));
|
||||
@@ -110,6 +115,9 @@ inline void build_kernel(
|
||||
" {0} index = pos.x + grid.x * (pos.y + {0}(grid.y) * pos.z);\n",
|
||||
idx_type);
|
||||
}
|
||||
if (work_per_thread > 1 && contiguous) {
|
||||
os += " for (int i = 0; i < N_ && index < size; ++i) {\n";
|
||||
}
|
||||
|
||||
// Read constant / contiguous inputs in tmps
|
||||
std::vector<array> nc_inputs;
|
||||
@@ -193,7 +201,7 @@ inline void build_kernel(
|
||||
}
|
||||
|
||||
// Open per-thread loop
|
||||
if (work_per_thread > 1) {
|
||||
if (work_per_thread > 1 && !contiguous) {
|
||||
os +=
|
||||
" for (int i = 0; i < N_ && (int(N_ * pos.x) + i) < xshape; ++i) {\n";
|
||||
}
|
||||
@@ -272,6 +280,7 @@ void Compiled::eval_gpu(
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto lib = d.get_library(kernel_lib_, [&]() {
|
||||
int work_per_thread = get_work_per_thread(outputs_[0].dtype());
|
||||
std::string kernel = metal::utils();
|
||||
concatenate(
|
||||
kernel, metal::unary_ops(), metal::binary_ops(), metal::ternary_ops());
|
||||
@@ -284,7 +293,9 @@ void Compiled::eval_gpu(
|
||||
constant_ids_,
|
||||
/* contiguous = */ true,
|
||||
/* ndim = */ 0,
|
||||
/* dynamic_dims = */ false);
|
||||
/* dynamic_dims = */ false,
|
||||
/* use_big_index = */ false,
|
||||
/* work_per_thread = */ work_per_thread);
|
||||
build_kernel(
|
||||
kernel,
|
||||
kernel_lib_ + "_contiguous_large",
|
||||
@@ -295,7 +306,8 @@ void Compiled::eval_gpu(
|
||||
/* contiguous = */ true,
|
||||
/* ndim = */ 0,
|
||||
/* dynamic_dims = */ false,
|
||||
/* use_big_index = */ true);
|
||||
/* use_big_index = */ true,
|
||||
/* work_per_thread = */ work_per_thread);
|
||||
for (int i = 1; i < 8; i++) {
|
||||
build_kernel(
|
||||
kernel,
|
||||
@@ -468,6 +480,13 @@ void Compiled::eval_gpu(
|
||||
if (!contiguous) {
|
||||
compute_encoder.set_vector_bytes(strides[0], cnt++);
|
||||
compute_encoder.set_vector_bytes(shape, cnt++);
|
||||
} else {
|
||||
auto size = outputs[0].data_size();
|
||||
if (large) {
|
||||
compute_encoder.set_bytes<int64_t>(size, cnt++);
|
||||
} else {
|
||||
compute_encoder.set_bytes<int>(size, cnt++);
|
||||
}
|
||||
}
|
||||
|
||||
// Put the number of dims in if it is dynamic
|
||||
@@ -477,12 +496,13 @@ void Compiled::eval_gpu(
|
||||
|
||||
// Launch the kernel
|
||||
if (contiguous) {
|
||||
size_t nthreads = outputs[0].data_size();
|
||||
int work_per_thread = get_work_per_thread(outputs[0].dtype());
|
||||
size_t nthreads = ceildiv(outputs[0].data_size(), work_per_thread);
|
||||
MTL::Size group_dims(
|
||||
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
|
||||
|
||||
MTL::Size grid_dims = large
|
||||
? get_2d_grid_dims(outputs[0].shape(), outputs[0].strides())
|
||||
? get_2d_grid_dims(
|
||||
outputs[0].shape(), outputs[0].strides(), work_per_thread)
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
} else {
|
||||
|
@@ -5,7 +5,7 @@
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
@@ -712,6 +712,65 @@ void winograd_conv_2D_gpu(
|
||||
}
|
||||
}
|
||||
|
||||
void depthwise_conv_2D_gpu(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const MLXConvParams<2>& conv_params) {
|
||||
std::ostringstream kname;
|
||||
kname << "depthwise_conv_2d_" << type_to_name(out);
|
||||
std::string base_name = kname.str();
|
||||
|
||||
const int N = conv_params.N;
|
||||
const int ker_h = conv_params.wS[0];
|
||||
const int ker_w = conv_params.wS[1];
|
||||
const int str_h = conv_params.str[0];
|
||||
const int str_w = conv_params.str[1];
|
||||
const int tc = 8;
|
||||
const int tw = 8;
|
||||
const int th = 4;
|
||||
const bool do_flip = conv_params.flip;
|
||||
|
||||
metal::MTLFCList func_consts = {
|
||||
{&ker_h, MTL::DataType::DataTypeInt, 00},
|
||||
{&ker_w, MTL::DataType::DataTypeInt, 01},
|
||||
{&str_h, MTL::DataType::DataTypeInt, 10},
|
||||
{&str_w, MTL::DataType::DataTypeInt, 11},
|
||||
{&th, MTL::DataType::DataTypeInt, 100},
|
||||
{&tw, MTL::DataType::DataTypeInt, 101},
|
||||
{&do_flip, MTL::DataType::DataTypeBool, 200},
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
kname << "_ker_h_" << ker_h
|
||||
<< "_ker_w_" << ker_w
|
||||
<< "_str_h_" << str_h
|
||||
<< "_str_w_" << str_w
|
||||
<< "_tgp_h_" << th
|
||||
<< "_tgp_w_" << tw
|
||||
<< "_do_flip_" << (do_flip ? 't' : 'n'); // clang-format on
|
||||
|
||||
std::string hash_name = kname.str();
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_input_array(wt, 1);
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
compute_encoder.set_bytes(conv_params, 3);
|
||||
|
||||
MTL::Size group_dims = MTL::Size(tc, tw, th);
|
||||
MTL::Size grid_dims = MTL::Size(
|
||||
conv_params.C / tc, conv_params.oS[1] / tw, (conv_params.oS[0] / th) * N);
|
||||
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void conv_2D_gpu(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
@@ -754,11 +813,20 @@ void conv_2D_gpu(
|
||||
bool is_kdil_one = conv_params.kdil[0] == 1 && conv_params.kdil[1] == 1;
|
||||
bool is_idil_one = conv_params.idil[0] == 1 && conv_params.idil[1] == 1;
|
||||
|
||||
if (groups > 1) {
|
||||
if (is_idil_one && groups > 1) {
|
||||
const int C_per_group = conv_params.C / groups;
|
||||
const int O_per_group = conv_params.O / groups;
|
||||
|
||||
if (is_idil_one && (C_per_group <= 4 || C_per_group % 16 == 0) &&
|
||||
if (C_per_group == 1 && O_per_group == 1 && is_kdil_one &&
|
||||
conv_params.wS[0] <= 7 && conv_params.wS[1] <= 7 &&
|
||||
conv_params.str[0] <= 2 && conv_params.str[1] <= 2 &&
|
||||
conv_params.oS[0] % 8 == 0 && conv_params.oS[1] % 8 == 0 &&
|
||||
conv_params.wt_strides[1] == conv_params.wS[1] &&
|
||||
conv_params.C % 16 == 0 && conv_params.C == conv_params.O) {
|
||||
return depthwise_conv_2D_gpu(s, d, in, wt, out, conv_params);
|
||||
}
|
||||
|
||||
if ((C_per_group <= 4 || C_per_group % 16 == 0) &&
|
||||
(O_per_group <= 16 || O_per_group % 16 == 0)) {
|
||||
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
|
||||
} else {
|
||||
|
@@ -1,35 +1,15 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
constexpr int MAX_COPY_SPECIALIZED_DIMS = 3;
|
||||
|
||||
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
|
||||
bool donated = set_copy_output_data(in, out, ctype);
|
||||
if (donated && in.dtype() == out.dtype()) {
|
||||
// If the output has the same type as the input then there is nothing to
|
||||
// copy, just use the buffer.
|
||||
return;
|
||||
}
|
||||
if (ctype == CopyType::GeneralGeneral) {
|
||||
ctype = CopyType::General;
|
||||
}
|
||||
copy_gpu_inplace(in, out, ctype, s);
|
||||
}
|
||||
|
||||
void copy_gpu(const array& in, array& out, CopyType ctype) {
|
||||
copy_gpu(in, out, ctype, out.primitive().stream());
|
||||
}
|
||||
|
||||
void copy_gpu_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
@@ -104,6 +84,8 @@ void copy_gpu_inplace(
|
||||
"[Copy::eval_gpu] Dynamic output offset requires GeneralGeneral copy");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
work_per_thread = get_work_per_thread(in.dtype());
|
||||
}
|
||||
concatenate(kernel_name, "_copy", type_to_name(in), type_to_name(out));
|
||||
auto kernel = dynamic ? get_dynamic_copy_kernel(d, kernel_name, in, out)
|
||||
@@ -165,39 +147,23 @@ void copy_gpu_inplace(
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
} else {
|
||||
size_t nthreads = out.data_size();
|
||||
size_t nthreads = ceildiv(out.data_size(), work_per_thread);
|
||||
if (thread_group_size > nthreads) {
|
||||
thread_group_size = nthreads;
|
||||
}
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
MTL::Size grid_dims;
|
||||
if (large) {
|
||||
compute_encoder.set_bytes<int64_t>(out.data_size(), 2);
|
||||
grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread);
|
||||
} else {
|
||||
compute_encoder.set_bytes<int>(out.data_size(), 2);
|
||||
grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
}
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
void copy_gpu_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
CopyType ctype,
|
||||
const Stream& s) {
|
||||
assert(in.shape() == out.shape());
|
||||
return copy_gpu_inplace(
|
||||
in, out, in.shape(), in.strides(), out.strides(), 0, 0, ctype, s);
|
||||
}
|
||||
|
||||
void copy_gpu_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
const Strides& i_strides,
|
||||
int64_t i_offset,
|
||||
CopyType ctype,
|
||||
const Stream& s) {
|
||||
assert(in.shape() == out.shape());
|
||||
return copy_gpu_inplace(
|
||||
in, out, in.shape(), i_strides, out.strides(), i_offset, 0, ctype, s);
|
||||
}
|
||||
|
||||
void fill_gpu(const array& val, array& out, const Stream& s) {
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
@@ -214,14 +180,21 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
|
||||
compute_encoder.set_input_array(val, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
|
||||
int work_per_thread = get_work_per_thread(val.dtype());
|
||||
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
size_t nthreads = out.data_size();
|
||||
size_t nthreads = ceildiv(out.data_size(), work_per_thread);
|
||||
if (thread_group_size > nthreads) {
|
||||
thread_group_size = nthreads;
|
||||
}
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
MTL::Size grid_dims;
|
||||
if (large) {
|
||||
compute_encoder.set_bytes<int64_t>(out.data_size(), 2);
|
||||
grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread);
|
||||
} else {
|
||||
compute_encoder.set_bytes<int>(out.data_size(), 2);
|
||||
grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
}
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
|
@@ -1,6 +1,6 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/metal/jit/includes.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
|
@@ -1,20 +1,20 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <filesystem>
|
||||
#include <sstream>
|
||||
|
||||
#include <sys/sysctl.h>
|
||||
|
||||
#define NS_PRIVATE_IMPLEMENTATION
|
||||
#define CA_PRIVATE_IMPLEMENTATION
|
||||
#define MTL_PRIVATE_IMPLEMENTATION
|
||||
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
#include "mlx/backend/metal/metal_impl.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace fs = std::filesystem;
|
||||
|
||||
namespace mlx::core::metal {
|
||||
|
||||
namespace {
|
||||
@@ -55,7 +55,10 @@ std::pair<MTL::Library*, NS::Error*> load_library_from_path(
|
||||
}
|
||||
|
||||
#ifdef SWIFTPM_BUNDLE
|
||||
MTL::Library* try_load_bundle(MTL::Device* device, NS::URL* url) {
|
||||
MTL::Library* try_load_bundle(
|
||||
MTL::Device* device,
|
||||
NS::URL* url,
|
||||
const std::string& lib_name) {
|
||||
std::string bundle_path = std::string(url->fileSystemRepresentation()) + "/" +
|
||||
SWIFTPM_BUNDLE + ".bundle";
|
||||
auto bundle = NS::Bundle::alloc()->init(
|
||||
@@ -63,7 +66,7 @@ MTL::Library* try_load_bundle(MTL::Device* device, NS::URL* url) {
|
||||
if (bundle != nullptr) {
|
||||
std::string resource_path =
|
||||
std::string(bundle->resourceURL()->fileSystemRepresentation()) + "/" +
|
||||
"default.metallib";
|
||||
lib_name + ".metallib";
|
||||
auto [lib, error] = load_library_from_path(device, resource_path.c_str());
|
||||
if (lib) {
|
||||
return lib;
|
||||
@@ -73,51 +76,133 @@ MTL::Library* try_load_bundle(MTL::Device* device, NS::URL* url) {
|
||||
}
|
||||
#endif
|
||||
|
||||
// Firstly, search for the metallib in the same path as this binary
|
||||
std::pair<MTL::Library*, NS::Error*> load_colocated_library(
|
||||
MTL::Device* device,
|
||||
const std::string& relative_path) {
|
||||
std::string binary_dir = get_binary_directory();
|
||||
if (binary_dir.size() == 0) {
|
||||
return {nullptr, nullptr};
|
||||
}
|
||||
|
||||
auto path = fs::path(binary_dir) / relative_path;
|
||||
if (!path.has_extension()) {
|
||||
path.replace_extension(".metallib");
|
||||
}
|
||||
|
||||
return load_library_from_path(device, path.c_str());
|
||||
}
|
||||
|
||||
std::pair<MTL::Library*, NS::Error*> load_swiftpm_library(
|
||||
MTL::Device* device,
|
||||
const std::string& lib_name) {
|
||||
#ifdef SWIFTPM_BUNDLE
|
||||
MTL::Library* library =
|
||||
try_load_bundle(device, NS::Bundle::mainBundle()->bundleURL(), lib_name);
|
||||
if (library != nullptr) {
|
||||
return {library, nullptr};
|
||||
}
|
||||
auto bundles = NS::Bundle::allBundles();
|
||||
for (int i = 0, c = (int)bundles->count(); i < c; i++) {
|
||||
auto bundle = reinterpret_cast<NS::Bundle*>(bundles->object(i));
|
||||
library = try_load_bundle(device, bundle->resourceURL(), lib_name);
|
||||
if (library != nullptr) {
|
||||
return {library, nullptr};
|
||||
}
|
||||
}
|
||||
#endif
|
||||
return {nullptr, nullptr};
|
||||
}
|
||||
|
||||
MTL::Library* load_default_library(MTL::Device* device) {
|
||||
NS::Error* error[4];
|
||||
MTL::Library* lib;
|
||||
// First try the colocated mlx.metallib
|
||||
std::tie(lib, error[0]) = load_colocated_library(device, "mlx");
|
||||
if (lib) {
|
||||
return lib;
|
||||
}
|
||||
|
||||
std::tie(lib, error[1]) = load_colocated_library(device, "Resources/mlx");
|
||||
if (lib) {
|
||||
return lib;
|
||||
}
|
||||
|
||||
// Then try default.metallib in a SwiftPM bundle if we have one
|
||||
std::tie(lib, error[2]) = load_swiftpm_library(device, "default");
|
||||
if (lib) {
|
||||
return lib;
|
||||
}
|
||||
|
||||
// Finally try default_mtllib_path
|
||||
std::tie(lib, error[3]) = load_library_from_path(device, default_mtllib_path);
|
||||
if (!lib) {
|
||||
std::ostringstream msg;
|
||||
msg << "Failed to load the default metallib. ";
|
||||
for (int i = 0; i < 4; i++) {
|
||||
if (error[i] != nullptr) {
|
||||
msg << error[i]->localizedDescription()->utf8String() << " ";
|
||||
}
|
||||
}
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
return lib;
|
||||
}
|
||||
|
||||
MTL::Library* load_library(
|
||||
MTL::Device* device,
|
||||
const std::string& lib_name = "mlx",
|
||||
const char* lib_path = default_mtllib_path) {
|
||||
// Firstly, search for the metallib in the same path as this binary
|
||||
std::string first_path = get_colocated_mtllib_path(lib_name);
|
||||
if (first_path.size() != 0) {
|
||||
auto [lib, error] = load_library_from_path(device, first_path.c_str());
|
||||
const std::string& lib_name,
|
||||
const std::string& lib_path) {
|
||||
// We have been given a path that ends in metallib so try to load it
|
||||
if (lib_path.size() > 9 &&
|
||||
std::equal(lib_path.end() - 9, lib_path.end(), ".metallib")) {
|
||||
auto [lib, error] = load_library_from_path(device, lib_path.c_str());
|
||||
if (!lib) {
|
||||
std::ostringstream msg;
|
||||
msg << "Failed to load the metallib from <" << lib_path << "> with error "
|
||||
<< error->localizedDescription()->utf8String();
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
return lib;
|
||||
}
|
||||
|
||||
// We have been given a path so try to load from lib_path / lib_name.metallib
|
||||
if (lib_path.size() > 0) {
|
||||
std::string full_path = lib_path + "/" + lib_name + ".metallib";
|
||||
auto [lib, error] = load_library_from_path(device, full_path.c_str());
|
||||
if (!lib) {
|
||||
std::ostringstream msg;
|
||||
msg << "Failed to load the metallib from <" << full_path
|
||||
<< "> with error " << error->localizedDescription()->utf8String();
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
return lib;
|
||||
}
|
||||
|
||||
// Try to load the colocated library
|
||||
{
|
||||
auto [lib, error] = load_colocated_library(device, lib_name);
|
||||
if (lib) {
|
||||
return lib;
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef SWIFTPM_BUNDLE
|
||||
// try to load from a swiftpm resource bundle -- scan the available bundles to
|
||||
// find one that contains the named bundle
|
||||
// Try to load the library from swiftpm
|
||||
{
|
||||
MTL::Library* library =
|
||||
try_load_bundle(device, NS::Bundle::mainBundle()->bundleURL());
|
||||
if (library != nullptr) {
|
||||
return library;
|
||||
}
|
||||
auto bundles = NS::Bundle::allBundles();
|
||||
for (int i = 0, c = (int)bundles->count(); i < c; i++) {
|
||||
auto bundle = reinterpret_cast<NS::Bundle*>(bundles->object(i));
|
||||
library = try_load_bundle(device, bundle->resourceURL());
|
||||
if (library != nullptr) {
|
||||
return library;
|
||||
}
|
||||
auto [lib, error] = load_swiftpm_library(device, lib_name);
|
||||
if (lib) {
|
||||
return lib;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
// Couldn't find it so let's load it from default_mtllib_path
|
||||
{
|
||||
auto [lib, error] = load_library_from_path(device, lib_path);
|
||||
if (!lib) {
|
||||
std::ostringstream msg;
|
||||
msg << error->localizedDescription()->utf8String() << "\n"
|
||||
<< "Failed to load device library from <" << lib_path << ">"
|
||||
<< " or <" << first_path << ">.";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
return lib;
|
||||
}
|
||||
std::ostringstream msg;
|
||||
msg << "Failed to load the metallib " << lib_name << ".metallib. "
|
||||
<< "We attempted to load it from <" << get_binary_directory() << "/"
|
||||
<< lib_name << ".metallib" << ">";
|
||||
#ifdef SWIFTPM_BUNDLE
|
||||
msg << " and from the Swift PM bundle.";
|
||||
#endif
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@@ -210,7 +295,7 @@ void CommandEncoder::barrier() {
|
||||
Device::Device() {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
device_ = load_device();
|
||||
library_map_ = {{"mlx", load_library(device_)}};
|
||||
library_map_ = {{"mlx", load_default_library(device_)}};
|
||||
arch_ = std::string(device_->architecture()->name()->utf8String());
|
||||
auto arch = arch_.back();
|
||||
switch (arch) {
|
||||
@@ -684,42 +769,4 @@ std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool() {
|
||||
NS::AutoreleasePool::alloc()->init(), dtor);
|
||||
}
|
||||
|
||||
void new_stream(Stream stream) {
|
||||
if (stream.device == mlx::core::Device::gpu) {
|
||||
device(stream.device).new_queue(stream.index);
|
||||
}
|
||||
}
|
||||
|
||||
const std::unordered_map<std::string, std::variant<std::string, size_t>>&
|
||||
device_info() {
|
||||
auto init_device_info = []()
|
||||
-> std::unordered_map<std::string, std::variant<std::string, size_t>> {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
auto raw_device = device(default_device()).mtl_device();
|
||||
auto name = std::string(raw_device->name()->utf8String());
|
||||
auto arch = std::string(raw_device->architecture()->name()->utf8String());
|
||||
|
||||
size_t memsize = 0;
|
||||
size_t length = sizeof(memsize);
|
||||
sysctlbyname("hw.memsize", &memsize, &length, NULL, 0);
|
||||
|
||||
size_t rsrc_limit = 0;
|
||||
sysctlbyname("iogpu.rsrc_limit", &rsrc_limit, &length, NULL, 0);
|
||||
if (rsrc_limit == 0) {
|
||||
rsrc_limit = 499000;
|
||||
}
|
||||
|
||||
return {
|
||||
{"device_name", name},
|
||||
{"architecture", arch},
|
||||
{"max_buffer_length", raw_device->maxBufferLength()},
|
||||
{"max_recommended_working_set_size",
|
||||
raw_device->recommendedMaxWorkingSetSize()},
|
||||
{"memory_size", memsize},
|
||||
{"resource_limit", rsrc_limit}};
|
||||
};
|
||||
static auto device_info_ = init_device_info();
|
||||
return device_info_;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::metal
|
||||
|
@@ -21,18 +21,14 @@ namespace mlx::core::metal {
|
||||
|
||||
// Note, this function must be left inline in a header so that it is not
|
||||
// dynamically linked.
|
||||
inline std::string get_colocated_mtllib_path(const std::string& lib_name) {
|
||||
inline std::string get_binary_directory() {
|
||||
Dl_info info;
|
||||
std::string mtllib_path;
|
||||
std::string lib_ext = lib_name + ".metallib";
|
||||
|
||||
int success = dladdr((void*)get_colocated_mtllib_path, &info);
|
||||
std::string directory;
|
||||
int success = dladdr((void*)get_binary_directory, &info);
|
||||
if (success) {
|
||||
auto mtllib = fs::path(info.dli_fname).remove_filename() / lib_ext;
|
||||
mtllib_path = mtllib.c_str();
|
||||
directory = fs::path(info.dli_fname).remove_filename().c_str();
|
||||
}
|
||||
|
||||
return mtllib_path;
|
||||
return directory;
|
||||
}
|
||||
|
||||
using MTLFCList =
|
||||
@@ -189,15 +185,7 @@ class Device {
|
||||
|
||||
void register_library(
|
||||
const std::string& lib_name,
|
||||
const std::string& lib_path);
|
||||
|
||||
// Note, this should remain in the header so that it is not dynamically
|
||||
// linked
|
||||
void register_library(const std::string& lib_name) {
|
||||
if (auto it = library_map_.find(lib_name); it == library_map_.end()) {
|
||||
register_library(lib_name, get_colocated_mtllib_path(lib_name));
|
||||
}
|
||||
}
|
||||
const std::string& lib_path = "");
|
||||
|
||||
MTL::Library* get_library(
|
||||
const std::string& name,
|
||||
@@ -278,4 +266,6 @@ class Device {
|
||||
|
||||
Device& device(mlx::core::Device);
|
||||
|
||||
std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool();
|
||||
|
||||
} // namespace mlx::core::metal
|
||||
|
@@ -4,7 +4,7 @@
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/distributed/ops.h"
|
||||
|
102
mlx/backend/metal/eval.cpp
Normal file
102
mlx/backend/metal/eval.cpp
Normal file
@@ -0,0 +1,102 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#include <memory>
|
||||
|
||||
#include "mlx/backend/gpu/available.h"
|
||||
#include "mlx/backend/gpu/eval.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/scheduler.h"
|
||||
|
||||
namespace mlx::core::gpu {
|
||||
|
||||
bool is_available() {
|
||||
return true;
|
||||
}
|
||||
|
||||
void new_stream(Stream stream) {
|
||||
if (stream.device == mlx::core::Device::gpu) {
|
||||
metal::device(stream.device).new_queue(stream.index);
|
||||
}
|
||||
}
|
||||
|
||||
inline void check_error(MTL::CommandBuffer* cbuf) {
|
||||
if (cbuf->status() == MTL::CommandBufferStatusError) {
|
||||
std::ostringstream msg;
|
||||
msg << "[METAL] Command buffer execution failed: "
|
||||
<< cbuf->error()->localizedDescription()->utf8String();
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
void eval(array& arr) {
|
||||
auto pool = metal::new_scoped_memory_pool();
|
||||
auto s = arr.primitive().stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto command_buffer = d.get_command_buffer(s.index);
|
||||
|
||||
auto outputs = arr.outputs();
|
||||
{
|
||||
// If the array is a tracer hold a reference
|
||||
// to its inputs so they don't get donated
|
||||
std::vector<array> inputs;
|
||||
if (arr.is_tracer()) {
|
||||
inputs = arr.inputs();
|
||||
}
|
||||
|
||||
debug_set_primitive_buffer_label(command_buffer, arr.primitive());
|
||||
arr.primitive().eval_gpu(arr.inputs(), outputs);
|
||||
}
|
||||
std::unordered_set<std::shared_ptr<array::Data>> buffers;
|
||||
for (auto& in : arr.inputs()) {
|
||||
buffers.insert(in.data_shared_ptr());
|
||||
}
|
||||
for (auto& s : arr.siblings()) {
|
||||
buffers.insert(s.data_shared_ptr());
|
||||
}
|
||||
// Remove the output if it was donated to by an input
|
||||
if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) {
|
||||
buffers.erase(it);
|
||||
}
|
||||
|
||||
if (d.command_buffer_needs_commit(s.index)) {
|
||||
d.end_encoding(s.index);
|
||||
scheduler::notify_new_task(s);
|
||||
command_buffer->addCompletedHandler(
|
||||
[s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
|
||||
scheduler::notify_task_completion(s);
|
||||
check_error(cbuf);
|
||||
});
|
||||
d.commit_command_buffer(s.index);
|
||||
d.get_command_buffer(s.index);
|
||||
} else {
|
||||
command_buffer->addCompletedHandler(
|
||||
[s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
|
||||
check_error(cbuf);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void finalize(Stream s) {
|
||||
auto pool = metal::new_scoped_memory_pool();
|
||||
auto& d = metal::device(s.device);
|
||||
auto cb = d.get_command_buffer(s.index);
|
||||
d.end_encoding(s.index);
|
||||
cb->addCompletedHandler([s](MTL::CommandBuffer* cbuf) { check_error(cbuf); });
|
||||
d.commit_command_buffer(s.index);
|
||||
d.get_command_buffer(s.index);
|
||||
}
|
||||
|
||||
void synchronize(Stream s) {
|
||||
auto pool = metal::new_scoped_memory_pool();
|
||||
auto& d = metal::device(s.device);
|
||||
auto cb = d.get_command_buffer(s.index);
|
||||
cb->retain();
|
||||
d.end_encoding(s.index);
|
||||
d.commit_command_buffer(s.index);
|
||||
cb->waitUntilCompleted();
|
||||
check_error(cb);
|
||||
cb->release();
|
||||
}
|
||||
|
||||
} // namespace mlx::core::gpu
|
@@ -2,7 +2,6 @@
|
||||
|
||||
#include "mlx/event.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/metal_impl.h"
|
||||
#include "mlx/scheduler.h"
|
||||
|
||||
namespace mlx::core {
|
||||
@@ -24,10 +23,6 @@ void Event::wait() {
|
||||
}
|
||||
}
|
||||
|
||||
void Event::signal() {
|
||||
static_cast<MTL::SharedEvent*>(event_.get())->setSignaledValue(value());
|
||||
}
|
||||
|
||||
void Event::wait(Stream stream) {
|
||||
if (stream.device == Device::cpu) {
|
||||
scheduler::enqueue(stream, [*this]() mutable { wait(); });
|
||||
@@ -42,7 +37,9 @@ void Event::wait(Stream stream) {
|
||||
|
||||
void Event::signal(Stream stream) {
|
||||
if (stream.device == Device::cpu) {
|
||||
scheduler::enqueue(stream, [*this]() mutable { signal(); });
|
||||
scheduler::enqueue(stream, [*this]() mutable {
|
||||
static_cast<MTL::SharedEvent*>(event_.get())->setSignaledValue(value());
|
||||
});
|
||||
} else {
|
||||
auto& d = metal::device(stream.device);
|
||||
d.end_encoding(stream.index);
|
||||
|
@@ -1,7 +1,6 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
#include "mlx/fence.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/metal_impl.h"
|
||||
#include "mlx/scheduler.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
@@ -139,7 +138,7 @@ void Fence::update(Stream stream, const array& x) {
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(x, 0);
|
||||
compute_encoder.set_bytes(nthreads, 1);
|
||||
compute_encoder.dispatch_threadgroups(group_dims, grid_dims);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
|
||||
// Barrier on previous kernels
|
||||
compute_encoder.barrier();
|
||||
|
@@ -1,16 +1,18 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
#include <cassert>
|
||||
#include <complex>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <numeric>
|
||||
#include <set>
|
||||
|
||||
#include "mlx/3rdparty/pocketfft.h"
|
||||
#include "mlx/backend/common/transpose.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/gpu/slicing.h"
|
||||
#include "mlx/backend/metal/binary.h"
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/slicing.h"
|
||||
#include "mlx/backend/metal/unary.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/utils.h"
|
||||
@@ -27,7 +29,7 @@ using MTLFC = std::tuple<const void*, MTL::DataType, NS::UInteger>;
|
||||
// For strided reads/writes, coalesce at least this many complex64s
|
||||
#define MIN_COALESCE_WIDTH 4
|
||||
|
||||
inline const std::vector<int> supported_radices() {
|
||||
inline constexpr std::array<int, 9> supported_radices() {
|
||||
// Ordered by preference in decomposition.
|
||||
return {13, 11, 8, 7, 6, 5, 4, 3, 2};
|
||||
}
|
||||
@@ -49,6 +51,35 @@ std::vector<int> prime_factors(int n) {
|
||||
return factors;
|
||||
}
|
||||
|
||||
int next_fast_n(int n) {
|
||||
return next_power_of_2(n);
|
||||
}
|
||||
|
||||
std::vector<int> stockham_decompose(int n) {
|
||||
auto radices = supported_radices();
|
||||
std::vector<int> steps(radices.size(), 0);
|
||||
int orig_n = n;
|
||||
|
||||
for (int i = 0; i < radices.size(); i++) {
|
||||
int radix = radices[i];
|
||||
|
||||
// Manually tuned radices for powers of 2
|
||||
if (is_power_of_2(orig_n) && orig_n < 512 && radix > 4) {
|
||||
continue;
|
||||
}
|
||||
|
||||
while (n % radix == 0) {
|
||||
steps[i] += 1;
|
||||
n /= radix;
|
||||
if (n == 1) {
|
||||
return steps;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
struct FourStepParams {
|
||||
bool required = false;
|
||||
bool first_step = true;
|
||||
@@ -65,9 +96,10 @@ void fft_op(
|
||||
bool real,
|
||||
const FourStepParams four_step_params,
|
||||
bool inplace,
|
||||
metal::Device& d,
|
||||
const Stream& s);
|
||||
|
||||
struct FFTPlan {
|
||||
struct OldFFTPlan {
|
||||
int n = 0;
|
||||
// Number of steps for each radix in the Stockham decomposition
|
||||
std::vector<int> stockham;
|
||||
@@ -82,9 +114,104 @@ struct FFTPlan {
|
||||
int n2 = 0;
|
||||
};
|
||||
|
||||
int next_fast_n(int n) {
|
||||
return next_power_of_2(n);
|
||||
}
|
||||
class FFTPlan {
|
||||
public:
|
||||
enum FFTType {
|
||||
UNSUPPORTED,
|
||||
NOOP,
|
||||
STOCKHAM,
|
||||
RADER,
|
||||
BLUESTEIN,
|
||||
MULTIUPLOAD_BLUESTEIN,
|
||||
SMALL_FOUR_STEP,
|
||||
LARGE_FOUR_STEP
|
||||
};
|
||||
|
||||
FFTPlan(int n) : n_(n) {
|
||||
// NOOP
|
||||
if (n == 1) {
|
||||
type_ = NOOP;
|
||||
}
|
||||
|
||||
// Too large for Stockham so do four step fft for powers of 2
|
||||
else if (n > MAX_STOCKHAM_FFT_SIZE && is_power_of_2(n)) {
|
||||
if (n <= 1 << 20) {
|
||||
type_ = SMALL_FOUR_STEP;
|
||||
n2_ = n > 65536 ? 1024 : 64;
|
||||
n1_ = n / n2_;
|
||||
steps1_ = stockham_decompose(n1_);
|
||||
steps2_ = stockham_decompose(n2_);
|
||||
} else {
|
||||
type_ = LARGE_FOUR_STEP;
|
||||
}
|
||||
}
|
||||
|
||||
// Too large and not power of 2 so do multi-upload Bluestein fft
|
||||
else if (n > MAX_STOCKHAM_FFT_SIZE) {
|
||||
type_ = MULTIUPLOAD_BLUESTEIN;
|
||||
bluestein_n_ = next_fast_n(2 * n - 1);
|
||||
}
|
||||
|
||||
// Stockham fft
|
||||
else if (auto steps = stockham_decompose(n); steps.size() > 0) {
|
||||
type_ = STOCKHAM;
|
||||
steps_ = steps;
|
||||
}
|
||||
|
||||
// Add rader but for now simply fall back to bluestein when stockham not
|
||||
// posssible
|
||||
else if (n > MAX_BLUESTEIN_FFT_SIZE) {
|
||||
type_ = MULTIUPLOAD_BLUESTEIN;
|
||||
bluestein_n_ = next_fast_n(2 * n - 1);
|
||||
} else {
|
||||
type_ = BLUESTEIN;
|
||||
bluestein_n_ = next_fast_n(2 * n - 1);
|
||||
steps_ = stockham_decompose(bluestein_n_);
|
||||
}
|
||||
}
|
||||
|
||||
FFTType type() const {
|
||||
return type_;
|
||||
}
|
||||
|
||||
int size() const {
|
||||
return n_;
|
||||
}
|
||||
|
||||
const std::vector<int>& steps() const {
|
||||
return steps_;
|
||||
}
|
||||
|
||||
int first_size() const {
|
||||
return n1_;
|
||||
}
|
||||
|
||||
const std::vector<int>& first_steps() const {
|
||||
return steps1_;
|
||||
}
|
||||
|
||||
int second_size() const {
|
||||
return n2_;
|
||||
}
|
||||
|
||||
const std::vector<int>& second_steps() const {
|
||||
return steps2_;
|
||||
}
|
||||
|
||||
int bluestein_size() const {
|
||||
return bluestein_n_;
|
||||
}
|
||||
|
||||
private:
|
||||
int n_;
|
||||
FFTType type_;
|
||||
std::vector<int> steps_;
|
||||
int n1_;
|
||||
std::vector<int> steps1_;
|
||||
int n2_;
|
||||
std::vector<int> steps2_;
|
||||
int bluestein_n_;
|
||||
};
|
||||
|
||||
std::vector<int> plan_stockham_fft(int n) {
|
||||
auto radices = supported_radices();
|
||||
@@ -110,15 +237,12 @@ std::vector<int> plan_stockham_fft(int n) {
|
||||
throw std::runtime_error("Unplannable");
|
||||
}
|
||||
|
||||
FFTPlan plan_fft(int n) {
|
||||
OldFFTPlan plan_fft(int n) {
|
||||
auto radices = supported_radices();
|
||||
std::set<int> radices_set(radices.begin(), radices.end());
|
||||
|
||||
FFTPlan plan;
|
||||
OldFFTPlan plan;
|
||||
plan.n = n;
|
||||
plan.rader = std::vector<int>(radices.size(), 0);
|
||||
auto factors = prime_factors(n);
|
||||
int remaining_n = n;
|
||||
|
||||
// Four Step FFT when N is too large for shared mem.
|
||||
if (n > MAX_STOCKHAM_FFT_SIZE && is_power_of_2(n)) {
|
||||
@@ -128,16 +252,20 @@ FFTPlan plan_fft(int n) {
|
||||
plan.n2 = n > 65536 ? 1024 : 64;
|
||||
plan.n1 = n / plan.n2;
|
||||
return plan;
|
||||
} else if (n > MAX_STOCKHAM_FFT_SIZE) {
|
||||
}
|
||||
|
||||
if (n > MAX_STOCKHAM_FFT_SIZE) {
|
||||
// Otherwise we use a multi-upload Bluestein's
|
||||
plan.four_step = true;
|
||||
plan.bluestein_n = next_fast_n(2 * n - 1);
|
||||
return plan;
|
||||
}
|
||||
|
||||
int remaining_n = n;
|
||||
auto factors = prime_factors(n);
|
||||
for (int factor : factors) {
|
||||
// Make sure the factor is a supported radix
|
||||
if (radices_set.find(factor) == radices_set.end()) {
|
||||
if (std::find(radices.begin(), radices.end(), factor) == radices.end()) {
|
||||
// We only support a single Rader factor currently
|
||||
// TODO(alexbarron) investigate weirdness with large
|
||||
// Rader sizes -- possibly a compiler issue?
|
||||
@@ -154,7 +282,7 @@ FFTPlan plan_fft(int n) {
|
||||
for (int rf : rader_factors) {
|
||||
// We don't nest Rader's algorithm so if `factor - 1`
|
||||
// isn't Stockham decomposable we give up and do Bluestein's.
|
||||
if (radices_set.find(rf) == radices_set.end()) {
|
||||
if (std::find(radices.begin(), radices.end(), rf) == radices.end()) {
|
||||
plan.four_step = n > MAX_BLUESTEIN_FFT_SIZE;
|
||||
plan.bluestein_n = next_fast_n(2 * n - 1);
|
||||
plan.stockham = plan_stockham_fft(plan.bluestein_n);
|
||||
@@ -172,7 +300,7 @@ FFTPlan plan_fft(int n) {
|
||||
return plan;
|
||||
}
|
||||
|
||||
int compute_elems_per_thread(FFTPlan plan) {
|
||||
int compute_elems_per_thread(OldFFTPlan plan) {
|
||||
// Heuristics for selecting an efficient number
|
||||
// of threads to use for a particular mixed-radix FFT.
|
||||
auto n = plan.n;
|
||||
@@ -355,21 +483,17 @@ void multi_upload_bluestein_fft(
|
||||
size_t axis,
|
||||
bool inverse,
|
||||
bool real,
|
||||
FFTPlan& plan,
|
||||
std::vector<array> copies,
|
||||
OldFFTPlan& plan,
|
||||
std::vector<array>& copies,
|
||||
const Stream& s) {
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
// TODO(alexbarron) Implement fused kernels for mutli upload bluestein's
|
||||
// algorithm
|
||||
int n = inverse ? out.shape(axis) : in.shape(axis);
|
||||
auto [w_k, w_q] = compute_bluestein_constants(n, plan.bluestein_n);
|
||||
|
||||
// Broadcast w_q and w_k to the batch size
|
||||
Strides b_strides(in.ndim(), 0);
|
||||
b_strides[axis] = 1;
|
||||
array w_k_broadcast({}, complex64, nullptr, {});
|
||||
array w_q_broadcast({}, complex64, nullptr, {});
|
||||
w_k_broadcast.copy_shared_buffer(w_k, b_strides, {}, w_k.data_size());
|
||||
w_q_broadcast.copy_shared_buffer(w_q, b_strides, {}, w_q.data_size());
|
||||
copies.push_back(w_k);
|
||||
copies.push_back(w_q);
|
||||
|
||||
auto temp_shape = inverse ? out.shape() : in.shape();
|
||||
array temp(temp_shape, complex64, nullptr, {});
|
||||
@@ -378,13 +502,13 @@ void multi_upload_bluestein_fft(
|
||||
if (real && !inverse) {
|
||||
// Convert float32->complex64
|
||||
copy_gpu(in, temp, CopyType::General, s);
|
||||
copies.push_back(temp);
|
||||
} else if (real && inverse) {
|
||||
int back_offset = n % 2 == 0 ? 2 : 1;
|
||||
auto slice_shape = in.shape();
|
||||
slice_shape[axis] -= back_offset;
|
||||
array slice_temp(slice_shape, complex64, nullptr, {});
|
||||
array conj_temp(in.shape(), complex64, nullptr, {});
|
||||
copies.push_back(slice_temp);
|
||||
copies.push_back(conj_temp);
|
||||
|
||||
Shape rstarts(in.ndim(), 0);
|
||||
@@ -394,19 +518,28 @@ void multi_upload_bluestein_fft(
|
||||
unary_op_gpu({in}, conj_temp, "Conjugate", s);
|
||||
slice_gpu(in, slice_temp, rstarts, rstrides, s);
|
||||
concatenate_gpu({conj_temp, slice_temp}, temp, (int)axis, s);
|
||||
copies.push_back(temp);
|
||||
} else if (inverse) {
|
||||
unary_op_gpu({in}, temp, "Conjugate", s);
|
||||
copies.push_back(temp);
|
||||
} else {
|
||||
temp.copy_shared_buffer(in);
|
||||
}
|
||||
|
||||
Strides b_strides(in.ndim(), 0);
|
||||
b_strides[axis] = 1;
|
||||
array w_k_broadcast(temp.shape(), complex64, nullptr, {});
|
||||
w_k_broadcast.copy_shared_buffer(w_k, b_strides, {}, w_k.data_size());
|
||||
binary_op_gpu({temp, w_k_broadcast}, temp1, "Multiply", s);
|
||||
|
||||
std::vector<std::pair<int, int>> pads;
|
||||
auto padded_shape = out.shape();
|
||||
padded_shape[axis] = plan.bluestein_n;
|
||||
array pad_temp(padded_shape, complex64, nullptr, {});
|
||||
pad_gpu(temp1, array(complex64_t{0.0f, 0.0f}), pad_temp, {(int)axis}, {0}, s);
|
||||
auto zero = array(complex64_t{0.0f, 0.0f});
|
||||
copies.push_back(zero);
|
||||
pad_gpu(temp1, zero, pad_temp, {(int)axis}, {0}, s);
|
||||
copies.push_back(pad_temp);
|
||||
|
||||
array pad_temp1(padded_shape, complex64, nullptr, {});
|
||||
fft_op(
|
||||
@@ -417,8 +550,12 @@ void multi_upload_bluestein_fft(
|
||||
/*real=*/false,
|
||||
FourStepParams(),
|
||||
/*inplace=*/false,
|
||||
d,
|
||||
s);
|
||||
copies.push_back(pad_temp1);
|
||||
|
||||
array w_q_broadcast(pad_temp1.shape(), complex64, nullptr, {});
|
||||
w_q_broadcast.copy_shared_buffer(w_q, b_strides, {}, w_q.data_size());
|
||||
binary_op_gpu_inplace({pad_temp1, w_q_broadcast}, pad_temp, "Multiply", s);
|
||||
|
||||
fft_op(
|
||||
@@ -429,15 +566,18 @@ void multi_upload_bluestein_fft(
|
||||
/* real= */ false,
|
||||
FourStepParams(),
|
||||
/*inplace=*/true,
|
||||
d,
|
||||
s);
|
||||
|
||||
int offset = plan.bluestein_n - (2 * n - 1);
|
||||
Shape starts(in.ndim(), 0);
|
||||
Shape strides(in.ndim(), 1);
|
||||
starts[axis] = plan.bluestein_n - offset - n;
|
||||
slice_gpu(pad_temp1, temp, starts, strides, s);
|
||||
|
||||
binary_op_gpu_inplace({temp, w_k_broadcast}, temp1, "Multiply", s);
|
||||
array temp2(temp_shape, complex64, nullptr, {});
|
||||
slice_gpu(pad_temp1, temp2, starts, strides, s);
|
||||
|
||||
binary_op_gpu_inplace({temp2, w_k_broadcast}, temp1, "Multiply", s);
|
||||
|
||||
if (real && !inverse) {
|
||||
Shape rstarts(in.ndim(), 0);
|
||||
@@ -449,26 +589,21 @@ void multi_upload_bluestein_fft(
|
||||
array temp_float(out.shape(), out.dtype(), nullptr, {});
|
||||
copies.push_back(temp_float);
|
||||
copies.push_back(inv_n);
|
||||
copies.push_back(temp1);
|
||||
|
||||
copy_gpu(temp1, temp_float, CopyType::General, s);
|
||||
binary_op_gpu({temp_float, inv_n}, out, "Multiply", s);
|
||||
} else if (inverse) {
|
||||
auto inv_n = array({1.0f / n}, {1}, complex64);
|
||||
unary_op_gpu({temp1}, temp, "Conjugate", s);
|
||||
binary_op_gpu({temp, inv_n}, out, "Multiply", s);
|
||||
array temp3(temp_shape, complex64, nullptr, {});
|
||||
unary_op_gpu({temp1}, temp3, "Conjugate", s);
|
||||
binary_op_gpu({temp3, inv_n}, out, "Multiply", s);
|
||||
copies.push_back(inv_n);
|
||||
copies.push_back(temp1);
|
||||
copies.push_back(temp3);
|
||||
} else {
|
||||
out.copy_shared_buffer(temp1);
|
||||
}
|
||||
|
||||
copies.push_back(w_k);
|
||||
copies.push_back(w_q);
|
||||
copies.push_back(w_k_broadcast);
|
||||
copies.push_back(w_q_broadcast);
|
||||
copies.push_back(temp);
|
||||
copies.push_back(temp1);
|
||||
copies.push_back(pad_temp);
|
||||
copies.push_back(pad_temp1);
|
||||
}
|
||||
|
||||
void four_step_fft(
|
||||
@@ -477,9 +612,10 @@ void four_step_fft(
|
||||
size_t axis,
|
||||
bool inverse,
|
||||
bool real,
|
||||
FFTPlan& plan,
|
||||
std::vector<array> copies,
|
||||
const Stream& s) {
|
||||
OldFFTPlan& plan,
|
||||
std::vector<array>& copies,
|
||||
const Stream& s,
|
||||
bool in_place) {
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
if (plan.bluestein_n == -1) {
|
||||
@@ -489,10 +625,26 @@ void four_step_fft(
|
||||
auto temp_shape = (real && inverse) ? out.shape() : in.shape();
|
||||
array temp(temp_shape, complex64, nullptr, {});
|
||||
fft_op(
|
||||
in, temp, axis, inverse, real, four_step_params, /*inplace=*/false, s);
|
||||
in,
|
||||
temp,
|
||||
axis,
|
||||
inverse,
|
||||
real,
|
||||
four_step_params,
|
||||
/*inplace=*/false,
|
||||
d,
|
||||
s);
|
||||
four_step_params.first_step = false;
|
||||
fft_op(
|
||||
temp, out, axis, inverse, real, four_step_params, /*inplace=*/false, s);
|
||||
temp,
|
||||
out,
|
||||
axis,
|
||||
inverse,
|
||||
real,
|
||||
four_step_params,
|
||||
/*inplace=*/in_place,
|
||||
d,
|
||||
s);
|
||||
copies.push_back(temp);
|
||||
} else {
|
||||
multi_upload_bluestein_fft(in, out, axis, inverse, real, plan, copies, s);
|
||||
@@ -507,9 +659,8 @@ void fft_op(
|
||||
bool real,
|
||||
const FourStepParams four_step_params,
|
||||
bool inplace,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
size_t n = out.dtype() == float32 ? out.shape(axis) : in.shape(axis);
|
||||
if (n == 1) {
|
||||
out.copy_shared_buffer(in);
|
||||
@@ -574,7 +725,7 @@ void fft_op(
|
||||
|
||||
auto plan = plan_fft(n);
|
||||
if (plan.four_step) {
|
||||
four_step_fft(in, out, axis, inverse, real, plan, copies, s);
|
||||
four_step_fft(in, out, axis, inverse, real, plan, copies, s, inplace);
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
return;
|
||||
}
|
||||
@@ -744,57 +895,517 @@ void fft_op(
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
}
|
||||
|
||||
void fft_op(
|
||||
inline int compute_elems_per_thread(int n, const std::vector<int>& steps) {
|
||||
auto radices = supported_radices();
|
||||
std::set<int> used_radices;
|
||||
for (int i = 0; i < steps.size(); i++) {
|
||||
if (steps[i] > 0) {
|
||||
used_radices.insert(radices[i % radices.size()]);
|
||||
}
|
||||
}
|
||||
|
||||
// Manual tuning for 7/11/13
|
||||
if (used_radices.find(7) != used_radices.end() &&
|
||||
(used_radices.find(11) != used_radices.end() ||
|
||||
used_radices.find(13) != used_radices.end())) {
|
||||
return 7;
|
||||
} else if (
|
||||
used_radices.find(11) != used_radices.end() &&
|
||||
used_radices.find(13) != used_radices.end()) {
|
||||
return 11;
|
||||
}
|
||||
|
||||
// TODO(alexbarron) Some really weird stuff is going on
|
||||
// for certain `elems_per_thread` on large composite n.
|
||||
// Possibly a compiler issue?
|
||||
if (n == 3159)
|
||||
return 13;
|
||||
if (n == 3645)
|
||||
return 5;
|
||||
if (n == 3969)
|
||||
return 7;
|
||||
if (n == 1982)
|
||||
return 5;
|
||||
|
||||
if (used_radices.size() == 1) {
|
||||
return *(used_radices.begin());
|
||||
}
|
||||
if (used_radices.size() == 2 &&
|
||||
(used_radices.find(11) != used_radices.end() ||
|
||||
used_radices.find(13) != used_radices.end())) {
|
||||
return std::accumulate(used_radices.begin(), used_radices.end(), 0) / 2;
|
||||
}
|
||||
|
||||
// In all other cases use the second smallest radix.
|
||||
return *(++used_radices.begin());
|
||||
}
|
||||
|
||||
inline array ensure_fastest_moving_axis(
|
||||
const array& x,
|
||||
int axis,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
// The axis is already with a stride of 1 so check that we have no overlaps
|
||||
// and broadcasting and avoid the copy.
|
||||
if (x.strides(axis) == 1) {
|
||||
// This is a fairly strict test perhaps consider relaxing it in the future.
|
||||
if (x.flags().row_contiguous || x.flags().col_contiguous) {
|
||||
return x;
|
||||
}
|
||||
}
|
||||
|
||||
// To make it the fastest moving axis simply transpose it, then copy it and
|
||||
// then transpose it back.
|
||||
|
||||
// Transpose it
|
||||
std::vector<int> axes(x.ndim(), 0);
|
||||
for (int ax = 0; ax < axes.size(); ax++) {
|
||||
axes[ax] = (ax < axis) ? ax : ax + 1;
|
||||
}
|
||||
axes.back() = axis;
|
||||
Shape xtshape;
|
||||
xtshape.reserve(axes.size());
|
||||
for (auto ax : axes) {
|
||||
xtshape.push_back(x.shape(ax));
|
||||
}
|
||||
array xt(xtshape, x.dtype(), nullptr, {});
|
||||
transpose(x, xt, axes);
|
||||
|
||||
// Copy it
|
||||
array xtc(xt.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(
|
||||
xt,
|
||||
xtc,
|
||||
xt.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||
s);
|
||||
d.add_temporary(xtc, s.index);
|
||||
|
||||
// Transpose it
|
||||
for (int ax = 0; ax < axes.size(); ax++) {
|
||||
axes[ax] = (ax < axis) ? ax : ((ax == axis) ? axes.size() - 1 : ax - 1);
|
||||
}
|
||||
array y(x.shape(), x.dtype(), nullptr, {});
|
||||
transpose(xtc, y, axes);
|
||||
|
||||
return y;
|
||||
}
|
||||
|
||||
inline void prepare_output_array(const array& in, array& out, int axis) {
|
||||
// Prepare the output array such that it matches the input in terms of
|
||||
// stride ordering. Namely we might have moved `axis` around in the `in`
|
||||
// array. We must do the same in `out`. The difference is that we don't have
|
||||
// to copy anything because `out` contains garbage at the moment.
|
||||
|
||||
if (in.flags().row_contiguous && out.flags().row_contiguous) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<int> axes(out.ndim(), 0);
|
||||
for (int ax = 0; ax < axes.size(); ax++) {
|
||||
axes[ax] = (ax < axis) ? ax : ax + 1;
|
||||
}
|
||||
axes.back() = axis;
|
||||
as_transposed(out, axes);
|
||||
}
|
||||
|
||||
void fft_stockham_inplace(
|
||||
const FFTPlan& plan,
|
||||
const array& in_,
|
||||
array& out,
|
||||
size_t axis,
|
||||
bool inverse,
|
||||
bool real,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
// Prepare the input and output arrays such that `axis` has stride 1.
|
||||
// Possibly copy the input but never the output as it doesn't have anything
|
||||
// useful in it yet.
|
||||
array in = ensure_fastest_moving_axis(in_, axis, d, s);
|
||||
prepare_output_array(in, out, axis);
|
||||
|
||||
// Prepare the arguments for stockham fft
|
||||
int n = plan.size();
|
||||
bool power_of_2 = is_power_of_2(n);
|
||||
int total_batch_size =
|
||||
out.dtype() == float32 ? out.size() / n : in.size() / n;
|
||||
auto& steps = plan.steps();
|
||||
int elems_per_thread = compute_elems_per_thread(n, steps);
|
||||
int threads_per_fft = ceildiv(n, elems_per_thread);
|
||||
int tg_batch_size = std::max(MIN_THREADGROUP_MEM_SIZE / n, 1);
|
||||
int tg_mem_size = next_power_of_2(tg_batch_size * n);
|
||||
int batch_size = ceildiv(total_batch_size, tg_batch_size);
|
||||
batch_size = real ? ceildiv(batch_size, 2) : batch_size; // 2 RFFTs at once
|
||||
std::vector<MTLFC> func_consts = {
|
||||
{&inverse, MTL::DataType::DataTypeBool, 0},
|
||||
{&power_of_2, MTL::DataType::DataTypeBool, 1},
|
||||
{&elems_per_thread, MTL::DataType::DataTypeInt, 2}};
|
||||
for (int i = 0; i < steps.size(); i++) {
|
||||
func_consts.emplace_back(&steps[i], MTL::DataType::DataTypeInt, 4 + i);
|
||||
}
|
||||
|
||||
// Get the kernel
|
||||
auto in_type = in.dtype() == float32 ? "float" : "float2";
|
||||
auto out_type = out.dtype() == float32 ? "float" : "float2";
|
||||
std::string hash_name;
|
||||
std::string kname;
|
||||
kname.reserve(64);
|
||||
hash_name.reserve(64);
|
||||
concatenate(kname, "fft_mem_", tg_mem_size, "_", in_type, "_", out_type);
|
||||
concatenate(hash_name, kname, "_n", n, "_inv_", inverse);
|
||||
auto template_def =
|
||||
get_template_definition(kname, "fft", tg_mem_size, in_type, out_type);
|
||||
auto kernel = get_fft_kernel(d, kname, hash_name, func_consts, template_def);
|
||||
|
||||
// Launch it
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder.set_bytes(n, 2);
|
||||
compute_encoder.set_bytes(total_batch_size, 3);
|
||||
|
||||
MTL::Size group_dims(1, tg_batch_size, threads_per_fft);
|
||||
MTL::Size grid_dims(batch_size, tg_batch_size, threads_per_fft);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void fft_four_step_inplace(
|
||||
const FFTPlan& plan,
|
||||
const array& in_,
|
||||
array& out,
|
||||
size_t axis,
|
||||
bool inverse,
|
||||
bool real,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
// Prepare the input and output arrays such that `axis` has stride 1.
|
||||
// Possibly copy the input but never the output as it doesn't have anything
|
||||
// useful in it yet.
|
||||
array in = ensure_fastest_moving_axis(in_, axis, d, s);
|
||||
prepare_output_array(in, out, axis);
|
||||
|
||||
// Also prepare the intermediate array for the four-step fft which is
|
||||
// implemented with 2 kernel calls.
|
||||
array intermediate(
|
||||
(real && inverse) ? out.shape() : in.shape(), complex64, nullptr, {});
|
||||
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
|
||||
prepare_output_array(in, intermediate, axis);
|
||||
d.add_temporary(intermediate, s.index);
|
||||
|
||||
// Make the two calls
|
||||
for (int step = 0; step < 2; step++) {
|
||||
// Create the parameters
|
||||
int n1 = plan.first_size();
|
||||
int n2 = plan.second_size();
|
||||
int n = (step == 0) ? n1 : n2;
|
||||
bool power_of_2 = true;
|
||||
int total_batch_size =
|
||||
out.dtype() == float32 ? out.size() / n : in.size() / n;
|
||||
auto& steps = (step == 0) ? plan.first_steps() : plan.second_steps();
|
||||
int elems_per_thread = compute_elems_per_thread(n, steps);
|
||||
int threads_per_fft = ceildiv(n, elems_per_thread);
|
||||
int tg_batch_size =
|
||||
std::max(MIN_THREADGROUP_MEM_SIZE / n, MIN_COALESCE_WIDTH);
|
||||
int tg_mem_size = next_power_of_2(tg_batch_size * n);
|
||||
int batch_size = ceildiv(total_batch_size, tg_batch_size);
|
||||
std::vector<MTLFC> func_consts = {
|
||||
{&inverse, MTL::DataType::DataTypeBool, 0},
|
||||
{&power_of_2, MTL::DataType::DataTypeBool, 1},
|
||||
{&elems_per_thread, MTL::DataType::DataTypeInt, 2}};
|
||||
for (int i = 0; i < steps.size(); i++) {
|
||||
func_consts.emplace_back(&steps[i], MTL::DataType::DataTypeInt, 4 + i);
|
||||
}
|
||||
|
||||
// Get the kernel
|
||||
auto to_type = [](const array& x) {
|
||||
return x.dtype() == float32 ? "float" : "float2";
|
||||
};
|
||||
auto in_type = step == 0 ? to_type(in) : to_type(intermediate);
|
||||
auto out_type = step == 0 ? to_type(intermediate) : to_type(out);
|
||||
std::string hash_name;
|
||||
std::string kname;
|
||||
kname.reserve(64);
|
||||
hash_name.reserve(64);
|
||||
concatenate(
|
||||
kname,
|
||||
"four_step_mem_",
|
||||
tg_mem_size,
|
||||
"_",
|
||||
in_type,
|
||||
"_",
|
||||
out_type,
|
||||
"_",
|
||||
step,
|
||||
(real ? "_true" : "_false"));
|
||||
concatenate(hash_name, kname, "_n", n, "_inv_", inverse);
|
||||
auto template_def = get_template_definition(
|
||||
kname, "four_step_fft", tg_mem_size, in_type, out_type, step, real);
|
||||
auto kernel =
|
||||
get_fft_kernel(d, kname, hash_name, func_consts, template_def);
|
||||
|
||||
// Launch it
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array((step == 0) ? in : intermediate, 0);
|
||||
compute_encoder.set_output_array((step == 0) ? intermediate : out, 1);
|
||||
compute_encoder.set_bytes(n1, 2);
|
||||
compute_encoder.set_bytes(n2, 3);
|
||||
compute_encoder.set_bytes(total_batch_size, 4);
|
||||
|
||||
MTL::Size group_dims(1, tg_batch_size, threads_per_fft);
|
||||
MTL::Size grid_dims(batch_size, tg_batch_size, threads_per_fft);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
void fft_bluestein(
|
||||
const FFTPlan& plan,
|
||||
const array& in_,
|
||||
array& out,
|
||||
size_t axis,
|
||||
bool inverse,
|
||||
bool real,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
// Prepare the input and output arrays such that `axis` has stride 1.
|
||||
// Possibly copy the input but never the output as it doesn't have anything
|
||||
// useful in it yet.
|
||||
array in = ensure_fastest_moving_axis(in_, axis, d, s);
|
||||
prepare_output_array(in, out, axis);
|
||||
|
||||
// Prepare the arguments for bluestein fft
|
||||
int n = plan.bluestein_size();
|
||||
bool power_of_2 = true;
|
||||
int total_batch_size = out.dtype() == float32 ? out.size() / plan.size()
|
||||
: in.size() / plan.size();
|
||||
auto& steps = plan.steps();
|
||||
int elems_per_thread = compute_elems_per_thread(n, steps);
|
||||
int threads_per_fft = ceildiv(n, elems_per_thread);
|
||||
int tg_batch_size = std::max(MIN_THREADGROUP_MEM_SIZE / n, 1);
|
||||
int tg_mem_size = next_power_of_2(tg_batch_size * n);
|
||||
int batch_size = ceildiv(total_batch_size, tg_batch_size);
|
||||
batch_size = real ? ceildiv(batch_size, 2) : batch_size; // 2 RFFTs at once
|
||||
std::vector<MTLFC> func_consts = {
|
||||
{&inverse, MTL::DataType::DataTypeBool, 0},
|
||||
{&power_of_2, MTL::DataType::DataTypeBool, 1},
|
||||
{&elems_per_thread, MTL::DataType::DataTypeInt, 2}};
|
||||
for (int i = 0; i < steps.size(); i++) {
|
||||
func_consts.emplace_back(&steps[i], MTL::DataType::DataTypeInt, 4 + i);
|
||||
}
|
||||
|
||||
// Get the kernel
|
||||
auto in_type = in.dtype() == float32 ? "float" : "float2";
|
||||
auto out_type = out.dtype() == float32 ? "float" : "float2";
|
||||
std::string hash_name;
|
||||
std::string kname;
|
||||
kname.reserve(64);
|
||||
hash_name.reserve(64);
|
||||
concatenate(
|
||||
kname, "bluestein_fft_mem_", tg_mem_size, "_", in_type, "_", out_type);
|
||||
concatenate(hash_name, kname, "_n", n, "_inv_", inverse);
|
||||
auto template_def = get_template_definition(
|
||||
kname, "bluestein_fft", tg_mem_size, in_type, out_type);
|
||||
auto kernel = get_fft_kernel(d, kname, hash_name, func_consts, template_def);
|
||||
|
||||
// Get the bluestein constants
|
||||
auto [w_k, w_q] =
|
||||
compute_bluestein_constants(plan.size(), plan.bluestein_size());
|
||||
d.add_temporary(w_k, s.index);
|
||||
d.add_temporary(w_q, s.index);
|
||||
|
||||
// Launch it
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder.set_input_array(w_q, 2);
|
||||
compute_encoder.set_input_array(w_k, 3);
|
||||
compute_encoder.set_bytes(plan.size(), 4);
|
||||
compute_encoder.set_bytes(n, 5);
|
||||
compute_encoder.set_bytes(total_batch_size, 6);
|
||||
|
||||
MTL::Size group_dims(1, tg_batch_size, threads_per_fft);
|
||||
MTL::Size grid_dims(batch_size, tg_batch_size, threads_per_fft);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void fft_multi_upload_bluestein(
|
||||
const FFTPlan& plan,
|
||||
const array& in_,
|
||||
array& out,
|
||||
size_t axis,
|
||||
bool inverse,
|
||||
bool real,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
// Get Bluestein's constants using the CPU (this is done in the submission
|
||||
// thread which is pretty bad).
|
||||
auto [w_k, w_q] =
|
||||
compute_bluestein_constants(plan.size(), plan.bluestein_size());
|
||||
d.add_temporary(w_k, s.index);
|
||||
d.add_temporary(w_q, s.index);
|
||||
|
||||
// Prepare the input
|
||||
auto in_shape = inverse ? out.shape() : in_.shape();
|
||||
array in(std::move(in_shape), complex64, nullptr, {});
|
||||
if (real && !inverse) {
|
||||
copy_gpu(
|
||||
in_,
|
||||
in,
|
||||
in_.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||
s);
|
||||
d.add_temporary(in, s.index);
|
||||
} else if (real && inverse) {
|
||||
int back_offset = plan.size() % 2 == 0 ? 2 : 1;
|
||||
auto slice_shape = in.shape();
|
||||
slice_shape[axis] -= back_offset;
|
||||
array slice_temp(slice_shape, complex64, nullptr, {});
|
||||
array conj_temp(in.shape(), complex64, nullptr, {});
|
||||
Shape rstarts(in.ndim(), 0);
|
||||
Shape rstrides(in.ndim(), 1);
|
||||
rstarts[axis] = in.shape(axis) - back_offset;
|
||||
rstrides[axis] = -1;
|
||||
unary_op_gpu({in_}, conj_temp, "Conjugate", s);
|
||||
slice_gpu(in_, slice_temp, rstarts, rstrides, s);
|
||||
concatenate_gpu({conj_temp, slice_temp}, in, (int)axis, s);
|
||||
d.add_temporary(conj_temp, s.index);
|
||||
} else if (inverse) {
|
||||
unary_op_gpu({in_}, in, "Conjugate", s);
|
||||
d.add_temporary(in, s.index);
|
||||
} else {
|
||||
in.copy_shared_buffer(in_);
|
||||
}
|
||||
|
||||
// Multiply with
|
||||
Strides b_strides(in.ndim(), 0);
|
||||
b_strides[axis] = 1;
|
||||
array w_k_broadcast(in.shape(), complex64, nullptr, {});
|
||||
w_k_broadcast.copy_shared_buffer(w_k, b_strides, {}, w_k.data_size());
|
||||
array x(in.shape(), complex64, nullptr, {});
|
||||
binary_op_gpu({in, w_k_broadcast}, x, "Multiply", s);
|
||||
d.add_temporary(x, s.index);
|
||||
|
||||
// Pad
|
||||
auto padded_shape = out.shape();
|
||||
padded_shape[axis] = plan.bluestein_size();
|
||||
array padded_x(padded_shape, complex64, nullptr, {});
|
||||
auto zero = array(complex64_t{0.0f, 0.0f});
|
||||
pad_gpu(x, zero, padded_x, {(int)axis}, {0}, s);
|
||||
d.add_temporary(zero, s.index);
|
||||
d.add_temporary(padded_x, s.index);
|
||||
|
||||
// First fft
|
||||
}
|
||||
|
||||
void fft_op_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
size_t axis,
|
||||
bool inverse,
|
||||
bool real,
|
||||
bool inplace,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
fft_op(in, out, axis, inverse, real, FourStepParams(), inplace, s);
|
||||
// Get the FFT size and plan it
|
||||
auto plan =
|
||||
FFTPlan(out.dtype() == float32 ? out.shape(axis) : in.shape(axis));
|
||||
|
||||
switch (plan.type()) {
|
||||
case FFTPlan::NOOP:
|
||||
std::cout << "--------------> 1-size FFT <-----------------" << std::endl;
|
||||
break;
|
||||
case FFTPlan::STOCKHAM:
|
||||
return fft_stockham_inplace(plan, in, out, axis, inverse, real, d, s);
|
||||
case FFTPlan::SMALL_FOUR_STEP:
|
||||
return fft_four_step_inplace(plan, in, out, axis, inverse, real, d, s);
|
||||
case FFTPlan::BLUESTEIN:
|
||||
return fft_bluestein(plan, in, out, axis, inverse, real, d, s);
|
||||
case FFTPlan::UNSUPPORTED: {
|
||||
std::string msg;
|
||||
concatenate(msg, "FFT of size ", plan.size(), " not supported");
|
||||
throw std::runtime_error(msg);
|
||||
}
|
||||
default:
|
||||
std::cout << "----- NYI ----" << std::endl;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void nd_fft_op(
|
||||
void nd_fft_op_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
const std::vector<size_t>& axes,
|
||||
bool inverse,
|
||||
bool real,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
// Perform ND FFT on GPU as a series of 1D FFTs
|
||||
auto temp_shape = inverse ? in.shape() : out.shape();
|
||||
array temp1(temp_shape, complex64, nullptr, {});
|
||||
array temp2(temp_shape, complex64, nullptr, {});
|
||||
std::vector<array> temp_arrs = {temp1, temp2};
|
||||
for (int i = axes.size() - 1; i >= 0; i--) {
|
||||
int reverse_index = axes.size() - i - 1;
|
||||
// For 5D and above, we don't want to reallocate our two temporary arrays
|
||||
bool inplace = reverse_index >= 3 && i != 0;
|
||||
// Opposite order for fft vs ifft
|
||||
int index = inverse ? reverse_index : i;
|
||||
size_t axis = axes[index];
|
||||
// Mirror np.fft.(i)rfftn and perform a real transform
|
||||
// only on the final axis.
|
||||
bool step_real = (real && index == axes.size() - 1);
|
||||
auto step_shape = inverse ? out.shape(axis) : in.shape(axis);
|
||||
const array& in_arr = i == axes.size() - 1 ? in : temp_arrs[1 - i % 2];
|
||||
array& out_arr = i == 0 ? out : temp_arrs[i % 2];
|
||||
fft_op(in_arr, out_arr, axis, inverse, step_real, inplace, s);
|
||||
}
|
||||
// We are going to make and possibly reuse some intermediate arrays that will
|
||||
// hold the intermediate fft results.
|
||||
auto shape = inverse ? in.shape() : out.shape();
|
||||
std::vector<array> intermediates;
|
||||
intermediates.reserve(2);
|
||||
|
||||
auto& d = metal::device(s.device);
|
||||
d.add_temporaries(std::move(temp_arrs), s.index);
|
||||
// Utility to return either in or one of the intermediates.
|
||||
auto get_input_array = [&](int step) -> const array& {
|
||||
// The first step so use the input array
|
||||
if (step == 0) {
|
||||
return in;
|
||||
}
|
||||
|
||||
return intermediates[(step - 1) % 2];
|
||||
};
|
||||
|
||||
// Utility to return either out or one of the intermediates. It also informs
|
||||
// us if we should allocate memory for that output or there is already some
|
||||
// allocated.
|
||||
auto get_output_array = [&](int step) -> array& {
|
||||
// It is the final step so return the output array
|
||||
if (step == axes.size() - 1) {
|
||||
return out;
|
||||
}
|
||||
|
||||
// We already have made an array that we can use so go ahead and use it and
|
||||
// don't reallocate the memory.
|
||||
if (step % 2 < intermediates.size()) {
|
||||
return intermediates[step % 2];
|
||||
}
|
||||
|
||||
array x(shape, complex64, nullptr, {});
|
||||
x.set_data(allocator::malloc(x.nbytes()));
|
||||
intermediates.emplace_back(std::move(x));
|
||||
d.add_temporary(intermediates.back(), s.index);
|
||||
|
||||
return intermediates.back();
|
||||
};
|
||||
|
||||
// Perform ND FFT on GPU as a series of 1D FFTs
|
||||
for (int step = 0; step < axes.size(); step++) {
|
||||
auto x = get_input_array(step);
|
||||
auto y = get_output_array(step);
|
||||
auto step_axis = axes[inverse ? step : axes.size() - step - 1];
|
||||
auto step_real = real && (inverse ? step == axes.size() - 1 : step == 0);
|
||||
fft_op_inplace(x, y, step_axis, inverse, step_real, d, s);
|
||||
}
|
||||
}
|
||||
|
||||
void FFT::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto& in = inputs[0];
|
||||
|
||||
// The FFT ops above have the *_inplace suffix. This means that the memory
|
||||
// needs to be already allocated in the output array. Similar to
|
||||
// copy_gpu_inplace and so on.
|
||||
//
|
||||
// Even though we allocate the memory, we do not necessarily want the
|
||||
// contiguous strides so the *_inplace ops may change the strides and flags
|
||||
// of the array but will not reallocate the memory.
|
||||
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
if (axes_.size() > 1) {
|
||||
nd_fft_op(in, out, axes_, inverse_, real_, s);
|
||||
nd_fft_op_inplace(in, out, axes_, inverse_, real_, d, s);
|
||||
} else {
|
||||
fft_op(in, out, axes_[0], inverse_, real_, /*inplace=*/false, s);
|
||||
fft_op_inplace(in, out, axes_[0], inverse_, real_, d, s);
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -1,11 +1,9 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <map>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/common/hadamard.h"
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/jit/includes.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
@@ -15,7 +13,6 @@
|
||||
namespace mlx::core {
|
||||
|
||||
constexpr int MAX_HADAMARD_THREADS_PER_GROUP = 256;
|
||||
constexpr int MAX_HADAMARD_BYTES = 32768; // 32KB
|
||||
|
||||
std::string gen_hadamard_codelet(int m) {
|
||||
// Generate a O(m^2) hadamard codelet for a given M
|
||||
@@ -60,121 +57,142 @@ std::string gen_hadamard_codelet(int m) {
|
||||
return source.str();
|
||||
}
|
||||
|
||||
void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& s = stream();
|
||||
void hadamard_mn_contiguous(
|
||||
const array& x,
|
||||
array& y,
|
||||
int m,
|
||||
int n1,
|
||||
int n2,
|
||||
float scale,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
int n = n1 * n2;
|
||||
int read_width_n1 = n1 == 2 ? 2 : 4;
|
||||
int read_width_n2 = n2 == 2 ? 2 : 4;
|
||||
int read_width_m = (n == 2 || m == 28) ? 2 : 4;
|
||||
int max_radix_1 = std::min(n1, 16);
|
||||
int max_radix_2 = std::min(n2, 16);
|
||||
float scale_n1 = 1.0;
|
||||
float scale_n2 = (m == 1) ? scale : 1.0;
|
||||
float scale_m = scale;
|
||||
|
||||
auto& in = inputs[0];
|
||||
// n2 is a row contiguous power of 2 hadamard transform
|
||||
MTL::Size group_dims_n2(n2 / max_radix_2, 1, 1);
|
||||
MTL::Size grid_dims_n2(n2 / max_radix_2, x.size() / n2, 1);
|
||||
|
||||
std::vector<array> copies;
|
||||
// Only support the last axis for now
|
||||
int axis = in.ndim() - 1;
|
||||
auto check_input = [&copies, &s](const array& x) {
|
||||
// TODO(alexbarron) pass strides to kernel to relax this constraint
|
||||
bool no_copy = x.flags().row_contiguous;
|
||||
if (no_copy) {
|
||||
return x;
|
||||
} else {
|
||||
copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
|
||||
copy_gpu(x, copies.back(), CopyType::General, s);
|
||||
return copies.back();
|
||||
// n1 is a strided power of 2 hadamard transform with stride n2
|
||||
MTL::Size group_dims_n1(n1 / max_radix_1, 1, 1);
|
||||
MTL::Size grid_dims_n1(n1 / max_radix_1, x.size() / n, n2);
|
||||
|
||||
// m is a strided hadamard transform with stride n = n1 * n2
|
||||
MTL::Size group_dims_m(
|
||||
std::min(n / read_width_m, MAX_HADAMARD_THREADS_PER_GROUP), 1, 1);
|
||||
MTL::Size grid_dims_m(
|
||||
group_dims_m.width, x.size() / m / read_width_m / group_dims_m.width, 1);
|
||||
|
||||
// Make the kernel
|
||||
std::string kname;
|
||||
kname.reserve(32);
|
||||
concatenate(kname, "hadamard_", n * m, "_", type_to_name(x));
|
||||
auto lib = d.get_library(kname, [&]() {
|
||||
std::string kernel;
|
||||
concatenate(
|
||||
kernel,
|
||||
metal::utils(),
|
||||
gen_hadamard_codelet(m),
|
||||
metal::hadamard(),
|
||||
get_template_definition(
|
||||
"n2" + kname,
|
||||
"hadamard_n",
|
||||
get_type_string(x.dtype()),
|
||||
n2,
|
||||
max_radix_2,
|
||||
read_width_n2));
|
||||
if (n1 > 1) {
|
||||
kernel += get_template_definition(
|
||||
"n1" + kname,
|
||||
"hadamard_n",
|
||||
get_type_string(x.dtype()),
|
||||
n1,
|
||||
max_radix_1,
|
||||
read_width_n1,
|
||||
n2);
|
||||
}
|
||||
};
|
||||
const array& in_contiguous = check_input(in);
|
||||
|
||||
if (in_contiguous.is_donatable()) {
|
||||
out.copy_shared_buffer(in_contiguous);
|
||||
} else {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
}
|
||||
|
||||
int n, m;
|
||||
std::tie(n, m) = decompose_hadamard(in.shape(axis));
|
||||
|
||||
if (n * (int)size_of(in.dtype()) > MAX_HADAMARD_BYTES) {
|
||||
throw std::invalid_argument(
|
||||
"[hadamard] For n = m*2^k, 2^k > 8192 for FP32 or 2^k > 16384 for FP16/BF16 NYI");
|
||||
}
|
||||
|
||||
int max_radix = std::min(n, 16);
|
||||
// Use read_width 2 for m = 28 to avoid register spilling
|
||||
int read_width = (n == 2 || m == 28) ? 2 : 4;
|
||||
|
||||
std::ostringstream kname;
|
||||
kname << "hadamard_" << n * m << "_" << type_to_name(out);
|
||||
auto kernel_name = kname.str();
|
||||
auto& d = metal::device(s.device);
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
std::ostringstream kernel_source;
|
||||
auto codelet = gen_hadamard_codelet(m);
|
||||
kernel_source << metal::utils() << codelet << metal::hadamard();
|
||||
kernel_source << get_template_definition(
|
||||
"n" + kernel_name,
|
||||
"hadamard_n",
|
||||
get_type_string(in.dtype()),
|
||||
n,
|
||||
max_radix,
|
||||
read_width);
|
||||
kernel_source << get_template_definition(
|
||||
"m" + kernel_name,
|
||||
"hadamard_m",
|
||||
get_type_string(in.dtype()),
|
||||
n,
|
||||
m,
|
||||
read_width);
|
||||
return kernel_source.str();
|
||||
if (m > 1) {
|
||||
kernel += get_template_definition(
|
||||
"m" + kname,
|
||||
"hadamard_m",
|
||||
get_type_string(x.dtype()),
|
||||
n,
|
||||
m,
|
||||
read_width_m);
|
||||
}
|
||||
return kernel;
|
||||
});
|
||||
|
||||
int batch_size = in.size() / n;
|
||||
int threads_per = n / max_radix;
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
|
||||
auto launch_hadamard = [&](const array& in,
|
||||
array& out,
|
||||
const std::string& kernel_name,
|
||||
float scale) {
|
||||
auto kernel = d.get_kernel(kernel_name, lib);
|
||||
assert(threads_per <= kernel->maxTotalThreadsPerThreadgroup());
|
||||
|
||||
// Launch the strided transform for n1
|
||||
if (n1 > 1) {
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel("n1" + kname, lib);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder.set_bytes(scale, 2);
|
||||
|
||||
MTL::Size group_dims = MTL::Size(1, threads_per, 1);
|
||||
MTL::Size grid_dims = MTL::Size(batch_size, threads_per, 1);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
};
|
||||
|
||||
if (m > 1) {
|
||||
// When m is greater than 1, we decompose the
|
||||
// computation into two uploads to the GPU:
|
||||
//
|
||||
// e.g. len(x) = 12*4 = 48, m = 12, n = 4
|
||||
//
|
||||
// y = h48 @ x
|
||||
//
|
||||
// Upload 1:
|
||||
// tmp = a.reshape(12, 4) @ h4
|
||||
//
|
||||
// Upload 2:
|
||||
// y = h12 @ tmp
|
||||
array temp(in.shape(), in.dtype(), nullptr, {});
|
||||
temp.set_data(allocator::malloc(temp.nbytes()));
|
||||
copies.push_back(temp);
|
||||
|
||||
launch_hadamard(in_contiguous, temp, "n" + kernel_name, 1.0);
|
||||
|
||||
// Metal sometimes reports 256 max threads per group for hadamard_m kernel
|
||||
threads_per = std::min(n / read_width, MAX_HADAMARD_THREADS_PER_GROUP);
|
||||
batch_size = in.size() / m / read_width / threads_per;
|
||||
launch_hadamard(temp, out, "m" + kernel_name, scale_);
|
||||
} else {
|
||||
launch_hadamard(in_contiguous, out, "n" + kernel_name, scale_);
|
||||
compute_encoder.set_input_array(x, 0);
|
||||
compute_encoder.set_output_array(y, 1);
|
||||
compute_encoder.set_bytes(scale_n1, 2);
|
||||
compute_encoder.dispatch_threads(grid_dims_n1, group_dims_n1);
|
||||
}
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
// Launch the transform for n2
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel("n2" + kname, lib);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(n1 > 1 ? y : x, 0);
|
||||
compute_encoder.set_output_array(y, 1);
|
||||
compute_encoder.set_bytes(scale_n2, 2);
|
||||
compute_encoder.dispatch_threads(grid_dims_n2, group_dims_n2);
|
||||
|
||||
// Launch the strided transform for m
|
||||
if (m > 1) {
|
||||
auto kernel = d.get_kernel("m" + kname, lib);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(y, 0);
|
||||
compute_encoder.set_output_array(y, 1);
|
||||
compute_encoder.set_bytes(scale_m, 2);
|
||||
compute_encoder.dispatch_threads(grid_dims_m, group_dims_m);
|
||||
}
|
||||
}
|
||||
|
||||
void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto& in = inputs[0];
|
||||
|
||||
// Split the hadamard transform so that all of them work on vectors smaller
|
||||
// than 8192 elements.
|
||||
//
|
||||
// We decompose it in the following way:
|
||||
//
|
||||
// n = m * n1 * n2 = m * 2^k1 * 2^k2
|
||||
//
|
||||
// where m is in (1, 12, 20, 28) and n1 and n2 <= 8192
|
||||
auto [n, m] = decompose_hadamard(in.shape().back());
|
||||
int n1 = 1, n2 = n;
|
||||
if (n > 8192) {
|
||||
for (n2 = 2; n2 * n2 < n; n2 *= 2) {
|
||||
}
|
||||
n1 = n / n2;
|
||||
}
|
||||
|
||||
if (in.flags().row_contiguous) {
|
||||
if (in.is_donatable()) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
}
|
||||
hadamard_mn_contiguous(in, out, m, n1, n2, scale_, d, s);
|
||||
} else {
|
||||
copy_gpu(in, out, CopyType::General, s);
|
||||
hadamard_mn_contiguous(out, out, m, n1, n2, scale_, d, s);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -2,7 +2,7 @@
|
||||
#include <fmt/format.h>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/jit/includes.h"
|
||||
#include "mlx/backend/metal/jit/indexing.h"
|
||||
|
@@ -1,9 +0,0 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view arange_kernels = R"(
|
||||
template [[host_name("{0}")]] [[kernel]] void arange<{1}>(
|
||||
constant const {1}& start,
|
||||
constant const {1}& step,
|
||||
device {1}* out,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
)";
|
@@ -20,6 +20,7 @@ const char* copy();
|
||||
const char* fft();
|
||||
const char* gather_axis();
|
||||
const char* hadamard();
|
||||
const char* logsumexp();
|
||||
const char* quantized();
|
||||
const char* ternary();
|
||||
const char* scan();
|
||||
@@ -32,6 +33,7 @@ const char* gemm();
|
||||
const char* steel_gemm_fused();
|
||||
const char* steel_gemm_masked();
|
||||
const char* steel_gemm_splitk();
|
||||
const char* steel_gemm_gather();
|
||||
const char* conv();
|
||||
const char* steel_conv();
|
||||
const char* steel_conv_general();
|
||||
|
@@ -1,23 +0,0 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view softmax_kernels = R"(
|
||||
template [[host_name("block_{0}")]] [[kernel]] void
|
||||
softmax_single_row<{1}, {2}>(
|
||||
const device {1}* in,
|
||||
device {1}* out,
|
||||
constant int& axis_size,
|
||||
uint gid [[thread_position_in_grid]],
|
||||
uint _lid [[thread_position_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
template [[host_name("looped_{0}")]] [[kernel]] void
|
||||
softmax_looped<{1}, {2}>(
|
||||
const device {1}* in,
|
||||
device {1}* out,
|
||||
constant int& axis_size,
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint lsize [[threads_per_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
)";
|
@@ -1,8 +1,6 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/metal/jit/arange.h"
|
||||
#include "mlx/backend/metal/jit/includes.h"
|
||||
#include "mlx/backend/metal/jit/softmax.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
|
||||
@@ -21,13 +19,11 @@ MTL::ComputePipelineState* get_arange_kernel(
|
||||
const std::string& kernel_name,
|
||||
const array& out) {
|
||||
auto lib = d.get_library(kernel_name, [&]() {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::arange()
|
||||
<< fmt::format(
|
||||
arange_kernels,
|
||||
kernel_name,
|
||||
get_type_string(out.dtype()));
|
||||
return kernel_source.str();
|
||||
std::string kernel_source = metal::utils();
|
||||
kernel_source += metal::arange();
|
||||
kernel_source += get_template_definition(
|
||||
kernel_name, "arange", get_type_string(out.dtype()));
|
||||
return kernel_source;
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
@@ -259,14 +255,34 @@ MTL::ComputePipelineState* get_softmax_kernel(
|
||||
const array& out) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name, [&] {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::softmax()
|
||||
<< fmt::format(
|
||||
softmax_kernels,
|
||||
lib_name,
|
||||
get_type_string(out.dtype()),
|
||||
get_type_string(precise ? float32 : out.dtype()));
|
||||
return kernel_source.str();
|
||||
std::string kernel_source = metal::utils();
|
||||
auto in_type = get_type_string(out.dtype());
|
||||
auto acc_type = get_type_string(precise ? float32 : out.dtype());
|
||||
kernel_source += metal::softmax();
|
||||
kernel_source += get_template_definition(
|
||||
"block_" + lib_name, "softmax_single_row", in_type, acc_type);
|
||||
kernel_source += get_template_definition(
|
||||
"looped_" + lib_name, "softmax_looped", in_type, acc_type);
|
||||
return kernel_source;
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_logsumexp_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name, [&] {
|
||||
auto t_str = get_type_string(out.dtype());
|
||||
std::string kernel_source;
|
||||
kernel_source = metal::utils();
|
||||
kernel_source += metal::logsumexp();
|
||||
kernel_source +=
|
||||
get_template_definition("block_" + lib_name, "logsumexp", t_str);
|
||||
kernel_source += get_template_definition(
|
||||
"looped_" + lib_name, "logsumexp_looped", t_str);
|
||||
return kernel_source;
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
@@ -568,6 +584,44 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_steel_gemm_gather_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& hash_name,
|
||||
const metal::MTLFCList& func_consts,
|
||||
const array& out,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
int bm,
|
||||
int bn,
|
||||
int bk,
|
||||
int wm,
|
||||
int wn,
|
||||
bool rhs) {
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
std::string kernel_source;
|
||||
concatenate(
|
||||
kernel_source,
|
||||
metal::utils(),
|
||||
metal::gemm(),
|
||||
metal::steel_gemm_gather(),
|
||||
get_template_definition(
|
||||
lib_name,
|
||||
rhs ? "gather_mm_rhs" : "gather_mm",
|
||||
get_type_string(out.dtype()),
|
||||
bm,
|
||||
bn,
|
||||
bk,
|
||||
wm,
|
||||
wn,
|
||||
transpose_a,
|
||||
transpose_b));
|
||||
return kernel_source;
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_gemv_masked_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
@@ -698,4 +752,43 @@ MTL::ComputePipelineState* get_quantized_kernel(
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_gather_qmm_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& hash_name,
|
||||
const metal::MTLFCList& func_consts,
|
||||
const array& x,
|
||||
int group_size,
|
||||
int bits,
|
||||
int bm,
|
||||
int bn,
|
||||
int bk,
|
||||
int wm,
|
||||
int wn,
|
||||
bool transpose) {
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
std::string kernel_source;
|
||||
concatenate(
|
||||
kernel_source,
|
||||
metal::utils(),
|
||||
metal::gemm(),
|
||||
metal::quantized(),
|
||||
get_template_definition(
|
||||
lib_name,
|
||||
"gather_qmm_rhs",
|
||||
get_type_string(x.dtype()),
|
||||
group_size,
|
||||
bits,
|
||||
bm,
|
||||
bn,
|
||||
bk,
|
||||
wm,
|
||||
wn,
|
||||
transpose));
|
||||
return kernel_source;
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -59,6 +59,11 @@ MTL::ComputePipelineState* get_softmax_kernel(
|
||||
bool precise,
|
||||
const array& out);
|
||||
|
||||
MTL::ComputePipelineState* get_logsumexp_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out);
|
||||
|
||||
MTL::ComputePipelineState* get_scan_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
@@ -155,6 +160,21 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
|
||||
bool mn_aligned,
|
||||
bool k_aligned);
|
||||
|
||||
MTL::ComputePipelineState* get_steel_gemm_gather_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& hash_name,
|
||||
const metal::MTLFCList& func_consts,
|
||||
const array& out,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
int bm,
|
||||
int bn,
|
||||
int bk,
|
||||
int wm,
|
||||
int wn,
|
||||
bool rhs);
|
||||
|
||||
MTL::ComputePipelineState* get_steel_conv_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
@@ -204,6 +224,21 @@ MTL::ComputePipelineState* get_quantized_kernel(
|
||||
const std::string& kernel_name,
|
||||
const std::string& template_def);
|
||||
|
||||
MTL::ComputePipelineState* get_gather_qmm_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& hash_name,
|
||||
const metal::MTLFCList& func_consts,
|
||||
const array& x,
|
||||
int group_size,
|
||||
int bits,
|
||||
int bm,
|
||||
int bn,
|
||||
int bk,
|
||||
int wm,
|
||||
int wn,
|
||||
bool transpose);
|
||||
|
||||
// Create a GPU kernel template definition for JIT compilation
|
||||
template <typename... Args>
|
||||
std::string
|
||||
|
@@ -13,6 +13,10 @@ function(build_kernel_base TARGET SRCFILE DEPS)
|
||||
if(MLX_METAL_DEBUG)
|
||||
set(METAL_FLAGS ${METAL_FLAGS} -gline-tables-only -frecord-sources)
|
||||
endif()
|
||||
if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "")
|
||||
set(METAL_FLAGS ${METAL_FLAGS}
|
||||
"-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
|
||||
endif()
|
||||
if(MLX_METAL_VERSION GREATER_EQUAL 310)
|
||||
set(VERSION_INCLUDES
|
||||
${PROJECT_SOURCE_DIR}/mlx/backend/metal/kernels/metal_3_1)
|
||||
@@ -65,6 +69,7 @@ set(STEEL_HEADERS
|
||||
steel/gemm/loader.h
|
||||
steel/gemm/transforms.h
|
||||
steel/gemm/kernels/steel_gemm_fused.h
|
||||
steel/gemm/kernels/steel_gemm_gather.h
|
||||
steel/gemm/kernels/steel_gemm_masked.h
|
||||
steel/gemm/kernels/steel_gemm_splitk.h
|
||||
steel/utils/type_traits.h
|
||||
@@ -105,12 +110,14 @@ if(NOT MLX_METAL_JIT)
|
||||
build_kernel(quantized quantized.h ${STEEL_HEADERS})
|
||||
build_kernel(scan scan.h)
|
||||
build_kernel(softmax softmax.h)
|
||||
build_kernel(logsumexp logsumexp.h)
|
||||
build_kernel(sort sort.h)
|
||||
build_kernel(ternary ternary.h ternary_ops.h)
|
||||
build_kernel(unary unary.h unary_ops.h)
|
||||
build_kernel(steel/conv/kernels/steel_conv ${STEEL_HEADERS})
|
||||
build_kernel(steel/conv/kernels/steel_conv_general ${STEEL_HEADERS})
|
||||
build_kernel(steel/gemm/kernels/steel_gemm_fused ${STEEL_HEADERS})
|
||||
build_kernel(steel/gemm/kernels/steel_gemm_gather ${STEEL_HEADERS})
|
||||
build_kernel(steel/gemm/kernels/steel_gemm_masked ${STEEL_HEADERS})
|
||||
build_kernel(steel/gemm/kernels/steel_gemm_splitk ${STEEL_HEADERS})
|
||||
build_kernel(gemv_masked steel/utils.h)
|
||||
|
@@ -5,11 +5,7 @@
|
||||
#include "mlx/backend/metal/kernels/arange.h"
|
||||
|
||||
#define instantiate_arange(tname, type) \
|
||||
template [[host_name("arange" #tname)]] [[kernel]] void arange<type>( \
|
||||
constant const type& start, \
|
||||
constant const type& step, \
|
||||
device type* out, \
|
||||
uint index [[thread_position_in_grid]]);
|
||||
instantiate_kernel("arange" #tname, arange, type)
|
||||
|
||||
instantiate_arange(uint8, uint8_t)
|
||||
instantiate_arange(uint16, uint16_t)
|
||||
|
@@ -9,64 +9,85 @@ template <typename T, typename U, typename Op>
|
||||
c[index] = Op()(a[0], b[0]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
[[kernel]] void binary_sv(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
constant uint& size,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
c[index] = Op()(a[0], b[index]);
|
||||
index *= N;
|
||||
for (int i = 0; i < N && (index + i) < size; ++i) {
|
||||
c[index + i] = Op()(a[0], b[index + i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
[[kernel]] void binary_vs(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
constant uint& size,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
c[index] = Op()(a[index], b[0]);
|
||||
index *= N;
|
||||
for (int i = 0; i < N && (index + i) < size; ++i) {
|
||||
c[index + i] = Op()(a[index + i], b[0]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
[[kernel]] void binary_vv(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
constant uint& size,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
c[index] = Op()(a[index], b[index]);
|
||||
index *= N;
|
||||
for (int i = 0; i < N && (index + i) < size; ++i) {
|
||||
c[index + i] = Op()(a[index + i], b[index + i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
[[kernel]] void binary_sv2(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
constant int64_t& size,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
int64_t offset = index.x + grid_dim.x * int64_t(index.y);
|
||||
c[offset] = Op()(a[0], b[offset]);
|
||||
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||
for (int i = 0; i < N && (offset + i) < size; ++i) {
|
||||
c[offset + i] = Op()(a[0], b[offset + i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
[[kernel]] void binary_vs2(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
constant int64_t& size,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
int64_t offset = index.x + grid_dim.x * int64_t(index.y);
|
||||
c[offset] = Op()(a[offset], b[0]);
|
||||
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||
for (int i = 0; i < N && (offset + i) < size; ++i) {
|
||||
c[offset + i] = Op()(a[offset + i], b[0]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
[[kernel]] void binary_vv2(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
constant int64_t& size,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
int64_t offset = index.x + grid_dim.x * int64_t(index.y);
|
||||
c[offset] = Op()(a[offset], b[offset]);
|
||||
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||
for (int i = 0; i < N && (offset + i) < size; ++i) {
|
||||
c[offset + i] = Op()(a[offset + i], b[offset + i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, typename IdxT = int64_t>
|
||||
|
@@ -71,6 +71,7 @@ instantiate_binary_types_bool(Less)
|
||||
instantiate_binary_types_bool(LessEqual)
|
||||
instantiate_binary_types_bool(NotEqual)
|
||||
instantiate_binary_float(LogAddExp)
|
||||
instantiate_binary_all(LogAddExp, complex64, complex64_t, complex64_t)
|
||||
instantiate_binary_types(Maximum)
|
||||
instantiate_binary_types(Minimum)
|
||||
instantiate_binary_types(Multiply)
|
||||
|
@@ -130,6 +130,24 @@ struct LogAddExp {
|
||||
? maxval
|
||||
: (maxval + log1p(metal::exp(minval - maxval)));
|
||||
};
|
||||
|
||||
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||
if (metal::isnan(x.real) || metal::isnan(x.imag) || metal::isnan(y.real) ||
|
||||
metal::isnan(y.imag)) {
|
||||
return metal::numeric_limits<float>::quiet_NaN();
|
||||
}
|
||||
constexpr float inf = metal::numeric_limits<float>::infinity();
|
||||
complex64_t maxval = x > y ? x : y;
|
||||
complex64_t minval = x < y ? x : y;
|
||||
if (minval.real == -inf || maxval.real == inf)
|
||||
return maxval;
|
||||
float m = metal::exp(minval.real - maxval.real);
|
||||
complex64_t dexp{
|
||||
m * metal::cos(minval.imag - maxval.imag),
|
||||
m * metal::sin(minval.imag - maxval.imag),
|
||||
};
|
||||
return maxval + log1p(dexp);
|
||||
}
|
||||
};
|
||||
|
||||
struct Maximum {
|
||||
|
@@ -12,82 +12,103 @@ template <typename T, typename U, typename Op>
|
||||
d[index] = out[1];
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
[[kernel]] void binary_sv(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
constant uint& size,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto out = Op()(a[0], b[index]);
|
||||
c[index] = out[0];
|
||||
d[index] = out[1];
|
||||
index *= N;
|
||||
for (int i = 0; i < N && (index + i) < size; ++i) {
|
||||
auto out = Op()(a[0], b[index + i]);
|
||||
c[index + i] = out[0];
|
||||
d[index + i] = out[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
[[kernel]] void binary_vs(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
constant uint& size,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto out = Op()(a[index], b[0]);
|
||||
c[index] = out[0];
|
||||
d[index] = out[1];
|
||||
index *= N;
|
||||
for (int i = 0; i < N && (index + i) < size; ++i) {
|
||||
auto out = Op()(a[index + i], b[0]);
|
||||
c[index + i] = out[0];
|
||||
d[index + i] = out[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
[[kernel]] void binary_vv(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
constant uint& size,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto out = Op()(a[index], b[index]);
|
||||
c[index] = out[0];
|
||||
d[index] = out[1];
|
||||
index *= N;
|
||||
for (int i = 0; i < N && (index + i) < size; ++i) {
|
||||
auto out = Op()(a[index + i], b[index + i]);
|
||||
c[index + i] = out[0];
|
||||
d[index + i] = out[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
[[kernel]] void binary_sv2(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
constant int64_t& size,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto offset = index.x + grid_dim.x * int64_t(index.y);
|
||||
auto out = Op()(a[0], b[offset]);
|
||||
c[offset] = out[0];
|
||||
d[offset] = out[1];
|
||||
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||
for (int i = 0; i < N && (offset + i) < size; ++i) {
|
||||
auto out = Op()(a[0], b[offset + i]);
|
||||
c[offset + i] = out[0];
|
||||
d[offset + i] = out[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
[[kernel]] void binary_vs2(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
constant int64_t& size,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto offset = index.x + grid_dim.x * int64_t(index.y);
|
||||
auto out = Op()(a[offset], b[0]);
|
||||
c[offset] = out[0];
|
||||
d[offset] = out[1];
|
||||
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||
for (int i = 0; i < N && (offset + i) < size; ++i) {
|
||||
auto out = Op()(a[offset + i], b[0]);
|
||||
c[offset + i] = out[0];
|
||||
d[offset + i] = out[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
[[kernel]] void binary_vv2(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
constant int64_t& size,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto offset = index.x + grid_dim.x * int64_t(index.y);
|
||||
auto out = Op()(a[offset], b[offset]);
|
||||
c[offset] = out[0];
|
||||
d[offset] = out[1];
|
||||
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||
for (int i = 0; i < N && (offset + i) < size; ++i) {
|
||||
auto out = Op()(a[offset + i], b[offset + i]);
|
||||
c[offset + i] = out[0];
|
||||
d[offset + i] = out[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, typename IdxT = int64_t>
|
||||
|
@@ -104,10 +104,22 @@ constexpr bool operator==(complex64_t a, complex64_t b) {
|
||||
constexpr complex64_t operator+(complex64_t a, complex64_t b) {
|
||||
return {a.real + b.real, a.imag + b.imag};
|
||||
}
|
||||
constexpr complex64_t operator+(float a, complex64_t b) {
|
||||
return {a + b.real, b.imag};
|
||||
}
|
||||
constexpr complex64_t operator+(complex64_t a, float b) {
|
||||
return {a.real + b, a.imag};
|
||||
}
|
||||
|
||||
constexpr complex64_t operator-(complex64_t a, complex64_t b) {
|
||||
return {a.real - b.real, a.imag - b.imag};
|
||||
}
|
||||
constexpr complex64_t operator-(float a, complex64_t b) {
|
||||
return {a - b.real, -b.imag};
|
||||
}
|
||||
constexpr complex64_t operator-(complex64_t a, float b) {
|
||||
return {a.real - b, a.imag};
|
||||
}
|
||||
|
||||
constexpr complex64_t operator*(complex64_t a, complex64_t b) {
|
||||
return {a.real * b.real - a.imag * b.imag, a.real * b.imag + a.imag * b.real};
|
||||
@@ -120,6 +132,13 @@ constexpr complex64_t operator/(complex64_t a, complex64_t b) {
|
||||
return {x / denom, y / denom};
|
||||
}
|
||||
|
||||
constexpr complex64_t operator/(float a, complex64_t b) {
|
||||
auto denom = b.real * b.real + b.imag * b.imag;
|
||||
auto x = a * b.real;
|
||||
auto y = -a * b.imag;
|
||||
return {x / denom, y / denom};
|
||||
}
|
||||
|
||||
constexpr complex64_t operator%(complex64_t a, complex64_t b) {
|
||||
auto real = a.real - (b.real * static_cast<int64_t>(a.real / b.real));
|
||||
auto imag = a.imag - (b.imag * static_cast<int64_t>(a.imag / b.imag));
|
||||
|
@@ -275,6 +275,128 @@ instantiate_naive_conv_2d_blocks(float32, float);
|
||||
instantiate_naive_conv_2d_blocks(float16, half);
|
||||
instantiate_naive_conv_2d_blocks(bfloat16, bfloat16_t);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
/// Depthwise convolution kernels
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
constant int ker_h [[function_constant(00)]];
|
||||
constant int ker_w [[function_constant(01)]];
|
||||
constant int str_h [[function_constant(10)]];
|
||||
constant int str_w [[function_constant(11)]];
|
||||
constant int tgp_h [[function_constant(100)]];
|
||||
constant int tgp_w [[function_constant(101)]];
|
||||
constant bool do_flip [[function_constant(200)]];
|
||||
|
||||
constant int span_h = tgp_h * str_h + ker_h - 1;
|
||||
constant int span_w = tgp_w * str_w + ker_w - 1;
|
||||
constant int span_hw = span_h * span_w;
|
||||
|
||||
template <typename T>
|
||||
[[kernel]] void depthwise_conv_2d(
|
||||
const device T* in [[buffer(0)]],
|
||||
const device T* wt [[buffer(1)]],
|
||||
device T* out [[buffer(2)]],
|
||||
const constant MLXConvParams<2>& params [[buffer(3)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 gid [[thread_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
constexpr int tc = 8;
|
||||
constexpr int tw = 8;
|
||||
constexpr int th = 4;
|
||||
|
||||
constexpr int c_per_thr = 8;
|
||||
|
||||
constexpr int TGH = th * 2 + 6;
|
||||
constexpr int TGW = tw * 2 + 6;
|
||||
constexpr int TGC = tc;
|
||||
|
||||
threadgroup T ins[TGH * TGW * TGC];
|
||||
|
||||
const int n_tgblocks_h = params.oS[0] / th;
|
||||
const int n = tid.z / n_tgblocks_h;
|
||||
const int tghid = tid.z % n_tgblocks_h;
|
||||
const int oh = tghid * th + lid.z;
|
||||
const int ow = gid.y;
|
||||
const int c = gid.x;
|
||||
|
||||
in += n * params.in_strides[0];
|
||||
|
||||
// Load in
|
||||
{
|
||||
constexpr int n_threads = th * tw * tc;
|
||||
const int tg_oh = (tghid * th) * str_h - params.pad[0];
|
||||
const int tg_ow = (tid.y * tw) * str_w - params.pad[1];
|
||||
const int tg_c = tid.x * tc;
|
||||
|
||||
const int thread_idx = simd_gid * 32 + simd_lid;
|
||||
constexpr int thr_per_hw = tc / c_per_thr;
|
||||
constexpr int hw_per_group = n_threads / thr_per_hw;
|
||||
|
||||
const int thr_c = thread_idx % thr_per_hw;
|
||||
const int thr_hw = thread_idx / thr_per_hw;
|
||||
|
||||
for (int hw = thr_hw; hw < span_hw; hw += hw_per_group) {
|
||||
const int h = hw / span_w;
|
||||
const int w = hw % span_w;
|
||||
|
||||
const int ih = tg_oh + h;
|
||||
const int iw = tg_ow + w;
|
||||
|
||||
const int in_s_offset = h * span_w * TGC + w * TGC;
|
||||
|
||||
if (ih >= 0 && ih < params.iS[0] && iw >= 0 && iw < params.iS[1]) {
|
||||
const auto in_load =
|
||||
in + ih * params.in_strides[1] + iw * params.in_strides[2] + tg_c;
|
||||
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int cc = 0; cc < c_per_thr; ++cc) {
|
||||
ins[in_s_offset + c_per_thr * thr_c + cc] =
|
||||
in_load[c_per_thr * thr_c + cc];
|
||||
}
|
||||
} else {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int cc = 0; cc < c_per_thr; ++cc) {
|
||||
ins[in_s_offset + c_per_thr * thr_c + cc] = T(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
wt += c * params.wt_strides[0];
|
||||
|
||||
const auto ins_ptr =
|
||||
&ins[lid.z * str_h * span_w * TGC + lid.y * str_w * TGC + lid.x];
|
||||
float o = 0.;
|
||||
for (int h = 0; h < ker_h; ++h) {
|
||||
for (int w = 0; w < ker_w; ++w) {
|
||||
int wt_h = h;
|
||||
int wt_w = w;
|
||||
if (do_flip) {
|
||||
wt_h = ker_h - h - 1;
|
||||
wt_w = ker_w - w - 1;
|
||||
}
|
||||
auto inv = ins_ptr[h * span_w * TGC + w * TGC];
|
||||
auto wtv = wt[wt_h * ker_w + wt_w];
|
||||
o += inv * wtv;
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
out += n * params.out_strides[0] + oh * params.out_strides[1] +
|
||||
ow * params.out_strides[2];
|
||||
out[c] = static_cast<T>(o);
|
||||
}
|
||||
|
||||
#define instantiate_depthconv2d(iname, itype) \
|
||||
instantiate_kernel("depthwise_conv_2d_" #iname, depthwise_conv_2d, itype)
|
||||
|
||||
instantiate_depthconv2d(float32, float);
|
||||
instantiate_depthconv2d(float16, half);
|
||||
instantiate_depthconv2d(bfloat16, bfloat16_t);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
/// Winograd kernels
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user