Compare commits

..

92 Commits

Author SHA1 Message Date
Jagrit Digani
4c46e17a5d Update benchmark output 2025-04-15 10:50:06 -07:00
Angelos Katharopoulos
99eefd2ec0 Gather mm new kernel and small refactoring (#2040) 2025-04-14 16:37:36 -07:00
Yury Popov
e9e268336b LogCumSumExp (#2069) 2025-04-13 01:27:29 -07:00
Awni Hannun
7275ac7523 Fix release build (#2072) 2025-04-12 20:41:58 -07:00
Angelos Katharopoulos
c4189a38e4 Add float mask to sdpa vector (#2068) 2025-04-11 17:29:40 -07:00
Awni Hannun
68d1b3256b nit: fix exception handling (#2066) 2025-04-11 14:12:08 -07:00
Awni Hannun
9c6953bda7 Fix stubgen (#2065)
* Fix stubgen

* add multi optim to docs
2025-04-11 12:02:54 -07:00
Awni Hannun
ef7ece9851 fix fft bug (#2062) 2025-04-10 19:41:27 -07:00
Angelos Katharopoulos
ddaa4b7dcb Fix the test and add custom min/max reductions for uncommon MPI types (#2060) 2025-04-10 17:01:17 -07:00
Cheng
dfae2c6989 Fix MSVC build due to use of M_LN2 (#2058) 2025-04-10 07:41:41 -07:00
Anastasiia Filippova
515f104926 Min / max reductions (#2041) 2025-04-09 23:22:20 -07:00
Angelos Katharopoulos
9ecefd56db Do not load the default lib if another is requested (#2055) 2025-04-09 13:31:38 -07:00
Awni Hannun
e5d35aa187 no sdpa in grad (#2054) 2025-04-08 19:13:54 -07:00
Awni Hannun
00794c42bc Fix causal mask sdpa vec (#2053)
* fix sdpa vector causal mask

* test
2025-04-08 09:11:23 -07:00
Cheng
08a1bf3f10 Remove Event::Signal() (#2052) 2025-04-08 06:20:27 -07:00
Awni Hannun
60c4154346 Only request residency once (#2051) 2025-04-07 10:47:51 -07:00
Awni Hannun
f2c85308c1 add a half simd gemm fallback (#2046)
* add a half simd gemm fallback

* nit
2025-04-07 09:31:29 -07:00
Awni Hannun
1a28b69ee2 only add to residency set once (#2049) 2025-04-06 17:38:25 -07:00
Cheng
ba09f01ce8 Remove test of converting negative float to uint (#2048) 2025-04-06 06:21:46 -07:00
Cheng
6cf48872b7 wait_for_one should wait for task to finish (#2047) 2025-04-05 20:05:16 -07:00
Angelos Katharopoulos
7b3b8fa000 Fix ci release (#2045) 2025-04-04 20:25:01 -07:00
Awni Hannun
ec5e2aae61 nit in doc (#2044) 2025-04-04 12:04:17 -07:00
Awni Hannun
86389bf970 patch bump (#2043) 2025-04-03 13:15:18 -07:00
Jagrit Digani
3290bfa690 Add new sdpa function overload (#2035)
* Add new sdpa function overload

* Address comments

* Remove std::varaint from cpp sdpa function
2025-04-03 11:58:28 -07:00
Jagrit Digani
8777fd104f Depthwise Conv2D optimization (#2036)
- Add new specialized kernel for small kernel (kernels size <= 7), small strides (strides <= 2) depthwise 2d convolutions
- Add related tests
2025-04-03 09:42:04 -07:00
Awni Hannun
c41f7565ed fix softmax / logsumexp (#2042) 2025-04-03 08:32:59 -07:00
Awni Hannun
9ba81e3da4 tune quant dispatch (#2031) 2025-04-02 20:05:54 -07:00
Awni Hannun
c23888acd7 Fix build warning (#2033) 2025-04-01 14:42:27 -07:00
Awni Hannun
f98ce25ab9 fix residency set for real (#2032) 2025-04-01 12:59:48 -07:00
Awni Hannun
de5f38fd48 Custom logsumexp (#2028)
* initial custom logsumexp

* more tests

* comments + fix
2025-03-31 07:36:55 -07:00
Angelos Katharopoulos
ec2854b13a Swap -inf for finite_minimum value (#2029) 2025-03-30 21:55:04 -07:00
Stephen Panaro
90823d2938 Add missing funcs to docs (#2021) 2025-03-30 18:29:33 -07:00
Jesper Stemann Andersen
5f5770e3a2 Fix CPU sign for unsigned ints (#2024)
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2025-03-30 17:56:59 -07:00
Awni Hannun
28f39e9038 Log for complex numbers in Metal (#2025)
* Log for complex numbers in Metal

* fix log2
2025-03-30 17:04:38 -07:00
Awni Hannun
b2d2b37888 fix residency set clearing (#2027) 2025-03-30 16:27:26 -07:00
Awni Hannun
fe597e141c add pinv to doc (#2020) 2025-03-30 15:54:18 -07:00
Yi Wang
72ca1539e0 Remove unused variable in /setup.py (#2026)
This is a follow up of https://github.com/ml-explore/mlx/pull/2011
2025-03-30 12:52:33 -07:00
Awni Hannun
13b26775f1 use minimum deployment target (#2016) 2025-03-28 14:31:53 -07:00
Awni Hannun
05d7118561 causal vector sdpa (#2018)
* causal vector sdpa

* get rid of memory threshold
2025-03-28 12:36:13 -07:00
Awni Hannun
98b901ad66 enable complex gemm (#2017) 2025-03-28 10:45:13 -07:00
Awni Hannun
5580b47291 iinfo and scalar overflow detection (#2009) 2025-03-27 19:54:56 -07:00
Awni Hannun
bc62932984 sdpa specialization for head dim 256 (#2007) 2025-03-27 19:31:25 -07:00
Awni Hannun
a6b5d6e759 revise cmake minimum for doctest (#2014) 2025-03-27 19:30:58 -07:00
Yi Wang
a8931306e1 Remove unused variable in CMakeBuild (#2011)
Fix https://github.com/ml-explore/mlx/issues/2010
2025-03-27 16:00:51 -07:00
Yi Wang
fecdb8717e Polish CONTRIBUTING>md (#2005) 2025-03-25 19:06:34 -07:00
Awni Hannun
916fd273ea wire cache (#2006) 2025-03-25 18:54:01 -07:00
Yi Wang
0da8506552 Update docs for extensions (#2004) 2025-03-25 18:35:03 -07:00
Cheng
eda7a7b43e Do not join threads during process exit on Windows (#1738) 2025-03-25 06:33:08 -07:00
Chunyang Wen
022eabb734 Remove unused import (#1987) 2025-03-24 20:19:32 -07:00
Awni Hannun
aba899cef8 patch bump (#2000) 2025-03-24 12:47:05 -07:00
Jagrit Digani
6a40e1c176 Fix looping limit in causal attention (#1999) 2025-03-24 12:28:00 -07:00
Jesper Stemann Andersen
9307b2ab8b Fixed 32-bit platform support for distributed/ring implementation (#1996)
Replaced unsigned long integer literals with size_t literals in ring implementation, e.g., 1UL with size_t(1).
2025-03-24 08:08:40 -07:00
Jesper Stemann Andersen
522d8d3917 Added missing netinet/in.h include that fixes build on FreeBSD (#1997)
Defines IPPROTO_TCP.
2025-03-24 08:07:34 -07:00
Awni Hannun
a84cc0123f promote mask when needed (#1998) 2025-03-23 19:58:28 -07:00
Andrey Velichkevich
f018e248cd fix(backend): Include algorithm library in Allocator (#1992)
Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com>
2025-03-22 21:27:51 -07:00
Awni Hannun
cfd7237a80 fix docs (#1991) 2025-03-21 19:58:53 -07:00
Angelos Katharopoulos
4eef8102c9 Distributed layers (#1270) 2025-03-21 13:52:17 -07:00
Angelos Katharopoulos
69e4dd506b Add a ring all gather (#1985) 2025-03-21 13:36:51 -07:00
Angelos Katharopoulos
25814a9458 Disable mpi on version mismatch (#1989) 2025-03-21 13:36:26 -07:00
Awni Hannun
2a980a76ce Add stats and limit to common allocator and enable tests (#1988)
* add stats to common allocator and enable tests

* linux memory and default

* fix
2025-03-21 12:28:36 -07:00
Angelos Katharopoulos
d343782c8b Cross platform libmpi loading (#1975) 2025-03-21 11:23:10 -07:00
Awni Hannun
4e1994e9d7 move memory APIs into top level mlx.core (#1982) 2025-03-21 07:25:12 -07:00
jiyzhang
65a38c452b update the formula of smooth_l1_loss (#1986) 2025-03-21 06:25:23 -07:00
Awni Hannun
7b7e2352cd fix malloc or wait deadlock (#1976) 2025-03-20 16:48:43 -07:00
Awni Hannun
1177d28395 patch bump (#1981) 2025-03-20 15:12:22 -07:00
Awni Hannun
005e7efa64 fix mask in sdpa (#1980)
* fix mask in sdpa

* fix attention mask

* Re-enable routing for array mask

---------

Co-authored-by: Jagrit Digani <digani@apple.com>
2025-03-20 14:53:12 -07:00
Jagrit Digani
b42d13ec84 Update attention tests to show diff, disable array masks (#1978) 2025-03-20 14:25:38 -07:00
Jagrit Digani
9adcd1a650 Support fused masking in Attention (#1924)
* Update API to allow mask='causal' in fast::sdpa

* Add fallback

* Update steel::AttnParams

* Fix typo

* WIP, basic causal

* Update tests

* Update benchmarking

* Update masking loop limits

* Add bool masking and update tests

* Update additive mask

* Update benchmarks

* Update benchmarks

* Update tests

* Update for bfloat error

* Update early exit

* Add random seed to tests
2025-03-20 11:01:32 -07:00
Awni Hannun
3c164fca8c Fix multistream GPU deadlock (#1969)
* fix multistream GPU deadlock

* comments
2025-03-20 07:19:47 -07:00
jiyzhang
95e335db7b Update smooth_l1_loss in losses.py (#1974)
According the definition of smooth_l1_loss, the line 

diff = predictions - targets

Should be updated to 

diff = mx.abs(predictions - targets)

After the modification, the result is consistent with PyTorch smooth_l1_loss
2025-03-19 20:19:02 -07:00
Awni Hannun
f90206ad74 Guard nullptr dereference (#1972)
* guard nullptr dereference

* comment
2025-03-19 16:24:10 -07:00
Chunyang Wen
3779150750 refactor: all use schedule (#1973) 2025-03-19 11:24:04 -07:00
Cheng
0a9777aa5c Do not define MLX_VERSION globally (#1966) 2025-03-18 07:12:40 -07:00
Chunyang Wen
45ad06aac8 Fix typo; Fix lint warning when reuse the same name (#1968)
* Fix typo; Fix lint warning when reuse the same name

* Add missing period
2025-03-18 07:12:24 -07:00
Awni Hannun
c6ea2ba329 Use same accumulation precision in gemv as gemm (#1962)
* use same accumulation precision in gemv as gemm

* faster

* fix compile
2025-03-16 07:13:24 -07:00
Awni Hannun
2770a10240 fix grad with inplace updates (#1961) 2025-03-13 19:13:09 -07:00
Awni Hannun
d2a94f9e6a Only compile warnings as errors for circle (#1957) 2025-03-12 13:08:19 -07:00
Awni Hannun
32da94507a fix vmap for flatten (#1955) 2025-03-11 10:42:22 -07:00
Awni Hannun
736a340478 reduce binary size (#1952) 2025-03-11 06:30:44 -07:00
Awni Hannun
117e1355a2 fix copy for large arrays (#1953) 2025-03-10 15:04:25 -07:00
Awni Hannun
3c3e558c60 Support transposed head/seq for kv (#1950)
* support transposed head/seq for kv

* fix flaky test

* nit
2025-03-10 10:53:45 -07:00
Chunyang Wen
cffceda6ee Add type hint for _extra_repr (#1948) 2025-03-10 06:05:36 -07:00
Chunyang Wen
048805ad2c Remove unused modules (#1949) 2025-03-10 06:05:26 -07:00
Chunyang Wen
d14c9fe7ea Add file info when raising errors in save (#1943) 2025-03-08 14:51:04 -08:00
Chunyang Wen
5db90ce822 Fix obsured warning (#1944) 2025-03-08 14:50:39 -08:00
Chunyang Wen
d699cc1330 Fix unreachable warning (#1939)
* Fix unreachable warning

* Update error message
2025-03-07 17:23:04 -08:00
Awni Hannun
c4230747a1 redesign for faster cpu/gpu synch (#1869)
* redesign for faster cpu/gpu synch

* load + more async CPU

* use command encoder API and move more ops to use it

* make fence back-end generic + CPU only fence

* faster build

* fix async eval

* fixes + handle temporaries

* fix / improve cpu conv

* remove unused status, fix siblings

* fix extensions

* fix

* fix no cpu build

* format

* comments

* fix perf regression, remove unecessary abort

* fix events, task limit cpu

* fix waiting

* fix donation / temporaries in normalization
2025-03-06 19:23:38 -08:00
Awni Hannun
5245f12a46 always use json (#1938) 2025-03-06 15:35:56 -08:00
Chunyang Wen
a198b2787e Remove unused modules (#1936) 2025-03-06 14:20:27 -08:00
Chunyang Wen
04edad8c59 Add doc string for path (#1937) 2025-03-06 14:20:09 -08:00
David Wisdom
392b3060b0 Fix typo in randint docstring (#1932)
This commit fixes a typo in the docstring for mlx.core.random.randint() by changing "roadcastable" to "broadcastable".
2025-03-05 21:48:00 -08:00
Chunyang Wen
85b34d59bc Clean unused sys (#1929) 2025-03-05 13:48:03 -08:00
243 changed files with 11622 additions and 6183 deletions

View File

@@ -24,8 +24,8 @@ jobs:
type: boolean type: boolean
default: false default: false
macos: macos:
xcode: "15.2.0" xcode: "16.2.0"
resource_class: macos.m1.medium.gen1 resource_class: m2pro.medium
steps: steps:
- checkout - checkout
- run: - run:
@@ -89,6 +89,7 @@ jobs:
pip install numpy pip install numpy
sudo apt-get update sudo apt-get update
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
- run: - run:
name: Install Python package name: Install Python package
command: | command: |
@@ -108,6 +109,8 @@ jobs:
name: Run Python tests name: Run Python tests
command: | command: |
python3 -m unittest discover python/tests -v python3 -m unittest discover python/tests -v
mpirun --bind-to none -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py
- run: - run:
name: Build CPP only name: Build CPP only
command: | command: |
@@ -122,10 +125,15 @@ jobs:
parameters: parameters:
xcode_version: xcode_version:
type: string type: string
default: "15.2.0" default: "16.2.0"
macosx_deployment_target:
type: string
default: ""
macos: macos:
xcode: << parameters.xcode_version >> xcode: << parameters.xcode_version >>
resource_class: macos.m1.medium.gen1 environment:
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
resource_class: m2pro.medium
steps: steps:
- checkout - checkout
- run: - run:
@@ -146,7 +154,9 @@ jobs:
name: Install Python package name: Install Python package
command: | command: |
source env/bin/activate source env/bin/activate
DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` pip install -e . -v DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
pip install -e . -v
- run: - run:
name: Generate package stubs name: Generate package stubs
command: | command: |
@@ -209,13 +219,18 @@ jobs:
default: "3.9" default: "3.9"
xcode_version: xcode_version:
type: string type: string
default: "15.2.0" default: "16.2.0"
build_env: build_env:
type: string type: string
default: "" default: ""
macosx_deployment_target:
type: string
default: ""
macos: macos:
xcode: << parameters.xcode_version >> xcode: << parameters.xcode_version >>
resource_class: macos.m1.medium.gen1 resource_class: m2pro.medium
environment:
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
steps: steps:
- checkout - checkout
- run: - run:
@@ -236,7 +251,7 @@ jobs:
name: Install Python package name: Install Python package
command: | command: |
source env/bin/activate source env/bin/activate
DEV_RELEASE=1 \ env -u MACOSX_DEPLOYMENT_TARGET DEV_RELEASE=1 \
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \ CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
pip install . -v pip install . -v
- run: - run:
@@ -331,7 +346,7 @@ workflows:
- mac_build_and_test: - mac_build_and_test:
matrix: matrix:
parameters: parameters:
xcode_version: ["15.0.0", "15.2.0", "16.0.0"] macosx_deployment_target: ["13.5", "14.0"]
- linux_build_and_test - linux_build_and_test
- build_documentation - build_documentation
@@ -351,8 +366,70 @@ workflows:
matrix: matrix:
parameters: parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
xcode_version: ["15.0.0", "15.2.0"] macosx_deployment_target: ["13.5", "14.0", "15.0"]
build_env: ["PYPI_RELEASE=1"] build_env: ["PYPI_RELEASE=1"]
xcode_version: ["16.2.0", "15.0.0"]
exclude:
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.9"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.10"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "PYPI_RELEASE=1"
- build_documentation: - build_documentation:
filters: filters:
tags: tags:
@@ -375,7 +452,7 @@ workflows:
requires: [ hold ] requires: [ hold ]
matrix: matrix:
parameters: parameters:
xcode_version: ["15.0.0", "15.2.0", "16.0.0"] macosx_deployment_target: ["13.5", "14.0"]
- linux_build_and_test: - linux_build_and_test:
requires: [ hold ] requires: [ hold ]
nightly_build: nightly_build:
@@ -388,7 +465,54 @@ workflows:
matrix: matrix:
parameters: parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
xcode_version: ["15.0.0", "15.2.0"] macosx_deployment_target: ["13.5", "14.0", "15.0"]
xcode_version: ["16.2.0", "15.0.0"]
exclude:
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.9"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.10"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
weekly_build: weekly_build:
when: when:
and: and:
@@ -399,8 +523,70 @@ workflows:
matrix: matrix:
parameters: parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
xcode_version: ["15.0.0", "15.2.0", "16.0.0"] macosx_deployment_target: ["13.5", "14.0", "15.0"]
build_env: ["DEV_RELEASE=1"] build_env: ["DEV_RELEASE=1"]
xcode_version: ["16.2.0", "15.0.0"]
exclude:
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
linux_test_release: linux_test_release:
when: when:
and: and:

View File

@@ -9,6 +9,7 @@ if(NOT MLX_VERSION)
string(REGEX MATCH "#define MLX_VERSION_PATCH ([0-9]+)" _ "${_mlx_h_version}") string(REGEX MATCH "#define MLX_VERSION_PATCH ([0-9]+)" _ "${_mlx_h_version}")
set(_patch ${CMAKE_MATCH_1}) set(_patch ${CMAKE_MATCH_1})
set(MLX_PROJECT_VERSION "${_major}.${_minor}.${_patch}") set(MLX_PROJECT_VERSION "${_major}.${_minor}.${_patch}")
set(MLX_VERSION ${MLX_PROJECT_VERSION})
else() else()
string(REGEX REPLACE "^([0-9]+\.[0-9]+\.[0-9]+).*" "\\1" MLX_PROJECT_VERSION string(REGEX REPLACE "^([0-9]+\.[0-9]+\.[0-9]+).*" "\\1" MLX_PROJECT_VERSION
${MLX_VERSION}) ${MLX_VERSION})
@@ -41,8 +42,6 @@ option(MLX_BUILD_BLAS_FROM_SOURCE "Build OpenBLAS from source code" OFF)
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF) option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF) option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
add_compile_definitions("MLX_VERSION=${MLX_VERSION}")
# --------------------- Processor tests ------------------------- # --------------------- Processor tests -------------------------
message( message(
STATUS STATUS
@@ -77,7 +76,6 @@ include(FetchContent)
cmake_policy(SET CMP0135 NEW) cmake_policy(SET CMP0135 NEW)
add_library(mlx) add_library(mlx)
set_target_properties(mlx PROPERTIES COMPILE_WARNING_AS_ERROR ON)
if(MLX_BUILD_METAL) if(MLX_BUILD_METAL)
set(METAL_LIB "-framework Metal") set(METAL_LIB "-framework Metal")
@@ -214,23 +212,13 @@ else()
set(MLX_BUILD_ACCELERATE OFF) set(MLX_BUILD_ACCELERATE OFF)
endif() endif()
find_package(MPI) message(STATUS "Downloading json")
if(MPI_FOUND) FetchContent_Declare(
execute_process( json
COMMAND zsh "-c" "mpirun --version" URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz)
OUTPUT_VARIABLE MPI_VERSION FetchContent_MakeAvailable(json)
ERROR_QUIET) target_include_directories(
if(${MPI_VERSION} MATCHES ".*Open MPI.*") mlx PRIVATE $<BUILD_INTERFACE:${json_SOURCE_DIR}/single_include/nlohmann>)
target_include_directories(mlx PRIVATE ${MPI_INCLUDE_PATH})
elseif(MPI_VERSION STREQUAL "")
set(MPI_FOUND FALSE)
message(
WARNING "MPI found but mpirun is not available. Building without MPI.")
else()
set(MPI_FOUND FALSE)
message(WARNING "MPI which is not OpenMPI found. Building without MPI.")
endif()
endif()
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx) add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)

View File

@@ -5,26 +5,26 @@ possible.
## Pull Requests ## Pull Requests
1. Fork and submit pull requests to the repo. 1. Fork and submit pull requests to the repo.
2. If you've added code that should be tested, add tests. 2. If you've added code that should be tested, add tests.
3. If a change is likely to impact efficiency, run some of the benchmarks before 3. If a change is likely to impact efficiency, run some of the benchmarks before
and after the change. Examples of benchmarks can be found in `benchmarks/python/`. and after the change. Examples of benchmarks can be found in `benchmarks/python/`.
4. If you've changed APIs, update the documentation. 4. If you've changed APIs, update the documentation.
5. Every PR should have passing tests and at least one review. 5. Every PR should have passing tests and at least one review.
6. For code formatting install `pre-commit` using something like `pip install pre-commit` and run `pre-commit install`. 6. For code formatting install `pre-commit` using something like `pip install pre-commit` and run `pre-commit install`.
This should install hooks for running `black` and `clang-format` to ensure This should install hooks for running `black` and `clang-format` to ensure
consistent style for C++ and python code. consistent style for C++ and python code.
You can also run the formatters manually as follows: You can also run the formatters manually as follows:
``` ```shell
clang-format -i file.cpp clang-format -i file.cpp
``` ```
``` ```shell
black file.py black file.py
``` ```
or run `pre-commit run --all-files` to check all files in the repo. or run `pre-commit run --all-files` to check all files in the repo.
## Issues ## Issues

View File

@@ -157,7 +157,7 @@ def bench_shape(B, M, N, K, np_dtype, transpose="nn"):
def get_gflop_count(B, M, N, K): def get_gflop_count(B, M, N, K):
return float(2.0 * N_iter_bench * N_iter_func * B * M * N * K) / float(1024.0**3) return float(2.0 * N_iter_bench * N_iter_func * B * M * N * K) / float(1000.0**3)
if __name__ == "__main__": if __name__ == "__main__":
@@ -175,6 +175,8 @@ if __name__ == "__main__":
(1, 4096, 4096, 4096), (1, 4096, 4096, 4096),
) )
print(f" B, M, N, K, dtype, t, gflops_pt, gflops_mx, diff%")
for dtype in dtypes: for dtype in dtypes:
for transpose in transposes: for transpose in transposes:
for B, M, N, K in shapes: for B, M, N, K in shapes:
@@ -187,7 +189,7 @@ if __name__ == "__main__":
diff = gflops_mx / gflops_pt - 1.0 diff = gflops_mx / gflops_pt - 1.0
print( print(
f"{B:3d}, {M:4d}, {N:4d}, {K:4d}, {dtype}, {transpose}, {gflops_pt:05.3f}, {gflops_mx:05.3f}, {100. * diff:+5.2f}%" f"{B:3d}, {M:4d}, {N:4d}, {K:5d}, {dtype}, {transpose}, {gflops_pt:8.2f}, {gflops_mx:8.2f}, {100. * diff:+5.2f}%"
) )
if gflops_pt >= 2.0 * gflops_mx: if gflops_pt >= 2.0 * gflops_mx:
print("ATTENTION ^^^^^^^") print("ATTENTION ^^^^^^^")

View File

@@ -1,7 +1,6 @@
# Copyright © 2023-2024 Apple Inc. # Copyright © 2023-2024 Apple Inc.
import argparse import argparse
from time import time
import mlx.core as mx import mlx.core as mx
import torch import torch

View File

@@ -0,0 +1,74 @@
# Copyright © 2023-2024 Apple Inc.
import mlx.core as mx
from time_utils import time_fn
N = 1024
D = 1024
M = 1024
E = 32
I = 4
def gather_sort(x, indices):
N, M = indices.shape
indices = indices.flatten()
order = mx.argsort(indices)
inv_order = mx.argsort(order)
return x.flatten(0, -3)[order // M], indices[order], inv_order
def scatter_unsort(x, inv_order, shape=None):
x = x[inv_order]
if shape is not None:
x = mx.unflatten(x, 0, shape)
return x
def gather_mm_simulate(x, w, indices):
x, idx, inv_order = gather_sort(x, indices)
for i in range(2):
y = mx.concatenate([x[i] @ w[j].T for i, j in enumerate(idx.tolist())], axis=0)
x = y[:, None]
x = scatter_unsort(x, inv_order, indices.shape)
return x
def time_gather_mm():
x = mx.random.normal((N, 1, 1, D)) / 1024**0.5
w1 = mx.random.normal((E, M, D)) / 1024**0.5
w2 = mx.random.normal((E, D, M)) / 1024**0.5
indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)
sorted_indices = mx.sort(indices.flatten()).reshape(N, I)
mx.eval(x, w1, w2, indices, sorted_indices)
def gather_mm(x, w1, w2, indices, sort):
idx = indices
inv_order = None
if sort:
x, idx, inv_order = gather_sort(x, indices)
x = mx.gather_mm(x, w1.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort)
x = mx.gather_mm(x, w2.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort)
if sort:
x = scatter_unsort(x, inv_order, indices.shape)
return x
time_fn(gather_mm, x, w1, w2, indices, False)
time_fn(gather_mm, x, w1, w2, sorted_indices, False)
time_fn(gather_mm, x, w1, w2, indices, True)
x = mx.random.normal((N * I, D)) / 1024**0.5
w1 = mx.random.normal((M, D)) / 1024**0.5
w2 = mx.random.normal((D, M)) / 1024**0.5
mx.eval(x, w1, w2)
def equivalent_matmul(x, w1, w2):
x = x @ w1.T
x = x @ w2.T
return x
time_fn(equivalent_matmul, x, w1, w2)
if __name__ == "__main__":
time_gather_mm()

View File

@@ -28,11 +28,34 @@ def bench(f, *args):
return (e - s) * 1e-9 return (e - s) * 1e-9
def mlx_sdpa_fused_inner(q, k, v, scale): def prepare_inputs(B, qL, kL, D, qH, kH, mask, transpose, dtype):
return mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=None) np_dtype = getattr(np, dtype)
shape_q = (B, qL, qH, D) if transpose else (B, qH, qL, D)
shape_kv = (B, kL, kH, D) if transpose else (B, kH, kL, D)
scale = 1.0 / math.sqrt(D)
q_np = np.random.normal(0.0, 1.0, shape_q).astype(np_dtype)
k_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype)
v_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype)
q_mx = mx.array(q_np)
k_mx = mx.array(k_np)
v_mx = mx.array(v_np)
if mask is not None:
if mask == "additive":
mask_np = np.random.normal(0.0, 1.0, (B, qH, qL, kL)).astype(np_dtype)
mask = mx.array(mask_np)
elif mask == "bool":
mask_np = np.random.uniform(0.0, 1.0, (B, qH, qL, kL)) < 0.5
mask = mx.array(mask_np)
return q_mx, k_mx, v_mx, scale, mask
def mlx_sdpa_unfused_inner(q, k, v, scale, f32softmax=False): def mlx_ref_attn(q, k, v, scale=1.0, mask=None):
q_dtype = q.dtype q_dtype = q.dtype
q = q * mx.array(scale, q_dtype) q = q * mx.array(scale, q_dtype)
n_q_heads = q.shape[-3] n_q_heads = q.shape[-3]
@@ -41,6 +64,7 @@ def mlx_sdpa_unfused_inner(q, k, v, scale, f32softmax=False):
B = q.shape[0] B = q.shape[0]
L = q.shape[2] L = q.shape[2]
kL = k.shape[2]
if n_repeats > 1: if n_repeats > 1:
q = mx.reshape(q, [B, n_kv_heads, n_repeats, L, -1]) q = mx.reshape(q, [B, n_kv_heads, n_repeats, L, -1])
@@ -48,10 +72,27 @@ def mlx_sdpa_unfused_inner(q, k, v, scale, f32softmax=False):
v = mx.expand_dims(v, 2) v = mx.expand_dims(v, 2)
scores = q @ mx.swapaxes(k, -1, -2) scores = q @ mx.swapaxes(k, -1, -2)
if f32softmax:
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(q_dtype) if mask is not None:
else:
scores = mx.softmax(scores, axis=-1) if mask == "causal":
q_offset = max(0, kL - L)
q_indices = mx.arange(q_offset, q_offset + L)
k_indices = mx.arange(kL)
mask = q_indices[:, None] >= k_indices[None]
if n_repeats > 1 and mask.ndim >= 3:
if mask.shape[-3] == 1:
mask = mx.expand_dims(mask, -3)
else:
mask = mx.unflatten(mask, -3, (n_kv_heads, n_repeats))
if mask.dtype == mx.bool_:
scores = mx.where(mask, scores, -np.float32(np.inf))
else:
scores += mask
scores = mx.softmax(scores, axis=-1, precise=True)
out = scores @ v out = scores @ v
if n_repeats > 1: if n_repeats > 1:
@@ -60,74 +101,55 @@ def mlx_sdpa_unfused_inner(q, k, v, scale, f32softmax=False):
return out return out
def mlx_spda_unfused(q, k, v, scale, transpose): def mlx_fused_attn(q, k, v, scale, mask):
q_out = q return mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
def do_attention(f, q, k, v, scale, mask=None, transpose=False):
if transpose: if transpose:
k = mx.transpose(k, (0, 2, 1, 3)) q_t = mx.transpose(q, (0, 2, 1, 3))
v = mx.transpose(v, (0, 2, 1, 3)) k_t = mx.transpose(k, (0, 2, 1, 3))
v_t = mx.transpose(v, (0, 2, 1, 3))
o_t = f(q_t, k_t, v_t, scale=scale, mask=mask)
return mx.transpose(o_t, (0, 2, 1, 3))
else:
return f(q, k, v, scale=scale, mask=mask)
def do_attention_bench(f, q, k, v, scale, mask=None, transpose=False):
q_out = q
for i in range(N_iter_func): for i in range(N_iter_func):
if transpose: q_out = do_attention(f, q_out, k, v, scale, mask=mask, transpose=transpose)
q_out = mx.transpose(q_out, (0, 2, 1, 3))
q_out = mlx_sdpa_unfused_inner(q_out, k, v, scale)
if transpose:
q_out = mx.transpose(q_out, (0, 2, 1, 3))
mx.eval(q_out) mx.eval(q_out)
return q_out return q_out
def mlx_spda_fused(q, k, v, scale, transpose): def bench_shape(
q_out = q B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, dtype, transpose=True, mask_in=None
if transpose: ):
k = mx.transpose(k, (0, 2, 1, 3)) q_mx, k_mx, v_mx, scale, mask = prepare_inputs(
v = mx.transpose(v, (0, 2, 1, 3)) B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, mask_in, transpose, dtype
for i in range(N_iter_func):
if transpose:
q_out = mx.transpose(q_out, (0, 2, 1, 3))
q_out = mlx_sdpa_fused_inner(q_out, k, v, scale)
if transpose:
q_out = mx.transpose(q_out, (0, 2, 1, 3))
mx.eval(q_out)
return q_out
def bench_shape(B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, np_dtype, transpose=True):
shape_q = (
(B, qsl, n_q_heads, head_dim) if transpose else (B, n_q_heads, qsl, head_dim)
)
shape_kv = (
(B, ksl, n_kv_heads, head_dim) if transpose else (B, n_kv_heads, ksl, head_dim)
) )
q_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_q).astype(np_dtype) time_mlx_unfused = bench(
k_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_kv).astype(np_dtype) do_attention_bench, mlx_ref_attn, q_mx, k_mx, v_mx, scale, mask, transpose
v_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_kv).astype(np_dtype) )
time_mlx_fused = bench(
do_attention_bench, mlx_fused_attn, q_mx, k_mx, v_mx, scale, mask, transpose
)
scale = math.sqrt(1.0 / head_dim) o_mlx_fused = do_attention(mlx_ref_attn, q_mx, k_mx, v_mx, scale, mask, transpose)
o_mlx_unfused = do_attention(
mlx_fused_attn, q_mx, k_mx, v_mx, scale, mask, transpose
)
q_mx = mx.array(q_np) atol = 1e-5 if dtype == "float32" else 2e-4
k_mx = mx.array(k_np)
v_mx = mx.array(v_np)
time_mlx_unfused = bench(mlx_spda_unfused, q_mx, k_mx, v_mx, scale, transpose) if not mx.allclose(o_mlx_fused, o_mlx_unfused, atol=atol, rtol=atol):
time_mlx_fused = bench(mlx_spda_fused, q_mx, k_mx, v_mx, scale, transpose)
if transpose:
q_mx = mx.transpose(q_mx, (0, 2, 1, 3))
k_mx = mx.transpose(k_mx, (0, 2, 1, 3))
v_mx = mx.transpose(v_mx, (0, 2, 1, 3))
o_mlx_fused = mlx_sdpa_fused_inner(q_mx, k_mx, v_mx, scale)
o_mlx_unfused = mlx_sdpa_unfused_inner(q_mx, k_mx, v_mx, scale, f32softmax=True)
atol = 1e-5 if np_dtype == np.float32 else 1e-4
if not mx.allclose(o_mlx_fused, o_mlx_unfused, atol=atol):
print( print(
f"Failed at (B: {B}, qsl: {qsl}, ksl: {ksl}, head_dim: {head_dim}, n_qh: {n_q_heads}, n_kvh: {n_kv_heads}) [tpose = {transpose}] with max(|a - b|) = {mx.max(mx.abs(o_mlx_unfused - o_mlx_fused)):3.2e}" f"Failed at (B: {B}, qsl: {qsl}, ksl: {ksl}, head_dim: {head_dim}, n_qh: {n_q_heads}, n_kvh: {n_kv_heads}, mask: {mask_in}) [tpose = {transpose}] with max(|a - b|) = {mx.max(mx.abs(o_mlx_unfused - o_mlx_fused)):3.2e}"
) )
return time_mlx_fused, time_mlx_unfused return time_mlx_fused, time_mlx_unfused
@@ -151,39 +173,51 @@ if __name__ == "__main__":
( 1, 128, 128, 64, 32, 32), ( 1, 128, 128, 64, 32, 32),
( 1, 256, 256, 64, 32, 32), ( 1, 256, 256, 64, 32, 32),
( 1, 512, 512, 64, 32, 32), ( 1, 512, 512, 64, 32, 32),
( 1, 1024, 1024, 64, 32, 32), ( 1, 1024, 1024, 64, 32, 8),
( 1, 2048, 2048, 64, 32, 32), ( 1, 2048, 2048, 64, 32, 8),
( 1, 4096, 4096, 64, 32, 32), ( 1, 4096, 4096, 64, 32, 8),
) )
shapes_80 = ( shapes_80 = (
# ( B, qsl, ksl, head_dim, n_qh, n_kvh) # ( B, qsl, ksl, head_dim, n_qh, n_kvh)
( 1, 1024, 1024, 80, 32, 32), ( 1, 1024, 1024, 80, 32, 8),
( 1, 2048, 2048, 80, 32, 32), ( 1, 2048, 2048, 80, 32, 8),
( 1, 4096, 4096, 80, 32, 32), ( 1, 4096, 4096, 80, 32, 8),
) )
shapes_128 = ( shapes_128 = (
# ( B, qsl, ksl, head_dim, n_qh, n_kvh) # ( B, qsl, ksl, head_dim, n_qh, n_kvh)
( 1, 1024, 1024, 128, 32, 32), ( 1, 1024, 1024, 128, 32, 8),
( 1, 2048, 2048, 128, 32, 32), ( 1, 2048, 2048, 128, 32, 8),
( 1, 4096, 4096, 128, 32, 32), ( 1, 4096, 4096, 128, 32, 8),
) )
# fmt: on # fmt: on
shapes = shapes_64 + shapes_80 + shapes_128 shapes = shapes_64 + shapes_80 + shapes_128
print(" B, qsl, ksl, hdim, n_qh, n_kvh, tpose, dtype, t_unfs, t_fuse, diff%") masks = [None, "bool", "causal"]
print(
" B, qsl, ksl, hdim, n_qh, n_kvh, t, dtype, mask, t_unfs, t_fuse, diff%"
)
for dtype in dtypes: for dtype in dtypes:
for transpose in transposes: for transpose in transposes:
for B, qsl, ksl, head_dim, n_q_heads, n_kv_heads in shapes: for B, qsl, ksl, head_dim, n_q_heads, n_kv_heads in shapes:
np_dtype = getattr(np, dtype) for mask_in in masks:
time_mlx_fused, time_mlx_unfused = bench_shape( time_mlx_fused, time_mlx_unfused = bench_shape(
B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, np_dtype, transpose B,
) qsl,
diff = time_mlx_unfused / time_mlx_fused - 1.0 ksl,
t_str = 1 if transpose else 0 head_dim,
print( n_q_heads,
f"{B:3d}, {qsl:5d}, {ksl:5d}, {head_dim:4d}, {n_q_heads:4d}, {n_kv_heads:5d}, {t_str:5d}, {dtype}, {time_mlx_unfused: 2.3f}, {time_mlx_fused: 2.3f}, {100. * diff:+5.2f}%" n_kv_heads,
) dtype,
transpose,
mask_in,
)
diff = time_mlx_unfused / time_mlx_fused - 1.0
t_str = 1 if transpose else 0
print(
f"{B:3d}, {qsl:5d}, {ksl:5d}, {head_dim:4d}, {n_q_heads:4d}, {n_kv_heads:5d}, {t_str:1d}, {dtype}, {str(mask_in):>8}, {time_mlx_unfused: 2.3f}, {time_mlx_fused: 2.3f}, {100. * diff:+5.2f}%"
)

View File

@@ -13,7 +13,7 @@ EXCLUDE_PATTERNS = */private/*
CREATE_SUBDIRS = NO CREATE_SUBDIRS = NO
FULL_PATH_NAMES = YES FULL_PATH_NAMES = YES
RECURSIVE = YES RECURSIVE = YES
GENERATE_HTML = YES GENERATE_HTML = NO
GENERATE_LATEX = NO GENERATE_LATEX = NO
GENERATE_XML = YES GENERATE_XML = YES
XML_PROGRAMLISTING = YES XML_PROGRAMLISTING = YES

View File

@@ -22,12 +22,12 @@ You can do that in MLX directly:
This function performs that operation while leaving the implementation and This function performs that operation while leaving the implementation and
function transformations to MLX. function transformations to MLX.
However you may need to customize the underlying implementation, perhaps to However, you may want to customize the underlying implementation, perhaps to
make it faster or for custom differentiation. In this tutorial we will go make it faster. In this tutorial we will go through adding custom extensions.
through adding custom extensions. It will cover: It will cover:
* The structure of the MLX library. * The structure of the MLX library.
* Implementing a CPU operation that redirects to Accelerate_ when appropriate. * Implementing a CPU operation.
* Implementing a GPU operation using metal. * Implementing a GPU operation using metal.
* Adding the ``vjp`` and ``jvp`` function transformation. * Adding the ``vjp`` and ``jvp`` function transformation.
* Building a custom extension and binding it to python. * Building a custom extension and binding it to python.
@@ -45,7 +45,7 @@ Operations
Operations are the front-end functions that operate on arrays. They are defined Operations are the front-end functions that operate on arrays. They are defined
in the C++ API (:ref:`cpp_ops`), and the Python API (:ref:`ops`) binds them. in the C++ API (:ref:`cpp_ops`), and the Python API (:ref:`ops`) binds them.
We would like an operation, :meth:`axpby` that takes in two arrays ``x`` and We would like an operation :meth:`axpby` that takes in two arrays, ``x`` and
``y``, and two scalars, ``alpha`` and ``beta``. This is how to define it in ``y``, and two scalars, ``alpha`` and ``beta``. This is how to define it in
C++: C++:
@@ -55,7 +55,7 @@ C++:
* Scale and sum two vectors element-wise * Scale and sum two vectors element-wise
* z = alpha * x + beta * y * z = alpha * x + beta * y
* *
* Follow numpy style broadcasting between x and y * Use NumPy-style broadcasting between x and y
* Inputs are upcasted to floats if needed * Inputs are upcasted to floats if needed
**/ **/
array axpby( array axpby(
@@ -66,7 +66,7 @@ C++:
StreamOrDevice s = {} // Stream on which to schedule the operation StreamOrDevice s = {} // Stream on which to schedule the operation
); );
The simplest way to this operation is in terms of existing operations: The simplest way to implement this is with existing operations:
.. code-block:: C++ .. code-block:: C++
@@ -93,9 +93,9 @@ Primitives
^^^^^^^^^^^ ^^^^^^^^^^^
A :class:`Primitive` is part of the computation graph of an :class:`array`. It A :class:`Primitive` is part of the computation graph of an :class:`array`. It
defines how to create outputs arrays given a input arrays. Further, a defines how to create output arrays given input arrays. Further, a
:class:`Primitive` has methods to run on the CPU or GPU and for function :class:`Primitive` has methods to run on the CPU or GPU and for function
transformations such as ``vjp`` and ``jvp``. Lets go back to our example to be transformations such as ``vjp`` and ``jvp``. Let's go back to our example to be
more concrete: more concrete:
.. code-block:: C++ .. code-block:: C++
@@ -128,7 +128,7 @@ more concrete:
/** The vector-Jacobian product. */ /** The vector-Jacobian product. */
std::vector<array> vjp( std::vector<array> vjp(
const std::vector<array>& primals, const std::vector<array>& primals,
const array& cotan, const std::vector<array>& cotangents,
const std::vector<int>& argnums, const std::vector<int>& argnums,
const std::vector<array>& outputs) override; const std::vector<array>& outputs) override;
@@ -153,9 +153,6 @@ more concrete:
private: private:
float alpha_; float alpha_;
float beta_; float beta_;
/** Fall back implementation for evaluation on CPU */
void eval(const std::vector<array>& inputs, array& out);
}; };
The :class:`Axpby` class derives from the base :class:`Primitive` class. The The :class:`Axpby` class derives from the base :class:`Primitive` class. The
@@ -188,7 +185,7 @@ Let's reimplement our operation now in terms of our :class:`Axpby` primitive.
auto promoted_dtype = promote_types(x.dtype(), y.dtype()); auto promoted_dtype = promote_types(x.dtype(), y.dtype());
// Upcast to float32 for non-floating point inputs x and y // Upcast to float32 for non-floating point inputs x and y
auto out_dtype = is_floating_point(promoted_dtype) auto out_dtype = issubdtype(promoted_dtype, float32)
? promoted_dtype ? promoted_dtype
: promote_types(promoted_dtype, float32); : promote_types(promoted_dtype, float32);
@@ -234,49 +231,57 @@ the execution of the computation graph, and calls :meth:`Axpby::eval_cpu` or
Implementing the CPU Back-end Implementing the CPU Back-end
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Let's start by implementing a naive and generic version of Let's start by implementing :meth:`Axpby::eval_cpu`.
:meth:`Axpby::eval_cpu`. We declared this as a private member function of
:class:`Axpby` earlier called :meth:`Axpby::eval`.
Our naive method will go over each element of the output array, find the The method will go over each element of the output array, find the
corresponding input elements of ``x`` and ``y`` and perform the operation corresponding input elements of ``x`` and ``y`` and perform the operation
point-wise. This is captured in the templated function :meth:`axpby_impl`. point-wise. This is captured in the templated function :meth:`axpby_impl`.
.. code-block:: C++ .. code-block:: C++
template <typename T> template <typename T>
void axpby_impl( void axpby_impl(
const array& x, const mx::array& x,
const array& y, const mx::array& y,
array& out, mx::array& out,
float alpha_, float alpha_,
float beta_) { float beta_,
// We only allocate memory when we are ready to fill the output mx::Stream stream) {
// malloc_or_wait synchronously allocates available memory out.set_data(mx::allocator::malloc(out.nbytes()));
// There may be a wait executed here if the allocation is requested
// under memory-pressured conditions
out.set_data(allocator::malloc_or_wait(out.nbytes()));
// Collect input and output data pointers // Get the CPU command encoder and register input and output arrays
const T* x_ptr = x.data<T>(); auto& encoder = mx::cpu::get_command_encoder(stream);
const T* y_ptr = y.data<T>(); encoder.set_input_array(x);
T* out_ptr = out.data<T>(); encoder.set_input_array(y);
encoder.set_output_array(out);
// Cast alpha and beta to the relevant types // Launch the CPU kernel
T alpha = static_cast<T>(alpha_); encoder.dispatch([x_ptr = x.data<T>(),
T beta = static_cast<T>(beta_); y_ptr = y.data<T>(),
out_ptr = out.data<T>(),
size = out.size(),
shape = out.shape(),
x_strides = x.strides(),
y_strides = y.strides(),
alpha_,
beta_]() {
// Do the element-wise operation for each output // Cast alpha and beta to the relevant types
for (size_t out_idx = 0; out_idx < out.size(); out_idx++) { T alpha = static_cast<T>(alpha_);
// Map linear indices to offsets in x and y T beta = static_cast<T>(beta_);
auto x_offset = elem_to_loc(out_idx, x.shape(), x.strides());
auto y_offset = elem_to_loc(out_idx, y.shape(), y.strides());
// We allocate the output to be contiguous and regularly strided // Do the element-wise operation for each output
// (defaults to row major) and hence it doesn't need additional mapping for (size_t out_idx = 0; out_idx < size; out_idx++) {
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset]; // Map linear indices to offsets in x and y
} auto x_offset = mx::elem_to_loc(out_idx, shape, x_strides);
} auto y_offset = mx::elem_to_loc(out_idx, shape, y_strides);
// We allocate the output to be contiguous and regularly strided
// (defaults to row major) and hence it doesn't need additional mapping
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
}
});
}
Our implementation should work for all incoming floating point arrays. Our implementation should work for all incoming floating point arrays.
Accordingly, we add dispatches for ``float32``, ``float16``, ``bfloat16`` and Accordingly, we add dispatches for ``float32``, ``float16``, ``bfloat16`` and
@@ -284,112 +289,32 @@ Accordingly, we add dispatches for ``float32``, ``float16``, ``bfloat16`` and
.. code-block:: C++ .. code-block:: C++
/** Fall back implementation for evaluation on CPU */
void Axpby::eval(
const std::vector<array>& inputs,
const std::vector<array>& outputs) {
auto& x = inputs[0];
auto& y = inputs[1];
auto& out = outputs[0];
// Dispatch to the correct dtype
if (out.dtype() == float32) {
return axpby_impl<float>(x, y, out, alpha_, beta_);
} else if (out.dtype() == float16) {
return axpby_impl<float16_t>(x, y, out, alpha_, beta_);
} else if (out.dtype() == bfloat16) {
return axpby_impl<bfloat16_t>(x, y, out, alpha_, beta_);
} else if (out.dtype() == complex64) {
return axpby_impl<complex64_t>(x, y, out, alpha_, beta_);
} else {
throw std::runtime_error(
"[Axpby] Only supports floating point types.");
}
}
This is good as a fallback implementation. We can use the ``axpby`` routine
provided by the Accelerate_ framework for a faster implementation in certain
cases:
#. Accelerate does not provide implementations of ``axpby`` for half precision
floats. We can only use it for ``float32`` types.
#. Accelerate assumes the inputs ``x`` and ``y`` are contiguous and all
elements have fixed strides between them. We only direct to Accelerate
if both ``x`` and ``y`` are row contiguous or column contiguous.
#. Accelerate performs the routine ``Y = (alpha * X) + (beta * Y)`` in-place.
MLX expects to write the output to a new array. We must copy the elements
of ``y`` into the output and use that as an input to ``axpby``.
Let's write an implementation that uses Accelerate in the right conditions.
It allocates data for the output, copies ``y`` into it, and then calls the
:func:`catlas_saxpby` from accelerate.
.. code-block:: C++
template <typename T>
void axpby_impl_accelerate(
const array& x,
const array& y,
array& out,
float alpha_,
float beta_) {
// Accelerate library provides catlas_saxpby which does
// Y = (alpha * X) + (beta * Y) in place
// To use it, we first copy the data in y over to the output array
out.set_data(allocator::malloc_or_wait(out.nbytes()));
// We then copy over the elements using the contiguous vector specialization
copy_inplace(y, out, CopyType::Vector);
// Get x and y pointers for catlas_saxpby
const T* x_ptr = x.data<T>();
T* y_ptr = out.data<T>();
T alpha = static_cast<T>(alpha_);
T beta = static_cast<T>(beta_);
// Call the inplace accelerate operator
catlas_saxpby(
/* N = */ out.size(),
/* ALPHA = */ alpha,
/* X = */ x_ptr,
/* INCX = */ 1,
/* BETA = */ beta,
/* Y = */ y_ptr,
/* INCY = */ 1);
}
For inputs that do not fit the criteria for accelerate, we fall back to
:meth:`Axpby::eval`. With this in mind, let's finish our
:meth:`Axpby::eval_cpu`.
.. code-block:: C++
/** Evaluate primitive on CPU using accelerate specializations */
void Axpby::eval_cpu( void Axpby::eval_cpu(
const std::vector<array>& inputs, const std::vector<mx::array>& inputs,
const std::vector<array>& outputs) { std::vector<mx::array>& outputs) {
assert(inputs.size() == 2); auto& x = inputs[0];
auto& x = inputs[0]; auto& y = inputs[1];
auto& y = inputs[1]; auto& out = outputs[0];
auto& out = outputs[0];
// Accelerate specialization for contiguous single precision float arrays // Dispatch to the correct dtype
if (out.dtype() == float32 && if (out.dtype() == mx::float32) {
((x.flags().row_contiguous && y.flags().row_contiguous) || return axpby_impl<float>(x, y, out, alpha_, beta_, stream());
(x.flags().col_contiguous && y.flags().col_contiguous))) { } else if (out.dtype() == mx::float16) {
axpby_impl_accelerate<float>(x, y, out, alpha_, beta_); return axpby_impl<mx::float16_t>(x, y, out, alpha_, beta_, stream());
return; } else if (out.dtype() == mx::bfloat16) {
} return axpby_impl<mx::bfloat16_t>(x, y, out, alpha_, beta_, stream());
} else if (out.dtype() == mx::complex64) {
// Fall back to common back-end if specializations are not available return axpby_impl<mx::complex64_t>(x, y, out, alpha_, beta_, stream());
eval(inputs, outputs); } else {
throw std::runtime_error(
"Axpby is only supported for floating point types.");
}
} }
Just this much is enough to run the operation :meth:`axpby` on a CPU stream! If Just this much is enough to run the operation :meth:`axpby` on a CPU stream! If
you do not plan on running the operation on the GPU or using transforms on you do not plan on running the operation on the GPU or using transforms on
computation graphs that contain :class:`Axpby`, you can stop implementing the computation graphs that contain :class:`Axpby`, you can stop implementing the
primitive here and enjoy the speed-ups you get from the Accelerate library. primitive here.
Implementing the GPU Back-end Implementing the GPU Back-end
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -466,7 +391,7 @@ below.
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
// Allocate output memory // Allocate output memory
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
// Resolve name of kernel // Resolve name of kernel
std::ostringstream kname; std::ostringstream kname;
@@ -544,7 +469,7 @@ one we just defined:
const std::vector<array>& tangents, const std::vector<array>& tangents,
const std::vector<int>& argnums) { const std::vector<int>& argnums) {
// Forward mode diff that pushes along the tangents // Forward mode diff that pushes along the tangents
// The jvp transform on the primitive can built with ops // The jvp transform on the primitive can be built with ops
// that are scheduled on the same stream as the primitive // that are scheduled on the same stream as the primitive
// If argnums = {0}, we only push along x in which case the // If argnums = {0}, we only push along x in which case the
@@ -556,7 +481,7 @@ one we just defined:
auto scale_arr = array(scale, tangents[0].dtype()); auto scale_arr = array(scale, tangents[0].dtype());
return {multiply(scale_arr, tangents[0], stream())}; return {multiply(scale_arr, tangents[0], stream())};
} }
// If, argnums = {0, 1}, we take contributions from both // If argnums = {0, 1}, we take contributions from both
// which gives us jvp = tangent_x * alpha + tangent_y * beta // which gives us jvp = tangent_x * alpha + tangent_y * beta
else { else {
return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())}; return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())};
@@ -810,7 +735,7 @@ Let's look at a simple script and its results:
print(f"c shape: {c.shape}") print(f"c shape: {c.shape}")
print(f"c dtype: {c.dtype}") print(f"c dtype: {c.dtype}")
print(f"c correct: {mx.all(c == 6.0).item()}") print(f"c is correct: {mx.all(c == 6.0).item()}")
Output: Output:
@@ -818,13 +743,13 @@ Output:
c shape: [3, 4] c shape: [3, 4]
c dtype: float32 c dtype: float32
c correctness: True c is correct: True
Results Results
^^^^^^^ ^^^^^^^
Let's run a quick benchmark and see how our new ``axpby`` operation compares Let's run a quick benchmark and see how our new ``axpby`` operation compares
with the naive :meth:`simple_axpby` we first defined on the CPU. with the naive :meth:`simple_axpby` we first defined.
.. code-block:: python .. code-block:: python
@@ -832,13 +757,11 @@ with the naive :meth:`simple_axpby` we first defined on the CPU.
from mlx_sample_extensions import axpby from mlx_sample_extensions import axpby
import time import time
mx.set_default_device(mx.cpu)
def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array: def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array:
return alpha * x + beta * y return alpha * x + beta * y
M = 256 M = 4096
N = 512 N = 4096
x = mx.random.normal((M, N)) x = mx.random.normal((M, N))
y = mx.random.normal((M, N)) y = mx.random.normal((M, N))
@@ -849,24 +772,24 @@ with the naive :meth:`simple_axpby` we first defined on the CPU.
def bench(f): def bench(f):
# Warm up # Warm up
for i in range(100): for i in range(5):
z = f(x, y, alpha, beta) z = f(x, y, alpha, beta)
mx.eval(z) mx.eval(z)
# Timed run # Timed run
s = time.time() s = time.time()
for i in range(5000): for i in range(100):
z = f(x, y, alpha, beta) z = f(x, y, alpha, beta)
mx.eval(z) mx.eval(z)
e = time.time() e = time.time()
return e - s return 1000 * (e - s) / 100
simple_time = bench(simple_axpby) simple_time = bench(simple_axpby)
custom_time = bench(axpby) custom_time = bench(axpby)
print(f"Simple axpby: {simple_time:.3f} s | Custom axpby: {custom_time:.3f} s") print(f"Simple axpby: {simple_time:.3f} ms | Custom axpby: {custom_time:.3f} ms")
The results are ``Simple axpby: 0.114 s | Custom axpby: 0.109 s``. We see The results are ``Simple axpby: 1.559 ms | Custom axpby: 0.774 ms``. We see
modest improvements right away! modest improvements right away!
This operation is now good to be used to build other operations, in This operation is now good to be used to build other operations, in

View File

@@ -70,6 +70,7 @@ are the CPU and GPU.
python/fft python/fft
python/linalg python/linalg
python/metal python/metal
python/memory_management
python/nn python/nn
python/optimizers python/optimizers
python/distributed python/distributed

View File

@@ -38,6 +38,7 @@ Array
array.log10 array.log10
array.log1p array.log1p
array.log2 array.log2
array.logcumsumexp
array.logsumexp array.logsumexp
array.max array.max
array.mean array.mean

View File

@@ -20,5 +20,6 @@ Linear Algebra
eigh eigh
lu lu
lu_factor lu_factor
pinv
solve solve
solve_triangular solve_triangular

View File

@@ -0,0 +1,16 @@
Memory Management
=================
.. currentmodule:: mlx.core
.. autosummary::
:toctree: _autosummary
get_active_memory
get_peak_memory
reset_peak_memory
get_cache_memory
set_memory_limit
set_cache_limit
set_wired_limit
clear_cache

View File

@@ -8,13 +8,5 @@ Metal
is_available is_available
device_info device_info
get_active_memory
get_peak_memory
reset_peak_memory
get_cache_memory
set_memory_limit
set_cache_limit
set_wired_limit
clear_cache
start_capture start_capture
stop_capture stop_capture

View File

@@ -36,10 +36,12 @@ Operations
bitwise_or bitwise_or
bitwise_xor bitwise_xor
block_masked_mm block_masked_mm
broadcast_arrays
broadcast_to broadcast_to
ceil ceil
clip clip
concatenate concatenate
contiguous
conj conj
conjugate conjugate
convolve convolve
@@ -101,6 +103,7 @@ Operations
log10 log10
log1p log1p
logaddexp logaddexp
logcumsumexp
logical_not logical_not
logical_and logical_and
logical_or logical_or

View File

@@ -18,3 +18,4 @@ Common Optimizers
AdamW AdamW
Adamax Adamax
Lion Lion
MultiOptimizer

View File

@@ -9,6 +9,7 @@ Transforms
:toctree: _autosummary :toctree: _autosummary
eval eval
async_eval
compile compile
custom_function custom_function
disable_compile disable_compile

View File

@@ -10,7 +10,6 @@ set(CMAKE_POSITION_INDEPENDENT_CODE ON)
option(BUILD_SHARED_LIBS "Build extensions as a shared library" ON) option(BUILD_SHARED_LIBS "Build extensions as a shared library" ON)
# ----------------------------- Dependencies ----------------------------- # ----------------------------- Dependencies -----------------------------
find_package(MLX CONFIG REQUIRED)
find_package( find_package(
Python 3.8 Python 3.8
COMPONENTS Interpreter Development.Module COMPONENTS Interpreter Development.Module
@@ -21,6 +20,12 @@ execute_process(
OUTPUT_VARIABLE nanobind_ROOT) OUTPUT_VARIABLE nanobind_ROOT)
find_package(nanobind CONFIG REQUIRED) find_package(nanobind CONFIG REQUIRED)
execute_process(
COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE MLX_ROOT)
find_package(MLX CONFIG REQUIRED)
# ----------------------------- Extensions ----------------------------- # ----------------------------- Extensions -----------------------------
# Add library # Add library

View File

@@ -1,20 +1,14 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2025 Apple Inc.
#include <cassert>
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/encoder.h"
#include "mlx/utils.h" #include "mlx/utils.h"
#include "axpby/axpby.h" #include "axpby/axpby.h"
#ifdef ACCELERATE_NEW_LAPACK
#include <vecLib/cblas_new.h>
#endif
#ifdef _METAL_ #ifdef _METAL_
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/utils.h" #include "mlx/backend/metal/utils.h"
@@ -76,136 +70,65 @@ void axpby_impl(
const mx::array& y, const mx::array& y,
mx::array& out, mx::array& out,
float alpha_, float alpha_,
float beta_) { float beta_,
// We only allocate memory when we are ready to fill the output mx::Stream stream) {
// malloc_or_wait synchronously allocates available memory out.set_data(mx::allocator::malloc(out.nbytes()));
// There may be a wait executed here if the allocation is requested
// under memory-pressured conditions
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
// Collect input and output data pointers // Get the CPU command encoder and register input and output arrays
const T* x_ptr = x.data<T>(); auto& encoder = mx::cpu::get_command_encoder(stream);
const T* y_ptr = y.data<T>(); encoder.set_input_array(x);
T* out_ptr = out.data<T>(); encoder.set_input_array(y);
encoder.set_output_array(out);
// Cast alpha and beta to the relevant types // Launch the CPU kernel
T alpha = static_cast<T>(alpha_); encoder.dispatch([x_ptr = x.data<T>(),
T beta = static_cast<T>(beta_); y_ptr = y.data<T>(),
out_ptr = out.data<T>(),
size = out.size(),
shape = out.shape(),
x_strides = x.strides(),
y_strides = y.strides(),
alpha_,
beta_]() {
// Cast alpha and beta to the relevant types
T alpha = static_cast<T>(alpha_);
T beta = static_cast<T>(beta_);
// Do the element-wise operation for each output // Do the element-wise operation for each output
for (size_t out_idx = 0; out_idx < out.size(); out_idx++) { for (size_t out_idx = 0; out_idx < size; out_idx++) {
// Map linear indices to offsets in x and y // Map linear indices to offsets in x and y
auto x_offset = mx::elem_to_loc(out_idx, x.shape(), x.strides()); auto x_offset = mx::elem_to_loc(out_idx, shape, x_strides);
auto y_offset = mx::elem_to_loc(out_idx, y.shape(), y.strides()); auto y_offset = mx::elem_to_loc(out_idx, shape, y_strides);
// We allocate the output to be contiguous and regularly strided // We allocate the output to be contiguous and regularly strided
// (defaults to row major) and hence it doesn't need additional mapping // (defaults to row major) and hence it doesn't need additional mapping
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset]; out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
} }
});
} }
/** Fall back implementation for evaluation on CPU */ void Axpby::eval_cpu(
void Axpby::eval(
const std::vector<mx::array>& inputs, const std::vector<mx::array>& inputs,
std::vector<mx::array>& outputs) { std::vector<mx::array>& outputs) {
// Check the inputs (registered in the op while constructing the out array)
assert(inputs.size() == 2);
auto& x = inputs[0]; auto& x = inputs[0];
auto& y = inputs[1]; auto& y = inputs[1];
auto& out = outputs[0]; auto& out = outputs[0];
// Dispatch to the correct dtype // Dispatch to the correct dtype
if (out.dtype() == mx::float32) { if (out.dtype() == mx::float32) {
return axpby_impl<float>(x, y, out, alpha_, beta_); return axpby_impl<float>(x, y, out, alpha_, beta_, stream());
} else if (out.dtype() == mx::float16) { } else if (out.dtype() == mx::float16) {
return axpby_impl<mx::float16_t>(x, y, out, alpha_, beta_); return axpby_impl<mx::float16_t>(x, y, out, alpha_, beta_, stream());
} else if (out.dtype() == mx::bfloat16) { } else if (out.dtype() == mx::bfloat16) {
return axpby_impl<mx::bfloat16_t>(x, y, out, alpha_, beta_); return axpby_impl<mx::bfloat16_t>(x, y, out, alpha_, beta_, stream());
} else if (out.dtype() == mx::complex64) { } else if (out.dtype() == mx::complex64) {
return axpby_impl<mx::complex64_t>(x, y, out, alpha_, beta_); return axpby_impl<mx::complex64_t>(x, y, out, alpha_, beta_, stream());
} else { } else {
throw std::runtime_error( throw std::runtime_error(
"Axpby is only supported for floating point types."); "Axpby is only supported for floating point types.");
} }
} }
///////////////////////////////////////////////////////////////////////////////
// Primitive Accelerate Backend Implementation
///////////////////////////////////////////////////////////////////////////////
#ifdef ACCELERATE_NEW_LAPACK
template <typename T>
void axpby_impl_accelerate(
const mx::array& x,
const mx::array& y,
mx::array& out,
float alpha_,
float beta_) {
// Accelerate library provides catlas_saxpby which does
// Y = (alpha * X) + (beta * Y) in place
// To use it, we first copy the data in y over to the output array
// This specialization requires both x and y be contiguous in the same mode
// i.e: corresponding linear indices in both point to corresponding elements
// The data in the output array is allocated to match the strides in y
// such that x, y, and out are contiguous in the same mode and
// no transposition is needed
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
// We then copy over the elements using the contiguous vector specialization
copy_inplace(y, out, mx::CopyType::Vector);
// Get x and y pointers for catlas_saxpby
const T* x_ptr = x.data<T>();
T* y_ptr = out.data<T>();
T alpha = static_cast<T>(alpha_);
T beta = static_cast<T>(beta_);
// Call the inplace accelerate operator
catlas_saxpby(
/* N = */ out.size(),
/* ALPHA = */ alpha,
/* X = */ x_ptr,
/* INCX = */ 1,
/* BETA = */ beta,
/* Y = */ y_ptr,
/* INCY = */ 1);
}
/** Evaluate primitive on CPU using accelerate specializations */
void Axpby::eval_cpu(
const std::vector<mx::array>& inputs,
std::vector<mx::array>& outputs) {
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
auto& out = outputs[0];
// Accelerate specialization for contiguous single precision float arrays
if (out.dtype() == mx::float32 &&
((x.flags().row_contiguous && y.flags().row_contiguous) ||
(x.flags().col_contiguous && y.flags().col_contiguous))) {
axpby_impl_accelerate<float>(x, y, out, alpha_, beta_);
return;
}
// Fall back to common backend if specializations are not available
eval(inputs, outputs);
}
#else // Accelerate not available
/** Evaluate primitive on CPU falling back to common backend */
void Axpby::eval_cpu(
const std::vector<mx::array>& inputs,
std::vector<mx::array>& outputs) {
eval(inputs, outputs);
}
#endif
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// Primitive Metal Backend Implementation // Primitive Metal Backend Implementation
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
@@ -217,7 +140,6 @@ void Axpby::eval_gpu(
const std::vector<mx::array>& inputs, const std::vector<mx::array>& inputs,
std::vector<mx::array>& outputs) { std::vector<mx::array>& outputs) {
// Prepare inputs // Prepare inputs
assert(inputs.size() == 2);
auto& x = inputs[0]; auto& x = inputs[0];
auto& y = inputs[1]; auto& y = inputs[1];
auto& out = outputs[0]; auto& out = outputs[0];
@@ -236,12 +158,12 @@ void Axpby::eval_gpu(
// Allocate output memory with strides based on specialization // Allocate output memory with strides based on specialization
if (contiguous_kernel) { if (contiguous_kernel) {
out.set_data( out.set_data(
mx::allocator::malloc_or_wait(x.data_size() * out.itemsize()), mx::allocator::malloc(x.data_size() * out.itemsize()),
x.data_size(), x.data_size(),
x.strides(), x.strides(),
x.flags()); x.flags());
} else { } else {
out.set_data(mx::allocator::malloc_or_wait(out.nbytes())); out.set_data(mx::allocator::malloc(out.nbytes()));
} }
// Resolve name of kernel (corresponds to axpby.metal) // Resolve name of kernel (corresponds to axpby.metal)

View File

@@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-2025 Apple Inc.
#pragma once #pragma once
@@ -85,11 +85,6 @@ class Axpby : public mx::Primitive {
private: private:
float alpha_; float alpha_;
float beta_; float beta_;
/** Fall back implementation for evaluation on CPU */
void eval(
const std::vector<mx::array>& inputs,
std::vector<mx::array>& outputs);
}; };
} // namespace my_ext } // namespace my_ext

View File

@@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-2025 Apple Inc.
#include <metal_stdlib> #include <metal_stdlib>

View File

@@ -17,9 +17,13 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp ${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp ${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
${CMAKE_CURRENT_SOURCE_DIR}/version.cpp
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h) ${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h)
# Define MLX_VERSION only in the version.cpp file.
add_library(mlx_version STATIC ${CMAKE_CURRENT_SOURCE_DIR}/version.cpp)
target_compile_definitions(mlx_version PRIVATE MLX_VERSION="${MLX_VERSION}")
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:mlx_version>)
if(MSVC) if(MSVC)
# Disable some MSVC warnings to speed up compilation. # Disable some MSVC warnings to speed up compilation.
target_compile_options(mlx PUBLIC /wd4068 /wd4244 /wd4267 /wd4804) target_compile_options(mlx PUBLIC /wd4068 /wd4244 /wd4267 /wd4804)

View File

@@ -4,12 +4,11 @@
#include <sstream> #include <sstream>
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/scheduler.h"
namespace mlx::core::allocator { namespace mlx::core::allocator {
Buffer malloc(size_t size) { Buffer malloc(size_t size) {
auto buffer = allocator().malloc(size, /* allow_swap */ true); auto buffer = allocator().malloc(size);
if (size && !buffer.ptr()) { if (size && !buffer.ptr()) {
std::ostringstream msg; std::ostringstream msg;
msg << "[malloc] Unable to allocate " << size << " bytes."; msg << "[malloc] Unable to allocate " << size << " bytes.";
@@ -22,45 +21,4 @@ void free(Buffer buffer) {
allocator().free(buffer); allocator().free(buffer);
} }
Buffer CommonAllocator::malloc(size_t size, bool) {
void* ptr = std::malloc(size + sizeof(size_t));
if (ptr != nullptr) {
*static_cast<size_t*>(ptr) = size;
}
return Buffer{ptr};
}
void CommonAllocator::free(Buffer buffer) {
std::free(buffer.ptr());
}
size_t CommonAllocator::size(Buffer buffer) const {
if (buffer.ptr() == nullptr) {
return 0;
}
return *static_cast<size_t*>(buffer.ptr());
}
Buffer malloc_or_wait(size_t size) {
auto buffer = allocator().malloc(size);
while (size && !buffer.ptr() && scheduler::n_active_tasks() > 0) {
scheduler::wait_for_one();
buffer = allocator().malloc(size);
}
// Try swapping if needed
if (size && !buffer.ptr()) {
buffer = allocator().malloc(size, /* allow_swap = */ true);
}
if (size && !buffer.ptr()) {
std::ostringstream msg;
msg << "[malloc_or_wait] Unable to allocate " << size << " bytes.";
throw std::runtime_error(msg.str());
}
return buffer;
}
} // namespace mlx::core::allocator } // namespace mlx::core::allocator

View File

@@ -32,14 +32,10 @@ Buffer malloc(size_t size);
void free(Buffer buffer); void free(Buffer buffer);
// Wait for running tasks to finish and free up memory
// if allocation fails
Buffer malloc_or_wait(size_t size);
class Allocator { class Allocator {
/** Abstract base class for a memory allocator. */ /** Abstract base class for a memory allocator. */
public: public:
virtual Buffer malloc(size_t size, bool allow_swap = false) = 0; virtual Buffer malloc(size_t size) = 0;
virtual void free(Buffer buffer) = 0; virtual void free(Buffer buffer) = 0;
virtual size_t size(Buffer buffer) const = 0; virtual size_t size(Buffer buffer) const = 0;
@@ -53,16 +49,4 @@ class Allocator {
Allocator& allocator(); Allocator& allocator();
class CommonAllocator : public Allocator {
/** A general CPU allocator. */
public:
virtual Buffer malloc(size_t size, bool allow_swap = false) override;
virtual void free(Buffer buffer) override;
virtual size_t size(Buffer buffer) const override;
private:
CommonAllocator() = default;
friend Allocator& allocator();
};
} // namespace mlx::core::allocator } // namespace mlx::core::allocator

View File

@@ -56,6 +56,18 @@ std::vector<array> array::make_arrays(
return outputs; return outputs;
} }
array array::unsafe_weak_copy(const array& other) {
auto cpy = array(other.shape(), other.dtype(), nullptr, {});
cpy.set_data(
other.buffer(),
other.data_size(),
other.strides(),
other.flags(),
[](auto) {});
cpy.array_desc_->data_ptr = other.array_desc_->data_ptr;
return cpy;
}
array::array(std::initializer_list<float> data) array::array(std::initializer_list<float> data)
: array_desc_(std::make_shared<ArrayDesc>( : array_desc_(std::make_shared<ArrayDesc>(
Shape{static_cast<ShapeElem>(data.size())}, Shape{static_cast<ShapeElem>(data.size())},
@@ -76,35 +88,27 @@ array::array(allocator::Buffer data, Shape shape, Dtype dtype, Deleter deleter)
set_data(data, deleter); set_data(data, deleter);
} }
array::array(
allocator::Buffer data,
Shape shape,
Dtype dtype,
Strides strides,
size_t data_size,
Flags flags,
Deleter deleter)
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
set_data(data, data_size, std::move(strides), flags, deleter);
}
void array::detach() { void array::detach() {
array_desc_->primitive = nullptr;
for (auto& s : array_desc_->siblings) {
s.array_desc_->primitive = nullptr;
}
for (auto& s : array_desc_->siblings) { for (auto& s : array_desc_->siblings) {
s.array_desc_->inputs.clear(); s.array_desc_->inputs.clear();
s.array_desc_->siblings.clear(); s.array_desc_->siblings.clear();
s.array_desc_->position = 0; s.array_desc_->position = 0;
s.array_desc_->primitive = nullptr;
} }
array_desc_->inputs.clear(); array_desc_->inputs.clear();
array_desc_->siblings.clear(); array_desc_->siblings.clear();
array_desc_->position = 0; array_desc_->position = 0;
array_desc_->primitive = nullptr;
} }
bool array::is_available() const { bool array::is_available() const {
if (status() == Status::available) { if (status() == Status::available) {
return true; return true;
} else if (status() == Status::evaluated && event().is_signaled()) { } else if (
status() == Status::evaluated &&
(!event().valid() || event().is_signaled())) {
set_status(Status::available); set_status(Status::available);
return true; return true;
} }
@@ -113,7 +117,10 @@ bool array::is_available() const {
void array::wait() { void array::wait() {
if (!is_available()) { if (!is_available()) {
event().wait(); if (event().valid()) {
event().wait();
detach_event();
}
set_status(Status::available); set_status(Status::available);
} }
} }
@@ -174,34 +181,13 @@ void array::copy_shared_buffer(const array& other) {
copy_shared_buffer(other, other.strides(), other.flags(), other.data_size()); copy_shared_buffer(other, other.strides(), other.flags(), other.data_size());
} }
void array::move_shared_buffer(
array other,
const Strides& strides,
Flags flags,
size_t data_size,
size_t offset /* = 0 */) {
array_desc_->data = std::move(other.array_desc_->data);
array_desc_->strides = strides;
array_desc_->flags = flags;
array_desc_->data_size = data_size;
auto char_offset = sizeof(char) * itemsize() * offset;
auto data_ptr = other.array_desc_->data_ptr;
other.array_desc_->data_ptr = nullptr;
array_desc_->data_ptr =
static_cast<void*>(static_cast<char*>(data_ptr) + char_offset);
}
void array::move_shared_buffer(array other) {
move_shared_buffer(other, other.strides(), other.flags(), other.data_size());
}
array::~array() { array::~array() {
if (array_desc_ == nullptr) { if (array_desc_ == nullptr) {
return; return;
} }
// Ignore arrays that might be detached during eval // Detached/detaching
if (status() == array::Status::scheduled) { if (array_desc_->primitive == nullptr) {
return; return;
} }

View File

@@ -199,6 +199,13 @@ class array {
const std::shared_ptr<Primitive>& primitive, const std::shared_ptr<Primitive>& primitive,
const std::vector<array>& inputs); const std::vector<array>& inputs);
/**
* Get a new array that refers to the same data as the input but with a
* non-owning pointer to it. Note the array is detached from the graph and has
* no inputs, siblings or primitive.
*/
static array unsafe_weak_copy(const array& other);
/** A unique identifier for an array. */ /** A unique identifier for an array. */
std::uintptr_t id() const { std::uintptr_t id() const {
return reinterpret_cast<std::uintptr_t>(array_desc_.get()); return reinterpret_cast<std::uintptr_t>(array_desc_.get());
@@ -243,18 +250,6 @@ class array {
bool col_contiguous : 1; bool col_contiguous : 1;
}; };
/** Build an array from all the info held by the array description. Including
* the buffer, strides, flags.
*/
explicit array(
allocator::Buffer data,
Shape shape,
Dtype dtype,
Strides strides,
size_t data_size,
Flags flags,
Deleter deleter = allocator::free);
/** The array's primitive. */ /** The array's primitive. */
Primitive& primitive() const { Primitive& primitive() const {
return *(array_desc_->primitive); return *(array_desc_->primitive);
@@ -365,11 +360,6 @@ class array {
// For example, the status of `x` in `auto x = a + b`. // For example, the status of `x` in `auto x = a + b`.
unscheduled, unscheduled,
// The ouptut of a computation which has been scheduled but `eval_*` has
// not yet been called on the array's primitive. A possible
// status of `x` in `auto x = a + b; eval(x);`
scheduled,
// The array's `eval_*` function has been run, but the computation is not // The array's `eval_*` function has been run, but the computation is not
// necessarily complete. The array will have memory allocated and if it is // necessarily complete. The array will have memory allocated and if it is
// not a tracer then it will be detached from the graph. // not a tracer then it will be detached from the graph.
@@ -406,6 +396,10 @@ class array {
array_desc_->event = std::move(e); array_desc_->event = std::move(e);
} }
void detach_event() const {
array_desc_->event = Event{};
}
// Mark the array as a tracer array (true) or not. // Mark the array as a tracer array (true) or not.
void set_tracer(bool is_tracer) { void set_tracer(bool is_tracer) {
array_desc_->is_tracer = is_tracer; array_desc_->is_tracer = is_tracer;
@@ -431,15 +425,6 @@ class array {
void copy_shared_buffer(const array& other); void copy_shared_buffer(const array& other);
void move_shared_buffer(
array other,
const Strides& strides,
Flags flags,
size_t data_size,
size_t offset = 0);
void move_shared_buffer(array other);
void overwrite_descriptor(const array& other) { void overwrite_descriptor(const array& other) {
array_desc_ = other.array_desc_; array_desc_ = other.array_desc_;
} }

View File

@@ -1,6 +1,7 @@
target_sources( target_sources(
mlx mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/broadcasting.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp ${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp

View File

@@ -38,25 +38,20 @@ inline void set_binary_op_output_data(
const array& a, const array& a,
const array& b, const array& b,
array& out, array& out,
BinaryOpType bopt, BinaryOpType bopt) {
bool donate_with_move = false) {
bool b_donatable = is_donatable(b, out); bool b_donatable = is_donatable(b, out);
bool a_donatable = is_donatable(a, out); bool a_donatable = is_donatable(a, out);
switch (bopt) { switch (bopt) {
case BinaryOpType::ScalarScalar: case BinaryOpType::ScalarScalar:
out.set_data( out.set_data(
allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags()); allocator::malloc(out.itemsize()), 1, a.strides(), a.flags());
break; break;
case BinaryOpType::ScalarVector: case BinaryOpType::ScalarVector:
if (b_donatable) { if (b_donatable) {
if (donate_with_move) { out.copy_shared_buffer(b);
out.move_shared_buffer(b);
} else {
out.copy_shared_buffer(b);
}
} else { } else {
out.set_data( out.set_data(
allocator::malloc_or_wait(b.data_size() * out.itemsize()), allocator::malloc(b.data_size() * out.itemsize()),
b.data_size(), b.data_size(),
b.strides(), b.strides(),
b.flags()); b.flags());
@@ -64,14 +59,10 @@ inline void set_binary_op_output_data(
break; break;
case BinaryOpType::VectorScalar: case BinaryOpType::VectorScalar:
if (a_donatable) { if (a_donatable) {
if (donate_with_move) { out.copy_shared_buffer(a);
out.move_shared_buffer(a);
} else {
out.copy_shared_buffer(a);
}
} else { } else {
out.set_data( out.set_data(
allocator::malloc_or_wait(a.data_size() * out.itemsize()), allocator::malloc(a.data_size() * out.itemsize()),
a.data_size(), a.data_size(),
a.strides(), a.strides(),
a.flags()); a.flags());
@@ -79,20 +70,12 @@ inline void set_binary_op_output_data(
break; break;
case BinaryOpType::VectorVector: case BinaryOpType::VectorVector:
if (a_donatable) { if (a_donatable) {
if (donate_with_move) { out.copy_shared_buffer(a);
out.move_shared_buffer(a);
} else {
out.copy_shared_buffer(a);
}
} else if (b_donatable) { } else if (b_donatable) {
if (donate_with_move) { out.copy_shared_buffer(b);
out.move_shared_buffer(b);
} else {
out.copy_shared_buffer(b);
}
} else { } else {
out.set_data( out.set_data(
allocator::malloc_or_wait(a.data_size() * out.itemsize()), allocator::malloc(a.data_size() * out.itemsize()),
a.data_size(), a.data_size(),
a.strides(), a.strides(),
a.flags()); a.flags());
@@ -100,20 +83,12 @@ inline void set_binary_op_output_data(
break; break;
case BinaryOpType::General: case BinaryOpType::General:
if (a_donatable && a.flags().row_contiguous && a.size() == out.size()) { if (a_donatable && a.flags().row_contiguous && a.size() == out.size()) {
if (donate_with_move) { out.copy_shared_buffer(a);
out.move_shared_buffer(a);
} else {
out.copy_shared_buffer(a);
}
} else if ( } else if (
b_donatable && b.flags().row_contiguous && b.size() == out.size()) { b_donatable && b.flags().row_contiguous && b.size() == out.size()) {
if (donate_with_move) { out.copy_shared_buffer(b);
out.move_shared_buffer(b);
} else {
out.copy_shared_buffer(b);
}
} else { } else {
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
} }
break; break;
} }

View File

@@ -0,0 +1,24 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/common/utils.h"
namespace mlx::core {
void broadcast(const array& in, array& out) {
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
Strides strides(out.ndim(), 0);
int diff = out.ndim() - in.ndim();
for (int i = in.ndim() - 1; i >= 0; --i) {
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
}
auto flags = in.flags();
if (out.size() > in.size()) {
flags.row_contiguous = flags.col_contiguous = false;
}
out.copy_shared_buffer(in, strides, flags, in.data_size());
}
} // namespace mlx::core

View File

@@ -1,10 +1,11 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
#pragma once #pragma once
#include "mlx/array.h"
namespace mlx::core { namespace mlx::core {
void encode_wait(Event e); void broadcast(const array& in, array& out);
void encode_signal(Event e);
} // namespace mlx::core } // namespace mlx::core

View File

@@ -1,6 +1,7 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
#include <cassert> #include <cassert>
#include "mlx/backend/common/broadcasting.h"
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
@@ -39,24 +40,7 @@ void AsStrided::eval(const std::vector<array>& inputs, array& out) {
// rely on data_size anyway. // rely on data_size anyway.
size_t data_size = out.size(); size_t data_size = out.size();
return move_or_copy(in, out, strides_, flags, data_size, offset_); return out.copy_shared_buffer(in, strides_, flags, data_size, offset_);
}
void broadcast(const array& in, array& out) {
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
Strides strides(out.ndim(), 0);
int diff = out.ndim() - in.ndim();
for (int i = in.ndim() - 1; i >= 0; --i) {
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
}
auto flags = in.flags();
if (out.size() > in.size()) {
flags.row_contiguous = flags.col_contiguous = false;
}
move_or_copy(in, out, strides, flags, in.data_size());
} }
void Broadcast::eval(const std::vector<array>& inputs, array& out) { void Broadcast::eval(const std::vector<array>& inputs, array& out) {
@@ -69,7 +53,7 @@ void BroadcastAxes::eval(const std::vector<array>& inputs, array& out) {
void Copy::eval(const std::vector<array>& inputs, array& out) { void Copy::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
move_or_copy(inputs[0], out); out.copy_shared_buffer(inputs[0]);
} }
void CustomTransforms::eval( void CustomTransforms::eval(
@@ -78,7 +62,7 @@ void CustomTransforms::eval(
assert(inputs.size() > outputs.size()); assert(inputs.size() > outputs.size());
for (int i = 0, j = inputs.size() - outputs.size(); i < outputs.size(); for (int i = 0, j = inputs.size() - outputs.size(); i < outputs.size();
i++, j++) { i++, j++) {
move_or_copy(inputs[j], outputs[i]); outputs[i].copy_shared_buffer(inputs[j]);
} }
} }
@@ -87,7 +71,7 @@ void Depends::eval(
std::vector<array>& outputs) { std::vector<array>& outputs) {
assert(inputs.size() > outputs.size()); assert(inputs.size() > outputs.size());
for (int i = 0; i < outputs.size(); i++) { for (int i = 0; i < outputs.size(); i++) {
move_or_copy(inputs[i], outputs[i]); outputs[i].copy_shared_buffer(inputs[i]);
} }
} }
@@ -98,12 +82,12 @@ void ExpandDims::eval(const std::vector<array>& inputs, array& out) {
for (auto ax : axes_) { for (auto ax : axes_) {
strides.insert(strides.begin() + ax, 1); strides.insert(strides.begin() + ax, 1);
} }
move_or_copy(in, out, strides, in.flags(), in.data_size()); out.copy_shared_buffer(in, strides, in.flags(), in.data_size());
} }
void NumberOfElements::eval(const std::vector<array>& inputs, array& out) { void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
double numel = 1; double numel = 1;
for (auto ax : axes_) { for (auto ax : axes_) {
@@ -210,7 +194,7 @@ void shared_buffer_reshape(
auto max_dim = std::max_element(out.shape().begin(), out.shape().end()); auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim; flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
} }
move_or_copy(in, out, out_strides, flags, in.data_size()); out.copy_shared_buffer(in, out_strides, flags, in.data_size());
} }
void Split::eval( void Split::eval(
@@ -276,12 +260,12 @@ void Squeeze::eval(const std::vector<array>& inputs, array& out) {
strides.push_back(in.strides(i)); strides.push_back(in.strides(i));
} }
} }
move_or_copy(in, out, strides, in.flags(), in.data_size()); out.copy_shared_buffer(in, strides, in.flags(), in.data_size());
} }
void StopGradient::eval(const std::vector<array>& inputs, array& out) { void StopGradient::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
move_or_copy(inputs[0], out); out.copy_shared_buffer(inputs[0]);
} }
void Transpose::eval(const std::vector<array>& inputs, array& out) { void Transpose::eval(const std::vector<array>& inputs, array& out) {
@@ -315,7 +299,7 @@ void Transpose::eval(const std::vector<array>& inputs, array& out) {
b_stride *= out.shape(ri); b_stride *= out.shape(ri);
} }
} }
move_or_copy(in, out, out_strides, flags, in.data_size()); out.copy_shared_buffer(in, out_strides, flags, in.data_size());
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -161,8 +161,7 @@ void compiled_allocate_outputs(
std::vector<array>& outputs, std::vector<array>& outputs,
const std::vector<array>& inputs_, const std::vector<array>& inputs_,
const std::unordered_set<uintptr_t>& constant_ids_, const std::unordered_set<uintptr_t>& constant_ids_,
bool contiguous, bool contiguous) {
bool move_buffers /* = false */) {
if (contiguous) { if (contiguous) {
int o = 0; int o = 0;
Strides strides; Strides strides;
@@ -178,11 +177,7 @@ void compiled_allocate_outputs(
if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) && if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) &&
in.is_donatable() && in.is_donatable() &&
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) { constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
if (move_buffers) { outputs[o++].copy_shared_buffer(in);
outputs[o++].move_shared_buffer(in);
} else {
outputs[o++].copy_shared_buffer(in);
}
} }
// Get representative input flags to properly set non-donated outputs // Get representative input flags to properly set non-donated outputs
if (strides.empty() && in.size() == outputs[0].size()) { if (strides.empty() && in.size() == outputs[0].size()) {
@@ -193,7 +188,7 @@ void compiled_allocate_outputs(
} }
for (; o < outputs.size(); ++o) { for (; o < outputs.size(); ++o) {
outputs[o].set_data( outputs[o].set_data(
allocator::malloc_or_wait(data_size * outputs[o].itemsize()), allocator::malloc(data_size * outputs[o].itemsize()),
data_size, data_size,
strides, strides,
flags); flags);
@@ -210,18 +205,13 @@ void compiled_allocate_outputs(
if (in.flags().row_contiguous && in.size() == outputs[o].size() && if (in.flags().row_contiguous && in.size() == outputs[o].size() &&
in.itemsize() == outputs[o].itemsize() && in.is_donatable() && in.itemsize() == outputs[o].itemsize() && in.is_donatable() &&
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) { constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
if (move_buffers) { outputs[o].copy_shared_buffer(
outputs[o].move_shared_buffer( in, outputs[o].strides(), in.flags(), in.data_size());
in, outputs[o].strides(), in.flags(), in.data_size());
} else {
outputs[o].copy_shared_buffer(
in, outputs[o].strides(), in.flags(), in.data_size());
}
o++; o++;
} }
} }
for (; o < outputs.size(); ++o) { for (; o < outputs.size(); ++o) {
outputs[o].set_data(allocator::malloc_or_wait(outputs[o].nbytes())); outputs[o].set_data(allocator::malloc(outputs[o].nbytes()));
} }
} }
} }

View File

@@ -62,7 +62,6 @@ void compiled_allocate_outputs(
std::vector<array>& outputs, std::vector<array>& outputs,
const std::vector<array>& inputs_, const std::vector<array>& inputs_,
const std::unordered_set<uintptr_t>& constant_ids_, const std::unordered_set<uintptr_t>& constant_ids_,
bool contiguous, bool contiguous);
bool move_buffers = false);
} // namespace mlx::core } // namespace mlx::core

View File

@@ -22,4 +22,25 @@ enum class CopyType {
GeneralGeneral GeneralGeneral
}; };
inline bool set_copy_output_data(const array& in, array& out, CopyType ctype) {
if (ctype == CopyType::Vector) {
// If the input is donateable, we are doing a vector copy and the types
// have the same size, then the input buffer can hold the output.
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
out.copy_shared_buffer(in);
return true;
} else {
out.set_data(
allocator::malloc(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
return false;
}
} else {
out.set_data(allocator::malloc(out.nbytes()));
return false;
}
}
} // namespace mlx::core } // namespace mlx::core

View File

@@ -3,7 +3,8 @@
#include <algorithm> #include <algorithm>
#include <utility> #include <utility>
#include "mlx/backend/common/load.h" #include "mlx/primitives.h"
#include "mlx/scheduler.h"
namespace { namespace {
@@ -26,26 +27,31 @@ void swap_endianness(uint8_t* data_bytes, size_t N) {
namespace mlx::core { namespace mlx::core {
void load( void Load::eval_cpu(const std::vector<array>& inputs, array& out) {
array& out, out.set_data(allocator::malloc(out.nbytes()));
size_t offset, auto read_task = [out_ptr = out.data<char>(),
const std::shared_ptr<io::Reader>& reader, size = out.size(),
bool swap_endianness_) { itemsize = out.itemsize(),
reader->read(out.data<char>(), out.nbytes(), offset); offset = offset_,
reader = reader_,
if (swap_endianness_) { swap_endianness_ = swap_endianness_]() mutable {
switch (out.itemsize()) { reader->read(out_ptr, size * itemsize, offset);
case 2: if (swap_endianness_) {
swap_endianness<2>(out.data<uint8_t>(), out.data_size()); switch (itemsize) {
break; case 2:
case 4: swap_endianness<2>(reinterpret_cast<uint8_t*>(out_ptr), size);
swap_endianness<4>(out.data<uint8_t>(), out.data_size()); break;
break; case 4:
case 8: swap_endianness<4>(reinterpret_cast<uint8_t*>(out_ptr), size);
swap_endianness<8>(out.data<uint8_t>(), out.data_size()); break;
break; case 8:
swap_endianness<8>(reinterpret_cast<uint8_t*>(out_ptr), size);
break;
}
} }
} };
auto fut = io::thread_pool().enqueue(std::move(read_task)).share();
scheduler::enqueue(stream(), [fut = std::move(fut)]() { fut.wait(); });
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -1,14 +0,0 @@
// Copyright © 2024 Apple Inc.
#include "mlx/array.h"
#include "mlx/io/load.h"
namespace mlx::core {
void load(
array& out,
size_t offset,
const std::shared_ptr<io::Reader>& reader,
bool swap_endianess);
} // namespace mlx::core

View File

@@ -36,7 +36,7 @@ void shared_buffer_slice(
flags.col_contiguous = is_col_contiguous; flags.col_contiguous = is_col_contiguous;
flags.contiguous = (no_bsx_size == data_size); flags.contiguous = (no_bsx_size == data_size);
move_or_copy(in, out, out_strides, flags, data_size, data_offset); out.copy_shared_buffer(in, out_strides, flags, data_size, data_offset);
} }
void slice( void slice(

View File

@@ -36,15 +36,10 @@ inline void set_ternary_op_output_data(
const array& b, const array& b,
const array& c, const array& c,
array& out, array& out,
TernaryOpType topt, TernaryOpType topt) {
bool donate_with_move = false) { auto maybe_donate = [&out](const array& x) {
auto maybe_donate = [&out, donate_with_move](const array& x) {
if (is_donatable(x, out)) { if (is_donatable(x, out)) {
if (donate_with_move) { out.copy_shared_buffer(x);
out.move_shared_buffer(x);
} else {
out.copy_shared_buffer(x);
}
return true; return true;
} }
return false; return false;
@@ -53,12 +48,12 @@ inline void set_ternary_op_output_data(
switch (topt) { switch (topt) {
case TernaryOpType::ScalarScalarScalar: case TernaryOpType::ScalarScalarScalar:
out.set_data( out.set_data(
allocator::malloc_or_wait(out.itemsize()), 1, b.strides(), b.flags()); allocator::malloc(out.itemsize()), 1, b.strides(), b.flags());
break; break;
case TernaryOpType::VectorVectorVector: case TernaryOpType::VectorVectorVector:
if (!(maybe_donate(a) || maybe_donate(b) || maybe_donate(c))) { if (!(maybe_donate(a) || maybe_donate(b) || maybe_donate(c))) {
out.set_data( out.set_data(
allocator::malloc_or_wait(out.itemsize() * b.data_size()), allocator::malloc(out.itemsize() * b.data_size()),
b.data_size(), b.data_size(),
b.strides(), b.strides(),
b.flags()); b.flags());
@@ -69,7 +64,7 @@ inline void set_ternary_op_output_data(
if (!((a.flags().row_contiguous && maybe_donate(a)) || if (!((a.flags().row_contiguous && maybe_donate(a)) ||
(b.flags().row_contiguous && maybe_donate(b)) || (b.flags().row_contiguous && maybe_donate(b)) ||
(c.flags().row_contiguous && maybe_donate(c)))) { (c.flags().row_contiguous && maybe_donate(c)))) {
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
} }
break; break;
} }

View File

@@ -4,28 +4,6 @@
namespace mlx::core { namespace mlx::core {
void move_or_copy(const array& in, array& out) {
if (in.is_donatable()) {
out.move_shared_buffer(in);
} else {
out.copy_shared_buffer(in);
}
}
void move_or_copy(
const array& in,
array& out,
const Strides& strides,
array::Flags flags,
size_t data_size,
size_t offset /* = 0 */) {
if (in.is_donatable()) {
out.move_shared_buffer(in, strides, flags, data_size, offset);
} else {
out.copy_shared_buffer(in, strides, flags, data_size, offset);
}
}
std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims( std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
const Shape& shape, const Shape& shape,
const std::vector<Strides>& strides, const std::vector<Strides>& strides,

View File

@@ -159,15 +159,6 @@ inline bool is_donatable(const array& in, const array& out) {
in.buffer_size() <= out.nbytes() + donation_extra; in.buffer_size() <= out.nbytes() + donation_extra;
} }
void move_or_copy(const array& in, array& out);
void move_or_copy(
const array& in,
array& out,
const Strides& strides,
array::Flags flags,
size_t data_size,
size_t offset = 0);
std::pair<bool, Strides> prepare_reshape(const array& in, const array& out); std::pair<bool, Strides> prepare_reshape(const array& in, const array& out);
void shared_buffer_reshape( void shared_buffer_reshape(

View File

@@ -44,7 +44,9 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp ${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eigh.cpp ${CMAKE_CURRENT_SOURCE_DIR}/eigh.cpp
${CMAKE_CURRENT_SOURCE_DIR}/encoder.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp ${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
@@ -56,6 +58,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp ${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp ${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp ${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp ${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
@@ -65,13 +68,14 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp ${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp
${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp ${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp) ${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp)
if(MLX_BUILD_ACCELERATE) if(MLX_BUILD_ACCELERATE)
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/bnns.cpp) target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/bnns.cpp)
else() else()
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/no_fp16.cpp target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/simd_fp16.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gemms/no_bf16.cpp) ${CMAKE_CURRENT_SOURCE_DIR}/gemms/simd_bf16.cpp)
endif() endif()
if(IOS) if(IOS)

View File

@@ -2,76 +2,27 @@
#pragma once #pragma once
#include "mlx/allocator.h"
#include "mlx/array.h" #include "mlx/array.h"
#include "mlx/backend/cpu/encoder.h"
namespace mlx::core { namespace mlx::core {
namespace { namespace {
template <typename T> template <typename T>
void arange(T start, T next, array& out, size_t size) { void arange(T start, T next, array& out, size_t size, Stream stream) {
auto ptr = out.data<T>(); auto ptr = out.data<T>();
auto step_size = next - start; auto step_size = next - start;
for (int i = 0; i < size; ++i) { auto& encoder = cpu::get_command_encoder(stream);
ptr[i] = start; encoder.set_output_array(out);
start += step_size; encoder.dispatch([ptr, start, step_size, size]() mutable {
} for (int i = 0; i < size; ++i) {
ptr[i] = start;
start += step_size;
}
});
} }
} // namespace } // namespace
void arange(
const std::vector<array>& inputs,
array& out,
double start,
double step) {
assert(inputs.size() == 0);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
switch (out.dtype()) {
case bool_:
throw std::runtime_error("Bool type unsupported for arange.");
break;
case uint8:
arange<uint8_t>(start, start + step, out, out.size());
break;
case uint16:
arange<uint16_t>(start, start + step, out, out.size());
break;
case uint32:
arange<uint32_t>(start, start + step, out, out.size());
break;
case uint64:
arange<uint64_t>(start, start + step, out, out.size());
break;
case int8:
arange<int8_t>(start, start + step, out, out.size());
break;
case int16:
arange<int16_t>(start, start + step, out, out.size());
break;
case int32:
arange<int32_t>(start, start + step, out, out.size());
break;
case int64:
arange<int64_t>(start, start + step, out, out.size());
break;
case float16:
arange<float16_t>(start, start + step, out, out.size());
break;
case float32:
arange<float>(start, start + step, out, out.size());
break;
case float64:
arange<double>(start, start + step, out, out.size());
break;
case bfloat16:
arange<bfloat16_t>(start, start + step, out, out.size());
break;
case complex64:
arange<complex64_t>(start, start + step, out, out.size());
break;
}
}
} // namespace mlx::core } // namespace mlx::core

View File

@@ -3,6 +3,7 @@
#include <cassert> #include <cassert>
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
namespace mlx::core { namespace mlx::core {
@@ -17,15 +18,18 @@ void arg_reduce(const array& in, array& out, const OpT& op, int axis) {
Shape shape = in.shape(); Shape shape = in.shape();
strides.erase(strides.begin() + axis); strides.erase(strides.begin() + axis);
shape.erase(shape.begin() + axis); shape.erase(shape.begin() + axis);
auto in_ptr = in.data<InT>();
auto out_ptr = out.data<uint32_t>();
for (uint32_t i = 0; i < out.size(); ++i) { for (uint32_t i = 0; i < out.size(); ++i) {
auto loc = elem_to_loc(i, shape, strides); auto loc = elem_to_loc(i, shape, strides);
auto in_ptr = in.data<InT>() + loc; auto local_in_ptr = in_ptr + loc;
uint32_t ind_v = 0; uint32_t ind_v = 0;
InT v = (*in_ptr); InT v = (*local_in_ptr);
for (uint32_t j = 0; j < axis_size; ++j, in_ptr += axis_stride) { for (uint32_t j = 0; j < axis_size; ++j, local_in_ptr += axis_stride) {
op(j, (*in_ptr), &ind_v, &v); op(j, (*local_in_ptr), &ind_v, &v);
} }
out.data<uint32_t>()[i] = ind_v; out_ptr[i] = ind_v;
} }
} }
@@ -64,52 +68,59 @@ void arg_reduce_dispatch(
void ArgReduce::eval_cpu(const std::vector<array>& inputs, array& out) { void ArgReduce::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cpu::get_command_encoder(stream());
switch (in.dtype()) { encoder.set_input_array(in);
case bool_: encoder.set_output_array(out);
arg_reduce_dispatch<bool>(in, out, reduce_type_, axis_); encoder.dispatch([in = array::unsafe_weak_copy(in),
break; out = array::unsafe_weak_copy(out),
case uint8: reduce_type_ = reduce_type_,
arg_reduce_dispatch<uint8_t>(in, out, reduce_type_, axis_); axis_ = axis_]() mutable {
break; switch (in.dtype()) {
case uint16: case bool_:
arg_reduce_dispatch<uint16_t>(in, out, reduce_type_, axis_); arg_reduce_dispatch<bool>(in, out, reduce_type_, axis_);
break; break;
case uint32: case uint8:
arg_reduce_dispatch<uint32_t>(in, out, reduce_type_, axis_); arg_reduce_dispatch<uint8_t>(in, out, reduce_type_, axis_);
break; break;
case uint64: case uint16:
arg_reduce_dispatch<uint64_t>(in, out, reduce_type_, axis_); arg_reduce_dispatch<uint16_t>(in, out, reduce_type_, axis_);
break; break;
case int8: case uint32:
arg_reduce_dispatch<int8_t>(in, out, reduce_type_, axis_); arg_reduce_dispatch<uint32_t>(in, out, reduce_type_, axis_);
break; break;
case int16: case uint64:
arg_reduce_dispatch<int16_t>(in, out, reduce_type_, axis_); arg_reduce_dispatch<uint64_t>(in, out, reduce_type_, axis_);
break; break;
case int32: case int8:
arg_reduce_dispatch<int32_t>(in, out, reduce_type_, axis_); arg_reduce_dispatch<int8_t>(in, out, reduce_type_, axis_);
break; break;
case int64: case int16:
arg_reduce_dispatch<int64_t>(in, out, reduce_type_, axis_); arg_reduce_dispatch<int16_t>(in, out, reduce_type_, axis_);
break; break;
case float16: case int32:
arg_reduce_dispatch<float16_t>(in, out, reduce_type_, axis_); arg_reduce_dispatch<int32_t>(in, out, reduce_type_, axis_);
break; break;
case float32: case int64:
arg_reduce_dispatch<float>(in, out, reduce_type_, axis_); arg_reduce_dispatch<int64_t>(in, out, reduce_type_, axis_);
break; break;
case bfloat16: case float16:
arg_reduce_dispatch<bfloat16_t>(in, out, reduce_type_, axis_); arg_reduce_dispatch<float16_t>(in, out, reduce_type_, axis_);
break; break;
case float64: case float32:
arg_reduce_dispatch<double>(in, out, reduce_type_, axis_); arg_reduce_dispatch<float>(in, out, reduce_type_, axis_);
break; break;
case complex64: case bfloat16:
arg_reduce_dispatch<complex64_t>(in, out, reduce_type_, axis_); arg_reduce_dispatch<bfloat16_t>(in, out, reduce_type_, axis_);
break; break;
} case float64:
arg_reduce_dispatch<double>(in, out, reduce_type_, axis_);
break;
case complex64:
arg_reduce_dispatch<complex64_t>(in, out, reduce_type_, axis_);
break;
}
});
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -8,6 +8,7 @@
#include "mlx/backend/cpu/binary.h" #include "mlx/backend/cpu/binary.h"
#include "mlx/backend/cpu/binary_ops.h" #include "mlx/backend/cpu/binary_ops.h"
#include "mlx/backend/cpu/binary_two.h" #include "mlx/backend/cpu/binary_two.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
#include "mlx/utils.h" #include "mlx/utils.h"
@@ -16,51 +17,218 @@ namespace mlx::core {
namespace { namespace {
template <typename Op> template <typename Op>
void comparison_op(const array& a, const array& b, array& out, Op op) { void binary(const array& a, const array& b, array& out, Op op, Stream stream) {
switch (a.dtype()) { auto bopt = get_binary_op_type(a, b);
case bool_: set_binary_op_output_data(a, b, out, bopt);
binary_op<bool, bool>(a, b, out, op);
break; auto& encoder = cpu::get_command_encoder(stream);
case uint8: encoder.set_input_array(a);
binary_op<uint8_t, bool>(a, b, out, op); encoder.set_input_array(b);
break; encoder.set_output_array(out);
case uint16: encoder.dispatch([a = array::unsafe_weak_copy(a),
binary_op<uint16_t, bool>(a, b, out, op); b = array::unsafe_weak_copy(b),
break; out = array::unsafe_weak_copy(out),
case uint32: bopt]() mutable {
binary_op<uint32_t, bool>(a, b, out, op); switch (out.dtype()) {
break; case bool_:
case uint64: binary_op<bool, Op>(a, b, out, bopt);
binary_op<uint64_t, bool>(a, b, out, op); break;
break; case uint8:
case int8: binary_op<uint8_t, Op>(a, b, out, bopt);
binary_op<int8_t, bool>(a, b, out, op); break;
break; case uint16:
case int16: binary_op<uint16_t, Op>(a, b, out, bopt);
binary_op<int16_t, bool>(a, b, out, op); break;
break; case uint32:
case int32: binary_op<uint32_t, Op>(a, b, out, bopt);
binary_op<int32_t, bool>(a, b, out, op); break;
break; case uint64:
case int64: binary_op<uint64_t, Op>(a, b, out, bopt);
binary_op<int64_t, bool>(a, b, out, op); break;
break; case int8:
case float16: binary_op<int8_t, Op>(a, b, out, bopt);
binary_op<float16_t, bool>(a, b, out, op); break;
break; case int16:
case float32: binary_op<int16_t, Op>(a, b, out, bopt);
binary_op<float, bool>(a, b, out, op); break;
break; case int32:
case float64: binary_op<int32_t, Op>(a, b, out, bopt);
binary_op<double, bool>(a, b, out, op); break;
break; case int64:
case bfloat16: binary_op<int64_t, Op>(a, b, out, bopt);
binary_op<bfloat16_t, bool>(a, b, out, op); break;
break; case float16:
case complex64: binary_op<float16_t, Op>(a, b, out, bopt);
binary_op<complex64_t, bool>(a, b, out, op); break;
break; case float32:
} binary_op<float, Op>(a, b, out, bopt);
break;
case float64:
binary_op<double, Op>(a, b, out, bopt);
break;
case bfloat16:
binary_op<bfloat16_t, Op>(a, b, out, bopt);
break;
case complex64:
binary_op<complex64_t, Op>(a, b, out, bopt);
break;
}
});
}
template <typename Op>
void comparison_op(
const array& a,
const array& b,
array& out,
Op op,
Stream stream) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (a.dtype()) {
case bool_:
binary_op<bool, bool, Op>(a, b, out, bopt);
break;
case uint8:
binary_op<uint8_t, bool, Op>(a, b, out, bopt);
break;
case uint16:
binary_op<uint16_t, bool, Op>(a, b, out, bopt);
break;
case uint32:
binary_op<uint32_t, bool, Op>(a, b, out, bopt);
break;
case uint64:
binary_op<uint64_t, bool, Op>(a, b, out, bopt);
break;
case int8:
binary_op<int8_t, bool, Op>(a, b, out, bopt);
break;
case int16:
binary_op<int16_t, bool, Op>(a, b, out, bopt);
break;
case int32:
binary_op<int32_t, bool, Op>(a, b, out, bopt);
break;
case int64:
binary_op<int64_t, bool, Op>(a, b, out, bopt);
break;
case float16:
binary_op<float16_t, bool, Op>(a, b, out, bopt);
break;
case float32:
binary_op<float, bool, Op>(a, b, out, bopt);
break;
case float64:
binary_op<double, bool, Op>(a, b, out, bopt);
break;
case bfloat16:
binary_op<bfloat16_t, bool, Op>(a, b, out, bopt);
break;
case complex64:
binary_op<complex64_t, bool, Op>(a, b, out, bopt);
break;
}
});
}
template <typename Op>
void binary_float(
const array& a,
const array& b,
array& out,
Op op,
Stream stream) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (out.dtype()) {
case float16:
binary_op<float16_t, Op>(a, b, out, bopt);
break;
case float32:
binary_op<float, Op>(a, b, out, bopt);
break;
case float64:
binary_op<double, Op>(a, b, out, bopt);
break;
case bfloat16:
binary_op<bfloat16_t, Op>(a, b, out, bopt);
break;
default:
throw std::runtime_error(
"[binary_float] Only supports non-complex floating point types.");
}
});
}
template <typename Op>
void binary_int(
const array& a,
const array& b,
array& out,
Op op,
Stream stream) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (out.dtype()) {
case bool_:
binary_op<bool, Op>(a, b, out, bopt);
case uint8:
binary_op<uint8_t, Op>(a, b, out, bopt);
break;
case uint16:
binary_op<uint16_t, Op>(a, b, out, bopt);
break;
case uint32:
binary_op<uint32_t, Op>(a, b, out, bopt);
break;
case uint64:
binary_op<uint64_t, Op>(a, b, out, bopt);
break;
case int8:
binary_op<int8_t, Op>(a, b, out, bopt);
break;
case int16:
binary_op<int16_t, Op>(a, b, out, bopt);
break;
case int32:
binary_op<int32_t, Op>(a, b, out, bopt);
break;
case int64:
binary_op<int64_t, Op>(a, b, out, bopt);
break;
default:
throw std::runtime_error("[binary_int] Type not supported");
break;
}
});
} }
} // namespace } // namespace
@@ -69,7 +237,7 @@ void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
binary(a, b, out, detail::Add()); binary(a, b, out, detail::Add(), stream());
} }
void DivMod::eval_cpu( void DivMod::eval_cpu(
@@ -78,70 +246,89 @@ void DivMod::eval_cpu(
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
auto integral_op = [](auto x, auto y) { auto bopt = get_binary_op_type(a, b);
return std::make_pair(x / y, x % y); auto& out_a = outputs[0];
}; auto& out_b = outputs[1];
auto float_op = [](auto x, auto y) { set_binary_op_output_data(a, b, out_a, bopt);
return std::make_pair(std::trunc(x / y), std::fmod(x, y)); set_binary_op_output_data(a, b, out_b, bopt);
};
switch (outputs[0].dtype()) { auto& encoder = cpu::get_command_encoder(stream());
case bool_: encoder.set_input_array(a);
binary_op<bool>(a, b, outputs, integral_op); encoder.set_input_array(b);
case uint8: encoder.set_output_array(out_a);
binary_op<uint8_t>(a, b, outputs, integral_op); encoder.set_output_array(out_b);
break;
case uint16: encoder.dispatch([a = array::unsafe_weak_copy(a),
binary_op<uint16_t>(a, b, outputs, integral_op); b = array::unsafe_weak_copy(b),
break; out_a = array::unsafe_weak_copy(out_a),
case uint32: out_b = array::unsafe_weak_copy(out_b),
binary_op<uint32_t>(a, b, outputs, integral_op); bopt]() mutable {
break; auto integral_op = [](auto x, auto y) {
case uint64: return std::make_pair(x / y, x % y);
binary_op<uint64_t>(a, b, outputs, integral_op); };
break; auto float_op = [](auto x, auto y) {
case int8: return std::make_pair(std::trunc(x / y), std::fmod(x, y));
binary_op<int8_t>(a, b, outputs, integral_op); };
break;
case int16: switch (out_a.dtype()) {
binary_op<int16_t>(a, b, outputs, integral_op); case bool_:
break; binary_op<bool>(a, b, out_a, out_b, integral_op, bopt);
case int32: case uint8:
binary_op<int32_t>(a, b, outputs, integral_op); binary_op<uint8_t>(a, b, out_a, out_b, integral_op, bopt);
break; break;
case int64: case uint16:
binary_op<int64_t>(a, b, outputs, integral_op); binary_op<uint16_t>(a, b, out_a, out_b, integral_op, bopt);
break; break;
case float16: case uint32:
binary_op<float16_t>(a, b, outputs, float_op); binary_op<uint32_t>(a, b, out_a, out_b, integral_op, bopt);
break; break;
case float32: case uint64:
binary_op<float>(a, b, outputs, float_op); binary_op<uint64_t>(a, b, out_a, out_b, integral_op, bopt);
break; break;
case float64: case int8:
binary_op<double>(a, b, outputs, float_op); binary_op<int8_t>(a, b, out_a, out_b, integral_op, bopt);
break; break;
case bfloat16: case int16:
binary_op<bfloat16_t>(a, b, outputs, float_op); binary_op<int16_t>(a, b, out_a, out_b, integral_op, bopt);
break; break;
case complex64: case int32:
// Should never get here binary_op<int32_t>(a, b, out_a, out_b, integral_op, bopt);
throw std::runtime_error("[DivMod] Complex type not supported"); break;
break; case int64:
} binary_op<int64_t>(a, b, out_a, out_b, integral_op, bopt);
break;
case float16:
binary_op<float16_t>(a, b, out_a, out_b, float_op, bopt);
break;
case float32:
binary_op<float>(a, b, out_a, out_b, float_op, bopt);
break;
case float64:
binary_op<double>(a, b, out_a, out_b, float_op, bopt);
break;
case bfloat16:
binary_op<bfloat16_t>(a, b, out_a, out_b, float_op, bopt);
break;
case complex64:
// Should never get here
throw std::runtime_error("[DivMod] Complex type not supported");
break;
}
});
} }
void Divide::eval_cpu(const std::vector<array>& inputs, array& out) { void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
binary(a, b, out, detail::Divide()); binary(a, b, out, detail::Divide(), stream());
} }
void Remainder::eval_cpu(const std::vector<array>& inputs, array& out) { void Remainder::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
binary(a, b, out, detail::Remainder()); binary(a, b, out, detail::Remainder(), stream());
} }
void Equal::eval_cpu(const std::vector<array>& inputs, array& out) { void Equal::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -149,181 +336,143 @@ void Equal::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
if (equal_nan_) { if (equal_nan_) {
switch (a.dtype()) { auto bopt = get_binary_op_type(a, b);
case float16: set_binary_op_output_data(a, b, out, bopt);
binary_op<float16_t, bool>(a, b, out, detail::NaNEqual());
break; auto& encoder = cpu::get_command_encoder(stream());
case float32: encoder.set_input_array(a);
binary_op<float, bool>(a, b, out, detail::NaNEqual()); encoder.set_input_array(b);
break; encoder.set_output_array(out);
case float64: encoder.dispatch([a = array::unsafe_weak_copy(a),
binary_op<double, bool>(a, b, out, detail::NaNEqual()); b = array::unsafe_weak_copy(b),
break; out = array::unsafe_weak_copy(out),
case bfloat16: bopt]() mutable {
binary_op<bfloat16_t, bool>(a, b, out, detail::NaNEqual()); switch (a.dtype()) {
break; case float16:
case complex64: binary_op<float16_t, bool, detail::NaNEqual>(a, b, out, bopt);
binary_op<complex64_t, bool>(a, b, out, detail::NaNEqual()); break;
break; case float32:
default: binary_op<float, bool, detail::NaNEqual>(a, b, out, bopt);
throw std::runtime_error( break;
"[NanEqual::eval_cpu] Only for floating point types."); case float64:
} binary_op<double, bool, detail::NaNEqual>(a, b, out, bopt);
break;
case bfloat16:
binary_op<bfloat16_t, bool, detail::NaNEqual>(a, b, out, bopt);
break;
case complex64:
binary_op<complex64_t, bool, detail::NaNEqual>(a, b, out, bopt);
break;
default:
throw std::runtime_error(
"[NanEqual::eval_cpu] Only for floating point types.");
}
});
} else { } else {
comparison_op(a, b, out, detail::Equal()); comparison_op(a, b, out, detail::Equal(), stream());
} }
} }
void Greater::eval_cpu(const std::vector<array>& inputs, array& out) { void Greater::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::Greater()); comparison_op(inputs[0], inputs[1], out, detail::Greater(), stream());
} }
void GreaterEqual::eval_cpu(const std::vector<array>& inputs, array& out) { void GreaterEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::GreaterEqual()); comparison_op(inputs[0], inputs[1], out, detail::GreaterEqual(), stream());
} }
void Less::eval_cpu(const std::vector<array>& inputs, array& out) { void Less::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::Less()); comparison_op(inputs[0], inputs[1], out, detail::Less(), stream());
} }
void LessEqual::eval_cpu(const std::vector<array>& inputs, array& out) { void LessEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::LessEqual()); comparison_op(inputs[0], inputs[1], out, detail::LessEqual(), stream());
} }
void LogAddExp::eval_cpu(const std::vector<array>& inputs, array& out) { void LogAddExp::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
switch (out.dtype()) { binary_float(a, b, out, detail::LogAddExp(), stream());
case float16:
binary_op<float16_t>(a, b, out, detail::LogAddExp());
break;
case float32:
binary_op<float>(a, b, out, detail::LogAddExp());
break;
case float64:
binary_op<double>(a, b, out, detail::LogAddExp());
break;
case bfloat16:
binary_op<bfloat16_t>(a, b, out, detail::LogAddExp());
break;
default:
throw std::runtime_error(
"[LogAddExp::eval_cpu] Only supports non-complex floating point types.");
}
} }
void LogicalAnd::eval_cpu(const std::vector<array>& inputs, array& out) { void LogicalAnd::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); // LogicalAnd requires two input arrays assert(inputs.size() == 2); // LogicalAnd requires two input arrays
auto& in1 = inputs[0]; auto& in1 = inputs[0];
auto& in2 = inputs[1]; auto& in2 = inputs[1];
binary(in1, in2, out, detail::LogicalAnd()); binary(in1, in2, out, detail::LogicalAnd(), stream());
} }
void LogicalOr::eval_cpu(const std::vector<array>& inputs, array& out) { void LogicalOr::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); // LogicalOr requires two input arrays assert(inputs.size() == 2); // LogicalOr requires two input arrays
auto& in1 = inputs[0]; auto& in1 = inputs[0];
auto& in2 = inputs[1]; auto& in2 = inputs[1];
binary(in1, in2, out, detail::LogicalOr()); binary(in1, in2, out, detail::LogicalOr(), stream());
} }
void Maximum::eval_cpu(const std::vector<array>& inputs, array& out) { void Maximum::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
binary(a, b, out, detail::Maximum()); binary(a, b, out, detail::Maximum(), stream());
} }
void Minimum::eval_cpu(const std::vector<array>& inputs, array& out) { void Minimum::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
binary(a, b, out, detail::Minimum()); binary(a, b, out, detail::Minimum(), stream());
} }
void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) { void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
binary(a, b, out, detail::Multiply()); binary(a, b, out, detail::Multiply(), stream());
} }
void NotEqual::eval_cpu(const std::vector<array>& inputs, array& out) { void NotEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::NotEqual()); comparison_op(inputs[0], inputs[1], out, detail::NotEqual(), stream());
} }
void Power::eval_cpu(const std::vector<array>& inputs, array& out) { void Power::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
binary(a, b, out, detail::Power()); binary(a, b, out, detail::Power(), stream());
} }
void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) { void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
binary(a, b, out, detail::Subtract()); binary(a, b, out, detail::Subtract(), stream());
} }
void BitwiseBinary::eval_cpu(const std::vector<array>& inputs, array& out) { void BitwiseBinary::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
auto dispatch_type = [&a, &b, &out](auto op) {
switch (out.dtype()) {
case bool_:
binary_op<bool>(a, b, out, op);
case uint8:
binary_op<uint8_t>(a, b, out, op);
break;
case uint16:
binary_op<uint16_t>(a, b, out, op);
break;
case uint32:
binary_op<uint32_t>(a, b, out, op);
break;
case uint64:
binary_op<uint64_t>(a, b, out, op);
break;
case int8:
binary_op<int8_t>(a, b, out, op);
break;
case int16:
binary_op<int16_t>(a, b, out, op);
break;
case int32:
binary_op<int32_t>(a, b, out, op);
break;
case int64:
binary_op<int64_t>(a, b, out, op);
break;
default:
throw std::runtime_error(
"[BitwiseBinary::eval_cpu] Type not supported");
break;
}
};
switch (op_) { switch (op_) {
case BitwiseBinary::And: case BitwiseBinary::And:
dispatch_type(detail::BitwiseAnd()); binary_int(a, b, out, detail::BitwiseAnd(), stream());
break; break;
case BitwiseBinary::Or: case BitwiseBinary::Or:
dispatch_type(detail::BitwiseOr()); binary_int(a, b, out, detail::BitwiseOr(), stream());
break; break;
case BitwiseBinary::Xor: case BitwiseBinary::Xor:
dispatch_type(detail::BitwiseXor()); binary_int(a, b, out, detail::BitwiseXor(), stream());
break; break;
case BitwiseBinary::LeftShift: case BitwiseBinary::LeftShift:
dispatch_type(detail::LeftShift()); binary_int(a, b, out, detail::LeftShift(), stream());
break; break;
case BitwiseBinary::RightShift: case BitwiseBinary::RightShift:
dispatch_type(detail::RightShift()); binary_int(a, b, out, detail::RightShift(), stream());
break; break;
} }
} }
@@ -332,23 +481,7 @@ void ArcTan2::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
const auto& a = inputs[0]; const auto& a = inputs[0];
const auto& b = inputs[1]; const auto& b = inputs[1];
switch (out.dtype()) { binary_float(a, b, out, detail::ArcTan2(), stream());
case float16:
binary_op<float16_t>(a, b, out, detail::ArcTan2());
break;
case float32:
binary_op<float>(a, b, out, detail::ArcTan2());
break;
case float64:
binary_op<double>(a, b, out, detail::ArcTan2());
break;
case bfloat16:
binary_op<bfloat16_t>(a, b, out, detail::ArcTan2());
break;
default:
throw std::runtime_error(
"[ArcTan2::eval_cpu] Only supports non-complex floating point types.");
}
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -3,7 +3,6 @@
#pragma once #pragma once
#include <cassert> #include <cassert>
#include "mlx/allocator.h"
#include "mlx/array.h" #include "mlx/array.h"
#include "mlx/backend/common/binary.h" #include "mlx/backend/common/binary.h"
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
@@ -14,22 +13,18 @@ namespace mlx::core {
template <typename Op> template <typename Op>
struct VectorScalar { struct VectorScalar {
Op op;
VectorScalar(Op op_) : op(op_) {}
template <typename T, typename U> template <typename T, typename U>
void operator()(const T* a, const T* b, U* dst, int size) { void operator()(const T* a, const T* b, U* dst, int size) {
T scalar = *b; T scalar = *b;
constexpr int N = simd::max_size<T>; constexpr int N = simd::max_size<T>;
while (size >= N) { while (size >= N) {
simd::store(dst, op(simd::load<T, N>(a), simd::Simd<T, N>(scalar))); simd::store(dst, Op{}(simd::load<T, N>(a), simd::Simd<T, N>(scalar)));
dst += N; dst += N;
a += N; a += N;
size -= N; size -= N;
} }
while (size-- > 0) { while (size-- > 0) {
*dst = op(*a, scalar); *dst = Op{}(*a, scalar);
dst++; dst++;
a++; a++;
} }
@@ -38,22 +33,18 @@ struct VectorScalar {
template <typename Op> template <typename Op>
struct ScalarVector { struct ScalarVector {
Op op;
ScalarVector(Op op_) : op(op_) {}
template <typename T, typename U> template <typename T, typename U>
void operator()(const T* a, const T* b, U* dst, int size) { void operator()(const T* a, const T* b, U* dst, int size) {
T scalar = *a; T scalar = *a;
constexpr int N = simd::max_size<T>; constexpr int N = simd::max_size<T>;
while (size >= N) { while (size >= N) {
simd::store(dst, op(simd::Simd<T, N>(scalar), simd::load<T, N>(b))); simd::store(dst, Op{}(simd::Simd<T, N>(scalar), simd::load<T, N>(b)));
dst += N; dst += N;
b += N; b += N;
size -= N; size -= N;
} }
while (size-- > 0) { while (size-- > 0) {
*dst = op(scalar, *b); *dst = Op{}(scalar, *b);
dst++; dst++;
b++; b++;
} }
@@ -62,22 +53,18 @@ struct ScalarVector {
template <typename Op> template <typename Op>
struct VectorVector { struct VectorVector {
Op op;
VectorVector(Op op_) : op(op_) {}
template <typename T, typename U> template <typename T, typename U>
void operator()(const T* a, const T* b, U* dst, int size) { void operator()(const T* a, const T* b, U* dst, int size) {
constexpr int N = simd::max_size<T>; constexpr int N = simd::max_size<T>;
while (size >= N) { while (size >= N) {
simd::store(dst, op(simd::load<T, N>(a), simd::load<T, N>(b))); simd::store(dst, Op{}(simd::load<T, N>(a), simd::load<T, N>(b)));
dst += N; dst += N;
a += N; a += N;
b += N; b += N;
size -= N; size -= N;
} }
while (size-- > 0) { while (size-- > 0) {
*dst = op(*a, *b); *dst = Op{}(*a, *b);
dst++; dst++;
a++; a++;
b++; b++;
@@ -90,7 +77,6 @@ void binary_op_dims(
const T* a, const T* a,
const T* b, const T* b,
U* out, U* out,
Op op,
const Shape& shape, const Shape& shape,
const Strides& a_strides, const Strides& a_strides,
const Strides& b_strides, const Strides& b_strides,
@@ -104,12 +90,12 @@ void binary_op_dims(
for (int i = 0; i < N; i++) { for (int i = 0; i < N; i++) {
if constexpr (D > 1) { if constexpr (D > 1) {
binary_op_dims<T, U, Op, D - 1, Strided>( binary_op_dims<T, U, Op, D - 1, Strided>(
a, b, out, op, shape, a_strides, b_strides, out_strides, axis + 1); a, b, out, shape, a_strides, b_strides, out_strides, axis + 1);
} else { } else {
if constexpr (Strided) { if constexpr (Strided) {
op(a, b, out, stride_out); Op{}(a, b, out, stride_out);
} else { } else {
*out = op(*a, *b); *out = Op{}(*a, *b);
} }
} }
out += stride_out; out += stride_out;
@@ -120,66 +106,38 @@ void binary_op_dims(
template <typename T, typename U, bool Strided, typename Op> template <typename T, typename U, bool Strided, typename Op>
void binary_op_dispatch_dims( void binary_op_dispatch_dims(
const array& a, const T* a,
const array& b, const T* b,
array& out, U* out,
Op op,
int dim, int dim,
int size,
const Shape& shape, const Shape& shape,
const Strides& a_strides, const Strides& a_strides,
const Strides& b_strides, const Strides& b_strides,
const Strides& out_strides) { const Strides& out_strides) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* out_ptr = out.data<U>();
switch (dim) { switch (dim) {
case 1: case 1:
binary_op_dims<T, U, Op, 1, Strided>( binary_op_dims<T, U, Op, 1, Strided>(
a_ptr, a, b, out, shape, a_strides, b_strides, out_strides, 0);
b_ptr,
out_ptr,
op,
shape,
a_strides,
b_strides,
out_strides,
0);
return; return;
case 2: case 2:
binary_op_dims<T, U, Op, 2, Strided>( binary_op_dims<T, U, Op, 2, Strided>(
a_ptr, a, b, out, shape, a_strides, b_strides, out_strides, 0);
b_ptr,
out_ptr,
op,
shape,
a_strides,
b_strides,
out_strides,
0);
return; return;
case 3: case 3:
binary_op_dims<T, U, Op, 3, Strided>( binary_op_dims<T, U, Op, 3, Strided>(
a_ptr, a, b, out, shape, a_strides, b_strides, out_strides, 0);
b_ptr,
out_ptr,
op,
shape,
a_strides,
b_strides,
out_strides,
0);
return; return;
} }
ContiguousIterator a_it(shape, a_strides, dim - 3); ContiguousIterator a_it(shape, a_strides, dim - 3);
ContiguousIterator b_it(shape, b_strides, dim - 3); ContiguousIterator b_it(shape, b_strides, dim - 3);
auto stride = out_strides[dim - 4]; auto stride = out_strides[dim - 4];
for (int64_t elem = 0; elem < a.size(); elem += stride) { for (int64_t elem = 0; elem < size; elem += stride) {
binary_op_dims<T, U, Op, 3, Strided>( binary_op_dims<T, U, Op, 3, Strided>(
a_ptr + a_it.loc, a + a_it.loc,
b_ptr + b_it.loc, b + b_it.loc,
out_ptr + elem, out + elem,
op,
shape, shape,
a_strides, a_strides,
b_strides, b_strides,
@@ -191,40 +149,41 @@ void binary_op_dispatch_dims(
} }
template <typename T, typename U, typename Op> template <typename T, typename U, typename Op>
void binary_op(const array& a, const array& b, array& out, Op op) { void binary_op(const array& a, const array& b, array& out, BinaryOpType bopt) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
// The full computation is scalar scalar so call the base op once // The full computation is scalar scalar so call the base op once
auto a_ptr = a.data<T>();
auto b_ptr = b.data<T>();
auto out_ptr = out.data<U>();
if (bopt == BinaryOpType::ScalarScalar) { if (bopt == BinaryOpType::ScalarScalar) {
*(out.data<U>()) = op(*a.data<T>(), *b.data<T>()); *out_ptr = Op{}(*a_ptr, *b_ptr);
return; return;
} }
// The full computation is scalar vector so delegate to the op // The full computation is scalar vector so delegate to the op
if (bopt == BinaryOpType::ScalarVector) { if (bopt == BinaryOpType::ScalarVector) {
ScalarVector{op}(a.data<T>(), b.data<T>(), out.data<U>(), b.data_size()); ScalarVector<Op>{}(a_ptr, b_ptr, out_ptr, b.data_size());
return; return;
} }
// The full computation is vector scalar so delegate to the op // The full computation is vector scalar so delegate to the op
if (bopt == BinaryOpType::VectorScalar) { if (bopt == BinaryOpType::VectorScalar) {
VectorScalar{op}(a.data<T>(), b.data<T>(), out.data<U>(), a.data_size()); VectorScalar<Op>{}(a_ptr, b_ptr, out_ptr, a.data_size());
return; return;
} }
// The full computation is vector vector so delegate to the op // The full computation is vector vector so delegate to the op
if (bopt == BinaryOpType::VectorVector) { if (bopt == BinaryOpType::VectorVector) {
VectorVector{op}(a.data<T>(), b.data<T>(), out.data<U>(), out.size()); VectorVector<Op>{}(a_ptr, b_ptr, out_ptr, a.size());
return; return;
} }
// General computation so let's try to optimize // General computation so let's try to optimize
auto [new_shape, new_strides] = collapse_contiguous_dims( auto [new_shape, new_strides] = collapse_contiguous_dims(
a.shape(), {a.strides(), b.strides(), out.strides()}); a.shape(), {a.strides(), b.strides(), out.strides()});
const auto& a_strides = new_strides[0]; auto& a_strides = new_strides[0];
const auto& b_strides = new_strides[1]; auto& b_strides = new_strides[1];
const auto& strides = new_strides[2]; auto& strides = new_strides[2];
// Get the left-most dim such that the array is row contiguous after // Get the left-most dim such that the array is row contiguous after
auto leftmost_rc_dim = [&strides](const auto& arr_strides) { auto leftmost_rc_dim = [&strides](const auto& arr_strides) {
@@ -248,7 +207,8 @@ void binary_op(const array& a, const array& b, array& out, Op op) {
auto ndim = new_shape.size(); auto ndim = new_shape.size();
// Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous // Case 1: LxM and FxM where L and F are broadcastable and M is row
// contiguous
int dim = ndim; int dim = ndim;
if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) { if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
bopt = BinaryOpType::VectorVector; bopt = BinaryOpType::VectorVector;
@@ -275,99 +235,59 @@ void binary_op(const array& a, const array& b, array& out, Op op) {
switch (bopt) { switch (bopt) {
case BinaryOpType::VectorVector: case BinaryOpType::VectorVector:
binary_op_dispatch_dims<T, U, true>( binary_op_dispatch_dims<T, U, true, VectorVector<Op>>(
a, a_ptr,
b, b_ptr,
out, out_ptr,
VectorVector{op},
dim, dim,
a.size(),
new_shape, new_shape,
a_strides, a_strides,
b_strides, b_strides,
strides); strides);
break; break;
case BinaryOpType::VectorScalar: case BinaryOpType::VectorScalar:
binary_op_dispatch_dims<T, U, true>( binary_op_dispatch_dims<T, U, true, VectorScalar<Op>>(
a, a_ptr,
b, b_ptr,
out, out_ptr,
VectorScalar{op},
dim, dim,
a.size(),
new_shape, new_shape,
a_strides, a_strides,
b_strides, b_strides,
strides); strides);
break; break;
case BinaryOpType::ScalarVector: case BinaryOpType::ScalarVector:
binary_op_dispatch_dims<T, U, true>( binary_op_dispatch_dims<T, U, true, ScalarVector<Op>>(
a, a_ptr,
b, b_ptr,
out, out_ptr,
ScalarVector{op},
dim, dim,
a.size(),
new_shape, new_shape,
a_strides, a_strides,
b_strides, b_strides,
strides); strides);
break; break;
default: default:
binary_op_dispatch_dims<T, U, false>( binary_op_dispatch_dims<T, U, false, Op>(
a, b, out, op, dim, new_shape, a_strides, b_strides, strides); a_ptr,
b_ptr,
out_ptr,
dim,
a.size(),
new_shape,
a_strides,
b_strides,
strides);
break; break;
} }
} }
template <typename T, typename Op> template <typename T, typename Op>
void binary_op(const array& a, const array& b, array& out, Op op) { void binary_op(const array& a, const array& b, array& out, BinaryOpType bopt) {
binary_op<T, T>(a, b, out, op); binary_op<T, T, Op>(a, b, out, bopt);
}
template <typename Op>
void binary(const array& a, const array& b, array& out, Op op) {
switch (out.dtype()) {
case bool_:
binary_op<bool>(a, b, out, op);
break;
case uint8:
binary_op<uint8_t>(a, b, out, op);
break;
case uint16:
binary_op<uint16_t>(a, b, out, op);
break;
case uint32:
binary_op<uint32_t>(a, b, out, op);
break;
case uint64:
binary_op<uint64_t>(a, b, out, op);
break;
case int8:
binary_op<int8_t>(a, b, out, op);
break;
case int16:
binary_op<int16_t>(a, b, out, op);
break;
case int32:
binary_op<int32_t>(a, b, out, op);
break;
case int64:
binary_op<int64_t>(a, b, out, op);
break;
case float16:
binary_op<float16_t>(a, b, out, op);
break;
case float32:
binary_op<float>(a, b, out, op);
break;
case float64:
binary_op<double>(a, b, out, op);
break;
case bfloat16:
binary_op<bfloat16_t>(a, b, out, op);
break;
case complex64:
binary_op<complex64_t>(a, b, out, op);
break;
}
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -58,14 +58,14 @@ void binary_op_dispatch_dims(
Op op) { Op op) {
auto [shape, strides] = collapse_contiguous_dims( auto [shape, strides] = collapse_contiguous_dims(
a.shape(), {a.strides(), b.strides(), out_a.strides()}); a.shape(), {a.strides(), b.strides(), out_a.strides()});
const auto& a_strides = strides[0];
const auto& b_strides = strides[1];
const auto& out_strides = strides[2];
const T* a_ptr = a.data<T>(); const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>(); const T* b_ptr = b.data<T>();
U* out_a_ptr = out_a.data<U>(); U* out_a_ptr = out_a.data<U>();
U* out_b_ptr = out_b.data<U>(); U* out_b_ptr = out_b.data<U>();
const auto& a_strides = strides[0];
const auto& b_strides = strides[1];
const auto& out_strides = strides[2];
int ndim = shape.size(); int ndim = shape.size();
switch (ndim) { switch (ndim) {
case 1: case 1:
@@ -120,14 +120,10 @@ template <typename T, typename U = T, typename Op>
void binary_op( void binary_op(
const array& a, const array& a,
const array& b, const array& b,
std::vector<array>& outputs, array& out_a,
Op op) { array& out_b,
auto bopt = get_binary_op_type(a, b); Op op,
auto& out_a = outputs[0]; BinaryOpType bopt) {
auto& out_b = outputs[1];
set_binary_op_output_data(a, b, out_a, bopt);
set_binary_op_output_data(a, b, out_b, bopt);
// The full computation is scalar scalar so call the base op once // The full computation is scalar scalar so call the base op once
if (bopt == BinaryOpType::General) { if (bopt == BinaryOpType::General) {
binary_op_dispatch_dims<T, U, Op>(a, b, out_a, out_b, op); binary_op_dispatch_dims<T, U, Op>(a, b, out_a, out_b, op);
@@ -141,14 +137,14 @@ void binary_op(
if (bopt == BinaryOpType::ScalarScalar) { if (bopt == BinaryOpType::ScalarScalar) {
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr); std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
} else if (bopt == BinaryOpType::ScalarVector) { } else if (bopt == BinaryOpType::ScalarVector) {
for (size_t i = 0; i < b.size(); ++i) { for (size_t i = 0; i < b.data_size(); ++i) {
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr); std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
out_a_ptr++; out_a_ptr++;
out_b_ptr++; out_b_ptr++;
b_ptr++; b_ptr++;
} }
} else if (bopt == BinaryOpType::VectorScalar) { } else if (bopt == BinaryOpType::VectorScalar) {
for (size_t i = 0; i < a.size(); ++i) { for (size_t i = 0; i < a.data_size(); ++i) {
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr); std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
out_a_ptr++; out_a_ptr++;
out_b_ptr++; out_b_ptr++;
@@ -165,58 +161,6 @@ void binary_op(
} }
} }
template <typename Op>
void binary(
const array& a,
const array& b,
std::vector<array>& outputs,
Op op) {
switch (outputs[0].dtype()) {
case bool_:
binary_op<bool>(a, b, outputs, op);
break;
case uint8:
binary_op<uint8_t>(a, b, outputs, op);
break;
case uint16:
binary_op<uint16_t>(a, b, outputs, op);
break;
case uint32:
binary_op<uint32_t>(a, b, outputs, op);
break;
case uint64:
binary_op<uint64_t>(a, b, outputs, op);
break;
case int8:
binary_op<int8_t>(a, b, outputs, op);
break;
case int16:
binary_op<int16_t>(a, b, outputs, op);
break;
case int32:
binary_op<int32_t>(a, b, outputs, op);
break;
case int64:
binary_op<int64_t>(a, b, outputs, op);
break;
case float16:
binary_op<float16_t>(a, b, outputs, op);
break;
case float32:
binary_op<float>(a, b, outputs, op);
break;
case float64:
binary_op<double>(a, b, outputs, op);
break;
case bfloat16:
binary_op<bfloat16_t>(a, b, outputs, op);
break;
case complex64:
binary_op<complex64_t>(a, b, outputs, op);
break;
}
}
} // namespace } // namespace
} // namespace mlx::core } // namespace mlx::core

View File

@@ -2,6 +2,7 @@
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/lapack.h" #include "mlx/backend/cpu/lapack.h"
#include "mlx/linalg.h" #include "mlx/linalg.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
@@ -9,7 +10,7 @@
namespace mlx::core { namespace mlx::core {
template <typename T> template <typename T>
void cholesky_impl(const array& a, array& factor, bool upper) { void cholesky_impl(const array& a, array& factor, bool upper, Stream stream) {
// Lapack uses the column-major convention. We take advantage of the fact that // Lapack uses the column-major convention. We take advantage of the fact that
// the matrix should be symmetric: // the matrix should be symmetric:
// (A)ᵀ = A // (A)ᵀ = A
@@ -17,60 +18,63 @@ void cholesky_impl(const array& a, array& factor, bool upper) {
// triangular matrix, so uplo is the opposite of what we would expect from // triangular matrix, so uplo is the opposite of what we would expect from
// upper // upper
char uplo = (upper) ? 'L' : 'U';
// The decomposition is computed in place, so just copy the input to the // The decomposition is computed in place, so just copy the input to the
// output. // output.
copy( copy(
a, a,
factor, factor,
a.flags().row_contiguous ? CopyType::Vector : CopyType::General); a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
stream);
const int N = a.shape(-1); auto& encoder = cpu::get_command_encoder(stream);
const size_t num_matrices = a.size() / (N * N); encoder.set_output_array(factor);
encoder.dispatch([matrix = factor.data<T>(),
upper,
N = a.shape(-1),
size = a.size()]() mutable {
char uplo = (upper) ? 'L' : 'U';
size_t num_matrices = size / (N * N);
for (int i = 0; i < num_matrices; i++) {
// Compute Cholesky factorization.
int info;
potrf<T>(
/* uplo = */ &uplo,
/* n = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* info = */ &info);
T* matrix = factor.data<T>(); // TODO: We do nothing when the matrix is not positive semi-definite
// because throwing an error would result in a crash. If we figure out how
for (int i = 0; i < num_matrices; i++) { // to catch errors from the implementation we should throw.
// Compute Cholesky factorization. if (info < 0) {
int info; std::stringstream msg;
potrf<T>( msg << "[Cholesky::eval_cpu] Cholesky decomposition failed with error code "
/* uplo = */ &uplo, << info;
/* n = */ &N, throw std::runtime_error(msg.str());
/* a = */ matrix, }
/* lda = */ &N,
/* info = */ &info); // Zero out the upper/lower triangle while advancing the pointer to the
// next matrix at the same time.
// TODO: We do nothing when the matrix is not positive semi-definite for (int row = 0; row < N; row++) {
// because throwing an error would result in a crash. If we figure out how if (upper) {
// to catch errors from the implementation we should throw. std::fill(matrix, matrix + row, 0);
if (info < 0) { } else {
std::stringstream msg; std::fill(matrix + row + 1, matrix + N, 0);
msg << "[cholesky] Cholesky decomposition failed with error code " }
<< info; matrix += N;
throw std::runtime_error(msg.str());
}
// Zero out the upper/lower triangle while advancing the pointer to the
// next matrix at the same time.
for (int row = 0; row < N; row++) {
if (upper) {
std::fill(matrix, matrix + row, 0);
} else {
std::fill(matrix + row + 1, matrix + N, 0);
} }
matrix += N;
} }
} });
} }
void Cholesky::eval_cpu(const std::vector<array>& inputs, array& output) { void Cholesky::eval_cpu(const std::vector<array>& inputs, array& output) {
switch (inputs[0].dtype()) { switch (inputs[0].dtype()) {
case float32: case float32:
cholesky_impl<float>(inputs[0], output, upper_); cholesky_impl<float>(inputs[0], output, upper_, stream());
break; break;
case float64: case float64:
cholesky_impl<double>(inputs[0], output, upper_); cholesky_impl<double>(inputs[0], output, upper_, stream());
break; break;
default: default:
throw std::runtime_error( throw std::runtime_error(

View File

@@ -11,6 +11,7 @@
#include "mlx/backend/common/compiled.h" #include "mlx/backend/common/compiled.h"
#include "mlx/backend/cpu/compiled_preamble.h" #include "mlx/backend/cpu/compiled_preamble.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/jit_compiler.h" #include "mlx/backend/cpu/jit_compiler.h"
#include "mlx/device.h" #include "mlx/device.h"
#include "mlx/graph_utils.h" #include "mlx/graph_utils.h"
@@ -288,6 +289,7 @@ void Compiled::eval_cpu(
// Figure out which kernel we are using // Figure out which kernel we are using
auto& shape = outputs[0].shape(); auto& shape = outputs[0].shape();
auto contiguous = compiled_check_contiguity(inputs, shape); auto contiguous = compiled_check_contiguity(inputs, shape);
auto& encoder = cpu::get_command_encoder(stream());
// Handle all broadcasting and collect function input arguments // Handle all broadcasting and collect function input arguments
std::vector<void*> args; std::vector<void*> args;
@@ -298,6 +300,7 @@ void Compiled::eval_cpu(
continue; continue;
} }
auto& x = inputs[i]; auto& x = inputs[i];
encoder.set_input_array(x);
args.push_back((void*)x.data<void>()); args.push_back((void*)x.data<void>());
if (contiguous || is_scalar(x)) { if (contiguous || is_scalar(x)) {
@@ -356,18 +359,25 @@ void Compiled::eval_cpu(
}); });
compiled_allocate_outputs( compiled_allocate_outputs(
inputs, outputs, inputs_, constant_ids_, contiguous, false); inputs, outputs, inputs_, constant_ids_, contiguous);
for (auto& x : outputs) { for (auto& x : outputs) {
args.push_back(x.data<void>()); args.push_back(x.data<void>());
encoder.set_output_array(x);
} }
Shape out_shape;
if (!contiguous) { if (!contiguous) {
args.push_back((void*)outputs[0].shape().data()); out_shape = outputs[0].shape();
args.push_back((void*)out_shape.data());
} else { } else {
args.push_back((void*)outputs[0].data_size()); args.push_back((void*)outputs[0].data_size());
} }
auto fun = (void (*)(void**))fn_ptr; auto fun = (void (*)(void**))fn_ptr;
fun(args.data()); encoder.dispatch(
[fun,
args = std::move(args),
strides = std::move(strides),
out_shape = std::move(out_shape)]() mutable { fun(args.data()); });
} }
} // namespace mlx::core } // namespace mlx::core

File diff suppressed because it is too large Load Diff

View File

@@ -5,6 +5,7 @@
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h" #include "mlx/backend/cpu/simd/simd.h"
namespace mlx::core { namespace mlx::core {
@@ -13,19 +14,19 @@ namespace {
template <typename SrcT, typename DstT> template <typename SrcT, typename DstT>
void copy_single(const array& src, array& dst) { void copy_single(const array& src, array& dst) {
auto val = static_cast<DstT>(src.data<SrcT>()[0]); auto src_ptr = src.data<SrcT>();
auto dst_ptr = dst.data<DstT>(); auto dst_ptr = dst.data<DstT>();
for (int i = 0; i < dst.size(); ++i) { auto size = dst.size();
dst_ptr[i] = val; auto val = static_cast<DstT>(src_ptr[0]);
} std::fill_n(dst_ptr, size, val);
} }
template <typename SrcT, typename DstT> template <typename SrcT, typename DstT>
void copy_vector(const array& src, array& dst) { void copy_vector(const array& src, array& dst) {
auto src_ptr = src.data<SrcT>(); auto src_ptr = src.data<SrcT>();
auto dst_ptr = dst.data<DstT>(); auto dst_ptr = dst.data<DstT>();
size_t size = src.data_size(); auto size = src.data_size();
std::copy(src_ptr, src_ptr + src.data_size(), dst_ptr); std::copy(src_ptr, src_ptr + size, dst_ptr);
} }
template <typename SrcT, typename DstT, int D> template <typename SrcT, typename DstT, int D>
@@ -60,36 +61,57 @@ void copy_general_general(
const Strides& i_strides, const Strides& i_strides,
const Strides& o_strides, const Strides& o_strides,
int64_t i_offset, int64_t i_offset,
int64_t o_offset) { int64_t o_offset,
const std::optional<array>& dynamic_i_offset,
const std::optional<array>& dynamic_o_offset) {
auto src_ptr = src.data<SrcT>() + i_offset;
auto dst_ptr = dst.data<DstT>() + o_offset;
auto i_offset_ptr =
dynamic_i_offset ? dynamic_i_offset->data<int64_t>() : nullptr;
auto o_offset_ptr =
dynamic_o_offset ? dynamic_o_offset->data<int64_t>() : nullptr;
auto size = src.size();
if (data_shape.empty()) { if (data_shape.empty()) {
auto val = static_cast<DstT>(*(src.data<SrcT>() + i_offset)); auto val = static_cast<DstT>(*src_ptr);
auto dst_ptr = dst.data<DstT>() + o_offset;
*dst_ptr = val; *dst_ptr = val;
return; return;
} }
auto [shape, strides] = auto [shape, strides] =
collapse_contiguous_dims(data_shape, {i_strides, o_strides}); collapse_contiguous_dims(data_shape, {i_strides, o_strides});
auto src_ptr = src.data<SrcT>() + i_offset;
auto dst_ptr = dst.data<DstT>() + o_offset;
int ndim = shape.size(); int ndim = shape.size();
if (ndim == 1) { if (ndim < 3) {
copy_dims<SrcT, DstT, 1>( if (i_offset_ptr) {
src_ptr, dst_ptr, shape, strides[0], strides[1], 0); src_ptr += i_offset_ptr[0];
return; }
} else if (ndim == 2) { if (o_offset_ptr) {
copy_dims<SrcT, DstT, 2>( dst_ptr += o_offset_ptr[0];
src_ptr, dst_ptr, shape, strides[0], strides[1], 0); }
return;
} else if (ndim == 3) { if (ndim == 1) {
copy_dims<SrcT, DstT, 3>( copy_dims<SrcT, DstT, 1>(
src_ptr, dst_ptr, shape, strides[0], strides[1], 0); src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
} else if (ndim == 2) {
copy_dims<SrcT, DstT, 2>(
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
} else if (ndim == 3) {
copy_dims<SrcT, DstT, 3>(
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
}
return; return;
} }
if (i_offset_ptr) {
src_ptr += i_offset_ptr[0];
}
if (o_offset_ptr) {
dst_ptr += o_offset_ptr[0];
}
ContiguousIterator in(shape, strides[0], ndim - 3); ContiguousIterator in(shape, strides[0], ndim - 3);
ContiguousIterator out(shape, strides[1], ndim - 3); ContiguousIterator out(shape, strides[1], ndim - 3);
auto stride = std::accumulate( auto stride = std::accumulate(
shape.end() - 3, shape.end(), 1, std::multiplies<int64_t>()); shape.end() - 3, shape.end(), 1, std::multiplies<int64_t>());
for (int64_t elem = 0; elem < src.size(); elem += stride) { for (int64_t elem = 0; elem < size; elem += stride) {
copy_dims<SrcT, DstT, 3>( copy_dims<SrcT, DstT, 3>(
src_ptr + in.loc, src_ptr + in.loc,
dst_ptr + out.loc, dst_ptr + out.loc,
@@ -105,7 +127,15 @@ void copy_general_general(
template <typename SrcT, typename DstT> template <typename SrcT, typename DstT>
inline void copy_general_general(const array& src, array& dst) { inline void copy_general_general(const array& src, array& dst) {
copy_general_general<SrcT, DstT>( copy_general_general<SrcT, DstT>(
src, dst, src.shape(), src.strides(), dst.strides(), 0, 0); src,
dst,
src.shape(),
src.strides(),
dst.strides(),
0,
0,
std::nullopt,
std::nullopt);
} }
template <typename SrcT, typename DstT> template <typename SrcT, typename DstT>
@@ -116,7 +146,9 @@ void copy_general(
const Strides& i_strides, const Strides& i_strides,
const Strides&, const Strides&,
int64_t i_offset, int64_t i_offset,
int64_t o_offset) { int64_t o_offset,
const std::optional<array>& dynamic_i_offset,
const std::optional<array>& dynamic_o_offset) {
copy_general_general<SrcT, DstT>( copy_general_general<SrcT, DstT>(
src, src,
dst, dst,
@@ -124,7 +156,9 @@ void copy_general(
i_strides, i_strides,
make_contiguous_strides(data_shape), make_contiguous_strides(data_shape),
i_offset, i_offset,
o_offset); o_offset,
dynamic_i_offset,
dynamic_o_offset);
} }
template <typename SrcT, typename DstT> template <typename SrcT, typename DstT>
@@ -136,7 +170,9 @@ inline void copy_general(const array& src, array& dst) {
src.strides(), src.strides(),
make_contiguous_strides(src.shape()), make_contiguous_strides(src.shape()),
0, 0,
0); 0,
std::nullopt,
std::nullopt);
} }
template <typename SrcT, typename DstT, typename... Args> template <typename SrcT, typename DstT, typename... Args>
@@ -259,35 +295,27 @@ inline void copy_inplace_dispatch(
} // namespace } // namespace
void copy_inplace(const array& src, array& dst, CopyType ctype) { void copy_inplace(const array& src, array& dst, CopyType ctype, Stream stream) {
copy_inplace_dispatch(src, dst, ctype); auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(src);
encoder.set_output_array(dst);
encoder.dispatch(
[src = array::unsafe_weak_copy(src),
dst = array::unsafe_weak_copy(dst),
ctype]() mutable { copy_inplace_dispatch(src, dst, ctype); });
} }
void copy(const array& src, array& dst, CopyType ctype) { void copy(const array& src, array& dst, CopyType ctype, Stream stream) {
// Allocate the output bool donated = set_copy_output_data(src, dst, ctype);
switch (ctype) { if (donated && src.dtype() == dst.dtype()) {
case CopyType::Vector: // If the output has the same type as the input then there is nothing to
if (src.is_donatable() && src.itemsize() == dst.itemsize()) { // copy, just use the buffer.
dst.copy_shared_buffer(src); return;
} else {
auto size = src.data_size();
dst.set_data(
allocator::malloc_or_wait(size * dst.itemsize()),
size,
src.strides(),
src.flags());
}
break;
case CopyType::Scalar:
case CopyType::General:
case CopyType::GeneralGeneral:
dst.set_data(allocator::malloc_or_wait(dst.nbytes()));
break;
} }
if (ctype == CopyType::GeneralGeneral) { if (ctype == CopyType::GeneralGeneral) {
ctype = CopyType::General; ctype = CopyType::General;
} }
copy_inplace(src, dst, ctype); copy_inplace(src, dst, ctype, stream);
} }
void copy_inplace( void copy_inplace(
@@ -298,24 +326,51 @@ void copy_inplace(
const Strides& o_strides, const Strides& o_strides,
int64_t i_offset, int64_t i_offset,
int64_t o_offset, int64_t o_offset,
CopyType ctype) { CopyType ctype,
switch (ctype) { Stream stream,
case CopyType::General: const std::optional<array>& dynamic_i_offset, /* = std::nullopt */
case CopyType::GeneralGeneral: const std::optional<array>& dynamic_o_offset /* = std::nullopt */) {
copy_inplace_dispatch( auto& encoder = cpu::get_command_encoder(stream);
src, encoder.set_input_array(src);
dst, encoder.set_output_array(dst);
ctype, auto weak_copy_if_set = [](auto x) -> std::optional<array> {
data_shape, if (x) {
i_strides, return array::unsafe_weak_copy(*x);
o_strides, } else {
i_offset, return std::nullopt;
o_offset); }
break; };
case CopyType::Scalar: encoder.dispatch(
case CopyType::Vector: [src = array::unsafe_weak_copy(src),
copy_inplace_dispatch(src, dst, ctype); dst = array::unsafe_weak_copy(dst),
} data_shape,
i_strides,
o_strides,
i_offset,
o_offset,
ctype,
dynamic_i_offset = weak_copy_if_set(dynamic_i_offset),
dynamic_o_offset = weak_copy_if_set(dynamic_o_offset)]() mutable {
switch (ctype) {
case CopyType::General:
case CopyType::GeneralGeneral:
copy_inplace_dispatch(
src,
dst,
ctype,
data_shape,
i_strides,
o_strides,
i_offset,
o_offset,
dynamic_i_offset,
dynamic_o_offset);
break;
case CopyType::Scalar:
case CopyType::Vector:
copy_inplace_dispatch(src, dst, ctype);
}
});
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -2,14 +2,16 @@
#pragma once #pragma once
#include <optional>
#include "mlx/array.h" #include "mlx/array.h"
#include "mlx/backend/common/copy.h" #include "mlx/backend/common/copy.h"
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
namespace mlx::core { namespace mlx::core {
void copy(const array& src, array& dst, CopyType ctype); void copy(const array& src, array& dst, CopyType ctype, Stream stream);
void copy_inplace(const array& src, array& dst, CopyType ctype); void copy_inplace(const array& src, array& dst, CopyType ctype, Stream stream);
void copy_inplace( void copy_inplace(
const array& src, const array& src,
@@ -19,6 +21,9 @@ void copy_inplace(
const Strides& o_strides, const Strides& o_strides,
int64_t i_offset, int64_t i_offset,
int64_t o_offset, int64_t o_offset,
CopyType ctype); CopyType ctype,
Stream stream,
const std::optional<array>& dynamic_i_offset = std::nullopt,
const std::optional<array>& dynamic_o_offset = std::nullopt);
} // namespace mlx::core } // namespace mlx::core

View File

@@ -0,0 +1,101 @@
// Copyright © 2024 Apple Inc.
#include <cassert>
#include "mlx/allocator.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/distributed/primitives.h"
namespace mlx::core::distributed {
std::pair<array, bool> ensure_row_contiguous(const array& arr, Stream stream) {
if (arr.flags().row_contiguous) {
return {arr, false};
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General, stream);
return {arr_copy, true};
}
};
void AllReduce::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 1);
assert(outputs.size() == 1);
auto donate_or_copy = [s = stream()](const array& in, array& out) {
if (in.flags().row_contiguous) {
if (in.is_donatable()) {
out.copy_shared_buffer(in);
} else {
out.set_data(allocator::malloc(out.nbytes()));
}
return in;
} else {
array arr_copy(in.shape(), in.dtype(), nullptr, {});
copy(in, arr_copy, CopyType::General, s);
out.copy_shared_buffer(arr_copy);
return arr_copy;
}
};
auto in = donate_or_copy(inputs[0], outputs[0]);
switch (reduce_type_) {
case Sum:
distributed::detail::all_sum(group(), in, outputs[0], stream());
break;
case Max:
distributed::detail::all_max(group(), in, outputs[0], stream());
break;
case Min:
distributed::detail::all_min(group(), in, outputs[0], stream());
break;
default:
throw std::runtime_error(
"Only all reduce sum, min and max are supported for now");
}
}
void AllGather::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 1);
assert(outputs.size() == 1);
auto [in, copied] = ensure_row_contiguous(inputs[0], stream());
outputs[0].set_data(allocator::malloc(outputs[0].nbytes()));
distributed::detail::all_gather(group(), in, outputs[0], stream());
if (copied) {
auto& enc = cpu::get_command_encoder(stream());
enc.add_temporary(in);
}
}
void Send::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 1);
assert(outputs.size() == 1);
auto [in, copied] = ensure_row_contiguous(inputs[0], stream());
distributed::detail::send(group(), in, dst_, stream());
outputs[0].copy_shared_buffer(inputs[0]);
if (copied) {
auto& enc = cpu::get_command_encoder(stream());
enc.add_temporary(in);
}
}
void Recv::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 0);
assert(outputs.size() == 1);
outputs[0].set_data(allocator::malloc(outputs[0].nbytes()));
distributed::detail::recv(group(), outputs[0], src_, stream());
}
} // namespace mlx::core::distributed

View File

@@ -3,6 +3,7 @@
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/array.h" #include "mlx/array.h"
#include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/lapack.h" #include "mlx/backend/cpu/lapack.h"
#include "mlx/linalg.h" #include "mlx/linalg.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
@@ -16,59 +17,71 @@ void eigh_impl(
array& vectors, array& vectors,
array& values, array& values,
const std::string& uplo, const std::string& uplo,
bool compute_eigenvectors) { bool compute_eigenvectors,
Stream stream) {
auto vec_ptr = vectors.data<T>(); auto vec_ptr = vectors.data<T>();
auto eig_ptr = values.data<T>(); auto eig_ptr = values.data<T>();
char jobz = compute_eigenvectors ? 'V' : 'N'; char jobz = compute_eigenvectors ? 'V' : 'N';
auto N = vectors.shape(-1);
// Work query auto& encoder = cpu::get_command_encoder(stream);
int lwork = -1; encoder.set_output_array(vectors);
int liwork = -1; encoder.set_output_array(values);
int info; encoder.dispatch([vec_ptr,
{ eig_ptr,
T work; jobz,
int iwork; uplo = uplo[0],
syevd<T>( N = vectors.shape(-1),
&jobz, size = vectors.size()]() mutable {
uplo.c_str(), // Work query
&N, int lwork = -1;
nullptr, int liwork = -1;
&N, int info;
nullptr, {
&work, T work;
&lwork, int iwork;
&iwork, syevd<T>(
&liwork, &jobz,
&info); &uplo,
lwork = static_cast<int>(work); &N,
liwork = iwork; nullptr,
} &N,
nullptr,
auto work_buf = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)}; &work,
auto iwork_buf = array::Data{allocator::malloc_or_wait(sizeof(int) * liwork)}; &lwork,
for (size_t i = 0; i < vectors.size() / (N * N); ++i) { &iwork,
syevd<T>( &liwork,
&jobz, &info);
uplo.c_str(), lwork = static_cast<int>(work);
&N, liwork = iwork;
vec_ptr,
&N,
eig_ptr,
static_cast<T*>(work_buf.buffer.raw_ptr()),
&lwork,
static_cast<int*>(iwork_buf.buffer.raw_ptr()),
&liwork,
&info);
vec_ptr += N * N;
eig_ptr += N;
if (info != 0) {
std::stringstream msg;
msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code "
<< info;
throw std::runtime_error(msg.str());
} }
auto work_buf = array::Data{allocator::malloc(sizeof(T) * lwork)};
auto iwork_buf = array::Data{allocator::malloc(sizeof(int) * liwork)};
for (size_t i = 0; i < size / (N * N); ++i) {
syevd<T>(
&jobz,
&uplo,
&N,
vec_ptr,
&N,
eig_ptr,
static_cast<T*>(work_buf.buffer.raw_ptr()),
&lwork,
static_cast<int*>(iwork_buf.buffer.raw_ptr()),
&liwork,
&info);
vec_ptr += N * N;
eig_ptr += N;
if (info != 0) {
std::stringstream msg;
msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code "
<< info;
throw std::runtime_error(msg.str());
}
}
});
if (!compute_eigenvectors) {
encoder.add_temporary(vectors);
} }
} }
@@ -84,12 +97,13 @@ void Eigh::eval_cpu(
? outputs[1] ? outputs[1]
: array(a.shape(), a.dtype(), nullptr, {}); : array(a.shape(), a.dtype(), nullptr, {});
values.set_data(allocator::malloc_or_wait(values.nbytes())); values.set_data(allocator::malloc(values.nbytes()));
copy( copy(
a, a,
vectors, vectors,
a.flags().row_contiguous ? CopyType::Vector : CopyType::General); a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
stream());
if (compute_eigenvectors_) { if (compute_eigenvectors_) {
// Set the strides and flags so the eigenvectors // Set the strides and flags so the eigenvectors
@@ -107,14 +121,15 @@ void Eigh::eval_cpu(
flags.col_contiguous = true; flags.col_contiguous = true;
} }
} }
vectors.move_shared_buffer(vectors, strides, flags, vectors.data_size()); vectors.copy_shared_buffer(vectors, strides, flags, vectors.data_size());
} }
switch (a.dtype()) { switch (a.dtype()) {
case float32: case float32:
eigh_impl<float>(vectors, values, uplo_, compute_eigenvectors_); eigh_impl<float>(vectors, values, uplo_, compute_eigenvectors_, stream());
break; break;
case float64: case float64:
eigh_impl<double>(vectors, values, uplo_, compute_eigenvectors_); eigh_impl<double>(
vectors, values, uplo_, compute_eigenvectors_, stream());
break; break;
default: default:
throw std::runtime_error( throw std::runtime_error(

View File

@@ -0,0 +1,16 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cpu/encoder.h"
namespace mlx::core::cpu {
CommandEncoder& get_command_encoder(Stream stream) {
static std::unordered_map<int, CommandEncoder> encoder_map;
auto it = encoder_map.find(stream.index);
if (it == encoder_map.end()) {
it = encoder_map.emplace(stream.index, stream).first;
}
return it->second;
}
} // namespace mlx::core::cpu

67
mlx/backend/cpu/encoder.h Normal file
View File

@@ -0,0 +1,67 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <unordered_map>
#include "mlx/array.h"
#include "mlx/scheduler.h"
namespace mlx::core::cpu {
// Number of dispatches per scheduler task
constexpr int DISPATCHES_PER_TASK = 10;
struct CommandEncoder {
CommandEncoder(Stream stream) : stream_(stream) {}
CommandEncoder(const CommandEncoder&) = delete;
CommandEncoder& operator=(const CommandEncoder&) = delete;
CommandEncoder(CommandEncoder&&) = delete;
CommandEncoder& operator=(CommandEncoder&&) = delete;
void set_input_array(const array& a) {}
void set_output_array(array& a) {}
// Hold onto a temporary until any already scheduled tasks which use it as
// an input are complete.
void add_temporary(array arr) {
temporaries_.push_back(std::move(arr));
}
void add_temporaries(std::vector<array> arrays) {
temporaries_.insert(
temporaries_.end(),
std::make_move_iterator(arrays.begin()),
std::make_move_iterator(arrays.end()));
}
std::vector<array>& temporaries() {
return temporaries_;
}
template <class F, class... Args>
void dispatch(F&& f, Args&&... args) {
num_ops_ = (num_ops_ + 1) % DISPATCHES_PER_TASK;
auto task = std::bind(std::forward<F>(f), std::forward<Args>(args)...);
if (num_ops_ == 0) {
scheduler::notify_new_task(stream_);
auto task_wrap = [s = stream_, task = std::move(task)]() mutable {
task();
scheduler::notify_task_completion(s);
};
scheduler::enqueue(stream_, std::move(task_wrap));
} else {
scheduler::enqueue(stream_, std::move(task));
}
}
private:
Stream stream_;
std::vector<array> temporaries_;
int num_ops_{0};
};
CommandEncoder& get_command_encoder(Stream stream);
} // namespace mlx::core::cpu

40
mlx/backend/cpu/eval.cpp Normal file
View File

@@ -0,0 +1,40 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cpu/eval.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/primitives.h"
#include "mlx/scheduler.h"
#include "mlx/utils.h"
namespace mlx::core::cpu {
void eval(array& arr) {
auto s = arr.primitive().stream();
auto outputs = arr.outputs();
{
// If the array is a tracer hold a reference
// to its inputs so they don't get donated
std::vector<array> inputs;
if (arr.is_tracer()) {
inputs = arr.inputs();
}
arr.primitive().eval_cpu(arr.inputs(), outputs);
}
std::unordered_set<std::shared_ptr<array::Data>> buffers;
for (auto& in : arr.inputs()) {
buffers.insert(in.data_shared_ptr());
}
for (auto& s : arr.siblings()) {
buffers.insert(s.data_shared_ptr());
}
// Remove the output if it was donated to by an input
if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) {
buffers.erase(it);
}
auto& encoder = cpu::get_command_encoder(s);
encoder.dispatch([buffers = std::move(buffers),
temps = std::move(encoder.temporaries())]() {});
}
} // namespace mlx::core::cpu

12
mlx/backend/cpu/eval.h Normal file
View File

@@ -0,0 +1,12 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/array.h"
#include "mlx/stream.h"
namespace mlx::core::cpu {
void eval(array& arr);
} // namespace mlx::core::cpu

View File

@@ -4,6 +4,7 @@
#include "mlx/3rdparty/pocketfft.h" #include "mlx/3rdparty/pocketfft.h"
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
namespace mlx::core { namespace mlx::core {
@@ -21,7 +22,7 @@ void FFT::eval_cpu(const std::vector<array>& inputs, array& out) {
s *= out.itemsize(); s *= out.itemsize();
} }
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
std::vector<size_t> shape; std::vector<size_t> shape;
if (out.dtype() == float32) { if (out.dtype() == float32) {
@@ -38,46 +39,78 @@ void FFT::eval_cpu(const std::vector<array>& inputs, array& out) {
}); });
scale /= nelem; scale /= nelem;
} }
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(in);
encoder.set_output_array(out);
if (in.dtype() == complex64 && out.dtype() == complex64) { if (in.dtype() == complex64 && out.dtype() == complex64) {
auto in_ptr = auto in_ptr =
reinterpret_cast<const std::complex<float>*>(in.data<complex64_t>()); reinterpret_cast<const std::complex<float>*>(in.data<complex64_t>());
auto out_ptr = auto out_ptr =
reinterpret_cast<std::complex<float>*>(out.data<complex64_t>()); reinterpret_cast<std::complex<float>*>(out.data<complex64_t>());
pocketfft::c2c( encoder.dispatch([shape = std::move(shape),
shape, strides_in = std::move(strides_in),
strides_in, strides_out = std::move(strides_out),
strides_out, axes = axes_,
axes_, inverse = inverse_,
!inverse_, in_ptr,
in_ptr, out_ptr,
out_ptr, scale]() {
scale); pocketfft::c2c(
shape,
strides_in,
strides_out,
axes,
!inverse,
in_ptr,
out_ptr,
scale);
});
} else if (in.dtype() == float32 && out.dtype() == complex64) { } else if (in.dtype() == float32 && out.dtype() == complex64) {
auto in_ptr = in.data<float>(); auto in_ptr = in.data<float>();
auto out_ptr = auto out_ptr =
reinterpret_cast<std::complex<float>*>(out.data<complex64_t>()); reinterpret_cast<std::complex<float>*>(out.data<complex64_t>());
pocketfft::r2c( encoder.dispatch([shape = std::move(shape),
shape, strides_in = std::move(strides_in),
strides_in, strides_out = std::move(strides_out),
strides_out, axes = axes_,
axes_, inverse = inverse_,
!inverse_, in_ptr,
in_ptr, out_ptr,
out_ptr, scale]() {
scale); pocketfft::r2c(
shape,
strides_in,
strides_out,
axes,
!inverse,
in_ptr,
out_ptr,
scale);
});
} else if (in.dtype() == complex64 && out.dtype() == float32) { } else if (in.dtype() == complex64 && out.dtype() == float32) {
auto in_ptr = auto in_ptr =
reinterpret_cast<const std::complex<float>*>(in.data<complex64_t>()); reinterpret_cast<const std::complex<float>*>(in.data<complex64_t>());
auto out_ptr = out.data<float>(); auto out_ptr = out.data<float>();
pocketfft::c2r( encoder.dispatch([shape = std::move(shape),
shape, strides_in = std::move(strides_in),
strides_in, strides_out = std::move(strides_out),
strides_out, axes = axes_,
axes_, inverse = inverse_,
!inverse_, in_ptr,
in_ptr, out_ptr,
out_ptr, scale]() {
scale); pocketfft::c2r(
shape,
strides_in,
strides_out,
axes,
!inverse,
in_ptr,
out_ptr,
scale);
});
} else { } else {
throw std::runtime_error( throw std::runtime_error(
"[FFT] Received unexpected input and output type combination."); "[FFT] Received unexpected input and output type combination.");

View File

@@ -7,14 +7,20 @@ namespace mlx::core {
template <typename T> template <typename T>
void matmul( void matmul(
const array& a, const T* a,
const array& b, const T* b,
array& out, T* out,
bool a_transposed, bool a_transposed,
bool b_transposed, bool b_transposed,
size_t lda, size_t lda,
size_t ldb, size_t ldb,
size_t ldc,
float alpha, float alpha,
float beta); float beta,
size_t batch_size,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides);
} // namespace mlx::core } // namespace mlx::core

View File

@@ -9,39 +9,46 @@
namespace mlx::core { namespace mlx::core {
BNNSDataType to_bnns_dtype(Dtype mlx_dtype) { template <typename T>
uint32_t size_bits = size_of(mlx_dtype) * 8; constexpr BNNSDataType to_bnns_dtype();
switch (kindof(mlx_dtype)) {
case Dtype::Kind::b: template <>
return BNNSDataTypeBoolean; constexpr BNNSDataType to_bnns_dtype<float>() {
case Dtype::Kind::u: return BNNSDataType(BNNSDataTypeFloatBit | 32);
return BNNSDataType(BNNSDataTypeUIntBit | size_bits); }
case Dtype::Kind::i: template <>
return BNNSDataType(BNNSDataTypeIntBit | size_bits); constexpr BNNSDataType to_bnns_dtype<float16_t>() {
case Dtype::Kind::f: return BNNSDataType(BNNSDataTypeFloatBit | 16);
return BNNSDataType(BNNSDataTypeFloatBit | size_bits);
case Dtype::Kind::V:
return BNNSDataTypeBFloat16;
case Dtype::Kind::c:
throw std::invalid_argument("BNNS does not support complex types");
}
} }
template <>
constexpr BNNSDataType to_bnns_dtype<bfloat16_t>() {
return BNNSDataTypeBFloat16;
}
template <typename T>
void matmul_bnns( void matmul_bnns(
const array& a, const T* a,
const array& b, const T* b,
array& out, T* out,
bool a_transposed, bool a_transposed,
bool b_transposed, bool b_transposed,
size_t lda, size_t lda,
size_t ldb, size_t ldb,
size_t ldc,
float alpha, float alpha,
float beta) { float beta,
size_t M = a.shape(-2); size_t batch_size,
size_t N = b.shape(-1); const Shape& a_shape,
size_t K = a.shape(-1); const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides) {
auto ndim = a_shape.size();
size_t M = a_shape[ndim - 2];
size_t N = b_shape[ndim - 1];
size_t K = a_shape[ndim - 1];
BNNSDataType bnns_dtype = to_bnns_dtype(out.dtype()); BNNSDataType bnns_dtype = to_bnns_dtype<T>();
#pragma GCC diagnostic push #pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wdeprecated-declarations" #pragma GCC diagnostic ignored "-Wdeprecated-declarations"
@@ -115,14 +122,14 @@ void matmul_bnns(
auto bnns_filter = auto bnns_filter =
BNNSFilterCreateLayerBroadcastMatMul(&gemm_params, nullptr); BNNSFilterCreateLayerBroadcastMatMul(&gemm_params, nullptr);
for (int i = 0; i < (a.size() / (M * K)); ++i) { for (int i = 0; i < batch_size; ++i) {
BNNSFilterApplyTwoInput( BNNSFilterApplyTwoInput(
bnns_filter, bnns_filter,
a.data<uint8_t>() + reinterpret_cast<const uint8_t*>(
elem_to_loc(M * K * i, a.shape(), a.strides()) * a.itemsize(), a + elem_to_loc(M * K * i, a_shape, a_strides)),
b.data<uint8_t>() + reinterpret_cast<const uint8_t*>(
elem_to_loc(K * N * i, b.shape(), b.strides()) * b.itemsize(), b + elem_to_loc(K * N * i, b_shape, b_strides)),
out.data<uint8_t>() + M * N * i * out.itemsize()); reinterpret_cast<uint8_t*>(out + M * N * i));
} }
BNNSFilterDestroy(bnns_filter); BNNSFilterDestroy(bnns_filter);
@@ -131,30 +138,72 @@ void matmul_bnns(
template <> template <>
void matmul<float16_t>( void matmul<float16_t>(
const array& a, const float16_t* a,
const array& b, const float16_t* b,
array& out, float16_t* out,
bool a_transposed, bool a_transposed,
bool b_transposed, bool b_transposed,
size_t lda, size_t lda,
size_t ldb, size_t ldb,
size_t ldc,
float alpha, float alpha,
float beta) { float beta,
matmul_bnns(a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta); size_t batch_size,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides) {
matmul_bnns(
a,
b,
out,
a_transposed,
b_transposed,
lda,
ldb,
ldc,
alpha,
beta,
batch_size,
a_shape,
a_strides,
b_shape,
b_strides);
} }
template <> template <>
void matmul<bfloat16_t>( void matmul<bfloat16_t>(
const array& a, const bfloat16_t* a,
const array& b, const bfloat16_t* b,
array& out, bfloat16_t* out,
bool a_transposed, bool a_transposed,
bool b_transposed, bool b_transposed,
size_t lda, size_t lda,
size_t ldb, size_t ldb,
size_t ldc,
float alpha, float alpha,
float beta) { float beta,
matmul_bnns(a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta); size_t batch_size,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides) {
matmul_bnns(
a,
b,
out,
a_transposed,
b_transposed,
lda,
ldb,
ldc,
alpha,
beta,
batch_size,
a_shape,
a_strides,
b_shape,
b_strides);
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -8,20 +8,27 @@ namespace mlx::core {
template <> template <>
void matmul<float>( void matmul<float>(
const array& a, const float* a,
const array& b, const float* b,
array& out, float* out,
bool a_transposed, bool a_transposed,
bool b_transposed, bool b_transposed,
size_t lda, size_t lda,
size_t ldb, size_t ldb,
size_t ldc,
float alpha, float alpha,
float beta) { float beta,
size_t M = a.shape(-2); size_t batch_size,
size_t N = b.shape(-1); const Shape& a_shape,
size_t K = a.shape(-1); const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides) {
auto ndim = a_shape.size();
size_t M = a_shape[ndim - 2];
size_t N = b_shape[ndim - 1];
size_t K = a_shape[ndim - 1];
for (int i = 0; i < (a.size() / (M * K)); ++i) { for (int i = 0; i < batch_size; ++i) {
cblas_sgemm( cblas_sgemm(
CblasRowMajor, CblasRowMajor,
a_transposed ? CblasTrans : CblasNoTrans, // transA a_transposed ? CblasTrans : CblasNoTrans, // transA
@@ -29,34 +36,40 @@ void matmul<float>(
M, M,
N, N,
K, K,
alpha, // alpha alpha,
a.data<float>() + elem_to_loc(M * K * i, a.shape(), a.strides()), a + elem_to_loc(M * K * i, a_shape, a_strides),
lda, lda,
b.data<float>() + elem_to_loc(K * N * i, b.shape(), b.strides()), b + elem_to_loc(K * N * i, b_shape, b_strides),
ldb, ldb,
beta, // beta beta,
out.data<float>() + M * N * i, out + M * N * i,
out.shape(-1) // ldc ldc);
);
} }
} }
template <> template <>
void matmul<double>( void matmul<double>(
const array& a, const double* a,
const array& b, const double* b,
array& out, double* out,
bool a_transposed, bool a_transposed,
bool b_transposed, bool b_transposed,
size_t lda, size_t lda,
size_t ldb, size_t ldb,
size_t ldc,
float alpha, float alpha,
float beta) { float beta,
size_t M = a.shape(-2); size_t batch_size,
size_t N = b.shape(-1); const Shape& a_shape,
size_t K = a.shape(-1); const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides) {
auto ndim = a_shape.size();
size_t M = a_shape[ndim - 2];
size_t N = b_shape[ndim - 1];
size_t K = a_shape[ndim - 1];
for (int i = 0; i < (a.size() / (M * K)); ++i) { for (int i = 0; i < batch_size; ++i) {
cblas_dgemm( cblas_dgemm(
CblasRowMajor, CblasRowMajor,
a_transposed ? CblasTrans : CblasNoTrans, // transA a_transposed ? CblasTrans : CblasNoTrans, // transA
@@ -64,15 +77,14 @@ void matmul<double>(
M, M,
N, N,
K, K,
alpha, // alpha alpha,
a.data<double>() + elem_to_loc(M * K * i, a.shape(), a.strides()), a + elem_to_loc(M * K * i, a_shape, a_strides),
lda, lda,
b.data<double>() + elem_to_loc(K * N * i, b.shape(), b.strides()), b + elem_to_loc(K * N * i, b_shape, b_strides),
ldb, ldb,
beta, // beta beta,
out.data<double>() + M * N * i, out + M * N * i,
out.shape(-1) // ldc ldc);
);
} }
} }

View File

@@ -1,21 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cpu/gemm.h"
namespace mlx::core {
template <>
void matmul<bfloat16_t>(
const array&,
const array&,
array&,
bool,
bool,
size_t,
size_t,
float,
float) {
throw std::runtime_error("[Matmul::eval_cpu] bfloat16 not supported.");
}
} // namespace mlx::core

View File

@@ -1,21 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cpu/gemm.h"
namespace mlx::core {
template <>
void matmul<float16_t>(
const array&,
const array&,
array&,
bool,
bool,
size_t,
size_t,
float,
float) {
throw std::runtime_error("[Matmul::eval_cpu] float16 not supported.");
}
} // namespace mlx::core

View File

@@ -0,0 +1,45 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/gemm.h"
#include "mlx/backend/cpu/gemms/simd_gemm.h"
namespace mlx::core {
template <>
void matmul<bfloat16_t>(
const bfloat16_t* a,
const bfloat16_t* b,
bfloat16_t* out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
size_t ldc,
float alpha,
float beta,
size_t batch_size,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides) {
auto ndim = a_shape.size();
size_t M = a_shape[ndim - 2];
size_t N = b_shape[ndim - 1];
size_t K = a_shape[ndim - 1];
for (int i = 0; i < batch_size; ++i) {
simd_gemm<bfloat16_t, float>(
a + elem_to_loc(M * K * i, a_shape, a_strides),
b + elem_to_loc(K * N * i, b_shape, b_strides),
out + M * N * i,
a_transposed,
b_transposed,
M,
N,
K,
alpha,
beta);
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,45 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/gemm.h"
#include "mlx/backend/cpu/gemms/simd_gemm.h"
namespace mlx::core {
template <>
void matmul<float16_t>(
const float16_t* a,
const float16_t* b,
float16_t* out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
size_t ldc,
float alpha,
float beta,
size_t batch_size,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides) {
auto ndim = a_shape.size();
size_t M = a_shape[ndim - 2];
size_t N = b_shape[ndim - 1];
size_t K = a_shape[ndim - 1];
for (int i = 0; i < batch_size; ++i) {
simd_gemm<float16_t, float>(
a + elem_to_loc(M * K * i, a_shape, a_strides),
b + elem_to_loc(K * N * i, b_shape, b_strides),
out + M * N * i,
a_transposed,
b_transposed,
M,
N,
K,
alpha,
beta);
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,139 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/backend/cpu/simd/simd.h"
namespace mlx::core {
inline int ceildiv(int a, int b) {
return (a + b - 1) / b;
}
template <int block_size, typename T, typename AccT>
void load_block(
const T* in,
AccT* out,
int M,
int N,
int i,
int j,
bool transpose) {
if (transpose) {
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
out[jj * block_size + ii] =
in[(i * block_size + ii) * N + j * block_size + jj];
}
}
} else {
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
out[ii * block_size + jj] =
in[(i * block_size + ii) * N + j * block_size + jj];
}
}
}
}
template <typename T, typename AccT>
void simd_gemm(
const T* a,
const T* b,
T* c,
bool a_trans,
bool b_trans,
int M,
int N,
int K,
float alpha,
float beta) {
constexpr int block_size = 16;
constexpr int simd_size = simd::max_size<AccT>;
static_assert(
(block_size % simd_size) == 0,
"Block size must be divisible by SIMD size");
int last_k_block_size = K - block_size * (K / block_size);
int last_k_simd_block = (last_k_block_size / simd_size) * simd_size;
for (int i = 0; i < ceildiv(M, block_size); i++) {
for (int j = 0; j < ceildiv(N, block_size); j++) {
AccT c_block[block_size * block_size] = {0.0};
AccT a_block[block_size * block_size];
AccT b_block[block_size * block_size];
int k = 0;
for (; k < K / block_size; k++) {
// Load a and b blocks
if (a_trans) {
load_block<block_size>(a, a_block, K, M, k, i, true);
} else {
load_block<block_size>(a, a_block, M, K, i, k, false);
}
if (b_trans) {
load_block<block_size>(b, b_block, N, K, j, k, false);
} else {
load_block<block_size>(b, b_block, K, N, k, j, true);
}
// Multiply and accumulate
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
for (int kk = 0; kk < block_size; kk += simd_size) {
auto av =
simd::load<AccT, simd_size>(a_block + ii * block_size + kk);
auto bv =
simd::load<AccT, simd_size>(b_block + jj * block_size + kk);
c_block[ii * block_size + jj] += simd::sum(av * bv);
}
}
}
}
if (last_k_block_size) {
// Load a and b blocks
if (a_trans) {
load_block<block_size>(a, a_block, K, M, k, i, true);
} else {
load_block<block_size>(a, a_block, M, K, i, k, false);
}
if (b_trans) {
load_block<block_size>(b, b_block, N, K, j, k, false);
} else {
load_block<block_size>(b, b_block, K, N, k, j, true);
}
// Multiply and accumulate
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
int kk = 0;
for (; kk < last_k_simd_block; kk += simd_size) {
auto av =
simd::load<AccT, simd_size>(a_block + ii * block_size + kk);
auto bv =
simd::load<AccT, simd_size>(b_block + jj * block_size + kk);
c_block[ii * block_size + jj] += simd::sum(av * bv);
}
for (; kk < last_k_block_size; ++kk) {
c_block[ii * block_size + jj] +=
a_block[ii * block_size + kk] * b_block[jj * block_size + kk];
}
}
}
}
// Store
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
auto c_idx = (i * block_size + ii) * N + j * block_size + jj;
if (beta != 0) {
c[c_idx] = static_cast<T>(
alpha * c_block[ii * block_size + jj] + beta * c[c_idx]);
} else {
c[c_idx] = static_cast<T>(alpha * c_block[ii * block_size + jj]);
}
}
}
}
}
}
} // namespace mlx::core

View File

@@ -4,16 +4,17 @@
#include "mlx/backend/common/hadamard.h" #include "mlx/backend/common/hadamard.h"
#include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
namespace mlx::core { namespace mlx::core {
// n = 2^k component // n = 2^k component
template <typename T> template <typename T>
void hadamard_n(array& out, int n, int m, float scale) { void hadamard_n(T* out, int n, int m, float scale, size_t size) {
for (int b = 0; b < out.size() / n; b++) { for (int b = 0; b < size / n; b++) {
size_t loc = b * n; size_t loc = b * n;
T* data_ptr = out.data<T>() + loc; T* data_ptr = out + loc;
int h = 1; int h = 1;
int n_over_2 = n / 2; int n_over_2 = n / 2;
while (h < n) { while (h < n) {
@@ -36,7 +37,7 @@ void hadamard_n(array& out, int n, int m, float scale) {
// m component // m component
template <typename T> template <typename T>
void hadamard_m(array& out, int n, int m, float scale) { void hadamard_m(T* out, int n, int m, float scale, size_t size) {
auto h_matrices = hadamard_matrices(); auto h_matrices = hadamard_matrices();
auto& matrix = h_matrices[m]; auto& matrix = h_matrices[m];
auto start = 1; auto start = 1;
@@ -51,9 +52,9 @@ void hadamard_m(array& out, int n, int m, float scale) {
end = matrix.find('\n', start); end = matrix.find('\n', start);
} }
for (int b = 0; b < out.size() / m / n; b++) { for (int b = 0; b < size / m / n; b++) {
size_t loc = b * n * m; size_t loc = b * n * m;
T* data_ptr = out.data<T>() + loc; T* data_ptr = out + loc;
for (int i = 0; i < n; i++) { for (int i = 0; i < n; i++) {
std::vector<float> out(m); std::vector<float> out(m);
for (int j = 0; j < m; j++) { for (int j = 0; j < m; j++) {
@@ -74,12 +75,17 @@ void hadamard_m(array& out, int n, int m, float scale) {
} }
template <typename T> template <typename T>
void hadamard(array& out, int n, int m, float scale) { void hadamard(array& out, int n, int m, float scale, Stream stream) {
float n_scale = m > 1 ? 1.0 : scale; auto& encoder = cpu::get_command_encoder(stream);
hadamard_n<T>(out, n, m, n_scale); encoder.set_output_array(out);
if (m > 1) { auto out_ptr = out.data<T>();
hadamard_m<T>(out, n, m, scale); encoder.dispatch([out_ptr, size = out.size(), n, m, scale]() {
} float n_scale = m > 1 ? 1.0 : scale;
hadamard_n<T>(out_ptr, n, m, n_scale, size);
if (m > 1) {
hadamard_m<T>(out_ptr, n, m, scale, size);
}
});
} }
void Hadamard::eval_cpu(const std::vector<array>& inputs, array& out) { void Hadamard::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -87,18 +93,26 @@ void Hadamard::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0]; auto& in = inputs[0];
// Copy input to output // Copy input to output
copy(in, out, CopyType::General); if (in.flags().row_contiguous && in.is_donatable()) {
out.copy_shared_buffer(in);
} else {
copy(
in,
out,
in.flags().row_contiguous ? CopyType::Vector : CopyType::General,
stream());
}
int axis = out.ndim() - 1; int axis = out.ndim() - 1;
auto [n, m] = decompose_hadamard(out.shape(axis)); auto [n, m] = decompose_hadamard(out.shape(axis));
switch (in.dtype()) { switch (in.dtype()) {
case float32: case float32:
return hadamard<float>(out, n, m, scale_); return hadamard<float>(out, n, m, scale_, stream());
case float16: case float16:
return hadamard<float16_t>(out, n, m, scale_); return hadamard<float16_t>(out, n, m, scale_, stream());
case bfloat16: case bfloat16:
return hadamard<bfloat16_t>(out, n, m, scale_); return hadamard<bfloat16_t>(out, n, m, scale_, stream());
default: default:
throw std::invalid_argument("[hadamard] Unsupported type."); throw std::invalid_argument("[hadamard] Unsupported type.");
} }

View File

@@ -8,6 +8,7 @@
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
namespace mlx::core { namespace mlx::core {
@@ -21,6 +22,40 @@ inline size_t offset_neg_idx(uint32_t idx, size_t) {
return idx; return idx;
} }
struct None {
template <typename T>
void operator()(T x, T* y) {
(*y) = x;
}
};
struct Sum {
template <typename T>
void operator()(T x, T* y) {
(*y) += x;
}
};
struct Prod {
template <typename T>
void operator()(T x, T* y) {
(*y) *= x;
}
};
struct Max {
template <typename T>
void operator()(T x, T* y) {
(*y) = (*y > x) ? *y : x;
}
};
struct Min {
template <typename T>
void operator()(T x, T* y) {
(*y) = (*y < x) ? *y : x;
}
};
template <typename T, typename IdxT> template <typename T, typename IdxT>
void gather( void gather(
const array& src, const array& src,
@@ -73,13 +108,14 @@ void gather(
size_t ind_size = slice_size == 0 ? 0 : out.size() / slice_size; size_t ind_size = slice_size == 0 ? 0 : out.size() / slice_size;
const T* src_ptr = src.data<T>(); const T* src_ptr = src.data<T>();
T* dst_ptr = out.data<T>(); T* dst_ptr = out.data<T>();
size_t out_idx = 0;
std::vector<ContiguousIterator> its(inds.begin(), inds.end()); std::vector<ContiguousIterator> its(inds.begin(), inds.end());
ContiguousIterator src_it; ContiguousIterator src_it;
if (!can_copy && src.ndim() > 0) { if (!can_copy && src.ndim() > 0) {
src_it = ContiguousIterator(slice_sizes, src.strides(), src.ndim()); src_it = ContiguousIterator(slice_sizes, src.strides(), src.ndim());
} }
size_t out_idx = 0;
for (int idx = 0; idx < ind_size; idx++) { for (int idx = 0; idx < ind_size; idx++) {
size_t src_idx = 0; size_t src_idx = 0;
for (int ii = 0; ii < inds.size(); ++ii) { for (int ii = 0; ii < inds.size(); ++ii) {
@@ -161,46 +197,59 @@ void dispatch_gather(
} }
void Gather::eval_cpu(const std::vector<array>& inputs, array& out) { void Gather::eval_cpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
auto& src = inputs[0]; auto& src = inputs[0];
std::vector<array> inds(inputs.begin() + 1, inputs.end()); std::vector<array> inds;
for (auto it = inputs.begin() + 1; it < inputs.end(); ++it) {
if (inds.empty()) { inds.push_back(array::unsafe_weak_copy(*it));
dispatch_gather<uint8_t>(src, inds, out, axes_, slice_sizes_);
return;
} }
auto& encoder = cpu::get_command_encoder(stream());
switch (inds[0].dtype()) { for (auto& in : inputs) {
case uint8: encoder.set_input_array(in);
}
encoder.set_output_array(out);
encoder.dispatch([axes_ = axes_,
slice_sizes_ = slice_sizes_,
src = array::unsafe_weak_copy(src),
inds = std::move(inds),
out = array::unsafe_weak_copy(out)]() mutable {
if (inds.empty()) {
dispatch_gather<uint8_t>(src, inds, out, axes_, slice_sizes_); dispatch_gather<uint8_t>(src, inds, out, axes_, slice_sizes_);
break; return;
case uint16: }
dispatch_gather<uint16_t>(src, inds, out, axes_, slice_sizes_);
break; switch (inds[0].dtype()) {
case uint32: case uint8:
dispatch_gather<uint32_t>(src, inds, out, axes_, slice_sizes_); dispatch_gather<uint8_t>(src, inds, out, axes_, slice_sizes_);
break; break;
case uint64: case uint16:
dispatch_gather<uint64_t>(src, inds, out, axes_, slice_sizes_); dispatch_gather<uint16_t>(src, inds, out, axes_, slice_sizes_);
break; break;
case int8: case uint32:
dispatch_gather<int8_t>(src, inds, out, axes_, slice_sizes_); dispatch_gather<uint32_t>(src, inds, out, axes_, slice_sizes_);
break; break;
case int16: case uint64:
dispatch_gather<int16_t>(src, inds, out, axes_, slice_sizes_); dispatch_gather<uint64_t>(src, inds, out, axes_, slice_sizes_);
break; break;
case int32: case int8:
dispatch_gather<int32_t>(src, inds, out, axes_, slice_sizes_); dispatch_gather<int8_t>(src, inds, out, axes_, slice_sizes_);
break; break;
case int64: case int16:
dispatch_gather<int64_t>(src, inds, out, axes_, slice_sizes_); dispatch_gather<int16_t>(src, inds, out, axes_, slice_sizes_);
break; break;
default: case int32:
throw std::runtime_error( dispatch_gather<int32_t>(src, inds, out, axes_, slice_sizes_);
"[Gather::eval_cpu] Cannot gather with indices type."); break;
break; case int64:
} dispatch_gather<int64_t>(src, inds, out, axes_, slice_sizes_);
break;
default:
throw std::runtime_error(
"[Gather::eval_cpu] Cannot gather with indices type.");
break;
}
});
} }
template <typename T, typename IdxT> template <typename T, typename IdxT>
void gather_axis( void gather_axis(
@@ -235,6 +284,7 @@ void gather_axis(
for (int i = axis + 1; i < ind.ndim(); ++i) { for (int i = axis + 1; i < ind.ndim(); ++i) {
size_post *= ind.shape(i); size_post *= ind.shape(i);
} }
size_t stride_pre = size_post * ind_ax_size; size_t stride_pre = size_post * ind_ax_size;
for (size_t i = 0; i < size_pre; i++) { for (size_t i = 0; i < size_pre; i++) {
for (size_t k = 0; k < size_post; k++) { for (size_t k = 0; k < size_post; k++) {
@@ -304,39 +354,49 @@ void dispatch_gather_axis(
} }
void GatherAxis::eval_cpu(const std::vector<array>& inputs, array& out) { void GatherAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
auto& src = inputs[0]; auto& src = inputs[0];
auto& inds = inputs[1]; auto& inds = inputs[1];
switch (inds.dtype()) { auto& encoder = cpu::get_command_encoder(stream());
case uint8: encoder.set_input_array(src);
dispatch_gather_axis<uint8_t>(src, inds, out, axis_); encoder.set_input_array(inds);
break; encoder.set_output_array(out);
case uint16: encoder.dispatch([axis_ = axis_,
dispatch_gather_axis<uint16_t>(src, inds, out, axis_); src = array::unsafe_weak_copy(src),
break; inds = array::unsafe_weak_copy(inds),
case uint32: out = array::unsafe_weak_copy(out)]() mutable {
dispatch_gather_axis<uint32_t>(src, inds, out, axis_); switch (inds.dtype()) {
break; case uint8:
case uint64: dispatch_gather_axis<uint8_t>(src, inds, out, axis_);
dispatch_gather_axis<uint64_t>(src, inds, out, axis_); break;
break; case uint16:
case int8: dispatch_gather_axis<uint16_t>(src, inds, out, axis_);
dispatch_gather_axis<int8_t>(src, inds, out, axis_); break;
break; case uint32:
case int16: dispatch_gather_axis<uint32_t>(src, inds, out, axis_);
dispatch_gather_axis<int16_t>(src, inds, out, axis_); break;
break; case uint64:
case int32: dispatch_gather_axis<uint64_t>(src, inds, out, axis_);
dispatch_gather_axis<int32_t>(src, inds, out, axis_); break;
break; case int8:
case int64: dispatch_gather_axis<int8_t>(src, inds, out, axis_);
dispatch_gather_axis<int64_t>(src, inds, out, axis_); break;
break; case int16:
default: dispatch_gather_axis<int16_t>(src, inds, out, axis_);
throw std::runtime_error( break;
"[GatherAxis::eval_cpu] Cannot gather with indices type."); case int32:
break; dispatch_gather_axis<int32_t>(src, inds, out, axis_);
} break;
case int64:
dispatch_gather_axis<int64_t>(src, inds, out, axis_);
break;
default:
throw std::runtime_error(
"[GatherAxis::eval_cpu] Cannot gather with indices type.");
break;
}
});
} }
template <typename InT, typename IdxT, typename OpT> template <typename InT, typename IdxT, typename OpT>
@@ -344,8 +404,7 @@ void scatter(
const array& updates, const array& updates,
array& out, array& out,
const std::vector<array>& inds, const std::vector<array>& inds,
const std::vector<int>& axes, const std::vector<int>& axes) {
const OpT& op) {
int nind = inds.size(); int nind = inds.size();
auto inds_ndim = updates.ndim() - out.ndim(); auto inds_ndim = updates.ndim() - out.ndim();
size_t n_updates = nind ? inds[0].size() : 1; size_t n_updates = nind ? inds[0].size() : 1;
@@ -361,9 +420,11 @@ void scatter(
ContiguousIterator update_it(updates); ContiguousIterator update_it(updates);
ContiguousIterator out_it(update_shape, out.strides(), out.ndim()); ContiguousIterator out_it(update_shape, out.strides(), out.ndim());
auto out_ptr = out.data<InT>();
auto upd_ptr = updates.data<InT>();
for (int i = 0; i < n_updates; ++i) { for (int i = 0; i < n_updates; ++i) {
size_t out_offset = 0; size_t out_offset = 0;
for (int j = 0; j < nind; ++j) { for (int j = 0; j < inds.size(); ++j) {
auto ax = axes[j]; auto ax = axes[j];
auto idx_loc = its[j].loc; auto idx_loc = its[j].loc;
its[j].step(); its[j].step();
@@ -373,8 +434,7 @@ void scatter(
} }
update_it.seek(i * update_size); update_it.seek(i * update_size);
for (int j = 0; j < update_size; ++j) { for (int j = 0; j < update_size; ++j) {
op(updates.data<InT>()[update_it.loc], OpT{}(upd_ptr[update_it.loc], out_ptr + out_offset + out_it.loc);
out.data<InT>() + out_offset + out_it.loc);
update_it.step(); update_it.step();
out_it.step(); out_it.step();
} }
@@ -392,26 +452,19 @@ void dispatch_scatter_inds(
Scatter::ReduceType rtype) { Scatter::ReduceType rtype) {
switch (rtype) { switch (rtype) {
case Scatter::None: case Scatter::None:
scatter<InT, IdxT>( scatter<InT, IdxT, None>(updates, out, indices, axes);
updates, out, indices, axes, [](auto x, auto* y) { (*y) = x; });
break; break;
case Scatter::Sum: case Scatter::Sum:
scatter<InT, IdxT>( scatter<InT, IdxT, Sum>(updates, out, indices, axes);
updates, out, indices, axes, [](auto x, auto* y) { (*y) += x; });
break; break;
case Scatter::Prod: case Scatter::Prod:
scatter<InT, IdxT>( scatter<InT, IdxT, Prod>(updates, out, indices, axes);
updates, out, indices, axes, [](auto x, auto* y) { (*y) *= x; });
break; break;
case Scatter::Max: case Scatter::Max:
scatter<InT, IdxT>(updates, out, indices, axes, [](auto x, auto* y) { scatter<InT, IdxT, Max>(updates, out, indices, axes);
(*y) = (*y > x) ? *y : x;
});
break; break;
case Scatter::Min: case Scatter::Min:
scatter<InT, IdxT>(updates, out, indices, axes, [](auto x, auto* y) { scatter<InT, IdxT, Min>(updates, out, indices, axes);
(*y) = (*y < x) ? *y : x;
});
break; break;
} }
} }
@@ -463,67 +516,75 @@ void Scatter::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() >= 2); assert(inputs.size() >= 2);
auto& src = inputs[0]; auto& src = inputs[0];
std::vector<array> inds(inputs.begin() + 1, inputs.end() - 1);
auto& updates = inputs.back(); auto& updates = inputs.back();
// Copy src into out (copy allocates memory for out) // Copy src into out (copy allocates memory for out)
auto ctype = auto ctype =
src.flags().row_contiguous ? CopyType::Vector : CopyType::General; src.flags().row_contiguous ? CopyType::Vector : CopyType::General;
copy(src, out, ctype); copy(src, out, ctype, stream());
switch (src.dtype()) { auto& encoder = cpu::get_command_encoder(stream());
case bool_: std::vector<array> inds;
dispatch_scatter<bool>(out, inds, updates, axes_, reduce_type_); for (auto it = inputs.begin() + 1; it < inputs.end() - 1; ++it) {
break; encoder.set_input_array(*it);
case uint8: inds.push_back(array::unsafe_weak_copy(*it));
dispatch_scatter<uint8_t>(out, inds, updates, axes_, reduce_type_);
break;
case uint16:
dispatch_scatter<uint16_t>(out, inds, updates, axes_, reduce_type_);
break;
case uint32:
dispatch_scatter<uint32_t>(out, inds, updates, axes_, reduce_type_);
break;
case uint64:
dispatch_scatter<uint64_t>(out, inds, updates, axes_, reduce_type_);
break;
case int8:
dispatch_scatter<int8_t>(out, inds, updates, axes_, reduce_type_);
break;
case int16:
dispatch_scatter<int16_t>(out, inds, updates, axes_, reduce_type_);
break;
case int32:
dispatch_scatter<int32_t>(out, inds, updates, axes_, reduce_type_);
break;
case int64:
dispatch_scatter<int64_t>(out, inds, updates, axes_, reduce_type_);
break;
case float16:
dispatch_scatter<float16_t>(out, inds, updates, axes_, reduce_type_);
break;
case float32:
dispatch_scatter<float>(out, inds, updates, axes_, reduce_type_);
break;
case float64:
dispatch_scatter<double>(out, inds, updates, axes_, reduce_type_);
break;
case bfloat16:
dispatch_scatter<bfloat16_t>(out, inds, updates, axes_, reduce_type_);
break;
case complex64:
dispatch_scatter<complex64_t>(out, inds, updates, axes_, reduce_type_);
break;
} }
encoder.set_input_array(updates);
encoder.set_output_array(out);
encoder.dispatch([axes_ = axes_,
reduce_type_ = reduce_type_,
updates = array::unsafe_weak_copy(updates),
inds = std::move(inds),
out = array::unsafe_weak_copy(out)]() mutable {
switch (out.dtype()) {
case bool_:
dispatch_scatter<bool>(out, inds, updates, axes_, reduce_type_);
break;
case uint8:
dispatch_scatter<uint8_t>(out, inds, updates, axes_, reduce_type_);
break;
case uint16:
dispatch_scatter<uint16_t>(out, inds, updates, axes_, reduce_type_);
break;
case uint32:
dispatch_scatter<uint32_t>(out, inds, updates, axes_, reduce_type_);
break;
case uint64:
dispatch_scatter<uint64_t>(out, inds, updates, axes_, reduce_type_);
break;
case int8:
dispatch_scatter<int8_t>(out, inds, updates, axes_, reduce_type_);
break;
case int16:
dispatch_scatter<int16_t>(out, inds, updates, axes_, reduce_type_);
break;
case int32:
dispatch_scatter<int32_t>(out, inds, updates, axes_, reduce_type_);
break;
case int64:
dispatch_scatter<int64_t>(out, inds, updates, axes_, reduce_type_);
break;
case float16:
dispatch_scatter<float16_t>(out, inds, updates, axes_, reduce_type_);
break;
case float32:
dispatch_scatter<float>(out, inds, updates, axes_, reduce_type_);
break;
case float64:
dispatch_scatter<double>(out, inds, updates, axes_, reduce_type_);
break;
case bfloat16:
dispatch_scatter<bfloat16_t>(out, inds, updates, axes_, reduce_type_);
break;
case complex64:
dispatch_scatter<complex64_t>(out, inds, updates, axes_, reduce_type_);
break;
}
});
} }
template <typename T, typename IdxT, typename OpT> template <typename T, typename IdxT, typename OpT>
void scatter_axis( void scatter_axis(array& out, const array idx, const array& upd, int axis) {
array& out,
const array idx,
const array& upd,
int axis,
const OpT& op) {
auto strides = idx.strides(); auto strides = idx.strides();
strides.erase(strides.begin() + axis); strides.erase(strides.begin() + axis);
auto shape = idx.shape(); auto shape = idx.shape();
@@ -557,8 +618,9 @@ void scatter_axis(
for (int j = 0; j < idx_ax_size; ++j) { for (int j = 0; j < idx_ax_size; ++j) {
auto ind_val = offset_neg_idx( auto ind_val = offset_neg_idx(
idx_ptr[idx_it.loc + j * idx_ax_stride], dst_ax_size); idx_ptr[idx_it.loc + j * idx_ax_stride], dst_ax_size);
op(upd_ptr[upd_it.loc + j * upd_ax_stride], OpT{}(
dst_ptr + k + ind_val * dst_ax_stride); upd_ptr[upd_it.loc + j * upd_ax_stride],
dst_ptr + k + ind_val * dst_ax_stride);
} }
idx_it.step(); idx_it.step();
upd_it.step(); upd_it.step();
@@ -576,12 +638,10 @@ void dispatch_scatter_axis_op(
ScatterAxis::ReduceType rtype) { ScatterAxis::ReduceType rtype) {
switch (rtype) { switch (rtype) {
case ScatterAxis::None: case ScatterAxis::None:
scatter_axis<InT, IdxT>( scatter_axis<InT, IdxT, None>(out, idx, updates, axis);
out, idx, updates, axis, [](auto x, auto* y) { (*y) = x; });
break; break;
case ScatterAxis::Sum: case ScatterAxis::Sum:
scatter_axis<InT, IdxT>( scatter_axis<InT, IdxT, Sum>(out, idx, updates, axis);
out, idx, updates, axis, [](auto x, auto* y) { (*y) += x; });
break; break;
} }
} }
@@ -634,53 +694,65 @@ void ScatterAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
// Copy src into out (copy allocates memory for out) // Copy src into out (copy allocates memory for out)
auto ctype = auto ctype =
src.flags().row_contiguous ? CopyType::Vector : CopyType::General; src.flags().row_contiguous ? CopyType::Vector : CopyType::General;
copy(src, out, ctype); copy(src, out, ctype, stream());
switch (src.dtype()) { auto& encoder = cpu::get_command_encoder(stream());
case bool_: encoder.set_input_array(idx);
dispatch_scatter_axis<bool>(out, idx, updates, axis_, reduce_type_); encoder.set_input_array(updates);
break; encoder.set_output_array(out);
case uint8: encoder.dispatch([axis_ = axis_,
dispatch_scatter_axis<uint8_t>(out, idx, updates, axis_, reduce_type_); reduce_type_ = reduce_type_,
break; idx = array::unsafe_weak_copy(idx),
case uint16: updates = array::unsafe_weak_copy(updates),
dispatch_scatter_axis<uint16_t>(out, idx, updates, axis_, reduce_type_); out = array::unsafe_weak_copy(out)]() mutable {
break; switch (out.dtype()) {
case uint32: case bool_:
dispatch_scatter_axis<uint32_t>(out, idx, updates, axis_, reduce_type_); dispatch_scatter_axis<bool>(out, idx, updates, axis_, reduce_type_);
break; break;
case uint64: case uint8:
dispatch_scatter_axis<uint64_t>(out, idx, updates, axis_, reduce_type_); dispatch_scatter_axis<uint8_t>(out, idx, updates, axis_, reduce_type_);
break; break;
case int8: case uint16:
dispatch_scatter_axis<int8_t>(out, idx, updates, axis_, reduce_type_); dispatch_scatter_axis<uint16_t>(out, idx, updates, axis_, reduce_type_);
break; break;
case int16: case uint32:
dispatch_scatter_axis<int16_t>(out, idx, updates, axis_, reduce_type_); dispatch_scatter_axis<uint32_t>(out, idx, updates, axis_, reduce_type_);
break; break;
case int32: case uint64:
dispatch_scatter_axis<int32_t>(out, idx, updates, axis_, reduce_type_); dispatch_scatter_axis<uint64_t>(out, idx, updates, axis_, reduce_type_);
break; break;
case int64: case int8:
dispatch_scatter_axis<int64_t>(out, idx, updates, axis_, reduce_type_); dispatch_scatter_axis<int8_t>(out, idx, updates, axis_, reduce_type_);
break; break;
case float16: case int16:
dispatch_scatter_axis<float16_t>(out, idx, updates, axis_, reduce_type_); dispatch_scatter_axis<int16_t>(out, idx, updates, axis_, reduce_type_);
break; break;
case float32: case int32:
dispatch_scatter_axis<float>(out, idx, updates, axis_, reduce_type_); dispatch_scatter_axis<int32_t>(out, idx, updates, axis_, reduce_type_);
break; break;
case float64: case int64:
dispatch_scatter_axis<double>(out, idx, updates, axis_, reduce_type_); dispatch_scatter_axis<int64_t>(out, idx, updates, axis_, reduce_type_);
break; break;
case bfloat16: case float16:
dispatch_scatter_axis<bfloat16_t>(out, idx, updates, axis_, reduce_type_); dispatch_scatter_axis<float16_t>(
break; out, idx, updates, axis_, reduce_type_);
case complex64: break;
dispatch_scatter_axis<complex64_t>( case float32:
out, idx, updates, axis_, reduce_type_); dispatch_scatter_axis<float>(out, idx, updates, axis_, reduce_type_);
break; break;
} case float64:
dispatch_scatter_axis<double>(out, idx, updates, axis_, reduce_type_);
break;
case bfloat16:
dispatch_scatter_axis<bfloat16_t>(
out, idx, updates, axis_, reduce_type_);
break;
case complex64:
dispatch_scatter_axis<complex64_t>(
out, idx, updates, axis_, reduce_type_);
break;
}
});
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -2,20 +2,21 @@
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/lapack.h" #include "mlx/backend/cpu/lapack.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
namespace mlx::core { namespace mlx::core {
template <typename T> template <typename T>
void general_inv(array& inv, int N, int i) { void general_inv(T* inv, int N) {
int info; int info;
auto ipiv = array::Data{allocator::malloc_or_wait(sizeof(int) * N)}; auto ipiv = array::Data{allocator::malloc(sizeof(int) * N)};
// Compute LU factorization. // Compute LU factorization.
getrf<T>( getrf<T>(
/* m = */ &N, /* m = */ &N,
/* n = */ &N, /* n = */ &N,
/* a = */ inv.data<T>() + N * N * i, /* a = */ inv,
/* lda = */ &N, /* lda = */ &N,
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()), /* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
/* info = */ &info); /* info = */ &info);
@@ -48,12 +49,12 @@ void general_inv(array& inv, int N, int i) {
} }
const int lwork = workspace_size; const int lwork = workspace_size;
auto scratch = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)}; auto scratch = array::Data{allocator::malloc(sizeof(T) * lwork)};
// Compute inverse. // Compute inverse.
getri<T>( getri<T>(
/* m = */ &N, /* m = */ &N,
/* a = */ inv.data<T>() + N * N * i, /* a = */ inv,
/* lda = */ &N, /* lda = */ &N,
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()), /* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
/* work = */ static_cast<T*>(scratch.buffer.raw_ptr()), /* work = */ static_cast<T*>(scratch.buffer.raw_ptr()),
@@ -68,29 +69,28 @@ void general_inv(array& inv, int N, int i) {
} }
template <typename T> template <typename T>
void tri_inv(array& inv, int N, int i, bool upper) { void tri_inv(T* inv, int N, bool upper) {
const char uplo = upper ? 'L' : 'U'; const char uplo = upper ? 'L' : 'U';
const char diag = 'N'; const char diag = 'N';
T* data = inv.data<T>() + N * N * i;
int info; int info;
trtri<T>( trtri<T>(
/* uplo = */ &uplo, /* uplo = */ &uplo,
/* diag = */ &diag, /* diag = */ &diag,
/* N = */ &N, /* N = */ &N,
/* a = */ data, /* a = */ inv,
/* lda = */ &N, /* lda = */ &N,
/* info = */ &info); /* info = */ &info);
// zero out the other triangle // zero out the other triangle
if (upper) { if (upper) {
for (int i = 0; i < N; i++) { for (int i = 0; i < N; i++) {
std::fill(data, data + i, 0.0f); std::fill(inv, inv + i, 0.0f);
data += N; inv += N;
} }
} else { } else {
for (int i = 0; i < N; i++) { for (int i = 0; i < N; i++) {
std::fill(data + i + 1, data + N, 0.0f); std::fill(inv + i + 1, inv + N, 0.0f);
data += N; inv += N;
} }
} }
@@ -103,34 +103,53 @@ void tri_inv(array& inv, int N, int i, bool upper) {
} }
template <typename T> template <typename T>
void inverse_impl(const array& a, array& inv, bool tri, bool upper) { void inverse_impl(
const array& a,
array& inv,
bool tri,
bool upper,
Stream stream) {
// Lapack uses the column-major convention. We take advantage of the following // Lapack uses the column-major convention. We take advantage of the following
// identity to avoid transposing (see // identity to avoid transposing (see
// https://math.stackexchange.com/a/340234): // https://math.stackexchange.com/a/340234):
// (A⁻¹)ᵀ = (Aᵀ)⁻¹ // (A⁻¹)ᵀ = (Aᵀ)⁻¹
// The inverse is computed in place, so just copy the input to the output. // The inverse is computed in place, so just copy the input to the output.
copy(a, inv, a.flags().row_contiguous ? CopyType::Vector : CopyType::General); copy(
a,
inv,
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
stream);
const int N = a.shape(-1); const int N = a.shape(-1);
const size_t num_matrices = a.size() / (N * N); const size_t num_matrices = a.size() / (N * N);
for (int i = 0; i < num_matrices; i++) { auto& encoder = cpu::get_command_encoder(stream);
if (tri) { encoder.set_output_array(inv);
tri_inv<T>(inv, N, i, upper);
} else { auto inv_ptr = inv.data<T>();
general_inv<T>(inv, N, i); if (tri) {
} encoder.dispatch([inv_ptr, N, num_matrices, upper]() {
for (int i = 0; i < num_matrices; i++) {
tri_inv<T>(inv_ptr + N * N * i, N, upper);
}
});
} else {
encoder.dispatch([inv_ptr, N, num_matrices]() {
for (int i = 0; i < num_matrices; i++) {
general_inv<T>(inv_ptr + N * N * i, N);
}
});
} }
} }
void Inverse::eval_cpu(const std::vector<array>& inputs, array& output) { void Inverse::eval_cpu(const std::vector<array>& inputs, array& output) {
switch (inputs[0].dtype()) { switch (inputs[0].dtype()) {
case float32: case float32:
inverse_impl<float>(inputs[0], output, tri_, upper_); inverse_impl<float>(inputs[0], output, tri_, upper_, stream());
break; break;
case float64: case float64:
inverse_impl<double>(inputs[0], output, tri_, upper_); inverse_impl<double>(inputs[0], output, tri_, upper_, stream());
break; break;
default: default:
throw std::runtime_error( throw std::runtime_error(

View File

@@ -0,0 +1,140 @@
// Copyright © 2023-2024 Apple Inc.
#include <cassert>
#include <cmath>
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h"
#include "mlx/primitives.h"
#include "mlx/types/limits.h"
namespace mlx::core {
namespace {
using namespace mlx::core::simd;
template <typename T, typename AccT>
void logsumexp(const array& in, array& out, Stream stream) {
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(in);
encoder.set_output_array(out);
const T* in_ptr = in.data<T>();
T* out_ptr = out.data<T>();
int M = in.shape().back();
int L = in.data_size() / M;
encoder.dispatch([in_ptr, out_ptr, M, L]() mutable {
constexpr int N = std::min(max_size<AccT>, max_size<T>);
const T* current_in_ptr;
for (int i = 0; i < L; i++, in_ptr += M, out_ptr += 1) {
// Find the maximum
current_in_ptr = in_ptr;
Simd<AccT, N> vmaximum(-numeric_limits<AccT>::infinity());
size_t s = M;
while (s >= N) {
Simd<AccT, N> vals = load<T, N>(current_in_ptr);
vmaximum = maximum(vals, vmaximum);
current_in_ptr += N;
s -= N;
}
AccT maximum = max(vmaximum);
while (s-- > 0) {
maximum = std::max(maximum, static_cast<AccT>(*current_in_ptr));
current_in_ptr++;
}
// Compute the normalizer and the exponentials
Simd<AccT, N> vnormalizer(0.0);
current_in_ptr = in_ptr;
s = M;
while (s >= N) {
Simd<AccT, N> vexp = load<T, N>(current_in_ptr);
vexp = exp(vexp - maximum);
vnormalizer = vnormalizer + vexp;
current_in_ptr += N;
s -= N;
}
AccT normalizer = sum(vnormalizer);
while (s-- > 0) {
AccT _exp = std::exp(*current_in_ptr - maximum);
normalizer += _exp;
current_in_ptr++;
}
// Normalize
*out_ptr = std::isinf(maximum)
? static_cast<T>(maximum)
: static_cast<T>(std::log(normalizer) + maximum);
}
});
}
} // namespace
void LogSumExp::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
// Make sure that the last dimension is contiguous
auto s = stream();
auto& encoder = cpu::get_command_encoder(s);
auto ensure_contiguous = [&s, &encoder](const array& x) {
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
return x;
} else {
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy(x, x_copy, CopyType::General, s);
encoder.add_temporary(x_copy);
return x_copy;
}
};
auto in = ensure_contiguous(inputs[0]);
if (in.flags().row_contiguous) {
out.set_data(allocator::malloc(out.nbytes()));
} else {
auto n = in.shape(-1);
auto flags = in.flags();
auto strides = in.strides();
for (auto& s : strides) {
s /= n;
}
bool col_contig = strides[0] == 1;
for (int i = 1; col_contig && i < strides.size(); ++i) {
col_contig &=
(out.shape(i) == 1 || strides[i - 1] == out.shape(i) * strides[i]);
}
flags.col_contiguous = col_contig;
out.set_data(
allocator::malloc(in.nbytes() / n),
in.data_size() / n,
std::move(strides),
flags);
}
switch (in.dtype()) {
case float32:
logsumexp<float, float>(in, out, stream());
break;
case float16:
logsumexp<float16_t, float>(in, out, stream());
break;
case bfloat16:
logsumexp<bfloat16_t, float>(in, out, stream());
break;
case float64:
logsumexp<double, double>(in, out, stream());
break;
default:
throw std::runtime_error(
"[logsumexp] only supports floating point types");
break;
}
}
} // namespace mlx::core

View File

@@ -4,15 +4,22 @@
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/lapack.h" #include "mlx/backend/cpu/lapack.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
namespace mlx::core { namespace mlx::core {
template <typename T> template <typename T>
void luf_impl(const array& a, array& lu, array& pivots, array& row_indices) { void luf_impl(
const array& a,
array& lu,
array& pivots,
array& row_indices,
Stream stream) {
int M = a.shape(-2); int M = a.shape(-2);
int N = a.shape(-1); int N = a.shape(-1);
int K = std::min(M, N);
// Copy a into lu and make it col contiguous // Copy a into lu and make it col contiguous
auto ndim = lu.ndim(); auto ndim = lu.ndim();
@@ -23,60 +30,74 @@ void luf_impl(const array& a, array& lu, array& pivots, array& row_indices) {
auto strides = lu.strides(); auto strides = lu.strides();
strides[ndim - 1] = M; strides[ndim - 1] = M;
strides[ndim - 2] = 1; strides[ndim - 2] = 1;
lu.set_data( lu.set_data(allocator::malloc(lu.nbytes()), lu.nbytes(), strides, flags);
allocator::malloc_or_wait(lu.nbytes()), lu.nbytes(), strides, flags);
copy_inplace( copy_inplace(
a, lu, a.shape(), a.strides(), strides, 0, 0, CopyType::GeneralGeneral); a,
lu,
a.shape(),
a.strides(),
strides,
0,
0,
CopyType::GeneralGeneral,
stream);
auto a_ptr = lu.data<T>(); auto a_ptr = lu.data<T>();
pivots.set_data(allocator::malloc(pivots.nbytes()));
pivots.set_data(allocator::malloc_or_wait(pivots.nbytes())); row_indices.set_data(allocator::malloc(row_indices.nbytes()));
row_indices.set_data(allocator::malloc_or_wait(row_indices.nbytes()));
auto pivots_ptr = pivots.data<uint32_t>(); auto pivots_ptr = pivots.data<uint32_t>();
auto row_indices_ptr = row_indices.data<uint32_t>(); auto row_indices_ptr = row_indices.data<uint32_t>();
int info;
size_t num_matrices = a.size() / (M * N); size_t num_matrices = a.size() / (M * N);
for (size_t i = 0; i < num_matrices; ++i) { auto& encoder = cpu::get_command_encoder(stream);
// Compute LU factorization of A encoder.set_input_array(a);
getrf<T>( encoder.set_output_array(lu);
/* m */ &M, encoder.set_output_array(pivots);
/* n */ &N, encoder.set_output_array(row_indices);
/* a */ a_ptr,
/* lda */ &M,
/* ipiv */ reinterpret_cast<int*>(pivots_ptr),
/* info */ &info);
if (info != 0) { encoder.dispatch(
std::stringstream ss; [a_ptr, pivots_ptr, row_indices_ptr, num_matrices, M, N, K]() mutable {
ss << "[LUF::eval_cpu] sgetrf_ failed with code " << info int info;
<< ((info > 0) ? " because matrix is singular" for (size_t i = 0; i < num_matrices; ++i) {
: " because argument had an illegal value"); // Compute LU factorization of A
throw std::runtime_error(ss.str()); getrf<T>(
} /* m */ &M,
/* n */ &N,
/* a */ a_ptr,
/* lda */ &M,
/* ipiv */ reinterpret_cast<int*>(pivots_ptr),
/* info */ &info);
// Subtract 1 to get 0-based index if (info != 0) {
int j = 0; std::stringstream ss;
for (; j < pivots.shape(-1); ++j) { ss << "[LUF::eval_cpu] sgetrf_ failed with code " << info
pivots_ptr[j]--; << ((info > 0) ? " because matrix is singular"
row_indices_ptr[j] = j; : " because argument had an illegal value");
} throw std::runtime_error(ss.str());
for (; j < row_indices.shape(-1); ++j) { }
row_indices_ptr[j] = j;
}
for (int j = pivots.shape(-1) - 1; j >= 0; --j) {
auto piv = pivots_ptr[j];
auto t1 = row_indices_ptr[piv];
auto t2 = row_indices_ptr[j];
row_indices_ptr[j] = t1;
row_indices_ptr[piv] = t2;
}
// Advance pointers to the next matrix // Subtract 1 to get 0-based index
a_ptr += M * N; int j = 0;
pivots_ptr += pivots.shape(-1); for (; j < K; ++j) {
row_indices_ptr += pivots.shape(-1); pivots_ptr[j]--;
} row_indices_ptr[j] = j;
}
for (; j < M; ++j) {
row_indices_ptr[j] = j;
}
for (int j = K - 1; j >= 0; --j) {
auto piv = pivots_ptr[j];
auto t1 = row_indices_ptr[piv];
auto t2 = row_indices_ptr[j];
row_indices_ptr[j] = t1;
row_indices_ptr[piv] = t2;
}
// Advance pointers to the next matrix
a_ptr += M * N;
pivots_ptr += K;
row_indices_ptr += M;
}
});
} }
void LUF::eval_cpu( void LUF::eval_cpu(
@@ -85,10 +106,10 @@ void LUF::eval_cpu(
assert(inputs.size() == 1); assert(inputs.size() == 1);
switch (inputs[0].dtype()) { switch (inputs[0].dtype()) {
case float32: case float32:
luf_impl<float>(inputs[0], outputs[0], outputs[1], outputs[2]); luf_impl<float>(inputs[0], outputs[0], outputs[1], outputs[2], stream());
break; break;
case float64: case float64:
luf_impl<double>(inputs[0], outputs[0], outputs[1], outputs[2]); luf_impl<double>(inputs[0], outputs[0], outputs[1], outputs[2], stream());
break; break;
default: default:
throw std::runtime_error( throw std::runtime_error(

View File

@@ -5,6 +5,7 @@
#include "mlx/array.h" #include "mlx/array.h"
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/lapack.h" #include "mlx/backend/cpu/lapack.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
@@ -58,42 +59,42 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
throw std::runtime_error( throw std::runtime_error(
"[BlockMaskedMM::eval] Currently only supports float32."); "[BlockMaskedMM::eval] Currently only supports float32.");
} }
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
auto& a_pre = inputs[0]; auto& a_pre = inputs[0];
auto& b_pre = inputs[1]; auto& b_pre = inputs[1];
auto check_transpose = auto check_transpose =
[](const array& arr, bool do_copy, bool expand_all = false) { [s = stream()](const array& arr, bool do_copy, bool expand_all = false) {
auto stx = arr.strides()[arr.ndim() - 2]; auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1]; auto sty = arr.strides()[arr.ndim() - 1];
if (!expand_all && stx == arr.shape(-1) && sty == 1) { if (!expand_all && stx == arr.shape(-1) && sty == 1) {
if (do_copy) { if (do_copy) {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::Vector); copy(arr, arr_copy, CopyType::Vector, s);
return std::make_tuple(false, stx, arr_copy); return std::make_tuple(false, stx, arr_copy, true);
} }
return std::make_tuple(false, stx, arr); return std::make_tuple(false, stx, arr, false);
} else if (!expand_all && stx == 1 && sty == arr.shape(-2)) { } else if (!expand_all && stx == 1 && sty == arr.shape(-2)) {
if (do_copy) { if (do_copy) {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::Vector); copy(arr, arr_copy, CopyType::Vector, s);
return std::make_tuple(true, sty, arr_copy); return std::make_tuple(true, sty, arr_copy, true);
} }
return std::make_tuple(true, sty, arr); return std::make_tuple(true, sty, arr, false);
} else { } else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General); copy(arr, arr_copy, CopyType::General, s);
int64_t stx = arr.shape(-1); int64_t stx = arr.shape(-1);
return std::make_tuple(false, stx, arr_copy); return std::make_tuple(false, stx, arr_copy, true);
} }
}; };
bool has_op_mask = inputs.size() > 3; bool has_op_mask = inputs.size() > 3;
bool has_out_mask = inputs.size() == 3 || inputs.size() == 5; bool has_out_mask = inputs.size() == 3 || inputs.size() == 5;
auto [a_transposed, lda, a] = auto [a_transposed, lda, a, a_copied] =
check_transpose(a_pre, has_op_mask, inputs.back().dtype() != bool_); check_transpose(a_pre, has_op_mask, inputs.back().dtype() != bool_);
auto [b_transposed, ldb, b] = auto [b_transposed, ldb, b, b_copied] =
check_transpose(b_pre, has_op_mask, inputs.back().dtype() != bool_); check_transpose(b_pre, has_op_mask, inputs.back().dtype() != bool_);
size_t M = a.shape(-2); size_t M = a.shape(-2);
@@ -104,31 +105,39 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
return; return;
} }
auto& encoder = cpu::get_command_encoder(stream());
if (K == 0) { if (K == 0) {
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes()); encoder.set_output_array(out);
encoder.dispatch([out_ptr = out.data<void>(), nbytes = out.nbytes()]() {
std::memset(out_ptr, 0, nbytes);
});
return; return;
} }
auto mask_array = [](const array& mask, auto mask_array = [](const void* mask,
float* data, float* data,
int block_size, int block_size,
int batch_idx, int batch_idx,
int X, int X,
int Y, int Y,
size_t X_data_str, size_t X_data_str,
size_t Y_data_str) { size_t Y_data_str,
const Shape& mask_shape,
const Strides& mask_strides,
bool is_bool) {
auto ndim = mask_shape.size();
auto mask_offset = elem_to_loc( auto mask_offset = elem_to_loc(
mask.shape(-1) * mask.shape(-2) * batch_idx, mask_shape[ndim - 1] * mask_shape[ndim - 2] * batch_idx,
mask.shape(), mask_shape,
mask.strides()); mask_strides);
auto X_mask_str = mask.strides()[mask.ndim() - 2]; auto X_mask_str = mask_strides[ndim - 2];
auto Y_mask_str = mask.strides()[mask.ndim() - 1]; auto Y_mask_str = mask_strides[ndim - 1];
if (mask.dtype() == bool_) { if (is_bool) {
return mask_matrix( return mask_matrix(
data, data,
mask.data<bool>(), static_cast<const bool*>(mask),
block_size, block_size,
X, X,
Y, Y,
@@ -140,7 +149,7 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
} else { } else {
return mask_matrix( return mask_matrix(
data, data,
mask.data<float>(), static_cast<const float*>(mask),
block_size, block_size,
X, X,
Y, Y,
@@ -152,61 +161,155 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
} }
}; };
for (int i = 0; i < (out.size() / (M * size_t(N))); ++i) { encoder.set_input_array(a);
// Adjust pointer encoder.set_input_array(b);
float* ai = const void* a_mask_ptr;
a.data<float>() + elem_to_loc(M * K * i, a.shape(), a.strides()); const void* b_mask_ptr;
float* bi = const void* out_mask_ptr;
b.data<float>() + elem_to_loc(K * N * i, b.shape(), b.strides()); Shape a_mask_shape;
float* ci = out.data<float>() + M * N * i; Shape b_mask_shape;
Shape out_mask_shape;
Strides a_mask_strides;
Strides b_mask_strides;
Strides out_mask_strides;
bool a_mask_bool;
bool b_mask_bool;
bool out_mask_bool;
if (has_op_mask) {
auto& a_mask = inputs[inputs.size() - 2];
auto& b_mask = inputs[inputs.size() - 1];
a_mask_ptr = a_mask.data<void>();
b_mask_ptr = b_mask.data<void>();
a_mask_shape = a_mask.shape();
b_mask_shape = b_mask.shape();
a_mask_strides = a_mask.strides();
b_mask_strides = b_mask.strides();
a_mask_bool = (a_mask.dtype() == bool_);
b_mask_bool = (b_mask.dtype() == bool_);
encoder.set_input_array(a_mask);
encoder.set_input_array(b_mask);
}
if (has_out_mask) {
auto& out_mask = inputs[2];
out_mask_ptr = out_mask.data<void>();
out_mask_bool = (out_mask.dtype() == bool_);
encoder.set_input_array(out_mask);
out_mask_shape = out_mask.shape();
out_mask_strides = out_mask.strides();
}
encoder.set_output_array(out);
auto a_ptr = a.data<float>();
auto b_ptr = b.data<float>();
auto out_ptr = out.data<float>();
size_t num_matrices = out.size() / (M * size_t(N));
auto ldc = out.shape(-1);
// Zero out blocks in a and b if needed encoder.dispatch([a_ptr,
if (has_op_mask) { b_ptr,
auto& a_mask = inputs[inputs.size() - 2]; out_ptr,
mask_array( a_mask_ptr,
a_mask, b_mask_ptr,
ai, out_mask_ptr,
block_size_, has_op_mask,
i, has_out_mask,
block_size = block_size_,
num_matrices,
M,
N,
K,
a_transposed = a_transposed,
b_transposed = b_transposed,
lda = lda,
ldb = ldb,
ldc,
a_shape = a.shape(),
a_strides = a.strides(),
b_shape = b.shape(),
b_strides = b.strides(),
a_mask_shape = std::move(a_mask_shape),
b_mask_shape = std::move(b_mask_shape),
out_mask_shape = std::move(out_mask_shape),
a_mask_strides = std::move(a_mask_strides),
b_mask_strides = std::move(b_mask_strides),
out_mask_strides = std::move(out_mask_strides),
mask_array,
a_mask_bool,
b_mask_bool,
out_mask_bool]() {
for (int i = 0; i < num_matrices; ++i) {
// Adjust pointer
float* ai = a_ptr + elem_to_loc(M * K * i, a_shape, a_strides);
float* bi = b_ptr + elem_to_loc(K * N * i, b_shape, b_strides);
float* ci = out_ptr + M * N * i;
// Zero out blocks in a and b if needed
if (has_op_mask) {
mask_array(
a_mask_ptr,
ai,
block_size,
i,
M,
K,
a_transposed ? 1 : lda,
a_transposed ? lda : 1,
a_mask_shape,
a_mask_strides,
a_mask_bool);
mask_array(
b_mask_ptr,
bi,
block_size,
i,
K,
N,
b_transposed ? 1 : ldb,
b_transposed ? ldb : 1,
b_mask_shape,
b_mask_strides,
b_mask_bool);
}
// Do matmul
cblas_sgemm(
CblasRowMajor,
a_transposed ? CblasTrans : CblasNoTrans, // transA
b_transposed ? CblasTrans : CblasNoTrans, // transB
M, M,
K,
a_transposed ? 1 : lda,
a_transposed ? lda : 1);
auto& b_mask = inputs[inputs.size() - 1];
mask_array(
b_mask,
bi,
block_size_,
i,
K,
N, N,
b_transposed ? 1 : ldb, K,
b_transposed ? ldb : 1); 1.0, // alpha
} ai,
lda,
bi,
ldb,
0.0, // beta
ci,
ldc);
// Do matmul // Zero out blocks in out
cblas_sgemm( if (has_out_mask) {
CblasRowMajor, mask_array(
a_transposed ? CblasTrans : CblasNoTrans, // transA out_mask_ptr,
b_transposed ? CblasTrans : CblasNoTrans, // transB ci,
M, block_size,
N, i,
K, M,
1.0, // alpha N,
ai, N,
lda, 1,
bi, out_mask_shape,
ldb, out_mask_strides,
0.0, // beta out_mask_bool);
ci, }
out.shape(-1) // ldc
);
// Zero out blocks in out
if (has_out_mask) {
mask_array(inputs[2], ci, block_size_, i, M, N, N, 1);
} }
});
if (a_copied) {
encoder.add_temporary(a);
}
if (b_copied) {
encoder.add_temporary(b);
} }
} }
@@ -215,12 +318,13 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
throw std::runtime_error( throw std::runtime_error(
"[GatherMM::eval] Currently only supports float32."); "[GatherMM::eval] Currently only supports float32.");
} }
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
auto& a_pre = inputs[0]; auto& a_pre = inputs[0];
auto& b_pre = inputs[1]; auto& b_pre = inputs[1];
auto check_transpose = [](const array& arr) { std::vector<array> temps;
auto check_transpose = [s = stream(), &temps](const array& arr) {
auto stx = arr.strides()[arr.ndim() - 2]; auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1]; auto sty = arr.strides()[arr.ndim() - 1];
if (stx == arr.shape(-1) && sty == 1) { if (stx == arr.shape(-1) && sty == 1) {
@@ -228,10 +332,10 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
} else if (stx == 1 && sty == arr.shape(-2)) { } else if (stx == 1 && sty == arr.shape(-2)) {
return std::make_tuple(true, sty, arr); return std::make_tuple(true, sty, arr);
} else { } else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
copy(arr, arr_copy, CopyType::General); copy(arr, temps.back(), CopyType::General, s);
int64_t stx = arr.shape(-1); int64_t stx = arr.shape(-1);
return std::make_tuple(false, stx, arr_copy); return std::make_tuple(false, stx, temps.back());
} }
}; };
@@ -246,8 +350,12 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
return; return;
} }
auto& encoder = cpu::get_command_encoder(stream());
if (K == 0) { if (K == 0) {
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes()); encoder.set_output_array(out);
encoder.dispatch([out_ptr = out.data<float>(), size = out.size()]() {
std::fill_n(out_ptr, size, 0);
});
return; return;
} }
@@ -272,29 +380,61 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
const uint32_t* lhs_indices_ptr = lhs_indices.data<uint32_t>(); const uint32_t* lhs_indices_ptr = lhs_indices.data<uint32_t>();
const uint32_t* rhs_indices_ptr = rhs_indices.data<uint32_t>(); const uint32_t* rhs_indices_ptr = rhs_indices.data<uint32_t>();
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_input_array(lhs_indices);
encoder.set_input_array(rhs_indices);
encoder.set_output_array(out);
auto ldc = out.shape(-1);
for (int i = 0; i < batch_size_out; i++) { encoder.dispatch([a_ptr = a.data<float>(),
// Get index b_ptr = b.data<float>(),
uint32_t indx_A = lhs_indices_ptr[elem_to_loc(i, lhs_indices)]; out_ptr = out.data<float>(),
uint32_t indx_B = rhs_indices_ptr[elem_to_loc(i, rhs_indices)]; M,
N,
K,
lda = lda,
ldb = ldb,
a_transposed = a_transposed,
b_transposed = b_transposed,
ldc,
lhs_indices_ptr,
rhs_indices_ptr,
lhs_indices_shape = lhs_indices.shape(),
lhs_indices_strides = lhs_indices.strides(),
rhs_indices_shape = rhs_indices.shape(),
rhs_indices_strides = rhs_indices.strides(),
batch_size_out,
matrix_stride_out,
batch_shape_A = std::move(batch_shape_A),
batch_shape_B = std::move(batch_shape_B),
batch_strides_A = std::move(batch_strides_A),
batch_strides_B = std::move(batch_strides_B)]() {
for (int i = 0; i < batch_size_out; i++) {
// Get index
uint32_t indx_A = lhs_indices_ptr[elem_to_loc(
i, lhs_indices_shape, lhs_indices_strides)];
uint32_t indx_B = rhs_indices_ptr[elem_to_loc(
i, rhs_indices_shape, rhs_indices_strides)];
cblas_sgemm( cblas_sgemm(
CblasRowMajor, CblasRowMajor,
a_transposed ? CblasTrans : CblasNoTrans, // transA a_transposed ? CblasTrans : CblasNoTrans, // transA
b_transposed ? CblasTrans : CblasNoTrans, // transB b_transposed ? CblasTrans : CblasNoTrans, // transB
M, M,
N, N,
K, K,
1.0f, // alpha 1.0f, // alpha
a.data<float>() + elem_to_loc(indx_A, batch_shape_A, batch_strides_A), a_ptr + elem_to_loc(indx_A, batch_shape_A, batch_strides_A),
lda, lda,
b.data<float>() + elem_to_loc(indx_B, batch_shape_B, batch_strides_B), b_ptr + elem_to_loc(indx_B, batch_shape_B, batch_strides_B),
ldb, ldb,
0.0f, // beta 0.0f, // beta
out.data<float>() + matrix_stride_out * i, out_ptr + matrix_stride_out * i,
out.shape(-1) // ldc ldc);
); }
} });
encoder.add_temporaries(std::move(temps));
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -3,18 +3,76 @@
#include <cstring> #include <cstring>
#include "mlx/array.h" #include "mlx/array.h"
#include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/gemm.h" #include "mlx/backend/cpu/gemm.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
namespace mlx::core { namespace mlx::core {
template <typename T>
void matmul_dispatch(
const array& a,
const array& b,
array& out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
float alpha,
float beta,
Stream stream) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
T* out_ptr = out.data<T>();
size_t ldc = out.shape(-1);
size_t batch_size = a.size() / (a.shape(-2) * a.shape(-1));
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a_ptr,
b_ptr,
out_ptr,
a_transposed,
b_transposed,
lda,
ldb,
ldc,
alpha,
beta,
batch_size,
a_shape = a.shape(),
a_strides = a.strides(),
b_shape = b.shape(),
b_strides = b.strides()]() {
matmul<T>(
a_ptr,
b_ptr,
out_ptr,
a_transposed,
b_transposed,
lda,
ldb,
ldc,
alpha,
beta,
batch_size,
a_shape,
a_strides,
b_shape,
b_strides);
});
}
void matmul_general( void matmul_general(
const array& a_pre, const array& a_pre,
const array& b_pre, const array& b_pre,
array& out, array& out,
Stream stream,
float alpha = 1.0f, float alpha = 1.0f,
float beta = 0.0f) { float beta = 0.0f) {
auto check_transpose = [](const array& arr) { std::vector<array> temps;
auto check_transpose = [stream, &temps](const array& arr) {
auto stx = arr.strides()[arr.ndim() - 2]; auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1]; auto sty = arr.strides()[arr.ndim() - 1];
if (stx == arr.shape(-1) && sty == 1) { if (stx == arr.shape(-1) && sty == 1) {
@@ -22,10 +80,10 @@ void matmul_general(
} else if (stx == 1 && sty == arr.shape(-2)) { } else if (stx == 1 && sty == arr.shape(-2)) {
return std::make_tuple(true, sty, arr); return std::make_tuple(true, sty, arr);
} else { } else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
copy(arr, arr_copy, CopyType::General); copy(arr, temps.back(), CopyType::General, stream);
stx = arr.shape(-1); stx = arr.shape(-1);
return std::make_tuple(false, stx, arr_copy); return std::make_tuple(false, stx, temps.back());
} }
}; };
@@ -39,28 +97,34 @@ void matmul_general(
} }
if (out.dtype() == float32) { if (out.dtype() == float32) {
matmul<float>(a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta); matmul_dispatch<float>(
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta, stream);
} else if (out.dtype() == float16) { } else if (out.dtype() == float16) {
matmul<float16_t>( matmul_dispatch<float16_t>(
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta); a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta, stream);
} else if (out.dtype() == bfloat16) { } else if (out.dtype() == bfloat16) {
matmul<bfloat16_t>( matmul_dispatch<bfloat16_t>(
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta); a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta, stream);
} else if (out.dtype() == float64) { } else if (out.dtype() == float64) {
matmul<double>( matmul_dispatch<double>(
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta); a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta, stream);
} else { } else {
throw std::runtime_error("[Matmul::eval_cpu] Invalid type."); throw std::runtime_error("[Matmul::eval_cpu] Invalid type.");
} }
cpu::get_command_encoder(stream).add_temporaries(std::move(temps));
} }
void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) { void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
if (inputs[0].shape(-1) == 0) { if (inputs[0].shape(-1) == 0) {
std::memset(out.data<void>(), 0, out.nbytes()); auto& encoder = cpu::get_command_encoder(stream());
encoder.set_output_array(out);
encoder.dispatch([out_ptr = out.data<void>(), nbytes = out.nbytes()]() {
std::memset(out_ptr, 0, nbytes);
});
return; return;
} }
return matmul_general(inputs[0], inputs[1], out); matmul_general(inputs[0], inputs[1], out, stream());
} }
void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) { void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -74,9 +138,9 @@ void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
CopyType ctype = c.data_size() == 1 CopyType ctype = c.data_size() == 1
? CopyType::Scalar ? CopyType::Scalar
: (c.flags().row_contiguous ? CopyType::Vector : CopyType::General); : (c.flags().row_contiguous ? CopyType::Vector : CopyType::General);
copy(c, out, ctype); copy(c, out, ctype, stream());
return matmul_general(inputs[0], inputs[1], out, alpha_, beta_); matmul_general(inputs[0], inputs[1], out, stream(), alpha_, beta_);
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -7,11 +7,11 @@
#include <sstream> #include <sstream>
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/backend/common/load.h"
#include "mlx/backend/common/slicing.h" #include "mlx/backend/common/slicing.h"
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/arange.h" #include "mlx/backend/cpu/arange.h"
#include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/threefry.h" #include "mlx/backend/cpu/threefry.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
#include "mlx/utils.h" #include "mlx/utils.h"
@@ -21,40 +21,59 @@ namespace mlx::core {
void reshape(const array& in, array& out) { void reshape(const array& in, array& out) {
auto [copy_necessary, out_strides] = prepare_reshape(in, out); auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) { if (copy_necessary) {
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
copy_inplace(in, out, CopyType::General); copy_inplace(in, out, CopyType::General, out.primitive().stream());
} else { } else {
shared_buffer_reshape(in, out_strides, out); shared_buffer_reshape(in, out_strides, out);
} }
} }
int64_t compute_dynamic_offset( static std::pair<array, bool> compute_dynamic_offset(
const array& indices, const array& indices,
const Strides& strides, const Strides& strides,
const std::vector<int>& axes) { const std::vector<int>& axes,
auto compute_offset = [&strides, &axes](const auto* indices) { Stream stream) {
int64_t offset = 0; array offset({1}, int64, nullptr, {});
for (int i = 0; i < axes.size(); ++i) { bool donate = indices.is_donatable() &&
offset += indices[i] * strides[axes[i]]; (indices.data_size() * indices.itemsize()) >= offset.itemsize();
} if (donate) {
return offset; offset.copy_shared_buffer(indices);
}; } else {
offset.set_data(allocator::malloc(offset.itemsize()));
}
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(indices);
encoder.set_output_array(offset);
auto compute_offset =
[strides, axes, offset = offset.data<int64_t>()](const auto* indices) {
int64_t offset_ = 0;
for (int i = 0; i < axes.size(); ++i) {
offset_ += indices[i] * strides[axes[i]];
}
offset[0] = offset_;
};
switch (indices.dtype()) { switch (indices.dtype()) {
case int8: case int8:
case uint8: case uint8:
return compute_offset(indices.data<uint8_t>()); encoder.dispatch(compute_offset, indices.data<uint8_t>());
break;
case int16: case int16:
case uint16: case uint16:
return compute_offset(indices.data<uint16_t>()); encoder.dispatch(compute_offset, indices.data<uint16_t>());
break;
case int32: case int32:
case uint32: case uint32:
return compute_offset(indices.data<uint32_t>()); encoder.dispatch(compute_offset, indices.data<uint32_t>());
break;
case int64: case int64:
case uint64: case uint64:
return compute_offset(indices.data<uint64_t>()); encoder.dispatch(compute_offset, indices.data<uint64_t>());
break;
default: default:
throw std::runtime_error("Invalid indices type."); throw std::runtime_error("Invalid indices type.");
} }
return {offset, donate};
} }
void AsStrided::eval_cpu(const std::vector<array>& inputs, array& out) { void AsStrided::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -104,14 +123,59 @@ void Transpose::eval_cpu(const std::vector<array>& inputs, array& out) {
} }
void Arange::eval_cpu(const std::vector<array>& inputs, array& out) { void Arange::eval_cpu(const std::vector<array>& inputs, array& out) {
arange(inputs, out, start_, step_); assert(inputs.size() == 0);
out.set_data(allocator::malloc(out.nbytes()));
switch (out.dtype()) {
case bool_:
throw std::runtime_error("Bool type unsupported for arange.");
break;
case uint8:
arange<uint8_t>(start_, start_ + step_, out, out.size(), stream());
break;
case uint16:
arange<uint16_t>(start_, start_ + step_, out, out.size(), stream());
break;
case uint32:
arange<uint32_t>(start_, start_ + step_, out, out.size(), stream());
break;
case uint64:
arange<uint64_t>(start_, start_ + step_, out, out.size(), stream());
break;
case int8:
arange<int8_t>(start_, start_ + step_, out, out.size(), stream());
break;
case int16:
arange<int16_t>(start_, start_ + step_, out, out.size(), stream());
break;
case int32:
arange<int32_t>(start_, start_ + step_, out, out.size(), stream());
break;
case int64:
arange<int64_t>(start_, start_ + step_, out, out.size(), stream());
break;
case float16:
arange<float16_t>(start_, start_ + step_, out, out.size(), stream());
break;
case float32:
arange<float>(start_, start_ + step_, out, out.size(), stream());
break;
case float64:
arange<double>(start_, start_ + step_, out, out.size(), stream());
break;
case bfloat16:
arange<bfloat16_t>(start_, start_ + step_, out, out.size(), stream());
break;
case complex64:
arange<complex64_t>(start_, start_ + step_, out, out.size(), stream());
break;
}
} }
void AsType::eval_cpu(const std::vector<array>& inputs, array& out) { void AsType::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General; CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
copy(in, out, ctype); copy(in, out, ctype, stream());
} }
void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) { void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -122,7 +186,7 @@ void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
} }
std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin()); std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin());
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
auto strides = out.strides(); auto strides = out.strides();
auto flags = out.flags(); auto flags = out.flags();
@@ -134,18 +198,20 @@ void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
size_t data_offset = strides[axis_] * sizes[i]; size_t data_offset = strides[axis_] * sizes[i];
out_slice.copy_shared_buffer( out_slice.copy_shared_buffer(
out, strides, flags, out_slice.size(), data_offset); out, strides, flags, out_slice.size(), data_offset);
copy_inplace(inputs[i], out_slice, CopyType::GeneralGeneral); copy_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, stream());
} }
} }
void Contiguous::eval_cpu(const std::vector<array>& inputs, array& out) { void Contiguous::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
if (in.flags().row_contiguous || constexpr size_t extra_bytes = 16384;
(allow_col_major_ && in.flags().col_contiguous)) { if (in.buffer_size() <= out.nbytes() + extra_bytes &&
(in.flags().row_contiguous ||
(allow_col_major_ && in.flags().col_contiguous))) {
out.copy_shared_buffer(in); out.copy_shared_buffer(in);
} else { } else {
copy(in, out, CopyType::General); copy(in, out, CopyType::General, stream());
} }
} }
@@ -169,14 +235,7 @@ void Full::eval_cpu(const std::vector<array>& inputs, array& out) {
} else { } else {
ctype = CopyType::General; ctype = CopyType::General;
} }
copy(in, out, ctype); copy(in, out, ctype, stream());
}
void Load::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 0);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
load(out, offset_, reader_, swap_endianness_);
} }
void Pad::eval_cpu(const std::vector<array>& inputs, array& out) { void Pad::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -192,7 +251,7 @@ void Pad::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(val.dtype() == in.dtype() && in.dtype() == out.dtype()); assert(val.dtype() == in.dtype() && in.dtype() == out.dtype());
// Fill output with val // Fill output with val
copy(val, out, CopyType::Scalar); copy(val, out, CopyType::Scalar, stream());
// Find offset for start of input values // Find offset for start of input values
size_t data_offset = 0; size_t data_offset = 0;
@@ -207,7 +266,7 @@ void Pad::eval_cpu(const std::vector<array>& inputs, array& out) {
out, out.strides(), out.flags(), out_slice.size(), data_offset); out, out.strides(), out.flags(), out_slice.size(), data_offset);
// Copy input values into the slice // Copy input values into the slice
copy_inplace(in, out_slice, CopyType::GeneralGeneral); copy_inplace(in, out_slice, CopyType::GeneralGeneral, stream());
} }
void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) { void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -219,43 +278,53 @@ void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
size_t elems_per_key = out.size() / num_keys; size_t elems_per_key = out.size() / num_keys;
size_t bytes_per_key = out.itemsize() * elems_per_key; size_t bytes_per_key = out.itemsize() * elems_per_key;
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
auto kptr = inputs[0].data<uint32_t>(); auto kptr = inputs[0].data<uint32_t>();
auto cptr = out.data<char>(); auto cptr = out.data<char>();
size_t out_skip = (bytes_per_key + 4 - 1) / 4; auto& encoder = cpu::get_command_encoder(stream());
auto half_size = out_skip / 2; encoder.set_input_array(inputs[0]);
bool even = out_skip % 2 == 0; encoder.set_output_array(out);
for (int i = 0; i < num_keys; ++i, cptr += bytes_per_key) { encoder.dispatch([kptr,
auto ptr = reinterpret_cast<uint32_t*>(cptr); cptr,
// Get ith key bytes_per_key,
auto kidx = 2 * i; num_keys,
auto k1_elem = elem_to_loc(kidx, keys.shape(), keys.strides()); kshape = keys.shape(),
auto k2_elem = elem_to_loc(kidx + 1, keys.shape(), keys.strides()); kstrides = keys.strides()]() mutable {
auto key = std::make_pair(kptr[k1_elem], kptr[k2_elem]); size_t out_skip = (bytes_per_key + 4 - 1) / 4;
auto half_size = out_skip / 2;
bool even = out_skip % 2 == 0;
for (int i = 0; i < num_keys; ++i, cptr += bytes_per_key) {
auto ptr = reinterpret_cast<uint32_t*>(cptr);
// Get ith key
auto kidx = 2 * i;
auto k1_elem = elem_to_loc(kidx, kshape, kstrides);
auto k2_elem = elem_to_loc(kidx + 1, kshape, kstrides);
auto key = std::make_pair(kptr[k1_elem], kptr[k2_elem]);
std::pair<uintptr_t, uintptr_t> count{0, half_size + !even}; std::pair<uintptr_t, uintptr_t> count{0, half_size + !even};
for (; count.first + 1 < half_size; count.first++, count.second++) { for (; count.first + 1 < half_size; count.first++, count.second++) {
std::tie(ptr[count.first], ptr[count.second]) = std::tie(ptr[count.first], ptr[count.second]) =
random::threefry2x32_hash(key, count); random::threefry2x32_hash(key, count);
} }
if (count.first < half_size) { if (count.first < half_size) {
auto rb = random::threefry2x32_hash(key, count); auto rb = random::threefry2x32_hash(key, count);
ptr[count.first++] = rb.first; ptr[count.first++] = rb.first;
if (bytes_per_key % 4 > 0) { if (bytes_per_key % 4 > 0) {
std::copy( std::copy(
reinterpret_cast<char*>(&rb.second), reinterpret_cast<char*>(&rb.second),
reinterpret_cast<char*>(&rb.second) + bytes_per_key % 4, reinterpret_cast<char*>(&rb.second) + bytes_per_key % 4,
cptr + 4 * count.second); cptr + 4 * count.second);
} else { } else {
ptr[count.second] = rb.second; ptr[count.second] = rb.second;
}
}
if (!even) {
count.second = 0;
ptr[half_size] = random::threefry2x32_hash(key, count).first;
} }
} }
if (!even) { });
count.second = 0;
ptr[half_size] = random::threefry2x32_hash(key, count).first;
}
}
} }
void Reshape::eval_cpu(const std::vector<array>& inputs, array& out) { void Reshape::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -268,17 +337,24 @@ void DynamicSlice::eval_cpu(const std::vector<array>& inputs, array& out) {
return; return;
} }
auto& in = inputs[0]; auto& in = inputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
auto i_offset = compute_dynamic_offset(inputs[1], in.strides(), axes_); auto [in_offset, donated] =
compute_dynamic_offset(inputs[1], in.strides(), axes_, stream());
copy_inplace( copy_inplace(
/* const array& src = */ in, /* const array& src = */ in,
/* array& dst = */ out, /* array& dst = */ out,
/* const Shape& data_shape = */ out.shape(), /* const Shape& data_shape = */ out.shape(),
/* const Strides& i_strides = */ in.strides(), /* const Strides& i_strides = */ in.strides(),
/* const Strides& o_strides = */ out.strides(), /* const Strides& o_strides = */ out.strides(),
/* int64_t i_offset = */ i_offset, /* int64_t i_offset = */ 0,
/* int64_t o_offset = */ 0, /* int64_t o_offset = */ 0,
/* CopyType ctype = */ CopyType::GeneralGeneral); /* CopyType ctype = */ CopyType::GeneralGeneral,
stream(),
/* const std::optional<array>& dynamic_i_offset = */ in_offset,
/* const std::optional<array>& dynamic_o_offset = */ std::nullopt);
if (!donated) {
cpu::get_command_encoder(stream()).add_temporary(std::move(in_offset));
}
} }
void DynamicSliceUpdate::eval_cpu( void DynamicSliceUpdate::eval_cpu(
@@ -296,9 +372,10 @@ void DynamicSliceUpdate::eval_cpu(
auto ctype = in.flags().contiguous && in.size() == in.data_size() auto ctype = in.flags().contiguous && in.size() == in.data_size()
? CopyType::Vector ? CopyType::Vector
: CopyType::General; : CopyType::General;
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype); copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
auto o_offset = compute_dynamic_offset(inputs[2], out.strides(), axes_); auto [out_offset, donated] =
compute_dynamic_offset(inputs[2], out.strides(), axes_, stream());
copy_inplace( copy_inplace(
/* const array& src = */ upd, /* const array& src = */ upd,
/* array& dst = */ out, /* array& dst = */ out,
@@ -306,8 +383,14 @@ void DynamicSliceUpdate::eval_cpu(
/* const std::vector<stride_t>& i_strides = */ upd.strides(), /* const std::vector<stride_t>& i_strides = */ upd.strides(),
/* const std::vector<stride_t>& o_strides = */ out.strides(), /* const std::vector<stride_t>& o_strides = */ out.strides(),
/* int64_t i_offset = */ 0, /* int64_t i_offset = */ 0,
/* int64_t o_offset = */ o_offset, /* int64_t o_offset = */ 0,
/* CopyType ctype = */ CopyType::GeneralGeneral); /* CopyType ctype = */ CopyType::GeneralGeneral,
stream(),
/* const std::optional<array>& dynamic_i_offset = */ std::nullopt,
/* const std::optional<array>& dynamic_o_offset = */ out_offset);
if (!donated) {
cpu::get_command_encoder(stream()).add_temporary(std::move(out_offset));
}
} }
void SliceUpdate::eval_cpu(const std::vector<array>& inputs, array& out) { void SliceUpdate::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -329,7 +412,7 @@ void SliceUpdate::eval_cpu(const std::vector<array>& inputs, array& out) {
auto ctype = in.flags().contiguous && in.size() == in.data_size() auto ctype = in.flags().contiguous && in.size() == in.data_size()
? CopyType::Vector ? CopyType::Vector
: CopyType::General; : CopyType::General;
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype); copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
// Calculate out strides, initial offset and if copy needs to be made // Calculate out strides, initial offset and if copy needs to be made
auto [data_offset, out_strides] = auto [data_offset, out_strides] =
@@ -344,7 +427,8 @@ void SliceUpdate::eval_cpu(const std::vector<array>& inputs, array& out) {
/* const std::vector<stride_t>& o_strides = */ out_strides, /* const std::vector<stride_t>& o_strides = */ out_strides,
/* int64_t i_offset = */ 0, /* int64_t i_offset = */ 0,
/* int64_t o_offset = */ data_offset, /* int64_t o_offset = */ data_offset,
/* CopyType ctype = */ CopyType::GeneralGeneral); /* CopyType ctype = */ CopyType::GeneralGeneral,
stream());
} }
void View::eval_cpu(const std::vector<array>& inputs, array& out) { void View::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -368,13 +452,13 @@ void View::eval_cpu(const std::vector<array>& inputs, array& out) {
} else { } else {
auto tmp = array( auto tmp = array(
in.shape(), in.dtype() == bool_ ? uint8 : in.dtype(), nullptr, {}); in.shape(), in.dtype() == bool_ ? uint8 : in.dtype(), nullptr, {});
tmp.set_data(allocator::malloc_or_wait(tmp.nbytes())); tmp.set_data(allocator::malloc(tmp.nbytes()));
if (in.dtype() == bool_) { if (in.dtype() == bool_) {
auto in_tmp = array(in.shape(), uint8, nullptr, {}); auto in_tmp = array(in.shape(), uint8, nullptr, {});
in_tmp.copy_shared_buffer(in); in_tmp.copy_shared_buffer(in);
copy_inplace(in_tmp, tmp, CopyType::General); copy_inplace(in_tmp, tmp, CopyType::General, stream());
} else { } else {
copy_inplace(in, tmp, CopyType::General); copy_inplace(in, tmp, CopyType::General, stream());
} }
auto flags = out.flags(); auto flags = out.flags();
@@ -382,7 +466,7 @@ void View::eval_cpu(const std::vector<array>& inputs, array& out) {
flags.row_contiguous = true; flags.row_contiguous = true;
auto max_dim = std::max_element(out.shape().begin(), out.shape().end()); auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim; flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
out.move_shared_buffer(tmp, out.strides(), flags, out.size()); out.copy_shared_buffer(tmp, out.strides(), flags, out.size());
} }
} }

View File

@@ -2,20 +2,18 @@
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/lapack.h" #include "mlx/backend/cpu/lapack.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
namespace mlx::core { namespace mlx::core {
template <typename T> template <typename T>
void qrf_impl(const array& a, array& q, array& r) { void qrf_impl(const array& a, array& q, array& r, Stream stream) {
const int M = a.shape(-2); const int M = a.shape(-2);
const int N = a.shape(-1); const int N = a.shape(-1);
const int lda = M; const int lda = M;
size_t num_matrices = a.size() / (M * N); size_t num_matrices = a.size() / (M * N);
int num_reflectors = std::min(M, N);
auto tau =
allocator::malloc_or_wait(sizeof(T) * num_matrices * num_reflectors);
// Copy A to inplace input and make it col-contiguous // Copy A to inplace input and make it col-contiguous
array in(a.shape(), a.dtype(), nullptr, {}); array in(a.shape(), a.dtype(), nullptr, {});
@@ -27,95 +25,107 @@ void qrf_impl(const array& a, array& q, array& r) {
auto strides = in.strides(); auto strides = in.strides();
strides[in.ndim() - 2] = 1; strides[in.ndim() - 2] = 1;
strides[in.ndim() - 1] = M; strides[in.ndim() - 1] = M;
in.set_data( in.set_data(allocator::malloc(in.nbytes()), in.nbytes(), strides, flags);
allocator::malloc_or_wait(in.nbytes()), in.nbytes(), strides, flags); copy_inplace(a, in, CopyType::GeneralGeneral, stream);
copy_inplace(a, in, CopyType::GeneralGeneral); auto& encoder = cpu::get_command_encoder(stream);
q.set_data(allocator::malloc(q.nbytes()));
r.set_data(allocator::malloc(r.nbytes()));
T optimal_work; auto in_ptr = in.data<T>();
int lwork = -1; auto r_ptr = r.data<T>();
int info; auto q_ptr = q.data<T>();
// Compute workspace size encoder.set_input_array(in);
geqrf<T>(&M, &N, nullptr, &lda, nullptr, &optimal_work, &lwork, &info); encoder.set_output_array(q);
encoder.set_output_array(r);
encoder.dispatch([in_ptr, q_ptr, r_ptr, M, N, lda, num_matrices]() {
int num_reflectors = std::min(M, N);
auto tau = allocator::malloc(sizeof(T) * num_matrices * num_reflectors);
// Update workspace size T optimal_work;
lwork = optimal_work; int lwork = -1;
auto work = allocator::malloc_or_wait(sizeof(T) * lwork); int info;
// Loop over matrices // Compute workspace size
for (int i = 0; i < num_matrices; ++i) { geqrf<T>(&M, &N, nullptr, &lda, nullptr, &optimal_work, &lwork, &info);
// Solve
geqrf<T>(
&M,
&N,
in.data<T>() + M * N * i,
&lda,
static_cast<T*>(tau.raw_ptr()) + num_reflectors * i,
static_cast<T*>(work.raw_ptr()),
&lwork,
&info);
}
allocator::free(work);
r.set_data(allocator::malloc_or_wait(r.nbytes())); // Update workspace size
lwork = optimal_work;
auto work = allocator::malloc(sizeof(T) * lwork);
for (int i = 0; i < num_matrices; ++i) { // Loop over matrices
/// num_reflectors x N for (int i = 0; i < num_matrices; ++i) {
for (int j = 0; j < r.shape(-2); ++j) { // Solve
for (int k = 0; k < j; ++k) { geqrf<T>(
r.data<T>()[i * N * num_reflectors + j * N + k] = 0; &M,
} &N,
for (int k = j; k < r.shape(-1); ++k) { in_ptr + M * N * i,
r.data<T>()[i * N * num_reflectors + j * N + k] = &lda,
in.data<T>()[i * N * M + j + k * M]; static_cast<T*>(tau.raw_ptr()) + num_reflectors * i,
static_cast<T*>(work.raw_ptr()),
&lwork,
&info);
}
allocator::free(work);
for (int i = 0; i < num_matrices; ++i) {
/// num_reflectors x N
for (int j = 0; j < num_reflectors; ++j) {
for (int k = 0; k < j; ++k) {
r_ptr[i * N * num_reflectors + j * N + k] = 0;
}
for (int k = j; k < N; ++k) {
r_ptr[i * N * num_reflectors + j * N + k] =
in_ptr[i * N * M + j + k * M];
}
} }
} }
}
// Get work size // Get work size
lwork = -1; lwork = -1;
orgqr<T>(
&M,
&num_reflectors,
&num_reflectors,
nullptr,
&lda,
nullptr,
&optimal_work,
&lwork,
&info);
lwork = optimal_work;
work = allocator::malloc_or_wait(sizeof(T) * lwork);
// Loop over matrices
for (int i = 0; i < num_matrices; ++i) {
// Compute Q
orgqr<T>( orgqr<T>(
&M, &M,
&num_reflectors, &num_reflectors,
&num_reflectors, &num_reflectors,
in.data<T>() + M * N * i, nullptr,
&lda, &lda,
static_cast<T*>(tau.raw_ptr()) + num_reflectors * i, nullptr,
static_cast<T*>(work.raw_ptr()), &optimal_work,
&lwork, &lwork,
&info); &info);
} lwork = optimal_work;
work = allocator::malloc(sizeof(T) * lwork);
q.set_data(allocator::malloc_or_wait(q.nbytes())); // Loop over matrices
for (int i = 0; i < num_matrices; ++i) { for (int i = 0; i < num_matrices; ++i) {
// M x num_reflectors // Compute Q
for (int j = 0; j < q.shape(-2); ++j) { orgqr<T>(
for (int k = 0; k < q.shape(-1); ++k) { &M,
q.data<T>()[i * M * num_reflectors + j * num_reflectors + k] = &num_reflectors,
in.data<T>()[i * N * M + j + k * M]; &num_reflectors,
in_ptr + M * N * i,
&lda,
static_cast<T*>(tau.raw_ptr()) + num_reflectors * i,
static_cast<T*>(work.raw_ptr()),
&lwork,
&info);
}
for (int i = 0; i < num_matrices; ++i) {
// M x num_reflectors
for (int j = 0; j < M; ++j) {
for (int k = 0; k < num_reflectors; ++k) {
q_ptr[i * M * num_reflectors + j * num_reflectors + k] =
in_ptr[i * N * M + j + k * M];
}
} }
} }
}
// Cleanup // Cleanup
allocator::free(work); allocator::free(work);
allocator::free(tau); allocator::free(tau);
});
encoder.add_temporary(in);
} }
void QRF::eval_cpu( void QRF::eval_cpu(
@@ -123,10 +133,10 @@ void QRF::eval_cpu(
std::vector<array>& outputs) { std::vector<array>& outputs) {
switch (inputs[0].dtype()) { switch (inputs[0].dtype()) {
case float32: case float32:
qrf_impl<float>(inputs[0], outputs[0], outputs[1]); qrf_impl<float>(inputs[0], outputs[0], outputs[1], stream());
break; break;
case float64: case float64:
qrf_impl<double>(inputs[0], outputs[0], outputs[1]); qrf_impl<double>(inputs[0], outputs[0], outputs[1], stream());
break; break;
default: default:
throw std::runtime_error( throw std::runtime_error(

View File

@@ -3,6 +3,7 @@
#include <cassert> #include <cassert>
#include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h" #include "mlx/backend/cpu/simd/simd.h"
#include "mlx/fast_primitives.h" #include "mlx/fast_primitives.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
@@ -316,7 +317,8 @@ void _qmm_dispatch_typed(
} }
} }
void _qmm_dispatch( template <typename T>
void _qmm_dispatch_typed(
array& out, array& out,
const array& x, const array& x,
const array& w, const array& w,
@@ -328,63 +330,61 @@ void _qmm_dispatch(
int K = x.shape(-1); int K = x.shape(-1);
int M = x.ndim() > 1 ? x.shape(-2) : 1; int M = x.ndim() > 1 ? x.shape(-2) : 1;
int N = out.shape(-1); int N = out.shape(-1);
int w_els = w.ndim() > 2 ? w.shape(-1) * w.shape(-2) : 0; int w_els = w.ndim() > 2 ? w.shape(-1) * w.shape(-2) : 0;
int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0; int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0;
int batch_size = x.size() / (K * M); int batch_size = x.size() / (K * M);
auto out_ptr = out.data<T>();
auto x_ptr = x.data<T>();
auto w_ptr = w.data<uint32_t>();
auto scales_ptr = scales.data<T>();
auto biases_ptr = biases.data<T>();
for (int i = 0; i < batch_size; i++) { for (int i = 0; i < batch_size; i++) {
switch (x.dtype()) { _qmm_dispatch_typed<T>(
case float32: out_ptr + i * M * N,
_qmm_dispatch_typed<float>( x_ptr + elem_to_loc(i * M * K, x.shape(), x.strides()),
out.data<float>() + i * M * N, w_ptr + elem_to_loc(i * w_els, w.shape(), w.strides()),
x.data<float>() + elem_to_loc(i * M * K, x), scales_ptr + elem_to_loc(i * g_els, scales.shape(), scales.strides()),
w.data<uint32_t>() + elem_to_loc(i * w_els, w), biases_ptr + elem_to_loc(i * g_els, biases.shape(), biases.strides()),
scales.data<float>() + elem_to_loc(i * g_els, scales), M,
biases.data<float>() + elem_to_loc(i * g_els, biases), N,
M, K,
N, bits,
K, group_size,
bits, transposed_w);
group_size,
transposed_w);
break;
case float16:
_qmm_dispatch_typed<float16_t>(
out.data<float16_t>() + i * M * N,
x.data<float16_t>() + elem_to_loc(i * M * K, x),
w.data<uint32_t>() + elem_to_loc(i * w_els, w),
scales.data<float16_t>() + elem_to_loc(i * g_els, scales),
biases.data<float16_t>() + elem_to_loc(i * g_els, biases),
M,
N,
K,
bits,
group_size,
transposed_w);
break;
case bfloat16:
_qmm_dispatch_typed<bfloat16_t>(
out.data<bfloat16_t>() + i * M * N,
x.data<bfloat16_t>() + elem_to_loc(i * M * K, x),
w.data<uint32_t>() + elem_to_loc(i * w_els, w),
scales.data<bfloat16_t>() + elem_to_loc(i * g_els, scales),
biases.data<bfloat16_t>() + elem_to_loc(i * g_els, biases),
M,
N,
K,
bits,
group_size,
transposed_w);
break;
default:
throw std::invalid_argument(
"[quantized_matmul] only floating types are supported");
}
} }
} }
void _bs_qmm_dispatch( void _qmm_dispatch(
array& out,
const array& x,
const array& w,
const array& scales,
const array& biases,
int bits,
int group_size,
bool transposed_w) {
switch (x.dtype()) {
case float32:
_qmm_dispatch_typed<float>(
out, x, w, scales, biases, bits, group_size, transposed_w);
break;
case float16:
_qmm_dispatch_typed<float16_t>(
out, x, w, scales, biases, bits, group_size, transposed_w);
break;
case bfloat16:
_qmm_dispatch_typed<bfloat16_t>(
out, x, w, scales, biases, bits, group_size, transposed_w);
break;
default:
throw std::invalid_argument(
"[quantized_matmul] only floating types are supported");
}
}
template <typename T>
void _bs_qmm_dispatch_typed(
array& out, array& out,
const array& x, const array& x,
const array& w, const array& w,
@@ -402,60 +402,90 @@ void _bs_qmm_dispatch(
int w_els = w.shape(-1) * w.shape(-2); int w_els = w.shape(-1) * w.shape(-2);
int g_els = scales.shape(-1) * scales.shape(-2); int g_els = scales.shape(-1) * scales.shape(-2);
const uint32_t* lhs_indices_data = lhs_indices.data<uint32_t>(); auto out_ptr = out.data<T>();
const uint32_t* rhs_indices_data = rhs_indices.data<uint32_t>(); auto x_ptr = x.data<T>();
auto w_ptr = w.data<uint32_t>();
auto scales_ptr = scales.data<T>();
auto biases_ptr = biases.data<T>();
auto lhs_indices_ptr = lhs_indices.data<uint32_t>();
auto rhs_indices_ptr = rhs_indices.data<uint32_t>();
for (int i = 0; i < lhs_indices.size(); i++) { for (int i = 0; i < lhs_indices.size(); i++) {
int x_idx = lhs_indices_data[elem_to_loc(i, lhs_indices)]; int x_idx = lhs_indices_ptr[elem_to_loc(
int w_idx = rhs_indices_data[elem_to_loc(i, rhs_indices)]; i, lhs_indices.shape(), lhs_indices.strides())];
int w_idx = rhs_indices_ptr[elem_to_loc(
i, rhs_indices.shape(), rhs_indices.strides())];
_qmm_dispatch_typed<T>(
out_ptr + i * M * N,
x_ptr + elem_to_loc(x_idx * M * K, x.shape(), x.strides()),
w_ptr + elem_to_loc(w_idx * w_els, w.shape(), w.strides()),
scales_ptr +
elem_to_loc(w_idx * g_els, scales.shape(), scales.strides()),
biases_ptr +
elem_to_loc(w_idx * g_els, biases.shape(), biases.strides()),
M,
N,
K,
bits,
group_size,
transposed_w);
}
}
switch (x.dtype()) { void _bs_qmm_dispatch(
case float32: array& out,
_qmm_dispatch_typed<float>( const array& x,
out.data<float>() + i * M * N, const array& w,
x.data<float>() + elem_to_loc(x_idx * M * K, x), const array& scales,
w.data<uint32_t>() + elem_to_loc(w_idx * w_els, w), const array& biases,
scales.data<float>() + elem_to_loc(w_idx * g_els, scales), const array& lhs_indices,
biases.data<float>() + elem_to_loc(w_idx * g_els, biases), const array& rhs_indices,
M, int bits,
N, int group_size,
K, bool transposed_w) {
bits, switch (x.dtype()) {
group_size, case float32:
transposed_w); _bs_qmm_dispatch_typed<float>(
break; out,
case float16: x,
_qmm_dispatch_typed<float16_t>( w,
out.data<float16_t>() + i * M * N, scales,
x.data<float16_t>() + elem_to_loc(x_idx * M * K, x), biases,
w.data<uint32_t>() + elem_to_loc(w_idx * w_els, w), lhs_indices,
scales.data<float16_t>() + elem_to_loc(w_idx * g_els, scales), rhs_indices,
biases.data<float16_t>() + elem_to_loc(w_idx * g_els, biases), bits,
M, group_size,
N, transposed_w);
K, break;
bits, case float16:
group_size, _bs_qmm_dispatch_typed<float16_t>(
transposed_w); out,
break; x,
case bfloat16: w,
_qmm_dispatch_typed<bfloat16_t>( scales,
out.data<bfloat16_t>() + i * M * N, biases,
x.data<bfloat16_t>() + elem_to_loc(x_idx * M * K, x), lhs_indices,
w.data<uint32_t>() + elem_to_loc(w_idx * w_els, w), rhs_indices,
scales.data<bfloat16_t>() + elem_to_loc(w_idx * g_els, scales), bits,
biases.data<bfloat16_t>() + elem_to_loc(w_idx * g_els, biases), group_size,
M, transposed_w);
N, break;
K, case bfloat16:
bits, _bs_qmm_dispatch_typed<bfloat16_t>(
group_size, out,
transposed_w); x,
break; w,
default: scales,
throw std::invalid_argument( biases,
"[quantized_matmul] only floating types are supported"); lhs_indices,
} rhs_indices,
bits,
group_size,
transposed_w);
break;
default:
throw std::invalid_argument(
"[quantized_matmul] only floating types are supported");
} }
} }
@@ -469,13 +499,14 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& scales_pre = inputs[2]; auto& scales_pre = inputs[2];
auto& biases_pre = inputs[3]; auto& biases_pre = inputs[3];
auto ensure_row_contiguous = [](const array& arr) { std::vector<array> temps;
auto ensure_row_contiguous = [s = stream(), &temps](const array& arr) {
if (arr.flags().row_contiguous) { if (arr.flags().row_contiguous) {
return arr; return arr;
} else { } else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
copy(arr, arr_copy, CopyType::General); copy(arr, temps.back(), CopyType::General, s);
return arr_copy; return temps.back();
} }
}; };
@@ -484,8 +515,25 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
auto scales = ensure_row_contiguous(scales_pre); auto scales = ensure_row_contiguous(scales_pre);
auto biases = ensure_row_contiguous(biases_pre); auto biases = ensure_row_contiguous(biases_pre);
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
auto& encoder = cpu::get_command_encoder(stream());
encoder.add_temporaries(std::move(temps));
encoder.set_input_array(x);
encoder.set_input_array(w);
encoder.set_input_array(scales);
encoder.set_input_array(biases);
encoder.set_output_array(out);
encoder.dispatch([out = array::unsafe_weak_copy(out),
x = array::unsafe_weak_copy(x),
w = array::unsafe_weak_copy(w),
scales = array::unsafe_weak_copy(scales),
biases = array::unsafe_weak_copy(biases),
group_size_ = group_size_,
bits_ = bits_,
transpose_ = transpose_]() mutable {
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
});
} }
void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) { void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -498,15 +546,17 @@ void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& lhs_indices = inputs[4]; auto& lhs_indices = inputs[4];
auto& rhs_indices = inputs[5]; auto& rhs_indices = inputs[5];
auto ensure_row_contiguous_last_dims = [](const array& arr) { std::vector<array> temps;
auto ensure_row_contiguous_last_dims = [s = stream(),
&temps](const array& arr) {
auto stride_0 = arr.strides()[arr.ndim() - 2]; auto stride_0 = arr.strides()[arr.ndim() - 2];
auto stride_1 = arr.strides()[arr.ndim() - 1]; auto stride_1 = arr.strides()[arr.ndim() - 1];
if (stride_0 == arr.shape(-1) && stride_1 == 1) { if (stride_0 == arr.shape(-1) && stride_1 == 1) {
return arr; return arr;
} else { } else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
copy(arr, arr_copy, CopyType::General); copy(arr, temps.back(), CopyType::General, s);
return arr_copy; return temps.back();
} }
}; };
@@ -515,42 +565,59 @@ void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
auto scales = ensure_row_contiguous_last_dims(scales_pre); auto scales = ensure_row_contiguous_last_dims(scales_pre);
auto biases = ensure_row_contiguous_last_dims(biases_pre); auto biases = ensure_row_contiguous_last_dims(biases_pre);
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
_bs_qmm_dispatch(
out, auto& encoder = cpu::get_command_encoder(stream());
x, encoder.add_temporaries(std::move(temps));
w, encoder.set_input_array(x);
scales, encoder.set_input_array(w);
biases, encoder.set_input_array(scales);
lhs_indices, encoder.set_input_array(biases);
rhs_indices, encoder.set_input_array(lhs_indices);
group_size_, encoder.set_input_array(rhs_indices);
bits_, encoder.set_output_array(out);
transpose_); encoder.dispatch([out = array::unsafe_weak_copy(out),
x = array::unsafe_weak_copy(x),
w = array::unsafe_weak_copy(w),
scales = array::unsafe_weak_copy(scales),
biases = array::unsafe_weak_copy(biases),
lhs_indices = array::unsafe_weak_copy(lhs_indices),
rhs_indices = array::unsafe_weak_copy(rhs_indices),
group_size_ = group_size_,
bits_ = bits_,
transpose_ = transpose_]() mutable {
_bs_qmm_dispatch(
out,
x,
w,
scales,
biases,
lhs_indices,
rhs_indices,
group_size_,
bits_,
transpose_);
});
} }
template <typename T, typename U> template <typename T, typename U>
void quantize( void quantize(
const array& w_, const T* w,
array& out_, U* out,
array& scales_, T* scales,
array& biases_, T* biases,
int bits, int bits,
int group_size) { int group_size,
const T* w = w_.data<T>(); size_t w_size) {
auto out = out_.data<U>();
T* scales = scales_.data<T>();
T* biases = biases_.data<T>();
float n_bins = (1 << bits) - 1; float n_bins = (1 << bits) - 1;
float eps = 1e-7; float eps = 1e-7;
bool power_of_2_bits = is_power_of_2(bits); bool power_of_2_bits = is_power_of_2(bits);
int el_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits; int el_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
// For 3/6 bits we read 3 uint8s at a time instead of 1 uint32 // For 3/6 bits we read 3 uint8s at a time instead of 1 uint32
int bytes_per_pack = power_of_2_bits ? 1 : 3; int bytes_per_pack = power_of_2_bits ? 1 : 3;
int int_per_group = group_size * bytes_per_pack / el_per_int; int int_per_group = group_size * bytes_per_pack / el_per_int;
size_t n_groups = w_.size() / group_size; size_t n_groups = w_size / group_size;
for (size_t i = 0; i < n_groups; ++i) { for (size_t i = 0; i < n_groups; ++i) {
size_t w_idx = i * group_size; size_t w_idx = i * group_size;
@@ -593,50 +660,86 @@ void quantize(
} }
} }
template <typename T, typename U>
void dispatch_quantize(
const array& w,
array& out,
array& scales,
array& biases,
int bits,
int group_size) {
auto w_ptr = w.data<T>();
auto out_ptr = out.data<U>();
auto scales_ptr = scales.data<T>();
auto biases_ptr = biases.data<T>();
quantize<T, U>(
w_ptr, out_ptr, scales_ptr, biases_ptr, bits, group_size, w.size());
}
void fast::AffineQuantize::eval_cpu( void fast::AffineQuantize::eval_cpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs) { std::vector<array>& outputs) {
auto ensure_row_contiguous = [](const array& arr) { auto ensure_row_contiguous = [s = stream()](const array& arr) {
if (arr.flags().row_contiguous) { if (arr.flags().row_contiguous) {
return arr; return std::make_pair(arr, false);
} else { } else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General); copy(arr, arr_copy, CopyType::General, s);
return arr_copy; return std::make_pair(arr_copy, true);
} }
}; };
auto w = ensure_row_contiguous(inputs[0]);
auto [w, copied] = ensure_row_contiguous(inputs[0]);
auto& out = outputs[0]; auto& out = outputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
auto& scales = outputs[1]; auto& scales = outputs[1];
auto& biases = outputs[2]; auto& biases = outputs[2];
scales.set_data(allocator::malloc_or_wait(scales.nbytes())); scales.set_data(allocator::malloc(scales.nbytes()));
biases.set_data(allocator::malloc_or_wait(biases.nbytes())); biases.set_data(allocator::malloc(biases.nbytes()));
if (w.dtype() == float16) { auto& encoder = cpu::get_command_encoder(stream());
if (is_power_of_2(bits_)) { if (copied) {
quantize<float16_t, uint32_t>(w, out, scales, biases, bits_, group_size_); encoder.add_temporary(w);
} else {
quantize<float16_t, uint8_t>(w, out, scales, biases, bits_, group_size_);
}
} else if (w.dtype() == bfloat16) {
if (is_power_of_2(bits_)) {
quantize<bfloat16_t, uint32_t>(
w, out, scales, biases, bits_, group_size_);
} else {
quantize<bfloat16_t, uint8_t>(w, out, scales, biases, bits_, group_size_);
}
} else if (w.dtype() == float32) {
if (is_power_of_2(bits_)) {
quantize<float, uint32_t>(w, out, scales, biases, bits_, group_size_);
} else {
quantize<float, uint8_t>(w, out, scales, biases, bits_, group_size_);
}
} else {
throw std::runtime_error(
"[fast::AffineQuantize::eval_cpu] Only supports floating point inputs");
} }
encoder.set_input_array(w);
encoder.set_input_array(scales);
encoder.set_input_array(biases);
encoder.set_output_array(out);
encoder.dispatch([w = array::unsafe_weak_copy(w),
out = array::unsafe_weak_copy(out),
scales = array::unsafe_weak_copy(scales),
biases = array::unsafe_weak_copy(biases),
group_size_ = group_size_,
bits_ = bits_]() mutable {
if (w.dtype() == float16) {
if (is_power_of_2(bits_)) {
dispatch_quantize<float16_t, uint32_t>(
w, out, scales, biases, bits_, group_size_);
} else {
dispatch_quantize<float16_t, uint8_t>(
w, out, scales, biases, bits_, group_size_);
}
} else if (w.dtype() == bfloat16) {
if (is_power_of_2(bits_)) {
dispatch_quantize<bfloat16_t, uint32_t>(
w, out, scales, biases, bits_, group_size_);
} else {
dispatch_quantize<bfloat16_t, uint8_t>(
w, out, scales, biases, bits_, group_size_);
}
} else if (w.dtype() == float32) {
if (is_power_of_2(bits_)) {
dispatch_quantize<float, uint32_t>(
w, out, scales, biases, bits_, group_size_);
} else {
dispatch_quantize<float, uint8_t>(
w, out, scales, biases, bits_, group_size_);
}
} else {
throw std::runtime_error(
"[fast::AffineQuantize::eval_cpu] Only supports floating point inputs");
}
});
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -5,6 +5,7 @@
#include <limits> #include <limits>
#include "mlx/backend/common/reduce.h" #include "mlx/backend/common/reduce.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h" #include "mlx/backend/cpu/simd/simd.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
@@ -139,25 +140,22 @@ void reduction_op(
const array& x, const array& x,
array& out, array& out,
const std::vector<int>& axes, const std::vector<int>& axes,
U init, U init) {
Op op) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
ReductionPlan plan = get_reduction_plan(x, axes); ReductionPlan plan = get_reduction_plan(x, axes);
auto in_ptr = x.data<T>();
auto out_ptr = out.data<U>();
if (plan.type == ContiguousAllReduce) { if (plan.type == ContiguousAllReduce) {
U* out_ptr = out.data<U>();
*out_ptr = init; *out_ptr = init;
contiguous_reduce(x.data<T>(), out_ptr, x.size(), op, init); contiguous_reduce(in_ptr, out_ptr, x.size(), Op{}, init);
return; return;
} }
if (plan.type == ContiguousReduce && plan.shape.size() == 1) { if (plan.type == ContiguousReduce && plan.shape.size() == 1) {
int reduction_size = plan.shape[0]; int reduction_size = plan.shape[0];
const T* x_ptr = x.data<T>(); for (int i = 0; i < out.size(); i++, out_ptr++, in_ptr += reduction_size) {
U* out_ptr = out.data<U>();
for (int i = 0; i < out.size(); i++, out_ptr++, x_ptr += reduction_size) {
*out_ptr = init; *out_ptr = init;
contiguous_reduce(x_ptr, out_ptr, reduction_size, op, init); contiguous_reduce(in_ptr, out_ptr, reduction_size, Op{}, init);
} }
return; return;
} }
@@ -166,8 +164,6 @@ void reduction_op(
int reduction_size = plan.shape.back(); int reduction_size = plan.shape.back();
plan.shape.pop_back(); plan.shape.pop_back();
plan.strides.pop_back(); plan.strides.pop_back();
const T* x_ptr = x.data<T>();
U* out_ptr = out.data<U>();
// Unrolling the following loop (and implementing it in order for // Unrolling the following loop (and implementing it in order for
// ContiguousReduce) should hold extra performance boost. // ContiguousReduce) should hold extra performance boost.
auto [shape, strides] = shapes_without_reduction_axes(x, axes); auto [shape, strides] = shapes_without_reduction_axes(x, axes);
@@ -175,7 +171,7 @@ void reduction_op(
for (int i = 0; i < out.size(); i++, out_ptr++) { for (int i = 0; i < out.size(); i++, out_ptr++) {
int offset = elem_to_loc(i, shape, strides); int offset = elem_to_loc(i, shape, strides);
*out_ptr = init; *out_ptr = init;
contiguous_reduce(x_ptr + offset, out_ptr, reduction_size, op, init); contiguous_reduce(in_ptr + offset, out_ptr, reduction_size, Op{}, init);
} }
} else { } else {
for (int i = 0; i < out.size(); i++, out_ptr++) { for (int i = 0; i < out.size(); i++, out_ptr++) {
@@ -184,10 +180,10 @@ void reduction_op(
nd_loop( nd_loop(
[&](int extra_offset) { [&](int extra_offset) {
contiguous_reduce( contiguous_reduce(
x_ptr + offset + extra_offset, in_ptr + offset + extra_offset,
out_ptr, out_ptr,
reduction_size, reduction_size,
op, Op{},
init); init);
}, },
plan.shape, plan.shape,
@@ -202,12 +198,10 @@ void reduction_op(
size_t reduction_stride = plan.strides.back(); size_t reduction_stride = plan.strides.back();
plan.shape.pop_back(); plan.shape.pop_back();
plan.strides.pop_back(); plan.strides.pop_back();
const T* x_ptr = x.data<T>();
U* out_ptr = out.data<U>();
for (int i = 0; i < out.size(); i += reduction_stride) { for (int i = 0; i < out.size(); i += reduction_stride) {
std::fill_n(out_ptr, reduction_stride, init); std::fill_n(out_ptr, reduction_stride, init);
strided_reduce(x_ptr, out_ptr, reduction_size, reduction_stride, op); strided_reduce(in_ptr, out_ptr, reduction_size, reduction_stride, Op{});
x_ptr += reduction_stride * reduction_size; in_ptr += reduction_stride * reduction_size;
out_ptr += reduction_stride; out_ptr += reduction_stride;
} }
return; return;
@@ -219,15 +213,14 @@ void reduction_op(
size_t reduction_stride = plan.strides.back(); size_t reduction_stride = plan.strides.back();
plan.shape.pop_back(); plan.shape.pop_back();
plan.strides.pop_back(); plan.strides.pop_back();
const T* x_ptr = x.data<T>();
U* out_ptr = out.data<U>();
auto [shape, strides] = shapes_without_reduction_axes(x, axes); auto [shape, strides] = shapes_without_reduction_axes(x, axes);
if (plan.shape.size() == 0) { if (plan.shape.size() == 0) {
for (int i = 0; i < out.size(); i += reduction_stride) { for (int i = 0; i < out.size(); i += reduction_stride) {
int offset = elem_to_loc(i, shape, strides); int offset = elem_to_loc(i, shape, strides);
std::fill_n(out_ptr, reduction_stride, init); std::fill_n(out_ptr, reduction_stride, init);
strided_reduce( strided_reduce(
x_ptr + offset, out_ptr, reduction_size, reduction_stride, op); in_ptr + offset, out_ptr, reduction_size, reduction_stride, Op{});
out_ptr += reduction_stride; out_ptr += reduction_stride;
} }
} else { } else {
@@ -237,11 +230,11 @@ void reduction_op(
nd_loop( nd_loop(
[&](int extra_offset) { [&](int extra_offset) {
strided_reduce( strided_reduce(
x_ptr + offset + extra_offset, in_ptr + offset + extra_offset,
out_ptr, out_ptr,
reduction_size, reduction_size,
reduction_stride, reduction_stride,
op); Op{});
}, },
plan.shape, plan.shape,
plan.strides); plan.strides);
@@ -252,15 +245,14 @@ void reduction_op(
} }
if (plan.type == GeneralReduce) { if (plan.type == GeneralReduce) {
const T* x_ptr = x.data<T>();
U* out_ptr = out.data<U>();
auto [shape, strides] = shapes_without_reduction_axes(x, axes); auto [shape, strides] = shapes_without_reduction_axes(x, axes);
for (int i = 0; i < out.size(); i++, out_ptr++) { for (int i = 0; i < out.size(); i++, out_ptr++) {
int offset = elem_to_loc(i, shape, strides); int offset = elem_to_loc(i, shape, strides);
U val = init; U val = init;
nd_loop( nd_loop(
[&](int extra_offset) { [&](int extra_offset) {
val = op(val, *(x_ptr + offset + extra_offset)); val = Op{}(val, *(in_ptr + offset + extra_offset));
}, },
plan.shape, plan.shape,
plan.strides); plan.strides);
@@ -396,9 +388,9 @@ void reduce_dispatch_and_or(
Reduce::ReduceType rtype, Reduce::ReduceType rtype,
const std::vector<int>& axes) { const std::vector<int>& axes) {
if (rtype == Reduce::And) { if (rtype == Reduce::And) {
reduction_op<InT, bool>(in, out, axes, true, AndReduce()); reduction_op<InT, bool, AndReduce>(in, out, axes, true);
} else { } else {
reduction_op<InT, bool>(in, out, axes, false, OrReduce()); reduction_op<InT, bool, OrReduce>(in, out, axes, false);
} }
} }
@@ -410,15 +402,15 @@ void reduce_dispatch_sum_prod(
const std::vector<int>& axes) { const std::vector<int>& axes) {
if (rtype == Reduce::Sum) { if (rtype == Reduce::Sum) {
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) { if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
reduction_op<InT, int32_t>(in, out, axes, 0, SumReduce()); reduction_op<InT, int32_t, SumReduce>(in, out, axes, 0);
} else { } else {
reduction_op<InT, InT>(in, out, axes, 0, SumReduce()); reduction_op<InT, InT, SumReduce>(in, out, axes, 0);
} }
} else { } else {
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) { if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
reduction_op<InT, int32_t>(in, out, axes, 1, ProdReduce()); reduction_op<InT, int32_t, ProdReduce>(in, out, axes, 1);
} else { } else {
reduction_op<InT, InT>(in, out, axes, 1, ProdReduce()); reduction_op<InT, InT, ProdReduce>(in, out, axes, 1);
} }
} }
} }
@@ -431,132 +423,141 @@ void reduce_dispatch_min_max(
const std::vector<int>& axes) { const std::vector<int>& axes) {
if (rtype == Reduce::Max) { if (rtype == Reduce::Max) {
auto init = Limits<InT>::min; auto init = Limits<InT>::min;
reduction_op<InT, InT>(in, out, axes, init, MaxReduce()); reduction_op<InT, InT, MaxReduce>(in, out, axes, init);
} else { } else {
auto init = Limits<InT>::max; auto init = Limits<InT>::max;
reduction_op<InT, InT>(in, out, axes, init, MinReduce()); reduction_op<InT, InT, MinReduce>(in, out, axes, init);
} }
} }
void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) { void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
switch (reduce_type_) { out.set_data(allocator::malloc(out.nbytes()));
case Reduce::And: auto& encoder = cpu::get_command_encoder(stream());
case Reduce::Or: { encoder.set_input_array(in);
switch (in.dtype()) { encoder.set_output_array(out);
case bool_: encoder.dispatch([in = array::unsafe_weak_copy(in),
case uint8: out = array::unsafe_weak_copy(out),
case int8: reduce_type_ = reduce_type_,
reduce_dispatch_and_or<int8_t>(in, out, reduce_type_, axes_); axes_ = axes_]() mutable {
break; switch (reduce_type_) {
case int16: case Reduce::And:
case uint16: case Reduce::Or: {
case float16: switch (in.dtype()) {
case bfloat16: case bool_:
reduce_dispatch_and_or<int16_t>(in, out, reduce_type_, axes_); case uint8:
break; case int8:
case uint32: reduce_dispatch_and_or<int8_t>(in, out, reduce_type_, axes_);
case int32: break;
case float32: case int16:
reduce_dispatch_and_or<int32_t>(in, out, reduce_type_, axes_); case uint16:
break; case float16:
case uint64: case bfloat16:
case int64: reduce_dispatch_and_or<int16_t>(in, out, reduce_type_, axes_);
case float64: break;
case complex64: case uint32:
reduce_dispatch_and_or<int64_t>(in, out, reduce_type_, axes_); case int32:
break; case float32:
reduce_dispatch_and_or<int32_t>(in, out, reduce_type_, axes_);
break;
case uint64:
case int64:
case float64:
case complex64:
reduce_dispatch_and_or<int64_t>(in, out, reduce_type_, axes_);
break;
}
break;
} }
break; case Reduce::Sum:
} case Reduce::Prod: {
case Reduce::Sum: switch (in.dtype()) {
case Reduce::Prod: { case bool_:
switch (in.dtype()) { case uint8:
case bool_: case int8:
case uint8: reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_);
case int8: break;
reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_); case int16:
break; case uint16:
case int16: reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_);
case uint16: break;
reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_); case int32:
break; case uint32:
case int32: reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_);
case uint32: break;
reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_); case int64:
break; case uint64:
case int64: reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_);
case uint64: break;
reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_); case float16:
break; reduce_dispatch_sum_prod<float16_t>(in, out, reduce_type_, axes_);
case float16: break;
reduce_dispatch_sum_prod<float16_t>(in, out, reduce_type_, axes_); case bfloat16:
break; reduce_dispatch_sum_prod<bfloat16_t>(in, out, reduce_type_, axes_);
case bfloat16: break;
reduce_dispatch_sum_prod<bfloat16_t>(in, out, reduce_type_, axes_); case float32:
break; reduce_dispatch_sum_prod<float>(in, out, reduce_type_, axes_);
case float32: break;
reduce_dispatch_sum_prod<float>(in, out, reduce_type_, axes_); case float64:
break; reduce_dispatch_sum_prod<double>(in, out, reduce_type_, axes_);
case float64: break;
reduce_dispatch_sum_prod<double>(in, out, reduce_type_, axes_); case complex64:
break; reduce_dispatch_sum_prod<complex64_t>(in, out, reduce_type_, axes_);
case complex64: break;
reduce_dispatch_sum_prod<complex64_t>(in, out, reduce_type_, axes_); }
break; break;
} }
break; case Reduce::Max:
} case Reduce::Min: {
case Reduce::Max: switch (in.dtype()) {
case Reduce::Min: { case bool_:
switch (in.dtype()) { reduce_dispatch_min_max<bool>(in, out, reduce_type_, axes_);
case bool_: break;
reduce_dispatch_min_max<bool>(in, out, reduce_type_, axes_); case uint8:
break; reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
case uint8: break;
reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_); case uint16:
break; reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
case uint16: break;
reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_); case uint32:
break; reduce_dispatch_min_max<uint32_t>(in, out, reduce_type_, axes_);
case uint32: break;
reduce_dispatch_min_max<uint32_t>(in, out, reduce_type_, axes_); case uint64:
break; reduce_dispatch_min_max<uint64_t>(in, out, reduce_type_, axes_);
case uint64: break;
reduce_dispatch_min_max<uint64_t>(in, out, reduce_type_, axes_); case int8:
break; reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
case int8: break;
reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_); case int16:
break; reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
case int16: break;
reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_); case int32:
break; reduce_dispatch_min_max<int32_t>(in, out, reduce_type_, axes_);
case int32: break;
reduce_dispatch_min_max<int32_t>(in, out, reduce_type_, axes_); case int64:
break; reduce_dispatch_min_max<int64_t>(in, out, reduce_type_, axes_);
case int64: break;
reduce_dispatch_min_max<int64_t>(in, out, reduce_type_, axes_); case float16:
break; reduce_dispatch_min_max<float16_t>(in, out, reduce_type_, axes_);
case float16: break;
reduce_dispatch_min_max<float16_t>(in, out, reduce_type_, axes_); case float32:
break; reduce_dispatch_min_max<float>(in, out, reduce_type_, axes_);
case float32: break;
reduce_dispatch_min_max<float>(in, out, reduce_type_, axes_); case float64:
break; reduce_dispatch_min_max<double>(in, out, reduce_type_, axes_);
case float64: break;
reduce_dispatch_min_max<double>(in, out, reduce_type_, axes_); case bfloat16:
break; reduce_dispatch_min_max<bfloat16_t>(in, out, reduce_type_, axes_);
case bfloat16: break;
reduce_dispatch_min_max<bfloat16_t>(in, out, reduce_type_, axes_); case complex64:
break; reduce_dispatch_min_max<complex64_t>(in, out, reduce_type_, axes_);
case complex64: break;
reduce_dispatch_min_max<complex64_t>(in, out, reduce_type_, axes_); }
break; break;
} }
break;
} }
} });
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -3,7 +3,9 @@
#include <cassert> #include <cassert>
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/binary_ops.h"
#include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h" #include "mlx/backend/cpu/simd/simd.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
@@ -153,33 +155,31 @@ void strided_scan(
template <typename T, typename U, typename Op> template <typename T, typename U, typename Op>
void scan_op( void scan_op(
const array& input, const array& in,
array& output, array& out,
int axis, int axis,
bool reverse, bool reverse,
bool inclusive, bool inclusive,
const Op& op, const Op& op,
U init) { U init) {
output.set_data(allocator::malloc_or_wait(output.nbytes())); if (in.flags().row_contiguous) {
if (in.strides()[axis] == 1) {
if (input.flags().row_contiguous) {
if (input.strides()[axis] == 1) {
contiguous_scan( contiguous_scan(
input.data<T>(), in.data<T>(),
output.data<U>(), out.data<U>(),
input.size() / input.shape(axis), in.size() / in.shape(axis),
input.shape(axis), in.shape(axis),
reverse, reverse,
inclusive, inclusive,
op, op,
init); init);
} else { } else {
strided_scan( strided_scan(
input.data<T>(), in.data<T>(),
output.data<U>(), out.data<U>(),
input.size() / input.shape(axis) / input.strides()[axis], in.size() / in.shape(axis) / in.strides()[axis],
input.shape(axis), in.shape(axis),
input.strides()[axis], in.strides()[axis],
reverse, reverse,
inclusive, inclusive,
op, op,
@@ -193,8 +193,8 @@ void scan_op(
template <typename T, typename U> template <typename T, typename U>
void scan_dispatch( void scan_dispatch(
Scan::ReduceType rtype, Scan::ReduceType rtype,
const array& input, const array& in,
array& output, array& out,
int axis, int axis,
bool reverse, bool reverse,
bool inclusive) { bool inclusive) {
@@ -202,29 +202,39 @@ void scan_dispatch(
case Scan::Sum: { case Scan::Sum: {
auto op = [](U y, T x) { return y + x; }; auto op = [](U y, T x) { return y + x; };
auto init = static_cast<U>(0); auto init = static_cast<U>(0);
scan_op<T, U>(input, output, axis, reverse, inclusive, op, init); scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
break; break;
} }
case Scan::Prod: { case Scan::Prod: {
auto op = [](U y, T x) { return y * x; }; auto op = [](U y, T x) { return y * x; };
auto init = static_cast<U>(1); auto init = static_cast<U>(1);
scan_op<T, U>(input, output, axis, reverse, inclusive, op, init); scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
break; break;
} }
case Scan::Min: { case Scan::Min: {
auto op = [](U y, T x) { return x < y ? x : y; }; auto op = [](U y, T x) { return x < y ? x : y; };
auto init = (issubdtype(input.dtype(), floating)) auto init = (issubdtype(in.dtype(), floating))
? static_cast<U>(std::numeric_limits<float>::infinity()) ? static_cast<U>(std::numeric_limits<float>::infinity())
: std::numeric_limits<U>::max(); : std::numeric_limits<U>::max();
scan_op<T, U>(input, output, axis, reverse, inclusive, op, init); scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
break; break;
} }
case Scan::Max: { case Scan::Max: {
auto op = [](U y, T x) { return x < y ? y : x; }; auto op = [](U y, T x) { return x < y ? y : x; };
auto init = (issubdtype(input.dtype(), floating)) auto init = (issubdtype(in.dtype(), floating))
? static_cast<U>(-std::numeric_limits<float>::infinity()) ? static_cast<U>(-std::numeric_limits<float>::infinity())
: std::numeric_limits<U>::min(); : std::numeric_limits<U>::min();
scan_op<T, U>(input, output, axis, reverse, inclusive, op, init); scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
break;
}
case Scan::LogAddExp: {
auto op = [](U a, T b) {
return detail::LogAddExp{}(a, static_cast<U>(b));
};
auto init = (issubdtype(in.dtype(), floating))
? static_cast<U>(-std::numeric_limits<float>::infinity())
: std::numeric_limits<U>::min();
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
break; break;
} }
} }
@@ -235,82 +245,95 @@ void scan_dispatch(
void Scan::eval_cpu(const std::vector<array>& inputs, array& out) { void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& encoder = cpu::get_command_encoder(stream());
// Ensure contiguity // Ensure contiguity
auto in = inputs[0]; auto in = inputs[0];
if (!in.flags().row_contiguous) { if (!in.flags().row_contiguous) {
array arr_copy(in.shape(), in.dtype(), nullptr, {}); array arr_copy(in.shape(), in.dtype(), nullptr, {});
copy(in, arr_copy, CopyType::General); copy(in, arr_copy, CopyType::General, stream());
in = arr_copy; in = arr_copy;
encoder.add_temporary(arr_copy);
} }
out.set_data(allocator::malloc(out.nbytes()));
switch (in.dtype()) { encoder.set_input_array(in);
case bool_: { encoder.set_output_array(out);
// We could do a full dtype x dtype switch but this is the only case encoder.dispatch([in = array::unsafe_weak_copy(in),
// where we accumulate in a different type, for now. out = array::unsafe_weak_copy(out),
// axis_ = axis_,
// TODO: If we add the option to accumulate floats in higher precision reduce_type_ = reduce_type_,
// floats perhaps we should add the full all-to-all dispatch. reverse_ = reverse_,
if (reduce_type_ == Scan::Sum && out.dtype() == int32) { inclusive_ = inclusive_]() mutable {
scan_dispatch<bool, int32_t>( switch (in.dtype()) {
reduce_type_, in, out, axis_, reverse_, inclusive_); case bool_: {
} else { // We could do a full dtype x dtype switch but this is the only case
scan_dispatch<bool, bool>( // where we accumulate in a different type, for now.
reduce_type_, in, out, axis_, reverse_, inclusive_); //
// TODO: If we add the option to accumulate floats in higher precision
// floats perhaps we should add the full all-to-all dispatch.
if (reduce_type_ == Scan::Sum && out.dtype() == int32) {
scan_dispatch<bool, int32_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
} else {
scan_dispatch<bool, bool>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
}
break;
} }
break; case uint8:
scan_dispatch<uint8_t, uint8_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case uint16:
scan_dispatch<uint16_t, uint16_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case uint32:
scan_dispatch<uint32_t, uint32_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case uint64:
scan_dispatch<uint64_t, uint64_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case int8:
scan_dispatch<int8_t, int8_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case int16:
scan_dispatch<int16_t, int16_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case int32:
scan_dispatch<int32_t, int32_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case int64:
scan_dispatch<int64_t, int64_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case float16:
scan_dispatch<float16_t, float16_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case float32:
scan_dispatch<float, float>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case float64:
scan_dispatch<double, double>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case bfloat16:
scan_dispatch<bfloat16_t, bfloat16_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case complex64:
throw std::runtime_error("Scan ops do not support complex types yet");
break;
} }
case uint8: });
scan_dispatch<uint8_t, uint8_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case uint16:
scan_dispatch<uint16_t, uint16_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case uint32:
scan_dispatch<uint32_t, uint32_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case uint64:
scan_dispatch<uint64_t, uint64_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case int8:
scan_dispatch<int8_t, int8_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case int16:
scan_dispatch<int16_t, int16_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case int32:
scan_dispatch<int32_t, int32_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case int64:
scan_dispatch<int64_t, int64_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case float16:
scan_dispatch<float16_t, float16_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case float32:
scan_dispatch<float, float>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case float64:
scan_dispatch<double, double>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case bfloat16:
scan_dispatch<bfloat16_t, bfloat16_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case complex64:
throw std::runtime_error("Scan ops do not support complex types yet");
break;
}
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -16,51 +16,70 @@ void select_op(
const array& b, const array& b,
const array& c, const array& c,
array& out, array& out,
Op op) { Op op,
switch (out.dtype()) { Stream stream) {
case bool_: TernaryOpType topt = get_ternary_op_type(a, b, c);
ternary_op<bool, bool, bool, bool>(a, b, c, out, op); set_ternary_op_output_data(a, b, c, out, topt);
break;
case uint8: auto& encoder = cpu::get_command_encoder(stream);
ternary_op<bool, uint8_t, uint8_t, uint8_t>(a, b, c, out, op); encoder.set_input_array(a);
break; encoder.set_input_array(b);
case uint16: encoder.set_input_array(c);
ternary_op<bool, uint16_t, uint16_t, uint16_t>(a, b, c, out, op); encoder.set_output_array(out);
break; encoder.dispatch([a = array::unsafe_weak_copy(a),
case uint32: b = array::unsafe_weak_copy(b),
ternary_op<bool, uint32_t, uint32_t, uint32_t>(a, b, c, out, op); c = array::unsafe_weak_copy(c),
break; out = array::unsafe_weak_copy(out),
case uint64: op,
ternary_op<bool, uint64_t, uint64_t, uint64_t>(a, b, c, out, op); topt]() mutable {
break; switch (out.dtype()) {
case int8: case bool_:
ternary_op<bool, int8_t, int8_t, int8_t>(a, b, c, out, op); ternary_op<bool, bool, bool, bool>(a, b, c, out, op, topt);
break; break;
case int16: case uint8:
ternary_op<bool, int16_t, int16_t, int16_t>(a, b, c, out, op); ternary_op<bool, uint8_t, uint8_t, uint8_t>(a, b, c, out, op, topt);
break; break;
case int32: case uint16:
ternary_op<bool, int32_t, int32_t, int32_t>(a, b, c, out, op); ternary_op<bool, uint16_t, uint16_t, uint16_t>(a, b, c, out, op, topt);
break; break;
case int64: case uint32:
ternary_op<bool, int64_t, int64_t, int64_t>(a, b, c, out, op); ternary_op<bool, uint32_t, uint32_t, uint32_t>(a, b, c, out, op, topt);
break; break;
case float16: case uint64:
ternary_op<bool, float16_t, float16_t, float16_t>(a, b, c, out, op); ternary_op<bool, uint64_t, uint64_t, uint64_t>(a, b, c, out, op, topt);
break; break;
case float32: case int8:
ternary_op<bool, float, float, float>(a, b, c, out, op); ternary_op<bool, int8_t, int8_t, int8_t>(a, b, c, out, op, topt);
break; break;
case float64: case int16:
ternary_op<bool, double, double, double>(a, b, c, out, op); ternary_op<bool, int16_t, int16_t, int16_t>(a, b, c, out, op, topt);
break; break;
case bfloat16: case int32:
ternary_op<bool, bfloat16_t, bfloat16_t, bfloat16_t>(a, b, c, out, op); ternary_op<bool, int32_t, int32_t, int32_t>(a, b, c, out, op, topt);
break; break;
case complex64: case int64:
ternary_op<bool, complex64_t, complex64_t, complex64_t>(a, b, c, out, op); ternary_op<bool, int64_t, int64_t, int64_t>(a, b, c, out, op, topt);
break; break;
} case float16:
ternary_op<bool, float16_t, float16_t, float16_t>(
a, b, c, out, op, topt);
break;
case float32:
ternary_op<bool, float, float, float>(a, b, c, out, op, topt);
break;
case float64:
ternary_op<bool, double, double, double>(a, b, c, out, op, topt);
break;
case bfloat16:
ternary_op<bool, bfloat16_t, bfloat16_t, bfloat16_t>(
a, b, c, out, op, topt);
break;
case complex64:
ternary_op<bool, complex64_t, complex64_t, complex64_t>(
a, b, c, out, op, topt);
break;
}
});
} }
} // namespace } // namespace
@@ -70,7 +89,7 @@ void Select::eval_cpu(const std::vector<array>& inputs, array& out) {
const auto& condition = inputs[0]; const auto& condition = inputs[0];
const auto& a = inputs[1]; const auto& a = inputs[1];
const auto& b = inputs[2]; const auto& b = inputs[2];
select_op(condition, a, b, out, detail::Select()); select_op(condition, a, b, out, detail::Select(), stream());
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -17,7 +17,7 @@ struct ScalarT<float16_t, N> {
#endif #endif
template <> template <>
static constexpr int max_size<float16_t> = N; inline constexpr int max_size<float16_t> = N;
#define SIMD_FP16_DEFAULT_UNARY(op) \ #define SIMD_FP16_DEFAULT_UNARY(op) \
template <> \ template <> \

View File

@@ -83,25 +83,25 @@ struct Simd {
// Values chosen based on benchmarks on M3 Max // Values chosen based on benchmarks on M3 Max
// TODO: consider choosing these more optimally // TODO: consider choosing these more optimally
template <> template <>
static constexpr int max_size<int8_t> = 16; inline constexpr int max_size<int8_t> = 16;
template <> template <>
static constexpr int max_size<int16_t> = 16; inline constexpr int max_size<int16_t> = 16;
template <> template <>
static constexpr int max_size<int> = 8; inline constexpr int max_size<int> = 8;
template <> template <>
static constexpr int max_size<int64_t> = 4; inline constexpr int max_size<int64_t> = 4;
template <> template <>
static constexpr int max_size<uint8_t> = 16; inline constexpr int max_size<uint8_t> = 16;
template <> template <>
static constexpr int max_size<uint16_t> = 16; inline constexpr int max_size<uint16_t> = 16;
template <> template <>
static constexpr int max_size<uint32_t> = 8; inline constexpr int max_size<uint32_t> = 8;
template <> template <>
static constexpr int max_size<uint64_t> = 4; inline constexpr int max_size<uint64_t> = 4;
template <> template <>
static constexpr int max_size<float> = 8; inline constexpr int max_size<float> = 8;
template <> template <>
static constexpr int max_size<double> = 4; inline constexpr int max_size<double> = 4;
#define SIMD_DEFAULT_UNARY(name, op) \ #define SIMD_DEFAULT_UNARY(name, op) \
template <typename T, int N> \ template <typename T, int N> \

View File

@@ -87,7 +87,6 @@ DEFAULT_UNARY(cosh, std::cosh)
DEFAULT_UNARY(expm1, std::expm1) DEFAULT_UNARY(expm1, std::expm1)
DEFAULT_UNARY(floor, std::floor) DEFAULT_UNARY(floor, std::floor)
DEFAULT_UNARY(log, std::log) DEFAULT_UNARY(log, std::log)
DEFAULT_UNARY(log2, std::log2)
DEFAULT_UNARY(log10, std::log10) DEFAULT_UNARY(log10, std::log10)
DEFAULT_UNARY(log1p, std::log1p) DEFAULT_UNARY(log1p, std::log1p)
DEFAULT_UNARY(sinh, std::sinh) DEFAULT_UNARY(sinh, std::sinh)
@@ -95,6 +94,17 @@ DEFAULT_UNARY(sqrt, std::sqrt)
DEFAULT_UNARY(tan, std::tan) DEFAULT_UNARY(tan, std::tan)
DEFAULT_UNARY(tanh, std::tanh) DEFAULT_UNARY(tanh, std::tanh)
template <typename T>
Simd<T, 1> log2(Simd<T, 1> in) {
if constexpr (is_complex<T>) {
auto out = std::log(in.value);
auto scale = decltype(out.real())(M_LN2);
return Simd<T, 1>{T{out.real() / scale, out.imag() / scale}};
} else {
return Simd<T, 1>{std::log2(in.value)};
}
}
template <typename T> template <typename T>
Simd<T, 1> operator~(Simd<T, 1> in) { Simd<T, 1> operator~(Simd<T, 1> in) {
return ~in.value; return ~in.value;

View File

@@ -4,6 +4,7 @@
#include <cmath> #include <cmath>
#include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h" #include "mlx/backend/cpu/simd/simd.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
#include "mlx/types/limits.h" #include "mlx/types/limits.h"
@@ -15,92 +16,100 @@ namespace {
using namespace mlx::core::simd; using namespace mlx::core::simd;
template <typename T, typename AccT> template <typename T, typename AccT>
void softmax(const array& in, array& out) { void softmax(const array& in, array& out, Stream stream) {
constexpr bool same_t = std::is_same_v<T, AccT>; auto& encoder = cpu::get_command_encoder(stream);
constexpr int N = std::min(max_size<AccT>, max_size<T>); encoder.set_input_array(in);
encoder.set_output_array(out);
const T* in_ptr = in.data<T>(); const T* in_ptr = in.data<T>();
T* out_ptr = out.data<T>(); T* out_ptr = out.data<T>();
int M = in.shape().back(); int M = in.shape().back();
int L = in.data_size() / M; int L = in.data_size() / M;
const T* current_in_ptr;
T* current_out_ptr;
for (int i = 0; i < L; i++, in_ptr += M, out_ptr += M) { encoder.dispatch([in_ptr, out_ptr, M, L]() mutable {
// Find the maximum constexpr bool same_t = std::is_same_v<T, AccT>;
current_in_ptr = in_ptr; constexpr int N = std::min(max_size<AccT>, max_size<T>);
Simd<AccT, N> vmaximum(-numeric_limits<AccT>::infinity());
size_t s = M;
while (s >= N) {
Simd<AccT, N> vals = load<T, N>(current_in_ptr);
vmaximum = maximum(vals, vmaximum);
current_in_ptr += N;
s -= N;
}
AccT maximum = max(vmaximum); const T* current_in_ptr;
while (s-- > 0) { T* current_out_ptr;
maximum = std::max(maximum, static_cast<AccT>(*current_in_ptr));
current_in_ptr++;
}
// Compute the normalizer and the exponentials for (int i = 0; i < L; i++, in_ptr += M, out_ptr += M) {
Simd<AccT, N> vnormalizer(0.0); // Find the maximum
current_out_ptr = out_ptr; current_in_ptr = in_ptr;
current_in_ptr = in_ptr; Simd<AccT, N> vmaximum(-numeric_limits<AccT>::infinity());
s = M; size_t s = M;
while (s >= N) { while (s >= N) {
Simd<AccT, N> vexp = load<T, N>(current_in_ptr); Simd<AccT, N> vals = load<T, N>(current_in_ptr);
vexp = exp(vexp - maximum); vmaximum = maximum(vals, vmaximum);
if constexpr (same_t) {
store(current_out_ptr, vexp);
}
vnormalizer = vnormalizer + vexp;
current_in_ptr += N;
current_out_ptr += N;
s -= N;
}
AccT normalizer = sum(vnormalizer);
while (s-- > 0) {
AccT _exp = std::exp(*current_in_ptr - maximum);
if constexpr (same_t) {
*current_out_ptr = _exp;
}
normalizer += _exp;
current_in_ptr++;
current_out_ptr++;
}
normalizer = 1 / normalizer;
// Normalize
current_out_ptr = out_ptr;
current_in_ptr = in_ptr;
s = M;
while (s >= N) {
if constexpr (same_t) {
store(
current_out_ptr,
Simd<T, N>(load<T, N>(current_out_ptr) * normalizer));
} else {
Simd<AccT, N> vexp = load<T, N>(current_in_ptr);
vexp = exp(vexp - maximum) * normalizer;
store(current_out_ptr, Simd<T, N>(vexp));
current_in_ptr += N; current_in_ptr += N;
s -= N;
} }
current_out_ptr += N;
s -= N; AccT maximum = max(vmaximum);
} while (s-- > 0) {
while (s-- > 0) { maximum = std::max(maximum, static_cast<AccT>(*current_in_ptr));
if constexpr (same_t) {
*current_out_ptr *= normalizer;
} else {
AccT _exp = std::exp(*current_in_ptr - maximum);
*current_out_ptr = static_cast<T>(_exp * normalizer);
current_in_ptr++; current_in_ptr++;
} }
current_out_ptr++;
// Compute the normalizer and the exponentials
Simd<AccT, N> vnormalizer(0.0);
current_out_ptr = out_ptr;
current_in_ptr = in_ptr;
s = M;
while (s >= N) {
Simd<AccT, N> vexp = load<T, N>(current_in_ptr);
vexp = exp(vexp - maximum);
if constexpr (same_t) {
store(current_out_ptr, vexp);
}
vnormalizer = vnormalizer + vexp;
current_in_ptr += N;
current_out_ptr += N;
s -= N;
}
AccT normalizer = sum(vnormalizer);
while (s-- > 0) {
AccT _exp = std::exp(*current_in_ptr - maximum);
if constexpr (same_t) {
*current_out_ptr = _exp;
}
normalizer += _exp;
current_in_ptr++;
current_out_ptr++;
}
normalizer = 1 / normalizer;
// Normalize
current_out_ptr = out_ptr;
current_in_ptr = in_ptr;
s = M;
while (s >= N) {
if constexpr (same_t) {
store(
current_out_ptr,
Simd<T, N>(load<T, N>(current_out_ptr) * normalizer));
} else {
Simd<AccT, N> vexp = load<T, N>(current_in_ptr);
vexp = exp(vexp - maximum) * normalizer;
store(current_out_ptr, Simd<T, N>(vexp));
current_in_ptr += N;
}
current_out_ptr += N;
s -= N;
}
while (s-- > 0) {
if constexpr (same_t) {
*current_out_ptr *= normalizer;
} else {
AccT _exp = std::exp(*current_in_ptr - maximum);
*current_out_ptr = static_cast<T>(_exp * normalizer);
current_in_ptr++;
}
current_out_ptr++;
}
} }
} });
} }
} // namespace } // namespace
@@ -109,67 +118,52 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
// Make sure that the last dimension is contiguous // Make sure that the last dimension is contiguous
auto check_input = [](array x) { auto set_output = [s = stream(), &out](const array& x) {
bool no_copy = x.strides()[x.ndim() - 1] == 1; if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
if (x.ndim() > 1) { if (x.is_donatable()) {
auto s = x.strides()[x.ndim() - 2]; out.copy_shared_buffer(x);
no_copy &= (s == 0 || s == x.shape().back()); } else {
} out.set_data(
if (no_copy) { allocator::malloc(x.data_size() * x.itemsize()),
x.data_size(),
x.strides(),
x.flags());
}
return x; return x;
} else { } else {
array x_copy(x.shape(), x.dtype(), nullptr, {}); array x_copy(x.shape(), x.dtype(), nullptr, {});
copy(x, x_copy, CopyType::General); copy(x, x_copy, CopyType::General, s);
out.copy_shared_buffer(x_copy);
return x_copy; return x_copy;
} }
}; };
array in = check_input(std::move(inputs[0]));
if (in.is_donatable()) { auto in = set_output(inputs[0]);
out.copy_shared_buffer(in);
} else {
out.set_data(
allocator::malloc_or_wait(in.data_size() * in.itemsize()),
in.data_size(),
in.strides(),
in.flags());
}
switch (in.dtype()) { switch (in.dtype()) {
case bool_:
case uint8:
case uint16:
case uint32:
case uint64:
case int8:
case int16:
case int32:
case int64:
throw std::runtime_error(
"Softmax is defined only for floating point types");
break;
case float32: case float32:
softmax<float, float>(in, out); softmax<float, float>(in, out, stream());
break; break;
case float16: case float16:
if (precise_) { if (precise_) {
softmax<float16_t, float>(in, out); softmax<float16_t, float>(in, out, stream());
} else { } else {
softmax<float16_t, float16_t>(in, out); softmax<float16_t, float16_t>(in, out, stream());
} }
break; break;
case bfloat16: case bfloat16:
if (precise_) { if (precise_) {
softmax<bfloat16_t, float>(in, out); softmax<bfloat16_t, float>(in, out, stream());
} else { } else {
softmax<bfloat16_t, bfloat16_t>(in, out); softmax<bfloat16_t, bfloat16_t>(in, out, stream());
} }
break; break;
case float64: case float64:
softmax<double, double>(in, out); softmax<double, double>(in, out, stream());
break; break;
case complex64: default:
throw std::invalid_argument( throw std::runtime_error(
"[Softmax] Not yet implemented for complex64"); "[softmax] Only defined for floating point types.");
break; break;
} }
} }

View File

@@ -7,6 +7,7 @@
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
@@ -103,16 +104,12 @@ struct StridedIterator {
T* ptr_; T* ptr_;
}; };
template <typename T, typename IdxT = uint32_t> template <typename T>
void sort(const array& in, array& out, int axis) { void sort(array& out, int axis) {
// Copy input to output
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
copy(in, out, ctype);
// Get axis, shape and stride info // Get axis, shape and stride info
axis = axis < 0 ? axis + in.ndim() : axis; axis = axis < 0 ? axis + out.ndim() : axis;
size_t in_size = in.flags().contiguous ? in.data_size() : in.size(); size_t in_size = out.size();
size_t n_rows = in_size / in.shape(axis); size_t n_rows = in_size / out.shape(axis);
auto remaining_shape = out.shape(); auto remaining_shape = out.shape();
remaining_shape.erase(remaining_shape.begin() + axis); remaining_shape.erase(remaining_shape.begin() + axis);
@@ -126,8 +123,9 @@ void sort(const array& in, array& out, int axis) {
// Perform sorting in place // Perform sorting in place
ContiguousIterator src_it( ContiguousIterator src_it(
remaining_shape, remaining_strides, remaining_shape.size()); remaining_shape, remaining_strides, remaining_shape.size());
auto out_ptr = out.data<T>();
for (int i = 0; i < n_rows; i++) { for (int i = 0; i < n_rows; i++) {
T* data_ptr = out.data<T>() + src_it.loc; T* data_ptr = out_ptr + src_it.loc;
StridedIterator st(data_ptr, axis_stride, 0); StridedIterator st(data_ptr, axis_stride, 0);
StridedIterator ed(data_ptr, axis_stride, axis_size); StridedIterator ed(data_ptr, axis_stride, axis_size);
@@ -139,9 +137,6 @@ void sort(const array& in, array& out, int axis) {
template <typename T, typename IdxT = uint32_t> template <typename T, typename IdxT = uint32_t>
void argsort(const array& in, array& out, int axis) { void argsort(const array& in, array& out, int axis) {
// Allocate output
out.set_data(allocator::malloc_or_wait(out.nbytes()));
// Get axis, shape and stride info // Get axis, shape and stride info
axis = axis < 0 ? axis + in.ndim() : axis; axis = axis < 0 ? axis + in.ndim() : axis;
size_t n_rows = in.size() / in.shape(axis); size_t n_rows = in.size() / in.shape(axis);
@@ -167,9 +162,12 @@ void argsort(const array& in, array& out, int axis) {
in_remaining_shape, in_remaining_strides, in_remaining_shape.size()); in_remaining_shape, in_remaining_strides, in_remaining_shape.size());
ContiguousIterator out_it( ContiguousIterator out_it(
out_remaining_shape, out_remaining_strides, out_remaining_shape.size()); out_remaining_shape, out_remaining_strides, out_remaining_shape.size());
auto in_ptr = in.data<T>();
auto out_ptr = out.data<IdxT>();
for (int i = 0; i < n_rows; i++) { for (int i = 0; i < n_rows; i++) {
const T* data_ptr = in.data<T>() + in_it.loc; const T* data_ptr = in_ptr + in_it.loc;
IdxT* idx_ptr = out.data<IdxT>() + out_it.loc; IdxT* idx_ptr = out_ptr + out_it.loc;
in_it.step(); in_it.step();
out_it.step(); out_it.step();
@@ -191,33 +189,30 @@ void argsort(const array& in, array& out, int axis) {
} }
} }
template <typename T, typename IdxT = uint32_t> template <typename T>
void partition(const array& in, array& out, int axis, int kth) { void partition(array& out, int axis, int kth) {
// Copy input to output
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
copy(in, out, ctype);
// Get axis, shape and stride info // Get axis, shape and stride info
axis = axis < 0 ? axis + in.ndim() : axis; axis = axis < 0 ? axis + out.ndim() : axis;
size_t in_size = in.flags().contiguous ? in.data_size() : in.size(); size_t in_size = out.size();
size_t n_rows = in_size / in.shape(axis); size_t n_rows = in_size / out.shape(axis);
auto remaining_shape = in.shape(); auto remaining_shape = out.shape();
remaining_shape.erase(remaining_shape.begin() + axis); remaining_shape.erase(remaining_shape.begin() + axis);
auto remaining_strides = in.strides(); auto remaining_strides = out.strides();
remaining_strides.erase(remaining_strides.begin() + axis); remaining_strides.erase(remaining_strides.begin() + axis);
auto axis_stride = in.strides()[axis]; auto axis_stride = out.strides()[axis];
int axis_size = in.shape(axis); int axis_size = out.shape(axis);
kth = kth < 0 ? kth + axis_size : kth; kth = kth < 0 ? kth + axis_size : kth;
// Perform partition in place // Perform partition in place
ContiguousIterator src_it( ContiguousIterator src_it(
remaining_shape, remaining_strides, remaining_shape.size()); remaining_shape, remaining_strides, remaining_shape.size());
auto out_ptr = out.data<T>();
for (int i = 0; i < n_rows; i++) { for (int i = 0; i < n_rows; i++) {
T* data_ptr = out.data<T>() + src_it.loc; T* data_ptr = out_ptr + src_it.loc;
src_it.step(); src_it.step();
StridedIterator st(data_ptr, axis_stride, 0); StridedIterator st(data_ptr, axis_stride, 0);
@@ -230,9 +225,6 @@ void partition(const array& in, array& out, int axis, int kth) {
template <typename T, typename IdxT = uint32_t> template <typename T, typename IdxT = uint32_t>
void argpartition(const array& in, array& out, int axis, int kth) { void argpartition(const array& in, array& out, int axis, int kth) {
// Allocate output
out.set_data(allocator::malloc_or_wait(out.nbytes()));
// Get axis, shape and stride info // Get axis, shape and stride info
axis = axis < 0 ? axis + in.ndim() : axis; axis = axis < 0 ? axis + in.ndim() : axis;
size_t n_rows = in.size() / in.shape(axis); size_t n_rows = in.size() / in.shape(axis);
@@ -260,9 +252,13 @@ void argpartition(const array& in, array& out, int axis, int kth) {
in_remaining_shape, in_remaining_strides, in_remaining_shape.size()); in_remaining_shape, in_remaining_strides, in_remaining_shape.size());
ContiguousIterator out_it( ContiguousIterator out_it(
out_remaining_shape, out_remaining_strides, out_remaining_shape.size()); out_remaining_shape, out_remaining_strides, out_remaining_shape.size());
auto in_ptr = in.data<T>();
auto out_ptr = out.data<IdxT>();
for (int i = 0; i < n_rows; i++) { for (int i = 0; i < n_rows; i++) {
const T* data_ptr = in.data<T>() + in_it.loc; const T* data_ptr = in_ptr + in_it.loc;
IdxT* idx_ptr = out.data<IdxT>() + out_it.loc; IdxT* idx_ptr = out_ptr + out_it.loc;
in_it.step(); in_it.step();
out_it.step(); out_it.step();
@@ -291,144 +287,184 @@ void ArgSort::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
switch (in.dtype()) { // Allocate output
case bool_: out.set_data(allocator::malloc(out.nbytes()));
return argsort<bool>(in, out, axis_);
case uint8: auto& encoder = cpu::get_command_encoder(stream());
return argsort<uint8_t>(in, out, axis_); encoder.set_input_array(in);
case uint16: encoder.set_input_array(out);
return argsort<uint16_t>(in, out, axis_); encoder.dispatch([in = array::unsafe_weak_copy(in),
case uint32: out = array::unsafe_weak_copy(out),
return argsort<uint32_t>(in, out, axis_); axis_ = axis_]() mutable {
case uint64: switch (in.dtype()) {
return argsort<uint64_t>(in, out, axis_); case bool_:
case int8: return argsort<bool>(in, out, axis_);
return argsort<int8_t>(in, out, axis_); case uint8:
case int16: return argsort<uint8_t>(in, out, axis_);
return argsort<int16_t>(in, out, axis_); case uint16:
case int32: return argsort<uint16_t>(in, out, axis_);
return argsort<int32_t>(in, out, axis_); case uint32:
case int64: return argsort<uint32_t>(in, out, axis_);
return argsort<int64_t>(in, out, axis_); case uint64:
case float32: return argsort<uint64_t>(in, out, axis_);
return argsort<float>(in, out, axis_); case int8:
case float64: return argsort<int8_t>(in, out, axis_);
return argsort<double>(in, out, axis_); case int16:
case float16: return argsort<int16_t>(in, out, axis_);
return argsort<float16_t>(in, out, axis_); case int32:
case bfloat16: return argsort<int32_t>(in, out, axis_);
return argsort<bfloat16_t>(in, out, axis_); case int64:
case complex64: return argsort<int64_t>(in, out, axis_);
return argsort<complex64_t>(in, out, axis_); case float32:
} return argsort<float>(in, out, axis_);
case float64:
return argsort<double>(in, out, axis_);
case float16:
return argsort<float16_t>(in, out, axis_);
case bfloat16:
return argsort<bfloat16_t>(in, out, axis_);
case complex64:
return argsort<complex64_t>(in, out, axis_);
}
});
} }
void Sort::eval_cpu(const std::vector<array>& inputs, array& out) { void Sort::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
switch (in.dtype()) { // Copy input to output
case bool_: CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
return sort<bool>(in, out, axis_); copy(in, out, ctype, stream());
case uint8:
return sort<uint8_t>(in, out, axis_); auto& encoder = cpu::get_command_encoder(stream());
case uint16: encoder.set_output_array(out);
return sort<uint16_t>(in, out, axis_); encoder.dispatch(
case uint32: [out = array::unsafe_weak_copy(out), axis_ = axis_]() mutable {
return sort<uint32_t>(in, out, axis_); switch (out.dtype()) {
case uint64: case bool_:
return sort<uint64_t>(in, out, axis_); return sort<bool>(out, axis_);
case int8: case uint8:
return sort<int8_t>(in, out, axis_); return sort<uint8_t>(out, axis_);
case int16: case uint16:
return sort<int16_t>(in, out, axis_); return sort<uint16_t>(out, axis_);
case int32: case uint32:
return sort<int32_t>(in, out, axis_); return sort<uint32_t>(out, axis_);
case int64: case uint64:
return sort<int64_t>(in, out, axis_); return sort<uint64_t>(out, axis_);
case float32: case int8:
return sort<float>(in, out, axis_); return sort<int8_t>(out, axis_);
case float64: case int16:
return sort<double>(in, out, axis_); return sort<int16_t>(out, axis_);
case float16: case int32:
return sort<float16_t>(in, out, axis_); return sort<int32_t>(out, axis_);
case bfloat16: case int64:
return sort<bfloat16_t>(in, out, axis_); return sort<int64_t>(out, axis_);
case complex64: case float32:
return sort<complex64_t>(in, out, axis_); return sort<float>(out, axis_);
} case float64:
return sort<double>(out, axis_);
case float16:
return sort<float16_t>(out, axis_);
case bfloat16:
return sort<bfloat16_t>(out, axis_);
case complex64:
return sort<complex64_t>(out, axis_);
}
});
} }
void ArgPartition::eval_cpu(const std::vector<array>& inputs, array& out) { void ArgPartition::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
switch (in.dtype()) { // Allocate output
case bool_: out.set_data(allocator::malloc(out.nbytes()));
return argpartition<bool>(in, out, axis_, kth_);
case uint8: auto& encoder = cpu::get_command_encoder(stream());
return argpartition<uint8_t>(in, out, axis_, kth_); encoder.set_input_array(in);
case uint16: encoder.set_input_array(out);
return argpartition<uint16_t>(in, out, axis_, kth_); encoder.dispatch([in = array::unsafe_weak_copy(in),
case uint32: out = array::unsafe_weak_copy(out),
return argpartition<uint32_t>(in, out, axis_, kth_); axis_ = axis_,
case uint64: kth_ = kth_]() mutable {
return argpartition<uint64_t>(in, out, axis_, kth_); switch (in.dtype()) {
case int8: case bool_:
return argpartition<int8_t>(in, out, axis_, kth_); return argpartition<bool>(in, out, axis_, kth_);
case int16: case uint8:
return argpartition<int16_t>(in, out, axis_, kth_); return argpartition<uint8_t>(in, out, axis_, kth_);
case int32: case uint16:
return argpartition<int32_t>(in, out, axis_, kth_); return argpartition<uint16_t>(in, out, axis_, kth_);
case int64: case uint32:
return argpartition<int64_t>(in, out, axis_, kth_); return argpartition<uint32_t>(in, out, axis_, kth_);
case float32: case uint64:
return argpartition<float>(in, out, axis_, kth_); return argpartition<uint64_t>(in, out, axis_, kth_);
case float64: case int8:
return argpartition<double>(in, out, axis_, kth_); return argpartition<int8_t>(in, out, axis_, kth_);
case float16: case int16:
return argpartition<float16_t>(in, out, axis_, kth_); return argpartition<int16_t>(in, out, axis_, kth_);
case bfloat16: case int32:
return argpartition<bfloat16_t>(in, out, axis_, kth_); return argpartition<int32_t>(in, out, axis_, kth_);
case complex64: case int64:
return argpartition<complex64_t>(in, out, axis_, kth_); return argpartition<int64_t>(in, out, axis_, kth_);
} case float32:
return argpartition<float>(in, out, axis_, kth_);
case float64:
return argpartition<double>(in, out, axis_, kth_);
case float16:
return argpartition<float16_t>(in, out, axis_, kth_);
case bfloat16:
return argpartition<bfloat16_t>(in, out, axis_, kth_);
case complex64:
return argpartition<complex64_t>(in, out, axis_, kth_);
}
});
} }
void Partition::eval_cpu(const std::vector<array>& inputs, array& out) { void Partition::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
switch (in.dtype()) { // Copy input to output
case bool_: CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
return partition<bool>(in, out, axis_, kth_); copy(in, out, ctype, stream());
case uint8:
return partition<uint8_t>(in, out, axis_, kth_); auto& encoder = cpu::get_command_encoder(stream());
case uint16: encoder.set_output_array(out);
return partition<uint16_t>(in, out, axis_, kth_); encoder.dispatch([out = array::unsafe_weak_copy(out),
case uint32: axis_ = axis_,
return partition<uint32_t>(in, out, axis_, kth_); kth_ = kth_]() mutable {
case uint64: switch (out.dtype()) {
return partition<uint64_t>(in, out, axis_, kth_); case bool_:
case int8: return partition<bool>(out, axis_, kth_);
return partition<int8_t>(in, out, axis_, kth_); case uint8:
case int16: return partition<uint8_t>(out, axis_, kth_);
return partition<int16_t>(in, out, axis_, kth_); case uint16:
case int32: return partition<uint16_t>(out, axis_, kth_);
return partition<int32_t>(in, out, axis_, kth_); case uint32:
case int64: return partition<uint32_t>(out, axis_, kth_);
return partition<int64_t>(in, out, axis_, kth_); case uint64:
case float32: return partition<uint64_t>(out, axis_, kth_);
return partition<float>(in, out, axis_, kth_); case int8:
case float64: return partition<int8_t>(out, axis_, kth_);
return partition<double>(in, out, axis_, kth_); case int16:
case float16: return partition<int16_t>(out, axis_, kth_);
return partition<float16_t>(in, out, axis_, kth_); case int32:
case bfloat16: return partition<int32_t>(out, axis_, kth_);
return partition<bfloat16_t>(in, out, axis_, kth_); case int64:
case complex64: return partition<int64_t>(out, axis_, kth_);
return partition<complex64_t>(in, out, axis_, kth_); case float32:
} return partition<float>(out, axis_, kth_);
case float64:
return partition<double>(out, axis_, kth_);
case float16:
return partition<float16_t>(out, axis_, kth_);
case bfloat16:
return partition<bfloat16_t>(out, axis_, kth_);
case complex64:
return partition<complex64_t>(out, axis_, kth_);
}
});
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -2,13 +2,18 @@
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/lapack.h" #include "mlx/backend/cpu/lapack.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
namespace mlx::core { namespace mlx::core {
template <typename T> template <typename T>
void svd_impl(const array& a, T* u_data, T* s_data, T* vt_data) { void svd_impl(
const array& a,
std::vector<array>& outputs,
bool compute_uv,
Stream stream) {
// Lapack uses the column-major convention. To avoid having to transpose // Lapack uses the column-major convention. To avoid having to transpose
// the input and then transpose the outputs, we swap the indices/sizes of the // the input and then transpose the outputs, we swap the indices/sizes of the
// matrices and take advantage of the following identity (see // matrices and take advantage of the following identity (see
@@ -22,75 +27,80 @@ void svd_impl(const array& a, T* u_data, T* s_data, T* vt_data) {
const int N = a.shape(-1); const int N = a.shape(-1);
const int K = std::min(M, N); const int K = std::min(M, N);
// A of shape M x N. The leading dimension is N since lapack receives Aᵀ.
const int lda = N;
// U of shape M x M. (N x N in lapack).
const int ldu = N;
// Vᵀ of shape N x N. (M x M in lapack).
const int ldvt = M;
size_t num_matrices = a.size() / (M * N); size_t num_matrices = a.size() / (M * N);
// lapack clobbers the input, so we have to make a copy. // lapack clobbers the input, so we have to make a copy.
array in(a.shape(), a.dtype(), nullptr, {}); array in(a.shape(), a.dtype(), nullptr, {});
copy(a, in, a.flags().row_contiguous ? CopyType::Vector : CopyType::General); copy(
a,
in,
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
stream);
auto job_u = (u_data && vt_data) ? "V" : "N"; // Allocate outputs.
auto job_vt = (u_data && vt_data) ? "V" : "N"; auto& encoder = cpu::get_command_encoder(stream);
static constexpr auto range = "A"; encoder.set_input_array(a);
auto in_ptr = in.data<T>();
T* u_ptr;
T* s_ptr;
T* vt_ptr;
// Will contain the number of singular values after the call has returned. if (compute_uv) {
int ns = 0; array& u = outputs[0];
T workspace_dimension = 0; array& s = outputs[1];
array& vt = outputs[2];
// Will contain the indices of eigenvectors that failed to converge (not used u.set_data(allocator::malloc(u.nbytes()));
// here but required by lapack). s.set_data(allocator::malloc(s.nbytes()));
auto iwork = array::Data{allocator::malloc_or_wait(sizeof(int) * 12 * K)}; vt.set_data(allocator::malloc(vt.nbytes()));
static const int lwork_query = -1; encoder.set_output_array(u);
encoder.set_output_array(s);
encoder.set_output_array(vt);
static const int ignored_int = 0; s_ptr = s.data<T>();
static const T ignored_float = 0; u_ptr = u.data<T>();
static T ignored_output = 0; vt_ptr = vt.data<T>();
} else {
array& s = outputs[0];
int info; s.set_data(allocator::malloc(s.nbytes()));
// Compute workspace size. encoder.set_output_array(s);
gesvdx<T>(
/* jobu = */ job_u,
/* jobvt = */ job_vt,
/* range = */ range,
// M and N are swapped since lapack expects column-major.
/* m = */ &N,
/* n = */ &M,
/* a = */ nullptr,
/* lda = */ &lda,
/* vl = */ &ignored_float,
/* vu = */ &ignored_float,
/* il = */ &ignored_int,
/* iu = */ &ignored_int,
/* ns = */ &ns,
/* s = */ nullptr,
/* u = */ nullptr,
/* ldu = */ &ldu,
/* vt = */ nullptr,
/* ldvt = */ &ldvt,
/* work = */ &workspace_dimension,
/* lwork = */ &lwork_query,
/* iwork = */ static_cast<int*>(iwork.buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) { s_ptr = s.data<T>();
std::stringstream ss; u_ptr = nullptr;
ss << "[SVD::eval_cpu] workspace calculation failed with code " << info; vt_ptr = nullptr;
throw std::runtime_error(ss.str());
} }
const int lwork = workspace_dimension; encoder.dispatch([in_ptr, u_ptr, s_ptr, vt_ptr, M, N, K, num_matrices]() {
auto scratch = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)}; // A of shape M x N. The leading dimension is N since lapack receives Aᵀ.
const int lda = N;
// U of shape M x M. (N x N in lapack).
const int ldu = N;
// Vᵀ of shape N x N. (M x M in lapack).
const int ldvt = M;
// Loop over matrices. auto job_u = (u_ptr) ? "V" : "N";
for (int i = 0; i < num_matrices; i++) { auto job_vt = (u_ptr) ? "V" : "N";
static constexpr auto range = "A";
// Will contain the number of singular values after the call has returned.
int ns = 0;
T workspace_dimension = 0;
// Will contain the indices of eigenvectors that failed to converge (not
// used here but required by lapack).
auto iwork = array::Data{allocator::malloc(sizeof(int) * 12 * K)};
static const int lwork_query = -1;
static const int ignored_int = 0;
static const T ignored_float = 0;
int info;
// Compute workspace size.
gesvdx<T>( gesvdx<T>(
/* jobu = */ job_u, /* jobu = */ job_u,
/* jobvt = */ job_vt, /* jobvt = */ job_vt,
@@ -98,70 +108,93 @@ void svd_impl(const array& a, T* u_data, T* s_data, T* vt_data) {
// M and N are swapped since lapack expects column-major. // M and N are swapped since lapack expects column-major.
/* m = */ &N, /* m = */ &N,
/* n = */ &M, /* n = */ &M,
/* a = */ in.data<T>() + M * N * i, /* a = */ nullptr,
/* lda = */ &lda, /* lda = */ &lda,
/* vl = */ &ignored_float, /* vl = */ &ignored_float,
/* vu = */ &ignored_float, /* vu = */ &ignored_float,
/* il = */ &ignored_int, /* il = */ &ignored_int,
/* iu = */ &ignored_int, /* iu = */ &ignored_int,
/* ns = */ &ns, /* ns = */ &ns,
/* s = */ s_data + K * i, /* s = */ nullptr,
// According to the identity above, lapack will write Vᵀᵀ as U. /* u = */ nullptr,
/* u = */ vt_data ? vt_data + N * N * i : &ignored_output,
/* ldu = */ &ldu, /* ldu = */ &ldu,
// According to the identity above, lapack will write Uᵀ as Vᵀ. /* vt = */ nullptr,
/* vt = */ u_data ? u_data + M * M * i : &ignored_output,
/* ldvt = */ &ldvt, /* ldvt = */ &ldvt,
/* work = */ static_cast<T*>(scratch.buffer.raw_ptr()), /* work = */ &workspace_dimension,
/* lwork = */ &lwork, /* lwork = */ &lwork_query,
/* iwork = */ static_cast<int*>(iwork.buffer.raw_ptr()), /* iwork = */ static_cast<int*>(iwork.buffer.raw_ptr()),
/* info = */ &info); /* info = */ &info);
if (info != 0) { if (info != 0) {
std::stringstream ss; std::stringstream ss;
ss << "[SVD::eval_cpu] failed with code " << info; ss << "[SVD::eval_cpu] workspace calculation failed with code " << info;
throw std::runtime_error(ss.str()); throw std::runtime_error(ss.str());
} }
if (ns != K) { const int lwork = workspace_dimension;
std::stringstream ss; auto scratch = array::Data{allocator::malloc(sizeof(T) * lwork)};
ss << "[SVD::eval_cpu] expected " << K << " singular values, but " << ns
<< " were computed."; // Loop over matrices.
throw std::runtime_error(ss.str()); for (int i = 0; i < num_matrices; i++) {
gesvdx<T>(
/* jobu = */ job_u,
/* jobvt = */ job_vt,
/* range = */ range,
// M and N are swapped since lapack expects column-major.
/* m = */ &N,
/* n = */ &M,
/* a = */ in_ptr + M * N * i,
/* lda = */ &lda,
/* vl = */ &ignored_float,
/* vu = */ &ignored_float,
/* il = */ &ignored_int,
/* iu = */ &ignored_int,
/* ns = */ &ns,
/* s = */ s_ptr + K * i,
// According to the identity above, lapack will write Vᵀᵀ as U.
/* u = */ vt_ptr ? vt_ptr + N * N * i : nullptr,
/* ldu = */ &ldu,
// According to the identity above, lapack will write Uᵀ as Vᵀ.
/* vt = */ u_ptr ? u_ptr + M * M * i : nullptr,
/* ldvt = */ &ldvt,
/* work = */ static_cast<T*>(scratch.buffer.raw_ptr()),
/* lwork = */ &lwork,
/* iwork = */ static_cast<int*>(iwork.buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "svd_impl: sgesvdx_ failed with code " << info;
throw std::runtime_error(ss.str());
}
if (ns != K) {
std::stringstream ss;
ss << "svd_impl: expected " << K << " singular values, but " << ns
<< " were computed.";
throw std::runtime_error(ss.str());
}
} }
} });
encoder.add_temporary(in);
} }
template <typename T> template <typename T>
void compute_svd(const array& a, bool compute_uv, std::vector<array>& outputs) { void compute_svd(
if (compute_uv) { const array& a,
array& u = outputs[0]; bool compute_uv,
array& s = outputs[1]; std::vector<array>& outputs,
array& vt = outputs[2]; Stream stream) {}
u.set_data(allocator::malloc_or_wait(u.nbytes()));
s.set_data(allocator::malloc_or_wait(s.nbytes()));
vt.set_data(allocator::malloc_or_wait(vt.nbytes()));
svd_impl<T>(a, u.data<T>(), s.data<T>(), vt.data<T>());
} else {
array& s = outputs[0];
s.set_data(allocator::malloc_or_wait(s.nbytes()));
svd_impl<T>(a, nullptr, s.data<T>(), nullptr);
}
}
void SVD::eval_cpu( void SVD::eval_cpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs) { std::vector<array>& outputs) {
switch (inputs[0].dtype()) { switch (inputs[0].dtype()) {
case float32: case float32:
compute_svd<float>(inputs[0], compute_uv_, outputs); svd_impl<float>(inputs[0], outputs, compute_uv_, stream());
break; break;
case float64: case float64:
compute_svd<double>(inputs[0], compute_uv_, outputs); svd_impl<double>(inputs[0], outputs, compute_uv_, stream());
break; break;
default: default:
throw std::runtime_error( throw std::runtime_error(

View File

@@ -1,10 +1,10 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023 Apple Inc.
#pragma once #pragma once
#include "mlx/allocator.h"
#include "mlx/array.h" #include "mlx/array.h"
#include "mlx/backend/common/ternary.h" #include "mlx/backend/common/ternary.h"
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/encoder.h"
namespace mlx::core { namespace mlx::core {
@@ -53,22 +53,18 @@ void ternary_op_dims(
template <typename T1, typename T2, typename T3, typename U, typename Op> template <typename T1, typename T2, typename T3, typename U, typename Op>
void ternary_op_dispatch_dims( void ternary_op_dispatch_dims(
const array& a, const T1* a_ptr,
const array& b, const T2* b_ptr,
const array& c, const T3* c_ptr,
array& out, U* out_ptr,
Op op) { Op op,
auto [shape, strides] = collapse_contiguous_dims( size_t size,
a.shape(), {a.strides(), b.strides(), c.strides(), out.strides()}); Shape& shape,
std::vector<Strides>& strides) {
const auto& a_strides = strides[0]; const auto& a_strides = strides[0];
const auto& b_strides = strides[1]; const auto& b_strides = strides[1];
const auto& c_strides = strides[2]; const auto& c_strides = strides[2];
const auto& out_strides = strides[3]; const auto& out_strides = strides[3];
const T1* a_ptr = a.data<T1>();
const T2* b_ptr = b.data<T2>();
const T3* c_ptr = c.data<T3>();
U* out_ptr = out.data<T3>();
int ndim = shape.size(); int ndim = shape.size();
switch (ndim) { switch (ndim) {
case 1: case 1:
@@ -105,7 +101,7 @@ void ternary_op_dispatch_dims(
ContiguousIterator b_it(shape, b_strides, ndim - 2); ContiguousIterator b_it(shape, b_strides, ndim - 2);
ContiguousIterator c_it(shape, c_strides, ndim - 2); ContiguousIterator c_it(shape, c_strides, ndim - 2);
auto stride = out_strides[ndim - 3]; auto stride = out_strides[ndim - 3];
for (size_t elem = 0; elem < a.size(); elem += stride) { for (size_t elem = 0; elem < size; elem += stride) {
ternary_op_dims<T1, T2, T3, U, Op, 2>( ternary_op_dims<T1, T2, T3, U, Op, 2>(
a_ptr + a_it.loc, a_ptr + a_it.loc,
b_ptr + b_it.loc, b_ptr + b_it.loc,
@@ -130,18 +126,16 @@ void ternary_op(
const array& b, const array& b,
const array& c, const array& c,
array& out, array& out,
Op op) { Op op,
TernaryOpType topt = get_ternary_op_type(a, b, c); TernaryOpType topt) {
set_ternary_op_output_data(a, b, c, out, topt); const T1* a_ptr = a.data<T1>();
const T2* b_ptr = b.data<T2>();
const T3* c_ptr = c.data<T3>();
U* out_ptr = out.data<U>();
// The full computation is scalar-scalar-scalar so we call the base op once.
if (topt == TernaryOpType::ScalarScalarScalar) { if (topt == TernaryOpType::ScalarScalarScalar) {
*(out.data<U>()) = op(*a.data<T1>(), *b.data<T2>(), *c.data<T3>()); *out_ptr = op(*a_ptr, *b_ptr, *c_ptr);
} else if (topt == TernaryOpType::VectorVectorVector) { } else if (topt == TernaryOpType::VectorVectorVector) {
const T1* a_ptr = a.data<T1>();
const T2* b_ptr = b.data<T2>();
const T3* c_ptr = c.data<T3>();
U* out_ptr = out.data<U>();
for (size_t i = 0; i < out.size(); ++i) { for (size_t i = 0; i < out.size(); ++i) {
*out_ptr = op(*a_ptr, *b_ptr, *c_ptr); *out_ptr = op(*a_ptr, *b_ptr, *c_ptr);
a_ptr++; a_ptr++;
@@ -150,7 +144,10 @@ void ternary_op(
out_ptr++; out_ptr++;
} }
} else { } else {
ternary_op_dispatch_dims<T1, T2, T3, U>(a, b, c, out, op); auto [shape, strides] = collapse_contiguous_dims(
a.shape(), {a.strides(), b.strides(), c.strides(), out.strides()});
ternary_op_dispatch_dims<T1, T2, T3, U>(
a_ptr, b_ptr, c_ptr, out_ptr, op, out.size(), shape, strides);
} }
} }

View File

@@ -1,5 +1,8 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
// Required for using M_LN2 in MSVC.
#define _USE_MATH_DEFINES
#include <cassert> #include <cassert>
#include "mlx/backend/cpu/unary.h" #include "mlx/backend/cpu/unary.h"
@@ -14,88 +17,57 @@ void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
// No-op for unsigned types // No-op for unsigned types
out.copy_shared_buffer(in); out.copy_shared_buffer(in);
} else { } else {
auto op = detail::Abs{}; unary_signed(in, out, detail::Abs(), stream());
switch (out.dtype()) {
case int8:
unary_op<int8_t>(in, out, op);
break;
case int16:
unary_op<int16_t>(in, out, op);
break;
case int32:
unary_op<int32_t>(in, out, op);
break;
case int64:
unary_op<int64_t>(in, out, op);
break;
case float16:
unary_op<float16_t>(in, out, op);
break;
case float32:
unary_op<float>(in, out, op);
break;
case float64:
unary_op<double>(in, out, op);
break;
case bfloat16:
unary_op<bfloat16_t>(in, out, op);
break;
case complex64:
unary_op<complex64_t>(in, out, op);
break;
default:
throw std::runtime_error("[Abs] Called on unsigned type");
}
} }
} }
void ArcCos::eval_cpu(const std::vector<array>& inputs, array& out) { void ArcCos::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
unary_fp(in, out, detail::ArcCos()); unary_fp(in, out, detail::ArcCos(), stream());
} }
void ArcCosh::eval_cpu(const std::vector<array>& inputs, array& out) { void ArcCosh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
unary_fp(in, out, detail::ArcCosh()); unary_fp(in, out, detail::ArcCosh(), stream());
} }
void ArcSin::eval_cpu(const std::vector<array>& inputs, array& out) { void ArcSin::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
unary_fp(in, out, detail::ArcSin()); unary_fp(in, out, detail::ArcSin(), stream());
} }
void ArcSinh::eval_cpu(const std::vector<array>& inputs, array& out) { void ArcSinh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
unary_fp(in, out, detail::ArcSinh()); unary_fp(in, out, detail::ArcSinh(), stream());
} }
void ArcTan::eval_cpu(const std::vector<array>& inputs, array& out) { void ArcTan::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
unary_fp(in, out, detail::ArcTan()); unary_fp(in, out, detail::ArcTan(), stream());
} }
void ArcTanh::eval_cpu(const std::vector<array>& inputs, array& out) { void ArcTanh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
unary_fp(in, out, detail::ArcTanh()); unary_fp(in, out, detail::ArcTanh(), stream());
} }
void BitwiseInvert::eval_cpu(const std::vector<array>& inputs, array& out) { void BitwiseInvert::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
unary_int(in, out, detail::BitwiseInvert()); unary_int(in, out, detail::BitwiseInvert(), stream());
} }
void Ceil::eval_cpu(const std::vector<array>& inputs, array& out) { void Ceil::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
if (issubdtype(in.dtype(), inexact)) { if (issubdtype(in.dtype(), inexact)) {
unary_fp(in, out, detail::Ceil()); unary_fp(in, out, detail::Ceil(), stream());
} else { } else {
// No-op integer types // No-op integer types
out.copy_shared_buffer(in); out.copy_shared_buffer(in);
@@ -104,84 +76,50 @@ void Ceil::eval_cpu(const std::vector<array>& inputs, array& out) {
void Conjugate::eval_cpu(const std::vector<array>& inputs, array& out) { void Conjugate::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
unary_op<complex64_t>(inputs[0], out, detail::Conjugate()); unary_complex(inputs[0], out, detail::Conjugate(), stream());
} }
void Cos::eval_cpu(const std::vector<array>& inputs, array& out) { void Cos::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
unary_fp(in, out, detail::Cos()); unary_fp(in, out, detail::Cos(), stream());
} }
void Cosh::eval_cpu(const std::vector<array>& inputs, array& out) { void Cosh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
unary_fp(in, out, detail::Cosh()); unary_fp(in, out, detail::Cosh(), stream());
} }
void Erf::eval_cpu(const std::vector<array>& inputs, array& out) { void Erf::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
switch (out.dtype()) { unary_real_fp(in, out, detail::Erf(), stream());
case float32:
unary_op<float>(in, out, detail::Erf());
break;
case float16:
unary_op<float16_t>(in, out, detail::Erf());
break;
case float64:
unary_op<double>(in, out, detail::Erf());
break;
case bfloat16:
unary_op<bfloat16_t>(in, out, detail::Erf());
break;
default:
throw std::invalid_argument(
"[erf] Error function only defined for arrays"
" with real floating point type.");
}
} }
void ErfInv::eval_cpu(const std::vector<array>& inputs, array& out) { void ErfInv::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
switch (out.dtype()) { unary_real_fp(in, out, detail::ErfInv(), stream());
case float32:
unary_op<float>(in, out, detail::ErfInv());
break;
case float16:
unary_op<float16_t>(in, out, detail::ErfInv());
break;
case float64:
unary_op<double>(in, out, detail::ErfInv());
break;
case bfloat16:
unary_op<bfloat16_t>(in, out, detail::ErfInv());
break;
default:
throw std::invalid_argument(
"[erf_inv] Inverse error function only defined for arrays"
" with real floating point type.");
}
} }
void Exp::eval_cpu(const std::vector<array>& inputs, array& out) { void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
unary_fp(in, out, detail::Exp()); unary_fp(in, out, detail::Exp(), stream());
} }
void Expm1::eval_cpu(const std::vector<array>& inputs, array& out) { void Expm1::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
unary_fp(in, out, detail::Expm1()); unary_fp(in, out, detail::Expm1(), stream());
} }
void Floor::eval_cpu(const std::vector<array>& inputs, array& out) { void Floor::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
if (issubdtype(in.dtype(), inexact)) { if (issubdtype(in.dtype(), inexact)) {
unary_fp(in, out, detail::Floor()); unary_fp(in, out, detail::Floor(), stream());
} else { } else {
// No-op integer types // No-op integer types
out.copy_shared_buffer(in); out.copy_shared_buffer(in);
@@ -189,7 +127,7 @@ void Floor::eval_cpu(const std::vector<array>& inputs, array& out) {
} }
void Imag::eval_cpu(const std::vector<array>& inputs, array& out) { void Imag::eval_cpu(const std::vector<array>& inputs, array& out) {
unary_op<complex64_t, float>(inputs[0], out, detail::Imag()); unary_complex_to_float(inputs[0], out, detail::Imag(), stream());
} }
void Log::eval_cpu(const std::vector<array>& inputs, array& out) { void Log::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -197,13 +135,13 @@ void Log::eval_cpu(const std::vector<array>& inputs, array& out) {
const auto& in = inputs[0]; const auto& in = inputs[0];
switch (base_) { switch (base_) {
case Base::e: case Base::e:
unary_fp(in, out, detail::Log()); unary_fp(in, out, detail::Log(), stream());
break; break;
case Base::two: case Base::two:
unary_fp(in, out, detail::Log2()); unary_fp(in, out, detail::Log2(), stream());
break; break;
case Base::ten: case Base::ten:
unary_fp(in, out, detail::Log10()); unary_fp(in, out, detail::Log10(), stream());
break; break;
} }
} }
@@ -211,30 +149,30 @@ void Log::eval_cpu(const std::vector<array>& inputs, array& out) {
void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) { void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
unary_fp(in, out, detail::Log1p()); unary_fp(in, out, detail::Log1p(), stream());
} }
void LogicalNot::eval_cpu(const std::vector<array>& inputs, array& out) { void LogicalNot::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
unary(in, out, detail::LogicalNot()); unary(in, out, detail::LogicalNot(), stream());
} }
void Negative::eval_cpu(const std::vector<array>& inputs, array& out) { void Negative::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
unary(in, out, detail::Negative()); unary(in, out, detail::Negative(), stream());
} }
void Real::eval_cpu(const std::vector<array>& inputs, array& out) { void Real::eval_cpu(const std::vector<array>& inputs, array& out) {
unary_op<complex64_t, float>(inputs[0], out, detail::Real()); unary_complex_to_float(inputs[0], out, detail::Real(), stream());
} }
void Round::eval_cpu(const std::vector<array>& inputs, array& out) { void Round::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
if (issubdtype(in.dtype(), inexact)) { if (issubdtype(in.dtype(), inexact)) {
unary_fp(in, out, detail::Round()); unary_fp(in, out, detail::Round(), stream());
} else { } else {
// No-op integer types // No-op integer types
out.copy_shared_buffer(in); out.copy_shared_buffer(in);
@@ -244,7 +182,7 @@ void Round::eval_cpu(const std::vector<array>& inputs, array& out) {
void Sigmoid::eval_cpu(const std::vector<array>& inputs, array& out) { void Sigmoid::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
unary_fp(in, out, detail::Sigmoid()); unary_fp(in, out, detail::Sigmoid(), stream());
} }
void Sign::eval_cpu(const std::vector<array>& inputs, array& out) { void Sign::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -253,48 +191,48 @@ void Sign::eval_cpu(const std::vector<array>& inputs, array& out) {
if (in.dtype() == bool_) { if (in.dtype() == bool_) {
out.copy_shared_buffer(in); out.copy_shared_buffer(in);
} else { } else {
unary(in, out, detail::Sign()); unary(in, out, detail::Sign(), stream());
} }
} }
void Sin::eval_cpu(const std::vector<array>& inputs, array& out) { void Sin::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
unary_fp(in, out, detail::Sin()); unary_fp(in, out, detail::Sin(), stream());
} }
void Sinh::eval_cpu(const std::vector<array>& inputs, array& out) { void Sinh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
unary_fp(in, out, detail::Sinh()); unary_fp(in, out, detail::Sinh(), stream());
} }
void Square::eval_cpu(const std::vector<array>& inputs, array& out) { void Square::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
unary(in, out, detail::Square()); unary(in, out, detail::Square(), stream());
} }
void Sqrt::eval_cpu(const std::vector<array>& inputs, array& out) { void Sqrt::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
if (recip_) { if (recip_) {
unary_fp(in, out, detail::Rsqrt()); unary_fp(in, out, detail::Rsqrt(), stream());
} else { } else {
unary_fp(in, out, detail::Sqrt()); unary_fp(in, out, detail::Sqrt(), stream());
} }
} }
void Tan::eval_cpu(const std::vector<array>& inputs, array& out) { void Tan::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
unary_fp(in, out, detail::Tan()); unary_fp(in, out, detail::Tan(), stream());
} }
void Tanh::eval_cpu(const std::vector<array>& inputs, array& out) { void Tanh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
unary_fp(in, out, detail::Tanh()); unary_fp(in, out, detail::Tanh(), stream());
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -5,174 +5,296 @@
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/array.h" #include "mlx/array.h"
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h" #include "mlx/backend/cpu/simd/simd.h"
#include "mlx/utils.h" #include "mlx/utils.h"
namespace mlx::core { namespace mlx::core {
void set_unary_output_data(const array& in, array& out) { void set_unary_output_data(const array& in, array& out) {
if (is_donatable(in, out)) { if (in.flags().contiguous) {
out.copy_shared_buffer(in); if (is_donatable(in, out)) {
out.copy_shared_buffer(in);
} else {
auto size = in.data_size();
out.set_data(
allocator::malloc(size * out.itemsize()),
size,
in.strides(),
in.flags());
}
} else { } else {
auto size = in.data_size(); out.set_data(allocator::malloc(out.nbytes()));
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
} }
} }
template <typename T, typename U = T, typename Op> template <typename T, typename U = T, typename Op>
void unary_op(const T* a, U* out, Op op, size_t shape, size_t stride) { void unary_op(const T* a, U* out, size_t shape, size_t stride) {
for (size_t i = 0; i < shape; i += 1) { for (size_t i = 0; i < shape; i += 1) {
out[i] = op(*a); out[i] = Op{}(*a);
a += stride; a += stride;
} }
} }
template <typename T, typename U = T, typename Op> template <typename T, typename U = T, typename Op>
void unary_op(const array& a, array& out, Op op) { void unary_op(const array& a, array& out, Op) {
const T* a_ptr = a.data<T>(); const T* src = a.data<T>();
U* dst = out.data<U>();
auto ndim = a.ndim();
if (a.flags().contiguous) { if (a.flags().contiguous) {
set_unary_output_data(a, out); auto size = a.data_size();
U* dst = out.data<U>();
constexpr int N = simd::max_size<T>; constexpr int N = simd::max_size<T>;
size_t size = a.data_size();
while (size >= N) { while (size >= N) {
simd::store(dst, op(simd::load<T, N>(a_ptr))); simd::store(dst, Op{}(simd::load<T, N>(src)));
size -= N; size -= N;
a_ptr += N; src += N;
dst += N; dst += N;
} }
while (size > 0) { while (size > 0) {
*dst = op(*a_ptr); *dst = Op{}(*src);
size--; size--;
dst++; dst++;
a_ptr++; src++;
} }
} else { } else {
out.set_data(allocator::malloc_or_wait(out.nbytes())); size_t shape = ndim > 0 ? a.shape().back() : 1;
U* dst = out.data<U>(); size_t stride = ndim > 0 ? a.strides().back() : 1;
size_t shape = a.ndim() > 0 ? a.shape(-1) : 1; if (ndim <= 1) {
size_t stride = a.ndim() > 0 ? a.strides(-1) : 1; unary_op<T, U, Op>(src, dst, shape, stride);
if (a.ndim() <= 1) {
unary_op(a_ptr, dst, op, shape, stride);
return; return;
} }
ContiguousIterator it(a.shape(), a.strides(), a.ndim() - 1); auto it = ContiguousIterator(a.shape(), a.strides(), ndim - 1);
for (size_t elem = 0; elem < a.size(); elem += shape) { for (size_t elem = 0; elem < a.size(); elem += shape) {
unary_op(a_ptr + it.loc, dst + elem, op, shape, stride); unary_op<T, U, Op>(src + it.loc, dst + elem, shape, stride);
it.step(); it.step();
} }
} }
} }
template <typename Op> template <typename Op>
void unary(const array& a, array& out, Op op) { void unary(const array& a, array& out, Op op, Stream stream) {
switch (out.dtype()) { set_unary_output_data(a, out);
case bool_: auto& encoder = cpu::get_command_encoder(stream);
unary_op<bool>(a, out, op); encoder.set_input_array(a);
break; encoder.set_output_array(out);
case uint8: encoder.dispatch([a = array::unsafe_weak_copy(a),
unary_op<uint8_t>(a, out, op); out = array::unsafe_weak_copy(out),
break; op = op]() mutable {
case uint16: switch (out.dtype()) {
unary_op<uint16_t>(a, out, op); case bool_:
break; unary_op<bool>(a, out, op);
case uint32: break;
unary_op<uint32_t>(a, out, op); case uint8:
break; unary_op<uint8_t>(a, out, op);
case uint64: break;
unary_op<uint64_t>(a, out, op); case uint16:
break; unary_op<uint16_t>(a, out, op);
case int8: break;
unary_op<int8_t>(a, out, op); case uint32:
break; unary_op<uint32_t>(a, out, op);
case int16: break;
unary_op<int16_t>(a, out, op); case uint64:
break; unary_op<uint64_t>(a, out, op);
case int32: break;
unary_op<int32_t>(a, out, op); case int8:
break; unary_op<int8_t>(a, out, op);
case int64: break;
unary_op<int64_t>(a, out, op); case int16:
break; unary_op<int16_t>(a, out, op);
case float16: break;
unary_op<float16_t>(a, out, op); case int32:
break; unary_op<int32_t>(a, out, op);
case float32: break;
unary_op<float>(a, out, op); case int64:
break; unary_op<int64_t>(a, out, op);
case float64: break;
unary_op<double>(a, out, op); case float16:
break; unary_op<float16_t>(a, out, op);
case bfloat16: break;
unary_op<bfloat16_t>(a, out, op); case float32:
break; unary_op<float>(a, out, op);
case complex64: break;
unary_op<complex64_t>(a, out, op); case float64:
break; unary_op<double>(a, out, op);
} break;
case bfloat16:
unary_op<bfloat16_t>(a, out, op);
break;
case complex64:
unary_op<complex64_t>(a, out, op);
break;
}
});
} }
template <typename Op> template <typename Op>
void unary_fp(const array& a, array& out, Op op) { void unary_real_fp(const array& a, array& out, Op op, Stream stream) {
switch (out.dtype()) { set_unary_output_data(a, out);
case bfloat16: auto& encoder = cpu::get_command_encoder(stream);
unary_op<bfloat16_t>(a, out, op); encoder.set_input_array(a);
break; encoder.set_output_array(out);
case float16: encoder.dispatch([a = array::unsafe_weak_copy(a),
unary_op<float16_t>(a, out, op); out = array::unsafe_weak_copy(out),
break; op = op]() mutable {
case float32: switch (out.dtype()) {
unary_op<float>(a, out, op); case bfloat16:
break; unary_op<bfloat16_t>(a, out, op);
case float64: break;
unary_op<double>(a, out, op); case float16:
break; unary_op<float16_t>(a, out, op);
case complex64: break;
unary_op<complex64_t>(a, out, op); case float32:
break; unary_op<float>(a, out, op);
default: break;
std::ostringstream err; case float64:
err << "[unary_fp] Does not support " << out.dtype(); unary_op<double>(a, out, op);
throw std::runtime_error(err.str()); break;
} default:
std::ostringstream err;
err << "[unary_real] Does not support " << out.dtype();
throw std::runtime_error(err.str());
}
});
}
template <typename Op>
void unary_fp(const array& a, array& out, Op op, Stream stream) {
set_unary_output_data(a, out);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
out = array::unsafe_weak_copy(out),
op = op]() mutable {
switch (out.dtype()) {
case bfloat16:
unary_op<bfloat16_t>(a, out, op);
break;
case float16:
unary_op<float16_t>(a, out, op);
break;
case float32:
unary_op<float>(a, out, op);
break;
case float64:
unary_op<double>(a, out, op);
break;
case complex64:
unary_op<complex64_t>(a, out, op);
break;
default:
std::ostringstream err;
err << "[unary_fp] Does not support " << out.dtype();
throw std::runtime_error(err.str());
}
});
} }
template <typename Op> template <typename Op>
void unary_int(const array& a, array& out, Op op) { void unary_signed(const array& a, array& out, Op op, Stream stream) {
switch (out.dtype()) { set_unary_output_data(a, out);
case uint8: auto& encoder = cpu::get_command_encoder(stream);
unary_op<uint8_t>(a, out, op); encoder.set_input_array(a);
break; encoder.set_output_array(out);
case uint16: encoder.dispatch([a = array::unsafe_weak_copy(a),
unary_op<uint16_t>(a, out, op); out = array::unsafe_weak_copy(out),
break; op = op]() mutable {
case uint32: switch (out.dtype()) {
unary_op<uint32_t>(a, out, op); case int8:
break; unary_op<int8_t>(a, out, op);
case uint64: break;
unary_op<uint64_t>(a, out, op); case int16:
break; unary_op<int16_t>(a, out, op);
case int8: break;
unary_op<int8_t>(a, out, op); case int32:
break; unary_op<int32_t>(a, out, op);
case int16: break;
unary_op<int16_t>(a, out, op); case int64:
break; unary_op<int64_t>(a, out, op);
case int32: break;
unary_op<int32_t>(a, out, op); case float16:
break; unary_op<float16_t>(a, out, op);
case int64: break;
unary_op<int64_t>(a, out, op); case float32:
break; unary_op<float>(a, out, op);
default: break;
std::ostringstream err; case float64:
err << "[unary_int] Does not support " << out.dtype(); unary_op<double>(a, out, op);
throw std::runtime_error(err.str()); break;
} case bfloat16:
unary_op<bfloat16_t>(a, out, op);
break;
case complex64:
unary_op<complex64_t>(a, out, op);
break;
default:
throw std::runtime_error("[Abs] Called on unsigned type");
}
});
}
template <typename Op>
void unary_complex(const array& a, array& out, Op op, Stream stream) {
set_unary_output_data(a, out);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
out = array::unsafe_weak_copy(out),
op = op]() mutable { unary_op<complex64_t>(a, out, op); });
}
template <typename Op>
void unary_complex_to_float(const array& a, array& out, Op op, Stream stream) {
set_unary_output_data(a, out);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_output_array(out);
encoder.dispatch(
[a = array::unsafe_weak_copy(a),
out = array::unsafe_weak_copy(out),
op = op]() mutable { unary_op<complex64_t, float>(a, out, op); });
}
template <typename Op>
void unary_int(const array& a, array& out, Op op, Stream stream) {
set_unary_output_data(a, out);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
out = array::unsafe_weak_copy(out),
op = op]() mutable {
switch (out.dtype()) {
case uint8:
unary_op<uint8_t>(a, out, op);
break;
case uint16:
unary_op<uint16_t>(a, out, op);
break;
case uint32:
unary_op<uint32_t>(a, out, op);
break;
case uint64:
unary_op<uint64_t>(a, out, op);
break;
case int8:
unary_op<int8_t>(a, out, op);
break;
case int16:
unary_op<int16_t>(a, out, op);
break;
case int32:
unary_op<int32_t>(a, out, op);
break;
case int64:
unary_op<int64_t>(a, out, op);
break;
default:
std::ostringstream err;
err << "[unary_int] Does not support " << out.dtype();
throw std::runtime_error(err.str());
}
});
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -86,13 +86,14 @@ struct Sign {
template <int N, typename T> template <int N, typename T>
Simd<T, N> operator()(Simd<T, N> x) { Simd<T, N> operator()(Simd<T, N> x) {
auto z = Simd<T, N>{0}; auto z = Simd<T, N>{0};
auto o = Simd<T, N>{1};
auto m = Simd<T, N>{-1};
if constexpr (std::is_unsigned_v<T>) { if constexpr (std::is_unsigned_v<T>) {
return x != z; return simd::select(x == z, z, o);
} else if constexpr (std::is_same_v<T, complex64_t>) { } else if constexpr (std::is_same_v<T, complex64_t>) {
return simd::select(x == z, x, Simd<T, N>(x / simd::abs(x))); return simd::select(x == z, x, Simd<T, N>(x / simd::abs(x)));
} else { } else {
return simd::select( return simd::select(x < z, m, simd::select(x > z, o, z));
x < z, Simd<T, N>{-1}, simd::select(x > z, Simd<T, N>{1}, z));
} }
} }
SINGLE() SINGLE()

View File

@@ -47,6 +47,7 @@ if(MLX_METAL_JIT)
make_jit_source(binary) make_jit_source(binary)
make_jit_source(binary_two) make_jit_source(binary_two)
make_jit_source(fft kernels/fft/radix.h kernels/fft/readwrite.h) make_jit_source(fft kernels/fft/radix.h kernels/fft/readwrite.h)
make_jit_source(logsumexp)
make_jit_source(ternary) make_jit_source(ternary)
make_jit_source(softmax) make_jit_source(softmax)
make_jit_source(scan) make_jit_source(scan)
@@ -60,6 +61,7 @@ if(MLX_METAL_JIT)
kernels/steel/gemm/transforms.h) kernels/steel/gemm/transforms.h)
make_jit_source(steel/gemm/kernels/steel_gemm_fused) make_jit_source(steel/gemm/kernels/steel_gemm_fused)
make_jit_source(steel/gemm/kernels/steel_gemm_masked kernels/steel/defines.h) make_jit_source(steel/gemm/kernels/steel_gemm_masked kernels/steel/defines.h)
make_jit_source(steel/gemm/kernels/steel_gemm_gather)
make_jit_source(steel/gemm/kernels/steel_gemm_splitk) make_jit_source(steel/gemm/kernels/steel_gemm_splitk)
make_jit_source( make_jit_source(
steel/conv/conv steel/conv/conv
@@ -95,6 +97,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp ${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cpp
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp ${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp ${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp

View File

@@ -3,6 +3,7 @@
#include "mlx/backend/metal/metal.h" #include "mlx/backend/metal/metal.h"
#include "mlx/backend/metal/metal_impl.h" #include "mlx/backend/metal/metal_impl.h"
#include "mlx/backend/metal/resident.h" #include "mlx/backend/metal/resident.h"
#include "mlx/memory.h"
#include <mach/vm_page_size.h> #include <mach/vm_page_size.h>
#include <unistd.h> #include <unistd.h>
@@ -20,6 +21,9 @@ Allocator& allocator() {
} }
void* Buffer::raw_ptr() { void* Buffer::raw_ptr() {
if (!ptr_) {
return nullptr;
}
return static_cast<MTL::Buffer*>(ptr_)->contents(); return static_cast<MTL::Buffer*>(ptr_)->contents();
} }
@@ -29,8 +33,11 @@ namespace metal {
namespace { namespace {
BufferCache::BufferCache(MTL::Device* device) BufferCache::BufferCache(ResidencySet& residency_set)
: device_(device), head_(nullptr), tail_(nullptr), pool_size_(0) {} : head_(nullptr),
tail_(nullptr),
pool_size_(0),
residency_set_(residency_set) {}
BufferCache::~BufferCache() { BufferCache::~BufferCache() {
auto pool = metal::new_scoped_memory_pool(); auto pool = metal::new_scoped_memory_pool();
@@ -41,6 +48,9 @@ int BufferCache::clear() {
int n_release = 0; int n_release = 0;
for (auto& [size, holder] : buffer_pool_) { for (auto& [size, holder] : buffer_pool_) {
if (holder->buf) { if (holder->buf) {
if (!holder->buf->heap()) {
residency_set_.erase(holder->buf);
}
holder->buf->release(); holder->buf->release();
n_release++; n_release++;
} }
@@ -98,6 +108,9 @@ int BufferCache::release_cached_buffers(size_t min_bytes_to_free) {
while (tail_ && (total_bytes_freed < min_bytes_to_free)) { while (tail_ && (total_bytes_freed < min_bytes_to_free)) {
if (tail_->buf) { if (tail_->buf) {
total_bytes_freed += tail_->buf->length(); total_bytes_freed += tail_->buf->length();
if (!tail_->buf->heap()) {
residency_set_.erase(tail_->buf);
}
tail_->buf->release(); tail_->buf->release();
tail_->buf = nullptr; tail_->buf = nullptr;
n_release++; n_release++;
@@ -152,7 +165,7 @@ void BufferCache::remove_from_list(BufferCache::BufferHolder* to_remove) {
MetalAllocator::MetalAllocator() MetalAllocator::MetalAllocator()
: device_(device(mlx::core::Device::gpu).mtl_device()), : device_(device(mlx::core::Device::gpu).mtl_device()),
residency_set_(device_), residency_set_(device_),
buffer_cache_(device_) { buffer_cache_(residency_set_) {
auto pool = metal::new_scoped_memory_pool(); auto pool = metal::new_scoped_memory_pool();
auto memsize = std::get<size_t>(device_info().at("memory_size")); auto memsize = std::get<size_t>(device_info().at("memory_size"));
auto max_rec_size = auto max_rec_size =
@@ -189,16 +202,19 @@ size_t MetalAllocator::set_cache_limit(size_t limit) {
return limit; return limit;
}; };
size_t MetalAllocator::set_memory_limit(size_t limit, bool relaxed) { size_t MetalAllocator::set_memory_limit(size_t limit) {
std::unique_lock lk(mutex_); std::unique_lock lk(mutex_);
std::swap(limit, block_limit_); std::swap(limit, block_limit_);
relaxed_ = relaxed;
gc_limit_ = std::min( gc_limit_ = std::min(
block_limit_, block_limit_,
static_cast<size_t>(0.95 * device_->recommendedMaxWorkingSetSize())); static_cast<size_t>(0.95 * device_->recommendedMaxWorkingSetSize()));
return limit; return limit;
}; };
size_t MetalAllocator::get_memory_limit() {
return block_limit_;
}
size_t MetalAllocator::set_wired_limit(size_t limit) { size_t MetalAllocator::set_wired_limit(size_t limit) {
std::unique_lock lk(mutex_); std::unique_lock lk(mutex_);
std::swap(limit, wired_limit_); std::swap(limit, wired_limit_);
@@ -206,7 +222,7 @@ size_t MetalAllocator::set_wired_limit(size_t limit) {
return limit; return limit;
}; };
Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) { Buffer MetalAllocator::malloc(size_t size) {
// Metal doesn't like empty buffers // Metal doesn't like empty buffers
if (size == 0) { if (size == 0) {
return Buffer{nullptr}; return Buffer{nullptr};
@@ -233,11 +249,6 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
if (!buf) { if (!buf) {
size_t mem_required = get_active_memory() + get_cache_memory() + size; size_t mem_required = get_active_memory() + get_cache_memory() + size;
// If there is too much memory pressure, fail (likely causes a wait).
if (!(allow_swap && relaxed_) && mem_required >= block_limit_) {
return Buffer{nullptr};
}
auto pool = metal::new_scoped_memory_pool(); auto pool = metal::new_scoped_memory_pool();
// If we have a lot of memory pressure or are over the maximum cache size, // If we have a lot of memory pressure or are over the maximum cache size,
@@ -261,9 +272,13 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
if (!buf) { if (!buf) {
buf = device_->newBuffer(size, resource_options); buf = device_->newBuffer(size, resource_options);
} }
if (!buf) {
return Buffer{nullptr};
}
lk.lock(); lk.lock();
if (buf) { num_resources_++;
num_resources_++; if (!buf->heap()) {
residency_set_.insert(buf);
} }
} }
@@ -277,10 +292,6 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
get_cache_memory() - max_pool_size_); get_cache_memory() - max_pool_size_);
} }
if (!buf->heap()) {
residency_set_.insert(buf);
}
return Buffer{static_cast<void*>(buf)}; return Buffer{static_cast<void*>(buf)};
} }
@@ -296,14 +307,14 @@ void MetalAllocator::free(Buffer buffer) {
return; return;
} }
std::unique_lock lk(mutex_); std::unique_lock lk(mutex_);
if (!buf->heap()) {
residency_set_.erase(buf);
}
active_memory_ -= buf->length(); active_memory_ -= buf->length();
if (get_cache_memory() < max_pool_size_) { if (get_cache_memory() < max_pool_size_) {
buffer_cache_.recycle_to_cache(buf); buffer_cache_.recycle_to_cache(buf);
} else { } else {
num_resources_--; num_resources_--;
if (!buf->heap()) {
residency_set_.erase(buf);
}
lk.unlock(); lk.unlock();
auto pool = metal::new_scoped_memory_pool(); auto pool = metal::new_scoped_memory_pool();
buf->release(); buf->release();
@@ -322,37 +333,40 @@ MetalAllocator& allocator() {
return *allocator_; return *allocator_;
} }
} // namespace metal
size_t set_cache_limit(size_t limit) { size_t set_cache_limit(size_t limit) {
return allocator().set_cache_limit(limit); return metal::allocator().set_cache_limit(limit);
} }
size_t set_memory_limit(size_t limit, bool relaxed /* = true */) { size_t set_memory_limit(size_t limit) {
return allocator().set_memory_limit(limit, relaxed); return metal::allocator().set_memory_limit(limit);
}
size_t get_memory_limit() {
return metal::allocator().get_memory_limit();
} }
size_t set_wired_limit(size_t limit) { size_t set_wired_limit(size_t limit) {
if (limit > if (limit > std::get<size_t>(metal::device_info().at(
std::get<size_t>(device_info().at("max_recommended_working_set_size"))) { "max_recommended_working_set_size"))) {
throw std::invalid_argument( throw std::invalid_argument(
"[metal::set_wired_limit] Setting a wired limit larger than " "[metal::set_wired_limit] Setting a wired limit larger than "
"the maximum working set size is not allowed."); "the maximum working set size is not allowed.");
} }
return allocator().set_wired_limit(limit); return metal::allocator().set_wired_limit(limit);
} }
size_t get_active_memory() { size_t get_active_memory() {
return allocator().get_active_memory(); return metal::allocator().get_active_memory();
} }
size_t get_peak_memory() { size_t get_peak_memory() {
return allocator().get_peak_memory(); return metal::allocator().get_peak_memory();
} }
void reset_peak_memory() { void reset_peak_memory() {
allocator().reset_peak_memory(); metal::allocator().reset_peak_memory();
} }
size_t get_cache_memory() { size_t get_cache_memory() {
return allocator().get_cache_memory(); return metal::allocator().get_cache_memory();
} }
void clear_cache() { void clear_cache() {
return allocator().clear_cache(); return metal::allocator().clear_cache();
} }
} // namespace metal
} // namespace mlx::core } // namespace mlx::core

View File

@@ -18,7 +18,7 @@ namespace {
class BufferCache { class BufferCache {
public: public:
BufferCache(MTL::Device* device); BufferCache(ResidencySet& residency_set);
~BufferCache(); ~BufferCache();
MTL::Buffer* reuse_from_cache(size_t size); MTL::Buffer* reuse_from_cache(size_t size);
@@ -42,13 +42,11 @@ class BufferCache {
void add_at_head(BufferHolder* to_add); void add_at_head(BufferHolder* to_add);
void remove_from_list(BufferHolder* to_remove); void remove_from_list(BufferHolder* to_remove);
MTL::Device* device_;
MTL::Heap* heap_{nullptr};
std::multimap<size_t, BufferHolder*> buffer_pool_; std::multimap<size_t, BufferHolder*> buffer_pool_;
BufferHolder* head_; BufferHolder* head_;
BufferHolder* tail_; BufferHolder* tail_;
size_t pool_size_; size_t pool_size_;
ResidencySet& residency_set_;
}; };
} // namespace } // namespace
@@ -56,7 +54,7 @@ class BufferCache {
class MetalAllocator : public allocator::Allocator { class MetalAllocator : public allocator::Allocator {
/** Allocator for Metal GPUs. */ /** Allocator for Metal GPUs. */
public: public:
virtual Buffer malloc(size_t size, bool allow_swap = false) override; virtual Buffer malloc(size_t size) override;
virtual void free(Buffer buffer) override; virtual void free(Buffer buffer) override;
virtual size_t size(Buffer buffer) const override; virtual size_t size(Buffer buffer) const override;
size_t get_active_memory() { size_t get_active_memory() {
@@ -73,7 +71,8 @@ class MetalAllocator : public allocator::Allocator {
return buffer_cache_.cache_size(); return buffer_cache_.cache_size();
}; };
size_t set_cache_limit(size_t limit); size_t set_cache_limit(size_t limit);
size_t set_memory_limit(size_t limit, bool relaxed); size_t set_memory_limit(size_t limit);
size_t get_memory_limit();
size_t set_wired_limit(size_t limit); size_t set_wired_limit(size_t limit);
void clear_cache(); void clear_cache();

View File

@@ -102,16 +102,9 @@ void binary_op_gpu_inplace(
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_compute_pipeline_state(kernel);
// - If a is donated it goes to the first output
// - If b is donated it goes to the first output if a was not donated
// otherwise it goes to the second output.
// - If there is only one output only one of a and b will be donated.
bool donate_a = a.data_shared_ptr() == nullptr;
bool donate_b = b.data_shared_ptr() == nullptr;
int arg_idx = 0; int arg_idx = 0;
compute_encoder.set_input_array(donate_a ? outputs[0] : a, arg_idx++); compute_encoder.set_input_array(a, arg_idx++);
compute_encoder.set_input_array( compute_encoder.set_input_array(b, arg_idx++);
donate_b ? (donate_a ? outputs[1] : outputs[0]) : b, arg_idx++);
compute_encoder.set_output_array(outputs[0], arg_idx++); compute_encoder.set_output_array(outputs[0], arg_idx++);
if (outputs.size() == 2) { if (outputs.size() == 2) {
compute_encoder.set_output_array(outputs[1], arg_idx++); compute_encoder.set_output_array(outputs[1], arg_idx++);
@@ -164,8 +157,8 @@ void binary_op_gpu(
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b); auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, outputs[0], bopt, true); set_binary_op_output_data(a, b, outputs[0], bopt);
set_binary_op_output_data(a, b, outputs[1], bopt, true); set_binary_op_output_data(a, b, outputs[1], bopt);
binary_op_gpu_inplace(inputs, outputs, op, s); binary_op_gpu_inplace(inputs, outputs, op, s);
} }
@@ -195,7 +188,7 @@ void binary_op_gpu(
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b); auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt, true); set_binary_op_output_data(a, b, out, bopt);
binary_op_gpu_inplace(inputs, out, op, s); binary_op_gpu_inplace(inputs, out, op, s);
} }

View File

@@ -457,7 +457,7 @@ void Compiled::eval_gpu(
} }
compiled_allocate_outputs( compiled_allocate_outputs(
inputs, outputs, inputs_, constant_ids_, contiguous, true); inputs, outputs, inputs_, constant_ids_, contiguous);
// Put the outputs in // Put the outputs in
for (auto& x : outputs) { for (auto& x : outputs) {

View File

@@ -37,7 +37,7 @@ void explicit_gemm_conv_ND_gpu(
Shape unfolded_shape{implicit_M, implicit_K}; Shape unfolded_shape{implicit_M, implicit_K};
array in_unfolded(unfolded_shape, in.dtype(), nullptr, {}); array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes())); in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes()));
// Prepare unfolding kernel // Prepare unfolding kernel
std::ostringstream kname; std::ostringstream kname;
@@ -115,7 +115,7 @@ void explicit_gemm_conv_group_ND_gpu(
// Prepare unfolding array // Prepare unfolding array
Shape unfolded_shape{implicit_M, implicit_K * groups}; Shape unfolded_shape{implicit_M, implicit_K * groups};
array in_unfolded(unfolded_shape, in.dtype(), nullptr, {}); array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes())); in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes()));
// Prepare unfolding kernel // Prepare unfolding kernel
std::ostringstream kname; std::ostringstream kname;
@@ -613,7 +613,7 @@ void winograd_conv_2D_gpu(
// Do filter transform // Do filter transform
Shape filt_wg_shape = {8 * 8, conv_params.C, conv_params.O}; Shape filt_wg_shape = {8 * 8, conv_params.C, conv_params.O};
array filt_wg(std::move(filt_wg_shape), wt.dtype(), nullptr, {}); array filt_wg(std::move(filt_wg_shape), wt.dtype(), nullptr, {});
filt_wg.set_data(allocator::malloc_or_wait(filt_wg.nbytes())); filt_wg.set_data(allocator::malloc(filt_wg.nbytes()));
copies_w.push_back(filt_wg); copies_w.push_back(filt_wg);
{ {
int bc = 32; int bc = 32;
@@ -640,7 +640,7 @@ void winograd_conv_2D_gpu(
// Do input transform // Do input transform
Shape inp_wg_shape = {8 * 8, N_tiles, conv_params.C}; Shape inp_wg_shape = {8 * 8, N_tiles, conv_params.C};
array inp_wg(std::move(inp_wg_shape), in.dtype(), nullptr, {}); array inp_wg(std::move(inp_wg_shape), in.dtype(), nullptr, {});
inp_wg.set_data(allocator::malloc_or_wait(inp_wg.nbytes())); inp_wg.set_data(allocator::malloc(inp_wg.nbytes()));
copies_w.push_back(inp_wg); copies_w.push_back(inp_wg);
{ {
int bc = 32; int bc = 32;
@@ -667,7 +667,7 @@ void winograd_conv_2D_gpu(
// Do batched gemm // Do batched gemm
Shape out_wg_shape = {8 * 8, N_tiles, conv_params.O}; Shape out_wg_shape = {8 * 8, N_tiles, conv_params.O};
array out_wg(std::move(out_wg_shape), in.dtype(), nullptr, {}); array out_wg(std::move(out_wg_shape), in.dtype(), nullptr, {});
out_wg.set_data(allocator::malloc_or_wait(out_wg.nbytes())); out_wg.set_data(allocator::malloc(out_wg.nbytes()));
copies_w.push_back(out_wg); copies_w.push_back(out_wg);
{ {
std::vector<array> empty_copies; std::vector<array> empty_copies;
@@ -712,6 +712,65 @@ void winograd_conv_2D_gpu(
} }
} }
void depthwise_conv_2D_gpu(
const Stream& s,
metal::Device& d,
const array& in,
const array& wt,
array out,
const MLXConvParams<2>& conv_params) {
std::ostringstream kname;
kname << "depthwise_conv_2d_" << type_to_name(out);
std::string base_name = kname.str();
const int N = conv_params.N;
const int ker_h = conv_params.wS[0];
const int ker_w = conv_params.wS[1];
const int str_h = conv_params.str[0];
const int str_w = conv_params.str[1];
const int tc = 8;
const int tw = 8;
const int th = 4;
const bool do_flip = conv_params.flip;
metal::MTLFCList func_consts = {
{&ker_h, MTL::DataType::DataTypeInt, 00},
{&ker_w, MTL::DataType::DataTypeInt, 01},
{&str_h, MTL::DataType::DataTypeInt, 10},
{&str_w, MTL::DataType::DataTypeInt, 11},
{&th, MTL::DataType::DataTypeInt, 100},
{&tw, MTL::DataType::DataTypeInt, 101},
{&do_flip, MTL::DataType::DataTypeBool, 200},
};
// clang-format off
kname << "_ker_h_" << ker_h
<< "_ker_w_" << ker_w
<< "_str_h_" << str_h
<< "_str_w_" << str_w
<< "_tgp_h_" << th
<< "_tgp_w_" << tw
<< "_do_flip_" << (do_flip ? 't' : 'n'); // clang-format on
std::string hash_name = kname.str();
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_input_array(wt, 1);
compute_encoder.set_output_array(out, 2);
compute_encoder.set_bytes(conv_params, 3);
MTL::Size group_dims = MTL::Size(tc, tw, th);
MTL::Size grid_dims = MTL::Size(
conv_params.C / tc, conv_params.oS[1] / tw, (conv_params.oS[0] / th) * N);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
void conv_2D_gpu( void conv_2D_gpu(
const Stream& s, const Stream& s,
metal::Device& d, metal::Device& d,
@@ -754,11 +813,20 @@ void conv_2D_gpu(
bool is_kdil_one = conv_params.kdil[0] == 1 && conv_params.kdil[1] == 1; bool is_kdil_one = conv_params.kdil[0] == 1 && conv_params.kdil[1] == 1;
bool is_idil_one = conv_params.idil[0] == 1 && conv_params.idil[1] == 1; bool is_idil_one = conv_params.idil[0] == 1 && conv_params.idil[1] == 1;
if (groups > 1) { if (is_idil_one && groups > 1) {
const int C_per_group = conv_params.C / groups; const int C_per_group = conv_params.C / groups;
const int O_per_group = conv_params.O / groups; const int O_per_group = conv_params.O / groups;
if (is_idil_one && (C_per_group <= 4 || C_per_group % 16 == 0) && if (C_per_group == 1 && O_per_group == 1 && is_kdil_one &&
conv_params.wS[0] <= 7 && conv_params.wS[1] <= 7 &&
conv_params.str[0] <= 2 && conv_params.str[1] <= 2 &&
conv_params.oS[0] % 8 == 0 && conv_params.oS[1] % 8 == 0 &&
conv_params.wt_strides[1] == conv_params.wS[1] &&
conv_params.C % 16 == 0 && conv_params.C == conv_params.O) {
return depthwise_conv_2D_gpu(s, d, in, wt, out, conv_params);
}
if ((C_per_group <= 4 || C_per_group % 16 == 0) &&
(O_per_group <= 16 || O_per_group % 16 == 0)) { (O_per_group <= 16 || O_per_group % 16 == 0)) {
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params); return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
} else { } else {
@@ -855,7 +923,7 @@ void conv_3D_gpu(
} // namespace } // namespace
void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) { void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream(); auto& s = stream();
auto& d = metal::device(s.device); auto& d = metal::device(s.device);

View File

@@ -14,25 +14,11 @@ namespace mlx::core {
constexpr int MAX_COPY_SPECIALIZED_DIMS = 3; constexpr int MAX_COPY_SPECIALIZED_DIMS = 3;
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) { void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
if (ctype == CopyType::Vector) { bool donated = set_copy_output_data(in, out, ctype);
// If the input is donateable, we are doing a vector copy and the types if (donated && in.dtype() == out.dtype()) {
// have the same size, then the input buffer can hold the output. // If the output has the same type as the input then there is nothing to
if (in.is_donatable() && in.itemsize() == out.itemsize()) { // copy, just use the buffer.
out.move_shared_buffer(in); return;
// If the output has the same type as the input then there is nothing to
// copy, just use the buffer.
if (in.dtype() == out.dtype()) {
return;
}
} else {
out.set_data(
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
}
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
} }
if (ctype == CopyType::GeneralGeneral) { if (ctype == CopyType::GeneralGeneral) {
ctype = CopyType::General; ctype = CopyType::General;
@@ -216,7 +202,7 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
if (out.size() == 0) { if (out.size() == 0) {
return; return;
} }
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
bool large = out.data_size() > UINT32_MAX; bool large = out.data_size() > UINT32_MAX;
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
std::string kernel_name = std::string(large ? "s2" : "s") + "_copy" + std::string kernel_name = std::string(large ? "s2" : "s") + "_copy" +

View File

@@ -19,7 +19,7 @@ void CustomKernel::eval_gpu(
copies.emplace_back(init_value_.value(), out.dtype()); copies.emplace_back(init_value_.value(), out.dtype());
fill_gpu(copies.back(), out, s); fill_gpu(copies.back(), out, s);
} else { } else {
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
} }
} }

View File

@@ -19,9 +19,6 @@ namespace mlx::core::metal {
namespace { namespace {
// TODO nicer way to set this or possibly expose as an environment variable
constexpr int MAX_BUFFERS_PER_QUEUE = 12;
constexpr const char* default_mtllib_path = METAL_PATH; constexpr const char* default_mtllib_path = METAL_PATH;
auto get_metal_version() { auto get_metal_version() {
@@ -58,7 +55,10 @@ std::pair<MTL::Library*, NS::Error*> load_library_from_path(
} }
#ifdef SWIFTPM_BUNDLE #ifdef SWIFTPM_BUNDLE
MTL::Library* try_load_bundle(MTL::Device* device, NS::URL* url) { MTL::Library* try_load_bundle(
MTL::Device* device,
NS::URL* url,
const std::string& lib_name) {
std::string bundle_path = std::string(url->fileSystemRepresentation()) + "/" + std::string bundle_path = std::string(url->fileSystemRepresentation()) + "/" +
SWIFTPM_BUNDLE + ".bundle"; SWIFTPM_BUNDLE + ".bundle";
auto bundle = NS::Bundle::alloc()->init( auto bundle = NS::Bundle::alloc()->init(
@@ -66,8 +66,8 @@ MTL::Library* try_load_bundle(MTL::Device* device, NS::URL* url) {
if (bundle != nullptr) { if (bundle != nullptr) {
std::string resource_path = std::string resource_path =
std::string(bundle->resourceURL()->fileSystemRepresentation()) + "/" + std::string(bundle->resourceURL()->fileSystemRepresentation()) + "/" +
"default.metallib"; lib_name + ".metallib" auto [lib, error] =
auto [lib, error] = load_library_from_path(device, resource_path.c_str()); load_library_from_path(device, resource_path.c_str());
if (lib) { if (lib) {
return lib; return lib;
} }
@@ -76,51 +76,124 @@ MTL::Library* try_load_bundle(MTL::Device* device, NS::URL* url) {
} }
#endif #endif
// Firstly, search for the metallib in the same path as this binary
std::pair<MTL::Library*, NS::Error*> load_colocated_library(
MTL::Device* device,
const std::string& lib_name) {
std::string lib_path = get_colocated_mtllib_path(lib_name);
if (lib_path.size() != 0) {
return load_library_from_path(device, lib_path.c_str());
}
return {nullptr, nullptr};
}
std::pair<MTL::Library*, NS::Error*> load_swiftpm_library(
MTL::Device* device,
const std::string& lib_name) {
#ifdef SWIFTPM_BUNDLE
MTL::Library* library =
try_load_bundle(device, NS::Bundle::mainBundle()->bundleURL(), lib_name);
if (library != nullptr) {
return {library, nullptr};
}
auto bundles = NS::Bundle::allBundles();
for (int i = 0, c = (int)bundles->count(); i < c; i++) {
auto bundle = reinterpret_cast<NS::Bundle*>(bundles->object(i));
library = try_load_bundle(device, bundle->resourceURL());
if (library != nullptr) {
return {library, nullptr};
}
}
#endif
return {nullptr, nullptr};
}
MTL::Library* load_default_library(MTL::Device* device) {
NS::Error *error1, *error2, *error3;
MTL::Library* lib;
// First try the colocated mlx.metallib
std::tie(lib, error1) = load_colocated_library(device, "mlx");
if (lib) {
return lib;
}
// Then try default.metallib in a SwiftPM bundle if we have one
std::tie(lib, error2) = load_swiftpm_library(device, "default");
if (lib) {
return lib;
}
// Finally try default_mtllib_path
std::tie(lib, error3) = load_library_from_path(device, default_mtllib_path);
if (!lib) {
std::ostringstream msg;
msg << "Failed to load the default metallib. ";
if (error1 != nullptr) {
msg << error1->localizedDescription()->utf8String() << " ";
}
if (error2 != nullptr) {
msg << error2->localizedDescription()->utf8String() << " ";
}
if (error3 != nullptr) {
msg << error3->localizedDescription()->utf8String() << " ";
}
throw std::runtime_error(msg.str());
}
return lib;
}
MTL::Library* load_library( MTL::Library* load_library(
MTL::Device* device, MTL::Device* device,
const std::string& lib_name = "mlx", const std::string& lib_name,
const char* lib_path = default_mtllib_path) { const std::string& lib_path) {
// Firstly, search for the metallib in the same path as this binary // We have been given a path that ends in metallib so try to load it
std::string first_path = get_colocated_mtllib_path(lib_name); if (lib_path.size() > 9 &&
if (first_path.size() != 0) { std::equal(lib_path.end() - 9, lib_path.end(), ".metallib")) {
auto [lib, error] = load_library_from_path(device, first_path.c_str()); auto [lib, error] = load_library_from_path(device, lib_path.c_str());
if (!lib) {
std::ostringstream msg;
msg << "Failed to load the metallib from <" << lib_path << "> with error "
<< error->localizedDescription()->utf8String();
throw std::runtime_error(msg.str());
}
}
// We have been given a path so try to load from lib_path / lib_name.metallib
if (lib_path.size() > 0) {
std::string full_path = lib_path + "/" + lib_name + ".metallib";
auto [lib, error] = load_library_from_path(device, full_path.c_str());
if (!lib) {
std::ostringstream msg;
msg << "Failed to load the metallib from <" << full_path
<< "> with error " << error->localizedDescription()->utf8String();
throw std::runtime_error(msg.str());
}
}
// Try to load the colocated library
{
auto [lib, error] = load_colocated_library(device, lib_name);
if (lib) { if (lib) {
return lib; return lib;
} }
} }
#ifdef SWIFTPM_BUNDLE // Try to load the library from swiftpm
// try to load from a swiftpm resource bundle -- scan the available bundles to
// find one that contains the named bundle
{ {
MTL::Library* library = auto [lib, error] = load_swiftpm_library(device, lib_name);
try_load_bundle(device, NS::Bundle::mainBundle()->bundleURL()); if (lib) {
if (library != nullptr) { return lib;
return library;
}
auto bundles = NS::Bundle::allBundles();
for (int i = 0, c = (int)bundles->count(); i < c; i++) {
auto bundle = reinterpret_cast<NS::Bundle*>(bundles->object(i));
library = try_load_bundle(device, bundle->resourceURL());
if (library != nullptr) {
return library;
}
} }
} }
#endif
// Couldn't find it so let's load it from default_mtllib_path std::ostringstream msg;
{ msg << "Failed to load the metallib " << lib_name << ".metallib. "
auto [lib, error] = load_library_from_path(device, lib_path); << "We attempted to load it from <" << get_colocated_mtllib_path(lib_name)
if (!lib) { << ">";
std::ostringstream msg; #ifdef SWIFTPM_BUNDLE
msg << error->localizedDescription()->utf8String() << "\n" msg << " and from the Swift PM bundle.";
<< "Failed to load device library from <" << lib_path << ">" #endif
<< " or <" << first_path << ">."; throw std::runtime_error(msg.str());
throw std::runtime_error(msg.str());
}
return lib;
}
} }
} // namespace } // namespace
@@ -168,9 +241,10 @@ void CommandEncoder::set_output_array(
register_output_array(a); register_output_array(a);
} }
void CommandEncoder::register_output_array(array& a) { void CommandEncoder::register_output_array(const array& a) {
all_outputs_.insert(a.buffer().ptr()); all_outputs_.insert(a.buffer().ptr());
auto buf = static_cast<MTL::Resource*>(a.buffer().ptr());
auto buf = static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
if (concurrent_) { if (concurrent_) {
concurrent_outputs_.insert(buf); concurrent_outputs_.insert(buf);
} else { } else {
@@ -212,7 +286,7 @@ void CommandEncoder::barrier() {
Device::Device() { Device::Device() {
auto pool = new_scoped_memory_pool(); auto pool = new_scoped_memory_pool();
device_ = load_device(); device_ = load_device();
library_map_ = {{"mlx", load_library(device_)}}; library_map_ = {{"mlx", load_default_library(device_)}};
arch_ = std::string(device_->architecture()->name()->utf8String()); arch_ = std::string(device_->architecture()->name()->utf8String());
auto arch = arch_.back(); auto arch = arch_.back();
switch (arch) { switch (arch) {
@@ -255,7 +329,7 @@ Device::~Device() {
void Device::new_queue(int index) { void Device::new_queue(int index) {
auto thread_pool = metal::new_scoped_memory_pool(); auto thread_pool = metal::new_scoped_memory_pool();
auto q = device_->newCommandQueue(MAX_BUFFERS_PER_QUEUE); auto q = device_->newCommandQueue();
debug_set_stream_queue_label(q, index); debug_set_stream_queue_label(q, index);
if (!q) { if (!q) {
throw std::runtime_error( throw std::runtime_error(

View File

@@ -62,7 +62,7 @@ struct CommandEncoder {
void set_input_array(const array& a, int idx, int64_t offset = 0); void set_input_array(const array& a, int idx, int64_t offset = 0);
void set_output_array(array& a, int idx, int64_t offset = 0); void set_output_array(array& a, int idx, int64_t offset = 0);
void register_output_array(array& a); void register_output_array(const array& a);
void dispatch_threadgroups(MTL::Size grid_dims, MTL::Size group_dims); void dispatch_threadgroups(MTL::Size grid_dims, MTL::Size group_dims);
void dispatch_threads(MTL::Size grid_dims, MTL::Size group_dims); void dispatch_threads(MTL::Size grid_dims, MTL::Size group_dims);
void maybeInsertBarrier(); void maybeInsertBarrier();
@@ -189,15 +189,7 @@ class Device {
void register_library( void register_library(
const std::string& lib_name, const std::string& lib_name,
const std::string& lib_path); const std::string& lib_path = "");
// Note, this should remain in the header so that it is not dynamically
// linked
void register_library(const std::string& lib_name) {
if (auto it = library_map_.find(lib_name); it == library_map_.end()) {
register_library(lib_name, get_colocated_mtllib_path(lib_name));
}
}
MTL::Library* get_library( MTL::Library* get_library(
const std::string& name, const std::string& name,

View File

@@ -4,149 +4,30 @@
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/event.h"
#include "mlx/backend/metal/fence.h"
#include "mlx/backend/metal/utils.h" #include "mlx/backend/metal/utils.h"
#include "mlx/distributed/ops.h" #include "mlx/distributed/ops.h"
#include "mlx/distributed/primitives.h" #include "mlx/distributed/primitives.h"
#include "mlx/fence.h"
#include "mlx/scheduler.h" #include "mlx/scheduler.h"
namespace mlx::core::distributed { namespace mlx::core::distributed {
void signal_and_wait(const Event& e_signal, const Event& e_wait) { void AllReduce::eval_gpu(const std::vector<array>&, std::vector<array>&) {
if (e_signal.valid()) { throw std::runtime_error("[AllReduce::eval_gpu] has no GPU implementation.");
encode_signal(e_signal);
}
encode_wait(e_wait);
} }
void AllReduce::eval_gpu( void AllGather::eval_gpu(const std::vector<array>&, std::vector<array>&) {
const std::vector<array>& inputs, throw std::runtime_error("[AllGather::eval_gpu] has no GPU implementation.");
std::vector<array>& outputs) {
assert(inputs.size() == 1);
assert(outputs.size() == 1);
auto& in = inputs[0];
Fence f{stream()};
if (in.event().valid()) {
f.update_gpu(in);
}
auto& out = outputs[0];
if (in.is_donatable()) {
out.move_shared_buffer(in);
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
f.wait_gpu(out);
auto task = [in = in,
out = unsafe_weak_copy(out),
f = std::move(f),
reduce_type = reduce_type_,
group = group()]() mutable {
if (in.event().valid()) {
f.wait();
}
switch (reduce_type) {
case Sum:
distributed::detail::all_sum(
group, in.data_shared_ptr() == nullptr ? out : in, out);
break;
default:
throw std::runtime_error("Only all reduce sum is supported for now");
}
f.update();
};
scheduler::enqueue(detail::communication_stream(), std::move(task));
} }
void AllGather::eval_gpu( void Send::eval_gpu(const std::vector<array>&, std::vector<array>&) {
const std::vector<array>& inputs, throw std::runtime_error("[Send::eval_gpu] has no GPU implementation.");
std::vector<array>& outputs) {
assert(inputs.size() == 1);
assert(outputs.size() == 1);
auto& in = inputs[0];
auto& out = outputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
Fence f{stream()};
if (in.event().valid()) {
f.update_gpu(in);
}
f.wait_gpu(out);
auto task = [in = in,
out = unsafe_weak_copy(out),
f = std::move(f),
group = group()]() mutable {
if (in.event().valid()) {
f.wait();
}
distributed::detail::all_gather(group, in, out);
f.update();
};
scheduler::enqueue(detail::communication_stream(), std::move(task));
} }
void Send::eval_gpu( void Recv::eval_gpu(const std::vector<array>&, std::vector<array>&) {
const std::vector<array>& inputs, throw std::runtime_error("[Recv::eval_gpu] has no GPU implementation.");
std::vector<array>& outputs) {
assert(inputs.size() == 1);
assert(outputs.size() == 1);
auto& in = inputs[0];
// Encode a signal event for the input
Fence f{stream()};
if (in.event().valid()) {
f.update_gpu(in);
}
auto& out = outputs[0];
move_or_copy(in, out);
// Schedule an async send on the comm stream
auto task = [in = in,
out = unsafe_weak_copy(out),
f = std::move(f),
group = group(),
dst = dst_]() mutable {
if (in.event().valid()) {
f.wait();
}
distributed::detail::send(group, out, dst);
};
scheduler::enqueue(detail::communication_stream(), std::move(task));
}
void Recv::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 0);
assert(outputs.size() == 1);
auto& out = outputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
Fence f{stream()};
f.wait_gpu(out);
// Schedule an async recv on the comm stream
auto task = [out = unsafe_weak_copy(out),
f = std::move(f),
group = group(),
src = src_]() mutable {
distributed::detail::recv(group, out, src);
f.update();
};
scheduler::enqueue(detail::communication_stream(), std::move(task));
} }
} // namespace mlx::core::distributed } // namespace mlx::core::distributed

Some files were not shown because too many files have changed in this diff Show More